253 lines
10 KiB
Python
253 lines
10 KiB
Python
import sys
|
||
from src.utils.util import *
|
||
import warnings
|
||
import pandas as pd
|
||
import re # Added for regex parsing in safe_load_spectral
|
||
# 配置:光谱起始列(前四列是坐标和像素信息:x_coord,y_coord,pixel_x,pixel_y)
|
||
|
||
SPEC_START_COL = 4
|
||
|
||
class RetrievalError(Exception):
|
||
"""面向用户的友好错误。"""
|
||
pass
|
||
|
||
def ensure_file_exists(path, name):
|
||
if not isinstance(path, str) or not path:
|
||
raise RetrievalError(f"{name} 路径为空。")
|
||
if not os.path.exists(path):
|
||
raise RetrievalError(f"{name} 不存在:{path}")
|
||
|
||
def safe_load_model(model_info_path):
|
||
ensure_file_exists(model_info_path, "模型信息文件")
|
||
try:
|
||
model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path)
|
||
except Exception as e:
|
||
raise RetrievalError(f"无法读取/解析模型文件:{model_info_path}\n原因:{e}")
|
||
if model_info is None:
|
||
raise RetrievalError("模型文件缺少 'model_info'。")
|
||
model_info = np.asarray(model_info)
|
||
if model_info.ndim == 0 or model_info.size == 0:
|
||
raise RetrievalError("模型系数为空。")
|
||
return model_type, model_info, accuracy_
|
||
|
||
def safe_load_spectral(coor_spectral_path):
|
||
ensure_file_exists(coor_spectral_path, "坐标-光谱文件")
|
||
|
||
# 使用 pandas 读取文件
|
||
try:
|
||
# 读取为 DataFrame,跳过第一行(列名),明确指定数据类型为 float
|
||
df = pd.read_csv(coor_spectral_path, encoding="utf-8-sig", header=0, dtype=float)
|
||
# 转换为 numpy 数组以保持原有格式
|
||
coor_spectral = df.values
|
||
except Exception as e:
|
||
raise RetrievalError(f"无法读取坐标-光谱文件:{coor_spectral_path}\n原因:{e}")
|
||
|
||
if coor_spectral.ndim != 2 or coor_spectral.shape[0] < 1:
|
||
raise RetrievalError("坐标-光谱文件维度异常:需要至少一行数据。")
|
||
|
||
if coor_spectral.shape[1] <= SPEC_START_COL:
|
||
raise RetrievalError(f"坐标-光谱文件列数不足(至少需要 {SPEC_START_COL+1} 列,含 4 列坐标信息 + ≥1 列光谱)。")
|
||
|
||
# 由于第一行已经是数据,不再需要提取波长行
|
||
# 波长信息需要从列名中提取
|
||
try:
|
||
# 读取列名来获取波长信息
|
||
df_with_header = pd.read_csv(coor_spectral_path, encoding="utf-8-sig", header=0)
|
||
wavelengths = df_with_header.columns[SPEC_START_COL:].astype(float).values
|
||
except Exception as e:
|
||
raise RetrievalError(f"无法解析波长信息:{e}")
|
||
|
||
if not np.all(np.isfinite(wavelengths)):
|
||
raise RetrievalError("波长数据包含 NaN/Inf。")
|
||
# 非严格单调也可,但给出警告
|
||
if np.any(np.diff(wavelengths) <= 0):
|
||
warnings.warn("波长非严格递增,这可能导致波段匹配误差。", RuntimeWarning)
|
||
|
||
return coor_spectral, wavelengths
|
||
|
||
def find_index(wavelength, array):
|
||
differences = np.abs(array - wavelength)
|
||
min_position = int(np.argmin(differences))
|
||
return min_position
|
||
|
||
def _clamp_window(index_abs, window, ncols, spec_start_col=SPEC_START_COL):
|
||
if window is None:
|
||
raise RetrievalError("window 为空。")
|
||
window = int(window)
|
||
if window < 0:
|
||
raise RetrievalError(f"window 必须为非负整数,收到:{window}")
|
||
left = max(spec_start_col, index_abs - window)
|
||
right = min(ncols, index_abs + window + 1)
|
||
if right - left <= 0:
|
||
raise RetrievalError(f"窗口无有效光谱列(left={left}, right={right}, ncols={ncols})。")
|
||
return left, right
|
||
|
||
def get_mean_value(index_abs, array, window):
|
||
"""index_abs 为绝对列索引(含前两列坐标),这里会夹紧窗口。"""
|
||
left, right = _clamp_window(index_abs, window, array.shape[1], SPEC_START_COL)
|
||
# 仅在样本行上取平均
|
||
result = array[1:, left:right].mean(axis=1)
|
||
if not np.all(np.isfinite(result)):
|
||
warnings.warn("均值结果包含 NaN/Inf,可能是窗口内存在异常值。", RuntimeWarning)
|
||
return result
|
||
|
||
def calculate(x1, x2, coefficients):
|
||
x1 = np.asarray(x1, dtype=np.float64).ravel()
|
||
x2 = np.asarray(x2, dtype=np.float64).ravel()
|
||
coeffs = np.asarray(coefficients, dtype=np.float64).reshape(-1)
|
||
if x1.shape[0] != x2.shape[0]:
|
||
raise RetrievalError(f"x1 与 x2 长度不一致: {x1.shape[0]} vs {x2.shape[0]}")
|
||
if coeffs.size != 3:
|
||
raise RetrievalError(f"线性模型系数应为 3 个(x1, x2, 截距),收到 {coeffs.size} 个。")
|
||
|
||
# 诊断:检查 NaN/Inf
|
||
n_bad = (~np.isfinite(x1) | ~np.isfinite(x2)).sum()
|
||
if n_bad:
|
||
print(f"[警告] x 含 {n_bad} 个非有限值,将产生 NaN。")
|
||
|
||
# 避免 dot/blAS,直接逐元素计算
|
||
y_pred = x1 * coeffs[0] + x2 * coeffs[1] + coeffs[2]
|
||
return y_pred
|
||
|
||
|
||
def _safe_polyval(coeffs, x, name):
|
||
coeffs = np.asarray(coeffs).reshape(-1)
|
||
if coeffs.ndim != 1 or coeffs.size < 1:
|
||
raise RetrievalError(f"{name} 的多项式系数非法。")
|
||
try:
|
||
y = np.polyval(coeffs, x)
|
||
except Exception as e:
|
||
raise RetrievalError(f"{name} 计算失败(polyval):{e}")
|
||
return y
|
||
|
||
def retrieval_chl_a(model_info_path, coor_spectral_path, output_path, window=5):
|
||
model_type, model_info, accuracy_ = safe_load_model(model_info_path)
|
||
coor_spectral, wavelengths = safe_load_spectral(coor_spectral_path)
|
||
|
||
def idx_abs_for(wave):
|
||
idx_rel = find_index(wave, wavelengths) # 相对光谱起始列的索引
|
||
return SPEC_START_COL + idx_rel # 转为绝对列索引
|
||
|
||
try:
|
||
idx_651 = idx_abs_for(651)
|
||
idx_707 = idx_abs_for(707)
|
||
idx_670 = idx_abs_for(670)
|
||
except Exception as e:
|
||
raise RetrievalError(f"波段索引计算失败:{e}")
|
||
|
||
band_651 = get_mean_value(idx_651, coor_spectral, window)
|
||
band_707 = get_mean_value(idx_707, coor_spectral, window)
|
||
band_670 = get_mean_value(idx_670, coor_spectral, window)
|
||
|
||
with np.errstate(divide='ignore', invalid='ignore'):
|
||
denom = (band_707 - band_670)
|
||
x = (band_651 - band_707) / denom
|
||
bad = ~np.isfinite(x)
|
||
if bad.any():
|
||
warnings.warn(f"chl_a 极速出现 {bad.sum()} 个无效比值(分母≈0 或含 NaN),这些位置结果将为 NaN。", RuntimeWarning)
|
||
|
||
retrieval_result = _safe_polyval(model_info, x, "chl_a")
|
||
|
||
# 创建DataFrame并保存为CSV
|
||
result_df = pd.DataFrame({
|
||
'longitude': coor_spectral[1:, 0],
|
||
'latitude': coor_spectral[1:, 1],
|
||
'prediction': retrieval_result
|
||
})
|
||
|
||
try:
|
||
result_df.to_csv(output_path, index=False, float_format='%.8f')
|
||
except Exception as e:
|
||
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
|
||
|
||
return result_df.values
|
||
|
||
def retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=5):
|
||
model_type, model_info, accuracy_ = safe_load_model(model_info_path)
|
||
coor_spectral, wavelengths = safe_load_spectral(coor_spectral_path)
|
||
|
||
def idx_abs_for(wave):
|
||
return SPEC_START_COL + find_index(wave, wavelengths)
|
||
|
||
idx_600 = idx_abs_for(600)
|
||
idx_500 = idx_abs_for(500)
|
||
idx_850 = idx_abs_for(850)
|
||
|
||
band_600 = get_mean_value(idx_600, coor_spectral, window)
|
||
band_500 = get_mean_value(idx_500, coor_spectral, window)
|
||
band_850 = get_mean_value(idx_850, coor_spectral, window)
|
||
|
||
with np.errstate(divide='ignore', invalid='ignore'):
|
||
x13 = np.log(band_500 / band_850)
|
||
x23 = np.exp(band_600 / band_500)
|
||
invalid = ~np.isfinite(x13) | ~np.isfinite(x23)
|
||
if invalid.any():
|
||
warnings.warn(f"nh3 自变量出现 {invalid.sum()} 个无效值(0/负数/NaN),对应位置结果将为 NaN。", RuntimeWarning)
|
||
|
||
retrieval_result = calculate(x13, x23, model_info)
|
||
|
||
# 创建DataFrame
|
||
result_df = pd.DataFrame({
|
||
'longitude': coor_spectral[1:, 0],
|
||
'latitude': coor_spectral[1:, 1],
|
||
'prediction': retrieval_result
|
||
})
|
||
|
||
if output_path is not None:
|
||
try:
|
||
result_df.to_csv(output_path, index=False, float_format='%.8f')
|
||
except Exception as e:
|
||
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
|
||
|
||
return result_df.values
|
||
|
||
def retrieval_tss(model_info_path, coor_spectral_path, output_path, window=5):
|
||
# 先跑 nh3 的同型模型(按你的原逻辑)
|
||
position_content = retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=window)
|
||
|
||
# 对结果进行指数变换
|
||
predictions = np.exp(position_content[:, -1])
|
||
|
||
# 创建DataFrame
|
||
result_df = pd.DataFrame({
|
||
'longitude': position_content[:, 0],
|
||
'latitude': position_content[:, 1],
|
||
'prediction': predictions
|
||
})
|
||
|
||
if not np.all(np.isfinite(result_df['prediction'])):
|
||
warnings.warn("tss 结果包含非有限值(可能因指数溢出),已保留为 NaN。", RuntimeWarning)
|
||
|
||
try:
|
||
result_df.to_csv(output_path, index=False, float_format='%.8f')
|
||
except Exception as e:
|
||
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
|
||
|
||
return result_df.values
|
||
|
||
def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius=5.0):
|
||
try:
|
||
if algorithm == "chl_a":
|
||
return retrieval_chl_a(model_info_path, coor_spectral_path, output_path, wave_radius)
|
||
elif algorithm in ["nh3", "mno4", "tn", "tp"]:
|
||
return retrieval_nh3(model_info_path, coor_spectral_path, output_path, wave_radius)
|
||
elif algorithm == "tss":
|
||
return retrieval_tss(model_info_path, coor_spectral_path, output_path, wave_radius)
|
||
else:
|
||
raise RetrievalError(f"未知算法:{algorithm}(可选:chl_a / nh3 / mno4 / tn / tp / tss)")
|
||
except RetrievalError as e:
|
||
# 面向用户的友好错误
|
||
print(f"[错误] {e}", file=sys.stderr)
|
||
sys.exit(2)
|
||
except Exception as e:
|
||
# 未预料的异常,附带类型与少量上下文
|
||
print(f"[致命错误] {type(e).__name__}: {e}", file=sys.stderr)
|
||
sys.exit(3)
|
||
|
||
if __name__ == "__main__":
|
||
algorithm= "chl_a"
|
||
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json"
|
||
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv"
|
||
output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv"
|
||
wave_radius=5.0
|
||
non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius) |