feat(step8): 外部模型从单文件升级为母文件夹多模型字典扫描
This commit is contained in:
@ -17,7 +17,8 @@
|
||||
"Bash(\"d:\\111\\office\\zhlduijie\\1.wq\\wq_gui\\venv\\scripts\\python.exe\" *)",
|
||||
"Bash(c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe *)",
|
||||
"Bash(venv\\scripts\\python.exe *)",
|
||||
"Bash(findstr *)"
|
||||
"Bash(findstr *)",
|
||||
"Bash(select-string *)"
|
||||
]
|
||||
},
|
||||
"$version": 4
|
||||
|
||||
@ -16,7 +16,8 @@
|
||||
"Bash(.\\venv\\scripts\\python.exe *)",
|
||||
"Bash(\"d:\\111\\office\\zhlduijie\\1.wq\\wq_gui\\venv\\scripts\\python.exe\" *)",
|
||||
"Bash(c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe *)",
|
||||
"Bash(venv\\scripts\\python.exe *)"
|
||||
"Bash(venv\\scripts\\python.exe *)",
|
||||
"Bash(findstr *)"
|
||||
]
|
||||
},
|
||||
"$version": 4
|
||||
|
||||
309
.qwen/skills/facade_kwargs_defense/SKILL.md
Normal file
309
.qwen/skills/facade_kwargs_defense/SKILL.md
Normal file
@ -0,0 +1,309 @@
|
||||
---
|
||||
name: PipelineRunner Facade 防御性 kwargs 兜底
|
||||
description: WQ_GUI 14 个 stepX_... Facade 方法必须以 **kwargs 收尾——配合 PipelineRunner 调度模式杜绝 "unexpected keyword argument" TypeError
|
||||
source: auto-skill
|
||||
extracted_at: '2026-06-04T00:54:50.036Z'
|
||||
---
|
||||
|
||||
# PipelineRunner Facade 防御性 kwargs 兜底
|
||||
|
||||
## 适用场景
|
||||
|
||||
在 WQ_GUI 中,**任何被 `PipelineRunner` 调用的 14 个 `stepX_...` Facade 方法**(位于 `src/core/water_quality_inversion_pipeline_GUI.py`),其形参表末尾**必须**带 `**kwargs`。触发信号:
|
||||
|
||||
- 用户报错 `TypeError: stepX_xxx() got an unexpected keyword argument 'yyy'`
|
||||
- 改 PIPELINE_STEPS 的 `requires` 列表
|
||||
- 新增 / 重命名一个 step 方法
|
||||
- 重构 PipelineRunner 的 `_invoke` 注入逻辑
|
||||
|
||||
## 核心原则
|
||||
|
||||
**Facade 的形参表 = 显式声明的形参 + `, **kwargs`**。`kwargs` 必须**严格位于形参表最后**(Python 语法硬要求)。
|
||||
|
||||
```python
|
||||
# ✅ 正确
|
||||
def step3_remove_glint(self, img_path: str,
|
||||
method: str = "subtract_nir",
|
||||
# ... 30+ 业务形参 ...
|
||||
skip_dependency_check: bool = False,
|
||||
**kwargs) -> str:
|
||||
...
|
||||
|
||||
# ❌ 错误:**kwargs 不能放中间或前面
|
||||
def step3_remove_glint(self, img_path, **kwargs, skip_dependency_check): # SyntaxError
|
||||
```
|
||||
|
||||
## 为什么需要这层防御
|
||||
|
||||
`PipelineRunner._invoke`(`src/core/pipeline/runner.py`)会向方法注入两类参数:
|
||||
|
||||
| 层 | 来源 | 形参 key 怎么定 |
|
||||
|---|---|---|
|
||||
| **L2** | ctx 字段(按 `spec.requires` 列表) | `_default_param_name(ctx_key)` 默认去 `_path` 后缀 |
|
||||
| **L3** | `ctx.user_config[step_id]`(14 panel dict 整体) | dict 的 key 原样注入 |
|
||||
|
||||
**L2 触发 TypeError 的真实场景**(2026-06-04 真实发生):
|
||||
|
||||
- `PIPELINE_STEPS.step3.requires = ["img_path", "water_mask_path", "glint_mask_path"]`
|
||||
- Runner 注入 `kwargs["glint_mask_path"] = ctx.glint_mask_path`
|
||||
- `step3_remove_glint` 形参表**没有** `glint_mask_path`(虽然业务上耀斑掩膜是在子调用 `GlintRemovalStep.run` 内部用的,Facade 本身不接)
|
||||
- → **TypeError: step3_remove_glint() got an unexpected keyword argument 'glint_mask_path'**
|
||||
|
||||
**L3 触发 TypeError 的场景**:
|
||||
|
||||
- user_config 14 panel dict 里残留了旧字段名 / 跨 step 串味的字段
|
||||
- 任何 `user_config[step_id][k]` 中的 `k` 都会被注入
|
||||
|
||||
`**kwargs` 一次性解决两类问题。
|
||||
|
||||
## 14 个 Facade 方法清单(截至 2026-06-04 已全部带 **kwargs)
|
||||
|
||||
| step | method | 形参表闭合示例 |
|
||||
|---|---|---|
|
||||
| 1 | `step1_generate_water_mask` | `output_path: Optional[str] = None, **kwargs) -> str:` |
|
||||
| 2 | `step2_find_glint_area` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
|
||||
| 3 | `step3_remove_glint` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
|
||||
| 4 | `step4_process_csv` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
|
||||
| 5 | `step5_extract_training_spectra` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
|
||||
| 5.5 | `step5_5_calculate_water_quality_indices` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
|
||||
| 6 | `step6_train_models` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
|
||||
| 6.5 | `step6_5_non_empirical_modeling` | `skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:` |
|
||||
| 6.75 | `step6_75_custom_regression` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
|
||||
| 7 | `step7_generate_sampling_points` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
|
||||
| 8 | `step8_predict_water_quality` | `skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:` |
|
||||
| 8.5 | `step8_5_predict_with_non_empirical_models` | `skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:` |
|
||||
| 8.75 | `step8_75_predict_with_custom_regression` | `skip_dependency_check: bool = False, **kwargs) -> Dict[str, str]:` |
|
||||
| 9 | `step9_generate_distribution_map` | `skip_dependency_check: bool = False, **kwargs) -> str:` |
|
||||
|
||||
## 标准操作
|
||||
|
||||
### 1. 编辑(最小外科手术式)
|
||||
|
||||
每个方法的最后形参是 `skip_dependency_check: bool = False`,把这一行改成:
|
||||
|
||||
```python
|
||||
skip_dependency_check: bool = False, **kwargs) -> str:
|
||||
```
|
||||
|
||||
**注意缩进必须与原行一致**(13 空格 / 35 空格 / 47 空格 / 48 空格不等,按方法原始缩进)。用 `edit` 工具的 old_string **必须含 docstring 第一行**(`"""步骤X: ..."""`)作唯一标识。
|
||||
|
||||
### 2. 验证
|
||||
|
||||
写一个临时校验脚本(项目根目录运行后删掉):
|
||||
|
||||
```python
|
||||
import ast, re
|
||||
target = r'D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\core\water_quality_inversion_pipeline_GUI.py'
|
||||
src = open(target, encoding='utf-8-sig').read()
|
||||
ast.parse(src) # AST 语法
|
||||
|
||||
expected = [
|
||||
'step1_generate_water_mask', 'step2_find_glint_area', 'step3_remove_glint',
|
||||
'step4_process_csv', 'step5_extract_training_spectra',
|
||||
'step5_5_calculate_water_quality_indices', 'step6_train_models',
|
||||
'step7_generate_sampling_points', 'step8_predict_water_quality',
|
||||
'step9_generate_distribution_map', 'step6_5_non_empirical_modeling',
|
||||
'step6_75_custom_regression', 'step8_5_predict_with_non_empirical_models',
|
||||
'step8_75_predict_with_custom_regression',
|
||||
]
|
||||
pat = re.compile(r'def\s+(\w+)\s*\((?:[^()]|\([^()]*\))*\*\*kwargs\)\s*->\s*[^:]+:', re.DOTALL)
|
||||
found = set(pat.findall(src))
|
||||
print('Missing:', [m for m in expected if m not in found])
|
||||
print('Extra :', [m for m in found if m not in expected])
|
||||
```
|
||||
|
||||
期望输出:两个列表都为空。
|
||||
|
||||
### 3. Windows 执行
|
||||
|
||||
```bat
|
||||
cd /d D:\111\office\ZHLduijie\1.WQ\WQ_GUI
|
||||
py _check.py
|
||||
del _check.py
|
||||
```
|
||||
|
||||
> 必加 `utf-8-sig`:`water_quality_inversion_pipeline_GUI.py` 头部可能含 BOM(`code_replacement_state_audit` skill 里有同样提示)。
|
||||
|
||||
## ⚠️ 这层防御不能解决什么
|
||||
|
||||
**` **kwargs` 兜底 ≠ 形参名错位修复**。如果 Runner 注入 `kwargs["training_csv_path"]` 而方法形参是 `csv_path`:
|
||||
|
||||
- ✅ **不会报 TypeError**(`training_csv_path` 被 `**kwargs` 收走)
|
||||
- ❌ **但 `csv_path` 仍是 None**(方法体内的 `if csv_path is not None: ... else: ...` 走 fallback 分支,可能读 `self.training_csv_path` 哨兵)
|
||||
|
||||
**已知形参名错位的方法**(2026-06-04 已修,commit `64aa5b8`):
|
||||
|
||||
| step | Runner 注入的 ctx key | 方法实际形参 | 实际落地的修复 |
|
||||
|---|---|---|---|
|
||||
| step6_5 | `training_csv_path` | `csv_path` | `parameter_map={"training_csv_path": "csv_path"}` |
|
||||
| step6_75 | (已切到 `indices_path`) | `csv_path` | ⚠️ 见下方"step6_75 路由修复"专题 |
|
||||
| step8_5 | `models_dir` | `non_empirical_models_dir` | `parameter_map={"models_dir": "non_empirical_models_dir"}` |
|
||||
| step8_75 | `models_dir` | `custom_regression_dir` | `parameter_map={"models_dir": "custom_regression_dir"}` |
|
||||
|
||||
> **parameter_map** 是 `StepSpec` 已有字段(runner.py:33),作用是把 ctx 字段重命名到方法形参名。**优先用 parameter_map 而非改 requires**——保持 ctx 字段语义清晰(声明式描述上游依赖),形参名是方法私有约定。
|
||||
|
||||
### step6_75 路由修复(特殊案例,2026-06-04 commit `64aa5b8`)
|
||||
|
||||
`step6_75_custom_regression` **不是简单的 ctx 字段名错位**——方法体内的 fallback 链透露了**真正的数据源是 `indices_path`**:
|
||||
|
||||
```python
|
||||
# step6_75 形参:csv_path
|
||||
# 方法体 fallback:
|
||||
# if csv_path is not None: input_csv = csv_path
|
||||
# elif self.indices_path is not None: input_csv = self.indices_path # ★ 真相
|
||||
```
|
||||
|
||||
加 `parameter_map` 把 `training_csv_path → csv_path` 看似能跑通,但**实际用错了数据**(training_csv 是 step5 输出,不是 step6_75 想要的 indices CSV)。
|
||||
|
||||
**正确做法 = 同时改 requires + parameter_map**:
|
||||
|
||||
```python
|
||||
StepSpec(
|
||||
step_id="step6_75", method_name="step6_75_custom_regression",
|
||||
requires=["indices_path"], # ★ 从 training_csv_path 切到 indices_path
|
||||
produces=["models_dir"],
|
||||
parameter_map={"indices_path": "csv_path"}, # ★ 同步改 key
|
||||
description="自定义回归分析",
|
||||
),
|
||||
```
|
||||
|
||||
**配合 `skip_when_missing` 兜底**:若用户没跑 step5_5(`ctx.indices_path` 为 None),runner 自动 skip 整个 step6_75,不会用错位数据静默执行。
|
||||
|
||||
**判别何时需要"路由切"vs"纯 rename"**:
|
||||
- 看方法体 fallback 链:fallback 到 `self.indices_path`/`self.deglint_img_path`/其他 ctx 字段名 → **需要改 requires**
|
||||
- 仅是 key 名字不同,方法体直接用形参 → 只改 parameter_map
|
||||
|
||||
## L2 注入顺序冲突:多个 requires 字段解析到同一形参名
|
||||
|
||||
### 场景
|
||||
|
||||
`StepSpec.requires` 里有**多个 ctx 字段**,经过 `_default_param_name` / `parameter_map` 解析后,**会落到同一个方法形参名**。L2 注入是**顺序敏感**的(后者覆盖前者),后注入的会**默默覆盖**前一个的赋值。
|
||||
|
||||
### 真实案例(2026-06-04 step5 修复)
|
||||
|
||||
业务需求:step5 真正需要 step4 产物 `processed_csv_path`,但**保留 raw `csv_path` 字段**作为 `user_config` 覆盖入口。
|
||||
|
||||
**❌ 错误的 parameter_map 写法**(用户原方案的隐藏 bug):
|
||||
|
||||
```python
|
||||
StepSpec(
|
||||
step_id="step5", method_name="step5_extract_training_spectra",
|
||||
requires=["deglint_img_path", "processed_csv_path", "csv_path", ...], # raw csv_path 也在
|
||||
parameter_map={"processed_csv_path": "csv_path"}, # ★ 只映射了一个
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
L2 注入顺序(`runner.py:184-186`):
|
||||
1. `deglint_img_path` → `kwargs[deglint_img_path] = ctx.deglint_img_path`
|
||||
2. `processed_csv_path` → `kwargs[csv_path] = ctx.processed_csv_path` ← 主路径生效
|
||||
3. `csv_path`(无映射 → 默认)→ `kwargs[csv_path] = ctx.csv_path` ← **后注入 None 覆盖了主路径!**
|
||||
|
||||
**症状**:step5 形参 `csv_path` 拿到的是 raw `ctx.csv_path`(通常是 None),方法体 fallback 到 `self.processed_csv_path`——但这个 fallback 也可能是 None(step4 没跑),step5 内部空跑 → "**静默错误**"。
|
||||
|
||||
**✅ 修法:parameter_map 双向映射 + 占位名落 **kwargs**:
|
||||
|
||||
```python
|
||||
parameter_map={
|
||||
"processed_csv_path": "csv_path", # 主路径(注入到方法形参)
|
||||
"csv_path": "_raw_csv_ignored", # 占位(落到 step5 形参列表末尾的 **kwargs)
|
||||
},
|
||||
```
|
||||
|
||||
注入顺序重排后:
|
||||
- `processed_csv_path` → `kwargs[csv_path] = ctx.processed_csv_path` ← 主路径
|
||||
- `csv_path` → `kwargs[_raw_csv_ignored] = ctx.csv_path` ← 落 **kwargs(被吞)
|
||||
|
||||
step5 形参 `csv_path` 最终拿到 `ctx.processed_csv_path` 的值 ✓。
|
||||
|
||||
### 验证模板(行为模拟)
|
||||
|
||||
写临时 `_verify_l2_inject.py` 复刻 `runner.py:184-186` 的 L2 注入循环,**不要只靠 AST 静态检查**——parameter_map 的 key 顺序、requires 的字段顺序都是动态的:
|
||||
|
||||
```python
|
||||
import sys
|
||||
sys.path.insert(0, r'D:\111\office\ZHLduijie\1.WQ\WQ_GUI')
|
||||
from src.core.pipeline.context import PipelineContext
|
||||
from src.core.pipeline.runner import PIPELINE_STEPS
|
||||
|
||||
spec5 = next(s for s in PIPELINE_STEPS if s.step_id == 'step5')
|
||||
|
||||
# 复刻 L2 注入(与 runner.py:184-186 完全一致)
|
||||
def l2_inject(spec, ctx):
|
||||
kwargs = {}
|
||||
for ctx_key in spec.requires:
|
||||
param_name = spec.parameter_map.get(ctx_key, ctx_key) # ★ 必须原样复刻
|
||||
kwargs[param_name] = ctx.get(ctx_key)
|
||||
return kwargs
|
||||
|
||||
# 关键断言
|
||||
ctx = PipelineContext(processed_csv_path='/csv/processed.csv', csv_path='/csv/raw.csv')
|
||||
kw = l2_inject(spec5, ctx)
|
||||
assert kw.get('csv_path') == '/csv/processed.csv', \
|
||||
f"形参 csv_path 应等于 processed_csv_path, 实际 {kw.get('csv_path')!r}"
|
||||
assert '_raw_csv_ignored' in kw, "占位名应被注入到 kwargs"
|
||||
print(f'OK: csv_path 形参 = {kw["csv_path"]!r} (processed, 主路径正确)')
|
||||
print(f'OK: _raw_csv_ignored 占位 = {kw["_raw_csv_ignored"]!r} (raw, 落 **kwargs 被吞)')
|
||||
```
|
||||
|
||||
跑完删掉:`py _verify_l2_inject.py & del _verify_l2_inject.py`(Windows 一行模式)。
|
||||
|
||||
### 何时需要警惕这个冲突
|
||||
|
||||
修改 StepSpec 时检查清单(**先看这一段再写 parameter_map**):
|
||||
|
||||
- [ ] **requires 里是否有多于 1 个 ctx 字段,解析后会落到同名方法形参?** 典型撞车:
|
||||
- 同名字段(如 `processed_csv_path` 和 `csv_path` 都能映射到 `csv_path`)
|
||||
- 同名 `_default_param_name` 退化(如 `boundary_path` 和 `boundary_shp_path` 默认都映射到 `boundary_path`——但要注意 `_default_param_name` 已废弃去后缀,原样返回 ctx key,所以 `boundary_path` 和 `boundary_shp_path` 默认就会撞 `boundary_path` / `boundary_shp_path` 不会撞,要撞就必须显式 parameter_map)
|
||||
- 字段名 + parameter_map 重命名撞车
|
||||
- [ ] **"主路径"字段在 requires 列表靠前位置**(让后续"备路径"覆盖,但**这不解决冲突**——只要有第二次注入就一定会覆盖)
|
||||
- [ ] **"备路径"字段**用占位名 `_xxx_ignored` / `_xxx_kwargs_only` 映射,让它落到 **kwargs
|
||||
- [ ] **确认方法形参表末尾有 `**kwargs`** 兜底(`facade_kwargs_defense` skill 核心要求,已 14/14 落地)
|
||||
|
||||
### 反例(不要做)
|
||||
|
||||
- ❌ "我让 `csv_path` 不在 requires 里就行了"——会**丢失 user_config 覆盖入口**(如果用户想用 raw CSV 而不是 processed)
|
||||
- ❌ "改 L2 注入循环,让 parameter_map 字段最后注入"——会**改变 runner 通用语义**,影响所有 step 的注入顺序
|
||||
- ❌ "加 `if param_name in kwargs: continue` 在 L2 注入里"——隐式"第一次优先"语义,新人读代码摸不着头脑
|
||||
- ❌ "用 position in requires 做加权"——把数据语义(哪个字段优先级高)塞到列表顺序里,runner 应该保持"声明式"
|
||||
|
||||
### 与"纯 rename"的区别
|
||||
|
||||
| 维度 | 纯 rename(已有 skill 案例) | 多→1 冲突(本节案例) |
|
||||
|---|---|---|
|
||||
| 典型场景 | step6_5/6_75/8_5/8_75:1 个 requires 重命名到形参 | step5:2 个 requires 撞到同一形参 |
|
||||
| parameter_map | 1 个 key→value | 2 个 key→同名 value + 占位名 |
|
||||
| requires | 1 个字段 | 2 个字段(主 + 备) |
|
||||
| 冲突来源 | 不会出现(单 key) | 出现(顺序敏感 + 撞名) |
|
||||
| 修法 | 只加 parameter_map | 双向 parameter_map + 占位名 |
|
||||
|
||||
## 与其他防御层的关系
|
||||
|
||||
```
|
||||
PipelineRunner.run() 主循环
|
||||
├─ L1 runner.py:152 skip_when_missing ─── ctx.<required> 全 None → skip step
|
||||
├─ L2 runner.py:182 ctx 字段注入 ─── 形参表里没声明 → TypeError ⚠️ → **kwargs 兜底
|
||||
├─ L3 runner.py:188 user_config 合并 ─── user_config 有"空字符串"/None → 跳过(上一轮加的守卫)✅
|
||||
└─ L4 runner.py:211 except 捕获 ─── 业务抛异常 → ctx.status="error" + raise
|
||||
```
|
||||
|
||||
`**kwargs` 是 **L2 的"消极兜底"**——宁愿吞掉多余 key 也不报 TypeError。**真正的"积极修复"是 parameter_map**(让 ctx 字段名映射到正确形参名)。两层配合:
|
||||
- **保守期间(重构初期)**:先 `**kwargs` 兜住,TypeError 消失
|
||||
- **稳定阶段**:补 parameter_map,让方法收到正确数据
|
||||
|
||||
## 反例(不要做)
|
||||
|
||||
- ❌ "不写 `**kwargs`,靠 type hint + IDE 检查兜底"——Runner 是运行时注入,IDE 看不到
|
||||
- ❌ "把 `**kwargs` 放形参表中间"——Python 语法错误
|
||||
- ❌ "改 requires 列表去掉冗余 ctx 字段"——会导致 `skip_when_missing` 误判(以为 step 不需要该 ctx 字段),应该用 `parameter_map` 重命名而非删除 requires
|
||||
- ❌ "在 14 个 Facade 方法体里加 `if 'glint_mask_path' in kwargs: kwargs.pop('glint_mask_path')`"——脏活,且每个方法都要加,远不如 `**kwargs` 一行优雅
|
||||
|
||||
## 案例来源
|
||||
|
||||
- 2026-06-04 WQ_GUI PipelineRunner 迁移第二步
|
||||
- 触发:`step3_remove_glint() got an unexpected keyword argument 'glint_mask_path'`
|
||||
- 根因:`PIPELINE_STEPS.step3.requires` 写了 `glint_mask_path`,但 `GlintRemovalStep` 内部使用,Facade 自身不接这个形参
|
||||
- 落地:14 个 Facade 全部加 `, **kwargs`,0 个 TypeError
|
||||
- 验证:临时 `_check.py` 14/14 命中 + AST 解析通过
|
||||
- 续:4 个 parameter_map 全部落地(commit `64aa5b8`),含 step6_75 路由切到 indices_path;L3 非空过滤同步加入 `runner._invoke:188`
|
||||
- 2026-06-04 step5 严格依赖修复:发现 L2 注入顺序冲突(requires 多个字段解析到同一形参名),引入"双向 parameter_map + 占位名落 **kwargs"模式;step5 形参 `csv_path` 真正接到 `processed_csv_path`(step4 产物),raw `csv_path` 保留为 user_config 覆盖入口,落占位名 `_raw_csv_ignored` 后被 `**kwargs` 吞。skip_when_missing 块同步加 `_notify` 通知,**拒绝静默跳过**(15 条 _notify 全带具体 missing 字段列表证据)。
|
||||
294
.qwen/skills/wq_gui_external_model_panel/SKILL.md
Normal file
294
.qwen/skills/wq_gui_external_model_panel/SKILL.md
Normal file
@ -0,0 +1,294 @@
|
||||
---
|
||||
name: WQ_GUI PyQt5 面板外部模型导入模式
|
||||
description: 在 Step8 等预测面板中通过 QRadioButton + FileSelectWidget + joblib.load 防御性加载实现"内置/导入"双模式切换的标准模式
|
||||
source: auto-skill
|
||||
extracted_at: '2026-06-08T01:38:14.481Z'
|
||||
---
|
||||
|
||||
# WQ_GUI PyQt5 面板外部模型导入模式
|
||||
|
||||
## 适用场景
|
||||
|
||||
Step8(机器学习预测)、Step8_5、Step8_75 等面板需要同时支持:
|
||||
1. **内置模式**:使用 `step6` 训练流程生成的模型目录
|
||||
2. **导入模式**:用户手动选择本地预训练 `.joblib` 文件直接加载
|
||||
|
||||
---
|
||||
|
||||
## 1. 模板(可直接复制到 `__init__` + `init_ui`)
|
||||
|
||||
```python
|
||||
from PyQt5.QtWidgets import QRadioButton
|
||||
|
||||
class StepXPanel(QWidget):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.current_model = None # ★ 外部模型实例缓存
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# -------- 模型来源选择(单选按钮组) --------
|
||||
source_group = QGroupBox("模型来源")
|
||||
source_layout = QVBoxLayout()
|
||||
|
||||
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
|
||||
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
|
||||
self.use_trained_model.setChecked(True)
|
||||
source_layout.addWidget(self.use_trained_model)
|
||||
source_layout.addWidget(self.use_external_model)
|
||||
|
||||
self.use_trained_model.toggled.connect(self._on_model_source_changed)
|
||||
self.use_external_model.toggled.connect(self._on_model_source_changed)
|
||||
|
||||
source_group.setLayout(source_layout)
|
||||
layout.addWidget(source_group)
|
||||
|
||||
# -------- 外部模型文件选择(条件显示) --------
|
||||
self.external_model_widget = FileSelectWidget(
|
||||
"预训练模型:",
|
||||
"Joblib Files (*.joblib);;All Files (*.*)"
|
||||
)
|
||||
# FileSelectWidget 的 browse_btn 默认连着 open file 行为,
|
||||
# 需要先断开默认连接,再接自定义槽
|
||||
self.external_model_widget.browse_btn.clicked.disconnect()
|
||||
self.external_model_widget.browse_btn.clicked.connect(self._browse_external_model)
|
||||
self.external_model_widget.setVisible(False)
|
||||
layout.addWidget(self.external_model_widget)
|
||||
|
||||
# ... 其余原有 UI ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 槽函数模板
|
||||
|
||||
### `_on_model_source_changed`
|
||||
|
||||
单选按钮 `toggled` 信号在**两个**按钮上都会触发(点击 A 时 A 触发,B 也触发),所以用 `if not checked: return` 让非选中分支短路。
|
||||
|
||||
```python
|
||||
def _on_model_source_changed(self, checked: bool):
|
||||
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
|
||||
if not checked:
|
||||
return
|
||||
is_external = self.use_external_model.isChecked()
|
||||
self.external_model_widget.setVisible(is_external)
|
||||
# 切回"使用当前模型"时清空缓存,释放内存并避免误用旧模型
|
||||
if not is_external:
|
||||
self.current_model = None
|
||||
```
|
||||
|
||||
### `_browse_external_model`
|
||||
|
||||
- 用 `QFileDialog.getOpenFileName` 而非 `getExistingDirectory`
|
||||
- 防御性解析两种格式:`{"model": pipeline, ...}`(Step6 输出格式)和裸 `Pipeline` 对象
|
||||
- 失败用 `QMessageBox.warning` 友善提示;成功用 `QMessageBox.information` 告知
|
||||
|
||||
```python
|
||||
from PyQt5.QtWidgets import QFileDialog, QMessageBox
|
||||
from pathlib import Path
|
||||
|
||||
def _browse_external_model(self):
|
||||
"""浏览并加载外部 .joblib 预训练模型文件"""
|
||||
default = self._get_default_work_dir()
|
||||
path, _ = QFileDialog.getOpenFileName(
|
||||
self,
|
||||
"选择预训练模型 (.joblib)",
|
||||
default,
|
||||
"Joblib Files (*.joblib);;All Files (*.*)",
|
||||
)
|
||||
if not path:
|
||||
return
|
||||
|
||||
try:
|
||||
import joblib
|
||||
loaded = joblib.load(path)
|
||||
# 兼容两种格式:dict{"model": obj} 或裸 Pipeline
|
||||
if isinstance(loaded, dict) and "model" in loaded:
|
||||
self.current_model = loaded["model"]
|
||||
elif hasattr(loaded, "predict"):
|
||||
self.current_model = loaded
|
||||
else:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"模型格式错误",
|
||||
f"无法识别的模型格式,文件内容类型为:{type(loaded).__name__}",
|
||||
)
|
||||
return
|
||||
self.external_model_widget.set_path(path)
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"模型加载成功",
|
||||
f"已加载模型:{Path(path).name}\n类型:{type(self.current_model).__name__}",
|
||||
)
|
||||
except Exception as e:
|
||||
self.current_model = None
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"模型加载失败",
|
||||
f"加载模型时发生错误:\n{type(e).__name__}: {e}",
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. `run_step` 改造模板
|
||||
|
||||
在原有目录加载逻辑之前,插入外部模型优先分支:
|
||||
|
||||
```python
|
||||
def run_step(self):
|
||||
"""独立运行步骤X"""
|
||||
# ... 公共输入校验 ...
|
||||
|
||||
# ★ 外部模型优先分支
|
||||
if self.use_external_model.isChecked():
|
||||
if self.current_model is None:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"模型未加载",
|
||||
"请先点击「浏览...」按钮加载预训练模型文件!",
|
||||
)
|
||||
return
|
||||
external_model_path = self.external_model_widget.get_path() or ""
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {
|
||||
'stepX': self.get_config(),
|
||||
'_external_model': self.current_model, # ★ 直接传对象
|
||||
'_external_model_path': external_model_path, # 供日志/回溯用
|
||||
}
|
||||
main_window.run_single_step('stepX', config)
|
||||
return
|
||||
|
||||
# 默认流程:使用模型目录(原有逻辑不变)
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if not models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
|
||||
return
|
||||
# ... 原有 run_step 剩余代码 ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 后端三层完整接入(2026-06-08 已落地)
|
||||
|
||||
完整数据流分为三层,每层各一处分流点:
|
||||
|
||||
```
|
||||
GUI step8_panel
|
||||
↓ config = {'_external_model': obj, '_external_model_path': path, 'step8': {...}}
|
||||
↓
|
||||
worker_thread.run_single_step() [第1处分流:透传顶层 key]
|
||||
↓ step_config = config['step8'] + {'_external_model': obj, '_external_model_path': path}
|
||||
↓
|
||||
prediction_step.predict_water_quality() [第2处分流:接收 + 透传]
|
||||
↓ _external_model=obj, _external_model_path=path
|
||||
↓
|
||||
WaterQualityInference(artifacts_dir, external_model=obj, external_model_path=path)
|
||||
↓
|
||||
inference_batch.batch_inference_multi_models() [第3处分流:effective_model 短路]
|
||||
↓ external_model=obj
|
||||
↓
|
||||
inference_batch.inference_pipeline()
|
||||
→ self.external_model is not None → self.loaded_model_data = self.external_model(跳过磁盘加载)
|
||||
```
|
||||
|
||||
### 4a. worker_thread.py — run_single_step 透传
|
||||
|
||||
在 `step_config = dict(config.get(step_name, {}))` 之后、"skip_dependency_check" 之前插入:
|
||||
|
||||
```python
|
||||
# 透传面板顶层传入的外部预训练模型(GUI step8_panel 通过 config['_external_model'] 传入)
|
||||
# 非空才覆盖(遵循 feedback_never_overwrite_with_empty 原则)
|
||||
for key in ('_external_model', '_external_model_path'):
|
||||
val = config.get(key)
|
||||
if val is not None and val != "":
|
||||
step_config[key] = val
|
||||
```
|
||||
|
||||
### 4b. prediction_step.py — predict_water_quality 签名 + 透传
|
||||
|
||||
形参表末尾增加两个参数:
|
||||
|
||||
```python
|
||||
_external_model=None,
|
||||
_external_model_path=None,
|
||||
```
|
||||
|
||||
构造处透传:
|
||||
|
||||
```python
|
||||
inferencer = WaterQualityInference(
|
||||
models_dir,
|
||||
external_model=_external_model,
|
||||
external_model_path=_external_model_path,
|
||||
)
|
||||
all_results = inferencer.batch_inference_multi_models(
|
||||
models_root_dir=models_dir,
|
||||
...
|
||||
external_model=_external_model,
|
||||
external_model_path=_external_model_path,
|
||||
)
|
||||
```
|
||||
|
||||
### 4c. inference_batch.py — 三处修改
|
||||
|
||||
**① `__init__` 存储**:
|
||||
|
||||
```python
|
||||
def __init__(self, artifacts_dir: str = "models/artifacts",
|
||||
external_model=None, external_model_path=None):
|
||||
...
|
||||
self.external_model = external_model
|
||||
self.external_model_path = external_model_path
|
||||
```
|
||||
|
||||
**② `batch_inference_multi_models` 短路 + 注入**:
|
||||
|
||||
```python
|
||||
# 优先级:外部预训练模型 > 从磁盘加载
|
||||
if external_model is not None:
|
||||
effective_model = external_model
|
||||
print(f"\n使用外部预训练模型: type={type(external_model).__name__}")
|
||||
else:
|
||||
effective_model = None
|
||||
|
||||
# 子目录循环中注入:
|
||||
if effective_model is not None:
|
||||
model_inferencer = WaterQualityInference(
|
||||
str(subdir),
|
||||
external_model=effective_model,
|
||||
external_model_path=external_model_path,
|
||||
)
|
||||
else:
|
||||
model_inferencer = WaterQualityInference(str(subdir))
|
||||
```
|
||||
|
||||
**③ `inference_pipeline` 模型加载短路**(`load_best_model` 调用前):
|
||||
|
||||
```python
|
||||
if self.external_model is not None:
|
||||
self.loaded_model_data = self.external_model
|
||||
print(f" 使用外部预训练模型: type={type(self.external_model).__name__}")
|
||||
elif model_file_path:
|
||||
self.load_specific_model(model_file_path)
|
||||
else:
|
||||
self.load_best_model(metric=metric)
|
||||
```
|
||||
|
||||
**关键约束**:
|
||||
- `joblib.load` 在 panel 槽函数里完成(GUI 进程内),对象通过 config 引用直接透传;**不跨进程**,所以不需要担心 pickle 序列化问题
|
||||
- `batch_inference_multi_models` 形参 `external_model` 和 `external_model_path` **与实例属性同名**(`self.external_model`),两者都传是为了让每个子目录创建的 `WaterQualityInference` 实例都能独立持有引用
|
||||
- 原有从 `models_dir` 目录加载的逻辑完全保留,只在 `external_model is not None` 时短路
|
||||
|
||||
---
|
||||
|
||||
## 5. 已知约束
|
||||
|
||||
- `FileSelectWidget.browse_btn.clicked` 在 `init_ui` 里会重复 connect,每次 `init_ui` 被调用时会累积;解决方案是在 connect 前先 `.disconnect()`(如模板所示)。
|
||||
- `QRadioButton.toggled` 信号在两个按钮上都会触发,**必须**用 `if not checked: return` 短路,否则会导致切换时状态错乱。
|
||||
- `self.current_model` 会在面板切换到"使用当前模型"时清空,防止用户忘记换回内置模式后仍使用旧导入模型。
|
||||
- 当前项目 venv 路径:`D:\111\office\ZHLduijie\1.WQ\WQ_GUI\venv`,导入 `joblib` 时注意 venv 环境一致性。
|
||||
@ -30,6 +30,7 @@ app/api/modeling.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@ -40,7 +41,7 @@ import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xarray as xr
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, UploadFile, File
|
||||
from pydantic import BaseModel, Field
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
@ -784,3 +785,63 @@ async def submit_predict(
|
||||
payload.output_zarr_path,
|
||||
)
|
||||
return {"task_id": task_id, "status": "PENDING", "kind": "predict"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# models_router — 独立于 modeling_router,路径前缀为 /models
|
||||
# 最终完整路径: GET /api/models, POST /api/models/upload
|
||||
# ---------------------------------------------------------------------------
|
||||
models_router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/models
|
||||
# ---------------------------------------------------------------------------
|
||||
@models_router.get("")
|
||||
async def list_models() -> Dict[str, Any]:
|
||||
"""
|
||||
扫描 ./data/models/ 目录,返回所有 .joblib 文件名(不含后缀)。
|
||||
|
||||
异常处理:目录不存在时自动创建,返回空列表。
|
||||
"""
|
||||
models_dir = Path("./data/models")
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model_names = [
|
||||
p.stem for p in models_dir.iterdir() if p.suffix == ".joblib"
|
||||
]
|
||||
return {"models": model_names}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/models/upload
|
||||
# ---------------------------------------------------------------------------
|
||||
@models_router.post("/upload")
|
||||
async def upload_model(
|
||||
file: UploadFile = File(...),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
接收上传的 .joblib 模型文件,保存到 ./data/models/ 目录。
|
||||
|
||||
- 校验后缀必须为 .joblib
|
||||
- 目录不存在时自动创建
|
||||
- 返回状态和文件名(不含后缀)
|
||||
"""
|
||||
if not file.filename or not file.filename.lower().endswith(".joblib"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="仅支持 .joblib 格式的文件",
|
||||
)
|
||||
|
||||
models_dir = Path("./data/models")
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dest_path = models_dir / file.filename
|
||||
|
||||
with dest_path.open("wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"model_id": dest_path.stem,
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.endpoints import router as deglint_router
|
||||
from app.api.modeling import router as modeling_router
|
||||
from app.api.modeling import models_router
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -52,6 +53,7 @@ app.add_middleware(
|
||||
# ---------------------------------------------------------------------------
|
||||
app.include_router(deglint_router, prefix="/api")
|
||||
app.include_router(modeling_router, prefix="/api")
|
||||
app.include_router(models_router, prefix="/api")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -26,19 +26,24 @@ from sklearn.model_selection import train_test_split
|
||||
class WaterQualityInference:
|
||||
"""水质参数反演推理类"""
|
||||
|
||||
def __init__(self, artifacts_dir: str = "models/artifacts"):
|
||||
def __init__(self, artifacts_dir: str = "models/artifacts",
|
||||
external_model=None, external_model_path=None):
|
||||
"""
|
||||
初始化推理类
|
||||
|
||||
Args:
|
||||
artifacts_dir: 模型保存目录
|
||||
external_model: 外部预训练模型对象(来自 GUI 导入,跳过磁盘加载)
|
||||
external_model_path: 外部模型文件路径(仅用于日志)
|
||||
"""
|
||||
self.artifacts_dir = Path(artifacts_dir)
|
||||
if not self.artifacts_dir.exists():
|
||||
print(f"警告: 模型目录不存在: {artifacts_dir},将在需要时创建")
|
||||
|
||||
|
||||
self.best_model_info = None
|
||||
self.loaded_model_data = None
|
||||
self.external_model = external_model
|
||||
self.external_model_path = external_model_path
|
||||
|
||||
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
@ -745,7 +750,11 @@ class WaterQualityInference:
|
||||
# 1. 加载模型
|
||||
print("\n步骤1: 加载模型")
|
||||
print("-" * 40)
|
||||
if model_file_path:
|
||||
if self.external_model is not None:
|
||||
# 外部预训练模型已注入,直接使用,跳过磁盘加载
|
||||
self.loaded_model_data = self.external_model
|
||||
print(f" 使用外部预训练模型: type={type(self.external_model).__name__}")
|
||||
elif model_file_path:
|
||||
self.load_specific_model(model_file_path)
|
||||
else:
|
||||
self.load_best_model(metric=metric)
|
||||
@ -863,10 +872,12 @@ class WaterQualityInference:
|
||||
print(f"\n批量推理完成,共处理 {len(csv_files)} 个文件")
|
||||
return results
|
||||
|
||||
def batch_inference_multi_models(self, models_root_dir: str, sampling_csv_path: str,
|
||||
output_dir: str, metric: str = 'test_r2',
|
||||
def batch_inference_multi_models(self, models_root_dir: str, sampling_csv_path: str,
|
||||
output_dir: str, metric: str = 'test_r2',
|
||||
prediction_column: str = 'prediction',
|
||||
output_format: str = 'csv'):
|
||||
output_format: str = 'csv',
|
||||
external_model=None,
|
||||
external_model_path=None):
|
||||
"""
|
||||
使用多个子文件夹中的模型进行批量推理
|
||||
|
||||
@ -881,7 +892,18 @@ class WaterQualityInference:
|
||||
models_root = Path(models_root_dir)
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 优先级:外部预训练模型 > 从磁盘加载
|
||||
if external_model is not None:
|
||||
effective_model = external_model
|
||||
model_desc = (
|
||||
f"外部导入模型 ({external_model_path or 'unknown'}), "
|
||||
f"type={type(external_model).__name__}"
|
||||
)
|
||||
print(f"\n使用外部预训练模型: {model_desc}")
|
||||
else:
|
||||
effective_model = None
|
||||
|
||||
# 查找所有子文件夹
|
||||
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
|
||||
|
||||
@ -900,9 +922,16 @@ class WaterQualityInference:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"处理模型文件夹: {subdir_name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 创建新的推理实例,使用当前子文件夹作为artifacts_dir
|
||||
model_inferencer = WaterQualityInference(str(subdir))
|
||||
|
||||
# 创建推理实例:外部模型优先注入,跳过磁盘查找
|
||||
if effective_model is not None:
|
||||
model_inferencer = WaterQualityInference(
|
||||
str(subdir),
|
||||
external_model=effective_model,
|
||||
external_model_path=external_model_path,
|
||||
)
|
||||
else:
|
||||
model_inferencer = WaterQualityInference(str(subdir))
|
||||
|
||||
# 根据输出格式设置文件扩展名
|
||||
file_ext = f".{output_format}"
|
||||
|
||||
@ -103,6 +103,10 @@ class PredictionStep:
|
||||
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
_external_model=None,
|
||||
_external_model_path=None,
|
||||
_external_models_dict=None,
|
||||
_external_model_dir=None,
|
||||
) -> Dict[str, str]:
|
||||
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
@ -149,19 +153,48 @@ class PredictionStep:
|
||||
else:
|
||||
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
|
||||
|
||||
inferencer = WaterQualityInference(models_dir)
|
||||
all_results = inferencer.batch_inference_multi_models(
|
||||
models_root_dir=models_dir,
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_dir=str(ml_prediction_dir),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
output_format="csv",
|
||||
)
|
||||
|
||||
for target_name, result in all_results.items():
|
||||
if result.get("status") == "success":
|
||||
prediction_files[target_name] = result["output_file"]
|
||||
if _external_models_dict:
|
||||
# 外部模型字典优先:每个 {subdir_name: model_obj} 对应一个水质参数,
|
||||
# 手动为每个模型创建 inference 实例并调用 inference_pipeline。
|
||||
print(f"\n使用外部导入模型字典({len(_external_models_dict)} 个模型)...")
|
||||
for target_name, model_obj in _external_models_dict.items():
|
||||
try:
|
||||
output_file = ml_prediction_dir / f"{target_name}.csv"
|
||||
model_inferencer = WaterQualityInference(
|
||||
models_dir or "./",
|
||||
external_model=model_obj,
|
||||
external_model_path=_external_model_dir or "",
|
||||
)
|
||||
predictions, result_df = model_inferencer.inference_pipeline(
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_csv_path=str(output_file),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
)
|
||||
prediction_files[target_name] = str(output_file)
|
||||
print(f" ✓ {target_name}: {len(predictions)} 个预测值")
|
||||
except Exception as e:
|
||||
print(f" ✗ {target_name}: 失败 — {type(e).__name__}: {e}")
|
||||
prediction_files[target_name] = None
|
||||
else:
|
||||
inferencer = WaterQualityInference(
|
||||
models_dir,
|
||||
external_model=_external_model,
|
||||
external_model_path=_external_model_path,
|
||||
)
|
||||
all_results = inferencer.batch_inference_multi_models(
|
||||
models_root_dir=models_dir,
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_dir=str(ml_prediction_dir),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
output_format="csv",
|
||||
external_model=_external_model,
|
||||
external_model_path=_external_model_path,
|
||||
)
|
||||
for target_name, result in all_results.items():
|
||||
if result.get("status") == "success":
|
||||
prediction_files[target_name] = result["output_file"]
|
||||
|
||||
print(f"预测完成,结果保存在: {ml_prediction_dir}")
|
||||
|
||||
|
||||
@ -326,6 +326,14 @@ class WorkerThread(QThread):
|
||||
method_name = step_method_map[step_name]
|
||||
step_config = dict(config.get(step_name, {}))
|
||||
|
||||
# 透传面板顶层传入的外部预训练模型(GUI step8_panel 通过 config['_external_model'] 传入)
|
||||
# 非空才覆盖(遵循 feedback_never_overwrite_with_empty 原则)
|
||||
for key in ('_external_model', '_external_model_path',
|
||||
'_external_models_dict', '_external_model_dir'):
|
||||
val = config.get(key)
|
||||
if val is not None and val != "":
|
||||
step_config[key] = val
|
||||
|
||||
step_config['skip_dependency_check'] = True
|
||||
|
||||
if step_name == 'step9':
|
||||
|
||||
@ -91,3 +91,13 @@ Traceback (most recent call last):
|
||||
sys.exit(app.exec_())
|
||||
^^^^^^^^^^^
|
||||
KeyboardInterrupt
|
||||
|
||||
============================================================
|
||||
[2026-06-04 09:54:07]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3237, in <module>
|
||||
main()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3214, in main
|
||||
sys.exit(app.exec_())
|
||||
^^^^^^^^^^^
|
||||
KeyboardInterrupt
|
||||
|
||||
@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
|
||||
QFileDialog,
|
||||
QFileDialog, QRadioButton,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
@ -21,12 +21,61 @@ class Step8Panel(QWidget):
|
||||
"""步骤8:机器学习预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.external_models_dict = {} # {subdir_name: model_obj, ...}
|
||||
self.external_model_dir = "" # 母文件夹路径(隐藏)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 采样光谱CSV文件(用于独立运行)
|
||||
# -------- 模型来源选择(单选按钮组) --------
|
||||
source_group = QGroupBox("模型来源")
|
||||
source_layout = QVBoxLayout()
|
||||
|
||||
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
|
||||
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
|
||||
self.use_trained_model.setChecked(True)
|
||||
source_layout.addWidget(self.use_trained_model)
|
||||
source_layout.addWidget(self.use_external_model)
|
||||
|
||||
self.use_trained_model.toggled.connect(self._on_model_source_changed)
|
||||
self.use_external_model.toggled.connect(self._on_model_source_changed)
|
||||
|
||||
source_group.setStyleSheet("""
|
||||
QRadioButton {
|
||||
font-size: 13px;
|
||||
spacing: 8px;
|
||||
}
|
||||
QRadioButton::indicator {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border-radius: 9px;
|
||||
border: 2px solid #A0A0A0;
|
||||
background-color: #FFFFFF;
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
background-color: #0078D7;
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
""")
|
||||
|
||||
source_group.setLayout(source_layout)
|
||||
layout.addWidget(source_group)
|
||||
|
||||
# -------- 外部模型文件选择(条件显示) --------
|
||||
self.external_model_widget = FileSelectWidget(
|
||||
"模型母文件夹:",
|
||||
"Directories"
|
||||
)
|
||||
self.external_model_widget.browse_btn.clicked.disconnect()
|
||||
self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
|
||||
self.external_model_widget.setVisible(False)
|
||||
layout.addWidget(self.external_model_widget)
|
||||
|
||||
# -------- 采样光谱CSV文件(用于独立运行)--------
|
||||
self.sampling_csv_file = FileSelectWidget(
|
||||
"采样光谱CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
@ -79,6 +128,94 @@ class Step8Panel(QWidget):
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def _on_model_source_changed(self, checked: bool):
|
||||
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
|
||||
if not checked:
|
||||
return
|
||||
is_external = self.use_external_model.isChecked()
|
||||
self.external_model_widget.setVisible(is_external)
|
||||
if not is_external:
|
||||
self.external_models_dict = {}
|
||||
self.external_model_dir = ""
|
||||
|
||||
def _scan_external_model_dir(self):
|
||||
"""浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "7_Supervised_Model_Training")
|
||||
dir_path = QFileDialog.getExistingDirectory(
|
||||
self,
|
||||
"选择模型母文件夹",
|
||||
default,
|
||||
)
|
||||
if not dir_path:
|
||||
return
|
||||
|
||||
self.external_model_dir = dir_path
|
||||
models_found = {}
|
||||
errors = []
|
||||
|
||||
try:
|
||||
import joblib
|
||||
|
||||
for subentry in os.scandir(dir_path):
|
||||
if not subentry.is_dir():
|
||||
continue
|
||||
subdir_name = subentry.name
|
||||
joblib_files = [
|
||||
f for f in os.scandir(subentry.path)
|
||||
if f.is_file() and f.name.lower().endswith(".joblib")
|
||||
]
|
||||
if not joblib_files:
|
||||
continue
|
||||
# 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
|
||||
joblib_path = joblib_files[0].path
|
||||
try:
|
||||
loaded = joblib.load(joblib_path)
|
||||
if isinstance(loaded, dict) and "model" in loaded:
|
||||
model_obj = loaded["model"]
|
||||
elif hasattr(loaded, "predict"):
|
||||
model_obj = loaded
|
||||
else:
|
||||
errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
|
||||
continue
|
||||
models_found[subdir_name] = model_obj
|
||||
except Exception as e:
|
||||
errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"扫描失败",
|
||||
f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
|
||||
)
|
||||
return
|
||||
|
||||
if not models_found:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"未找到模型",
|
||||
f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
|
||||
"请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
|
||||
)
|
||||
self.external_model_widget.set_path("")
|
||||
self.external_models_dict = {}
|
||||
return
|
||||
|
||||
self.external_models_dict = models_found
|
||||
names = sorted(models_found.keys())
|
||||
display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
|
||||
self.external_model_widget.set_path(display)
|
||||
self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
|
||||
|
||||
err_lines = "\n".join(errors) if errors else "无"
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"模型扫描完成",
|
||||
f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
|
||||
f"加载失败 {len(errors)} 个:\n{err_lines}",
|
||||
)
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充采样光谱和模型目录
|
||||
|
||||
@ -197,10 +334,31 @@ class Step8Panel(QWidget):
|
||||
def run_step(self):
|
||||
"""独立运行步骤8"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
|
||||
# 外部模型优先:用户选择了"导入本地预训练模型"
|
||||
if self.use_external_model.isChecked():
|
||||
if not self.external_models_dict:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"模型未加载",
|
||||
"请先点击「浏览...」按钮选择模型母文件夹!",
|
||||
)
|
||||
return
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {
|
||||
'step8': self.get_config(),
|
||||
'_external_models_dict': self.external_models_dict,
|
||||
'_external_model_dir': self.external_model_dir,
|
||||
}
|
||||
main_window.run_single_step('step8', config)
|
||||
return
|
||||
|
||||
# 默认流程:使用模型目录
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if not models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
|
||||
return
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user