refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名
24
.qwen/settings.json
Normal file
@ -0,0 +1,24 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(\"c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe\" *)",
|
||||
"Bash(get-childitem *)",
|
||||
"Bash(select-object *)",
|
||||
"Bash(python *)",
|
||||
"Bash(where *)",
|
||||
"Bash(conda *)",
|
||||
"Bash(dir *)",
|
||||
"Bash(cmd *)",
|
||||
"Bash(del *)",
|
||||
"Bash(powershell *)",
|
||||
"Bash(git *)",
|
||||
"Bash(type *)",
|
||||
"Bash(.\\venv\\scripts\\python.exe *)",
|
||||
"Bash(\"d:\\111\\office\\zhlduijie\\1.wq\\wq_gui\\venv\\scripts\\python.exe\" *)",
|
||||
"Bash(c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe *)",
|
||||
"Bash(venv\\scripts\\python.exe *)",
|
||||
"Bash(findstr *)"
|
||||
]
|
||||
},
|
||||
"$version": 4
|
||||
}
|
||||
23
.qwen/settings.json.orig
Normal file
@ -0,0 +1,23 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(\"c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe\" *)",
|
||||
"Bash(get-childitem *)",
|
||||
"Bash(select-object *)",
|
||||
"Bash(python *)",
|
||||
"Bash(where *)",
|
||||
"Bash(conda *)",
|
||||
"Bash(dir *)",
|
||||
"Bash(cmd *)",
|
||||
"Bash(del *)",
|
||||
"Bash(powershell *)",
|
||||
"Bash(git *)",
|
||||
"Bash(type *)",
|
||||
"Bash(.\\venv\\scripts\\python.exe *)",
|
||||
"Bash(\"d:\\111\\office\\zhlduijie\\1.wq\\wq_gui\\venv\\scripts\\python.exe\" *)",
|
||||
"Bash(c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe *)",
|
||||
"Bash(venv\\scripts\\python.exe *)"
|
||||
]
|
||||
},
|
||||
"$version": 4
|
||||
}
|
||||
141
.qwen/skills/code_replacement_state_audit/SKILL.md
Normal file
@ -0,0 +1,141 @@
|
||||
---
|
||||
name: 代码替换请求的现状审计
|
||||
description: 处理用户"代码替换/新增"指令时,先审计磁盘真实状态再用 ask_user_question 确认——避免覆盖已落盘的高版本代码
|
||||
source: auto-skill
|
||||
extracted_at: '2026-06-03T05:36:58.746Z'
|
||||
---
|
||||
|
||||
# 代码替换请求的现状审计
|
||||
|
||||
## 适用场景
|
||||
|
||||
用户给出"代码替换"或"按某版本代码新增"指令,但**没有提供与磁盘当前状态对比信息**时。典型触发:
|
||||
|
||||
- 用户贴了一段代码说"请帮我写/替换这个"
|
||||
- 用户引用某个文档/旧版本/旧 chat 说"按这个来"
|
||||
- 之前的 `state_snapshot` / `memory` / `git log` 描述可能与磁盘现状不一致
|
||||
|
||||
## 核心原则
|
||||
|
||||
**永远不要盲信"用户给的代码是最新版本"**——磁盘上的代码可能已经是更完善的版本(用户或其他 agent 已迭代过)。覆盖 = 丢功能。
|
||||
|
||||
直接覆盖的代价不一定是显式 bug,也可能是"丢失用户已批准的设计决策"(如 duck-type 探测 / ctx 抽象 / 信号协议 / 二次确认窗 / 错误定位)。
|
||||
|
||||
## 5 步标准操作
|
||||
|
||||
### 1. 确认文件存在
|
||||
|
||||
`glob` 或 `list_directory` 看目标文件是否已存在:
|
||||
|
||||
- 不存在 → 新建
|
||||
- 存在 → 进入第 2 步审计
|
||||
|
||||
### 2. grep 关键符号 + 读关键段
|
||||
|
||||
- 找"用户贴的代码"里的 3-5 个关键符号(函数名 / 类名 / 关键常量 / import)
|
||||
- 在磁盘文件里 grep 同样的符号
|
||||
- `read_file` 关键段(行号从 grep 结果直接拿)
|
||||
|
||||
### 3. 构造差异对照表
|
||||
|
||||
列出:
|
||||
|
||||
```
|
||||
| 目标文件 | 用户贴的版本 | 磁盘现有版本 | 直接覆盖会丢失 |
|
||||
```
|
||||
|
||||
**关键列**:"直接覆盖会丢失什么"——让用户判断成本。具体粒度到"功能模块 / 设计决策 / 防御层 / 入口协议",不要写"代码差异"这种空话。
|
||||
|
||||
### 4. ask_user_question 让用户拍板
|
||||
|
||||
3 个标准选项(措辞可调,但**必须给出现状 + 三选一**):
|
||||
|
||||
- **A. 保留现状**(推荐,磁盘已是更新版)—— 直接进 Smoke Test
|
||||
- **B. 强制覆盖到旧版** —— 写明丢什么 + 备份建议(git stash / 复制到 `_old.py`)
|
||||
- **C. 混合:只取某段增量** —— 见第 5 步
|
||||
|
||||
**不要在第 1 次 ask 时就列具体的"哪段增量"**——先让用户在 A/B/C 之间选。如果选 C,再做第 5 步。
|
||||
|
||||
### 5. 若用户选 C,识别"真正增量"
|
||||
|
||||
对比 1.0 vs 2.0,识别 1.0 真正独有的部分(2.0 没有的):
|
||||
|
||||
- ❌ 排除 1.0 比 2.0 简单的(2.0 是超集 / 工厂分层 / 多了 CLI)
|
||||
- ❌ 排除 1.0 整体被 2.0 工厂分层超越的(_make_objective vs _build_model + _get_search_space)
|
||||
- ✅ 关注 1.0 独有的功能层(即使 2.0 不"明显"需要)
|
||||
|
||||
对每个候选增量,再问一次"采纳哪段",让用户具体选(multiSelect=false,一次只选 1 段最稳)。
|
||||
|
||||
## 落地原则
|
||||
|
||||
执行"采纳 1.0 某段增量到 2.0"时:
|
||||
|
||||
- **最小化外科手术式编辑**:只动需要动的文件,只改需要改的段
|
||||
- **保留 2.0 的设计决策**(duck-type 探测 / ctx 抽象 / 信号协议 / 二次确认窗 / 错误定位)
|
||||
- **顶部 import 增量用 `replace_all=False` 单点插入**,避免破坏其他 import 顺序
|
||||
- **同名变量全链路替换**(如 `self.config` → `clean_config`)要贯穿 ctx 构造 / v2 调用 / v1 fallback,避免双源差异
|
||||
- **单步模式不一定要清洗**(不走 panel 完整 config,与清洗器无关)
|
||||
- **清洗器这种"防患于未然"的代码要给日志**(`self.log_message.emit(f"[清洗器] 已删除 N 个未知 key")`)让运行时可见
|
||||
|
||||
## 验证三件套
|
||||
|
||||
落地后必跑:
|
||||
|
||||
1. **AST 语法检查**:`ast.parse(open(p, encoding='utf-8-sig').read())` 对 5 个核心文件
|
||||
- 必加 `utf-8-sig`:WQ_GUI 的 water_quality_gui.py line 1 是 BOM,plain `utf-8` 必挂
|
||||
2. **关键符号 grep**:确认新代码的关键符号(import / 关键函数调用)都命中,hit 数符合预期
|
||||
3. **顶层导入测试**:用 mock PyQt5 + `sys.path.insert(0, 'src/gui/core')`,验证模块整体可加载
|
||||
- PyQt5 mock 模板见下方"参考代码"
|
||||
- Windows 环境调 Python:用 conda env 的 `python.exe` 全路径,不要靠 PATH
|
||||
|
||||
## 反例(不要做)
|
||||
|
||||
- ❌ "按用户贴的代码原封不动写入"——1.0 简化版的覆盖陷阱
|
||||
- ❌ "保留 state_snapshot 描述"——state snapshot 可能不准确(写的是意图,磁盘才是事实)
|
||||
- ❌ "用 git log 反推当前状态"——git log 不能反映工作区未提交改动
|
||||
- ❌ "靠 memory 推断当前状态"——memory 可能是 22 天前的(已确认过期)
|
||||
- ❌ "磁盘和用户给的代码看起来一样就不审计"——一行之差可能就是"防弹层"丢失
|
||||
|
||||
## 参考代码
|
||||
|
||||
### PyQt5 mock 模板(worker_thread.py 顶层导入测试)
|
||||
|
||||
```python
|
||||
import os, sys
|
||||
os.environ['GDAL_FILENAME_IS_UTF8'] = 'YES'
|
||||
os.environ['SHAPE_ENCODING'] = 'UTF-8'
|
||||
sys.path.insert(0, 'src/gui/core')
|
||||
|
||||
import types
|
||||
pyqt5 = types.ModuleType("PyQt5")
|
||||
qtc = types.ModuleType("PyQt5.QtCore")
|
||||
class _QThread:
|
||||
def __init__(self, *a, **kw): pass
|
||||
class _Signal:
|
||||
def __init__(self, *a, **kw): pass
|
||||
qtc.QThread = _QThread
|
||||
qtc.pyqtSignal = _Signal
|
||||
qtc.Qt = type("Qt", (), {"QueuedConnection": 1, "UserRole": 0})()
|
||||
sys.modules["PyQt5"] = pyqt5
|
||||
sys.modules["PyQt5.QtCore"] = qtc
|
||||
|
||||
import worker_thread
|
||||
# 副作用: check_pipeline_dependencies() 会打印依赖检查日志(可忽略)
|
||||
```
|
||||
|
||||
### Windows 上跑 conda env python
|
||||
|
||||
```bat
|
||||
cmd /c "D:\xxx\anconda\envs\XXX\python.exe D:\path\to\script.py"
|
||||
```
|
||||
|
||||
PowerShell 单行 `python -c "..."` 在中文路径 / 双引号 / 单引号嵌套时易翻车,**写临时 .py 文件再用 `cmd /c` 调**最稳。
|
||||
|
||||
## 案例来源(2026-06-03 WQ_GUI 路线 B MVP)
|
||||
|
||||
- 用户贴 1.0 简化版:300 行 automl_trainer / 简化 worker_thread.run() / 简化 on_run_all_clicked
|
||||
- 磁盘上 2.0 落盘版:545 行 automl_trainer(_build_model + _get_search_space 工厂 / argparse CLI)/ duck-type 探测 v2 + PipelineContext 抽象 / 完整二次确认窗 / 失败步骤 _focus_step 定位 / [DEPRECATED] stop 保留
|
||||
- 1.0 唯一真增量 = **"防弹级参数清洗器"**(method_map 14 项 + inspect.signature 过滤未知 key + has_kwargs 豁免 + 未知 key 数量日志)
|
||||
- 落地:worker_thread.py:run() 内 set_callback 之后插入 53 行清洗器,self.config 6 处替换为 clean_config
|
||||
- 验证:5 文件 AST 全通过 + 关键符号 7 项命中 + PyQt5 mock 下 import 成功
|
||||
- 净增行数:407 → 457(+50 行)
|
||||
206
.qwen/skills/wq_gui_data_flow/SKILL.md
Normal file
@ -0,0 +1,206 @@
|
||||
---
|
||||
name: WQ_GUI 数据流转架构
|
||||
description: WQ_GUI ProjectSession 事件总线驱动的步骤间数据传递机制(完整重构版)
|
||||
source: auto-skill
|
||||
extracted_at: '2026-05-28T09:07:34.967Z'
|
||||
---
|
||||
|
||||
# WQ_GUI 数据流转架构
|
||||
|
||||
## 核心结论
|
||||
|
||||
整个系统是**基于文件路径驱动**的管道,所有数据存储在本地磁盘。重构后通过 `ProjectSession` 事件总线实现 Panel 间完全解耦。
|
||||
|
||||
---
|
||||
|
||||
## 1. 旧架构(旧代码中已删除)
|
||||
|
||||
主窗口通过 `self.step_outputs` 字典 + `step_dependencies` 配置 + `auto_populate_*` 系列方法管理步骤间路径填充。存在高度耦合问题:
|
||||
|
||||
```python
|
||||
# 已废弃并删除
|
||||
self.step_outputs = {}
|
||||
self._init_step_dependencies()
|
||||
self.update_step_outputs(step_name, work_path)
|
||||
self.auto_populate_dependent_steps(completed_step)
|
||||
self.auto_populate_step_inputs(step_id)
|
||||
self.find_step_output(work_path, step_id, output_type)
|
||||
self.add_auto_fill_buttons_to_panels()
|
||||
self.scan_work_directory_for_files(work_path)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 新架构:ProjectSession 事件总线
|
||||
|
||||
### Session 核心 API(`src/core/project_session.py`)
|
||||
|
||||
```python
|
||||
class ProjectSession(QObject):
|
||||
path_updated = pyqtSignal(str, str, str) # step, out_type, path
|
||||
step_outputs_ready = pyqtSignal(str, str) # step, out_type
|
||||
|
||||
def update_output(step, out_type, path):
|
||||
"""Panel 完成后广播输出路径"""
|
||||
|
||||
def update_outputs(step, {out_type: path, ...}):
|
||||
"""Panel 完成后批量广播多个输出路径"""
|
||||
|
||||
def get_output(step, out_type):
|
||||
"""Panel 可主动查询上游路径(用于自动填充)"""
|
||||
|
||||
def get_step_outputs(step):
|
||||
"""返回该 step 的全部输出字典"""
|
||||
|
||||
def scan_work_directory():
|
||||
"""主窗口 on_step_completed 末尾调用,扫描并广播所有已知路径"""
|
||||
```
|
||||
|
||||
### Panel 重构模板
|
||||
|
||||
```python
|
||||
class StepXPanel(QWidget):
|
||||
def __init__(self, session=None, parent=None):
|
||||
super().__init__(parent)
|
||||
self.session = session
|
||||
self.work_dir = None
|
||||
self.init_ui()
|
||||
self._bind_session_signals()
|
||||
|
||||
def _bind_session_signals(self):
|
||||
if not self.session:
|
||||
return
|
||||
self.session.path_updated.connect(
|
||||
self._on_session_path_updated, Qt.QueuedConnection
|
||||
)
|
||||
|
||||
@pyqtSlot(str, str, str)
|
||||
def _on_session_path_updated(self, step_name, output_type, path):
|
||||
print(f"[StepX Debug] 收到广播: step={step_name}, type={output_type}, path={path}")
|
||||
if step_name == 'step1':
|
||||
if output_type == 'reference_img':
|
||||
if not self.img_file.get_path().strip():
|
||||
self.img_file.set_path(path)
|
||||
print(f"[StepX] 自动填充参考影像: {path}")
|
||||
elif output_type == 'water_mask':
|
||||
if not self.water_mask_file.get_path().strip():
|
||||
self.water_mask_file.set_path(path)
|
||||
print(f"[StepX] 自动填充水域掩膜: {path}")
|
||||
# ...
|
||||
|
||||
def on_step_finished(self, success, message):
|
||||
"""由主窗口 on_step_completed 通过 getattr 动态调用"""
|
||||
if not success:
|
||||
return
|
||||
if self.session:
|
||||
outputs = {}
|
||||
path = self.output_widget.get_path().strip()
|
||||
if path:
|
||||
outputs['output_type'] = path
|
||||
if outputs:
|
||||
self.session.update_outputs('stepX', outputs)
|
||||
```
|
||||
|
||||
### 主窗口两处改动
|
||||
|
||||
```python
|
||||
# 1. __init__ 中注入 session(所有 Panel 统一注入)
|
||||
self.step1_panel = Step1Panel(session=self.session)
|
||||
self.step2_panel = Step2Panel(session=self.session)
|
||||
self.step3_panel = Step3Panel(session=self.session)
|
||||
self.step4_panel = Step4Panel(session=self.session)
|
||||
self.step5_panel = Step5Panel(session=self.session)
|
||||
self.step5_5_panel = Step5_5Panel(session=self.session)
|
||||
self.step6_panel = Step6Panel(session=self.session)
|
||||
self.step6_5_panel = Step6_5Panel(session=self.session)
|
||||
self.step6_75_panel = Step6_75Panel(session=self.session)
|
||||
self.step7_panel = Step7Panel(session=self.session)
|
||||
self.step8_panel = Step8Panel(session=self.session)
|
||||
self.step8_5_panel = Step8_5Panel(session=self.session)
|
||||
self.step8_75_panel = Step8_75Panel(session=self.session)
|
||||
self.step9_panel = Step9Panel(session=self.session)
|
||||
|
||||
# 2. on_step_completed(通用动态获取,无需维护字典)
|
||||
def on_step_completed(self, step_name, success, message):
|
||||
if not success:
|
||||
return
|
||||
if hasattr(self, 'session') and self.session:
|
||||
self.session.scan_work_directory()
|
||||
|
||||
panel = getattr(self, f"{step_name}_panel", None)
|
||||
if panel and hasattr(panel, 'on_step_finished'):
|
||||
panel.on_step_finished(success, message)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. 全链路事件流
|
||||
|
||||
### step1 → step2 / step3 路径(通过 Shapefile 栅格化产物)
|
||||
|
||||
| 场景 | 广播的 water_mask 路径 |
|
||||
|------|----------------------|
|
||||
| NDWI 模式 | `output_file` 用户指定路径 |
|
||||
| Shapefile 模式 | `{work_dir}/1_water_mask/water_mask_from_shp.dat`(优先)<br>若文件不存在则 fallback 回 `mask_file.get_path()` |
|
||||
|
||||
```
|
||||
step1 完成
|
||||
→ step1_panel.on_step_finished()
|
||||
→ session.update_outputs('step1', {
|
||||
'reference_img': img_path,
|
||||
'water_mask': mask_path # 可能是 .dat 或 .shp(见上表)
|
||||
})
|
||||
→ step2_panel._on_session_path_updated()
|
||||
→ step3_panel._on_session_path_updated()
|
||||
```
|
||||
|
||||
### step3 → step5 / step7;step5 → 下游训练
|
||||
|
||||
```
|
||||
step3.deglint_image ──┬─→ step5.deglint_image(填充 img_file)
|
||||
└─→ step7.deglint_image(填充 img_file)
|
||||
|
||||
step5.training_spectra ──┬─→ step5_5.index_features
|
||||
├─→ step6.models_dir ──→ step8.predictions
|
||||
├─→ step6_5.models_dir ──→ step8_5.predictions
|
||||
└─→ step6_75.models_dir ──→ step8_75.predictions
|
||||
|
||||
step7.sampling_points ──┬─→ step8
|
||||
├─→ step8_5
|
||||
└─→ step8_75
|
||||
|
||||
step8/8_5/8_75.predictions ──→ step9.distribution_map
|
||||
```
|
||||
|
||||
### 各 Panel 监听/发布对照表(完整版)
|
||||
|
||||
| Panel | 监听 | 发布 |
|
||||
|-------|------|------|
|
||||
| step1 | — | `reference_img`, `water_mask` |
|
||||
| step2 | `step1.reference_img`, `step1.water_mask` | `glint_mask` |
|
||||
| step3 | `step1.reference_img`, `step1.water_mask`, `step2.glint_mask` | `deglint_image` |
|
||||
| step4 | — | `processed_data` |
|
||||
| step5 | `step3.deglint_image`, `step4.processed_data`, `step2.glint_mask` | `training_spectra` |
|
||||
| step5_5 | `step5.training_spectra` | `index_features` |
|
||||
| step6 | `step5.training_spectra` | `models_dir` |
|
||||
| step6_5 | `step5.training_spectra` | `models_dir` |
|
||||
| step6_75 | `step5.training_spectra` | `models_dir` |
|
||||
| step7 | `step3.deglint_image`, `step1.water_mask`, `step2.glint_mask` | `sampling_points` |
|
||||
| step8 | `step7.sampling_points`, `step6.models_dir` | `predictions` |
|
||||
| step8_5 | `step7.sampling_points`, `step6_5.models_dir` | `predictions` |
|
||||
| step8_75 | `step7.sampling_points`, `step6_75.models_dir` | `predictions` |
|
||||
| step9 | `step8.predictions`, `step8_5.predictions`, `step8_75.predictions` | `distribution_map` |
|
||||
|
||||
---
|
||||
|
||||
## 4. 关键约束
|
||||
|
||||
- `__init__` 参数 `session=None`(向后兼容,主窗口可继续不传)
|
||||
- 所有 Panel 的 `init_ui / get_config / set_config / update_from_config` 完整保留
|
||||
- 删除所有 `self.window().stepX_panel` 跨界访问
|
||||
- 使用 `self.session.get_output()` 替代直接读取其他 panel 的 widget
|
||||
- 监听使用 `Qt.QueuedConnection` 确保跨线程安全
|
||||
- 仅在 field 为空时自动填充(`not widget.get_path().strip()`)
|
||||
- `update_from_config` 中优先从 Session 获取路径,再用 Session 广播
|
||||
- 主窗口 `on_step_completed` 中使用 `getattr(self, f"{step_name}_panel", None)` 实现通用动态获取,无需维护硬编码字典
|
||||
- `step1` Shapefile 模式下,**不能**直接广播 `.shp` 输入文件,必须拼接 `{work_dir}/1_water_mask/water_mask_from_shp.dat` 作为产物路径
|
||||
229
.qwen/skills/wq_gui_frontend_scaffold/SKILL.md
Normal file
@ -0,0 +1,229 @@
|
||||
---
|
||||
name: WQ_GUI 前端 Vue3 + Element Plus 脚手架
|
||||
description: WQ_GUI 项目 frontend/ 目录的 Vite + Vue 3 + TS + Element Plus 最小可运行脚手架,以及 useTaskPoller 与 Element Plus UI 的接线模式
|
||||
source: auto-skill
|
||||
extracted_at: '2026-06-02T08:17:33.116Z'
|
||||
---
|
||||
|
||||
# WQ_GUI 前端脚手架 (Vue 3 + Element Plus)
|
||||
|
||||
## 适用场景
|
||||
|
||||
为 WQ_GUI FastAPI 后端 (`127.0.0.1:8000`) 搭建一个**最小可联调**的浏览器控制台。
|
||||
后端已暴露:
|
||||
|
||||
- `POST /api/modeling/train` → `{ task_id, status, kind }`
|
||||
- `POST /api/modeling/predict` → `{ task_id, status, kind }`
|
||||
- `GET /api/tasks/{task_id}` → `TaskRecord`(含 PENDING/PROCESSING/SUCCESS/FAILED + 模型指标 / 输出路径)
|
||||
- `GET /api/algorithms` → 算法清单
|
||||
|
||||
前端已有 (`frontend/src/`):
|
||||
|
||||
- `api/request.ts`:axios 单例 + 响应拦截器自动 unwrap,baseURL 走 `VITE_API_BASE_URL` 缺省 `http://127.0.0.1:8000`
|
||||
- `api/tasks.ts`:所有提交 / 查询函数 + 完整 `TaskRecord` / `TaskStatus` / `TaskKind` 类型
|
||||
- `composables/useTaskPoller.ts`:完整轮询 composable,支持 3 种用法(静态 / 响应式 taskId / 手动)
|
||||
|
||||
## 1. 一次性补齐的脚手架文件
|
||||
|
||||
`frontend/` 初始状态**只有 `src/api` 和 `src/composables`**,缺整个 Vite 骨架。直接照下面这 7 个文件铺一遍:
|
||||
|
||||
```
|
||||
frontend/
|
||||
├── .env.development # VITE_API_BASE_URL=http://127.0.0.1:8000
|
||||
├── .gitignore # node_modules / dist / .vite
|
||||
├── env.d.ts # vite/client + ImportMeta + *.vue shim
|
||||
├── index.html # 挂载 #app
|
||||
├── package.json
|
||||
├── tsconfig.json # 严格模式 + @ → src + bundler resolution
|
||||
├── tsconfig.node.json # 给 vite.config.ts 用
|
||||
├── vite.config.ts # @ alias + 0.0.0.0:5173
|
||||
└── src/
|
||||
├── main.ts
|
||||
└── App.vue
|
||||
```
|
||||
|
||||
### 锁定版本(2026-06 联调通过)
|
||||
|
||||
```json
|
||||
{
|
||||
"dependencies": {
|
||||
"vue": "^3.4.27",
|
||||
"element-plus": "^2.7.5",
|
||||
"@element-plus/icons-vue": "^2.3.1",
|
||||
"axios": "^1.7.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.12.12",
|
||||
"@vitejs/plugin-vue": "^5.0.4",
|
||||
"typescript": "^5.4.5",
|
||||
"vite": "^5.2.11",
|
||||
"vue-tsc": "^2.0.19"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**`@types/node` 必加**——`vite.config.ts` 用了 `import { fileURLToPath, URL } from 'node:url'`,否则 `npm run build` 类型检查必挂。
|
||||
|
||||
### `tsconfig.json` 关键字段
|
||||
|
||||
- `"moduleResolution": "bundler"`
|
||||
- `"allowImportingTsExtensions": true`(配合 `vue-tsc --noEmit`)
|
||||
- `"paths": { "@/*": ["src/*"] }` + `"baseUrl": "."`
|
||||
- `"include": ["src/**/*.vue"]`(`vue-tsc` 才会处理 SFC)
|
||||
- `"references": [{ "path": "./tsconfig.node.json" }]`
|
||||
|
||||
### `vite.config.ts` 关键字段
|
||||
|
||||
```ts
|
||||
resolve: {
|
||||
alias: { '@': fileURLToPath(new URL('./src', import.meta.url)) },
|
||||
},
|
||||
server: { host: '0.0.0.0', port: 5173 },
|
||||
```
|
||||
|
||||
`0.0.0.0` 方便局域网真机调试;端口冲突时 `strictPort: false` 允许 Vite 自动 +1。
|
||||
|
||||
---
|
||||
|
||||
## 2. main.ts 模板(全量注册 Element Plus)
|
||||
|
||||
```ts
|
||||
import { createApp } from 'vue'
|
||||
import ElementPlus from 'element-plus'
|
||||
import 'element-plus/dist/index.css'
|
||||
import * as ElementPlusIconsVue from '@element-plus/icons-vue'
|
||||
|
||||
import App from './App.vue'
|
||||
|
||||
const app = createApp(App)
|
||||
app.use(ElementPlus)
|
||||
|
||||
// 全量注册图标 (<el-icon><Cpu /></el-icon>)
|
||||
for (const [name, component] of Object.entries(ElementPlusIconsVue)) {
|
||||
app.component(name, component)
|
||||
}
|
||||
app.mount('#app')
|
||||
```
|
||||
|
||||
联调期**全量注册最省事**;后期打包体积大再换 `unplugin-vue-components` 按需。
|
||||
|
||||
---
|
||||
|
||||
## 3. useTaskPoller 接线模式(双实例)
|
||||
|
||||
训练 / 推断是**两条独立流水线**,各起一个 `useTaskPoller` 实例。核心套路:把 `task_id` 包成 `ref<string | null>(null)`,composable 内部 `watch` 会**自动 start()**,无需手动调:
|
||||
|
||||
```ts
|
||||
import { ref, watch, computed } from 'vue'
|
||||
import { submitTrain, submitPredict, type TaskRecord } from './api/tasks'
|
||||
import { useTaskPoller } from './composables/useTaskPoller'
|
||||
|
||||
// —— 训练 ——
|
||||
const trainTaskId = ref<string | null>(null)
|
||||
const trainPoller = useTaskPoller(trainTaskId) // 传 ref 进去, 自动 watch
|
||||
|
||||
async function onStartTrain() {
|
||||
const { task_id } = await submitTrain({ ... })
|
||||
trainTaskId.value = task_id // 赋值后 watch 触发 start()
|
||||
}
|
||||
|
||||
// —— 推断 ——
|
||||
const predictTaskId = ref<string | null>(null)
|
||||
const predictPoller = useTaskPoller(predictTaskId)
|
||||
const modelId = ref('')
|
||||
|
||||
// 训练一成功, model_id 自动填入推断输入框
|
||||
watch(
|
||||
() => trainPoller.result.value?.model_id,
|
||||
(newId) => { if (newId) modelId.value = newId },
|
||||
)
|
||||
|
||||
async function onStartPredict() {
|
||||
const { task_id } = await submitPredict({ model_id: modelId.value, ... })
|
||||
predictTaskId.value = task_id
|
||||
}
|
||||
```
|
||||
|
||||
**关键点**:
|
||||
|
||||
- `trainPoller.result.value` 才是 SUCCESS 后的完整 `TaskRecord`;`record.value` 是任意时刻(含中间态)的最新记录。模板里同时展示用 `trainPoller.record.value ?? trainPoller.result.value`。
|
||||
- `poller.isPolling.value` / `poller.status.value` / `poller.error.value` / `poller.taskId.value` 都是 `Ref`,模板里必须用 `.value`(它们是嵌套 ref,**Vue 模板不会自动 unwrap**)。
|
||||
|
||||
---
|
||||
|
||||
## 4. el-progress 状态映射
|
||||
|
||||
`PollerStatus = 'idle' | 'PENDING' | 'PROCESSING' | 'SUCCESS' | 'FAILED'`
|
||||
`el-progress` 的 `status` 接受 `'' | 'success' | 'warning' | 'exception'`。
|
||||
|
||||
```ts
|
||||
function progressOf(status: string): number {
|
||||
switch (status) {
|
||||
case 'idle':
|
||||
case 'PENDING': return 10
|
||||
case 'PROCESSING':return 60
|
||||
case 'SUCCESS':
|
||||
case 'FAILED': return 100
|
||||
default: return 0
|
||||
}
|
||||
}
|
||||
function progressStatusOf(s: string): '' | 'success' | 'exception' {
|
||||
if (s === 'SUCCESS') return 'success'
|
||||
if (s === 'FAILED') return 'exception'
|
||||
return ''
|
||||
}
|
||||
```
|
||||
|
||||
模板里 `v-if="poller.isPolling.value || poller.status.value === 'SUCCESS' || poller.status.value === 'FAILED'"` 控制展示。
|
||||
|
||||
---
|
||||
|
||||
## 5. CSS:深色控制台风(slate 渐变 + 卡片玻璃态)
|
||||
|
||||
```css
|
||||
.app-root {
|
||||
min-height: 100vh;
|
||||
background: linear-gradient(180deg, #0f172a 0%, #1e293b 100%);
|
||||
color: #e2e8f0;
|
||||
}
|
||||
.panel {
|
||||
background: rgba(30, 41, 59, 0.7) !important;
|
||||
border: 1px solid rgba(148, 163, 184, 0.18) !important;
|
||||
}
|
||||
.app-main {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr; /* 左训练 / 右推断 */
|
||||
gap: 20px;
|
||||
}
|
||||
@media (max-width: 960px) { .app-main { grid-template-columns: 1fr; } }
|
||||
```
|
||||
|
||||
深色背景下 Element Plus 的 `el-form-item__label` / `el-descriptions__label` 默认是黑色文字,必须 `:deep()` 覆盖成浅色。
|
||||
|
||||
---
|
||||
|
||||
## 6. 启动与验证
|
||||
|
||||
```bat
|
||||
cd /d D:\111\office\ZHLduijie\1.WQ\WQ_GUI\frontend
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
打开 `http://127.0.0.1:5173/`,联调期望路径:
|
||||
|
||||
1. 左侧「开始训练」→ 立即拿到 `task_id` + 黄色 `轮询中` + 进度条 60%
|
||||
2. 后端 SUCCESS → 进度条变绿,下面出现 `model_id` 标签 + R²/RMSE/MAE
|
||||
3. 右侧 `model_id` 被自动填入 → 「开始推断」→ 走 `output_zarr_path` 展示
|
||||
4. 任何一步 FAILED → 进度条变红 + 后端 `error` 字段
|
||||
|
||||
---
|
||||
|
||||
## 7. 已知 caveat
|
||||
|
||||
- **第一次 `npm install` 约 150MB**,要耐心等。
|
||||
- `useTaskPoller` 已有 `onUnmounted` 自动清理,**不要再手写 `clearInterval`**。
|
||||
- `request.ts` 注释里写明 FastAPI dev 期 `allow_origins=["*"]`,**不需要配 Vite proxy**;如果未来后端收紧 CORS,再在 `vite.config.ts` 加 `server.proxy['/api']`。
|
||||
- `feature_start` 后端接受 `number | string`;el-input v-model 出来是 string,**直接传给 API 即可**,后端会自己判别。
|
||||
- `v-model` 绑 `ref<number | string>(4)` 类型注解是必须的,否则 TS 会推断成 `Ref<number>`,输入框失焦报错。
|
||||
- `@element-plus/icons-vue` 全量注册后用 `<el-icon><Cpu /></el-icon>` 调,本期 App.vue 没用到但留着扩展位。
|
||||
BIN
data/icons-1/1.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/10.ico
Normal file
|
After Width: | Height: | Size: 37 KiB |
BIN
data/icons-1/11.ico
Normal file
|
After Width: | Height: | Size: 8.3 KiB |
BIN
data/icons-1/2.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/3.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/4.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/5.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/6.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/7.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/8.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/9.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/IMG_20250904_123453.ico
Normal file
|
After Width: | Height: | Size: 79 KiB |
BIN
data/icons-1/IMG_20250904_134825.ico
Normal file
|
After Width: | Height: | Size: 92 KiB |
BIN
data/icons-1/IRIS.ico
Normal file
|
After Width: | Height: | Size: 22 KiB |
BIN
data/icons-1/Mega Water 1.0.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/fenmian.ico
Normal file
|
After Width: | Height: | Size: 15 KiB |
BIN
data/icons-1/lica.ico
Normal file
|
After Width: | Height: | Size: 18 KiB |
BIN
data/icons-1/liucheng.ico
Normal file
|
After Width: | Height: | Size: 14 KiB |
BIN
data/icons-1/logo.ico
Normal file
|
After Width: | Height: | Size: 6.5 KiB |
BIN
data/icons-1/table.ico
Normal file
|
After Width: | Height: | Size: 47 KiB |
BIN
data/icons-1/uitubiao.ico
Normal file
|
After Width: | Height: | Size: 94 KiB |
BIN
data/icons-1/图片矢量化与编辑.ico
Normal file
|
After Width: | Height: | Size: 50 KiB |
BIN
data/icons-1/屏幕截图 2026-03-27 172136.ico
Normal file
|
After Width: | Height: | Size: 6 B |
BIN
data/icons-1/屏幕截图 2026-03-31 144131.ico
Normal file
|
After Width: | Height: | Size: 52 KiB |
BIN
data/icons-1/演示文稿1.ico
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
data/icons-1/生成软件GUI矢量图标 (2).ico
Normal file
|
After Width: | Height: | Size: 6.5 KiB |
BIN
data/icons-1/生成软件GUI矢量图标 (3).ico
Normal file
|
After Width: | Height: | Size: 8.4 KiB |
BIN
data/icons-1/生成软件GUI矢量图标 (4).ico
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
data/icons/uitubiao.jpg
Normal file
|
After Width: | Height: | Size: 204 KiB |
85
data/格式转化.py
Normal file
@ -0,0 +1,85 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def batch_convert_to_ico(source_dirs, output_dir, target_size=(256, 256)):
|
||||
"""
|
||||
批量将指定目录下的图像文件转换为 ICO 格式。
|
||||
|
||||
:param source_dirs: 包含源文件夹路径的列表
|
||||
:param output_dir: 转换后 ICO 文件的保存目录
|
||||
:param target_size: 输出 ICO 的尺寸,默认 256x256
|
||||
"""
|
||||
# 支持的常见输入图像后缀
|
||||
supported_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.webp', '.tiff'}
|
||||
|
||||
# 确保输出目录存在,若无则自动创建
|
||||
out_path = Path(output_dir)
|
||||
out_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_converted = 0
|
||||
total_failed = 0
|
||||
|
||||
print("=" * 50)
|
||||
print(f"🚀 开始批量转换 ICO 图标...")
|
||||
print(f"📁 目标输出目录: {out_path}")
|
||||
print("=" * 50)
|
||||
|
||||
# 遍历所有传入的源目录
|
||||
for folder in source_dirs:
|
||||
folder_path = Path(folder)
|
||||
|
||||
if not folder_path.exists():
|
||||
print(f"⚠️ 警告: 源目录不存在,已跳过 -> {folder_path}")
|
||||
continue
|
||||
|
||||
print(f"\n📂 正在扫描目录: {folder_path}")
|
||||
|
||||
# 遍历目录下的所有文件
|
||||
for file_path in folder_path.iterdir():
|
||||
# 仅处理普通文件且后缀在支持列表内(忽略大小写)
|
||||
if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
# 处理透明通道问题:
|
||||
# 如果图片支持透明通道 (RGBA/P/LA),转为 RGBA 确保透明背景不丢失
|
||||
# 如果是普通 RGB (如 JPG),转为 RGB
|
||||
if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info):
|
||||
img_clean = img.convert('RGBA')
|
||||
else:
|
||||
img_clean = img.convert('RGB')
|
||||
|
||||
# 构造输出文件名 (原文件名.ico)
|
||||
new_filename = f"{file_path.stem}.ico"
|
||||
save_path = out_path / new_filename
|
||||
|
||||
# 如果目标文件夹中已存在同名文件,为了防止覆盖,可以在文件名后加个标识
|
||||
# 但通常图标库同名直接覆盖较符合需求,这里默认直接保存
|
||||
img_clean.save(save_path, format="ICO", sizes=[target_size])
|
||||
|
||||
print(f" ✅ 成功: {file_path.name} -> {new_filename}")
|
||||
total_converted += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 失败: 无法转换 {file_path.name},错误信息: {e}")
|
||||
total_failed += 1
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("🎉 转换任务结束!")
|
||||
print(f"统计: 成功转换 {total_converted} 个文件,失败 {total_failed} 个。")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 1. 定义你要读取的两个源文件夹路径列表
|
||||
SOURCES = [
|
||||
r"D:\111\office\ZHLduijie\1.WQ\WQ_GUI\data\icons",
|
||||
r"D:\111\office\ZHLduijie\1.WQ\WQ_GUI\data\icons\word"
|
||||
]
|
||||
|
||||
# 2. 定义统一输出的目标文件夹路径
|
||||
OUTPUT = r"D:\111\office\ZHLduijie\1.WQ\WQ_GUI\data\icons-1"
|
||||
|
||||
# 执行转换
|
||||
batch_convert_to_ico(SOURCES, OUTPUT)
|
||||
350
docs/SMOKE_TEST_ROUTE_B_MVP.md
Normal file
@ -0,0 +1,350 @@
|
||||
# Smoke Test — 路线 B MVP(PipelineContext + AutoML + 软取消 + GUI 缝合)
|
||||
|
||||
> 适用范围:路线 B 重构 4 部分(pipeline 包 / AutoML 训练器 / WorkerThread 软取消 / GUI 一键全自动)落盘后的端到端点火试飞清单。
|
||||
> 目标:**用最小数据集(1 个 BSQ + 1 个 CSV)在 10–20 分钟内验证全链路打通**。
|
||||
|
||||
---
|
||||
|
||||
## 0. 前置准备(5 分钟)
|
||||
|
||||
### 0.1 装 Optuna
|
||||
|
||||
`environment.yml` 当前**未列** optuna(属于本次重构新增依赖)。若不装,Step 6 会自动降级到老 GridSearchCV(仍能跑通,但会触发 fallback 日志)。
|
||||
|
||||
```bash
|
||||
call venv\Scripts\activate.bat
|
||||
pip install "optuna>=3.6,<4.0"
|
||||
```
|
||||
|
||||
写入 `environment.yml` 的 patch(提交时改):
|
||||
|
||||
```yaml
|
||||
# 路线 B AutoML 防爆引擎(可选;未装时 Step 6 走老 GridSearchCV 降级路径)
|
||||
- optuna>=3.6
|
||||
```
|
||||
|
||||
### 0.2 准备最小数据集
|
||||
|
||||
```text
|
||||
work_dir_smoke/
|
||||
├── raw/
|
||||
│ ├── sample.b # 假彩色 BSQ(任意小分辨率都行,建议 50×50×6 波段)
|
||||
│ ├── sample_mask.tif # (可选)水域掩膜;不提供则 Step 1 自动生成 NDWI
|
||||
│ └── sample.csv # 含 3–6 个水质参数目标列(Chl-a / TSS / SD / TN / TP / COD…)+ 6 列波段反射率
|
||||
└── (其他文件由流程自动生成)
|
||||
```
|
||||
|
||||
**CSV 模板示例**(`feature_start_column` 默认为第一列;目标列必须**在特征列之前**):
|
||||
|
||||
```csv
|
||||
Chl-a,TSS,SD,B1,B2,B3,B4,B5,B6
|
||||
12.3,15.1,0.8,0.045,0.052,0.038,0.061,0.072,0.085
|
||||
11.8,14.2,0.9,0.044,0.051,0.037,0.060,0.071,0.084
|
||||
... (≥ 200 行;AutoML 智能子采样 N>5000 时才生效)
|
||||
```
|
||||
|
||||
### 0.3 启动 venv
|
||||
|
||||
```bash
|
||||
cd /d "D:\111\office\ZHLduijie\1.WQ\WQ_GUI"
|
||||
call venv\Scripts\activate.bat
|
||||
set PYTHONPATH=src;%PYTHONPATH%
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1. CLI 烟雾(最快路径,3 分钟)— **A 级:必跑**
|
||||
|
||||
跳过 GUI,直接验证 `automl_trainer.py` 自身可独立运行 + Optuna 子采样 + 降级路径:
|
||||
|
||||
```bash
|
||||
python -m src.core.prediction.automl_trainer ^
|
||||
--csv work_dir_smoke/raw/sample.csv ^
|
||||
--feature-start 6 ^
|
||||
--n-trials 5 ^
|
||||
--timeout 60.0 ^
|
||||
--out work_dir_smoke/7_Supervised_Model_Training_AutoML
|
||||
```
|
||||
|
||||
**通过标准**:
|
||||
|
||||
- [ ] 进程退出码 0
|
||||
- [ ] 控制台打印 `AutoML: 目标列 X 共尝试 N 个 trial,最佳 CV R²=…`
|
||||
- [ ] `<out>/<preprocess>/<target>_<preprocess>_<model>_AUTOML.joblib` 存在
|
||||
- [ ] `<out>/automl_summary.json` 存在且 `success=true`
|
||||
|
||||
**若 Optuna 未装**,期待看到:
|
||||
|
||||
```
|
||||
[AutoML] optuna 未安装,全目标列回退老 GridSearchCV
|
||||
```
|
||||
|
||||
产物文件名带 `_AUTOML` 后缀的逻辑此时**不会触发**(fallback 走老路径),属正常。
|
||||
|
||||
---
|
||||
|
||||
## 2. GUI 端到端 9 步(核心场景,10–20 分钟)— **S 级:必跑**
|
||||
|
||||
### 2.1 启动 GUI
|
||||
|
||||
```bash
|
||||
call venv\Scripts\activate.bat
|
||||
set PYTHONPATH=src;%PYTHONPATH%
|
||||
python -m src.gui.water_quality_gui
|
||||
```
|
||||
|
||||
### 2.2 UI 配置
|
||||
|
||||
| 步骤 | 操作 | 期望 |
|
||||
| ----- | -------------------------------------------------------------------- | ------------------------------------------------------------------------------------ |
|
||||
| 1/9 | 点"选择工作目录" → 选 `work_dir_smoke/` | 左侧步骤列表高亮,UI 不报错 |
|
||||
| 2/9 | 在 Step 1 面板选 `sample.b`;**掩膜留空**(验证 NDWI 自动生成路径) | 掩膜文本框保持空白 |
|
||||
| 3/9 | 在 Step 4 面板选 `sample.csv` | CSV 路径显示正确 |
|
||||
| 4/9 | **关键**:其他步骤(2/3/5/5.5/6/7/8/9)保持默认,不改任何参数 | AutoML 默认开启(use_automl=True) |
|
||||
| 5/9 | 点 **▶ 运行完整流程**(不要用老 `run_full_pipeline` 槽) | 弹出**二次确认窗**,文案显示:<br>• 掩膜:`未指定(将自动生成 NDWI 水域掩膜)`<br>• 去耀斑:开启<br>• AutoML:开启(Optuna 子采样寻优) |
|
||||
| 6/9 | 点"是(Y)" | "运行"按钮变灰,"停止"按钮亮起;进度条归零 |
|
||||
|
||||
### 2.3 观察日志(重点 4 大检查点)
|
||||
|
||||
#### ✅ 检查点 1:ctx 路径传递
|
||||
|
||||
启动后**第一秒**应看到类似:
|
||||
|
||||
```
|
||||
[Runner] ctx 已构造:14 路径字段,4 目录字段
|
||||
[Runner] 步骤 1/14:step1_generate_water_mask(requires=['raw_img_path', 'water_mask_path'])
|
||||
[Runner] 步骤 2/14:step2_find_glint_area(requires=['raw_img_path', 'water_mask_path', 'output_dir'])
|
||||
...
|
||||
[Runner] ctx 路径校准:water_mask_path = ...\work_dir_smoke\2_Glint_Area_Mask\glint_mask.tif
|
||||
```
|
||||
|
||||
→ **若没有 `[Runner]` 日志**,说明 v1 旧路径被走到了,**`inspect.signature` duck-type 没探测到 v2**,回去检查 `worker_thread.py:run()`。
|
||||
|
||||
#### ✅ 检查点 2:Step 1 NDWI 自动生成
|
||||
|
||||
```
|
||||
[Step1] 未指定 mask_path,自动基于 NDWI 生成水域掩膜
|
||||
[Step1] NDWI 阈值=0.4,写入 1_Water_Mask/water_mask.tif
|
||||
```
|
||||
|
||||
→ 验证 `<work_dir>/1_Water_Mask/water_mask.tif` 文件存在且非空。
|
||||
|
||||
#### ✅ 检查点 3:AutoML 启用
|
||||
|
||||
```
|
||||
[Step6] AutoML 启用 Optuna 子采样寻优(timeout=300s, n_trials=20, max_samples=5000)
|
||||
[Step6] 目标列 'Chl-a' 共 3 个候选模型,最佳 R²=0.812(model=RandomForest)
|
||||
[Step6] 目标列 'TSS' 共 3 个候选模型,最佳 R²=0.745(model=XGBoost)
|
||||
[Step6] 训练完成,产物写入 7_Supervised_Model_Training_AutoML/
|
||||
[Step6] automl_summary.json 写入完成
|
||||
```
|
||||
|
||||
→ 验证产物:
|
||||
- [ ] `7_Supervised_Model_Training_AutoML/<preprocess>/<target>_<preprocess>_<model>_AUTOML.joblib` ≥ 1 个
|
||||
- [ ] `7_Supervised_Model_Training_AutoML/automl_summary.json` 含 `automl: true` 字段
|
||||
- [ ] 老目录 `7_Supervised_Model_Training/` **不应该被创建**(AutoML 路径独立)
|
||||
|
||||
#### ✅ 检查点 4:AutoML 降级(仅未装 Optuna 时)
|
||||
|
||||
```
|
||||
[AutoML] optuna 未安装,全目标列回退老 GridSearchCV
|
||||
[Step6] 降级路径:调用 WaterQualityModelingBatch.train_models_batch(132 组 GridSearchCV)
|
||||
```
|
||||
|
||||
→ 跑通即可(仍能产生模型文件),但**降级**属于非优选路径。
|
||||
|
||||
### 2.4 9 步全程观察清单
|
||||
|
||||
| 步 | 期望产物(路径相对 `work_dir`) | 期望耗时(50×50 测试数据) |
|
||||
| ---- | -------------------------------------------------------------- | -------------------------- |
|
||||
| 1 | `1_Water_Mask/water_mask.tif` | < 5 s |
|
||||
| 2 | `2_Glint_Area_Mask/glint_mask.tif` | < 5 s |
|
||||
| 3 | `3_Remove_Glint_Image/deglint_image.tif` | < 5 s |
|
||||
| 4 | `4_Process_CSV/processed_data.csv` | < 2 s |
|
||||
| 5 | `5_Training_Sample/training_spectra.csv` | < 5 s |
|
||||
| 5.5 | `5_5_Calculate_Indices/indices.csv`(如启用) | < 2 s |
|
||||
| **6**| `7_Supervised_Model_Training_AutoML/`(**新路径!**) | **< 5 min(Optuna 5 trial)** |
|
||||
| 6.5 | `6_5_Non_Empirical_Modeling/`(如启用) | 1–2 min |
|
||||
| 6.75 | `6_75_Custom_Regression/`(如启用) | 1–2 min |
|
||||
| 7 | `7_Sampling_Points/sampling_points.csv` | < 3 s |
|
||||
| 8 | `8_Prediction/predicted_values.csv` | < 5 s |
|
||||
| 8.5 | `8_5_Prediction_Non_Empirical/predicted.csv`(如启用) | < 5 s |
|
||||
| 8.75 | `8_75_Prediction_Custom/predicted.csv`(如启用) | < 5 s |
|
||||
| 9 | `9_Kriging_Distribution_Map/distribution_map.tif` | 5–30 s(纯 Python 慢) |
|
||||
|
||||
### 2.5 流程结束
|
||||
|
||||
- [ ] 进度条到 100%
|
||||
- [ ] "运行"按钮恢复可点
|
||||
- [ ] "停止"按钮变灰
|
||||
- [ ] 日志末行出现 `=== 流程执行完成 ===` 或 `=== 流程被取消 ===`(取决于是否点过停止)
|
||||
- [ ] 控制台 `on_pipeline_finished` 触发:UI 状态被统一恢复
|
||||
|
||||
---
|
||||
|
||||
## 3. 软取消测试(3 分钟)— **A 级:必跑**
|
||||
|
||||
验证 `threading.Event` 软取消链路(不再用 `terminate()`)。
|
||||
|
||||
### 3.1 启动完整流程
|
||||
|
||||
如 2.2 启动流程。
|
||||
|
||||
### 3.2 中途点"停止"
|
||||
|
||||
**时机**:在 Step 6 AutoML 跑 trials 的中途(看到 `[Step6] 目标列 'Chl-a' 共 N 个候选模型` 之后任意时刻)点"停止"。
|
||||
|
||||
**期望看到**:
|
||||
|
||||
```
|
||||
[STOP] 用户请求软取消
|
||||
[Step6] 检测到 cancel_event,本 trial 完成后退出
|
||||
[Step6] AutoML 在 trial #X 中止,已完成 5/20 trial
|
||||
[Runner] 软取消:跳过剩余 8 个 step
|
||||
=== 流程被取消 ===
|
||||
```
|
||||
|
||||
UI 状态:
|
||||
|
||||
- [ ] "运行"按钮重新亮起
|
||||
- [ ] "停止"按钮变灰
|
||||
- [ ] 进度条保留在中断时的百分比(**不**归零)
|
||||
- [ ] `on_pipeline_finished` 触发(用 `success=False, cancelled=True` 区分)
|
||||
- [ ] **Python 进程不退出**(GUI 仍可继续点"运行"开新流程)
|
||||
|
||||
**反例(不应该发生)**:
|
||||
|
||||
- ❌ `QThread: Destroyed while thread is still running` 警告
|
||||
- ❌ Python 解释器直接崩溃
|
||||
- ❌ UI 永远卡死(`run_all_btn` 一直是灰的)
|
||||
|
||||
### 3.3 旧 `stop()` 路径回归
|
||||
|
||||
为防老代码忘了改,临时把 `water_quality_gui.py:stop_pipeline` 改回 `self.worker.stop()`,跑一次完整流程,看是否出现:
|
||||
|
||||
```
|
||||
[DEPRECATED] WorkerThread.stop() 已弃用,请改用 soft_stop()。
|
||||
```
|
||||
|
||||
**这是预期行为**(弃用方法保留但打 warning),流程仍能完成即视为通过。
|
||||
|
||||
---
|
||||
|
||||
## 4. 失败 / 降级场景(5 分钟)— **B 级:选跑**
|
||||
|
||||
### 4.1 未填掩膜 + NDWI 阈值设极端值
|
||||
|
||||
把 NDWI 阈值设到 `0.9`(几乎无水域),Step 1 应给出 warning 但不崩:
|
||||
|
||||
```
|
||||
[Step1] NDWI 阈值=0.9,水域覆盖率 < 1%,请检查影像
|
||||
```
|
||||
|
||||
### 4.2 CSV 完全无目标列
|
||||
|
||||
准备一个**没有目标列的 CSV**(全特征列),点运行:
|
||||
|
||||
```
|
||||
[AutoML] 训练 CSV 不存在或无目标列:未识别出目标列
|
||||
[Step6] AutoML 全部失败,所有目标列返回 success=False
|
||||
```
|
||||
|
||||
→ UI 不会崩,会在 `automl_summary.json` 写 `error: "未识别出目标列"`。
|
||||
|
||||
### 4.3 Step 1 路径不存在
|
||||
|
||||
Step 1 选了一个**不存在的 .bsq 文件**:
|
||||
|
||||
```
|
||||
[Runner] step1_generate_water_mask 异常:FileNotFoundError
|
||||
[STOP] 流程中止在 step 1
|
||||
```
|
||||
|
||||
→ UI 弹错误窗 + 把左侧步骤列表 `setCurrentRow(0)` 自动定位到 Step 1(`_focus_step` 起效)。
|
||||
|
||||
### 4.4 Optuna 版本冲突
|
||||
|
||||
装一个 `optuna==2.10`(API 大改),跑 GUI:
|
||||
|
||||
```
|
||||
[AutoML] optuna API 不兼容(>=3.6 要求):<error>
|
||||
[AutoML] 全目标列回退老 GridSearchCV
|
||||
```
|
||||
|
||||
→ 降级路径生效即视为通过。
|
||||
|
||||
---
|
||||
|
||||
## 5. 验证矩阵 Checklist
|
||||
|
||||
复制以下到 PR 描述 / 验收单:
|
||||
|
||||
```markdown
|
||||
## 路线 B MVP 验证矩阵
|
||||
|
||||
### 代码落盘
|
||||
- [ ] src/core/pipeline/__init__.py(17 行,4 export)
|
||||
- [ ] src/core/pipeline/context.py(PipelineContext dataclass)
|
||||
- [ ] src/core/pipeline/runner.py(StepSpec + PIPELINE_STEPS + PipelineRunner)
|
||||
- [ ] src/core/prediction/__init__.py(追加 train_with_automl export)
|
||||
- [ ] src/core/prediction/automl_trainer.py(AutoMLResult + train_with_automl + CLI)
|
||||
- [ ] src/core/steps/modeling_step.py(use_automl 分支 + _train_models_automl)
|
||||
- [ ] src/core/water_quality_inversion_pipeline_GUI.py(run_full_pipeline_v2 + LEGACY_ATTR_MAP + _sync_legacy_attrs_from_context)
|
||||
- [ ] src/gui/core/worker_thread.py(cancel_event + soft_stop + run() duck-type)
|
||||
- [ ] src/gui/water_quality_gui.py(on_run_all_clicked + _collect_minimal_config + 按钮重连)
|
||||
|
||||
### CLI 自测
|
||||
- [ ] A.1 `python -m src.core.prediction.automl_trainer --csv ...` 退出码 0
|
||||
- [ ] A.2 产物 .joblib 含 `_AUTOML` 后缀
|
||||
- [ ] A.3 automl_summary.json 含 success=true
|
||||
|
||||
### GUI 端到端
|
||||
- [ ] B.1 启动无 ImportError
|
||||
- [ ] B.2 二次确认窗文案含 mask 提示 + AutoML 状态
|
||||
- [ ] B.3 日志含 [Runner] 前缀(v2 路径生效)
|
||||
- [ ] B.4 Step 1 NDWI 自动生成路径生效
|
||||
- [ ] B.5 9 步产物路径全部存在
|
||||
- [ ] B.6 流程结束后 UI 状态恢复(运行按钮亮、停止按钮灰)
|
||||
|
||||
### 软取消
|
||||
- [ ] C.1 流程中途点停止,cancel_event 触发
|
||||
- [ ] C.2 流程被取消而非崩溃
|
||||
- [ ] C.3 UI 状态由 on_pipeline_finished 统一恢复
|
||||
- [ ] C.4 旧 stop() 调用打 [DEPRECATED] warning
|
||||
|
||||
### 降级
|
||||
- [ ] D.1 Optuna 未装 → 全目标列回退老 GridSearchCV
|
||||
- [ ] D.2 无目标列 CSV → 写 error 到 summary,不崩 UI
|
||||
- [ ] D.3 不存在文件 → _focus_step 定位到对应 step
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 已知未做(不在本次范围)
|
||||
|
||||
- [ ] Kriging 多进程并行(当前 backend="loop" 纯 Python)
|
||||
- [ ] Step 5 radius==0 内存优化(整波段读入)
|
||||
- [ ] 进度条 sub-step 粒度(当前只到 step 级)
|
||||
- [ ] Step 8 全图预测(当前只对采样点预测)
|
||||
- [ ] 全项目搜替换老 `self.worker.stop()` 调用(仅本会话改了 `water_quality_gui.py` 的 stop_pipeline)
|
||||
- [ ] `requirements.txt` 同步 Optuna(仅 `environment.yml` 写)
|
||||
- [ ] 单元测试套件(`tests/` 目录为空;建议用 pytest 覆盖 train_with_automl / PipelineRunner)
|
||||
|
||||
---
|
||||
|
||||
## 7. 出问题找哪里
|
||||
|
||||
| 现象 | 看哪里 |
|
||||
| --------------------------------------------- | ------------------------------------------------------- |
|
||||
| `[Runner]` 日志没出来 | `worker_thread.py:run()` 的 `inspect.signature` 探测 |
|
||||
| `[AutoML]` 完全没打 | `modeling_step.py:170` 的 `if use_automl` 是否进了 |
|
||||
| AutoML 报 `optuna API 不兼容` | `automl_trainer.py:236` 的 `try import` 块 |
|
||||
| 软取消无反应 | `worker_thread.py:run()` 末尾的 `cancel_event.is_set()` |
|
||||
| 二次确认窗没出来 | `water_quality_gui.py:on_run_all_clicked` line ~2848 |
|
||||
| 9 步产物路径错位 | `pipeline/runner.py:PIPELINE_STEPS` 的 `output` 字段 |
|
||||
| 老 v1 路径被走到 | `_sync_legacy_attrs_from_context` 没调,或 v2 异常 |
|
||||
|
||||
---
|
||||
|
||||
> **作者注**:本清单对应**路线 B 一键全自动重构 4 部分全部落盘**的验收场景,编号与 todo 8 同步。
|
||||
> 跑通 §1 + §2 + §3 三段即视为 MVP 验收通过;§4 用于鲁棒性抽查。
|
||||
2
frontend/.env.development
Normal file
@ -0,0 +1,2 @@
|
||||
# 联调期指向本地 FastAPI dev 服务
|
||||
VITE_API_BASE_URL=http://127.0.0.1:9090
|
||||
7
frontend/.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
.vite
|
||||
*.local
|
||||
.DS_Store
|
||||
*.log
|
||||
15
frontend/env.d.ts
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
/// <reference types="vite/client" />
|
||||
|
||||
interface ImportMetaEnv {
|
||||
readonly VITE_API_BASE_URL?: string
|
||||
}
|
||||
|
||||
interface ImportMeta {
|
||||
readonly env: ImportMetaEnv
|
||||
}
|
||||
|
||||
declare module '*.vue' {
|
||||
import type { DefineComponent } from 'vue'
|
||||
const component: DefineComponent<{}, {}, any>
|
||||
export default component
|
||||
}
|
||||
13
frontend/index.html
Normal file
@ -0,0 +1,13 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>WQ_GUI · 水质反演联调控制台</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
<script type="module" src="/src/main.ts"></script>
|
||||
</body>
|
||||
</html>
|
||||
25
frontend/package.json
Normal file
@ -0,0 +1,25 @@
|
||||
{
|
||||
"name": "wq-gui-frontend",
|
||||
"private": true,
|
||||
"version": "0.0.1",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vue-tsc --noEmit && vite build",
|
||||
"preview": "vite preview",
|
||||
"type-check": "vue-tsc --noEmit"
|
||||
},
|
||||
"dependencies": {
|
||||
"vue": "^3.4.27",
|
||||
"element-plus": "^2.7.5",
|
||||
"@element-plus/icons-vue": "^2.3.1",
|
||||
"axios": "^1.7.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.12.12",
|
||||
"@vitejs/plugin-vue": "^5.0.4",
|
||||
"typescript": "^5.4.5",
|
||||
"vite": "^5.2.11",
|
||||
"vue-tsc": "^2.0.19"
|
||||
}
|
||||
}
|
||||
225
frontend/src/App.vue
Normal file
@ -0,0 +1,225 @@
|
||||
<template>
|
||||
<div class="dashboard-container">
|
||||
<h1 class="title">高光谱水质反演控制台</h1>
|
||||
<el-row :gutter="20">
|
||||
|
||||
<el-col :span="12">
|
||||
<el-card class="box-card" shadow="hover">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span class="header-title">🚀 模型训练 (Train)</span>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<el-form label-position="top">
|
||||
<el-form-item label="算法选择 (Model Type)">
|
||||
<el-select v-model="trainForm.model_type" placeholder="请选择算法" class="w-full">
|
||||
<el-option label="随机森林 (RF)" value="RF" />
|
||||
<el-option label="支持向量回归 (SVR)" value="SVR" />
|
||||
<el-option label="线性回归 (LinearRegression)" value="LinearRegression" />
|
||||
<el-option label="K近邻 (KNN)" value="KNN" />
|
||||
<el-option label="偏最小二乘 (PLS)" value="PLS" />
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="目标参数 (Target)">
|
||||
<el-input v-model="trainForm.target" placeholder="如 Chl-a" />
|
||||
</el-form-item>
|
||||
<el-form-item label="训练数据路径 (CSV 绝对路径)">
|
||||
<el-input v-model="trainForm.train_data_path" placeholder="如 D:\111\data.csv" />
|
||||
</el-form-item>
|
||||
<el-form-item label="特征起始列 (如 4, 或列名)">
|
||||
<el-input v-model="trainForm.feature_start" placeholder="填写数字或列名" />
|
||||
</el-form-item>
|
||||
<el-button type="primary" @click="handleTrain" :loading="trainPoller?.isPolling?.value" class="w-full">
|
||||
开始训练
|
||||
</el-button>
|
||||
</el-form>
|
||||
|
||||
<div v-if="trainTaskId" class="status-board">
|
||||
<p><strong>任务 ID:</strong> <el-tag size="small" type="info">{{ trainTaskId }}</el-tag></p>
|
||||
<p><strong>当前状态:</strong>
|
||||
<el-tag :type="getStatusType(trainPoller?.status?.value || 'PENDING')" style="margin-left:10px">
|
||||
{{ trainPoller?.status?.value || 'PENDING' }}
|
||||
</el-tag>
|
||||
</p>
|
||||
<el-progress
|
||||
v-if="trainPoller?.isPolling?.value || trainPoller?.status?.value === 'SUCCESS'"
|
||||
:percentage="trainPoller?.status?.value === 'SUCCESS' ? 100 : 60"
|
||||
:status="trainPoller?.status?.value === 'SUCCESS' ? 'success' : (trainPoller?.status?.value === 'FAILED' ? 'exception' : '')"
|
||||
:indeterminate="trainPoller?.isPolling?.value"
|
||||
/>
|
||||
|
||||
<div v-if="trainPoller?.error?.value" class="error-msg">
|
||||
<el-alert :title="trainPoller.error.value" type="error" :closable="false" show-icon />
|
||||
</div>
|
||||
|
||||
<div v-if="trainPoller?.result?.value?.model_id" class="result-msg">
|
||||
<el-descriptions border :column="1" size="small" title="训练指标">
|
||||
<el-descriptions-item label="Model ID">{{ trainPoller.result.value.model_id }}</el-descriptions-item>
|
||||
<el-descriptions-item label="Test R²">{{ Number(trainPoller.result.value.test_r2).toFixed(4) }}</el-descriptions-item>
|
||||
<el-descriptions-item label="Test RMSE">{{ Number(trainPoller.result.value.test_rmse).toFixed(4) }}</el-descriptions-item>
|
||||
</el-descriptions>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
|
||||
<el-col :span="12">
|
||||
<el-card class="box-card" shadow="hover">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span class="header-title">🎯 模型推断 (Predict)</span>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<el-form label-position="top">
|
||||
<el-form-item label="已训练模型 ID (Model ID)">
|
||||
<el-input v-model="predictForm.model_id" placeholder="将自动填入左侧训练好的 ID" />
|
||||
</el-form-item>
|
||||
<el-form-item label="待推断影像路径 (Zarr 绝对路径)">
|
||||
<el-input v-model="predictForm.input_zarr_path" placeholder="如 D:\111\image.zarr" />
|
||||
</el-form-item>
|
||||
<el-button type="success" @click="handlePredict" :loading="predictPoller?.isPolling?.value" class="w-full">
|
||||
开始大图反演推断
|
||||
</el-button>
|
||||
</el-form>
|
||||
|
||||
<div v-if="predictTaskId" class="status-board">
|
||||
<p><strong>任务 ID:</strong> <el-tag size="small" type="info">{{ predictTaskId }}</el-tag></p>
|
||||
<p><strong>当前状态:</strong>
|
||||
<el-tag :type="getStatusType(predictPoller?.status?.value || 'PENDING')" style="margin-left:10px">
|
||||
{{ predictPoller?.status?.value || 'PENDING' }}
|
||||
</el-tag>
|
||||
</p>
|
||||
<el-progress
|
||||
v-if="predictPoller?.isPolling?.value || predictPoller?.status?.value === 'SUCCESS'"
|
||||
:percentage="predictPoller?.status?.value === 'SUCCESS' ? 100 : 50"
|
||||
:status="predictPoller?.status?.value === 'SUCCESS' ? 'success' : (predictPoller?.status?.value === 'FAILED' ? 'exception' : '')"
|
||||
:indeterminate="predictPoller?.isPolling?.value"
|
||||
/>
|
||||
|
||||
<div v-if="predictPoller?.error?.value" class="error-msg">
|
||||
<el-alert :title="predictPoller.error.value" type="error" :closable="false" show-icon />
|
||||
</div>
|
||||
|
||||
<div v-if="predictPoller?.result?.value?.output_zarr_path" class="result-msg">
|
||||
<el-alert :title="'推断成功!结果已落盘至: ' + predictPoller.result.value.output_zarr_path" type="success" :closable="false" show-icon />
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
|
||||
</el-row>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, reactive } from 'vue'
|
||||
import { submitTrain, submitPredict } from '@/api/tasks'
|
||||
import { useTaskPoller } from '@/composables/useTaskPoller'
|
||||
|
||||
// 训练表单状态
|
||||
const trainForm = reactive({
|
||||
model_type: 'RF',
|
||||
target: 'Chl-a',
|
||||
train_data_path: '',
|
||||
feature_start: '4'
|
||||
})
|
||||
const trainTaskId = ref<string | null>(null)
|
||||
const trainPoller = useTaskPoller(trainTaskId)
|
||||
|
||||
// 推断表单状态
|
||||
const predictForm = reactive({
|
||||
model_id: '',
|
||||
input_zarr_path: ''
|
||||
})
|
||||
const predictTaskId = ref<string | null>(null)
|
||||
const predictPoller = useTaskPoller(predictTaskId)
|
||||
|
||||
// 自动填入联动
|
||||
watch(() => trainPoller?.result?.value?.model_id, (newId) => {
|
||||
if (newId) predictForm.model_id = newId as string
|
||||
})
|
||||
|
||||
// 提交训练
|
||||
const handleTrain = async () => {
|
||||
try {
|
||||
const res = await submitTrain({
|
||||
model_type: trainForm.model_type,
|
||||
target: trainForm.target,
|
||||
train_data_path: trainForm.train_data_path,
|
||||
feature_start: trainForm.feature_start,
|
||||
params: {}
|
||||
})
|
||||
trainTaskId.value = res.task_id
|
||||
} catch (e: any) {
|
||||
console.error('训练接口调用失败', e)
|
||||
alert('提交失败,请检查后端是否在 9090 端口启动,或按 F12 查看控制台跨域报错')
|
||||
}
|
||||
}
|
||||
|
||||
// 提交推断
|
||||
const handlePredict = async () => {
|
||||
try {
|
||||
const res = await submitPredict({
|
||||
model_id: predictForm.model_id,
|
||||
input_zarr_path: predictForm.input_zarr_path
|
||||
})
|
||||
predictTaskId.value = res.task_id
|
||||
} catch (e: any) {
|
||||
console.error('推断接口调用失败', e)
|
||||
}
|
||||
}
|
||||
|
||||
// 样式辅助
|
||||
const getStatusType = (status: string) => {
|
||||
if (status === 'SUCCESS') return 'success'
|
||||
if (status === 'FAILED') return 'danger'
|
||||
if (status === 'PROCESSING') return 'warning'
|
||||
return 'info'
|
||||
}
|
||||
</script>
|
||||
|
||||
<style>
|
||||
/* 去除全局默认边距 */
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
</style>
|
||||
|
||||
<style scoped>
|
||||
.dashboard-container {
|
||||
padding: 40px;
|
||||
min-height: 100vh;
|
||||
background-color: #1e1e2d; /* 科技深色底 */
|
||||
}
|
||||
.title {
|
||||
text-align: center;
|
||||
margin-bottom: 40px;
|
||||
color: #ffffff;
|
||||
font-weight: 300;
|
||||
letter-spacing: 2px;
|
||||
}
|
||||
.header-title {
|
||||
font-weight: bold;
|
||||
font-size: 16px;
|
||||
}
|
||||
.box-card {
|
||||
margin-bottom: 20px;
|
||||
background-color: rgba(255, 255, 255, 0.95);
|
||||
}
|
||||
.w-full {
|
||||
width: 100%;
|
||||
}
|
||||
.status-board {
|
||||
margin-top: 25px;
|
||||
padding: 20px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
border: 1px solid #e4e7ed;
|
||||
}
|
||||
.error-msg, .result-msg {
|
||||
margin-top: 20px;
|
||||
}
|
||||
</style>
|
||||
94
frontend/src/api/request.ts
Normal file
@ -0,0 +1,94 @@
|
||||
/**
|
||||
* Axios 单例 + 响应拦截器
|
||||
* --------------------------------
|
||||
* 1. baseURL 默认指向本地 FastAPI dev 服务。
|
||||
* 通过 Vite 环境变量 VITE_API_BASE_URL 可覆盖, 例如:
|
||||
* .env.development: VITE_API_BASE_URL=http://127.0.0.1:8000
|
||||
* .env.production: VITE_API_BASE_URL=https://api.example.com
|
||||
*
|
||||
* 2. 响应拦截器统一 unwrap response.data, 调用方拿到的是真正的业务对象,
|
||||
* 而不是 AxiosResponse 包装。失败时统一抛 Error, message 优先取
|
||||
* FastAPI 的 detail 字段。
|
||||
*
|
||||
* 3. 类型增强: cast 成 UnwrappedAxiosInstance, 让 request.get<T>(url)
|
||||
* 的返回类型直接是 T, 而不是 AxiosResponse<T>, 调用方无需二次解包。
|
||||
*/
|
||||
import axios, {
|
||||
type AxiosInstance,
|
||||
type AxiosRequestConfig,
|
||||
} from 'axios'
|
||||
|
||||
// 在 Vite 下用 import.meta.env; 其它环境 (webpack/直接 ts-node) 兜底到 process.env
|
||||
type ViteEnv = { env?: Record<string, string | undefined> }
|
||||
const viteEnv: ViteEnv | undefined =
|
||||
typeof import.meta !== 'undefined' ? ((import.meta as unknown) as ViteEnv) : undefined
|
||||
|
||||
const baseURL: string =
|
||||
viteEnv?.env?.VITE_API_BASE_URL ??
|
||||
(typeof process !== 'undefined' && process.env?.VITE_API_BASE_URL) ??
|
||||
'http://127.0.0.1:9090'
|
||||
|
||||
const _instance: AxiosInstance = axios.create({
|
||||
baseURL,
|
||||
timeout: 15000,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
// FastAPI 开发期 CORS 是 allow_origins=["*"], 不需要带 cookie
|
||||
withCredentials: false,
|
||||
})
|
||||
|
||||
// ----- 请求拦截器: 预留 token / 日志位 -----
|
||||
_instance.interceptors.request.use(
|
||||
(config) => {
|
||||
// const token = localStorage.getItem('token')
|
||||
// if (token) config.headers.Authorization = `Bearer ${token}`
|
||||
return config
|
||||
},
|
||||
(error) => Promise.reject(error),
|
||||
)
|
||||
|
||||
// ----- 响应拦截器: unwrap data + 统一错误 message -----
|
||||
_instance.interceptors.response.use(
|
||||
(response) => response.data,
|
||||
(error) => {
|
||||
const detail = error?.response?.data?.detail
|
||||
const message =
|
||||
(typeof detail === 'string' ? detail : detail?.msg) ??
|
||||
error?.response?.data?.message ??
|
||||
error?.message ??
|
||||
'请求失败'
|
||||
return Promise.reject(new Error(message))
|
||||
},
|
||||
)
|
||||
|
||||
// ----- 类型增强: 把响应拦截器 unwrap 的事实在类型上表达出来 -----
|
||||
type UnwrappedAxiosInstance = Omit<
|
||||
AxiosInstance,
|
||||
'get' | 'delete' | 'head' | 'options' | 'post' | 'put' | 'patch'
|
||||
> & {
|
||||
get<T = unknown>(url: string, config?: AxiosRequestConfig): Promise<T>
|
||||
delete<T = unknown>(url: string, config?: AxiosRequestConfig): Promise<T>
|
||||
head<T = unknown>(url: string, config?: AxiosRequestConfig): Promise<T>
|
||||
options<T = unknown>(url: string, config?: AxiosRequestConfig): Promise<T>
|
||||
post<T = unknown, D = unknown>(
|
||||
url: string,
|
||||
data?: D,
|
||||
config?: AxiosRequestConfig,
|
||||
): Promise<T>
|
||||
put<T = unknown, D = unknown>(
|
||||
url: string,
|
||||
data?: D,
|
||||
config?: AxiosRequestConfig,
|
||||
): Promise<T>
|
||||
patch<T = unknown, D = unknown>(
|
||||
url: string,
|
||||
data?: D,
|
||||
config?: AxiosRequestConfig,
|
||||
): Promise<T>
|
||||
}
|
||||
|
||||
const request = _instance as UnwrappedAxiosInstance
|
||||
|
||||
export default request
|
||||
export { baseURL }
|
||||
155
frontend/src/api/tasks.ts
Normal file
@ -0,0 +1,155 @@
|
||||
/**
|
||||
* 与 FastAPI 后端对接的 API 函数
|
||||
* --------------------------------
|
||||
* 全部用 request 单例, 调用方拿到的就是业务对象 (response 拦截器已 unwrap)。
|
||||
*
|
||||
* 后端路由:
|
||||
* GET /api/algorithms
|
||||
* POST /api/process/deglint
|
||||
* POST /api/modeling/train
|
||||
* POST /api/modeling/predict (额外, 与 train 配套)
|
||||
* GET /api/tasks/{task_id}
|
||||
*/
|
||||
import request from './request'
|
||||
|
||||
// ============================================================
|
||||
// 通用类型
|
||||
// ============================================================
|
||||
|
||||
/** 后端任务状态机 (与 app.core.task_store.TASK_STORE 保持一致) */
|
||||
export type TaskStatus = 'PENDING' | 'PROCESSING' | 'SUCCESS' | 'FAILED'
|
||||
|
||||
/** 任务类型, 区分去耀斑 / 训练 / 推断 */
|
||||
export type TaskKind = 'deglint' | 'train' | 'predict'
|
||||
|
||||
/** 提交后端后立即返回的最小任务凭证 */
|
||||
export interface TaskAcceptedResponse {
|
||||
task_id: string
|
||||
status: TaskStatus
|
||||
kind: TaskKind
|
||||
}
|
||||
|
||||
/**
|
||||
* 任务详情 (与后端 TASK_STORE 里记录的字段对齐, 通用 + 各 kind 增量字段)
|
||||
* 用 [key: string]: unknown 兜底, 兼容未来后端新增字段
|
||||
*/
|
||||
export interface TaskRecord {
|
||||
task_id: string
|
||||
kind: TaskKind
|
||||
status: TaskStatus
|
||||
// 去耀斑
|
||||
algorithm?: string
|
||||
input_zarr_path?: string
|
||||
output_zarr_path?: string | null
|
||||
// 训练
|
||||
model_type?: string
|
||||
target?: string
|
||||
train_data_path?: string
|
||||
feature_start?: number | string
|
||||
params?: Record<string, unknown>
|
||||
model_id?: string | null
|
||||
model_path?: string | null
|
||||
test_r2?: number | null
|
||||
test_rmse?: number | null
|
||||
test_mae?: number | null
|
||||
n_features?: number | null
|
||||
n_samples?: number | null
|
||||
// 推断
|
||||
// (model_id / input_zarr_path / output_zarr_path 已在上方)
|
||||
// 失败
|
||||
error?: string | null
|
||||
traceback?: string | null
|
||||
// 元
|
||||
created_at?: string
|
||||
updated_at?: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 1) 算法列表 GET /api/algorithms
|
||||
// ============================================================
|
||||
|
||||
export interface AlgorithmInfo {
|
||||
name: string
|
||||
doc?: string
|
||||
}
|
||||
|
||||
export interface AlgorithmListResponse {
|
||||
algorithms: AlgorithmInfo[]
|
||||
count: number
|
||||
}
|
||||
|
||||
export function getAlgorithms(): Promise<AlgorithmListResponse> {
|
||||
return request.get<AlgorithmListResponse>('/api/algorithms')
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 2) 提交去耀斑 POST /api/process/deglint
|
||||
// ============================================================
|
||||
|
||||
export interface DeglintParams {
|
||||
input_zarr_path: string
|
||||
output_zarr_path?: string
|
||||
/** 算法自定义参数 (D_max / band 选择等) */
|
||||
[key: string]: unknown
|
||||
}
|
||||
|
||||
export function submitDeglint(
|
||||
method: string,
|
||||
params: DeglintParams,
|
||||
): Promise<TaskAcceptedResponse> {
|
||||
return request.post<TaskAcceptedResponse, { method: string; params: DeglintParams }>(
|
||||
'/api/process/deglint',
|
||||
{ method, params },
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 3) 提交训练 POST /api/modeling/train
|
||||
// ============================================================
|
||||
|
||||
export interface TrainRequest {
|
||||
model_type: string
|
||||
target: string
|
||||
train_data_path: string
|
||||
/** 特征起始列, int 索引或 str 列名, 默认 4 */
|
||||
feature_start?: number | string
|
||||
/** sklearn 估计器超参 */
|
||||
params?: Record<string, unknown>
|
||||
}
|
||||
|
||||
export function submitTrain(payload: TrainRequest): Promise<TaskAcceptedResponse> {
|
||||
return request.post<TaskAcceptedResponse, TrainRequest>(
|
||||
'/api/modeling/train',
|
||||
payload,
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 4) 提交推断 POST /api/modeling/predict (配套, 训练后才能用)
|
||||
// ============================================================
|
||||
|
||||
export interface PredictRequest {
|
||||
model_id: string
|
||||
input_zarr_path: string
|
||||
output_zarr_path?: string
|
||||
}
|
||||
|
||||
export function submitPredict(
|
||||
payload: PredictRequest,
|
||||
): Promise<TaskAcceptedResponse> {
|
||||
return request.post<TaskAcceptedResponse, PredictRequest>(
|
||||
'/api/modeling/predict',
|
||||
payload,
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 5) 查询任务状态 GET /api/tasks/{task_id}
|
||||
// ============================================================
|
||||
|
||||
export function getTaskStatus(task_id: string): Promise<TaskRecord> {
|
||||
return request.get<TaskRecord>(
|
||||
`/api/tasks/${encodeURIComponent(task_id)}`,
|
||||
)
|
||||
}
|
||||
238
frontend/src/composables/useTaskPoller.ts
Normal file
@ -0,0 +1,238 @@
|
||||
/**
|
||||
* 任务轮询 Composable (Vue 3 + TypeScript)
|
||||
* -----------------------------------------
|
||||
* 用法 1 — 静态 task_id, 立即开始轮询:
|
||||
* const { status, result, error, waitForCompletion } = useTaskPoller(taskId)
|
||||
*
|
||||
* 用法 2 — 响应式 task_id (异步拿到后赋值, 自动开始):
|
||||
* const taskId = ref<string | null>(null)
|
||||
* const poller = useTaskPoller(taskId)
|
||||
* ;(async () => { taskId.value = (await submitTrain({...})).task_id })()
|
||||
* await poller.waitForCompletion()
|
||||
*
|
||||
* 用法 3 — 手动控制:
|
||||
* const poller = useTaskPoller()
|
||||
* poller.start(taskId) // 开始
|
||||
* poller.stop() // 停止
|
||||
* poller.reset() // 清空状态
|
||||
*
|
||||
* 设计要点:
|
||||
* - 终态 (SUCCESS/FAILED) 自动停止轮询
|
||||
* - 组件卸载自动清理 (onUnmounted)
|
||||
* - 网络错误不立刻终止, 计入 error.value 但继续轮询 (兼容临时抖动)
|
||||
* - waitForCompletion 是单次承诺: SUCCESS resolve(record), FAILED reject(error)
|
||||
* 外部 stop() 也会 reject
|
||||
*/
|
||||
import {
|
||||
onUnmounted,
|
||||
ref,
|
||||
watch,
|
||||
type MaybeRefOrGetter,
|
||||
type Ref,
|
||||
} from 'vue'
|
||||
import { toValue } from 'vue'
|
||||
import {
|
||||
getTaskStatus,
|
||||
type TaskRecord,
|
||||
type TaskStatus,
|
||||
} from '../api/tasks'
|
||||
|
||||
// 显式包含 'idle', 用于未开始轮询的初始态
|
||||
export type PollerStatus = TaskStatus | 'idle'
|
||||
|
||||
export interface UseTaskPollerOptions {
|
||||
/** 轮询间隔 ms, 默认 2000 */
|
||||
intervalMs?: number
|
||||
/** task_id 变 null 时是否自动停止, 默认 true */
|
||||
autoStopOnNull?: boolean
|
||||
}
|
||||
|
||||
export interface UseTaskPollerReturn {
|
||||
/** 当前任务状态, 初始 'idle' */
|
||||
status: Ref<PollerStatus>
|
||||
/** SUCCESS 时的完整任务记录 (含 output_zarr_path / model_id 等) */
|
||||
result: Ref<TaskRecord | null>
|
||||
/** FAILED 时的错误描述, 或轮询过程中网络异常的消息 */
|
||||
error: Ref<string | null>
|
||||
/** 最新一次拉取到的任务记录 (含 PENDING/PROCESSING 占位字段) */
|
||||
record: Ref<TaskRecord | null>
|
||||
/** 是否正在轮询中 */
|
||||
isPolling: Ref<boolean>
|
||||
/** 当前轮询的 task_id (可能为 null) */
|
||||
taskId: Ref<string | null>
|
||||
/** 开始轮询某 task, 已轮询同一 id 时是 no-op */
|
||||
start: (taskId: string) => void
|
||||
/** 主动停止 (会 reject 未完成的 waitForCompletion) */
|
||||
stop: () => void
|
||||
/** 清空所有状态回 'idle' */
|
||||
reset: () => void
|
||||
/**
|
||||
* 等到 SUCCESS/FAILED。
|
||||
* - SUCCESS: resolve(record)
|
||||
* - FAILED : reject(Error)
|
||||
* - stop() : reject(Error('Polling stopped'))
|
||||
* - 组件卸载: reject(Error('Component unmounted'))
|
||||
* 已处于终态时立刻 resolve/reject, 不重复等待。
|
||||
*/
|
||||
waitForCompletion: () => Promise<TaskRecord>
|
||||
}
|
||||
|
||||
export function useTaskPoller(
|
||||
taskIdSource?: MaybeRefOrGetter<string | null>,
|
||||
options: UseTaskPollerOptions = {},
|
||||
): UseTaskPollerReturn {
|
||||
const { intervalMs = 2000, autoStopOnNull = true } = options
|
||||
|
||||
const status = ref<PollerStatus>('idle')
|
||||
const result = ref<TaskRecord | null>(null)
|
||||
const error = ref<string | null>(null)
|
||||
const record = ref<TaskRecord | null>(null)
|
||||
const isPolling = ref(false)
|
||||
const taskId = ref<string | null>(null)
|
||||
|
||||
let timerId: ReturnType<typeof setInterval> | null = null
|
||||
let inFlightTick = false
|
||||
let resolveWait: ((rec: TaskRecord) => void) | null = null
|
||||
let rejectWait: ((err: Error) => void) | null = null
|
||||
|
||||
function clearTimer() {
|
||||
if (timerId !== null) {
|
||||
clearInterval(timerId)
|
||||
timerId = null
|
||||
}
|
||||
}
|
||||
|
||||
function resolveOrRejectWait(rec: TaskRecord | null, err: Error | null) {
|
||||
const r = resolveWait
|
||||
const rj = rejectWait
|
||||
resolveWait = null
|
||||
rejectWait = null
|
||||
if (rec && r) r(rec)
|
||||
else if (err && rj) rj(err)
|
||||
}
|
||||
|
||||
function applyTerminalRecord(rec: TaskRecord) {
|
||||
record.value = rec
|
||||
status.value = rec.status
|
||||
if (rec.status === 'SUCCESS') {
|
||||
result.value = rec
|
||||
error.value = null
|
||||
resolveOrRejectWait(rec, null)
|
||||
} else if (rec.status === 'FAILED') {
|
||||
result.value = null
|
||||
error.value = rec.error ?? '任务失败 (无具体错误信息)'
|
||||
resolveOrRejectWait(
|
||||
null,
|
||||
new Error(error.value ?? '任务失败'),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async function tick() {
|
||||
const currentId = taskId.value
|
||||
if (!currentId || inFlightTick) return
|
||||
inFlightTick = true
|
||||
try {
|
||||
const rec = await getTaskStatus(currentId)
|
||||
// 防止 await 期间用户 stop() / start() 了别的 task
|
||||
if (taskId.value !== currentId) return
|
||||
if (rec.status === 'SUCCESS' || rec.status === 'FAILED') {
|
||||
applyTerminalRecord(rec)
|
||||
stop()
|
||||
} else {
|
||||
// PENDING / PROCESSING 阶段, 更新 record 与 status 供 UI 展示
|
||||
record.value = rec
|
||||
status.value = rec.status
|
||||
}
|
||||
} catch (e) {
|
||||
const msg = e instanceof Error ? e.message : String(e)
|
||||
// 单次失败不立刻终止, 写入 error 但保持轮询
|
||||
error.value = `轮询异常: ${msg}`
|
||||
} finally {
|
||||
inFlightTick = false
|
||||
}
|
||||
}
|
||||
|
||||
function start(nextId: string) {
|
||||
if (!nextId) return
|
||||
// 已在轮询同一 id, 幂等
|
||||
if (taskId.value === nextId && isPolling.value) return
|
||||
clearTimer()
|
||||
taskId.value = nextId
|
||||
status.value = 'idle'
|
||||
result.value = null
|
||||
error.value = null
|
||||
record.value = null
|
||||
isPolling.value = true
|
||||
// 立刻拉一次, 避免 2s 空窗
|
||||
void tick()
|
||||
timerId = setInterval(() => void tick(), intervalMs)
|
||||
}
|
||||
|
||||
function stop() {
|
||||
const wasActive = isPolling.value
|
||||
clearTimer()
|
||||
isPolling.value = false
|
||||
if (wasActive) {
|
||||
resolveOrRejectWait(null, new Error('Polling stopped'))
|
||||
}
|
||||
}
|
||||
|
||||
function reset() {
|
||||
stop()
|
||||
taskId.value = null
|
||||
status.value = 'idle'
|
||||
result.value = null
|
||||
error.value = null
|
||||
record.value = null
|
||||
}
|
||||
|
||||
function waitForCompletion(): Promise<TaskRecord> {
|
||||
const r = record.value
|
||||
if (r && r.status === 'SUCCESS') return Promise.resolve(r)
|
||||
if (r && r.status === 'FAILED') {
|
||||
return Promise.reject(
|
||||
new Error(r.error ?? '任务失败 (无具体错误信息)'),
|
||||
)
|
||||
}
|
||||
return new Promise<TaskRecord>((resolve, reject) => {
|
||||
resolveWait = resolve
|
||||
rejectWait = reject
|
||||
})
|
||||
}
|
||||
|
||||
// 自动模式: 监听外部 taskIdSource
|
||||
if (taskIdSource !== undefined) {
|
||||
const stopWatch = watch(
|
||||
() => toValue(taskIdSource),
|
||||
(newId, oldId) => {
|
||||
if (newId && newId !== oldId) start(newId)
|
||||
else if (!newId && autoStopOnNull) stop()
|
||||
},
|
||||
{ immediate: true },
|
||||
)
|
||||
onUnmounted(() => {
|
||||
stopWatch()
|
||||
reset()
|
||||
resolveOrRejectWait(null, new Error('Component unmounted'))
|
||||
})
|
||||
} else {
|
||||
onUnmounted(() => {
|
||||
reset()
|
||||
resolveOrRejectWait(null, new Error('Component unmounted'))
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
status,
|
||||
result,
|
||||
error,
|
||||
record,
|
||||
isPolling,
|
||||
taskId,
|
||||
start,
|
||||
stop,
|
||||
reset,
|
||||
waitForCompletion,
|
||||
}
|
||||
}
|
||||
9
frontend/src/main.ts
Normal file
@ -0,0 +1,9 @@
|
||||
import { createApp } from 'vue'
|
||||
import ElementPlus from 'element-plus'
|
||||
import 'element-plus/dist/index.css'
|
||||
import App from './App.vue'
|
||||
|
||||
const app = createApp(App)
|
||||
|
||||
app.use(ElementPlus)
|
||||
app.mount('#app')
|
||||
24
frontend/tsconfig.json
Normal file
@ -0,0 +1,24 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"useDefineForClassFields": true,
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"strict": true,
|
||||
"jsx": "preserve",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"esModuleInterop": true,
|
||||
"lib": ["ES2020", "DOM", "DOM.Iterable"],
|
||||
"skipLibCheck": true,
|
||||
"noEmit": true,
|
||||
"allowImportingTsExtensions": true,
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"@/*": ["src/*"]
|
||||
},
|
||||
"types": ["vite/client"]
|
||||
},
|
||||
"include": ["src/**/*.ts", "src/**/*.d.ts", "src/**/*.vue", "env.d.ts"],
|
||||
"references": [{ "path": "./tsconfig.node.json" }]
|
||||
}
|
||||
12
frontend/tsconfig.node.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"composite": true,
|
||||
"skipLibCheck": true,
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"allowSyntheticDefaultImports": true,
|
||||
"strict": true,
|
||||
"types": ["node"]
|
||||
},
|
||||
"include": ["vite.config.ts"]
|
||||
}
|
||||
21
frontend/vite.config.ts
Normal file
@ -0,0 +1,21 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import vue from '@vitejs/plugin-vue'
|
||||
import { fileURLToPath, URL } from 'node:url'
|
||||
|
||||
// Vite 配置:
|
||||
// - @ -> frontend/src
|
||||
// - dev server 监听 0.0.0.0:5173, 允许局域网内真机调试
|
||||
// - VITE_API_BASE_URL 通过 .env.development 注入, 缺省走 src/api/request.ts 内的兜底 (http://127.0.0.1:8000)
|
||||
export default defineConfig({
|
||||
plugins: [vue()],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': fileURLToPath(new URL('./src', import.meta.url)),
|
||||
},
|
||||
},
|
||||
server: {
|
||||
host: '0.0.0.0',
|
||||
port: 5173,
|
||||
strictPort: false,
|
||||
},
|
||||
})
|
||||
8
license.lic
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"product": "WaterQualityInversion",
|
||||
"machine_code": "76E4992A5CF08BA570D6150908E04755",
|
||||
"generated_at": "2026-05-28 14:21:35",
|
||||
"expiry": "2099-12-31",
|
||||
"signature": "DC9AB900D7033A281E54F41F3F76D026FFA75D635484D40C7F6FC1F6023E02AB"
|
||||
}
|
||||
201
new/app/api/_smoke_test_train.py
Normal file
@ -0,0 +1,201 @@
|
||||
"""
|
||||
冒烟测试 _run_train_sync: 用合成数据走通真实训练管线。
|
||||
不依赖 FastAPI / xarray / dask, 只验训练 + 持久化 + 回测。
|
||||
"""
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# 绕过 main.py 触发 app 包导入(只导入 modeling 模块)
|
||||
# 当前文件位于 new/app/api/_smoke_test_train.py
|
||||
# app 包在 new/app/__init__.py, 故 new/ 必须在 sys.path 上
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from app.api.modeling import (
|
||||
_get_model_pipeline,
|
||||
_load_train_df,
|
||||
_resolve_feature_start,
|
||||
_run_train_sync,
|
||||
_MODEL_CLASS_REGISTRY,
|
||||
)
|
||||
|
||||
|
||||
def make_synthetic_csv(n_samples: int = 200, n_features: int = 8, noise: float = 0.1, seed: int = 42) -> Path:
|
||||
"""生成 [lat, lon, target, lat2, lon2, feat_0, feat_1, ...] 布局的 CSV"""
|
||||
rng = np.random.default_rng(seed)
|
||||
lat = rng.uniform(20, 25, n_samples)
|
||||
lon = rng.uniform(110, 115, n_samples)
|
||||
target = rng.uniform(0, 50, n_samples)
|
||||
lat2 = rng.uniform(0, 1, n_samples) # 元数据
|
||||
lon2 = rng.uniform(0, 1, n_samples) # 元数据
|
||||
feats = rng.normal(0, 1, (n_samples, n_features))
|
||||
# 让 y 真正依赖前 3 个特征, RF 至少应该能学到 R² > 0.5
|
||||
feats[:, 0] += target / 10
|
||||
feats[:, 1] += target / 20
|
||||
feats[:, 2] -= target / 15
|
||||
|
||||
df = pd.DataFrame({
|
||||
"lat": lat,
|
||||
"lon": lon,
|
||||
"Chl-a": target,
|
||||
"lat2": lat2,
|
||||
"lon2": lon2,
|
||||
**{f"feat_{i}": feats[:, i] for i in range(n_features)},
|
||||
})
|
||||
tmp = Path(tempfile.mkdtemp()) / "train.csv"
|
||||
df.to_csv(tmp, index=False)
|
||||
return tmp
|
||||
|
||||
|
||||
def test_load_train_df():
|
||||
print("== test_load_train_df ==")
|
||||
p = make_synthetic_csv(n_samples=50)
|
||||
df = _load_train_df(str(p))
|
||||
assert df.shape == (50, 5 + 8), f"shape={df.shape}"
|
||||
print(f" shape={df.shape}, columns[:6]={list(df.columns[:6])}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_resolve_feature_start_int_and_str():
|
||||
print("== test_resolve_feature_start (int + str) ==")
|
||||
p = make_synthetic_csv()
|
||||
df = _load_train_df(str(p))
|
||||
idx_int = _resolve_feature_start(df, 5)
|
||||
idx_str = _resolve_feature_start(df, "feat_0")
|
||||
assert idx_int == 5 == idx_str, f"int={idx_int}, str={idx_str}"
|
||||
print(f" int(5) -> {idx_int}, str('feat_0') -> {idx_str}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_resolve_feature_start_str_miss():
|
||||
print("== test_resolve_feature_start (str 不存在 -> 抛错) ==")
|
||||
p = make_synthetic_csv()
|
||||
df = _load_train_df(str(p))
|
||||
try:
|
||||
_resolve_feature_start(df, "not_exist")
|
||||
print(" FAIL: 应抛 ValueError")
|
||||
except ValueError as e:
|
||||
print(f" 正确抛 ValueError: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_get_model_pipeline_all_types():
|
||||
print("== test_get_model_pipeline (5 种 model_type) ==")
|
||||
for mt in ["RF", "SVR", "LinearRegression", "KNN", "PLS"]:
|
||||
p = _get_model_pipeline(mt, {})
|
||||
assert len(p.steps) == 2
|
||||
assert p.steps[0][0] == "scaler"
|
||||
assert p.steps[1][0] == "model"
|
||||
print(f" 全部通过: {list(_MODEL_CLASS_REGISTRY)}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_get_model_pipeline_bad_type():
|
||||
print("== test_get_model_pipeline (坏 model_type) ==")
|
||||
try:
|
||||
_get_model_pipeline("XGBoost", {})
|
||||
print(" FAIL: 应抛 ValueError")
|
||||
except ValueError as e:
|
||||
print(f" 正确抛 ValueError: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_rf_end_to_end():
|
||||
print("== test_run_train_sync (RF 端到端) ==")
|
||||
p = make_synthetic_csv(n_samples=200)
|
||||
out_dir = Path(tempfile.mkdtemp())
|
||||
out_path = out_dir / "model.joblib"
|
||||
|
||||
import time
|
||||
t0 = time.time()
|
||||
metadata = _run_train_sync(
|
||||
model_type="RF",
|
||||
target="Chl-a",
|
||||
train_data_path=str(p),
|
||||
feature_start=5,
|
||||
params={"n_estimators": 30, "max_depth": 6, "random_state": 42, "n_jobs": 1},
|
||||
output_model_path=out_path,
|
||||
)
|
||||
dt = time.time() - t0
|
||||
|
||||
assert out_path.exists(), f"joblib 未落盘: {out_path}"
|
||||
print(f" joblib 落盘: {out_path} ({out_path.stat().st_size} bytes)")
|
||||
print(f" metadata.test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f} test_mae={metadata['test_mae']:.4f}")
|
||||
print(f" metadata.n_features={metadata['n_features']} n_samples={metadata['n_samples']} train_size={metadata['train_size']} test_size={metadata['test_size']}")
|
||||
print(f" 耗时 {dt:.2f}s")
|
||||
|
||||
# 回测: 加载 joblib 再 predict
|
||||
import joblib
|
||||
saved = joblib.load(out_path)
|
||||
assert "model" in saved and "metadata" in saved, f"joblib 双 key 缺失: {saved.keys()}"
|
||||
assert hasattr(saved["model"], "predict")
|
||||
assert saved["metadata"]["test_r2"] == metadata["test_r2"]
|
||||
print(f" joblib 加载 OK, 含 'model' 和 'metadata' 双 key")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_linearregression_fast():
|
||||
print("== test_run_train_sync (LinearRegression 快速路径) ==")
|
||||
p = make_synthetic_csv(n_samples=150)
|
||||
out_path = Path(tempfile.mkdtemp()) / "lr.joblib"
|
||||
metadata = _run_train_sync(
|
||||
model_type="LinearRegression",
|
||||
target="Chl-a",
|
||||
train_data_path=str(p),
|
||||
feature_start=5,
|
||||
params={},
|
||||
output_model_path=out_path,
|
||||
)
|
||||
print(f" test_r2={metadata['test_r2']:.4f} (LR 学到线性, R² 应 >= 0.4)")
|
||||
assert metadata["test_r2"] > 0.3, f"LR test_r2={metadata['test_r2']} 太低, 数据生成可能有问题"
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_bad_csv():
|
||||
print("== test_run_train_sync (CSV 不存在) ==")
|
||||
try:
|
||||
_run_train_sync("RF", "Chl-a", "/no/such/path.csv", 5, {}, Path("/tmp/x.joblib"))
|
||||
print(" FAIL: 应抛异常")
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
print(f" 正确抛 {type(e).__name__}: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_bad_target():
|
||||
print("== test_run_train_sync (target 列不存在) ==")
|
||||
p = make_synthetic_csv()
|
||||
try:
|
||||
_run_train_sync("RF", "NopeTarget", str(p), 5, {}, Path("/tmp/x.joblib"))
|
||||
print(" FAIL: 应抛 ValueError")
|
||||
except ValueError as e:
|
||||
print(f" 正确抛 ValueError: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_str_feature_start():
|
||||
print("== test_run_train_sync (feature_start 用列名) ==")
|
||||
p = make_synthetic_csv()
|
||||
out_path = Path(tempfile.mkdtemp()) / "str_fs.joblib"
|
||||
metadata = _run_train_sync("RF", "Chl-a", str(p), "feat_0", {"n_estimators": 10}, out_path)
|
||||
assert metadata["feature_start"] == "feat_0"
|
||||
assert metadata["n_features"] == 8
|
||||
assert metadata["feature_columns"][0] == "feat_0"
|
||||
print(f" 列名 'feat_0' 解析正确, n_features={metadata['n_features']}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_train_df()
|
||||
test_resolve_feature_start_int_and_str()
|
||||
test_resolve_feature_start_str_miss()
|
||||
test_get_model_pipeline_all_types()
|
||||
test_get_model_pipeline_bad_type()
|
||||
test_run_train_sync_rf_end_to_end()
|
||||
test_run_train_sync_linearregression_fast()
|
||||
test_run_train_sync_bad_csv()
|
||||
test_run_train_sync_bad_target()
|
||||
test_run_train_sync_str_feature_start()
|
||||
print("\n>>> ALL SMOKE TESTS PASSED")
|
||||
222
new/app/api/endpoints.py
Normal file
@ -0,0 +1,222 @@
|
||||
"""
|
||||
API 路由集合
|
||||
============
|
||||
|
||||
把业务接口统一收口到 APIRouter,再由 main.py 通过 include_router 挂载。
|
||||
|
||||
当前包含的接口:
|
||||
GET /api/algorithms 列出已注册的所有去耀斑算法(供前端下拉框)
|
||||
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
|
||||
GET /api/tasks/{task_id} 查询指定任务的状态与结果
|
||||
|
||||
派发链:
|
||||
POST /api/process/deglint
|
||||
└─ BackgroundTasks.add_task(execute_glint_removal_task, ...)
|
||||
└─ get_remover(method) 从注册表拿到算法类
|
||||
└─ remover.process(input_zarr, output_zarr, **params)
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 并发安全的任务状态存储(替代旧版的 MOCK_TASK_DB)
|
||||
from app.core.task_store import get_task, set_task, update_task
|
||||
|
||||
# 算法注册表 API
|
||||
from app.core.algorithms import get_remover, list_removers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 路由实例
|
||||
# ---------------------------------------------------------------------------
|
||||
# prefix 不在此处设置,统一在 main.py 挂载时给定,便于将来按版本拆分
|
||||
# (例如 /api/v1、/api/v2 共存时复用同一个 router 对象)。
|
||||
# ---------------------------------------------------------------------------
|
||||
router = APIRouter(tags=["deglint"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 请求 / 响应数据模型
|
||||
# ---------------------------------------------------------------------------
|
||||
class DeglintRequest(BaseModel):
|
||||
"""POST /api/process/deglint 的请求体"""
|
||||
|
||||
method: str = Field(
|
||||
...,
|
||||
description="去耀斑方法名称,必须是已注册算法,例如 'kutser' / 'goodman'",
|
||||
examples=["kutser"],
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description=(
|
||||
"传递给算法 process() 的超参数字典,例如 "
|
||||
"Kutser: {'band_lower': 773, 'band_oxy': 845, 'band_upper': 893}; "
|
||||
"Goodman: {'band_ref': 750, 'band_diff': 640, 'A': 0.0, 'B': 0.0}"
|
||||
),
|
||||
examples=[{"band_lower": 773, "band_oxy": 845, "band_upper": 893}],
|
||||
)
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
"""提交任务成功后立即返回的响应"""
|
||||
|
||||
task_id: str
|
||||
status: str # 一定是 PENDING
|
||||
|
||||
|
||||
class AlgorithmListResponse(BaseModel):
|
||||
"""GET /api/algorithms 的响应"""
|
||||
|
||||
algorithms: list # 已注册算法名列表
|
||||
count: int # 算法总数
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 后台任务执行器(真实派发链)
|
||||
# ---------------------------------------------------------------------------
|
||||
# 注意:这里使用 async def。
|
||||
# FastAPI / Starlette 的 BackgroundTasks 支持 async function,
|
||||
# 会在响应返回后自动 await 它,不影响主请求链路。
|
||||
# ---------------------------------------------------------------------------
|
||||
async def execute_glint_removal_task(
|
||||
task_id: str,
|
||||
method: str,
|
||||
params: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
后台异步执行器:按 method 名字从注册表取出算法类,实例化并运行 process()。
|
||||
|
||||
状态机:
|
||||
PENDING -> PROCESSING -> SUCCESS
|
||||
└──> FAILED(含 error / traceback)
|
||||
"""
|
||||
# 0. 安全检查:任务记录必须已存在(POST 阶段已写入)
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
print(f"[{task_id}] 任务不存在, 跳过")
|
||||
return
|
||||
|
||||
# 1. 状态推进到 PROCESSING
|
||||
await update_task(
|
||||
task_id,
|
||||
status="PROCESSING",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 开始处理 method={method} params={params}")
|
||||
|
||||
# 2. 临时硬编码 IO 路径(未来由数据管理层提供)
|
||||
# TODO: 替换为真实的数据管理服务返回的 zarr 路径
|
||||
input_zarr_path = "./data/temp_in.zarr"
|
||||
output_zarr_path = f"./data/{task_id}_out.zarr"
|
||||
|
||||
try:
|
||||
# 3. 按 method 名字从注册表取算法类并实例化
|
||||
# get_remover 找不到时会抛 KeyError,下面的 except 会兜住
|
||||
algorithm_cls = get_remover(method)
|
||||
remover = algorithm_cls()
|
||||
|
||||
# 4. 调用算法(注意 await,因为 BaseGlintRemover.process 是 async)
|
||||
await remover.process(input_zarr_path, output_zarr_path, **params)
|
||||
|
||||
# 5. 成功:写回结果路径与状态
|
||||
await update_task(
|
||||
task_id,
|
||||
status="SUCCESS",
|
||||
output_zarr_path=output_zarr_path,
|
||||
error=None,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 处理完成 -> SUCCESS, output={output_zarr_path}")
|
||||
|
||||
except Exception as exc: # noqa: BLE001 顶层兜底,绝不让后台任务静默失败
|
||||
# 6. 失败:记录错误信息与堆栈,便于前端排查
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
output_zarr_path=None,
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
traceback=traceback.format_exc(),
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 处理失败 -> {type(exc).__name__}: {exc}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /algorithms
|
||||
# ---------------------------------------------------------------------------
|
||||
# 返回当前已注册的所有算法名,供前端动态渲染下拉框 / 选择器。
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/algorithms", response_model=AlgorithmListResponse)
|
||||
async def list_registered_algorithms() -> Dict[str, Any]:
|
||||
"""列出已注册的去耀斑算法。"""
|
||||
names = list(list_removers().keys())
|
||||
return {"algorithms": names, "count": len(names)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /process/deglint
|
||||
# ---------------------------------------------------------------------------
|
||||
# 提交去耀斑处理任务。FastAPI 在函数返回后才会把响应发给前端,
|
||||
# 因此通过 BackgroundTasks 把耗时操作丢到后台,接口本身立刻返回 task_id。
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/process/deglint", response_model=TaskAcceptedResponse)
|
||||
async def submit_deglint(
|
||||
payload: DeglintRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> Dict[str, Any]:
|
||||
"""提交一个去耀斑处理任务,并立即返回 task_id。"""
|
||||
|
||||
# 1. 生成唯一任务 ID(UUID4 足以保证全局唯一性)
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 2. 在任务库中登记一条 PENDING 记录(并发安全)
|
||||
# 注意:output_zarr_path / error / traceback 字段在执行过程中被填充
|
||||
await set_task(
|
||||
task_id,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"method": payload.method,
|
||||
"params": payload.params,
|
||||
"status": "PENDING",
|
||||
"output_zarr_path": None,
|
||||
"error": None,
|
||||
"traceback": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# 3. 把真实执行器丢到后台
|
||||
background_tasks.add_task(
|
||||
execute_glint_removal_task,
|
||||
task_id,
|
||||
payload.method,
|
||||
payload.params,
|
||||
)
|
||||
|
||||
# 4. 立即返回 task_id 与 PENDING 状态
|
||||
return {"task_id": task_id, "status": "PENDING"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /tasks/{task_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
# 前端轮询此接口获取任务状态。PENDING / PROCESSING 表示仍在跑,
|
||||
# SUCCESS 表示成功(含 output_zarr_path),FAILED 表示失败(含 error / traceback)。
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task_status(task_id: str) -> Dict[str, Any]:
|
||||
"""查询指定任务的当前状态与结果。"""
|
||||
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
# 找不到 task_id 通常意味着客户端拼错了 ID,或者记录已被清理
|
||||
raise HTTPException(status_code=404, detail=f"task_id 不存在: {task_id}")
|
||||
|
||||
# 直接返回字典,FastAPI 会自动 JSON 序列化
|
||||
return record
|
||||
786
new/app/api/modeling.py
Normal file
@ -0,0 +1,786 @@
|
||||
"""
|
||||
app/api/modeling.py
|
||||
===================
|
||||
|
||||
机器学习与水质反演相关的 API 路由。
|
||||
|
||||
接口(最终路径, 挂载后):
|
||||
POST /api/modeling/train 提交模型训练任务, 立即返回 task_id
|
||||
GET /api/modeling/models 列出已训练好的模型(未来从磁盘 joblib 读)
|
||||
POST /api/modeling/predict 提交模型推断任务, 立即返回 task_id
|
||||
|
||||
设计要点
|
||||
--------
|
||||
- 训练 / 推断均为异步后台任务, 复用 app.core.task_store 的并发安全任务状态。
|
||||
- 模型元数据用模块级 _MODEL_REGISTRY 暂存(开发期内存存储),
|
||||
未来从磁盘 joblib 读时只需替换 list_trained_models() 内部实现即可。
|
||||
- /predict 已接入真实 sklearn + xarray + dask 流式推断:
|
||||
* joblib.load 读模型(缺文件时降级为 Dummy RandomForestRegressor)
|
||||
* xr.open_zarr 延迟打开影像, NaN 填 0
|
||||
* xr.apply_ufunc(dask="parallelized") 沿 (y, x) 逐 chunk 调 model.predict
|
||||
* to_zarr(mode="w", compute=True) 流式写出, 内存峰值 ≈ 1 个 chunk
|
||||
- /train 已接入真实 sklearn + pandas 训练管线:
|
||||
* pd.read_csv 读结构化训练表(支持 [lat, lon, target_*, feature_*] 布局)
|
||||
* 按 target 列 dropna 清洗;按 feature_start 索引/列名切分特征
|
||||
* sklearn Pipeline: StandardScaler -> {RF/SVR/LinearRegression/KNN/PLS}
|
||||
* train_test_split(80/20) 划分, 计算 test_r2/rmse/mae
|
||||
* joblib.dump({model, metadata}) 落盘 ./data/models/{model_id}.joblib
|
||||
* 测试指标写回 TASK_STORE, 同时登记到 _MODEL_REGISTRY
|
||||
注: 旧版 SPXY / KS 划分留作未来扩展, 当前固定 random 划分 (test_size=0.2, random_state=42)。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xarray as xr
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
from sklearn.linear_model import LinearRegression
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVR
|
||||
|
||||
# 复用并发安全任务状态存储(与 deglint 共享同一份 TASK_STORE,
|
||||
# 通过 task 记录里的 "kind" 字段区分 train / predict / deglint)
|
||||
from app.core.task_store import get_task, set_task, update_task
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 路由实例
|
||||
# ---------------------------------------------------------------------------
|
||||
# prefix="/modeling" 让本文件内只写 /train /models /predict 等短路径,
|
||||
# 最终完整路径由 main.py 挂载时再补 /api。
|
||||
# ---------------------------------------------------------------------------
|
||||
router = APIRouter(prefix="/modeling", tags=["modeling"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 数据模型
|
||||
# ---------------------------------------------------------------------------
|
||||
class TrainRequest(BaseModel):
|
||||
"""POST /api/modeling/train 的请求体"""
|
||||
|
||||
model_type: str = Field(
|
||||
...,
|
||||
description="模型类型, 例如 'RF' (随机森林) / 'SVR' (支持向量回归) / 'XGBoost' / 'MLP'",
|
||||
examples=["RF", "SVR"],
|
||||
)
|
||||
target: str = Field(
|
||||
...,
|
||||
description="反演目标水质参数, 例如 'Chl-a' (叶绿素a) / 'TSS' (总悬浮物) / 'CDOM' (有色可溶有机物)",
|
||||
examples=["Chl-a", "TSS", "CDOM"],
|
||||
)
|
||||
train_data_path: str = Field(
|
||||
...,
|
||||
description="训练数据集的 zarr 路径(包含 reflectance 变量与 target 标签)",
|
||||
examples=["./data/train.zarr"],
|
||||
)
|
||||
feature_start: Union[int, str] = Field(
|
||||
default=4,
|
||||
description=(
|
||||
"特征列起始位置. 表格布局假定为 "
|
||||
"[lat, lon, target_1, target_2, ..., feature_1, feature_2, ...] "
|
||||
"可传 int 列索引(如 4)或 str 列名(如 '374.285' 波长起点)。"
|
||||
"默认 4, 即前 4 列视为元数据/目标, 之后全部是特征。"
|
||||
),
|
||||
examples=[4, "374.285"],
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="模型超参, 例如 RF 的 {'n_estimators': 100, 'max_depth': 20}",
|
||||
examples=[{"n_estimators": 100, "max_depth": 20}],
|
||||
)
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
"""POST /api/modeling/predict 的请求体"""
|
||||
|
||||
model_id: str = Field(
|
||||
...,
|
||||
description="已训练模型的 ID(由 /api/modeling/train 返回或 /api/modeling/models 列出)",
|
||||
)
|
||||
input_zarr_path: str = Field(
|
||||
...,
|
||||
description="待推断影像的 zarr 路径",
|
||||
examples=["./data/scene.zarr"],
|
||||
)
|
||||
output_zarr_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"输出 zarr 路径, 缺省时由后端按规则生成 "
|
||||
"(如 ./data/{model_id}_{input_stem}_pred.zarr)"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
"""提交训练/推断任务后立即返回的响应"""
|
||||
|
||||
task_id: str
|
||||
status: str # 一定是 PENDING
|
||||
kind: str # "train" / "predict", 便于前端识别任务类型
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""单个模型的元信息(GET /api/modeling/models 的元素)"""
|
||||
|
||||
model_id: str
|
||||
model_type: str
|
||||
target: str
|
||||
params: Dict[str, Any]
|
||||
path: str # joblib 文件路径
|
||||
created_at: str
|
||||
train_task_id: str # 产生此模型的那个训练任务的 ID
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
"""GET /api/modeling/models 的响应"""
|
||||
|
||||
models: List[ModelInfo]
|
||||
count: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 模块级模型注册表(开发期内存, 未来替换为磁盘扫描)
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_id -> ModelInfo 字典
|
||||
# 读多写少, 用一个普通 dict 足够(CPython GIL 兜底)。
|
||||
# 写时(训练完成时)只发生一次, 无并发风险。
|
||||
# ---------------------------------------------------------------------------
|
||||
_MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _register_model(record: Dict[str, Any]) -> None:
|
||||
"""将训练完成的模型登记到内存注册表。"""
|
||||
_MODEL_REGISTRY[record["model_id"]] = record
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 训练管线的模块级辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设计要点 (与推断管线一致):
|
||||
# 1) 模块级函数: dask / joblib 后端若走子进程 pickle, 嵌套闭包会丢字段。
|
||||
# 2) 同步执行: execute_train_task 用 asyncio.to_thread 派发, 内部全程同步阻塞。
|
||||
# 3) 失败抛异常: 异常由 execute_train_task 捕获, 转 FAILED + traceback。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# model_type (大写字符串) -> sklearn 估计器类
|
||||
# 与 OpenClaw model_configs 思路一致, 但此处只保留类 (参数由 params 透传)
|
||||
_MODEL_CLASS_REGISTRY: Dict[str, type] = {
|
||||
"RF": RandomForestRegressor,
|
||||
"SVR": SVR,
|
||||
"LinearRegression": LinearRegression,
|
||||
"KNN": KNeighborsRegressor,
|
||||
"PLS": PLSRegression,
|
||||
}
|
||||
|
||||
|
||||
def _get_model_pipeline(model_type: str, params: Optional[Dict[str, Any]]) -> Pipeline:
|
||||
"""
|
||||
模型工厂: 按 model_type 选 sklearn 类, 用 StandardScaler + 估计器构造 Pipeline。
|
||||
|
||||
与 OpenClaw 不同之处: 把 scaler 放进 Pipeline 第一步,
|
||||
推断时直接 pipeline.predict(X) 即可, scaler 参数与训练时严格一致。
|
||||
"""
|
||||
model_cls = _MODEL_CLASS_REGISTRY.get(model_type)
|
||||
if model_cls is None:
|
||||
raise ValueError(
|
||||
f"不支持的 model_type='{model_type}', "
|
||||
f"可选: {sorted(_MODEL_CLASS_REGISTRY.keys())}"
|
||||
)
|
||||
estimator = model_cls(**(params or {}))
|
||||
return Pipeline([("scaler", StandardScaler()), ("model", estimator)])
|
||||
|
||||
|
||||
def _load_train_df(csv_path: str) -> pd.DataFrame:
|
||||
"""
|
||||
读 CSV 训练表, 规整空串 / 空白 / NULL 等为 NaN。
|
||||
|
||||
沿用 OpenClaw modeling_batch.load_data_batch 的读取策略:
|
||||
na_values 显式列举 + 正则二次清理 (防 cell 内出现 " " 等纯空白)。
|
||||
"""
|
||||
try:
|
||||
df = pd.read_csv(
|
||||
csv_path,
|
||||
na_values=["", " ", "NaN", "nan", "NULL", "null"],
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise FileNotFoundError(f"训练数据文件不存在: {csv_path}") from exc
|
||||
except pd.errors.EmptyDataError as exc:
|
||||
raise ValueError(f"训练数据文件为空: {csv_path}") from exc
|
||||
# 二次清理: 残留的纯空白 cell
|
||||
df = df.replace(r"^\s*$", np.nan, regex=True)
|
||||
return df
|
||||
|
||||
|
||||
def _resolve_feature_start(
|
||||
df: pd.DataFrame,
|
||||
feature_start: Union[int, str],
|
||||
) -> int:
|
||||
"""
|
||||
将 feature_start (int 索引 / str 列名) 统一解析为 int 列索引。
|
||||
|
||||
与 OpenClaw modeling_batch.load_data_batch / load_data_single 一致:
|
||||
str 走 columns.get_loc, int 直接返回。
|
||||
"""
|
||||
if isinstance(feature_start, str):
|
||||
if feature_start not in df.columns:
|
||||
raise ValueError(
|
||||
f"feature_start='{feature_start}' 不在 CSV 列中: {list(df.columns)}"
|
||||
)
|
||||
return int(df.columns.get_loc(feature_start))
|
||||
return int(feature_start)
|
||||
|
||||
|
||||
def _run_train_sync(
|
||||
model_type: str,
|
||||
target: str,
|
||||
train_data_path: str,
|
||||
feature_start: Union[int, str],
|
||||
params: Optional[Dict[str, Any]],
|
||||
output_model_path: Path,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
完整同步训练流程 (由 execute_train_task 在线程池内调用):
|
||||
|
||||
pd.read_csv -> 目标列 dropna -> 切特征 -> train_test_split(80/20)
|
||||
-> Pipeline(StandardScaler + model).fit -> 评估 test_r2/rmse/mae
|
||||
-> joblib.dump({model, metadata}, output_model_path)
|
||||
|
||||
Returns:
|
||||
metadata 字典, 含 test_r2 / test_rmse / test_mae / n_features 等,
|
||||
调用方负责写回 TASK_STORE 和 _MODEL_REGISTRY。
|
||||
|
||||
注: 旧版 SPXY / KS 划分留作未来扩展 (params.split_method 控制),
|
||||
当前固定 random + test_size=0.2 + random_state=42。
|
||||
"""
|
||||
df = _load_train_df(train_data_path)
|
||||
|
||||
if target not in df.columns:
|
||||
raise ValueError(
|
||||
f"target='{target}' 不在 CSV 列中, 可选: {list(df.columns)}"
|
||||
)
|
||||
|
||||
# 1) 清洗: 仅剔除 target NaN 的行 (与 OpenClaw load_data_single 一致)
|
||||
df = df[df[target].notna()].copy()
|
||||
if df.empty:
|
||||
raise ValueError("target 剔除 NaN 后无样本, 终止训练")
|
||||
|
||||
# 2) 特征切分
|
||||
feature_start_idx = _resolve_feature_start(df, feature_start)
|
||||
feature_columns = list(df.columns[feature_start_idx:])
|
||||
|
||||
X = df.iloc[:, feature_start_idx:].astype(np.float64)
|
||||
y = df[target].astype(np.float64).values
|
||||
|
||||
# 3) 划分 (固定 random, 未来扩展 spxy/ks)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X.values,
|
||||
y,
|
||||
test_size=0.2,
|
||||
random_state=42,
|
||||
)
|
||||
|
||||
# 4) 构造 Pipeline + 训练
|
||||
pipeline = _get_model_pipeline(model_type, params)
|
||||
pipeline.fit(X_train, y_train)
|
||||
|
||||
# 5) 测试集与训练集评估
|
||||
y_pred = pipeline.predict(X_test)
|
||||
test_r2 = float(r2_score(y_test, y_pred))
|
||||
test_rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
|
||||
test_mae = float(mean_absolute_error(y_test, y_pred))
|
||||
|
||||
y_train_pred = pipeline.predict(X_train)
|
||||
train_r2 = float(r2_score(y_train, y_train_pred))
|
||||
train_rmse = float(np.sqrt(mean_squared_error(y_train, y_train_pred)))
|
||||
train_mae = float(mean_absolute_error(y_train, y_train_pred))
|
||||
|
||||
metadata: Dict[str, Any] = {
|
||||
"model_type": model_type,
|
||||
"target": target,
|
||||
"feature_start": feature_start,
|
||||
"feature_columns": feature_columns,
|
||||
"n_features": int(X.shape[1]),
|
||||
"n_samples": int(X.shape[0]),
|
||||
"train_size": int(X_train.shape[0]),
|
||||
"test_size": int(X_test.shape[0]),
|
||||
"params": dict(params or {}),
|
||||
"test_r2": test_r2,
|
||||
"test_rmse": test_rmse,
|
||||
"test_mae": test_mae,
|
||||
"train_r2": train_r2,
|
||||
"train_rmse": train_rmse,
|
||||
"train_mae": train_mae,
|
||||
"split_method": "random",
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# 7) 持久化 (目录可能不存在, 顺手建)
|
||||
output_model_path = Path(output_model_path)
|
||||
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
joblib.dump(
|
||||
{"model": pipeline, "metadata": metadata},
|
||||
output_model_path,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 推断管线的模块级辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设计要点:
|
||||
# 1) Dask 调度时, 函数必须可被工作进程 pickle 序列化。
|
||||
# 因此 _predict_block / _load_model / _make_dummy_model / _run_predict_sync
|
||||
# 全部是模块级函数 (而非嵌套), 避免闭包陷阱。
|
||||
# 2) _predict_block 通过 model.predict(spectra_2d) 整批预测,
|
||||
# 整张影像的 O(n_pixels * n_bands) 一次性预测在大矩阵上必 OOM,
|
||||
# 因此外层用 xr.apply_ufunc(dask="parallelized") 把矩阵切块
|
||||
# 逐块进入此函数, 单次内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) 大小。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dummy_model(n_features: int) -> RandomForestRegressor:
|
||||
"""
|
||||
构造一个 Dummy 随机森林回归器。
|
||||
|
||||
用途:
|
||||
1) 真实 joblib 文件不存在时的连通性测试
|
||||
2) 训练骨架尚未接入真实数据时的占位推断
|
||||
"""
|
||||
rng = np.random.default_rng(42)
|
||||
X = rng.random((200, n_features))
|
||||
y = rng.random(200)
|
||||
model = RandomForestRegressor(
|
||||
n_estimators=10, max_depth=5, random_state=0, n_jobs=1
|
||||
)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
|
||||
def _load_model(path: str, n_features: int) -> Any:
|
||||
"""
|
||||
加载训练好的 sklearn 模型, 失败时降级 Dummy。
|
||||
|
||||
优先级:
|
||||
1) path 存在且 joblib.load 成功 -> 返回真实模型
|
||||
2) 否则 -> 降级为 Dummy 随机森林 (n_features 必须指定)
|
||||
"""
|
||||
p = Path(path)
|
||||
if p.is_file() and p.stat().st_size > 0:
|
||||
try:
|
||||
print(f"[model] 从磁盘加载: {path}")
|
||||
return joblib.load(path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[model] joblib.load 失败 ({type(exc).__name__}: {exc}), 降级 Dummy")
|
||||
print(f"[model] 真实 joblib 不存在 ({path}), 使用 Dummy RandomForest")
|
||||
return _make_dummy_model(n_features)
|
||||
|
||||
|
||||
def _predict_block(spectra_3d: np.ndarray, model: Any) -> np.ndarray:
|
||||
"""
|
||||
单个 dask chunk 的推断函数 (xr.apply_ufunc 会自动调度调用)。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spectra_3d : np.ndarray
|
||||
形状 (y_chunk, x_chunk, n_bands)。
|
||||
此形状由 input_core_dims=[["band"]] 决定:
|
||||
xarray 会把 band 维移到最后一轴, 然后按 (y, x) 的 chunk 切分调用本函数。
|
||||
model : 已 fit 好的 sklearn 估计器
|
||||
接受 (n_samples, n_features) 输入, 返回 (n_samples,) 预测。
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
形状 (y_chunk, x_chunk), dtype float32 的标量预测图。
|
||||
"""
|
||||
yc, xc, nb = spectra_3d.shape
|
||||
# 2D 化: 每个像素一行光谱
|
||||
flat = spectra_3d.reshape(yc * xc, nb)
|
||||
# sklearn 风格的批量预测
|
||||
pred = model.predict(flat)
|
||||
# 还原为 2D 空间图, 强制 float32 节约一半内存
|
||||
return pred.reshape(yc, xc).astype(np.float32, copy=False)
|
||||
|
||||
|
||||
def _run_predict_sync(
|
||||
model: Any,
|
||||
model_id: str,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
同步推断主流程 (被 asyncio.to_thread 调用)。
|
||||
|
||||
流程:
|
||||
1) xr.open_zarr 延迟打开 (dask 数组, 不一次性读入内存)
|
||||
2) NaN -> 0 清洗 (model.predict 不接受 NaN)
|
||||
3) xr.apply_ufunc 沿 (y, x) 逐 chunk 调 _predict_block
|
||||
4) 非水域置 NaN (zarr 支持 float NaN)
|
||||
5) to_zarr 触发整图计算 + 流式写出
|
||||
"""
|
||||
# 1. 延迟打开输入 (关键: Dask 不一次性读入内存)
|
||||
ds = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
if "reflectance" not in ds.data_vars:
|
||||
raise KeyError(
|
||||
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds.data_vars)}"
|
||||
)
|
||||
|
||||
reflectance = ds["reflectance"] # dims: (y, x, band)
|
||||
n_bands = reflectance.sizes["band"]
|
||||
|
||||
# 2. 水域掩膜 (与去耀斑算法同约定)
|
||||
if "water_mask" in ds.data_vars or "water_mask" in ds.coords:
|
||||
water_mask = ds["water_mask"].astype(bool)
|
||||
else:
|
||||
water_mask = xr.ones_like(reflectance.isel(band=0), dtype=bool)
|
||||
|
||||
# 3. NaN 清洗: 填充 0 (model.predict 不接受 NaN)
|
||||
refl_clean = reflectance.fillna(0.0)
|
||||
|
||||
# 4. 核心: 用 apply_ufunc 把 model.predict 沿 (y, x) 应用
|
||||
# dask="parallelized" 让每个 (y_chunk, x_chunk, band) chunk
|
||||
# 独立调 _predict_block, 任意时刻内存中只有若干个 chunk。
|
||||
prediction: xr.DataArray = xr.apply_ufunc(
|
||||
_predict_block,
|
||||
refl_clean,
|
||||
kwargs={"model": model},
|
||||
input_core_dims=[["band"]],
|
||||
output_core_dims=[[]],
|
||||
dask="parallelized",
|
||||
output_dtypes=[np.float32],
|
||||
dask_gufunc_kwargs={"allow_rechunk": True},
|
||||
vectorize=False,
|
||||
)
|
||||
|
||||
# 5. 非水域置 NaN (zarr 支持 float NaN, 便于后续可视化/掩膜分析)
|
||||
prediction = prediction.where(water_mask, np.nan)
|
||||
|
||||
# 6. 包装为 Dataset 并流式写出
|
||||
out = xr.Dataset(
|
||||
{"prediction": prediction},
|
||||
attrs={
|
||||
"model_id": model_id,
|
||||
"input_zarr_path": input_zarr_path,
|
||||
"n_bands": n_bands,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
# 保留 y/x 坐标
|
||||
out = out.assign_coords(y=ds["y"], x=ds["x"])
|
||||
|
||||
# to_zarr + compute=True 触发整图 dask 图求值
|
||||
# 中间会按 chunk 逐块调度到线程池, 内存峰值 ≈ 1 个 chunk 的体量
|
||||
out.to_zarr(output_zarr_path, mode="w", compute=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 后台任务执行器
|
||||
# ---------------------------------------------------------------------------
|
||||
async def execute_train_task(
|
||||
task_id: str,
|
||||
model_type: str,
|
||||
target: str,
|
||||
train_data_path: str,
|
||||
feature_start: Union[int, str],
|
||||
params: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
训练任务后台执行器(已接入真实 sklearn 训练流程)。
|
||||
|
||||
流程:
|
||||
1) get_task 校验任务存在
|
||||
2) update_task(PROCESSING)
|
||||
3) 生成 model_id / model_path
|
||||
4) asyncio.to_thread 派发 _run_train_sync 到默认线程池
|
||||
5) 成功 -> _register_model + update_task(SUCCESS, 附 test_r2/rmse/mae)
|
||||
6) 失败 -> update_task(FAILED, 附 error + traceback)
|
||||
"""
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
print(f"[{task_id}] 训练任务不存在, 跳过")
|
||||
return
|
||||
|
||||
await update_task(
|
||||
task_id,
|
||||
status="PROCESSING",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(
|
||||
f"[{task_id}] 开始训练 model_type={model_type} target={target} "
|
||||
f"train_data_path={train_data_path} feature_start={feature_start}"
|
||||
)
|
||||
|
||||
# model_id 用 uuid4 前 12 位 (8 位易撞, 12 位兼顾可读性)
|
||||
model_id = f"model_{uuid.uuid4().hex[:12]}"
|
||||
model_path = Path(f"./data/models/{model_id}.joblib")
|
||||
|
||||
try:
|
||||
# 同步 sklearn / pandas 训练丢到默认线程池, 不阻塞 event loop
|
||||
metadata = await asyncio.to_thread(
|
||||
_run_train_sync,
|
||||
model_type,
|
||||
target,
|
||||
train_data_path,
|
||||
feature_start,
|
||||
params,
|
||||
model_path,
|
||||
)
|
||||
|
||||
# 登记到内存注册表 (供 /predict 查 model_id)
|
||||
_register_model(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"model_type": model_type,
|
||||
"target": target,
|
||||
"params": dict(params or {}),
|
||||
"path": str(model_path),
|
||||
"feature_start": feature_start,
|
||||
"n_features": metadata["n_features"],
|
||||
"test_r2": metadata["test_r2"],
|
||||
"test_rmse": metadata["test_rmse"],
|
||||
"test_mae": metadata["test_mae"],
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"train_task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
# 把训练指标写回任务记录, 前端轮询时可直接看
|
||||
await update_task(
|
||||
task_id,
|
||||
status="SUCCESS",
|
||||
model_id=model_id,
|
||||
model_path=str(model_path),
|
||||
test_r2=metadata["test_r2"],
|
||||
test_rmse=metadata["test_rmse"],
|
||||
test_mae=metadata["test_mae"],
|
||||
n_features=metadata["n_features"],
|
||||
n_samples=metadata["n_samples"],
|
||||
error=None,
|
||||
traceback=None,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(
|
||||
f"[{task_id}] 训练完成 -> model_id={model_id} "
|
||||
f"test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f}"
|
||||
)
|
||||
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# 失败时 model_path 不一定有产物, 显式置 None 方便前端判断
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
model_id=None,
|
||||
model_path=None,
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
traceback=traceback.format_exc(),
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 训练失败 -> {type(exc).__name__}: {exc}")
|
||||
|
||||
|
||||
async def execute_predict_task(
|
||||
task_id: str,
|
||||
model_id: str,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
推断任务后台执行器(真实实现版)。
|
||||
|
||||
OOM 防护策略:
|
||||
- xr.open_zarr(..., chunks="auto") 延迟打开, 整图不一次性读入内存
|
||||
- xr.apply_ufunc(..., dask="parallelized") 把影像按 chunk 切分
|
||||
- 每个 chunk 内部 reshape 成 2D, 调 model.predict, 再 reshape 回 2D
|
||||
- 任意时刻内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) chunk 的体量
|
||||
- 整图完成计算后再 to_zarr(compute=True) 流式写出
|
||||
"""
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
print(f"[{task_id}] 推断任务不存在, 跳过")
|
||||
return
|
||||
|
||||
# 1. 校验 model_id 是否已注册 (避免在后台任务里报模糊错误)
|
||||
model_meta = _MODEL_REGISTRY.get(model_id)
|
||||
if model_meta is None:
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
error=f"model_id 不存在: {model_id}",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 推断失败 -> model_id 不存在: {model_id}")
|
||||
return
|
||||
|
||||
# 2. 自动生成 output_zarr_path (若未提供)
|
||||
if output_zarr_path is None:
|
||||
stem = input_zarr_path.rstrip("/\\").split("/")[-1].split("\\")[-1]
|
||||
stem = stem.replace(".zarr", "")
|
||||
output_zarr_path = f"./data/{model_id}_{stem}_pred.zarr"
|
||||
|
||||
await update_task(
|
||||
task_id,
|
||||
status="PROCESSING",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 开始推断 model_id={model_id} input={input_zarr_path}")
|
||||
|
||||
try:
|
||||
# 3. 探测波段数 (用于 Dummy 模型适配)
|
||||
# 这里只读 zarr 元数据 (.zarray 的 shape), 不读真实数据
|
||||
ds_probe = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
if "reflectance" not in ds_probe.data_vars:
|
||||
raise KeyError(
|
||||
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds_probe.data_vars)}"
|
||||
)
|
||||
n_bands = ds_probe["reflectance"].sizes["band"]
|
||||
ds_probe.close()
|
||||
|
||||
# 4. 加载模型 (真实文件优先, Dummy 兜底)
|
||||
model = _load_model(model_meta["path"], n_features=n_bands)
|
||||
|
||||
# 5. 包装同步执行, 丢到线程池, 事件循环不阻塞
|
||||
await asyncio.to_thread(
|
||||
_run_predict_sync,
|
||||
model,
|
||||
model_id,
|
||||
input_zarr_path,
|
||||
output_zarr_path,
|
||||
)
|
||||
|
||||
await update_task(
|
||||
task_id,
|
||||
status="SUCCESS",
|
||||
output_zarr_path=output_zarr_path,
|
||||
model_id=model_id,
|
||||
error=None,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 推断完成 -> output={output_zarr_path}")
|
||||
|
||||
except Exception as exc: # noqa: BLE001
|
||||
tb_text = traceback.format_exc()
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
output_zarr_path=None,
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
traceback=tb_text,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 推断失败 -> {type(exc).__name__}: {exc}")
|
||||
print(tb_text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/modeling/train
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/train", response_model=TaskAcceptedResponse)
|
||||
async def submit_train(
|
||||
payload: TrainRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> Dict[str, Any]:
|
||||
"""提交一个模型训练任务, 立即返回 task_id。"""
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
await set_task(
|
||||
task_id,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"kind": "train",
|
||||
"model_type": payload.model_type,
|
||||
"target": payload.target,
|
||||
"train_data_path": payload.train_data_path,
|
||||
"feature_start": payload.feature_start,
|
||||
"params": payload.params,
|
||||
"status": "PENDING",
|
||||
"model_id": None,
|
||||
"model_path": None,
|
||||
"test_r2": None,
|
||||
"test_rmse": None,
|
||||
"test_mae": None,
|
||||
"n_features": None,
|
||||
"n_samples": None,
|
||||
"error": None,
|
||||
"traceback": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
background_tasks.add_task(
|
||||
execute_train_task,
|
||||
task_id,
|
||||
payload.model_type,
|
||||
payload.target,
|
||||
payload.train_data_path,
|
||||
payload.feature_start,
|
||||
payload.params,
|
||||
)
|
||||
return {"task_id": task_id, "status": "PENDING", "kind": "train"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/modeling/models
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/models", response_model=ModelListResponse)
|
||||
async def list_trained_models() -> Dict[str, Any]:
|
||||
"""
|
||||
列出已训练好的模型。
|
||||
|
||||
未来实现: 从 ./data/models/*.joblib 扫描元信息,
|
||||
当前直接从内存 _MODEL_REGISTRY 读。
|
||||
"""
|
||||
models = list(_MODEL_REGISTRY.values())
|
||||
# 按 created_at 倒序, 最新训练的在前
|
||||
models.sort(key=lambda m: m.get("created_at", ""), reverse=True)
|
||||
return {"models": models, "count": len(models)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/modeling/predict
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/predict", response_model=TaskAcceptedResponse)
|
||||
async def submit_predict(
|
||||
payload: PredictRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> Dict[str, Any]:
|
||||
"""提交一个模型推断任务, 立即返回 task_id。"""
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
await set_task(
|
||||
task_id,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"kind": "predict",
|
||||
"model_id": payload.model_id,
|
||||
"input_zarr_path": payload.input_zarr_path,
|
||||
"output_zarr_path": payload.output_zarr_path,
|
||||
"status": "PENDING",
|
||||
"error": None,
|
||||
"traceback": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
background_tasks.add_task(
|
||||
execute_predict_task,
|
||||
task_id,
|
||||
payload.model_id,
|
||||
payload.input_zarr_path,
|
||||
payload.output_zarr_path,
|
||||
)
|
||||
return {"task_id": task_id, "status": "PENDING", "kind": "predict"}
|
||||
40
new/app/core/algorithms/__init__.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""
|
||||
去耀斑算法包
|
||||
============
|
||||
|
||||
通过「注册表 + 策略模式」组织不同的去耀斑算法。
|
||||
所有具体算法都应继承 BaseGlintRemover,并使用 @register_glint_remover
|
||||
装饰器把算法名和实现类绑定。
|
||||
|
||||
外部调用约定
|
||||
------------
|
||||
1. 所有算法子模块必须在本 __init__ 中显式 import,
|
||||
这样装饰器才会被执行、注册表才会被填满。
|
||||
2. 上层(endpoints、worker)只允许:
|
||||
from app.core.algorithms import get_remover
|
||||
来获取算法类,不要直接 import 具体实现类,
|
||||
保持调度层与具体算法的解耦。
|
||||
"""
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
from app.core.algorithms.registry import (
|
||||
get_remover,
|
||||
list_removers,
|
||||
register_glint_remover,
|
||||
unregister_glint_remover,
|
||||
)
|
||||
|
||||
# ---- 算法子模块 import 区 ----
|
||||
# 新增算法时,在这里加一行 import,确保装饰器被执行。
|
||||
from app.core.algorithms import goodman # Goodman
|
||||
from app.core.algorithms import kutser # Kutser
|
||||
# from app.core.algorithms import hedley # Hedley
|
||||
# from app.core.algorithms import sugar # SUGAR
|
||||
|
||||
__all__ = [
|
||||
"BaseGlintRemover",
|
||||
"register_glint_remover",
|
||||
"get_remover",
|
||||
"list_removers",
|
||||
"unregister_glint_remover",
|
||||
]
|
||||
85
new/app/core/algorithms/base.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""
|
||||
去耀斑算法抽象基类
|
||||
==================
|
||||
|
||||
设计目标(策略模式 Strategy Pattern)
|
||||
------------------------------------
|
||||
本模块定义了所有去耀斑算法必须遵守的标准接口。
|
||||
未来的 Kutser、Goodman、Hedley、SUGAR 等算法都将继承本基类,
|
||||
并实现统一的 process() 方法。
|
||||
|
||||
输入输出规范
|
||||
------------
|
||||
所有算法的输入与输出均统一为 **Zarr 文件路径**(字符串),
|
||||
而不是内存中的 numpy ndarray。这样做的核心收益是:
|
||||
|
||||
1. **解耦数据存储与内存计算**:
|
||||
算法只关心「从哪个 zarr 读、写到哪个 zarr」,
|
||||
至于数据最初来自 GeoTIFF / HDF5 / NetCDF / 内存数组,
|
||||
都由 IO 层负责归一化转为 zarr。
|
||||
2. **支持 Out-of-Core 计算**:
|
||||
影像往往超过内存上限,zarr 分块(chunk)天然支持按块读取,
|
||||
算法实现可以借助 dask / xarray 进行流式计算。
|
||||
3. **可缓存、可复用**:
|
||||
中间产物落盘后,下游算法(大气校正、辐射定标)能直接消费,
|
||||
避免重复 IO。
|
||||
4. **易于并行与分布式**:
|
||||
任务调度层只需把两个路径扔给 worker,无需关心数据细节。
|
||||
|
||||
约定
|
||||
----
|
||||
- 子类应实现 process(),完成「读 -> 计算 -> 写」的完整流程。
|
||||
- process() 返回 True 表示成功,False 表示失败。
|
||||
- 失败时建议抛出异常而非仅返回 False,便于上层 BackgroundTasks 捕获并写入 error 字段。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseGlintRemover(ABC):
|
||||
"""
|
||||
去耀斑算法抽象基类。
|
||||
|
||||
所有具体算法(Kutser / Goodman / Hedley / SUGAR …)必须继承本类并实现 process()。
|
||||
子类可在 __init__ 中接收自己的超参数(如参考波段、阈值等),
|
||||
真正的输入输出数据则由 process() 的两个 zarr 路径参数指定。
|
||||
"""
|
||||
|
||||
# 子类可覆盖的算法名称标识,用于调度层按 method 名字查找
|
||||
name: str = "base"
|
||||
|
||||
@abstractmethod
|
||||
async def process(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
执行去耀斑处理。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_zarr_path : str
|
||||
输入高光谱影像的 zarr 存储路径。
|
||||
数据已由 IO 层完成格式归一化(波段、坐标系、空间维度均已对齐)。
|
||||
output_zarr_path : str
|
||||
处理结果(去耀斑后影像)的 zarr 存储路径。
|
||||
子类需自行创建该 zarr 存储并写入结果。
|
||||
**kwargs : Any
|
||||
算法的可选超参数,例如:
|
||||
- reference_band: 参考近红外波段索引
|
||||
- chunk_size: 计算分块大小
|
||||
- 其它算法特定参数
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True 表示处理成功,False 表示失败。
|
||||
建议在出错时直接 raise,由调用方统一记录到任务状态。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - 调试辅助
|
||||
return f"<{self.__class__.__name__} name={self.name!r}>"
|
||||
123
new/app/core/algorithms/goodman.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""
|
||||
app/core/algorithms/goodman.py
|
||||
===============================
|
||||
|
||||
Goodman et al. 2008 去耀斑算法的 xarray + dask 流式实现。
|
||||
|
||||
算法公式
|
||||
--------
|
||||
R_corrected = R_raw - R_750 + A + B * (R_640 - R_750)
|
||||
|
||||
其中:
|
||||
R_raw -- 原始反射率 (y, x, band)
|
||||
R_750 -- λ=750 nm 处的反射率(红外参考波段, 远离水汽吸收)
|
||||
R_640 -- λ=640 nm 处的反射率(可见光差异波段)
|
||||
A, B -- 经验回归参数(用户可通过 params 传入, 默认全 0)
|
||||
|
||||
后处理
|
||||
------
|
||||
- 负值截断为 0(Clamp to 0)
|
||||
- 仅在水域掩膜 (water_mask) 内生效, 水外置 0
|
||||
|
||||
维度约定
|
||||
--------
|
||||
reflectance: (y, x, band), band 坐标通常为 wavelength (nm)
|
||||
water_mask : (y, x), 布尔类型, True = 水域
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import xarray as xr
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
from app.core.algorithms.registry import register_glint_remover
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 默认参数
|
||||
# ---------------------------------------------------------------------------
|
||||
# 与原始 Goodman 2008 论文符号保持一致, 方便用户交叉对照。
|
||||
# A、B 通常通过对纯净深水区做 (R_corr - R_raw) ~ (R_640 - R_750) 回归得到;
|
||||
# 在缺乏先验知识时, 退化为 A=0, B=0 即等价于 R_corrected = clip(R_raw - R_750, 0)。
|
||||
# ---------------------------------------------------------------------------
|
||||
DEFAULT_BAND_REF: float = 750.0 # λ_750 nm, 红外参考波段
|
||||
DEFAULT_BAND_DIFF: float = 640.0 # λ_640 nm, 可见光差异波段
|
||||
DEFAULT_A: float = 0.0 # 公式中的常数偏移项
|
||||
DEFAULT_B: float = 0.0 # 公式中的斜率项
|
||||
|
||||
|
||||
@register_glint_remover("goodman")
|
||||
class GoodmanGlintRemover(BaseGlintRemover):
|
||||
"""Goodman et al. 2008 去耀斑算法"""
|
||||
|
||||
name = "goodman"
|
||||
|
||||
async def process(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
# 1. 解析超参数(带默认值, 方便用户按需覆盖)
|
||||
band_ref: float = kwargs.get("band_ref", DEFAULT_BAND_REF)
|
||||
band_diff: float = kwargs.get("band_diff", DEFAULT_BAND_DIFF)
|
||||
A: float = kwargs.get("A", DEFAULT_A)
|
||||
B: float = kwargs.get("B", DEFAULT_B)
|
||||
|
||||
# 2. 把同步的 xarray/dask 计算丢到工作线程,
|
||||
# 避免阻塞 FastAPI 的事件循环
|
||||
return await asyncio.to_thread(
|
||||
self._process_sync,
|
||||
input_zarr_path,
|
||||
output_zarr_path,
|
||||
band_ref,
|
||||
band_diff,
|
||||
A,
|
||||
B,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_sync(
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
band_ref: float,
|
||||
band_diff: float,
|
||||
A: float,
|
||||
B: float,
|
||||
) -> bool:
|
||||
# 1. 以 zarr 路径打开(dask-backed, 不物化到内存)
|
||||
# chunks="auto" 让 dask 根据每条坐标轴的大小自动决定分块
|
||||
ds = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
reflectance = ds["reflectance"] # (y, x, band)
|
||||
|
||||
# 2. 用 sel + method='nearest' 提取两个关键波段
|
||||
# 返回形状 (y, x), 后续与 (y, x, band) 算术时会自动广播
|
||||
R_750 = reflectance.sel(band=band_ref, method="nearest")
|
||||
R_640 = reflectance.sel(band=band_diff, method="nearest")
|
||||
|
||||
# 3. Goodman 公式: xarray 沿 band 维度自动广播
|
||||
# R_corr = R_raw - R_750 + A + B * (R_640 - R_750)
|
||||
result = reflectance - R_750 + A + B * (R_640 - R_750)
|
||||
|
||||
# 4. 负值截断为 0(clip(min=0) 优于 where(>0, 0, _):
|
||||
# 不构造布尔中间数组, 底层走 dask 矢量化 clip 路径)
|
||||
result = result.clip(min=0)
|
||||
|
||||
# 5. 仅在水域内生效(水外强制为 0)
|
||||
# 优先从 zarr 内部读 water_mask 变量, 缺失则视为全图水域
|
||||
if "water_mask" in ds:
|
||||
water_mask = ds["water_mask"].astype(bool)
|
||||
result = result.where(water_mask, 0)
|
||||
|
||||
# 6. 构造输出 Dataset, 保留元信息(波段坐标/属性等)
|
||||
out = xr.Dataset({"reflectance": result})
|
||||
if ds.attrs:
|
||||
out.attrs = dict(ds.attrs)
|
||||
if reflectance.attrs:
|
||||
out["reflectance"].attrs = dict(reflectance.attrs)
|
||||
|
||||
# 7. 流式写出(Out-of-Core):不一次性物化大数组,
|
||||
# dask 会按 chunk 边算边写, 内存峰值 ≈ 单个 chunk 大小
|
||||
out.to_zarr(output_zarr_path, mode="w", compute=True)
|
||||
return True
|
||||
211
new/app/core/algorithms/kutser.py
Normal file
@ -0,0 +1,211 @@
|
||||
"""
|
||||
Kutser 去耀斑算法(xarray + dask 重构版)
|
||||
========================================
|
||||
|
||||
旧版痛点
|
||||
--------
|
||||
原始 Kutser 实现(参考 Kutser et al., 2013)通常写成像这样:
|
||||
|
||||
R_corr = np.zeros_like(R_raw)
|
||||
for b in range(n_bands):
|
||||
for y in range(H):
|
||||
for x in range(W):
|
||||
if water_mask[y, x]:
|
||||
R_corr[y, x, b] = (
|
||||
R_raw[y, x, b] - G_list[b] * D_norm[y, x]
|
||||
)
|
||||
with rasterio.open(..., 'w') as dst:
|
||||
dst.write(R_corr)
|
||||
|
||||
问题:
|
||||
1. 三重 Python 循环,每次只做一个浮点运算,解释器开销巨大;
|
||||
2. 一次性把整张图 R_raw 读进内存,大影像直接 OOM;
|
||||
3. rasterio 写出要求 numpy 连续数组,进一步放大内存。
|
||||
|
||||
本文件用 xarray + dask 重写:
|
||||
- 用 DataArray 维度广播,三重循环 → 一行表达式;
|
||||
- 用 dask chunk 保持数据常驻磁盘、流式计算;
|
||||
- 用 to_zarr 边算边写,输出格式与算法层彻底解耦。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import xarray as xr
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
from app.core.algorithms.registry import register_glint_remover
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 算法实现
|
||||
# ---------------------------------------------------------------------------
|
||||
@register_glint_remover("kutser")
|
||||
class KutserGlintRemover(BaseGlintRemover):
|
||||
"""
|
||||
Kutser 近红外扣除法去耀斑。
|
||||
|
||||
数学公式(与旧版完全等价)
|
||||
-------------------------
|
||||
1) 水汽吸收深度 D(每像素):
|
||||
D = (R(λ_lower) + R(λ_upper)) / 2 - R(λ_oxy)
|
||||
2) 全局归一化因子 D_max:
|
||||
D_max = max(D) over 水域
|
||||
归一化:
|
||||
D_norm = D / D_max
|
||||
3) 每波段水域范围:
|
||||
G_list[b] = max(R[:, :, b] over 水域) - min(R[:, :, b] over 水域)
|
||||
4) 校正公式(每像素、每波段):
|
||||
R_corr(λ_b) = R_raw(λ_b) - G_list[b] * D_norm
|
||||
"""
|
||||
|
||||
# Kutser 2013 论文里使用的参考波段(nm):
|
||||
# λ_lower = 773, λ_oxy = 845, λ_upper = 893
|
||||
# 允许通过 kwargs 覆盖,便于适配 MERIS / OLCI / Landsat 等不同传感器。
|
||||
DEFAULT_BAND_LOWER: float = 773.0
|
||||
DEFAULT_BAND_OXY: float = 845.0
|
||||
DEFAULT_BAND_UPPER: float = 893.0
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# 公开异步入口
|
||||
# --------------------------------------------------------------
|
||||
# xarray / dask 的算子本身是同步阻塞的。在 async 函数中,
|
||||
# 用 asyncio.to_thread 把同步体丢到默认线程池执行,
|
||||
# 避免阻塞 FastAPI 的事件循环。
|
||||
# --------------------------------------------------------------
|
||||
async def process(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
return await asyncio.to_thread(
|
||||
self._process_sync,
|
||||
input_zarr_path,
|
||||
output_zarr_path,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# 同步核心实现
|
||||
# --------------------------------------------------------------
|
||||
def _process_sync(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
kwargs: dict,
|
||||
) -> bool:
|
||||
# ============================================================
|
||||
# 步骤 0:打开 zarr,建立 dask 计算图
|
||||
# ============================================================
|
||||
# chunks="auto":让 dask 根据 zarr 的存储分块自动选择内存上限,
|
||||
# 数据不会一次性全部 materialize 进 RAM。
|
||||
# ============================================================
|
||||
ds = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
reflectance: xr.DataArray = ds["reflectance"] # 维度约定:(y, x, band)
|
||||
|
||||
# 维度顺序约定(也可根据 ds.dims 自动适配):
|
||||
assert "y" in reflectance.dims and "x" in reflectance.dims and "band" in reflectance.dims, (
|
||||
f"reflectance 必须包含 y/x/band 三个维度,实际为: {reflectance.dims}"
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 1:取出 3 个参考波段对应的二维 (y, x) 切片
|
||||
# ============================================================
|
||||
# 假设 band 维度的坐标是 wavelength(nm)。
|
||||
# 用 sel(..., method="nearest") 自动匹配最接近的波段。
|
||||
# ============================================================
|
||||
wl_lower = float(kwargs.get("band_lower", self.DEFAULT_BAND_LOWER))
|
||||
wl_oxy = float(kwargs.get("band_oxy", self.DEFAULT_BAND_OXY))
|
||||
wl_upper = float(kwargs.get("band_upper", self.DEFAULT_BAND_UPPER))
|
||||
|
||||
R_lower = reflectance.sel(band=wl_lower, method="nearest") # (y, x)
|
||||
R_upper = reflectance.sel(band=wl_upper, method="nearest") # (y, x)
|
||||
R_oxy = reflectance.sel(band=wl_oxy, method="nearest") # (y, x)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 2:水域掩膜
|
||||
# ============================================================
|
||||
# 优先从 zarr 内部读取 water_mask 变量;
|
||||
# 如果不存在,则假定整幅图都是水域(开发期兜底)。
|
||||
# ============================================================
|
||||
if "water_mask" in ds:
|
||||
water_mask = ds["water_mask"].astype(bool)
|
||||
else:
|
||||
water_mask = xr.ones_like(
|
||||
reflectance.isel(band=0), dtype=bool
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 3:水汽吸收深度 D(每像素,形状 (y, x))
|
||||
# ============================================================
|
||||
# 旧版:D[y, x] = (R_lower[y, x] + R_upper[y, x]) / 2 - R_oxy[y, x]
|
||||
# 新版:一行表达式,dask 自动构建 lazy 计算图。
|
||||
# ============================================================
|
||||
D = (R_lower + R_upper) / 2.0 - R_oxy # (y, x),dtype 与 reflectance 一致
|
||||
|
||||
# ============================================================
|
||||
# 步骤 4:全局归一化因子 D_max(标量,0-dim DataArray)
|
||||
# ============================================================
|
||||
# 关键:先 .where(water_mask) 把非水域置 NaN,
|
||||
# 再 .max() 跨 (x, y) 聚合,自动规约到 0 维。
|
||||
# dask 此时仍然没有真正计算,等到 to_zarr 时再触发。
|
||||
# ============================================================
|
||||
D_max = D.where(water_mask).max() # scalar
|
||||
# 容错:如果水域为空导致 D_max 为 NaN,用极小值兜底,避免除零
|
||||
D_max = D_max.fillna(1e-6)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 5:归一化 D_norm(形状 (y, x))
|
||||
# ============================================================
|
||||
D_norm = D / D_max # 标量除以 (y, x) 数组 → 自动广播
|
||||
|
||||
# ============================================================
|
||||
# 步骤 6:每波段水域范围 G_list(形状 (band,))
|
||||
# ============================================================
|
||||
# 旧版三重循环内部还要做一次 min/max 聚合。
|
||||
# xarray 版本:把 (y, x) 一起 reduce,只保留 band 维度。
|
||||
# ============================================================
|
||||
R_water = reflectance.where(water_mask) # (y, x, band),非水域 NaN
|
||||
G_min = R_water.min(dim=["x", "y"]) # (band,)
|
||||
G_max = R_water.max(dim=["x", "y"]) # (band,)
|
||||
G_list = (G_max - G_min).fillna(0.0) # (band,),容错
|
||||
|
||||
# ============================================================
|
||||
# 步骤 7:校正公式(最关键的一行,演示 xarray 广播)
|
||||
# ============================================================
|
||||
# 旧版需要:
|
||||
# for b in bands:
|
||||
# for y in range(H):
|
||||
# for x in range(W):
|
||||
# R_corr[y,x,b] = R_raw[y,x,b] - G_list[b] * D_norm[y,x]
|
||||
#
|
||||
# xarray 维度对齐规则:
|
||||
# R_raw : (y, x, band)
|
||||
# G_list: (band,) → 缺失 y, x 自动扩展
|
||||
# D_norm: (y, x) → 缺失 band 自动扩展
|
||||
# 乘法结果: (y, x, band) → 减法对齐
|
||||
# 一行表达式完成「三重 for 循环 + 标量索引」的语义。
|
||||
# ============================================================
|
||||
corrected = reflectance - G_list * D_norm # (y, x, band)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 8:水域掩膜过滤(非水域置 NaN)
|
||||
# ============================================================
|
||||
result = corrected.where(water_mask)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 9:持久化为 zarr
|
||||
# ============================================================
|
||||
# mode="w":覆盖写入(如果目标已存在则删除重建)。
|
||||
# compute=True:阻塞直到整张图算完并落盘。
|
||||
# 由于数据始终是 dask chunk + 流式写出,
|
||||
# 内存峰值 ≈ 单个 chunk 大小,与整张影像大小无关。
|
||||
# ============================================================
|
||||
out = xr.Dataset({"reflectance": result})
|
||||
# 保留原数据集的全局属性 / 坐标信息(CRS、wavelength、...)
|
||||
out.attrs = dict(ds.attrs)
|
||||
out["reflectance"].attrs = dict(reflectance.attrs)
|
||||
out.to_zarr(output_zarr_path, mode="w", compute=True)
|
||||
|
||||
return True
|
||||
135
new/app/core/algorithms/registry.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""
|
||||
算法注册表(Registry / Factory)
|
||||
================================
|
||||
|
||||
通过装饰器把「算法名字符串」与「算法实现类」绑定在一起。
|
||||
上层调度层(FastAPI endpoints、BackgroundTasks worker)只需要拿到
|
||||
前端传过来的 method 字符串,就可以自动派发到对应的算法实现,
|
||||
而无需写一长串 if/elif。
|
||||
|
||||
使用示例
|
||||
--------
|
||||
|
||||
from app.core.algorithms import BaseGlintRemover
|
||||
from app.core.algorithms.registry import (
|
||||
register_glint_remover,
|
||||
get_remover,
|
||||
list_removers,
|
||||
)
|
||||
|
||||
@register_glint_remover("kutser")
|
||||
class KutserGlintRemover(BaseGlintRemover):
|
||||
async def process(self, input_zarr_path, output_zarr_path, **kwargs):
|
||||
...
|
||||
|
||||
# 派发
|
||||
Cls = get_remover(method_from_request)
|
||||
remover = Cls()
|
||||
await remover.process(input_zarr_path, output_zarr_path, **kwargs)
|
||||
|
||||
设计要点
|
||||
--------
|
||||
- 注册动作发生在「类定义时」,所以必须在所有算法 import 完之后
|
||||
注册表才完整。可以在 `app/core/algorithms/__init__.py` 中
|
||||
把算法子模块 import 一遍来强制触发注册。
|
||||
- 重复注册同名算法会直接抛错,避免静默覆盖。
|
||||
- name 会同步写回到类的 `name` 属性,便于算法自身查询身份。
|
||||
"""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
|
||||
|
||||
# 全局注册表:name(str) -> 实现类(type),类未被实例化
|
||||
_REGISTRY: Dict[str, Type[BaseGlintRemover]] = {}
|
||||
|
||||
|
||||
def register_glint_remover(name: str):
|
||||
"""
|
||||
类装饰器工厂:把传入 name 的算法类注册到全局注册表。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
算法标识,建议小写下划线风格,例如 "kutser"、"goodman"。
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
- name 不是非空字符串
|
||||
- name 已经被其它类占用
|
||||
TypeError
|
||||
- 被装饰的对象不是 BaseGlintRemover 的子类
|
||||
"""
|
||||
|
||||
# ---- 防御性校验:name 必须是合法字符串 ----
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError(
|
||||
f"register_glint_remover 的 name 必须是非空字符串,收到: {name!r}"
|
||||
)
|
||||
|
||||
def decorator(cls: Type[BaseGlintRemover]) -> Type[BaseGlintRemover]:
|
||||
# ---- 防御性校验:被装饰对象必须是 BaseGlintRemover 子类 ----
|
||||
if not isinstance(cls, type) or not issubclass(cls, BaseGlintRemover):
|
||||
raise TypeError(
|
||||
f"@register_glint_remover 只能装饰 BaseGlintRemover 的子类,"
|
||||
f"收到: {cls!r}"
|
||||
)
|
||||
|
||||
# ---- 防御性校验:禁止静默覆盖 ----
|
||||
if name in _REGISTRY:
|
||||
raise ValueError(
|
||||
f"算法名 {name!r} 已被 {_REGISTRY[name].__name__} 占用,"
|
||||
f"请使用其它名字或先调用 unregister_glint_remover() 注销旧实现。"
|
||||
)
|
||||
|
||||
# 同步把 name 写回类属性,便于算法自身和日志输出使用
|
||||
cls.name = name
|
||||
_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_remover(name: str) -> Type[BaseGlintRemover]:
|
||||
"""
|
||||
按算法名字符串取出对应的实现类(未实例化)。
|
||||
|
||||
调用方拿到类后自行 `Cls(...)` 构造实例,再调用 process()。
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
当 name 不在注册表中时抛出,错误信息中附带已注册列表便于排查。
|
||||
"""
|
||||
try:
|
||||
return _REGISTRY[name]
|
||||
except KeyError as exc:
|
||||
known = ", ".join(sorted(_REGISTRY)) or "<空>"
|
||||
raise KeyError(
|
||||
f"未注册的算法名: {name!r}。已注册的算法: {known}"
|
||||
) from exc
|
||||
|
||||
|
||||
def list_removers() -> Dict[str, Type[BaseGlintRemover]]:
|
||||
"""
|
||||
返回当前注册表的浅拷贝。
|
||||
可用于:
|
||||
- 调试日志
|
||||
- 给前端暴露一个 GET /api/algorithms 接口
|
||||
- 单元测试断言
|
||||
"""
|
||||
return dict(_REGISTRY)
|
||||
|
||||
|
||||
def unregister_glint_remover(name: str) -> None:
|
||||
"""
|
||||
注销指定算法。主要给:
|
||||
- 单元测试
|
||||
- 热重载 / 插件卸载场景
|
||||
生产代码一般不需要调用。
|
||||
"""
|
||||
if name not in _REGISTRY:
|
||||
raise KeyError(f"未注册的算法名: {name!r}")
|
||||
del _REGISTRY[name]
|
||||
91
new/app/core/task_store.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""
|
||||
app/core/task_store.py
|
||||
======================
|
||||
|
||||
并发安全的内存任务状态存储,替代早期 mock 流水线中的 MOCK_TASK_DB。
|
||||
|
||||
设计目标
|
||||
--------
|
||||
1. 在单进程内提供事件循环级别的互斥(asyncio.Lock),
|
||||
避免在 update 与 set/get 之间穿插 await 时发生状态不一致。
|
||||
2. 暴露异步 API(set_task / update_task / get_task),
|
||||
让调用方在 async 上下文中显式表达临界区。
|
||||
3. 保留一个同步的 has_task() 用于轻量存在性判断。
|
||||
4. 生产环境应替换为 Redis / SQLite / PostgreSQL,
|
||||
但接口形状保持一致, 便于上层调用方无缝迁移。
|
||||
|
||||
使用约定
|
||||
--------
|
||||
- 写入初始 PENDING 记录: await set_task(task_id, record)
|
||||
- 增量更新字段(PROCESSING/SUCCESS/FAILED):await update_task(task_id, **fields)
|
||||
- 读取任务记录: await get_task(task_id) # 可能返回 None
|
||||
- 同步判断是否存在: has_task(task_id)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 全局存储与锁
|
||||
# ---------------------------------------------------------------------------
|
||||
# TASK_STORE: task_id -> 任务记录
|
||||
# 任务记录字段约定(与 endpoints.py 保持一致):
|
||||
# task_id, method, params, status,
|
||||
# output_zarr_path, error, traceback,
|
||||
# created_at, updated_at
|
||||
# ---------------------------------------------------------------------------
|
||||
TASK_STORE: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 单进程内的事件循环级互斥锁
|
||||
# 注意:asyncio.Lock 必须在事件循环内创建, 故在模块顶层实例化时
|
||||
# 仅获取引用, 第一次使用 (await lock.acquire()) 会在运行循环内进行。
|
||||
_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 异步 API
|
||||
# ---------------------------------------------------------------------------
|
||||
async def set_task(task_id: str, record: Dict[str, Any]) -> None:
|
||||
"""
|
||||
初始化或整体覆盖一个任务记录。
|
||||
|
||||
用法:POST 端点收到提交请求后立即调用, 写入 PENDING 状态的初始记录。
|
||||
"""
|
||||
async with _lock:
|
||||
TASK_STORE[task_id] = record
|
||||
|
||||
|
||||
async def update_task(task_id: str, **fields: Any) -> None:
|
||||
"""
|
||||
按字段增量更新任务记录。
|
||||
|
||||
用法:后台执行器在 PROCESSING / SUCCESS / FAILED 等状态切换时调用。
|
||||
若 task_id 不存在, setdefault 会自动创建一个空 dict 再 update(防御性兜底)。
|
||||
"""
|
||||
async with _lock:
|
||||
record = TASK_STORE.setdefault(task_id, {})
|
||||
record.update(fields)
|
||||
|
||||
|
||||
async def get_task(task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取任务记录; 不存在时返回 None。
|
||||
|
||||
用法:GET /api/tasks/{task_id} 用此接口查询。
|
||||
"""
|
||||
async with _lock:
|
||||
return TASK_STORE.get(task_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 同步 API(轻量)
|
||||
# ---------------------------------------------------------------------------
|
||||
def has_task(task_id: str) -> bool:
|
||||
"""
|
||||
同步判断 task_id 是否存在。
|
||||
|
||||
适用于不需要锁的轻量场景(例如日志前置判断);
|
||||
在 async 上下文中仍可调用, 因为 dict 的 in 判断是原子操作。
|
||||
"""
|
||||
return task_id in TASK_STORE
|
||||
62
new/app/main.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""
|
||||
WQ_GUI FastAPI 后端入口
|
||||
=======================
|
||||
|
||||
应用启动与全局中间件配置:
|
||||
- CORS:开发阶段允许所有来源,方便本地前端(Vite / Webpack dev server)联调
|
||||
- 路由:通过 include_router 挂载 app/api/endpoints.py 中的业务接口
|
||||
|
||||
业务接口说明:
|
||||
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
|
||||
GET /api/tasks/{task_id} 查询指定任务的状态与结果
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.endpoints import router as deglint_router
|
||||
from app.api.modeling import router as modeling_router
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI 应用实例
|
||||
# ---------------------------------------------------------------------------
|
||||
app = FastAPI(
|
||||
title="WQ_GUI Backend",
|
||||
description="高光谱影像去耀斑处理 API",
|
||||
version="0.2.0",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CORS 中间件
|
||||
# ---------------------------------------------------------------------------
|
||||
# 开发阶段:放开所有来源、方法和头部,方便本地前端(任意端口)联调。
|
||||
# 生产环境务必收敛 allow_origins 为前端真实域名,避免安全风险。
|
||||
# ---------------------------------------------------------------------------
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 路由注册
|
||||
# ---------------------------------------------------------------------------
|
||||
# 统一以 /api 为前缀,便于将来做版本管理(如 /api/v1、/api/v2)。
|
||||
# ---------------------------------------------------------------------------
|
||||
app.include_router(deglint_router, prefix="/api")
|
||||
app.include_router(modeling_router, prefix="/api")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 根路径健康检查(方便本地调试,非业务必需)
|
||||
# ---------------------------------------------------------------------------
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, str]:
|
||||
return {"service": "WQ_GUI Backend", "status": "ok"}
|
||||
24
new/frontend/.gitignore
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
5
new/frontend/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Vue 3 + TypeScript + Vite
|
||||
|
||||
This template should help get you started developing with Vue 3 and TypeScript in Vite. The template uses Vue 3 `<script setup>` SFCs, check out the [script setup docs](https://v3.vuejs.org/api/sfc-script-setup.html#sfc-script-setup) to learn more.
|
||||
|
||||
Learn more about the recommended Project Setup and IDE Support in the [Vue Docs TypeScript Guide](https://vuejs.org/guide/typescript/overview.html#project-setup).
|
||||
13
new/frontend/index.html
Normal file
@ -0,0 +1,13 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>frontend</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
<script type="module" src="/src/main.ts"></script>
|
||||
</body>
|
||||
</html>
|
||||
2412
new/frontend/package-lock.json
generated
Normal file
27
new/frontend/package.json
Normal file
@ -0,0 +1,27 @@
|
||||
{
|
||||
"name": "frontend",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vue-tsc -b && vite build",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.16.1",
|
||||
"echarts": "^6.1.0",
|
||||
"element-plus": "^2.14.1",
|
||||
"pinia": "^3.0.4",
|
||||
"vue": "^3.5.34",
|
||||
"vue-router": "^5.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^24.12.3",
|
||||
"@vitejs/plugin-vue": "^6.0.6",
|
||||
"@vue/tsconfig": "^0.9.1",
|
||||
"typescript": "~6.0.2",
|
||||
"vite": "^8.0.12",
|
||||
"vue-tsc": "^3.2.8"
|
||||
}
|
||||
}
|
||||
1
new/frontend/public/favicon.svg
Normal file
|
After Width: | Height: | Size: 9.3 KiB |
24
new/frontend/public/icons.svg
Normal file
@ -0,0 +1,24 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg">
|
||||
<symbol id="bluesky-icon" viewBox="0 0 16 17">
|
||||
<g clip-path="url(#bluesky-clip)"><path fill="#08060d" d="M7.75 7.735c-.693-1.348-2.58-3.86-4.334-5.097-1.68-1.187-2.32-.981-2.74-.79C.188 2.065.1 2.812.1 3.251s.241 3.602.398 4.13c.52 1.744 2.367 2.333 4.07 2.145-2.495.37-4.71 1.278-1.805 4.512 3.196 3.309 4.38-.71 4.987-2.746.608 2.036 1.307 5.91 4.93 2.746 2.72-2.746.747-4.143-1.747-4.512 1.702.189 3.55-.4 4.07-2.145.156-.528.397-3.691.397-4.13s-.088-1.186-.575-1.406c-.42-.19-1.06-.395-2.741.79-1.755 1.24-3.64 3.752-4.334 5.099"/></g>
|
||||
<defs><clipPath id="bluesky-clip"><path fill="#fff" d="M.1.85h15.3v15.3H.1z"/></clipPath></defs>
|
||||
</symbol>
|
||||
<symbol id="discord-icon" viewBox="0 0 20 19">
|
||||
<path fill="#08060d" d="M16.224 3.768a14.5 14.5 0 0 0-3.67-1.153c-.158.286-.343.67-.47.976a13.5 13.5 0 0 0-4.067 0c-.128-.306-.317-.69-.476-.976A14.4 14.4 0 0 0 3.868 3.77C1.546 7.28.916 10.703 1.231 14.077a14.7 14.7 0 0 0 4.5 2.306q.545-.748.965-1.587a9.5 9.5 0 0 1-1.518-.74q.191-.14.372-.293c2.927 1.369 6.107 1.369 8.999 0q.183.152.372.294-.723.437-1.52.74.418.838.963 1.588a14.6 14.6 0 0 0 4.504-2.308c.37-3.911-.63-7.302-2.644-10.309m-9.13 8.234c-.878 0-1.599-.82-1.599-1.82 0-.998.705-1.82 1.6-1.82.894 0 1.614.82 1.599 1.82.001 1-.705 1.82-1.6 1.82m5.91 0c-.878 0-1.599-.82-1.599-1.82 0-.998.705-1.82 1.6-1.82.893 0 1.614.82 1.599 1.82 0 1-.706 1.82-1.6 1.82"/>
|
||||
</symbol>
|
||||
<symbol id="documentation-icon" viewBox="0 0 21 20">
|
||||
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="m15.5 13.333 1.533 1.322c.645.555.967.833.967 1.178s-.322.623-.967 1.179L15.5 18.333m-3.333-5-1.534 1.322c-.644.555-.966.833-.966 1.178s.322.623.966 1.179l1.534 1.321"/>
|
||||
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M17.167 10.836v-4.32c0-1.41 0-2.117-.224-2.68-.359-.906-1.118-1.621-2.08-1.96-.599-.21-1.349-.21-2.848-.21-2.623 0-3.935 0-4.983.369-1.684.591-3.013 1.842-3.641 3.428C3 6.449 3 7.684 3 10.154v2.122c0 2.558 0 3.838.706 4.726q.306.383.713.671c.76.536 1.79.64 3.581.66"/>
|
||||
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M3 10a2.78 2.78 0 0 1 2.778-2.778c.555 0 1.209.097 1.748-.047.48-.129.854-.503.982-.982.145-.54.048-1.194.048-1.749a2.78 2.78 0 0 1 2.777-2.777"/>
|
||||
</symbol>
|
||||
<symbol id="github-icon" viewBox="0 0 19 19">
|
||||
<path fill="#08060d" fill-rule="evenodd" d="M9.356 1.85C5.05 1.85 1.57 5.356 1.57 9.694a7.84 7.84 0 0 0 5.324 7.44c.387.079.528-.168.528-.376 0-.182-.013-.805-.013-1.454-2.165.467-2.616-.935-2.616-.935-.349-.91-.864-1.143-.864-1.143-.71-.48.051-.48.051-.48.787.051 1.2.805 1.2.805.695 1.194 1.817.857 2.268.649.064-.507.27-.857.49-1.052-1.728-.182-3.545-.857-3.545-3.87 0-.857.31-1.558.8-2.104-.078-.195-.349-1 .077-2.078 0 0 .657-.208 2.14.805a7.5 7.5 0 0 1 1.946-.26c.657 0 1.328.092 1.946.26 1.483-1.013 2.14-.805 2.14-.805.426 1.078.155 1.883.078 2.078.502.546.799 1.247.799 2.104 0 3.013-1.818 3.675-3.558 3.87.284.247.528.714.528 1.454 0 1.052-.012 1.896-.012 2.156 0 .208.142.455.528.377a7.84 7.84 0 0 0 5.324-7.441c.013-4.338-3.48-7.844-7.773-7.844" clip-rule="evenodd"/>
|
||||
</symbol>
|
||||
<symbol id="social-icon" viewBox="0 0 20 20">
|
||||
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M12.5 6.667a4.167 4.167 0 1 0-8.334 0 4.167 4.167 0 0 0 8.334 0"/>
|
||||
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M2.5 16.667a5.833 5.833 0 0 1 8.75-5.053m3.837.474.513 1.035c.07.144.257.282.414.309l.93.155c.596.1.736.536.307.965l-.723.73a.64.64 0 0 0-.152.531l.207.903c.164.715-.213.991-.84.618l-.872-.52a.63.63 0 0 0-.577 0l-.872.52c-.624.373-1.003.094-.84-.618l.207-.903a.64.64 0 0 0-.152-.532l-.723-.729c-.426-.43-.289-.864.306-.964l.93-.156a.64.64 0 0 0 .412-.31l.513-1.034c.28-.562.735-.562 1.012 0"/>
|
||||
</symbol>
|
||||
<symbol id="x-icon" viewBox="0 0 19 19">
|
||||
<path fill="#08060d" fill-rule="evenodd" d="M1.893 1.98c.052.072 1.245 1.769 2.653 3.77l2.892 4.114c.183.261.333.48.333.486s-.068.089-.152.183l-.522.593-.765.867-3.597 4.087c-.375.426-.734.834-.798.905a1 1 0 0 0-.118.148c0 .01.236.017.664.017h.663l.729-.83c.4-.457.796-.906.879-.999a692 692 0 0 0 1.794-2.038c.034-.037.301-.34.594-.675l.551-.624.345-.392a7 7 0 0 1 .34-.374c.006 0 .93 1.306 2.052 2.903l2.084 2.965.045.063h2.275c1.87 0 2.273-.003 2.266-.021-.008-.02-1.098-1.572-3.894-5.547-2.013-2.862-2.28-3.246-2.273-3.266.008-.019.282-.332 2.085-2.38l2-2.274 1.567-1.782c.022-.028-.016-.03-.65-.03h-.674l-.3.342a871 871 0 0 1-1.782 2.025c-.067.075-.405.458-.75.852a100 100 0 0 1-.803.91c-.148.172-.299.344-.99 1.127-.304.343-.32.358-.345.327-.015-.019-.904-1.282-1.976-2.808L6.365 1.85H1.8zm1.782.91 8.078 11.294c.772 1.08 1.413 1.973 1.425 1.984.016.017.241.02 1.05.017l1.03-.004-2.694-3.766L7.796 5.75 5.722 2.852l-1.039-.004-1.039-.004z" clip-rule="evenodd"/>
|
||||
</symbol>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 4.9 KiB |
225
new/frontend/src/App.vue
Normal file
@ -0,0 +1,225 @@
|
||||
<template>
|
||||
<div class="dashboard-container">
|
||||
<h1 class="title">高光谱水质反演控制台</h1>
|
||||
<el-row :gutter="20">
|
||||
|
||||
<el-col :span="12">
|
||||
<el-card class="box-card" shadow="hover">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span class="header-title">🚀 模型训练 (Train)</span>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<el-form label-position="top">
|
||||
<el-form-item label="算法选择 (Model Type)">
|
||||
<el-select v-model="trainForm.model_type" placeholder="请选择算法" class="w-full">
|
||||
<el-option label="随机森林 (RF)" value="RF" />
|
||||
<el-option label="支持向量回归 (SVR)" value="SVR" />
|
||||
<el-option label="线性回归 (LinearRegression)" value="LinearRegression" />
|
||||
<el-option label="K近邻 (KNN)" value="KNN" />
|
||||
<el-option label="偏最小二乘 (PLS)" value="PLS" />
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="目标参数 (Target)">
|
||||
<el-input v-model="trainForm.target" placeholder="如 Chl-a" />
|
||||
</el-form-item>
|
||||
<el-form-item label="训练数据路径 (CSV 绝对路径)">
|
||||
<el-input v-model="trainForm.train_data_path" placeholder="如 D:\111\data.csv" />
|
||||
</el-form-item>
|
||||
<el-form-item label="特征起始列 (如 4, 或列名)">
|
||||
<el-input v-model="trainForm.feature_start" placeholder="填写数字或列名" />
|
||||
</el-form-item>
|
||||
<el-button type="primary" @click="handleTrain" :loading="trainPoller?.isPolling?.value" class="w-full">
|
||||
开始训练
|
||||
</el-button>
|
||||
</el-form>
|
||||
|
||||
<div v-if="trainTaskId" class="status-board">
|
||||
<p><strong>任务 ID:</strong> <el-tag size="small" type="info">{{ trainTaskId }}</el-tag></p>
|
||||
<p><strong>当前状态:</strong>
|
||||
<el-tag :type="getStatusType(trainPoller?.status?.value || 'PENDING')" style="margin-left:10px">
|
||||
{{ trainPoller?.status?.value || 'PENDING' }}
|
||||
</el-tag>
|
||||
</p>
|
||||
<el-progress
|
||||
v-if="trainPoller?.isPolling?.value || trainPoller?.status?.value === 'SUCCESS'"
|
||||
:percentage="trainPoller?.status?.value === 'SUCCESS' ? 100 : 60"
|
||||
:status="trainPoller?.status?.value === 'SUCCESS' ? 'success' : (trainPoller?.status?.value === 'FAILED' ? 'exception' : '')"
|
||||
:indeterminate="trainPoller?.isPolling?.value"
|
||||
/>
|
||||
|
||||
<div v-if="trainPoller?.error?.value" class="error-msg">
|
||||
<el-alert :title="trainPoller.error.value" type="error" :closable="false" show-icon />
|
||||
</div>
|
||||
|
||||
<div v-if="trainPoller?.result?.value?.model_id" class="result-msg">
|
||||
<el-descriptions border :column="1" size="small" title="训练指标">
|
||||
<el-descriptions-item label="Model ID">{{ trainPoller.result.value.model_id }}</el-descriptions-item>
|
||||
<el-descriptions-item label="Test R²">{{ Number(trainPoller.result.value.test_r2).toFixed(4) }}</el-descriptions-item>
|
||||
<el-descriptions-item label="Test RMSE">{{ Number(trainPoller.result.value.test_rmse).toFixed(4) }}</el-descriptions-item>
|
||||
</el-descriptions>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
|
||||
<el-col :span="12">
|
||||
<el-card class="box-card" shadow="hover">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span class="header-title">🎯 模型推断 (Predict)</span>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<el-form label-position="top">
|
||||
<el-form-item label="已训练模型 ID (Model ID)">
|
||||
<el-input v-model="predictForm.model_id" placeholder="将自动填入左侧训练好的 ID" />
|
||||
</el-form-item>
|
||||
<el-form-item label="待推断影像路径 (Zarr 绝对路径)">
|
||||
<el-input v-model="predictForm.input_zarr_path" placeholder="如 D:\111\image.zarr" />
|
||||
</el-form-item>
|
||||
<el-button type="success" @click="handlePredict" :loading="predictPoller?.isPolling?.value" class="w-full">
|
||||
开始大图反演推断
|
||||
</el-button>
|
||||
</el-form>
|
||||
|
||||
<div v-if="predictTaskId" class="status-board">
|
||||
<p><strong>任务 ID:</strong> <el-tag size="small" type="info">{{ predictTaskId }}</el-tag></p>
|
||||
<p><strong>当前状态:</strong>
|
||||
<el-tag :type="getStatusType(predictPoller?.status?.value || 'PENDING')" style="margin-left:10px">
|
||||
{{ predictPoller?.status?.value || 'PENDING' }}
|
||||
</el-tag>
|
||||
</p>
|
||||
<el-progress
|
||||
v-if="predictPoller?.isPolling?.value || predictPoller?.status?.value === 'SUCCESS'"
|
||||
:percentage="predictPoller?.status?.value === 'SUCCESS' ? 100 : 50"
|
||||
:status="predictPoller?.status?.value === 'SUCCESS' ? 'success' : (predictPoller?.status?.value === 'FAILED' ? 'exception' : '')"
|
||||
:indeterminate="predictPoller?.isPolling?.value"
|
||||
/>
|
||||
|
||||
<div v-if="predictPoller?.error?.value" class="error-msg">
|
||||
<el-alert :title="predictPoller.error.value" type="error" :closable="false" show-icon />
|
||||
</div>
|
||||
|
||||
<div v-if="predictPoller?.result?.value?.output_zarr_path" class="result-msg">
|
||||
<el-alert :title="'推断成功!结果已落盘至: ' + predictPoller.result.value.output_zarr_path" type="success" :closable="false" show-icon />
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</el-col>
|
||||
|
||||
</el-row>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, reactive } from 'vue'
|
||||
import { submitTrain, submitPredict } from './api/tasks'
|
||||
import { useTaskPoller } from './composables/useTaskPoller'
|
||||
|
||||
// 训练表单状态
|
||||
const trainForm = reactive({
|
||||
model_type: 'RF',
|
||||
target: 'Chl-a',
|
||||
train_data_path: '',
|
||||
feature_start: '4'
|
||||
})
|
||||
const trainTaskId = ref<string | null>(null)
|
||||
const trainPoller = useTaskPoller(trainTaskId)
|
||||
|
||||
// 推断表单状态
|
||||
const predictForm = reactive({
|
||||
model_id: '',
|
||||
input_zarr_path: ''
|
||||
})
|
||||
const predictTaskId = ref<string | null>(null)
|
||||
const predictPoller = useTaskPoller(predictTaskId)
|
||||
|
||||
// 自动填入联动
|
||||
watch(() => trainPoller?.result?.value?.model_id, (newId) => {
|
||||
if (newId) predictForm.model_id = newId as string
|
||||
})
|
||||
|
||||
// 提交训练
|
||||
const handleTrain = async () => {
|
||||
try {
|
||||
const res = await submitTrain({
|
||||
model_type: trainForm.model_type,
|
||||
target: trainForm.target,
|
||||
train_data_path: trainForm.train_data_path,
|
||||
feature_start: trainForm.feature_start,
|
||||
params: {}
|
||||
})
|
||||
trainTaskId.value = res.task_id
|
||||
} catch (e: any) {
|
||||
console.error('训练接口调用失败', e)
|
||||
alert('提交失败,请检查后端是否在 9090 端口启动,或按 F12 查看控制台跨域报错')
|
||||
}
|
||||
}
|
||||
|
||||
// 提交推断
|
||||
const handlePredict = async () => {
|
||||
try {
|
||||
const res = await submitPredict({
|
||||
model_id: predictForm.model_id,
|
||||
input_zarr_path: predictForm.input_zarr_path
|
||||
})
|
||||
predictTaskId.value = res.task_id
|
||||
} catch (e: any) {
|
||||
console.error('推断接口调用失败', e)
|
||||
}
|
||||
}
|
||||
|
||||
// 样式辅助
|
||||
const getStatusType = (status: string) => {
|
||||
if (status === 'SUCCESS') return 'success'
|
||||
if (status === 'FAILED') return 'danger'
|
||||
if (status === 'PROCESSING') return 'warning'
|
||||
return 'info'
|
||||
}
|
||||
</script>
|
||||
|
||||
<style>
|
||||
/* 去除全局默认边距 */
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
</style>
|
||||
|
||||
<style scoped>
|
||||
.dashboard-container {
|
||||
padding: 40px;
|
||||
min-height: 100vh;
|
||||
background-color: #1e1e2d; /* 科技深色底 */
|
||||
}
|
||||
.title {
|
||||
text-align: center;
|
||||
margin-bottom: 40px;
|
||||
color: #ffffff;
|
||||
font-weight: 300;
|
||||
letter-spacing: 2px;
|
||||
}
|
||||
.header-title {
|
||||
font-weight: bold;
|
||||
font-size: 16px;
|
||||
}
|
||||
.box-card {
|
||||
margin-bottom: 20px;
|
||||
background-color: rgba(255, 255, 255, 0.95);
|
||||
}
|
||||
.w-full {
|
||||
width: 100%;
|
||||
}
|
||||
.status-board {
|
||||
margin-top: 25px;
|
||||
padding: 20px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
border: 1px solid #e4e7ed;
|
||||
}
|
||||
.error-msg, .result-msg {
|
||||
margin-top: 20px;
|
||||
}
|
||||
</style>
|
||||
15
new/frontend/src/api/request.ts
Normal file
@ -0,0 +1,15 @@
|
||||
import axios from 'axios'
|
||||
|
||||
const request = axios.create({
|
||||
// 注意:直接指向我们刚刚改好的 9090 端口
|
||||
baseURL: 'http://127.0.0.1:9090',
|
||||
timeout: 60000
|
||||
})
|
||||
|
||||
// 拦截器:直接剥离 data
|
||||
request.interceptors.response.use(
|
||||
response => response.data,
|
||||
error => Promise.reject(error)
|
||||
)
|
||||
|
||||
export default request
|
||||
13
new/frontend/src/api/tasks.ts
Normal file
@ -0,0 +1,13 @@
|
||||
import request from './request'
|
||||
|
||||
export const submitTrain = (data: any) => {
|
||||
return request.post<any, any>('/api/modeling/train', data)
|
||||
}
|
||||
|
||||
export const submitPredict = (data: any) => {
|
||||
return request.post<any, any>('/api/modeling/predict', data)
|
||||
}
|
||||
|
||||
export const getTaskStatus = (task_id: string) => {
|
||||
return request.get<any, any>(`/api/tasks/${task_id}`)
|
||||
}
|
||||
BIN
new/frontend/src/assets/hero.png
Normal file
|
After Width: | Height: | Size: 13 KiB |
1
new/frontend/src/assets/vite.svg
Normal file
|
After Width: | Height: | Size: 8.5 KiB |
1
new/frontend/src/assets/vue.svg
Normal file
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="37.07" height="36" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 198"><path fill="#41B883" d="M204.8 0H256L128 220.8L0 0h97.92L128 51.2L157.44 0h47.36Z"></path><path fill="#41B883" d="m0 0l128 220.8L256 0h-51.2L128 132.48L50.56 0H0Z"></path><path fill="#35495E" d="M50.56 0L128 133.12L204.8 0h-47.36L128 51.2L97.92 0H50.56Z"></path></svg>
|
||||
|
After Width: | Height: | Size: 496 B |
95
new/frontend/src/components/HelloWorld.vue
Normal file
@ -0,0 +1,95 @@
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
import viteLogo from '../assets/vite.svg'
|
||||
import heroImg from '../assets/hero.png'
|
||||
import vueLogo from '../assets/vue.svg'
|
||||
|
||||
const count = ref(0)
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<section id="center">
|
||||
<div class="hero">
|
||||
<img :src="heroImg" class="base" width="170" height="179" alt="" />
|
||||
<img :src="vueLogo" class="framework" alt="Vue logo" />
|
||||
<img :src="viteLogo" class="vite" alt="Vite logo" />
|
||||
</div>
|
||||
<div>
|
||||
<h1>Get started</h1>
|
||||
<p>Edit <code>src/App.vue</code> and save to test <code>HMR</code></p>
|
||||
</div>
|
||||
<button type="button" class="counter" @click="count++">
|
||||
Count is {{ count }}
|
||||
</button>
|
||||
</section>
|
||||
|
||||
<div class="ticks"></div>
|
||||
|
||||
<section id="next-steps">
|
||||
<div id="docs">
|
||||
<svg class="icon" role="presentation" aria-hidden="true">
|
||||
<use href="/icons.svg#documentation-icon"></use>
|
||||
</svg>
|
||||
<h2>Documentation</h2>
|
||||
<p>Your questions, answered</p>
|
||||
<ul>
|
||||
<li>
|
||||
<a href="https://vite.dev/" target="_blank">
|
||||
<img class="logo" :src="viteLogo" alt="" />
|
||||
Explore Vite
|
||||
</a>
|
||||
</li>
|
||||
<li>
|
||||
<a href="https://vuejs.org/" target="_blank">
|
||||
<img class="button-icon" :src="vueLogo" alt="" />
|
||||
Learn more
|
||||
</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="social">
|
||||
<svg class="icon" role="presentation" aria-hidden="true">
|
||||
<use href="/icons.svg#social-icon"></use>
|
||||
</svg>
|
||||
<h2>Connect with us</h2>
|
||||
<p>Join the Vite community</p>
|
||||
<ul>
|
||||
<li>
|
||||
<a href="https://github.com/vitejs/vite" target="_blank">
|
||||
<svg class="button-icon" role="presentation" aria-hidden="true">
|
||||
<use href="/icons.svg#github-icon"></use>
|
||||
</svg>
|
||||
GitHub
|
||||
</a>
|
||||
</li>
|
||||
<li>
|
||||
<a href="https://chat.vite.dev/" target="_blank">
|
||||
<svg class="button-icon" role="presentation" aria-hidden="true">
|
||||
<use href="/icons.svg#discord-icon"></use>
|
||||
</svg>
|
||||
Discord
|
||||
</a>
|
||||
</li>
|
||||
<li>
|
||||
<a href="https://x.com/vite_js" target="_blank">
|
||||
<svg class="button-icon" role="presentation" aria-hidden="true">
|
||||
<use href="/icons.svg#x-icon"></use>
|
||||
</svg>
|
||||
X.com
|
||||
</a>
|
||||
</li>
|
||||
<li>
|
||||
<a href="https://bsky.app/profile/vite.dev" target="_blank">
|
||||
<svg class="button-icon" role="presentation" aria-hidden="true">
|
||||
<use href="/icons.svg#bluesky-icon"></use>
|
||||
</svg>
|
||||
Bluesky
|
||||
</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<div class="ticks"></div>
|
||||
<section id="spacer"></section>
|
||||
</template>
|
||||
51
new/frontend/src/composables/useTaskPoller.ts
Normal file
@ -0,0 +1,51 @@
|
||||
import { ref, watch, onUnmounted, type Ref } from 'vue'
|
||||
import { getTaskStatus } from '../api/tasks'
|
||||
|
||||
export function useTaskPoller(taskIdRef: Ref<string | null>) {
|
||||
const status = ref<string>('')
|
||||
const isPolling = ref(false)
|
||||
const error = ref<string | null>(null)
|
||||
const result = ref<any>(null)
|
||||
let timer: any = null
|
||||
|
||||
const start = () => {
|
||||
if (!taskIdRef.value) return
|
||||
isPolling.value = true
|
||||
error.value = null
|
||||
status.value = 'PENDING'
|
||||
|
||||
timer = setInterval(async () => {
|
||||
try {
|
||||
const res = await getTaskStatus(taskIdRef.value!)
|
||||
status.value = res.status
|
||||
|
||||
if (res.status === 'SUCCESS') {
|
||||
result.value = res
|
||||
stop()
|
||||
} else if (res.status === 'FAILED') {
|
||||
error.value = res.error || '任务执行失败'
|
||||
stop()
|
||||
}
|
||||
} catch (e: any) {
|
||||
error.value = '网络请求失败,请检查后端状态'
|
||||
stop()
|
||||
}
|
||||
}, 2000)
|
||||
}
|
||||
|
||||
const stop = () => {
|
||||
isPolling.value = false
|
||||
if (timer) clearInterval(timer)
|
||||
}
|
||||
|
||||
// 监听 Task ID 变化自动开启轮询
|
||||
watch(taskIdRef, (newVal) => {
|
||||
stop()
|
||||
if (newVal) start()
|
||||
})
|
||||
|
||||
// 组件销毁时清理定时器
|
||||
onUnmounted(() => stop())
|
||||
|
||||
return { status, isPolling, error, result, stop }
|
||||
}
|
||||
9
new/frontend/src/main.ts
Normal file
@ -0,0 +1,9 @@
|
||||
import { createApp } from 'vue'
|
||||
import ElementPlus from 'element-plus'
|
||||
import 'element-plus/dist/index.css'
|
||||
import App from './App.vue'
|
||||
|
||||
const app = createApp(App)
|
||||
|
||||
app.use(ElementPlus)
|
||||
app.mount('#app')
|
||||
296
new/frontend/src/style.css
Normal file
@ -0,0 +1,296 @@
|
||||
:root {
|
||||
--text: #6b6375;
|
||||
--text-h: #08060d;
|
||||
--bg: #fff;
|
||||
--border: #e5e4e7;
|
||||
--code-bg: #f4f3ec;
|
||||
--accent: #aa3bff;
|
||||
--accent-bg: rgba(170, 59, 255, 0.1);
|
||||
--accent-border: rgba(170, 59, 255, 0.5);
|
||||
--social-bg: rgba(244, 243, 236, 0.5);
|
||||
--shadow:
|
||||
rgba(0, 0, 0, 0.1) 0 10px 15px -3px, rgba(0, 0, 0, 0.05) 0 4px 6px -2px;
|
||||
|
||||
--sans: system-ui, 'Segoe UI', Roboto, sans-serif;
|
||||
--heading: system-ui, 'Segoe UI', Roboto, sans-serif;
|
||||
--mono: ui-monospace, Consolas, monospace;
|
||||
|
||||
font: 18px/145% var(--sans);
|
||||
letter-spacing: 0.18px;
|
||||
color-scheme: light dark;
|
||||
color: var(--text);
|
||||
background: var(--bg);
|
||||
font-synthesis: none;
|
||||
text-rendering: optimizeLegibility;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
font-size: 16px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root {
|
||||
--text: #9ca3af;
|
||||
--text-h: #f3f4f6;
|
||||
--bg: #16171d;
|
||||
--border: #2e303a;
|
||||
--code-bg: #1f2028;
|
||||
--accent: #c084fc;
|
||||
--accent-bg: rgba(192, 132, 252, 0.15);
|
||||
--accent-border: rgba(192, 132, 252, 0.5);
|
||||
--social-bg: rgba(47, 48, 58, 0.5);
|
||||
--shadow:
|
||||
rgba(0, 0, 0, 0.4) 0 10px 15px -3px, rgba(0, 0, 0, 0.25) 0 4px 6px -2px;
|
||||
}
|
||||
|
||||
#social .button-icon {
|
||||
filter: invert(1) brightness(2);
|
||||
}
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
h1,
|
||||
h2 {
|
||||
font-family: var(--heading);
|
||||
font-weight: 500;
|
||||
color: var(--text-h);
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 56px;
|
||||
letter-spacing: -1.68px;
|
||||
margin: 32px 0;
|
||||
@media (max-width: 1024px) {
|
||||
font-size: 36px;
|
||||
margin: 20px 0;
|
||||
}
|
||||
}
|
||||
h2 {
|
||||
font-size: 24px;
|
||||
line-height: 118%;
|
||||
letter-spacing: -0.24px;
|
||||
margin: 0 0 8px;
|
||||
@media (max-width: 1024px) {
|
||||
font-size: 20px;
|
||||
}
|
||||
}
|
||||
p {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
code,
|
||||
.counter {
|
||||
font-family: var(--mono);
|
||||
display: inline-flex;
|
||||
border-radius: 4px;
|
||||
color: var(--text-h);
|
||||
}
|
||||
|
||||
code {
|
||||
font-size: 15px;
|
||||
line-height: 135%;
|
||||
padding: 4px 8px;
|
||||
background: var(--code-bg);
|
||||
}
|
||||
|
||||
.counter {
|
||||
font-size: 16px;
|
||||
padding: 5px 10px;
|
||||
border-radius: 5px;
|
||||
color: var(--accent);
|
||||
background: var(--accent-bg);
|
||||
border: 2px solid transparent;
|
||||
transition: border-color 0.3s;
|
||||
margin-bottom: 24px;
|
||||
|
||||
&:hover {
|
||||
border-color: var(--accent-border);
|
||||
}
|
||||
&:focus-visible {
|
||||
outline: 2px solid var(--accent);
|
||||
outline-offset: 2px;
|
||||
}
|
||||
}
|
||||
|
||||
.hero {
|
||||
position: relative;
|
||||
|
||||
.base,
|
||||
.framework,
|
||||
.vite {
|
||||
inset-inline: 0;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.base {
|
||||
width: 170px;
|
||||
position: relative;
|
||||
z-index: 0;
|
||||
}
|
||||
|
||||
.framework,
|
||||
.vite {
|
||||
position: absolute;
|
||||
}
|
||||
|
||||
.framework {
|
||||
z-index: 1;
|
||||
top: 34px;
|
||||
height: 28px;
|
||||
transform: perspective(2000px) rotateZ(300deg) rotateX(44deg) rotateY(39deg)
|
||||
scale(1.4);
|
||||
}
|
||||
|
||||
.vite {
|
||||
z-index: 0;
|
||||
top: 107px;
|
||||
height: 26px;
|
||||
width: auto;
|
||||
transform: perspective(2000px) rotateZ(300deg) rotateX(40deg) rotateY(39deg)
|
||||
scale(0.8);
|
||||
}
|
||||
}
|
||||
|
||||
#app {
|
||||
width: 1126px;
|
||||
max-width: 100%;
|
||||
margin: 0 auto;
|
||||
text-align: center;
|
||||
border-inline: 1px solid var(--border);
|
||||
min-height: 100svh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
#center {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 25px;
|
||||
place-content: center;
|
||||
place-items: center;
|
||||
flex-grow: 1;
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
padding: 32px 20px 24px;
|
||||
gap: 18px;
|
||||
}
|
||||
}
|
||||
|
||||
#next-steps {
|
||||
display: flex;
|
||||
border-top: 1px solid var(--border);
|
||||
text-align: left;
|
||||
|
||||
& > div {
|
||||
flex: 1 1 0;
|
||||
padding: 32px;
|
||||
@media (max-width: 1024px) {
|
||||
padding: 24px 20px;
|
||||
}
|
||||
}
|
||||
|
||||
.icon {
|
||||
margin-bottom: 16px;
|
||||
width: 22px;
|
||||
height: 22px;
|
||||
}
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
flex-direction: column;
|
||||
text-align: center;
|
||||
}
|
||||
}
|
||||
|
||||
#docs {
|
||||
border-right: 1px solid var(--border);
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
border-right: none;
|
||||
border-bottom: 1px solid var(--border);
|
||||
}
|
||||
}
|
||||
|
||||
#next-steps ul {
|
||||
list-style: none;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
margin: 32px 0 0;
|
||||
|
||||
.logo {
|
||||
height: 18px;
|
||||
}
|
||||
|
||||
a {
|
||||
color: var(--text-h);
|
||||
font-size: 16px;
|
||||
border-radius: 6px;
|
||||
background: var(--social-bg);
|
||||
display: flex;
|
||||
padding: 6px 12px;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
text-decoration: none;
|
||||
transition: box-shadow 0.3s;
|
||||
|
||||
&:hover {
|
||||
box-shadow: var(--shadow);
|
||||
}
|
||||
.button-icon {
|
||||
height: 18px;
|
||||
width: 18px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 1024px) {
|
||||
margin-top: 20px;
|
||||
flex-wrap: wrap;
|
||||
justify-content: center;
|
||||
|
||||
li {
|
||||
flex: 1 1 calc(50% - 8px);
|
||||
}
|
||||
|
||||
a {
|
||||
width: 100%;
|
||||
justify-content: center;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#spacer {
|
||||
height: 88px;
|
||||
border-top: 1px solid var(--border);
|
||||
@media (max-width: 1024px) {
|
||||
height: 48px;
|
||||
}
|
||||
}
|
||||
|
||||
.ticks {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
|
||||
&::before,
|
||||
&::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: -4.5px;
|
||||
border: 5px solid transparent;
|
||||
}
|
||||
|
||||
&::before {
|
||||
left: 0;
|
||||
border-left-color: var(--border);
|
||||
}
|
||||
&::after {
|
||||
right: 0;
|
||||
border-right-color: var(--border);
|
||||
}
|
||||
}
|
||||
14
new/frontend/tsconfig.app.json
Normal file
@ -0,0 +1,14 @@
|
||||
{
|
||||
"extends": "@vue/tsconfig/tsconfig.dom.json",
|
||||
"compilerOptions": {
|
||||
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
|
||||
"types": ["vite/client"],
|
||||
|
||||
/* Linting */
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"erasableSyntaxOnly": true,
|
||||
"noFallthroughCasesInSwitch": true
|
||||
},
|
||||
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"]
|
||||
}
|
||||
7
new/frontend/tsconfig.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"files": [],
|
||||
"references": [
|
||||
{ "path": "./tsconfig.app.json" },
|
||||
{ "path": "./tsconfig.node.json" }
|
||||
]
|
||||
}
|
||||
24
new/frontend/tsconfig.node.json
Normal file
@ -0,0 +1,24 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
|
||||
"target": "es2023",
|
||||
"lib": ["ES2023"],
|
||||
"module": "esnext",
|
||||
"types": ["node"],
|
||||
"skipLibCheck": true,
|
||||
|
||||
/* Bundler mode */
|
||||
"moduleResolution": "bundler",
|
||||
"allowImportingTsExtensions": true,
|
||||
"verbatimModuleSyntax": true,
|
||||
"moduleDetection": "force",
|
||||
"noEmit": true,
|
||||
|
||||
/* Linting */
|
||||
"noUnusedLocals": true,
|
||||
"noUnusedParameters": true,
|
||||
"erasableSyntaxOnly": true,
|
||||
"noFallthroughCasesInSwitch": true
|
||||
},
|
||||
"include": ["vite.config.ts"]
|
||||
}
|
||||
7
new/frontend/vite.config.ts
Normal file
@ -0,0 +1,7 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import vue from '@vitejs/plugin-vue'
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [vue()],
|
||||
})
|
||||
6
run_smoke.bat
Normal file
@ -0,0 +1,6 @@
|
||||
@echo off
|
||||
cd /d "D:\111\office\ZHLduijie\1.WQ\WQ_GUI"
|
||||
call venv\Scripts\activate.bat
|
||||
set PYTHONPATH=new\app\api;%PYTHONPATH%
|
||||
python -c "import _smoke_test_train; _smoke_test_train.test_load_train_df(); _smoke_test_train.test_get_model_pipeline_all_types(); _smoke_test_train.test_run_train_sync_linearregression_fast(); _smoke_test_train.test_run_train_sync_bad_csv(); _smoke_test_train.test_run_train_sync_bad_target(); print('OK')" > %TEMP%\smoke_log.txt 2>&1
|
||||
type %TEMP%\smoke_log.txt
|
||||
@ -20,23 +20,28 @@ class PipelineContext:
|
||||
"""流水线运行上下文(在 14 个 step 之间传递的内存字典)
|
||||
|
||||
字段命名约定:
|
||||
- 主路径字段统一 `_path` 后缀(如 water_mask_path)
|
||||
- 目录类字段无 `_path` 后缀(如 models_dir)
|
||||
- 路径类字段名 = panel key 名 = step 形参名(全链路无翻译)
|
||||
- 训练/产物 CSV 用 `_path` 后缀(如 training_csv_path / water_mask_path)
|
||||
- 入参影像/CSV 沿用 panel 原名(img_path / csv_path),无 `_path` 后缀
|
||||
- 目录类字段无 `_path` 后缀(如 models_dir / prediction_dir)
|
||||
- 元信息字段无后缀(如 user_config / status / log)
|
||||
"""
|
||||
|
||||
# ── 9 步主路径(按 step 输出顺序排列) ──
|
||||
raw_img_path: Optional[str] = None # Step 1 入参:原始影像
|
||||
# ── 11 个 step 的入参/产物(按 step 顺序排列;字段名 = panel key = step 形参) ──
|
||||
img_path: Optional[str] = None # Step 1/2/3 入参:原始影像
|
||||
water_mask_path: Optional[str] = None # Step 1 出 → Step 2/3/7 入
|
||||
glint_mask_path: Optional[str] = None # Step 2 出 → Step 3/7 入
|
||||
deglint_img_path: Optional[str] = None # Step 3 出 → Step 5/7 入
|
||||
raw_csv_path: Optional[str] = None # Step 4 入:原始 CSV
|
||||
csv_path: Optional[str] = None # Step 4/5/6_5/6_75 入参:原始/训练 CSV
|
||||
processed_csv_path: Optional[str] = None # Step 4 出 → Step 5 入
|
||||
training_spectra_path: Optional[str] = None # Step 5 出 → Step 6 入
|
||||
training_csv_path: Optional[str] = None # Step 5 出 → Step 5_5/6/6_5/6_75 入
|
||||
boundary_path: Optional[str] = None # Step 5 入参:边界 SHP(panel step5 名)
|
||||
indices_path: Optional[str] = None # Step 5.5 出
|
||||
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/9 入
|
||||
prediction_csv_path: Optional[str] = None # Step 8 出
|
||||
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/8_5/8_75/9 入
|
||||
prediction_csv_path: Optional[str] = None # Step 8 出 → Step 9 入
|
||||
distribution_map_path: Optional[str] = None # Step 9 出
|
||||
boundary_shp_path: Optional[str] = None # Step 9 入参:边界 SHP(panel step9 名)
|
||||
formula_csv_path: Optional[str] = None # Step 8_75 入参:公式 CSV
|
||||
|
||||
# ── 目录类(命名不带 _path 以示区别) ──
|
||||
models_dir: Optional[str] = None
|
||||
|
||||
@ -4,10 +4,8 @@ PipelineRunner:基于 StepSpec 声明式调度 14 个 step。
|
||||
|
||||
设计要点:
|
||||
- StepSpec 声明 requires(ctx 字段名列表)+ produces(ctx 字段名列表)
|
||||
- 默认约定:ctx 字段名去掉 `_path` 后缀 = step 方法形参名
|
||||
例:ctx.water_mask_path → 形参 water_mask
|
||||
例:ctx.raw_img_path → 形参 raw_img
|
||||
- 可被 spec.parameter_map 覆盖
|
||||
- 命名约定:ctx 字段名 == panel key 名 == step 形参名(全链路无翻译)
|
||||
- 保留 spec.parameter_map 字段骨架供极少数特例覆盖(默认空 dict)
|
||||
- 调度顺序:按 PIPELINE_STEPS 列表顺序,requires 缺则 skip
|
||||
- 软取消:在每个 step 前检查 ctx.is_cancelled()
|
||||
- duck-typed pipeline:runner 只调 getattr(pipeline, method_name),不强依赖类层级
|
||||
@ -48,101 +46,76 @@ class StepSpec:
|
||||
PIPELINE_STEPS: List[StepSpec] = [
|
||||
StepSpec(
|
||||
step_id="step1", method_name="step1_generate_water_mask",
|
||||
requires=["raw_img_path"], produces=["water_mask_path"],
|
||||
# ctx.raw_img_path → 形参 img_path(老 step1 形参名是 img_path,不是 raw_img)
|
||||
parameter_map={"raw_img_path": "img_path"},
|
||||
requires=["img_path"], produces=["water_mask_path"],
|
||||
description="水域掩膜生成(NDWI 或 SHP)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step2", method_name="step2_find_glint_area",
|
||||
requires=["raw_img_path", "water_mask_path"], produces=["glint_mask_path"],
|
||||
# raw_img_path→img_path;water_mask_path 不变
|
||||
parameter_map={"raw_img_path": "img_path"},
|
||||
requires=["img_path", "water_mask_path"], produces=["glint_mask_path"],
|
||||
description="耀斑区域检测",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step3", method_name="step3_remove_glint",
|
||||
requires=["deglint_img_path"], produces=["deglint_img_path"],
|
||||
# deglint_img_path→img_path(老 step3 形参名是 img_path)
|
||||
# 注意:glint_mask_path 不在 requires 中——step3 形参表无该参数,内部走 self.glint_mask_path 回退
|
||||
parameter_map={"deglint_img_path": "img_path"},
|
||||
requires=["img_path", "water_mask_path", "glint_mask_path"],
|
||||
produces=["deglint_img_path"],
|
||||
description="耀斑去除",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step4", method_name="step4_process_csv",
|
||||
requires=["raw_csv_path"], produces=["processed_csv_path"],
|
||||
# raw_csv_path→csv_path(老 step4 形参名是 csv_path)
|
||||
parameter_map={"raw_csv_path": "csv_path"},
|
||||
requires=["csv_path"], produces=["processed_csv_path"],
|
||||
description="CSV 异常值清洗",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step5", method_name="step5_extract_training_spectra",
|
||||
requires=["deglint_img_path", "processed_csv_path"], produces=["training_spectra_path"],
|
||||
# processed_csv_path→csv_path(老 step5 形参名是 csv_path);deglint_img_path 不变
|
||||
parameter_map={"processed_csv_path": "csv_path"},
|
||||
requires=["deglint_img_path", "csv_path", "boundary_path", "glint_mask_path"],
|
||||
produces=["training_csv_path"],
|
||||
description="实测样本点光谱提取",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step5_5", method_name="step5_5_calculate_water_quality_indices",
|
||||
requires=["training_spectra_path"], produces=["indices_path"],
|
||||
# 老 step5.5 形参是 training_spectra_path;ctx 字段同名,无需映射
|
||||
parameter_map={},
|
||||
requires=["training_csv_path"], produces=["indices_path"],
|
||||
description="水质光谱指数计算(optional)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step6", method_name="step6_train_models",
|
||||
requires=["training_spectra_path"], produces=["models_dir"],
|
||||
# training_spectra_path→training_csv_path(老 step6 形参名是 training_csv_path)
|
||||
parameter_map={"training_spectra_path": "training_csv_path"},
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
description="ML 建模(GridSearchCV / AutoML)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step6_5", method_name="step6_5_non_empirical_modeling",
|
||||
requires=["training_spectra_path"], produces=["models_dir"],
|
||||
# training_spectra_path→csv_path(老 step6.5 形参名是 csv_path)
|
||||
parameter_map={"training_spectra_path": "csv_path"},
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
description="非经验统计回归",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step6_75", method_name="step6_75_custom_regression",
|
||||
requires=["training_spectra_path"], produces=["models_dir"],
|
||||
# training_spectra_path→csv_path(老 step6.75 形参名是 csv_path)
|
||||
parameter_map={"training_spectra_path": "csv_path"},
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
description="自定义回归分析",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step7", method_name="step7_generate_sampling_points",
|
||||
requires=["deglint_img_path", "water_mask_path"], produces=["sampling_csv_path"],
|
||||
# 老 step7 形参是 deglint_img_path / water_mask_path;ctx 字段同名
|
||||
parameter_map={},
|
||||
description="整景密集采样点生成 + 光谱提取",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step8", method_name="step8_predict_water_quality",
|
||||
requires=["sampling_csv_path", "models_dir"], produces=["prediction_csv_path"],
|
||||
parameter_map={},
|
||||
description="ML 模型预测(采样点)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step8_5", method_name="step8_5_predict_with_non_empirical_models",
|
||||
requires=["sampling_csv_path"], produces=["prediction_dir"],
|
||||
parameter_map={},
|
||||
requires=["sampling_csv_path", "models_dir"], produces=["prediction_dir"],
|
||||
description="非经验模型预测",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step8_75", method_name="step8_75_predict_with_custom_regression",
|
||||
requires=["sampling_csv_path"], produces=["prediction_dir"],
|
||||
parameter_map={},
|
||||
requires=["sampling_csv_path", "models_dir", "formula_csv_path"],
|
||||
produces=["prediction_dir"],
|
||||
description="自定义回归预测",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step9", method_name="step9_generate_distribution_map",
|
||||
requires=["prediction_csv_path"],
|
||||
requires=["prediction_csv_path", "boundary_shp_path"],
|
||||
produces=["distribution_map_path"],
|
||||
# 老 step9 形参是 prediction_csv_path / boundary_shp_path;ctx 字段同名
|
||||
# 注意:sampling_csv_path / water_mask_path 不在 requires 中——step9 形参表无该参数,
|
||||
# 内部走 self.sampling_csv_path / self.water_mask_path 回退
|
||||
parameter_map={},
|
||||
description="克里金插值成图",
|
||||
),
|
||||
]
|
||||
@ -157,7 +130,7 @@ class PipelineRunner:
|
||||
|
||||
用法:
|
||||
runner = PipelineRunner(pipeline_instance)
|
||||
ctx = PipelineContext(raw_img_path=..., ...)
|
||||
ctx = PipelineContext(img_path=..., ...)
|
||||
result_ctx = runner.run(ctx)
|
||||
"""
|
||||
|
||||
|
||||
544
src/core/prediction/automl_trainer.py
Normal file
@ -0,0 +1,544 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Optuna + 智能子采样 AutoML 训练器(路线 B 防爆引擎)。
|
||||
|
||||
为什么需要这个:
|
||||
- 老路径:11 预处理 × 4 模型 × 3 划分 = 132 组 GridSearchCV
|
||||
对中小数据集 10 分钟+,对大数据集 5w+ 行 直接 OOM
|
||||
- AutoML 路径:1 预处理 × N 模型(Optuna 调超参),用智能子采样避开 OOM
|
||||
再用最优超参在**全量数据**上 refit,最终保存单一模型
|
||||
|
||||
设计要点:
|
||||
- 入口 train_with_automl(csv, feature_start_column, model_names, ...)
|
||||
- AutoMLResult dataclass 返回(每个目标列一份)
|
||||
- smart_subsample:N > max_samples 时随机下采样
|
||||
- 失败兜底:optuna 未装 / 全 trial 失败 → fallback 到 WaterQualityModelingBatch
|
||||
- 文件命名规范:{target}_{preprocess}_{model}_AUTOML.joblib
|
||||
- save_data["metadata"]["automl"] = True 标记
|
||||
|
||||
调用:
|
||||
from src.core.prediction.automl_trainer import train_with_automl
|
||||
results = train_with_automl(
|
||||
training_csv_path=".../training_spectra.csv",
|
||||
feature_start_column="374.285004",
|
||||
model_names=["RF", "SVR", "Ridge"],
|
||||
n_trials=20,
|
||||
timeout_sec=300,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 常量
|
||||
# ============================================================
|
||||
|
||||
# AutoML 寻优阶段允许的最大样本数(避免 OOM)
|
||||
# 5000 样本对 RF/SVR/Ridge 的 Optuna 寻优足够给出稳定 CV
|
||||
DEFAULT_MAX_SAMPLES = 5000
|
||||
|
||||
# 单次 Optuna trial 的默认超时(秒)
|
||||
DEFAULT_TIMEOUT = 300.0
|
||||
|
||||
# 默认 trial 数
|
||||
DEFAULT_N_TRIALS = 20
|
||||
|
||||
# AutoML 输出目录名后缀
|
||||
AUTOML_DIR_SUFFIX = "_AutoML"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 数据类
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class AutoMLResult:
|
||||
"""单个目标列的 AutoML 训练结果"""
|
||||
success: bool = False
|
||||
model_path: Optional[str] = None
|
||||
cv_score: float = -float("inf")
|
||||
best_params: Optional[Dict[str, Any]] = None
|
||||
target_column: str = ""
|
||||
preprocessing: str = ""
|
||||
model_name: str = ""
|
||||
n_trials_done: int = 0
|
||||
n_samples_used: int = 0
|
||||
fallback_used: bool = False
|
||||
elapsed_sec: float = 0.0
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 智能子采样
|
||||
# ============================================================
|
||||
|
||||
def smart_subsample(
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
max_samples: int = DEFAULT_MAX_SAMPLES,
|
||||
random_state: int = 42,
|
||||
) -> Tuple[np.ndarray, np.ndarray, bool]:
|
||||
"""当 N > max_samples 时随机下采样;否则原样返回。
|
||||
|
||||
Returns:
|
||||
(X_sub, y_sub, was_subsampled)
|
||||
"""
|
||||
n = X.shape[0]
|
||||
if n <= max_samples:
|
||||
return X, y, False
|
||||
rng = np.random.default_rng(random_state)
|
||||
idx = rng.choice(n, size=max_samples, replace=False)
|
||||
return X[idx], y[idx], True
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 模型工厂
|
||||
# ============================================================
|
||||
|
||||
def _build_model(model_name: str, random_state: int = 42):
|
||||
"""根据英文模型键名构造 sklearn-compatible 模型实例(factory)。"""
|
||||
from sklearn.ensemble import (
|
||||
AdaBoostRegressor, ExtraTreesRegressor, GradientBoostingRegressor,
|
||||
RandomForestRegressor,
|
||||
)
|
||||
from sklearn.linear_model import (
|
||||
ElasticNet, Lasso, LinearRegression, Ridge,
|
||||
)
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.tree import DecisionTreeRegressor
|
||||
|
||||
factory = {
|
||||
"RF": lambda **kw: RandomForestRegressor(random_state=random_state, n_jobs=1, **kw),
|
||||
"ET": lambda **kw: ExtraTreesRegressor(random_state=random_state, n_jobs=1, **kw),
|
||||
"GradientBoosting": lambda **kw: GradientBoostingRegressor(random_state=random_state, **kw),
|
||||
"AdaBoost": lambda **kw: AdaBoostRegressor(random_state=random_state, **kw),
|
||||
"Ridge": lambda **kw: Ridge(**kw),
|
||||
"Lasso": lambda **kw: Lasso(max_iter=5000, **kw),
|
||||
"ElasticNet": lambda **kw: ElasticNet(max_iter=5000, **kw),
|
||||
"LinearRegression": lambda **kw: LinearRegression(**kw),
|
||||
"SVR": lambda **kw: SVR(**kw),
|
||||
"KNN": lambda **kw: KNeighborsRegressor(n_jobs=1, **kw),
|
||||
"MLP": lambda **kw: MLPRegressor(max_iter=500, random_state=random_state, **kw),
|
||||
"DecisionTree": lambda **kw: DecisionTreeRegressor(random_state=random_state, **kw),
|
||||
"PLS": None, # sklearn.cross_decomposition.PLSRegression 暂未集成
|
||||
}
|
||||
builder = factory.get(model_name)
|
||||
if builder is None:
|
||||
return None
|
||||
return builder
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Optuna 超参 search space
|
||||
# ============================================================
|
||||
|
||||
def _get_search_space(model_name: str, trial) -> Dict[str, Any]:
|
||||
"""按模型名返回 Optuna 超参 search space。"""
|
||||
sp: Dict[str, Any] = {}
|
||||
if model_name == "RF":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
|
||||
sp["min_samples_split"] = trial.suggest_int("min_samples_split", 2, 10)
|
||||
sp["min_samples_leaf"] = trial.suggest_int("min_samples_leaf", 1, 5)
|
||||
elif model_name == "ET":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
|
||||
elif model_name == "GradientBoosting":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 8)
|
||||
sp["learning_rate"] = trial.suggest_float("learning_rate", 0.01, 0.3, log=True)
|
||||
elif model_name == "SVR":
|
||||
sp["C"] = trial.suggest_float("C", 0.1, 100.0, log=True)
|
||||
sp["epsilon"] = trial.suggest_float("epsilon", 0.001, 1.0, log=True)
|
||||
sp["kernel"] = trial.suggest_categorical("kernel", ["rbf", "linear"])
|
||||
elif model_name == "KNN":
|
||||
sp["n_neighbors"] = trial.suggest_int("n_neighbors", 3, 20)
|
||||
sp["weights"] = trial.suggest_categorical("weights", ["uniform", "distance"])
|
||||
elif model_name in ("Ridge", "Lasso", "ElasticNet"):
|
||||
sp["alpha"] = trial.suggest_float("alpha", 0.01, 100.0, log=True)
|
||||
if model_name == "ElasticNet":
|
||||
sp["l1_ratio"] = trial.suggest_float("l1_ratio", 0.0, 1.0)
|
||||
elif model_name == "MLP":
|
||||
sp["hidden_layer_sizes"] = trial.suggest_categorical(
|
||||
"hidden_layer_sizes", [(50,), (100,), (50, 50), (100, 50)]
|
||||
)
|
||||
sp["alpha"] = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
|
||||
sp["learning_rate_init"] = trial.suggest_float("learning_rate_init", 1e-4, 1e-2, log=True)
|
||||
elif model_name == "DecisionTree":
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
|
||||
sp["min_samples_split"] = trial.suggest_int("min_samples_split", 2, 10)
|
||||
elif model_name == "AdaBoost":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 30, 200, step=30)
|
||||
sp["learning_rate"] = trial.suggest_float("learning_rate", 0.01, 1.0, log=True)
|
||||
else:
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 200, step=50)
|
||||
return sp
|
||||
|
||||
|
||||
def _make_objective(model_name: str, X: np.ndarray, y: np.ndarray,
|
||||
cv_folds: int, random_state: int):
|
||||
"""构造 Optuna objective(5 折 CV R²)。"""
|
||||
from sklearn.model_selection import KFold, cross_val_score
|
||||
|
||||
def objective(trial):
|
||||
params = _get_search_space(model_name, trial)
|
||||
try:
|
||||
builder = _build_model(model_name, random_state=random_state)
|
||||
if builder is None:
|
||||
return -1.0
|
||||
model = builder(**params)
|
||||
kf = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
||||
scores = cross_val_score(model, X, y, cv=kf, scoring="r2", n_jobs=1)
|
||||
return float(np.mean(scores))
|
||||
except Exception:
|
||||
return -1.0
|
||||
|
||||
return objective
|
||||
|
||||
|
||||
def _refit_full(model_name: str, best_params: Dict[str, Any],
|
||||
X: np.ndarray, y: np.ndarray, random_state: int):
|
||||
"""用 best params 在**全量数据**上 refit。"""
|
||||
builder = _build_model(model_name, random_state=random_state)
|
||||
if builder is None:
|
||||
return None
|
||||
model = builder(**best_params)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 失败兜底(回退到老 GridSearchCV 路径)
|
||||
# ============================================================
|
||||
|
||||
def _fallback_train(
|
||||
training_csv_path: str,
|
||||
feature_start_column,
|
||||
preprocessing: str,
|
||||
model_name: str,
|
||||
split_method: str,
|
||||
cv_folds: int,
|
||||
output_dir: Path,
|
||||
target_column: str,
|
||||
) -> AutoMLResult:
|
||||
"""AutoML 失败时调老 WaterQualityModelingBatch。
|
||||
|
||||
返回的 AutoMLResult.fallback_used=True。
|
||||
"""
|
||||
try:
|
||||
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||
except ImportError as e:
|
||||
return AutoMLResult(
|
||||
success=False, error=f"fallback 导入失败: {e!r}", fallback_used=True,
|
||||
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
|
||||
)
|
||||
|
||||
try:
|
||||
out_dir = output_dir / preprocessing
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
modeler = WaterQualityModelingBatch(str(out_dir))
|
||||
modeler.train_models_batch(
|
||||
csv_path=training_csv_path,
|
||||
feature_start_column=feature_start_column,
|
||||
preprocessing_methods=[preprocessing],
|
||||
model_names=[model_name],
|
||||
split_methods=[split_method],
|
||||
cv_folds=cv_folds,
|
||||
)
|
||||
# 找产出
|
||||
candidates = list(out_dir.rglob(f"{target_column}_{preprocessing}_{model_name}.joblib"))
|
||||
model_path = str(candidates[0]) if candidates else None
|
||||
return AutoMLResult(
|
||||
success=model_path is not None,
|
||||
model_path=model_path,
|
||||
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
|
||||
fallback_used=True,
|
||||
metadata={"source": "WaterQualityModelingBatch"},
|
||||
)
|
||||
except Exception as e:
|
||||
return AutoMLResult(
|
||||
success=False, error=f"fallback 失败: {e!r}", fallback_used=True,
|
||||
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 主入口
|
||||
# ============================================================
|
||||
|
||||
def train_with_automl(
|
||||
training_csv_path: str,
|
||||
feature_start_column,
|
||||
preprocessing_methods: Optional[List[str]] = None,
|
||||
model_names: Optional[List[str]] = None,
|
||||
split_methods: Optional[List[str]] = None,
|
||||
cv_folds: int = 5,
|
||||
output_dir: Optional[str] = None,
|
||||
n_trials: int = DEFAULT_N_TRIALS,
|
||||
timeout_sec: float = DEFAULT_TIMEOUT,
|
||||
max_samples: int = DEFAULT_MAX_SAMPLES,
|
||||
random_state: int = 42,
|
||||
callback: Optional[Callable[[str, str, str], None]] = None,
|
||||
) -> List[AutoMLResult]:
|
||||
"""用 Optuna + 子采样跑 AutoML。失败时自动回退到 GridSearchCV。
|
||||
|
||||
Args:
|
||||
training_csv_path: 训练用 CSV(Step 5 产物 training_spectra.csv)
|
||||
feature_start_column: 特征起始列名或索引(之前所有列视为目标 y)
|
||||
preprocessing_methods: 候选预处理列表(**仅用第 1 个**,避免笛卡尔爆炸)
|
||||
model_names: 候选模型列表(每个都会跑一遍 Optuna)
|
||||
split_methods: 候选数据划分列表(AutoML 仅用第 1 个)
|
||||
cv_folds: 交叉验证折数
|
||||
output_dir: 输出目录(默认 <models_dir>_AutoML)
|
||||
n_trials: 单模型 Optuna trial 数
|
||||
timeout_sec: 单模型超时(秒),到时强制停止
|
||||
max_samples: 寻优阶段允许的最大样本数
|
||||
callback: 状态回调 callback(step_name, status, message)
|
||||
|
||||
Returns:
|
||||
List[AutoMLResult],每个目标列一份结果
|
||||
"""
|
||||
def notify(status: str, msg: str = "") -> None:
|
||||
if callback:
|
||||
callback("步骤6_AutoML", status, msg)
|
||||
|
||||
# ---- 1) 参数默认值 ----
|
||||
if preprocessing_methods is None:
|
||||
preprocessing_methods = ["MMS"]
|
||||
if model_names is None:
|
||||
model_names = ["RF", "SVR", "Ridge"]
|
||||
if split_methods is None:
|
||||
split_methods = ["spxy"]
|
||||
|
||||
# 决策:仅用第一个预处理 + 第一个划分,避免笛卡尔爆炸
|
||||
preproc = preprocessing_methods[0]
|
||||
split_method = split_methods[0]
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = "./7_Supervised_Model_Training_AutoML"
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
preproc_dir = out_dir / preproc
|
||||
preproc_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---- 2) 加载数据 ----
|
||||
notify("start", f"AutoML 训练开始 (n_trials={n_trials}, timeout={timeout_sec}s, max_samples={max_samples})")
|
||||
if not Path(training_csv_path).exists():
|
||||
return [AutoMLResult(success=False, error=f"训练 CSV 不存在: {training_csv_path}")]
|
||||
|
||||
df = pd.read_csv(training_csv_path)
|
||||
|
||||
# 提取目标列(feature_start_column 之前所有数值列)
|
||||
if isinstance(feature_start_column, int):
|
||||
y_cols = [c for c in df.columns[:feature_start_column]
|
||||
if pd.api.types.is_numeric_dtype(df[c])]
|
||||
else:
|
||||
try:
|
||||
idx = list(df.columns).index(feature_start_column)
|
||||
y_cols = [c for c in df.columns[:idx]
|
||||
if pd.api.types.is_numeric_dtype(df[c])]
|
||||
except ValueError:
|
||||
y_cols = []
|
||||
|
||||
if not y_cols:
|
||||
notify("error", "AutoML: 未识别出目标列(feature_start_column 之前的所有数值列)")
|
||||
return [AutoMLResult(success=False, error="未识别出目标列")]
|
||||
|
||||
feat_cols = [c for c in df.columns if c not in y_cols]
|
||||
X_all = df[feat_cols].values.astype(np.float64)
|
||||
|
||||
# ---- 3) 预处理(仅第一项) ----
|
||||
if preproc != "None":
|
||||
try:
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
processed = Preprocessing(preproc, df[feat_cols])
|
||||
if isinstance(processed, pd.DataFrame):
|
||||
X_all = processed.values.astype(np.float64)
|
||||
else:
|
||||
X_all = np.asarray(processed, dtype=np.float64)
|
||||
except Exception as e:
|
||||
notify("warning", f"预处理 {preproc} 失败: {e!r},改用 None")
|
||||
preproc = "None"
|
||||
|
||||
# ---- 4) 检查 Optuna 是否可用 ----
|
||||
try:
|
||||
import optuna
|
||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||
optuna_available = True
|
||||
except ImportError:
|
||||
optuna_available = False
|
||||
notify("warning", "optuna 未安装,全目标列回退到 GridSearchCV(pip install \"optuna>=3.6\")")
|
||||
|
||||
# ---- 5) 逐 target 跑 ----
|
||||
results: List[AutoMLResult] = []
|
||||
total = len(y_cols)
|
||||
per_model_timeout = max(10.0, timeout_sec / max(1, len(model_names)))
|
||||
|
||||
for ti, tgt in enumerate(y_cols, 1):
|
||||
t0 = time.time()
|
||||
yv = df[tgt].values.astype(np.float64)
|
||||
mask = ~np.isnan(yv)
|
||||
X_t = X_all[mask]
|
||||
y_t = yv[mask]
|
||||
|
||||
if X_t.shape[0] < cv_folds * 2:
|
||||
notify("warning", f"目标 {tgt}: 有效样本 {X_t.shape[0]} 不足,跳过")
|
||||
results.append(AutoMLResult(
|
||||
success=False, target_column=tgt, error=f"样本不足({X_t.shape[0]})",
|
||||
preprocessing=preproc,
|
||||
))
|
||||
continue
|
||||
|
||||
X_sub, y_sub, was_sub = smart_subsample(X_t, y_t, max_samples=max_samples, random_state=random_state)
|
||||
if was_sub:
|
||||
notify("info", f"目标 {tgt}: {X_t.shape[0]} 样本 → 子采样 {X_sub.shape[0]}(寻优用)")
|
||||
|
||||
best_overall = AutoMLResult(success=False, target_column=tgt, preprocessing=preproc)
|
||||
|
||||
if not optuna_available:
|
||||
# 全目标列一次性 fallback
|
||||
best_overall = _fallback_train(
|
||||
training_csv_path, feature_start_column, preproc, model_names[0], split_method,
|
||||
cv_folds, out_dir, tgt,
|
||||
)
|
||||
else:
|
||||
for model_name in model_names:
|
||||
try:
|
||||
builder = _build_model(model_name, random_state=random_state)
|
||||
if builder is None:
|
||||
notify("warning", f"模型 {model_name} 暂不支持 AutoML 寻优")
|
||||
continue
|
||||
|
||||
study = optuna.create_study(
|
||||
direction="maximize",
|
||||
sampler=optuna.samplers.TPESampler(seed=random_state),
|
||||
)
|
||||
study.optimize(
|
||||
_make_objective(model_name, X_sub, y_sub, cv_folds, random_state),
|
||||
n_trials=n_trials,
|
||||
timeout=per_model_timeout,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
if study.best_value is None or study.best_value <= -1.0:
|
||||
notify("warning", f"{tgt}/{model_name}: 全部 trial 失败(CV 全部 <= -1)")
|
||||
continue
|
||||
|
||||
# refit on FULL
|
||||
final_model = _refit_full(model_name, study.best_params, X_t, y_t, random_state)
|
||||
if final_model is None:
|
||||
continue
|
||||
|
||||
# 保存
|
||||
import joblib
|
||||
fname = f"{tgt}_{preproc}_{model_name}_AUTOML.joblib"
|
||||
fpath = preproc_dir / fname
|
||||
joblib.dump({
|
||||
"model": final_model,
|
||||
"target_column_name": tgt,
|
||||
"preprocess_method": preproc,
|
||||
"model_name": model_name,
|
||||
"metadata": {
|
||||
"automl": True,
|
||||
"best_params": study.best_params,
|
||||
"cv_score": float(study.best_value),
|
||||
"n_trials_done": len(study.trials),
|
||||
"n_samples_used_full": int(X_t.shape[0]),
|
||||
"n_samples_used_for_search": int(X_sub.shape[0]),
|
||||
"was_subsampled": was_sub,
|
||||
"split_method": split_method,
|
||||
},
|
||||
}, fpath)
|
||||
|
||||
cand = AutoMLResult(
|
||||
success=True,
|
||||
model_path=str(fpath),
|
||||
cv_score=float(study.best_value),
|
||||
best_params=study.best_params,
|
||||
target_column=tgt,
|
||||
preprocessing=preproc,
|
||||
model_name=model_name,
|
||||
n_trials_done=len(study.trials),
|
||||
n_samples_used=int(X_sub.shape[0]),
|
||||
metadata={"refit_on_full": True, "n_samples_full": int(X_t.shape[0])},
|
||||
)
|
||||
if cand.cv_score > best_overall.cv_score:
|
||||
best_overall = cand
|
||||
except Exception as e:
|
||||
notify("warning", f"目标 {tgt} / 模型 {model_name} 失败: {e!r}")
|
||||
continue
|
||||
|
||||
if not best_overall.success:
|
||||
notify("warning", f"目标 {tgt} 全部 Optuna trial 失败,回退 GridSearchCV")
|
||||
best_overall = _fallback_train(
|
||||
training_csv_path, feature_start_column, preproc, model_names[0], split_method,
|
||||
cv_folds, out_dir, tgt,
|
||||
)
|
||||
|
||||
best_overall.elapsed_sec = time.time() - t0
|
||||
results.append(best_overall)
|
||||
notify("info", f"AutoML 目标 {tgt} 完成 ({ti}/{total}) cv={best_overall.cv_score:.4f}")
|
||||
|
||||
# ---- 6) 汇总 json ----
|
||||
summary_path = out_dir / "automl_summary.json"
|
||||
try:
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump([asdict(r) for r in results], f, ensure_ascii=False, indent=2, default=str)
|
||||
except Exception as e:
|
||||
notify("warning", f"写 automl_summary.json 失败: {e!r}")
|
||||
|
||||
success_n = sum(1 for r in results if r.success)
|
||||
fallback_n = sum(1 for r in results if r.fallback_used)
|
||||
notify("completed", f"AutoML 训练完成 {success_n}/{len(results)} 成功({fallback_n} 走 fallback),汇总 {summary_path}")
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI 自测
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(description="AutoML 训练器 CLI 自测")
|
||||
p.add_argument("--csv", required=True, help="训练用 CSV(feature_start_column 之前的列为目标 y)")
|
||||
p.add_argument("--feature-start", default="0", help="特征起始列名或索引(默认 0)")
|
||||
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
|
||||
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
|
||||
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
|
||||
p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML")
|
||||
args = p.parse_args()
|
||||
|
||||
# 智能推断 feature_start_column 类型
|
||||
fsc: Any = args.feature_start
|
||||
try:
|
||||
fsc = int(fsc)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
res = train_with_automl(
|
||||
training_csv_path=args.csv,
|
||||
feature_start_column=fsc,
|
||||
n_trials=args.n_trials,
|
||||
timeout_sec=args.timeout,
|
||||
max_samples=args.max_samples,
|
||||
output_dir=args.out,
|
||||
)
|
||||
print(f"\n训练完成 {len(res)} 个目标")
|
||||
for r in res:
|
||||
marker = "✓" if r.success else "✗"
|
||||
fb = " [fallback]" if r.fallback_used else ""
|
||||
print(f" {marker} {r.target_column}: cv={r.cv_score:.4f} path={r.model_path}{fb}")
|
||||
@ -126,7 +126,7 @@ class DataPreparationStep:
|
||||
|
||||
@staticmethod
|
||||
def calculate_water_quality_indices(
|
||||
training_spectra_path: Optional[str] = None,
|
||||
training_csv_path: Optional[str] = None,
|
||||
formula_csv_file: Optional[str] = None,
|
||||
formula_names: Optional[List[str]] = None,
|
||||
output_file: Optional[str] = None,
|
||||
@ -153,8 +153,8 @@ class DataPreparationStep:
|
||||
notify("skipped", "跳过水质指数计算")
|
||||
return None
|
||||
|
||||
if training_spectra_path is None:
|
||||
raise ValueError("必须提供 training_spectra_path 参数")
|
||||
if training_csv_path is None:
|
||||
raise ValueError("必须提供 training_csv_path 参数")
|
||||
if formula_csv_file is None:
|
||||
raise ValueError("必须提供 formula_csv_file 参数")
|
||||
|
||||
@ -170,7 +170,7 @@ class DataPreparationStep:
|
||||
|
||||
from src.utils.band_math import BandMathCalculator
|
||||
|
||||
calculator = BandMathCalculator(training_spectra_path)
|
||||
calculator = BandMathCalculator(training_csv_path)
|
||||
result_df = calculator.process_formulas_from_csv(
|
||||
formula_csv_file=formula_csv_file,
|
||||
formula_names=formula_names,
|
||||
|
||||
@ -173,7 +173,7 @@ class WaterQualityInversionPipeline:
|
||||
self.interpolated_img_path = None # 存储插值后的影像路径
|
||||
self.deglint_img_path = None
|
||||
self.processed_csv_path = None
|
||||
self.training_spectra_path = None
|
||||
self.training_csv_path = None
|
||||
self.indices_path = None
|
||||
self.custom_regression_path = None
|
||||
|
||||
@ -511,7 +511,7 @@ class WaterQualityInversionPipeline:
|
||||
left_shoulder_wave: Optional[float] = None,
|
||||
valley_wave: Optional[float] = None,
|
||||
right_shoulder_wave: Optional[float] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
water_mask_path: Optional[Union[str, np.ndarray]] = None,
|
||||
interpolate_zeros: bool = False,
|
||||
interpolation_method: str = 'nearest',
|
||||
enabled: bool = True,
|
||||
@ -546,7 +546,7 @@ class WaterQualityInversionPipeline:
|
||||
left_shoulder_wave=left_shoulder_wave,
|
||||
valley_wave=valley_wave,
|
||||
right_shoulder_wave=right_shoulder_wave,
|
||||
water_mask=water_mask,
|
||||
water_mask=water_mask_path,
|
||||
interpolate_zeros=interpolate_zeros,
|
||||
interpolation_method=interpolation_method,
|
||||
enabled=enabled,
|
||||
@ -655,13 +655,13 @@ class WaterQualityInversionPipeline:
|
||||
water_mask_path=self.water_mask_path,
|
||||
output_dir=str(self.training_spectra_dir),
|
||||
)
|
||||
self.training_spectra_path = result
|
||||
self.training_csv_path = result
|
||||
self._record_step_time("步骤5: 提取训练样本点光谱", 0, 0)
|
||||
self._notify("completed", f"训练光谱数据已保存: {result}")
|
||||
return result
|
||||
|
||||
def step5_5_calculate_water_quality_indices(self,
|
||||
training_spectra_path: Optional[str] = None,
|
||||
training_csv_path: Optional[str] = None,
|
||||
formula_csv_file: Optional[str] = None,
|
||||
formula_names: Optional[List[str]] = None,
|
||||
output_file: Optional[str] = None,
|
||||
@ -673,7 +673,7 @@ class WaterQualityInversionPipeline:
|
||||
使用band_math.py中的方法实现,支持从公式CSV文件中批量计算指定公式
|
||||
|
||||
Args:
|
||||
training_spectra_path: 训练光谱数据CSV路径(如果为None,使用步骤5的结果)
|
||||
training_csv_path: 训练光谱数据CSV路径(如果为None,使用步骤5的结果)
|
||||
formula_csv_file: 公式CSV文件路径,包含公式名称和具体公式
|
||||
formula_names: 要计算的公式名称列表,如果为None则计算所有公式
|
||||
output_file: 输出文件完整路径(支持绝对路径),如果为None则使用默认路径
|
||||
@ -682,16 +682,16 @@ class WaterQualityInversionPipeline:
|
||||
包含计算结果的新CSV文件路径
|
||||
"""
|
||||
# 参数解析(保留原逻辑)
|
||||
if training_spectra_path is not None:
|
||||
csv_path = training_spectra_path
|
||||
elif self.training_spectra_path is not None:
|
||||
csv_path = self.training_spectra_path
|
||||
if training_csv_path is not None:
|
||||
csv_path = training_csv_path
|
||||
elif self.training_csv_path is not None:
|
||||
csv_path = self.training_csv_path
|
||||
else:
|
||||
csv_path = None
|
||||
|
||||
self._notify("started", "步骤5.5: 计算水质光谱指数")
|
||||
result = DataPreparationStep.calculate_water_quality_indices(
|
||||
training_spectra_path=csv_path,
|
||||
training_csv_path=csv_path,
|
||||
formula_csv_file=formula_csv_file,
|
||||
formula_names=formula_names,
|
||||
output_file=output_file,
|
||||
@ -727,8 +727,8 @@ class WaterQualityInversionPipeline:
|
||||
# 参数解析(保留原逻辑)
|
||||
if training_csv_path is not None:
|
||||
final_csv_path = training_csv_path
|
||||
elif self.training_spectra_path is not None:
|
||||
final_csv_path = self.training_spectra_path
|
||||
elif self.training_csv_path is not None:
|
||||
final_csv_path = self.training_csv_path
|
||||
else:
|
||||
final_csv_path = None
|
||||
|
||||
@ -911,7 +911,7 @@ class WaterQualityInversionPipeline:
|
||||
print("="*80)
|
||||
|
||||
if training_csv_path is None:
|
||||
training_csv_path = self.training_spectra_path
|
||||
training_csv_path = self.training_csv_path
|
||||
if training_csv_path is None:
|
||||
raise ValueError("请提供训练数据CSV路径,或先执行步骤5")
|
||||
|
||||
@ -1033,7 +1033,7 @@ class WaterQualityInversionPipeline:
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
csv_path = self.training_spectra_path
|
||||
csv_path = self.training_csv_path
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供CSV文件路径,或先执行步骤5")
|
||||
|
||||
@ -1506,7 +1506,7 @@ class WaterQualityInversionPipeline:
|
||||
if 'step5' in config:
|
||||
self._notify("步骤5: 光谱提取", "start")
|
||||
self.step5_extract_training_spectra(**config['step5'])
|
||||
self._notify("步骤5: 光谱提取", "completed", f"(输出: {self.training_spectra_path})")
|
||||
self._notify("步骤5: 光谱提取", "completed", f"(输出: {self.training_csv_path})")
|
||||
else:
|
||||
self._notify("步骤5: 光谱提取", "skipped", "未配置")
|
||||
|
||||
@ -1615,7 +1615,7 @@ class WaterQualityInversionPipeline:
|
||||
|
||||
# 生成散点图
|
||||
if 'visualization' in config and config['visualization'].get('generate_scatter', True):
|
||||
if self.training_spectra_path and self.models_dir.exists():
|
||||
if self.training_csv_path and self.models_dir.exists():
|
||||
try:
|
||||
self._notify("可视化", "info", "生成模型评估散点图...")
|
||||
scatter_config = config['visualization'].get('scatter_config', {})
|
||||
@ -1653,7 +1653,7 @@ class WaterQualityInversionPipeline:
|
||||
|
||||
# 生成光谱曲线图
|
||||
if 'visualization' in config and config['visualization'].get('generate_spectrum', True):
|
||||
if self.training_spectra_path:
|
||||
if self.training_csv_path:
|
||||
try:
|
||||
self._notify("可视化", "info", "生成光谱曲线对比图...")
|
||||
spectrum_paths = self.generate_spectrum_comparison_plots(
|
||||
@ -1701,7 +1701,7 @@ class WaterQualityInversionPipeline:
|
||||
pipeline_info['step2'] = {'status': 'completed', 'output_file': str(self.glint_mask_path) if self.glint_mask_path else 'N/A'}
|
||||
pipeline_info['step3'] = {'status': 'completed', 'output_file': str(self.deglint_img_path) if self.deglint_img_path else 'N/A'}
|
||||
pipeline_info['step4'] = {'status': 'completed', 'output_file': str(self.processed_csv_path) if self.processed_csv_path else 'N/A'}
|
||||
pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_spectra_path) if self.training_spectra_path else 'N/A'}
|
||||
pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_csv_path) if self.training_csv_path else 'N/A'}
|
||||
pipeline_info['step5_5'] = {'status': 'completed', 'output_file': str(self.indices_path) if self.indices_path else 'N/A'}
|
||||
pipeline_info['step6'] = {'status': 'completed', 'output_file': str(self.models_dir)}
|
||||
pipeline_info['step6_75'] = {'status': 'completed', 'output_file': str(self.custom_regression_path) if self.custom_regression_path else 'N/A'}
|
||||
@ -1784,8 +1784,8 @@ class WaterQualityInversionPipeline:
|
||||
# 参数解析(保留原逻辑)
|
||||
if csv_path is not None:
|
||||
final_csv_path = csv_path
|
||||
elif self.training_spectra_path is not None:
|
||||
final_csv_path = self.training_spectra_path
|
||||
elif self.training_csv_path is not None:
|
||||
final_csv_path = self.training_csv_path
|
||||
else:
|
||||
final_csv_path = None
|
||||
|
||||
@ -2109,7 +2109,7 @@ def main():
|
||||
'interpolation_method': 'bilinear', # 插值方法: 'nearest'(邻近), 'bilinear'(双线性),
|
||||
# 'spline'(样条), 'kriging'(克里金)
|
||||
# 水域掩膜参数(可选):
|
||||
'water_mask':r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp", # None表示自动使用步骤1生成的掩膜,也可以提供:
|
||||
'water_mask_path':r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp", # None表示自动使用步骤1生成的掩膜,也可以提供:
|
||||
# # - numpy数组
|
||||
# # - 栅格文件路径(.dat/.tif)
|
||||
# # - shapefile路径(.shp)
|
||||
|
||||
430
src/gui/components/chart_dialogs.py
Normal file
@ -0,0 +1,430 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
图表与交互弹窗模块
|
||||
|
||||
包含 ChartViewerDialog、ChartBrowserDialog 和 InteractiveViewerDialog 类。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PyQt5.QtWidgets import (
|
||||
QDialog, QVBoxLayout, QHBoxLayout, QPushButton,
|
||||
QSizePolicy, QFileDialog, QMessageBox, QGroupBox,
|
||||
QListWidget, QLabel, QComboBox, QCheckBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
|
||||
class ChartViewerDialog(QDialog):
|
||||
"""图表查看器对话框"""
|
||||
def __init__(self, title="图表查看器", parent=None):
|
||||
super().__init__(parent)
|
||||
self.setWindowTitle(title)
|
||||
self.resize(1000, 700)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
self.figure = Figure(figsize=(10, 7))
|
||||
self.canvas = FigureCanvas(self.figure)
|
||||
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
|
||||
|
||||
self.toolbar = NavigationToolbar(self.canvas, self)
|
||||
|
||||
layout.addWidget(self.toolbar)
|
||||
layout.addWidget(self.canvas)
|
||||
|
||||
btn_layout = QHBoxLayout()
|
||||
|
||||
self.save_btn = QPushButton("保存图表")
|
||||
self.save_btn.clicked.connect(self.save_chart)
|
||||
btn_layout.addWidget(self.save_btn)
|
||||
|
||||
btn_layout.addStretch()
|
||||
|
||||
self.close_btn = QPushButton("关闭")
|
||||
self.close_btn.clicked.connect(self.close)
|
||||
btn_layout.addWidget(self.close_btn)
|
||||
|
||||
layout.addLayout(btn_layout)
|
||||
self.setLayout(layout)
|
||||
|
||||
def display_image(self, image_path):
|
||||
"""显示图片"""
|
||||
self.figure.clear()
|
||||
ax = self.figure.add_subplot(111)
|
||||
|
||||
try:
|
||||
import matplotlib.image as mpimg
|
||||
img = mpimg.imread(image_path)
|
||||
ax.imshow(img)
|
||||
ax.axis('off')
|
||||
self.figure.tight_layout()
|
||||
self.canvas.draw()
|
||||
self.current_image_path = image_path
|
||||
except Exception as e:
|
||||
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
|
||||
ha='center', va='center', transform=ax.transAxes)
|
||||
self.canvas.draw()
|
||||
|
||||
def display_custom_plot(self, plot_func):
|
||||
"""显示自定义绘图函数"""
|
||||
self.figure.clear()
|
||||
try:
|
||||
plot_func(self.figure)
|
||||
self.canvas.draw()
|
||||
except Exception as e:
|
||||
ax = self.figure.add_subplot(111)
|
||||
ax.text(0.5, 0.5, f'绘图失败:\n{str(e)}',
|
||||
ha='center', va='center', transform=ax.transAxes)
|
||||
self.canvas.draw()
|
||||
|
||||
def save_chart(self):
|
||||
"""保存图表"""
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存图表", "",
|
||||
"PNG图片 (*.png);;JPG图片 (*.jpg);;PDF文件 (*.pdf);;所有文件 (*.*)"
|
||||
)
|
||||
if file_path:
|
||||
try:
|
||||
self.figure.savefig(file_path, dpi=300, bbox_inches='tight')
|
||||
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
|
||||
|
||||
|
||||
class ChartBrowserDialog(QDialog):
|
||||
"""图表浏览器对话框"""
|
||||
def __init__(self, chart_files, parent=None):
|
||||
super().__init__(parent)
|
||||
self.chart_files = sorted(chart_files, key=lambda x: x.stat().st_mtime, reverse=True)
|
||||
self.current_index = 0
|
||||
self.setWindowTitle("图表浏览器")
|
||||
self.resize(1200, 800)
|
||||
self.init_ui()
|
||||
self.show_chart(0)
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
list_group = QGroupBox(f"图表列表 (共 {len(self.chart_files)} 个)")
|
||||
list_layout = QHBoxLayout()
|
||||
|
||||
self.chart_list = QListWidget()
|
||||
self.chart_list.setMaximumHeight(150)
|
||||
for chart_file in self.chart_files:
|
||||
self.chart_list.addItem(chart_file.name)
|
||||
self.chart_list.currentRowChanged.connect(self.show_chart)
|
||||
|
||||
list_layout.addWidget(self.chart_list)
|
||||
list_group.setLayout(list_layout)
|
||||
layout.addWidget(list_group)
|
||||
|
||||
self.figure = Figure(figsize=(12, 8))
|
||||
self.canvas = FigureCanvas(self.figure)
|
||||
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
|
||||
|
||||
self.toolbar = NavigationToolbar(self.canvas, self)
|
||||
layout.addWidget(self.toolbar)
|
||||
layout.addWidget(self.canvas, 1)
|
||||
|
||||
btn_layout = QHBoxLayout()
|
||||
|
||||
self.prev_btn = QPushButton("◀ 上一个")
|
||||
self.prev_btn.clicked.connect(self.prev_chart)
|
||||
btn_layout.addWidget(self.prev_btn)
|
||||
|
||||
self.next_btn = QPushButton("下一个 >")
|
||||
self.next_btn.clicked.connect(self.next_chart)
|
||||
btn_layout.addWidget(self.next_btn)
|
||||
|
||||
btn_layout.addStretch()
|
||||
|
||||
self.save_btn = QPushButton("💾 保存当前图表")
|
||||
self.save_btn.clicked.connect(self.save_current_chart)
|
||||
btn_layout.addWidget(self.save_btn)
|
||||
|
||||
self.close_btn = QPushButton("关闭")
|
||||
self.close_btn.clicked.connect(self.close)
|
||||
btn_layout.addWidget(self.close_btn)
|
||||
|
||||
layout.addLayout(btn_layout)
|
||||
self.setLayout(layout)
|
||||
|
||||
def show_chart(self, index):
|
||||
"""显示指定索引的图表"""
|
||||
if 0 <= index < len(self.chart_files):
|
||||
self.current_index = index
|
||||
self.chart_list.setCurrentRow(index)
|
||||
|
||||
chart_file = self.chart_files[index]
|
||||
self.figure.clear()
|
||||
ax = self.figure.add_subplot(111)
|
||||
|
||||
try:
|
||||
import matplotlib.image as mpimg
|
||||
img = mpimg.imread(str(chart_file))
|
||||
ax.imshow(img)
|
||||
ax.axis('off')
|
||||
ax.set_title(chart_file.name, fontsize=12, pad=10)
|
||||
self.figure.tight_layout()
|
||||
self.canvas.draw()
|
||||
except Exception as e:
|
||||
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
|
||||
ha='center', va='center', transform=ax.transAxes)
|
||||
self.canvas.draw()
|
||||
|
||||
self.prev_btn.setEnabled(index > 0)
|
||||
self.next_btn.setEnabled(index < len(self.chart_files) - 1)
|
||||
|
||||
def prev_chart(self):
|
||||
"""上一个图表"""
|
||||
if self.current_index > 0:
|
||||
self.show_chart(self.current_index - 1)
|
||||
|
||||
def next_chart(self):
|
||||
"""下一个图表"""
|
||||
if self.current_index < len(self.chart_files) - 1:
|
||||
self.show_chart(self.current_index + 1)
|
||||
|
||||
def save_current_chart(self):
|
||||
"""保存当前图表"""
|
||||
if 0 <= self.current_index < len(self.chart_files):
|
||||
current_file = self.chart_files[self.current_index]
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存图表", current_file.name,
|
||||
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
|
||||
)
|
||||
if file_path:
|
||||
try:
|
||||
import shutil
|
||||
shutil.copy(str(current_file), file_path)
|
||||
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
|
||||
|
||||
|
||||
class InteractiveViewerDialog(QDialog):
|
||||
"""交互式影像预览对话框:显示影像、参考点散点图、点击查询坐标/值"""
|
||||
|
||||
def __init__(self, parent, img_path, ref_csv=None):
|
||||
super().__init__(parent)
|
||||
self.img_path = img_path
|
||||
self.ref_csv = ref_csv
|
||||
self.geotransform = None
|
||||
self.fig = None
|
||||
self.canvas = None
|
||||
self.ax = None
|
||||
self.status_label = None
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
self.setWindowTitle("👁️ 交互式影像预览")
|
||||
self.setMinimumSize(900, 700)
|
||||
|
||||
layout = QVBoxLayout()
|
||||
|
||||
toolbar = QHBoxLayout()
|
||||
self.band_combo = QComboBox()
|
||||
self.band_combo.currentIndexChanged.connect(self.on_band_changed)
|
||||
toolbar.addWidget(QLabel("显示波段:"))
|
||||
toolbar.addWidget(self.band_combo)
|
||||
|
||||
self.gray_check = QCheckBox("灰度显示")
|
||||
self.gray_check.stateChanged.connect(self.on_band_changed)
|
||||
toolbar.addWidget(self.gray_check)
|
||||
toolbar.addStretch()
|
||||
layout.addLayout(toolbar)
|
||||
|
||||
try:
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.figure import Figure
|
||||
import matplotlib
|
||||
matplotlib.use('Qt5Agg')
|
||||
|
||||
self.fig = Figure(figsize=(10, 8))
|
||||
self.canvas = FigureCanvas(self.fig)
|
||||
self.ax = self.fig.add_subplot(111)
|
||||
self.fig.tight_layout()
|
||||
layout.addWidget(self.canvas)
|
||||
|
||||
self.load_and_display()
|
||||
|
||||
except ImportError as e:
|
||||
layout.addWidget(QLabel(f"Matplotlib 未安装: {e}"))
|
||||
|
||||
self.status_label = QLabel("点击影像查看像素坐标和经纬度")
|
||||
self.status_label.setStyleSheet("background:#f0f0f0;padding:4px;font-size:12px;")
|
||||
self.status_label.setWordWrap(True)
|
||||
layout.addWidget(self.status_label)
|
||||
|
||||
close_btn = QPushButton("关闭")
|
||||
close_btn.clicked.connect(self.close)
|
||||
layout.addWidget(close_btn)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def load_and_display(self):
|
||||
"""加载影像并显示"""
|
||||
from osgeo import gdal
|
||||
|
||||
dataset = gdal.Open(self.img_path)
|
||||
if dataset is None:
|
||||
self.status_label.setText(f"无法打开影像: {self.img_path}")
|
||||
return
|
||||
|
||||
self.geotransform = dataset.GetGeoTransform()
|
||||
self.projection = dataset.GetProjection()
|
||||
n_bands = dataset.RasterCount
|
||||
self.height = dataset.RasterYSize
|
||||
self.width = dataset.RasterXSize
|
||||
|
||||
self.band_combo.clear()
|
||||
if n_bands >= 3:
|
||||
for i in range(1, n_bands + 1):
|
||||
self.band_combo.addItem(f"RGB (B{i-0}, G{i-1}, R{i-2})" if i >= 3 else f"波段 {i}", i)
|
||||
self.band_combo.addItem(f"单波段 (B1)", 0)
|
||||
else:
|
||||
for i in range(1, n_bands + 1):
|
||||
self.band_combo.addItem(f"波段 {i}", i - 1)
|
||||
self.band_combo.setCurrentIndex(0)
|
||||
|
||||
self.dataset = dataset
|
||||
self.display_band(0, is_gray=False)
|
||||
self.load_ref_points()
|
||||
|
||||
def display_band(self, band_idx, is_gray=False):
|
||||
"""显示指定波段组合"""
|
||||
from osgeo import gdal
|
||||
import numpy as np
|
||||
|
||||
dataset = self.dataset
|
||||
self.ax.clear()
|
||||
|
||||
if is_gray or (self.band_combo.currentData() == 0 and dataset.RasterCount == 1):
|
||||
band = dataset.GetRasterBand(1 if band_idx == 0 else band_idx + 1)
|
||||
data = band.ReadAsArray()
|
||||
data = np.nan_to_num(data, nan=0.0)
|
||||
self.ax.imshow(data, cmap='gray')
|
||||
self.ax.set_title(f"波段 {band_idx + 1} (灰度)")
|
||||
else:
|
||||
n = min(3, dataset.RasterCount)
|
||||
bands_data = []
|
||||
for i in range(n):
|
||||
b = dataset.GetRasterBand(i + 1)
|
||||
bd = b.ReadAsArray()
|
||||
bd = np.nan_to_num(bd, nan=0.0)
|
||||
bands_data.append(bd)
|
||||
rgb = np.dstack(bands_data)
|
||||
|
||||
for i in range(rgb.shape[2]):
|
||||
p2, p98 = np.percentile(rgb[:, :, i], [2, 98])
|
||||
if p98 > p2:
|
||||
rgb[:, :, i] = np.clip((rgb[:, :, i] - p2) / (p98 - p2), 0, 1)
|
||||
else:
|
||||
rgb[:, :, i] = np.clip(rgb[:, :, i] / (p98 + 1e-6), 0, 1)
|
||||
|
||||
self.ax.imshow(rgb)
|
||||
self.ax.set_title(f"RGB 显示")
|
||||
|
||||
self.ax.set_xlabel("列 (Column)")
|
||||
self.ax.set_ylabel("行 (Row)")
|
||||
self.fig.tight_layout()
|
||||
self.canvas.draw()
|
||||
|
||||
self.cid = self.canvas.mpl_connect('button_press_event', self.on_click)
|
||||
|
||||
def on_band_changed(self):
|
||||
"""波段选择变化时更新显示"""
|
||||
if not hasattr(self, 'dataset'):
|
||||
return
|
||||
is_gray = self.gray_check.isChecked()
|
||||
band_data = self.band_combo.currentData()
|
||||
self.display_band(band_data if band_data != 0 else 0, is_gray=is_gray)
|
||||
|
||||
def load_ref_points(self):
|
||||
"""加载并显示参考点"""
|
||||
import os
|
||||
if not self.ref_csv or not os.path.isfile(self.ref_csv):
|
||||
return
|
||||
|
||||
try:
|
||||
import csv
|
||||
lon_list, lat_list = [], []
|
||||
with open(self.ref_csv, 'r', encoding='utf-8-sig') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
try:
|
||||
lon = float(row.get('Lon', row.get('lon', row.get('LON', 0))))
|
||||
lat = float(row.get('Lat', row.get('lat', row.get('LAT', 0))))
|
||||
if lon and lat:
|
||||
lon_list.append(lon)
|
||||
lat_list.append(lat)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
if not lon_list:
|
||||
return
|
||||
|
||||
px_list, py_list = [], []
|
||||
gt = self.geotransform
|
||||
if gt and (gt[1] != 0 or gt[5] != 0):
|
||||
for lon, lat in zip(lon_list, lat_list):
|
||||
px = (lon - gt[0]) / gt[1]
|
||||
py = (lat - gt[3]) / gt[5]
|
||||
if 0 <= px < self.width and 0 <= py < self.height:
|
||||
px_list.append(px)
|
||||
py_list.append(py)
|
||||
|
||||
if px_list:
|
||||
self.ax.scatter(px_list, py_list, c='red', s=40, marker='o',
|
||||
edgecolors='white', linewidths=0.8, zorder=5, alpha=0.9,
|
||||
label=f'参考点 ({len(px_list)}个)')
|
||||
self.ax.legend(loc='upper right', fontsize=9)
|
||||
self.fig.tight_layout()
|
||||
self.canvas.draw()
|
||||
self.status_label.setText(
|
||||
f"已加载 {len(px_list)} 个参考点(仅显示在影像范围内的点)"
|
||||
)
|
||||
except Exception as e:
|
||||
self.status_label.setText(f"加载参考点失败: {e}")
|
||||
|
||||
def pixel_to_geo(self, px, py):
|
||||
"""像素坐标转经纬度"""
|
||||
gt = self.geotransform
|
||||
if gt is None:
|
||||
return None, None
|
||||
lon = gt[0] + px * gt[1] + py * gt[2]
|
||||
lat = gt[3] + px * gt[4] + py * gt[5]
|
||||
return lon, lat
|
||||
|
||||
def on_click(self, event):
|
||||
"""鼠标点击事件"""
|
||||
if event.inaxes != self.ax or event.xdata is None or event.ydata is None:
|
||||
return
|
||||
|
||||
px, py = int(round(event.xdata)), int(round(event.ydata))
|
||||
if not (0 <= px < self.width and 0 <= py < self.height):
|
||||
return
|
||||
|
||||
from osgeo import gdal
|
||||
import numpy as np
|
||||
dataset = self.dataset
|
||||
n_bands = dataset.RasterCount
|
||||
vals = []
|
||||
for b in range(1, n_bands + 1):
|
||||
val = dataset.GetRasterBand(b).ReadAsArray()[py, px]
|
||||
vals.append(f"{val:.4f}" if isinstance(val, float) else str(val))
|
||||
|
||||
lon, lat = self.pixel_to_geo(px, py)
|
||||
geo_str = f"Lon={lon:.6f}, Lat={lat:.6f}" if lon is not None else "无地理参考"
|
||||
|
||||
self.status_label.setText(
|
||||
f"像素: (行={py}, 列={px}) | {geo_str} | "
|
||||
f"波段值: {' | '.join(vals[:5])}" +
|
||||
(f" ... ({n_bands}波段的更多信息)" if n_bands > 5 else "")
|
||||
)
|
||||
50
src/gui/components/data_models.py
Normal file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据模型模块
|
||||
|
||||
包含 PandasTableModel 等数据模型类。
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from PyQt5.QtCore import Qt, QAbstractTableModel
|
||||
|
||||
|
||||
class PandasTableModel(QAbstractTableModel):
|
||||
"""支持DataFrame的表格模型"""
|
||||
def __init__(self, data_frame: pd.DataFrame):
|
||||
super().__init__()
|
||||
self._data = data_frame.copy()
|
||||
if self._data.empty:
|
||||
self._data = pd.DataFrame()
|
||||
self._data.fillna("", inplace=True)
|
||||
self._columns = [str(col) for col in self._data.columns]
|
||||
|
||||
def rowCount(self, parent=None):
|
||||
return len(self._data)
|
||||
|
||||
def columnCount(self, parent=None):
|
||||
return len(self._columns)
|
||||
|
||||
def data(self, index, role=Qt.DisplayRole):
|
||||
if not index.isValid() or role != Qt.DisplayRole:
|
||||
return None
|
||||
|
||||
value = self._data.iat[index.row(), index.column()]
|
||||
if pd.isna(value):
|
||||
return ""
|
||||
return str(value)
|
||||
|
||||
def headerData(self, section, orientation, role=Qt.DisplayRole):
|
||||
if role != Qt.DisplayRole:
|
||||
return None
|
||||
if orientation == Qt.Horizontal:
|
||||
if section < len(self._columns):
|
||||
return self._columns[section]
|
||||
return str(section)
|
||||
return str(section + 1)
|
||||
|
||||
def flags(self, index):
|
||||
if not index.isValid():
|
||||
return Qt.NoItemFlags
|
||||
return Qt.ItemIsEnabled | Qt.ItemIsSelectable
|
||||
351
src/gui/components/image_widgets.py
Normal file
@ -0,0 +1,351 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
图像浏览组件模块
|
||||
|
||||
包含 ImageCategoryTree 和 ImageViewerWidget 类。
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QTreeWidget, QTreeWidgetItem, QWidget, QVBoxLayout, QHBoxLayout,
|
||||
QPushButton, QLabel, QScrollArea, QFrame, QGroupBox,
|
||||
QFileDialog, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt, QTimer
|
||||
from PyQt5.QtGui import QPixmap
|
||||
|
||||
|
||||
class ImageCategoryTree(QTreeWidget):
|
||||
"""图像分类目录树 - 按类别组织图像文件"""
|
||||
|
||||
CATEGORIES = [
|
||||
("模型评估", ["scatter", "regression", "validation", "r2", "rmse"], "📊"),
|
||||
("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"),
|
||||
("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"),
|
||||
("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"),
|
||||
("含量分布图", [], "📁"),
|
||||
]
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.setHeaderLabel("图像目录")
|
||||
self.setMaximumWidth(300)
|
||||
self.setMinimumWidth(250)
|
||||
self.setup_categories()
|
||||
self.setStyleSheet("""
|
||||
QTreeWidget {
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 5px;
|
||||
background-color: #f8f9fa;
|
||||
}
|
||||
QTreeWidget::item {
|
||||
padding: 5px;
|
||||
border-radius: 3px;
|
||||
}
|
||||
QTreeWidget::item:selected {
|
||||
background-color: #0078D4;
|
||||
color: white;
|
||||
}
|
||||
QTreeWidget::item:hover {
|
||||
background-color: #e3f2fd;
|
||||
}
|
||||
""")
|
||||
|
||||
def setup_categories(self):
|
||||
"""初始化类别节点"""
|
||||
self.category_items = {}
|
||||
for category_name, keywords, icon in self.CATEGORIES:
|
||||
item = QTreeWidgetItem(self)
|
||||
item.setText(0, f"{icon} {category_name}")
|
||||
item.setData(0, Qt.UserRole, {"type": "category", "keywords": keywords, "name": category_name})
|
||||
item.setExpanded(True)
|
||||
self.category_items[category_name] = item
|
||||
|
||||
def clear_all_images(self):
|
||||
"""清除所有图像项"""
|
||||
for category_item in self.category_items.values():
|
||||
while category_item.childCount() > 0:
|
||||
category_item.removeChild(category_item.child(0))
|
||||
|
||||
def add_image(self, file_path: Path, display_name: str = None):
|
||||
"""添加图像到对应的类别"""
|
||||
if display_name is None:
|
||||
display_name = file_path.stem
|
||||
|
||||
category = self._determine_category(file_path.name)
|
||||
category_item = self.category_items.get(category, self.category_items["含量分布图"])
|
||||
|
||||
image_item = QTreeWidgetItem(category_item)
|
||||
image_item.setText(0, f" └─ {display_name}")
|
||||
image_item.setData(0, Qt.UserRole, {"type": "image", "path": str(file_path)})
|
||||
image_item.setToolTip(0, str(file_path))
|
||||
|
||||
return image_item
|
||||
|
||||
def _determine_category(self, filename: str) -> str:
|
||||
"""根据文件名确定类别"""
|
||||
filename_lower = filename.lower()
|
||||
|
||||
for category_name, keywords, _ in self.CATEGORIES:
|
||||
if any(keyword in filename_lower for keyword in keywords):
|
||||
return category_name
|
||||
|
||||
return "含量分布图"
|
||||
|
||||
def scan_directory(self, work_dir: str):
|
||||
"""扫描目录中的所有图像文件"""
|
||||
self.clear_all_images()
|
||||
|
||||
work_path = Path(work_dir)
|
||||
if not work_path.exists():
|
||||
return
|
||||
|
||||
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff', '*.bmp']
|
||||
scan_roots: List[Path] = []
|
||||
_viz = work_path / "14_visualization"
|
||||
if _viz.is_dir():
|
||||
scan_roots.append(_viz)
|
||||
_wm = work_path / "1_water_mask"
|
||||
if _wm.is_dir():
|
||||
scan_roots.append(_wm)
|
||||
if not scan_roots:
|
||||
scan_roots.append(work_path)
|
||||
|
||||
seen_norm: set = set()
|
||||
image_files: List[Path] = []
|
||||
for root in scan_roots:
|
||||
for ext in image_extensions:
|
||||
for p in root.glob(f"**/{ext}"):
|
||||
key = os.path.normcase(os.path.normpath(str(p.resolve())))
|
||||
if key in seen_norm:
|
||||
continue
|
||||
seen_norm.add(key)
|
||||
image_files.append(p)
|
||||
|
||||
for img_file in sorted(image_files):
|
||||
if img_file.name.startswith('.') or 'thumb' in img_file.name.lower():
|
||||
continue
|
||||
self.add_image(img_file)
|
||||
|
||||
for category_name, item in self.category_items.items():
|
||||
count = item.childCount()
|
||||
if count > 0:
|
||||
for cat_name, _, icon in self.CATEGORIES:
|
||||
if cat_name == category_name:
|
||||
item.setText(0, f"{icon} {category_name} ({count})")
|
||||
break
|
||||
|
||||
def get_selected_image_path(self) -> Optional[str]:
|
||||
"""获取当前选中的图像路径"""
|
||||
selected_item = self.currentItem()
|
||||
if not selected_item:
|
||||
return None
|
||||
|
||||
data = selected_item.data(0, Qt.UserRole)
|
||||
if data and data.get("type") == "image":
|
||||
return data.get("path")
|
||||
return None
|
||||
|
||||
|
||||
class ImageViewerWidget(QWidget):
|
||||
"""图像查看器组件 - 支持缩放、平移"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.current_image_path = None
|
||||
self.scale_factor = 1.0
|
||||
self._update_timer = QTimer()
|
||||
self._update_timer.setSingleShot(True)
|
||||
self._update_timer.timeout.connect(self._do_update_display)
|
||||
self._pending_scale = None
|
||||
self.setup_ui()
|
||||
|
||||
def setup_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
toolbar = QHBoxLayout()
|
||||
|
||||
self.refresh_btn = QPushButton("🔄 刷新目录")
|
||||
self.refresh_btn.setToolTip("重新扫描工作目录中的图像文件")
|
||||
toolbar.addWidget(self.refresh_btn)
|
||||
|
||||
separator = QFrame()
|
||||
separator.setFrameShape(QFrame.VLine)
|
||||
separator.setFrameShadow(QFrame.Sunken)
|
||||
toolbar.addWidget(separator)
|
||||
|
||||
self.zoom_in_btn = QPushButton("🔍+")
|
||||
self.zoom_in_btn.setToolTip("放大")
|
||||
self.zoom_in_btn.setMaximumWidth(50)
|
||||
toolbar.addWidget(self.zoom_in_btn)
|
||||
|
||||
self.zoom_out_btn = QPushButton("🔍-")
|
||||
self.zoom_out_btn.setToolTip("缩小")
|
||||
self.zoom_out_btn.setMaximumWidth(50)
|
||||
toolbar.addWidget(self.zoom_out_btn)
|
||||
|
||||
self.fit_btn = QPushButton("⬜ 适应窗口")
|
||||
self.fit_btn.setToolTip("适应窗口大小")
|
||||
toolbar.addWidget(self.fit_btn)
|
||||
|
||||
self.original_btn = QPushButton("1:1 原始大小")
|
||||
self.original_btn.setToolTip("原始大小")
|
||||
toolbar.addWidget(self.original_btn)
|
||||
|
||||
toolbar.addStretch()
|
||||
|
||||
self.save_btn = QPushButton("💾 保存")
|
||||
self.save_btn.setToolTip("保存当前图像")
|
||||
toolbar.addWidget(self.save_btn)
|
||||
|
||||
layout.addLayout(toolbar)
|
||||
|
||||
self.scroll_area = QScrollArea()
|
||||
self.scroll_area.setWidgetResizable(True)
|
||||
self.scroll_area.setStyleSheet("background-color: white;")
|
||||
|
||||
self.image_label = QLabel()
|
||||
self.image_label.setAlignment(Qt.AlignCenter)
|
||||
self.image_label.setStyleSheet("background-color: white;")
|
||||
|
||||
self.scroll_area.setWidget(self.image_label)
|
||||
layout.addWidget(self.scroll_area, 1)
|
||||
|
||||
status_layout = QHBoxLayout()
|
||||
self.status_label = QLabel("就绪")
|
||||
self.status_label.setStyleSheet("color: #666; font-size: 11px;")
|
||||
status_layout.addWidget(self.status_label)
|
||||
status_layout.addStretch()
|
||||
layout.addLayout(status_layout)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
self.zoom_in_btn.clicked.connect(self.zoom_in)
|
||||
self.zoom_out_btn.clicked.connect(self.zoom_out)
|
||||
self.fit_btn.clicked.connect(self.fit_to_window)
|
||||
self.original_btn.clicked.connect(self.original_size)
|
||||
self.save_btn.clicked.connect(self.save_image)
|
||||
|
||||
def load_image(self, image_path: str):
|
||||
"""加载并显示图像"""
|
||||
if not image_path or not Path(image_path).exists():
|
||||
self.image_label.setText("图像不存在")
|
||||
self.status_label.setText("图像加载失败")
|
||||
return
|
||||
|
||||
self.current_image_path = image_path
|
||||
self.scale_factor = 1.0
|
||||
|
||||
pixmap = QPixmap(image_path)
|
||||
if pixmap.isNull():
|
||||
self.image_label.setText("无法加载图像")
|
||||
self.status_label.setText("图像格式不支持")
|
||||
return
|
||||
|
||||
self.original_pixmap = pixmap
|
||||
self.fit_to_window()
|
||||
|
||||
file_info = Path(image_path).stat()
|
||||
size_mb = file_info.st_size / (1024 * 1024)
|
||||
self.status_label.setText(f"{pixmap.width()}x{pixmap.height()} | {size_mb:.2f} MB | {Path(image_path).name} | 适应窗口")
|
||||
|
||||
def update_image_display(self):
|
||||
"""更新图像显示 - 使用防抖避免频繁重绘卡顿"""
|
||||
self._update_timer.stop()
|
||||
self._pending_scale = self.scale_factor
|
||||
self._update_timer.start(50)
|
||||
|
||||
def _do_update_display(self):
|
||||
"""实际执行图像更新"""
|
||||
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
|
||||
return
|
||||
|
||||
if self._pending_scale is None:
|
||||
return
|
||||
|
||||
if self._pending_scale > 2.0 or self._pending_scale < 0.5:
|
||||
transform = Qt.FastTransformation
|
||||
else:
|
||||
transform = Qt.SmoothTransformation
|
||||
|
||||
scaled_pixmap = self.original_pixmap.scaled(
|
||||
int(self.original_pixmap.width() * self._pending_scale),
|
||||
int(self.original_pixmap.height() * self._pending_scale),
|
||||
Qt.KeepAspectRatio,
|
||||
transform
|
||||
)
|
||||
self.image_label.setPixmap(scaled_pixmap)
|
||||
self._pending_scale = None
|
||||
|
||||
def wheelEvent(self, event):
|
||||
"""鼠标滚轮缩放 - 实时响应"""
|
||||
delta = event.angleDelta().y()
|
||||
|
||||
if delta > 0:
|
||||
if self.scale_factor < 5.0:
|
||||
self.scale_factor = min(self.scale_factor * 1.1, 5.0)
|
||||
self.update_image_display()
|
||||
else:
|
||||
if self.scale_factor > 0.1:
|
||||
self.scale_factor = max(self.scale_factor / 1.1, 0.1)
|
||||
self.update_image_display()
|
||||
|
||||
event.accept()
|
||||
|
||||
def zoom_in(self):
|
||||
"""放大"""
|
||||
if self.scale_factor < 5.0:
|
||||
self.scale_factor = min(self.scale_factor * 1.25, 5.0)
|
||||
self.update_image_display()
|
||||
|
||||
def zoom_out(self):
|
||||
"""缩小"""
|
||||
if self.scale_factor > 0.1:
|
||||
self.scale_factor = max(self.scale_factor / 1.25, 0.1)
|
||||
self.update_image_display()
|
||||
|
||||
def fit_to_window(self):
|
||||
"""适应窗口"""
|
||||
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
|
||||
return
|
||||
|
||||
view_size = self.scroll_area.viewport().size()
|
||||
img_size = self.original_pixmap.size()
|
||||
|
||||
scale_w = view_size.width() / img_size.width()
|
||||
scale_h = view_size.height() / img_size.height()
|
||||
|
||||
self._fit_scale = min(scale_w, scale_h)
|
||||
self.scale_factor = self._fit_scale
|
||||
|
||||
self.update_image_display()
|
||||
self.status_label.setText(f"适应窗口 | 缩放: {self.scale_factor:.1%}")
|
||||
|
||||
def original_size(self):
|
||||
"""原始大小"""
|
||||
self.scale_factor = 1.0
|
||||
self._fit_scale = None
|
||||
self.update_image_display()
|
||||
self.status_label.setText("原始大小 | 缩放: 100%")
|
||||
|
||||
def save_image(self):
|
||||
"""保存图像"""
|
||||
if not self.current_image_path:
|
||||
return
|
||||
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存图像", Path(self.current_image_path).name,
|
||||
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
|
||||
)
|
||||
|
||||
if file_path:
|
||||
try:
|
||||
import shutil
|
||||
shutil.copy(self.current_image_path, file_path)
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "错误", f"保存失败: {e}")
|
||||
112
src/gui/core/test_modeling.py
Normal file
@ -0,0 +1,112 @@
|
||||
import time
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.datasets import make_regression
|
||||
|
||||
# 屏蔽烦人的 sklearn 警告
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
print("====== 🚀 启动 Mega Water 模型终极体检脚本 ======")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 1. 完美复刻侦察报告中的 CSV 数据结构
|
||||
# 报告指出: 目标值(y)在左边,光谱特征(X)在右边
|
||||
# ---------------------------------------------------------
|
||||
print("📦 正在生成符合系统结构的模拟测试数据...")
|
||||
X_raw, y_raw = make_regression(n_samples=200, n_features=50, noise=0.1, random_state=42)
|
||||
|
||||
# 模拟真实的 CSV 列名:前2列是水质参数,后面是 50 个光谱波段
|
||||
columns = ['Chla', 'SS'] + [f"Band_{i}" for i in range(50)]
|
||||
# 拼装成一整张大表
|
||||
data = pd.DataFrame(np.hstack((y_raw.reshape(-1, 1), (y_raw * 0.5).reshape(-1, 1), X_raw)), columns=columns)
|
||||
|
||||
# 按照 load_data_batch 的逻辑进行切割
|
||||
feature_start_index = 2
|
||||
X = data.iloc[:, feature_start_index:] # 截取光谱作为 X
|
||||
y = data['Chla'] # 提取一个目标参数作为 y
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
print(f"✅ 数据切割完毕! 模拟波段数: {X.shape[1]}, 训练集样本数: {X_train.shape[0]}\n")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 2. 严格装载侦察报告中的 16 个真实模型
|
||||
# ---------------------------------------------------------
|
||||
print("🔍 正在加载底层真实配置库中的模型...")
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.tree import DecisionTreeRegressor
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
|
||||
# 将参数压至极低,实施“降维打击”,确保 1 秒内跑完
|
||||
models = {
|
||||
'SVR': SVR(),
|
||||
'RF': RandomForestRegressor(n_estimators=10, max_depth=5, n_jobs=-1),
|
||||
'KNN': KNeighborsRegressor(),
|
||||
'LinearRegression': LinearRegression(),
|
||||
'Ridge': Ridge(),
|
||||
'Lasso': Lasso(),
|
||||
'ElasticNet': ElasticNet(),
|
||||
'PLS': PLSRegression(),
|
||||
'GradientBoosting': GradientBoostingRegressor(n_estimators=10, max_depth=5),
|
||||
'AdaBoost': AdaBoostRegressor(n_estimators=10),
|
||||
'DecisionTree': DecisionTreeRegressor(max_depth=5),
|
||||
'MLP': MLPRegressor(max_iter=50),
|
||||
'ExtraTrees': ExtraTreesRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
|
||||
}
|
||||
|
||||
# 针对报告中发现的 3 个“被禁用”的第三方强力库,进行刺探测试
|
||||
try:
|
||||
from xgboost import XGBRegressor
|
||||
|
||||
models['XGBoost'] = XGBRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
|
||||
except ImportError:
|
||||
models['XGBoost'] = "IMPORT_ERROR"
|
||||
|
||||
try:
|
||||
from lightgbm import LGBMRegressor
|
||||
|
||||
models['LightGBM'] = LGBMRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
|
||||
except ImportError:
|
||||
models['LightGBM'] = "IMPORT_ERROR"
|
||||
|
||||
try:
|
||||
from catboost import CatBoostRegressor
|
||||
|
||||
models['CatBoost'] = CatBoostRegressor(iterations=10, depth=5, verbose=0)
|
||||
except ImportError:
|
||||
models['CatBoost'] = "IMPORT_ERROR"
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 3. 开始残酷的体检循环
|
||||
# ---------------------------------------------------------
|
||||
print("\n================ 开始跑分测试 ================")
|
||||
results = []
|
||||
|
||||
for name, model in models.items():
|
||||
if model == "IMPORT_ERROR":
|
||||
results.append(f"⚠️ [缺库] {name:<16} : 环境未安装此库 (建议: pip install {name.lower()})")
|
||||
continue
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 极速拟合与评分
|
||||
model.fit(X_train, y_train)
|
||||
score = model.score(X_test, y_test)
|
||||
cost_time = time.time() - start_time
|
||||
results.append(f"✅ [成功] {name:<16} : 耗时 {cost_time:.3f} 秒 (R2: {score:.2f})")
|
||||
except Exception as e:
|
||||
error_msg = str(e).split('\n')[0]
|
||||
results.append(f"❌ [崩溃] {name:<16} : {error_msg}")
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 4. 打印最终体检报告
|
||||
# ---------------------------------------------------------
|
||||
print("\n=============== 🏥 最终体检报告 ===============")
|
||||
for res in results:
|
||||
print(res)
|
||||
print("===============================================")
|
||||
346
src/gui/core/viz_thread.py
Normal file
@ -0,0 +1,346 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化后台线程模块
|
||||
|
||||
包含 VisualizationWorkerThread 后台线程类和辅助函数。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from PyQt5.QtCore import QThread, pyqtSignal
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _viz_infer_wavelength_start_column(df) -> Union[str, int]:
|
||||
"""推断光谱起始列(training_spectra 通常以波长数值为列名,未必含 UTM_Y)。"""
|
||||
import pandas as pd
|
||||
for i, col in enumerate(df.columns):
|
||||
name = str(col).strip().lstrip("\ufeff")
|
||||
try:
|
||||
v = float(name)
|
||||
except ValueError:
|
||||
continue
|
||||
if 200.0 <= v <= 3000.0:
|
||||
return i
|
||||
if "UTM_Y" in df.columns:
|
||||
return "UTM_Y"
|
||||
return 0
|
||||
|
||||
|
||||
class VisualizationWorkerThread(QThread):
|
||||
"""可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。"""
|
||||
|
||||
finished_ok = pyqtSignal(object)
|
||||
failed = pyqtSignal(str)
|
||||
|
||||
def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None):
|
||||
super().__init__()
|
||||
self.task = task
|
||||
self.work_dir = str(work_dir)
|
||||
self.extra = extra or {}
|
||||
|
||||
def run(self):
|
||||
mpl_prev = None
|
||||
try:
|
||||
import matplotlib
|
||||
mpl_prev = matplotlib.get_backend()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("Agg")
|
||||
except Exception:
|
||||
mpl_prev = None
|
||||
try:
|
||||
wp = Path(self.work_dir)
|
||||
if self.task == "mask_glint":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
preview_paths = viz.generate_glint_deglint_previews(
|
||||
work_dir=str(wp),
|
||||
output_subdir="glint_deglint_previews",
|
||||
)
|
||||
cnt = len(preview_paths) if preview_paths else 0
|
||||
self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths})
|
||||
elif self.task == "sampling_map":
|
||||
hyperspectral_files = []
|
||||
deglint_dir = wp / "3_deglint"
|
||||
if deglint_dir.exists():
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
|
||||
if not hyperspectral_files:
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
|
||||
if not hyperspectral_files:
|
||||
self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif)。")
|
||||
return
|
||||
hyperspectral_path = str(hyperspectral_files[0])
|
||||
csv_files = []
|
||||
processed_dir = wp / "4_processed_data"
|
||||
if processed_dir.exists():
|
||||
csv_files = list(processed_dir.glob("*.csv"))
|
||||
if not csv_files:
|
||||
csv_files = (
|
||||
list(wp.glob("**/*sampling*.csv"))
|
||||
+ list(wp.glob("**/*point*.csv"))
|
||||
+ list(wp.glob("**/*.csv"))
|
||||
)
|
||||
if not csv_files:
|
||||
self.failed.emit("未找到采样点 CSV 文件。")
|
||||
return
|
||||
csv_path = str(csv_files[0])
|
||||
from src.postprocessing.point_map import SamplingPointMap
|
||||
map_generator = SamplingPointMap(
|
||||
output_dir=str(wp / "14_visualization" / "sampling_maps"),
|
||||
fast_mode=True,
|
||||
)
|
||||
map_path = map_generator.create_sampling_point_map(
|
||||
hyperspectral_path=hyperspectral_path,
|
||||
csv_path=csv_path,
|
||||
point_color="red",
|
||||
point_size=100,
|
||||
point_alpha=0.9,
|
||||
show_north_arrow=True,
|
||||
show_scale_bar=True,
|
||||
show_legend=True,
|
||||
downsample=True,
|
||||
dpi=180,
|
||||
)
|
||||
self.finished_ok.emit(
|
||||
{
|
||||
"task": "sampling_map",
|
||||
"map_path": map_path,
|
||||
"hyperspectral_path": hyperspectral_path,
|
||||
"csv_path": csv_path,
|
||||
}
|
||||
)
|
||||
elif self.task == "spectrum":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
csv_file = self.extra.get("csv_path")
|
||||
wl = self.extra.get("wavelength_start_column", "UTM_Y")
|
||||
n_groups = int(self.extra.get("n_groups", 5))
|
||||
param_cols = self.extra.get("param_cols") or []
|
||||
if param_cols:
|
||||
output_paths: List[str] = []
|
||||
err_lines: List[str] = []
|
||||
for param_col in param_cols:
|
||||
try:
|
||||
out = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(csv_file),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl,
|
||||
n_groups=n_groups,
|
||||
)
|
||||
output_paths.append(out)
|
||||
except Exception as _ex:
|
||||
err_lines.append(f"{param_col}: {_ex}")
|
||||
if not output_paths:
|
||||
self.failed.emit(
|
||||
"所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20])
|
||||
)
|
||||
return
|
||||
self.finished_ok.emit(
|
||||
{
|
||||
"task": "spectrum",
|
||||
"output_paths": output_paths,
|
||||
"errors": err_lines,
|
||||
}
|
||||
)
|
||||
else:
|
||||
param_col = self.extra.get("param_col")
|
||||
out = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(csv_file),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl,
|
||||
n_groups=n_groups,
|
||||
)
|
||||
self.finished_ok.emit(
|
||||
{"task": "spectrum", "output_path": out, "param_col": param_col}
|
||||
)
|
||||
elif self.task == "statistics":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
csv_file = self.extra.get("csv_path")
|
||||
param_cols = self.extra.get("param_cols") or []
|
||||
output_paths = viz.plot_statistical_charts(
|
||||
csv_path=str(csv_file),
|
||||
parameter_columns=param_cols,
|
||||
)
|
||||
self.finished_ok.emit(
|
||||
{"task": "statistics", "output_paths": output_paths}
|
||||
)
|
||||
elif self.task == "scatter":
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
|
||||
training_csv_path = (self.extra.get("training_csv_path") or "").strip()
|
||||
models_dir = (self.extra.get("models_dir") or "").strip()
|
||||
if not training_csv_path or not Path(training_csv_path).is_file():
|
||||
self.failed.emit("训练光谱 CSV 无效或不存在,请确认已选择步骤5输出的文件。")
|
||||
return
|
||||
if not models_dir or not Path(models_dir).is_dir():
|
||||
self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
|
||||
return
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||||
training_csv_path=training_csv_path,
|
||||
models_dir=models_dir,
|
||||
)
|
||||
self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}})
|
||||
elif self.task == "generate_all_selected":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
parts = []
|
||||
|
||||
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
|
||||
|
||||
if self.extra.get("gen_scatter"):
|
||||
if training_csv.is_file():
|
||||
models_dir = wp / "7_Supervised_Model_Training"
|
||||
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||||
training_csv_path=str(training_csv),
|
||||
models_dir=str(models_dir),
|
||||
)
|
||||
count = len(scatter_paths) if scatter_paths else 0
|
||||
parts.append(f"散点图: {count} 个")
|
||||
else:
|
||||
parts.append("散点图: 跳过(无模型目录)")
|
||||
else:
|
||||
parts.append("散点图: 跳过(无训练数据)")
|
||||
|
||||
if self.extra.get("gen_spectrum"):
|
||||
if training_csv.is_file():
|
||||
import pandas as pd
|
||||
df = pd.read_csv(training_csv)
|
||||
wl_col = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl_col, str):
|
||||
idx = int(df.columns.get_loc(wl_col)) + 1
|
||||
else:
|
||||
idx = int(wl_col)
|
||||
param_cols = []
|
||||
if idx > 0 and idx < len(df.columns):
|
||||
param_cols = [
|
||||
c for c in df.columns[:idx]
|
||||
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
|
||||
]
|
||||
if param_cols:
|
||||
spectrum_paths = []
|
||||
for param_col in param_cols:
|
||||
try:
|
||||
path = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(training_csv),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl_col,
|
||||
n_groups=5,
|
||||
)
|
||||
if path:
|
||||
spectrum_paths.append(path)
|
||||
except Exception as e:
|
||||
print(f"生成光谱图失败 ({param_col}): {e}")
|
||||
count = len(spectrum_paths)
|
||||
parts.append(f"光谱图: {count} 个")
|
||||
else:
|
||||
parts.append("光谱图: 跳过(无可用参数列)")
|
||||
else:
|
||||
parts.append("光谱图: 跳过(无训练数据)")
|
||||
|
||||
if self.extra.get("gen_boxplots"):
|
||||
if training_csv.is_file():
|
||||
import pandas as pd
|
||||
df = pd.read_csv(training_csv)
|
||||
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
|
||||
param_cols = [
|
||||
c for c in df.select_dtypes(include=[np.number]).columns
|
||||
if not any(exc in c.lower() for exc in exclude_cols)
|
||||
]
|
||||
wl = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl, str):
|
||||
idx = int(df.columns.get_loc(wl)) + 1
|
||||
else:
|
||||
idx = int(wl)
|
||||
if 0 < idx < len(df.columns):
|
||||
meta_set = set(df.columns[:idx])
|
||||
param_cols = [c for c in param_cols if c in meta_set]
|
||||
|
||||
if param_cols:
|
||||
output_dict = viz.plot_statistical_charts(
|
||||
csv_path=str(training_csv),
|
||||
parameter_columns=param_cols,
|
||||
)
|
||||
count = len([v for v in output_dict.values() if v]) if output_dict else 0
|
||||
parts.append(f"统计图: {count} 个")
|
||||
else:
|
||||
parts.append("统计图: 跳过(无可用水质参数列)")
|
||||
else:
|
||||
parts.append("统计图: 跳过(无训练数据)")
|
||||
|
||||
if self.extra.get("gen_mask_glint"):
|
||||
preview_paths = viz.generate_glint_deglint_previews(
|
||||
work_dir=str(wp),
|
||||
output_subdir="glint_deglint_previews",
|
||||
)
|
||||
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0} 个")
|
||||
|
||||
if self.extra.get("gen_sampling_map"):
|
||||
hyperspectral_files = []
|
||||
deglint_dir = wp / "3_deglint"
|
||||
if deglint_dir.exists():
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
|
||||
if not hyperspectral_files:
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
|
||||
if hyperspectral_files:
|
||||
hyperspectral_path = str(hyperspectral_files[0])
|
||||
csv_files = []
|
||||
processed_dir = wp / "4_processed_data"
|
||||
if processed_dir.exists():
|
||||
csv_files = list(processed_dir.glob("*.csv"))
|
||||
if not csv_files:
|
||||
csv_files = (
|
||||
list(wp.glob("**/*sampling*.csv"))
|
||||
+ list(wp.glob("**/*point*.csv"))
|
||||
+ list(wp.glob("**/*.csv"))
|
||||
)
|
||||
if csv_files:
|
||||
csv_path = str(csv_files[0])
|
||||
from src.postprocessing.point_map import SamplingPointMap
|
||||
map_generator = SamplingPointMap(
|
||||
output_dir=str(wp / "14_visualization" / "sampling_maps"),
|
||||
fast_mode=True,
|
||||
)
|
||||
map_path = map_generator.create_sampling_point_map(
|
||||
hyperspectral_path=hyperspectral_path,
|
||||
csv_path=csv_path,
|
||||
point_color="red",
|
||||
point_size=100,
|
||||
point_alpha=0.9,
|
||||
show_north_arrow=True,
|
||||
show_scale_bar=True,
|
||||
show_legend=True,
|
||||
downsample=True,
|
||||
dpi=180,
|
||||
)
|
||||
parts.append(f"采样点图: {Path(map_path).name}")
|
||||
else:
|
||||
parts.append("采样点图: 跳过(无CSV)")
|
||||
else:
|
||||
parts.append("采样点图: 跳过(无影像)")
|
||||
self.finished_ok.emit({"task": "generate_all_selected", "parts": parts})
|
||||
else:
|
||||
self.failed.emit(f"未知可视化任务: {self.task}")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
self.failed.emit(f"{e}\n{traceback.format_exc()}")
|
||||
finally:
|
||||
if mpl_prev:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend(mpl_prev)
|
||||
except Exception:
|
||||
pass
|
||||
93
src/gui/crash_dump.txt
Normal file
@ -0,0 +1,93 @@
|
||||
|
||||
============================================================
|
||||
[2026-05-12 11:14:51]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 130, in <module>
|
||||
from src.gui.panels.step9_panel import Step9Panel
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\panels\step9_panel.py", line 24, in <module>
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\core\water_quality_inversion_pipeline_GUI.py", line 45, in <module>
|
||||
from src.preprocessing.process_water_quality_data import process_water_quality_data
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\preprocessing\process_water_quality_data.py", line 9, in <module>
|
||||
from scipy import stats
|
||||
File "<frozen importlib._bootstrap>", line 1412, in _handle_fromlist
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\__init__.py", line 143, in __getattr__
|
||||
return _importlib.import_module(f'scipy.{name}')
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\importlib\__init__.py", line 90, in import_module
|
||||
return _bootstrap._gcd_import(name[level:], package, level)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\__init__.py", line 632, in <module>
|
||||
from ._multicomp import *
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\_multicomp.py", line 11, in <module>
|
||||
from scipy.stats._qmc import check_random_state
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\_qmc.py", line 26, in <module>
|
||||
from scipy.sparse.csgraph import minimum_spanning_tree
|
||||
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\sparse\csgraph\__init__.py", line 188, in <module>
|
||||
from ._shortest_path import (
|
||||
File "scipy/sparse/csgraph/_shortest_path.pyx", line 21, in init scipy.sparse.csgraph._shortest_path
|
||||
File "<frozen importlib._bootstrap>", line 1349, in _find_and_load
|
||||
KeyboardInterrupt
|
||||
|
||||
============================================================
|
||||
[2026-05-12 11:57:28]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
|
||||
main()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3093, in main
|
||||
_dialog.exec_()
|
||||
KeyboardInterrupt
|
||||
|
||||
============================================================
|
||||
[2026-05-28 15:45:11]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
|
||||
main()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3097, in main
|
||||
window = WaterQualityGUI()
|
||||
^^^^^^^^^^^^^^^^^
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1352, in __init__
|
||||
self.init_ui()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1586, in init_ui
|
||||
self.create_content_area()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1943, in create_content_area
|
||||
self.step2_panel = Step2Panel()
|
||||
^^^^^^^^^^^^
|
||||
TypeError: Step2Panel.__init__() missing 1 required positional argument: 'session'
|
||||
|
||||
============================================================
|
||||
[2026-05-28 15:45:19]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
|
||||
main()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3097, in main
|
||||
window = WaterQualityGUI()
|
||||
^^^^^^^^^^^^^^^^^
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1352, in __init__
|
||||
self.init_ui()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1586, in init_ui
|
||||
self.create_content_area()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1943, in create_content_area
|
||||
self.step2_panel = Step2Panel()
|
||||
^^^^^^^^^^^^
|
||||
TypeError: Step2Panel.__init__() missing 1 required positional argument: 'session'
|
||||
|
||||
============================================================
|
||||
[2026-05-28 16:00:53]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 2149, in on_step_changed
|
||||
self.auto_populate_step_inputs(item_data)
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 2362, in auto_populate_step_inputs
|
||||
if step_id not in self.step_dependencies:
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
AttributeError: 'WaterQualityGUI' object has no attribute 'step_dependencies'. Did you mean: '_init_step_dependencies'?
|
||||
|
||||
============================================================
|
||||
[2026-06-03 13:56:59]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3354, in <module>
|
||||
main()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3331, in main
|
||||
sys.exit(app.exec_())
|
||||
^^^^^^^^^^^
|
||||
KeyboardInterrupt
|
||||
@ -325,7 +325,7 @@ class Step3Panel(QWidget):
|
||||
}
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['water_mask'] = water_mask_path
|
||||
config['water_mask_path'] = water_mask_path
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
@ -366,8 +366,8 @@ class Step3Panel(QWidget):
|
||||
"""设置配置"""
|
||||
if 'img_path' in config:
|
||||
self.img_file.set_path(config['img_path'])
|
||||
if 'water_mask' in config:
|
||||
self.water_mask_file.set_path(config['water_mask'])
|
||||
if 'water_mask_path' in config:
|
||||
self.water_mask_file.set_path(config['water_mask_path'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
if 'reference_csv' in config:
|
||||
|
||||
@ -187,7 +187,7 @@ class Step5_5Panel(QWidget):
|
||||
def get_config(self):
|
||||
selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
|
||||
return {
|
||||
'training_spectra_path': self.training_data_widget.get_path(),
|
||||
'training_csv_path': self.training_data_widget.get_path(),
|
||||
'formula_csv_file': self.builtin_formula_path,
|
||||
'formula_names': selected,
|
||||
'output_file': self.output_file_widget.get_path(),
|
||||
@ -195,7 +195,7 @@ class Step5_5Panel(QWidget):
|
||||
}
|
||||
|
||||
def set_config(self, config):
|
||||
if 'training_spectra_path' in config: self.training_data_widget.set_path(config['training_spectra_path'])
|
||||
if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path'])
|
||||
if 'formula_names' in config:
|
||||
sel = set(config['formula_names'])
|
||||
for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
|
||||
@ -217,7 +217,7 @@ class Step5_5Panel(QWidget):
|
||||
|
||||
def run_step(self):
|
||||
config = self.get_config()
|
||||
if not config['training_spectra_path']:
|
||||
if not config['training_csv_path']:
|
||||
QMessageBox.warning(self, "提示", "请先选择输入数据")
|
||||
return
|
||||
parent = self.parent()
|
||||
|
||||
@ -124,7 +124,7 @@ class Step5Panel(QWidget):
|
||||
glint_mask_path = self.glint_mask_file.get_path()
|
||||
if glint_mask_path:
|
||||
config['glint_mask_path'] = glint_mask_path
|
||||
# 注意:step5_extract_training_spectra 不接受 output_path / training_spectra_path
|
||||
# 注意:step5_extract_training_spectra 不接受 output_path / training_csv_path
|
||||
# 参数,输出路径由 pipeline 内部根据 training_spectra_dir 自动生成。
|
||||
return config
|
||||
|
||||
|
||||
@ -363,7 +363,7 @@ class Step6Panel(QWidget):
|
||||
# 回退:从 Step5 的 config 字典中查找可能的键名
|
||||
step5_cfg = main_window.step5_panel.get_config()
|
||||
step5_csv = (
|
||||
step5_cfg.get('training_spectra_path')
|
||||
step5_cfg.get('training_csv_path')
|
||||
or step5_cfg.get('output_file')
|
||||
or step5_cfg.get('csv_path')
|
||||
or step5_cfg.get('output_csv')
|
||||
|
||||
BIN
src/gui/scaler_params.pkl
Normal file
@ -1432,7 +1432,7 @@ class WaterQualityGUI(QMainWindow):
|
||||
'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤5可选耀斑掩膜
|
||||
},
|
||||
'step5_5': {
|
||||
'training_spectra_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱
|
||||
'training_csv_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱
|
||||
},
|
||||
'step6': {
|
||||
'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6需要训练光谱数据
|
||||
|
||||