refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名

This commit is contained in:
DXC
2026-06-03 17:29:41 +08:00
parent 517bb28611
commit 343e316799
99 changed files with 9127 additions and 91 deletions

View File

@ -0,0 +1,40 @@
"""
去耀斑算法包
============
通过「注册表 + 策略模式」组织不同的去耀斑算法。
所有具体算法都应继承 BaseGlintRemover并使用 @register_glint_remover
装饰器把算法名和实现类绑定。
外部调用约定
------------
1. 所有算法子模块必须在本 __init__ 中显式 import
这样装饰器才会被执行、注册表才会被填满。
2. 上层endpoints、worker只允许
from app.core.algorithms import get_remover
来获取算法类,不要直接 import 具体实现类,
保持调度层与具体算法的解耦。
"""
from app.core.algorithms.base import BaseGlintRemover
from app.core.algorithms.registry import (
get_remover,
list_removers,
register_glint_remover,
unregister_glint_remover,
)
# ---- 算法子模块 import 区 ----
# 新增算法时,在这里加一行 import确保装饰器被执行。
from app.core.algorithms import goodman # Goodman
from app.core.algorithms import kutser # Kutser
# from app.core.algorithms import hedley # Hedley
# from app.core.algorithms import sugar # SUGAR
__all__ = [
"BaseGlintRemover",
"register_glint_remover",
"get_remover",
"list_removers",
"unregister_glint_remover",
]

View File

@ -0,0 +1,85 @@
"""
去耀斑算法抽象基类
==================
设计目标(策略模式 Strategy Pattern
------------------------------------
本模块定义了所有去耀斑算法必须遵守的标准接口。
未来的 Kutser、Goodman、Hedley、SUGAR 等算法都将继承本基类,
并实现统一的 process() 方法。
输入输出规范
------------
所有算法的输入与输出均统一为 **Zarr 文件路径**(字符串),
而不是内存中的 numpy ndarray。这样做的核心收益是
1. **解耦数据存储与内存计算**
算法只关心「从哪个 zarr 读、写到哪个 zarr」
至于数据最初来自 GeoTIFF / HDF5 / NetCDF / 内存数组,
都由 IO 层负责归一化转为 zarr。
2. **支持 Out-of-Core 计算**
影像往往超过内存上限zarr 分块chunk天然支持按块读取
算法实现可以借助 dask / xarray 进行流式计算。
3. **可缓存、可复用**
中间产物落盘后,下游算法(大气校正、辐射定标)能直接消费,
避免重复 IO。
4. **易于并行与分布式**
任务调度层只需把两个路径扔给 worker无需关心数据细节。
约定
----
- 子类应实现 process(),完成「读 -> 计算 -> 写」的完整流程。
- process() 返回 True 表示成功False 表示失败。
- 失败时建议抛出异常而非仅返回 False便于上层 BackgroundTasks 捕获并写入 error 字段。
"""
from abc import ABC, abstractmethod
from typing import Any
class BaseGlintRemover(ABC):
"""
去耀斑算法抽象基类。
所有具体算法Kutser / Goodman / Hedley / SUGAR …)必须继承本类并实现 process()。
子类可在 __init__ 中接收自己的超参数(如参考波段、阈值等),
真正的输入输出数据则由 process() 的两个 zarr 路径参数指定。
"""
# 子类可覆盖的算法名称标识,用于调度层按 method 名字查找
name: str = "base"
@abstractmethod
async def process(
self,
input_zarr_path: str,
output_zarr_path: str,
**kwargs: Any,
) -> bool:
"""
执行去耀斑处理。
Parameters
----------
input_zarr_path : str
输入高光谱影像的 zarr 存储路径。
数据已由 IO 层完成格式归一化(波段、坐标系、空间维度均已对齐)。
output_zarr_path : str
处理结果(去耀斑后影像)的 zarr 存储路径。
子类需自行创建该 zarr 存储并写入结果。
**kwargs : Any
算法的可选超参数,例如:
- reference_band: 参考近红外波段索引
- chunk_size: 计算分块大小
- 其它算法特定参数
Returns
-------
bool
True 表示处理成功False 表示失败。
建议在出错时直接 raise由调用方统一记录到任务状态。
"""
raise NotImplementedError
def __repr__(self) -> str: # pragma: no cover - 调试辅助
return f"<{self.__class__.__name__} name={self.name!r}>"

View File

@ -0,0 +1,123 @@
"""
app/core/algorithms/goodman.py
===============================
Goodman et al. 2008 去耀斑算法的 xarray + dask 流式实现。
算法公式
--------
R_corrected = R_raw - R_750 + A + B * (R_640 - R_750)
其中:
R_raw -- 原始反射率 (y, x, band)
R_750 -- λ=750 nm 处的反射率(红外参考波段, 远离水汽吸收)
R_640 -- λ=640 nm 处的反射率(可见光差异波段)
A, B -- 经验回归参数(用户可通过 params 传入, 默认全 0
后处理
------
- 负值截断为 0Clamp to 0
- 仅在水域掩膜 (water_mask) 内生效, 水外置 0
维度约定
--------
reflectance: (y, x, band), band 坐标通常为 wavelength (nm)
water_mask : (y, x), 布尔类型, True = 水域
"""
import asyncio
from typing import Any
import xarray as xr
from app.core.algorithms.base import BaseGlintRemover
from app.core.algorithms.registry import register_glint_remover
# ---------------------------------------------------------------------------
# 默认参数
# ---------------------------------------------------------------------------
# 与原始 Goodman 2008 论文符号保持一致, 方便用户交叉对照。
# A、B 通常通过对纯净深水区做 (R_corr - R_raw) ~ (R_640 - R_750) 回归得到;
# 在缺乏先验知识时, 退化为 A=0, B=0 即等价于 R_corrected = clip(R_raw - R_750, 0)。
# ---------------------------------------------------------------------------
DEFAULT_BAND_REF: float = 750.0 # λ_750 nm, 红外参考波段
DEFAULT_BAND_DIFF: float = 640.0 # λ_640 nm, 可见光差异波段
DEFAULT_A: float = 0.0 # 公式中的常数偏移项
DEFAULT_B: float = 0.0 # 公式中的斜率项
@register_glint_remover("goodman")
class GoodmanGlintRemover(BaseGlintRemover):
"""Goodman et al. 2008 去耀斑算法"""
name = "goodman"
async def process(
self,
input_zarr_path: str,
output_zarr_path: str,
**kwargs: Any,
) -> bool:
# 1. 解析超参数(带默认值, 方便用户按需覆盖)
band_ref: float = kwargs.get("band_ref", DEFAULT_BAND_REF)
band_diff: float = kwargs.get("band_diff", DEFAULT_BAND_DIFF)
A: float = kwargs.get("A", DEFAULT_A)
B: float = kwargs.get("B", DEFAULT_B)
# 2. 把同步的 xarray/dask 计算丢到工作线程,
# 避免阻塞 FastAPI 的事件循环
return await asyncio.to_thread(
self._process_sync,
input_zarr_path,
output_zarr_path,
band_ref,
band_diff,
A,
B,
)
@staticmethod
def _process_sync(
input_zarr_path: str,
output_zarr_path: str,
band_ref: float,
band_diff: float,
A: float,
B: float,
) -> bool:
# 1. 以 zarr 路径打开dask-backed, 不物化到内存)
# chunks="auto" 让 dask 根据每条坐标轴的大小自动决定分块
ds = xr.open_zarr(input_zarr_path, chunks="auto")
reflectance = ds["reflectance"] # (y, x, band)
# 2. 用 sel + method='nearest' 提取两个关键波段
# 返回形状 (y, x), 后续与 (y, x, band) 算术时会自动广播
R_750 = reflectance.sel(band=band_ref, method="nearest")
R_640 = reflectance.sel(band=band_diff, method="nearest")
# 3. Goodman 公式: xarray 沿 band 维度自动广播
# R_corr = R_raw - R_750 + A + B * (R_640 - R_750)
result = reflectance - R_750 + A + B * (R_640 - R_750)
# 4. 负值截断为 0clip(min=0) 优于 where(>0, 0, _)
# 不构造布尔中间数组, 底层走 dask 矢量化 clip 路径)
result = result.clip(min=0)
# 5. 仅在水域内生效(水外强制为 0
# 优先从 zarr 内部读 water_mask 变量, 缺失则视为全图水域
if "water_mask" in ds:
water_mask = ds["water_mask"].astype(bool)
result = result.where(water_mask, 0)
# 6. 构造输出 Dataset, 保留元信息(波段坐标/属性等)
out = xr.Dataset({"reflectance": result})
if ds.attrs:
out.attrs = dict(ds.attrs)
if reflectance.attrs:
out["reflectance"].attrs = dict(reflectance.attrs)
# 7. 流式写出Out-of-Core不一次性物化大数组,
# dask 会按 chunk 边算边写, 内存峰值 ≈ 单个 chunk 大小
out.to_zarr(output_zarr_path, mode="w", compute=True)
return True

View File

@ -0,0 +1,211 @@
"""
Kutser 去耀斑算法xarray + dask 重构版)
========================================
旧版痛点
--------
原始 Kutser 实现(参考 Kutser et al., 2013通常写成像这样
R_corr = np.zeros_like(R_raw)
for b in range(n_bands):
for y in range(H):
for x in range(W):
if water_mask[y, x]:
R_corr[y, x, b] = (
R_raw[y, x, b] - G_list[b] * D_norm[y, x]
)
with rasterio.open(..., 'w') as dst:
dst.write(R_corr)
问题:
1. 三重 Python 循环,每次只做一个浮点运算,解释器开销巨大;
2. 一次性把整张图 R_raw 读进内存,大影像直接 OOM
3. rasterio 写出要求 numpy 连续数组,进一步放大内存。
本文件用 xarray + dask 重写:
- 用 DataArray 维度广播,三重循环 → 一行表达式;
- 用 dask chunk 保持数据常驻磁盘、流式计算;
- 用 to_zarr 边算边写,输出格式与算法层彻底解耦。
"""
import asyncio
from typing import Any
import xarray as xr
from app.core.algorithms.base import BaseGlintRemover
from app.core.algorithms.registry import register_glint_remover
# ---------------------------------------------------------------------------
# 算法实现
# ---------------------------------------------------------------------------
@register_glint_remover("kutser")
class KutserGlintRemover(BaseGlintRemover):
"""
Kutser 近红外扣除法去耀斑。
数学公式(与旧版完全等价)
-------------------------
1) 水汽吸收深度 D每像素
D = (R(λ_lower) + R(λ_upper)) / 2 - R(λ_oxy)
2) 全局归一化因子 D_max
D_max = max(D) over 水域
归一化:
D_norm = D / D_max
3) 每波段水域范围:
G_list[b] = max(R[:, :, b] over 水域) - min(R[:, :, b] over 水域)
4) 校正公式(每像素、每波段):
R_corr(λ_b) = R_raw(λ_b) - G_list[b] * D_norm
"""
# Kutser 2013 论文里使用的参考波段nm
# λ_lower = 773, λ_oxy = 845, λ_upper = 893
# 允许通过 kwargs 覆盖,便于适配 MERIS / OLCI / Landsat 等不同传感器。
DEFAULT_BAND_LOWER: float = 773.0
DEFAULT_BAND_OXY: float = 845.0
DEFAULT_BAND_UPPER: float = 893.0
# --------------------------------------------------------------
# 公开异步入口
# --------------------------------------------------------------
# xarray / dask 的算子本身是同步阻塞的。在 async 函数中,
# 用 asyncio.to_thread 把同步体丢到默认线程池执行,
# 避免阻塞 FastAPI 的事件循环。
# --------------------------------------------------------------
async def process(
self,
input_zarr_path: str,
output_zarr_path: str,
**kwargs: Any,
) -> bool:
return await asyncio.to_thread(
self._process_sync,
input_zarr_path,
output_zarr_path,
kwargs,
)
# --------------------------------------------------------------
# 同步核心实现
# --------------------------------------------------------------
def _process_sync(
self,
input_zarr_path: str,
output_zarr_path: str,
kwargs: dict,
) -> bool:
# ============================================================
# 步骤 0打开 zarr建立 dask 计算图
# ============================================================
# chunks="auto":让 dask 根据 zarr 的存储分块自动选择内存上限,
# 数据不会一次性全部 materialize 进 RAM。
# ============================================================
ds = xr.open_zarr(input_zarr_path, chunks="auto")
reflectance: xr.DataArray = ds["reflectance"] # 维度约定:(y, x, band)
# 维度顺序约定(也可根据 ds.dims 自动适配):
assert "y" in reflectance.dims and "x" in reflectance.dims and "band" in reflectance.dims, (
f"reflectance 必须包含 y/x/band 三个维度,实际为: {reflectance.dims}"
)
# ============================================================
# 步骤 1取出 3 个参考波段对应的二维 (y, x) 切片
# ============================================================
# 假设 band 维度的坐标是 wavelengthnm
# 用 sel(..., method="nearest") 自动匹配最接近的波段。
# ============================================================
wl_lower = float(kwargs.get("band_lower", self.DEFAULT_BAND_LOWER))
wl_oxy = float(kwargs.get("band_oxy", self.DEFAULT_BAND_OXY))
wl_upper = float(kwargs.get("band_upper", self.DEFAULT_BAND_UPPER))
R_lower = reflectance.sel(band=wl_lower, method="nearest") # (y, x)
R_upper = reflectance.sel(band=wl_upper, method="nearest") # (y, x)
R_oxy = reflectance.sel(band=wl_oxy, method="nearest") # (y, x)
# ============================================================
# 步骤 2水域掩膜
# ============================================================
# 优先从 zarr 内部读取 water_mask 变量;
# 如果不存在,则假定整幅图都是水域(开发期兜底)。
# ============================================================
if "water_mask" in ds:
water_mask = ds["water_mask"].astype(bool)
else:
water_mask = xr.ones_like(
reflectance.isel(band=0), dtype=bool
)
# ============================================================
# 步骤 3水汽吸收深度 D每像素形状 (y, x)
# ============================================================
# 旧版D[y, x] = (R_lower[y, x] + R_upper[y, x]) / 2 - R_oxy[y, x]
# 新版一行表达式dask 自动构建 lazy 计算图。
# ============================================================
D = (R_lower + R_upper) / 2.0 - R_oxy # (y, x)dtype 与 reflectance 一致
# ============================================================
# 步骤 4全局归一化因子 D_max标量0-dim DataArray
# ============================================================
# 关键:先 .where(water_mask) 把非水域置 NaN
# 再 .max() 跨 (x, y) 聚合,自动规约到 0 维。
# dask 此时仍然没有真正计算,等到 to_zarr 时再触发。
# ============================================================
D_max = D.where(water_mask).max() # scalar
# 容错:如果水域为空导致 D_max 为 NaN用极小值兜底避免除零
D_max = D_max.fillna(1e-6)
# ============================================================
# 步骤 5归一化 D_norm形状 (y, x)
# ============================================================
D_norm = D / D_max # 标量除以 (y, x) 数组 → 自动广播
# ============================================================
# 步骤 6每波段水域范围 G_list形状 (band,)
# ============================================================
# 旧版三重循环内部还要做一次 min/max 聚合。
# xarray 版本:把 (y, x) 一起 reduce只保留 band 维度。
# ============================================================
R_water = reflectance.where(water_mask) # (y, x, band),非水域 NaN
G_min = R_water.min(dim=["x", "y"]) # (band,)
G_max = R_water.max(dim=["x", "y"]) # (band,)
G_list = (G_max - G_min).fillna(0.0) # (band,),容错
# ============================================================
# 步骤 7校正公式最关键的一行演示 xarray 广播)
# ============================================================
# 旧版需要:
# for b in bands:
# for y in range(H):
# for x in range(W):
# R_corr[y,x,b] = R_raw[y,x,b] - G_list[b] * D_norm[y,x]
#
# xarray 维度对齐规则:
# R_raw : (y, x, band)
# G_list: (band,) → 缺失 y, x 自动扩展
# D_norm: (y, x) → 缺失 band 自动扩展
# 乘法结果: (y, x, band) → 减法对齐
# 一行表达式完成「三重 for 循环 + 标量索引」的语义。
# ============================================================
corrected = reflectance - G_list * D_norm # (y, x, band)
# ============================================================
# 步骤 8水域掩膜过滤非水域置 NaN
# ============================================================
result = corrected.where(water_mask)
# ============================================================
# 步骤 9持久化为 zarr
# ============================================================
# mode="w":覆盖写入(如果目标已存在则删除重建)。
# compute=True阻塞直到整张图算完并落盘。
# 由于数据始终是 dask chunk + 流式写出,
# 内存峰值 ≈ 单个 chunk 大小,与整张影像大小无关。
# ============================================================
out = xr.Dataset({"reflectance": result})
# 保留原数据集的全局属性 / 坐标信息CRS、wavelength、...
out.attrs = dict(ds.attrs)
out["reflectance"].attrs = dict(reflectance.attrs)
out.to_zarr(output_zarr_path, mode="w", compute=True)
return True

View File

@ -0,0 +1,135 @@
"""
算法注册表Registry / Factory
================================
通过装饰器把「算法名字符串」与「算法实现类」绑定在一起。
上层调度层FastAPI endpoints、BackgroundTasks worker只需要拿到
前端传过来的 method 字符串,就可以自动派发到对应的算法实现,
而无需写一长串 if/elif。
使用示例
--------
from app.core.algorithms import BaseGlintRemover
from app.core.algorithms.registry import (
register_glint_remover,
get_remover,
list_removers,
)
@register_glint_remover("kutser")
class KutserGlintRemover(BaseGlintRemover):
async def process(self, input_zarr_path, output_zarr_path, **kwargs):
...
# 派发
Cls = get_remover(method_from_request)
remover = Cls()
await remover.process(input_zarr_path, output_zarr_path, **kwargs)
设计要点
--------
- 注册动作发生在「类定义时」,所以必须在所有算法 import 完之后
注册表才完整。可以在 `app/core/algorithms/__init__.py` 中
把算法子模块 import 一遍来强制触发注册。
- 重复注册同名算法会直接抛错,避免静默覆盖。
- name 会同步写回到类的 `name` 属性,便于算法自身查询身份。
"""
from typing import Dict, Type
from app.core.algorithms.base import BaseGlintRemover
# 全局注册表name(str) -> 实现类(type),类未被实例化
_REGISTRY: Dict[str, Type[BaseGlintRemover]] = {}
def register_glint_remover(name: str):
"""
类装饰器工厂:把传入 name 的算法类注册到全局注册表。
Parameters
----------
name : str
算法标识,建议小写下划线风格,例如 "kutser""goodman"
Raises
------
ValueError
- name 不是非空字符串
- name 已经被其它类占用
TypeError
- 被装饰的对象不是 BaseGlintRemover 的子类
"""
# ---- 防御性校验name 必须是合法字符串 ----
if not isinstance(name, str) or not name.strip():
raise ValueError(
f"register_glint_remover 的 name 必须是非空字符串,收到: {name!r}"
)
def decorator(cls: Type[BaseGlintRemover]) -> Type[BaseGlintRemover]:
# ---- 防御性校验:被装饰对象必须是 BaseGlintRemover 子类 ----
if not isinstance(cls, type) or not issubclass(cls, BaseGlintRemover):
raise TypeError(
f"@register_glint_remover 只能装饰 BaseGlintRemover 的子类,"
f"收到: {cls!r}"
)
# ---- 防御性校验:禁止静默覆盖 ----
if name in _REGISTRY:
raise ValueError(
f"算法名 {name!r} 已被 {_REGISTRY[name].__name__} 占用,"
f"请使用其它名字或先调用 unregister_glint_remover() 注销旧实现。"
)
# 同步把 name 写回类属性,便于算法自身和日志输出使用
cls.name = name
_REGISTRY[name] = cls
return cls
return decorator
def get_remover(name: str) -> Type[BaseGlintRemover]:
"""
按算法名字符串取出对应的实现类(未实例化)。
调用方拿到类后自行 `Cls(...)` 构造实例,再调用 process()。
Raises
------
KeyError
当 name 不在注册表中时抛出,错误信息中附带已注册列表便于排查。
"""
try:
return _REGISTRY[name]
except KeyError as exc:
known = ", ".join(sorted(_REGISTRY)) or "<空>"
raise KeyError(
f"未注册的算法名: {name!r}。已注册的算法: {known}"
) from exc
def list_removers() -> Dict[str, Type[BaseGlintRemover]]:
"""
返回当前注册表的浅拷贝。
可用于:
- 调试日志
- 给前端暴露一个 GET /api/algorithms 接口
- 单元测试断言
"""
return dict(_REGISTRY)
def unregister_glint_remover(name: str) -> None:
"""
注销指定算法。主要给:
- 单元测试
- 热重载 / 插件卸载场景
生产代码一般不需要调用。
"""
if name not in _REGISTRY:
raise KeyError(f"未注册的算法名: {name!r}")
del _REGISTRY[name]