Initial commit of WQ_GUI
This commit is contained in:
253
src/core/non_empirical_retrieval.py
Normal file
253
src/core/non_empirical_retrieval.py
Normal file
@ -0,0 +1,253 @@
|
||||
import sys
|
||||
from src.utils.util import *
|
||||
import warnings
|
||||
import pandas as pd
|
||||
import re # Added for regex parsing in safe_load_spectral
|
||||
# 配置:光谱起始列(前四列是坐标和像素信息:x_coord,y_coord,pixel_x,pixel_y)
|
||||
|
||||
SPEC_START_COL = 4
|
||||
|
||||
class RetrievalError(Exception):
|
||||
"""面向用户的友好错误。"""
|
||||
pass
|
||||
|
||||
def ensure_file_exists(path, name):
|
||||
if not isinstance(path, str) or not path:
|
||||
raise RetrievalError(f"{name} 路径为空。")
|
||||
if not os.path.exists(path):
|
||||
raise RetrievalError(f"{name} 不存在:{path}")
|
||||
|
||||
def safe_load_model(model_info_path):
|
||||
ensure_file_exists(model_info_path, "模型信息文件")
|
||||
try:
|
||||
model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path)
|
||||
except Exception as e:
|
||||
raise RetrievalError(f"无法读取/解析模型文件:{model_info_path}\n原因:{e}")
|
||||
if model_info is None:
|
||||
raise RetrievalError("模型文件缺少 'model_info'。")
|
||||
model_info = np.asarray(model_info)
|
||||
if model_info.ndim == 0 or model_info.size == 0:
|
||||
raise RetrievalError("模型系数为空。")
|
||||
return model_type, model_info, accuracy_
|
||||
|
||||
def safe_load_spectral(coor_spectral_path):
|
||||
ensure_file_exists(coor_spectral_path, "坐标-光谱文件")
|
||||
|
||||
# 使用 pandas 读取文件
|
||||
try:
|
||||
# 读取为 DataFrame,跳过第一行(列名),明确指定数据类型为 float
|
||||
df = pd.read_csv(coor_spectral_path, encoding="utf-8-sig", header=0, dtype=float)
|
||||
# 转换为 numpy 数组以保持原有格式
|
||||
coor_spectral = df.values
|
||||
except Exception as e:
|
||||
raise RetrievalError(f"无法读取坐标-光谱文件:{coor_spectral_path}\n原因:{e}")
|
||||
|
||||
if coor_spectral.ndim != 2 or coor_spectral.shape[0] < 1:
|
||||
raise RetrievalError("坐标-光谱文件维度异常:需要至少一行数据。")
|
||||
|
||||
if coor_spectral.shape[1] <= SPEC_START_COL:
|
||||
raise RetrievalError(f"坐标-光谱文件列数不足(至少需要 {SPEC_START_COL+1} 列,含 4 列坐标信息 + ≥1 列光谱)。")
|
||||
|
||||
# 由于第一行已经是数据,不再需要提取波长行
|
||||
# 波长信息需要从列名中提取
|
||||
try:
|
||||
# 读取列名来获取波长信息
|
||||
df_with_header = pd.read_csv(coor_spectral_path, encoding="utf-8-sig", header=0)
|
||||
wavelengths = df_with_header.columns[SPEC_START_COL:].astype(float).values
|
||||
except Exception as e:
|
||||
raise RetrievalError(f"无法解析波长信息:{e}")
|
||||
|
||||
if not np.all(np.isfinite(wavelengths)):
|
||||
raise RetrievalError("波长数据包含 NaN/Inf。")
|
||||
# 非严格单调也可,但给出警告
|
||||
if np.any(np.diff(wavelengths) <= 0):
|
||||
warnings.warn("波长非严格递增,这可能导致波段匹配误差。", RuntimeWarning)
|
||||
|
||||
return coor_spectral, wavelengths
|
||||
|
||||
def find_index(wavelength, array):
|
||||
differences = np.abs(array - wavelength)
|
||||
min_position = int(np.argmin(differences))
|
||||
return min_position
|
||||
|
||||
def _clamp_window(index_abs, window, ncols, spec_start_col=SPEC_START_COL):
|
||||
if window is None:
|
||||
raise RetrievalError("window 为空。")
|
||||
window = int(window)
|
||||
if window < 0:
|
||||
raise RetrievalError(f"window 必须为非负整数,收到:{window}")
|
||||
left = max(spec_start_col, index_abs - window)
|
||||
right = min(ncols, index_abs + window + 1)
|
||||
if right - left <= 0:
|
||||
raise RetrievalError(f"窗口无有效光谱列(left={left}, right={right}, ncols={ncols})。")
|
||||
return left, right
|
||||
|
||||
def get_mean_value(index_abs, array, window):
|
||||
"""index_abs 为绝对列索引(含前两列坐标),这里会夹紧窗口。"""
|
||||
left, right = _clamp_window(index_abs, window, array.shape[1], SPEC_START_COL)
|
||||
# 仅在样本行上取平均
|
||||
result = array[1:, left:right].mean(axis=1)
|
||||
if not np.all(np.isfinite(result)):
|
||||
warnings.warn("均值结果包含 NaN/Inf,可能是窗口内存在异常值。", RuntimeWarning)
|
||||
return result
|
||||
|
||||
def calculate(x1, x2, coefficients):
|
||||
x1 = np.asarray(x1, dtype=np.float64).ravel()
|
||||
x2 = np.asarray(x2, dtype=np.float64).ravel()
|
||||
coeffs = np.asarray(coefficients, dtype=np.float64).reshape(-1)
|
||||
if x1.shape[0] != x2.shape[0]:
|
||||
raise RetrievalError(f"x1 与 x2 长度不一致: {x1.shape[0]} vs {x2.shape[0]}")
|
||||
if coeffs.size != 3:
|
||||
raise RetrievalError(f"线性模型系数应为 3 个(x1, x2, 截距),收到 {coeffs.size} 个。")
|
||||
|
||||
# 诊断:检查 NaN/Inf
|
||||
n_bad = (~np.isfinite(x1) | ~np.isfinite(x2)).sum()
|
||||
if n_bad:
|
||||
print(f"[警告] x 含 {n_bad} 个非有限值,将产生 NaN。")
|
||||
|
||||
# 避免 dot/blAS,直接逐元素计算
|
||||
y_pred = x1 * coeffs[0] + x2 * coeffs[1] + coeffs[2]
|
||||
return y_pred
|
||||
|
||||
|
||||
def _safe_polyval(coeffs, x, name):
|
||||
coeffs = np.asarray(coeffs).reshape(-1)
|
||||
if coeffs.ndim != 1 or coeffs.size < 1:
|
||||
raise RetrievalError(f"{name} 的多项式系数非法。")
|
||||
try:
|
||||
y = np.polyval(coeffs, x)
|
||||
except Exception as e:
|
||||
raise RetrievalError(f"{name} 计算失败(polyval):{e}")
|
||||
return y
|
||||
|
||||
def retrieval_chl_a(model_info_path, coor_spectral_path, output_path, window=5):
|
||||
model_type, model_info, accuracy_ = safe_load_model(model_info_path)
|
||||
coor_spectral, wavelengths = safe_load_spectral(coor_spectral_path)
|
||||
|
||||
def idx_abs_for(wave):
|
||||
idx_rel = find_index(wave, wavelengths) # 相对光谱起始列的索引
|
||||
return SPEC_START_COL + idx_rel # 转为绝对列索引
|
||||
|
||||
try:
|
||||
idx_651 = idx_abs_for(651)
|
||||
idx_707 = idx_abs_for(707)
|
||||
idx_670 = idx_abs_for(670)
|
||||
except Exception as e:
|
||||
raise RetrievalError(f"波段索引计算失败:{e}")
|
||||
|
||||
band_651 = get_mean_value(idx_651, coor_spectral, window)
|
||||
band_707 = get_mean_value(idx_707, coor_spectral, window)
|
||||
band_670 = get_mean_value(idx_670, coor_spectral, window)
|
||||
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
denom = (band_707 - band_670)
|
||||
x = (band_651 - band_707) / denom
|
||||
bad = ~np.isfinite(x)
|
||||
if bad.any():
|
||||
warnings.warn(f"chl_a 极速出现 {bad.sum()} 个无效比值(分母≈0 或含 NaN),这些位置结果将为 NaN。", RuntimeWarning)
|
||||
|
||||
retrieval_result = _safe_polyval(model_info, x, "chl_a")
|
||||
|
||||
# 创建DataFrame并保存为CSV
|
||||
result_df = pd.DataFrame({
|
||||
'longitude': coor_spectral[1:, 0],
|
||||
'latitude': coor_spectral[1:, 1],
|
||||
'prediction': retrieval_result
|
||||
})
|
||||
|
||||
try:
|
||||
result_df.to_csv(output_path, index=False, float_format='%.8f')
|
||||
except Exception as e:
|
||||
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
|
||||
|
||||
return result_df.values
|
||||
|
||||
def retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=5):
|
||||
model_type, model_info, accuracy_ = safe_load_model(model_info_path)
|
||||
coor_spectral, wavelengths = safe_load_spectral(coor_spectral_path)
|
||||
|
||||
def idx_abs_for(wave):
|
||||
return SPEC_START_COL + find_index(wave, wavelengths)
|
||||
|
||||
idx_600 = idx_abs_for(600)
|
||||
idx_500 = idx_abs_for(500)
|
||||
idx_850 = idx_abs_for(850)
|
||||
|
||||
band_600 = get_mean_value(idx_600, coor_spectral, window)
|
||||
band_500 = get_mean_value(idx_500, coor_spectral, window)
|
||||
band_850 = get_mean_value(idx_850, coor_spectral, window)
|
||||
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
x13 = np.log(band_500 / band_850)
|
||||
x23 = np.exp(band_600 / band_500)
|
||||
invalid = ~np.isfinite(x13) | ~np.isfinite(x23)
|
||||
if invalid.any():
|
||||
warnings.warn(f"nh3 自变量出现 {invalid.sum()} 个无效值(0/负数/NaN),对应位置结果将为 NaN。", RuntimeWarning)
|
||||
|
||||
retrieval_result = calculate(x13, x23, model_info)
|
||||
|
||||
# 创建DataFrame
|
||||
result_df = pd.DataFrame({
|
||||
'longitude': coor_spectral[1:, 0],
|
||||
'latitude': coor_spectral[1:, 1],
|
||||
'prediction': retrieval_result
|
||||
})
|
||||
|
||||
if output_path is not None:
|
||||
try:
|
||||
result_df.to_csv(output_path, index=False, float_format='%.8f')
|
||||
except Exception as e:
|
||||
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
|
||||
|
||||
return result_df.values
|
||||
|
||||
def retrieval_tss(model_info_path, coor_spectral_path, output_path, window=5):
|
||||
# 先跑 nh3 的同型模型(按你的原逻辑)
|
||||
position_content = retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=window)
|
||||
|
||||
# 对结果进行指数变换
|
||||
predictions = np.exp(position_content[:, -1])
|
||||
|
||||
# 创建DataFrame
|
||||
result_df = pd.DataFrame({
|
||||
'longitude': position_content[:, 0],
|
||||
'latitude': position_content[:, 1],
|
||||
'prediction': predictions
|
||||
})
|
||||
|
||||
if not np.all(np.isfinite(result_df['prediction'])):
|
||||
warnings.warn("tss 结果包含非有限值(可能因指数溢出),已保留为 NaN。", RuntimeWarning)
|
||||
|
||||
try:
|
||||
result_df.to_csv(output_path, index=False, float_format='%.8f')
|
||||
except Exception as e:
|
||||
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
|
||||
|
||||
return result_df.values
|
||||
|
||||
def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius=5.0):
|
||||
try:
|
||||
if algorithm == "chl_a":
|
||||
return retrieval_chl_a(model_info_path, coor_spectral_path, output_path, wave_radius)
|
||||
elif algorithm in ["nh3", "mno4", "tn", "tp"]:
|
||||
return retrieval_nh3(model_info_path, coor_spectral_path, output_path, wave_radius)
|
||||
elif algorithm == "tss":
|
||||
return retrieval_tss(model_info_path, coor_spectral_path, output_path, wave_radius)
|
||||
else:
|
||||
raise RetrievalError(f"未知算法:{algorithm}(可选:chl_a / nh3 / mno4 / tn / tp / tss)")
|
||||
except RetrievalError as e:
|
||||
# 面向用户的友好错误
|
||||
print(f"[错误] {e}", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
except Exception as e:
|
||||
# 未预料的异常,附带类型与少量上下文
|
||||
print(f"[致命错误] {type(e).__name__}: {e}", file=sys.stderr)
|
||||
sys.exit(3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
algorithm= "chl_a"
|
||||
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\6_5_non_empirical_models\SS\SS_chl_a.json"
|
||||
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\7_sampling\sampling_spectra.csv"
|
||||
output_path= r"E:\code\WQ\pipeline_result\work_dir\8_predictions\SS_chl_a.csv"
|
||||
wave_radius=5.0
|
||||
non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius)
|
||||
Reference in New Issue
Block a user