From dcbcc043e499a383959ae71bf0fe7318e41439d9 Mon Sep 17 00:00:00 2001 From: DXC Date: Sat, 9 May 2026 17:18:34 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=B8=90=E8=BF=9B=E5=BC=8F?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E5=8C=96=E9=87=8D=E6=9E=84=20=E2=80=94=20?= =?UTF-8?q?=E5=89=A5=E7=A6=BB=E5=8F=AF=E8=A7=86=E5=8C=96=E5=B1=82=E3=80=81?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E5=B1=82=E3=80=81=E7=AE=97=E6=B3=95=E5=B1=82?= =?UTF-8?q?=E5=88=B0=E7=8B=AC=E7=AB=8B=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/algorithms/__init__.py | 36 + .../algorithms/glint_detection/__init__.py | 31 + .../algorithms/glint_detection/detectors.py | 595 ++++++++++ src/core/algorithms/interpolation/__init__.py | 7 + .../algorithms/interpolation/interpolator.py | 320 +++++ src/core/utils/__init__.py | 42 + src/core/utils/gdal_helper.py | 309 +++++ src/core/utils/mask_converter.py | 210 ++++ src/core/utils/preview_generator.py | 339 ++++++ src/core/visualization/__init__.py | 21 + src/core/visualization/boxplot.py | 183 +++ src/core/visualization/preview.py | 59 + src/core/visualization/report.py | 147 +++ src/core/visualization/scatter_plot.py | 147 +++ src/core/visualization/spectrum_plot.py | 80 ++ src/core/visualization/statistics.py | 59 + .../water_quality_inversion_pipeline_GUI.py | 1036 ++--------------- 17 files changed, 2673 insertions(+), 948 deletions(-) create mode 100644 src/core/algorithms/__init__.py create mode 100644 src/core/algorithms/glint_detection/__init__.py create mode 100644 src/core/algorithms/glint_detection/detectors.py create mode 100644 src/core/algorithms/interpolation/__init__.py create mode 100644 src/core/algorithms/interpolation/interpolator.py create mode 100644 src/core/utils/__init__.py create mode 100644 src/core/utils/gdal_helper.py create mode 100644 src/core/utils/mask_converter.py create mode 100644 src/core/utils/preview_generator.py create mode 100644 src/core/visualization/__init__.py create mode 100644 src/core/visualization/boxplot.py create mode 100644 src/core/visualization/preview.py create mode 100644 src/core/visualization/report.py create mode 100644 src/core/visualization/scatter_plot.py create mode 100644 src/core/visualization/spectrum_plot.py create mode 100644 src/core/visualization/statistics.py diff --git a/src/core/algorithms/__init__.py b/src/core/algorithms/__init__.py new file mode 100644 index 0000000..ac04e68 --- /dev/null +++ b/src/core/algorithms/__init__.py @@ -0,0 +1,36 @@ +""" +算法层模块 +包含插值算法和耀斑检测算法等核心数学计算 +""" +from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch +from src.core.algorithms.glint_detection.detectors import ( + otsu_threshold, + zscore_threshold, + percentile_threshold, + iqr_outlier_detection, + adaptive_threshold, + multi_band_glint_detection, + percentile_stretch, + filter_large_components, + create_shoreline_buffer, + remove_shoreline_buffer, + calculate_glint_mask, +) + +__all__ = [ + # 插值 + 'interpolate_pixels', + 'interpolate_zero_pixels_batch', + # 耀斑检测 + 'otsu_threshold', + 'zscore_threshold', + 'percentile_threshold', + 'iqr_outlier_detection', + 'adaptive_threshold', + 'multi_band_glint_detection', + 'percentile_stretch', + 'filter_large_components', + 'create_shoreline_buffer', + 'remove_shoreline_buffer', + 'calculate_glint_mask', +] diff --git a/src/core/algorithms/glint_detection/__init__.py b/src/core/algorithms/glint_detection/__init__.py new file mode 100644 index 0000000..e6425d2 --- /dev/null +++ b/src/core/algorithms/glint_detection/__init__.py @@ -0,0 +1,31 @@ +""" +耀斑检测算法模块 +包含各种耀斑检测的核心数学计算函数 +""" +from src.core.algorithms.glint_detection.detectors import ( + otsu_threshold, + zscore_threshold, + percentile_threshold, + iqr_outlier_detection, + adaptive_threshold, + multi_band_glint_detection, + percentile_stretch, + filter_large_components, + create_shoreline_buffer, + remove_shoreline_buffer, + calculate_glint_mask, +) + +__all__ = [ + 'otsu_threshold', + 'zscore_threshold', + 'percentile_threshold', + 'iqr_outlier_detection', + 'adaptive_threshold', + 'multi_band_glint_detection', + 'percentile_stretch', + 'filter_large_components', + 'create_shoreline_buffer', + 'remove_shoreline_buffer', + 'calculate_glint_mask', +] diff --git a/src/core/algorithms/glint_detection/detectors.py b/src/core/algorithms/glint_detection/detectors.py new file mode 100644 index 0000000..fb15fd3 --- /dev/null +++ b/src/core/algorithms/glint_detection/detectors.py @@ -0,0 +1,595 @@ +""" +耀斑检测算法模块 + +包含各种耀斑检测的核心数学计算函数,纯数学逻辑,不涉及文件I/O。 +支持的方法:otsu, zscore, percentile, iqr, adaptive, multi_band + +本模块是从 src/utils/find_severe_glint_area.py 抽取出来的核心算法部分。 +""" + +import numpy as np +from typing import Optional, List, Tuple +from functools import wraps + +try: + import cv2 + CV2_AVAILABLE = True +except ImportError: + CV2_AVAILABLE = False + + +def timeit(func): + """装饰器:测量函数执行时间""" + @wraps(func) + def wrapper(*args, **kwargs): + import time + start = time.time() + result = func(*args, **kwargs) + end = time.time() + print(f"[{func.__name__}] 耗时: {end - start:.2f}s") + return result + return wrapper + + +# ============================================================================= +# 百分位数拉伸 +# ============================================================================= + +def percentile_stretch( + img: np.ndarray, + data_water_mask: np.ndarray, + lower_percentile: float = 2, + upper_percentile: float = 98, + output_range: Tuple[int, int] = (0, 255) +) -> np.ndarray: + """ + 使用百分位数裁剪进行归一化,适用于低反射率数据 + 通过排除极值,更好地利用数据的动态范围 + + Args: + img: 输入图像数组(反射率值,通常在0-1之间) + data_water_mask: 水域掩膜 + lower_percentile: 下百分位数,用于裁剪最小值(默认2) + upper_percentile: 上百分位数,用于裁剪最大值(默认98) + output_range: 输出范围,默认(0, 255) + + Returns: + 归一化后的图像数组(整数类型) + """ + valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)] + + if len(valid_pixels) == 0: + return img.astype(np.int32) + + p_lower = np.percentile(valid_pixels, lower_percentile) + p_upper = np.percentile(valid_pixels, upper_percentile) + + if p_lower >= p_upper: + p_lower = np.percentile(valid_pixels, 1) + p_upper = np.percentile(valid_pixels, 99) + if p_lower >= p_upper: + p_upper = valid_pixels.max() + p_lower = valid_pixels.min() + + img_clipped = np.clip(img, p_lower, p_upper) + + if p_upper > p_lower: + img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * ( + output_range[1] - output_range[0] + ) + output_range[0] + else: + img_stretched = np.full_like(img, output_range[0], dtype=np.float32) + + return img_stretched.astype(np.int32) + + +# ============================================================================= +# Otsu阈值分割 +# ============================================================================= + +def otsu_threshold( + img: np.ndarray, + data_water_mask: np.ndarray, + ignore_value: int = 0, + foreground: int = 1, + background: int = 0 +) -> np.ndarray: + """ + 基于Otsu方法的自动阈值分割 + 通过最大化类间方差找到最佳分割阈值 + + Args: + img: 输入图像数组(整数值) + data_water_mask: 水域掩膜 + ignore_value: 忽略的值(默认为0) + foreground: 耀斑区域值(默认1) + background: 背景值(默认0) + + Returns: + 二值化检测结果数组 + """ + height, width = img.shape + + max_value = int(np.max(img[img > ignore_value])) + 1 + if max_value < 2: + max_value = 256 + + hist = np.zeros([max_value], np.float32) + + invalid_counter = 0 + for i in range(height): + for j in range(width): + if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0: + invalid_counter += 1 + continue + hist[img[i, j]] += 1 + + total_valid = height * width - invalid_counter + if total_valid <= 0: + return np.zeros_like(img, dtype=np.int32) + hist /= total_valid + + threshold = 0 + deltaMax = 0 + + for i in range(max_value): + wA = sum(hist[:i + 1]) + wB = sum(hist[i + 1:]) + if wA == 0: + wA = 1e-10 + if wB == 0: + wB = 1e-10 + + uAtmp = sum(j * hist[j] for j in range(i + 1)) + uBtmp = sum(j * hist[j] for j in range(i + 1, max_value)) + uA = uAtmp / wA + uB = uBtmp / wB + u = uAtmp + uBtmp + + deltaTmp = wA * ((uA - u) ** 2) + wB * ((uB - u) ** 2) + if deltaTmp > deltaMax: + deltaMax = deltaTmp + threshold = i + + det_img = np.zeros_like(img, dtype=np.int32) + det_img[img > threshold] = foreground + det_img[data_water_mask == 0] = background + + return det_img + + +# ============================================================================= +# Z-score阈值检测 +# ============================================================================= + +def zscore_threshold( + img: np.ndarray, + data_water_mask: np.ndarray, + z_threshold: float = 2.5, + foreground: int = 1, + background: int = 0 +) -> np.ndarray: + """ + 基于Z-score(标准化分数)的耀斑检测方法 + 使用统计方法识别异常高亮的像素,对数据分布不敏感 + + Args: + img: 输入图像数组 + data_water_mask: 水域掩膜 + z_threshold: Z-score阈值,默认2.5(即超过均值2.5个标准差) + foreground: 前景值 + background: 背景值 + + Returns: + 二值化检测结果 + """ + valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)] + + if len(valid_pixels) == 0: + return np.zeros_like(img, dtype=np.int32) + + mean_val = np.mean(valid_pixels) + std_val = np.std(valid_pixels) + + if std_val == 0: + return np.zeros_like(img, dtype=np.int32) + + z_scores = np.zeros_like(img, dtype=np.float32) + valid_mask = (data_water_mask > 0) & np.isfinite(img) + z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val + + det_img = np.zeros_like(img, dtype=np.int32) + det_img[z_scores > z_threshold] = foreground + det_img[data_water_mask == 0] = background + + return det_img + + +# ============================================================================= +# 百分位数阈值检测 +# ============================================================================= + +def percentile_threshold( + img: np.ndarray, + data_water_mask: np.ndarray, + percentile: float = 95, + foreground: int = 1, + background: int = 0 +) -> np.ndarray: + """ + 基于百分位数的耀斑检测方法 + 使用百分位数作为阈值,对异常值更稳健 + + Args: + img: 输入图像数组 + data_water_mask: 水域掩膜 + percentile: 百分位数阈值,默认95(即超过95%的像素值) + foreground: 前景值 + background: 背景值 + + Returns: + 二值化检测结果 + """ + valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)] + + if len(valid_pixels) == 0: + return np.zeros_like(img, dtype=np.int32) + + threshold_val = np.percentile(valid_pixels, percentile) + + det_img = np.zeros_like(img, dtype=np.int32) + det_img[img > threshold_val] = foreground + det_img[data_water_mask == 0] = background + + return det_img + + +# ============================================================================= +# IQR异常值检测 +# ============================================================================= + +def iqr_outlier_detection( + img: np.ndarray, + data_water_mask: np.ndarray, + iqr_multiplier: float = 1.5, + foreground: int = 1, + background: int = 0 +) -> np.ndarray: + """ + 基于IQR(四分位距)的异常值检测方法 + 使用四分位距识别异常高亮的像素,对数据分布不敏感 + + Args: + img: 输入图像数组 + data_water_mask: 水域掩膜 + iqr_multiplier: IQR倍数,默认1.5(标准异常值检测) + foreground: 前景值 + background: 背景值 + + Returns: + 二值化检测结果 + """ + valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)] + + if len(valid_pixels) == 0: + return np.zeros_like(img, dtype=np.int32) + + q1 = np.percentile(valid_pixels, 25) + q3 = np.percentile(valid_pixels, 75) + iqr = q3 - q1 + + upper_bound = q3 + iqr_multiplier * iqr + + det_img = np.zeros_like(img, dtype=np.int32) + det_img[img > upper_bound] = foreground + det_img[data_water_mask == 0] = background + + return det_img + + +# ============================================================================= +# 自适应阈值检测 +# ============================================================================= + +def adaptive_threshold( + img: np.ndarray, + data_water_mask: np.ndarray, + window_size: int = 15, + percentile: float = 90, + foreground: int = 1, + background: int = 0 +) -> np.ndarray: + """ + 自适应阈值方法 + 基于局部统计特性进行阈值分割,对光照变化更稳健 + + Args: + img: 输入图像数组 + data_water_mask: 水域掩膜 + window_size: 局部窗口大小(奇数) + percentile: 局部百分位数阈值 + foreground: 前景值 + background: 背景值 + + Returns: + 二值化检测结果 + """ + height, width = img.shape + + if window_size % 2 == 0: + window_size += 1 + + half_window = window_size // 2 + + det_img = np.zeros_like(img, dtype=np.int32) + + for i in range(half_window, height - half_window): + for j in range(half_window, width - half_window): + if data_water_mask[i, j] == 0: + continue + + local_window = img[i - half_window:i + half_window + 1, + j - half_window:j + half_window + 1] + local_mask = data_water_mask[i - half_window:i + half_window + 1, + j - half_window:j + half_window + 1] + + valid_pixels = local_window[local_mask > 0] + + if len(valid_pixels) > 0: + local_th = np.percentile(valid_pixels, percentile) + if img[i, j] > local_th: + det_img[i, j] = foreground + + det_img[data_water_mask == 0] = background + + return det_img + + +# ============================================================================= +# 多波段融合耀斑检测 +# ============================================================================= + +def multi_band_glint_detection( + nir_band: np.ndarray, + water_mask: np.ndarray, + glint_waves: List[float], + weights: Optional[List[float]] = None, + method: str = 'zscore', + z_threshold: float = 2.5, + percentile: float = 95, + sub_band_arrays: Optional[List[np.ndarray]] = None +) -> np.ndarray: + """ + 多波段融合的耀斑检测方法 + 结合多个波段的耀斑特征,提高检测的稳健性 + + Args: + nir_band: 近红外波段数组(主波段,用于兼容性) + water_mask: 水域掩膜数组 + glint_waves: 用于检测的波长列表,如[750, 800, 850] + weights: 各波段的权重,如果为None则使用等权重 + method: 使用的检测方法 ('zscore', 'percentile', 'otsu') + z_threshold: Z-score阈值(当method='zscore'时使用) + percentile: 百分位数阈值(当method='percentile'时使用) + sub_band_arrays: 子波段数组列表(如果提供,与 glint_waves 一一对应) + + Returns: + 二值化检测结果 + """ + if weights is None: + weights = [1.0 / len(glint_waves)] * len(glint_waves) + + if len(weights) != len(glint_waves): + raise ValueError("权重数量必须与波长数量相同") + + fused_band = None + + if sub_band_arrays is not None and len(sub_band_arrays) == len(glint_waves): + for i, band_array in enumerate(sub_band_arrays): + if fused_band is None: + fused_band = (band_array * weights[i]).astype(np.float32) + else: + fused_band = (fused_band + band_array * weights[i]).astype(np.float32) + else: + fused_band = nir_band.astype(np.float32) + + if method == 'otsu': + stretched = percentile_stretch(fused_band, water_mask, 2, 98) + return otsu_threshold(stretched, water_mask) + elif method == 'zscore': + return zscore_threshold(fused_band, water_mask, z_threshold) + elif method == 'percentile': + return percentile_threshold(fused_band, water_mask, percentile) + else: + raise ValueError(f"不支持的方法: {method}") + + +# ============================================================================= +# 连通域过滤 +# ============================================================================= + +def filter_large_components( + binary_img: np.ndarray, + max_area: Optional[int] = None, + foreground: int = 1, + background: int = 0 +) -> np.ndarray: + """ + 过滤掉面积超过阈值的连通域 + 用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域 + + Args: + binary_img: 二值化图像 + max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除 + foreground: 前景值 + background: 背景值 + + Returns: + 过滤后的二值化图像 + """ + if max_area is None or max_area <= 0: + return binary_img + + if CV2_AVAILABLE: + binary_for_label = (binary_img == foreground).astype(np.uint8) + num_features, labeled_array, stats, _ = cv2.connectedComponentsWithStats( + binary_for_label, connectivity=8 + ) + + if num_features == 0: + return binary_img + + component_sizes = stats[1:, cv2.CC_STAT_AREA] + keep_labels = np.where(component_sizes <= max_area)[0] + 1 + + keep_mask = np.isin(labeled_array, keep_labels) + filtered = np.zeros_like(binary_img, dtype=binary_img.dtype) + filtered[keep_mask] = foreground + + return filtered + else: + from scipy import ndimage + labeled_array, num_features = ndimage.label( + (binary_img == foreground).astype(np.int32) + ) + + if num_features == 0: + return binary_img + + component_sizes = ndimage.sum( + (labeled_array == i).astype(np.int32), + labeled_array, + range(1, num_features + 1) + ) + + keep_mask = np.isin(labeled_array, [i + 1 for i, s in enumerate(component_sizes) if s <= max_area]) + filtered = np.zeros_like(binary_img, dtype=binary_img.dtype) + filtered[keep_mask] = foreground + + return filtered + + +# ============================================================================= +# 岸边缓冲区处理 +# ============================================================================= + +def create_shoreline_buffer( + water_mask: np.ndarray, + buffer_size: int = 5, + foreground: int = 1, + background: int = 0 +) -> np.ndarray: + """ + 创建岸边缓冲区掩膜(向内缓冲) + 用于去除岸边附近的错误耀斑检测区域 + + 方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域 + + Args: + water_mask: 水域掩膜数组(水域=1,非水域=0) + buffer_size: 缓冲区大小(像素数),默认5像素 + foreground: 前景值 + background: 背景值 + + Returns: + 岸边缓冲区掩膜(缓冲区区域=1,其他=0) + """ + if buffer_size <= 0: + return np.zeros_like(water_mask, dtype=np.int32) + + water_binary = (water_mask > 0).astype(np.uint8) + structure_size = buffer_size * 2 + 1 + structure = np.ones((structure_size, structure_size), dtype=np.uint8) + + if CV2_AVAILABLE: + eroded_water = cv2.erode(water_binary, structure).astype(np.int32) + else: + from scipy import ndimage + eroded_water = ndimage.binary_erosion(water_binary, structure).astype(np.int32) + + buffer_mask = (water_binary - eroded_water).astype(np.int32) + + return buffer_mask + + +def remove_shoreline_buffer( + glint_mask: np.ndarray, + water_mask: np.ndarray, + buffer_size: int = 5, + foreground: int = 1, + background: int = 0 +) -> np.ndarray: + """ + 从耀斑掩膜中去除岸边缓冲区内的区域 + + Args: + glint_mask: 耀斑掩膜数组 + water_mask: 水域掩膜数组 + buffer_size: 缓冲区大小(像素数),默认5像素 + foreground: 前景值 + background: 背景值 + + Returns: + 去除岸边缓冲区后的耀斑掩膜 + """ + if buffer_size <= 0: + return glint_mask + + buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background) + + cleaned = glint_mask.copy() + cleaned[buffer_mask > 0] = background + + return cleaned + + +# ============================================================================= +# 高级组合函数 +# ============================================================================= + +def calculate_glint_mask( + nir_band: np.ndarray, + water_mask: np.ndarray, + method: str = 'otsu', + z_threshold: float = 2.5, + percentile: float = 95, + iqr_multiplier: float = 1.5, + window_size: int = 15, + apply_percentile_stretch: bool = True +) -> np.ndarray: + """ + 计算耀斑掩膜的统一入口函数 + + Args: + nir_band: 近红外波段数组 + water_mask: 水域掩膜 + method: 检测方法 ('otsu', 'zscore', 'percentile', 'iqr', 'adaptive') + z_threshold: Z-score阈值 + percentile: 百分位数阈值 + iqr_multiplier: IQR倍数 + window_size: 自适应阈值窗口大小 + apply_percentile_stretch: 是否对otsu和adaptive方法应用百分位数拉伸 + + Returns: + 二值化耀斑掩膜 + """ + if method == 'otsu': + if apply_percentile_stretch: + stretched = percentile_stretch(nir_band, water_mask, 2, 98) + return otsu_threshold(stretched, water_mask) + else: + return otsu_threshold(nir_band.astype(np.int32), water_mask) + elif method == 'zscore': + return zscore_threshold(nir_band, water_mask, z_threshold) + elif method == 'percentile': + return percentile_threshold(nir_band, water_mask, percentile) + elif method == 'iqr': + return iqr_outlier_detection(nir_band, water_mask, iqr_multiplier) + elif method == 'adaptive': + if apply_percentile_stretch: + stretched = percentile_stretch(nir_band, water_mask, 2, 98) + return adaptive_threshold(stretched, water_mask, window_size, percentile) + else: + return adaptive_threshold(nir_band.astype(np.int32), water_mask, window_size, percentile) + else: + raise ValueError(f"不支持的方法: {method}") diff --git a/src/core/algorithms/interpolation/__init__.py b/src/core/algorithms/interpolation/__init__.py new file mode 100644 index 0000000..6f7fc6f --- /dev/null +++ b/src/core/algorithms/interpolation/__init__.py @@ -0,0 +1,7 @@ +""" +插值算法模块 +包含0值像素插值的核心数学逻辑 +""" +from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch + +__all__ = ['interpolate_pixels', 'interpolate_zero_pixels_batch'] diff --git a/src/core/algorithms/interpolation/interpolator.py b/src/core/algorithms/interpolation/interpolator.py new file mode 100644 index 0000000..3d1f370 --- /dev/null +++ b/src/core/algorithms/interpolation/interpolator.py @@ -0,0 +1,320 @@ +""" +像素插值算法模块 + +提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。 +支持多种插值方法:nearest, bilinear, spline (RBF), kriging。 +""" + +import numpy as np +from typing import Optional, Union, Tuple, List +from pathlib import Path + +try: + from scipy import ndimage + from scipy.interpolate import griddata, RBFInterpolator + from scipy.spatial import cKDTree + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + +try: + from osgeo import gdal + GDAL_AVAILABLE = True +except ImportError: + GDAL_AVAILABLE = False + + +def interpolate_pixels( + image_stack: np.ndarray, + zero_coords: np.ndarray, + valid_coords: np.ndarray, + valid_values: np.ndarray, + interpolation_method: str = 'nearest', + water_mask: Optional[np.ndarray] = None +) -> np.ndarray: + """ + 对指定坐标的像素进行插值(核心数学函数,不涉及文件I/O) + + Args: + image_stack: 影像数据堆叠,形状为 (height, width, n_bands) 的 float32 数组 + zero_coords: 需要插值的像素坐标,形状为 (n_zero, 2),每行是 [x, y] + valid_coords: 有效像素坐标,形状为 (n_valid, 2) + valid_values: 有效像素对应的值,形状为 (n_valid,) 或 (n_valid, n_bands) + interpolation_method: 插值方法,可选 'nearest', 'bilinear', 'spline', 'kriging' + water_mask: 可选的水域掩膜数组 + + Returns: + 插值后的影像副本,形状与 image_stack 相同 + """ + if not SCIPY_AVAILABLE: + raise ImportError("scipy未安装,无法进行0值像素插值") + + height, width, n_bands = image_stack.shape + result = image_stack.copy() + + # 兼容中文和各种格式的method参数 + raw_method = str(interpolation_method).lower() + if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method: + method = 'nearest' + elif 'bilinear' in raw_method or '线性' in raw_method or '双线性' in raw_method: + method = 'bilinear' + elif 'spline' in raw_method or '样条' in raw_method or 'rbf' in raw_method: + method = 'spline' + elif 'kriging' in raw_method or '克里金' in raw_method: + method = 'kriging' + else: + method = 'nearest' + + if len(valid_values) == 0: + return result + + is_multiband = len(valid_values.shape) > 1 and valid_values.shape[1] > 1 + + if is_multiband: + for band_idx in range(n_bands): + band_valid_values = valid_values[:, band_idx] + interpolated_values = _interpolate_single_band( + zero_coords, valid_coords, band_valid_values, method + ) + y_coords = zero_coords[:, 1].astype(int) + x_coords = zero_coords[:, 0].astype(int) + result[y_coords, x_coords, band_idx] = interpolated_values + else: + interpolated_values = _interpolate_single_band( + zero_coords, valid_coords, valid_values, method + ) + y_coords = zero_coords[:, 1].astype(int) + x_coords = zero_coords[:, 0].astype(int) + result[y_coords, x_coords] = interpolated_values + + return result + + +def _interpolate_single_band( + zero_coords: np.ndarray, + valid_coords: np.ndarray, + valid_values: np.ndarray, + method: str +) -> np.ndarray: + """对单个波段执行插值计算""" + if method == 'nearest': + tree = cKDTree(valid_coords) + _, indices = tree.query(zero_coords) + return valid_values[indices] + + elif method == 'bilinear': + interpolated = griddata( + valid_coords, valid_values, zero_coords, + method='linear', fill_value=0.0 + ) + nan_mask = np.isnan(interpolated) + if np.any(nan_mask): + tree = cKDTree(valid_coords) + _, indices = tree.query(zero_coords[nan_mask]) + interpolated[nan_mask] = valid_values[indices] + return interpolated + + elif method == 'spline': + try: + max_points = 10000 + if len(valid_values) > max_points: + indices = np.random.choice(len(valid_values), max_points, replace=False) + sample_coords = valid_coords[indices] + sample_values = valid_values[indices] + else: + sample_coords = valid_coords + sample_values = valid_values + rbf = RBFInterpolator(sample_coords, sample_values, kernel='thin_plate_spline') + interpolated = rbf(zero_coords) + nan_mask = np.isnan(interpolated) + if np.any(nan_mask): + tree = cKDTree(valid_coords) + _, indices = tree.query(zero_coords[nan_mask]) + interpolated[nan_mask] = valid_values[indices] + return interpolated + except Exception: + interpolated = griddata( + valid_coords, valid_values, zero_coords, + method='linear', fill_value=0.0 + ) + nan_mask = np.isnan(interpolated) + if np.any(nan_mask): + tree = cKDTree(valid_coords) + _, indices = tree.query(zero_coords[nan_mask]) + interpolated[nan_mask] = valid_values[indices] + return interpolated + + elif method == 'kriging': + try: + from src.utils.kriging import KrigingInterpolator + interpolator = KrigingInterpolator() + max_points = 5000 + if len(valid_values) > max_points: + indices = np.random.choice(len(valid_values), max_points, replace=False) + sample_coords = valid_coords[indices] + sample_values = valid_values[indices] + else: + sample_coords = valid_coords + sample_values = valid_values + interpolated = griddata( + sample_coords, sample_values, zero_coords, + method='cubic', fill_value=0.0 + ) + nan_mask = np.isnan(interpolated) + if np.any(nan_mask): + tree = cKDTree(valid_coords) + _, indices = tree.query(zero_coords[nan_mask]) + interpolated[nan_mask] = valid_values[indices] + return interpolated + except Exception: + interpolated = griddata( + valid_coords, valid_values, zero_coords, + method='linear', fill_value=0.0 + ) + nan_mask = np.isnan(interpolated) + if np.any(nan_mask): + tree = cKDTree(valid_coords) + _, indices = tree.query(zero_coords[nan_mask]) + interpolated[nan_mask] = valid_values[indices] + return interpolated + + return np.zeros(len(zero_coords)) + + +def interpolate_zero_pixels_batch( + img_path: str, + interpolation_method: str = 'nearest', + output_path: Optional[str] = None, + water_mask: Optional[Union[str, np.ndarray]] = None, + deglint_dir: Optional[str] = None, + callback_progress: Optional[callable] = None +) -> Tuple[str, Optional[np.ndarray]]: + """ + 对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O) + + Args: + img_path: 输入影像文件路径 + interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging' + output_path: 输出文件路径(如果为None,自动生成) + water_mask: 水域掩膜(文件路径或数组) + deglint_dir: 去耀斑目录(用于生成默认输出路径) + callback_progress: 进度回调函数 + + Returns: + (output_path, interpolated_image_stack) 元组 + """ + if not SCIPY_AVAILABLE: + raise ImportError("scipy未安装,无法进行0值像素插值") + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法读取影像文件") + + # 确定输出路径 + if output_path is None and deglint_dir is not None: + output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq") + + # 检查文件是否已存在 + if output_path and Path(output_path).exists(): + return output_path, None + + dataset = gdal.Open(img_path, gdal.GA_ReadOnly) + if dataset is None: + raise ValueError(f"无法打开影像文件: {img_path}") + + try: + width = dataset.RasterXSize + height = dataset.RasterYSize + n_bands = dataset.RasterCount + geotransform = dataset.GetGeoTransform() + projection = dataset.GetProjection() + + # 读取所有波段数据 + all_bands = [] + for band_idx in range(1, n_bands + 1): + band = dataset.GetRasterBand(band_idx) + band_data = band.ReadAsArray().astype(np.float32) + all_bands.append(band_data) + + image_stack = np.dstack(all_bands) + + # 读取水域掩膜 + mask_array = None + if water_mask is not None: + if isinstance(water_mask, str): + mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly) + if mask_dataset: + mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() + mask_dataset = None + elif isinstance(water_mask, np.ndarray): + mask_array = water_mask + + # 找出所有波段都为0的像素点 + all_bands_zero = np.all(image_stack == 0, axis=2) + + if mask_array is not None: + all_bands_zero = all_bands_zero & (mask_array > 0) + + zero_pixel_count = np.sum(all_bands_zero) + if zero_pixel_count == 0: + # 无需插值,直接保存 + if output_path: + driver = gdal.GetDriverByName('ENVI') + if driver is None: + driver = gdal.GetDriverByName('GTiff') + out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32) + out_dataset.SetGeoTransform(geotransform) + out_dataset.SetProjection(projection) + for i, band_data in enumerate(all_bands): + out_band = out_dataset.GetRasterBand(i + 1) + out_band.WriteArray(band_data) + out_band.FlushCache() + out_dataset = None + return output_path, image_stack + + # 获取坐标 + zero_y, zero_x = np.where(all_bands_zero) + zero_coords = np.column_stack([zero_x, zero_y]) + + valid_mask = ~all_bands_zero + valid_y, valid_x = np.where(valid_mask) + valid_coords = np.column_stack([valid_x, valid_y]) + + if len(valid_coords) == 0: + raise ValueError("没有有效像素可用于插值") + + # 逐波段插值 + interpolated_bands = [] + for band_idx in range(n_bands): + if callback_progress: + callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...") + band_data = all_bands[band_idx].copy() + valid_values_band = band_data[valid_mask] + + if len(valid_values_band) == 0: + interpolated_bands.append(band_data) + continue + + band_result = _interpolate_single_band( + zero_coords, valid_coords, valid_values_band, interpolation_method + ) + band_data[all_bands_zero] = band_result + interpolated_bands.append(band_data) + + # 保存结果 + if output_path: + driver = gdal.GetDriverByName('ENVI') + if driver is None: + driver = gdal.GetDriverByName('GTiff') + out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32) + out_dataset.SetGeoTransform(geotransform) + out_dataset.SetProjection(projection) + for i, band_data in enumerate(interpolated_bands): + out_band = out_dataset.GetRasterBand(i + 1) + out_band.WriteArray(band_data) + out_band.FlushCache() + out_dataset = None + + result_stack = np.dstack(interpolated_bands) + return output_path, result_stack + + finally: + dataset = None diff --git a/src/core/utils/__init__.py b/src/core/utils/__init__.py new file mode 100644 index 0000000..6b87817 --- /dev/null +++ b/src/core/utils/__init__.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +""" +工具模块 - 统一导出接口 +""" +from src.core.utils.gdal_helper import ( + get_image_geo_info, + load_image_as_array, + save_array_as_image, + save_bands_as_image, + copy_hdr_info, + read_band_as_array, + read_multiple_bands, +) +from src.core.utils.mask_converter import ( + prepare_water_mask_for_algorithm, + ensure_water_mask_dat, +) +from src.core.utils.preview_generator import ( + generate_image_preview, + generate_water_mask_overlay, + select_rgb_bands_by_wavelength, + get_wavelength_info, +) + +__all__ = [ + # GDAL IO + 'get_image_geo_info', + 'load_image_as_array', + 'save_array_as_image', + 'save_bands_as_image', + 'copy_hdr_info', + 'read_band_as_array', + 'read_multiple_bands', + # 掩膜转换 + 'prepare_water_mask_for_algorithm', + 'ensure_water_mask_dat', + # 预览图生成 + 'generate_image_preview', + 'generate_water_mask_overlay', + 'select_rgb_bands_by_wavelength', + 'get_wavelength_info', +] \ No newline at end of file diff --git a/src/core/utils/gdal_helper.py b/src/core/utils/gdal_helper.py new file mode 100644 index 0000000..abc672c --- /dev/null +++ b/src/core/utils/gdal_helper.py @@ -0,0 +1,309 @@ +# -*- coding: utf-8 -*- +""" +GDAL 底层 IO 工具模块 + +提供遥感影像读写、格式转换等底层 GDAL 操作功能。 +这些函数不依赖任何业务逻辑,可在其他项目中独立复用。 +""" +import os +from pathlib import Path +from typing import Tuple, Optional + +import numpy as np + +# GDAL 导入(可选) +try: + from osgeo import gdal, ogr, gdal_array + GDAL_AVAILABLE = True +except ImportError: + GDAL_AVAILABLE = False + +# hdr 文件工具 +try: + from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path + UTIL_AVAILABLE = True +except ImportError: + UTIL_AVAILABLE = False + write_fields_to_hdrfile = None + get_hdr_file_path = None + + +# ============================================================ +# 影像信息读取 +# ============================================================ + +def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]: + """ + 获取影像的地理信息(不加载图像数据,节省内存) + + Args: + img_path: 影像文件路径 + + Returns: + tuple: (geotransform, projection, width, height, n_bands) + """ + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法读取影像文件") + + dataset = gdal.Open(img_path, gdal.GA_ReadOnly) + if dataset is None: + raise ValueError(f"无法打开影像文件: {img_path}") + + try: + width = dataset.RasterXSize + height = dataset.RasterYSize + n_bands = dataset.RasterCount + geotransform = dataset.GetGeoTransform() + projection = dataset.GetProjection() + return geotransform, projection, width, height, n_bands + finally: + dataset = None + + +def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]: + """ + 加载影像文件为numpy数组 + + 注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。 + 建议直接传递文件路径给算法类,让算法类使用GDAL逐波段处理。 + + Args: + img_path: 影像文件路径 + + Returns: + tuple: (image_array, geotransform, projection) + image_array: numpy数组,形状为(height, width, bands) + geotransform: 地理变换参数 + projection: 投影信息 + """ + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法读取影像文件") + + dataset = gdal.Open(img_path, gdal.GA_ReadOnly) + if dataset is None: + raise ValueError(f"无法打开影像文件: {img_path}") + + try: + width = dataset.RasterXSize + height = dataset.RasterYSize + n_bands = dataset.RasterCount + geotransform = dataset.GetGeoTransform() + projection = dataset.GetProjection() + + image_bands = [] + for i in range(1, n_bands + 1): + band = dataset.GetRasterBand(i) + band_data = band.ReadAsArray() + image_bands.append(band_data) + + image_array = np.dstack(image_bands) + return image_array, geotransform, projection + finally: + dataset = None + + +def read_band_as_array(img_path: str, band_index: int) -> np.ndarray: + """ + 读取单个波段为 numpy 数组 + + Args: + img_path: 影像文件路径 + band_index: 波段索引(从 0 开始) + + Returns: + numpy 数组,形状为 (height, width) + """ + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法读取影像文件") + + dataset = gdal.Open(img_path, gdal.GA_ReadOnly) + if dataset is None: + raise ValueError(f"无法打开影像文件: {img_path}") + + try: + band = dataset.GetRasterBand(band_index + 1) + return band.ReadAsArray() + finally: + dataset = None + + +def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]: + """ + 读取多个指定波段为列表 + + Args: + img_path: 影像文件路径 + band_indices: 波段索引列表 + + Returns: + tuple: (band_list, geotransform, projection) + """ + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法读取影像文件") + + dataset = gdal.Open(img_path, gdal.GA_ReadOnly) + if dataset is None: + raise ValueError(f"无法打开影像文件: {img_path}") + + try: + geotransform = dataset.GetGeoTransform() + projection = dataset.GetProjection() + bands = [] + for idx in band_indices: + band = dataset.GetRasterBand(idx + 1) + bands.append(band.ReadAsArray()) + return bands, geotransform, projection + finally: + dataset = None + + +# ============================================================ +# 影像写入 +# ============================================================ + +def save_array_as_image(image_array: np.ndarray, output_path: str, + geotransform: tuple, projection: str, + dtype=None) -> str: + """ + 将numpy数组保存为影像文件 + + Args: + image_array: numpy数组,形状为(height, width, bands) 或 (height, width) + output_path: 输出文件路径 + geotransform: 地理变换参数 + projection: 投影信息 + dtype: GDAL数据类型(默认 gdal.GDT_Float32) + + Returns: + 输出文件路径 + """ + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法保存影像文件") + + if dtype is None: + dtype = gdal.GDT_Float32 + + if image_array.ndim == 2: + height, width = image_array.shape + n_bands = 1 + else: + height, width, n_bands = image_array.shape + + driver = gdal.GetDriverByName('ENVI') + if driver is None: + driver = gdal.GetDriverByName('GTiff') + + if driver is None: + raise ValueError("无法创建影像文件,没有可用的驱动") + + dataset = driver.Create(output_path, width, height, n_bands, dtype) + if dataset is None: + raise ValueError(f"无法创建输出文件: {output_path}") + + try: + dataset.SetGeoTransform(geotransform) + dataset.SetProjection(projection) + + if n_bands == 1: + band = dataset.GetRasterBand(1) + band.WriteArray(image_array) + band.FlushCache() + else: + for i in range(n_bands): + band = dataset.GetRasterBand(i + 1) + band.WriteArray(image_array[:, :, i]) + band.FlushCache() + finally: + dataset = None + + return output_path + + +def save_bands_as_image(corrected_bands: list, output_path: str, + geotransform: tuple, projection: str, + dtype=None) -> str: + """ + 直接从波段列表保存影像文件(避免堆叠,节省内存) + + Args: + corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组 + output_path: 输出文件路径 + geotransform: 地理变换参数 + projection: 投影信息 + dtype: GDAL数据类型 + + Returns: + 输出文件路径 + """ + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法保存影像文件") + + if not corrected_bands: + raise ValueError("波段列表为空") + + if dtype is None: + dtype = gdal.GDT_Float32 + + n_bands = len(corrected_bands) + height, width = corrected_bands[0].shape + + driver = gdal.GetDriverByName('ENVI') + if driver is None: + driver = gdal.GetDriverByName('GTiff') + + if driver is None: + raise ValueError("无法创建影像文件,没有可用的驱动") + + dataset = driver.Create(output_path, width, height, n_bands, dtype) + if dataset is None: + raise ValueError(f"无法创建输出文件: {output_path}") + + try: + dataset.SetGeoTransform(geotransform) + dataset.SetProjection(projection) + + for i, band_array in enumerate(corrected_bands): + if band_array.shape != (height, width): + raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配") + band = dataset.GetRasterBand(i + 1) + band.WriteArray(band_array) + band.FlushCache() + finally: + dataset = None + + return output_path + + +def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool: + """ + 复制原始影像的hdr文件信息(如波长等)到目标影像的hdr文件 + + Args: + source_img_path: 源影像文件路径(原始bsq文件) + dest_img_path: 目标影像文件路径(去耀斑后的bsq文件) + + Returns: + bool: 是否成功 + """ + if not UTIL_AVAILABLE: + print("警告: util模块未导入,无法复制hdr文件信息") + return False + + try: + source_hdr_path = get_hdr_file_path(source_img_path) + dest_hdr_path = get_hdr_file_path(dest_img_path) + + if not Path(source_hdr_path).exists(): + print(f"警告: 源hdr文件不存在: {source_hdr_path}") + return False + + if not Path(dest_hdr_path).exists(): + print(f"警告: 目标hdr文件不存在: {dest_hdr_path}") + return False + + write_fields_to_hdrfile(source_hdr_path, dest_hdr_path) + print(f"已复制原始hdr文件信息到: {dest_hdr_path}") + return True + except Exception as e: + print(f"警告: 复制hdr文件信息时出错: {e}") + return False \ No newline at end of file diff --git a/src/core/utils/mask_converter.py b/src/core/utils/mask_converter.py new file mode 100644 index 0000000..72da6e0 --- /dev/null +++ b/src/core/utils/mask_converter.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +""" +掩膜转换工具模块 + +提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换, +以及水体掩膜的预处理逻辑。 +""" +import os +from pathlib import Path +from typing import Optional, Union + +import numpy as np + +try: + from osgeo import gdal, ogr + GDAL_AVAILABLE = True +except ImportError: + GDAL_AVAILABLE = False + + +def prepare_water_mask_for_algorithm( + water_mask: Optional[Union[str, np.ndarray]], + image_shape: Union[tuple, np.ndarray], + geotransform: tuple, + projection: str, + img_path: str, + water_mask_dir: Optional[str] = None, + callback=None +) -> Optional[np.ndarray]: + """ + 准备水域掩膜供算法使用 + + 支持格式: + - None:自动使用预先生成的 dat 格式掩膜 + - numpy.ndarray:直接返回(确保是 0/1 格式) + - .dat / .tif 等栅格文件:读取并返回 + - .shp 文件:先栅格化,再读取返回 + + Args: + water_mask: 掩膜来源 + image_shape: 影像形状 (height, width) 或 (height, width, channels) + geotransform: GDAL 地理变换参数 + projection: 投影信息 + img_path: 影像路径(用于 shp 栅格化) + water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果) + callback: 进度回调函数(可选) + + Returns: + numpy数组(dtype=uint8,0=非水域,1=水域)或 None + """ + img_height, img_width = image_shape[0], image_shape[1] + + if water_mask is None: + return None + + # numpy 数组直接返回 + if isinstance(water_mask, np.ndarray): + if water_mask.shape[:2] != (img_height, img_width): + raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配") + return (water_mask > 0).astype(np.uint8) + + # 字符串路径 + if isinstance(water_mask, str): + ext = Path(water_mask).suffix.lower() + + # shapefile 格式 + if ext == '.shp': + return _convert_shp_to_mask( + shp_path=water_mask, + img_path=img_path, + image_shape=image_shape, + geotransform=geotransform, + projection=projection, + water_mask_dir=water_mask_dir, + callback=callback + ) + + # 栅格文件格式 + return _load_raster_mask(water_mask, img_height, img_width) + + raise ValueError(f"不支持的掩膜类型: {type(water_mask)}") + + +def _convert_shp_to_mask(shp_path: str, img_path: str, + image_shape: tuple, + geotransform: tuple, + projection: str, + water_mask_dir: Optional[str] = None, + callback=None) -> np.ndarray: + """将 shapefile 栅格化为掩膜数组""" + from src.utils.extract_water_area import rasterize_shp + + safe_shp_path = os.path.abspath(shp_path).replace('\\', '/') + shp_name = Path(safe_shp_path).stem + + if water_mask_dir: + temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat") + else: + temp_mask_path = f"/tmp/water_mask_{shp_name}.dat" + + # 缓存:已栅格化则直接读取 + if Path(temp_mask_path).exists(): + print(f"使用已存在的栅格化掩膜: {temp_mask_path}") + return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1]) + + # 需要栅格化 + if img_path is None: + raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化") + + print(f"正在将 SHP 栅格化: {safe_shp_path}") + rasterize_shp(safe_shp_path, temp_mask_path, img_path) + + return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1]) + + +def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray: + """从栅格文件加载掩膜""" + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法读取掩膜文件") + + mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly) + if mask_dataset is None: + raise ValueError(f"无法打开掩膜文件: {mask_path}") + + try: + mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() + finally: + mask_dataset = None + + if mask_array.shape != (img_height, img_width): + raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配") + + return (mask_array > 0).astype(np.uint8) + + +def ensure_water_mask_dat(img_path: str, + existing_dat_path: Optional[str] = None, + output_dir: Optional[str] = None) -> str: + """ + 确保存在 dat 格式的水体掩膜文件(用于步骤3/4中的算法) + + 如果 existing_dat_path 存在且是 .dat 文件,直接返回。 + 如果存在同名 .dat 文件,直接返回。 + 否则从 img_path 生成并保存到 output_dir。 + + Args: + img_path: 用于生成掩膜的遥感影像路径 + existing_dat_path: 已有的 dat 格式掩膜路径(可选) + output_dir: 输出目录(可选) + + Returns: + dat 格式掩膜文件路径 + """ + if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat': + if Path(existing_dat_path).exists(): + return existing_dat_path + + img_name = Path(img_path).stem + if output_dir is None: + output_dir = str(Path(img_path).parent) + + dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat") + + if Path(dat_path).exists(): + return dat_path + + # 如果已有其他格式的掩膜,转换为 dat + for ext in ['.tif', '.img', '.tiff']: + alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}") + if Path(alt_path).exists(): + return _convert_to_dat(alt_path, dat_path) + + return dat_path # 返回目标路径,让调用方决定是否需要生成 + + +def _convert_to_dat(src_path: str, dest_path: str) -> str: + """将其他栅格格式转换为 ENVI dat 格式""" + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法转换格式") + + src_ds = gdal.Open(src_path, gdal.GA_ReadOnly) + if src_ds is None: + raise ValueError(f"无法打开源掩膜文件: {src_path}") + + try: + geotransform = src_ds.GetGeoTransform() + projection = src_ds.GetProjection() + band = src_ds.GetRasterBand(1) + array = band.ReadAsArray() + + driver = gdal.GetDriverByName('ENVI') + if driver is None: + driver = gdal.GetDriverByName('GTiff') + + dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte) + if dest_ds is None: + raise ValueError(f"无法创建输出文件: {dest_path}") + + try: + dest_ds.SetGeoTransform(geotransform) + dest_ds.SetProjection(projection) + dest_band = dest_ds.GetRasterBand(1) + dest_band.WriteArray((array > 0).astype(np.uint8)) + dest_band.FlushCache() + finally: + dest_ds = None + + return dest_path + finally: + src_ds = None \ No newline at end of file diff --git a/src/core/utils/preview_generator.py b/src/core/utils/preview_generator.py new file mode 100644 index 0000000..168be01 --- /dev/null +++ b/src/core/utils/preview_generator.py @@ -0,0 +1,339 @@ +# -*- coding: utf-8 -*- +""" +遥感影像预览图生成工具模块 + +提供高光谱影像的 RGB 预览图、水域掩膜叠加图等可视化功能。 +""" +import numpy as np +from pathlib import Path +from typing import Optional, List + +try: + from osgeo import gdal + GDAL_AVAILABLE = True +except ImportError: + GDAL_AVAILABLE = False + +# matplotlib 仅在实际使用时导入(preview_generator 是可视化工具) +import matplotlib.pyplot as plt +from matplotlib.patches import Patch + +plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS'] +plt.rcParams['axes.unicode_minus'] = False + + +# ============================================================ +# 辅助函数:波段选择 +# ============================================================ + +def select_rgb_bands_by_wavelength(band_count: int, + wavelength_info: Optional[List[float]] = None, + fallback_bands: Optional[List[int]] = None) -> List[int]: + """ + 根据波长自动选择 RGB 波段 + + Args: + band_count: 总波段数 + wavelength_info: 各波段波长列表(nm),长度为 band_count + fallback_bands: 当无法通过波长选择时的回退波段索引 [R, G, B] + + Returns: + 波段索引列表 [R_index, G_index, B_index](0-based) + """ + if fallback_bands is None: + fallback_bands = [band_count - 3, band_count - 2, band_count - 1] + + if wavelength_info is None: + return [max(0, min(i, band_count - 1)) for i in fallback_bands] + + # 目标波长(nm) + TARGET_R = 650 + TARGET_G = 550 + TARGET_B = 460 + + def find_closest(target: float) -> int: + min_dist = float('inf') + best_idx = 0 + for i, wl in enumerate(wavelength_info): + dist = abs(wl - target) + if dist < min_dist: + min_dist = dist + best_idx = i + return best_idx + + try: + r_idx = find_closest(TARGET_R) + g_idx = find_closest(TARGET_G) + b_idx = find_closest(TARGET_B) + return [r_idx, g_idx, b_idx] + except Exception: + return [max(0, min(i, band_count - 1)) for i in fallback_bands] + + +def get_wavelength_info(img_path: str) -> Optional[List[float]]: + """从 hdr 文件读取波长信息""" + try: + hdr_path = Path(img_path).with_suffix('.hdr') + if not hdr_path.exists(): + return None + + wavelengths = [] + in_wl = False + with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f: + for line in f: + line = line.strip() + if line.startswith('wavelength ='): + in_wl = True + line = line.split('=', 1)[1].strip() + elif in_wl: + if line.startswith('{'): + line = line[1:] + if line.endswith('}'): + line = line[:-1] + in_wl = False + # 解析逗号分隔的数值 + for token in line.replace(',', ' ').split(): + try: + wavelengths.append(float(token)) + except ValueError: + pass + return wavelengths if wavelengths else None + except Exception: + return None + + +# ============================================================ +# 核心预览图生成函数 +# ============================================================ + +def generate_image_preview(img_path: str, + output_path: str, + bands: Optional[List[int]] = None, + title: str = "影像预览") -> str: + """ + 生成高光谱影像的 PNG 预览图 + + Args: + img_path: 输入影像路径 + output_path: 输出 PNG 文件路径 + bands: RGB 波段索引 [R, G, B],None 则自动选择 + title: 图片标题 + + Returns: + 生成的 PNG 文件路径 + """ + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法生成影像预览图") + + if Path(output_path).exists(): + print(f"检测到已存在的预览图,跳过生成: {output_path}") + return output_path + + dataset = gdal.Open(img_path) + if dataset is None: + raise ValueError(f"无法打开影像文件: {img_path}") + + try: + width = dataset.RasterXSize + height = dataset.RasterYSize + band_count = dataset.RasterCount + geotransform = dataset.GetGeoTransform() + + # 自动选择波段 + if bands is None: + if band_count >= 3: + wl_info = get_wavelength_info(img_path) + bands = select_rgb_bands_by_wavelength(band_count, wl_info) + else: + bands = [0, 0, 0] + + # 读取波段 + r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32) + g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32) + b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32) + + r_data[r_data <= 0] = np.nan + if band_count > 1: + g_data[g_data <= 0] = np.nan + if band_count > 2: + b_data[b_data <= 0] = np.nan + + # 线性拉伸 + def linear_stretch(data, low=2, high=98): + valid = data[~np.isnan(data)] + if len(valid) == 0: + return np.zeros_like(data) + lo = np.percentile(valid, low) + hi = np.percentile(valid, high) + if hi - lo < 1e-10: + return np.zeros_like(data) + stretched = np.clip((data - lo) / (hi - lo), 0, 1) + return np.nan_to_num(stretched, nan=0.0) + + r_s = linear_stretch(r_data) + g_s = linear_stretch(g_data) if band_count > 1 else r_s + b_s = linear_stretch(b_data) if band_count > 2 else r_s + + rgb_image = np.stack([r_s, g_s, b_s], axis=2) + + # 绘图 + fig, ax = plt.subplots(figsize=(12, 10)) + ax.imshow(rgb_image) + ax.set_title(title, fontsize=12, fontweight='bold') + ax.axis('off') + + geotransform = dataset.GetGeoTransform() + if geotransform and geotransform[1] != 0: + pixel_size_x = abs(geotransform[1]) + scale_text = f"分辨率: {pixel_size_x:.2f} m/px | 尺寸: {width} x {height} px" + fig.text(0.5, 0.02, scale_text, ha='center', fontsize=9, + color='white', + bbox=dict(facecolor='black', alpha=0.6, + boxstyle='round,pad=0.3')) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1) + plt.close(fig) + + return output_path + + finally: + dataset = None + + +def generate_water_mask_overlay(img_path: str, + mask_path: str, + output_path: str, + bands: Optional[List[int]] = None, + mask_color: tuple = (0, 100, 255), + mask_alpha: float = 0.5) -> str: + """ + 生成水域掩膜叠加到原图的 PNG 图像 + + Args: + img_path: 输入影像路径 + mask_path: 水域掩膜文件路径 + output_path: 输出 PNG 路径 + bands: RGB 波段索引,None 则自动选择 + mask_color: 掩膜叠加颜色 (R, G, B) + mask_alpha: 掩膜透明度(0=完全透明,1=完全不透明) + + Returns: + 生成的 PNG 文件路径 + """ + if not GDAL_AVAILABLE: + raise ImportError("GDAL未安装,无法生成叠加图") + + if Path(output_path).exists(): + print(f"检测到已存在的叠加图,跳过生成: {output_path}") + return output_path + + dataset = gdal.Open(img_path) + if dataset is None: + raise ValueError(f"无法打开影像文件: {img_path}") + + try: + width = dataset.RasterXSize + height = dataset.RasterYSize + band_count = dataset.RasterCount + geotransform = dataset.GetGeoTransform() + + # 自动选择波段 + if bands is None: + if band_count >= 3: + wl_info = get_wavelength_info(img_path) + bands = select_rgb_bands_by_wavelength(band_count, wl_info) + else: + bands = [0, 0, 0] + + r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32) + g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32) + b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32) + + r_data[r_data <= 0] = np.nan + if band_count > 1: + g_data[g_data <= 0] = np.nan + if band_count > 2: + b_data[b_data <= 0] = np.nan + + def linear_stretch(data, low=2, high=98): + valid = data[~np.isnan(data)] + if len(valid) == 0: + return np.zeros_like(data) + lo = np.percentile(valid, low) + hi = np.percentile(valid, high) + if hi - lo < 1e-10: + return np.zeros_like(data) + stretched = np.clip((data - lo) / (hi - lo), 0, 1) + return np.nan_to_num(stretched, nan=0.0) + + r_s = linear_stretch(r_data) + g_s = linear_stretch(g_data) if band_count > 1 else r_s + b_s = linear_stretch(b_data) if band_count > 2 else r_s + + rgb_image = np.nan_to_num(np.stack([r_s, g_s, b_s], axis=2)) * 255 + rgb_image = rgb_image.astype(np.uint8) + + # 读取掩膜 + mask_dataset = gdal.Open(mask_path) + if mask_dataset is not None: + mask_data = mask_dataset.GetRasterBand(1).ReadAsArray() + mask_dataset = None + else: + print(f"警告: 无法打开掩膜文件: {mask_path}") + mask_data = None + + # Alpha 混合 + overlay = np.zeros((height, width, 4), dtype=np.uint8) + overlay[:, :, 0:3] = mask_color + overlay[:, :, 3] = 255 # 全不透明 + + blended = rgb_image.astype(np.float32) + if mask_data is not None: + alpha = mask_data.astype(np.float32) / 255.0 * mask_alpha + for c in range(3): + blended[:, :, c] = rgb_image[:, :, c].astype(np.float32) * (1 - alpha) + mask_color[c] * alpha + blended = blended.astype(np.uint8) + + # 绘图 + fig, ax = plt.subplots(figsize=(14, 10)) + ax.imshow(blended) + ax.axis('off') + + legend_elements = [ + Patch(facecolor=f'#{mask_color[0]:02x}{mask_color[1]:02x}{mask_color[2]:02x}', + edgecolor='black', alpha=mask_alpha, label='水域范围') + ] + ax.legend(handles=legend_elements, loc='upper right', framealpha=0.9) + + # 面积计算 + if geotransform and geotransform[1] != 0: + pixel_size_x = abs(geotransform[1]) + pixel_size_y = abs(geotransform[5]) + pixel_area = pixel_size_x * pixel_size_y + + if mask_data is not None: + water_pixels = np.sum(mask_data > 0) + valid_pixels = np.sum(mask_data >= 0) + water_km2 = water_pixels * pixel_area / 1_000_000 + valid_km2 = valid_pixels * pixel_area / 1_000_000 + pct = (water_pixels / valid_pixels * 100) if valid_pixels > 0 else 0 + + area_text = (f'水域面积: {water_km2:.2f} km² | ' + f'影像总面积: {valid_km2:.2f} km² | ' + f'占比: {pct:.1f}%') + ax.text(0.02, 0.98, area_text, + transform=ax.transAxes, fontsize=11, + color='white', fontweight='bold', + bbox=dict(facecolor='#0064FF', alpha=0.8, + edgecolor='black', boxstyle='round,pad=0.5'), + verticalalignment='top') + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1) + plt.close(fig) + + return output_path + + finally: + dataset = None \ No newline at end of file diff --git a/src/core/visualization/__init__.py b/src/core/visualization/__init__.py new file mode 100644 index 0000000..4589d01 --- /dev/null +++ b/src/core/visualization/__init__.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +""" +可视化模块 - 统一导出接口 + +本模块从各子模块导入可视化函数,提供统一的导出接口。 +""" +from src.core.visualization.scatter_plot import generate_model_scatter_plots +from src.core.visualization.spectrum_plot import generate_spectrum_comparison_plots +from src.core.visualization.boxplot import generate_boxplots +from src.core.visualization.statistics import generate_statistical_charts +from src.core.visualization.preview import generate_glint_deglint_previews +from src.core.visualization.report import generate_pipeline_report + +__all__ = [ + 'generate_model_scatter_plots', + 'generate_spectrum_comparison_plots', + 'generate_boxplots', + 'generate_statistical_charts', + 'generate_glint_deglint_previews', + 'generate_pipeline_report', +] \ No newline at end of file diff --git a/src/core/visualization/boxplot.py b/src/core/visualization/boxplot.py new file mode 100644 index 0000000..5b0c270 --- /dev/null +++ b/src/core/visualization/boxplot.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +""" +可视化模块 - 箱型图生成 +""" +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from pathlib import Path +from typing import Optional, Dict, List + +sns.set_style("whitegrid") +plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS'] +plt.rcParams['axes.unicode_minus'] = False + + +def generate_boxplots( + csv_path: str, + parameter_columns: Optional[List[str]] = None, + data_start_column: int = 4, + save_individual: bool = True, + use_seaborn: bool = True, + output_dir: Optional[str] = None +) -> Dict[str, str]: + """ + 生成水质参数的箱型图 + + Args: + csv_path: CSV文件路径 + parameter_columns: 参数列名列表(如果为None,自动检测) + data_start_column: 数据开始列索引(从第几列开始,默认第5列,索引为4) + save_individual: 是否为每个参数单独保存箱型图 + use_seaborn: 是否使用seaborn绘制(更美观) + output_dir: 输出目录(None则使用默认) + + Returns: + 箱型图文件路径字典 + """ + print("\n" + "="*80) + print("生成水质参数箱型图") + print("="*80) + + if csv_path is None: + raise ValueError("请提供 csv_path") + + # 确定输出目录 + if output_dir is None: + csv_dir = Path(csv_path).parent + output_dir = str(csv_dir / "visualization" / "boxplots") + Path(output_dir).mkdir(parents=True, exist_ok=True) + + # 读取数据 + df = pd.read_csv(csv_path) + + # 确定参数列 + if parameter_columns is None: + data_columns = df.iloc[:, data_start_column:] + parameter_columns = list(data_columns.columns) + else: + parameter_columns = [col for col in parameter_columns if col in df.columns] + + if not parameter_columns: + print("警告: 未找到有效的参数列") + return {} + + boxplot_dir = Path(output_dir) + boxplot_paths = {} + + if save_individual: + print(f"为每个参数单独绘制箱型图(共 {len(parameter_columns)} 个参数)") + + for column in parameter_columns: + if column not in df.columns: + continue + + clean_data = df[column].dropna() + + if len(clean_data) == 0: + print(f"跳过列 '{column}': 没有有效数据") + continue + + try: + plt.figure(figsize=(8, 6)) + + if use_seaborn: + plot_data = pd.DataFrame({ + '参数': [column] * len(clean_data), + '数值': clean_data + }) + sns.boxplot(data=plot_data, x='参数', y='数值', palette='Set2') + sns.stripplot(data=plot_data, x='参数', y='数值', + color='red', alpha=0.6, size=5, jitter=True) + else: + box_plot = plt.boxplot([clean_data], labels=[column], + patch_artist=True, showfliers=False) + box_plot['boxes'][0].set_facecolor('lightblue') + box_plot['boxes'][0].set_alpha(0.7) + + x_pos = np.random.normal(1, 0.04, size=len(clean_data)) + plt.scatter(x_pos, clean_data, alpha=0.6, s=30, color='red', + edgecolors='black', linewidth=0.5, zorder=3) + + plt.title(f'{column} - 箱型图', fontsize=14, fontweight='bold') + plt.xlabel('参数', fontsize=12) + plt.ylabel('数值', fontsize=12) + + stats_text = (f'数据点数: {len(clean_data)}\n' + f'均值: {clean_data.mean():.2f}\n' + f'中位数: {clean_data.median():.2f}\n' + f'标准差: {clean_data.std():.2f}') + plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, + verticalalignment='top', + bbox=dict(boxstyle='round', + facecolor='wheat' if not use_seaborn else 'lightgreen', + alpha=0.8)) + + plt.grid(True, alpha=0.3, linestyle='--') + plt.tight_layout() + + safe_column_name = column.replace('/', '_').replace('\\', '_').replace(':', '_') + save_path = boxplot_dir / f'{safe_column_name}_boxplot.png' + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + + boxplot_paths[column] = str(save_path) + print(f" 已保存: {save_path.name}") + + except Exception as e: + print(f" 处理参数 {column} 时出错: {e}") + continue + + # 综合箱型图 + try: + print("\n生成综合箱型图(所有参数在一张图上)") + plt.figure(figsize=(max(12, len(parameter_columns) * 0.8), 8)) + + box_data = [] + labels = [] + for column in parameter_columns: + if column in df.columns: + clean_data = df[column].dropna() + if len(clean_data) > 0: + box_data.append(clean_data) + labels.append(column) + + if box_data: + if use_seaborn: + melted_data = pd.melt(df[labels], var_name='参数', value_name='数值') + melted_data = melted_data.dropna() + sns.boxplot(data=melted_data, x='参数', y='数值', palette='Set3') + sns.stripplot(data=melted_data, x='参数', y='数值', + color='red', alpha=0.6, size=4, jitter=True) + else: + box_plot = plt.boxplot(box_data, labels=labels, patch_artist=True, showfliers=False) + colors = plt.cm.Set3(np.linspace(0, 1, len(box_data))) + for patch, color in zip(box_plot['boxes'], colors): + patch.set_facecolor(color) + patch.set_alpha(0.7) + + for i, data in enumerate(box_data): + x_pos = np.random.normal(i + 1, 0.04, size=len(data)) + plt.scatter(x_pos, data, alpha=0.6, s=20, color='red', + edgecolors='black', linewidth=0.5, zorder=3) + + plt.title('水质参数箱型图(综合)', fontsize=16, fontweight='bold') + plt.xlabel('参数', fontsize=12) + plt.ylabel('数值', fontsize=12) + plt.xticks(rotation=45, ha='right') + plt.grid(True, alpha=0.3, linestyle='--') + plt.tight_layout() + + combined_path = boxplot_dir / 'all_parameters_boxplot.png' + plt.savefig(combined_path, dpi=300, bbox_inches='tight') + plt.close() + + boxplot_paths['all_parameters'] = str(combined_path) + print(f" 已保存综合箱型图: {combined_path.name}") + + except Exception as e: + print(f"生成综合箱型图时出错: {e}") + + print(f"\n箱型图生成完成,共生成 {len(boxplot_paths)} 个图表") + return boxplot_paths \ No newline at end of file diff --git a/src/core/visualization/preview.py b/src/core/visualization/preview.py new file mode 100644 index 0000000..7f6e597 --- /dev/null +++ b/src/core/visualization/preview.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +""" +可视化模块 - 耀斑影像预览图生成 +""" +from pathlib import Path +from typing import Optional, Dict + +from src.postprocessing.visualization_reports import WaterQualityVisualization + + +def generate_glint_deglint_previews( + work_dir: str, + output_subdir: str = "glint_deglint_previews", + generate_glint: bool = True, + generate_deglint: bool = True, + output_dir: Optional[str] = None +) -> Dict[str, str]: + """ + 生成2_glint和3_deglint文件夹中影像文件的PNG预览图 + + Args: + work_dir: 工作目录 + output_subdir: 输出子目录名称 + generate_glint: 是否处理2_glint文件夹 + generate_deglint: 是否处理3_deglint文件夹 + output_dir: 输出目录(None则使用默认) + + Returns: + 生成的预览图路径字典 + """ + print(f"\n{'='*70}") + print("步骤: 生成耀斑分析影像预览图") + print(f"{'='*70}") + + if work_dir is None: + raise ValueError("请提供 work_dir") + + # 确定输出目录 + if output_dir is None: + output_dir = str(Path(work_dir) / "visualization" / output_subdir) + Path(output_dir).mkdir(parents=True, exist_ok=True) + + # 实例化可视化器 + visualizer = WaterQualityVisualization(output_dir) + + try: + preview_paths = visualizer.generate_glint_deglint_previews( + work_dir=work_dir, + output_subdir=output_subdir, + generate_glint=generate_glint, + generate_deglint=generate_deglint + ) + + print(f"耀斑分析影像预览图生成完成,共生成 {len(preview_paths)} 个预览图") + return preview_paths + + except Exception as e: + print(f"生成耀斑分析影像预览图时出错: {e}") + return {} \ No newline at end of file diff --git a/src/core/visualization/report.py b/src/core/visualization/report.py new file mode 100644 index 0000000..e8e5c29 --- /dev/null +++ b/src/core/visualization/report.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +""" +可视化模块 - 流程执行报告生成 +""" +import numpy as np +import pandas as pd +from pathlib import Path +from typing import Optional, Dict +from datetime import datetime + + +def generate_pipeline_report( + step_timings: Dict, + pipeline_start_time: Optional[float] = None, + pipeline_end_time: Optional[float] = None, + output_path: Optional[str] = None +) -> str: + """ + 生成流程执行报告,包含每步的耗时统计 + + Args: + step_timings: 步骤耗时字典(格式:{step_name: {start_time, end_time, elapsed_seconds, elapsed_formatted, status, error}}) + pipeline_start_time: 流程开始时间戳 + pipeline_end_time: 流程结束时间戳 + output_path: 输出文件路径(如果为None,自动生成) + + Returns: + 报告文件路径 + """ + if output_path is None: + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + output_path = str(Path.cwd() / "reports" / f"pipeline_report_{timestamp}.csv") + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + def _format_time(seconds: float) -> str: + if seconds < 60: + return f"{seconds:.2f}秒" + elif seconds < 3600: + minutes = int(seconds // 60) + secs = seconds % 60 + return f"{minutes}分{secs:.2f}秒" + else: + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = seconds % 60 + return f"{hours}小时{minutes}分{secs:.2f}秒" + + # 准备报告数据 + report_data = [] + total_time = 0.0 + + step_order = [ + "步骤1: 生成水域mask", + "步骤2: 找到耀斑区域", + "步骤3: 去除耀斑", + "步骤4: 处理CSV文件", + "步骤5: 提取训练样本点光谱", + "步骤5.5: 计算水质光谱指数", + "步骤6: 训练机器学习模型", + "步骤6.5: 非经验模型训练", + "步骤6.75: 自定义回归", + "步骤7: 生成预测采样点", + "步骤8: 预测水质参数", + "步骤9: 生成分布图" + ] + + for step_name in step_order: + if step_name in step_timings: + timing_info = step_timings[step_name] + report_data.append({ + '步骤': step_name, + '开始时间': timing_info['start_time'], + '结束时间': timing_info['end_time'], + '耗时(秒)': f"{timing_info['elapsed_seconds']:.2f}", + '耗时(格式化)': timing_info['elapsed_formatted'], + '状态': timing_info['status'], + '错误信息': timing_info.get('error', '') + }) + if timing_info['status'] == 'completed': + total_time += timing_info['elapsed_seconds'] + + if pipeline_start_time and pipeline_end_time: + pipeline_total = pipeline_end_time - pipeline_start_time + report_data.append({ + '步骤': '总计', + '开始时间': datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S'), + '结束时间': datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S'), + '耗时(秒)': f"{pipeline_total:.2f}", + '耗时(格式化)': _format_time(pipeline_total), + '状态': 'completed', + '错误信息': '' + }) + + df_report = pd.DataFrame(report_data) + df_report.to_csv(output_path, index=False, encoding='utf-8-sig') + + # 文本格式报告 + txt_output_path = str(Path(output_path).with_suffix('.txt')) + with open(txt_output_path, 'w', encoding='utf-8') as f: + f.write("="*80 + "\n") + f.write("水质参数反演流程执行报告\n") + f.write("="*80 + "\n\n") + + if pipeline_start_time and pipeline_end_time: + f.write(f"流程开始时间: {datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"流程结束时间: {datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"总耗时: {_format_time(pipeline_end_time - pipeline_start_time)}\n\n") + + f.write("-"*80 + "\n") + f.write("各步骤执行详情:\n") + f.write("-"*80 + "\n\n") + + for step_name in step_order: + if step_name in step_timings: + timing_info = step_timings[step_name] + f.write(f"{step_name}\n") + f.write(f" 开始时间: {timing_info['start_time']}\n") + f.write(f" 结束时间: {timing_info['end_time']}\n") + f.write(f" 耗时: {timing_info['elapsed_formatted']} ({timing_info['elapsed_seconds']:.2f}秒)\n") + f.write(f" 状态: {timing_info['status']}\n") + if timing_info.get('error'): + f.write(f" 错误: {timing_info['error']}\n") + f.write("\n") + + f.write("-"*80 + "\n") + f.write("统计摘要:\n") + f.write("-"*80 + "\n") + completed_steps = [s for s in step_timings.values() if s['status'] == 'completed'] + failed_steps = [s for s in step_timings.values() if s['status'] == 'failed'] + skipped_steps = [s for s in step_timings.values() if s['status'] == 'skipped'] + + f.write(f"成功完成的步骤: {len(completed_steps)}\n") + f.write(f"失败的步骤: {len(failed_steps)}\n") + f.write(f"跳过的步骤: {len(skipped_steps)}\n") + + if completed_steps: + completed_times = [s['elapsed_seconds'] for s in completed_steps] + f.write(f"平均耗时: {_format_time(np.mean(completed_times))}\n") + f.write(f"最长耗时: {_format_time(np.max(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.max(completed_times)][0]})\n") + f.write(f"最短耗时: {_format_time(np.min(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.min(completed_times)][0]})\n") + + print(f"\n流程报告已生成:") + print(f" CSV格式: {output_path}") + print(f" 文本格式: {txt_output_path}") + + return output_path \ No newline at end of file diff --git a/src/core/visualization/scatter_plot.py b/src/core/visualization/scatter_plot.py new file mode 100644 index 0000000..78dce80 --- /dev/null +++ b/src/core/visualization/scatter_plot.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +""" +可视化模块 - 散点图生成 +""" +import numpy as np +import pandas as pd +from pathlib import Path +from typing import Optional, Dict, List, Union +from src.core.prediction.inference_batch import WaterQualityInference +from src.postprocessing.visualization_reports import WaterQualityVisualization + + +def generate_model_scatter_plots( + models_dir: str, + training_csv_path: str, + output_dir: Optional[str] = None, + metric: str = 'test_r2', + use_enhanced: bool = True, + feature_start_column: Union[str, int] = 13, + test_size: float = 0.2, + random_state: int = 42, + scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象 +) -> Dict[str, str]: + """ + 生成模型评估散点图(真实值vs预测值) + + Args: + models_dir: 模型保存目录 + training_csv_path: 训练数据CSV路径 + output_dir: 输出目录(None则使用默认) + metric: 选择最佳模型的指标 + use_enhanced: 是否使用增强版散点图(带置信区间,使用sctter_batch) + feature_start_column: 特征开始列名或索引 + test_size: 测试集比例 + random_state: 随机种子 + scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象 + + Returns: + 散点图文件路径字典(键为目标参数名) + """ + print("\n" + "="*80) + print("生成模型评估散点图") + print("="*80) + + if training_csv_path is None: + raise ValueError("请提供 training_csv_path") + + models_path = Path(models_dir) + if not models_path.exists(): + raise ValueError(f"模型目录不存在: {models_dir}") + + # 确定输出目录 + if output_dir is None: + output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots") + Path(output_dir).mkdir(parents=True, exist_ok=True) + + # 实例化可视化器 + visualizer = WaterQualityVisualization(output_dir) + + scatter_paths = {} + + # 增强版散点图 + if use_enhanced: + print("使用增强版散点图(带置信区间)") + try: + from src.core.prediction.sctter_batch import WaterQualityScatterBatch + if scatter_batch is None: + scatter_batch = WaterQualityScatterBatch() + + results = scatter_batch.batch_plot_scatter( + models_root_dir=models_dir, + csv_path=training_csv_path, + output_dir=output_dir, + metric=metric, + target_column=None, + feature_start_column=feature_start_column, + test_size=test_size, + random_state=random_state + ) + + for target_name, result in results.items(): + if result.get('status') == 'success': + scatter_paths[target_name] = result.get('save_path', '') + print(f" ✓ {target_name}: {result.get('save_path', '')}") + else: + print(f" ✗ {target_name}: 失败 - {result.get('error', '未知错误')}") + + except Exception as e: + print(f"使用增强版散点图时出错: {e}") + print("回退到基础版散点图") + use_enhanced = False + + # 基础版散点图 + if not use_enhanced or not scatter_paths: + print("使用基础版散点图") + for target_folder in models_path.iterdir(): + if not target_folder.is_dir(): + continue + + target_name = target_folder.name + print(f"\n处理目标参数: {target_name}") + + try: + inferencer = WaterQualityInference(str(target_folder)) + eval_result = inferencer.evaluate_with_split( + data_csv_path=training_csv_path, + split_method="spxy", + test_size=test_size, + random_state=random_state, + metric=metric + ) + + predictions = eval_result.get('predictions', {}) + if predictions: + y_train_true = predictions.get('y_train_true') + y_train_pred = predictions.get('y_train_pred') + y_test_true = predictions.get('y_test_true') + y_test_pred = predictions.get('y_test_pred') + metrics = eval_result.get('test_metrics', {}) + + if y_train_true is not None and y_test_true is not None: + y_all_true = np.concatenate([y_train_true, y_test_true]) + y_all_pred = np.concatenate([y_train_pred, y_test_pred]) + + train_indices = np.arange(len(y_train_true)) + test_indices = np.arange(len(y_train_true), len(y_all_true)) + + scatter_path = visualizer.plot_scatter_true_vs_pred( + y_true=y_all_true, + y_pred=y_all_pred, + target_name=target_name, + train_indices=train_indices, + test_indices=test_indices, + metrics={ + 'train_r2': eval_result.get('train_metrics', {}).get('r2', 0), + 'test_r2': metrics.get('r2', 0), + 'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0), + 'test_rmse': metrics.get('rmse', 0) + } + ) + scatter_paths[target_name] = scatter_path + except Exception as e: + print(f"处理目标参数 {target_name} 时出错: {e}") + continue + + print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表") + return scatter_paths \ No newline at end of file diff --git a/src/core/visualization/spectrum_plot.py b/src/core/visualization/spectrum_plot.py new file mode 100644 index 0000000..5770073 --- /dev/null +++ b/src/core/visualization/spectrum_plot.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +""" +可视化模块 - 光谱曲线图生成 +""" +import pandas as pd +from pathlib import Path +from typing import Optional, Dict, List, Union +from src.postprocessing.visualization_reports import WaterQualityVisualization + + +def generate_spectrum_comparison_plots( + csv_path: str, + parameter_columns: Optional[List[str]] = None, + wavelength_start_column: Union[str, int] = "UTM_Y", + output_dir: Optional[str] = None +) -> Dict[str, str]: + """ + 生成光谱曲线对比图(不同参数值的光谱曲线对比) + + Args: + csv_path: 包含光谱和参数值的CSV文件路径 + parameter_columns: 参数列名列表(如果为None,自动检测) + wavelength_start_column: 波长开始列名或索引 + output_dir: 输出目录(None则使用默认) + + Returns: + 光谱曲线图文件路径字典(键为参数名) + """ + print("\n" + "="*80) + print("生成光谱曲线对比图") + print("="*80) + + if csv_path is None: + raise ValueError("请提供 csv_path") + + # 确定输出目录 + if output_dir is None: + csv_dir = Path(csv_path).parent + output_dir = str(csv_dir / "visualization" / "spectrum_plots") + Path(output_dir).mkdir(parents=True, exist_ok=True) + + # 实例化可视化器 + visualizer = WaterQualityVisualization(output_dir) + + # 读取数据以检测参数列 + df = pd.read_csv(csv_path) + + if parameter_columns is None: + if isinstance(wavelength_start_column, str): + try: + wavelength_start_idx = df.columns.get_loc(wavelength_start_column) + except: + wavelength_start_idx = 13 + else: + wavelength_start_idx = wavelength_start_column + + parameter_columns = list(df.columns[:wavelength_start_idx]) + if len(parameter_columns) > 2: + parameter_columns = parameter_columns[2:] + + spectrum_paths = {} + for param_col in parameter_columns: + if param_col not in df.columns: + continue + + print(f"\n处理参数: {param_col}") + try: + spectrum_path = visualizer.plot_spectrum_by_parameter( + csv_path=csv_path, + parameter_column=param_col, + wavelength_start_column=wavelength_start_column, + n_groups=5 + ) + spectrum_paths[param_col] = spectrum_path + except Exception as e: + print(f"处理参数 {param_col} 时出错: {e}") + continue + + print(f"\n光谱曲线图生成完成,共生成 {len(spectrum_paths)} 个图表") + return spectrum_paths \ No newline at end of file diff --git a/src/core/visualization/statistics.py b/src/core/visualization/statistics.py new file mode 100644 index 0000000..381a629 --- /dev/null +++ b/src/core/visualization/statistics.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +""" +可视化模块 - 统计图表生成 +""" +import numpy as np +import pandas as pd +from pathlib import Path +from typing import Optional, Dict, List + +from src.postprocessing.visualization_reports import WaterQualityVisualization + + +def generate_statistical_charts( + csv_path: str, + parameter_columns: Optional[List[str]] = None, + output_dir: Optional[str] = None +) -> Dict[str, str]: + """ + 生成统计图表(箱线图、直方图、相关性热力图) + + Args: + csv_path: CSV文件路径 + parameter_columns: 参数列名列表(如果为None,自动检测) + output_dir: 输出目录(None则使用默认) + + Returns: + 统计图表文件路径字典 + """ + print("\n" + "="*80) + print("生成统计图表") + print("="*80) + + if csv_path is None: + raise ValueError("请提供 csv_path") + + # 确定输出目录 + if output_dir is None: + csv_dir = Path(csv_path).parent + output_dir = str(csv_dir / "visualization" / "statistical_charts") + Path(output_dir).mkdir(parents=True, exist_ok=True) + + # 实例化可视化器 + visualizer = WaterQualityVisualization(output_dir) + + # 读取数据以检测参数列 + df = pd.read_csv(csv_path) + + if parameter_columns is None: + parameter_columns = list(df.columns[2:]) + parameter_columns = [col for col in parameter_columns + if df[col].dtype in [np.float64, np.int64]] + + chart_paths = visualizer.plot_statistical_charts( + csv_path=csv_path, + parameter_columns=parameter_columns + ) + + print(f"\n统计图表生成完成") + return chart_paths \ No newline at end of file diff --git a/src/core/water_quality_inversion_pipeline_GUI.py b/src/core/water_quality_inversion_pipeline_GUI.py index b63c1cb..abee6da 100644 --- a/src/core/water_quality_inversion_pipeline_GUI.py +++ b/src/core/water_quality_inversion_pipeline_GUI.py @@ -51,6 +51,26 @@ from src.utils.kriging import KrigingInterpolator, batch_kriging_interpolation from src.postprocessing.map import ContentMapper from src.postprocessing.visualization_reports import WaterQualityVisualization, ReportGenerator from src.core.prediction.sctter_batch import WaterQualityScatterBatch +# 导入底层工具模块(从 utils/ 目录迁移) +from src.core.utils.gdal_helper import ( + get_image_geo_info as _get_image_geo_info, + load_image_as_array as _load_image_as_array, + save_array_as_image as _save_array_as_image, + save_bands_as_image as _save_bands_as_image, + copy_hdr_info as _copy_hdr_info, +) +from src.core.utils.mask_converter import ( + prepare_water_mask_for_algorithm as _prepare_water_mask_for_algorithm, + ensure_water_mask_dat as _ensure_water_mask_dat, +) +from src.core.utils.preview_generator import ( + generate_image_preview as _generate_image_preview, + generate_water_mask_overlay as _generate_water_mask_overlay, + select_rgb_bands_by_wavelength as _select_rgb_bands_by_wavelength, + get_wavelength_info as _get_wavelength_info, +) +# 导入算法层模块 +from src.core.algorithms.interpolation.interpolator import interpolate_zero_pixels_batch as _interpolate_zero_pixels_batch # 导入新的耀斑去除算法 from src.core.glint_removal.Kutser import Kutser from src.core.glint_removal.Goodman import Goodman @@ -225,23 +245,13 @@ class WaterQualityInversionPipeline: return f"{hours}小时{minutes}分{secs:.2f}秒" def _ensure_water_mask_dat(self, img_path: str) -> str: - """ - 确保有dat格式的水体掩膜文件(简化版本,因为步骤1已经确保有dat文件) - - Args: - img_path: 影像文件路径(已废弃,保留用于兼容性) - - Returns: - dat格式的水体掩膜文件路径 - """ - if self.water_mask_path is not None: - if Path(self.water_mask_path).exists(): - return self.water_mask_path - else: - raise ValueError(f"水体掩膜文件不存在: {self.water_mask_path}") - - raise ValueError("未找到水体掩膜文件,请先执行步骤1") - + """确保存在 dat 格式水体掩膜,转发至工具模块""" + return _ensure_water_mask_dat( + img_path=img_path, + existing_dat_path=self.water_mask_path, + output_dir=str(self.water_mask_dir) + ) + def step1_generate_water_mask(self, mask_path: Optional[str] = None, img_path: Optional[str] = None, @@ -409,313 +419,23 @@ class WaterQualityInversionPipeline: raise def _generate_image_preview(self, img_path: str, bands: Optional[List[int]] = None) -> str: - """ - 生成高光谱影像的PNG预览图 - - 根据波长选择RGB波段: - - 蓝波段 (Blue): ~460nm - - 绿波段 (Green): ~550nm - - 红波段 (Red): ~650nm - - 如果无法通过波长获取波段索引,则回退到基于波段序号的近似选择。 - - Args: - img_path: 输入高光谱影像文件路径 (.dat格式) - bands: 用于RGB合成的三个波段索引 [R, G, B],默认为None自动选择 - - Returns: - 生成的PNG文件路径 - """ - try: - print(f"正在生成影像预览图...") - - # 设置输出PNG路径 - img_name = Path(img_path).stem - png_path = str(self.water_mask_dir / f"hsi_preview.png") - - # 检查是否已存在 - if Path(png_path).exists(): - print(f"检测到已存在的预览图,跳过生成: {png_path}") - return png_path - - if not GDAL_AVAILABLE: - print("警告: GDAL未安装,无法生成影像预览图") - return "" - - # 使用GDAL读取影像 - dataset = gdal.Open(img_path) - if dataset is None: - print(f"警告: 无法打开影像文件: {img_path}") - return "" - - # 获取影像信息 - width = dataset.RasterXSize - height = dataset.RasterYSize - band_count = dataset.RasterCount - - # 如果没有指定波段,根据波长选择RGB波段 - if bands is None: - if band_count >= 3: - bands = self._select_rgb_bands_by_wavelength(img_path, band_count) - else: - # 如果只有一个波段,使用灰度显示 - bands = [0, 0, 0] - - # 读取指定波段 - r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32) - g_data = dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32) if band_count > 1 else r_data - b_data = dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32) if band_count > 2 else r_data - - # 去除无效值 - r_data[r_data <= 0] = np.nan - if band_count > 1: - g_data[g_data <= 0] = np.nan - if band_count > 2: - b_data[b_data <= 0] = np.nan - - # 对每个波段进行2%线性拉伸,增强视觉效果 - def linear_stretch(data, low_percent=2, high_percent=98): - """线性拉伸""" - valid_data = data[~np.isnan(data)] - if len(valid_data) == 0: - return np.zeros_like(data) - - low_val = np.percentile(valid_data, low_percent) - high_val = np.percentile(valid_data, high_percent) - - if high_val - low_val < 1e-10: - return np.zeros_like(data) - - stretched = (data - low_val) / (high_val - low_val) - stretched = np.clip(stretched, 0, 1) - return stretched - - r_stretched = linear_stretch(r_data) - g_stretched = linear_stretch(g_data) if band_count > 1 else r_stretched - b_stretched = linear_stretch(b_data) if band_count > 2 else r_stretched - - # 合成为RGB图像 - rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2) - - # 处理可能存在的nan值 - rgb_image = np.nan_to_num(rgb_image, nan=0.0) - - # 创建图形 - fig, ax = plt.subplots(figsize=(12, 10)) - ax.imshow(rgb_image) - ax.set_title(f'影像预览: RGB波段(基于波长): R=650 nm, G=550 nm, B=460 nm', - fontsize=12, fontweight='bold') - ax.axis('off') - - # 添加比例尺信息(白色文字,黑色背景下清晰可见) - geo_transform = dataset.GetGeoTransform() - if geo_transform: - pixel_size_x = abs(geo_transform[1]) - pixel_size_y = abs(geo_transform[5]) - scale_text = f"分辨率: {pixel_size_x:.2f}m x {pixel_size_y:.2f}m | 尺寸: {width} x {height}" - fig.text(0.5, 0.02, scale_text, ha='center', fontsize=10, style='italic', color='white') - - plt.tight_layout() - plt.savefig(png_path, dpi=150, bbox_inches='tight', pad_inches=0.1) - plt.close(fig) - - # 释放GDAL数据集 - dataset = None - - print(f"已生成影像预览图: {png_path}") - return png_path - - except Exception as e: - print(f"生成影像预览图时出错: {e}") - plt.close('all') - return "" + """生成影像预览图,转发至工具模块""" + output_path = str(self.water_mask_dir / f"hsi_preview.png") + return _generate_image_preview( + img_path=img_path, + output_path=output_path, + bands=bands, + title="影像预览: RGB波段(基于波长)" + ) def _generate_water_mask_overlay(self, img_path: str, mask_path: str) -> str: - """ - 生成水域掩膜叠加到原图的PNG图像 - - 将水域掩膜以透明度(蓝色半透明)叠加到RGB原图上,便于可视化水域范围。 - - Args: - img_path: 输入高光谱影像文件路径 - mask_path: 水域掩膜文件路径 (.dat格式) - - Returns: - 生成的PNG文件路径 - """ - try: - print(f"正在生成水域掩膜叠加图...") - - # 设置输出PNG路径 - png_path = str(self.water_mask_dir / "water_mask_overlay.png") - - # 检查是否已存在 - if Path(png_path).exists(): - print(f"检测到已存在的叠加图,跳过生成: {png_path}") - return png_path - - if not GDAL_AVAILABLE: - print("警告: GDAL未安装,无法生成叠加图") - return "" - - # 使用GDAL读取影像 - dataset = gdal.Open(img_path) - if dataset is None: - print(f"警告: 无法打开影像文件: {img_path}") - return "" - - # 获取影像信息 - width = dataset.RasterXSize - height = dataset.RasterYSize - band_count = dataset.RasterCount - - # 读取RGB波段(基于波长选择) - if band_count >= 3: - bands = self._select_rgb_bands_by_wavelength(img_path, band_count) - else: - bands = [0, 0, 0] - - r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32) - g_data = dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32) if band_count > 1 else r_data - b_data = dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32) if band_count > 2 else r_data - - # 去除无效值 - r_data[r_data <= 0] = np.nan - if band_count > 1: - g_data[g_data <= 0] = np.nan - if band_count > 2: - b_data[b_data <= 0] = np.nan - - # 线性拉伸 - def linear_stretch(data, low_percent=2, high_percent=98): - valid_data = data[~np.isnan(data)] - if len(valid_data) == 0: - return np.zeros_like(data) - low_val = np.percentile(valid_data, low_percent) - high_val = np.percentile(valid_data, high_percent) - if high_val - low_val < 1e-10: - return np.zeros_like(data) - stretched = (data - low_val) / (high_val - low_val) - stretched = np.clip(stretched, 0, 1) - return stretched - - r_stretched = linear_stretch(r_data) - g_stretched = linear_stretch(g_data) if band_count > 1 else r_stretched - b_stretched = linear_stretch(b_data) if band_count > 2 else r_stretched - - # 合成为RGB背景图像 (0-255) - rgb_image = np.stack([r_stretched, g_stretched, b_stretched], axis=2) - rgb_image = np.nan_to_num(rgb_image, nan=0.0) - rgb_image = (rgb_image * 255).astype(np.uint8) - - # 释放影像数据集 - dataset = None - - # 读取水域掩膜 - mask_dataset = gdal.Open(mask_path) - if mask_dataset is None: - print(f"警告: 无法打开掩膜文件: {mask_path}") - # 保存原图预览 - fig, ax = plt.subplots(figsize=(12, 10)) - ax.imshow(rgb_image) - ax.set_title('影像预览 (无掩膜叠加)', fontsize=12, fontweight='bold') - ax.axis('off') - plt.tight_layout() - plt.savefig(png_path, dpi=150, bbox_inches='tight', pad_inches=0.1) - plt.close(fig) - return png_path - - mask_data = mask_dataset.GetRasterBand(1).ReadAsArray() - mask_dataset = None - - # 创建叠加图 - # 创建RGBA图像(带透明通道) - rgba_image = np.zeros((height, width, 4), dtype=np.uint8) - rgba_image[:, :, 0:3] = rgb_image # RGB通道 - rgba_image[:, :, 3] = 255 # Alpha通道(完全不透明) - - # 创建掩膜叠加层(蓝色半透明) - # 水域区域用蓝色高亮显示,透明度50% - mask_overlay = np.zeros((height, width, 4), dtype=np.uint8) - mask_overlay[:, :, 0] = 0 # R: 0 - mask_overlay[:, :, 1] = 100 # G: 100 (蓝色偏青) - mask_overlay[:, :, 2] = 255 # B: 255 (纯蓝) - mask_overlay[:, :, 3] = (mask_data > 0).astype(np.uint8) * 128 # Alpha: 50%透明 - - # 混合原图和掩膜层 - # 使用alpha混合公式: result = fg * alpha + bg * (1 - alpha) - alpha = mask_overlay[:, :, 3:4].astype(np.float32) / 255.0 - blended = rgb_image.astype(np.float32) * (1 - alpha) + mask_overlay[:, :, 0:3].astype(np.float32) * alpha - blended = blended.astype(np.uint8) - - # 创建图形 - fig, ax = plt.subplots(figsize=(14, 10)) - ax.imshow(blended) - ax.axis('off') - - # 添加图例 - from matplotlib.patches import Patch - legend_elements = [ - Patch(facecolor='#0064FF', edgecolor='black', alpha=0.5, label='水域范围') - ] - ax.legend(handles=legend_elements, loc='upper right', framealpha=0.9) - - # 计算水域面积 - dataset = gdal.Open(img_path) - geo_transform = dataset.GetGeoTransform() - if geo_transform: - pixel_size_x = abs(geo_transform[1]) - pixel_size_y = abs(geo_transform[5]) - pixel_area = pixel_size_x * pixel_size_y # 平方米 - - # 计算水域像素数量和有效像素数量(非零像素) - water_pixels = np.sum(mask_data > 0) - valid_pixels = np.sum(mask_data >= 0) # 有效像素(包括水域和非水域) - - # 计算面积(平方米 -> 平方千米) - # 水域面积 - water_area_m2 = water_pixels * pixel_area - water_area_km2 = water_area_m2 / 1_000_000 - - # 有效像素面积(影像实际覆盖面积) - valid_area_m2 = valid_pixels * pixel_area - valid_area_km2 = valid_area_m2 / 1_000_000 - - # 水域占比(相对于有效像素) - water_percentage = (water_pixels / valid_pixels) * 100 if valid_pixels > 0 else 0 - - # 在图像上添加面积标注(合并显示) - area_text = f'水域面积: {water_area_km2:.2f} 平方千米 | 影像总面积: {valid_area_km2:.2f} 平方千米 | 水域占比: {water_percentage:.1f}%' - ax.text(0.02, 0.98, area_text, - transform=ax.transAxes, - fontsize=11, fontweight='bold', - color='white', - bbox=dict(facecolor='#0064FF', alpha=0.8, edgecolor='black', - boxstyle='round,pad=0.5', linewidth=2), - verticalalignment='top') - - # 添加比例尺信息 - scale_text = f"分辨率: {pixel_size_x:.2f}m x {pixel_size_y:.2f}m | 影像尺寸: {width} x {height}像素" - fig.text(0.5, 0.02, scale_text, ha='center', fontsize=10, style='italic', - color='white', - bbox=dict(facecolor='black', alpha=0.6, edgecolor='none', - boxstyle='round,pad=0.3')) - - print(f" 水域面积: {water_area_km2:.2f} km² | 影像总面积: {valid_area_km2:.2f} km² | 占比: {water_percentage:.1f}%") - dataset = None - - plt.tight_layout() - plt.savefig(png_path, dpi=150, bbox_inches='tight', pad_inches=0.1) - plt.close(fig) - - print(f"已生成水域掩膜叠加图: {png_path}") - return png_path - - except Exception as e: - print(f"生成水域掩膜叠加图时出错: {e}") - plt.close('all') - return "" - + """生成水域掩膜叠加图,转发至工具模块""" + output_path = str(self.water_mask_dir / "water_mask_overlay.png") + return _generate_water_mask_overlay( + img_path=img_path, + mask_path=mask_path, + output_path=output_path + ) def _select_rgb_bands_by_wavelength(self, img_path: str, band_count: int) -> List[int]: """ 根据波长选择RGB波段 @@ -901,373 +621,51 @@ class WaterQualityInversionPipeline: raise def _get_image_geo_info(self, img_path: str) -> tuple: - """ - 获取影像的地理信息(不加载图像数据,节省内存) - - Args: - img_path: 影像文件路径 - - Returns: - tuple: (geotransform, projection, width, height, n_bands) - geotransform: 地理变换参数 - projection: 投影信息 - width: 图像宽度 - height: 图像高度 - n_bands: 波段数 - """ - if not GDAL_AVAILABLE: - raise ImportError("GDAL未安装,无法读取影像文件") - - dataset = gdal.Open(img_path, gdal.GA_ReadOnly) - if dataset is None: - raise ValueError(f"无法打开影像文件: {img_path}") - - try: - width = dataset.RasterXSize - height = dataset.RasterYSize - n_bands = dataset.RasterCount - geotransform = dataset.GetGeoTransform() - projection = dataset.GetProjection() - - return geotransform, projection, width, height, n_bands - finally: - dataset = None - + """获取影像地理信息,转发至工具模块""" + return _get_image_geo_info(img_path) + def _load_image_as_array(self, img_path: str) -> tuple: - """ - 加载影像文件为numpy数组(已废弃,建议直接使用GDAL读取) - - 注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。 - 建议直接传递文件路径给算法类,让算法类使用GDAL逐波段处理。 - - Args: - img_path: 影像文件路径 - - Returns: - tuple: (image_array, geotransform, projection) - image_array: numpy数组,形状为(height, width, bands) - geotransform: 地理变换参数 - projection: 投影信息 - """ - if not GDAL_AVAILABLE: - raise ImportError("GDAL未安装,无法读取影像文件") - - dataset = gdal.Open(img_path, gdal.GA_ReadOnly) - if dataset is None: - raise ValueError(f"无法打开影像文件: {img_path}") - - try: - width = dataset.RasterXSize - height = dataset.RasterYSize - n_bands = dataset.RasterCount - geotransform = dataset.GetGeoTransform() - projection = dataset.GetProjection() - - # 读取所有波段 - image_bands = [] - for i in range(1, n_bands + 1): - band = dataset.GetRasterBand(i) - band_data = band.ReadAsArray() - image_bands.append(band_data) - - # 堆叠为(height, width, bands)格式 - image_array = np.dstack(image_bands) - - return image_array, geotransform, projection - finally: - dataset = None - + """加载影像为numpy数组,转发至工具模块""" + return _load_image_as_array(img_path) + def _save_array_as_image(self, image_array: np.ndarray, output_path: str, geotransform: tuple, projection: str, - dtype: type = gdal.GDT_Float32) -> str: - """ - 将numpy数组保存为影像文件 - - Args: - image_array: numpy数组,形状为(height, width, bands) - output_path: 输出文件路径 - geotransform: 地理变换参数 - projection: 投影信息 - dtype: GDAL数据类型 - - Returns: - 输出文件路径 - """ - if not GDAL_AVAILABLE: - raise ImportError("GDAL未安装,无法保存影像文件") - - # 兼容 (H,W) 和 (H,W,C) 两种 shape 格式 - if image_array.ndim == 2: - height, width = image_array.shape - n_bands = 1 - else: - height, width, n_bands = image_array.shape - - # 获取驱动 - driver = gdal.GetDriverByName('ENVI') - if driver is None: - # 如果ENVI驱动不可用,尝试使用GTiff - driver = gdal.GetDriverByName('GTiff') - - if driver is None: - raise ValueError("无法创建影像文件,没有可用的驱动") - - # 创建数据集 - dataset = driver.Create(output_path, width, height, n_bands, dtype) - if dataset is None: - raise ValueError(f"无法创建输出文件: {output_path}") - - try: - # 设置地理变换和投影 - dataset.SetGeoTransform(geotransform) - dataset.SetProjection(projection) - - # 写入每个波段 - for i in range(n_bands): - band = dataset.GetRasterBand(i + 1) - band.WriteArray(image_array[:, :, i]) - band.FlushCache() - - finally: - dataset = None - - return output_path - + dtype=None) -> str: + """保存numpy数组为影像,转发至工具模块""" + return _save_array_as_image(image_array, output_path, geotransform, projection, dtype) + def _save_bands_as_image(self, corrected_bands: list, output_path: str, geotransform: tuple, projection: str, - dtype: type = gdal.GDT_Float32) -> str: - """ - 直接从波段列表保存影像文件(避免堆叠,节省内存) - - Args: - corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组 - output_path: 输出文件路径 - geotransform: 地理变换参数 - projection: 投影信息 - dtype: GDAL数据类型 - - Returns: - 输出文件路径 - """ - if not GDAL_AVAILABLE: - raise ImportError("GDAL未安装,无法保存影像文件") - - if not corrected_bands: - raise ValueError("波段列表为空") - - n_bands = len(corrected_bands) - height, width = corrected_bands[0].shape - - # 获取驱动 - driver = gdal.GetDriverByName('ENVI') - if driver is None: - # 如果ENVI驱动不可用,尝试使用GTiff - driver = gdal.GetDriverByName('GTiff') - - if driver is None: - raise ValueError("无法创建影像文件,没有可用的驱动") - - # 创建数据集 - dataset = driver.Create(output_path, width, height, n_bands, dtype) - if dataset is None: - raise ValueError(f"无法创建输出文件: {output_path}") - - try: - # 设置地理变换和投影 - dataset.SetGeoTransform(geotransform) - dataset.SetProjection(projection) - - # 逐个写入波段(避免堆叠所有波段,节省内存) - for i, band_array in enumerate(corrected_bands): - if band_array.shape != (height, width): - raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配") - band = dataset.GetRasterBand(i + 1) - band.WriteArray(band_array) - band.FlushCache() - # 注意:这里不能删除band_array,因为它还在corrected_bands列表中 - # 但保存后可以提示垃圾回收器(如果需要) - - finally: - dataset = None - - return output_path - + dtype=None) -> str: + """从波段列表保存影像,转发至工具模块""" + return _save_bands_as_image(corrected_bands, output_path, geotransform, projection, dtype) + def _prepare_water_mask_for_algorithm(self, water_mask: Optional[Union[str, np.ndarray]], image_shape: Union[tuple, np.ndarray], geotransform: tuple, projection: str, img_path: str) -> Optional[np.ndarray]: - """ - 准备水域掩膜供算法使用 - - 注意:如果传入的是shp文件,会先检查是否已经栅格化过,避免重复转换 - - Args: - water_mask: 水域掩膜,可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp) - image_shape: 影像形状,可以是(height, width)元组或numpy数组(用于获取形状) - geotransform: 地理变换参数 - projection: 投影信息 - img_path: 影像文件路径(用于栅格化shp文件) - - Returns: - numpy数组或None,1表示水域,0表示非水域 - """ - # 获取图像尺寸(统一从 shape 元组中提取前两个维度,兼容 (H,W)、(H,W,C)、(B,H,W) 等多种格式) - img_height, img_width = image_shape[0], image_shape[1] - - if water_mask is None: - # 如果water_mask为None,使用步骤1生成的dat格式掩膜 - if self.water_mask_path is not None: - try: - dat_mask_path = self._ensure_water_mask_dat(img_path) - water_mask = dat_mask_path - print(f"使用步骤1生成的水域掩膜: {water_mask}") - except Exception as e: - print(f"警告: 无法使用步骤1的水域掩膜: {e}") - return None - else: - return None - - # 如果已经是numpy数组 - if isinstance(water_mask, np.ndarray): - if water_mask.shape[:2] != (img_height, img_width): - raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配") - return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜 - - # 如果是文件路径 - if isinstance(water_mask, str): - # 检查是否为shapefile - if water_mask.lower().endswith('.shp'): - # 从shp文件创建掩膜(这种情况应该很少,因为步骤1已经统一转换为dat) - try: - from src.utils.extract_water_area import rasterize_shp - - # 路径标准化:转为绝对路径,统一正斜杠 - safe_shp_path = os.path.abspath(water_mask).replace('\\', '/') - print(f"[DEBUG] 标准化后的 SHP 路径: {safe_shp_path}") - - # 检查必要的伴随文件是否存在 - shp_dir = os.path.dirname(safe_shp_path) - shp_base = os.path.splitext(safe_shp_path)[0] - for ext in ['.dbf', '.shx', '.prj']: - companion = shp_base + ext - if not os.path.exists(companion): - print(f"[DEBUG] 缺失伴随文件: {companion}") - - # 检查 ESRI Shapefile 驱动是否可用 - if GDAL_AVAILABLE: - driver = ogr.GetDriverByName("ESRI Shapefile") - if driver is None: - raise RuntimeError("系统中未找到 ESRI Shapefile 驱动!请检查 GDAL 安装。") - print(f"[DEBUG] ESRI Shapefile 驱动可用: {driver.GetName()}") - - # 尝试用 ogr.Open 打开,捕获详细错误 - try: - ogr_ds = ogr.Open(safe_shp_path) - if ogr_ds is None: - # 通过 gdal.OpenEx 再试一次,获取详细原因 - ogr_ds2 = gdal.OpenEx(safe_shp_path, gdal.OF_VECTOR) - if ogr_ds2 is None: - raise RuntimeError( - f"ogr.Open 和 gdal.OpenEx 均无法打开 SHP 文件。\n" - f"可能原因:\n" - f" 1. 文件路径包含中文/特殊字符(当前路径: {safe_shp_path})\n" - f" 2. .dbf/.shx 等伴随文件缺失或损坏\n" - f" 3. GDAL 驱动未注册\n" - f"建议:将 SHP 文件复制到纯英文路径下重试" - ) - else: - print(f"[DEBUG] ogr.Open 成功打开 SHP,图层数: {ogr_ds.GetLayerCount()}") - ogr_ds = None # 仅用于诊断,不做后续处理 - except Exception as ogr_err: - raise RuntimeError( - f"OGR 打开 SHP 时出错(详细原因): {str(ogr_err)}\n" - f"文件路径: {safe_shp_path}" - ) - - # 使用固定路径,避免重复转换 - shp_name = Path(safe_shp_path).stem - temp_mask_path = str(self.water_mask_dir / f"water_mask_{shp_name}.dat") - - # 如果文件已存在,直接使用 - if Path(temp_mask_path).exists(): - print(f"使用已存在的栅格化掩膜: {temp_mask_path}") - water_mask = temp_mask_path - else: - # 需要栅格化(需要img_path) - if img_path is None: - raise ValueError("当water_mask为shp格式时,需要提供img_path参数用于栅格化") - # 传入标准化后的路径 - rasterize_shp(safe_shp_path, temp_mask_path, img_path) - water_mask = temp_mask_path - print(f"已将shp格式的水域掩膜栅格化为: {temp_mask_path}") - - # 读取栅格化的掩膜 - if not GDAL_AVAILABLE: - raise ImportError("GDAL未安装,无法读取掩膜文件") - mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly) - if mask_dataset is None: - raise ValueError(f"无法打开栅格化的掩膜文件: {water_mask}") - mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() - mask_dataset = None - except Exception as e: - raise ValueError(f"无法从shp文件创建掩膜: {e}") - else: - # 栅格文件 - if not GDAL_AVAILABLE: - raise ImportError("GDAL未安装,无法读取掩膜文件") - mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly) - if mask_dataset is None: - raise ValueError(f"无法打开掩膜文件: {water_mask}") - - mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() - mask_dataset = None - - # 检查尺寸 - if mask_array.shape != (img_height, img_width): - raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配") - - return (mask_array > 0).astype(np.uint8) - - raise ValueError(f"不支持的掩膜类型: {type(water_mask)}") - + """准备水体掩膜供算法使用,转发至工具模块""" + return _prepare_water_mask_for_algorithm( + water_mask=water_mask, + image_shape=image_shape, + geotransform=geotransform, + projection=projection, + img_path=img_path, + water_mask_dir=str(self.water_mask_dir), + callback=getattr(self, 'callback', None) + ) def _copy_hdr_info(self, source_img_path: str, dest_img_path: str): - """ - 复制原始影像的hdr文件信息(如波长等)到目标影像的hdr文件 - - Args: - source_img_path: 源影像文件路径(原始bsq文件) - dest_img_path: 目标影像文件路径(去耀斑后的bsq文件) - """ - if not UTIL_AVAILABLE: - print("警告: util模块未导入,无法复制hdr文件信息") - return - - try: - source_hdr_path = get_hdr_file_path(source_img_path) - dest_hdr_path = get_hdr_file_path(dest_img_path) - - if not Path(source_hdr_path).exists(): - print(f"警告: 源hdr文件不存在: {source_hdr_path}") - return - - if not Path(dest_hdr_path).exists(): - print(f"警告: 目标hdr文件不存在: {dest_hdr_path}") - return - - # 复制hdr文件信息(波长等) - write_fields_to_hdrfile(source_hdr_path, dest_hdr_path) - print(f"已复制原始hdr文件信息到: {dest_hdr_path}") - except Exception as e: - print(f"警告: 复制hdr文件信息时出错: {e}") - - def _interpolate_zero_pixels(self, img_path: str, + """复制hdr文件信息,转发至工具模块""" + return _copy_hdr_info(source_img_path, dest_img_path) + + def _interpolate_zero_pixels(self, img_path: str, interpolation_method: str = 'nearest', output_path: Optional[str] = None, water_mask: Optional[Union[str, np.ndarray]] = None) -> str: """ - 对影像中所有波段都为0的像素点进行插值(只处理所有波段都为0的像素) - + 对影像中所有波段都为0的像素点进行插值(转发至算法模块) + Args: img_path: 输入影像文件路径 interpolation_method: 插值方法,支持: @@ -1277,290 +675,32 @@ class WaterQualityInversionPipeline: - 'kriging': 克里金插值(最慢但最准确) output_path: 输出文件路径(如果为None,自动生成) water_mask: 水域掩膜,用于限制插值区域(可选) - + Returns: 插值后的影像文件路径 """ if not SCIPY_AVAILABLE: raise ImportError("scipy未安装,无法进行0值像素插值") - if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法读取影像文件") - + print(f"\n开始对0值像素进行插值,方法: {interpolation_method}") print("注意: 只处理所有波段都为0的像素点") - - # 确定输出路径 - if output_path is None: - output_path = str(self.deglint_dir / f"interpolated_{interpolation_method}.bsq") - - # 检查文件是否已存在 - if Path(output_path).exists(): - print(f"检测到已存在的插值影像文件,直接使用: {output_path}") - self.interpolated_img_path = output_path - return output_path - - # 读取影像 - dataset = gdal.Open(img_path, gdal.GA_ReadOnly) - if dataset is None: - raise ValueError(f"无法打开影像文件: {img_path}") - - try: - width = dataset.RasterXSize - height = dataset.RasterYSize - n_bands = dataset.RasterCount - geotransform = dataset.GetGeoTransform() - projection = dataset.GetProjection() - - print(f"影像尺寸: {width} x {height} x {n_bands}") - - # 读取所有波段数据 - print("读取所有波段数据...") - all_bands = [] - for band_idx in range(1, n_bands + 1): - band = dataset.GetRasterBand(band_idx) - band_data = band.ReadAsArray().astype(np.float32) - all_bands.append(band_data) - - # 堆叠为 (height, width, n_bands) 格式 - image_stack = np.dstack(all_bands) - - # 读取水域掩膜(如果提供) - mask_array = None - if water_mask is not None: - if isinstance(water_mask, str): - mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly) - if mask_dataset: - mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() - mask_dataset = None - elif isinstance(water_mask, np.ndarray): - mask_array = water_mask - - # 找出所有波段都为0的像素点 - # 检查每个像素在所有波段是否都为0 - all_bands_zero = np.all(image_stack == 0, axis=2) # (height, width) - - # 如果提供了水域掩膜,只在水域掩膜内处理 - if mask_array is not None: - all_bands_zero = all_bands_zero & (mask_array > 0) - - # 统计需要插值的像素数量 - zero_pixel_count = np.sum(all_bands_zero) - print(f"发现 {zero_pixel_count} 个所有波段都为0的像素点") - - if zero_pixel_count == 0: - print("没有需要插值的像素点,直接保存原影像") - # 直接保存原影像 - driver = gdal.GetDriverByName('ENVI') - if driver is None: - driver = gdal.GetDriverByName('GTiff') - if driver is None: - raise ValueError("无法创建影像文件,没有可用的驱动") - - out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32) - if out_dataset is None: - raise ValueError(f"无法创建输出文件: {output_path}") - - out_dataset.SetGeoTransform(geotransform) - out_dataset.SetProjection(projection) - - for i, band_data in enumerate(all_bands): - out_band = out_dataset.GetRasterBand(i + 1) - out_band.WriteArray(band_data) - out_band.FlushCache() - - out_dataset = None - self.interpolated_img_path = output_path - return output_path - - # 获取需要插值的像素坐标 - zero_y, zero_x = np.where(all_bands_zero) - zero_coords = np.column_stack([zero_x, zero_y]) # (n_zero_pixels, 2) - - # 获取有效像素的坐标(至少有一个波段不为0的像素) - valid_mask = ~all_bands_zero - valid_y, valid_x = np.where(valid_mask) - valid_coords = np.column_stack([valid_x, valid_y]) # (n_valid_pixels, 2) - - if len(valid_coords) == 0: - raise ValueError("没有有效像素可用于插值") - - print(f"使用 {len(valid_coords)} 个有效像素进行插值") - - # 创建输出数据集 - driver = gdal.GetDriverByName('ENVI') - if driver is None: - driver = gdal.GetDriverByName('GTiff') - if driver is None: - raise ValueError("无法创建影像文件,没有可用的驱动") - - out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32) - if out_dataset is None: - raise ValueError(f"无法创建输出文件: {output_path}") - - out_dataset.SetGeoTransform(geotransform) - out_dataset.SetProjection(projection) - - # 逐波段进行插值(但只对"所有波段都为0"的像素进行插值) - interpolated_bands = [] - - for band_idx in range(n_bands): - print(f"处理波段 {band_idx + 1}/{n_bands}...", end=' ') - band_data = all_bands[band_idx].copy() - - # 获取有效像素的值 - valid_values = band_data[valid_mask] # (n_valid_pixels,) - - if len(valid_values) == 0: - print(f"警告: 波段 {band_idx + 1} 没有有效像素,跳过插值") - interpolated_bands.append(band_data) - continue - - # 兼容中文和各种格式 - raw_interp = str(interpolation_method).lower() - if 'nearest' in raw_interp or '邻近' in raw_interp or '最邻近' in raw_interp: - interpolation_method = 'nearest' - elif 'bilinear' in raw_interp or '线性' in raw_interp or '双线性' in raw_interp: - interpolation_method = 'bilinear' - elif 'spline' in raw_interp or '样条' in raw_interp or 'rbf' in raw_interp: - interpolation_method = 'spline' - elif 'kriging' in raw_interp or '克里金' in raw_interp: - interpolation_method = 'kriging' - else: - interpolation_method = 'nearest' - # 对需要插值的像素进行插值 - if interpolation_method == 'nearest': - # 邻近插值 - from scipy.spatial import cKDTree - tree = cKDTree(valid_coords) - _, indices = tree.query(zero_coords) - interpolated_values = valid_values[indices] - - elif interpolation_method == 'bilinear': - # 双线性插值(使用griddata) - interpolated_values = griddata( - valid_coords, valid_values, zero_coords, - method='linear', fill_value=0.0 - ) - - # 如果线性插值失败,使用邻近插值 - nan_mask = np.isnan(interpolated_values) - if np.any(nan_mask): - from scipy.spatial import cKDTree - tree = cKDTree(valid_coords) - _, indices = tree.query(zero_coords[nan_mask]) - interpolated_values[nan_mask] = valid_values[indices] - - elif interpolation_method == 'spline': - # 样条插值(RBF) - try: - # 如果有效点太多,随机采样以提高速度 - max_points = 10000 - if len(valid_values) > max_points: - indices = np.random.choice(len(valid_values), max_points, replace=False) - sample_coords = valid_coords[indices] - sample_values = valid_values[indices] - else: - sample_coords = valid_coords - sample_values = valid_values - - # 使用RBF插值 - rbf = RBFInterpolator(sample_coords, sample_values, kernel='thin_plate_spline') - interpolated_values = rbf(zero_coords) - except Exception as e: - print(f"样条插值失败: {e},回退到双线性插值") - interpolated_values = griddata( - valid_coords, valid_values, zero_coords, - method='linear', fill_value=0.0 - ) - nan_mask = np.isnan(interpolated_values) - if np.any(nan_mask): - from scipy.spatial import cKDTree - tree = cKDTree(valid_coords) - _, indices = tree.query(zero_coords[nan_mask]) - interpolated_values[nan_mask] = valid_values[indices] - - elif interpolation_method == 'kriging': - # 克里金插值 - try: - from src.utils.kriging import KrigingInterpolator - interpolator = KrigingInterpolator() - - # 如果有效点太多,随机采样以提高速度 - max_points = 5000 - if len(valid_values) > max_points: - indices = np.random.choice(len(valid_values), max_points, replace=False) - sample_coords = valid_coords[indices] - sample_values = valid_values[indices] - else: - sample_coords = valid_coords - sample_values = valid_values - - # 执行克里金插值 - result = interpolator.interpolate( - sample_coords[:, 0], sample_coords[:, 1], sample_values, - spatial_resolution=1.0, - output_path=None, - proj=projection - ) - - if result is not None: - # 从结果中提取插值点 - # 注意:KrigingInterpolator返回的是网格,需要提取对应位置的值 - # 这里简化处理,使用griddata作为后备 - interpolated_values = griddata( - valid_coords, valid_values, zero_coords, - method='cubic', fill_value=0.0 - ) - nan_mask = np.isnan(interpolated_values) - if np.any(nan_mask): - from scipy.spatial import cKDTree - tree = cKDTree(valid_coords) - _, indices = tree.query(zero_coords[nan_mask]) - interpolated_values[nan_mask] = valid_values[indices] - else: - raise ValueError("克里金插值失败") - except Exception as e: - print(f"克里金插值失败: {e},回退到双线性插值") - interpolated_values = griddata( - valid_coords, valid_values, zero_coords, - method='linear', fill_value=0.0 - ) - nan_mask = np.isnan(interpolated_values) - if np.any(nan_mask): - from scipy.spatial import cKDTree - tree = cKDTree(valid_coords) - _, indices = tree.query(zero_coords[nan_mask]) - interpolated_values[nan_mask] = valid_values[indices] - else: - raise ValueError(f"不支持的插值方法: {interpolation_method}") - - # 更新波段数据(只更新所有波段都为0的像素) - band_data[all_bands_zero] = interpolated_values - interpolated_bands.append(band_data) - print(f"完成") - - # 保存所有波段 - for i, band_data in enumerate(interpolated_bands): - out_band = out_dataset.GetRasterBand(i + 1) - out_band.WriteArray(band_data) - out_band.FlushCache() - - out_dataset = None - dataset = None - - print(f"\n插值完成,共处理 {zero_pixel_count} 个所有波段都为0的像素点") - print(f"插值后的影像已保存: {output_path}") - - self.interpolated_img_path = output_path - return output_path - - finally: - if dataset: - dataset = None - - def step3_remove_glint(self, img_path: str, + # 转发至算法模块 + result_path, _ = _interpolate_zero_pixels_batch( + img_path=img_path, + interpolation_method=interpolation_method, + output_path=output_path, + water_mask=water_mask, + deglint_dir=str(self.deglint_dir), + callback_progress=lambda msg: print(f" {msg}") + ) + + self.interpolated_img_path = result_path + return result_path + + def step3_remove_glint(self, img_path: str, method: str = "subtract_nir", start_wave: Optional[float] = None, end_wave: Optional[float] = None,