848 lines
30 KiB
Python
848 lines
30 KiB
Python
"""
|
||
app/api/modeling.py
|
||
===================
|
||
|
||
机器学习与水质反演相关的 API 路由。
|
||
|
||
接口(最终路径, 挂载后):
|
||
POST /api/modeling/train 提交模型训练任务, 立即返回 task_id
|
||
GET /api/modeling/models 列出已训练好的模型(未来从磁盘 joblib 读)
|
||
POST /api/modeling/predict 提交模型推断任务, 立即返回 task_id
|
||
|
||
设计要点
|
||
--------
|
||
- 训练 / 推断均为异步后台任务, 复用 app.core.task_store 的并发安全任务状态。
|
||
- 模型元数据用模块级 _MODEL_REGISTRY 暂存(开发期内存存储),
|
||
未来从磁盘 joblib 读时只需替换 list_trained_models() 内部实现即可。
|
||
- /predict 已接入真实 sklearn + xarray + dask 流式推断:
|
||
* joblib.load 读模型(缺文件时降级为 Dummy RandomForestRegressor)
|
||
* xr.open_zarr 延迟打开影像, NaN 填 0
|
||
* xr.apply_ufunc(dask="parallelized") 沿 (y, x) 逐 chunk 调 model.predict
|
||
* to_zarr(mode="w", compute=True) 流式写出, 内存峰值 ≈ 1 个 chunk
|
||
- /train 已接入真实 sklearn + pandas 训练管线:
|
||
* pd.read_csv 读结构化训练表(支持 [lat, lon, target_*, feature_*] 布局)
|
||
* 按 target 列 dropna 清洗;按 feature_start 索引/列名切分特征
|
||
* sklearn Pipeline: StandardScaler -> {RF/SVR/LinearRegression/KNN/PLS}
|
||
* train_test_split(80/20) 划分, 计算 test_r2/rmse/mae
|
||
* joblib.dump({model, metadata}) 落盘 ./data/models/{model_id}.joblib
|
||
* 测试指标写回 TASK_STORE, 同时登记到 _MODEL_REGISTRY
|
||
注: 旧版 SPXY / KS 划分留作未来扩展, 当前固定 random 划分 (test_size=0.2, random_state=42)。
|
||
"""
|
||
|
||
import asyncio
|
||
import shutil
|
||
import traceback
|
||
import uuid
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Union
|
||
|
||
import joblib
|
||
import numpy as np
|
||
import pandas as pd
|
||
import xarray as xr
|
||
from fastapi import APIRouter, BackgroundTasks, HTTPException, UploadFile, File
|
||
from pydantic import BaseModel, Field
|
||
from sklearn.cross_decomposition import PLSRegression
|
||
from sklearn.ensemble import RandomForestRegressor
|
||
from sklearn.linear_model import LinearRegression
|
||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||
from sklearn.model_selection import train_test_split
|
||
from sklearn.neighbors import KNeighborsRegressor
|
||
from sklearn.pipeline import Pipeline
|
||
from sklearn.preprocessing import StandardScaler
|
||
from sklearn.svm import SVR
|
||
|
||
# 复用并发安全任务状态存储(与 deglint 共享同一份 TASK_STORE,
|
||
# 通过 task 记录里的 "kind" 字段区分 train / predict / deglint)
|
||
from app.core.task_store import get_task, set_task, update_task
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 路由实例
|
||
# ---------------------------------------------------------------------------
|
||
# prefix="/modeling" 让本文件内只写 /train /models /predict 等短路径,
|
||
# 最终完整路径由 main.py 挂载时再补 /api。
|
||
# ---------------------------------------------------------------------------
|
||
router = APIRouter(prefix="/modeling", tags=["modeling"])
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 数据模型
|
||
# ---------------------------------------------------------------------------
|
||
class TrainRequest(BaseModel):
|
||
"""POST /api/modeling/train 的请求体"""
|
||
|
||
model_type: str = Field(
|
||
...,
|
||
description="模型类型, 例如 'RF' (随机森林) / 'SVR' (支持向量回归) / 'XGBoost' / 'MLP'",
|
||
examples=["RF", "SVR"],
|
||
)
|
||
target: str = Field(
|
||
...,
|
||
description="反演目标水质参数, 例如 'Chl-a' (叶绿素a) / 'TSS' (总悬浮物) / 'CDOM' (有色可溶有机物)",
|
||
examples=["Chl-a", "TSS", "CDOM"],
|
||
)
|
||
train_data_path: str = Field(
|
||
...,
|
||
description="训练数据集的 zarr 路径(包含 reflectance 变量与 target 标签)",
|
||
examples=["./data/train.zarr"],
|
||
)
|
||
feature_start: Union[int, str] = Field(
|
||
default=4,
|
||
description=(
|
||
"特征列起始位置. 表格布局假定为 "
|
||
"[lat, lon, target_1, target_2, ..., feature_1, feature_2, ...] "
|
||
"可传 int 列索引(如 4)或 str 列名(如 '374.285' 波长起点)。"
|
||
"默认 4, 即前 4 列视为元数据/目标, 之后全部是特征。"
|
||
),
|
||
examples=[4, "374.285"],
|
||
)
|
||
params: Dict[str, Any] = Field(
|
||
default_factory=dict,
|
||
description="模型超参, 例如 RF 的 {'n_estimators': 100, 'max_depth': 20}",
|
||
examples=[{"n_estimators": 100, "max_depth": 20}],
|
||
)
|
||
|
||
|
||
class PredictRequest(BaseModel):
|
||
"""POST /api/modeling/predict 的请求体"""
|
||
|
||
model_id: str = Field(
|
||
...,
|
||
description="已训练模型的 ID(由 /api/modeling/train 返回或 /api/modeling/models 列出)",
|
||
)
|
||
input_zarr_path: str = Field(
|
||
...,
|
||
description="待推断影像的 zarr 路径",
|
||
examples=["./data/scene.zarr"],
|
||
)
|
||
output_zarr_path: Optional[str] = Field(
|
||
default=None,
|
||
description=(
|
||
"输出 zarr 路径, 缺省时由后端按规则生成 "
|
||
"(如 ./data/{model_id}_{input_stem}_pred.zarr)"
|
||
),
|
||
)
|
||
|
||
|
||
class TaskAcceptedResponse(BaseModel):
|
||
"""提交训练/推断任务后立即返回的响应"""
|
||
|
||
task_id: str
|
||
status: str # 一定是 PENDING
|
||
kind: str # "train" / "predict", 便于前端识别任务类型
|
||
|
||
|
||
class ModelInfo(BaseModel):
|
||
"""单个模型的元信息(GET /api/modeling/models 的元素)"""
|
||
|
||
model_id: str
|
||
model_type: str
|
||
target: str
|
||
params: Dict[str, Any]
|
||
path: str # joblib 文件路径
|
||
created_at: str
|
||
train_task_id: str # 产生此模型的那个训练任务的 ID
|
||
|
||
|
||
class ModelListResponse(BaseModel):
|
||
"""GET /api/modeling/models 的响应"""
|
||
|
||
models: List[ModelInfo]
|
||
count: int
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 模块级模型注册表(开发期内存, 未来替换为磁盘扫描)
|
||
# ---------------------------------------------------------------------------
|
||
# model_id -> ModelInfo 字典
|
||
# 读多写少, 用一个普通 dict 足够(CPython GIL 兜底)。
|
||
# 写时(训练完成时)只发生一次, 无并发风险。
|
||
# ---------------------------------------------------------------------------
|
||
_MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {}
|
||
|
||
|
||
def _register_model(record: Dict[str, Any]) -> None:
|
||
"""将训练完成的模型登记到内存注册表。"""
|
||
_MODEL_REGISTRY[record["model_id"]] = record
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 训练管线的模块级辅助函数
|
||
# ---------------------------------------------------------------------------
|
||
# 设计要点 (与推断管线一致):
|
||
# 1) 模块级函数: dask / joblib 后端若走子进程 pickle, 嵌套闭包会丢字段。
|
||
# 2) 同步执行: execute_train_task 用 asyncio.to_thread 派发, 内部全程同步阻塞。
|
||
# 3) 失败抛异常: 异常由 execute_train_task 捕获, 转 FAILED + traceback。
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# model_type (大写字符串) -> sklearn 估计器类
|
||
# 与 OpenClaw model_configs 思路一致, 但此处只保留类 (参数由 params 透传)
|
||
_MODEL_CLASS_REGISTRY: Dict[str, type] = {
|
||
"RF": RandomForestRegressor,
|
||
"SVR": SVR,
|
||
"LinearRegression": LinearRegression,
|
||
"KNN": KNeighborsRegressor,
|
||
"PLS": PLSRegression,
|
||
}
|
||
|
||
|
||
def _get_model_pipeline(model_type: str, params: Optional[Dict[str, Any]]) -> Pipeline:
|
||
"""
|
||
模型工厂: 按 model_type 选 sklearn 类, 用 StandardScaler + 估计器构造 Pipeline。
|
||
|
||
与 OpenClaw 不同之处: 把 scaler 放进 Pipeline 第一步,
|
||
推断时直接 pipeline.predict(X) 即可, scaler 参数与训练时严格一致。
|
||
"""
|
||
model_cls = _MODEL_CLASS_REGISTRY.get(model_type)
|
||
if model_cls is None:
|
||
raise ValueError(
|
||
f"不支持的 model_type='{model_type}', "
|
||
f"可选: {sorted(_MODEL_CLASS_REGISTRY.keys())}"
|
||
)
|
||
estimator = model_cls(**(params or {}))
|
||
return Pipeline([("scaler", StandardScaler()), ("model", estimator)])
|
||
|
||
|
||
def _load_train_df(csv_path: str) -> pd.DataFrame:
|
||
"""
|
||
读 CSV 训练表, 规整空串 / 空白 / NULL 等为 NaN。
|
||
|
||
沿用 OpenClaw modeling_batch.load_data_batch 的读取策略:
|
||
na_values 显式列举 + 正则二次清理 (防 cell 内出现 " " 等纯空白)。
|
||
"""
|
||
try:
|
||
df = pd.read_csv(
|
||
csv_path,
|
||
na_values=["", " ", "NaN", "nan", "NULL", "null"],
|
||
)
|
||
except FileNotFoundError as exc:
|
||
raise FileNotFoundError(f"训练数据文件不存在: {csv_path}") from exc
|
||
except pd.errors.EmptyDataError as exc:
|
||
raise ValueError(f"训练数据文件为空: {csv_path}") from exc
|
||
# 二次清理: 残留的纯空白 cell
|
||
df = df.replace(r"^\s*$", np.nan, regex=True)
|
||
return df
|
||
|
||
|
||
def _resolve_feature_start(
|
||
df: pd.DataFrame,
|
||
feature_start: Union[int, str],
|
||
) -> int:
|
||
"""
|
||
将 feature_start (int 索引 / str 列名) 统一解析为 int 列索引。
|
||
|
||
与 OpenClaw modeling_batch.load_data_batch / load_data_single 一致:
|
||
str 走 columns.get_loc, int 直接返回。
|
||
"""
|
||
if isinstance(feature_start, str):
|
||
if feature_start not in df.columns:
|
||
raise ValueError(
|
||
f"feature_start='{feature_start}' 不在 CSV 列中: {list(df.columns)}"
|
||
)
|
||
return int(df.columns.get_loc(feature_start))
|
||
return int(feature_start)
|
||
|
||
|
||
def _run_train_sync(
|
||
model_type: str,
|
||
target: str,
|
||
train_data_path: str,
|
||
feature_start: Union[int, str],
|
||
params: Optional[Dict[str, Any]],
|
||
output_model_path: Path,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
完整同步训练流程 (由 execute_train_task 在线程池内调用):
|
||
|
||
pd.read_csv -> 目标列 dropna -> 切特征 -> train_test_split(80/20)
|
||
-> Pipeline(StandardScaler + model).fit -> 评估 test_r2/rmse/mae
|
||
-> joblib.dump({model, metadata}, output_model_path)
|
||
|
||
Returns:
|
||
metadata 字典, 含 test_r2 / test_rmse / test_mae / n_features 等,
|
||
调用方负责写回 TASK_STORE 和 _MODEL_REGISTRY。
|
||
|
||
注: 旧版 SPXY / KS 划分留作未来扩展 (params.split_method 控制),
|
||
当前固定 random + test_size=0.2 + random_state=42。
|
||
"""
|
||
df = _load_train_df(train_data_path)
|
||
|
||
if target not in df.columns:
|
||
raise ValueError(
|
||
f"target='{target}' 不在 CSV 列中, 可选: {list(df.columns)}"
|
||
)
|
||
|
||
# 1) 清洗: 仅剔除 target NaN 的行 (与 OpenClaw load_data_single 一致)
|
||
df = df[df[target].notna()].copy()
|
||
if df.empty:
|
||
raise ValueError("target 剔除 NaN 后无样本, 终止训练")
|
||
|
||
# 2) 特征切分
|
||
feature_start_idx = _resolve_feature_start(df, feature_start)
|
||
feature_columns = list(df.columns[feature_start_idx:])
|
||
|
||
X = df.iloc[:, feature_start_idx:].astype(np.float64)
|
||
y = df[target].astype(np.float64).values
|
||
|
||
# 3) 划分 (固定 random, 未来扩展 spxy/ks)
|
||
X_train, X_test, y_train, y_test = train_test_split(
|
||
X.values,
|
||
y,
|
||
test_size=0.2,
|
||
random_state=42,
|
||
)
|
||
|
||
# 4) 构造 Pipeline + 训练
|
||
pipeline = _get_model_pipeline(model_type, params)
|
||
pipeline.fit(X_train, y_train)
|
||
|
||
# 5) 测试集与训练集评估
|
||
y_pred = pipeline.predict(X_test)
|
||
test_r2 = float(r2_score(y_test, y_pred))
|
||
test_rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
|
||
test_mae = float(mean_absolute_error(y_test, y_pred))
|
||
|
||
y_train_pred = pipeline.predict(X_train)
|
||
train_r2 = float(r2_score(y_train, y_train_pred))
|
||
train_rmse = float(np.sqrt(mean_squared_error(y_train, y_train_pred)))
|
||
train_mae = float(mean_absolute_error(y_train, y_train_pred))
|
||
|
||
metadata: Dict[str, Any] = {
|
||
"model_type": model_type,
|
||
"target": target,
|
||
"feature_start": feature_start,
|
||
"feature_columns": feature_columns,
|
||
"n_features": int(X.shape[1]),
|
||
"n_samples": int(X.shape[0]),
|
||
"train_size": int(X_train.shape[0]),
|
||
"test_size": int(X_test.shape[0]),
|
||
"params": dict(params or {}),
|
||
"test_r2": test_r2,
|
||
"test_rmse": test_rmse,
|
||
"test_mae": test_mae,
|
||
"train_r2": train_r2,
|
||
"train_rmse": train_rmse,
|
||
"train_mae": train_mae,
|
||
"split_method": "random",
|
||
"trained_at": datetime.now().isoformat(),
|
||
}
|
||
|
||
# 7) 持久化 (目录可能不存在, 顺手建)
|
||
output_model_path = Path(output_model_path)
|
||
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
||
joblib.dump(
|
||
{"model": pipeline, "metadata": metadata},
|
||
output_model_path,
|
||
)
|
||
|
||
return metadata
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 推断管线的模块级辅助函数
|
||
# ---------------------------------------------------------------------------
|
||
# 设计要点:
|
||
# 1) Dask 调度时, 函数必须可被工作进程 pickle 序列化。
|
||
# 因此 _predict_block / _load_model / _make_dummy_model / _run_predict_sync
|
||
# 全部是模块级函数 (而非嵌套), 避免闭包陷阱。
|
||
# 2) _predict_block 通过 model.predict(spectra_2d) 整批预测,
|
||
# 整张影像的 O(n_pixels * n_bands) 一次性预测在大矩阵上必 OOM,
|
||
# 因此外层用 xr.apply_ufunc(dask="parallelized") 把矩阵切块
|
||
# 逐块进入此函数, 单次内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) 大小。
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_dummy_model(n_features: int) -> RandomForestRegressor:
|
||
"""
|
||
构造一个 Dummy 随机森林回归器。
|
||
|
||
用途:
|
||
1) 真实 joblib 文件不存在时的连通性测试
|
||
2) 训练骨架尚未接入真实数据时的占位推断
|
||
"""
|
||
rng = np.random.default_rng(42)
|
||
X = rng.random((200, n_features))
|
||
y = rng.random(200)
|
||
model = RandomForestRegressor(
|
||
n_estimators=10, max_depth=5, random_state=0, n_jobs=1
|
||
)
|
||
model.fit(X, y)
|
||
return model
|
||
|
||
|
||
def _load_model(path: str, n_features: int) -> Any:
|
||
"""
|
||
加载训练好的 sklearn 模型, 失败时降级 Dummy。
|
||
|
||
优先级:
|
||
1) path 存在且 joblib.load 成功 -> 返回真实模型
|
||
2) 否则 -> 降级为 Dummy 随机森林 (n_features 必须指定)
|
||
"""
|
||
p = Path(path)
|
||
if p.is_file() and p.stat().st_size > 0:
|
||
try:
|
||
print(f"[model] 从磁盘加载: {path}")
|
||
return joblib.load(path)
|
||
except Exception as exc: # noqa: BLE001
|
||
print(f"[model] joblib.load 失败 ({type(exc).__name__}: {exc}), 降级 Dummy")
|
||
print(f"[model] 真实 joblib 不存在 ({path}), 使用 Dummy RandomForest")
|
||
return _make_dummy_model(n_features)
|
||
|
||
|
||
def _predict_block(spectra_3d: np.ndarray, model: Any) -> np.ndarray:
|
||
"""
|
||
单个 dask chunk 的推断函数 (xr.apply_ufunc 会自动调度调用)。
|
||
|
||
Parameters
|
||
----------
|
||
spectra_3d : np.ndarray
|
||
形状 (y_chunk, x_chunk, n_bands)。
|
||
此形状由 input_core_dims=[["band"]] 决定:
|
||
xarray 会把 band 维移到最后一轴, 然后按 (y, x) 的 chunk 切分调用本函数。
|
||
model : 已 fit 好的 sklearn 估计器
|
||
接受 (n_samples, n_features) 输入, 返回 (n_samples,) 预测。
|
||
|
||
Returns
|
||
-------
|
||
np.ndarray
|
||
形状 (y_chunk, x_chunk), dtype float32 的标量预测图。
|
||
"""
|
||
yc, xc, nb = spectra_3d.shape
|
||
# 2D 化: 每个像素一行光谱
|
||
flat = spectra_3d.reshape(yc * xc, nb)
|
||
# sklearn 风格的批量预测
|
||
pred = model.predict(flat)
|
||
# 还原为 2D 空间图, 强制 float32 节约一半内存
|
||
return pred.reshape(yc, xc).astype(np.float32, copy=False)
|
||
|
||
|
||
def _run_predict_sync(
|
||
model: Any,
|
||
model_id: str,
|
||
input_zarr_path: str,
|
||
output_zarr_path: str,
|
||
) -> None:
|
||
"""
|
||
同步推断主流程 (被 asyncio.to_thread 调用)。
|
||
|
||
流程:
|
||
1) xr.open_zarr 延迟打开 (dask 数组, 不一次性读入内存)
|
||
2) NaN -> 0 清洗 (model.predict 不接受 NaN)
|
||
3) xr.apply_ufunc 沿 (y, x) 逐 chunk 调 _predict_block
|
||
4) 非水域置 NaN (zarr 支持 float NaN)
|
||
5) to_zarr 触发整图计算 + 流式写出
|
||
"""
|
||
# 1. 延迟打开输入 (关键: Dask 不一次性读入内存)
|
||
ds = xr.open_zarr(input_zarr_path, chunks="auto")
|
||
if "reflectance" not in ds.data_vars:
|
||
raise KeyError(
|
||
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds.data_vars)}"
|
||
)
|
||
|
||
reflectance = ds["reflectance"] # dims: (y, x, band)
|
||
n_bands = reflectance.sizes["band"]
|
||
|
||
# 2. 水域掩膜 (与去耀斑算法同约定)
|
||
if "water_mask" in ds.data_vars or "water_mask" in ds.coords:
|
||
water_mask = ds["water_mask"].astype(bool)
|
||
else:
|
||
water_mask = xr.ones_like(reflectance.isel(band=0), dtype=bool)
|
||
|
||
# 3. NaN 清洗: 填充 0 (model.predict 不接受 NaN)
|
||
refl_clean = reflectance.fillna(0.0)
|
||
|
||
# 4. 核心: 用 apply_ufunc 把 model.predict 沿 (y, x) 应用
|
||
# dask="parallelized" 让每个 (y_chunk, x_chunk, band) chunk
|
||
# 独立调 _predict_block, 任意时刻内存中只有若干个 chunk。
|
||
prediction: xr.DataArray = xr.apply_ufunc(
|
||
_predict_block,
|
||
refl_clean,
|
||
kwargs={"model": model},
|
||
input_core_dims=[["band"]],
|
||
output_core_dims=[[]],
|
||
dask="parallelized",
|
||
output_dtypes=[np.float32],
|
||
dask_gufunc_kwargs={"allow_rechunk": True},
|
||
vectorize=False,
|
||
)
|
||
|
||
# 5. 非水域置 NaN (zarr 支持 float NaN, 便于后续可视化/掩膜分析)
|
||
prediction = prediction.where(water_mask, np.nan)
|
||
|
||
# 6. 包装为 Dataset 并流式写出
|
||
out = xr.Dataset(
|
||
{"prediction": prediction},
|
||
attrs={
|
||
"model_id": model_id,
|
||
"input_zarr_path": input_zarr_path,
|
||
"n_bands": n_bands,
|
||
"created_at": datetime.now().isoformat(),
|
||
},
|
||
)
|
||
# 保留 y/x 坐标
|
||
out = out.assign_coords(y=ds["y"], x=ds["x"])
|
||
|
||
# to_zarr + compute=True 触发整图 dask 图求值
|
||
# 中间会按 chunk 逐块调度到线程池, 内存峰值 ≈ 1 个 chunk 的体量
|
||
out.to_zarr(output_zarr_path, mode="w", compute=True)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 后台任务执行器
|
||
# ---------------------------------------------------------------------------
|
||
async def execute_train_task(
|
||
task_id: str,
|
||
model_type: str,
|
||
target: str,
|
||
train_data_path: str,
|
||
feature_start: Union[int, str],
|
||
params: Dict[str, Any],
|
||
) -> None:
|
||
"""
|
||
训练任务后台执行器(已接入真实 sklearn 训练流程)。
|
||
|
||
流程:
|
||
1) get_task 校验任务存在
|
||
2) update_task(PROCESSING)
|
||
3) 生成 model_id / model_path
|
||
4) asyncio.to_thread 派发 _run_train_sync 到默认线程池
|
||
5) 成功 -> _register_model + update_task(SUCCESS, 附 test_r2/rmse/mae)
|
||
6) 失败 -> update_task(FAILED, 附 error + traceback)
|
||
"""
|
||
record = await get_task(task_id)
|
||
if record is None:
|
||
print(f"[{task_id}] 训练任务不存在, 跳过")
|
||
return
|
||
|
||
await update_task(
|
||
task_id,
|
||
status="PROCESSING",
|
||
updated_at=datetime.now().isoformat(),
|
||
)
|
||
print(
|
||
f"[{task_id}] 开始训练 model_type={model_type} target={target} "
|
||
f"train_data_path={train_data_path} feature_start={feature_start}"
|
||
)
|
||
|
||
# model_id 用 uuid4 前 12 位 (8 位易撞, 12 位兼顾可读性)
|
||
model_id = f"model_{uuid.uuid4().hex[:12]}"
|
||
model_path = Path(f"./data/models/{model_id}.joblib")
|
||
|
||
try:
|
||
# 同步 sklearn / pandas 训练丢到默认线程池, 不阻塞 event loop
|
||
metadata = await asyncio.to_thread(
|
||
_run_train_sync,
|
||
model_type,
|
||
target,
|
||
train_data_path,
|
||
feature_start,
|
||
params,
|
||
model_path,
|
||
)
|
||
|
||
# 登记到内存注册表 (供 /predict 查 model_id)
|
||
_register_model(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"target": target,
|
||
"params": dict(params or {}),
|
||
"path": str(model_path),
|
||
"feature_start": feature_start,
|
||
"n_features": metadata["n_features"],
|
||
"test_r2": metadata["test_r2"],
|
||
"test_rmse": metadata["test_rmse"],
|
||
"test_mae": metadata["test_mae"],
|
||
"created_at": datetime.now().isoformat(),
|
||
"train_task_id": task_id,
|
||
}
|
||
)
|
||
|
||
# 把训练指标写回任务记录, 前端轮询时可直接看
|
||
await update_task(
|
||
task_id,
|
||
status="SUCCESS",
|
||
model_id=model_id,
|
||
model_path=str(model_path),
|
||
test_r2=metadata["test_r2"],
|
||
test_rmse=metadata["test_rmse"],
|
||
test_mae=metadata["test_mae"],
|
||
n_features=metadata["n_features"],
|
||
n_samples=metadata["n_samples"],
|
||
error=None,
|
||
traceback=None,
|
||
updated_at=datetime.now().isoformat(),
|
||
)
|
||
print(
|
||
f"[{task_id}] 训练完成 -> model_id={model_id} "
|
||
f"test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f}"
|
||
)
|
||
|
||
except Exception as exc: # noqa: BLE001
|
||
# 失败时 model_path 不一定有产物, 显式置 None 方便前端判断
|
||
await update_task(
|
||
task_id,
|
||
status="FAILED",
|
||
model_id=None,
|
||
model_path=None,
|
||
error=f"{type(exc).__name__}: {exc}",
|
||
traceback=traceback.format_exc(),
|
||
updated_at=datetime.now().isoformat(),
|
||
)
|
||
print(f"[{task_id}] 训练失败 -> {type(exc).__name__}: {exc}")
|
||
|
||
|
||
async def execute_predict_task(
|
||
task_id: str,
|
||
model_id: str,
|
||
input_zarr_path: str,
|
||
output_zarr_path: Optional[str],
|
||
) -> None:
|
||
"""
|
||
推断任务后台执行器(真实实现版)。
|
||
|
||
OOM 防护策略:
|
||
- xr.open_zarr(..., chunks="auto") 延迟打开, 整图不一次性读入内存
|
||
- xr.apply_ufunc(..., dask="parallelized") 把影像按 chunk 切分
|
||
- 每个 chunk 内部 reshape 成 2D, 调 model.predict, 再 reshape 回 2D
|
||
- 任意时刻内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) chunk 的体量
|
||
- 整图完成计算后再 to_zarr(compute=True) 流式写出
|
||
"""
|
||
record = await get_task(task_id)
|
||
if record is None:
|
||
print(f"[{task_id}] 推断任务不存在, 跳过")
|
||
return
|
||
|
||
# 1. 校验 model_id 是否已注册 (避免在后台任务里报模糊错误)
|
||
model_meta = _MODEL_REGISTRY.get(model_id)
|
||
if model_meta is None:
|
||
await update_task(
|
||
task_id,
|
||
status="FAILED",
|
||
error=f"model_id 不存在: {model_id}",
|
||
updated_at=datetime.now().isoformat(),
|
||
)
|
||
print(f"[{task_id}] 推断失败 -> model_id 不存在: {model_id}")
|
||
return
|
||
|
||
# 2. 自动生成 output_zarr_path (若未提供)
|
||
if output_zarr_path is None:
|
||
stem = input_zarr_path.rstrip("/\\").split("/")[-1].split("\\")[-1]
|
||
stem = stem.replace(".zarr", "")
|
||
output_zarr_path = f"./data/{model_id}_{stem}_pred.zarr"
|
||
|
||
await update_task(
|
||
task_id,
|
||
status="PROCESSING",
|
||
updated_at=datetime.now().isoformat(),
|
||
)
|
||
print(f"[{task_id}] 开始推断 model_id={model_id} input={input_zarr_path}")
|
||
|
||
try:
|
||
# 3. 探测波段数 (用于 Dummy 模型适配)
|
||
# 这里只读 zarr 元数据 (.zarray 的 shape), 不读真实数据
|
||
ds_probe = xr.open_zarr(input_zarr_path, chunks="auto")
|
||
if "reflectance" not in ds_probe.data_vars:
|
||
raise KeyError(
|
||
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds_probe.data_vars)}"
|
||
)
|
||
n_bands = ds_probe["reflectance"].sizes["band"]
|
||
ds_probe.close()
|
||
|
||
# 4. 加载模型 (真实文件优先, Dummy 兜底)
|
||
model = _load_model(model_meta["path"], n_features=n_bands)
|
||
|
||
# 5. 包装同步执行, 丢到线程池, 事件循环不阻塞
|
||
await asyncio.to_thread(
|
||
_run_predict_sync,
|
||
model,
|
||
model_id,
|
||
input_zarr_path,
|
||
output_zarr_path,
|
||
)
|
||
|
||
await update_task(
|
||
task_id,
|
||
status="SUCCESS",
|
||
output_zarr_path=output_zarr_path,
|
||
model_id=model_id,
|
||
error=None,
|
||
updated_at=datetime.now().isoformat(),
|
||
)
|
||
print(f"[{task_id}] 推断完成 -> output={output_zarr_path}")
|
||
|
||
except Exception as exc: # noqa: BLE001
|
||
tb_text = traceback.format_exc()
|
||
await update_task(
|
||
task_id,
|
||
status="FAILED",
|
||
output_zarr_path=None,
|
||
error=f"{type(exc).__name__}: {exc}",
|
||
traceback=tb_text,
|
||
updated_at=datetime.now().isoformat(),
|
||
)
|
||
print(f"[{task_id}] 推断失败 -> {type(exc).__name__}: {exc}")
|
||
print(tb_text)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# POST /api/modeling/train
|
||
# ---------------------------------------------------------------------------
|
||
@router.post("/train", response_model=TaskAcceptedResponse)
|
||
async def submit_train(
|
||
payload: TrainRequest,
|
||
background_tasks: BackgroundTasks,
|
||
) -> Dict[str, Any]:
|
||
"""提交一个模型训练任务, 立即返回 task_id。"""
|
||
|
||
task_id = str(uuid.uuid4())
|
||
await set_task(
|
||
task_id,
|
||
{
|
||
"task_id": task_id,
|
||
"kind": "train",
|
||
"model_type": payload.model_type,
|
||
"target": payload.target,
|
||
"train_data_path": payload.train_data_path,
|
||
"feature_start": payload.feature_start,
|
||
"params": payload.params,
|
||
"status": "PENDING",
|
||
"model_id": None,
|
||
"model_path": None,
|
||
"test_r2": None,
|
||
"test_rmse": None,
|
||
"test_mae": None,
|
||
"n_features": None,
|
||
"n_samples": None,
|
||
"error": None,
|
||
"traceback": None,
|
||
"created_at": datetime.now().isoformat(),
|
||
"updated_at": datetime.now().isoformat(),
|
||
},
|
||
)
|
||
background_tasks.add_task(
|
||
execute_train_task,
|
||
task_id,
|
||
payload.model_type,
|
||
payload.target,
|
||
payload.train_data_path,
|
||
payload.feature_start,
|
||
payload.params,
|
||
)
|
||
return {"task_id": task_id, "status": "PENDING", "kind": "train"}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/modeling/models
|
||
# ---------------------------------------------------------------------------
|
||
@router.get("/models", response_model=ModelListResponse)
|
||
async def list_trained_models() -> Dict[str, Any]:
|
||
"""
|
||
列出已训练好的模型。
|
||
|
||
未来实现: 从 ./data/models/*.joblib 扫描元信息,
|
||
当前直接从内存 _MODEL_REGISTRY 读。
|
||
"""
|
||
models = list(_MODEL_REGISTRY.values())
|
||
# 按 created_at 倒序, 最新训练的在前
|
||
models.sort(key=lambda m: m.get("created_at", ""), reverse=True)
|
||
return {"models": models, "count": len(models)}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# POST /api/modeling/predict
|
||
# ---------------------------------------------------------------------------
|
||
@router.post("/predict", response_model=TaskAcceptedResponse)
|
||
async def submit_predict(
|
||
payload: PredictRequest,
|
||
background_tasks: BackgroundTasks,
|
||
) -> Dict[str, Any]:
|
||
"""提交一个模型推断任务, 立即返回 task_id。"""
|
||
|
||
task_id = str(uuid.uuid4())
|
||
await set_task(
|
||
task_id,
|
||
{
|
||
"task_id": task_id,
|
||
"kind": "predict",
|
||
"model_id": payload.model_id,
|
||
"input_zarr_path": payload.input_zarr_path,
|
||
"output_zarr_path": payload.output_zarr_path,
|
||
"status": "PENDING",
|
||
"error": None,
|
||
"traceback": None,
|
||
"created_at": datetime.now().isoformat(),
|
||
"updated_at": datetime.now().isoformat(),
|
||
},
|
||
)
|
||
background_tasks.add_task(
|
||
execute_predict_task,
|
||
task_id,
|
||
payload.model_id,
|
||
payload.input_zarr_path,
|
||
payload.output_zarr_path,
|
||
)
|
||
return {"task_id": task_id, "status": "PENDING", "kind": "predict"}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# models_router — 独立于 modeling_router,路径前缀为 /models
|
||
# 最终完整路径: GET /api/models, POST /api/models/upload
|
||
# ---------------------------------------------------------------------------
|
||
models_router = APIRouter(prefix="/models", tags=["models"])
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GET /api/models
|
||
# ---------------------------------------------------------------------------
|
||
@models_router.get("")
|
||
async def list_models() -> Dict[str, Any]:
|
||
"""
|
||
扫描 ./data/models/ 目录,返回所有 .joblib 文件名(不含后缀)。
|
||
|
||
异常处理:目录不存在时自动创建,返回空列表。
|
||
"""
|
||
models_dir = Path("./data/models")
|
||
models_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
model_names = [
|
||
p.stem for p in models_dir.iterdir() if p.suffix == ".joblib"
|
||
]
|
||
return {"models": model_names}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# POST /api/models/upload
|
||
# ---------------------------------------------------------------------------
|
||
@models_router.post("/upload")
|
||
async def upload_model(
|
||
file: UploadFile = File(...),
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
接收上传的 .joblib 模型文件,保存到 ./data/models/ 目录。
|
||
|
||
- 校验后缀必须为 .joblib
|
||
- 目录不存在时自动创建
|
||
- 返回状态和文件名(不含后缀)
|
||
"""
|
||
if not file.filename or not file.filename.lower().endswith(".joblib"):
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="仅支持 .joblib 格式的文件",
|
||
)
|
||
|
||
models_dir = Path("./data/models")
|
||
models_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
dest_path = models_dir / file.filename
|
||
|
||
with dest_path.open("wb") as buffer:
|
||
shutil.copyfileobj(file.file, buffer)
|
||
|
||
return {
|
||
"status": "success",
|
||
"model_id": dest_path.stem,
|
||
}
|