# -*- coding: utf-8 -*- """ PipelineRunner:基于 StepSpec 声明式调度 14 个 step。 设计要点: - StepSpec 声明 requires(ctx 字段名列表)+ produces(ctx 字段名列表) - 命名约定:ctx 字段名 == panel key 名 == step 形参名(全链路无翻译) - 步骤命名:step_id 格式为 stepN 或 stepN_suffix(无小数位),method_name 与 step_id 对齐 - 调度顺序:按 PIPELINE_STEPS 列表顺序,requires 缺则 skip - 软取消:在每个 step 前检查 ctx.is_cancelled() - 断点续跑:spec.output_file 已落盘则跳过执行 - 错误汇总:全流程结束后 error_summary 记录所有 step 的异常 - 预检:run() 入口硬校验 step1 img_path;其余依赖通过智能补全 + 软警告处理 - PipelineHalt:外层 run() 不 catch,触发循环 break,实现硬终止 - STEP_MAP:旧 step_id → 新 step_id 双向映射,供 GUI 配置兼容使用 - duck-typed pipeline:runner 只调 getattr(pipeline, method_name),不强依赖类层级 """ from __future__ import annotations import inspect import logging import os import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence from .context import PipelineContext, STEP_MAP_OLD_TO_NEW, STEP_MAP_NEW_TO_OLD, resolve_step_id logger = logging.getLogger(__name__) # ============================================================ # 终止异常(外层 run() 不 catch,触发循环 break) # ============================================================ class PipelineHalt(Exception): """不可恢复的错误,在 run() 循环中抛出后直接 break,不走 Exception 处理分支。 适用场景: - GUI 层通过 _notify 弹窗拦截后主动抛出的硬终止信号 """ pass # ============================================================ # StepSpec 声明式描述 # ============================================================ @dataclass class StepSpec: """单个 step 的元信息(声明式,避免硬编码)""" step_id: str method_name: str requires: List[str] # PipelineContext 字段名列表 produces: List[str] = field(default_factory=list) # 写入 ctx 的字段名列表 enabled: bool = True parameter_map: Dict[str, str] = field(default_factory=dict) # 当 requires 中任一字段为 None 时是否跳过;默认 True(缺输入就 skip) skip_when_missing: bool = True # 备注(仅用于文档生成 / 调试输出) description: str = "" # ★ 断点续跑:产物文件路径,支持 {work_dir} 占位符(运行时解析) output_file: Optional[str] = None # ★ 预检用:需要验证磁盘文件实际存在的 ctx key 列表 required_input_files: List[str] = field(default_factory=list) # ============================================================ # 14 个 step 的声明表(顺序即调度顺序) # step_id / method_name 均不含小数位,与前端显示对齐 # output_file / required_input_files 使用 {work_dir} 占位符,由 _resolve_path 展开 # ============================================================ PIPELINE_STEPS: List[StepSpec] = [ StepSpec( step_id="step1", method_name="step1_generate_water_mask", requires=["img_path"], produces=["water_mask_path"], required_input_files=["img_path"], output_file="{work_dir}/1_water_mask/water_mask.dat", description="水域掩膜生成(NDWI 或 SHP)", ), StepSpec( step_id="step2", method_name="step2_find_glint_area", requires=["img_path", "water_mask_path"], produces=["glint_mask_path"], required_input_files=["img_path", "water_mask_path"], output_file="{work_dir}/2_glint/glint_mask.dat", description="耀斑区域检测", ), StepSpec( step_id="step3", method_name="step3_remove_glint", requires=["img_path", "water_mask_path", "glint_mask_path"], produces=["deglint_img_path"], required_input_files=["img_path", "water_mask_path", "glint_mask_path"], output_file="{work_dir}/3_deglint/deglint.bsq", description="耀斑去除", ), StepSpec( step_id="step4", method_name="step4_process_csv", requires=["csv_path"], produces=["processed_csv_path"], required_input_files=["csv_path"], output_file="{work_dir}/4_processed_data/processed_data.csv", description="CSV 异常值清洗", ), StepSpec( step_id="step5", method_name="step5_extract_training_spectra", requires=["deglint_img_path", "processed_csv_path", "csv_path", "boundary_path", "glint_mask_path"], produces=["training_csv_path"], parameter_map={ "processed_csv_path": "csv_path", "csv_path": "_raw_csv_ignored", }, skip_when_missing=False, required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"], output_file="{work_dir}/5_training_spectra/training_spectra.csv", description="实测样本点光谱提取", ), StepSpec( step_id="step7", method_name="step7_water_quality_indices", requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"], required_input_files=["training_csv_path"], output_file="{work_dir}/6_water_quality_indices/training_spectra_indices.csv", description="水质参数指数计算(双轨输出:A轨宽表 + B轨单文件)", ), StepSpec( step_id="step8", method_name="step8_ml_modeling", requires=["training_csv_path"], produces=["models_dir"], required_input_files=["training_csv_path"], output_file="{work_dir}/7_Supervised_Model_Training/best_models.pkl", description="ML 建模(GridSearchCV / AutoML)", ), StepSpec( step_id="step8_non_empirical_modeling", method_name="step8_non_empirical_modeling", requires=["training_csv_path"], produces=["models_dir"], parameter_map={"training_csv_path": "csv_path"}, required_input_files=["training_csv_path"], output_file="{work_dir}/8_Regression_Modeling/non_empirical_models.pkl", description="非经验统计回归", ), StepSpec( step_id="step9", method_name="step9_watercolor_inversion", requires=["deglint_img_path", "water_mask_path"], produces=["watercolor_index_dir"], required_input_files=["deglint_img_path"], output_file="{work_dir}/9_WaterColor_Index_Images", description="水色指数反演(BSQ 影像直接处理)", ), StepSpec( step_id="step10", method_name="step10_sampling", requires=["deglint_img_path", "water_mask_path"], produces=["sampling_csv_path"], required_input_files=["deglint_img_path", "water_mask_path"], output_file="{work_dir}/4_sampling/sampling_spectra.csv", description="整景密集采样点生成 + 光谱提取", ), StepSpec( step_id="step11_ml", method_name="step11_ml_prediction", requires=["sampling_csv_path", "models_dir"], produces=["prediction_csv_path"], required_input_files=["sampling_csv_path", "models_dir"], output_file="{work_dir}/11_12_13_predictions/prediction_results.csv", description="ML 模型预测(采样点)", ), StepSpec( step_id="step11", method_name="step11_non_empirical_prediction", requires=["sampling_csv_path", "models_dir"], produces=["prediction_dir"], parameter_map={"models_dir": "non_empirical_models_dir"}, required_input_files=["sampling_csv_path", "models_dir"], output_file="{work_dir}/11_12_13_predictions/non_empirical_predictions", description="非经验模型预测", ), StepSpec( step_id="step14", method_name="step14_distribution_map", requires=["prediction_csv_path", "boundary_shp_path"], produces=["distribution_map_path"], required_input_files=["prediction_csv_path", "boundary_shp_path"], output_file="{work_dir}/distribution_map.png", description="克里金插值成图", ), ] # ============================================================ # PipelineRunner:执行者 # ============================================================ class PipelineRunner: """按 StepSpec 调度 14 个 step 方法,支持软取消 + 断点续跑 + 错误汇总。 用法: ctx = PipelineContext(img_path=..., work_dir=..., user_config=config) runner = PipelineRunner(pipeline_instance) result_ctx = runner.run(ctx, config=config) # 预检通过后开始执行 print(result_ctx.error_summary) # [(step_id, error_msg), ...] """ def __init__(self, pipeline, steps: Optional[Sequence[StepSpec]] = None): self.pipeline = pipeline self.steps: List[StepSpec] = list(steps) if steps else list(PIPELINE_STEPS) # ------------------------------------------------------------------ # 主入口 # ------------------------------------------------------------------ def run(self, ctx: PipelineContext, config=None, skip_list: Optional[List[str]] = None) -> PipelineContext: self.config = config or {} skip_list = skip_list or [] logger.info("开始运行完整流程 (Runner 调度模式)...") ctx.pipeline_start_time = time.time() error_summary: List[tuple[str, str]] = [] skip_set = set(skip_list) if skip_list else set() # ── ★ Step1 img_path 硬校验(缺失则立即终止整个流程) ── if not ctx.get("img_path"): msg = "【全流程预检失败】缺少参考影像路径 (img_path),流程无法启动。" ctx.append_log(f"[RUNNER] {msg}") self._notify_step("全流程", "error", msg) ctx.last_error = msg ctx.pipeline_end_time = time.time() return ctx # ── ★ 智能补全:扫描 work_dir 默认产物路径,回填 ctx ── self._scan_workdir_outputs(ctx) # ── ★ 自动补全缺失步骤:work_dir 有产物则强制开启 + 回填路径 ── self._auto_fill_missing_steps(ctx) # ── 软预检警告(不再阻断,仅记录日志)── self._preflight_warnings(ctx) # 断点续跑预扫描:ctx 已有产物则记录诊断日志 self._restore_outputs_from_ctx(ctx) # 1. 暴力上下文注入:将 GUI config 中的所有参数强行塞入 ctx,防丢失 for step_id, cfg in self.config.items(): if isinstance(cfg, dict): for k, v in cfg.items(): if k != 'enabled' and v: setattr(ctx, k, v) # 2. 构建依赖提供者映射 (Provider Map) provider_map = {} for step in self.steps: for prod in step.produces: provider_map[prod] = step # 3. 强力依赖级联唤醒 (Auto-Wakeup Engine) changed = True woke_up_steps = [] while changed: changed = False for step in self.steps: if step.step_id in skip_set: continue # 用户强踢的,绝不唤醒 step_cfg = self.config.setdefault(step.step_id, {}) if not step_cfg.get('enabled', True): continue for req in step.requires: # 如果上下文缺这个参数 if not (hasattr(ctx, req) and getattr(ctx, req)): provider = provider_map.get(req) if provider and provider.step_id not in skip_set: prov_cfg = self.config.setdefault(provider.step_id, {}) if not prov_cfg.get('enabled', True): prov_cfg['enabled'] = True changed = True woke_up_steps.append(provider.step_id) logger.info(f"[*] 自动唤醒: {provider.step_id} (为下游提供 {req})") if woke_up_steps: logger.info(f"★ 依赖唤醒完成,共唤醒 {len(woke_up_steps)} 个次/步骤") # 4. 正式执行流水线 for step in self.steps: # ── 软取消 ── if ctx.is_cancelled(): ctx.append_log(f"[RUNNER] 收到取消信号,提前终止 @ {step.step_id}") break if step.step_id in skip_set: ctx.status[step.step_id] = "user_skipped" ctx.append_log( f"\n{'='*60}\n" f" ⚠ 用户强制跳过: {step.step_id}({step.description})\n" f" 原因:用户在预检弹窗中勾选「忽略」,已确认跳过\n" f"{'='*60}\n" ) self._notify_step(step.step_id, "skipped", "用户强制跳过(预检弹窗)") continue step_cfg = self.config.get(step.step_id, {}) if not step_cfg.get('enabled', True): continue # 4.1 检查磁盘产物:如果已落盘,恢复上下文并跳过(拒绝静默跳过,必须打日志) if step.output_file and os.path.exists(step.output_file): for prod in step.produces: if not (hasattr(ctx, prod) and getattr(ctx, prod)): setattr(ctx, prod, step.output_file) ctx.status[step.step_id] = "skipped" ctx.append_log(f"[CACHE] 产物已存在,跳过运行并恢复上下文: {step.step_id}") self._notify_step(step.step_id, "skipped", "产物已存在(断点续跑)") continue # 4.2 依赖死线检查 missing = [req for req in step.requires if not (hasattr(ctx, req) and getattr(ctx, req))] if missing: ctx.status[step.step_id] = "skipped" reason = f"缺少必要的上下文参数,自动跳过: {missing}" ctx.append_log(f"[RUNNER] 跳过 {step.step_id},仍缺少必要参数: {missing}") self._notify_step(step.step_id, "skipped", reason) continue # 4.3 真正执行 ctx.append_log(f"[START] 正在执行步骤: {step.step_id}") self._notify_step(step.step_id, "running", f"正在执行: {step.description}") try: method = getattr(self.pipeline, step.method_name) sig = inspect.signature(method) kwargs = {} current_step_cfg = self.config.get(step.step_id, {}) for param_name in sig.parameters: # 优先级 1:直接使用当前步骤专属配置中的值 if param_name in current_step_cfg: kwargs[param_name] = current_step_cfg[param_name] continue # 优先级 1.5:【核心修复】硬隔离 output_file,防止被其他步骤的同名变量污染 if param_name == 'output_file' and hasattr(step, 'output_file') and step.output_file: work_dir = getattr(ctx, 'work_dir', '') kwargs[param_name] = step.output_file.format(work_dir=work_dir) continue # 优先级 2:处理跨步骤的映射逻辑 ctx_key = param_name if hasattr(step, 'parameter_map') and step.parameter_map: for k, v in step.parameter_map.items(): if v == param_name: ctx_key = k break # 优先级 3:从全局大背包 ctx 中取(排在最后) if hasattr(ctx, ctx_key): kwargs[param_name] = getattr(ctx, ctx_key) # 使用解包后的关键字参数调用底层函数 result = method(**kwargs) # 【产物接力 1】:如果底层函数返回了字典,直接合并到上下文 if isinstance(result, dict): for k, v in result.items(): setattr(ctx, k, v) # 【产物接力 2】:强制通过 StepSpec 的 output_file 模板注入 if hasattr(step, 'output_file') and step.output_file: work_dir = getattr(ctx, 'work_dir', '') actual_out_path = step.output_file.format(work_dir=work_dir) for prod in step.produces: if not hasattr(ctx, prod) or not getattr(ctx, prod): setattr(ctx, prod, actual_out_path) logger.info(f"[产物接力] 登记 {prod} = {actual_out_path}") except PipelineHalt: ctx.status[step.step_id] = "error" ctx.append_log(f"[RUNNER] PipelineHalt 硬终止 @ {step.step_id}") self._notify_step(step.step_id, "error", "预检失败,硬终止") break except Exception as e: ctx.status[step.step_id] = "error" error_summary.append((step.step_id, str(e))) ctx.last_error = f"{step.step_id}: {e!r}" ctx.append_log(f"[ERROR] 步骤 {step.step_id} 执行崩溃: {str(e)}") self._notify_step(step.step_id, "error", str(e)) break ctx.pipeline_end_time = time.time() ctx.error_summary = error_summary return ctx # ------------------------------------------------------------------ # ★ 智能补全:工作目录产物扫描 # ------------------------------------------------------------------ def _scan_workdir_outputs(self, ctx: PipelineContext) -> None: """扫描 work_dir 下所有步骤的默认产物路径,若存在则回填 ctx。 利用 spec.output_file 的 {work_dir} 占位符,展开为实际绝对路径。 存在则写入对应的 ctx 字段(produces),供后续步骤直接使用。 已在 ctx 中有值的字段不会被覆盖。 """ work_dir = ctx.get("work_dir") or "" if not work_dir: return for spec in self.steps: if not spec.produces: continue for produce_key in spec.produces: if ctx.get(produce_key): continue # 已有人工填写的值,不覆盖 resolved = self._resolve_path(spec.output_file, ctx) if resolved and os.path.exists(resolved): ctx.set(produce_key, resolved) ctx.append_log( f"[AUTO_FILL] 检测到已有产物,回填 {produce_key} = {resolved}" ) # ------------------------------------------------------------------ # ★ 智能补全:强制开启被静默跳过的步骤 # ------------------------------------------------------------------ def _auto_fill_missing_steps(self, ctx: PipelineContext) -> None: """检查所有 disabled 步骤。 若某步骤的 output_file 已在 work_dir 落盘(断点续跑), 说明该步骤之前已完成但被用户在 GUI 中禁用了。 此时系统自动重开启该步骤(forced=True),并将其加入 locked_steps。 同时,将已落盘的产物路径回填到对应的 ctx 字段, 确保下游步骤能正常拿到输入。 阻断性缺失(step1 img_path)已在 run() 入口硬校验,此处不处理。 """ newly_locked: List[str] = [] for spec in self.steps: if spec.enabled: continue # 用户主动开启的步骤不受影响 skip_set = getattr(ctx, '_skip_set', set()) if spec.step_id in skip_set: continue # 用户在 PreflightDialog 中手动忽略的步骤不自动补全 resolved = self._resolve_path(spec.output_file, ctx) if resolved and os.path.exists(resolved): # ── 该步骤已有产物但被禁用 → 自动开启 ── spec.enabled = True ctx.locked_steps.append(spec.step_id) newly_locked.append(spec.step_id) # 回填所有产物字段到 ctx for produce_key in spec.produces: if not ctx.get(produce_key): ctx.set(produce_key, resolved) ctx.append_log( f"[AUTO_FILL] 强制开启并回填 {spec.step_id} 产物 {produce_key} = {resolved}" ) ctx.append_log( f"\n{'='*60}\n" f" ⚡ 智能补全:步骤 {spec.step_id}({spec.description})\n" f" 原因:该步骤在 work_dir 中已有产物但被您在 GUI 中禁用了。\n" f" 操作:系统已自动开启该步骤,产物路径已回填。\n" f" 注意:运行期间该步骤已被锁定,您无法临时关闭。\n" f"{'='*60}\n" ) if newly_locked: self._notify_step( "全流程", "info", f"智能补全已自动开启 {len(newly_locked)} 个步骤:{newly_locked}" ) def _resolve_output_for_key( self, produce_key: str, ctx: PipelineContext ) -> Optional[str]: """根据 produces key 查找对应步骤的 output_file 并展开路径。""" for spec in self.steps: if produce_key in spec.produces: return self._resolve_path(spec.output_file, ctx) return None def _scan_single_step_outputs( self, spec: StepSpec, ctx: PipelineContext ) -> None: """扫描单个步骤的 work_dir 产物,回填 ctx(不覆盖已有值)。""" if not spec.produces: return for produce_key in spec.produces: if ctx.get(produce_key): continue resolved = self._resolve_path(spec.output_file, ctx) if resolved and os.path.exists(resolved): ctx.set(produce_key, resolved) ctx.append_log( f"[AUTO_FILL] 依赖唤醒后检测到产物,回填 {produce_key} = {resolved}" ) # ------------------------------------------------------------------ # 软预检警告(不再阻断) # ------------------------------------------------------------------ def _preflight_warnings(self, ctx: PipelineContext) -> None: """软预检警告:遍历所有步骤,检测可预见的运行时跳过。 所有缺失均以 warning 记录日志,不抛异常,不阻止执行。 GUI 层可通过回调函数 _notify_step 向用户展示警告列表。 """ warnings: List[str] = [] for spec in self.steps: if not spec.enabled: continue # ── Step4 csv_path 缺失警告 ── if spec.step_id == "step4": if not ctx.get("csv_path"): warnings.append( f"[{spec.step_id}] 缺少实测水质数据 (csv_path)," "步骤 5-9 将被自动跳过" ) # ── 磁盘文件缺失警告(已填充 ctx 但文件实际不存在)── for ctx_key in spec.required_input_files: value = ctx.get(ctx_key) if not value: continue if not os.path.exists(value): warnings.append( f"[{spec.step_id}] 磁盘文件缺失(但 ctx 已回填): {ctx_key} = {value}" ) if warnings: detail = "\n".join(f" - {w}" for w in warnings) ctx.append_log( f"[RUNNER] 【软预检警告】(流程将继续执行,缺失项将被自动跳过)\n{detail}" ) self._notify_step("全流程", "warning", f"预检警告:{len(warnings)} 项\n{detail}") # ------------------------------------------------------------------ # 单步调用 # ------------------------------------------------------------------ def _invoke(self, spec: StepSpec, ctx: PipelineContext) -> None: """调一个 step 方法:ctx 路径 → 形参;产出 → ctx 字段。""" ctx.append_log( f"[DEBUG] Step {spec.step_id} requires: {spec.requires}, " f"actual ctx data: {[ctx.get(k) for k in spec.requires]}" ) method = getattr(self.pipeline, spec.method_name, None) if method is None: ctx.append_log(f"[RUNNER] 步骤方法缺失: {spec.method_name}(跳过)") ctx.status[spec.step_id] = "skipped" return # 1) 把 ctx 路径作为形参注入 kwargs: Dict[str, Any] = {} for ctx_key in spec.requires: param_name = spec.parameter_map.get(ctx_key, self._default_param_name(ctx_key)) kwargs[param_name] = ctx.get(ctx_key) # 2) 允许用户在 ctx.user_config[step_id] 覆盖/补充(非空值才覆盖) user_overrides = ctx.user_config.get(spec.step_id) or {} if isinstance(user_overrides, dict): for k, v in user_overrides.items(): if v is not None and v != "": kwargs[k] = v # 3) 状态置 start ctx.append_log( f"[RUNNER] -> {spec.method_name}({list(kwargs.keys())})" ) ctx.status[spec.step_id] = "start" self._notify_step(spec.step_id, "start", spec.method_name) # 4) 执行(外层 run() 统一捕获异常) t0 = time.time() result = method(**kwargs) ctx.status[spec.step_id] = "completed" ctx.step_timings[spec.step_id] = time.time() - t0 # 5) 产出收割 self._harvest(spec, result, ctx) self._notify_step( spec.step_id, "completed", str(result)[:200] if result is not None else "", ) # ------------------------------------------------------------------ # 产出收割 # ------------------------------------------------------------------ def _harvest(self, spec: StepSpec, result: Any, ctx: PipelineContext) -> None: """把 step 方法返回值灌入 ctx 的 produces 字段。""" if not spec.produces: return if isinstance(result, dict): for produce_key in spec.produces: if produce_key in result: ctx.set(produce_key, result[produce_key]) elif result is not None: ctx.set(spec.produces[0], result) # ------------------------------------------------------------------ # 断点续跑辅助 # ------------------------------------------------------------------ def _resolve_path( self, template: Optional[str], ctx: PipelineContext ) -> Optional[str]: """解析模板中的 {work_dir} 占位符,返回展开后的绝对路径或 None。""" if not template: return None work_dir = ctx.get("work_dir") or "" try: return template.format(work_dir=work_dir) except (KeyError, ValueError): return template def _restore_outputs_from_ctx(self, ctx: PipelineContext) -> None: """诊断日志:记录 ctx 中已有的非 None 产物。""" for spec in self.steps: if not (spec.enabled and spec.produces): continue for key in spec.produces: val = ctx.get(key) if val: ctx.append_log( f"[RUNNER] 断点续跑检测: {spec.step_id} 已有 {key} = {val}" ) def _restore_ctx_from_output( self, spec: StepSpec, resolved_path: str, ctx: PipelineContext ) -> None: """断点跳过时:将已存在的 output_file 写回 ctx 所有 produces 字段,供下游使用。 接力棒断链修复:遍历 spec.produces 逐一注册,不遗漏任何下游可能依赖的 key。 """ if not spec.produces: return for produce_key in spec.produces: ctx.set(produce_key, resolved_path) # ------------------------------------------------------------------ # 工具 # ------------------------------------------------------------------ @staticmethod def _default_param_name(ctx_key: str) -> str: """默认原样返回 ctx 键名作为形参名。特殊缩写由 parameter_map 显式处理。""" return ctx_key def _notify_step(self, step_id: str, status: str, message: str) -> None: """通过 pipeline.callback 通知 GUI 当前步骤状态。""" notify = getattr(self.pipeline, "_notify", None) if callable(notify): try: notify(step_id, status, message) except Exception: pass