Files
WQ_GUI/src/core/pipeline/context.py

147 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
PipelineContext内存级数据载体跨 14 个 step 传递路径与元信息。
设计原则:
- 所有路径字段以 `_path` 为后缀(与 step 方法形参命名约定一致)
- 字段值可缺省None由 StepSpec.requires 在调度时注入
- dataclass + field(default_factory=dict) 支持原地增删
- 不放 GUI 状态(避免循环依赖)
- 不绑具体 step 方法duck-typed cancellation / log append
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
# ============================================================
# 步骤命名映射(定义在叶子节点,打破循环依赖)
# ============================================================
STEP_MAP_OLD_TO_NEW: Dict[str, str] = {
"step5_5": "step8",
"step6_5": "step8_non_empirical_modeling",
"step6_75": "step9",
"step8_5": "step11",
"step8_75": "step12",
"step7": "step10",
"step8": "step11_ml",
"step9": "step14",
}
STEP_MAP_NEW_TO_OLD: Dict[str, str] = {v: k for k, v in STEP_MAP_OLD_TO_NEW.items()}
ALL_STEP_IDS: Set[str] = set(STEP_MAP_OLD_TO_NEW.keys()) | set(STEP_MAP_OLD_TO_NEW.values())
def resolve_step_id(step_id: str) -> str:
"""将任意 step_id 转换为标准新格式。"""
if step_id in STEP_MAP_OLD_TO_NEW:
return STEP_MAP_OLD_TO_NEW[step_id]
return step_id
@dataclass
class PipelineContext:
"""流水线运行上下文(在 14 个 step 之间传递的内存字典)
字段命名约定:
- 路径类字段名 = panel key 名 = step 形参名(全链路无翻译)
- 训练/产物 CSV 用 `_path` 后缀(如 training_csv_path / water_mask_path
- 入参影像/CSV 沿用 panel 原名img_path / csv_path无 `_path` 后缀
- 目录类字段无 `_path` 后缀(如 models_dir / prediction_dir
- 元信息字段无后缀(如 user_config / status / log
"""
# ── 11 个 step 的入参/产物(按 step 顺序排列;字段名 = panel key = step 形参) ──
img_path: Optional[str] = None # Step 1/2/3 入参:原始影像
water_mask_path: Optional[str] = None # Step 1 出 → Step 2/3/7 入
glint_mask_path: Optional[str] = None # Step 2 出 → Step 3/7 入
deglint_img_path: Optional[str] = None # Step 3 出 → Step 5/7 入
csv_path: Optional[str] = None # Step 4/5/6_5/6_75 入参:原始/训练 CSV
processed_csv_path: Optional[str] = None # Step 4 出 → Step 5 入
training_csv_path: Optional[str] = None # Step 5 出 → Step 5_5/6/6_5/6_75 入
boundary_path: Optional[str] = None # Step 5 入参:边界 SHPpanel step5 名)
indices_path: Optional[str] = None # Step 5.5 出
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/8_5/8_75/9 入
prediction_csv_path: Optional[str] = None # Step 8 出 → Step 9 入
distribution_map_path: Optional[str] = None # Step 9 出
boundary_shp_path: Optional[str] = None # Step 9 入参:边界 SHPpanel step9 名)
formula_csv_path: Optional[str] = None # Step 8_75 入参:公式 CSV
# ── 目录类(命名不带 _path 以示区别) ──
models_dir: Optional[str] = None
prediction_dir: Optional[str] = None
work_dir: Optional[str] = None
# ── Step 6 训练产物AutoML 模式有,常规模式为空) ──
model_files: List[str] = field(default_factory=list)
# ── 元信息(三件套:用户传的配置 / 取消事件 / 状态) ──
user_config: Dict[str, Any] = field(default_factory=dict)
cancel_event: Optional[Any] = None # duck-typed threading.Event / asyncio.Event
status: Dict[str, str] = field(default_factory=dict) # {step_id: 'start'/'completed'/'skipped'/'error'}
log: List[str] = field(default_factory=list)
# ── 诊断 ──
step_timings: Dict[str, float] = field(default_factory=dict)
pipeline_start_time: Optional[float] = None
pipeline_end_time: Optional[float] = None
last_error: Optional[str] = None
# ── 错误汇总(全流程结束后可用) ──
error_summary: List[tuple[str, str]] = field(default_factory=list)
# ── 出错时立即停止全流程(默认 False继续后续步骤 ──
breakpoint_on_error: bool = False
# ── ★ 智能补全锁定步骤列表(由 _auto_fill_missing_steps 自动开启的步骤) ──
# GUI 层读取此字段,在运行期间禁用对应面板的启用复选框
locked_steps: List[str] = field(default_factory=list)
# ============================================================
# 读写辅助
# ============================================================
def step_id(self, step_id: str) -> str:
"""将任意 step_id可能是旧名转换为标准新格式。
用法示例:
ctx.status[ctx.step_id('step6_5')] # 'step8_non_empirical_modeling'
ctx.user_config[ctx.step_id('step8_5')] # 'step11'
"""
if step_id in STEP_MAP_OLD_TO_NEW:
return STEP_MAP_OLD_TO_NEW[step_id]
return step_id
def set(self, key: str, value: Any) -> None:
"""原地写入任意属性。
允许动态字段(如 'report_path')直接挂在 __dict__ 上,
避免因静态字段缺失而抛 AttributeError。
"""
object.__setattr__(self, key, value)
def get(self, key: str, default: Any = None) -> Any:
"""原地读出,缺 key 不抛错。"""
return getattr(self, key, default)
def is_cancelled(self) -> bool:
"""统一软取消检查入口duck-typed
支持:
- threading.Event.is_set()
- asyncio.Eventloop-boundis_set 同步接口存在)
- 自定义 .is_set() / .cancelled 属性
"""
ev = self.cancel_event
if ev is None:
return False
is_set = getattr(ev, "is_set", None)
if callable(is_set):
return bool(is_set())
return bool(getattr(ev, "cancelled", False))
def append_log(self, msg: str) -> None:
"""写入日志列表(也用于主进程 stdout 调试)。"""
self.log.append(msg)