refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名

This commit is contained in:
DXC
2026-06-03 17:29:41 +08:00
parent 517bb28611
commit 343e316799
99 changed files with 9127 additions and 91 deletions

24
.qwen/settings.json Normal file
View File

@ -0,0 +1,24 @@
{
"permissions": {
"allow": [
"Bash(\"c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe\" *)",
"Bash(get-childitem *)",
"Bash(select-object *)",
"Bash(python *)",
"Bash(where *)",
"Bash(conda *)",
"Bash(dir *)",
"Bash(cmd *)",
"Bash(del *)",
"Bash(powershell *)",
"Bash(git *)",
"Bash(type *)",
"Bash(.\\venv\\scripts\\python.exe *)",
"Bash(\"d:\\111\\office\\zhlduijie\\1.wq\\wq_gui\\venv\\scripts\\python.exe\" *)",
"Bash(c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe *)",
"Bash(venv\\scripts\\python.exe *)",
"Bash(findstr *)"
]
},
"$version": 4
}

23
.qwen/settings.json.orig Normal file
View File

@ -0,0 +1,23 @@
{
"permissions": {
"allow": [
"Bash(\"c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe\" *)",
"Bash(get-childitem *)",
"Bash(select-object *)",
"Bash(python *)",
"Bash(where *)",
"Bash(conda *)",
"Bash(dir *)",
"Bash(cmd *)",
"Bash(del *)",
"Bash(powershell *)",
"Bash(git *)",
"Bash(type *)",
"Bash(.\\venv\\scripts\\python.exe *)",
"Bash(\"d:\\111\\office\\zhlduijie\\1.wq\\wq_gui\\venv\\scripts\\python.exe\" *)",
"Bash(c:\\users\\duxin\\appdata\\local\\programs\\python\\python311\\python.exe *)",
"Bash(venv\\scripts\\python.exe *)"
]
},
"$version": 4
}

View File

@ -0,0 +1,141 @@
---
name: 代码替换请求的现状审计
description: 处理用户"代码替换/新增"指令时,先审计磁盘真实状态再用 ask_user_question 确认——避免覆盖已落盘的高版本代码
source: auto-skill
extracted_at: '2026-06-03T05:36:58.746Z'
---
# 代码替换请求的现状审计
## 适用场景
用户给出"代码替换"或"按某版本代码新增"指令,但**没有提供与磁盘当前状态对比信息**时。典型触发:
- 用户贴了一段代码说"请帮我写/替换这个"
- 用户引用某个文档/旧版本/旧 chat 说"按这个来"
- 之前的 `state_snapshot` / `memory` / `git log` 描述可能与磁盘现状不一致
## 核心原则
**永远不要盲信"用户给的代码是最新版本"**——磁盘上的代码可能已经是更完善的版本(用户或其他 agent 已迭代过)。覆盖 = 丢功能。
直接覆盖的代价不一定是显式 bug也可能是"丢失用户已批准的设计决策"(如 duck-type 探测 / ctx 抽象 / 信号协议 / 二次确认窗 / 错误定位)。
## 5 步标准操作
### 1. 确认文件存在
`glob``list_directory` 看目标文件是否已存在:
- 不存在 → 新建
- 存在 → 进入第 2 步审计
### 2. grep 关键符号 + 读关键段
- 找"用户贴的代码"里的 3-5 个关键符号(函数名 / 类名 / 关键常量 / import
- 在磁盘文件里 grep 同样的符号
- `read_file` 关键段(行号从 grep 结果直接拿)
### 3. 构造差异对照表
列出:
```
| 目标文件 | 用户贴的版本 | 磁盘现有版本 | 直接覆盖会丢失 |
```
**关键列**"直接覆盖会丢失什么"——让用户判断成本。具体粒度到"功能模块 / 设计决策 / 防御层 / 入口协议",不要写"代码差异"这种空话。
### 4. ask_user_question 让用户拍板
3 个标准选项(措辞可调,但**必须给出现状 + 三选一**
- **A. 保留现状**(推荐,磁盘已是更新版)—— 直接进 Smoke Test
- **B. 强制覆盖到旧版** —— 写明丢什么 + 备份建议git stash / 复制到 `_old.py`
- **C. 混合:只取某段增量** —— 见第 5 步
**不要在第 1 次 ask 时就列具体的"哪段增量"**——先让用户在 A/B/C 之间选。如果选 C再做第 5 步。
### 5. 若用户选 C识别"真正增量"
对比 1.0 vs 2.0,识别 1.0 真正独有的部分2.0 没有的):
- ❌ 排除 1.0 比 2.0 简单的2.0 是超集 / 工厂分层 / 多了 CLI
- ❌ 排除 1.0 整体被 2.0 工厂分层超越的_make_objective vs _build_model + _get_search_space
- ✅ 关注 1.0 独有的功能层(即使 2.0 不"明显"需要)
对每个候选增量,再问一次"采纳哪段"让用户具体选multiSelect=false一次只选 1 段最稳)。
## 落地原则
执行"采纳 1.0 某段增量到 2.0"时:
- **最小化外科手术式编辑**:只动需要动的文件,只改需要改的段
- **保留 2.0 的设计决策**duck-type 探测 / ctx 抽象 / 信号协议 / 二次确认窗 / 错误定位)
- **顶部 import 增量用 `replace_all=False` 单点插入**,避免破坏其他 import 顺序
- **同名变量全链路替换**(如 `self.config``clean_config`)要贯穿 ctx 构造 / v2 调用 / v1 fallback避免双源差异
- **单步模式不一定要清洗**(不走 panel 完整 config与清洗器无关
- **清洗器这种"防患于未然"的代码要给日志**`self.log_message.emit(f"[清洗器] 已删除 N 个未知 key")`)让运行时可见
## 验证三件套
落地后必跑:
1. **AST 语法检查**`ast.parse(open(p, encoding='utf-8-sig').read())` 对 5 个核心文件
- 必加 `utf-8-sig`WQ_GUI 的 water_quality_gui.py line 1 是 BOMplain `utf-8` 必挂
2. **关键符号 grep**确认新代码的关键符号import / 关键函数调用都命中hit 数符合预期
3. **顶层导入测试**:用 mock PyQt5 + `sys.path.insert(0, 'src/gui/core')`,验证模块整体可加载
- PyQt5 mock 模板见下方"参考代码"
- Windows 环境调 Python用 conda env 的 `python.exe` 全路径,不要靠 PATH
## 反例(不要做)
- ❌ "按用户贴的代码原封不动写入"——1.0 简化版的覆盖陷阱
- ❌ "保留 state_snapshot 描述"——state snapshot 可能不准确(写的是意图,磁盘才是事实)
- ❌ "用 git log 反推当前状态"——git log 不能反映工作区未提交改动
- ❌ "靠 memory 推断当前状态"——memory 可能是 22 天前的(已确认过期)
- ❌ "磁盘和用户给的代码看起来一样就不审计"——一行之差可能就是"防弹层"丢失
## 参考代码
### PyQt5 mock 模板worker_thread.py 顶层导入测试)
```python
import os, sys
os.environ['GDAL_FILENAME_IS_UTF8'] = 'YES'
os.environ['SHAPE_ENCODING'] = 'UTF-8'
sys.path.insert(0, 'src/gui/core')
import types
pyqt5 = types.ModuleType("PyQt5")
qtc = types.ModuleType("PyQt5.QtCore")
class _QThread:
def __init__(self, *a, **kw): pass
class _Signal:
def __init__(self, *a, **kw): pass
qtc.QThread = _QThread
qtc.pyqtSignal = _Signal
qtc.Qt = type("Qt", (), {"QueuedConnection": 1, "UserRole": 0})()
sys.modules["PyQt5"] = pyqt5
sys.modules["PyQt5.QtCore"] = qtc
import worker_thread
# 副作用: check_pipeline_dependencies() 会打印依赖检查日志(可忽略)
```
### Windows 上跑 conda env python
```bat
cmd /c "D:\xxx\anconda\envs\XXX\python.exe D:\path\to\script.py"
```
PowerShell 单行 `python -c "..."` 在中文路径 / 双引号 / 单引号嵌套时易翻车,**写临时 .py 文件再用 `cmd /c` 调**最稳。
## 案例来源2026-06-03 WQ_GUI 路线 B MVP
- 用户贴 1.0 简化版300 行 automl_trainer / 简化 worker_thread.run() / 简化 on_run_all_clicked
- 磁盘上 2.0 落盘版545 行 automl_trainer_build_model + _get_search_space 工厂 / argparse CLI/ duck-type 探测 v2 + PipelineContext 抽象 / 完整二次确认窗 / 失败步骤 _focus_step 定位 / [DEPRECATED] stop 保留
- 1.0 唯一真增量 = **"防弹级参数清洗器"**method_map 14 项 + inspect.signature 过滤未知 key + has_kwargs 豁免 + 未知 key 数量日志)
- 落地worker_thread.py:run() 内 set_callback 之后插入 53 行清洗器self.config 6 处替换为 clean_config
- 验证5 文件 AST 全通过 + 关键符号 7 项命中 + PyQt5 mock 下 import 成功
- 净增行数407 → 457+50 行)

View File

@ -0,0 +1,206 @@
---
name: WQ_GUI 数据流转架构
description: WQ_GUI ProjectSession 事件总线驱动的步骤间数据传递机制(完整重构版)
source: auto-skill
extracted_at: '2026-05-28T09:07:34.967Z'
---
# WQ_GUI 数据流转架构
## 核心结论
整个系统是**基于文件路径驱动**的管道,所有数据存储在本地磁盘。重构后通过 `ProjectSession` 事件总线实现 Panel 间完全解耦。
---
## 1. 旧架构(旧代码中已删除)
主窗口通过 `self.step_outputs` 字典 + `step_dependencies` 配置 + `auto_populate_*` 系列方法管理步骤间路径填充。存在高度耦合问题:
```python
# 已废弃并删除
self.step_outputs = {}
self._init_step_dependencies()
self.update_step_outputs(step_name, work_path)
self.auto_populate_dependent_steps(completed_step)
self.auto_populate_step_inputs(step_id)
self.find_step_output(work_path, step_id, output_type)
self.add_auto_fill_buttons_to_panels()
self.scan_work_directory_for_files(work_path)
```
---
## 2. 新架构ProjectSession 事件总线
### Session 核心 API`src/core/project_session.py`
```python
class ProjectSession(QObject):
path_updated = pyqtSignal(str, str, str) # step, out_type, path
step_outputs_ready = pyqtSignal(str, str) # step, out_type
def update_output(step, out_type, path):
"""Panel 完成后广播输出路径"""
def update_outputs(step, {out_type: path, ...}):
"""Panel 完成后批量广播多个输出路径"""
def get_output(step, out_type):
"""Panel 可主动查询上游路径(用于自动填充)"""
def get_step_outputs(step):
"""返回该 step 的全部输出字典"""
def scan_work_directory():
"""主窗口 on_step_completed 末尾调用,扫描并广播所有已知路径"""
```
### Panel 重构模板
```python
class StepXPanel(QWidget):
def __init__(self, session=None, parent=None):
super().__init__(parent)
self.session = session
self.work_dir = None
self.init_ui()
self._bind_session_signals()
def _bind_session_signals(self):
if not self.session:
return
self.session.path_updated.connect(
self._on_session_path_updated, Qt.QueuedConnection
)
@pyqtSlot(str, str, str)
def _on_session_path_updated(self, step_name, output_type, path):
print(f"[StepX Debug] 收到广播: step={step_name}, type={output_type}, path={path}")
if step_name == 'step1':
if output_type == 'reference_img':
if not self.img_file.get_path().strip():
self.img_file.set_path(path)
print(f"[StepX] 自动填充参考影像: {path}")
elif output_type == 'water_mask':
if not self.water_mask_file.get_path().strip():
self.water_mask_file.set_path(path)
print(f"[StepX] 自动填充水域掩膜: {path}")
# ...
def on_step_finished(self, success, message):
"""由主窗口 on_step_completed 通过 getattr 动态调用"""
if not success:
return
if self.session:
outputs = {}
path = self.output_widget.get_path().strip()
if path:
outputs['output_type'] = path
if outputs:
self.session.update_outputs('stepX', outputs)
```
### 主窗口两处改动
```python
# 1. __init__ 中注入 session所有 Panel 统一注入)
self.step1_panel = Step1Panel(session=self.session)
self.step2_panel = Step2Panel(session=self.session)
self.step3_panel = Step3Panel(session=self.session)
self.step4_panel = Step4Panel(session=self.session)
self.step5_panel = Step5Panel(session=self.session)
self.step5_5_panel = Step5_5Panel(session=self.session)
self.step6_panel = Step6Panel(session=self.session)
self.step6_5_panel = Step6_5Panel(session=self.session)
self.step6_75_panel = Step6_75Panel(session=self.session)
self.step7_panel = Step7Panel(session=self.session)
self.step8_panel = Step8Panel(session=self.session)
self.step8_5_panel = Step8_5Panel(session=self.session)
self.step8_75_panel = Step8_75Panel(session=self.session)
self.step9_panel = Step9Panel(session=self.session)
# 2. on_step_completed通用动态获取无需维护字典
def on_step_completed(self, step_name, success, message):
if not success:
return
if hasattr(self, 'session') and self.session:
self.session.scan_work_directory()
panel = getattr(self, f"{step_name}_panel", None)
if panel and hasattr(panel, 'on_step_finished'):
panel.on_step_finished(success, message)
```
---
## 3. 全链路事件流
### step1 → step2 / step3 路径(通过 Shapefile 栅格化产物)
| 场景 | 广播的 water_mask 路径 |
|------|----------------------|
| NDWI 模式 | `output_file` 用户指定路径 |
| Shapefile 模式 | `{work_dir}/1_water_mask/water_mask_from_shp.dat`(优先)<br>若文件不存在则 fallback 回 `mask_file.get_path()` |
```
step1 完成
→ step1_panel.on_step_finished()
→ session.update_outputs('step1', {
'reference_img': img_path,
'water_mask': mask_path # 可能是 .dat 或 .shp见上表
})
→ step2_panel._on_session_path_updated()
→ step3_panel._on_session_path_updated()
```
### step3 → step5 / step7step5 → 下游训练
```
step3.deglint_image ──┬─→ step5.deglint_image填充 img_file
└─→ step7.deglint_image填充 img_file
step5.training_spectra ──┬─→ step5_5.index_features
├─→ step6.models_dir ──→ step8.predictions
├─→ step6_5.models_dir ──→ step8_5.predictions
└─→ step6_75.models_dir ──→ step8_75.predictions
step7.sampling_points ──┬─→ step8
├─→ step8_5
└─→ step8_75
step8/8_5/8_75.predictions ──→ step9.distribution_map
```
### 各 Panel 监听/发布对照表(完整版)
| Panel | 监听 | 发布 |
|-------|------|------|
| step1 | — | `reference_img`, `water_mask` |
| step2 | `step1.reference_img`, `step1.water_mask` | `glint_mask` |
| step3 | `step1.reference_img`, `step1.water_mask`, `step2.glint_mask` | `deglint_image` |
| step4 | — | `processed_data` |
| step5 | `step3.deglint_image`, `step4.processed_data`, `step2.glint_mask` | `training_spectra` |
| step5_5 | `step5.training_spectra` | `index_features` |
| step6 | `step5.training_spectra` | `models_dir` |
| step6_5 | `step5.training_spectra` | `models_dir` |
| step6_75 | `step5.training_spectra` | `models_dir` |
| step7 | `step3.deglint_image`, `step1.water_mask`, `step2.glint_mask` | `sampling_points` |
| step8 | `step7.sampling_points`, `step6.models_dir` | `predictions` |
| step8_5 | `step7.sampling_points`, `step6_5.models_dir` | `predictions` |
| step8_75 | `step7.sampling_points`, `step6_75.models_dir` | `predictions` |
| step9 | `step8.predictions`, `step8_5.predictions`, `step8_75.predictions` | `distribution_map` |
---
## 4. 关键约束
- `__init__` 参数 `session=None`(向后兼容,主窗口可继续不传)
- 所有 Panel 的 `init_ui / get_config / set_config / update_from_config` 完整保留
- 删除所有 `self.window().stepX_panel` 跨界访问
- 使用 `self.session.get_output()` 替代直接读取其他 panel 的 widget
- 监听使用 `Qt.QueuedConnection` 确保跨线程安全
- 仅在 field 为空时自动填充(`not widget.get_path().strip()`
- `update_from_config` 中优先从 Session 获取路径,再用 Session 广播
- 主窗口 `on_step_completed` 中使用 `getattr(self, f"{step_name}_panel", None)` 实现通用动态获取,无需维护硬编码字典
- `step1` Shapefile 模式下,**不能**直接广播 `.shp` 输入文件,必须拼接 `{work_dir}/1_water_mask/water_mask_from_shp.dat` 作为产物路径

View File

@ -0,0 +1,229 @@
---
name: WQ_GUI 前端 Vue3 + Element Plus 脚手架
description: WQ_GUI 项目 frontend/ 目录的 Vite + Vue 3 + TS + Element Plus 最小可运行脚手架,以及 useTaskPoller 与 Element Plus UI 的接线模式
source: auto-skill
extracted_at: '2026-06-02T08:17:33.116Z'
---
# WQ_GUI 前端脚手架 (Vue 3 + Element Plus)
## 适用场景
为 WQ_GUI FastAPI 后端 (`127.0.0.1:8000`) 搭建一个**最小可联调**的浏览器控制台。
后端已暴露:
- `POST /api/modeling/train``{ task_id, status, kind }`
- `POST /api/modeling/predict``{ task_id, status, kind }`
- `GET /api/tasks/{task_id}``TaskRecord`(含 PENDING/PROCESSING/SUCCESS/FAILED + 模型指标 / 输出路径)
- `GET /api/algorithms` → 算法清单
前端已有 (`frontend/src/`):
- `api/request.ts`axios 单例 + 响应拦截器自动 unwrapbaseURL 走 `VITE_API_BASE_URL` 缺省 `http://127.0.0.1:8000`
- `api/tasks.ts`:所有提交 / 查询函数 + 完整 `TaskRecord` / `TaskStatus` / `TaskKind` 类型
- `composables/useTaskPoller.ts`:完整轮询 composable支持 3 种用法(静态 / 响应式 taskId / 手动)
## 1. 一次性补齐的脚手架文件
`frontend/` 初始状态**只有 `src/api``src/composables`**,缺整个 Vite 骨架。直接照下面这 7 个文件铺一遍:
```
frontend/
├── .env.development # VITE_API_BASE_URL=http://127.0.0.1:8000
├── .gitignore # node_modules / dist / .vite
├── env.d.ts # vite/client + ImportMeta + *.vue shim
├── index.html # 挂载 #app
├── package.json
├── tsconfig.json # 严格模式 + @ → src + bundler resolution
├── tsconfig.node.json # 给 vite.config.ts 用
├── vite.config.ts # @ alias + 0.0.0.0:5173
└── src/
├── main.ts
└── App.vue
```
### 锁定版本2026-06 联调通过)
```json
{
"dependencies": {
"vue": "^3.4.27",
"element-plus": "^2.7.5",
"@element-plus/icons-vue": "^2.3.1",
"axios": "^1.7.2"
},
"devDependencies": {
"@types/node": "^20.12.12",
"@vitejs/plugin-vue": "^5.0.4",
"typescript": "^5.4.5",
"vite": "^5.2.11",
"vue-tsc": "^2.0.19"
}
}
```
**`@types/node` 必加**——`vite.config.ts` 用了 `import { fileURLToPath, URL } from 'node:url'`,否则 `npm run build` 类型检查必挂。
### `tsconfig.json` 关键字段
- `"moduleResolution": "bundler"`
- `"allowImportingTsExtensions": true`(配合 `vue-tsc --noEmit`
- `"paths": { "@/*": ["src/*"] }` + `"baseUrl": "."`
- `"include": ["src/**/*.vue"]``vue-tsc` 才会处理 SFC
- `"references": [{ "path": "./tsconfig.node.json" }]`
### `vite.config.ts` 关键字段
```ts
resolve: {
alias: { '@': fileURLToPath(new URL('./src', import.meta.url)) },
},
server: { host: '0.0.0.0', port: 5173 },
```
`0.0.0.0` 方便局域网真机调试;端口冲突时 `strictPort: false` 允许 Vite 自动 +1。
---
## 2. main.ts 模板(全量注册 Element Plus
```ts
import { createApp } from 'vue'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import * as ElementPlusIconsVue from '@element-plus/icons-vue'
import App from './App.vue'
const app = createApp(App)
app.use(ElementPlus)
// 全量注册图标 (<el-icon><Cpu /></el-icon>)
for (const [name, component] of Object.entries(ElementPlusIconsVue)) {
app.component(name, component)
}
app.mount('#app')
```
联调期**全量注册最省事**;后期打包体积大再换 `unplugin-vue-components` 按需。
---
## 3. useTaskPoller 接线模式(双实例)
训练 / 推断是**两条独立流水线**,各起一个 `useTaskPoller` 实例。核心套路:把 `task_id` 包成 `ref<string | null>(null)`composable 内部 `watch` 会**自动 start()**,无需手动调:
```ts
import { ref, watch, computed } from 'vue'
import { submitTrain, submitPredict, type TaskRecord } from './api/tasks'
import { useTaskPoller } from './composables/useTaskPoller'
// —— 训练 ——
const trainTaskId = ref<string | null>(null)
const trainPoller = useTaskPoller(trainTaskId) // 传 ref 进去, 自动 watch
async function onStartTrain() {
const { task_id } = await submitTrain({ ... })
trainTaskId.value = task_id // 赋值后 watch 触发 start()
}
// —— 推断 ——
const predictTaskId = ref<string | null>(null)
const predictPoller = useTaskPoller(predictTaskId)
const modelId = ref('')
// 训练一成功, model_id 自动填入推断输入框
watch(
() => trainPoller.result.value?.model_id,
(newId) => { if (newId) modelId.value = newId },
)
async function onStartPredict() {
const { task_id } = await submitPredict({ model_id: modelId.value, ... })
predictTaskId.value = task_id
}
```
**关键点**
- `trainPoller.result.value` 才是 SUCCESS 后的完整 `TaskRecord``record.value` 是任意时刻(含中间态)的最新记录。模板里同时展示用 `trainPoller.record.value ?? trainPoller.result.value`
- `poller.isPolling.value` / `poller.status.value` / `poller.error.value` / `poller.taskId.value` 都是 `Ref`,模板里必须用 `.value`(它们是嵌套 ref**Vue 模板不会自动 unwrap**)。
---
## 4. el-progress 状态映射
`PollerStatus = 'idle' | 'PENDING' | 'PROCESSING' | 'SUCCESS' | 'FAILED'`
`el-progress``status` 接受 `'' | 'success' | 'warning' | 'exception'`
```ts
function progressOf(status: string): number {
switch (status) {
case 'idle':
case 'PENDING': return 10
case 'PROCESSING':return 60
case 'SUCCESS':
case 'FAILED': return 100
default: return 0
}
}
function progressStatusOf(s: string): '' | 'success' | 'exception' {
if (s === 'SUCCESS') return 'success'
if (s === 'FAILED') return 'exception'
return ''
}
```
模板里 `v-if="poller.isPolling.value || poller.status.value === 'SUCCESS' || poller.status.value === 'FAILED'"` 控制展示。
---
## 5. CSS深色控制台风slate 渐变 + 卡片玻璃态)
```css
.app-root {
min-height: 100vh;
background: linear-gradient(180deg, #0f172a 0%, #1e293b 100%);
color: #e2e8f0;
}
.panel {
background: rgba(30, 41, 59, 0.7) !important;
border: 1px solid rgba(148, 163, 184, 0.18) !important;
}
.app-main {
display: grid;
grid-template-columns: 1fr 1fr; /* 左训练 / 右推断 */
gap: 20px;
}
@media (max-width: 960px) { .app-main { grid-template-columns: 1fr; } }
```
深色背景下 Element Plus 的 `el-form-item__label` / `el-descriptions__label` 默认是黑色文字,必须 `:deep()` 覆盖成浅色。
---
## 6. 启动与验证
```bat
cd /d D:\111\office\ZHLduijie\1.WQ\WQ_GUI\frontend
npm install
npm run dev
```
打开 `http://127.0.0.1:5173/`,联调期望路径:
1. 左侧「开始训练」→ 立即拿到 `task_id` + 黄色 `轮询中` + 进度条 60%
2. 后端 SUCCESS → 进度条变绿,下面出现 `model_id` 标签 + R²/RMSE/MAE
3. 右侧 `model_id` 被自动填入 → 「开始推断」→ 走 `output_zarr_path` 展示
4. 任何一步 FAILED → 进度条变红 + 后端 `error` 字段
---
## 7. 已知 caveat
- **第一次 `npm install` 约 150MB**,要耐心等。
- `useTaskPoller` 已有 `onUnmounted` 自动清理,**不要再手写 `clearInterval`**。
- `request.ts` 注释里写明 FastAPI dev 期 `allow_origins=["*"]`**不需要配 Vite proxy**;如果未来后端收紧 CORS再在 `vite.config.ts``server.proxy['/api']`
- `feature_start` 后端接受 `number | string`el-input v-model 出来是 string**直接传给 API 即可**,后端会自己判别。
- `v-model``ref<number | string>(4)` 类型注解是必须的,否则 TS 会推断成 `Ref<number>`,输入框失焦报错。
- `@element-plus/icons-vue` 全量注册后用 `<el-icon><Cpu /></el-icon>` 调,本期 App.vue 没用到但留着扩展位。

BIN
data/icons-1/1.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

BIN
data/icons-1/10.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
data/icons-1/11.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.3 KiB

BIN
data/icons-1/2.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

BIN
data/icons-1/3.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

BIN
data/icons-1/4.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

BIN
data/icons-1/5.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

BIN
data/icons-1/6.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

BIN
data/icons-1/7.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

BIN
data/icons-1/8.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

BIN
data/icons-1/9.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 79 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

BIN
data/icons-1/IRIS.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

BIN
data/icons-1/fenmian.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

BIN
data/icons-1/lica.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

BIN
data/icons-1/liucheng.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

BIN
data/icons-1/logo.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

BIN
data/icons-1/table.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 47 KiB

BIN
data/icons-1/uitubiao.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

BIN
data/icons/uitubiao.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 204 KiB

85
data/格式转化.py Normal file
View File

@ -0,0 +1,85 @@
import os
from pathlib import Path
from PIL import Image
def batch_convert_to_ico(source_dirs, output_dir, target_size=(256, 256)):
"""
批量将指定目录下的图像文件转换为 ICO 格式。
:param source_dirs: 包含源文件夹路径的列表
:param output_dir: 转换后 ICO 文件的保存目录
:param target_size: 输出 ICO 的尺寸,默认 256x256
"""
# 支持的常见输入图像后缀
supported_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.webp', '.tiff'}
# 确保输出目录存在,若无则自动创建
out_path = Path(output_dir)
out_path.mkdir(parents=True, exist_ok=True)
total_converted = 0
total_failed = 0
print("=" * 50)
print(f"🚀 开始批量转换 ICO 图标...")
print(f"📁 目标输出目录: {out_path}")
print("=" * 50)
# 遍历所有传入的源目录
for folder in source_dirs:
folder_path = Path(folder)
if not folder_path.exists():
print(f"⚠️ 警告: 源目录不存在,已跳过 -> {folder_path}")
continue
print(f"\n📂 正在扫描目录: {folder_path}")
# 遍历目录下的所有文件
for file_path in folder_path.iterdir():
# 仅处理普通文件且后缀在支持列表内(忽略大小写)
if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
try:
with Image.open(file_path) as img:
# 处理透明通道问题:
# 如果图片支持透明通道 (RGBA/P/LA),转为 RGBA 确保透明背景不丢失
# 如果是普通 RGB (如 JPG),转为 RGB
if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info):
img_clean = img.convert('RGBA')
else:
img_clean = img.convert('RGB')
# 构造输出文件名 (原文件名.ico)
new_filename = f"{file_path.stem}.ico"
save_path = out_path / new_filename
# 如果目标文件夹中已存在同名文件,为了防止覆盖,可以在文件名后加个标识
# 但通常图标库同名直接覆盖较符合需求,这里默认直接保存
img_clean.save(save_path, format="ICO", sizes=[target_size])
print(f" ✅ 成功: {file_path.name} -> {new_filename}")
total_converted += 1
except Exception as e:
print(f" ❌ 失败: 无法转换 {file_path.name},错误信息: {e}")
total_failed += 1
print("\n" + "=" * 50)
print("🎉 转换任务结束!")
print(f"统计: 成功转换 {total_converted} 个文件,失败 {total_failed} 个。")
print("=" * 50)
if __name__ == "__main__":
# 1. 定义你要读取的两个源文件夹路径列表
SOURCES = [
r"D:\111\office\ZHLduijie\1.WQ\WQ_GUI\data\icons",
r"D:\111\office\ZHLduijie\1.WQ\WQ_GUI\data\icons\word"
]
# 2. 定义统一输出的目标文件夹路径
OUTPUT = r"D:\111\office\ZHLduijie\1.WQ\WQ_GUI\data\icons-1"
# 执行转换
batch_convert_to_ico(SOURCES, OUTPUT)

View File

@ -0,0 +1,350 @@
# Smoke Test — 路线 B MVPPipelineContext + AutoML + 软取消 + GUI 缝合)
> 适用范围:路线 B 重构 4 部分pipeline 包 / AutoML 训练器 / WorkerThread 软取消 / GUI 一键全自动)落盘后的端到端点火试飞清单。
> 目标:**用最小数据集1 个 BSQ + 1 个 CSV在 1020 分钟内验证全链路打通**。
---
## 0. 前置准备5 分钟)
### 0.1 装 Optuna
`environment.yml` 当前**未列** optuna属于本次重构新增依赖。若不装Step 6 会自动降级到老 GridSearchCV仍能跑通但会触发 fallback 日志)。
```bash
call venv\Scripts\activate.bat
pip install "optuna>=3.6,<4.0"
```
写入 `environment.yml` 的 patch提交时改
```yaml
# 路线 B AutoML 防爆引擎(可选;未装时 Step 6 走老 GridSearchCV 降级路径)
- optuna>=3.6
```
### 0.2 准备最小数据集
```text
work_dir_smoke/
├── raw/
│ ├── sample.b # 假彩色 BSQ任意小分辨率都行建议 50×50×6 波段)
│ ├── sample_mask.tif # (可选)水域掩膜;不提供则 Step 1 自动生成 NDWI
│ └── sample.csv # 含 36 个水质参数目标列Chl-a / TSS / SD / TN / TP / COD…+ 6 列波段反射率
└── (其他文件由流程自动生成)
```
**CSV 模板示例**`feature_start_column` 默认为第一列;目标列必须**在特征列之前**
```csv
Chl-a,TSS,SD,B1,B2,B3,B4,B5,B6
12.3,15.1,0.8,0.045,0.052,0.038,0.061,0.072,0.085
11.8,14.2,0.9,0.044,0.051,0.037,0.060,0.071,0.084
... (≥ 200 行AutoML 智能子采样 N>5000 时才生效)
```
### 0.3 启动 venv
```bash
cd /d "D:\111\office\ZHLduijie\1.WQ\WQ_GUI"
call venv\Scripts\activate.bat
set PYTHONPATH=src;%PYTHONPATH%
```
---
## 1. CLI 烟雾最快路径3 分钟)— **A 级:必跑**
跳过 GUI直接验证 `automl_trainer.py` 自身可独立运行 + Optuna 子采样 + 降级路径:
```bash
python -m src.core.prediction.automl_trainer ^
--csv work_dir_smoke/raw/sample.csv ^
--feature-start 6 ^
--n-trials 5 ^
--timeout 60.0 ^
--out work_dir_smoke/7_Supervised_Model_Training_AutoML
```
**通过标准**
- [ ] 进程退出码 0
- [ ] 控制台打印 `AutoML: 目标列 X 共尝试 N 个 trial最佳 CV R²=…`
- [ ] `<out>/<preprocess>/<target>_<preprocess>_<model>_AUTOML.joblib` 存在
- [ ] `<out>/automl_summary.json` 存在且 `success=true`
**若 Optuna 未装**,期待看到:
```
[AutoML] optuna 未安装,全目标列回退老 GridSearchCV
```
产物文件名带 `_AUTOML` 后缀的逻辑此时**不会触发**fallback 走老路径),属正常。
---
## 2. GUI 端到端 9 步核心场景1020 分钟)— **S 级:必跑**
### 2.1 启动 GUI
```bash
call venv\Scripts\activate.bat
set PYTHONPATH=src;%PYTHONPATH%
python -m src.gui.water_quality_gui
```
### 2.2 UI 配置
| 步骤 | 操作 | 期望 |
| ----- | -------------------------------------------------------------------- | ------------------------------------------------------------------------------------ |
| 1/9 | 点"选择工作目录" → 选 `work_dir_smoke/` | 左侧步骤列表高亮UI 不报错 |
| 2/9 | 在 Step 1 面板选 `sample.b`**掩膜留空**(验证 NDWI 自动生成路径) | 掩膜文本框保持空白 |
| 3/9 | 在 Step 4 面板选 `sample.csv` | CSV 路径显示正确 |
| 4/9 | **关键**其他步骤2/3/5/5.5/6/7/8/9保持默认不改任何参数 | AutoML 默认开启use_automl=True |
| 5/9 | 点 **▶ 运行完整流程**(不要用老 `run_full_pipeline` 槽) | 弹出**二次确认窗**,文案显示:<br>• 掩膜:`未指定(将自动生成 NDWI 水域掩膜)`<br>• 去耀斑:开启<br>• AutoML开启Optuna 子采样寻优) |
| 6/9 | 点"是(Y)" | "运行"按钮变灰,"停止"按钮亮起;进度条归零 |
### 2.3 观察日志(重点 4 大检查点)
#### ✅ 检查点 1ctx 路径传递
启动后**第一秒**应看到类似:
```
[Runner] ctx 已构造14 路径字段4 目录字段
[Runner] 步骤 1/14step1_generate_water_maskrequires=['raw_img_path', 'water_mask_path']
[Runner] 步骤 2/14step2_find_glint_arearequires=['raw_img_path', 'water_mask_path', 'output_dir']
...
[Runner] ctx 路径校准water_mask_path = ...\work_dir_smoke\2_Glint_Area_Mask\glint_mask.tif
```
**若没有 `[Runner]` 日志**,说明 v1 旧路径被走到了,**`inspect.signature` duck-type 没探测到 v2**,回去检查 `worker_thread.py:run()`
#### ✅ 检查点 2Step 1 NDWI 自动生成
```
[Step1] 未指定 mask_path自动基于 NDWI 生成水域掩膜
[Step1] NDWI 阈值=0.4,写入 1_Water_Mask/water_mask.tif
```
→ 验证 `<work_dir>/1_Water_Mask/water_mask.tif` 文件存在且非空。
#### ✅ 检查点 3AutoML 启用
```
[Step6] AutoML 启用 Optuna 子采样寻优timeout=300s, n_trials=20, max_samples=5000
[Step6] 目标列 'Chl-a' 共 3 个候选模型,最佳 R²=0.812model=RandomForest
[Step6] 目标列 'TSS' 共 3 个候选模型,最佳 R²=0.745model=XGBoost
[Step6] 训练完成,产物写入 7_Supervised_Model_Training_AutoML/
[Step6] automl_summary.json 写入完成
```
→ 验证产物:
- [ ] `7_Supervised_Model_Training_AutoML/<preprocess>/<target>_<preprocess>_<model>_AUTOML.joblib` ≥ 1 个
- [ ] `7_Supervised_Model_Training_AutoML/automl_summary.json``automl: true` 字段
- [ ] 老目录 `7_Supervised_Model_Training/` **不应该被创建**AutoML 路径独立)
#### ✅ 检查点 4AutoML 降级(仅未装 Optuna 时)
```
[AutoML] optuna 未安装,全目标列回退老 GridSearchCV
[Step6] 降级路径:调用 WaterQualityModelingBatch.train_models_batch132 组 GridSearchCV
```
→ 跑通即可(仍能产生模型文件),但**降级**属于非优选路径。
### 2.4 9 步全程观察清单
| 步 | 期望产物(路径相对 `work_dir` | 期望耗时50×50 测试数据) |
| ---- | -------------------------------------------------------------- | -------------------------- |
| 1 | `1_Water_Mask/water_mask.tif` | < 5 s |
| 2 | `2_Glint_Area_Mask/glint_mask.tif` | < 5 s |
| 3 | `3_Remove_Glint_Image/deglint_image.tif` | < 5 s |
| 4 | `4_Process_CSV/processed_data.csv` | < 2 s |
| 5 | `5_Training_Sample/training_spectra.csv` | < 5 s |
| 5.5 | `5_5_Calculate_Indices/indices.csv`如启用 | < 2 s |
| **6**| `7_Supervised_Model_Training_AutoML/`**新路径** | **< 5 minOptuna 5 trial** |
| 6.5 | `6_5_Non_Empirical_Modeling/`如启用 | 12 min |
| 6.75 | `6_75_Custom_Regression/`如启用 | 12 min |
| 7 | `7_Sampling_Points/sampling_points.csv` | < 3 s |
| 8 | `8_Prediction/predicted_values.csv` | < 5 s |
| 8.5 | `8_5_Prediction_Non_Empirical/predicted.csv`如启用 | < 5 s |
| 8.75 | `8_75_Prediction_Custom/predicted.csv`如启用 | < 5 s |
| 9 | `9_Kriging_Distribution_Map/distribution_map.tif` | 530 s Python |
### 2.5 流程结束
- [ ] 进度条到 100%
- [ ] "运行"按钮恢复可点
- [ ] "停止"按钮变灰
- [ ] 日志末行出现 `=== 流程执行完成 ===` `=== 流程被取消 ===`取决于是否点过停止
- [ ] 控制台 `on_pipeline_finished` 触发UI 状态被统一恢复
---
## 3. 软取消测试3 分钟)— **A 级:必跑**
验证 `threading.Event` 软取消链路不再用 `terminate()`)。
### 3.1 启动完整流程
2.2 启动流程
### 3.2 中途点"停止"
**时机** Step 6 AutoML trials 的中途看到 `[Step6] 目标列 'Chl-a' 共 N 个候选模型` 之后任意时刻"停止"。
**期望看到**
```
[STOP] 用户请求软取消
[Step6] 检测到 cancel_event本 trial 完成后退出
[Step6] AutoML 在 trial #X 中止,已完成 5/20 trial
[Runner] 软取消:跳过剩余 8 个 step
=== 流程被取消 ===
```
UI 状态
- [ ] "运行"按钮重新亮起
- [ ] "停止"按钮变灰
- [ ] 进度条保留在中断时的百分比****归零
- [ ] `on_pipeline_finished` 触发 `success=False, cancelled=True` 区分
- [ ] **Python 进程不退出**GUI 仍可继续点"运行"开新流程
**反例(不应该发生)**
- `QThread: Destroyed while thread is still running` 警告
- Python 解释器直接崩溃
- UI 永远卡死`run_all_btn` 一直是灰的
### 3.3 旧 `stop()` 路径回归
为防老代码忘了改临时把 `water_quality_gui.py:stop_pipeline` 改回 `self.worker.stop()`跑一次完整流程看是否出现
```
[DEPRECATED] WorkerThread.stop() 已弃用,请改用 soft_stop()。
```
**这是预期行为**弃用方法保留但打 warning流程仍能完成即视为通过
---
## 4. 失败 / 降级场景5 分钟)— **B 级:选跑**
### 4.1 未填掩膜 + NDWI 阈值设极端值
NDWI 阈值设到 `0.9`几乎无水域Step 1 应给出 warning 但不崩
```
[Step1] NDWI 阈值=0.9,水域覆盖率 < 1%,请检查影像
```
### 4.2 CSV 完全无目标列
准备一个**没有目标列的 CSV**全特征列点运行
```
[AutoML] 训练 CSV 不存在或无目标列:未识别出目标列
[Step6] AutoML 全部失败,所有目标列返回 success=False
```
UI 不会崩会在 `automl_summary.json` `error: "未识别出目标列"`
### 4.3 Step 1 路径不存在
Step 1 选了一个**不存在的 .bsq 文件**
```
[Runner] step1_generate_water_mask 异常FileNotFoundError
[STOP] 流程中止在 step 1
```
UI 弹错误窗 + 把左侧步骤列表 `setCurrentRow(0)` 自动定位到 Step 1`_focus_step` 起效)。
### 4.4 Optuna 版本冲突
装一个 `optuna==2.10`API 大改 GUI
```
[AutoML] optuna API 不兼容(>=3.6 要求):<error>
[AutoML] 全目标列回退老 GridSearchCV
```
降级路径生效即视为通过
---
## 5. 验证矩阵 Checklist
复制以下到 PR 描述 / 验收单
```markdown
## 路线 B MVP 验证矩阵
### 代码落盘
- [ ] src/core/pipeline/__init__.py17 行4 export
- [ ] src/core/pipeline/context.pyPipelineContext dataclass
- [ ] src/core/pipeline/runner.pyStepSpec + PIPELINE_STEPS + PipelineRunner
- [ ] src/core/prediction/__init__.py追加 train_with_automl export
- [ ] src/core/prediction/automl_trainer.pyAutoMLResult + train_with_automl + CLI
- [ ] src/core/steps/modeling_step.pyuse_automl 分支 + _train_models_automl
- [ ] src/core/water_quality_inversion_pipeline_GUI.pyrun_full_pipeline_v2 + LEGACY_ATTR_MAP + _sync_legacy_attrs_from_context
- [ ] src/gui/core/worker_thread.pycancel_event + soft_stop + run() duck-type
- [ ] src/gui/water_quality_gui.pyon_run_all_clicked + _collect_minimal_config + 按钮重连)
### CLI 自测
- [ ] A.1 `python -m src.core.prediction.automl_trainer --csv ...` 退出码 0
- [ ] A.2 产物 .joblib 含 `_AUTOML` 后缀
- [ ] A.3 automl_summary.json 含 success=true
### GUI 端到端
- [ ] B.1 启动无 ImportError
- [ ] B.2 二次确认窗文案含 mask 提示 + AutoML 状态
- [ ] B.3 日志含 [Runner] 前缀v2 路径生效)
- [ ] B.4 Step 1 NDWI 自动生成路径生效
- [ ] B.5 9 步产物路径全部存在
- [ ] B.6 流程结束后 UI 状态恢复(运行按钮亮、停止按钮灰)
### 软取消
- [ ] C.1 流程中途点停止cancel_event 触发
- [ ] C.2 流程被取消而非崩溃
- [ ] C.3 UI 状态由 on_pipeline_finished 统一恢复
- [ ] C.4 旧 stop() 调用打 [DEPRECATED] warning
### 降级
- [ ] D.1 Optuna 未装 → 全目标列回退老 GridSearchCV
- [ ] D.2 无目标列 CSV → 写 error 到 summary不崩 UI
- [ ] D.3 不存在文件 → _focus_step 定位到对应 step
```
---
## 6. 已知未做(不在本次范围)
- [ ] Kriging 多进程并行当前 backend="loop" Python
- [ ] Step 5 radius==0 内存优化整波段读入
- [ ] 进度条 sub-step 粒度当前只到 step
- [ ] Step 8 全图预测当前只对采样点预测
- [ ] 全项目搜替换老 `self.worker.stop()` 调用仅本会话改了 `water_quality_gui.py` stop_pipeline
- [ ] `requirements.txt` 同步 Optuna `environment.yml`
- [ ] 单元测试套件`tests/` 目录为空建议用 pytest 覆盖 train_with_automl / PipelineRunner
---
## 7. 出问题找哪里
| 现象 | 看哪里 |
| --------------------------------------------- | ------------------------------------------------------- |
| `[Runner]` 日志没出来 | `worker_thread.py:run()` `inspect.signature` 探测 |
| `[AutoML]` 完全没打 | `modeling_step.py:170` `if use_automl` 是否进了 |
| AutoML `optuna API 不兼容` | `automl_trainer.py:236` `try import` |
| 软取消无反应 | `worker_thread.py:run()` 末尾的 `cancel_event.is_set()` |
| 二次确认窗没出来 | `water_quality_gui.py:on_run_all_clicked` line ~2848 |
| 9 步产物路径错位 | `pipeline/runner.py:PIPELINE_STEPS` `output` 字段 |
| v1 路径被走到 | `_sync_legacy_attrs_from_context` 没调 v2 异常 |
---
> **作者注**:本清单对应**路线 B 一键全自动重构 4 部分全部落盘**的验收场景,编号与 todo 8 同步。
> 跑通 §1 + §2 + §3 三段即视为 MVP 验收通过§4 用于鲁棒性抽查。

View File

@ -0,0 +1,2 @@
# 联调期指向本地 FastAPI dev 服务
VITE_API_BASE_URL=http://127.0.0.1:9090

7
frontend/.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
node_modules
dist
dist-ssr
.vite
*.local
.DS_Store
*.log

15
frontend/env.d.ts vendored Normal file
View File

@ -0,0 +1,15 @@
/// <reference types="vite/client" />
interface ImportMetaEnv {
readonly VITE_API_BASE_URL?: string
}
interface ImportMeta {
readonly env: ImportMetaEnv
}
declare module '*.vue' {
import type { DefineComponent } from 'vue'
const component: DefineComponent<{}, {}, any>
export default component
}

13
frontend/index.html Normal file
View File

@ -0,0 +1,13 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>WQ_GUI · 水质反演联调控制台</title>
</head>
<body>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>

25
frontend/package.json Normal file
View File

@ -0,0 +1,25 @@
{
"name": "wq-gui-frontend",
"private": true,
"version": "0.0.1",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vue-tsc --noEmit && vite build",
"preview": "vite preview",
"type-check": "vue-tsc --noEmit"
},
"dependencies": {
"vue": "^3.4.27",
"element-plus": "^2.7.5",
"@element-plus/icons-vue": "^2.3.1",
"axios": "^1.7.2"
},
"devDependencies": {
"@types/node": "^20.12.12",
"@vitejs/plugin-vue": "^5.0.4",
"typescript": "^5.4.5",
"vite": "^5.2.11",
"vue-tsc": "^2.0.19"
}
}

225
frontend/src/App.vue Normal file
View File

@ -0,0 +1,225 @@
<template>
<div class="dashboard-container">
<h1 class="title">高光谱水质反演控制台</h1>
<el-row :gutter="20">
<el-col :span="12">
<el-card class="box-card" shadow="hover">
<template #header>
<div class="card-header">
<span class="header-title">🚀 模型训练 (Train)</span>
</div>
</template>
<el-form label-position="top">
<el-form-item label="算法选择 (Model Type)">
<el-select v-model="trainForm.model_type" placeholder="请选择算法" class="w-full">
<el-option label="随机森林 (RF)" value="RF" />
<el-option label="支持向量回归 (SVR)" value="SVR" />
<el-option label="线性回归 (LinearRegression)" value="LinearRegression" />
<el-option label="K近邻 (KNN)" value="KNN" />
<el-option label="偏最小二乘 (PLS)" value="PLS" />
</el-select>
</el-form-item>
<el-form-item label="目标参数 (Target)">
<el-input v-model="trainForm.target" placeholder="如 Chl-a" />
</el-form-item>
<el-form-item label="训练数据路径 (CSV 绝对路径)">
<el-input v-model="trainForm.train_data_path" placeholder="如 D:\111\data.csv" />
</el-form-item>
<el-form-item label="特征起始列 (如 4, 或列名)">
<el-input v-model="trainForm.feature_start" placeholder="填写数字或列名" />
</el-form-item>
<el-button type="primary" @click="handleTrain" :loading="trainPoller?.isPolling?.value" class="w-full">
开始训练
</el-button>
</el-form>
<div v-if="trainTaskId" class="status-board">
<p><strong>任务 ID:</strong> <el-tag size="small" type="info">{{ trainTaskId }}</el-tag></p>
<p><strong>当前状态:</strong>
<el-tag :type="getStatusType(trainPoller?.status?.value || 'PENDING')" style="margin-left:10px">
{{ trainPoller?.status?.value || 'PENDING' }}
</el-tag>
</p>
<el-progress
v-if="trainPoller?.isPolling?.value || trainPoller?.status?.value === 'SUCCESS'"
:percentage="trainPoller?.status?.value === 'SUCCESS' ? 100 : 60"
:status="trainPoller?.status?.value === 'SUCCESS' ? 'success' : (trainPoller?.status?.value === 'FAILED' ? 'exception' : '')"
:indeterminate="trainPoller?.isPolling?.value"
/>
<div v-if="trainPoller?.error?.value" class="error-msg">
<el-alert :title="trainPoller.error.value" type="error" :closable="false" show-icon />
</div>
<div v-if="trainPoller?.result?.value?.model_id" class="result-msg">
<el-descriptions border :column="1" size="small" title="训练指标">
<el-descriptions-item label="Model ID">{{ trainPoller.result.value.model_id }}</el-descriptions-item>
<el-descriptions-item label="Test R²">{{ Number(trainPoller.result.value.test_r2).toFixed(4) }}</el-descriptions-item>
<el-descriptions-item label="Test RMSE">{{ Number(trainPoller.result.value.test_rmse).toFixed(4) }}</el-descriptions-item>
</el-descriptions>
</div>
</div>
</el-card>
</el-col>
<el-col :span="12">
<el-card class="box-card" shadow="hover">
<template #header>
<div class="card-header">
<span class="header-title">🎯 模型推断 (Predict)</span>
</div>
</template>
<el-form label-position="top">
<el-form-item label="已训练模型 ID (Model ID)">
<el-input v-model="predictForm.model_id" placeholder="将自动填入左侧训练好的 ID" />
</el-form-item>
<el-form-item label="待推断影像路径 (Zarr 绝对路径)">
<el-input v-model="predictForm.input_zarr_path" placeholder="如 D:\111\image.zarr" />
</el-form-item>
<el-button type="success" @click="handlePredict" :loading="predictPoller?.isPolling?.value" class="w-full">
开始大图反演推断
</el-button>
</el-form>
<div v-if="predictTaskId" class="status-board">
<p><strong>任务 ID:</strong> <el-tag size="small" type="info">{{ predictTaskId }}</el-tag></p>
<p><strong>当前状态:</strong>
<el-tag :type="getStatusType(predictPoller?.status?.value || 'PENDING')" style="margin-left:10px">
{{ predictPoller?.status?.value || 'PENDING' }}
</el-tag>
</p>
<el-progress
v-if="predictPoller?.isPolling?.value || predictPoller?.status?.value === 'SUCCESS'"
:percentage="predictPoller?.status?.value === 'SUCCESS' ? 100 : 50"
:status="predictPoller?.status?.value === 'SUCCESS' ? 'success' : (predictPoller?.status?.value === 'FAILED' ? 'exception' : '')"
:indeterminate="predictPoller?.isPolling?.value"
/>
<div v-if="predictPoller?.error?.value" class="error-msg">
<el-alert :title="predictPoller.error.value" type="error" :closable="false" show-icon />
</div>
<div v-if="predictPoller?.result?.value?.output_zarr_path" class="result-msg">
<el-alert :title="'推断成功!结果已落盘至: ' + predictPoller.result.value.output_zarr_path" type="success" :closable="false" show-icon />
</div>
</div>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script setup lang="ts">
import { ref, watch, reactive } from 'vue'
import { submitTrain, submitPredict } from '@/api/tasks'
import { useTaskPoller } from '@/composables/useTaskPoller'
// 训练表单状态
const trainForm = reactive({
model_type: 'RF',
target: 'Chl-a',
train_data_path: '',
feature_start: '4'
})
const trainTaskId = ref<string | null>(null)
const trainPoller = useTaskPoller(trainTaskId)
// 推断表单状态
const predictForm = reactive({
model_id: '',
input_zarr_path: ''
})
const predictTaskId = ref<string | null>(null)
const predictPoller = useTaskPoller(predictTaskId)
// 自动填入联动
watch(() => trainPoller?.result?.value?.model_id, (newId) => {
if (newId) predictForm.model_id = newId as string
})
// 提交训练
const handleTrain = async () => {
try {
const res = await submitTrain({
model_type: trainForm.model_type,
target: trainForm.target,
train_data_path: trainForm.train_data_path,
feature_start: trainForm.feature_start,
params: {}
})
trainTaskId.value = res.task_id
} catch (e: any) {
console.error('训练接口调用失败', e)
alert('提交失败,请检查后端是否在 9090 端口启动,或按 F12 查看控制台跨域报错')
}
}
// 提交推断
const handlePredict = async () => {
try {
const res = await submitPredict({
model_id: predictForm.model_id,
input_zarr_path: predictForm.input_zarr_path
})
predictTaskId.value = res.task_id
} catch (e: any) {
console.error('推断接口调用失败', e)
}
}
// 样式辅助
const getStatusType = (status: string) => {
if (status === 'SUCCESS') return 'success'
if (status === 'FAILED') return 'danger'
if (status === 'PROCESSING') return 'warning'
return 'info'
}
</script>
<style>
/* 去除全局默认边距 */
body {
margin: 0;
padding: 0;
}
</style>
<style scoped>
.dashboard-container {
padding: 40px;
min-height: 100vh;
background-color: #1e1e2d; /* 科技深色底 */
}
.title {
text-align: center;
margin-bottom: 40px;
color: #ffffff;
font-weight: 300;
letter-spacing: 2px;
}
.header-title {
font-weight: bold;
font-size: 16px;
}
.box-card {
margin-bottom: 20px;
background-color: rgba(255, 255, 255, 0.95);
}
.w-full {
width: 100%;
}
.status-board {
margin-top: 25px;
padding: 20px;
background: #f8f9fa;
border-radius: 8px;
border: 1px solid #e4e7ed;
}
.error-msg, .result-msg {
margin-top: 20px;
}
</style>

View File

@ -0,0 +1,94 @@
/**
* Axios 单例 + 响应拦截器
* --------------------------------
* 1. baseURL 默认指向本地 FastAPI dev 服务。
* 通过 Vite 环境变量 VITE_API_BASE_URL 可覆盖, 例如:
* .env.development: VITE_API_BASE_URL=http://127.0.0.1:8000
* .env.production: VITE_API_BASE_URL=https://api.example.com
*
* 2. 响应拦截器统一 unwrap response.data, 调用方拿到的是真正的业务对象,
* 而不是 AxiosResponse 包装。失败时统一抛 Error, message 优先取
* FastAPI 的 detail 字段。
*
* 3. 类型增强: cast 成 UnwrappedAxiosInstance, 让 request.get<T>(url)
* 的返回类型直接是 T, 而不是 AxiosResponse<T>, 调用方无需二次解包。
*/
import axios, {
type AxiosInstance,
type AxiosRequestConfig,
} from 'axios'
// 在 Vite 下用 import.meta.env; 其它环境 (webpack/直接 ts-node) 兜底到 process.env
type ViteEnv = { env?: Record<string, string | undefined> }
const viteEnv: ViteEnv | undefined =
typeof import.meta !== 'undefined' ? ((import.meta as unknown) as ViteEnv) : undefined
const baseURL: string =
viteEnv?.env?.VITE_API_BASE_URL ??
(typeof process !== 'undefined' && process.env?.VITE_API_BASE_URL) ??
'http://127.0.0.1:9090'
const _instance: AxiosInstance = axios.create({
baseURL,
timeout: 15000,
headers: {
'Content-Type': 'application/json',
},
// FastAPI 开发期 CORS 是 allow_origins=["*"], 不需要带 cookie
withCredentials: false,
})
// ----- 请求拦截器: 预留 token / 日志位 -----
_instance.interceptors.request.use(
(config) => {
// const token = localStorage.getItem('token')
// if (token) config.headers.Authorization = `Bearer ${token}`
return config
},
(error) => Promise.reject(error),
)
// ----- 响应拦截器: unwrap data + 统一错误 message -----
_instance.interceptors.response.use(
(response) => response.data,
(error) => {
const detail = error?.response?.data?.detail
const message =
(typeof detail === 'string' ? detail : detail?.msg) ??
error?.response?.data?.message ??
error?.message ??
'请求失败'
return Promise.reject(new Error(message))
},
)
// ----- 类型增强: 把响应拦截器 unwrap 的事实在类型上表达出来 -----
type UnwrappedAxiosInstance = Omit<
AxiosInstance,
'get' | 'delete' | 'head' | 'options' | 'post' | 'put' | 'patch'
> & {
get<T = unknown>(url: string, config?: AxiosRequestConfig): Promise<T>
delete<T = unknown>(url: string, config?: AxiosRequestConfig): Promise<T>
head<T = unknown>(url: string, config?: AxiosRequestConfig): Promise<T>
options<T = unknown>(url: string, config?: AxiosRequestConfig): Promise<T>
post<T = unknown, D = unknown>(
url: string,
data?: D,
config?: AxiosRequestConfig,
): Promise<T>
put<T = unknown, D = unknown>(
url: string,
data?: D,
config?: AxiosRequestConfig,
): Promise<T>
patch<T = unknown, D = unknown>(
url: string,
data?: D,
config?: AxiosRequestConfig,
): Promise<T>
}
const request = _instance as UnwrappedAxiosInstance
export default request
export { baseURL }

155
frontend/src/api/tasks.ts Normal file
View File

@ -0,0 +1,155 @@
/**
* 与 FastAPI 后端对接的 API 函数
* --------------------------------
* 全部用 request 单例, 调用方拿到的就是业务对象 (response 拦截器已 unwrap)。
*
* 后端路由:
* GET /api/algorithms
* POST /api/process/deglint
* POST /api/modeling/train
* POST /api/modeling/predict (额外, 与 train 配套)
* GET /api/tasks/{task_id}
*/
import request from './request'
// ============================================================
// 通用类型
// ============================================================
/** 后端任务状态机 (与 app.core.task_store.TASK_STORE 保持一致) */
export type TaskStatus = 'PENDING' | 'PROCESSING' | 'SUCCESS' | 'FAILED'
/** 任务类型, 区分去耀斑 / 训练 / 推断 */
export type TaskKind = 'deglint' | 'train' | 'predict'
/** 提交后端后立即返回的最小任务凭证 */
export interface TaskAcceptedResponse {
task_id: string
status: TaskStatus
kind: TaskKind
}
/**
* 任务详情 (与后端 TASK_STORE 里记录的字段对齐, 通用 + 各 kind 增量字段)
* 用 [key: string]: unknown 兜底, 兼容未来后端新增字段
*/
export interface TaskRecord {
task_id: string
kind: TaskKind
status: TaskStatus
// 去耀斑
algorithm?: string
input_zarr_path?: string
output_zarr_path?: string | null
// 训练
model_type?: string
target?: string
train_data_path?: string
feature_start?: number | string
params?: Record<string, unknown>
model_id?: string | null
model_path?: string | null
test_r2?: number | null
test_rmse?: number | null
test_mae?: number | null
n_features?: number | null
n_samples?: number | null
// 推断
// (model_id / input_zarr_path / output_zarr_path 已在上方)
// 失败
error?: string | null
traceback?: string | null
// 元
created_at?: string
updated_at?: string
[key: string]: unknown
}
// ============================================================
// 1) 算法列表 GET /api/algorithms
// ============================================================
export interface AlgorithmInfo {
name: string
doc?: string
}
export interface AlgorithmListResponse {
algorithms: AlgorithmInfo[]
count: number
}
export function getAlgorithms(): Promise<AlgorithmListResponse> {
return request.get<AlgorithmListResponse>('/api/algorithms')
}
// ============================================================
// 2) 提交去耀斑 POST /api/process/deglint
// ============================================================
export interface DeglintParams {
input_zarr_path: string
output_zarr_path?: string
/** 算法自定义参数 (D_max / band 选择等) */
[key: string]: unknown
}
export function submitDeglint(
method: string,
params: DeglintParams,
): Promise<TaskAcceptedResponse> {
return request.post<TaskAcceptedResponse, { method: string; params: DeglintParams }>(
'/api/process/deglint',
{ method, params },
)
}
// ============================================================
// 3) 提交训练 POST /api/modeling/train
// ============================================================
export interface TrainRequest {
model_type: string
target: string
train_data_path: string
/** 特征起始列, int 索引或 str 列名, 默认 4 */
feature_start?: number | string
/** sklearn 估计器超参 */
params?: Record<string, unknown>
}
export function submitTrain(payload: TrainRequest): Promise<TaskAcceptedResponse> {
return request.post<TaskAcceptedResponse, TrainRequest>(
'/api/modeling/train',
payload,
)
}
// ============================================================
// 4) 提交推断 POST /api/modeling/predict (配套, 训练后才能用)
// ============================================================
export interface PredictRequest {
model_id: string
input_zarr_path: string
output_zarr_path?: string
}
export function submitPredict(
payload: PredictRequest,
): Promise<TaskAcceptedResponse> {
return request.post<TaskAcceptedResponse, PredictRequest>(
'/api/modeling/predict',
payload,
)
}
// ============================================================
// 5) 查询任务状态 GET /api/tasks/{task_id}
// ============================================================
export function getTaskStatus(task_id: string): Promise<TaskRecord> {
return request.get<TaskRecord>(
`/api/tasks/${encodeURIComponent(task_id)}`,
)
}

View File

@ -0,0 +1,238 @@
/**
* 任务轮询 Composable (Vue 3 + TypeScript)
* -----------------------------------------
* 用法 1 — 静态 task_id, 立即开始轮询:
* const { status, result, error, waitForCompletion } = useTaskPoller(taskId)
*
* 用法 2 — 响应式 task_id (异步拿到后赋值, 自动开始):
* const taskId = ref<string | null>(null)
* const poller = useTaskPoller(taskId)
* ;(async () => { taskId.value = (await submitTrain({...})).task_id })()
* await poller.waitForCompletion()
*
* 用法 3 — 手动控制:
* const poller = useTaskPoller()
* poller.start(taskId) // 开始
* poller.stop() // 停止
* poller.reset() // 清空状态
*
* 设计要点:
* - 终态 (SUCCESS/FAILED) 自动停止轮询
* - 组件卸载自动清理 (onUnmounted)
* - 网络错误不立刻终止, 计入 error.value 但继续轮询 (兼容临时抖动)
* - waitForCompletion 是单次承诺: SUCCESS resolve(record), FAILED reject(error)
* 外部 stop() 也会 reject
*/
import {
onUnmounted,
ref,
watch,
type MaybeRefOrGetter,
type Ref,
} from 'vue'
import { toValue } from 'vue'
import {
getTaskStatus,
type TaskRecord,
type TaskStatus,
} from '../api/tasks'
// 显式包含 'idle', 用于未开始轮询的初始态
export type PollerStatus = TaskStatus | 'idle'
export interface UseTaskPollerOptions {
/** 轮询间隔 ms, 默认 2000 */
intervalMs?: number
/** task_id 变 null 时是否自动停止, 默认 true */
autoStopOnNull?: boolean
}
export interface UseTaskPollerReturn {
/** 当前任务状态, 初始 'idle' */
status: Ref<PollerStatus>
/** SUCCESS 时的完整任务记录 (含 output_zarr_path / model_id 等) */
result: Ref<TaskRecord | null>
/** FAILED 时的错误描述, 或轮询过程中网络异常的消息 */
error: Ref<string | null>
/** 最新一次拉取到的任务记录 (含 PENDING/PROCESSING 占位字段) */
record: Ref<TaskRecord | null>
/** 是否正在轮询中 */
isPolling: Ref<boolean>
/** 当前轮询的 task_id (可能为 null) */
taskId: Ref<string | null>
/** 开始轮询某 task, 已轮询同一 id 时是 no-op */
start: (taskId: string) => void
/** 主动停止 (会 reject 未完成的 waitForCompletion) */
stop: () => void
/** 清空所有状态回 'idle' */
reset: () => void
/**
* 等到 SUCCESS/FAILED。
* - SUCCESS: resolve(record)
* - FAILED : reject(Error)
* - stop() : reject(Error('Polling stopped'))
* - 组件卸载: reject(Error('Component unmounted'))
* 已处于终态时立刻 resolve/reject, 不重复等待。
*/
waitForCompletion: () => Promise<TaskRecord>
}
export function useTaskPoller(
taskIdSource?: MaybeRefOrGetter<string | null>,
options: UseTaskPollerOptions = {},
): UseTaskPollerReturn {
const { intervalMs = 2000, autoStopOnNull = true } = options
const status = ref<PollerStatus>('idle')
const result = ref<TaskRecord | null>(null)
const error = ref<string | null>(null)
const record = ref<TaskRecord | null>(null)
const isPolling = ref(false)
const taskId = ref<string | null>(null)
let timerId: ReturnType<typeof setInterval> | null = null
let inFlightTick = false
let resolveWait: ((rec: TaskRecord) => void) | null = null
let rejectWait: ((err: Error) => void) | null = null
function clearTimer() {
if (timerId !== null) {
clearInterval(timerId)
timerId = null
}
}
function resolveOrRejectWait(rec: TaskRecord | null, err: Error | null) {
const r = resolveWait
const rj = rejectWait
resolveWait = null
rejectWait = null
if (rec && r) r(rec)
else if (err && rj) rj(err)
}
function applyTerminalRecord(rec: TaskRecord) {
record.value = rec
status.value = rec.status
if (rec.status === 'SUCCESS') {
result.value = rec
error.value = null
resolveOrRejectWait(rec, null)
} else if (rec.status === 'FAILED') {
result.value = null
error.value = rec.error ?? '任务失败 (无具体错误信息)'
resolveOrRejectWait(
null,
new Error(error.value ?? '任务失败'),
)
}
}
async function tick() {
const currentId = taskId.value
if (!currentId || inFlightTick) return
inFlightTick = true
try {
const rec = await getTaskStatus(currentId)
// 防止 await 期间用户 stop() / start() 了别的 task
if (taskId.value !== currentId) return
if (rec.status === 'SUCCESS' || rec.status === 'FAILED') {
applyTerminalRecord(rec)
stop()
} else {
// PENDING / PROCESSING 阶段, 更新 record 与 status 供 UI 展示
record.value = rec
status.value = rec.status
}
} catch (e) {
const msg = e instanceof Error ? e.message : String(e)
// 单次失败不立刻终止, 写入 error 但保持轮询
error.value = `轮询异常: ${msg}`
} finally {
inFlightTick = false
}
}
function start(nextId: string) {
if (!nextId) return
// 已在轮询同一 id, 幂等
if (taskId.value === nextId && isPolling.value) return
clearTimer()
taskId.value = nextId
status.value = 'idle'
result.value = null
error.value = null
record.value = null
isPolling.value = true
// 立刻拉一次, 避免 2s 空窗
void tick()
timerId = setInterval(() => void tick(), intervalMs)
}
function stop() {
const wasActive = isPolling.value
clearTimer()
isPolling.value = false
if (wasActive) {
resolveOrRejectWait(null, new Error('Polling stopped'))
}
}
function reset() {
stop()
taskId.value = null
status.value = 'idle'
result.value = null
error.value = null
record.value = null
}
function waitForCompletion(): Promise<TaskRecord> {
const r = record.value
if (r && r.status === 'SUCCESS') return Promise.resolve(r)
if (r && r.status === 'FAILED') {
return Promise.reject(
new Error(r.error ?? '任务失败 (无具体错误信息)'),
)
}
return new Promise<TaskRecord>((resolve, reject) => {
resolveWait = resolve
rejectWait = reject
})
}
// 自动模式: 监听外部 taskIdSource
if (taskIdSource !== undefined) {
const stopWatch = watch(
() => toValue(taskIdSource),
(newId, oldId) => {
if (newId && newId !== oldId) start(newId)
else if (!newId && autoStopOnNull) stop()
},
{ immediate: true },
)
onUnmounted(() => {
stopWatch()
reset()
resolveOrRejectWait(null, new Error('Component unmounted'))
})
} else {
onUnmounted(() => {
reset()
resolveOrRejectWait(null, new Error('Component unmounted'))
})
}
return {
status,
result,
error,
record,
isPolling,
taskId,
start,
stop,
reset,
waitForCompletion,
}
}

9
frontend/src/main.ts Normal file
View File

@ -0,0 +1,9 @@
import { createApp } from 'vue'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import App from './App.vue'
const app = createApp(App)
app.use(ElementPlus)
app.mount('#app')

24
frontend/tsconfig.json Normal file
View File

@ -0,0 +1,24 @@
{
"compilerOptions": {
"target": "ES2020",
"useDefineForClassFields": true,
"module": "ESNext",
"moduleResolution": "bundler",
"strict": true,
"jsx": "preserve",
"resolveJsonModule": true,
"isolatedModules": true,
"esModuleInterop": true,
"lib": ["ES2020", "DOM", "DOM.Iterable"],
"skipLibCheck": true,
"noEmit": true,
"allowImportingTsExtensions": true,
"baseUrl": ".",
"paths": {
"@/*": ["src/*"]
},
"types": ["vite/client"]
},
"include": ["src/**/*.ts", "src/**/*.d.ts", "src/**/*.vue", "env.d.ts"],
"references": [{ "path": "./tsconfig.node.json" }]
}

View File

@ -0,0 +1,12 @@
{
"compilerOptions": {
"composite": true,
"skipLibCheck": true,
"module": "ESNext",
"moduleResolution": "bundler",
"allowSyntheticDefaultImports": true,
"strict": true,
"types": ["node"]
},
"include": ["vite.config.ts"]
}

21
frontend/vite.config.ts Normal file
View File

@ -0,0 +1,21 @@
import { defineConfig } from 'vite'
import vue from '@vitejs/plugin-vue'
import { fileURLToPath, URL } from 'node:url'
// Vite 配置:
// - @ -> frontend/src
// - dev server 监听 0.0.0.0:5173, 允许局域网内真机调试
// - VITE_API_BASE_URL 通过 .env.development 注入, 缺省走 src/api/request.ts 内的兜底 (http://127.0.0.1:8000)
export default defineConfig({
plugins: [vue()],
resolve: {
alias: {
'@': fileURLToPath(new URL('./src', import.meta.url)),
},
},
server: {
host: '0.0.0.0',
port: 5173,
strictPort: false,
},
})

8
license.lic Normal file
View File

@ -0,0 +1,8 @@
{
"version": "1.0",
"product": "WaterQualityInversion",
"machine_code": "76E4992A5CF08BA570D6150908E04755",
"generated_at": "2026-05-28 14:21:35",
"expiry": "2099-12-31",
"signature": "DC9AB900D7033A281E54F41F3F76D026FFA75D635484D40C7F6FC1F6023E02AB"
}

View File

@ -0,0 +1,201 @@
"""
冒烟测试 _run_train_sync: 用合成数据走通真实训练管线。
不依赖 FastAPI / xarray / dask, 只验训练 + 持久化 + 回测。
"""
import sys
import tempfile
from pathlib import Path
import numpy as np
import pandas as pd
# 绕过 main.py 触发 app 包导入(只导入 modeling 模块)
# 当前文件位于 new/app/api/_smoke_test_train.py
# app 包在 new/app/__init__.py, 故 new/ 必须在 sys.path 上
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from app.api.modeling import (
_get_model_pipeline,
_load_train_df,
_resolve_feature_start,
_run_train_sync,
_MODEL_CLASS_REGISTRY,
)
def make_synthetic_csv(n_samples: int = 200, n_features: int = 8, noise: float = 0.1, seed: int = 42) -> Path:
"""生成 [lat, lon, target, lat2, lon2, feat_0, feat_1, ...] 布局的 CSV"""
rng = np.random.default_rng(seed)
lat = rng.uniform(20, 25, n_samples)
lon = rng.uniform(110, 115, n_samples)
target = rng.uniform(0, 50, n_samples)
lat2 = rng.uniform(0, 1, n_samples) # 元数据
lon2 = rng.uniform(0, 1, n_samples) # 元数据
feats = rng.normal(0, 1, (n_samples, n_features))
# 让 y 真正依赖前 3 个特征, RF 至少应该能学到 R² > 0.5
feats[:, 0] += target / 10
feats[:, 1] += target / 20
feats[:, 2] -= target / 15
df = pd.DataFrame({
"lat": lat,
"lon": lon,
"Chl-a": target,
"lat2": lat2,
"lon2": lon2,
**{f"feat_{i}": feats[:, i] for i in range(n_features)},
})
tmp = Path(tempfile.mkdtemp()) / "train.csv"
df.to_csv(tmp, index=False)
return tmp
def test_load_train_df():
print("== test_load_train_df ==")
p = make_synthetic_csv(n_samples=50)
df = _load_train_df(str(p))
assert df.shape == (50, 5 + 8), f"shape={df.shape}"
print(f" shape={df.shape}, columns[:6]={list(df.columns[:6])}")
print(" PASS")
def test_resolve_feature_start_int_and_str():
print("== test_resolve_feature_start (int + str) ==")
p = make_synthetic_csv()
df = _load_train_df(str(p))
idx_int = _resolve_feature_start(df, 5)
idx_str = _resolve_feature_start(df, "feat_0")
assert idx_int == 5 == idx_str, f"int={idx_int}, str={idx_str}"
print(f" int(5) -> {idx_int}, str('feat_0') -> {idx_str}")
print(" PASS")
def test_resolve_feature_start_str_miss():
print("== test_resolve_feature_start (str 不存在 -> 抛错) ==")
p = make_synthetic_csv()
df = _load_train_df(str(p))
try:
_resolve_feature_start(df, "not_exist")
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_get_model_pipeline_all_types():
print("== test_get_model_pipeline (5 种 model_type) ==")
for mt in ["RF", "SVR", "LinearRegression", "KNN", "PLS"]:
p = _get_model_pipeline(mt, {})
assert len(p.steps) == 2
assert p.steps[0][0] == "scaler"
assert p.steps[1][0] == "model"
print(f" 全部通过: {list(_MODEL_CLASS_REGISTRY)}")
print(" PASS")
def test_get_model_pipeline_bad_type():
print("== test_get_model_pipeline (坏 model_type) ==")
try:
_get_model_pipeline("XGBoost", {})
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_run_train_sync_rf_end_to_end():
print("== test_run_train_sync (RF 端到端) ==")
p = make_synthetic_csv(n_samples=200)
out_dir = Path(tempfile.mkdtemp())
out_path = out_dir / "model.joblib"
import time
t0 = time.time()
metadata = _run_train_sync(
model_type="RF",
target="Chl-a",
train_data_path=str(p),
feature_start=5,
params={"n_estimators": 30, "max_depth": 6, "random_state": 42, "n_jobs": 1},
output_model_path=out_path,
)
dt = time.time() - t0
assert out_path.exists(), f"joblib 未落盘: {out_path}"
print(f" joblib 落盘: {out_path} ({out_path.stat().st_size} bytes)")
print(f" metadata.test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f} test_mae={metadata['test_mae']:.4f}")
print(f" metadata.n_features={metadata['n_features']} n_samples={metadata['n_samples']} train_size={metadata['train_size']} test_size={metadata['test_size']}")
print(f" 耗时 {dt:.2f}s")
# 回测: 加载 joblib 再 predict
import joblib
saved = joblib.load(out_path)
assert "model" in saved and "metadata" in saved, f"joblib 双 key 缺失: {saved.keys()}"
assert hasattr(saved["model"], "predict")
assert saved["metadata"]["test_r2"] == metadata["test_r2"]
print(f" joblib 加载 OK, 含 'model''metadata' 双 key")
print(" PASS")
def test_run_train_sync_linearregression_fast():
print("== test_run_train_sync (LinearRegression 快速路径) ==")
p = make_synthetic_csv(n_samples=150)
out_path = Path(tempfile.mkdtemp()) / "lr.joblib"
metadata = _run_train_sync(
model_type="LinearRegression",
target="Chl-a",
train_data_path=str(p),
feature_start=5,
params={},
output_model_path=out_path,
)
print(f" test_r2={metadata['test_r2']:.4f} (LR 学到线性, R² 应 >= 0.4)")
assert metadata["test_r2"] > 0.3, f"LR test_r2={metadata['test_r2']} 太低, 数据生成可能有问题"
print(" PASS")
def test_run_train_sync_bad_csv():
print("== test_run_train_sync (CSV 不存在) ==")
try:
_run_train_sync("RF", "Chl-a", "/no/such/path.csv", 5, {}, Path("/tmp/x.joblib"))
print(" FAIL: 应抛异常")
except (FileNotFoundError, ValueError) as e:
print(f" 正确抛 {type(e).__name__}: {e}")
print(" PASS")
def test_run_train_sync_bad_target():
print("== test_run_train_sync (target 列不存在) ==")
p = make_synthetic_csv()
try:
_run_train_sync("RF", "NopeTarget", str(p), 5, {}, Path("/tmp/x.joblib"))
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_run_train_sync_str_feature_start():
print("== test_run_train_sync (feature_start 用列名) ==")
p = make_synthetic_csv()
out_path = Path(tempfile.mkdtemp()) / "str_fs.joblib"
metadata = _run_train_sync("RF", "Chl-a", str(p), "feat_0", {"n_estimators": 10}, out_path)
assert metadata["feature_start"] == "feat_0"
assert metadata["n_features"] == 8
assert metadata["feature_columns"][0] == "feat_0"
print(f" 列名 'feat_0' 解析正确, n_features={metadata['n_features']}")
print(" PASS")
if __name__ == "__main__":
test_load_train_df()
test_resolve_feature_start_int_and_str()
test_resolve_feature_start_str_miss()
test_get_model_pipeline_all_types()
test_get_model_pipeline_bad_type()
test_run_train_sync_rf_end_to_end()
test_run_train_sync_linearregression_fast()
test_run_train_sync_bad_csv()
test_run_train_sync_bad_target()
test_run_train_sync_str_feature_start()
print("\n>>> ALL SMOKE TESTS PASSED")

222
new/app/api/endpoints.py Normal file
View File

@ -0,0 +1,222 @@
"""
API 路由集合
============
把业务接口统一收口到 APIRouter再由 main.py 通过 include_router 挂载。
当前包含的接口:
GET /api/algorithms 列出已注册的所有去耀斑算法(供前端下拉框)
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
GET /api/tasks/{task_id} 查询指定任务的状态与结果
派发链:
POST /api/process/deglint
└─ BackgroundTasks.add_task(execute_glint_removal_task, ...)
└─ get_remover(method) 从注册表拿到算法类
└─ remover.process(input_zarr, output_zarr, **params)
"""
import traceback
import uuid
from datetime import datetime
from typing import Any, Dict
from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel, Field
# 并发安全的任务状态存储(替代旧版的 MOCK_TASK_DB
from app.core.task_store import get_task, set_task, update_task
# 算法注册表 API
from app.core.algorithms import get_remover, list_removers
# ---------------------------------------------------------------------------
# 路由实例
# ---------------------------------------------------------------------------
# prefix 不在此处设置,统一在 main.py 挂载时给定,便于将来按版本拆分
# (例如 /api/v1、/api/v2 共存时复用同一个 router 对象)。
# ---------------------------------------------------------------------------
router = APIRouter(tags=["deglint"])
# ---------------------------------------------------------------------------
# 请求 / 响应数据模型
# ---------------------------------------------------------------------------
class DeglintRequest(BaseModel):
"""POST /api/process/deglint 的请求体"""
method: str = Field(
...,
description="去耀斑方法名称,必须是已注册算法,例如 'kutser' / 'goodman'",
examples=["kutser"],
)
params: Dict[str, Any] = Field(
default_factory=dict,
description=(
"传递给算法 process() 的超参数字典,例如 "
"Kutser: {'band_lower': 773, 'band_oxy': 845, 'band_upper': 893}; "
"Goodman: {'band_ref': 750, 'band_diff': 640, 'A': 0.0, 'B': 0.0}"
),
examples=[{"band_lower": 773, "band_oxy": 845, "band_upper": 893}],
)
class TaskAcceptedResponse(BaseModel):
"""提交任务成功后立即返回的响应"""
task_id: str
status: str # 一定是 PENDING
class AlgorithmListResponse(BaseModel):
"""GET /api/algorithms 的响应"""
algorithms: list # 已注册算法名列表
count: int # 算法总数
# ---------------------------------------------------------------------------
# 后台任务执行器(真实派发链)
# ---------------------------------------------------------------------------
# 注意:这里使用 async def。
# FastAPI / Starlette 的 BackgroundTasks 支持 async function
# 会在响应返回后自动 await 它,不影响主请求链路。
# ---------------------------------------------------------------------------
async def execute_glint_removal_task(
task_id: str,
method: str,
params: Dict[str, Any],
) -> None:
"""
后台异步执行器:按 method 名字从注册表取出算法类,实例化并运行 process()。
状态机:
PENDING -> PROCESSING -> SUCCESS
└──> FAILED含 error / traceback
"""
# 0. 安全检查任务记录必须已存在POST 阶段已写入)
record = await get_task(task_id)
if record is None:
print(f"[{task_id}] 任务不存在, 跳过")
return
# 1. 状态推进到 PROCESSING
await update_task(
task_id,
status="PROCESSING",
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 开始处理 method={method} params={params}")
# 2. 临时硬编码 IO 路径(未来由数据管理层提供)
# TODO: 替换为真实的数据管理服务返回的 zarr 路径
input_zarr_path = "./data/temp_in.zarr"
output_zarr_path = f"./data/{task_id}_out.zarr"
try:
# 3. 按 method 名字从注册表取算法类并实例化
# get_remover 找不到时会抛 KeyError下面的 except 会兜住
algorithm_cls = get_remover(method)
remover = algorithm_cls()
# 4. 调用算法(注意 await因为 BaseGlintRemover.process 是 async
await remover.process(input_zarr_path, output_zarr_path, **params)
# 5. 成功:写回结果路径与状态
await update_task(
task_id,
status="SUCCESS",
output_zarr_path=output_zarr_path,
error=None,
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 处理完成 -> SUCCESS, output={output_zarr_path}")
except Exception as exc: # noqa: BLE001 顶层兜底,绝不让后台任务静默失败
# 6. 失败:记录错误信息与堆栈,便于前端排查
await update_task(
task_id,
status="FAILED",
output_zarr_path=None,
error=f"{type(exc).__name__}: {exc}",
traceback=traceback.format_exc(),
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 处理失败 -> {type(exc).__name__}: {exc}")
# ---------------------------------------------------------------------------
# GET /algorithms
# ---------------------------------------------------------------------------
# 返回当前已注册的所有算法名,供前端动态渲染下拉框 / 选择器。
# ---------------------------------------------------------------------------
@router.get("/algorithms", response_model=AlgorithmListResponse)
async def list_registered_algorithms() -> Dict[str, Any]:
"""列出已注册的去耀斑算法。"""
names = list(list_removers().keys())
return {"algorithms": names, "count": len(names)}
# ---------------------------------------------------------------------------
# POST /process/deglint
# ---------------------------------------------------------------------------
# 提交去耀斑处理任务。FastAPI 在函数返回后才会把响应发给前端,
# 因此通过 BackgroundTasks 把耗时操作丢到后台,接口本身立刻返回 task_id。
# ---------------------------------------------------------------------------
@router.post("/process/deglint", response_model=TaskAcceptedResponse)
async def submit_deglint(
payload: DeglintRequest,
background_tasks: BackgroundTasks,
) -> Dict[str, Any]:
"""提交一个去耀斑处理任务,并立即返回 task_id。"""
# 1. 生成唯一任务 IDUUID4 足以保证全局唯一性)
task_id = str(uuid.uuid4())
# 2. 在任务库中登记一条 PENDING 记录(并发安全)
# 注意output_zarr_path / error / traceback 字段在执行过程中被填充
await set_task(
task_id,
{
"task_id": task_id,
"method": payload.method,
"params": payload.params,
"status": "PENDING",
"output_zarr_path": None,
"error": None,
"traceback": None,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
},
)
# 3. 把真实执行器丢到后台
background_tasks.add_task(
execute_glint_removal_task,
task_id,
payload.method,
payload.params,
)
# 4. 立即返回 task_id 与 PENDING 状态
return {"task_id": task_id, "status": "PENDING"}
# ---------------------------------------------------------------------------
# GET /tasks/{task_id}
# ---------------------------------------------------------------------------
# 前端轮询此接口获取任务状态。PENDING / PROCESSING 表示仍在跑,
# SUCCESS 表示成功(含 output_zarr_pathFAILED 表示失败(含 error / traceback
# ---------------------------------------------------------------------------
@router.get("/tasks/{task_id}")
async def get_task_status(task_id: str) -> Dict[str, Any]:
"""查询指定任务的当前状态与结果。"""
record = await get_task(task_id)
if record is None:
# 找不到 task_id 通常意味着客户端拼错了 ID或者记录已被清理
raise HTTPException(status_code=404, detail=f"task_id 不存在: {task_id}")
# 直接返回字典FastAPI 会自动 JSON 序列化
return record

786
new/app/api/modeling.py Normal file
View File

@ -0,0 +1,786 @@
"""
app/api/modeling.py
===================
机器学习与水质反演相关的 API 路由。
接口(最终路径, 挂载后):
POST /api/modeling/train 提交模型训练任务, 立即返回 task_id
GET /api/modeling/models 列出已训练好的模型(未来从磁盘 joblib 读)
POST /api/modeling/predict 提交模型推断任务, 立即返回 task_id
设计要点
--------
- 训练 / 推断均为异步后台任务, 复用 app.core.task_store 的并发安全任务状态。
- 模型元数据用模块级 _MODEL_REGISTRY 暂存(开发期内存存储),
未来从磁盘 joblib 读时只需替换 list_trained_models() 内部实现即可。
- /predict 已接入真实 sklearn + xarray + dask 流式推断:
* joblib.load 读模型(缺文件时降级为 Dummy RandomForestRegressor
* xr.open_zarr 延迟打开影像, NaN 填 0
* xr.apply_ufunc(dask="parallelized") 沿 (y, x) 逐 chunk 调 model.predict
* to_zarr(mode="w", compute=True) 流式写出, 内存峰值 ≈ 1 个 chunk
- /train 已接入真实 sklearn + pandas 训练管线:
* pd.read_csv 读结构化训练表(支持 [lat, lon, target_*, feature_*] 布局)
* 按 target 列 dropna 清洗;按 feature_start 索引/列名切分特征
* sklearn Pipeline: StandardScaler -> {RF/SVR/LinearRegression/KNN/PLS}
* train_test_split(80/20) 划分, 计算 test_r2/rmse/mae
* joblib.dump({model, metadata}) 落盘 ./data/models/{model_id}.joblib
* 测试指标写回 TASK_STORE, 同时登记到 _MODEL_REGISTRY
注: 旧版 SPXY / KS 划分留作未来扩展, 当前固定 random 划分 (test_size=0.2, random_state=42)。
"""
import asyncio
import traceback
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import joblib
import numpy as np
import pandas as pd
import xarray as xr
from fastapi import APIRouter, BackgroundTasks
from pydantic import BaseModel, Field
from sklearn.cross_decomposition import PLSRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR
# 复用并发安全任务状态存储(与 deglint 共享同一份 TASK_STORE,
# 通过 task 记录里的 "kind" 字段区分 train / predict / deglint
from app.core.task_store import get_task, set_task, update_task
# ---------------------------------------------------------------------------
# 路由实例
# ---------------------------------------------------------------------------
# prefix="/modeling" 让本文件内只写 /train /models /predict 等短路径,
# 最终完整路径由 main.py 挂载时再补 /api。
# ---------------------------------------------------------------------------
router = APIRouter(prefix="/modeling", tags=["modeling"])
# ---------------------------------------------------------------------------
# 数据模型
# ---------------------------------------------------------------------------
class TrainRequest(BaseModel):
"""POST /api/modeling/train 的请求体"""
model_type: str = Field(
...,
description="模型类型, 例如 'RF' (随机森林) / 'SVR' (支持向量回归) / 'XGBoost' / 'MLP'",
examples=["RF", "SVR"],
)
target: str = Field(
...,
description="反演目标水质参数, 例如 'Chl-a' (叶绿素a) / 'TSS' (总悬浮物) / 'CDOM' (有色可溶有机物)",
examples=["Chl-a", "TSS", "CDOM"],
)
train_data_path: str = Field(
...,
description="训练数据集的 zarr 路径(包含 reflectance 变量与 target 标签)",
examples=["./data/train.zarr"],
)
feature_start: Union[int, str] = Field(
default=4,
description=(
"特征列起始位置. 表格布局假定为 "
"[lat, lon, target_1, target_2, ..., feature_1, feature_2, ...] "
"可传 int 列索引(如 4或 str 列名(如 '374.285' 波长起点)。"
"默认 4, 即前 4 列视为元数据/目标, 之后全部是特征。"
),
examples=[4, "374.285"],
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="模型超参, 例如 RF 的 {'n_estimators': 100, 'max_depth': 20}",
examples=[{"n_estimators": 100, "max_depth": 20}],
)
class PredictRequest(BaseModel):
"""POST /api/modeling/predict 的请求体"""
model_id: str = Field(
...,
description="已训练模型的 ID由 /api/modeling/train 返回或 /api/modeling/models 列出)",
)
input_zarr_path: str = Field(
...,
description="待推断影像的 zarr 路径",
examples=["./data/scene.zarr"],
)
output_zarr_path: Optional[str] = Field(
default=None,
description=(
"输出 zarr 路径, 缺省时由后端按规则生成 "
"(如 ./data/{model_id}_{input_stem}_pred.zarr)"
),
)
class TaskAcceptedResponse(BaseModel):
"""提交训练/推断任务后立即返回的响应"""
task_id: str
status: str # 一定是 PENDING
kind: str # "train" / "predict", 便于前端识别任务类型
class ModelInfo(BaseModel):
"""单个模型的元信息GET /api/modeling/models 的元素)"""
model_id: str
model_type: str
target: str
params: Dict[str, Any]
path: str # joblib 文件路径
created_at: str
train_task_id: str # 产生此模型的那个训练任务的 ID
class ModelListResponse(BaseModel):
"""GET /api/modeling/models 的响应"""
models: List[ModelInfo]
count: int
# ---------------------------------------------------------------------------
# 模块级模型注册表(开发期内存, 未来替换为磁盘扫描)
# ---------------------------------------------------------------------------
# model_id -> ModelInfo 字典
# 读多写少, 用一个普通 dict 足够CPython GIL 兜底)。
# 写时(训练完成时)只发生一次, 无并发风险。
# ---------------------------------------------------------------------------
_MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {}
def _register_model(record: Dict[str, Any]) -> None:
"""将训练完成的模型登记到内存注册表。"""
_MODEL_REGISTRY[record["model_id"]] = record
# ---------------------------------------------------------------------------
# 训练管线的模块级辅助函数
# ---------------------------------------------------------------------------
# 设计要点 (与推断管线一致):
# 1) 模块级函数: dask / joblib 后端若走子进程 pickle, 嵌套闭包会丢字段。
# 2) 同步执行: execute_train_task 用 asyncio.to_thread 派发, 内部全程同步阻塞。
# 3) 失败抛异常: 异常由 execute_train_task 捕获, 转 FAILED + traceback。
# ---------------------------------------------------------------------------
# model_type (大写字符串) -> sklearn 估计器类
# 与 OpenClaw model_configs 思路一致, 但此处只保留类 (参数由 params 透传)
_MODEL_CLASS_REGISTRY: Dict[str, type] = {
"RF": RandomForestRegressor,
"SVR": SVR,
"LinearRegression": LinearRegression,
"KNN": KNeighborsRegressor,
"PLS": PLSRegression,
}
def _get_model_pipeline(model_type: str, params: Optional[Dict[str, Any]]) -> Pipeline:
"""
模型工厂: 按 model_type 选 sklearn 类, 用 StandardScaler + 估计器构造 Pipeline。
与 OpenClaw 不同之处: 把 scaler 放进 Pipeline 第一步,
推断时直接 pipeline.predict(X) 即可, scaler 参数与训练时严格一致。
"""
model_cls = _MODEL_CLASS_REGISTRY.get(model_type)
if model_cls is None:
raise ValueError(
f"不支持的 model_type='{model_type}', "
f"可选: {sorted(_MODEL_CLASS_REGISTRY.keys())}"
)
estimator = model_cls(**(params or {}))
return Pipeline([("scaler", StandardScaler()), ("model", estimator)])
def _load_train_df(csv_path: str) -> pd.DataFrame:
"""
读 CSV 训练表, 规整空串 / 空白 / NULL 等为 NaN。
沿用 OpenClaw modeling_batch.load_data_batch 的读取策略:
na_values 显式列举 + 正则二次清理 (防 cell 内出现 " " 等纯空白)。
"""
try:
df = pd.read_csv(
csv_path,
na_values=["", " ", "NaN", "nan", "NULL", "null"],
)
except FileNotFoundError as exc:
raise FileNotFoundError(f"训练数据文件不存在: {csv_path}") from exc
except pd.errors.EmptyDataError as exc:
raise ValueError(f"训练数据文件为空: {csv_path}") from exc
# 二次清理: 残留的纯空白 cell
df = df.replace(r"^\s*$", np.nan, regex=True)
return df
def _resolve_feature_start(
df: pd.DataFrame,
feature_start: Union[int, str],
) -> int:
"""
将 feature_start (int 索引 / str 列名) 统一解析为 int 列索引。
与 OpenClaw modeling_batch.load_data_batch / load_data_single 一致:
str 走 columns.get_loc, int 直接返回。
"""
if isinstance(feature_start, str):
if feature_start not in df.columns:
raise ValueError(
f"feature_start='{feature_start}' 不在 CSV 列中: {list(df.columns)}"
)
return int(df.columns.get_loc(feature_start))
return int(feature_start)
def _run_train_sync(
model_type: str,
target: str,
train_data_path: str,
feature_start: Union[int, str],
params: Optional[Dict[str, Any]],
output_model_path: Path,
) -> Dict[str, Any]:
"""
完整同步训练流程 (由 execute_train_task 在线程池内调用):
pd.read_csv -> 目标列 dropna -> 切特征 -> train_test_split(80/20)
-> Pipeline(StandardScaler + model).fit -> 评估 test_r2/rmse/mae
-> joblib.dump({model, metadata}, output_model_path)
Returns:
metadata 字典, 含 test_r2 / test_rmse / test_mae / n_features 等,
调用方负责写回 TASK_STORE 和 _MODEL_REGISTRY。
注: 旧版 SPXY / KS 划分留作未来扩展 (params.split_method 控制),
当前固定 random + test_size=0.2 + random_state=42。
"""
df = _load_train_df(train_data_path)
if target not in df.columns:
raise ValueError(
f"target='{target}' 不在 CSV 列中, 可选: {list(df.columns)}"
)
# 1) 清洗: 仅剔除 target NaN 的行 (与 OpenClaw load_data_single 一致)
df = df[df[target].notna()].copy()
if df.empty:
raise ValueError("target 剔除 NaN 后无样本, 终止训练")
# 2) 特征切分
feature_start_idx = _resolve_feature_start(df, feature_start)
feature_columns = list(df.columns[feature_start_idx:])
X = df.iloc[:, feature_start_idx:].astype(np.float64)
y = df[target].astype(np.float64).values
# 3) 划分 (固定 random, 未来扩展 spxy/ks)
X_train, X_test, y_train, y_test = train_test_split(
X.values,
y,
test_size=0.2,
random_state=42,
)
# 4) 构造 Pipeline + 训练
pipeline = _get_model_pipeline(model_type, params)
pipeline.fit(X_train, y_train)
# 5) 测试集与训练集评估
y_pred = pipeline.predict(X_test)
test_r2 = float(r2_score(y_test, y_pred))
test_rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
test_mae = float(mean_absolute_error(y_test, y_pred))
y_train_pred = pipeline.predict(X_train)
train_r2 = float(r2_score(y_train, y_train_pred))
train_rmse = float(np.sqrt(mean_squared_error(y_train, y_train_pred)))
train_mae = float(mean_absolute_error(y_train, y_train_pred))
metadata: Dict[str, Any] = {
"model_type": model_type,
"target": target,
"feature_start": feature_start,
"feature_columns": feature_columns,
"n_features": int(X.shape[1]),
"n_samples": int(X.shape[0]),
"train_size": int(X_train.shape[0]),
"test_size": int(X_test.shape[0]),
"params": dict(params or {}),
"test_r2": test_r2,
"test_rmse": test_rmse,
"test_mae": test_mae,
"train_r2": train_r2,
"train_rmse": train_rmse,
"train_mae": train_mae,
"split_method": "random",
"trained_at": datetime.now().isoformat(),
}
# 7) 持久化 (目录可能不存在, 顺手建)
output_model_path = Path(output_model_path)
output_model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(
{"model": pipeline, "metadata": metadata},
output_model_path,
)
return metadata
# ---------------------------------------------------------------------------
# 推断管线的模块级辅助函数
# ---------------------------------------------------------------------------
# 设计要点:
# 1) Dask 调度时, 函数必须可被工作进程 pickle 序列化。
# 因此 _predict_block / _load_model / _make_dummy_model / _run_predict_sync
# 全部是模块级函数 (而非嵌套), 避免闭包陷阱。
# 2) _predict_block 通过 model.predict(spectra_2d) 整批预测,
# 整张影像的 O(n_pixels * n_bands) 一次性预测在大矩阵上必 OOM,
# 因此外层用 xr.apply_ufunc(dask="parallelized") 把矩阵切块
# 逐块进入此函数, 单次内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) 大小。
# ---------------------------------------------------------------------------
def _make_dummy_model(n_features: int) -> RandomForestRegressor:
"""
构造一个 Dummy 随机森林回归器。
用途:
1) 真实 joblib 文件不存在时的连通性测试
2) 训练骨架尚未接入真实数据时的占位推断
"""
rng = np.random.default_rng(42)
X = rng.random((200, n_features))
y = rng.random(200)
model = RandomForestRegressor(
n_estimators=10, max_depth=5, random_state=0, n_jobs=1
)
model.fit(X, y)
return model
def _load_model(path: str, n_features: int) -> Any:
"""
加载训练好的 sklearn 模型, 失败时降级 Dummy。
优先级:
1) path 存在且 joblib.load 成功 -> 返回真实模型
2) 否则 -> 降级为 Dummy 随机森林 (n_features 必须指定)
"""
p = Path(path)
if p.is_file() and p.stat().st_size > 0:
try:
print(f"[model] 从磁盘加载: {path}")
return joblib.load(path)
except Exception as exc: # noqa: BLE001
print(f"[model] joblib.load 失败 ({type(exc).__name__}: {exc}), 降级 Dummy")
print(f"[model] 真实 joblib 不存在 ({path}), 使用 Dummy RandomForest")
return _make_dummy_model(n_features)
def _predict_block(spectra_3d: np.ndarray, model: Any) -> np.ndarray:
"""
单个 dask chunk 的推断函数 (xr.apply_ufunc 会自动调度调用)。
Parameters
----------
spectra_3d : np.ndarray
形状 (y_chunk, x_chunk, n_bands)。
此形状由 input_core_dims=[["band"]] 决定:
xarray 会把 band 维移到最后一轴, 然后按 (y, x) 的 chunk 切分调用本函数。
model : 已 fit 好的 sklearn 估计器
接受 (n_samples, n_features) 输入, 返回 (n_samples,) 预测。
Returns
-------
np.ndarray
形状 (y_chunk, x_chunk), dtype float32 的标量预测图。
"""
yc, xc, nb = spectra_3d.shape
# 2D 化: 每个像素一行光谱
flat = spectra_3d.reshape(yc * xc, nb)
# sklearn 风格的批量预测
pred = model.predict(flat)
# 还原为 2D 空间图, 强制 float32 节约一半内存
return pred.reshape(yc, xc).astype(np.float32, copy=False)
def _run_predict_sync(
model: Any,
model_id: str,
input_zarr_path: str,
output_zarr_path: str,
) -> None:
"""
同步推断主流程 (被 asyncio.to_thread 调用)。
流程:
1) xr.open_zarr 延迟打开 (dask 数组, 不一次性读入内存)
2) NaN -> 0 清洗 (model.predict 不接受 NaN)
3) xr.apply_ufunc 沿 (y, x) 逐 chunk 调 _predict_block
4) 非水域置 NaN (zarr 支持 float NaN)
5) to_zarr 触发整图计算 + 流式写出
"""
# 1. 延迟打开输入 (关键: Dask 不一次性读入内存)
ds = xr.open_zarr(input_zarr_path, chunks="auto")
if "reflectance" not in ds.data_vars:
raise KeyError(
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds.data_vars)}"
)
reflectance = ds["reflectance"] # dims: (y, x, band)
n_bands = reflectance.sizes["band"]
# 2. 水域掩膜 (与去耀斑算法同约定)
if "water_mask" in ds.data_vars or "water_mask" in ds.coords:
water_mask = ds["water_mask"].astype(bool)
else:
water_mask = xr.ones_like(reflectance.isel(band=0), dtype=bool)
# 3. NaN 清洗: 填充 0 (model.predict 不接受 NaN)
refl_clean = reflectance.fillna(0.0)
# 4. 核心: 用 apply_ufunc 把 model.predict 沿 (y, x) 应用
# dask="parallelized" 让每个 (y_chunk, x_chunk, band) chunk
# 独立调 _predict_block, 任意时刻内存中只有若干个 chunk。
prediction: xr.DataArray = xr.apply_ufunc(
_predict_block,
refl_clean,
kwargs={"model": model},
input_core_dims=[["band"]],
output_core_dims=[[]],
dask="parallelized",
output_dtypes=[np.float32],
dask_gufunc_kwargs={"allow_rechunk": True},
vectorize=False,
)
# 5. 非水域置 NaN (zarr 支持 float NaN, 便于后续可视化/掩膜分析)
prediction = prediction.where(water_mask, np.nan)
# 6. 包装为 Dataset 并流式写出
out = xr.Dataset(
{"prediction": prediction},
attrs={
"model_id": model_id,
"input_zarr_path": input_zarr_path,
"n_bands": n_bands,
"created_at": datetime.now().isoformat(),
},
)
# 保留 y/x 坐标
out = out.assign_coords(y=ds["y"], x=ds["x"])
# to_zarr + compute=True 触发整图 dask 图求值
# 中间会按 chunk 逐块调度到线程池, 内存峰值 ≈ 1 个 chunk 的体量
out.to_zarr(output_zarr_path, mode="w", compute=True)
# ---------------------------------------------------------------------------
# 后台任务执行器
# ---------------------------------------------------------------------------
async def execute_train_task(
task_id: str,
model_type: str,
target: str,
train_data_path: str,
feature_start: Union[int, str],
params: Dict[str, Any],
) -> None:
"""
训练任务后台执行器(已接入真实 sklearn 训练流程)。
流程:
1) get_task 校验任务存在
2) update_task(PROCESSING)
3) 生成 model_id / model_path
4) asyncio.to_thread 派发 _run_train_sync 到默认线程池
5) 成功 -> _register_model + update_task(SUCCESS, 附 test_r2/rmse/mae)
6) 失败 -> update_task(FAILED, 附 error + traceback)
"""
record = await get_task(task_id)
if record is None:
print(f"[{task_id}] 训练任务不存在, 跳过")
return
await update_task(
task_id,
status="PROCESSING",
updated_at=datetime.now().isoformat(),
)
print(
f"[{task_id}] 开始训练 model_type={model_type} target={target} "
f"train_data_path={train_data_path} feature_start={feature_start}"
)
# model_id 用 uuid4 前 12 位 (8 位易撞, 12 位兼顾可读性)
model_id = f"model_{uuid.uuid4().hex[:12]}"
model_path = Path(f"./data/models/{model_id}.joblib")
try:
# 同步 sklearn / pandas 训练丢到默认线程池, 不阻塞 event loop
metadata = await asyncio.to_thread(
_run_train_sync,
model_type,
target,
train_data_path,
feature_start,
params,
model_path,
)
# 登记到内存注册表 (供 /predict 查 model_id)
_register_model(
{
"model_id": model_id,
"model_type": model_type,
"target": target,
"params": dict(params or {}),
"path": str(model_path),
"feature_start": feature_start,
"n_features": metadata["n_features"],
"test_r2": metadata["test_r2"],
"test_rmse": metadata["test_rmse"],
"test_mae": metadata["test_mae"],
"created_at": datetime.now().isoformat(),
"train_task_id": task_id,
}
)
# 把训练指标写回任务记录, 前端轮询时可直接看
await update_task(
task_id,
status="SUCCESS",
model_id=model_id,
model_path=str(model_path),
test_r2=metadata["test_r2"],
test_rmse=metadata["test_rmse"],
test_mae=metadata["test_mae"],
n_features=metadata["n_features"],
n_samples=metadata["n_samples"],
error=None,
traceback=None,
updated_at=datetime.now().isoformat(),
)
print(
f"[{task_id}] 训练完成 -> model_id={model_id} "
f"test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f}"
)
except Exception as exc: # noqa: BLE001
# 失败时 model_path 不一定有产物, 显式置 None 方便前端判断
await update_task(
task_id,
status="FAILED",
model_id=None,
model_path=None,
error=f"{type(exc).__name__}: {exc}",
traceback=traceback.format_exc(),
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 训练失败 -> {type(exc).__name__}: {exc}")
async def execute_predict_task(
task_id: str,
model_id: str,
input_zarr_path: str,
output_zarr_path: Optional[str],
) -> None:
"""
推断任务后台执行器(真实实现版)。
OOM 防护策略:
- xr.open_zarr(..., chunks="auto") 延迟打开, 整图不一次性读入内存
- xr.apply_ufunc(..., dask="parallelized") 把影像按 chunk 切分
- 每个 chunk 内部 reshape 成 2D, 调 model.predict, 再 reshape 回 2D
- 任意时刻内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) chunk 的体量
- 整图完成计算后再 to_zarr(compute=True) 流式写出
"""
record = await get_task(task_id)
if record is None:
print(f"[{task_id}] 推断任务不存在, 跳过")
return
# 1. 校验 model_id 是否已注册 (避免在后台任务里报模糊错误)
model_meta = _MODEL_REGISTRY.get(model_id)
if model_meta is None:
await update_task(
task_id,
status="FAILED",
error=f"model_id 不存在: {model_id}",
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 推断失败 -> model_id 不存在: {model_id}")
return
# 2. 自动生成 output_zarr_path (若未提供)
if output_zarr_path is None:
stem = input_zarr_path.rstrip("/\\").split("/")[-1].split("\\")[-1]
stem = stem.replace(".zarr", "")
output_zarr_path = f"./data/{model_id}_{stem}_pred.zarr"
await update_task(
task_id,
status="PROCESSING",
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 开始推断 model_id={model_id} input={input_zarr_path}")
try:
# 3. 探测波段数 (用于 Dummy 模型适配)
# 这里只读 zarr 元数据 (.zarray 的 shape), 不读真实数据
ds_probe = xr.open_zarr(input_zarr_path, chunks="auto")
if "reflectance" not in ds_probe.data_vars:
raise KeyError(
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds_probe.data_vars)}"
)
n_bands = ds_probe["reflectance"].sizes["band"]
ds_probe.close()
# 4. 加载模型 (真实文件优先, Dummy 兜底)
model = _load_model(model_meta["path"], n_features=n_bands)
# 5. 包装同步执行, 丢到线程池, 事件循环不阻塞
await asyncio.to_thread(
_run_predict_sync,
model,
model_id,
input_zarr_path,
output_zarr_path,
)
await update_task(
task_id,
status="SUCCESS",
output_zarr_path=output_zarr_path,
model_id=model_id,
error=None,
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 推断完成 -> output={output_zarr_path}")
except Exception as exc: # noqa: BLE001
tb_text = traceback.format_exc()
await update_task(
task_id,
status="FAILED",
output_zarr_path=None,
error=f"{type(exc).__name__}: {exc}",
traceback=tb_text,
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 推断失败 -> {type(exc).__name__}: {exc}")
print(tb_text)
# ---------------------------------------------------------------------------
# POST /api/modeling/train
# ---------------------------------------------------------------------------
@router.post("/train", response_model=TaskAcceptedResponse)
async def submit_train(
payload: TrainRequest,
background_tasks: BackgroundTasks,
) -> Dict[str, Any]:
"""提交一个模型训练任务, 立即返回 task_id。"""
task_id = str(uuid.uuid4())
await set_task(
task_id,
{
"task_id": task_id,
"kind": "train",
"model_type": payload.model_type,
"target": payload.target,
"train_data_path": payload.train_data_path,
"feature_start": payload.feature_start,
"params": payload.params,
"status": "PENDING",
"model_id": None,
"model_path": None,
"test_r2": None,
"test_rmse": None,
"test_mae": None,
"n_features": None,
"n_samples": None,
"error": None,
"traceback": None,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
},
)
background_tasks.add_task(
execute_train_task,
task_id,
payload.model_type,
payload.target,
payload.train_data_path,
payload.feature_start,
payload.params,
)
return {"task_id": task_id, "status": "PENDING", "kind": "train"}
# ---------------------------------------------------------------------------
# GET /api/modeling/models
# ---------------------------------------------------------------------------
@router.get("/models", response_model=ModelListResponse)
async def list_trained_models() -> Dict[str, Any]:
"""
列出已训练好的模型。
未来实现: 从 ./data/models/*.joblib 扫描元信息,
当前直接从内存 _MODEL_REGISTRY 读。
"""
models = list(_MODEL_REGISTRY.values())
# 按 created_at 倒序, 最新训练的在前
models.sort(key=lambda m: m.get("created_at", ""), reverse=True)
return {"models": models, "count": len(models)}
# ---------------------------------------------------------------------------
# POST /api/modeling/predict
# ---------------------------------------------------------------------------
@router.post("/predict", response_model=TaskAcceptedResponse)
async def submit_predict(
payload: PredictRequest,
background_tasks: BackgroundTasks,
) -> Dict[str, Any]:
"""提交一个模型推断任务, 立即返回 task_id。"""
task_id = str(uuid.uuid4())
await set_task(
task_id,
{
"task_id": task_id,
"kind": "predict",
"model_id": payload.model_id,
"input_zarr_path": payload.input_zarr_path,
"output_zarr_path": payload.output_zarr_path,
"status": "PENDING",
"error": None,
"traceback": None,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
},
)
background_tasks.add_task(
execute_predict_task,
task_id,
payload.model_id,
payload.input_zarr_path,
payload.output_zarr_path,
)
return {"task_id": task_id, "status": "PENDING", "kind": "predict"}

View File

@ -0,0 +1,40 @@
"""
去耀斑算法包
============
通过「注册表 + 策略模式」组织不同的去耀斑算法。
所有具体算法都应继承 BaseGlintRemover并使用 @register_glint_remover
装饰器把算法名和实现类绑定。
外部调用约定
------------
1. 所有算法子模块必须在本 __init__ 中显式 import
这样装饰器才会被执行、注册表才会被填满。
2. 上层endpoints、worker只允许
from app.core.algorithms import get_remover
来获取算法类,不要直接 import 具体实现类,
保持调度层与具体算法的解耦。
"""
from app.core.algorithms.base import BaseGlintRemover
from app.core.algorithms.registry import (
get_remover,
list_removers,
register_glint_remover,
unregister_glint_remover,
)
# ---- 算法子模块 import 区 ----
# 新增算法时,在这里加一行 import确保装饰器被执行。
from app.core.algorithms import goodman # Goodman
from app.core.algorithms import kutser # Kutser
# from app.core.algorithms import hedley # Hedley
# from app.core.algorithms import sugar # SUGAR
__all__ = [
"BaseGlintRemover",
"register_glint_remover",
"get_remover",
"list_removers",
"unregister_glint_remover",
]

View File

@ -0,0 +1,85 @@
"""
去耀斑算法抽象基类
==================
设计目标(策略模式 Strategy Pattern
------------------------------------
本模块定义了所有去耀斑算法必须遵守的标准接口。
未来的 Kutser、Goodman、Hedley、SUGAR 等算法都将继承本基类,
并实现统一的 process() 方法。
输入输出规范
------------
所有算法的输入与输出均统一为 **Zarr 文件路径**(字符串),
而不是内存中的 numpy ndarray。这样做的核心收益是
1. **解耦数据存储与内存计算**
算法只关心「从哪个 zarr 读、写到哪个 zarr」
至于数据最初来自 GeoTIFF / HDF5 / NetCDF / 内存数组,
都由 IO 层负责归一化转为 zarr。
2. **支持 Out-of-Core 计算**
影像往往超过内存上限zarr 分块chunk天然支持按块读取
算法实现可以借助 dask / xarray 进行流式计算。
3. **可缓存、可复用**
中间产物落盘后,下游算法(大气校正、辐射定标)能直接消费,
避免重复 IO。
4. **易于并行与分布式**
任务调度层只需把两个路径扔给 worker无需关心数据细节。
约定
----
- 子类应实现 process(),完成「读 -> 计算 -> 写」的完整流程。
- process() 返回 True 表示成功False 表示失败。
- 失败时建议抛出异常而非仅返回 False便于上层 BackgroundTasks 捕获并写入 error 字段。
"""
from abc import ABC, abstractmethod
from typing import Any
class BaseGlintRemover(ABC):
"""
去耀斑算法抽象基类。
所有具体算法Kutser / Goodman / Hedley / SUGAR …)必须继承本类并实现 process()。
子类可在 __init__ 中接收自己的超参数(如参考波段、阈值等),
真正的输入输出数据则由 process() 的两个 zarr 路径参数指定。
"""
# 子类可覆盖的算法名称标识,用于调度层按 method 名字查找
name: str = "base"
@abstractmethod
async def process(
self,
input_zarr_path: str,
output_zarr_path: str,
**kwargs: Any,
) -> bool:
"""
执行去耀斑处理。
Parameters
----------
input_zarr_path : str
输入高光谱影像的 zarr 存储路径。
数据已由 IO 层完成格式归一化(波段、坐标系、空间维度均已对齐)。
output_zarr_path : str
处理结果(去耀斑后影像)的 zarr 存储路径。
子类需自行创建该 zarr 存储并写入结果。
**kwargs : Any
算法的可选超参数,例如:
- reference_band: 参考近红外波段索引
- chunk_size: 计算分块大小
- 其它算法特定参数
Returns
-------
bool
True 表示处理成功False 表示失败。
建议在出错时直接 raise由调用方统一记录到任务状态。
"""
raise NotImplementedError
def __repr__(self) -> str: # pragma: no cover - 调试辅助
return f"<{self.__class__.__name__} name={self.name!r}>"

View File

@ -0,0 +1,123 @@
"""
app/core/algorithms/goodman.py
===============================
Goodman et al. 2008 去耀斑算法的 xarray + dask 流式实现。
算法公式
--------
R_corrected = R_raw - R_750 + A + B * (R_640 - R_750)
其中:
R_raw -- 原始反射率 (y, x, band)
R_750 -- λ=750 nm 处的反射率(红外参考波段, 远离水汽吸收)
R_640 -- λ=640 nm 处的反射率(可见光差异波段)
A, B -- 经验回归参数(用户可通过 params 传入, 默认全 0
后处理
------
- 负值截断为 0Clamp to 0
- 仅在水域掩膜 (water_mask) 内生效, 水外置 0
维度约定
--------
reflectance: (y, x, band), band 坐标通常为 wavelength (nm)
water_mask : (y, x), 布尔类型, True = 水域
"""
import asyncio
from typing import Any
import xarray as xr
from app.core.algorithms.base import BaseGlintRemover
from app.core.algorithms.registry import register_glint_remover
# ---------------------------------------------------------------------------
# 默认参数
# ---------------------------------------------------------------------------
# 与原始 Goodman 2008 论文符号保持一致, 方便用户交叉对照。
# A、B 通常通过对纯净深水区做 (R_corr - R_raw) ~ (R_640 - R_750) 回归得到;
# 在缺乏先验知识时, 退化为 A=0, B=0 即等价于 R_corrected = clip(R_raw - R_750, 0)。
# ---------------------------------------------------------------------------
DEFAULT_BAND_REF: float = 750.0 # λ_750 nm, 红外参考波段
DEFAULT_BAND_DIFF: float = 640.0 # λ_640 nm, 可见光差异波段
DEFAULT_A: float = 0.0 # 公式中的常数偏移项
DEFAULT_B: float = 0.0 # 公式中的斜率项
@register_glint_remover("goodman")
class GoodmanGlintRemover(BaseGlintRemover):
"""Goodman et al. 2008 去耀斑算法"""
name = "goodman"
async def process(
self,
input_zarr_path: str,
output_zarr_path: str,
**kwargs: Any,
) -> bool:
# 1. 解析超参数(带默认值, 方便用户按需覆盖)
band_ref: float = kwargs.get("band_ref", DEFAULT_BAND_REF)
band_diff: float = kwargs.get("band_diff", DEFAULT_BAND_DIFF)
A: float = kwargs.get("A", DEFAULT_A)
B: float = kwargs.get("B", DEFAULT_B)
# 2. 把同步的 xarray/dask 计算丢到工作线程,
# 避免阻塞 FastAPI 的事件循环
return await asyncio.to_thread(
self._process_sync,
input_zarr_path,
output_zarr_path,
band_ref,
band_diff,
A,
B,
)
@staticmethod
def _process_sync(
input_zarr_path: str,
output_zarr_path: str,
band_ref: float,
band_diff: float,
A: float,
B: float,
) -> bool:
# 1. 以 zarr 路径打开dask-backed, 不物化到内存)
# chunks="auto" 让 dask 根据每条坐标轴的大小自动决定分块
ds = xr.open_zarr(input_zarr_path, chunks="auto")
reflectance = ds["reflectance"] # (y, x, band)
# 2. 用 sel + method='nearest' 提取两个关键波段
# 返回形状 (y, x), 后续与 (y, x, band) 算术时会自动广播
R_750 = reflectance.sel(band=band_ref, method="nearest")
R_640 = reflectance.sel(band=band_diff, method="nearest")
# 3. Goodman 公式: xarray 沿 band 维度自动广播
# R_corr = R_raw - R_750 + A + B * (R_640 - R_750)
result = reflectance - R_750 + A + B * (R_640 - R_750)
# 4. 负值截断为 0clip(min=0) 优于 where(>0, 0, _)
# 不构造布尔中间数组, 底层走 dask 矢量化 clip 路径)
result = result.clip(min=0)
# 5. 仅在水域内生效(水外强制为 0
# 优先从 zarr 内部读 water_mask 变量, 缺失则视为全图水域
if "water_mask" in ds:
water_mask = ds["water_mask"].astype(bool)
result = result.where(water_mask, 0)
# 6. 构造输出 Dataset, 保留元信息(波段坐标/属性等)
out = xr.Dataset({"reflectance": result})
if ds.attrs:
out.attrs = dict(ds.attrs)
if reflectance.attrs:
out["reflectance"].attrs = dict(reflectance.attrs)
# 7. 流式写出Out-of-Core不一次性物化大数组,
# dask 会按 chunk 边算边写, 内存峰值 ≈ 单个 chunk 大小
out.to_zarr(output_zarr_path, mode="w", compute=True)
return True

View File

@ -0,0 +1,211 @@
"""
Kutser 去耀斑算法xarray + dask 重构版)
========================================
旧版痛点
--------
原始 Kutser 实现(参考 Kutser et al., 2013通常写成像这样
R_corr = np.zeros_like(R_raw)
for b in range(n_bands):
for y in range(H):
for x in range(W):
if water_mask[y, x]:
R_corr[y, x, b] = (
R_raw[y, x, b] - G_list[b] * D_norm[y, x]
)
with rasterio.open(..., 'w') as dst:
dst.write(R_corr)
问题:
1. 三重 Python 循环,每次只做一个浮点运算,解释器开销巨大;
2. 一次性把整张图 R_raw 读进内存,大影像直接 OOM
3. rasterio 写出要求 numpy 连续数组,进一步放大内存。
本文件用 xarray + dask 重写:
- 用 DataArray 维度广播,三重循环 → 一行表达式;
- 用 dask chunk 保持数据常驻磁盘、流式计算;
- 用 to_zarr 边算边写,输出格式与算法层彻底解耦。
"""
import asyncio
from typing import Any
import xarray as xr
from app.core.algorithms.base import BaseGlintRemover
from app.core.algorithms.registry import register_glint_remover
# ---------------------------------------------------------------------------
# 算法实现
# ---------------------------------------------------------------------------
@register_glint_remover("kutser")
class KutserGlintRemover(BaseGlintRemover):
"""
Kutser 近红外扣除法去耀斑。
数学公式(与旧版完全等价)
-------------------------
1) 水汽吸收深度 D每像素
D = (R(λ_lower) + R(λ_upper)) / 2 - R(λ_oxy)
2) 全局归一化因子 D_max
D_max = max(D) over 水域
归一化:
D_norm = D / D_max
3) 每波段水域范围:
G_list[b] = max(R[:, :, b] over 水域) - min(R[:, :, b] over 水域)
4) 校正公式(每像素、每波段):
R_corr(λ_b) = R_raw(λ_b) - G_list[b] * D_norm
"""
# Kutser 2013 论文里使用的参考波段nm
# λ_lower = 773, λ_oxy = 845, λ_upper = 893
# 允许通过 kwargs 覆盖,便于适配 MERIS / OLCI / Landsat 等不同传感器。
DEFAULT_BAND_LOWER: float = 773.0
DEFAULT_BAND_OXY: float = 845.0
DEFAULT_BAND_UPPER: float = 893.0
# --------------------------------------------------------------
# 公开异步入口
# --------------------------------------------------------------
# xarray / dask 的算子本身是同步阻塞的。在 async 函数中,
# 用 asyncio.to_thread 把同步体丢到默认线程池执行,
# 避免阻塞 FastAPI 的事件循环。
# --------------------------------------------------------------
async def process(
self,
input_zarr_path: str,
output_zarr_path: str,
**kwargs: Any,
) -> bool:
return await asyncio.to_thread(
self._process_sync,
input_zarr_path,
output_zarr_path,
kwargs,
)
# --------------------------------------------------------------
# 同步核心实现
# --------------------------------------------------------------
def _process_sync(
self,
input_zarr_path: str,
output_zarr_path: str,
kwargs: dict,
) -> bool:
# ============================================================
# 步骤 0打开 zarr建立 dask 计算图
# ============================================================
# chunks="auto":让 dask 根据 zarr 的存储分块自动选择内存上限,
# 数据不会一次性全部 materialize 进 RAM。
# ============================================================
ds = xr.open_zarr(input_zarr_path, chunks="auto")
reflectance: xr.DataArray = ds["reflectance"] # 维度约定:(y, x, band)
# 维度顺序约定(也可根据 ds.dims 自动适配):
assert "y" in reflectance.dims and "x" in reflectance.dims and "band" in reflectance.dims, (
f"reflectance 必须包含 y/x/band 三个维度,实际为: {reflectance.dims}"
)
# ============================================================
# 步骤 1取出 3 个参考波段对应的二维 (y, x) 切片
# ============================================================
# 假设 band 维度的坐标是 wavelengthnm
# 用 sel(..., method="nearest") 自动匹配最接近的波段。
# ============================================================
wl_lower = float(kwargs.get("band_lower", self.DEFAULT_BAND_LOWER))
wl_oxy = float(kwargs.get("band_oxy", self.DEFAULT_BAND_OXY))
wl_upper = float(kwargs.get("band_upper", self.DEFAULT_BAND_UPPER))
R_lower = reflectance.sel(band=wl_lower, method="nearest") # (y, x)
R_upper = reflectance.sel(band=wl_upper, method="nearest") # (y, x)
R_oxy = reflectance.sel(band=wl_oxy, method="nearest") # (y, x)
# ============================================================
# 步骤 2水域掩膜
# ============================================================
# 优先从 zarr 内部读取 water_mask 变量;
# 如果不存在,则假定整幅图都是水域(开发期兜底)。
# ============================================================
if "water_mask" in ds:
water_mask = ds["water_mask"].astype(bool)
else:
water_mask = xr.ones_like(
reflectance.isel(band=0), dtype=bool
)
# ============================================================
# 步骤 3水汽吸收深度 D每像素形状 (y, x)
# ============================================================
# 旧版D[y, x] = (R_lower[y, x] + R_upper[y, x]) / 2 - R_oxy[y, x]
# 新版一行表达式dask 自动构建 lazy 计算图。
# ============================================================
D = (R_lower + R_upper) / 2.0 - R_oxy # (y, x)dtype 与 reflectance 一致
# ============================================================
# 步骤 4全局归一化因子 D_max标量0-dim DataArray
# ============================================================
# 关键:先 .where(water_mask) 把非水域置 NaN
# 再 .max() 跨 (x, y) 聚合,自动规约到 0 维。
# dask 此时仍然没有真正计算,等到 to_zarr 时再触发。
# ============================================================
D_max = D.where(water_mask).max() # scalar
# 容错:如果水域为空导致 D_max 为 NaN用极小值兜底避免除零
D_max = D_max.fillna(1e-6)
# ============================================================
# 步骤 5归一化 D_norm形状 (y, x)
# ============================================================
D_norm = D / D_max # 标量除以 (y, x) 数组 → 自动广播
# ============================================================
# 步骤 6每波段水域范围 G_list形状 (band,)
# ============================================================
# 旧版三重循环内部还要做一次 min/max 聚合。
# xarray 版本:把 (y, x) 一起 reduce只保留 band 维度。
# ============================================================
R_water = reflectance.where(water_mask) # (y, x, band),非水域 NaN
G_min = R_water.min(dim=["x", "y"]) # (band,)
G_max = R_water.max(dim=["x", "y"]) # (band,)
G_list = (G_max - G_min).fillna(0.0) # (band,),容错
# ============================================================
# 步骤 7校正公式最关键的一行演示 xarray 广播)
# ============================================================
# 旧版需要:
# for b in bands:
# for y in range(H):
# for x in range(W):
# R_corr[y,x,b] = R_raw[y,x,b] - G_list[b] * D_norm[y,x]
#
# xarray 维度对齐规则:
# R_raw : (y, x, band)
# G_list: (band,) → 缺失 y, x 自动扩展
# D_norm: (y, x) → 缺失 band 自动扩展
# 乘法结果: (y, x, band) → 减法对齐
# 一行表达式完成「三重 for 循环 + 标量索引」的语义。
# ============================================================
corrected = reflectance - G_list * D_norm # (y, x, band)
# ============================================================
# 步骤 8水域掩膜过滤非水域置 NaN
# ============================================================
result = corrected.where(water_mask)
# ============================================================
# 步骤 9持久化为 zarr
# ============================================================
# mode="w":覆盖写入(如果目标已存在则删除重建)。
# compute=True阻塞直到整张图算完并落盘。
# 由于数据始终是 dask chunk + 流式写出,
# 内存峰值 ≈ 单个 chunk 大小,与整张影像大小无关。
# ============================================================
out = xr.Dataset({"reflectance": result})
# 保留原数据集的全局属性 / 坐标信息CRS、wavelength、...
out.attrs = dict(ds.attrs)
out["reflectance"].attrs = dict(reflectance.attrs)
out.to_zarr(output_zarr_path, mode="w", compute=True)
return True

View File

@ -0,0 +1,135 @@
"""
算法注册表Registry / Factory
================================
通过装饰器把「算法名字符串」与「算法实现类」绑定在一起。
上层调度层FastAPI endpoints、BackgroundTasks worker只需要拿到
前端传过来的 method 字符串,就可以自动派发到对应的算法实现,
而无需写一长串 if/elif。
使用示例
--------
from app.core.algorithms import BaseGlintRemover
from app.core.algorithms.registry import (
register_glint_remover,
get_remover,
list_removers,
)
@register_glint_remover("kutser")
class KutserGlintRemover(BaseGlintRemover):
async def process(self, input_zarr_path, output_zarr_path, **kwargs):
...
# 派发
Cls = get_remover(method_from_request)
remover = Cls()
await remover.process(input_zarr_path, output_zarr_path, **kwargs)
设计要点
--------
- 注册动作发生在「类定义时」,所以必须在所有算法 import 完之后
注册表才完整。可以在 `app/core/algorithms/__init__.py` 中
把算法子模块 import 一遍来强制触发注册。
- 重复注册同名算法会直接抛错,避免静默覆盖。
- name 会同步写回到类的 `name` 属性,便于算法自身查询身份。
"""
from typing import Dict, Type
from app.core.algorithms.base import BaseGlintRemover
# 全局注册表name(str) -> 实现类(type),类未被实例化
_REGISTRY: Dict[str, Type[BaseGlintRemover]] = {}
def register_glint_remover(name: str):
"""
类装饰器工厂:把传入 name 的算法类注册到全局注册表。
Parameters
----------
name : str
算法标识,建议小写下划线风格,例如 "kutser""goodman"
Raises
------
ValueError
- name 不是非空字符串
- name 已经被其它类占用
TypeError
- 被装饰的对象不是 BaseGlintRemover 的子类
"""
# ---- 防御性校验name 必须是合法字符串 ----
if not isinstance(name, str) or not name.strip():
raise ValueError(
f"register_glint_remover 的 name 必须是非空字符串,收到: {name!r}"
)
def decorator(cls: Type[BaseGlintRemover]) -> Type[BaseGlintRemover]:
# ---- 防御性校验:被装饰对象必须是 BaseGlintRemover 子类 ----
if not isinstance(cls, type) or not issubclass(cls, BaseGlintRemover):
raise TypeError(
f"@register_glint_remover 只能装饰 BaseGlintRemover 的子类,"
f"收到: {cls!r}"
)
# ---- 防御性校验:禁止静默覆盖 ----
if name in _REGISTRY:
raise ValueError(
f"算法名 {name!r} 已被 {_REGISTRY[name].__name__} 占用,"
f"请使用其它名字或先调用 unregister_glint_remover() 注销旧实现。"
)
# 同步把 name 写回类属性,便于算法自身和日志输出使用
cls.name = name
_REGISTRY[name] = cls
return cls
return decorator
def get_remover(name: str) -> Type[BaseGlintRemover]:
"""
按算法名字符串取出对应的实现类(未实例化)。
调用方拿到类后自行 `Cls(...)` 构造实例,再调用 process()。
Raises
------
KeyError
当 name 不在注册表中时抛出,错误信息中附带已注册列表便于排查。
"""
try:
return _REGISTRY[name]
except KeyError as exc:
known = ", ".join(sorted(_REGISTRY)) or "<空>"
raise KeyError(
f"未注册的算法名: {name!r}。已注册的算法: {known}"
) from exc
def list_removers() -> Dict[str, Type[BaseGlintRemover]]:
"""
返回当前注册表的浅拷贝。
可用于:
- 调试日志
- 给前端暴露一个 GET /api/algorithms 接口
- 单元测试断言
"""
return dict(_REGISTRY)
def unregister_glint_remover(name: str) -> None:
"""
注销指定算法。主要给:
- 单元测试
- 热重载 / 插件卸载场景
生产代码一般不需要调用。
"""
if name not in _REGISTRY:
raise KeyError(f"未注册的算法名: {name!r}")
del _REGISTRY[name]

View File

@ -0,0 +1,91 @@
"""
app/core/task_store.py
======================
并发安全的内存任务状态存储,替代早期 mock 流水线中的 MOCK_TASK_DB。
设计目标
--------
1. 在单进程内提供事件循环级别的互斥asyncio.Lock
避免在 update 与 set/get 之间穿插 await 时发生状态不一致。
2. 暴露异步 APIset_task / update_task / get_task
让调用方在 async 上下文中显式表达临界区。
3. 保留一个同步的 has_task() 用于轻量存在性判断。
4. 生产环境应替换为 Redis / SQLite / PostgreSQL
但接口形状保持一致, 便于上层调用方无缝迁移。
使用约定
--------
- 写入初始 PENDING 记录: await set_task(task_id, record)
- 增量更新字段PROCESSING/SUCCESS/FAILEDawait update_task(task_id, **fields)
- 读取任务记录: await get_task(task_id) # 可能返回 None
- 同步判断是否存在: has_task(task_id)
"""
import asyncio
from typing import Any, Dict, Optional
# ---------------------------------------------------------------------------
# 全局存储与锁
# ---------------------------------------------------------------------------
# TASK_STORE: task_id -> 任务记录
# 任务记录字段约定(与 endpoints.py 保持一致):
# task_id, method, params, status,
# output_zarr_path, error, traceback,
# created_at, updated_at
# ---------------------------------------------------------------------------
TASK_STORE: Dict[str, Dict[str, Any]] = {}
# 单进程内的事件循环级互斥锁
# 注意asyncio.Lock 必须在事件循环内创建, 故在模块顶层实例化时
# 仅获取引用, 第一次使用 (await lock.acquire()) 会在运行循环内进行。
_lock: asyncio.Lock = asyncio.Lock()
# ---------------------------------------------------------------------------
# 异步 API
# ---------------------------------------------------------------------------
async def set_task(task_id: str, record: Dict[str, Any]) -> None:
"""
初始化或整体覆盖一个任务记录。
用法POST 端点收到提交请求后立即调用, 写入 PENDING 状态的初始记录。
"""
async with _lock:
TASK_STORE[task_id] = record
async def update_task(task_id: str, **fields: Any) -> None:
"""
按字段增量更新任务记录。
用法:后台执行器在 PROCESSING / SUCCESS / FAILED 等状态切换时调用。
若 task_id 不存在, setdefault 会自动创建一个空 dict 再 update防御性兜底
"""
async with _lock:
record = TASK_STORE.setdefault(task_id, {})
record.update(fields)
async def get_task(task_id: str) -> Optional[Dict[str, Any]]:
"""
读取任务记录; 不存在时返回 None。
用法GET /api/tasks/{task_id} 用此接口查询。
"""
async with _lock:
return TASK_STORE.get(task_id)
# ---------------------------------------------------------------------------
# 同步 API轻量
# ---------------------------------------------------------------------------
def has_task(task_id: str) -> bool:
"""
同步判断 task_id 是否存在。
适用于不需要锁的轻量场景(例如日志前置判断);
在 async 上下文中仍可调用, 因为 dict 的 in 判断是原子操作。
"""
return task_id in TASK_STORE

62
new/app/main.py Normal file
View File

@ -0,0 +1,62 @@
"""
WQ_GUI FastAPI 后端入口
=======================
应用启动与全局中间件配置:
- CORS开发阶段允许所有来源方便本地前端Vite / Webpack dev server联调
- 路由:通过 include_router 挂载 app/api/endpoints.py 中的业务接口
业务接口说明:
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
GET /api/tasks/{task_id} 查询指定任务的状态与结果
"""
from typing import Dict
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.endpoints import router as deglint_router
from app.api.modeling import router as modeling_router
# ---------------------------------------------------------------------------
# FastAPI 应用实例
# ---------------------------------------------------------------------------
app = FastAPI(
title="WQ_GUI Backend",
description="高光谱影像去耀斑处理 API",
version="0.2.0",
)
# ---------------------------------------------------------------------------
# CORS 中间件
# ---------------------------------------------------------------------------
# 开发阶段:放开所有来源、方法和头部,方便本地前端(任意端口)联调。
# 生产环境务必收敛 allow_origins 为前端真实域名,避免安全风险。
# ---------------------------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# 路由注册
# ---------------------------------------------------------------------------
# 统一以 /api 为前缀,便于将来做版本管理(如 /api/v1、/api/v2
# ---------------------------------------------------------------------------
app.include_router(deglint_router, prefix="/api")
app.include_router(modeling_router, prefix="/api")
# ---------------------------------------------------------------------------
# 根路径健康检查(方便本地调试,非业务必需)
# ---------------------------------------------------------------------------
@app.get("/")
async def root() -> Dict[str, str]:
return {"service": "WQ_GUI Backend", "status": "ok"}

24
new/frontend/.gitignore vendored Normal file
View File

@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

5
new/frontend/README.md Normal file
View File

@ -0,0 +1,5 @@
# Vue 3 + TypeScript + Vite
This template should help get you started developing with Vue 3 and TypeScript in Vite. The template uses Vue 3 `<script setup>` SFCs, check out the [script setup docs](https://v3.vuejs.org/api/sfc-script-setup.html#sfc-script-setup) to learn more.
Learn more about the recommended Project Setup and IDE Support in the [Vue Docs TypeScript Guide](https://vuejs.org/guide/typescript/overview.html#project-setup).

13
new/frontend/index.html Normal file
View File

@ -0,0 +1,13 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>frontend</title>
</head>
<body>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>

2412
new/frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

27
new/frontend/package.json Normal file
View File

@ -0,0 +1,27 @@
{
"name": "frontend",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vue-tsc -b && vite build",
"preview": "vite preview"
},
"dependencies": {
"axios": "^1.16.1",
"echarts": "^6.1.0",
"element-plus": "^2.14.1",
"pinia": "^3.0.4",
"vue": "^3.5.34",
"vue-router": "^5.1.0"
},
"devDependencies": {
"@types/node": "^24.12.3",
"@vitejs/plugin-vue": "^6.0.6",
"@vue/tsconfig": "^0.9.1",
"typescript": "~6.0.2",
"vite": "^8.0.12",
"vue-tsc": "^3.2.8"
}
}

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 9.3 KiB

View File

@ -0,0 +1,24 @@
<svg xmlns="http://www.w3.org/2000/svg">
<symbol id="bluesky-icon" viewBox="0 0 16 17">
<g clip-path="url(#bluesky-clip)"><path fill="#08060d" d="M7.75 7.735c-.693-1.348-2.58-3.86-4.334-5.097-1.68-1.187-2.32-.981-2.74-.79C.188 2.065.1 2.812.1 3.251s.241 3.602.398 4.13c.52 1.744 2.367 2.333 4.07 2.145-2.495.37-4.71 1.278-1.805 4.512 3.196 3.309 4.38-.71 4.987-2.746.608 2.036 1.307 5.91 4.93 2.746 2.72-2.746.747-4.143-1.747-4.512 1.702.189 3.55-.4 4.07-2.145.156-.528.397-3.691.397-4.13s-.088-1.186-.575-1.406c-.42-.19-1.06-.395-2.741.79-1.755 1.24-3.64 3.752-4.334 5.099"/></g>
<defs><clipPath id="bluesky-clip"><path fill="#fff" d="M.1.85h15.3v15.3H.1z"/></clipPath></defs>
</symbol>
<symbol id="discord-icon" viewBox="0 0 20 19">
<path fill="#08060d" d="M16.224 3.768a14.5 14.5 0 0 0-3.67-1.153c-.158.286-.343.67-.47.976a13.5 13.5 0 0 0-4.067 0c-.128-.306-.317-.69-.476-.976A14.4 14.4 0 0 0 3.868 3.77C1.546 7.28.916 10.703 1.231 14.077a14.7 14.7 0 0 0 4.5 2.306q.545-.748.965-1.587a9.5 9.5 0 0 1-1.518-.74q.191-.14.372-.293c2.927 1.369 6.107 1.369 8.999 0q.183.152.372.294-.723.437-1.52.74.418.838.963 1.588a14.6 14.6 0 0 0 4.504-2.308c.37-3.911-.63-7.302-2.644-10.309m-9.13 8.234c-.878 0-1.599-.82-1.599-1.82 0-.998.705-1.82 1.6-1.82.894 0 1.614.82 1.599 1.82.001 1-.705 1.82-1.6 1.82m5.91 0c-.878 0-1.599-.82-1.599-1.82 0-.998.705-1.82 1.6-1.82.893 0 1.614.82 1.599 1.82 0 1-.706 1.82-1.6 1.82"/>
</symbol>
<symbol id="documentation-icon" viewBox="0 0 21 20">
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="m15.5 13.333 1.533 1.322c.645.555.967.833.967 1.178s-.322.623-.967 1.179L15.5 18.333m-3.333-5-1.534 1.322c-.644.555-.966.833-.966 1.178s.322.623.966 1.179l1.534 1.321"/>
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M17.167 10.836v-4.32c0-1.41 0-2.117-.224-2.68-.359-.906-1.118-1.621-2.08-1.96-.599-.21-1.349-.21-2.848-.21-2.623 0-3.935 0-4.983.369-1.684.591-3.013 1.842-3.641 3.428C3 6.449 3 7.684 3 10.154v2.122c0 2.558 0 3.838.706 4.726q.306.383.713.671c.76.536 1.79.64 3.581.66"/>
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M3 10a2.78 2.78 0 0 1 2.778-2.778c.555 0 1.209.097 1.748-.047.48-.129.854-.503.982-.982.145-.54.048-1.194.048-1.749a2.78 2.78 0 0 1 2.777-2.777"/>
</symbol>
<symbol id="github-icon" viewBox="0 0 19 19">
<path fill="#08060d" fill-rule="evenodd" d="M9.356 1.85C5.05 1.85 1.57 5.356 1.57 9.694a7.84 7.84 0 0 0 5.324 7.44c.387.079.528-.168.528-.376 0-.182-.013-.805-.013-1.454-2.165.467-2.616-.935-2.616-.935-.349-.91-.864-1.143-.864-1.143-.71-.48.051-.48.051-.48.787.051 1.2.805 1.2.805.695 1.194 1.817.857 2.268.649.064-.507.27-.857.49-1.052-1.728-.182-3.545-.857-3.545-3.87 0-.857.31-1.558.8-2.104-.078-.195-.349-1 .077-2.078 0 0 .657-.208 2.14.805a7.5 7.5 0 0 1 1.946-.26c.657 0 1.328.092 1.946.26 1.483-1.013 2.14-.805 2.14-.805.426 1.078.155 1.883.078 2.078.502.546.799 1.247.799 2.104 0 3.013-1.818 3.675-3.558 3.87.284.247.528.714.528 1.454 0 1.052-.012 1.896-.012 2.156 0 .208.142.455.528.377a7.84 7.84 0 0 0 5.324-7.441c.013-4.338-3.48-7.844-7.773-7.844" clip-rule="evenodd"/>
</symbol>
<symbol id="social-icon" viewBox="0 0 20 20">
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M12.5 6.667a4.167 4.167 0 1 0-8.334 0 4.167 4.167 0 0 0 8.334 0"/>
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M2.5 16.667a5.833 5.833 0 0 1 8.75-5.053m3.837.474.513 1.035c.07.144.257.282.414.309l.93.155c.596.1.736.536.307.965l-.723.73a.64.64 0 0 0-.152.531l.207.903c.164.715-.213.991-.84.618l-.872-.52a.63.63 0 0 0-.577 0l-.872.52c-.624.373-1.003.094-.84-.618l.207-.903a.64.64 0 0 0-.152-.532l-.723-.729c-.426-.43-.289-.864.306-.964l.93-.156a.64.64 0 0 0 .412-.31l.513-1.034c.28-.562.735-.562 1.012 0"/>
</symbol>
<symbol id="x-icon" viewBox="0 0 19 19">
<path fill="#08060d" fill-rule="evenodd" d="M1.893 1.98c.052.072 1.245 1.769 2.653 3.77l2.892 4.114c.183.261.333.48.333.486s-.068.089-.152.183l-.522.593-.765.867-3.597 4.087c-.375.426-.734.834-.798.905a1 1 0 0 0-.118.148c0 .01.236.017.664.017h.663l.729-.83c.4-.457.796-.906.879-.999a692 692 0 0 0 1.794-2.038c.034-.037.301-.34.594-.675l.551-.624.345-.392a7 7 0 0 1 .34-.374c.006 0 .93 1.306 2.052 2.903l2.084 2.965.045.063h2.275c1.87 0 2.273-.003 2.266-.021-.008-.02-1.098-1.572-3.894-5.547-2.013-2.862-2.28-3.246-2.273-3.266.008-.019.282-.332 2.085-2.38l2-2.274 1.567-1.782c.022-.028-.016-.03-.65-.03h-.674l-.3.342a871 871 0 0 1-1.782 2.025c-.067.075-.405.458-.75.852a100 100 0 0 1-.803.91c-.148.172-.299.344-.99 1.127-.304.343-.32.358-.345.327-.015-.019-.904-1.282-1.976-2.808L6.365 1.85H1.8zm1.782.91 8.078 11.294c.772 1.08 1.413 1.973 1.425 1.984.016.017.241.02 1.05.017l1.03-.004-2.694-3.766L7.796 5.75 5.722 2.852l-1.039-.004-1.039-.004z" clip-rule="evenodd"/>
</symbol>
</svg>

After

Width:  |  Height:  |  Size: 4.9 KiB

225
new/frontend/src/App.vue Normal file
View File

@ -0,0 +1,225 @@
<template>
<div class="dashboard-container">
<h1 class="title">高光谱水质反演控制台</h1>
<el-row :gutter="20">
<el-col :span="12">
<el-card class="box-card" shadow="hover">
<template #header>
<div class="card-header">
<span class="header-title">🚀 模型训练 (Train)</span>
</div>
</template>
<el-form label-position="top">
<el-form-item label="算法选择 (Model Type)">
<el-select v-model="trainForm.model_type" placeholder="请选择算法" class="w-full">
<el-option label="随机森林 (RF)" value="RF" />
<el-option label="支持向量回归 (SVR)" value="SVR" />
<el-option label="线性回归 (LinearRegression)" value="LinearRegression" />
<el-option label="K近邻 (KNN)" value="KNN" />
<el-option label="偏最小二乘 (PLS)" value="PLS" />
</el-select>
</el-form-item>
<el-form-item label="目标参数 (Target)">
<el-input v-model="trainForm.target" placeholder="如 Chl-a" />
</el-form-item>
<el-form-item label="训练数据路径 (CSV 绝对路径)">
<el-input v-model="trainForm.train_data_path" placeholder="如 D:\111\data.csv" />
</el-form-item>
<el-form-item label="特征起始列 (如 4, 或列名)">
<el-input v-model="trainForm.feature_start" placeholder="填写数字或列名" />
</el-form-item>
<el-button type="primary" @click="handleTrain" :loading="trainPoller?.isPolling?.value" class="w-full">
开始训练
</el-button>
</el-form>
<div v-if="trainTaskId" class="status-board">
<p><strong>任务 ID:</strong> <el-tag size="small" type="info">{{ trainTaskId }}</el-tag></p>
<p><strong>当前状态:</strong>
<el-tag :type="getStatusType(trainPoller?.status?.value || 'PENDING')" style="margin-left:10px">
{{ trainPoller?.status?.value || 'PENDING' }}
</el-tag>
</p>
<el-progress
v-if="trainPoller?.isPolling?.value || trainPoller?.status?.value === 'SUCCESS'"
:percentage="trainPoller?.status?.value === 'SUCCESS' ? 100 : 60"
:status="trainPoller?.status?.value === 'SUCCESS' ? 'success' : (trainPoller?.status?.value === 'FAILED' ? 'exception' : '')"
:indeterminate="trainPoller?.isPolling?.value"
/>
<div v-if="trainPoller?.error?.value" class="error-msg">
<el-alert :title="trainPoller.error.value" type="error" :closable="false" show-icon />
</div>
<div v-if="trainPoller?.result?.value?.model_id" class="result-msg">
<el-descriptions border :column="1" size="small" title="训练指标">
<el-descriptions-item label="Model ID">{{ trainPoller.result.value.model_id }}</el-descriptions-item>
<el-descriptions-item label="Test R²">{{ Number(trainPoller.result.value.test_r2).toFixed(4) }}</el-descriptions-item>
<el-descriptions-item label="Test RMSE">{{ Number(trainPoller.result.value.test_rmse).toFixed(4) }}</el-descriptions-item>
</el-descriptions>
</div>
</div>
</el-card>
</el-col>
<el-col :span="12">
<el-card class="box-card" shadow="hover">
<template #header>
<div class="card-header">
<span class="header-title">🎯 模型推断 (Predict)</span>
</div>
</template>
<el-form label-position="top">
<el-form-item label="已训练模型 ID (Model ID)">
<el-input v-model="predictForm.model_id" placeholder="将自动填入左侧训练好的 ID" />
</el-form-item>
<el-form-item label="待推断影像路径 (Zarr 绝对路径)">
<el-input v-model="predictForm.input_zarr_path" placeholder="如 D:\111\image.zarr" />
</el-form-item>
<el-button type="success" @click="handlePredict" :loading="predictPoller?.isPolling?.value" class="w-full">
开始大图反演推断
</el-button>
</el-form>
<div v-if="predictTaskId" class="status-board">
<p><strong>任务 ID:</strong> <el-tag size="small" type="info">{{ predictTaskId }}</el-tag></p>
<p><strong>当前状态:</strong>
<el-tag :type="getStatusType(predictPoller?.status?.value || 'PENDING')" style="margin-left:10px">
{{ predictPoller?.status?.value || 'PENDING' }}
</el-tag>
</p>
<el-progress
v-if="predictPoller?.isPolling?.value || predictPoller?.status?.value === 'SUCCESS'"
:percentage="predictPoller?.status?.value === 'SUCCESS' ? 100 : 50"
:status="predictPoller?.status?.value === 'SUCCESS' ? 'success' : (predictPoller?.status?.value === 'FAILED' ? 'exception' : '')"
:indeterminate="predictPoller?.isPolling?.value"
/>
<div v-if="predictPoller?.error?.value" class="error-msg">
<el-alert :title="predictPoller.error.value" type="error" :closable="false" show-icon />
</div>
<div v-if="predictPoller?.result?.value?.output_zarr_path" class="result-msg">
<el-alert :title="'推断成功!结果已落盘至: ' + predictPoller.result.value.output_zarr_path" type="success" :closable="false" show-icon />
</div>
</div>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script setup lang="ts">
import { ref, watch, reactive } from 'vue'
import { submitTrain, submitPredict } from './api/tasks'
import { useTaskPoller } from './composables/useTaskPoller'
// 训练表单状态
const trainForm = reactive({
model_type: 'RF',
target: 'Chl-a',
train_data_path: '',
feature_start: '4'
})
const trainTaskId = ref<string | null>(null)
const trainPoller = useTaskPoller(trainTaskId)
// 推断表单状态
const predictForm = reactive({
model_id: '',
input_zarr_path: ''
})
const predictTaskId = ref<string | null>(null)
const predictPoller = useTaskPoller(predictTaskId)
// 自动填入联动
watch(() => trainPoller?.result?.value?.model_id, (newId) => {
if (newId) predictForm.model_id = newId as string
})
// 提交训练
const handleTrain = async () => {
try {
const res = await submitTrain({
model_type: trainForm.model_type,
target: trainForm.target,
train_data_path: trainForm.train_data_path,
feature_start: trainForm.feature_start,
params: {}
})
trainTaskId.value = res.task_id
} catch (e: any) {
console.error('训练接口调用失败', e)
alert('提交失败,请检查后端是否在 9090 端口启动,或按 F12 查看控制台跨域报错')
}
}
// 提交推断
const handlePredict = async () => {
try {
const res = await submitPredict({
model_id: predictForm.model_id,
input_zarr_path: predictForm.input_zarr_path
})
predictTaskId.value = res.task_id
} catch (e: any) {
console.error('推断接口调用失败', e)
}
}
// 样式辅助
const getStatusType = (status: string) => {
if (status === 'SUCCESS') return 'success'
if (status === 'FAILED') return 'danger'
if (status === 'PROCESSING') return 'warning'
return 'info'
}
</script>
<style>
/* 去除全局默认边距 */
body {
margin: 0;
padding: 0;
}
</style>
<style scoped>
.dashboard-container {
padding: 40px;
min-height: 100vh;
background-color: #1e1e2d; /* 科技深色底 */
}
.title {
text-align: center;
margin-bottom: 40px;
color: #ffffff;
font-weight: 300;
letter-spacing: 2px;
}
.header-title {
font-weight: bold;
font-size: 16px;
}
.box-card {
margin-bottom: 20px;
background-color: rgba(255, 255, 255, 0.95);
}
.w-full {
width: 100%;
}
.status-board {
margin-top: 25px;
padding: 20px;
background: #f8f9fa;
border-radius: 8px;
border: 1px solid #e4e7ed;
}
.error-msg, .result-msg {
margin-top: 20px;
}
</style>

View File

@ -0,0 +1,15 @@
import axios from 'axios'
const request = axios.create({
// 注意:直接指向我们刚刚改好的 9090 端口
baseURL: 'http://127.0.0.1:9090',
timeout: 60000
})
// 拦截器:直接剥离 data
request.interceptors.response.use(
response => response.data,
error => Promise.reject(error)
)
export default request

View File

@ -0,0 +1,13 @@
import request from './request'
export const submitTrain = (data: any) => {
return request.post<any, any>('/api/modeling/train', data)
}
export const submitPredict = (data: any) => {
return request.post<any, any>('/api/modeling/predict', data)
}
export const getTaskStatus = (task_id: string) => {
return request.get<any, any>(`/api/tasks/${task_id}`)
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 8.5 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="37.07" height="36" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 198"><path fill="#41B883" d="M204.8 0H256L128 220.8L0 0h97.92L128 51.2L157.44 0h47.36Z"></path><path fill="#41B883" d="m0 0l128 220.8L256 0h-51.2L128 132.48L50.56 0H0Z"></path><path fill="#35495E" d="M50.56 0L128 133.12L204.8 0h-47.36L128 51.2L97.92 0H50.56Z"></path></svg>

After

Width:  |  Height:  |  Size: 496 B

View File

@ -0,0 +1,95 @@
<script setup lang="ts">
import { ref } from 'vue'
import viteLogo from '../assets/vite.svg'
import heroImg from '../assets/hero.png'
import vueLogo from '../assets/vue.svg'
const count = ref(0)
</script>
<template>
<section id="center">
<div class="hero">
<img :src="heroImg" class="base" width="170" height="179" alt="" />
<img :src="vueLogo" class="framework" alt="Vue logo" />
<img :src="viteLogo" class="vite" alt="Vite logo" />
</div>
<div>
<h1>Get started</h1>
<p>Edit <code>src/App.vue</code> and save to test <code>HMR</code></p>
</div>
<button type="button" class="counter" @click="count++">
Count is {{ count }}
</button>
</section>
<div class="ticks"></div>
<section id="next-steps">
<div id="docs">
<svg class="icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#documentation-icon"></use>
</svg>
<h2>Documentation</h2>
<p>Your questions, answered</p>
<ul>
<li>
<a href="https://vite.dev/" target="_blank">
<img class="logo" :src="viteLogo" alt="" />
Explore Vite
</a>
</li>
<li>
<a href="https://vuejs.org/" target="_blank">
<img class="button-icon" :src="vueLogo" alt="" />
Learn more
</a>
</li>
</ul>
</div>
<div id="social">
<svg class="icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#social-icon"></use>
</svg>
<h2>Connect with us</h2>
<p>Join the Vite community</p>
<ul>
<li>
<a href="https://github.com/vitejs/vite" target="_blank">
<svg class="button-icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#github-icon"></use>
</svg>
GitHub
</a>
</li>
<li>
<a href="https://chat.vite.dev/" target="_blank">
<svg class="button-icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#discord-icon"></use>
</svg>
Discord
</a>
</li>
<li>
<a href="https://x.com/vite_js" target="_blank">
<svg class="button-icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#x-icon"></use>
</svg>
X.com
</a>
</li>
<li>
<a href="https://bsky.app/profile/vite.dev" target="_blank">
<svg class="button-icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#bluesky-icon"></use>
</svg>
Bluesky
</a>
</li>
</ul>
</div>
</section>
<div class="ticks"></div>
<section id="spacer"></section>
</template>

View File

@ -0,0 +1,51 @@
import { ref, watch, onUnmounted, type Ref } from 'vue'
import { getTaskStatus } from '../api/tasks'
export function useTaskPoller(taskIdRef: Ref<string | null>) {
const status = ref<string>('')
const isPolling = ref(false)
const error = ref<string | null>(null)
const result = ref<any>(null)
let timer: any = null
const start = () => {
if (!taskIdRef.value) return
isPolling.value = true
error.value = null
status.value = 'PENDING'
timer = setInterval(async () => {
try {
const res = await getTaskStatus(taskIdRef.value!)
status.value = res.status
if (res.status === 'SUCCESS') {
result.value = res
stop()
} else if (res.status === 'FAILED') {
error.value = res.error || '任务执行失败'
stop()
}
} catch (e: any) {
error.value = '网络请求失败,请检查后端状态'
stop()
}
}, 2000)
}
const stop = () => {
isPolling.value = false
if (timer) clearInterval(timer)
}
// 监听 Task ID 变化自动开启轮询
watch(taskIdRef, (newVal) => {
stop()
if (newVal) start()
})
// 组件销毁时清理定时器
onUnmounted(() => stop())
return { status, isPolling, error, result, stop }
}

9
new/frontend/src/main.ts Normal file
View File

@ -0,0 +1,9 @@
import { createApp } from 'vue'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import App from './App.vue'
const app = createApp(App)
app.use(ElementPlus)
app.mount('#app')

296
new/frontend/src/style.css Normal file
View File

@ -0,0 +1,296 @@
:root {
--text: #6b6375;
--text-h: #08060d;
--bg: #fff;
--border: #e5e4e7;
--code-bg: #f4f3ec;
--accent: #aa3bff;
--accent-bg: rgba(170, 59, 255, 0.1);
--accent-border: rgba(170, 59, 255, 0.5);
--social-bg: rgba(244, 243, 236, 0.5);
--shadow:
rgba(0, 0, 0, 0.1) 0 10px 15px -3px, rgba(0, 0, 0, 0.05) 0 4px 6px -2px;
--sans: system-ui, 'Segoe UI', Roboto, sans-serif;
--heading: system-ui, 'Segoe UI', Roboto, sans-serif;
--mono: ui-monospace, Consolas, monospace;
font: 18px/145% var(--sans);
letter-spacing: 0.18px;
color-scheme: light dark;
color: var(--text);
background: var(--bg);
font-synthesis: none;
text-rendering: optimizeLegibility;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
@media (max-width: 1024px) {
font-size: 16px;
}
}
@media (prefers-color-scheme: dark) {
:root {
--text: #9ca3af;
--text-h: #f3f4f6;
--bg: #16171d;
--border: #2e303a;
--code-bg: #1f2028;
--accent: #c084fc;
--accent-bg: rgba(192, 132, 252, 0.15);
--accent-border: rgba(192, 132, 252, 0.5);
--social-bg: rgba(47, 48, 58, 0.5);
--shadow:
rgba(0, 0, 0, 0.4) 0 10px 15px -3px, rgba(0, 0, 0, 0.25) 0 4px 6px -2px;
}
#social .button-icon {
filter: invert(1) brightness(2);
}
}
body {
margin: 0;
}
h1,
h2 {
font-family: var(--heading);
font-weight: 500;
color: var(--text-h);
}
h1 {
font-size: 56px;
letter-spacing: -1.68px;
margin: 32px 0;
@media (max-width: 1024px) {
font-size: 36px;
margin: 20px 0;
}
}
h2 {
font-size: 24px;
line-height: 118%;
letter-spacing: -0.24px;
margin: 0 0 8px;
@media (max-width: 1024px) {
font-size: 20px;
}
}
p {
margin: 0;
}
code,
.counter {
font-family: var(--mono);
display: inline-flex;
border-radius: 4px;
color: var(--text-h);
}
code {
font-size: 15px;
line-height: 135%;
padding: 4px 8px;
background: var(--code-bg);
}
.counter {
font-size: 16px;
padding: 5px 10px;
border-radius: 5px;
color: var(--accent);
background: var(--accent-bg);
border: 2px solid transparent;
transition: border-color 0.3s;
margin-bottom: 24px;
&:hover {
border-color: var(--accent-border);
}
&:focus-visible {
outline: 2px solid var(--accent);
outline-offset: 2px;
}
}
.hero {
position: relative;
.base,
.framework,
.vite {
inset-inline: 0;
margin: 0 auto;
}
.base {
width: 170px;
position: relative;
z-index: 0;
}
.framework,
.vite {
position: absolute;
}
.framework {
z-index: 1;
top: 34px;
height: 28px;
transform: perspective(2000px) rotateZ(300deg) rotateX(44deg) rotateY(39deg)
scale(1.4);
}
.vite {
z-index: 0;
top: 107px;
height: 26px;
width: auto;
transform: perspective(2000px) rotateZ(300deg) rotateX(40deg) rotateY(39deg)
scale(0.8);
}
}
#app {
width: 1126px;
max-width: 100%;
margin: 0 auto;
text-align: center;
border-inline: 1px solid var(--border);
min-height: 100svh;
display: flex;
flex-direction: column;
box-sizing: border-box;
}
#center {
display: flex;
flex-direction: column;
gap: 25px;
place-content: center;
place-items: center;
flex-grow: 1;
@media (max-width: 1024px) {
padding: 32px 20px 24px;
gap: 18px;
}
}
#next-steps {
display: flex;
border-top: 1px solid var(--border);
text-align: left;
& > div {
flex: 1 1 0;
padding: 32px;
@media (max-width: 1024px) {
padding: 24px 20px;
}
}
.icon {
margin-bottom: 16px;
width: 22px;
height: 22px;
}
@media (max-width: 1024px) {
flex-direction: column;
text-align: center;
}
}
#docs {
border-right: 1px solid var(--border);
@media (max-width: 1024px) {
border-right: none;
border-bottom: 1px solid var(--border);
}
}
#next-steps ul {
list-style: none;
padding: 0;
display: flex;
gap: 8px;
margin: 32px 0 0;
.logo {
height: 18px;
}
a {
color: var(--text-h);
font-size: 16px;
border-radius: 6px;
background: var(--social-bg);
display: flex;
padding: 6px 12px;
align-items: center;
gap: 8px;
text-decoration: none;
transition: box-shadow 0.3s;
&:hover {
box-shadow: var(--shadow);
}
.button-icon {
height: 18px;
width: 18px;
}
}
@media (max-width: 1024px) {
margin-top: 20px;
flex-wrap: wrap;
justify-content: center;
li {
flex: 1 1 calc(50% - 8px);
}
a {
width: 100%;
justify-content: center;
box-sizing: border-box;
}
}
}
#spacer {
height: 88px;
border-top: 1px solid var(--border);
@media (max-width: 1024px) {
height: 48px;
}
}
.ticks {
position: relative;
width: 100%;
&::before,
&::after {
content: '';
position: absolute;
top: -4.5px;
border: 5px solid transparent;
}
&::before {
left: 0;
border-left-color: var(--border);
}
&::after {
right: 0;
border-right-color: var(--border);
}
}

View File

@ -0,0 +1,14 @@
{
"extends": "@vue/tsconfig/tsconfig.dom.json",
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
"types": ["vite/client"],
/* Linting */
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true
},
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"]
}

View File

@ -0,0 +1,7 @@
{
"files": [],
"references": [
{ "path": "./tsconfig.app.json" },
{ "path": "./tsconfig.node.json" }
]
}

View File

@ -0,0 +1,24 @@
{
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
"target": "es2023",
"lib": ["ES2023"],
"module": "esnext",
"types": ["node"],
"skipLibCheck": true,
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"verbatimModuleSyntax": true,
"moduleDetection": "force",
"noEmit": true,
/* Linting */
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true
},
"include": ["vite.config.ts"]
}

View File

@ -0,0 +1,7 @@
import { defineConfig } from 'vite'
import vue from '@vitejs/plugin-vue'
// https://vite.dev/config/
export default defineConfig({
plugins: [vue()],
})

6
run_smoke.bat Normal file
View File

@ -0,0 +1,6 @@
@echo off
cd /d "D:\111\office\ZHLduijie\1.WQ\WQ_GUI"
call venv\Scripts\activate.bat
set PYTHONPATH=new\app\api;%PYTHONPATH%
python -c "import _smoke_test_train; _smoke_test_train.test_load_train_df(); _smoke_test_train.test_get_model_pipeline_all_types(); _smoke_test_train.test_run_train_sync_linearregression_fast(); _smoke_test_train.test_run_train_sync_bad_csv(); _smoke_test_train.test_run_train_sync_bad_target(); print('OK')" > %TEMP%\smoke_log.txt 2>&1
type %TEMP%\smoke_log.txt

View File

@ -20,23 +20,28 @@ class PipelineContext:
"""流水线运行上下文(在 14 个 step 之间传递的内存字典)
字段命名约定:
- 路径字段统一 `_path` 后缀(如 water_mask_path
- 目录类字段无 `_path` 后缀(如 models_dir
- 路径字段名 = panel key 名 = step 形参名(全链路无翻译
- 训练/产物 CSV 用 `_path` 后缀(如 training_csv_path / water_mask_path
- 入参影像/CSV 沿用 panel 原名img_path / csv_path无 `_path` 后缀
- 目录类字段无 `_path` 后缀(如 models_dir / prediction_dir
- 元信息字段无后缀(如 user_config / status / log
"""
# ── 9 步主路径(按 step 输出顺序排列) ──
raw_img_path: Optional[str] = None # Step 1 入参:原始影像
# ── 11 个 step 的入参/产物(按 step 顺序排列;字段名 = panel key = step 形参 ──
img_path: Optional[str] = None # Step 1/2/3 入参:原始影像
water_mask_path: Optional[str] = None # Step 1 出 → Step 2/3/7 入
glint_mask_path: Optional[str] = None # Step 2 出 → Step 3/7 入
deglint_img_path: Optional[str] = None # Step 3 出 → Step 5/7 入
raw_csv_path: Optional[str] = None # Step 4 入:原始 CSV
csv_path: Optional[str] = None # Step 4/5/6_5/6_75:原始/训练 CSV
processed_csv_path: Optional[str] = None # Step 4 出 → Step 5 入
training_spectra_path: Optional[str] = None # Step 5 出 → Step 6
training_csv_path: Optional[str] = None # Step 5 出 → Step 5_5/6/6_5/6_75
boundary_path: Optional[str] = None # Step 5 入参:边界 SHPpanel step5 名)
indices_path: Optional[str] = None # Step 5.5 出
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/9 入
prediction_csv_path: Optional[str] = None # Step 8 出
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/8_5/8_75/9 入
prediction_csv_path: Optional[str] = None # Step 8 出 → Step 9 入
distribution_map_path: Optional[str] = None # Step 9 出
boundary_shp_path: Optional[str] = None # Step 9 入参:边界 SHPpanel step9 名)
formula_csv_path: Optional[str] = None # Step 8_75 入参:公式 CSV
# ── 目录类(命名不带 _path 以示区别) ──
models_dir: Optional[str] = None

View File

@ -4,10 +4,8 @@ PipelineRunner基于 StepSpec 声明式调度 14 个 step。
设计要点:
- StepSpec 声明 requiresctx 字段名列表)+ producesctx 字段名列表)
- 默认约定ctx 字段名去掉 `_path` 后缀 = step 方法形参名
ctx.water_mask_path → 形参 water_mask
ctx.raw_img_path → 形参 raw_img
- 可被 spec.parameter_map 覆盖
- 命名约定ctx 字段名 == panel key 名 == step 形参名(全链路无翻译)
- 保留 spec.parameter_map 字段骨架供极少数特例覆盖(默认空 dict
- 调度顺序:按 PIPELINE_STEPS 列表顺序requires 缺则 skip
- 软取消:在每个 step 前检查 ctx.is_cancelled()
- duck-typed pipelinerunner 只调 getattr(pipeline, method_name),不强依赖类层级
@ -48,101 +46,76 @@ class StepSpec:
PIPELINE_STEPS: List[StepSpec] = [
StepSpec(
step_id="step1", method_name="step1_generate_water_mask",
requires=["raw_img_path"], produces=["water_mask_path"],
# ctx.raw_img_path → 形参 img_path老 step1 形参名是 img_path不是 raw_img
parameter_map={"raw_img_path": "img_path"},
requires=["img_path"], produces=["water_mask_path"],
description="水域掩膜生成NDWI 或 SHP",
),
StepSpec(
step_id="step2", method_name="step2_find_glint_area",
requires=["raw_img_path", "water_mask_path"], produces=["glint_mask_path"],
# raw_img_path→img_pathwater_mask_path 不变
parameter_map={"raw_img_path": "img_path"},
requires=["img_path", "water_mask_path"], produces=["glint_mask_path"],
description="耀斑区域检测",
),
StepSpec(
step_id="step3", method_name="step3_remove_glint",
requires=["deglint_img_path"], produces=["deglint_img_path"],
# deglint_img_path→img_path老 step3 形参名是 img_path
# 注意glint_mask_path 不在 requires 中——step3 形参表无该参数,内部走 self.glint_mask_path 回退
parameter_map={"deglint_img_path": "img_path"},
requires=["img_path", "water_mask_path", "glint_mask_path"],
produces=["deglint_img_path"],
description="耀斑去除",
),
StepSpec(
step_id="step4", method_name="step4_process_csv",
requires=["raw_csv_path"], produces=["processed_csv_path"],
# raw_csv_path→csv_path老 step4 形参名是 csv_path
parameter_map={"raw_csv_path": "csv_path"},
requires=["csv_path"], produces=["processed_csv_path"],
description="CSV 异常值清洗",
),
StepSpec(
step_id="step5", method_name="step5_extract_training_spectra",
requires=["deglint_img_path", "processed_csv_path"], produces=["training_spectra_path"],
# processed_csv_path→csv_path老 step5 形参名是 csv_pathdeglint_img_path 不变
parameter_map={"processed_csv_path": "csv_path"},
requires=["deglint_img_path", "csv_path", "boundary_path", "glint_mask_path"],
produces=["training_csv_path"],
description="实测样本点光谱提取",
),
StepSpec(
step_id="step5_5", method_name="step5_5_calculate_water_quality_indices",
requires=["training_spectra_path"], produces=["indices_path"],
# 老 step5.5 形参是 training_spectra_pathctx 字段同名,无需映射
parameter_map={},
requires=["training_csv_path"], produces=["indices_path"],
description="水质光谱指数计算optional",
),
StepSpec(
step_id="step6", method_name="step6_train_models",
requires=["training_spectra_path"], produces=["models_dir"],
# training_spectra_path→training_csv_path老 step6 形参名是 training_csv_path
parameter_map={"training_spectra_path": "training_csv_path"},
requires=["training_csv_path"], produces=["models_dir"],
description="ML 建模GridSearchCV / AutoML",
),
StepSpec(
step_id="step6_5", method_name="step6_5_non_empirical_modeling",
requires=["training_spectra_path"], produces=["models_dir"],
# training_spectra_path→csv_path老 step6.5 形参名是 csv_path
parameter_map={"training_spectra_path": "csv_path"},
requires=["training_csv_path"], produces=["models_dir"],
description="非经验统计回归",
),
StepSpec(
step_id="step6_75", method_name="step6_75_custom_regression",
requires=["training_spectra_path"], produces=["models_dir"],
# training_spectra_path→csv_path老 step6.75 形参名是 csv_path
parameter_map={"training_spectra_path": "csv_path"},
requires=["training_csv_path"], produces=["models_dir"],
description="自定义回归分析",
),
StepSpec(
step_id="step7", method_name="step7_generate_sampling_points",
requires=["deglint_img_path", "water_mask_path"], produces=["sampling_csv_path"],
# 老 step7 形参是 deglint_img_path / water_mask_pathctx 字段同名
parameter_map={},
description="整景密集采样点生成 + 光谱提取",
),
StepSpec(
step_id="step8", method_name="step8_predict_water_quality",
requires=["sampling_csv_path", "models_dir"], produces=["prediction_csv_path"],
parameter_map={},
description="ML 模型预测(采样点)",
),
StepSpec(
step_id="step8_5", method_name="step8_5_predict_with_non_empirical_models",
requires=["sampling_csv_path"], produces=["prediction_dir"],
parameter_map={},
requires=["sampling_csv_path", "models_dir"], produces=["prediction_dir"],
description="非经验模型预测",
),
StepSpec(
step_id="step8_75", method_name="step8_75_predict_with_custom_regression",
requires=["sampling_csv_path"], produces=["prediction_dir"],
parameter_map={},
requires=["sampling_csv_path", "models_dir", "formula_csv_path"],
produces=["prediction_dir"],
description="自定义回归预测",
),
StepSpec(
step_id="step9", method_name="step9_generate_distribution_map",
requires=["prediction_csv_path"],
requires=["prediction_csv_path", "boundary_shp_path"],
produces=["distribution_map_path"],
# 老 step9 形参是 prediction_csv_path / boundary_shp_pathctx 字段同名
# 注意sampling_csv_path / water_mask_path 不在 requires 中——step9 形参表无该参数,
# 内部走 self.sampling_csv_path / self.water_mask_path 回退
parameter_map={},
description="克里金插值成图",
),
]
@ -157,7 +130,7 @@ class PipelineRunner:
用法:
runner = PipelineRunner(pipeline_instance)
ctx = PipelineContext(raw_img_path=..., ...)
ctx = PipelineContext(img_path=..., ...)
result_ctx = runner.run(ctx)
"""

View File

@ -0,0 +1,544 @@
# -*- coding: utf-8 -*-
"""
Optuna + 智能子采样 AutoML 训练器(路线 B 防爆引擎)。
为什么需要这个:
- 老路径11 预处理 × 4 模型 × 3 划分 = 132 组 GridSearchCV
对中小数据集 10 分钟+,对大数据集 5w+ 行 直接 OOM
- AutoML 路径1 预处理 × N 模型Optuna 调超参),用智能子采样避开 OOM
再用最优超参在**全量数据**上 refit最终保存单一模型
设计要点:
- 入口 train_with_automl(csv, feature_start_column, model_names, ...)
- AutoMLResult dataclass 返回(每个目标列一份)
- smart_subsampleN > max_samples 时随机下采样
- 失败兜底optuna 未装 / 全 trial 失败 → fallback 到 WaterQualityModelingBatch
- 文件命名规范:{target}_{preprocess}_{model}_AUTOML.joblib
- save_data["metadata"]["automl"] = True 标记
调用:
from src.core.prediction.automl_trainer import train_with_automl
results = train_with_automl(
training_csv_path=".../training_spectra.csv",
feature_start_column="374.285004",
model_names=["RF", "SVR", "Ridge"],
n_trials=20,
timeout_sec=300,
)
"""
from __future__ import annotations
import json
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
# ============================================================
# 常量
# ============================================================
# AutoML 寻优阶段允许的最大样本数(避免 OOM
# 5000 样本对 RF/SVR/Ridge 的 Optuna 寻优足够给出稳定 CV
DEFAULT_MAX_SAMPLES = 5000
# 单次 Optuna trial 的默认超时(秒)
DEFAULT_TIMEOUT = 300.0
# 默认 trial 数
DEFAULT_N_TRIALS = 20
# AutoML 输出目录名后缀
AUTOML_DIR_SUFFIX = "_AutoML"
# ============================================================
# 数据类
# ============================================================
@dataclass
class AutoMLResult:
"""单个目标列的 AutoML 训练结果"""
success: bool = False
model_path: Optional[str] = None
cv_score: float = -float("inf")
best_params: Optional[Dict[str, Any]] = None
target_column: str = ""
preprocessing: str = ""
model_name: str = ""
n_trials_done: int = 0
n_samples_used: int = 0
fallback_used: bool = False
elapsed_sec: float = 0.0
error: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
# ============================================================
# 智能子采样
# ============================================================
def smart_subsample(
X: np.ndarray,
y: np.ndarray,
max_samples: int = DEFAULT_MAX_SAMPLES,
random_state: int = 42,
) -> Tuple[np.ndarray, np.ndarray, bool]:
"""当 N > max_samples 时随机下采样;否则原样返回。
Returns:
(X_sub, y_sub, was_subsampled)
"""
n = X.shape[0]
if n <= max_samples:
return X, y, False
rng = np.random.default_rng(random_state)
idx = rng.choice(n, size=max_samples, replace=False)
return X[idx], y[idx], True
# ============================================================
# 模型工厂
# ============================================================
def _build_model(model_name: str, random_state: int = 42):
"""根据英文模型键名构造 sklearn-compatible 模型实例factory"""
from sklearn.ensemble import (
AdaBoostRegressor, ExtraTreesRegressor, GradientBoostingRegressor,
RandomForestRegressor,
)
from sklearn.linear_model import (
ElasticNet, Lasso, LinearRegression, Ridge,
)
from sklearn.neighbors import KNeighborsRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
factory = {
"RF": lambda **kw: RandomForestRegressor(random_state=random_state, n_jobs=1, **kw),
"ET": lambda **kw: ExtraTreesRegressor(random_state=random_state, n_jobs=1, **kw),
"GradientBoosting": lambda **kw: GradientBoostingRegressor(random_state=random_state, **kw),
"AdaBoost": lambda **kw: AdaBoostRegressor(random_state=random_state, **kw),
"Ridge": lambda **kw: Ridge(**kw),
"Lasso": lambda **kw: Lasso(max_iter=5000, **kw),
"ElasticNet": lambda **kw: ElasticNet(max_iter=5000, **kw),
"LinearRegression": lambda **kw: LinearRegression(**kw),
"SVR": lambda **kw: SVR(**kw),
"KNN": lambda **kw: KNeighborsRegressor(n_jobs=1, **kw),
"MLP": lambda **kw: MLPRegressor(max_iter=500, random_state=random_state, **kw),
"DecisionTree": lambda **kw: DecisionTreeRegressor(random_state=random_state, **kw),
"PLS": None, # sklearn.cross_decomposition.PLSRegression 暂未集成
}
builder = factory.get(model_name)
if builder is None:
return None
return builder
# ============================================================
# Optuna 超参 search space
# ============================================================
def _get_search_space(model_name: str, trial) -> Dict[str, Any]:
"""按模型名返回 Optuna 超参 search space。"""
sp: Dict[str, Any] = {}
if model_name == "RF":
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
sp["min_samples_split"] = trial.suggest_int("min_samples_split", 2, 10)
sp["min_samples_leaf"] = trial.suggest_int("min_samples_leaf", 1, 5)
elif model_name == "ET":
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
elif model_name == "GradientBoosting":
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
sp["max_depth"] = trial.suggest_int("max_depth", 3, 8)
sp["learning_rate"] = trial.suggest_float("learning_rate", 0.01, 0.3, log=True)
elif model_name == "SVR":
sp["C"] = trial.suggest_float("C", 0.1, 100.0, log=True)
sp["epsilon"] = trial.suggest_float("epsilon", 0.001, 1.0, log=True)
sp["kernel"] = trial.suggest_categorical("kernel", ["rbf", "linear"])
elif model_name == "KNN":
sp["n_neighbors"] = trial.suggest_int("n_neighbors", 3, 20)
sp["weights"] = trial.suggest_categorical("weights", ["uniform", "distance"])
elif model_name in ("Ridge", "Lasso", "ElasticNet"):
sp["alpha"] = trial.suggest_float("alpha", 0.01, 100.0, log=True)
if model_name == "ElasticNet":
sp["l1_ratio"] = trial.suggest_float("l1_ratio", 0.0, 1.0)
elif model_name == "MLP":
sp["hidden_layer_sizes"] = trial.suggest_categorical(
"hidden_layer_sizes", [(50,), (100,), (50, 50), (100, 50)]
)
sp["alpha"] = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
sp["learning_rate_init"] = trial.suggest_float("learning_rate_init", 1e-4, 1e-2, log=True)
elif model_name == "DecisionTree":
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
sp["min_samples_split"] = trial.suggest_int("min_samples_split", 2, 10)
elif model_name == "AdaBoost":
sp["n_estimators"] = trial.suggest_int("n_estimators", 30, 200, step=30)
sp["learning_rate"] = trial.suggest_float("learning_rate", 0.01, 1.0, log=True)
else:
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 200, step=50)
return sp
def _make_objective(model_name: str, X: np.ndarray, y: np.ndarray,
cv_folds: int, random_state: int):
"""构造 Optuna objective5 折 CV R²"""
from sklearn.model_selection import KFold, cross_val_score
def objective(trial):
params = _get_search_space(model_name, trial)
try:
builder = _build_model(model_name, random_state=random_state)
if builder is None:
return -1.0
model = builder(**params)
kf = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
scores = cross_val_score(model, X, y, cv=kf, scoring="r2", n_jobs=1)
return float(np.mean(scores))
except Exception:
return -1.0
return objective
def _refit_full(model_name: str, best_params: Dict[str, Any],
X: np.ndarray, y: np.ndarray, random_state: int):
"""用 best params 在**全量数据**上 refit。"""
builder = _build_model(model_name, random_state=random_state)
if builder is None:
return None
model = builder(**best_params)
model.fit(X, y)
return model
# ============================================================
# 失败兜底(回退到老 GridSearchCV 路径)
# ============================================================
def _fallback_train(
training_csv_path: str,
feature_start_column,
preprocessing: str,
model_name: str,
split_method: str,
cv_folds: int,
output_dir: Path,
target_column: str,
) -> AutoMLResult:
"""AutoML 失败时调老 WaterQualityModelingBatch。
返回的 AutoMLResult.fallback_used=True。
"""
try:
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
except ImportError as e:
return AutoMLResult(
success=False, error=f"fallback 导入失败: {e!r}", fallback_used=True,
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
)
try:
out_dir = output_dir / preprocessing
out_dir.mkdir(parents=True, exist_ok=True)
modeler = WaterQualityModelingBatch(str(out_dir))
modeler.train_models_batch(
csv_path=training_csv_path,
feature_start_column=feature_start_column,
preprocessing_methods=[preprocessing],
model_names=[model_name],
split_methods=[split_method],
cv_folds=cv_folds,
)
# 找产出
candidates = list(out_dir.rglob(f"{target_column}_{preprocessing}_{model_name}.joblib"))
model_path = str(candidates[0]) if candidates else None
return AutoMLResult(
success=model_path is not None,
model_path=model_path,
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
fallback_used=True,
metadata={"source": "WaterQualityModelingBatch"},
)
except Exception as e:
return AutoMLResult(
success=False, error=f"fallback 失败: {e!r}", fallback_used=True,
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
)
# ============================================================
# 主入口
# ============================================================
def train_with_automl(
training_csv_path: str,
feature_start_column,
preprocessing_methods: Optional[List[str]] = None,
model_names: Optional[List[str]] = None,
split_methods: Optional[List[str]] = None,
cv_folds: int = 5,
output_dir: Optional[str] = None,
n_trials: int = DEFAULT_N_TRIALS,
timeout_sec: float = DEFAULT_TIMEOUT,
max_samples: int = DEFAULT_MAX_SAMPLES,
random_state: int = 42,
callback: Optional[Callable[[str, str, str], None]] = None,
) -> List[AutoMLResult]:
"""用 Optuna + 子采样跑 AutoML。失败时自动回退到 GridSearchCV。
Args:
training_csv_path: 训练用 CSVStep 5 产物 training_spectra.csv
feature_start_column: 特征起始列名或索引(之前所有列视为目标 y
preprocessing_methods: 候选预处理列表(**仅用第 1 个**,避免笛卡尔爆炸)
model_names: 候选模型列表(每个都会跑一遍 Optuna
split_methods: 候选数据划分列表AutoML 仅用第 1 个)
cv_folds: 交叉验证折数
output_dir: 输出目录(默认 <models_dir>_AutoML
n_trials: 单模型 Optuna trial 数
timeout_sec: 单模型超时(秒),到时强制停止
max_samples: 寻优阶段允许的最大样本数
callback: 状态回调 callback(step_name, status, message)
Returns:
List[AutoMLResult],每个目标列一份结果
"""
def notify(status: str, msg: str = "") -> None:
if callback:
callback("步骤6_AutoML", status, msg)
# ---- 1) 参数默认值 ----
if preprocessing_methods is None:
preprocessing_methods = ["MMS"]
if model_names is None:
model_names = ["RF", "SVR", "Ridge"]
if split_methods is None:
split_methods = ["spxy"]
# 决策:仅用第一个预处理 + 第一个划分,避免笛卡尔爆炸
preproc = preprocessing_methods[0]
split_method = split_methods[0]
if output_dir is None:
output_dir = "./7_Supervised_Model_Training_AutoML"
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
preproc_dir = out_dir / preproc
preproc_dir.mkdir(parents=True, exist_ok=True)
# ---- 2) 加载数据 ----
notify("start", f"AutoML 训练开始 (n_trials={n_trials}, timeout={timeout_sec}s, max_samples={max_samples})")
if not Path(training_csv_path).exists():
return [AutoMLResult(success=False, error=f"训练 CSV 不存在: {training_csv_path}")]
df = pd.read_csv(training_csv_path)
# 提取目标列feature_start_column 之前所有数值列)
if isinstance(feature_start_column, int):
y_cols = [c for c in df.columns[:feature_start_column]
if pd.api.types.is_numeric_dtype(df[c])]
else:
try:
idx = list(df.columns).index(feature_start_column)
y_cols = [c for c in df.columns[:idx]
if pd.api.types.is_numeric_dtype(df[c])]
except ValueError:
y_cols = []
if not y_cols:
notify("error", "AutoML: 未识别出目标列feature_start_column 之前的所有数值列)")
return [AutoMLResult(success=False, error="未识别出目标列")]
feat_cols = [c for c in df.columns if c not in y_cols]
X_all = df[feat_cols].values.astype(np.float64)
# ---- 3) 预处理(仅第一项) ----
if preproc != "None":
try:
from src.preprocessing.spectral_Preprocessing import Preprocessing
processed = Preprocessing(preproc, df[feat_cols])
if isinstance(processed, pd.DataFrame):
X_all = processed.values.astype(np.float64)
else:
X_all = np.asarray(processed, dtype=np.float64)
except Exception as e:
notify("warning", f"预处理 {preproc} 失败: {e!r},改用 None")
preproc = "None"
# ---- 4) 检查 Optuna 是否可用 ----
try:
import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
optuna_available = True
except ImportError:
optuna_available = False
notify("warning", "optuna 未安装,全目标列回退到 GridSearchCVpip install \"optuna>=3.6\"")
# ---- 5) 逐 target 跑 ----
results: List[AutoMLResult] = []
total = len(y_cols)
per_model_timeout = max(10.0, timeout_sec / max(1, len(model_names)))
for ti, tgt in enumerate(y_cols, 1):
t0 = time.time()
yv = df[tgt].values.astype(np.float64)
mask = ~np.isnan(yv)
X_t = X_all[mask]
y_t = yv[mask]
if X_t.shape[0] < cv_folds * 2:
notify("warning", f"目标 {tgt}: 有效样本 {X_t.shape[0]} 不足,跳过")
results.append(AutoMLResult(
success=False, target_column=tgt, error=f"样本不足({X_t.shape[0]})",
preprocessing=preproc,
))
continue
X_sub, y_sub, was_sub = smart_subsample(X_t, y_t, max_samples=max_samples, random_state=random_state)
if was_sub:
notify("info", f"目标 {tgt}: {X_t.shape[0]} 样本 → 子采样 {X_sub.shape[0]}(寻优用)")
best_overall = AutoMLResult(success=False, target_column=tgt, preprocessing=preproc)
if not optuna_available:
# 全目标列一次性 fallback
best_overall = _fallback_train(
training_csv_path, feature_start_column, preproc, model_names[0], split_method,
cv_folds, out_dir, tgt,
)
else:
for model_name in model_names:
try:
builder = _build_model(model_name, random_state=random_state)
if builder is None:
notify("warning", f"模型 {model_name} 暂不支持 AutoML 寻优")
continue
study = optuna.create_study(
direction="maximize",
sampler=optuna.samplers.TPESampler(seed=random_state),
)
study.optimize(
_make_objective(model_name, X_sub, y_sub, cv_folds, random_state),
n_trials=n_trials,
timeout=per_model_timeout,
show_progress_bar=False,
)
if study.best_value is None or study.best_value <= -1.0:
notify("warning", f"{tgt}/{model_name}: 全部 trial 失败CV 全部 <= -1")
continue
# refit on FULL
final_model = _refit_full(model_name, study.best_params, X_t, y_t, random_state)
if final_model is None:
continue
# 保存
import joblib
fname = f"{tgt}_{preproc}_{model_name}_AUTOML.joblib"
fpath = preproc_dir / fname
joblib.dump({
"model": final_model,
"target_column_name": tgt,
"preprocess_method": preproc,
"model_name": model_name,
"metadata": {
"automl": True,
"best_params": study.best_params,
"cv_score": float(study.best_value),
"n_trials_done": len(study.trials),
"n_samples_used_full": int(X_t.shape[0]),
"n_samples_used_for_search": int(X_sub.shape[0]),
"was_subsampled": was_sub,
"split_method": split_method,
},
}, fpath)
cand = AutoMLResult(
success=True,
model_path=str(fpath),
cv_score=float(study.best_value),
best_params=study.best_params,
target_column=tgt,
preprocessing=preproc,
model_name=model_name,
n_trials_done=len(study.trials),
n_samples_used=int(X_sub.shape[0]),
metadata={"refit_on_full": True, "n_samples_full": int(X_t.shape[0])},
)
if cand.cv_score > best_overall.cv_score:
best_overall = cand
except Exception as e:
notify("warning", f"目标 {tgt} / 模型 {model_name} 失败: {e!r}")
continue
if not best_overall.success:
notify("warning", f"目标 {tgt} 全部 Optuna trial 失败,回退 GridSearchCV")
best_overall = _fallback_train(
training_csv_path, feature_start_column, preproc, model_names[0], split_method,
cv_folds, out_dir, tgt,
)
best_overall.elapsed_sec = time.time() - t0
results.append(best_overall)
notify("info", f"AutoML 目标 {tgt} 完成 ({ti}/{total}) cv={best_overall.cv_score:.4f}")
# ---- 6) 汇总 json ----
summary_path = out_dir / "automl_summary.json"
try:
with open(summary_path, "w", encoding="utf-8") as f:
json.dump([asdict(r) for r in results], f, ensure_ascii=False, indent=2, default=str)
except Exception as e:
notify("warning", f"写 automl_summary.json 失败: {e!r}")
success_n = sum(1 for r in results if r.success)
fallback_n = sum(1 for r in results if r.fallback_used)
notify("completed", f"AutoML 训练完成 {success_n}/{len(results)} 成功({fallback_n} 走 fallback汇总 {summary_path}")
return results
# ============================================================
# CLI 自测
# ============================================================
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser(description="AutoML 训练器 CLI 自测")
p.add_argument("--csv", required=True, help="训练用 CSVfeature_start_column 之前的列为目标 y")
p.add_argument("--feature-start", default="0", help="特征起始列名或索引(默认 0")
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
p.add_argument("--out", default="./7_Supervised_Model_Training_AutoML")
args = p.parse_args()
# 智能推断 feature_start_column 类型
fsc: Any = args.feature_start
try:
fsc = int(fsc)
except ValueError:
pass
res = train_with_automl(
training_csv_path=args.csv,
feature_start_column=fsc,
n_trials=args.n_trials,
timeout_sec=args.timeout,
max_samples=args.max_samples,
output_dir=args.out,
)
print(f"\n训练完成 {len(res)} 个目标")
for r in res:
marker = "" if r.success else ""
fb = " [fallback]" if r.fallback_used else ""
print(f" {marker} {r.target_column}: cv={r.cv_score:.4f} path={r.model_path}{fb}")

View File

@ -126,7 +126,7 @@ class DataPreparationStep:
@staticmethod
def calculate_water_quality_indices(
training_spectra_path: Optional[str] = None,
training_csv_path: Optional[str] = None,
formula_csv_file: Optional[str] = None,
formula_names: Optional[List[str]] = None,
output_file: Optional[str] = None,
@ -153,8 +153,8 @@ class DataPreparationStep:
notify("skipped", "跳过水质指数计算")
return None
if training_spectra_path is None:
raise ValueError("必须提供 training_spectra_path 参数")
if training_csv_path is None:
raise ValueError("必须提供 training_csv_path 参数")
if formula_csv_file is None:
raise ValueError("必须提供 formula_csv_file 参数")
@ -170,7 +170,7 @@ class DataPreparationStep:
from src.utils.band_math import BandMathCalculator
calculator = BandMathCalculator(training_spectra_path)
calculator = BandMathCalculator(training_csv_path)
result_df = calculator.process_formulas_from_csv(
formula_csv_file=formula_csv_file,
formula_names=formula_names,

View File

@ -173,7 +173,7 @@ class WaterQualityInversionPipeline:
self.interpolated_img_path = None # 存储插值后的影像路径
self.deglint_img_path = None
self.processed_csv_path = None
self.training_spectra_path = None
self.training_csv_path = None
self.indices_path = None
self.custom_regression_path = None
@ -511,7 +511,7 @@ class WaterQualityInversionPipeline:
left_shoulder_wave: Optional[float] = None,
valley_wave: Optional[float] = None,
right_shoulder_wave: Optional[float] = None,
water_mask: Optional[Union[str, np.ndarray]] = None,
water_mask_path: Optional[Union[str, np.ndarray]] = None,
interpolate_zeros: bool = False,
interpolation_method: str = 'nearest',
enabled: bool = True,
@ -546,7 +546,7 @@ class WaterQualityInversionPipeline:
left_shoulder_wave=left_shoulder_wave,
valley_wave=valley_wave,
right_shoulder_wave=right_shoulder_wave,
water_mask=water_mask,
water_mask=water_mask_path,
interpolate_zeros=interpolate_zeros,
interpolation_method=interpolation_method,
enabled=enabled,
@ -655,13 +655,13 @@ class WaterQualityInversionPipeline:
water_mask_path=self.water_mask_path,
output_dir=str(self.training_spectra_dir),
)
self.training_spectra_path = result
self.training_csv_path = result
self._record_step_time("步骤5: 提取训练样本点光谱", 0, 0)
self._notify("completed", f"训练光谱数据已保存: {result}")
return result
def step5_5_calculate_water_quality_indices(self,
training_spectra_path: Optional[str] = None,
training_csv_path: Optional[str] = None,
formula_csv_file: Optional[str] = None,
formula_names: Optional[List[str]] = None,
output_file: Optional[str] = None,
@ -673,7 +673,7 @@ class WaterQualityInversionPipeline:
使用band_math.py中的方法实现支持从公式CSV文件中批量计算指定公式
Args:
training_spectra_path: 训练光谱数据CSV路径如果为None使用步骤5的结果
training_csv_path: 训练光谱数据CSV路径如果为None使用步骤5的结果
formula_csv_file: 公式CSV文件路径包含公式名称和具体公式
formula_names: 要计算的公式名称列表如果为None则计算所有公式
output_file: 输出文件完整路径支持绝对路径如果为None则使用默认路径
@ -682,16 +682,16 @@ class WaterQualityInversionPipeline:
包含计算结果的新CSV文件路径
"""
# 参数解析(保留原逻辑)
if training_spectra_path is not None:
csv_path = training_spectra_path
elif self.training_spectra_path is not None:
csv_path = self.training_spectra_path
if training_csv_path is not None:
csv_path = training_csv_path
elif self.training_csv_path is not None:
csv_path = self.training_csv_path
else:
csv_path = None
self._notify("started", "步骤5.5: 计算水质光谱指数")
result = DataPreparationStep.calculate_water_quality_indices(
training_spectra_path=csv_path,
training_csv_path=csv_path,
formula_csv_file=formula_csv_file,
formula_names=formula_names,
output_file=output_file,
@ -727,8 +727,8 @@ class WaterQualityInversionPipeline:
# 参数解析(保留原逻辑)
if training_csv_path is not None:
final_csv_path = training_csv_path
elif self.training_spectra_path is not None:
final_csv_path = self.training_spectra_path
elif self.training_csv_path is not None:
final_csv_path = self.training_csv_path
else:
final_csv_path = None
@ -911,7 +911,7 @@ class WaterQualityInversionPipeline:
print("="*80)
if training_csv_path is None:
training_csv_path = self.training_spectra_path
training_csv_path = self.training_csv_path
if training_csv_path is None:
raise ValueError("请提供训练数据CSV路径或先执行步骤5")
@ -1033,7 +1033,7 @@ class WaterQualityInversionPipeline:
print("="*80)
if csv_path is None:
csv_path = self.training_spectra_path
csv_path = self.training_csv_path
if csv_path is None:
raise ValueError("请提供CSV文件路径或先执行步骤5")
@ -1506,7 +1506,7 @@ class WaterQualityInversionPipeline:
if 'step5' in config:
self._notify("步骤5: 光谱提取", "start")
self.step5_extract_training_spectra(**config['step5'])
self._notify("步骤5: 光谱提取", "completed", f"(输出: {self.training_spectra_path})")
self._notify("步骤5: 光谱提取", "completed", f"(输出: {self.training_csv_path})")
else:
self._notify("步骤5: 光谱提取", "skipped", "未配置")
@ -1615,7 +1615,7 @@ class WaterQualityInversionPipeline:
# 生成散点图
if 'visualization' in config and config['visualization'].get('generate_scatter', True):
if self.training_spectra_path and self.models_dir.exists():
if self.training_csv_path and self.models_dir.exists():
try:
self._notify("可视化", "info", "生成模型评估散点图...")
scatter_config = config['visualization'].get('scatter_config', {})
@ -1653,7 +1653,7 @@ class WaterQualityInversionPipeline:
# 生成光谱曲线图
if 'visualization' in config and config['visualization'].get('generate_spectrum', True):
if self.training_spectra_path:
if self.training_csv_path:
try:
self._notify("可视化", "info", "生成光谱曲线对比图...")
spectrum_paths = self.generate_spectrum_comparison_plots(
@ -1701,7 +1701,7 @@ class WaterQualityInversionPipeline:
pipeline_info['step2'] = {'status': 'completed', 'output_file': str(self.glint_mask_path) if self.glint_mask_path else 'N/A'}
pipeline_info['step3'] = {'status': 'completed', 'output_file': str(self.deglint_img_path) if self.deglint_img_path else 'N/A'}
pipeline_info['step4'] = {'status': 'completed', 'output_file': str(self.processed_csv_path) if self.processed_csv_path else 'N/A'}
pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_spectra_path) if self.training_spectra_path else 'N/A'}
pipeline_info['step5'] = {'status': 'completed', 'output_file': str(self.training_csv_path) if self.training_csv_path else 'N/A'}
pipeline_info['step5_5'] = {'status': 'completed', 'output_file': str(self.indices_path) if self.indices_path else 'N/A'}
pipeline_info['step6'] = {'status': 'completed', 'output_file': str(self.models_dir)}
pipeline_info['step6_75'] = {'status': 'completed', 'output_file': str(self.custom_regression_path) if self.custom_regression_path else 'N/A'}
@ -1784,8 +1784,8 @@ class WaterQualityInversionPipeline:
# 参数解析(保留原逻辑)
if csv_path is not None:
final_csv_path = csv_path
elif self.training_spectra_path is not None:
final_csv_path = self.training_spectra_path
elif self.training_csv_path is not None:
final_csv_path = self.training_csv_path
else:
final_csv_path = None
@ -2109,7 +2109,7 @@ def main():
'interpolation_method': 'bilinear', # 插值方法: 'nearest'(邻近), 'bilinear'(双线性),
# 'spline'(样条), 'kriging'(克里金)
# 水域掩膜参数(可选):
'water_mask':r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp", # None表示自动使用步骤1生成的掩膜也可以提供
'water_mask_path':r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp", # None表示自动使用步骤1生成的掩膜也可以提供
# # - numpy数组
# # - 栅格文件路径(.dat/.tif)
# # - shapefile路径(.shp)

View File

@ -0,0 +1,430 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
图表与交互弹窗模块
包含 ChartViewerDialog、ChartBrowserDialog 和 InteractiveViewerDialog 类。
"""
import numpy as np
from PyQt5.QtWidgets import (
QDialog, QVBoxLayout, QHBoxLayout, QPushButton,
QSizePolicy, QFileDialog, QMessageBox, QGroupBox,
QListWidget, QLabel, QComboBox, QCheckBox,
)
from PyQt5.QtCore import Qt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
class ChartViewerDialog(QDialog):
"""图表查看器对话框"""
def __init__(self, title="图表查看器", parent=None):
super().__init__(parent)
self.setWindowTitle(title)
self.resize(1000, 700)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
self.figure = Figure(figsize=(10, 7))
self.canvas = FigureCanvas(self.figure)
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.toolbar = NavigationToolbar(self.canvas, self)
layout.addWidget(self.toolbar)
layout.addWidget(self.canvas)
btn_layout = QHBoxLayout()
self.save_btn = QPushButton("保存图表")
self.save_btn.clicked.connect(self.save_chart)
btn_layout.addWidget(self.save_btn)
btn_layout.addStretch()
self.close_btn = QPushButton("关闭")
self.close_btn.clicked.connect(self.close)
btn_layout.addWidget(self.close_btn)
layout.addLayout(btn_layout)
self.setLayout(layout)
def display_image(self, image_path):
"""显示图片"""
self.figure.clear()
ax = self.figure.add_subplot(111)
try:
import matplotlib.image as mpimg
img = mpimg.imread(image_path)
ax.imshow(img)
ax.axis('off')
self.figure.tight_layout()
self.canvas.draw()
self.current_image_path = image_path
except Exception as e:
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
def display_custom_plot(self, plot_func):
"""显示自定义绘图函数"""
self.figure.clear()
try:
plot_func(self.figure)
self.canvas.draw()
except Exception as e:
ax = self.figure.add_subplot(111)
ax.text(0.5, 0.5, f'绘图失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
def save_chart(self):
"""保存图表"""
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图表", "",
"PNG图片 (*.png);;JPG图片 (*.jpg);;PDF文件 (*.pdf);;所有文件 (*.*)"
)
if file_path:
try:
self.figure.savefig(file_path, dpi=300, bbox_inches='tight')
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
class ChartBrowserDialog(QDialog):
"""图表浏览器对话框"""
def __init__(self, chart_files, parent=None):
super().__init__(parent)
self.chart_files = sorted(chart_files, key=lambda x: x.stat().st_mtime, reverse=True)
self.current_index = 0
self.setWindowTitle("图表浏览器")
self.resize(1200, 800)
self.init_ui()
self.show_chart(0)
def init_ui(self):
layout = QVBoxLayout()
list_group = QGroupBox(f"图表列表 (共 {len(self.chart_files)} 个)")
list_layout = QHBoxLayout()
self.chart_list = QListWidget()
self.chart_list.setMaximumHeight(150)
for chart_file in self.chart_files:
self.chart_list.addItem(chart_file.name)
self.chart_list.currentRowChanged.connect(self.show_chart)
list_layout.addWidget(self.chart_list)
list_group.setLayout(list_layout)
layout.addWidget(list_group)
self.figure = Figure(figsize=(12, 8))
self.canvas = FigureCanvas(self.figure)
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.toolbar = NavigationToolbar(self.canvas, self)
layout.addWidget(self.toolbar)
layout.addWidget(self.canvas, 1)
btn_layout = QHBoxLayout()
self.prev_btn = QPushButton("◀ 上一个")
self.prev_btn.clicked.connect(self.prev_chart)
btn_layout.addWidget(self.prev_btn)
self.next_btn = QPushButton("下一个 >")
self.next_btn.clicked.connect(self.next_chart)
btn_layout.addWidget(self.next_btn)
btn_layout.addStretch()
self.save_btn = QPushButton("💾 保存当前图表")
self.save_btn.clicked.connect(self.save_current_chart)
btn_layout.addWidget(self.save_btn)
self.close_btn = QPushButton("关闭")
self.close_btn.clicked.connect(self.close)
btn_layout.addWidget(self.close_btn)
layout.addLayout(btn_layout)
self.setLayout(layout)
def show_chart(self, index):
"""显示指定索引的图表"""
if 0 <= index < len(self.chart_files):
self.current_index = index
self.chart_list.setCurrentRow(index)
chart_file = self.chart_files[index]
self.figure.clear()
ax = self.figure.add_subplot(111)
try:
import matplotlib.image as mpimg
img = mpimg.imread(str(chart_file))
ax.imshow(img)
ax.axis('off')
ax.set_title(chart_file.name, fontsize=12, pad=10)
self.figure.tight_layout()
self.canvas.draw()
except Exception as e:
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
self.prev_btn.setEnabled(index > 0)
self.next_btn.setEnabled(index < len(self.chart_files) - 1)
def prev_chart(self):
"""上一个图表"""
if self.current_index > 0:
self.show_chart(self.current_index - 1)
def next_chart(self):
"""下一个图表"""
if self.current_index < len(self.chart_files) - 1:
self.show_chart(self.current_index + 1)
def save_current_chart(self):
"""保存当前图表"""
if 0 <= self.current_index < len(self.chart_files):
current_file = self.chart_files[self.current_index]
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图表", current_file.name,
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
)
if file_path:
try:
import shutil
shutil.copy(str(current_file), file_path)
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
class InteractiveViewerDialog(QDialog):
"""交互式影像预览对话框:显示影像、参考点散点图、点击查询坐标/值"""
def __init__(self, parent, img_path, ref_csv=None):
super().__init__(parent)
self.img_path = img_path
self.ref_csv = ref_csv
self.geotransform = None
self.fig = None
self.canvas = None
self.ax = None
self.status_label = None
self.init_ui()
def init_ui(self):
self.setWindowTitle("👁️ 交互式影像预览")
self.setMinimumSize(900, 700)
layout = QVBoxLayout()
toolbar = QHBoxLayout()
self.band_combo = QComboBox()
self.band_combo.currentIndexChanged.connect(self.on_band_changed)
toolbar.addWidget(QLabel("显示波段:"))
toolbar.addWidget(self.band_combo)
self.gray_check = QCheckBox("灰度显示")
self.gray_check.stateChanged.connect(self.on_band_changed)
toolbar.addWidget(self.gray_check)
toolbar.addStretch()
layout.addLayout(toolbar)
try:
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import matplotlib
matplotlib.use('Qt5Agg')
self.fig = Figure(figsize=(10, 8))
self.canvas = FigureCanvas(self.fig)
self.ax = self.fig.add_subplot(111)
self.fig.tight_layout()
layout.addWidget(self.canvas)
self.load_and_display()
except ImportError as e:
layout.addWidget(QLabel(f"Matplotlib 未安装: {e}"))
self.status_label = QLabel("点击影像查看像素坐标和经纬度")
self.status_label.setStyleSheet("background:#f0f0f0;padding:4px;font-size:12px;")
self.status_label.setWordWrap(True)
layout.addWidget(self.status_label)
close_btn = QPushButton("关闭")
close_btn.clicked.connect(self.close)
layout.addWidget(close_btn)
self.setLayout(layout)
def load_and_display(self):
"""加载影像并显示"""
from osgeo import gdal
dataset = gdal.Open(self.img_path)
if dataset is None:
self.status_label.setText(f"无法打开影像: {self.img_path}")
return
self.geotransform = dataset.GetGeoTransform()
self.projection = dataset.GetProjection()
n_bands = dataset.RasterCount
self.height = dataset.RasterYSize
self.width = dataset.RasterXSize
self.band_combo.clear()
if n_bands >= 3:
for i in range(1, n_bands + 1):
self.band_combo.addItem(f"RGB (B{i-0}, G{i-1}, R{i-2})" if i >= 3 else f"波段 {i}", i)
self.band_combo.addItem(f"单波段 (B1)", 0)
else:
for i in range(1, n_bands + 1):
self.band_combo.addItem(f"波段 {i}", i - 1)
self.band_combo.setCurrentIndex(0)
self.dataset = dataset
self.display_band(0, is_gray=False)
self.load_ref_points()
def display_band(self, band_idx, is_gray=False):
"""显示指定波段组合"""
from osgeo import gdal
import numpy as np
dataset = self.dataset
self.ax.clear()
if is_gray or (self.band_combo.currentData() == 0 and dataset.RasterCount == 1):
band = dataset.GetRasterBand(1 if band_idx == 0 else band_idx + 1)
data = band.ReadAsArray()
data = np.nan_to_num(data, nan=0.0)
self.ax.imshow(data, cmap='gray')
self.ax.set_title(f"波段 {band_idx + 1} (灰度)")
else:
n = min(3, dataset.RasterCount)
bands_data = []
for i in range(n):
b = dataset.GetRasterBand(i + 1)
bd = b.ReadAsArray()
bd = np.nan_to_num(bd, nan=0.0)
bands_data.append(bd)
rgb = np.dstack(bands_data)
for i in range(rgb.shape[2]):
p2, p98 = np.percentile(rgb[:, :, i], [2, 98])
if p98 > p2:
rgb[:, :, i] = np.clip((rgb[:, :, i] - p2) / (p98 - p2), 0, 1)
else:
rgb[:, :, i] = np.clip(rgb[:, :, i] / (p98 + 1e-6), 0, 1)
self.ax.imshow(rgb)
self.ax.set_title(f"RGB 显示")
self.ax.set_xlabel("列 (Column)")
self.ax.set_ylabel("行 (Row)")
self.fig.tight_layout()
self.canvas.draw()
self.cid = self.canvas.mpl_connect('button_press_event', self.on_click)
def on_band_changed(self):
"""波段选择变化时更新显示"""
if not hasattr(self, 'dataset'):
return
is_gray = self.gray_check.isChecked()
band_data = self.band_combo.currentData()
self.display_band(band_data if band_data != 0 else 0, is_gray=is_gray)
def load_ref_points(self):
"""加载并显示参考点"""
import os
if not self.ref_csv or not os.path.isfile(self.ref_csv):
return
try:
import csv
lon_list, lat_list = [], []
with open(self.ref_csv, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
try:
lon = float(row.get('Lon', row.get('lon', row.get('LON', 0))))
lat = float(row.get('Lat', row.get('lat', row.get('LAT', 0))))
if lon and lat:
lon_list.append(lon)
lat_list.append(lat)
except (ValueError, TypeError):
continue
if not lon_list:
return
px_list, py_list = [], []
gt = self.geotransform
if gt and (gt[1] != 0 or gt[5] != 0):
for lon, lat in zip(lon_list, lat_list):
px = (lon - gt[0]) / gt[1]
py = (lat - gt[3]) / gt[5]
if 0 <= px < self.width and 0 <= py < self.height:
px_list.append(px)
py_list.append(py)
if px_list:
self.ax.scatter(px_list, py_list, c='red', s=40, marker='o',
edgecolors='white', linewidths=0.8, zorder=5, alpha=0.9,
label=f'参考点 ({len(px_list)}个)')
self.ax.legend(loc='upper right', fontsize=9)
self.fig.tight_layout()
self.canvas.draw()
self.status_label.setText(
f"已加载 {len(px_list)} 个参考点(仅显示在影像范围内的点)"
)
except Exception as e:
self.status_label.setText(f"加载参考点失败: {e}")
def pixel_to_geo(self, px, py):
"""像素坐标转经纬度"""
gt = self.geotransform
if gt is None:
return None, None
lon = gt[0] + px * gt[1] + py * gt[2]
lat = gt[3] + px * gt[4] + py * gt[5]
return lon, lat
def on_click(self, event):
"""鼠标点击事件"""
if event.inaxes != self.ax or event.xdata is None or event.ydata is None:
return
px, py = int(round(event.xdata)), int(round(event.ydata))
if not (0 <= px < self.width and 0 <= py < self.height):
return
from osgeo import gdal
import numpy as np
dataset = self.dataset
n_bands = dataset.RasterCount
vals = []
for b in range(1, n_bands + 1):
val = dataset.GetRasterBand(b).ReadAsArray()[py, px]
vals.append(f"{val:.4f}" if isinstance(val, float) else str(val))
lon, lat = self.pixel_to_geo(px, py)
geo_str = f"Lon={lon:.6f}, Lat={lat:.6f}" if lon is not None else "无地理参考"
self.status_label.setText(
f"像素: (行={py}, 列={px}) | {geo_str} | "
f"波段值: {' | '.join(vals[:5])}" +
(f" ... ({n_bands}波段的更多信息)" if n_bands > 5 else "")
)

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据模型模块
包含 PandasTableModel 等数据模型类。
"""
import pandas as pd
from PyQt5.QtCore import Qt, QAbstractTableModel
class PandasTableModel(QAbstractTableModel):
"""支持DataFrame的表格模型"""
def __init__(self, data_frame: pd.DataFrame):
super().__init__()
self._data = data_frame.copy()
if self._data.empty:
self._data = pd.DataFrame()
self._data.fillna("", inplace=True)
self._columns = [str(col) for col in self._data.columns]
def rowCount(self, parent=None):
return len(self._data)
def columnCount(self, parent=None):
return len(self._columns)
def data(self, index, role=Qt.DisplayRole):
if not index.isValid() or role != Qt.DisplayRole:
return None
value = self._data.iat[index.row(), index.column()]
if pd.isna(value):
return ""
return str(value)
def headerData(self, section, orientation, role=Qt.DisplayRole):
if role != Qt.DisplayRole:
return None
if orientation == Qt.Horizontal:
if section < len(self._columns):
return self._columns[section]
return str(section)
return str(section + 1)
def flags(self, index):
if not index.isValid():
return Qt.NoItemFlags
return Qt.ItemIsEnabled | Qt.ItemIsSelectable

View File

@ -0,0 +1,351 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
图像浏览组件模块
包含 ImageCategoryTree 和 ImageViewerWidget 类。
"""
import os
from pathlib import Path
from typing import List, Optional
from PyQt5.QtWidgets import (
QTreeWidget, QTreeWidgetItem, QWidget, QVBoxLayout, QHBoxLayout,
QPushButton, QLabel, QScrollArea, QFrame, QGroupBox,
QFileDialog, QMessageBox,
)
from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtGui import QPixmap
class ImageCategoryTree(QTreeWidget):
"""图像分类目录树 - 按类别组织图像文件"""
CATEGORIES = [
("模型评估", ["scatter", "regression", "validation", "r2", "rmse"], "📊"),
("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"),
("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"),
("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"),
("含量分布图", [], "📁"),
]
def __init__(self, parent=None):
super().__init__(parent)
self.setHeaderLabel("图像目录")
self.setMaximumWidth(300)
self.setMinimumWidth(250)
self.setup_categories()
self.setStyleSheet("""
QTreeWidget {
border: 1px solid #ddd;
border-radius: 5px;
background-color: #f8f9fa;
}
QTreeWidget::item {
padding: 5px;
border-radius: 3px;
}
QTreeWidget::item:selected {
background-color: #0078D4;
color: white;
}
QTreeWidget::item:hover {
background-color: #e3f2fd;
}
""")
def setup_categories(self):
"""初始化类别节点"""
self.category_items = {}
for category_name, keywords, icon in self.CATEGORIES:
item = QTreeWidgetItem(self)
item.setText(0, f"{icon} {category_name}")
item.setData(0, Qt.UserRole, {"type": "category", "keywords": keywords, "name": category_name})
item.setExpanded(True)
self.category_items[category_name] = item
def clear_all_images(self):
"""清除所有图像项"""
for category_item in self.category_items.values():
while category_item.childCount() > 0:
category_item.removeChild(category_item.child(0))
def add_image(self, file_path: Path, display_name: str = None):
"""添加图像到对应的类别"""
if display_name is None:
display_name = file_path.stem
category = self._determine_category(file_path.name)
category_item = self.category_items.get(category, self.category_items["含量分布图"])
image_item = QTreeWidgetItem(category_item)
image_item.setText(0, f" └─ {display_name}")
image_item.setData(0, Qt.UserRole, {"type": "image", "path": str(file_path)})
image_item.setToolTip(0, str(file_path))
return image_item
def _determine_category(self, filename: str) -> str:
"""根据文件名确定类别"""
filename_lower = filename.lower()
for category_name, keywords, _ in self.CATEGORIES:
if any(keyword in filename_lower for keyword in keywords):
return category_name
return "含量分布图"
def scan_directory(self, work_dir: str):
"""扫描目录中的所有图像文件"""
self.clear_all_images()
work_path = Path(work_dir)
if not work_path.exists():
return
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff', '*.bmp']
scan_roots: List[Path] = []
_viz = work_path / "14_visualization"
if _viz.is_dir():
scan_roots.append(_viz)
_wm = work_path / "1_water_mask"
if _wm.is_dir():
scan_roots.append(_wm)
if not scan_roots:
scan_roots.append(work_path)
seen_norm: set = set()
image_files: List[Path] = []
for root in scan_roots:
for ext in image_extensions:
for p in root.glob(f"**/{ext}"):
key = os.path.normcase(os.path.normpath(str(p.resolve())))
if key in seen_norm:
continue
seen_norm.add(key)
image_files.append(p)
for img_file in sorted(image_files):
if img_file.name.startswith('.') or 'thumb' in img_file.name.lower():
continue
self.add_image(img_file)
for category_name, item in self.category_items.items():
count = item.childCount()
if count > 0:
for cat_name, _, icon in self.CATEGORIES:
if cat_name == category_name:
item.setText(0, f"{icon} {category_name} ({count})")
break
def get_selected_image_path(self) -> Optional[str]:
"""获取当前选中的图像路径"""
selected_item = self.currentItem()
if not selected_item:
return None
data = selected_item.data(0, Qt.UserRole)
if data and data.get("type") == "image":
return data.get("path")
return None
class ImageViewerWidget(QWidget):
"""图像查看器组件 - 支持缩放、平移"""
def __init__(self, parent=None):
super().__init__(parent)
self.current_image_path = None
self.scale_factor = 1.0
self._update_timer = QTimer()
self._update_timer.setSingleShot(True)
self._update_timer.timeout.connect(self._do_update_display)
self._pending_scale = None
self.setup_ui()
def setup_ui(self):
layout = QVBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
toolbar = QHBoxLayout()
self.refresh_btn = QPushButton("🔄 刷新目录")
self.refresh_btn.setToolTip("重新扫描工作目录中的图像文件")
toolbar.addWidget(self.refresh_btn)
separator = QFrame()
separator.setFrameShape(QFrame.VLine)
separator.setFrameShadow(QFrame.Sunken)
toolbar.addWidget(separator)
self.zoom_in_btn = QPushButton("🔍+")
self.zoom_in_btn.setToolTip("放大")
self.zoom_in_btn.setMaximumWidth(50)
toolbar.addWidget(self.zoom_in_btn)
self.zoom_out_btn = QPushButton("🔍-")
self.zoom_out_btn.setToolTip("缩小")
self.zoom_out_btn.setMaximumWidth(50)
toolbar.addWidget(self.zoom_out_btn)
self.fit_btn = QPushButton("⬜ 适应窗口")
self.fit_btn.setToolTip("适应窗口大小")
toolbar.addWidget(self.fit_btn)
self.original_btn = QPushButton("1:1 原始大小")
self.original_btn.setToolTip("原始大小")
toolbar.addWidget(self.original_btn)
toolbar.addStretch()
self.save_btn = QPushButton("💾 保存")
self.save_btn.setToolTip("保存当前图像")
toolbar.addWidget(self.save_btn)
layout.addLayout(toolbar)
self.scroll_area = QScrollArea()
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setStyleSheet("background-color: white;")
self.image_label = QLabel()
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setStyleSheet("background-color: white;")
self.scroll_area.setWidget(self.image_label)
layout.addWidget(self.scroll_area, 1)
status_layout = QHBoxLayout()
self.status_label = QLabel("就绪")
self.status_label.setStyleSheet("color: #666; font-size: 11px;")
status_layout.addWidget(self.status_label)
status_layout.addStretch()
layout.addLayout(status_layout)
self.setLayout(layout)
self.zoom_in_btn.clicked.connect(self.zoom_in)
self.zoom_out_btn.clicked.connect(self.zoom_out)
self.fit_btn.clicked.connect(self.fit_to_window)
self.original_btn.clicked.connect(self.original_size)
self.save_btn.clicked.connect(self.save_image)
def load_image(self, image_path: str):
"""加载并显示图像"""
if not image_path or not Path(image_path).exists():
self.image_label.setText("图像不存在")
self.status_label.setText("图像加载失败")
return
self.current_image_path = image_path
self.scale_factor = 1.0
pixmap = QPixmap(image_path)
if pixmap.isNull():
self.image_label.setText("无法加载图像")
self.status_label.setText("图像格式不支持")
return
self.original_pixmap = pixmap
self.fit_to_window()
file_info = Path(image_path).stat()
size_mb = file_info.st_size / (1024 * 1024)
self.status_label.setText(f"{pixmap.width()}x{pixmap.height()} | {size_mb:.2f} MB | {Path(image_path).name} | 适应窗口")
def update_image_display(self):
"""更新图像显示 - 使用防抖避免频繁重绘卡顿"""
self._update_timer.stop()
self._pending_scale = self.scale_factor
self._update_timer.start(50)
def _do_update_display(self):
"""实际执行图像更新"""
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
return
if self._pending_scale is None:
return
if self._pending_scale > 2.0 or self._pending_scale < 0.5:
transform = Qt.FastTransformation
else:
transform = Qt.SmoothTransformation
scaled_pixmap = self.original_pixmap.scaled(
int(self.original_pixmap.width() * self._pending_scale),
int(self.original_pixmap.height() * self._pending_scale),
Qt.KeepAspectRatio,
transform
)
self.image_label.setPixmap(scaled_pixmap)
self._pending_scale = None
def wheelEvent(self, event):
"""鼠标滚轮缩放 - 实时响应"""
delta = event.angleDelta().y()
if delta > 0:
if self.scale_factor < 5.0:
self.scale_factor = min(self.scale_factor * 1.1, 5.0)
self.update_image_display()
else:
if self.scale_factor > 0.1:
self.scale_factor = max(self.scale_factor / 1.1, 0.1)
self.update_image_display()
event.accept()
def zoom_in(self):
"""放大"""
if self.scale_factor < 5.0:
self.scale_factor = min(self.scale_factor * 1.25, 5.0)
self.update_image_display()
def zoom_out(self):
"""缩小"""
if self.scale_factor > 0.1:
self.scale_factor = max(self.scale_factor / 1.25, 0.1)
self.update_image_display()
def fit_to_window(self):
"""适应窗口"""
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
return
view_size = self.scroll_area.viewport().size()
img_size = self.original_pixmap.size()
scale_w = view_size.width() / img_size.width()
scale_h = view_size.height() / img_size.height()
self._fit_scale = min(scale_w, scale_h)
self.scale_factor = self._fit_scale
self.update_image_display()
self.status_label.setText(f"适应窗口 | 缩放: {self.scale_factor:.1%}")
def original_size(self):
"""原始大小"""
self.scale_factor = 1.0
self._fit_scale = None
self.update_image_display()
self.status_label.setText("原始大小 | 缩放: 100%")
def save_image(self):
"""保存图像"""
if not self.current_image_path:
return
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图像", Path(self.current_image_path).name,
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
)
if file_path:
try:
import shutil
shutil.copy(self.current_image_path, file_path)
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败: {e}")

View File

@ -0,0 +1,112 @@
import time
import warnings
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_regression
# 屏蔽烦人的 sklearn 警告
warnings.filterwarnings("ignore")
print("====== 🚀 启动 Mega Water 模型终极体检脚本 ======")
# ---------------------------------------------------------
# 1. 完美复刻侦察报告中的 CSV 数据结构
# 报告指出: 目标值(y)在左边,光谱特征(X)在右边
# ---------------------------------------------------------
print("📦 正在生成符合系统结构的模拟测试数据...")
X_raw, y_raw = make_regression(n_samples=200, n_features=50, noise=0.1, random_state=42)
# 模拟真实的 CSV 列名前2列是水质参数后面是 50 个光谱波段
columns = ['Chla', 'SS'] + [f"Band_{i}" for i in range(50)]
# 拼装成一整张大表
data = pd.DataFrame(np.hstack((y_raw.reshape(-1, 1), (y_raw * 0.5).reshape(-1, 1), X_raw)), columns=columns)
# 按照 load_data_batch 的逻辑进行切割
feature_start_index = 2
X = data.iloc[:, feature_start_index:] # 截取光谱作为 X
y = data['Chla'] # 提取一个目标参数作为 y
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"✅ 数据切割完毕! 模拟波段数: {X.shape[1]}, 训练集样本数: {X_train.shape[0]}\n")
# ---------------------------------------------------------
# 2. 严格装载侦察报告中的 16 个真实模型
# ---------------------------------------------------------
print("🔍 正在加载底层真实配置库中的模型...")
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
from sklearn.cross_decomposition import PLSRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.neural_network import MLPRegressor
# 将参数压至极低,实施“降维打击”,确保 1 秒内跑完
models = {
'SVR': SVR(),
'RF': RandomForestRegressor(n_estimators=10, max_depth=5, n_jobs=-1),
'KNN': KNeighborsRegressor(),
'LinearRegression': LinearRegression(),
'Ridge': Ridge(),
'Lasso': Lasso(),
'ElasticNet': ElasticNet(),
'PLS': PLSRegression(),
'GradientBoosting': GradientBoostingRegressor(n_estimators=10, max_depth=5),
'AdaBoost': AdaBoostRegressor(n_estimators=10),
'DecisionTree': DecisionTreeRegressor(max_depth=5),
'MLP': MLPRegressor(max_iter=50),
'ExtraTrees': ExtraTreesRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
}
# 针对报告中发现的 3 个“被禁用”的第三方强力库,进行刺探测试
try:
from xgboost import XGBRegressor
models['XGBoost'] = XGBRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
except ImportError:
models['XGBoost'] = "IMPORT_ERROR"
try:
from lightgbm import LGBMRegressor
models['LightGBM'] = LGBMRegressor(n_estimators=10, max_depth=5, n_jobs=-1)
except ImportError:
models['LightGBM'] = "IMPORT_ERROR"
try:
from catboost import CatBoostRegressor
models['CatBoost'] = CatBoostRegressor(iterations=10, depth=5, verbose=0)
except ImportError:
models['CatBoost'] = "IMPORT_ERROR"
# ---------------------------------------------------------
# 3. 开始残酷的体检循环
# ---------------------------------------------------------
print("\n================ 开始跑分测试 ================")
results = []
for name, model in models.items():
if model == "IMPORT_ERROR":
results.append(f"⚠️ [缺库] {name:<16} : 环境未安装此库 (建议: pip install {name.lower()})")
continue
start_time = time.time()
try:
# 极速拟合与评分
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
cost_time = time.time() - start_time
results.append(f"✅ [成功] {name:<16} : 耗时 {cost_time:.3f} 秒 (R2: {score:.2f})")
except Exception as e:
error_msg = str(e).split('\n')[0]
results.append(f"❌ [崩溃] {name:<16} : {error_msg}")
# ---------------------------------------------------------
# 4. 打印最终体检报告
# ---------------------------------------------------------
print("\n=============== 🏥 最终体检报告 ===============")
for res in results:
print(res)
print("===============================================")

346
src/gui/core/viz_thread.py Normal file
View File

@ -0,0 +1,346 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
可视化后台线程模块
包含 VisualizationWorkerThread 后台线程类和辅助函数。
"""
from pathlib import Path
from typing import Optional, List, Union
from PyQt5.QtCore import QThread, pyqtSignal
import numpy as np
def _viz_infer_wavelength_start_column(df) -> Union[str, int]:
"""推断光谱起始列training_spectra 通常以波长数值为列名,未必含 UTM_Y"""
import pandas as pd
for i, col in enumerate(df.columns):
name = str(col).strip().lstrip("\ufeff")
try:
v = float(name)
except ValueError:
continue
if 200.0 <= v <= 3000.0:
return i
if "UTM_Y" in df.columns:
return "UTM_Y"
return 0
class VisualizationWorkerThread(QThread):
"""可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。"""
finished_ok = pyqtSignal(object)
failed = pyqtSignal(str)
def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None):
super().__init__()
self.task = task
self.work_dir = str(work_dir)
self.extra = extra or {}
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
wp = Path(self.work_dir)
if self.task == "mask_glint":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
cnt = len(preview_paths) if preview_paths else 0
self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths})
elif self.task == "sampling_map":
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if not hyperspectral_files:
self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif")
return
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if not csv_files:
self.failed.emit("未找到采样点 CSV 文件。")
return
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
self.finished_ok.emit(
{
"task": "sampling_map",
"map_path": map_path,
"hyperspectral_path": hyperspectral_path,
"csv_path": csv_path,
}
)
elif self.task == "spectrum":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
wl = self.extra.get("wavelength_start_column", "UTM_Y")
n_groups = int(self.extra.get("n_groups", 5))
param_cols = self.extra.get("param_cols") or []
if param_cols:
output_paths: List[str] = []
err_lines: List[str] = []
for param_col in param_cols:
try:
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
output_paths.append(out)
except Exception as _ex:
err_lines.append(f"{param_col}: {_ex}")
if not output_paths:
self.failed.emit(
"所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20])
)
return
self.finished_ok.emit(
{
"task": "spectrum",
"output_paths": output_paths,
"errors": err_lines,
}
)
else:
param_col = self.extra.get("param_col")
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
self.finished_ok.emit(
{"task": "spectrum", "output_path": out, "param_col": param_col}
)
elif self.task == "statistics":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
param_cols = self.extra.get("param_cols") or []
output_paths = viz.plot_statistical_charts(
csv_path=str(csv_file),
parameter_columns=param_cols,
)
self.finished_ok.emit(
{"task": "statistics", "output_paths": output_paths}
)
elif self.task == "scatter":
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
training_csv_path = (self.extra.get("training_csv_path") or "").strip()
models_dir = (self.extra.get("models_dir") or "").strip()
if not training_csv_path or not Path(training_csv_path).is_file():
self.failed.emit("训练光谱 CSV 无效或不存在请确认已选择步骤5输出的文件。")
return
if not models_dir or not Path(models_dir).is_dir():
self.failed.emit("模型目录无效或不存在请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
return
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=training_csv_path,
models_dir=models_dir,
)
self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}})
elif self.task == "generate_all_selected":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
parts = []
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
if self.extra.get("gen_scatter"):
if training_csv.is_file():
models_dir = wp / "7_Supervised_Model_Training"
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=str(training_csv),
models_dir=str(models_dir),
)
count = len(scatter_paths) if scatter_paths else 0
parts.append(f"散点图: {count}")
else:
parts.append("散点图: 跳过(无模型目录)")
else:
parts.append("散点图: 跳过(无训练数据)")
if self.extra.get("gen_spectrum"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
wl_col = _viz_infer_wavelength_start_column(df)
if isinstance(wl_col, str):
idx = int(df.columns.get_loc(wl_col)) + 1
else:
idx = int(wl_col)
param_cols = []
if idx > 0 and idx < len(df.columns):
param_cols = [
c for c in df.columns[:idx]
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
]
if param_cols:
spectrum_paths = []
for param_col in param_cols:
try:
path = viz.plot_spectrum_by_parameter(
csv_path=str(training_csv),
parameter_column=param_col,
wavelength_start_column=wl_col,
n_groups=5,
)
if path:
spectrum_paths.append(path)
except Exception as e:
print(f"生成光谱图失败 ({param_col}): {e}")
count = len(spectrum_paths)
parts.append(f"光谱图: {count}")
else:
parts.append("光谱图: 跳过(无可用参数列)")
else:
parts.append("光谱图: 跳过(无训练数据)")
if self.extra.get("gen_boxplots"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
param_cols = [
c for c in df.select_dtypes(include=[np.number]).columns
if not any(exc in c.lower() for exc in exclude_cols)
]
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
if 0 < idx < len(df.columns):
meta_set = set(df.columns[:idx])
param_cols = [c for c in param_cols if c in meta_set]
if param_cols:
output_dict = viz.plot_statistical_charts(
csv_path=str(training_csv),
parameter_columns=param_cols,
)
count = len([v for v in output_dict.values() if v]) if output_dict else 0
parts.append(f"统计图: {count}")
else:
parts.append("统计图: 跳过(无可用水质参数列)")
else:
parts.append("统计图: 跳过(无训练数据)")
if self.extra.get("gen_mask_glint"):
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0}")
if self.extra.get("gen_sampling_map"):
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if hyperspectral_files:
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if csv_files:
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
parts.append(f"采样点图: {Path(map_path).name}")
else:
parts.append("采样点图: 跳过(无CSV)")
else:
parts.append("采样点图: 跳过(无影像)")
self.finished_ok.emit({"task": "generate_all_selected", "parts": parts})
else:
self.failed.emit(f"未知可视化任务: {self.task}")
except Exception as e:
import traceback
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass

93
src/gui/crash_dump.txt Normal file
View File

@ -0,0 +1,93 @@
============================================================
[2026-05-12 11:14:51]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 130, in <module>
from src.gui.panels.step9_panel import Step9Panel
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\panels\step9_panel.py", line 24, in <module>
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\core\water_quality_inversion_pipeline_GUI.py", line 45, in <module>
from src.preprocessing.process_water_quality_data import process_water_quality_data
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\preprocessing\process_water_quality_data.py", line 9, in <module>
from scipy import stats
File "<frozen importlib._bootstrap>", line 1412, in _handle_fromlist
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\__init__.py", line 143, in __getattr__
return _importlib.import_module(f'scipy.{name}')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\importlib\__init__.py", line 90, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\__init__.py", line 632, in <module>
from ._multicomp import *
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\_multicomp.py", line 11, in <module>
from scipy.stats._qmc import check_random_state
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\stats\_qmc.py", line 26, in <module>
from scipy.sparse.csgraph import minimum_spanning_tree
File "D:\111\changyongruanjian\anconda\envs\WQ_GUI\Lib\site-packages\scipy\sparse\csgraph\__init__.py", line 188, in <module>
from ._shortest_path import (
File "scipy/sparse/csgraph/_shortest_path.pyx", line 21, in init scipy.sparse.csgraph._shortest_path
File "<frozen importlib._bootstrap>", line 1349, in _find_and_load
KeyboardInterrupt
============================================================
[2026-05-12 11:57:28]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
main()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3093, in main
_dialog.exec_()
KeyboardInterrupt
============================================================
[2026-05-28 15:45:11]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
main()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3097, in main
window = WaterQualityGUI()
^^^^^^^^^^^^^^^^^
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1352, in __init__
self.init_ui()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1586, in init_ui
self.create_content_area()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1943, in create_content_area
self.step2_panel = Step2Panel()
^^^^^^^^^^^^
TypeError: Step2Panel.__init__() missing 1 required positional argument: 'session'
============================================================
[2026-05-28 15:45:19]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3123, in <module>
main()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3097, in main
window = WaterQualityGUI()
^^^^^^^^^^^^^^^^^
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1352, in __init__
self.init_ui()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1586, in init_ui
self.create_content_area()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 1943, in create_content_area
self.step2_panel = Step2Panel()
^^^^^^^^^^^^
TypeError: Step2Panel.__init__() missing 1 required positional argument: 'session'
============================================================
[2026-05-28 16:00:53]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 2149, in on_step_changed
self.auto_populate_step_inputs(item_data)
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 2362, in auto_populate_step_inputs
if step_id not in self.step_dependencies:
^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'WaterQualityGUI' object has no attribute 'step_dependencies'. Did you mean: '_init_step_dependencies'?
============================================================
[2026-06-03 13:56:59]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3354, in <module>
main()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3331, in main
sys.exit(app.exec_())
^^^^^^^^^^^
KeyboardInterrupt

View File

@ -325,7 +325,7 @@ class Step3Panel(QWidget):
}
water_mask_path = self.water_mask_file.get_path()
if water_mask_path:
config['water_mask'] = water_mask_path
config['water_mask_path'] = water_mask_path
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
@ -366,8 +366,8 @@ class Step3Panel(QWidget):
"""设置配置"""
if 'img_path' in config:
self.img_file.set_path(config['img_path'])
if 'water_mask' in config:
self.water_mask_file.set_path(config['water_mask'])
if 'water_mask_path' in config:
self.water_mask_file.set_path(config['water_mask_path'])
if 'output_path' in config:
self.output_file.set_path(config['output_path'])
if 'reference_csv' in config:

View File

@ -187,7 +187,7 @@ class Step5_5Panel(QWidget):
def get_config(self):
selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
return {
'training_spectra_path': self.training_data_widget.get_path(),
'training_csv_path': self.training_data_widget.get_path(),
'formula_csv_file': self.builtin_formula_path,
'formula_names': selected,
'output_file': self.output_file_widget.get_path(),
@ -195,7 +195,7 @@ class Step5_5Panel(QWidget):
}
def set_config(self, config):
if 'training_spectra_path' in config: self.training_data_widget.set_path(config['training_spectra_path'])
if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path'])
if 'formula_names' in config:
sel = set(config['formula_names'])
for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
@ -217,7 +217,7 @@ class Step5_5Panel(QWidget):
def run_step(self):
config = self.get_config()
if not config['training_spectra_path']:
if not config['training_csv_path']:
QMessageBox.warning(self, "提示", "请先选择输入数据")
return
parent = self.parent()

View File

@ -124,7 +124,7 @@ class Step5Panel(QWidget):
glint_mask_path = self.glint_mask_file.get_path()
if glint_mask_path:
config['glint_mask_path'] = glint_mask_path
# 注意step5_extract_training_spectra 不接受 output_path / training_spectra_path
# 注意step5_extract_training_spectra 不接受 output_path / training_csv_path
# 参数,输出路径由 pipeline 内部根据 training_spectra_dir 自动生成。
return config

View File

@ -363,7 +363,7 @@ class Step6Panel(QWidget):
# 回退:从 Step5 的 config 字典中查找可能的键名
step5_cfg = main_window.step5_panel.get_config()
step5_csv = (
step5_cfg.get('training_spectra_path')
step5_cfg.get('training_csv_path')
or step5_cfg.get('output_file')
or step5_cfg.get('csv_path')
or step5_cfg.get('output_csv')

BIN
src/gui/scaler_params.pkl Normal file

Binary file not shown.

View File

@ -1432,7 +1432,7 @@ class WaterQualityGUI(QMainWindow):
'glint_mask_path': ('step2', 'glint_mask', 'glint_mask_file') # 步骤5可选耀斑掩膜
},
'step5_5': {
'training_spectra_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱
'training_csv_path': ('step5', 'training_spectra', 'output_file') # 步骤5.5需要步骤5输出的训练光谱
},
'step6': {
'csv_path': ('step5', 'training_spectra', 'csv_file') # 步骤6需要训练光谱数据