diff --git a/src/core/pipeline/__init__.py b/src/core/pipeline/__init__.py new file mode 100644 index 0000000..3a3d22c --- /dev/null +++ b/src/core/pipeline/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +""" +Pipeline 调度核心:基于 Context 的内存级依赖注入。 + +设计目标: + - 用 PipelineContext 替代 dict 散落传参(9 步主路径 + 14 个 step 共享同一份 ctx) + - 14 个 step 声明式描述(StepSpec),便于 Web / 异步 / 单元测试复用 + - 不绑定具体 Pipeline 实现(duck-typed),WorkerThread / Web API / 单测可共用 +""" + +from .context import PipelineContext +from .runner import StepSpec, PIPELINE_STEPS, PipelineRunner + +__all__ = ["PipelineContext", "StepSpec", "PIPELINE_STEPS", "PipelineRunner"] diff --git a/src/core/pipeline/context.py b/src/core/pipeline/context.py new file mode 100644 index 0000000..2bf8d29 --- /dev/null +++ b/src/core/pipeline/context.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +""" +PipelineContext:内存级数据载体,跨 14 个 step 传递路径与元信息。 + +设计原则: + - 所有路径字段以 `_path` 为后缀(与 step 方法形参命名约定一致) + - 字段值可缺省(None),由 StepSpec.requires 在调度时注入 + - dataclass + field(default_factory=dict) 支持原地增删 + - 不放 GUI 状态(避免循环依赖) + - 不绑具体 step 方法(duck-typed cancellation / log append) +""" + +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class PipelineContext: + """流水线运行上下文(在 14 个 step 之间传递的内存字典) + + 字段命名约定: + - 主路径字段统一 `_path` 后缀(如 water_mask_path) + - 目录类字段无 `_path` 后缀(如 models_dir) + - 元信息字段无后缀(如 user_config / status / log) + """ + + # ── 9 步主路径(按 step 输出顺序排列) ── + raw_img_path: Optional[str] = None # Step 1 入参:原始影像 + water_mask_path: Optional[str] = None # Step 1 出 → Step 2/3/7 入 + glint_mask_path: Optional[str] = None # Step 2 出 → Step 3/7 入 + deglint_img_path: Optional[str] = None # Step 3 出 → Step 5/7 入 + raw_csv_path: Optional[str] = None # Step 4 入:原始 CSV + processed_csv_path: Optional[str] = None # Step 4 出 → Step 5 入 + training_spectra_path: Optional[str] = None # Step 5 出 → Step 6 入 + indices_path: Optional[str] = None # Step 5.5 出 + sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/9 入 + prediction_csv_path: Optional[str] = None # Step 8 出 + distribution_map_path: Optional[str] = None # Step 9 出 + + # ── 目录类(命名不带 _path 以示区别) ── + models_dir: Optional[str] = None + prediction_dir: Optional[str] = None + work_dir: Optional[str] = None + + # ── Step 6 训练产物(AutoML 模式有,常规模式为空) ── + model_files: List[str] = field(default_factory=list) + + # ── 元信息(三件套:用户传的配置 / 取消事件 / 状态) ── + user_config: Dict[str, Any] = field(default_factory=dict) + cancel_event: Optional[Any] = None # duck-typed threading.Event / asyncio.Event + status: Dict[str, str] = field(default_factory=dict) # {step_id: 'start'/'completed'/'skipped'/'error'} + log: List[str] = field(default_factory=list) + + # ── 诊断 ── + step_timings: Dict[str, float] = field(default_factory=dict) + pipeline_start_time: Optional[float] = None + pipeline_end_time: Optional[float] = None + last_error: Optional[str] = None + + # ============================================================ + # 读写辅助 + # ============================================================ + + def set(self, key: str, value: Any) -> None: + """原地写入任意属性。 + + 允许动态字段(如 'report_path')直接挂在 __dict__ 上, + 避免因静态字段缺失而抛 AttributeError。 + """ + object.__setattr__(self, key, value) + + def get(self, key: str, default: Any = None) -> Any: + """原地读出,缺 key 不抛错。""" + return getattr(self, key, default) + + def is_cancelled(self) -> bool: + """统一软取消检查入口(duck-typed)。 + + 支持: + - threading.Event(.is_set()) + - asyncio.Event(loop-bound,is_set 同步接口存在) + - 自定义 .is_set() / .cancelled 属性 + """ + ev = self.cancel_event + if ev is None: + return False + is_set = getattr(ev, "is_set", None) + if callable(is_set): + return bool(is_set()) + return bool(getattr(ev, "cancelled", False)) + + def append_log(self, msg: str) -> None: + """写入日志列表(也用于主进程 stdout 调试)。""" + self.log.append(msg) diff --git a/src/core/pipeline/runner.py b/src/core/pipeline/runner.py new file mode 100644 index 0000000..0d758cd --- /dev/null +++ b/src/core/pipeline/runner.py @@ -0,0 +1,283 @@ +# -*- coding: utf-8 -*- +""" +PipelineRunner:基于 StepSpec 声明式调度 14 个 step。 + +设计要点: + - StepSpec 声明 requires(ctx 字段名列表)+ produces(ctx 字段名列表) + - 默认约定:ctx 字段名去掉 `_path` 后缀 = step 方法形参名 + 例:ctx.water_mask_path → 形参 water_mask + 例:ctx.raw_img_path → 形参 raw_img + - 可被 spec.parameter_map 覆盖 + - 调度顺序:按 PIPELINE_STEPS 列表顺序,requires 缺则 skip + - 软取消:在每个 step 前检查 ctx.is_cancelled() + - duck-typed pipeline:runner 只调 getattr(pipeline, method_name),不强依赖类层级 +""" + +from __future__ import annotations +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence + +from .context import PipelineContext + + +# ============================================================ +# StepSpec 声明式描述 +# ============================================================ + +@dataclass +class StepSpec: + """单个 step 的元信息(声明式,避免硬编码)""" + step_id: str + method_name: str + requires: List[str] # PipelineContext 字段名列表 + produces: List[str] = field(default_factory=list) # 写入 ctx 的字段名列表 + enabled: bool = True + parameter_map: Dict[str, str] = field(default_factory=dict) + # 当 requires 中任一字段为 None 时是否跳过;默认 True(缺输入就 skip) + skip_when_missing: bool = True + # 备注(仅用于文档生成 / 调试输出) + description: str = "" + + +# ============================================================ +# 14 个 step 的声明表(顺序即调度顺序) +# 注:本表是"权威描述",与 WorkerThread.step_method_map / 旧 run_full_pipeline 保持一致 +# ============================================================ + +PIPELINE_STEPS: List[StepSpec] = [ + StepSpec( + step_id="step1", method_name="step1_generate_water_mask", + requires=["raw_img_path"], produces=["water_mask_path"], + # ctx.raw_img_path → 形参 img_path(老 step1 形参名是 img_path,不是 raw_img) + parameter_map={"raw_img_path": "img_path"}, + description="水域掩膜生成(NDWI 或 SHP)", + ), + StepSpec( + step_id="step2", method_name="step2_find_glint_area", + requires=["raw_img_path", "water_mask_path"], produces=["glint_mask_path"], + # raw_img_path→img_path;water_mask_path 不变 + parameter_map={"raw_img_path": "img_path"}, + description="耀斑区域检测", + ), + StepSpec( + step_id="step3", method_name="step3_remove_glint", + requires=["deglint_img_path"], produces=["deglint_img_path"], + # deglint_img_path→img_path(老 step3 形参名是 img_path) + # 注意:glint_mask_path 不在 requires 中——step3 形参表无该参数,内部走 self.glint_mask_path 回退 + parameter_map={"deglint_img_path": "img_path"}, + description="耀斑去除", + ), + StepSpec( + step_id="step4", method_name="step4_process_csv", + requires=["raw_csv_path"], produces=["processed_csv_path"], + # raw_csv_path→csv_path(老 step4 形参名是 csv_path) + parameter_map={"raw_csv_path": "csv_path"}, + description="CSV 异常值清洗", + ), + StepSpec( + step_id="step5", method_name="step5_extract_training_spectra", + requires=["deglint_img_path", "processed_csv_path"], produces=["training_spectra_path"], + # processed_csv_path→csv_path(老 step5 形参名是 csv_path);deglint_img_path 不变 + parameter_map={"processed_csv_path": "csv_path"}, + description="实测样本点光谱提取", + ), + StepSpec( + step_id="step5_5", method_name="step5_5_calculate_water_quality_indices", + requires=["training_spectra_path"], produces=["indices_path"], + # 老 step5.5 形参是 training_spectra_path;ctx 字段同名,无需映射 + parameter_map={}, + description="水质光谱指数计算(optional)", + ), + StepSpec( + step_id="step6", method_name="step6_train_models", + requires=["training_spectra_path"], produces=["models_dir"], + # training_spectra_path→training_csv_path(老 step6 形参名是 training_csv_path) + parameter_map={"training_spectra_path": "training_csv_path"}, + description="ML 建模(GridSearchCV / AutoML)", + ), + StepSpec( + step_id="step6_5", method_name="step6_5_non_empirical_modeling", + requires=["training_spectra_path"], produces=["models_dir"], + # training_spectra_path→csv_path(老 step6.5 形参名是 csv_path) + parameter_map={"training_spectra_path": "csv_path"}, + description="非经验统计回归", + ), + StepSpec( + step_id="step6_75", method_name="step6_75_custom_regression", + requires=["training_spectra_path"], produces=["models_dir"], + # training_spectra_path→csv_path(老 step6.75 形参名是 csv_path) + parameter_map={"training_spectra_path": "csv_path"}, + description="自定义回归分析", + ), + StepSpec( + step_id="step7", method_name="step7_generate_sampling_points", + requires=["deglint_img_path", "water_mask_path"], produces=["sampling_csv_path"], + # 老 step7 形参是 deglint_img_path / water_mask_path;ctx 字段同名 + parameter_map={}, + description="整景密集采样点生成 + 光谱提取", + ), + StepSpec( + step_id="step8", method_name="step8_predict_water_quality", + requires=["sampling_csv_path", "models_dir"], produces=["prediction_csv_path"], + parameter_map={}, + description="ML 模型预测(采样点)", + ), + StepSpec( + step_id="step8_5", method_name="step8_5_predict_with_non_empirical_models", + requires=["sampling_csv_path"], produces=["prediction_dir"], + parameter_map={}, + description="非经验模型预测", + ), + StepSpec( + step_id="step8_75", method_name="step8_75_predict_with_custom_regression", + requires=["sampling_csv_path"], produces=["prediction_dir"], + parameter_map={}, + description="自定义回归预测", + ), + StepSpec( + step_id="step9", method_name="step9_generate_distribution_map", + requires=["prediction_csv_path"], + produces=["distribution_map_path"], + # 老 step9 形参是 prediction_csv_path / boundary_shp_path;ctx 字段同名 + # 注意:sampling_csv_path / water_mask_path 不在 requires 中——step9 形参表无该参数, + # 内部走 self.sampling_csv_path / self.water_mask_path 回退 + parameter_map={}, + description="克里金插值成图", + ), +] + + +# ============================================================ +# PipelineRunner:执行者 +# ============================================================ + +class PipelineRunner: + """按 StepSpec 调度 14 个 step 方法,支持软取消 + 路径 ctx 注入。 + + 用法: + runner = PipelineRunner(pipeline_instance) + ctx = PipelineContext(raw_img_path=..., ...) + result_ctx = runner.run(ctx) + """ + + def __init__(self, pipeline, steps: Optional[Sequence[StepSpec]] = None): + self.pipeline = pipeline + self.steps: List[StepSpec] = list(steps) if steps else list(PIPELINE_STEPS) + + def run(self, ctx: PipelineContext) -> PipelineContext: + """主入口:按顺序执行 14 步。软取消时已完成的 step 保留结果。""" + ctx.pipeline_start_time = time.time() + for spec in self.steps: + if ctx.is_cancelled(): + ctx.append_log(f"[RUNNER] 收到取消信号,提前终止 @ {spec.step_id}") + break + if not spec.enabled: + ctx.status[spec.step_id] = "skipped" + ctx.append_log(f"[RUNNER] {spec.step_id} 标记为 disabled,跳过") + continue + if spec.skip_when_missing: + missing = [k for k in spec.requires if not ctx.get(k)] + if missing: + ctx.status[spec.step_id] = "skipped" + ctx.append_log( + f"[RUNNER] {spec.step_id} 缺少入参: {missing},跳过" + ) + continue + self._invoke(spec, ctx) + ctx.pipeline_end_time = time.time() + return ctx + + # ------------------------------------------------------------------ + def _invoke(self, spec: StepSpec, ctx: PipelineContext) -> None: + """调一个 step 方法:ctx 路径 → 形参;产出 → ctx 字段。""" + # DEBUG: 诊断"停在 step4"问题——每步打印 requires + ctx 实际数据 + # 看到 requires=[] 但 actual=[None,...] 就说明 ctx 缺料,step 会被 skip + ctx.append_log( + f"[DEBUG] Step {spec.step_id} requires: {spec.requires}, " + f"actual ctx data: {[ctx.get(k) for k in spec.requires]}" + ) + method = getattr(self.pipeline, spec.method_name, None) + if method is None: + ctx.append_log(f"[RUNNER] 步骤方法缺失: {spec.method_name}(跳过)") + ctx.status[spec.step_id] = "skipped" + return + + # 1) 把 ctx 路径作为形参注入(默认约定:去 _path 后缀) + kwargs: Dict[str, Any] = {} + for ctx_key in spec.requires: + param_name = spec.parameter_map.get(ctx_key, self._default_param_name(ctx_key)) + kwargs[param_name] = ctx.get(ctx_key) + + # 2) 允许用户在 ctx.user_config[step_id] 覆盖/补充 + user_overrides = ctx.user_config.get(spec.step_id) or {} + if isinstance(user_overrides, dict): + kwargs.update(user_overrides) + + # 3) 状态置 start + ctx.append_log( + f"[RUNNER] -> {spec.method_name}({list(kwargs.keys())})" + ) + ctx.status[spec.step_id] = "start" + notify = getattr(self.pipeline, "_notify", None) + if callable(notify): + try: + notify(f"步骤{spec.step_id[-1]}", "start", spec.method_name) + except Exception: + pass + + # 4) 执行 + 捕获异常(不让单步崩溃拖垮 runner) + t0 = time.time() + try: + result = method(**kwargs) + ctx.status[spec.step_id] = "completed" + ctx.step_timings[spec.step_id] = time.time() - t0 + + # 5) 产出收割 + self._harvest(spec, result, ctx) + + if callable(notify): + try: + notify( + f"步骤{spec.step_id[-1]}", + "completed", + str(result)[:200] if result is not None else "", + ) + except Exception: + pass + except Exception as exc: + ctx.status[spec.step_id] = "error" + ctx.last_error = f"{spec.step_id}: {exc!r}" + ctx.append_log(f"[RUNNER] {spec.step_id} 异常: {exc!r}") + if callable(notify): + try: + notify(f"步骤{spec.step_id[-1]}", "error", str(exc)) + except Exception: + pass + + # ------------------------------------------------------------------ + def _harvest(self, spec: StepSpec, result: Any, ctx: PipelineContext) -> None: + """把 step 方法返回值灌入 ctx 的 produces 字段。 + + 规则: + - 若 result 是 dict 且 key 匹配 produce_key:ctx.set(produce_key, result[key]) + - 若 result 非 dict 且 produces 非空:第一个 produces 字段接 result + - 若 produces 为空:result 仅记录到 log,不写 ctx + """ + if not spec.produces: + return + if isinstance(result, dict): + for produce_key in spec.produces: + if produce_key in result: + ctx.set(produce_key, result[produce_key]) + elif result is not None: + ctx.set(spec.produces[0], result) + + # ------------------------------------------------------------------ + @staticmethod + def _default_param_name(ctx_key: str) -> str: + """ + 废弃有毒的去 _path 后缀逻辑。 + 默认原样返回 ctx 键名作为形参名。遇到特殊缩写时,由各个 step 的 parameter_map 显式处理。 + """ + return ctx_key