Files
WQ_GUI/new/app/api/endpoints.py

223 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API 路由集合
============
把业务接口统一收口到 APIRouter再由 main.py 通过 include_router 挂载。
当前包含的接口:
GET /api/algorithms 列出已注册的所有去耀斑算法(供前端下拉框)
POST /api/process/deglint 提交去耀斑处理任务,立即返回 task_id
GET /api/tasks/{task_id} 查询指定任务的状态与结果
派发链:
POST /api/process/deglint
└─ BackgroundTasks.add_task(execute_glint_removal_task, ...)
└─ get_remover(method) 从注册表拿到算法类
└─ remover.process(input_zarr, output_zarr, **params)
"""
import traceback
import uuid
from datetime import datetime
from typing import Any, Dict
from fastapi import APIRouter, BackgroundTasks, HTTPException
from pydantic import BaseModel, Field
# 并发安全的任务状态存储(替代旧版的 MOCK_TASK_DB
from app.core.task_store import get_task, set_task, update_task
# 算法注册表 API
from app.core.algorithms import get_remover, list_removers
# ---------------------------------------------------------------------------
# 路由实例
# ---------------------------------------------------------------------------
# prefix 不在此处设置,统一在 main.py 挂载时给定,便于将来按版本拆分
# (例如 /api/v1、/api/v2 共存时复用同一个 router 对象)。
# ---------------------------------------------------------------------------
router = APIRouter(tags=["deglint"])
# ---------------------------------------------------------------------------
# 请求 / 响应数据模型
# ---------------------------------------------------------------------------
class DeglintRequest(BaseModel):
"""POST /api/process/deglint 的请求体"""
method: str = Field(
...,
description="去耀斑方法名称,必须是已注册算法,例如 'kutser' / 'goodman'",
examples=["kutser"],
)
params: Dict[str, Any] = Field(
default_factory=dict,
description=(
"传递给算法 process() 的超参数字典,例如 "
"Kutser: {'band_lower': 773, 'band_oxy': 845, 'band_upper': 893}; "
"Goodman: {'band_ref': 750, 'band_diff': 640, 'A': 0.0, 'B': 0.0}"
),
examples=[{"band_lower": 773, "band_oxy": 845, "band_upper": 893}],
)
class TaskAcceptedResponse(BaseModel):
"""提交任务成功后立即返回的响应"""
task_id: str
status: str # 一定是 PENDING
class AlgorithmListResponse(BaseModel):
"""GET /api/algorithms 的响应"""
algorithms: list # 已注册算法名列表
count: int # 算法总数
# ---------------------------------------------------------------------------
# 后台任务执行器(真实派发链)
# ---------------------------------------------------------------------------
# 注意:这里使用 async def。
# FastAPI / Starlette 的 BackgroundTasks 支持 async function
# 会在响应返回后自动 await 它,不影响主请求链路。
# ---------------------------------------------------------------------------
async def execute_glint_removal_task(
task_id: str,
method: str,
params: Dict[str, Any],
) -> None:
"""
后台异步执行器:按 method 名字从注册表取出算法类,实例化并运行 process()。
状态机:
PENDING -> PROCESSING -> SUCCESS
└──> FAILED含 error / traceback
"""
# 0. 安全检查任务记录必须已存在POST 阶段已写入)
record = await get_task(task_id)
if record is None:
print(f"[{task_id}] 任务不存在, 跳过")
return
# 1. 状态推进到 PROCESSING
await update_task(
task_id,
status="PROCESSING",
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 开始处理 method={method} params={params}")
# 2. 临时硬编码 IO 路径(未来由数据管理层提供)
# TODO: 替换为真实的数据管理服务返回的 zarr 路径
input_zarr_path = "./data/temp_in.zarr"
output_zarr_path = f"./data/{task_id}_out.zarr"
try:
# 3. 按 method 名字从注册表取算法类并实例化
# get_remover 找不到时会抛 KeyError下面的 except 会兜住
algorithm_cls = get_remover(method)
remover = algorithm_cls()
# 4. 调用算法(注意 await因为 BaseGlintRemover.process 是 async
await remover.process(input_zarr_path, output_zarr_path, **params)
# 5. 成功:写回结果路径与状态
await update_task(
task_id,
status="SUCCESS",
output_zarr_path=output_zarr_path,
error=None,
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 处理完成 -> SUCCESS, output={output_zarr_path}")
except Exception as exc: # noqa: BLE001 顶层兜底,绝不让后台任务静默失败
# 6. 失败:记录错误信息与堆栈,便于前端排查
await update_task(
task_id,
status="FAILED",
output_zarr_path=None,
error=f"{type(exc).__name__}: {exc}",
traceback=traceback.format_exc(),
updated_at=datetime.now().isoformat(),
)
print(f"[{task_id}] 处理失败 -> {type(exc).__name__}: {exc}")
# ---------------------------------------------------------------------------
# GET /algorithms
# ---------------------------------------------------------------------------
# 返回当前已注册的所有算法名,供前端动态渲染下拉框 / 选择器。
# ---------------------------------------------------------------------------
@router.get("/algorithms", response_model=AlgorithmListResponse)
async def list_registered_algorithms() -> Dict[str, Any]:
"""列出已注册的去耀斑算法。"""
names = list(list_removers().keys())
return {"algorithms": names, "count": len(names)}
# ---------------------------------------------------------------------------
# POST /process/deglint
# ---------------------------------------------------------------------------
# 提交去耀斑处理任务。FastAPI 在函数返回后才会把响应发给前端,
# 因此通过 BackgroundTasks 把耗时操作丢到后台,接口本身立刻返回 task_id。
# ---------------------------------------------------------------------------
@router.post("/process/deglint", response_model=TaskAcceptedResponse)
async def submit_deglint(
payload: DeglintRequest,
background_tasks: BackgroundTasks,
) -> Dict[str, Any]:
"""提交一个去耀斑处理任务,并立即返回 task_id。"""
# 1. 生成唯一任务 IDUUID4 足以保证全局唯一性)
task_id = str(uuid.uuid4())
# 2. 在任务库中登记一条 PENDING 记录(并发安全)
# 注意output_zarr_path / error / traceback 字段在执行过程中被填充
await set_task(
task_id,
{
"task_id": task_id,
"method": payload.method,
"params": payload.params,
"status": "PENDING",
"output_zarr_path": None,
"error": None,
"traceback": None,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
},
)
# 3. 把真实执行器丢到后台
background_tasks.add_task(
execute_glint_removal_task,
task_id,
payload.method,
payload.params,
)
# 4. 立即返回 task_id 与 PENDING 状态
return {"task_id": task_id, "status": "PENDING"}
# ---------------------------------------------------------------------------
# GET /tasks/{task_id}
# ---------------------------------------------------------------------------
# 前端轮询此接口获取任务状态。PENDING / PROCESSING 表示仍在跑,
# SUCCESS 表示成功(含 output_zarr_pathFAILED 表示失败(含 error / traceback
# ---------------------------------------------------------------------------
@router.get("/tasks/{task_id}")
async def get_task_status(task_id: str) -> Dict[str, Any]:
"""查询指定任务的当前状态与结果。"""
record = await get_task(task_id)
if record is None:
# 找不到 task_id 通常意味着客户端拼错了 ID或者记录已被清理
raise HTTPException(status_code=404, detail=f"task_id 不存在: {task_id}")
# 直接返回字典FastAPI 会自动 JSON 序列化
return record