refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块
This commit is contained in:
36
src/core/algorithms/__init__.py
Normal file
36
src/core/algorithms/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""
|
||||
算法层模块
|
||||
包含插值算法和耀斑检测算法等核心数学计算
|
||||
"""
|
||||
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
|
||||
from src.core.algorithms.glint_detection.detectors import (
|
||||
otsu_threshold,
|
||||
zscore_threshold,
|
||||
percentile_threshold,
|
||||
iqr_outlier_detection,
|
||||
adaptive_threshold,
|
||||
multi_band_glint_detection,
|
||||
percentile_stretch,
|
||||
filter_large_components,
|
||||
create_shoreline_buffer,
|
||||
remove_shoreline_buffer,
|
||||
calculate_glint_mask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 插值
|
||||
'interpolate_pixels',
|
||||
'interpolate_zero_pixels_batch',
|
||||
# 耀斑检测
|
||||
'otsu_threshold',
|
||||
'zscore_threshold',
|
||||
'percentile_threshold',
|
||||
'iqr_outlier_detection',
|
||||
'adaptive_threshold',
|
||||
'multi_band_glint_detection',
|
||||
'percentile_stretch',
|
||||
'filter_large_components',
|
||||
'create_shoreline_buffer',
|
||||
'remove_shoreline_buffer',
|
||||
'calculate_glint_mask',
|
||||
]
|
||||
31
src/core/algorithms/glint_detection/__init__.py
Normal file
31
src/core/algorithms/glint_detection/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
耀斑检测算法模块
|
||||
包含各种耀斑检测的核心数学计算函数
|
||||
"""
|
||||
from src.core.algorithms.glint_detection.detectors import (
|
||||
otsu_threshold,
|
||||
zscore_threshold,
|
||||
percentile_threshold,
|
||||
iqr_outlier_detection,
|
||||
adaptive_threshold,
|
||||
multi_band_glint_detection,
|
||||
percentile_stretch,
|
||||
filter_large_components,
|
||||
create_shoreline_buffer,
|
||||
remove_shoreline_buffer,
|
||||
calculate_glint_mask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'otsu_threshold',
|
||||
'zscore_threshold',
|
||||
'percentile_threshold',
|
||||
'iqr_outlier_detection',
|
||||
'adaptive_threshold',
|
||||
'multi_band_glint_detection',
|
||||
'percentile_stretch',
|
||||
'filter_large_components',
|
||||
'create_shoreline_buffer',
|
||||
'remove_shoreline_buffer',
|
||||
'calculate_glint_mask',
|
||||
]
|
||||
595
src/core/algorithms/glint_detection/detectors.py
Normal file
595
src/core/algorithms/glint_detection/detectors.py
Normal file
@ -0,0 +1,595 @@
|
||||
"""
|
||||
耀斑检测算法模块
|
||||
|
||||
包含各种耀斑检测的核心数学计算函数,纯数学逻辑,不涉及文件I/O。
|
||||
支持的方法:otsu, zscore, percentile, iqr, adaptive, multi_band
|
||||
|
||||
本模块是从 src/utils/find_severe_glint_area.py 抽取出来的核心算法部分。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, List, Tuple
|
||||
from functools import wraps
|
||||
|
||||
try:
|
||||
import cv2
|
||||
CV2_AVAILABLE = True
|
||||
except ImportError:
|
||||
CV2_AVAILABLE = False
|
||||
|
||||
|
||||
def timeit(func):
|
||||
"""装饰器:测量函数执行时间"""
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
import time
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
print(f"[{func.__name__}] 耗时: {end - start:.2f}s")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 百分位数拉伸
|
||||
# =============================================================================
|
||||
|
||||
def percentile_stretch(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
lower_percentile: float = 2,
|
||||
upper_percentile: float = 98,
|
||||
output_range: Tuple[int, int] = (0, 255)
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用百分位数裁剪进行归一化,适用于低反射率数据
|
||||
通过排除极值,更好地利用数据的动态范围
|
||||
|
||||
Args:
|
||||
img: 输入图像数组(反射率值,通常在0-1之间)
|
||||
data_water_mask: 水域掩膜
|
||||
lower_percentile: 下百分位数,用于裁剪最小值(默认2)
|
||||
upper_percentile: 上百分位数,用于裁剪最大值(默认98)
|
||||
output_range: 输出范围,默认(0, 255)
|
||||
|
||||
Returns:
|
||||
归一化后的图像数组(整数类型)
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return img.astype(np.int32)
|
||||
|
||||
p_lower = np.percentile(valid_pixels, lower_percentile)
|
||||
p_upper = np.percentile(valid_pixels, upper_percentile)
|
||||
|
||||
if p_lower >= p_upper:
|
||||
p_lower = np.percentile(valid_pixels, 1)
|
||||
p_upper = np.percentile(valid_pixels, 99)
|
||||
if p_lower >= p_upper:
|
||||
p_upper = valid_pixels.max()
|
||||
p_lower = valid_pixels.min()
|
||||
|
||||
img_clipped = np.clip(img, p_lower, p_upper)
|
||||
|
||||
if p_upper > p_lower:
|
||||
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (
|
||||
output_range[1] - output_range[0]
|
||||
) + output_range[0]
|
||||
else:
|
||||
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
|
||||
|
||||
return img_stretched.astype(np.int32)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Otsu阈值分割
|
||||
# =============================================================================
|
||||
|
||||
def otsu_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
ignore_value: int = 0,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于Otsu方法的自动阈值分割
|
||||
通过最大化类间方差找到最佳分割阈值
|
||||
|
||||
Args:
|
||||
img: 输入图像数组(整数值)
|
||||
data_water_mask: 水域掩膜
|
||||
ignore_value: 忽略的值(默认为0)
|
||||
foreground: 耀斑区域值(默认1)
|
||||
background: 背景值(默认0)
|
||||
|
||||
Returns:
|
||||
二值化检测结果数组
|
||||
"""
|
||||
height, width = img.shape
|
||||
|
||||
max_value = int(np.max(img[img > ignore_value])) + 1
|
||||
if max_value < 2:
|
||||
max_value = 256
|
||||
|
||||
hist = np.zeros([max_value], np.float32)
|
||||
|
||||
invalid_counter = 0
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
|
||||
invalid_counter += 1
|
||||
continue
|
||||
hist[img[i, j]] += 1
|
||||
|
||||
total_valid = height * width - invalid_counter
|
||||
if total_valid <= 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
hist /= total_valid
|
||||
|
||||
threshold = 0
|
||||
deltaMax = 0
|
||||
|
||||
for i in range(max_value):
|
||||
wA = sum(hist[:i + 1])
|
||||
wB = sum(hist[i + 1:])
|
||||
if wA == 0:
|
||||
wA = 1e-10
|
||||
if wB == 0:
|
||||
wB = 1e-10
|
||||
|
||||
uAtmp = sum(j * hist[j] for j in range(i + 1))
|
||||
uBtmp = sum(j * hist[j] for j in range(i + 1, max_value))
|
||||
uA = uAtmp / wA
|
||||
uB = uBtmp / wB
|
||||
u = uAtmp + uBtmp
|
||||
|
||||
deltaTmp = wA * ((uA - u) ** 2) + wB * ((uB - u) ** 2)
|
||||
if deltaTmp > deltaMax:
|
||||
deltaMax = deltaTmp
|
||||
threshold = i
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > threshold] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Z-score阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def zscore_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
z_threshold: float = 2.5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于Z-score(标准化分数)的耀斑检测方法
|
||||
使用统计方法识别异常高亮的像素,对数据分布不敏感
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
z_threshold: Z-score阈值,默认2.5(即超过均值2.5个标准差)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
mean_val = np.mean(valid_pixels)
|
||||
std_val = np.std(valid_pixels)
|
||||
|
||||
if std_val == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
z_scores = np.zeros_like(img, dtype=np.float32)
|
||||
valid_mask = (data_water_mask > 0) & np.isfinite(img)
|
||||
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[z_scores > z_threshold] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 百分位数阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def percentile_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
percentile: float = 95,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于百分位数的耀斑检测方法
|
||||
使用百分位数作为阈值,对异常值更稳健
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
percentile: 百分位数阈值,默认95(即超过95%的像素值)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
threshold_val = np.percentile(valid_pixels, percentile)
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > threshold_val] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# IQR异常值检测
|
||||
# =============================================================================
|
||||
|
||||
def iqr_outlier_detection(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
iqr_multiplier: float = 1.5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于IQR(四分位距)的异常值检测方法
|
||||
使用四分位距识别异常高亮的像素,对数据分布不敏感
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
iqr_multiplier: IQR倍数,默认1.5(标准异常值检测)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
q1 = np.percentile(valid_pixels, 25)
|
||||
q3 = np.percentile(valid_pixels, 75)
|
||||
iqr = q3 - q1
|
||||
|
||||
upper_bound = q3 + iqr_multiplier * iqr
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > upper_bound] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 自适应阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def adaptive_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
window_size: int = 15,
|
||||
percentile: float = 90,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
自适应阈值方法
|
||||
基于局部统计特性进行阈值分割,对光照变化更稳健
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
window_size: 局部窗口大小(奇数)
|
||||
percentile: 局部百分位数阈值
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
height, width = img.shape
|
||||
|
||||
if window_size % 2 == 0:
|
||||
window_size += 1
|
||||
|
||||
half_window = window_size // 2
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
for i in range(half_window, height - half_window):
|
||||
for j in range(half_window, width - half_window):
|
||||
if data_water_mask[i, j] == 0:
|
||||
continue
|
||||
|
||||
local_window = img[i - half_window:i + half_window + 1,
|
||||
j - half_window:j + half_window + 1]
|
||||
local_mask = data_water_mask[i - half_window:i + half_window + 1,
|
||||
j - half_window:j + half_window + 1]
|
||||
|
||||
valid_pixels = local_window[local_mask > 0]
|
||||
|
||||
if len(valid_pixels) > 0:
|
||||
local_th = np.percentile(valid_pixels, percentile)
|
||||
if img[i, j] > local_th:
|
||||
det_img[i, j] = foreground
|
||||
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 多波段融合耀斑检测
|
||||
# =============================================================================
|
||||
|
||||
def multi_band_glint_detection(
|
||||
nir_band: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
glint_waves: List[float],
|
||||
weights: Optional[List[float]] = None,
|
||||
method: str = 'zscore',
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95,
|
||||
sub_band_arrays: Optional[List[np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
多波段融合的耀斑检测方法
|
||||
结合多个波段的耀斑特征,提高检测的稳健性
|
||||
|
||||
Args:
|
||||
nir_band: 近红外波段数组(主波段,用于兼容性)
|
||||
water_mask: 水域掩膜数组
|
||||
glint_waves: 用于检测的波长列表,如[750, 800, 850]
|
||||
weights: 各波段的权重,如果为None则使用等权重
|
||||
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
|
||||
z_threshold: Z-score阈值(当method='zscore'时使用)
|
||||
percentile: 百分位数阈值(当method='percentile'时使用)
|
||||
sub_band_arrays: 子波段数组列表(如果提供,与 glint_waves 一一对应)
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
if weights is None:
|
||||
weights = [1.0 / len(glint_waves)] * len(glint_waves)
|
||||
|
||||
if len(weights) != len(glint_waves):
|
||||
raise ValueError("权重数量必须与波长数量相同")
|
||||
|
||||
fused_band = None
|
||||
|
||||
if sub_band_arrays is not None and len(sub_band_arrays) == len(glint_waves):
|
||||
for i, band_array in enumerate(sub_band_arrays):
|
||||
if fused_band is None:
|
||||
fused_band = (band_array * weights[i]).astype(np.float32)
|
||||
else:
|
||||
fused_band = (fused_band + band_array * weights[i]).astype(np.float32)
|
||||
else:
|
||||
fused_band = nir_band.astype(np.float32)
|
||||
|
||||
if method == 'otsu':
|
||||
stretched = percentile_stretch(fused_band, water_mask, 2, 98)
|
||||
return otsu_threshold(stretched, water_mask)
|
||||
elif method == 'zscore':
|
||||
return zscore_threshold(fused_band, water_mask, z_threshold)
|
||||
elif method == 'percentile':
|
||||
return percentile_threshold(fused_band, water_mask, percentile)
|
||||
else:
|
||||
raise ValueError(f"不支持的方法: {method}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 连通域过滤
|
||||
# =============================================================================
|
||||
|
||||
def filter_large_components(
|
||||
binary_img: np.ndarray,
|
||||
max_area: Optional[int] = None,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
过滤掉面积超过阈值的连通域
|
||||
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
|
||||
|
||||
Args:
|
||||
binary_img: 二值化图像
|
||||
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
过滤后的二值化图像
|
||||
"""
|
||||
if max_area is None or max_area <= 0:
|
||||
return binary_img
|
||||
|
||||
if CV2_AVAILABLE:
|
||||
binary_for_label = (binary_img == foreground).astype(np.uint8)
|
||||
num_features, labeled_array, stats, _ = cv2.connectedComponentsWithStats(
|
||||
binary_for_label, connectivity=8
|
||||
)
|
||||
|
||||
if num_features == 0:
|
||||
return binary_img
|
||||
|
||||
component_sizes = stats[1:, cv2.CC_STAT_AREA]
|
||||
keep_labels = np.where(component_sizes <= max_area)[0] + 1
|
||||
|
||||
keep_mask = np.isin(labeled_array, keep_labels)
|
||||
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||||
filtered[keep_mask] = foreground
|
||||
|
||||
return filtered
|
||||
else:
|
||||
from scipy import ndimage
|
||||
labeled_array, num_features = ndimage.label(
|
||||
(binary_img == foreground).astype(np.int32)
|
||||
)
|
||||
|
||||
if num_features == 0:
|
||||
return binary_img
|
||||
|
||||
component_sizes = ndimage.sum(
|
||||
(labeled_array == i).astype(np.int32),
|
||||
labeled_array,
|
||||
range(1, num_features + 1)
|
||||
)
|
||||
|
||||
keep_mask = np.isin(labeled_array, [i + 1 for i, s in enumerate(component_sizes) if s <= max_area])
|
||||
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||||
filtered[keep_mask] = foreground
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 岸边缓冲区处理
|
||||
# =============================================================================
|
||||
|
||||
def create_shoreline_buffer(
|
||||
water_mask: np.ndarray,
|
||||
buffer_size: int = 5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
创建岸边缓冲区掩膜(向内缓冲)
|
||||
用于去除岸边附近的错误耀斑检测区域
|
||||
|
||||
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
|
||||
|
||||
Args:
|
||||
water_mask: 水域掩膜数组(水域=1,非水域=0)
|
||||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
岸边缓冲区掩膜(缓冲区区域=1,其他=0)
|
||||
"""
|
||||
if buffer_size <= 0:
|
||||
return np.zeros_like(water_mask, dtype=np.int32)
|
||||
|
||||
water_binary = (water_mask > 0).astype(np.uint8)
|
||||
structure_size = buffer_size * 2 + 1
|
||||
structure = np.ones((structure_size, structure_size), dtype=np.uint8)
|
||||
|
||||
if CV2_AVAILABLE:
|
||||
eroded_water = cv2.erode(water_binary, structure).astype(np.int32)
|
||||
else:
|
||||
from scipy import ndimage
|
||||
eroded_water = ndimage.binary_erosion(water_binary, structure).astype(np.int32)
|
||||
|
||||
buffer_mask = (water_binary - eroded_water).astype(np.int32)
|
||||
|
||||
return buffer_mask
|
||||
|
||||
|
||||
def remove_shoreline_buffer(
|
||||
glint_mask: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
buffer_size: int = 5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
从耀斑掩膜中去除岸边缓冲区内的区域
|
||||
|
||||
Args:
|
||||
glint_mask: 耀斑掩膜数组
|
||||
water_mask: 水域掩膜数组
|
||||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
去除岸边缓冲区后的耀斑掩膜
|
||||
"""
|
||||
if buffer_size <= 0:
|
||||
return glint_mask
|
||||
|
||||
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
|
||||
|
||||
cleaned = glint_mask.copy()
|
||||
cleaned[buffer_mask > 0] = background
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 高级组合函数
|
||||
# =============================================================================
|
||||
|
||||
def calculate_glint_mask(
|
||||
nir_band: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
method: str = 'otsu',
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95,
|
||||
iqr_multiplier: float = 1.5,
|
||||
window_size: int = 15,
|
||||
apply_percentile_stretch: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
计算耀斑掩膜的统一入口函数
|
||||
|
||||
Args:
|
||||
nir_band: 近红外波段数组
|
||||
water_mask: 水域掩膜
|
||||
method: 检测方法 ('otsu', 'zscore', 'percentile', 'iqr', 'adaptive')
|
||||
z_threshold: Z-score阈值
|
||||
percentile: 百分位数阈值
|
||||
iqr_multiplier: IQR倍数
|
||||
window_size: 自适应阈值窗口大小
|
||||
apply_percentile_stretch: 是否对otsu和adaptive方法应用百分位数拉伸
|
||||
|
||||
Returns:
|
||||
二值化耀斑掩膜
|
||||
"""
|
||||
if method == 'otsu':
|
||||
if apply_percentile_stretch:
|
||||
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
|
||||
return otsu_threshold(stretched, water_mask)
|
||||
else:
|
||||
return otsu_threshold(nir_band.astype(np.int32), water_mask)
|
||||
elif method == 'zscore':
|
||||
return zscore_threshold(nir_band, water_mask, z_threshold)
|
||||
elif method == 'percentile':
|
||||
return percentile_threshold(nir_band, water_mask, percentile)
|
||||
elif method == 'iqr':
|
||||
return iqr_outlier_detection(nir_band, water_mask, iqr_multiplier)
|
||||
elif method == 'adaptive':
|
||||
if apply_percentile_stretch:
|
||||
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
|
||||
return adaptive_threshold(stretched, water_mask, window_size, percentile)
|
||||
else:
|
||||
return adaptive_threshold(nir_band.astype(np.int32), water_mask, window_size, percentile)
|
||||
else:
|
||||
raise ValueError(f"不支持的方法: {method}")
|
||||
7
src/core/algorithms/interpolation/__init__.py
Normal file
7
src/core/algorithms/interpolation/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""
|
||||
插值算法模块
|
||||
包含0值像素插值的核心数学逻辑
|
||||
"""
|
||||
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
|
||||
|
||||
__all__ = ['interpolate_pixels', 'interpolate_zero_pixels_batch']
|
||||
320
src/core/algorithms/interpolation/interpolator.py
Normal file
320
src/core/algorithms/interpolation/interpolator.py
Normal file
@ -0,0 +1,320 @@
|
||||
"""
|
||||
像素插值算法模块
|
||||
|
||||
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
|
||||
支持多种插值方法:nearest, bilinear, spline (RBF), kriging。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import griddata, RBFInterpolator
|
||||
from scipy.spatial import cKDTree
|
||||
SCIPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
SCIPY_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
def interpolate_pixels(
|
||||
image_stack: np.ndarray,
|
||||
zero_coords: np.ndarray,
|
||||
valid_coords: np.ndarray,
|
||||
valid_values: np.ndarray,
|
||||
interpolation_method: str = 'nearest',
|
||||
water_mask: Optional[np.ndarray] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对指定坐标的像素进行插值(核心数学函数,不涉及文件I/O)
|
||||
|
||||
Args:
|
||||
image_stack: 影像数据堆叠,形状为 (height, width, n_bands) 的 float32 数组
|
||||
zero_coords: 需要插值的像素坐标,形状为 (n_zero, 2),每行是 [x, y]
|
||||
valid_coords: 有效像素坐标,形状为 (n_valid, 2)
|
||||
valid_values: 有效像素对应的值,形状为 (n_valid,) 或 (n_valid, n_bands)
|
||||
interpolation_method: 插值方法,可选 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
water_mask: 可选的水域掩膜数组
|
||||
|
||||
Returns:
|
||||
插值后的影像副本,形状与 image_stack 相同
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
|
||||
height, width, n_bands = image_stack.shape
|
||||
result = image_stack.copy()
|
||||
|
||||
# 兼容中文和各种格式的method参数
|
||||
raw_method = str(interpolation_method).lower()
|
||||
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
|
||||
method = 'nearest'
|
||||
elif 'bilinear' in raw_method or '线性' in raw_method or '双线性' in raw_method:
|
||||
method = 'bilinear'
|
||||
elif 'spline' in raw_method or '样条' in raw_method or 'rbf' in raw_method:
|
||||
method = 'spline'
|
||||
elif 'kriging' in raw_method or '克里金' in raw_method:
|
||||
method = 'kriging'
|
||||
else:
|
||||
method = 'nearest'
|
||||
|
||||
if len(valid_values) == 0:
|
||||
return result
|
||||
|
||||
is_multiband = len(valid_values.shape) > 1 and valid_values.shape[1] > 1
|
||||
|
||||
if is_multiband:
|
||||
for band_idx in range(n_bands):
|
||||
band_valid_values = valid_values[:, band_idx]
|
||||
interpolated_values = _interpolate_single_band(
|
||||
zero_coords, valid_coords, band_valid_values, method
|
||||
)
|
||||
y_coords = zero_coords[:, 1].astype(int)
|
||||
x_coords = zero_coords[:, 0].astype(int)
|
||||
result[y_coords, x_coords, band_idx] = interpolated_values
|
||||
else:
|
||||
interpolated_values = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values, method
|
||||
)
|
||||
y_coords = zero_coords[:, 1].astype(int)
|
||||
x_coords = zero_coords[:, 0].astype(int)
|
||||
result[y_coords, x_coords] = interpolated_values
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _interpolate_single_band(
|
||||
zero_coords: np.ndarray,
|
||||
valid_coords: np.ndarray,
|
||||
valid_values: np.ndarray,
|
||||
method: str
|
||||
) -> np.ndarray:
|
||||
"""对单个波段执行插值计算"""
|
||||
if method == 'nearest':
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords)
|
||||
return valid_values[indices]
|
||||
|
||||
elif method == 'bilinear':
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
elif method == 'spline':
|
||||
try:
|
||||
max_points = 10000
|
||||
if len(valid_values) > max_points:
|
||||
indices = np.random.choice(len(valid_values), max_points, replace=False)
|
||||
sample_coords = valid_coords[indices]
|
||||
sample_values = valid_values[indices]
|
||||
else:
|
||||
sample_coords = valid_coords
|
||||
sample_values = valid_values
|
||||
rbf = RBFInterpolator(sample_coords, sample_values, kernel='thin_plate_spline')
|
||||
interpolated = rbf(zero_coords)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
except Exception:
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
elif method == 'kriging':
|
||||
try:
|
||||
from src.utils.kriging import KrigingInterpolator
|
||||
interpolator = KrigingInterpolator()
|
||||
max_points = 5000
|
||||
if len(valid_values) > max_points:
|
||||
indices = np.random.choice(len(valid_values), max_points, replace=False)
|
||||
sample_coords = valid_coords[indices]
|
||||
sample_values = valid_values[indices]
|
||||
else:
|
||||
sample_coords = valid_coords
|
||||
sample_values = valid_values
|
||||
interpolated = griddata(
|
||||
sample_coords, sample_values, zero_coords,
|
||||
method='cubic', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
except Exception:
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
return np.zeros(len(zero_coords))
|
||||
|
||||
|
||||
def interpolate_zero_pixels_batch(
|
||||
img_path: str,
|
||||
interpolation_method: str = 'nearest',
|
||||
output_path: Optional[str] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
deglint_dir: Optional[str] = None,
|
||||
callback_progress: Optional[callable] = None
|
||||
) -> Tuple[str, Optional[np.ndarray]]:
|
||||
"""
|
||||
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
output_path: 输出文件路径(如果为None,自动生成)
|
||||
water_mask: 水域掩膜(文件路径或数组)
|
||||
deglint_dir: 去耀斑目录(用于生成默认输出路径)
|
||||
callback_progress: 进度回调函数
|
||||
|
||||
Returns:
|
||||
(output_path, interpolated_image_stack) 元组
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
# 确定输出路径
|
||||
if output_path is None and deglint_dir is not None:
|
||||
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
|
||||
|
||||
# 检查文件是否已存在
|
||||
if output_path and Path(output_path).exists():
|
||||
return output_path, None
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
# 读取所有波段数据
|
||||
all_bands = []
|
||||
for band_idx in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(band_idx)
|
||||
band_data = band.ReadAsArray().astype(np.float32)
|
||||
all_bands.append(band_data)
|
||||
|
||||
image_stack = np.dstack(all_bands)
|
||||
|
||||
# 读取水域掩膜
|
||||
mask_array = None
|
||||
if water_mask is not None:
|
||||
if isinstance(water_mask, str):
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
elif isinstance(water_mask, np.ndarray):
|
||||
mask_array = water_mask
|
||||
|
||||
# 找出所有波段都为0的像素点
|
||||
all_bands_zero = np.all(image_stack == 0, axis=2)
|
||||
|
||||
if mask_array is not None:
|
||||
all_bands_zero = all_bands_zero & (mask_array > 0)
|
||||
|
||||
zero_pixel_count = np.sum(all_bands_zero)
|
||||
if zero_pixel_count == 0:
|
||||
# 无需插值,直接保存
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(all_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
out_dataset = None
|
||||
return output_path, image_stack
|
||||
|
||||
# 获取坐标
|
||||
zero_y, zero_x = np.where(all_bands_zero)
|
||||
zero_coords = np.column_stack([zero_x, zero_y])
|
||||
|
||||
valid_mask = ~all_bands_zero
|
||||
valid_y, valid_x = np.where(valid_mask)
|
||||
valid_coords = np.column_stack([valid_x, valid_y])
|
||||
|
||||
if len(valid_coords) == 0:
|
||||
raise ValueError("没有有效像素可用于插值")
|
||||
|
||||
# 逐波段插值
|
||||
interpolated_bands = []
|
||||
for band_idx in range(n_bands):
|
||||
if callback_progress:
|
||||
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
|
||||
band_data = all_bands[band_idx].copy()
|
||||
valid_values_band = band_data[valid_mask]
|
||||
|
||||
if len(valid_values_band) == 0:
|
||||
interpolated_bands.append(band_data)
|
||||
continue
|
||||
|
||||
band_result = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values_band, interpolation_method
|
||||
)
|
||||
band_data[all_bands_zero] = band_result
|
||||
interpolated_bands.append(band_data)
|
||||
|
||||
# 保存结果
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(interpolated_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
out_dataset = None
|
||||
|
||||
result_stack = np.dstack(interpolated_bands)
|
||||
return output_path, result_stack
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
42
src/core/utils/__init__.py
Normal file
42
src/core/utils/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
工具模块 - 统一导出接口
|
||||
"""
|
||||
from src.core.utils.gdal_helper import (
|
||||
get_image_geo_info,
|
||||
load_image_as_array,
|
||||
save_array_as_image,
|
||||
save_bands_as_image,
|
||||
copy_hdr_info,
|
||||
read_band_as_array,
|
||||
read_multiple_bands,
|
||||
)
|
||||
from src.core.utils.mask_converter import (
|
||||
prepare_water_mask_for_algorithm,
|
||||
ensure_water_mask_dat,
|
||||
)
|
||||
from src.core.utils.preview_generator import (
|
||||
generate_image_preview,
|
||||
generate_water_mask_overlay,
|
||||
select_rgb_bands_by_wavelength,
|
||||
get_wavelength_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# GDAL IO
|
||||
'get_image_geo_info',
|
||||
'load_image_as_array',
|
||||
'save_array_as_image',
|
||||
'save_bands_as_image',
|
||||
'copy_hdr_info',
|
||||
'read_band_as_array',
|
||||
'read_multiple_bands',
|
||||
# 掩膜转换
|
||||
'prepare_water_mask_for_algorithm',
|
||||
'ensure_water_mask_dat',
|
||||
# 预览图生成
|
||||
'generate_image_preview',
|
||||
'generate_water_mask_overlay',
|
||||
'select_rgb_bands_by_wavelength',
|
||||
'get_wavelength_info',
|
||||
]
|
||||
309
src/core/utils/gdal_helper.py
Normal file
309
src/core/utils/gdal_helper.py
Normal file
@ -0,0 +1,309 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
GDAL 底层 IO 工具模块
|
||||
|
||||
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
|
||||
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
# GDAL 导入(可选)
|
||||
try:
|
||||
from osgeo import gdal, ogr, gdal_array
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# hdr 文件工具
|
||||
try:
|
||||
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
|
||||
UTIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
UTIL_AVAILABLE = False
|
||||
write_fields_to_hdrfile = None
|
||||
get_hdr_file_path = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 影像信息读取
|
||||
# ============================================================
|
||||
|
||||
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
|
||||
"""
|
||||
获取影像的地理信息(不加载图像数据,节省内存)
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (geotransform, projection, width, height, n_bands)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
return geotransform, projection, width, height, n_bands
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
|
||||
"""
|
||||
加载影像文件为numpy数组
|
||||
|
||||
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
|
||||
建议直接传递文件路径给算法类,让算法类使用GDAL逐波段处理。
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (image_array, geotransform, projection)
|
||||
image_array: numpy数组,形状为(height, width, bands)
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
image_bands = []
|
||||
for i in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(i)
|
||||
band_data = band.ReadAsArray()
|
||||
image_bands.append(band_data)
|
||||
|
||||
image_array = np.dstack(image_bands)
|
||||
return image_array, geotransform, projection
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
|
||||
"""
|
||||
读取单个波段为 numpy 数组
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
band_index: 波段索引(从 0 开始)
|
||||
|
||||
Returns:
|
||||
numpy 数组,形状为 (height, width)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
band = dataset.GetRasterBand(band_index + 1)
|
||||
return band.ReadAsArray()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
|
||||
"""
|
||||
读取多个指定波段为列表
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
band_indices: 波段索引列表
|
||||
|
||||
Returns:
|
||||
tuple: (band_list, geotransform, projection)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
bands = []
|
||||
for idx in band_indices:
|
||||
band = dataset.GetRasterBand(idx + 1)
|
||||
bands.append(band.ReadAsArray())
|
||||
return bands, geotransform, projection
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 影像写入
|
||||
# ============================================================
|
||||
|
||||
def save_array_as_image(image_array: np.ndarray, output_path: str,
|
||||
geotransform: tuple, projection: str,
|
||||
dtype=None) -> str:
|
||||
"""
|
||||
将numpy数组保存为影像文件
|
||||
|
||||
Args:
|
||||
image_array: numpy数组,形状为(height, width, bands) 或 (height, width)
|
||||
output_path: 输出文件路径
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
dtype: GDAL数据类型(默认 gdal.GDT_Float32)
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if dtype is None:
|
||||
dtype = gdal.GDT_Float32
|
||||
|
||||
if image_array.ndim == 2:
|
||||
height, width = image_array.shape
|
||||
n_bands = 1
|
||||
else:
|
||||
height, width, n_bands = image_array.shape
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
if driver is None:
|
||||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||||
|
||||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||||
|
||||
try:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
if n_bands == 1:
|
||||
band = dataset.GetRasterBand(1)
|
||||
band.WriteArray(image_array)
|
||||
band.FlushCache()
|
||||
else:
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(image_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def save_bands_as_image(corrected_bands: list, output_path: str,
|
||||
geotransform: tuple, projection: str,
|
||||
dtype=None) -> str:
|
||||
"""
|
||||
直接从波段列表保存影像文件(避免堆叠,节省内存)
|
||||
|
||||
Args:
|
||||
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
|
||||
output_path: 输出文件路径
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
dtype: GDAL数据类型
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if not corrected_bands:
|
||||
raise ValueError("波段列表为空")
|
||||
|
||||
if dtype is None:
|
||||
dtype = gdal.GDT_Float32
|
||||
|
||||
n_bands = len(corrected_bands)
|
||||
height, width = corrected_bands[0].shape
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
if driver is None:
|
||||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||||
|
||||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||||
|
||||
try:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
for i, band_array in enumerate(corrected_bands):
|
||||
if band_array.shape != (height, width):
|
||||
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(band_array)
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
|
||||
"""
|
||||
复制原始影像的hdr文件信息(如波长等)到目标影像的hdr文件
|
||||
|
||||
Args:
|
||||
source_img_path: 源影像文件路径(原始bsq文件)
|
||||
dest_img_path: 目标影像文件路径(去耀斑后的bsq文件)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not UTIL_AVAILABLE:
|
||||
print("警告: util模块未导入,无法复制hdr文件信息")
|
||||
return False
|
||||
|
||||
try:
|
||||
source_hdr_path = get_hdr_file_path(source_img_path)
|
||||
dest_hdr_path = get_hdr_file_path(dest_img_path)
|
||||
|
||||
if not Path(source_hdr_path).exists():
|
||||
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
|
||||
return False
|
||||
|
||||
if not Path(dest_hdr_path).exists():
|
||||
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
|
||||
return False
|
||||
|
||||
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
|
||||
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"警告: 复制hdr文件信息时出错: {e}")
|
||||
return False
|
||||
210
src/core/utils/mask_converter.py
Normal file
210
src/core/utils/mask_converter.py
Normal file
@ -0,0 +1,210 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
掩膜转换工具模块
|
||||
|
||||
提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换,
|
||||
以及水体掩膜的预处理逻辑。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
def prepare_water_mask_for_algorithm(
|
||||
water_mask: Optional[Union[str, np.ndarray]],
|
||||
image_shape: Union[tuple, np.ndarray],
|
||||
geotransform: tuple,
|
||||
projection: str,
|
||||
img_path: str,
|
||||
water_mask_dir: Optional[str] = None,
|
||||
callback=None
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
准备水域掩膜供算法使用
|
||||
|
||||
支持格式:
|
||||
- None:自动使用预先生成的 dat 格式掩膜
|
||||
- numpy.ndarray:直接返回(确保是 0/1 格式)
|
||||
- .dat / .tif 等栅格文件:读取并返回
|
||||
- .shp 文件:先栅格化,再读取返回
|
||||
|
||||
Args:
|
||||
water_mask: 掩膜来源
|
||||
image_shape: 影像形状 (height, width) 或 (height, width, channels)
|
||||
geotransform: GDAL 地理变换参数
|
||||
projection: 投影信息
|
||||
img_path: 影像路径(用于 shp 栅格化)
|
||||
water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果)
|
||||
callback: 进度回调函数(可选)
|
||||
|
||||
Returns:
|
||||
numpy数组(dtype=uint8,0=非水域,1=水域)或 None
|
||||
"""
|
||||
img_height, img_width = image_shape[0], image_shape[1]
|
||||
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# numpy 数组直接返回
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (img_height, img_width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8)
|
||||
|
||||
# 字符串路径
|
||||
if isinstance(water_mask, str):
|
||||
ext = Path(water_mask).suffix.lower()
|
||||
|
||||
# shapefile 格式
|
||||
if ext == '.shp':
|
||||
return _convert_shp_to_mask(
|
||||
shp_path=water_mask,
|
||||
img_path=img_path,
|
||||
image_shape=image_shape,
|
||||
geotransform=geotransform,
|
||||
projection=projection,
|
||||
water_mask_dir=water_mask_dir,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
# 栅格文件格式
|
||||
return _load_raster_mask(water_mask, img_height, img_width)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
|
||||
def _convert_shp_to_mask(shp_path: str, img_path: str,
|
||||
image_shape: tuple,
|
||||
geotransform: tuple,
|
||||
projection: str,
|
||||
water_mask_dir: Optional[str] = None,
|
||||
callback=None) -> np.ndarray:
|
||||
"""将 shapefile 栅格化为掩膜数组"""
|
||||
from src.utils.extract_water_area import rasterize_shp
|
||||
|
||||
safe_shp_path = os.path.abspath(shp_path).replace('\\', '/')
|
||||
shp_name = Path(safe_shp_path).stem
|
||||
|
||||
if water_mask_dir:
|
||||
temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat")
|
||||
else:
|
||||
temp_mask_path = f"/tmp/water_mask_{shp_name}.dat"
|
||||
|
||||
# 缓存:已栅格化则直接读取
|
||||
if Path(temp_mask_path).exists():
|
||||
print(f"使用已存在的栅格化掩膜: {temp_mask_path}")
|
||||
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
|
||||
|
||||
# 需要栅格化
|
||||
if img_path is None:
|
||||
raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化")
|
||||
|
||||
print(f"正在将 SHP 栅格化: {safe_shp_path}")
|
||||
rasterize_shp(safe_shp_path, temp_mask_path, img_path)
|
||||
|
||||
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
|
||||
|
||||
|
||||
def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray:
|
||||
"""从栅格文件加载掩膜"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取掩膜文件")
|
||||
|
||||
mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {mask_path}")
|
||||
|
||||
try:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
finally:
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (img_height, img_width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
|
||||
def ensure_water_mask_dat(img_path: str,
|
||||
existing_dat_path: Optional[str] = None,
|
||||
output_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
确保存在 dat 格式的水体掩膜文件(用于步骤3/4中的算法)
|
||||
|
||||
如果 existing_dat_path 存在且是 .dat 文件,直接返回。
|
||||
如果存在同名 .dat 文件,直接返回。
|
||||
否则从 img_path 生成并保存到 output_dir。
|
||||
|
||||
Args:
|
||||
img_path: 用于生成掩膜的遥感影像路径
|
||||
existing_dat_path: 已有的 dat 格式掩膜路径(可选)
|
||||
output_dir: 输出目录(可选)
|
||||
|
||||
Returns:
|
||||
dat 格式掩膜文件路径
|
||||
"""
|
||||
if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat':
|
||||
if Path(existing_dat_path).exists():
|
||||
return existing_dat_path
|
||||
|
||||
img_name = Path(img_path).stem
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(img_path).parent)
|
||||
|
||||
dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat")
|
||||
|
||||
if Path(dat_path).exists():
|
||||
return dat_path
|
||||
|
||||
# 如果已有其他格式的掩膜,转换为 dat
|
||||
for ext in ['.tif', '.img', '.tiff']:
|
||||
alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}")
|
||||
if Path(alt_path).exists():
|
||||
return _convert_to_dat(alt_path, dat_path)
|
||||
|
||||
return dat_path # 返回目标路径,让调用方决定是否需要生成
|
||||
|
||||
|
||||
def _convert_to_dat(src_path: str, dest_path: str) -> str:
|
||||
"""将其他栅格格式转换为 ENVI dat 格式"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法转换格式")
|
||||
|
||||
src_ds = gdal.Open(src_path, gdal.GA_ReadOnly)
|
||||
if src_ds is None:
|
||||
raise ValueError(f"无法打开源掩膜文件: {src_path}")
|
||||
|
||||
try:
|
||||
geotransform = src_ds.GetGeoTransform()
|
||||
projection = src_ds.GetProjection()
|
||||
band = src_ds.GetRasterBand(1)
|
||||
array = band.ReadAsArray()
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte)
|
||||
if dest_ds is None:
|
||||
raise ValueError(f"无法创建输出文件: {dest_path}")
|
||||
|
||||
try:
|
||||
dest_ds.SetGeoTransform(geotransform)
|
||||
dest_ds.SetProjection(projection)
|
||||
dest_band = dest_ds.GetRasterBand(1)
|
||||
dest_band.WriteArray((array > 0).astype(np.uint8))
|
||||
dest_band.FlushCache()
|
||||
finally:
|
||||
dest_ds = None
|
||||
|
||||
return dest_path
|
||||
finally:
|
||||
src_ds = None
|
||||
339
src/core/utils/preview_generator.py
Normal file
339
src/core/utils/preview_generator.py
Normal file
@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
遥感影像预览图生成工具模块
|
||||
|
||||
提供高光谱影像的 RGB 预览图、水域掩膜叠加图等可视化功能。
|
||||
"""
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# matplotlib 仅在实际使用时导入(preview_generator 是可视化工具)
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 辅助函数:波段选择
|
||||
# ============================================================
|
||||
|
||||
def select_rgb_bands_by_wavelength(band_count: int,
|
||||
wavelength_info: Optional[List[float]] = None,
|
||||
fallback_bands: Optional[List[int]] = None) -> List[int]:
|
||||
"""
|
||||
根据波长自动选择 RGB 波段
|
||||
|
||||
Args:
|
||||
band_count: 总波段数
|
||||
wavelength_info: 各波段波长列表(nm),长度为 band_count
|
||||
fallback_bands: 当无法通过波长选择时的回退波段索引 [R, G, B]
|
||||
|
||||
Returns:
|
||||
波段索引列表 [R_index, G_index, B_index](0-based)
|
||||
"""
|
||||
if fallback_bands is None:
|
||||
fallback_bands = [band_count - 3, band_count - 2, band_count - 1]
|
||||
|
||||
if wavelength_info is None:
|
||||
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
|
||||
|
||||
# 目标波长(nm)
|
||||
TARGET_R = 650
|
||||
TARGET_G = 550
|
||||
TARGET_B = 460
|
||||
|
||||
def find_closest(target: float) -> int:
|
||||
min_dist = float('inf')
|
||||
best_idx = 0
|
||||
for i, wl in enumerate(wavelength_info):
|
||||
dist = abs(wl - target)
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
best_idx = i
|
||||
return best_idx
|
||||
|
||||
try:
|
||||
r_idx = find_closest(TARGET_R)
|
||||
g_idx = find_closest(TARGET_G)
|
||||
b_idx = find_closest(TARGET_B)
|
||||
return [r_idx, g_idx, b_idx]
|
||||
except Exception:
|
||||
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
|
||||
|
||||
|
||||
def get_wavelength_info(img_path: str) -> Optional[List[float]]:
|
||||
"""从 hdr 文件读取波长信息"""
|
||||
try:
|
||||
hdr_path = Path(img_path).with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
return None
|
||||
|
||||
wavelengths = []
|
||||
in_wl = False
|
||||
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith('wavelength ='):
|
||||
in_wl = True
|
||||
line = line.split('=', 1)[1].strip()
|
||||
elif in_wl:
|
||||
if line.startswith('{'):
|
||||
line = line[1:]
|
||||
if line.endswith('}'):
|
||||
line = line[:-1]
|
||||
in_wl = False
|
||||
# 解析逗号分隔的数值
|
||||
for token in line.replace(',', ' ').split():
|
||||
try:
|
||||
wavelengths.append(float(token))
|
||||
except ValueError:
|
||||
pass
|
||||
return wavelengths if wavelengths else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 核心预览图生成函数
|
||||
# ============================================================
|
||||
|
||||
def generate_image_preview(img_path: str,
|
||||
output_path: str,
|
||||
bands: Optional[List[int]] = None,
|
||||
title: str = "影像预览") -> str:
|
||||
"""
|
||||
生成高光谱影像的 PNG 预览图
|
||||
|
||||
Args:
|
||||
img_path: 输入影像路径
|
||||
output_path: 输出 PNG 文件路径
|
||||
bands: RGB 波段索引 [R, G, B],None 则自动选择
|
||||
title: 图片标题
|
||||
|
||||
Returns:
|
||||
生成的 PNG 文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法生成影像预览图")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预览图,跳过生成: {output_path}")
|
||||
return output_path
|
||||
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
|
||||
# 自动选择波段
|
||||
if bands is None:
|
||||
if band_count >= 3:
|
||||
wl_info = get_wavelength_info(img_path)
|
||||
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
|
||||
else:
|
||||
bands = [0, 0, 0]
|
||||
|
||||
# 读取波段
|
||||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||||
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
|
||||
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
|
||||
|
||||
r_data[r_data <= 0] = np.nan
|
||||
if band_count > 1:
|
||||
g_data[g_data <= 0] = np.nan
|
||||
if band_count > 2:
|
||||
b_data[b_data <= 0] = np.nan
|
||||
|
||||
# 线性拉伸
|
||||
def linear_stretch(data, low=2, high=98):
|
||||
valid = data[~np.isnan(data)]
|
||||
if len(valid) == 0:
|
||||
return np.zeros_like(data)
|
||||
lo = np.percentile(valid, low)
|
||||
hi = np.percentile(valid, high)
|
||||
if hi - lo < 1e-10:
|
||||
return np.zeros_like(data)
|
||||
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
|
||||
return np.nan_to_num(stretched, nan=0.0)
|
||||
|
||||
r_s = linear_stretch(r_data)
|
||||
g_s = linear_stretch(g_data) if band_count > 1 else r_s
|
||||
b_s = linear_stretch(b_data) if band_count > 2 else r_s
|
||||
|
||||
rgb_image = np.stack([r_s, g_s, b_s], axis=2)
|
||||
|
||||
# 绘图
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
ax.imshow(rgb_image)
|
||||
ax.set_title(title, fontsize=12, fontweight='bold')
|
||||
ax.axis('off')
|
||||
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
if geotransform and geotransform[1] != 0:
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
scale_text = f"分辨率: {pixel_size_x:.2f} m/px | 尺寸: {width} x {height} px"
|
||||
fig.text(0.5, 0.02, scale_text, ha='center', fontsize=9,
|
||||
color='white',
|
||||
bbox=dict(facecolor='black', alpha=0.6,
|
||||
boxstyle='round,pad=0.3'))
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.close(fig)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def generate_water_mask_overlay(img_path: str,
|
||||
mask_path: str,
|
||||
output_path: str,
|
||||
bands: Optional[List[int]] = None,
|
||||
mask_color: tuple = (0, 100, 255),
|
||||
mask_alpha: float = 0.5) -> str:
|
||||
"""
|
||||
生成水域掩膜叠加到原图的 PNG 图像
|
||||
|
||||
Args:
|
||||
img_path: 输入影像路径
|
||||
mask_path: 水域掩膜文件路径
|
||||
output_path: 输出 PNG 路径
|
||||
bands: RGB 波段索引,None 则自动选择
|
||||
mask_color: 掩膜叠加颜色 (R, G, B)
|
||||
mask_alpha: 掩膜透明度(0=完全透明,1=完全不透明)
|
||||
|
||||
Returns:
|
||||
生成的 PNG 文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法生成叠加图")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的叠加图,跳过生成: {output_path}")
|
||||
return output_path
|
||||
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
|
||||
# 自动选择波段
|
||||
if bands is None:
|
||||
if band_count >= 3:
|
||||
wl_info = get_wavelength_info(img_path)
|
||||
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
|
||||
else:
|
||||
bands = [0, 0, 0]
|
||||
|
||||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||||
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
|
||||
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
|
||||
|
||||
r_data[r_data <= 0] = np.nan
|
||||
if band_count > 1:
|
||||
g_data[g_data <= 0] = np.nan
|
||||
if band_count > 2:
|
||||
b_data[b_data <= 0] = np.nan
|
||||
|
||||
def linear_stretch(data, low=2, high=98):
|
||||
valid = data[~np.isnan(data)]
|
||||
if len(valid) == 0:
|
||||
return np.zeros_like(data)
|
||||
lo = np.percentile(valid, low)
|
||||
hi = np.percentile(valid, high)
|
||||
if hi - lo < 1e-10:
|
||||
return np.zeros_like(data)
|
||||
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
|
||||
return np.nan_to_num(stretched, nan=0.0)
|
||||
|
||||
r_s = linear_stretch(r_data)
|
||||
g_s = linear_stretch(g_data) if band_count > 1 else r_s
|
||||
b_s = linear_stretch(b_data) if band_count > 2 else r_s
|
||||
|
||||
rgb_image = np.nan_to_num(np.stack([r_s, g_s, b_s], axis=2)) * 255
|
||||
rgb_image = rgb_image.astype(np.uint8)
|
||||
|
||||
# 读取掩膜
|
||||
mask_dataset = gdal.Open(mask_path)
|
||||
if mask_dataset is not None:
|
||||
mask_data = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
else:
|
||||
print(f"警告: 无法打开掩膜文件: {mask_path}")
|
||||
mask_data = None
|
||||
|
||||
# Alpha 混合
|
||||
overlay = np.zeros((height, width, 4), dtype=np.uint8)
|
||||
overlay[:, :, 0:3] = mask_color
|
||||
overlay[:, :, 3] = 255 # 全不透明
|
||||
|
||||
blended = rgb_image.astype(np.float32)
|
||||
if mask_data is not None:
|
||||
alpha = mask_data.astype(np.float32) / 255.0 * mask_alpha
|
||||
for c in range(3):
|
||||
blended[:, :, c] = rgb_image[:, :, c].astype(np.float32) * (1 - alpha) + mask_color[c] * alpha
|
||||
blended = blended.astype(np.uint8)
|
||||
|
||||
# 绘图
|
||||
fig, ax = plt.subplots(figsize=(14, 10))
|
||||
ax.imshow(blended)
|
||||
ax.axis('off')
|
||||
|
||||
legend_elements = [
|
||||
Patch(facecolor=f'#{mask_color[0]:02x}{mask_color[1]:02x}{mask_color[2]:02x}',
|
||||
edgecolor='black', alpha=mask_alpha, label='水域范围')
|
||||
]
|
||||
ax.legend(handles=legend_elements, loc='upper right', framealpha=0.9)
|
||||
|
||||
# 面积计算
|
||||
if geotransform and geotransform[1] != 0:
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
pixel_size_y = abs(geotransform[5])
|
||||
pixel_area = pixel_size_x * pixel_size_y
|
||||
|
||||
if mask_data is not None:
|
||||
water_pixels = np.sum(mask_data > 0)
|
||||
valid_pixels = np.sum(mask_data >= 0)
|
||||
water_km2 = water_pixels * pixel_area / 1_000_000
|
||||
valid_km2 = valid_pixels * pixel_area / 1_000_000
|
||||
pct = (water_pixels / valid_pixels * 100) if valid_pixels > 0 else 0
|
||||
|
||||
area_text = (f'水域面积: {water_km2:.2f} km² | '
|
||||
f'影像总面积: {valid_km2:.2f} km² | '
|
||||
f'占比: {pct:.1f}%')
|
||||
ax.text(0.02, 0.98, area_text,
|
||||
transform=ax.transAxes, fontsize=11,
|
||||
color='white', fontweight='bold',
|
||||
bbox=dict(facecolor='#0064FF', alpha=0.8,
|
||||
edgecolor='black', boxstyle='round,pad=0.5'),
|
||||
verticalalignment='top')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.close(fig)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
21
src/core/visualization/__init__.py
Normal file
21
src/core/visualization/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 统一导出接口
|
||||
|
||||
本模块从各子模块导入可视化函数,提供统一的导出接口。
|
||||
"""
|
||||
from src.core.visualization.scatter_plot import generate_model_scatter_plots
|
||||
from src.core.visualization.spectrum_plot import generate_spectrum_comparison_plots
|
||||
from src.core.visualization.boxplot import generate_boxplots
|
||||
from src.core.visualization.statistics import generate_statistical_charts
|
||||
from src.core.visualization.preview import generate_glint_deglint_previews
|
||||
from src.core.visualization.report import generate_pipeline_report
|
||||
|
||||
__all__ = [
|
||||
'generate_model_scatter_plots',
|
||||
'generate_spectrum_comparison_plots',
|
||||
'generate_boxplots',
|
||||
'generate_statistical_charts',
|
||||
'generate_glint_deglint_previews',
|
||||
'generate_pipeline_report',
|
||||
]
|
||||
183
src/core/visualization/boxplot.py
Normal file
183
src/core/visualization/boxplot.py
Normal file
@ -0,0 +1,183 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 箱型图生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
sns.set_style("whitegrid")
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
def generate_boxplots(
|
||||
csv_path: str,
|
||||
parameter_columns: Optional[List[str]] = None,
|
||||
data_start_column: int = 4,
|
||||
save_individual: bool = True,
|
||||
use_seaborn: bool = True,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成水质参数的箱型图
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径
|
||||
parameter_columns: 参数列名列表(如果为None,自动检测)
|
||||
data_start_column: 数据开始列索引(从第几列开始,默认第5列,索引为4)
|
||||
save_individual: 是否为每个参数单独保存箱型图
|
||||
use_seaborn: 是否使用seaborn绘制(更美观)
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
箱型图文件路径字典
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成水质参数箱型图")
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供 csv_path")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
csv_dir = Path(csv_path).parent
|
||||
output_dir = str(csv_dir / "visualization" / "boxplots")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
# 确定参数列
|
||||
if parameter_columns is None:
|
||||
data_columns = df.iloc[:, data_start_column:]
|
||||
parameter_columns = list(data_columns.columns)
|
||||
else:
|
||||
parameter_columns = [col for col in parameter_columns if col in df.columns]
|
||||
|
||||
if not parameter_columns:
|
||||
print("警告: 未找到有效的参数列")
|
||||
return {}
|
||||
|
||||
boxplot_dir = Path(output_dir)
|
||||
boxplot_paths = {}
|
||||
|
||||
if save_individual:
|
||||
print(f"为每个参数单独绘制箱型图(共 {len(parameter_columns)} 个参数)")
|
||||
|
||||
for column in parameter_columns:
|
||||
if column not in df.columns:
|
||||
continue
|
||||
|
||||
clean_data = df[column].dropna()
|
||||
|
||||
if len(clean_data) == 0:
|
||||
print(f"跳过列 '{column}': 没有有效数据")
|
||||
continue
|
||||
|
||||
try:
|
||||
plt.figure(figsize=(8, 6))
|
||||
|
||||
if use_seaborn:
|
||||
plot_data = pd.DataFrame({
|
||||
'参数': [column] * len(clean_data),
|
||||
'数值': clean_data
|
||||
})
|
||||
sns.boxplot(data=plot_data, x='参数', y='数值', palette='Set2')
|
||||
sns.stripplot(data=plot_data, x='参数', y='数值',
|
||||
color='red', alpha=0.6, size=5, jitter=True)
|
||||
else:
|
||||
box_plot = plt.boxplot([clean_data], labels=[column],
|
||||
patch_artist=True, showfliers=False)
|
||||
box_plot['boxes'][0].set_facecolor('lightblue')
|
||||
box_plot['boxes'][0].set_alpha(0.7)
|
||||
|
||||
x_pos = np.random.normal(1, 0.04, size=len(clean_data))
|
||||
plt.scatter(x_pos, clean_data, alpha=0.6, s=30, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
|
||||
plt.title(f'{column} - 箱型图', fontsize=14, fontweight='bold')
|
||||
plt.xlabel('参数', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
|
||||
stats_text = (f'数据点数: {len(clean_data)}\n'
|
||||
f'均值: {clean_data.mean():.2f}\n'
|
||||
f'中位数: {clean_data.median():.2f}\n'
|
||||
f'标准差: {clean_data.std():.2f}')
|
||||
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round',
|
||||
facecolor='wheat' if not use_seaborn else 'lightgreen',
|
||||
alpha=0.8))
|
||||
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
plt.tight_layout()
|
||||
|
||||
safe_column_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
|
||||
save_path = boxplot_dir / f'{safe_column_name}_boxplot.png'
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
boxplot_paths[column] = str(save_path)
|
||||
print(f" 已保存: {save_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" 处理参数 {column} 时出错: {e}")
|
||||
continue
|
||||
|
||||
# 综合箱型图
|
||||
try:
|
||||
print("\n生成综合箱型图(所有参数在一张图上)")
|
||||
plt.figure(figsize=(max(12, len(parameter_columns) * 0.8), 8))
|
||||
|
||||
box_data = []
|
||||
labels = []
|
||||
for column in parameter_columns:
|
||||
if column in df.columns:
|
||||
clean_data = df[column].dropna()
|
||||
if len(clean_data) > 0:
|
||||
box_data.append(clean_data)
|
||||
labels.append(column)
|
||||
|
||||
if box_data:
|
||||
if use_seaborn:
|
||||
melted_data = pd.melt(df[labels], var_name='参数', value_name='数值')
|
||||
melted_data = melted_data.dropna()
|
||||
sns.boxplot(data=melted_data, x='参数', y='数值', palette='Set3')
|
||||
sns.stripplot(data=melted_data, x='参数', y='数值',
|
||||
color='red', alpha=0.6, size=4, jitter=True)
|
||||
else:
|
||||
box_plot = plt.boxplot(box_data, labels=labels, patch_artist=True, showfliers=False)
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(box_data)))
|
||||
for patch, color in zip(box_plot['boxes'], colors):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.7)
|
||||
|
||||
for i, data in enumerate(box_data):
|
||||
x_pos = np.random.normal(i + 1, 0.04, size=len(data))
|
||||
plt.scatter(x_pos, data, alpha=0.6, s=20, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
|
||||
plt.title('水质参数箱型图(综合)', fontsize=16, fontweight='bold')
|
||||
plt.xlabel('参数', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
plt.tight_layout()
|
||||
|
||||
combined_path = boxplot_dir / 'all_parameters_boxplot.png'
|
||||
plt.savefig(combined_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
boxplot_paths['all_parameters'] = str(combined_path)
|
||||
print(f" 已保存综合箱型图: {combined_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"生成综合箱型图时出错: {e}")
|
||||
|
||||
print(f"\n箱型图生成完成,共生成 {len(boxplot_paths)} 个图表")
|
||||
return boxplot_paths
|
||||
59
src/core/visualization/preview.py
Normal file
59
src/core/visualization/preview.py
Normal file
@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 耀斑影像预览图生成
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_glint_deglint_previews(
|
||||
work_dir: str,
|
||||
output_subdir: str = "glint_deglint_previews",
|
||||
generate_glint: bool = True,
|
||||
generate_deglint: bool = True,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成2_glint和3_deglint文件夹中影像文件的PNG预览图
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录
|
||||
output_subdir: 输出子目录名称
|
||||
generate_glint: 是否处理2_glint文件夹
|
||||
generate_deglint: 是否处理3_deglint文件夹
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
生成的预览图路径字典
|
||||
"""
|
||||
print(f"\n{'='*70}")
|
||||
print("步骤: 生成耀斑分析影像预览图")
|
||||
print(f"{'='*70}")
|
||||
|
||||
if work_dir is None:
|
||||
raise ValueError("请提供 work_dir")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(work_dir) / "visualization" / output_subdir)
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
try:
|
||||
preview_paths = visualizer.generate_glint_deglint_previews(
|
||||
work_dir=work_dir,
|
||||
output_subdir=output_subdir,
|
||||
generate_glint=generate_glint,
|
||||
generate_deglint=generate_deglint
|
||||
)
|
||||
|
||||
print(f"耀斑分析影像预览图生成完成,共生成 {len(preview_paths)} 个预览图")
|
||||
return preview_paths
|
||||
|
||||
except Exception as e:
|
||||
print(f"生成耀斑分析影像预览图时出错: {e}")
|
||||
return {}
|
||||
147
src/core/visualization/report.py
Normal file
147
src/core/visualization/report.py
Normal file
@ -0,0 +1,147 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 流程执行报告生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def generate_pipeline_report(
|
||||
step_timings: Dict,
|
||||
pipeline_start_time: Optional[float] = None,
|
||||
pipeline_end_time: Optional[float] = None,
|
||||
output_path: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
生成流程执行报告,包含每步的耗时统计
|
||||
|
||||
Args:
|
||||
step_timings: 步骤耗时字典(格式:{step_name: {start_time, end_time, elapsed_seconds, elapsed_formatted, status, error}})
|
||||
pipeline_start_time: 流程开始时间戳
|
||||
pipeline_end_time: 流程结束时间戳
|
||||
output_path: 输出文件路径(如果为None,自动生成)
|
||||
|
||||
Returns:
|
||||
报告文件路径
|
||||
"""
|
||||
if output_path is None:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
output_path = str(Path.cwd() / "reports" / f"pipeline_report_{timestamp}.csv")
|
||||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _format_time(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.2f}秒"
|
||||
elif seconds < 3600:
|
||||
minutes = int(seconds // 60)
|
||||
secs = seconds % 60
|
||||
return f"{minutes}分{secs:.2f}秒"
|
||||
else:
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = seconds % 60
|
||||
return f"{hours}小时{minutes}分{secs:.2f}秒"
|
||||
|
||||
# 准备报告数据
|
||||
report_data = []
|
||||
total_time = 0.0
|
||||
|
||||
step_order = [
|
||||
"步骤1: 生成水域mask",
|
||||
"步骤2: 找到耀斑区域",
|
||||
"步骤3: 去除耀斑",
|
||||
"步骤4: 处理CSV文件",
|
||||
"步骤5: 提取训练样本点光谱",
|
||||
"步骤5.5: 计算水质光谱指数",
|
||||
"步骤6: 训练机器学习模型",
|
||||
"步骤6.5: 非经验模型训练",
|
||||
"步骤6.75: 自定义回归",
|
||||
"步骤7: 生成预测采样点",
|
||||
"步骤8: 预测水质参数",
|
||||
"步骤9: 生成分布图"
|
||||
]
|
||||
|
||||
for step_name in step_order:
|
||||
if step_name in step_timings:
|
||||
timing_info = step_timings[step_name]
|
||||
report_data.append({
|
||||
'步骤': step_name,
|
||||
'开始时间': timing_info['start_time'],
|
||||
'结束时间': timing_info['end_time'],
|
||||
'耗时(秒)': f"{timing_info['elapsed_seconds']:.2f}",
|
||||
'耗时(格式化)': timing_info['elapsed_formatted'],
|
||||
'状态': timing_info['status'],
|
||||
'错误信息': timing_info.get('error', '')
|
||||
})
|
||||
if timing_info['status'] == 'completed':
|
||||
total_time += timing_info['elapsed_seconds']
|
||||
|
||||
if pipeline_start_time and pipeline_end_time:
|
||||
pipeline_total = pipeline_end_time - pipeline_start_time
|
||||
report_data.append({
|
||||
'步骤': '总计',
|
||||
'开始时间': datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'结束时间': datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'耗时(秒)': f"{pipeline_total:.2f}",
|
||||
'耗时(格式化)': _format_time(pipeline_total),
|
||||
'状态': 'completed',
|
||||
'错误信息': ''
|
||||
})
|
||||
|
||||
df_report = pd.DataFrame(report_data)
|
||||
df_report.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
# 文本格式报告
|
||||
txt_output_path = str(Path(output_path).with_suffix('.txt'))
|
||||
with open(txt_output_path, 'w', encoding='utf-8') as f:
|
||||
f.write("="*80 + "\n")
|
||||
f.write("水质参数反演流程执行报告\n")
|
||||
f.write("="*80 + "\n\n")
|
||||
|
||||
if pipeline_start_time and pipeline_end_time:
|
||||
f.write(f"流程开始时间: {datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"流程结束时间: {datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"总耗时: {_format_time(pipeline_end_time - pipeline_start_time)}\n\n")
|
||||
|
||||
f.write("-"*80 + "\n")
|
||||
f.write("各步骤执行详情:\n")
|
||||
f.write("-"*80 + "\n\n")
|
||||
|
||||
for step_name in step_order:
|
||||
if step_name in step_timings:
|
||||
timing_info = step_timings[step_name]
|
||||
f.write(f"{step_name}\n")
|
||||
f.write(f" 开始时间: {timing_info['start_time']}\n")
|
||||
f.write(f" 结束时间: {timing_info['end_time']}\n")
|
||||
f.write(f" 耗时: {timing_info['elapsed_formatted']} ({timing_info['elapsed_seconds']:.2f}秒)\n")
|
||||
f.write(f" 状态: {timing_info['status']}\n")
|
||||
if timing_info.get('error'):
|
||||
f.write(f" 错误: {timing_info['error']}\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write("-"*80 + "\n")
|
||||
f.write("统计摘要:\n")
|
||||
f.write("-"*80 + "\n")
|
||||
completed_steps = [s for s in step_timings.values() if s['status'] == 'completed']
|
||||
failed_steps = [s for s in step_timings.values() if s['status'] == 'failed']
|
||||
skipped_steps = [s for s in step_timings.values() if s['status'] == 'skipped']
|
||||
|
||||
f.write(f"成功完成的步骤: {len(completed_steps)}\n")
|
||||
f.write(f"失败的步骤: {len(failed_steps)}\n")
|
||||
f.write(f"跳过的步骤: {len(skipped_steps)}\n")
|
||||
|
||||
if completed_steps:
|
||||
completed_times = [s['elapsed_seconds'] for s in completed_steps]
|
||||
f.write(f"平均耗时: {_format_time(np.mean(completed_times))}\n")
|
||||
f.write(f"最长耗时: {_format_time(np.max(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.max(completed_times)][0]})\n")
|
||||
f.write(f"最短耗时: {_format_time(np.min(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.min(completed_times)][0]})\n")
|
||||
|
||||
print(f"\n流程报告已生成:")
|
||||
print(f" CSV格式: {output_path}")
|
||||
print(f" 文本格式: {txt_output_path}")
|
||||
|
||||
return output_path
|
||||
147
src/core/visualization/scatter_plot.py
Normal file
147
src/core/visualization/scatter_plot.py
Normal file
@ -0,0 +1,147 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 散点图生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Union
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_model_scatter_plots(
|
||||
models_dir: str,
|
||||
training_csv_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
metric: str = 'test_r2',
|
||||
use_enhanced: bool = True,
|
||||
feature_start_column: Union[str, int] = 13,
|
||||
test_size: float = 0.2,
|
||||
random_state: int = 42,
|
||||
scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成模型评估散点图(真实值vs预测值)
|
||||
|
||||
Args:
|
||||
models_dir: 模型保存目录
|
||||
training_csv_path: 训练数据CSV路径
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
metric: 选择最佳模型的指标
|
||||
use_enhanced: 是否使用增强版散点图(带置信区间,使用sctter_batch)
|
||||
feature_start_column: 特征开始列名或索引
|
||||
test_size: 测试集比例
|
||||
random_state: 随机种子
|
||||
scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象
|
||||
|
||||
Returns:
|
||||
散点图文件路径字典(键为目标参数名)
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成模型评估散点图")
|
||||
print("="*80)
|
||||
|
||||
if training_csv_path is None:
|
||||
raise ValueError("请提供 training_csv_path")
|
||||
|
||||
models_path = Path(models_dir)
|
||||
if not models_path.exists():
|
||||
raise ValueError(f"模型目录不存在: {models_dir}")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
scatter_paths = {}
|
||||
|
||||
# 增强版散点图
|
||||
if use_enhanced:
|
||||
print("使用增强版散点图(带置信区间)")
|
||||
try:
|
||||
from src.core.prediction.sctter_batch import WaterQualityScatterBatch
|
||||
if scatter_batch is None:
|
||||
scatter_batch = WaterQualityScatterBatch()
|
||||
|
||||
results = scatter_batch.batch_plot_scatter(
|
||||
models_root_dir=models_dir,
|
||||
csv_path=training_csv_path,
|
||||
output_dir=output_dir,
|
||||
metric=metric,
|
||||
target_column=None,
|
||||
feature_start_column=feature_start_column,
|
||||
test_size=test_size,
|
||||
random_state=random_state
|
||||
)
|
||||
|
||||
for target_name, result in results.items():
|
||||
if result.get('status') == 'success':
|
||||
scatter_paths[target_name] = result.get('save_path', '')
|
||||
print(f" ✓ {target_name}: {result.get('save_path', '')}")
|
||||
else:
|
||||
print(f" ✗ {target_name}: 失败 - {result.get('error', '未知错误')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用增强版散点图时出错: {e}")
|
||||
print("回退到基础版散点图")
|
||||
use_enhanced = False
|
||||
|
||||
# 基础版散点图
|
||||
if not use_enhanced or not scatter_paths:
|
||||
print("使用基础版散点图")
|
||||
for target_folder in models_path.iterdir():
|
||||
if not target_folder.is_dir():
|
||||
continue
|
||||
|
||||
target_name = target_folder.name
|
||||
print(f"\n处理目标参数: {target_name}")
|
||||
|
||||
try:
|
||||
inferencer = WaterQualityInference(str(target_folder))
|
||||
eval_result = inferencer.evaluate_with_split(
|
||||
data_csv_path=training_csv_path,
|
||||
split_method="spxy",
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
metric=metric
|
||||
)
|
||||
|
||||
predictions = eval_result.get('predictions', {})
|
||||
if predictions:
|
||||
y_train_true = predictions.get('y_train_true')
|
||||
y_train_pred = predictions.get('y_train_pred')
|
||||
y_test_true = predictions.get('y_test_true')
|
||||
y_test_pred = predictions.get('y_test_pred')
|
||||
metrics = eval_result.get('test_metrics', {})
|
||||
|
||||
if y_train_true is not None and y_test_true is not None:
|
||||
y_all_true = np.concatenate([y_train_true, y_test_true])
|
||||
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
|
||||
|
||||
train_indices = np.arange(len(y_train_true))
|
||||
test_indices = np.arange(len(y_train_true), len(y_all_true))
|
||||
|
||||
scatter_path = visualizer.plot_scatter_true_vs_pred(
|
||||
y_true=y_all_true,
|
||||
y_pred=y_all_pred,
|
||||
target_name=target_name,
|
||||
train_indices=train_indices,
|
||||
test_indices=test_indices,
|
||||
metrics={
|
||||
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
|
||||
'test_r2': metrics.get('r2', 0),
|
||||
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
|
||||
'test_rmse': metrics.get('rmse', 0)
|
||||
}
|
||||
)
|
||||
scatter_paths[target_name] = scatter_path
|
||||
except Exception as e:
|
||||
print(f"处理目标参数 {target_name} 时出错: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表")
|
||||
return scatter_paths
|
||||
80
src/core/visualization/spectrum_plot.py
Normal file
80
src/core/visualization/spectrum_plot.py
Normal file
@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 光谱曲线图生成
|
||||
"""
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Union
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_spectrum_comparison_plots(
|
||||
csv_path: str,
|
||||
parameter_columns: Optional[List[str]] = None,
|
||||
wavelength_start_column: Union[str, int] = "UTM_Y",
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成光谱曲线对比图(不同参数值的光谱曲线对比)
|
||||
|
||||
Args:
|
||||
csv_path: 包含光谱和参数值的CSV文件路径
|
||||
parameter_columns: 参数列名列表(如果为None,自动检测)
|
||||
wavelength_start_column: 波长开始列名或索引
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
光谱曲线图文件路径字典(键为参数名)
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成光谱曲线对比图")
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供 csv_path")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
csv_dir = Path(csv_path).parent
|
||||
output_dir = str(csv_dir / "visualization" / "spectrum_plots")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
# 读取数据以检测参数列
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
if parameter_columns is None:
|
||||
if isinstance(wavelength_start_column, str):
|
||||
try:
|
||||
wavelength_start_idx = df.columns.get_loc(wavelength_start_column)
|
||||
except:
|
||||
wavelength_start_idx = 13
|
||||
else:
|
||||
wavelength_start_idx = wavelength_start_column
|
||||
|
||||
parameter_columns = list(df.columns[:wavelength_start_idx])
|
||||
if len(parameter_columns) > 2:
|
||||
parameter_columns = parameter_columns[2:]
|
||||
|
||||
spectrum_paths = {}
|
||||
for param_col in parameter_columns:
|
||||
if param_col not in df.columns:
|
||||
continue
|
||||
|
||||
print(f"\n处理参数: {param_col}")
|
||||
try:
|
||||
spectrum_path = visualizer.plot_spectrum_by_parameter(
|
||||
csv_path=csv_path,
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wavelength_start_column,
|
||||
n_groups=5
|
||||
)
|
||||
spectrum_paths[param_col] = spectrum_path
|
||||
except Exception as e:
|
||||
print(f"处理参数 {param_col} 时出错: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n光谱曲线图生成完成,共生成 {len(spectrum_paths)} 个图表")
|
||||
return spectrum_paths
|
||||
59
src/core/visualization/statistics.py
Normal file
59
src/core/visualization/statistics.py
Normal file
@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 统计图表生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_statistical_charts(
|
||||
csv_path: str,
|
||||
parameter_columns: Optional[List[str]] = None,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成统计图表(箱线图、直方图、相关性热力图)
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径
|
||||
parameter_columns: 参数列名列表(如果为None,自动检测)
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
统计图表文件路径字典
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成统计图表")
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供 csv_path")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
csv_dir = Path(csv_path).parent
|
||||
output_dir = str(csv_dir / "visualization" / "statistical_charts")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
# 读取数据以检测参数列
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
if parameter_columns is None:
|
||||
parameter_columns = list(df.columns[2:])
|
||||
parameter_columns = [col for col in parameter_columns
|
||||
if df[col].dtype in [np.float64, np.int64]]
|
||||
|
||||
chart_paths = visualizer.plot_statistical_charts(
|
||||
csv_path=csv_path,
|
||||
parameter_columns=parameter_columns
|
||||
)
|
||||
|
||||
print(f"\n统计图表生成完成")
|
||||
return chart_paths
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user