147 lines
6.3 KiB
Python
147 lines
6.3 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
PipelineContext:内存级数据载体,跨 14 个 step 传递路径与元信息。
|
||
|
||
设计原则:
|
||
- 所有路径字段以 `_path` 为后缀(与 step 方法形参命名约定一致)
|
||
- 字段值可缺省(None),由 StepSpec.requires 在调度时注入
|
||
- dataclass + field(default_factory=dict) 支持原地增删
|
||
- 不放 GUI 状态(避免循环依赖)
|
||
- 不绑具体 step 方法(duck-typed cancellation / log append)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Dict, List, Optional, Set
|
||
|
||
|
||
# ============================================================
|
||
# 步骤命名映射(定义在叶子节点,打破循环依赖)
|
||
# ============================================================
|
||
|
||
STEP_MAP_OLD_TO_NEW: Dict[str, str] = {
|
||
"step5_5": "step8",
|
||
"step6_5": "step8_non_empirical_modeling",
|
||
"step6_75": "step9",
|
||
"step8_5": "step11",
|
||
"step8_75": "step12",
|
||
"step7": "step10",
|
||
"step8": "step11_ml",
|
||
"step9": "step14",
|
||
}
|
||
|
||
STEP_MAP_NEW_TO_OLD: Dict[str, str] = {v: k for k, v in STEP_MAP_OLD_TO_NEW.items()}
|
||
|
||
ALL_STEP_IDS: Set[str] = set(STEP_MAP_OLD_TO_NEW.keys()) | set(STEP_MAP_OLD_TO_NEW.values())
|
||
|
||
|
||
def resolve_step_id(step_id: str) -> str:
|
||
"""将任意 step_id 转换为标准新格式。"""
|
||
if step_id in STEP_MAP_OLD_TO_NEW:
|
||
return STEP_MAP_OLD_TO_NEW[step_id]
|
||
return step_id
|
||
|
||
|
||
@dataclass
|
||
class PipelineContext:
|
||
"""流水线运行上下文(在 14 个 step 之间传递的内存字典)
|
||
|
||
字段命名约定:
|
||
- 路径类字段名 = panel key 名 = step 形参名(全链路无翻译)
|
||
- 训练/产物 CSV 用 `_path` 后缀(如 training_csv_path / water_mask_path)
|
||
- 入参影像/CSV 沿用 panel 原名(img_path / csv_path),无 `_path` 后缀
|
||
- 目录类字段无 `_path` 后缀(如 models_dir / prediction_dir)
|
||
- 元信息字段无后缀(如 user_config / status / log)
|
||
"""
|
||
|
||
# ── 11 个 step 的入参/产物(按 step 顺序排列;字段名 = panel key = step 形参) ──
|
||
img_path: Optional[str] = None # Step 1/2/3 入参:原始影像
|
||
water_mask_path: Optional[str] = None # Step 1 出 → Step 2/3/7 入
|
||
glint_mask_path: Optional[str] = None # Step 2 出 → Step 3/7 入
|
||
deglint_img_path: Optional[str] = None # Step 3 出 → Step 5/7 入
|
||
csv_path: Optional[str] = None # Step 4/5/6_5/6_75 入参:原始/训练 CSV
|
||
processed_csv_path: Optional[str] = None # Step 4 出 → Step 5 入
|
||
training_csv_path: Optional[str] = None # Step 5 出 → Step 5_5/6/6_5/6_75 入
|
||
boundary_path: Optional[str] = None # Step 5 入参:边界 SHP(panel step5 名)
|
||
indices_path: Optional[str] = None # Step 5.5 出
|
||
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/8_5/8_75/9 入
|
||
prediction_csv_path: Optional[str] = None # Step 8 出 → Step 9 入
|
||
distribution_map_path: Optional[str] = None # Step 9 出
|
||
boundary_shp_path: Optional[str] = None # Step 9 入参:边界 SHP(panel step9 名)
|
||
formula_csv_path: Optional[str] = None # Step 8_75 入参:公式 CSV
|
||
|
||
# ── 目录类(命名不带 _path 以示区别) ──
|
||
models_dir: Optional[str] = None
|
||
prediction_dir: Optional[str] = None
|
||
work_dir: Optional[str] = None
|
||
|
||
# ── Step 6 训练产物(AutoML 模式有,常规模式为空) ──
|
||
model_files: List[str] = field(default_factory=list)
|
||
|
||
# ── 元信息(三件套:用户传的配置 / 取消事件 / 状态) ──
|
||
user_config: Dict[str, Any] = field(default_factory=dict)
|
||
cancel_event: Optional[Any] = None # duck-typed threading.Event / asyncio.Event
|
||
status: Dict[str, str] = field(default_factory=dict) # {step_id: 'start'/'completed'/'skipped'/'error'}
|
||
log: List[str] = field(default_factory=list)
|
||
|
||
# ── 诊断 ──
|
||
step_timings: Dict[str, float] = field(default_factory=dict)
|
||
pipeline_start_time: Optional[float] = None
|
||
pipeline_end_time: Optional[float] = None
|
||
last_error: Optional[str] = None
|
||
|
||
# ── 错误汇总(全流程结束后可用) ──
|
||
error_summary: List[tuple[str, str]] = field(default_factory=list)
|
||
# ── 出错时立即停止全流程(默认 False:继续后续步骤) ──
|
||
breakpoint_on_error: bool = False
|
||
# ── ★ 智能补全锁定步骤列表(由 _auto_fill_missing_steps 自动开启的步骤) ──
|
||
# GUI 层读取此字段,在运行期间禁用对应面板的启用复选框
|
||
locked_steps: List[str] = field(default_factory=list)
|
||
|
||
# ============================================================
|
||
# 读写辅助
|
||
# ============================================================
|
||
|
||
def step_id(self, step_id: str) -> str:
|
||
"""将任意 step_id(可能是旧名)转换为标准新格式。
|
||
|
||
用法示例:
|
||
ctx.status[ctx.step_id('step6_5')] # 'step8_non_empirical_modeling'
|
||
ctx.user_config[ctx.step_id('step8_5')] # 'step11'
|
||
"""
|
||
if step_id in STEP_MAP_OLD_TO_NEW:
|
||
return STEP_MAP_OLD_TO_NEW[step_id]
|
||
return step_id
|
||
|
||
def set(self, key: str, value: Any) -> None:
|
||
"""原地写入任意属性。
|
||
|
||
允许动态字段(如 'report_path')直接挂在 __dict__ 上,
|
||
避免因静态字段缺失而抛 AttributeError。
|
||
"""
|
||
object.__setattr__(self, key, value)
|
||
|
||
def get(self, key: str, default: Any = None) -> Any:
|
||
"""原地读出,缺 key 不抛错。"""
|
||
return getattr(self, key, default)
|
||
|
||
def is_cancelled(self) -> bool:
|
||
"""统一软取消检查入口(duck-typed)。
|
||
|
||
支持:
|
||
- threading.Event(.is_set())
|
||
- asyncio.Event(loop-bound,is_set 同步接口存在)
|
||
- 自定义 .is_set() / .cancelled 属性
|
||
"""
|
||
ev = self.cancel_event
|
||
if ev is None:
|
||
return False
|
||
is_set = getattr(ev, "is_set", None)
|
||
if callable(is_set):
|
||
return bool(is_set())
|
||
return bool(getattr(ev, "cancelled", False))
|
||
|
||
def append_log(self, msg: str) -> None:
|
||
"""写入日志列表(也用于主进程 stdout 调试)。"""
|
||
self.log.append(msg)
|