From 28394f2eda16a59a4957036c5e3b18da4eef4105 Mon Sep 17 00:00:00 2001 From: DXC Date: Tue, 9 Jun 2026 11:30:42 +0800 Subject: [PATCH] =?UTF-8?q?feat(gui):=20=E5=85=A8=E6=B5=81=E7=A8=8B?= =?UTF-8?q?=E9=9D=A2=E6=9D=BF=E5=90=88=E5=B9=B6=20+=20=E4=B8=80=E9=94=AE?= =?UTF-8?q?=E5=BC=8F=E8=BF=90=E8=A1=8C=20GUI=20=E5=85=A5=E5=8F=A3=E9=9B=86?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- check_lines.py | 6 + src/core/pipeline/__init__.py | 16 +- src/core/pipeline/context.py | 48 +- .../water_quality_inversion_pipeline_GUI.py | 246 +++--- src/gui/core/pipeline_mode_dialog.py | 237 ++++++ src/gui/dialogs.py | 38 +- src/gui/panels/step10_panel.py | 252 +++++++ src/gui/panels/step11_ml_panel.py | 462 ++++++++++++ .../{step8_5_panel.py => step11_panel.py} | 34 +- .../{step8_75_panel.py => step12_panel.py} | 38 +- src/gui/panels/step14_panel.py | 533 +++++++++++++ src/gui/panels/step5_5_panel.py | 225 ------ src/gui/panels/step6_75_panel.py | 374 ---------- src/gui/panels/step6_panel.py | 415 ----------- src/gui/panels/step7_panel.py | 517 ++++++++----- ..._panel.py => step8_non_empirical_panel.py} | 10 +- src/gui/panels/step8_panel.py | 593 +++++---------- src/gui/panels/step9_panel.py | 699 +++++++----------- src/gui/panels/visualization_panel.py | 12 +- src/gui/water_quality_gui.py | 520 +++++++------ 20 files changed, 2843 insertions(+), 2432 deletions(-) create mode 100644 check_lines.py create mode 100644 src/gui/core/pipeline_mode_dialog.py create mode 100644 src/gui/panels/step10_panel.py create mode 100644 src/gui/panels/step11_ml_panel.py rename src/gui/panels/{step8_5_panel.py => step11_panel.py} (86%) rename src/gui/panels/{step8_75_panel.py => step12_panel.py} (88%) create mode 100644 src/gui/panels/step14_panel.py delete mode 100644 src/gui/panels/step5_5_panel.py delete mode 100644 src/gui/panels/step6_75_panel.py delete mode 100644 src/gui/panels/step6_panel.py rename src/gui/panels/{step6_5_panel.py => step8_non_empirical_panel.py} (97%) diff --git a/check_lines.py b/check_lines.py new file mode 100644 index 0000000..22c4dac --- /dev/null +++ b/check_lines.py @@ -0,0 +1,6 @@ +import sys +with open(r'D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py', 'rb') as f: + content = f.read() +lines = content.split(b'\r\n') +for i, line in enumerate(lines[2918:2955], start=2919): + sys.stdout.buffer.write(f'{i}: {repr(line[:120])}'.encode('utf-8') + b'\n') \ No newline at end of file diff --git a/src/core/pipeline/__init__.py b/src/core/pipeline/__init__.py index 3a3d22c..00ba2ce 100644 --- a/src/core/pipeline/__init__.py +++ b/src/core/pipeline/__init__.py @@ -8,7 +8,17 @@ Pipeline 调度核心:基于 Context 的内存级依赖注入。 - 不绑定具体 Pipeline 实现(duck-typed),WorkerThread / Web API / 单测可共用 """ -from .context import PipelineContext -from .runner import StepSpec, PIPELINE_STEPS, PipelineRunner +from .context import ( + PipelineContext, + STEP_MAP_OLD_TO_NEW, STEP_MAP_NEW_TO_OLD, + resolve_step_id, ALL_STEP_IDS, +) +from .runner import ( + StepSpec, PIPELINE_STEPS, PipelineRunner, PipelineHalt, +) -__all__ = ["PipelineContext", "StepSpec", "PIPELINE_STEPS", "PipelineRunner"] +__all__ = [ + "PipelineContext", "StepSpec", "PIPELINE_STEPS", "PipelineRunner", "PipelineHalt", + "STEP_MAP_OLD_TO_NEW", "STEP_MAP_NEW_TO_OLD", + "resolve_step_id", "ALL_STEP_IDS", +] diff --git a/src/core/pipeline/context.py b/src/core/pipeline/context.py index b31a0ea..d6723cd 100644 --- a/src/core/pipeline/context.py +++ b/src/core/pipeline/context.py @@ -12,7 +12,34 @@ PipelineContext:内存级数据载体,跨 14 个 step 传递路径与元信 from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set + + +# ============================================================ +# 步骤命名映射(定义在叶子节点,打破循环依赖) +# ============================================================ + +STEP_MAP_OLD_TO_NEW: Dict[str, str] = { + "step5_5": "step8", + "step6_5": "step8_non_empirical_modeling", + "step6_75": "step9", + "step8_5": "step11", + "step8_75": "step12", + "step7": "step10", + "step8": "step11_ml", + "step9": "step14", +} + +STEP_MAP_NEW_TO_OLD: Dict[str, str] = {v: k for k, v in STEP_MAP_OLD_TO_NEW.items()} + +ALL_STEP_IDS: Set[str] = set(STEP_MAP_OLD_TO_NEW.keys()) | set(STEP_MAP_OLD_TO_NEW.values()) + + +def resolve_step_id(step_id: str) -> str: + """将任意 step_id 转换为标准新格式。""" + if step_id in STEP_MAP_OLD_TO_NEW: + return STEP_MAP_OLD_TO_NEW[step_id] + return step_id @dataclass @@ -63,10 +90,29 @@ class PipelineContext: pipeline_end_time: Optional[float] = None last_error: Optional[str] = None + # ── 错误汇总(全流程结束后可用) ── + error_summary: List[tuple[str, str]] = field(default_factory=list) + # ── 出错时立即停止全流程(默认 False:继续后续步骤) ── + breakpoint_on_error: bool = False + # ── ★ 智能补全锁定步骤列表(由 _auto_fill_missing_steps 自动开启的步骤) ── + # GUI 层读取此字段,在运行期间禁用对应面板的启用复选框 + locked_steps: List[str] = field(default_factory=list) + # ============================================================ # 读写辅助 # ============================================================ + def step_id(self, step_id: str) -> str: + """将任意 step_id(可能是旧名)转换为标准新格式。 + + 用法示例: + ctx.status[ctx.step_id('step6_5')] # 'step8_non_empirical_modeling' + ctx.user_config[ctx.step_id('step8_5')] # 'step11' + """ + if step_id in STEP_MAP_OLD_TO_NEW: + return STEP_MAP_OLD_TO_NEW[step_id] + return step_id + def set(self, key: str, value: Any) -> None: """原地写入任意属性。 diff --git a/src/core/water_quality_inversion_pipeline_GUI.py b/src/core/water_quality_inversion_pipeline_GUI.py index ac5d639..d6f4102 100644 --- a/src/core/water_quality_inversion_pipeline_GUI.py +++ b/src/core/water_quality_inversion_pipeline_GUI.py @@ -660,7 +660,7 @@ class WaterQualityInversionPipeline: self._notify("completed", f"训练光谱数据已保存: {result}") return result - def step5_5_calculate_water_quality_indices(self, + def step8_water_quality_indices(self, training_csv_path: Optional[str] = None, formula_csv_file: Optional[str] = None, formula_names: Optional[List[str]] = None, @@ -704,7 +704,7 @@ class WaterQualityInversionPipeline: self._notify("completed", f"水质指数已保存: {result}") return result - def step6_train_models(self, feature_start_column: str = "374.285004", + def step7_ml_modeling(self, feature_start_column: str = "374.285004", preprocessing_methods: List[str] = None, model_names: List[str] = None, split_methods: List[str] = None, @@ -747,7 +747,7 @@ class WaterQualityInversionPipeline: self._notify("completed", f"模型训练完成,结果保存在: {result}") return result - def step7_generate_sampling_points(self, deglint_img_path: Optional[str] = None, + def step10_sampling(self, deglint_img_path: Optional[str] = None, interval: int = 50, sample_radius: int = 5, chunk_size: int = 1000, @@ -756,7 +756,7 @@ class WaterQualityInversionPipeline: use_adaptive_sampling: bool = True, skip_dependency_check: bool = False, **kwargs) -> str: """ - 步骤7: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱 + 步骤10: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱 Args: deglint_img_path: 去除耀斑后的影像文件路径(如果为None,使用步骤3的结果) @@ -779,7 +779,7 @@ class WaterQualityInversionPipeline: if water_mask_path is None and self.water_mask_path is not None: water_mask_path = self.water_mask_path - self._notify("started", "步骤7: 生成预测采样点") + self._notify("started", "步骤10: 生成预测采样点") result = PredictionStep.generate_sampling_points( deglint_img_path=img_path, interval=interval, @@ -790,21 +790,21 @@ class WaterQualityInversionPipeline: output_dir=str(self.sampling_dir), use_adaptive_sampling=use_adaptive_sampling, ) - self._record_step_time("步骤7: 生成预测采样点", 0, 0) + self._record_step_time("步骤10: 生成预测采样点", 0, 0) self._notify("completed", f"采样点光谱数据已保存: {result}") return result - def step8_predict_water_quality(self, sampling_csv_path: str, + def step11_ml_prediction(self, sampling_csv_path: str, models_dir: Optional[str] = None, metric: str = 'test_r2', prediction_column: str = 'prediction', skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]: """ - 步骤8: 将训练好的最佳机器学习模型应用到采样点的平均光谱上,预测水质参数 - + 步骤11: 将训练好的最佳机器学习模型应用到采样点的平均光谱上,预测水质参数 + Args: sampling_csv_path: 采样点光谱数据CSV路径 - models_dir: 模型保存目录(如果为None,使用步骤6的结果) + models_dir: 模型保存目录(如果为None,使用步骤7的结果) metric: 选择最佳模型的指标 prediction_column: 预测结果列名 @@ -818,7 +818,7 @@ class WaterQualityInversionPipeline: print(f"[Pipeline] 收到字典: {'Yes' if _external_models_dict else 'No'}" f", 收到单模型: {'Yes' if _external_model else 'No'}") - self._notify("started", "步骤8: 预测水质参数") + self._notify("started", "步骤11: 预测水质参数") result = PredictionStep.predict_water_quality( sampling_csv_path=sampling_csv_path, models_dir=models_dir if models_dir else str(self.models_dir), @@ -831,11 +831,11 @@ class WaterQualityInversionPipeline: _external_models_dict=_external_models_dict, _external_model_dir=_external_model_dir, ) - self._record_step_time("步骤8: 预测水质参数", 0, 0) + self._record_step_time("步骤11: 预测水质参数", 0, 0) self._notify("completed", f"预测完成,结果保存在: {self.prediction_dir}") return result - def step9_generate_distribution_map(self, prediction_csv_path: str, + def step14_distribution_map(self, prediction_csv_path: str, boundary_shp_path: str, output_image_path: Optional[str] = None, resolution: float = 30, @@ -1524,99 +1524,99 @@ class WaterQualityInversionPipeline: else: self._notify("步骤5: 光谱提取", "skipped", "未配置") - # 步骤5.5: 计算水质指数 - if 'step5_5' in config: - self._notify("步骤5.5: 水质指数计算", "start") - self.step5_5_calculate_water_quality_indices(**config['step5_5']) - self._notify("步骤5.5: 水质指数计算", "completed", f"(输出: {self.indices_path})") + # 步骤8: 计算水质指数 + if 'step8' in config: + self._notify("步骤8: 水质指数计算", "start") + self.step8_water_quality_indices(**config['step8']) + self._notify("步骤8: 水质指数计算", "completed", f"(输出: {self.indices_path})") else: - self._notify("步骤5.5: 水质指数计算", "skipped", "未配置") + self._notify("步骤8: 水质指数计算", "skipped", "未配置") - # 步骤6: 训练模型 - if 'step6' in config: - self._notify("步骤6: 模型训练", "start") - self.step6_train_models(**config['step6']) - self._notify("步骤6: 模型训练", "completed", f"(输出: {self.models_dir})") - else: - self._notify("步骤6: 模型训练", "skipped", "未配置") - - # 步骤6.5: 非经验统计回归模型训练 - if 'step6_5' in config: - self._notify("步骤6.5: 非经验模型训练", "start") - self.step6_5_non_empirical_modeling(**config['step6_5']) - self._notify("步骤6.5: 非经验模型训练", "completed", f"(输出: {self.models_dir})") - else: - self._notify("步骤6.5: 非经验模型训练", "skipped", "未配置") - - # 步骤6.75: 自定义回归分析 - if 'step6_75' in config: - self._notify("步骤6.75: 自定义回归", "start") - self.step6_75_custom_regression(**config['step6_75']) - self._notify("步骤6.75: 自定义回归", "completed", f"(输出: {self.custom_regression_path})") - else: - self._notify("步骤6.75: 自定义回归", "skipped", "未配置") - - # 步骤7: 生成预测采样点 + # 步骤7: 训练模型 if 'step7' in config: - self._notify("步骤7: 采样点生成", "start") - sampling_csv_path = self.step7_generate_sampling_points(**config['step7']) - self._notify("步骤7: 采样点生成", "completed", f"(输出: {sampling_csv_path})") + self._notify("步骤7: 模型训练", "start") + self.step7_ml_modeling(**config['step7']) + self._notify("步骤7: 模型训练", "completed", f"(输出: {self.models_dir})") + else: + self._notify("步骤7: 模型训练", "skipped", "未配置") + + # 步骤8_non_empirical_modeling: 非经验统计回归模型训练 + if 'step8_non_empirical_modeling' in config: + self._notify("步骤8: 非经验模型训练", "start") + self.step8_non_empirical_modeling(**config['step8_non_empirical_modeling']) + self._notify("步骤8: 非经验模型训练", "completed", f"(输出: {self.models_dir})") + else: + self._notify("步骤8: 非经验模型训练", "skipped", "未配置") + + # 步骤9: 自定义回归分析 + if 'step9' in config: + self._notify("步骤9: 自定义回归", "start") + self.step9_custom_regression(**config['step9']) + self._notify("步骤9: 自定义回归", "completed", f"(输出: {self.custom_regression_path})") + else: + self._notify("步骤9: 自定义回归", "skipped", "未配置") + + # 步骤10: 生成预测采样点 + if 'step10' in config: + self._notify("步骤10: 采样点生成", "start") + sampling_csv_path = self.step10_sampling(**config['step10']) + self._notify("步骤10: 采样点生成", "completed", f"(输出: {sampling_csv_path})") else: sampling_csv_path = None - self._notify("步骤7: 采样点生成", "skipped", "未配置") + self._notify("步骤10: 采样点生成", "skipped", "未配置") - # 步骤8: 预测水质参数 - if 'step8' in config and sampling_csv_path: - self._notify("步骤8: 参数预测", "start") - step8_config = config['step8'].copy() - step8_config['sampling_csv_path'] = sampling_csv_path - prediction_files = self.step8_predict_water_quality(**step8_config) - self._notify("步骤8: 参数预测", "completed", f"(生成{len(prediction_files)}个预测文件)") + # 步骤11_ml: 预测水质参数 + if 'step11_ml' in config and sampling_csv_path: + self._notify("步骤11: 参数预测", "start") + step11_ml_config = config['step11_ml'].copy() + step11_ml_config['sampling_csv_path'] = sampling_csv_path + prediction_files = self.step11_ml_prediction(**step11_ml_config) + self._notify("步骤11: 参数预测", "completed", f"(生成{len(prediction_files)}个预测文件)") else: prediction_files = {} - self._notify("步骤8: 参数预测", "skipped", "未配置或缺少采样点") + self._notify("步骤11: 参数预测", "skipped", "未配置或缺少采样点") - # 步骤8.5: 使用非经验模型进行参数预测 + # 步骤11: 使用非经验模型进行参数预测 non_empirical_prediction_files = {} - if 'step8_5' in config and sampling_csv_path: - self._notify("步骤8.5: 非经验模型预测", "start") - step8_5_config = config['step8_5'].copy() - step8_5_config['sampling_csv_path'] = sampling_csv_path - non_empirical_prediction_files = self.step8_5_predict_with_non_empirical_models(**step8_5_config) - self._notify("步骤8.5: 非经验模型预测", "completed", f"(生成{len(non_empirical_prediction_files)}个预测文件)") + if 'step11' in config and sampling_csv_path: + self._notify("步骤11: 非经验模型预测", "start") + step11_config = config['step11'].copy() + step11_config['sampling_csv_path'] = sampling_csv_path + non_empirical_prediction_files = self.step11_non_empirical_prediction(**step11_config) + self._notify("步骤11: 非经验模型预测", "completed", f"(生成{len(non_empirical_prediction_files)}个预测文件)") else: - self._notify("步骤8.5: 非经验模型预测", "skipped", "未配置或缺少采样点") + self._notify("步骤11: 非经验模型预测", "skipped", "未配置或缺少采样点") - # 步骤8.75: 使用自定义回归模型进行参数预测 + # 步骤12: 使用自定义回归模型进行参数预测 custom_regression_prediction_files = {} - if 'step8_75' in config and sampling_csv_path: - self._notify("步骤8.75: 自定义回归预测", "start") - step8_75_config = config['step8_75'].copy() - step8_75_config['sampling_csv_path'] = sampling_csv_path - custom_regression_prediction_files = self.step8_75_predict_with_custom_regression(**step8_75_config) - self._notify("步骤8.75: 自定义回归预测", "completed", f"(生成{len(custom_regression_prediction_files)}个预测文件)") + if 'step12' in config and sampling_csv_path: + self._notify("步骤12: 自定义回归预测", "start") + step12_config = config['step12'].copy() + step12_config['sampling_csv_path'] = sampling_csv_path + custom_regression_prediction_files = self.step12_custom_regression_prediction(**step12_config) + self._notify("步骤12: 自定义回归预测", "completed", f"(生成{len(custom_regression_prediction_files)}个预测文件)") else: - self._notify("步骤8.75: 自定义回归预测", "skipped", "未配置或缺少采样点") + self._notify("步骤12: 自定义回归预测", "skipped", "未配置或缺少采样点") # 合并机器学习预测、非经验模型预测和自定义回归预测结果 all_prediction_files = {**prediction_files, **non_empirical_prediction_files, **custom_regression_prediction_files} - # 步骤9: 生成分布图 + # 步骤14: 生成分布图 distribution_maps = {} - if 'step9' in config and all_prediction_files: - self._notify("步骤9: 分布图生成", "start") + if 'step14' in config and all_prediction_files: + self._notify("步骤14: 分布图生成", "start") for target_name, pred_file in all_prediction_files.items(): - step9_config = config['step9'].copy() + step14_config = config['step14'].copy() for _k in ('step9_batch_mode', 'prediction_csv_dir', 'recursive_csv_scan'): - step9_config.pop(_k, None) - step9_config['prediction_csv_path'] = pred_file - if 'output_image_path' not in step9_config: - step9_config['output_image_path'] = None - dist_map_path = self.step9_generate_distribution_map(**step9_config) + step14_config.pop(_k, None) + step14_config['prediction_csv_path'] = pred_file + if 'output_image_path' not in step14_config: + step14_config['output_image_path'] = None + dist_map_path = self.step14_distribution_map(**step14_config) distribution_maps[target_name] = dist_map_path - self._notify("步骤9: 分布图生成", "completed", f"(生成{len(distribution_maps)}个分布图)") + self._notify("步骤14: 分布图生成", "completed", f"(生成{len(distribution_maps)}个分布图)") else: - self._notify("步骤9: 分布图生成", "skipped", "未配置或缺少预测结果") + self._notify("步骤14: 分布图生成", "skipped", "未配置或缺少预测结果") # 生成可视化图表 output_files = {} @@ -1716,10 +1716,10 @@ class WaterQualityInversionPipeline: pipeline_info['step3'] = {'status': 'completed', 'output_file': str(self.deglint_img_path) if self.deglint_img_path else 'N/A'} pipeline_info['step4'] = {'status': 'completed', 'output_file': str(self.processed_csv_path) if self.processed_csv_path else 'N/A'} pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_csv_path) if self.training_csv_path else 'N/A'} - pipeline_info['step5_5'] = {'status': 'completed', 'output_file': str(self.indices_path) if self.indices_path else 'N/A'} - pipeline_info['step6'] = {'status': 'completed', 'output_file': str(self.models_dir)} - pipeline_info['step6_75'] = {'status': 'completed', 'output_file': str(self.custom_regression_path) if self.custom_regression_path else 'N/A'} - pipeline_info['training_params'] = config.get('step6', {}) + pipeline_info['step8'] = {'status': 'completed', 'output_file': str(self.indices_path) if self.indices_path else 'N/A'} + pipeline_info['step7'] = {'status': 'completed', 'output_file': str(self.models_dir)} + pipeline_info['step9'] = {'status': 'completed', 'output_file': str(self.custom_regression_path) if self.custom_regression_path else 'N/A'} + pipeline_info['training_params'] = config.get('step7', {}) summary_path = self.report_generator.generate_batch_inference_summary(pipeline_info) print(f"批量处理摘要已生成: {summary_path}") @@ -1769,7 +1769,7 @@ class WaterQualityInversionPipeline: traceback.print_exc() raise - def step6_5_non_empirical_modeling(self, csv_path: Optional[str] = None, + def step8_non_empirical_modeling(self, csv_path: Optional[str] = None, preprocessing_methods: List[str] = None, algorithms: List[str] = None, value_cols: Union[int, Dict[str, int]] = 0, @@ -1819,7 +1819,7 @@ class WaterQualityInversionPipeline: self._notify("completed", f"非经验模型训练完成") return result - def step6_75_custom_regression(self, + def step9_custom_regression(self, csv_path: Optional[str] = None, x_columns: Optional[Union[str, List[str]]] = None, y_columns: Optional[Union[str, List[str]]] = None, @@ -1999,7 +1999,7 @@ class WaterQualityInversionPipeline: return summary_path - def step8_5_predict_with_non_empirical_models(self, sampling_csv_path: str, + def step11_non_empirical_prediction(self, sampling_csv_path: str, non_empirical_models_dir: Optional[str] = None, output_path: Optional[str] = None, metric: str = 'Average Accuracy(%)', @@ -2007,13 +2007,13 @@ class WaterQualityInversionPipeline: enabled: bool = True, skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]: """ - 步骤8.5: 使用非经验统计回归模型进行参数预测 + 步骤11: 使用非经验统计回归模型进行参数预测 根据非经验模型训练结果汇总CSV筛选给定方法的准确率最高的模型,使用该模型进行预测 Args: sampling_csv_path: 采样点光谱数据CSV路径 - non_empirical_models_dir: 非经验模型保存目录(如果为None,使用步骤6.5的结果) + non_empirical_models_dir: 非经验模型保存目录(如果为None,使用步骤8的结果) output_path: 输出目录路径(如果为None,使用默认目录) metric: 选择最佳模型的指标(默认使用平均准确率) prediction_column: 预测结果列名 @@ -2021,7 +2021,7 @@ class WaterQualityInversionPipeline: Returns: 预测结果文件路径字典(键为算法名) """ - self._notify("started", "步骤8.5: 使用非经验模型进行参数预测") + self._notify("started", "步骤11: 使用非经验模型进行参数预测") result = PredictionStep.predict_with_non_empirical_models( sampling_csv_path=sampling_csv_path, non_empirical_models_dir=non_empirical_models_dir, @@ -2031,11 +2031,11 @@ class WaterQualityInversionPipeline: enabled=enabled, work_dir=str(self.work_dir), ) - self._record_step_time("步骤8.5: 非经验模型预测", 0, 0) + self._record_step_time("步骤11: 非经验模型预测", 0, 0) self._notify("completed", f"非经验模型预测完成,结果保存在: {self.prediction_dir}") return result - def step8_75_predict_with_custom_regression(self, sampling_csv_path: str, + def step12_custom_regression_prediction(self, sampling_csv_path: str, custom_regression_dir: Optional[str] = None, formula_csv_path: Optional[str] = None, coordinate_columns: Optional[List[str]] = None, @@ -2044,13 +2044,13 @@ class WaterQualityInversionPipeline: enabled: bool = True, skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]: """ - 步骤8.75: 使用自定义回归模型进行参数预测 - + 步骤12: 使用自定义回归模型进行参数预测 + 使用新的CustomRegressionPredictor模块,基于9_Custom_Regression_Modeling文件夹中的CSV, 根据r_squared选择最佳模型,批量预测水质参数 - + Args: - sampling_csv_path: 采样点光谱数据CSV路径(来自步骤7) + sampling_csv_path: 采样点光谱数据CSV路径(来自步骤10) custom_regression_dir: 自定义回归模型目录(9_Custom_Regression_Modeling) formula_csv_path: 公式CSV文件路径,用于查找index_formula coordinate_columns: 坐标列名列表,默认为['longitude', 'latitude']或自动识别 @@ -2058,11 +2058,11 @@ class WaterQualityInversionPipeline: filename_prefix: 输出文件名前缀 enabled: 是否启用 skip_dependency_check: 是否跳过依赖检查 - + Returns: 预测结果文件路径字典(键为参数名) """ - self._notify("started", "步骤8.75: 使用自定义回归模型进行参数预测") + self._notify("started", "步骤12: 使用自定义回归模型进行参数预测") result = PredictionStep.predict_with_custom_regression( sampling_csv_path=sampling_csv_path, custom_regression_dir=custom_regression_dir, @@ -2073,7 +2073,7 @@ class WaterQualityInversionPipeline: enabled=enabled, work_dir=str(self.work_dir), ) - self._record_step_time("步骤8.75: 自定义回归模型预测", 0, 0) + self._record_step_time("步骤12: 自定义回归模型预测", 0, 0) self._notify("completed", f"自定义回归预测完成") return result @@ -2161,20 +2161,20 @@ def main(): # 单步运行时建议显式指定;完整流程中可省略,将使用步骤2输出的耀斑掩膜 # 'glint_mask_path': r"path/to/severe_glint_area.dat", }, - 'step5_5': { + 'step8': { 'formula_csv_file': 'path/to/water_quality_formulas.csv', # 公式CSV文件路径 'formula_names': ['Al10SABI', 'TurbBe16RedOverViolet'], # 要计算的公式名称列表 'output_filename': 'water_quality_indices.csv', 'enabled': True # 是否启用水质指数计算 }, - 'step6': { + 'step7': { 'feature_start_column': '374.285004', 'preprocessing_methods': ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT'], 'model_names': ['SVR', 'RF', 'Ridge', 'Lasso'], 'split_methods': ['spxy', 'ks', 'random'], 'cv_folds': 3 }, - 'step6_5': { + 'step8_non_empirical_modeling': { 'preprocessing_methods': ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT'], 'algorithms': ['chl_a', 'nh3', 'mno4', 'tn', 'tp', 'tss'], 'value_cols': 0, # 可以是单个整数或字典,如 {'chl_a': 0, 'nh3': 1, 'mno4': 2, 'tn': 3, 'tp': 4, 'tss': 5} @@ -2182,14 +2182,14 @@ def main(): 'window': 5, 'enabled': True # 是否启用非经验模型训练 }, - 'step6_75': { + 'step9': { 'x_columns': ['NDWI', 'NDVI'], # 自变量列名列表 'y_columns': ['chl_a', 'tn', 'tp'], # 因变量列名列表 'methods': 'all', # 回归方法 'output_dir': 'custom_regression_results', # 输出目录 'enabled': True # 是否启用自定义回归分析 }, - 'step7': { + 'step10': { 'interval': 50, 'sample_radius': 5, 'chunk_size': 1000, @@ -2197,16 +2197,16 @@ def main(): # 可选:耀斑掩膜文件(dat),若不提供将使用步骤2结果;需要外部指定时取消注释 # 'glint_mask_path': r"D:\path\to\severe_glint_area.dat", }, - 'step8': { + 'step11_ml': { 'metric': 'test_r2', 'prediction_column': 'prediction' }, - 'step8_5': { + 'step11': { 'metric': 'Average Accuracy(%)', # 选择最佳模型的指标 'prediction_column': 'prediction', 'enabled': True # 是否启用非经验模型预测 }, - 'step8_75': { + 'step12': { 'custom_regression_dir': None, # 自定义回归模型目录(None表示使用9_Custom_Regression_Modeling) 'formula_csv_path': None, # 公式CSV文件路径,用于查找index_formula(如water_quality_formulas.csv) 'coordinate_columns': None, # 坐标列名(None表示自动识别) @@ -2214,7 +2214,7 @@ def main(): 'filename_prefix': 'custom_regression_prediction', # 输出文件名前缀 'enabled': True # 是否启用自定义回归预测 }, - 'step9': { + 'step14': { 'boundary_shp_path': r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp" , 'resolution': 30, 'input_crs': 'EPSG:32651', @@ -2345,41 +2345,41 @@ def example_independent_steps(): except Exception as e: print(f"步骤6失败: {e}") - # 示例6: 独立运行步骤7 - 采样点生成 - print("\n示例6: 独立运行步骤7 - 采样点生成") + # 示例6: 独立运行步骤10 - 采样点生成 + print("\n示例6: 独立运行步骤10 - 采样点生成") try: - sampling_csv = pipeline.step7_generate_sampling_points( + sampling_csv = pipeline.step10_sampling( deglint_img_path="path/to/deglint_image.bsq", water_mask_path="path/to/water_mask.dat", skip_dependency_check=True ) print(f"采样点数据: {sampling_csv}") except Exception as e: - print(f"步骤7失败: {e}") + print(f"步骤10失败: {e}") - # 示例7: 独立运行步骤8 - 水质预测 - print("\n示例7: 独立运行步骤8 - 水质预测") + # 示例7: 独立运行步骤11 - 水质预测 + print("\n示例7: 独立运行步骤11 - 水质预测") try: - predictions = pipeline.step8_predict_water_quality( + predictions = pipeline.step11_ml_prediction( sampling_csv_path="path/to/sampling_spectra.csv", models_dir="path/to/models_directory", skip_dependency_check=True ) print(f"预测结果: {predictions}") except Exception as e: - print(f"步骤8失败: {e}") + print(f"步骤11失败: {e}") - # 示例8: 独立运行步骤9 - 分布图生成 - print("\n示例8: 独立运行步骤9 - 分布图生成") + # 示例8: 独立运行步骤14 - 分布图生成 + print("\n示例8: 独立运行步骤14 - 分布图生成") try: - distribution_map = pipeline.step9_generate_distribution_map( + distribution_map = pipeline.step14_distribution_map( prediction_csv_path="path/to/prediction_results.csv", boundary_shp_path="path/to/boundary.shp", skip_dependency_check=True ) print(f"分布图: {distribution_map}") except Exception as e: - print(f"步骤9失败: {e}") + print(f"步骤14失败: {e}") print("\n" + "="*80) print("独立步骤运行示例完成") diff --git a/src/gui/core/pipeline_mode_dialog.py b/src/gui/core/pipeline_mode_dialog.py new file mode 100644 index 0000000..8fa1a7a --- /dev/null +++ b/src/gui/core/pipeline_mode_dialog.py @@ -0,0 +1,237 @@ +# -*- coding: utf-8 -*- +""" +PipelineModeDialog:全流程运行前的模式选择弹窗。 + +用户点击"运行完整流程"后,首先弹出此弹窗选择执行模式: + - 选项 A(训练新模型并预测):执行完整建模与预测流程,需要实测水质 CSV + - 选项 B(使用已有模型直接预测):跳过训练步骤,直接使用外部模型目录进行预测 + +弹窗结果: + - QDialog.Accepted + self.selected_mode = "training" 或 "prediction_only" + - QDialog.Rejected → 调用方中止 run_full_pipeline +""" + +import os +from typing import Optional + +from PyQt5.QtCore import Qt +from PyQt5.QtGui import QFont +from PyQt5.QtWidgets import ( + QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, + QRadioButton, QGroupBox, QButtonGroup, QMessageBox, QSizePolicy, +) + + +def _is_valid_model_dir(path: str) -> bool: + """深层递归检测模型目录:只要任意层级存在文件即返回 True。""" + if not path or not os.path.isdir(path): + return False + for _root, _dirs, files in os.walk(path): + if files: + return True + return False + + +class PipelineModeDialog(QDialog): + """全流程模式选择对话框。 + + 两个单选按钮覆盖两种业务场景: + - A:训练新模型(完整流程,需要 step4 CSV) + - B:仅预测(跳过 step4/5/7/8,直接用外部模型目录) + + 属性: + selected_mode: "training" | "prediction_only" + """ + + def __init__(self, main_window=None, parent=None): + super().__init__(parent) + self.main_window = main_window + self.selected_mode: Optional[str] = None + self.setWindowTitle("选择运行模式") + self.setMinimumSize(560, 340) + self.setModal(True) + self._setup_ui() + + # ------------------------------------------------------------------ + # UI 构建 + # ------------------------------------------------------------------ + + def _setup_ui(self): + layout = QVBoxLayout(self) + layout.setContentsMargins(28, 24, 28, 20) + layout.setSpacing(14) + + # ── 标题 ── + title = QLabel("请选择全流程运行模式") + title_font = QFont() + title_font.setPointSize(13) + title_font.setBold(True) + title.setFont(title_font) + title.setAlignment(Qt.AlignCenter) + layout.addWidget(title) + + layout.addSpacing(4) + + # ── 选项 A:训练新模型 ── + group_a = QGroupBox() + group_a.setObjectName("groupA") + group_a.setMinimumHeight(100) + layout.addWidget(group_a) + + self.radio_a = QRadioButton("【训练新模型并预测】") + self.radio_a.setChecked(True) # 默认选项 A + self.radio_a.setObjectName("radioTraining") + + desc_a = QLabel( + "需要提供实测水质数据 (CSV),将执行完整建模与预测流程。\n" + "包括:水域掩膜 → 耀斑去除 → 光谱特征提取 → 模型训练 → 密集采样 → 预测 → 专题图" + ) + desc_a.setWordWrap(True) + desc_a.setStyleSheet("color: #555555; background: transparent;") + desc_a.setObjectName("descA") + + vbox_a = QVBoxLayout(group_a) + vbox_a.setContentsMargins(16, 20, 16, 14) + vbox_a.setSpacing(8) + vbox_a.addWidget(self.radio_a) + vbox_a.addWidget(desc_a) + + # ── 选项 B:仅预测 ── + group_b = QGroupBox() + group_b.setObjectName("groupB") + group_b.setMinimumHeight(100) + layout.addWidget(group_b) + + self.radio_b = QRadioButton("【使用已有模型直接预测】") + self.radio_b.setObjectName("radioPrediction") + + desc_b = QLabel( + "跳过模型训练步骤,直接使用导入的外部模型目录进行预测。\n" + "前提条件:请在「监督预测」或「回归预测」面板中指定模型目录。\n" + "适用范围:已有预训练模型、或其他来源模型目录。" + ) + desc_b.setWordWrap(True) + desc_b.setStyleSheet("color: #555555; background: transparent;") + desc_b.setObjectName("descB") + + vbox_b = QVBoxLayout(group_b) + vbox_b.setContentsMargins(16, 20, 16, 14) + vbox_b.setSpacing(8) + vbox_b.addWidget(self.radio_b) + vbox_b.addWidget(desc_b) + + # ── 强制互斥:QButtonGroup ── + self.mode_group = QButtonGroup(self) + self.mode_group.addButton(self.radio_a) + self.mode_group.addButton(self.radio_b) + + # ── 提示栏(动态显示 models_dir 状态) ── + self.models_hint = QLabel() + self.models_hint.setObjectName("modelsHint") + self.models_hint.setWordWrap(True) + self.models_hint.setStyleSheet("color: #888888; font-size: 11px; padding: 4px 0;") + layout.addWidget(self.models_hint) + + # ── 强制 QRadioButton 指示器为实心圆点 ── + self.setStyleSheet(""" + QRadioButton::indicator { + width: 14px; + height: 14px; + } + QRadioButton::indicator:checked { + background-color: #0078D7; + border: 2px solid #0078D7; + border-radius: 7px; + } + QRadioButton::indicator:unchecked { + background-color: white; + border: 2px solid #A0A0A0; + border-radius: 7px; + } + """) + + # ── 按钮 ── + btn_layout = QHBoxLayout() + btn_layout.addStretch() + + cancel_btn = QPushButton("取消") + cancel_btn.setObjectName("cancelBtn") + cancel_btn.setMinimumWidth(90) + cancel_btn.clicked.connect(self.reject) + + self.btn_confirm = QPushButton("确认") + self.btn_confirm.setObjectName("confirmBtn") + self.btn_confirm.setMinimumWidth(90) + self.btn_confirm.setDefault(True) + self.btn_confirm.clicked.connect(self._on_confirm) + + btn_layout.addWidget(self.btn_confirm) + btn_layout.addWidget(cancel_btn) + layout.addLayout(btn_layout) + + # 信号连接:任一 radio 切换时重新渲染提示 + 按钮状态 + self.radio_a.toggled.connect(self._update_models_hint) + self.radio_b.toggled.connect(self._update_models_hint) + + # 初始状态渲染 + self._update_models_hint() + + def _update_models_hint(self, checked=False, *args) -> None: + """根据当前选中模式和 models_dir 状态更新提示文字及确认按钮可用性。""" + training_checked = self.radio_a.isChecked() + + # 从主窗口 config 读取 models_dir(优先 ml,其次 reg) + models_dir = "" + if self.main_window: + config = self.main_window.get_current_config() + models_dir = config.get("step11_ml", {}).get("models_dir", "") + if not models_dir: + models_dir = config.get("step11", {}).get("models_dir", "") + + has_files = bool(models_dir and _is_valid_model_dir(models_dir)) + dir_exists = bool(models_dir and os.path.isdir(models_dir)) + + if training_checked: + if hasattr(self, 'btn_confirm') and self.btn_confirm is not None: + self.btn_confirm.setEnabled(True) + if has_files: + self.models_hint.setText( + f"⚠ 注意:当前模型目录已包含文件,继续训练将会【覆盖】原有模型!\n路径:{models_dir}" + ) + self.models_hint.setStyleSheet("color: #e65100; font-size: 11px; padding: 4px 0;") + else: + label = f"✓ 模型将保存至该目录(当前为空,安全)。\n路径:{models_dir}" if dir_exists else "✓ 尚未指定模型目录,将使用默认路径创建新模型。" + self.models_hint.setText(label) + self.models_hint.setStyleSheet("color: #2e7d32; font-size: 11px; padding: 4px 0;") + else: + if has_files: + self.models_hint.setText( + f"✓ 已检测到有效模型目录,可以直接预测。\n路径:{models_dir}" + ) + self.models_hint.setStyleSheet("color: #2e7d32; font-size: 11px; padding: 4px 0;") + if hasattr(self, 'btn_confirm') and self.btn_confirm is not None: + self.btn_confirm.setEnabled(True) + else: + if dir_exists: + self.models_hint.setText( + f"❌ 错误:模型目录为空(未找到任何文件),无法进行预测!\n路径:{models_dir}" + ) + else: + self.models_hint.setText( + "❌ 错误:模型目录为空或不存在!请先返回对应面板配置有效路径。" + ) + self.models_hint.setStyleSheet("color: #c62828; font-size: 11px; padding: 4px 0;") + if hasattr(self, 'btn_confirm') and self.btn_confirm is not None: + self.btn_confirm.setEnabled(False) + + def _on_confirm(self) -> None: + """确认按钮回调:直接存储模式并关闭。 + + 注意:按钮禁用状态已在 _update_models_hint 中处理, + 此处仅负责结果存储,不再做二次弹窗拦截。 + """ + if self.radio_a.isChecked(): + self.selected_mode = "training" + else: + self.selected_mode = "prediction_only" + self.accept() \ No newline at end of file diff --git a/src/gui/dialogs.py b/src/gui/dialogs.py index 04ecc8a..3c95a6f 100644 --- a/src/gui/dialogs.py +++ b/src/gui/dialogs.py @@ -349,6 +349,12 @@ class AISettingsDialog(QDialog): # 交互式采样点与光谱查看器 # ───────────────────────────────────────────────────────────────────────────── +import matplotlib.pyplot as _plt + +# 全局字体设置(防中文乱码 + 负号显示异常) +_plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'Arial Unicode MS'] +_plt.rcParams['axes.unicode_minus'] = False + class SamplingViewerDialog(QDialog): """交互式采样点与光谱查看器 @@ -381,8 +387,8 @@ class SamplingViewerDialog(QDialog): self._fig = Figure(figsize=(6, 5)) self._canvas = FigureCanvasQTAgg(self._fig) self._ax_scatter = self._fig.add_subplot(111) - self._ax_scatter.set_xlabel("pixel_x") - self._ax_scatter.set_ylabel("pixel_y") + self._ax_scatter.set_xlabel("像素 X") + self._ax_scatter.set_ylabel("像素 Y") self._ax_scatter.set_title("采样点分布(点击查看详情)") self._ax_scatter.invert_yaxis() self._fig.tight_layout() @@ -392,18 +398,20 @@ class SamplingViewerDialog(QDialog): # --- 右侧:信息面板 + 光谱子图 --- right_widget = QWidget() right_layout = QVBoxLayout(right_widget) + right_layout.setContentsMargins(0, 0, 0, 0) + # 坐标信息面板(多行中文清晰显示) self._info_label = QLabel("点击左侧散点图选择采样点") self._info_label.setStyleSheet( - "QLabel { background-color: #f0f0f0; padding: 6px; " - "border: 1px solid #ccc; font-size: 13px; }" + "QLabel { background-color: #f0f0f0; padding: 8px; " + "border: 1px solid #ccc; border-radius: 4px; font-size: 13px; }" ) right_layout.addWidget(self._info_label) # 采样点列表迷你表格 self._point_table = QTableWidget() self._point_table.setColumnCount(3) - self._point_table.setHorizontalHeaderLabels(["pixel_x", "pixel_y", "index"]) + self._point_table.setHorizontalHeaderLabels(["像素 X", "像素 Y", "序号"]) self._point_table.setMaximumHeight(120) self._point_table.setEditTriggers(QTableWidget.NoEditTriggers) self._point_table.setSelectionBehavior(QTableWidget.SelectRows) @@ -413,8 +421,8 @@ class SamplingViewerDialog(QDialog): # 光谱曲线子图 self._fig_right = Figure(figsize=(5, 3)) self._ax_spectrum = self._fig_right.add_subplot(111) - self._ax_spectrum.set_xlabel("Band Index") - self._ax_spectrum.set_ylabel("Reflectance") + self._ax_spectrum.set_xlabel("波段序号") + self._ax_spectrum.set_ylabel("反射率") self._ax_spectrum.set_title("光谱曲线") self._fig_right.tight_layout() @@ -440,8 +448,8 @@ class SamplingViewerDialog(QDialog): def _draw_scatter(self): """绘制散点图""" self._ax_scatter.clear() - self._ax_scatter.set_xlabel("pixel_x") - self._ax_scatter.set_ylabel("pixel_y") + self._ax_scatter.set_xlabel("像素 X") + self._ax_scatter.set_ylabel("像素 Y") self._ax_scatter.set_title("采样点分布(点击查看详情)") self._ax_scatter.invert_yaxis() @@ -497,14 +505,14 @@ class SamplingViewerDialog(QDialog): self._info_label.setText( f"选中的采样点 #{nearest_idx}
" - f"pixel_x = {pixel_x}   pixel_y = {pixel_y}
" - f"x_coord = {x_coord}   y_coord = {y_coord}" + f"图像像素坐标: X = {pixel_x}, Y = {pixel_y}
" + f"地理真实坐标: 经度(X) = {x_coord}, 纬度(Y) = {y_coord}" ) # 高亮散点图 self._ax_scatter.clear() - self._ax_scatter.set_xlabel("pixel_x") - self._ax_scatter.set_ylabel("pixel_y") + self._ax_scatter.set_xlabel("像素 X") + self._ax_scatter.set_ylabel("像素 Y") self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)") self._ax_scatter.invert_yaxis() self._ax_scatter.scatter( @@ -523,8 +531,8 @@ class SamplingViewerDialog(QDialog): def _draw_spectrum(self, row: pd.Series): """从一行数据中提取纯波段数值并绘图""" self._ax_spectrum.clear() - self._ax_spectrum.set_xlabel("Band Index") - self._ax_spectrum.set_ylabel("Reflectance") + self._ax_spectrum.set_xlabel("波段序号") + self._ax_spectrum.set_ylabel("反射率") self._ax_spectrum.set_title("光谱曲线") exclude_patterns = ( diff --git a/src/gui/panels/step10_panel.py b/src/gui/panels/step10_panel.py new file mode 100644 index 0000000..5e6150d --- /dev/null +++ b/src/gui/panels/step10_panel.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Step10 面板 - 采样点生成 +""" + +import os + +from PyQt5.QtWidgets import ( + QWidget, QVBoxLayout, QGroupBox, QFormLayout, + QPushButton, QCheckBox, QSpinBox, QMessageBox, +) + +from src.gui.components.custom_widgets import FileSelectWidget +from src.gui.dialogs import SamplingViewerDialog +from src.gui.styles import ModernStylesheet + + +class Step10Panel(QWidget): + """步骤10:采样点生成""" + def __init__(self, parent=None): + super().__init__(parent) + self.init_ui() + + def init_ui(self): + layout = QVBoxLayout() + + # 去耀斑影像文件(用于独立运行) + self.deglint_img_file = FileSelectWidget( + "去耀斑影像:", + "Image Files (*.bsq *.dat *.tif);;All Files (*.*)" + ) + layout.addWidget(self.deglint_img_file) + + # 水域掩膜文件(可选,用于独立运行) + self.water_mask_file = FileSelectWidget( + "水域掩膜:", + "Mask Files (*.dat *.tif);;All Files (*.*)" + ) + self.water_mask_file.label.setText("水域掩膜:") + layout.addWidget(self.water_mask_file) + + # 参数设置 + params_group = QGroupBox("采样参数") + params_layout = QFormLayout() + + self.interval = QSpinBox() + self.interval.setRange(10, 500) + self.interval.setValue(50) + params_layout.addRow("采样点间隔(像素):", self.interval) + + self.sample_radius = QSpinBox() + self.sample_radius.setRange(1, 50) + self.sample_radius.setValue(5) + params_layout.addRow("采样半径(像素):", self.sample_radius) + + self.chunk_size = QSpinBox() + self.chunk_size.setRange(100, 10000) + self.chunk_size.setValue(1000) + params_layout.addRow("处理块大小:", self.chunk_size) + + self.use_adaptive_sampling = QCheckBox("启用自适应采样") + self.use_adaptive_sampling.setChecked(True) + params_layout.addRow("采样模式:", self.use_adaptive_sampling) + + params_group.setLayout(params_layout) + layout.addWidget(params_group) + + # 输出文件路径 + self.output_file = FileSelectWidget( + "输出采样点:", + "CSV Files (*.csv);;All Files (*.*)" + ) + self.output_file.line_edit.setPlaceholderText("sampling_points.csv") + layout.addWidget(self.output_file) + + # 启用步骤 + self.enable_checkbox = QCheckBox("启用此步骤") + self.enable_checkbox.setChecked(True) + layout.addWidget(self.enable_checkbox) + + # 独立运行按钮 + self.run_btn = QPushButton("独立运行此步骤") + self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) + self.run_btn.clicked.connect(self.run_step) + layout.addWidget(self.run_btn) + + # 交互式预览按钮 + self.preview_btn = QPushButton("📊 交互式预览采样点与光谱") + self.preview_btn.setEnabled(False) + self.preview_btn.clicked.connect(self._open_sampling_viewer) + layout.addWidget(self.preview_btn) + + layout.addStretch() + self.setLayout(layout) + + # 监听输出路径变化,实时更新预览按钮状态 + self.output_file.line_edit.textChanged.connect(self._on_output_changed) + + def get_config(self): + """获取配置""" + config = { + 'interval': self.interval.value(), + 'sample_radius': self.sample_radius.value(), + 'chunk_size': self.chunk_size.value(), + 'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(), + } + deglint_img_path = self.deglint_img_file.get_path() + if deglint_img_path: + config['deglint_img_path'] = deglint_img_path + water_mask_path = self.water_mask_file.get_path() + if water_mask_path: + config['water_mask_path'] = water_mask_path + return config + + def set_config(self, config): + """设置配置""" + if 'interval' in config: + self.interval.setValue(config['interval']) + if 'sample_radius' in config: + self.sample_radius.setValue(config['sample_radius']) + if 'chunk_size' in config: + self.chunk_size.setValue(config['chunk_size']) + if 'use_adaptive_sampling' in config: + self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling']) + if 'deglint_img_path' in config: + self.deglint_img_file.set_path(config['deglint_img_path']) + if 'water_mask_path' in config: + self.water_mask_file.set_path(config['water_mask_path']) + if 'glint_mask_path' in config: + self.glint_mask_file.set_path(config['glint_mask_path']) + + def update_from_config(self, work_dir=None, pipeline=None): + """从全局配置自动填充去耀斑影像和掩膜路径 + + Args: + work_dir: 工作目录路径 + pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径) + """ + if work_dir: + self.work_dir = work_dir + elif hasattr(self, 'work_dir') and self.work_dir: + pass + else: + self.work_dir = None + + main_window = self.window() + + # 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径) + deglint_path = None + if pipeline and hasattr(pipeline, 'step_outputs'): + step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {}) + deglint_path = ( + step3_outputs.get('deglint_image') + or step3_outputs.get('output_path') + or step3_outputs.get('output_file') + or step3_outputs.get('deglint_img_path') + ) + # 回退:从 step3 面板 widget 直接读取(可能是相对路径) + if not deglint_path and hasattr(main_window, 'step3_panel'): + deglint_path = main_window.step3_panel.output_file.get_path() + + if deglint_path: + # 若为相对路径,使用 work_dir 合成为绝对路径 + if not os.path.isabs(deglint_path): + deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/') + self.deglint_img_file.set_path(deglint_path) + + # 2. 填充水域掩膜路径(优先级:pipeline.step_outputs > step1_panel > 1_water_mask > input-test) + water_mask_path = None + if pipeline and hasattr(pipeline, 'step_outputs'): + step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {}) + water_mask_path = ( + step1_outputs.get('water_mask') + or step1_outputs.get('output_path') + or step1_outputs.get('output_file') + ) + # 回退:从 step1 面板 widget 直接读取 + if not water_mask_path and hasattr(main_window, 'step1_panel'): + water_mask_path = main_window.step1_panel.output_file.get_path() + # 备选:扫描 1_water_mask 目录下的 .dat 文件 + if not water_mask_path and self.work_dir: + mask_dir = os.path.join(self.work_dir, "1_water_mask") + if os.path.isdir(mask_dir): + dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')] + if dat_files: + water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/') + # 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat) + if not water_mask_path and self.work_dir: + input_test_dir = os.path.join(self.work_dir, "input-test") + if os.path.isdir(input_test_dir): + dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')] + # 优先匹配 water_mask_from_shp.dat + for f in dat_files: + if 'water_mask_from_shp' in f.lower(): + water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/') + break + # 否则取第一个 .dat 文件 + if not water_mask_path and dat_files: + water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/') + + if water_mask_path: + # 若为相对路径,使用 work_dir 合成为绝对路径 + if not os.path.isabs(water_mask_path): + water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/') + self.water_mask_file.set_path(water_mask_path) + + # 3. 自动填充输出路径(绝对路径) + if self.work_dir: + output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv") + os.makedirs(os.path.dirname(output_path), exist_ok=True) + self.output_file.set_path(output_path.replace('\\', '/')) + + # 4. 同步更新预览按钮状态(路径可能已自动填充) + self._check_csv_exists() + + def run_step(self): + """独立运行步骤10""" + deglint_img_path = self.deglint_img_file.get_path() + if not deglint_img_path: + QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!") + return + + main_window = self.window() + if hasattr(main_window, 'run_single_step'): + config = {'step10': self.get_config()} + main_window.run_single_step('step10', config) + + def _check_csv_exists(self): + """检查 output csv 是否存在,驱动预览按钮启停""" + csv_path = self.output_file.get_path() + enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path)) + self.preview_btn.setEnabled(enabled) + return enabled + + def _on_output_changed(self, _text=None): + """输出路径输入框内容变化时调用(_text 为 line_edit.textChanged 信号参数)""" + self._check_csv_exists() + + def _open_sampling_viewer(self): + """打开交互式采样点查看器弹窗""" + csv_path = self.output_file.get_path() + if not csv_path or not os.path.exists(csv_path): + QMessageBox.warning( + self, "文件不存在", + f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤10生成数据。" + ) + return + dialog = SamplingViewerDialog(csv_path, self) + dialog.exec_() + # 弹窗关闭后再次检查状态(可能文件被覆盖等) + self._check_csv_exists() diff --git a/src/gui/panels/step11_ml_panel.py b/src/gui/panels/step11_ml_panel.py new file mode 100644 index 0000000..8881f1e --- /dev/null +++ b/src/gui/panels/step11_ml_panel.py @@ -0,0 +1,462 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Step8 面板 - 机器学习预测 +""" + +import os +from pathlib import Path + +from PyQt5.QtWidgets import ( + QWidget, QVBoxLayout, QGroupBox, QFormLayout, + QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox, + QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout, + QListWidgetItem, +) +from PyQt5.QtCore import Qt + +from src.gui.components.custom_widgets import FileSelectWidget +from src.gui.styles import ModernStylesheet + + +class Step11MlPanel(QWidget): + """步骤11:机器学习预测""" + def __init__(self, parent=None): + super().__init__(parent) + self.external_models_dict = {} # {subdir_name: model_obj, ...} + self.external_model_dir = "" # 母文件夹路径(隐藏) + self.init_ui() + + def init_ui(self): + layout = QVBoxLayout() + + # -------- 模型来源选择(单选按钮组) -------- + source_group = QGroupBox("模型来源") + source_layout = QVBoxLayout() + + self.use_trained_model = QRadioButton("使用当前训练流程的模型") + self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)") + self.use_trained_model.setChecked(True) + source_layout.addWidget(self.use_trained_model) + source_layout.addWidget(self.use_external_model) + + self.use_trained_model.toggled.connect(self._on_model_source_changed) + self.use_external_model.toggled.connect(self._on_model_source_changed) + + source_group.setStyleSheet(""" + QRadioButton { + font-size: 13px; + spacing: 8px; + } + QRadioButton::indicator { + width: 16px; + height: 16px; + border-radius: 9px; + border: 2px solid #A0A0A0; + background-color: #FFFFFF; + } + QRadioButton::indicator:hover { + border: 2px solid #0078D7; + } + QRadioButton::indicator:checked { + background-color: #0078D7; + border: 2px solid #0078D7; + } + """) + + source_group.setLayout(source_layout) + layout.addWidget(source_group) + + # -------- 外部模型文件选择(条件显示) -------- + self.external_model_widget = FileSelectWidget( + "模型母文件夹:", + "Directories" + ) + self.external_model_widget.browse_btn.clicked.disconnect() + self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir) + self.external_model_widget.setVisible(False) + layout.addWidget(self.external_model_widget) + + # -------- 已扫描模型列表(条件显示) -------- + self.model_list_group = QGroupBox("选择参与预测的模型") + self.model_list_group.setVisible(False) + model_list_layout = QVBoxLayout() + + self.model_list = QListWidget() + self.model_list.setMaximumHeight(130) + self.model_list.setSelectionMode(QAbstractItemView.NoSelection) + self.model_list.setStyleSheet(""" + QListWidget { + border: 1px solid #C0C0C0; + border-radius: 4px; + background-color: #FFFFFF; + font-size: 12px; + } + QListWidget::item { + padding: 4px 6px; + border-bottom: 1px solid #F0F0F0; + } + QListWidget::item:selected { + background-color: transparent; + } + """) + model_list_layout.addWidget(self.model_list) + + btn_row = QHBoxLayout() + self.btn_select_all = QPushButton("全选") + self.btn_select_all.setMaximumWidth(80) + self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default')) + self.btn_select_all.clicked.connect(self._select_all_models) + + self.btn_select_none = QPushButton("全不选") + self.btn_select_none.setMaximumWidth(80) + self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default')) + self.btn_select_none.clicked.connect(self._select_none_models) + + btn_row.addWidget(self.btn_select_all) + btn_row.addWidget(self.btn_select_none) + btn_row.addStretch() + model_list_layout.addLayout(btn_row) + + self.model_list_group.setLayout(model_list_layout) + layout.addWidget(self.model_list_group) + + # -------- 采样光谱CSV文件(用于独立运行)-------- + self.sampling_csv_file = FileSelectWidget( + "采样光谱CSV:", + "CSV Files (*.csv);;All Files (*.*)" + ) + layout.addWidget(self.sampling_csv_file) + + # 模型目录(用于独立运行) + self.models_dir_file = FileSelectWidget( + "模型目录:", + "Directories;;All Files (*.*)" + ) + self.models_dir_file.label.setText("模型目录:") + self.models_dir_file.browse_btn.clicked.disconnect() + self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir) + layout.addWidget(self.models_dir_file) + + # 参数设置 + params_group = QGroupBox("预测参数") + params_layout = QFormLayout() + + self.metric = QComboBox() + self.metric.addItems(['test_r2', 'test_rmse', 'test_mae']) + params_layout.addRow("模型选择指标:", self.metric) + + self.prediction_column = QLineEdit() + self.prediction_column.setText("prediction") + params_layout.addRow("预测列名:", self.prediction_column) + + params_group.setLayout(params_layout) + layout.addWidget(params_group) + + # 输出路径 + self.output_file = FileSelectWidget( + "输出路径:", + "CSV Files (*.csv);;All Files (*.*)" + ) + layout.addWidget(self.output_file) + + # 启用步骤 + self.enable_checkbox = QCheckBox("启用此步骤") + self.enable_checkbox.setChecked(True) + layout.addWidget(self.enable_checkbox) + + # 独立运行按钮 + self.run_btn = QPushButton("独立运行此步骤") + self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) + self.run_btn.clicked.connect(self.run_step) + layout.addWidget(self.run_btn) + + layout.addStretch() + self.setLayout(layout) + + def _on_model_source_changed(self, checked: bool): + """单选按钮切换:控制外部模型文件选择控件的显示/隐藏""" + if not checked: + return + is_external = self.use_external_model.isChecked() + self.external_model_widget.setVisible(is_external) + self.model_list_group.setVisible(is_external) + if not is_external: + self.external_models_dict = {} + self.external_model_dir = "" + self._clear_model_list() + + def _scan_external_model_dir(self): + """浏览模型母文件夹,自动扫描子目录中的 .joblib 文件""" + default = self._get_default_work_dir() + if default: + default = os.path.join(default, "7_Supervised_Model_Training") + dir_path = QFileDialog.getExistingDirectory( + self, + "选择模型母文件夹", + default, + ) + if not dir_path: + return + + self.external_model_dir = dir_path + models_found = {} + errors = [] + + try: + import joblib + + for subentry in os.scandir(dir_path): + if not subentry.is_dir(): + continue + subdir_name = subentry.name + joblib_files = [ + f for f in os.scandir(subentry.path) + if f.is_file() and f.name.lower().endswith(".joblib") + ] + if not joblib_files: + continue + # 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致) + joblib_path = joblib_files[0].path + try: + loaded = joblib.load(joblib_path) + if isinstance(loaded, dict) and "model" in loaded: + model_obj = loaded["model"] + elif hasattr(loaded, "predict"): + model_obj = loaded + else: + errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}") + continue + models_found[subdir_name] = model_obj + except Exception as e: + errors.append(f"{subdir_name}: {type(e).__name__}: {e}") + + except Exception as e: + QMessageBox.warning( + self, + "扫描失败", + f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}", + ) + return + + if not models_found: + QMessageBox.warning( + self, + "未找到模型", + f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n" + "请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。", + ) + self.external_model_widget.set_path("") + self.external_models_dict = {} + self._clear_model_list() + return + + self.external_models_dict = models_found + self._populate_model_list(models_found) + names = sorted(models_found.keys()) + display = f"已识别到 {len(names)} 个模型: {', '.join(names)}" + self.external_model_widget.set_path(display) + self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;") + + err_lines = "\n".join(errors) if errors else "无" + QMessageBox.information( + self, + "模型扫描完成", + f"成功加载 {len(models_found)} 个模型:\n{display}\n\n" + f"加载失败 {len(errors)} 个:\n{err_lines}", + ) + + def _populate_model_list(self, models_dict): + """将扫描到的模型填充到 QListWidget,每个条目可勾选,默认全选""" + self.model_list.clear() + for name in sorted(models_dict.keys()): + item = QListWidgetItem(name) + item.setFlags(item.flags() | Qt.ItemIsUserCheckable) + item.setCheckState(Qt.Checked) + self.model_list.addItem(item) + + def _clear_model_list(self): + """清空模型列表""" + self.model_list.clear() + + def _select_all_models(self): + """全选:设置所有条目为 Checked""" + for i in range(self.model_list.count()): + self.model_list.item(i).setCheckState(Qt.Checked) + + def _select_none_models(self): + """全不选:设置所有条目为 Unchecked""" + for i in range(self.model_list.count()): + self.model_list.item(i).setCheckState(Qt.Unchecked) + + def _get_checked_models_dict(self): + """从列表中提取用户勾选的模型,组装成字典返回""" + result = {} + for i in range(self.model_list.count()): + item = self.model_list.item(i) + if item.checkState() == Qt.Checked: + name = item.text() + if name in self.external_models_dict: + result[name] = self.external_models_dict[name] + return result + + def update_from_config(self, work_dir=None, pipeline=None): + """从全局配置自动填充采样光谱和模型目录 + + Args: + work_dir: 工作目录路径 + pipeline: Pipeline 实例(未使用,保留接口兼容性) + """ + try: + import traceback + + if work_dir: + self.work_dir = work_dir + elif hasattr(self, 'work_dir') and self.work_dir: + pass + else: + self.work_dir = None + + main_window = self.window() + + # 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径 + if main_window and hasattr(main_window, 'step10_panel'): + step7_widget = getattr(main_window.step10_panel, 'output_file', None) + step7_output_path = "" + if hasattr(step7_widget, 'get_path'): + step7_output_path = step7_widget.get_path() or "" + elif hasattr(step7_widget, 'text'): + step7_output_path = step7_widget.text() or "" + + if step7_output_path: + # 若为相对路径,使用 work_dir 合成为绝对路径 + if not os.path.isabs(step7_output_path): + step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/') + existing = self.sampling_csv_file.get_path() + if not existing or not existing.strip(): + self.sampling_csv_file.set_path(step7_output_path) + + # 2. 尝试从 Step6 界面读取监督模型目录 + if main_window and hasattr(main_window, 'step7_panel'): + step6_widget = getattr(main_window.step7_panel, 'output_dir', None) + step6_models_dir = "" + if hasattr(step6_widget, 'get_path'): + step6_models_dir = step6_widget.get_path() or "" + elif hasattr(step6_widget, 'text'): + step6_models_dir = step6_widget.text() or "" + + if step6_models_dir: + # 若为相对路径,使用 work_dir 合成为绝对路径 + if not os.path.isabs(step6_models_dir): + step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/') + existing_models = self.models_dir_file.get_path() + if not existing_models or not existing_models.strip(): + self.models_dir_file.set_path(step6_models_dir) + + # 3. 自动填充输出路径(机器学习预测目录) + if self.work_dir: + output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction") + os.makedirs(output_dir, exist_ok=True) + existing_out = self.output_file.get_path() + if not existing_out or not existing_out.strip(): + self.output_file.set_path(output_dir) + except Exception as e: + import traceback + print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}") + traceback.print_exc() + + def _get_default_work_dir(self): + """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" + if hasattr(self, 'work_dir') and self.work_dir: + return str(self.work_dir) + mw = self.window() + if mw and hasattr(mw, 'work_dir') and mw.work_dir: + return str(mw.work_dir) + return "" + + def browse_models_dir(self): + """浏览模型目录""" + default = self._get_default_work_dir() + if default: + default = os.path.join(default, "7_Supervised_Model_Training") + dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default) + if dir_path: + self.models_dir_file.set_path(dir_path) + + def get_config(self): + """获取配置""" + config = { + 'metric': self.metric.currentText(), + 'prediction_column': self.prediction_column.text(), + } + sampling_csv_path = self.sampling_csv_file.get_path() + if sampling_csv_path: + config['sampling_csv_path'] = sampling_csv_path + models_dir = self.models_dir_file.get_path() + if models_dir: + config['models_dir'] = models_dir + output_path = self.output_file.get_path() + if output_path: + config['output_path'] = output_path + return config + + def set_config(self, config): + """设置配置""" + if 'metric' in config: + idx = self.metric.findText(config['metric']) + if idx >= 0: + self.metric.setCurrentIndex(idx) + if 'prediction_column' in config: + self.prediction_column.setText(config['prediction_column']) + if 'sampling_csv_path' in config: + self.sampling_csv_file.set_path(config['sampling_csv_path']) + if 'models_dir' in config: + self.models_dir_file.set_path(config['models_dir']) + if 'output_path' in config: + self.output_file.set_path(config['output_path']) + + def run_step(self): + """独立运行步骤8""" + sampling_csv_path = self.sampling_csv_file.get_path() + if not sampling_csv_path: + QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!") + return + + # 外部模型优先:用户选择了"导入本地预训练模型" + if self.use_external_model.isChecked(): + if not self.external_models_dict: + QMessageBox.warning( + self, + "模型未加载", + "请先点击「浏览...」按钮选择模型母文件夹!", + ) + return + # 只传递用户勾选的模型 + checked_dict = self._get_checked_models_dict() + if not checked_dict: + QMessageBox.warning( + self, + "未选择模型", + "请至少勾选一个模型参与预测!", + ) + return + main_window = self.window() + if hasattr(main_window, 'run_single_step'): + config = { + 'step11_ml': self.get_config(), + '_external_models_dict': checked_dict, + '_external_model_dir': self.external_model_dir, + } + main_window.run_single_step('step11_ml', config) + return + + # 默认流程:使用模型目录 + models_dir = self.models_dir_file.get_path() + if not models_dir: + QMessageBox.warning(self, "输入错误", "请选择模型目录!") + return + + main_window = self.window() + if hasattr(main_window, 'run_single_step'): + config = {'step11_ml': self.get_config()} + main_window.run_single_step('step11_ml', config) diff --git a/src/gui/panels/step8_5_panel.py b/src/gui/panels/step11_panel.py similarity index 86% rename from src/gui/panels/step8_5_panel.py rename to src/gui/panels/step11_panel.py index 74e8520..ce78fb9 100644 --- a/src/gui/panels/step8_5_panel.py +++ b/src/gui/panels/step11_panel.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -Step8_5 面板 - 非经验模型预测 +Step11 面板 - 非经验模型预测 """ import os @@ -17,8 +17,8 @@ from src.gui.components.custom_widgets import FileSelectWidget from src.gui.styles import ModernStylesheet -class Step8_5Panel(QWidget): - """步骤8.5:非经验模型预测""" +class Step11Panel(QWidget): + """步骤11:非经验模型预测""" def __init__(self, parent=None): super().__init__(parent) self.init_ui() @@ -118,22 +118,22 @@ class Step8_5Panel(QWidget): if not existing or not existing.strip(): self.sampling_csv_file.set_path(step7_output_path) - # 2. 尝试从 Step6.5 界面读取回归模型目录 - if main_window and hasattr(main_window, 'step6_5_panel'): - step6_5_widget = getattr(main_window.step6_5_panel, 'output_dir', None) - step6_5_models_dir = "" - if hasattr(step6_5_widget, 'get_path'): - step6_5_models_dir = step6_5_widget.get_path() or "" - elif hasattr(step6_5_widget, 'text'): - step6_5_models_dir = step6_5_widget.text() or "" + # 2. 尝试从 Step8_Non_Empirical 界面读取回归模型目录 + if main_window and hasattr(main_window, 'step8_non_empirical_panel'): + step8_non_empirical_widget = getattr(main_window.step8_non_empirical_panel, 'output_dir', None) + step8_non_empirical_models_dir = "" + if hasattr(step8_non_empirical_widget, 'get_path'): + step8_non_empirical_models_dir = step8_non_empirical_widget.get_path() or "" + elif hasattr(step8_non_empirical_widget, 'text'): + step8_non_empirical_models_dir = step8_non_empirical_widget.text() or "" - if step6_5_models_dir: + if step8_non_empirical_models_dir: # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(step6_5_models_dir): - step6_5_models_dir = os.path.join(self.work_dir or '', step6_5_models_dir).replace('\\', '/') + if not os.path.isabs(step8_non_empirical_models_dir): + step8_non_empirical_models_dir = os.path.join(self.work_dir or '', step8_non_empirical_models_dir).replace('\\', '/') existing_models = self.models_dir_file.get_path() if not existing_models or not existing_models.strip(): - self.models_dir_file.set_path(step6_5_models_dir) + self.models_dir_file.set_path(step8_non_empirical_models_dir) # 3. 自动填充输出路径(非经验模型预测目录) if self.work_dir: @@ -208,7 +208,7 @@ class Step8_5Panel(QWidget): self.enable_checkbox.setChecked(config['enabled']) def run_step(self): - """独立运行步骤8.5""" + """独立运行步骤11""" sampling_csv_path = self.sampling_csv_file.get_path() if not sampling_csv_path: QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!") @@ -221,6 +221,6 @@ class Step8_5Panel(QWidget): parent = parent.parent() if parent and hasattr(parent, 'run_single_step'): - parent.run_single_step('step8_5', {'step8_5': config}) + parent.run_single_step('step11', {'step11': config}) else: QMessageBox.critical(self, "错误", "无法找到父级GUI对象") diff --git a/src/gui/panels/step8_75_panel.py b/src/gui/panels/step12_panel.py similarity index 88% rename from src/gui/panels/step8_75_panel.py rename to src/gui/panels/step12_panel.py index fb37d53..ab97b71 100644 --- a/src/gui/panels/step8_75_panel.py +++ b/src/gui/panels/step12_panel.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -Step8_75 面板 - 自定义回归预测 +Step12 面板 - 自定义回归预测 """ import os @@ -15,8 +15,8 @@ from src.gui.components.custom_widgets import FileSelectWidget from src.gui.styles import ModernStylesheet -class Step8_75Panel(QWidget): - """步骤8.75:自定义回归预测""" +class Step12Panel(QWidget): + """步骤12:自定义回归预测""" def __init__(self, parent=None): super().__init__(parent) self.init_ui() @@ -111,25 +111,25 @@ class Step8_75Panel(QWidget): if not existing or not existing.strip(): self.sampling_csv_file.set_path(step7_output_path) - # 2. 尝试从 Step6.75 界面读取自定义回归模型目录 - if main_window and hasattr(main_window, 'step6_75_panel'): - step6_75_widget = getattr(main_window.step6_75_panel, 'output_dir', None) - step6_75_models_dir = "" - if hasattr(step6_75_widget, 'get_path'): - step6_75_models_dir = step6_75_widget.get_path() or "" - elif hasattr(step6_75_widget, 'text'): - step6_75_models_dir = step6_75_widget.text() or "" - step6_75_models_dir = step6_75_models_dir.strip() + # 2. 尝试从 Step9 界面读取自定义回归模型目录 + if main_window and hasattr(main_window, 'step12_panel'): + step9_widget = getattr(main_window.step9_panel, 'output_dir', None) + step9_models_dir = "" + if hasattr(step9_widget, 'get_path'): + step9_models_dir = step9_widget.get_path() or "" + elif hasattr(step9_widget, 'text'): + step9_models_dir = step9_widget.text() or "" + step9_models_dir = step9_models_dir.strip() - if step6_75_models_dir: + if step9_models_dir: # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(step6_75_models_dir): - step6_75_models_dir = os.path.join(self.work_dir or '', step6_75_models_dir).replace('\\', '/') + if not os.path.isabs(step9_models_dir): + step9_models_dir = os.path.join(self.work_dir or '', step9_models_dir).replace('\\', '/') existing_models = self.regression_models_dir.get_path() if not existing_models or not existing_models.strip(): - self.regression_models_dir.set_path(step6_75_models_dir) + self.regression_models_dir.set_path(step9_models_dir) - # 3. 自动填充回归模型目录(如果 step6_75 未提供) + # 3. 自动填充回归模型目录(如果 step9 未提供) if self.work_dir: models_dir = self.regression_models_dir.get_path().strip() if not models_dir: @@ -208,7 +208,7 @@ class Step8_75Panel(QWidget): self.enable_checkbox.setChecked(config['enabled']) def run_step(self): - """独立运行步骤8.75""" + """独立运行步骤12""" sampling_csv_path = self.sampling_csv_file.get_path() if not sampling_csv_path: QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!") @@ -225,6 +225,6 @@ class Step8_75Panel(QWidget): parent = parent.parent() if parent and hasattr(parent, 'run_single_step'): - parent.run_single_step('step8_75', {'step8_75': config}) + parent.run_single_step('step12', {'step12': config}) else: QMessageBox.critical(self, "错误", "无法找到父级GUI对象") diff --git a/src/gui/panels/step14_panel.py b/src/gui/panels/step14_panel.py new file mode 100644 index 0000000..77e9aa7 --- /dev/null +++ b/src/gui/panels/step14_panel.py @@ -0,0 +1,533 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Step14 面板 - 分布图生成 +""" + +import os +import traceback +from pathlib import Path +from typing import List, Optional + +from PyQt5.QtCore import Qt, QThread, pyqtSignal +from PyQt5.QtWidgets import ( + QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout, + QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox, + QRadioButton, QButtonGroup, QMessageBox, QFileDialog, +) + +from src.gui.components.custom_widgets import FileSelectWidget +from src.gui.styles import ModernStylesheet + +# Pipeline 可用性(与 core/worker_thread.py 保持一致) +try: + from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline + PIPELINE_AVAILABLE = True +except ImportError: + PIPELINE_AVAILABLE = False + + +class Step14BatchThread(QThread): + """专题图:按文件夹内多个预测 CSV 批量生成分布图。""" + + finished_ok = pyqtSignal(int) + failed = pyqtSignal(str) + log_message = pyqtSignal(str, str) + + def __init__(self, work_dir: str, csv_paths: List[str], step14_kwargs: dict, output_dir_optional: Optional[str]): + super().__init__() + self.work_dir = work_dir + self.csv_paths = csv_paths + self.step14_kwargs = step14_kwargs + self.output_dir_optional = (output_dir_optional or "").strip() or None + + def run(self): + mpl_prev = None + try: + import matplotlib + mpl_prev = matplotlib.get_backend() + except Exception: + pass + try: + import matplotlib.pyplot as plt + plt.switch_backend("Agg") + except Exception: + mpl_prev = None + try: + from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline + pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir) + n = len(self.csv_paths) + for i, csv_p in enumerate(self.csv_paths): + self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info") + kw = {**self.step14_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True} + if self.output_dir_optional: + stem = Path(csv_p).stem + kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png") + else: + kw["output_image_path"] = None + pipeline.step9_generate_distribution_map(**kw) + self.finished_ok.emit(n) + except Exception as e: + self.failed.emit(f"{e}\n{traceback.format_exc()}") + finally: + if mpl_prev: + try: + import matplotlib.pyplot as plt + plt.switch_backend(mpl_prev) + except Exception: + pass + + +class Step14Panel(QWidget): + """步骤14:分布图生成""" + def __init__(self, parent=None): + super().__init__(parent) + self._batch_thread = None + self.init_ui() + + def init_ui(self): + layout = QVBoxLayout() + + hint = QLabel( + "独立运行:可选「单个 CSV」或「文件夹批量」(扫描目录下所有 .csv)。" + "完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。" + ) + hint.setWordWrap(True) + hint.setStyleSheet( + f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};" + ) + layout.addWidget(hint) + + mode_row = QHBoxLayout() + self.mode_single_rb = QRadioButton("单个 CSV 文件") + self.mode_folder_rb = QRadioButton("文件夹批量") + self._mode_group = QButtonGroup(self) + self._mode_group.addButton(self.mode_single_rb, 0) + self._mode_group.addButton(self.mode_folder_rb, 1) + mode_row.addWidget(self.mode_single_rb) + mode_row.addWidget(self.mode_folder_rb) + mode_row.addStretch() + layout.addLayout(mode_row) + + # ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ---------- + radio_style = """ + QRadioButton { + font-size: 14px; + spacing: 8px; + color: #333333; + } + QRadioButton::indicator { + width: 16px; + height: 16px; + border: 2px solid #999999; + border-radius: 3px; + background-color: white; + } + QRadioButton::indicator:checked { + border: 2px solid #0078d4; + background-color: #0078d4; + image: none; + } + QRadioButton::indicator:hover { + border: 2px solid #005a9e; + } + """ + self.mode_single_rb.setStyleSheet(radio_style) + self.mode_folder_rb.setStyleSheet(radio_style) + + self.prediction_csv_file = FileSelectWidget( + "预测结果CSV:", + "CSV Files (*.csv);;All Files (*.*)" + ) + layout.addWidget(self.prediction_csv_file) + + folder_row = QHBoxLayout() + self.prediction_csv_dir_label = QLabel("预测CSV目录:") + self.prediction_csv_dir_label.setMinimumWidth(120) + self.prediction_csv_dir_edit = QLineEdit() + self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…") + pred_dir_btn = QPushButton("浏览…") + pred_dir_btn.setMaximumWidth(80) + pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir) + folder_row.addWidget(self.prediction_csv_dir_label) + folder_row.addWidget(self.prediction_csv_dir_edit, 1) + folder_row.addWidget(pred_dir_btn) + self._folder_row_widget = QWidget() + self._folder_row_widget.setLayout(folder_row) + layout.addWidget(self._folder_row_widget) + + self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv)") + layout.addWidget(self.recursive_csv_cb) + + self.boundary_file = FileSelectWidget( + "边界文件:", + "Shapefiles (*.shp);;All Files (*.*)" + ) + layout.addWidget(self.boundary_file) + + # 参数设置 + params_group = QGroupBox("生成参数") + params_layout = QFormLayout() + + self.resolution = QDoubleSpinBox() + self.resolution.setRange(1, 1000) + self.resolution.setValue(30) + params_layout.addRow("分辨率(米):", self.resolution) + + self.input_crs = QLineEdit() + self.input_crs.setText("EPSG:32651") + params_layout.addRow("输入坐标系:", self.input_crs) + + self.output_crs = QLineEdit() + self.output_crs.setText("EPSG:4326") + params_layout.addRow("输出坐标系:", self.output_crs) + + self.show_points = QCheckBox("显示采样点") + params_layout.addRow("", self.show_points) + + self.use_diffusion = QCheckBox("启用距离扩散") + self.use_diffusion.setChecked(True) + params_layout.addRow("", self.use_diffusion) + + params_group.setLayout(params_layout) + layout.addWidget(params_group) + + # 输出目录 + self.output_dir = FileSelectWidget( + "输出分布图目录:", + "Directories;;All Files (*.*)" + ) + self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization") + self.output_dir.browse_btn.clicked.disconnect() + self.output_dir.browse_btn.clicked.connect(self.browse_output_dir) + layout.addWidget(self.output_dir) + + # 启用步骤 + self.enable_checkbox = QCheckBox("启用此步骤") + self.enable_checkbox.setChecked(True) + layout.addWidget(self.enable_checkbox) + + # 独立运行按钮 + self.run_button = QPushButton("独立运行此步骤") + self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) + self.run_button.clicked.connect(self.run_step) + layout.addWidget(self.run_button) + + layout.addStretch() + self.setLayout(layout) + + # 信号绑定与初始状态 + self.mode_single_rb.toggled.connect(self._toggle_input_mode) + self.mode_folder_rb.toggled.connect(self._toggle_input_mode) + self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV" + self._toggle_input_mode() # 根据默认值设置初始显示状态 + + def _toggle_input_mode(self): + """槽函数:根据单选框状态动态显示/隐藏对应的输入组件。""" + folder_mode = self.mode_folder_rb.isChecked() + # 单个 CSV 模式:显示单文件选择,隐藏文件夹选择 + self.prediction_csv_file.setVisible(not folder_mode) + # 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择 + self._folder_row_widget.setVisible(folder_mode) + self.recursive_csv_cb.setVisible(folder_mode) + + def _get_default_work_dir(self): + """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" + if hasattr(self, 'work_dir') and self.work_dir: + return str(self.work_dir) + mw = self.window() + if mw and hasattr(mw, 'work_dir') and mw.work_dir: + return str(mw.work_dir) + return "" + + def browse_prediction_csv_dir(self): + default = self._get_default_work_dir() + if default: + default = os.path.join(default, "11_12_13_predictions") + d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default) + if d: + self.prediction_csv_dir_edit.setText(d) + + def _collect_csv_paths_from_folder(self) -> List[str]: + folder = (self.prediction_csv_dir_edit.text() or "").strip() + if not folder or not os.path.isdir(folder): + return [] + root = Path(folder) + if self.recursive_csv_cb.isChecked(): + files = sorted(root.rglob("*.csv")) + else: + files = sorted(root.glob("*.csv")) + return [str(p) for p in files if p.is_file()] + + def _step14_base_pipeline_kwargs(self) -> dict: + return { + 'boundary_shp_path': self.boundary_file.get_path(), + 'resolution': self.resolution.value(), + 'input_crs': self.input_crs.text(), + 'output_crs': self.output_crs.text(), + 'show_sample_points': self.show_points.isChecked(), + 'use_distance_diffusion': self.use_diffusion.isChecked(), + } + + def get_config(self): + pred_csv = (self.prediction_csv_file.get_path() or "").strip() + folder_mode = self.mode_folder_rb.isChecked() + pred_dir = (self.prediction_csv_dir_edit.text() or "").strip() + config = { + 'step14_batch_mode': 'folder' if folder_mode else 'single', + 'prediction_csv_dir': pred_dir if pred_dir else None, + 'recursive_csv_scan': self.recursive_csv_cb.isChecked(), + 'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None), + 'boundary_shp_path': self.boundary_file.get_path(), + 'resolution': self.resolution.value(), + 'input_crs': self.input_crs.text(), + 'output_crs': self.output_crs.text(), + 'show_sample_points': self.show_points.isChecked(), + 'use_distance_diffusion': self.use_diffusion.isChecked(), + } + out_dir = (self.output_dir.get_path() or "").strip() + if not folder_mode and pred_csv and out_dir: + stem = Path(pred_csv).stem + config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png") + else: + config['output_image_path'] = None + return config + + def set_config(self, config): + mode = config.get('step14_batch_mode', 'single') + if mode == 'folder': + self.mode_folder_rb.setChecked(True) + else: + self.mode_single_rb.setChecked(True) + if config.get('prediction_csv_dir'): + self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir'])) + if 'recursive_csv_scan' in config: + self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan'])) + if 'prediction_csv_path' in config and config['prediction_csv_path']: + self.prediction_csv_file.set_path(str(config['prediction_csv_path'])) + if 'boundary_shp_path' in config: + self.boundary_file.set_path(config['boundary_shp_path']) + if 'resolution' in config: + self.resolution.setValue(config['resolution']) + if 'input_crs' in config: + self.input_crs.setText(config['input_crs']) + if 'output_crs' in config: + self.output_crs.setText(config['output_crs']) + if 'show_sample_points' in config: + self.show_points.setChecked(config['show_sample_points']) + if 'use_distance_diffusion' in config: + self.use_diffusion.setChecked(config['use_distance_diffusion']) + if 'output_dir' in config and config['output_dir']: + self.output_dir.set_path(str(config['output_dir'])) + elif config.get('output_image_path'): + p = Path(str(config['output_image_path'])) + if p.parent and str(p.parent) != '.': + self.output_dir.set_path(str(p.parent)) + + def update_from_config(self, work_dir=None, pipeline=None): + """从全局配置自动填充预测结果目录 + + 优先使用 Step8(机器学习预测)的输出目录作为待预测 CSV 目录; + 其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。 + + Args: + work_dir: 工作目录路径 + pipeline: Pipeline 实例(未使用,保留接口兼容性) + """ + try: + import traceback + + if work_dir: + self.work_dir = work_dir + elif hasattr(self, 'work_dir') and self.work_dir: + pass + else: + self.work_dir = None + + main_window = self.window() + if not main_window: + return + + # 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先) + pred_dir = None + if hasattr(main_window, 'step11_prediction_panel'): + step8_widget = getattr(main_window.step11_prediction_panel, 'output_file', None) + step8_output = "" + if hasattr(step8_widget, 'get_path'): + step8_output = step8_widget.get_path() or "" + elif hasattr(step8_widget, 'text'): + step8_output = step8_widget.text() or "" + + if step8_output: + # 若为相对路径,使用 work_dir 合成为绝对路径 + if not os.path.isabs(step8_output): + step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/') + # 提取父目录后追加 Machine_Learning_Prediction(最底层真实子目录) + base_pred_dir = str(Path(step8_output).parent) + ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction" + pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir + + # 2. 备选:从 Step11 界面读取非经验预测输出目录 + if not pred_dir and hasattr(main_window, 'step11_panel'): + step8_5_widget = getattr(main_window.step11_panel, 'output_file', None) + step8_5_output = "" + if hasattr(step8_5_widget, 'get_path'): + step8_5_output = step8_5_widget.get_path() or "" + elif hasattr(step8_5_widget, 'text'): + step8_5_output = step8_5_widget.text() or "" + + if step8_5_output: + # 若为相对路径,使用 work_dir 合成为绝对路径 + if not os.path.isabs(step8_5_output): + step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/') + pred_dir = str(Path(step8_5_output).parent) + + # 3. 备选:从 Step12 界面读取自定义回归预测输出目录 + if not pred_dir and hasattr(main_window, 'step12_panel'): + step8_75_widget = getattr(main_window.step12_panel, 'output_dir_widget', None) + step8_75_output = "" + if hasattr(step8_75_widget, 'get_path'): + step8_75_output = step8_75_widget.get_path() or "" + elif hasattr(step8_75_widget, 'text'): + step8_75_output = step8_75_widget.text() or "" + + if step8_75_output: + pred_dir = step8_75_output + + # 自动填入"预测CSV目录"(文件夹批量模式) + if pred_dir: + existing_dir = (self.prediction_csv_dir_edit.text() or "").strip() + if not existing_dir: + self.prediction_csv_dir_edit.setText(pred_dir) + # 切换到文件夹批量模式 + self.mode_folder_rb.setChecked(True) + + # 4. 自动填充输出目录(14_visualization) + if self.work_dir: + output_dir = os.path.join(self.work_dir, "14_visualization") + os.makedirs(output_dir, exist_ok=True) + existing_out = self.output_dir.get_path() + if not existing_out or not existing_out.strip(): + self.output_dir.set_path(output_dir) + + # 5. 自动探测原始矢量边界文件(.shp)作为专题图底图 + # 优先回溯 input-test/roi.shp,geopandas.read_file 仅支持矢量格式 + if self.work_dir: + possible_shp = None + candidates = [ + Path(self.work_dir).parent / "input-test" / "roi.shp", + Path(self.work_dir) / "roi.shp", + Path(self.work_dir).parent / "roi.shp", + ] + for candidate in candidates: + if candidate.exists() and candidate.suffix.lower() == ".shp": + possible_shp = candidate + break + + existing_boundary = (self.boundary_file.get_path() or "").strip() + if not existing_boundary and possible_shp: + self.boundary_file.set_path(str(possible_shp)) + elif not existing_boundary: + # 未找到 .shp 时清空并提示用户手动选择矢量文件 + self.boundary_file.set_path("") + print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。") + except Exception as e: + import traceback + print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}") + traceback.print_exc() + + def browse_output_dir(self): + """浏览输出目录""" + default = self._get_default_work_dir() + if default: + default = os.path.join(default, "14_visualization") + dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default) + if dir_path: + self.output_dir.set_path(dir_path) + + def run_step(self): + """独立运行步骤14""" + if self._batch_thread and self._batch_thread.isRunning(): + QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。") + return + + boundary_shp_path = self.boundary_file.get_path() + if not boundary_shp_path: + QMessageBox.warning(self, "输入验证失败", "请选择边界文件") + return + if not os.path.exists(boundary_shp_path): + QMessageBox.warning(self, "输入验证失败", "边界文件不存在") + return + + parent = self.parent() + while parent and not hasattr(parent, 'run_single_step'): + parent = parent.parent() + + if not parent or not hasattr(parent, 'run_single_step'): + QMessageBox.critical(self, "错误", "无法找到父级GUI对象") + return + + if self.mode_folder_rb.isChecked(): + csv_list = self._collect_csv_paths_from_folder() + if not csv_list: + QMessageBox.warning( + self, + "输入验证失败", + "所选文件夹中未找到 .csv 文件,或目录无效。\n" + "可勾选「包含子文件夹」以递归扫描。", + ) + return + if not PIPELINE_AVAILABLE: + QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。") + return + work_dir = getattr(parent, "work_dir", None) or "./work_dir" + work_dir = str(work_dir) + base_kw = self._step14_base_pipeline_kwargs() + out_dir_opt = (self.output_dir.get_path() or "").strip() or None + self.run_button.setEnabled(False) + self._batch_thread = Step14BatchThread(work_dir, csv_list, base_kw, out_dir_opt) + main_win = parent + + def _batch_log(msg, lvl): + if hasattr(main_win, "log_message"): + main_win.log_message(msg, lvl) + + self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection) + self._batch_thread.finished_ok.connect(self._on_step14_batch_ok, Qt.QueuedConnection) + self._batch_thread.failed.connect(self._on_step14_batch_fail, Qt.QueuedConnection) + self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection) + self._batch_thread.start() + if hasattr(parent, "log_message"): + parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV,工作目录 {work_dir}", "info") + return + + prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip() + if not prediction_csv_path: + QMessageBox.warning( + self, + "输入验证失败", + "请选择「预测结果 CSV」文件,或切换到「文件夹批量」。", + ) + return + if not os.path.isfile(prediction_csv_path): + QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件") + return + + config = self.get_config() + parent.run_single_step('step14', {'step14': config}) + + def _on_step14_batch_ok(self, n: int): + QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。") + parent = self.parent() + while parent and not hasattr(parent, "log_message"): + parent = parent.parent() + if parent and hasattr(parent, "log_message"): + parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info") + + def _on_step14_batch_fail(self, err: str): + QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}") + parent = self.parent() + while parent and not hasattr(parent, "log_message"): + parent = parent.parent() + if parent and hasattr(parent, "log_message"): + parent.log_message(err, "error") diff --git a/src/gui/panels/step5_5_panel.py b/src/gui/panels/step5_5_panel.py deleted file mode 100644 index 6a33037..0000000 --- a/src/gui/panels/step5_5_panel.py +++ /dev/null @@ -1,225 +0,0 @@ -import os -import sys -import pandas as pd -from pathlib import Path -from typing import Dict, List, Union - -from PyQt5.QtWidgets import ( - QWidget, QVBoxLayout, QGroupBox, QGridLayout, - QHBoxLayout, QLabel, QCheckBox, QPushButton, QMessageBox, QScrollArea -) -from PyQt5.QtCore import Qt -from src.gui.components.custom_widgets import FileSelectWidget -from src.gui.styles import ModernStylesheet - -def get_resource_path(relative_path: str) -> str: - """适配开发与 PyInstaller 环境的路径获取逻辑。 - 支持两种打包模式: - 1. --onedir 模式:文件在 exe_root/_internal/ 下 → 检查 _internal 目录 - 2. --onefile 模式:文件在 sys._MEIPASS 平铺目录 - """ - # 优先检查 PyInstaller onefile 模式(文件平铺在 _MEIPASS 下) - if hasattr(sys, '_MEIPASS'): - internal_path = os.path.join(sys._MEIPASS, '_internal', relative_path) - if os.path.exists(internal_path): - return internal_path - return os.path.join(sys._MEIPASS, relative_path) - - # 兼容 PyInstaller onedir 模式的 _internal 目录(exe 同级目录下) - exe_dir = os.path.dirname(sys.executable) - internal_path = os.path.join(exe_dir, '_internal', relative_path) - if os.path.exists(internal_path): - return internal_path - - # 开发环境下:基于当前文件 (step5_5_panel.py) 的绝对路径进行回溯 - # 当前在 src/gui/panels/,目标在 src/gui/model/ - base_dir = Path(__file__).resolve().parent.parent / "model" - target_path = base_dir / os.path.basename(relative_path) - return str(target_path) - -class Step5_5Panel(QWidget): - def __init__(self, parent=None): - super().__init__(parent) - self.index_checkboxes: Dict[str, QCheckBox] = {} - # 标识为 waterindex.csv,目录跳转逻辑在 get_resource_path 中 - self.builtin_formula_path = get_resource_path("waterindex.csv") - - self.init_ui() - # 延迟一小会儿加载,确保UI框架已就绪 - self._auto_load_formulas() - - def init_ui(self): - main_layout = QVBoxLayout() - main_layout.setContentsMargins(20, 20, 20, 20) - main_layout.setSpacing(10) - - # 1. 路径展示区 (半透明只读) - path_group = QGroupBox("公式配置源 (内置)") - path_layout = QVBoxLayout() - self.formula_csv_widget = FileSelectWidget("内置CSV路径:", "CSV Files (*.csv)") - self.formula_csv_widget.set_path(self.builtin_formula_path) - self.formula_csv_widget.set_read_only(True) - # 视觉微调:提示用户这是内置的 - self.formula_csv_widget.line_edit.setStyleSheet("background-color: #f0f0f0; color: #666;") - path_layout.addWidget(self.formula_csv_widget) - path_group.setLayout(path_layout) - main_layout.addWidget(path_group) - - # 2. 训练数据输入 - input_group = QGroupBox("输入样本数据") - input_layout = QVBoxLayout() - self.training_data_widget = FileSelectWidget("特征提取CSV:", "CSV Files (*.csv)") - input_layout.addWidget(self.training_data_widget) - input_group.setLayout(input_layout) - main_layout.addWidget(input_group) - - # 3. 公式选择区 - self.formula_group = QGroupBox("待计算水质指数勾选") - formula_outer_layout = QVBoxLayout() - - btn_layout = QHBoxLayout() - self.select_all_btn = QPushButton("全选") - self.deselect_all_btn = QPushButton("清空") - self.select_all_btn.clicked.connect(self.select_all_formulas) - self.deselect_all_btn.clicked.connect(self.deselect_all_formulas) - btn_layout.addWidget(self.select_all_btn) - btn_layout.addWidget(self.deselect_all_btn) - btn_layout.addStretch() - - self.refresh_button = QPushButton("手动重新加载公式") - self.refresh_button.clicked.connect(lambda: self.refresh_formulas(silent=False)) - btn_layout.addWidget(self.refresh_button) - - formula_outer_layout.addLayout(btn_layout) - - # 核心滚动区 - scroll = QScrollArea() - scroll.setWidgetResizable(True) - scroll.setMinimumHeight(300) # 强制最小高度,防止塌陷 - self.scroll_content = QWidget() - self.formula_layout = QGridLayout(self.scroll_content) - self.formula_layout.setAlignment(Qt.AlignTop) # 靠顶对齐 - scroll.setWidget(self.scroll_content) - formula_outer_layout.addWidget(scroll) - - self.formula_group.setLayout(formula_outer_layout) - main_layout.addWidget(self.formula_group) - - # 4. 输出与运行 - output_group = QGroupBox("结果输出") - output_layout = QVBoxLayout() - self.output_file_widget = FileSelectWidget("保存路径:", "CSV Files (*.csv)", mode="save") - output_layout.addWidget(self.output_file_widget) - output_group.setLayout(output_layout) - main_layout.addWidget(output_group) - - self.enable_checkbox = QCheckBox("启用计算流程") - self.enable_checkbox.setChecked(True) - main_layout.addWidget(self.enable_checkbox) - - self.run_button = QPushButton("立即执行计算") - self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) - self.run_button.setMinimumHeight(40) - self.run_button.clicked.connect(self.run_step) - main_layout.addWidget(self.run_button) - - self.setLayout(main_layout) - - def _auto_load_formulas(self): - """启动时自动加载逻辑""" - if os.path.exists(self.builtin_formula_path): - self.refresh_formulas(silent=True) - else: - print(f"DEBUG: 自动加载失败,路径不存在: {self.builtin_formula_path}") - - def refresh_formulas(self, silent=False): - path = self.builtin_formula_path - if not os.path.exists(path): - if not silent: QMessageBox.warning(self, "错误", f"找不到内置公式文件:\n{path}") - return - - try: - # 清理旧列表 - for i in reversed(range(self.formula_layout.count())): - widget = self.formula_layout.itemAt(i).widget() - if widget: widget.deleteLater() - self.index_checkboxes.clear() - - # 鲁棒性读取:尝试不同编码 - for encoding in ['utf-8', 'gbk', 'utf-8-sig']: - try: - df = pd.read_csv(path, encoding=encoding) - if 'Formula_Name' in df.columns: break - except: continue - - if 'Formula_Name' not in df.columns: - if not silent: QMessageBox.critical(self, "错误", "CSV文件缺少 'Formula_Name' 列") - return - - names = df['Formula_Name'].dropna().unique().tolist() - - row, col = 0, 0 - for name in names: - name = str(name).strip() - if not name: continue - cb = QCheckBox(name) - cb.setChecked(True) - self.index_checkboxes[name] = cb - self.formula_layout.addWidget(cb, row, col) - col += 1 - if col >= 3: - col = 0 - row += 1 - - # 强制UI更新 - self.scroll_content.adjustSize() - print(f"✅ 成功加载 {len(self.index_checkboxes)} 个公式") - - except Exception as e: - if not silent: QMessageBox.critical(self, "加载失败", f"原因: {str(e)}") - - def select_all_formulas(self): - for cb in self.index_checkboxes.values(): cb.setChecked(True) - - def deselect_all_formulas(self): - for cb in self.index_checkboxes.values(): cb.setChecked(False) - - def get_config(self): - selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()] - return { - 'training_csv_path': self.training_data_widget.get_path(), - 'formula_csv_file': self.builtin_formula_path, - 'formula_names': selected, - 'output_file': self.output_file_widget.get_path(), - 'enabled': self.enable_checkbox.isChecked() - } - - def set_config(self, config): - if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path']) - if 'formula_names' in config: - sel = set(config['formula_names']) - for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel) - if 'output_file' in config: self.output_file_widget.set_path(config['output_file']) - self.enable_checkbox.setChecked(config.get('enabled', True)) - - def update_from_config(self, work_dir=None, pipeline=None): - if work_dir: self.work_dir = work_dir - main = self.window() - if hasattr(main, 'step5_panel'): - p5 = main.step5_panel.output_file.get_path() # 修正:变量名对齐 - if p5: - if not os.path.isabs(p5): p5 = os.path.join(self.work_dir or '', p5).replace('\\', '/') - self.training_data_widget.set_path(p5) - - if self.work_dir: - out = os.path.join(self.work_dir, "6_water_quality_indices", "training_spectra_indices.csv").replace('\\', '/') - self.output_file_widget.set_path(out) - - def run_step(self): - config = self.get_config() - if not config['training_csv_path']: - QMessageBox.warning(self, "提示", "请先选择输入数据") - return - parent = self.parent() - while parent and not hasattr(parent, 'run_single_step'): parent = parent.parent() - if parent: parent.run_single_step('step5_5', {'step5_5': config}) \ No newline at end of file diff --git a/src/gui/panels/step6_75_panel.py b/src/gui/panels/step6_75_panel.py deleted file mode 100644 index 130387c..0000000 --- a/src/gui/panels/step6_75_panel.py +++ /dev/null @@ -1,374 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -Step6_75 面板 - 自定义回归分析 -""" - -import os -from typing import Dict - -import pandas as pd -from PyQt5.QtWidgets import ( - QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout, - QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton, - QScrollArea, QMessageBox, -) - -from src.gui.components.custom_widgets import FileSelectWidget -from src.gui.styles import ModernStylesheet - - -class Step6_75Panel(QWidget): - """步骤6.75:自定义回归分析""" - def __init__(self, parent=None): - super().__init__(parent) - self.x_column_checkboxes: Dict[str, QCheckBox] = {} - self.y_column_checkboxes: Dict[str, QCheckBox] = {} - self.method_checkboxes: Dict[str, QCheckBox] = {} - self.csv_columns = [] - self.init_ui() - - def init_ui(self): - layout = QVBoxLayout() - - hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法") - hint.setStyleSheet("color: #666; font-size: 11px;") - layout.addWidget(hint) - - # CSV文件选择 - csv_group = QGroupBox("数据文件") - csv_layout = QVBoxLayout() - - self.csv_file = FileSelectWidget( - "输入CSV文件:", - "CSV Files (*.csv);;All Files (*.*)" - ) - self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed) - csv_layout.addWidget(self.csv_file) - - self.refresh_btn = QPushButton("刷新列信息") - self.refresh_btn.clicked.connect(self.refresh_csv_columns) - csv_layout.addWidget(self.refresh_btn) - - csv_group.setLayout(csv_layout) - layout.addWidget(csv_group) - - # 自变量选择 - x_group = QGroupBox("自变量列选择 (可多选)") - x_layout = QVBoxLayout() - - x_scroll = QScrollArea() - x_scroll.setWidgetResizable(True) - x_scroll.setMinimumHeight(250) - x_scroll.setMaximumHeight(350) - - x_widget = QWidget() - self.x_columns_layout = QGridLayout() - x_widget.setLayout(self.x_columns_layout) - - x_scroll.setWidget(x_widget) - x_layout.addWidget(x_scroll) - - x_btn_layout = QHBoxLayout() - self.x_select_all = QPushButton("全选") - self.x_deselect_all = QPushButton("全不选") - self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True)) - self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False)) - x_btn_layout.addWidget(self.x_select_all) - x_btn_layout.addWidget(self.x_deselect_all) - x_btn_layout.addStretch() - x_layout.addLayout(x_btn_layout) - - x_group.setLayout(x_layout) - layout.addWidget(x_group) - - # 因变量选择 - y_group = QGroupBox("因变量列选择 (可多选)") - y_layout = QVBoxLayout() - - y_scroll = QScrollArea() - y_scroll.setWidgetResizable(True) - y_scroll.setMinimumHeight(200) - y_scroll.setMaximumHeight(300) - - y_widget = QWidget() - self.y_columns_layout = QGridLayout() - y_widget.setLayout(self.y_columns_layout) - - y_scroll.setWidget(y_widget) - y_layout.addWidget(y_scroll) - - y_btn_layout = QHBoxLayout() - self.y_select_all = QPushButton("全选") - self.y_deselect_all = QPushButton("全不选") - self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True)) - self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False)) - y_btn_layout.addWidget(self.y_select_all) - y_btn_layout.addWidget(self.y_deselect_all) - y_btn_layout.addStretch() - y_layout.addLayout(y_btn_layout) - - y_group.setLayout(y_layout) - layout.addWidget(y_group) - - # 回归方法选择 - method_group = QGroupBox("回归方法选择 (可多选)") - method_layout = QVBoxLayout() - - method_grid = QGridLayout() - regression_methods = [ - 'linear', 'exponential', 'power', 'logarithmic', - 'polynomial', 'hyperbolic', 'sigmoidal' - ] - - for i, method in enumerate(regression_methods): - checkbox = QCheckBox(method) - if method in ['linear', 'exponential', 'power', 'logarithmic']: - checkbox.setChecked(True) - self.method_checkboxes[method] = checkbox - method_grid.addWidget(checkbox, i // 3, i % 3) - - method_layout.addLayout(method_grid) - - method_btn_layout = QHBoxLayout() - self.method_select_all = QPushButton("全选") - self.method_deselect_all = QPushButton("全不选") - self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True)) - self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False)) - method_btn_layout.addWidget(self.method_select_all) - method_btn_layout.addWidget(self.method_deselect_all) - method_btn_layout.addStretch() - method_layout.addLayout(method_btn_layout) - - method_group.setLayout(method_layout) - layout.addWidget(method_group) - - # 输出目录 - output_group = QGroupBox("输出设置") - output_layout = QFormLayout() - - self.output_dir = QLineEdit() - self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充 - output_layout.addRow("输出目录名:", self.output_dir) - - output_group.setLayout(output_layout) - layout.addWidget(output_group) - - # 启用步骤 - self.enable_checkbox = QCheckBox("启用此步骤") - self.enable_checkbox.setChecked(True) - layout.addWidget(self.enable_checkbox) - - # 独立运行按钮 - self.run_button = QPushButton("独立运行此步骤") - self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) - self.run_button.clicked.connect(self.run_step) - layout.addWidget(self.run_button) - - layout.addStretch() - self.setLayout(layout) - - def toggle_checkboxes(self, checkboxes_dict, checked): - """统一设置checkbox状态""" - for checkbox in checkboxes_dict.values(): - checkbox.setChecked(checked) - - def on_csv_file_changed(self): - """CSV文件改变时自动刷新列信息""" - self.refresh_csv_columns() - - def refresh_csv_columns(self): - """刷新CSV文件的列信息""" - csv_path = self.csv_file.get_path() - if not csv_path or not os.path.exists(csv_path): - self.csv_columns = [] - self.update_column_widgets() - return - - try: - df = pd.read_csv(csv_path, nrows=0) - self.csv_columns = list(df.columns) - self.update_column_widgets() - except Exception as e: - self.csv_columns = [] - self.update_column_widgets() - print(f"读取CSV列信息失败: {e}") - - def update_column_widgets(self): - """更新列选择组件""" - for checkbox in self.x_column_checkboxes.values(): - checkbox.setParent(None) - self.x_column_checkboxes.clear() - - for checkbox in self.y_column_checkboxes.values(): - checkbox.setParent(None) - self.y_column_checkboxes.clear() - - if not self.csv_columns: - return - - for i, col in enumerate(self.csv_columns): - checkbox = QCheckBox(col) - if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']): - checkbox.setChecked(True) - self.x_column_checkboxes[col] = checkbox - self.x_columns_layout.addWidget(checkbox, i // 3, i % 3) - - for i, col in enumerate(self.csv_columns): - checkbox = QCheckBox(col) - if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']): - checkbox.setChecked(True) - self.y_column_checkboxes[col] = checkbox - self.y_columns_layout.addWidget(checkbox, i // 2, i % 2) - - self.x_columns_layout.update() - self.y_columns_layout.update() - - def get_config(self): - selected_x_columns = [ - col for col, checkbox in self.x_column_checkboxes.items() - if checkbox.isChecked() - ] - selected_y_columns = [ - col for col, checkbox in self.y_column_checkboxes.items() - if checkbox.isChecked() - ] - selected_methods = [ - method for method, checkbox in self.method_checkboxes.items() - if checkbox.isChecked() - ] - if not selected_methods: - selected_methods = 'all' - - return { - 'csv_path': self.csv_file.get_path() or None, - 'x_columns': selected_x_columns, - 'y_columns': selected_y_columns, - 'methods': selected_methods, - 'output_dir': self.output_dir.text().strip() or None, - 'enabled': self.enable_checkbox.isChecked() - } - - def set_config(self, config): - if 'csv_path' in config: - self.csv_file.set_path(config['csv_path']) - self.refresh_csv_columns() - - if 'x_columns' in config: - selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set() - for col, checkbox in self.x_column_checkboxes.items(): - checkbox.setChecked(col in selected_x) - - if 'y_columns' in config: - selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set() - for col, checkbox in self.y_column_checkboxes.items(): - checkbox.setChecked(col in selected_y) - - if 'methods' in config: - methods = config['methods'] - if isinstance(methods, list): - selected_methods = set(methods) - elif methods == 'all': - selected_methods = set(self.method_checkboxes.keys()) - else: - selected_methods = set() - for method, checkbox in self.method_checkboxes.items(): - checkbox.setChecked(method in selected_methods) - - if 'output_dir' in config: - self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling") - if 'enabled' in config: - self.enable_checkbox.setChecked(config['enabled']) - - def update_from_config(self, work_dir=None, pipeline=None): - """从全局配置自动填充训练数据和输出路径 - - Args: - work_dir: 工作目录路径 - pipeline: Pipeline 实例(未使用,保留接口兼容性) - """ - try: - import traceback - - if work_dir: - self.work_dir = work_dir - elif hasattr(self, 'work_dir') and self.work_dir: - pass - else: - self.work_dir = None - - # 1. 尝试从 Step5 界面读取训练光谱 CSV 路径 - main_window = self.window() - if main_window and hasattr(main_window, 'step5_panel'): - step5_widget = getattr(main_window.step5_panel, 'output_file', None) - step5_output_path = "" - if hasattr(step5_widget, 'get_path'): - step5_output_path = step5_widget.get_path() or "" - elif hasattr(step5_widget, 'text'): - step5_output_path = step5_widget.text() or "" - - if step5_output_path: - # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(step5_output_path): - step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/') - existing = self.csv_file.get_path() - if not existing or not existing.strip(): - self.csv_file.set_path(step5_output_path) - - # 2. 自动填充输出目录(9_Custom_Regression_Modeling) - if self.work_dir: - output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling") - os.makedirs(output_dir, exist_ok=True) - existing_out = self.output_dir.text().strip() - if not existing_out: - self.output_dir.setText(output_dir) - except Exception as e: - import traceback - print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}") - traceback.print_exc() - - def run_step(self): - """独立运行步骤6.75""" - csv_path = self.csv_file.get_path() - - if not csv_path: - QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件") - return - if not os.path.exists(csv_path): - QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在") - return - - selected_x_columns = [ - col for col, checkbox in self.x_column_checkboxes.items() - if checkbox.isChecked() - ] - if not selected_x_columns: - QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列") - return - - selected_y_columns = [ - col for col, checkbox in self.y_column_checkboxes.items() - if checkbox.isChecked() - ] - if not selected_y_columns: - QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列") - return - - selected_methods = [ - method for method, checkbox in self.method_checkboxes.items() - if checkbox.isChecked() - ] - if not selected_methods: - QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法") - return - - config = self.get_config() - - parent = self.parent() - while parent and not hasattr(parent, 'run_single_step'): - parent = parent.parent() - - if parent and hasattr(parent, 'run_single_step'): - parent.run_single_step('step6_75', {'step6_75': config}) - else: - QMessageBox.critical(self, "错误", "无法找到父级GUI对象") diff --git a/src/gui/panels/step6_panel.py b/src/gui/panels/step6_panel.py deleted file mode 100644 index c35121f..0000000 --- a/src/gui/panels/step6_panel.py +++ /dev/null @@ -1,415 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -Step6 面板 - 机器学习建模 -""" - -import os - -from PyQt5.QtWidgets import ( - QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout, - QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox, - QPushButton, QFileDialog, QMessageBox, -) -from PyQt5.QtCore import Qt - -from src.gui.components.custom_widgets import FileSelectWidget -from src.gui.styles import ModernStylesheet - - -# ============================================================ -# 中文映射表(内部键名 -> 显示文本) -# ============================================================ - -# 预处理方法:内部键 -> 显示文本 -PREPROC_CHINESE = { - 'None': '无 (None)', - 'MMS': '最小-最大归一化 (MMS)', - 'SS': '标度化 (SS)', - 'SNV': '标准正态变换 (SNV)', - 'MA': '移动平均 (MA)', - 'SG': 'Savitzky-Golay (SG)', - 'MSC': '多元散射校正 (MSC)', - 'D1': '一阶导数 (D1)', - 'D2': '二阶导数 (D2)', - 'DT': '去趋势 (DT)', - 'CT': '中心化 (CT)', -} - -# 模型类型:内部键 -> 显示文本 -MODEL_CHINESE = { - # 线性模型 - 'LinearRegression': '多元线性回归 (MLR)', - 'Ridge': '岭回归 (Ridge)', - 'Lasso': '套索回归 (Lasso)', - 'ElasticNet': '弹性网络 (ElasticNet)', - 'PLS': '偏最小二乘 (PLSR)', - # 树模型 - 'DecisionTree': '决策树 (CART)', - 'RF': '随机森林 (RF)', - 'ExtraTrees': '极端随机树 (ET)', - 'XGBoost': '极值梯度提升 (XGBoost)', - 'LightGBM': '轻量梯度提升 (LightGBM)', - 'CatBoost': '类别梯度提升 (CatBoost)', - # 集成学习 - 'GradientBoosting': '梯度提升树 (GBDT)', - 'AdaBoost': '自适应提升 (AdaBoost)', - # 其他模型 - 'SVR': '支持向量回归 (SVR)', - 'KNN': 'K近邻回归 (KNN)', - 'MLP': '多层感知机 (BP神经网络)', -} - -# 数据划分方法:内部键 -> 显示文本 -SPLIT_CHINESE = { - 'spxy': 'SPXY 算法 (考量X-Y空间)', - 'ks': 'KS 算法 (考量X空间)', - 'random': '随机划分 (Random)', -} - - -class Step6Panel(QWidget): - """步骤6:机器学习建模""" - def __init__(self, parent=None): - super().__init__(parent) - self.init_ui() - - def init_ui(self): - layout = QVBoxLayout() - - # 标题 - - - # 训练数据文件(用于独立运行) - self.training_csv_file = FileSelectWidget( - "训练数据:", - "CSV Files (*.csv);;All Files (*.*)" - ) - layout.addWidget(self.training_csv_file) - - # 机器学习模型页面 - self.ml_page = QWidget() - self.create_ml_page() - layout.addWidget(self.ml_page) - - # 输出文件路径 - self.output_path = FileSelectWidget( - "输出文件:", - "CSV Files (*.csv);;All Files (*.*)", - mode="save" - ) - self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...") - self.output_path.browse_btn.clicked.disconnect() - self.output_path.browse_btn.clicked.connect(self.browse_output_path) - layout.addWidget(self.output_path) - - # 启用步骤 - self.enable_checkbox = QCheckBox("启用此步骤") - self.enable_checkbox.setChecked(False) - layout.addWidget(self.enable_checkbox) - - # 独立运行按钮 - self.run_btn = QPushButton("独立运行此步骤") - self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) - self.run_btn.clicked.connect(self.run_step) - layout.addWidget(self.run_btn) - - layout.addStretch() - self.setLayout(layout) - - def create_ml_page(self): - """创建机器学习模型页面""" - layout = QVBoxLayout() - - # 参数设置 - params_group = QGroupBox("训练参数") - params_layout = QFormLayout() - - self.feature_start = QLineEdit() - self.feature_start.setText("374.285004") - params_layout.addRow("特征起始列:", self.feature_start) - - self.cv_folds = QSpinBox() - self.cv_folds.setRange(2, 10) - self.cv_folds.setValue(3) - params_layout.addRow("交叉验证折数:", self.cv_folds) - - params_group.setLayout(params_layout) - layout.addWidget(params_group) - - # 预处理方法 - 多选 - preproc_group = QGroupBox("预处理方法 (可多选)") - preproc_layout = QVBoxLayout() - - preproc_grid = QGridLayout() - self.preproc_checkboxes = {} - preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT'] - - for i, method in enumerate(preproc_methods): - checkbox = QCheckBox(PREPROC_CHINESE.get(method, method)) - checkbox.setChecked(False) - self.preproc_checkboxes[method] = checkbox - preproc_grid.addWidget(checkbox, i // 4, i % 4) - - button_layout = QHBoxLayout() - select_all_btn = QPushButton("全选") - deselect_all_btn = QPushButton("全不选") - select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True)) - deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False)) - button_layout.addWidget(select_all_btn) - button_layout.addWidget(deselect_all_btn) - button_layout.addStretch() - - preproc_layout.addLayout(preproc_grid) - preproc_layout.addLayout(button_layout) - preproc_group.setLayout(preproc_layout) - layout.addWidget(preproc_group) - - # 模型选择 - 多选 - model_group = QGroupBox("模型类型 (可多选)") - model_layout = QVBoxLayout() - - model_grid = QGridLayout() - self.model_checkboxes = {} - - model_groups = [ - ("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']), - ("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']), - ("【集成学习】", ['GradientBoosting', 'AdaBoost']), - ("【其他模型】", ['SVR', 'KNN', 'MLP']) - ] - - row = 0 - for group_name, models in model_groups: - group_label = QLabel(f"{group_name}") - group_label.setStyleSheet( - f"background-color: {ModernStylesheet.COLORS['hover']}; " - f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; " - f"border-radius: 3px;" - ) - model_grid.addWidget(group_label, row, 0, 1, 4) - row += 1 - - for i, model in enumerate(models): - checkbox = QCheckBox(MODEL_CHINESE.get(model, model)) - checkbox.setChecked(False) - self.model_checkboxes[model] = checkbox - model_grid.addWidget(checkbox, row, i % 4) - if (i + 1) % 4 == 0: - row += 1 - - row += 1 - - model_button_layout = QHBoxLayout() - model_select_all = QPushButton("全选") - model_deselect_all = QPushButton("全不选") - model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True)) - model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False)) - model_button_layout.addWidget(model_select_all) - model_button_layout.addWidget(model_deselect_all) - model_button_layout.addStretch() - - model_layout.addLayout(model_grid) - model_layout.addLayout(model_button_layout) - model_group.setLayout(model_layout) - layout.addWidget(model_group) - - # 数据划分方法 - 多选 - split_group = QGroupBox("数据划分方法 (可多选)") - split_layout = QVBoxLayout() - - split_grid = QGridLayout() - self.split_checkboxes = {} - split_methods = ['spxy', 'ks', 'random'] - - for i, method in enumerate(split_methods): - checkbox = QCheckBox(SPLIT_CHINESE.get(method, method)) - checkbox.setChecked(False) - self.split_checkboxes[method] = checkbox - split_grid.addWidget(checkbox, 0, i) - - split_button_layout = QHBoxLayout() - split_select_all = QPushButton("全选") - split_deselect_all = QPushButton("全不选") - split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True)) - split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False)) - split_button_layout.addWidget(split_select_all) - split_button_layout.addWidget(split_deselect_all) - split_button_layout.addStretch() - - split_layout.addLayout(split_grid) - split_layout.addLayout(split_button_layout) - split_group.setLayout(split_layout) - layout.addWidget(split_group) - - self.ml_page.setLayout(layout) - - def _toggle_checkboxes(self, checkboxes_dict, checked): - """统一设置checkbox状态""" - for checkbox in checkboxes_dict.values(): - checkbox.setChecked(checked) - - def _get_default_work_dir(self): - """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" - if hasattr(self, 'work_dir') and self.work_dir: - return str(self.work_dir) - mw = self.window() - if mw and hasattr(mw, 'work_dir') and mw.work_dir: - return str(mw.work_dir) - return "" - - def browse_output_path(self): - """浏览输出文件路径(保存对话框)""" - current = self.output_path.get_path().strip() - if current: - initial_dir = os.path.dirname(current) - initial_file = os.path.basename(current) - else: - initial_dir = "" - initial_file = "" - - if not initial_dir or not os.path.isdir(initial_dir): - # 默认定位到 indices 目录 - work_dir = self._get_default_work_dir() - initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else "" - if initial_dir and not os.path.isdir(initial_dir): - os.makedirs(initial_dir, exist_ok=True) - - file_path, _ = QFileDialog.getSaveFileName( - self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir, - "CSV Files (*.csv);;All Files (*.*)" - ) - if file_path: - self.output_path.set_path(file_path) - - def get_config(self): - """获取配置""" - preprocessing_methods = [ - method for method, checkbox in self.preproc_checkboxes.items() - if checkbox.isChecked() - ] - model_names = [ - model for model, checkbox in self.model_checkboxes.items() - if checkbox.isChecked() - ] - split_methods = [ - method for method, checkbox in self.split_checkboxes.items() - if checkbox.isChecked() - ] - - config = { - 'feature_start_column': self.feature_start.text(), - 'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'], - 'model_names': model_names if model_names else ['SVR'], - 'split_methods': split_methods if split_methods else ['random'], - 'cv_folds': self.cv_folds.value() - } - training_csv_path = self.training_csv_file.get_path() - if training_csv_path: - config['training_csv_path'] = training_csv_path - output_path = self.output_path.get_path() - if output_path: - config['output_path'] = output_path - return config - - def set_config(self, config): - """设置配置""" - if 'feature_start_column' in config: - self.feature_start.setText(str(config['feature_start_column'])) - if 'cv_folds' in config: - self.cv_folds.setValue(config['cv_folds']) - if 'preprocessing_methods' in config: - methods = config['preprocessing_methods'] - for method, checkbox in self.preproc_checkboxes.items(): - checkbox.setChecked(method in methods) - if 'model_names' in config: - models = config['model_names'] - for model, checkbox in self.model_checkboxes.items(): - checkbox.setChecked(model in models) - if 'split_methods' in config: - methods = config['split_methods'] - for method, checkbox in self.split_checkboxes.items(): - checkbox.setChecked(method in methods) - if 'training_csv_path' in config: - self.training_csv_file.set_path(config['training_csv_path']) - if 'output_path' in config: - self.output_path.set_path(config['output_path']) - - def update_from_config(self, work_dir=None, pipeline=None): - """从全局配置自动填充训练数据和输出路径 - - Args: - work_dir: 工作目录路径 - pipeline: Pipeline 实例(未使用,保留接口兼容性) - """ - if work_dir: - self.work_dir = work_dir - elif hasattr(self, 'work_dir') and self.work_dir: - pass - else: - self.work_dir = None - - # 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径 - main_window = self.window() - if hasattr(main_window, 'step5_panel'): - # 优先直接从 Step5 的输出 widget 读取 - step5_output = main_window.step5_panel.output_file.get_path() - if step5_output: - # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(step5_output): - step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/') - self.training_csv_file.set_path(step5_output) - elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'): - # 回退:从 Step5 的 config 字典中查找可能的键名 - step5_cfg = main_window.step5_panel.get_config() - step5_csv = ( - step5_cfg.get('training_csv_path') - or step5_cfg.get('output_file') - or step5_cfg.get('csv_path') - or step5_cfg.get('output_csv') - ) - if step5_csv: - # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(step5_csv): - step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/') - self.training_csv_file.set_path(step5_csv) - - # 2. 自动填充输出文件路径(基于工作目录和输入文件名) - # 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv - # 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv - if self.work_dir: - indices_dir = os.path.join(self.work_dir, "6_water_quality_indices") - os.makedirs(indices_dir, exist_ok=True) - training_csv = self.training_csv_file.get_path() - if training_csv: - basename = os.path.splitext(os.path.basename(training_csv))[0] - output_file = f"{basename}_indices.csv" - else: - output_file = "water_quality_indices.csv" - output_path = os.path.join(indices_dir, output_file).replace('\\', '/') - self.output_path.set_path(output_path) - else: - self.output_path.set_path("") - - def run_step(self): - """独立运行步骤6""" - training_csv_path = self.training_csv_file.get_path() - if not training_csv_path: - QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!") - return - - main_window = self.window() - if hasattr(main_window, 'run_single_step'): - config = {'step6': self.get_config()} - main_window.run_single_step('step6', config) - - def get_training_params(self): - """获取模型训练参数""" - return { - 'pipeline_type': 'machine_learning', - 'feature_start': float(self.feature_start.text()), - 'cv_folds': self.cv_folds.value(), - 'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()], - 'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()], - 'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()] - } diff --git a/src/gui/panels/step7_panel.py b/src/gui/panels/step7_panel.py index 7c9ed67..e049fce 100644 --- a/src/gui/panels/step7_panel.py +++ b/src/gui/panels/step7_panel.py @@ -1,23 +1,75 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -Step7 面板 - 采样点生成 +Step7 面板 - 机器学习建模 """ import os from PyQt5.QtWidgets import ( - QWidget, QVBoxLayout, QGroupBox, QFormLayout, - QPushButton, QCheckBox, QSpinBox, QMessageBox, + QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout, + QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox, + QPushButton, QFileDialog, QMessageBox, ) +from PyQt5.QtCore import Qt from src.gui.components.custom_widgets import FileSelectWidget -from src.gui.dialogs import SamplingViewerDialog from src.gui.styles import ModernStylesheet +# ============================================================ +# 中文映射表(内部键名 -> 显示文本) +# ============================================================ + +# 预处理方法:内部键 -> 显示文本 +PREPROC_CHINESE = { + 'None': '无 (None)', + 'MMS': '最小-最大归一化 (MMS)', + 'SS': '标度化 (SS)', + 'SNV': '标准正态变换 (SNV)', + 'MA': '移动平均 (MA)', + 'SG': 'Savitzky-Golay (SG)', + 'MSC': '多元散射校正 (MSC)', + 'D1': '一阶导数 (D1)', + 'D2': '二阶导数 (D2)', + 'DT': '去趋势 (DT)', + 'CT': '中心化 (CT)', +} + +# 模型类型:内部键 -> 显示文本 +MODEL_CHINESE = { + # 线性模型 + 'LinearRegression': '多元线性回归 (MLR)', + 'Ridge': '岭回归 (Ridge)', + 'Lasso': '套索回归 (Lasso)', + 'ElasticNet': '弹性网络 (ElasticNet)', + 'PLS': '偏最小二乘 (PLSR)', + # 树模型 + 'DecisionTree': '决策树 (CART)', + 'RF': '随机森林 (RF)', + 'ExtraTrees': '极端随机树 (ET)', + 'XGBoost': '极值梯度提升 (XGBoost)', + 'LightGBM': '轻量梯度提升 (LightGBM)', + 'CatBoost': '类别梯度提升 (CatBoost)', + # 集成学习 + 'GradientBoosting': '梯度提升树 (GBDT)', + 'AdaBoost': '自适应提升 (AdaBoost)', + # 其他模型 + 'SVR': '支持向量回归 (SVR)', + 'KNN': 'K近邻回归 (KNN)', + 'MLP': '多层感知机 (BP神经网络)', +} + +# 数据划分方法:内部键 -> 显示文本 +SPLIT_CHINESE = { + 'spxy': 'SPXY 算法 (考量X-Y空间)', + 'ks': 'KS 算法 (考量X空间)', + 'random': '随机划分 (Random)', +} + + class Step7Panel(QWidget): - """步骤7:采样点生成""" + """步骤7:机器学习建模""" def __init__(self, parent=None): super().__init__(parent) self.init_ui() @@ -25,58 +77,35 @@ class Step7Panel(QWidget): def init_ui(self): layout = QVBoxLayout() - # 去耀斑影像文件(用于独立运行) - self.deglint_img_file = FileSelectWidget( - "去耀斑影像:", - "Image Files (*.bsq *.dat *.tif);;All Files (*.*)" - ) - layout.addWidget(self.deglint_img_file) + # 标题 - # 水域掩膜文件(可选,用于独立运行) - self.water_mask_file = FileSelectWidget( - "水域掩膜:", - "Mask Files (*.dat *.tif);;All Files (*.*)" - ) - self.water_mask_file.label.setText("水域掩膜:") - layout.addWidget(self.water_mask_file) - # 参数设置 - params_group = QGroupBox("采样参数") - params_layout = QFormLayout() - - self.interval = QSpinBox() - self.interval.setRange(10, 500) - self.interval.setValue(50) - params_layout.addRow("采样点间隔(像素):", self.interval) - - self.sample_radius = QSpinBox() - self.sample_radius.setRange(1, 50) - self.sample_radius.setValue(5) - params_layout.addRow("采样半径(像素):", self.sample_radius) - - self.chunk_size = QSpinBox() - self.chunk_size.setRange(100, 10000) - self.chunk_size.setValue(1000) - params_layout.addRow("处理块大小:", self.chunk_size) - - self.use_adaptive_sampling = QCheckBox("启用自适应采样") - self.use_adaptive_sampling.setChecked(True) - params_layout.addRow("采样模式:", self.use_adaptive_sampling) - - params_group.setLayout(params_layout) - layout.addWidget(params_group) - - # 输出文件路径 - self.output_file = FileSelectWidget( - "输出采样点:", + # 训练数据文件(用于独立运行) + self.training_csv_file = FileSelectWidget( + "训练数据:", "CSV Files (*.csv);;All Files (*.*)" ) - self.output_file.line_edit.setPlaceholderText("sampling_points.csv") - layout.addWidget(self.output_file) + layout.addWidget(self.training_csv_file) + + # 机器学习模型页面 + self.ml_page = QWidget() + self.create_ml_page() + layout.addWidget(self.ml_page) + + # 输出文件路径 + self.output_path = FileSelectWidget( + "输出文件:", + "CSV Files (*.csv);;All Files (*.*)", + mode="save" + ) + self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...") + self.output_path.browse_btn.clicked.disconnect() + self.output_path.browse_btn.clicked.connect(self.browse_output_path) + layout.addWidget(self.output_path) # 启用步骤 self.enable_checkbox = QCheckBox("启用此步骤") - self.enable_checkbox.setChecked(True) + self.enable_checkbox.setChecked(False) layout.addWidget(self.enable_checkbox) # 独立运行按钮 @@ -85,57 +114,233 @@ class Step7Panel(QWidget): self.run_btn.clicked.connect(self.run_step) layout.addWidget(self.run_btn) - # 交互式预览按钮 - self.preview_btn = QPushButton("📊 交互式预览采样点与光谱") - self.preview_btn.setEnabled(False) - self.preview_btn.clicked.connect(self._open_sampling_viewer) - layout.addWidget(self.preview_btn) - layout.addStretch() self.setLayout(layout) - # 监听输出路径变化,实时更新预览按钮状态 - self.output_file.line_edit.textChanged.connect(self._on_output_changed) + def create_ml_page(self): + """创建机器学习模型页面""" + layout = QVBoxLayout() + + # 参数设置 + params_group = QGroupBox("训练参数") + params_layout = QFormLayout() + + self.feature_start = QLineEdit() + self.feature_start.setText("374.285004") + params_layout.addRow("特征起始列:", self.feature_start) + + self.cv_folds = QSpinBox() + self.cv_folds.setRange(2, 10) + self.cv_folds.setValue(3) + params_layout.addRow("交叉验证折数:", self.cv_folds) + + params_group.setLayout(params_layout) + layout.addWidget(params_group) + + # 预处理方法 - 多选 + preproc_group = QGroupBox("预处理方法 (可多选)") + preproc_layout = QVBoxLayout() + + preproc_grid = QGridLayout() + self.preproc_checkboxes = {} + preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT'] + + for i, method in enumerate(preproc_methods): + checkbox = QCheckBox(PREPROC_CHINESE.get(method, method)) + checkbox.setChecked(False) + self.preproc_checkboxes[method] = checkbox + preproc_grid.addWidget(checkbox, i // 4, i % 4) + + button_layout = QHBoxLayout() + select_all_btn = QPushButton("全选") + deselect_all_btn = QPushButton("全不选") + select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True)) + deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False)) + button_layout.addWidget(select_all_btn) + button_layout.addWidget(deselect_all_btn) + button_layout.addStretch() + + preproc_layout.addLayout(preproc_grid) + preproc_layout.addLayout(button_layout) + preproc_group.setLayout(preproc_layout) + layout.addWidget(preproc_group) + + # 模型选择 - 多选 + model_group = QGroupBox("模型类型 (可多选)") + model_layout = QVBoxLayout() + + model_grid = QGridLayout() + self.model_checkboxes = {} + + model_groups = [ + ("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']), + ("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']), + ("【集成学习】", ['GradientBoosting', 'AdaBoost']), + ("【其他模型】", ['SVR', 'KNN', 'MLP']) + ] + + row = 0 + for group_name, models in model_groups: + group_label = QLabel(f"{group_name}") + group_label.setStyleSheet( + f"background-color: {ModernStylesheet.COLORS['hover']}; " + f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; " + f"border-radius: 3px;" + ) + model_grid.addWidget(group_label, row, 0, 1, 4) + row += 1 + + for i, model in enumerate(models): + checkbox = QCheckBox(MODEL_CHINESE.get(model, model)) + checkbox.setChecked(False) + self.model_checkboxes[model] = checkbox + model_grid.addWidget(checkbox, row, i % 4) + if (i + 1) % 4 == 0: + row += 1 + + row += 1 + + model_button_layout = QHBoxLayout() + model_select_all = QPushButton("全选") + model_deselect_all = QPushButton("全不选") + model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True)) + model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False)) + model_button_layout.addWidget(model_select_all) + model_button_layout.addWidget(model_deselect_all) + model_button_layout.addStretch() + + model_layout.addLayout(model_grid) + model_layout.addLayout(model_button_layout) + model_group.setLayout(model_layout) + layout.addWidget(model_group) + + # 数据划分方法 - 多选 + split_group = QGroupBox("数据划分方法 (可多选)") + split_layout = QVBoxLayout() + + split_grid = QGridLayout() + self.split_checkboxes = {} + split_methods = ['spxy', 'ks', 'random'] + + for i, method in enumerate(split_methods): + checkbox = QCheckBox(SPLIT_CHINESE.get(method, method)) + checkbox.setChecked(False) + self.split_checkboxes[method] = checkbox + split_grid.addWidget(checkbox, 0, i) + + split_button_layout = QHBoxLayout() + split_select_all = QPushButton("全选") + split_deselect_all = QPushButton("全不选") + split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True)) + split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False)) + split_button_layout.addWidget(split_select_all) + split_button_layout.addWidget(split_deselect_all) + split_button_layout.addStretch() + + split_layout.addLayout(split_grid) + split_layout.addLayout(split_button_layout) + split_group.setLayout(split_layout) + layout.addWidget(split_group) + + self.ml_page.setLayout(layout) + + def _toggle_checkboxes(self, checkboxes_dict, checked): + """统一设置checkbox状态""" + for checkbox in checkboxes_dict.values(): + checkbox.setChecked(checked) + + def _get_default_work_dir(self): + """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" + if hasattr(self, 'work_dir') and self.work_dir: + return str(self.work_dir) + mw = self.window() + if mw and hasattr(mw, 'work_dir') and mw.work_dir: + return str(mw.work_dir) + return "" + + def browse_output_path(self): + """浏览输出文件路径(保存对话框)""" + current = self.output_path.get_path().strip() + if current: + initial_dir = os.path.dirname(current) + initial_file = os.path.basename(current) + else: + initial_dir = "" + initial_file = "" + + if not initial_dir or not os.path.isdir(initial_dir): + # 默认定位到 indices 目录 + work_dir = self._get_default_work_dir() + initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else "" + if initial_dir and not os.path.isdir(initial_dir): + os.makedirs(initial_dir, exist_ok=True) + + file_path, _ = QFileDialog.getSaveFileName( + self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir, + "CSV Files (*.csv);;All Files (*.*)" + ) + if file_path: + self.output_path.set_path(file_path) def get_config(self): """获取配置""" + preprocessing_methods = [ + method for method, checkbox in self.preproc_checkboxes.items() + if checkbox.isChecked() + ] + model_names = [ + model for model, checkbox in self.model_checkboxes.items() + if checkbox.isChecked() + ] + split_methods = [ + method for method, checkbox in self.split_checkboxes.items() + if checkbox.isChecked() + ] + config = { - 'interval': self.interval.value(), - 'sample_radius': self.sample_radius.value(), - 'chunk_size': self.chunk_size.value(), - 'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(), + 'feature_start_column': self.feature_start.text(), + 'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'], + 'model_names': model_names if model_names else ['SVR'], + 'split_methods': split_methods if split_methods else ['random'], + 'cv_folds': self.cv_folds.value() } - deglint_img_path = self.deglint_img_file.get_path() - if deglint_img_path: - config['deglint_img_path'] = deglint_img_path - water_mask_path = self.water_mask_file.get_path() - if water_mask_path: - config['water_mask_path'] = water_mask_path + training_csv_path = self.training_csv_file.get_path() + if training_csv_path: + config['training_csv_path'] = training_csv_path + output_path = self.output_path.get_path() + if output_path: + config['output_path'] = output_path return config def set_config(self, config): """设置配置""" - if 'interval' in config: - self.interval.setValue(config['interval']) - if 'sample_radius' in config: - self.sample_radius.setValue(config['sample_radius']) - if 'chunk_size' in config: - self.chunk_size.setValue(config['chunk_size']) - if 'use_adaptive_sampling' in config: - self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling']) - if 'deglint_img_path' in config: - self.deglint_img_file.set_path(config['deglint_img_path']) - if 'water_mask_path' in config: - self.water_mask_file.set_path(config['water_mask_path']) - if 'glint_mask_path' in config: - self.glint_mask_file.set_path(config['glint_mask_path']) + if 'feature_start_column' in config: + self.feature_start.setText(str(config['feature_start_column'])) + if 'cv_folds' in config: + self.cv_folds.setValue(config['cv_folds']) + if 'preprocessing_methods' in config: + methods = config['preprocessing_methods'] + for method, checkbox in self.preproc_checkboxes.items(): + checkbox.setChecked(method in methods) + if 'model_names' in config: + models = config['model_names'] + for model, checkbox in self.model_checkboxes.items(): + checkbox.setChecked(model in models) + if 'split_methods' in config: + methods = config['split_methods'] + for method, checkbox in self.split_checkboxes.items(): + checkbox.setChecked(method in methods) + if 'training_csv_path' in config: + self.training_csv_file.set_path(config['training_csv_path']) + if 'output_path' in config: + self.output_path.set_path(config['output_path']) def update_from_config(self, work_dir=None, pipeline=None): - """从全局配置自动填充去耀斑影像和掩膜路径 + """从全局配置自动填充训练数据和输出路径 Args: work_dir: 工作目录路径 - pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径) + pipeline: Pipeline 实例(未使用,保留接口兼容性) """ if work_dir: self.work_dir = work_dir @@ -144,81 +349,53 @@ class Step7Panel(QWidget): else: self.work_dir = None + # 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径 main_window = self.window() + if hasattr(main_window, 'step5_panel'): + # 优先直接从 Step5 的输出 widget 读取 + step5_output = main_window.step5_panel.output_file.get_path() + if step5_output: + # 若为相对路径,使用 work_dir 合成为绝对路径 + if not os.path.isabs(step5_output): + step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/') + self.training_csv_file.set_path(step5_output) + elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'): + # 回退:从 Step5 的 config 字典中查找可能的键名 + step5_cfg = main_window.step5_panel.get_config() + step5_csv = ( + step5_cfg.get('training_csv_path') + or step5_cfg.get('output_file') + or step5_cfg.get('csv_path') + or step5_cfg.get('output_csv') + ) + if step5_csv: + # 若为相对路径,使用 work_dir 合成为绝对路径 + if not os.path.isabs(step5_csv): + step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/') + self.training_csv_file.set_path(step5_csv) - # 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径) - deglint_path = None - if pipeline and hasattr(pipeline, 'step_outputs'): - step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {}) - deglint_path = ( - step3_outputs.get('deglint_image') - or step3_outputs.get('output_path') - or step3_outputs.get('output_file') - or step3_outputs.get('deglint_img_path') - ) - # 回退:从 step3 面板 widget 直接读取(可能是相对路径) - if not deglint_path and hasattr(main_window, 'step3_panel'): - deglint_path = main_window.step3_panel.output_file.get_path() - - if deglint_path: - # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(deglint_path): - deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/') - self.deglint_img_file.set_path(deglint_path) - - # 2. 填充水域掩膜路径(优先级:pipeline.step_outputs > step1_panel > 1_water_mask > input-test) - water_mask_path = None - if pipeline and hasattr(pipeline, 'step_outputs'): - step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {}) - water_mask_path = ( - step1_outputs.get('water_mask') - or step1_outputs.get('output_path') - or step1_outputs.get('output_file') - ) - # 回退:从 step1 面板 widget 直接读取 - if not water_mask_path and hasattr(main_window, 'step1_panel'): - water_mask_path = main_window.step1_panel.output_file.get_path() - # 备选:扫描 1_water_mask 目录下的 .dat 文件 - if not water_mask_path and self.work_dir: - mask_dir = os.path.join(self.work_dir, "1_water_mask") - if os.path.isdir(mask_dir): - dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')] - if dat_files: - water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/') - # 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat) - if not water_mask_path and self.work_dir: - input_test_dir = os.path.join(self.work_dir, "input-test") - if os.path.isdir(input_test_dir): - dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')] - # 优先匹配 water_mask_from_shp.dat - for f in dat_files: - if 'water_mask_from_shp' in f.lower(): - water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/') - break - # 否则取第一个 .dat 文件 - if not water_mask_path and dat_files: - water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/') - - if water_mask_path: - # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(water_mask_path): - water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/') - self.water_mask_file.set_path(water_mask_path) - - # 3. 自动填充输出路径(绝对路径) + # 2. 自动填充输出文件路径(基于工作目录和输入文件名) + # 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv + # 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv if self.work_dir: - output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv") - os.makedirs(os.path.dirname(output_path), exist_ok=True) - self.output_file.set_path(output_path.replace('\\', '/')) - - # 4. 同步更新预览按钮状态(路径可能已自动填充) - self._check_csv_exists() + indices_dir = os.path.join(self.work_dir, "6_water_quality_indices") + os.makedirs(indices_dir, exist_ok=True) + training_csv = self.training_csv_file.get_path() + if training_csv: + basename = os.path.splitext(os.path.basename(training_csv))[0] + output_file = f"{basename}_indices.csv" + else: + output_file = "water_quality_indices.csv" + output_path = os.path.join(indices_dir, output_file).replace('\\', '/') + self.output_path.set_path(output_path) + else: + self.output_path.set_path("") def run_step(self): """独立运行步骤7""" - deglint_img_path = self.deglint_img_file.get_path() - if not deglint_img_path: - QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!") + training_csv_path = self.training_csv_file.get_path() + if not training_csv_path: + QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!") return main_window = self.window() @@ -226,27 +403,13 @@ class Step7Panel(QWidget): config = {'step7': self.get_config()} main_window.run_single_step('step7', config) - def _check_csv_exists(self): - """检查 output csv 是否存在,驱动预览按钮启停""" - csv_path = self.output_file.get_path() - enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path)) - self.preview_btn.setEnabled(enabled) - return enabled - - def _on_output_changed(self, _text=None): - """输出路径输入框内容变化时调用(_text 为 line_edit.textChanged 信号参数)""" - self._check_csv_exists() - - def _open_sampling_viewer(self): - """打开交互式采样点查看器弹窗""" - csv_path = self.output_file.get_path() - if not csv_path or not os.path.exists(csv_path): - QMessageBox.warning( - self, "文件不存在", - f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤7生成数据。" - ) - return - dialog = SamplingViewerDialog(csv_path, self) - dialog.exec_() - # 弹窗关闭后再次检查状态(可能文件被覆盖等) - self._check_csv_exists() + def get_training_params(self): + """获取模型训练参数""" + return { + 'pipeline_type': 'machine_learning', + 'feature_start': float(self.feature_start.text()), + 'cv_folds': self.cv_folds.value(), + 'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()], + 'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()], + 'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()] + } diff --git a/src/gui/panels/step6_5_panel.py b/src/gui/panels/step8_non_empirical_panel.py similarity index 97% rename from src/gui/panels/step6_5_panel.py rename to src/gui/panels/step8_non_empirical_panel.py index 6eddab7..25dff2a 100644 --- a/src/gui/panels/step6_5_panel.py +++ b/src/gui/panels/step8_non_empirical_panel.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -Step6_5 面板 - 非经验统计回归建模 +Step8 面板 - 非经验统计回归建模 """ import os @@ -17,8 +17,8 @@ from src.gui.components.custom_widgets import FileSelectWidget from src.gui.styles import ModernStylesheet -class Step6_5Panel(QWidget): - """步骤6.5:非经验统计回归建模""" +class Step8NonEmpiricalPanel(QWidget): + """步骤8:非经验统计回归建模""" def __init__(self, parent=None): super().__init__(parent) self.init_ui() @@ -280,7 +280,7 @@ class Step6_5Panel(QWidget): self.output_dir.set_path(dir_path) def run_step(self): - """独立运行步骤6.5""" + """独立运行步骤8""" training_csv_path = self.training_csv_file.get_path() if not training_csv_path: QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!") @@ -297,7 +297,7 @@ class Step6_5Panel(QWidget): parent = parent.parent() if parent and hasattr(parent, 'run_single_step'): - parent.run_single_step('step6_5', {'step6_5': config}) + parent.run_single_step('step8_non_empirical_modeling', {'step8_non_empirical_modeling': config}) else: QMessageBox.critical(self, "错误", "无法找到父级GUI对象") diff --git a/src/gui/panels/step8_panel.py b/src/gui/panels/step8_panel.py index 727cbb1..ab34d14 100644 --- a/src/gui/panels/step8_panel.py +++ b/src/gui/panels/step8_panel.py @@ -1,462 +1,225 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -Step8 面板 - 机器学习预测 -""" - import os +import sys +import pandas as pd from pathlib import Path +from typing import Dict, List, Union from PyQt5.QtWidgets import ( - QWidget, QVBoxLayout, QGroupBox, QFormLayout, - QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox, - QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout, - QListWidgetItem, + QWidget, QVBoxLayout, QGroupBox, QGridLayout, + QHBoxLayout, QLabel, QCheckBox, QPushButton, QMessageBox, QScrollArea ) from PyQt5.QtCore import Qt - from src.gui.components.custom_widgets import FileSelectWidget from src.gui.styles import ModernStylesheet +def get_resource_path(relative_path: str) -> str: + """适配开发与 PyInstaller 环境的路径获取逻辑。 + 支持两种打包模式: + 1. --onedir 模式:文件在 exe_root/_internal/ 下 → 检查 _internal 目录 + 2. --onefile 模式:文件在 sys._MEIPASS 平铺目录 + """ + # 优先检查 PyInstaller onefile 模式(文件平铺在 _MEIPASS 下) + if hasattr(sys, '_MEIPASS'): + internal_path = os.path.join(sys._MEIPASS, '_internal', relative_path) + if os.path.exists(internal_path): + return internal_path + return os.path.join(sys._MEIPASS, relative_path) + + # 兼容 PyInstaller onedir 模式的 _internal 目录(exe 同级目录下) + exe_dir = os.path.dirname(sys.executable) + internal_path = os.path.join(exe_dir, '_internal', relative_path) + if os.path.exists(internal_path): + return internal_path + + # 开发环境下:基于当前文件 (step8_panel.py) 的绝对路径进行回溯 + # 当前在 src/gui/panels/,目标在 src/gui/model/ + base_dir = Path(__file__).resolve().parent.parent / "model" + target_path = base_dir / os.path.basename(relative_path) + return str(target_path) class Step8Panel(QWidget): - """步骤8:机器学习预测""" def __init__(self, parent=None): super().__init__(parent) - self.external_models_dict = {} # {subdir_name: model_obj, ...} - self.external_model_dir = "" # 母文件夹路径(隐藏) + self.index_checkboxes: Dict[str, QCheckBox] = {} + # 标识为 waterindex.csv,目录跳转逻辑在 get_resource_path 中 + self.builtin_formula_path = get_resource_path("waterindex.csv") + self.init_ui() + # 延迟一小会儿加载,确保UI框架已就绪 + self._auto_load_formulas() def init_ui(self): - layout = QVBoxLayout() + main_layout = QVBoxLayout() + main_layout.setContentsMargins(20, 20, 20, 20) + main_layout.setSpacing(10) - # -------- 模型来源选择(单选按钮组) -------- - source_group = QGroupBox("模型来源") - source_layout = QVBoxLayout() + # 1. 路径展示区 (半透明只读) + path_group = QGroupBox("公式配置源 (内置)") + path_layout = QVBoxLayout() + self.formula_csv_widget = FileSelectWidget("内置CSV路径:", "CSV Files (*.csv)") + self.formula_csv_widget.set_path(self.builtin_formula_path) + self.formula_csv_widget.set_read_only(True) + # 视觉微调:提示用户这是内置的 + self.formula_csv_widget.line_edit.setStyleSheet("background-color: #f0f0f0; color: #666;") + path_layout.addWidget(self.formula_csv_widget) + path_group.setLayout(path_layout) + main_layout.addWidget(path_group) - self.use_trained_model = QRadioButton("使用当前训练流程的模型") - self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)") - self.use_trained_model.setChecked(True) - source_layout.addWidget(self.use_trained_model) - source_layout.addWidget(self.use_external_model) + # 2. 训练数据输入 + input_group = QGroupBox("输入样本数据") + input_layout = QVBoxLayout() + self.training_data_widget = FileSelectWidget("特征提取CSV:", "CSV Files (*.csv)") + input_layout.addWidget(self.training_data_widget) + input_group.setLayout(input_layout) + main_layout.addWidget(input_group) - self.use_trained_model.toggled.connect(self._on_model_source_changed) - self.use_external_model.toggled.connect(self._on_model_source_changed) + # 3. 公式选择区 + self.formula_group = QGroupBox("待计算水质指数勾选") + formula_outer_layout = QVBoxLayout() - source_group.setStyleSheet(""" - QRadioButton { - font-size: 13px; - spacing: 8px; - } - QRadioButton::indicator { - width: 16px; - height: 16px; - border-radius: 9px; - border: 2px solid #A0A0A0; - background-color: #FFFFFF; - } - QRadioButton::indicator:hover { - border: 2px solid #0078D7; - } - QRadioButton::indicator:checked { - background-color: #0078D7; - border: 2px solid #0078D7; - } - """) + btn_layout = QHBoxLayout() + self.select_all_btn = QPushButton("全选") + self.deselect_all_btn = QPushButton("清空") + self.select_all_btn.clicked.connect(self.select_all_formulas) + self.deselect_all_btn.clicked.connect(self.deselect_all_formulas) + btn_layout.addWidget(self.select_all_btn) + btn_layout.addWidget(self.deselect_all_btn) + btn_layout.addStretch() - source_group.setLayout(source_layout) - layout.addWidget(source_group) + self.refresh_button = QPushButton("手动重新加载公式") + self.refresh_button.clicked.connect(lambda: self.refresh_formulas(silent=False)) + btn_layout.addWidget(self.refresh_button) - # -------- 外部模型文件选择(条件显示) -------- - self.external_model_widget = FileSelectWidget( - "模型母文件夹:", - "Directories" - ) - self.external_model_widget.browse_btn.clicked.disconnect() - self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir) - self.external_model_widget.setVisible(False) - layout.addWidget(self.external_model_widget) + formula_outer_layout.addLayout(btn_layout) - # -------- 已扫描模型列表(条件显示) -------- - self.model_list_group = QGroupBox("选择参与预测的模型") - self.model_list_group.setVisible(False) - model_list_layout = QVBoxLayout() + # 核心滚动区 + scroll = QScrollArea() + scroll.setWidgetResizable(True) + scroll.setMinimumHeight(300) # 强制最小高度,防止塌陷 + self.scroll_content = QWidget() + self.formula_layout = QGridLayout(self.scroll_content) + self.formula_layout.setAlignment(Qt.AlignTop) # 靠顶对齐 + scroll.setWidget(self.scroll_content) + formula_outer_layout.addWidget(scroll) - self.model_list = QListWidget() - self.model_list.setMaximumHeight(130) - self.model_list.setSelectionMode(QAbstractItemView.NoSelection) - self.model_list.setStyleSheet(""" - QListWidget { - border: 1px solid #C0C0C0; - border-radius: 4px; - background-color: #FFFFFF; - font-size: 12px; - } - QListWidget::item { - padding: 4px 6px; - border-bottom: 1px solid #F0F0F0; - } - QListWidget::item:selected { - background-color: transparent; - } - """) - model_list_layout.addWidget(self.model_list) + self.formula_group.setLayout(formula_outer_layout) + main_layout.addWidget(self.formula_group) - btn_row = QHBoxLayout() - self.btn_select_all = QPushButton("全选") - self.btn_select_all.setMaximumWidth(80) - self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default')) - self.btn_select_all.clicked.connect(self._select_all_models) + # 4. 输出与运行 + output_group = QGroupBox("结果输出") + output_layout = QVBoxLayout() + self.output_file_widget = FileSelectWidget("保存路径:", "CSV Files (*.csv)", mode="save") + output_layout.addWidget(self.output_file_widget) + output_group.setLayout(output_layout) + main_layout.addWidget(output_group) - self.btn_select_none = QPushButton("全不选") - self.btn_select_none.setMaximumWidth(80) - self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default')) - self.btn_select_none.clicked.connect(self._select_none_models) - - btn_row.addWidget(self.btn_select_all) - btn_row.addWidget(self.btn_select_none) - btn_row.addStretch() - model_list_layout.addLayout(btn_row) - - self.model_list_group.setLayout(model_list_layout) - layout.addWidget(self.model_list_group) - - # -------- 采样光谱CSV文件(用于独立运行)-------- - self.sampling_csv_file = FileSelectWidget( - "采样光谱CSV:", - "CSV Files (*.csv);;All Files (*.*)" - ) - layout.addWidget(self.sampling_csv_file) - - # 模型目录(用于独立运行) - self.models_dir_file = FileSelectWidget( - "模型目录:", - "Directories;;All Files (*.*)" - ) - self.models_dir_file.label.setText("模型目录:") - self.models_dir_file.browse_btn.clicked.disconnect() - self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir) - layout.addWidget(self.models_dir_file) - - # 参数设置 - params_group = QGroupBox("预测参数") - params_layout = QFormLayout() - - self.metric = QComboBox() - self.metric.addItems(['test_r2', 'test_rmse', 'test_mae']) - params_layout.addRow("模型选择指标:", self.metric) - - self.prediction_column = QLineEdit() - self.prediction_column.setText("prediction") - params_layout.addRow("预测列名:", self.prediction_column) - - params_group.setLayout(params_layout) - layout.addWidget(params_group) - - # 输出路径 - self.output_file = FileSelectWidget( - "输出路径:", - "CSV Files (*.csv);;All Files (*.*)" - ) - layout.addWidget(self.output_file) - - # 启用步骤 - self.enable_checkbox = QCheckBox("启用此步骤") + self.enable_checkbox = QCheckBox("启用计算流程") self.enable_checkbox.setChecked(True) - layout.addWidget(self.enable_checkbox) + main_layout.addWidget(self.enable_checkbox) - # 独立运行按钮 - self.run_btn = QPushButton("独立运行此步骤") - self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) - self.run_btn.clicked.connect(self.run_step) - layout.addWidget(self.run_btn) + self.run_button = QPushButton("立即执行计算") + self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) + self.run_button.setMinimumHeight(40) + self.run_button.clicked.connect(self.run_step) + main_layout.addWidget(self.run_button) - layout.addStretch() - self.setLayout(layout) + self.setLayout(main_layout) - def _on_model_source_changed(self, checked: bool): - """单选按钮切换:控制外部模型文件选择控件的显示/隐藏""" - if not checked: + def _auto_load_formulas(self): + """启动时自动加载逻辑""" + if os.path.exists(self.builtin_formula_path): + self.refresh_formulas(silent=True) + else: + print(f"DEBUG: 自动加载失败,路径不存在: {self.builtin_formula_path}") + + def refresh_formulas(self, silent=False): + path = self.builtin_formula_path + if not os.path.exists(path): + if not silent: QMessageBox.warning(self, "错误", f"找不到内置公式文件:\n{path}") return - is_external = self.use_external_model.isChecked() - self.external_model_widget.setVisible(is_external) - self.model_list_group.setVisible(is_external) - if not is_external: - self.external_models_dict = {} - self.external_model_dir = "" - self._clear_model_list() - - def _scan_external_model_dir(self): - """浏览模型母文件夹,自动扫描子目录中的 .joblib 文件""" - default = self._get_default_work_dir() - if default: - default = os.path.join(default, "7_Supervised_Model_Training") - dir_path = QFileDialog.getExistingDirectory( - self, - "选择模型母文件夹", - default, - ) - if not dir_path: - return - - self.external_model_dir = dir_path - models_found = {} - errors = [] try: - import joblib + # 清理旧列表 + for i in reversed(range(self.formula_layout.count())): + widget = self.formula_layout.itemAt(i).widget() + if widget: widget.deleteLater() + self.index_checkboxes.clear() - for subentry in os.scandir(dir_path): - if not subentry.is_dir(): - continue - subdir_name = subentry.name - joblib_files = [ - f for f in os.scandir(subentry.path) - if f.is_file() and f.name.lower().endswith(".joblib") - ] - if not joblib_files: - continue - # 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致) - joblib_path = joblib_files[0].path + # 鲁棒性读取:尝试不同编码 + for encoding in ['utf-8', 'gbk', 'utf-8-sig']: try: - loaded = joblib.load(joblib_path) - if isinstance(loaded, dict) and "model" in loaded: - model_obj = loaded["model"] - elif hasattr(loaded, "predict"): - model_obj = loaded - else: - errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}") - continue - models_found[subdir_name] = model_obj - except Exception as e: - errors.append(f"{subdir_name}: {type(e).__name__}: {e}") + df = pd.read_csv(path, encoding=encoding) + if 'Formula_Name' in df.columns: break + except: continue + + if 'Formula_Name' not in df.columns: + if not silent: QMessageBox.critical(self, "错误", "CSV文件缺少 'Formula_Name' 列") + return + + names = df['Formula_Name'].dropna().unique().tolist() + + row, col = 0, 0 + for name in names: + name = str(name).strip() + if not name: continue + cb = QCheckBox(name) + cb.setChecked(True) + self.index_checkboxes[name] = cb + self.formula_layout.addWidget(cb, row, col) + col += 1 + if col >= 3: + col = 0 + row += 1 + + # 强制UI更新 + self.scroll_content.adjustSize() + print(f"✅ 成功加载 {len(self.index_checkboxes)} 个公式") except Exception as e: - QMessageBox.warning( - self, - "扫描失败", - f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}", - ) - return + if not silent: QMessageBox.critical(self, "加载失败", f"原因: {str(e)}") - if not models_found: - QMessageBox.warning( - self, - "未找到模型", - f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n" - "请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。", - ) - self.external_model_widget.set_path("") - self.external_models_dict = {} - self._clear_model_list() - return + def select_all_formulas(self): + for cb in self.index_checkboxes.values(): cb.setChecked(True) - self.external_models_dict = models_found - self._populate_model_list(models_found) - names = sorted(models_found.keys()) - display = f"已识别到 {len(names)} 个模型: {', '.join(names)}" - self.external_model_widget.set_path(display) - self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;") - - err_lines = "\n".join(errors) if errors else "无" - QMessageBox.information( - self, - "模型扫描完成", - f"成功加载 {len(models_found)} 个模型:\n{display}\n\n" - f"加载失败 {len(errors)} 个:\n{err_lines}", - ) - - def _populate_model_list(self, models_dict): - """将扫描到的模型填充到 QListWidget,每个条目可勾选,默认全选""" - self.model_list.clear() - for name in sorted(models_dict.keys()): - item = QListWidgetItem(name) - item.setFlags(item.flags() | Qt.ItemIsUserCheckable) - item.setCheckState(Qt.Checked) - self.model_list.addItem(item) - - def _clear_model_list(self): - """清空模型列表""" - self.model_list.clear() - - def _select_all_models(self): - """全选:设置所有条目为 Checked""" - for i in range(self.model_list.count()): - self.model_list.item(i).setCheckState(Qt.Checked) - - def _select_none_models(self): - """全不选:设置所有条目为 Unchecked""" - for i in range(self.model_list.count()): - self.model_list.item(i).setCheckState(Qt.Unchecked) - - def _get_checked_models_dict(self): - """从列表中提取用户勾选的模型,组装成字典返回""" - result = {} - for i in range(self.model_list.count()): - item = self.model_list.item(i) - if item.checkState() == Qt.Checked: - name = item.text() - if name in self.external_models_dict: - result[name] = self.external_models_dict[name] - return result - - def update_from_config(self, work_dir=None, pipeline=None): - """从全局配置自动填充采样光谱和模型目录 - - Args: - work_dir: 工作目录路径 - pipeline: Pipeline 实例(未使用,保留接口兼容性) - """ - try: - import traceback - - if work_dir: - self.work_dir = work_dir - elif hasattr(self, 'work_dir') and self.work_dir: - pass - else: - self.work_dir = None - - main_window = self.window() - - # 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径 - if main_window and hasattr(main_window, 'step7_panel'): - step7_widget = getattr(main_window.step7_panel, 'output_file', None) - step7_output_path = "" - if hasattr(step7_widget, 'get_path'): - step7_output_path = step7_widget.get_path() or "" - elif hasattr(step7_widget, 'text'): - step7_output_path = step7_widget.text() or "" - - if step7_output_path: - # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(step7_output_path): - step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/') - existing = self.sampling_csv_file.get_path() - if not existing or not existing.strip(): - self.sampling_csv_file.set_path(step7_output_path) - - # 2. 尝试从 Step6 界面读取监督模型目录 - if main_window and hasattr(main_window, 'step6_panel'): - step6_widget = getattr(main_window.step6_panel, 'output_dir', None) - step6_models_dir = "" - if hasattr(step6_widget, 'get_path'): - step6_models_dir = step6_widget.get_path() or "" - elif hasattr(step6_widget, 'text'): - step6_models_dir = step6_widget.text() or "" - - if step6_models_dir: - # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(step6_models_dir): - step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/') - existing_models = self.models_dir_file.get_path() - if not existing_models or not existing_models.strip(): - self.models_dir_file.set_path(step6_models_dir) - - # 3. 自动填充输出路径(机器学习预测目录) - if self.work_dir: - output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction") - os.makedirs(output_dir, exist_ok=True) - existing_out = self.output_file.get_path() - if not existing_out or not existing_out.strip(): - self.output_file.set_path(output_dir) - except Exception as e: - import traceback - print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}") - traceback.print_exc() - - def _get_default_work_dir(self): - """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" - if hasattr(self, 'work_dir') and self.work_dir: - return str(self.work_dir) - mw = self.window() - if mw and hasattr(mw, 'work_dir') and mw.work_dir: - return str(mw.work_dir) - return "" - - def browse_models_dir(self): - """浏览模型目录""" - default = self._get_default_work_dir() - if default: - default = os.path.join(default, "7_Supervised_Model_Training") - dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default) - if dir_path: - self.models_dir_file.set_path(dir_path) + def deselect_all_formulas(self): + for cb in self.index_checkboxes.values(): cb.setChecked(False) def get_config(self): - """获取配置""" - config = { - 'metric': self.metric.currentText(), - 'prediction_column': self.prediction_column.text(), + selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()] + return { + 'training_csv_path': self.training_data_widget.get_path(), + 'formula_csv_file': self.builtin_formula_path, + 'formula_names': selected, + 'output_file': self.output_file_widget.get_path(), + 'enabled': self.enable_checkbox.isChecked() } - sampling_csv_path = self.sampling_csv_file.get_path() - if sampling_csv_path: - config['sampling_csv_path'] = sampling_csv_path - models_dir = self.models_dir_file.get_path() - if models_dir: - config['models_dir'] = models_dir - output_path = self.output_file.get_path() - if output_path: - config['output_path'] = output_path - return config def set_config(self, config): - """设置配置""" - if 'metric' in config: - idx = self.metric.findText(config['metric']) - if idx >= 0: - self.metric.setCurrentIndex(idx) - if 'prediction_column' in config: - self.prediction_column.setText(config['prediction_column']) - if 'sampling_csv_path' in config: - self.sampling_csv_file.set_path(config['sampling_csv_path']) - if 'models_dir' in config: - self.models_dir_file.set_path(config['models_dir']) - if 'output_path' in config: - self.output_file.set_path(config['output_path']) + if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path']) + if 'formula_names' in config: + sel = set(config['formula_names']) + for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel) + if 'output_file' in config: self.output_file_widget.set_path(config['output_file']) + self.enable_checkbox.setChecked(config.get('enabled', True)) + + def update_from_config(self, work_dir=None, pipeline=None): + if work_dir: self.work_dir = work_dir + main = self.window() + if hasattr(main, 'step5_panel'): + p5 = main.step5_panel.output_file.get_path() # 修正:变量名对齐 + if p5: + if not os.path.isabs(p5): p5 = os.path.join(self.work_dir or '', p5).replace('\\', '/') + self.training_data_widget.set_path(p5) + + if self.work_dir: + out = os.path.join(self.work_dir, "6_water_quality_indices", "training_spectra_indices.csv").replace('\\', '/') + self.output_file_widget.set_path(out) def run_step(self): - """独立运行步骤8""" - sampling_csv_path = self.sampling_csv_file.get_path() - if not sampling_csv_path: - QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!") + config = self.get_config() + if not config['training_csv_path']: + QMessageBox.warning(self, "提示", "请先选择输入数据") return - - # 外部模型优先:用户选择了"导入本地预训练模型" - if self.use_external_model.isChecked(): - if not self.external_models_dict: - QMessageBox.warning( - self, - "模型未加载", - "请先点击「浏览...」按钮选择模型母文件夹!", - ) - return - # 只传递用户勾选的模型 - checked_dict = self._get_checked_models_dict() - if not checked_dict: - QMessageBox.warning( - self, - "未选择模型", - "请至少勾选一个模型参与预测!", - ) - return - main_window = self.window() - if hasattr(main_window, 'run_single_step'): - config = { - 'step8': self.get_config(), - '_external_models_dict': checked_dict, - '_external_model_dir': self.external_model_dir, - } - main_window.run_single_step('step8', config) - return - - # 默认流程:使用模型目录 - models_dir = self.models_dir_file.get_path() - if not models_dir: - QMessageBox.warning(self, "输入错误", "请选择模型目录!") - return - - main_window = self.window() - if hasattr(main_window, 'run_single_step'): - config = {'step8': self.get_config()} - main_window.run_single_step('step8', config) + parent = self.parent() + while parent and not hasattr(parent, 'run_single_step'): parent = parent.parent() + if parent: parent.run_single_step('step8', {'step8': config}) \ No newline at end of file diff --git a/src/gui/panels/step9_panel.py b/src/gui/panels/step9_panel.py index c35a9d1..335eb44 100644 --- a/src/gui/panels/step9_panel.py +++ b/src/gui/panels/step9_panel.py @@ -1,206 +1,158 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -Step9 面板 - 分布图生成 +Step9 面板 - 自定义回归分析 """ import os -import traceback -from pathlib import Path -from typing import List, Optional +from typing import Dict -from PyQt5.QtCore import Qt, QThread, pyqtSignal +import pandas as pd from PyQt5.QtWidgets import ( - QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout, - QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox, - QRadioButton, QButtonGroup, QMessageBox, QFileDialog, + QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout, + QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton, + QScrollArea, QMessageBox, ) from src.gui.components.custom_widgets import FileSelectWidget from src.gui.styles import ModernStylesheet -# Pipeline 可用性(与 core/worker_thread.py 保持一致) -try: - from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline - PIPELINE_AVAILABLE = True -except ImportError: - PIPELINE_AVAILABLE = False - - -class Step9BatchThread(QThread): - """专题图:按文件夹内多个预测 CSV 批量生成分布图。""" - - finished_ok = pyqtSignal(int) - failed = pyqtSignal(str) - log_message = pyqtSignal(str, str) - - def __init__(self, work_dir: str, csv_paths: List[str], step9_kwargs: dict, output_dir_optional: Optional[str]): - super().__init__() - self.work_dir = work_dir - self.csv_paths = csv_paths - self.step9_kwargs = step9_kwargs - self.output_dir_optional = (output_dir_optional or "").strip() or None - - def run(self): - mpl_prev = None - try: - import matplotlib - mpl_prev = matplotlib.get_backend() - except Exception: - pass - try: - import matplotlib.pyplot as plt - plt.switch_backend("Agg") - except Exception: - mpl_prev = None - try: - from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline - pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir) - n = len(self.csv_paths) - for i, csv_p in enumerate(self.csv_paths): - self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info") - kw = {**self.step9_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True} - if self.output_dir_optional: - stem = Path(csv_p).stem - kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png") - else: - kw["output_image_path"] = None - pipeline.step9_generate_distribution_map(**kw) - self.finished_ok.emit(n) - except Exception as e: - self.failed.emit(f"{e}\n{traceback.format_exc()}") - finally: - if mpl_prev: - try: - import matplotlib.pyplot as plt - plt.switch_backend(mpl_prev) - except Exception: - pass - class Step9Panel(QWidget): - """步骤9:分布图生成""" + """步骤9:自定义回归分析""" def __init__(self, parent=None): super().__init__(parent) - self._batch_thread = None + self.x_column_checkboxes: Dict[str, QCheckBox] = {} + self.y_column_checkboxes: Dict[str, QCheckBox] = {} + self.method_checkboxes: Dict[str, QCheckBox] = {} + self.csv_columns = [] self.init_ui() def init_ui(self): layout = QVBoxLayout() - hint = QLabel( - "独立运行:可选「单个 CSV」或「文件夹批量」(扫描目录下所有 .csv)。" - "完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。" - ) - hint.setWordWrap(True) - hint.setStyleSheet( - f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};" - ) + hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法") + hint.setStyleSheet("color: #666; font-size: 11px;") layout.addWidget(hint) - mode_row = QHBoxLayout() - self.mode_single_rb = QRadioButton("单个 CSV 文件") - self.mode_folder_rb = QRadioButton("文件夹批量") - self._mode_group = QButtonGroup(self) - self._mode_group.addButton(self.mode_single_rb, 0) - self._mode_group.addButton(self.mode_folder_rb, 1) - mode_row.addWidget(self.mode_single_rb) - mode_row.addWidget(self.mode_folder_rb) - mode_row.addStretch() - layout.addLayout(mode_row) + # CSV文件选择 + csv_group = QGroupBox("数据文件") + csv_layout = QVBoxLayout() - # ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ---------- - radio_style = """ - QRadioButton { - font-size: 14px; - spacing: 8px; - color: #333333; - } - QRadioButton::indicator { - width: 16px; - height: 16px; - border: 2px solid #999999; - border-radius: 3px; - background-color: white; - } - QRadioButton::indicator:checked { - border: 2px solid #0078d4; - background-color: #0078d4; - image: none; - } - QRadioButton::indicator:hover { - border: 2px solid #005a9e; - } - """ - self.mode_single_rb.setStyleSheet(radio_style) - self.mode_folder_rb.setStyleSheet(radio_style) - - self.prediction_csv_file = FileSelectWidget( - "预测结果CSV:", + self.csv_file = FileSelectWidget( + "输入CSV文件:", "CSV Files (*.csv);;All Files (*.*)" ) - layout.addWidget(self.prediction_csv_file) + self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed) + csv_layout.addWidget(self.csv_file) - folder_row = QHBoxLayout() - self.prediction_csv_dir_label = QLabel("预测CSV目录:") - self.prediction_csv_dir_label.setMinimumWidth(120) - self.prediction_csv_dir_edit = QLineEdit() - self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…") - pred_dir_btn = QPushButton("浏览…") - pred_dir_btn.setMaximumWidth(80) - pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir) - folder_row.addWidget(self.prediction_csv_dir_label) - folder_row.addWidget(self.prediction_csv_dir_edit, 1) - folder_row.addWidget(pred_dir_btn) - self._folder_row_widget = QWidget() - self._folder_row_widget.setLayout(folder_row) - layout.addWidget(self._folder_row_widget) + self.refresh_btn = QPushButton("刷新列信息") + self.refresh_btn.clicked.connect(self.refresh_csv_columns) + csv_layout.addWidget(self.refresh_btn) - self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv)") - layout.addWidget(self.recursive_csv_cb) + csv_group.setLayout(csv_layout) + layout.addWidget(csv_group) - self.boundary_file = FileSelectWidget( - "边界文件:", - "Shapefiles (*.shp);;All Files (*.*)" - ) - layout.addWidget(self.boundary_file) + # 自变量选择 + x_group = QGroupBox("自变量列选择 (可多选)") + x_layout = QVBoxLayout() - # 参数设置 - params_group = QGroupBox("生成参数") - params_layout = QFormLayout() + x_scroll = QScrollArea() + x_scroll.setWidgetResizable(True) + x_scroll.setMinimumHeight(250) + x_scroll.setMaximumHeight(350) - self.resolution = QDoubleSpinBox() - self.resolution.setRange(1, 1000) - self.resolution.setValue(30) - params_layout.addRow("分辨率(米):", self.resolution) + x_widget = QWidget() + self.x_columns_layout = QGridLayout() + x_widget.setLayout(self.x_columns_layout) - self.input_crs = QLineEdit() - self.input_crs.setText("EPSG:32651") - params_layout.addRow("输入坐标系:", self.input_crs) + x_scroll.setWidget(x_widget) + x_layout.addWidget(x_scroll) - self.output_crs = QLineEdit() - self.output_crs.setText("EPSG:4326") - params_layout.addRow("输出坐标系:", self.output_crs) + x_btn_layout = QHBoxLayout() + self.x_select_all = QPushButton("全选") + self.x_deselect_all = QPushButton("全不选") + self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True)) + self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False)) + x_btn_layout.addWidget(self.x_select_all) + x_btn_layout.addWidget(self.x_deselect_all) + x_btn_layout.addStretch() + x_layout.addLayout(x_btn_layout) - self.show_points = QCheckBox("显示采样点") - params_layout.addRow("", self.show_points) + x_group.setLayout(x_layout) + layout.addWidget(x_group) - self.use_diffusion = QCheckBox("启用距离扩散") - self.use_diffusion.setChecked(True) - params_layout.addRow("", self.use_diffusion) + # 因变量选择 + y_group = QGroupBox("因变量列选择 (可多选)") + y_layout = QVBoxLayout() - params_group.setLayout(params_layout) - layout.addWidget(params_group) + y_scroll = QScrollArea() + y_scroll.setWidgetResizable(True) + y_scroll.setMinimumHeight(200) + y_scroll.setMaximumHeight(300) + + y_widget = QWidget() + self.y_columns_layout = QGridLayout() + y_widget.setLayout(self.y_columns_layout) + + y_scroll.setWidget(y_widget) + y_layout.addWidget(y_scroll) + + y_btn_layout = QHBoxLayout() + self.y_select_all = QPushButton("全选") + self.y_deselect_all = QPushButton("全不选") + self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True)) + self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False)) + y_btn_layout.addWidget(self.y_select_all) + y_btn_layout.addWidget(self.y_deselect_all) + y_btn_layout.addStretch() + y_layout.addLayout(y_btn_layout) + + y_group.setLayout(y_layout) + layout.addWidget(y_group) + + # 回归方法选择 + method_group = QGroupBox("回归方法选择 (可多选)") + method_layout = QVBoxLayout() + + method_grid = QGridLayout() + regression_methods = [ + 'linear', 'exponential', 'power', 'logarithmic', + 'polynomial', 'hyperbolic', 'sigmoidal' + ] + + for i, method in enumerate(regression_methods): + checkbox = QCheckBox(method) + if method in ['linear', 'exponential', 'power', 'logarithmic']: + checkbox.setChecked(True) + self.method_checkboxes[method] = checkbox + method_grid.addWidget(checkbox, i // 3, i % 3) + + method_layout.addLayout(method_grid) + + method_btn_layout = QHBoxLayout() + self.method_select_all = QPushButton("全选") + self.method_deselect_all = QPushButton("全不选") + self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True)) + self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False)) + method_btn_layout.addWidget(self.method_select_all) + method_btn_layout.addWidget(self.method_deselect_all) + method_btn_layout.addStretch() + method_layout.addLayout(method_btn_layout) + + method_group.setLayout(method_layout) + layout.addWidget(method_group) # 输出目录 - self.output_dir = FileSelectWidget( - "输出分布图目录:", - "Directories;;All Files (*.*)" - ) - self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization") - self.output_dir.browse_btn.clicked.disconnect() - self.output_dir.browse_btn.clicked.connect(self.browse_output_dir) - layout.addWidget(self.output_dir) + output_group = QGroupBox("输出设置") + output_layout = QFormLayout() + + self.output_dir = QLineEdit() + self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充 + output_layout.addRow("输出目录名:", self.output_dir) + + output_group.setLayout(output_layout) + layout.addWidget(output_group) # 启用步骤 self.enable_checkbox = QCheckBox("启用此步骤") @@ -216,119 +168,120 @@ class Step9Panel(QWidget): layout.addStretch() self.setLayout(layout) - # 信号绑定与初始状态 - self.mode_single_rb.toggled.connect(self._toggle_input_mode) - self.mode_folder_rb.toggled.connect(self._toggle_input_mode) - self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV" - self._toggle_input_mode() # 根据默认值设置初始显示状态 + def toggle_checkboxes(self, checkboxes_dict, checked): + """统一设置checkbox状态""" + for checkbox in checkboxes_dict.values(): + checkbox.setChecked(checked) - def _toggle_input_mode(self): - """槽函数:根据单选框状态动态显示/隐藏对应的输入组件。""" - folder_mode = self.mode_folder_rb.isChecked() - # 单个 CSV 模式:显示单文件选择,隐藏文件夹选择 - self.prediction_csv_file.setVisible(not folder_mode) - # 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择 - self._folder_row_widget.setVisible(folder_mode) - self.recursive_csv_cb.setVisible(folder_mode) + def on_csv_file_changed(self): + """CSV文件改变时自动刷新列信息""" + self.refresh_csv_columns() - def _get_default_work_dir(self): - """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" - if hasattr(self, 'work_dir') and self.work_dir: - return str(self.work_dir) - mw = self.window() - if mw and hasattr(mw, 'work_dir') and mw.work_dir: - return str(mw.work_dir) - return "" + def refresh_csv_columns(self): + """刷新CSV文件的列信息""" + csv_path = self.csv_file.get_path() + if not csv_path or not os.path.exists(csv_path): + self.csv_columns = [] + self.update_column_widgets() + return - def browse_prediction_csv_dir(self): - default = self._get_default_work_dir() - if default: - default = os.path.join(default, "11_12_13_predictions") - d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default) - if d: - self.prediction_csv_dir_edit.setText(d) + try: + df = pd.read_csv(csv_path, nrows=0) + self.csv_columns = list(df.columns) + self.update_column_widgets() + except Exception as e: + self.csv_columns = [] + self.update_column_widgets() + print(f"读取CSV列信息失败: {e}") - def _collect_csv_paths_from_folder(self) -> List[str]: - folder = (self.prediction_csv_dir_edit.text() or "").strip() - if not folder or not os.path.isdir(folder): - return [] - root = Path(folder) - if self.recursive_csv_cb.isChecked(): - files = sorted(root.rglob("*.csv")) - else: - files = sorted(root.glob("*.csv")) - return [str(p) for p in files if p.is_file()] + def update_column_widgets(self): + """更新列选择组件""" + for checkbox in self.x_column_checkboxes.values(): + checkbox.setParent(None) + self.x_column_checkboxes.clear() - def _step9_base_pipeline_kwargs(self) -> dict: - return { - 'boundary_shp_path': self.boundary_file.get_path(), - 'resolution': self.resolution.value(), - 'input_crs': self.input_crs.text(), - 'output_crs': self.output_crs.text(), - 'show_sample_points': self.show_points.isChecked(), - 'use_distance_diffusion': self.use_diffusion.isChecked(), - } + for checkbox in self.y_column_checkboxes.values(): + checkbox.setParent(None) + self.y_column_checkboxes.clear() + + if not self.csv_columns: + return + + for i, col in enumerate(self.csv_columns): + checkbox = QCheckBox(col) + if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']): + checkbox.setChecked(True) + self.x_column_checkboxes[col] = checkbox + self.x_columns_layout.addWidget(checkbox, i // 3, i % 3) + + for i, col in enumerate(self.csv_columns): + checkbox = QCheckBox(col) + if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']): + checkbox.setChecked(True) + self.y_column_checkboxes[col] = checkbox + self.y_columns_layout.addWidget(checkbox, i // 2, i % 2) + + self.x_columns_layout.update() + self.y_columns_layout.update() def get_config(self): - pred_csv = (self.prediction_csv_file.get_path() or "").strip() - folder_mode = self.mode_folder_rb.isChecked() - pred_dir = (self.prediction_csv_dir_edit.text() or "").strip() - config = { - 'step9_batch_mode': 'folder' if folder_mode else 'single', - 'prediction_csv_dir': pred_dir if pred_dir else None, - 'recursive_csv_scan': self.recursive_csv_cb.isChecked(), - 'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None), - 'boundary_shp_path': self.boundary_file.get_path(), - 'resolution': self.resolution.value(), - 'input_crs': self.input_crs.text(), - 'output_crs': self.output_crs.text(), - 'show_sample_points': self.show_points.isChecked(), - 'use_distance_diffusion': self.use_diffusion.isChecked(), + selected_x_columns = [ + col for col, checkbox in self.x_column_checkboxes.items() + if checkbox.isChecked() + ] + selected_y_columns = [ + col for col, checkbox in self.y_column_checkboxes.items() + if checkbox.isChecked() + ] + selected_methods = [ + method for method, checkbox in self.method_checkboxes.items() + if checkbox.isChecked() + ] + if not selected_methods: + selected_methods = 'all' + + return { + 'csv_path': self.csv_file.get_path() or None, + 'x_columns': selected_x_columns, + 'y_columns': selected_y_columns, + 'methods': selected_methods, + 'output_dir': self.output_dir.text().strip() or None, + 'enabled': self.enable_checkbox.isChecked() } - out_dir = (self.output_dir.get_path() or "").strip() - if not folder_mode and pred_csv and out_dir: - stem = Path(pred_csv).stem - config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png") - else: - config['output_image_path'] = None - return config def set_config(self, config): - mode = config.get('step9_batch_mode', 'single') - if mode == 'folder': - self.mode_folder_rb.setChecked(True) - else: - self.mode_single_rb.setChecked(True) - if config.get('prediction_csv_dir'): - self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir'])) - if 'recursive_csv_scan' in config: - self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan'])) - if 'prediction_csv_path' in config and config['prediction_csv_path']: - self.prediction_csv_file.set_path(str(config['prediction_csv_path'])) - if 'boundary_shp_path' in config: - self.boundary_file.set_path(config['boundary_shp_path']) - if 'resolution' in config: - self.resolution.setValue(config['resolution']) - if 'input_crs' in config: - self.input_crs.setText(config['input_crs']) - if 'output_crs' in config: - self.output_crs.setText(config['output_crs']) - if 'show_sample_points' in config: - self.show_points.setChecked(config['show_sample_points']) - if 'use_distance_diffusion' in config: - self.use_diffusion.setChecked(config['use_distance_diffusion']) - if 'output_dir' in config and config['output_dir']: - self.output_dir.set_path(str(config['output_dir'])) - elif config.get('output_image_path'): - p = Path(str(config['output_image_path'])) - if p.parent and str(p.parent) != '.': - self.output_dir.set_path(str(p.parent)) + if 'csv_path' in config: + self.csv_file.set_path(config['csv_path']) + self.refresh_csv_columns() + + if 'x_columns' in config: + selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set() + for col, checkbox in self.x_column_checkboxes.items(): + checkbox.setChecked(col in selected_x) + + if 'y_columns' in config: + selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set() + for col, checkbox in self.y_column_checkboxes.items(): + checkbox.setChecked(col in selected_y) + + if 'methods' in config: + methods = config['methods'] + if isinstance(methods, list): + selected_methods = set(methods) + elif methods == 'all': + selected_methods = set(self.method_checkboxes.keys()) + else: + selected_methods = set() + for method, checkbox in self.method_checkboxes.items(): + checkbox.setChecked(method in selected_methods) + + if 'output_dir' in config: + self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling") + if 'enabled' in config: + self.enable_checkbox.setChecked(config['enabled']) def update_from_config(self, work_dir=None, pipeline=None): - """从全局配置自动填充预测结果目录 - - 优先使用 Step8(机器学习预测)的输出目录作为待预测 CSV 目录; - 其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。 + """从全局配置自动填充训练数据和输出路径 Args: work_dir: 工作目录路径 @@ -344,190 +297,78 @@ class Step9Panel(QWidget): else: self.work_dir = None + # 1. 尝试从 Step5 界面读取训练光谱 CSV 路径 main_window = self.window() - if not main_window: - return + if main_window and hasattr(main_window, 'step5_panel'): + step5_widget = getattr(main_window.step5_panel, 'output_file', None) + step5_output_path = "" + if hasattr(step5_widget, 'get_path'): + step5_output_path = step5_widget.get_path() or "" + elif hasattr(step5_widget, 'text'): + step5_output_path = step5_widget.text() or "" - # 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先) - pred_dir = None - if hasattr(main_window, 'step8_panel'): - step8_widget = getattr(main_window.step8_panel, 'output_file', None) - step8_output = "" - if hasattr(step8_widget, 'get_path'): - step8_output = step8_widget.get_path() or "" - elif hasattr(step8_widget, 'text'): - step8_output = step8_widget.text() or "" - - if step8_output: + if step5_output_path: # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(step8_output): - step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/') - # 提取父目录后追加 Machine_Learning_Prediction(最底层真实子目录) - base_pred_dir = str(Path(step8_output).parent) - ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction" - pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir + if not os.path.isabs(step5_output_path): + step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/') + existing = self.csv_file.get_path() + if not existing or not existing.strip(): + self.csv_file.set_path(step5_output_path) - # 2. 备选:从 Step8.5 界面读取非经验预测输出目录 - if not pred_dir and hasattr(main_window, 'step8_5_panel'): - step8_5_widget = getattr(main_window.step8_5_panel, 'output_file', None) - step8_5_output = "" - if hasattr(step8_5_widget, 'get_path'): - step8_5_output = step8_5_widget.get_path() or "" - elif hasattr(step8_5_widget, 'text'): - step8_5_output = step8_5_widget.text() or "" - - if step8_5_output: - # 若为相对路径,使用 work_dir 合成为绝对路径 - if not os.path.isabs(step8_5_output): - step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/') - pred_dir = str(Path(step8_5_output).parent) - - # 3. 备选:从 Step8.75 界面读取自定义回归预测输出目录 - if not pred_dir and hasattr(main_window, 'step8_75_panel'): - step8_75_widget = getattr(main_window.step8_75_panel, 'output_dir_widget', None) - step8_75_output = "" - if hasattr(step8_75_widget, 'get_path'): - step8_75_output = step8_75_widget.get_path() or "" - elif hasattr(step8_75_widget, 'text'): - step8_75_output = step8_75_widget.text() or "" - - if step8_75_output: - pred_dir = step8_75_output - - # 自动填入"预测CSV目录"(文件夹批量模式) - if pred_dir: - existing_dir = (self.prediction_csv_dir_edit.text() or "").strip() - if not existing_dir: - self.prediction_csv_dir_edit.setText(pred_dir) - # 切换到文件夹批量模式 - self.mode_folder_rb.setChecked(True) - - # 4. 自动填充输出目录(14_visualization) + # 2. 自动填充输出目录(9_Custom_Regression_Modeling) if self.work_dir: - output_dir = os.path.join(self.work_dir, "14_visualization") + output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling") os.makedirs(output_dir, exist_ok=True) - existing_out = self.output_dir.get_path() - if not existing_out or not existing_out.strip(): - self.output_dir.set_path(output_dir) - - # 5. 自动探测原始矢量边界文件(.shp)作为专题图底图 - # 优先回溯 input-test/roi.shp,geopandas.read_file 仅支持矢量格式 - if self.work_dir: - possible_shp = None - candidates = [ - Path(self.work_dir).parent / "input-test" / "roi.shp", - Path(self.work_dir) / "roi.shp", - Path(self.work_dir).parent / "roi.shp", - ] - for candidate in candidates: - if candidate.exists() and candidate.suffix.lower() == ".shp": - possible_shp = candidate - break - - existing_boundary = (self.boundary_file.get_path() or "").strip() - if not existing_boundary and possible_shp: - self.boundary_file.set_path(str(possible_shp)) - elif not existing_boundary: - # 未找到 .shp 时清空并提示用户手动选择矢量文件 - self.boundary_file.set_path("") - print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。") + existing_out = self.output_dir.text().strip() + if not existing_out: + self.output_dir.setText(output_dir) except Exception as e: import traceback print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}") traceback.print_exc() - def browse_output_dir(self): - """浏览输出目录""" - default = self._get_default_work_dir() - if default: - default = os.path.join(default, "14_visualization") - dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default) - if dir_path: - self.output_dir.set_path(dir_path) - def run_step(self): """独立运行步骤9""" - if self._batch_thread and self._batch_thread.isRunning(): - QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。") + csv_path = self.csv_file.get_path() + + if not csv_path: + QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件") + return + if not os.path.exists(csv_path): + QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在") return - boundary_shp_path = self.boundary_file.get_path() - if not boundary_shp_path: - QMessageBox.warning(self, "输入验证失败", "请选择边界文件") + selected_x_columns = [ + col for col, checkbox in self.x_column_checkboxes.items() + if checkbox.isChecked() + ] + if not selected_x_columns: + QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列") return - if not os.path.exists(boundary_shp_path): - QMessageBox.warning(self, "输入验证失败", "边界文件不存在") + + selected_y_columns = [ + col for col, checkbox in self.y_column_checkboxes.items() + if checkbox.isChecked() + ] + if not selected_y_columns: + QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列") return + selected_methods = [ + method for method, checkbox in self.method_checkboxes.items() + if checkbox.isChecked() + ] + if not selected_methods: + QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法") + return + + config = self.get_config() + parent = self.parent() while parent and not hasattr(parent, 'run_single_step'): parent = parent.parent() - if not parent or not hasattr(parent, 'run_single_step'): + if parent and hasattr(parent, 'run_single_step'): + parent.run_single_step('step9', {'step9': config}) + else: QMessageBox.critical(self, "错误", "无法找到父级GUI对象") - return - - if self.mode_folder_rb.isChecked(): - csv_list = self._collect_csv_paths_from_folder() - if not csv_list: - QMessageBox.warning( - self, - "输入验证失败", - "所选文件夹中未找到 .csv 文件,或目录无效。\n" - "可勾选「包含子文件夹」以递归扫描。", - ) - return - if not PIPELINE_AVAILABLE: - QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。") - return - work_dir = getattr(parent, "work_dir", None) or "./work_dir" - work_dir = str(work_dir) - base_kw = self._step9_base_pipeline_kwargs() - out_dir_opt = (self.output_dir.get_path() or "").strip() or None - self.run_button.setEnabled(False) - self._batch_thread = Step9BatchThread(work_dir, csv_list, base_kw, out_dir_opt) - main_win = parent - - def _batch_log(msg, lvl): - if hasattr(main_win, "log_message"): - main_win.log_message(msg, lvl) - - self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection) - self._batch_thread.finished_ok.connect(self._on_step9_batch_ok, Qt.QueuedConnection) - self._batch_thread.failed.connect(self._on_step9_batch_fail, Qt.QueuedConnection) - self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection) - self._batch_thread.start() - if hasattr(parent, "log_message"): - parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV,工作目录 {work_dir}", "info") - return - - prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip() - if not prediction_csv_path: - QMessageBox.warning( - self, - "输入验证失败", - "请选择「预测结果 CSV」文件,或切换到「文件夹批量」。", - ) - return - if not os.path.isfile(prediction_csv_path): - QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件") - return - - config = self.get_config() - parent.run_single_step('step9', {'step9': config}) - - def _on_step9_batch_ok(self, n: int): - QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。") - parent = self.parent() - while parent and not hasattr(parent, "log_message"): - parent = parent.parent() - if parent and hasattr(parent, "log_message"): - parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info") - - def _on_step9_batch_fail(self, err: str): - QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}") - parent = self.parent() - while parent and not hasattr(parent, "log_message"): - parent = parent.parent() - if parent and hasattr(parent, "log_message"): - parent.log_message(err, "error") diff --git a/src/gui/panels/visualization_panel.py b/src/gui/panels/visualization_panel.py index 0c5300a..7064ed8 100644 --- a/src/gui/panels/visualization_panel.py +++ b/src/gui/panels/visualization_panel.py @@ -1567,12 +1567,12 @@ class VisualizationPanel(QWidget): ml_dir.mkdir(parents=True, exist_ok=True) reg_dir.mkdir(parents=True, exist_ok=True) custom_dir.mkdir(parents=True, exist_ok=True) - if hasattr(self, 'step8_panel') and hasattr(self.step8_panel, 'output_file'): - self.step8_panel.output_file.set_path(str(ml_dir)) - if hasattr(self, 'step8_5_panel') and hasattr(self.step8_5_panel, 'output_file'): - self.step8_5_panel.output_file.set_path(str(reg_dir)) - if hasattr(self, 'step8_75_panel') and hasattr(self.step8_75_panel, 'output_dir_widget'): - self.step8_75_panel.output_dir_widget.set_path(str(custom_dir)) + if hasattr(self, 'step11_ml_panel') and hasattr(self.step11_ml_panel, 'output_file'): + self.step11_ml_panel.output_file.set_path(str(ml_dir)) + if hasattr(self, 'step11_panel') and hasattr(self.step11_panel, 'output_file'): + self.step11_panel.output_file.set_path(str(reg_dir)) + if hasattr(self, 'step12_panel') and hasattr(self.step12_panel, 'output_dir_widget'): + self.step12_panel.output_dir_widget.set_path(str(custom_dir)) print(f"预测输出目录已设置:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}") except Exception as e: print(f"设置预测输出目录失败: {e}") diff --git a/src/gui/water_quality_gui.py b/src/gui/water_quality_gui.py index 0b88f03..6526866 100644 --- a/src/gui/water_quality_gui.py +++ b/src/gui/water_quality_gui.py @@ -119,19 +119,22 @@ from src.gui.panels.step2_panel import Step2Panel from src.gui.panels.step3_panel import Step3Panel from src.gui.panels.step4_panel import Step4Panel from src.gui.panels.step5_panel import Step5Panel -from src.gui.panels.step5_5_panel import Step5_5Panel -from src.gui.panels.step6_panel import Step6Panel -from src.gui.panels.step6_5_panel import Step6_5Panel -from src.gui.panels.step6_75_panel import Step6_75Panel -from src.gui.panels.step7_panel import Step7Panel -from src.gui.panels.step8_panel import Step8Panel -from src.gui.panels.step8_5_panel import Step8_5Panel -from src.gui.panels.step8_75_panel import Step8_75Panel -from src.gui.panels.step9_panel import Step9Panel +from src.gui.panels.step8_panel import Step8Panel # was step5_5_panel +from src.gui.panels.step7_panel import Step7Panel # was step6_panel +from src.gui.panels.step8_non_empirical_panel import Step8NonEmpiricalPanel # was step6_5_panel +from src.gui.panels.step9_panel import Step9Panel # was step6_75_panel +from src.gui.panels.step10_panel import Step10Panel # was step7_panel +from src.gui.panels.step11_ml_panel import Step11MlPanel # ML prediction (step11_ml) +from src.gui.panels.step11_panel import Step11Panel # was step8_5_panel +from src.gui.panels.step12_panel import Step12Panel # was step8_75_panel +from src.gui.panels.step14_panel import Step14Panel # was step9_panel from src.gui.dialogs import BandConfirmDialog, AISettingsDialog from src.gui.panels.visualization_panel import VisualizationPanel from src.gui.panels.report_generation_panel import ReportGenerationPanel +# Pipeline 核心异常(用于预检弹窗) +from src.core.pipeline.runner import PipelineHalt + # Matplotlib相关导入 (推迟并加入底层防爆保护) import matplotlib try: @@ -152,6 +155,9 @@ from src.gui.core.worker_thread import ( check_pipeline_dependencies, diagnose_pipeline_import_error, ) +# 预检交互对话框 +from src.gui.core.preflight_dialog import PreflightDialog +from src.gui.core.pipeline_mode_dialog import PipelineModeDialog def _viz_training_spectra_csv_path(work_path: Path) -> Path: @@ -1384,31 +1390,31 @@ class WaterQualityGUI(QMainWindow): 'step5': { 'training_spectra': '5_training_spectra/training_spectra.csv' }, - 'step5_5': { + 'step8': { 'water_indices': '6_water_quality_indices/water_quality_indices.csv' }, - 'step6': { + 'step7': { 'models': '7_Supervised_Model_Training/' # 目录,包含各参数子目录 }, - 'step6_5': { + 'step8_non_empirical_modeling': { 'regression_models': '8_Regression_Modeling/' # 目录,包含各参数子目录 }, - 'step6_75': { + 'step9': { 'custom_regression_models': '9_Custom_Regression_Modeling/' # 目录 }, - 'step7': { + 'step10': { 'sampling_points': '10_sampling/sampling_spectra.csv' }, - 'step8': { + 'step11_ml': { 'predictions': '11_12_13_predictions/Machine_Learning_Prediction/' # 目录,包含机器学习预测结果 }, - 'step8_5': { + 'step11': { 'regression_predictions': '11_12_13_predictions/Non_Empirical_Prediction/' # 目录,包含非经验模型预测结果 }, - 'step8_75': { + 'step12': { 'custom_predictions': '11_12_13_predictions/Custom_Regression_Prediction/' # 目录,包含自定义回归预测结果 }, - 'step9': { + 'step14': { 'distribution_maps': '14_visualization/' # 目录,包含专题图 } } @@ -1432,37 +1438,37 @@ class WaterQualityGUI(QMainWindow): 'boundary_mask_path': ('step1', 'water_mask', 'boundary_mask_file'), # 步骤5可选水体掩膜 'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤5可选耀斑掩膜 }, - 'step5_5': { - 'training_csv_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱 - }, - 'step6': { - 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6需要训练光谱数据 - }, - 'step6_5': { - 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6.5需要训练光谱数据 - }, - 'step6_75': { - 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6.75需要训练光谱数据 + 'step8': { + 'training_csv_path': ('step5', 'training_spectra', 'output_file') # 步骤8需要步骤5输出的训练光谱 }, 'step7': { - 'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'), # 步骤7需要去耀斑影像 - 'water_mask_path': ('step1', 'water_mask', 'water_mask_file'), # 步骤7需要水域掩膜 - 'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤7可选耀斑掩膜 + 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤7需要训练光谱数据 }, - 'step8': { - 'sampling_csv_path': ('step7', 'sampling_points', 'sampling_csv_file'), # 步骤8需要采样点 - 'models_dir': ('step6', 'models', 'models_dir_file') # 步骤8需要训练好的模型 - }, - 'step8_5': { - 'sampling_csv_path': ('step7', 'sampling_points', 'sampling_csv_file'), # 步骤8.5需要采样点 - 'models_dir': ('step6_5', 'regression_models', 'models_dir') # 步骤8.5需要回归模型 - }, - 'step8_75': { - 'sampling_csv_path': ('step7', 'sampling_points', 'sampling_csv_file'), # 步骤8.75需要采样点 - 'models_dir': ('step6_75', 'custom_regression_models', 'models_dir') # 步骤8.75需要自定义回归模型 + 'step8_non_empirical_modeling': { + 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤8非经验建模需要训练光谱数据 }, 'step9': { - 'prediction_csv_path': ('step8', 'predictions', 'prediction_csv_file') # 步骤9需要预测结果CSV + 'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤9需要训练光谱数据 + }, + 'step10': { + 'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'), # 步骤10需要去耀斑影像 + 'water_mask_path': ('step1', 'water_mask', 'water_mask_file'), # 步骤10需要水域掩膜 + 'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤10可选耀斑掩膜 + }, + 'step11_ml': { + 'sampling_csv_path': ('step10', 'sampling_points', 'sampling_csv_file'), # 步骤11ML需要采样点 + 'models_dir': ('step7', 'models', 'models_dir_file') # 步骤11ML需要训练好的模型 + }, + 'step11': { + 'sampling_csv_path': ('step10', 'sampling_points', 'sampling_csv_file'), # 步骤11需要采样点 + 'models_dir': ('step8_non_empirical_modeling', 'regression_models', 'models_dir') # 步骤11需要回归模型 + }, + 'step12': { + 'sampling_csv_path': ('step10', 'sampling_points', 'sampling_csv_file'), # 步骤12需要采样点 + 'models_dir': ('step9', 'custom_regression_models', 'models_dir') # 步骤12需要自定义回归模型 + }, + 'step14': { + 'prediction_csv_path': ('step11_ml', 'predictions', 'prediction_csv_file') # 步骤14需要预测结果CSV } } @@ -1545,7 +1551,7 @@ class WaterQualityGUI(QMainWindow): def init_ui(self): """初始化UI""" - self.setWindowTitle("MegaCube-Water Quality ‌V1.1") + self.setWindowTitle("MegaCube-Water Quality ‌V1.2") # 获取屏幕可用区域(排除任务栏) screen_geometry = QApplication.primaryScreen().availableGeometry() @@ -1730,7 +1736,7 @@ class WaterQualityGUI(QMainWindow): def create_banner_widget(self): """创建横幅区域 - 支持自适应等比缩放""" # 横幅标题文字(方便后续直接修改版本号) - self._APP_TITLE = "MegaCube-Water Quality ‌V1.1" + self._APP_TITLE = "MegaCube-Water Quality ‌V1.2" # 创建横幅容器 banner_widget = QWidget() @@ -1844,19 +1850,19 @@ class WaterQualityGUI(QMainWindow): "阶段二:样本数据准备 ": [ ("step4", "4. 数据标准化处理"), ("step5", "5. 光谱特征提取"), - ("step5_5", "6. 水质参数指数计算"), + ("step8", "6. 水质参数指数计算"), ], "阶段三:模型构建与训练": [ - ("step6", "7. 机器学习模型训练"), - ("step6_5", "8. 回归模型训练"), - ("step6_75", "9. 自定义回归模型训练"), + ("step7", "7. 机器学习模型训练"), + ("step8_non_empirical_modeling", "8. 回归模型训练"), + ("step9", "9. 自定义回归模型训练"), ], "阶段四:预测与成果输出 ": [ - ("step7", "10. 采样点布设"), - ("step8", "11. 机器学习预测"), - ("step8_5", "12. 回归预测"), - ("step8_75", "13. 自定义回归预测"), - ("step9", "14. 专题图生成"), + ("step10", "10. 采样点布设"), + ("step11_ml", "11. 机器学习预测"), + ("step11", "12. 回归预测"), + ("step12", "13. 自定义回归预测"), + ("step14", "14. 专题图生成"), ("step9_viz", "15. 可视化分析"), ("step_report", "16. 分析报告生成"), ] @@ -1878,7 +1884,7 @@ class WaterQualityGUI(QMainWindow): self.step_list.addItem(stage_item) # 添加该阶段的所有步骤 - HIDDEN_STEP_IDS = {"step6_5", "step6_75", "step8_5", "step8_75"} + HIDDEN_STEP_IDS = {"step8_non_empirical_modeling", "step9", "step11", "step12"} for step_id, step_display in steps: if step_id in HIDDEN_STEP_IDS: continue @@ -1958,36 +1964,36 @@ class WaterQualityGUI(QMainWindow): self.step5_panel = Step5Panel() self.step_stack.addTab(self.create_scroll_area(self.step5_panel), QIcon(self.get_icon_path("5.png")), "特征构建") - self.step5_5_panel = Step5_5Panel() - self.step_stack.addTab(self.create_scroll_area(self.step5_5_panel), QIcon(self.get_icon_path("5.png")), "水质指数") - - self.step6_panel = Step6Panel() - self.step_stack.addTab(self.create_scroll_area(self.step6_panel), QIcon(self.get_icon_path("6.png")), "监督建模") - - self.step6_5_panel = Step6_5Panel() - self.step_stack.addTab(self.create_scroll_area(self.step6_5_panel), QIcon(self.get_icon_path("6.png")), "回归建模") - self.step_stack.tabBar().setTabVisible(7, False) # 隐藏回归建模 Tab - - self.step6_75_panel = Step6_75Panel() - self.step_stack.addTab(self.create_scroll_area(self.step6_75_panel), QIcon(self.get_icon_path("6.png")), "自定义回归建模") - self.step_stack.tabBar().setTabVisible(8, False) # 隐藏自定义回归建模 Tab + self.step8_panel = Step8Panel() + self.step_stack.addTab(self.create_scroll_area(self.step8_panel), QIcon(self.get_icon_path("5.png")), "水质指数") self.step7_panel = Step7Panel() - self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("7.png")), "采样点布设") - - self.step8_panel = Step8Panel() - self.step_stack.addTab(self.create_scroll_area(self.step8_panel), QIcon(self.get_icon_path("8.png")), "监督预测") + self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("6.png")), "监督建模") - self.step8_5_panel = Step8_5Panel() - self.step_stack.addTab(self.create_scroll_area(self.step8_5_panel), QIcon(self.get_icon_path("8.png")), "回归预测") - self.step_stack.tabBar().setTabVisible(11, False) # 隐藏回归预测 Tab - - self.step8_75_panel = Step8_75Panel() - self.step_stack.addTab(self.create_scroll_area(self.step8_75_panel), QIcon(self.get_icon_path("8.png")), "自定义回归预测") - self.step_stack.tabBar().setTabVisible(12, False) # 隐藏自定义回归预测 Tab + self.step8_non_empirical_panel = Step8NonEmpiricalPanel() + self.step_stack.addTab(self.create_scroll_area(self.step8_non_empirical_panel), QIcon(self.get_icon_path("6.png")), "回归建模") + self.step_stack.tabBar().setTabVisible(7, False) # 隐藏回归建模 Tab self.step9_panel = Step9Panel() - self.step_stack.addTab(self.create_scroll_area(self.step9_panel), QIcon(self.get_icon_path("10.png")), "专题图生成") + self.step_stack.addTab(self.create_scroll_area(self.step9_panel), QIcon(self.get_icon_path("6.png")), "自定义回归建模") + self.step_stack.tabBar().setTabVisible(8, False) # 隐藏自定义回归建模 Tab + + self.step10_panel = Step10Panel() + self.step_stack.addTab(self.create_scroll_area(self.step10_panel), QIcon(self.get_icon_path("7.png")), "采样点布设") + + self.step11_ml_panel = Step11MlPanel() # ML prediction panel (step11_ml) + self.step_stack.addTab(self.create_scroll_area(self.step11_ml_panel), QIcon(self.get_icon_path("8.png")), "监督预测") + + self.step11_panel = Step11Panel() + self.step_stack.addTab(self.create_scroll_area(self.step11_panel), QIcon(self.get_icon_path("8.png")), "回归预测") + self.step_stack.tabBar().setTabVisible(11, False) # 隐藏回归预测 Tab + + self.step12_panel = Step12Panel() + self.step_stack.addTab(self.create_scroll_area(self.step12_panel), QIcon(self.get_icon_path("8.png")), "自定义回归预测") + self.step_stack.tabBar().setTabVisible(12, False) # 隐藏自定义回归预测 Tab + + self.step14_panel = Step14Panel() + self.step_stack.addTab(self.create_scroll_area(self.step14_panel), QIcon(self.get_icon_path("10.png")), "专题图生成") self.viz_panel = VisualizationPanel() self.step_stack.addTab(self.create_scroll_area(self.viz_panel), QIcon(self.get_icon_path("9.png")), "可视化") @@ -2137,15 +2143,15 @@ class WaterQualityGUI(QMainWindow): 'step3': 2, 'step4': 3, 'step5': 4, - 'step5_5': 5, - 'step6': 6, - 'step6_5': 7, - 'step6_75': 8, - 'step7': 9, - 'step8': 10, - 'step8_5': 11, - 'step8_75': 12, - 'step9': 13, + 'step8': 5, + 'step7': 6, + 'step8_non_empirical_modeling': 7, + 'step9': 8, + 'step10': 9, + 'step11_ml': 10, + 'step11': 11, + 'step12': 12, + 'step14': 13, 'step9_viz': 14, 'step_report': 15, } @@ -2168,15 +2174,15 @@ class WaterQualityGUI(QMainWindow): 2: 'step3', 3: 'step4', 4: 'step5', - 5: 'step5_5', - 6: 'step6', - 7: 'step6_5', - 8: 'step6_75', - 9: 'step7', - 10: 'step8', - 11: 'step8_5', - 12: 'step8_75', - 13: 'step9', + 5: 'step8', + 6: 'step7', + 7: 'step8_non_empirical_modeling', + 8: 'step9', + 9: 'step10', + 10: 'step11_ml', + 11: 'step11', + 12: 'step12', + 13: 'step14', 14: 'step9_viz', 15: 'step_report', } @@ -2213,41 +2219,41 @@ class WaterQualityGUI(QMainWindow): elif index == 4: self.step5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) - # Step5_5 切换时自动填充输出路径 + # Step8(水质指数)切换时自动填充输出路径 elif index == 5: - self.step5_5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + self.step8_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) - # Step6 切换时自动填充训练数据和输出路径 + # Step7(监督建模)切换时自动填充训练数据和输出路径 elif index == 6: - self.step6_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) - - # Step6.5(非经验回归建模)切换时自动填充训练数据和模型目录 - elif index == 7: - self.step6_5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) - - # Step6.75(自定义回归建模)切换时自动填充训练数据和模型目录 - elif index == 8: - self.step6_75_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) - - # Step7(采样点布设)切换时自动填充掩膜和输出路径 - elif index == 9: self.step7_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + # Step8非经验建模切换时自动填充训练数据和模型目录 + elif index == 7: + self.step8_non_empirical_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + + # Step9(自定义回归建模)切换时自动填充训练数据和模型目录 + elif index == 8: + self.step9_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + + # Step10(采样点布设)切换时自动填充掩膜和输出路径 + elif index == 9: + self.step10_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + # Step8(机器学习预测)切换时自动填充采样光谱和模型目录 elif index == 10: - self.step8_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + self.step11_ml_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) - # Step8.5(非经验模型预测)切换时自动填充采样光谱和回归模型目录 + # Step11(回归预测)切换时自动填充采样光谱和回归模型目录 elif index == 11: - self.step8_5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + self.step11_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) - # Step8.75(自定义回归预测)切换时自动填充采样光谱和自定义回归模型目录 + # Step12(自定义回归预测)切换时自动填充采样光谱和自定义回归模型目录 elif index == 12: - self.step8_75_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + self.step12_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) - # Step9(专题图生成)切换时自动填充预测结果目录 + # Step14(专题图生成)切换时自动填充预测结果目录 elif index == 13: - self.step9_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + self.step14_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) # 可视化分析面板切换时自动推断图像目录并加载目录树 elif index == 14: @@ -2294,22 +2300,22 @@ class WaterQualityGUI(QMainWindow): self.step4_panel.set_config(config['step4']) if 'step5' in config: self.step5_panel.set_config(config['step5']) - if 'step5_5' in config: - self.step5_5_panel.set_config(config['step5_5']) - if 'step6' in config: - self.step6_panel.set_config(config['step6']) - if 'step6_5' in config: - self.step6_5_panel.set_config(config['step6_5']) - if 'step6_75' in config: - self.step6_75_panel.set_config(config['step6_75']) - if 'step7' in config: - self.step7_panel.set_config(config['step7']) if 'step8' in config: self.step8_panel.set_config(config['step8']) - if 'step8_5' in config: - self.step8_5_panel.set_config(config['step8_5']) + if 'step7' in config: + self.step7_panel.set_config(config['step7']) + if 'step8_non_empirical_modeling' in config: + self.step8_non_empirical_panel.set_config(config['step8_non_empirical_modeling']) if 'step9' in config: self.step9_panel.set_config(config['step9']) + if 'step10' in config: + self.step10_panel.set_config(config['step10']) + if 'step11_ml' in config: + self.step11_ml_panel.set_config(config['step11_ml']) + if 'step11' in config: + self.step11_panel.set_config(config['step11']) + if 'step14' in config: + self.step14_panel.set_config(config['step14']) if 'visualization' in config: self.viz_panel.set_config(config['visualization']) if 'report_generation' in config: @@ -2352,14 +2358,14 @@ class WaterQualityGUI(QMainWindow): 'step3': self.step3_panel.get_config(), 'step4': self.step4_panel.get_config(), 'step5': self.step5_panel.get_config(), - 'step5_5': self.step5_5_panel.get_config(), - 'step6': self.step6_panel.get_config(), - 'step6_5': self.step6_5_panel.get_config(), - 'step6_75': self.step6_75_panel.get_config(), - 'step7': self.step7_panel.get_config(), 'step8': self.step8_panel.get_config(), - 'step8_5': self.step8_5_panel.get_config(), + 'step7': self.step7_panel.get_config(), + 'step8_non_empirical_modeling': self.step8_non_empirical_panel.get_config(), 'step9': self.step9_panel.get_config(), + 'step10': self.step10_panel.get_config(), + 'step11_ml': self.step11_ml_panel.get_config(), + 'step11': self.step11_panel.get_config(), + 'step14': self.step14_panel.get_config(), 'visualization': self.viz_panel.get_config(), 'report_generation': self.report_panel.get_config(), } @@ -2410,15 +2416,15 @@ class WaterQualityGUI(QMainWindow): 'step3': self.step3_panel, 'step4': self.step4_panel, 'step5': self.step5_panel, - 'step5_5': self.step5_5_panel, - 'step6': self.step6_panel, - 'step6_5': self.step6_5_panel, - 'step6_75': self.step6_75_panel, - 'step7': self.step7_panel, 'step8': self.step8_panel, - 'step8_5': self.step8_5_panel, - 'step8_75': self.step8_75_panel, + 'step7': self.step7_panel, + 'step8_non_empirical_modeling': self.step8_non_empirical_panel, 'step9': self.step9_panel, + 'step10': self.step10_panel, + 'step11_ml': self.step11_ml_panel, + 'step11': self.step11_panel, + 'step12': self.step12_panel, + 'step14': self.step14_panel, } return panel_map.get(step_id) @@ -2426,15 +2432,38 @@ class WaterQualityGUI(QMainWindow): """查找指定步骤的输出文件""" if step_id not in self.step_default_outputs: return None - + step_outputs = self.step_default_outputs[step_id] - + + # ★ 掩膜类型列表:这些类型只接受科学数据格式 + mask_types = {'water_mask', 'glint_mask', 'boundary_mask'} + # ★ 白名单机制:只允许 .dat .tif .tiff .shp,拒绝其他一切格式 + scientific_extensions = {'.dat', '.tif', '.tiff', '.shp'} + # ★ 临时文件关键词黑名单 + tmp_keywords = ('__tmp', '_tmp') + + def _is_scientific_mask(path_str): + """白名单判断:只有 .dat .tif .tiff .shp 才算科学数据格式""" + p = Path(path_str) + name_lower = str(path_str).lower() + # 拒绝临时文件 + if any(kw in name_lower for kw in tmp_keywords): + return False + # 白名单校验 + return p.suffix.lower() in scientific_extensions + # 特殊处理:从step_outputs记录中查找实际输出路径 if step_id in self.step_outputs: actual_outputs = self.step_outputs[step_id] if output_type in actual_outputs: - return actual_outputs[output_type] - + candidate = actual_outputs[output_type] + # ★ 掩膜类型白名单二次校验:不在白名单内的一律拒绝 + if output_type in mask_types and not _is_scientific_mask(candidate): + # 非科学格式被拒绝,不使用 step_outputs 中的值 + pass + else: + return candidate + # 根据输出类型查找对应的文件 if output_type == 'water_mask': # 水域掩膜:优先查找NDWI生成的,其次是shp生成的 @@ -2485,19 +2514,19 @@ class WaterQualityGUI(QMainWindow): # 扫描各个子目录 subdirs = { '1_water_mask': 'step1', - '2_glint': 'step2', + '2_glint': 'step2', '3_deglint': 'step3', '4_processed_data': 'step4', '5_training_spectra': 'step5', - '6_water_quality_indices': 'step5_5', - '7_Supervised_Model_Training': 'step6', - '8_Regression_Modeling': 'step6_5', - '9_Custom_Regression_Modeling': 'step6_75', - '10_sampling': 'step7', - '11_12_13_predictions/Machine_Learning_Prediction': 'step8', - '11_12_13_predictions/Non_Empirical_Prediction': 'step8_5', - '11_12_13_predictions/Custom_Regression_Prediction': 'step8_75', - '14_visualization': 'step9' + '6_water_quality_indices': 'step8', + '7_Supervised_Model_Training': 'step7', + '8_Regression_Modeling': 'step8_non_empirical_modeling', + '9_Custom_Regression_Modeling': 'step9', + '10_sampling': 'step10', + '11_12_13_predictions/Machine_Learning_Prediction': 'step11_ml', + '11_12_13_predictions/Non_Empirical_Prediction': 'step11', + '11_12_13_predictions/Custom_Regression_Prediction': 'step12', + '14_visualization': 'step14' } for subdir, step_ids in subdirs.items(): @@ -2517,23 +2546,37 @@ class WaterQualityGUI(QMainWindow): for step_id in step_ids: if step_id not in discovered_outputs: discovered_outputs[step_id] = {} - + + # ★ 掩膜文件白名单过滤:只有 .dat .tif .tiff .shp 才通过,拒绝 .hdr .xml .png 等 + scientific_extensions = {'.dat', '.tif', '.tiff', '.shp'} + tmp_keywords = ('__tmp', '_tmp') + + def _is_scientific_mask(path_str): + """白名单判断:拒绝 .hdr .xml 临时文件等,只接受科学数据格式""" + p = Path(path_str) + name_lower = str(path_str).lower() + if any(kw in name_lower for kw in tmp_keywords): + return False + return p.suffix.lower() in scientific_extensions + # 匹配不同的文件类型 if 'water_mask' in file_name and step_id == 'step1': - discovered_outputs[step_id]['water_mask'] = str(file_path) + if _is_scientific_mask(file_path): + discovered_outputs[step_id]['water_mask'] = str(file_path) elif 'glint' in file_name and 'mask' in file_name and step_id == 'step2': - discovered_outputs[step_id]['glint_mask'] = str(file_path) + if _is_scientific_mask(file_path): + discovered_outputs[step_id]['glint_mask'] = str(file_path) elif 'deglint' in file_name and step_id == 'step3': discovered_outputs[step_id]['deglint_image'] = str(file_path) elif 'processed_data' in file_name and step_id == 'step4': discovered_outputs[step_id]['processed_data'] = str(file_path) elif 'training_spectra' in file_name and step_id == 'step5': discovered_outputs[step_id]['training_spectra'] = str(file_path) - elif 'water_quality_indices' in file_name and step_id == 'step5_5': + elif 'water_quality_indices' in file_name and step_id == 'step8': discovered_outputs[step_id]['water_indices'] = str(file_path) - elif 'sampling_spectra' in file_name and step_id == 'step7': + elif 'sampling_spectra' in file_name and step_id == 'step10': discovered_outputs[step_id]['sampling_points'] = str(file_path) - elif file_name.endswith('.csv') and step_id in ['step8', 'step8_5', 'step8_75']: + elif file_name.endswith('.csv') and step_id in ['step11_ml', 'step11', 'step12']: discovered_outputs[step_id]['predictions'] = str(file_path) # 更新内部记录 @@ -2556,8 +2599,8 @@ class WaterQualityGUI(QMainWindow): # 首先扫描工作目录发现已有的输出文件 self.scan_work_directory_for_files(work_path) - step_order = ['step2', 'step3', 'step4', 'step5', 'step5_5', 'step6', 'step6_5', 'step6_75', - 'step7', 'step8', 'step8_5', 'step8_75', 'step9'] + step_order = ['step2', 'step3', 'step4', 'step5', 'step8', 'step7', 'step8_non_empirical_modeling', 'step9', + 'step10', 'step11_ml', 'step11', 'step12', 'step14'] filled_count = 0 for step_id in step_order: @@ -2579,15 +2622,15 @@ class WaterQualityGUI(QMainWindow): ('step2', self.step2_panel), ('step3', self.step3_panel), ('step5', self.step5_panel), - ('step5_5', self.step5_5_panel), - ('step6', self.step6_panel), - ('step6_5', self.step6_5_panel), - ('step6_75', self.step6_75_panel), - ('step7', self.step7_panel), ('step8', self.step8_panel), - ('step8_5', self.step8_5_panel), - ('step8_75', self.step8_75_panel), - ('step9', self.step9_panel) + ('step7', self.step7_panel), + ('step8_non_empirical_modeling', self.step8_non_empirical_panel), + ('step9', self.step9_panel), + ('step10', self.step10_panel), + ('step11_ml', self.step11_ml_panel), + ('step11', self.step11_panel), + ('step12', self.step12_panel), + ('step14', self.step14_panel) ] for step_id, panel in panels_with_dependencies: @@ -2735,7 +2778,7 @@ class WaterQualityGUI(QMainWindow): """显示关于对话框""" QMessageBox.about( self, "关于", - "MegaCube-Water Quality ‌V1.1\n\n" + "MegaCube-Water Quality ‌V1.2\n\n" "一个完整的水质参数反演工作流程工具\n\n" "功能包括:\n" "- 水域掩膜生成\n" @@ -2858,6 +2901,41 @@ class WaterQualityGUI(QMainWindow): return True + # ------------------------------------------------------------------ + # ★ 全流程模式动态裁剪 + # ------------------------------------------------------------------ + + def _prune_config_for_prediction_mode(self, config: dict) -> dict: + """Prediction-only 模式:禁用训练相关步骤,保留预测和成图步骤。 + + 被禁用的 step dict 中统一写入 'enabled': False, + 这些配置最终传给 PipelineRunner,Runner 会跳过它们。 + 同时,被跳过的步骤的 required_input_files 在 build_missing_items + 中不会被检查,从而自然规避了"CSV 缺失"等训练模式下的误报。 + + Args: + config: 完整配置字典(来自 get_current_config) + + Returns: + 裁剪后的 config(深拷贝,原 config 不被修改) + """ + cfg = copy.deepcopy(config) + + # 在每个训练相关步骤的 dict 中写入 enabled=False + training_steps = [ + "step4", # CSV 实测数据清洗 + "step5", # 实测点光谱提取(→ training_csv_path) + "step7", # ML 监督建模 + "step8", # 水质指数计算(辅助训练) + "step8_non_empirical_modeling", # 非经验回归建模 + "step9", # 自定义回归建模 + ] + for step_id in training_steps: + step_cfg = cfg.setdefault(step_id, {}) + step_cfg["enabled"] = False + + return cfg + def run_full_pipeline(self): """运行完整流程""" if not PIPELINE_AVAILABLE: @@ -2867,8 +2945,14 @@ class WaterQualityGUI(QMainWindow): ) return + # ── 0) 强制获取 work_dir(禁止依赖外部或全局变量) ── + work_dir = getattr(self, 'work_dir', None) + if not work_dir: + QMessageBox.warning(self, "警告", "未选择工作目录,请先设置工作目录。") + return + # ── 1) 运行前智能预检与自动回填(硬盘已有产物自动跳过) ── - work_path = Path(getattr(self, 'work_dir', './work_dir')) + work_path = Path(work_dir) self.log_message("正在进行运行前环境预检与自动扫描...", "info") self.scan_work_directory_for_files(work_path) self.auto_populate_all_steps() @@ -2878,31 +2962,52 @@ class WaterQualityGUI(QMainWindow): if not self._precheck_step3_bands(): return # 用户点"取消运行" + # ── 1.6) ★ 全流程模式选择弹窗 ── + mode_dlg = PipelineModeDialog(main_window=self, parent=self) + if mode_dlg.exec() != QDialog.Accepted: + return # 用户点"取消" + selected_mode = mode_dlg.selected_mode + self.log_message(f"[模式选择] 选定模式: {'训练新模型' if selected_mode == 'training' else '使用已有模型直接预测'}", "info") + # ── 2) 刷新配置(拿到自动填充后的"满血版" config) ── config = self.get_current_config() - # ── 3) 根基数据校验:step1.img_path(参考影像) ── - if not config['step1'].get('img_path'): - QMessageBox.warning(self, "警告", "缺失核心数据:请先在步骤 1 中上传【参考影像】!") - for i in range(self.step_list.count()): - item = self.step_list.item(i) - if item.data(Qt.UserRole) == 'step1': - self.step_list.setCurrentRow(i) - break - return + # ── 2.1) ★ 根据模式动态裁剪配置 ── + if selected_mode == "prediction_only": + config = self._prune_config_for_prediction_mode(config) + self.log_message("[模式选择] 已裁剪训练相关步骤(step4/5/7/8),进入仅预测模式", "info") - # ── 4) 软提示:csv_path 缺失 → 模型训练步骤会被静默跳过(不阻断) ── - csv_path = config.get('step4', {}).get('csv_path') or config.get('step5', {}).get('csv_path') - if not csv_path: - QMessageBox.information( - self, - "提示:模型训练将被跳过", - "未检测到实测水质数据 (CSV)。\n" - "流程将自动跳过模型训练(步骤 4-6),仅执行预测与制图。\n" - "如果需要训练新模型,请先在步骤 4 中上传水质数据。", - ) + # ── 3) ★ 一次性全预检 + 用户交互式决策 ── + missing_items = PreflightDialog.build_missing_items(config) + if missing_items: + critical_items = [it for it in missing_items if it.is_critical] + if critical_items: + lines = "\n".join(f" - [{it.step_name}] {it.reason}" for it in critical_items) + QMessageBox.critical( + self, "预检失败(阻断性错误)", + f"以下为阻断性缺失,流程无法启动:\n\n{lines}\n\n" + "请填写后重新运行。" + ) + return + dialog = PreflightDialog(missing_items, parent=self) + if dialog.exec() != QDialog.Accepted: + return + result = dialog.get_result() + if result is None: + return + action, *payload = result + if action == "fill": + _, step_id, tab_index = result + self.step_stack.setCurrentIndex(tab_index) + self.log_message(f"[预检] 用户选择填写 {step_id},已切换到对应面板。", "info") + return + skip_list: List[str] = payload[0] if payload else [] + if skip_list: + self.log_message(f"[预检] 用户强制跳过 {len(skip_list)} 个步骤: {skip_list}", "info") + else: + skip_list = [] - # 确认执行 + # ── 4) 确认执行 reply = QMessageBox.question( self, "确认", "是否开始执行完整流程?\n\n这可能需要较长时间,请确保配置正确。", @@ -2913,19 +3018,18 @@ class WaterQualityGUI(QMainWindow): return # 创建pipeline实例 - work_dir = getattr(self, 'work_dir', './work_dir') self.log_message(f"初始化pipeline,工作目录: {work_dir}", "info") # 准备实际运行配置(排除未启用的步骤) worker_config = copy.deepcopy(config) - step5_5_cfg = worker_config.get('step5_5') - if step5_5_cfg: - enabled = step5_5_cfg.pop('enabled', True) + step8_cfg = worker_config.get('step8') + if step8_cfg: + enabled = step8_cfg.pop('enabled', True) if not enabled: - worker_config.pop('step5_5', None) + worker_config.pop('step8', None) # 工作线程内创建 Pipeline,避免主线程阻塞及 Qt5Agg 子线程绘图卡死 - self.worker = WorkerThread(work_dir, worker_config, mode='full') + self.worker = WorkerThread(work_dir, worker_config, mode='full', skip_list=skip_list) self.worker.log_message.connect(self.log_message, Qt.QueuedConnection) self.worker.progress_update.connect(self.update_progress, Qt.QueuedConnection) self.worker.step_completed.connect(self.on_step_completed, Qt.QueuedConnection) @@ -3152,14 +3256,14 @@ class WaterQualityGUI(QMainWindow): def update_ui_for_training_mode(self): """根据训练数据模式更新UI状态""" # 需要禁用的步骤ID(对应无训练数据模式下需要禁用的步骤) - disabled_step_ids = ['step4', 'step5', 'step5_5', 'step6', 'step6_5', 'step6_75'] + disabled_step_ids = ['step4', 'step5', 'step8', 'step7', 'step8_non_empirical_modeling', 'step9'] # 更新标签页的启用/禁用状态 step_id_to_tab = { 'step1': 0, 'step2': 1, 'step3': 2, 'step4': 3, - 'step5': 4, 'step5_5': 5, 'step6': 6, 'step6_5': 7, - 'step6_75': 8, 'step7': 9, 'step8': 10, 'step8_5': 11, - 'step8_75': 12, 'step9': 13, 'step9_viz': 14 + 'step5': 4, 'step8': 5, 'step7': 6, 'step8_non_empirical_modeling': 7, + 'step9': 8, 'step10': 9, 'step11_ml': 10, 'step11': 11, + 'step12': 12, 'step14': 13, 'step9_viz': 14 } for step_id in disabled_step_ids: