feat(step8): 外部模型从单文件升级为母文件夹多模型字典扫描

This commit is contained in:
DXC
2026-06-08 09:56:02 +08:00
parent 4efe5b871e
commit 2b76d7908f
12 changed files with 935 additions and 29 deletions

View File

@ -17,7 +17,8 @@
"Bash(\"d:\\111\\office\\zhlduijie\\1.wq\\wq_gui\\venv\\scripts\\python.exe\" *)",
"Bash(c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe *)",
"Bash(venv\\scripts\\python.exe *)",
"Bash(findstr *)"
"Bash(findstr *)",
"Bash(select-string *)"
]
},
"$version": 4

View File

@ -16,7 +16,8 @@
"Bash(.\\venv\\scripts\\python.exe *)",
"Bash(\"d:\\111\\office\\zhlduijie\\1.wq\\wq_gui\\venv\\scripts\\python.exe\" *)",
"Bash(c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe *)",
"Bash(venv\\scripts\\python.exe *)"
"Bash(venv\\scripts\\python.exe *)",
"Bash(findstr *)"
]
},
"$version": 4

View File

@ -0,0 +1,309 @@
---
name: PipelineRunner Facade 防御性 kwargs 兜底
description: WQ_GUI 14 个 stepX_... Facade 方法必须以 **kwargs 收尾——配合 PipelineRunner 调度模式杜绝 "unexpected keyword argument" TypeError
source: auto-skill
extracted_at: '2026-06-04T00:54:50.036Z'
---
# PipelineRunner Facade 防御性 kwargs 兜底
## 适用场景
在 WQ_GUI 中,**任何被 `PipelineRunner` 调用的 14 个 `stepX_...` Facade 方法**(位于 `src/core/water_quality_inversion_pipeline_GUI.py`),其形参表末尾**必须**带 `**kwargs`。触发信号:
- 用户报错 `TypeError: stepX_xxx() got an unexpected keyword argument 'yyy'`
- 改 PIPELINE_STEPS 的 `requires` 列表
- 新增 / 重命名一个 step 方法
- 重构 PipelineRunner 的 `_invoke` 注入逻辑
## 核心原则
**Facade 的形参表 = 显式声明的形参 + `, **kwargs`**`kwargs` 必须**严格位于形参表最后**Python 语法硬要求)。
```python
# ✅ 正确
def step3_remove_glint(self, img_path: str,
method: str = "subtract_nir",
# ... 30+ 业务形参 ...
skip_dependency_check: bool = False,
**kwargs) -> str:
...
# ❌ 错误:**kwargs 不能放中间或前面
def step3_remove_glint(self, img_path, **kwargs, skip_dependency_check): # SyntaxError
```
## 为什么需要这层防御
`PipelineRunner._invoke``src/core/pipeline/runner.py`)会向方法注入两类参数:
| 层 | 来源 | 形参 key 怎么定 |
|---|---|---|
| **L2** | ctx 字段(按 `spec.requires` 列表) | `_default_param_name(ctx_key)` 默认去 `_path` 后缀 |
| **L3** | `ctx.user_config[step_id]`14 panel dict 整体) | dict 的 key 原样注入 |
**L2 触发 TypeError 的真实场景**2026-06-04 真实发生):
- `PIPELINE_STEPS.step3.requires = ["img_path", "water_mask_path", "glint_mask_path"]`
- Runner 注入 `kwargs["glint_mask_path"] = ctx.glint_mask_path`
- `step3_remove_glint` 形参表**没有** `glint_mask_path`(虽然业务上耀斑掩膜是在子调用 `GlintRemovalStep.run` 内部用的Facade 本身不接)
-**TypeError: step3_remove_glint() got an unexpected keyword argument 'glint_mask_path'**
**L3 触发 TypeError 的场景**
- user_config 14 panel dict 里残留了旧字段名 / 跨 step 串味的字段
- 任何 `user_config[step_id][k]` 中的 `k` 都会被注入
`**kwargs` 一次性解决两类问题。
## 14 个 Facade 方法清单(截至 2026-06-04 已全部带 **kwargs
| step | method | 形参表闭合示例 |
|---|---|---|
| 1 | `step1_generate_water_mask` | `output_path: Optional[str] = None, **kwargs) -> str:` |
| 2 | `step2_find_glint_area` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
| 3 | `step3_remove_glint` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
| 4 | `step4_process_csv` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
| 5 | `step5_extract_training_spectra` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
| 5.5 | `step5_5_calculate_water_quality_indices` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
| 6 | `step6_train_models` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
| 6.5 | `step6_5_non_empirical_modeling` | `skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:` |
| 6.75 | `step6_75_custom_regression` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
| 7 | `step7_generate_sampling_points` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
| 8 | `step8_predict_water_quality` | `skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:` |
| 8.5 | `step8_5_predict_with_non_empirical_models` | `skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:` |
| 8.75 | `step8_75_predict_with_custom_regression` | `skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:` |
| 9 | `step9_generate_distribution_map` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
## 标准操作
### 1. 编辑(最小外科手术式)
每个方法的最后形参是 `skip_dependency_check: bool = False`,把这一行改成:
```python
skip_dependency_check: bool = False, **kwargs) -> str:
```
**注意缩进必须与原行一致**13 空格 / 35 空格 / 47 空格 / 48 空格不等,按方法原始缩进)。用 `edit` 工具的 old_string **必须含 docstring 第一行**`"""步骤X: ..."""`)作唯一标识。
### 2. 验证
写一个临时校验脚本(项目根目录运行后删掉):
```python
import ast, re
target = r'D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\core\water_quality_inversion_pipeline_GUI.py'
src = open(target, encoding='utf-8-sig').read()
ast.parse(src) # AST 语法
expected = [
'step1_generate_water_mask', 'step2_find_glint_area', 'step3_remove_glint',
'step4_process_csv', 'step5_extract_training_spectra',
'step5_5_calculate_water_quality_indices', 'step6_train_models',
'step7_generate_sampling_points', 'step8_predict_water_quality',
'step9_generate_distribution_map', 'step6_5_non_empirical_modeling',
'step6_75_custom_regression', 'step8_5_predict_with_non_empirical_models',
'step8_75_predict_with_custom_regression',
]
pat = re.compile(r'def\s+(\w+)\s*\((?:[^()]|\([^()]*\))*\*\*kwargs\)\s*->\s*[^:]+:', re.DOTALL)
found = set(pat.findall(src))
print('Missing:', [m for m in expected if m not in found])
print('Extra :', [m for m in found if m not in expected])
```
期望输出:两个列表都为空。
### 3. Windows 执行
```bat
cd /d D:\111\office\ZHLduijie\1.WQ\WQ_GUI
py _check.py
del _check.py
```
> 必加 `utf-8-sig``water_quality_inversion_pipeline_GUI.py` 头部可能含 BOM`code_replacement_state_audit` skill 里有同样提示)。
## ⚠️ 这层防御不能解决什么
**` **kwargs` 兜底 ≠ 形参名错位修复**。如果 Runner 注入 `kwargs["training_csv_path"]` 而方法形参是 `csv_path`
-**不会报 TypeError**`training_csv_path``**kwargs` 收走)
-**但 `csv_path` 仍是 None**(方法体内的 `if csv_path is not None: ... else: ...` 走 fallback 分支,可能读 `self.training_csv_path` 哨兵)
**已知形参名错位的方法**2026-06-04 已修commit `64aa5b8`
| step | Runner 注入的 ctx key | 方法实际形参 | 实际落地的修复 |
|---|---|---|---|
| step6_5 | `training_csv_path` | `csv_path` | `parameter_map={"training_csv_path": "csv_path"}` |
| step6_75 | (已切到 `indices_path` | `csv_path` | ⚠️ 见下方"step6_75 路由修复"专题 |
| step8_5 | `models_dir` | `non_empirical_models_dir` | `parameter_map={"models_dir": "non_empirical_models_dir"}` |
| step8_75 | `models_dir` | `custom_regression_dir` | `parameter_map={"models_dir": "custom_regression_dir"}` |
> **parameter_map** 是 `StepSpec` 已有字段runner.py:33作用是把 ctx 字段重命名到方法形参名。**优先用 parameter_map 而非改 requires**——保持 ctx 字段语义清晰(声明式描述上游依赖),形参名是方法私有约定。
### step6_75 路由修复特殊案例2026-06-04 commit `64aa5b8`
`step6_75_custom_regression` **不是简单的 ctx 字段名错位**——方法体内的 fallback 链透露了**真正的数据源是 `indices_path`**
```python
# step6_75 形参csv_path
# 方法体 fallback:
# if csv_path is not None: input_csv = csv_path
# elif self.indices_path is not None: input_csv = self.indices_path # ★ 真相
```
`parameter_map``training_csv_path → csv_path` 看似能跑通,但**实际用错了数据**training_csv 是 step5 输出,不是 step6_75 想要的 indices CSV
**正确做法 = 同时改 requires + parameter_map**
```python
StepSpec(
step_id="step6_75", method_name="step6_75_custom_regression",
requires=["indices_path"], # ★ 从 training_csv_path 切到 indices_path
produces=["models_dir"],
parameter_map={"indices_path": "csv_path"}, # ★ 同步改 key
description="自定义回归分析",
),
```
**配合 `skip_when_missing` 兜底**:若用户没跑 step5_5`ctx.indices_path` 为 Nonerunner 自动 skip 整个 step6_75不会用错位数据静默执行。
**判别何时需要"路由切"vs"纯 rename"**
- 看方法体 fallback 链fallback 到 `self.indices_path`/`self.deglint_img_path`/其他 ctx 字段名 → **需要改 requires**
- 仅是 key 名字不同,方法体直接用形参 → 只改 parameter_map
## L2 注入顺序冲突:多个 requires 字段解析到同一形参名
### 场景
`StepSpec.requires` 里有**多个 ctx 字段**,经过 `_default_param_name` / `parameter_map` 解析后,**会落到同一个方法形参名**。L2 注入是**顺序敏感**的(后者覆盖前者),后注入的会**默默覆盖**前一个的赋值。
### 真实案例2026-06-04 step5 修复)
业务需求step5 真正需要 step4 产物 `processed_csv_path`,但**保留 raw `csv_path` 字段**作为 `user_config` 覆盖入口。
**❌ 错误的 parameter_map 写法**(用户原方案的隐藏 bug
```python
StepSpec(
step_id="step5", method_name="step5_extract_training_spectra",
requires=["deglint_img_path", "processed_csv_path", "csv_path", ...], # raw csv_path 也在
parameter_map={"processed_csv_path": "csv_path"}, # ★ 只映射了一个
...
)
```
L2 注入顺序(`runner.py:184-186`
1. `deglint_img_path``kwargs[deglint_img_path] = ctx.deglint_img_path`
2. `processed_csv_path``kwargs[csv_path] = ctx.processed_csv_path` ← 主路径生效
3. `csv_path`(无映射 → 默认)→ `kwargs[csv_path] = ctx.csv_path`**后注入 None 覆盖了主路径!**
**症状**step5 形参 `csv_path` 拿到的是 raw `ctx.csv_path`(通常是 None方法体 fallback 到 `self.processed_csv_path`——但这个 fallback 也可能是 Nonestep4 没跑step5 内部空跑 → "**静默错误**"。
**✅ 修法parameter_map 双向映射 + 占位名落 **kwargs**
```python
parameter_map={
"processed_csv_path": "csv_path", # 主路径(注入到方法形参)
"csv_path": "_raw_csv_ignored", # 占位(落到 step5 形参列表末尾的 **kwargs
},
```
注入顺序重排后:
- `processed_csv_path``kwargs[csv_path] = ctx.processed_csv_path` ← 主路径
- `csv_path``kwargs[_raw_csv_ignored] = ctx.csv_path` ← 落 **kwargs被吞
step5 形参 `csv_path` 最终拿到 `ctx.processed_csv_path` 的值 ✓。
### 验证模板(行为模拟)
写临时 `_verify_l2_inject.py` 复刻 `runner.py:184-186` 的 L2 注入循环,**不要只靠 AST 静态检查**——parameter_map 的 key 顺序、requires 的字段顺序都是动态的:
```python
import sys
sys.path.insert(0, r'D:\111\office\ZHLduijie\1.WQ\WQ_GUI')
from src.core.pipeline.context import PipelineContext
from src.core.pipeline.runner import PIPELINE_STEPS
spec5 = next(s for s in PIPELINE_STEPS if s.step_id == 'step5')
# 复刻 L2 注入(与 runner.py:184-186 完全一致)
def l2_inject(spec, ctx):
kwargs = {}
for ctx_key in spec.requires:
param_name = spec.parameter_map.get(ctx_key, ctx_key) # ★ 必须原样复刻
kwargs[param_name] = ctx.get(ctx_key)
return kwargs
# 关键断言
ctx = PipelineContext(processed_csv_path='/csv/processed.csv', csv_path='/csv/raw.csv')
kw = l2_inject(spec5, ctx)
assert kw.get('csv_path') == '/csv/processed.csv', \
f"形参 csv_path 应等于 processed_csv_path, 实际 {kw.get('csv_path')!r}"
assert '_raw_csv_ignored' in kw, "占位名应被注入到 kwargs"
print(f'OK: csv_path 形参 = {kw["csv_path"]!r} (processed, 主路径正确)')
print(f'OK: _raw_csv_ignored 占位 = {kw["_raw_csv_ignored"]!r} (raw, 落 **kwargs 被吞)')
```
跑完删掉:`py _verify_l2_inject.py & del _verify_l2_inject.py`Windows 一行模式)。
### 何时需要警惕这个冲突
修改 StepSpec 时检查清单(**先看这一段再写 parameter_map**
- [ ] **requires 里是否有多于 1 个 ctx 字段,解析后会落到同名方法形参?** 典型撞车:
- 同名字段(如 `processed_csv_path``csv_path` 都能映射到 `csv_path`
- 同名 `_default_param_name` 退化(如 `boundary_path``boundary_shp_path` 默认都映射到 `boundary_path`——但要注意 `_default_param_name` 已废弃去后缀,原样返回 ctx key所以 `boundary_path``boundary_shp_path` 默认就会撞 `boundary_path` / `boundary_shp_path` 不会撞,要撞就必须显式 parameter_map
- 字段名 + parameter_map 重命名撞车
- [ ] **"主路径"字段在 requires 列表靠前位置**(让后续"备路径"覆盖,但**这不解决冲突**——只要有第二次注入就一定会覆盖)
- [ ] **"备路径"字段**用占位名 `_xxx_ignored` / `_xxx_kwargs_only` 映射,让它落到 **kwargs
- [ ] **确认方法形参表末尾有 `**kwargs`** 兜底(`facade_kwargs_defense` skill 核心要求,已 14/14 落地)
### 反例(不要做)
- ❌ "我让 `csv_path` 不在 requires 里就行了"——会**丢失 user_config 覆盖入口**(如果用户想用 raw CSV 而不是 processed
- ❌ "改 L2 注入循环,让 parameter_map 字段最后注入"——会**改变 runner 通用语义**,影响所有 step 的注入顺序
- ❌ "加 `if param_name in kwargs: continue` 在 L2 注入里"——隐式"第一次优先"语义,新人读代码摸不着头脑
- ❌ "用 position in requires 做加权"——把数据语义哪个字段优先级高塞到列表顺序里runner 应该保持"声明式"
### 与"纯 rename"的区别
| 维度 | 纯 rename已有 skill 案例) | 多→1 冲突(本节案例) |
|---|---|---|
| 典型场景 | step6_5/6_75/8_5/8_751 个 requires 重命名到形参 | step52 个 requires 撞到同一形参 |
| parameter_map | 1 个 key→value | 2 个 key→同名 value + 占位名 |
| requires | 1 个字段 | 2 个字段(主 + 备) |
| 冲突来源 | 不会出现(单 key | 出现(顺序敏感 + 撞名) |
| 修法 | 只加 parameter_map | 双向 parameter_map + 占位名 |
## 与其他防御层的关系
```
PipelineRunner.run() 主循环
├─ L1 runner.py:152 skip_when_missing ─── ctx.<required> 全 None → skip step
├─ L2 runner.py:182 ctx 字段注入 ─── 形参表里没声明 → TypeError ⚠️ → **kwargs 兜底
├─ L3 runner.py:188 user_config 合并 ─── user_config 有"空字符串"/None → 跳过(上一轮加的守卫)✅
└─ L4 runner.py:211 except 捕获 ─── 业务抛异常 → ctx.status="error" + raise
```
`**kwargs`**L2 的"消极兜底"**——宁愿吞掉多余 key 也不报 TypeError。**真正的"积极修复"是 parameter_map**(让 ctx 字段名映射到正确形参名)。两层配合:
- **保守期间(重构初期)**:先 `**kwargs` 兜住TypeError 消失
- **稳定阶段**:补 parameter_map让方法收到正确数据
## 反例(不要做)
- ❌ "不写 `**kwargs`,靠 type hint + IDE 检查兜底"——Runner 是运行时注入IDE 看不到
- ❌ "把 `**kwargs` 放形参表中间"——Python 语法错误
- ❌ "改 requires 列表去掉冗余 ctx 字段"——会导致 `skip_when_missing` 误判(以为 step 不需要该 ctx 字段),应该用 `parameter_map` 重命名而非删除 requires
- ❌ "在 14 个 Facade 方法体里加 `if 'glint_mask_path' in kwargs: kwargs.pop('glint_mask_path')`"——脏活,且每个方法都要加,远不如 `**kwargs` 一行优雅
## 案例来源
- 2026-06-04 WQ_GUI PipelineRunner 迁移第二步
- 触发:`step3_remove_glint() got an unexpected keyword argument 'glint_mask_path'`
- 根因:`PIPELINE_STEPS.step3.requires` 写了 `glint_mask_path`,但 `GlintRemovalStep` 内部使用Facade 自身不接这个形参
- 落地14 个 Facade 全部加 `, **kwargs`0 个 TypeError
- 验证:临时 `_check.py` 14/14 命中 + AST 解析通过
-4 个 parameter_map 全部落地commit `64aa5b8`),含 step6_75 路由切到 indices_pathL3 非空过滤同步加入 `runner._invoke:188`
- 2026-06-04 step5 严格依赖修复:发现 L2 注入顺序冲突requires 多个字段解析到同一形参名),引入"双向 parameter_map + 占位名落 **kwargs"模式step5 形参 `csv_path` 真正接到 `processed_csv_path`step4 产物raw `csv_path` 保留为 user_config 覆盖入口,落占位名 `_raw_csv_ignored` 后被 `**kwargs` 吞。skip_when_missing 块同步加 `_notify` 通知,**拒绝静默跳过**15 条 _notify 全带具体 missing 字段列表证据)。

View File

@ -0,0 +1,294 @@
---
name: WQ_GUI PyQt5 面板外部模型导入模式
description: 在 Step8 等预测面板中通过 QRadioButton + FileSelectWidget + joblib.load 防御性加载实现"内置/导入"双模式切换的标准模式
source: auto-skill
extracted_at: '2026-06-08T01:38:14.481Z'
---
# WQ_GUI PyQt5 面板外部模型导入模式
## 适用场景
Step8机器学习预测、Step8_5、Step8_75 等面板需要同时支持:
1. **内置模式**:使用 `step6` 训练流程生成的模型目录
2. **导入模式**:用户手动选择本地预训练 `.joblib` 文件直接加载
---
## 1. 模板(可直接复制到 `__init__` + `init_ui`
```python
from PyQt5.QtWidgets import QRadioButton
class StepXPanel(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.current_model = None # ★ 外部模型实例缓存
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# -------- 模型来源选择(单选按钮组) --------
source_group = QGroupBox("模型来源")
source_layout = QVBoxLayout()
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
self.use_trained_model.setChecked(True)
source_layout.addWidget(self.use_trained_model)
source_layout.addWidget(self.use_external_model)
self.use_trained_model.toggled.connect(self._on_model_source_changed)
self.use_external_model.toggled.connect(self._on_model_source_changed)
source_group.setLayout(source_layout)
layout.addWidget(source_group)
# -------- 外部模型文件选择(条件显示) --------
self.external_model_widget = FileSelectWidget(
"预训练模型:",
"Joblib Files (*.joblib);;All Files (*.*)"
)
# FileSelectWidget 的 browse_btn 默认连着 open file 行为,
# 需要先断开默认连接,再接自定义槽
self.external_model_widget.browse_btn.clicked.disconnect()
self.external_model_widget.browse_btn.clicked.connect(self._browse_external_model)
self.external_model_widget.setVisible(False)
layout.addWidget(self.external_model_widget)
# ... 其余原有 UI ...
```
---
## 2. 槽函数模板
### `_on_model_source_changed`
单选按钮 `toggled` 信号在**两个**按钮上都会触发(点击 A 时 A 触发B 也触发),所以用 `if not checked: return` 让非选中分支短路。
```python
def _on_model_source_changed(self, checked: bool):
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
if not checked:
return
is_external = self.use_external_model.isChecked()
self.external_model_widget.setVisible(is_external)
# 切回"使用当前模型"时清空缓存,释放内存并避免误用旧模型
if not is_external:
self.current_model = None
```
### `_browse_external_model`
-`QFileDialog.getOpenFileName` 而非 `getExistingDirectory`
- 防御性解析两种格式:`{"model": pipeline, ...}`Step6 输出格式)和裸 `Pipeline` 对象
- 失败用 `QMessageBox.warning` 友善提示;成功用 `QMessageBox.information` 告知
```python
from PyQt5.QtWidgets import QFileDialog, QMessageBox
from pathlib import Path
def _browse_external_model(self):
"""浏览并加载外部 .joblib 预训练模型文件"""
default = self._get_default_work_dir()
path, _ = QFileDialog.getOpenFileName(
self,
"选择预训练模型 (.joblib)",
default,
"Joblib Files (*.joblib);;All Files (*.*)",
)
if not path:
return
try:
import joblib
loaded = joblib.load(path)
# 兼容两种格式dict{"model": obj} 或裸 Pipeline
if isinstance(loaded, dict) and "model" in loaded:
self.current_model = loaded["model"]
elif hasattr(loaded, "predict"):
self.current_model = loaded
else:
QMessageBox.warning(
self,
"模型格式错误",
f"无法识别的模型格式,文件内容类型为:{type(loaded).__name__}",
)
return
self.external_model_widget.set_path(path)
QMessageBox.information(
self,
"模型加载成功",
f"已加载模型:{Path(path).name}\n类型:{type(self.current_model).__name__}",
)
except Exception as e:
self.current_model = None
QMessageBox.warning(
self,
"模型加载失败",
f"加载模型时发生错误:\n{type(e).__name__}: {e}",
)
```
---
## 3. `run_step` 改造模板
在原有目录加载逻辑之前,插入外部模型优先分支:
```python
def run_step(self):
"""独立运行步骤X"""
# ... 公共输入校验 ...
# ★ 外部模型优先分支
if self.use_external_model.isChecked():
if self.current_model is None:
QMessageBox.warning(
self,
"模型未加载",
"请先点击「浏览...」按钮加载预训练模型文件!",
)
return
external_model_path = self.external_model_widget.get_path() or ""
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {
'stepX': self.get_config(),
'_external_model': self.current_model, # ★ 直接传对象
'_external_model_path': external_model_path, # 供日志/回溯用
}
main_window.run_single_step('stepX', config)
return
# 默认流程:使用模型目录(原有逻辑不变)
models_dir = self.models_dir_file.get_path()
if not models_dir:
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
return
# ... 原有 run_step 剩余代码 ...
```
---
## 4. 后端三层完整接入2026-06-08 已落地)
完整数据流分为三层,每层各一处分流点:
```
GUI step8_panel
↓ config = {'_external_model': obj, '_external_model_path': path, 'step8': {...}}
worker_thread.run_single_step() [第1处分流透传顶层 key]
↓ step_config = config['step8'] + {'_external_model': obj, '_external_model_path': path}
prediction_step.predict_water_quality() [第2处分流接收 + 透传]
↓ _external_model=obj, _external_model_path=path
WaterQualityInference(artifacts_dir, external_model=obj, external_model_path=path)
inference_batch.batch_inference_multi_models() [第3处分流effective_model 短路]
↓ external_model=obj
inference_batch.inference_pipeline()
→ self.external_model is not None → self.loaded_model_data = self.external_model跳过磁盘加载
```
### 4a. worker_thread.py — run_single_step 透传
`step_config = dict(config.get(step_name, {}))` 之后、"skip_dependency_check" 之前插入:
```python
# 透传面板顶层传入的外部预训练模型GUI step8_panel 通过 config['_external_model'] 传入)
# 非空才覆盖(遵循 feedback_never_overwrite_with_empty 原则)
for key in ('_external_model', '_external_model_path'):
val = config.get(key)
if val is not None and val != "":
step_config[key] = val
```
### 4b. prediction_step.py — predict_water_quality 签名 + 透传
形参表末尾增加两个参数:
```python
_external_model=None,
_external_model_path=None,
```
构造处透传:
```python
inferencer = WaterQualityInference(
models_dir,
external_model=_external_model,
external_model_path=_external_model_path,
)
all_results = inferencer.batch_inference_multi_models(
models_root_dir=models_dir,
...
external_model=_external_model,
external_model_path=_external_model_path,
)
```
### 4c. inference_batch.py — 三处修改
**`__init__` 存储**
```python
def __init__(self, artifacts_dir: str = "models/artifacts",
external_model=None, external_model_path=None):
...
self.external_model = external_model
self.external_model_path = external_model_path
```
**`batch_inference_multi_models` 短路 + 注入**
```python
# 优先级:外部预训练模型 > 从磁盘加载
if external_model is not None:
effective_model = external_model
print(f"\n使用外部预训练模型: type={type(external_model).__name__}")
else:
effective_model = None
# 子目录循环中注入:
if effective_model is not None:
model_inferencer = WaterQualityInference(
str(subdir),
external_model=effective_model,
external_model_path=external_model_path,
)
else:
model_inferencer = WaterQualityInference(str(subdir))
```
**`inference_pipeline` 模型加载短路**`load_best_model` 调用前):
```python
if self.external_model is not None:
self.loaded_model_data = self.external_model
print(f" 使用外部预训练模型: type={type(self.external_model).__name__}")
elif model_file_path:
self.load_specific_model(model_file_path)
else:
self.load_best_model(metric=metric)
```
**关键约束**
- `joblib.load` 在 panel 槽函数里完成GUI 进程内),对象通过 config 引用直接透传;**不跨进程**,所以不需要担心 pickle 序列化问题
- `batch_inference_multi_models` 形参 `external_model``external_model_path` **与实例属性同名**`self.external_model`),两者都传是为了让每个子目录创建的 `WaterQualityInference` 实例都能独立持有引用
- 原有从 `models_dir` 目录加载的逻辑完全保留,只在 `external_model is not None` 时短路
---
## 5. 已知约束
- `FileSelectWidget.browse_btn.clicked``init_ui` 里会重复 connect每次 `init_ui` 被调用时会累积;解决方案是在 connect 前先 `.disconnect()`(如模板所示)。
- `QRadioButton.toggled` 信号在两个按钮上都会触发,**必须**用 `if not checked: return` 短路,否则会导致切换时状态错乱。
- `self.current_model` 会在面板切换到"使用当前模型"时清空,防止用户忘记换回内置模式后仍使用旧导入模型。
- 当前项目 venv 路径:`D:\111\office\ZHLduijie\1.WQ\WQ_GUI\venv`,导入 `joblib` 时注意 venv 环境一致性。

View File

@ -30,6 +30,7 @@ app/api/modeling.py
"""
import asyncio
import shutil
import traceback
import uuid
from datetime import datetime
@ -40,7 +41,7 @@ import joblib
import numpy as np
import pandas as pd
import xarray as xr
from fastapi import APIRouter, BackgroundTasks
from fastapi import APIRouter, BackgroundTasks, HTTPException, UploadFile, File
from pydantic import BaseModel, Field
from sklearn.cross_decomposition import PLSRegression
from sklearn.ensemble import RandomForestRegressor
@ -784,3 +785,63 @@ async def submit_predict(
payload.output_zarr_path,
)
return {"task_id": task_id, "status": "PENDING", "kind": "predict"}
# ---------------------------------------------------------------------------
# models_router — 独立于 modeling_router路径前缀为 /models
# 最终完整路径: GET /api/models, POST /api/models/upload
# ---------------------------------------------------------------------------
models_router = APIRouter(prefix="/models", tags=["models"])
# ---------------------------------------------------------------------------
# GET /api/models
# ---------------------------------------------------------------------------
@models_router.get("")
async def list_models() -> Dict[str, Any]:
"""
扫描 ./data/models/ 目录,返回所有 .joblib 文件名(不含后缀)。
异常处理:目录不存在时自动创建,返回空列表。
"""
models_dir = Path("./data/models")
models_dir.mkdir(parents=True, exist_ok=True)
model_names = [
p.stem for p in models_dir.iterdir() if p.suffix == ".joblib"
]
return {"models": model_names}
# ---------------------------------------------------------------------------
# POST /api/models/upload
# ---------------------------------------------------------------------------
@models_router.post("/upload")
async def upload_model(
file: UploadFile = File(...),
) -> Dict[str, Any]:
"""
接收上传的 .joblib 模型文件,保存到 ./data/models/ 目录。
- 校验后缀必须为 .joblib
- 目录不存在时自动创建
- 返回状态和文件名(不含后缀)
"""
if not file.filename or not file.filename.lower().endswith(".joblib"):
raise HTTPException(
status_code=400,
detail="仅支持 .joblib 格式的文件",
)
models_dir = Path("./data/models")
models_dir.mkdir(parents=True, exist_ok=True)
dest_path = models_dir / file.filename
with dest_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return {
"status": "success",
"model_id": dest_path.stem,
}

View File

@ -18,6 +18,7 @@ from fastapi.middleware.cors import CORSMiddleware
from app.api.endpoints import router as deglint_router
from app.api.modeling import router as modeling_router
from app.api.modeling import models_router
# ---------------------------------------------------------------------------
@ -52,6 +53,7 @@ app.add_middleware(
# ---------------------------------------------------------------------------
app.include_router(deglint_router, prefix="/api")
app.include_router(modeling_router, prefix="/api")
app.include_router(models_router, prefix="/api")
# ---------------------------------------------------------------------------

View File

@ -26,12 +26,15 @@ from sklearn.model_selection import train_test_split
class WaterQualityInference:
"""水质参数反演推理类"""
def __init__(self, artifacts_dir: str = "models/artifacts"):
def __init__(self, artifacts_dir: str = "models/artifacts",
external_model=None, external_model_path=None):
"""
初始化推理类
Args:
artifacts_dir: 模型保存目录
external_model: 外部预训练模型对象(来自 GUI 导入,跳过磁盘加载)
external_model_path: 外部模型文件路径(仅用于日志)
"""
self.artifacts_dir = Path(artifacts_dir)
if not self.artifacts_dir.exists():
@ -39,6 +42,8 @@ class WaterQualityInference:
self.best_model_info = None
self.loaded_model_data = None
self.external_model = external_model
self.external_model_path = external_model_path
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
@ -745,7 +750,11 @@ class WaterQualityInference:
# 1. 加载模型
print("\n步骤1: 加载模型")
print("-" * 40)
if model_file_path:
if self.external_model is not None:
# 外部预训练模型已注入,直接使用,跳过磁盘加载
self.loaded_model_data = self.external_model
print(f" 使用外部预训练模型: type={type(self.external_model).__name__}")
elif model_file_path:
self.load_specific_model(model_file_path)
else:
self.load_best_model(metric=metric)
@ -866,7 +875,9 @@ class WaterQualityInference:
def batch_inference_multi_models(self, models_root_dir: str, sampling_csv_path: str,
output_dir: str, metric: str = 'test_r2',
prediction_column: str = 'prediction',
output_format: str = 'csv'):
output_format: str = 'csv',
external_model=None,
external_model_path=None):
"""
使用多个子文件夹中的模型进行批量推理
@ -882,6 +893,17 @@ class WaterQualityInference:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 优先级:外部预训练模型 > 从磁盘加载
if external_model is not None:
effective_model = external_model
model_desc = (
f"外部导入模型 ({external_model_path or 'unknown'}), "
f"type={type(external_model).__name__}"
)
print(f"\n使用外部预训练模型: {model_desc}")
else:
effective_model = None
# 查找所有子文件夹
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
@ -901,7 +923,14 @@ class WaterQualityInference:
print(f"处理模型文件夹: {subdir_name}")
print(f"{'='*60}")
# 创建新的推理实例使用当前子文件夹作为artifacts_dir
# 创建推理实例:外部模型优先注入,跳过磁盘查找
if effective_model is not None:
model_inferencer = WaterQualityInference(
str(subdir),
external_model=effective_model,
external_model_path=external_model_path,
)
else:
model_inferencer = WaterQualityInference(str(subdir))
# 根据输出格式设置文件扩展名

View File

@ -103,6 +103,10 @@ class PredictionStep:
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
callback: Optional[Callable] = None,
_report_generator=None,
_external_model=None,
_external_model_path=None,
_external_models_dict=None,
_external_model_dir=None,
) -> Dict[str, str]:
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
from src.core.prediction.inference_batch import WaterQualityInference
@ -149,7 +153,35 @@ class PredictionStep:
else:
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
inferencer = WaterQualityInference(models_dir)
if _external_models_dict:
# 外部模型字典优先:每个 {subdir_name: model_obj} 对应一个水质参数,
# 手动为每个模型创建 inference 实例并调用 inference_pipeline。
print(f"\n使用外部导入模型字典({len(_external_models_dict)} 个模型)...")
for target_name, model_obj in _external_models_dict.items():
try:
output_file = ml_prediction_dir / f"{target_name}.csv"
model_inferencer = WaterQualityInference(
models_dir or "./",
external_model=model_obj,
external_model_path=_external_model_dir or "",
)
predictions, result_df = model_inferencer.inference_pipeline(
sampling_csv_path=sampling_csv_path,
output_csv_path=str(output_file),
metric=metric,
prediction_column=prediction_column,
)
prediction_files[target_name] = str(output_file)
print(f"{target_name}: {len(predictions)} 个预测值")
except Exception as e:
print(f"{target_name}: 失败 — {type(e).__name__}: {e}")
prediction_files[target_name] = None
else:
inferencer = WaterQualityInference(
models_dir,
external_model=_external_model,
external_model_path=_external_model_path,
)
all_results = inferencer.batch_inference_multi_models(
models_root_dir=models_dir,
sampling_csv_path=sampling_csv_path,
@ -157,8 +189,9 @@ class PredictionStep:
metric=metric,
prediction_column=prediction_column,
output_format="csv",
external_model=_external_model,
external_model_path=_external_model_path,
)
for target_name, result in all_results.items():
if result.get("status") == "success":
prediction_files[target_name] = result["output_file"]

View File

@ -326,6 +326,14 @@ class WorkerThread(QThread):
method_name = step_method_map[step_name]
step_config = dict(config.get(step_name, {}))
# 透传面板顶层传入的外部预训练模型GUI step8_panel 通过 config['_external_model'] 传入)
# 非空才覆盖(遵循 feedback_never_overwrite_with_empty 原则)
for key in ('_external_model', '_external_model_path',
'_external_models_dict', '_external_model_dir'):
val = config.get(key)
if val is not None and val != "":
step_config[key] = val
step_config['skip_dependency_check'] = True
if step_name == 'step9':

View File

@ -91,3 +91,13 @@ Traceback (most recent call last):
sys.exit(app.exec_())
^^^^^^^^^^^
KeyboardInterrupt
============================================================
[2026-06-04 09:54:07]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3237, in <module>
main()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3214, in main
sys.exit(app.exec_())
^^^^^^^^^^^
KeyboardInterrupt

View File

@ -10,7 +10,7 @@ from pathlib import Path
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
QFileDialog,
QFileDialog, QRadioButton,
)
from src.gui.components.custom_widgets import FileSelectWidget
@ -21,12 +21,61 @@ class Step8Panel(QWidget):
"""步骤8机器学习预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.external_models_dict = {} # {subdir_name: model_obj, ...}
self.external_model_dir = "" # 母文件夹路径(隐藏)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 采样光谱CSV文件用于独立运行
# -------- 模型来源选择(单选按钮组) --------
source_group = QGroupBox("模型来源")
source_layout = QVBoxLayout()
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
self.use_trained_model.setChecked(True)
source_layout.addWidget(self.use_trained_model)
source_layout.addWidget(self.use_external_model)
self.use_trained_model.toggled.connect(self._on_model_source_changed)
self.use_external_model.toggled.connect(self._on_model_source_changed)
source_group.setStyleSheet("""
QRadioButton {
font-size: 13px;
spacing: 8px;
}
QRadioButton::indicator {
width: 16px;
height: 16px;
border-radius: 9px;
border: 2px solid #A0A0A0;
background-color: #FFFFFF;
}
QRadioButton::indicator:hover {
border: 2px solid #0078D7;
}
QRadioButton::indicator:checked {
background-color: #0078D7;
border: 2px solid #0078D7;
}
""")
source_group.setLayout(source_layout)
layout.addWidget(source_group)
# -------- 外部模型文件选择(条件显示) --------
self.external_model_widget = FileSelectWidget(
"模型母文件夹:",
"Directories"
)
self.external_model_widget.browse_btn.clicked.disconnect()
self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
self.external_model_widget.setVisible(False)
layout.addWidget(self.external_model_widget)
# -------- 采样光谱CSV文件用于独立运行--------
self.sampling_csv_file = FileSelectWidget(
"采样光谱CSV:",
"CSV Files (*.csv);;All Files (*.*)"
@ -79,6 +128,94 @@ class Step8Panel(QWidget):
layout.addStretch()
self.setLayout(layout)
def _on_model_source_changed(self, checked: bool):
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
if not checked:
return
is_external = self.use_external_model.isChecked()
self.external_model_widget.setVisible(is_external)
if not is_external:
self.external_models_dict = {}
self.external_model_dir = ""
def _scan_external_model_dir(self):
"""浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "7_Supervised_Model_Training")
dir_path = QFileDialog.getExistingDirectory(
self,
"选择模型母文件夹",
default,
)
if not dir_path:
return
self.external_model_dir = dir_path
models_found = {}
errors = []
try:
import joblib
for subentry in os.scandir(dir_path):
if not subentry.is_dir():
continue
subdir_name = subentry.name
joblib_files = [
f for f in os.scandir(subentry.path)
if f.is_file() and f.name.lower().endswith(".joblib")
]
if not joblib_files:
continue
# 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
joblib_path = joblib_files[0].path
try:
loaded = joblib.load(joblib_path)
if isinstance(loaded, dict) and "model" in loaded:
model_obj = loaded["model"]
elif hasattr(loaded, "predict"):
model_obj = loaded
else:
errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
continue
models_found[subdir_name] = model_obj
except Exception as e:
errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
except Exception as e:
QMessageBox.warning(
self,
"扫描失败",
f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
)
return
if not models_found:
QMessageBox.warning(
self,
"未找到模型",
f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
"请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
)
self.external_model_widget.set_path("")
self.external_models_dict = {}
return
self.external_models_dict = models_found
names = sorted(models_found.keys())
display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
self.external_model_widget.set_path(display)
self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
err_lines = "\n".join(errors) if errors else ""
QMessageBox.information(
self,
"模型扫描完成",
f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
f"加载失败 {len(errors)} 个:\n{err_lines}",
)
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充采样光谱和模型目录
@ -197,10 +334,31 @@ class Step8Panel(QWidget):
def run_step(self):
"""独立运行步骤8"""
sampling_csv_path = self.sampling_csv_file.get_path()
models_dir = self.models_dir_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
return
# 外部模型优先:用户选择了"导入本地预训练模型"
if self.use_external_model.isChecked():
if not self.external_models_dict:
QMessageBox.warning(
self,
"模型未加载",
"请先点击「浏览...」按钮选择模型母文件夹!",
)
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {
'step8': self.get_config(),
'_external_models_dict': self.external_models_dict,
'_external_model_dir': self.external_model_dir,
}
main_window.run_single_step('step8', config)
return
# 默认流程:使用模型目录
models_dir = self.models_dir_file.get_path()
if not models_dir:
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
return

Binary file not shown.