refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名
This commit is contained in:
@ -20,23 +20,28 @@ class PipelineContext:
|
||||
"""流水线运行上下文(在 14 个 step 之间传递的内存字典)
|
||||
|
||||
字段命名约定:
|
||||
- 主路径字段统一 `_path` 后缀(如 water_mask_path)
|
||||
- 目录类字段无 `_path` 后缀(如 models_dir)
|
||||
- 路径类字段名 = panel key 名 = step 形参名(全链路无翻译)
|
||||
- 训练/产物 CSV 用 `_path` 后缀(如 training_csv_path / water_mask_path)
|
||||
- 入参影像/CSV 沿用 panel 原名(img_path / csv_path),无 `_path` 后缀
|
||||
- 目录类字段无 `_path` 后缀(如 models_dir / prediction_dir)
|
||||
- 元信息字段无后缀(如 user_config / status / log)
|
||||
"""
|
||||
|
||||
# ── 9 步主路径(按 step 输出顺序排列) ──
|
||||
raw_img_path: Optional[str] = None # Step 1 入参:原始影像
|
||||
# ── 11 个 step 的入参/产物(按 step 顺序排列;字段名 = panel key = step 形参) ──
|
||||
img_path: Optional[str] = None # Step 1/2/3 入参:原始影像
|
||||
water_mask_path: Optional[str] = None # Step 1 出 → Step 2/3/7 入
|
||||
glint_mask_path: Optional[str] = None # Step 2 出 → Step 3/7 入
|
||||
deglint_img_path: Optional[str] = None # Step 3 出 → Step 5/7 入
|
||||
raw_csv_path: Optional[str] = None # Step 4 入:原始 CSV
|
||||
csv_path: Optional[str] = None # Step 4/5/6_5/6_75 入参:原始/训练 CSV
|
||||
processed_csv_path: Optional[str] = None # Step 4 出 → Step 5 入
|
||||
training_spectra_path: Optional[str] = None # Step 5 出 → Step 6 入
|
||||
training_csv_path: Optional[str] = None # Step 5 出 → Step 5_5/6/6_5/6_75 入
|
||||
boundary_path: Optional[str] = None # Step 5 入参:边界 SHP(panel step5 名)
|
||||
indices_path: Optional[str] = None # Step 5.5 出
|
||||
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/9 入
|
||||
prediction_csv_path: Optional[str] = None # Step 8 出
|
||||
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/8_5/8_75/9 入
|
||||
prediction_csv_path: Optional[str] = None # Step 8 出 → Step 9 入
|
||||
distribution_map_path: Optional[str] = None # Step 9 出
|
||||
boundary_shp_path: Optional[str] = None # Step 9 入参:边界 SHP(panel step9 名)
|
||||
formula_csv_path: Optional[str] = None # Step 8_75 入参:公式 CSV
|
||||
|
||||
# ── 目录类(命名不带 _path 以示区别) ──
|
||||
models_dir: Optional[str] = None
|
||||
|
||||
@ -4,10 +4,8 @@ PipelineRunner:基于 StepSpec 声明式调度 14 个 step。
|
||||
|
||||
设计要点:
|
||||
- StepSpec 声明 requires(ctx 字段名列表)+ produces(ctx 字段名列表)
|
||||
- 默认约定:ctx 字段名去掉 `_path` 后缀 = step 方法形参名
|
||||
例:ctx.water_mask_path → 形参 water_mask
|
||||
例:ctx.raw_img_path → 形参 raw_img
|
||||
- 可被 spec.parameter_map 覆盖
|
||||
- 命名约定:ctx 字段名 == panel key 名 == step 形参名(全链路无翻译)
|
||||
- 保留 spec.parameter_map 字段骨架供极少数特例覆盖(默认空 dict)
|
||||
- 调度顺序:按 PIPELINE_STEPS 列表顺序,requires 缺则 skip
|
||||
- 软取消:在每个 step 前检查 ctx.is_cancelled()
|
||||
- duck-typed pipeline:runner 只调 getattr(pipeline, method_name),不强依赖类层级
|
||||
@ -48,101 +46,76 @@ class StepSpec:
|
||||
PIPELINE_STEPS: List[StepSpec] = [
|
||||
StepSpec(
|
||||
step_id="step1", method_name="step1_generate_water_mask",
|
||||
requires=["raw_img_path"], produces=["water_mask_path"],
|
||||
# ctx.raw_img_path → 形参 img_path(老 step1 形参名是 img_path,不是 raw_img)
|
||||
parameter_map={"raw_img_path": "img_path"},
|
||||
requires=["img_path"], produces=["water_mask_path"],
|
||||
description="水域掩膜生成(NDWI 或 SHP)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step2", method_name="step2_find_glint_area",
|
||||
requires=["raw_img_path", "water_mask_path"], produces=["glint_mask_path"],
|
||||
# raw_img_path→img_path;water_mask_path 不变
|
||||
parameter_map={"raw_img_path": "img_path"},
|
||||
requires=["img_path", "water_mask_path"], produces=["glint_mask_path"],
|
||||
description="耀斑区域检测",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step3", method_name="step3_remove_glint",
|
||||
requires=["deglint_img_path"], produces=["deglint_img_path"],
|
||||
# deglint_img_path→img_path(老 step3 形参名是 img_path)
|
||||
# 注意:glint_mask_path 不在 requires 中——step3 形参表无该参数,内部走 self.glint_mask_path 回退
|
||||
parameter_map={"deglint_img_path": "img_path"},
|
||||
requires=["img_path", "water_mask_path", "glint_mask_path"],
|
||||
produces=["deglint_img_path"],
|
||||
description="耀斑去除",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step4", method_name="step4_process_csv",
|
||||
requires=["raw_csv_path"], produces=["processed_csv_path"],
|
||||
# raw_csv_path→csv_path(老 step4 形参名是 csv_path)
|
||||
parameter_map={"raw_csv_path": "csv_path"},
|
||||
requires=["csv_path"], produces=["processed_csv_path"],
|
||||
description="CSV 异常值清洗",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step5", method_name="step5_extract_training_spectra",
|
||||
requires=["deglint_img_path", "processed_csv_path"], produces=["training_spectra_path"],
|
||||
# processed_csv_path→csv_path(老 step5 形参名是 csv_path);deglint_img_path 不变
|
||||
parameter_map={"processed_csv_path": "csv_path"},
|
||||
requires=["deglint_img_path", "csv_path", "boundary_path", "glint_mask_path"],
|
||||
produces=["training_csv_path"],
|
||||
description="实测样本点光谱提取",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step5_5", method_name="step5_5_calculate_water_quality_indices",
|
||||
requires=["training_spectra_path"], produces=["indices_path"],
|
||||
# 老 step5.5 形参是 training_spectra_path;ctx 字段同名,无需映射
|
||||
parameter_map={},
|
||||
requires=["training_csv_path"], produces=["indices_path"],
|
||||
description="水质光谱指数计算(optional)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step6", method_name="step6_train_models",
|
||||
requires=["training_spectra_path"], produces=["models_dir"],
|
||||
# training_spectra_path→training_csv_path(老 step6 形参名是 training_csv_path)
|
||||
parameter_map={"training_spectra_path": "training_csv_path"},
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
description="ML 建模(GridSearchCV / AutoML)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step6_5", method_name="step6_5_non_empirical_modeling",
|
||||
requires=["training_spectra_path"], produces=["models_dir"],
|
||||
# training_spectra_path→csv_path(老 step6.5 形参名是 csv_path)
|
||||
parameter_map={"training_spectra_path": "csv_path"},
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
description="非经验统计回归",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step6_75", method_name="step6_75_custom_regression",
|
||||
requires=["training_spectra_path"], produces=["models_dir"],
|
||||
# training_spectra_path→csv_path(老 step6.75 形参名是 csv_path)
|
||||
parameter_map={"training_spectra_path": "csv_path"},
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
description="自定义回归分析",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step7", method_name="step7_generate_sampling_points",
|
||||
requires=["deglint_img_path", "water_mask_path"], produces=["sampling_csv_path"],
|
||||
# 老 step7 形参是 deglint_img_path / water_mask_path;ctx 字段同名
|
||||
parameter_map={},
|
||||
description="整景密集采样点生成 + 光谱提取",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step8", method_name="step8_predict_water_quality",
|
||||
requires=["sampling_csv_path", "models_dir"], produces=["prediction_csv_path"],
|
||||
parameter_map={},
|
||||
description="ML 模型预测(采样点)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step8_5", method_name="step8_5_predict_with_non_empirical_models",
|
||||
requires=["sampling_csv_path"], produces=["prediction_dir"],
|
||||
parameter_map={},
|
||||
requires=["sampling_csv_path", "models_dir"], produces=["prediction_dir"],
|
||||
description="非经验模型预测",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step8_75", method_name="step8_75_predict_with_custom_regression",
|
||||
requires=["sampling_csv_path"], produces=["prediction_dir"],
|
||||
parameter_map={},
|
||||
requires=["sampling_csv_path", "models_dir", "formula_csv_path"],
|
||||
produces=["prediction_dir"],
|
||||
description="自定义回归预测",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step9", method_name="step9_generate_distribution_map",
|
||||
requires=["prediction_csv_path"],
|
||||
requires=["prediction_csv_path", "boundary_shp_path"],
|
||||
produces=["distribution_map_path"],
|
||||
# 老 step9 形参是 prediction_csv_path / boundary_shp_path;ctx 字段同名
|
||||
# 注意:sampling_csv_path / water_mask_path 不在 requires 中——step9 形参表无该参数,
|
||||
# 内部走 self.sampling_csv_path / self.water_mask_path 回退
|
||||
parameter_map={},
|
||||
description="克里金插值成图",
|
||||
),
|
||||
]
|
||||
@ -157,7 +130,7 @@ class PipelineRunner:
|
||||
|
||||
用法:
|
||||
runner = PipelineRunner(pipeline_instance)
|
||||
ctx = PipelineContext(raw_img_path=..., ...)
|
||||
ctx = PipelineContext(img_path=..., ...)
|
||||
result_ctx = runner.run(ctx)
|
||||
"""
|
||||
|
||||
|
||||
544
src/core/prediction/automl_trainer.py
Normal file
544
src/core/prediction/automl_trainer.py
Normal file
@ -0,0 +1,544 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Optuna + 智能子采样 AutoML 训练器(路线 B 防爆引擎)。
|
||||
|
||||
为什么需要这个:
|
||||
- 老路径:11 预处理 × 4 模型 × 3 划分 = 132 组 GridSearchCV
|
||||
对中小数据集 10 分钟+,对大数据集 5w+ 行 直接 OOM
|
||||
- AutoML 路径:1 预处理 × N 模型(Optuna 调超参),用智能子采样避开 OOM
|
||||
再用最优超参在**全量数据**上 refit,最终保存单一模型
|
||||
|
||||
设计要点:
|
||||
- 入口 train_with_automl(csv, feature_start_column, model_names, ...)
|
||||
- AutoMLResult dataclass 返回(每个目标列一份)
|
||||
- smart_subsample:N > max_samples 时随机下采样
|
||||
- 失败兜底:optuna 未装 / 全 trial 失败 → fallback 到 WaterQualityModelingBatch
|
||||
- 文件命名规范:{target}_{preprocess}_{model}_AUTOML.joblib
|
||||
- save_data["metadata"]["automl"] = True 标记
|
||||
|
||||
调用:
|
||||
from src.core.prediction.automl_trainer import train_with_automl
|
||||
results = train_with_automl(
|
||||
training_csv_path=".../training_spectra.csv",
|
||||
feature_start_column="374.285004",
|
||||
model_names=["RF", "SVR", "Ridge"],
|
||||
n_trials=20,
|
||||
timeout_sec=300,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 常量
|
||||
# ============================================================
|
||||
|
||||
# AutoML 寻优阶段允许的最大样本数(避免 OOM)
|
||||
# 5000 样本对 RF/SVR/Ridge 的 Optuna 寻优足够给出稳定 CV
|
||||
DEFAULT_MAX_SAMPLES = 5000
|
||||
|
||||
# 单次 Optuna trial 的默认超时(秒)
|
||||
DEFAULT_TIMEOUT = 300.0
|
||||
|
||||
# 默认 trial 数
|
||||
DEFAULT_N_TRIALS = 20
|
||||
|
||||
# AutoML 输出目录名后缀
|
||||
AUTOML_DIR_SUFFIX = "_AutoML"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 数据类
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class AutoMLResult:
|
||||
"""单个目标列的 AutoML 训练结果"""
|
||||
success: bool = False
|
||||
model_path: Optional[str] = None
|
||||
cv_score: float = -float("inf")
|
||||
best_params: Optional[Dict[str, Any]] = None
|
||||
target_column: str = ""
|
||||
preprocessing: str = ""
|
||||
model_name: str = ""
|
||||
n_trials_done: int = 0
|
||||
n_samples_used: int = 0
|
||||
fallback_used: bool = False
|
||||
elapsed_sec: float = 0.0
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 智能子采样
|
||||
# ============================================================
|
||||
|
||||
def smart_subsample(
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
max_samples: int = DEFAULT_MAX_SAMPLES,
|
||||
random_state: int = 42,
|
||||
) -> Tuple[np.ndarray, np.ndarray, bool]:
|
||||
"""当 N > max_samples 时随机下采样;否则原样返回。
|
||||
|
||||
Returns:
|
||||
(X_sub, y_sub, was_subsampled)
|
||||
"""
|
||||
n = X.shape[0]
|
||||
if n <= max_samples:
|
||||
return X, y, False
|
||||
rng = np.random.default_rng(random_state)
|
||||
idx = rng.choice(n, size=max_samples, replace=False)
|
||||
return X[idx], y[idx], True
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 模型工厂
|
||||
# ============================================================
|
||||
|
||||
def _build_model(model_name: str, random_state: int = 42):
|
||||
"""根据英文模型键名构造 sklearn-compatible 模型实例(factory)。"""
|
||||
from sklearn.ensemble import (
|
||||
AdaBoostRegressor, ExtraTreesRegressor, GradientBoostingRegressor,
|
||||
RandomForestRegressor,
|
||||
)
|
||||
from sklearn.linear_model import (
|
||||
ElasticNet, Lasso, LinearRegression, Ridge,
|
||||
)
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.tree import DecisionTreeRegressor
|
||||
|
||||
factory = {
|
||||
"RF": lambda **kw: RandomForestRegressor(random_state=random_state, n_jobs=1, **kw),
|
||||
"ET": lambda **kw: ExtraTreesRegressor(random_state=random_state, n_jobs=1, **kw),
|
||||
"GradientBoosting": lambda **kw: GradientBoostingRegressor(random_state=random_state, **kw),
|
||||
"AdaBoost": lambda **kw: AdaBoostRegressor(random_state=random_state, **kw),
|
||||
"Ridge": lambda **kw: Ridge(**kw),
|
||||
"Lasso": lambda **kw: Lasso(max_iter=5000, **kw),
|
||||
"ElasticNet": lambda **kw: ElasticNet(max_iter=5000, **kw),
|
||||
"LinearRegression": lambda **kw: LinearRegression(**kw),
|
||||
"SVR": lambda **kw: SVR(**kw),
|
||||
"KNN": lambda **kw: KNeighborsRegressor(n_jobs=1, **kw),
|
||||
"MLP": lambda **kw: MLPRegressor(max_iter=500, random_state=random_state, **kw),
|
||||
"DecisionTree": lambda **kw: DecisionTreeRegressor(random_state=random_state, **kw),
|
||||
"PLS": None, # sklearn.cross_decomposition.PLSRegression 暂未集成
|
||||
}
|
||||
builder = factory.get(model_name)
|
||||
if builder is None:
|
||||
return None
|
||||
return builder
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Optuna 超参 search space
|
||||
# ============================================================
|
||||
|
||||
def _get_search_space(model_name: str, trial) -> Dict[str, Any]:
|
||||
"""按模型名返回 Optuna 超参 search space。"""
|
||||
sp: Dict[str, Any] = {}
|
||||
if model_name == "RF":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
|
||||
sp["min_samples_split"] = trial.suggest_int("min_samples_split", 2, 10)
|
||||
sp["min_samples_leaf"] = trial.suggest_int("min_samples_leaf", 1, 5)
|
||||
elif model_name == "ET":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
|
||||
elif model_name == "GradientBoosting":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 8)
|
||||
sp["learning_rate"] = trial.suggest_float("learning_rate", 0.01, 0.3, log=True)
|
||||
elif model_name == "SVR":
|
||||
sp["C"] = trial.suggest_float("C", 0.1, 100.0, log=True)
|
||||
sp["epsilon"] = trial.suggest_float("epsilon", 0.001, 1.0, log=True)
|
||||
sp["kernel"] = trial.suggest_categorical("kernel", ["rbf", "linear"])
|
||||
elif model_name == "KNN":
|
||||
sp["n_neighbors"] = trial.suggest_int("n_neighbors", 3, 20)
|
||||
sp["weights"] = trial.suggest_categorical("weights", ["uniform", "distance"])
|
||||
elif model_name in ("Ridge", "Lasso", "ElasticNet"):
|
||||
sp["alpha"] = trial.suggest_float("alpha", 0.01, 100.0, log=True)
|
||||
if model_name == "ElasticNet":
|
||||
sp["l1_ratio"] = trial.suggest_float("l1_ratio", 0.0, 1.0)
|
||||
elif model_name == "MLP":
|
||||
sp["hidden_layer_sizes"] = trial.suggest_categorical(
|
||||
"hidden_layer_sizes", [(50,), (100,), (50, 50), (100, 50)]
|
||||
)
|
||||
sp["alpha"] = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
|
||||
sp["learning_rate_init"] = trial.suggest_float("learning_rate_init", 1e-4, 1e-2, log=True)
|
||||
elif model_name == "DecisionTree":
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
|
||||
sp["min_samples_split"] = trial.suggest_int("min_samples_split", 2, 10)
|
||||
elif model_name == "AdaBoost":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 30, 200, step=30)
|
||||
sp["learning_rate"] = trial.suggest_float("learning_rate", 0.01, 1.0, log=True)
|
||||
else:
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 200, step=50)
|
||||
return sp
|
||||
|
||||
|
||||
def _make_objective(model_name: str, X: np.ndarray, y: np.ndarray,
|
||||
cv_folds: int, random_state: int):
|
||||
"""构造 Optuna objective(5 折 CV R²)。"""
|
||||
from sklearn.model_selection import KFold, cross_val_score
|
||||
|
||||
def objective(trial):
|
||||
params = _get_search_space(model_name, trial)
|
||||
try:
|
||||
builder = _build_model(model_name, random_state=random_state)
|
||||
if builder is None:
|
||||
return -1.0
|
||||
model = builder(**params)
|
||||
kf = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
||||
scores = cross_val_score(model, X, y, cv=kf, scoring="r2", n_jobs=1)
|
||||
return float(np.mean(scores))
|
||||
except Exception:
|
||||
return -1.0
|
||||
|
||||
return objective
|
||||
|
||||
|
||||
def _refit_full(model_name: str, best_params: Dict[str, Any],
|
||||
X: np.ndarray, y: np.ndarray, random_state: int):
|
||||
"""用 best params 在**全量数据**上 refit。"""
|
||||
builder = _build_model(model_name, random_state=random_state)
|
||||
if builder is None:
|
||||
return None
|
||||
model = builder(**best_params)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 失败兜底(回退到老 GridSearchCV 路径)
|
||||
# ============================================================
|
||||
|
||||
def _fallback_train(
|
||||
training_csv_path: str,
|
||||
feature_start_column,
|
||||
preprocessing: str,
|
||||
model_name: str,
|
||||
split_method: str,
|
||||
cv_folds: int,
|
||||
output_dir: Path,
|
||||
target_column: str,
|
||||
) -> AutoMLResult:
|
||||
"""AutoML 失败时调老 WaterQualityModelingBatch。
|
||||
|
||||
返回的 AutoMLResult.fallback_used=True。
|
||||
"""
|
||||
try:
|
||||
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||
except ImportError as e:
|
||||
return AutoMLResult(
|
||||
success=False, error=f"fallback 导入失败: {e!r}", fallback_used=True,
|
||||
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
|
||||
)
|
||||
|
||||
try:
|
||||
out_dir = output_dir / preprocessing
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
modeler = WaterQualityModelingBatch(str(out_dir))
|
||||
modeler.train_models_batch(
|
||||
csv_path=training_csv_path,
|
||||
feature_start_column=feature_start_column,
|
||||
preprocessing_methods=[preprocessing],
|
||||
model_names=[model_name],
|
||||
split_methods=[split_method],
|
||||
cv_folds=cv_folds,
|
||||
)
|
||||
# 找产出
|
||||
candidates = list(out_dir.rglob(f"{target_column}_{preprocessing}_{model_name}.joblib"))
|
||||
model_path = str(candidates[0]) if candidates else None
|
||||
return AutoMLResult(
|
||||
success=model_path is not None,
|
||||
model_path=model_path,
|
||||
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
|
||||
fallback_used=True,
|
||||
metadata={"source": "WaterQualityModelingBatch"},
|
||||
)
|
||||
except Exception as e:
|
||||
return AutoMLResult(
|
||||
success=False, error=f"fallback 失败: {e!r}", fallback_used=True,
|
||||
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 主入口
|
||||
# ============================================================
|
||||
|
||||
def train_with_automl(
|
||||
training_csv_path: str,
|
||||
feature_start_column,
|
||||
preprocessing_methods: Optional[List[str]] = None,
|
||||
model_names: Optional[List[str]] = None,
|
||||
split_methods: Optional[List[str]] = None,
|
||||
cv_folds: int = 5,
|
||||
output_dir: Optional[str] = None,
|
||||
n_trials: int = DEFAULT_N_TRIALS,
|
||||
timeout_sec: float = DEFAULT_TIMEOUT,
|
||||
max_samples: int = DEFAULT_MAX_SAMPLES,
|
||||
random_state: int = 42,
|
||||
callback: Optional[Callable[[str, str, str], None]] = None,
|
||||
) -> List[AutoMLResult]:
|
||||
"""用 Optuna + 子采样跑 AutoML。失败时自动回退到 GridSearchCV。
|
||||
|
||||
Args:
|
||||
training_csv_path: 训练用 CSV(Step 5 产物 training_spectra.csv)
|
||||
feature_start_column: 特征起始列名或索引(之前所有列视为目标 y)
|
||||
preprocessing_methods: 候选预处理列表(**仅用第 1 个**,避免笛卡尔爆炸)
|
||||
model_names: 候选模型列表(每个都会跑一遍 Optuna)
|
||||
split_methods: 候选数据划分列表(AutoML 仅用第 1 个)
|
||||
cv_folds: 交叉验证折数
|
||||
output_dir: 输出目录(默认 <models_dir>_AutoML)
|
||||
n_trials: 单模型 Optuna trial 数
|
||||
timeout_sec: 单模型超时(秒),到时强制停止
|
||||
max_samples: 寻优阶段允许的最大样本数
|
||||
callback: 状态回调 callback(step_name, status, message)
|
||||
|
||||
Returns:
|
||||
List[AutoMLResult],每个目标列一份结果
|
||||
"""
|
||||
def notify(status: str, msg: str = "") -> None:
|
||||
if callback:
|
||||
callback("步骤6_AutoML", status, msg)
|
||||
|
||||
# ---- 1) 参数默认值 ----
|
||||
if preprocessing_methods is None:
|
||||
preprocessing_methods = ["MMS"]
|
||||
if model_names is None:
|
||||
model_names = ["RF", "SVR", "Ridge"]
|
||||
if split_methods is None:
|
||||
split_methods = ["spxy"]
|
||||
|
||||
# 决策:仅用第一个预处理 + 第一个划分,避免笛卡尔爆炸
|
||||
preproc = preprocessing_methods[0]
|
||||
split_method = split_methods[0]
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = "./7_Supervised_Model_Training_AutoML"
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
preproc_dir = out_dir / preproc
|
||||
preproc_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---- 2) 加载数据 ----
|
||||
notify("start", f"AutoML 训练开始 (n_trials={n_trials}, timeout={timeout_sec}s, max_samples={max_samples})")
|
||||
if not Path(training_csv_path).exists():
|
||||
return [AutoMLResult(success=False, error=f"训练 CSV 不存在: {training_csv_path}")]
|
||||
|
||||
df = pd.read_csv(training_csv_path)
|
||||
|
||||
# 提取目标列(feature_start_column 之前所有数值列)
|
||||
if isinstance(feature_start_column, int):
|
||||
y_cols = [c for c in df.columns[:feature_start_column]
|
||||
if pd.api.types.is_numeric_dtype(df[c])]
|
||||
else:
|
||||
try:
|
||||
idx = list(df.columns).index(feature_start_column)
|
||||
y_cols = [c for c in df.columns[:idx]
|
||||
if pd.api.types.is_numeric_dtype(df[c])]
|
||||
except ValueError:
|
||||
y_cols = []
|
||||
|
||||
if not y_cols:
|
||||
notify("error", "AutoML: 未识别出目标列(feature_start_column 之前的所有数值列)")
|
||||
return [AutoMLResult(success=False, error="未识别出目标列")]
|
||||
|
||||
feat_cols = [c for c in df.columns if c not in y_cols]
|
||||
X_all = df[feat_cols].values.astype(np.float64)
|
||||
|
||||
# ---- 3) 预处理(仅第一项) ----
|
||||
if preproc != "None":
|
||||
try:
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
processed = Preprocessing(preproc, df[feat_cols])
|
||||
if isinstance(processed, pd.DataFrame):
|
||||
X_all = processed.values.astype(np.float64)
|
||||
else:
|
||||
X_all = np.asarray(processed, dtype=np.float64)
|
||||
except Exception as e:
|
||||
notify("warning", f"预处理 {preproc} 失败: {e!r},改用 None")
|
||||
preproc = "None"
|
||||
|
||||
# ---- 4) 检查 Optuna 是否可用 ----
|
||||
try:
|
||||
import optuna
|
||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||
optuna_available = True
|
||||
except ImportError:
|
||||
optuna_available = False
|
||||
notify("warning", "optuna 未安装,全目标列回退到 GridSearchCV(pip install \"optuna>=3.6\")")
|
||||
|
||||
# ---- 5) 逐 target 跑 ----
|
||||
results: List[AutoMLResult] = []
|
||||
total = len(y_cols)
|
||||
per_model_timeout = max(10.0, timeout_sec / max(1, len(model_names)))
|
||||
|
||||
for ti, tgt in enumerate(y_cols, 1):
|
||||
t0 = time.time()
|
||||
yv = df[tgt].values.astype(np.float64)
|
||||
mask = ~np.isnan(yv)
|
||||
X_t = X_all[mask]
|
||||
y_t = yv[mask]
|
||||
|
||||
if X_t.shape[0] < cv_folds * 2:
|
||||
notify("warning", f"目标 {tgt}: 有效样本 {X_t.shape[0]} 不足,跳过")
|
||||
results.append(AutoMLResult(
|
||||
success=False, target_column=tgt, error=f"样本不足({X_t.shape[0]})",
|
||||
preprocessing=preproc,
|
||||
))
|
||||
continue
|
||||
|
||||
X_sub, y_sub, was_sub = smart_subsample(X_t, y_t, max_samples=max_samples, random_state=random_state)
|
||||
if was_sub:
|
||||
notify("info", f"目标 {tgt}: {X_t.shape[0]} 样本 → 子采样 {X_sub.shape[0]}(寻优用)")
|
||||
|
||||
best_overall = AutoMLResult(success=False, target_column=tgt, preprocessing=preproc)
|
||||
|
||||
if not optuna_available:
|
||||
# 全目标列一次性 fallback
|
||||
best_overall = _fallback_train(
|
||||
training_csv_path, feature_start_column, preproc, model_names[0], split_method,
|
||||
cv_folds, out_dir, tgt,
|
||||
)
|
||||
else:
|
||||
for model_name in model_names:
|
||||
try:
|
||||
builder = _build_model(model_name, random_state=random_state)
|
||||
if builder is None:
|
||||
notify("warning", f"模型 {model_name} 暂不支持 AutoML 寻优")
|
||||
continue
|
||||
|
||||
study = optuna.create_study(
|
||||
direction="maximize",
|
||||
sampler=optuna.samplers.TPESampler(seed=random_state),
|
||||
)
|
||||
study.optimize(
|
||||
_make_objective(model_name, X_sub, y_sub, cv_folds, random_state),
|
||||
n_trials=n_trials,
|
||||
timeout=per_model_timeout,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
if study.best_value is None or study.best_value <= -1.0:
|
||||
notify("warning", f"{tgt}/{model_name}: 全部 trial 失败(CV 全部 <= -1)")
|
||||
continue
|
||||
|
||||
# refit on FULL
|
||||
final_model = _refit_full(model_name, study.best_params, X_t, y_t, random_state)
|
||||
if final_model is None:
|
||||
continue
|
||||
|
||||
# 保存
|
||||
import joblib
|
||||
fname = f"{tgt}_{preproc}_{model_name}_AUTOML.joblib"
|
||||
fpath = preproc_dir / fname
|
||||
joblib.dump({
|
||||
"model": final_model,
|
||||
"target_column_name": tgt,
|
||||
"preprocess_method": preproc,
|
||||
"model_name": model_name,
|
||||
"metadata": {
|
||||
"automl": True,
|
||||
"best_params": study.best_params,
|
||||
"cv_score": float(study.best_value),
|
||||
"n_trials_done": len(study.trials),
|
||||
"n_samples_used_full": int(X_t.shape[0]),
|
||||
"n_samples_used_for_search": int(X_sub.shape[0]),
|
||||
"was_subsampled": was_sub,
|
||||
"split_method": split_method,
|
||||
},
|
||||
}, fpath)
|
||||
|
||||
cand = AutoMLResult(
|
||||
success=True,
|
||||
model_path=str(fpath),
|
||||
cv_score=float(study.best_value),
|
||||
best_params=study.best_params,
|
||||
target_column=tgt,
|
||||
preprocessing=preproc,
|
||||
model_name=model_name,
|
||||
n_trials_done=len(study.trials),
|
||||
n_samples_used=int(X_sub.shape[0]),
|
||||
metadata={"refit_on_full": True, "n_samples_full": int(X_t.shape[0])},
|
||||
)
|
||||
if cand.cv_score > best_overall.cv_score:
|
||||
best_overall = cand
|
||||
except Exception as e:
|
||||
notify("warning", f"目标 {tgt} / 模型 {model_name} 失败: {e!r}")
|
||||
continue
|
||||
|
||||
if not best_overall.success:
|
||||
notify("warning", f"目标 {tgt} 全部 Optuna trial 失败,回退 GridSearchCV")
|
||||
best_overall = _fallback_train(
|
||||
training_csv_path, feature_start_column, preproc, model_names[0], split_method,
|
||||
cv_folds, out_dir, tgt,
|
||||
)
|
||||
|
||||
best_overall.elapsed_sec = time.time() - t0
|
||||
results.append(best_overall)
|
||||
notify("info", f"AutoML 目标 {tgt} 完成 ({ti}/{total}) cv={best_overall.cv_score:.4f}")
|
||||
|
||||
# ---- 6) 汇总 json ----
|
||||
summary_path = out_dir / "automl_summary.json"
|
||||
try:
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump([asdict(r) for r in results], f, ensure_ascii=False, indent=2, default=str)
|
||||
except Exception as e:
|
||||
notify("warning", f"写 automl_summary.json 失败: {e!r}")
|
||||
|
||||
success_n = sum(1 for r in results if r.success)
|
||||
fallback_n = sum(1 for r in results if r.fallback_used)
|
||||
notify("completed", f"AutoML 训练完成 {success_n}/{len(results)} 成功({fallback_n} 走 fallback),汇总 {summary_path}")
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI 自测
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(description="AutoML 训练器 CLI 自测")
|
||||
p.add_argument("--csv", required=True, help="训练用 CSV(feature_start_column 之前的列为目标 y)")
|
||||
p.add_argument("--feature-start", default="0", help="特征起始列名或索引(默认 0)")
|
||||
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
|
||||
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
|
||||
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
|
||||
p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML")
|
||||
args = p.parse_args()
|
||||
|
||||
# 智能推断 feature_start_column 类型
|
||||
fsc: Any = args.feature_start
|
||||
try:
|
||||
fsc = int(fsc)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
res = train_with_automl(
|
||||
training_csv_path=args.csv,
|
||||
feature_start_column=fsc,
|
||||
n_trials=args.n_trials,
|
||||
timeout_sec=args.timeout,
|
||||
max_samples=args.max_samples,
|
||||
output_dir=args.out,
|
||||
)
|
||||
print(f"\n训练完成 {len(res)} 个目标")
|
||||
for r in res:
|
||||
marker = "✓" if r.success else "✗"
|
||||
fb = " [fallback]" if r.fallback_used else ""
|
||||
print(f" {marker} {r.target_column}: cv={r.cv_score:.4f} path={r.model_path}{fb}")
|
||||
@ -126,7 +126,7 @@ class DataPreparationStep:
|
||||
|
||||
@staticmethod
|
||||
def calculate_water_quality_indices(
|
||||
training_spectra_path: Optional[str] = None,
|
||||
training_csv_path: Optional[str] = None,
|
||||
formula_csv_file: Optional[str] = None,
|
||||
formula_names: Optional[List[str]] = None,
|
||||
output_file: Optional[str] = None,
|
||||
@ -153,8 +153,8 @@ class DataPreparationStep:
|
||||
notify("skipped", "跳过水质指数计算")
|
||||
return None
|
||||
|
||||
if training_spectra_path is None:
|
||||
raise ValueError("必须提供 training_spectra_path 参数")
|
||||
if training_csv_path is None:
|
||||
raise ValueError("必须提供 training_csv_path 参数")
|
||||
if formula_csv_file is None:
|
||||
raise ValueError("必须提供 formula_csv_file 参数")
|
||||
|
||||
@ -170,7 +170,7 @@ class DataPreparationStep:
|
||||
|
||||
from src.utils.band_math import BandMathCalculator
|
||||
|
||||
calculator = BandMathCalculator(training_spectra_path)
|
||||
calculator = BandMathCalculator(training_csv_path)
|
||||
result_df = calculator.process_formulas_from_csv(
|
||||
formula_csv_file=formula_csv_file,
|
||||
formula_names=formula_names,
|
||||
|
||||
@ -173,7 +173,7 @@ class WaterQualityInversionPipeline:
|
||||
self.interpolated_img_path = None # 存储插值后的影像路径
|
||||
self.deglint_img_path = None
|
||||
self.processed_csv_path = None
|
||||
self.training_spectra_path = None
|
||||
self.training_csv_path = None
|
||||
self.indices_path = None
|
||||
self.custom_regression_path = None
|
||||
|
||||
@ -511,7 +511,7 @@ class WaterQualityInversionPipeline:
|
||||
left_shoulder_wave: Optional[float] = None,
|
||||
valley_wave: Optional[float] = None,
|
||||
right_shoulder_wave: Optional[float] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
water_mask_path: Optional[Union[str, np.ndarray]] = None,
|
||||
interpolate_zeros: bool = False,
|
||||
interpolation_method: str = 'nearest',
|
||||
enabled: bool = True,
|
||||
@ -546,7 +546,7 @@ class WaterQualityInversionPipeline:
|
||||
left_shoulder_wave=left_shoulder_wave,
|
||||
valley_wave=valley_wave,
|
||||
right_shoulder_wave=right_shoulder_wave,
|
||||
water_mask=water_mask,
|
||||
water_mask=water_mask_path,
|
||||
interpolate_zeros=interpolate_zeros,
|
||||
interpolation_method=interpolation_method,
|
||||
enabled=enabled,
|
||||
@ -655,13 +655,13 @@ class WaterQualityInversionPipeline:
|
||||
water_mask_path=self.water_mask_path,
|
||||
output_dir=str(self.training_spectra_dir),
|
||||
)
|
||||
self.training_spectra_path = result
|
||||
self.training_csv_path = result
|
||||
self._record_step_time("步骤5: 提取训练样本点光谱", 0, 0)
|
||||
self._notify("completed", f"训练光谱数据已保存: {result}")
|
||||
return result
|
||||
|
||||
def step5_5_calculate_water_quality_indices(self,
|
||||
training_spectra_path: Optional[str] = None,
|
||||
training_csv_path: Optional[str] = None,
|
||||
formula_csv_file: Optional[str] = None,
|
||||
formula_names: Optional[List[str]] = None,
|
||||
output_file: Optional[str] = None,
|
||||
@ -669,29 +669,29 @@ class WaterQualityInversionPipeline:
|
||||
skip_dependency_check: bool = False) -> str:
|
||||
"""
|
||||
步骤5.5: 根据训练光谱计算水质光谱指数
|
||||
|
||||
|
||||
使用band_math.py中的方法实现,支持从公式CSV文件中批量计算指定公式
|
||||
|
||||
|
||||
Args:
|
||||
training_spectra_path: 训练光谱数据CSV路径(如果为None,使用步骤5的结果)
|
||||
training_csv_path: 训练光谱数据CSV路径(如果为None,使用步骤5的结果)
|
||||
formula_csv_file: 公式CSV文件路径,包含公式名称和具体公式
|
||||
formula_names: 要计算的公式名称列表,如果为None则计算所有公式
|
||||
output_file: 输出文件完整路径(支持绝对路径),如果为None则使用默认路径
|
||||
|
||||
|
||||
Returns:
|
||||
包含计算结果的新CSV文件路径
|
||||
"""
|
||||
# 参数解析(保留原逻辑)
|
||||
if training_spectra_path is not None:
|
||||
csv_path = training_spectra_path
|
||||
elif self.training_spectra_path is not None:
|
||||
csv_path = self.training_spectra_path
|
||||
if training_csv_path is not None:
|
||||
csv_path = training_csv_path
|
||||
elif self.training_csv_path is not None:
|
||||
csv_path = self.training_csv_path
|
||||
else:
|
||||
csv_path = None
|
||||
|
||||
self._notify("started", "步骤5.5: 计算水质光谱指数")
|
||||
result = DataPreparationStep.calculate_water_quality_indices(
|
||||
training_spectra_path=csv_path,
|
||||
training_csv_path=csv_path,
|
||||
formula_csv_file=formula_csv_file,
|
||||
formula_names=formula_names,
|
||||
output_file=output_file,
|
||||
@ -727,8 +727,8 @@ class WaterQualityInversionPipeline:
|
||||
# 参数解析(保留原逻辑)
|
||||
if training_csv_path is not None:
|
||||
final_csv_path = training_csv_path
|
||||
elif self.training_spectra_path is not None:
|
||||
final_csv_path = self.training_spectra_path
|
||||
elif self.training_csv_path is not None:
|
||||
final_csv_path = self.training_csv_path
|
||||
else:
|
||||
final_csv_path = None
|
||||
|
||||
@ -911,7 +911,7 @@ class WaterQualityInversionPipeline:
|
||||
print("="*80)
|
||||
|
||||
if training_csv_path is None:
|
||||
training_csv_path = self.training_spectra_path
|
||||
training_csv_path = self.training_csv_path
|
||||
if training_csv_path is None:
|
||||
raise ValueError("请提供训练数据CSV路径,或先执行步骤5")
|
||||
|
||||
@ -1033,7 +1033,7 @@ class WaterQualityInversionPipeline:
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
csv_path = self.training_spectra_path
|
||||
csv_path = self.training_csv_path
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供CSV文件路径,或先执行步骤5")
|
||||
|
||||
@ -1506,7 +1506,7 @@ class WaterQualityInversionPipeline:
|
||||
if 'step5' in config:
|
||||
self._notify("步骤5: 光谱提取", "start")
|
||||
self.step5_extract_training_spectra(**config['step5'])
|
||||
self._notify("步骤5: 光谱提取", "completed", f"(输出: {self.training_spectra_path})")
|
||||
self._notify("步骤5: 光谱提取", "completed", f"(输出: {self.training_csv_path})")
|
||||
else:
|
||||
self._notify("步骤5: 光谱提取", "skipped", "未配置")
|
||||
|
||||
@ -1615,7 +1615,7 @@ class WaterQualityInversionPipeline:
|
||||
|
||||
# 生成散点图
|
||||
if 'visualization' in config and config['visualization'].get('generate_scatter', True):
|
||||
if self.training_spectra_path and self.models_dir.exists():
|
||||
if self.training_csv_path and self.models_dir.exists():
|
||||
try:
|
||||
self._notify("可视化", "info", "生成模型评估散点图...")
|
||||
scatter_config = config['visualization'].get('scatter_config', {})
|
||||
@ -1653,7 +1653,7 @@ class WaterQualityInversionPipeline:
|
||||
|
||||
# 生成光谱曲线图
|
||||
if 'visualization' in config and config['visualization'].get('generate_spectrum', True):
|
||||
if self.training_spectra_path:
|
||||
if self.training_csv_path:
|
||||
try:
|
||||
self._notify("可视化", "info", "生成光谱曲线对比图...")
|
||||
spectrum_paths = self.generate_spectrum_comparison_plots(
|
||||
@ -1701,7 +1701,7 @@ class WaterQualityInversionPipeline:
|
||||
pipeline_info['step2'] = {'status': 'completed', 'output_file': str(self.glint_mask_path) if self.glint_mask_path else 'N/A'}
|
||||
pipeline_info['step3'] = {'status': 'completed', 'output_file': str(self.deglint_img_path) if self.deglint_img_path else 'N/A'}
|
||||
pipeline_info['step4'] = {'status': 'completed', 'output_file': str(self.processed_csv_path) if self.processed_csv_path else 'N/A'}
|
||||
pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_spectra_path) if self.training_spectra_path else 'N/A'}
|
||||
pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_csv_path) if self.training_csv_path else 'N/A'}
|
||||
pipeline_info['step5_5'] = {'status': 'completed', 'output_file': str(self.indices_path) if self.indices_path else 'N/A'}
|
||||
pipeline_info['step6'] = {'status': 'completed', 'output_file': str(self.models_dir)}
|
||||
pipeline_info['step6_75'] = {'status': 'completed', 'output_file': str(self.custom_regression_path) if self.custom_regression_path else 'N/A'}
|
||||
@ -1784,8 +1784,8 @@ class WaterQualityInversionPipeline:
|
||||
# 参数解析(保留原逻辑)
|
||||
if csv_path is not None:
|
||||
final_csv_path = csv_path
|
||||
elif self.training_spectra_path is not None:
|
||||
final_csv_path = self.training_spectra_path
|
||||
elif self.training_csv_path is not None:
|
||||
final_csv_path = self.training_csv_path
|
||||
else:
|
||||
final_csv_path = None
|
||||
|
||||
@ -2109,7 +2109,7 @@ def main():
|
||||
'interpolation_method': 'bilinear', # 插值方法: 'nearest'(邻近), 'bilinear'(双线性),
|
||||
# 'spline'(样条), 'kriging'(克里金)
|
||||
# 水域掩膜参数(可选):
|
||||
'water_mask':r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp", # None表示自动使用步骤1生成的掩膜,也可以提供:
|
||||
'water_mask_path':r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp", # None表示自动使用步骤1生成的掩膜,也可以提供:
|
||||
# # - numpy数组
|
||||
# # - 栅格文件路径(.dat/.tif)
|
||||
# # - shapefile路径(.shp)
|
||||
|
||||
430
src/gui/components/chart_dialogs.py
Normal file
430
src/gui/components/chart_dialogs.py
Normal file
@ -0,0 +1,430 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
图表与交互弹窗模块
|
||||
|
||||
包含 ChartViewerDialog、ChartBrowserDialog 和 InteractiveViewerDialog 类。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PyQt5.QtWidgets import (
|
||||
QDialog, QVBoxLayout, QHBoxLayout, QPushButton,
|
||||
QSizePolicy, QFileDialog, QMessageBox, QGroupBox,
|
||||
QListWidget, QLabel, QComboBox, QCheckBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
|
||||
class ChartViewerDialog(QDialog):
|
||||
"""图表查看器对话框"""
|
||||
def __init__(self, title="图表查看器", parent=None):
|
||||
super().__init__(parent)
|
||||
self.setWindowTitle(title)
|
||||
self.resize(1000, 700)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
self.figure = Figure(figsize=(10, 7))
|
||||
self.canvas = FigureCanvas(self.figure)
|
||||
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
|
||||
|
||||
self.toolbar = NavigationToolbar(self.canvas, self)
|
||||
|
||||
layout.addWidget(self.toolbar)
|
||||
layout.addWidget(self.canvas)
|
||||
|
||||
btn_layout = QHBoxLayout()
|
||||
|
||||
self.save_btn = QPushButton("保存图表")
|
||||
self.save_btn.clicked.connect(self.save_chart)
|
||||
btn_layout.addWidget(self.save_btn)
|
||||
|
||||
btn_layout.addStretch()
|
||||
|
||||
self.close_btn = QPushButton("关闭")
|
||||
self.close_btn.clicked.connect(self.close)
|
||||
btn_layout.addWidget(self.close_btn)
|
||||
|
||||
layout.addLayout(btn_layout)
|
||||
self.setLayout(layout)
|
||||
|
||||
def display_image(self, image_path):
|
||||
"""显示图片"""
|
||||
self.figure.clear()
|
||||
ax = self.figure.add_subplot(111)
|
||||
|
||||
try:
|
||||
import matplotlib.image as mpimg
|
||||
img = mpimg.imread(image_path)
|
||||
ax.imshow(img)
|
||||
ax.axis('off')
|
||||
self.figure.tight_layout()
|
||||
self.canvas.draw()
|
||||
self.current_image_path = image_path
|
||||
except Exception as e:
|
||||
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
|
||||
ha='center', va='center', transform=ax.transAxes)
|
||||
self.canvas.draw()
|
||||
|
||||
def display_custom_plot(self, plot_func):
|
||||
"""显示自定义绘图函数"""
|
||||
self.figure.clear()
|
||||
try:
|
||||
plot_func(self.figure)
|
||||
self.canvas.draw()
|
||||
except Exception as e:
|
||||
ax = self.figure.add_subplot(111)
|
||||
ax.text(0.5, 0.5, f'绘图失败:\n{str(e)}',
|
||||
ha='center', va='center', transform=ax.transAxes)
|
||||
self.canvas.draw()
|
||||
|
||||
def save_chart(self):
|
||||
"""保存图表"""
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存图表", "",
|
||||
"PNG图片 (*.png);;JPG图片 (*.jpg);;PDF文件 (*.pdf);;所有文件 (*.*)"
|
||||
)
|
||||
if file_path:
|
||||
try:
|
||||
self.figure.savefig(file_path, dpi=300, bbox_inches='tight')
|
||||
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
|
||||
|
||||
|
||||
class ChartBrowserDialog(QDialog):
|
||||
"""图表浏览器对话框"""
|
||||
def __init__(self, chart_files, parent=None):
|
||||
super().__init__(parent)
|
||||
self.chart_files = sorted(chart_files, key=lambda x: x.stat().st_mtime, reverse=True)
|
||||
self.current_index = 0
|
||||
self.setWindowTitle("图表浏览器")
|
||||
self.resize(1200, 800)
|
||||
self.init_ui()
|
||||
self.show_chart(0)
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
list_group = QGroupBox(f"图表列表 (共 {len(self.chart_files)} 个)")
|
||||
list_layout = QHBoxLayout()
|
||||
|
||||
self.chart_list = QListWidget()
|
||||
self.chart_list.setMaximumHeight(150)
|
||||
for chart_file in self.chart_files:
|
||||
self.chart_list.addItem(chart_file.name)
|
||||
self.chart_list.currentRowChanged.connect(self.show_chart)
|
||||
|
||||
list_layout.addWidget(self.chart_list)
|
||||
list_group.setLayout(list_layout)
|
||||
layout.addWidget(list_group)
|
||||
|
||||
self.figure = Figure(figsize=(12, 8))
|
||||
self.canvas = FigureCanvas(self.figure)
|
||||
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
|
||||
|
||||
self.toolbar = NavigationToolbar(self.canvas, self)
|
||||
layout.addWidget(self.toolbar)
|
||||
layout.addWidget(self.canvas, 1)
|
||||
|
||||
btn_layout = QHBoxLayout()
|
||||
|
||||
self.prev_btn = QPushButton("◀ 上一个")
|
||||
self.prev_btn.clicked.connect(self.prev_chart)
|
||||
btn_layout.addWidget(self.prev_btn)
|
||||
|
||||
self.next_btn = QPushButton("下一个 >")
|
||||
self.next_btn.clicked.connect(self.next_chart)
|
||||
btn_layout.addWidget(self.next_btn)
|
||||
|
||||
btn_layout.addStretch()
|
||||
|
||||
self.save_btn = QPushButton("💾 保存当前图表")
|
||||
self.save_btn.clicked.connect(self.save_current_chart)
|
||||
btn_layout.addWidget(self.save_btn)
|
||||
|
||||
self.close_btn = QPushButton("关闭")
|
||||
self.close_btn.clicked.connect(self.close)
|
||||
btn_layout.addWidget(self.close_btn)
|
||||
|
||||
layout.addLayout(btn_layout)
|
||||
self.setLayout(layout)
|
||||
|
||||
def show_chart(self, index):
|
||||
"""显示指定索引的图表"""
|
||||
if 0 <= index < len(self.chart_files):
|
||||
self.current_index = index
|
||||
self.chart_list.setCurrentRow(index)
|
||||
|
||||
chart_file = self.chart_files[index]
|
||||
self.figure.clear()
|
||||
ax = self.figure.add_subplot(111)
|
||||
|
||||
try:
|
||||
import matplotlib.image as mpimg
|
||||
img = mpimg.imread(str(chart_file))
|
||||
ax.imshow(img)
|
||||
ax.axis('off')
|
||||
ax.set_title(chart_file.name, fontsize=12, pad=10)
|
||||
self.figure.tight_layout()
|
||||
self.canvas.draw()
|
||||
except Exception as e:
|
||||
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
|
||||
ha='center', va='center', transform=ax.transAxes)
|
||||
self.canvas.draw()
|
||||
|
||||
self.prev_btn.setEnabled(index > 0)
|
||||
self.next_btn.setEnabled(index < len(self.chart_files) - 1)
|
||||
|
||||
def prev_chart(self):
|
||||
"""上一个图表"""
|
||||
if self.current_index > 0:
|
||||
self.show_chart(self.current_index - 1)
|
||||
|
||||
def next_chart(self):
|
||||
"""下一个图表"""
|
||||
if self.current_index < len(self.chart_files) - 1:
|
||||
self.show_chart(self.current_index + 1)
|
||||
|
||||
def save_current_chart(self):
|
||||
"""保存当前图表"""
|
||||
if 0 <= self.current_index < len(self.chart_files):
|
||||
current_file = self.chart_files[self.current_index]
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存图表", current_file.name,
|
||||
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
|
||||
)
|
||||
if file_path:
|
||||
try:
|
||||
import shutil
|
||||
shutil.copy(str(current_file), file_path)
|
||||
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
|
||||
|
||||
|
||||
class InteractiveViewerDialog(QDialog):
|
||||
"""交互式影像预览对话框:显示影像、参考点散点图、点击查询坐标/值"""
|
||||
|
||||
def __init__(self, parent, img_path, ref_csv=None):
|
||||
super().__init__(parent)
|
||||
self.img_path = img_path
|
||||
self.ref_csv = ref_csv
|
||||
self.geotransform = None
|
||||
self.fig = None
|
||||
self.canvas = None
|
||||
self.ax = None
|
||||
self.status_label = None
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
self.setWindowTitle("👁️ 交互式影像预览")
|
||||
self.setMinimumSize(900, 700)
|
||||
|
||||
layout = QVBoxLayout()
|
||||
|
||||
toolbar = QHBoxLayout()
|
||||
self.band_combo = QComboBox()
|
||||
self.band_combo.currentIndexChanged.connect(self.on_band_changed)
|
||||
toolbar.addWidget(QLabel("显示波段:"))
|
||||
toolbar.addWidget(self.band_combo)
|
||||
|
||||
self.gray_check = QCheckBox("灰度显示")
|
||||
self.gray_check.stateChanged.connect(self.on_band_changed)
|
||||
toolbar.addWidget(self.gray_check)
|
||||
toolbar.addStretch()
|
||||
layout.addLayout(toolbar)
|
||||
|
||||
try:
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.figure import Figure
|
||||
import matplotlib
|
||||
matplotlib.use('Qt5Agg')
|
||||
|
||||
self.fig = Figure(figsize=(10, 8))
|
||||
self.canvas = FigureCanvas(self.fig)
|
||||
self.ax = self.fig.add_subplot(111)
|
||||
self.fig.tight_layout()
|
||||
layout.addWidget(self.canvas)
|
||||
|
||||
self.load_and_display()
|
||||
|
||||
except ImportError as e:
|
||||
layout.addWidget(QLabel(f"Matplotlib 未安装: {e}"))
|
||||
|
||||
self.status_label = QLabel("点击影像查看像素坐标和经纬度")
|
||||
self.status_label.setStyleSheet("background:#f0f0f0;padding:4px;font-size:12px;")
|
||||
self.status_label.setWordWrap(True)
|
||||
layout.addWidget(self.status_label)
|
||||
|
||||
close_btn = QPushButton("关闭")
|
||||
close_btn.clicked.connect(self.close)
|
||||
layout.addWidget(close_btn)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def load_and_display(self):
|
||||
"""加载影像并显示"""
|
||||
from osgeo import gdal
|
||||
|
||||
dataset = gdal.Open(self.img_path)
|
||||
if dataset is None:
|
||||
self.status_label.setText(f"无法打开影像: {self.img_path}")
|
||||
return
|
||||
|
||||
self.geotransform = dataset.GetGeoTransform()
|
||||
self.projection = dataset.GetProjection()
|
||||
n_bands = dataset.RasterCount
|
||||
self.height = dataset.RasterYSize
|
||||
self.width = dataset.RasterXSize
|
||||
|
||||
self.band_combo.clear()
|
||||
if n_bands >= 3:
|
||||
for i in range(1, n_bands + 1):
|
||||
self.band_combo.addItem(f"RGB (B{i-0}, G{i-1}, R{i-2})" if i >= 3 else f"波段 {i}", i)
|
||||
self.band_combo.addItem(f"单波段 (B1)", 0)
|
||||
else:
|
||||
for i in range(1, n_bands + 1):
|
||||
self.band_combo.addItem(f"波段 {i}", i - 1)
|
||||
self.band_combo.setCurrentIndex(0)
|
||||
|
||||
self.dataset = dataset
|
||||
self.display_band(0, is_gray=False)
|
||||
self.load_ref_points()
|
||||
|
||||
def display_band(self, band_idx, is_gray=False):
|
||||
"""显示指定波段组合"""
|
||||
from osgeo import gdal
|
||||
import numpy as np
|
||||
|
||||
dataset = self.dataset
|
||||
self.ax.clear()
|
||||
|
||||
if is_gray or (self.band_combo.currentData() == 0 and dataset.RasterCount == 1):
|
||||
band = dataset.GetRasterBand(1 if band_idx == 0 else band_idx + 1)
|
||||
data = band.ReadAsArray()
|
||||
data = np.nan_to_num(data, nan=0.0)
|
||||
self.ax.imshow(data, cmap='gray')
|
||||
self.ax.set_title(f"波段 {band_idx + 1} (灰度)")
|
||||
else:
|
||||
n = min(3, dataset.RasterCount)
|
||||
bands_data = []
|
||||
for i in range(n):
|
||||
b = dataset.GetRasterBand(i + 1)
|
||||
bd = b.ReadAsArray()
|
||||
bd = np.nan_to_num(bd, nan=0.0)
|
||||
bands_data.append(bd)
|
||||
rgb = np.dstack(bands_data)
|
||||
|
||||
for i in range(rgb.shape[2]):
|
||||
p2, p98 = np.percentile(rgb[:, :, i], [2, 98])
|
||||
if p98 > p2:
|
||||
rgb[:, :, i] = np.clip((rgb[:, :, i] - p2) / (p98 - p2), 0, 1)
|
||||
else:
|
||||
rgb[:, :, i] = np.clip(rgb[:, :, i] / (p98 + 1e-6), 0, 1)
|
||||
|
||||
self.ax.imshow(rgb)
|
||||
self.ax.set_title(f"RGB 显示")
|
||||
|
||||
self.ax.set_xlabel("列 (Column)")
|
||||
self.ax.set_ylabel("行 (Row)")
|
||||
self.fig.tight_layout()
|
||||
self.canvas.draw()
|
||||
|
||||
self.cid = self.canvas.mpl_connect('button_press_event', self.on_click)
|
||||
|
||||
def on_band_changed(self):
|
||||
"""波段选择变化时更新显示"""
|
||||
if not hasattr(self, 'dataset'):
|
||||
return
|
||||
is_gray = self.gray_check.isChecked()
|
||||
band_data = self.band_combo.currentData()
|
||||
self.display_band(band_data if band_data != 0 else 0, is_gray=is_gray)
|
||||
|
||||
def load_ref_points(self):
|
||||
"""加载并显示参考点"""
|
||||
import os
|
||||
if not self.ref_csv or not os.path.isfile(self.ref_csv):
|
||||
return
|
||||
|
||||
try:
|
||||
import csv
|
||||
lon_list, lat_list = [], []
|
||||
with open(self.ref_csv, 'r', encoding='utf-8-sig') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
try:
|
||||
lon = float(row.get('Lon', row.get('lon', row.get('LON', 0))))
|
||||
lat = float(row.get('Lat', row.get('lat', row.get('LAT', 0))))
|
||||
if lon and lat:
|
||||
lon_list.append(lon)
|
||||
lat_list.append(lat)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
if not lon_list:
|
||||
return
|
||||
|
||||
px_list, py_list = [], []
|
||||
gt = self.geotransform
|
||||
if gt and (gt[1] != 0 or gt[5] != 0):
|
||||
for lon, lat in zip(lon_list, lat_list):
|
||||
px = (lon - gt[0]) / gt[1]
|
||||
py = (lat - gt[3]) / gt[5]
|
||||
if 0 <= px < self.width and 0 <= py < self.height:
|
||||
px_list.append(px)
|
||||
py_list.append(py)
|
||||
|
||||
if px_list:
|
||||
self.ax.scatter(px_list, py_list, c='red', s=40, marker='o',
|
||||
edgecolors='white', linewidths=0.8, zorder=5, alpha=0.9,
|
||||
label=f'参考点 ({len(px_list)}个)')
|
||||
self.ax.legend(loc='upper right', fontsize=9)
|
||||
self.fig.tight_layout()
|
||||
self.canvas.draw()
|
||||
self.status_label.setText(
|
||||
f"已加载 {len(px_list)} 个参考点(仅显示在影像范围内的点)"
|
||||
)
|
||||
except Exception as e:
|
||||
self.status_label.setText(f"加载参考点失败: {e}")
|
||||
|
||||
def pixel_to_geo(self, px, py):
|
||||
"""像素坐标转经纬度"""
|
||||
gt = self.geotransform
|
||||
if gt is None:
|
||||
return None, None
|
||||
lon = gt[0] + px * gt[1] + py * gt[2]
|
||||
lat = gt[3] + px * gt[4] + py * gt[5]
|
||||
return lon, lat
|
||||
|
||||
def on_click(self, event):
|
||||
"""鼠标点击事件"""
|
||||
if event.inaxes != self.ax or event.xdata is None or event.ydata is None:
|
||||
return
|
||||
|
||||
px, py = int(round(event.xdata)), int(round(event.ydata))
|
||||
if not (0 <= px < self.width and 0 <= py < self.height):
|
||||
return
|
||||
|
||||
from osgeo import gdal
|
||||
import numpy as np
|
||||
dataset = self.dataset
|
||||
n_bands = dataset.RasterCount
|
||||
vals = []
|
||||
for b in range(1, n_bands + 1):
|
||||
val = dataset.GetRasterBand(b).ReadAsArray()[py, px]
|
||||
vals.append(f"{val:.4f}" if isinstance(val, float) else str(val))
|
||||
|
||||
lon, lat = self.pixel_to_geo(px, py)
|
||||
geo_str = f"Lon={lon:.6f}, Lat={lat:.6f}" if lon is not None else "无地理参考"
|
||||
|
||||
self.status_label.setText(
|
||||
f"像素: (行={py}, 列={px}) | {geo_str} | "
|
||||
f"波段值: {' | '.join(vals[:5])}" +
|
||||
(f" ... ({n_bands}波段的更多信息)" if n_bands > 5 else "")
|
||||
)
|
||||
50
src/gui/components/data_models.py
Normal file
50
src/gui/components/data_models.py
Normal file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据模型模块
|
||||
|
||||
包含 PandasTableModel 等数据模型类。
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from PyQt5.QtCore import Qt, QAbstractTableModel
|
||||
|
||||
|
||||
class PandasTableModel(QAbstractTableModel):
|
||||
"""支持DataFrame的表格模型"""
|
||||
def __init__(self, data_frame: pd.DataFrame):
|
||||
super().__init__()
|
||||
self._data = data_frame.copy()
|
||||
if self._data.empty:
|
||||
self._data = pd.DataFrame()
|
||||
self._data.fillna("", inplace=True)
|
||||
self._columns = [str(col) for col in self._data.columns]
|
||||
|
||||
def rowCount(self, parent=None):
|
||||
return len(self._data)
|
||||
|
||||
def columnCount(self, parent=None):
|
||||
return len(self._columns)
|
||||
|
||||
def data(self, index, role=Qt.DisplayRole):
|
||||
if not index.isValid() or role != Qt.DisplayRole:
|
||||
return None
|
||||
|
||||
value = self._data.iat[index.row(), index.column()]
|
||||
if pd.isna(value):
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
def headerData(self, section, orientation, role=Qt.DisplayRole):
|
||||
if role != Qt.DisplayRole:
|
||||
return None
|
||||
if orientation == Qt.Horizontal:
|
||||
if section < len(self._columns):
|
||||
return self._columns[section]
|
||||
return str(section)
|
||||
return str(section + 1)
|
||||
|
||||
def flags(self, index):
|
||||
if not index.isValid():
|
||||
return Qt.NoItemFlags
|
||||
return Qt.ItemIsEnabled | Qt.ItemIsSelectable
|
||||
351
src/gui/components/image_widgets.py
Normal file
351
src/gui/components/image_widgets.py
Normal file
@ -0,0 +1,351 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
图像浏览组件模块
|
||||
|
||||
包含 ImageCategoryTree 和 ImageViewerWidget 类。
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QTreeWidget, QTreeWidgetItem, QWidget, QVBoxLayout, QHBoxLayout,
|
||||
QPushButton, QLabel, QScrollArea, QFrame, QGroupBox,
|
||||
QFileDialog, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt, QTimer
|
||||
from PyQt5.QtGui import QPixmap
|
||||
|
||||
|
||||
class ImageCategoryTree(QTreeWidget):
|
||||
"""图像分类目录树 - 按类别组织图像文件"""
|
||||
|
||||
CATEGORIES = [
|
||||
("模型评估", ["scatter", "regression", "validation", "r2", "rmse"], "📊"),
|
||||
("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"),
|
||||
("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"),
|
||||
("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"),
|
||||
("含量分布图", [], "📁"),
|
||||
]
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.setHeaderLabel("图像目录")
|
||||
self.setMaximumWidth(300)
|
||||
self.setMinimumWidth(250)
|
||||
self.setup_categories()
|
||||
self.setStyleSheet("""
|
||||
QTreeWidget {
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 5px;
|
||||
background-color: #f8f9fa;
|
||||
}
|
||||
QTreeWidget::item {
|
||||
padding: 5px;
|
||||
border-radius: 3px;
|
||||
}
|
||||
QTreeWidget::item:selected {
|
||||
background-color: #0078D4;
|
||||
color: white;
|
||||
}
|
||||
QTreeWidget::item:hover {
|
||||
background-color: #e3f2fd;
|
||||
}
|
||||
""")
|
||||
|
||||
def setup_categories(self):
|
||||
"""初始化类别节点"""
|
||||
self.category_items = {}
|
||||
for category_name, keywords, icon in self.CATEGORIES:
|
||||
item = QTreeWidgetItem(self)
|
||||
item.setText(0, f"{icon} {category_name}")
|
||||
item.setData(0, Qt.UserRole, {"type": "category", "keywords": keywords, "name": category_name})
|
||||
item.setExpanded(True)
|
||||
self.category_items[category_name] = item
|
||||
|
||||
def clear_all_images(self):
|
||||
"""清除所有图像项"""
|
||||
for category_item in self.category_items.values():
|
||||
while category_item.childCount() > 0:
|
||||
category_item.removeChild(category_item.child(0))
|
||||
|
||||
def add_image(self, file_path: Path, display_name: str = None):
|
||||
"""添加图像到对应的类别"""
|
||||
if display_name is None:
|
||||
display_name = file_path.stem
|
||||
|
||||
category = self._determine_category(file_path.name)
|
||||
category_item = self.category_items.get(category, self.category_items["含量分布图"])
|
||||
|
||||
image_item = QTreeWidgetItem(category_item)
|
||||
image_item.setText(0, f" └─ {display_name}")
|
||||
image_item.setData(0, Qt.UserRole, {"type": "image", "path": str(file_path)})
|
||||
image_item.setToolTip(0, str(file_path))
|
||||
|
||||
return image_item
|
||||
|
||||
def _determine_category(self, filename: str) -> str:
|
||||
"""根据文件名确定类别"""
|
||||
filename_lower = filename.lower()
|
||||
|
||||
for category_name, keywords, _ in self.CATEGORIES:
|
||||
if any(keyword in filename_lower for keyword in keywords):
|
||||
return category_name
|
||||
|
||||
return "含量分布图"
|
||||
|
||||
def scan_directory(self, work_dir: str):
|
||||
"""扫描目录中的所有图像文件"""
|
||||
self.clear_all_images()
|
||||
|
||||
work_path = Path(work_dir)
|
||||
if not work_path.exists():
|
||||
return
|
||||
|
||||
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff', '*.bmp']
|
||||
scan_roots: List[Path] = []
|
||||
_viz = work_path / "14_visualization"
|
||||
if _viz.is_dir():
|
||||
scan_roots.append(_viz)
|
||||
_wm = work_path / "1_water_mask"
|
||||
if _wm.is_dir():
|
||||
scan_roots.append(_wm)
|
||||
if not scan_roots:
|
||||
scan_roots.append(work_path)
|
||||
|
||||
seen_norm: set = set()
|
||||
image_files: List[Path] = []
|
||||
for root in scan_roots:
|
||||
for ext in image_extensions:
|
||||
for p in root.glob(f"**/{ext}"):
|
||||
key = os.path.normcase(os.path.normpath(str(p.resolve())))
|
||||
if key in seen_norm:
|
||||
continue
|
||||
seen_norm.add(key)
|
||||
image_files.append(p)
|
||||
|
||||
for img_file in sorted(image_files):
|
||||
if img_file.name.startswith('.') or 'thumb' in img_file.name.lower():
|
||||
continue
|
||||
self.add_image(img_file)
|
||||
|
||||
for category_name, item in self.category_items.items():
|
||||
count = item.childCount()
|
||||
if count > 0:
|
||||
for cat_name, _, icon in self.CATEGORIES:
|
||||
if cat_name == category_name:
|
||||
item.setText(0, f"{icon} {category_name} ({count})")
|
||||
break
|
||||
|
||||
def get_selected_image_path(self) -> Optional[str]:
|
||||
"""获取当前选中的图像路径"""
|
||||
selected_item = self.currentItem()
|
||||
if not selected_item:
|
||||
return None
|
||||
|
||||
data = selected_item.data(0, Qt.UserRole)
|
||||
if data and data.get("type") == "image":
|
||||
return data.get("path")
|
||||
return None
|
||||
|
||||
|
||||
class ImageViewerWidget(QWidget):
|
||||
"""图像查看器组件 - 支持缩放、平移"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.current_image_path = None
|
||||
self.scale_factor = 1.0
|
||||
self._update_timer = QTimer()
|
||||
self._update_timer.setSingleShot(True)
|
||||
self._update_timer.timeout.connect(self._do_update_display)
|
||||
self._pending_scale = None
|
||||
self.setup_ui()
|
||||
|
||||
def setup_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
toolbar = QHBoxLayout()
|
||||
|
||||
self.refresh_btn = QPushButton("🔄 刷新目录")
|
||||
self.refresh_btn.setToolTip("重新扫描工作目录中的图像文件")
|
||||
toolbar.addWidget(self.refresh_btn)
|
||||
|
||||
separator = QFrame()
|
||||
separator.setFrameShape(QFrame.VLine)
|
||||
separator.setFrameShadow(QFrame.Sunken)
|
||||
toolbar.addWidget(separator)
|
||||
|
||||
self.zoom_in_btn = QPushButton("🔍+")
|
||||
self.zoom_in_btn.setToolTip("放大")
|
||||
self.zoom_in_btn.setMaximumWidth(50)
|
||||
toolbar.addWidget(self.zoom_in_btn)
|
||||
|
||||
self.zoom_out_btn = QPushButton("🔍-")
|
||||
self.zoom_out_btn.setToolTip("缩小")
|
||||
self.zoom_out_btn.setMaximumWidth(50)
|
||||
toolbar.addWidget(self.zoom_out_btn)
|
||||
|
||||
self.fit_btn = QPushButton("⬜ 适应窗口")
|
||||
self.fit_btn.setToolTip("适应窗口大小")
|
||||
toolbar.addWidget(self.fit_btn)
|
||||
|
||||
self.original_btn = QPushButton("1:1 原始大小")
|
||||
self.original_btn.setToolTip("原始大小")
|
||||
toolbar.addWidget(self.original_btn)
|
||||
|
||||
toolbar.addStretch()
|
||||
|
||||
self.save_btn = QPushButton("💾 保存")
|
||||
self.save_btn.setToolTip("保存当前图像")
|
||||
toolbar.addWidget(self.save_btn)
|
||||
|
||||
layout.addLayout(toolbar)
|
||||
|
||||
self.scroll_area = QScrollArea()
|
||||
self.scroll_area.setWidgetResizable(True)
|
||||
self.scroll_area.setStyleSheet("background-color: white;")
|
||||
|
||||
self.image_label = QLabel()
|
||||
self.image_label.setAlignment(Qt.AlignCenter)
|
||||
self.image_label.setStyleSheet("background-color: white;")
|
||||
|
||||
self.scroll_area.setWidget(self.image_label)
|
||||
layout.addWidget(self.scroll_area, 1)
|
||||
|
||||
status_layout = QHBoxLayout()
|
||||
self.status_label = QLabel("就绪")
|
||||
self.status_label.setStyleSheet("color: #666; font-size: 11px;")
|
||||
status_layout.addWidget(self.status_label)
|
||||
status_layout.addStretch()
|
||||
layout.addLayout(status_layout)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
self.zoom_in_btn.clicked.connect(self.zoom_in)
|
||||
self.zoom_out_btn.clicked.connect(self.zoom_out)
|
||||
self.fit_btn.clicked.connect(self.fit_to_window)
|
||||
self.original_btn.clicked.connect(self.original_size)
|
||||
self.save_btn.clicked.connect(self.save_image)
|
||||
|
||||
def load_image(self, image_path: str):
|
||||
"""加载并显示图像"""
|
||||
if not image_path or not Path(image_path).exists():
|
||||
self.image_label.setText("图像不存在")
|
||||
self.status_label.setText("图像加载失败")
|
||||
return
|
||||
|
||||
self.current_image_path = image_path
|
||||
self.scale_factor = 1.0
|
||||
|
||||
pixmap = QPixmap(image_path)
|
||||
if pixmap.isNull():
|
||||
self.image_label.setText("无法加载图像")
|
||||
self.status_label.setText("图像格式不支持")
|
||||
return
|
||||
|
||||
self.original_pixmap = pixmap
|
||||
self.fit_to_window()
|
||||
|
||||
file_info = Path(image_path).stat()
|
||||
size_mb = file_info.st_size / (1024 * 1024)
|
||||
self.status_label.setText(f"{pixmap.width()}x{pixmap.height()} | {size_mb:.2f} MB | {Path(image_path).name} | 适应窗口")
|
||||
|
||||
def update_image_display(self):
|
||||
"""更新图像显示 - 使用防抖避免频繁重绘卡顿"""
|
||||
self._update_timer.stop()
|
||||
self._pending_scale = self.scale_factor
|
||||
self._update_timer.start(50)
|
||||
|
||||
def _do_update_display(self):
|
||||
"""实际执行图像更新"""
|
||||
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
|
||||
return
|
||||
|
||||
if self._pending_scale is None:
|
||||
return
|
||||
|
||||
if self._pending_scale > 2.0 or self._pending_scale < 0.5:
|
||||
transform = Qt.FastTransformation
|
||||
else:
|
||||
transform = Qt.SmoothTransformation
|
||||
|
||||
scaled_pixmap = self.original_pixmap.scaled(
|
||||
int(self.original_pixmap.width() * self._pending_scale),
|
||||
int(self.original_pixmap.height() * self._pending_scale),
|
||||
Qt.KeepAspectRatio,
|
||||
transform
|
||||
)
|
||||
self.image_label.setPixmap(scaled_pixmap)
|
||||
self._pending_scale = None
|
||||
|
||||
def wheelEvent(self, event):
|
||||
"""鼠标滚轮缩放 - 实时响应"""
|
||||
delta = event.angleDelta().y()
|
||||
|
||||
if delta > 0:
|
||||
if self.scale_factor < 5.0:
|
||||
self.scale_factor = min(self.scale_factor * 1.1, 5.0)
|
||||
self.update_image_display()
|
||||
else:
|
||||
if self.scale_factor > 0.1:
|
||||
self.scale_factor = max(self.scale_factor / 1.1, 0.1)
|
||||
self.update_image_display()
|
||||
|
||||
event.accept()
|
||||
|
||||
def zoom_in(self):
|
||||
"""放大"""
|
||||
if self.scale_factor < 5.0:
|
||||
self.scale_factor = min(self.scale_factor * 1.25, 5.0)
|
||||
self.update_image_display()
|
||||
|
||||
def zoom_out(self):
|
||||
"""缩小"""
|
||||
if self.scale_factor > 0.1:
|
||||
self.scale_factor = max(self.scale_factor / 1.25, 0.1)
|
||||
self.update_image_display()
|
||||
|
||||
def fit_to_window(self):
|
||||
"""适应窗口"""
|
||||
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
|
||||
return
|
||||
|
||||
view_size = self.scroll_area.viewport().size()
|
||||
img_size = self.original_pixmap.size()
|
||||
|
||||
scale_w = view_size.width() / img_size.width()
|
||||
scale_h = view_size.height() / img_size.height()
|
||||
|
||||
self._fit_scale = min(scale_w, scale_h)
|
||||
self.scale_factor = self._fit_scale
|
||||
|
||||
self.update_image_display()
|
||||
self.status_label.setText(f"适应窗口 | 缩放: {self.scale_factor:.1%}")
|
||||
|
||||
def original_size(self):
|
||||
"""原始大小"""
|
||||
self.scale_factor = 1.0
|
||||
self._fit_scale = None
|
||||
self.update_image_display()
|
||||
self.status_label.setText("原始大小 | 缩放: 100%")
|
||||
|
||||
def save_image(self):
|
||||
"""保存图像"""
|
||||
if not self.current_image_path:
|
||||
return
|
||||
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存图像", Path(self.current_image_path).name,
|
||||
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
|
||||
)
|
||||
|
||||
if file_path:
|
||||
try:
|
||||
import shutil
|
||||
shutil.copy(self.current_image_path, file_path)
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "错误", f"保存失败: {e}")
|
||||
112
src/gui/core/test_modeling.py
Normal file
112
src/gui/core/test_modeling.py
Normal file
@ -0,0 +1,112 @@
|
||||
import time
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.datasets import make_regression
|
||||
|
||||
# 屏蔽烦人的 sklearn 警告
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
print("====== 🚀 启动 Mega Water 模型终极体检脚本 ======")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 1. 完美复刻侦察报告中的 CSV 数据结构
|
||||
# 报告指出: 目标值(y)在左边,光谱特征(X)在右边
|
||||
# ---------------------------------------------------------
|
||||
print("📦 正在生成符合系统结构的模拟测试数据...")
|
||||
X_raw, y_raw = make_regression(n_samples=200, n_features=50, noise=0.1, random_state=42)
|
||||
|
||||
# 模拟真实的 CSV 列名:前2列是水质参数,后面是 50 个光谱波段
|
||||
columns = ['Chla', 'SS'] + [f"Band_{i}" for i in range(50)]
|
||||
# 拼装成一整张大表
|
||||
data = pd.DataFrame(np.hstack((y_raw.reshape(-1, 1), (y_raw * 0.5).reshape(-1, 1), X_raw)), columns=columns)
|
||||
|
||||
# 按照 load_data_batch 的逻辑进行切割
|
||||
feature_start_index = 2
|
||||
X = data.iloc[:, feature_start_index:] # 截取光谱作为 X
|
||||
y = data['Chla'] # 提取一个目标参数作为 y
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
print(f"✅ 数据切割完毕! 模拟波段数: {X.shape[1]}, 训练集样本数: {X_train.shape[0]}\n")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 2. 严格装载侦察报告中的 16 个真实模型
|
||||
# ---------------------------------------------------------
|
||||
print("🔍 正在加载底层真实配置库中的模型...")
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.tree import DecisionTreeRegressor
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
|
||||
# 将参数压至极低,实施“降维打击”,确保 1 秒内跑完
|
||||
models = {
|
||||
'SVR': SVR(),
|
||||
'RF': RandomForestRegressor(n_estimators=10, max_depth=5, n_jobs=-1),
|
||||
'KNN': KNeighborsRegressor(),
|
||||
'LinearRegression': LinearRegression(),
|
||||
'Ridge': Ridge(),
|
||||
'Lasso': Lasso(),
|
||||
'ElasticNet': ElasticNet(),
|
||||
'PLS': PLSRegression(),
|
||||
'GradientBoosting': GradientBoostingRegressor(n_estimators=10, max_depth=5),
|
||||
'AdaBoost': AdaBoostRegressor(n_estimators=10),
|
||||
'DecisionTree': DecisionTreeRegressor(max_depth=5),
|
||||
'MLP': MLPRegressor(max_iter=50),
|
||||
'ExtraTrees': ExtraTreesRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
|
||||
}
|
||||
|
||||
# 针对报告中发现的 3 个“被禁用”的第三方强力库,进行刺探测试
|
||||
try:
|
||||
from xgboost import XGBRegressor
|
||||
|
||||
models['XGBoost'] = XGBRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
|
||||
except ImportError:
|
||||
models['XGBoost'] = "IMPORT_ERROR"
|
||||
|
||||
try:
|
||||
from lightgbm import LGBMRegressor
|
||||
|
||||
models['LightGBM'] = LGBMRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
|
||||
except ImportError:
|
||||
models['LightGBM'] = "IMPORT_ERROR"
|
||||
|
||||
try:
|
||||
from catboost import CatBoostRegressor
|
||||
|
||||
models['CatBoost'] = CatBoostRegressor(iterations=10, depth=5, verbose=0)
|
||||
except ImportError:
|
||||
models['CatBoost'] = "IMPORT_ERROR"
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 3. 开始残酷的体检循环
|
||||
# ---------------------------------------------------------
|
||||
print("\n================ 开始跑分测试 ================")
|
||||
results = []
|
||||
|
||||
for name, model in models.items():
|
||||
if model == "IMPORT_ERROR":
|
||||
results.append(f"⚠️ [缺库] {name:<16} : 环境未安装此库 (建议: pip install {name.lower()})")
|
||||
continue
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 极速拟合与评分
|
||||
model.fit(X_train, y_train)
|
||||
score = model.score(X_test, y_test)
|
||||
cost_time = time.time() - start_time
|
||||
results.append(f"✅ [成功] {name:<16} : 耗时 {cost_time:.3f} 秒 (R2: {score:.2f})")
|
||||
except Exception as e:
|
||||
error_msg = str(e).split('\n')[0]
|
||||
results.append(f"❌ [崩溃] {name:<16} : {error_msg}")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 4. 打印最终体检报告
|
||||
# ---------------------------------------------------------
|
||||
print("\n=============== 🏥 最终体检报告 ===============")
|
||||
for res in results:
|
||||
print(res)
|
||||
print("===============================================")
|
||||
346
src/gui/core/viz_thread.py
Normal file
346
src/gui/core/viz_thread.py
Normal file
@ -0,0 +1,346 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化后台线程模块
|
||||
|
||||
包含 VisualizationWorkerThread 后台线程类和辅助函数。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from PyQt5.QtCore import QThread, pyqtSignal
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _viz_infer_wavelength_start_column(df) -> Union[str, int]:
|
||||
"""推断光谱起始列(training_spectra 通常以波长数值为列名,未必含 UTM_Y)。"""
|
||||
import pandas as pd
|
||||
for i, col in enumerate(df.columns):
|
||||
name = str(col).strip().lstrip("\ufeff")
|
||||
try:
|
||||
v = float(name)
|
||||
except ValueError:
|
||||
continue
|
||||
if 200.0 <= v <= 3000.0:
|
||||
return i
|
||||
if "UTM_Y" in df.columns:
|
||||
return "UTM_Y"
|
||||
return 0
|
||||
|
||||
|
||||
class VisualizationWorkerThread(QThread):
|
||||
"""可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。"""
|
||||
|
||||
finished_ok = pyqtSignal(object)
|
||||
failed = pyqtSignal(str)
|
||||
|
||||
def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None):
|
||||
super().__init__()
|
||||
self.task = task
|
||||
self.work_dir = str(work_dir)
|
||||
self.extra = extra or {}
|
||||
|
||||
def run(self):
|
||||
mpl_prev = None
|
||||
try:
|
||||
import matplotlib
|
||||
mpl_prev = matplotlib.get_backend()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("Agg")
|
||||
except Exception:
|
||||
mpl_prev = None
|
||||
try:
|
||||
wp = Path(self.work_dir)
|
||||
if self.task == "mask_glint":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
preview_paths = viz.generate_glint_deglint_previews(
|
||||
work_dir=str(wp),
|
||||
output_subdir="glint_deglint_previews",
|
||||
)
|
||||
cnt = len(preview_paths) if preview_paths else 0
|
||||
self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths})
|
||||
elif self.task == "sampling_map":
|
||||
hyperspectral_files = []
|
||||
deglint_dir = wp / "3_deglint"
|
||||
if deglint_dir.exists():
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
|
||||
if not hyperspectral_files:
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
|
||||
if not hyperspectral_files:
|
||||
self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif)。")
|
||||
return
|
||||
hyperspectral_path = str(hyperspectral_files[0])
|
||||
csv_files = []
|
||||
processed_dir = wp / "4_processed_data"
|
||||
if processed_dir.exists():
|
||||
csv_files = list(processed_dir.glob("*.csv"))
|
||||
if not csv_files:
|
||||
csv_files = (
|
||||
list(wp.glob("**/*sampling*.csv"))
|
||||
+ list(wp.glob("**/*point*.csv"))
|
||||
+ list(wp.glob("**/*.csv"))
|
||||
)
|
||||
if not csv_files:
|
||||
self.failed.emit("未找到采样点 CSV 文件。")
|
||||
return
|
||||
csv_path = str(csv_files[0])
|
||||
from src.postprocessing.point_map import SamplingPointMap
|
||||
map_generator = SamplingPointMap(
|
||||
output_dir=str(wp / "14_visualization" / "sampling_maps"),
|
||||
fast_mode=True,
|
||||
)
|
||||
map_path = map_generator.create_sampling_point_map(
|
||||
hyperspectral_path=hyperspectral_path,
|
||||
csv_path=csv_path,
|
||||
point_color="red",
|
||||
point_size=100,
|
||||
point_alpha=0.9,
|
||||
show_north_arrow=True,
|
||||
show_scale_bar=True,
|
||||
show_legend=True,
|
||||
downsample=True,
|
||||
dpi=180,
|
||||
)
|
||||
self.finished_ok.emit(
|
||||
{
|
||||
"task": "sampling_map",
|
||||
"map_path": map_path,
|
||||
"hyperspectral_path": hyperspectral_path,
|
||||
"csv_path": csv_path,
|
||||
}
|
||||
)
|
||||
elif self.task == "spectrum":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
csv_file = self.extra.get("csv_path")
|
||||
wl = self.extra.get("wavelength_start_column", "UTM_Y")
|
||||
n_groups = int(self.extra.get("n_groups", 5))
|
||||
param_cols = self.extra.get("param_cols") or []
|
||||
if param_cols:
|
||||
output_paths: List[str] = []
|
||||
err_lines: List[str] = []
|
||||
for param_col in param_cols:
|
||||
try:
|
||||
out = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(csv_file),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl,
|
||||
n_groups=n_groups,
|
||||
)
|
||||
output_paths.append(out)
|
||||
except Exception as _ex:
|
||||
err_lines.append(f"{param_col}: {_ex}")
|
||||
if not output_paths:
|
||||
self.failed.emit(
|
||||
"所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20])
|
||||
)
|
||||
return
|
||||
self.finished_ok.emit(
|
||||
{
|
||||
"task": "spectrum",
|
||||
"output_paths": output_paths,
|
||||
"errors": err_lines,
|
||||
}
|
||||
)
|
||||
else:
|
||||
param_col = self.extra.get("param_col")
|
||||
out = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(csv_file),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl,
|
||||
n_groups=n_groups,
|
||||
)
|
||||
self.finished_ok.emit(
|
||||
{"task": "spectrum", "output_path": out, "param_col": param_col}
|
||||
)
|
||||
elif self.task == "statistics":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
csv_file = self.extra.get("csv_path")
|
||||
param_cols = self.extra.get("param_cols") or []
|
||||
output_paths = viz.plot_statistical_charts(
|
||||
csv_path=str(csv_file),
|
||||
parameter_columns=param_cols,
|
||||
)
|
||||
self.finished_ok.emit(
|
||||
{"task": "statistics", "output_paths": output_paths}
|
||||
)
|
||||
elif self.task == "scatter":
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
|
||||
training_csv_path = (self.extra.get("training_csv_path") or "").strip()
|
||||
models_dir = (self.extra.get("models_dir") or "").strip()
|
||||
if not training_csv_path or not Path(training_csv_path).is_file():
|
||||
self.failed.emit("训练光谱 CSV 无效或不存在,请确认已选择步骤5输出的文件。")
|
||||
return
|
||||
if not models_dir or not Path(models_dir).is_dir():
|
||||
self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
|
||||
return
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||||
training_csv_path=training_csv_path,
|
||||
models_dir=models_dir,
|
||||
)
|
||||
self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}})
|
||||
elif self.task == "generate_all_selected":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
parts = []
|
||||
|
||||
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
|
||||
|
||||
if self.extra.get("gen_scatter"):
|
||||
if training_csv.is_file():
|
||||
models_dir = wp / "7_Supervised_Model_Training"
|
||||
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||||
training_csv_path=str(training_csv),
|
||||
models_dir=str(models_dir),
|
||||
)
|
||||
count = len(scatter_paths) if scatter_paths else 0
|
||||
parts.append(f"散点图: {count} 个")
|
||||
else:
|
||||
parts.append("散点图: 跳过(无模型目录)")
|
||||
else:
|
||||
parts.append("散点图: 跳过(无训练数据)")
|
||||
|
||||
if self.extra.get("gen_spectrum"):
|
||||
if training_csv.is_file():
|
||||
import pandas as pd
|
||||
df = pd.read_csv(training_csv)
|
||||
wl_col = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl_col, str):
|
||||
idx = int(df.columns.get_loc(wl_col)) + 1
|
||||
else:
|
||||
idx = int(wl_col)
|
||||
param_cols = []
|
||||
if idx > 0 and idx < len(df.columns):
|
||||
param_cols = [
|
||||
c for c in df.columns[:idx]
|
||||
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
|
||||
]
|
||||
if param_cols:
|
||||
spectrum_paths = []
|
||||
for param_col in param_cols:
|
||||
try:
|
||||
path = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(training_csv),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl_col,
|
||||
n_groups=5,
|
||||
)
|
||||
if path:
|
||||
spectrum_paths.append(path)
|
||||
except Exception as e:
|
||||
print(f"生成光谱图失败 ({param_col}): {e}")
|
||||
count = len(spectrum_paths)
|
||||
parts.append(f"光谱图: {count} 个")
|
||||
else:
|
||||
parts.append("光谱图: 跳过(无可用参数列)")
|
||||
else:
|
||||
parts.append("光谱图: 跳过(无训练数据)")
|
||||
|
||||
if self.extra.get("gen_boxplots"):
|
||||
if training_csv.is_file():
|
||||
import pandas as pd
|
||||
df = pd.read_csv(training_csv)
|
||||
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
|
||||
param_cols = [
|
||||
c for c in df.select_dtypes(include=[np.number]).columns
|
||||
if not any(exc in c.lower() for exc in exclude_cols)
|
||||
]
|
||||
wl = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl, str):
|
||||
idx = int(df.columns.get_loc(wl)) + 1
|
||||
else:
|
||||
idx = int(wl)
|
||||
if 0 < idx < len(df.columns):
|
||||
meta_set = set(df.columns[:idx])
|
||||
param_cols = [c for c in param_cols if c in meta_set]
|
||||
|
||||
if param_cols:
|
||||
output_dict = viz.plot_statistical_charts(
|
||||
csv_path=str(training_csv),
|
||||
parameter_columns=param_cols,
|
||||
)
|
||||
count = len([v for v in output_dict.values() if v]) if output_dict else 0
|
||||
parts.append(f"统计图: {count} 个")
|
||||
else:
|
||||
parts.append("统计图: 跳过(无可用水质参数列)")
|
||||
else:
|
||||
parts.append("统计图: 跳过(无训练数据)")
|
||||
|
||||
if self.extra.get("gen_mask_glint"):
|
||||
preview_paths = viz.generate_glint_deglint_previews(
|
||||
work_dir=str(wp),
|
||||
output_subdir="glint_deglint_previews",
|
||||
)
|
||||
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0} 个")
|
||||
|
||||
if self.extra.get("gen_sampling_map"):
|
||||
hyperspectral_files = []
|
||||
deglint_dir = wp / "3_deglint"
|
||||
if deglint_dir.exists():
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
|
||||
if not hyperspectral_files:
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
|
||||
if hyperspectral_files:
|
||||
hyperspectral_path = str(hyperspectral_files[0])
|
||||
csv_files = []
|
||||
processed_dir = wp / "4_processed_data"
|
||||
if processed_dir.exists():
|
||||
csv_files = list(processed_dir.glob("*.csv"))
|
||||
if not csv_files:
|
||||
csv_files = (
|
||||
list(wp.glob("**/*sampling*.csv"))
|
||||
+ list(wp.glob("**/*point*.csv"))
|
||||
+ list(wp.glob("**/*.csv"))
|
||||
)
|
||||
if csv_files:
|
||||
csv_path = str(csv_files[0])
|
||||
from src.postprocessing.point_map import SamplingPointMap
|
||||
map_generator = SamplingPointMap(
|
||||
output_dir=str(wp / "14_visualization" / "sampling_maps"),
|
||||
fast_mode=True,
|
||||
)
|
||||
map_path = map_generator.create_sampling_point_map(
|
||||
hyperspectral_path=hyperspectral_path,
|
||||
csv_path=csv_path,
|
||||
point_color="red",
|
||||
point_size=100,
|
||||
point_alpha=0.9,
|
||||
show_north_arrow=True,
|
||||
show_scale_bar=True,
|
||||
show_legend=True,
|
||||
downsample=True,
|
||||
dpi=180,
|
||||
)
|
||||
parts.append(f"采样点图: {Path(map_path).name}")
|
||||
else:
|
||||
parts.append("采样点图: 跳过(无CSV)")
|
||||
else:
|
||||
parts.append("采样点图: 跳过(无影像)")
|
||||
self.finished_ok.emit({"task": "generate_all_selected", "parts": parts})
|
||||
else:
|
||||
self.failed.emit(f"未知可视化任务: {self.task}")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
self.failed.emit(f"{e}\n{traceback.format_exc()}")
|
||||
finally:
|
||||
if mpl_prev:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend(mpl_prev)
|
||||
except Exception:
|
||||
pass
|
||||
93
src/gui/crash_dump.txt
Normal file
93
src/gui/crash_dump.txt
Normal file
@ -0,0 +1,93 @@
|
||||
|
||||
============================================================
|
||||
[2026-05-12 11:14:51]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 130, in <module>
|
||||
from src.gui.panels.step9_panel import Step9Panel
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\panels\step9_panel.py", line 24, in <module>
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\core\water_quality_inversion_pipeline_GUI.py", line 45, in <module>
|
||||
from src.preprocessing.process_water_quality_data import process_water_quality_data
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\preprocessing\process_water_quality_data.py", line 9, in <module>
|
||||
from scipy import stats
|
||||
File "<frozen importlib._bootstrap>", line 1412, in _handle_fromlist
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\__init__.py", line 143, in __getattr__
|
||||
return _importlib.import_module(f'scipy.{name}')
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\importlib\__init__.py", line 90, in import_module
|
||||
return _bootstrap._gcd_import(name[level:], package, level)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\__init__.py", line 632, in <module>
|
||||
from ._multicomp import *
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\_multicomp.py", line 11, in <module>
|
||||
from scipy.stats._qmc import check_random_state
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\_qmc.py", line 26, in <module>
|
||||
from scipy.sparse.csgraph import minimum_spanning_tree
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\sparse\csgraph\__init__.py", line 188, in <module>
|
||||
from ._shortest_path import (
|
||||
File "scipy/sparse/csgraph/_shortest_path.pyx", line 21, in init scipy.sparse.csgraph._shortest_path
|
||||
File "<frozen importlib._bootstrap>", line 1349, in _find_and_load
|
||||
KeyboardInterrupt
|
||||
|
||||
============================================================
|
||||
[2026-05-12 11:57:28]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
|
||||
main()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3093, in main
|
||||
_dialog.exec_()
|
||||
KeyboardInterrupt
|
||||
|
||||
============================================================
|
||||
[2026-05-28 15:45:11]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
|
||||
main()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3097, in main
|
||||
window = WaterQualityGUI()
|
||||
^^^^^^^^^^^^^^^^^
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1352, in __init__
|
||||
self.init_ui()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1586, in init_ui
|
||||
self.create_content_area()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1943, in create_content_area
|
||||
self.step2_panel = Step2Panel()
|
||||
^^^^^^^^^^^^
|
||||
TypeError: Step2Panel.__init__() missing 1 required positional argument: 'session'
|
||||
|
||||
============================================================
|
||||
[2026-05-28 15:45:19]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
|
||||
main()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3097, in main
|
||||
window = WaterQualityGUI()
|
||||
^^^^^^^^^^^^^^^^^
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1352, in __init__
|
||||
self.init_ui()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1586, in init_ui
|
||||
self.create_content_area()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1943, in create_content_area
|
||||
self.step2_panel = Step2Panel()
|
||||
^^^^^^^^^^^^
|
||||
TypeError: Step2Panel.__init__() missing 1 required positional argument: 'session'
|
||||
|
||||
============================================================
|
||||
[2026-05-28 16:00:53]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 2149, in on_step_changed
|
||||
self.auto_populate_step_inputs(item_data)
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 2362, in auto_populate_step_inputs
|
||||
if step_id not in self.step_dependencies:
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
AttributeError: 'WaterQualityGUI' object has no attribute 'step_dependencies'. Did you mean: '_init_step_dependencies'?
|
||||
|
||||
============================================================
|
||||
[2026-06-03 13:56:59]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3354, in <module>
|
||||
main()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3331, in main
|
||||
sys.exit(app.exec_())
|
||||
^^^^^^^^^^^
|
||||
KeyboardInterrupt
|
||||
@ -325,7 +325,7 @@ class Step3Panel(QWidget):
|
||||
}
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['water_mask'] = water_mask_path
|
||||
config['water_mask_path'] = water_mask_path
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
@ -366,8 +366,8 @@ class Step3Panel(QWidget):
|
||||
"""设置配置"""
|
||||
if 'img_path' in config:
|
||||
self.img_file.set_path(config['img_path'])
|
||||
if 'water_mask' in config:
|
||||
self.water_mask_file.set_path(config['water_mask'])
|
||||
if 'water_mask_path' in config:
|
||||
self.water_mask_file.set_path(config['water_mask_path'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
if 'reference_csv' in config:
|
||||
|
||||
@ -187,7 +187,7 @@ class Step5_5Panel(QWidget):
|
||||
def get_config(self):
|
||||
selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
|
||||
return {
|
||||
'training_spectra_path': self.training_data_widget.get_path(),
|
||||
'training_csv_path': self.training_data_widget.get_path(),
|
||||
'formula_csv_file': self.builtin_formula_path,
|
||||
'formula_names': selected,
|
||||
'output_file': self.output_file_widget.get_path(),
|
||||
@ -195,7 +195,7 @@ class Step5_5Panel(QWidget):
|
||||
}
|
||||
|
||||
def set_config(self, config):
|
||||
if 'training_spectra_path' in config: self.training_data_widget.set_path(config['training_spectra_path'])
|
||||
if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path'])
|
||||
if 'formula_names' in config:
|
||||
sel = set(config['formula_names'])
|
||||
for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
|
||||
@ -217,7 +217,7 @@ class Step5_5Panel(QWidget):
|
||||
|
||||
def run_step(self):
|
||||
config = self.get_config()
|
||||
if not config['training_spectra_path']:
|
||||
if not config['training_csv_path']:
|
||||
QMessageBox.warning(self, "提示", "请先选择输入数据")
|
||||
return
|
||||
parent = self.parent()
|
||||
|
||||
@ -124,7 +124,7 @@ class Step5Panel(QWidget):
|
||||
glint_mask_path = self.glint_mask_file.get_path()
|
||||
if glint_mask_path:
|
||||
config['glint_mask_path'] = glint_mask_path
|
||||
# 注意:step5_extract_training_spectra 不接受 output_path / training_spectra_path
|
||||
# 注意:step5_extract_training_spectra 不接受 output_path / training_csv_path
|
||||
# 参数,输出路径由 pipeline 内部根据 training_spectra_dir 自动生成。
|
||||
return config
|
||||
|
||||
|
||||
@ -363,7 +363,7 @@ class Step6Panel(QWidget):
|
||||
# 回退:从 Step5 的 config 字典中查找可能的键名
|
||||
step5_cfg = main_window.step5_panel.get_config()
|
||||
step5_csv = (
|
||||
step5_cfg.get('training_spectra_path')
|
||||
step5_cfg.get('training_csv_path')
|
||||
or step5_cfg.get('output_file')
|
||||
or step5_cfg.get('csv_path')
|
||||
or step5_cfg.get('output_csv')
|
||||
|
||||
BIN
src/gui/scaler_params.pkl
Normal file
BIN
src/gui/scaler_params.pkl
Normal file
Binary file not shown.
@ -1432,7 +1432,7 @@ class WaterQualityGUI(QMainWindow):
|
||||
'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤5可选耀斑掩膜
|
||||
},
|
||||
'step5_5': {
|
||||
'training_spectra_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱
|
||||
'training_csv_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱
|
||||
},
|
||||
'step6': {
|
||||
'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6需要训练光谱数据
|
||||
|
||||
Reference in New Issue
Block a user