refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名
This commit is contained in:
201
new/app/api/_smoke_test_train.py
Normal file
201
new/app/api/_smoke_test_train.py
Normal file
@ -0,0 +1,201 @@
|
||||
"""
|
||||
冒烟测试 _run_train_sync: 用合成数据走通真实训练管线。
|
||||
不依赖 FastAPI / xarray / dask, 只验训练 + 持久化 + 回测。
|
||||
"""
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# 绕过 main.py 触发 app 包导入(只导入 modeling 模块)
|
||||
# 当前文件位于 new/app/api/_smoke_test_train.py
|
||||
# app 包在 new/app/__init__.py, 故 new/ 必须在 sys.path 上
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from app.api.modeling import (
|
||||
_get_model_pipeline,
|
||||
_load_train_df,
|
||||
_resolve_feature_start,
|
||||
_run_train_sync,
|
||||
_MODEL_CLASS_REGISTRY,
|
||||
)
|
||||
|
||||
|
||||
def make_synthetic_csv(n_samples: int = 200, n_features: int = 8, noise: float = 0.1, seed: int = 42) -> Path:
|
||||
"""生成 [lat, lon, target, lat2, lon2, feat_0, feat_1, ...] 布局的 CSV"""
|
||||
rng = np.random.default_rng(seed)
|
||||
lat = rng.uniform(20, 25, n_samples)
|
||||
lon = rng.uniform(110, 115, n_samples)
|
||||
target = rng.uniform(0, 50, n_samples)
|
||||
lat2 = rng.uniform(0, 1, n_samples) # 元数据
|
||||
lon2 = rng.uniform(0, 1, n_samples) # 元数据
|
||||
feats = rng.normal(0, 1, (n_samples, n_features))
|
||||
# 让 y 真正依赖前 3 个特征, RF 至少应该能学到 R² > 0.5
|
||||
feats[:, 0] += target / 10
|
||||
feats[:, 1] += target / 20
|
||||
feats[:, 2] -= target / 15
|
||||
|
||||
df = pd.DataFrame({
|
||||
"lat": lat,
|
||||
"lon": lon,
|
||||
"Chl-a": target,
|
||||
"lat2": lat2,
|
||||
"lon2": lon2,
|
||||
**{f"feat_{i}": feats[:, i] for i in range(n_features)},
|
||||
})
|
||||
tmp = Path(tempfile.mkdtemp()) / "train.csv"
|
||||
df.to_csv(tmp, index=False)
|
||||
return tmp
|
||||
|
||||
|
||||
def test_load_train_df():
|
||||
print("== test_load_train_df ==")
|
||||
p = make_synthetic_csv(n_samples=50)
|
||||
df = _load_train_df(str(p))
|
||||
assert df.shape == (50, 5 + 8), f"shape={df.shape}"
|
||||
print(f" shape={df.shape}, columns[:6]={list(df.columns[:6])}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_resolve_feature_start_int_and_str():
|
||||
print("== test_resolve_feature_start (int + str) ==")
|
||||
p = make_synthetic_csv()
|
||||
df = _load_train_df(str(p))
|
||||
idx_int = _resolve_feature_start(df, 5)
|
||||
idx_str = _resolve_feature_start(df, "feat_0")
|
||||
assert idx_int == 5 == idx_str, f"int={idx_int}, str={idx_str}"
|
||||
print(f" int(5) -> {idx_int}, str('feat_0') -> {idx_str}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_resolve_feature_start_str_miss():
|
||||
print("== test_resolve_feature_start (str 不存在 -> 抛错) ==")
|
||||
p = make_synthetic_csv()
|
||||
df = _load_train_df(str(p))
|
||||
try:
|
||||
_resolve_feature_start(df, "not_exist")
|
||||
print(" FAIL: 应抛 ValueError")
|
||||
except ValueError as e:
|
||||
print(f" 正确抛 ValueError: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_get_model_pipeline_all_types():
|
||||
print("== test_get_model_pipeline (5 种 model_type) ==")
|
||||
for mt in ["RF", "SVR", "LinearRegression", "KNN", "PLS"]:
|
||||
p = _get_model_pipeline(mt, {})
|
||||
assert len(p.steps) == 2
|
||||
assert p.steps[0][0] == "scaler"
|
||||
assert p.steps[1][0] == "model"
|
||||
print(f" 全部通过: {list(_MODEL_CLASS_REGISTRY)}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_get_model_pipeline_bad_type():
|
||||
print("== test_get_model_pipeline (坏 model_type) ==")
|
||||
try:
|
||||
_get_model_pipeline("XGBoost", {})
|
||||
print(" FAIL: 应抛 ValueError")
|
||||
except ValueError as e:
|
||||
print(f" 正确抛 ValueError: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_rf_end_to_end():
|
||||
print("== test_run_train_sync (RF 端到端) ==")
|
||||
p = make_synthetic_csv(n_samples=200)
|
||||
out_dir = Path(tempfile.mkdtemp())
|
||||
out_path = out_dir / "model.joblib"
|
||||
|
||||
import time
|
||||
t0 = time.time()
|
||||
metadata = _run_train_sync(
|
||||
model_type="RF",
|
||||
target="Chl-a",
|
||||
train_data_path=str(p),
|
||||
feature_start=5,
|
||||
params={"n_estimators": 30, "max_depth": 6, "random_state": 42, "n_jobs": 1},
|
||||
output_model_path=out_path,
|
||||
)
|
||||
dt = time.time() - t0
|
||||
|
||||
assert out_path.exists(), f"joblib 未落盘: {out_path}"
|
||||
print(f" joblib 落盘: {out_path} ({out_path.stat().st_size} bytes)")
|
||||
print(f" metadata.test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f} test_mae={metadata['test_mae']:.4f}")
|
||||
print(f" metadata.n_features={metadata['n_features']} n_samples={metadata['n_samples']} train_size={metadata['train_size']} test_size={metadata['test_size']}")
|
||||
print(f" 耗时 {dt:.2f}s")
|
||||
|
||||
# 回测: 加载 joblib 再 predict
|
||||
import joblib
|
||||
saved = joblib.load(out_path)
|
||||
assert "model" in saved and "metadata" in saved, f"joblib 双 key 缺失: {saved.keys()}"
|
||||
assert hasattr(saved["model"], "predict")
|
||||
assert saved["metadata"]["test_r2"] == metadata["test_r2"]
|
||||
print(f" joblib 加载 OK, 含 'model' 和 'metadata' 双 key")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_linearregression_fast():
|
||||
print("== test_run_train_sync (LinearRegression 快速路径) ==")
|
||||
p = make_synthetic_csv(n_samples=150)
|
||||
out_path = Path(tempfile.mkdtemp()) / "lr.joblib"
|
||||
metadata = _run_train_sync(
|
||||
model_type="LinearRegression",
|
||||
target="Chl-a",
|
||||
train_data_path=str(p),
|
||||
feature_start=5,
|
||||
params={},
|
||||
output_model_path=out_path,
|
||||
)
|
||||
print(f" test_r2={metadata['test_r2']:.4f} (LR 学到线性, R² 应 >= 0.4)")
|
||||
assert metadata["test_r2"] > 0.3, f"LR test_r2={metadata['test_r2']} 太低, 数据生成可能有问题"
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_bad_csv():
|
||||
print("== test_run_train_sync (CSV 不存在) ==")
|
||||
try:
|
||||
_run_train_sync("RF", "Chl-a", "/no/such/path.csv", 5, {}, Path("/tmp/x.joblib"))
|
||||
print(" FAIL: 应抛异常")
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
print(f" 正确抛 {type(e).__name__}: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_bad_target():
|
||||
print("== test_run_train_sync (target 列不存在) ==")
|
||||
p = make_synthetic_csv()
|
||||
try:
|
||||
_run_train_sync("RF", "NopeTarget", str(p), 5, {}, Path("/tmp/x.joblib"))
|
||||
print(" FAIL: 应抛 ValueError")
|
||||
except ValueError as e:
|
||||
print(f" 正确抛 ValueError: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_str_feature_start():
|
||||
print("== test_run_train_sync (feature_start 用列名) ==")
|
||||
p = make_synthetic_csv()
|
||||
out_path = Path(tempfile.mkdtemp()) / "str_fs.joblib"
|
||||
metadata = _run_train_sync("RF", "Chl-a", str(p), "feat_0", {"n_estimators": 10}, out_path)
|
||||
assert metadata["feature_start"] == "feat_0"
|
||||
assert metadata["n_features"] == 8
|
||||
assert metadata["feature_columns"][0] == "feat_0"
|
||||
print(f" 列名 'feat_0' 解析正确, n_features={metadata['n_features']}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_train_df()
|
||||
test_resolve_feature_start_int_and_str()
|
||||
test_resolve_feature_start_str_miss()
|
||||
test_get_model_pipeline_all_types()
|
||||
test_get_model_pipeline_bad_type()
|
||||
test_run_train_sync_rf_end_to_end()
|
||||
test_run_train_sync_linearregression_fast()
|
||||
test_run_train_sync_bad_csv()
|
||||
test_run_train_sync_bad_target()
|
||||
test_run_train_sync_str_feature_start()
|
||||
print("\n>>> ALL SMOKE TESTS PASSED")
|
||||
222
new/app/api/endpoints.py
Normal file
222
new/app/api/endpoints.py
Normal file
@ -0,0 +1,222 @@
|
||||
"""
|
||||
API 路由集合
|
||||
============
|
||||
|
||||
把业务接口统一收口到 APIRouter,再由 main.py 通过 include_router 挂载。
|
||||
|
||||
当前包含的接口:
|
||||
GET /api/algorithms 列出已注册的所有去耀斑算法(供前端下拉框)
|
||||
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
|
||||
GET /api/tasks/{task_id} 查询指定任务的状态与结果
|
||||
|
||||
派发链:
|
||||
POST /api/process/deglint
|
||||
└─ BackgroundTasks.add_task(execute_glint_removal_task, ...)
|
||||
└─ get_remover(method) 从注册表拿到算法类
|
||||
└─ remover.process(input_zarr, output_zarr, **params)
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 并发安全的任务状态存储(替代旧版的 MOCK_TASK_DB)
|
||||
from app.core.task_store import get_task, set_task, update_task
|
||||
|
||||
# 算法注册表 API
|
||||
from app.core.algorithms import get_remover, list_removers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 路由实例
|
||||
# ---------------------------------------------------------------------------
|
||||
# prefix 不在此处设置,统一在 main.py 挂载时给定,便于将来按版本拆分
|
||||
# (例如 /api/v1、/api/v2 共存时复用同一个 router 对象)。
|
||||
# ---------------------------------------------------------------------------
|
||||
router = APIRouter(tags=["deglint"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 请求 / 响应数据模型
|
||||
# ---------------------------------------------------------------------------
|
||||
class DeglintRequest(BaseModel):
|
||||
"""POST /api/process/deglint 的请求体"""
|
||||
|
||||
method: str = Field(
|
||||
...,
|
||||
description="去耀斑方法名称,必须是已注册算法,例如 'kutser' / 'goodman'",
|
||||
examples=["kutser"],
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description=(
|
||||
"传递给算法 process() 的超参数字典,例如 "
|
||||
"Kutser: {'band_lower': 773, 'band_oxy': 845, 'band_upper': 893}; "
|
||||
"Goodman: {'band_ref': 750, 'band_diff': 640, 'A': 0.0, 'B': 0.0}"
|
||||
),
|
||||
examples=[{"band_lower": 773, "band_oxy": 845, "band_upper": 893}],
|
||||
)
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
"""提交任务成功后立即返回的响应"""
|
||||
|
||||
task_id: str
|
||||
status: str # 一定是 PENDING
|
||||
|
||||
|
||||
class AlgorithmListResponse(BaseModel):
|
||||
"""GET /api/algorithms 的响应"""
|
||||
|
||||
algorithms: list # 已注册算法名列表
|
||||
count: int # 算法总数
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 后台任务执行器(真实派发链)
|
||||
# ---------------------------------------------------------------------------
|
||||
# 注意:这里使用 async def。
|
||||
# FastAPI / Starlette 的 BackgroundTasks 支持 async function,
|
||||
# 会在响应返回后自动 await 它,不影响主请求链路。
|
||||
# ---------------------------------------------------------------------------
|
||||
async def execute_glint_removal_task(
|
||||
task_id: str,
|
||||
method: str,
|
||||
params: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
后台异步执行器:按 method 名字从注册表取出算法类,实例化并运行 process()。
|
||||
|
||||
状态机:
|
||||
PENDING -> PROCESSING -> SUCCESS
|
||||
└──> FAILED(含 error / traceback)
|
||||
"""
|
||||
# 0. 安全检查:任务记录必须已存在(POST 阶段已写入)
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
print(f"[{task_id}] 任务不存在, 跳过")
|
||||
return
|
||||
|
||||
# 1. 状态推进到 PROCESSING
|
||||
await update_task(
|
||||
task_id,
|
||||
status="PROCESSING",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 开始处理 method={method} params={params}")
|
||||
|
||||
# 2. 临时硬编码 IO 路径(未来由数据管理层提供)
|
||||
# TODO: 替换为真实的数据管理服务返回的 zarr 路径
|
||||
input_zarr_path = "./data/temp_in.zarr"
|
||||
output_zarr_path = f"./data/{task_id}_out.zarr"
|
||||
|
||||
try:
|
||||
# 3. 按 method 名字从注册表取算法类并实例化
|
||||
# get_remover 找不到时会抛 KeyError,下面的 except 会兜住
|
||||
algorithm_cls = get_remover(method)
|
||||
remover = algorithm_cls()
|
||||
|
||||
# 4. 调用算法(注意 await,因为 BaseGlintRemover.process 是 async)
|
||||
await remover.process(input_zarr_path, output_zarr_path, **params)
|
||||
|
||||
# 5. 成功:写回结果路径与状态
|
||||
await update_task(
|
||||
task_id,
|
||||
status="SUCCESS",
|
||||
output_zarr_path=output_zarr_path,
|
||||
error=None,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 处理完成 -> SUCCESS, output={output_zarr_path}")
|
||||
|
||||
except Exception as exc: # noqa: BLE001 顶层兜底,绝不让后台任务静默失败
|
||||
# 6. 失败:记录错误信息与堆栈,便于前端排查
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
output_zarr_path=None,
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
traceback=traceback.format_exc(),
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 处理失败 -> {type(exc).__name__}: {exc}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /algorithms
|
||||
# ---------------------------------------------------------------------------
|
||||
# 返回当前已注册的所有算法名,供前端动态渲染下拉框 / 选择器。
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/algorithms", response_model=AlgorithmListResponse)
|
||||
async def list_registered_algorithms() -> Dict[str, Any]:
|
||||
"""列出已注册的去耀斑算法。"""
|
||||
names = list(list_removers().keys())
|
||||
return {"algorithms": names, "count": len(names)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /process/deglint
|
||||
# ---------------------------------------------------------------------------
|
||||
# 提交去耀斑处理任务。FastAPI 在函数返回后才会把响应发给前端,
|
||||
# 因此通过 BackgroundTasks 把耗时操作丢到后台,接口本身立刻返回 task_id。
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/process/deglint", response_model=TaskAcceptedResponse)
|
||||
async def submit_deglint(
|
||||
payload: DeglintRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> Dict[str, Any]:
|
||||
"""提交一个去耀斑处理任务,并立即返回 task_id。"""
|
||||
|
||||
# 1. 生成唯一任务 ID(UUID4 足以保证全局唯一性)
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 2. 在任务库中登记一条 PENDING 记录(并发安全)
|
||||
# 注意:output_zarr_path / error / traceback 字段在执行过程中被填充
|
||||
await set_task(
|
||||
task_id,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"method": payload.method,
|
||||
"params": payload.params,
|
||||
"status": "PENDING",
|
||||
"output_zarr_path": None,
|
||||
"error": None,
|
||||
"traceback": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# 3. 把真实执行器丢到后台
|
||||
background_tasks.add_task(
|
||||
execute_glint_removal_task,
|
||||
task_id,
|
||||
payload.method,
|
||||
payload.params,
|
||||
)
|
||||
|
||||
# 4. 立即返回 task_id 与 PENDING 状态
|
||||
return {"task_id": task_id, "status": "PENDING"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /tasks/{task_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
# 前端轮询此接口获取任务状态。PENDING / PROCESSING 表示仍在跑,
|
||||
# SUCCESS 表示成功(含 output_zarr_path),FAILED 表示失败(含 error / traceback)。
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task_status(task_id: str) -> Dict[str, Any]:
|
||||
"""查询指定任务的当前状态与结果。"""
|
||||
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
# 找不到 task_id 通常意味着客户端拼错了 ID,或者记录已被清理
|
||||
raise HTTPException(status_code=404, detail=f"task_id 不存在: {task_id}")
|
||||
|
||||
# 直接返回字典,FastAPI 会自动 JSON 序列化
|
||||
return record
|
||||
786
new/app/api/modeling.py
Normal file
786
new/app/api/modeling.py
Normal file
@ -0,0 +1,786 @@
|
||||
"""
|
||||
app/api/modeling.py
|
||||
===================
|
||||
|
||||
机器学习与水质反演相关的 API 路由。
|
||||
|
||||
接口(最终路径, 挂载后):
|
||||
POST /api/modeling/train 提交模型训练任务, 立即返回 task_id
|
||||
GET /api/modeling/models 列出已训练好的模型(未来从磁盘 joblib 读)
|
||||
POST /api/modeling/predict 提交模型推断任务, 立即返回 task_id
|
||||
|
||||
设计要点
|
||||
--------
|
||||
- 训练 / 推断均为异步后台任务, 复用 app.core.task_store 的并发安全任务状态。
|
||||
- 模型元数据用模块级 _MODEL_REGISTRY 暂存(开发期内存存储),
|
||||
未来从磁盘 joblib 读时只需替换 list_trained_models() 内部实现即可。
|
||||
- /predict 已接入真实 sklearn + xarray + dask 流式推断:
|
||||
* joblib.load 读模型(缺文件时降级为 Dummy RandomForestRegressor)
|
||||
* xr.open_zarr 延迟打开影像, NaN 填 0
|
||||
* xr.apply_ufunc(dask="parallelized") 沿 (y, x) 逐 chunk 调 model.predict
|
||||
* to_zarr(mode="w", compute=True) 流式写出, 内存峰值 ≈ 1 个 chunk
|
||||
- /train 已接入真实 sklearn + pandas 训练管线:
|
||||
* pd.read_csv 读结构化训练表(支持 [lat, lon, target_*, feature_*] 布局)
|
||||
* 按 target 列 dropna 清洗;按 feature_start 索引/列名切分特征
|
||||
* sklearn Pipeline: StandardScaler -> {RF/SVR/LinearRegression/KNN/PLS}
|
||||
* train_test_split(80/20) 划分, 计算 test_r2/rmse/mae
|
||||
* joblib.dump({model, metadata}) 落盘 ./data/models/{model_id}.joblib
|
||||
* 测试指标写回 TASK_STORE, 同时登记到 _MODEL_REGISTRY
|
||||
注: 旧版 SPXY / KS 划分留作未来扩展, 当前固定 random 划分 (test_size=0.2, random_state=42)。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xarray as xr
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
from sklearn.linear_model import LinearRegression
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVR
|
||||
|
||||
# 复用并发安全任务状态存储(与 deglint 共享同一份 TASK_STORE,
|
||||
# 通过 task 记录里的 "kind" 字段区分 train / predict / deglint)
|
||||
from app.core.task_store import get_task, set_task, update_task
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 路由实例
|
||||
# ---------------------------------------------------------------------------
|
||||
# prefix="/modeling" 让本文件内只写 /train /models /predict 等短路径,
|
||||
# 最终完整路径由 main.py 挂载时再补 /api。
|
||||
# ---------------------------------------------------------------------------
|
||||
router = APIRouter(prefix="/modeling", tags=["modeling"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 数据模型
|
||||
# ---------------------------------------------------------------------------
|
||||
class TrainRequest(BaseModel):
|
||||
"""POST /api/modeling/train 的请求体"""
|
||||
|
||||
model_type: str = Field(
|
||||
...,
|
||||
description="模型类型, 例如 'RF' (随机森林) / 'SVR' (支持向量回归) / 'XGBoost' / 'MLP'",
|
||||
examples=["RF", "SVR"],
|
||||
)
|
||||
target: str = Field(
|
||||
...,
|
||||
description="反演目标水质参数, 例如 'Chl-a' (叶绿素a) / 'TSS' (总悬浮物) / 'CDOM' (有色可溶有机物)",
|
||||
examples=["Chl-a", "TSS", "CDOM"],
|
||||
)
|
||||
train_data_path: str = Field(
|
||||
...,
|
||||
description="训练数据集的 zarr 路径(包含 reflectance 变量与 target 标签)",
|
||||
examples=["./data/train.zarr"],
|
||||
)
|
||||
feature_start: Union[int, str] = Field(
|
||||
default=4,
|
||||
description=(
|
||||
"特征列起始位置. 表格布局假定为 "
|
||||
"[lat, lon, target_1, target_2, ..., feature_1, feature_2, ...] "
|
||||
"可传 int 列索引(如 4)或 str 列名(如 '374.285' 波长起点)。"
|
||||
"默认 4, 即前 4 列视为元数据/目标, 之后全部是特征。"
|
||||
),
|
||||
examples=[4, "374.285"],
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="模型超参, 例如 RF 的 {'n_estimators': 100, 'max_depth': 20}",
|
||||
examples=[{"n_estimators": 100, "max_depth": 20}],
|
||||
)
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
"""POST /api/modeling/predict 的请求体"""
|
||||
|
||||
model_id: str = Field(
|
||||
...,
|
||||
description="已训练模型的 ID(由 /api/modeling/train 返回或 /api/modeling/models 列出)",
|
||||
)
|
||||
input_zarr_path: str = Field(
|
||||
...,
|
||||
description="待推断影像的 zarr 路径",
|
||||
examples=["./data/scene.zarr"],
|
||||
)
|
||||
output_zarr_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"输出 zarr 路径, 缺省时由后端按规则生成 "
|
||||
"(如 ./data/{model_id}_{input_stem}_pred.zarr)"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
"""提交训练/推断任务后立即返回的响应"""
|
||||
|
||||
task_id: str
|
||||
status: str # 一定是 PENDING
|
||||
kind: str # "train" / "predict", 便于前端识别任务类型
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""单个模型的元信息(GET /api/modeling/models 的元素)"""
|
||||
|
||||
model_id: str
|
||||
model_type: str
|
||||
target: str
|
||||
params: Dict[str, Any]
|
||||
path: str # joblib 文件路径
|
||||
created_at: str
|
||||
train_task_id: str # 产生此模型的那个训练任务的 ID
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
"""GET /api/modeling/models 的响应"""
|
||||
|
||||
models: List[ModelInfo]
|
||||
count: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 模块级模型注册表(开发期内存, 未来替换为磁盘扫描)
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_id -> ModelInfo 字典
|
||||
# 读多写少, 用一个普通 dict 足够(CPython GIL 兜底)。
|
||||
# 写时(训练完成时)只发生一次, 无并发风险。
|
||||
# ---------------------------------------------------------------------------
|
||||
_MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _register_model(record: Dict[str, Any]) -> None:
|
||||
"""将训练完成的模型登记到内存注册表。"""
|
||||
_MODEL_REGISTRY[record["model_id"]] = record
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 训练管线的模块级辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设计要点 (与推断管线一致):
|
||||
# 1) 模块级函数: dask / joblib 后端若走子进程 pickle, 嵌套闭包会丢字段。
|
||||
# 2) 同步执行: execute_train_task 用 asyncio.to_thread 派发, 内部全程同步阻塞。
|
||||
# 3) 失败抛异常: 异常由 execute_train_task 捕获, 转 FAILED + traceback。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# model_type (大写字符串) -> sklearn 估计器类
|
||||
# 与 OpenClaw model_configs 思路一致, 但此处只保留类 (参数由 params 透传)
|
||||
_MODEL_CLASS_REGISTRY: Dict[str, type] = {
|
||||
"RF": RandomForestRegressor,
|
||||
"SVR": SVR,
|
||||
"LinearRegression": LinearRegression,
|
||||
"KNN": KNeighborsRegressor,
|
||||
"PLS": PLSRegression,
|
||||
}
|
||||
|
||||
|
||||
def _get_model_pipeline(model_type: str, params: Optional[Dict[str, Any]]) -> Pipeline:
|
||||
"""
|
||||
模型工厂: 按 model_type 选 sklearn 类, 用 StandardScaler + 估计器构造 Pipeline。
|
||||
|
||||
与 OpenClaw 不同之处: 把 scaler 放进 Pipeline 第一步,
|
||||
推断时直接 pipeline.predict(X) 即可, scaler 参数与训练时严格一致。
|
||||
"""
|
||||
model_cls = _MODEL_CLASS_REGISTRY.get(model_type)
|
||||
if model_cls is None:
|
||||
raise ValueError(
|
||||
f"不支持的 model_type='{model_type}', "
|
||||
f"可选: {sorted(_MODEL_CLASS_REGISTRY.keys())}"
|
||||
)
|
||||
estimator = model_cls(**(params or {}))
|
||||
return Pipeline([("scaler", StandardScaler()), ("model", estimator)])
|
||||
|
||||
|
||||
def _load_train_df(csv_path: str) -> pd.DataFrame:
|
||||
"""
|
||||
读 CSV 训练表, 规整空串 / 空白 / NULL 等为 NaN。
|
||||
|
||||
沿用 OpenClaw modeling_batch.load_data_batch 的读取策略:
|
||||
na_values 显式列举 + 正则二次清理 (防 cell 内出现 " " 等纯空白)。
|
||||
"""
|
||||
try:
|
||||
df = pd.read_csv(
|
||||
csv_path,
|
||||
na_values=["", " ", "NaN", "nan", "NULL", "null"],
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise FileNotFoundError(f"训练数据文件不存在: {csv_path}") from exc
|
||||
except pd.errors.EmptyDataError as exc:
|
||||
raise ValueError(f"训练数据文件为空: {csv_path}") from exc
|
||||
# 二次清理: 残留的纯空白 cell
|
||||
df = df.replace(r"^\s*$", np.nan, regex=True)
|
||||
return df
|
||||
|
||||
|
||||
def _resolve_feature_start(
|
||||
df: pd.DataFrame,
|
||||
feature_start: Union[int, str],
|
||||
) -> int:
|
||||
"""
|
||||
将 feature_start (int 索引 / str 列名) 统一解析为 int 列索引。
|
||||
|
||||
与 OpenClaw modeling_batch.load_data_batch / load_data_single 一致:
|
||||
str 走 columns.get_loc, int 直接返回。
|
||||
"""
|
||||
if isinstance(feature_start, str):
|
||||
if feature_start not in df.columns:
|
||||
raise ValueError(
|
||||
f"feature_start='{feature_start}' 不在 CSV 列中: {list(df.columns)}"
|
||||
)
|
||||
return int(df.columns.get_loc(feature_start))
|
||||
return int(feature_start)
|
||||
|
||||
|
||||
def _run_train_sync(
|
||||
model_type: str,
|
||||
target: str,
|
||||
train_data_path: str,
|
||||
feature_start: Union[int, str],
|
||||
params: Optional[Dict[str, Any]],
|
||||
output_model_path: Path,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
完整同步训练流程 (由 execute_train_task 在线程池内调用):
|
||||
|
||||
pd.read_csv -> 目标列 dropna -> 切特征 -> train_test_split(80/20)
|
||||
-> Pipeline(StandardScaler + model).fit -> 评估 test_r2/rmse/mae
|
||||
-> joblib.dump({model, metadata}, output_model_path)
|
||||
|
||||
Returns:
|
||||
metadata 字典, 含 test_r2 / test_rmse / test_mae / n_features 等,
|
||||
调用方负责写回 TASK_STORE 和 _MODEL_REGISTRY。
|
||||
|
||||
注: 旧版 SPXY / KS 划分留作未来扩展 (params.split_method 控制),
|
||||
当前固定 random + test_size=0.2 + random_state=42。
|
||||
"""
|
||||
df = _load_train_df(train_data_path)
|
||||
|
||||
if target not in df.columns:
|
||||
raise ValueError(
|
||||
f"target='{target}' 不在 CSV 列中, 可选: {list(df.columns)}"
|
||||
)
|
||||
|
||||
# 1) 清洗: 仅剔除 target NaN 的行 (与 OpenClaw load_data_single 一致)
|
||||
df = df[df[target].notna()].copy()
|
||||
if df.empty:
|
||||
raise ValueError("target 剔除 NaN 后无样本, 终止训练")
|
||||
|
||||
# 2) 特征切分
|
||||
feature_start_idx = _resolve_feature_start(df, feature_start)
|
||||
feature_columns = list(df.columns[feature_start_idx:])
|
||||
|
||||
X = df.iloc[:, feature_start_idx:].astype(np.float64)
|
||||
y = df[target].astype(np.float64).values
|
||||
|
||||
# 3) 划分 (固定 random, 未来扩展 spxy/ks)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X.values,
|
||||
y,
|
||||
test_size=0.2,
|
||||
random_state=42,
|
||||
)
|
||||
|
||||
# 4) 构造 Pipeline + 训练
|
||||
pipeline = _get_model_pipeline(model_type, params)
|
||||
pipeline.fit(X_train, y_train)
|
||||
|
||||
# 5) 测试集与训练集评估
|
||||
y_pred = pipeline.predict(X_test)
|
||||
test_r2 = float(r2_score(y_test, y_pred))
|
||||
test_rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
|
||||
test_mae = float(mean_absolute_error(y_test, y_pred))
|
||||
|
||||
y_train_pred = pipeline.predict(X_train)
|
||||
train_r2 = float(r2_score(y_train, y_train_pred))
|
||||
train_rmse = float(np.sqrt(mean_squared_error(y_train, y_train_pred)))
|
||||
train_mae = float(mean_absolute_error(y_train, y_train_pred))
|
||||
|
||||
metadata: Dict[str, Any] = {
|
||||
"model_type": model_type,
|
||||
"target": target,
|
||||
"feature_start": feature_start,
|
||||
"feature_columns": feature_columns,
|
||||
"n_features": int(X.shape[1]),
|
||||
"n_samples": int(X.shape[0]),
|
||||
"train_size": int(X_train.shape[0]),
|
||||
"test_size": int(X_test.shape[0]),
|
||||
"params": dict(params or {}),
|
||||
"test_r2": test_r2,
|
||||
"test_rmse": test_rmse,
|
||||
"test_mae": test_mae,
|
||||
"train_r2": train_r2,
|
||||
"train_rmse": train_rmse,
|
||||
"train_mae": train_mae,
|
||||
"split_method": "random",
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# 7) 持久化 (目录可能不存在, 顺手建)
|
||||
output_model_path = Path(output_model_path)
|
||||
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
joblib.dump(
|
||||
{"model": pipeline, "metadata": metadata},
|
||||
output_model_path,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 推断管线的模块级辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设计要点:
|
||||
# 1) Dask 调度时, 函数必须可被工作进程 pickle 序列化。
|
||||
# 因此 _predict_block / _load_model / _make_dummy_model / _run_predict_sync
|
||||
# 全部是模块级函数 (而非嵌套), 避免闭包陷阱。
|
||||
# 2) _predict_block 通过 model.predict(spectra_2d) 整批预测,
|
||||
# 整张影像的 O(n_pixels * n_bands) 一次性预测在大矩阵上必 OOM,
|
||||
# 因此外层用 xr.apply_ufunc(dask="parallelized") 把矩阵切块
|
||||
# 逐块进入此函数, 单次内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) 大小。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dummy_model(n_features: int) -> RandomForestRegressor:
|
||||
"""
|
||||
构造一个 Dummy 随机森林回归器。
|
||||
|
||||
用途:
|
||||
1) 真实 joblib 文件不存在时的连通性测试
|
||||
2) 训练骨架尚未接入真实数据时的占位推断
|
||||
"""
|
||||
rng = np.random.default_rng(42)
|
||||
X = rng.random((200, n_features))
|
||||
y = rng.random(200)
|
||||
model = RandomForestRegressor(
|
||||
n_estimators=10, max_depth=5, random_state=0, n_jobs=1
|
||||
)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
|
||||
def _load_model(path: str, n_features: int) -> Any:
|
||||
"""
|
||||
加载训练好的 sklearn 模型, 失败时降级 Dummy。
|
||||
|
||||
优先级:
|
||||
1) path 存在且 joblib.load 成功 -> 返回真实模型
|
||||
2) 否则 -> 降级为 Dummy 随机森林 (n_features 必须指定)
|
||||
"""
|
||||
p = Path(path)
|
||||
if p.is_file() and p.stat().st_size > 0:
|
||||
try:
|
||||
print(f"[model] 从磁盘加载: {path}")
|
||||
return joblib.load(path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[model] joblib.load 失败 ({type(exc).__name__}: {exc}), 降级 Dummy")
|
||||
print(f"[model] 真实 joblib 不存在 ({path}), 使用 Dummy RandomForest")
|
||||
return _make_dummy_model(n_features)
|
||||
|
||||
|
||||
def _predict_block(spectra_3d: np.ndarray, model: Any) -> np.ndarray:
|
||||
"""
|
||||
单个 dask chunk 的推断函数 (xr.apply_ufunc 会自动调度调用)。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spectra_3d : np.ndarray
|
||||
形状 (y_chunk, x_chunk, n_bands)。
|
||||
此形状由 input_core_dims=[["band"]] 决定:
|
||||
xarray 会把 band 维移到最后一轴, 然后按 (y, x) 的 chunk 切分调用本函数。
|
||||
model : 已 fit 好的 sklearn 估计器
|
||||
接受 (n_samples, n_features) 输入, 返回 (n_samples,) 预测。
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
形状 (y_chunk, x_chunk), dtype float32 的标量预测图。
|
||||
"""
|
||||
yc, xc, nb = spectra_3d.shape
|
||||
# 2D 化: 每个像素一行光谱
|
||||
flat = spectra_3d.reshape(yc * xc, nb)
|
||||
# sklearn 风格的批量预测
|
||||
pred = model.predict(flat)
|
||||
# 还原为 2D 空间图, 强制 float32 节约一半内存
|
||||
return pred.reshape(yc, xc).astype(np.float32, copy=False)
|
||||
|
||||
|
||||
def _run_predict_sync(
|
||||
model: Any,
|
||||
model_id: str,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
同步推断主流程 (被 asyncio.to_thread 调用)。
|
||||
|
||||
流程:
|
||||
1) xr.open_zarr 延迟打开 (dask 数组, 不一次性读入内存)
|
||||
2) NaN -> 0 清洗 (model.predict 不接受 NaN)
|
||||
3) xr.apply_ufunc 沿 (y, x) 逐 chunk 调 _predict_block
|
||||
4) 非水域置 NaN (zarr 支持 float NaN)
|
||||
5) to_zarr 触发整图计算 + 流式写出
|
||||
"""
|
||||
# 1. 延迟打开输入 (关键: Dask 不一次性读入内存)
|
||||
ds = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
if "reflectance" not in ds.data_vars:
|
||||
raise KeyError(
|
||||
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds.data_vars)}"
|
||||
)
|
||||
|
||||
reflectance = ds["reflectance"] # dims: (y, x, band)
|
||||
n_bands = reflectance.sizes["band"]
|
||||
|
||||
# 2. 水域掩膜 (与去耀斑算法同约定)
|
||||
if "water_mask" in ds.data_vars or "water_mask" in ds.coords:
|
||||
water_mask = ds["water_mask"].astype(bool)
|
||||
else:
|
||||
water_mask = xr.ones_like(reflectance.isel(band=0), dtype=bool)
|
||||
|
||||
# 3. NaN 清洗: 填充 0 (model.predict 不接受 NaN)
|
||||
refl_clean = reflectance.fillna(0.0)
|
||||
|
||||
# 4. 核心: 用 apply_ufunc 把 model.predict 沿 (y, x) 应用
|
||||
# dask="parallelized" 让每个 (y_chunk, x_chunk, band) chunk
|
||||
# 独立调 _predict_block, 任意时刻内存中只有若干个 chunk。
|
||||
prediction: xr.DataArray = xr.apply_ufunc(
|
||||
_predict_block,
|
||||
refl_clean,
|
||||
kwargs={"model": model},
|
||||
input_core_dims=[["band"]],
|
||||
output_core_dims=[[]],
|
||||
dask="parallelized",
|
||||
output_dtypes=[np.float32],
|
||||
dask_gufunc_kwargs={"allow_rechunk": True},
|
||||
vectorize=False,
|
||||
)
|
||||
|
||||
# 5. 非水域置 NaN (zarr 支持 float NaN, 便于后续可视化/掩膜分析)
|
||||
prediction = prediction.where(water_mask, np.nan)
|
||||
|
||||
# 6. 包装为 Dataset 并流式写出
|
||||
out = xr.Dataset(
|
||||
{"prediction": prediction},
|
||||
attrs={
|
||||
"model_id": model_id,
|
||||
"input_zarr_path": input_zarr_path,
|
||||
"n_bands": n_bands,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
# 保留 y/x 坐标
|
||||
out = out.assign_coords(y=ds["y"], x=ds["x"])
|
||||
|
||||
# to_zarr + compute=True 触发整图 dask 图求值
|
||||
# 中间会按 chunk 逐块调度到线程池, 内存峰值 ≈ 1 个 chunk 的体量
|
||||
out.to_zarr(output_zarr_path, mode="w", compute=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 后台任务执行器
|
||||
# ---------------------------------------------------------------------------
|
||||
async def execute_train_task(
|
||||
task_id: str,
|
||||
model_type: str,
|
||||
target: str,
|
||||
train_data_path: str,
|
||||
feature_start: Union[int, str],
|
||||
params: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
训练任务后台执行器(已接入真实 sklearn 训练流程)。
|
||||
|
||||
流程:
|
||||
1) get_task 校验任务存在
|
||||
2) update_task(PROCESSING)
|
||||
3) 生成 model_id / model_path
|
||||
4) asyncio.to_thread 派发 _run_train_sync 到默认线程池
|
||||
5) 成功 -> _register_model + update_task(SUCCESS, 附 test_r2/rmse/mae)
|
||||
6) 失败 -> update_task(FAILED, 附 error + traceback)
|
||||
"""
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
print(f"[{task_id}] 训练任务不存在, 跳过")
|
||||
return
|
||||
|
||||
await update_task(
|
||||
task_id,
|
||||
status="PROCESSING",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(
|
||||
f"[{task_id}] 开始训练 model_type={model_type} target={target} "
|
||||
f"train_data_path={train_data_path} feature_start={feature_start}"
|
||||
)
|
||||
|
||||
# model_id 用 uuid4 前 12 位 (8 位易撞, 12 位兼顾可读性)
|
||||
model_id = f"model_{uuid.uuid4().hex[:12]}"
|
||||
model_path = Path(f"./data/models/{model_id}.joblib")
|
||||
|
||||
try:
|
||||
# 同步 sklearn / pandas 训练丢到默认线程池, 不阻塞 event loop
|
||||
metadata = await asyncio.to_thread(
|
||||
_run_train_sync,
|
||||
model_type,
|
||||
target,
|
||||
train_data_path,
|
||||
feature_start,
|
||||
params,
|
||||
model_path,
|
||||
)
|
||||
|
||||
# 登记到内存注册表 (供 /predict 查 model_id)
|
||||
_register_model(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"model_type": model_type,
|
||||
"target": target,
|
||||
"params": dict(params or {}),
|
||||
"path": str(model_path),
|
||||
"feature_start": feature_start,
|
||||
"n_features": metadata["n_features"],
|
||||
"test_r2": metadata["test_r2"],
|
||||
"test_rmse": metadata["test_rmse"],
|
||||
"test_mae": metadata["test_mae"],
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"train_task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
# 把训练指标写回任务记录, 前端轮询时可直接看
|
||||
await update_task(
|
||||
task_id,
|
||||
status="SUCCESS",
|
||||
model_id=model_id,
|
||||
model_path=str(model_path),
|
||||
test_r2=metadata["test_r2"],
|
||||
test_rmse=metadata["test_rmse"],
|
||||
test_mae=metadata["test_mae"],
|
||||
n_features=metadata["n_features"],
|
||||
n_samples=metadata["n_samples"],
|
||||
error=None,
|
||||
traceback=None,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(
|
||||
f"[{task_id}] 训练完成 -> model_id={model_id} "
|
||||
f"test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f}"
|
||||
)
|
||||
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# 失败时 model_path 不一定有产物, 显式置 None 方便前端判断
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
model_id=None,
|
||||
model_path=None,
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
traceback=traceback.format_exc(),
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 训练失败 -> {type(exc).__name__}: {exc}")
|
||||
|
||||
|
||||
async def execute_predict_task(
|
||||
task_id: str,
|
||||
model_id: str,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
推断任务后台执行器(真实实现版)。
|
||||
|
||||
OOM 防护策略:
|
||||
- xr.open_zarr(..., chunks="auto") 延迟打开, 整图不一次性读入内存
|
||||
- xr.apply_ufunc(..., dask="parallelized") 把影像按 chunk 切分
|
||||
- 每个 chunk 内部 reshape 成 2D, 调 model.predict, 再 reshape 回 2D
|
||||
- 任意时刻内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) chunk 的体量
|
||||
- 整图完成计算后再 to_zarr(compute=True) 流式写出
|
||||
"""
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
print(f"[{task_id}] 推断任务不存在, 跳过")
|
||||
return
|
||||
|
||||
# 1. 校验 model_id 是否已注册 (避免在后台任务里报模糊错误)
|
||||
model_meta = _MODEL_REGISTRY.get(model_id)
|
||||
if model_meta is None:
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
error=f"model_id 不存在: {model_id}",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 推断失败 -> model_id 不存在: {model_id}")
|
||||
return
|
||||
|
||||
# 2. 自动生成 output_zarr_path (若未提供)
|
||||
if output_zarr_path is None:
|
||||
stem = input_zarr_path.rstrip("/\\").split("/")[-1].split("\\")[-1]
|
||||
stem = stem.replace(".zarr", "")
|
||||
output_zarr_path = f"./data/{model_id}_{stem}_pred.zarr"
|
||||
|
||||
await update_task(
|
||||
task_id,
|
||||
status="PROCESSING",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 开始推断 model_id={model_id} input={input_zarr_path}")
|
||||
|
||||
try:
|
||||
# 3. 探测波段数 (用于 Dummy 模型适配)
|
||||
# 这里只读 zarr 元数据 (.zarray 的 shape), 不读真实数据
|
||||
ds_probe = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
if "reflectance" not in ds_probe.data_vars:
|
||||
raise KeyError(
|
||||
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds_probe.data_vars)}"
|
||||
)
|
||||
n_bands = ds_probe["reflectance"].sizes["band"]
|
||||
ds_probe.close()
|
||||
|
||||
# 4. 加载模型 (真实文件优先, Dummy 兜底)
|
||||
model = _load_model(model_meta["path"], n_features=n_bands)
|
||||
|
||||
# 5. 包装同步执行, 丢到线程池, 事件循环不阻塞
|
||||
await asyncio.to_thread(
|
||||
_run_predict_sync,
|
||||
model,
|
||||
model_id,
|
||||
input_zarr_path,
|
||||
output_zarr_path,
|
||||
)
|
||||
|
||||
await update_task(
|
||||
task_id,
|
||||
status="SUCCESS",
|
||||
output_zarr_path=output_zarr_path,
|
||||
model_id=model_id,
|
||||
error=None,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 推断完成 -> output={output_zarr_path}")
|
||||
|
||||
except Exception as exc: # noqa: BLE001
|
||||
tb_text = traceback.format_exc()
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
output_zarr_path=None,
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
traceback=tb_text,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 推断失败 -> {type(exc).__name__}: {exc}")
|
||||
print(tb_text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/modeling/train
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/train", response_model=TaskAcceptedResponse)
|
||||
async def submit_train(
|
||||
payload: TrainRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> Dict[str, Any]:
|
||||
"""提交一个模型训练任务, 立即返回 task_id。"""
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
await set_task(
|
||||
task_id,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"kind": "train",
|
||||
"model_type": payload.model_type,
|
||||
"target": payload.target,
|
||||
"train_data_path": payload.train_data_path,
|
||||
"feature_start": payload.feature_start,
|
||||
"params": payload.params,
|
||||
"status": "PENDING",
|
||||
"model_id": None,
|
||||
"model_path": None,
|
||||
"test_r2": None,
|
||||
"test_rmse": None,
|
||||
"test_mae": None,
|
||||
"n_features": None,
|
||||
"n_samples": None,
|
||||
"error": None,
|
||||
"traceback": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
background_tasks.add_task(
|
||||
execute_train_task,
|
||||
task_id,
|
||||
payload.model_type,
|
||||
payload.target,
|
||||
payload.train_data_path,
|
||||
payload.feature_start,
|
||||
payload.params,
|
||||
)
|
||||
return {"task_id": task_id, "status": "PENDING", "kind": "train"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/modeling/models
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/models", response_model=ModelListResponse)
|
||||
async def list_trained_models() -> Dict[str, Any]:
|
||||
"""
|
||||
列出已训练好的模型。
|
||||
|
||||
未来实现: 从 ./data/models/*.joblib 扫描元信息,
|
||||
当前直接从内存 _MODEL_REGISTRY 读。
|
||||
"""
|
||||
models = list(_MODEL_REGISTRY.values())
|
||||
# 按 created_at 倒序, 最新训练的在前
|
||||
models.sort(key=lambda m: m.get("created_at", ""), reverse=True)
|
||||
return {"models": models, "count": len(models)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/modeling/predict
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/predict", response_model=TaskAcceptedResponse)
|
||||
async def submit_predict(
|
||||
payload: PredictRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> Dict[str, Any]:
|
||||
"""提交一个模型推断任务, 立即返回 task_id。"""
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
await set_task(
|
||||
task_id,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"kind": "predict",
|
||||
"model_id": payload.model_id,
|
||||
"input_zarr_path": payload.input_zarr_path,
|
||||
"output_zarr_path": payload.output_zarr_path,
|
||||
"status": "PENDING",
|
||||
"error": None,
|
||||
"traceback": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
background_tasks.add_task(
|
||||
execute_predict_task,
|
||||
task_id,
|
||||
payload.model_id,
|
||||
payload.input_zarr_path,
|
||||
payload.output_zarr_path,
|
||||
)
|
||||
return {"task_id": task_id, "status": "PENDING", "kind": "predict"}
|
||||
Reference in New Issue
Block a user