diff --git a/check_lines.py b/check_lines.py
new file mode 100644
index 0000000..22c4dac
--- /dev/null
+++ b/check_lines.py
@@ -0,0 +1,6 @@
+import sys
+with open(r'D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py', 'rb') as f:
+ content = f.read()
+lines = content.split(b'\r\n')
+for i, line in enumerate(lines[2918:2955], start=2919):
+ sys.stdout.buffer.write(f'{i}: {repr(line[:120])}'.encode('utf-8') + b'\n')
\ No newline at end of file
diff --git a/src/core/pipeline/__init__.py b/src/core/pipeline/__init__.py
index 3a3d22c..00ba2ce 100644
--- a/src/core/pipeline/__init__.py
+++ b/src/core/pipeline/__init__.py
@@ -8,7 +8,17 @@ Pipeline 调度核心:基于 Context 的内存级依赖注入。
- 不绑定具体 Pipeline 实现(duck-typed),WorkerThread / Web API / 单测可共用
"""
-from .context import PipelineContext
-from .runner import StepSpec, PIPELINE_STEPS, PipelineRunner
+from .context import (
+ PipelineContext,
+ STEP_MAP_OLD_TO_NEW, STEP_MAP_NEW_TO_OLD,
+ resolve_step_id, ALL_STEP_IDS,
+)
+from .runner import (
+ StepSpec, PIPELINE_STEPS, PipelineRunner, PipelineHalt,
+)
-__all__ = ["PipelineContext", "StepSpec", "PIPELINE_STEPS", "PipelineRunner"]
+__all__ = [
+ "PipelineContext", "StepSpec", "PIPELINE_STEPS", "PipelineRunner", "PipelineHalt",
+ "STEP_MAP_OLD_TO_NEW", "STEP_MAP_NEW_TO_OLD",
+ "resolve_step_id", "ALL_STEP_IDS",
+]
diff --git a/src/core/pipeline/context.py b/src/core/pipeline/context.py
index b31a0ea..d6723cd 100644
--- a/src/core/pipeline/context.py
+++ b/src/core/pipeline/context.py
@@ -12,7 +12,34 @@ PipelineContext:内存级数据载体,跨 14 个 step 传递路径与元信
from __future__ import annotations
from dataclasses import dataclass, field
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Set
+
+
+# ============================================================
+# 步骤命名映射(定义在叶子节点,打破循环依赖)
+# ============================================================
+
+STEP_MAP_OLD_TO_NEW: Dict[str, str] = {
+ "step5_5": "step8",
+ "step6_5": "step8_non_empirical_modeling",
+ "step6_75": "step9",
+ "step8_5": "step11",
+ "step8_75": "step12",
+ "step7": "step10",
+ "step8": "step11_ml",
+ "step9": "step14",
+}
+
+STEP_MAP_NEW_TO_OLD: Dict[str, str] = {v: k for k, v in STEP_MAP_OLD_TO_NEW.items()}
+
+ALL_STEP_IDS: Set[str] = set(STEP_MAP_OLD_TO_NEW.keys()) | set(STEP_MAP_OLD_TO_NEW.values())
+
+
+def resolve_step_id(step_id: str) -> str:
+ """将任意 step_id 转换为标准新格式。"""
+ if step_id in STEP_MAP_OLD_TO_NEW:
+ return STEP_MAP_OLD_TO_NEW[step_id]
+ return step_id
@dataclass
@@ -63,10 +90,29 @@ class PipelineContext:
pipeline_end_time: Optional[float] = None
last_error: Optional[str] = None
+ # ── 错误汇总(全流程结束后可用) ──
+ error_summary: List[tuple[str, str]] = field(default_factory=list)
+ # ── 出错时立即停止全流程(默认 False:继续后续步骤) ──
+ breakpoint_on_error: bool = False
+ # ── ★ 智能补全锁定步骤列表(由 _auto_fill_missing_steps 自动开启的步骤) ──
+ # GUI 层读取此字段,在运行期间禁用对应面板的启用复选框
+ locked_steps: List[str] = field(default_factory=list)
+
# ============================================================
# 读写辅助
# ============================================================
+ def step_id(self, step_id: str) -> str:
+ """将任意 step_id(可能是旧名)转换为标准新格式。
+
+ 用法示例:
+ ctx.status[ctx.step_id('step6_5')] # 'step8_non_empirical_modeling'
+ ctx.user_config[ctx.step_id('step8_5')] # 'step11'
+ """
+ if step_id in STEP_MAP_OLD_TO_NEW:
+ return STEP_MAP_OLD_TO_NEW[step_id]
+ return step_id
+
def set(self, key: str, value: Any) -> None:
"""原地写入任意属性。
diff --git a/src/core/water_quality_inversion_pipeline_GUI.py b/src/core/water_quality_inversion_pipeline_GUI.py
index ac5d639..d6f4102 100644
--- a/src/core/water_quality_inversion_pipeline_GUI.py
+++ b/src/core/water_quality_inversion_pipeline_GUI.py
@@ -660,7 +660,7 @@ class WaterQualityInversionPipeline:
self._notify("completed", f"训练光谱数据已保存: {result}")
return result
- def step5_5_calculate_water_quality_indices(self,
+ def step8_water_quality_indices(self,
training_csv_path: Optional[str] = None,
formula_csv_file: Optional[str] = None,
formula_names: Optional[List[str]] = None,
@@ -704,7 +704,7 @@ class WaterQualityInversionPipeline:
self._notify("completed", f"水质指数已保存: {result}")
return result
- def step6_train_models(self, feature_start_column: str = "374.285004",
+ def step7_ml_modeling(self, feature_start_column: str = "374.285004",
preprocessing_methods: List[str] = None,
model_names: List[str] = None,
split_methods: List[str] = None,
@@ -747,7 +747,7 @@ class WaterQualityInversionPipeline:
self._notify("completed", f"模型训练完成,结果保存在: {result}")
return result
- def step7_generate_sampling_points(self, deglint_img_path: Optional[str] = None,
+ def step10_sampling(self, deglint_img_path: Optional[str] = None,
interval: int = 50,
sample_radius: int = 5,
chunk_size: int = 1000,
@@ -756,7 +756,7 @@ class WaterQualityInversionPipeline:
use_adaptive_sampling: bool = True,
skip_dependency_check: bool = False, **kwargs) -> str:
"""
- 步骤7: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱
+ 步骤10: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱
Args:
deglint_img_path: 去除耀斑后的影像文件路径(如果为None,使用步骤3的结果)
@@ -779,7 +779,7 @@ class WaterQualityInversionPipeline:
if water_mask_path is None and self.water_mask_path is not None:
water_mask_path = self.water_mask_path
- self._notify("started", "步骤7: 生成预测采样点")
+ self._notify("started", "步骤10: 生成预测采样点")
result = PredictionStep.generate_sampling_points(
deglint_img_path=img_path,
interval=interval,
@@ -790,21 +790,21 @@ class WaterQualityInversionPipeline:
output_dir=str(self.sampling_dir),
use_adaptive_sampling=use_adaptive_sampling,
)
- self._record_step_time("步骤7: 生成预测采样点", 0, 0)
+ self._record_step_time("步骤10: 生成预测采样点", 0, 0)
self._notify("completed", f"采样点光谱数据已保存: {result}")
return result
- def step8_predict_water_quality(self, sampling_csv_path: str,
+ def step11_ml_prediction(self, sampling_csv_path: str,
models_dir: Optional[str] = None,
metric: str = 'test_r2',
prediction_column: str = 'prediction',
skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:
"""
- 步骤8: 将训练好的最佳机器学习模型应用到采样点的平均光谱上,预测水质参数
-
+ 步骤11: 将训练好的最佳机器学习模型应用到采样点的平均光谱上,预测水质参数
+
Args:
sampling_csv_path: 采样点光谱数据CSV路径
- models_dir: 模型保存目录(如果为None,使用步骤6的结果)
+ models_dir: 模型保存目录(如果为None,使用步骤7的结果)
metric: 选择最佳模型的指标
prediction_column: 预测结果列名
@@ -818,7 +818,7 @@ class WaterQualityInversionPipeline:
print(f"[Pipeline] 收到字典: {'Yes' if _external_models_dict else 'No'}"
f", 收到单模型: {'Yes' if _external_model else 'No'}")
- self._notify("started", "步骤8: 预测水质参数")
+ self._notify("started", "步骤11: 预测水质参数")
result = PredictionStep.predict_water_quality(
sampling_csv_path=sampling_csv_path,
models_dir=models_dir if models_dir else str(self.models_dir),
@@ -831,11 +831,11 @@ class WaterQualityInversionPipeline:
_external_models_dict=_external_models_dict,
_external_model_dir=_external_model_dir,
)
- self._record_step_time("步骤8: 预测水质参数", 0, 0)
+ self._record_step_time("步骤11: 预测水质参数", 0, 0)
self._notify("completed", f"预测完成,结果保存在: {self.prediction_dir}")
return result
- def step9_generate_distribution_map(self, prediction_csv_path: str,
+ def step14_distribution_map(self, prediction_csv_path: str,
boundary_shp_path: str,
output_image_path: Optional[str] = None,
resolution: float = 30,
@@ -1524,99 +1524,99 @@ class WaterQualityInversionPipeline:
else:
self._notify("步骤5: 光谱提取", "skipped", "未配置")
- # 步骤5.5: 计算水质指数
- if 'step5_5' in config:
- self._notify("步骤5.5: 水质指数计算", "start")
- self.step5_5_calculate_water_quality_indices(**config['step5_5'])
- self._notify("步骤5.5: 水质指数计算", "completed", f"(输出: {self.indices_path})")
+ # 步骤8: 计算水质指数
+ if 'step8' in config:
+ self._notify("步骤8: 水质指数计算", "start")
+ self.step8_water_quality_indices(**config['step8'])
+ self._notify("步骤8: 水质指数计算", "completed", f"(输出: {self.indices_path})")
else:
- self._notify("步骤5.5: 水质指数计算", "skipped", "未配置")
+ self._notify("步骤8: 水质指数计算", "skipped", "未配置")
- # 步骤6: 训练模型
- if 'step6' in config:
- self._notify("步骤6: 模型训练", "start")
- self.step6_train_models(**config['step6'])
- self._notify("步骤6: 模型训练", "completed", f"(输出: {self.models_dir})")
- else:
- self._notify("步骤6: 模型训练", "skipped", "未配置")
-
- # 步骤6.5: 非经验统计回归模型训练
- if 'step6_5' in config:
- self._notify("步骤6.5: 非经验模型训练", "start")
- self.step6_5_non_empirical_modeling(**config['step6_5'])
- self._notify("步骤6.5: 非经验模型训练", "completed", f"(输出: {self.models_dir})")
- else:
- self._notify("步骤6.5: 非经验模型训练", "skipped", "未配置")
-
- # 步骤6.75: 自定义回归分析
- if 'step6_75' in config:
- self._notify("步骤6.75: 自定义回归", "start")
- self.step6_75_custom_regression(**config['step6_75'])
- self._notify("步骤6.75: 自定义回归", "completed", f"(输出: {self.custom_regression_path})")
- else:
- self._notify("步骤6.75: 自定义回归", "skipped", "未配置")
-
- # 步骤7: 生成预测采样点
+ # 步骤7: 训练模型
if 'step7' in config:
- self._notify("步骤7: 采样点生成", "start")
- sampling_csv_path = self.step7_generate_sampling_points(**config['step7'])
- self._notify("步骤7: 采样点生成", "completed", f"(输出: {sampling_csv_path})")
+ self._notify("步骤7: 模型训练", "start")
+ self.step7_ml_modeling(**config['step7'])
+ self._notify("步骤7: 模型训练", "completed", f"(输出: {self.models_dir})")
+ else:
+ self._notify("步骤7: 模型训练", "skipped", "未配置")
+
+ # 步骤8_non_empirical_modeling: 非经验统计回归模型训练
+ if 'step8_non_empirical_modeling' in config:
+ self._notify("步骤8: 非经验模型训练", "start")
+ self.step8_non_empirical_modeling(**config['step8_non_empirical_modeling'])
+ self._notify("步骤8: 非经验模型训练", "completed", f"(输出: {self.models_dir})")
+ else:
+ self._notify("步骤8: 非经验模型训练", "skipped", "未配置")
+
+ # 步骤9: 自定义回归分析
+ if 'step9' in config:
+ self._notify("步骤9: 自定义回归", "start")
+ self.step9_custom_regression(**config['step9'])
+ self._notify("步骤9: 自定义回归", "completed", f"(输出: {self.custom_regression_path})")
+ else:
+ self._notify("步骤9: 自定义回归", "skipped", "未配置")
+
+ # 步骤10: 生成预测采样点
+ if 'step10' in config:
+ self._notify("步骤10: 采样点生成", "start")
+ sampling_csv_path = self.step10_sampling(**config['step10'])
+ self._notify("步骤10: 采样点生成", "completed", f"(输出: {sampling_csv_path})")
else:
sampling_csv_path = None
- self._notify("步骤7: 采样点生成", "skipped", "未配置")
+ self._notify("步骤10: 采样点生成", "skipped", "未配置")
- # 步骤8: 预测水质参数
- if 'step8' in config and sampling_csv_path:
- self._notify("步骤8: 参数预测", "start")
- step8_config = config['step8'].copy()
- step8_config['sampling_csv_path'] = sampling_csv_path
- prediction_files = self.step8_predict_water_quality(**step8_config)
- self._notify("步骤8: 参数预测", "completed", f"(生成{len(prediction_files)}个预测文件)")
+ # 步骤11_ml: 预测水质参数
+ if 'step11_ml' in config and sampling_csv_path:
+ self._notify("步骤11: 参数预测", "start")
+ step11_ml_config = config['step11_ml'].copy()
+ step11_ml_config['sampling_csv_path'] = sampling_csv_path
+ prediction_files = self.step11_ml_prediction(**step11_ml_config)
+ self._notify("步骤11: 参数预测", "completed", f"(生成{len(prediction_files)}个预测文件)")
else:
prediction_files = {}
- self._notify("步骤8: 参数预测", "skipped", "未配置或缺少采样点")
+ self._notify("步骤11: 参数预测", "skipped", "未配置或缺少采样点")
- # 步骤8.5: 使用非经验模型进行参数预测
+ # 步骤11: 使用非经验模型进行参数预测
non_empirical_prediction_files = {}
- if 'step8_5' in config and sampling_csv_path:
- self._notify("步骤8.5: 非经验模型预测", "start")
- step8_5_config = config['step8_5'].copy()
- step8_5_config['sampling_csv_path'] = sampling_csv_path
- non_empirical_prediction_files = self.step8_5_predict_with_non_empirical_models(**step8_5_config)
- self._notify("步骤8.5: 非经验模型预测", "completed", f"(生成{len(non_empirical_prediction_files)}个预测文件)")
+ if 'step11' in config and sampling_csv_path:
+ self._notify("步骤11: 非经验模型预测", "start")
+ step11_config = config['step11'].copy()
+ step11_config['sampling_csv_path'] = sampling_csv_path
+ non_empirical_prediction_files = self.step11_non_empirical_prediction(**step11_config)
+ self._notify("步骤11: 非经验模型预测", "completed", f"(生成{len(non_empirical_prediction_files)}个预测文件)")
else:
- self._notify("步骤8.5: 非经验模型预测", "skipped", "未配置或缺少采样点")
+ self._notify("步骤11: 非经验模型预测", "skipped", "未配置或缺少采样点")
- # 步骤8.75: 使用自定义回归模型进行参数预测
+ # 步骤12: 使用自定义回归模型进行参数预测
custom_regression_prediction_files = {}
- if 'step8_75' in config and sampling_csv_path:
- self._notify("步骤8.75: 自定义回归预测", "start")
- step8_75_config = config['step8_75'].copy()
- step8_75_config['sampling_csv_path'] = sampling_csv_path
- custom_regression_prediction_files = self.step8_75_predict_with_custom_regression(**step8_75_config)
- self._notify("步骤8.75: 自定义回归预测", "completed", f"(生成{len(custom_regression_prediction_files)}个预测文件)")
+ if 'step12' in config and sampling_csv_path:
+ self._notify("步骤12: 自定义回归预测", "start")
+ step12_config = config['step12'].copy()
+ step12_config['sampling_csv_path'] = sampling_csv_path
+ custom_regression_prediction_files = self.step12_custom_regression_prediction(**step12_config)
+ self._notify("步骤12: 自定义回归预测", "completed", f"(生成{len(custom_regression_prediction_files)}个预测文件)")
else:
- self._notify("步骤8.75: 自定义回归预测", "skipped", "未配置或缺少采样点")
+ self._notify("步骤12: 自定义回归预测", "skipped", "未配置或缺少采样点")
# 合并机器学习预测、非经验模型预测和自定义回归预测结果
all_prediction_files = {**prediction_files, **non_empirical_prediction_files, **custom_regression_prediction_files}
- # 步骤9: 生成分布图
+ # 步骤14: 生成分布图
distribution_maps = {}
- if 'step9' in config and all_prediction_files:
- self._notify("步骤9: 分布图生成", "start")
+ if 'step14' in config and all_prediction_files:
+ self._notify("步骤14: 分布图生成", "start")
for target_name, pred_file in all_prediction_files.items():
- step9_config = config['step9'].copy()
+ step14_config = config['step14'].copy()
for _k in ('step9_batch_mode', 'prediction_csv_dir', 'recursive_csv_scan'):
- step9_config.pop(_k, None)
- step9_config['prediction_csv_path'] = pred_file
- if 'output_image_path' not in step9_config:
- step9_config['output_image_path'] = None
- dist_map_path = self.step9_generate_distribution_map(**step9_config)
+ step14_config.pop(_k, None)
+ step14_config['prediction_csv_path'] = pred_file
+ if 'output_image_path' not in step14_config:
+ step14_config['output_image_path'] = None
+ dist_map_path = self.step14_distribution_map(**step14_config)
distribution_maps[target_name] = dist_map_path
- self._notify("步骤9: 分布图生成", "completed", f"(生成{len(distribution_maps)}个分布图)")
+ self._notify("步骤14: 分布图生成", "completed", f"(生成{len(distribution_maps)}个分布图)")
else:
- self._notify("步骤9: 分布图生成", "skipped", "未配置或缺少预测结果")
+ self._notify("步骤14: 分布图生成", "skipped", "未配置或缺少预测结果")
# 生成可视化图表
output_files = {}
@@ -1716,10 +1716,10 @@ class WaterQualityInversionPipeline:
pipeline_info['step3'] = {'status': 'completed', 'output_file': str(self.deglint_img_path) if self.deglint_img_path else 'N/A'}
pipeline_info['step4'] = {'status': 'completed', 'output_file': str(self.processed_csv_path) if self.processed_csv_path else 'N/A'}
pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_csv_path) if self.training_csv_path else 'N/A'}
- pipeline_info['step5_5'] = {'status': 'completed', 'output_file': str(self.indices_path) if self.indices_path else 'N/A'}
- pipeline_info['step6'] = {'status': 'completed', 'output_file': str(self.models_dir)}
- pipeline_info['step6_75'] = {'status': 'completed', 'output_file': str(self.custom_regression_path) if self.custom_regression_path else 'N/A'}
- pipeline_info['training_params'] = config.get('step6', {})
+ pipeline_info['step8'] = {'status': 'completed', 'output_file': str(self.indices_path) if self.indices_path else 'N/A'}
+ pipeline_info['step7'] = {'status': 'completed', 'output_file': str(self.models_dir)}
+ pipeline_info['step9'] = {'status': 'completed', 'output_file': str(self.custom_regression_path) if self.custom_regression_path else 'N/A'}
+ pipeline_info['training_params'] = config.get('step7', {})
summary_path = self.report_generator.generate_batch_inference_summary(pipeline_info)
print(f"批量处理摘要已生成: {summary_path}")
@@ -1769,7 +1769,7 @@ class WaterQualityInversionPipeline:
traceback.print_exc()
raise
- def step6_5_non_empirical_modeling(self, csv_path: Optional[str] = None,
+ def step8_non_empirical_modeling(self, csv_path: Optional[str] = None,
preprocessing_methods: List[str] = None,
algorithms: List[str] = None,
value_cols: Union[int, Dict[str, int]] = 0,
@@ -1819,7 +1819,7 @@ class WaterQualityInversionPipeline:
self._notify("completed", f"非经验模型训练完成")
return result
- def step6_75_custom_regression(self,
+ def step9_custom_regression(self,
csv_path: Optional[str] = None,
x_columns: Optional[Union[str, List[str]]] = None,
y_columns: Optional[Union[str, List[str]]] = None,
@@ -1999,7 +1999,7 @@ class WaterQualityInversionPipeline:
return summary_path
- def step8_5_predict_with_non_empirical_models(self, sampling_csv_path: str,
+ def step11_non_empirical_prediction(self, sampling_csv_path: str,
non_empirical_models_dir: Optional[str] = None,
output_path: Optional[str] = None,
metric: str = 'Average Accuracy(%)',
@@ -2007,13 +2007,13 @@ class WaterQualityInversionPipeline:
enabled: bool = True,
skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:
"""
- 步骤8.5: 使用非经验统计回归模型进行参数预测
+ 步骤11: 使用非经验统计回归模型进行参数预测
根据非经验模型训练结果汇总CSV筛选给定方法的准确率最高的模型,使用该模型进行预测
Args:
sampling_csv_path: 采样点光谱数据CSV路径
- non_empirical_models_dir: 非经验模型保存目录(如果为None,使用步骤6.5的结果)
+ non_empirical_models_dir: 非经验模型保存目录(如果为None,使用步骤8的结果)
output_path: 输出目录路径(如果为None,使用默认目录)
metric: 选择最佳模型的指标(默认使用平均准确率)
prediction_column: 预测结果列名
@@ -2021,7 +2021,7 @@ class WaterQualityInversionPipeline:
Returns:
预测结果文件路径字典(键为算法名)
"""
- self._notify("started", "步骤8.5: 使用非经验模型进行参数预测")
+ self._notify("started", "步骤11: 使用非经验模型进行参数预测")
result = PredictionStep.predict_with_non_empirical_models(
sampling_csv_path=sampling_csv_path,
non_empirical_models_dir=non_empirical_models_dir,
@@ -2031,11 +2031,11 @@ class WaterQualityInversionPipeline:
enabled=enabled,
work_dir=str(self.work_dir),
)
- self._record_step_time("步骤8.5: 非经验模型预测", 0, 0)
+ self._record_step_time("步骤11: 非经验模型预测", 0, 0)
self._notify("completed", f"非经验模型预测完成,结果保存在: {self.prediction_dir}")
return result
- def step8_75_predict_with_custom_regression(self, sampling_csv_path: str,
+ def step12_custom_regression_prediction(self, sampling_csv_path: str,
custom_regression_dir: Optional[str] = None,
formula_csv_path: Optional[str] = None,
coordinate_columns: Optional[List[str]] = None,
@@ -2044,13 +2044,13 @@ class WaterQualityInversionPipeline:
enabled: bool = True,
skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:
"""
- 步骤8.75: 使用自定义回归模型进行参数预测
-
+ 步骤12: 使用自定义回归模型进行参数预测
+
使用新的CustomRegressionPredictor模块,基于9_Custom_Regression_Modeling文件夹中的CSV,
根据r_squared选择最佳模型,批量预测水质参数
-
+
Args:
- sampling_csv_path: 采样点光谱数据CSV路径(来自步骤7)
+ sampling_csv_path: 采样点光谱数据CSV路径(来自步骤10)
custom_regression_dir: 自定义回归模型目录(9_Custom_Regression_Modeling)
formula_csv_path: 公式CSV文件路径,用于查找index_formula
coordinate_columns: 坐标列名列表,默认为['longitude', 'latitude']或自动识别
@@ -2058,11 +2058,11 @@ class WaterQualityInversionPipeline:
filename_prefix: 输出文件名前缀
enabled: 是否启用
skip_dependency_check: 是否跳过依赖检查
-
+
Returns:
预测结果文件路径字典(键为参数名)
"""
- self._notify("started", "步骤8.75: 使用自定义回归模型进行参数预测")
+ self._notify("started", "步骤12: 使用自定义回归模型进行参数预测")
result = PredictionStep.predict_with_custom_regression(
sampling_csv_path=sampling_csv_path,
custom_regression_dir=custom_regression_dir,
@@ -2073,7 +2073,7 @@ class WaterQualityInversionPipeline:
enabled=enabled,
work_dir=str(self.work_dir),
)
- self._record_step_time("步骤8.75: 自定义回归模型预测", 0, 0)
+ self._record_step_time("步骤12: 自定义回归模型预测", 0, 0)
self._notify("completed", f"自定义回归预测完成")
return result
@@ -2161,20 +2161,20 @@ def main():
# 单步运行时建议显式指定;完整流程中可省略,将使用步骤2输出的耀斑掩膜
# 'glint_mask_path': r"path/to/severe_glint_area.dat",
},
- 'step5_5': {
+ 'step8': {
'formula_csv_file': 'path/to/water_quality_formulas.csv', # 公式CSV文件路径
'formula_names': ['Al10SABI', 'TurbBe16RedOverViolet'], # 要计算的公式名称列表
'output_filename': 'water_quality_indices.csv',
'enabled': True # 是否启用水质指数计算
},
- 'step6': {
+ 'step7': {
'feature_start_column': '374.285004',
'preprocessing_methods': ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT'],
'model_names': ['SVR', 'RF', 'Ridge', 'Lasso'],
'split_methods': ['spxy', 'ks', 'random'],
'cv_folds': 3
},
- 'step6_5': {
+ 'step8_non_empirical_modeling': {
'preprocessing_methods': ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT'],
'algorithms': ['chl_a', 'nh3', 'mno4', 'tn', 'tp', 'tss'],
'value_cols': 0, # 可以是单个整数或字典,如 {'chl_a': 0, 'nh3': 1, 'mno4': 2, 'tn': 3, 'tp': 4, 'tss': 5}
@@ -2182,14 +2182,14 @@ def main():
'window': 5,
'enabled': True # 是否启用非经验模型训练
},
- 'step6_75': {
+ 'step9': {
'x_columns': ['NDWI', 'NDVI'], # 自变量列名列表
'y_columns': ['chl_a', 'tn', 'tp'], # 因变量列名列表
'methods': 'all', # 回归方法
'output_dir': 'custom_regression_results', # 输出目录
'enabled': True # 是否启用自定义回归分析
},
- 'step7': {
+ 'step10': {
'interval': 50,
'sample_radius': 5,
'chunk_size': 1000,
@@ -2197,16 +2197,16 @@ def main():
# 可选:耀斑掩膜文件(dat),若不提供将使用步骤2结果;需要外部指定时取消注释
# 'glint_mask_path': r"D:\path\to\severe_glint_area.dat",
},
- 'step8': {
+ 'step11_ml': {
'metric': 'test_r2',
'prediction_column': 'prediction'
},
- 'step8_5': {
+ 'step11': {
'metric': 'Average Accuracy(%)', # 选择最佳模型的指标
'prediction_column': 'prediction',
'enabled': True # 是否启用非经验模型预测
},
- 'step8_75': {
+ 'step12': {
'custom_regression_dir': None, # 自定义回归模型目录(None表示使用9_Custom_Regression_Modeling)
'formula_csv_path': None, # 公式CSV文件路径,用于查找index_formula(如water_quality_formulas.csv)
'coordinate_columns': None, # 坐标列名(None表示自动识别)
@@ -2214,7 +2214,7 @@ def main():
'filename_prefix': 'custom_regression_prediction', # 输出文件名前缀
'enabled': True # 是否启用自定义回归预测
},
- 'step9': {
+ 'step14': {
'boundary_shp_path': r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp" ,
'resolution': 30,
'input_crs': 'EPSG:32651',
@@ -2345,41 +2345,41 @@ def example_independent_steps():
except Exception as e:
print(f"步骤6失败: {e}")
- # 示例6: 独立运行步骤7 - 采样点生成
- print("\n示例6: 独立运行步骤7 - 采样点生成")
+ # 示例6: 独立运行步骤10 - 采样点生成
+ print("\n示例6: 独立运行步骤10 - 采样点生成")
try:
- sampling_csv = pipeline.step7_generate_sampling_points(
+ sampling_csv = pipeline.step10_sampling(
deglint_img_path="path/to/deglint_image.bsq",
water_mask_path="path/to/water_mask.dat",
skip_dependency_check=True
)
print(f"采样点数据: {sampling_csv}")
except Exception as e:
- print(f"步骤7失败: {e}")
+ print(f"步骤10失败: {e}")
- # 示例7: 独立运行步骤8 - 水质预测
- print("\n示例7: 独立运行步骤8 - 水质预测")
+ # 示例7: 独立运行步骤11 - 水质预测
+ print("\n示例7: 独立运行步骤11 - 水质预测")
try:
- predictions = pipeline.step8_predict_water_quality(
+ predictions = pipeline.step11_ml_prediction(
sampling_csv_path="path/to/sampling_spectra.csv",
models_dir="path/to/models_directory",
skip_dependency_check=True
)
print(f"预测结果: {predictions}")
except Exception as e:
- print(f"步骤8失败: {e}")
+ print(f"步骤11失败: {e}")
- # 示例8: 独立运行步骤9 - 分布图生成
- print("\n示例8: 独立运行步骤9 - 分布图生成")
+ # 示例8: 独立运行步骤14 - 分布图生成
+ print("\n示例8: 独立运行步骤14 - 分布图生成")
try:
- distribution_map = pipeline.step9_generate_distribution_map(
+ distribution_map = pipeline.step14_distribution_map(
prediction_csv_path="path/to/prediction_results.csv",
boundary_shp_path="path/to/boundary.shp",
skip_dependency_check=True
)
print(f"分布图: {distribution_map}")
except Exception as e:
- print(f"步骤9失败: {e}")
+ print(f"步骤14失败: {e}")
print("\n" + "="*80)
print("独立步骤运行示例完成")
diff --git a/src/gui/core/pipeline_mode_dialog.py b/src/gui/core/pipeline_mode_dialog.py
new file mode 100644
index 0000000..8fa1a7a
--- /dev/null
+++ b/src/gui/core/pipeline_mode_dialog.py
@@ -0,0 +1,237 @@
+# -*- coding: utf-8 -*-
+"""
+PipelineModeDialog:全流程运行前的模式选择弹窗。
+
+用户点击"运行完整流程"后,首先弹出此弹窗选择执行模式:
+ - 选项 A(训练新模型并预测):执行完整建模与预测流程,需要实测水质 CSV
+ - 选项 B(使用已有模型直接预测):跳过训练步骤,直接使用外部模型目录进行预测
+
+弹窗结果:
+ - QDialog.Accepted + self.selected_mode = "training" 或 "prediction_only"
+ - QDialog.Rejected → 调用方中止 run_full_pipeline
+"""
+
+import os
+from typing import Optional
+
+from PyQt5.QtCore import Qt
+from PyQt5.QtGui import QFont
+from PyQt5.QtWidgets import (
+ QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
+ QRadioButton, QGroupBox, QButtonGroup, QMessageBox, QSizePolicy,
+)
+
+
+def _is_valid_model_dir(path: str) -> bool:
+ """深层递归检测模型目录:只要任意层级存在文件即返回 True。"""
+ if not path or not os.path.isdir(path):
+ return False
+ for _root, _dirs, files in os.walk(path):
+ if files:
+ return True
+ return False
+
+
+class PipelineModeDialog(QDialog):
+ """全流程模式选择对话框。
+
+ 两个单选按钮覆盖两种业务场景:
+ - A:训练新模型(完整流程,需要 step4 CSV)
+ - B:仅预测(跳过 step4/5/7/8,直接用外部模型目录)
+
+ 属性:
+ selected_mode: "training" | "prediction_only"
+ """
+
+ def __init__(self, main_window=None, parent=None):
+ super().__init__(parent)
+ self.main_window = main_window
+ self.selected_mode: Optional[str] = None
+ self.setWindowTitle("选择运行模式")
+ self.setMinimumSize(560, 340)
+ self.setModal(True)
+ self._setup_ui()
+
+ # ------------------------------------------------------------------
+ # UI 构建
+ # ------------------------------------------------------------------
+
+ def _setup_ui(self):
+ layout = QVBoxLayout(self)
+ layout.setContentsMargins(28, 24, 28, 20)
+ layout.setSpacing(14)
+
+ # ── 标题 ──
+ title = QLabel("请选择全流程运行模式")
+ title_font = QFont()
+ title_font.setPointSize(13)
+ title_font.setBold(True)
+ title.setFont(title_font)
+ title.setAlignment(Qt.AlignCenter)
+ layout.addWidget(title)
+
+ layout.addSpacing(4)
+
+ # ── 选项 A:训练新模型 ──
+ group_a = QGroupBox()
+ group_a.setObjectName("groupA")
+ group_a.setMinimumHeight(100)
+ layout.addWidget(group_a)
+
+ self.radio_a = QRadioButton("【训练新模型并预测】")
+ self.radio_a.setChecked(True) # 默认选项 A
+ self.radio_a.setObjectName("radioTraining")
+
+ desc_a = QLabel(
+ "需要提供实测水质数据 (CSV),将执行完整建模与预测流程。\n"
+ "包括:水域掩膜 → 耀斑去除 → 光谱特征提取 → 模型训练 → 密集采样 → 预测 → 专题图"
+ )
+ desc_a.setWordWrap(True)
+ desc_a.setStyleSheet("color: #555555; background: transparent;")
+ desc_a.setObjectName("descA")
+
+ vbox_a = QVBoxLayout(group_a)
+ vbox_a.setContentsMargins(16, 20, 16, 14)
+ vbox_a.setSpacing(8)
+ vbox_a.addWidget(self.radio_a)
+ vbox_a.addWidget(desc_a)
+
+ # ── 选项 B:仅预测 ──
+ group_b = QGroupBox()
+ group_b.setObjectName("groupB")
+ group_b.setMinimumHeight(100)
+ layout.addWidget(group_b)
+
+ self.radio_b = QRadioButton("【使用已有模型直接预测】")
+ self.radio_b.setObjectName("radioPrediction")
+
+ desc_b = QLabel(
+ "跳过模型训练步骤,直接使用导入的外部模型目录进行预测。\n"
+ "前提条件:请在「监督预测」或「回归预测」面板中指定模型目录。\n"
+ "适用范围:已有预训练模型、或其他来源模型目录。"
+ )
+ desc_b.setWordWrap(True)
+ desc_b.setStyleSheet("color: #555555; background: transparent;")
+ desc_b.setObjectName("descB")
+
+ vbox_b = QVBoxLayout(group_b)
+ vbox_b.setContentsMargins(16, 20, 16, 14)
+ vbox_b.setSpacing(8)
+ vbox_b.addWidget(self.radio_b)
+ vbox_b.addWidget(desc_b)
+
+ # ── 强制互斥:QButtonGroup ──
+ self.mode_group = QButtonGroup(self)
+ self.mode_group.addButton(self.radio_a)
+ self.mode_group.addButton(self.radio_b)
+
+ # ── 提示栏(动态显示 models_dir 状态) ──
+ self.models_hint = QLabel()
+ self.models_hint.setObjectName("modelsHint")
+ self.models_hint.setWordWrap(True)
+ self.models_hint.setStyleSheet("color: #888888; font-size: 11px; padding: 4px 0;")
+ layout.addWidget(self.models_hint)
+
+ # ── 强制 QRadioButton 指示器为实心圆点 ──
+ self.setStyleSheet("""
+ QRadioButton::indicator {
+ width: 14px;
+ height: 14px;
+ }
+ QRadioButton::indicator:checked {
+ background-color: #0078D7;
+ border: 2px solid #0078D7;
+ border-radius: 7px;
+ }
+ QRadioButton::indicator:unchecked {
+ background-color: white;
+ border: 2px solid #A0A0A0;
+ border-radius: 7px;
+ }
+ """)
+
+ # ── 按钮 ──
+ btn_layout = QHBoxLayout()
+ btn_layout.addStretch()
+
+ cancel_btn = QPushButton("取消")
+ cancel_btn.setObjectName("cancelBtn")
+ cancel_btn.setMinimumWidth(90)
+ cancel_btn.clicked.connect(self.reject)
+
+ self.btn_confirm = QPushButton("确认")
+ self.btn_confirm.setObjectName("confirmBtn")
+ self.btn_confirm.setMinimumWidth(90)
+ self.btn_confirm.setDefault(True)
+ self.btn_confirm.clicked.connect(self._on_confirm)
+
+ btn_layout.addWidget(self.btn_confirm)
+ btn_layout.addWidget(cancel_btn)
+ layout.addLayout(btn_layout)
+
+ # 信号连接:任一 radio 切换时重新渲染提示 + 按钮状态
+ self.radio_a.toggled.connect(self._update_models_hint)
+ self.radio_b.toggled.connect(self._update_models_hint)
+
+ # 初始状态渲染
+ self._update_models_hint()
+
+ def _update_models_hint(self, checked=False, *args) -> None:
+ """根据当前选中模式和 models_dir 状态更新提示文字及确认按钮可用性。"""
+ training_checked = self.radio_a.isChecked()
+
+ # 从主窗口 config 读取 models_dir(优先 ml,其次 reg)
+ models_dir = ""
+ if self.main_window:
+ config = self.main_window.get_current_config()
+ models_dir = config.get("step11_ml", {}).get("models_dir", "")
+ if not models_dir:
+ models_dir = config.get("step11", {}).get("models_dir", "")
+
+ has_files = bool(models_dir and _is_valid_model_dir(models_dir))
+ dir_exists = bool(models_dir and os.path.isdir(models_dir))
+
+ if training_checked:
+ if hasattr(self, 'btn_confirm') and self.btn_confirm is not None:
+ self.btn_confirm.setEnabled(True)
+ if has_files:
+ self.models_hint.setText(
+ f"⚠ 注意:当前模型目录已包含文件,继续训练将会【覆盖】原有模型!\n路径:{models_dir}"
+ )
+ self.models_hint.setStyleSheet("color: #e65100; font-size: 11px; padding: 4px 0;")
+ else:
+ label = f"✓ 模型将保存至该目录(当前为空,安全)。\n路径:{models_dir}" if dir_exists else "✓ 尚未指定模型目录,将使用默认路径创建新模型。"
+ self.models_hint.setText(label)
+ self.models_hint.setStyleSheet("color: #2e7d32; font-size: 11px; padding: 4px 0;")
+ else:
+ if has_files:
+ self.models_hint.setText(
+ f"✓ 已检测到有效模型目录,可以直接预测。\n路径:{models_dir}"
+ )
+ self.models_hint.setStyleSheet("color: #2e7d32; font-size: 11px; padding: 4px 0;")
+ if hasattr(self, 'btn_confirm') and self.btn_confirm is not None:
+ self.btn_confirm.setEnabled(True)
+ else:
+ if dir_exists:
+ self.models_hint.setText(
+ f"❌ 错误:模型目录为空(未找到任何文件),无法进行预测!\n路径:{models_dir}"
+ )
+ else:
+ self.models_hint.setText(
+ "❌ 错误:模型目录为空或不存在!请先返回对应面板配置有效路径。"
+ )
+ self.models_hint.setStyleSheet("color: #c62828; font-size: 11px; padding: 4px 0;")
+ if hasattr(self, 'btn_confirm') and self.btn_confirm is not None:
+ self.btn_confirm.setEnabled(False)
+
+ def _on_confirm(self) -> None:
+ """确认按钮回调:直接存储模式并关闭。
+
+ 注意:按钮禁用状态已在 _update_models_hint 中处理,
+ 此处仅负责结果存储,不再做二次弹窗拦截。
+ """
+ if self.radio_a.isChecked():
+ self.selected_mode = "training"
+ else:
+ self.selected_mode = "prediction_only"
+ self.accept()
\ No newline at end of file
diff --git a/src/gui/dialogs.py b/src/gui/dialogs.py
index 04ecc8a..3c95a6f 100644
--- a/src/gui/dialogs.py
+++ b/src/gui/dialogs.py
@@ -349,6 +349,12 @@ class AISettingsDialog(QDialog):
# 交互式采样点与光谱查看器
# ─────────────────────────────────────────────────────────────────────────────
+import matplotlib.pyplot as _plt
+
+# 全局字体设置(防中文乱码 + 负号显示异常)
+_plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'Arial Unicode MS']
+_plt.rcParams['axes.unicode_minus'] = False
+
class SamplingViewerDialog(QDialog):
"""交互式采样点与光谱查看器
@@ -381,8 +387,8 @@ class SamplingViewerDialog(QDialog):
self._fig = Figure(figsize=(6, 5))
self._canvas = FigureCanvasQTAgg(self._fig)
self._ax_scatter = self._fig.add_subplot(111)
- self._ax_scatter.set_xlabel("pixel_x")
- self._ax_scatter.set_ylabel("pixel_y")
+ self._ax_scatter.set_xlabel("像素 X")
+ self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
self._fig.tight_layout()
@@ -392,18 +398,20 @@ class SamplingViewerDialog(QDialog):
# --- 右侧:信息面板 + 光谱子图 ---
right_widget = QWidget()
right_layout = QVBoxLayout(right_widget)
+ right_layout.setContentsMargins(0, 0, 0, 0)
+ # 坐标信息面板(多行中文清晰显示)
self._info_label = QLabel("点击左侧散点图选择采样点")
self._info_label.setStyleSheet(
- "QLabel { background-color: #f0f0f0; padding: 6px; "
- "border: 1px solid #ccc; font-size: 13px; }"
+ "QLabel { background-color: #f0f0f0; padding: 8px; "
+ "border: 1px solid #ccc; border-radius: 4px; font-size: 13px; }"
)
right_layout.addWidget(self._info_label)
# 采样点列表迷你表格
self._point_table = QTableWidget()
self._point_table.setColumnCount(3)
- self._point_table.setHorizontalHeaderLabels(["pixel_x", "pixel_y", "index"])
+ self._point_table.setHorizontalHeaderLabels(["像素 X", "像素 Y", "序号"])
self._point_table.setMaximumHeight(120)
self._point_table.setEditTriggers(QTableWidget.NoEditTriggers)
self._point_table.setSelectionBehavior(QTableWidget.SelectRows)
@@ -413,8 +421,8 @@ class SamplingViewerDialog(QDialog):
# 光谱曲线子图
self._fig_right = Figure(figsize=(5, 3))
self._ax_spectrum = self._fig_right.add_subplot(111)
- self._ax_spectrum.set_xlabel("Band Index")
- self._ax_spectrum.set_ylabel("Reflectance")
+ self._ax_spectrum.set_xlabel("波段序号")
+ self._ax_spectrum.set_ylabel("反射率")
self._ax_spectrum.set_title("光谱曲线")
self._fig_right.tight_layout()
@@ -440,8 +448,8 @@ class SamplingViewerDialog(QDialog):
def _draw_scatter(self):
"""绘制散点图"""
self._ax_scatter.clear()
- self._ax_scatter.set_xlabel("pixel_x")
- self._ax_scatter.set_ylabel("pixel_y")
+ self._ax_scatter.set_xlabel("像素 X")
+ self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
@@ -497,14 +505,14 @@ class SamplingViewerDialog(QDialog):
self._info_label.setText(
f"选中的采样点 #{nearest_idx}
"
- f"pixel_x = {pixel_x} pixel_y = {pixel_y}
"
- f"x_coord = {x_coord} y_coord = {y_coord}"
+ f"图像像素坐标: X = {pixel_x}, Y = {pixel_y}
"
+ f"地理真实坐标: 经度(X) = {x_coord}, 纬度(Y) = {y_coord}"
)
# 高亮散点图
self._ax_scatter.clear()
- self._ax_scatter.set_xlabel("pixel_x")
- self._ax_scatter.set_ylabel("pixel_y")
+ self._ax_scatter.set_xlabel("像素 X")
+ self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
self._ax_scatter.invert_yaxis()
self._ax_scatter.scatter(
@@ -523,8 +531,8 @@ class SamplingViewerDialog(QDialog):
def _draw_spectrum(self, row: pd.Series):
"""从一行数据中提取纯波段数值并绘图"""
self._ax_spectrum.clear()
- self._ax_spectrum.set_xlabel("Band Index")
- self._ax_spectrum.set_ylabel("Reflectance")
+ self._ax_spectrum.set_xlabel("波段序号")
+ self._ax_spectrum.set_ylabel("反射率")
self._ax_spectrum.set_title("光谱曲线")
exclude_patterns = (
diff --git a/src/gui/panels/step10_panel.py b/src/gui/panels/step10_panel.py
new file mode 100644
index 0000000..5e6150d
--- /dev/null
+++ b/src/gui/panels/step10_panel.py
@@ -0,0 +1,252 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+Step10 面板 - 采样点生成
+"""
+
+import os
+
+from PyQt5.QtWidgets import (
+ QWidget, QVBoxLayout, QGroupBox, QFormLayout,
+ QPushButton, QCheckBox, QSpinBox, QMessageBox,
+)
+
+from src.gui.components.custom_widgets import FileSelectWidget
+from src.gui.dialogs import SamplingViewerDialog
+from src.gui.styles import ModernStylesheet
+
+
+class Step10Panel(QWidget):
+ """步骤10:采样点生成"""
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.init_ui()
+
+ def init_ui(self):
+ layout = QVBoxLayout()
+
+ # 去耀斑影像文件(用于独立运行)
+ self.deglint_img_file = FileSelectWidget(
+ "去耀斑影像:",
+ "Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
+ )
+ layout.addWidget(self.deglint_img_file)
+
+ # 水域掩膜文件(可选,用于独立运行)
+ self.water_mask_file = FileSelectWidget(
+ "水域掩膜:",
+ "Mask Files (*.dat *.tif);;All Files (*.*)"
+ )
+ self.water_mask_file.label.setText("水域掩膜:")
+ layout.addWidget(self.water_mask_file)
+
+ # 参数设置
+ params_group = QGroupBox("采样参数")
+ params_layout = QFormLayout()
+
+ self.interval = QSpinBox()
+ self.interval.setRange(10, 500)
+ self.interval.setValue(50)
+ params_layout.addRow("采样点间隔(像素):", self.interval)
+
+ self.sample_radius = QSpinBox()
+ self.sample_radius.setRange(1, 50)
+ self.sample_radius.setValue(5)
+ params_layout.addRow("采样半径(像素):", self.sample_radius)
+
+ self.chunk_size = QSpinBox()
+ self.chunk_size.setRange(100, 10000)
+ self.chunk_size.setValue(1000)
+ params_layout.addRow("处理块大小:", self.chunk_size)
+
+ self.use_adaptive_sampling = QCheckBox("启用自适应采样")
+ self.use_adaptive_sampling.setChecked(True)
+ params_layout.addRow("采样模式:", self.use_adaptive_sampling)
+
+ params_group.setLayout(params_layout)
+ layout.addWidget(params_group)
+
+ # 输出文件路径
+ self.output_file = FileSelectWidget(
+ "输出采样点:",
+ "CSV Files (*.csv);;All Files (*.*)"
+ )
+ self.output_file.line_edit.setPlaceholderText("sampling_points.csv")
+ layout.addWidget(self.output_file)
+
+ # 启用步骤
+ self.enable_checkbox = QCheckBox("启用此步骤")
+ self.enable_checkbox.setChecked(True)
+ layout.addWidget(self.enable_checkbox)
+
+ # 独立运行按钮
+ self.run_btn = QPushButton("独立运行此步骤")
+ self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
+ self.run_btn.clicked.connect(self.run_step)
+ layout.addWidget(self.run_btn)
+
+ # 交互式预览按钮
+ self.preview_btn = QPushButton("📊 交互式预览采样点与光谱")
+ self.preview_btn.setEnabled(False)
+ self.preview_btn.clicked.connect(self._open_sampling_viewer)
+ layout.addWidget(self.preview_btn)
+
+ layout.addStretch()
+ self.setLayout(layout)
+
+ # 监听输出路径变化,实时更新预览按钮状态
+ self.output_file.line_edit.textChanged.connect(self._on_output_changed)
+
+ def get_config(self):
+ """获取配置"""
+ config = {
+ 'interval': self.interval.value(),
+ 'sample_radius': self.sample_radius.value(),
+ 'chunk_size': self.chunk_size.value(),
+ 'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(),
+ }
+ deglint_img_path = self.deglint_img_file.get_path()
+ if deglint_img_path:
+ config['deglint_img_path'] = deglint_img_path
+ water_mask_path = self.water_mask_file.get_path()
+ if water_mask_path:
+ config['water_mask_path'] = water_mask_path
+ return config
+
+ def set_config(self, config):
+ """设置配置"""
+ if 'interval' in config:
+ self.interval.setValue(config['interval'])
+ if 'sample_radius' in config:
+ self.sample_radius.setValue(config['sample_radius'])
+ if 'chunk_size' in config:
+ self.chunk_size.setValue(config['chunk_size'])
+ if 'use_adaptive_sampling' in config:
+ self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling'])
+ if 'deglint_img_path' in config:
+ self.deglint_img_file.set_path(config['deglint_img_path'])
+ if 'water_mask_path' in config:
+ self.water_mask_file.set_path(config['water_mask_path'])
+ if 'glint_mask_path' in config:
+ self.glint_mask_file.set_path(config['glint_mask_path'])
+
+ def update_from_config(self, work_dir=None, pipeline=None):
+ """从全局配置自动填充去耀斑影像和掩膜路径
+
+ Args:
+ work_dir: 工作目录路径
+ pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径)
+ """
+ if work_dir:
+ self.work_dir = work_dir
+ elif hasattr(self, 'work_dir') and self.work_dir:
+ pass
+ else:
+ self.work_dir = None
+
+ main_window = self.window()
+
+ # 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径)
+ deglint_path = None
+ if pipeline and hasattr(pipeline, 'step_outputs'):
+ step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {})
+ deglint_path = (
+ step3_outputs.get('deglint_image')
+ or step3_outputs.get('output_path')
+ or step3_outputs.get('output_file')
+ or step3_outputs.get('deglint_img_path')
+ )
+ # 回退:从 step3 面板 widget 直接读取(可能是相对路径)
+ if not deglint_path and hasattr(main_window, 'step3_panel'):
+ deglint_path = main_window.step3_panel.output_file.get_path()
+
+ if deglint_path:
+ # 若为相对路径,使用 work_dir 合成为绝对路径
+ if not os.path.isabs(deglint_path):
+ deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/')
+ self.deglint_img_file.set_path(deglint_path)
+
+ # 2. 填充水域掩膜路径(优先级:pipeline.step_outputs > step1_panel > 1_water_mask > input-test)
+ water_mask_path = None
+ if pipeline and hasattr(pipeline, 'step_outputs'):
+ step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {})
+ water_mask_path = (
+ step1_outputs.get('water_mask')
+ or step1_outputs.get('output_path')
+ or step1_outputs.get('output_file')
+ )
+ # 回退:从 step1 面板 widget 直接读取
+ if not water_mask_path and hasattr(main_window, 'step1_panel'):
+ water_mask_path = main_window.step1_panel.output_file.get_path()
+ # 备选:扫描 1_water_mask 目录下的 .dat 文件
+ if not water_mask_path and self.work_dir:
+ mask_dir = os.path.join(self.work_dir, "1_water_mask")
+ if os.path.isdir(mask_dir):
+ dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')]
+ if dat_files:
+ water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/')
+ # 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat)
+ if not water_mask_path and self.work_dir:
+ input_test_dir = os.path.join(self.work_dir, "input-test")
+ if os.path.isdir(input_test_dir):
+ dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')]
+ # 优先匹配 water_mask_from_shp.dat
+ for f in dat_files:
+ if 'water_mask_from_shp' in f.lower():
+ water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/')
+ break
+ # 否则取第一个 .dat 文件
+ if not water_mask_path and dat_files:
+ water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/')
+
+ if water_mask_path:
+ # 若为相对路径,使用 work_dir 合成为绝对路径
+ if not os.path.isabs(water_mask_path):
+ water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/')
+ self.water_mask_file.set_path(water_mask_path)
+
+ # 3. 自动填充输出路径(绝对路径)
+ if self.work_dir:
+ output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv")
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ self.output_file.set_path(output_path.replace('\\', '/'))
+
+ # 4. 同步更新预览按钮状态(路径可能已自动填充)
+ self._check_csv_exists()
+
+ def run_step(self):
+ """独立运行步骤10"""
+ deglint_img_path = self.deglint_img_file.get_path()
+ if not deglint_img_path:
+ QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
+ return
+
+ main_window = self.window()
+ if hasattr(main_window, 'run_single_step'):
+ config = {'step10': self.get_config()}
+ main_window.run_single_step('step10', config)
+
+ def _check_csv_exists(self):
+ """检查 output csv 是否存在,驱动预览按钮启停"""
+ csv_path = self.output_file.get_path()
+ enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path))
+ self.preview_btn.setEnabled(enabled)
+ return enabled
+
+ def _on_output_changed(self, _text=None):
+ """输出路径输入框内容变化时调用(_text 为 line_edit.textChanged 信号参数)"""
+ self._check_csv_exists()
+
+ def _open_sampling_viewer(self):
+ """打开交互式采样点查看器弹窗"""
+ csv_path = self.output_file.get_path()
+ if not csv_path or not os.path.exists(csv_path):
+ QMessageBox.warning(
+ self, "文件不存在",
+ f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤10生成数据。"
+ )
+ return
+ dialog = SamplingViewerDialog(csv_path, self)
+ dialog.exec_()
+ # 弹窗关闭后再次检查状态(可能文件被覆盖等)
+ self._check_csv_exists()
diff --git a/src/gui/panels/step11_ml_panel.py b/src/gui/panels/step11_ml_panel.py
new file mode 100644
index 0000000..8881f1e
--- /dev/null
+++ b/src/gui/panels/step11_ml_panel.py
@@ -0,0 +1,462 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+Step8 面板 - 机器学习预测
+"""
+
+import os
+from pathlib import Path
+
+from PyQt5.QtWidgets import (
+ QWidget, QVBoxLayout, QGroupBox, QFormLayout,
+ QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
+ QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout,
+ QListWidgetItem,
+)
+from PyQt5.QtCore import Qt
+
+from src.gui.components.custom_widgets import FileSelectWidget
+from src.gui.styles import ModernStylesheet
+
+
+class Step11MlPanel(QWidget):
+ """步骤11:机器学习预测"""
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.external_models_dict = {} # {subdir_name: model_obj, ...}
+ self.external_model_dir = "" # 母文件夹路径(隐藏)
+ self.init_ui()
+
+ def init_ui(self):
+ layout = QVBoxLayout()
+
+ # -------- 模型来源选择(单选按钮组) --------
+ source_group = QGroupBox("模型来源")
+ source_layout = QVBoxLayout()
+
+ self.use_trained_model = QRadioButton("使用当前训练流程的模型")
+ self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
+ self.use_trained_model.setChecked(True)
+ source_layout.addWidget(self.use_trained_model)
+ source_layout.addWidget(self.use_external_model)
+
+ self.use_trained_model.toggled.connect(self._on_model_source_changed)
+ self.use_external_model.toggled.connect(self._on_model_source_changed)
+
+ source_group.setStyleSheet("""
+ QRadioButton {
+ font-size: 13px;
+ spacing: 8px;
+ }
+ QRadioButton::indicator {
+ width: 16px;
+ height: 16px;
+ border-radius: 9px;
+ border: 2px solid #A0A0A0;
+ background-color: #FFFFFF;
+ }
+ QRadioButton::indicator:hover {
+ border: 2px solid #0078D7;
+ }
+ QRadioButton::indicator:checked {
+ background-color: #0078D7;
+ border: 2px solid #0078D7;
+ }
+ """)
+
+ source_group.setLayout(source_layout)
+ layout.addWidget(source_group)
+
+ # -------- 外部模型文件选择(条件显示) --------
+ self.external_model_widget = FileSelectWidget(
+ "模型母文件夹:",
+ "Directories"
+ )
+ self.external_model_widget.browse_btn.clicked.disconnect()
+ self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
+ self.external_model_widget.setVisible(False)
+ layout.addWidget(self.external_model_widget)
+
+ # -------- 已扫描模型列表(条件显示) --------
+ self.model_list_group = QGroupBox("选择参与预测的模型")
+ self.model_list_group.setVisible(False)
+ model_list_layout = QVBoxLayout()
+
+ self.model_list = QListWidget()
+ self.model_list.setMaximumHeight(130)
+ self.model_list.setSelectionMode(QAbstractItemView.NoSelection)
+ self.model_list.setStyleSheet("""
+ QListWidget {
+ border: 1px solid #C0C0C0;
+ border-radius: 4px;
+ background-color: #FFFFFF;
+ font-size: 12px;
+ }
+ QListWidget::item {
+ padding: 4px 6px;
+ border-bottom: 1px solid #F0F0F0;
+ }
+ QListWidget::item:selected {
+ background-color: transparent;
+ }
+ """)
+ model_list_layout.addWidget(self.model_list)
+
+ btn_row = QHBoxLayout()
+ self.btn_select_all = QPushButton("全选")
+ self.btn_select_all.setMaximumWidth(80)
+ self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
+ self.btn_select_all.clicked.connect(self._select_all_models)
+
+ self.btn_select_none = QPushButton("全不选")
+ self.btn_select_none.setMaximumWidth(80)
+ self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
+ self.btn_select_none.clicked.connect(self._select_none_models)
+
+ btn_row.addWidget(self.btn_select_all)
+ btn_row.addWidget(self.btn_select_none)
+ btn_row.addStretch()
+ model_list_layout.addLayout(btn_row)
+
+ self.model_list_group.setLayout(model_list_layout)
+ layout.addWidget(self.model_list_group)
+
+ # -------- 采样光谱CSV文件(用于独立运行)--------
+ self.sampling_csv_file = FileSelectWidget(
+ "采样光谱CSV:",
+ "CSV Files (*.csv);;All Files (*.*)"
+ )
+ layout.addWidget(self.sampling_csv_file)
+
+ # 模型目录(用于独立运行)
+ self.models_dir_file = FileSelectWidget(
+ "模型目录:",
+ "Directories;;All Files (*.*)"
+ )
+ self.models_dir_file.label.setText("模型目录:")
+ self.models_dir_file.browse_btn.clicked.disconnect()
+ self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
+ layout.addWidget(self.models_dir_file)
+
+ # 参数设置
+ params_group = QGroupBox("预测参数")
+ params_layout = QFormLayout()
+
+ self.metric = QComboBox()
+ self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
+ params_layout.addRow("模型选择指标:", self.metric)
+
+ self.prediction_column = QLineEdit()
+ self.prediction_column.setText("prediction")
+ params_layout.addRow("预测列名:", self.prediction_column)
+
+ params_group.setLayout(params_layout)
+ layout.addWidget(params_group)
+
+ # 输出路径
+ self.output_file = FileSelectWidget(
+ "输出路径:",
+ "CSV Files (*.csv);;All Files (*.*)"
+ )
+ layout.addWidget(self.output_file)
+
+ # 启用步骤
+ self.enable_checkbox = QCheckBox("启用此步骤")
+ self.enable_checkbox.setChecked(True)
+ layout.addWidget(self.enable_checkbox)
+
+ # 独立运行按钮
+ self.run_btn = QPushButton("独立运行此步骤")
+ self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
+ self.run_btn.clicked.connect(self.run_step)
+ layout.addWidget(self.run_btn)
+
+ layout.addStretch()
+ self.setLayout(layout)
+
+ def _on_model_source_changed(self, checked: bool):
+ """单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
+ if not checked:
+ return
+ is_external = self.use_external_model.isChecked()
+ self.external_model_widget.setVisible(is_external)
+ self.model_list_group.setVisible(is_external)
+ if not is_external:
+ self.external_models_dict = {}
+ self.external_model_dir = ""
+ self._clear_model_list()
+
+ def _scan_external_model_dir(self):
+ """浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
+ default = self._get_default_work_dir()
+ if default:
+ default = os.path.join(default, "7_Supervised_Model_Training")
+ dir_path = QFileDialog.getExistingDirectory(
+ self,
+ "选择模型母文件夹",
+ default,
+ )
+ if not dir_path:
+ return
+
+ self.external_model_dir = dir_path
+ models_found = {}
+ errors = []
+
+ try:
+ import joblib
+
+ for subentry in os.scandir(dir_path):
+ if not subentry.is_dir():
+ continue
+ subdir_name = subentry.name
+ joblib_files = [
+ f for f in os.scandir(subentry.path)
+ if f.is_file() and f.name.lower().endswith(".joblib")
+ ]
+ if not joblib_files:
+ continue
+ # 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
+ joblib_path = joblib_files[0].path
+ try:
+ loaded = joblib.load(joblib_path)
+ if isinstance(loaded, dict) and "model" in loaded:
+ model_obj = loaded["model"]
+ elif hasattr(loaded, "predict"):
+ model_obj = loaded
+ else:
+ errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
+ continue
+ models_found[subdir_name] = model_obj
+ except Exception as e:
+ errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
+
+ except Exception as e:
+ QMessageBox.warning(
+ self,
+ "扫描失败",
+ f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
+ )
+ return
+
+ if not models_found:
+ QMessageBox.warning(
+ self,
+ "未找到模型",
+ f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
+ "请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
+ )
+ self.external_model_widget.set_path("")
+ self.external_models_dict = {}
+ self._clear_model_list()
+ return
+
+ self.external_models_dict = models_found
+ self._populate_model_list(models_found)
+ names = sorted(models_found.keys())
+ display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
+ self.external_model_widget.set_path(display)
+ self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
+
+ err_lines = "\n".join(errors) if errors else "无"
+ QMessageBox.information(
+ self,
+ "模型扫描完成",
+ f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
+ f"加载失败 {len(errors)} 个:\n{err_lines}",
+ )
+
+ def _populate_model_list(self, models_dict):
+ """将扫描到的模型填充到 QListWidget,每个条目可勾选,默认全选"""
+ self.model_list.clear()
+ for name in sorted(models_dict.keys()):
+ item = QListWidgetItem(name)
+ item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
+ item.setCheckState(Qt.Checked)
+ self.model_list.addItem(item)
+
+ def _clear_model_list(self):
+ """清空模型列表"""
+ self.model_list.clear()
+
+ def _select_all_models(self):
+ """全选:设置所有条目为 Checked"""
+ for i in range(self.model_list.count()):
+ self.model_list.item(i).setCheckState(Qt.Checked)
+
+ def _select_none_models(self):
+ """全不选:设置所有条目为 Unchecked"""
+ for i in range(self.model_list.count()):
+ self.model_list.item(i).setCheckState(Qt.Unchecked)
+
+ def _get_checked_models_dict(self):
+ """从列表中提取用户勾选的模型,组装成字典返回"""
+ result = {}
+ for i in range(self.model_list.count()):
+ item = self.model_list.item(i)
+ if item.checkState() == Qt.Checked:
+ name = item.text()
+ if name in self.external_models_dict:
+ result[name] = self.external_models_dict[name]
+ return result
+
+ def update_from_config(self, work_dir=None, pipeline=None):
+ """从全局配置自动填充采样光谱和模型目录
+
+ Args:
+ work_dir: 工作目录路径
+ pipeline: Pipeline 实例(未使用,保留接口兼容性)
+ """
+ try:
+ import traceback
+
+ if work_dir:
+ self.work_dir = work_dir
+ elif hasattr(self, 'work_dir') and self.work_dir:
+ pass
+ else:
+ self.work_dir = None
+
+ main_window = self.window()
+
+ # 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
+ if main_window and hasattr(main_window, 'step10_panel'):
+ step7_widget = getattr(main_window.step10_panel, 'output_file', None)
+ step7_output_path = ""
+ if hasattr(step7_widget, 'get_path'):
+ step7_output_path = step7_widget.get_path() or ""
+ elif hasattr(step7_widget, 'text'):
+ step7_output_path = step7_widget.text() or ""
+
+ if step7_output_path:
+ # 若为相对路径,使用 work_dir 合成为绝对路径
+ if not os.path.isabs(step7_output_path):
+ step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
+ existing = self.sampling_csv_file.get_path()
+ if not existing or not existing.strip():
+ self.sampling_csv_file.set_path(step7_output_path)
+
+ # 2. 尝试从 Step6 界面读取监督模型目录
+ if main_window and hasattr(main_window, 'step7_panel'):
+ step6_widget = getattr(main_window.step7_panel, 'output_dir', None)
+ step6_models_dir = ""
+ if hasattr(step6_widget, 'get_path'):
+ step6_models_dir = step6_widget.get_path() or ""
+ elif hasattr(step6_widget, 'text'):
+ step6_models_dir = step6_widget.text() or ""
+
+ if step6_models_dir:
+ # 若为相对路径,使用 work_dir 合成为绝对路径
+ if not os.path.isabs(step6_models_dir):
+ step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/')
+ existing_models = self.models_dir_file.get_path()
+ if not existing_models or not existing_models.strip():
+ self.models_dir_file.set_path(step6_models_dir)
+
+ # 3. 自动填充输出路径(机器学习预测目录)
+ if self.work_dir:
+ output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction")
+ os.makedirs(output_dir, exist_ok=True)
+ existing_out = self.output_file.get_path()
+ if not existing_out or not existing_out.strip():
+ self.output_file.set_path(output_dir)
+ except Exception as e:
+ import traceback
+ print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
+ traceback.print_exc()
+
+ def _get_default_work_dir(self):
+ """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
+ if hasattr(self, 'work_dir') and self.work_dir:
+ return str(self.work_dir)
+ mw = self.window()
+ if mw and hasattr(mw, 'work_dir') and mw.work_dir:
+ return str(mw.work_dir)
+ return ""
+
+ def browse_models_dir(self):
+ """浏览模型目录"""
+ default = self._get_default_work_dir()
+ if default:
+ default = os.path.join(default, "7_Supervised_Model_Training")
+ dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
+ if dir_path:
+ self.models_dir_file.set_path(dir_path)
+
+ def get_config(self):
+ """获取配置"""
+ config = {
+ 'metric': self.metric.currentText(),
+ 'prediction_column': self.prediction_column.text(),
+ }
+ sampling_csv_path = self.sampling_csv_file.get_path()
+ if sampling_csv_path:
+ config['sampling_csv_path'] = sampling_csv_path
+ models_dir = self.models_dir_file.get_path()
+ if models_dir:
+ config['models_dir'] = models_dir
+ output_path = self.output_file.get_path()
+ if output_path:
+ config['output_path'] = output_path
+ return config
+
+ def set_config(self, config):
+ """设置配置"""
+ if 'metric' in config:
+ idx = self.metric.findText(config['metric'])
+ if idx >= 0:
+ self.metric.setCurrentIndex(idx)
+ if 'prediction_column' in config:
+ self.prediction_column.setText(config['prediction_column'])
+ if 'sampling_csv_path' in config:
+ self.sampling_csv_file.set_path(config['sampling_csv_path'])
+ if 'models_dir' in config:
+ self.models_dir_file.set_path(config['models_dir'])
+ if 'output_path' in config:
+ self.output_file.set_path(config['output_path'])
+
+ def run_step(self):
+ """独立运行步骤8"""
+ sampling_csv_path = self.sampling_csv_file.get_path()
+ if not sampling_csv_path:
+ QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
+ return
+
+ # 外部模型优先:用户选择了"导入本地预训练模型"
+ if self.use_external_model.isChecked():
+ if not self.external_models_dict:
+ QMessageBox.warning(
+ self,
+ "模型未加载",
+ "请先点击「浏览...」按钮选择模型母文件夹!",
+ )
+ return
+ # 只传递用户勾选的模型
+ checked_dict = self._get_checked_models_dict()
+ if not checked_dict:
+ QMessageBox.warning(
+ self,
+ "未选择模型",
+ "请至少勾选一个模型参与预测!",
+ )
+ return
+ main_window = self.window()
+ if hasattr(main_window, 'run_single_step'):
+ config = {
+ 'step11_ml': self.get_config(),
+ '_external_models_dict': checked_dict,
+ '_external_model_dir': self.external_model_dir,
+ }
+ main_window.run_single_step('step11_ml', config)
+ return
+
+ # 默认流程:使用模型目录
+ models_dir = self.models_dir_file.get_path()
+ if not models_dir:
+ QMessageBox.warning(self, "输入错误", "请选择模型目录!")
+ return
+
+ main_window = self.window()
+ if hasattr(main_window, 'run_single_step'):
+ config = {'step11_ml': self.get_config()}
+ main_window.run_single_step('step11_ml', config)
diff --git a/src/gui/panels/step8_5_panel.py b/src/gui/panels/step11_panel.py
similarity index 86%
rename from src/gui/panels/step8_5_panel.py
rename to src/gui/panels/step11_panel.py
index 74e8520..ce78fb9 100644
--- a/src/gui/panels/step8_5_panel.py
+++ b/src/gui/panels/step11_panel.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
-Step8_5 面板 - 非经验模型预测
+Step11 面板 - 非经验模型预测
"""
import os
@@ -17,8 +17,8 @@ from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
-class Step8_5Panel(QWidget):
- """步骤8.5:非经验模型预测"""
+class Step11Panel(QWidget):
+ """步骤11:非经验模型预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
@@ -118,22 +118,22 @@ class Step8_5Panel(QWidget):
if not existing or not existing.strip():
self.sampling_csv_file.set_path(step7_output_path)
- # 2. 尝试从 Step6.5 界面读取回归模型目录
- if main_window and hasattr(main_window, 'step6_5_panel'):
- step6_5_widget = getattr(main_window.step6_5_panel, 'output_dir', None)
- step6_5_models_dir = ""
- if hasattr(step6_5_widget, 'get_path'):
- step6_5_models_dir = step6_5_widget.get_path() or ""
- elif hasattr(step6_5_widget, 'text'):
- step6_5_models_dir = step6_5_widget.text() or ""
+ # 2. 尝试从 Step8_Non_Empirical 界面读取回归模型目录
+ if main_window and hasattr(main_window, 'step8_non_empirical_panel'):
+ step8_non_empirical_widget = getattr(main_window.step8_non_empirical_panel, 'output_dir', None)
+ step8_non_empirical_models_dir = ""
+ if hasattr(step8_non_empirical_widget, 'get_path'):
+ step8_non_empirical_models_dir = step8_non_empirical_widget.get_path() or ""
+ elif hasattr(step8_non_empirical_widget, 'text'):
+ step8_non_empirical_models_dir = step8_non_empirical_widget.text() or ""
- if step6_5_models_dir:
+ if step8_non_empirical_models_dir:
# 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(step6_5_models_dir):
- step6_5_models_dir = os.path.join(self.work_dir or '', step6_5_models_dir).replace('\\', '/')
+ if not os.path.isabs(step8_non_empirical_models_dir):
+ step8_non_empirical_models_dir = os.path.join(self.work_dir or '', step8_non_empirical_models_dir).replace('\\', '/')
existing_models = self.models_dir_file.get_path()
if not existing_models or not existing_models.strip():
- self.models_dir_file.set_path(step6_5_models_dir)
+ self.models_dir_file.set_path(step8_non_empirical_models_dir)
# 3. 自动填充输出路径(非经验模型预测目录)
if self.work_dir:
@@ -208,7 +208,7 @@ class Step8_5Panel(QWidget):
self.enable_checkbox.setChecked(config['enabled'])
def run_step(self):
- """独立运行步骤8.5"""
+ """独立运行步骤11"""
sampling_csv_path = self.sampling_csv_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
@@ -221,6 +221,6 @@ class Step8_5Panel(QWidget):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
- parent.run_single_step('step8_5', {'step8_5': config})
+ parent.run_single_step('step11', {'step11': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
diff --git a/src/gui/panels/step8_75_panel.py b/src/gui/panels/step12_panel.py
similarity index 88%
rename from src/gui/panels/step8_75_panel.py
rename to src/gui/panels/step12_panel.py
index fb37d53..ab97b71 100644
--- a/src/gui/panels/step8_75_panel.py
+++ b/src/gui/panels/step12_panel.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
-Step8_75 面板 - 自定义回归预测
+Step12 面板 - 自定义回归预测
"""
import os
@@ -15,8 +15,8 @@ from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
-class Step8_75Panel(QWidget):
- """步骤8.75:自定义回归预测"""
+class Step12Panel(QWidget):
+ """步骤12:自定义回归预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
@@ -111,25 +111,25 @@ class Step8_75Panel(QWidget):
if not existing or not existing.strip():
self.sampling_csv_file.set_path(step7_output_path)
- # 2. 尝试从 Step6.75 界面读取自定义回归模型目录
- if main_window and hasattr(main_window, 'step6_75_panel'):
- step6_75_widget = getattr(main_window.step6_75_panel, 'output_dir', None)
- step6_75_models_dir = ""
- if hasattr(step6_75_widget, 'get_path'):
- step6_75_models_dir = step6_75_widget.get_path() or ""
- elif hasattr(step6_75_widget, 'text'):
- step6_75_models_dir = step6_75_widget.text() or ""
- step6_75_models_dir = step6_75_models_dir.strip()
+ # 2. 尝试从 Step9 界面读取自定义回归模型目录
+ if main_window and hasattr(main_window, 'step12_panel'):
+ step9_widget = getattr(main_window.step9_panel, 'output_dir', None)
+ step9_models_dir = ""
+ if hasattr(step9_widget, 'get_path'):
+ step9_models_dir = step9_widget.get_path() or ""
+ elif hasattr(step9_widget, 'text'):
+ step9_models_dir = step9_widget.text() or ""
+ step9_models_dir = step9_models_dir.strip()
- if step6_75_models_dir:
+ if step9_models_dir:
# 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(step6_75_models_dir):
- step6_75_models_dir = os.path.join(self.work_dir or '', step6_75_models_dir).replace('\\', '/')
+ if not os.path.isabs(step9_models_dir):
+ step9_models_dir = os.path.join(self.work_dir or '', step9_models_dir).replace('\\', '/')
existing_models = self.regression_models_dir.get_path()
if not existing_models or not existing_models.strip():
- self.regression_models_dir.set_path(step6_75_models_dir)
+ self.regression_models_dir.set_path(step9_models_dir)
- # 3. 自动填充回归模型目录(如果 step6_75 未提供)
+ # 3. 自动填充回归模型目录(如果 step9 未提供)
if self.work_dir:
models_dir = self.regression_models_dir.get_path().strip()
if not models_dir:
@@ -208,7 +208,7 @@ class Step8_75Panel(QWidget):
self.enable_checkbox.setChecked(config['enabled'])
def run_step(self):
- """独立运行步骤8.75"""
+ """独立运行步骤12"""
sampling_csv_path = self.sampling_csv_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
@@ -225,6 +225,6 @@ class Step8_75Panel(QWidget):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
- parent.run_single_step('step8_75', {'step8_75': config})
+ parent.run_single_step('step12', {'step12': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
diff --git a/src/gui/panels/step14_panel.py b/src/gui/panels/step14_panel.py
new file mode 100644
index 0000000..77e9aa7
--- /dev/null
+++ b/src/gui/panels/step14_panel.py
@@ -0,0 +1,533 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+Step14 面板 - 分布图生成
+"""
+
+import os
+import traceback
+from pathlib import Path
+from typing import List, Optional
+
+from PyQt5.QtCore import Qt, QThread, pyqtSignal
+from PyQt5.QtWidgets import (
+ QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
+ QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
+ QRadioButton, QButtonGroup, QMessageBox, QFileDialog,
+)
+
+from src.gui.components.custom_widgets import FileSelectWidget
+from src.gui.styles import ModernStylesheet
+
+# Pipeline 可用性(与 core/worker_thread.py 保持一致)
+try:
+ from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
+ PIPELINE_AVAILABLE = True
+except ImportError:
+ PIPELINE_AVAILABLE = False
+
+
+class Step14BatchThread(QThread):
+ """专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
+
+ finished_ok = pyqtSignal(int)
+ failed = pyqtSignal(str)
+ log_message = pyqtSignal(str, str)
+
+ def __init__(self, work_dir: str, csv_paths: List[str], step14_kwargs: dict, output_dir_optional: Optional[str]):
+ super().__init__()
+ self.work_dir = work_dir
+ self.csv_paths = csv_paths
+ self.step14_kwargs = step14_kwargs
+ self.output_dir_optional = (output_dir_optional or "").strip() or None
+
+ def run(self):
+ mpl_prev = None
+ try:
+ import matplotlib
+ mpl_prev = matplotlib.get_backend()
+ except Exception:
+ pass
+ try:
+ import matplotlib.pyplot as plt
+ plt.switch_backend("Agg")
+ except Exception:
+ mpl_prev = None
+ try:
+ from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
+ pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
+ n = len(self.csv_paths)
+ for i, csv_p in enumerate(self.csv_paths):
+ self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
+ kw = {**self.step14_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
+ if self.output_dir_optional:
+ stem = Path(csv_p).stem
+ kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
+ else:
+ kw["output_image_path"] = None
+ pipeline.step9_generate_distribution_map(**kw)
+ self.finished_ok.emit(n)
+ except Exception as e:
+ self.failed.emit(f"{e}\n{traceback.format_exc()}")
+ finally:
+ if mpl_prev:
+ try:
+ import matplotlib.pyplot as plt
+ plt.switch_backend(mpl_prev)
+ except Exception:
+ pass
+
+
+class Step14Panel(QWidget):
+ """步骤14:分布图生成"""
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self._batch_thread = None
+ self.init_ui()
+
+ def init_ui(self):
+ layout = QVBoxLayout()
+
+ hint = QLabel(
+ "独立运行:可选「单个 CSV」或「文件夹批量」(扫描目录下所有 .csv)。"
+ "完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。"
+ )
+ hint.setWordWrap(True)
+ hint.setStyleSheet(
+ f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
+ )
+ layout.addWidget(hint)
+
+ mode_row = QHBoxLayout()
+ self.mode_single_rb = QRadioButton("单个 CSV 文件")
+ self.mode_folder_rb = QRadioButton("文件夹批量")
+ self._mode_group = QButtonGroup(self)
+ self._mode_group.addButton(self.mode_single_rb, 0)
+ self._mode_group.addButton(self.mode_folder_rb, 1)
+ mode_row.addWidget(self.mode_single_rb)
+ mode_row.addWidget(self.mode_folder_rb)
+ mode_row.addStretch()
+ layout.addLayout(mode_row)
+
+ # ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
+ radio_style = """
+ QRadioButton {
+ font-size: 14px;
+ spacing: 8px;
+ color: #333333;
+ }
+ QRadioButton::indicator {
+ width: 16px;
+ height: 16px;
+ border: 2px solid #999999;
+ border-radius: 3px;
+ background-color: white;
+ }
+ QRadioButton::indicator:checked {
+ border: 2px solid #0078d4;
+ background-color: #0078d4;
+ image: none;
+ }
+ QRadioButton::indicator:hover {
+ border: 2px solid #005a9e;
+ }
+ """
+ self.mode_single_rb.setStyleSheet(radio_style)
+ self.mode_folder_rb.setStyleSheet(radio_style)
+
+ self.prediction_csv_file = FileSelectWidget(
+ "预测结果CSV:",
+ "CSV Files (*.csv);;All Files (*.*)"
+ )
+ layout.addWidget(self.prediction_csv_file)
+
+ folder_row = QHBoxLayout()
+ self.prediction_csv_dir_label = QLabel("预测CSV目录:")
+ self.prediction_csv_dir_label.setMinimumWidth(120)
+ self.prediction_csv_dir_edit = QLineEdit()
+ self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
+ pred_dir_btn = QPushButton("浏览…")
+ pred_dir_btn.setMaximumWidth(80)
+ pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
+ folder_row.addWidget(self.prediction_csv_dir_label)
+ folder_row.addWidget(self.prediction_csv_dir_edit, 1)
+ folder_row.addWidget(pred_dir_btn)
+ self._folder_row_widget = QWidget()
+ self._folder_row_widget.setLayout(folder_row)
+ layout.addWidget(self._folder_row_widget)
+
+ self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv)")
+ layout.addWidget(self.recursive_csv_cb)
+
+ self.boundary_file = FileSelectWidget(
+ "边界文件:",
+ "Shapefiles (*.shp);;All Files (*.*)"
+ )
+ layout.addWidget(self.boundary_file)
+
+ # 参数设置
+ params_group = QGroupBox("生成参数")
+ params_layout = QFormLayout()
+
+ self.resolution = QDoubleSpinBox()
+ self.resolution.setRange(1, 1000)
+ self.resolution.setValue(30)
+ params_layout.addRow("分辨率(米):", self.resolution)
+
+ self.input_crs = QLineEdit()
+ self.input_crs.setText("EPSG:32651")
+ params_layout.addRow("输入坐标系:", self.input_crs)
+
+ self.output_crs = QLineEdit()
+ self.output_crs.setText("EPSG:4326")
+ params_layout.addRow("输出坐标系:", self.output_crs)
+
+ self.show_points = QCheckBox("显示采样点")
+ params_layout.addRow("", self.show_points)
+
+ self.use_diffusion = QCheckBox("启用距离扩散")
+ self.use_diffusion.setChecked(True)
+ params_layout.addRow("", self.use_diffusion)
+
+ params_group.setLayout(params_layout)
+ layout.addWidget(params_group)
+
+ # 输出目录
+ self.output_dir = FileSelectWidget(
+ "输出分布图目录:",
+ "Directories;;All Files (*.*)"
+ )
+ self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
+ self.output_dir.browse_btn.clicked.disconnect()
+ self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
+ layout.addWidget(self.output_dir)
+
+ # 启用步骤
+ self.enable_checkbox = QCheckBox("启用此步骤")
+ self.enable_checkbox.setChecked(True)
+ layout.addWidget(self.enable_checkbox)
+
+ # 独立运行按钮
+ self.run_button = QPushButton("独立运行此步骤")
+ self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
+ self.run_button.clicked.connect(self.run_step)
+ layout.addWidget(self.run_button)
+
+ layout.addStretch()
+ self.setLayout(layout)
+
+ # 信号绑定与初始状态
+ self.mode_single_rb.toggled.connect(self._toggle_input_mode)
+ self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
+ self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
+ self._toggle_input_mode() # 根据默认值设置初始显示状态
+
+ def _toggle_input_mode(self):
+ """槽函数:根据单选框状态动态显示/隐藏对应的输入组件。"""
+ folder_mode = self.mode_folder_rb.isChecked()
+ # 单个 CSV 模式:显示单文件选择,隐藏文件夹选择
+ self.prediction_csv_file.setVisible(not folder_mode)
+ # 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择
+ self._folder_row_widget.setVisible(folder_mode)
+ self.recursive_csv_cb.setVisible(folder_mode)
+
+ def _get_default_work_dir(self):
+ """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
+ if hasattr(self, 'work_dir') and self.work_dir:
+ return str(self.work_dir)
+ mw = self.window()
+ if mw and hasattr(mw, 'work_dir') and mw.work_dir:
+ return str(mw.work_dir)
+ return ""
+
+ def browse_prediction_csv_dir(self):
+ default = self._get_default_work_dir()
+ if default:
+ default = os.path.join(default, "11_12_13_predictions")
+ d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
+ if d:
+ self.prediction_csv_dir_edit.setText(d)
+
+ def _collect_csv_paths_from_folder(self) -> List[str]:
+ folder = (self.prediction_csv_dir_edit.text() or "").strip()
+ if not folder or not os.path.isdir(folder):
+ return []
+ root = Path(folder)
+ if self.recursive_csv_cb.isChecked():
+ files = sorted(root.rglob("*.csv"))
+ else:
+ files = sorted(root.glob("*.csv"))
+ return [str(p) for p in files if p.is_file()]
+
+ def _step14_base_pipeline_kwargs(self) -> dict:
+ return {
+ 'boundary_shp_path': self.boundary_file.get_path(),
+ 'resolution': self.resolution.value(),
+ 'input_crs': self.input_crs.text(),
+ 'output_crs': self.output_crs.text(),
+ 'show_sample_points': self.show_points.isChecked(),
+ 'use_distance_diffusion': self.use_diffusion.isChecked(),
+ }
+
+ def get_config(self):
+ pred_csv = (self.prediction_csv_file.get_path() or "").strip()
+ folder_mode = self.mode_folder_rb.isChecked()
+ pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
+ config = {
+ 'step14_batch_mode': 'folder' if folder_mode else 'single',
+ 'prediction_csv_dir': pred_dir if pred_dir else None,
+ 'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
+ 'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
+ 'boundary_shp_path': self.boundary_file.get_path(),
+ 'resolution': self.resolution.value(),
+ 'input_crs': self.input_crs.text(),
+ 'output_crs': self.output_crs.text(),
+ 'show_sample_points': self.show_points.isChecked(),
+ 'use_distance_diffusion': self.use_diffusion.isChecked(),
+ }
+ out_dir = (self.output_dir.get_path() or "").strip()
+ if not folder_mode and pred_csv and out_dir:
+ stem = Path(pred_csv).stem
+ config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
+ else:
+ config['output_image_path'] = None
+ return config
+
+ def set_config(self, config):
+ mode = config.get('step14_batch_mode', 'single')
+ if mode == 'folder':
+ self.mode_folder_rb.setChecked(True)
+ else:
+ self.mode_single_rb.setChecked(True)
+ if config.get('prediction_csv_dir'):
+ self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
+ if 'recursive_csv_scan' in config:
+ self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
+ if 'prediction_csv_path' in config and config['prediction_csv_path']:
+ self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
+ if 'boundary_shp_path' in config:
+ self.boundary_file.set_path(config['boundary_shp_path'])
+ if 'resolution' in config:
+ self.resolution.setValue(config['resolution'])
+ if 'input_crs' in config:
+ self.input_crs.setText(config['input_crs'])
+ if 'output_crs' in config:
+ self.output_crs.setText(config['output_crs'])
+ if 'show_sample_points' in config:
+ self.show_points.setChecked(config['show_sample_points'])
+ if 'use_distance_diffusion' in config:
+ self.use_diffusion.setChecked(config['use_distance_diffusion'])
+ if 'output_dir' in config and config['output_dir']:
+ self.output_dir.set_path(str(config['output_dir']))
+ elif config.get('output_image_path'):
+ p = Path(str(config['output_image_path']))
+ if p.parent and str(p.parent) != '.':
+ self.output_dir.set_path(str(p.parent))
+
+ def update_from_config(self, work_dir=None, pipeline=None):
+ """从全局配置自动填充预测结果目录
+
+ 优先使用 Step8(机器学习预测)的输出目录作为待预测 CSV 目录;
+ 其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
+
+ Args:
+ work_dir: 工作目录路径
+ pipeline: Pipeline 实例(未使用,保留接口兼容性)
+ """
+ try:
+ import traceback
+
+ if work_dir:
+ self.work_dir = work_dir
+ elif hasattr(self, 'work_dir') and self.work_dir:
+ pass
+ else:
+ self.work_dir = None
+
+ main_window = self.window()
+ if not main_window:
+ return
+
+ # 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
+ pred_dir = None
+ if hasattr(main_window, 'step11_prediction_panel'):
+ step8_widget = getattr(main_window.step11_prediction_panel, 'output_file', None)
+ step8_output = ""
+ if hasattr(step8_widget, 'get_path'):
+ step8_output = step8_widget.get_path() or ""
+ elif hasattr(step8_widget, 'text'):
+ step8_output = step8_widget.text() or ""
+
+ if step8_output:
+ # 若为相对路径,使用 work_dir 合成为绝对路径
+ if not os.path.isabs(step8_output):
+ step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
+ # 提取父目录后追加 Machine_Learning_Prediction(最底层真实子目录)
+ base_pred_dir = str(Path(step8_output).parent)
+ ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
+ pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
+
+ # 2. 备选:从 Step11 界面读取非经验预测输出目录
+ if not pred_dir and hasattr(main_window, 'step11_panel'):
+ step8_5_widget = getattr(main_window.step11_panel, 'output_file', None)
+ step8_5_output = ""
+ if hasattr(step8_5_widget, 'get_path'):
+ step8_5_output = step8_5_widget.get_path() or ""
+ elif hasattr(step8_5_widget, 'text'):
+ step8_5_output = step8_5_widget.text() or ""
+
+ if step8_5_output:
+ # 若为相对路径,使用 work_dir 合成为绝对路径
+ if not os.path.isabs(step8_5_output):
+ step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
+ pred_dir = str(Path(step8_5_output).parent)
+
+ # 3. 备选:从 Step12 界面读取自定义回归预测输出目录
+ if not pred_dir and hasattr(main_window, 'step12_panel'):
+ step8_75_widget = getattr(main_window.step12_panel, 'output_dir_widget', None)
+ step8_75_output = ""
+ if hasattr(step8_75_widget, 'get_path'):
+ step8_75_output = step8_75_widget.get_path() or ""
+ elif hasattr(step8_75_widget, 'text'):
+ step8_75_output = step8_75_widget.text() or ""
+
+ if step8_75_output:
+ pred_dir = step8_75_output
+
+ # 自动填入"预测CSV目录"(文件夹批量模式)
+ if pred_dir:
+ existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
+ if not existing_dir:
+ self.prediction_csv_dir_edit.setText(pred_dir)
+ # 切换到文件夹批量模式
+ self.mode_folder_rb.setChecked(True)
+
+ # 4. 自动填充输出目录(14_visualization)
+ if self.work_dir:
+ output_dir = os.path.join(self.work_dir, "14_visualization")
+ os.makedirs(output_dir, exist_ok=True)
+ existing_out = self.output_dir.get_path()
+ if not existing_out or not existing_out.strip():
+ self.output_dir.set_path(output_dir)
+
+ # 5. 自动探测原始矢量边界文件(.shp)作为专题图底图
+ # 优先回溯 input-test/roi.shp,geopandas.read_file 仅支持矢量格式
+ if self.work_dir:
+ possible_shp = None
+ candidates = [
+ Path(self.work_dir).parent / "input-test" / "roi.shp",
+ Path(self.work_dir) / "roi.shp",
+ Path(self.work_dir).parent / "roi.shp",
+ ]
+ for candidate in candidates:
+ if candidate.exists() and candidate.suffix.lower() == ".shp":
+ possible_shp = candidate
+ break
+
+ existing_boundary = (self.boundary_file.get_path() or "").strip()
+ if not existing_boundary and possible_shp:
+ self.boundary_file.set_path(str(possible_shp))
+ elif not existing_boundary:
+ # 未找到 .shp 时清空并提示用户手动选择矢量文件
+ self.boundary_file.set_path("")
+ print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。")
+ except Exception as e:
+ import traceback
+ print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
+ traceback.print_exc()
+
+ def browse_output_dir(self):
+ """浏览输出目录"""
+ default = self._get_default_work_dir()
+ if default:
+ default = os.path.join(default, "14_visualization")
+ dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
+ if dir_path:
+ self.output_dir.set_path(dir_path)
+
+ def run_step(self):
+ """独立运行步骤14"""
+ if self._batch_thread and self._batch_thread.isRunning():
+ QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
+ return
+
+ boundary_shp_path = self.boundary_file.get_path()
+ if not boundary_shp_path:
+ QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
+ return
+ if not os.path.exists(boundary_shp_path):
+ QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
+ return
+
+ parent = self.parent()
+ while parent and not hasattr(parent, 'run_single_step'):
+ parent = parent.parent()
+
+ if not parent or not hasattr(parent, 'run_single_step'):
+ QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
+ return
+
+ if self.mode_folder_rb.isChecked():
+ csv_list = self._collect_csv_paths_from_folder()
+ if not csv_list:
+ QMessageBox.warning(
+ self,
+ "输入验证失败",
+ "所选文件夹中未找到 .csv 文件,或目录无效。\n"
+ "可勾选「包含子文件夹」以递归扫描。",
+ )
+ return
+ if not PIPELINE_AVAILABLE:
+ QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
+ return
+ work_dir = getattr(parent, "work_dir", None) or "./work_dir"
+ work_dir = str(work_dir)
+ base_kw = self._step14_base_pipeline_kwargs()
+ out_dir_opt = (self.output_dir.get_path() or "").strip() or None
+ self.run_button.setEnabled(False)
+ self._batch_thread = Step14BatchThread(work_dir, csv_list, base_kw, out_dir_opt)
+ main_win = parent
+
+ def _batch_log(msg, lvl):
+ if hasattr(main_win, "log_message"):
+ main_win.log_message(msg, lvl)
+
+ self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
+ self._batch_thread.finished_ok.connect(self._on_step14_batch_ok, Qt.QueuedConnection)
+ self._batch_thread.failed.connect(self._on_step14_batch_fail, Qt.QueuedConnection)
+ self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection)
+ self._batch_thread.start()
+ if hasattr(parent, "log_message"):
+ parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV,工作目录 {work_dir}", "info")
+ return
+
+ prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
+ if not prediction_csv_path:
+ QMessageBox.warning(
+ self,
+ "输入验证失败",
+ "请选择「预测结果 CSV」文件,或切换到「文件夹批量」。",
+ )
+ return
+ if not os.path.isfile(prediction_csv_path):
+ QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
+ return
+
+ config = self.get_config()
+ parent.run_single_step('step14', {'step14': config})
+
+ def _on_step14_batch_ok(self, n: int):
+ QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
+ parent = self.parent()
+ while parent and not hasattr(parent, "log_message"):
+ parent = parent.parent()
+ if parent and hasattr(parent, "log_message"):
+ parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
+
+ def _on_step14_batch_fail(self, err: str):
+ QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
+ parent = self.parent()
+ while parent and not hasattr(parent, "log_message"):
+ parent = parent.parent()
+ if parent and hasattr(parent, "log_message"):
+ parent.log_message(err, "error")
diff --git a/src/gui/panels/step5_5_panel.py b/src/gui/panels/step5_5_panel.py
deleted file mode 100644
index 6a33037..0000000
--- a/src/gui/panels/step5_5_panel.py
+++ /dev/null
@@ -1,225 +0,0 @@
-import os
-import sys
-import pandas as pd
-from pathlib import Path
-from typing import Dict, List, Union
-
-from PyQt5.QtWidgets import (
- QWidget, QVBoxLayout, QGroupBox, QGridLayout,
- QHBoxLayout, QLabel, QCheckBox, QPushButton, QMessageBox, QScrollArea
-)
-from PyQt5.QtCore import Qt
-from src.gui.components.custom_widgets import FileSelectWidget
-from src.gui.styles import ModernStylesheet
-
-def get_resource_path(relative_path: str) -> str:
- """适配开发与 PyInstaller 环境的路径获取逻辑。
- 支持两种打包模式:
- 1. --onedir 模式:文件在 exe_root/_internal/ 下 → 检查 _internal 目录
- 2. --onefile 模式:文件在 sys._MEIPASS 平铺目录
- """
- # 优先检查 PyInstaller onefile 模式(文件平铺在 _MEIPASS 下)
- if hasattr(sys, '_MEIPASS'):
- internal_path = os.path.join(sys._MEIPASS, '_internal', relative_path)
- if os.path.exists(internal_path):
- return internal_path
- return os.path.join(sys._MEIPASS, relative_path)
-
- # 兼容 PyInstaller onedir 模式的 _internal 目录(exe 同级目录下)
- exe_dir = os.path.dirname(sys.executable)
- internal_path = os.path.join(exe_dir, '_internal', relative_path)
- if os.path.exists(internal_path):
- return internal_path
-
- # 开发环境下:基于当前文件 (step5_5_panel.py) 的绝对路径进行回溯
- # 当前在 src/gui/panels/,目标在 src/gui/model/
- base_dir = Path(__file__).resolve().parent.parent / "model"
- target_path = base_dir / os.path.basename(relative_path)
- return str(target_path)
-
-class Step5_5Panel(QWidget):
- def __init__(self, parent=None):
- super().__init__(parent)
- self.index_checkboxes: Dict[str, QCheckBox] = {}
- # 标识为 waterindex.csv,目录跳转逻辑在 get_resource_path 中
- self.builtin_formula_path = get_resource_path("waterindex.csv")
-
- self.init_ui()
- # 延迟一小会儿加载,确保UI框架已就绪
- self._auto_load_formulas()
-
- def init_ui(self):
- main_layout = QVBoxLayout()
- main_layout.setContentsMargins(20, 20, 20, 20)
- main_layout.setSpacing(10)
-
- # 1. 路径展示区 (半透明只读)
- path_group = QGroupBox("公式配置源 (内置)")
- path_layout = QVBoxLayout()
- self.formula_csv_widget = FileSelectWidget("内置CSV路径:", "CSV Files (*.csv)")
- self.formula_csv_widget.set_path(self.builtin_formula_path)
- self.formula_csv_widget.set_read_only(True)
- # 视觉微调:提示用户这是内置的
- self.formula_csv_widget.line_edit.setStyleSheet("background-color: #f0f0f0; color: #666;")
- path_layout.addWidget(self.formula_csv_widget)
- path_group.setLayout(path_layout)
- main_layout.addWidget(path_group)
-
- # 2. 训练数据输入
- input_group = QGroupBox("输入样本数据")
- input_layout = QVBoxLayout()
- self.training_data_widget = FileSelectWidget("特征提取CSV:", "CSV Files (*.csv)")
- input_layout.addWidget(self.training_data_widget)
- input_group.setLayout(input_layout)
- main_layout.addWidget(input_group)
-
- # 3. 公式选择区
- self.formula_group = QGroupBox("待计算水质指数勾选")
- formula_outer_layout = QVBoxLayout()
-
- btn_layout = QHBoxLayout()
- self.select_all_btn = QPushButton("全选")
- self.deselect_all_btn = QPushButton("清空")
- self.select_all_btn.clicked.connect(self.select_all_formulas)
- self.deselect_all_btn.clicked.connect(self.deselect_all_formulas)
- btn_layout.addWidget(self.select_all_btn)
- btn_layout.addWidget(self.deselect_all_btn)
- btn_layout.addStretch()
-
- self.refresh_button = QPushButton("手动重新加载公式")
- self.refresh_button.clicked.connect(lambda: self.refresh_formulas(silent=False))
- btn_layout.addWidget(self.refresh_button)
-
- formula_outer_layout.addLayout(btn_layout)
-
- # 核心滚动区
- scroll = QScrollArea()
- scroll.setWidgetResizable(True)
- scroll.setMinimumHeight(300) # 强制最小高度,防止塌陷
- self.scroll_content = QWidget()
- self.formula_layout = QGridLayout(self.scroll_content)
- self.formula_layout.setAlignment(Qt.AlignTop) # 靠顶对齐
- scroll.setWidget(self.scroll_content)
- formula_outer_layout.addWidget(scroll)
-
- self.formula_group.setLayout(formula_outer_layout)
- main_layout.addWidget(self.formula_group)
-
- # 4. 输出与运行
- output_group = QGroupBox("结果输出")
- output_layout = QVBoxLayout()
- self.output_file_widget = FileSelectWidget("保存路径:", "CSV Files (*.csv)", mode="save")
- output_layout.addWidget(self.output_file_widget)
- output_group.setLayout(output_layout)
- main_layout.addWidget(output_group)
-
- self.enable_checkbox = QCheckBox("启用计算流程")
- self.enable_checkbox.setChecked(True)
- main_layout.addWidget(self.enable_checkbox)
-
- self.run_button = QPushButton("立即执行计算")
- self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
- self.run_button.setMinimumHeight(40)
- self.run_button.clicked.connect(self.run_step)
- main_layout.addWidget(self.run_button)
-
- self.setLayout(main_layout)
-
- def _auto_load_formulas(self):
- """启动时自动加载逻辑"""
- if os.path.exists(self.builtin_formula_path):
- self.refresh_formulas(silent=True)
- else:
- print(f"DEBUG: 自动加载失败,路径不存在: {self.builtin_formula_path}")
-
- def refresh_formulas(self, silent=False):
- path = self.builtin_formula_path
- if not os.path.exists(path):
- if not silent: QMessageBox.warning(self, "错误", f"找不到内置公式文件:\n{path}")
- return
-
- try:
- # 清理旧列表
- for i in reversed(range(self.formula_layout.count())):
- widget = self.formula_layout.itemAt(i).widget()
- if widget: widget.deleteLater()
- self.index_checkboxes.clear()
-
- # 鲁棒性读取:尝试不同编码
- for encoding in ['utf-8', 'gbk', 'utf-8-sig']:
- try:
- df = pd.read_csv(path, encoding=encoding)
- if 'Formula_Name' in df.columns: break
- except: continue
-
- if 'Formula_Name' not in df.columns:
- if not silent: QMessageBox.critical(self, "错误", "CSV文件缺少 'Formula_Name' 列")
- return
-
- names = df['Formula_Name'].dropna().unique().tolist()
-
- row, col = 0, 0
- for name in names:
- name = str(name).strip()
- if not name: continue
- cb = QCheckBox(name)
- cb.setChecked(True)
- self.index_checkboxes[name] = cb
- self.formula_layout.addWidget(cb, row, col)
- col += 1
- if col >= 3:
- col = 0
- row += 1
-
- # 强制UI更新
- self.scroll_content.adjustSize()
- print(f"✅ 成功加载 {len(self.index_checkboxes)} 个公式")
-
- except Exception as e:
- if not silent: QMessageBox.critical(self, "加载失败", f"原因: {str(e)}")
-
- def select_all_formulas(self):
- for cb in self.index_checkboxes.values(): cb.setChecked(True)
-
- def deselect_all_formulas(self):
- for cb in self.index_checkboxes.values(): cb.setChecked(False)
-
- def get_config(self):
- selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
- return {
- 'training_csv_path': self.training_data_widget.get_path(),
- 'formula_csv_file': self.builtin_formula_path,
- 'formula_names': selected,
- 'output_file': self.output_file_widget.get_path(),
- 'enabled': self.enable_checkbox.isChecked()
- }
-
- def set_config(self, config):
- if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path'])
- if 'formula_names' in config:
- sel = set(config['formula_names'])
- for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
- if 'output_file' in config: self.output_file_widget.set_path(config['output_file'])
- self.enable_checkbox.setChecked(config.get('enabled', True))
-
- def update_from_config(self, work_dir=None, pipeline=None):
- if work_dir: self.work_dir = work_dir
- main = self.window()
- if hasattr(main, 'step5_panel'):
- p5 = main.step5_panel.output_file.get_path() # 修正:变量名对齐
- if p5:
- if not os.path.isabs(p5): p5 = os.path.join(self.work_dir or '', p5).replace('\\', '/')
- self.training_data_widget.set_path(p5)
-
- if self.work_dir:
- out = os.path.join(self.work_dir, "6_water_quality_indices", "training_spectra_indices.csv").replace('\\', '/')
- self.output_file_widget.set_path(out)
-
- def run_step(self):
- config = self.get_config()
- if not config['training_csv_path']:
- QMessageBox.warning(self, "提示", "请先选择输入数据")
- return
- parent = self.parent()
- while parent and not hasattr(parent, 'run_single_step'): parent = parent.parent()
- if parent: parent.run_single_step('step5_5', {'step5_5': config})
\ No newline at end of file
diff --git a/src/gui/panels/step6_75_panel.py b/src/gui/panels/step6_75_panel.py
deleted file mode 100644
index 130387c..0000000
--- a/src/gui/panels/step6_75_panel.py
+++ /dev/null
@@ -1,374 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-"""
-Step6_75 面板 - 自定义回归分析
-"""
-
-import os
-from typing import Dict
-
-import pandas as pd
-from PyQt5.QtWidgets import (
- QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
- QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton,
- QScrollArea, QMessageBox,
-)
-
-from src.gui.components.custom_widgets import FileSelectWidget
-from src.gui.styles import ModernStylesheet
-
-
-class Step6_75Panel(QWidget):
- """步骤6.75:自定义回归分析"""
- def __init__(self, parent=None):
- super().__init__(parent)
- self.x_column_checkboxes: Dict[str, QCheckBox] = {}
- self.y_column_checkboxes: Dict[str, QCheckBox] = {}
- self.method_checkboxes: Dict[str, QCheckBox] = {}
- self.csv_columns = []
- self.init_ui()
-
- def init_ui(self):
- layout = QVBoxLayout()
-
- hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法")
- hint.setStyleSheet("color: #666; font-size: 11px;")
- layout.addWidget(hint)
-
- # CSV文件选择
- csv_group = QGroupBox("数据文件")
- csv_layout = QVBoxLayout()
-
- self.csv_file = FileSelectWidget(
- "输入CSV文件:",
- "CSV Files (*.csv);;All Files (*.*)"
- )
- self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed)
- csv_layout.addWidget(self.csv_file)
-
- self.refresh_btn = QPushButton("刷新列信息")
- self.refresh_btn.clicked.connect(self.refresh_csv_columns)
- csv_layout.addWidget(self.refresh_btn)
-
- csv_group.setLayout(csv_layout)
- layout.addWidget(csv_group)
-
- # 自变量选择
- x_group = QGroupBox("自变量列选择 (可多选)")
- x_layout = QVBoxLayout()
-
- x_scroll = QScrollArea()
- x_scroll.setWidgetResizable(True)
- x_scroll.setMinimumHeight(250)
- x_scroll.setMaximumHeight(350)
-
- x_widget = QWidget()
- self.x_columns_layout = QGridLayout()
- x_widget.setLayout(self.x_columns_layout)
-
- x_scroll.setWidget(x_widget)
- x_layout.addWidget(x_scroll)
-
- x_btn_layout = QHBoxLayout()
- self.x_select_all = QPushButton("全选")
- self.x_deselect_all = QPushButton("全不选")
- self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True))
- self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False))
- x_btn_layout.addWidget(self.x_select_all)
- x_btn_layout.addWidget(self.x_deselect_all)
- x_btn_layout.addStretch()
- x_layout.addLayout(x_btn_layout)
-
- x_group.setLayout(x_layout)
- layout.addWidget(x_group)
-
- # 因变量选择
- y_group = QGroupBox("因变量列选择 (可多选)")
- y_layout = QVBoxLayout()
-
- y_scroll = QScrollArea()
- y_scroll.setWidgetResizable(True)
- y_scroll.setMinimumHeight(200)
- y_scroll.setMaximumHeight(300)
-
- y_widget = QWidget()
- self.y_columns_layout = QGridLayout()
- y_widget.setLayout(self.y_columns_layout)
-
- y_scroll.setWidget(y_widget)
- y_layout.addWidget(y_scroll)
-
- y_btn_layout = QHBoxLayout()
- self.y_select_all = QPushButton("全选")
- self.y_deselect_all = QPushButton("全不选")
- self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True))
- self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False))
- y_btn_layout.addWidget(self.y_select_all)
- y_btn_layout.addWidget(self.y_deselect_all)
- y_btn_layout.addStretch()
- y_layout.addLayout(y_btn_layout)
-
- y_group.setLayout(y_layout)
- layout.addWidget(y_group)
-
- # 回归方法选择
- method_group = QGroupBox("回归方法选择 (可多选)")
- method_layout = QVBoxLayout()
-
- method_grid = QGridLayout()
- regression_methods = [
- 'linear', 'exponential', 'power', 'logarithmic',
- 'polynomial', 'hyperbolic', 'sigmoidal'
- ]
-
- for i, method in enumerate(regression_methods):
- checkbox = QCheckBox(method)
- if method in ['linear', 'exponential', 'power', 'logarithmic']:
- checkbox.setChecked(True)
- self.method_checkboxes[method] = checkbox
- method_grid.addWidget(checkbox, i // 3, i % 3)
-
- method_layout.addLayout(method_grid)
-
- method_btn_layout = QHBoxLayout()
- self.method_select_all = QPushButton("全选")
- self.method_deselect_all = QPushButton("全不选")
- self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True))
- self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False))
- method_btn_layout.addWidget(self.method_select_all)
- method_btn_layout.addWidget(self.method_deselect_all)
- method_btn_layout.addStretch()
- method_layout.addLayout(method_btn_layout)
-
- method_group.setLayout(method_layout)
- layout.addWidget(method_group)
-
- # 输出目录
- output_group = QGroupBox("输出设置")
- output_layout = QFormLayout()
-
- self.output_dir = QLineEdit()
- self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充
- output_layout.addRow("输出目录名:", self.output_dir)
-
- output_group.setLayout(output_layout)
- layout.addWidget(output_group)
-
- # 启用步骤
- self.enable_checkbox = QCheckBox("启用此步骤")
- self.enable_checkbox.setChecked(True)
- layout.addWidget(self.enable_checkbox)
-
- # 独立运行按钮
- self.run_button = QPushButton("独立运行此步骤")
- self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
- self.run_button.clicked.connect(self.run_step)
- layout.addWidget(self.run_button)
-
- layout.addStretch()
- self.setLayout(layout)
-
- def toggle_checkboxes(self, checkboxes_dict, checked):
- """统一设置checkbox状态"""
- for checkbox in checkboxes_dict.values():
- checkbox.setChecked(checked)
-
- def on_csv_file_changed(self):
- """CSV文件改变时自动刷新列信息"""
- self.refresh_csv_columns()
-
- def refresh_csv_columns(self):
- """刷新CSV文件的列信息"""
- csv_path = self.csv_file.get_path()
- if not csv_path or not os.path.exists(csv_path):
- self.csv_columns = []
- self.update_column_widgets()
- return
-
- try:
- df = pd.read_csv(csv_path, nrows=0)
- self.csv_columns = list(df.columns)
- self.update_column_widgets()
- except Exception as e:
- self.csv_columns = []
- self.update_column_widgets()
- print(f"读取CSV列信息失败: {e}")
-
- def update_column_widgets(self):
- """更新列选择组件"""
- for checkbox in self.x_column_checkboxes.values():
- checkbox.setParent(None)
- self.x_column_checkboxes.clear()
-
- for checkbox in self.y_column_checkboxes.values():
- checkbox.setParent(None)
- self.y_column_checkboxes.clear()
-
- if not self.csv_columns:
- return
-
- for i, col in enumerate(self.csv_columns):
- checkbox = QCheckBox(col)
- if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
- checkbox.setChecked(True)
- self.x_column_checkboxes[col] = checkbox
- self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
-
- for i, col in enumerate(self.csv_columns):
- checkbox = QCheckBox(col)
- if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
- checkbox.setChecked(True)
- self.y_column_checkboxes[col] = checkbox
- self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
-
- self.x_columns_layout.update()
- self.y_columns_layout.update()
-
- def get_config(self):
- selected_x_columns = [
- col for col, checkbox in self.x_column_checkboxes.items()
- if checkbox.isChecked()
- ]
- selected_y_columns = [
- col for col, checkbox in self.y_column_checkboxes.items()
- if checkbox.isChecked()
- ]
- selected_methods = [
- method for method, checkbox in self.method_checkboxes.items()
- if checkbox.isChecked()
- ]
- if not selected_methods:
- selected_methods = 'all'
-
- return {
- 'csv_path': self.csv_file.get_path() or None,
- 'x_columns': selected_x_columns,
- 'y_columns': selected_y_columns,
- 'methods': selected_methods,
- 'output_dir': self.output_dir.text().strip() or None,
- 'enabled': self.enable_checkbox.isChecked()
- }
-
- def set_config(self, config):
- if 'csv_path' in config:
- self.csv_file.set_path(config['csv_path'])
- self.refresh_csv_columns()
-
- if 'x_columns' in config:
- selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set()
- for col, checkbox in self.x_column_checkboxes.items():
- checkbox.setChecked(col in selected_x)
-
- if 'y_columns' in config:
- selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set()
- for col, checkbox in self.y_column_checkboxes.items():
- checkbox.setChecked(col in selected_y)
-
- if 'methods' in config:
- methods = config['methods']
- if isinstance(methods, list):
- selected_methods = set(methods)
- elif methods == 'all':
- selected_methods = set(self.method_checkboxes.keys())
- else:
- selected_methods = set()
- for method, checkbox in self.method_checkboxes.items():
- checkbox.setChecked(method in selected_methods)
-
- if 'output_dir' in config:
- self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
- if 'enabled' in config:
- self.enable_checkbox.setChecked(config['enabled'])
-
- def update_from_config(self, work_dir=None, pipeline=None):
- """从全局配置自动填充训练数据和输出路径
-
- Args:
- work_dir: 工作目录路径
- pipeline: Pipeline 实例(未使用,保留接口兼容性)
- """
- try:
- import traceback
-
- if work_dir:
- self.work_dir = work_dir
- elif hasattr(self, 'work_dir') and self.work_dir:
- pass
- else:
- self.work_dir = None
-
- # 1. 尝试从 Step5 界面读取训练光谱 CSV 路径
- main_window = self.window()
- if main_window and hasattr(main_window, 'step5_panel'):
- step5_widget = getattr(main_window.step5_panel, 'output_file', None)
- step5_output_path = ""
- if hasattr(step5_widget, 'get_path'):
- step5_output_path = step5_widget.get_path() or ""
- elif hasattr(step5_widget, 'text'):
- step5_output_path = step5_widget.text() or ""
-
- if step5_output_path:
- # 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(step5_output_path):
- step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
- existing = self.csv_file.get_path()
- if not existing or not existing.strip():
- self.csv_file.set_path(step5_output_path)
-
- # 2. 自动填充输出目录(9_Custom_Regression_Modeling)
- if self.work_dir:
- output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling")
- os.makedirs(output_dir, exist_ok=True)
- existing_out = self.output_dir.text().strip()
- if not existing_out:
- self.output_dir.setText(output_dir)
- except Exception as e:
- import traceback
- print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
- traceback.print_exc()
-
- def run_step(self):
- """独立运行步骤6.75"""
- csv_path = self.csv_file.get_path()
-
- if not csv_path:
- QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件")
- return
- if not os.path.exists(csv_path):
- QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在")
- return
-
- selected_x_columns = [
- col for col, checkbox in self.x_column_checkboxes.items()
- if checkbox.isChecked()
- ]
- if not selected_x_columns:
- QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列")
- return
-
- selected_y_columns = [
- col for col, checkbox in self.y_column_checkboxes.items()
- if checkbox.isChecked()
- ]
- if not selected_y_columns:
- QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列")
- return
-
- selected_methods = [
- method for method, checkbox in self.method_checkboxes.items()
- if checkbox.isChecked()
- ]
- if not selected_methods:
- QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法")
- return
-
- config = self.get_config()
-
- parent = self.parent()
- while parent and not hasattr(parent, 'run_single_step'):
- parent = parent.parent()
-
- if parent and hasattr(parent, 'run_single_step'):
- parent.run_single_step('step6_75', {'step6_75': config})
- else:
- QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
diff --git a/src/gui/panels/step6_panel.py b/src/gui/panels/step6_panel.py
deleted file mode 100644
index c35121f..0000000
--- a/src/gui/panels/step6_panel.py
+++ /dev/null
@@ -1,415 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-"""
-Step6 面板 - 机器学习建模
-"""
-
-import os
-
-from PyQt5.QtWidgets import (
- QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
- QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
- QPushButton, QFileDialog, QMessageBox,
-)
-from PyQt5.QtCore import Qt
-
-from src.gui.components.custom_widgets import FileSelectWidget
-from src.gui.styles import ModernStylesheet
-
-
-# ============================================================
-# 中文映射表(内部键名 -> 显示文本)
-# ============================================================
-
-# 预处理方法:内部键 -> 显示文本
-PREPROC_CHINESE = {
- 'None': '无 (None)',
- 'MMS': '最小-最大归一化 (MMS)',
- 'SS': '标度化 (SS)',
- 'SNV': '标准正态变换 (SNV)',
- 'MA': '移动平均 (MA)',
- 'SG': 'Savitzky-Golay (SG)',
- 'MSC': '多元散射校正 (MSC)',
- 'D1': '一阶导数 (D1)',
- 'D2': '二阶导数 (D2)',
- 'DT': '去趋势 (DT)',
- 'CT': '中心化 (CT)',
-}
-
-# 模型类型:内部键 -> 显示文本
-MODEL_CHINESE = {
- # 线性模型
- 'LinearRegression': '多元线性回归 (MLR)',
- 'Ridge': '岭回归 (Ridge)',
- 'Lasso': '套索回归 (Lasso)',
- 'ElasticNet': '弹性网络 (ElasticNet)',
- 'PLS': '偏最小二乘 (PLSR)',
- # 树模型
- 'DecisionTree': '决策树 (CART)',
- 'RF': '随机森林 (RF)',
- 'ExtraTrees': '极端随机树 (ET)',
- 'XGBoost': '极值梯度提升 (XGBoost)',
- 'LightGBM': '轻量梯度提升 (LightGBM)',
- 'CatBoost': '类别梯度提升 (CatBoost)',
- # 集成学习
- 'GradientBoosting': '梯度提升树 (GBDT)',
- 'AdaBoost': '自适应提升 (AdaBoost)',
- # 其他模型
- 'SVR': '支持向量回归 (SVR)',
- 'KNN': 'K近邻回归 (KNN)',
- 'MLP': '多层感知机 (BP神经网络)',
-}
-
-# 数据划分方法:内部键 -> 显示文本
-SPLIT_CHINESE = {
- 'spxy': 'SPXY 算法 (考量X-Y空间)',
- 'ks': 'KS 算法 (考量X空间)',
- 'random': '随机划分 (Random)',
-}
-
-
-class Step6Panel(QWidget):
- """步骤6:机器学习建模"""
- def __init__(self, parent=None):
- super().__init__(parent)
- self.init_ui()
-
- def init_ui(self):
- layout = QVBoxLayout()
-
- # 标题
-
-
- # 训练数据文件(用于独立运行)
- self.training_csv_file = FileSelectWidget(
- "训练数据:",
- "CSV Files (*.csv);;All Files (*.*)"
- )
- layout.addWidget(self.training_csv_file)
-
- # 机器学习模型页面
- self.ml_page = QWidget()
- self.create_ml_page()
- layout.addWidget(self.ml_page)
-
- # 输出文件路径
- self.output_path = FileSelectWidget(
- "输出文件:",
- "CSV Files (*.csv);;All Files (*.*)",
- mode="save"
- )
- self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
- self.output_path.browse_btn.clicked.disconnect()
- self.output_path.browse_btn.clicked.connect(self.browse_output_path)
- layout.addWidget(self.output_path)
-
- # 启用步骤
- self.enable_checkbox = QCheckBox("启用此步骤")
- self.enable_checkbox.setChecked(False)
- layout.addWidget(self.enable_checkbox)
-
- # 独立运行按钮
- self.run_btn = QPushButton("独立运行此步骤")
- self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
- self.run_btn.clicked.connect(self.run_step)
- layout.addWidget(self.run_btn)
-
- layout.addStretch()
- self.setLayout(layout)
-
- def create_ml_page(self):
- """创建机器学习模型页面"""
- layout = QVBoxLayout()
-
- # 参数设置
- params_group = QGroupBox("训练参数")
- params_layout = QFormLayout()
-
- self.feature_start = QLineEdit()
- self.feature_start.setText("374.285004")
- params_layout.addRow("特征起始列:", self.feature_start)
-
- self.cv_folds = QSpinBox()
- self.cv_folds.setRange(2, 10)
- self.cv_folds.setValue(3)
- params_layout.addRow("交叉验证折数:", self.cv_folds)
-
- params_group.setLayout(params_layout)
- layout.addWidget(params_group)
-
- # 预处理方法 - 多选
- preproc_group = QGroupBox("预处理方法 (可多选)")
- preproc_layout = QVBoxLayout()
-
- preproc_grid = QGridLayout()
- self.preproc_checkboxes = {}
- preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
-
- for i, method in enumerate(preproc_methods):
- checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
- checkbox.setChecked(False)
- self.preproc_checkboxes[method] = checkbox
- preproc_grid.addWidget(checkbox, i // 4, i % 4)
-
- button_layout = QHBoxLayout()
- select_all_btn = QPushButton("全选")
- deselect_all_btn = QPushButton("全不选")
- select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
- deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
- button_layout.addWidget(select_all_btn)
- button_layout.addWidget(deselect_all_btn)
- button_layout.addStretch()
-
- preproc_layout.addLayout(preproc_grid)
- preproc_layout.addLayout(button_layout)
- preproc_group.setLayout(preproc_layout)
- layout.addWidget(preproc_group)
-
- # 模型选择 - 多选
- model_group = QGroupBox("模型类型 (可多选)")
- model_layout = QVBoxLayout()
-
- model_grid = QGridLayout()
- self.model_checkboxes = {}
-
- model_groups = [
- ("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
- ("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
- ("【集成学习】", ['GradientBoosting', 'AdaBoost']),
- ("【其他模型】", ['SVR', 'KNN', 'MLP'])
- ]
-
- row = 0
- for group_name, models in model_groups:
- group_label = QLabel(f"{group_name}")
- group_label.setStyleSheet(
- f"background-color: {ModernStylesheet.COLORS['hover']}; "
- f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
- f"border-radius: 3px;"
- )
- model_grid.addWidget(group_label, row, 0, 1, 4)
- row += 1
-
- for i, model in enumerate(models):
- checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
- checkbox.setChecked(False)
- self.model_checkboxes[model] = checkbox
- model_grid.addWidget(checkbox, row, i % 4)
- if (i + 1) % 4 == 0:
- row += 1
-
- row += 1
-
- model_button_layout = QHBoxLayout()
- model_select_all = QPushButton("全选")
- model_deselect_all = QPushButton("全不选")
- model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
- model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
- model_button_layout.addWidget(model_select_all)
- model_button_layout.addWidget(model_deselect_all)
- model_button_layout.addStretch()
-
- model_layout.addLayout(model_grid)
- model_layout.addLayout(model_button_layout)
- model_group.setLayout(model_layout)
- layout.addWidget(model_group)
-
- # 数据划分方法 - 多选
- split_group = QGroupBox("数据划分方法 (可多选)")
- split_layout = QVBoxLayout()
-
- split_grid = QGridLayout()
- self.split_checkboxes = {}
- split_methods = ['spxy', 'ks', 'random']
-
- for i, method in enumerate(split_methods):
- checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
- checkbox.setChecked(False)
- self.split_checkboxes[method] = checkbox
- split_grid.addWidget(checkbox, 0, i)
-
- split_button_layout = QHBoxLayout()
- split_select_all = QPushButton("全选")
- split_deselect_all = QPushButton("全不选")
- split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
- split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
- split_button_layout.addWidget(split_select_all)
- split_button_layout.addWidget(split_deselect_all)
- split_button_layout.addStretch()
-
- split_layout.addLayout(split_grid)
- split_layout.addLayout(split_button_layout)
- split_group.setLayout(split_layout)
- layout.addWidget(split_group)
-
- self.ml_page.setLayout(layout)
-
- def _toggle_checkboxes(self, checkboxes_dict, checked):
- """统一设置checkbox状态"""
- for checkbox in checkboxes_dict.values():
- checkbox.setChecked(checked)
-
- def _get_default_work_dir(self):
- """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
- if hasattr(self, 'work_dir') and self.work_dir:
- return str(self.work_dir)
- mw = self.window()
- if mw and hasattr(mw, 'work_dir') and mw.work_dir:
- return str(mw.work_dir)
- return ""
-
- def browse_output_path(self):
- """浏览输出文件路径(保存对话框)"""
- current = self.output_path.get_path().strip()
- if current:
- initial_dir = os.path.dirname(current)
- initial_file = os.path.basename(current)
- else:
- initial_dir = ""
- initial_file = ""
-
- if not initial_dir or not os.path.isdir(initial_dir):
- # 默认定位到 indices 目录
- work_dir = self._get_default_work_dir()
- initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
- if initial_dir and not os.path.isdir(initial_dir):
- os.makedirs(initial_dir, exist_ok=True)
-
- file_path, _ = QFileDialog.getSaveFileName(
- self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
- "CSV Files (*.csv);;All Files (*.*)"
- )
- if file_path:
- self.output_path.set_path(file_path)
-
- def get_config(self):
- """获取配置"""
- preprocessing_methods = [
- method for method, checkbox in self.preproc_checkboxes.items()
- if checkbox.isChecked()
- ]
- model_names = [
- model for model, checkbox in self.model_checkboxes.items()
- if checkbox.isChecked()
- ]
- split_methods = [
- method for method, checkbox in self.split_checkboxes.items()
- if checkbox.isChecked()
- ]
-
- config = {
- 'feature_start_column': self.feature_start.text(),
- 'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
- 'model_names': model_names if model_names else ['SVR'],
- 'split_methods': split_methods if split_methods else ['random'],
- 'cv_folds': self.cv_folds.value()
- }
- training_csv_path = self.training_csv_file.get_path()
- if training_csv_path:
- config['training_csv_path'] = training_csv_path
- output_path = self.output_path.get_path()
- if output_path:
- config['output_path'] = output_path
- return config
-
- def set_config(self, config):
- """设置配置"""
- if 'feature_start_column' in config:
- self.feature_start.setText(str(config['feature_start_column']))
- if 'cv_folds' in config:
- self.cv_folds.setValue(config['cv_folds'])
- if 'preprocessing_methods' in config:
- methods = config['preprocessing_methods']
- for method, checkbox in self.preproc_checkboxes.items():
- checkbox.setChecked(method in methods)
- if 'model_names' in config:
- models = config['model_names']
- for model, checkbox in self.model_checkboxes.items():
- checkbox.setChecked(model in models)
- if 'split_methods' in config:
- methods = config['split_methods']
- for method, checkbox in self.split_checkboxes.items():
- checkbox.setChecked(method in methods)
- if 'training_csv_path' in config:
- self.training_csv_file.set_path(config['training_csv_path'])
- if 'output_path' in config:
- self.output_path.set_path(config['output_path'])
-
- def update_from_config(self, work_dir=None, pipeline=None):
- """从全局配置自动填充训练数据和输出路径
-
- Args:
- work_dir: 工作目录路径
- pipeline: Pipeline 实例(未使用,保留接口兼容性)
- """
- if work_dir:
- self.work_dir = work_dir
- elif hasattr(self, 'work_dir') and self.work_dir:
- pass
- else:
- self.work_dir = None
-
- # 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
- main_window = self.window()
- if hasattr(main_window, 'step5_panel'):
- # 优先直接从 Step5 的输出 widget 读取
- step5_output = main_window.step5_panel.output_file.get_path()
- if step5_output:
- # 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(step5_output):
- step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
- self.training_csv_file.set_path(step5_output)
- elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
- # 回退:从 Step5 的 config 字典中查找可能的键名
- step5_cfg = main_window.step5_panel.get_config()
- step5_csv = (
- step5_cfg.get('training_csv_path')
- or step5_cfg.get('output_file')
- or step5_cfg.get('csv_path')
- or step5_cfg.get('output_csv')
- )
- if step5_csv:
- # 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(step5_csv):
- step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
- self.training_csv_file.set_path(step5_csv)
-
- # 2. 自动填充输出文件路径(基于工作目录和输入文件名)
- # 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
- # 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
- if self.work_dir:
- indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
- os.makedirs(indices_dir, exist_ok=True)
- training_csv = self.training_csv_file.get_path()
- if training_csv:
- basename = os.path.splitext(os.path.basename(training_csv))[0]
- output_file = f"{basename}_indices.csv"
- else:
- output_file = "water_quality_indices.csv"
- output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
- self.output_path.set_path(output_path)
- else:
- self.output_path.set_path("")
-
- def run_step(self):
- """独立运行步骤6"""
- training_csv_path = self.training_csv_file.get_path()
- if not training_csv_path:
- QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
- return
-
- main_window = self.window()
- if hasattr(main_window, 'run_single_step'):
- config = {'step6': self.get_config()}
- main_window.run_single_step('step6', config)
-
- def get_training_params(self):
- """获取模型训练参数"""
- return {
- 'pipeline_type': 'machine_learning',
- 'feature_start': float(self.feature_start.text()),
- 'cv_folds': self.cv_folds.value(),
- 'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
- 'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
- 'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
- }
diff --git a/src/gui/panels/step7_panel.py b/src/gui/panels/step7_panel.py
index 7c9ed67..e049fce 100644
--- a/src/gui/panels/step7_panel.py
+++ b/src/gui/panels/step7_panel.py
@@ -1,23 +1,75 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
-Step7 面板 - 采样点生成
+Step7 面板 - 机器学习建模
"""
import os
from PyQt5.QtWidgets import (
- QWidget, QVBoxLayout, QGroupBox, QFormLayout,
- QPushButton, QCheckBox, QSpinBox, QMessageBox,
+ QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
+ QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
+ QPushButton, QFileDialog, QMessageBox,
)
+from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
-from src.gui.dialogs import SamplingViewerDialog
from src.gui.styles import ModernStylesheet
+# ============================================================
+# 中文映射表(内部键名 -> 显示文本)
+# ============================================================
+
+# 预处理方法:内部键 -> 显示文本
+PREPROC_CHINESE = {
+ 'None': '无 (None)',
+ 'MMS': '最小-最大归一化 (MMS)',
+ 'SS': '标度化 (SS)',
+ 'SNV': '标准正态变换 (SNV)',
+ 'MA': '移动平均 (MA)',
+ 'SG': 'Savitzky-Golay (SG)',
+ 'MSC': '多元散射校正 (MSC)',
+ 'D1': '一阶导数 (D1)',
+ 'D2': '二阶导数 (D2)',
+ 'DT': '去趋势 (DT)',
+ 'CT': '中心化 (CT)',
+}
+
+# 模型类型:内部键 -> 显示文本
+MODEL_CHINESE = {
+ # 线性模型
+ 'LinearRegression': '多元线性回归 (MLR)',
+ 'Ridge': '岭回归 (Ridge)',
+ 'Lasso': '套索回归 (Lasso)',
+ 'ElasticNet': '弹性网络 (ElasticNet)',
+ 'PLS': '偏最小二乘 (PLSR)',
+ # 树模型
+ 'DecisionTree': '决策树 (CART)',
+ 'RF': '随机森林 (RF)',
+ 'ExtraTrees': '极端随机树 (ET)',
+ 'XGBoost': '极值梯度提升 (XGBoost)',
+ 'LightGBM': '轻量梯度提升 (LightGBM)',
+ 'CatBoost': '类别梯度提升 (CatBoost)',
+ # 集成学习
+ 'GradientBoosting': '梯度提升树 (GBDT)',
+ 'AdaBoost': '自适应提升 (AdaBoost)',
+ # 其他模型
+ 'SVR': '支持向量回归 (SVR)',
+ 'KNN': 'K近邻回归 (KNN)',
+ 'MLP': '多层感知机 (BP神经网络)',
+}
+
+# 数据划分方法:内部键 -> 显示文本
+SPLIT_CHINESE = {
+ 'spxy': 'SPXY 算法 (考量X-Y空间)',
+ 'ks': 'KS 算法 (考量X空间)',
+ 'random': '随机划分 (Random)',
+}
+
+
class Step7Panel(QWidget):
- """步骤7:采样点生成"""
+ """步骤7:机器学习建模"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
@@ -25,58 +77,35 @@ class Step7Panel(QWidget):
def init_ui(self):
layout = QVBoxLayout()
- # 去耀斑影像文件(用于独立运行)
- self.deglint_img_file = FileSelectWidget(
- "去耀斑影像:",
- "Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
- )
- layout.addWidget(self.deglint_img_file)
+ # 标题
- # 水域掩膜文件(可选,用于独立运行)
- self.water_mask_file = FileSelectWidget(
- "水域掩膜:",
- "Mask Files (*.dat *.tif);;All Files (*.*)"
- )
- self.water_mask_file.label.setText("水域掩膜:")
- layout.addWidget(self.water_mask_file)
- # 参数设置
- params_group = QGroupBox("采样参数")
- params_layout = QFormLayout()
-
- self.interval = QSpinBox()
- self.interval.setRange(10, 500)
- self.interval.setValue(50)
- params_layout.addRow("采样点间隔(像素):", self.interval)
-
- self.sample_radius = QSpinBox()
- self.sample_radius.setRange(1, 50)
- self.sample_radius.setValue(5)
- params_layout.addRow("采样半径(像素):", self.sample_radius)
-
- self.chunk_size = QSpinBox()
- self.chunk_size.setRange(100, 10000)
- self.chunk_size.setValue(1000)
- params_layout.addRow("处理块大小:", self.chunk_size)
-
- self.use_adaptive_sampling = QCheckBox("启用自适应采样")
- self.use_adaptive_sampling.setChecked(True)
- params_layout.addRow("采样模式:", self.use_adaptive_sampling)
-
- params_group.setLayout(params_layout)
- layout.addWidget(params_group)
-
- # 输出文件路径
- self.output_file = FileSelectWidget(
- "输出采样点:",
+ # 训练数据文件(用于独立运行)
+ self.training_csv_file = FileSelectWidget(
+ "训练数据:",
"CSV Files (*.csv);;All Files (*.*)"
)
- self.output_file.line_edit.setPlaceholderText("sampling_points.csv")
- layout.addWidget(self.output_file)
+ layout.addWidget(self.training_csv_file)
+
+ # 机器学习模型页面
+ self.ml_page = QWidget()
+ self.create_ml_page()
+ layout.addWidget(self.ml_page)
+
+ # 输出文件路径
+ self.output_path = FileSelectWidget(
+ "输出文件:",
+ "CSV Files (*.csv);;All Files (*.*)",
+ mode="save"
+ )
+ self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
+ self.output_path.browse_btn.clicked.disconnect()
+ self.output_path.browse_btn.clicked.connect(self.browse_output_path)
+ layout.addWidget(self.output_path)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
- self.enable_checkbox.setChecked(True)
+ self.enable_checkbox.setChecked(False)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
@@ -85,57 +114,233 @@ class Step7Panel(QWidget):
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
- # 交互式预览按钮
- self.preview_btn = QPushButton("📊 交互式预览采样点与光谱")
- self.preview_btn.setEnabled(False)
- self.preview_btn.clicked.connect(self._open_sampling_viewer)
- layout.addWidget(self.preview_btn)
-
layout.addStretch()
self.setLayout(layout)
- # 监听输出路径变化,实时更新预览按钮状态
- self.output_file.line_edit.textChanged.connect(self._on_output_changed)
+ def create_ml_page(self):
+ """创建机器学习模型页面"""
+ layout = QVBoxLayout()
+
+ # 参数设置
+ params_group = QGroupBox("训练参数")
+ params_layout = QFormLayout()
+
+ self.feature_start = QLineEdit()
+ self.feature_start.setText("374.285004")
+ params_layout.addRow("特征起始列:", self.feature_start)
+
+ self.cv_folds = QSpinBox()
+ self.cv_folds.setRange(2, 10)
+ self.cv_folds.setValue(3)
+ params_layout.addRow("交叉验证折数:", self.cv_folds)
+
+ params_group.setLayout(params_layout)
+ layout.addWidget(params_group)
+
+ # 预处理方法 - 多选
+ preproc_group = QGroupBox("预处理方法 (可多选)")
+ preproc_layout = QVBoxLayout()
+
+ preproc_grid = QGridLayout()
+ self.preproc_checkboxes = {}
+ preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
+
+ for i, method in enumerate(preproc_methods):
+ checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
+ checkbox.setChecked(False)
+ self.preproc_checkboxes[method] = checkbox
+ preproc_grid.addWidget(checkbox, i // 4, i % 4)
+
+ button_layout = QHBoxLayout()
+ select_all_btn = QPushButton("全选")
+ deselect_all_btn = QPushButton("全不选")
+ select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
+ deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
+ button_layout.addWidget(select_all_btn)
+ button_layout.addWidget(deselect_all_btn)
+ button_layout.addStretch()
+
+ preproc_layout.addLayout(preproc_grid)
+ preproc_layout.addLayout(button_layout)
+ preproc_group.setLayout(preproc_layout)
+ layout.addWidget(preproc_group)
+
+ # 模型选择 - 多选
+ model_group = QGroupBox("模型类型 (可多选)")
+ model_layout = QVBoxLayout()
+
+ model_grid = QGridLayout()
+ self.model_checkboxes = {}
+
+ model_groups = [
+ ("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
+ ("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
+ ("【集成学习】", ['GradientBoosting', 'AdaBoost']),
+ ("【其他模型】", ['SVR', 'KNN', 'MLP'])
+ ]
+
+ row = 0
+ for group_name, models in model_groups:
+ group_label = QLabel(f"{group_name}")
+ group_label.setStyleSheet(
+ f"background-color: {ModernStylesheet.COLORS['hover']}; "
+ f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
+ f"border-radius: 3px;"
+ )
+ model_grid.addWidget(group_label, row, 0, 1, 4)
+ row += 1
+
+ for i, model in enumerate(models):
+ checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
+ checkbox.setChecked(False)
+ self.model_checkboxes[model] = checkbox
+ model_grid.addWidget(checkbox, row, i % 4)
+ if (i + 1) % 4 == 0:
+ row += 1
+
+ row += 1
+
+ model_button_layout = QHBoxLayout()
+ model_select_all = QPushButton("全选")
+ model_deselect_all = QPushButton("全不选")
+ model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
+ model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
+ model_button_layout.addWidget(model_select_all)
+ model_button_layout.addWidget(model_deselect_all)
+ model_button_layout.addStretch()
+
+ model_layout.addLayout(model_grid)
+ model_layout.addLayout(model_button_layout)
+ model_group.setLayout(model_layout)
+ layout.addWidget(model_group)
+
+ # 数据划分方法 - 多选
+ split_group = QGroupBox("数据划分方法 (可多选)")
+ split_layout = QVBoxLayout()
+
+ split_grid = QGridLayout()
+ self.split_checkboxes = {}
+ split_methods = ['spxy', 'ks', 'random']
+
+ for i, method in enumerate(split_methods):
+ checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
+ checkbox.setChecked(False)
+ self.split_checkboxes[method] = checkbox
+ split_grid.addWidget(checkbox, 0, i)
+
+ split_button_layout = QHBoxLayout()
+ split_select_all = QPushButton("全选")
+ split_deselect_all = QPushButton("全不选")
+ split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
+ split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
+ split_button_layout.addWidget(split_select_all)
+ split_button_layout.addWidget(split_deselect_all)
+ split_button_layout.addStretch()
+
+ split_layout.addLayout(split_grid)
+ split_layout.addLayout(split_button_layout)
+ split_group.setLayout(split_layout)
+ layout.addWidget(split_group)
+
+ self.ml_page.setLayout(layout)
+
+ def _toggle_checkboxes(self, checkboxes_dict, checked):
+ """统一设置checkbox状态"""
+ for checkbox in checkboxes_dict.values():
+ checkbox.setChecked(checked)
+
+ def _get_default_work_dir(self):
+ """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
+ if hasattr(self, 'work_dir') and self.work_dir:
+ return str(self.work_dir)
+ mw = self.window()
+ if mw and hasattr(mw, 'work_dir') and mw.work_dir:
+ return str(mw.work_dir)
+ return ""
+
+ def browse_output_path(self):
+ """浏览输出文件路径(保存对话框)"""
+ current = self.output_path.get_path().strip()
+ if current:
+ initial_dir = os.path.dirname(current)
+ initial_file = os.path.basename(current)
+ else:
+ initial_dir = ""
+ initial_file = ""
+
+ if not initial_dir or not os.path.isdir(initial_dir):
+ # 默认定位到 indices 目录
+ work_dir = self._get_default_work_dir()
+ initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
+ if initial_dir and not os.path.isdir(initial_dir):
+ os.makedirs(initial_dir, exist_ok=True)
+
+ file_path, _ = QFileDialog.getSaveFileName(
+ self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
+ "CSV Files (*.csv);;All Files (*.*)"
+ )
+ if file_path:
+ self.output_path.set_path(file_path)
def get_config(self):
"""获取配置"""
+ preprocessing_methods = [
+ method for method, checkbox in self.preproc_checkboxes.items()
+ if checkbox.isChecked()
+ ]
+ model_names = [
+ model for model, checkbox in self.model_checkboxes.items()
+ if checkbox.isChecked()
+ ]
+ split_methods = [
+ method for method, checkbox in self.split_checkboxes.items()
+ if checkbox.isChecked()
+ ]
+
config = {
- 'interval': self.interval.value(),
- 'sample_radius': self.sample_radius.value(),
- 'chunk_size': self.chunk_size.value(),
- 'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(),
+ 'feature_start_column': self.feature_start.text(),
+ 'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
+ 'model_names': model_names if model_names else ['SVR'],
+ 'split_methods': split_methods if split_methods else ['random'],
+ 'cv_folds': self.cv_folds.value()
}
- deglint_img_path = self.deglint_img_file.get_path()
- if deglint_img_path:
- config['deglint_img_path'] = deglint_img_path
- water_mask_path = self.water_mask_file.get_path()
- if water_mask_path:
- config['water_mask_path'] = water_mask_path
+ training_csv_path = self.training_csv_file.get_path()
+ if training_csv_path:
+ config['training_csv_path'] = training_csv_path
+ output_path = self.output_path.get_path()
+ if output_path:
+ config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
- if 'interval' in config:
- self.interval.setValue(config['interval'])
- if 'sample_radius' in config:
- self.sample_radius.setValue(config['sample_radius'])
- if 'chunk_size' in config:
- self.chunk_size.setValue(config['chunk_size'])
- if 'use_adaptive_sampling' in config:
- self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling'])
- if 'deglint_img_path' in config:
- self.deglint_img_file.set_path(config['deglint_img_path'])
- if 'water_mask_path' in config:
- self.water_mask_file.set_path(config['water_mask_path'])
- if 'glint_mask_path' in config:
- self.glint_mask_file.set_path(config['glint_mask_path'])
+ if 'feature_start_column' in config:
+ self.feature_start.setText(str(config['feature_start_column']))
+ if 'cv_folds' in config:
+ self.cv_folds.setValue(config['cv_folds'])
+ if 'preprocessing_methods' in config:
+ methods = config['preprocessing_methods']
+ for method, checkbox in self.preproc_checkboxes.items():
+ checkbox.setChecked(method in methods)
+ if 'model_names' in config:
+ models = config['model_names']
+ for model, checkbox in self.model_checkboxes.items():
+ checkbox.setChecked(model in models)
+ if 'split_methods' in config:
+ methods = config['split_methods']
+ for method, checkbox in self.split_checkboxes.items():
+ checkbox.setChecked(method in methods)
+ if 'training_csv_path' in config:
+ self.training_csv_file.set_path(config['training_csv_path'])
+ if 'output_path' in config:
+ self.output_path.set_path(config['output_path'])
def update_from_config(self, work_dir=None, pipeline=None):
- """从全局配置自动填充去耀斑影像和掩膜路径
+ """从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
- pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径)
+ pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
if work_dir:
self.work_dir = work_dir
@@ -144,81 +349,53 @@ class Step7Panel(QWidget):
else:
self.work_dir = None
+ # 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
main_window = self.window()
+ if hasattr(main_window, 'step5_panel'):
+ # 优先直接从 Step5 的输出 widget 读取
+ step5_output = main_window.step5_panel.output_file.get_path()
+ if step5_output:
+ # 若为相对路径,使用 work_dir 合成为绝对路径
+ if not os.path.isabs(step5_output):
+ step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
+ self.training_csv_file.set_path(step5_output)
+ elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
+ # 回退:从 Step5 的 config 字典中查找可能的键名
+ step5_cfg = main_window.step5_panel.get_config()
+ step5_csv = (
+ step5_cfg.get('training_csv_path')
+ or step5_cfg.get('output_file')
+ or step5_cfg.get('csv_path')
+ or step5_cfg.get('output_csv')
+ )
+ if step5_csv:
+ # 若为相对路径,使用 work_dir 合成为绝对路径
+ if not os.path.isabs(step5_csv):
+ step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
+ self.training_csv_file.set_path(step5_csv)
- # 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径)
- deglint_path = None
- if pipeline and hasattr(pipeline, 'step_outputs'):
- step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {})
- deglint_path = (
- step3_outputs.get('deglint_image')
- or step3_outputs.get('output_path')
- or step3_outputs.get('output_file')
- or step3_outputs.get('deglint_img_path')
- )
- # 回退:从 step3 面板 widget 直接读取(可能是相对路径)
- if not deglint_path and hasattr(main_window, 'step3_panel'):
- deglint_path = main_window.step3_panel.output_file.get_path()
-
- if deglint_path:
- # 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(deglint_path):
- deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/')
- self.deglint_img_file.set_path(deglint_path)
-
- # 2. 填充水域掩膜路径(优先级:pipeline.step_outputs > step1_panel > 1_water_mask > input-test)
- water_mask_path = None
- if pipeline and hasattr(pipeline, 'step_outputs'):
- step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {})
- water_mask_path = (
- step1_outputs.get('water_mask')
- or step1_outputs.get('output_path')
- or step1_outputs.get('output_file')
- )
- # 回退:从 step1 面板 widget 直接读取
- if not water_mask_path and hasattr(main_window, 'step1_panel'):
- water_mask_path = main_window.step1_panel.output_file.get_path()
- # 备选:扫描 1_water_mask 目录下的 .dat 文件
- if not water_mask_path and self.work_dir:
- mask_dir = os.path.join(self.work_dir, "1_water_mask")
- if os.path.isdir(mask_dir):
- dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')]
- if dat_files:
- water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/')
- # 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat)
- if not water_mask_path and self.work_dir:
- input_test_dir = os.path.join(self.work_dir, "input-test")
- if os.path.isdir(input_test_dir):
- dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')]
- # 优先匹配 water_mask_from_shp.dat
- for f in dat_files:
- if 'water_mask_from_shp' in f.lower():
- water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/')
- break
- # 否则取第一个 .dat 文件
- if not water_mask_path and dat_files:
- water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/')
-
- if water_mask_path:
- # 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(water_mask_path):
- water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/')
- self.water_mask_file.set_path(water_mask_path)
-
- # 3. 自动填充输出路径(绝对路径)
+ # 2. 自动填充输出文件路径(基于工作目录和输入文件名)
+ # 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
+ # 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
if self.work_dir:
- output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv")
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
- self.output_file.set_path(output_path.replace('\\', '/'))
-
- # 4. 同步更新预览按钮状态(路径可能已自动填充)
- self._check_csv_exists()
+ indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
+ os.makedirs(indices_dir, exist_ok=True)
+ training_csv = self.training_csv_file.get_path()
+ if training_csv:
+ basename = os.path.splitext(os.path.basename(training_csv))[0]
+ output_file = f"{basename}_indices.csv"
+ else:
+ output_file = "water_quality_indices.csv"
+ output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
+ self.output_path.set_path(output_path)
+ else:
+ self.output_path.set_path("")
def run_step(self):
"""独立运行步骤7"""
- deglint_img_path = self.deglint_img_file.get_path()
- if not deglint_img_path:
- QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
+ training_csv_path = self.training_csv_file.get_path()
+ if not training_csv_path:
+ QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
return
main_window = self.window()
@@ -226,27 +403,13 @@ class Step7Panel(QWidget):
config = {'step7': self.get_config()}
main_window.run_single_step('step7', config)
- def _check_csv_exists(self):
- """检查 output csv 是否存在,驱动预览按钮启停"""
- csv_path = self.output_file.get_path()
- enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path))
- self.preview_btn.setEnabled(enabled)
- return enabled
-
- def _on_output_changed(self, _text=None):
- """输出路径输入框内容变化时调用(_text 为 line_edit.textChanged 信号参数)"""
- self._check_csv_exists()
-
- def _open_sampling_viewer(self):
- """打开交互式采样点查看器弹窗"""
- csv_path = self.output_file.get_path()
- if not csv_path or not os.path.exists(csv_path):
- QMessageBox.warning(
- self, "文件不存在",
- f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤7生成数据。"
- )
- return
- dialog = SamplingViewerDialog(csv_path, self)
- dialog.exec_()
- # 弹窗关闭后再次检查状态(可能文件被覆盖等)
- self._check_csv_exists()
+ def get_training_params(self):
+ """获取模型训练参数"""
+ return {
+ 'pipeline_type': 'machine_learning',
+ 'feature_start': float(self.feature_start.text()),
+ 'cv_folds': self.cv_folds.value(),
+ 'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
+ 'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
+ 'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
+ }
diff --git a/src/gui/panels/step6_5_panel.py b/src/gui/panels/step8_non_empirical_panel.py
similarity index 97%
rename from src/gui/panels/step6_5_panel.py
rename to src/gui/panels/step8_non_empirical_panel.py
index 6eddab7..25dff2a 100644
--- a/src/gui/panels/step6_5_panel.py
+++ b/src/gui/panels/step8_non_empirical_panel.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
-Step6_5 面板 - 非经验统计回归建模
+Step8 面板 - 非经验统计回归建模
"""
import os
@@ -17,8 +17,8 @@ from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
-class Step6_5Panel(QWidget):
- """步骤6.5:非经验统计回归建模"""
+class Step8NonEmpiricalPanel(QWidget):
+ """步骤8:非经验统计回归建模"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
@@ -280,7 +280,7 @@ class Step6_5Panel(QWidget):
self.output_dir.set_path(dir_path)
def run_step(self):
- """独立运行步骤6.5"""
+ """独立运行步骤8"""
training_csv_path = self.training_csv_file.get_path()
if not training_csv_path:
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
@@ -297,7 +297,7 @@ class Step6_5Panel(QWidget):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
- parent.run_single_step('step6_5', {'step6_5': config})
+ parent.run_single_step('step8_non_empirical_modeling', {'step8_non_empirical_modeling': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
diff --git a/src/gui/panels/step8_panel.py b/src/gui/panels/step8_panel.py
index 727cbb1..ab34d14 100644
--- a/src/gui/panels/step8_panel.py
+++ b/src/gui/panels/step8_panel.py
@@ -1,462 +1,225 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-"""
-Step8 面板 - 机器学习预测
-"""
-
import os
+import sys
+import pandas as pd
from pathlib import Path
+from typing import Dict, List, Union
from PyQt5.QtWidgets import (
- QWidget, QVBoxLayout, QGroupBox, QFormLayout,
- QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
- QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout,
- QListWidgetItem,
+ QWidget, QVBoxLayout, QGroupBox, QGridLayout,
+ QHBoxLayout, QLabel, QCheckBox, QPushButton, QMessageBox, QScrollArea
)
from PyQt5.QtCore import Qt
-
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
+def get_resource_path(relative_path: str) -> str:
+ """适配开发与 PyInstaller 环境的路径获取逻辑。
+ 支持两种打包模式:
+ 1. --onedir 模式:文件在 exe_root/_internal/ 下 → 检查 _internal 目录
+ 2. --onefile 模式:文件在 sys._MEIPASS 平铺目录
+ """
+ # 优先检查 PyInstaller onefile 模式(文件平铺在 _MEIPASS 下)
+ if hasattr(sys, '_MEIPASS'):
+ internal_path = os.path.join(sys._MEIPASS, '_internal', relative_path)
+ if os.path.exists(internal_path):
+ return internal_path
+ return os.path.join(sys._MEIPASS, relative_path)
+
+ # 兼容 PyInstaller onedir 模式的 _internal 目录(exe 同级目录下)
+ exe_dir = os.path.dirname(sys.executable)
+ internal_path = os.path.join(exe_dir, '_internal', relative_path)
+ if os.path.exists(internal_path):
+ return internal_path
+
+ # 开发环境下:基于当前文件 (step8_panel.py) 的绝对路径进行回溯
+ # 当前在 src/gui/panels/,目标在 src/gui/model/
+ base_dir = Path(__file__).resolve().parent.parent / "model"
+ target_path = base_dir / os.path.basename(relative_path)
+ return str(target_path)
class Step8Panel(QWidget):
- """步骤8:机器学习预测"""
def __init__(self, parent=None):
super().__init__(parent)
- self.external_models_dict = {} # {subdir_name: model_obj, ...}
- self.external_model_dir = "" # 母文件夹路径(隐藏)
+ self.index_checkboxes: Dict[str, QCheckBox] = {}
+ # 标识为 waterindex.csv,目录跳转逻辑在 get_resource_path 中
+ self.builtin_formula_path = get_resource_path("waterindex.csv")
+
self.init_ui()
+ # 延迟一小会儿加载,确保UI框架已就绪
+ self._auto_load_formulas()
def init_ui(self):
- layout = QVBoxLayout()
+ main_layout = QVBoxLayout()
+ main_layout.setContentsMargins(20, 20, 20, 20)
+ main_layout.setSpacing(10)
- # -------- 模型来源选择(单选按钮组) --------
- source_group = QGroupBox("模型来源")
- source_layout = QVBoxLayout()
+ # 1. 路径展示区 (半透明只读)
+ path_group = QGroupBox("公式配置源 (内置)")
+ path_layout = QVBoxLayout()
+ self.formula_csv_widget = FileSelectWidget("内置CSV路径:", "CSV Files (*.csv)")
+ self.formula_csv_widget.set_path(self.builtin_formula_path)
+ self.formula_csv_widget.set_read_only(True)
+ # 视觉微调:提示用户这是内置的
+ self.formula_csv_widget.line_edit.setStyleSheet("background-color: #f0f0f0; color: #666;")
+ path_layout.addWidget(self.formula_csv_widget)
+ path_group.setLayout(path_layout)
+ main_layout.addWidget(path_group)
- self.use_trained_model = QRadioButton("使用当前训练流程的模型")
- self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
- self.use_trained_model.setChecked(True)
- source_layout.addWidget(self.use_trained_model)
- source_layout.addWidget(self.use_external_model)
+ # 2. 训练数据输入
+ input_group = QGroupBox("输入样本数据")
+ input_layout = QVBoxLayout()
+ self.training_data_widget = FileSelectWidget("特征提取CSV:", "CSV Files (*.csv)")
+ input_layout.addWidget(self.training_data_widget)
+ input_group.setLayout(input_layout)
+ main_layout.addWidget(input_group)
- self.use_trained_model.toggled.connect(self._on_model_source_changed)
- self.use_external_model.toggled.connect(self._on_model_source_changed)
+ # 3. 公式选择区
+ self.formula_group = QGroupBox("待计算水质指数勾选")
+ formula_outer_layout = QVBoxLayout()
- source_group.setStyleSheet("""
- QRadioButton {
- font-size: 13px;
- spacing: 8px;
- }
- QRadioButton::indicator {
- width: 16px;
- height: 16px;
- border-radius: 9px;
- border: 2px solid #A0A0A0;
- background-color: #FFFFFF;
- }
- QRadioButton::indicator:hover {
- border: 2px solid #0078D7;
- }
- QRadioButton::indicator:checked {
- background-color: #0078D7;
- border: 2px solid #0078D7;
- }
- """)
+ btn_layout = QHBoxLayout()
+ self.select_all_btn = QPushButton("全选")
+ self.deselect_all_btn = QPushButton("清空")
+ self.select_all_btn.clicked.connect(self.select_all_formulas)
+ self.deselect_all_btn.clicked.connect(self.deselect_all_formulas)
+ btn_layout.addWidget(self.select_all_btn)
+ btn_layout.addWidget(self.deselect_all_btn)
+ btn_layout.addStretch()
- source_group.setLayout(source_layout)
- layout.addWidget(source_group)
+ self.refresh_button = QPushButton("手动重新加载公式")
+ self.refresh_button.clicked.connect(lambda: self.refresh_formulas(silent=False))
+ btn_layout.addWidget(self.refresh_button)
- # -------- 外部模型文件选择(条件显示) --------
- self.external_model_widget = FileSelectWidget(
- "模型母文件夹:",
- "Directories"
- )
- self.external_model_widget.browse_btn.clicked.disconnect()
- self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
- self.external_model_widget.setVisible(False)
- layout.addWidget(self.external_model_widget)
+ formula_outer_layout.addLayout(btn_layout)
- # -------- 已扫描模型列表(条件显示) --------
- self.model_list_group = QGroupBox("选择参与预测的模型")
- self.model_list_group.setVisible(False)
- model_list_layout = QVBoxLayout()
+ # 核心滚动区
+ scroll = QScrollArea()
+ scroll.setWidgetResizable(True)
+ scroll.setMinimumHeight(300) # 强制最小高度,防止塌陷
+ self.scroll_content = QWidget()
+ self.formula_layout = QGridLayout(self.scroll_content)
+ self.formula_layout.setAlignment(Qt.AlignTop) # 靠顶对齐
+ scroll.setWidget(self.scroll_content)
+ formula_outer_layout.addWidget(scroll)
- self.model_list = QListWidget()
- self.model_list.setMaximumHeight(130)
- self.model_list.setSelectionMode(QAbstractItemView.NoSelection)
- self.model_list.setStyleSheet("""
- QListWidget {
- border: 1px solid #C0C0C0;
- border-radius: 4px;
- background-color: #FFFFFF;
- font-size: 12px;
- }
- QListWidget::item {
- padding: 4px 6px;
- border-bottom: 1px solid #F0F0F0;
- }
- QListWidget::item:selected {
- background-color: transparent;
- }
- """)
- model_list_layout.addWidget(self.model_list)
+ self.formula_group.setLayout(formula_outer_layout)
+ main_layout.addWidget(self.formula_group)
- btn_row = QHBoxLayout()
- self.btn_select_all = QPushButton("全选")
- self.btn_select_all.setMaximumWidth(80)
- self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
- self.btn_select_all.clicked.connect(self._select_all_models)
+ # 4. 输出与运行
+ output_group = QGroupBox("结果输出")
+ output_layout = QVBoxLayout()
+ self.output_file_widget = FileSelectWidget("保存路径:", "CSV Files (*.csv)", mode="save")
+ output_layout.addWidget(self.output_file_widget)
+ output_group.setLayout(output_layout)
+ main_layout.addWidget(output_group)
- self.btn_select_none = QPushButton("全不选")
- self.btn_select_none.setMaximumWidth(80)
- self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
- self.btn_select_none.clicked.connect(self._select_none_models)
-
- btn_row.addWidget(self.btn_select_all)
- btn_row.addWidget(self.btn_select_none)
- btn_row.addStretch()
- model_list_layout.addLayout(btn_row)
-
- self.model_list_group.setLayout(model_list_layout)
- layout.addWidget(self.model_list_group)
-
- # -------- 采样光谱CSV文件(用于独立运行)--------
- self.sampling_csv_file = FileSelectWidget(
- "采样光谱CSV:",
- "CSV Files (*.csv);;All Files (*.*)"
- )
- layout.addWidget(self.sampling_csv_file)
-
- # 模型目录(用于独立运行)
- self.models_dir_file = FileSelectWidget(
- "模型目录:",
- "Directories;;All Files (*.*)"
- )
- self.models_dir_file.label.setText("模型目录:")
- self.models_dir_file.browse_btn.clicked.disconnect()
- self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
- layout.addWidget(self.models_dir_file)
-
- # 参数设置
- params_group = QGroupBox("预测参数")
- params_layout = QFormLayout()
-
- self.metric = QComboBox()
- self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
- params_layout.addRow("模型选择指标:", self.metric)
-
- self.prediction_column = QLineEdit()
- self.prediction_column.setText("prediction")
- params_layout.addRow("预测列名:", self.prediction_column)
-
- params_group.setLayout(params_layout)
- layout.addWidget(params_group)
-
- # 输出路径
- self.output_file = FileSelectWidget(
- "输出路径:",
- "CSV Files (*.csv);;All Files (*.*)"
- )
- layout.addWidget(self.output_file)
-
- # 启用步骤
- self.enable_checkbox = QCheckBox("启用此步骤")
+ self.enable_checkbox = QCheckBox("启用计算流程")
self.enable_checkbox.setChecked(True)
- layout.addWidget(self.enable_checkbox)
+ main_layout.addWidget(self.enable_checkbox)
- # 独立运行按钮
- self.run_btn = QPushButton("独立运行此步骤")
- self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
- self.run_btn.clicked.connect(self.run_step)
- layout.addWidget(self.run_btn)
+ self.run_button = QPushButton("立即执行计算")
+ self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
+ self.run_button.setMinimumHeight(40)
+ self.run_button.clicked.connect(self.run_step)
+ main_layout.addWidget(self.run_button)
- layout.addStretch()
- self.setLayout(layout)
+ self.setLayout(main_layout)
- def _on_model_source_changed(self, checked: bool):
- """单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
- if not checked:
+ def _auto_load_formulas(self):
+ """启动时自动加载逻辑"""
+ if os.path.exists(self.builtin_formula_path):
+ self.refresh_formulas(silent=True)
+ else:
+ print(f"DEBUG: 自动加载失败,路径不存在: {self.builtin_formula_path}")
+
+ def refresh_formulas(self, silent=False):
+ path = self.builtin_formula_path
+ if not os.path.exists(path):
+ if not silent: QMessageBox.warning(self, "错误", f"找不到内置公式文件:\n{path}")
return
- is_external = self.use_external_model.isChecked()
- self.external_model_widget.setVisible(is_external)
- self.model_list_group.setVisible(is_external)
- if not is_external:
- self.external_models_dict = {}
- self.external_model_dir = ""
- self._clear_model_list()
-
- def _scan_external_model_dir(self):
- """浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
- default = self._get_default_work_dir()
- if default:
- default = os.path.join(default, "7_Supervised_Model_Training")
- dir_path = QFileDialog.getExistingDirectory(
- self,
- "选择模型母文件夹",
- default,
- )
- if not dir_path:
- return
-
- self.external_model_dir = dir_path
- models_found = {}
- errors = []
try:
- import joblib
+ # 清理旧列表
+ for i in reversed(range(self.formula_layout.count())):
+ widget = self.formula_layout.itemAt(i).widget()
+ if widget: widget.deleteLater()
+ self.index_checkboxes.clear()
- for subentry in os.scandir(dir_path):
- if not subentry.is_dir():
- continue
- subdir_name = subentry.name
- joblib_files = [
- f for f in os.scandir(subentry.path)
- if f.is_file() and f.name.lower().endswith(".joblib")
- ]
- if not joblib_files:
- continue
- # 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
- joblib_path = joblib_files[0].path
+ # 鲁棒性读取:尝试不同编码
+ for encoding in ['utf-8', 'gbk', 'utf-8-sig']:
try:
- loaded = joblib.load(joblib_path)
- if isinstance(loaded, dict) and "model" in loaded:
- model_obj = loaded["model"]
- elif hasattr(loaded, "predict"):
- model_obj = loaded
- else:
- errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
- continue
- models_found[subdir_name] = model_obj
- except Exception as e:
- errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
+ df = pd.read_csv(path, encoding=encoding)
+ if 'Formula_Name' in df.columns: break
+ except: continue
+
+ if 'Formula_Name' not in df.columns:
+ if not silent: QMessageBox.critical(self, "错误", "CSV文件缺少 'Formula_Name' 列")
+ return
+
+ names = df['Formula_Name'].dropna().unique().tolist()
+
+ row, col = 0, 0
+ for name in names:
+ name = str(name).strip()
+ if not name: continue
+ cb = QCheckBox(name)
+ cb.setChecked(True)
+ self.index_checkboxes[name] = cb
+ self.formula_layout.addWidget(cb, row, col)
+ col += 1
+ if col >= 3:
+ col = 0
+ row += 1
+
+ # 强制UI更新
+ self.scroll_content.adjustSize()
+ print(f"✅ 成功加载 {len(self.index_checkboxes)} 个公式")
except Exception as e:
- QMessageBox.warning(
- self,
- "扫描失败",
- f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
- )
- return
+ if not silent: QMessageBox.critical(self, "加载失败", f"原因: {str(e)}")
- if not models_found:
- QMessageBox.warning(
- self,
- "未找到模型",
- f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
- "请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
- )
- self.external_model_widget.set_path("")
- self.external_models_dict = {}
- self._clear_model_list()
- return
+ def select_all_formulas(self):
+ for cb in self.index_checkboxes.values(): cb.setChecked(True)
- self.external_models_dict = models_found
- self._populate_model_list(models_found)
- names = sorted(models_found.keys())
- display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
- self.external_model_widget.set_path(display)
- self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
-
- err_lines = "\n".join(errors) if errors else "无"
- QMessageBox.information(
- self,
- "模型扫描完成",
- f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
- f"加载失败 {len(errors)} 个:\n{err_lines}",
- )
-
- def _populate_model_list(self, models_dict):
- """将扫描到的模型填充到 QListWidget,每个条目可勾选,默认全选"""
- self.model_list.clear()
- for name in sorted(models_dict.keys()):
- item = QListWidgetItem(name)
- item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
- item.setCheckState(Qt.Checked)
- self.model_list.addItem(item)
-
- def _clear_model_list(self):
- """清空模型列表"""
- self.model_list.clear()
-
- def _select_all_models(self):
- """全选:设置所有条目为 Checked"""
- for i in range(self.model_list.count()):
- self.model_list.item(i).setCheckState(Qt.Checked)
-
- def _select_none_models(self):
- """全不选:设置所有条目为 Unchecked"""
- for i in range(self.model_list.count()):
- self.model_list.item(i).setCheckState(Qt.Unchecked)
-
- def _get_checked_models_dict(self):
- """从列表中提取用户勾选的模型,组装成字典返回"""
- result = {}
- for i in range(self.model_list.count()):
- item = self.model_list.item(i)
- if item.checkState() == Qt.Checked:
- name = item.text()
- if name in self.external_models_dict:
- result[name] = self.external_models_dict[name]
- return result
-
- def update_from_config(self, work_dir=None, pipeline=None):
- """从全局配置自动填充采样光谱和模型目录
-
- Args:
- work_dir: 工作目录路径
- pipeline: Pipeline 实例(未使用,保留接口兼容性)
- """
- try:
- import traceback
-
- if work_dir:
- self.work_dir = work_dir
- elif hasattr(self, 'work_dir') and self.work_dir:
- pass
- else:
- self.work_dir = None
-
- main_window = self.window()
-
- # 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
- if main_window and hasattr(main_window, 'step7_panel'):
- step7_widget = getattr(main_window.step7_panel, 'output_file', None)
- step7_output_path = ""
- if hasattr(step7_widget, 'get_path'):
- step7_output_path = step7_widget.get_path() or ""
- elif hasattr(step7_widget, 'text'):
- step7_output_path = step7_widget.text() or ""
-
- if step7_output_path:
- # 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(step7_output_path):
- step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
- existing = self.sampling_csv_file.get_path()
- if not existing or not existing.strip():
- self.sampling_csv_file.set_path(step7_output_path)
-
- # 2. 尝试从 Step6 界面读取监督模型目录
- if main_window and hasattr(main_window, 'step6_panel'):
- step6_widget = getattr(main_window.step6_panel, 'output_dir', None)
- step6_models_dir = ""
- if hasattr(step6_widget, 'get_path'):
- step6_models_dir = step6_widget.get_path() or ""
- elif hasattr(step6_widget, 'text'):
- step6_models_dir = step6_widget.text() or ""
-
- if step6_models_dir:
- # 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(step6_models_dir):
- step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/')
- existing_models = self.models_dir_file.get_path()
- if not existing_models or not existing_models.strip():
- self.models_dir_file.set_path(step6_models_dir)
-
- # 3. 自动填充输出路径(机器学习预测目录)
- if self.work_dir:
- output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction")
- os.makedirs(output_dir, exist_ok=True)
- existing_out = self.output_file.get_path()
- if not existing_out or not existing_out.strip():
- self.output_file.set_path(output_dir)
- except Exception as e:
- import traceback
- print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
- traceback.print_exc()
-
- def _get_default_work_dir(self):
- """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
- if hasattr(self, 'work_dir') and self.work_dir:
- return str(self.work_dir)
- mw = self.window()
- if mw and hasattr(mw, 'work_dir') and mw.work_dir:
- return str(mw.work_dir)
- return ""
-
- def browse_models_dir(self):
- """浏览模型目录"""
- default = self._get_default_work_dir()
- if default:
- default = os.path.join(default, "7_Supervised_Model_Training")
- dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
- if dir_path:
- self.models_dir_file.set_path(dir_path)
+ def deselect_all_formulas(self):
+ for cb in self.index_checkboxes.values(): cb.setChecked(False)
def get_config(self):
- """获取配置"""
- config = {
- 'metric': self.metric.currentText(),
- 'prediction_column': self.prediction_column.text(),
+ selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
+ return {
+ 'training_csv_path': self.training_data_widget.get_path(),
+ 'formula_csv_file': self.builtin_formula_path,
+ 'formula_names': selected,
+ 'output_file': self.output_file_widget.get_path(),
+ 'enabled': self.enable_checkbox.isChecked()
}
- sampling_csv_path = self.sampling_csv_file.get_path()
- if sampling_csv_path:
- config['sampling_csv_path'] = sampling_csv_path
- models_dir = self.models_dir_file.get_path()
- if models_dir:
- config['models_dir'] = models_dir
- output_path = self.output_file.get_path()
- if output_path:
- config['output_path'] = output_path
- return config
def set_config(self, config):
- """设置配置"""
- if 'metric' in config:
- idx = self.metric.findText(config['metric'])
- if idx >= 0:
- self.metric.setCurrentIndex(idx)
- if 'prediction_column' in config:
- self.prediction_column.setText(config['prediction_column'])
- if 'sampling_csv_path' in config:
- self.sampling_csv_file.set_path(config['sampling_csv_path'])
- if 'models_dir' in config:
- self.models_dir_file.set_path(config['models_dir'])
- if 'output_path' in config:
- self.output_file.set_path(config['output_path'])
+ if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path'])
+ if 'formula_names' in config:
+ sel = set(config['formula_names'])
+ for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
+ if 'output_file' in config: self.output_file_widget.set_path(config['output_file'])
+ self.enable_checkbox.setChecked(config.get('enabled', True))
+
+ def update_from_config(self, work_dir=None, pipeline=None):
+ if work_dir: self.work_dir = work_dir
+ main = self.window()
+ if hasattr(main, 'step5_panel'):
+ p5 = main.step5_panel.output_file.get_path() # 修正:变量名对齐
+ if p5:
+ if not os.path.isabs(p5): p5 = os.path.join(self.work_dir or '', p5).replace('\\', '/')
+ self.training_data_widget.set_path(p5)
+
+ if self.work_dir:
+ out = os.path.join(self.work_dir, "6_water_quality_indices", "training_spectra_indices.csv").replace('\\', '/')
+ self.output_file_widget.set_path(out)
def run_step(self):
- """独立运行步骤8"""
- sampling_csv_path = self.sampling_csv_file.get_path()
- if not sampling_csv_path:
- QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
+ config = self.get_config()
+ if not config['training_csv_path']:
+ QMessageBox.warning(self, "提示", "请先选择输入数据")
return
-
- # 外部模型优先:用户选择了"导入本地预训练模型"
- if self.use_external_model.isChecked():
- if not self.external_models_dict:
- QMessageBox.warning(
- self,
- "模型未加载",
- "请先点击「浏览...」按钮选择模型母文件夹!",
- )
- return
- # 只传递用户勾选的模型
- checked_dict = self._get_checked_models_dict()
- if not checked_dict:
- QMessageBox.warning(
- self,
- "未选择模型",
- "请至少勾选一个模型参与预测!",
- )
- return
- main_window = self.window()
- if hasattr(main_window, 'run_single_step'):
- config = {
- 'step8': self.get_config(),
- '_external_models_dict': checked_dict,
- '_external_model_dir': self.external_model_dir,
- }
- main_window.run_single_step('step8', config)
- return
-
- # 默认流程:使用模型目录
- models_dir = self.models_dir_file.get_path()
- if not models_dir:
- QMessageBox.warning(self, "输入错误", "请选择模型目录!")
- return
-
- main_window = self.window()
- if hasattr(main_window, 'run_single_step'):
- config = {'step8': self.get_config()}
- main_window.run_single_step('step8', config)
+ parent = self.parent()
+ while parent and not hasattr(parent, 'run_single_step'): parent = parent.parent()
+ if parent: parent.run_single_step('step8', {'step8': config})
\ No newline at end of file
diff --git a/src/gui/panels/step9_panel.py b/src/gui/panels/step9_panel.py
index c35a9d1..335eb44 100644
--- a/src/gui/panels/step9_panel.py
+++ b/src/gui/panels/step9_panel.py
@@ -1,206 +1,158 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
-Step9 面板 - 分布图生成
+Step9 面板 - 自定义回归分析
"""
import os
-import traceback
-from pathlib import Path
-from typing import List, Optional
+from typing import Dict
-from PyQt5.QtCore import Qt, QThread, pyqtSignal
+import pandas as pd
from PyQt5.QtWidgets import (
- QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
- QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
- QRadioButton, QButtonGroup, QMessageBox, QFileDialog,
+ QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
+ QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton,
+ QScrollArea, QMessageBox,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
-# Pipeline 可用性(与 core/worker_thread.py 保持一致)
-try:
- from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
- PIPELINE_AVAILABLE = True
-except ImportError:
- PIPELINE_AVAILABLE = False
-
-
-class Step9BatchThread(QThread):
- """专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
-
- finished_ok = pyqtSignal(int)
- failed = pyqtSignal(str)
- log_message = pyqtSignal(str, str)
-
- def __init__(self, work_dir: str, csv_paths: List[str], step9_kwargs: dict, output_dir_optional: Optional[str]):
- super().__init__()
- self.work_dir = work_dir
- self.csv_paths = csv_paths
- self.step9_kwargs = step9_kwargs
- self.output_dir_optional = (output_dir_optional or "").strip() or None
-
- def run(self):
- mpl_prev = None
- try:
- import matplotlib
- mpl_prev = matplotlib.get_backend()
- except Exception:
- pass
- try:
- import matplotlib.pyplot as plt
- plt.switch_backend("Agg")
- except Exception:
- mpl_prev = None
- try:
- from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
- pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
- n = len(self.csv_paths)
- for i, csv_p in enumerate(self.csv_paths):
- self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
- kw = {**self.step9_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
- if self.output_dir_optional:
- stem = Path(csv_p).stem
- kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
- else:
- kw["output_image_path"] = None
- pipeline.step9_generate_distribution_map(**kw)
- self.finished_ok.emit(n)
- except Exception as e:
- self.failed.emit(f"{e}\n{traceback.format_exc()}")
- finally:
- if mpl_prev:
- try:
- import matplotlib.pyplot as plt
- plt.switch_backend(mpl_prev)
- except Exception:
- pass
-
class Step9Panel(QWidget):
- """步骤9:分布图生成"""
+ """步骤9:自定义回归分析"""
def __init__(self, parent=None):
super().__init__(parent)
- self._batch_thread = None
+ self.x_column_checkboxes: Dict[str, QCheckBox] = {}
+ self.y_column_checkboxes: Dict[str, QCheckBox] = {}
+ self.method_checkboxes: Dict[str, QCheckBox] = {}
+ self.csv_columns = []
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
- hint = QLabel(
- "独立运行:可选「单个 CSV」或「文件夹批量」(扫描目录下所有 .csv)。"
- "完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。"
- )
- hint.setWordWrap(True)
- hint.setStyleSheet(
- f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
- )
+ hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法")
+ hint.setStyleSheet("color: #666; font-size: 11px;")
layout.addWidget(hint)
- mode_row = QHBoxLayout()
- self.mode_single_rb = QRadioButton("单个 CSV 文件")
- self.mode_folder_rb = QRadioButton("文件夹批量")
- self._mode_group = QButtonGroup(self)
- self._mode_group.addButton(self.mode_single_rb, 0)
- self._mode_group.addButton(self.mode_folder_rb, 1)
- mode_row.addWidget(self.mode_single_rb)
- mode_row.addWidget(self.mode_folder_rb)
- mode_row.addStretch()
- layout.addLayout(mode_row)
+ # CSV文件选择
+ csv_group = QGroupBox("数据文件")
+ csv_layout = QVBoxLayout()
- # ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
- radio_style = """
- QRadioButton {
- font-size: 14px;
- spacing: 8px;
- color: #333333;
- }
- QRadioButton::indicator {
- width: 16px;
- height: 16px;
- border: 2px solid #999999;
- border-radius: 3px;
- background-color: white;
- }
- QRadioButton::indicator:checked {
- border: 2px solid #0078d4;
- background-color: #0078d4;
- image: none;
- }
- QRadioButton::indicator:hover {
- border: 2px solid #005a9e;
- }
- """
- self.mode_single_rb.setStyleSheet(radio_style)
- self.mode_folder_rb.setStyleSheet(radio_style)
-
- self.prediction_csv_file = FileSelectWidget(
- "预测结果CSV:",
+ self.csv_file = FileSelectWidget(
+ "输入CSV文件:",
"CSV Files (*.csv);;All Files (*.*)"
)
- layout.addWidget(self.prediction_csv_file)
+ self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed)
+ csv_layout.addWidget(self.csv_file)
- folder_row = QHBoxLayout()
- self.prediction_csv_dir_label = QLabel("预测CSV目录:")
- self.prediction_csv_dir_label.setMinimumWidth(120)
- self.prediction_csv_dir_edit = QLineEdit()
- self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
- pred_dir_btn = QPushButton("浏览…")
- pred_dir_btn.setMaximumWidth(80)
- pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
- folder_row.addWidget(self.prediction_csv_dir_label)
- folder_row.addWidget(self.prediction_csv_dir_edit, 1)
- folder_row.addWidget(pred_dir_btn)
- self._folder_row_widget = QWidget()
- self._folder_row_widget.setLayout(folder_row)
- layout.addWidget(self._folder_row_widget)
+ self.refresh_btn = QPushButton("刷新列信息")
+ self.refresh_btn.clicked.connect(self.refresh_csv_columns)
+ csv_layout.addWidget(self.refresh_btn)
- self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv)")
- layout.addWidget(self.recursive_csv_cb)
+ csv_group.setLayout(csv_layout)
+ layout.addWidget(csv_group)
- self.boundary_file = FileSelectWidget(
- "边界文件:",
- "Shapefiles (*.shp);;All Files (*.*)"
- )
- layout.addWidget(self.boundary_file)
+ # 自变量选择
+ x_group = QGroupBox("自变量列选择 (可多选)")
+ x_layout = QVBoxLayout()
- # 参数设置
- params_group = QGroupBox("生成参数")
- params_layout = QFormLayout()
+ x_scroll = QScrollArea()
+ x_scroll.setWidgetResizable(True)
+ x_scroll.setMinimumHeight(250)
+ x_scroll.setMaximumHeight(350)
- self.resolution = QDoubleSpinBox()
- self.resolution.setRange(1, 1000)
- self.resolution.setValue(30)
- params_layout.addRow("分辨率(米):", self.resolution)
+ x_widget = QWidget()
+ self.x_columns_layout = QGridLayout()
+ x_widget.setLayout(self.x_columns_layout)
- self.input_crs = QLineEdit()
- self.input_crs.setText("EPSG:32651")
- params_layout.addRow("输入坐标系:", self.input_crs)
+ x_scroll.setWidget(x_widget)
+ x_layout.addWidget(x_scroll)
- self.output_crs = QLineEdit()
- self.output_crs.setText("EPSG:4326")
- params_layout.addRow("输出坐标系:", self.output_crs)
+ x_btn_layout = QHBoxLayout()
+ self.x_select_all = QPushButton("全选")
+ self.x_deselect_all = QPushButton("全不选")
+ self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True))
+ self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False))
+ x_btn_layout.addWidget(self.x_select_all)
+ x_btn_layout.addWidget(self.x_deselect_all)
+ x_btn_layout.addStretch()
+ x_layout.addLayout(x_btn_layout)
- self.show_points = QCheckBox("显示采样点")
- params_layout.addRow("", self.show_points)
+ x_group.setLayout(x_layout)
+ layout.addWidget(x_group)
- self.use_diffusion = QCheckBox("启用距离扩散")
- self.use_diffusion.setChecked(True)
- params_layout.addRow("", self.use_diffusion)
+ # 因变量选择
+ y_group = QGroupBox("因变量列选择 (可多选)")
+ y_layout = QVBoxLayout()
- params_group.setLayout(params_layout)
- layout.addWidget(params_group)
+ y_scroll = QScrollArea()
+ y_scroll.setWidgetResizable(True)
+ y_scroll.setMinimumHeight(200)
+ y_scroll.setMaximumHeight(300)
+
+ y_widget = QWidget()
+ self.y_columns_layout = QGridLayout()
+ y_widget.setLayout(self.y_columns_layout)
+
+ y_scroll.setWidget(y_widget)
+ y_layout.addWidget(y_scroll)
+
+ y_btn_layout = QHBoxLayout()
+ self.y_select_all = QPushButton("全选")
+ self.y_deselect_all = QPushButton("全不选")
+ self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True))
+ self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False))
+ y_btn_layout.addWidget(self.y_select_all)
+ y_btn_layout.addWidget(self.y_deselect_all)
+ y_btn_layout.addStretch()
+ y_layout.addLayout(y_btn_layout)
+
+ y_group.setLayout(y_layout)
+ layout.addWidget(y_group)
+
+ # 回归方法选择
+ method_group = QGroupBox("回归方法选择 (可多选)")
+ method_layout = QVBoxLayout()
+
+ method_grid = QGridLayout()
+ regression_methods = [
+ 'linear', 'exponential', 'power', 'logarithmic',
+ 'polynomial', 'hyperbolic', 'sigmoidal'
+ ]
+
+ for i, method in enumerate(regression_methods):
+ checkbox = QCheckBox(method)
+ if method in ['linear', 'exponential', 'power', 'logarithmic']:
+ checkbox.setChecked(True)
+ self.method_checkboxes[method] = checkbox
+ method_grid.addWidget(checkbox, i // 3, i % 3)
+
+ method_layout.addLayout(method_grid)
+
+ method_btn_layout = QHBoxLayout()
+ self.method_select_all = QPushButton("全选")
+ self.method_deselect_all = QPushButton("全不选")
+ self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True))
+ self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False))
+ method_btn_layout.addWidget(self.method_select_all)
+ method_btn_layout.addWidget(self.method_deselect_all)
+ method_btn_layout.addStretch()
+ method_layout.addLayout(method_btn_layout)
+
+ method_group.setLayout(method_layout)
+ layout.addWidget(method_group)
# 输出目录
- self.output_dir = FileSelectWidget(
- "输出分布图目录:",
- "Directories;;All Files (*.*)"
- )
- self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
- self.output_dir.browse_btn.clicked.disconnect()
- self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
- layout.addWidget(self.output_dir)
+ output_group = QGroupBox("输出设置")
+ output_layout = QFormLayout()
+
+ self.output_dir = QLineEdit()
+ self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充
+ output_layout.addRow("输出目录名:", self.output_dir)
+
+ output_group.setLayout(output_layout)
+ layout.addWidget(output_group)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
@@ -216,119 +168,120 @@ class Step9Panel(QWidget):
layout.addStretch()
self.setLayout(layout)
- # 信号绑定与初始状态
- self.mode_single_rb.toggled.connect(self._toggle_input_mode)
- self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
- self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
- self._toggle_input_mode() # 根据默认值设置初始显示状态
+ def toggle_checkboxes(self, checkboxes_dict, checked):
+ """统一设置checkbox状态"""
+ for checkbox in checkboxes_dict.values():
+ checkbox.setChecked(checked)
- def _toggle_input_mode(self):
- """槽函数:根据单选框状态动态显示/隐藏对应的输入组件。"""
- folder_mode = self.mode_folder_rb.isChecked()
- # 单个 CSV 模式:显示单文件选择,隐藏文件夹选择
- self.prediction_csv_file.setVisible(not folder_mode)
- # 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择
- self._folder_row_widget.setVisible(folder_mode)
- self.recursive_csv_cb.setVisible(folder_mode)
+ def on_csv_file_changed(self):
+ """CSV文件改变时自动刷新列信息"""
+ self.refresh_csv_columns()
- def _get_default_work_dir(self):
- """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
- if hasattr(self, 'work_dir') and self.work_dir:
- return str(self.work_dir)
- mw = self.window()
- if mw and hasattr(mw, 'work_dir') and mw.work_dir:
- return str(mw.work_dir)
- return ""
+ def refresh_csv_columns(self):
+ """刷新CSV文件的列信息"""
+ csv_path = self.csv_file.get_path()
+ if not csv_path or not os.path.exists(csv_path):
+ self.csv_columns = []
+ self.update_column_widgets()
+ return
- def browse_prediction_csv_dir(self):
- default = self._get_default_work_dir()
- if default:
- default = os.path.join(default, "11_12_13_predictions")
- d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
- if d:
- self.prediction_csv_dir_edit.setText(d)
+ try:
+ df = pd.read_csv(csv_path, nrows=0)
+ self.csv_columns = list(df.columns)
+ self.update_column_widgets()
+ except Exception as e:
+ self.csv_columns = []
+ self.update_column_widgets()
+ print(f"读取CSV列信息失败: {e}")
- def _collect_csv_paths_from_folder(self) -> List[str]:
- folder = (self.prediction_csv_dir_edit.text() or "").strip()
- if not folder or not os.path.isdir(folder):
- return []
- root = Path(folder)
- if self.recursive_csv_cb.isChecked():
- files = sorted(root.rglob("*.csv"))
- else:
- files = sorted(root.glob("*.csv"))
- return [str(p) for p in files if p.is_file()]
+ def update_column_widgets(self):
+ """更新列选择组件"""
+ for checkbox in self.x_column_checkboxes.values():
+ checkbox.setParent(None)
+ self.x_column_checkboxes.clear()
- def _step9_base_pipeline_kwargs(self) -> dict:
- return {
- 'boundary_shp_path': self.boundary_file.get_path(),
- 'resolution': self.resolution.value(),
- 'input_crs': self.input_crs.text(),
- 'output_crs': self.output_crs.text(),
- 'show_sample_points': self.show_points.isChecked(),
- 'use_distance_diffusion': self.use_diffusion.isChecked(),
- }
+ for checkbox in self.y_column_checkboxes.values():
+ checkbox.setParent(None)
+ self.y_column_checkboxes.clear()
+
+ if not self.csv_columns:
+ return
+
+ for i, col in enumerate(self.csv_columns):
+ checkbox = QCheckBox(col)
+ if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
+ checkbox.setChecked(True)
+ self.x_column_checkboxes[col] = checkbox
+ self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
+
+ for i, col in enumerate(self.csv_columns):
+ checkbox = QCheckBox(col)
+ if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
+ checkbox.setChecked(True)
+ self.y_column_checkboxes[col] = checkbox
+ self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
+
+ self.x_columns_layout.update()
+ self.y_columns_layout.update()
def get_config(self):
- pred_csv = (self.prediction_csv_file.get_path() or "").strip()
- folder_mode = self.mode_folder_rb.isChecked()
- pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
- config = {
- 'step9_batch_mode': 'folder' if folder_mode else 'single',
- 'prediction_csv_dir': pred_dir if pred_dir else None,
- 'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
- 'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
- 'boundary_shp_path': self.boundary_file.get_path(),
- 'resolution': self.resolution.value(),
- 'input_crs': self.input_crs.text(),
- 'output_crs': self.output_crs.text(),
- 'show_sample_points': self.show_points.isChecked(),
- 'use_distance_diffusion': self.use_diffusion.isChecked(),
+ selected_x_columns = [
+ col for col, checkbox in self.x_column_checkboxes.items()
+ if checkbox.isChecked()
+ ]
+ selected_y_columns = [
+ col for col, checkbox in self.y_column_checkboxes.items()
+ if checkbox.isChecked()
+ ]
+ selected_methods = [
+ method for method, checkbox in self.method_checkboxes.items()
+ if checkbox.isChecked()
+ ]
+ if not selected_methods:
+ selected_methods = 'all'
+
+ return {
+ 'csv_path': self.csv_file.get_path() or None,
+ 'x_columns': selected_x_columns,
+ 'y_columns': selected_y_columns,
+ 'methods': selected_methods,
+ 'output_dir': self.output_dir.text().strip() or None,
+ 'enabled': self.enable_checkbox.isChecked()
}
- out_dir = (self.output_dir.get_path() or "").strip()
- if not folder_mode and pred_csv and out_dir:
- stem = Path(pred_csv).stem
- config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
- else:
- config['output_image_path'] = None
- return config
def set_config(self, config):
- mode = config.get('step9_batch_mode', 'single')
- if mode == 'folder':
- self.mode_folder_rb.setChecked(True)
- else:
- self.mode_single_rb.setChecked(True)
- if config.get('prediction_csv_dir'):
- self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
- if 'recursive_csv_scan' in config:
- self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
- if 'prediction_csv_path' in config and config['prediction_csv_path']:
- self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
- if 'boundary_shp_path' in config:
- self.boundary_file.set_path(config['boundary_shp_path'])
- if 'resolution' in config:
- self.resolution.setValue(config['resolution'])
- if 'input_crs' in config:
- self.input_crs.setText(config['input_crs'])
- if 'output_crs' in config:
- self.output_crs.setText(config['output_crs'])
- if 'show_sample_points' in config:
- self.show_points.setChecked(config['show_sample_points'])
- if 'use_distance_diffusion' in config:
- self.use_diffusion.setChecked(config['use_distance_diffusion'])
- if 'output_dir' in config and config['output_dir']:
- self.output_dir.set_path(str(config['output_dir']))
- elif config.get('output_image_path'):
- p = Path(str(config['output_image_path']))
- if p.parent and str(p.parent) != '.':
- self.output_dir.set_path(str(p.parent))
+ if 'csv_path' in config:
+ self.csv_file.set_path(config['csv_path'])
+ self.refresh_csv_columns()
+
+ if 'x_columns' in config:
+ selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set()
+ for col, checkbox in self.x_column_checkboxes.items():
+ checkbox.setChecked(col in selected_x)
+
+ if 'y_columns' in config:
+ selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set()
+ for col, checkbox in self.y_column_checkboxes.items():
+ checkbox.setChecked(col in selected_y)
+
+ if 'methods' in config:
+ methods = config['methods']
+ if isinstance(methods, list):
+ selected_methods = set(methods)
+ elif methods == 'all':
+ selected_methods = set(self.method_checkboxes.keys())
+ else:
+ selected_methods = set()
+ for method, checkbox in self.method_checkboxes.items():
+ checkbox.setChecked(method in selected_methods)
+
+ if 'output_dir' in config:
+ self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
+ if 'enabled' in config:
+ self.enable_checkbox.setChecked(config['enabled'])
def update_from_config(self, work_dir=None, pipeline=None):
- """从全局配置自动填充预测结果目录
-
- 优先使用 Step8(机器学习预测)的输出目录作为待预测 CSV 目录;
- 其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
+ """从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
@@ -344,190 +297,78 @@ class Step9Panel(QWidget):
else:
self.work_dir = None
+ # 1. 尝试从 Step5 界面读取训练光谱 CSV 路径
main_window = self.window()
- if not main_window:
- return
+ if main_window and hasattr(main_window, 'step5_panel'):
+ step5_widget = getattr(main_window.step5_panel, 'output_file', None)
+ step5_output_path = ""
+ if hasattr(step5_widget, 'get_path'):
+ step5_output_path = step5_widget.get_path() or ""
+ elif hasattr(step5_widget, 'text'):
+ step5_output_path = step5_widget.text() or ""
- # 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
- pred_dir = None
- if hasattr(main_window, 'step8_panel'):
- step8_widget = getattr(main_window.step8_panel, 'output_file', None)
- step8_output = ""
- if hasattr(step8_widget, 'get_path'):
- step8_output = step8_widget.get_path() or ""
- elif hasattr(step8_widget, 'text'):
- step8_output = step8_widget.text() or ""
-
- if step8_output:
+ if step5_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(step8_output):
- step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
- # 提取父目录后追加 Machine_Learning_Prediction(最底层真实子目录)
- base_pred_dir = str(Path(step8_output).parent)
- ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
- pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
+ if not os.path.isabs(step5_output_path):
+ step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
+ existing = self.csv_file.get_path()
+ if not existing or not existing.strip():
+ self.csv_file.set_path(step5_output_path)
- # 2. 备选:从 Step8.5 界面读取非经验预测输出目录
- if not pred_dir and hasattr(main_window, 'step8_5_panel'):
- step8_5_widget = getattr(main_window.step8_5_panel, 'output_file', None)
- step8_5_output = ""
- if hasattr(step8_5_widget, 'get_path'):
- step8_5_output = step8_5_widget.get_path() or ""
- elif hasattr(step8_5_widget, 'text'):
- step8_5_output = step8_5_widget.text() or ""
-
- if step8_5_output:
- # 若为相对路径,使用 work_dir 合成为绝对路径
- if not os.path.isabs(step8_5_output):
- step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
- pred_dir = str(Path(step8_5_output).parent)
-
- # 3. 备选:从 Step8.75 界面读取自定义回归预测输出目录
- if not pred_dir and hasattr(main_window, 'step8_75_panel'):
- step8_75_widget = getattr(main_window.step8_75_panel, 'output_dir_widget', None)
- step8_75_output = ""
- if hasattr(step8_75_widget, 'get_path'):
- step8_75_output = step8_75_widget.get_path() or ""
- elif hasattr(step8_75_widget, 'text'):
- step8_75_output = step8_75_widget.text() or ""
-
- if step8_75_output:
- pred_dir = step8_75_output
-
- # 自动填入"预测CSV目录"(文件夹批量模式)
- if pred_dir:
- existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
- if not existing_dir:
- self.prediction_csv_dir_edit.setText(pred_dir)
- # 切换到文件夹批量模式
- self.mode_folder_rb.setChecked(True)
-
- # 4. 自动填充输出目录(14_visualization)
+ # 2. 自动填充输出目录(9_Custom_Regression_Modeling)
if self.work_dir:
- output_dir = os.path.join(self.work_dir, "14_visualization")
+ output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling")
os.makedirs(output_dir, exist_ok=True)
- existing_out = self.output_dir.get_path()
- if not existing_out or not existing_out.strip():
- self.output_dir.set_path(output_dir)
-
- # 5. 自动探测原始矢量边界文件(.shp)作为专题图底图
- # 优先回溯 input-test/roi.shp,geopandas.read_file 仅支持矢量格式
- if self.work_dir:
- possible_shp = None
- candidates = [
- Path(self.work_dir).parent / "input-test" / "roi.shp",
- Path(self.work_dir) / "roi.shp",
- Path(self.work_dir).parent / "roi.shp",
- ]
- for candidate in candidates:
- if candidate.exists() and candidate.suffix.lower() == ".shp":
- possible_shp = candidate
- break
-
- existing_boundary = (self.boundary_file.get_path() or "").strip()
- if not existing_boundary and possible_shp:
- self.boundary_file.set_path(str(possible_shp))
- elif not existing_boundary:
- # 未找到 .shp 时清空并提示用户手动选择矢量文件
- self.boundary_file.set_path("")
- print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。")
+ existing_out = self.output_dir.text().strip()
+ if not existing_out:
+ self.output_dir.setText(output_dir)
except Exception as e:
import traceback
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
- def browse_output_dir(self):
- """浏览输出目录"""
- default = self._get_default_work_dir()
- if default:
- default = os.path.join(default, "14_visualization")
- dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
- if dir_path:
- self.output_dir.set_path(dir_path)
-
def run_step(self):
"""独立运行步骤9"""
- if self._batch_thread and self._batch_thread.isRunning():
- QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
+ csv_path = self.csv_file.get_path()
+
+ if not csv_path:
+ QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件")
+ return
+ if not os.path.exists(csv_path):
+ QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在")
return
- boundary_shp_path = self.boundary_file.get_path()
- if not boundary_shp_path:
- QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
+ selected_x_columns = [
+ col for col, checkbox in self.x_column_checkboxes.items()
+ if checkbox.isChecked()
+ ]
+ if not selected_x_columns:
+ QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列")
return
- if not os.path.exists(boundary_shp_path):
- QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
+
+ selected_y_columns = [
+ col for col, checkbox in self.y_column_checkboxes.items()
+ if checkbox.isChecked()
+ ]
+ if not selected_y_columns:
+ QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列")
return
+ selected_methods = [
+ method for method, checkbox in self.method_checkboxes.items()
+ if checkbox.isChecked()
+ ]
+ if not selected_methods:
+ QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法")
+ return
+
+ config = self.get_config()
+
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
- if not parent or not hasattr(parent, 'run_single_step'):
+ if parent and hasattr(parent, 'run_single_step'):
+ parent.run_single_step('step9', {'step9': config})
+ else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
- return
-
- if self.mode_folder_rb.isChecked():
- csv_list = self._collect_csv_paths_from_folder()
- if not csv_list:
- QMessageBox.warning(
- self,
- "输入验证失败",
- "所选文件夹中未找到 .csv 文件,或目录无效。\n"
- "可勾选「包含子文件夹」以递归扫描。",
- )
- return
- if not PIPELINE_AVAILABLE:
- QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
- return
- work_dir = getattr(parent, "work_dir", None) or "./work_dir"
- work_dir = str(work_dir)
- base_kw = self._step9_base_pipeline_kwargs()
- out_dir_opt = (self.output_dir.get_path() or "").strip() or None
- self.run_button.setEnabled(False)
- self._batch_thread = Step9BatchThread(work_dir, csv_list, base_kw, out_dir_opt)
- main_win = parent
-
- def _batch_log(msg, lvl):
- if hasattr(main_win, "log_message"):
- main_win.log_message(msg, lvl)
-
- self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
- self._batch_thread.finished_ok.connect(self._on_step9_batch_ok, Qt.QueuedConnection)
- self._batch_thread.failed.connect(self._on_step9_batch_fail, Qt.QueuedConnection)
- self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection)
- self._batch_thread.start()
- if hasattr(parent, "log_message"):
- parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV,工作目录 {work_dir}", "info")
- return
-
- prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
- if not prediction_csv_path:
- QMessageBox.warning(
- self,
- "输入验证失败",
- "请选择「预测结果 CSV」文件,或切换到「文件夹批量」。",
- )
- return
- if not os.path.isfile(prediction_csv_path):
- QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
- return
-
- config = self.get_config()
- parent.run_single_step('step9', {'step9': config})
-
- def _on_step9_batch_ok(self, n: int):
- QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
- parent = self.parent()
- while parent and not hasattr(parent, "log_message"):
- parent = parent.parent()
- if parent and hasattr(parent, "log_message"):
- parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
-
- def _on_step9_batch_fail(self, err: str):
- QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
- parent = self.parent()
- while parent and not hasattr(parent, "log_message"):
- parent = parent.parent()
- if parent and hasattr(parent, "log_message"):
- parent.log_message(err, "error")
diff --git a/src/gui/panels/visualization_panel.py b/src/gui/panels/visualization_panel.py
index 0c5300a..7064ed8 100644
--- a/src/gui/panels/visualization_panel.py
+++ b/src/gui/panels/visualization_panel.py
@@ -1567,12 +1567,12 @@ class VisualizationPanel(QWidget):
ml_dir.mkdir(parents=True, exist_ok=True)
reg_dir.mkdir(parents=True, exist_ok=True)
custom_dir.mkdir(parents=True, exist_ok=True)
- if hasattr(self, 'step8_panel') and hasattr(self.step8_panel, 'output_file'):
- self.step8_panel.output_file.set_path(str(ml_dir))
- if hasattr(self, 'step8_5_panel') and hasattr(self.step8_5_panel, 'output_file'):
- self.step8_5_panel.output_file.set_path(str(reg_dir))
- if hasattr(self, 'step8_75_panel') and hasattr(self.step8_75_panel, 'output_dir_widget'):
- self.step8_75_panel.output_dir_widget.set_path(str(custom_dir))
+ if hasattr(self, 'step11_ml_panel') and hasattr(self.step11_ml_panel, 'output_file'):
+ self.step11_ml_panel.output_file.set_path(str(ml_dir))
+ if hasattr(self, 'step11_panel') and hasattr(self.step11_panel, 'output_file'):
+ self.step11_panel.output_file.set_path(str(reg_dir))
+ if hasattr(self, 'step12_panel') and hasattr(self.step12_panel, 'output_dir_widget'):
+ self.step12_panel.output_dir_widget.set_path(str(custom_dir))
print(f"预测输出目录已设置:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}")
except Exception as e:
print(f"设置预测输出目录失败: {e}")
diff --git a/src/gui/water_quality_gui.py b/src/gui/water_quality_gui.py
index 0b88f03..6526866 100644
--- a/src/gui/water_quality_gui.py
+++ b/src/gui/water_quality_gui.py
@@ -119,19 +119,22 @@ from src.gui.panels.step2_panel import Step2Panel
from src.gui.panels.step3_panel import Step3Panel
from src.gui.panels.step4_panel import Step4Panel
from src.gui.panels.step5_panel import Step5Panel
-from src.gui.panels.step5_5_panel import Step5_5Panel
-from src.gui.panels.step6_panel import Step6Panel
-from src.gui.panels.step6_5_panel import Step6_5Panel
-from src.gui.panels.step6_75_panel import Step6_75Panel
-from src.gui.panels.step7_panel import Step7Panel
-from src.gui.panels.step8_panel import Step8Panel
-from src.gui.panels.step8_5_panel import Step8_5Panel
-from src.gui.panels.step8_75_panel import Step8_75Panel
-from src.gui.panels.step9_panel import Step9Panel
+from src.gui.panels.step8_panel import Step8Panel # was step5_5_panel
+from src.gui.panels.step7_panel import Step7Panel # was step6_panel
+from src.gui.panels.step8_non_empirical_panel import Step8NonEmpiricalPanel # was step6_5_panel
+from src.gui.panels.step9_panel import Step9Panel # was step6_75_panel
+from src.gui.panels.step10_panel import Step10Panel # was step7_panel
+from src.gui.panels.step11_ml_panel import Step11MlPanel # ML prediction (step11_ml)
+from src.gui.panels.step11_panel import Step11Panel # was step8_5_panel
+from src.gui.panels.step12_panel import Step12Panel # was step8_75_panel
+from src.gui.panels.step14_panel import Step14Panel # was step9_panel
from src.gui.dialogs import BandConfirmDialog, AISettingsDialog
from src.gui.panels.visualization_panel import VisualizationPanel
from src.gui.panels.report_generation_panel import ReportGenerationPanel
+# Pipeline 核心异常(用于预检弹窗)
+from src.core.pipeline.runner import PipelineHalt
+
# Matplotlib相关导入 (推迟并加入底层防爆保护)
import matplotlib
try:
@@ -152,6 +155,9 @@ from src.gui.core.worker_thread import (
check_pipeline_dependencies,
diagnose_pipeline_import_error,
)
+# 预检交互对话框
+from src.gui.core.preflight_dialog import PreflightDialog
+from src.gui.core.pipeline_mode_dialog import PipelineModeDialog
def _viz_training_spectra_csv_path(work_path: Path) -> Path:
@@ -1384,31 +1390,31 @@ class WaterQualityGUI(QMainWindow):
'step5': {
'training_spectra': '5_training_spectra/training_spectra.csv'
},
- 'step5_5': {
+ 'step8': {
'water_indices': '6_water_quality_indices/water_quality_indices.csv'
},
- 'step6': {
+ 'step7': {
'models': '7_Supervised_Model_Training/' # 目录,包含各参数子目录
},
- 'step6_5': {
+ 'step8_non_empirical_modeling': {
'regression_models': '8_Regression_Modeling/' # 目录,包含各参数子目录
},
- 'step6_75': {
+ 'step9': {
'custom_regression_models': '9_Custom_Regression_Modeling/' # 目录
},
- 'step7': {
+ 'step10': {
'sampling_points': '10_sampling/sampling_spectra.csv'
},
- 'step8': {
+ 'step11_ml': {
'predictions': '11_12_13_predictions/Machine_Learning_Prediction/' # 目录,包含机器学习预测结果
},
- 'step8_5': {
+ 'step11': {
'regression_predictions': '11_12_13_predictions/Non_Empirical_Prediction/' # 目录,包含非经验模型预测结果
},
- 'step8_75': {
+ 'step12': {
'custom_predictions': '11_12_13_predictions/Custom_Regression_Prediction/' # 目录,包含自定义回归预测结果
},
- 'step9': {
+ 'step14': {
'distribution_maps': '14_visualization/' # 目录,包含专题图
}
}
@@ -1432,37 +1438,37 @@ class WaterQualityGUI(QMainWindow):
'boundary_mask_path': ('step1', 'water_mask', 'boundary_mask_file'), # 步骤5可选水体掩膜
'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤5可选耀斑掩膜
},
- 'step5_5': {
- 'training_csv_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱
- },
- 'step6': {
- 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6需要训练光谱数据
- },
- 'step6_5': {
- 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6.5需要训练光谱数据
- },
- 'step6_75': {
- 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6.75需要训练光谱数据
+ 'step8': {
+ 'training_csv_path': ('step5', 'training_spectra', 'output_file') # 步骤8需要步骤5输出的训练光谱
},
'step7': {
- 'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'), # 步骤7需要去耀斑影像
- 'water_mask_path': ('step1', 'water_mask', 'water_mask_file'), # 步骤7需要水域掩膜
- 'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤7可选耀斑掩膜
+ 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤7需要训练光谱数据
},
- 'step8': {
- 'sampling_csv_path': ('step7', 'sampling_points', 'sampling_csv_file'), # 步骤8需要采样点
- 'models_dir': ('step6', 'models', 'models_dir_file') # 步骤8需要训练好的模型
- },
- 'step8_5': {
- 'sampling_csv_path': ('step7', 'sampling_points', 'sampling_csv_file'), # 步骤8.5需要采样点
- 'models_dir': ('step6_5', 'regression_models', 'models_dir') # 步骤8.5需要回归模型
- },
- 'step8_75': {
- 'sampling_csv_path': ('step7', 'sampling_points', 'sampling_csv_file'), # 步骤8.75需要采样点
- 'models_dir': ('step6_75', 'custom_regression_models', 'models_dir') # 步骤8.75需要自定义回归模型
+ 'step8_non_empirical_modeling': {
+ 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤8非经验建模需要训练光谱数据
},
'step9': {
- 'prediction_csv_path': ('step8', 'predictions', 'prediction_csv_file') # 步骤9需要预测结果CSV
+ 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤9需要训练光谱数据
+ },
+ 'step10': {
+ 'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'), # 步骤10需要去耀斑影像
+ 'water_mask_path': ('step1', 'water_mask', 'water_mask_file'), # 步骤10需要水域掩膜
+ 'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤10可选耀斑掩膜
+ },
+ 'step11_ml': {
+ 'sampling_csv_path': ('step10', 'sampling_points', 'sampling_csv_file'), # 步骤11ML需要采样点
+ 'models_dir': ('step7', 'models', 'models_dir_file') # 步骤11ML需要训练好的模型
+ },
+ 'step11': {
+ 'sampling_csv_path': ('step10', 'sampling_points', 'sampling_csv_file'), # 步骤11需要采样点
+ 'models_dir': ('step8_non_empirical_modeling', 'regression_models', 'models_dir') # 步骤11需要回归模型
+ },
+ 'step12': {
+ 'sampling_csv_path': ('step10', 'sampling_points', 'sampling_csv_file'), # 步骤12需要采样点
+ 'models_dir': ('step9', 'custom_regression_models', 'models_dir') # 步骤12需要自定义回归模型
+ },
+ 'step14': {
+ 'prediction_csv_path': ('step11_ml', 'predictions', 'prediction_csv_file') # 步骤14需要预测结果CSV
}
}
@@ -1545,7 +1551,7 @@ class WaterQualityGUI(QMainWindow):
def init_ui(self):
"""初始化UI"""
- self.setWindowTitle("MegaCube-Water Quality V1.1")
+ self.setWindowTitle("MegaCube-Water Quality V1.2")
# 获取屏幕可用区域(排除任务栏)
screen_geometry = QApplication.primaryScreen().availableGeometry()
@@ -1730,7 +1736,7 @@ class WaterQualityGUI(QMainWindow):
def create_banner_widget(self):
"""创建横幅区域 - 支持自适应等比缩放"""
# 横幅标题文字(方便后续直接修改版本号)
- self._APP_TITLE = "MegaCube-Water Quality V1.1"
+ self._APP_TITLE = "MegaCube-Water Quality V1.2"
# 创建横幅容器
banner_widget = QWidget()
@@ -1844,19 +1850,19 @@ class WaterQualityGUI(QMainWindow):
"阶段二:样本数据准备 ": [
("step4", "4. 数据标准化处理"),
("step5", "5. 光谱特征提取"),
- ("step5_5", "6. 水质参数指数计算"),
+ ("step8", "6. 水质参数指数计算"),
],
"阶段三:模型构建与训练": [
- ("step6", "7. 机器学习模型训练"),
- ("step6_5", "8. 回归模型训练"),
- ("step6_75", "9. 自定义回归模型训练"),
+ ("step7", "7. 机器学习模型训练"),
+ ("step8_non_empirical_modeling", "8. 回归模型训练"),
+ ("step9", "9. 自定义回归模型训练"),
],
"阶段四:预测与成果输出 ": [
- ("step7", "10. 采样点布设"),
- ("step8", "11. 机器学习预测"),
- ("step8_5", "12. 回归预测"),
- ("step8_75", "13. 自定义回归预测"),
- ("step9", "14. 专题图生成"),
+ ("step10", "10. 采样点布设"),
+ ("step11_ml", "11. 机器学习预测"),
+ ("step11", "12. 回归预测"),
+ ("step12", "13. 自定义回归预测"),
+ ("step14", "14. 专题图生成"),
("step9_viz", "15. 可视化分析"),
("step_report", "16. 分析报告生成"),
]
@@ -1878,7 +1884,7 @@ class WaterQualityGUI(QMainWindow):
self.step_list.addItem(stage_item)
# 添加该阶段的所有步骤
- HIDDEN_STEP_IDS = {"step6_5", "step6_75", "step8_5", "step8_75"}
+ HIDDEN_STEP_IDS = {"step8_non_empirical_modeling", "step9", "step11", "step12"}
for step_id, step_display in steps:
if step_id in HIDDEN_STEP_IDS:
continue
@@ -1958,36 +1964,36 @@ class WaterQualityGUI(QMainWindow):
self.step5_panel = Step5Panel()
self.step_stack.addTab(self.create_scroll_area(self.step5_panel), QIcon(self.get_icon_path("5.png")), "特征构建")
- self.step5_5_panel = Step5_5Panel()
- self.step_stack.addTab(self.create_scroll_area(self.step5_5_panel), QIcon(self.get_icon_path("5.png")), "水质指数")
-
- self.step6_panel = Step6Panel()
- self.step_stack.addTab(self.create_scroll_area(self.step6_panel), QIcon(self.get_icon_path("6.png")), "监督建模")
-
- self.step6_5_panel = Step6_5Panel()
- self.step_stack.addTab(self.create_scroll_area(self.step6_5_panel), QIcon(self.get_icon_path("6.png")), "回归建模")
- self.step_stack.tabBar().setTabVisible(7, False) # 隐藏回归建模 Tab
-
- self.step6_75_panel = Step6_75Panel()
- self.step_stack.addTab(self.create_scroll_area(self.step6_75_panel), QIcon(self.get_icon_path("6.png")), "自定义回归建模")
- self.step_stack.tabBar().setTabVisible(8, False) # 隐藏自定义回归建模 Tab
+ self.step8_panel = Step8Panel()
+ self.step_stack.addTab(self.create_scroll_area(self.step8_panel), QIcon(self.get_icon_path("5.png")), "水质指数")
self.step7_panel = Step7Panel()
- self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("7.png")), "采样点布设")
-
- self.step8_panel = Step8Panel()
- self.step_stack.addTab(self.create_scroll_area(self.step8_panel), QIcon(self.get_icon_path("8.png")), "监督预测")
+ self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("6.png")), "监督建模")
- self.step8_5_panel = Step8_5Panel()
- self.step_stack.addTab(self.create_scroll_area(self.step8_5_panel), QIcon(self.get_icon_path("8.png")), "回归预测")
- self.step_stack.tabBar().setTabVisible(11, False) # 隐藏回归预测 Tab
-
- self.step8_75_panel = Step8_75Panel()
- self.step_stack.addTab(self.create_scroll_area(self.step8_75_panel), QIcon(self.get_icon_path("8.png")), "自定义回归预测")
- self.step_stack.tabBar().setTabVisible(12, False) # 隐藏自定义回归预测 Tab
+ self.step8_non_empirical_panel = Step8NonEmpiricalPanel()
+ self.step_stack.addTab(self.create_scroll_area(self.step8_non_empirical_panel), QIcon(self.get_icon_path("6.png")), "回归建模")
+ self.step_stack.tabBar().setTabVisible(7, False) # 隐藏回归建模 Tab
self.step9_panel = Step9Panel()
- self.step_stack.addTab(self.create_scroll_area(self.step9_panel), QIcon(self.get_icon_path("10.png")), "专题图生成")
+ self.step_stack.addTab(self.create_scroll_area(self.step9_panel), QIcon(self.get_icon_path("6.png")), "自定义回归建模")
+ self.step_stack.tabBar().setTabVisible(8, False) # 隐藏自定义回归建模 Tab
+
+ self.step10_panel = Step10Panel()
+ self.step_stack.addTab(self.create_scroll_area(self.step10_panel), QIcon(self.get_icon_path("7.png")), "采样点布设")
+
+ self.step11_ml_panel = Step11MlPanel() # ML prediction panel (step11_ml)
+ self.step_stack.addTab(self.create_scroll_area(self.step11_ml_panel), QIcon(self.get_icon_path("8.png")), "监督预测")
+
+ self.step11_panel = Step11Panel()
+ self.step_stack.addTab(self.create_scroll_area(self.step11_panel), QIcon(self.get_icon_path("8.png")), "回归预测")
+ self.step_stack.tabBar().setTabVisible(11, False) # 隐藏回归预测 Tab
+
+ self.step12_panel = Step12Panel()
+ self.step_stack.addTab(self.create_scroll_area(self.step12_panel), QIcon(self.get_icon_path("8.png")), "自定义回归预测")
+ self.step_stack.tabBar().setTabVisible(12, False) # 隐藏自定义回归预测 Tab
+
+ self.step14_panel = Step14Panel()
+ self.step_stack.addTab(self.create_scroll_area(self.step14_panel), QIcon(self.get_icon_path("10.png")), "专题图生成")
self.viz_panel = VisualizationPanel()
self.step_stack.addTab(self.create_scroll_area(self.viz_panel), QIcon(self.get_icon_path("9.png")), "可视化")
@@ -2137,15 +2143,15 @@ class WaterQualityGUI(QMainWindow):
'step3': 2,
'step4': 3,
'step5': 4,
- 'step5_5': 5,
- 'step6': 6,
- 'step6_5': 7,
- 'step6_75': 8,
- 'step7': 9,
- 'step8': 10,
- 'step8_5': 11,
- 'step8_75': 12,
- 'step9': 13,
+ 'step8': 5,
+ 'step7': 6,
+ 'step8_non_empirical_modeling': 7,
+ 'step9': 8,
+ 'step10': 9,
+ 'step11_ml': 10,
+ 'step11': 11,
+ 'step12': 12,
+ 'step14': 13,
'step9_viz': 14,
'step_report': 15,
}
@@ -2168,15 +2174,15 @@ class WaterQualityGUI(QMainWindow):
2: 'step3',
3: 'step4',
4: 'step5',
- 5: 'step5_5',
- 6: 'step6',
- 7: 'step6_5',
- 8: 'step6_75',
- 9: 'step7',
- 10: 'step8',
- 11: 'step8_5',
- 12: 'step8_75',
- 13: 'step9',
+ 5: 'step8',
+ 6: 'step7',
+ 7: 'step8_non_empirical_modeling',
+ 8: 'step9',
+ 9: 'step10',
+ 10: 'step11_ml',
+ 11: 'step11',
+ 12: 'step12',
+ 13: 'step14',
14: 'step9_viz',
15: 'step_report',
}
@@ -2213,41 +2219,41 @@ class WaterQualityGUI(QMainWindow):
elif index == 4:
self.step5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
- # Step5_5 切换时自动填充输出路径
+ # Step8(水质指数)切换时自动填充输出路径
elif index == 5:
- self.step5_5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
+ self.step8_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
- # Step6 切换时自动填充训练数据和输出路径
+ # Step7(监督建模)切换时自动填充训练数据和输出路径
elif index == 6:
- self.step6_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
-
- # Step6.5(非经验回归建模)切换时自动填充训练数据和模型目录
- elif index == 7:
- self.step6_5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
-
- # Step6.75(自定义回归建模)切换时自动填充训练数据和模型目录
- elif index == 8:
- self.step6_75_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
-
- # Step7(采样点布设)切换时自动填充掩膜和输出路径
- elif index == 9:
self.step7_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
+ # Step8非经验建模切换时自动填充训练数据和模型目录
+ elif index == 7:
+ self.step8_non_empirical_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
+
+ # Step9(自定义回归建模)切换时自动填充训练数据和模型目录
+ elif index == 8:
+ self.step9_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
+
+ # Step10(采样点布设)切换时自动填充掩膜和输出路径
+ elif index == 9:
+ self.step10_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
+
# Step8(机器学习预测)切换时自动填充采样光谱和模型目录
elif index == 10:
- self.step8_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
+ self.step11_ml_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
- # Step8.5(非经验模型预测)切换时自动填充采样光谱和回归模型目录
+ # Step11(回归预测)切换时自动填充采样光谱和回归模型目录
elif index == 11:
- self.step8_5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
+ self.step11_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
- # Step8.75(自定义回归预测)切换时自动填充采样光谱和自定义回归模型目录
+ # Step12(自定义回归预测)切换时自动填充采样光谱和自定义回归模型目录
elif index == 12:
- self.step8_75_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
+ self.step12_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
- # Step9(专题图生成)切换时自动填充预测结果目录
+ # Step14(专题图生成)切换时自动填充预测结果目录
elif index == 13:
- self.step9_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
+ self.step14_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# 可视化分析面板切换时自动推断图像目录并加载目录树
elif index == 14:
@@ -2294,22 +2300,22 @@ class WaterQualityGUI(QMainWindow):
self.step4_panel.set_config(config['step4'])
if 'step5' in config:
self.step5_panel.set_config(config['step5'])
- if 'step5_5' in config:
- self.step5_5_panel.set_config(config['step5_5'])
- if 'step6' in config:
- self.step6_panel.set_config(config['step6'])
- if 'step6_5' in config:
- self.step6_5_panel.set_config(config['step6_5'])
- if 'step6_75' in config:
- self.step6_75_panel.set_config(config['step6_75'])
- if 'step7' in config:
- self.step7_panel.set_config(config['step7'])
if 'step8' in config:
self.step8_panel.set_config(config['step8'])
- if 'step8_5' in config:
- self.step8_5_panel.set_config(config['step8_5'])
+ if 'step7' in config:
+ self.step7_panel.set_config(config['step7'])
+ if 'step8_non_empirical_modeling' in config:
+ self.step8_non_empirical_panel.set_config(config['step8_non_empirical_modeling'])
if 'step9' in config:
self.step9_panel.set_config(config['step9'])
+ if 'step10' in config:
+ self.step10_panel.set_config(config['step10'])
+ if 'step11_ml' in config:
+ self.step11_ml_panel.set_config(config['step11_ml'])
+ if 'step11' in config:
+ self.step11_panel.set_config(config['step11'])
+ if 'step14' in config:
+ self.step14_panel.set_config(config['step14'])
if 'visualization' in config:
self.viz_panel.set_config(config['visualization'])
if 'report_generation' in config:
@@ -2352,14 +2358,14 @@ class WaterQualityGUI(QMainWindow):
'step3': self.step3_panel.get_config(),
'step4': self.step4_panel.get_config(),
'step5': self.step5_panel.get_config(),
- 'step5_5': self.step5_5_panel.get_config(),
- 'step6': self.step6_panel.get_config(),
- 'step6_5': self.step6_5_panel.get_config(),
- 'step6_75': self.step6_75_panel.get_config(),
- 'step7': self.step7_panel.get_config(),
'step8': self.step8_panel.get_config(),
- 'step8_5': self.step8_5_panel.get_config(),
+ 'step7': self.step7_panel.get_config(),
+ 'step8_non_empirical_modeling': self.step8_non_empirical_panel.get_config(),
'step9': self.step9_panel.get_config(),
+ 'step10': self.step10_panel.get_config(),
+ 'step11_ml': self.step11_ml_panel.get_config(),
+ 'step11': self.step11_panel.get_config(),
+ 'step14': self.step14_panel.get_config(),
'visualization': self.viz_panel.get_config(),
'report_generation': self.report_panel.get_config(),
}
@@ -2410,15 +2416,15 @@ class WaterQualityGUI(QMainWindow):
'step3': self.step3_panel,
'step4': self.step4_panel,
'step5': self.step5_panel,
- 'step5_5': self.step5_5_panel,
- 'step6': self.step6_panel,
- 'step6_5': self.step6_5_panel,
- 'step6_75': self.step6_75_panel,
- 'step7': self.step7_panel,
'step8': self.step8_panel,
- 'step8_5': self.step8_5_panel,
- 'step8_75': self.step8_75_panel,
+ 'step7': self.step7_panel,
+ 'step8_non_empirical_modeling': self.step8_non_empirical_panel,
'step9': self.step9_panel,
+ 'step10': self.step10_panel,
+ 'step11_ml': self.step11_ml_panel,
+ 'step11': self.step11_panel,
+ 'step12': self.step12_panel,
+ 'step14': self.step14_panel,
}
return panel_map.get(step_id)
@@ -2426,15 +2432,38 @@ class WaterQualityGUI(QMainWindow):
"""查找指定步骤的输出文件"""
if step_id not in self.step_default_outputs:
return None
-
+
step_outputs = self.step_default_outputs[step_id]
-
+
+ # ★ 掩膜类型列表:这些类型只接受科学数据格式
+ mask_types = {'water_mask', 'glint_mask', 'boundary_mask'}
+ # ★ 白名单机制:只允许 .dat .tif .tiff .shp,拒绝其他一切格式
+ scientific_extensions = {'.dat', '.tif', '.tiff', '.shp'}
+ # ★ 临时文件关键词黑名单
+ tmp_keywords = ('__tmp', '_tmp')
+
+ def _is_scientific_mask(path_str):
+ """白名单判断:只有 .dat .tif .tiff .shp 才算科学数据格式"""
+ p = Path(path_str)
+ name_lower = str(path_str).lower()
+ # 拒绝临时文件
+ if any(kw in name_lower for kw in tmp_keywords):
+ return False
+ # 白名单校验
+ return p.suffix.lower() in scientific_extensions
+
# 特殊处理:从step_outputs记录中查找实际输出路径
if step_id in self.step_outputs:
actual_outputs = self.step_outputs[step_id]
if output_type in actual_outputs:
- return actual_outputs[output_type]
-
+ candidate = actual_outputs[output_type]
+ # ★ 掩膜类型白名单二次校验:不在白名单内的一律拒绝
+ if output_type in mask_types and not _is_scientific_mask(candidate):
+ # 非科学格式被拒绝,不使用 step_outputs 中的值
+ pass
+ else:
+ return candidate
+
# 根据输出类型查找对应的文件
if output_type == 'water_mask':
# 水域掩膜:优先查找NDWI生成的,其次是shp生成的
@@ -2485,19 +2514,19 @@ class WaterQualityGUI(QMainWindow):
# 扫描各个子目录
subdirs = {
'1_water_mask': 'step1',
- '2_glint': 'step2',
+ '2_glint': 'step2',
'3_deglint': 'step3',
'4_processed_data': 'step4',
'5_training_spectra': 'step5',
- '6_water_quality_indices': 'step5_5',
- '7_Supervised_Model_Training': 'step6',
- '8_Regression_Modeling': 'step6_5',
- '9_Custom_Regression_Modeling': 'step6_75',
- '10_sampling': 'step7',
- '11_12_13_predictions/Machine_Learning_Prediction': 'step8',
- '11_12_13_predictions/Non_Empirical_Prediction': 'step8_5',
- '11_12_13_predictions/Custom_Regression_Prediction': 'step8_75',
- '14_visualization': 'step9'
+ '6_water_quality_indices': 'step8',
+ '7_Supervised_Model_Training': 'step7',
+ '8_Regression_Modeling': 'step8_non_empirical_modeling',
+ '9_Custom_Regression_Modeling': 'step9',
+ '10_sampling': 'step10',
+ '11_12_13_predictions/Machine_Learning_Prediction': 'step11_ml',
+ '11_12_13_predictions/Non_Empirical_Prediction': 'step11',
+ '11_12_13_predictions/Custom_Regression_Prediction': 'step12',
+ '14_visualization': 'step14'
}
for subdir, step_ids in subdirs.items():
@@ -2517,23 +2546,37 @@ class WaterQualityGUI(QMainWindow):
for step_id in step_ids:
if step_id not in discovered_outputs:
discovered_outputs[step_id] = {}
-
+
+ # ★ 掩膜文件白名单过滤:只有 .dat .tif .tiff .shp 才通过,拒绝 .hdr .xml .png 等
+ scientific_extensions = {'.dat', '.tif', '.tiff', '.shp'}
+ tmp_keywords = ('__tmp', '_tmp')
+
+ def _is_scientific_mask(path_str):
+ """白名单判断:拒绝 .hdr .xml 临时文件等,只接受科学数据格式"""
+ p = Path(path_str)
+ name_lower = str(path_str).lower()
+ if any(kw in name_lower for kw in tmp_keywords):
+ return False
+ return p.suffix.lower() in scientific_extensions
+
# 匹配不同的文件类型
if 'water_mask' in file_name and step_id == 'step1':
- discovered_outputs[step_id]['water_mask'] = str(file_path)
+ if _is_scientific_mask(file_path):
+ discovered_outputs[step_id]['water_mask'] = str(file_path)
elif 'glint' in file_name and 'mask' in file_name and step_id == 'step2':
- discovered_outputs[step_id]['glint_mask'] = str(file_path)
+ if _is_scientific_mask(file_path):
+ discovered_outputs[step_id]['glint_mask'] = str(file_path)
elif 'deglint' in file_name and step_id == 'step3':
discovered_outputs[step_id]['deglint_image'] = str(file_path)
elif 'processed_data' in file_name and step_id == 'step4':
discovered_outputs[step_id]['processed_data'] = str(file_path)
elif 'training_spectra' in file_name and step_id == 'step5':
discovered_outputs[step_id]['training_spectra'] = str(file_path)
- elif 'water_quality_indices' in file_name and step_id == 'step5_5':
+ elif 'water_quality_indices' in file_name and step_id == 'step8':
discovered_outputs[step_id]['water_indices'] = str(file_path)
- elif 'sampling_spectra' in file_name and step_id == 'step7':
+ elif 'sampling_spectra' in file_name and step_id == 'step10':
discovered_outputs[step_id]['sampling_points'] = str(file_path)
- elif file_name.endswith('.csv') and step_id in ['step8', 'step8_5', 'step8_75']:
+ elif file_name.endswith('.csv') and step_id in ['step11_ml', 'step11', 'step12']:
discovered_outputs[step_id]['predictions'] = str(file_path)
# 更新内部记录
@@ -2556,8 +2599,8 @@ class WaterQualityGUI(QMainWindow):
# 首先扫描工作目录发现已有的输出文件
self.scan_work_directory_for_files(work_path)
- step_order = ['step2', 'step3', 'step4', 'step5', 'step5_5', 'step6', 'step6_5', 'step6_75',
- 'step7', 'step8', 'step8_5', 'step8_75', 'step9']
+ step_order = ['step2', 'step3', 'step4', 'step5', 'step8', 'step7', 'step8_non_empirical_modeling', 'step9',
+ 'step10', 'step11_ml', 'step11', 'step12', 'step14']
filled_count = 0
for step_id in step_order:
@@ -2579,15 +2622,15 @@ class WaterQualityGUI(QMainWindow):
('step2', self.step2_panel),
('step3', self.step3_panel),
('step5', self.step5_panel),
- ('step5_5', self.step5_5_panel),
- ('step6', self.step6_panel),
- ('step6_5', self.step6_5_panel),
- ('step6_75', self.step6_75_panel),
- ('step7', self.step7_panel),
('step8', self.step8_panel),
- ('step8_5', self.step8_5_panel),
- ('step8_75', self.step8_75_panel),
- ('step9', self.step9_panel)
+ ('step7', self.step7_panel),
+ ('step8_non_empirical_modeling', self.step8_non_empirical_panel),
+ ('step9', self.step9_panel),
+ ('step10', self.step10_panel),
+ ('step11_ml', self.step11_ml_panel),
+ ('step11', self.step11_panel),
+ ('step12', self.step12_panel),
+ ('step14', self.step14_panel)
]
for step_id, panel in panels_with_dependencies:
@@ -2735,7 +2778,7 @@ class WaterQualityGUI(QMainWindow):
"""显示关于对话框"""
QMessageBox.about(
self, "关于",
- "MegaCube-Water Quality V1.1\n\n"
+ "MegaCube-Water Quality V1.2\n\n"
"一个完整的水质参数反演工作流程工具\n\n"
"功能包括:\n"
"- 水域掩膜生成\n"
@@ -2858,6 +2901,41 @@ class WaterQualityGUI(QMainWindow):
return True
+ # ------------------------------------------------------------------
+ # ★ 全流程模式动态裁剪
+ # ------------------------------------------------------------------
+
+ def _prune_config_for_prediction_mode(self, config: dict) -> dict:
+ """Prediction-only 模式:禁用训练相关步骤,保留预测和成图步骤。
+
+ 被禁用的 step dict 中统一写入 'enabled': False,
+ 这些配置最终传给 PipelineRunner,Runner 会跳过它们。
+ 同时,被跳过的步骤的 required_input_files 在 build_missing_items
+ 中不会被检查,从而自然规避了"CSV 缺失"等训练模式下的误报。
+
+ Args:
+ config: 完整配置字典(来自 get_current_config)
+
+ Returns:
+ 裁剪后的 config(深拷贝,原 config 不被修改)
+ """
+ cfg = copy.deepcopy(config)
+
+ # 在每个训练相关步骤的 dict 中写入 enabled=False
+ training_steps = [
+ "step4", # CSV 实测数据清洗
+ "step5", # 实测点光谱提取(→ training_csv_path)
+ "step7", # ML 监督建模
+ "step8", # 水质指数计算(辅助训练)
+ "step8_non_empirical_modeling", # 非经验回归建模
+ "step9", # 自定义回归建模
+ ]
+ for step_id in training_steps:
+ step_cfg = cfg.setdefault(step_id, {})
+ step_cfg["enabled"] = False
+
+ return cfg
+
def run_full_pipeline(self):
"""运行完整流程"""
if not PIPELINE_AVAILABLE:
@@ -2867,8 +2945,14 @@ class WaterQualityGUI(QMainWindow):
)
return
+ # ── 0) 强制获取 work_dir(禁止依赖外部或全局变量) ──
+ work_dir = getattr(self, 'work_dir', None)
+ if not work_dir:
+ QMessageBox.warning(self, "警告", "未选择工作目录,请先设置工作目录。")
+ return
+
# ── 1) 运行前智能预检与自动回填(硬盘已有产物自动跳过) ──
- work_path = Path(getattr(self, 'work_dir', './work_dir'))
+ work_path = Path(work_dir)
self.log_message("正在进行运行前环境预检与自动扫描...", "info")
self.scan_work_directory_for_files(work_path)
self.auto_populate_all_steps()
@@ -2878,31 +2962,52 @@ class WaterQualityGUI(QMainWindow):
if not self._precheck_step3_bands():
return # 用户点"取消运行"
+ # ── 1.6) ★ 全流程模式选择弹窗 ──
+ mode_dlg = PipelineModeDialog(main_window=self, parent=self)
+ if mode_dlg.exec() != QDialog.Accepted:
+ return # 用户点"取消"
+ selected_mode = mode_dlg.selected_mode
+ self.log_message(f"[模式选择] 选定模式: {'训练新模型' if selected_mode == 'training' else '使用已有模型直接预测'}", "info")
+
# ── 2) 刷新配置(拿到自动填充后的"满血版" config) ──
config = self.get_current_config()
- # ── 3) 根基数据校验:step1.img_path(参考影像) ──
- if not config['step1'].get('img_path'):
- QMessageBox.warning(self, "警告", "缺失核心数据:请先在步骤 1 中上传【参考影像】!")
- for i in range(self.step_list.count()):
- item = self.step_list.item(i)
- if item.data(Qt.UserRole) == 'step1':
- self.step_list.setCurrentRow(i)
- break
- return
+ # ── 2.1) ★ 根据模式动态裁剪配置 ──
+ if selected_mode == "prediction_only":
+ config = self._prune_config_for_prediction_mode(config)
+ self.log_message("[模式选择] 已裁剪训练相关步骤(step4/5/7/8),进入仅预测模式", "info")
- # ── 4) 软提示:csv_path 缺失 → 模型训练步骤会被静默跳过(不阻断) ──
- csv_path = config.get('step4', {}).get('csv_path') or config.get('step5', {}).get('csv_path')
- if not csv_path:
- QMessageBox.information(
- self,
- "提示:模型训练将被跳过",
- "未检测到实测水质数据 (CSV)。\n"
- "流程将自动跳过模型训练(步骤 4-6),仅执行预测与制图。\n"
- "如果需要训练新模型,请先在步骤 4 中上传水质数据。",
- )
+ # ── 3) ★ 一次性全预检 + 用户交互式决策 ──
+ missing_items = PreflightDialog.build_missing_items(config)
+ if missing_items:
+ critical_items = [it for it in missing_items if it.is_critical]
+ if critical_items:
+ lines = "\n".join(f" - [{it.step_name}] {it.reason}" for it in critical_items)
+ QMessageBox.critical(
+ self, "预检失败(阻断性错误)",
+ f"以下为阻断性缺失,流程无法启动:\n\n{lines}\n\n"
+ "请填写后重新运行。"
+ )
+ return
+ dialog = PreflightDialog(missing_items, parent=self)
+ if dialog.exec() != QDialog.Accepted:
+ return
+ result = dialog.get_result()
+ if result is None:
+ return
+ action, *payload = result
+ if action == "fill":
+ _, step_id, tab_index = result
+ self.step_stack.setCurrentIndex(tab_index)
+ self.log_message(f"[预检] 用户选择填写 {step_id},已切换到对应面板。", "info")
+ return
+ skip_list: List[str] = payload[0] if payload else []
+ if skip_list:
+ self.log_message(f"[预检] 用户强制跳过 {len(skip_list)} 个步骤: {skip_list}", "info")
+ else:
+ skip_list = []
- # 确认执行
+ # ── 4) 确认执行
reply = QMessageBox.question(
self, "确认",
"是否开始执行完整流程?\n\n这可能需要较长时间,请确保配置正确。",
@@ -2913,19 +3018,18 @@ class WaterQualityGUI(QMainWindow):
return
# 创建pipeline实例
- work_dir = getattr(self, 'work_dir', './work_dir')
self.log_message(f"初始化pipeline,工作目录: {work_dir}", "info")
# 准备实际运行配置(排除未启用的步骤)
worker_config = copy.deepcopy(config)
- step5_5_cfg = worker_config.get('step5_5')
- if step5_5_cfg:
- enabled = step5_5_cfg.pop('enabled', True)
+ step8_cfg = worker_config.get('step8')
+ if step8_cfg:
+ enabled = step8_cfg.pop('enabled', True)
if not enabled:
- worker_config.pop('step5_5', None)
+ worker_config.pop('step8', None)
# 工作线程内创建 Pipeline,避免主线程阻塞及 Qt5Agg 子线程绘图卡死
- self.worker = WorkerThread(work_dir, worker_config, mode='full')
+ self.worker = WorkerThread(work_dir, worker_config, mode='full', skip_list=skip_list)
self.worker.log_message.connect(self.log_message, Qt.QueuedConnection)
self.worker.progress_update.connect(self.update_progress, Qt.QueuedConnection)
self.worker.step_completed.connect(self.on_step_completed, Qt.QueuedConnection)
@@ -3152,14 +3256,14 @@ class WaterQualityGUI(QMainWindow):
def update_ui_for_training_mode(self):
"""根据训练数据模式更新UI状态"""
# 需要禁用的步骤ID(对应无训练数据模式下需要禁用的步骤)
- disabled_step_ids = ['step4', 'step5', 'step5_5', 'step6', 'step6_5', 'step6_75']
+ disabled_step_ids = ['step4', 'step5', 'step8', 'step7', 'step8_non_empirical_modeling', 'step9']
# 更新标签页的启用/禁用状态
step_id_to_tab = {
'step1': 0, 'step2': 1, 'step3': 2, 'step4': 3,
- 'step5': 4, 'step5_5': 5, 'step6': 6, 'step6_5': 7,
- 'step6_75': 8, 'step7': 9, 'step8': 10, 'step8_5': 11,
- 'step8_75': 12, 'step9': 13, 'step9_viz': 14
+ 'step5': 4, 'step8': 5, 'step7': 6, 'step8_non_empirical_modeling': 7,
+ 'step9': 8, 'step10': 9, 'step11_ml': 10, 'step11': 11,
+ 'step12': 12, 'step14': 13, 'step9_viz': 14
}
for step_id in disabled_step_ids: