import sys from src.utils.util import * import warnings import pandas as pd import re # Added for regex parsing in safe_load_spectral # 配置:光谱起始列(前四列是坐标和像素信息:x_coord,y_coord,pixel_x,pixel_y) SPEC_START_COL = 4 class RetrievalError(Exception): """面向用户的友好错误。""" pass def ensure_file_exists(path, name): if not isinstance(path, str) or not path: raise RetrievalError(f"{name} 路径为空。") if not os.path.exists(path): raise RetrievalError(f"{name} 不存在:{path}") def safe_load_model(model_info_path): ensure_file_exists(model_info_path, "模型信息文件") try: model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path) except Exception as e: raise RetrievalError(f"无法读取/解析模型文件:{model_info_path}\n原因:{e}") if model_info is None: raise RetrievalError("模型文件缺少 'model_info'。") model_info = np.asarray(model_info) if model_info.ndim == 0 or model_info.size == 0: raise RetrievalError("模型系数为空。") return model_type, model_info, accuracy_ def safe_load_spectral(coor_spectral_path): ensure_file_exists(coor_spectral_path, "坐标-光谱文件") # 使用 pandas 读取文件 try: # 读取为 DataFrame,跳过第一行(列名),明确指定数据类型为 float df = pd.read_csv(coor_spectral_path, encoding="utf-8-sig", header=0, dtype=float) # 转换为 numpy 数组以保持原有格式 coor_spectral = df.values except Exception as e: raise RetrievalError(f"无法读取坐标-光谱文件:{coor_spectral_path}\n原因:{e}") if coor_spectral.ndim != 2 or coor_spectral.shape[0] < 1: raise RetrievalError("坐标-光谱文件维度异常:需要至少一行数据。") if coor_spectral.shape[1] <= SPEC_START_COL: raise RetrievalError(f"坐标-光谱文件列数不足(至少需要 {SPEC_START_COL+1} 列,含 4 列坐标信息 + ≥1 列光谱)。") # 由于第一行已经是数据,不再需要提取波长行 # 波长信息需要从列名中提取 try: # 读取列名来获取波长信息 df_with_header = pd.read_csv(coor_spectral_path, encoding="utf-8-sig", header=0) wavelengths = df_with_header.columns[SPEC_START_COL:].astype(float).values except Exception as e: raise RetrievalError(f"无法解析波长信息:{e}") if not np.all(np.isfinite(wavelengths)): raise RetrievalError("波长数据包含 NaN/Inf。") # 非严格单调也可,但给出警告 if np.any(np.diff(wavelengths) <= 0): warnings.warn("波长非严格递增,这可能导致波段匹配误差。", RuntimeWarning) return coor_spectral, wavelengths def find_index(wavelength, array): differences = np.abs(array - wavelength) min_position = int(np.argmin(differences)) return min_position def _clamp_window(index_abs, window, ncols, spec_start_col=SPEC_START_COL): if window is None: raise RetrievalError("window 为空。") window = int(window) if window < 0: raise RetrievalError(f"window 必须为非负整数,收到:{window}") left = max(spec_start_col, index_abs - window) right = min(ncols, index_abs + window + 1) if right - left <= 0: raise RetrievalError(f"窗口无有效光谱列(left={left}, right={right}, ncols={ncols})。") return left, right def get_mean_value(index_abs, array, window): """index_abs 为绝对列索引(含前两列坐标),这里会夹紧窗口。""" left, right = _clamp_window(index_abs, window, array.shape[1], SPEC_START_COL) # 仅在样本行上取平均 result = array[1:, left:right].mean(axis=1) if not np.all(np.isfinite(result)): warnings.warn("均值结果包含 NaN/Inf,可能是窗口内存在异常值。", RuntimeWarning) return result def calculate(x1, x2, coefficients): x1 = np.asarray(x1, dtype=np.float64).ravel() x2 = np.asarray(x2, dtype=np.float64).ravel() coeffs = np.asarray(coefficients, dtype=np.float64).reshape(-1) if x1.shape[0] != x2.shape[0]: raise RetrievalError(f"x1 与 x2 长度不一致: {x1.shape[0]} vs {x2.shape[0]}") if coeffs.size != 3: raise RetrievalError(f"线性模型系数应为 3 个(x1, x2, 截距),收到 {coeffs.size} 个。") # 诊断:检查 NaN/Inf n_bad = (~np.isfinite(x1) | ~np.isfinite(x2)).sum() if n_bad: print(f"[警告] x 含 {n_bad} 个非有限值,将产生 NaN。") # 避免 dot/blAS,直接逐元素计算 y_pred = x1 * coeffs[0] + x2 * coeffs[1] + coeffs[2] return y_pred def _safe_polyval(coeffs, x, name): coeffs = np.asarray(coeffs).reshape(-1) if coeffs.ndim != 1 or coeffs.size < 1: raise RetrievalError(f"{name} 的多项式系数非法。") try: y = np.polyval(coeffs, x) except Exception as e: raise RetrievalError(f"{name} 计算失败(polyval):{e}") return y def retrieval_chl_a(model_info_path, coor_spectral_path, output_path, window=5): model_type, model_info, accuracy_ = safe_load_model(model_info_path) coor_spectral, wavelengths = safe_load_spectral(coor_spectral_path) def idx_abs_for(wave): idx_rel = find_index(wave, wavelengths) # 相对光谱起始列的索引 return SPEC_START_COL + idx_rel # 转为绝对列索引 try: idx_651 = idx_abs_for(651) idx_707 = idx_abs_for(707) idx_670 = idx_abs_for(670) except Exception as e: raise RetrievalError(f"波段索引计算失败:{e}") band_651 = get_mean_value(idx_651, coor_spectral, window) band_707 = get_mean_value(idx_707, coor_spectral, window) band_670 = get_mean_value(idx_670, coor_spectral, window) with np.errstate(divide='ignore', invalid='ignore'): denom = (band_707 - band_670) x = (band_651 - band_707) / denom bad = ~np.isfinite(x) if bad.any(): warnings.warn(f"chl_a 极速出现 {bad.sum()} 个无效比值(分母≈0 或含 NaN),这些位置结果将为 NaN。", RuntimeWarning) retrieval_result = _safe_polyval(model_info, x, "chl_a") # 创建DataFrame并保存为CSV result_df = pd.DataFrame({ 'longitude': coor_spectral[1:, 0], 'latitude': coor_spectral[1:, 1], 'prediction': retrieval_result }) try: result_df.to_csv(output_path, index=False, float_format='%.8f') except Exception as e: raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}") return result_df.values def retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=5): model_type, model_info, accuracy_ = safe_load_model(model_info_path) coor_spectral, wavelengths = safe_load_spectral(coor_spectral_path) def idx_abs_for(wave): return SPEC_START_COL + find_index(wave, wavelengths) idx_600 = idx_abs_for(600) idx_500 = idx_abs_for(500) idx_850 = idx_abs_for(850) band_600 = get_mean_value(idx_600, coor_spectral, window) band_500 = get_mean_value(idx_500, coor_spectral, window) band_850 = get_mean_value(idx_850, coor_spectral, window) with np.errstate(divide='ignore', invalid='ignore'): x13 = np.log(band_500 / band_850) x23 = np.exp(band_600 / band_500) invalid = ~np.isfinite(x13) | ~np.isfinite(x23) if invalid.any(): warnings.warn(f"nh3 自变量出现 {invalid.sum()} 个无效值(0/负数/NaN),对应位置结果将为 NaN。", RuntimeWarning) retrieval_result = calculate(x13, x23, model_info) # 创建DataFrame result_df = pd.DataFrame({ 'longitude': coor_spectral[1:, 0], 'latitude': coor_spectral[1:, 1], 'prediction': retrieval_result }) if output_path is not None: try: result_df.to_csv(output_path, index=False, float_format='%.8f') except Exception as e: raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}") return result_df.values def retrieval_tss(model_info_path, coor_spectral_path, output_path, window=5): # 先跑 nh3 的同型模型(按你的原逻辑) position_content = retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=window) # 对结果进行指数变换 predictions = np.exp(position_content[:, -1]) # 创建DataFrame result_df = pd.DataFrame({ 'longitude': position_content[:, 0], 'latitude': position_content[:, 1], 'prediction': predictions }) if not np.all(np.isfinite(result_df['prediction'])): warnings.warn("tss 结果包含非有限值(可能因指数溢出),已保留为 NaN。", RuntimeWarning) try: result_df.to_csv(output_path, index=False, float_format='%.8f') except Exception as e: raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}") return result_df.values def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius=5.0): try: if algorithm == "chl_a": return retrieval_chl_a(model_info_path, coor_spectral_path, output_path, wave_radius) elif algorithm in ["nh3", "mno4", "tn", "tp"]: return retrieval_nh3(model_info_path, coor_spectral_path, output_path, wave_radius) elif algorithm == "tss": return retrieval_tss(model_info_path, coor_spectral_path, output_path, wave_radius) else: raise RetrievalError(f"未知算法:{algorithm}(可选:chl_a / nh3 / mno4 / tn / tp / tss)") except RetrievalError as e: # 面向用户的友好错误 print(f"[错误] {e}", file=sys.stderr) sys.exit(2) except Exception as e: # 未预料的异常,附带类型与少量上下文 print(f"[致命错误] {type(e).__name__}: {e}", file=sys.stderr) sys.exit(3) if __name__ == "__main__": algorithm= "chl_a" model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json" coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv" output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv" wave_radius=5.0 non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius)