refactor(step4): 剥离 Steps 层 - step1~step3 业务逻辑下沉到独立模块

This commit is contained in:
DXC
2026-05-09 17:30:49 +08:00
parent 605ec86108
commit d0eb458392
5 changed files with 674 additions and 528 deletions

View File

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
"""业务步骤层模块"""
from src.core.steps.water_mask_step import WaterMaskStep
from src.core.steps.glint_detection_step import GlintDetectionStep
from src.core.steps.glint_removal_step import GlintRemovalStep
__all__ = [
"WaterMaskStep",
"GlintDetectionStep",
"GlintRemovalStep",
]

View File

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
"""
步骤2: 耀斑区域检测
支持多种检测方法: otsu, zscore, percentile, iqr, adaptive, multi_band
"""
import time
from pathlib import Path
from typing import Optional, List, Union
class GlintDetectionStep:
"""耀斑区域检测步骤"""
@staticmethod
def run(
img_path: str,
glint_wave: float = 750.0,
method: str = "otsu",
z_threshold: float = 2.5,
percentile: float = 95.0,
iqr_multiplier: float = 1.5,
window_size: int = 15,
multi_band_waves: Optional[List[float]] = None,
sub_method: str = "zscore",
weights: Optional[List[float]] = None,
max_area: Optional[int] = None,
buffer_size: Optional[int] = None,
water_mask_path: Optional[str] = None,
glint_dir: Union[str, Path] = "./2_glint",
callback: Optional[callable] = None,
) -> str:
"""
执行耀斑区域检测
Args:
img_path: 输入影像文件路径
glint_wave: 用于耀斑检测的波段波长nm
method: 检测方法 ('otsu' | 'zscore' | 'percentile' | 'iqr' | 'adaptive' | 'multi_band')
z_threshold: Z-score 方法阈值(默认 2.5
percentile: 百分位数阈值(默认 95.0
iqr_multiplier: IQR 倍数(默认 1.5
window_size: 自适应阈值窗口大小(默认 15
multi_band_waves: 多波段方法的波长列表,如 [750, 800, 850]
sub_method: 多波段方法的子方法(默认 'zscore'
weights: 多波段方法的权重列表None 表示等权重)
max_area: 最大连通域面积阈值(像素),超过则过滤
buffer_size: 岸边缓冲区大小(像素),用于去除岸边附近错误掩膜
water_mask_path: 水域掩膜文件路径dat 格式优先)
glint_dir: 工作目录
callback: 回调函数
Returns:
耀斑掩膜文件路径 (.dat)
"""
from src.utils.find_severe_glint_area import find_severe_glint_area
glint_dir = Path(glint_dir)
glint_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤2", status, msg)
print("\n" + "=" * 80)
print("步骤2: 找到耀斑区域")
print("=" * 80)
step_start_time = time.time()
# 确定水体掩膜路径
if water_mask_path is not None and Path(water_mask_path).exists():
final_water_mask_path = water_mask_path
else:
final_water_mask_path = None
output_path = str(glint_dir / "severe_glint_area.dat")
# 跳过已存在的文件
if Path(output_path).exists():
print(f"检测到已存在的耀斑掩膜文件,直接使用: {output_path}")
notify("skipped", f"耀斑掩膜已设置: {output_path}")
return output_path
# 构建检测参数字典
kwargs = {
"method": method,
"z_threshold": z_threshold,
"percentile": percentile,
"iqr_multiplier": iqr_multiplier,
"window_size": window_size,
}
if method == "multi_band":
if multi_band_waves is not None:
kwargs["multi_band_waves"] = multi_band_waves
if sub_method is not None:
kwargs["sub_method"] = sub_method
if weights is not None:
kwargs["weights"] = weights
if max_area is not None:
kwargs["max_area"] = max_area
if buffer_size is not None:
kwargs["buffer_size"] = buffer_size
glint_mask_path = find_severe_glint_area(
img_path, final_water_mask_path, glint_wave, output_path, **kwargs
)
print(f"耀斑掩膜已生成: {glint_mask_path}")
print(f"使用检测方法: {method}")
notify("completed", f"耀斑掩膜已生成: {glint_mask_path}")
return glint_mask_path

View File

@ -0,0 +1,314 @@
# -*- coding: utf-8 -*-
"""
步骤3: 去除耀斑
支持多种方法: subtract_nir, regression_slope, oxygen_absorption, kutser, goodman, hedley, sugar
每种方法都会:
1. 准备水域掩膜(支持 shp 自动转 dat
2. 调用对应的算法类执行处理
3. 复制 hdr 文件到输出影像
"""
import os
import time
from pathlib import Path
from typing import Optional, List, Union, Callable
import numpy as np
class GlintRemovalStep:
"""去除耀斑步骤"""
@staticmethod
def run(
img_path: str,
method: str = "subtract_nir",
start_wave: Optional[float] = None,
end_wave: Optional[float] = None,
json_path: Optional[str] = None,
left_shoulder_wave: Optional[float] = None,
valley_wave: Optional[float] = None,
right_shoulder_wave: Optional[float] = None,
water_mask: Optional[Union[str, np.ndarray]] = None,
interpolated_img_path: Optional[str] = None,
interpolate_zeros: bool = False,
interpolation_method: str = "nearest",
enabled: bool = True,
# Kutser 参数
kutser_shp_path: Optional[str] = None,
oxy_band: int = 38,
lower_oxy: int = 36,
upper_oxy: int = 49,
nir_band: int = 47,
# Goodman 参数
nir_lower: int = 25,
nir_upper: int = 37,
goodman_A: float = 0.000019,
goodman_B: float = 0.1,
# Hedley 参数
hedley_shp_path: Optional[str] = None,
hedley_nir_band: int = 47,
# SUGAR 参数
sugar_bounds: Optional[List[tuple]] = None,
sugar_sigma: float = 1.0,
sugar_estimate_background: bool = True,
sugar_glint_mask_method: str = "cdf",
sugar_iter: Optional[int] = 3,
sugar_termination_thresh: float = 20.0,
# 内部工具函数
_get_image_geo_info=None,
_load_image_as_array=None,
_save_bands_as_image=None,
_copy_hdr_info=None,
_prepare_water_mask_for_algorithm=None,
_interpolate_zero_pixels_batch=None,
deglint_dir: Union[str, Path] = "./3_deglint",
water_mask_dir: Union[str, Path] = "./1_water_mask",
callback: Optional[Callable] = None,
) -> str:
"""
执行去除耀斑处理
Args:
img_path: 输入影像文件路径
method: 去耀斑方法
...(其余参数同主类 step3_remove_glint
Returns:
去除耀斑后的影像文件路径
"""
from src.core.glint_removal.Kutser import Kutser
from src.core.glint_removal.Goodman import Goodman
from src.core.glint_removal.Hedley import Hedley
from src.core.glint_removal.SUGAR import SUGAR, correction_iterative
from src.core.utils.gdal_helper import (
get_image_geo_info as _default_get_geo,
load_image_as_array as _default_load,
save_bands_as_image as _default_save_bands,
copy_hdr_info as _default_copy_hdr,
)
from src.core.utils.mask_converter import (
prepare_water_mask_for_algorithm as _default_prepare,
)
# 使用提供的函数或默认函数
if _get_image_geo_info is None:
_get_image_geo_info = _default_get_geo
if _load_image_as_array is None:
_load_image_as_array = _default_load
if _save_bands_as_image is None:
_save_bands_as_image = _default_save_bands
if _copy_hdr_info is None:
_copy_hdr_info = _default_copy_hdr
if _prepare_water_mask_for_algorithm is None:
_prepare_water_mask_for_algorithm = _default_prepare
deglint_dir = Path(deglint_dir)
deglint_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤3", status, msg)
print("\n" + "=" * 80)
print("步骤3: 去除耀斑")
print("=" * 80)
step_start_time = time.time()
# 方法名标准化
raw_method = str(method).lower()
if "kutser" in raw_method:
method = "kutser"
elif "goodman" in raw_method:
method = "goodman"
elif "hedley" in raw_method:
method = "hedley"
elif "sugar" in raw_method:
method = "sugar"
# 如果未启用,直接返回原始影像
if not enabled:
print("已设置跳过去除耀斑enabled=False将直接使用原始影像。")
notify("skipped", "跳过去耀斑,使用原始影像")
return img_path
# ---- 确定水域掩膜 ----
final_water_mask = water_mask
if final_water_mask is not None and str(final_water_mask).lower().endswith(".shp"):
# shp 自动替换为 dat
dat_mask = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
if Path(dat_mask).exists():
print(f"检测到输入掩膜为 .shp自动替换为栅格掩膜: {dat_mask}")
final_water_mask = dat_mask
if final_water_mask is None:
dat_mask_default = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
if Path(dat_mask_default).exists():
final_water_mask = dat_mask_default
print(f"使用步骤1生成的水域掩膜: {final_water_mask}")
# ---- 步骤3.1: 0值像素插值 ----
if interpolate_zeros:
print("\n" + "-" * 80)
print("步骤3.1: 对0值像素进行插值")
print("-" * 80)
interp_start_time = time.time()
if _interpolate_zero_pixels_batch is None:
from src.core.algorithms.interpolation.interpolator import (
interpolate_zero_pixels_batch as _interp_batch,
)
_interpolate_zero_pixels_batch = _interp_batch
interp_result, _ = _interpolate_zero_pixels_batch(
img_path=img_path,
interpolation_method=interpolation_method,
output_path=None,
water_mask=final_water_mask,
deglint_dir=str(deglint_dir),
callback_progress=lambda msg: print(f" {msg}"),
)
img_path = interp_result
interp_end_time = time.time()
print(f"插值完成,使用插值后的影像: {img_path}")
# ---- 获取影像信息 ----
geotransform, projection, width, height, n_bands = _get_image_geo_info(img_path)
print(f"影像尺寸: {width} x {height} x {n_bands}")
mask_for_algorithm = _prepare_water_mask_for_algorithm(
final_water_mask, (height, width), geotransform, projection, img_path
)
# ==================== Kutser ====================
if method == "kutser":
print(f"使用方法: Kutser (氧吸收波段={oxy_band}, NIR波段={nir_band})")
output_path = str(deglint_dir / "deglint_kutser.bsq")
if Path(output_path).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {output_path}")
notify("skipped", f"去耀斑影像已设置: {output_path}")
return output_path
kutser = Kutser(
img_path,
shp_path=None,
oxy_band=oxy_band,
lower_oxy=lower_oxy,
upper_oxy=upper_oxy,
NIR_band=nir_band,
water_mask=mask_for_algorithm,
output_path=output_path,
)
kutser.get_corrected_bands()
if Path(output_path).exists():
_copy_hdr_info(img_path, output_path)
notify("completed", f"去耀斑影像已生成: {output_path}")
return output_path
raise RuntimeError(f"Kutser算法未生成输出文件: {output_path}")
# ==================== Goodman ====================
elif method == "goodman":
print(f"使用方法: Goodman (NIR波段范围: {nir_lower}-{nir_upper})")
output_path = str(deglint_dir / "deglint_goodman.bsq")
if Path(output_path).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {output_path}")
notify("skipped", f"去耀斑影像已设置: {output_path}")
return output_path
goodman = Goodman(
img_path,
NIR_lower=nir_lower,
NIR_upper=nir_upper,
A=goodman_A,
B=goodman_B,
water_mask=mask_for_algorithm,
output_path=output_path,
)
corrected_bands = goodman.get_corrected_bands()
if not Path(output_path).exists():
_save_bands_as_image(corrected_bands, output_path, geotransform, projection)
_copy_hdr_info(img_path, output_path)
else:
_copy_hdr_info(img_path, output_path)
del corrected_bands
notify("completed", f"去耀斑影像已生成: {output_path}")
return output_path
# ==================== Hedley ====================
elif method == "hedley":
print(f"使用方法: Hedley (NIR波段={hedley_nir_band})")
output_path = str(deglint_dir / "deglint_hedley.bsq")
if Path(output_path).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {output_path}")
notify("skipped", f"去耀斑影像已设置: {output_path}")
return output_path
hedley = Hedley(
img_path,
shp_path=None,
NIR_band=hedley_nir_band,
water_mask=mask_for_algorithm,
output_path=output_path,
)
hedley.get_corrected_bands()
if Path(output_path).exists():
_copy_hdr_info(img_path, output_path)
notify("completed", f"去耀斑影像已生成: {output_path}")
return output_path
raise RuntimeError(f"Hedley算法未生成输出文件: {output_path}")
# ==================== SUGAR ====================
elif method == "sugar":
# 方法名标准化
glint_method_raw = str(sugar_glint_mask_method).lower()
if "cdf" in glint_method_raw or "累积" in glint_method_raw:
sugar_glint_mask_method_fixed = "cdf"
elif "otsu" in glint_method_raw or "大津" in glint_method_raw:
sugar_glint_mask_method_fixed = "otsu"
else:
sugar_glint_mask_method_fixed = "cdf"
print(
f"使用方法: SUGAR (迭代次数={sugar_iter}, 掩膜方法={sugar_glint_mask_method_fixed})"
)
output_path = str(deglint_dir / "deglint_sugar.bsq")
if Path(output_path).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {output_path}")
notify("skipped", f"去耀斑影像已设置: {output_path}")
return output_path
if sugar_bounds is None:
sugar_bounds = [(1, 2)]
correction_iterative(
img_path,
iter=sugar_iter,
bounds=sugar_bounds,
estimate_background=sugar_estimate_background,
glint_mask_method=sugar_glint_mask_method_fixed,
termination_thresh=sugar_termination_thresh,
water_mask=mask_for_algorithm,
output_path=output_path,
)
if Path(output_path).exists():
_copy_hdr_info(img_path, output_path)
notify("completed", f"去耀斑影像已生成: {output_path}")
return output_path
raise RuntimeError(f"SUGAR算法未生成输出文件: {output_path}")
else:
raise ValueError(
f"不支持的方法: {method}。支持的方法: kutser, goodman, hedley, sugar"
)

View File

@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
"""
步骤1: 水域掩膜生成
支持三种方式:
1. 基于 shp 文件栅格化
2. 使用现有栅格格式掩膜文件 (.dat/.tif)
3. 基于 NDWI 从影像自动生成水体掩膜
"""
import os
import time
from pathlib import Path
from typing import Optional, List, Callable, Union
import numpy as np
class WaterMaskStep:
"""水域掩膜生成步骤"""
@staticmethod
def run(
mask_path: Optional[str] = None,
img_path: Optional[str] = None,
ndwi_threshold: float = 0.4,
use_ndwi: bool = False,
generate_png: bool = True,
output_path: Optional[str] = None,
water_mask_dir: Union[str, Path] = "./1_water_mask",
callback: Optional[Callable] = None,
) -> str:
"""
执行水域掩膜生成
Args:
mask_path: 水体掩膜文件路径,支持 .shp需 img_path或 .dat/.tif直接使用
img_path: 输入影像文件路径(当 mask_path 为 shp 或 use_ndwi=True 时必须提供)
ndwi_threshold: NDWI 阈值use_ndwi=True 时使用)
use_ndwi: 是否使用 NDWI 方法从影像生成水体掩膜
generate_png: 是否生成 PNG 预览图(默认 True
output_path: 指定输出掩膜文件的保存路径(可选)
water_mask_dir: 工作目录
callback: 回调函数,签名为 callback(step, status, message)
Returns:
dat 格式的水域掩膜文件路径
"""
from src.utils.extract_water_area import rasterize_shp, ndwi
from src.core.utils.preview_generator import (
generate_image_preview,
generate_water_mask_overlay,
)
water_mask_dir = Path(water_mask_dir)
water_mask_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤1", status, msg)
print("\n" + "=" * 80)
print("步骤1: 生成或设置水域mask")
print("=" * 80)
step_start_time = time.time()
# 生成影像预览图
if generate_png and img_path is not None and Path(img_path).exists():
preview_path = str(water_mask_dir / "hsi_preview.png")
generate_image_preview(
img_path=img_path,
output_path=preview_path,
title="影像预览: RGB波段(基于波长)"
)
# ---- NDWI 方法 ----
if use_ndwi:
if img_path is None:
raise ValueError("当 use_ndwi=True 时,必须提供 img_path 参数")
if not Path(img_path).exists():
raise ValueError(f"影像文件不存在: {img_path}")
print(f"使用NDWI方法从影像生成水体掩膜阈值={ndwi_threshold}...")
ndwi_output_path = output_path or str(water_mask_dir / "water_mask_from_ndwi.dat")
os.makedirs(Path(ndwi_output_path).parent, exist_ok=True)
if Path(ndwi_output_path).exists():
print(f"检测到已存在的NDWI掩膜文件直接使用: {ndwi_output_path}")
notify("skipped", f"水域掩膜已设置: {ndwi_output_path}")
return ndwi_output_path
ndwi(img_path, ndwi_threshold, ndwi_output_path)
if generate_png:
overlay_path = water_mask_dir / "water_mask_overlay.png"
generate_water_mask_overlay(
img_path=img_path, mask_path=ndwi_output_path, output_path=str(overlay_path)
)
notify("completed", f"NDWI水体掩膜已生成: {ndwi_output_path}")
return ndwi_output_path
# ---- 必须提供 mask_path ----
if mask_path is None:
raise ValueError("必须提供 mask_path 参数或设置 use_ndwi=True")
if not Path(mask_path).exists():
raise ValueError(f"文件不存在: {mask_path}")
file_ext = Path(mask_path).suffix.lower()
# ---- SHP 栅格化 ----
if file_ext == ".shp":
if img_path is None:
raise ValueError("当 mask_path 为 shp 格式时,必须提供 img_path 参数")
print(f"检测到shp格式的水体掩膜正在转换为dat格式...")
shp_output_path = output_path or str(water_mask_dir / "water_mask_from_shp.dat")
os.makedirs(Path(shp_output_path).parent, exist_ok=True)
if Path(shp_output_path).exists():
print(f"检测到已存在的栅格化掩膜文件,直接使用: {shp_output_path}")
notify("skipped", f"水域掩膜已设置: {shp_output_path}")
if generate_png:
overlay_path = water_mask_dir / "water_mask_overlay.png"
if not overlay_path.exists():
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
return shp_output_path
safe_mask_path = os.path.abspath(mask_path).replace("\\", "/")
rasterize_shp(safe_mask_path, shp_output_path, img_path)
if generate_png:
overlay_path = water_mask_dir / "water_mask_overlay.png"
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
notify("completed", f"dat格式水域掩膜已生成: {shp_output_path}")
return shp_output_path
# ---- 栅格格式直接使用 ----
print(f"检测到栅格格式的水体掩膜,直接使用: {mask_path}")
if generate_png and img_path is not None and Path(img_path).exists():
overlay_path = water_mask_dir / "water_mask_overlay.png"
generate_water_mask_overlay(img_path, mask_path, str(overlay_path))
notify("completed", f"水域掩膜已设置: {mask_path}")
return mask_path