Files
WQ_GUI/src/core/non_empirical_retrieval.py

253 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
from src.utils.util import *
import warnings
import pandas as pd
import re # Added for regex parsing in safe_load_spectral
# 配置光谱起始列前四列是坐标和像素信息x_coord,y_coord,pixel_x,pixel_y
SPEC_START_COL = 4
class RetrievalError(Exception):
"""面向用户的友好错误。"""
pass
def ensure_file_exists(path, name):
if not isinstance(path, str) or not path:
raise RetrievalError(f"{name} 路径为空。")
if not os.path.exists(path):
raise RetrievalError(f"{name} 不存在:{path}")
def safe_load_model(model_info_path):
ensure_file_exists(model_info_path, "模型信息文件")
try:
model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path)
except Exception as e:
raise RetrievalError(f"无法读取/解析模型文件:{model_info_path}\n原因:{e}")
if model_info is None:
raise RetrievalError("模型文件缺少 'model_info'")
model_info = np.asarray(model_info)
if model_info.ndim == 0 or model_info.size == 0:
raise RetrievalError("模型系数为空。")
return model_type, model_info, accuracy_
def safe_load_spectral(coor_spectral_path):
ensure_file_exists(coor_spectral_path, "坐标-光谱文件")
# 使用 pandas 读取文件
try:
# 读取为 DataFrame跳过第一行列名明确指定数据类型为 float
df = pd.read_csv(coor_spectral_path, encoding="utf-8-sig", header=0, dtype=float)
# 转换为 numpy 数组以保持原有格式
coor_spectral = df.values
except Exception as e:
raise RetrievalError(f"无法读取坐标-光谱文件:{coor_spectral_path}\n原因:{e}")
if coor_spectral.ndim != 2 or coor_spectral.shape[0] < 1:
raise RetrievalError("坐标-光谱文件维度异常:需要至少一行数据。")
if coor_spectral.shape[1] <= SPEC_START_COL:
raise RetrievalError(f"坐标-光谱文件列数不足(至少需要 {SPEC_START_COL+1} 列,含 4 列坐标信息 + ≥1 列光谱)。")
# 由于第一行已经是数据,不再需要提取波长行
# 波长信息需要从列名中提取
try:
# 读取列名来获取波长信息
df_with_header = pd.read_csv(coor_spectral_path, encoding="utf-8-sig", header=0)
wavelengths = df_with_header.columns[SPEC_START_COL:].astype(float).values
except Exception as e:
raise RetrievalError(f"无法解析波长信息:{e}")
if not np.all(np.isfinite(wavelengths)):
raise RetrievalError("波长数据包含 NaN/Inf。")
# 非严格单调也可,但给出警告
if np.any(np.diff(wavelengths) <= 0):
warnings.warn("波长非严格递增,这可能导致波段匹配误差。", RuntimeWarning)
return coor_spectral, wavelengths
def find_index(wavelength, array):
differences = np.abs(array - wavelength)
min_position = int(np.argmin(differences))
return min_position
def _clamp_window(index_abs, window, ncols, spec_start_col=SPEC_START_COL):
if window is None:
raise RetrievalError("window 为空。")
window = int(window)
if window < 0:
raise RetrievalError(f"window 必须为非负整数,收到:{window}")
left = max(spec_start_col, index_abs - window)
right = min(ncols, index_abs + window + 1)
if right - left <= 0:
raise RetrievalError(f"窗口无有效光谱列left={left}, right={right}, ncols={ncols})。")
return left, right
def get_mean_value(index_abs, array, window):
"""index_abs 为绝对列索引(含前两列坐标),这里会夹紧窗口。"""
left, right = _clamp_window(index_abs, window, array.shape[1], SPEC_START_COL)
# 仅在样本行上取平均
result = array[1:, left:right].mean(axis=1)
if not np.all(np.isfinite(result)):
warnings.warn("均值结果包含 NaN/Inf可能是窗口内存在异常值。", RuntimeWarning)
return result
def calculate(x1, x2, coefficients):
x1 = np.asarray(x1, dtype=np.float64).ravel()
x2 = np.asarray(x2, dtype=np.float64).ravel()
coeffs = np.asarray(coefficients, dtype=np.float64).reshape(-1)
if x1.shape[0] != x2.shape[0]:
raise RetrievalError(f"x1 与 x2 长度不一致: {x1.shape[0]} vs {x2.shape[0]}")
if coeffs.size != 3:
raise RetrievalError(f"线性模型系数应为 3 个x1, x2, 截距),收到 {coeffs.size} 个。")
# 诊断:检查 NaN/Inf
n_bad = (~np.isfinite(x1) | ~np.isfinite(x2)).sum()
if n_bad:
print(f"[警告] x 含 {n_bad} 个非有限值,将产生 NaN。")
# 避免 dot/blAS直接逐元素计算
y_pred = x1 * coeffs[0] + x2 * coeffs[1] + coeffs[2]
return y_pred
def _safe_polyval(coeffs, x, name):
coeffs = np.asarray(coeffs).reshape(-1)
if coeffs.ndim != 1 or coeffs.size < 1:
raise RetrievalError(f"{name} 的多项式系数非法。")
try:
y = np.polyval(coeffs, x)
except Exception as e:
raise RetrievalError(f"{name} 计算失败polyval{e}")
return y
def retrieval_chl_a(model_info_path, coor_spectral_path, output_path, window=5):
model_type, model_info, accuracy_ = safe_load_model(model_info_path)
coor_spectral, wavelengths = safe_load_spectral(coor_spectral_path)
def idx_abs_for(wave):
idx_rel = find_index(wave, wavelengths) # 相对光谱起始列的索引
return SPEC_START_COL + idx_rel # 转为绝对列索引
try:
idx_651 = idx_abs_for(651)
idx_707 = idx_abs_for(707)
idx_670 = idx_abs_for(670)
except Exception as e:
raise RetrievalError(f"波段索引计算失败:{e}")
band_651 = get_mean_value(idx_651, coor_spectral, window)
band_707 = get_mean_value(idx_707, coor_spectral, window)
band_670 = get_mean_value(idx_670, coor_spectral, window)
with np.errstate(divide='ignore', invalid='ignore'):
denom = (band_707 - band_670)
x = (band_651 - band_707) / denom
bad = ~np.isfinite(x)
if bad.any():
warnings.warn(f"chl_a 极速出现 {bad.sum()} 个无效比值分母≈0 或含 NaN这些位置结果将为 NaN。", RuntimeWarning)
retrieval_result = _safe_polyval(model_info, x, "chl_a")
# 创建DataFrame并保存为CSV
result_df = pd.DataFrame({
'longitude': coor_spectral[1:, 0],
'latitude': coor_spectral[1:, 1],
'prediction': retrieval_result
})
try:
result_df.to_csv(output_path, index=False, float_format='%.8f')
except Exception as e:
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
return result_df.values
def retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=5):
model_type, model_info, accuracy_ = safe_load_model(model_info_path)
coor_spectral, wavelengths = safe_load_spectral(coor_spectral_path)
def idx_abs_for(wave):
return SPEC_START_COL + find_index(wave, wavelengths)
idx_600 = idx_abs_for(600)
idx_500 = idx_abs_for(500)
idx_850 = idx_abs_for(850)
band_600 = get_mean_value(idx_600, coor_spectral, window)
band_500 = get_mean_value(idx_500, coor_spectral, window)
band_850 = get_mean_value(idx_850, coor_spectral, window)
with np.errstate(divide='ignore', invalid='ignore'):
x13 = np.log(band_500 / band_850)
x23 = np.exp(band_600 / band_500)
invalid = ~np.isfinite(x13) | ~np.isfinite(x23)
if invalid.any():
warnings.warn(f"nh3 自变量出现 {invalid.sum()} 个无效值0/负数/NaN对应位置结果将为 NaN。", RuntimeWarning)
retrieval_result = calculate(x13, x23, model_info)
# 创建DataFrame
result_df = pd.DataFrame({
'longitude': coor_spectral[1:, 0],
'latitude': coor_spectral[1:, 1],
'prediction': retrieval_result
})
if output_path is not None:
try:
result_df.to_csv(output_path, index=False, float_format='%.8f')
except Exception as e:
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
return result_df.values
def retrieval_tss(model_info_path, coor_spectral_path, output_path, window=5):
# 先跑 nh3 的同型模型(按你的原逻辑)
position_content = retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=window)
# 对结果进行指数变换
predictions = np.exp(position_content[:, -1])
# 创建DataFrame
result_df = pd.DataFrame({
'longitude': position_content[:, 0],
'latitude': position_content[:, 1],
'prediction': predictions
})
if not np.all(np.isfinite(result_df['prediction'])):
warnings.warn("tss 结果包含非有限值(可能因指数溢出),已保留为 NaN。", RuntimeWarning)
try:
result_df.to_csv(output_path, index=False, float_format='%.8f')
except Exception as e:
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
return result_df.values
def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius=5.0):
try:
if algorithm == "chl_a":
return retrieval_chl_a(model_info_path, coor_spectral_path, output_path, wave_radius)
elif algorithm in ["nh3", "mno4", "tn", "tp"]:
return retrieval_nh3(model_info_path, coor_spectral_path, output_path, wave_radius)
elif algorithm == "tss":
return retrieval_tss(model_info_path, coor_spectral_path, output_path, wave_radius)
else:
raise RetrievalError(f"未知算法:{algorithm}可选chl_a / nh3 / mno4 / tn / tp / tss")
except RetrievalError as e:
# 面向用户的友好错误
print(f"[错误] {e}", file=sys.stderr)
sys.exit(2)
except Exception as e:
# 未预料的异常,附带类型与少量上下文
print(f"[致命错误] {type(e).__name__}: {e}", file=sys.stderr)
sys.exit(3)
if __name__ == "__main__":
algorithm= "chl_a"
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json"
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv"
output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv"
wave_radius=5.0
non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius)