diff --git a/src/core/handlers/__init__.py b/src/core/handlers/__init__.py index 0e9d3d6..fe2a70d 100644 --- a/src/core/handlers/__init__.py +++ b/src/core/handlers/__init__.py @@ -12,9 +12,21 @@ from src.core.handlers.base import BaseStepHandler, PipelineContext from src.core.handlers.step1_water_mask import Step1WaterMaskHandler +from src.core.handlers.step2_glint_detection import Step2GlintDetectionHandler +from src.core.handlers.step3_glint_removal import Step3GlintRemovalHandler +from src.core.handlers.step4_sampling import Step4SamplingHandler +from src.core.handlers.step5_process_csv import Step5ProcessCsvHandler +from src.core.handlers.step6_extract_spectra import Step6ExtractSpectraHandler +from src.core.handlers.step7_calc_indices import Step7CalcIndicesHandler __all__ = [ 'BaseStepHandler', 'PipelineContext', 'Step1WaterMaskHandler', + 'Step2GlintDetectionHandler', + 'Step3GlintRemovalHandler', + 'Step4SamplingHandler', + 'Step5ProcessCsvHandler', + 'Step6ExtractSpectraHandler', + 'Step7CalcIndicesHandler', ] diff --git a/src/core/handlers/register_handlers.py b/src/core/handlers/register_handlers.py new file mode 100644 index 0000000..2793648 --- /dev/null +++ b/src/core/handlers/register_handlers.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Handler 注册辅助函数 + +将所有步骤 Handler 一次性注册到 PipelineScheduler。 +新增步骤只需在此函数中加一行 register_handler() 调用。 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from src.core.handlers.step1_water_mask import Step1WaterMaskHandler +from src.core.handlers.step2_glint_detection import Step2GlintDetectionHandler +from src.core.handlers.step3_glint_removal import Step3GlintRemovalHandler +from src.core.handlers.step4_sampling import Step4SamplingHandler +from src.core.handlers.step5_process_csv import Step5ProcessCsvHandler +from src.core.handlers.step6_extract_spectra import Step6ExtractSpectraHandler +from src.core.handlers.step7_calc_indices import Step7CalcIndicesHandler + +if TYPE_CHECKING: + from src.core.handlers.pipeline_scheduler import PipelineScheduler + + +def register_all_handlers(scheduler: PipelineScheduler): + """将所有已实现的步骤 Handler 注册到调度器。 + + 用法:: + + scheduler = PipelineScheduler(work_dir="./work_dir") + register_all_handlers(scheduler) + result = scheduler.run_full_pipeline(config) + + 新增步骤时,在此函数中追加一行 register_handler() 即可。 + """ + scheduler.register_handler(Step1WaterMaskHandler()) + scheduler.register_handler(Step2GlintDetectionHandler()) + scheduler.register_handler(Step3GlintRemovalHandler()) + scheduler.register_handler(Step4SamplingHandler()) + scheduler.register_handler(Step5ProcessCsvHandler()) + scheduler.register_handler(Step6ExtractSpectraHandler()) + scheduler.register_handler(Step7CalcIndicesHandler()) diff --git a/src/core/handlers/step2_glint_detection.py b/src/core/handlers/step2_glint_detection.py new file mode 100644 index 0000000..5745c47 --- /dev/null +++ b/src/core/handlers/step2_glint_detection.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Step2 处理器:耀斑区域检测 + +将原 WaterQualityInversionPipeline.step2_find_glint_area() 方法 +剥离为独立的 Step2GlintDetectionHandler。 +""" + +import time +from typing import Any, Dict + +from src.core.handlers.base import BaseStepHandler, PipelineContext +from src.core.steps.glint_detection_step import GlintDetectionStep + + +class Step2GlintDetectionHandler(BaseStepHandler): + """步骤2:耀斑区域检测。 + + 对应 config key: 'step2' + 委托类: GlintDetectionStep.run() + """ + + step_key = 'step2' + + def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]: + step_start_time = time.time() + + water_mask_path = self._resolve_path( + config.get('water_mask_path'), context.water_mask_path, 'water_mask' + ) + + try: + result = GlintDetectionStep.run( + img_path=config.get('img_path'), + glint_wave=config.get('glint_wave', 750.0), + method=config.get('method', 'otsu'), + z_threshold=config.get('z_threshold', 2.5), + percentile=config.get('percentile', 95.0), + iqr_multiplier=config.get('iqr_multiplier', 1.5), + window_size=config.get('window_size', 15), + multi_band_waves=config.get('multi_band_waves'), + sub_method=config.get('sub_method', 'zscore'), + weights=config.get('weights'), + max_area=config.get('max_area'), + buffer_size=config.get('buffer_size'), + water_mask_path=water_mask_path, + glint_dir=str(context.glint_dir), + callback=context.notify, + ) + + context.glint_mask_path = result + + step_end_time = time.time() + context.record_step_time( + "步骤2: 耀斑区域检测", step_start_time, step_end_time + ) + + return {'glint_mask_path': result} + + except Exception as e: + step_end_time = time.time() + context.record_step_time( + "步骤2: 耀斑区域检测", step_start_time, step_end_time, + status="failed", error=str(e) + ) + raise diff --git a/src/core/handlers/step3_glint_removal.py b/src/core/handlers/step3_glint_removal.py new file mode 100644 index 0000000..b8daaae --- /dev/null +++ b/src/core/handlers/step3_glint_removal.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Step3 处理器:耀斑去除 + +将原 WaterQualityInversionPipeline.step3_remove_glint() 方法 +剥离为独立的 Step3GlintRemovalHandler。 +""" + +import time +from typing import Any, Dict + +from src.core.handlers.base import BaseStepHandler, PipelineContext +from src.core.steps.glint_removal_step import GlintRemovalStep + + +class Step3GlintRemovalHandler(BaseStepHandler): + """步骤3:耀斑去除。 + + 对应 config key: 'step3' + 委托类: GlintRemovalStep.run() + """ + + step_key = 'step3' + + def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]: + step_start_time = time.time() + + water_mask_path = self._resolve_path( + config.get('water_mask_path'), context.water_mask_path, 'water_mask' + ) + + try: + result = GlintRemovalStep.run( + img_path=config.get('img_path'), + method=config.get('method', 'subtract_nir'), + start_wave=config.get('start_wave'), + end_wave=config.get('end_wave'), + json_path=config.get('json_path'), + left_shoulder_wave=config.get('left_shoulder_wave'), + valley_wave=config.get('valley_wave'), + right_shoulder_wave=config.get('right_shoulder_wave'), + water_mask=water_mask_path, + interpolate_zeros=config.get('interpolate_zeros', False), + interpolation_method=config.get('interpolation_method', 'nearest'), + enabled=config.get('enabled', True), + kutser_shp_path=config.get('kutser_shp_path'), + oxy_band=config.get('oxy_band', 38), + lower_oxy=config.get('lower_oxy', 36), + upper_oxy=config.get('upper_oxy', 49), + nir_band=config.get('nir_band', 47), + nir_lower=config.get('nir_lower', 25), + nir_upper=config.get('nir_upper', 37), + goodman_A=config.get('goodman_A', 0.000019), + goodman_B=config.get('goodman_B', 0.1), + hedley_shp_path=config.get('hedley_shp_path'), + hedley_nir_band=config.get('hedley_nir_band', 47), + sugar_bounds=config.get('sugar_bounds'), + sugar_sigma=config.get('sugar_sigma', 1.0), + sugar_estimate_background=config.get('sugar_estimate_background', True), + sugar_glint_mask_method=config.get('sugar_glint_mask_method', 'cdf'), + sugar_iter=config.get('sugar_iter', 3), + sugar_termination_thresh=config.get('sugar_termination_thresh', 20.0), + deglint_dir=str(context.deglint_dir), + water_mask_dir=str(context.water_mask_dir), + callback=context.notify, + output_path=config.get('output_path'), + ) + + context.deglint_img_path = result + + step_end_time = time.time() + context.record_step_time( + "步骤3: 耀斑去除", step_start_time, step_end_time + ) + + return {'deglint_img_path': result} + + except Exception as e: + step_end_time = time.time() + context.record_step_time( + "步骤3: 耀斑去除", step_start_time, step_end_time, + status="failed", error=str(e) + ) + raise diff --git a/src/core/handlers/step4_sampling.py b/src/core/handlers/step4_sampling.py new file mode 100644 index 0000000..9a3834c --- /dev/null +++ b/src/core/handlers/step4_sampling.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Step4 处理器:预测采样点生成 + +将原 WaterQualityInversionPipeline.step4_sampling() 方法 +剥离为独立的 Step4SamplingHandler。 +""" + +import time +from typing import Any, Dict + +from src.core.handlers.base import BaseStepHandler, PipelineContext +from src.core.steps.prediction_step import PredictionStep + + +class Step4SamplingHandler(BaseStepHandler): + """步骤4:生成预测采样点并提取光谱。 + + 对应 config key: 'step4_sampling' + 委托类: PredictionStep.generate_sampling_points() + """ + + step_key = 'step4_sampling' + + def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]: + step_start_time = time.time() + + deglint_img_path = self._resolve_path( + config.get('deglint_img_path'), context.deglint_img_path, 'deglint_img' + ) + water_mask_path = self._resolve_path( + config.get('water_mask_path'), context.water_mask_path, 'water_mask' + ) + glint_mask_path = self._resolve_path( + config.get('glint_mask_path'), context.glint_mask_path, 'glint_mask' + ) + + try: + result = PredictionStep.generate_sampling_points( + deglint_img_path=deglint_img_path, + interval=config.get('interval', 50), + sample_radius=config.get('sample_radius', 5), + chunk_size=config.get('chunk_size', 1000), + water_mask_path=water_mask_path, + glint_mask_path=glint_mask_path, + output_dir=str(context.sampling_dir), + use_adaptive_sampling=config.get('use_adaptive_sampling', True), + ) + + step_end_time = time.time() + context.record_step_time( + "步骤4: 生成预测采样点", step_start_time, step_end_time + ) + + return {'sampling_csv_path': result} + + except Exception as e: + step_end_time = time.time() + context.record_step_time( + "步骤4: 生成预测采样点", step_start_time, step_end_time, + status="failed", error=str(e) + ) + raise diff --git a/src/core/handlers/step5_process_csv.py b/src/core/handlers/step5_process_csv.py new file mode 100644 index 0000000..233613a --- /dev/null +++ b/src/core/handlers/step5_process_csv.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Step5 处理器:CSV 数据处理 + +将原 WaterQualityInversionPipeline.step5_process_csv() 方法 +剥离为独立的 Step5ProcessCsvHandler。 +""" + +import time +from typing import Any, Dict + +from src.core.handlers.base import BaseStepHandler, PipelineContext +from src.core.steps.data_preparation_step import DataPreparationStep + + +class Step5ProcessCsvHandler(BaseStepHandler): + """步骤5:处理 CSV 文件,筛选剔除异常值。 + + 对应 config key: 'step5_clean' + 委托类: DataPreparationStep.process_csv() + """ + + step_key = 'step5_clean' + + def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]: + step_start_time = time.time() + + try: + result = DataPreparationStep.process_csv( + csv_path=config.get('csv_path'), + output_dir=str(context.processed_data_dir), + ) + + context.processed_csv_path = result + + step_end_time = time.time() + context.record_step_time( + "步骤5: 处理CSV文件", step_start_time, step_end_time + ) + + return {'processed_csv_path': result} + + except Exception as e: + step_end_time = time.time() + context.record_step_time( + "步骤5: 处理CSV文件", step_start_time, step_end_time, + status="failed", error=str(e) + ) + raise diff --git a/src/core/handlers/step6_extract_spectra.py b/src/core/handlers/step6_extract_spectra.py new file mode 100644 index 0000000..ee53208 --- /dev/null +++ b/src/core/handlers/step6_extract_spectra.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Step6 处理器:训练样本点光谱提取 + +将原 WaterQualityInversionPipeline.step6_extract_spectra() 方法 +剥离为独立的 Step6ExtractSpectraHandler。 +""" + +import time +from typing import Any, Dict + +from src.core.handlers.base import BaseStepHandler, PipelineContext +from src.core.steps.data_preparation_step import DataPreparationStep + + +class Step6ExtractSpectraHandler(BaseStepHandler): + """步骤6:根据采样点坐标在去耀斑影像中提取平均光谱。 + + 对应 config key: 'step6_feature' + 委托类: DataPreparationStep.extract_training_spectra() + """ + + step_key = 'step6_feature' + + def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]: + step_start_time = time.time() + + deglint_img_path = self._resolve_path( + config.get('deglint_img_path'), context.deglint_img_path, 'deglint_img' + ) + csv_path = self._resolve_path( + config.get('csv_path'), context.processed_csv_path, 'csv' + ) + glint_mask_path = self._resolve_path( + config.get('glint_mask_path'), context.glint_mask_path, 'glint_mask' + ) + + try: + result = DataPreparationStep.extract_training_spectra( + deglint_img_path=deglint_img_path, + radius=config.get('radius', 5), + source_epsg=config.get('source_epsg', 4326), + csv_path=csv_path, + boundary_path=config.get('boundary_path'), + glint_mask_path=glint_mask_path, + water_mask_path=context.water_mask_path, + output_dir=str(context.training_spectra_dir), + ) + + context.training_csv_path = result + + step_end_time = time.time() + context.record_step_time( + "步骤6: 提取训练样本点光谱", step_start_time, step_end_time + ) + + return {'training_csv_path': result} + + except Exception as e: + step_end_time = time.time() + context.record_step_time( + "步骤6: 提取训练样本点光谱", step_start_time, step_end_time, + status="failed", error=str(e) + ) + raise diff --git a/src/core/handlers/step7_calc_indices.py b/src/core/handlers/step7_calc_indices.py new file mode 100644 index 0000000..cf1ddd7 --- /dev/null +++ b/src/core/handlers/step7_calc_indices.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Step7 处理器:水质光谱指数计算 + +将原 WaterQualityInversionPipeline.step7_calc_indices() 方法 +剥离为独立的 Step7CalcIndicesHandler。 +""" + +import time +from typing import Any, Dict + +from src.core.handlers.base import BaseStepHandler, PipelineContext +from src.core.steps.data_preparation_step import DataPreparationStep + + +class Step7CalcIndicesHandler(BaseStepHandler): + """步骤7:根据训练光谱计算水质光谱指数。 + + 对应 config key: 'step7_index' + 委托类: DataPreparationStep.calculate_water_quality_indices() + """ + + step_key = 'step7_index' + + def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]: + step_start_time = time.time() + + training_csv_path = self._resolve_path( + config.get('training_csv_path'), context.training_csv_path, 'training_csv' + ) + + try: + result = DataPreparationStep.calculate_water_quality_indices( + training_csv_path=training_csv_path, + formula_csv_file=config.get('formula_csv_file'), + formula_names=config.get('formula_names'), + output_file=config.get('output_file'), + enabled=config.get('enabled', True), + output_dir=str(context.indices_dir), + ) + + context.indices_path = result + + step_end_time = time.time() + context.record_step_time( + "步骤7: 计算水质光谱指数", step_start_time, step_end_time + ) + + return {'indices_path': result} + + except Exception as e: + step_end_time = time.time() + context.record_step_time( + "步骤7: 计算水质光谱指数", step_start_time, step_end_time, + status="failed", error=str(e) + ) + raise diff --git a/src/gui/core/worker_thread.py b/src/gui/core/worker_thread.py index b8c30f2..62da5b7 100644 --- a/src/gui/core/worker_thread.py +++ b/src/gui/core/worker_thread.py @@ -6,8 +6,9 @@ import os import traceback from typing import Dict, List from PyQt5.QtCore import QThread, pyqtSignal -from src.core.pipeline.runner import PipelineRunner, PipelineHalt -from src.core.pipeline.context import PipelineContext +from src.core.pipeline.runner import PipelineHalt +from src.core.handlers.pipeline_scheduler import PipelineScheduler +from src.core.handlers.register_handlers import register_all_handlers # ============================================================================= @@ -113,9 +114,10 @@ PIPELINE_ERROR_INFO = [] try: error_info = diagnose_pipeline_import_error() - from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline + from src.core.handlers.pipeline_scheduler import PipelineScheduler + from src.core.handlers.register_handlers import register_all_handlers PIPELINE_AVAILABLE = True - print("[OK] 成功导入pipeline模块") + print("[OK] 成功导入 Handler 调度器模块") PIPELINE_ERROR_INFO = error_info except ImportError as e: @@ -140,12 +142,11 @@ except ImportError as e: print(" 2. 如果需要修复,可以在.spec文件中添加unittest模块:") print(" a = Analysis(..., hiddenimports=['unittest', 'unittest.mock'])") print(" 3. 或在PyInstaller命令中添加: --hidden-import unittest") - elif "water_quality_inversion_pipeline_GUI" in str(e): + elif "handlers" in str(e) or "pipeline_scheduler" in str(e): print("[INFO] 可能的解决方案:") - print(" 1. 检查src/core/water_quality_inversion_pipeline_GUI.py文件是否存在") - print(" 2. 确保Python路径设置正确") + print(" 1. 检查 src/core/handlers/ 目录是否存在") + print(" 2. 确保 Python 路径设置正确") print(" 3. 尝试重新安装依赖: pip install -r requirements.txt") - print(" 4. 检查Python版本是否兼容(推荐Python 3.8-3.11)") import traceback print("\n完整错误追踪:") @@ -257,35 +258,32 @@ class WorkerThread(QThread): except Exception: mpl_prev = None try: - from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline - - self.pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir) + # ── 新架构:PipelineScheduler + Handler 注册表 ── + scheduler = PipelineScheduler(work_dir=self.work_dir) + scheduler.set_callback(self.pipeline_callback) + register_all_handlers(scheduler) + self.pipeline = scheduler # 保持兼容(stop() 等引用 self.pipeline) if self.mode == 'full': - self.log_message.emit("开始运行完整流程 (Runner 调度模式)...", "info") - if hasattr(self.pipeline, 'set_callback'): - self.pipeline.set_callback(self.pipeline_callback) + self.log_message.emit("开始运行完整流程 (Handler 调度模式)...", "info") # ── ★ 预检已由 GUI 层 perform_preflight() 完成,此处不再重复预检 ── - # 构造上下文 (Ctx),将 config 整体注入 user_config - ctx = PipelineContext( - img_path=self.config.get('step1', {}).get('img_path'), - water_mask_path=self.config.get('step1', {}).get('mask_path'), - csv_path=self.config.get('step4_sampling', {}).get('csv_path'), - boundary_path=self.config.get('step5_clean', {}).get('boundary_path'), - boundary_shp_path=self.config.get('step11_map', {}).get('boundary_shp_path'), - formula_csv_path=self.config.get('step8_non_empirical_modeling', {}).get('formula_csv_path'), - work_dir=self.work_dir, - user_config=self.config - ) + # 过滤 skip_list 中的步骤 + active_config = { + k: v for k, v in self.config.items() + if k not in self.skip_list + } - # 启动新调度器 - runner = PipelineRunner(self.pipeline) - result_ctx = runner.run(ctx, config=self.config, skip_list=self.skip_list) + result = scheduler.run_full_pipeline(active_config) - if result_ctx.last_error: - raise RuntimeError(f"流水线执行失败: {result_ctx.last_error}") + errors = result.get('errors', {}) + if errors: + error_lines = [f" {k}: {v}" for k, v in errors.items()] + raise RuntimeError( + f"流水线部分步骤执行失败 ({len(errors)} 个):\n" + + "\n".join(error_lines) + ) self.progress_update.emit(100, "流程执行完成") self.finished.emit(True, "完整流程执行成功!") @@ -293,10 +291,7 @@ class WorkerThread(QThread): self.log_message.emit(f"开始独立运行步骤: {self.step_name}", "info") self.progress_update.emit(0, f"正在执行: {self.step_name}") - if hasattr(self.pipeline, 'set_callback'): - self.pipeline.set_callback(self.pipeline_callback) - - self.run_single_step(self.step_name, self.config) + self.run_single_step(scheduler, self.step_name, self.config) self.progress_update.emit(100, f"步骤 {self.step_name} 执行完成") self.finished.emit(True, f"步骤 {self.step_name} 独立运行成功!") @@ -317,56 +312,24 @@ class WorkerThread(QThread): except Exception: pass - def run_single_step(self, step_name, config): - """运行单个步骤""" - step_method_map = { - 'step1': 'step1_generate_water_mask', - 'step2': 'step2_find_glint_area', - 'step3': 'step3_remove_glint', - 'step4_sampling': 'step4_sampling', - 'step5_clean': 'step5_process_csv', - 'step6_feature': 'step6_extract_spectra', - 'step7_index': 'step7_calc_indices', - 'step8_ml_train': 'step8_train_ml', - 'step8_non_empirical_modeling': 'step8_non_empirical_modeling', - 'step8_qaa': 'step8_qaa_inversion', - 'step9_ml_predict': 'step9_predict_ml', - 'step10_watercolor': 'step9_watercolor_inversion', - 'step11_map': 'step10_map', - } - - if step_name not in step_method_map: - raise ValueError(f"未知的步骤名称: {step_name}") - - method_name = step_method_map[step_name] + def run_single_step(self, scheduler, step_name, config): + """使用新调度器运行单个步骤。""" step_config = dict(config.get(step_name, {})) - # step8_qaa_inversion 内部使用 config.get('step8_qaa', {}) 读取内层, - # 必须透传完整 config dict(含外层 step_name key) - if step_name == 'step8_qaa': - method = getattr(self.pipeline, method_name) - result = method(**config) - return result - - # 透传面板顶层传入的外部预训练模型(GUI step11_prediction_panel 通过 config['_external_model'] 传入) - # 非空才覆盖(遵循 feedback_never_overwrite_with_empty 原则) + # 透传外部预训练模型(非空才覆盖) for key in ('_external_model', '_external_model_path', '_external_models_dict', '_external_model_dir'): val = config.get(key) if val is not None and val != "": step_config[key] = val - if key == '_external_models_dict': - print(f"[Worker] 提取到的外部字典 Keys: {list(val.keys())}") - else: - print(f"[Worker] 透传 {key}: {val}") step_config['skip_dependency_check'] = True - if step_name in ['step2', 'step3', 'step4_sampling', 'step5_clean', 'step7_index', 'step9_ml_predict']: - step_config.pop('output_path', None) - - method = getattr(self.pipeline, method_name) - result = method(**step_config) + # step8_qaa 特殊处理:透传完整 config(含外层 step8_qaa key) + if step_name == 'step8_qaa': + result = scheduler.run_step(step_name, config) + else: + result = scheduler.run_step(step_name, step_config) return result