Initial commit of WQ_GUI

This commit is contained in:
2026-04-08 15:25:08 +08:00
commit 91e36407ae
302 changed files with 40872 additions and 0 deletions

View File

@ -0,0 +1,253 @@
import sys
from src.utils.util import *
import warnings
import pandas as pd
import re # Added for regex parsing in safe_load_spectral
# 配置光谱起始列前四列是坐标和像素信息x_coord,y_coord,pixel_x,pixel_y
SPEC_START_COL = 4
class RetrievalError(Exception):
"""面向用户的友好错误。"""
pass
def ensure_file_exists(path, name):
if not isinstance(path, str) or not path:
raise RetrievalError(f"{name} 路径为空。")
if not os.path.exists(path):
raise RetrievalError(f"{name} 不存在:{path}")
def safe_load_model(model_info_path):
ensure_file_exists(model_info_path, "模型信息文件")
try:
model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path)
except Exception as e:
raise RetrievalError(f"无法读取/解析模型文件:{model_info_path}\n原因:{e}")
if model_info is None:
raise RetrievalError("模型文件缺少 'model_info'")
model_info = np.asarray(model_info)
if model_info.ndim == 0 or model_info.size == 0:
raise RetrievalError("模型系数为空。")
return model_type, model_info, accuracy_
def safe_load_spectral(coor_spectral_path):
ensure_file_exists(coor_spectral_path, "坐标-光谱文件")
# 使用 pandas 读取文件
try:
# 读取为 DataFrame跳过第一行列名明确指定数据类型为 float
df = pd.read_csv(coor_spectral_path, encoding="utf-8-sig", header=0, dtype=float)
# 转换为 numpy 数组以保持原有格式
coor_spectral = df.values
except Exception as e:
raise RetrievalError(f"无法读取坐标-光谱文件:{coor_spectral_path}\n原因:{e}")
if coor_spectral.ndim != 2 or coor_spectral.shape[0] < 1:
raise RetrievalError("坐标-光谱文件维度异常:需要至少一行数据。")
if coor_spectral.shape[1] <= SPEC_START_COL:
raise RetrievalError(f"坐标-光谱文件列数不足(至少需要 {SPEC_START_COL+1} 列,含 4 列坐标信息 + ≥1 列光谱)。")
# 由于第一行已经是数据,不再需要提取波长行
# 波长信息需要从列名中提取
try:
# 读取列名来获取波长信息
df_with_header = pd.read_csv(coor_spectral_path, encoding="utf-8-sig", header=0)
wavelengths = df_with_header.columns[SPEC_START_COL:].astype(float).values
except Exception as e:
raise RetrievalError(f"无法解析波长信息:{e}")
if not np.all(np.isfinite(wavelengths)):
raise RetrievalError("波长数据包含 NaN/Inf。")
# 非严格单调也可,但给出警告
if np.any(np.diff(wavelengths) <= 0):
warnings.warn("波长非严格递增,这可能导致波段匹配误差。", RuntimeWarning)
return coor_spectral, wavelengths
def find_index(wavelength, array):
differences = np.abs(array - wavelength)
min_position = int(np.argmin(differences))
return min_position
def _clamp_window(index_abs, window, ncols, spec_start_col=SPEC_START_COL):
if window is None:
raise RetrievalError("window 为空。")
window = int(window)
if window < 0:
raise RetrievalError(f"window 必须为非负整数,收到:{window}")
left = max(spec_start_col, index_abs - window)
right = min(ncols, index_abs + window + 1)
if right - left <= 0:
raise RetrievalError(f"窗口无有效光谱列left={left}, right={right}, ncols={ncols})。")
return left, right
def get_mean_value(index_abs, array, window):
"""index_abs 为绝对列索引(含前两列坐标),这里会夹紧窗口。"""
left, right = _clamp_window(index_abs, window, array.shape[1], SPEC_START_COL)
# 仅在样本行上取平均
result = array[1:, left:right].mean(axis=1)
if not np.all(np.isfinite(result)):
warnings.warn("均值结果包含 NaN/Inf可能是窗口内存在异常值。", RuntimeWarning)
return result
def calculate(x1, x2, coefficients):
x1 = np.asarray(x1, dtype=np.float64).ravel()
x2 = np.asarray(x2, dtype=np.float64).ravel()
coeffs = np.asarray(coefficients, dtype=np.float64).reshape(-1)
if x1.shape[0] != x2.shape[0]:
raise RetrievalError(f"x1 与 x2 长度不一致: {x1.shape[0]} vs {x2.shape[0]}")
if coeffs.size != 3:
raise RetrievalError(f"线性模型系数应为 3 个x1, x2, 截距),收到 {coeffs.size} 个。")
# 诊断:检查 NaN/Inf
n_bad = (~np.isfinite(x1) | ~np.isfinite(x2)).sum()
if n_bad:
print(f"[警告] x 含 {n_bad} 个非有限值,将产生 NaN。")
# 避免 dot/blAS直接逐元素计算
y_pred = x1 * coeffs[0] + x2 * coeffs[1] + coeffs[2]
return y_pred
def _safe_polyval(coeffs, x, name):
coeffs = np.asarray(coeffs).reshape(-1)
if coeffs.ndim != 1 or coeffs.size < 1:
raise RetrievalError(f"{name} 的多项式系数非法。")
try:
y = np.polyval(coeffs, x)
except Exception as e:
raise RetrievalError(f"{name} 计算失败polyval{e}")
return y
def retrieval_chl_a(model_info_path, coor_spectral_path, output_path, window=5):
model_type, model_info, accuracy_ = safe_load_model(model_info_path)
coor_spectral, wavelengths = safe_load_spectral(coor_spectral_path)
def idx_abs_for(wave):
idx_rel = find_index(wave, wavelengths) # 相对光谱起始列的索引
return SPEC_START_COL + idx_rel # 转为绝对列索引
try:
idx_651 = idx_abs_for(651)
idx_707 = idx_abs_for(707)
idx_670 = idx_abs_for(670)
except Exception as e:
raise RetrievalError(f"波段索引计算失败:{e}")
band_651 = get_mean_value(idx_651, coor_spectral, window)
band_707 = get_mean_value(idx_707, coor_spectral, window)
band_670 = get_mean_value(idx_670, coor_spectral, window)
with np.errstate(divide='ignore', invalid='ignore'):
denom = (band_707 - band_670)
x = (band_651 - band_707) / denom
bad = ~np.isfinite(x)
if bad.any():
warnings.warn(f"chl_a 极速出现 {bad.sum()} 个无效比值分母≈0 或含 NaN这些位置结果将为 NaN。", RuntimeWarning)
retrieval_result = _safe_polyval(model_info, x, "chl_a")
# 创建DataFrame并保存为CSV
result_df = pd.DataFrame({
'longitude': coor_spectral[1:, 0],
'latitude': coor_spectral[1:, 1],
'prediction': retrieval_result
})
try:
result_df.to_csv(output_path, index=False, float_format='%.8f')
except Exception as e:
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
return result_df.values
def retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=5):
model_type, model_info, accuracy_ = safe_load_model(model_info_path)
coor_spectral, wavelengths = safe_load_spectral(coor_spectral_path)
def idx_abs_for(wave):
return SPEC_START_COL + find_index(wave, wavelengths)
idx_600 = idx_abs_for(600)
idx_500 = idx_abs_for(500)
idx_850 = idx_abs_for(850)
band_600 = get_mean_value(idx_600, coor_spectral, window)
band_500 = get_mean_value(idx_500, coor_spectral, window)
band_850 = get_mean_value(idx_850, coor_spectral, window)
with np.errstate(divide='ignore', invalid='ignore'):
x13 = np.log(band_500 / band_850)
x23 = np.exp(band_600 / band_500)
invalid = ~np.isfinite(x13) | ~np.isfinite(x23)
if invalid.any():
warnings.warn(f"nh3 自变量出现 {invalid.sum()} 个无效值0/负数/NaN对应位置结果将为 NaN。", RuntimeWarning)
retrieval_result = calculate(x13, x23, model_info)
# 创建DataFrame
result_df = pd.DataFrame({
'longitude': coor_spectral[1:, 0],
'latitude': coor_spectral[1:, 1],
'prediction': retrieval_result
})
if output_path is not None:
try:
result_df.to_csv(output_path, index=False, float_format='%.8f')
except Exception as e:
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
return result_df.values
def retrieval_tss(model_info_path, coor_spectral_path, output_path, window=5):
# 先跑 nh3 的同型模型(按你的原逻辑)
position_content = retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=window)
# 对结果进行指数变换
predictions = np.exp(position_content[:, -1])
# 创建DataFrame
result_df = pd.DataFrame({
'longitude': position_content[:, 0],
'latitude': position_content[:, 1],
'prediction': predictions
})
if not np.all(np.isfinite(result_df['prediction'])):
warnings.warn("tss 结果包含非有限值(可能因指数溢出),已保留为 NaN。", RuntimeWarning)
try:
result_df.to_csv(output_path, index=False, float_format='%.8f')
except Exception as e:
raise RetrievalError(f"写出结果失败:{output_path}\n原因:{e}")
return result_df.values
def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius=5.0):
try:
if algorithm == "chl_a":
return retrieval_chl_a(model_info_path, coor_spectral_path, output_path, wave_radius)
elif algorithm in ["nh3", "mno4", "tn", "tp"]:
return retrieval_nh3(model_info_path, coor_spectral_path, output_path, wave_radius)
elif algorithm == "tss":
return retrieval_tss(model_info_path, coor_spectral_path, output_path, wave_radius)
else:
raise RetrievalError(f"未知算法:{algorithm}可选chl_a / nh3 / mno4 / tn / tp / tss")
except RetrievalError as e:
# 面向用户的友好错误
print(f"[错误] {e}", file=sys.stderr)
sys.exit(2)
except Exception as e:
# 未预料的异常,附带类型与少量上下文
print(f"[致命错误] {type(e).__name__}: {e}", file=sys.stderr)
sys.exit(3)
if __name__ == "__main__":
algorithm= "chl_a"
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\6_5_non_empirical_models\SS\SS_chl_a.json"
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\7_sampling\sampling_spectra.csv"
output_path= r"E:\code\WQ\pipeline_result\work_dir\8_predictions\SS_chl_a.csv"
wave_radius=5.0
non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius)