refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块

This commit is contained in:
DXC
2026-05-09 17:18:34 +08:00
parent b2b90050dc
commit dcbcc043e4
17 changed files with 2673 additions and 948 deletions

View File

@ -0,0 +1,36 @@
"""
算法层模块
包含插值算法和耀斑检测算法等核心数学计算
"""
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
from src.core.algorithms.glint_detection.detectors import (
otsu_threshold,
zscore_threshold,
percentile_threshold,
iqr_outlier_detection,
adaptive_threshold,
multi_band_glint_detection,
percentile_stretch,
filter_large_components,
create_shoreline_buffer,
remove_shoreline_buffer,
calculate_glint_mask,
)
__all__ = [
# 插值
'interpolate_pixels',
'interpolate_zero_pixels_batch',
# 耀斑检测
'otsu_threshold',
'zscore_threshold',
'percentile_threshold',
'iqr_outlier_detection',
'adaptive_threshold',
'multi_band_glint_detection',
'percentile_stretch',
'filter_large_components',
'create_shoreline_buffer',
'remove_shoreline_buffer',
'calculate_glint_mask',
]

View File

@ -0,0 +1,31 @@
"""
耀斑检测算法模块
包含各种耀斑检测的核心数学计算函数
"""
from src.core.algorithms.glint_detection.detectors import (
otsu_threshold,
zscore_threshold,
percentile_threshold,
iqr_outlier_detection,
adaptive_threshold,
multi_band_glint_detection,
percentile_stretch,
filter_large_components,
create_shoreline_buffer,
remove_shoreline_buffer,
calculate_glint_mask,
)
__all__ = [
'otsu_threshold',
'zscore_threshold',
'percentile_threshold',
'iqr_outlier_detection',
'adaptive_threshold',
'multi_band_glint_detection',
'percentile_stretch',
'filter_large_components',
'create_shoreline_buffer',
'remove_shoreline_buffer',
'calculate_glint_mask',
]

View File

@ -0,0 +1,595 @@
"""
耀斑检测算法模块
包含各种耀斑检测的核心数学计算函数纯数学逻辑不涉及文件I/O。
支持的方法otsu, zscore, percentile, iqr, adaptive, multi_band
本模块是从 src/utils/find_severe_glint_area.py 抽取出来的核心算法部分。
"""
import numpy as np
from typing import Optional, List, Tuple
from functools import wraps
try:
import cv2
CV2_AVAILABLE = True
except ImportError:
CV2_AVAILABLE = False
def timeit(func):
"""装饰器:测量函数执行时间"""
@wraps(func)
def wrapper(*args, **kwargs):
import time
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print(f"[{func.__name__}] 耗时: {end - start:.2f}s")
return result
return wrapper
# =============================================================================
# 百分位数拉伸
# =============================================================================
def percentile_stretch(
img: np.ndarray,
data_water_mask: np.ndarray,
lower_percentile: float = 2,
upper_percentile: float = 98,
output_range: Tuple[int, int] = (0, 255)
) -> np.ndarray:
"""
使用百分位数裁剪进行归一化,适用于低反射率数据
通过排除极值,更好地利用数据的动态范围
Args:
img: 输入图像数组反射率值通常在0-1之间
data_water_mask: 水域掩膜
lower_percentile: 下百分位数用于裁剪最小值默认2
upper_percentile: 上百分位数用于裁剪最大值默认98
output_range: 输出范围,默认(0, 255)
Returns:
归一化后的图像数组(整数类型)
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return img.astype(np.int32)
p_lower = np.percentile(valid_pixels, lower_percentile)
p_upper = np.percentile(valid_pixels, upper_percentile)
if p_lower >= p_upper:
p_lower = np.percentile(valid_pixels, 1)
p_upper = np.percentile(valid_pixels, 99)
if p_lower >= p_upper:
p_upper = valid_pixels.max()
p_lower = valid_pixels.min()
img_clipped = np.clip(img, p_lower, p_upper)
if p_upper > p_lower:
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (
output_range[1] - output_range[0]
) + output_range[0]
else:
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
return img_stretched.astype(np.int32)
# =============================================================================
# Otsu阈值分割
# =============================================================================
def otsu_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
ignore_value: int = 0,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于Otsu方法的自动阈值分割
通过最大化类间方差找到最佳分割阈值
Args:
img: 输入图像数组(整数值)
data_water_mask: 水域掩膜
ignore_value: 忽略的值默认为0
foreground: 耀斑区域值默认1
background: 背景值默认0
Returns:
二值化检测结果数组
"""
height, width = img.shape
max_value = int(np.max(img[img > ignore_value])) + 1
if max_value < 2:
max_value = 256
hist = np.zeros([max_value], np.float32)
invalid_counter = 0
for i in range(height):
for j in range(width):
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
invalid_counter += 1
continue
hist[img[i, j]] += 1
total_valid = height * width - invalid_counter
if total_valid <= 0:
return np.zeros_like(img, dtype=np.int32)
hist /= total_valid
threshold = 0
deltaMax = 0
for i in range(max_value):
wA = sum(hist[:i + 1])
wB = sum(hist[i + 1:])
if wA == 0:
wA = 1e-10
if wB == 0:
wB = 1e-10
uAtmp = sum(j * hist[j] for j in range(i + 1))
uBtmp = sum(j * hist[j] for j in range(i + 1, max_value))
uA = uAtmp / wA
uB = uBtmp / wB
u = uAtmp + uBtmp
deltaTmp = wA * ((uA - u) ** 2) + wB * ((uB - u) ** 2)
if deltaTmp > deltaMax:
deltaMax = deltaTmp
threshold = i
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > threshold] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# Z-score阈值检测
# =============================================================================
def zscore_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
z_threshold: float = 2.5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于Z-score标准化分数的耀斑检测方法
使用统计方法识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
z_threshold: Z-score阈值默认2.5即超过均值2.5个标准差)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return np.zeros_like(img, dtype=np.int32)
mean_val = np.mean(valid_pixels)
std_val = np.std(valid_pixels)
if std_val == 0:
return np.zeros_like(img, dtype=np.int32)
z_scores = np.zeros_like(img, dtype=np.float32)
valid_mask = (data_water_mask > 0) & np.isfinite(img)
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
det_img = np.zeros_like(img, dtype=np.int32)
det_img[z_scores > z_threshold] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# 百分位数阈值检测
# =============================================================================
def percentile_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
percentile: float = 95,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于百分位数的耀斑检测方法
使用百分位数作为阈值,对异常值更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
percentile: 百分位数阈值默认95即超过95%的像素值)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return np.zeros_like(img, dtype=np.int32)
threshold_val = np.percentile(valid_pixels, percentile)
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > threshold_val] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# IQR异常值检测
# =============================================================================
def iqr_outlier_detection(
img: np.ndarray,
data_water_mask: np.ndarray,
iqr_multiplier: float = 1.5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于IQR四分位距的异常值检测方法
使用四分位距识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
iqr_multiplier: IQR倍数默认1.5(标准异常值检测)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return np.zeros_like(img, dtype=np.int32)
q1 = np.percentile(valid_pixels, 25)
q3 = np.percentile(valid_pixels, 75)
iqr = q3 - q1
upper_bound = q3 + iqr_multiplier * iqr
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > upper_bound] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# 自适应阈值检测
# =============================================================================
def adaptive_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
window_size: int = 15,
percentile: float = 90,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
自适应阈值方法
基于局部统计特性进行阈值分割,对光照变化更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
window_size: 局部窗口大小(奇数)
percentile: 局部百分位数阈值
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
height, width = img.shape
if window_size % 2 == 0:
window_size += 1
half_window = window_size // 2
det_img = np.zeros_like(img, dtype=np.int32)
for i in range(half_window, height - half_window):
for j in range(half_window, width - half_window):
if data_water_mask[i, j] == 0:
continue
local_window = img[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
local_mask = data_water_mask[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
valid_pixels = local_window[local_mask > 0]
if len(valid_pixels) > 0:
local_th = np.percentile(valid_pixels, percentile)
if img[i, j] > local_th:
det_img[i, j] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# 多波段融合耀斑检测
# =============================================================================
def multi_band_glint_detection(
nir_band: np.ndarray,
water_mask: np.ndarray,
glint_waves: List[float],
weights: Optional[List[float]] = None,
method: str = 'zscore',
z_threshold: float = 2.5,
percentile: float = 95,
sub_band_arrays: Optional[List[np.ndarray]] = None
) -> np.ndarray:
"""
多波段融合的耀斑检测方法
结合多个波段的耀斑特征,提高检测的稳健性
Args:
nir_band: 近红外波段数组(主波段,用于兼容性)
water_mask: 水域掩膜数组
glint_waves: 用于检测的波长列表,如[750, 800, 850]
weights: 各波段的权重如果为None则使用等权重
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
z_threshold: Z-score阈值当method='zscore'时使用)
percentile: 百分位数阈值当method='percentile'时使用)
sub_band_arrays: 子波段数组列表(如果提供,与 glint_waves 一一对应)
Returns:
二值化检测结果
"""
if weights is None:
weights = [1.0 / len(glint_waves)] * len(glint_waves)
if len(weights) != len(glint_waves):
raise ValueError("权重数量必须与波长数量相同")
fused_band = None
if sub_band_arrays is not None and len(sub_band_arrays) == len(glint_waves):
for i, band_array in enumerate(sub_band_arrays):
if fused_band is None:
fused_band = (band_array * weights[i]).astype(np.float32)
else:
fused_band = (fused_band + band_array * weights[i]).astype(np.float32)
else:
fused_band = nir_band.astype(np.float32)
if method == 'otsu':
stretched = percentile_stretch(fused_band, water_mask, 2, 98)
return otsu_threshold(stretched, water_mask)
elif method == 'zscore':
return zscore_threshold(fused_band, water_mask, z_threshold)
elif method == 'percentile':
return percentile_threshold(fused_band, water_mask, percentile)
else:
raise ValueError(f"不支持的方法: {method}")
# =============================================================================
# 连通域过滤
# =============================================================================
def filter_large_components(
binary_img: np.ndarray,
max_area: Optional[int] = None,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
过滤掉面积超过阈值的连通域
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
Args:
binary_img: 二值化图像
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
foreground: 前景值
background: 背景值
Returns:
过滤后的二值化图像
"""
if max_area is None or max_area <= 0:
return binary_img
if CV2_AVAILABLE:
binary_for_label = (binary_img == foreground).astype(np.uint8)
num_features, labeled_array, stats, _ = cv2.connectedComponentsWithStats(
binary_for_label, connectivity=8
)
if num_features == 0:
return binary_img
component_sizes = stats[1:, cv2.CC_STAT_AREA]
keep_labels = np.where(component_sizes <= max_area)[0] + 1
keep_mask = np.isin(labeled_array, keep_labels)
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
filtered[keep_mask] = foreground
return filtered
else:
from scipy import ndimage
labeled_array, num_features = ndimage.label(
(binary_img == foreground).astype(np.int32)
)
if num_features == 0:
return binary_img
component_sizes = ndimage.sum(
(labeled_array == i).astype(np.int32),
labeled_array,
range(1, num_features + 1)
)
keep_mask = np.isin(labeled_array, [i + 1 for i, s in enumerate(component_sizes) if s <= max_area])
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
filtered[keep_mask] = foreground
return filtered
# =============================================================================
# 岸边缓冲区处理
# =============================================================================
def create_shoreline_buffer(
water_mask: np.ndarray,
buffer_size: int = 5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
创建岸边缓冲区掩膜(向内缓冲)
用于去除岸边附近的错误耀斑检测区域
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
Args:
water_mask: 水域掩膜数组(水域=1非水域=0
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
岸边缓冲区掩膜(缓冲区区域=1其他=0
"""
if buffer_size <= 0:
return np.zeros_like(water_mask, dtype=np.int32)
water_binary = (water_mask > 0).astype(np.uint8)
structure_size = buffer_size * 2 + 1
structure = np.ones((structure_size, structure_size), dtype=np.uint8)
if CV2_AVAILABLE:
eroded_water = cv2.erode(water_binary, structure).astype(np.int32)
else:
from scipy import ndimage
eroded_water = ndimage.binary_erosion(water_binary, structure).astype(np.int32)
buffer_mask = (water_binary - eroded_water).astype(np.int32)
return buffer_mask
def remove_shoreline_buffer(
glint_mask: np.ndarray,
water_mask: np.ndarray,
buffer_size: int = 5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
从耀斑掩膜中去除岸边缓冲区内的区域
Args:
glint_mask: 耀斑掩膜数组
water_mask: 水域掩膜数组
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
去除岸边缓冲区后的耀斑掩膜
"""
if buffer_size <= 0:
return glint_mask
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
cleaned = glint_mask.copy()
cleaned[buffer_mask > 0] = background
return cleaned
# =============================================================================
# 高级组合函数
# =============================================================================
def calculate_glint_mask(
nir_band: np.ndarray,
water_mask: np.ndarray,
method: str = 'otsu',
z_threshold: float = 2.5,
percentile: float = 95,
iqr_multiplier: float = 1.5,
window_size: int = 15,
apply_percentile_stretch: bool = True
) -> np.ndarray:
"""
计算耀斑掩膜的统一入口函数
Args:
nir_band: 近红外波段数组
water_mask: 水域掩膜
method: 检测方法 ('otsu', 'zscore', 'percentile', 'iqr', 'adaptive')
z_threshold: Z-score阈值
percentile: 百分位数阈值
iqr_multiplier: IQR倍数
window_size: 自适应阈值窗口大小
apply_percentile_stretch: 是否对otsu和adaptive方法应用百分位数拉伸
Returns:
二值化耀斑掩膜
"""
if method == 'otsu':
if apply_percentile_stretch:
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
return otsu_threshold(stretched, water_mask)
else:
return otsu_threshold(nir_band.astype(np.int32), water_mask)
elif method == 'zscore':
return zscore_threshold(nir_band, water_mask, z_threshold)
elif method == 'percentile':
return percentile_threshold(nir_band, water_mask, percentile)
elif method == 'iqr':
return iqr_outlier_detection(nir_band, water_mask, iqr_multiplier)
elif method == 'adaptive':
if apply_percentile_stretch:
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
return adaptive_threshold(stretched, water_mask, window_size, percentile)
else:
return adaptive_threshold(nir_band.astype(np.int32), water_mask, window_size, percentile)
else:
raise ValueError(f"不支持的方法: {method}")

View File

@ -0,0 +1,7 @@
"""
插值算法模块
包含0值像素插值的核心数学逻辑
"""
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
__all__ = ['interpolate_pixels', 'interpolate_zero_pixels_batch']

View File

@ -0,0 +1,320 @@
"""
像素插值算法模块
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
支持多种插值方法nearest, bilinear, spline (RBF), kriging。
"""
import numpy as np
from typing import Optional, Union, Tuple, List
from pathlib import Path
try:
from scipy import ndimage
from scipy.interpolate import griddata, RBFInterpolator
from scipy.spatial import cKDTree
SCIPY_AVAILABLE = True
except ImportError:
SCIPY_AVAILABLE = False
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
def interpolate_pixels(
image_stack: np.ndarray,
zero_coords: np.ndarray,
valid_coords: np.ndarray,
valid_values: np.ndarray,
interpolation_method: str = 'nearest',
water_mask: Optional[np.ndarray] = None
) -> np.ndarray:
"""
对指定坐标的像素进行插值核心数学函数不涉及文件I/O
Args:
image_stack: 影像数据堆叠,形状为 (height, width, n_bands) 的 float32 数组
zero_coords: 需要插值的像素坐标,形状为 (n_zero, 2),每行是 [x, y]
valid_coords: 有效像素坐标,形状为 (n_valid, 2)
valid_values: 有效像素对应的值,形状为 (n_valid,) 或 (n_valid, n_bands)
interpolation_method: 插值方法,可选 'nearest', 'bilinear', 'spline', 'kriging'
water_mask: 可选的水域掩膜数组
Returns:
插值后的影像副本,形状与 image_stack 相同
"""
if not SCIPY_AVAILABLE:
raise ImportError("scipy未安装无法进行0值像素插值")
height, width, n_bands = image_stack.shape
result = image_stack.copy()
# 兼容中文和各种格式的method参数
raw_method = str(interpolation_method).lower()
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
method = 'nearest'
elif 'bilinear' in raw_method or '线性' in raw_method or '双线性' in raw_method:
method = 'bilinear'
elif 'spline' in raw_method or '样条' in raw_method or 'rbf' in raw_method:
method = 'spline'
elif 'kriging' in raw_method or '克里金' in raw_method:
method = 'kriging'
else:
method = 'nearest'
if len(valid_values) == 0:
return result
is_multiband = len(valid_values.shape) > 1 and valid_values.shape[1] > 1
if is_multiband:
for band_idx in range(n_bands):
band_valid_values = valid_values[:, band_idx]
interpolated_values = _interpolate_single_band(
zero_coords, valid_coords, band_valid_values, method
)
y_coords = zero_coords[:, 1].astype(int)
x_coords = zero_coords[:, 0].astype(int)
result[y_coords, x_coords, band_idx] = interpolated_values
else:
interpolated_values = _interpolate_single_band(
zero_coords, valid_coords, valid_values, method
)
y_coords = zero_coords[:, 1].astype(int)
x_coords = zero_coords[:, 0].astype(int)
result[y_coords, x_coords] = interpolated_values
return result
def _interpolate_single_band(
zero_coords: np.ndarray,
valid_coords: np.ndarray,
valid_values: np.ndarray,
method: str
) -> np.ndarray:
"""对单个波段执行插值计算"""
if method == 'nearest':
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords)
return valid_values[indices]
elif method == 'bilinear':
interpolated = griddata(
valid_coords, valid_values, zero_coords,
method='linear', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
elif method == 'spline':
try:
max_points = 10000
if len(valid_values) > max_points:
indices = np.random.choice(len(valid_values), max_points, replace=False)
sample_coords = valid_coords[indices]
sample_values = valid_values[indices]
else:
sample_coords = valid_coords
sample_values = valid_values
rbf = RBFInterpolator(sample_coords, sample_values, kernel='thin_plate_spline')
interpolated = rbf(zero_coords)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
except Exception:
interpolated = griddata(
valid_coords, valid_values, zero_coords,
method='linear', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
elif method == 'kriging':
try:
from src.utils.kriging import KrigingInterpolator
interpolator = KrigingInterpolator()
max_points = 5000
if len(valid_values) > max_points:
indices = np.random.choice(len(valid_values), max_points, replace=False)
sample_coords = valid_coords[indices]
sample_values = valid_values[indices]
else:
sample_coords = valid_coords
sample_values = valid_values
interpolated = griddata(
sample_coords, sample_values, zero_coords,
method='cubic', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
except Exception:
interpolated = griddata(
valid_coords, valid_values, zero_coords,
method='linear', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
return np.zeros(len(zero_coords))
def interpolate_zero_pixels_batch(
img_path: str,
interpolation_method: str = 'nearest',
output_path: Optional[str] = None,
water_mask: Optional[Union[str, np.ndarray]] = None,
deglint_dir: Optional[str] = None,
callback_progress: Optional[callable] = None
) -> Tuple[str, Optional[np.ndarray]]:
"""
对影像中所有波段都为0的像素点进行插值完整流程含文件I/O
Args:
img_path: 输入影像文件路径
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
output_path: 输出文件路径如果为None自动生成
water_mask: 水域掩膜(文件路径或数组)
deglint_dir: 去耀斑目录(用于生成默认输出路径)
callback_progress: 进度回调函数
Returns:
(output_path, interpolated_image_stack) 元组
"""
if not SCIPY_AVAILABLE:
raise ImportError("scipy未安装无法进行0值像素插值")
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
# 确定输出路径
if output_path is None and deglint_dir is not None:
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
# 检查文件是否已存在
if output_path and Path(output_path).exists():
return output_path, None
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
# 读取所有波段数据
all_bands = []
for band_idx in range(1, n_bands + 1):
band = dataset.GetRasterBand(band_idx)
band_data = band.ReadAsArray().astype(np.float32)
all_bands.append(band_data)
image_stack = np.dstack(all_bands)
# 读取水域掩膜
mask_array = None
if water_mask is not None:
if isinstance(water_mask, str):
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
elif isinstance(water_mask, np.ndarray):
mask_array = water_mask
# 找出所有波段都为0的像素点
all_bands_zero = np.all(image_stack == 0, axis=2)
if mask_array is not None:
all_bands_zero = all_bands_zero & (mask_array > 0)
zero_pixel_count = np.sum(all_bands_zero)
if zero_pixel_count == 0:
# 无需插值,直接保存
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(all_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
out_dataset = None
return output_path, image_stack
# 获取坐标
zero_y, zero_x = np.where(all_bands_zero)
zero_coords = np.column_stack([zero_x, zero_y])
valid_mask = ~all_bands_zero
valid_y, valid_x = np.where(valid_mask)
valid_coords = np.column_stack([valid_x, valid_y])
if len(valid_coords) == 0:
raise ValueError("没有有效像素可用于插值")
# 逐波段插值
interpolated_bands = []
for band_idx in range(n_bands):
if callback_progress:
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
band_data = all_bands[band_idx].copy()
valid_values_band = band_data[valid_mask]
if len(valid_values_band) == 0:
interpolated_bands.append(band_data)
continue
band_result = _interpolate_single_band(
zero_coords, valid_coords, valid_values_band, interpolation_method
)
band_data[all_bands_zero] = band_result
interpolated_bands.append(band_data)
# 保存结果
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(interpolated_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
out_dataset = None
result_stack = np.dstack(interpolated_bands)
return output_path, result_stack
finally:
dataset = None