feat(gui): 全流程面板合并 + 一键式运行 GUI 入口集成

This commit is contained in:
DXC
2026-06-09 11:30:42 +08:00
parent aefc9d5aac
commit 28394f2eda
20 changed files with 2843 additions and 2432 deletions

View File

@ -8,7 +8,17 @@ Pipeline 调度核心:基于 Context 的内存级依赖注入。
- 不绑定具体 Pipeline 实现duck-typedWorkerThread / Web API / 单测可共用
"""
from .context import PipelineContext
from .runner import StepSpec, PIPELINE_STEPS, PipelineRunner
from .context import (
PipelineContext,
STEP_MAP_OLD_TO_NEW, STEP_MAP_NEW_TO_OLD,
resolve_step_id, ALL_STEP_IDS,
)
from .runner import (
StepSpec, PIPELINE_STEPS, PipelineRunner, PipelineHalt,
)
__all__ = ["PipelineContext", "StepSpec", "PIPELINE_STEPS", "PipelineRunner"]
__all__ = [
"PipelineContext", "StepSpec", "PIPELINE_STEPS", "PipelineRunner", "PipelineHalt",
"STEP_MAP_OLD_TO_NEW", "STEP_MAP_NEW_TO_OLD",
"resolve_step_id", "ALL_STEP_IDS",
]

View File

@ -12,7 +12,34 @@ PipelineContext内存级数据载体跨 14 个 step 传递路径与元信
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set
# ============================================================
# 步骤命名映射(定义在叶子节点,打破循环依赖)
# ============================================================
STEP_MAP_OLD_TO_NEW: Dict[str, str] = {
"step5_5": "step8",
"step6_5": "step8_non_empirical_modeling",
"step6_75": "step9",
"step8_5": "step11",
"step8_75": "step12",
"step7": "step10",
"step8": "step11_ml",
"step9": "step14",
}
STEP_MAP_NEW_TO_OLD: Dict[str, str] = {v: k for k, v in STEP_MAP_OLD_TO_NEW.items()}
ALL_STEP_IDS: Set[str] = set(STEP_MAP_OLD_TO_NEW.keys()) | set(STEP_MAP_OLD_TO_NEW.values())
def resolve_step_id(step_id: str) -> str:
"""将任意 step_id 转换为标准新格式。"""
if step_id in STEP_MAP_OLD_TO_NEW:
return STEP_MAP_OLD_TO_NEW[step_id]
return step_id
@dataclass
@ -63,10 +90,29 @@ class PipelineContext:
pipeline_end_time: Optional[float] = None
last_error: Optional[str] = None
# ── 错误汇总(全流程结束后可用) ──
error_summary: List[tuple[str, str]] = field(default_factory=list)
# ── 出错时立即停止全流程(默认 False继续后续步骤 ──
breakpoint_on_error: bool = False
# ── ★ 智能补全锁定步骤列表(由 _auto_fill_missing_steps 自动开启的步骤) ──
# GUI 层读取此字段,在运行期间禁用对应面板的启用复选框
locked_steps: List[str] = field(default_factory=list)
# ============================================================
# 读写辅助
# ============================================================
def step_id(self, step_id: str) -> str:
"""将任意 step_id可能是旧名转换为标准新格式。
用法示例:
ctx.status[ctx.step_id('step6_5')] # 'step8_non_empirical_modeling'
ctx.user_config[ctx.step_id('step8_5')] # 'step11'
"""
if step_id in STEP_MAP_OLD_TO_NEW:
return STEP_MAP_OLD_TO_NEW[step_id]
return step_id
def set(self, key: str, value: Any) -> None:
"""原地写入任意属性。

View File

@ -660,7 +660,7 @@ class WaterQualityInversionPipeline:
self._notify("completed", f"训练光谱数据已保存: {result}")
return result
def step5_5_calculate_water_quality_indices(self,
def step8_water_quality_indices(self,
training_csv_path: Optional[str] = None,
formula_csv_file: Optional[str] = None,
formula_names: Optional[List[str]] = None,
@ -704,7 +704,7 @@ class WaterQualityInversionPipeline:
self._notify("completed", f"水质指数已保存: {result}")
return result
def step6_train_models(self, feature_start_column: str = "374.285004",
def step7_ml_modeling(self, feature_start_column: str = "374.285004",
preprocessing_methods: List[str] = None,
model_names: List[str] = None,
split_methods: List[str] = None,
@ -747,7 +747,7 @@ class WaterQualityInversionPipeline:
self._notify("completed", f"模型训练完成,结果保存在: {result}")
return result
def step7_generate_sampling_points(self, deglint_img_path: Optional[str] = None,
def step10_sampling(self, deglint_img_path: Optional[str] = None,
interval: int = 50,
sample_radius: int = 5,
chunk_size: int = 1000,
@ -756,7 +756,7 @@ class WaterQualityInversionPipeline:
use_adaptive_sampling: bool = True,
skip_dependency_check: bool = False, **kwargs) -> str:
"""
步骤7: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱
步骤10: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱
Args:
deglint_img_path: 去除耀斑后的影像文件路径如果为None使用步骤3的结果
@ -779,7 +779,7 @@ class WaterQualityInversionPipeline:
if water_mask_path is None and self.water_mask_path is not None:
water_mask_path = self.water_mask_path
self._notify("started", "步骤7: 生成预测采样点")
self._notify("started", "步骤10: 生成预测采样点")
result = PredictionStep.generate_sampling_points(
deglint_img_path=img_path,
interval=interval,
@ -790,21 +790,21 @@ class WaterQualityInversionPipeline:
output_dir=str(self.sampling_dir),
use_adaptive_sampling=use_adaptive_sampling,
)
self._record_step_time("步骤7: 生成预测采样点", 0, 0)
self._record_step_time("步骤10: 生成预测采样点", 0, 0)
self._notify("completed", f"采样点光谱数据已保存: {result}")
return result
def step8_predict_water_quality(self, sampling_csv_path: str,
def step11_ml_prediction(self, sampling_csv_path: str,
models_dir: Optional[str] = None,
metric: str = 'test_r2',
prediction_column: str = 'prediction',
skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:
"""
步骤8: 将训练好的最佳机器学习模型应用到采样点的平均光谱上,预测水质参数
步骤11: 将训练好的最佳机器学习模型应用到采样点的平均光谱上,预测水质参数
Args:
sampling_csv_path: 采样点光谱数据CSV路径
models_dir: 模型保存目录如果为None使用步骤6的结果)
models_dir: 模型保存目录如果为None使用步骤7的结果)
metric: 选择最佳模型的指标
prediction_column: 预测结果列名
@ -818,7 +818,7 @@ class WaterQualityInversionPipeline:
print(f"[Pipeline] 收到字典: {'Yes' if _external_models_dict else 'No'}"
f", 收到单模型: {'Yes' if _external_model else 'No'}")
self._notify("started", "步骤8: 预测水质参数")
self._notify("started", "步骤11: 预测水质参数")
result = PredictionStep.predict_water_quality(
sampling_csv_path=sampling_csv_path,
models_dir=models_dir if models_dir else str(self.models_dir),
@ -831,11 +831,11 @@ class WaterQualityInversionPipeline:
_external_models_dict=_external_models_dict,
_external_model_dir=_external_model_dir,
)
self._record_step_time("步骤8: 预测水质参数", 0, 0)
self._record_step_time("步骤11: 预测水质参数", 0, 0)
self._notify("completed", f"预测完成,结果保存在: {self.prediction_dir}")
return result
def step9_generate_distribution_map(self, prediction_csv_path: str,
def step14_distribution_map(self, prediction_csv_path: str,
boundary_shp_path: str,
output_image_path: Optional[str] = None,
resolution: float = 30,
@ -1524,99 +1524,99 @@ class WaterQualityInversionPipeline:
else:
self._notify("步骤5: 光谱提取", "skipped", "未配置")
# 步骤5.5: 计算水质指数
if 'step5_5' in config:
self._notify("步骤5.5: 水质指数计算", "start")
self.step5_5_calculate_water_quality_indices(**config['step5_5'])
self._notify("步骤5.5: 水质指数计算", "completed", f"(输出: {self.indices_path})")
# 步骤8: 计算水质指数
if 'step8' in config:
self._notify("步骤8: 水质指数计算", "start")
self.step8_water_quality_indices(**config['step8'])
self._notify("步骤8: 水质指数计算", "completed", f"(输出: {self.indices_path})")
else:
self._notify("步骤5.5: 水质指数计算", "skipped", "未配置")
self._notify("步骤8: 水质指数计算", "skipped", "未配置")
# 步骤6: 训练模型
if 'step6' in config:
self._notify("步骤6: 模型训练", "start")
self.step6_train_models(**config['step6'])
self._notify("步骤6: 模型训练", "completed", f"(输出: {self.models_dir})")
else:
self._notify("步骤6: 模型训练", "skipped", "未配置")
# 步骤6.5: 非经验统计回归模型训练
if 'step6_5' in config:
self._notify("步骤6.5: 非经验模型训练", "start")
self.step6_5_non_empirical_modeling(**config['step6_5'])
self._notify("步骤6.5: 非经验模型训练", "completed", f"(输出: {self.models_dir})")
else:
self._notify("步骤6.5: 非经验模型训练", "skipped", "未配置")
# 步骤6.75: 自定义回归分析
if 'step6_75' in config:
self._notify("步骤6.75: 自定义回归", "start")
self.step6_75_custom_regression(**config['step6_75'])
self._notify("步骤6.75: 自定义回归", "completed", f"(输出: {self.custom_regression_path})")
else:
self._notify("步骤6.75: 自定义回归", "skipped", "未配置")
# 步骤7: 生成预测采样点
# 步骤7: 训练模型
if 'step7' in config:
self._notify("步骤7: 采样点生成", "start")
sampling_csv_path = self.step7_generate_sampling_points(**config['step7'])
self._notify("步骤7: 采样点生成", "completed", f"(输出: {sampling_csv_path})")
self._notify("步骤7: 模型训练", "start")
self.step7_ml_modeling(**config['step7'])
self._notify("步骤7: 模型训练", "completed", f"(输出: {self.models_dir})")
else:
self._notify("步骤7: 模型训练", "skipped", "未配置")
# 步骤8_non_empirical_modeling: 非经验统计回归模型训练
if 'step8_non_empirical_modeling' in config:
self._notify("步骤8: 非经验模型训练", "start")
self.step8_non_empirical_modeling(**config['step8_non_empirical_modeling'])
self._notify("步骤8: 非经验模型训练", "completed", f"(输出: {self.models_dir})")
else:
self._notify("步骤8: 非经验模型训练", "skipped", "未配置")
# 步骤9: 自定义回归分析
if 'step9' in config:
self._notify("步骤9: 自定义回归", "start")
self.step9_custom_regression(**config['step9'])
self._notify("步骤9: 自定义回归", "completed", f"(输出: {self.custom_regression_path})")
else:
self._notify("步骤9: 自定义回归", "skipped", "未配置")
# 步骤10: 生成预测采样点
if 'step10' in config:
self._notify("步骤10: 采样点生成", "start")
sampling_csv_path = self.step10_sampling(**config['step10'])
self._notify("步骤10: 采样点生成", "completed", f"(输出: {sampling_csv_path})")
else:
sampling_csv_path = None
self._notify("步骤7: 采样点生成", "skipped", "未配置")
self._notify("步骤10: 采样点生成", "skipped", "未配置")
# 步骤8: 预测水质参数
if 'step8' in config and sampling_csv_path:
self._notify("步骤8: 参数预测", "start")
step8_config = config['step8'].copy()
step8_config['sampling_csv_path'] = sampling_csv_path
prediction_files = self.step8_predict_water_quality(**step8_config)
self._notify("步骤8: 参数预测", "completed", f"(生成{len(prediction_files)}个预测文件)")
# 步骤11_ml: 预测水质参数
if 'step11_ml' in config and sampling_csv_path:
self._notify("步骤11: 参数预测", "start")
step11_ml_config = config['step11_ml'].copy()
step11_ml_config['sampling_csv_path'] = sampling_csv_path
prediction_files = self.step11_ml_prediction(**step11_ml_config)
self._notify("步骤11: 参数预测", "completed", f"(生成{len(prediction_files)}个预测文件)")
else:
prediction_files = {}
self._notify("步骤8: 参数预测", "skipped", "未配置或缺少采样点")
self._notify("步骤11: 参数预测", "skipped", "未配置或缺少采样点")
# 步骤8.5: 使用非经验模型进行参数预测
# 步骤11: 使用非经验模型进行参数预测
non_empirical_prediction_files = {}
if 'step8_5' in config and sampling_csv_path:
self._notify("步骤8.5: 非经验模型预测", "start")
step8_5_config = config['step8_5'].copy()
step8_5_config['sampling_csv_path'] = sampling_csv_path
non_empirical_prediction_files = self.step8_5_predict_with_non_empirical_models(**step8_5_config)
self._notify("步骤8.5: 非经验模型预测", "completed", f"(生成{len(non_empirical_prediction_files)}个预测文件)")
if 'step11' in config and sampling_csv_path:
self._notify("步骤11: 非经验模型预测", "start")
step11_config = config['step11'].copy()
step11_config['sampling_csv_path'] = sampling_csv_path
non_empirical_prediction_files = self.step11_non_empirical_prediction(**step11_config)
self._notify("步骤11: 非经验模型预测", "completed", f"(生成{len(non_empirical_prediction_files)}个预测文件)")
else:
self._notify("步骤8.5: 非经验模型预测", "skipped", "未配置或缺少采样点")
self._notify("步骤11: 非经验模型预测", "skipped", "未配置或缺少采样点")
# 步骤8.75: 使用自定义回归模型进行参数预测
# 步骤12: 使用自定义回归模型进行参数预测
custom_regression_prediction_files = {}
if 'step8_75' in config and sampling_csv_path:
self._notify("步骤8.75: 自定义回归预测", "start")
step8_75_config = config['step8_75'].copy()
step8_75_config['sampling_csv_path'] = sampling_csv_path
custom_regression_prediction_files = self.step8_75_predict_with_custom_regression(**step8_75_config)
self._notify("步骤8.75: 自定义回归预测", "completed", f"(生成{len(custom_regression_prediction_files)}个预测文件)")
if 'step12' in config and sampling_csv_path:
self._notify("步骤12: 自定义回归预测", "start")
step12_config = config['step12'].copy()
step12_config['sampling_csv_path'] = sampling_csv_path
custom_regression_prediction_files = self.step12_custom_regression_prediction(**step12_config)
self._notify("步骤12: 自定义回归预测", "completed", f"(生成{len(custom_regression_prediction_files)}个预测文件)")
else:
self._notify("步骤8.75: 自定义回归预测", "skipped", "未配置或缺少采样点")
self._notify("步骤12: 自定义回归预测", "skipped", "未配置或缺少采样点")
# 合并机器学习预测、非经验模型预测和自定义回归预测结果
all_prediction_files = {**prediction_files, **non_empirical_prediction_files, **custom_regression_prediction_files}
# 步骤9: 生成分布图
# 步骤14: 生成分布图
distribution_maps = {}
if 'step9' in config and all_prediction_files:
self._notify("步骤9: 分布图生成", "start")
if 'step14' in config and all_prediction_files:
self._notify("步骤14: 分布图生成", "start")
for target_name, pred_file in all_prediction_files.items():
step9_config = config['step9'].copy()
step14_config = config['step14'].copy()
for _k in ('step9_batch_mode', 'prediction_csv_dir', 'recursive_csv_scan'):
step9_config.pop(_k, None)
step9_config['prediction_csv_path'] = pred_file
if 'output_image_path' not in step9_config:
step9_config['output_image_path'] = None
dist_map_path = self.step9_generate_distribution_map(**step9_config)
step14_config.pop(_k, None)
step14_config['prediction_csv_path'] = pred_file
if 'output_image_path' not in step14_config:
step14_config['output_image_path'] = None
dist_map_path = self.step14_distribution_map(**step14_config)
distribution_maps[target_name] = dist_map_path
self._notify("步骤9: 分布图生成", "completed", f"(生成{len(distribution_maps)}个分布图)")
self._notify("步骤14: 分布图生成", "completed", f"(生成{len(distribution_maps)}个分布图)")
else:
self._notify("步骤9: 分布图生成", "skipped", "未配置或缺少预测结果")
self._notify("步骤14: 分布图生成", "skipped", "未配置或缺少预测结果")
# 生成可视化图表
output_files = {}
@ -1716,10 +1716,10 @@ class WaterQualityInversionPipeline:
pipeline_info['step3'] = {'status': 'completed', 'output_file': str(self.deglint_img_path) if self.deglint_img_path else 'N/A'}
pipeline_info['step4'] = {'status': 'completed', 'output_file': str(self.processed_csv_path) if self.processed_csv_path else 'N/A'}
pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_csv_path) if self.training_csv_path else 'N/A'}
pipeline_info['step5_5'] = {'status': 'completed', 'output_file': str(self.indices_path) if self.indices_path else 'N/A'}
pipeline_info['step6'] = {'status': 'completed', 'output_file': str(self.models_dir)}
pipeline_info['step6_75'] = {'status': 'completed', 'output_file': str(self.custom_regression_path) if self.custom_regression_path else 'N/A'}
pipeline_info['training_params'] = config.get('step6', {})
pipeline_info['step8'] = {'status': 'completed', 'output_file': str(self.indices_path) if self.indices_path else 'N/A'}
pipeline_info['step7'] = {'status': 'completed', 'output_file': str(self.models_dir)}
pipeline_info['step9'] = {'status': 'completed', 'output_file': str(self.custom_regression_path) if self.custom_regression_path else 'N/A'}
pipeline_info['training_params'] = config.get('step7', {})
summary_path = self.report_generator.generate_batch_inference_summary(pipeline_info)
print(f"批量处理摘要已生成: {summary_path}")
@ -1769,7 +1769,7 @@ class WaterQualityInversionPipeline:
traceback.print_exc()
raise
def step6_5_non_empirical_modeling(self, csv_path: Optional[str] = None,
def step8_non_empirical_modeling(self, csv_path: Optional[str] = None,
preprocessing_methods: List[str] = None,
algorithms: List[str] = None,
value_cols: Union[int, Dict[str, int]] = 0,
@ -1819,7 +1819,7 @@ class WaterQualityInversionPipeline:
self._notify("completed", f"非经验模型训练完成")
return result
def step6_75_custom_regression(self,
def step9_custom_regression(self,
csv_path: Optional[str] = None,
x_columns: Optional[Union[str, List[str]]] = None,
y_columns: Optional[Union[str, List[str]]] = None,
@ -1999,7 +1999,7 @@ class WaterQualityInversionPipeline:
return summary_path
def step8_5_predict_with_non_empirical_models(self, sampling_csv_path: str,
def step11_non_empirical_prediction(self, sampling_csv_path: str,
non_empirical_models_dir: Optional[str] = None,
output_path: Optional[str] = None,
metric: str = 'Average Accuracy(%)',
@ -2007,13 +2007,13 @@ class WaterQualityInversionPipeline:
enabled: bool = True,
skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:
"""
步骤8.5: 使用非经验统计回归模型进行参数预测
步骤11: 使用非经验统计回归模型进行参数预测
根据非经验模型训练结果汇总CSV筛选给定方法的准确率最高的模型使用该模型进行预测
Args:
sampling_csv_path: 采样点光谱数据CSV路径
non_empirical_models_dir: 非经验模型保存目录如果为None使用步骤6.5的结果)
non_empirical_models_dir: 非经验模型保存目录如果为None使用步骤8的结果)
output_path: 输出目录路径如果为None使用默认目录
metric: 选择最佳模型的指标(默认使用平均准确率)
prediction_column: 预测结果列名
@ -2021,7 +2021,7 @@ class WaterQualityInversionPipeline:
Returns:
预测结果文件路径字典(键为算法名)
"""
self._notify("started", "步骤8.5: 使用非经验模型进行参数预测")
self._notify("started", "步骤11: 使用非经验模型进行参数预测")
result = PredictionStep.predict_with_non_empirical_models(
sampling_csv_path=sampling_csv_path,
non_empirical_models_dir=non_empirical_models_dir,
@ -2031,11 +2031,11 @@ class WaterQualityInversionPipeline:
enabled=enabled,
work_dir=str(self.work_dir),
)
self._record_step_time("步骤8.5: 非经验模型预测", 0, 0)
self._record_step_time("步骤11: 非经验模型预测", 0, 0)
self._notify("completed", f"非经验模型预测完成,结果保存在: {self.prediction_dir}")
return result
def step8_75_predict_with_custom_regression(self, sampling_csv_path: str,
def step12_custom_regression_prediction(self, sampling_csv_path: str,
custom_regression_dir: Optional[str] = None,
formula_csv_path: Optional[str] = None,
coordinate_columns: Optional[List[str]] = None,
@ -2044,13 +2044,13 @@ class WaterQualityInversionPipeline:
enabled: bool = True,
skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:
"""
步骤8.75: 使用自定义回归模型进行参数预测
步骤12: 使用自定义回归模型进行参数预测
使用新的CustomRegressionPredictor模块基于9_Custom_Regression_Modeling文件夹中的CSV
根据r_squared选择最佳模型批量预测水质参数
Args:
sampling_csv_path: 采样点光谱数据CSV路径来自步骤7
sampling_csv_path: 采样点光谱数据CSV路径来自步骤10
custom_regression_dir: 自定义回归模型目录9_Custom_Regression_Modeling
formula_csv_path: 公式CSV文件路径用于查找index_formula
coordinate_columns: 坐标列名列表,默认为['longitude', 'latitude']或自动识别
@ -2058,11 +2058,11 @@ class WaterQualityInversionPipeline:
filename_prefix: 输出文件名前缀
enabled: 是否启用
skip_dependency_check: 是否跳过依赖检查
Returns:
预测结果文件路径字典(键为参数名)
"""
self._notify("started", "步骤8.75: 使用自定义回归模型进行参数预测")
self._notify("started", "步骤12: 使用自定义回归模型进行参数预测")
result = PredictionStep.predict_with_custom_regression(
sampling_csv_path=sampling_csv_path,
custom_regression_dir=custom_regression_dir,
@ -2073,7 +2073,7 @@ class WaterQualityInversionPipeline:
enabled=enabled,
work_dir=str(self.work_dir),
)
self._record_step_time("步骤8.75: 自定义回归模型预测", 0, 0)
self._record_step_time("步骤12: 自定义回归模型预测", 0, 0)
self._notify("completed", f"自定义回归预测完成")
return result
@ -2161,20 +2161,20 @@ def main():
# 单步运行时建议显式指定完整流程中可省略将使用步骤2输出的耀斑掩膜
# 'glint_mask_path': r"path/to/severe_glint_area.dat",
},
'step5_5': {
'step8': {
'formula_csv_file': 'path/to/water_quality_formulas.csv', # 公式CSV文件路径
'formula_names': ['Al10SABI', 'TurbBe16RedOverViolet'], # 要计算的公式名称列表
'output_filename': 'water_quality_indices.csv',
'enabled': True # 是否启用水质指数计算
},
'step6': {
'step7': {
'feature_start_column': '374.285004',
'preprocessing_methods': ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT'],
'model_names': ['SVR', 'RF', 'Ridge', 'Lasso'],
'split_methods': ['spxy', 'ks', 'random'],
'cv_folds': 3
},
'step6_5': {
'step8_non_empirical_modeling': {
'preprocessing_methods': ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT'],
'algorithms': ['chl_a', 'nh3', 'mno4', 'tn', 'tp', 'tss'],
'value_cols': 0, # 可以是单个整数或字典,如 {'chl_a': 0, 'nh3': 1, 'mno4': 2, 'tn': 3, 'tp': 4, 'tss': 5}
@ -2182,14 +2182,14 @@ def main():
'window': 5,
'enabled': True # 是否启用非经验模型训练
},
'step6_75': {
'step9': {
'x_columns': ['NDWI', 'NDVI'], # 自变量列名列表
'y_columns': ['chl_a', 'tn', 'tp'], # 因变量列名列表
'methods': 'all', # 回归方法
'output_dir': 'custom_regression_results', # 输出目录
'enabled': True # 是否启用自定义回归分析
},
'step7': {
'step10': {
'interval': 50,
'sample_radius': 5,
'chunk_size': 1000,
@ -2197,16 +2197,16 @@ def main():
# 可选耀斑掩膜文件dat若不提供将使用步骤2结果需要外部指定时取消注释
# 'glint_mask_path': r"D:\path\to\severe_glint_area.dat",
},
'step8': {
'step11_ml': {
'metric': 'test_r2',
'prediction_column': 'prediction'
},
'step8_5': {
'step11': {
'metric': 'Average Accuracy(%)', # 选择最佳模型的指标
'prediction_column': 'prediction',
'enabled': True # 是否启用非经验模型预测
},
'step8_75': {
'step12': {
'custom_regression_dir': None, # 自定义回归模型目录None表示使用9_Custom_Regression_Modeling
'formula_csv_path': None, # 公式CSV文件路径用于查找index_formula如water_quality_formulas.csv
'coordinate_columns': None, # 坐标列名None表示自动识别
@ -2214,7 +2214,7 @@ def main():
'filename_prefix': 'custom_regression_prediction', # 输出文件名前缀
'enabled': True # 是否启用自定义回归预测
},
'step9': {
'step14': {
'boundary_shp_path': r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp" ,
'resolution': 30,
'input_crs': 'EPSG:32651',
@ -2345,41 +2345,41 @@ def example_independent_steps():
except Exception as e:
print(f"步骤6失败: {e}")
# 示例6: 独立运行步骤7 - 采样点生成
print("\n示例6: 独立运行步骤7 - 采样点生成")
# 示例6: 独立运行步骤10 - 采样点生成
print("\n示例6: 独立运行步骤10 - 采样点生成")
try:
sampling_csv = pipeline.step7_generate_sampling_points(
sampling_csv = pipeline.step10_sampling(
deglint_img_path="path/to/deglint_image.bsq",
water_mask_path="path/to/water_mask.dat",
skip_dependency_check=True
)
print(f"采样点数据: {sampling_csv}")
except Exception as e:
print(f"步骤7失败: {e}")
print(f"步骤10失败: {e}")
# 示例7: 独立运行步骤8 - 水质预测
print("\n示例7: 独立运行步骤8 - 水质预测")
# 示例7: 独立运行步骤11 - 水质预测
print("\n示例7: 独立运行步骤11 - 水质预测")
try:
predictions = pipeline.step8_predict_water_quality(
predictions = pipeline.step11_ml_prediction(
sampling_csv_path="path/to/sampling_spectra.csv",
models_dir="path/to/models_directory",
skip_dependency_check=True
)
print(f"预测结果: {predictions}")
except Exception as e:
print(f"步骤8失败: {e}")
print(f"步骤11失败: {e}")
# 示例8: 独立运行步骤9 - 分布图生成
print("\n示例8: 独立运行步骤9 - 分布图生成")
# 示例8: 独立运行步骤14 - 分布图生成
print("\n示例8: 独立运行步骤14 - 分布图生成")
try:
distribution_map = pipeline.step9_generate_distribution_map(
distribution_map = pipeline.step14_distribution_map(
prediction_csv_path="path/to/prediction_results.csv",
boundary_shp_path="path/to/boundary.shp",
skip_dependency_check=True
)
print(f"分布图: {distribution_map}")
except Exception as e:
print(f"步骤9失败: {e}")
print(f"步骤14失败: {e}")
print("\n" + "="*80)
print("独立步骤运行示例完成")

View File

@ -0,0 +1,237 @@
# -*- coding: utf-8 -*-
"""
PipelineModeDialog全流程运行前的模式选择弹窗。
用户点击"运行完整流程"后,首先弹出此弹窗选择执行模式:
- 选项 A训练新模型并预测执行完整建模与预测流程需要实测水质 CSV
- 选项 B使用已有模型直接预测跳过训练步骤直接使用外部模型目录进行预测
弹窗结果:
- QDialog.Accepted + self.selected_mode = "training""prediction_only"
- QDialog.Rejected → 调用方中止 run_full_pipeline
"""
import os
from typing import Optional
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QFont
from PyQt5.QtWidgets import (
QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
QRadioButton, QGroupBox, QButtonGroup, QMessageBox, QSizePolicy,
)
def _is_valid_model_dir(path: str) -> bool:
"""深层递归检测模型目录:只要任意层级存在文件即返回 True。"""
if not path or not os.path.isdir(path):
return False
for _root, _dirs, files in os.walk(path):
if files:
return True
return False
class PipelineModeDialog(QDialog):
"""全流程模式选择对话框。
两个单选按钮覆盖两种业务场景:
- A训练新模型完整流程需要 step4 CSV
- B仅预测跳过 step4/5/7/8直接用外部模型目录
属性:
selected_mode: "training" | "prediction_only"
"""
def __init__(self, main_window=None, parent=None):
super().__init__(parent)
self.main_window = main_window
self.selected_mode: Optional[str] = None
self.setWindowTitle("选择运行模式")
self.setMinimumSize(560, 340)
self.setModal(True)
self._setup_ui()
# ------------------------------------------------------------------
# UI 构建
# ------------------------------------------------------------------
def _setup_ui(self):
layout = QVBoxLayout(self)
layout.setContentsMargins(28, 24, 28, 20)
layout.setSpacing(14)
# ── 标题 ──
title = QLabel("请选择全流程运行模式")
title_font = QFont()
title_font.setPointSize(13)
title_font.setBold(True)
title.setFont(title_font)
title.setAlignment(Qt.AlignCenter)
layout.addWidget(title)
layout.addSpacing(4)
# ── 选项 A训练新模型 ──
group_a = QGroupBox()
group_a.setObjectName("groupA")
group_a.setMinimumHeight(100)
layout.addWidget(group_a)
self.radio_a = QRadioButton("【训练新模型并预测】")
self.radio_a.setChecked(True) # 默认选项 A
self.radio_a.setObjectName("radioTraining")
desc_a = QLabel(
"需要提供实测水质数据 (CSV),将执行完整建模与预测流程。\n"
"包括:水域掩膜 → 耀斑去除 → 光谱特征提取 → 模型训练 → 密集采样 → 预测 → 专题图"
)
desc_a.setWordWrap(True)
desc_a.setStyleSheet("color: #555555; background: transparent;")
desc_a.setObjectName("descA")
vbox_a = QVBoxLayout(group_a)
vbox_a.setContentsMargins(16, 20, 16, 14)
vbox_a.setSpacing(8)
vbox_a.addWidget(self.radio_a)
vbox_a.addWidget(desc_a)
# ── 选项 B仅预测 ──
group_b = QGroupBox()
group_b.setObjectName("groupB")
group_b.setMinimumHeight(100)
layout.addWidget(group_b)
self.radio_b = QRadioButton("【使用已有模型直接预测】")
self.radio_b.setObjectName("radioPrediction")
desc_b = QLabel(
"跳过模型训练步骤,直接使用导入的外部模型目录进行预测。\n"
"前提条件:请在「监督预测」或「回归预测」面板中指定模型目录。\n"
"适用范围:已有预训练模型、或其他来源模型目录。"
)
desc_b.setWordWrap(True)
desc_b.setStyleSheet("color: #555555; background: transparent;")
desc_b.setObjectName("descB")
vbox_b = QVBoxLayout(group_b)
vbox_b.setContentsMargins(16, 20, 16, 14)
vbox_b.setSpacing(8)
vbox_b.addWidget(self.radio_b)
vbox_b.addWidget(desc_b)
# ── 强制互斥QButtonGroup ──
self.mode_group = QButtonGroup(self)
self.mode_group.addButton(self.radio_a)
self.mode_group.addButton(self.radio_b)
# ── 提示栏(动态显示 models_dir 状态) ──
self.models_hint = QLabel()
self.models_hint.setObjectName("modelsHint")
self.models_hint.setWordWrap(True)
self.models_hint.setStyleSheet("color: #888888; font-size: 11px; padding: 4px 0;")
layout.addWidget(self.models_hint)
# ── 强制 QRadioButton 指示器为实心圆点 ──
self.setStyleSheet("""
QRadioButton::indicator {
width: 14px;
height: 14px;
}
QRadioButton::indicator:checked {
background-color: #0078D7;
border: 2px solid #0078D7;
border-radius: 7px;
}
QRadioButton::indicator:unchecked {
background-color: white;
border: 2px solid #A0A0A0;
border-radius: 7px;
}
""")
# ── 按钮 ──
btn_layout = QHBoxLayout()
btn_layout.addStretch()
cancel_btn = QPushButton("取消")
cancel_btn.setObjectName("cancelBtn")
cancel_btn.setMinimumWidth(90)
cancel_btn.clicked.connect(self.reject)
self.btn_confirm = QPushButton("确认")
self.btn_confirm.setObjectName("confirmBtn")
self.btn_confirm.setMinimumWidth(90)
self.btn_confirm.setDefault(True)
self.btn_confirm.clicked.connect(self._on_confirm)
btn_layout.addWidget(self.btn_confirm)
btn_layout.addWidget(cancel_btn)
layout.addLayout(btn_layout)
# 信号连接:任一 radio 切换时重新渲染提示 + 按钮状态
self.radio_a.toggled.connect(self._update_models_hint)
self.radio_b.toggled.connect(self._update_models_hint)
# 初始状态渲染
self._update_models_hint()
def _update_models_hint(self, checked=False, *args) -> None:
"""根据当前选中模式和 models_dir 状态更新提示文字及确认按钮可用性。"""
training_checked = self.radio_a.isChecked()
# 从主窗口 config 读取 models_dir优先 ml其次 reg
models_dir = ""
if self.main_window:
config = self.main_window.get_current_config()
models_dir = config.get("step11_ml", {}).get("models_dir", "")
if not models_dir:
models_dir = config.get("step11", {}).get("models_dir", "")
has_files = bool(models_dir and _is_valid_model_dir(models_dir))
dir_exists = bool(models_dir and os.path.isdir(models_dir))
if training_checked:
if hasattr(self, 'btn_confirm') and self.btn_confirm is not None:
self.btn_confirm.setEnabled(True)
if has_files:
self.models_hint.setText(
f"⚠ 注意:当前模型目录已包含文件,继续训练将会【覆盖】原有模型!\n路径:{models_dir}"
)
self.models_hint.setStyleSheet("color: #e65100; font-size: 11px; padding: 4px 0;")
else:
label = f"✓ 模型将保存至该目录(当前为空,安全)。\n路径:{models_dir}" if dir_exists else "✓ 尚未指定模型目录,将使用默认路径创建新模型。"
self.models_hint.setText(label)
self.models_hint.setStyleSheet("color: #2e7d32; font-size: 11px; padding: 4px 0;")
else:
if has_files:
self.models_hint.setText(
f"✓ 已检测到有效模型目录,可以直接预测。\n路径:{models_dir}"
)
self.models_hint.setStyleSheet("color: #2e7d32; font-size: 11px; padding: 4px 0;")
if hasattr(self, 'btn_confirm') and self.btn_confirm is not None:
self.btn_confirm.setEnabled(True)
else:
if dir_exists:
self.models_hint.setText(
f"❌ 错误:模型目录为空(未找到任何文件),无法进行预测!\n路径:{models_dir}"
)
else:
self.models_hint.setText(
"❌ 错误:模型目录为空或不存在!请先返回对应面板配置有效路径。"
)
self.models_hint.setStyleSheet("color: #c62828; font-size: 11px; padding: 4px 0;")
if hasattr(self, 'btn_confirm') and self.btn_confirm is not None:
self.btn_confirm.setEnabled(False)
def _on_confirm(self) -> None:
"""确认按钮回调:直接存储模式并关闭。
注意:按钮禁用状态已在 _update_models_hint 中处理,
此处仅负责结果存储,不再做二次弹窗拦截。
"""
if self.radio_a.isChecked():
self.selected_mode = "training"
else:
self.selected_mode = "prediction_only"
self.accept()

View File

@ -349,6 +349,12 @@ class AISettingsDialog(QDialog):
# 交互式采样点与光谱查看器
# ─────────────────────────────────────────────────────────────────────────────
import matplotlib.pyplot as _plt
# 全局字体设置(防中文乱码 + 负号显示异常)
_plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'Arial Unicode MS']
_plt.rcParams['axes.unicode_minus'] = False
class SamplingViewerDialog(QDialog):
"""交互式采样点与光谱查看器
@ -381,8 +387,8 @@ class SamplingViewerDialog(QDialog):
self._fig = Figure(figsize=(6, 5))
self._canvas = FigureCanvasQTAgg(self._fig)
self._ax_scatter = self._fig.add_subplot(111)
self._ax_scatter.set_xlabel("pixel_x")
self._ax_scatter.set_ylabel("pixel_y")
self._ax_scatter.set_xlabel("像素 X")
self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
self._fig.tight_layout()
@ -392,18 +398,20 @@ class SamplingViewerDialog(QDialog):
# --- 右侧:信息面板 + 光谱子图 ---
right_widget = QWidget()
right_layout = QVBoxLayout(right_widget)
right_layout.setContentsMargins(0, 0, 0, 0)
# 坐标信息面板(多行中文清晰显示)
self._info_label = QLabel("点击左侧散点图选择采样点")
self._info_label.setStyleSheet(
"QLabel { background-color: #f0f0f0; padding: 6px; "
"border: 1px solid #ccc; font-size: 13px; }"
"QLabel { background-color: #f0f0f0; padding: 8px; "
"border: 1px solid #ccc; border-radius: 4px; font-size: 13px; }"
)
right_layout.addWidget(self._info_label)
# 采样点列表迷你表格
self._point_table = QTableWidget()
self._point_table.setColumnCount(3)
self._point_table.setHorizontalHeaderLabels(["pixel_x", "pixel_y", "index"])
self._point_table.setHorizontalHeaderLabels(["像素 X", "像素 Y", "序号"])
self._point_table.setMaximumHeight(120)
self._point_table.setEditTriggers(QTableWidget.NoEditTriggers)
self._point_table.setSelectionBehavior(QTableWidget.SelectRows)
@ -413,8 +421,8 @@ class SamplingViewerDialog(QDialog):
# 光谱曲线子图
self._fig_right = Figure(figsize=(5, 3))
self._ax_spectrum = self._fig_right.add_subplot(111)
self._ax_spectrum.set_xlabel("Band Index")
self._ax_spectrum.set_ylabel("Reflectance")
self._ax_spectrum.set_xlabel("波段序号")
self._ax_spectrum.set_ylabel("反射率")
self._ax_spectrum.set_title("光谱曲线")
self._fig_right.tight_layout()
@ -440,8 +448,8 @@ class SamplingViewerDialog(QDialog):
def _draw_scatter(self):
"""绘制散点图"""
self._ax_scatter.clear()
self._ax_scatter.set_xlabel("pixel_x")
self._ax_scatter.set_ylabel("pixel_y")
self._ax_scatter.set_xlabel("像素 X")
self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
@ -497,14 +505,14 @@ class SamplingViewerDialog(QDialog):
self._info_label.setText(
f"<b>选中的采样点 #{nearest_idx}</b><br>"
f"pixel_x = {pixel_x} &nbsp; pixel_y = {pixel_y}<br>"
f"x_coord = {x_coord} &nbsp; y_coord = {y_coord}"
f"图像像素坐标: X = {pixel_x}, Y = {pixel_y}<br>"
f"地理真实坐标: 经度(X) = {x_coord}, 纬度(Y) = {y_coord}"
)
# 高亮散点图
self._ax_scatter.clear()
self._ax_scatter.set_xlabel("pixel_x")
self._ax_scatter.set_ylabel("pixel_y")
self._ax_scatter.set_xlabel("像素 X")
self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
self._ax_scatter.invert_yaxis()
self._ax_scatter.scatter(
@ -523,8 +531,8 @@ class SamplingViewerDialog(QDialog):
def _draw_spectrum(self, row: pd.Series):
"""从一行数据中提取纯波段数值并绘图"""
self._ax_spectrum.clear()
self._ax_spectrum.set_xlabel("Band Index")
self._ax_spectrum.set_ylabel("Reflectance")
self._ax_spectrum.set_xlabel("波段序号")
self._ax_spectrum.set_ylabel("反射率")
self._ax_spectrum.set_title("光谱曲线")
exclude_patterns = (

View File

@ -0,0 +1,252 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step10 面板 - 采样点生成
"""
import os
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QPushButton, QCheckBox, QSpinBox, QMessageBox,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.dialogs import SamplingViewerDialog
from src.gui.styles import ModernStylesheet
class Step10Panel(QWidget):
"""步骤10采样点生成"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 去耀斑影像文件(用于独立运行)
self.deglint_img_file = FileSelectWidget(
"去耀斑影像:",
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
)
layout.addWidget(self.deglint_img_file)
# 水域掩膜文件(可选,用于独立运行)
self.water_mask_file = FileSelectWidget(
"水域掩膜:",
"Mask Files (*.dat *.tif);;All Files (*.*)"
)
self.water_mask_file.label.setText("水域掩膜:")
layout.addWidget(self.water_mask_file)
# 参数设置
params_group = QGroupBox("采样参数")
params_layout = QFormLayout()
self.interval = QSpinBox()
self.interval.setRange(10, 500)
self.interval.setValue(50)
params_layout.addRow("采样点间隔(像素):", self.interval)
self.sample_radius = QSpinBox()
self.sample_radius.setRange(1, 50)
self.sample_radius.setValue(5)
params_layout.addRow("采样半径(像素):", self.sample_radius)
self.chunk_size = QSpinBox()
self.chunk_size.setRange(100, 10000)
self.chunk_size.setValue(1000)
params_layout.addRow("处理块大小:", self.chunk_size)
self.use_adaptive_sampling = QCheckBox("启用自适应采样")
self.use_adaptive_sampling.setChecked(True)
params_layout.addRow("采样模式:", self.use_adaptive_sampling)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出文件路径
self.output_file = FileSelectWidget(
"输出采样点:",
"CSV Files (*.csv);;All Files (*.*)"
)
self.output_file.line_edit.setPlaceholderText("sampling_points.csv")
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
# 交互式预览按钮
self.preview_btn = QPushButton("📊 交互式预览采样点与光谱")
self.preview_btn.setEnabled(False)
self.preview_btn.clicked.connect(self._open_sampling_viewer)
layout.addWidget(self.preview_btn)
layout.addStretch()
self.setLayout(layout)
# 监听输出路径变化,实时更新预览按钮状态
self.output_file.line_edit.textChanged.connect(self._on_output_changed)
def get_config(self):
"""获取配置"""
config = {
'interval': self.interval.value(),
'sample_radius': self.sample_radius.value(),
'chunk_size': self.chunk_size.value(),
'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(),
}
deglint_img_path = self.deglint_img_file.get_path()
if deglint_img_path:
config['deglint_img_path'] = deglint_img_path
water_mask_path = self.water_mask_file.get_path()
if water_mask_path:
config['water_mask_path'] = water_mask_path
return config
def set_config(self, config):
"""设置配置"""
if 'interval' in config:
self.interval.setValue(config['interval'])
if 'sample_radius' in config:
self.sample_radius.setValue(config['sample_radius'])
if 'chunk_size' in config:
self.chunk_size.setValue(config['chunk_size'])
if 'use_adaptive_sampling' in config:
self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling'])
if 'deglint_img_path' in config:
self.deglint_img_file.set_path(config['deglint_img_path'])
if 'water_mask_path' in config:
self.water_mask_file.set_path(config['water_mask_path'])
if 'glint_mask_path' in config:
self.glint_mask_file.set_path(config['glint_mask_path'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充去耀斑影像和掩膜路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径)
"""
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
# 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径)
deglint_path = None
if pipeline and hasattr(pipeline, 'step_outputs'):
step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {})
deglint_path = (
step3_outputs.get('deglint_image')
or step3_outputs.get('output_path')
or step3_outputs.get('output_file')
or step3_outputs.get('deglint_img_path')
)
# 回退:从 step3 面板 widget 直接读取(可能是相对路径)
if not deglint_path and hasattr(main_window, 'step3_panel'):
deglint_path = main_window.step3_panel.output_file.get_path()
if deglint_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(deglint_path):
deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/')
self.deglint_img_file.set_path(deglint_path)
# 2. 填充水域掩膜路径优先级pipeline.step_outputs > step1_panel > 1_water_mask > input-test
water_mask_path = None
if pipeline and hasattr(pipeline, 'step_outputs'):
step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {})
water_mask_path = (
step1_outputs.get('water_mask')
or step1_outputs.get('output_path')
or step1_outputs.get('output_file')
)
# 回退:从 step1 面板 widget 直接读取
if not water_mask_path and hasattr(main_window, 'step1_panel'):
water_mask_path = main_window.step1_panel.output_file.get_path()
# 备选:扫描 1_water_mask 目录下的 .dat 文件
if not water_mask_path and self.work_dir:
mask_dir = os.path.join(self.work_dir, "1_water_mask")
if os.path.isdir(mask_dir):
dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')]
if dat_files:
water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/')
# 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat
if not water_mask_path and self.work_dir:
input_test_dir = os.path.join(self.work_dir, "input-test")
if os.path.isdir(input_test_dir):
dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')]
# 优先匹配 water_mask_from_shp.dat
for f in dat_files:
if 'water_mask_from_shp' in f.lower():
water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/')
break
# 否则取第一个 .dat 文件
if not water_mask_path and dat_files:
water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/')
if water_mask_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(water_mask_path):
water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/')
self.water_mask_file.set_path(water_mask_path)
# 3. 自动填充输出路径(绝对路径)
if self.work_dir:
output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
self.output_file.set_path(output_path.replace('\\', '/'))
# 4. 同步更新预览按钮状态(路径可能已自动填充)
self._check_csv_exists()
def run_step(self):
"""独立运行步骤10"""
deglint_img_path = self.deglint_img_file.get_path()
if not deglint_img_path:
QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step10': self.get_config()}
main_window.run_single_step('step10', config)
def _check_csv_exists(self):
"""检查 output csv 是否存在,驱动预览按钮启停"""
csv_path = self.output_file.get_path()
enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path))
self.preview_btn.setEnabled(enabled)
return enabled
def _on_output_changed(self, _text=None):
"""输出路径输入框内容变化时调用_text 为 line_edit.textChanged 信号参数)"""
self._check_csv_exists()
def _open_sampling_viewer(self):
"""打开交互式采样点查看器弹窗"""
csv_path = self.output_file.get_path()
if not csv_path or not os.path.exists(csv_path):
QMessageBox.warning(
self, "文件不存在",
f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤10生成数据。"
)
return
dialog = SamplingViewerDialog(csv_path, self)
dialog.exec_()
# 弹窗关闭后再次检查状态(可能文件被覆盖等)
self._check_csv_exists()

View File

@ -0,0 +1,462 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step8 面板 - 机器学习预测
"""
import os
from pathlib import Path
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout,
QListWidgetItem,
)
from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step11MlPanel(QWidget):
"""步骤11机器学习预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.external_models_dict = {} # {subdir_name: model_obj, ...}
self.external_model_dir = "" # 母文件夹路径(隐藏)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# -------- 模型来源选择(单选按钮组) --------
source_group = QGroupBox("模型来源")
source_layout = QVBoxLayout()
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
self.use_trained_model.setChecked(True)
source_layout.addWidget(self.use_trained_model)
source_layout.addWidget(self.use_external_model)
self.use_trained_model.toggled.connect(self._on_model_source_changed)
self.use_external_model.toggled.connect(self._on_model_source_changed)
source_group.setStyleSheet("""
QRadioButton {
font-size: 13px;
spacing: 8px;
}
QRadioButton::indicator {
width: 16px;
height: 16px;
border-radius: 9px;
border: 2px solid #A0A0A0;
background-color: #FFFFFF;
}
QRadioButton::indicator:hover {
border: 2px solid #0078D7;
}
QRadioButton::indicator:checked {
background-color: #0078D7;
border: 2px solid #0078D7;
}
""")
source_group.setLayout(source_layout)
layout.addWidget(source_group)
# -------- 外部模型文件选择(条件显示) --------
self.external_model_widget = FileSelectWidget(
"模型母文件夹:",
"Directories"
)
self.external_model_widget.browse_btn.clicked.disconnect()
self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
self.external_model_widget.setVisible(False)
layout.addWidget(self.external_model_widget)
# -------- 已扫描模型列表(条件显示) --------
self.model_list_group = QGroupBox("选择参与预测的模型")
self.model_list_group.setVisible(False)
model_list_layout = QVBoxLayout()
self.model_list = QListWidget()
self.model_list.setMaximumHeight(130)
self.model_list.setSelectionMode(QAbstractItemView.NoSelection)
self.model_list.setStyleSheet("""
QListWidget {
border: 1px solid #C0C0C0;
border-radius: 4px;
background-color: #FFFFFF;
font-size: 12px;
}
QListWidget::item {
padding: 4px 6px;
border-bottom: 1px solid #F0F0F0;
}
QListWidget::item:selected {
background-color: transparent;
}
""")
model_list_layout.addWidget(self.model_list)
btn_row = QHBoxLayout()
self.btn_select_all = QPushButton("全选")
self.btn_select_all.setMaximumWidth(80)
self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
self.btn_select_all.clicked.connect(self._select_all_models)
self.btn_select_none = QPushButton("全不选")
self.btn_select_none.setMaximumWidth(80)
self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
self.btn_select_none.clicked.connect(self._select_none_models)
btn_row.addWidget(self.btn_select_all)
btn_row.addWidget(self.btn_select_none)
btn_row.addStretch()
model_list_layout.addLayout(btn_row)
self.model_list_group.setLayout(model_list_layout)
layout.addWidget(self.model_list_group)
# -------- 采样光谱CSV文件用于独立运行--------
self.sampling_csv_file = FileSelectWidget(
"采样光谱CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.sampling_csv_file)
# 模型目录(用于独立运行)
self.models_dir_file = FileSelectWidget(
"模型目录:",
"Directories;;All Files (*.*)"
)
self.models_dir_file.label.setText("模型目录:")
self.models_dir_file.browse_btn.clicked.disconnect()
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
layout.addWidget(self.models_dir_file)
# 参数设置
params_group = QGroupBox("预测参数")
params_layout = QFormLayout()
self.metric = QComboBox()
self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
params_layout.addRow("模型选择指标:", self.metric)
self.prediction_column = QLineEdit()
self.prediction_column.setText("prediction")
params_layout.addRow("预测列名:", self.prediction_column)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出路径
self.output_file = FileSelectWidget(
"输出路径:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
layout.addStretch()
self.setLayout(layout)
def _on_model_source_changed(self, checked: bool):
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
if not checked:
return
is_external = self.use_external_model.isChecked()
self.external_model_widget.setVisible(is_external)
self.model_list_group.setVisible(is_external)
if not is_external:
self.external_models_dict = {}
self.external_model_dir = ""
self._clear_model_list()
def _scan_external_model_dir(self):
"""浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "7_Supervised_Model_Training")
dir_path = QFileDialog.getExistingDirectory(
self,
"选择模型母文件夹",
default,
)
if not dir_path:
return
self.external_model_dir = dir_path
models_found = {}
errors = []
try:
import joblib
for subentry in os.scandir(dir_path):
if not subentry.is_dir():
continue
subdir_name = subentry.name
joblib_files = [
f for f in os.scandir(subentry.path)
if f.is_file() and f.name.lower().endswith(".joblib")
]
if not joblib_files:
continue
# 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
joblib_path = joblib_files[0].path
try:
loaded = joblib.load(joblib_path)
if isinstance(loaded, dict) and "model" in loaded:
model_obj = loaded["model"]
elif hasattr(loaded, "predict"):
model_obj = loaded
else:
errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
continue
models_found[subdir_name] = model_obj
except Exception as e:
errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
except Exception as e:
QMessageBox.warning(
self,
"扫描失败",
f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
)
return
if not models_found:
QMessageBox.warning(
self,
"未找到模型",
f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
"请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
)
self.external_model_widget.set_path("")
self.external_models_dict = {}
self._clear_model_list()
return
self.external_models_dict = models_found
self._populate_model_list(models_found)
names = sorted(models_found.keys())
display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
self.external_model_widget.set_path(display)
self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
err_lines = "\n".join(errors) if errors else ""
QMessageBox.information(
self,
"模型扫描完成",
f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
f"加载失败 {len(errors)} 个:\n{err_lines}",
)
def _populate_model_list(self, models_dict):
"""将扫描到的模型填充到 QListWidget每个条目可勾选默认全选"""
self.model_list.clear()
for name in sorted(models_dict.keys()):
item = QListWidgetItem(name)
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
item.setCheckState(Qt.Checked)
self.model_list.addItem(item)
def _clear_model_list(self):
"""清空模型列表"""
self.model_list.clear()
def _select_all_models(self):
"""全选:设置所有条目为 Checked"""
for i in range(self.model_list.count()):
self.model_list.item(i).setCheckState(Qt.Checked)
def _select_none_models(self):
"""全不选:设置所有条目为 Unchecked"""
for i in range(self.model_list.count()):
self.model_list.item(i).setCheckState(Qt.Unchecked)
def _get_checked_models_dict(self):
"""从列表中提取用户勾选的模型,组装成字典返回"""
result = {}
for i in range(self.model_list.count()):
item = self.model_list.item(i)
if item.checkState() == Qt.Checked:
name = item.text()
if name in self.external_models_dict:
result[name] = self.external_models_dict[name]
return result
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充采样光谱和模型目录
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
if main_window and hasattr(main_window, 'step10_panel'):
step7_widget = getattr(main_window.step10_panel, 'output_file', None)
step7_output_path = ""
if hasattr(step7_widget, 'get_path'):
step7_output_path = step7_widget.get_path() or ""
elif hasattr(step7_widget, 'text'):
step7_output_path = step7_widget.text() or ""
if step7_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step7_output_path):
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
existing = self.sampling_csv_file.get_path()
if not existing or not existing.strip():
self.sampling_csv_file.set_path(step7_output_path)
# 2. 尝试从 Step6 界面读取监督模型目录
if main_window and hasattr(main_window, 'step7_panel'):
step6_widget = getattr(main_window.step7_panel, 'output_dir', None)
step6_models_dir = ""
if hasattr(step6_widget, 'get_path'):
step6_models_dir = step6_widget.get_path() or ""
elif hasattr(step6_widget, 'text'):
step6_models_dir = step6_widget.text() or ""
if step6_models_dir:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step6_models_dir):
step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/')
existing_models = self.models_dir_file.get_path()
if not existing_models or not existing_models.strip():
self.models_dir_file.set_path(step6_models_dir)
# 3. 自动填充输出路径(机器学习预测目录)
if self.work_dir:
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_file.get_path()
if not existing_out or not existing_out.strip():
self.output_file.set_path(output_dir)
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_models_dir(self):
"""浏览模型目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "7_Supervised_Model_Training")
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
if dir_path:
self.models_dir_file.set_path(dir_path)
def get_config(self):
"""获取配置"""
config = {
'metric': self.metric.currentText(),
'prediction_column': self.prediction_column.text(),
}
sampling_csv_path = self.sampling_csv_file.get_path()
if sampling_csv_path:
config['sampling_csv_path'] = sampling_csv_path
models_dir = self.models_dir_file.get_path()
if models_dir:
config['models_dir'] = models_dir
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'metric' in config:
idx = self.metric.findText(config['metric'])
if idx >= 0:
self.metric.setCurrentIndex(idx)
if 'prediction_column' in config:
self.prediction_column.setText(config['prediction_column'])
if 'sampling_csv_path' in config:
self.sampling_csv_file.set_path(config['sampling_csv_path'])
if 'models_dir' in config:
self.models_dir_file.set_path(config['models_dir'])
if 'output_path' in config:
self.output_file.set_path(config['output_path'])
def run_step(self):
"""独立运行步骤8"""
sampling_csv_path = self.sampling_csv_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
return
# 外部模型优先:用户选择了"导入本地预训练模型"
if self.use_external_model.isChecked():
if not self.external_models_dict:
QMessageBox.warning(
self,
"模型未加载",
"请先点击「浏览...」按钮选择模型母文件夹!",
)
return
# 只传递用户勾选的模型
checked_dict = self._get_checked_models_dict()
if not checked_dict:
QMessageBox.warning(
self,
"未选择模型",
"请至少勾选一个模型参与预测!",
)
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {
'step11_ml': self.get_config(),
'_external_models_dict': checked_dict,
'_external_model_dir': self.external_model_dir,
}
main_window.run_single_step('step11_ml', config)
return
# 默认流程:使用模型目录
models_dir = self.models_dir_file.get_path()
if not models_dir:
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step11_ml': self.get_config()}
main_window.run_single_step('step11_ml', config)

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step8_5 面板 - 非经验模型预测
Step11 面板 - 非经验模型预测
"""
import os
@ -17,8 +17,8 @@ from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step8_5Panel(QWidget):
"""步骤8.5:非经验模型预测"""
class Step11Panel(QWidget):
"""步骤11:非经验模型预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
@ -118,22 +118,22 @@ class Step8_5Panel(QWidget):
if not existing or not existing.strip():
self.sampling_csv_file.set_path(step7_output_path)
# 2. 尝试从 Step6.5 界面读取回归模型目录
if main_window and hasattr(main_window, 'step6_5_panel'):
step6_5_widget = getattr(main_window.step6_5_panel, 'output_dir', None)
step6_5_models_dir = ""
if hasattr(step6_5_widget, 'get_path'):
step6_5_models_dir = step6_5_widget.get_path() or ""
elif hasattr(step6_5_widget, 'text'):
step6_5_models_dir = step6_5_widget.text() or ""
# 2. 尝试从 Step8_Non_Empirical 界面读取回归模型目录
if main_window and hasattr(main_window, 'step8_non_empirical_panel'):
step8_non_empirical_widget = getattr(main_window.step8_non_empirical_panel, 'output_dir', None)
step8_non_empirical_models_dir = ""
if hasattr(step8_non_empirical_widget, 'get_path'):
step8_non_empirical_models_dir = step8_non_empirical_widget.get_path() or ""
elif hasattr(step8_non_empirical_widget, 'text'):
step8_non_empirical_models_dir = step8_non_empirical_widget.text() or ""
if step6_5_models_dir:
if step8_non_empirical_models_dir:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step6_5_models_dir):
step6_5_models_dir = os.path.join(self.work_dir or '', step6_5_models_dir).replace('\\', '/')
if not os.path.isabs(step8_non_empirical_models_dir):
step8_non_empirical_models_dir = os.path.join(self.work_dir or '', step8_non_empirical_models_dir).replace('\\', '/')
existing_models = self.models_dir_file.get_path()
if not existing_models or not existing_models.strip():
self.models_dir_file.set_path(step6_5_models_dir)
self.models_dir_file.set_path(step8_non_empirical_models_dir)
# 3. 自动填充输出路径(非经验模型预测目录)
if self.work_dir:
@ -208,7 +208,7 @@ class Step8_5Panel(QWidget):
self.enable_checkbox.setChecked(config['enabled'])
def run_step(self):
"""独立运行步骤8.5"""
"""独立运行步骤11"""
sampling_csv_path = self.sampling_csv_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
@ -221,6 +221,6 @@ class Step8_5Panel(QWidget):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step8_5', {'step8_5': config})
parent.run_single_step('step11', {'step11': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step8_75 面板 - 自定义回归预测
Step12 面板 - 自定义回归预测
"""
import os
@ -15,8 +15,8 @@ from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step8_75Panel(QWidget):
"""步骤8.75:自定义回归预测"""
class Step12Panel(QWidget):
"""步骤12:自定义回归预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
@ -111,25 +111,25 @@ class Step8_75Panel(QWidget):
if not existing or not existing.strip():
self.sampling_csv_file.set_path(step7_output_path)
# 2. 尝试从 Step6.75 界面读取自定义回归模型目录
if main_window and hasattr(main_window, 'step6_75_panel'):
step6_75_widget = getattr(main_window.step6_75_panel, 'output_dir', None)
step6_75_models_dir = ""
if hasattr(step6_75_widget, 'get_path'):
step6_75_models_dir = step6_75_widget.get_path() or ""
elif hasattr(step6_75_widget, 'text'):
step6_75_models_dir = step6_75_widget.text() or ""
step6_75_models_dir = step6_75_models_dir.strip()
# 2. 尝试从 Step9 界面读取自定义回归模型目录
if main_window and hasattr(main_window, 'step12_panel'):
step9_widget = getattr(main_window.step9_panel, 'output_dir', None)
step9_models_dir = ""
if hasattr(step9_widget, 'get_path'):
step9_models_dir = step9_widget.get_path() or ""
elif hasattr(step9_widget, 'text'):
step9_models_dir = step9_widget.text() or ""
step9_models_dir = step9_models_dir.strip()
if step6_75_models_dir:
if step9_models_dir:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step6_75_models_dir):
step6_75_models_dir = os.path.join(self.work_dir or '', step6_75_models_dir).replace('\\', '/')
if not os.path.isabs(step9_models_dir):
step9_models_dir = os.path.join(self.work_dir or '', step9_models_dir).replace('\\', '/')
existing_models = self.regression_models_dir.get_path()
if not existing_models or not existing_models.strip():
self.regression_models_dir.set_path(step6_75_models_dir)
self.regression_models_dir.set_path(step9_models_dir)
# 3. 自动填充回归模型目录(如果 step6_75 未提供)
# 3. 自动填充回归模型目录(如果 step9 未提供)
if self.work_dir:
models_dir = self.regression_models_dir.get_path().strip()
if not models_dir:
@ -208,7 +208,7 @@ class Step8_75Panel(QWidget):
self.enable_checkbox.setChecked(config['enabled'])
def run_step(self):
"""独立运行步骤8.75"""
"""独立运行步骤12"""
sampling_csv_path = self.sampling_csv_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
@ -225,6 +225,6 @@ class Step8_75Panel(QWidget):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step8_75', {'step8_75': config})
parent.run_single_step('step12', {'step12': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")

View File

@ -0,0 +1,533 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step14 面板 - 分布图生成
"""
import os
import traceback
from pathlib import Path
from typing import List, Optional
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
QRadioButton, QButtonGroup, QMessageBox, QFileDialog,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
PIPELINE_AVAILABLE = True
except ImportError:
PIPELINE_AVAILABLE = False
class Step14BatchThread(QThread):
"""专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
finished_ok = pyqtSignal(int)
failed = pyqtSignal(str)
log_message = pyqtSignal(str, str)
def __init__(self, work_dir: str, csv_paths: List[str], step14_kwargs: dict, output_dir_optional: Optional[str]):
super().__init__()
self.work_dir = work_dir
self.csv_paths = csv_paths
self.step14_kwargs = step14_kwargs
self.output_dir_optional = (output_dir_optional or "").strip() or None
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
n = len(self.csv_paths)
for i, csv_p in enumerate(self.csv_paths):
self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
kw = {**self.step14_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
if self.output_dir_optional:
stem = Path(csv_p).stem
kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
else:
kw["output_image_path"] = None
pipeline.step9_generate_distribution_map(**kw)
self.finished_ok.emit(n)
except Exception as e:
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass
class Step14Panel(QWidget):
"""步骤14分布图生成"""
def __init__(self, parent=None):
super().__init__(parent)
self._batch_thread = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
hint = QLabel(
"独立运行:可选「单个 CSV」或「文件夹批量」扫描目录下所有 .csv"
"完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。"
)
hint.setWordWrap(True)
hint.setStyleSheet(
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
)
layout.addWidget(hint)
mode_row = QHBoxLayout()
self.mode_single_rb = QRadioButton("单个 CSV 文件")
self.mode_folder_rb = QRadioButton("文件夹批量")
self._mode_group = QButtonGroup(self)
self._mode_group.addButton(self.mode_single_rb, 0)
self._mode_group.addButton(self.mode_folder_rb, 1)
mode_row.addWidget(self.mode_single_rb)
mode_row.addWidget(self.mode_folder_rb)
mode_row.addStretch()
layout.addLayout(mode_row)
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
radio_style = """
QRadioButton {
font-size: 14px;
spacing: 8px;
color: #333333;
}
QRadioButton::indicator {
width: 16px;
height: 16px;
border: 2px solid #999999;
border-radius: 3px;
background-color: white;
}
QRadioButton::indicator:checked {
border: 2px solid #0078d4;
background-color: #0078d4;
image: none;
}
QRadioButton::indicator:hover {
border: 2px solid #005a9e;
}
"""
self.mode_single_rb.setStyleSheet(radio_style)
self.mode_folder_rb.setStyleSheet(radio_style)
self.prediction_csv_file = FileSelectWidget(
"预测结果CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.prediction_csv_file)
folder_row = QHBoxLayout()
self.prediction_csv_dir_label = QLabel("预测CSV目录:")
self.prediction_csv_dir_label.setMinimumWidth(120)
self.prediction_csv_dir_edit = QLineEdit()
self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
pred_dir_btn = QPushButton("浏览…")
pred_dir_btn.setMaximumWidth(80)
pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
folder_row.addWidget(self.prediction_csv_dir_label)
folder_row.addWidget(self.prediction_csv_dir_edit, 1)
folder_row.addWidget(pred_dir_btn)
self._folder_row_widget = QWidget()
self._folder_row_widget.setLayout(folder_row)
layout.addWidget(self._folder_row_widget)
self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv")
layout.addWidget(self.recursive_csv_cb)
self.boundary_file = FileSelectWidget(
"边界文件:",
"Shapefiles (*.shp);;All Files (*.*)"
)
layout.addWidget(self.boundary_file)
# 参数设置
params_group = QGroupBox("生成参数")
params_layout = QFormLayout()
self.resolution = QDoubleSpinBox()
self.resolution.setRange(1, 1000)
self.resolution.setValue(30)
params_layout.addRow("分辨率(米):", self.resolution)
self.input_crs = QLineEdit()
self.input_crs.setText("EPSG:32651")
params_layout.addRow("输入坐标系:", self.input_crs)
self.output_crs = QLineEdit()
self.output_crs.setText("EPSG:4326")
params_layout.addRow("输出坐标系:", self.output_crs)
self.show_points = QCheckBox("显示采样点")
params_layout.addRow("", self.show_points)
self.use_diffusion = QCheckBox("启用距离扩散")
self.use_diffusion.setChecked(True)
params_layout.addRow("", self.use_diffusion)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出目录
self.output_dir = FileSelectWidget(
"输出分布图目录:",
"Directories;;All Files (*.*)"
)
self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
self.output_dir.browse_btn.clicked.disconnect()
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
layout.addWidget(self.output_dir)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_button = QPushButton("独立运行此步骤")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.clicked.connect(self.run_step)
layout.addWidget(self.run_button)
layout.addStretch()
self.setLayout(layout)
# 信号绑定与初始状态
self.mode_single_rb.toggled.connect(self._toggle_input_mode)
self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
self._toggle_input_mode() # 根据默认值设置初始显示状态
def _toggle_input_mode(self):
"""槽函数:根据单选框状态动态显示/隐藏对应的输入组件。"""
folder_mode = self.mode_folder_rb.isChecked()
# 单个 CSV 模式:显示单文件选择,隐藏文件夹选择
self.prediction_csv_file.setVisible(not folder_mode)
# 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择
self._folder_row_widget.setVisible(folder_mode)
self.recursive_csv_cb.setVisible(folder_mode)
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_prediction_csv_dir(self):
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "11_12_13_predictions")
d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
if d:
self.prediction_csv_dir_edit.setText(d)
def _collect_csv_paths_from_folder(self) -> List[str]:
folder = (self.prediction_csv_dir_edit.text() or "").strip()
if not folder or not os.path.isdir(folder):
return []
root = Path(folder)
if self.recursive_csv_cb.isChecked():
files = sorted(root.rglob("*.csv"))
else:
files = sorted(root.glob("*.csv"))
return [str(p) for p in files if p.is_file()]
def _step14_base_pipeline_kwargs(self) -> dict:
return {
'boundary_shp_path': self.boundary_file.get_path(),
'resolution': self.resolution.value(),
'input_crs': self.input_crs.text(),
'output_crs': self.output_crs.text(),
'show_sample_points': self.show_points.isChecked(),
'use_distance_diffusion': self.use_diffusion.isChecked(),
}
def get_config(self):
pred_csv = (self.prediction_csv_file.get_path() or "").strip()
folder_mode = self.mode_folder_rb.isChecked()
pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
config = {
'step14_batch_mode': 'folder' if folder_mode else 'single',
'prediction_csv_dir': pred_dir if pred_dir else None,
'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
'boundary_shp_path': self.boundary_file.get_path(),
'resolution': self.resolution.value(),
'input_crs': self.input_crs.text(),
'output_crs': self.output_crs.text(),
'show_sample_points': self.show_points.isChecked(),
'use_distance_diffusion': self.use_diffusion.isChecked(),
}
out_dir = (self.output_dir.get_path() or "").strip()
if not folder_mode and pred_csv and out_dir:
stem = Path(pred_csv).stem
config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
else:
config['output_image_path'] = None
return config
def set_config(self, config):
mode = config.get('step14_batch_mode', 'single')
if mode == 'folder':
self.mode_folder_rb.setChecked(True)
else:
self.mode_single_rb.setChecked(True)
if config.get('prediction_csv_dir'):
self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
if 'recursive_csv_scan' in config:
self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
if 'prediction_csv_path' in config and config['prediction_csv_path']:
self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
if 'boundary_shp_path' in config:
self.boundary_file.set_path(config['boundary_shp_path'])
if 'resolution' in config:
self.resolution.setValue(config['resolution'])
if 'input_crs' in config:
self.input_crs.setText(config['input_crs'])
if 'output_crs' in config:
self.output_crs.setText(config['output_crs'])
if 'show_sample_points' in config:
self.show_points.setChecked(config['show_sample_points'])
if 'use_distance_diffusion' in config:
self.use_diffusion.setChecked(config['use_distance_diffusion'])
if 'output_dir' in config and config['output_dir']:
self.output_dir.set_path(str(config['output_dir']))
elif config.get('output_image_path'):
p = Path(str(config['output_image_path']))
if p.parent and str(p.parent) != '.':
self.output_dir.set_path(str(p.parent))
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充预测结果目录
优先使用 Step8机器学习预测的输出目录作为待预测 CSV 目录;
其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
if not main_window:
return
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
pred_dir = None
if hasattr(main_window, 'step11_prediction_panel'):
step8_widget = getattr(main_window.step11_prediction_panel, 'output_file', None)
step8_output = ""
if hasattr(step8_widget, 'get_path'):
step8_output = step8_widget.get_path() or ""
elif hasattr(step8_widget, 'text'):
step8_output = step8_widget.text() or ""
if step8_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_output):
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
# 提取父目录后追加 Machine_Learning_Prediction最底层真实子目录
base_pred_dir = str(Path(step8_output).parent)
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
# 2. 备选:从 Step11 界面读取非经验预测输出目录
if not pred_dir and hasattr(main_window, 'step11_panel'):
step8_5_widget = getattr(main_window.step11_panel, 'output_file', None)
step8_5_output = ""
if hasattr(step8_5_widget, 'get_path'):
step8_5_output = step8_5_widget.get_path() or ""
elif hasattr(step8_5_widget, 'text'):
step8_5_output = step8_5_widget.text() or ""
if step8_5_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_5_output):
step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
pred_dir = str(Path(step8_5_output).parent)
# 3. 备选:从 Step12 界面读取自定义回归预测输出目录
if not pred_dir and hasattr(main_window, 'step12_panel'):
step8_75_widget = getattr(main_window.step12_panel, 'output_dir_widget', None)
step8_75_output = ""
if hasattr(step8_75_widget, 'get_path'):
step8_75_output = step8_75_widget.get_path() or ""
elif hasattr(step8_75_widget, 'text'):
step8_75_output = step8_75_widget.text() or ""
if step8_75_output:
pred_dir = step8_75_output
# 自动填入"预测CSV目录"(文件夹批量模式)
if pred_dir:
existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
if not existing_dir:
self.prediction_csv_dir_edit.setText(pred_dir)
# 切换到文件夹批量模式
self.mode_folder_rb.setChecked(True)
# 4. 自动填充输出目录14_visualization
if self.work_dir:
output_dir = os.path.join(self.work_dir, "14_visualization")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_dir.get_path()
if not existing_out or not existing_out.strip():
self.output_dir.set_path(output_dir)
# 5. 自动探测原始矢量边界文件(.shp作为专题图底图
# 优先回溯 input-test/roi.shpgeopandas.read_file 仅支持矢量格式
if self.work_dir:
possible_shp = None
candidates = [
Path(self.work_dir).parent / "input-test" / "roi.shp",
Path(self.work_dir) / "roi.shp",
Path(self.work_dir).parent / "roi.shp",
]
for candidate in candidates:
if candidate.exists() and candidate.suffix.lower() == ".shp":
possible_shp = candidate
break
existing_boundary = (self.boundary_file.get_path() or "").strip()
if not existing_boundary and possible_shp:
self.boundary_file.set_path(str(possible_shp))
elif not existing_boundary:
# 未找到 .shp 时清空并提示用户手动选择矢量文件
self.boundary_file.set_path("")
print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。")
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def browse_output_dir(self):
"""浏览输出目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "14_visualization")
dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
if dir_path:
self.output_dir.set_path(dir_path)
def run_step(self):
"""独立运行步骤14"""
if self._batch_thread and self._batch_thread.isRunning():
QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
return
boundary_shp_path = self.boundary_file.get_path()
if not boundary_shp_path:
QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
return
if not os.path.exists(boundary_shp_path):
QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
return
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if not parent or not hasattr(parent, 'run_single_step'):
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
return
if self.mode_folder_rb.isChecked():
csv_list = self._collect_csv_paths_from_folder()
if not csv_list:
QMessageBox.warning(
self,
"输入验证失败",
"所选文件夹中未找到 .csv 文件,或目录无效。\n"
"可勾选「包含子文件夹」以递归扫描。",
)
return
if not PIPELINE_AVAILABLE:
QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
return
work_dir = getattr(parent, "work_dir", None) or "./work_dir"
work_dir = str(work_dir)
base_kw = self._step14_base_pipeline_kwargs()
out_dir_opt = (self.output_dir.get_path() or "").strip() or None
self.run_button.setEnabled(False)
self._batch_thread = Step14BatchThread(work_dir, csv_list, base_kw, out_dir_opt)
main_win = parent
def _batch_log(msg, lvl):
if hasattr(main_win, "log_message"):
main_win.log_message(msg, lvl)
self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
self._batch_thread.finished_ok.connect(self._on_step14_batch_ok, Qt.QueuedConnection)
self._batch_thread.failed.connect(self._on_step14_batch_fail, Qt.QueuedConnection)
self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection)
self._batch_thread.start()
if hasattr(parent, "log_message"):
parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV工作目录 {work_dir}", "info")
return
prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
if not prediction_csv_path:
QMessageBox.warning(
self,
"输入验证失败",
"请选择「预测结果 CSV」文件或切换到「文件夹批量」。",
)
return
if not os.path.isfile(prediction_csv_path):
QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
return
config = self.get_config()
parent.run_single_step('step14', {'step14': config})
def _on_step14_batch_ok(self, n: int):
QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
parent = self.parent()
while parent and not hasattr(parent, "log_message"):
parent = parent.parent()
if parent and hasattr(parent, "log_message"):
parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
def _on_step14_batch_fail(self, err: str):
QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
parent = self.parent()
while parent and not hasattr(parent, "log_message"):
parent = parent.parent()
if parent and hasattr(parent, "log_message"):
parent.log_message(err, "error")

View File

@ -1,225 +0,0 @@
import os
import sys
import pandas as pd
from pathlib import Path
from typing import Dict, List, Union
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QGridLayout,
QHBoxLayout, QLabel, QCheckBox, QPushButton, QMessageBox, QScrollArea
)
from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
def get_resource_path(relative_path: str) -> str:
"""适配开发与 PyInstaller 环境的路径获取逻辑。
支持两种打包模式:
1. --onedir 模式:文件在 exe_root/_internal/ 下 → 检查 _internal 目录
2. --onefile 模式:文件在 sys._MEIPASS 平铺目录
"""
# 优先检查 PyInstaller onefile 模式(文件平铺在 _MEIPASS 下)
if hasattr(sys, '_MEIPASS'):
internal_path = os.path.join(sys._MEIPASS, '_internal', relative_path)
if os.path.exists(internal_path):
return internal_path
return os.path.join(sys._MEIPASS, relative_path)
# 兼容 PyInstaller onedir 模式的 _internal 目录exe 同级目录下)
exe_dir = os.path.dirname(sys.executable)
internal_path = os.path.join(exe_dir, '_internal', relative_path)
if os.path.exists(internal_path):
return internal_path
# 开发环境下:基于当前文件 (step5_5_panel.py) 的绝对路径进行回溯
# 当前在 src/gui/panels/,目标在 src/gui/model/
base_dir = Path(__file__).resolve().parent.parent / "model"
target_path = base_dir / os.path.basename(relative_path)
return str(target_path)
class Step5_5Panel(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.index_checkboxes: Dict[str, QCheckBox] = {}
# 标识为 waterindex.csv目录跳转逻辑在 get_resource_path 中
self.builtin_formula_path = get_resource_path("waterindex.csv")
self.init_ui()
# 延迟一小会儿加载确保UI框架已就绪
self._auto_load_formulas()
def init_ui(self):
main_layout = QVBoxLayout()
main_layout.setContentsMargins(20, 20, 20, 20)
main_layout.setSpacing(10)
# 1. 路径展示区 (半透明只读)
path_group = QGroupBox("公式配置源 (内置)")
path_layout = QVBoxLayout()
self.formula_csv_widget = FileSelectWidget("内置CSV路径:", "CSV Files (*.csv)")
self.formula_csv_widget.set_path(self.builtin_formula_path)
self.formula_csv_widget.set_read_only(True)
# 视觉微调:提示用户这是内置的
self.formula_csv_widget.line_edit.setStyleSheet("background-color: #f0f0f0; color: #666;")
path_layout.addWidget(self.formula_csv_widget)
path_group.setLayout(path_layout)
main_layout.addWidget(path_group)
# 2. 训练数据输入
input_group = QGroupBox("输入样本数据")
input_layout = QVBoxLayout()
self.training_data_widget = FileSelectWidget("特征提取CSV:", "CSV Files (*.csv)")
input_layout.addWidget(self.training_data_widget)
input_group.setLayout(input_layout)
main_layout.addWidget(input_group)
# 3. 公式选择区
self.formula_group = QGroupBox("待计算水质指数勾选")
formula_outer_layout = QVBoxLayout()
btn_layout = QHBoxLayout()
self.select_all_btn = QPushButton("全选")
self.deselect_all_btn = QPushButton("清空")
self.select_all_btn.clicked.connect(self.select_all_formulas)
self.deselect_all_btn.clicked.connect(self.deselect_all_formulas)
btn_layout.addWidget(self.select_all_btn)
btn_layout.addWidget(self.deselect_all_btn)
btn_layout.addStretch()
self.refresh_button = QPushButton("手动重新加载公式")
self.refresh_button.clicked.connect(lambda: self.refresh_formulas(silent=False))
btn_layout.addWidget(self.refresh_button)
formula_outer_layout.addLayout(btn_layout)
# 核心滚动区
scroll = QScrollArea()
scroll.setWidgetResizable(True)
scroll.setMinimumHeight(300) # 强制最小高度,防止塌陷
self.scroll_content = QWidget()
self.formula_layout = QGridLayout(self.scroll_content)
self.formula_layout.setAlignment(Qt.AlignTop) # 靠顶对齐
scroll.setWidget(self.scroll_content)
formula_outer_layout.addWidget(scroll)
self.formula_group.setLayout(formula_outer_layout)
main_layout.addWidget(self.formula_group)
# 4. 输出与运行
output_group = QGroupBox("结果输出")
output_layout = QVBoxLayout()
self.output_file_widget = FileSelectWidget("保存路径:", "CSV Files (*.csv)", mode="save")
output_layout.addWidget(self.output_file_widget)
output_group.setLayout(output_layout)
main_layout.addWidget(output_group)
self.enable_checkbox = QCheckBox("启用计算流程")
self.enable_checkbox.setChecked(True)
main_layout.addWidget(self.enable_checkbox)
self.run_button = QPushButton("立即执行计算")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.setMinimumHeight(40)
self.run_button.clicked.connect(self.run_step)
main_layout.addWidget(self.run_button)
self.setLayout(main_layout)
def _auto_load_formulas(self):
"""启动时自动加载逻辑"""
if os.path.exists(self.builtin_formula_path):
self.refresh_formulas(silent=True)
else:
print(f"DEBUG: 自动加载失败,路径不存在: {self.builtin_formula_path}")
def refresh_formulas(self, silent=False):
path = self.builtin_formula_path
if not os.path.exists(path):
if not silent: QMessageBox.warning(self, "错误", f"找不到内置公式文件:\n{path}")
return
try:
# 清理旧列表
for i in reversed(range(self.formula_layout.count())):
widget = self.formula_layout.itemAt(i).widget()
if widget: widget.deleteLater()
self.index_checkboxes.clear()
# 鲁棒性读取:尝试不同编码
for encoding in ['utf-8', 'gbk', 'utf-8-sig']:
try:
df = pd.read_csv(path, encoding=encoding)
if 'Formula_Name' in df.columns: break
except: continue
if 'Formula_Name' not in df.columns:
if not silent: QMessageBox.critical(self, "错误", "CSV文件缺少 'Formula_Name'")
return
names = df['Formula_Name'].dropna().unique().tolist()
row, col = 0, 0
for name in names:
name = str(name).strip()
if not name: continue
cb = QCheckBox(name)
cb.setChecked(True)
self.index_checkboxes[name] = cb
self.formula_layout.addWidget(cb, row, col)
col += 1
if col >= 3:
col = 0
row += 1
# 强制UI更新
self.scroll_content.adjustSize()
print(f"✅ 成功加载 {len(self.index_checkboxes)} 个公式")
except Exception as e:
if not silent: QMessageBox.critical(self, "加载失败", f"原因: {str(e)}")
def select_all_formulas(self):
for cb in self.index_checkboxes.values(): cb.setChecked(True)
def deselect_all_formulas(self):
for cb in self.index_checkboxes.values(): cb.setChecked(False)
def get_config(self):
selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
return {
'training_csv_path': self.training_data_widget.get_path(),
'formula_csv_file': self.builtin_formula_path,
'formula_names': selected,
'output_file': self.output_file_widget.get_path(),
'enabled': self.enable_checkbox.isChecked()
}
def set_config(self, config):
if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path'])
if 'formula_names' in config:
sel = set(config['formula_names'])
for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
if 'output_file' in config: self.output_file_widget.set_path(config['output_file'])
self.enable_checkbox.setChecked(config.get('enabled', True))
def update_from_config(self, work_dir=None, pipeline=None):
if work_dir: self.work_dir = work_dir
main = self.window()
if hasattr(main, 'step5_panel'):
p5 = main.step5_panel.output_file.get_path() # 修正:变量名对齐
if p5:
if not os.path.isabs(p5): p5 = os.path.join(self.work_dir or '', p5).replace('\\', '/')
self.training_data_widget.set_path(p5)
if self.work_dir:
out = os.path.join(self.work_dir, "6_water_quality_indices", "training_spectra_indices.csv").replace('\\', '/')
self.output_file_widget.set_path(out)
def run_step(self):
config = self.get_config()
if not config['training_csv_path']:
QMessageBox.warning(self, "提示", "请先选择输入数据")
return
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'): parent = parent.parent()
if parent: parent.run_single_step('step5_5', {'step5_5': config})

View File

@ -1,374 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step6_75 面板 - 自定义回归分析
"""
import os
from typing import Dict
import pandas as pd
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton,
QScrollArea, QMessageBox,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step6_75Panel(QWidget):
"""步骤6.75:自定义回归分析"""
def __init__(self, parent=None):
super().__init__(parent)
self.x_column_checkboxes: Dict[str, QCheckBox] = {}
self.y_column_checkboxes: Dict[str, QCheckBox] = {}
self.method_checkboxes: Dict[str, QCheckBox] = {}
self.csv_columns = []
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法")
hint.setStyleSheet("color: #666; font-size: 11px;")
layout.addWidget(hint)
# CSV文件选择
csv_group = QGroupBox("数据文件")
csv_layout = QVBoxLayout()
self.csv_file = FileSelectWidget(
"输入CSV文件:",
"CSV Files (*.csv);;All Files (*.*)"
)
self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed)
csv_layout.addWidget(self.csv_file)
self.refresh_btn = QPushButton("刷新列信息")
self.refresh_btn.clicked.connect(self.refresh_csv_columns)
csv_layout.addWidget(self.refresh_btn)
csv_group.setLayout(csv_layout)
layout.addWidget(csv_group)
# 自变量选择
x_group = QGroupBox("自变量列选择 (可多选)")
x_layout = QVBoxLayout()
x_scroll = QScrollArea()
x_scroll.setWidgetResizable(True)
x_scroll.setMinimumHeight(250)
x_scroll.setMaximumHeight(350)
x_widget = QWidget()
self.x_columns_layout = QGridLayout()
x_widget.setLayout(self.x_columns_layout)
x_scroll.setWidget(x_widget)
x_layout.addWidget(x_scroll)
x_btn_layout = QHBoxLayout()
self.x_select_all = QPushButton("全选")
self.x_deselect_all = QPushButton("全不选")
self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True))
self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False))
x_btn_layout.addWidget(self.x_select_all)
x_btn_layout.addWidget(self.x_deselect_all)
x_btn_layout.addStretch()
x_layout.addLayout(x_btn_layout)
x_group.setLayout(x_layout)
layout.addWidget(x_group)
# 因变量选择
y_group = QGroupBox("因变量列选择 (可多选)")
y_layout = QVBoxLayout()
y_scroll = QScrollArea()
y_scroll.setWidgetResizable(True)
y_scroll.setMinimumHeight(200)
y_scroll.setMaximumHeight(300)
y_widget = QWidget()
self.y_columns_layout = QGridLayout()
y_widget.setLayout(self.y_columns_layout)
y_scroll.setWidget(y_widget)
y_layout.addWidget(y_scroll)
y_btn_layout = QHBoxLayout()
self.y_select_all = QPushButton("全选")
self.y_deselect_all = QPushButton("全不选")
self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True))
self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False))
y_btn_layout.addWidget(self.y_select_all)
y_btn_layout.addWidget(self.y_deselect_all)
y_btn_layout.addStretch()
y_layout.addLayout(y_btn_layout)
y_group.setLayout(y_layout)
layout.addWidget(y_group)
# 回归方法选择
method_group = QGroupBox("回归方法选择 (可多选)")
method_layout = QVBoxLayout()
method_grid = QGridLayout()
regression_methods = [
'linear', 'exponential', 'power', 'logarithmic',
'polynomial', 'hyperbolic', 'sigmoidal'
]
for i, method in enumerate(regression_methods):
checkbox = QCheckBox(method)
if method in ['linear', 'exponential', 'power', 'logarithmic']:
checkbox.setChecked(True)
self.method_checkboxes[method] = checkbox
method_grid.addWidget(checkbox, i // 3, i % 3)
method_layout.addLayout(method_grid)
method_btn_layout = QHBoxLayout()
self.method_select_all = QPushButton("全选")
self.method_deselect_all = QPushButton("全不选")
self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True))
self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False))
method_btn_layout.addWidget(self.method_select_all)
method_btn_layout.addWidget(self.method_deselect_all)
method_btn_layout.addStretch()
method_layout.addLayout(method_btn_layout)
method_group.setLayout(method_layout)
layout.addWidget(method_group)
# 输出目录
output_group = QGroupBox("输出设置")
output_layout = QFormLayout()
self.output_dir = QLineEdit()
self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充
output_layout.addRow("输出目录名:", self.output_dir)
output_group.setLayout(output_layout)
layout.addWidget(output_group)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_button = QPushButton("独立运行此步骤")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.clicked.connect(self.run_step)
layout.addWidget(self.run_button)
layout.addStretch()
self.setLayout(layout)
def toggle_checkboxes(self, checkboxes_dict, checked):
"""统一设置checkbox状态"""
for checkbox in checkboxes_dict.values():
checkbox.setChecked(checked)
def on_csv_file_changed(self):
"""CSV文件改变时自动刷新列信息"""
self.refresh_csv_columns()
def refresh_csv_columns(self):
"""刷新CSV文件的列信息"""
csv_path = self.csv_file.get_path()
if not csv_path or not os.path.exists(csv_path):
self.csv_columns = []
self.update_column_widgets()
return
try:
df = pd.read_csv(csv_path, nrows=0)
self.csv_columns = list(df.columns)
self.update_column_widgets()
except Exception as e:
self.csv_columns = []
self.update_column_widgets()
print(f"读取CSV列信息失败: {e}")
def update_column_widgets(self):
"""更新列选择组件"""
for checkbox in self.x_column_checkboxes.values():
checkbox.setParent(None)
self.x_column_checkboxes.clear()
for checkbox in self.y_column_checkboxes.values():
checkbox.setParent(None)
self.y_column_checkboxes.clear()
if not self.csv_columns:
return
for i, col in enumerate(self.csv_columns):
checkbox = QCheckBox(col)
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
checkbox.setChecked(True)
self.x_column_checkboxes[col] = checkbox
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
for i, col in enumerate(self.csv_columns):
checkbox = QCheckBox(col)
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
checkbox.setChecked(True)
self.y_column_checkboxes[col] = checkbox
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
self.x_columns_layout.update()
self.y_columns_layout.update()
def get_config(self):
selected_x_columns = [
col for col, checkbox in self.x_column_checkboxes.items()
if checkbox.isChecked()
]
selected_y_columns = [
col for col, checkbox in self.y_column_checkboxes.items()
if checkbox.isChecked()
]
selected_methods = [
method for method, checkbox in self.method_checkboxes.items()
if checkbox.isChecked()
]
if not selected_methods:
selected_methods = 'all'
return {
'csv_path': self.csv_file.get_path() or None,
'x_columns': selected_x_columns,
'y_columns': selected_y_columns,
'methods': selected_methods,
'output_dir': self.output_dir.text().strip() or None,
'enabled': self.enable_checkbox.isChecked()
}
def set_config(self, config):
if 'csv_path' in config:
self.csv_file.set_path(config['csv_path'])
self.refresh_csv_columns()
if 'x_columns' in config:
selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set()
for col, checkbox in self.x_column_checkboxes.items():
checkbox.setChecked(col in selected_x)
if 'y_columns' in config:
selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set()
for col, checkbox in self.y_column_checkboxes.items():
checkbox.setChecked(col in selected_y)
if 'methods' in config:
methods = config['methods']
if isinstance(methods, list):
selected_methods = set(methods)
elif methods == 'all':
selected_methods = set(self.method_checkboxes.keys())
else:
selected_methods = set()
for method, checkbox in self.method_checkboxes.items():
checkbox.setChecked(method in selected_methods)
if 'output_dir' in config:
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
# 1. 尝试从 Step5 界面读取训练光谱 CSV 路径
main_window = self.window()
if main_window and hasattr(main_window, 'step5_panel'):
step5_widget = getattr(main_window.step5_panel, 'output_file', None)
step5_output_path = ""
if hasattr(step5_widget, 'get_path'):
step5_output_path = step5_widget.get_path() or ""
elif hasattr(step5_widget, 'text'):
step5_output_path = step5_widget.text() or ""
if step5_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_output_path):
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
existing = self.csv_file.get_path()
if not existing or not existing.strip():
self.csv_file.set_path(step5_output_path)
# 2. 自动填充输出目录9_Custom_Regression_Modeling
if self.work_dir:
output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_dir.text().strip()
if not existing_out:
self.output_dir.setText(output_dir)
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def run_step(self):
"""独立运行步骤6.75"""
csv_path = self.csv_file.get_path()
if not csv_path:
QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件")
return
if not os.path.exists(csv_path):
QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在")
return
selected_x_columns = [
col for col, checkbox in self.x_column_checkboxes.items()
if checkbox.isChecked()
]
if not selected_x_columns:
QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列")
return
selected_y_columns = [
col for col, checkbox in self.y_column_checkboxes.items()
if checkbox.isChecked()
]
if not selected_y_columns:
QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列")
return
selected_methods = [
method for method, checkbox in self.method_checkboxes.items()
if checkbox.isChecked()
]
if not selected_methods:
QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法")
return
config = self.get_config()
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step6_75', {'step6_75': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")

View File

@ -1,415 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step6 面板 - 机器学习建模
"""
import os
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
QPushButton, QFileDialog, QMessageBox,
)
from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
# ============================================================
# 中文映射表(内部键名 -> 显示文本)
# ============================================================
# 预处理方法:内部键 -> 显示文本
PREPROC_CHINESE = {
'None': '无 (None)',
'MMS': '最小-最大归一化 (MMS)',
'SS': '标度化 (SS)',
'SNV': '标准正态变换 (SNV)',
'MA': '移动平均 (MA)',
'SG': 'Savitzky-Golay (SG)',
'MSC': '多元散射校正 (MSC)',
'D1': '一阶导数 (D1)',
'D2': '二阶导数 (D2)',
'DT': '去趋势 (DT)',
'CT': '中心化 (CT)',
}
# 模型类型:内部键 -> 显示文本
MODEL_CHINESE = {
# 线性模型
'LinearRegression': '多元线性回归 (MLR)',
'Ridge': '岭回归 (Ridge)',
'Lasso': '套索回归 (Lasso)',
'ElasticNet': '弹性网络 (ElasticNet)',
'PLS': '偏最小二乘 (PLSR)',
# 树模型
'DecisionTree': '决策树 (CART)',
'RF': '随机森林 (RF)',
'ExtraTrees': '极端随机树 (ET)',
'XGBoost': '极值梯度提升 (XGBoost)',
'LightGBM': '轻量梯度提升 (LightGBM)',
'CatBoost': '类别梯度提升 (CatBoost)',
# 集成学习
'GradientBoosting': '梯度提升树 (GBDT)',
'AdaBoost': '自适应提升 (AdaBoost)',
# 其他模型
'SVR': '支持向量回归 (SVR)',
'KNN': 'K近邻回归 (KNN)',
'MLP': '多层感知机 (BP神经网络)',
}
# 数据划分方法:内部键 -> 显示文本
SPLIT_CHINESE = {
'spxy': 'SPXY 算法 (考量X-Y空间)',
'ks': 'KS 算法 (考量X空间)',
'random': '随机划分 (Random)',
}
class Step6Panel(QWidget):
"""步骤6机器学习建模"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 标题
# 训练数据文件(用于独立运行)
self.training_csv_file = FileSelectWidget(
"训练数据:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.training_csv_file)
# 机器学习模型页面
self.ml_page = QWidget()
self.create_ml_page()
layout.addWidget(self.ml_page)
# 输出文件路径
self.output_path = FileSelectWidget(
"输出文件:",
"CSV Files (*.csv);;All Files (*.*)",
mode="save"
)
self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
self.output_path.browse_btn.clicked.disconnect()
self.output_path.browse_btn.clicked.connect(self.browse_output_path)
layout.addWidget(self.output_path)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(False)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
layout.addStretch()
self.setLayout(layout)
def create_ml_page(self):
"""创建机器学习模型页面"""
layout = QVBoxLayout()
# 参数设置
params_group = QGroupBox("训练参数")
params_layout = QFormLayout()
self.feature_start = QLineEdit()
self.feature_start.setText("374.285004")
params_layout.addRow("特征起始列:", self.feature_start)
self.cv_folds = QSpinBox()
self.cv_folds.setRange(2, 10)
self.cv_folds.setValue(3)
params_layout.addRow("交叉验证折数:", self.cv_folds)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 预处理方法 - 多选
preproc_group = QGroupBox("预处理方法 (可多选)")
preproc_layout = QVBoxLayout()
preproc_grid = QGridLayout()
self.preproc_checkboxes = {}
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
for i, method in enumerate(preproc_methods):
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
checkbox.setChecked(False)
self.preproc_checkboxes[method] = checkbox
preproc_grid.addWidget(checkbox, i // 4, i % 4)
button_layout = QHBoxLayout()
select_all_btn = QPushButton("全选")
deselect_all_btn = QPushButton("全不选")
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
button_layout.addWidget(select_all_btn)
button_layout.addWidget(deselect_all_btn)
button_layout.addStretch()
preproc_layout.addLayout(preproc_grid)
preproc_layout.addLayout(button_layout)
preproc_group.setLayout(preproc_layout)
layout.addWidget(preproc_group)
# 模型选择 - 多选
model_group = QGroupBox("模型类型 (可多选)")
model_layout = QVBoxLayout()
model_grid = QGridLayout()
self.model_checkboxes = {}
model_groups = [
("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
("【集成学习】", ['GradientBoosting', 'AdaBoost']),
("【其他模型】", ['SVR', 'KNN', 'MLP'])
]
row = 0
for group_name, models in model_groups:
group_label = QLabel(f"<b>{group_name}</b>")
group_label.setStyleSheet(
f"background-color: {ModernStylesheet.COLORS['hover']}; "
f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
f"border-radius: 3px;"
)
model_grid.addWidget(group_label, row, 0, 1, 4)
row += 1
for i, model in enumerate(models):
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
checkbox.setChecked(False)
self.model_checkboxes[model] = checkbox
model_grid.addWidget(checkbox, row, i % 4)
if (i + 1) % 4 == 0:
row += 1
row += 1
model_button_layout = QHBoxLayout()
model_select_all = QPushButton("全选")
model_deselect_all = QPushButton("全不选")
model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
model_button_layout.addWidget(model_select_all)
model_button_layout.addWidget(model_deselect_all)
model_button_layout.addStretch()
model_layout.addLayout(model_grid)
model_layout.addLayout(model_button_layout)
model_group.setLayout(model_layout)
layout.addWidget(model_group)
# 数据划分方法 - 多选
split_group = QGroupBox("数据划分方法 (可多选)")
split_layout = QVBoxLayout()
split_grid = QGridLayout()
self.split_checkboxes = {}
split_methods = ['spxy', 'ks', 'random']
for i, method in enumerate(split_methods):
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
checkbox.setChecked(False)
self.split_checkboxes[method] = checkbox
split_grid.addWidget(checkbox, 0, i)
split_button_layout = QHBoxLayout()
split_select_all = QPushButton("全选")
split_deselect_all = QPushButton("全不选")
split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
split_button_layout.addWidget(split_select_all)
split_button_layout.addWidget(split_deselect_all)
split_button_layout.addStretch()
split_layout.addLayout(split_grid)
split_layout.addLayout(split_button_layout)
split_group.setLayout(split_layout)
layout.addWidget(split_group)
self.ml_page.setLayout(layout)
def _toggle_checkboxes(self, checkboxes_dict, checked):
"""统一设置checkbox状态"""
for checkbox in checkboxes_dict.values():
checkbox.setChecked(checked)
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_output_path(self):
"""浏览输出文件路径(保存对话框)"""
current = self.output_path.get_path().strip()
if current:
initial_dir = os.path.dirname(current)
initial_file = os.path.basename(current)
else:
initial_dir = ""
initial_file = ""
if not initial_dir or not os.path.isdir(initial_dir):
# 默认定位到 indices 目录
work_dir = self._get_default_work_dir()
initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
if initial_dir and not os.path.isdir(initial_dir):
os.makedirs(initial_dir, exist_ok=True)
file_path, _ = QFileDialog.getSaveFileName(
self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
"CSV Files (*.csv);;All Files (*.*)"
)
if file_path:
self.output_path.set_path(file_path)
def get_config(self):
"""获取配置"""
preprocessing_methods = [
method for method, checkbox in self.preproc_checkboxes.items()
if checkbox.isChecked()
]
model_names = [
model for model, checkbox in self.model_checkboxes.items()
if checkbox.isChecked()
]
split_methods = [
method for method, checkbox in self.split_checkboxes.items()
if checkbox.isChecked()
]
config = {
'feature_start_column': self.feature_start.text(),
'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
'model_names': model_names if model_names else ['SVR'],
'split_methods': split_methods if split_methods else ['random'],
'cv_folds': self.cv_folds.value()
}
training_csv_path = self.training_csv_file.get_path()
if training_csv_path:
config['training_csv_path'] = training_csv_path
output_path = self.output_path.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'feature_start_column' in config:
self.feature_start.setText(str(config['feature_start_column']))
if 'cv_folds' in config:
self.cv_folds.setValue(config['cv_folds'])
if 'preprocessing_methods' in config:
methods = config['preprocessing_methods']
for method, checkbox in self.preproc_checkboxes.items():
checkbox.setChecked(method in methods)
if 'model_names' in config:
models = config['model_names']
for model, checkbox in self.model_checkboxes.items():
checkbox.setChecked(model in models)
if 'split_methods' in config:
methods = config['split_methods']
for method, checkbox in self.split_checkboxes.items():
checkbox.setChecked(method in methods)
if 'training_csv_path' in config:
self.training_csv_file.set_path(config['training_csv_path'])
if 'output_path' in config:
self.output_path.set_path(config['output_path'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
# 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
main_window = self.window()
if hasattr(main_window, 'step5_panel'):
# 优先直接从 Step5 的输出 widget 读取
step5_output = main_window.step5_panel.output_file.get_path()
if step5_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_output):
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
self.training_csv_file.set_path(step5_output)
elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
# 回退:从 Step5 的 config 字典中查找可能的键名
step5_cfg = main_window.step5_panel.get_config()
step5_csv = (
step5_cfg.get('training_csv_path')
or step5_cfg.get('output_file')
or step5_cfg.get('csv_path')
or step5_cfg.get('output_csv')
)
if step5_csv:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_csv):
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
self.training_csv_file.set_path(step5_csv)
# 2. 自动填充输出文件路径(基于工作目录和输入文件名)
# 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
# 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
if self.work_dir:
indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
os.makedirs(indices_dir, exist_ok=True)
training_csv = self.training_csv_file.get_path()
if training_csv:
basename = os.path.splitext(os.path.basename(training_csv))[0]
output_file = f"{basename}_indices.csv"
else:
output_file = "water_quality_indices.csv"
output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
self.output_path.set_path(output_path)
else:
self.output_path.set_path("")
def run_step(self):
"""独立运行步骤6"""
training_csv_path = self.training_csv_file.get_path()
if not training_csv_path:
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件")
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step6': self.get_config()}
main_window.run_single_step('step6', config)
def get_training_params(self):
"""获取模型训练参数"""
return {
'pipeline_type': 'machine_learning',
'feature_start': float(self.feature_start.text()),
'cv_folds': self.cv_folds.value(),
'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
}

View File

@ -1,23 +1,75 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step7 面板 - 采样点生成
Step7 面板 - 机器学习建模
"""
import os
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QPushButton, QCheckBox, QSpinBox, QMessageBox,
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
QPushButton, QFileDialog, QMessageBox,
)
from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.dialogs import SamplingViewerDialog
from src.gui.styles import ModernStylesheet
# ============================================================
# 中文映射表(内部键名 -> 显示文本)
# ============================================================
# 预处理方法:内部键 -> 显示文本
PREPROC_CHINESE = {
'None': '无 (None)',
'MMS': '最小-最大归一化 (MMS)',
'SS': '标度化 (SS)',
'SNV': '标准正态变换 (SNV)',
'MA': '移动平均 (MA)',
'SG': 'Savitzky-Golay (SG)',
'MSC': '多元散射校正 (MSC)',
'D1': '一阶导数 (D1)',
'D2': '二阶导数 (D2)',
'DT': '去趋势 (DT)',
'CT': '中心化 (CT)',
}
# 模型类型:内部键 -> 显示文本
MODEL_CHINESE = {
# 线性模型
'LinearRegression': '多元线性回归 (MLR)',
'Ridge': '岭回归 (Ridge)',
'Lasso': '套索回归 (Lasso)',
'ElasticNet': '弹性网络 (ElasticNet)',
'PLS': '偏最小二乘 (PLSR)',
# 树模型
'DecisionTree': '决策树 (CART)',
'RF': '随机森林 (RF)',
'ExtraTrees': '极端随机树 (ET)',
'XGBoost': '极值梯度提升 (XGBoost)',
'LightGBM': '轻量梯度提升 (LightGBM)',
'CatBoost': '类别梯度提升 (CatBoost)',
# 集成学习
'GradientBoosting': '梯度提升树 (GBDT)',
'AdaBoost': '自适应提升 (AdaBoost)',
# 其他模型
'SVR': '支持向量回归 (SVR)',
'KNN': 'K近邻回归 (KNN)',
'MLP': '多层感知机 (BP神经网络)',
}
# 数据划分方法:内部键 -> 显示文本
SPLIT_CHINESE = {
'spxy': 'SPXY 算法 (考量X-Y空间)',
'ks': 'KS 算法 (考量X空间)',
'random': '随机划分 (Random)',
}
class Step7Panel(QWidget):
"""步骤7采样点生成"""
"""步骤7机器学习建模"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
@ -25,58 +77,35 @@ class Step7Panel(QWidget):
def init_ui(self):
layout = QVBoxLayout()
# 去耀斑影像文件(用于独立运行)
self.deglint_img_file = FileSelectWidget(
"去耀斑影像:",
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
)
layout.addWidget(self.deglint_img_file)
# 标题
# 水域掩膜文件(可选,用于独立运行)
self.water_mask_file = FileSelectWidget(
"水域掩膜:",
"Mask Files (*.dat *.tif);;All Files (*.*)"
)
self.water_mask_file.label.setText("水域掩膜:")
layout.addWidget(self.water_mask_file)
# 参数设置
params_group = QGroupBox("采样参数")
params_layout = QFormLayout()
self.interval = QSpinBox()
self.interval.setRange(10, 500)
self.interval.setValue(50)
params_layout.addRow("采样点间隔(像素):", self.interval)
self.sample_radius = QSpinBox()
self.sample_radius.setRange(1, 50)
self.sample_radius.setValue(5)
params_layout.addRow("采样半径(像素):", self.sample_radius)
self.chunk_size = QSpinBox()
self.chunk_size.setRange(100, 10000)
self.chunk_size.setValue(1000)
params_layout.addRow("处理块大小:", self.chunk_size)
self.use_adaptive_sampling = QCheckBox("启用自适应采样")
self.use_adaptive_sampling.setChecked(True)
params_layout.addRow("采样模式:", self.use_adaptive_sampling)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出文件路径
self.output_file = FileSelectWidget(
"输出采样点:",
# 训练数据文件(用于独立运行)
self.training_csv_file = FileSelectWidget(
"训练数据:",
"CSV Files (*.csv);;All Files (*.*)"
)
self.output_file.line_edit.setPlaceholderText("sampling_points.csv")
layout.addWidget(self.output_file)
layout.addWidget(self.training_csv_file)
# 机器学习模型页面
self.ml_page = QWidget()
self.create_ml_page()
layout.addWidget(self.ml_page)
# 输出文件路径
self.output_path = FileSelectWidget(
"输出文件:",
"CSV Files (*.csv);;All Files (*.*)",
mode="save"
)
self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
self.output_path.browse_btn.clicked.disconnect()
self.output_path.browse_btn.clicked.connect(self.browse_output_path)
layout.addWidget(self.output_path)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
self.enable_checkbox.setChecked(False)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
@ -85,57 +114,233 @@ class Step7Panel(QWidget):
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
# 交互式预览按钮
self.preview_btn = QPushButton("📊 交互式预览采样点与光谱")
self.preview_btn.setEnabled(False)
self.preview_btn.clicked.connect(self._open_sampling_viewer)
layout.addWidget(self.preview_btn)
layout.addStretch()
self.setLayout(layout)
# 监听输出路径变化,实时更新预览按钮状态
self.output_file.line_edit.textChanged.connect(self._on_output_changed)
def create_ml_page(self):
"""创建机器学习模型页面"""
layout = QVBoxLayout()
# 参数设置
params_group = QGroupBox("训练参数")
params_layout = QFormLayout()
self.feature_start = QLineEdit()
self.feature_start.setText("374.285004")
params_layout.addRow("特征起始列:", self.feature_start)
self.cv_folds = QSpinBox()
self.cv_folds.setRange(2, 10)
self.cv_folds.setValue(3)
params_layout.addRow("交叉验证折数:", self.cv_folds)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 预处理方法 - 多选
preproc_group = QGroupBox("预处理方法 (可多选)")
preproc_layout = QVBoxLayout()
preproc_grid = QGridLayout()
self.preproc_checkboxes = {}
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
for i, method in enumerate(preproc_methods):
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
checkbox.setChecked(False)
self.preproc_checkboxes[method] = checkbox
preproc_grid.addWidget(checkbox, i // 4, i % 4)
button_layout = QHBoxLayout()
select_all_btn = QPushButton("全选")
deselect_all_btn = QPushButton("全不选")
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
button_layout.addWidget(select_all_btn)
button_layout.addWidget(deselect_all_btn)
button_layout.addStretch()
preproc_layout.addLayout(preproc_grid)
preproc_layout.addLayout(button_layout)
preproc_group.setLayout(preproc_layout)
layout.addWidget(preproc_group)
# 模型选择 - 多选
model_group = QGroupBox("模型类型 (可多选)")
model_layout = QVBoxLayout()
model_grid = QGridLayout()
self.model_checkboxes = {}
model_groups = [
("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
("【集成学习】", ['GradientBoosting', 'AdaBoost']),
("【其他模型】", ['SVR', 'KNN', 'MLP'])
]
row = 0
for group_name, models in model_groups:
group_label = QLabel(f"<b>{group_name}</b>")
group_label.setStyleSheet(
f"background-color: {ModernStylesheet.COLORS['hover']}; "
f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
f"border-radius: 3px;"
)
model_grid.addWidget(group_label, row, 0, 1, 4)
row += 1
for i, model in enumerate(models):
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
checkbox.setChecked(False)
self.model_checkboxes[model] = checkbox
model_grid.addWidget(checkbox, row, i % 4)
if (i + 1) % 4 == 0:
row += 1
row += 1
model_button_layout = QHBoxLayout()
model_select_all = QPushButton("全选")
model_deselect_all = QPushButton("全不选")
model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
model_button_layout.addWidget(model_select_all)
model_button_layout.addWidget(model_deselect_all)
model_button_layout.addStretch()
model_layout.addLayout(model_grid)
model_layout.addLayout(model_button_layout)
model_group.setLayout(model_layout)
layout.addWidget(model_group)
# 数据划分方法 - 多选
split_group = QGroupBox("数据划分方法 (可多选)")
split_layout = QVBoxLayout()
split_grid = QGridLayout()
self.split_checkboxes = {}
split_methods = ['spxy', 'ks', 'random']
for i, method in enumerate(split_methods):
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
checkbox.setChecked(False)
self.split_checkboxes[method] = checkbox
split_grid.addWidget(checkbox, 0, i)
split_button_layout = QHBoxLayout()
split_select_all = QPushButton("全选")
split_deselect_all = QPushButton("全不选")
split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
split_button_layout.addWidget(split_select_all)
split_button_layout.addWidget(split_deselect_all)
split_button_layout.addStretch()
split_layout.addLayout(split_grid)
split_layout.addLayout(split_button_layout)
split_group.setLayout(split_layout)
layout.addWidget(split_group)
self.ml_page.setLayout(layout)
def _toggle_checkboxes(self, checkboxes_dict, checked):
"""统一设置checkbox状态"""
for checkbox in checkboxes_dict.values():
checkbox.setChecked(checked)
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_output_path(self):
"""浏览输出文件路径(保存对话框)"""
current = self.output_path.get_path().strip()
if current:
initial_dir = os.path.dirname(current)
initial_file = os.path.basename(current)
else:
initial_dir = ""
initial_file = ""
if not initial_dir or not os.path.isdir(initial_dir):
# 默认定位到 indices 目录
work_dir = self._get_default_work_dir()
initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
if initial_dir and not os.path.isdir(initial_dir):
os.makedirs(initial_dir, exist_ok=True)
file_path, _ = QFileDialog.getSaveFileName(
self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
"CSV Files (*.csv);;All Files (*.*)"
)
if file_path:
self.output_path.set_path(file_path)
def get_config(self):
"""获取配置"""
preprocessing_methods = [
method for method, checkbox in self.preproc_checkboxes.items()
if checkbox.isChecked()
]
model_names = [
model for model, checkbox in self.model_checkboxes.items()
if checkbox.isChecked()
]
split_methods = [
method for method, checkbox in self.split_checkboxes.items()
if checkbox.isChecked()
]
config = {
'interval': self.interval.value(),
'sample_radius': self.sample_radius.value(),
'chunk_size': self.chunk_size.value(),
'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(),
'feature_start_column': self.feature_start.text(),
'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
'model_names': model_names if model_names else ['SVR'],
'split_methods': split_methods if split_methods else ['random'],
'cv_folds': self.cv_folds.value()
}
deglint_img_path = self.deglint_img_file.get_path()
if deglint_img_path:
config['deglint_img_path'] = deglint_img_path
water_mask_path = self.water_mask_file.get_path()
if water_mask_path:
config['water_mask_path'] = water_mask_path
training_csv_path = self.training_csv_file.get_path()
if training_csv_path:
config['training_csv_path'] = training_csv_path
output_path = self.output_path.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'interval' in config:
self.interval.setValue(config['interval'])
if 'sample_radius' in config:
self.sample_radius.setValue(config['sample_radius'])
if 'chunk_size' in config:
self.chunk_size.setValue(config['chunk_size'])
if 'use_adaptive_sampling' in config:
self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling'])
if 'deglint_img_path' in config:
self.deglint_img_file.set_path(config['deglint_img_path'])
if 'water_mask_path' in config:
self.water_mask_file.set_path(config['water_mask_path'])
if 'glint_mask_path' in config:
self.glint_mask_file.set_path(config['glint_mask_path'])
if 'feature_start_column' in config:
self.feature_start.setText(str(config['feature_start_column']))
if 'cv_folds' in config:
self.cv_folds.setValue(config['cv_folds'])
if 'preprocessing_methods' in config:
methods = config['preprocessing_methods']
for method, checkbox in self.preproc_checkboxes.items():
checkbox.setChecked(method in methods)
if 'model_names' in config:
models = config['model_names']
for model, checkbox in self.model_checkboxes.items():
checkbox.setChecked(model in models)
if 'split_methods' in config:
methods = config['split_methods']
for method, checkbox in self.split_checkboxes.items():
checkbox.setChecked(method in methods)
if 'training_csv_path' in config:
self.training_csv_file.set_path(config['training_csv_path'])
if 'output_path' in config:
self.output_path.set_path(config['output_path'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充去耀斑影像和掩膜路径
"""从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径
pipeline: Pipeline 实例(未使用,保留接口兼容性
"""
if work_dir:
self.work_dir = work_dir
@ -144,81 +349,53 @@ class Step7Panel(QWidget):
else:
self.work_dir = None
# 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
main_window = self.window()
if hasattr(main_window, 'step5_panel'):
# 优先直接从 Step5 的输出 widget 读取
step5_output = main_window.step5_panel.output_file.get_path()
if step5_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_output):
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
self.training_csv_file.set_path(step5_output)
elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
# 回退:从 Step5 的 config 字典中查找可能的键名
step5_cfg = main_window.step5_panel.get_config()
step5_csv = (
step5_cfg.get('training_csv_path')
or step5_cfg.get('output_file')
or step5_cfg.get('csv_path')
or step5_cfg.get('output_csv')
)
if step5_csv:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_csv):
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
self.training_csv_file.set_path(step5_csv)
# 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径
deglint_path = None
if pipeline and hasattr(pipeline, 'step_outputs'):
step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {})
deglint_path = (
step3_outputs.get('deglint_image')
or step3_outputs.get('output_path')
or step3_outputs.get('output_file')
or step3_outputs.get('deglint_img_path')
)
# 回退:从 step3 面板 widget 直接读取(可能是相对路径)
if not deglint_path and hasattr(main_window, 'step3_panel'):
deglint_path = main_window.step3_panel.output_file.get_path()
if deglint_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(deglint_path):
deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/')
self.deglint_img_file.set_path(deglint_path)
# 2. 填充水域掩膜路径优先级pipeline.step_outputs > step1_panel > 1_water_mask > input-test
water_mask_path = None
if pipeline and hasattr(pipeline, 'step_outputs'):
step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {})
water_mask_path = (
step1_outputs.get('water_mask')
or step1_outputs.get('output_path')
or step1_outputs.get('output_file')
)
# 回退:从 step1 面板 widget 直接读取
if not water_mask_path and hasattr(main_window, 'step1_panel'):
water_mask_path = main_window.step1_panel.output_file.get_path()
# 备选:扫描 1_water_mask 目录下的 .dat 文件
if not water_mask_path and self.work_dir:
mask_dir = os.path.join(self.work_dir, "1_water_mask")
if os.path.isdir(mask_dir):
dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')]
if dat_files:
water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/')
# 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat
if not water_mask_path and self.work_dir:
input_test_dir = os.path.join(self.work_dir, "input-test")
if os.path.isdir(input_test_dir):
dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')]
# 优先匹配 water_mask_from_shp.dat
for f in dat_files:
if 'water_mask_from_shp' in f.lower():
water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/')
break
# 否则取第一个 .dat 文件
if not water_mask_path and dat_files:
water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/')
if water_mask_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(water_mask_path):
water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/')
self.water_mask_file.set_path(water_mask_path)
# 3. 自动填充输出路径(绝对路径)
# 2. 自动填充输出文件路径(基于工作目录和输入文件名
# 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
# 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
if self.work_dir:
output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
self.output_file.set_path(output_path.replace('\\', '/'))
# 4. 同步更新预览按钮状态(路径可能已自动填充)
self._check_csv_exists()
indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
os.makedirs(indices_dir, exist_ok=True)
training_csv = self.training_csv_file.get_path()
if training_csv:
basename = os.path.splitext(os.path.basename(training_csv))[0]
output_file = f"{basename}_indices.csv"
else:
output_file = "water_quality_indices.csv"
output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
self.output_path.set_path(output_path)
else:
self.output_path.set_path("")
def run_step(self):
"""独立运行步骤7"""
deglint_img_path = self.deglint_img_file.get_path()
if not deglint_img_path:
QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
training_csv_path = self.training_csv_file.get_path()
if not training_csv_path:
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
return
main_window = self.window()
@ -226,27 +403,13 @@ class Step7Panel(QWidget):
config = {'step7': self.get_config()}
main_window.run_single_step('step7', config)
def _check_csv_exists(self):
"""检查 output csv 是否存在,驱动预览按钮启停"""
csv_path = self.output_file.get_path()
enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path))
self.preview_btn.setEnabled(enabled)
return enabled
def _on_output_changed(self, _text=None):
"""输出路径输入框内容变化时调用_text 为 line_edit.textChanged 信号参数)"""
self._check_csv_exists()
def _open_sampling_viewer(self):
"""打开交互式采样点查看器弹窗"""
csv_path = self.output_file.get_path()
if not csv_path or not os.path.exists(csv_path):
QMessageBox.warning(
self, "文件不存在",
f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤7生成数据。"
)
return
dialog = SamplingViewerDialog(csv_path, self)
dialog.exec_()
# 弹窗关闭后再次检查状态(可能文件被覆盖等)
self._check_csv_exists()
def get_training_params(self):
"""获取模型训练参数"""
return {
'pipeline_type': 'machine_learning',
'feature_start': float(self.feature_start.text()),
'cv_folds': self.cv_folds.value(),
'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
}

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step6_5 面板 - 非经验统计回归建模
Step8 面板 - 非经验统计回归建模
"""
import os
@ -17,8 +17,8 @@ from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step6_5Panel(QWidget):
"""步骤6.5:非经验统计回归建模"""
class Step8NonEmpiricalPanel(QWidget):
"""步骤8:非经验统计回归建模"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
@ -280,7 +280,7 @@ class Step6_5Panel(QWidget):
self.output_dir.set_path(dir_path)
def run_step(self):
"""独立运行步骤6.5"""
"""独立运行步骤8"""
training_csv_path = self.training_csv_file.get_path()
if not training_csv_path:
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件")
@ -297,7 +297,7 @@ class Step6_5Panel(QWidget):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step6_5', {'step6_5': config})
parent.run_single_step('step8_non_empirical_modeling', {'step8_non_empirical_modeling': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")

View File

@ -1,462 +1,225 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step8 面板 - 机器学习预测
"""
import os
import sys
import pandas as pd
from pathlib import Path
from typing import Dict, List, Union
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout,
QListWidgetItem,
QWidget, QVBoxLayout, QGroupBox, QGridLayout,
QHBoxLayout, QLabel, QCheckBox, QPushButton, QMessageBox, QScrollArea
)
from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
def get_resource_path(relative_path: str) -> str:
"""适配开发与 PyInstaller 环境的路径获取逻辑。
支持两种打包模式:
1. --onedir 模式:文件在 exe_root/_internal/ 下 → 检查 _internal 目录
2. --onefile 模式:文件在 sys._MEIPASS 平铺目录
"""
# 优先检查 PyInstaller onefile 模式(文件平铺在 _MEIPASS 下)
if hasattr(sys, '_MEIPASS'):
internal_path = os.path.join(sys._MEIPASS, '_internal', relative_path)
if os.path.exists(internal_path):
return internal_path
return os.path.join(sys._MEIPASS, relative_path)
# 兼容 PyInstaller onedir 模式的 _internal 目录exe 同级目录下)
exe_dir = os.path.dirname(sys.executable)
internal_path = os.path.join(exe_dir, '_internal', relative_path)
if os.path.exists(internal_path):
return internal_path
# 开发环境下:基于当前文件 (step8_panel.py) 的绝对路径进行回溯
# 当前在 src/gui/panels/,目标在 src/gui/model/
base_dir = Path(__file__).resolve().parent.parent / "model"
target_path = base_dir / os.path.basename(relative_path)
return str(target_path)
class Step8Panel(QWidget):
"""步骤8机器学习预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.external_models_dict = {} # {subdir_name: model_obj, ...}
self.external_model_dir = "" # 母文件夹路径(隐藏)
self.index_checkboxes: Dict[str, QCheckBox] = {}
# 标识为 waterindex.csv目录跳转逻辑在 get_resource_path 中
self.builtin_formula_path = get_resource_path("waterindex.csv")
self.init_ui()
# 延迟一小会儿加载确保UI框架已就绪
self._auto_load_formulas()
def init_ui(self):
layout = QVBoxLayout()
main_layout = QVBoxLayout()
main_layout.setContentsMargins(20, 20, 20, 20)
main_layout.setSpacing(10)
# -------- 模型来源选择(单选按钮组) --------
source_group = QGroupBox("模型来源")
source_layout = QVBoxLayout()
# 1. 路径展示区 (半透明只读)
path_group = QGroupBox("公式配置源 (内置)")
path_layout = QVBoxLayout()
self.formula_csv_widget = FileSelectWidget("内置CSV路径:", "CSV Files (*.csv)")
self.formula_csv_widget.set_path(self.builtin_formula_path)
self.formula_csv_widget.set_read_only(True)
# 视觉微调:提示用户这是内置的
self.formula_csv_widget.line_edit.setStyleSheet("background-color: #f0f0f0; color: #666;")
path_layout.addWidget(self.formula_csv_widget)
path_group.setLayout(path_layout)
main_layout.addWidget(path_group)
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
self.use_trained_model.setChecked(True)
source_layout.addWidget(self.use_trained_model)
source_layout.addWidget(self.use_external_model)
# 2. 训练数据输入
input_group = QGroupBox("输入样本数据")
input_layout = QVBoxLayout()
self.training_data_widget = FileSelectWidget("特征提取CSV:", "CSV Files (*.csv)")
input_layout.addWidget(self.training_data_widget)
input_group.setLayout(input_layout)
main_layout.addWidget(input_group)
self.use_trained_model.toggled.connect(self._on_model_source_changed)
self.use_external_model.toggled.connect(self._on_model_source_changed)
# 3. 公式选择区
self.formula_group = QGroupBox("待计算水质指数勾选")
formula_outer_layout = QVBoxLayout()
source_group.setStyleSheet("""
QRadioButton {
font-size: 13px;
spacing: 8px;
}
QRadioButton::indicator {
width: 16px;
height: 16px;
border-radius: 9px;
border: 2px solid #A0A0A0;
background-color: #FFFFFF;
}
QRadioButton::indicator:hover {
border: 2px solid #0078D7;
}
QRadioButton::indicator:checked {
background-color: #0078D7;
border: 2px solid #0078D7;
}
""")
btn_layout = QHBoxLayout()
self.select_all_btn = QPushButton("全选")
self.deselect_all_btn = QPushButton("清空")
self.select_all_btn.clicked.connect(self.select_all_formulas)
self.deselect_all_btn.clicked.connect(self.deselect_all_formulas)
btn_layout.addWidget(self.select_all_btn)
btn_layout.addWidget(self.deselect_all_btn)
btn_layout.addStretch()
source_group.setLayout(source_layout)
layout.addWidget(source_group)
self.refresh_button = QPushButton("手动重新加载公式")
self.refresh_button.clicked.connect(lambda: self.refresh_formulas(silent=False))
btn_layout.addWidget(self.refresh_button)
# -------- 外部模型文件选择(条件显示) --------
self.external_model_widget = FileSelectWidget(
"模型母文件夹:",
"Directories"
)
self.external_model_widget.browse_btn.clicked.disconnect()
self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
self.external_model_widget.setVisible(False)
layout.addWidget(self.external_model_widget)
formula_outer_layout.addLayout(btn_layout)
# -------- 已扫描模型列表(条件显示) --------
self.model_list_group = QGroupBox("选择参与预测的模型")
self.model_list_group.setVisible(False)
model_list_layout = QVBoxLayout()
# 核心滚动区
scroll = QScrollArea()
scroll.setWidgetResizable(True)
scroll.setMinimumHeight(300) # 强制最小高度,防止塌陷
self.scroll_content = QWidget()
self.formula_layout = QGridLayout(self.scroll_content)
self.formula_layout.setAlignment(Qt.AlignTop) # 靠顶对齐
scroll.setWidget(self.scroll_content)
formula_outer_layout.addWidget(scroll)
self.model_list = QListWidget()
self.model_list.setMaximumHeight(130)
self.model_list.setSelectionMode(QAbstractItemView.NoSelection)
self.model_list.setStyleSheet("""
QListWidget {
border: 1px solid #C0C0C0;
border-radius: 4px;
background-color: #FFFFFF;
font-size: 12px;
}
QListWidget::item {
padding: 4px 6px;
border-bottom: 1px solid #F0F0F0;
}
QListWidget::item:selected {
background-color: transparent;
}
""")
model_list_layout.addWidget(self.model_list)
self.formula_group.setLayout(formula_outer_layout)
main_layout.addWidget(self.formula_group)
btn_row = QHBoxLayout()
self.btn_select_all = QPushButton("全选")
self.btn_select_all.setMaximumWidth(80)
self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
self.btn_select_all.clicked.connect(self._select_all_models)
# 4. 输出与运行
output_group = QGroupBox("结果输出")
output_layout = QVBoxLayout()
self.output_file_widget = FileSelectWidget("保存路径:", "CSV Files (*.csv)", mode="save")
output_layout.addWidget(self.output_file_widget)
output_group.setLayout(output_layout)
main_layout.addWidget(output_group)
self.btn_select_none = QPushButton("全不选")
self.btn_select_none.setMaximumWidth(80)
self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
self.btn_select_none.clicked.connect(self._select_none_models)
btn_row.addWidget(self.btn_select_all)
btn_row.addWidget(self.btn_select_none)
btn_row.addStretch()
model_list_layout.addLayout(btn_row)
self.model_list_group.setLayout(model_list_layout)
layout.addWidget(self.model_list_group)
# -------- 采样光谱CSV文件用于独立运行--------
self.sampling_csv_file = FileSelectWidget(
"采样光谱CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.sampling_csv_file)
# 模型目录(用于独立运行)
self.models_dir_file = FileSelectWidget(
"模型目录:",
"Directories;;All Files (*.*)"
)
self.models_dir_file.label.setText("模型目录:")
self.models_dir_file.browse_btn.clicked.disconnect()
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
layout.addWidget(self.models_dir_file)
# 参数设置
params_group = QGroupBox("预测参数")
params_layout = QFormLayout()
self.metric = QComboBox()
self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
params_layout.addRow("模型选择指标:", self.metric)
self.prediction_column = QLineEdit()
self.prediction_column.setText("prediction")
params_layout.addRow("预测列名:", self.prediction_column)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出路径
self.output_file = FileSelectWidget(
"输出路径:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox = QCheckBox("启用计算流程")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
main_layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
self.run_button = QPushButton("立即执行计算")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.setMinimumHeight(40)
self.run_button.clicked.connect(self.run_step)
main_layout.addWidget(self.run_button)
layout.addStretch()
self.setLayout(layout)
self.setLayout(main_layout)
def _on_model_source_changed(self, checked: bool):
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
if not checked:
def _auto_load_formulas(self):
"""启动时自动加载逻辑"""
if os.path.exists(self.builtin_formula_path):
self.refresh_formulas(silent=True)
else:
print(f"DEBUG: 自动加载失败,路径不存在: {self.builtin_formula_path}")
def refresh_formulas(self, silent=False):
path = self.builtin_formula_path
if not os.path.exists(path):
if not silent: QMessageBox.warning(self, "错误", f"找不到内置公式文件:\n{path}")
return
is_external = self.use_external_model.isChecked()
self.external_model_widget.setVisible(is_external)
self.model_list_group.setVisible(is_external)
if not is_external:
self.external_models_dict = {}
self.external_model_dir = ""
self._clear_model_list()
def _scan_external_model_dir(self):
"""浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "7_Supervised_Model_Training")
dir_path = QFileDialog.getExistingDirectory(
self,
"选择模型母文件夹",
default,
)
if not dir_path:
return
self.external_model_dir = dir_path
models_found = {}
errors = []
try:
import joblib
# 清理旧列表
for i in reversed(range(self.formula_layout.count())):
widget = self.formula_layout.itemAt(i).widget()
if widget: widget.deleteLater()
self.index_checkboxes.clear()
for subentry in os.scandir(dir_path):
if not subentry.is_dir():
continue
subdir_name = subentry.name
joblib_files = [
f for f in os.scandir(subentry.path)
if f.is_file() and f.name.lower().endswith(".joblib")
]
if not joblib_files:
continue
# 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
joblib_path = joblib_files[0].path
# 鲁棒性读取:尝试不同编码
for encoding in ['utf-8', 'gbk', 'utf-8-sig']:
try:
loaded = joblib.load(joblib_path)
if isinstance(loaded, dict) and "model" in loaded:
model_obj = loaded["model"]
elif hasattr(loaded, "predict"):
model_obj = loaded
else:
errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
continue
models_found[subdir_name] = model_obj
except Exception as e:
errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
df = pd.read_csv(path, encoding=encoding)
if 'Formula_Name' in df.columns: break
except: continue
if 'Formula_Name' not in df.columns:
if not silent: QMessageBox.critical(self, "错误", "CSV文件缺少 'Formula_Name'")
return
names = df['Formula_Name'].dropna().unique().tolist()
row, col = 0, 0
for name in names:
name = str(name).strip()
if not name: continue
cb = QCheckBox(name)
cb.setChecked(True)
self.index_checkboxes[name] = cb
self.formula_layout.addWidget(cb, row, col)
col += 1
if col >= 3:
col = 0
row += 1
# 强制UI更新
self.scroll_content.adjustSize()
print(f"✅ 成功加载 {len(self.index_checkboxes)} 个公式")
except Exception as e:
QMessageBox.warning(
self,
"扫描失败",
f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
)
return
if not silent: QMessageBox.critical(self, "加载失败", f"原因: {str(e)}")
if not models_found:
QMessageBox.warning(
self,
"未找到模型",
f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
"请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
)
self.external_model_widget.set_path("")
self.external_models_dict = {}
self._clear_model_list()
return
def select_all_formulas(self):
for cb in self.index_checkboxes.values(): cb.setChecked(True)
self.external_models_dict = models_found
self._populate_model_list(models_found)
names = sorted(models_found.keys())
display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
self.external_model_widget.set_path(display)
self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
err_lines = "\n".join(errors) if errors else ""
QMessageBox.information(
self,
"模型扫描完成",
f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
f"加载失败 {len(errors)} 个:\n{err_lines}",
)
def _populate_model_list(self, models_dict):
"""将扫描到的模型填充到 QListWidget每个条目可勾选默认全选"""
self.model_list.clear()
for name in sorted(models_dict.keys()):
item = QListWidgetItem(name)
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
item.setCheckState(Qt.Checked)
self.model_list.addItem(item)
def _clear_model_list(self):
"""清空模型列表"""
self.model_list.clear()
def _select_all_models(self):
"""全选:设置所有条目为 Checked"""
for i in range(self.model_list.count()):
self.model_list.item(i).setCheckState(Qt.Checked)
def _select_none_models(self):
"""全不选:设置所有条目为 Unchecked"""
for i in range(self.model_list.count()):
self.model_list.item(i).setCheckState(Qt.Unchecked)
def _get_checked_models_dict(self):
"""从列表中提取用户勾选的模型,组装成字典返回"""
result = {}
for i in range(self.model_list.count()):
item = self.model_list.item(i)
if item.checkState() == Qt.Checked:
name = item.text()
if name in self.external_models_dict:
result[name] = self.external_models_dict[name]
return result
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充采样光谱和模型目录
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
if main_window and hasattr(main_window, 'step7_panel'):
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
step7_output_path = ""
if hasattr(step7_widget, 'get_path'):
step7_output_path = step7_widget.get_path() or ""
elif hasattr(step7_widget, 'text'):
step7_output_path = step7_widget.text() or ""
if step7_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step7_output_path):
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
existing = self.sampling_csv_file.get_path()
if not existing or not existing.strip():
self.sampling_csv_file.set_path(step7_output_path)
# 2. 尝试从 Step6 界面读取监督模型目录
if main_window and hasattr(main_window, 'step6_panel'):
step6_widget = getattr(main_window.step6_panel, 'output_dir', None)
step6_models_dir = ""
if hasattr(step6_widget, 'get_path'):
step6_models_dir = step6_widget.get_path() or ""
elif hasattr(step6_widget, 'text'):
step6_models_dir = step6_widget.text() or ""
if step6_models_dir:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step6_models_dir):
step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/')
existing_models = self.models_dir_file.get_path()
if not existing_models or not existing_models.strip():
self.models_dir_file.set_path(step6_models_dir)
# 3. 自动填充输出路径(机器学习预测目录)
if self.work_dir:
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_file.get_path()
if not existing_out or not existing_out.strip():
self.output_file.set_path(output_dir)
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_models_dir(self):
"""浏览模型目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "7_Supervised_Model_Training")
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
if dir_path:
self.models_dir_file.set_path(dir_path)
def deselect_all_formulas(self):
for cb in self.index_checkboxes.values(): cb.setChecked(False)
def get_config(self):
"""获取配置"""
config = {
'metric': self.metric.currentText(),
'prediction_column': self.prediction_column.text(),
selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
return {
'training_csv_path': self.training_data_widget.get_path(),
'formula_csv_file': self.builtin_formula_path,
'formula_names': selected,
'output_file': self.output_file_widget.get_path(),
'enabled': self.enable_checkbox.isChecked()
}
sampling_csv_path = self.sampling_csv_file.get_path()
if sampling_csv_path:
config['sampling_csv_path'] = sampling_csv_path
models_dir = self.models_dir_file.get_path()
if models_dir:
config['models_dir'] = models_dir
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'metric' in config:
idx = self.metric.findText(config['metric'])
if idx >= 0:
self.metric.setCurrentIndex(idx)
if 'prediction_column' in config:
self.prediction_column.setText(config['prediction_column'])
if 'sampling_csv_path' in config:
self.sampling_csv_file.set_path(config['sampling_csv_path'])
if 'models_dir' in config:
self.models_dir_file.set_path(config['models_dir'])
if 'output_path' in config:
self.output_file.set_path(config['output_path'])
if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path'])
if 'formula_names' in config:
sel = set(config['formula_names'])
for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
if 'output_file' in config: self.output_file_widget.set_path(config['output_file'])
self.enable_checkbox.setChecked(config.get('enabled', True))
def update_from_config(self, work_dir=None, pipeline=None):
if work_dir: self.work_dir = work_dir
main = self.window()
if hasattr(main, 'step5_panel'):
p5 = main.step5_panel.output_file.get_path() # 修正:变量名对齐
if p5:
if not os.path.isabs(p5): p5 = os.path.join(self.work_dir or '', p5).replace('\\', '/')
self.training_data_widget.set_path(p5)
if self.work_dir:
out = os.path.join(self.work_dir, "6_water_quality_indices", "training_spectra_indices.csv").replace('\\', '/')
self.output_file_widget.set_path(out)
def run_step(self):
"""独立运行步骤8"""
sampling_csv_path = self.sampling_csv_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
config = self.get_config()
if not config['training_csv_path']:
QMessageBox.warning(self, "提示", "请先选择输入数据")
return
# 外部模型优先:用户选择了"导入本地预训练模型"
if self.use_external_model.isChecked():
if not self.external_models_dict:
QMessageBox.warning(
self,
"模型未加载",
"请先点击「浏览...」按钮选择模型母文件夹!",
)
return
# 只传递用户勾选的模型
checked_dict = self._get_checked_models_dict()
if not checked_dict:
QMessageBox.warning(
self,
"未选择模型",
"请至少勾选一个模型参与预测!",
)
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {
'step8': self.get_config(),
'_external_models_dict': checked_dict,
'_external_model_dir': self.external_model_dir,
}
main_window.run_single_step('step8', config)
return
# 默认流程:使用模型目录
models_dir = self.models_dir_file.get_path()
if not models_dir:
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step8': self.get_config()}
main_window.run_single_step('step8', config)
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'): parent = parent.parent()
if parent: parent.run_single_step('step8', {'step8': config})

View File

@ -1,206 +1,158 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step9 面板 - 分布图生成
Step9 面板 - 自定义回归分析
"""
import os
import traceback
from pathlib import Path
from typing import List, Optional
from typing import Dict
from PyQt5.QtCore import Qt, QThread, pyqtSignal
import pandas as pd
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
QRadioButton, QButtonGroup, QMessageBox, QFileDialog,
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton,
QScrollArea, QMessageBox,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
PIPELINE_AVAILABLE = True
except ImportError:
PIPELINE_AVAILABLE = False
class Step9BatchThread(QThread):
"""专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
finished_ok = pyqtSignal(int)
failed = pyqtSignal(str)
log_message = pyqtSignal(str, str)
def __init__(self, work_dir: str, csv_paths: List[str], step9_kwargs: dict, output_dir_optional: Optional[str]):
super().__init__()
self.work_dir = work_dir
self.csv_paths = csv_paths
self.step9_kwargs = step9_kwargs
self.output_dir_optional = (output_dir_optional or "").strip() or None
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
n = len(self.csv_paths)
for i, csv_p in enumerate(self.csv_paths):
self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
kw = {**self.step9_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
if self.output_dir_optional:
stem = Path(csv_p).stem
kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
else:
kw["output_image_path"] = None
pipeline.step9_generate_distribution_map(**kw)
self.finished_ok.emit(n)
except Exception as e:
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass
class Step9Panel(QWidget):
"""步骤9分布图生成"""
"""步骤9自定义回归分析"""
def __init__(self, parent=None):
super().__init__(parent)
self._batch_thread = None
self.x_column_checkboxes: Dict[str, QCheckBox] = {}
self.y_column_checkboxes: Dict[str, QCheckBox] = {}
self.method_checkboxes: Dict[str, QCheckBox] = {}
self.csv_columns = []
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
hint = QLabel(
"独立运行:可选「单个 CSV」或「文件夹批量」扫描目录下所有 .csv"
"完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。"
)
hint.setWordWrap(True)
hint.setStyleSheet(
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
)
hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法")
hint.setStyleSheet("color: #666; font-size: 11px;")
layout.addWidget(hint)
mode_row = QHBoxLayout()
self.mode_single_rb = QRadioButton("单个 CSV 文件")
self.mode_folder_rb = QRadioButton("文件夹批量")
self._mode_group = QButtonGroup(self)
self._mode_group.addButton(self.mode_single_rb, 0)
self._mode_group.addButton(self.mode_folder_rb, 1)
mode_row.addWidget(self.mode_single_rb)
mode_row.addWidget(self.mode_folder_rb)
mode_row.addStretch()
layout.addLayout(mode_row)
# CSV文件选择
csv_group = QGroupBox("数据文件")
csv_layout = QVBoxLayout()
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
radio_style = """
QRadioButton {
font-size: 14px;
spacing: 8px;
color: #333333;
}
QRadioButton::indicator {
width: 16px;
height: 16px;
border: 2px solid #999999;
border-radius: 3px;
background-color: white;
}
QRadioButton::indicator:checked {
border: 2px solid #0078d4;
background-color: #0078d4;
image: none;
}
QRadioButton::indicator:hover {
border: 2px solid #005a9e;
}
"""
self.mode_single_rb.setStyleSheet(radio_style)
self.mode_folder_rb.setStyleSheet(radio_style)
self.prediction_csv_file = FileSelectWidget(
"预测结果CSV:",
self.csv_file = FileSelectWidget(
"输入CSV文件:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.prediction_csv_file)
self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed)
csv_layout.addWidget(self.csv_file)
folder_row = QHBoxLayout()
self.prediction_csv_dir_label = QLabel("预测CSV目录:")
self.prediction_csv_dir_label.setMinimumWidth(120)
self.prediction_csv_dir_edit = QLineEdit()
self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
pred_dir_btn = QPushButton("浏览…")
pred_dir_btn.setMaximumWidth(80)
pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
folder_row.addWidget(self.prediction_csv_dir_label)
folder_row.addWidget(self.prediction_csv_dir_edit, 1)
folder_row.addWidget(pred_dir_btn)
self._folder_row_widget = QWidget()
self._folder_row_widget.setLayout(folder_row)
layout.addWidget(self._folder_row_widget)
self.refresh_btn = QPushButton("刷新列信息")
self.refresh_btn.clicked.connect(self.refresh_csv_columns)
csv_layout.addWidget(self.refresh_btn)
self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv")
layout.addWidget(self.recursive_csv_cb)
csv_group.setLayout(csv_layout)
layout.addWidget(csv_group)
self.boundary_file = FileSelectWidget(
"边界文件:",
"Shapefiles (*.shp);;All Files (*.*)"
)
layout.addWidget(self.boundary_file)
# 自变量选择
x_group = QGroupBox("自变量列选择 (可多选)")
x_layout = QVBoxLayout()
# 参数设置
params_group = QGroupBox("生成参数")
params_layout = QFormLayout()
x_scroll = QScrollArea()
x_scroll.setWidgetResizable(True)
x_scroll.setMinimumHeight(250)
x_scroll.setMaximumHeight(350)
self.resolution = QDoubleSpinBox()
self.resolution.setRange(1, 1000)
self.resolution.setValue(30)
params_layout.addRow("分辨率(米):", self.resolution)
x_widget = QWidget()
self.x_columns_layout = QGridLayout()
x_widget.setLayout(self.x_columns_layout)
self.input_crs = QLineEdit()
self.input_crs.setText("EPSG:32651")
params_layout.addRow("输入坐标系:", self.input_crs)
x_scroll.setWidget(x_widget)
x_layout.addWidget(x_scroll)
self.output_crs = QLineEdit()
self.output_crs.setText("EPSG:4326")
params_layout.addRow("输出坐标系:", self.output_crs)
x_btn_layout = QHBoxLayout()
self.x_select_all = QPushButton("全选")
self.x_deselect_all = QPushButton("全不选")
self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True))
self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False))
x_btn_layout.addWidget(self.x_select_all)
x_btn_layout.addWidget(self.x_deselect_all)
x_btn_layout.addStretch()
x_layout.addLayout(x_btn_layout)
self.show_points = QCheckBox("显示采样点")
params_layout.addRow("", self.show_points)
x_group.setLayout(x_layout)
layout.addWidget(x_group)
self.use_diffusion = QCheckBox("启用距离扩散")
self.use_diffusion.setChecked(True)
params_layout.addRow("", self.use_diffusion)
# 因变量选择
y_group = QGroupBox("因变量列选择 (可多选)")
y_layout = QVBoxLayout()
params_group.setLayout(params_layout)
layout.addWidget(params_group)
y_scroll = QScrollArea()
y_scroll.setWidgetResizable(True)
y_scroll.setMinimumHeight(200)
y_scroll.setMaximumHeight(300)
y_widget = QWidget()
self.y_columns_layout = QGridLayout()
y_widget.setLayout(self.y_columns_layout)
y_scroll.setWidget(y_widget)
y_layout.addWidget(y_scroll)
y_btn_layout = QHBoxLayout()
self.y_select_all = QPushButton("全选")
self.y_deselect_all = QPushButton("全不选")
self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True))
self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False))
y_btn_layout.addWidget(self.y_select_all)
y_btn_layout.addWidget(self.y_deselect_all)
y_btn_layout.addStretch()
y_layout.addLayout(y_btn_layout)
y_group.setLayout(y_layout)
layout.addWidget(y_group)
# 回归方法选择
method_group = QGroupBox("回归方法选择 (可多选)")
method_layout = QVBoxLayout()
method_grid = QGridLayout()
regression_methods = [
'linear', 'exponential', 'power', 'logarithmic',
'polynomial', 'hyperbolic', 'sigmoidal'
]
for i, method in enumerate(regression_methods):
checkbox = QCheckBox(method)
if method in ['linear', 'exponential', 'power', 'logarithmic']:
checkbox.setChecked(True)
self.method_checkboxes[method] = checkbox
method_grid.addWidget(checkbox, i // 3, i % 3)
method_layout.addLayout(method_grid)
method_btn_layout = QHBoxLayout()
self.method_select_all = QPushButton("全选")
self.method_deselect_all = QPushButton("全不选")
self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True))
self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False))
method_btn_layout.addWidget(self.method_select_all)
method_btn_layout.addWidget(self.method_deselect_all)
method_btn_layout.addStretch()
method_layout.addLayout(method_btn_layout)
method_group.setLayout(method_layout)
layout.addWidget(method_group)
# 输出目录
self.output_dir = FileSelectWidget(
"输出分布图目录:",
"Directories;;All Files (*.*)"
)
self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
self.output_dir.browse_btn.clicked.disconnect()
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
layout.addWidget(self.output_dir)
output_group = QGroupBox("输出设置")
output_layout = QFormLayout()
self.output_dir = QLineEdit()
self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充
output_layout.addRow("输出目录名:", self.output_dir)
output_group.setLayout(output_layout)
layout.addWidget(output_group)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
@ -216,119 +168,120 @@ class Step9Panel(QWidget):
layout.addStretch()
self.setLayout(layout)
# 信号绑定与初始状态
self.mode_single_rb.toggled.connect(self._toggle_input_mode)
self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
self._toggle_input_mode() # 根据默认值设置初始显示状态
def toggle_checkboxes(self, checkboxes_dict, checked):
"""统一设置checkbox状态"""
for checkbox in checkboxes_dict.values():
checkbox.setChecked(checked)
def _toggle_input_mode(self):
"""槽函数:根据单选框状态动态显示/隐藏对应的输入组件。"""
folder_mode = self.mode_folder_rb.isChecked()
# 单个 CSV 模式:显示单文件选择,隐藏文件夹选择
self.prediction_csv_file.setVisible(not folder_mode)
# 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择
self._folder_row_widget.setVisible(folder_mode)
self.recursive_csv_cb.setVisible(folder_mode)
def on_csv_file_changed(self):
"""CSV文件改变时自动刷新列信息"""
self.refresh_csv_columns()
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def refresh_csv_columns(self):
"""刷新CSV文件的列信息"""
csv_path = self.csv_file.get_path()
if not csv_path or not os.path.exists(csv_path):
self.csv_columns = []
self.update_column_widgets()
return
def browse_prediction_csv_dir(self):
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "11_12_13_predictions")
d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
if d:
self.prediction_csv_dir_edit.setText(d)
try:
df = pd.read_csv(csv_path, nrows=0)
self.csv_columns = list(df.columns)
self.update_column_widgets()
except Exception as e:
self.csv_columns = []
self.update_column_widgets()
print(f"读取CSV列信息失败: {e}")
def _collect_csv_paths_from_folder(self) -> List[str]:
folder = (self.prediction_csv_dir_edit.text() or "").strip()
if not folder or not os.path.isdir(folder):
return []
root = Path(folder)
if self.recursive_csv_cb.isChecked():
files = sorted(root.rglob("*.csv"))
else:
files = sorted(root.glob("*.csv"))
return [str(p) for p in files if p.is_file()]
def update_column_widgets(self):
"""更新列选择组件"""
for checkbox in self.x_column_checkboxes.values():
checkbox.setParent(None)
self.x_column_checkboxes.clear()
def _step9_base_pipeline_kwargs(self) -> dict:
return {
'boundary_shp_path': self.boundary_file.get_path(),
'resolution': self.resolution.value(),
'input_crs': self.input_crs.text(),
'output_crs': self.output_crs.text(),
'show_sample_points': self.show_points.isChecked(),
'use_distance_diffusion': self.use_diffusion.isChecked(),
}
for checkbox in self.y_column_checkboxes.values():
checkbox.setParent(None)
self.y_column_checkboxes.clear()
if not self.csv_columns:
return
for i, col in enumerate(self.csv_columns):
checkbox = QCheckBox(col)
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
checkbox.setChecked(True)
self.x_column_checkboxes[col] = checkbox
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
for i, col in enumerate(self.csv_columns):
checkbox = QCheckBox(col)
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
checkbox.setChecked(True)
self.y_column_checkboxes[col] = checkbox
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
self.x_columns_layout.update()
self.y_columns_layout.update()
def get_config(self):
pred_csv = (self.prediction_csv_file.get_path() or "").strip()
folder_mode = self.mode_folder_rb.isChecked()
pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
config = {
'step9_batch_mode': 'folder' if folder_mode else 'single',
'prediction_csv_dir': pred_dir if pred_dir else None,
'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
'boundary_shp_path': self.boundary_file.get_path(),
'resolution': self.resolution.value(),
'input_crs': self.input_crs.text(),
'output_crs': self.output_crs.text(),
'show_sample_points': self.show_points.isChecked(),
'use_distance_diffusion': self.use_diffusion.isChecked(),
selected_x_columns = [
col for col, checkbox in self.x_column_checkboxes.items()
if checkbox.isChecked()
]
selected_y_columns = [
col for col, checkbox in self.y_column_checkboxes.items()
if checkbox.isChecked()
]
selected_methods = [
method for method, checkbox in self.method_checkboxes.items()
if checkbox.isChecked()
]
if not selected_methods:
selected_methods = 'all'
return {
'csv_path': self.csv_file.get_path() or None,
'x_columns': selected_x_columns,
'y_columns': selected_y_columns,
'methods': selected_methods,
'output_dir': self.output_dir.text().strip() or None,
'enabled': self.enable_checkbox.isChecked()
}
out_dir = (self.output_dir.get_path() or "").strip()
if not folder_mode and pred_csv and out_dir:
stem = Path(pred_csv).stem
config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
else:
config['output_image_path'] = None
return config
def set_config(self, config):
mode = config.get('step9_batch_mode', 'single')
if mode == 'folder':
self.mode_folder_rb.setChecked(True)
else:
self.mode_single_rb.setChecked(True)
if config.get('prediction_csv_dir'):
self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
if 'recursive_csv_scan' in config:
self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
if 'prediction_csv_path' in config and config['prediction_csv_path']:
self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
if 'boundary_shp_path' in config:
self.boundary_file.set_path(config['boundary_shp_path'])
if 'resolution' in config:
self.resolution.setValue(config['resolution'])
if 'input_crs' in config:
self.input_crs.setText(config['input_crs'])
if 'output_crs' in config:
self.output_crs.setText(config['output_crs'])
if 'show_sample_points' in config:
self.show_points.setChecked(config['show_sample_points'])
if 'use_distance_diffusion' in config:
self.use_diffusion.setChecked(config['use_distance_diffusion'])
if 'output_dir' in config and config['output_dir']:
self.output_dir.set_path(str(config['output_dir']))
elif config.get('output_image_path'):
p = Path(str(config['output_image_path']))
if p.parent and str(p.parent) != '.':
self.output_dir.set_path(str(p.parent))
if 'csv_path' in config:
self.csv_file.set_path(config['csv_path'])
self.refresh_csv_columns()
if 'x_columns' in config:
selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set()
for col, checkbox in self.x_column_checkboxes.items():
checkbox.setChecked(col in selected_x)
if 'y_columns' in config:
selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set()
for col, checkbox in self.y_column_checkboxes.items():
checkbox.setChecked(col in selected_y)
if 'methods' in config:
methods = config['methods']
if isinstance(methods, list):
selected_methods = set(methods)
elif methods == 'all':
selected_methods = set(self.method_checkboxes.keys())
else:
selected_methods = set()
for method, checkbox in self.method_checkboxes.items():
checkbox.setChecked(method in selected_methods)
if 'output_dir' in config:
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充预测结果目录
优先使用 Step8机器学习预测的输出目录作为待预测 CSV 目录;
其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
"""从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
@ -344,190 +297,78 @@ class Step9Panel(QWidget):
else:
self.work_dir = None
# 1. 尝试从 Step5 界面读取训练光谱 CSV 路径
main_window = self.window()
if not main_window:
return
if main_window and hasattr(main_window, 'step5_panel'):
step5_widget = getattr(main_window.step5_panel, 'output_file', None)
step5_output_path = ""
if hasattr(step5_widget, 'get_path'):
step5_output_path = step5_widget.get_path() or ""
elif hasattr(step5_widget, 'text'):
step5_output_path = step5_widget.text() or ""
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
pred_dir = None
if hasattr(main_window, 'step8_panel'):
step8_widget = getattr(main_window.step8_panel, 'output_file', None)
step8_output = ""
if hasattr(step8_widget, 'get_path'):
step8_output = step8_widget.get_path() or ""
elif hasattr(step8_widget, 'text'):
step8_output = step8_widget.text() or ""
if step8_output:
if step5_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_output):
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
# 提取父目录后追加 Machine_Learning_Prediction最底层真实子目录
base_pred_dir = str(Path(step8_output).parent)
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
if not os.path.isabs(step5_output_path):
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
existing = self.csv_file.get_path()
if not existing or not existing.strip():
self.csv_file.set_path(step5_output_path)
# 2. 备选:从 Step8.5 界面读取非经验预测输出目录
if not pred_dir and hasattr(main_window, 'step8_5_panel'):
step8_5_widget = getattr(main_window.step8_5_panel, 'output_file', None)
step8_5_output = ""
if hasattr(step8_5_widget, 'get_path'):
step8_5_output = step8_5_widget.get_path() or ""
elif hasattr(step8_5_widget, 'text'):
step8_5_output = step8_5_widget.text() or ""
if step8_5_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_5_output):
step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
pred_dir = str(Path(step8_5_output).parent)
# 3. 备选:从 Step8.75 界面读取自定义回归预测输出目录
if not pred_dir and hasattr(main_window, 'step8_75_panel'):
step8_75_widget = getattr(main_window.step8_75_panel, 'output_dir_widget', None)
step8_75_output = ""
if hasattr(step8_75_widget, 'get_path'):
step8_75_output = step8_75_widget.get_path() or ""
elif hasattr(step8_75_widget, 'text'):
step8_75_output = step8_75_widget.text() or ""
if step8_75_output:
pred_dir = step8_75_output
# 自动填入"预测CSV目录"(文件夹批量模式)
if pred_dir:
existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
if not existing_dir:
self.prediction_csv_dir_edit.setText(pred_dir)
# 切换到文件夹批量模式
self.mode_folder_rb.setChecked(True)
# 4. 自动填充输出目录14_visualization
# 2. 自动填充输出目录9_Custom_Regression_Modeling
if self.work_dir:
output_dir = os.path.join(self.work_dir, "14_visualization")
output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_dir.get_path()
if not existing_out or not existing_out.strip():
self.output_dir.set_path(output_dir)
# 5. 自动探测原始矢量边界文件(.shp作为专题图底图
# 优先回溯 input-test/roi.shpgeopandas.read_file 仅支持矢量格式
if self.work_dir:
possible_shp = None
candidates = [
Path(self.work_dir).parent / "input-test" / "roi.shp",
Path(self.work_dir) / "roi.shp",
Path(self.work_dir).parent / "roi.shp",
]
for candidate in candidates:
if candidate.exists() and candidate.suffix.lower() == ".shp":
possible_shp = candidate
break
existing_boundary = (self.boundary_file.get_path() or "").strip()
if not existing_boundary and possible_shp:
self.boundary_file.set_path(str(possible_shp))
elif not existing_boundary:
# 未找到 .shp 时清空并提示用户手动选择矢量文件
self.boundary_file.set_path("")
print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。")
existing_out = self.output_dir.text().strip()
if not existing_out:
self.output_dir.setText(output_dir)
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def browse_output_dir(self):
"""浏览输出目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "14_visualization")
dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
if dir_path:
self.output_dir.set_path(dir_path)
def run_step(self):
"""独立运行步骤9"""
if self._batch_thread and self._batch_thread.isRunning():
QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
csv_path = self.csv_file.get_path()
if not csv_path:
QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件")
return
if not os.path.exists(csv_path):
QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在")
return
boundary_shp_path = self.boundary_file.get_path()
if not boundary_shp_path:
QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
selected_x_columns = [
col for col, checkbox in self.x_column_checkboxes.items()
if checkbox.isChecked()
]
if not selected_x_columns:
QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列")
return
if not os.path.exists(boundary_shp_path):
QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
selected_y_columns = [
col for col, checkbox in self.y_column_checkboxes.items()
if checkbox.isChecked()
]
if not selected_y_columns:
QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列")
return
selected_methods = [
method for method, checkbox in self.method_checkboxes.items()
if checkbox.isChecked()
]
if not selected_methods:
QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法")
return
config = self.get_config()
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if not parent or not hasattr(parent, 'run_single_step'):
if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step9', {'step9': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
return
if self.mode_folder_rb.isChecked():
csv_list = self._collect_csv_paths_from_folder()
if not csv_list:
QMessageBox.warning(
self,
"输入验证失败",
"所选文件夹中未找到 .csv 文件,或目录无效。\n"
"可勾选「包含子文件夹」以递归扫描。",
)
return
if not PIPELINE_AVAILABLE:
QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
return
work_dir = getattr(parent, "work_dir", None) or "./work_dir"
work_dir = str(work_dir)
base_kw = self._step9_base_pipeline_kwargs()
out_dir_opt = (self.output_dir.get_path() or "").strip() or None
self.run_button.setEnabled(False)
self._batch_thread = Step9BatchThread(work_dir, csv_list, base_kw, out_dir_opt)
main_win = parent
def _batch_log(msg, lvl):
if hasattr(main_win, "log_message"):
main_win.log_message(msg, lvl)
self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
self._batch_thread.finished_ok.connect(self._on_step9_batch_ok, Qt.QueuedConnection)
self._batch_thread.failed.connect(self._on_step9_batch_fail, Qt.QueuedConnection)
self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection)
self._batch_thread.start()
if hasattr(parent, "log_message"):
parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV工作目录 {work_dir}", "info")
return
prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
if not prediction_csv_path:
QMessageBox.warning(
self,
"输入验证失败",
"请选择「预测结果 CSV」文件或切换到「文件夹批量」。",
)
return
if not os.path.isfile(prediction_csv_path):
QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
return
config = self.get_config()
parent.run_single_step('step9', {'step9': config})
def _on_step9_batch_ok(self, n: int):
QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
parent = self.parent()
while parent and not hasattr(parent, "log_message"):
parent = parent.parent()
if parent and hasattr(parent, "log_message"):
parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
def _on_step9_batch_fail(self, err: str):
QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
parent = self.parent()
while parent and not hasattr(parent, "log_message"):
parent = parent.parent()
if parent and hasattr(parent, "log_message"):
parent.log_message(err, "error")

View File

@ -1567,12 +1567,12 @@ class VisualizationPanel(QWidget):
ml_dir.mkdir(parents=True, exist_ok=True)
reg_dir.mkdir(parents=True, exist_ok=True)
custom_dir.mkdir(parents=True, exist_ok=True)
if hasattr(self, 'step8_panel') and hasattr(self.step8_panel, 'output_file'):
self.step8_panel.output_file.set_path(str(ml_dir))
if hasattr(self, 'step8_5_panel') and hasattr(self.step8_5_panel, 'output_file'):
self.step8_5_panel.output_file.set_path(str(reg_dir))
if hasattr(self, 'step8_75_panel') and hasattr(self.step8_75_panel, 'output_dir_widget'):
self.step8_75_panel.output_dir_widget.set_path(str(custom_dir))
if hasattr(self, 'step11_ml_panel') and hasattr(self.step11_ml_panel, 'output_file'):
self.step11_ml_panel.output_file.set_path(str(ml_dir))
if hasattr(self, 'step11_panel') and hasattr(self.step11_panel, 'output_file'):
self.step11_panel.output_file.set_path(str(reg_dir))
if hasattr(self, 'step12_panel') and hasattr(self.step12_panel, 'output_dir_widget'):
self.step12_panel.output_dir_widget.set_path(str(custom_dir))
print(f"预测输出目录已设置:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}")
except Exception as e:
print(f"设置预测输出目录失败: {e}")

View File

@ -119,19 +119,22 @@ from src.gui.panels.step2_panel import Step2Panel
from src.gui.panels.step3_panel import Step3Panel
from src.gui.panels.step4_panel import Step4Panel
from src.gui.panels.step5_panel import Step5Panel
from src.gui.panels.step5_5_panel import Step5_5Panel
from src.gui.panels.step6_panel import Step6Panel
from src.gui.panels.step6_5_panel import Step6_5Panel
from src.gui.panels.step6_75_panel import Step6_75Panel
from src.gui.panels.step7_panel import Step7Panel
from src.gui.panels.step8_panel import Step8Panel
from src.gui.panels.step8_5_panel import Step8_5Panel
from src.gui.panels.step8_75_panel import Step8_75Panel
from src.gui.panels.step9_panel import Step9Panel
from src.gui.panels.step8_panel import Step8Panel # was step5_5_panel
from src.gui.panels.step7_panel import Step7Panel # was step6_panel
from src.gui.panels.step8_non_empirical_panel import Step8NonEmpiricalPanel # was step6_5_panel
from src.gui.panels.step9_panel import Step9Panel # was step6_75_panel
from src.gui.panels.step10_panel import Step10Panel # was step7_panel
from src.gui.panels.step11_ml_panel import Step11MlPanel # ML prediction (step11_ml)
from src.gui.panels.step11_panel import Step11Panel # was step8_5_panel
from src.gui.panels.step12_panel import Step12Panel # was step8_75_panel
from src.gui.panels.step14_panel import Step14Panel # was step9_panel
from src.gui.dialogs import BandConfirmDialog, AISettingsDialog
from src.gui.panels.visualization_panel import VisualizationPanel
from src.gui.panels.report_generation_panel import ReportGenerationPanel
# Pipeline 核心异常(用于预检弹窗)
from src.core.pipeline.runner import PipelineHalt
# Matplotlib相关导入 (推迟并加入底层防爆保护)
import matplotlib
try:
@ -152,6 +155,9 @@ from src.gui.core.worker_thread import (
check_pipeline_dependencies,
diagnose_pipeline_import_error,
)
# 预检交互对话框
from src.gui.core.preflight_dialog import PreflightDialog
from src.gui.core.pipeline_mode_dialog import PipelineModeDialog
def _viz_training_spectra_csv_path(work_path: Path) -> Path:
@ -1384,31 +1390,31 @@ class WaterQualityGUI(QMainWindow):
'step5': {
'training_spectra': '5_training_spectra/training_spectra.csv'
},
'step5_5': {
'step8': {
'water_indices': '6_water_quality_indices/water_quality_indices.csv'
},
'step6': {
'step7': {
'models': '7_Supervised_Model_Training/' # 目录,包含各参数子目录
},
'step6_5': {
'step8_non_empirical_modeling': {
'regression_models': '8_Regression_Modeling/' # 目录,包含各参数子目录
},
'step6_75': {
'step9': {
'custom_regression_models': '9_Custom_Regression_Modeling/' # 目录
},
'step7': {
'step10': {
'sampling_points': '10_sampling/sampling_spectra.csv'
},
'step8': {
'step11_ml': {
'predictions': '11_12_13_predictions/Machine_Learning_Prediction/' # 目录,包含机器学习预测结果
},
'step8_5': {
'step11': {
'regression_predictions': '11_12_13_predictions/Non_Empirical_Prediction/' # 目录,包含非经验模型预测结果
},
'step8_75': {
'step12': {
'custom_predictions': '11_12_13_predictions/Custom_Regression_Prediction/' # 目录,包含自定义回归预测结果
},
'step9': {
'step14': {
'distribution_maps': '14_visualization/' # 目录,包含专题图
}
}
@ -1432,37 +1438,37 @@ class WaterQualityGUI(QMainWindow):
'boundary_mask_path': ('step1', 'water_mask', 'boundary_mask_file'), # 步骤5可选水体掩膜
'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤5可选耀斑掩膜
},
'step5_5': {
'training_csv_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱
},
'step6': {
'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6需要训练光谱数据
},
'step6_5': {
'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6.5需要训练光谱数据
},
'step6_75': {
'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6.75需要训练光谱数据
'step8': {
'training_csv_path': ('step5', 'training_spectra', 'output_file') # 步骤8需要步骤5输出的训练光谱
},
'step7': {
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'), # 步骤7需要去耀斑影像
'water_mask_path': ('step1', 'water_mask', 'water_mask_file'), # 步骤7需要水域掩膜
'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤7可选耀斑掩膜
'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤7需要训练光谱数据
},
'step8': {
'sampling_csv_path': ('step7', 'sampling_points', 'sampling_csv_file'), # 步骤8需要采样点
'models_dir': ('step6', 'models', 'models_dir_file') # 步骤8需要训练好的模型
},
'step8_5': {
'sampling_csv_path': ('step7', 'sampling_points', 'sampling_csv_file'), # 步骤8.5需要采样点
'models_dir': ('step6_5', 'regression_models', 'models_dir') # 步骤8.5需要回归模型
},
'step8_75': {
'sampling_csv_path': ('step7', 'sampling_points', 'sampling_csv_file'), # 步骤8.75需要采样点
'models_dir': ('step6_75', 'custom_regression_models', 'models_dir') # 步骤8.75需要自定义回归模型
'step8_non_empirical_modeling': {
'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤8非经验建模需要训练光谱数据
},
'step9': {
'prediction_csv_path': ('step8', 'predictions', 'prediction_csv_file') # 步骤9需要预测结果CSV
'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤9需要训练光谱数据
},
'step10': {
'deglint_img_path': ('step3', 'deglint_image', 'deglint_img_file'), # 步骤10需要去耀斑影像
'water_mask_path': ('step1', 'water_mask', 'water_mask_file'), # 步骤10需要水域掩膜
'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤10可选耀斑掩膜
},
'step11_ml': {
'sampling_csv_path': ('step10', 'sampling_points', 'sampling_csv_file'), # 步骤11ML需要采样点
'models_dir': ('step7', 'models', 'models_dir_file') # 步骤11ML需要训练好的模型
},
'step11': {
'sampling_csv_path': ('step10', 'sampling_points', 'sampling_csv_file'), # 步骤11需要采样点
'models_dir': ('step8_non_empirical_modeling', 'regression_models', 'models_dir') # 步骤11需要回归模型
},
'step12': {
'sampling_csv_path': ('step10', 'sampling_points', 'sampling_csv_file'), # 步骤12需要采样点
'models_dir': ('step9', 'custom_regression_models', 'models_dir') # 步骤12需要自定义回归模型
},
'step14': {
'prediction_csv_path': ('step11_ml', 'predictions', 'prediction_csv_file') # 步骤14需要预测结果CSV
}
}
@ -1545,7 +1551,7 @@ class WaterQualityGUI(QMainWindow):
def init_ui(self):
"""初始化UI"""
self.setWindowTitle("MegaCube-Water Quality V1.1")
self.setWindowTitle("MegaCube-Water Quality V1.2")
# 获取屏幕可用区域(排除任务栏)
screen_geometry = QApplication.primaryScreen().availableGeometry()
@ -1730,7 +1736,7 @@ class WaterQualityGUI(QMainWindow):
def create_banner_widget(self):
"""创建横幅区域 - 支持自适应等比缩放"""
# 横幅标题文字(方便后续直接修改版本号)
self._APP_TITLE = "MegaCube-Water Quality V1.1"
self._APP_TITLE = "MegaCube-Water Quality V1.2"
# 创建横幅容器
banner_widget = QWidget()
@ -1844,19 +1850,19 @@ class WaterQualityGUI(QMainWindow):
"阶段二:样本数据准备 ": [
("step4", "4. 数据标准化处理"),
("step5", "5. 光谱特征提取"),
("step5_5", "6. 水质参数指数计算"),
("step8", "6. 水质参数指数计算"),
],
"阶段三:模型构建与训练": [
("step6", "7. 机器学习模型训练"),
("step6_5", "8. 回归模型训练"),
("step6_75", "9. 自定义回归模型训练"),
("step7", "7. 机器学习模型训练"),
("step8_non_empirical_modeling", "8. 回归模型训练"),
("step9", "9. 自定义回归模型训练"),
],
"阶段四:预测与成果输出 ": [
("step7", "10. 采样点布设"),
("step8", "11. 机器学习预测"),
("step8_5", "12. 回归预测"),
("step8_75", "13. 自定义回归预测"),
("step9", "14. 专题图生成"),
("step10", "10. 采样点布设"),
("step11_ml", "11. 机器学习预测"),
("step11", "12. 回归预测"),
("step12", "13. 自定义回归预测"),
("step14", "14. 专题图生成"),
("step9_viz", "15. 可视化分析"),
("step_report", "16. 分析报告生成"),
]
@ -1878,7 +1884,7 @@ class WaterQualityGUI(QMainWindow):
self.step_list.addItem(stage_item)
# 添加该阶段的所有步骤
HIDDEN_STEP_IDS = {"step6_5", "step6_75", "step8_5", "step8_75"}
HIDDEN_STEP_IDS = {"step8_non_empirical_modeling", "step9", "step11", "step12"}
for step_id, step_display in steps:
if step_id in HIDDEN_STEP_IDS:
continue
@ -1958,36 +1964,36 @@ class WaterQualityGUI(QMainWindow):
self.step5_panel = Step5Panel()
self.step_stack.addTab(self.create_scroll_area(self.step5_panel), QIcon(self.get_icon_path("5.png")), "特征构建")
self.step5_5_panel = Step5_5Panel()
self.step_stack.addTab(self.create_scroll_area(self.step5_5_panel), QIcon(self.get_icon_path("5.png")), "水质指数")
self.step6_panel = Step6Panel()
self.step_stack.addTab(self.create_scroll_area(self.step6_panel), QIcon(self.get_icon_path("6.png")), "监督建模")
self.step6_5_panel = Step6_5Panel()
self.step_stack.addTab(self.create_scroll_area(self.step6_5_panel), QIcon(self.get_icon_path("6.png")), "回归建模")
self.step_stack.tabBar().setTabVisible(7, False) # 隐藏回归建模 Tab
self.step6_75_panel = Step6_75Panel()
self.step_stack.addTab(self.create_scroll_area(self.step6_75_panel), QIcon(self.get_icon_path("6.png")), "自定义回归建模")
self.step_stack.tabBar().setTabVisible(8, False) # 隐藏自定义回归建模 Tab
self.step8_panel = Step8Panel()
self.step_stack.addTab(self.create_scroll_area(self.step8_panel), QIcon(self.get_icon_path("5.png")), "水质指数")
self.step7_panel = Step7Panel()
self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("7.png")), "采样点布设")
self.step8_panel = Step8Panel()
self.step_stack.addTab(self.create_scroll_area(self.step8_panel), QIcon(self.get_icon_path("8.png")), "监督预测")
self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("6.png")), "监督建模")
self.step8_5_panel = Step8_5Panel()
self.step_stack.addTab(self.create_scroll_area(self.step8_5_panel), QIcon(self.get_icon_path("8.png")), "回归预测")
self.step_stack.tabBar().setTabVisible(11, False) # 隐藏回归预测 Tab
self.step8_75_panel = Step8_75Panel()
self.step_stack.addTab(self.create_scroll_area(self.step8_75_panel), QIcon(self.get_icon_path("8.png")), "自定义回归预测")
self.step_stack.tabBar().setTabVisible(12, False) # 隐藏自定义回归预测 Tab
self.step8_non_empirical_panel = Step8NonEmpiricalPanel()
self.step_stack.addTab(self.create_scroll_area(self.step8_non_empirical_panel), QIcon(self.get_icon_path("6.png")), "回归建模")
self.step_stack.tabBar().setTabVisible(7, False) # 隐藏回归建模 Tab
self.step9_panel = Step9Panel()
self.step_stack.addTab(self.create_scroll_area(self.step9_panel), QIcon(self.get_icon_path("10.png")), "专题图生成")
self.step_stack.addTab(self.create_scroll_area(self.step9_panel), QIcon(self.get_icon_path("6.png")), "自定义回归建模")
self.step_stack.tabBar().setTabVisible(8, False) # 隐藏自定义回归建模 Tab
self.step10_panel = Step10Panel()
self.step_stack.addTab(self.create_scroll_area(self.step10_panel), QIcon(self.get_icon_path("7.png")), "采样点布设")
self.step11_ml_panel = Step11MlPanel() # ML prediction panel (step11_ml)
self.step_stack.addTab(self.create_scroll_area(self.step11_ml_panel), QIcon(self.get_icon_path("8.png")), "监督预测")
self.step11_panel = Step11Panel()
self.step_stack.addTab(self.create_scroll_area(self.step11_panel), QIcon(self.get_icon_path("8.png")), "回归预测")
self.step_stack.tabBar().setTabVisible(11, False) # 隐藏回归预测 Tab
self.step12_panel = Step12Panel()
self.step_stack.addTab(self.create_scroll_area(self.step12_panel), QIcon(self.get_icon_path("8.png")), "自定义回归预测")
self.step_stack.tabBar().setTabVisible(12, False) # 隐藏自定义回归预测 Tab
self.step14_panel = Step14Panel()
self.step_stack.addTab(self.create_scroll_area(self.step14_panel), QIcon(self.get_icon_path("10.png")), "专题图生成")
self.viz_panel = VisualizationPanel()
self.step_stack.addTab(self.create_scroll_area(self.viz_panel), QIcon(self.get_icon_path("9.png")), "可视化")
@ -2137,15 +2143,15 @@ class WaterQualityGUI(QMainWindow):
'step3': 2,
'step4': 3,
'step5': 4,
'step5_5': 5,
'step6': 6,
'step6_5': 7,
'step6_75': 8,
'step7': 9,
'step8': 10,
'step8_5': 11,
'step8_75': 12,
'step9': 13,
'step8': 5,
'step7': 6,
'step8_non_empirical_modeling': 7,
'step9': 8,
'step10': 9,
'step11_ml': 10,
'step11': 11,
'step12': 12,
'step14': 13,
'step9_viz': 14,
'step_report': 15,
}
@ -2168,15 +2174,15 @@ class WaterQualityGUI(QMainWindow):
2: 'step3',
3: 'step4',
4: 'step5',
5: 'step5_5',
6: 'step6',
7: 'step6_5',
8: 'step6_75',
9: 'step7',
10: 'step8',
11: 'step8_5',
12: 'step8_75',
13: 'step9',
5: 'step8',
6: 'step7',
7: 'step8_non_empirical_modeling',
8: 'step9',
9: 'step10',
10: 'step11_ml',
11: 'step11',
12: 'step12',
13: 'step14',
14: 'step9_viz',
15: 'step_report',
}
@ -2213,41 +2219,41 @@ class WaterQualityGUI(QMainWindow):
elif index == 4:
self.step5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step5_5 切换时自动填充输出路径
# Step8水质指数切换时自动填充输出路径
elif index == 5:
self.step5_5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
self.step8_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step6 切换时自动填充训练数据和输出路径
# Step7监督建模切换时自动填充训练数据和输出路径
elif index == 6:
self.step6_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step6.5(非经验回归建模)切换时自动填充训练数据和模型目录
elif index == 7:
self.step6_5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step6.75(自定义回归建模)切换时自动填充训练数据和模型目录
elif index == 8:
self.step6_75_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step7采样点布设切换时自动填充掩膜和输出路径
elif index == 9:
self.step7_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step8非经验建模切换时自动填充训练数据和模型目录
elif index == 7:
self.step8_non_empirical_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step9自定义回归建模切换时自动填充训练数据和模型目录
elif index == 8:
self.step9_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step10采样点布设切换时自动填充掩膜和输出路径
elif index == 9:
self.step10_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step8机器学习预测切换时自动填充采样光谱和模型目录
elif index == 10:
self.step8_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
self.step11_ml_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step8.5(非经验模型预测)切换时自动填充采样光谱和回归模型目录
# Step11回归预测)切换时自动填充采样光谱和回归模型目录
elif index == 11:
self.step8_5_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
self.step11_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step8.75(自定义回归预测)切换时自动填充采样光谱和自定义回归模型目录
# Step12(自定义回归预测)切换时自动填充采样光谱和自定义回归模型目录
elif index == 12:
self.step8_75_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
self.step12_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step9(专题图生成)切换时自动填充预测结果目录
# Step14(专题图生成)切换时自动填充预测结果目录
elif index == 13:
self.step9_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
self.step14_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# 可视化分析面板切换时自动推断图像目录并加载目录树
elif index == 14:
@ -2294,22 +2300,22 @@ class WaterQualityGUI(QMainWindow):
self.step4_panel.set_config(config['step4'])
if 'step5' in config:
self.step5_panel.set_config(config['step5'])
if 'step5_5' in config:
self.step5_5_panel.set_config(config['step5_5'])
if 'step6' in config:
self.step6_panel.set_config(config['step6'])
if 'step6_5' in config:
self.step6_5_panel.set_config(config['step6_5'])
if 'step6_75' in config:
self.step6_75_panel.set_config(config['step6_75'])
if 'step7' in config:
self.step7_panel.set_config(config['step7'])
if 'step8' in config:
self.step8_panel.set_config(config['step8'])
if 'step8_5' in config:
self.step8_5_panel.set_config(config['step8_5'])
if 'step7' in config:
self.step7_panel.set_config(config['step7'])
if 'step8_non_empirical_modeling' in config:
self.step8_non_empirical_panel.set_config(config['step8_non_empirical_modeling'])
if 'step9' in config:
self.step9_panel.set_config(config['step9'])
if 'step10' in config:
self.step10_panel.set_config(config['step10'])
if 'step11_ml' in config:
self.step11_ml_panel.set_config(config['step11_ml'])
if 'step11' in config:
self.step11_panel.set_config(config['step11'])
if 'step14' in config:
self.step14_panel.set_config(config['step14'])
if 'visualization' in config:
self.viz_panel.set_config(config['visualization'])
if 'report_generation' in config:
@ -2352,14 +2358,14 @@ class WaterQualityGUI(QMainWindow):
'step3': self.step3_panel.get_config(),
'step4': self.step4_panel.get_config(),
'step5': self.step5_panel.get_config(),
'step5_5': self.step5_5_panel.get_config(),
'step6': self.step6_panel.get_config(),
'step6_5': self.step6_5_panel.get_config(),
'step6_75': self.step6_75_panel.get_config(),
'step7': self.step7_panel.get_config(),
'step8': self.step8_panel.get_config(),
'step8_5': self.step8_5_panel.get_config(),
'step7': self.step7_panel.get_config(),
'step8_non_empirical_modeling': self.step8_non_empirical_panel.get_config(),
'step9': self.step9_panel.get_config(),
'step10': self.step10_panel.get_config(),
'step11_ml': self.step11_ml_panel.get_config(),
'step11': self.step11_panel.get_config(),
'step14': self.step14_panel.get_config(),
'visualization': self.viz_panel.get_config(),
'report_generation': self.report_panel.get_config(),
}
@ -2410,15 +2416,15 @@ class WaterQualityGUI(QMainWindow):
'step3': self.step3_panel,
'step4': self.step4_panel,
'step5': self.step5_panel,
'step5_5': self.step5_5_panel,
'step6': self.step6_panel,
'step6_5': self.step6_5_panel,
'step6_75': self.step6_75_panel,
'step7': self.step7_panel,
'step8': self.step8_panel,
'step8_5': self.step8_5_panel,
'step8_75': self.step8_75_panel,
'step7': self.step7_panel,
'step8_non_empirical_modeling': self.step8_non_empirical_panel,
'step9': self.step9_panel,
'step10': self.step10_panel,
'step11_ml': self.step11_ml_panel,
'step11': self.step11_panel,
'step12': self.step12_panel,
'step14': self.step14_panel,
}
return panel_map.get(step_id)
@ -2426,15 +2432,38 @@ class WaterQualityGUI(QMainWindow):
"""查找指定步骤的输出文件"""
if step_id not in self.step_default_outputs:
return None
step_outputs = self.step_default_outputs[step_id]
# ★ 掩膜类型列表:这些类型只接受科学数据格式
mask_types = {'water_mask', 'glint_mask', 'boundary_mask'}
# ★ 白名单机制:只允许 .dat .tif .tiff .shp拒绝其他一切格式
scientific_extensions = {'.dat', '.tif', '.tiff', '.shp'}
# ★ 临时文件关键词黑名单
tmp_keywords = ('__tmp', '_tmp')
def _is_scientific_mask(path_str):
"""白名单判断:只有 .dat .tif .tiff .shp 才算科学数据格式"""
p = Path(path_str)
name_lower = str(path_str).lower()
# 拒绝临时文件
if any(kw in name_lower for kw in tmp_keywords):
return False
# 白名单校验
return p.suffix.lower() in scientific_extensions
# 特殊处理从step_outputs记录中查找实际输出路径
if step_id in self.step_outputs:
actual_outputs = self.step_outputs[step_id]
if output_type in actual_outputs:
return actual_outputs[output_type]
candidate = actual_outputs[output_type]
# ★ 掩膜类型白名单二次校验:不在白名单内的一律拒绝
if output_type in mask_types and not _is_scientific_mask(candidate):
# 非科学格式被拒绝,不使用 step_outputs 中的值
pass
else:
return candidate
# 根据输出类型查找对应的文件
if output_type == 'water_mask':
# 水域掩膜优先查找NDWI生成的其次是shp生成的
@ -2485,19 +2514,19 @@ class WaterQualityGUI(QMainWindow):
# 扫描各个子目录
subdirs = {
'1_water_mask': 'step1',
'2_glint': 'step2',
'2_glint': 'step2',
'3_deglint': 'step3',
'4_processed_data': 'step4',
'5_training_spectra': 'step5',
'6_water_quality_indices': 'step5_5',
'7_Supervised_Model_Training': 'step6',
'8_Regression_Modeling': 'step6_5',
'9_Custom_Regression_Modeling': 'step6_75',
'10_sampling': 'step7',
'11_12_13_predictions/Machine_Learning_Prediction': 'step8',
'11_12_13_predictions/Non_Empirical_Prediction': 'step8_5',
'11_12_13_predictions/Custom_Regression_Prediction': 'step8_75',
'14_visualization': 'step9'
'6_water_quality_indices': 'step8',
'7_Supervised_Model_Training': 'step7',
'8_Regression_Modeling': 'step8_non_empirical_modeling',
'9_Custom_Regression_Modeling': 'step9',
'10_sampling': 'step10',
'11_12_13_predictions/Machine_Learning_Prediction': 'step11_ml',
'11_12_13_predictions/Non_Empirical_Prediction': 'step11',
'11_12_13_predictions/Custom_Regression_Prediction': 'step12',
'14_visualization': 'step14'
}
for subdir, step_ids in subdirs.items():
@ -2517,23 +2546,37 @@ class WaterQualityGUI(QMainWindow):
for step_id in step_ids:
if step_id not in discovered_outputs:
discovered_outputs[step_id] = {}
# ★ 掩膜文件白名单过滤:只有 .dat .tif .tiff .shp 才通过,拒绝 .hdr .xml .png 等
scientific_extensions = {'.dat', '.tif', '.tiff', '.shp'}
tmp_keywords = ('__tmp', '_tmp')
def _is_scientific_mask(path_str):
"""白名单判断:拒绝 .hdr .xml 临时文件等,只接受科学数据格式"""
p = Path(path_str)
name_lower = str(path_str).lower()
if any(kw in name_lower for kw in tmp_keywords):
return False
return p.suffix.lower() in scientific_extensions
# 匹配不同的文件类型
if 'water_mask' in file_name and step_id == 'step1':
discovered_outputs[step_id]['water_mask'] = str(file_path)
if _is_scientific_mask(file_path):
discovered_outputs[step_id]['water_mask'] = str(file_path)
elif 'glint' in file_name and 'mask' in file_name and step_id == 'step2':
discovered_outputs[step_id]['glint_mask'] = str(file_path)
if _is_scientific_mask(file_path):
discovered_outputs[step_id]['glint_mask'] = str(file_path)
elif 'deglint' in file_name and step_id == 'step3':
discovered_outputs[step_id]['deglint_image'] = str(file_path)
elif 'processed_data' in file_name and step_id == 'step4':
discovered_outputs[step_id]['processed_data'] = str(file_path)
elif 'training_spectra' in file_name and step_id == 'step5':
discovered_outputs[step_id]['training_spectra'] = str(file_path)
elif 'water_quality_indices' in file_name and step_id == 'step5_5':
elif 'water_quality_indices' in file_name and step_id == 'step8':
discovered_outputs[step_id]['water_indices'] = str(file_path)
elif 'sampling_spectra' in file_name and step_id == 'step7':
elif 'sampling_spectra' in file_name and step_id == 'step10':
discovered_outputs[step_id]['sampling_points'] = str(file_path)
elif file_name.endswith('.csv') and step_id in ['step8', 'step8_5', 'step8_75']:
elif file_name.endswith('.csv') and step_id in ['step11_ml', 'step11', 'step12']:
discovered_outputs[step_id]['predictions'] = str(file_path)
# 更新内部记录
@ -2556,8 +2599,8 @@ class WaterQualityGUI(QMainWindow):
# 首先扫描工作目录发现已有的输出文件
self.scan_work_directory_for_files(work_path)
step_order = ['step2', 'step3', 'step4', 'step5', 'step5_5', 'step6', 'step6_5', 'step6_75',
'step7', 'step8', 'step8_5', 'step8_75', 'step9']
step_order = ['step2', 'step3', 'step4', 'step5', 'step8', 'step7', 'step8_non_empirical_modeling', 'step9',
'step10', 'step11_ml', 'step11', 'step12', 'step14']
filled_count = 0
for step_id in step_order:
@ -2579,15 +2622,15 @@ class WaterQualityGUI(QMainWindow):
('step2', self.step2_panel),
('step3', self.step3_panel),
('step5', self.step5_panel),
('step5_5', self.step5_5_panel),
('step6', self.step6_panel),
('step6_5', self.step6_5_panel),
('step6_75', self.step6_75_panel),
('step7', self.step7_panel),
('step8', self.step8_panel),
('step8_5', self.step8_5_panel),
('step8_75', self.step8_75_panel),
('step9', self.step9_panel)
('step7', self.step7_panel),
('step8_non_empirical_modeling', self.step8_non_empirical_panel),
('step9', self.step9_panel),
('step10', self.step10_panel),
('step11_ml', self.step11_ml_panel),
('step11', self.step11_panel),
('step12', self.step12_panel),
('step14', self.step14_panel)
]
for step_id, panel in panels_with_dependencies:
@ -2735,7 +2778,7 @@ class WaterQualityGUI(QMainWindow):
"""显示关于对话框"""
QMessageBox.about(
self, "关于",
"MegaCube-Water Quality V1.1\n\n"
"MegaCube-Water Quality V1.2\n\n"
"一个完整的水质参数反演工作流程工具\n\n"
"功能包括:\n"
"- 水域掩膜生成\n"
@ -2858,6 +2901,41 @@ class WaterQualityGUI(QMainWindow):
return True
# ------------------------------------------------------------------
# ★ 全流程模式动态裁剪
# ------------------------------------------------------------------
def _prune_config_for_prediction_mode(self, config: dict) -> dict:
"""Prediction-only 模式:禁用训练相关步骤,保留预测和成图步骤。
被禁用的 step dict 中统一写入 'enabled': False
这些配置最终传给 PipelineRunnerRunner 会跳过它们。
同时,被跳过的步骤的 required_input_files 在 build_missing_items
中不会被检查,从而自然规避了"CSV 缺失"等训练模式下的误报。
Args:
config: 完整配置字典(来自 get_current_config
Returns:
裁剪后的 config深拷贝原 config 不被修改)
"""
cfg = copy.deepcopy(config)
# 在每个训练相关步骤的 dict 中写入 enabled=False
training_steps = [
"step4", # CSV 实测数据清洗
"step5", # 实测点光谱提取(→ training_csv_path
"step7", # ML 监督建模
"step8", # 水质指数计算(辅助训练)
"step8_non_empirical_modeling", # 非经验回归建模
"step9", # 自定义回归建模
]
for step_id in training_steps:
step_cfg = cfg.setdefault(step_id, {})
step_cfg["enabled"] = False
return cfg
def run_full_pipeline(self):
"""运行完整流程"""
if not PIPELINE_AVAILABLE:
@ -2867,8 +2945,14 @@ class WaterQualityGUI(QMainWindow):
)
return
# ── 0) 强制获取 work_dir禁止依赖外部或全局变量 ──
work_dir = getattr(self, 'work_dir', None)
if not work_dir:
QMessageBox.warning(self, "警告", "未选择工作目录,请先设置工作目录。")
return
# ── 1) 运行前智能预检与自动回填(硬盘已有产物自动跳过) ──
work_path = Path(getattr(self, 'work_dir', './work_dir'))
work_path = Path(work_dir)
self.log_message("正在进行运行前环境预检与自动扫描...", "info")
self.scan_work_directory_for_files(work_path)
self.auto_populate_all_steps()
@ -2878,31 +2962,52 @@ class WaterQualityGUI(QMainWindow):
if not self._precheck_step3_bands():
return # 用户点"取消运行"
# ── 1.6) ★ 全流程模式选择弹窗 ──
mode_dlg = PipelineModeDialog(main_window=self, parent=self)
if mode_dlg.exec() != QDialog.Accepted:
return # 用户点"取消"
selected_mode = mode_dlg.selected_mode
self.log_message(f"[模式选择] 选定模式: {'训练新模型' if selected_mode == 'training' else '使用已有模型直接预测'}", "info")
# ── 2) 刷新配置(拿到自动填充后的"满血版" config ──
config = self.get_current_config()
# ── 3) 根基数据校验step1.img_path参考影像 ──
if not config['step1'].get('img_path'):
QMessageBox.warning(self, "警告", "缺失核心数据:请先在步骤 1 中上传【参考影像】!")
for i in range(self.step_list.count()):
item = self.step_list.item(i)
if item.data(Qt.UserRole) == 'step1':
self.step_list.setCurrentRow(i)
break
return
# ── 2.1) ★ 根据模式动态裁剪配置 ──
if selected_mode == "prediction_only":
config = self._prune_config_for_prediction_mode(config)
self.log_message("[模式选择] 已裁剪训练相关步骤step4/5/7/8进入仅预测模式", "info")
# ── 4) 软提示csv_path 缺失 → 模型训练步骤会被静默跳过(不阻断) ──
csv_path = config.get('step4', {}).get('csv_path') or config.get('step5', {}).get('csv_path')
if not csv_path:
QMessageBox.information(
self,
"提示:模型训练将被跳过",
"未检测到实测水质数据 (CSV)。\n"
"流程将自动跳过模型训练(步骤 4-6仅执行预测与制图。\n"
"如果需要训练新模型,请先在步骤 4 中上传水质数据。",
)
# ── 3) ★ 一次性全预检 + 用户交互式决策 ──
missing_items = PreflightDialog.build_missing_items(config)
if missing_items:
critical_items = [it for it in missing_items if it.is_critical]
if critical_items:
lines = "\n".join(f" - [{it.step_name}] {it.reason}" for it in critical_items)
QMessageBox.critical(
self, "预检失败(阻断性错误)",
f"以下为阻断性缺失,流程无法启动:\n\n{lines}\n\n"
"请填写后重新运行。"
)
return
dialog = PreflightDialog(missing_items, parent=self)
if dialog.exec() != QDialog.Accepted:
return
result = dialog.get_result()
if result is None:
return
action, *payload = result
if action == "fill":
_, step_id, tab_index = result
self.step_stack.setCurrentIndex(tab_index)
self.log_message(f"[预检] 用户选择填写 {step_id},已切换到对应面板。", "info")
return
skip_list: List[str] = payload[0] if payload else []
if skip_list:
self.log_message(f"[预检] 用户强制跳过 {len(skip_list)} 个步骤: {skip_list}", "info")
else:
skip_list = []
# 确认执行
# ── 4) 确认执行
reply = QMessageBox.question(
self, "确认",
"是否开始执行完整流程?\n\n这可能需要较长时间,请确保配置正确。",
@ -2913,19 +3018,18 @@ class WaterQualityGUI(QMainWindow):
return
# 创建pipeline实例
work_dir = getattr(self, 'work_dir', './work_dir')
self.log_message(f"初始化pipeline工作目录: {work_dir}", "info")
# 准备实际运行配置(排除未启用的步骤)
worker_config = copy.deepcopy(config)
step5_5_cfg = worker_config.get('step5_5')
if step5_5_cfg:
enabled = step5_5_cfg.pop('enabled', True)
step8_cfg = worker_config.get('step8')
if step8_cfg:
enabled = step8_cfg.pop('enabled', True)
if not enabled:
worker_config.pop('step5_5', None)
worker_config.pop('step8', None)
# 工作线程内创建 Pipeline避免主线程阻塞及 Qt5Agg 子线程绘图卡死
self.worker = WorkerThread(work_dir, worker_config, mode='full')
self.worker = WorkerThread(work_dir, worker_config, mode='full', skip_list=skip_list)
self.worker.log_message.connect(self.log_message, Qt.QueuedConnection)
self.worker.progress_update.connect(self.update_progress, Qt.QueuedConnection)
self.worker.step_completed.connect(self.on_step_completed, Qt.QueuedConnection)
@ -3152,14 +3256,14 @@ class WaterQualityGUI(QMainWindow):
def update_ui_for_training_mode(self):
"""根据训练数据模式更新UI状态"""
# 需要禁用的步骤ID对应无训练数据模式下需要禁用的步骤
disabled_step_ids = ['step4', 'step5', 'step5_5', 'step6', 'step6_5', 'step6_75']
disabled_step_ids = ['step4', 'step5', 'step8', 'step7', 'step8_non_empirical_modeling', 'step9']
# 更新标签页的启用/禁用状态
step_id_to_tab = {
'step1': 0, 'step2': 1, 'step3': 2, 'step4': 3,
'step5': 4, 'step5_5': 5, 'step6': 6, 'step6_5': 7,
'step6_75': 8, 'step7': 9, 'step8': 10, 'step8_5': 11,
'step8_75': 12, 'step9': 13, 'step9_viz': 14
'step5': 4, 'step8': 5, 'step7': 6, 'step8_non_empirical_modeling': 7,
'step9': 8, 'step10': 9, 'step11_ml': 10, 'step11': 11,
'step12': 12, 'step14': 13, 'step9_viz': 14
}
for step_id in disabled_step_ids: