refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块
This commit is contained in:
36
src/core/algorithms/__init__.py
Normal file
36
src/core/algorithms/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""
|
||||
算法层模块
|
||||
包含插值算法和耀斑检测算法等核心数学计算
|
||||
"""
|
||||
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
|
||||
from src.core.algorithms.glint_detection.detectors import (
|
||||
otsu_threshold,
|
||||
zscore_threshold,
|
||||
percentile_threshold,
|
||||
iqr_outlier_detection,
|
||||
adaptive_threshold,
|
||||
multi_band_glint_detection,
|
||||
percentile_stretch,
|
||||
filter_large_components,
|
||||
create_shoreline_buffer,
|
||||
remove_shoreline_buffer,
|
||||
calculate_glint_mask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 插值
|
||||
'interpolate_pixels',
|
||||
'interpolate_zero_pixels_batch',
|
||||
# 耀斑检测
|
||||
'otsu_threshold',
|
||||
'zscore_threshold',
|
||||
'percentile_threshold',
|
||||
'iqr_outlier_detection',
|
||||
'adaptive_threshold',
|
||||
'multi_band_glint_detection',
|
||||
'percentile_stretch',
|
||||
'filter_large_components',
|
||||
'create_shoreline_buffer',
|
||||
'remove_shoreline_buffer',
|
||||
'calculate_glint_mask',
|
||||
]
|
||||
31
src/core/algorithms/glint_detection/__init__.py
Normal file
31
src/core/algorithms/glint_detection/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
耀斑检测算法模块
|
||||
包含各种耀斑检测的核心数学计算函数
|
||||
"""
|
||||
from src.core.algorithms.glint_detection.detectors import (
|
||||
otsu_threshold,
|
||||
zscore_threshold,
|
||||
percentile_threshold,
|
||||
iqr_outlier_detection,
|
||||
adaptive_threshold,
|
||||
multi_band_glint_detection,
|
||||
percentile_stretch,
|
||||
filter_large_components,
|
||||
create_shoreline_buffer,
|
||||
remove_shoreline_buffer,
|
||||
calculate_glint_mask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'otsu_threshold',
|
||||
'zscore_threshold',
|
||||
'percentile_threshold',
|
||||
'iqr_outlier_detection',
|
||||
'adaptive_threshold',
|
||||
'multi_band_glint_detection',
|
||||
'percentile_stretch',
|
||||
'filter_large_components',
|
||||
'create_shoreline_buffer',
|
||||
'remove_shoreline_buffer',
|
||||
'calculate_glint_mask',
|
||||
]
|
||||
595
src/core/algorithms/glint_detection/detectors.py
Normal file
595
src/core/algorithms/glint_detection/detectors.py
Normal file
@ -0,0 +1,595 @@
|
||||
"""
|
||||
耀斑检测算法模块
|
||||
|
||||
包含各种耀斑检测的核心数学计算函数,纯数学逻辑,不涉及文件I/O。
|
||||
支持的方法:otsu, zscore, percentile, iqr, adaptive, multi_band
|
||||
|
||||
本模块是从 src/utils/find_severe_glint_area.py 抽取出来的核心算法部分。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, List, Tuple
|
||||
from functools import wraps
|
||||
|
||||
try:
|
||||
import cv2
|
||||
CV2_AVAILABLE = True
|
||||
except ImportError:
|
||||
CV2_AVAILABLE = False
|
||||
|
||||
|
||||
def timeit(func):
|
||||
"""装饰器:测量函数执行时间"""
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
import time
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
print(f"[{func.__name__}] 耗时: {end - start:.2f}s")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 百分位数拉伸
|
||||
# =============================================================================
|
||||
|
||||
def percentile_stretch(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
lower_percentile: float = 2,
|
||||
upper_percentile: float = 98,
|
||||
output_range: Tuple[int, int] = (0, 255)
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用百分位数裁剪进行归一化,适用于低反射率数据
|
||||
通过排除极值,更好地利用数据的动态范围
|
||||
|
||||
Args:
|
||||
img: 输入图像数组(反射率值,通常在0-1之间)
|
||||
data_water_mask: 水域掩膜
|
||||
lower_percentile: 下百分位数,用于裁剪最小值(默认2)
|
||||
upper_percentile: 上百分位数,用于裁剪最大值(默认98)
|
||||
output_range: 输出范围,默认(0, 255)
|
||||
|
||||
Returns:
|
||||
归一化后的图像数组(整数类型)
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return img.astype(np.int32)
|
||||
|
||||
p_lower = np.percentile(valid_pixels, lower_percentile)
|
||||
p_upper = np.percentile(valid_pixels, upper_percentile)
|
||||
|
||||
if p_lower >= p_upper:
|
||||
p_lower = np.percentile(valid_pixels, 1)
|
||||
p_upper = np.percentile(valid_pixels, 99)
|
||||
if p_lower >= p_upper:
|
||||
p_upper = valid_pixels.max()
|
||||
p_lower = valid_pixels.min()
|
||||
|
||||
img_clipped = np.clip(img, p_lower, p_upper)
|
||||
|
||||
if p_upper > p_lower:
|
||||
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (
|
||||
output_range[1] - output_range[0]
|
||||
) + output_range[0]
|
||||
else:
|
||||
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
|
||||
|
||||
return img_stretched.astype(np.int32)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Otsu阈值分割
|
||||
# =============================================================================
|
||||
|
||||
def otsu_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
ignore_value: int = 0,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于Otsu方法的自动阈值分割
|
||||
通过最大化类间方差找到最佳分割阈值
|
||||
|
||||
Args:
|
||||
img: 输入图像数组(整数值)
|
||||
data_water_mask: 水域掩膜
|
||||
ignore_value: 忽略的值(默认为0)
|
||||
foreground: 耀斑区域值(默认1)
|
||||
background: 背景值(默认0)
|
||||
|
||||
Returns:
|
||||
二值化检测结果数组
|
||||
"""
|
||||
height, width = img.shape
|
||||
|
||||
max_value = int(np.max(img[img > ignore_value])) + 1
|
||||
if max_value < 2:
|
||||
max_value = 256
|
||||
|
||||
hist = np.zeros([max_value], np.float32)
|
||||
|
||||
invalid_counter = 0
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
|
||||
invalid_counter += 1
|
||||
continue
|
||||
hist[img[i, j]] += 1
|
||||
|
||||
total_valid = height * width - invalid_counter
|
||||
if total_valid <= 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
hist /= total_valid
|
||||
|
||||
threshold = 0
|
||||
deltaMax = 0
|
||||
|
||||
for i in range(max_value):
|
||||
wA = sum(hist[:i + 1])
|
||||
wB = sum(hist[i + 1:])
|
||||
if wA == 0:
|
||||
wA = 1e-10
|
||||
if wB == 0:
|
||||
wB = 1e-10
|
||||
|
||||
uAtmp = sum(j * hist[j] for j in range(i + 1))
|
||||
uBtmp = sum(j * hist[j] for j in range(i + 1, max_value))
|
||||
uA = uAtmp / wA
|
||||
uB = uBtmp / wB
|
||||
u = uAtmp + uBtmp
|
||||
|
||||
deltaTmp = wA * ((uA - u) ** 2) + wB * ((uB - u) ** 2)
|
||||
if deltaTmp > deltaMax:
|
||||
deltaMax = deltaTmp
|
||||
threshold = i
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > threshold] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Z-score阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def zscore_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
z_threshold: float = 2.5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于Z-score(标准化分数)的耀斑检测方法
|
||||
使用统计方法识别异常高亮的像素,对数据分布不敏感
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
z_threshold: Z-score阈值,默认2.5(即超过均值2.5个标准差)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
mean_val = np.mean(valid_pixels)
|
||||
std_val = np.std(valid_pixels)
|
||||
|
||||
if std_val == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
z_scores = np.zeros_like(img, dtype=np.float32)
|
||||
valid_mask = (data_water_mask > 0) & np.isfinite(img)
|
||||
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[z_scores > z_threshold] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 百分位数阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def percentile_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
percentile: float = 95,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于百分位数的耀斑检测方法
|
||||
使用百分位数作为阈值,对异常值更稳健
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
percentile: 百分位数阈值,默认95(即超过95%的像素值)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
threshold_val = np.percentile(valid_pixels, percentile)
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > threshold_val] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# IQR异常值检测
|
||||
# =============================================================================
|
||||
|
||||
def iqr_outlier_detection(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
iqr_multiplier: float = 1.5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于IQR(四分位距)的异常值检测方法
|
||||
使用四分位距识别异常高亮的像素,对数据分布不敏感
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
iqr_multiplier: IQR倍数,默认1.5(标准异常值检测)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
q1 = np.percentile(valid_pixels, 25)
|
||||
q3 = np.percentile(valid_pixels, 75)
|
||||
iqr = q3 - q1
|
||||
|
||||
upper_bound = q3 + iqr_multiplier * iqr
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > upper_bound] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 自适应阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def adaptive_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
window_size: int = 15,
|
||||
percentile: float = 90,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
自适应阈值方法
|
||||
基于局部统计特性进行阈值分割,对光照变化更稳健
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
window_size: 局部窗口大小(奇数)
|
||||
percentile: 局部百分位数阈值
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
height, width = img.shape
|
||||
|
||||
if window_size % 2 == 0:
|
||||
window_size += 1
|
||||
|
||||
half_window = window_size // 2
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
for i in range(half_window, height - half_window):
|
||||
for j in range(half_window, width - half_window):
|
||||
if data_water_mask[i, j] == 0:
|
||||
continue
|
||||
|
||||
local_window = img[i - half_window:i + half_window + 1,
|
||||
j - half_window:j + half_window + 1]
|
||||
local_mask = data_water_mask[i - half_window:i + half_window + 1,
|
||||
j - half_window:j + half_window + 1]
|
||||
|
||||
valid_pixels = local_window[local_mask > 0]
|
||||
|
||||
if len(valid_pixels) > 0:
|
||||
local_th = np.percentile(valid_pixels, percentile)
|
||||
if img[i, j] > local_th:
|
||||
det_img[i, j] = foreground
|
||||
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 多波段融合耀斑检测
|
||||
# =============================================================================
|
||||
|
||||
def multi_band_glint_detection(
|
||||
nir_band: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
glint_waves: List[float],
|
||||
weights: Optional[List[float]] = None,
|
||||
method: str = 'zscore',
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95,
|
||||
sub_band_arrays: Optional[List[np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
多波段融合的耀斑检测方法
|
||||
结合多个波段的耀斑特征,提高检测的稳健性
|
||||
|
||||
Args:
|
||||
nir_band: 近红外波段数组(主波段,用于兼容性)
|
||||
water_mask: 水域掩膜数组
|
||||
glint_waves: 用于检测的波长列表,如[750, 800, 850]
|
||||
weights: 各波段的权重,如果为None则使用等权重
|
||||
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
|
||||
z_threshold: Z-score阈值(当method='zscore'时使用)
|
||||
percentile: 百分位数阈值(当method='percentile'时使用)
|
||||
sub_band_arrays: 子波段数组列表(如果提供,与 glint_waves 一一对应)
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
if weights is None:
|
||||
weights = [1.0 / len(glint_waves)] * len(glint_waves)
|
||||
|
||||
if len(weights) != len(glint_waves):
|
||||
raise ValueError("权重数量必须与波长数量相同")
|
||||
|
||||
fused_band = None
|
||||
|
||||
if sub_band_arrays is not None and len(sub_band_arrays) == len(glint_waves):
|
||||
for i, band_array in enumerate(sub_band_arrays):
|
||||
if fused_band is None:
|
||||
fused_band = (band_array * weights[i]).astype(np.float32)
|
||||
else:
|
||||
fused_band = (fused_band + band_array * weights[i]).astype(np.float32)
|
||||
else:
|
||||
fused_band = nir_band.astype(np.float32)
|
||||
|
||||
if method == 'otsu':
|
||||
stretched = percentile_stretch(fused_band, water_mask, 2, 98)
|
||||
return otsu_threshold(stretched, water_mask)
|
||||
elif method == 'zscore':
|
||||
return zscore_threshold(fused_band, water_mask, z_threshold)
|
||||
elif method == 'percentile':
|
||||
return percentile_threshold(fused_band, water_mask, percentile)
|
||||
else:
|
||||
raise ValueError(f"不支持的方法: {method}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 连通域过滤
|
||||
# =============================================================================
|
||||
|
||||
def filter_large_components(
|
||||
binary_img: np.ndarray,
|
||||
max_area: Optional[int] = None,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
过滤掉面积超过阈值的连通域
|
||||
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
|
||||
|
||||
Args:
|
||||
binary_img: 二值化图像
|
||||
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
过滤后的二值化图像
|
||||
"""
|
||||
if max_area is None or max_area <= 0:
|
||||
return binary_img
|
||||
|
||||
if CV2_AVAILABLE:
|
||||
binary_for_label = (binary_img == foreground).astype(np.uint8)
|
||||
num_features, labeled_array, stats, _ = cv2.connectedComponentsWithStats(
|
||||
binary_for_label, connectivity=8
|
||||
)
|
||||
|
||||
if num_features == 0:
|
||||
return binary_img
|
||||
|
||||
component_sizes = stats[1:, cv2.CC_STAT_AREA]
|
||||
keep_labels = np.where(component_sizes <= max_area)[0] + 1
|
||||
|
||||
keep_mask = np.isin(labeled_array, keep_labels)
|
||||
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||||
filtered[keep_mask] = foreground
|
||||
|
||||
return filtered
|
||||
else:
|
||||
from scipy import ndimage
|
||||
labeled_array, num_features = ndimage.label(
|
||||
(binary_img == foreground).astype(np.int32)
|
||||
)
|
||||
|
||||
if num_features == 0:
|
||||
return binary_img
|
||||
|
||||
component_sizes = ndimage.sum(
|
||||
(labeled_array == i).astype(np.int32),
|
||||
labeled_array,
|
||||
range(1, num_features + 1)
|
||||
)
|
||||
|
||||
keep_mask = np.isin(labeled_array, [i + 1 for i, s in enumerate(component_sizes) if s <= max_area])
|
||||
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||||
filtered[keep_mask] = foreground
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 岸边缓冲区处理
|
||||
# =============================================================================
|
||||
|
||||
def create_shoreline_buffer(
|
||||
water_mask: np.ndarray,
|
||||
buffer_size: int = 5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
创建岸边缓冲区掩膜(向内缓冲)
|
||||
用于去除岸边附近的错误耀斑检测区域
|
||||
|
||||
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
|
||||
|
||||
Args:
|
||||
water_mask: 水域掩膜数组(水域=1,非水域=0)
|
||||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
岸边缓冲区掩膜(缓冲区区域=1,其他=0)
|
||||
"""
|
||||
if buffer_size <= 0:
|
||||
return np.zeros_like(water_mask, dtype=np.int32)
|
||||
|
||||
water_binary = (water_mask > 0).astype(np.uint8)
|
||||
structure_size = buffer_size * 2 + 1
|
||||
structure = np.ones((structure_size, structure_size), dtype=np.uint8)
|
||||
|
||||
if CV2_AVAILABLE:
|
||||
eroded_water = cv2.erode(water_binary, structure).astype(np.int32)
|
||||
else:
|
||||
from scipy import ndimage
|
||||
eroded_water = ndimage.binary_erosion(water_binary, structure).astype(np.int32)
|
||||
|
||||
buffer_mask = (water_binary - eroded_water).astype(np.int32)
|
||||
|
||||
return buffer_mask
|
||||
|
||||
|
||||
def remove_shoreline_buffer(
|
||||
glint_mask: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
buffer_size: int = 5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
从耀斑掩膜中去除岸边缓冲区内的区域
|
||||
|
||||
Args:
|
||||
glint_mask: 耀斑掩膜数组
|
||||
water_mask: 水域掩膜数组
|
||||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
去除岸边缓冲区后的耀斑掩膜
|
||||
"""
|
||||
if buffer_size <= 0:
|
||||
return glint_mask
|
||||
|
||||
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
|
||||
|
||||
cleaned = glint_mask.copy()
|
||||
cleaned[buffer_mask > 0] = background
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 高级组合函数
|
||||
# =============================================================================
|
||||
|
||||
def calculate_glint_mask(
|
||||
nir_band: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
method: str = 'otsu',
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95,
|
||||
iqr_multiplier: float = 1.5,
|
||||
window_size: int = 15,
|
||||
apply_percentile_stretch: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
计算耀斑掩膜的统一入口函数
|
||||
|
||||
Args:
|
||||
nir_band: 近红外波段数组
|
||||
water_mask: 水域掩膜
|
||||
method: 检测方法 ('otsu', 'zscore', 'percentile', 'iqr', 'adaptive')
|
||||
z_threshold: Z-score阈值
|
||||
percentile: 百分位数阈值
|
||||
iqr_multiplier: IQR倍数
|
||||
window_size: 自适应阈值窗口大小
|
||||
apply_percentile_stretch: 是否对otsu和adaptive方法应用百分位数拉伸
|
||||
|
||||
Returns:
|
||||
二值化耀斑掩膜
|
||||
"""
|
||||
if method == 'otsu':
|
||||
if apply_percentile_stretch:
|
||||
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
|
||||
return otsu_threshold(stretched, water_mask)
|
||||
else:
|
||||
return otsu_threshold(nir_band.astype(np.int32), water_mask)
|
||||
elif method == 'zscore':
|
||||
return zscore_threshold(nir_band, water_mask, z_threshold)
|
||||
elif method == 'percentile':
|
||||
return percentile_threshold(nir_band, water_mask, percentile)
|
||||
elif method == 'iqr':
|
||||
return iqr_outlier_detection(nir_band, water_mask, iqr_multiplier)
|
||||
elif method == 'adaptive':
|
||||
if apply_percentile_stretch:
|
||||
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
|
||||
return adaptive_threshold(stretched, water_mask, window_size, percentile)
|
||||
else:
|
||||
return adaptive_threshold(nir_band.astype(np.int32), water_mask, window_size, percentile)
|
||||
else:
|
||||
raise ValueError(f"不支持的方法: {method}")
|
||||
7
src/core/algorithms/interpolation/__init__.py
Normal file
7
src/core/algorithms/interpolation/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""
|
||||
插值算法模块
|
||||
包含0值像素插值的核心数学逻辑
|
||||
"""
|
||||
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
|
||||
|
||||
__all__ = ['interpolate_pixels', 'interpolate_zero_pixels_batch']
|
||||
320
src/core/algorithms/interpolation/interpolator.py
Normal file
320
src/core/algorithms/interpolation/interpolator.py
Normal file
@ -0,0 +1,320 @@
|
||||
"""
|
||||
像素插值算法模块
|
||||
|
||||
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
|
||||
支持多种插值方法:nearest, bilinear, spline (RBF), kriging。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import griddata, RBFInterpolator
|
||||
from scipy.spatial import cKDTree
|
||||
SCIPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
SCIPY_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
def interpolate_pixels(
|
||||
image_stack: np.ndarray,
|
||||
zero_coords: np.ndarray,
|
||||
valid_coords: np.ndarray,
|
||||
valid_values: np.ndarray,
|
||||
interpolation_method: str = 'nearest',
|
||||
water_mask: Optional[np.ndarray] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对指定坐标的像素进行插值(核心数学函数,不涉及文件I/O)
|
||||
|
||||
Args:
|
||||
image_stack: 影像数据堆叠,形状为 (height, width, n_bands) 的 float32 数组
|
||||
zero_coords: 需要插值的像素坐标,形状为 (n_zero, 2),每行是 [x, y]
|
||||
valid_coords: 有效像素坐标,形状为 (n_valid, 2)
|
||||
valid_values: 有效像素对应的值,形状为 (n_valid,) 或 (n_valid, n_bands)
|
||||
interpolation_method: 插值方法,可选 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
water_mask: 可选的水域掩膜数组
|
||||
|
||||
Returns:
|
||||
插值后的影像副本,形状与 image_stack 相同
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
|
||||
height, width, n_bands = image_stack.shape
|
||||
result = image_stack.copy()
|
||||
|
||||
# 兼容中文和各种格式的method参数
|
||||
raw_method = str(interpolation_method).lower()
|
||||
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
|
||||
method = 'nearest'
|
||||
elif 'bilinear' in raw_method or '线性' in raw_method or '双线性' in raw_method:
|
||||
method = 'bilinear'
|
||||
elif 'spline' in raw_method or '样条' in raw_method or 'rbf' in raw_method:
|
||||
method = 'spline'
|
||||
elif 'kriging' in raw_method or '克里金' in raw_method:
|
||||
method = 'kriging'
|
||||
else:
|
||||
method = 'nearest'
|
||||
|
||||
if len(valid_values) == 0:
|
||||
return result
|
||||
|
||||
is_multiband = len(valid_values.shape) > 1 and valid_values.shape[1] > 1
|
||||
|
||||
if is_multiband:
|
||||
for band_idx in range(n_bands):
|
||||
band_valid_values = valid_values[:, band_idx]
|
||||
interpolated_values = _interpolate_single_band(
|
||||
zero_coords, valid_coords, band_valid_values, method
|
||||
)
|
||||
y_coords = zero_coords[:, 1].astype(int)
|
||||
x_coords = zero_coords[:, 0].astype(int)
|
||||
result[y_coords, x_coords, band_idx] = interpolated_values
|
||||
else:
|
||||
interpolated_values = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values, method
|
||||
)
|
||||
y_coords = zero_coords[:, 1].astype(int)
|
||||
x_coords = zero_coords[:, 0].astype(int)
|
||||
result[y_coords, x_coords] = interpolated_values
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _interpolate_single_band(
|
||||
zero_coords: np.ndarray,
|
||||
valid_coords: np.ndarray,
|
||||
valid_values: np.ndarray,
|
||||
method: str
|
||||
) -> np.ndarray:
|
||||
"""对单个波段执行插值计算"""
|
||||
if method == 'nearest':
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords)
|
||||
return valid_values[indices]
|
||||
|
||||
elif method == 'bilinear':
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
elif method == 'spline':
|
||||
try:
|
||||
max_points = 10000
|
||||
if len(valid_values) > max_points:
|
||||
indices = np.random.choice(len(valid_values), max_points, replace=False)
|
||||
sample_coords = valid_coords[indices]
|
||||
sample_values = valid_values[indices]
|
||||
else:
|
||||
sample_coords = valid_coords
|
||||
sample_values = valid_values
|
||||
rbf = RBFInterpolator(sample_coords, sample_values, kernel='thin_plate_spline')
|
||||
interpolated = rbf(zero_coords)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
except Exception:
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
elif method == 'kriging':
|
||||
try:
|
||||
from src.utils.kriging import KrigingInterpolator
|
||||
interpolator = KrigingInterpolator()
|
||||
max_points = 5000
|
||||
if len(valid_values) > max_points:
|
||||
indices = np.random.choice(len(valid_values), max_points, replace=False)
|
||||
sample_coords = valid_coords[indices]
|
||||
sample_values = valid_values[indices]
|
||||
else:
|
||||
sample_coords = valid_coords
|
||||
sample_values = valid_values
|
||||
interpolated = griddata(
|
||||
sample_coords, sample_values, zero_coords,
|
||||
method='cubic', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
except Exception:
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
return np.zeros(len(zero_coords))
|
||||
|
||||
|
||||
def interpolate_zero_pixels_batch(
|
||||
img_path: str,
|
||||
interpolation_method: str = 'nearest',
|
||||
output_path: Optional[str] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
deglint_dir: Optional[str] = None,
|
||||
callback_progress: Optional[callable] = None
|
||||
) -> Tuple[str, Optional[np.ndarray]]:
|
||||
"""
|
||||
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
output_path: 输出文件路径(如果为None,自动生成)
|
||||
water_mask: 水域掩膜(文件路径或数组)
|
||||
deglint_dir: 去耀斑目录(用于生成默认输出路径)
|
||||
callback_progress: 进度回调函数
|
||||
|
||||
Returns:
|
||||
(output_path, interpolated_image_stack) 元组
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
# 确定输出路径
|
||||
if output_path is None and deglint_dir is not None:
|
||||
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
|
||||
|
||||
# 检查文件是否已存在
|
||||
if output_path and Path(output_path).exists():
|
||||
return output_path, None
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
# 读取所有波段数据
|
||||
all_bands = []
|
||||
for band_idx in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(band_idx)
|
||||
band_data = band.ReadAsArray().astype(np.float32)
|
||||
all_bands.append(band_data)
|
||||
|
||||
image_stack = np.dstack(all_bands)
|
||||
|
||||
# 读取水域掩膜
|
||||
mask_array = None
|
||||
if water_mask is not None:
|
||||
if isinstance(water_mask, str):
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
elif isinstance(water_mask, np.ndarray):
|
||||
mask_array = water_mask
|
||||
|
||||
# 找出所有波段都为0的像素点
|
||||
all_bands_zero = np.all(image_stack == 0, axis=2)
|
||||
|
||||
if mask_array is not None:
|
||||
all_bands_zero = all_bands_zero & (mask_array > 0)
|
||||
|
||||
zero_pixel_count = np.sum(all_bands_zero)
|
||||
if zero_pixel_count == 0:
|
||||
# 无需插值,直接保存
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(all_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
out_dataset = None
|
||||
return output_path, image_stack
|
||||
|
||||
# 获取坐标
|
||||
zero_y, zero_x = np.where(all_bands_zero)
|
||||
zero_coords = np.column_stack([zero_x, zero_y])
|
||||
|
||||
valid_mask = ~all_bands_zero
|
||||
valid_y, valid_x = np.where(valid_mask)
|
||||
valid_coords = np.column_stack([valid_x, valid_y])
|
||||
|
||||
if len(valid_coords) == 0:
|
||||
raise ValueError("没有有效像素可用于插值")
|
||||
|
||||
# 逐波段插值
|
||||
interpolated_bands = []
|
||||
for band_idx in range(n_bands):
|
||||
if callback_progress:
|
||||
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
|
||||
band_data = all_bands[band_idx].copy()
|
||||
valid_values_band = band_data[valid_mask]
|
||||
|
||||
if len(valid_values_band) == 0:
|
||||
interpolated_bands.append(band_data)
|
||||
continue
|
||||
|
||||
band_result = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values_band, interpolation_method
|
||||
)
|
||||
band_data[all_bands_zero] = band_result
|
||||
interpolated_bands.append(band_data)
|
||||
|
||||
# 保存结果
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(interpolated_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
out_dataset = None
|
||||
|
||||
result_stack = np.dstack(interpolated_bands)
|
||||
return output_path, result_stack
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
Reference in New Issue
Block a user