""" app/api/modeling.py =================== 机器学习与水质反演相关的 API 路由。 接口(最终路径, 挂载后): POST /api/modeling/train 提交模型训练任务, 立即返回 task_id GET /api/modeling/models 列出已训练好的模型(未来从磁盘 joblib 读) POST /api/modeling/predict 提交模型推断任务, 立即返回 task_id 设计要点 -------- - 训练 / 推断均为异步后台任务, 复用 app.core.task_store 的并发安全任务状态。 - 模型元数据用模块级 _MODEL_REGISTRY 暂存(开发期内存存储), 未来从磁盘 joblib 读时只需替换 list_trained_models() 内部实现即可。 - /predict 已接入真实 sklearn + xarray + dask 流式推断: * joblib.load 读模型(缺文件时降级为 Dummy RandomForestRegressor) * xr.open_zarr 延迟打开影像, NaN 填 0 * xr.apply_ufunc(dask="parallelized") 沿 (y, x) 逐 chunk 调 model.predict * to_zarr(mode="w", compute=True) 流式写出, 内存峰值 ≈ 1 个 chunk - /train 已接入真实 sklearn + pandas 训练管线: * pd.read_csv 读结构化训练表(支持 [lat, lon, target_*, feature_*] 布局) * 按 target 列 dropna 清洗;按 feature_start 索引/列名切分特征 * sklearn Pipeline: StandardScaler -> {RF/SVR/LinearRegression/KNN/PLS} * train_test_split(80/20) 划分, 计算 test_r2/rmse/mae * joblib.dump({model, metadata}) 落盘 ./data/models/{model_id}.joblib * 测试指标写回 TASK_STORE, 同时登记到 _MODEL_REGISTRY 注: 旧版 SPXY / KS 划分留作未来扩展, 当前固定 random 划分 (test_size=0.2, random_state=42)。 """ import asyncio import shutil import traceback import uuid from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional, Union import joblib import numpy as np import pandas as pd import xarray as xr from fastapi import APIRouter, BackgroundTasks, HTTPException, UploadFile, File from pydantic import BaseModel, Field from sklearn.cross_decomposition import PLSRegression from sklearn.ensemble import RandomForestRegressor from sklearn.linear_model import LinearRegression from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsRegressor from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.svm import SVR # 复用并发安全任务状态存储(与 deglint 共享同一份 TASK_STORE, # 通过 task 记录里的 "kind" 字段区分 train / predict / deglint) from app.core.task_store import get_task, set_task, update_task # --------------------------------------------------------------------------- # 路由实例 # --------------------------------------------------------------------------- # prefix="/modeling" 让本文件内只写 /train /models /predict 等短路径, # 最终完整路径由 main.py 挂载时再补 /api。 # --------------------------------------------------------------------------- router = APIRouter(prefix="/modeling", tags=["modeling"]) # --------------------------------------------------------------------------- # 数据模型 # --------------------------------------------------------------------------- class TrainRequest(BaseModel): """POST /api/modeling/train 的请求体""" model_type: str = Field( ..., description="模型类型, 例如 'RF' (随机森林) / 'SVR' (支持向量回归) / 'XGBoost' / 'MLP'", examples=["RF", "SVR"], ) target: str = Field( ..., description="反演目标水质参数, 例如 'Chl-a' (叶绿素a) / 'TSS' (总悬浮物) / 'CDOM' (有色可溶有机物)", examples=["Chl-a", "TSS", "CDOM"], ) train_data_path: str = Field( ..., description="训练数据集的 zarr 路径(包含 reflectance 变量与 target 标签)", examples=["./data/train.zarr"], ) feature_start: Union[int, str] = Field( default=4, description=( "特征列起始位置. 表格布局假定为 " "[lat, lon, target_1, target_2, ..., feature_1, feature_2, ...] " "可传 int 列索引(如 4)或 str 列名(如 '374.285' 波长起点)。" "默认 4, 即前 4 列视为元数据/目标, 之后全部是特征。" ), examples=[4, "374.285"], ) params: Dict[str, Any] = Field( default_factory=dict, description="模型超参, 例如 RF 的 {'n_estimators': 100, 'max_depth': 20}", examples=[{"n_estimators": 100, "max_depth": 20}], ) class PredictRequest(BaseModel): """POST /api/modeling/predict 的请求体""" model_id: str = Field( ..., description="已训练模型的 ID(由 /api/modeling/train 返回或 /api/modeling/models 列出)", ) input_zarr_path: str = Field( ..., description="待推断影像的 zarr 路径", examples=["./data/scene.zarr"], ) output_zarr_path: Optional[str] = Field( default=None, description=( "输出 zarr 路径, 缺省时由后端按规则生成 " "(如 ./data/{model_id}_{input_stem}_pred.zarr)" ), ) class TaskAcceptedResponse(BaseModel): """提交训练/推断任务后立即返回的响应""" task_id: str status: str # 一定是 PENDING kind: str # "train" / "predict", 便于前端识别任务类型 class ModelInfo(BaseModel): """单个模型的元信息(GET /api/modeling/models 的元素)""" model_id: str model_type: str target: str params: Dict[str, Any] path: str # joblib 文件路径 created_at: str train_task_id: str # 产生此模型的那个训练任务的 ID class ModelListResponse(BaseModel): """GET /api/modeling/models 的响应""" models: List[ModelInfo] count: int # --------------------------------------------------------------------------- # 模块级模型注册表(开发期内存, 未来替换为磁盘扫描) # --------------------------------------------------------------------------- # model_id -> ModelInfo 字典 # 读多写少, 用一个普通 dict 足够(CPython GIL 兜底)。 # 写时(训练完成时)只发生一次, 无并发风险。 # --------------------------------------------------------------------------- _MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {} def _register_model(record: Dict[str, Any]) -> None: """将训练完成的模型登记到内存注册表。""" _MODEL_REGISTRY[record["model_id"]] = record # --------------------------------------------------------------------------- # 训练管线的模块级辅助函数 # --------------------------------------------------------------------------- # 设计要点 (与推断管线一致): # 1) 模块级函数: dask / joblib 后端若走子进程 pickle, 嵌套闭包会丢字段。 # 2) 同步执行: execute_train_task 用 asyncio.to_thread 派发, 内部全程同步阻塞。 # 3) 失败抛异常: 异常由 execute_train_task 捕获, 转 FAILED + traceback。 # --------------------------------------------------------------------------- # model_type (大写字符串) -> sklearn 估计器类 # 与 OpenClaw model_configs 思路一致, 但此处只保留类 (参数由 params 透传) _MODEL_CLASS_REGISTRY: Dict[str, type] = { "RF": RandomForestRegressor, "SVR": SVR, "LinearRegression": LinearRegression, "KNN": KNeighborsRegressor, "PLS": PLSRegression, } def _get_model_pipeline(model_type: str, params: Optional[Dict[str, Any]]) -> Pipeline: """ 模型工厂: 按 model_type 选 sklearn 类, 用 StandardScaler + 估计器构造 Pipeline。 与 OpenClaw 不同之处: 把 scaler 放进 Pipeline 第一步, 推断时直接 pipeline.predict(X) 即可, scaler 参数与训练时严格一致。 """ model_cls = _MODEL_CLASS_REGISTRY.get(model_type) if model_cls is None: raise ValueError( f"不支持的 model_type='{model_type}', " f"可选: {sorted(_MODEL_CLASS_REGISTRY.keys())}" ) estimator = model_cls(**(params or {})) return Pipeline([("scaler", StandardScaler()), ("model", estimator)]) def _load_train_df(csv_path: str) -> pd.DataFrame: """ 读 CSV 训练表, 规整空串 / 空白 / NULL 等为 NaN。 沿用 OpenClaw modeling_batch.load_data_batch 的读取策略: na_values 显式列举 + 正则二次清理 (防 cell 内出现 " " 等纯空白)。 """ try: df = pd.read_csv( csv_path, na_values=["", " ", "NaN", "nan", "NULL", "null"], ) except FileNotFoundError as exc: raise FileNotFoundError(f"训练数据文件不存在: {csv_path}") from exc except pd.errors.EmptyDataError as exc: raise ValueError(f"训练数据文件为空: {csv_path}") from exc # 二次清理: 残留的纯空白 cell df = df.replace(r"^\s*$", np.nan, regex=True) return df def _resolve_feature_start( df: pd.DataFrame, feature_start: Union[int, str], ) -> int: """ 将 feature_start (int 索引 / str 列名) 统一解析为 int 列索引。 与 OpenClaw modeling_batch.load_data_batch / load_data_single 一致: str 走 columns.get_loc, int 直接返回。 """ if isinstance(feature_start, str): if feature_start not in df.columns: raise ValueError( f"feature_start='{feature_start}' 不在 CSV 列中: {list(df.columns)}" ) return int(df.columns.get_loc(feature_start)) return int(feature_start) def _run_train_sync( model_type: str, target: str, train_data_path: str, feature_start: Union[int, str], params: Optional[Dict[str, Any]], output_model_path: Path, ) -> Dict[str, Any]: """ 完整同步训练流程 (由 execute_train_task 在线程池内调用): pd.read_csv -> 目标列 dropna -> 切特征 -> train_test_split(80/20) -> Pipeline(StandardScaler + model).fit -> 评估 test_r2/rmse/mae -> joblib.dump({model, metadata}, output_model_path) Returns: metadata 字典, 含 test_r2 / test_rmse / test_mae / n_features 等, 调用方负责写回 TASK_STORE 和 _MODEL_REGISTRY。 注: 旧版 SPXY / KS 划分留作未来扩展 (params.split_method 控制), 当前固定 random + test_size=0.2 + random_state=42。 """ df = _load_train_df(train_data_path) if target not in df.columns: raise ValueError( f"target='{target}' 不在 CSV 列中, 可选: {list(df.columns)}" ) # 1) 清洗: 仅剔除 target NaN 的行 (与 OpenClaw load_data_single 一致) df = df[df[target].notna()].copy() if df.empty: raise ValueError("target 剔除 NaN 后无样本, 终止训练") # 2) 特征切分 feature_start_idx = _resolve_feature_start(df, feature_start) feature_columns = list(df.columns[feature_start_idx:]) X = df.iloc[:, feature_start_idx:].astype(np.float64) y = df[target].astype(np.float64).values # 3) 划分 (固定 random, 未来扩展 spxy/ks) X_train, X_test, y_train, y_test = train_test_split( X.values, y, test_size=0.2, random_state=42, ) # 4) 构造 Pipeline + 训练 pipeline = _get_model_pipeline(model_type, params) pipeline.fit(X_train, y_train) # 5) 测试集与训练集评估 y_pred = pipeline.predict(X_test) test_r2 = float(r2_score(y_test, y_pred)) test_rmse = float(np.sqrt(mean_squared_error(y_test, y_pred))) test_mae = float(mean_absolute_error(y_test, y_pred)) y_train_pred = pipeline.predict(X_train) train_r2 = float(r2_score(y_train, y_train_pred)) train_rmse = float(np.sqrt(mean_squared_error(y_train, y_train_pred))) train_mae = float(mean_absolute_error(y_train, y_train_pred)) metadata: Dict[str, Any] = { "model_type": model_type, "target": target, "feature_start": feature_start, "feature_columns": feature_columns, "n_features": int(X.shape[1]), "n_samples": int(X.shape[0]), "train_size": int(X_train.shape[0]), "test_size": int(X_test.shape[0]), "params": dict(params or {}), "test_r2": test_r2, "test_rmse": test_rmse, "test_mae": test_mae, "train_r2": train_r2, "train_rmse": train_rmse, "train_mae": train_mae, "split_method": "random", "trained_at": datetime.now().isoformat(), } # 7) 持久化 (目录可能不存在, 顺手建) output_model_path = Path(output_model_path) output_model_path.parent.mkdir(parents=True, exist_ok=True) joblib.dump( {"model": pipeline, "metadata": metadata}, output_model_path, ) return metadata # --------------------------------------------------------------------------- # 推断管线的模块级辅助函数 # --------------------------------------------------------------------------- # 设计要点: # 1) Dask 调度时, 函数必须可被工作进程 pickle 序列化。 # 因此 _predict_block / _load_model / _make_dummy_model / _run_predict_sync # 全部是模块级函数 (而非嵌套), 避免闭包陷阱。 # 2) _predict_block 通过 model.predict(spectra_2d) 整批预测, # 整张影像的 O(n_pixels * n_bands) 一次性预测在大矩阵上必 OOM, # 因此外层用 xr.apply_ufunc(dask="parallelized") 把矩阵切块 # 逐块进入此函数, 单次内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) 大小。 # --------------------------------------------------------------------------- def _make_dummy_model(n_features: int) -> RandomForestRegressor: """ 构造一个 Dummy 随机森林回归器。 用途: 1) 真实 joblib 文件不存在时的连通性测试 2) 训练骨架尚未接入真实数据时的占位推断 """ rng = np.random.default_rng(42) X = rng.random((200, n_features)) y = rng.random(200) model = RandomForestRegressor( n_estimators=10, max_depth=5, random_state=0, n_jobs=1 ) model.fit(X, y) return model def _load_model(path: str, n_features: int) -> Any: """ 加载训练好的 sklearn 模型, 失败时降级 Dummy。 优先级: 1) path 存在且 joblib.load 成功 -> 返回真实模型 2) 否则 -> 降级为 Dummy 随机森林 (n_features 必须指定) """ p = Path(path) if p.is_file() and p.stat().st_size > 0: try: print(f"[model] 从磁盘加载: {path}") return joblib.load(path) except Exception as exc: # noqa: BLE001 print(f"[model] joblib.load 失败 ({type(exc).__name__}: {exc}), 降级 Dummy") print(f"[model] 真实 joblib 不存在 ({path}), 使用 Dummy RandomForest") return _make_dummy_model(n_features) def _predict_block(spectra_3d: np.ndarray, model: Any) -> np.ndarray: """ 单个 dask chunk 的推断函数 (xr.apply_ufunc 会自动调度调用)。 Parameters ---------- spectra_3d : np.ndarray 形状 (y_chunk, x_chunk, n_bands)。 此形状由 input_core_dims=[["band"]] 决定: xarray 会把 band 维移到最后一轴, 然后按 (y, x) 的 chunk 切分调用本函数。 model : 已 fit 好的 sklearn 估计器 接受 (n_samples, n_features) 输入, 返回 (n_samples,) 预测。 Returns ------- np.ndarray 形状 (y_chunk, x_chunk), dtype float32 的标量预测图。 """ yc, xc, nb = spectra_3d.shape # 2D 化: 每个像素一行光谱 flat = spectra_3d.reshape(yc * xc, nb) # sklearn 风格的批量预测 pred = model.predict(flat) # 还原为 2D 空间图, 强制 float32 节约一半内存 return pred.reshape(yc, xc).astype(np.float32, copy=False) def _run_predict_sync( model: Any, model_id: str, input_zarr_path: str, output_zarr_path: str, ) -> None: """ 同步推断主流程 (被 asyncio.to_thread 调用)。 流程: 1) xr.open_zarr 延迟打开 (dask 数组, 不一次性读入内存) 2) NaN -> 0 清洗 (model.predict 不接受 NaN) 3) xr.apply_ufunc 沿 (y, x) 逐 chunk 调 _predict_block 4) 非水域置 NaN (zarr 支持 float NaN) 5) to_zarr 触发整图计算 + 流式写出 """ # 1. 延迟打开输入 (关键: Dask 不一次性读入内存) ds = xr.open_zarr(input_zarr_path, chunks="auto") if "reflectance" not in ds.data_vars: raise KeyError( f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds.data_vars)}" ) reflectance = ds["reflectance"] # dims: (y, x, band) n_bands = reflectance.sizes["band"] # 2. 水域掩膜 (与去耀斑算法同约定) if "water_mask" in ds.data_vars or "water_mask" in ds.coords: water_mask = ds["water_mask"].astype(bool) else: water_mask = xr.ones_like(reflectance.isel(band=0), dtype=bool) # 3. NaN 清洗: 填充 0 (model.predict 不接受 NaN) refl_clean = reflectance.fillna(0.0) # 4. 核心: 用 apply_ufunc 把 model.predict 沿 (y, x) 应用 # dask="parallelized" 让每个 (y_chunk, x_chunk, band) chunk # 独立调 _predict_block, 任意时刻内存中只有若干个 chunk。 prediction: xr.DataArray = xr.apply_ufunc( _predict_block, refl_clean, kwargs={"model": model}, input_core_dims=[["band"]], output_core_dims=[[]], dask="parallelized", output_dtypes=[np.float32], dask_gufunc_kwargs={"allow_rechunk": True}, vectorize=False, ) # 5. 非水域置 NaN (zarr 支持 float NaN, 便于后续可视化/掩膜分析) prediction = prediction.where(water_mask, np.nan) # 6. 包装为 Dataset 并流式写出 out = xr.Dataset( {"prediction": prediction}, attrs={ "model_id": model_id, "input_zarr_path": input_zarr_path, "n_bands": n_bands, "created_at": datetime.now().isoformat(), }, ) # 保留 y/x 坐标 out = out.assign_coords(y=ds["y"], x=ds["x"]) # to_zarr + compute=True 触发整图 dask 图求值 # 中间会按 chunk 逐块调度到线程池, 内存峰值 ≈ 1 个 chunk 的体量 out.to_zarr(output_zarr_path, mode="w", compute=True) # --------------------------------------------------------------------------- # 后台任务执行器 # --------------------------------------------------------------------------- async def execute_train_task( task_id: str, model_type: str, target: str, train_data_path: str, feature_start: Union[int, str], params: Dict[str, Any], ) -> None: """ 训练任务后台执行器(已接入真实 sklearn 训练流程)。 流程: 1) get_task 校验任务存在 2) update_task(PROCESSING) 3) 生成 model_id / model_path 4) asyncio.to_thread 派发 _run_train_sync 到默认线程池 5) 成功 -> _register_model + update_task(SUCCESS, 附 test_r2/rmse/mae) 6) 失败 -> update_task(FAILED, 附 error + traceback) """ record = await get_task(task_id) if record is None: print(f"[{task_id}] 训练任务不存在, 跳过") return await update_task( task_id, status="PROCESSING", updated_at=datetime.now().isoformat(), ) print( f"[{task_id}] 开始训练 model_type={model_type} target={target} " f"train_data_path={train_data_path} feature_start={feature_start}" ) # model_id 用 uuid4 前 12 位 (8 位易撞, 12 位兼顾可读性) model_id = f"model_{uuid.uuid4().hex[:12]}" model_path = Path(f"./data/models/{model_id}.joblib") try: # 同步 sklearn / pandas 训练丢到默认线程池, 不阻塞 event loop metadata = await asyncio.to_thread( _run_train_sync, model_type, target, train_data_path, feature_start, params, model_path, ) # 登记到内存注册表 (供 /predict 查 model_id) _register_model( { "model_id": model_id, "model_type": model_type, "target": target, "params": dict(params or {}), "path": str(model_path), "feature_start": feature_start, "n_features": metadata["n_features"], "test_r2": metadata["test_r2"], "test_rmse": metadata["test_rmse"], "test_mae": metadata["test_mae"], "created_at": datetime.now().isoformat(), "train_task_id": task_id, } ) # 把训练指标写回任务记录, 前端轮询时可直接看 await update_task( task_id, status="SUCCESS", model_id=model_id, model_path=str(model_path), test_r2=metadata["test_r2"], test_rmse=metadata["test_rmse"], test_mae=metadata["test_mae"], n_features=metadata["n_features"], n_samples=metadata["n_samples"], error=None, traceback=None, updated_at=datetime.now().isoformat(), ) print( f"[{task_id}] 训练完成 -> model_id={model_id} " f"test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f}" ) except Exception as exc: # noqa: BLE001 # 失败时 model_path 不一定有产物, 显式置 None 方便前端判断 await update_task( task_id, status="FAILED", model_id=None, model_path=None, error=f"{type(exc).__name__}: {exc}", traceback=traceback.format_exc(), updated_at=datetime.now().isoformat(), ) print(f"[{task_id}] 训练失败 -> {type(exc).__name__}: {exc}") async def execute_predict_task( task_id: str, model_id: str, input_zarr_path: str, output_zarr_path: Optional[str], ) -> None: """ 推断任务后台执行器(真实实现版)。 OOM 防护策略: - xr.open_zarr(..., chunks="auto") 延迟打开, 整图不一次性读入内存 - xr.apply_ufunc(..., dask="parallelized") 把影像按 chunk 切分 - 每个 chunk 内部 reshape 成 2D, 调 model.predict, 再 reshape 回 2D - 任意时刻内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) chunk 的体量 - 整图完成计算后再 to_zarr(compute=True) 流式写出 """ record = await get_task(task_id) if record is None: print(f"[{task_id}] 推断任务不存在, 跳过") return # 1. 校验 model_id 是否已注册 (避免在后台任务里报模糊错误) model_meta = _MODEL_REGISTRY.get(model_id) if model_meta is None: await update_task( task_id, status="FAILED", error=f"model_id 不存在: {model_id}", updated_at=datetime.now().isoformat(), ) print(f"[{task_id}] 推断失败 -> model_id 不存在: {model_id}") return # 2. 自动生成 output_zarr_path (若未提供) if output_zarr_path is None: stem = input_zarr_path.rstrip("/\\").split("/")[-1].split("\\")[-1] stem = stem.replace(".zarr", "") output_zarr_path = f"./data/{model_id}_{stem}_pred.zarr" await update_task( task_id, status="PROCESSING", updated_at=datetime.now().isoformat(), ) print(f"[{task_id}] 开始推断 model_id={model_id} input={input_zarr_path}") try: # 3. 探测波段数 (用于 Dummy 模型适配) # 这里只读 zarr 元数据 (.zarray 的 shape), 不读真实数据 ds_probe = xr.open_zarr(input_zarr_path, chunks="auto") if "reflectance" not in ds_probe.data_vars: raise KeyError( f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds_probe.data_vars)}" ) n_bands = ds_probe["reflectance"].sizes["band"] ds_probe.close() # 4. 加载模型 (真实文件优先, Dummy 兜底) model = _load_model(model_meta["path"], n_features=n_bands) # 5. 包装同步执行, 丢到线程池, 事件循环不阻塞 await asyncio.to_thread( _run_predict_sync, model, model_id, input_zarr_path, output_zarr_path, ) await update_task( task_id, status="SUCCESS", output_zarr_path=output_zarr_path, model_id=model_id, error=None, updated_at=datetime.now().isoformat(), ) print(f"[{task_id}] 推断完成 -> output={output_zarr_path}") except Exception as exc: # noqa: BLE001 tb_text = traceback.format_exc() await update_task( task_id, status="FAILED", output_zarr_path=None, error=f"{type(exc).__name__}: {exc}", traceback=tb_text, updated_at=datetime.now().isoformat(), ) print(f"[{task_id}] 推断失败 -> {type(exc).__name__}: {exc}") print(tb_text) # --------------------------------------------------------------------------- # POST /api/modeling/train # --------------------------------------------------------------------------- @router.post("/train", response_model=TaskAcceptedResponse) async def submit_train( payload: TrainRequest, background_tasks: BackgroundTasks, ) -> Dict[str, Any]: """提交一个模型训练任务, 立即返回 task_id。""" task_id = str(uuid.uuid4()) await set_task( task_id, { "task_id": task_id, "kind": "train", "model_type": payload.model_type, "target": payload.target, "train_data_path": payload.train_data_path, "feature_start": payload.feature_start, "params": payload.params, "status": "PENDING", "model_id": None, "model_path": None, "test_r2": None, "test_rmse": None, "test_mae": None, "n_features": None, "n_samples": None, "error": None, "traceback": None, "created_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(), }, ) background_tasks.add_task( execute_train_task, task_id, payload.model_type, payload.target, payload.train_data_path, payload.feature_start, payload.params, ) return {"task_id": task_id, "status": "PENDING", "kind": "train"} # --------------------------------------------------------------------------- # GET /api/modeling/models # --------------------------------------------------------------------------- @router.get("/models", response_model=ModelListResponse) async def list_trained_models() -> Dict[str, Any]: """ 列出已训练好的模型。 未来实现: 从 ./data/models/*.joblib 扫描元信息, 当前直接从内存 _MODEL_REGISTRY 读。 """ models = list(_MODEL_REGISTRY.values()) # 按 created_at 倒序, 最新训练的在前 models.sort(key=lambda m: m.get("created_at", ""), reverse=True) return {"models": models, "count": len(models)} # --------------------------------------------------------------------------- # POST /api/modeling/predict # --------------------------------------------------------------------------- @router.post("/predict", response_model=TaskAcceptedResponse) async def submit_predict( payload: PredictRequest, background_tasks: BackgroundTasks, ) -> Dict[str, Any]: """提交一个模型推断任务, 立即返回 task_id。""" task_id = str(uuid.uuid4()) await set_task( task_id, { "task_id": task_id, "kind": "predict", "model_id": payload.model_id, "input_zarr_path": payload.input_zarr_path, "output_zarr_path": payload.output_zarr_path, "status": "PENDING", "error": None, "traceback": None, "created_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(), }, ) background_tasks.add_task( execute_predict_task, task_id, payload.model_id, payload.input_zarr_path, payload.output_zarr_path, ) return {"task_id": task_id, "status": "PENDING", "kind": "predict"} # --------------------------------------------------------------------------- # models_router — 独立于 modeling_router,路径前缀为 /models # 最终完整路径: GET /api/models, POST /api/models/upload # --------------------------------------------------------------------------- models_router = APIRouter(prefix="/models", tags=["models"]) # --------------------------------------------------------------------------- # GET /api/models # --------------------------------------------------------------------------- @models_router.get("") async def list_models() -> Dict[str, Any]: """ 扫描 ./data/models/ 目录,返回所有 .joblib 文件名(不含后缀)。 异常处理:目录不存在时自动创建,返回空列表。 """ models_dir = Path("./data/models") models_dir.mkdir(parents=True, exist_ok=True) model_names = [ p.stem for p in models_dir.iterdir() if p.suffix == ".joblib" ] return {"models": model_names} # --------------------------------------------------------------------------- # POST /api/models/upload # --------------------------------------------------------------------------- @models_router.post("/upload") async def upload_model( file: UploadFile = File(...), ) -> Dict[str, Any]: """ 接收上传的 .joblib 模型文件,保存到 ./data/models/ 目录。 - 校验后缀必须为 .joblib - 目录不存在时自动创建 - 返回状态和文件名(不含后缀) """ if not file.filename or not file.filename.lower().endswith(".joblib"): raise HTTPException( status_code=400, detail="仅支持 .joblib 格式的文件", ) models_dir = Path("./data/models") models_dir.mkdir(parents=True, exist_ok=True) dest_path = models_dir / file.filename with dest_path.open("wb") as buffer: shutil.copyfileobj(file.file, buffer) return { "status": "success", "model_id": dest_path.stem, }