refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名

This commit is contained in:
DXC
2026-06-03 17:29:41 +08:00
parent 517bb28611
commit 343e316799
99 changed files with 9127 additions and 91 deletions

View File

@ -0,0 +1,201 @@
"""
冒烟测试 _run_train_sync: 用合成数据走通真实训练管线。
不依赖 FastAPI / xarray / dask, 只验训练 + 持久化 + 回测。
"""
import sys
import tempfile
from pathlib import Path
import numpy as np
import pandas as pd
# 绕过 main.py 触发 app 包导入(只导入 modeling 模块)
# 当前文件位于 new/app/api/_smoke_test_train.py
# app 包在 new/app/__init__.py, 故 new/ 必须在 sys.path 上
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from app.api.modeling import (
_get_model_pipeline,
_load_train_df,
_resolve_feature_start,
_run_train_sync,
_MODEL_CLASS_REGISTRY,
)
def make_synthetic_csv(n_samples: int = 200, n_features: int = 8, noise: float = 0.1, seed: int = 42) -> Path:
"""生成 [lat, lon, target, lat2, lon2, feat_0, feat_1, ...] 布局的 CSV"""
rng = np.random.default_rng(seed)
lat = rng.uniform(20, 25, n_samples)
lon = rng.uniform(110, 115, n_samples)
target = rng.uniform(0, 50, n_samples)
lat2 = rng.uniform(0, 1, n_samples) # 元数据
lon2 = rng.uniform(0, 1, n_samples) # 元数据
feats = rng.normal(0, 1, (n_samples, n_features))
# 让 y 真正依赖前 3 个特征, RF 至少应该能学到 R² > 0.5
feats[:, 0] += target / 10
feats[:, 1] += target / 20
feats[:, 2] -= target / 15
df = pd.DataFrame({
"lat": lat,
"lon": lon,
"Chl-a": target,
"lat2": lat2,
"lon2": lon2,
**{f"feat_{i}": feats[:, i] for i in range(n_features)},
})
tmp = Path(tempfile.mkdtemp()) / "train.csv"
df.to_csv(tmp, index=False)
return tmp
def test_load_train_df():
print("== test_load_train_df ==")
p = make_synthetic_csv(n_samples=50)
df = _load_train_df(str(p))
assert df.shape == (50, 5 + 8), f"shape={df.shape}"
print(f" shape={df.shape}, columns[:6]={list(df.columns[:6])}")
print(" PASS")
def test_resolve_feature_start_int_and_str():
print("== test_resolve_feature_start (int + str) ==")
p = make_synthetic_csv()
df = _load_train_df(str(p))
idx_int = _resolve_feature_start(df, 5)
idx_str = _resolve_feature_start(df, "feat_0")
assert idx_int == 5 == idx_str, f"int={idx_int}, str={idx_str}"
print(f" int(5) -> {idx_int}, str('feat_0') -> {idx_str}")
print(" PASS")
def test_resolve_feature_start_str_miss():
print("== test_resolve_feature_start (str 不存在 -> 抛错) ==")
p = make_synthetic_csv()
df = _load_train_df(str(p))
try:
_resolve_feature_start(df, "not_exist")
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_get_model_pipeline_all_types():
print("== test_get_model_pipeline (5 种 model_type) ==")
for mt in ["RF", "SVR", "LinearRegression", "KNN", "PLS"]:
p = _get_model_pipeline(mt, {})
assert len(p.steps) == 2
assert p.steps[0][0] == "scaler"
assert p.steps[1][0] == "model"
print(f" 全部通过: {list(_MODEL_CLASS_REGISTRY)}")
print(" PASS")
def test_get_model_pipeline_bad_type():
print("== test_get_model_pipeline (坏 model_type) ==")
try:
_get_model_pipeline("XGBoost", {})
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_run_train_sync_rf_end_to_end():
print("== test_run_train_sync (RF 端到端) ==")
p = make_synthetic_csv(n_samples=200)
out_dir = Path(tempfile.mkdtemp())
out_path = out_dir / "model.joblib"
import time
t0 = time.time()
metadata = _run_train_sync(
model_type="RF",
target="Chl-a",
train_data_path=str(p),
feature_start=5,
params={"n_estimators": 30, "max_depth": 6, "random_state": 42, "n_jobs": 1},
output_model_path=out_path,
)
dt = time.time() - t0
assert out_path.exists(), f"joblib 未落盘: {out_path}"
print(f" joblib 落盘: {out_path} ({out_path.stat().st_size} bytes)")
print(f" metadata.test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f} test_mae={metadata['test_mae']:.4f}")
print(f" metadata.n_features={metadata['n_features']} n_samples={metadata['n_samples']} train_size={metadata['train_size']} test_size={metadata['test_size']}")
print(f" 耗时 {dt:.2f}s")
# 回测: 加载 joblib 再 predict
import joblib
saved = joblib.load(out_path)
assert "model" in saved and "metadata" in saved, f"joblib 双 key 缺失: {saved.keys()}"
assert hasattr(saved["model"], "predict")
assert saved["metadata"]["test_r2"] == metadata["test_r2"]
print(f" joblib 加载 OK, 含 'model''metadata' 双 key")
print(" PASS")
def test_run_train_sync_linearregression_fast():
print("== test_run_train_sync (LinearRegression 快速路径) ==")
p = make_synthetic_csv(n_samples=150)
out_path = Path(tempfile.mkdtemp()) / "lr.joblib"
metadata = _run_train_sync(
model_type="LinearRegression",
target="Chl-a",
train_data_path=str(p),
feature_start=5,
params={},
output_model_path=out_path,
)
print(f" test_r2={metadata['test_r2']:.4f} (LR 学到线性, R² 应 >= 0.4)")
assert metadata["test_r2"] > 0.3, f"LR test_r2={metadata['test_r2']} 太低, 数据生成可能有问题"
print(" PASS")
def test_run_train_sync_bad_csv():
print("== test_run_train_sync (CSV 不存在) ==")
try:
_run_train_sync("RF", "Chl-a", "/no/such/path.csv", 5, {}, Path("/tmp/x.joblib"))
print(" FAIL: 应抛异常")
except (FileNotFoundError, ValueError) as e:
print(f" 正确抛 {type(e).__name__}: {e}")
print(" PASS")
def test_run_train_sync_bad_target():
print("== test_run_train_sync (target 列不存在) ==")
p = make_synthetic_csv()
try:
_run_train_sync("RF", "NopeTarget", str(p), 5, {}, Path("/tmp/x.joblib"))
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_run_train_sync_str_feature_start():
print("== test_run_train_sync (feature_start 用列名) ==")
p = make_synthetic_csv()
out_path = Path(tempfile.mkdtemp()) / "str_fs.joblib"
metadata = _run_train_sync("RF", "Chl-a", str(p), "feat_0", {"n_estimators": 10}, out_path)
assert metadata["feature_start"] == "feat_0"
assert metadata["n_features"] == 8
assert metadata["feature_columns"][0] == "feat_0"
print(f" 列名 'feat_0' 解析正确, n_features={metadata['n_features']}")
print(" PASS")
if __name__ == "__main__":
test_load_train_df()
test_resolve_feature_start_int_and_str()
test_resolve_feature_start_str_miss()
test_get_model_pipeline_all_types()
test_get_model_pipeline_bad_type()
test_run_train_sync_rf_end_to_end()
test_run_train_sync_linearregression_fast()
test_run_train_sync_bad_csv()
test_run_train_sync_bad_target()
test_run_train_sync_str_feature_start()
print("\n>>> ALL SMOKE TESTS PASSED")

222
new/app/api/endpoints.py Normal file
View File

@ -0,0 +1,222 @@
"""
API 路由集合
============
把业务接口统一收口到 APIRouter再由 main.py 通过 include_router 挂载。
当前包含的接口:
GET /api/algorithms 列出已注册的所有去耀斑算法(供前端下拉框)
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
GET /api/tasks/{task_id} 查询指定任务的状态与结果
派发链:
POST /api/process/deglint
└─ BackgroundTasks.add_task(execute_glint_removal_task, ...)
└─ get_remover(method) 从注册表拿到算法类
└─ remover.process(input_zarr, output_zarr, **params)
"""
import traceback
import uuid
from datetime import datetime
from typing import Any, Dict
from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel, Field
# 并发安全的任务状态存储(替代旧版的 MOCK_TASK_DB
from app.core.task_store import get_task, set_task, update_task
# 算法注册表 API
from app.core.algorithms import get_remover, list_removers
# ---------------------------------------------------------------------------
# 路由实例
# ---------------------------------------------------------------------------
# prefix 不在此处设置,统一在 main.py 挂载时给定,便于将来按版本拆分
# (例如 /api/v1、/api/v2 共存时复用同一个 router 对象)。
# ---------------------------------------------------------------------------
router = APIRouter(tags=["deglint"])
# ---------------------------------------------------------------------------
# 请求 / 响应数据模型
# ---------------------------------------------------------------------------
class DeglintRequest(BaseModel):
"""POST /api/process/deglint 的请求体"""
method: str = Field(
...,
description="去耀斑方法名称,必须是已注册算法,例如 'kutser' / 'goodman'",
examples=["kutser"],
)
params: Dict[str, Any] = Field(
default_factory=dict,
description=(
"传递给算法 process() 的超参数字典,例如 "
"Kutser: {'band_lower': 773, 'band_oxy': 845, 'band_upper': 893}; "
"Goodman: {'band_ref': 750, 'band_diff': 640, 'A': 0.0, 'B': 0.0}"
),
examples=[{"band_lower": 773, "band_oxy": 845, "band_upper": 893}],
)
class TaskAcceptedResponse(BaseModel):
"""提交任务成功后立即返回的响应"""
task_id: str
status: str # 一定是 PENDING
class AlgorithmListResponse(BaseModel):
"""GET /api/algorithms 的响应"""
algorithms: list # 已注册算法名列表
count: int # 算法总数
# ---------------------------------------------------------------------------
# 后台任务执行器(真实派发链)
# ---------------------------------------------------------------------------
# 注意:这里使用 async def。
# FastAPI / Starlette 的 BackgroundTasks 支持 async function
# 会在响应返回后自动 await 它,不影响主请求链路。
# ---------------------------------------------------------------------------
async def execute_glint_removal_task(
task_id: str,
method: str,
params: Dict[str, Any],
) -> None:
"""
后台异步执行器:按 method 名字从注册表取出算法类,实例化并运行 process()。
状态机:
PENDING -> PROCESSING -> SUCCESS
└──> FAILED含 error / traceback
"""
# 0. 安全检查任务记录必须已存在POST 阶段已写入)
record = await get_task(task_id)
if record is None:
print(f"[{task_id}] 任务不存在, 跳过")
return
# 1. 状态推进到 PROCESSING
await update_task(
task_id,
status="PROCESSING",
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 开始处理 method={method} params={params}")
# 2. 临时硬编码 IO 路径(未来由数据管理层提供)
# TODO: 替换为真实的数据管理服务返回的 zarr 路径
input_zarr_path = "./data/temp_in.zarr"
output_zarr_path = f"./data/{task_id}_out.zarr"
try:
# 3. 按 method 名字从注册表取算法类并实例化
# get_remover 找不到时会抛 KeyError下面的 except 会兜住
algorithm_cls = get_remover(method)
remover = algorithm_cls()
# 4. 调用算法(注意 await因为 BaseGlintRemover.process 是 async
await remover.process(input_zarr_path, output_zarr_path, **params)
# 5. 成功:写回结果路径与状态
await update_task(
task_id,
status="SUCCESS",
output_zarr_path=output_zarr_path,
error=None,
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 处理完成 -> SUCCESS, output={output_zarr_path}")
except Exception as exc: # noqa: BLE001 顶层兜底,绝不让后台任务静默失败
# 6. 失败:记录错误信息与堆栈,便于前端排查
await update_task(
task_id,
status="FAILED",
output_zarr_path=None,
error=f"{type(exc).__name__}: {exc}",
traceback=traceback.format_exc(),
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 处理失败 -> {type(exc).__name__}: {exc}")
# ---------------------------------------------------------------------------
# GET /algorithms
# ---------------------------------------------------------------------------
# 返回当前已注册的所有算法名,供前端动态渲染下拉框 / 选择器。
# ---------------------------------------------------------------------------
@router.get("/algorithms", response_model=AlgorithmListResponse)
async def list_registered_algorithms() -> Dict[str, Any]:
"""列出已注册的去耀斑算法。"""
names = list(list_removers().keys())
return {"algorithms": names, "count": len(names)}
# ---------------------------------------------------------------------------
# POST /process/deglint
# ---------------------------------------------------------------------------
# 提交去耀斑处理任务。FastAPI 在函数返回后才会把响应发给前端,
# 因此通过 BackgroundTasks 把耗时操作丢到后台,接口本身立刻返回 task_id。
# ---------------------------------------------------------------------------
@router.post("/process/deglint", response_model=TaskAcceptedResponse)
async def submit_deglint(
payload: DeglintRequest,
background_tasks: BackgroundTasks,
) -> Dict[str, Any]:
"""提交一个去耀斑处理任务,并立即返回 task_id。"""
# 1. 生成唯一任务 IDUUID4 足以保证全局唯一性)
task_id = str(uuid.uuid4())
# 2. 在任务库中登记一条 PENDING 记录(并发安全)
# 注意output_zarr_path / error / traceback 字段在执行过程中被填充
await set_task(
task_id,
{
"task_id": task_id,
"method": payload.method,
"params": payload.params,
"status": "PENDING",
"output_zarr_path": None,
"error": None,
"traceback": None,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
},
)
# 3. 把真实执行器丢到后台
background_tasks.add_task(
execute_glint_removal_task,
task_id,
payload.method,
payload.params,
)
# 4. 立即返回 task_id 与 PENDING 状态
return {"task_id": task_id, "status": "PENDING"}
# ---------------------------------------------------------------------------
# GET /tasks/{task_id}
# ---------------------------------------------------------------------------
# 前端轮询此接口获取任务状态。PENDING / PROCESSING 表示仍在跑,
# SUCCESS 表示成功(含 output_zarr_pathFAILED 表示失败(含 error / traceback
# ---------------------------------------------------------------------------
@router.get("/tasks/{task_id}")
async def get_task_status(task_id: str) -> Dict[str, Any]:
"""查询指定任务的当前状态与结果。"""
record = await get_task(task_id)
if record is None:
# 找不到 task_id 通常意味着客户端拼错了 ID或者记录已被清理
raise HTTPException(status_code=404, detail=f"task_id 不存在: {task_id}")
# 直接返回字典FastAPI 会自动 JSON 序列化
return record

786
new/app/api/modeling.py Normal file
View File

@ -0,0 +1,786 @@
"""
app/api/modeling.py
===================
机器学习与水质反演相关的 API 路由。
接口(最终路径, 挂载后):
POST /api/modeling/train 提交模型训练任务, 立即返回 task_id
GET /api/modeling/models 列出已训练好的模型(未来从磁盘 joblib 读)
POST /api/modeling/predict 提交模型推断任务, 立即返回 task_id
设计要点
--------
- 训练 / 推断均为异步后台任务, 复用 app.core.task_store 的并发安全任务状态。
- 模型元数据用模块级 _MODEL_REGISTRY 暂存(开发期内存存储),
未来从磁盘 joblib 读时只需替换 list_trained_models() 内部实现即可。
- /predict 已接入真实 sklearn + xarray + dask 流式推断:
* joblib.load 读模型(缺文件时降级为 Dummy RandomForestRegressor
* xr.open_zarr 延迟打开影像, NaN 填 0
* xr.apply_ufunc(dask="parallelized") 沿 (y, x) 逐 chunk 调 model.predict
* to_zarr(mode="w", compute=True) 流式写出, 内存峰值 ≈ 1 个 chunk
- /train 已接入真实 sklearn + pandas 训练管线:
* pd.read_csv 读结构化训练表(支持 [lat, lon, target_*, feature_*] 布局)
* 按 target 列 dropna 清洗;按 feature_start 索引/列名切分特征
* sklearn Pipeline: StandardScaler -> {RF/SVR/LinearRegression/KNN/PLS}
* train_test_split(80/20) 划分, 计算 test_r2/rmse/mae
* joblib.dump({model, metadata}) 落盘 ./data/models/{model_id}.joblib
* 测试指标写回 TASK_STORE, 同时登记到 _MODEL_REGISTRY
注: 旧版 SPXY / KS 划分留作未来扩展, 当前固定 random 划分 (test_size=0.2, random_state=42)。
"""
import asyncio
import traceback
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import joblib
import numpy as np
import pandas as pd
import xarray as xr
from fastapi import APIRouter, BackgroundTasks
from pydantic import BaseModel, Field
from sklearn.cross_decomposition import PLSRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR
# 复用并发安全任务状态存储(与 deglint 共享同一份 TASK_STORE,
# 通过 task 记录里的 "kind" 字段区分 train / predict / deglint
from app.core.task_store import get_task, set_task, update_task
# ---------------------------------------------------------------------------
# 路由实例
# ---------------------------------------------------------------------------
# prefix="/modeling" 让本文件内只写 /train /models /predict 等短路径,
# 最终完整路径由 main.py 挂载时再补 /api。
# ---------------------------------------------------------------------------
router = APIRouter(prefix="/modeling", tags=["modeling"])
# ---------------------------------------------------------------------------
# 数据模型
# ---------------------------------------------------------------------------
class TrainRequest(BaseModel):
"""POST /api/modeling/train 的请求体"""
model_type: str = Field(
...,
description="模型类型, 例如 'RF' (随机森林) / 'SVR' (支持向量回归) / 'XGBoost' / 'MLP'",
examples=["RF", "SVR"],
)
target: str = Field(
...,
description="反演目标水质参数, 例如 'Chl-a' (叶绿素a) / 'TSS' (总悬浮物) / 'CDOM' (有色可溶有机物)",
examples=["Chl-a", "TSS", "CDOM"],
)
train_data_path: str = Field(
...,
description="训练数据集的 zarr 路径(包含 reflectance 变量与 target 标签)",
examples=["./data/train.zarr"],
)
feature_start: Union[int, str] = Field(
default=4,
description=(
"特征列起始位置. 表格布局假定为 "
"[lat, lon, target_1, target_2, ..., feature_1, feature_2, ...] "
"可传 int 列索引(如 4或 str 列名(如 '374.285' 波长起点)。"
"默认 4, 即前 4 列视为元数据/目标, 之后全部是特征。"
),
examples=[4, "374.285"],
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="模型超参, 例如 RF 的 {'n_estimators': 100, 'max_depth': 20}",
examples=[{"n_estimators": 100, "max_depth": 20}],
)
class PredictRequest(BaseModel):
"""POST /api/modeling/predict 的请求体"""
model_id: str = Field(
...,
description="已训练模型的 ID由 /api/modeling/train 返回或 /api/modeling/models 列出)",
)
input_zarr_path: str = Field(
...,
description="待推断影像的 zarr 路径",
examples=["./data/scene.zarr"],
)
output_zarr_path: Optional[str] = Field(
default=None,
description=(
"输出 zarr 路径, 缺省时由后端按规则生成 "
"(如 ./data/{model_id}_{input_stem}_pred.zarr)"
),
)
class TaskAcceptedResponse(BaseModel):
"""提交训练/推断任务后立即返回的响应"""
task_id: str
status: str # 一定是 PENDING
kind: str # "train" / "predict", 便于前端识别任务类型
class ModelInfo(BaseModel):
"""单个模型的元信息GET /api/modeling/models 的元素)"""
model_id: str
model_type: str
target: str
params: Dict[str, Any]
path: str # joblib 文件路径
created_at: str
train_task_id: str # 产生此模型的那个训练任务的 ID
class ModelListResponse(BaseModel):
"""GET /api/modeling/models 的响应"""
models: List[ModelInfo]
count: int
# ---------------------------------------------------------------------------
# 模块级模型注册表(开发期内存, 未来替换为磁盘扫描)
# ---------------------------------------------------------------------------
# model_id -> ModelInfo 字典
# 读多写少, 用一个普通 dict 足够CPython GIL 兜底)。
# 写时(训练完成时)只发生一次, 无并发风险。
# ---------------------------------------------------------------------------
_MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {}
def _register_model(record: Dict[str, Any]) -> None:
"""将训练完成的模型登记到内存注册表。"""
_MODEL_REGISTRY[record["model_id"]] = record
# ---------------------------------------------------------------------------
# 训练管线的模块级辅助函数
# ---------------------------------------------------------------------------
# 设计要点 (与推断管线一致):
# 1) 模块级函数: dask / joblib 后端若走子进程 pickle, 嵌套闭包会丢字段。
# 2) 同步执行: execute_train_task 用 asyncio.to_thread 派发, 内部全程同步阻塞。
# 3) 失败抛异常: 异常由 execute_train_task 捕获, 转 FAILED + traceback。
# ---------------------------------------------------------------------------
# model_type (大写字符串) -> sklearn 估计器类
# 与 OpenClaw model_configs 思路一致, 但此处只保留类 (参数由 params 透传)
_MODEL_CLASS_REGISTRY: Dict[str, type] = {
"RF": RandomForestRegressor,
"SVR": SVR,
"LinearRegression": LinearRegression,
"KNN": KNeighborsRegressor,
"PLS": PLSRegression,
}
def _get_model_pipeline(model_type: str, params: Optional[Dict[str, Any]]) -> Pipeline:
"""
模型工厂: 按 model_type 选 sklearn 类, 用 StandardScaler + 估计器构造 Pipeline。
与 OpenClaw 不同之处: 把 scaler 放进 Pipeline 第一步,
推断时直接 pipeline.predict(X) 即可, scaler 参数与训练时严格一致。
"""
model_cls = _MODEL_CLASS_REGISTRY.get(model_type)
if model_cls is None:
raise ValueError(
f"不支持的 model_type='{model_type}', "
f"可选: {sorted(_MODEL_CLASS_REGISTRY.keys())}"
)
estimator = model_cls(**(params or {}))
return Pipeline([("scaler", StandardScaler()), ("model", estimator)])
def _load_train_df(csv_path: str) -> pd.DataFrame:
"""
读 CSV 训练表, 规整空串 / 空白 / NULL 等为 NaN。
沿用 OpenClaw modeling_batch.load_data_batch 的读取策略:
na_values 显式列举 + 正则二次清理 (防 cell 内出现 " " 等纯空白)。
"""
try:
df = pd.read_csv(
csv_path,
na_values=["", " ", "NaN", "nan", "NULL", "null"],
)
except FileNotFoundError as exc:
raise FileNotFoundError(f"训练数据文件不存在: {csv_path}") from exc
except pd.errors.EmptyDataError as exc:
raise ValueError(f"训练数据文件为空: {csv_path}") from exc
# 二次清理: 残留的纯空白 cell
df = df.replace(r"^\s*$", np.nan, regex=True)
return df
def _resolve_feature_start(
df: pd.DataFrame,
feature_start: Union[int, str],
) -> int:
"""
将 feature_start (int 索引 / str 列名) 统一解析为 int 列索引。
与 OpenClaw modeling_batch.load_data_batch / load_data_single 一致:
str 走 columns.get_loc, int 直接返回。
"""
if isinstance(feature_start, str):
if feature_start not in df.columns:
raise ValueError(
f"feature_start='{feature_start}' 不在 CSV 列中: {list(df.columns)}"
)
return int(df.columns.get_loc(feature_start))
return int(feature_start)
def _run_train_sync(
model_type: str,
target: str,
train_data_path: str,
feature_start: Union[int, str],
params: Optional[Dict[str, Any]],
output_model_path: Path,
) -> Dict[str, Any]:
"""
完整同步训练流程 (由 execute_train_task 在线程池内调用):
pd.read_csv -> 目标列 dropna -> 切特征 -> train_test_split(80/20)
-> Pipeline(StandardScaler + model).fit -> 评估 test_r2/rmse/mae
-> joblib.dump({model, metadata}, output_model_path)
Returns:
metadata 字典, 含 test_r2 / test_rmse / test_mae / n_features 等,
调用方负责写回 TASK_STORE 和 _MODEL_REGISTRY。
注: 旧版 SPXY / KS 划分留作未来扩展 (params.split_method 控制),
当前固定 random + test_size=0.2 + random_state=42。
"""
df = _load_train_df(train_data_path)
if target not in df.columns:
raise ValueError(
f"target='{target}' 不在 CSV 列中, 可选: {list(df.columns)}"
)
# 1) 清洗: 仅剔除 target NaN 的行 (与 OpenClaw load_data_single 一致)
df = df[df[target].notna()].copy()
if df.empty:
raise ValueError("target 剔除 NaN 后无样本, 终止训练")
# 2) 特征切分
feature_start_idx = _resolve_feature_start(df, feature_start)
feature_columns = list(df.columns[feature_start_idx:])
X = df.iloc[:, feature_start_idx:].astype(np.float64)
y = df[target].astype(np.float64).values
# 3) 划分 (固定 random, 未来扩展 spxy/ks)
X_train, X_test, y_train, y_test = train_test_split(
X.values,
y,
test_size=0.2,
random_state=42,
)
# 4) 构造 Pipeline + 训练
pipeline = _get_model_pipeline(model_type, params)
pipeline.fit(X_train, y_train)
# 5) 测试集与训练集评估
y_pred = pipeline.predict(X_test)
test_r2 = float(r2_score(y_test, y_pred))
test_rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
test_mae = float(mean_absolute_error(y_test, y_pred))
y_train_pred = pipeline.predict(X_train)
train_r2 = float(r2_score(y_train, y_train_pred))
train_rmse = float(np.sqrt(mean_squared_error(y_train, y_train_pred)))
train_mae = float(mean_absolute_error(y_train, y_train_pred))
metadata: Dict[str, Any] = {
"model_type": model_type,
"target": target,
"feature_start": feature_start,
"feature_columns": feature_columns,
"n_features": int(X.shape[1]),
"n_samples": int(X.shape[0]),
"train_size": int(X_train.shape[0]),
"test_size": int(X_test.shape[0]),
"params": dict(params or {}),
"test_r2": test_r2,
"test_rmse": test_rmse,
"test_mae": test_mae,
"train_r2": train_r2,
"train_rmse": train_rmse,
"train_mae": train_mae,
"split_method": "random",
"trained_at": datetime.now().isoformat(),
}
# 7) 持久化 (目录可能不存在, 顺手建)
output_model_path = Path(output_model_path)
output_model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(
{"model": pipeline, "metadata": metadata},
output_model_path,
)
return metadata
# ---------------------------------------------------------------------------
# 推断管线的模块级辅助函数
# ---------------------------------------------------------------------------
# 设计要点:
# 1) Dask 调度时, 函数必须可被工作进程 pickle 序列化。
# 因此 _predict_block / _load_model / _make_dummy_model / _run_predict_sync
# 全部是模块级函数 (而非嵌套), 避免闭包陷阱。
# 2) _predict_block 通过 model.predict(spectra_2d) 整批预测,
# 整张影像的 O(n_pixels * n_bands) 一次性预测在大矩阵上必 OOM,
# 因此外层用 xr.apply_ufunc(dask="parallelized") 把矩阵切块
# 逐块进入此函数, 单次内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) 大小。
# ---------------------------------------------------------------------------
def _make_dummy_model(n_features: int) -> RandomForestRegressor:
"""
构造一个 Dummy 随机森林回归器。
用途:
1) 真实 joblib 文件不存在时的连通性测试
2) 训练骨架尚未接入真实数据时的占位推断
"""
rng = np.random.default_rng(42)
X = rng.random((200, n_features))
y = rng.random(200)
model = RandomForestRegressor(
n_estimators=10, max_depth=5, random_state=0, n_jobs=1
)
model.fit(X, y)
return model
def _load_model(path: str, n_features: int) -> Any:
"""
加载训练好的 sklearn 模型, 失败时降级 Dummy。
优先级:
1) path 存在且 joblib.load 成功 -> 返回真实模型
2) 否则 -> 降级为 Dummy 随机森林 (n_features 必须指定)
"""
p = Path(path)
if p.is_file() and p.stat().st_size > 0:
try:
print(f"[model] 从磁盘加载: {path}")
return joblib.load(path)
except Exception as exc: # noqa: BLE001
print(f"[model] joblib.load 失败 ({type(exc).__name__}: {exc}), 降级 Dummy")
print(f"[model] 真实 joblib 不存在 ({path}), 使用 Dummy RandomForest")
return _make_dummy_model(n_features)
def _predict_block(spectra_3d: np.ndarray, model: Any) -> np.ndarray:
"""
单个 dask chunk 的推断函数 (xr.apply_ufunc 会自动调度调用)。
Parameters
----------
spectra_3d : np.ndarray
形状 (y_chunk, x_chunk, n_bands)。
此形状由 input_core_dims=[["band"]] 决定:
xarray 会把 band 维移到最后一轴, 然后按 (y, x) 的 chunk 切分调用本函数。
model : 已 fit 好的 sklearn 估计器
接受 (n_samples, n_features) 输入, 返回 (n_samples,) 预测。
Returns
-------
np.ndarray
形状 (y_chunk, x_chunk), dtype float32 的标量预测图。
"""
yc, xc, nb = spectra_3d.shape
# 2D 化: 每个像素一行光谱
flat = spectra_3d.reshape(yc * xc, nb)
# sklearn 风格的批量预测
pred = model.predict(flat)
# 还原为 2D 空间图, 强制 float32 节约一半内存
return pred.reshape(yc, xc).astype(np.float32, copy=False)
def _run_predict_sync(
model: Any,
model_id: str,
input_zarr_path: str,
output_zarr_path: str,
) -> None:
"""
同步推断主流程 (被 asyncio.to_thread 调用)。
流程:
1) xr.open_zarr 延迟打开 (dask 数组, 不一次性读入内存)
2) NaN -> 0 清洗 (model.predict 不接受 NaN)
3) xr.apply_ufunc 沿 (y, x) 逐 chunk 调 _predict_block
4) 非水域置 NaN (zarr 支持 float NaN)
5) to_zarr 触发整图计算 + 流式写出
"""
# 1. 延迟打开输入 (关键: Dask 不一次性读入内存)
ds = xr.open_zarr(input_zarr_path, chunks="auto")
if "reflectance" not in ds.data_vars:
raise KeyError(
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds.data_vars)}"
)
reflectance = ds["reflectance"] # dims: (y, x, band)
n_bands = reflectance.sizes["band"]
# 2. 水域掩膜 (与去耀斑算法同约定)
if "water_mask" in ds.data_vars or "water_mask" in ds.coords:
water_mask = ds["water_mask"].astype(bool)
else:
water_mask = xr.ones_like(reflectance.isel(band=0), dtype=bool)
# 3. NaN 清洗: 填充 0 (model.predict 不接受 NaN)
refl_clean = reflectance.fillna(0.0)
# 4. 核心: 用 apply_ufunc 把 model.predict 沿 (y, x) 应用
# dask="parallelized" 让每个 (y_chunk, x_chunk, band) chunk
# 独立调 _predict_block, 任意时刻内存中只有若干个 chunk。
prediction: xr.DataArray = xr.apply_ufunc(
_predict_block,
refl_clean,
kwargs={"model": model},
input_core_dims=[["band"]],
output_core_dims=[[]],
dask="parallelized",
output_dtypes=[np.float32],
dask_gufunc_kwargs={"allow_rechunk": True},
vectorize=False,
)
# 5. 非水域置 NaN (zarr 支持 float NaN, 便于后续可视化/掩膜分析)
prediction = prediction.where(water_mask, np.nan)
# 6. 包装为 Dataset 并流式写出
out = xr.Dataset(
{"prediction": prediction},
attrs={
"model_id": model_id,
"input_zarr_path": input_zarr_path,
"n_bands": n_bands,
"created_at": datetime.now().isoformat(),
},
)
# 保留 y/x 坐标
out = out.assign_coords(y=ds["y"], x=ds["x"])
# to_zarr + compute=True 触发整图 dask 图求值
# 中间会按 chunk 逐块调度到线程池, 内存峰值 ≈ 1 个 chunk 的体量
out.to_zarr(output_zarr_path, mode="w", compute=True)
# ---------------------------------------------------------------------------
# 后台任务执行器
# ---------------------------------------------------------------------------
async def execute_train_task(
task_id: str,
model_type: str,
target: str,
train_data_path: str,
feature_start: Union[int, str],
params: Dict[str, Any],
) -> None:
"""
训练任务后台执行器(已接入真实 sklearn 训练流程)。
流程:
1) get_task 校验任务存在
2) update_task(PROCESSING)
3) 生成 model_id / model_path
4) asyncio.to_thread 派发 _run_train_sync 到默认线程池
5) 成功 -> _register_model + update_task(SUCCESS, 附 test_r2/rmse/mae)
6) 失败 -> update_task(FAILED, 附 error + traceback)
"""
record = await get_task(task_id)
if record is None:
print(f"[{task_id}] 训练任务不存在, 跳过")
return
await update_task(
task_id,
status="PROCESSING",
updated_at=datetime.now().isoformat(),
)
print(
f"[{task_id}] 开始训练 model_type={model_type} target={target} "
f"train_data_path={train_data_path} feature_start={feature_start}"
)
# model_id 用 uuid4 前 12 位 (8 位易撞, 12 位兼顾可读性)
model_id = f"model_{uuid.uuid4().hex[:12]}"
model_path = Path(f"./data/models/{model_id}.joblib")
try:
# 同步 sklearn / pandas 训练丢到默认线程池, 不阻塞 event loop
metadata = await asyncio.to_thread(
_run_train_sync,
model_type,
target,
train_data_path,
feature_start,
params,
model_path,
)
# 登记到内存注册表 (供 /predict 查 model_id)
_register_model(
{
"model_id": model_id,
"model_type": model_type,
"target": target,
"params": dict(params or {}),
"path": str(model_path),
"feature_start": feature_start,
"n_features": metadata["n_features"],
"test_r2": metadata["test_r2"],
"test_rmse": metadata["test_rmse"],
"test_mae": metadata["test_mae"],
"created_at": datetime.now().isoformat(),
"train_task_id": task_id,
}
)
# 把训练指标写回任务记录, 前端轮询时可直接看
await update_task(
task_id,
status="SUCCESS",
model_id=model_id,
model_path=str(model_path),
test_r2=metadata["test_r2"],
test_rmse=metadata["test_rmse"],
test_mae=metadata["test_mae"],
n_features=metadata["n_features"],
n_samples=metadata["n_samples"],
error=None,
traceback=None,
updated_at=datetime.now().isoformat(),
)
print(
f"[{task_id}] 训练完成 -> model_id={model_id} "
f"test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f}"
)
except Exception as exc: # noqa: BLE001
# 失败时 model_path 不一定有产物, 显式置 None 方便前端判断
await update_task(
task_id,
status="FAILED",
model_id=None,
model_path=None,
error=f"{type(exc).__name__}: {exc}",
traceback=traceback.format_exc(),
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 训练失败 -> {type(exc).__name__}: {exc}")
async def execute_predict_task(
task_id: str,
model_id: str,
input_zarr_path: str,
output_zarr_path: Optional[str],
) -> None:
"""
推断任务后台执行器(真实实现版)。
OOM 防护策略:
- xr.open_zarr(..., chunks="auto") 延迟打开, 整图不一次性读入内存
- xr.apply_ufunc(..., dask="parallelized") 把影像按 chunk 切分
- 每个 chunk 内部 reshape 成 2D, 调 model.predict, 再 reshape 回 2D
- 任意时刻内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) chunk 的体量
- 整图完成计算后再 to_zarr(compute=True) 流式写出
"""
record = await get_task(task_id)
if record is None:
print(f"[{task_id}] 推断任务不存在, 跳过")
return
# 1. 校验 model_id 是否已注册 (避免在后台任务里报模糊错误)
model_meta = _MODEL_REGISTRY.get(model_id)
if model_meta is None:
await update_task(
task_id,
status="FAILED",
error=f"model_id 不存在: {model_id}",
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 推断失败 -> model_id 不存在: {model_id}")
return
# 2. 自动生成 output_zarr_path (若未提供)
if output_zarr_path is None:
stem = input_zarr_path.rstrip("/\\").split("/")[-1].split("\\")[-1]
stem = stem.replace(".zarr", "")
output_zarr_path = f"./data/{model_id}_{stem}_pred.zarr"
await update_task(
task_id,
status="PROCESSING",
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 开始推断 model_id={model_id} input={input_zarr_path}")
try:
# 3. 探测波段数 (用于 Dummy 模型适配)
# 这里只读 zarr 元数据 (.zarray 的 shape), 不读真实数据
ds_probe = xr.open_zarr(input_zarr_path, chunks="auto")
if "reflectance" not in ds_probe.data_vars:
raise KeyError(
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds_probe.data_vars)}"
)
n_bands = ds_probe["reflectance"].sizes["band"]
ds_probe.close()
# 4. 加载模型 (真实文件优先, Dummy 兜底)
model = _load_model(model_meta["path"], n_features=n_bands)
# 5. 包装同步执行, 丢到线程池, 事件循环不阻塞
await asyncio.to_thread(
_run_predict_sync,
model,
model_id,
input_zarr_path,
output_zarr_path,
)
await update_task(
task_id,
status="SUCCESS",
output_zarr_path=output_zarr_path,
model_id=model_id,
error=None,
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 推断完成 -> output={output_zarr_path}")
except Exception as exc: # noqa: BLE001
tb_text = traceback.format_exc()
await update_task(
task_id,
status="FAILED",
output_zarr_path=None,
error=f"{type(exc).__name__}: {exc}",
traceback=tb_text,
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 推断失败 -> {type(exc).__name__}: {exc}")
print(tb_text)
# ---------------------------------------------------------------------------
# POST /api/modeling/train
# ---------------------------------------------------------------------------
@router.post("/train", response_model=TaskAcceptedResponse)
async def submit_train(
payload: TrainRequest,
background_tasks: BackgroundTasks,
) -> Dict[str, Any]:
"""提交一个模型训练任务, 立即返回 task_id。"""
task_id = str(uuid.uuid4())
await set_task(
task_id,
{
"task_id": task_id,
"kind": "train",
"model_type": payload.model_type,
"target": payload.target,
"train_data_path": payload.train_data_path,
"feature_start": payload.feature_start,
"params": payload.params,
"status": "PENDING",
"model_id": None,
"model_path": None,
"test_r2": None,
"test_rmse": None,
"test_mae": None,
"n_features": None,
"n_samples": None,
"error": None,
"traceback": None,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
},
)
background_tasks.add_task(
execute_train_task,
task_id,
payload.model_type,
payload.target,
payload.train_data_path,
payload.feature_start,
payload.params,
)
return {"task_id": task_id, "status": "PENDING", "kind": "train"}
# ---------------------------------------------------------------------------
# GET /api/modeling/models
# ---------------------------------------------------------------------------
@router.get("/models", response_model=ModelListResponse)
async def list_trained_models() -> Dict[str, Any]:
"""
列出已训练好的模型。
未来实现: 从 ./data/models/*.joblib 扫描元信息,
当前直接从内存 _MODEL_REGISTRY 读。
"""
models = list(_MODEL_REGISTRY.values())
# 按 created_at 倒序, 最新训练的在前
models.sort(key=lambda m: m.get("created_at", ""), reverse=True)
return {"models": models, "count": len(models)}
# ---------------------------------------------------------------------------
# POST /api/modeling/predict
# ---------------------------------------------------------------------------
@router.post("/predict", response_model=TaskAcceptedResponse)
async def submit_predict(
payload: PredictRequest,
background_tasks: BackgroundTasks,
) -> Dict[str, Any]:
"""提交一个模型推断任务, 立即返回 task_id。"""
task_id = str(uuid.uuid4())
await set_task(
task_id,
{
"task_id": task_id,
"kind": "predict",
"model_id": payload.model_id,
"input_zarr_path": payload.input_zarr_path,
"output_zarr_path": payload.output_zarr_path,
"status": "PENDING",
"error": None,
"traceback": None,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
},
)
background_tasks.add_task(
execute_predict_task,
task_id,
payload.model_id,
payload.input_zarr_path,
payload.output_zarr_path,
)
return {"task_id": task_id, "status": "PENDING", "kind": "predict"}