refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名

This commit is contained in:
DXC
2026-06-03 17:29:41 +08:00
parent 517bb28611
commit 343e316799
99 changed files with 9127 additions and 91 deletions

View File

@ -0,0 +1,201 @@
"""
冒烟测试 _run_train_sync: 用合成数据走通真实训练管线。
不依赖 FastAPI / xarray / dask, 只验训练 + 持久化 + 回测。
"""
import sys
import tempfile
from pathlib import Path
import numpy as np
import pandas as pd
# 绕过 main.py 触发 app 包导入(只导入 modeling 模块)
# 当前文件位于 new/app/api/_smoke_test_train.py
# app 包在 new/app/__init__.py, 故 new/ 必须在 sys.path 上
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from app.api.modeling import (
_get_model_pipeline,
_load_train_df,
_resolve_feature_start,
_run_train_sync,
_MODEL_CLASS_REGISTRY,
)
def make_synthetic_csv(n_samples: int = 200, n_features: int = 8, noise: float = 0.1, seed: int = 42) -> Path:
"""生成 [lat, lon, target, lat2, lon2, feat_0, feat_1, ...] 布局的 CSV"""
rng = np.random.default_rng(seed)
lat = rng.uniform(20, 25, n_samples)
lon = rng.uniform(110, 115, n_samples)
target = rng.uniform(0, 50, n_samples)
lat2 = rng.uniform(0, 1, n_samples) # 元数据
lon2 = rng.uniform(0, 1, n_samples) # 元数据
feats = rng.normal(0, 1, (n_samples, n_features))
# 让 y 真正依赖前 3 个特征, RF 至少应该能学到 R² > 0.5
feats[:, 0] += target / 10
feats[:, 1] += target / 20
feats[:, 2] -= target / 15
df = pd.DataFrame({
"lat": lat,
"lon": lon,
"Chl-a": target,
"lat2": lat2,
"lon2": lon2,
**{f"feat_{i}": feats[:, i] for i in range(n_features)},
})
tmp = Path(tempfile.mkdtemp()) / "train.csv"
df.to_csv(tmp, index=False)
return tmp
def test_load_train_df():
print("== test_load_train_df ==")
p = make_synthetic_csv(n_samples=50)
df = _load_train_df(str(p))
assert df.shape == (50, 5 + 8), f"shape={df.shape}"
print(f" shape={df.shape}, columns[:6]={list(df.columns[:6])}")
print(" PASS")
def test_resolve_feature_start_int_and_str():
print("== test_resolve_feature_start (int + str) ==")
p = make_synthetic_csv()
df = _load_train_df(str(p))
idx_int = _resolve_feature_start(df, 5)
idx_str = _resolve_feature_start(df, "feat_0")
assert idx_int == 5 == idx_str, f"int={idx_int}, str={idx_str}"
print(f" int(5) -> {idx_int}, str('feat_0') -> {idx_str}")
print(" PASS")
def test_resolve_feature_start_str_miss():
print("== test_resolve_feature_start (str 不存在 -> 抛错) ==")
p = make_synthetic_csv()
df = _load_train_df(str(p))
try:
_resolve_feature_start(df, "not_exist")
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_get_model_pipeline_all_types():
print("== test_get_model_pipeline (5 种 model_type) ==")
for mt in ["RF", "SVR", "LinearRegression", "KNN", "PLS"]:
p = _get_model_pipeline(mt, {})
assert len(p.steps) == 2
assert p.steps[0][0] == "scaler"
assert p.steps[1][0] == "model"
print(f" 全部通过: {list(_MODEL_CLASS_REGISTRY)}")
print(" PASS")
def test_get_model_pipeline_bad_type():
print("== test_get_model_pipeline (坏 model_type) ==")
try:
_get_model_pipeline("XGBoost", {})
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_run_train_sync_rf_end_to_end():
print("== test_run_train_sync (RF 端到端) ==")
p = make_synthetic_csv(n_samples=200)
out_dir = Path(tempfile.mkdtemp())
out_path = out_dir / "model.joblib"
import time
t0 = time.time()
metadata = _run_train_sync(
model_type="RF",
target="Chl-a",
train_data_path=str(p),
feature_start=5,
params={"n_estimators": 30, "max_depth": 6, "random_state": 42, "n_jobs": 1},
output_model_path=out_path,
)
dt = time.time() - t0
assert out_path.exists(), f"joblib 未落盘: {out_path}"
print(f" joblib 落盘: {out_path} ({out_path.stat().st_size} bytes)")
print(f" metadata.test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f} test_mae={metadata['test_mae']:.4f}")
print(f" metadata.n_features={metadata['n_features']} n_samples={metadata['n_samples']} train_size={metadata['train_size']} test_size={metadata['test_size']}")
print(f" 耗时 {dt:.2f}s")
# 回测: 加载 joblib 再 predict
import joblib
saved = joblib.load(out_path)
assert "model" in saved and "metadata" in saved, f"joblib 双 key 缺失: {saved.keys()}"
assert hasattr(saved["model"], "predict")
assert saved["metadata"]["test_r2"] == metadata["test_r2"]
print(f" joblib 加载 OK, 含 'model''metadata' 双 key")
print(" PASS")
def test_run_train_sync_linearregression_fast():
print("== test_run_train_sync (LinearRegression 快速路径) ==")
p = make_synthetic_csv(n_samples=150)
out_path = Path(tempfile.mkdtemp()) / "lr.joblib"
metadata = _run_train_sync(
model_type="LinearRegression",
target="Chl-a",
train_data_path=str(p),
feature_start=5,
params={},
output_model_path=out_path,
)
print(f" test_r2={metadata['test_r2']:.4f} (LR 学到线性, R² 应 >= 0.4)")
assert metadata["test_r2"] > 0.3, f"LR test_r2={metadata['test_r2']} 太低, 数据生成可能有问题"
print(" PASS")
def test_run_train_sync_bad_csv():
print("== test_run_train_sync (CSV 不存在) ==")
try:
_run_train_sync("RF", "Chl-a", "/no/such/path.csv", 5, {}, Path("/tmp/x.joblib"))
print(" FAIL: 应抛异常")
except (FileNotFoundError, ValueError) as e:
print(f" 正确抛 {type(e).__name__}: {e}")
print(" PASS")
def test_run_train_sync_bad_target():
print("== test_run_train_sync (target 列不存在) ==")
p = make_synthetic_csv()
try:
_run_train_sync("RF", "NopeTarget", str(p), 5, {}, Path("/tmp/x.joblib"))
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_run_train_sync_str_feature_start():
print("== test_run_train_sync (feature_start 用列名) ==")
p = make_synthetic_csv()
out_path = Path(tempfile.mkdtemp()) / "str_fs.joblib"
metadata = _run_train_sync("RF", "Chl-a", str(p), "feat_0", {"n_estimators": 10}, out_path)
assert metadata["feature_start"] == "feat_0"
assert metadata["n_features"] == 8
assert metadata["feature_columns"][0] == "feat_0"
print(f" 列名 'feat_0' 解析正确, n_features={metadata['n_features']}")
print(" PASS")
if __name__ == "__main__":
test_load_train_df()
test_resolve_feature_start_int_and_str()
test_resolve_feature_start_str_miss()
test_get_model_pipeline_all_types()
test_get_model_pipeline_bad_type()
test_run_train_sync_rf_end_to_end()
test_run_train_sync_linearregression_fast()
test_run_train_sync_bad_csv()
test_run_train_sync_bad_target()
test_run_train_sync_str_feature_start()
print("\n>>> ALL SMOKE TESTS PASSED")

222
new/app/api/endpoints.py Normal file
View File

@ -0,0 +1,222 @@
"""
API 路由集合
============
把业务接口统一收口到 APIRouter再由 main.py 通过 include_router 挂载。
当前包含的接口:
GET /api/algorithms 列出已注册的所有去耀斑算法(供前端下拉框)
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
GET /api/tasks/{task_id} 查询指定任务的状态与结果
派发链:
POST /api/process/deglint
└─ BackgroundTasks.add_task(execute_glint_removal_task, ...)
└─ get_remover(method) 从注册表拿到算法类
└─ remover.process(input_zarr, output_zarr, **params)
"""
import traceback
import uuid
from datetime import datetime
from typing import Any, Dict
from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel, Field
# 并发安全的任务状态存储(替代旧版的 MOCK_TASK_DB
from app.core.task_store import get_task, set_task, update_task
# 算法注册表 API
from app.core.algorithms import get_remover, list_removers
# ---------------------------------------------------------------------------
# 路由实例
# ---------------------------------------------------------------------------
# prefix 不在此处设置,统一在 main.py 挂载时给定,便于将来按版本拆分
# (例如 /api/v1、/api/v2 共存时复用同一个 router 对象)。
# ---------------------------------------------------------------------------
router = APIRouter(tags=["deglint"])
# ---------------------------------------------------------------------------
# 请求 / 响应数据模型
# ---------------------------------------------------------------------------
class DeglintRequest(BaseModel):
"""POST /api/process/deglint 的请求体"""
method: str = Field(
...,
description="去耀斑方法名称,必须是已注册算法,例如 'kutser' / 'goodman'",
examples=["kutser"],
)
params: Dict[str, Any] = Field(
default_factory=dict,
description=(
"传递给算法 process() 的超参数字典,例如 "
"Kutser: {'band_lower': 773, 'band_oxy': 845, 'band_upper': 893}; "
"Goodman: {'band_ref': 750, 'band_diff': 640, 'A': 0.0, 'B': 0.0}"
),
examples=[{"band_lower": 773, "band_oxy": 845, "band_upper": 893}],
)
class TaskAcceptedResponse(BaseModel):
"""提交任务成功后立即返回的响应"""
task_id: str
status: str # 一定是 PENDING
class AlgorithmListResponse(BaseModel):
"""GET /api/algorithms 的响应"""
algorithms: list # 已注册算法名列表
count: int # 算法总数
# ---------------------------------------------------------------------------
# 后台任务执行器(真实派发链)
# ---------------------------------------------------------------------------
# 注意:这里使用 async def。
# FastAPI / Starlette 的 BackgroundTasks 支持 async function
# 会在响应返回后自动 await 它,不影响主请求链路。
# ---------------------------------------------------------------------------
async def execute_glint_removal_task(
task_id: str,
method: str,
params: Dict[str, Any],
) -> None:
"""
后台异步执行器:按 method 名字从注册表取出算法类,实例化并运行 process()。
状态机:
PENDING -> PROCESSING -> SUCCESS
└──> FAILED含 error / traceback
"""
# 0. 安全检查任务记录必须已存在POST 阶段已写入)
record = await get_task(task_id)
if record is None:
print(f"[{task_id}] 任务不存在, 跳过")
return
# 1. 状态推进到 PROCESSING
await update_task(
task_id,
status="PROCESSING",
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 开始处理 method={method} params={params}")
# 2. 临时硬编码 IO 路径(未来由数据管理层提供)
# TODO: 替换为真实的数据管理服务返回的 zarr 路径
input_zarr_path = "./data/temp_in.zarr"
output_zarr_path = f"./data/{task_id}_out.zarr"
try:
# 3. 按 method 名字从注册表取算法类并实例化
# get_remover 找不到时会抛 KeyError下面的 except 会兜住
algorithm_cls = get_remover(method)
remover = algorithm_cls()
# 4. 调用算法(注意 await因为 BaseGlintRemover.process 是 async
await remover.process(input_zarr_path, output_zarr_path, **params)
# 5. 成功:写回结果路径与状态
await update_task(
task_id,
status="SUCCESS",
output_zarr_path=output_zarr_path,
error=None,
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 处理完成 -> SUCCESS, output={output_zarr_path}")
except Exception as exc: # noqa: BLE001 顶层兜底,绝不让后台任务静默失败
# 6. 失败:记录错误信息与堆栈,便于前端排查
await update_task(
task_id,
status="FAILED",
output_zarr_path=None,
error=f"{type(exc).__name__}: {exc}",
traceback=traceback.format_exc(),
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 处理失败 -> {type(exc).__name__}: {exc}")
# ---------------------------------------------------------------------------
# GET /algorithms
# ---------------------------------------------------------------------------
# 返回当前已注册的所有算法名,供前端动态渲染下拉框 / 选择器。
# ---------------------------------------------------------------------------
@router.get("/algorithms", response_model=AlgorithmListResponse)
async def list_registered_algorithms() -> Dict[str, Any]:
"""列出已注册的去耀斑算法。"""
names = list(list_removers().keys())
return {"algorithms": names, "count": len(names)}
# ---------------------------------------------------------------------------
# POST /process/deglint
# ---------------------------------------------------------------------------
# 提交去耀斑处理任务。FastAPI 在函数返回后才会把响应发给前端,
# 因此通过 BackgroundTasks 把耗时操作丢到后台,接口本身立刻返回 task_id。
# ---------------------------------------------------------------------------
@router.post("/process/deglint", response_model=TaskAcceptedResponse)
async def submit_deglint(
payload: DeglintRequest,
background_tasks: BackgroundTasks,
) -> Dict[str, Any]:
"""提交一个去耀斑处理任务,并立即返回 task_id。"""
# 1. 生成唯一任务 IDUUID4 足以保证全局唯一性)
task_id = str(uuid.uuid4())
# 2. 在任务库中登记一条 PENDING 记录(并发安全)
# 注意output_zarr_path / error / traceback 字段在执行过程中被填充
await set_task(
task_id,
{
"task_id": task_id,
"method": payload.method,
"params": payload.params,
"status": "PENDING",
"output_zarr_path": None,
"error": None,
"traceback": None,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
},
)
# 3. 把真实执行器丢到后台
background_tasks.add_task(
execute_glint_removal_task,
task_id,
payload.method,
payload.params,
)
# 4. 立即返回 task_id 与 PENDING 状态
return {"task_id": task_id, "status": "PENDING"}
# ---------------------------------------------------------------------------
# GET /tasks/{task_id}
# ---------------------------------------------------------------------------
# 前端轮询此接口获取任务状态。PENDING / PROCESSING 表示仍在跑,
# SUCCESS 表示成功(含 output_zarr_pathFAILED 表示失败(含 error / traceback
# ---------------------------------------------------------------------------
@router.get("/tasks/{task_id}")
async def get_task_status(task_id: str) -> Dict[str, Any]:
"""查询指定任务的当前状态与结果。"""
record = await get_task(task_id)
if record is None:
# 找不到 task_id 通常意味着客户端拼错了 ID或者记录已被清理
raise HTTPException(status_code=404, detail=f"task_id 不存在: {task_id}")
# 直接返回字典FastAPI 会自动 JSON 序列化
return record

786
new/app/api/modeling.py Normal file
View File

@ -0,0 +1,786 @@
"""
app/api/modeling.py
===================
机器学习与水质反演相关的 API 路由。
接口(最终路径, 挂载后):
POST /api/modeling/train 提交模型训练任务, 立即返回 task_id
GET /api/modeling/models 列出已训练好的模型(未来从磁盘 joblib 读)
POST /api/modeling/predict 提交模型推断任务, 立即返回 task_id
设计要点
--------
- 训练 / 推断均为异步后台任务, 复用 app.core.task_store 的并发安全任务状态。
- 模型元数据用模块级 _MODEL_REGISTRY 暂存(开发期内存存储),
未来从磁盘 joblib 读时只需替换 list_trained_models() 内部实现即可。
- /predict 已接入真实 sklearn + xarray + dask 流式推断:
* joblib.load 读模型(缺文件时降级为 Dummy RandomForestRegressor
* xr.open_zarr 延迟打开影像, NaN 填 0
* xr.apply_ufunc(dask="parallelized") 沿 (y, x) 逐 chunk 调 model.predict
* to_zarr(mode="w", compute=True) 流式写出, 内存峰值 ≈ 1 个 chunk
- /train 已接入真实 sklearn + pandas 训练管线:
* pd.read_csv 读结构化训练表(支持 [lat, lon, target_*, feature_*] 布局)
* 按 target 列 dropna 清洗;按 feature_start 索引/列名切分特征
* sklearn Pipeline: StandardScaler -> {RF/SVR/LinearRegression/KNN/PLS}
* train_test_split(80/20) 划分, 计算 test_r2/rmse/mae
* joblib.dump({model, metadata}) 落盘 ./data/models/{model_id}.joblib
* 测试指标写回 TASK_STORE, 同时登记到 _MODEL_REGISTRY
注: 旧版 SPXY / KS 划分留作未来扩展, 当前固定 random 划分 (test_size=0.2, random_state=42)。
"""
import asyncio
import traceback
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import joblib
import numpy as np
import pandas as pd
import xarray as xr
from fastapi import APIRouter, BackgroundTasks
from pydantic import BaseModel, Field
from sklearn.cross_decomposition import PLSRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR
# 复用并发安全任务状态存储(与 deglint 共享同一份 TASK_STORE,
# 通过 task 记录里的 "kind" 字段区分 train / predict / deglint
from app.core.task_store import get_task, set_task, update_task
# ---------------------------------------------------------------------------
# 路由实例
# ---------------------------------------------------------------------------
# prefix="/modeling" 让本文件内只写 /train /models /predict 等短路径,
# 最终完整路径由 main.py 挂载时再补 /api。
# ---------------------------------------------------------------------------
router = APIRouter(prefix="/modeling", tags=["modeling"])
# ---------------------------------------------------------------------------
# 数据模型
# ---------------------------------------------------------------------------
class TrainRequest(BaseModel):
"""POST /api/modeling/train 的请求体"""
model_type: str = Field(
...,
description="模型类型, 例如 'RF' (随机森林) / 'SVR' (支持向量回归) / 'XGBoost' / 'MLP'",
examples=["RF", "SVR"],
)
target: str = Field(
...,
description="反演目标水质参数, 例如 'Chl-a' (叶绿素a) / 'TSS' (总悬浮物) / 'CDOM' (有色可溶有机物)",
examples=["Chl-a", "TSS", "CDOM"],
)
train_data_path: str = Field(
...,
description="训练数据集的 zarr 路径(包含 reflectance 变量与 target 标签)",
examples=["./data/train.zarr"],
)
feature_start: Union[int, str] = Field(
default=4,
description=(
"特征列起始位置. 表格布局假定为 "
"[lat, lon, target_1, target_2, ..., feature_1, feature_2, ...] "
"可传 int 列索引(如 4或 str 列名(如 '374.285' 波长起点)。"
"默认 4, 即前 4 列视为元数据/目标, 之后全部是特征。"
),
examples=[4, "374.285"],
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="模型超参, 例如 RF 的 {'n_estimators': 100, 'max_depth': 20}",
examples=[{"n_estimators": 100, "max_depth": 20}],
)
class PredictRequest(BaseModel):
"""POST /api/modeling/predict 的请求体"""
model_id: str = Field(
...,
description="已训练模型的 ID由 /api/modeling/train 返回或 /api/modeling/models 列出)",
)
input_zarr_path: str = Field(
...,
description="待推断影像的 zarr 路径",
examples=["./data/scene.zarr"],
)
output_zarr_path: Optional[str] = Field(
default=None,
description=(
"输出 zarr 路径, 缺省时由后端按规则生成 "
"(如 ./data/{model_id}_{input_stem}_pred.zarr)"
),
)
class TaskAcceptedResponse(BaseModel):
"""提交训练/推断任务后立即返回的响应"""
task_id: str
status: str # 一定是 PENDING
kind: str # "train" / "predict", 便于前端识别任务类型
class ModelInfo(BaseModel):
"""单个模型的元信息GET /api/modeling/models 的元素)"""
model_id: str
model_type: str
target: str
params: Dict[str, Any]
path: str # joblib 文件路径
created_at: str
train_task_id: str # 产生此模型的那个训练任务的 ID
class ModelListResponse(BaseModel):
"""GET /api/modeling/models 的响应"""
models: List[ModelInfo]
count: int
# ---------------------------------------------------------------------------
# 模块级模型注册表(开发期内存, 未来替换为磁盘扫描)
# ---------------------------------------------------------------------------
# model_id -> ModelInfo 字典
# 读多写少, 用一个普通 dict 足够CPython GIL 兜底)。
# 写时(训练完成时)只发生一次, 无并发风险。
# ---------------------------------------------------------------------------
_MODEL_REGISTRY: Dict[str, Dict[str, Any]] = {}
def _register_model(record: Dict[str, Any]) -> None:
"""将训练完成的模型登记到内存注册表。"""
_MODEL_REGISTRY[record["model_id"]] = record
# ---------------------------------------------------------------------------
# 训练管线的模块级辅助函数
# ---------------------------------------------------------------------------
# 设计要点 (与推断管线一致):
# 1) 模块级函数: dask / joblib 后端若走子进程 pickle, 嵌套闭包会丢字段。
# 2) 同步执行: execute_train_task 用 asyncio.to_thread 派发, 内部全程同步阻塞。
# 3) 失败抛异常: 异常由 execute_train_task 捕获, 转 FAILED + traceback。
# ---------------------------------------------------------------------------
# model_type (大写字符串) -> sklearn 估计器类
# 与 OpenClaw model_configs 思路一致, 但此处只保留类 (参数由 params 透传)
_MODEL_CLASS_REGISTRY: Dict[str, type] = {
"RF": RandomForestRegressor,
"SVR": SVR,
"LinearRegression": LinearRegression,
"KNN": KNeighborsRegressor,
"PLS": PLSRegression,
}
def _get_model_pipeline(model_type: str, params: Optional[Dict[str, Any]]) -> Pipeline:
"""
模型工厂: 按 model_type 选 sklearn 类, 用 StandardScaler + 估计器构造 Pipeline。
与 OpenClaw 不同之处: 把 scaler 放进 Pipeline 第一步,
推断时直接 pipeline.predict(X) 即可, scaler 参数与训练时严格一致。
"""
model_cls = _MODEL_CLASS_REGISTRY.get(model_type)
if model_cls is None:
raise ValueError(
f"不支持的 model_type='{model_type}', "
f"可选: {sorted(_MODEL_CLASS_REGISTRY.keys())}"
)
estimator = model_cls(**(params or {}))
return Pipeline([("scaler", StandardScaler()), ("model", estimator)])
def _load_train_df(csv_path: str) -> pd.DataFrame:
"""
读 CSV 训练表, 规整空串 / 空白 / NULL 等为 NaN。
沿用 OpenClaw modeling_batch.load_data_batch 的读取策略:
na_values 显式列举 + 正则二次清理 (防 cell 内出现 " " 等纯空白)。
"""
try:
df = pd.read_csv(
csv_path,
na_values=["", " ", "NaN", "nan", "NULL", "null"],
)
except FileNotFoundError as exc:
raise FileNotFoundError(f"训练数据文件不存在: {csv_path}") from exc
except pd.errors.EmptyDataError as exc:
raise ValueError(f"训练数据文件为空: {csv_path}") from exc
# 二次清理: 残留的纯空白 cell
df = df.replace(r"^\s*$", np.nan, regex=True)
return df
def _resolve_feature_start(
df: pd.DataFrame,
feature_start: Union[int, str],
) -> int:
"""
将 feature_start (int 索引 / str 列名) 统一解析为 int 列索引。
与 OpenClaw modeling_batch.load_data_batch / load_data_single 一致:
str 走 columns.get_loc, int 直接返回。
"""
if isinstance(feature_start, str):
if feature_start not in df.columns:
raise ValueError(
f"feature_start='{feature_start}' 不在 CSV 列中: {list(df.columns)}"
)
return int(df.columns.get_loc(feature_start))
return int(feature_start)
def _run_train_sync(
model_type: str,
target: str,
train_data_path: str,
feature_start: Union[int, str],
params: Optional[Dict[str, Any]],
output_model_path: Path,
) -> Dict[str, Any]:
"""
完整同步训练流程 (由 execute_train_task 在线程池内调用):
pd.read_csv -> 目标列 dropna -> 切特征 -> train_test_split(80/20)
-> Pipeline(StandardScaler + model).fit -> 评估 test_r2/rmse/mae
-> joblib.dump({model, metadata}, output_model_path)
Returns:
metadata 字典, 含 test_r2 / test_rmse / test_mae / n_features 等,
调用方负责写回 TASK_STORE 和 _MODEL_REGISTRY。
注: 旧版 SPXY / KS 划分留作未来扩展 (params.split_method 控制),
当前固定 random + test_size=0.2 + random_state=42。
"""
df = _load_train_df(train_data_path)
if target not in df.columns:
raise ValueError(
f"target='{target}' 不在 CSV 列中, 可选: {list(df.columns)}"
)
# 1) 清洗: 仅剔除 target NaN 的行 (与 OpenClaw load_data_single 一致)
df = df[df[target].notna()].copy()
if df.empty:
raise ValueError("target 剔除 NaN 后无样本, 终止训练")
# 2) 特征切分
feature_start_idx = _resolve_feature_start(df, feature_start)
feature_columns = list(df.columns[feature_start_idx:])
X = df.iloc[:, feature_start_idx:].astype(np.float64)
y = df[target].astype(np.float64).values
# 3) 划分 (固定 random, 未来扩展 spxy/ks)
X_train, X_test, y_train, y_test = train_test_split(
X.values,
y,
test_size=0.2,
random_state=42,
)
# 4) 构造 Pipeline + 训练
pipeline = _get_model_pipeline(model_type, params)
pipeline.fit(X_train, y_train)
# 5) 测试集与训练集评估
y_pred = pipeline.predict(X_test)
test_r2 = float(r2_score(y_test, y_pred))
test_rmse = float(np.sqrt(mean_squared_error(y_test, y_pred)))
test_mae = float(mean_absolute_error(y_test, y_pred))
y_train_pred = pipeline.predict(X_train)
train_r2 = float(r2_score(y_train, y_train_pred))
train_rmse = float(np.sqrt(mean_squared_error(y_train, y_train_pred)))
train_mae = float(mean_absolute_error(y_train, y_train_pred))
metadata: Dict[str, Any] = {
"model_type": model_type,
"target": target,
"feature_start": feature_start,
"feature_columns": feature_columns,
"n_features": int(X.shape[1]),
"n_samples": int(X.shape[0]),
"train_size": int(X_train.shape[0]),
"test_size": int(X_test.shape[0]),
"params": dict(params or {}),
"test_r2": test_r2,
"test_rmse": test_rmse,
"test_mae": test_mae,
"train_r2": train_r2,
"train_rmse": train_rmse,
"train_mae": train_mae,
"split_method": "random",
"trained_at": datetime.now().isoformat(),
}
# 7) 持久化 (目录可能不存在, 顺手建)
output_model_path = Path(output_model_path)
output_model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(
{"model": pipeline, "metadata": metadata},
output_model_path,
)
return metadata
# ---------------------------------------------------------------------------
# 推断管线的模块级辅助函数
# ---------------------------------------------------------------------------
# 设计要点:
# 1) Dask 调度时, 函数必须可被工作进程 pickle 序列化。
# 因此 _predict_block / _load_model / _make_dummy_model / _run_predict_sync
# 全部是模块级函数 (而非嵌套), 避免闭包陷阱。
# 2) _predict_block 通过 model.predict(spectra_2d) 整批预测,
# 整张影像的 O(n_pixels * n_bands) 一次性预测在大矩阵上必 OOM,
# 因此外层用 xr.apply_ufunc(dask="parallelized") 把矩阵切块
# 逐块进入此函数, 单次内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) 大小。
# ---------------------------------------------------------------------------
def _make_dummy_model(n_features: int) -> RandomForestRegressor:
"""
构造一个 Dummy 随机森林回归器。
用途:
1) 真实 joblib 文件不存在时的连通性测试
2) 训练骨架尚未接入真实数据时的占位推断
"""
rng = np.random.default_rng(42)
X = rng.random((200, n_features))
y = rng.random(200)
model = RandomForestRegressor(
n_estimators=10, max_depth=5, random_state=0, n_jobs=1
)
model.fit(X, y)
return model
def _load_model(path: str, n_features: int) -> Any:
"""
加载训练好的 sklearn 模型, 失败时降级 Dummy。
优先级:
1) path 存在且 joblib.load 成功 -> 返回真实模型
2) 否则 -> 降级为 Dummy 随机森林 (n_features 必须指定)
"""
p = Path(path)
if p.is_file() and p.stat().st_size > 0:
try:
print(f"[model] 从磁盘加载: {path}")
return joblib.load(path)
except Exception as exc: # noqa: BLE001
print(f"[model] joblib.load 失败 ({type(exc).__name__}: {exc}), 降级 Dummy")
print(f"[model] 真实 joblib 不存在 ({path}), 使用 Dummy RandomForest")
return _make_dummy_model(n_features)
def _predict_block(spectra_3d: np.ndarray, model: Any) -> np.ndarray:
"""
单个 dask chunk 的推断函数 (xr.apply_ufunc 会自动调度调用)。
Parameters
----------
spectra_3d : np.ndarray
形状 (y_chunk, x_chunk, n_bands)。
此形状由 input_core_dims=[["band"]] 决定:
xarray 会把 band 维移到最后一轴, 然后按 (y, x) 的 chunk 切分调用本函数。
model : 已 fit 好的 sklearn 估计器
接受 (n_samples, n_features) 输入, 返回 (n_samples,) 预测。
Returns
-------
np.ndarray
形状 (y_chunk, x_chunk), dtype float32 的标量预测图。
"""
yc, xc, nb = spectra_3d.shape
# 2D 化: 每个像素一行光谱
flat = spectra_3d.reshape(yc * xc, nb)
# sklearn 风格的批量预测
pred = model.predict(flat)
# 还原为 2D 空间图, 强制 float32 节约一半内存
return pred.reshape(yc, xc).astype(np.float32, copy=False)
def _run_predict_sync(
model: Any,
model_id: str,
input_zarr_path: str,
output_zarr_path: str,
) -> None:
"""
同步推断主流程 (被 asyncio.to_thread 调用)。
流程:
1) xr.open_zarr 延迟打开 (dask 数组, 不一次性读入内存)
2) NaN -> 0 清洗 (model.predict 不接受 NaN)
3) xr.apply_ufunc 沿 (y, x) 逐 chunk 调 _predict_block
4) 非水域置 NaN (zarr 支持 float NaN)
5) to_zarr 触发整图计算 + 流式写出
"""
# 1. 延迟打开输入 (关键: Dask 不一次性读入内存)
ds = xr.open_zarr(input_zarr_path, chunks="auto")
if "reflectance" not in ds.data_vars:
raise KeyError(
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds.data_vars)}"
)
reflectance = ds["reflectance"] # dims: (y, x, band)
n_bands = reflectance.sizes["band"]
# 2. 水域掩膜 (与去耀斑算法同约定)
if "water_mask" in ds.data_vars or "water_mask" in ds.coords:
water_mask = ds["water_mask"].astype(bool)
else:
water_mask = xr.ones_like(reflectance.isel(band=0), dtype=bool)
# 3. NaN 清洗: 填充 0 (model.predict 不接受 NaN)
refl_clean = reflectance.fillna(0.0)
# 4. 核心: 用 apply_ufunc 把 model.predict 沿 (y, x) 应用
# dask="parallelized" 让每个 (y_chunk, x_chunk, band) chunk
# 独立调 _predict_block, 任意时刻内存中只有若干个 chunk。
prediction: xr.DataArray = xr.apply_ufunc(
_predict_block,
refl_clean,
kwargs={"model": model},
input_core_dims=[["band"]],
output_core_dims=[[]],
dask="parallelized",
output_dtypes=[np.float32],
dask_gufunc_kwargs={"allow_rechunk": True},
vectorize=False,
)
# 5. 非水域置 NaN (zarr 支持 float NaN, 便于后续可视化/掩膜分析)
prediction = prediction.where(water_mask, np.nan)
# 6. 包装为 Dataset 并流式写出
out = xr.Dataset(
{"prediction": prediction},
attrs={
"model_id": model_id,
"input_zarr_path": input_zarr_path,
"n_bands": n_bands,
"created_at": datetime.now().isoformat(),
},
)
# 保留 y/x 坐标
out = out.assign_coords(y=ds["y"], x=ds["x"])
# to_zarr + compute=True 触发整图 dask 图求值
# 中间会按 chunk 逐块调度到线程池, 内存峰值 ≈ 1 个 chunk 的体量
out.to_zarr(output_zarr_path, mode="w", compute=True)
# ---------------------------------------------------------------------------
# 后台任务执行器
# ---------------------------------------------------------------------------
async def execute_train_task(
task_id: str,
model_type: str,
target: str,
train_data_path: str,
feature_start: Union[int, str],
params: Dict[str, Any],
) -> None:
"""
训练任务后台执行器(已接入真实 sklearn 训练流程)。
流程:
1) get_task 校验任务存在
2) update_task(PROCESSING)
3) 生成 model_id / model_path
4) asyncio.to_thread 派发 _run_train_sync 到默认线程池
5) 成功 -> _register_model + update_task(SUCCESS, 附 test_r2/rmse/mae)
6) 失败 -> update_task(FAILED, 附 error + traceback)
"""
record = await get_task(task_id)
if record is None:
print(f"[{task_id}] 训练任务不存在, 跳过")
return
await update_task(
task_id,
status="PROCESSING",
updated_at=datetime.now().isoformat(),
)
print(
f"[{task_id}] 开始训练 model_type={model_type} target={target} "
f"train_data_path={train_data_path} feature_start={feature_start}"
)
# model_id 用 uuid4 前 12 位 (8 位易撞, 12 位兼顾可读性)
model_id = f"model_{uuid.uuid4().hex[:12]}"
model_path = Path(f"./data/models/{model_id}.joblib")
try:
# 同步 sklearn / pandas 训练丢到默认线程池, 不阻塞 event loop
metadata = await asyncio.to_thread(
_run_train_sync,
model_type,
target,
train_data_path,
feature_start,
params,
model_path,
)
# 登记到内存注册表 (供 /predict 查 model_id)
_register_model(
{
"model_id": model_id,
"model_type": model_type,
"target": target,
"params": dict(params or {}),
"path": str(model_path),
"feature_start": feature_start,
"n_features": metadata["n_features"],
"test_r2": metadata["test_r2"],
"test_rmse": metadata["test_rmse"],
"test_mae": metadata["test_mae"],
"created_at": datetime.now().isoformat(),
"train_task_id": task_id,
}
)
# 把训练指标写回任务记录, 前端轮询时可直接看
await update_task(
task_id,
status="SUCCESS",
model_id=model_id,
model_path=str(model_path),
test_r2=metadata["test_r2"],
test_rmse=metadata["test_rmse"],
test_mae=metadata["test_mae"],
n_features=metadata["n_features"],
n_samples=metadata["n_samples"],
error=None,
traceback=None,
updated_at=datetime.now().isoformat(),
)
print(
f"[{task_id}] 训练完成 -> model_id={model_id} "
f"test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f}"
)
except Exception as exc: # noqa: BLE001
# 失败时 model_path 不一定有产物, 显式置 None 方便前端判断
await update_task(
task_id,
status="FAILED",
model_id=None,
model_path=None,
error=f"{type(exc).__name__}: {exc}",
traceback=traceback.format_exc(),
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 训练失败 -> {type(exc).__name__}: {exc}")
async def execute_predict_task(
task_id: str,
model_id: str,
input_zarr_path: str,
output_zarr_path: Optional[str],
) -> None:
"""
推断任务后台执行器(真实实现版)。
OOM 防护策略:
- xr.open_zarr(..., chunks="auto") 延迟打开, 整图不一次性读入内存
- xr.apply_ufunc(..., dask="parallelized") 把影像按 chunk 切分
- 每个 chunk 内部 reshape 成 2D, 调 model.predict, 再 reshape 回 2D
- 任意时刻内存峰值 ≈ 1 个 (y_chunk, x_chunk, band) chunk 的体量
- 整图完成计算后再 to_zarr(compute=True) 流式写出
"""
record = await get_task(task_id)
if record is None:
print(f"[{task_id}] 推断任务不存在, 跳过")
return
# 1. 校验 model_id 是否已注册 (避免在后台任务里报模糊错误)
model_meta = _MODEL_REGISTRY.get(model_id)
if model_meta is None:
await update_task(
task_id,
status="FAILED",
error=f"model_id 不存在: {model_id}",
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 推断失败 -> model_id 不存在: {model_id}")
return
# 2. 自动生成 output_zarr_path (若未提供)
if output_zarr_path is None:
stem = input_zarr_path.rstrip("/\\").split("/")[-1].split("\\")[-1]
stem = stem.replace(".zarr", "")
output_zarr_path = f"./data/{model_id}_{stem}_pred.zarr"
await update_task(
task_id,
status="PROCESSING",
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 开始推断 model_id={model_id} input={input_zarr_path}")
try:
# 3. 探测波段数 (用于 Dummy 模型适配)
# 这里只读 zarr 元数据 (.zarray 的 shape), 不读真实数据
ds_probe = xr.open_zarr(input_zarr_path, chunks="auto")
if "reflectance" not in ds_probe.data_vars:
raise KeyError(
f"输入 zarr 缺少 'reflectance' 变量; 实际: {list(ds_probe.data_vars)}"
)
n_bands = ds_probe["reflectance"].sizes["band"]
ds_probe.close()
# 4. 加载模型 (真实文件优先, Dummy 兜底)
model = _load_model(model_meta["path"], n_features=n_bands)
# 5. 包装同步执行, 丢到线程池, 事件循环不阻塞
await asyncio.to_thread(
_run_predict_sync,
model,
model_id,
input_zarr_path,
output_zarr_path,
)
await update_task(
task_id,
status="SUCCESS",
output_zarr_path=output_zarr_path,
model_id=model_id,
error=None,
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 推断完成 -> output={output_zarr_path}")
except Exception as exc: # noqa: BLE001
tb_text = traceback.format_exc()
await update_task(
task_id,
status="FAILED",
output_zarr_path=None,
error=f"{type(exc).__name__}: {exc}",
traceback=tb_text,
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 推断失败 -> {type(exc).__name__}: {exc}")
print(tb_text)
# ---------------------------------------------------------------------------
# POST /api/modeling/train
# ---------------------------------------------------------------------------
@router.post("/train", response_model=TaskAcceptedResponse)
async def submit_train(
payload: TrainRequest,
background_tasks: BackgroundTasks,
) -> Dict[str, Any]:
"""提交一个模型训练任务, 立即返回 task_id。"""
task_id = str(uuid.uuid4())
await set_task(
task_id,
{
"task_id": task_id,
"kind": "train",
"model_type": payload.model_type,
"target": payload.target,
"train_data_path": payload.train_data_path,
"feature_start": payload.feature_start,
"params": payload.params,
"status": "PENDING",
"model_id": None,
"model_path": None,
"test_r2": None,
"test_rmse": None,
"test_mae": None,
"n_features": None,
"n_samples": None,
"error": None,
"traceback": None,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
},
)
background_tasks.add_task(
execute_train_task,
task_id,
payload.model_type,
payload.target,
payload.train_data_path,
payload.feature_start,
payload.params,
)
return {"task_id": task_id, "status": "PENDING", "kind": "train"}
# ---------------------------------------------------------------------------
# GET /api/modeling/models
# ---------------------------------------------------------------------------
@router.get("/models", response_model=ModelListResponse)
async def list_trained_models() -> Dict[str, Any]:
"""
列出已训练好的模型。
未来实现: 从 ./data/models/*.joblib 扫描元信息,
当前直接从内存 _MODEL_REGISTRY 读。
"""
models = list(_MODEL_REGISTRY.values())
# 按 created_at 倒序, 最新训练的在前
models.sort(key=lambda m: m.get("created_at", ""), reverse=True)
return {"models": models, "count": len(models)}
# ---------------------------------------------------------------------------
# POST /api/modeling/predict
# ---------------------------------------------------------------------------
@router.post("/predict", response_model=TaskAcceptedResponse)
async def submit_predict(
payload: PredictRequest,
background_tasks: BackgroundTasks,
) -> Dict[str, Any]:
"""提交一个模型推断任务, 立即返回 task_id。"""
task_id = str(uuid.uuid4())
await set_task(
task_id,
{
"task_id": task_id,
"kind": "predict",
"model_id": payload.model_id,
"input_zarr_path": payload.input_zarr_path,
"output_zarr_path": payload.output_zarr_path,
"status": "PENDING",
"error": None,
"traceback": None,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
},
)
background_tasks.add_task(
execute_predict_task,
task_id,
payload.model_id,
payload.input_zarr_path,
payload.output_zarr_path,
)
return {"task_id": task_id, "status": "PENDING", "kind": "predict"}

View File

@ -0,0 +1,40 @@
"""
去耀斑算法包
============
通过「注册表 + 策略模式」组织不同的去耀斑算法。
所有具体算法都应继承 BaseGlintRemover并使用 @register_glint_remover
装饰器把算法名和实现类绑定。
外部调用约定
------------
1. 所有算法子模块必须在本 __init__ 中显式 import
这样装饰器才会被执行、注册表才会被填满。
2. 上层endpoints、worker只允许
from app.core.algorithms import get_remover
来获取算法类,不要直接 import 具体实现类,
保持调度层与具体算法的解耦。
"""
from app.core.algorithms.base import BaseGlintRemover
from app.core.algorithms.registry import (
get_remover,
list_removers,
register_glint_remover,
unregister_glint_remover,
)
# ---- 算法子模块 import 区 ----
# 新增算法时,在这里加一行 import确保装饰器被执行。
from app.core.algorithms import goodman # Goodman
from app.core.algorithms import kutser # Kutser
# from app.core.algorithms import hedley # Hedley
# from app.core.algorithms import sugar # SUGAR
__all__ = [
"BaseGlintRemover",
"register_glint_remover",
"get_remover",
"list_removers",
"unregister_glint_remover",
]

View File

@ -0,0 +1,85 @@
"""
去耀斑算法抽象基类
==================
设计目标(策略模式 Strategy Pattern
------------------------------------
本模块定义了所有去耀斑算法必须遵守的标准接口。
未来的 Kutser、Goodman、Hedley、SUGAR 等算法都将继承本基类,
并实现统一的 process() 方法。
输入输出规范
------------
所有算法的输入与输出均统一为 **Zarr 文件路径**(字符串),
而不是内存中的 numpy ndarray。这样做的核心收益是
1. **解耦数据存储与内存计算**
算法只关心「从哪个 zarr 读、写到哪个 zarr」
至于数据最初来自 GeoTIFF / HDF5 / NetCDF / 内存数组,
都由 IO 层负责归一化转为 zarr。
2. **支持 Out-of-Core 计算**
影像往往超过内存上限zarr 分块chunk天然支持按块读取
算法实现可以借助 dask / xarray 进行流式计算。
3. **可缓存、可复用**
中间产物落盘后,下游算法(大气校正、辐射定标)能直接消费,
避免重复 IO。
4. **易于并行与分布式**
任务调度层只需把两个路径扔给 worker无需关心数据细节。
约定
----
- 子类应实现 process(),完成「读 -> 计算 -> 写」的完整流程。
- process() 返回 True 表示成功False 表示失败。
- 失败时建议抛出异常而非仅返回 False便于上层 BackgroundTasks 捕获并写入 error 字段。
"""
from abc import ABC, abstractmethod
from typing import Any
class BaseGlintRemover(ABC):
"""
去耀斑算法抽象基类。
所有具体算法Kutser / Goodman / Hedley / SUGAR …)必须继承本类并实现 process()。
子类可在 __init__ 中接收自己的超参数(如参考波段、阈值等),
真正的输入输出数据则由 process() 的两个 zarr 路径参数指定。
"""
# 子类可覆盖的算法名称标识,用于调度层按 method 名字查找
name: str = "base"
@abstractmethod
async def process(
self,
input_zarr_path: str,
output_zarr_path: str,
**kwargs: Any,
) -> bool:
"""
执行去耀斑处理。
Parameters
----------
input_zarr_path : str
输入高光谱影像的 zarr 存储路径。
数据已由 IO 层完成格式归一化(波段、坐标系、空间维度均已对齐)。
output_zarr_path : str
处理结果(去耀斑后影像)的 zarr 存储路径。
子类需自行创建该 zarr 存储并写入结果。
**kwargs : Any
算法的可选超参数,例如:
- reference_band: 参考近红外波段索引
- chunk_size: 计算分块大小
- 其它算法特定参数
Returns
-------
bool
True 表示处理成功False 表示失败。
建议在出错时直接 raise由调用方统一记录到任务状态。
"""
raise NotImplementedError
def __repr__(self) -> str: # pragma: no cover - 调试辅助
return f"<{self.__class__.__name__} name={self.name!r}>"

View File

@ -0,0 +1,123 @@
"""
app/core/algorithms/goodman.py
===============================
Goodman et al. 2008 去耀斑算法的 xarray + dask 流式实现。
算法公式
--------
R_corrected = R_raw - R_750 + A + B * (R_640 - R_750)
其中:
R_raw -- 原始反射率 (y, x, band)
R_750 -- λ=750 nm 处的反射率(红外参考波段, 远离水汽吸收)
R_640 -- λ=640 nm 处的反射率(可见光差异波段)
A, B -- 经验回归参数(用户可通过 params 传入, 默认全 0
后处理
------
- 负值截断为 0Clamp to 0
- 仅在水域掩膜 (water_mask) 内生效, 水外置 0
维度约定
--------
reflectance: (y, x, band), band 坐标通常为 wavelength (nm)
water_mask : (y, x), 布尔类型, True = 水域
"""
import asyncio
from typing import Any
import xarray as xr
from app.core.algorithms.base import BaseGlintRemover
from app.core.algorithms.registry import register_glint_remover
# ---------------------------------------------------------------------------
# 默认参数
# ---------------------------------------------------------------------------
# 与原始 Goodman 2008 论文符号保持一致, 方便用户交叉对照。
# A、B 通常通过对纯净深水区做 (R_corr - R_raw) ~ (R_640 - R_750) 回归得到;
# 在缺乏先验知识时, 退化为 A=0, B=0 即等价于 R_corrected = clip(R_raw - R_750, 0)。
# ---------------------------------------------------------------------------
DEFAULT_BAND_REF: float = 750.0 # λ_750 nm, 红外参考波段
DEFAULT_BAND_DIFF: float = 640.0 # λ_640 nm, 可见光差异波段
DEFAULT_A: float = 0.0 # 公式中的常数偏移项
DEFAULT_B: float = 0.0 # 公式中的斜率项
@register_glint_remover("goodman")
class GoodmanGlintRemover(BaseGlintRemover):
"""Goodman et al. 2008 去耀斑算法"""
name = "goodman"
async def process(
self,
input_zarr_path: str,
output_zarr_path: str,
**kwargs: Any,
) -> bool:
# 1. 解析超参数(带默认值, 方便用户按需覆盖)
band_ref: float = kwargs.get("band_ref", DEFAULT_BAND_REF)
band_diff: float = kwargs.get("band_diff", DEFAULT_BAND_DIFF)
A: float = kwargs.get("A", DEFAULT_A)
B: float = kwargs.get("B", DEFAULT_B)
# 2. 把同步的 xarray/dask 计算丢到工作线程,
# 避免阻塞 FastAPI 的事件循环
return await asyncio.to_thread(
self._process_sync,
input_zarr_path,
output_zarr_path,
band_ref,
band_diff,
A,
B,
)
@staticmethod
def _process_sync(
input_zarr_path: str,
output_zarr_path: str,
band_ref: float,
band_diff: float,
A: float,
B: float,
) -> bool:
# 1. 以 zarr 路径打开dask-backed, 不物化到内存)
# chunks="auto" 让 dask 根据每条坐标轴的大小自动决定分块
ds = xr.open_zarr(input_zarr_path, chunks="auto")
reflectance = ds["reflectance"] # (y, x, band)
# 2. 用 sel + method='nearest' 提取两个关键波段
# 返回形状 (y, x), 后续与 (y, x, band) 算术时会自动广播
R_750 = reflectance.sel(band=band_ref, method="nearest")
R_640 = reflectance.sel(band=band_diff, method="nearest")
# 3. Goodman 公式: xarray 沿 band 维度自动广播
# R_corr = R_raw - R_750 + A + B * (R_640 - R_750)
result = reflectance - R_750 + A + B * (R_640 - R_750)
# 4. 负值截断为 0clip(min=0) 优于 where(>0, 0, _)
# 不构造布尔中间数组, 底层走 dask 矢量化 clip 路径)
result = result.clip(min=0)
# 5. 仅在水域内生效(水外强制为 0
# 优先从 zarr 内部读 water_mask 变量, 缺失则视为全图水域
if "water_mask" in ds:
water_mask = ds["water_mask"].astype(bool)
result = result.where(water_mask, 0)
# 6. 构造输出 Dataset, 保留元信息(波段坐标/属性等)
out = xr.Dataset({"reflectance": result})
if ds.attrs:
out.attrs = dict(ds.attrs)
if reflectance.attrs:
out["reflectance"].attrs = dict(reflectance.attrs)
# 7. 流式写出Out-of-Core不一次性物化大数组,
# dask 会按 chunk 边算边写, 内存峰值 ≈ 单个 chunk 大小
out.to_zarr(output_zarr_path, mode="w", compute=True)
return True

View File

@ -0,0 +1,211 @@
"""
Kutser 去耀斑算法xarray + dask 重构版)
========================================
旧版痛点
--------
原始 Kutser 实现(参考 Kutser et al., 2013通常写成像这样
R_corr = np.zeros_like(R_raw)
for b in range(n_bands):
for y in range(H):
for x in range(W):
if water_mask[y, x]:
R_corr[y, x, b] = (
R_raw[y, x, b] - G_list[b] * D_norm[y, x]
)
with rasterio.open(..., 'w') as dst:
dst.write(R_corr)
问题:
1. 三重 Python 循环,每次只做一个浮点运算,解释器开销巨大;
2. 一次性把整张图 R_raw 读进内存,大影像直接 OOM
3. rasterio 写出要求 numpy 连续数组,进一步放大内存。
本文件用 xarray + dask 重写:
- 用 DataArray 维度广播,三重循环 → 一行表达式;
- 用 dask chunk 保持数据常驻磁盘、流式计算;
- 用 to_zarr 边算边写,输出格式与算法层彻底解耦。
"""
import asyncio
from typing import Any
import xarray as xr
from app.core.algorithms.base import BaseGlintRemover
from app.core.algorithms.registry import register_glint_remover
# ---------------------------------------------------------------------------
# 算法实现
# ---------------------------------------------------------------------------
@register_glint_remover("kutser")
class KutserGlintRemover(BaseGlintRemover):
"""
Kutser 近红外扣除法去耀斑。
数学公式(与旧版完全等价)
-------------------------
1) 水汽吸收深度 D每像素
D = (R(λ_lower) + R(λ_upper)) / 2 - R(λ_oxy)
2) 全局归一化因子 D_max
D_max = max(D) over 水域
归一化:
D_norm = D / D_max
3) 每波段水域范围:
G_list[b] = max(R[:, :, b] over 水域) - min(R[:, :, b] over 水域)
4) 校正公式(每像素、每波段):
R_corr(λ_b) = R_raw(λ_b) - G_list[b] * D_norm
"""
# Kutser 2013 论文里使用的参考波段nm
# λ_lower = 773, λ_oxy = 845, λ_upper = 893
# 允许通过 kwargs 覆盖,便于适配 MERIS / OLCI / Landsat 等不同传感器。
DEFAULT_BAND_LOWER: float = 773.0
DEFAULT_BAND_OXY: float = 845.0
DEFAULT_BAND_UPPER: float = 893.0
# --------------------------------------------------------------
# 公开异步入口
# --------------------------------------------------------------
# xarray / dask 的算子本身是同步阻塞的。在 async 函数中,
# 用 asyncio.to_thread 把同步体丢到默认线程池执行,
# 避免阻塞 FastAPI 的事件循环。
# --------------------------------------------------------------
async def process(
self,
input_zarr_path: str,
output_zarr_path: str,
**kwargs: Any,
) -> bool:
return await asyncio.to_thread(
self._process_sync,
input_zarr_path,
output_zarr_path,
kwargs,
)
# --------------------------------------------------------------
# 同步核心实现
# --------------------------------------------------------------
def _process_sync(
self,
input_zarr_path: str,
output_zarr_path: str,
kwargs: dict,
) -> bool:
# ============================================================
# 步骤 0打开 zarr建立 dask 计算图
# ============================================================
# chunks="auto":让 dask 根据 zarr 的存储分块自动选择内存上限,
# 数据不会一次性全部 materialize 进 RAM。
# ============================================================
ds = xr.open_zarr(input_zarr_path, chunks="auto")
reflectance: xr.DataArray = ds["reflectance"] # 维度约定:(y, x, band)
# 维度顺序约定(也可根据 ds.dims 自动适配):
assert "y" in reflectance.dims and "x" in reflectance.dims and "band" in reflectance.dims, (
f"reflectance 必须包含 y/x/band 三个维度,实际为: {reflectance.dims}"
)
# ============================================================
# 步骤 1取出 3 个参考波段对应的二维 (y, x) 切片
# ============================================================
# 假设 band 维度的坐标是 wavelengthnm
# 用 sel(..., method="nearest") 自动匹配最接近的波段。
# ============================================================
wl_lower = float(kwargs.get("band_lower", self.DEFAULT_BAND_LOWER))
wl_oxy = float(kwargs.get("band_oxy", self.DEFAULT_BAND_OXY))
wl_upper = float(kwargs.get("band_upper", self.DEFAULT_BAND_UPPER))
R_lower = reflectance.sel(band=wl_lower, method="nearest") # (y, x)
R_upper = reflectance.sel(band=wl_upper, method="nearest") # (y, x)
R_oxy = reflectance.sel(band=wl_oxy, method="nearest") # (y, x)
# ============================================================
# 步骤 2水域掩膜
# ============================================================
# 优先从 zarr 内部读取 water_mask 变量;
# 如果不存在,则假定整幅图都是水域(开发期兜底)。
# ============================================================
if "water_mask" in ds:
water_mask = ds["water_mask"].astype(bool)
else:
water_mask = xr.ones_like(
reflectance.isel(band=0), dtype=bool
)
# ============================================================
# 步骤 3水汽吸收深度 D每像素形状 (y, x)
# ============================================================
# 旧版D[y, x] = (R_lower[y, x] + R_upper[y, x]) / 2 - R_oxy[y, x]
# 新版一行表达式dask 自动构建 lazy 计算图。
# ============================================================
D = (R_lower + R_upper) / 2.0 - R_oxy # (y, x)dtype 与 reflectance 一致
# ============================================================
# 步骤 4全局归一化因子 D_max标量0-dim DataArray
# ============================================================
# 关键:先 .where(water_mask) 把非水域置 NaN
# 再 .max() 跨 (x, y) 聚合,自动规约到 0 维。
# dask 此时仍然没有真正计算,等到 to_zarr 时再触发。
# ============================================================
D_max = D.where(water_mask).max() # scalar
# 容错:如果水域为空导致 D_max 为 NaN用极小值兜底避免除零
D_max = D_max.fillna(1e-6)
# ============================================================
# 步骤 5归一化 D_norm形状 (y, x)
# ============================================================
D_norm = D / D_max # 标量除以 (y, x) 数组 → 自动广播
# ============================================================
# 步骤 6每波段水域范围 G_list形状 (band,)
# ============================================================
# 旧版三重循环内部还要做一次 min/max 聚合。
# xarray 版本:把 (y, x) 一起 reduce只保留 band 维度。
# ============================================================
R_water = reflectance.where(water_mask) # (y, x, band),非水域 NaN
G_min = R_water.min(dim=["x", "y"]) # (band,)
G_max = R_water.max(dim=["x", "y"]) # (band,)
G_list = (G_max - G_min).fillna(0.0) # (band,),容错
# ============================================================
# 步骤 7校正公式最关键的一行演示 xarray 广播)
# ============================================================
# 旧版需要:
# for b in bands:
# for y in range(H):
# for x in range(W):
# R_corr[y,x,b] = R_raw[y,x,b] - G_list[b] * D_norm[y,x]
#
# xarray 维度对齐规则:
# R_raw : (y, x, band)
# G_list: (band,) → 缺失 y, x 自动扩展
# D_norm: (y, x) → 缺失 band 自动扩展
# 乘法结果: (y, x, band) → 减法对齐
# 一行表达式完成「三重 for 循环 + 标量索引」的语义。
# ============================================================
corrected = reflectance - G_list * D_norm # (y, x, band)
# ============================================================
# 步骤 8水域掩膜过滤非水域置 NaN
# ============================================================
result = corrected.where(water_mask)
# ============================================================
# 步骤 9持久化为 zarr
# ============================================================
# mode="w":覆盖写入(如果目标已存在则删除重建)。
# compute=True阻塞直到整张图算完并落盘。
# 由于数据始终是 dask chunk + 流式写出,
# 内存峰值 ≈ 单个 chunk 大小,与整张影像大小无关。
# ============================================================
out = xr.Dataset({"reflectance": result})
# 保留原数据集的全局属性 / 坐标信息CRS、wavelength、...
out.attrs = dict(ds.attrs)
out["reflectance"].attrs = dict(reflectance.attrs)
out.to_zarr(output_zarr_path, mode="w", compute=True)
return True

View File

@ -0,0 +1,135 @@
"""
算法注册表Registry / Factory
================================
通过装饰器把「算法名字符串」与「算法实现类」绑定在一起。
上层调度层FastAPI endpoints、BackgroundTasks worker只需要拿到
前端传过来的 method 字符串,就可以自动派发到对应的算法实现,
而无需写一长串 if/elif。
使用示例
--------
from app.core.algorithms import BaseGlintRemover
from app.core.algorithms.registry import (
register_glint_remover,
get_remover,
list_removers,
)
@register_glint_remover("kutser")
class KutserGlintRemover(BaseGlintRemover):
async def process(self, input_zarr_path, output_zarr_path, **kwargs):
...
# 派发
Cls = get_remover(method_from_request)
remover = Cls()
await remover.process(input_zarr_path, output_zarr_path, **kwargs)
设计要点
--------
- 注册动作发生在「类定义时」,所以必须在所有算法 import 完之后
注册表才完整。可以在 `app/core/algorithms/__init__.py` 中
把算法子模块 import 一遍来强制触发注册。
- 重复注册同名算法会直接抛错,避免静默覆盖。
- name 会同步写回到类的 `name` 属性,便于算法自身查询身份。
"""
from typing import Dict, Type
from app.core.algorithms.base import BaseGlintRemover
# 全局注册表name(str) -> 实现类(type),类未被实例化
_REGISTRY: Dict[str, Type[BaseGlintRemover]] = {}
def register_glint_remover(name: str):
"""
类装饰器工厂:把传入 name 的算法类注册到全局注册表。
Parameters
----------
name : str
算法标识,建议小写下划线风格,例如 "kutser""goodman"
Raises
------
ValueError
- name 不是非空字符串
- name 已经被其它类占用
TypeError
- 被装饰的对象不是 BaseGlintRemover 的子类
"""
# ---- 防御性校验name 必须是合法字符串 ----
if not isinstance(name, str) or not name.strip():
raise ValueError(
f"register_glint_remover 的 name 必须是非空字符串,收到: {name!r}"
)
def decorator(cls: Type[BaseGlintRemover]) -> Type[BaseGlintRemover]:
# ---- 防御性校验:被装饰对象必须是 BaseGlintRemover 子类 ----
if not isinstance(cls, type) or not issubclass(cls, BaseGlintRemover):
raise TypeError(
f"@register_glint_remover 只能装饰 BaseGlintRemover 的子类,"
f"收到: {cls!r}"
)
# ---- 防御性校验:禁止静默覆盖 ----
if name in _REGISTRY:
raise ValueError(
f"算法名 {name!r} 已被 {_REGISTRY[name].__name__} 占用,"
f"请使用其它名字或先调用 unregister_glint_remover() 注销旧实现。"
)
# 同步把 name 写回类属性,便于算法自身和日志输出使用
cls.name = name
_REGISTRY[name] = cls
return cls
return decorator
def get_remover(name: str) -> Type[BaseGlintRemover]:
"""
按算法名字符串取出对应的实现类(未实例化)。
调用方拿到类后自行 `Cls(...)` 构造实例,再调用 process()。
Raises
------
KeyError
当 name 不在注册表中时抛出,错误信息中附带已注册列表便于排查。
"""
try:
return _REGISTRY[name]
except KeyError as exc:
known = ", ".join(sorted(_REGISTRY)) or "<空>"
raise KeyError(
f"未注册的算法名: {name!r}。已注册的算法: {known}"
) from exc
def list_removers() -> Dict[str, Type[BaseGlintRemover]]:
"""
返回当前注册表的浅拷贝。
可用于:
- 调试日志
- 给前端暴露一个 GET /api/algorithms 接口
- 单元测试断言
"""
return dict(_REGISTRY)
def unregister_glint_remover(name: str) -> None:
"""
注销指定算法。主要给:
- 单元测试
- 热重载 / 插件卸载场景
生产代码一般不需要调用。
"""
if name not in _REGISTRY:
raise KeyError(f"未注册的算法名: {name!r}")
del _REGISTRY[name]

View File

@ -0,0 +1,91 @@
"""
app/core/task_store.py
======================
并发安全的内存任务状态存储,替代早期 mock 流水线中的 MOCK_TASK_DB。
设计目标
--------
1. 在单进程内提供事件循环级别的互斥asyncio.Lock
避免在 update 与 set/get 之间穿插 await 时发生状态不一致。
2. 暴露异步 APIset_task / update_task / get_task
让调用方在 async 上下文中显式表达临界区。
3. 保留一个同步的 has_task() 用于轻量存在性判断。
4. 生产环境应替换为 Redis / SQLite / PostgreSQL
但接口形状保持一致, 便于上层调用方无缝迁移。
使用约定
--------
- 写入初始 PENDING 记录: await set_task(task_id, record)
- 增量更新字段PROCESSING/SUCCESS/FAILEDawait update_task(task_id, **fields)
- 读取任务记录: await get_task(task_id) # 可能返回 None
- 同步判断是否存在: has_task(task_id)
"""
import asyncio
from typing import Any, Dict, Optional
# ---------------------------------------------------------------------------
# 全局存储与锁
# ---------------------------------------------------------------------------
# TASK_STORE: task_id -> 任务记录
# 任务记录字段约定(与 endpoints.py 保持一致):
# task_id, method, params, status,
# output_zarr_path, error, traceback,
# created_at, updated_at
# ---------------------------------------------------------------------------
TASK_STORE: Dict[str, Dict[str, Any]] = {}
# 单进程内的事件循环级互斥锁
# 注意asyncio.Lock 必须在事件循环内创建, 故在模块顶层实例化时
# 仅获取引用, 第一次使用 (await lock.acquire()) 会在运行循环内进行。
_lock: asyncio.Lock = asyncio.Lock()
# ---------------------------------------------------------------------------
# 异步 API
# ---------------------------------------------------------------------------
async def set_task(task_id: str, record: Dict[str, Any]) -> None:
"""
初始化或整体覆盖一个任务记录。
用法POST 端点收到提交请求后立即调用, 写入 PENDING 状态的初始记录。
"""
async with _lock:
TASK_STORE[task_id] = record
async def update_task(task_id: str, **fields: Any) -> None:
"""
按字段增量更新任务记录。
用法:后台执行器在 PROCESSING / SUCCESS / FAILED 等状态切换时调用。
若 task_id 不存在, setdefault 会自动创建一个空 dict 再 update防御性兜底
"""
async with _lock:
record = TASK_STORE.setdefault(task_id, {})
record.update(fields)
async def get_task(task_id: str) -> Optional[Dict[str, Any]]:
"""
读取任务记录; 不存在时返回 None。
用法GET /api/tasks/{task_id} 用此接口查询。
"""
async with _lock:
return TASK_STORE.get(task_id)
# ---------------------------------------------------------------------------
# 同步 API轻量
# ---------------------------------------------------------------------------
def has_task(task_id: str) -> bool:
"""
同步判断 task_id 是否存在。
适用于不需要锁的轻量场景(例如日志前置判断);
在 async 上下文中仍可调用, 因为 dict 的 in 判断是原子操作。
"""
return task_id in TASK_STORE

62
new/app/main.py Normal file
View File

@ -0,0 +1,62 @@
"""
WQ_GUI FastAPI 后端入口
=======================
应用启动与全局中间件配置:
- CORS开发阶段允许所有来源方便本地前端Vite / Webpack dev server联调
- 路由:通过 include_router 挂载 app/api/endpoints.py 中的业务接口
业务接口说明:
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
GET /api/tasks/{task_id} 查询指定任务的状态与结果
"""
from typing import Dict
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.endpoints import router as deglint_router
from app.api.modeling import router as modeling_router
# ---------------------------------------------------------------------------
# FastAPI 应用实例
# ---------------------------------------------------------------------------
app = FastAPI(
title="WQ_GUI Backend",
description="高光谱影像去耀斑处理 API",
version="0.2.0",
)
# ---------------------------------------------------------------------------
# CORS 中间件
# ---------------------------------------------------------------------------
# 开发阶段:放开所有来源、方法和头部,方便本地前端(任意端口)联调。
# 生产环境务必收敛 allow_origins 为前端真实域名,避免安全风险。
# ---------------------------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# 路由注册
# ---------------------------------------------------------------------------
# 统一以 /api 为前缀,便于将来做版本管理(如 /api/v1、/api/v2
# ---------------------------------------------------------------------------
app.include_router(deglint_router, prefix="/api")
app.include_router(modeling_router, prefix="/api")
# ---------------------------------------------------------------------------
# 根路径健康检查(方便本地调试,非业务必需)
# ---------------------------------------------------------------------------
@app.get("/")
async def root() -> Dict[str, str]:
return {"service": "WQ_GUI Backend", "status": "ok"}

24
new/frontend/.gitignore vendored Normal file
View File

@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

5
new/frontend/README.md Normal file
View File

@ -0,0 +1,5 @@
# Vue 3 + TypeScript + Vite
This template should help get you started developing with Vue 3 and TypeScript in Vite. The template uses Vue 3 `<script setup>` SFCs, check out the [script setup docs](https://v3.vuejs.org/api/sfc-script-setup.html#sfc-script-setup) to learn more.
Learn more about the recommended Project Setup and IDE Support in the [Vue Docs TypeScript Guide](https://vuejs.org/guide/typescript/overview.html#project-setup).

13
new/frontend/index.html Normal file
View File

@ -0,0 +1,13 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>frontend</title>
</head>
<body>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>

2412
new/frontend/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

27
new/frontend/package.json Normal file
View File

@ -0,0 +1,27 @@
{
"name": "frontend",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vue-tsc -b && vite build",
"preview": "vite preview"
},
"dependencies": {
"axios": "^1.16.1",
"echarts": "^6.1.0",
"element-plus": "^2.14.1",
"pinia": "^3.0.4",
"vue": "^3.5.34",
"vue-router": "^5.1.0"
},
"devDependencies": {
"@types/node": "^24.12.3",
"@vitejs/plugin-vue": "^6.0.6",
"@vue/tsconfig": "^0.9.1",
"typescript": "~6.0.2",
"vite": "^8.0.12",
"vue-tsc": "^3.2.8"
}
}

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 9.3 KiB

View File

@ -0,0 +1,24 @@
<svg xmlns="http://www.w3.org/2000/svg">
<symbol id="bluesky-icon" viewBox="0 0 16 17">
<g clip-path="url(#bluesky-clip)"><path fill="#08060d" d="M7.75 7.735c-.693-1.348-2.58-3.86-4.334-5.097-1.68-1.187-2.32-.981-2.74-.79C.188 2.065.1 2.812.1 3.251s.241 3.602.398 4.13c.52 1.744 2.367 2.333 4.07 2.145-2.495.37-4.71 1.278-1.805 4.512 3.196 3.309 4.38-.71 4.987-2.746.608 2.036 1.307 5.91 4.93 2.746 2.72-2.746.747-4.143-1.747-4.512 1.702.189 3.55-.4 4.07-2.145.156-.528.397-3.691.397-4.13s-.088-1.186-.575-1.406c-.42-.19-1.06-.395-2.741.79-1.755 1.24-3.64 3.752-4.334 5.099"/></g>
<defs><clipPath id="bluesky-clip"><path fill="#fff" d="M.1.85h15.3v15.3H.1z"/></clipPath></defs>
</symbol>
<symbol id="discord-icon" viewBox="0 0 20 19">
<path fill="#08060d" d="M16.224 3.768a14.5 14.5 0 0 0-3.67-1.153c-.158.286-.343.67-.47.976a13.5 13.5 0 0 0-4.067 0c-.128-.306-.317-.69-.476-.976A14.4 14.4 0 0 0 3.868 3.77C1.546 7.28.916 10.703 1.231 14.077a14.7 14.7 0 0 0 4.5 2.306q.545-.748.965-1.587a9.5 9.5 0 0 1-1.518-.74q.191-.14.372-.293c2.927 1.369 6.107 1.369 8.999 0q.183.152.372.294-.723.437-1.52.74.418.838.963 1.588a14.6 14.6 0 0 0 4.504-2.308c.37-3.911-.63-7.302-2.644-10.309m-9.13 8.234c-.878 0-1.599-.82-1.599-1.82 0-.998.705-1.82 1.6-1.82.894 0 1.614.82 1.599 1.82.001 1-.705 1.82-1.6 1.82m5.91 0c-.878 0-1.599-.82-1.599-1.82 0-.998.705-1.82 1.6-1.82.893 0 1.614.82 1.599 1.82 0 1-.706 1.82-1.6 1.82"/>
</symbol>
<symbol id="documentation-icon" viewBox="0 0 21 20">
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="m15.5 13.333 1.533 1.322c.645.555.967.833.967 1.178s-.322.623-.967 1.179L15.5 18.333m-3.333-5-1.534 1.322c-.644.555-.966.833-.966 1.178s.322.623.966 1.179l1.534 1.321"/>
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M17.167 10.836v-4.32c0-1.41 0-2.117-.224-2.68-.359-.906-1.118-1.621-2.08-1.96-.599-.21-1.349-.21-2.848-.21-2.623 0-3.935 0-4.983.369-1.684.591-3.013 1.842-3.641 3.428C3 6.449 3 7.684 3 10.154v2.122c0 2.558 0 3.838.706 4.726q.306.383.713.671c.76.536 1.79.64 3.581.66"/>
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M3 10a2.78 2.78 0 0 1 2.778-2.778c.555 0 1.209.097 1.748-.047.48-.129.854-.503.982-.982.145-.54.048-1.194.048-1.749a2.78 2.78 0 0 1 2.777-2.777"/>
</symbol>
<symbol id="github-icon" viewBox="0 0 19 19">
<path fill="#08060d" fill-rule="evenodd" d="M9.356 1.85C5.05 1.85 1.57 5.356 1.57 9.694a7.84 7.84 0 0 0 5.324 7.44c.387.079.528-.168.528-.376 0-.182-.013-.805-.013-1.454-2.165.467-2.616-.935-2.616-.935-.349-.91-.864-1.143-.864-1.143-.71-.48.051-.48.051-.48.787.051 1.2.805 1.2.805.695 1.194 1.817.857 2.268.649.064-.507.27-.857.49-1.052-1.728-.182-3.545-.857-3.545-3.87 0-.857.31-1.558.8-2.104-.078-.195-.349-1 .077-2.078 0 0 .657-.208 2.14.805a7.5 7.5 0 0 1 1.946-.26c.657 0 1.328.092 1.946.26 1.483-1.013 2.14-.805 2.14-.805.426 1.078.155 1.883.078 2.078.502.546.799 1.247.799 2.104 0 3.013-1.818 3.675-3.558 3.87.284.247.528.714.528 1.454 0 1.052-.012 1.896-.012 2.156 0 .208.142.455.528.377a7.84 7.84 0 0 0 5.324-7.441c.013-4.338-3.48-7.844-7.773-7.844" clip-rule="evenodd"/>
</symbol>
<symbol id="social-icon" viewBox="0 0 20 20">
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M12.5 6.667a4.167 4.167 0 1 0-8.334 0 4.167 4.167 0 0 0 8.334 0"/>
<path fill="none" stroke="#aa3bff" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.35" d="M2.5 16.667a5.833 5.833 0 0 1 8.75-5.053m3.837.474.513 1.035c.07.144.257.282.414.309l.93.155c.596.1.736.536.307.965l-.723.73a.64.64 0 0 0-.152.531l.207.903c.164.715-.213.991-.84.618l-.872-.52a.63.63 0 0 0-.577 0l-.872.52c-.624.373-1.003.094-.84-.618l.207-.903a.64.64 0 0 0-.152-.532l-.723-.729c-.426-.43-.289-.864.306-.964l.93-.156a.64.64 0 0 0 .412-.31l.513-1.034c.28-.562.735-.562 1.012 0"/>
</symbol>
<symbol id="x-icon" viewBox="0 0 19 19">
<path fill="#08060d" fill-rule="evenodd" d="M1.893 1.98c.052.072 1.245 1.769 2.653 3.77l2.892 4.114c.183.261.333.48.333.486s-.068.089-.152.183l-.522.593-.765.867-3.597 4.087c-.375.426-.734.834-.798.905a1 1 0 0 0-.118.148c0 .01.236.017.664.017h.663l.729-.83c.4-.457.796-.906.879-.999a692 692 0 0 0 1.794-2.038c.034-.037.301-.34.594-.675l.551-.624.345-.392a7 7 0 0 1 .34-.374c.006 0 .93 1.306 2.052 2.903l2.084 2.965.045.063h2.275c1.87 0 2.273-.003 2.266-.021-.008-.02-1.098-1.572-3.894-5.547-2.013-2.862-2.28-3.246-2.273-3.266.008-.019.282-.332 2.085-2.38l2-2.274 1.567-1.782c.022-.028-.016-.03-.65-.03h-.674l-.3.342a871 871 0 0 1-1.782 2.025c-.067.075-.405.458-.75.852a100 100 0 0 1-.803.91c-.148.172-.299.344-.99 1.127-.304.343-.32.358-.345.327-.015-.019-.904-1.282-1.976-2.808L6.365 1.85H1.8zm1.782.91 8.078 11.294c.772 1.08 1.413 1.973 1.425 1.984.016.017.241.02 1.05.017l1.03-.004-2.694-3.766L7.796 5.75 5.722 2.852l-1.039-.004-1.039-.004z" clip-rule="evenodd"/>
</symbol>
</svg>

After

Width:  |  Height:  |  Size: 4.9 KiB

225
new/frontend/src/App.vue Normal file
View File

@ -0,0 +1,225 @@
<template>
<div class="dashboard-container">
<h1 class="title">高光谱水质反演控制台</h1>
<el-row :gutter="20">
<el-col :span="12">
<el-card class="box-card" shadow="hover">
<template #header>
<div class="card-header">
<span class="header-title">🚀 模型训练 (Train)</span>
</div>
</template>
<el-form label-position="top">
<el-form-item label="算法选择 (Model Type)">
<el-select v-model="trainForm.model_type" placeholder="请选择算法" class="w-full">
<el-option label="随机森林 (RF)" value="RF" />
<el-option label="支持向量回归 (SVR)" value="SVR" />
<el-option label="线性回归 (LinearRegression)" value="LinearRegression" />
<el-option label="K近邻 (KNN)" value="KNN" />
<el-option label="偏最小二乘 (PLS)" value="PLS" />
</el-select>
</el-form-item>
<el-form-item label="目标参数 (Target)">
<el-input v-model="trainForm.target" placeholder="如 Chl-a" />
</el-form-item>
<el-form-item label="训练数据路径 (CSV 绝对路径)">
<el-input v-model="trainForm.train_data_path" placeholder="如 D:\111\data.csv" />
</el-form-item>
<el-form-item label="特征起始列 (如 4, 或列名)">
<el-input v-model="trainForm.feature_start" placeholder="填写数字或列名" />
</el-form-item>
<el-button type="primary" @click="handleTrain" :loading="trainPoller?.isPolling?.value" class="w-full">
开始训练
</el-button>
</el-form>
<div v-if="trainTaskId" class="status-board">
<p><strong>任务 ID:</strong> <el-tag size="small" type="info">{{ trainTaskId }}</el-tag></p>
<p><strong>当前状态:</strong>
<el-tag :type="getStatusType(trainPoller?.status?.value || 'PENDING')" style="margin-left:10px">
{{ trainPoller?.status?.value || 'PENDING' }}
</el-tag>
</p>
<el-progress
v-if="trainPoller?.isPolling?.value || trainPoller?.status?.value === 'SUCCESS'"
:percentage="trainPoller?.status?.value === 'SUCCESS' ? 100 : 60"
:status="trainPoller?.status?.value === 'SUCCESS' ? 'success' : (trainPoller?.status?.value === 'FAILED' ? 'exception' : '')"
:indeterminate="trainPoller?.isPolling?.value"
/>
<div v-if="trainPoller?.error?.value" class="error-msg">
<el-alert :title="trainPoller.error.value" type="error" :closable="false" show-icon />
</div>
<div v-if="trainPoller?.result?.value?.model_id" class="result-msg">
<el-descriptions border :column="1" size="small" title="训练指标">
<el-descriptions-item label="Model ID">{{ trainPoller.result.value.model_id }}</el-descriptions-item>
<el-descriptions-item label="Test R²">{{ Number(trainPoller.result.value.test_r2).toFixed(4) }}</el-descriptions-item>
<el-descriptions-item label="Test RMSE">{{ Number(trainPoller.result.value.test_rmse).toFixed(4) }}</el-descriptions-item>
</el-descriptions>
</div>
</div>
</el-card>
</el-col>
<el-col :span="12">
<el-card class="box-card" shadow="hover">
<template #header>
<div class="card-header">
<span class="header-title">🎯 模型推断 (Predict)</span>
</div>
</template>
<el-form label-position="top">
<el-form-item label="已训练模型 ID (Model ID)">
<el-input v-model="predictForm.model_id" placeholder="将自动填入左侧训练好的 ID" />
</el-form-item>
<el-form-item label="待推断影像路径 (Zarr 绝对路径)">
<el-input v-model="predictForm.input_zarr_path" placeholder="如 D:\111\image.zarr" />
</el-form-item>
<el-button type="success" @click="handlePredict" :loading="predictPoller?.isPolling?.value" class="w-full">
开始大图反演推断
</el-button>
</el-form>
<div v-if="predictTaskId" class="status-board">
<p><strong>任务 ID:</strong> <el-tag size="small" type="info">{{ predictTaskId }}</el-tag></p>
<p><strong>当前状态:</strong>
<el-tag :type="getStatusType(predictPoller?.status?.value || 'PENDING')" style="margin-left:10px">
{{ predictPoller?.status?.value || 'PENDING' }}
</el-tag>
</p>
<el-progress
v-if="predictPoller?.isPolling?.value || predictPoller?.status?.value === 'SUCCESS'"
:percentage="predictPoller?.status?.value === 'SUCCESS' ? 100 : 50"
:status="predictPoller?.status?.value === 'SUCCESS' ? 'success' : (predictPoller?.status?.value === 'FAILED' ? 'exception' : '')"
:indeterminate="predictPoller?.isPolling?.value"
/>
<div v-if="predictPoller?.error?.value" class="error-msg">
<el-alert :title="predictPoller.error.value" type="error" :closable="false" show-icon />
</div>
<div v-if="predictPoller?.result?.value?.output_zarr_path" class="result-msg">
<el-alert :title="'推断成功!结果已落盘至: ' + predictPoller.result.value.output_zarr_path" type="success" :closable="false" show-icon />
</div>
</div>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script setup lang="ts">
import { ref, watch, reactive } from 'vue'
import { submitTrain, submitPredict } from './api/tasks'
import { useTaskPoller } from './composables/useTaskPoller'
// 训练表单状态
const trainForm = reactive({
model_type: 'RF',
target: 'Chl-a',
train_data_path: '',
feature_start: '4'
})
const trainTaskId = ref<string | null>(null)
const trainPoller = useTaskPoller(trainTaskId)
// 推断表单状态
const predictForm = reactive({
model_id: '',
input_zarr_path: ''
})
const predictTaskId = ref<string | null>(null)
const predictPoller = useTaskPoller(predictTaskId)
// 自动填入联动
watch(() => trainPoller?.result?.value?.model_id, (newId) => {
if (newId) predictForm.model_id = newId as string
})
// 提交训练
const handleTrain = async () => {
try {
const res = await submitTrain({
model_type: trainForm.model_type,
target: trainForm.target,
train_data_path: trainForm.train_data_path,
feature_start: trainForm.feature_start,
params: {}
})
trainTaskId.value = res.task_id
} catch (e: any) {
console.error('训练接口调用失败', e)
alert('提交失败,请检查后端是否在 9090 端口启动,或按 F12 查看控制台跨域报错')
}
}
// 提交推断
const handlePredict = async () => {
try {
const res = await submitPredict({
model_id: predictForm.model_id,
input_zarr_path: predictForm.input_zarr_path
})
predictTaskId.value = res.task_id
} catch (e: any) {
console.error('推断接口调用失败', e)
}
}
// 样式辅助
const getStatusType = (status: string) => {
if (status === 'SUCCESS') return 'success'
if (status === 'FAILED') return 'danger'
if (status === 'PROCESSING') return 'warning'
return 'info'
}
</script>
<style>
/* 去除全局默认边距 */
body {
margin: 0;
padding: 0;
}
</style>
<style scoped>
.dashboard-container {
padding: 40px;
min-height: 100vh;
background-color: #1e1e2d; /* 科技深色底 */
}
.title {
text-align: center;
margin-bottom: 40px;
color: #ffffff;
font-weight: 300;
letter-spacing: 2px;
}
.header-title {
font-weight: bold;
font-size: 16px;
}
.box-card {
margin-bottom: 20px;
background-color: rgba(255, 255, 255, 0.95);
}
.w-full {
width: 100%;
}
.status-board {
margin-top: 25px;
padding: 20px;
background: #f8f9fa;
border-radius: 8px;
border: 1px solid #e4e7ed;
}
.error-msg, .result-msg {
margin-top: 20px;
}
</style>

View File

@ -0,0 +1,15 @@
import axios from 'axios'
const request = axios.create({
// 注意:直接指向我们刚刚改好的 9090 端口
baseURL: 'http://127.0.0.1:9090',
timeout: 60000
})
// 拦截器:直接剥离 data
request.interceptors.response.use(
response => response.data,
error => Promise.reject(error)
)
export default request

View File

@ -0,0 +1,13 @@
import request from './request'
export const submitTrain = (data: any) => {
return request.post<any, any>('/api/modeling/train', data)
}
export const submitPredict = (data: any) => {
return request.post<any, any>('/api/modeling/predict', data)
}
export const getTaskStatus = (task_id: string) => {
return request.get<any, any>(`/api/tasks/${task_id}`)
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 8.5 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="37.07" height="36" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 198"><path fill="#41B883" d="M204.8 0H256L128 220.8L0 0h97.92L128 51.2L157.44 0h47.36Z"></path><path fill="#41B883" d="m0 0l128 220.8L256 0h-51.2L128 132.48L50.56 0H0Z"></path><path fill="#35495E" d="M50.56 0L128 133.12L204.8 0h-47.36L128 51.2L97.92 0H50.56Z"></path></svg>

After

Width:  |  Height:  |  Size: 496 B

View File

@ -0,0 +1,95 @@
<script setup lang="ts">
import { ref } from 'vue'
import viteLogo from '../assets/vite.svg'
import heroImg from '../assets/hero.png'
import vueLogo from '../assets/vue.svg'
const count = ref(0)
</script>
<template>
<section id="center">
<div class="hero">
<img :src="heroImg" class="base" width="170" height="179" alt="" />
<img :src="vueLogo" class="framework" alt="Vue logo" />
<img :src="viteLogo" class="vite" alt="Vite logo" />
</div>
<div>
<h1>Get started</h1>
<p>Edit <code>src/App.vue</code> and save to test <code>HMR</code></p>
</div>
<button type="button" class="counter" @click="count++">
Count is {{ count }}
</button>
</section>
<div class="ticks"></div>
<section id="next-steps">
<div id="docs">
<svg class="icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#documentation-icon"></use>
</svg>
<h2>Documentation</h2>
<p>Your questions, answered</p>
<ul>
<li>
<a href="https://vite.dev/" target="_blank">
<img class="logo" :src="viteLogo" alt="" />
Explore Vite
</a>
</li>
<li>
<a href="https://vuejs.org/" target="_blank">
<img class="button-icon" :src="vueLogo" alt="" />
Learn more
</a>
</li>
</ul>
</div>
<div id="social">
<svg class="icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#social-icon"></use>
</svg>
<h2>Connect with us</h2>
<p>Join the Vite community</p>
<ul>
<li>
<a href="https://github.com/vitejs/vite" target="_blank">
<svg class="button-icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#github-icon"></use>
</svg>
GitHub
</a>
</li>
<li>
<a href="https://chat.vite.dev/" target="_blank">
<svg class="button-icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#discord-icon"></use>
</svg>
Discord
</a>
</li>
<li>
<a href="https://x.com/vite_js" target="_blank">
<svg class="button-icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#x-icon"></use>
</svg>
X.com
</a>
</li>
<li>
<a href="https://bsky.app/profile/vite.dev" target="_blank">
<svg class="button-icon" role="presentation" aria-hidden="true">
<use href="/icons.svg#bluesky-icon"></use>
</svg>
Bluesky
</a>
</li>
</ul>
</div>
</section>
<div class="ticks"></div>
<section id="spacer"></section>
</template>

View File

@ -0,0 +1,51 @@
import { ref, watch, onUnmounted, type Ref } from 'vue'
import { getTaskStatus } from '../api/tasks'
export function useTaskPoller(taskIdRef: Ref<string | null>) {
const status = ref<string>('')
const isPolling = ref(false)
const error = ref<string | null>(null)
const result = ref<any>(null)
let timer: any = null
const start = () => {
if (!taskIdRef.value) return
isPolling.value = true
error.value = null
status.value = 'PENDING'
timer = setInterval(async () => {
try {
const res = await getTaskStatus(taskIdRef.value!)
status.value = res.status
if (res.status === 'SUCCESS') {
result.value = res
stop()
} else if (res.status === 'FAILED') {
error.value = res.error || '任务执行失败'
stop()
}
} catch (e: any) {
error.value = '网络请求失败,请检查后端状态'
stop()
}
}, 2000)
}
const stop = () => {
isPolling.value = false
if (timer) clearInterval(timer)
}
// 监听 Task ID 变化自动开启轮询
watch(taskIdRef, (newVal) => {
stop()
if (newVal) start()
})
// 组件销毁时清理定时器
onUnmounted(() => stop())
return { status, isPolling, error, result, stop }
}

9
new/frontend/src/main.ts Normal file
View File

@ -0,0 +1,9 @@
import { createApp } from 'vue'
import ElementPlus from 'element-plus'
import 'element-plus/dist/index.css'
import App from './App.vue'
const app = createApp(App)
app.use(ElementPlus)
app.mount('#app')

296
new/frontend/src/style.css Normal file
View File

@ -0,0 +1,296 @@
:root {
--text: #6b6375;
--text-h: #08060d;
--bg: #fff;
--border: #e5e4e7;
--code-bg: #f4f3ec;
--accent: #aa3bff;
--accent-bg: rgba(170, 59, 255, 0.1);
--accent-border: rgba(170, 59, 255, 0.5);
--social-bg: rgba(244, 243, 236, 0.5);
--shadow:
rgba(0, 0, 0, 0.1) 0 10px 15px -3px, rgba(0, 0, 0, 0.05) 0 4px 6px -2px;
--sans: system-ui, 'Segoe UI', Roboto, sans-serif;
--heading: system-ui, 'Segoe UI', Roboto, sans-serif;
--mono: ui-monospace, Consolas, monospace;
font: 18px/145% var(--sans);
letter-spacing: 0.18px;
color-scheme: light dark;
color: var(--text);
background: var(--bg);
font-synthesis: none;
text-rendering: optimizeLegibility;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
@media (max-width: 1024px) {
font-size: 16px;
}
}
@media (prefers-color-scheme: dark) {
:root {
--text: #9ca3af;
--text-h: #f3f4f6;
--bg: #16171d;
--border: #2e303a;
--code-bg: #1f2028;
--accent: #c084fc;
--accent-bg: rgba(192, 132, 252, 0.15);
--accent-border: rgba(192, 132, 252, 0.5);
--social-bg: rgba(47, 48, 58, 0.5);
--shadow:
rgba(0, 0, 0, 0.4) 0 10px 15px -3px, rgba(0, 0, 0, 0.25) 0 4px 6px -2px;
}
#social .button-icon {
filter: invert(1) brightness(2);
}
}
body {
margin: 0;
}
h1,
h2 {
font-family: var(--heading);
font-weight: 500;
color: var(--text-h);
}
h1 {
font-size: 56px;
letter-spacing: -1.68px;
margin: 32px 0;
@media (max-width: 1024px) {
font-size: 36px;
margin: 20px 0;
}
}
h2 {
font-size: 24px;
line-height: 118%;
letter-spacing: -0.24px;
margin: 0 0 8px;
@media (max-width: 1024px) {
font-size: 20px;
}
}
p {
margin: 0;
}
code,
.counter {
font-family: var(--mono);
display: inline-flex;
border-radius: 4px;
color: var(--text-h);
}
code {
font-size: 15px;
line-height: 135%;
padding: 4px 8px;
background: var(--code-bg);
}
.counter {
font-size: 16px;
padding: 5px 10px;
border-radius: 5px;
color: var(--accent);
background: var(--accent-bg);
border: 2px solid transparent;
transition: border-color 0.3s;
margin-bottom: 24px;
&:hover {
border-color: var(--accent-border);
}
&:focus-visible {
outline: 2px solid var(--accent);
outline-offset: 2px;
}
}
.hero {
position: relative;
.base,
.framework,
.vite {
inset-inline: 0;
margin: 0 auto;
}
.base {
width: 170px;
position: relative;
z-index: 0;
}
.framework,
.vite {
position: absolute;
}
.framework {
z-index: 1;
top: 34px;
height: 28px;
transform: perspective(2000px) rotateZ(300deg) rotateX(44deg) rotateY(39deg)
scale(1.4);
}
.vite {
z-index: 0;
top: 107px;
height: 26px;
width: auto;
transform: perspective(2000px) rotateZ(300deg) rotateX(40deg) rotateY(39deg)
scale(0.8);
}
}
#app {
width: 1126px;
max-width: 100%;
margin: 0 auto;
text-align: center;
border-inline: 1px solid var(--border);
min-height: 100svh;
display: flex;
flex-direction: column;
box-sizing: border-box;
}
#center {
display: flex;
flex-direction: column;
gap: 25px;
place-content: center;
place-items: center;
flex-grow: 1;
@media (max-width: 1024px) {
padding: 32px 20px 24px;
gap: 18px;
}
}
#next-steps {
display: flex;
border-top: 1px solid var(--border);
text-align: left;
& > div {
flex: 1 1 0;
padding: 32px;
@media (max-width: 1024px) {
padding: 24px 20px;
}
}
.icon {
margin-bottom: 16px;
width: 22px;
height: 22px;
}
@media (max-width: 1024px) {
flex-direction: column;
text-align: center;
}
}
#docs {
border-right: 1px solid var(--border);
@media (max-width: 1024px) {
border-right: none;
border-bottom: 1px solid var(--border);
}
}
#next-steps ul {
list-style: none;
padding: 0;
display: flex;
gap: 8px;
margin: 32px 0 0;
.logo {
height: 18px;
}
a {
color: var(--text-h);
font-size: 16px;
border-radius: 6px;
background: var(--social-bg);
display: flex;
padding: 6px 12px;
align-items: center;
gap: 8px;
text-decoration: none;
transition: box-shadow 0.3s;
&:hover {
box-shadow: var(--shadow);
}
.button-icon {
height: 18px;
width: 18px;
}
}
@media (max-width: 1024px) {
margin-top: 20px;
flex-wrap: wrap;
justify-content: center;
li {
flex: 1 1 calc(50% - 8px);
}
a {
width: 100%;
justify-content: center;
box-sizing: border-box;
}
}
}
#spacer {
height: 88px;
border-top: 1px solid var(--border);
@media (max-width: 1024px) {
height: 48px;
}
}
.ticks {
position: relative;
width: 100%;
&::before,
&::after {
content: '';
position: absolute;
top: -4.5px;
border: 5px solid transparent;
}
&::before {
left: 0;
border-left-color: var(--border);
}
&::after {
right: 0;
border-right-color: var(--border);
}
}

View File

@ -0,0 +1,14 @@
{
"extends": "@vue/tsconfig/tsconfig.dom.json",
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
"types": ["vite/client"],
/* Linting */
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true
},
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"]
}

View File

@ -0,0 +1,7 @@
{
"files": [],
"references": [
{ "path": "./tsconfig.app.json" },
{ "path": "./tsconfig.node.json" }
]
}

View File

@ -0,0 +1,24 @@
{
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
"target": "es2023",
"lib": ["ES2023"],
"module": "esnext",
"types": ["node"],
"skipLibCheck": true,
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"verbatimModuleSyntax": true,
"moduleDetection": "force",
"noEmit": true,
/* Linting */
"noUnusedLocals": true,
"noUnusedParameters": true,
"erasableSyntaxOnly": true,
"noFallthroughCasesInSwitch": true
},
"include": ["vite.config.ts"]
}

View File

@ -0,0 +1,7 @@
import { defineConfig } from 'vite'
import vue from '@vitejs/plugin-vue'
// https://vite.dev/config/
export default defineConfig({
plugins: [vue()],
})