refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名
This commit is contained in:
40
new/app/core/algorithms/__init__.py
Normal file
40
new/app/core/algorithms/__init__.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""
|
||||
去耀斑算法包
|
||||
============
|
||||
|
||||
通过「注册表 + 策略模式」组织不同的去耀斑算法。
|
||||
所有具体算法都应继承 BaseGlintRemover,并使用 @register_glint_remover
|
||||
装饰器把算法名和实现类绑定。
|
||||
|
||||
外部调用约定
|
||||
------------
|
||||
1. 所有算法子模块必须在本 __init__ 中显式 import,
|
||||
这样装饰器才会被执行、注册表才会被填满。
|
||||
2. 上层(endpoints、worker)只允许:
|
||||
from app.core.algorithms import get_remover
|
||||
来获取算法类,不要直接 import 具体实现类,
|
||||
保持调度层与具体算法的解耦。
|
||||
"""
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
from app.core.algorithms.registry import (
|
||||
get_remover,
|
||||
list_removers,
|
||||
register_glint_remover,
|
||||
unregister_glint_remover,
|
||||
)
|
||||
|
||||
# ---- 算法子模块 import 区 ----
|
||||
# 新增算法时,在这里加一行 import,确保装饰器被执行。
|
||||
from app.core.algorithms import goodman # Goodman
|
||||
from app.core.algorithms import kutser # Kutser
|
||||
# from app.core.algorithms import hedley # Hedley
|
||||
# from app.core.algorithms import sugar # SUGAR
|
||||
|
||||
__all__ = [
|
||||
"BaseGlintRemover",
|
||||
"register_glint_remover",
|
||||
"get_remover",
|
||||
"list_removers",
|
||||
"unregister_glint_remover",
|
||||
]
|
||||
85
new/app/core/algorithms/base.py
Normal file
85
new/app/core/algorithms/base.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""
|
||||
去耀斑算法抽象基类
|
||||
==================
|
||||
|
||||
设计目标(策略模式 Strategy Pattern)
|
||||
------------------------------------
|
||||
本模块定义了所有去耀斑算法必须遵守的标准接口。
|
||||
未来的 Kutser、Goodman、Hedley、SUGAR 等算法都将继承本基类,
|
||||
并实现统一的 process() 方法。
|
||||
|
||||
输入输出规范
|
||||
------------
|
||||
所有算法的输入与输出均统一为 **Zarr 文件路径**(字符串),
|
||||
而不是内存中的 numpy ndarray。这样做的核心收益是:
|
||||
|
||||
1. **解耦数据存储与内存计算**:
|
||||
算法只关心「从哪个 zarr 读、写到哪个 zarr」,
|
||||
至于数据最初来自 GeoTIFF / HDF5 / NetCDF / 内存数组,
|
||||
都由 IO 层负责归一化转为 zarr。
|
||||
2. **支持 Out-of-Core 计算**:
|
||||
影像往往超过内存上限,zarr 分块(chunk)天然支持按块读取,
|
||||
算法实现可以借助 dask / xarray 进行流式计算。
|
||||
3. **可缓存、可复用**:
|
||||
中间产物落盘后,下游算法(大气校正、辐射定标)能直接消费,
|
||||
避免重复 IO。
|
||||
4. **易于并行与分布式**:
|
||||
任务调度层只需把两个路径扔给 worker,无需关心数据细节。
|
||||
|
||||
约定
|
||||
----
|
||||
- 子类应实现 process(),完成「读 -> 计算 -> 写」的完整流程。
|
||||
- process() 返回 True 表示成功,False 表示失败。
|
||||
- 失败时建议抛出异常而非仅返回 False,便于上层 BackgroundTasks 捕获并写入 error 字段。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseGlintRemover(ABC):
|
||||
"""
|
||||
去耀斑算法抽象基类。
|
||||
|
||||
所有具体算法(Kutser / Goodman / Hedley / SUGAR …)必须继承本类并实现 process()。
|
||||
子类可在 __init__ 中接收自己的超参数(如参考波段、阈值等),
|
||||
真正的输入输出数据则由 process() 的两个 zarr 路径参数指定。
|
||||
"""
|
||||
|
||||
# 子类可覆盖的算法名称标识,用于调度层按 method 名字查找
|
||||
name: str = "base"
|
||||
|
||||
@abstractmethod
|
||||
async def process(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
执行去耀斑处理。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_zarr_path : str
|
||||
输入高光谱影像的 zarr 存储路径。
|
||||
数据已由 IO 层完成格式归一化(波段、坐标系、空间维度均已对齐)。
|
||||
output_zarr_path : str
|
||||
处理结果(去耀斑后影像)的 zarr 存储路径。
|
||||
子类需自行创建该 zarr 存储并写入结果。
|
||||
**kwargs : Any
|
||||
算法的可选超参数,例如:
|
||||
- reference_band: 参考近红外波段索引
|
||||
- chunk_size: 计算分块大小
|
||||
- 其它算法特定参数
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True 表示处理成功,False 表示失败。
|
||||
建议在出错时直接 raise,由调用方统一记录到任务状态。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover - 调试辅助
|
||||
return f"<{self.__class__.__name__} name={self.name!r}>"
|
||||
123
new/app/core/algorithms/goodman.py
Normal file
123
new/app/core/algorithms/goodman.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""
|
||||
app/core/algorithms/goodman.py
|
||||
===============================
|
||||
|
||||
Goodman et al. 2008 去耀斑算法的 xarray + dask 流式实现。
|
||||
|
||||
算法公式
|
||||
--------
|
||||
R_corrected = R_raw - R_750 + A + B * (R_640 - R_750)
|
||||
|
||||
其中:
|
||||
R_raw -- 原始反射率 (y, x, band)
|
||||
R_750 -- λ=750 nm 处的反射率(红外参考波段, 远离水汽吸收)
|
||||
R_640 -- λ=640 nm 处的反射率(可见光差异波段)
|
||||
A, B -- 经验回归参数(用户可通过 params 传入, 默认全 0)
|
||||
|
||||
后处理
|
||||
------
|
||||
- 负值截断为 0(Clamp to 0)
|
||||
- 仅在水域掩膜 (water_mask) 内生效, 水外置 0
|
||||
|
||||
维度约定
|
||||
--------
|
||||
reflectance: (y, x, band), band 坐标通常为 wavelength (nm)
|
||||
water_mask : (y, x), 布尔类型, True = 水域
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import xarray as xr
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
from app.core.algorithms.registry import register_glint_remover
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 默认参数
|
||||
# ---------------------------------------------------------------------------
|
||||
# 与原始 Goodman 2008 论文符号保持一致, 方便用户交叉对照。
|
||||
# A、B 通常通过对纯净深水区做 (R_corr - R_raw) ~ (R_640 - R_750) 回归得到;
|
||||
# 在缺乏先验知识时, 退化为 A=0, B=0 即等价于 R_corrected = clip(R_raw - R_750, 0)。
|
||||
# ---------------------------------------------------------------------------
|
||||
DEFAULT_BAND_REF: float = 750.0 # λ_750 nm, 红外参考波段
|
||||
DEFAULT_BAND_DIFF: float = 640.0 # λ_640 nm, 可见光差异波段
|
||||
DEFAULT_A: float = 0.0 # 公式中的常数偏移项
|
||||
DEFAULT_B: float = 0.0 # 公式中的斜率项
|
||||
|
||||
|
||||
@register_glint_remover("goodman")
|
||||
class GoodmanGlintRemover(BaseGlintRemover):
|
||||
"""Goodman et al. 2008 去耀斑算法"""
|
||||
|
||||
name = "goodman"
|
||||
|
||||
async def process(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
# 1. 解析超参数(带默认值, 方便用户按需覆盖)
|
||||
band_ref: float = kwargs.get("band_ref", DEFAULT_BAND_REF)
|
||||
band_diff: float = kwargs.get("band_diff", DEFAULT_BAND_DIFF)
|
||||
A: float = kwargs.get("A", DEFAULT_A)
|
||||
B: float = kwargs.get("B", DEFAULT_B)
|
||||
|
||||
# 2. 把同步的 xarray/dask 计算丢到工作线程,
|
||||
# 避免阻塞 FastAPI 的事件循环
|
||||
return await asyncio.to_thread(
|
||||
self._process_sync,
|
||||
input_zarr_path,
|
||||
output_zarr_path,
|
||||
band_ref,
|
||||
band_diff,
|
||||
A,
|
||||
B,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_sync(
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
band_ref: float,
|
||||
band_diff: float,
|
||||
A: float,
|
||||
B: float,
|
||||
) -> bool:
|
||||
# 1. 以 zarr 路径打开(dask-backed, 不物化到内存)
|
||||
# chunks="auto" 让 dask 根据每条坐标轴的大小自动决定分块
|
||||
ds = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
reflectance = ds["reflectance"] # (y, x, band)
|
||||
|
||||
# 2. 用 sel + method='nearest' 提取两个关键波段
|
||||
# 返回形状 (y, x), 后续与 (y, x, band) 算术时会自动广播
|
||||
R_750 = reflectance.sel(band=band_ref, method="nearest")
|
||||
R_640 = reflectance.sel(band=band_diff, method="nearest")
|
||||
|
||||
# 3. Goodman 公式: xarray 沿 band 维度自动广播
|
||||
# R_corr = R_raw - R_750 + A + B * (R_640 - R_750)
|
||||
result = reflectance - R_750 + A + B * (R_640 - R_750)
|
||||
|
||||
# 4. 负值截断为 0(clip(min=0) 优于 where(>0, 0, _):
|
||||
# 不构造布尔中间数组, 底层走 dask 矢量化 clip 路径)
|
||||
result = result.clip(min=0)
|
||||
|
||||
# 5. 仅在水域内生效(水外强制为 0)
|
||||
# 优先从 zarr 内部读 water_mask 变量, 缺失则视为全图水域
|
||||
if "water_mask" in ds:
|
||||
water_mask = ds["water_mask"].astype(bool)
|
||||
result = result.where(water_mask, 0)
|
||||
|
||||
# 6. 构造输出 Dataset, 保留元信息(波段坐标/属性等)
|
||||
out = xr.Dataset({"reflectance": result})
|
||||
if ds.attrs:
|
||||
out.attrs = dict(ds.attrs)
|
||||
if reflectance.attrs:
|
||||
out["reflectance"].attrs = dict(reflectance.attrs)
|
||||
|
||||
# 7. 流式写出(Out-of-Core):不一次性物化大数组,
|
||||
# dask 会按 chunk 边算边写, 内存峰值 ≈ 单个 chunk 大小
|
||||
out.to_zarr(output_zarr_path, mode="w", compute=True)
|
||||
return True
|
||||
211
new/app/core/algorithms/kutser.py
Normal file
211
new/app/core/algorithms/kutser.py
Normal file
@ -0,0 +1,211 @@
|
||||
"""
|
||||
Kutser 去耀斑算法(xarray + dask 重构版)
|
||||
========================================
|
||||
|
||||
旧版痛点
|
||||
--------
|
||||
原始 Kutser 实现(参考 Kutser et al., 2013)通常写成像这样:
|
||||
|
||||
R_corr = np.zeros_like(R_raw)
|
||||
for b in range(n_bands):
|
||||
for y in range(H):
|
||||
for x in range(W):
|
||||
if water_mask[y, x]:
|
||||
R_corr[y, x, b] = (
|
||||
R_raw[y, x, b] - G_list[b] * D_norm[y, x]
|
||||
)
|
||||
with rasterio.open(..., 'w') as dst:
|
||||
dst.write(R_corr)
|
||||
|
||||
问题:
|
||||
1. 三重 Python 循环,每次只做一个浮点运算,解释器开销巨大;
|
||||
2. 一次性把整张图 R_raw 读进内存,大影像直接 OOM;
|
||||
3. rasterio 写出要求 numpy 连续数组,进一步放大内存。
|
||||
|
||||
本文件用 xarray + dask 重写:
|
||||
- 用 DataArray 维度广播,三重循环 → 一行表达式;
|
||||
- 用 dask chunk 保持数据常驻磁盘、流式计算;
|
||||
- 用 to_zarr 边算边写,输出格式与算法层彻底解耦。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import xarray as xr
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
from app.core.algorithms.registry import register_glint_remover
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 算法实现
|
||||
# ---------------------------------------------------------------------------
|
||||
@register_glint_remover("kutser")
|
||||
class KutserGlintRemover(BaseGlintRemover):
|
||||
"""
|
||||
Kutser 近红外扣除法去耀斑。
|
||||
|
||||
数学公式(与旧版完全等价)
|
||||
-------------------------
|
||||
1) 水汽吸收深度 D(每像素):
|
||||
D = (R(λ_lower) + R(λ_upper)) / 2 - R(λ_oxy)
|
||||
2) 全局归一化因子 D_max:
|
||||
D_max = max(D) over 水域
|
||||
归一化:
|
||||
D_norm = D / D_max
|
||||
3) 每波段水域范围:
|
||||
G_list[b] = max(R[:, :, b] over 水域) - min(R[:, :, b] over 水域)
|
||||
4) 校正公式(每像素、每波段):
|
||||
R_corr(λ_b) = R_raw(λ_b) - G_list[b] * D_norm
|
||||
"""
|
||||
|
||||
# Kutser 2013 论文里使用的参考波段(nm):
|
||||
# λ_lower = 773, λ_oxy = 845, λ_upper = 893
|
||||
# 允许通过 kwargs 覆盖,便于适配 MERIS / OLCI / Landsat 等不同传感器。
|
||||
DEFAULT_BAND_LOWER: float = 773.0
|
||||
DEFAULT_BAND_OXY: float = 845.0
|
||||
DEFAULT_BAND_UPPER: float = 893.0
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# 公开异步入口
|
||||
# --------------------------------------------------------------
|
||||
# xarray / dask 的算子本身是同步阻塞的。在 async 函数中,
|
||||
# 用 asyncio.to_thread 把同步体丢到默认线程池执行,
|
||||
# 避免阻塞 FastAPI 的事件循环。
|
||||
# --------------------------------------------------------------
|
||||
async def process(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
return await asyncio.to_thread(
|
||||
self._process_sync,
|
||||
input_zarr_path,
|
||||
output_zarr_path,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------
|
||||
# 同步核心实现
|
||||
# --------------------------------------------------------------
|
||||
def _process_sync(
|
||||
self,
|
||||
input_zarr_path: str,
|
||||
output_zarr_path: str,
|
||||
kwargs: dict,
|
||||
) -> bool:
|
||||
# ============================================================
|
||||
# 步骤 0:打开 zarr,建立 dask 计算图
|
||||
# ============================================================
|
||||
# chunks="auto":让 dask 根据 zarr 的存储分块自动选择内存上限,
|
||||
# 数据不会一次性全部 materialize 进 RAM。
|
||||
# ============================================================
|
||||
ds = xr.open_zarr(input_zarr_path, chunks="auto")
|
||||
reflectance: xr.DataArray = ds["reflectance"] # 维度约定:(y, x, band)
|
||||
|
||||
# 维度顺序约定(也可根据 ds.dims 自动适配):
|
||||
assert "y" in reflectance.dims and "x" in reflectance.dims and "band" in reflectance.dims, (
|
||||
f"reflectance 必须包含 y/x/band 三个维度,实际为: {reflectance.dims}"
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 1:取出 3 个参考波段对应的二维 (y, x) 切片
|
||||
# ============================================================
|
||||
# 假设 band 维度的坐标是 wavelength(nm)。
|
||||
# 用 sel(..., method="nearest") 自动匹配最接近的波段。
|
||||
# ============================================================
|
||||
wl_lower = float(kwargs.get("band_lower", self.DEFAULT_BAND_LOWER))
|
||||
wl_oxy = float(kwargs.get("band_oxy", self.DEFAULT_BAND_OXY))
|
||||
wl_upper = float(kwargs.get("band_upper", self.DEFAULT_BAND_UPPER))
|
||||
|
||||
R_lower = reflectance.sel(band=wl_lower, method="nearest") # (y, x)
|
||||
R_upper = reflectance.sel(band=wl_upper, method="nearest") # (y, x)
|
||||
R_oxy = reflectance.sel(band=wl_oxy, method="nearest") # (y, x)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 2:水域掩膜
|
||||
# ============================================================
|
||||
# 优先从 zarr 内部读取 water_mask 变量;
|
||||
# 如果不存在,则假定整幅图都是水域(开发期兜底)。
|
||||
# ============================================================
|
||||
if "water_mask" in ds:
|
||||
water_mask = ds["water_mask"].astype(bool)
|
||||
else:
|
||||
water_mask = xr.ones_like(
|
||||
reflectance.isel(band=0), dtype=bool
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 3:水汽吸收深度 D(每像素,形状 (y, x))
|
||||
# ============================================================
|
||||
# 旧版:D[y, x] = (R_lower[y, x] + R_upper[y, x]) / 2 - R_oxy[y, x]
|
||||
# 新版:一行表达式,dask 自动构建 lazy 计算图。
|
||||
# ============================================================
|
||||
D = (R_lower + R_upper) / 2.0 - R_oxy # (y, x),dtype 与 reflectance 一致
|
||||
|
||||
# ============================================================
|
||||
# 步骤 4:全局归一化因子 D_max(标量,0-dim DataArray)
|
||||
# ============================================================
|
||||
# 关键:先 .where(water_mask) 把非水域置 NaN,
|
||||
# 再 .max() 跨 (x, y) 聚合,自动规约到 0 维。
|
||||
# dask 此时仍然没有真正计算,等到 to_zarr 时再触发。
|
||||
# ============================================================
|
||||
D_max = D.where(water_mask).max() # scalar
|
||||
# 容错:如果水域为空导致 D_max 为 NaN,用极小值兜底,避免除零
|
||||
D_max = D_max.fillna(1e-6)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 5:归一化 D_norm(形状 (y, x))
|
||||
# ============================================================
|
||||
D_norm = D / D_max # 标量除以 (y, x) 数组 → 自动广播
|
||||
|
||||
# ============================================================
|
||||
# 步骤 6:每波段水域范围 G_list(形状 (band,))
|
||||
# ============================================================
|
||||
# 旧版三重循环内部还要做一次 min/max 聚合。
|
||||
# xarray 版本:把 (y, x) 一起 reduce,只保留 band 维度。
|
||||
# ============================================================
|
||||
R_water = reflectance.where(water_mask) # (y, x, band),非水域 NaN
|
||||
G_min = R_water.min(dim=["x", "y"]) # (band,)
|
||||
G_max = R_water.max(dim=["x", "y"]) # (band,)
|
||||
G_list = (G_max - G_min).fillna(0.0) # (band,),容错
|
||||
|
||||
# ============================================================
|
||||
# 步骤 7:校正公式(最关键的一行,演示 xarray 广播)
|
||||
# ============================================================
|
||||
# 旧版需要:
|
||||
# for b in bands:
|
||||
# for y in range(H):
|
||||
# for x in range(W):
|
||||
# R_corr[y,x,b] = R_raw[y,x,b] - G_list[b] * D_norm[y,x]
|
||||
#
|
||||
# xarray 维度对齐规则:
|
||||
# R_raw : (y, x, band)
|
||||
# G_list: (band,) → 缺失 y, x 自动扩展
|
||||
# D_norm: (y, x) → 缺失 band 自动扩展
|
||||
# 乘法结果: (y, x, band) → 减法对齐
|
||||
# 一行表达式完成「三重 for 循环 + 标量索引」的语义。
|
||||
# ============================================================
|
||||
corrected = reflectance - G_list * D_norm # (y, x, band)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 8:水域掩膜过滤(非水域置 NaN)
|
||||
# ============================================================
|
||||
result = corrected.where(water_mask)
|
||||
|
||||
# ============================================================
|
||||
# 步骤 9:持久化为 zarr
|
||||
# ============================================================
|
||||
# mode="w":覆盖写入(如果目标已存在则删除重建)。
|
||||
# compute=True:阻塞直到整张图算完并落盘。
|
||||
# 由于数据始终是 dask chunk + 流式写出,
|
||||
# 内存峰值 ≈ 单个 chunk 大小,与整张影像大小无关。
|
||||
# ============================================================
|
||||
out = xr.Dataset({"reflectance": result})
|
||||
# 保留原数据集的全局属性 / 坐标信息(CRS、wavelength、...)
|
||||
out.attrs = dict(ds.attrs)
|
||||
out["reflectance"].attrs = dict(reflectance.attrs)
|
||||
out.to_zarr(output_zarr_path, mode="w", compute=True)
|
||||
|
||||
return True
|
||||
135
new/app/core/algorithms/registry.py
Normal file
135
new/app/core/algorithms/registry.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""
|
||||
算法注册表(Registry / Factory)
|
||||
================================
|
||||
|
||||
通过装饰器把「算法名字符串」与「算法实现类」绑定在一起。
|
||||
上层调度层(FastAPI endpoints、BackgroundTasks worker)只需要拿到
|
||||
前端传过来的 method 字符串,就可以自动派发到对应的算法实现,
|
||||
而无需写一长串 if/elif。
|
||||
|
||||
使用示例
|
||||
--------
|
||||
|
||||
from app.core.algorithms import BaseGlintRemover
|
||||
from app.core.algorithms.registry import (
|
||||
register_glint_remover,
|
||||
get_remover,
|
||||
list_removers,
|
||||
)
|
||||
|
||||
@register_glint_remover("kutser")
|
||||
class KutserGlintRemover(BaseGlintRemover):
|
||||
async def process(self, input_zarr_path, output_zarr_path, **kwargs):
|
||||
...
|
||||
|
||||
# 派发
|
||||
Cls = get_remover(method_from_request)
|
||||
remover = Cls()
|
||||
await remover.process(input_zarr_path, output_zarr_path, **kwargs)
|
||||
|
||||
设计要点
|
||||
--------
|
||||
- 注册动作发生在「类定义时」,所以必须在所有算法 import 完之后
|
||||
注册表才完整。可以在 `app/core/algorithms/__init__.py` 中
|
||||
把算法子模块 import 一遍来强制触发注册。
|
||||
- 重复注册同名算法会直接抛错,避免静默覆盖。
|
||||
- name 会同步写回到类的 `name` 属性,便于算法自身查询身份。
|
||||
"""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from app.core.algorithms.base import BaseGlintRemover
|
||||
|
||||
|
||||
# 全局注册表:name(str) -> 实现类(type),类未被实例化
|
||||
_REGISTRY: Dict[str, Type[BaseGlintRemover]] = {}
|
||||
|
||||
|
||||
def register_glint_remover(name: str):
|
||||
"""
|
||||
类装饰器工厂:把传入 name 的算法类注册到全局注册表。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
算法标识,建议小写下划线风格,例如 "kutser"、"goodman"。
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
- name 不是非空字符串
|
||||
- name 已经被其它类占用
|
||||
TypeError
|
||||
- 被装饰的对象不是 BaseGlintRemover 的子类
|
||||
"""
|
||||
|
||||
# ---- 防御性校验:name 必须是合法字符串 ----
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError(
|
||||
f"register_glint_remover 的 name 必须是非空字符串,收到: {name!r}"
|
||||
)
|
||||
|
||||
def decorator(cls: Type[BaseGlintRemover]) -> Type[BaseGlintRemover]:
|
||||
# ---- 防御性校验:被装饰对象必须是 BaseGlintRemover 子类 ----
|
||||
if not isinstance(cls, type) or not issubclass(cls, BaseGlintRemover):
|
||||
raise TypeError(
|
||||
f"@register_glint_remover 只能装饰 BaseGlintRemover 的子类,"
|
||||
f"收到: {cls!r}"
|
||||
)
|
||||
|
||||
# ---- 防御性校验:禁止静默覆盖 ----
|
||||
if name in _REGISTRY:
|
||||
raise ValueError(
|
||||
f"算法名 {name!r} 已被 {_REGISTRY[name].__name__} 占用,"
|
||||
f"请使用其它名字或先调用 unregister_glint_remover() 注销旧实现。"
|
||||
)
|
||||
|
||||
# 同步把 name 写回类属性,便于算法自身和日志输出使用
|
||||
cls.name = name
|
||||
_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_remover(name: str) -> Type[BaseGlintRemover]:
|
||||
"""
|
||||
按算法名字符串取出对应的实现类(未实例化)。
|
||||
|
||||
调用方拿到类后自行 `Cls(...)` 构造实例,再调用 process()。
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
当 name 不在注册表中时抛出,错误信息中附带已注册列表便于排查。
|
||||
"""
|
||||
try:
|
||||
return _REGISTRY[name]
|
||||
except KeyError as exc:
|
||||
known = ", ".join(sorted(_REGISTRY)) or "<空>"
|
||||
raise KeyError(
|
||||
f"未注册的算法名: {name!r}。已注册的算法: {known}"
|
||||
) from exc
|
||||
|
||||
|
||||
def list_removers() -> Dict[str, Type[BaseGlintRemover]]:
|
||||
"""
|
||||
返回当前注册表的浅拷贝。
|
||||
可用于:
|
||||
- 调试日志
|
||||
- 给前端暴露一个 GET /api/algorithms 接口
|
||||
- 单元测试断言
|
||||
"""
|
||||
return dict(_REGISTRY)
|
||||
|
||||
|
||||
def unregister_glint_remover(name: str) -> None:
|
||||
"""
|
||||
注销指定算法。主要给:
|
||||
- 单元测试
|
||||
- 热重载 / 插件卸载场景
|
||||
生产代码一般不需要调用。
|
||||
"""
|
||||
if name not in _REGISTRY:
|
||||
raise KeyError(f"未注册的算法名: {name!r}")
|
||||
del _REGISTRY[name]
|
||||
Reference in New Issue
Block a user