refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名

This commit is contained in:
DXC
2026-06-03 17:29:41 +08:00
parent 517bb28611
commit 343e316799
99 changed files with 9127 additions and 91 deletions

View File

@ -20,23 +20,28 @@ class PipelineContext:
"""流水线运行上下文(在 14 个 step 之间传递的内存字典)
字段命名约定:
- 路径字段统一 `_path` 后缀(如 water_mask_path
- 目录类字段无 `_path` 后缀(如 models_dir
- 路径字段名 = panel key 名 = step 形参名(全链路无翻译
- 训练/产物 CSV 用 `_path` 后缀(如 training_csv_path / water_mask_path
- 入参影像/CSV 沿用 panel 原名img_path / csv_path无 `_path` 后缀
- 目录类字段无 `_path` 后缀(如 models_dir / prediction_dir
- 元信息字段无后缀(如 user_config / status / log
"""
# ── 9 步主路径(按 step 输出顺序排列) ──
raw_img_path: Optional[str] = None # Step 1 入参:原始影像
# ── 11 个 step 的入参/产物(按 step 顺序排列;字段名 = panel key = step 形参 ──
img_path: Optional[str] = None # Step 1/2/3 入参:原始影像
water_mask_path: Optional[str] = None # Step 1 出 → Step 2/3/7 入
glint_mask_path: Optional[str] = None # Step 2 出 → Step 3/7 入
deglint_img_path: Optional[str] = None # Step 3 出 → Step 5/7 入
raw_csv_path: Optional[str] = None # Step 4 入:原始 CSV
csv_path: Optional[str] = None # Step 4/5/6_5/6_75:原始/训练 CSV
processed_csv_path: Optional[str] = None # Step 4 出 → Step 5 入
training_spectra_path: Optional[str] = None # Step 5 出 → Step 6
training_csv_path: Optional[str] = None # Step 5 出 → Step 5_5/6/6_5/6_75
boundary_path: Optional[str] = None # Step 5 入参:边界 SHPpanel step5 名)
indices_path: Optional[str] = None # Step 5.5 出
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/9 入
prediction_csv_path: Optional[str] = None # Step 8 出
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/8_5/8_75/9 入
prediction_csv_path: Optional[str] = None # Step 8 出 → Step 9 入
distribution_map_path: Optional[str] = None # Step 9 出
boundary_shp_path: Optional[str] = None # Step 9 入参:边界 SHPpanel step9 名)
formula_csv_path: Optional[str] = None # Step 8_75 入参:公式 CSV
# ── 目录类(命名不带 _path 以示区别) ──
models_dir: Optional[str] = None

View File

@ -4,10 +4,8 @@ PipelineRunner基于 StepSpec 声明式调度 14 个 step。
设计要点:
- StepSpec 声明 requiresctx 字段名列表)+ producesctx 字段名列表)
- 默认约定ctx 字段名去掉 `_path` 后缀 = step 方法形参名
ctx.water_mask_path → 形参 water_mask
ctx.raw_img_path → 形参 raw_img
- 可被 spec.parameter_map 覆盖
- 命名约定ctx 字段名 == panel key 名 == step 形参名(全链路无翻译)
- 保留 spec.parameter_map 字段骨架供极少数特例覆盖(默认空 dict
- 调度顺序:按 PIPELINE_STEPS 列表顺序requires 缺则 skip
- 软取消:在每个 step 前检查 ctx.is_cancelled()
- duck-typed pipelinerunner 只调 getattr(pipeline, method_name),不强依赖类层级
@ -48,101 +46,76 @@ class StepSpec:
PIPELINE_STEPS: List[StepSpec] = [
StepSpec(
step_id="step1", method_name="step1_generate_water_mask",
requires=["raw_img_path"], produces=["water_mask_path"],
# ctx.raw_img_path → 形参 img_path老 step1 形参名是 img_path不是 raw_img
parameter_map={"raw_img_path": "img_path"},
requires=["img_path"], produces=["water_mask_path"],
description="水域掩膜生成NDWI 或 SHP",
),
StepSpec(
step_id="step2", method_name="step2_find_glint_area",
requires=["raw_img_path", "water_mask_path"], produces=["glint_mask_path"],
# raw_img_path→img_pathwater_mask_path 不变
parameter_map={"raw_img_path": "img_path"},
requires=["img_path", "water_mask_path"], produces=["glint_mask_path"],
description="耀斑区域检测",
),
StepSpec(
step_id="step3", method_name="step3_remove_glint",
requires=["deglint_img_path"], produces=["deglint_img_path"],
# deglint_img_path→img_path老 step3 形参名是 img_path
# 注意glint_mask_path 不在 requires 中——step3 形参表无该参数,内部走 self.glint_mask_path 回退
parameter_map={"deglint_img_path": "img_path"},
requires=["img_path", "water_mask_path", "glint_mask_path"],
produces=["deglint_img_path"],
description="耀斑去除",
),
StepSpec(
step_id="step4", method_name="step4_process_csv",
requires=["raw_csv_path"], produces=["processed_csv_path"],
# raw_csv_path→csv_path老 step4 形参名是 csv_path
parameter_map={"raw_csv_path": "csv_path"},
requires=["csv_path"], produces=["processed_csv_path"],
description="CSV 异常值清洗",
),
StepSpec(
step_id="step5", method_name="step5_extract_training_spectra",
requires=["deglint_img_path", "processed_csv_path"], produces=["training_spectra_path"],
# processed_csv_path→csv_path老 step5 形参名是 csv_pathdeglint_img_path 不变
parameter_map={"processed_csv_path": "csv_path"},
requires=["deglint_img_path", "csv_path", "boundary_path", "glint_mask_path"],
produces=["training_csv_path"],
description="实测样本点光谱提取",
),
StepSpec(
step_id="step5_5", method_name="step5_5_calculate_water_quality_indices",
requires=["training_spectra_path"], produces=["indices_path"],
# 老 step5.5 形参是 training_spectra_pathctx 字段同名,无需映射
parameter_map={},
requires=["training_csv_path"], produces=["indices_path"],
description="水质光谱指数计算optional",
),
StepSpec(
step_id="step6", method_name="step6_train_models",
requires=["training_spectra_path"], produces=["models_dir"],
# training_spectra_path→training_csv_path老 step6 形参名是 training_csv_path
parameter_map={"training_spectra_path": "training_csv_path"},
requires=["training_csv_path"], produces=["models_dir"],
description="ML 建模GridSearchCV / AutoML",
),
StepSpec(
step_id="step6_5", method_name="step6_5_non_empirical_modeling",
requires=["training_spectra_path"], produces=["models_dir"],
# training_spectra_path→csv_path老 step6.5 形参名是 csv_path
parameter_map={"training_spectra_path": "csv_path"},
requires=["training_csv_path"], produces=["models_dir"],
description="非经验统计回归",
),
StepSpec(
step_id="step6_75", method_name="step6_75_custom_regression",
requires=["training_spectra_path"], produces=["models_dir"],
# training_spectra_path→csv_path老 step6.75 形参名是 csv_path
parameter_map={"training_spectra_path": "csv_path"},
requires=["training_csv_path"], produces=["models_dir"],
description="自定义回归分析",
),
StepSpec(
step_id="step7", method_name="step7_generate_sampling_points",
requires=["deglint_img_path", "water_mask_path"], produces=["sampling_csv_path"],
# 老 step7 形参是 deglint_img_path / water_mask_pathctx 字段同名
parameter_map={},
description="整景密集采样点生成 + 光谱提取",
),
StepSpec(
step_id="step8", method_name="step8_predict_water_quality",
requires=["sampling_csv_path", "models_dir"], produces=["prediction_csv_path"],
parameter_map={},
description="ML 模型预测(采样点)",
),
StepSpec(
step_id="step8_5", method_name="step8_5_predict_with_non_empirical_models",
requires=["sampling_csv_path"], produces=["prediction_dir"],
parameter_map={},
requires=["sampling_csv_path", "models_dir"], produces=["prediction_dir"],
description="非经验模型预测",
),
StepSpec(
step_id="step8_75", method_name="step8_75_predict_with_custom_regression",
requires=["sampling_csv_path"], produces=["prediction_dir"],
parameter_map={},
requires=["sampling_csv_path", "models_dir", "formula_csv_path"],
produces=["prediction_dir"],
description="自定义回归预测",
),
StepSpec(
step_id="step9", method_name="step9_generate_distribution_map",
requires=["prediction_csv_path"],
requires=["prediction_csv_path", "boundary_shp_path"],
produces=["distribution_map_path"],
# 老 step9 形参是 prediction_csv_path / boundary_shp_pathctx 字段同名
# 注意sampling_csv_path / water_mask_path 不在 requires 中——step9 形参表无该参数,
# 内部走 self.sampling_csv_path / self.water_mask_path 回退
parameter_map={},
description="克里金插值成图",
),
]
@ -157,7 +130,7 @@ class PipelineRunner:
用法:
runner = PipelineRunner(pipeline_instance)
ctx = PipelineContext(raw_img_path=..., ...)
ctx = PipelineContext(img_path=..., ...)
result_ctx = runner.run(ctx)
"""

View File

@ -0,0 +1,544 @@
# -*- coding: utf-8 -*-
"""
Optuna + 智能子采样 AutoML 训练器(路线 B 防爆引擎)。
为什么需要这个:
- 老路径11 预处理 × 4 模型 × 3 划分 = 132 组 GridSearchCV
对中小数据集 10 分钟+,对大数据集 5w+ 行 直接 OOM
- AutoML 路径1 预处理 × N 模型Optuna 调超参),用智能子采样避开 OOM
再用最优超参在**全量数据**上 refit最终保存单一模型
设计要点:
- 入口 train_with_automl(csv, feature_start_column, model_names, ...)
- AutoMLResult dataclass 返回(每个目标列一份)
- smart_subsampleN > max_samples 时随机下采样
- 失败兜底optuna 未装 / 全 trial 失败 → fallback 到 WaterQualityModelingBatch
- 文件命名规范:{target}_{preprocess}_{model}_AUTOML.joblib
- save_data["metadata"]["automl"] = True 标记
调用:
from src.core.prediction.automl_trainer import train_with_automl
results = train_with_automl(
training_csv_path=".../training_spectra.csv",
feature_start_column="374.285004",
model_names=["RF", "SVR", "Ridge"],
n_trials=20,
timeout_sec=300,
)
"""
from __future__ import annotations
import json
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
# ============================================================
# 常量
# ============================================================
# AutoML 寻优阶段允许的最大样本数(避免 OOM
# 5000 样本对 RF/SVR/Ridge 的 Optuna 寻优足够给出稳定 CV
DEFAULT_MAX_SAMPLES = 5000
# 单次 Optuna trial 的默认超时(秒)
DEFAULT_TIMEOUT = 300.0
# 默认 trial 数
DEFAULT_N_TRIALS = 20
# AutoML 输出目录名后缀
AUTOML_DIR_SUFFIX = "_AutoML"
# ============================================================
# 数据类
# ============================================================
@dataclass
class AutoMLResult:
"""单个目标列的 AutoML 训练结果"""
success: bool = False
model_path: Optional[str] = None
cv_score: float = -float("inf")
best_params: Optional[Dict[str, Any]] = None
target_column: str = ""
preprocessing: str = ""
model_name: str = ""
n_trials_done: int = 0
n_samples_used: int = 0
fallback_used: bool = False
elapsed_sec: float = 0.0
error: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
# ============================================================
# 智能子采样
# ============================================================
def smart_subsample(
X: np.ndarray,
y: np.ndarray,
max_samples: int = DEFAULT_MAX_SAMPLES,
random_state: int = 42,
) -> Tuple[np.ndarray, np.ndarray, bool]:
"""当 N > max_samples 时随机下采样;否则原样返回。
Returns:
(X_sub, y_sub, was_subsampled)
"""
n = X.shape[0]
if n <= max_samples:
return X, y, False
rng = np.random.default_rng(random_state)
idx = rng.choice(n, size=max_samples, replace=False)
return X[idx], y[idx], True
# ============================================================
# 模型工厂
# ============================================================
def _build_model(model_name: str, random_state: int = 42):
"""根据英文模型键名构造 sklearn-compatible 模型实例factory"""
from sklearn.ensemble import (
AdaBoostRegressor, ExtraTreesRegressor, GradientBoostingRegressor,
RandomForestRegressor,
)
from sklearn.linear_model import (
ElasticNet, Lasso, LinearRegression, Ridge,
)
from sklearn.neighbors import KNeighborsRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
factory = {
"RF": lambda **kw: RandomForestRegressor(random_state=random_state, n_jobs=1, **kw),
"ET": lambda **kw: ExtraTreesRegressor(random_state=random_state, n_jobs=1, **kw),
"GradientBoosting": lambda **kw: GradientBoostingRegressor(random_state=random_state, **kw),
"AdaBoost": lambda **kw: AdaBoostRegressor(random_state=random_state, **kw),
"Ridge": lambda **kw: Ridge(**kw),
"Lasso": lambda **kw: Lasso(max_iter=5000, **kw),
"ElasticNet": lambda **kw: ElasticNet(max_iter=5000, **kw),
"LinearRegression": lambda **kw: LinearRegression(**kw),
"SVR": lambda **kw: SVR(**kw),
"KNN": lambda **kw: KNeighborsRegressor(n_jobs=1, **kw),
"MLP": lambda **kw: MLPRegressor(max_iter=500, random_state=random_state, **kw),
"DecisionTree": lambda **kw: DecisionTreeRegressor(random_state=random_state, **kw),
"PLS": None, # sklearn.cross_decomposition.PLSRegression 暂未集成
}
builder = factory.get(model_name)
if builder is None:
return None
return builder
# ============================================================
# Optuna 超参 search space
# ============================================================
def _get_search_space(model_name: str, trial) -> Dict[str, Any]:
"""按模型名返回 Optuna 超参 search space。"""
sp: Dict[str, Any] = {}
if model_name == "RF":
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
sp["min_samples_split"] = trial.suggest_int("min_samples_split", 2, 10)
sp["min_samples_leaf"] = trial.suggest_int("min_samples_leaf", 1, 5)
elif model_name == "ET":
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
elif model_name == "GradientBoosting":
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
sp["max_depth"] = trial.suggest_int("max_depth", 3, 8)
sp["learning_rate"] = trial.suggest_float("learning_rate", 0.01, 0.3, log=True)
elif model_name == "SVR":
sp["C"] = trial.suggest_float("C", 0.1, 100.0, log=True)
sp["epsilon"] = trial.suggest_float("epsilon", 0.001, 1.0, log=True)
sp["kernel"] = trial.suggest_categorical("kernel", ["rbf", "linear"])
elif model_name == "KNN":
sp["n_neighbors"] = trial.suggest_int("n_neighbors", 3, 20)
sp["weights"] = trial.suggest_categorical("weights", ["uniform", "distance"])
elif model_name in ("Ridge", "Lasso", "ElasticNet"):
sp["alpha"] = trial.suggest_float("alpha", 0.01, 100.0, log=True)
if model_name == "ElasticNet":
sp["l1_ratio"] = trial.suggest_float("l1_ratio", 0.0, 1.0)
elif model_name == "MLP":
sp["hidden_layer_sizes"] = trial.suggest_categorical(
"hidden_layer_sizes", [(50,), (100,), (50, 50), (100, 50)]
)
sp["alpha"] = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
sp["learning_rate_init"] = trial.suggest_float("learning_rate_init", 1e-4, 1e-2, log=True)
elif model_name == "DecisionTree":
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
sp["min_samples_split"] = trial.suggest_int("min_samples_split", 2, 10)
elif model_name == "AdaBoost":
sp["n_estimators"] = trial.suggest_int("n_estimators", 30, 200, step=30)
sp["learning_rate"] = trial.suggest_float("learning_rate", 0.01, 1.0, log=True)
else:
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 200, step=50)
return sp
def _make_objective(model_name: str, X: np.ndarray, y: np.ndarray,
cv_folds: int, random_state: int):
"""构造 Optuna objective5 折 CV R²"""
from sklearn.model_selection import KFold, cross_val_score
def objective(trial):
params = _get_search_space(model_name, trial)
try:
builder = _build_model(model_name, random_state=random_state)
if builder is None:
return -1.0
model = builder(**params)
kf = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
scores = cross_val_score(model, X, y, cv=kf, scoring="r2", n_jobs=1)
return float(np.mean(scores))
except Exception:
return -1.0
return objective
def _refit_full(model_name: str, best_params: Dict[str, Any],
X: np.ndarray, y: np.ndarray, random_state: int):
"""用 best params 在**全量数据**上 refit。"""
builder = _build_model(model_name, random_state=random_state)
if builder is None:
return None
model = builder(**best_params)
model.fit(X, y)
return model
# ============================================================
# 失败兜底(回退到老 GridSearchCV 路径)
# ============================================================
def _fallback_train(
training_csv_path: str,
feature_start_column,
preprocessing: str,
model_name: str,
split_method: str,
cv_folds: int,
output_dir: Path,
target_column: str,
) -> AutoMLResult:
"""AutoML 失败时调老 WaterQualityModelingBatch。
返回的 AutoMLResult.fallback_used=True。
"""
try:
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
except ImportError as e:
return AutoMLResult(
success=False, error=f"fallback 导入失败: {e!r}", fallback_used=True,
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
)
try:
out_dir = output_dir / preprocessing
out_dir.mkdir(parents=True, exist_ok=True)
modeler = WaterQualityModelingBatch(str(out_dir))
modeler.train_models_batch(
csv_path=training_csv_path,
feature_start_column=feature_start_column,
preprocessing_methods=[preprocessing],
model_names=[model_name],
split_methods=[split_method],
cv_folds=cv_folds,
)
# 找产出
candidates = list(out_dir.rglob(f"{target_column}_{preprocessing}_{model_name}.joblib"))
model_path = str(candidates[0]) if candidates else None
return AutoMLResult(
success=model_path is not None,
model_path=model_path,
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
fallback_used=True,
metadata={"source": "WaterQualityModelingBatch"},
)
except Exception as e:
return AutoMLResult(
success=False, error=f"fallback 失败: {e!r}", fallback_used=True,
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
)
# ============================================================
# 主入口
# ============================================================
def train_with_automl(
training_csv_path: str,
feature_start_column,
preprocessing_methods: Optional[List[str]] = None,
model_names: Optional[List[str]] = None,
split_methods: Optional[List[str]] = None,
cv_folds: int = 5,
output_dir: Optional[str] = None,
n_trials: int = DEFAULT_N_TRIALS,
timeout_sec: float = DEFAULT_TIMEOUT,
max_samples: int = DEFAULT_MAX_SAMPLES,
random_state: int = 42,
callback: Optional[Callable[[str, str, str], None]] = None,
) -> List[AutoMLResult]:
"""用 Optuna + 子采样跑 AutoML。失败时自动回退到 GridSearchCV。
Args:
training_csv_path: 训练用 CSVStep 5 产物 training_spectra.csv
feature_start_column: 特征起始列名或索引(之前所有列视为目标 y
preprocessing_methods: 候选预处理列表(**仅用第 1 个**,避免笛卡尔爆炸)
model_names: 候选模型列表(每个都会跑一遍 Optuna
split_methods: 候选数据划分列表AutoML 仅用第 1 个)
cv_folds: 交叉验证折数
output_dir: 输出目录(默认 <models_dir>_AutoML
n_trials: 单模型 Optuna trial 数
timeout_sec: 单模型超时(秒),到时强制停止
max_samples: 寻优阶段允许的最大样本数
callback: 状态回调 callback(step_name, status, message)
Returns:
List[AutoMLResult],每个目标列一份结果
"""
def notify(status: str, msg: str = "") -> None:
if callback:
callback("步骤6_AutoML", status, msg)
# ---- 1) 参数默认值 ----
if preprocessing_methods is None:
preprocessing_methods = ["MMS"]
if model_names is None:
model_names = ["RF", "SVR", "Ridge"]
if split_methods is None:
split_methods = ["spxy"]
# 决策:仅用第一个预处理 + 第一个划分,避免笛卡尔爆炸
preproc = preprocessing_methods[0]
split_method = split_methods[0]
if output_dir is None:
output_dir = "./7_Supervised_Model_Training_AutoML"
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
preproc_dir = out_dir / preproc
preproc_dir.mkdir(parents=True, exist_ok=True)
# ---- 2) 加载数据 ----
notify("start", f"AutoML 训练开始 (n_trials={n_trials}, timeout={timeout_sec}s, max_samples={max_samples})")
if not Path(training_csv_path).exists():
return [AutoMLResult(success=False, error=f"训练 CSV 不存在: {training_csv_path}")]
df = pd.read_csv(training_csv_path)
# 提取目标列feature_start_column 之前所有数值列)
if isinstance(feature_start_column, int):
y_cols = [c for c in df.columns[:feature_start_column]
if pd.api.types.is_numeric_dtype(df[c])]
else:
try:
idx = list(df.columns).index(feature_start_column)
y_cols = [c for c in df.columns[:idx]
if pd.api.types.is_numeric_dtype(df[c])]
except ValueError:
y_cols = []
if not y_cols:
notify("error", "AutoML: 未识别出目标列feature_start_column 之前的所有数值列)")
return [AutoMLResult(success=False, error="未识别出目标列")]
feat_cols = [c for c in df.columns if c not in y_cols]
X_all = df[feat_cols].values.astype(np.float64)
# ---- 3) 预处理(仅第一项) ----
if preproc != "None":
try:
from src.preprocessing.spectral_Preprocessing import Preprocessing
processed = Preprocessing(preproc, df[feat_cols])
if isinstance(processed, pd.DataFrame):
X_all = processed.values.astype(np.float64)
else:
X_all = np.asarray(processed, dtype=np.float64)
except Exception as e:
notify("warning", f"预处理 {preproc} 失败: {e!r},改用 None")
preproc = "None"
# ---- 4) 检查 Optuna 是否可用 ----
try:
import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
optuna_available = True
except ImportError:
optuna_available = False
notify("warning", "optuna 未安装,全目标列回退到 GridSearchCVpip install \"optuna>=3.6\"")
# ---- 5) 逐 target 跑 ----
results: List[AutoMLResult] = []
total = len(y_cols)
per_model_timeout = max(10.0, timeout_sec / max(1, len(model_names)))
for ti, tgt in enumerate(y_cols, 1):
t0 = time.time()
yv = df[tgt].values.astype(np.float64)
mask = ~np.isnan(yv)
X_t = X_all[mask]
y_t = yv[mask]
if X_t.shape[0] < cv_folds * 2:
notify("warning", f"目标 {tgt}: 有效样本 {X_t.shape[0]} 不足,跳过")
results.append(AutoMLResult(
success=False, target_column=tgt, error=f"样本不足({X_t.shape[0]})",
preprocessing=preproc,
))
continue
X_sub, y_sub, was_sub = smart_subsample(X_t, y_t, max_samples=max_samples, random_state=random_state)
if was_sub:
notify("info", f"目标 {tgt}: {X_t.shape[0]} 样本 → 子采样 {X_sub.shape[0]}(寻优用)")
best_overall = AutoMLResult(success=False, target_column=tgt, preprocessing=preproc)
if not optuna_available:
# 全目标列一次性 fallback
best_overall = _fallback_train(
training_csv_path, feature_start_column, preproc, model_names[0], split_method,
cv_folds, out_dir, tgt,
)
else:
for model_name in model_names:
try:
builder = _build_model(model_name, random_state=random_state)
if builder is None:
notify("warning", f"模型 {model_name} 暂不支持 AutoML 寻优")
continue
study = optuna.create_study(
direction="maximize",
sampler=optuna.samplers.TPESampler(seed=random_state),
)
study.optimize(
_make_objective(model_name, X_sub, y_sub, cv_folds, random_state),
n_trials=n_trials,
timeout=per_model_timeout,
show_progress_bar=False,
)
if study.best_value is None or study.best_value <= -1.0:
notify("warning", f"{tgt}/{model_name}: 全部 trial 失败CV 全部 <= -1")
continue
# refit on FULL
final_model = _refit_full(model_name, study.best_params, X_t, y_t, random_state)
if final_model is None:
continue
# 保存
import joblib
fname = f"{tgt}_{preproc}_{model_name}_AUTOML.joblib"
fpath = preproc_dir / fname
joblib.dump({
"model": final_model,
"target_column_name": tgt,
"preprocess_method": preproc,
"model_name": model_name,
"metadata": {
"automl": True,
"best_params": study.best_params,
"cv_score": float(study.best_value),
"n_trials_done": len(study.trials),
"n_samples_used_full": int(X_t.shape[0]),
"n_samples_used_for_search": int(X_sub.shape[0]),
"was_subsampled": was_sub,
"split_method": split_method,
},
}, fpath)
cand = AutoMLResult(
success=True,
model_path=str(fpath),
cv_score=float(study.best_value),
best_params=study.best_params,
target_column=tgt,
preprocessing=preproc,
model_name=model_name,
n_trials_done=len(study.trials),
n_samples_used=int(X_sub.shape[0]),
metadata={"refit_on_full": True, "n_samples_full": int(X_t.shape[0])},
)
if cand.cv_score > best_overall.cv_score:
best_overall = cand
except Exception as e:
notify("warning", f"目标 {tgt} / 模型 {model_name} 失败: {e!r}")
continue
if not best_overall.success:
notify("warning", f"目标 {tgt} 全部 Optuna trial 失败,回退 GridSearchCV")
best_overall = _fallback_train(
training_csv_path, feature_start_column, preproc, model_names[0], split_method,
cv_folds, out_dir, tgt,
)
best_overall.elapsed_sec = time.time() - t0
results.append(best_overall)
notify("info", f"AutoML 目标 {tgt} 完成 ({ti}/{total}) cv={best_overall.cv_score:.4f}")
# ---- 6) 汇总 json ----
summary_path = out_dir / "automl_summary.json"
try:
with open(summary_path, "w", encoding="utf-8") as f:
json.dump([asdict(r) for r in results], f, ensure_ascii=False, indent=2, default=str)
except Exception as e:
notify("warning", f"写 automl_summary.json 失败: {e!r}")
success_n = sum(1 for r in results if r.success)
fallback_n = sum(1 for r in results if r.fallback_used)
notify("completed", f"AutoML 训练完成 {success_n}/{len(results)} 成功({fallback_n} 走 fallback汇总 {summary_path}")
return results
# ============================================================
# CLI 自测
# ============================================================
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser(description="AutoML 训练器 CLI 自测")
p.add_argument("--csv", required=True, help="训练用 CSVfeature_start_column 之前的列为目标 y")
p.add_argument("--feature-start", default="0", help="特征起始列名或索引(默认 0")
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML")
args = p.parse_args()
# 智能推断 feature_start_column 类型
fsc: Any = args.feature_start
try:
fsc = int(fsc)
except ValueError:
pass
res = train_with_automl(
training_csv_path=args.csv,
feature_start_column=fsc,
n_trials=args.n_trials,
timeout_sec=args.timeout,
max_samples=args.max_samples,
output_dir=args.out,
)
print(f"\n训练完成 {len(res)} 个目标")
for r in res:
marker = "" if r.success else ""
fb = " [fallback]" if r.fallback_used else ""
print(f" {marker} {r.target_column}: cv={r.cv_score:.4f} path={r.model_path}{fb}")

View File

@ -126,7 +126,7 @@ class DataPreparationStep:
@staticmethod
def calculate_water_quality_indices(
training_spectra_path: Optional[str] = None,
training_csv_path: Optional[str] = None,
formula_csv_file: Optional[str] = None,
formula_names: Optional[List[str]] = None,
output_file: Optional[str] = None,
@ -153,8 +153,8 @@ class DataPreparationStep:
notify("skipped", "跳过水质指数计算")
return None
if training_spectra_path is None:
raise ValueError("必须提供 training_spectra_path 参数")
if training_csv_path is None:
raise ValueError("必须提供 training_csv_path 参数")
if formula_csv_file is None:
raise ValueError("必须提供 formula_csv_file 参数")
@ -170,7 +170,7 @@ class DataPreparationStep:
from src.utils.band_math import BandMathCalculator
calculator = BandMathCalculator(training_spectra_path)
calculator = BandMathCalculator(training_csv_path)
result_df = calculator.process_formulas_from_csv(
formula_csv_file=formula_csv_file,
formula_names=formula_names,

View File

@ -173,7 +173,7 @@ class WaterQualityInversionPipeline:
self.interpolated_img_path = None # 存储插值后的影像路径
self.deglint_img_path = None
self.processed_csv_path = None
self.training_spectra_path = None
self.training_csv_path = None
self.indices_path = None
self.custom_regression_path = None
@ -511,7 +511,7 @@ class WaterQualityInversionPipeline:
left_shoulder_wave: Optional[float] = None,
valley_wave: Optional[float] = None,
right_shoulder_wave: Optional[float] = None,
water_mask: Optional[Union[str, np.ndarray]] = None,
water_mask_path: Optional[Union[str, np.ndarray]] = None,
interpolate_zeros: bool = False,
interpolation_method: str = 'nearest',
enabled: bool = True,
@ -546,7 +546,7 @@ class WaterQualityInversionPipeline:
left_shoulder_wave=left_shoulder_wave,
valley_wave=valley_wave,
right_shoulder_wave=right_shoulder_wave,
water_mask=water_mask,
water_mask=water_mask_path,
interpolate_zeros=interpolate_zeros,
interpolation_method=interpolation_method,
enabled=enabled,
@ -655,13 +655,13 @@ class WaterQualityInversionPipeline:
water_mask_path=self.water_mask_path,
output_dir=str(self.training_spectra_dir),
)
self.training_spectra_path = result
self.training_csv_path = result
self._record_step_time("步骤5: 提取训练样本点光谱", 0, 0)
self._notify("completed", f"训练光谱数据已保存: {result}")
return result
def step5_5_calculate_water_quality_indices(self,
training_spectra_path: Optional[str] = None,
training_csv_path: Optional[str] = None,
formula_csv_file: Optional[str] = None,
formula_names: Optional[List[str]] = None,
output_file: Optional[str] = None,
@ -669,29 +669,29 @@ class WaterQualityInversionPipeline:
skip_dependency_check: bool = False) -> str:
"""
步骤5.5: 根据训练光谱计算水质光谱指数
使用band_math.py中的方法实现支持从公式CSV文件中批量计算指定公式
Args:
training_spectra_path: 训练光谱数据CSV路径如果为None使用步骤5的结果
training_csv_path: 训练光谱数据CSV路径如果为None使用步骤5的结果
formula_csv_file: 公式CSV文件路径包含公式名称和具体公式
formula_names: 要计算的公式名称列表如果为None则计算所有公式
output_file: 输出文件完整路径支持绝对路径如果为None则使用默认路径
Returns:
包含计算结果的新CSV文件路径
"""
# 参数解析(保留原逻辑)
if training_spectra_path is not None:
csv_path = training_spectra_path
elif self.training_spectra_path is not None:
csv_path = self.training_spectra_path
if training_csv_path is not None:
csv_path = training_csv_path
elif self.training_csv_path is not None:
csv_path = self.training_csv_path
else:
csv_path = None
self._notify("started", "步骤5.5: 计算水质光谱指数")
result = DataPreparationStep.calculate_water_quality_indices(
training_spectra_path=csv_path,
training_csv_path=csv_path,
formula_csv_file=formula_csv_file,
formula_names=formula_names,
output_file=output_file,
@ -727,8 +727,8 @@ class WaterQualityInversionPipeline:
# 参数解析(保留原逻辑)
if training_csv_path is not None:
final_csv_path = training_csv_path
elif self.training_spectra_path is not None:
final_csv_path = self.training_spectra_path
elif self.training_csv_path is not None:
final_csv_path = self.training_csv_path
else:
final_csv_path = None
@ -911,7 +911,7 @@ class WaterQualityInversionPipeline:
print("="*80)
if training_csv_path is None:
training_csv_path = self.training_spectra_path
training_csv_path = self.training_csv_path
if training_csv_path is None:
raise ValueError("请提供训练数据CSV路径或先执行步骤5")
@ -1033,7 +1033,7 @@ class WaterQualityInversionPipeline:
print("="*80)
if csv_path is None:
csv_path = self.training_spectra_path
csv_path = self.training_csv_path
if csv_path is None:
raise ValueError("请提供CSV文件路径或先执行步骤5")
@ -1506,7 +1506,7 @@ class WaterQualityInversionPipeline:
if 'step5' in config:
self._notify("步骤5: 光谱提取", "start")
self.step5_extract_training_spectra(**config['step5'])
self._notify("步骤5: 光谱提取", "completed", f"(输出: {self.training_spectra_path})")
self._notify("步骤5: 光谱提取", "completed", f"(输出: {self.training_csv_path})")
else:
self._notify("步骤5: 光谱提取", "skipped", "未配置")
@ -1615,7 +1615,7 @@ class WaterQualityInversionPipeline:
# 生成散点图
if 'visualization' in config and config['visualization'].get('generate_scatter', True):
if self.training_spectra_path and self.models_dir.exists():
if self.training_csv_path and self.models_dir.exists():
try:
self._notify("可视化", "info", "生成模型评估散点图...")
scatter_config = config['visualization'].get('scatter_config', {})
@ -1653,7 +1653,7 @@ class WaterQualityInversionPipeline:
# 生成光谱曲线图
if 'visualization' in config and config['visualization'].get('generate_spectrum', True):
if self.training_spectra_path:
if self.training_csv_path:
try:
self._notify("可视化", "info", "生成光谱曲线对比图...")
spectrum_paths = self.generate_spectrum_comparison_plots(
@ -1701,7 +1701,7 @@ class WaterQualityInversionPipeline:
pipeline_info['step2'] = {'status': 'completed', 'output_file': str(self.glint_mask_path) if self.glint_mask_path else 'N/A'}
pipeline_info['step3'] = {'status': 'completed', 'output_file': str(self.deglint_img_path) if self.deglint_img_path else 'N/A'}
pipeline_info['step4'] = {'status': 'completed', 'output_file': str(self.processed_csv_path) if self.processed_csv_path else 'N/A'}
pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_spectra_path) if self.training_spectra_path else 'N/A'}
pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_csv_path) if self.training_csv_path else 'N/A'}
pipeline_info['step5_5'] = {'status': 'completed', 'output_file': str(self.indices_path) if self.indices_path else 'N/A'}
pipeline_info['step6'] = {'status': 'completed', 'output_file': str(self.models_dir)}
pipeline_info['step6_75'] = {'status': 'completed', 'output_file': str(self.custom_regression_path) if self.custom_regression_path else 'N/A'}
@ -1784,8 +1784,8 @@ class WaterQualityInversionPipeline:
# 参数解析(保留原逻辑)
if csv_path is not None:
final_csv_path = csv_path
elif self.training_spectra_path is not None:
final_csv_path = self.training_spectra_path
elif self.training_csv_path is not None:
final_csv_path = self.training_csv_path
else:
final_csv_path = None
@ -2109,7 +2109,7 @@ def main():
'interpolation_method': 'bilinear', # 插值方法: 'nearest'(邻近), 'bilinear'(双线性),
# 'spline'(样条), 'kriging'(克里金)
# 水域掩膜参数(可选):
'water_mask':r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp", # None表示自动使用步骤1生成的掩膜也可以提供
'water_mask_path':r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp", # None表示自动使用步骤1生成的掩膜也可以提供
# # - numpy数组
# # - 栅格文件路径(.dat/.tif)
# # - shapefile路径(.shp)

View File

@ -0,0 +1,430 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
图表与交互弹窗模块
包含 ChartViewerDialog、ChartBrowserDialog 和 InteractiveViewerDialog 类。
"""
import numpy as np
from PyQt5.QtWidgets import (
QDialog, QVBoxLayout, QHBoxLayout, QPushButton,
QSizePolicy, QFileDialog, QMessageBox, QGroupBox,
QListWidget, QLabel, QComboBox, QCheckBox,
)
from PyQt5.QtCore import Qt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
class ChartViewerDialog(QDialog):
"""图表查看器对话框"""
def __init__(self, title="图表查看器", parent=None):
super().__init__(parent)
self.setWindowTitle(title)
self.resize(1000, 700)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
self.figure = Figure(figsize=(10, 7))
self.canvas = FigureCanvas(self.figure)
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.toolbar = NavigationToolbar(self.canvas, self)
layout.addWidget(self.toolbar)
layout.addWidget(self.canvas)
btn_layout = QHBoxLayout()
self.save_btn = QPushButton("保存图表")
self.save_btn.clicked.connect(self.save_chart)
btn_layout.addWidget(self.save_btn)
btn_layout.addStretch()
self.close_btn = QPushButton("关闭")
self.close_btn.clicked.connect(self.close)
btn_layout.addWidget(self.close_btn)
layout.addLayout(btn_layout)
self.setLayout(layout)
def display_image(self, image_path):
"""显示图片"""
self.figure.clear()
ax = self.figure.add_subplot(111)
try:
import matplotlib.image as mpimg
img = mpimg.imread(image_path)
ax.imshow(img)
ax.axis('off')
self.figure.tight_layout()
self.canvas.draw()
self.current_image_path = image_path
except Exception as e:
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
def display_custom_plot(self, plot_func):
"""显示自定义绘图函数"""
self.figure.clear()
try:
plot_func(self.figure)
self.canvas.draw()
except Exception as e:
ax = self.figure.add_subplot(111)
ax.text(0.5, 0.5, f'绘图失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
def save_chart(self):
"""保存图表"""
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图表", "",
"PNG图片 (*.png);;JPG图片 (*.jpg);;PDF文件 (*.pdf);;所有文件 (*.*)"
)
if file_path:
try:
self.figure.savefig(file_path, dpi=300, bbox_inches='tight')
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
class ChartBrowserDialog(QDialog):
"""图表浏览器对话框"""
def __init__(self, chart_files, parent=None):
super().__init__(parent)
self.chart_files = sorted(chart_files, key=lambda x: x.stat().st_mtime, reverse=True)
self.current_index = 0
self.setWindowTitle("图表浏览器")
self.resize(1200, 800)
self.init_ui()
self.show_chart(0)
def init_ui(self):
layout = QVBoxLayout()
list_group = QGroupBox(f"图表列表 (共 {len(self.chart_files)} 个)")
list_layout = QHBoxLayout()
self.chart_list = QListWidget()
self.chart_list.setMaximumHeight(150)
for chart_file in self.chart_files:
self.chart_list.addItem(chart_file.name)
self.chart_list.currentRowChanged.connect(self.show_chart)
list_layout.addWidget(self.chart_list)
list_group.setLayout(list_layout)
layout.addWidget(list_group)
self.figure = Figure(figsize=(12, 8))
self.canvas = FigureCanvas(self.figure)
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.toolbar = NavigationToolbar(self.canvas, self)
layout.addWidget(self.toolbar)
layout.addWidget(self.canvas, 1)
btn_layout = QHBoxLayout()
self.prev_btn = QPushButton("◀ 上一个")
self.prev_btn.clicked.connect(self.prev_chart)
btn_layout.addWidget(self.prev_btn)
self.next_btn = QPushButton("下一个 >")
self.next_btn.clicked.connect(self.next_chart)
btn_layout.addWidget(self.next_btn)
btn_layout.addStretch()
self.save_btn = QPushButton("💾 保存当前图表")
self.save_btn.clicked.connect(self.save_current_chart)
btn_layout.addWidget(self.save_btn)
self.close_btn = QPushButton("关闭")
self.close_btn.clicked.connect(self.close)
btn_layout.addWidget(self.close_btn)
layout.addLayout(btn_layout)
self.setLayout(layout)
def show_chart(self, index):
"""显示指定索引的图表"""
if 0 <= index < len(self.chart_files):
self.current_index = index
self.chart_list.setCurrentRow(index)
chart_file = self.chart_files[index]
self.figure.clear()
ax = self.figure.add_subplot(111)
try:
import matplotlib.image as mpimg
img = mpimg.imread(str(chart_file))
ax.imshow(img)
ax.axis('off')
ax.set_title(chart_file.name, fontsize=12, pad=10)
self.figure.tight_layout()
self.canvas.draw()
except Exception as e:
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
self.prev_btn.setEnabled(index > 0)
self.next_btn.setEnabled(index < len(self.chart_files) - 1)
def prev_chart(self):
"""上一个图表"""
if self.current_index > 0:
self.show_chart(self.current_index - 1)
def next_chart(self):
"""下一个图表"""
if self.current_index < len(self.chart_files) - 1:
self.show_chart(self.current_index + 1)
def save_current_chart(self):
"""保存当前图表"""
if 0 <= self.current_index < len(self.chart_files):
current_file = self.chart_files[self.current_index]
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图表", current_file.name,
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
)
if file_path:
try:
import shutil
shutil.copy(str(current_file), file_path)
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
class InteractiveViewerDialog(QDialog):
"""交互式影像预览对话框:显示影像、参考点散点图、点击查询坐标/值"""
def __init__(self, parent, img_path, ref_csv=None):
super().__init__(parent)
self.img_path = img_path
self.ref_csv = ref_csv
self.geotransform = None
self.fig = None
self.canvas = None
self.ax = None
self.status_label = None
self.init_ui()
def init_ui(self):
self.setWindowTitle("👁️ 交互式影像预览")
self.setMinimumSize(900, 700)
layout = QVBoxLayout()
toolbar = QHBoxLayout()
self.band_combo = QComboBox()
self.band_combo.currentIndexChanged.connect(self.on_band_changed)
toolbar.addWidget(QLabel("显示波段:"))
toolbar.addWidget(self.band_combo)
self.gray_check = QCheckBox("灰度显示")
self.gray_check.stateChanged.connect(self.on_band_changed)
toolbar.addWidget(self.gray_check)
toolbar.addStretch()
layout.addLayout(toolbar)
try:
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import matplotlib
matplotlib.use('Qt5Agg')
self.fig = Figure(figsize=(10, 8))
self.canvas = FigureCanvas(self.fig)
self.ax = self.fig.add_subplot(111)
self.fig.tight_layout()
layout.addWidget(self.canvas)
self.load_and_display()
except ImportError as e:
layout.addWidget(QLabel(f"Matplotlib 未安装: {e}"))
self.status_label = QLabel("点击影像查看像素坐标和经纬度")
self.status_label.setStyleSheet("background:#f0f0f0;padding:4px;font-size:12px;")
self.status_label.setWordWrap(True)
layout.addWidget(self.status_label)
close_btn = QPushButton("关闭")
close_btn.clicked.connect(self.close)
layout.addWidget(close_btn)
self.setLayout(layout)
def load_and_display(self):
"""加载影像并显示"""
from osgeo import gdal
dataset = gdal.Open(self.img_path)
if dataset is None:
self.status_label.setText(f"无法打开影像: {self.img_path}")
return
self.geotransform = dataset.GetGeoTransform()
self.projection = dataset.GetProjection()
n_bands = dataset.RasterCount
self.height = dataset.RasterYSize
self.width = dataset.RasterXSize
self.band_combo.clear()
if n_bands >= 3:
for i in range(1, n_bands + 1):
self.band_combo.addItem(f"RGB (B{i-0}, G{i-1}, R{i-2})" if i >= 3 else f"波段 {i}", i)
self.band_combo.addItem(f"单波段 (B1)", 0)
else:
for i in range(1, n_bands + 1):
self.band_combo.addItem(f"波段 {i}", i - 1)
self.band_combo.setCurrentIndex(0)
self.dataset = dataset
self.display_band(0, is_gray=False)
self.load_ref_points()
def display_band(self, band_idx, is_gray=False):
"""显示指定波段组合"""
from osgeo import gdal
import numpy as np
dataset = self.dataset
self.ax.clear()
if is_gray or (self.band_combo.currentData() == 0 and dataset.RasterCount == 1):
band = dataset.GetRasterBand(1 if band_idx == 0 else band_idx + 1)
data = band.ReadAsArray()
data = np.nan_to_num(data, nan=0.0)
self.ax.imshow(data, cmap='gray')
self.ax.set_title(f"波段 {band_idx + 1} (灰度)")
else:
n = min(3, dataset.RasterCount)
bands_data = []
for i in range(n):
b = dataset.GetRasterBand(i + 1)
bd = b.ReadAsArray()
bd = np.nan_to_num(bd, nan=0.0)
bands_data.append(bd)
rgb = np.dstack(bands_data)
for i in range(rgb.shape[2]):
p2, p98 = np.percentile(rgb[:, :, i], [2, 98])
if p98 > p2:
rgb[:, :, i] = np.clip((rgb[:, :, i] - p2) / (p98 - p2), 0, 1)
else:
rgb[:, :, i] = np.clip(rgb[:, :, i] / (p98 + 1e-6), 0, 1)
self.ax.imshow(rgb)
self.ax.set_title(f"RGB 显示")
self.ax.set_xlabel("列 (Column)")
self.ax.set_ylabel("行 (Row)")
self.fig.tight_layout()
self.canvas.draw()
self.cid = self.canvas.mpl_connect('button_press_event', self.on_click)
def on_band_changed(self):
"""波段选择变化时更新显示"""
if not hasattr(self, 'dataset'):
return
is_gray = self.gray_check.isChecked()
band_data = self.band_combo.currentData()
self.display_band(band_data if band_data != 0 else 0, is_gray=is_gray)
def load_ref_points(self):
"""加载并显示参考点"""
import os
if not self.ref_csv or not os.path.isfile(self.ref_csv):
return
try:
import csv
lon_list, lat_list = [], []
with open(self.ref_csv, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
try:
lon = float(row.get('Lon', row.get('lon', row.get('LON', 0))))
lat = float(row.get('Lat', row.get('lat', row.get('LAT', 0))))
if lon and lat:
lon_list.append(lon)
lat_list.append(lat)
except (ValueError, TypeError):
continue
if not lon_list:
return
px_list, py_list = [], []
gt = self.geotransform
if gt and (gt[1] != 0 or gt[5] != 0):
for lon, lat in zip(lon_list, lat_list):
px = (lon - gt[0]) / gt[1]
py = (lat - gt[3]) / gt[5]
if 0 <= px < self.width and 0 <= py < self.height:
px_list.append(px)
py_list.append(py)
if px_list:
self.ax.scatter(px_list, py_list, c='red', s=40, marker='o',
edgecolors='white', linewidths=0.8, zorder=5, alpha=0.9,
label=f'参考点 ({len(px_list)}个)')
self.ax.legend(loc='upper right', fontsize=9)
self.fig.tight_layout()
self.canvas.draw()
self.status_label.setText(
f"已加载 {len(px_list)} 个参考点(仅显示在影像范围内的点)"
)
except Exception as e:
self.status_label.setText(f"加载参考点失败: {e}")
def pixel_to_geo(self, px, py):
"""像素坐标转经纬度"""
gt = self.geotransform
if gt is None:
return None, None
lon = gt[0] + px * gt[1] + py * gt[2]
lat = gt[3] + px * gt[4] + py * gt[5]
return lon, lat
def on_click(self, event):
"""鼠标点击事件"""
if event.inaxes != self.ax or event.xdata is None or event.ydata is None:
return
px, py = int(round(event.xdata)), int(round(event.ydata))
if not (0 <= px < self.width and 0 <= py < self.height):
return
from osgeo import gdal
import numpy as np
dataset = self.dataset
n_bands = dataset.RasterCount
vals = []
for b in range(1, n_bands + 1):
val = dataset.GetRasterBand(b).ReadAsArray()[py, px]
vals.append(f"{val:.4f}" if isinstance(val, float) else str(val))
lon, lat = self.pixel_to_geo(px, py)
geo_str = f"Lon={lon:.6f}, Lat={lat:.6f}" if lon is not None else "无地理参考"
self.status_label.setText(
f"像素: (行={py}, 列={px}) | {geo_str} | "
f"波段值: {' | '.join(vals[:5])}" +
(f" ... ({n_bands}波段的更多信息)" if n_bands > 5 else "")
)

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据模型模块
包含 PandasTableModel 等数据模型类。
"""
import pandas as pd
from PyQt5.QtCore import Qt, QAbstractTableModel
class PandasTableModel(QAbstractTableModel):
"""支持DataFrame的表格模型"""
def __init__(self, data_frame: pd.DataFrame):
super().__init__()
self._data = data_frame.copy()
if self._data.empty:
self._data = pd.DataFrame()
self._data.fillna("", inplace=True)
self._columns = [str(col) for col in self._data.columns]
def rowCount(self, parent=None):
return len(self._data)
def columnCount(self, parent=None):
return len(self._columns)
def data(self, index, role=Qt.DisplayRole):
if not index.isValid() or role != Qt.DisplayRole:
return None
value = self._data.iat[index.row(), index.column()]
if pd.isna(value):
return ""
return str(value)
def headerData(self, section, orientation, role=Qt.DisplayRole):
if role != Qt.DisplayRole:
return None
if orientation == Qt.Horizontal:
if section < len(self._columns):
return self._columns[section]
return str(section)
return str(section + 1)
def flags(self, index):
if not index.isValid():
return Qt.NoItemFlags
return Qt.ItemIsEnabled | Qt.ItemIsSelectable

View File

@ -0,0 +1,351 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
图像浏览组件模块
包含 ImageCategoryTree 和 ImageViewerWidget 类。
"""
import os
from pathlib import Path
from typing import List, Optional
from PyQt5.QtWidgets import (
QTreeWidget, QTreeWidgetItem, QWidget, QVBoxLayout, QHBoxLayout,
QPushButton, QLabel, QScrollArea, QFrame, QGroupBox,
QFileDialog, QMessageBox,
)
from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtGui import QPixmap
class ImageCategoryTree(QTreeWidget):
"""图像分类目录树 - 按类别组织图像文件"""
CATEGORIES = [
("模型评估", ["scatter", "regression", "validation", "r2", "rmse"], "📊"),
("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"),
("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"),
("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"),
("含量分布图", [], "📁"),
]
def __init__(self, parent=None):
super().__init__(parent)
self.setHeaderLabel("图像目录")
self.setMaximumWidth(300)
self.setMinimumWidth(250)
self.setup_categories()
self.setStyleSheet("""
QTreeWidget {
border: 1px solid #ddd;
border-radius: 5px;
background-color: #f8f9fa;
}
QTreeWidget::item {
padding: 5px;
border-radius: 3px;
}
QTreeWidget::item:selected {
background-color: #0078D4;
color: white;
}
QTreeWidget::item:hover {
background-color: #e3f2fd;
}
""")
def setup_categories(self):
"""初始化类别节点"""
self.category_items = {}
for category_name, keywords, icon in self.CATEGORIES:
item = QTreeWidgetItem(self)
item.setText(0, f"{icon} {category_name}")
item.setData(0, Qt.UserRole, {"type": "category", "keywords": keywords, "name": category_name})
item.setExpanded(True)
self.category_items[category_name] = item
def clear_all_images(self):
"""清除所有图像项"""
for category_item in self.category_items.values():
while category_item.childCount() > 0:
category_item.removeChild(category_item.child(0))
def add_image(self, file_path: Path, display_name: str = None):
"""添加图像到对应的类别"""
if display_name is None:
display_name = file_path.stem
category = self._determine_category(file_path.name)
category_item = self.category_items.get(category, self.category_items["含量分布图"])
image_item = QTreeWidgetItem(category_item)
image_item.setText(0, f" └─ {display_name}")
image_item.setData(0, Qt.UserRole, {"type": "image", "path": str(file_path)})
image_item.setToolTip(0, str(file_path))
return image_item
def _determine_category(self, filename: str) -> str:
"""根据文件名确定类别"""
filename_lower = filename.lower()
for category_name, keywords, _ in self.CATEGORIES:
if any(keyword in filename_lower for keyword in keywords):
return category_name
return "含量分布图"
def scan_directory(self, work_dir: str):
"""扫描目录中的所有图像文件"""
self.clear_all_images()
work_path = Path(work_dir)
if not work_path.exists():
return
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff', '*.bmp']
scan_roots: List[Path] = []
_viz = work_path / "14_visualization"
if _viz.is_dir():
scan_roots.append(_viz)
_wm = work_path / "1_water_mask"
if _wm.is_dir():
scan_roots.append(_wm)
if not scan_roots:
scan_roots.append(work_path)
seen_norm: set = set()
image_files: List[Path] = []
for root in scan_roots:
for ext in image_extensions:
for p in root.glob(f"**/{ext}"):
key = os.path.normcase(os.path.normpath(str(p.resolve())))
if key in seen_norm:
continue
seen_norm.add(key)
image_files.append(p)
for img_file in sorted(image_files):
if img_file.name.startswith('.') or 'thumb' in img_file.name.lower():
continue
self.add_image(img_file)
for category_name, item in self.category_items.items():
count = item.childCount()
if count > 0:
for cat_name, _, icon in self.CATEGORIES:
if cat_name == category_name:
item.setText(0, f"{icon} {category_name} ({count})")
break
def get_selected_image_path(self) -> Optional[str]:
"""获取当前选中的图像路径"""
selected_item = self.currentItem()
if not selected_item:
return None
data = selected_item.data(0, Qt.UserRole)
if data and data.get("type") == "image":
return data.get("path")
return None
class ImageViewerWidget(QWidget):
"""图像查看器组件 - 支持缩放、平移"""
def __init__(self, parent=None):
super().__init__(parent)
self.current_image_path = None
self.scale_factor = 1.0
self._update_timer = QTimer()
self._update_timer.setSingleShot(True)
self._update_timer.timeout.connect(self._do_update_display)
self._pending_scale = None
self.setup_ui()
def setup_ui(self):
layout = QVBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
toolbar = QHBoxLayout()
self.refresh_btn = QPushButton("🔄 刷新目录")
self.refresh_btn.setToolTip("重新扫描工作目录中的图像文件")
toolbar.addWidget(self.refresh_btn)
separator = QFrame()
separator.setFrameShape(QFrame.VLine)
separator.setFrameShadow(QFrame.Sunken)
toolbar.addWidget(separator)
self.zoom_in_btn = QPushButton("🔍+")
self.zoom_in_btn.setToolTip("放大")
self.zoom_in_btn.setMaximumWidth(50)
toolbar.addWidget(self.zoom_in_btn)
self.zoom_out_btn = QPushButton("🔍-")
self.zoom_out_btn.setToolTip("缩小")
self.zoom_out_btn.setMaximumWidth(50)
toolbar.addWidget(self.zoom_out_btn)
self.fit_btn = QPushButton("⬜ 适应窗口")
self.fit_btn.setToolTip("适应窗口大小")
toolbar.addWidget(self.fit_btn)
self.original_btn = QPushButton("1:1 原始大小")
self.original_btn.setToolTip("原始大小")
toolbar.addWidget(self.original_btn)
toolbar.addStretch()
self.save_btn = QPushButton("💾 保存")
self.save_btn.setToolTip("保存当前图像")
toolbar.addWidget(self.save_btn)
layout.addLayout(toolbar)
self.scroll_area = QScrollArea()
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setStyleSheet("background-color: white;")
self.image_label = QLabel()
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setStyleSheet("background-color: white;")
self.scroll_area.setWidget(self.image_label)
layout.addWidget(self.scroll_area, 1)
status_layout = QHBoxLayout()
self.status_label = QLabel("就绪")
self.status_label.setStyleSheet("color: #666; font-size: 11px;")
status_layout.addWidget(self.status_label)
status_layout.addStretch()
layout.addLayout(status_layout)
self.setLayout(layout)
self.zoom_in_btn.clicked.connect(self.zoom_in)
self.zoom_out_btn.clicked.connect(self.zoom_out)
self.fit_btn.clicked.connect(self.fit_to_window)
self.original_btn.clicked.connect(self.original_size)
self.save_btn.clicked.connect(self.save_image)
def load_image(self, image_path: str):
"""加载并显示图像"""
if not image_path or not Path(image_path).exists():
self.image_label.setText("图像不存在")
self.status_label.setText("图像加载失败")
return
self.current_image_path = image_path
self.scale_factor = 1.0
pixmap = QPixmap(image_path)
if pixmap.isNull():
self.image_label.setText("无法加载图像")
self.status_label.setText("图像格式不支持")
return
self.original_pixmap = pixmap
self.fit_to_window()
file_info = Path(image_path).stat()
size_mb = file_info.st_size / (1024 * 1024)
self.status_label.setText(f"{pixmap.width()}x{pixmap.height()} | {size_mb:.2f} MB | {Path(image_path).name} | 适应窗口")
def update_image_display(self):
"""更新图像显示 - 使用防抖避免频繁重绘卡顿"""
self._update_timer.stop()
self._pending_scale = self.scale_factor
self._update_timer.start(50)
def _do_update_display(self):
"""实际执行图像更新"""
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
return
if self._pending_scale is None:
return
if self._pending_scale > 2.0 or self._pending_scale < 0.5:
transform = Qt.FastTransformation
else:
transform = Qt.SmoothTransformation
scaled_pixmap = self.original_pixmap.scaled(
int(self.original_pixmap.width() * self._pending_scale),
int(self.original_pixmap.height() * self._pending_scale),
Qt.KeepAspectRatio,
transform
)
self.image_label.setPixmap(scaled_pixmap)
self._pending_scale = None
def wheelEvent(self, event):
"""鼠标滚轮缩放 - 实时响应"""
delta = event.angleDelta().y()
if delta > 0:
if self.scale_factor < 5.0:
self.scale_factor = min(self.scale_factor * 1.1, 5.0)
self.update_image_display()
else:
if self.scale_factor > 0.1:
self.scale_factor = max(self.scale_factor / 1.1, 0.1)
self.update_image_display()
event.accept()
def zoom_in(self):
"""放大"""
if self.scale_factor < 5.0:
self.scale_factor = min(self.scale_factor * 1.25, 5.0)
self.update_image_display()
def zoom_out(self):
"""缩小"""
if self.scale_factor > 0.1:
self.scale_factor = max(self.scale_factor / 1.25, 0.1)
self.update_image_display()
def fit_to_window(self):
"""适应窗口"""
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
return
view_size = self.scroll_area.viewport().size()
img_size = self.original_pixmap.size()
scale_w = view_size.width() / img_size.width()
scale_h = view_size.height() / img_size.height()
self._fit_scale = min(scale_w, scale_h)
self.scale_factor = self._fit_scale
self.update_image_display()
self.status_label.setText(f"适应窗口 | 缩放: {self.scale_factor:.1%}")
def original_size(self):
"""原始大小"""
self.scale_factor = 1.0
self._fit_scale = None
self.update_image_display()
self.status_label.setText("原始大小 | 缩放: 100%")
def save_image(self):
"""保存图像"""
if not self.current_image_path:
return
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图像", Path(self.current_image_path).name,
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
)
if file_path:
try:
import shutil
shutil.copy(self.current_image_path, file_path)
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败: {e}")

View File

@ -0,0 +1,112 @@
import time
import warnings
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_regression
# 屏蔽烦人的 sklearn 警告
warnings.filterwarnings("ignore")
print("====== 🚀 启动 Mega Water 模型终极体检脚本 ======")
# ---------------------------------------------------------
# 1. 完美复刻侦察报告中的 CSV 数据结构
# 报告指出: 目标值(y)在左边,光谱特征(X)在右边
# ---------------------------------------------------------
print("📦 正在生成符合系统结构的模拟测试数据...")
X_raw, y_raw = make_regression(n_samples=200, n_features=50, noise=0.1, random_state=42)
# 模拟真实的 CSV 列名前2列是水质参数后面是 50 个光谱波段
columns = ['Chla', 'SS'] + [f"Band_{i}" for i in range(50)]
# 拼装成一整张大表
data = pd.DataFrame(np.hstack((y_raw.reshape(-1, 1), (y_raw * 0.5).reshape(-1, 1), X_raw)), columns=columns)
# 按照 load_data_batch 的逻辑进行切割
feature_start_index = 2
X = data.iloc[:, feature_start_index:] # 截取光谱作为 X
y = data['Chla'] # 提取一个目标参数作为 y
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"✅ 数据切割完毕! 模拟波段数: {X.shape[1]}, 训练集样本数: {X_train.shape[0]}\n")
# ---------------------------------------------------------
# 2. 严格装载侦察报告中的 16 个真实模型
# ---------------------------------------------------------
print("🔍 正在加载底层真实配置库中的模型...")
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
from sklearn.cross_decomposition import PLSRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.neural_network import MLPRegressor
# 将参数压至极低,实施“降维打击”,确保 1 秒内跑完
models = {
'SVR': SVR(),
'RF': RandomForestRegressor(n_estimators=10, max_depth=5, n_jobs=-1),
'KNN': KNeighborsRegressor(),
'LinearRegression': LinearRegression(),
'Ridge': Ridge(),
'Lasso': Lasso(),
'ElasticNet': ElasticNet(),
'PLS': PLSRegression(),
'GradientBoosting': GradientBoostingRegressor(n_estimators=10, max_depth=5),
'AdaBoost': AdaBoostRegressor(n_estimators=10),
'DecisionTree': DecisionTreeRegressor(max_depth=5),
'MLP': MLPRegressor(max_iter=50),
'ExtraTrees': ExtraTreesRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
}
# 针对报告中发现的 3 个“被禁用”的第三方强力库,进行刺探测试
try:
from xgboost import XGBRegressor
models['XGBoost'] = XGBRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
except ImportError:
models['XGBoost'] = "IMPORT_ERROR"
try:
from lightgbm import LGBMRegressor
models['LightGBM'] = LGBMRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
except ImportError:
models['LightGBM'] = "IMPORT_ERROR"
try:
from catboost import CatBoostRegressor
models['CatBoost'] = CatBoostRegressor(iterations=10, depth=5, verbose=0)
except ImportError:
models['CatBoost'] = "IMPORT_ERROR"
# ---------------------------------------------------------
# 3. 开始残酷的体检循环
# ---------------------------------------------------------
print("\n================ 开始跑分测试 ================")
results = []
for name, model in models.items():
if model == "IMPORT_ERROR":
results.append(f"⚠️ [缺库] {name:<16} : 环境未安装此库 (建议: pip install {name.lower()})")
continue
start_time = time.time()
try:
# 极速拟合与评分
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
cost_time = time.time() - start_time
results.append(f"✅ [成功] {name:<16} : 耗时 {cost_time:.3f} 秒 (R2: {score:.2f})")
except Exception as e:
error_msg = str(e).split('\n')[0]
results.append(f"❌ [崩溃] {name:<16} : {error_msg}")
# ---------------------------------------------------------
# 4. 打印最终体检报告
# ---------------------------------------------------------
print("\n=============== 🏥 最终体检报告 ===============")
for res in results:
print(res)
print("===============================================")

346
src/gui/core/viz_thread.py Normal file
View File

@ -0,0 +1,346 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
可视化后台线程模块
包含 VisualizationWorkerThread 后台线程类和辅助函数。
"""
from pathlib import Path
from typing import Optional, List, Union
from PyQt5.QtCore import QThread, pyqtSignal
import numpy as np
def _viz_infer_wavelength_start_column(df) -> Union[str, int]:
"""推断光谱起始列training_spectra 通常以波长数值为列名,未必含 UTM_Y"""
import pandas as pd
for i, col in enumerate(df.columns):
name = str(col).strip().lstrip("\ufeff")
try:
v = float(name)
except ValueError:
continue
if 200.0 <= v <= 3000.0:
return i
if "UTM_Y" in df.columns:
return "UTM_Y"
return 0
class VisualizationWorkerThread(QThread):
"""可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。"""
finished_ok = pyqtSignal(object)
failed = pyqtSignal(str)
def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None):
super().__init__()
self.task = task
self.work_dir = str(work_dir)
self.extra = extra or {}
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
wp = Path(self.work_dir)
if self.task == "mask_glint":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
cnt = len(preview_paths) if preview_paths else 0
self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths})
elif self.task == "sampling_map":
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if not hyperspectral_files:
self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif")
return
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if not csv_files:
self.failed.emit("未找到采样点 CSV 文件。")
return
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
self.finished_ok.emit(
{
"task": "sampling_map",
"map_path": map_path,
"hyperspectral_path": hyperspectral_path,
"csv_path": csv_path,
}
)
elif self.task == "spectrum":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
wl = self.extra.get("wavelength_start_column", "UTM_Y")
n_groups = int(self.extra.get("n_groups", 5))
param_cols = self.extra.get("param_cols") or []
if param_cols:
output_paths: List[str] = []
err_lines: List[str] = []
for param_col in param_cols:
try:
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
output_paths.append(out)
except Exception as _ex:
err_lines.append(f"{param_col}: {_ex}")
if not output_paths:
self.failed.emit(
"所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20])
)
return
self.finished_ok.emit(
{
"task": "spectrum",
"output_paths": output_paths,
"errors": err_lines,
}
)
else:
param_col = self.extra.get("param_col")
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
self.finished_ok.emit(
{"task": "spectrum", "output_path": out, "param_col": param_col}
)
elif self.task == "statistics":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
param_cols = self.extra.get("param_cols") or []
output_paths = viz.plot_statistical_charts(
csv_path=str(csv_file),
parameter_columns=param_cols,
)
self.finished_ok.emit(
{"task": "statistics", "output_paths": output_paths}
)
elif self.task == "scatter":
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
training_csv_path = (self.extra.get("training_csv_path") or "").strip()
models_dir = (self.extra.get("models_dir") or "").strip()
if not training_csv_path or not Path(training_csv_path).is_file():
self.failed.emit("训练光谱 CSV 无效或不存在请确认已选择步骤5输出的文件。")
return
if not models_dir or not Path(models_dir).is_dir():
self.failed.emit("模型目录无效或不存在请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
return
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=training_csv_path,
models_dir=models_dir,
)
self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}})
elif self.task == "generate_all_selected":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
parts = []
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
if self.extra.get("gen_scatter"):
if training_csv.is_file():
models_dir = wp / "7_Supervised_Model_Training"
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=str(training_csv),
models_dir=str(models_dir),
)
count = len(scatter_paths) if scatter_paths else 0
parts.append(f"散点图: {count}")
else:
parts.append("散点图: 跳过(无模型目录)")
else:
parts.append("散点图: 跳过(无训练数据)")
if self.extra.get("gen_spectrum"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
wl_col = _viz_infer_wavelength_start_column(df)
if isinstance(wl_col, str):
idx = int(df.columns.get_loc(wl_col)) + 1
else:
idx = int(wl_col)
param_cols = []
if idx > 0 and idx < len(df.columns):
param_cols = [
c for c in df.columns[:idx]
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
]
if param_cols:
spectrum_paths = []
for param_col in param_cols:
try:
path = viz.plot_spectrum_by_parameter(
csv_path=str(training_csv),
parameter_column=param_col,
wavelength_start_column=wl_col,
n_groups=5,
)
if path:
spectrum_paths.append(path)
except Exception as e:
print(f"生成光谱图失败 ({param_col}): {e}")
count = len(spectrum_paths)
parts.append(f"光谱图: {count}")
else:
parts.append("光谱图: 跳过(无可用参数列)")
else:
parts.append("光谱图: 跳过(无训练数据)")
if self.extra.get("gen_boxplots"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
param_cols = [
c for c in df.select_dtypes(include=[np.number]).columns
if not any(exc in c.lower() for exc in exclude_cols)
]
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
if 0 < idx < len(df.columns):
meta_set = set(df.columns[:idx])
param_cols = [c for c in param_cols if c in meta_set]
if param_cols:
output_dict = viz.plot_statistical_charts(
csv_path=str(training_csv),
parameter_columns=param_cols,
)
count = len([v for v in output_dict.values() if v]) if output_dict else 0
parts.append(f"统计图: {count}")
else:
parts.append("统计图: 跳过(无可用水质参数列)")
else:
parts.append("统计图: 跳过(无训练数据)")
if self.extra.get("gen_mask_glint"):
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0}")
if self.extra.get("gen_sampling_map"):
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if hyperspectral_files:
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if csv_files:
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
parts.append(f"采样点图: {Path(map_path).name}")
else:
parts.append("采样点图: 跳过(无CSV)")
else:
parts.append("采样点图: 跳过(无影像)")
self.finished_ok.emit({"task": "generate_all_selected", "parts": parts})
else:
self.failed.emit(f"未知可视化任务: {self.task}")
except Exception as e:
import traceback
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass

93
src/gui/crash_dump.txt Normal file
View File

@ -0,0 +1,93 @@
============================================================
[2026-05-12 11:14:51]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 130, in <module>
from src.gui.panels.step9_panel import Step9Panel
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\panels\step9_panel.py", line 24, in <module>
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\core\water_quality_inversion_pipeline_GUI.py", line 45, in <module>
from src.preprocessing.process_water_quality_data import process_water_quality_data
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\preprocessing\process_water_quality_data.py", line 9, in <module>
from scipy import stats
File "<frozen importlib._bootstrap>", line 1412, in _handle_fromlist
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\__init__.py", line 143, in __getattr__
return _importlib.import_module(f'scipy.{name}')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\importlib\__init__.py", line 90, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\__init__.py", line 632, in <module>
from ._multicomp import *
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\_multicomp.py", line 11, in <module>
from scipy.stats._qmc import check_random_state
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\_qmc.py", line 26, in <module>
from scipy.sparse.csgraph import minimum_spanning_tree
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\sparse\csgraph\__init__.py", line 188, in <module>
from ._shortest_path import (
File "scipy/sparse/csgraph/_shortest_path.pyx", line 21, in init scipy.sparse.csgraph._shortest_path
File "<frozen importlib._bootstrap>", line 1349, in _find_and_load
KeyboardInterrupt
============================================================
[2026-05-12 11:57:28]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
main()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3093, in main
_dialog.exec_()
KeyboardInterrupt
============================================================
[2026-05-28 15:45:11]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
main()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3097, in main
window = WaterQualityGUI()
^^^^^^^^^^^^^^^^^
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1352, in __init__
self.init_ui()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1586, in init_ui
self.create_content_area()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1943, in create_content_area
self.step2_panel = Step2Panel()
^^^^^^^^^^^^
TypeError: Step2Panel.__init__() missing 1 required positional argument: 'session'
============================================================
[2026-05-28 15:45:19]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
main()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3097, in main
window = WaterQualityGUI()
^^^^^^^^^^^^^^^^^
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1352, in __init__
self.init_ui()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1586, in init_ui
self.create_content_area()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1943, in create_content_area
self.step2_panel = Step2Panel()
^^^^^^^^^^^^
TypeError: Step2Panel.__init__() missing 1 required positional argument: 'session'
============================================================
[2026-05-28 16:00:53]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 2149, in on_step_changed
self.auto_populate_step_inputs(item_data)
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 2362, in auto_populate_step_inputs
if step_id not in self.step_dependencies:
^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'WaterQualityGUI' object has no attribute 'step_dependencies'. Did you mean: '_init_step_dependencies'?
============================================================
[2026-06-03 13:56:59]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3354, in <module>
main()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3331, in main
sys.exit(app.exec_())
^^^^^^^^^^^
KeyboardInterrupt

View File

@ -325,7 +325,7 @@ class Step3Panel(QWidget):
}
water_mask_path = self.water_mask_file.get_path()
if water_mask_path:
config['water_mask'] = water_mask_path
config['water_mask_path'] = water_mask_path
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
@ -366,8 +366,8 @@ class Step3Panel(QWidget):
"""设置配置"""
if 'img_path' in config:
self.img_file.set_path(config['img_path'])
if 'water_mask' in config:
self.water_mask_file.set_path(config['water_mask'])
if 'water_mask_path' in config:
self.water_mask_file.set_path(config['water_mask_path'])
if 'output_path' in config:
self.output_file.set_path(config['output_path'])
if 'reference_csv' in config:

View File

@ -187,7 +187,7 @@ class Step5_5Panel(QWidget):
def get_config(self):
selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
return {
'training_spectra_path': self.training_data_widget.get_path(),
'training_csv_path': self.training_data_widget.get_path(),
'formula_csv_file': self.builtin_formula_path,
'formula_names': selected,
'output_file': self.output_file_widget.get_path(),
@ -195,7 +195,7 @@ class Step5_5Panel(QWidget):
}
def set_config(self, config):
if 'training_spectra_path' in config: self.training_data_widget.set_path(config['training_spectra_path'])
if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path'])
if 'formula_names' in config:
sel = set(config['formula_names'])
for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
@ -217,7 +217,7 @@ class Step5_5Panel(QWidget):
def run_step(self):
config = self.get_config()
if not config['training_spectra_path']:
if not config['training_csv_path']:
QMessageBox.warning(self, "提示", "请先选择输入数据")
return
parent = self.parent()

View File

@ -124,7 +124,7 @@ class Step5Panel(QWidget):
glint_mask_path = self.glint_mask_file.get_path()
if glint_mask_path:
config['glint_mask_path'] = glint_mask_path
# 注意step5_extract_training_spectra 不接受 output_path / training_spectra_path
# 注意step5_extract_training_spectra 不接受 output_path / training_csv_path
# 参数,输出路径由 pipeline 内部根据 training_spectra_dir 自动生成。
return config

View File

@ -363,7 +363,7 @@ class Step6Panel(QWidget):
# 回退:从 Step5 的 config 字典中查找可能的键名
step5_cfg = main_window.step5_panel.get_config()
step5_csv = (
step5_cfg.get('training_spectra_path')
step5_cfg.get('training_csv_path')
or step5_cfg.get('output_file')
or step5_cfg.get('csv_path')
or step5_cfg.get('output_csv')

BIN
src/gui/scaler_params.pkl Normal file

Binary file not shown.

View File

@ -1432,7 +1432,7 @@ class WaterQualityGUI(QMainWindow):
'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤5可选耀斑掩膜
},
'step5_5': {
'training_spectra_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱
'training_csv_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱
},
'step6': {
'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6需要训练光谱数据