Files
WQ_GUI/src/core/pipeline/runner.py

650 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
PipelineRunner基于 StepSpec 声明式调度 14 个 step。
设计要点:
- StepSpec 声明 requiresctx 字段名列表)+ producesctx 字段名列表)
- 命名约定ctx 字段名 == panel key 名 == step 形参名(全链路无翻译)
- 步骤命名step_id 格式为 stepN 或 stepN_suffix无小数位method_name 与 step_id 对齐
- 调度顺序:按 PIPELINE_STEPS 列表顺序requires 缺则 skip
- 软取消:在每个 step 前检查 ctx.is_cancelled()
- 断点续跑spec.output_file 已落盘则跳过执行
- 错误汇总:全流程结束后 error_summary 记录所有 step 的异常
- 预检run() 入口硬校验 step1 img_path其余依赖通过智能补全 + 软警告处理
- PipelineHalt外层 run() 不 catch触发循环 break实现硬终止
- STEP_MAP旧 step_id → 新 step_id 双向映射,供 GUI 配置兼容使用
- duck-typed pipelinerunner 只调 getattr(pipeline, method_name),不强依赖类层级
"""
from __future__ import annotations
import inspect
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence
from .context import PipelineContext, STEP_MAP_OLD_TO_NEW, STEP_MAP_NEW_TO_OLD, resolve_step_id
logger = logging.getLogger(__name__)
# ============================================================
# 终止异常(外层 run() 不 catch触发循环 break
# ============================================================
class PipelineHalt(Exception):
"""不可恢复的错误,在 run() 循环中抛出后直接 break不走 Exception 处理分支。
适用场景:
- GUI 层通过 _notify 弹窗拦截后主动抛出的硬终止信号
"""
pass
# ============================================================
# StepSpec 声明式描述
# ============================================================
@dataclass
class StepSpec:
"""单个 step 的元信息(声明式,避免硬编码)"""
step_id: str
method_name: str
requires: List[str] # PipelineContext 字段名列表
produces: List[str] = field(default_factory=list) # 写入 ctx 的字段名列表
enabled: bool = True
parameter_map: Dict[str, str] = field(default_factory=dict)
# 当 requires 中任一字段为 None 时是否跳过;默认 True缺输入就 skip
skip_when_missing: bool = True
# 备注(仅用于文档生成 / 调试输出)
description: str = ""
# ★ 断点续跑:产物文件路径,支持 {work_dir} 占位符(运行时解析)
output_file: Optional[str] = None
# ★ 预检用:需要验证磁盘文件实际存在的 ctx key 列表
required_input_files: List[str] = field(default_factory=list)
# ============================================================
# 14 个 step 的声明表(顺序即调度顺序)
# step_id / method_name 均不含小数位,与前端显示对齐
# output_file / required_input_files 使用 {work_dir} 占位符,由 _resolve_path 展开
# ============================================================
PIPELINE_STEPS: List[StepSpec] = [
StepSpec(
step_id="step1", method_name="step1_generate_water_mask",
requires=["img_path"], produces=["water_mask_path"],
required_input_files=["img_path"],
output_file="{work_dir}/1_water_mask/water_mask.dat",
description="水域掩膜生成NDWI 或 SHP",
),
StepSpec(
step_id="step2", method_name="step2_find_glint_area",
requires=["img_path", "water_mask_path"], produces=["glint_mask_path"],
required_input_files=["img_path", "water_mask_path"],
output_file="{work_dir}/2_glint/glint_mask.dat",
description="耀斑区域检测",
),
StepSpec(
step_id="step3", method_name="step3_remove_glint",
requires=["img_path", "water_mask_path", "glint_mask_path"],
produces=["deglint_img_path"],
required_input_files=["img_path", "water_mask_path", "glint_mask_path"],
output_file="{work_dir}/3_deglint/deglint.bsq",
description="耀斑去除",
),
StepSpec(
step_id="step4", method_name="step4_process_csv",
requires=["csv_path"], produces=["processed_csv_path"],
required_input_files=["csv_path"],
output_file="{work_dir}/4_processed_data/processed_data.csv",
description="CSV 异常值清洗",
),
StepSpec(
step_id="step5", method_name="step5_extract_training_spectra",
requires=["deglint_img_path", "processed_csv_path", "csv_path", "boundary_path", "glint_mask_path"],
produces=["training_csv_path"],
parameter_map={
"processed_csv_path": "csv_path",
"csv_path": "_raw_csv_ignored",
},
skip_when_missing=False,
required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"],
output_file="{work_dir}/5_training_spectra/training_spectra.csv",
description="实测样本点光谱提取",
),
StepSpec(
step_id="step7", method_name="step7_water_quality_indices",
requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"],
required_input_files=["training_csv_path"],
output_file="{work_dir}/6_water_quality_indices/training_spectra_indices.csv",
description="水质参数指数计算双轨输出A轨宽表 + B轨单文件",
),
StepSpec(
step_id="step8", method_name="step8_ml_modeling",
requires=["training_csv_path"], produces=["models_dir"],
required_input_files=["training_csv_path"],
output_file="{work_dir}/7_Supervised_Model_Training/best_models.pkl",
description="ML 建模GridSearchCV / AutoML",
),
StepSpec(
step_id="step8_non_empirical_modeling",
method_name="step8_non_empirical_modeling",
requires=["training_csv_path"], produces=["models_dir"],
parameter_map={"training_csv_path": "csv_path"},
required_input_files=["training_csv_path"],
output_file="{work_dir}/8_Regression_Modeling/non_empirical_models.pkl",
description="非经验统计回归",
),
StepSpec(
step_id="step9", method_name="step9_watercolor_inversion",
requires=["deglint_img_path", "water_mask_path"], produces=["watercolor_index_dir"],
required_input_files=["deglint_img_path"],
output_file="{work_dir}/9_WaterColor_Index_Images",
description="水色指数反演BSQ 影像直接处理)",
),
StepSpec(
step_id="step10", method_name="step10_sampling",
requires=["deglint_img_path", "water_mask_path"], produces=["sampling_csv_path"],
required_input_files=["deglint_img_path", "water_mask_path"],
output_file="{work_dir}/4_sampling/sampling_spectra.csv",
description="整景密集采样点生成 + 光谱提取",
),
StepSpec(
step_id="step11_ml", method_name="step11_ml_prediction",
requires=["sampling_csv_path", "models_dir"], produces=["prediction_csv_path"],
required_input_files=["sampling_csv_path", "models_dir"],
output_file="{work_dir}/11_12_13_predictions/prediction_results.csv",
description="ML 模型预测(采样点)",
),
StepSpec(
step_id="step11", method_name="step11_non_empirical_prediction",
requires=["sampling_csv_path", "models_dir"], produces=["prediction_dir"],
parameter_map={"models_dir": "non_empirical_models_dir"},
required_input_files=["sampling_csv_path", "models_dir"],
output_file="{work_dir}/11_12_13_predictions/non_empirical_predictions",
description="非经验模型预测",
),
StepSpec(
step_id="step14", method_name="step14_distribution_map",
requires=["prediction_csv_path", "boundary_shp_path"],
produces=["distribution_map_path"],
required_input_files=["prediction_csv_path", "boundary_shp_path"],
output_file="{work_dir}/distribution_map.png",
description="克里金插值成图",
),
]
# ============================================================
# PipelineRunner执行者
# ============================================================
class PipelineRunner:
"""按 StepSpec 调度 14 个 step 方法,支持软取消 + 断点续跑 + 错误汇总。
用法:
ctx = PipelineContext(img_path=..., work_dir=..., user_config=config)
runner = PipelineRunner(pipeline_instance)
result_ctx = runner.run(ctx, config=config) # 预检通过后开始执行
print(result_ctx.error_summary) # [(step_id, error_msg), ...]
"""
def __init__(self, pipeline, steps: Optional[Sequence[StepSpec]] = None):
self.pipeline = pipeline
self.steps: List[StepSpec] = list(steps) if steps else list(PIPELINE_STEPS)
# ------------------------------------------------------------------
# 主入口
# ------------------------------------------------------------------
def run(self, ctx: PipelineContext, config=None, skip_list: Optional[List[str]] = None) -> PipelineContext:
self.config = config or {}
skip_list = skip_list or []
logger.info("开始运行完整流程 (Runner 调度模式)...")
ctx.pipeline_start_time = time.time()
error_summary: List[tuple[str, str]] = []
skip_set = set(skip_list) if skip_list else set()
# ── ★ Step1 img_path 硬校验(缺失则立即终止整个流程) ──
if not ctx.get("img_path"):
msg = "【全流程预检失败】缺少参考影像路径 (img_path),流程无法启动。"
ctx.append_log(f"[RUNNER] {msg}")
self._notify_step("全流程", "error", msg)
ctx.last_error = msg
ctx.pipeline_end_time = time.time()
return ctx
# ── ★ 智能补全:扫描 work_dir 默认产物路径,回填 ctx ──
self._scan_workdir_outputs(ctx)
# ── ★ 自动补全缺失步骤work_dir 有产物则强制开启 + 回填路径 ──
self._auto_fill_missing_steps(ctx)
# ── 软预检警告(不再阻断,仅记录日志)──
self._preflight_warnings(ctx)
# 断点续跑预扫描ctx 已有产物则记录诊断日志
self._restore_outputs_from_ctx(ctx)
# 1. 暴力上下文注入:将 GUI config 中的所有参数强行塞入 ctx防丢失
for step_id, cfg in self.config.items():
if isinstance(cfg, dict):
for k, v in cfg.items():
if k != 'enabled' and v:
setattr(ctx, k, v)
# 2. 构建依赖提供者映射 (Provider Map)
provider_map = {}
for step in self.steps:
for prod in step.produces:
provider_map[prod] = step
# 3. 强力依赖级联唤醒 (Auto-Wakeup Engine)
changed = True
woke_up_steps = []
while changed:
changed = False
for step in self.steps:
if step.step_id in skip_set:
continue # 用户强踢的,绝不唤醒
step_cfg = self.config.setdefault(step.step_id, {})
if not step_cfg.get('enabled', True):
continue
for req in step.requires:
# 如果上下文缺这个参数
if not (hasattr(ctx, req) and getattr(ctx, req)):
provider = provider_map.get(req)
if provider and provider.step_id not in skip_set:
prov_cfg = self.config.setdefault(provider.step_id, {})
if not prov_cfg.get('enabled', True):
prov_cfg['enabled'] = True
changed = True
woke_up_steps.append(provider.step_id)
logger.info(f"[*] 自动唤醒: {provider.step_id} (为下游提供 {req})")
if woke_up_steps:
logger.info(f"★ 依赖唤醒完成,共唤醒 {len(woke_up_steps)} 个次/步骤")
# 4. 正式执行流水线
for step in self.steps:
# ── 软取消 ──
if ctx.is_cancelled():
ctx.append_log(f"[RUNNER] 收到取消信号,提前终止 @ {step.step_id}")
break
if step.step_id in skip_set:
ctx.status[step.step_id] = "user_skipped"
ctx.append_log(
f"\n{'='*60}\n"
f" ⚠ 用户强制跳过: {step.step_id}{step.description}\n"
f" 原因:用户在预检弹窗中勾选「忽略」,已确认跳过\n"
f"{'='*60}\n"
)
self._notify_step(step.step_id, "skipped", "用户强制跳过(预检弹窗)")
continue
step_cfg = self.config.get(step.step_id, {})
if not step_cfg.get('enabled', True):
continue
# 4.1 检查磁盘产物:如果已落盘,恢复上下文并跳过(拒绝静默跳过,必须打日志)
if step.output_file and os.path.exists(step.output_file):
for prod in step.produces:
if not (hasattr(ctx, prod) and getattr(ctx, prod)):
setattr(ctx, prod, step.output_file)
ctx.status[step.step_id] = "skipped"
ctx.append_log(f"[CACHE] 产物已存在,跳过运行并恢复上下文: {step.step_id}")
self._notify_step(step.step_id, "skipped", "产物已存在(断点续跑)")
continue
# 4.2 依赖死线检查
missing = [req for req in step.requires if not (hasattr(ctx, req) and getattr(ctx, req))]
if missing:
ctx.status[step.step_id] = "skipped"
reason = f"缺少必要的上下文参数,自动跳过: {missing}"
ctx.append_log(f"[RUNNER] 跳过 {step.step_id},仍缺少必要参数: {missing}")
self._notify_step(step.step_id, "skipped", reason)
continue
# 4.3 真正执行
ctx.append_log(f"[START] 正在执行步骤: {step.step_id}")
self._notify_step(step.step_id, "running", f"正在执行: {step.description}")
try:
method = getattr(self.pipeline, step.method_name)
sig = inspect.signature(method)
kwargs = {}
current_step_cfg = self.config.get(step.step_id, {})
for param_name in sig.parameters:
# 优先级 1直接使用当前步骤专属配置中的值
if param_name in current_step_cfg:
kwargs[param_name] = current_step_cfg[param_name]
continue
# 优先级 1.5:【核心修复】硬隔离 output_file防止被其他步骤的同名变量污染
if param_name == 'output_file' and hasattr(step, 'output_file') and step.output_file:
work_dir = getattr(ctx, 'work_dir', '')
kwargs[param_name] = step.output_file.format(work_dir=work_dir)
continue
# 优先级 2处理跨步骤的映射逻辑
ctx_key = param_name
if hasattr(step, 'parameter_map') and step.parameter_map:
for k, v in step.parameter_map.items():
if v == param_name:
ctx_key = k
break
# 优先级 3从全局大背包 ctx 中取(排在最后)
if hasattr(ctx, ctx_key):
kwargs[param_name] = getattr(ctx, ctx_key)
# 使用解包后的关键字参数调用底层函数
result = method(**kwargs)
# 【产物接力 1】如果底层函数返回了字典直接合并到上下文
if isinstance(result, dict):
for k, v in result.items():
setattr(ctx, k, v)
# 【产物接力 2】强制通过 StepSpec 的 output_file 模板注入
if hasattr(step, 'output_file') and step.output_file:
work_dir = getattr(ctx, 'work_dir', '')
actual_out_path = step.output_file.format(work_dir=work_dir)
for prod in step.produces:
if not hasattr(ctx, prod) or not getattr(ctx, prod):
setattr(ctx, prod, actual_out_path)
logger.info(f"[产物接力] 登记 {prod} = {actual_out_path}")
except PipelineHalt:
ctx.status[step.step_id] = "error"
ctx.append_log(f"[RUNNER] PipelineHalt 硬终止 @ {step.step_id}")
self._notify_step(step.step_id, "error", "预检失败,硬终止")
break
except Exception as e:
ctx.status[step.step_id] = "error"
error_summary.append((step.step_id, str(e)))
ctx.last_error = f"{step.step_id}: {e!r}"
ctx.append_log(f"[ERROR] 步骤 {step.step_id} 执行崩溃: {str(e)}")
self._notify_step(step.step_id, "error", str(e))
break
ctx.pipeline_end_time = time.time()
ctx.error_summary = error_summary
return ctx
# ------------------------------------------------------------------
# ★ 智能补全:工作目录产物扫描
# ------------------------------------------------------------------
def _scan_workdir_outputs(self, ctx: PipelineContext) -> None:
"""扫描 work_dir 下所有步骤的默认产物路径,若存在则回填 ctx。
利用 spec.output_file 的 {work_dir} 占位符,展开为实际绝对路径。
存在则写入对应的 ctx 字段produces供后续步骤直接使用。
已在 ctx 中有值的字段不会被覆盖。
"""
work_dir = ctx.get("work_dir") or ""
if not work_dir:
return
for spec in self.steps:
if not spec.produces:
continue
for produce_key in spec.produces:
if ctx.get(produce_key):
continue # 已有人工填写的值,不覆盖
resolved = self._resolve_path(spec.output_file, ctx)
if resolved and os.path.exists(resolved):
ctx.set(produce_key, resolved)
ctx.append_log(
f"[AUTO_FILL] 检测到已有产物,回填 {produce_key} = {resolved}"
)
# ------------------------------------------------------------------
# ★ 智能补全:强制开启被静默跳过的步骤
# ------------------------------------------------------------------
def _auto_fill_missing_steps(self, ctx: PipelineContext) -> None:
"""检查所有 disabled 步骤。
若某步骤的 output_file 已在 work_dir 落盘(断点续跑),
说明该步骤之前已完成但被用户在 GUI 中禁用了。
此时系统自动重开启该步骤forced=True并将其加入 locked_steps。
同时,将已落盘的产物路径回填到对应的 ctx 字段,
确保下游步骤能正常拿到输入。
阻断性缺失step1 img_path已在 run() 入口硬校验,此处不处理。
"""
newly_locked: List[str] = []
for spec in self.steps:
if spec.enabled:
continue # 用户主动开启的步骤不受影响
skip_set = getattr(ctx, '_skip_set', set())
if spec.step_id in skip_set:
continue # 用户在 PreflightDialog 中手动忽略的步骤不自动补全
resolved = self._resolve_path(spec.output_file, ctx)
if resolved and os.path.exists(resolved):
# ── 该步骤已有产物但被禁用 → 自动开启 ──
spec.enabled = True
ctx.locked_steps.append(spec.step_id)
newly_locked.append(spec.step_id)
# 回填所有产物字段到 ctx
for produce_key in spec.produces:
if not ctx.get(produce_key):
ctx.set(produce_key, resolved)
ctx.append_log(
f"[AUTO_FILL] 强制开启并回填 {spec.step_id} 产物 {produce_key} = {resolved}"
)
ctx.append_log(
f"\n{'='*60}\n"
f" ⚡ 智能补全:步骤 {spec.step_id}{spec.description}\n"
f" 原因:该步骤在 work_dir 中已有产物但被您在 GUI 中禁用了。\n"
f" 操作:系统已自动开启该步骤,产物路径已回填。\n"
f" 注意:运行期间该步骤已被锁定,您无法临时关闭。\n"
f"{'='*60}\n"
)
if newly_locked:
self._notify_step(
"全流程",
"info",
f"智能补全已自动开启 {len(newly_locked)} 个步骤:{newly_locked}"
)
def _resolve_output_for_key(
self, produce_key: str, ctx: PipelineContext
) -> Optional[str]:
"""根据 produces key 查找对应步骤的 output_file 并展开路径。"""
for spec in self.steps:
if produce_key in spec.produces:
return self._resolve_path(spec.output_file, ctx)
return None
def _scan_single_step_outputs(
self, spec: StepSpec, ctx: PipelineContext
) -> None:
"""扫描单个步骤的 work_dir 产物,回填 ctx不覆盖已有值"""
if not spec.produces:
return
for produce_key in spec.produces:
if ctx.get(produce_key):
continue
resolved = self._resolve_path(spec.output_file, ctx)
if resolved and os.path.exists(resolved):
ctx.set(produce_key, resolved)
ctx.append_log(
f"[AUTO_FILL] 依赖唤醒后检测到产物,回填 {produce_key} = {resolved}"
)
# ------------------------------------------------------------------
# 软预检警告(不再阻断)
# ------------------------------------------------------------------
def _preflight_warnings(self, ctx: PipelineContext) -> None:
"""软预检警告:遍历所有步骤,检测可预见的运行时跳过。
所有缺失均以 warning 记录日志,不抛异常,不阻止执行。
GUI 层可通过回调函数 _notify_step 向用户展示警告列表。
"""
warnings: List[str] = []
for spec in self.steps:
if not spec.enabled:
continue
# ── Step4 csv_path 缺失警告 ──
if spec.step_id == "step4":
if not ctx.get("csv_path"):
warnings.append(
f"[{spec.step_id}] 缺少实测水质数据 (csv_path)"
"步骤 5-9 将被自动跳过"
)
# ── 磁盘文件缺失警告(已填充 ctx 但文件实际不存在)──
for ctx_key in spec.required_input_files:
value = ctx.get(ctx_key)
if not value:
continue
if not os.path.exists(value):
warnings.append(
f"[{spec.step_id}] 磁盘文件缺失(但 ctx 已回填): {ctx_key} = {value}"
)
if warnings:
detail = "\n".join(f" - {w}" for w in warnings)
ctx.append_log(
f"[RUNNER] 【软预检警告】(流程将继续执行,缺失项将被自动跳过)\n{detail}"
)
self._notify_step("全流程", "warning", f"预检警告:{len(warnings)}\n{detail}")
# ------------------------------------------------------------------
# 单步调用
# ------------------------------------------------------------------
def _invoke(self, spec: StepSpec, ctx: PipelineContext) -> None:
"""调一个 step 方法ctx 路径 → 形参;产出 → ctx 字段。"""
ctx.append_log(
f"[DEBUG] Step {spec.step_id} requires: {spec.requires}, "
f"actual ctx data: {[ctx.get(k) for k in spec.requires]}"
)
method = getattr(self.pipeline, spec.method_name, None)
if method is None:
ctx.append_log(f"[RUNNER] 步骤方法缺失: {spec.method_name}(跳过)")
ctx.status[spec.step_id] = "skipped"
return
# 1) 把 ctx 路径作为形参注入
kwargs: Dict[str, Any] = {}
for ctx_key in spec.requires:
param_name = spec.parameter_map.get(ctx_key, self._default_param_name(ctx_key))
kwargs[param_name] = ctx.get(ctx_key)
# 2) 允许用户在 ctx.user_config[step_id] 覆盖/补充(非空值才覆盖)
user_overrides = ctx.user_config.get(spec.step_id) or {}
if isinstance(user_overrides, dict):
for k, v in user_overrides.items():
if v is not None and v != "":
kwargs[k] = v
# 3) 状态置 start
ctx.append_log(
f"[RUNNER] -> {spec.method_name}({list(kwargs.keys())})"
)
ctx.status[spec.step_id] = "start"
self._notify_step(spec.step_id, "start", spec.method_name)
# 4) 执行(外层 run() 统一捕获异常)
t0 = time.time()
result = method(**kwargs)
ctx.status[spec.step_id] = "completed"
ctx.step_timings[spec.step_id] = time.time() - t0
# 5) 产出收割
self._harvest(spec, result, ctx)
self._notify_step(
spec.step_id, "completed",
str(result)[:200] if result is not None else "",
)
# ------------------------------------------------------------------
# 产出收割
# ------------------------------------------------------------------
def _harvest(self, spec: StepSpec, result: Any, ctx: PipelineContext) -> None:
"""把 step 方法返回值灌入 ctx 的 produces 字段。"""
if not spec.produces:
return
if isinstance(result, dict):
for produce_key in spec.produces:
if produce_key in result:
ctx.set(produce_key, result[produce_key])
elif result is not None:
ctx.set(spec.produces[0], result)
# ------------------------------------------------------------------
# 断点续跑辅助
# ------------------------------------------------------------------
def _resolve_path(
self, template: Optional[str], ctx: PipelineContext
) -> Optional[str]:
"""解析模板中的 {work_dir} 占位符,返回展开后的绝对路径或 None。"""
if not template:
return None
work_dir = ctx.get("work_dir") or ""
try:
return template.format(work_dir=work_dir)
except (KeyError, ValueError):
return template
def _restore_outputs_from_ctx(self, ctx: PipelineContext) -> None:
"""诊断日志:记录 ctx 中已有的非 None 产物。"""
for spec in self.steps:
if not (spec.enabled and spec.produces):
continue
for key in spec.produces:
val = ctx.get(key)
if val:
ctx.append_log(
f"[RUNNER] 断点续跑检测: {spec.step_id} 已有 {key} = {val}"
)
def _restore_ctx_from_output(
self, spec: StepSpec, resolved_path: str, ctx: PipelineContext
) -> None:
"""断点跳过时:将已存在的 output_file 写回 ctx 所有 produces 字段,供下游使用。
接力棒断链修复:遍历 spec.produces 逐一注册,不遗漏任何下游可能依赖的 key。
"""
if not spec.produces:
return
for produce_key in spec.produces:
ctx.set(produce_key, resolved_path)
# ------------------------------------------------------------------
# 工具
# ------------------------------------------------------------------
@staticmethod
def _default_param_name(ctx_key: str) -> str:
"""默认原样返回 ctx 键名作为形参名。特殊缩写由 parameter_map 显式处理。"""
return ctx_key
def _notify_step(self, step_id: str, status: str, message: str) -> None:
"""通过 pipeline.callback 通知 GUI 当前步骤状态。"""
notify = getattr(self.pipeline, "_notify", None)
if callable(notify):
try:
notify(step_id, status, message)
except Exception:
pass