feat(step8): 外部模型从单文件升级为母文件夹多模型字典扫描
This commit is contained in:
@ -30,6 +30,7 @@ app/api/modeling.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@ -40,7 +41,7 @@ import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xarray as xr
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, UploadFile, File
|
||||
from pydantic import BaseModel, Field
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
@ -784,3 +785,63 @@ async def submit_predict(
|
||||
payload.output_zarr_path,
|
||||
)
|
||||
return {"task_id": task_id, "status": "PENDING", "kind": "predict"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# models_router — 独立于 modeling_router,路径前缀为 /models
|
||||
# 最终完整路径: GET /api/models, POST /api/models/upload
|
||||
# ---------------------------------------------------------------------------
|
||||
models_router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/models
|
||||
# ---------------------------------------------------------------------------
|
||||
@models_router.get("")
|
||||
async def list_models() -> Dict[str, Any]:
|
||||
"""
|
||||
扫描 ./data/models/ 目录,返回所有 .joblib 文件名(不含后缀)。
|
||||
|
||||
异常处理:目录不存在时自动创建,返回空列表。
|
||||
"""
|
||||
models_dir = Path("./data/models")
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model_names = [
|
||||
p.stem for p in models_dir.iterdir() if p.suffix == ".joblib"
|
||||
]
|
||||
return {"models": model_names}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/models/upload
|
||||
# ---------------------------------------------------------------------------
|
||||
@models_router.post("/upload")
|
||||
async def upload_model(
|
||||
file: UploadFile = File(...),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
接收上传的 .joblib 模型文件,保存到 ./data/models/ 目录。
|
||||
|
||||
- 校验后缀必须为 .joblib
|
||||
- 目录不存在时自动创建
|
||||
- 返回状态和文件名(不含后缀)
|
||||
"""
|
||||
if not file.filename or not file.filename.lower().endswith(".joblib"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="仅支持 .joblib 格式的文件",
|
||||
)
|
||||
|
||||
models_dir = Path("./data/models")
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dest_path = models_dir / file.filename
|
||||
|
||||
with dest_path.open("wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"model_id": dest_path.stem,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user