refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名
This commit is contained in:
201
new/app/api/_smoke_test_train.py
Normal file
201
new/app/api/_smoke_test_train.py
Normal file
@ -0,0 +1,201 @@
|
||||
"""
|
||||
冒烟测试 _run_train_sync: 用合成数据走通真实训练管线。
|
||||
不依赖 FastAPI / xarray / dask, 只验训练 + 持久化 + 回测。
|
||||
"""
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# 绕过 main.py 触发 app 包导入(只导入 modeling 模块)
|
||||
# 当前文件位于 new/app/api/_smoke_test_train.py
|
||||
# app 包在 new/app/__init__.py, 故 new/ 必须在 sys.path 上
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from app.api.modeling import (
|
||||
_get_model_pipeline,
|
||||
_load_train_df,
|
||||
_resolve_feature_start,
|
||||
_run_train_sync,
|
||||
_MODEL_CLASS_REGISTRY,
|
||||
)
|
||||
|
||||
|
||||
def make_synthetic_csv(n_samples: int = 200, n_features: int = 8, noise: float = 0.1, seed: int = 42) -> Path:
|
||||
"""生成 [lat, lon, target, lat2, lon2, feat_0, feat_1, ...] 布局的 CSV"""
|
||||
rng = np.random.default_rng(seed)
|
||||
lat = rng.uniform(20, 25, n_samples)
|
||||
lon = rng.uniform(110, 115, n_samples)
|
||||
target = rng.uniform(0, 50, n_samples)
|
||||
lat2 = rng.uniform(0, 1, n_samples) # 元数据
|
||||
lon2 = rng.uniform(0, 1, n_samples) # 元数据
|
||||
feats = rng.normal(0, 1, (n_samples, n_features))
|
||||
# 让 y 真正依赖前 3 个特征, RF 至少应该能学到 R² > 0.5
|
||||
feats[:, 0] += target / 10
|
||||
feats[:, 1] += target / 20
|
||||
feats[:, 2] -= target / 15
|
||||
|
||||
df = pd.DataFrame({
|
||||
"lat": lat,
|
||||
"lon": lon,
|
||||
"Chl-a": target,
|
||||
"lat2": lat2,
|
||||
"lon2": lon2,
|
||||
**{f"feat_{i}": feats[:, i] for i in range(n_features)},
|
||||
})
|
||||
tmp = Path(tempfile.mkdtemp()) / "train.csv"
|
||||
df.to_csv(tmp, index=False)
|
||||
return tmp
|
||||
|
||||
|
||||
def test_load_train_df():
|
||||
print("== test_load_train_df ==")
|
||||
p = make_synthetic_csv(n_samples=50)
|
||||
df = _load_train_df(str(p))
|
||||
assert df.shape == (50, 5 + 8), f"shape={df.shape}"
|
||||
print(f" shape={df.shape}, columns[:6]={list(df.columns[:6])}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_resolve_feature_start_int_and_str():
|
||||
print("== test_resolve_feature_start (int + str) ==")
|
||||
p = make_synthetic_csv()
|
||||
df = _load_train_df(str(p))
|
||||
idx_int = _resolve_feature_start(df, 5)
|
||||
idx_str = _resolve_feature_start(df, "feat_0")
|
||||
assert idx_int == 5 == idx_str, f"int={idx_int}, str={idx_str}"
|
||||
print(f" int(5) -> {idx_int}, str('feat_0') -> {idx_str}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_resolve_feature_start_str_miss():
|
||||
print("== test_resolve_feature_start (str 不存在 -> 抛错) ==")
|
||||
p = make_synthetic_csv()
|
||||
df = _load_train_df(str(p))
|
||||
try:
|
||||
_resolve_feature_start(df, "not_exist")
|
||||
print(" FAIL: 应抛 ValueError")
|
||||
except ValueError as e:
|
||||
print(f" 正确抛 ValueError: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_get_model_pipeline_all_types():
|
||||
print("== test_get_model_pipeline (5 种 model_type) ==")
|
||||
for mt in ["RF", "SVR", "LinearRegression", "KNN", "PLS"]:
|
||||
p = _get_model_pipeline(mt, {})
|
||||
assert len(p.steps) == 2
|
||||
assert p.steps[0][0] == "scaler"
|
||||
assert p.steps[1][0] == "model"
|
||||
print(f" 全部通过: {list(_MODEL_CLASS_REGISTRY)}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_get_model_pipeline_bad_type():
|
||||
print("== test_get_model_pipeline (坏 model_type) ==")
|
||||
try:
|
||||
_get_model_pipeline("XGBoost", {})
|
||||
print(" FAIL: 应抛 ValueError")
|
||||
except ValueError as e:
|
||||
print(f" 正确抛 ValueError: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_rf_end_to_end():
|
||||
print("== test_run_train_sync (RF 端到端) ==")
|
||||
p = make_synthetic_csv(n_samples=200)
|
||||
out_dir = Path(tempfile.mkdtemp())
|
||||
out_path = out_dir / "model.joblib"
|
||||
|
||||
import time
|
||||
t0 = time.time()
|
||||
metadata = _run_train_sync(
|
||||
model_type="RF",
|
||||
target="Chl-a",
|
||||
train_data_path=str(p),
|
||||
feature_start=5,
|
||||
params={"n_estimators": 30, "max_depth": 6, "random_state": 42, "n_jobs": 1},
|
||||
output_model_path=out_path,
|
||||
)
|
||||
dt = time.time() - t0
|
||||
|
||||
assert out_path.exists(), f"joblib 未落盘: {out_path}"
|
||||
print(f" joblib 落盘: {out_path} ({out_path.stat().st_size} bytes)")
|
||||
print(f" metadata.test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f} test_mae={metadata['test_mae']:.4f}")
|
||||
print(f" metadata.n_features={metadata['n_features']} n_samples={metadata['n_samples']} train_size={metadata['train_size']} test_size={metadata['test_size']}")
|
||||
print(f" 耗时 {dt:.2f}s")
|
||||
|
||||
# 回测: 加载 joblib 再 predict
|
||||
import joblib
|
||||
saved = joblib.load(out_path)
|
||||
assert "model" in saved and "metadata" in saved, f"joblib 双 key 缺失: {saved.keys()}"
|
||||
assert hasattr(saved["model"], "predict")
|
||||
assert saved["metadata"]["test_r2"] == metadata["test_r2"]
|
||||
print(f" joblib 加载 OK, 含 'model' 和 'metadata' 双 key")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_linearregression_fast():
|
||||
print("== test_run_train_sync (LinearRegression 快速路径) ==")
|
||||
p = make_synthetic_csv(n_samples=150)
|
||||
out_path = Path(tempfile.mkdtemp()) / "lr.joblib"
|
||||
metadata = _run_train_sync(
|
||||
model_type="LinearRegression",
|
||||
target="Chl-a",
|
||||
train_data_path=str(p),
|
||||
feature_start=5,
|
||||
params={},
|
||||
output_model_path=out_path,
|
||||
)
|
||||
print(f" test_r2={metadata['test_r2']:.4f} (LR 学到线性, R² 应 >= 0.4)")
|
||||
assert metadata["test_r2"] > 0.3, f"LR test_r2={metadata['test_r2']} 太低, 数据生成可能有问题"
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_bad_csv():
|
||||
print("== test_run_train_sync (CSV 不存在) ==")
|
||||
try:
|
||||
_run_train_sync("RF", "Chl-a", "/no/such/path.csv", 5, {}, Path("/tmp/x.joblib"))
|
||||
print(" FAIL: 应抛异常")
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
print(f" 正确抛 {type(e).__name__}: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_bad_target():
|
||||
print("== test_run_train_sync (target 列不存在) ==")
|
||||
p = make_synthetic_csv()
|
||||
try:
|
||||
_run_train_sync("RF", "NopeTarget", str(p), 5, {}, Path("/tmp/x.joblib"))
|
||||
print(" FAIL: 应抛 ValueError")
|
||||
except ValueError as e:
|
||||
print(f" 正确抛 ValueError: {e}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
def test_run_train_sync_str_feature_start():
|
||||
print("== test_run_train_sync (feature_start 用列名) ==")
|
||||
p = make_synthetic_csv()
|
||||
out_path = Path(tempfile.mkdtemp()) / "str_fs.joblib"
|
||||
metadata = _run_train_sync("RF", "Chl-a", str(p), "feat_0", {"n_estimators": 10}, out_path)
|
||||
assert metadata["feature_start"] == "feat_0"
|
||||
assert metadata["n_features"] == 8
|
||||
assert metadata["feature_columns"][0] == "feat_0"
|
||||
print(f" 列名 'feat_0' 解析正确, n_features={metadata['n_features']}")
|
||||
print(" PASS")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_load_train_df()
|
||||
test_resolve_feature_start_int_and_str()
|
||||
test_resolve_feature_start_str_miss()
|
||||
test_get_model_pipeline_all_types()
|
||||
test_get_model_pipeline_bad_type()
|
||||
test_run_train_sync_rf_end_to_end()
|
||||
test_run_train_sync_linearregression_fast()
|
||||
test_run_train_sync_bad_csv()
|
||||
test_run_train_sync_bad_target()
|
||||
test_run_train_sync_str_feature_start()
|
||||
print("\n>>> ALL SMOKE TESTS PASSED")
|
||||
222
new/app/api/endpoints.py
Normal file
222
new/app/api/endpoints.py
Normal file
@ -0,0 +1,222 @@
|
||||
"""
|
||||
API 路由集合
|
||||
============
|
||||
|
||||
把业务接口统一收口到 APIRouter,再由 main.py 通过 include_router 挂载。
|
||||
|
||||
当前包含的接口:
|
||||
GET /api/algorithms 列出已注册的所有去耀斑算法(供前端下拉框)
|
||||
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
|
||||
GET /api/tasks/{task_id} 查询指定任务的状态与结果
|
||||
|
||||
派发链:
|
||||
POST /api/process/deglint
|
||||
└─ BackgroundTasks.add_task(execute_glint_removal_task, ...)
|
||||
└─ get_remover(method) 从注册表拿到算法类
|
||||
└─ remover.process(input_zarr, output_zarr, **params)
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 并发安全的任务状态存储(替代旧版的 MOCK_TASK_DB)
|
||||
from app.core.task_store import get_task, set_task, update_task
|
||||
|
||||
# 算法注册表 API
|
||||
from app.core.algorithms import get_remover, list_removers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 路由实例
|
||||
# ---------------------------------------------------------------------------
|
||||
# prefix 不在此处设置,统一在 main.py 挂载时给定,便于将来按版本拆分
|
||||
# (例如 /api/v1、/api/v2 共存时复用同一个 router 对象)。
|
||||
# ---------------------------------------------------------------------------
|
||||
router = APIRouter(tags=["deglint"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 请求 / 响应数据模型
|
||||
# ---------------------------------------------------------------------------
|
||||
class DeglintRequest(BaseModel):
|
||||
"""POST /api/process/deglint 的请求体"""
|
||||
|
||||
method: str = Field(
|
||||
...,
|
||||
description="去耀斑方法名称,必须是已注册算法,例如 'kutser' / 'goodman'",
|
||||
examples=["kutser"],
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description=(
|
||||
"传递给算法 process() 的超参数字典,例如 "
|
||||
"Kutser: {'band_lower': 773, 'band_oxy': 845, 'band_upper': 893}; "
|
||||
"Goodman: {'band_ref': 750, 'band_diff': 640, 'A': 0.0, 'B': 0.0}"
|
||||
),
|
||||
examples=[{"band_lower": 773, "band_oxy": 845, "band_upper": 893}],
|
||||
)
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
"""提交任务成功后立即返回的响应"""
|
||||
|
||||
task_id: str
|
||||
status: str # 一定是 PENDING
|
||||
|
||||
|
||||
class AlgorithmListResponse(BaseModel):
|
||||
"""GET /api/algorithms 的响应"""
|
||||
|
||||
algorithms: list # 已注册算法名列表
|
||||
count: int # 算法总数
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 后台任务执行器(真实派发链)
|
||||
# ---------------------------------------------------------------------------
|
||||
# 注意:这里使用 async def。
|
||||
# FastAPI / Starlette 的 BackgroundTasks 支持 async function,
|
||||
# 会在响应返回后自动 await 它,不影响主请求链路。
|
||||
# ---------------------------------------------------------------------------
|
||||
async def execute_glint_removal_task(
|
||||
task_id: str,
|
||||
method: str,
|
||||
params: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
后台异步执行器:按 method 名字从注册表取出算法类,实例化并运行 process()。
|
||||
|
||||
状态机:
|
||||
PENDING -> PROCESSING -> SUCCESS
|
||||
└──> FAILED(含 error / traceback)
|
||||
"""
|
||||
# 0. 安全检查:任务记录必须已存在(POST 阶段已写入)
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
print(f"[{task_id}] 任务不存在, 跳过")
|
||||
return
|
||||
|
||||
# 1. 状态推进到 PROCESSING
|
||||
await update_task(
|
||||
task_id,
|
||||
status="PROCESSING",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 开始处理 method={method} params={params}")
|
||||
|
||||
# 2. 临时硬编码 IO 路径(未来由数据管理层提供)
|
||||
# TODO: 替换为真实的数据管理服务返回的 zarr 路径
|
||||
input_zarr_path = "./data/temp_in.zarr"
|
||||
output_zarr_path = f"./data/{task_id}_out.zarr"
|
||||
|
||||
try:
|
||||
# 3. 按 method 名字从注册表取算法类并实例化
|
||||
# get_remover 找不到时会抛 KeyError,下面的 except 会兜住
|
||||
algorithm_cls = get_remover(method)
|
||||
remover = algorithm_cls()
|
||||
|
||||
# 4. 调用算法(注意 await,因为 BaseGlintRemover.process 是 async)
|
||||
await remover.process(input_zarr_path, output_zarr_path, **params)
|
||||
|
||||
# 5. 成功:写回结果路径与状态
|
||||
await update_task(
|
||||
task_id,
|
||||
status="SUCCESS",
|
||||
output_zarr_path=output_zarr_path,
|
||||
error=None,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 处理完成 -> SUCCESS, output={output_zarr_path}")
|
||||
|
||||
except Exception as exc: # noqa: BLE001 顶层兜底,绝不让后台任务静默失败
|
||||
# 6. 失败:记录错误信息与堆栈,便于前端排查
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
output_zarr_path=None,
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
traceback=traceback.format_exc(),
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 处理失败 -> {type(exc).__name__}: {exc}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /algorithms
|
||||
# ---------------------------------------------------------------------------
|
||||
# 返回当前已注册的所有算法名,供前端动态渲染下拉框 / 选择器。
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/algorithms", response_model=AlgorithmListResponse)
|
||||
async def list_registered_algorithms() -> Dict[str, Any]:
|
||||
"""列出已注册的去耀斑算法。"""
|
||||
names = list(list_removers().keys())
|
||||
return {"algorithms": names, "count": len(names)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /process/deglint
|
||||
# ---------------------------------------------------------------------------
|
||||
# 提交去耀斑处理任务。FastAPI 在函数返回后才会把响应发给前端,
|
||||
# 因此通过 BackgroundTasks 把耗时操作丢到后台,接口本身立刻返回 task_id。
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/process/deglint", response_model=TaskAcceptedResponse)
|
||||
async def submit_deglint(
|
||||
payload: DeglintRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> Dict[str, Any]:
|
||||
"""提交一个去耀斑处理任务,并立即返回 task_id。"""
|
||||
|
||||
# 1. 生成唯一任务 ID(UUID4 足以保证全局唯一性)
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 2. 在任务库中登记一条 PENDING 记录(并发安全)
|
||||
# 注意:output_zarr_path / error / traceback 字段在执行过程中被填充
|
||||
await set_task(
|
||||
task_id,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"method": payload.method,
|
||||
"params": payload.params,
|
||||
"status": "PENDING",
|
||||
"output_zarr_path": None,
|
||||
"error": None,
|
||||
"traceback": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# 3. 把真实执行器丢到后台
|
||||
background_tasks.add_task(
|
||||
execute_glint_removal_task,
|
||||
task_id,
|
||||
payload.method,
|
||||
payload.params,
|
||||
)
|
||||
|
||||
# 4. 立即返回 task_id 与 PENDING 状态
|
||||
return {"task_id": task_id, "status": "PENDING"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /tasks/{task_id}
|
||||
# ---------------------------------------------------------------------------
|
||||
# 前端轮询此接口获取任务状态。PENDING / PROCESSING 表示仍在跑,
|
||||
# SUCCESS 表示成功(含 output_zarr_path),FAILED 表示失败(含 error / traceback)。
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task_status(task_id: str) -> Dict[str, Any]:
|
||||
"""查询指定任务的当前状态与结果。"""
|
||||
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
# 找不到 task_id 通常意味着客户端拼错了 ID,或者记录已被清理
|
||||
raise HTTPException(status_code=404, detail=f"task_id 不存在: {task_id}")
|
||||
|
||||
# 直接返回字典,FastAPI 会自动 JSON 序列化
|
||||
return record
|
||||
786
new/app/api/modeling.py
Normal file
786
new/app/api/modeling.py
Normal file
@ -0,0 +1,786 @@
|
||||
"""
|
||||
app/api/modeling.py
|
||||
===================
|
||||
|
||||
机器学习与水质反演相关的 API 路由。
|
||||
|
||||
接口(最终路径, 挂载后):
|
||||
POST /api/modeling/train 提交模型训练任务, 立即返回 task_id
|
||||
GET /api/modeling/models 列出已训练好的模型(未来从磁盘 joblib 读)
|
||||
POST /api/modeling/predict 提交模型推断任务, 立即返回 task_id
|
||||
|
||||
设计要点
|
||||
--------
|
||||
- 训练 / 推断均为异步后台任务, 复用 app.core.task_store 的并发安全任务状态。
|
||||
- 模型元数据用模块级 _MODEL_REGISTRY 暂存(开发期内存存储),
|
||||
未来从磁盘 joblib 读时只需替换 list_trained_models() 内部实现即可。
|
||||
- /predict 已接入真实 sklearn + xarray + dask 流式推断:
|
||||
* joblib.load 读模型(缺文件时降级为 Dummy RandomForestRegressor)
|
||||
* xr.open_zarr 延迟打开影像, NaN 填 0
|
||||
* xr.apply_ufunc(dask="parallelized") 沿 (y, x) 逐 chunk 调 model.predict
|
||||
* to_zarr(mode="w", compute=True) 流式写出, 内存峰值 ≈ 1 个 chunk
|
||||
- /train 已接入真实 sklearn + pandas 训练管线:
|
||||
* pd.read_csv 读结构化训练表(支持 [lat, lon, target_*, feature_*] 布局)
|
||||
* 按 target 列 dropna 清洗;按 feature_start 索引/列名切分特征
|
||||
* sklearn Pipeline: StandardScaler -> {RF/SVR/LinearRegression/KNN/PLS}
|
||||
* train_test_split(80/20) 划分, 计算 test_r2/rmse/mae
|
||||
* joblib.dump({model, metadata}) 落盘 ./data/models/{model_id}.joblib
|
||||
* 测试指标写回 TASK_STORE, 同时登记到 _MODEL_REGISTRY
|
||||
注: 旧版 SPXY / KS 划分留作未来扩展, 当前固定 random 划分 (test_size=0.2, random_state=42)。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xarray as xr
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
from sklearn.linear_model import LinearRegression
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVR
|
||||
|
||||
# 复用并发安全任务状态存储(与 deglint 共享同一份 TASK_STORE,
|
||||
# 通过 task 记录里的 "kind" 字段区分 train / predict / deglint)
|
||||
from app.core.task_store import get_task, set_task, update_task
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 路由实例
|
||||
# ---------------------------------------------------------------------------
|
||||
# prefix="/modeling" 让本文件内只写 /train /models /predict 等短路径,
|
||||
# 最终完整路径由 main.py 挂载时再补 /api。
|
||||
# ---------------------------------------------------------------------------
|
||||
router = APIRouter(prefix="/modeling", tags=["modeling"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 数据模型
|
||||
# ---------------------------------------------------------------------------
|
||||
class TrainRequest(BaseModel):
|
||||
"""POST /api/modeling/train 的请求体"""
|
||||
|
||||
model_type: str = Field(
|
||||
...,
|
||||
description="模型类型, 例如 'RF' (随机森林) / 'SVR' (支持向量回归) / 'XGBoost' / 'MLP'",
|
||||
examples=["RF", "SVR"],
|
||||
)
|
||||
target: str = Field(
|
||||
...,
|
||||
description="反演目标水质参数, 例如 'Chl-a' (叶绿素a) / 'TSS' (总悬浮物) / 'CDOM' (有色可溶有机物)",
|
||||
examples=["Chl-a", "TSS", "CDOM"],
|
||||
)
|
||||
train_data_path: str = Field(
|
||||
...,
|
||||
description="训练数据集的 zarr 路径(包含 reflectance 变量与 target 标签)",
|
||||
examples=["./data/train.zarr"],
|
||||
)
|
||||
feature_start: Union[int, str] = Field(
|
||||
default=4,
|
||||
description=(
|
||||
"特征列起始位置. 表格布局假定为 "
|
||||
"[lat, lon, target_1, target_2, ..., feature_1, feature_2, ...] "
|
||||
"可传 int 列索引(如 4)或 str 列名(如 '374.285' 波长起点)。"
|
||||
"默认 4, 即前 4 列视为元数据/目标, 之后全部是特征。"
|
||||
),
|
||||
examples=[4, "374.285"],
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="模型超参, 例如 RF 的 {'n_estimators': 100, 'max_depth': 20}",
|
||||
examples=[{"n_estimators": 100, "max_depth": 20}],
|
||||
)
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
"""POST /api/modeling/predict 的请求体"""
|
||||
|
||||
model_id: str = Field(
|
||||
...,
|
||||
description="已训练模型的 ID(由 /api/modeling/train 返回或 /api/modeling/models 列出)",
|
||||
)
|
||||
input_zarr_path: str = Field(
|
||||
...,
|
||||
description="待推断影像的 zarr 路径",
|
||||
examples=["./data/scene.zarr"],
|
||||
)
|
||||
output_zarr_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"输出 zarr 路径, 缺省时由后端按规则生成 "
|
||||
"(如 ./data/{model_id}_{input_stem}_pred.zarr)"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TaskAcceptedResponse(BaseModel):
|
||||
"""提交训练/推断任务后立即返回的响应"""
|
||||
|
||||
task_id: str
|
||||
status: str # 一定是 PENDING
|
||||
kind: str # "train" / "predict", 便于前端识别任务类型
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""单个模型的元信息(GET /api/modeling/models 的元素)"""
|
||||
|
||||
model_id: str
|
||||
model_type: str
|
||||
target: str
|
||||
params: Dict[str, Any]
|
||||
path: str # joblib 文件路径
|
||||
created_at: str
|
||||
train_task_id: str # 产生此模型的那个训练任务的 ID
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
"""GET /api/modeling/models 的响应"""
|
||||
|
||||
models: List[ModelInfo]
|
||||
count: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 模块级模型注册表(开发期内存, 未来替换为磁盘扫描)
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_id -> ModelInfo 字典
|
||||
# 读多写少, 用一个普通 dict 足够(CPython GIL 兜底)。
|
||||
# 写时(训练完成时)只发生一次, 无并发风险。
|
||||
# ---------------------------------------------------------------------------
|
||||
_MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _register_model(record: Dict[str, Any]) -> None:
|
||||
"""将训练完成的模型登记到内存注册表。"""
|
||||
_MODEL_REGISTRY[record["model_id"]] = record
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 训练管线的模块级辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设计要点 (与推断管线一致):
|
||||
# 1) 模块级函数: dask / joblib 后端若走子进程 pickle, 嵌套闭包会丢字段。
|
||||
# 2) 同步执行: execute_train_task 用 asyncio.to_thread 派发, 内部全程同步阻塞。
|
||||
# 3) 失败抛异常: 异常由 execute_train_task 捕获, 转 FAILED + traceback。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# model_type (大写字符串) -> sklearn 估计器类
|
||||
# 与 OpenClaw model_configs 思路一致, 但此处只保留类 (参数由 params 透传)
|
||||
_MODEL_CLASS_REGISTRY: Dict[str, type] = {
|
||||
"RF": RandomForestRegressor,
|
||||
"SVR": SVR,
|
||||
"LinearRegression": LinearRegression,
|
||||
"KNN": KNeighborsRegressor,
|
||||
"PLS": PLSRegression,
|
||||
}
|
||||
|
||||
|
||||
def _get_model_pipeline(model_type: str, params: Optional[Dict[str, Any]]) -> Pipeline:
|
||||
"""
|
||||
模型工厂: 按 model_type 选 sklearn 类, 用 StandardScaler + 估计器构造 Pipeline。
|
||||
|
||||
与 OpenClaw 不同之处: 把 scaler 放进 Pipeline 第一步,
|
||||
推断时直接 pipeline.predict(X) 即可, scaler 参数与训练时严格一致。
|
||||
"""
|
||||
model_cls = _MODEL_CLASS_REGISTRY.get(model_type)
|
||||
if model_cls is None:
|
||||
raise ValueError(
|
||||
f"不支持的 model_type='{model_type}', "
|
||||
f"可选: {sorted(_MODEL_CLASS_REGISTRY.keys())}"
|
||||
)
|
||||
estimator = model_cls(**(params or {}))
|
||||
return Pipeline([("scaler", StandardScaler()), ("model", estimator)])
|
||||
|
||||
|
||||
def _load_train_df(csv_path: str) -> pd.DataFrame:
|
||||
"""
|
||||
读 CSV 训练表, 规整空串 / 空白 / NULL 等为 NaN。
|
||||
|
||||
沿用 OpenClaw modeling_batch.load_data_batch 的读取策略:
|
||||
na_values 显式列举 + 正则二次清理 (防 cell 内出现 " " 等纯空白)。
|
||||
"""
|
||||
try:
|
||||
df = pd.read_csv(
|
||||
csv_path,
|
||||
na_values=["", " ", "NaN", "nan", "NULL", "null"],
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise FileNotFoundError(f"训练数据文件不存在: {csv_path}") from exc
|
||||
except pd.errors.EmptyDataError as exc:
|
||||
raise ValueError(f"训练数据文件为空: {csv_path}") from exc
|
||||
# 二次清理: 残留的纯空白 cell
|
||||
df = df.replace(r"^\s*$", np.nan, regex=True)
|
||||
return df
|
||||
|
||||
|
||||
def _resolve_feature_start(
|
||||
df: pd.DataFrame,
|
||||
feature_start: Union[int, str],
|
||||
) -> int:
|
||||
"""
|
||||
将 feature_start (int 索引 / str 列名) 统一解析为 int 列索引。
|
||||
|
||||
与 OpenClaw modeling_batch.load_data_batch / load_data_single 一致:
|
||||
str 走 columns.get_loc, int 直接返回。
|
||||
"""
|
||||
if isinstance(feature_start, str):
|
||||
if feature_start not in df.columns:
|
||||
raise ValueError(
|
||||
f"feature_start='{feature_start}' 不在 CSV 列中: {list(df.columns)}"
|
||||
)
|
||||
return int(df.columns.get_loc(feature_start))
|
||||
return int(feature_start)
|
||||
|
||||
|
||||
def _run_train_sync(
|
||||
model_type: str,
|
||||
target: str,
|
||||
train_data_path: str,
|
||||
feature_start: Union[int, str],
|
||||
params: Optional[Dict[str, Any]],
|
||||
output_model_path: Path,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
完整同步训练流程 (由 execute_train_task 在线程池内调用):
|
||||
|
||||
pd.read_csv -> 目标列 dropna -> 切特征 -> train_test_split(80/20)
|
||||
-> Pipeline(StandardScaler + model).fit -> 评估 test_r2/rmse/mae
|
||||
-> joblib.dump({model, metadata}, output_model_path)
|
||||
|
||||
Returns:
|
||||
metadata 字典, 含 test_r2 / test_rmse / test_mae / n_features 等,
|
||||
调用方负责写回 TASK_STORE 和 _MODEL_REGISTRY。
|
||||
|
||||
注: 旧版 SPXY / KS 划分留作未来扩展 (params.split_method 控制),
|
||||
当前固定 random + test_size=0.2 + random_state=42。
|
||||
"""
|
||||
df = _load_train_df(train_data_path)
|
||||
|
||||
if target not in df.columns:
|
||||
raise ValueError(
|
||||
f"target='{target}' 不在 CSV 列中, 可选: {list(df.columns)}"
|
||||
)
|
||||
|
||||
# 1) 清洗: 仅剔除 target NaN 的行 (与 OpenClaw load_data_single 一致)
|
||||
df = df[df[target].notna()].copy()
|
||||
if df.empty:
|
||||
raise ValueError("target 剔除 NaN 后无样本, 终止训练")
|
||||
|
||||
# 2) 特征切分
|
||||
feature_start_idx = _resolve_feature_start(df, feature_start)
|
||||
feature_columns = list(df.columns[feature_start_idx:])
|
||||
|
||||
X = df.iloc[:, feature_start_idx:].astype(np.float64)
|
||||
y = df[target].astype(np.float64).values
|
||||
|
||||
# 3) 划分 (固定 random, 未来扩展 spxy/ks)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X.values,
|
||||
y,
|
||||
test_size=0.2,
|
||||
random_state=42,
|
||||
)
|
||||
|
||||
# 4) 构造 Pipeline + 训练
|
||||
pipeline = _get_model_pipeline(model_type, params)
|
||||
pipeline.fit(X_train, y_train)
|
||||
|
||||
# 5) 测试集与训练集评估
|
||||
y_pred = pipeline.predict(X_test)
|
||||
test_r2 = float(r2_score(y_test, y_pred))
|
||||
test_rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
|
||||
test_mae = float(mean_absolute_error(y_test, y_pred))
|
||||
|
||||
y_train_pred = pipeline.predict(X_train)
|
||||
train_r2 = float(r2_score(y_train, y_train_pred))
|
||||
train_rmse = float(np.sqrt(mean_squared_error(y_train, y_train_pred)))
|
||||
train_mae = float(mean_absolute_error(y_train, y_train_pred))
|
||||
|
||||
metadata: Dict[str, Any] = {
|
||||
"model_type": model_type,
|
||||
"target": target,
|
||||
"feature_start": feature_start,
|
||||
"feature_columns": feature_columns,
|
||||
"n_features": int(X.shape[1]),
|
||||
"n_samples": int(X.shape[0]),
|
||||
"train_size": int(X_train.shape[0]),
|
||||
"test_size": int(X_test.shape[0]),
|
||||
"params": dict(params or {}),
|
||||
"test_r2": test_r2,
|
||||
"test_rmse": test_rmse,
|
||||
"test_mae": test_mae,
|
||||
"train_r2": train_r2,
|
||||
"train_rmse": train_rmse,
|
||||
"train_mae": train_mae,
|
||||
"split_method": "random",
|
||||
"trained_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# 7) 持久化 (目录可能不存在, 顺手建)
|
||||
output_model_path = Path(output_model_path)
|
||||
output_model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
joblib.dump(
|
||||
{"model": pipeline, "metadata": metadata},
|
||||
output_model_path,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 推断管线的模块级辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
# 设计要点:
|
||||
# 1) Dask 调度时, 函数必须可被工作进程 pickle 序列化。
|
||||
# 因此 _predict_block / _load_model / _make_dummy_model / _run_predict_sync
|
||||
# 全部是模块级函数 (而非嵌套), 避免闭包陷阱。
|
||||
# 2) _predict_block 通过 model.predict(spectra_2d) 整批预测,
|
||||
# 整张影像的 O(n_pixels * n_bands) 一次性预测在大矩阵上必 OOM,
|
||||
# 因此外层用 xr.apply_ufunc(dask="parallelized") 把矩阵切块
|
||||
# 逐块进入此函数, 单次内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) 大小。
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dummy_model(n_features: int) -> RandomForestRegressor:
|
||||
"""
|
||||
构造一个 Dummy 随机森林回归器。
|
||||
|
||||
用途:
|
||||
1) 真实 joblib 文件不存在时的连通性测试
|
||||
2) 训练骨架尚未接入真实数据时的占位推断
|
||||
"""
|
||||
rng = np.random.default_rng(42)
|
||||
X = rng.random((200, n_features))
|
||||
y = rng.random(200)
|
||||
model = RandomForestRegressor(
|
||||
n_estimators=10, max_depth=5, random_state=0, n_jobs=1
|
||||
)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
|
||||
def _load_model(path: str, n_features: int) -> Any:
|
||||
"""
|
||||
加载训练好的 sklearn 模型, 失败时降级 Dummy。
|
||||
|
||||
优先级:
|
||||
1) path 存在且 joblib.load 成功 -> 返回真实模型
|
||||
2) 否则 -> 降级为 Dummy 随机森林 (n_features 必须指定)
|
||||
"""
|
||||
p = Path(path)
|
||||
if p.is_file() and p.stat().st_size > 0:
|
||||
try:
|
||||
print(f"[model] 从磁盘加载: {path}")
|
||||
return joblib.load(path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[model] joblib.load 失败 ({type(exc).__name__}: {exc}), 降级 Dummy")
|
||||
print(f"[model] 真实 joblib 不存在 ({path}), 使用 Dummy RandomForest")
|
||||
return _make_dummy_model(n_features)
|
||||
|
||||
|
||||
def _predict_block(spectra_3d: np.ndarray, model: Any) -> np.ndarray:
|
||||
"""
|
||||
单个 dask chunk 的推断函数 (xr.apply_ufunc 会自动调度调用)。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spectra_3d : np.ndarray
|
||||
形状 (y_chunk, x_chunk, n_bands)。
|
||||
此形状由 input_core_dims=[["band"]] 决定:
|
||||
xarray 会把 band 维移到最后一轴, 然后按 (y, x) 的 chunk 切分调用本函数。
|
||||
model : 已 fit 好的 sklearn 估计器
|
||||
接受 (n_samples, n_features) 输入, 返回 (n_samples,) 预测。
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
形状 (y_chunk, x_chunk), dtype float32 的标量预测图。
|
||||
"""
|
||||
yc, xc, nb = spectra_3d.shape
|
||||
# 2D 化: 每个像素一行光谱
|
||||
flat = spectra_3d.reshape(yc * xc, nb)
|
||||
# sklearn 风格的批量预测
|
||||
pred = model.predict(flat)
|
||||
# 还原为 2D 空间图, 强制 float32 节约一半内存
|
||||
return pred.reshape(yc, xc).astype(np.float32, copy=False)
|
||||
|
||||
|
||||
def _run_predict_sync(
|
||||
model: Any,
|
||||
model_id: str,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
同步推断主流程 (被 asyncio.to_thread 调用)。
|
||||
|
||||
流程:
|
||||
1) xr.open_zarr 延迟打开 (dask 数组, 不一次性读入内存)
|
||||
2) NaN -> 0 清洗 (model.predict 不接受 NaN)
|
||||
3) xr.apply_ufunc 沿 (y, x) 逐 chunk 调 _predict_block
|
||||
4) 非水域置 NaN (zarr 支持 float NaN)
|
||||
5) to_zarr 触发整图计算 + 流式写出
|
||||
"""
|
||||
# 1. 延迟打开输入 (关键: Dask 不一次性读入内存)
|
||||
ds = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
if "reflectance" not in ds.data_vars:
|
||||
raise KeyError(
|
||||
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds.data_vars)}"
|
||||
)
|
||||
|
||||
reflectance = ds["reflectance"] # dims: (y, x, band)
|
||||
n_bands = reflectance.sizes["band"]
|
||||
|
||||
# 2. 水域掩膜 (与去耀斑算法同约定)
|
||||
if "water_mask" in ds.data_vars or "water_mask" in ds.coords:
|
||||
water_mask = ds["water_mask"].astype(bool)
|
||||
else:
|
||||
water_mask = xr.ones_like(reflectance.isel(band=0), dtype=bool)
|
||||
|
||||
# 3. NaN 清洗: 填充 0 (model.predict 不接受 NaN)
|
||||
refl_clean = reflectance.fillna(0.0)
|
||||
|
||||
# 4. 核心: 用 apply_ufunc 把 model.predict 沿 (y, x) 应用
|
||||
# dask="parallelized" 让每个 (y_chunk, x_chunk, band) chunk
|
||||
# 独立调 _predict_block, 任意时刻内存中只有若干个 chunk。
|
||||
prediction: xr.DataArray = xr.apply_ufunc(
|
||||
_predict_block,
|
||||
refl_clean,
|
||||
kwargs={"model": model},
|
||||
input_core_dims=[["band"]],
|
||||
output_core_dims=[[]],
|
||||
dask="parallelized",
|
||||
output_dtypes=[np.float32],
|
||||
dask_gufunc_kwargs={"allow_rechunk": True},
|
||||
vectorize=False,
|
||||
)
|
||||
|
||||
# 5. 非水域置 NaN (zarr 支持 float NaN, 便于后续可视化/掩膜分析)
|
||||
prediction = prediction.where(water_mask, np.nan)
|
||||
|
||||
# 6. 包装为 Dataset 并流式写出
|
||||
out = xr.Dataset(
|
||||
{"prediction": prediction},
|
||||
attrs={
|
||||
"model_id": model_id,
|
||||
"input_zarr_path": input_zarr_path,
|
||||
"n_bands": n_bands,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
# 保留 y/x 坐标
|
||||
out = out.assign_coords(y=ds["y"], x=ds["x"])
|
||||
|
||||
# to_zarr + compute=True 触发整图 dask 图求值
|
||||
# 中间会按 chunk 逐块调度到线程池, 内存峰值 ≈ 1 个 chunk 的体量
|
||||
out.to_zarr(output_zarr_path, mode="w", compute=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 后台任务执行器
|
||||
# ---------------------------------------------------------------------------
|
||||
async def execute_train_task(
|
||||
task_id: str,
|
||||
model_type: str,
|
||||
target: str,
|
||||
train_data_path: str,
|
||||
feature_start: Union[int, str],
|
||||
params: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
训练任务后台执行器(已接入真实 sklearn 训练流程)。
|
||||
|
||||
流程:
|
||||
1) get_task 校验任务存在
|
||||
2) update_task(PROCESSING)
|
||||
3) 生成 model_id / model_path
|
||||
4) asyncio.to_thread 派发 _run_train_sync 到默认线程池
|
||||
5) 成功 -> _register_model + update_task(SUCCESS, 附 test_r2/rmse/mae)
|
||||
6) 失败 -> update_task(FAILED, 附 error + traceback)
|
||||
"""
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
print(f"[{task_id}] 训练任务不存在, 跳过")
|
||||
return
|
||||
|
||||
await update_task(
|
||||
task_id,
|
||||
status="PROCESSING",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(
|
||||
f"[{task_id}] 开始训练 model_type={model_type} target={target} "
|
||||
f"train_data_path={train_data_path} feature_start={feature_start}"
|
||||
)
|
||||
|
||||
# model_id 用 uuid4 前 12 位 (8 位易撞, 12 位兼顾可读性)
|
||||
model_id = f"model_{uuid.uuid4().hex[:12]}"
|
||||
model_path = Path(f"./data/models/{model_id}.joblib")
|
||||
|
||||
try:
|
||||
# 同步 sklearn / pandas 训练丢到默认线程池, 不阻塞 event loop
|
||||
metadata = await asyncio.to_thread(
|
||||
_run_train_sync,
|
||||
model_type,
|
||||
target,
|
||||
train_data_path,
|
||||
feature_start,
|
||||
params,
|
||||
model_path,
|
||||
)
|
||||
|
||||
# 登记到内存注册表 (供 /predict 查 model_id)
|
||||
_register_model(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"model_type": model_type,
|
||||
"target": target,
|
||||
"params": dict(params or {}),
|
||||
"path": str(model_path),
|
||||
"feature_start": feature_start,
|
||||
"n_features": metadata["n_features"],
|
||||
"test_r2": metadata["test_r2"],
|
||||
"test_rmse": metadata["test_rmse"],
|
||||
"test_mae": metadata["test_mae"],
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"train_task_id": task_id,
|
||||
}
|
||||
)
|
||||
|
||||
# 把训练指标写回任务记录, 前端轮询时可直接看
|
||||
await update_task(
|
||||
task_id,
|
||||
status="SUCCESS",
|
||||
model_id=model_id,
|
||||
model_path=str(model_path),
|
||||
test_r2=metadata["test_r2"],
|
||||
test_rmse=metadata["test_rmse"],
|
||||
test_mae=metadata["test_mae"],
|
||||
n_features=metadata["n_features"],
|
||||
n_samples=metadata["n_samples"],
|
||||
error=None,
|
||||
traceback=None,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(
|
||||
f"[{task_id}] 训练完成 -> model_id={model_id} "
|
||||
f"test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f}"
|
||||
)
|
||||
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# 失败时 model_path 不一定有产物, 显式置 None 方便前端判断
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
model_id=None,
|
||||
model_path=None,
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
traceback=traceback.format_exc(),
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 训练失败 -> {type(exc).__name__}: {exc}")
|
||||
|
||||
|
||||
async def execute_predict_task(
|
||||
task_id: str,
|
||||
model_id: str,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
推断任务后台执行器(真实实现版)。
|
||||
|
||||
OOM 防护策略:
|
||||
- xr.open_zarr(..., chunks="auto") 延迟打开, 整图不一次性读入内存
|
||||
- xr.apply_ufunc(..., dask="parallelized") 把影像按 chunk 切分
|
||||
- 每个 chunk 内部 reshape 成 2D, 调 model.predict, 再 reshape 回 2D
|
||||
- 任意时刻内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) chunk 的体量
|
||||
- 整图完成计算后再 to_zarr(compute=True) 流式写出
|
||||
"""
|
||||
record = await get_task(task_id)
|
||||
if record is None:
|
||||
print(f"[{task_id}] 推断任务不存在, 跳过")
|
||||
return
|
||||
|
||||
# 1. 校验 model_id 是否已注册 (避免在后台任务里报模糊错误)
|
||||
model_meta = _MODEL_REGISTRY.get(model_id)
|
||||
if model_meta is None:
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
error=f"model_id 不存在: {model_id}",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 推断失败 -> model_id 不存在: {model_id}")
|
||||
return
|
||||
|
||||
# 2. 自动生成 output_zarr_path (若未提供)
|
||||
if output_zarr_path is None:
|
||||
stem = input_zarr_path.rstrip("/\\").split("/")[-1].split("\\")[-1]
|
||||
stem = stem.replace(".zarr", "")
|
||||
output_zarr_path = f"./data/{model_id}_{stem}_pred.zarr"
|
||||
|
||||
await update_task(
|
||||
task_id,
|
||||
status="PROCESSING",
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 开始推断 model_id={model_id} input={input_zarr_path}")
|
||||
|
||||
try:
|
||||
# 3. 探测波段数 (用于 Dummy 模型适配)
|
||||
# 这里只读 zarr 元数据 (.zarray 的 shape), 不读真实数据
|
||||
ds_probe = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
if "reflectance" not in ds_probe.data_vars:
|
||||
raise KeyError(
|
||||
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds_probe.data_vars)}"
|
||||
)
|
||||
n_bands = ds_probe["reflectance"].sizes["band"]
|
||||
ds_probe.close()
|
||||
|
||||
# 4. 加载模型 (真实文件优先, Dummy 兜底)
|
||||
model = _load_model(model_meta["path"], n_features=n_bands)
|
||||
|
||||
# 5. 包装同步执行, 丢到线程池, 事件循环不阻塞
|
||||
await asyncio.to_thread(
|
||||
_run_predict_sync,
|
||||
model,
|
||||
model_id,
|
||||
input_zarr_path,
|
||||
output_zarr_path,
|
||||
)
|
||||
|
||||
await update_task(
|
||||
task_id,
|
||||
status="SUCCESS",
|
||||
output_zarr_path=output_zarr_path,
|
||||
model_id=model_id,
|
||||
error=None,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 推断完成 -> output={output_zarr_path}")
|
||||
|
||||
except Exception as exc: # noqa: BLE001
|
||||
tb_text = traceback.format_exc()
|
||||
await update_task(
|
||||
task_id,
|
||||
status="FAILED",
|
||||
output_zarr_path=None,
|
||||
error=f"{type(exc).__name__}: {exc}",
|
||||
traceback=tb_text,
|
||||
updated_at=datetime.now().isoformat(),
|
||||
)
|
||||
print(f"[{task_id}] 推断失败 -> {type(exc).__name__}: {exc}")
|
||||
print(tb_text)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/modeling/train
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/train", response_model=TaskAcceptedResponse)
|
||||
async def submit_train(
|
||||
payload: TrainRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> Dict[str, Any]:
|
||||
"""提交一个模型训练任务, 立即返回 task_id。"""
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
await set_task(
|
||||
task_id,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"kind": "train",
|
||||
"model_type": payload.model_type,
|
||||
"target": payload.target,
|
||||
"train_data_path": payload.train_data_path,
|
||||
"feature_start": payload.feature_start,
|
||||
"params": payload.params,
|
||||
"status": "PENDING",
|
||||
"model_id": None,
|
||||
"model_path": None,
|
||||
"test_r2": None,
|
||||
"test_rmse": None,
|
||||
"test_mae": None,
|
||||
"n_features": None,
|
||||
"n_samples": None,
|
||||
"error": None,
|
||||
"traceback": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
background_tasks.add_task(
|
||||
execute_train_task,
|
||||
task_id,
|
||||
payload.model_type,
|
||||
payload.target,
|
||||
payload.train_data_path,
|
||||
payload.feature_start,
|
||||
payload.params,
|
||||
)
|
||||
return {"task_id": task_id, "status": "PENDING", "kind": "train"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/modeling/models
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.get("/models", response_model=ModelListResponse)
|
||||
async def list_trained_models() -> Dict[str, Any]:
|
||||
"""
|
||||
列出已训练好的模型。
|
||||
|
||||
未来实现: 从 ./data/models/*.joblib 扫描元信息,
|
||||
当前直接从内存 _MODEL_REGISTRY 读。
|
||||
"""
|
||||
models = list(_MODEL_REGISTRY.values())
|
||||
# 按 created_at 倒序, 最新训练的在前
|
||||
models.sort(key=lambda m: m.get("created_at", ""), reverse=True)
|
||||
return {"models": models, "count": len(models)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/modeling/predict
|
||||
# ---------------------------------------------------------------------------
|
||||
@router.post("/predict", response_model=TaskAcceptedResponse)
|
||||
async def submit_predict(
|
||||
payload: PredictRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
) -> Dict[str, Any]:
|
||||
"""提交一个模型推断任务, 立即返回 task_id。"""
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
await set_task(
|
||||
task_id,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"kind": "predict",
|
||||
"model_id": payload.model_id,
|
||||
"input_zarr_path": payload.input_zarr_path,
|
||||
"output_zarr_path": payload.output_zarr_path,
|
||||
"status": "PENDING",
|
||||
"error": None,
|
||||
"traceback": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
background_tasks.add_task(
|
||||
execute_predict_task,
|
||||
task_id,
|
||||
payload.model_id,
|
||||
payload.input_zarr_path,
|
||||
payload.output_zarr_path,
|
||||
)
|
||||
return {"task_id": task_id, "status": "PENDING", "kind": "predict"}
|
||||
40
new/app/core/algorithms/__init__.py
Normal file
40
new/app/core/algorithms/__init__.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""
|
||||
去耀斑算法包
|
||||
============
|
||||
|
||||
通过「注册表 + 策略模式」组织不同的去耀斑算法。
|
||||
所有具体算法都应继承 BaseGlintRemover,并使用 @register_glint_remover
|
||||
装饰器把算法名和实现类绑定。
|
||||
|
||||
外部调用约定
|
||||
------------
|
||||
1. 所有算法子模块必须在本 __init__ 中显式 import,
|
||||
这样装饰器才会被执行、注册表才会被填满。
|
||||
2. 上层(endpoints、worker)只允许:
|
||||
from app.core.algorithms import get_remover
|
||||
来获取算法类,不要直接 import 具体实现类,
|
||||
保持调度层与具体算法的解耦。
|
||||
"""
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
from app.core.algorithms.registry import (
|
||||
get_remover,
|
||||
list_removers,
|
||||
register_glint_remover,
|
||||
unregister_glint_remover,
|
||||
)
|
||||
|
||||
# ---- 算法子模块 import 区 ----
|
||||
# 新增算法时,在这里加一行 import,确保装饰器被执行。
|
||||
from app.core.algorithms import goodman # Goodman
|
||||
from app.core.algorithms import kutser # Kutser
|
||||
# from app.core.algorithms import hedley # Hedley
|
||||
# from app.core.algorithms import sugar # SUGAR
|
||||
|
||||
__all__ = [
|
||||
"BaseGlintRemover",
|
||||
"register_glint_remover",
|
||||
"get_remover",
|
||||
"list_removers",
|
||||
"unregister_glint_remover",
|
||||
]
|
||||
85
new/app/core/algorithms/base.py
Normal file
85
new/app/core/algorithms/base.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""
|
||||
去耀斑算法抽象基类
|
||||
==================
|
||||
|
||||
设计目标(策略模式 Strategy Pattern)
|
||||
------------------------------------
|
||||
本模块定义了所有去耀斑算法必须遵守的标准接口。
|
||||
未来的 Kutser、Goodman、Hedley、SUGAR 等算法都将继承本基类,
|
||||
并实现统一的 process() 方法。
|
||||
|
||||
输入输出规范
|
||||
------------
|
||||
所有算法的输入与输出均统一为 **Zarr 文件路径**(字符串),
|
||||
而不是内存中的 numpy ndarray。这样做的核心收益是:
|
||||
|
||||
1. **解耦数据存储与内存计算**:
|
||||
算法只关心「从哪个 zarr 读、写到哪个 zarr」,
|
||||
至于数据最初来自 GeoTIFF / HDF5 / NetCDF / 内存数组,
|
||||
都由 IO 层负责归一化转为 zarr。
|
||||
2. **支持 Out-of-Core 计算**:
|
||||
影像往往超过内存上限,zarr 分块(chunk)天然支持按块读取,
|
||||
算法实现可以借助 dask / xarray 进行流式计算。
|
||||
3. **可缓存、可复用**:
|
||||
中间产物落盘后,下游算法(大气校正、辐射定标)能直接消费,
|
||||
避免重复 IO。
|
||||
4. **易于并行与分布式**:
|
||||
任务调度层只需把两个路径扔给 worker,无需关心数据细节。
|
||||
|
||||
约定
|
||||
----
|
||||
- 子类应实现 process(),完成「读 -> 计算 -> 写」的完整流程。
|
||||
- process() 返回 True 表示成功,False 表示失败。
|
||||
- 失败时建议抛出异常而非仅返回 False,便于上层 BackgroundTasks 捕获并写入 error 字段。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseGlintRemover(ABC):
|
||||
"""
|
||||
去耀斑算法抽象基类。
|
||||
|
||||
所有具体算法(Kutser / Goodman / Hedley / SUGAR …)必须继承本类并实现 process()。
|
||||
子类可在 __init__ 中接收自己的超参数(如参考波段、阈值等),
|
||||
真正的输入输出数据则由 process() 的两个 zarr 路径参数指定。
|
||||
"""
|
||||
|
||||
# 子类可覆盖的算法名称标识,用于调度层按 method 名字查找
|
||||
name: str = "base"
|
||||
|
||||
@abstractmethod
|
||||
async def process(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
执行去耀斑处理。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_zarr_path : str
|
||||
输入高光谱影像的 zarr 存储路径。
|
||||
数据已由 IO 层完成格式归一化(波段、坐标系、空间维度均已对齐)。
|
||||
output_zarr_path : str
|
||||
处理结果(去耀斑后影像)的 zarr 存储路径。
|
||||
子类需自行创建该 zarr 存储并写入结果。
|
||||
**kwargs : Any
|
||||
算法的可选超参数,例如:
|
||||
- reference_band: 参考近红外波段索引
|
||||
- chunk_size: 计算分块大小
|
||||
- 其它算法特定参数
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True 表示处理成功,False 表示失败。
|
||||
建议在出错时直接 raise,由调用方统一记录到任务状态。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - 调试辅助
|
||||
return f"<{self.__class__.__name__} name={self.name!r}>"
|
||||
123
new/app/core/algorithms/goodman.py
Normal file
123
new/app/core/algorithms/goodman.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""
|
||||
app/core/algorithms/goodman.py
|
||||
===============================
|
||||
|
||||
Goodman et al. 2008 去耀斑算法的 xarray + dask 流式实现。
|
||||
|
||||
算法公式
|
||||
--------
|
||||
R_corrected = R_raw - R_750 + A + B * (R_640 - R_750)
|
||||
|
||||
其中:
|
||||
R_raw -- 原始反射率 (y, x, band)
|
||||
R_750 -- λ=750 nm 处的反射率(红外参考波段, 远离水汽吸收)
|
||||
R_640 -- λ=640 nm 处的反射率(可见光差异波段)
|
||||
A, B -- 经验回归参数(用户可通过 params 传入, 默认全 0)
|
||||
|
||||
后处理
|
||||
------
|
||||
- 负值截断为 0(Clamp to 0)
|
||||
- 仅在水域掩膜 (water_mask) 内生效, 水外置 0
|
||||
|
||||
维度约定
|
||||
--------
|
||||
reflectance: (y, x, band), band 坐标通常为 wavelength (nm)
|
||||
water_mask : (y, x), 布尔类型, True = 水域
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import xarray as xr
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
from app.core.algorithms.registry import register_glint_remover
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 默认参数
|
||||
# ---------------------------------------------------------------------------
|
||||
# 与原始 Goodman 2008 论文符号保持一致, 方便用户交叉对照。
|
||||
# A、B 通常通过对纯净深水区做 (R_corr - R_raw) ~ (R_640 - R_750) 回归得到;
|
||||
# 在缺乏先验知识时, 退化为 A=0, B=0 即等价于 R_corrected = clip(R_raw - R_750, 0)。
|
||||
# ---------------------------------------------------------------------------
|
||||
DEFAULT_BAND_REF: float = 750.0 # λ_750 nm, 红外参考波段
|
||||
DEFAULT_BAND_DIFF: float = 640.0 # λ_640 nm, 可见光差异波段
|
||||
DEFAULT_A: float = 0.0 # 公式中的常数偏移项
|
||||
DEFAULT_B: float = 0.0 # 公式中的斜率项
|
||||
|
||||
|
||||
@register_glint_remover("goodman")
|
||||
class GoodmanGlintRemover(BaseGlintRemover):
|
||||
"""Goodman et al. 2008 去耀斑算法"""
|
||||
|
||||
name = "goodman"
|
||||
|
||||
async def process(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
# 1. 解析超参数(带默认值, 方便用户按需覆盖)
|
||||
band_ref: float = kwargs.get("band_ref", DEFAULT_BAND_REF)
|
||||
band_diff: float = kwargs.get("band_diff", DEFAULT_BAND_DIFF)
|
||||
A: float = kwargs.get("A", DEFAULT_A)
|
||||
B: float = kwargs.get("B", DEFAULT_B)
|
||||
|
||||
# 2. 把同步的 xarray/dask 计算丢到工作线程,
|
||||
# 避免阻塞 FastAPI 的事件循环
|
||||
return await asyncio.to_thread(
|
||||
self._process_sync,
|
||||
input_zarr_path,
|
||||
output_zarr_path,
|
||||
band_ref,
|
||||
band_diff,
|
||||
A,
|
||||
B,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_sync(
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
band_ref: float,
|
||||
band_diff: float,
|
||||
A: float,
|
||||
B: float,
|
||||
) -> bool:
|
||||
# 1. 以 zarr 路径打开(dask-backed, 不物化到内存)
|
||||
# chunks="auto" 让 dask 根据每条坐标轴的大小自动决定分块
|
||||
ds = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
reflectance = ds["reflectance"] # (y, x, band)
|
||||
|
||||
# 2. 用 sel + method='nearest' 提取两个关键波段
|
||||
# 返回形状 (y, x), 后续与 (y, x, band) 算术时会自动广播
|
||||
R_750 = reflectance.sel(band=band_ref, method="nearest")
|
||||
R_640 = reflectance.sel(band=band_diff, method="nearest")
|
||||
|
||||
# 3. Goodman 公式: xarray 沿 band 维度自动广播
|
||||
# R_corr = R_raw - R_750 + A + B * (R_640 - R_750)
|
||||
result = reflectance - R_750 + A + B * (R_640 - R_750)
|
||||
|
||||
# 4. 负值截断为 0(clip(min=0) 优于 where(>0, 0, _):
|
||||
# 不构造布尔中间数组, 底层走 dask 矢量化 clip 路径)
|
||||
result = result.clip(min=0)
|
||||
|
||||
# 5. 仅在水域内生效(水外强制为 0)
|
||||
# 优先从 zarr 内部读 water_mask 变量, 缺失则视为全图水域
|
||||
if "water_mask" in ds:
|
||||
water_mask = ds["water_mask"].astype(bool)
|
||||
result = result.where(water_mask, 0)
|
||||
|
||||
# 6. 构造输出 Dataset, 保留元信息(波段坐标/属性等)
|
||||
out = xr.Dataset({"reflectance": result})
|
||||
if ds.attrs:
|
||||
out.attrs = dict(ds.attrs)
|
||||
if reflectance.attrs:
|
||||
out["reflectance"].attrs = dict(reflectance.attrs)
|
||||
|
||||
# 7. 流式写出(Out-of-Core):不一次性物化大数组,
|
||||
# dask 会按 chunk 边算边写, 内存峰值 ≈ 单个 chunk 大小
|
||||
out.to_zarr(output_zarr_path, mode="w", compute=True)
|
||||
return True
|
||||
211
new/app/core/algorithms/kutser.py
Normal file
211
new/app/core/algorithms/kutser.py
Normal file
@ -0,0 +1,211 @@
|
||||
"""
|
||||
Kutser 去耀斑算法(xarray + dask 重构版)
|
||||
========================================
|
||||
|
||||
旧版痛点
|
||||
--------
|
||||
原始 Kutser 实现(参考 Kutser et al., 2013)通常写成像这样:
|
||||
|
||||
R_corr = np.zeros_like(R_raw)
|
||||
for b in range(n_bands):
|
||||
for y in range(H):
|
||||
for x in range(W):
|
||||
if water_mask[y, x]:
|
||||
R_corr[y, x, b] = (
|
||||
R_raw[y, x, b] - G_list[b] * D_norm[y, x]
|
||||
)
|
||||
with rasterio.open(..., 'w') as dst:
|
||||
dst.write(R_corr)
|
||||
|
||||
问题:
|
||||
1. 三重 Python 循环,每次只做一个浮点运算,解释器开销巨大;
|
||||
2. 一次性把整张图 R_raw 读进内存,大影像直接 OOM;
|
||||
3. rasterio 写出要求 numpy 连续数组,进一步放大内存。
|
||||
|
||||
本文件用 xarray + dask 重写:
|
||||
- 用 DataArray 维度广播,三重循环 → 一行表达式;
|
||||
- 用 dask chunk 保持数据常驻磁盘、流式计算;
|
||||
- 用 to_zarr 边算边写,输出格式与算法层彻底解耦。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import xarray as xr
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
from app.core.algorithms.registry import register_glint_remover
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 算法实现
|
||||
# ---------------------------------------------------------------------------
|
||||
@register_glint_remover("kutser")
|
||||
class KutserGlintRemover(BaseGlintRemover):
|
||||
"""
|
||||
Kutser 近红外扣除法去耀斑。
|
||||
|
||||
数学公式(与旧版完全等价)
|
||||
-------------------------
|
||||
1) 水汽吸收深度 D(每像素):
|
||||
D = (R(λ_lower) + R(λ_upper)) / 2 - R(λ_oxy)
|
||||
2) 全局归一化因子 D_max:
|
||||
D_max = max(D) over 水域
|
||||
归一化:
|
||||
D_norm = D / D_max
|
||||
3) 每波段水域范围:
|
||||
G_list[b] = max(R[:, :, b] over 水域) - min(R[:, :, b] over 水域)
|
||||
4) 校正公式(每像素、每波段):
|
||||
R_corr(λ_b) = R_raw(λ_b) - G_list[b] * D_norm
|
||||
"""
|
||||
|
||||
# Kutser 2013 论文里使用的参考波段(nm):
|
||||
# λ_lower = 773, λ_oxy = 845, λ_upper = 893
|
||||
# 允许通过 kwargs 覆盖,便于适配 MERIS / OLCI / Landsat 等不同传感器。
|
||||
DEFAULT_BAND_LOWER: float = 773.0
|
||||
DEFAULT_BAND_OXY: float = 845.0
|
||||
DEFAULT_BAND_UPPER: float = 893.0
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# 公开异步入口
|
||||
# --------------------------------------------------------------
|
||||
# xarray / dask 的算子本身是同步阻塞的。在 async 函数中,
|
||||
# 用 asyncio.to_thread 把同步体丢到默认线程池执行,
|
||||
# 避免阻塞 FastAPI 的事件循环。
|
||||
# --------------------------------------------------------------
|
||||
async def process(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
return await asyncio.to_thread(
|
||||
self._process_sync,
|
||||
input_zarr_path,
|
||||
output_zarr_path,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# 同步核心实现
|
||||
# --------------------------------------------------------------
|
||||
def _process_sync(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
kwargs: dict,
|
||||
) -> bool:
|
||||
# ============================================================
|
||||
# 步骤 0:打开 zarr,建立 dask 计算图
|
||||
# ============================================================
|
||||
# chunks="auto":让 dask 根据 zarr 的存储分块自动选择内存上限,
|
||||
# 数据不会一次性全部 materialize 进 RAM。
|
||||
# ============================================================
|
||||
ds = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
reflectance: xr.DataArray = ds["reflectance"] # 维度约定:(y, x, band)
|
||||
|
||||
# 维度顺序约定(也可根据 ds.dims 自动适配):
|
||||
assert "y" in reflectance.dims and "x" in reflectance.dims and "band" in reflectance.dims, (
|
||||
f"reflectance 必须包含 y/x/band 三个维度,实际为: {reflectance.dims}"
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 1:取出 3 个参考波段对应的二维 (y, x) 切片
|
||||
# ============================================================
|
||||
# 假设 band 维度的坐标是 wavelength(nm)。
|
||||
# 用 sel(..., method="nearest") 自动匹配最接近的波段。
|
||||
# ============================================================
|
||||
wl_lower = float(kwargs.get("band_lower", self.DEFAULT_BAND_LOWER))
|
||||
wl_oxy = float(kwargs.get("band_oxy", self.DEFAULT_BAND_OXY))
|
||||
wl_upper = float(kwargs.get("band_upper", self.DEFAULT_BAND_UPPER))
|
||||
|
||||
R_lower = reflectance.sel(band=wl_lower, method="nearest") # (y, x)
|
||||
R_upper = reflectance.sel(band=wl_upper, method="nearest") # (y, x)
|
||||
R_oxy = reflectance.sel(band=wl_oxy, method="nearest") # (y, x)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 2:水域掩膜
|
||||
# ============================================================
|
||||
# 优先从 zarr 内部读取 water_mask 变量;
|
||||
# 如果不存在,则假定整幅图都是水域(开发期兜底)。
|
||||
# ============================================================
|
||||
if "water_mask" in ds:
|
||||
water_mask = ds["water_mask"].astype(bool)
|
||||
else:
|
||||
water_mask = xr.ones_like(
|
||||
reflectance.isel(band=0), dtype=bool
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 3:水汽吸收深度 D(每像素,形状 (y, x))
|
||||
# ============================================================
|
||||
# 旧版:D[y, x] = (R_lower[y, x] + R_upper[y, x]) / 2 - R_oxy[y, x]
|
||||
# 新版:一行表达式,dask 自动构建 lazy 计算图。
|
||||
# ============================================================
|
||||
D = (R_lower + R_upper) / 2.0 - R_oxy # (y, x),dtype 与 reflectance 一致
|
||||
|
||||
# ============================================================
|
||||
# 步骤 4:全局归一化因子 D_max(标量,0-dim DataArray)
|
||||
# ============================================================
|
||||
# 关键:先 .where(water_mask) 把非水域置 NaN,
|
||||
# 再 .max() 跨 (x, y) 聚合,自动规约到 0 维。
|
||||
# dask 此时仍然没有真正计算,等到 to_zarr 时再触发。
|
||||
# ============================================================
|
||||
D_max = D.where(water_mask).max() # scalar
|
||||
# 容错:如果水域为空导致 D_max 为 NaN,用极小值兜底,避免除零
|
||||
D_max = D_max.fillna(1e-6)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 5:归一化 D_norm(形状 (y, x))
|
||||
# ============================================================
|
||||
D_norm = D / D_max # 标量除以 (y, x) 数组 → 自动广播
|
||||
|
||||
# ============================================================
|
||||
# 步骤 6:每波段水域范围 G_list(形状 (band,))
|
||||
# ============================================================
|
||||
# 旧版三重循环内部还要做一次 min/max 聚合。
|
||||
# xarray 版本:把 (y, x) 一起 reduce,只保留 band 维度。
|
||||
# ============================================================
|
||||
R_water = reflectance.where(water_mask) # (y, x, band),非水域 NaN
|
||||
G_min = R_water.min(dim=["x", "y"]) # (band,)
|
||||
G_max = R_water.max(dim=["x", "y"]) # (band,)
|
||||
G_list = (G_max - G_min).fillna(0.0) # (band,),容错
|
||||
|
||||
# ============================================================
|
||||
# 步骤 7:校正公式(最关键的一行,演示 xarray 广播)
|
||||
# ============================================================
|
||||
# 旧版需要:
|
||||
# for b in bands:
|
||||
# for y in range(H):
|
||||
# for x in range(W):
|
||||
# R_corr[y,x,b] = R_raw[y,x,b] - G_list[b] * D_norm[y,x]
|
||||
#
|
||||
# xarray 维度对齐规则:
|
||||
# R_raw : (y, x, band)
|
||||
# G_list: (band,) → 缺失 y, x 自动扩展
|
||||
# D_norm: (y, x) → 缺失 band 自动扩展
|
||||
# 乘法结果: (y, x, band) → 减法对齐
|
||||
# 一行表达式完成「三重 for 循环 + 标量索引」的语义。
|
||||
# ============================================================
|
||||
corrected = reflectance - G_list * D_norm # (y, x, band)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 8:水域掩膜过滤(非水域置 NaN)
|
||||
# ============================================================
|
||||
result = corrected.where(water_mask)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 9:持久化为 zarr
|
||||
# ============================================================
|
||||
# mode="w":覆盖写入(如果目标已存在则删除重建)。
|
||||
# compute=True:阻塞直到整张图算完并落盘。
|
||||
# 由于数据始终是 dask chunk + 流式写出,
|
||||
# 内存峰值 ≈ 单个 chunk 大小,与整张影像大小无关。
|
||||
# ============================================================
|
||||
out = xr.Dataset({"reflectance": result})
|
||||
# 保留原数据集的全局属性 / 坐标信息(CRS、wavelength、...)
|
||||
out.attrs = dict(ds.attrs)
|
||||
out["reflectance"].attrs = dict(reflectance.attrs)
|
||||
out.to_zarr(output_zarr_path, mode="w", compute=True)
|
||||
|
||||
return True
|
||||
135
new/app/core/algorithms/registry.py
Normal file
135
new/app/core/algorithms/registry.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""
|
||||
算法注册表(Registry / Factory)
|
||||
================================
|
||||
|
||||
通过装饰器把「算法名字符串」与「算法实现类」绑定在一起。
|
||||
上层调度层(FastAPI endpoints、BackgroundTasks worker)只需要拿到
|
||||
前端传过来的 method 字符串,就可以自动派发到对应的算法实现,
|
||||
而无需写一长串 if/elif。
|
||||
|
||||
使用示例
|
||||
--------
|
||||
|
||||
from app.core.algorithms import BaseGlintRemover
|
||||
from app.core.algorithms.registry import (
|
||||
register_glint_remover,
|
||||
get_remover,
|
||||
list_removers,
|
||||
)
|
||||
|
||||
@register_glint_remover("kutser")
|
||||
class KutserGlintRemover(BaseGlintRemover):
|
||||
async def process(self, input_zarr_path, output_zarr_path, **kwargs):
|
||||
...
|
||||
|
||||
# 派发
|
||||
Cls = get_remover(method_from_request)
|
||||
remover = Cls()
|
||||
await remover.process(input_zarr_path, output_zarr_path, **kwargs)
|
||||
|
||||
设计要点
|
||||
--------
|
||||
- 注册动作发生在「类定义时」,所以必须在所有算法 import 完之后
|
||||
注册表才完整。可以在 `app/core/algorithms/__init__.py` 中
|
||||
把算法子模块 import 一遍来强制触发注册。
|
||||
- 重复注册同名算法会直接抛错,避免静默覆盖。
|
||||
- name 会同步写回到类的 `name` 属性,便于算法自身查询身份。
|
||||
"""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
|
||||
|
||||
# 全局注册表:name(str) -> 实现类(type),类未被实例化
|
||||
_REGISTRY: Dict[str, Type[BaseGlintRemover]] = {}
|
||||
|
||||
|
||||
def register_glint_remover(name: str):
|
||||
"""
|
||||
类装饰器工厂:把传入 name 的算法类注册到全局注册表。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
算法标识,建议小写下划线风格,例如 "kutser"、"goodman"。
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
- name 不是非空字符串
|
||||
- name 已经被其它类占用
|
||||
TypeError
|
||||
- 被装饰的对象不是 BaseGlintRemover 的子类
|
||||
"""
|
||||
|
||||
# ---- 防御性校验:name 必须是合法字符串 ----
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError(
|
||||
f"register_glint_remover 的 name 必须是非空字符串,收到: {name!r}"
|
||||
)
|
||||
|
||||
def decorator(cls: Type[BaseGlintRemover]) -> Type[BaseGlintRemover]:
|
||||
# ---- 防御性校验:被装饰对象必须是 BaseGlintRemover 子类 ----
|
||||
if not isinstance(cls, type) or not issubclass(cls, BaseGlintRemover):
|
||||
raise TypeError(
|
||||
f"@register_glint_remover 只能装饰 BaseGlintRemover 的子类,"
|
||||
f"收到: {cls!r}"
|
||||
)
|
||||
|
||||
# ---- 防御性校验:禁止静默覆盖 ----
|
||||
if name in _REGISTRY:
|
||||
raise ValueError(
|
||||
f"算法名 {name!r} 已被 {_REGISTRY[name].__name__} 占用,"
|
||||
f"请使用其它名字或先调用 unregister_glint_remover() 注销旧实现。"
|
||||
)
|
||||
|
||||
# 同步把 name 写回类属性,便于算法自身和日志输出使用
|
||||
cls.name = name
|
||||
_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_remover(name: str) -> Type[BaseGlintRemover]:
|
||||
"""
|
||||
按算法名字符串取出对应的实现类(未实例化)。
|
||||
|
||||
调用方拿到类后自行 `Cls(...)` 构造实例,再调用 process()。
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
当 name 不在注册表中时抛出,错误信息中附带已注册列表便于排查。
|
||||
"""
|
||||
try:
|
||||
return _REGISTRY[name]
|
||||
except KeyError as exc:
|
||||
known = ", ".join(sorted(_REGISTRY)) or "<空>"
|
||||
raise KeyError(
|
||||
f"未注册的算法名: {name!r}。已注册的算法: {known}"
|
||||
) from exc
|
||||
|
||||
|
||||
def list_removers() -> Dict[str, Type[BaseGlintRemover]]:
|
||||
"""
|
||||
返回当前注册表的浅拷贝。
|
||||
可用于:
|
||||
- 调试日志
|
||||
- 给前端暴露一个 GET /api/algorithms 接口
|
||||
- 单元测试断言
|
||||
"""
|
||||
return dict(_REGISTRY)
|
||||
|
||||
|
||||
def unregister_glint_remover(name: str) -> None:
|
||||
"""
|
||||
注销指定算法。主要给:
|
||||
- 单元测试
|
||||
- 热重载 / 插件卸载场景
|
||||
生产代码一般不需要调用。
|
||||
"""
|
||||
if name not in _REGISTRY:
|
||||
raise KeyError(f"未注册的算法名: {name!r}")
|
||||
del _REGISTRY[name]
|
||||
91
new/app/core/task_store.py
Normal file
91
new/app/core/task_store.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""
|
||||
app/core/task_store.py
|
||||
======================
|
||||
|
||||
并发安全的内存任务状态存储,替代早期 mock 流水线中的 MOCK_TASK_DB。
|
||||
|
||||
设计目标
|
||||
--------
|
||||
1. 在单进程内提供事件循环级别的互斥(asyncio.Lock),
|
||||
避免在 update 与 set/get 之间穿插 await 时发生状态不一致。
|
||||
2. 暴露异步 API(set_task / update_task / get_task),
|
||||
让调用方在 async 上下文中显式表达临界区。
|
||||
3. 保留一个同步的 has_task() 用于轻量存在性判断。
|
||||
4. 生产环境应替换为 Redis / SQLite / PostgreSQL,
|
||||
但接口形状保持一致, 便于上层调用方无缝迁移。
|
||||
|
||||
使用约定
|
||||
--------
|
||||
- 写入初始 PENDING 记录: await set_task(task_id, record)
|
||||
- 增量更新字段(PROCESSING/SUCCESS/FAILED):await update_task(task_id, **fields)
|
||||
- 读取任务记录: await get_task(task_id) # 可能返回 None
|
||||
- 同步判断是否存在: has_task(task_id)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 全局存储与锁
|
||||
# ---------------------------------------------------------------------------
|
||||
# TASK_STORE: task_id -> 任务记录
|
||||
# 任务记录字段约定(与 endpoints.py 保持一致):
|
||||
# task_id, method, params, status,
|
||||
# output_zarr_path, error, traceback,
|
||||
# created_at, updated_at
|
||||
# ---------------------------------------------------------------------------
|
||||
TASK_STORE: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 单进程内的事件循环级互斥锁
|
||||
# 注意:asyncio.Lock 必须在事件循环内创建, 故在模块顶层实例化时
|
||||
# 仅获取引用, 第一次使用 (await lock.acquire()) 会在运行循环内进行。
|
||||
_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 异步 API
|
||||
# ---------------------------------------------------------------------------
|
||||
async def set_task(task_id: str, record: Dict[str, Any]) -> None:
|
||||
"""
|
||||
初始化或整体覆盖一个任务记录。
|
||||
|
||||
用法:POST 端点收到提交请求后立即调用, 写入 PENDING 状态的初始记录。
|
||||
"""
|
||||
async with _lock:
|
||||
TASK_STORE[task_id] = record
|
||||
|
||||
|
||||
async def update_task(task_id: str, **fields: Any) -> None:
|
||||
"""
|
||||
按字段增量更新任务记录。
|
||||
|
||||
用法:后台执行器在 PROCESSING / SUCCESS / FAILED 等状态切换时调用。
|
||||
若 task_id 不存在, setdefault 会自动创建一个空 dict 再 update(防御性兜底)。
|
||||
"""
|
||||
async with _lock:
|
||||
record = TASK_STORE.setdefault(task_id, {})
|
||||
record.update(fields)
|
||||
|
||||
|
||||
async def get_task(task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取任务记录; 不存在时返回 None。
|
||||
|
||||
用法:GET /api/tasks/{task_id} 用此接口查询。
|
||||
"""
|
||||
async with _lock:
|
||||
return TASK_STORE.get(task_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 同步 API(轻量)
|
||||
# ---------------------------------------------------------------------------
|
||||
def has_task(task_id: str) -> bool:
|
||||
"""
|
||||
同步判断 task_id 是否存在。
|
||||
|
||||
适用于不需要锁的轻量场景(例如日志前置判断);
|
||||
在 async 上下文中仍可调用, 因为 dict 的 in 判断是原子操作。
|
||||
"""
|
||||
return task_id in TASK_STORE
|
||||
62
new/app/main.py
Normal file
62
new/app/main.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""
|
||||
WQ_GUI FastAPI 后端入口
|
||||
=======================
|
||||
|
||||
应用启动与全局中间件配置:
|
||||
- CORS:开发阶段允许所有来源,方便本地前端(Vite / Webpack dev server)联调
|
||||
- 路由:通过 include_router 挂载 app/api/endpoints.py 中的业务接口
|
||||
|
||||
业务接口说明:
|
||||
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
|
||||
GET /api/tasks/{task_id} 查询指定任务的状态与结果
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.endpoints import router as deglint_router
|
||||
from app.api.modeling import router as modeling_router
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FastAPI 应用实例
|
||||
# ---------------------------------------------------------------------------
|
||||
app = FastAPI(
|
||||
title="WQ_GUI Backend",
|
||||
description="高光谱影像去耀斑处理 API",
|
||||
version="0.2.0",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CORS 中间件
|
||||
# ---------------------------------------------------------------------------
|
||||
# 开发阶段:放开所有来源、方法和头部,方便本地前端(任意端口)联调。
|
||||
# 生产环境务必收敛 allow_origins 为前端真实域名,避免安全风险。
|
||||
# ---------------------------------------------------------------------------
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 路由注册
|
||||
# ---------------------------------------------------------------------------
|
||||
# 统一以 /api 为前缀,便于将来做版本管理(如 /api/v1、/api/v2)。
|
||||
# ---------------------------------------------------------------------------
|
||||
app.include_router(deglint_router, prefix="/api")
|
||||
app.include_router(modeling_router, prefix="/api")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 根路径健康检查(方便本地调试,非业务必需)
|
||||
# ---------------------------------------------------------------------------
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, str]:
|
||||
return {"service": "WQ_GUI Backend", "status": "ok"}
|
||||
Reference in New Issue
Block a user