# -*- coding: utf-8 -*- """ PipelineContext:内存级数据载体,跨 14 个 step 传递路径与元信息。 设计原则: - 所有路径字段以 `_path` 为后缀(与 step 方法形参命名约定一致) - 字段值可缺省(None),由 StepSpec.requires 在调度时注入 - dataclass + field(default_factory=dict) 支持原地增删 - 不放 GUI 状态(避免循环依赖) - 不绑具体 step 方法(duck-typed cancellation / log append) """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set # ============================================================ # 步骤命名映射(定义在叶子节点,打破循环依赖) # ============================================================ STEP_MAP_OLD_TO_NEW: Dict[str, str] = { "step5_5": "step8", "step6_5": "step8_non_empirical_modeling", "step6_75": "step9", "step8_5": "step11", "step8_75": "step12", "step7": "step10", "step8": "step11_ml", "step9": "step14", } STEP_MAP_NEW_TO_OLD: Dict[str, str] = {v: k for k, v in STEP_MAP_OLD_TO_NEW.items()} ALL_STEP_IDS: Set[str] = set(STEP_MAP_OLD_TO_NEW.keys()) | set(STEP_MAP_OLD_TO_NEW.values()) def resolve_step_id(step_id: str) -> str: """将任意 step_id 转换为标准新格式。""" if step_id in STEP_MAP_OLD_TO_NEW: return STEP_MAP_OLD_TO_NEW[step_id] return step_id @dataclass class PipelineContext: """流水线运行上下文(在 14 个 step 之间传递的内存字典) 字段命名约定: - 路径类字段名 = panel key 名 = step 形参名(全链路无翻译) - 训练/产物 CSV 用 `_path` 后缀(如 training_csv_path / water_mask_path) - 入参影像/CSV 沿用 panel 原名(img_path / csv_path),无 `_path` 后缀 - 目录类字段无 `_path` 后缀(如 models_dir / prediction_dir) - 元信息字段无后缀(如 user_config / status / log) """ # ── 11 个 step 的入参/产物(按 step 顺序排列;字段名 = panel key = step 形参) ── img_path: Optional[str] = None # Step 1/2/3 入参:原始影像 water_mask_path: Optional[str] = None # Step 1 出 → Step 2/3/7 入 glint_mask_path: Optional[str] = None # Step 2 出 → Step 3/7 入 deglint_img_path: Optional[str] = None # Step 3 出 → Step 5/7 入 csv_path: Optional[str] = None # Step 4/5/6_5/6_75 入参:原始/训练 CSV processed_csv_path: Optional[str] = None # Step 4 出 → Step 5 入 training_csv_path: Optional[str] = None # Step 5 出 → Step 5_5/6/6_5/6_75 入 boundary_path: Optional[str] = None # Step 5 入参:边界 SHP(panel step5 名) indices_path: Optional[str] = None # Step 5.5 出 sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/8_5/8_75/9 入 prediction_csv_path: Optional[str] = None # Step 8 出 → Step 9 入 distribution_map_path: Optional[str] = None # Step 9 出 boundary_shp_path: Optional[str] = None # Step 9 入参:边界 SHP(panel step9 名) formula_csv_path: Optional[str] = None # Step 8_75 入参:公式 CSV # ── 目录类(命名不带 _path 以示区别) ── models_dir: Optional[str] = None prediction_dir: Optional[str] = None work_dir: Optional[str] = None # ── Step 6 训练产物(AutoML 模式有,常规模式为空) ── model_files: List[str] = field(default_factory=list) # ── 元信息(三件套:用户传的配置 / 取消事件 / 状态) ── user_config: Dict[str, Any] = field(default_factory=dict) cancel_event: Optional[Any] = None # duck-typed threading.Event / asyncio.Event status: Dict[str, str] = field(default_factory=dict) # {step_id: 'start'/'completed'/'skipped'/'error'} log: List[str] = field(default_factory=list) # ── 诊断 ── step_timings: Dict[str, float] = field(default_factory=dict) pipeline_start_time: Optional[float] = None pipeline_end_time: Optional[float] = None last_error: Optional[str] = None # ── 错误汇总(全流程结束后可用) ── error_summary: List[tuple[str, str]] = field(default_factory=list) # ── 出错时立即停止全流程(默认 False:继续后续步骤) ── breakpoint_on_error: bool = False # ── ★ 智能补全锁定步骤列表(由 _auto_fill_missing_steps 自动开启的步骤) ── # GUI 层读取此字段,在运行期间禁用对应面板的启用复选框 locked_steps: List[str] = field(default_factory=list) # ============================================================ # 读写辅助 # ============================================================ def step_id(self, step_id: str) -> str: """将任意 step_id(可能是旧名)转换为标准新格式。 用法示例: ctx.status[ctx.step_id('step6_5')] # 'step8_non_empirical_modeling' ctx.user_config[ctx.step_id('step8_5')] # 'step11' """ if step_id in STEP_MAP_OLD_TO_NEW: return STEP_MAP_OLD_TO_NEW[step_id] return step_id def set(self, key: str, value: Any) -> None: """原地写入任意属性。 允许动态字段(如 'report_path')直接挂在 __dict__ 上, 避免因静态字段缺失而抛 AttributeError。 """ object.__setattr__(self, key, value) def get(self, key: str, default: Any = None) -> Any: """原地读出,缺 key 不抛错。""" return getattr(self, key, default) def is_cancelled(self) -> bool: """统一软取消检查入口(duck-typed)。 支持: - threading.Event(.is_set()) - asyncio.Event(loop-bound,is_set 同步接口存在) - 自定义 .is_set() / .cancelled 属性 """ ev = self.cancel_event if ev is None: return False is_set = getattr(ev, "is_set", None) if callable(is_set): return bool(is_set()) return bool(getattr(ev, "cancelled", False)) def append_log(self, msg: str) -> None: """写入日志列表(也用于主进程 stdout 调试)。""" self.log.append(msg)