refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块

This commit is contained in:
DXC
2026-05-09 17:18:34 +08:00
parent b2b90050dc
commit dcbcc043e4
17 changed files with 2673 additions and 948 deletions

View File

@ -0,0 +1,36 @@
"""
算法层模块
包含插值算法和耀斑检测算法等核心数学计算
"""
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
from src.core.algorithms.glint_detection.detectors import (
otsu_threshold,
zscore_threshold,
percentile_threshold,
iqr_outlier_detection,
adaptive_threshold,
multi_band_glint_detection,
percentile_stretch,
filter_large_components,
create_shoreline_buffer,
remove_shoreline_buffer,
calculate_glint_mask,
)
__all__ = [
# 插值
'interpolate_pixels',
'interpolate_zero_pixels_batch',
# 耀斑检测
'otsu_threshold',
'zscore_threshold',
'percentile_threshold',
'iqr_outlier_detection',
'adaptive_threshold',
'multi_band_glint_detection',
'percentile_stretch',
'filter_large_components',
'create_shoreline_buffer',
'remove_shoreline_buffer',
'calculate_glint_mask',
]

View File

@ -0,0 +1,31 @@
"""
耀斑检测算法模块
包含各种耀斑检测的核心数学计算函数
"""
from src.core.algorithms.glint_detection.detectors import (
otsu_threshold,
zscore_threshold,
percentile_threshold,
iqr_outlier_detection,
adaptive_threshold,
multi_band_glint_detection,
percentile_stretch,
filter_large_components,
create_shoreline_buffer,
remove_shoreline_buffer,
calculate_glint_mask,
)
__all__ = [
'otsu_threshold',
'zscore_threshold',
'percentile_threshold',
'iqr_outlier_detection',
'adaptive_threshold',
'multi_band_glint_detection',
'percentile_stretch',
'filter_large_components',
'create_shoreline_buffer',
'remove_shoreline_buffer',
'calculate_glint_mask',
]

View File

@ -0,0 +1,595 @@
"""
耀斑检测算法模块
包含各种耀斑检测的核心数学计算函数纯数学逻辑不涉及文件I/O。
支持的方法otsu, zscore, percentile, iqr, adaptive, multi_band
本模块是从 src/utils/find_severe_glint_area.py 抽取出来的核心算法部分。
"""
import numpy as np
from typing import Optional, List, Tuple
from functools import wraps
try:
import cv2
CV2_AVAILABLE = True
except ImportError:
CV2_AVAILABLE = False
def timeit(func):
"""装饰器:测量函数执行时间"""
@wraps(func)
def wrapper(*args, **kwargs):
import time
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print(f"[{func.__name__}] 耗时: {end - start:.2f}s")
return result
return wrapper
# =============================================================================
# 百分位数拉伸
# =============================================================================
def percentile_stretch(
img: np.ndarray,
data_water_mask: np.ndarray,
lower_percentile: float = 2,
upper_percentile: float = 98,
output_range: Tuple[int, int] = (0, 255)
) -> np.ndarray:
"""
使用百分位数裁剪进行归一化,适用于低反射率数据
通过排除极值,更好地利用数据的动态范围
Args:
img: 输入图像数组反射率值通常在0-1之间
data_water_mask: 水域掩膜
lower_percentile: 下百分位数用于裁剪最小值默认2
upper_percentile: 上百分位数用于裁剪最大值默认98
output_range: 输出范围,默认(0, 255)
Returns:
归一化后的图像数组(整数类型)
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return img.astype(np.int32)
p_lower = np.percentile(valid_pixels, lower_percentile)
p_upper = np.percentile(valid_pixels, upper_percentile)
if p_lower >= p_upper:
p_lower = np.percentile(valid_pixels, 1)
p_upper = np.percentile(valid_pixels, 99)
if p_lower >= p_upper:
p_upper = valid_pixels.max()
p_lower = valid_pixels.min()
img_clipped = np.clip(img, p_lower, p_upper)
if p_upper > p_lower:
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (
output_range[1] - output_range[0]
) + output_range[0]
else:
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
return img_stretched.astype(np.int32)
# =============================================================================
# Otsu阈值分割
# =============================================================================
def otsu_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
ignore_value: int = 0,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于Otsu方法的自动阈值分割
通过最大化类间方差找到最佳分割阈值
Args:
img: 输入图像数组(整数值)
data_water_mask: 水域掩膜
ignore_value: 忽略的值默认为0
foreground: 耀斑区域值默认1
background: 背景值默认0
Returns:
二值化检测结果数组
"""
height, width = img.shape
max_value = int(np.max(img[img > ignore_value])) + 1
if max_value < 2:
max_value = 256
hist = np.zeros([max_value], np.float32)
invalid_counter = 0
for i in range(height):
for j in range(width):
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
invalid_counter += 1
continue
hist[img[i, j]] += 1
total_valid = height * width - invalid_counter
if total_valid <= 0:
return np.zeros_like(img, dtype=np.int32)
hist /= total_valid
threshold = 0
deltaMax = 0
for i in range(max_value):
wA = sum(hist[:i + 1])
wB = sum(hist[i + 1:])
if wA == 0:
wA = 1e-10
if wB == 0:
wB = 1e-10
uAtmp = sum(j * hist[j] for j in range(i + 1))
uBtmp = sum(j * hist[j] for j in range(i + 1, max_value))
uA = uAtmp / wA
uB = uBtmp / wB
u = uAtmp + uBtmp
deltaTmp = wA * ((uA - u) ** 2) + wB * ((uB - u) ** 2)
if deltaTmp > deltaMax:
deltaMax = deltaTmp
threshold = i
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > threshold] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# Z-score阈值检测
# =============================================================================
def zscore_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
z_threshold: float = 2.5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于Z-score标准化分数的耀斑检测方法
使用统计方法识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
z_threshold: Z-score阈值默认2.5即超过均值2.5个标准差)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return np.zeros_like(img, dtype=np.int32)
mean_val = np.mean(valid_pixels)
std_val = np.std(valid_pixels)
if std_val == 0:
return np.zeros_like(img, dtype=np.int32)
z_scores = np.zeros_like(img, dtype=np.float32)
valid_mask = (data_water_mask > 0) & np.isfinite(img)
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
det_img = np.zeros_like(img, dtype=np.int32)
det_img[z_scores > z_threshold] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# 百分位数阈值检测
# =============================================================================
def percentile_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
percentile: float = 95,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于百分位数的耀斑检测方法
使用百分位数作为阈值,对异常值更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
percentile: 百分位数阈值默认95即超过95%的像素值)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return np.zeros_like(img, dtype=np.int32)
threshold_val = np.percentile(valid_pixels, percentile)
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > threshold_val] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# IQR异常值检测
# =============================================================================
def iqr_outlier_detection(
img: np.ndarray,
data_water_mask: np.ndarray,
iqr_multiplier: float = 1.5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于IQR四分位距的异常值检测方法
使用四分位距识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
iqr_multiplier: IQR倍数默认1.5(标准异常值检测)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return np.zeros_like(img, dtype=np.int32)
q1 = np.percentile(valid_pixels, 25)
q3 = np.percentile(valid_pixels, 75)
iqr = q3 - q1
upper_bound = q3 + iqr_multiplier * iqr
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > upper_bound] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# 自适应阈值检测
# =============================================================================
def adaptive_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
window_size: int = 15,
percentile: float = 90,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
自适应阈值方法
基于局部统计特性进行阈值分割,对光照变化更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
window_size: 局部窗口大小(奇数)
percentile: 局部百分位数阈值
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
height, width = img.shape
if window_size % 2 == 0:
window_size += 1
half_window = window_size // 2
det_img = np.zeros_like(img, dtype=np.int32)
for i in range(half_window, height - half_window):
for j in range(half_window, width - half_window):
if data_water_mask[i, j] == 0:
continue
local_window = img[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
local_mask = data_water_mask[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
valid_pixels = local_window[local_mask > 0]
if len(valid_pixels) > 0:
local_th = np.percentile(valid_pixels, percentile)
if img[i, j] > local_th:
det_img[i, j] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# 多波段融合耀斑检测
# =============================================================================
def multi_band_glint_detection(
nir_band: np.ndarray,
water_mask: np.ndarray,
glint_waves: List[float],
weights: Optional[List[float]] = None,
method: str = 'zscore',
z_threshold: float = 2.5,
percentile: float = 95,
sub_band_arrays: Optional[List[np.ndarray]] = None
) -> np.ndarray:
"""
多波段融合的耀斑检测方法
结合多个波段的耀斑特征,提高检测的稳健性
Args:
nir_band: 近红外波段数组(主波段,用于兼容性)
water_mask: 水域掩膜数组
glint_waves: 用于检测的波长列表,如[750, 800, 850]
weights: 各波段的权重如果为None则使用等权重
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
z_threshold: Z-score阈值当method='zscore'时使用)
percentile: 百分位数阈值当method='percentile'时使用)
sub_band_arrays: 子波段数组列表(如果提供,与 glint_waves 一一对应)
Returns:
二值化检测结果
"""
if weights is None:
weights = [1.0 / len(glint_waves)] * len(glint_waves)
if len(weights) != len(glint_waves):
raise ValueError("权重数量必须与波长数量相同")
fused_band = None
if sub_band_arrays is not None and len(sub_band_arrays) == len(glint_waves):
for i, band_array in enumerate(sub_band_arrays):
if fused_band is None:
fused_band = (band_array * weights[i]).astype(np.float32)
else:
fused_band = (fused_band + band_array * weights[i]).astype(np.float32)
else:
fused_band = nir_band.astype(np.float32)
if method == 'otsu':
stretched = percentile_stretch(fused_band, water_mask, 2, 98)
return otsu_threshold(stretched, water_mask)
elif method == 'zscore':
return zscore_threshold(fused_band, water_mask, z_threshold)
elif method == 'percentile':
return percentile_threshold(fused_band, water_mask, percentile)
else:
raise ValueError(f"不支持的方法: {method}")
# =============================================================================
# 连通域过滤
# =============================================================================
def filter_large_components(
binary_img: np.ndarray,
max_area: Optional[int] = None,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
过滤掉面积超过阈值的连通域
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
Args:
binary_img: 二值化图像
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
foreground: 前景值
background: 背景值
Returns:
过滤后的二值化图像
"""
if max_area is None or max_area <= 0:
return binary_img
if CV2_AVAILABLE:
binary_for_label = (binary_img == foreground).astype(np.uint8)
num_features, labeled_array, stats, _ = cv2.connectedComponentsWithStats(
binary_for_label, connectivity=8
)
if num_features == 0:
return binary_img
component_sizes = stats[1:, cv2.CC_STAT_AREA]
keep_labels = np.where(component_sizes <= max_area)[0] + 1
keep_mask = np.isin(labeled_array, keep_labels)
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
filtered[keep_mask] = foreground
return filtered
else:
from scipy import ndimage
labeled_array, num_features = ndimage.label(
(binary_img == foreground).astype(np.int32)
)
if num_features == 0:
return binary_img
component_sizes = ndimage.sum(
(labeled_array == i).astype(np.int32),
labeled_array,
range(1, num_features + 1)
)
keep_mask = np.isin(labeled_array, [i + 1 for i, s in enumerate(component_sizes) if s <= max_area])
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
filtered[keep_mask] = foreground
return filtered
# =============================================================================
# 岸边缓冲区处理
# =============================================================================
def create_shoreline_buffer(
water_mask: np.ndarray,
buffer_size: int = 5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
创建岸边缓冲区掩膜(向内缓冲)
用于去除岸边附近的错误耀斑检测区域
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
Args:
water_mask: 水域掩膜数组(水域=1非水域=0
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
岸边缓冲区掩膜(缓冲区区域=1其他=0
"""
if buffer_size <= 0:
return np.zeros_like(water_mask, dtype=np.int32)
water_binary = (water_mask > 0).astype(np.uint8)
structure_size = buffer_size * 2 + 1
structure = np.ones((structure_size, structure_size), dtype=np.uint8)
if CV2_AVAILABLE:
eroded_water = cv2.erode(water_binary, structure).astype(np.int32)
else:
from scipy import ndimage
eroded_water = ndimage.binary_erosion(water_binary, structure).astype(np.int32)
buffer_mask = (water_binary - eroded_water).astype(np.int32)
return buffer_mask
def remove_shoreline_buffer(
glint_mask: np.ndarray,
water_mask: np.ndarray,
buffer_size: int = 5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
从耀斑掩膜中去除岸边缓冲区内的区域
Args:
glint_mask: 耀斑掩膜数组
water_mask: 水域掩膜数组
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
去除岸边缓冲区后的耀斑掩膜
"""
if buffer_size <= 0:
return glint_mask
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
cleaned = glint_mask.copy()
cleaned[buffer_mask > 0] = background
return cleaned
# =============================================================================
# 高级组合函数
# =============================================================================
def calculate_glint_mask(
nir_band: np.ndarray,
water_mask: np.ndarray,
method: str = 'otsu',
z_threshold: float = 2.5,
percentile: float = 95,
iqr_multiplier: float = 1.5,
window_size: int = 15,
apply_percentile_stretch: bool = True
) -> np.ndarray:
"""
计算耀斑掩膜的统一入口函数
Args:
nir_band: 近红外波段数组
water_mask: 水域掩膜
method: 检测方法 ('otsu', 'zscore', 'percentile', 'iqr', 'adaptive')
z_threshold: Z-score阈值
percentile: 百分位数阈值
iqr_multiplier: IQR倍数
window_size: 自适应阈值窗口大小
apply_percentile_stretch: 是否对otsu和adaptive方法应用百分位数拉伸
Returns:
二值化耀斑掩膜
"""
if method == 'otsu':
if apply_percentile_stretch:
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
return otsu_threshold(stretched, water_mask)
else:
return otsu_threshold(nir_band.astype(np.int32), water_mask)
elif method == 'zscore':
return zscore_threshold(nir_band, water_mask, z_threshold)
elif method == 'percentile':
return percentile_threshold(nir_band, water_mask, percentile)
elif method == 'iqr':
return iqr_outlier_detection(nir_band, water_mask, iqr_multiplier)
elif method == 'adaptive':
if apply_percentile_stretch:
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
return adaptive_threshold(stretched, water_mask, window_size, percentile)
else:
return adaptive_threshold(nir_band.astype(np.int32), water_mask, window_size, percentile)
else:
raise ValueError(f"不支持的方法: {method}")

View File

@ -0,0 +1,7 @@
"""
插值算法模块
包含0值像素插值的核心数学逻辑
"""
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
__all__ = ['interpolate_pixels', 'interpolate_zero_pixels_batch']

View File

@ -0,0 +1,320 @@
"""
像素插值算法模块
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
支持多种插值方法nearest, bilinear, spline (RBF), kriging。
"""
import numpy as np
from typing import Optional, Union, Tuple, List
from pathlib import Path
try:
from scipy import ndimage
from scipy.interpolate import griddata, RBFInterpolator
from scipy.spatial import cKDTree
SCIPY_AVAILABLE = True
except ImportError:
SCIPY_AVAILABLE = False
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
def interpolate_pixels(
image_stack: np.ndarray,
zero_coords: np.ndarray,
valid_coords: np.ndarray,
valid_values: np.ndarray,
interpolation_method: str = 'nearest',
water_mask: Optional[np.ndarray] = None
) -> np.ndarray:
"""
对指定坐标的像素进行插值核心数学函数不涉及文件I/O
Args:
image_stack: 影像数据堆叠,形状为 (height, width, n_bands) 的 float32 数组
zero_coords: 需要插值的像素坐标,形状为 (n_zero, 2),每行是 [x, y]
valid_coords: 有效像素坐标,形状为 (n_valid, 2)
valid_values: 有效像素对应的值,形状为 (n_valid,) 或 (n_valid, n_bands)
interpolation_method: 插值方法,可选 'nearest', 'bilinear', 'spline', 'kriging'
water_mask: 可选的水域掩膜数组
Returns:
插值后的影像副本,形状与 image_stack 相同
"""
if not SCIPY_AVAILABLE:
raise ImportError("scipy未安装无法进行0值像素插值")
height, width, n_bands = image_stack.shape
result = image_stack.copy()
# 兼容中文和各种格式的method参数
raw_method = str(interpolation_method).lower()
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
method = 'nearest'
elif 'bilinear' in raw_method or '线性' in raw_method or '双线性' in raw_method:
method = 'bilinear'
elif 'spline' in raw_method or '样条' in raw_method or 'rbf' in raw_method:
method = 'spline'
elif 'kriging' in raw_method or '克里金' in raw_method:
method = 'kriging'
else:
method = 'nearest'
if len(valid_values) == 0:
return result
is_multiband = len(valid_values.shape) > 1 and valid_values.shape[1] > 1
if is_multiband:
for band_idx in range(n_bands):
band_valid_values = valid_values[:, band_idx]
interpolated_values = _interpolate_single_band(
zero_coords, valid_coords, band_valid_values, method
)
y_coords = zero_coords[:, 1].astype(int)
x_coords = zero_coords[:, 0].astype(int)
result[y_coords, x_coords, band_idx] = interpolated_values
else:
interpolated_values = _interpolate_single_band(
zero_coords, valid_coords, valid_values, method
)
y_coords = zero_coords[:, 1].astype(int)
x_coords = zero_coords[:, 0].astype(int)
result[y_coords, x_coords] = interpolated_values
return result
def _interpolate_single_band(
zero_coords: np.ndarray,
valid_coords: np.ndarray,
valid_values: np.ndarray,
method: str
) -> np.ndarray:
"""对单个波段执行插值计算"""
if method == 'nearest':
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords)
return valid_values[indices]
elif method == 'bilinear':
interpolated = griddata(
valid_coords, valid_values, zero_coords,
method='linear', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
elif method == 'spline':
try:
max_points = 10000
if len(valid_values) > max_points:
indices = np.random.choice(len(valid_values), max_points, replace=False)
sample_coords = valid_coords[indices]
sample_values = valid_values[indices]
else:
sample_coords = valid_coords
sample_values = valid_values
rbf = RBFInterpolator(sample_coords, sample_values, kernel='thin_plate_spline')
interpolated = rbf(zero_coords)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
except Exception:
interpolated = griddata(
valid_coords, valid_values, zero_coords,
method='linear', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
elif method == 'kriging':
try:
from src.utils.kriging import KrigingInterpolator
interpolator = KrigingInterpolator()
max_points = 5000
if len(valid_values) > max_points:
indices = np.random.choice(len(valid_values), max_points, replace=False)
sample_coords = valid_coords[indices]
sample_values = valid_values[indices]
else:
sample_coords = valid_coords
sample_values = valid_values
interpolated = griddata(
sample_coords, sample_values, zero_coords,
method='cubic', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
except Exception:
interpolated = griddata(
valid_coords, valid_values, zero_coords,
method='linear', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
return np.zeros(len(zero_coords))
def interpolate_zero_pixels_batch(
img_path: str,
interpolation_method: str = 'nearest',
output_path: Optional[str] = None,
water_mask: Optional[Union[str, np.ndarray]] = None,
deglint_dir: Optional[str] = None,
callback_progress: Optional[callable] = None
) -> Tuple[str, Optional[np.ndarray]]:
"""
对影像中所有波段都为0的像素点进行插值完整流程含文件I/O
Args:
img_path: 输入影像文件路径
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
output_path: 输出文件路径如果为None自动生成
water_mask: 水域掩膜(文件路径或数组)
deglint_dir: 去耀斑目录(用于生成默认输出路径)
callback_progress: 进度回调函数
Returns:
(output_path, interpolated_image_stack) 元组
"""
if not SCIPY_AVAILABLE:
raise ImportError("scipy未安装无法进行0值像素插值")
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
# 确定输出路径
if output_path is None and deglint_dir is not None:
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
# 检查文件是否已存在
if output_path and Path(output_path).exists():
return output_path, None
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
# 读取所有波段数据
all_bands = []
for band_idx in range(1, n_bands + 1):
band = dataset.GetRasterBand(band_idx)
band_data = band.ReadAsArray().astype(np.float32)
all_bands.append(band_data)
image_stack = np.dstack(all_bands)
# 读取水域掩膜
mask_array = None
if water_mask is not None:
if isinstance(water_mask, str):
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
elif isinstance(water_mask, np.ndarray):
mask_array = water_mask
# 找出所有波段都为0的像素点
all_bands_zero = np.all(image_stack == 0, axis=2)
if mask_array is not None:
all_bands_zero = all_bands_zero & (mask_array > 0)
zero_pixel_count = np.sum(all_bands_zero)
if zero_pixel_count == 0:
# 无需插值,直接保存
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(all_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
out_dataset = None
return output_path, image_stack
# 获取坐标
zero_y, zero_x = np.where(all_bands_zero)
zero_coords = np.column_stack([zero_x, zero_y])
valid_mask = ~all_bands_zero
valid_y, valid_x = np.where(valid_mask)
valid_coords = np.column_stack([valid_x, valid_y])
if len(valid_coords) == 0:
raise ValueError("没有有效像素可用于插值")
# 逐波段插值
interpolated_bands = []
for band_idx in range(n_bands):
if callback_progress:
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
band_data = all_bands[band_idx].copy()
valid_values_band = band_data[valid_mask]
if len(valid_values_band) == 0:
interpolated_bands.append(band_data)
continue
band_result = _interpolate_single_band(
zero_coords, valid_coords, valid_values_band, interpolation_method
)
band_data[all_bands_zero] = band_result
interpolated_bands.append(band_data)
# 保存结果
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(interpolated_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
out_dataset = None
result_stack = np.dstack(interpolated_bands)
return output_path, result_stack
finally:
dataset = None

View File

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
"""
工具模块 - 统一导出接口
"""
from src.core.utils.gdal_helper import (
get_image_geo_info,
load_image_as_array,
save_array_as_image,
save_bands_as_image,
copy_hdr_info,
read_band_as_array,
read_multiple_bands,
)
from src.core.utils.mask_converter import (
prepare_water_mask_for_algorithm,
ensure_water_mask_dat,
)
from src.core.utils.preview_generator import (
generate_image_preview,
generate_water_mask_overlay,
select_rgb_bands_by_wavelength,
get_wavelength_info,
)
__all__ = [
# GDAL IO
'get_image_geo_info',
'load_image_as_array',
'save_array_as_image',
'save_bands_as_image',
'copy_hdr_info',
'read_band_as_array',
'read_multiple_bands',
# 掩膜转换
'prepare_water_mask_for_algorithm',
'ensure_water_mask_dat',
# 预览图生成
'generate_image_preview',
'generate_water_mask_overlay',
'select_rgb_bands_by_wavelength',
'get_wavelength_info',
]

View File

@ -0,0 +1,309 @@
# -*- coding: utf-8 -*-
"""
GDAL 底层 IO 工具模块
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
"""
import os
from pathlib import Path
from typing import Tuple, Optional
import numpy as np
# GDAL 导入(可选)
try:
from osgeo import gdal, ogr, gdal_array
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# hdr 文件工具
try:
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
UTIL_AVAILABLE = True
except ImportError:
UTIL_AVAILABLE = False
write_fields_to_hdrfile = None
get_hdr_file_path = None
# ============================================================
# 影像信息读取
# ============================================================
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
"""
获取影像的地理信息(不加载图像数据,节省内存)
Args:
img_path: 影像文件路径
Returns:
tuple: (geotransform, projection, width, height, n_bands)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
return geotransform, projection, width, height, n_bands
finally:
dataset = None
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
"""
加载影像文件为numpy数组
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
建议直接传递文件路径给算法类让算法类使用GDAL逐波段处理。
Args:
img_path: 影像文件路径
Returns:
tuple: (image_array, geotransform, projection)
image_array: numpy数组形状为(height, width, bands)
geotransform: 地理变换参数
projection: 投影信息
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
image_bands = []
for i in range(1, n_bands + 1):
band = dataset.GetRasterBand(i)
band_data = band.ReadAsArray()
image_bands.append(band_data)
image_array = np.dstack(image_bands)
return image_array, geotransform, projection
finally:
dataset = None
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
"""
读取单个波段为 numpy 数组
Args:
img_path: 影像文件路径
band_index: 波段索引(从 0 开始)
Returns:
numpy 数组,形状为 (height, width)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
band = dataset.GetRasterBand(band_index + 1)
return band.ReadAsArray()
finally:
dataset = None
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
"""
读取多个指定波段为列表
Args:
img_path: 影像文件路径
band_indices: 波段索引列表
Returns:
tuple: (band_list, geotransform, projection)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
bands = []
for idx in band_indices:
band = dataset.GetRasterBand(idx + 1)
bands.append(band.ReadAsArray())
return bands, geotransform, projection
finally:
dataset = None
# ============================================================
# 影像写入
# ============================================================
def save_array_as_image(image_array: np.ndarray, output_path: str,
geotransform: tuple, projection: str,
dtype=None) -> str:
"""
将numpy数组保存为影像文件
Args:
image_array: numpy数组形状为(height, width, bands) 或 (height, width)
output_path: 输出文件路径
geotransform: 地理变换参数
projection: 投影信息
dtype: GDAL数据类型默认 gdal.GDT_Float32
Returns:
输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if dtype is None:
dtype = gdal.GDT_Float32
if image_array.ndim == 2:
height, width = image_array.shape
n_bands = 1
else:
height, width, n_bands = image_array.shape
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise ValueError("无法创建影像文件,没有可用的驱动")
dataset = driver.Create(output_path, width, height, n_bands, dtype)
if dataset is None:
raise ValueError(f"无法创建输出文件: {output_path}")
try:
dataset.SetGeoTransform(geotransform)
dataset.SetProjection(projection)
if n_bands == 1:
band = dataset.GetRasterBand(1)
band.WriteArray(image_array)
band.FlushCache()
else:
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(image_array[:, :, i])
band.FlushCache()
finally:
dataset = None
return output_path
def save_bands_as_image(corrected_bands: list, output_path: str,
geotransform: tuple, projection: str,
dtype=None) -> str:
"""
直接从波段列表保存影像文件(避免堆叠,节省内存)
Args:
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
output_path: 输出文件路径
geotransform: 地理变换参数
projection: 投影信息
dtype: GDAL数据类型
Returns:
输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if not corrected_bands:
raise ValueError("波段列表为空")
if dtype is None:
dtype = gdal.GDT_Float32
n_bands = len(corrected_bands)
height, width = corrected_bands[0].shape
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise ValueError("无法创建影像文件,没有可用的驱动")
dataset = driver.Create(output_path, width, height, n_bands, dtype)
if dataset is None:
raise ValueError(f"无法创建输出文件: {output_path}")
try:
dataset.SetGeoTransform(geotransform)
dataset.SetProjection(projection)
for i, band_array in enumerate(corrected_bands):
if band_array.shape != (height, width):
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
band = dataset.GetRasterBand(i + 1)
band.WriteArray(band_array)
band.FlushCache()
finally:
dataset = None
return output_path
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
"""
复制原始影像的hdr文件信息如波长等到目标影像的hdr文件
Args:
source_img_path: 源影像文件路径原始bsq文件
dest_img_path: 目标影像文件路径去耀斑后的bsq文件
Returns:
bool: 是否成功
"""
if not UTIL_AVAILABLE:
print("警告: util模块未导入无法复制hdr文件信息")
return False
try:
source_hdr_path = get_hdr_file_path(source_img_path)
dest_hdr_path = get_hdr_file_path(dest_img_path)
if not Path(source_hdr_path).exists():
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
return False
if not Path(dest_hdr_path).exists():
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
return False
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
return True
except Exception as e:
print(f"警告: 复制hdr文件信息时出错: {e}")
return False

View File

@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
"""
掩膜转换工具模块
提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换,
以及水体掩膜的预处理逻辑。
"""
import os
from pathlib import Path
from typing import Optional, Union
import numpy as np
try:
from osgeo import gdal, ogr
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
def prepare_water_mask_for_algorithm(
water_mask: Optional[Union[str, np.ndarray]],
image_shape: Union[tuple, np.ndarray],
geotransform: tuple,
projection: str,
img_path: str,
water_mask_dir: Optional[str] = None,
callback=None
) -> Optional[np.ndarray]:
"""
准备水域掩膜供算法使用
支持格式:
- None自动使用预先生成的 dat 格式掩膜
- numpy.ndarray直接返回确保是 0/1 格式)
- .dat / .tif 等栅格文件:读取并返回
- .shp 文件:先栅格化,再读取返回
Args:
water_mask: 掩膜来源
image_shape: 影像形状 (height, width) 或 (height, width, channels)
geotransform: GDAL 地理变换参数
projection: 投影信息
img_path: 影像路径(用于 shp 栅格化)
water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果)
callback: 进度回调函数(可选)
Returns:
numpy数组dtype=uint80=非水域1=水域)或 None
"""
img_height, img_width = image_shape[0], image_shape[1]
if water_mask is None:
return None
# numpy 数组直接返回
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (img_height, img_width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配")
return (water_mask > 0).astype(np.uint8)
# 字符串路径
if isinstance(water_mask, str):
ext = Path(water_mask).suffix.lower()
# shapefile 格式
if ext == '.shp':
return _convert_shp_to_mask(
shp_path=water_mask,
img_path=img_path,
image_shape=image_shape,
geotransform=geotransform,
projection=projection,
water_mask_dir=water_mask_dir,
callback=callback
)
# 栅格文件格式
return _load_raster_mask(water_mask, img_height, img_width)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def _convert_shp_to_mask(shp_path: str, img_path: str,
image_shape: tuple,
geotransform: tuple,
projection: str,
water_mask_dir: Optional[str] = None,
callback=None) -> np.ndarray:
"""将 shapefile 栅格化为掩膜数组"""
from src.utils.extract_water_area import rasterize_shp
safe_shp_path = os.path.abspath(shp_path).replace('\\', '/')
shp_name = Path(safe_shp_path).stem
if water_mask_dir:
temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat")
else:
temp_mask_path = f"/tmp/water_mask_{shp_name}.dat"
# 缓存:已栅格化则直接读取
if Path(temp_mask_path).exists():
print(f"使用已存在的栅格化掩膜: {temp_mask_path}")
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
# 需要栅格化
if img_path is None:
raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化")
print(f"正在将 SHP 栅格化: {safe_shp_path}")
rasterize_shp(safe_shp_path, temp_mask_path, img_path)
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray:
"""从栅格文件加载掩膜"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取掩膜文件")
mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {mask_path}")
try:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
finally:
mask_dataset = None
if mask_array.shape != (img_height, img_width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
def ensure_water_mask_dat(img_path: str,
existing_dat_path: Optional[str] = None,
output_dir: Optional[str] = None) -> str:
"""
确保存在 dat 格式的水体掩膜文件用于步骤3/4中的算法
如果 existing_dat_path 存在且是 .dat 文件,直接返回。
如果存在同名 .dat 文件,直接返回。
否则从 img_path 生成并保存到 output_dir。
Args:
img_path: 用于生成掩膜的遥感影像路径
existing_dat_path: 已有的 dat 格式掩膜路径(可选)
output_dir: 输出目录(可选)
Returns:
dat 格式掩膜文件路径
"""
if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat':
if Path(existing_dat_path).exists():
return existing_dat_path
img_name = Path(img_path).stem
if output_dir is None:
output_dir = str(Path(img_path).parent)
dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat")
if Path(dat_path).exists():
return dat_path
# 如果已有其他格式的掩膜,转换为 dat
for ext in ['.tif', '.img', '.tiff']:
alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}")
if Path(alt_path).exists():
return _convert_to_dat(alt_path, dat_path)
return dat_path # 返回目标路径,让调用方决定是否需要生成
def _convert_to_dat(src_path: str, dest_path: str) -> str:
"""将其他栅格格式转换为 ENVI dat 格式"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法转换格式")
src_ds = gdal.Open(src_path, gdal.GA_ReadOnly)
if src_ds is None:
raise ValueError(f"无法打开源掩膜文件: {src_path}")
try:
geotransform = src_ds.GetGeoTransform()
projection = src_ds.GetProjection()
band = src_ds.GetRasterBand(1)
array = band.ReadAsArray()
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte)
if dest_ds is None:
raise ValueError(f"无法创建输出文件: {dest_path}")
try:
dest_ds.SetGeoTransform(geotransform)
dest_ds.SetProjection(projection)
dest_band = dest_ds.GetRasterBand(1)
dest_band.WriteArray((array > 0).astype(np.uint8))
dest_band.FlushCache()
finally:
dest_ds = None
return dest_path
finally:
src_ds = None

View File

@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-
"""
遥感影像预览图生成工具模块
提供高光谱影像的 RGB 预览图、水域掩膜叠加图等可视化功能。
"""
import numpy as np
from pathlib import Path
from typing import Optional, List
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# matplotlib 仅在实际使用时导入preview_generator 是可视化工具)
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
# ============================================================
# 辅助函数:波段选择
# ============================================================
def select_rgb_bands_by_wavelength(band_count: int,
wavelength_info: Optional[List[float]] = None,
fallback_bands: Optional[List[int]] = None) -> List[int]:
"""
根据波长自动选择 RGB 波段
Args:
band_count: 总波段数
wavelength_info: 各波段波长列表nm长度为 band_count
fallback_bands: 当无法通过波长选择时的回退波段索引 [R, G, B]
Returns:
波段索引列表 [R_index, G_index, B_index]0-based
"""
if fallback_bands is None:
fallback_bands = [band_count - 3, band_count - 2, band_count - 1]
if wavelength_info is None:
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
# 目标波长nm
TARGET_R = 650
TARGET_G = 550
TARGET_B = 460
def find_closest(target: float) -> int:
min_dist = float('inf')
best_idx = 0
for i, wl in enumerate(wavelength_info):
dist = abs(wl - target)
if dist < min_dist:
min_dist = dist
best_idx = i
return best_idx
try:
r_idx = find_closest(TARGET_R)
g_idx = find_closest(TARGET_G)
b_idx = find_closest(TARGET_B)
return [r_idx, g_idx, b_idx]
except Exception:
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
def get_wavelength_info(img_path: str) -> Optional[List[float]]:
"""从 hdr 文件读取波长信息"""
try:
hdr_path = Path(img_path).with_suffix('.hdr')
if not hdr_path.exists():
return None
wavelengths = []
in_wl = False
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.strip()
if line.startswith('wavelength ='):
in_wl = True
line = line.split('=', 1)[1].strip()
elif in_wl:
if line.startswith('{'):
line = line[1:]
if line.endswith('}'):
line = line[:-1]
in_wl = False
# 解析逗号分隔的数值
for token in line.replace(',', ' ').split():
try:
wavelengths.append(float(token))
except ValueError:
pass
return wavelengths if wavelengths else None
except Exception:
return None
# ============================================================
# 核心预览图生成函数
# ============================================================
def generate_image_preview(img_path: str,
output_path: str,
bands: Optional[List[int]] = None,
title: str = "影像预览") -> str:
"""
生成高光谱影像的 PNG 预览图
Args:
img_path: 输入影像路径
output_path: 输出 PNG 文件路径
bands: RGB 波段索引 [R, G, B]None 则自动选择
title: 图片标题
Returns:
生成的 PNG 文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法生成影像预览图")
if Path(output_path).exists():
print(f"检测到已存在的预览图,跳过生成: {output_path}")
return output_path
dataset = gdal.Open(img_path)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
band_count = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
# 自动选择波段
if bands is None:
if band_count >= 3:
wl_info = get_wavelength_info(img_path)
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
else:
bands = [0, 0, 0]
# 读取波段
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
r_data[r_data <= 0] = np.nan
if band_count > 1:
g_data[g_data <= 0] = np.nan
if band_count > 2:
b_data[b_data <= 0] = np.nan
# 线性拉伸
def linear_stretch(data, low=2, high=98):
valid = data[~np.isnan(data)]
if len(valid) == 0:
return np.zeros_like(data)
lo = np.percentile(valid, low)
hi = np.percentile(valid, high)
if hi - lo < 1e-10:
return np.zeros_like(data)
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
return np.nan_to_num(stretched, nan=0.0)
r_s = linear_stretch(r_data)
g_s = linear_stretch(g_data) if band_count > 1 else r_s
b_s = linear_stretch(b_data) if band_count > 2 else r_s
rgb_image = np.stack([r_s, g_s, b_s], axis=2)
# 绘图
fig, ax = plt.subplots(figsize=(12, 10))
ax.imshow(rgb_image)
ax.set_title(title, fontsize=12, fontweight='bold')
ax.axis('off')
geotransform = dataset.GetGeoTransform()
if geotransform and geotransform[1] != 0:
pixel_size_x = abs(geotransform[1])
scale_text = f"分辨率: {pixel_size_x:.2f} m/px | 尺寸: {width} x {height} px"
fig.text(0.5, 0.02, scale_text, ha='center', fontsize=9,
color='white',
bbox=dict(facecolor='black', alpha=0.6,
boxstyle='round,pad=0.3'))
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
return output_path
finally:
dataset = None
def generate_water_mask_overlay(img_path: str,
mask_path: str,
output_path: str,
bands: Optional[List[int]] = None,
mask_color: tuple = (0, 100, 255),
mask_alpha: float = 0.5) -> str:
"""
生成水域掩膜叠加到原图的 PNG 图像
Args:
img_path: 输入影像路径
mask_path: 水域掩膜文件路径
output_path: 输出 PNG 路径
bands: RGB 波段索引None 则自动选择
mask_color: 掩膜叠加颜色 (R, G, B)
mask_alpha: 掩膜透明度0=完全透明1=完全不透明)
Returns:
生成的 PNG 文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法生成叠加图")
if Path(output_path).exists():
print(f"检测到已存在的叠加图,跳过生成: {output_path}")
return output_path
dataset = gdal.Open(img_path)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
band_count = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
# 自动选择波段
if bands is None:
if band_count >= 3:
wl_info = get_wavelength_info(img_path)
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
else:
bands = [0, 0, 0]
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
r_data[r_data <= 0] = np.nan
if band_count > 1:
g_data[g_data <= 0] = np.nan
if band_count > 2:
b_data[b_data <= 0] = np.nan
def linear_stretch(data, low=2, high=98):
valid = data[~np.isnan(data)]
if len(valid) == 0:
return np.zeros_like(data)
lo = np.percentile(valid, low)
hi = np.percentile(valid, high)
if hi - lo < 1e-10:
return np.zeros_like(data)
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
return np.nan_to_num(stretched, nan=0.0)
r_s = linear_stretch(r_data)
g_s = linear_stretch(g_data) if band_count > 1 else r_s
b_s = linear_stretch(b_data) if band_count > 2 else r_s
rgb_image = np.nan_to_num(np.stack([r_s, g_s, b_s], axis=2)) * 255
rgb_image = rgb_image.astype(np.uint8)
# 读取掩膜
mask_dataset = gdal.Open(mask_path)
if mask_dataset is not None:
mask_data = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
else:
print(f"警告: 无法打开掩膜文件: {mask_path}")
mask_data = None
# Alpha 混合
overlay = np.zeros((height, width, 4), dtype=np.uint8)
overlay[:, :, 0:3] = mask_color
overlay[:, :, 3] = 255 # 全不透明
blended = rgb_image.astype(np.float32)
if mask_data is not None:
alpha = mask_data.astype(np.float32) / 255.0 * mask_alpha
for c in range(3):
blended[:, :, c] = rgb_image[:, :, c].astype(np.float32) * (1 - alpha) + mask_color[c] * alpha
blended = blended.astype(np.uint8)
# 绘图
fig, ax = plt.subplots(figsize=(14, 10))
ax.imshow(blended)
ax.axis('off')
legend_elements = [
Patch(facecolor=f'#{mask_color[0]:02x}{mask_color[1]:02x}{mask_color[2]:02x}',
edgecolor='black', alpha=mask_alpha, label='水域范围')
]
ax.legend(handles=legend_elements, loc='upper right', framealpha=0.9)
# 面积计算
if geotransform and geotransform[1] != 0:
pixel_size_x = abs(geotransform[1])
pixel_size_y = abs(geotransform[5])
pixel_area = pixel_size_x * pixel_size_y
if mask_data is not None:
water_pixels = np.sum(mask_data > 0)
valid_pixels = np.sum(mask_data >= 0)
water_km2 = water_pixels * pixel_area / 1_000_000
valid_km2 = valid_pixels * pixel_area / 1_000_000
pct = (water_pixels / valid_pixels * 100) if valid_pixels > 0 else 0
area_text = (f'水域面积: {water_km2:.2f} km² | '
f'影像总面积: {valid_km2:.2f} km² | '
f'占比: {pct:.1f}%')
ax.text(0.02, 0.98, area_text,
transform=ax.transAxes, fontsize=11,
color='white', fontweight='bold',
bbox=dict(facecolor='#0064FF', alpha=0.8,
edgecolor='black', boxstyle='round,pad=0.5'),
verticalalignment='top')
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
return output_path
finally:
dataset = None

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 统一导出接口
本模块从各子模块导入可视化函数,提供统一的导出接口。
"""
from src.core.visualization.scatter_plot import generate_model_scatter_plots
from src.core.visualization.spectrum_plot import generate_spectrum_comparison_plots
from src.core.visualization.boxplot import generate_boxplots
from src.core.visualization.statistics import generate_statistical_charts
from src.core.visualization.preview import generate_glint_deglint_previews
from src.core.visualization.report import generate_pipeline_report
__all__ = [
'generate_model_scatter_plots',
'generate_spectrum_comparison_plots',
'generate_boxplots',
'generate_statistical_charts',
'generate_glint_deglint_previews',
'generate_pipeline_report',
]

View File

@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 箱型图生成
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from typing import Optional, Dict, List
sns.set_style("whitegrid")
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
def generate_boxplots(
csv_path: str,
parameter_columns: Optional[List[str]] = None,
data_start_column: int = 4,
save_individual: bool = True,
use_seaborn: bool = True,
output_dir: Optional[str] = None
) -> Dict[str, str]:
"""
生成水质参数的箱型图
Args:
csv_path: CSV文件路径
parameter_columns: 参数列名列表如果为None自动检测
data_start_column: 数据开始列索引从第几列开始默认第5列索引为4
save_individual: 是否为每个参数单独保存箱型图
use_seaborn: 是否使用seaborn绘制更美观
output_dir: 输出目录None则使用默认
Returns:
箱型图文件路径字典
"""
print("\n" + "="*80)
print("生成水质参数箱型图")
print("="*80)
if csv_path is None:
raise ValueError("请提供 csv_path")
# 确定输出目录
if output_dir is None:
csv_dir = Path(csv_path).parent
output_dir = str(csv_dir / "visualization" / "boxplots")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 读取数据
df = pd.read_csv(csv_path)
# 确定参数列
if parameter_columns is None:
data_columns = df.iloc[:, data_start_column:]
parameter_columns = list(data_columns.columns)
else:
parameter_columns = [col for col in parameter_columns if col in df.columns]
if not parameter_columns:
print("警告: 未找到有效的参数列")
return {}
boxplot_dir = Path(output_dir)
boxplot_paths = {}
if save_individual:
print(f"为每个参数单独绘制箱型图(共 {len(parameter_columns)} 个参数)")
for column in parameter_columns:
if column not in df.columns:
continue
clean_data = df[column].dropna()
if len(clean_data) == 0:
print(f"跳过列 '{column}': 没有有效数据")
continue
try:
plt.figure(figsize=(8, 6))
if use_seaborn:
plot_data = pd.DataFrame({
'参数': [column] * len(clean_data),
'数值': clean_data
})
sns.boxplot(data=plot_data, x='参数', y='数值', palette='Set2')
sns.stripplot(data=plot_data, x='参数', y='数值',
color='red', alpha=0.6, size=5, jitter=True)
else:
box_plot = plt.boxplot([clean_data], labels=[column],
patch_artist=True, showfliers=False)
box_plot['boxes'][0].set_facecolor('lightblue')
box_plot['boxes'][0].set_alpha(0.7)
x_pos = np.random.normal(1, 0.04, size=len(clean_data))
plt.scatter(x_pos, clean_data, alpha=0.6, s=30, color='red',
edgecolors='black', linewidth=0.5, zorder=3)
plt.title(f'{column} - 箱型图', fontsize=14, fontweight='bold')
plt.xlabel('参数', fontsize=12)
plt.ylabel('数值', fontsize=12)
stats_text = (f'数据点数: {len(clean_data)}\n'
f'均值: {clean_data.mean():.2f}\n'
f'中位数: {clean_data.median():.2f}\n'
f'标准差: {clean_data.std():.2f}')
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
verticalalignment='top',
bbox=dict(boxstyle='round',
facecolor='wheat' if not use_seaborn else 'lightgreen',
alpha=0.8))
plt.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
safe_column_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
save_path = boxplot_dir / f'{safe_column_name}_boxplot.png'
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
boxplot_paths[column] = str(save_path)
print(f" 已保存: {save_path.name}")
except Exception as e:
print(f" 处理参数 {column} 时出错: {e}")
continue
# 综合箱型图
try:
print("\n生成综合箱型图(所有参数在一张图上)")
plt.figure(figsize=(max(12, len(parameter_columns) * 0.8), 8))
box_data = []
labels = []
for column in parameter_columns:
if column in df.columns:
clean_data = df[column].dropna()
if len(clean_data) > 0:
box_data.append(clean_data)
labels.append(column)
if box_data:
if use_seaborn:
melted_data = pd.melt(df[labels], var_name='参数', value_name='数值')
melted_data = melted_data.dropna()
sns.boxplot(data=melted_data, x='参数', y='数值', palette='Set3')
sns.stripplot(data=melted_data, x='参数', y='数值',
color='red', alpha=0.6, size=4, jitter=True)
else:
box_plot = plt.boxplot(box_data, labels=labels, patch_artist=True, showfliers=False)
colors = plt.cm.Set3(np.linspace(0, 1, len(box_data)))
for patch, color in zip(box_plot['boxes'], colors):
patch.set_facecolor(color)
patch.set_alpha(0.7)
for i, data in enumerate(box_data):
x_pos = np.random.normal(i + 1, 0.04, size=len(data))
plt.scatter(x_pos, data, alpha=0.6, s=20, color='red',
edgecolors='black', linewidth=0.5, zorder=3)
plt.title('水质参数箱型图(综合)', fontsize=16, fontweight='bold')
plt.xlabel('参数', fontsize=12)
plt.ylabel('数值', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
combined_path = boxplot_dir / 'all_parameters_boxplot.png'
plt.savefig(combined_path, dpi=300, bbox_inches='tight')
plt.close()
boxplot_paths['all_parameters'] = str(combined_path)
print(f" 已保存综合箱型图: {combined_path.name}")
except Exception as e:
print(f"生成综合箱型图时出错: {e}")
print(f"\n箱型图生成完成,共生成 {len(boxplot_paths)} 个图表")
return boxplot_paths

View File

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 耀斑影像预览图生成
"""
from pathlib import Path
from typing import Optional, Dict
from src.postprocessing.visualization_reports import WaterQualityVisualization
def generate_glint_deglint_previews(
work_dir: str,
output_subdir: str = "glint_deglint_previews",
generate_glint: bool = True,
generate_deglint: bool = True,
output_dir: Optional[str] = None
) -> Dict[str, str]:
"""
生成2_glint和3_deglint文件夹中影像文件的PNG预览图
Args:
work_dir: 工作目录
output_subdir: 输出子目录名称
generate_glint: 是否处理2_glint文件夹
generate_deglint: 是否处理3_deglint文件夹
output_dir: 输出目录None则使用默认
Returns:
生成的预览图路径字典
"""
print(f"\n{'='*70}")
print("步骤: 生成耀斑分析影像预览图")
print(f"{'='*70}")
if work_dir is None:
raise ValueError("请提供 work_dir")
# 确定输出目录
if output_dir is None:
output_dir = str(Path(work_dir) / "visualization" / output_subdir)
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 实例化可视化器
visualizer = WaterQualityVisualization(output_dir)
try:
preview_paths = visualizer.generate_glint_deglint_previews(
work_dir=work_dir,
output_subdir=output_subdir,
generate_glint=generate_glint,
generate_deglint=generate_deglint
)
print(f"耀斑分析影像预览图生成完成,共生成 {len(preview_paths)} 个预览图")
return preview_paths
except Exception as e:
print(f"生成耀斑分析影像预览图时出错: {e}")
return {}

View File

@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 流程执行报告生成
"""
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional, Dict
from datetime import datetime
def generate_pipeline_report(
step_timings: Dict,
pipeline_start_time: Optional[float] = None,
pipeline_end_time: Optional[float] = None,
output_path: Optional[str] = None
) -> str:
"""
生成流程执行报告,包含每步的耗时统计
Args:
step_timings: 步骤耗时字典(格式:{step_name: {start_time, end_time, elapsed_seconds, elapsed_formatted, status, error}}
pipeline_start_time: 流程开始时间戳
pipeline_end_time: 流程结束时间戳
output_path: 输出文件路径如果为None自动生成
Returns:
报告文件路径
"""
if output_path is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_path = str(Path.cwd() / "reports" / f"pipeline_report_{timestamp}.csv")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
def _format_time(seconds: float) -> str:
if seconds < 60:
return f"{seconds:.2f}"
elif seconds < 3600:
minutes = int(seconds // 60)
secs = seconds % 60
return f"{minutes}{secs:.2f}"
else:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = seconds % 60
return f"{hours}小时{minutes}{secs:.2f}"
# 准备报告数据
report_data = []
total_time = 0.0
step_order = [
"步骤1: 生成水域mask",
"步骤2: 找到耀斑区域",
"步骤3: 去除耀斑",
"步骤4: 处理CSV文件",
"步骤5: 提取训练样本点光谱",
"步骤5.5: 计算水质光谱指数",
"步骤6: 训练机器学习模型",
"步骤6.5: 非经验模型训练",
"步骤6.75: 自定义回归",
"步骤7: 生成预测采样点",
"步骤8: 预测水质参数",
"步骤9: 生成分布图"
]
for step_name in step_order:
if step_name in step_timings:
timing_info = step_timings[step_name]
report_data.append({
'步骤': step_name,
'开始时间': timing_info['start_time'],
'结束时间': timing_info['end_time'],
'耗时(秒)': f"{timing_info['elapsed_seconds']:.2f}",
'耗时(格式化)': timing_info['elapsed_formatted'],
'状态': timing_info['status'],
'错误信息': timing_info.get('error', '')
})
if timing_info['status'] == 'completed':
total_time += timing_info['elapsed_seconds']
if pipeline_start_time and pipeline_end_time:
pipeline_total = pipeline_end_time - pipeline_start_time
report_data.append({
'步骤': '总计',
'开始时间': datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S'),
'结束时间': datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S'),
'耗时(秒)': f"{pipeline_total:.2f}",
'耗时(格式化)': _format_time(pipeline_total),
'状态': 'completed',
'错误信息': ''
})
df_report = pd.DataFrame(report_data)
df_report.to_csv(output_path, index=False, encoding='utf-8-sig')
# 文本格式报告
txt_output_path = str(Path(output_path).with_suffix('.txt'))
with open(txt_output_path, 'w', encoding='utf-8') as f:
f.write("="*80 + "\n")
f.write("水质参数反演流程执行报告\n")
f.write("="*80 + "\n\n")
if pipeline_start_time and pipeline_end_time:
f.write(f"流程开始时间: {datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"流程结束时间: {datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"总耗时: {_format_time(pipeline_end_time - pipeline_start_time)}\n\n")
f.write("-"*80 + "\n")
f.write("各步骤执行详情:\n")
f.write("-"*80 + "\n\n")
for step_name in step_order:
if step_name in step_timings:
timing_info = step_timings[step_name]
f.write(f"{step_name}\n")
f.write(f" 开始时间: {timing_info['start_time']}\n")
f.write(f" 结束时间: {timing_info['end_time']}\n")
f.write(f" 耗时: {timing_info['elapsed_formatted']} ({timing_info['elapsed_seconds']:.2f}秒)\n")
f.write(f" 状态: {timing_info['status']}\n")
if timing_info.get('error'):
f.write(f" 错误: {timing_info['error']}\n")
f.write("\n")
f.write("-"*80 + "\n")
f.write("统计摘要:\n")
f.write("-"*80 + "\n")
completed_steps = [s for s in step_timings.values() if s['status'] == 'completed']
failed_steps = [s for s in step_timings.values() if s['status'] == 'failed']
skipped_steps = [s for s in step_timings.values() if s['status'] == 'skipped']
f.write(f"成功完成的步骤: {len(completed_steps)}\n")
f.write(f"失败的步骤: {len(failed_steps)}\n")
f.write(f"跳过的步骤: {len(skipped_steps)}\n")
if completed_steps:
completed_times = [s['elapsed_seconds'] for s in completed_steps]
f.write(f"平均耗时: {_format_time(np.mean(completed_times))}\n")
f.write(f"最长耗时: {_format_time(np.max(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.max(completed_times)][0]})\n")
f.write(f"最短耗时: {_format_time(np.min(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.min(completed_times)][0]})\n")
print(f"\n流程报告已生成:")
print(f" CSV格式: {output_path}")
print(f" 文本格式: {txt_output_path}")
return output_path

View File

@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 散点图生成
"""
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional, Dict, List, Union
from src.core.prediction.inference_batch import WaterQualityInference
from src.postprocessing.visualization_reports import WaterQualityVisualization
def generate_model_scatter_plots(
models_dir: str,
training_csv_path: str,
output_dir: Optional[str] = None,
metric: str = 'test_r2',
use_enhanced: bool = True,
feature_start_column: Union[str, int] = 13,
test_size: float = 0.2,
random_state: int = 42,
scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象
) -> Dict[str, str]:
"""
生成模型评估散点图真实值vs预测值
Args:
models_dir: 模型保存目录
training_csv_path: 训练数据CSV路径
output_dir: 输出目录None则使用默认
metric: 选择最佳模型的指标
use_enhanced: 是否使用增强版散点图带置信区间使用sctter_batch
feature_start_column: 特征开始列名或索引
test_size: 测试集比例
random_state: 随机种子
scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象
Returns:
散点图文件路径字典(键为目标参数名)
"""
print("\n" + "="*80)
print("生成模型评估散点图")
print("="*80)
if training_csv_path is None:
raise ValueError("请提供 training_csv_path")
models_path = Path(models_dir)
if not models_path.exists():
raise ValueError(f"模型目录不存在: {models_dir}")
# 确定输出目录
if output_dir is None:
output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 实例化可视化器
visualizer = WaterQualityVisualization(output_dir)
scatter_paths = {}
# 增强版散点图
if use_enhanced:
print("使用增强版散点图(带置信区间)")
try:
from src.core.prediction.sctter_batch import WaterQualityScatterBatch
if scatter_batch is None:
scatter_batch = WaterQualityScatterBatch()
results = scatter_batch.batch_plot_scatter(
models_root_dir=models_dir,
csv_path=training_csv_path,
output_dir=output_dir,
metric=metric,
target_column=None,
feature_start_column=feature_start_column,
test_size=test_size,
random_state=random_state
)
for target_name, result in results.items():
if result.get('status') == 'success':
scatter_paths[target_name] = result.get('save_path', '')
print(f"{target_name}: {result.get('save_path', '')}")
else:
print(f"{target_name}: 失败 - {result.get('error', '未知错误')}")
except Exception as e:
print(f"使用增强版散点图时出错: {e}")
print("回退到基础版散点图")
use_enhanced = False
# 基础版散点图
if not use_enhanced or not scatter_paths:
print("使用基础版散点图")
for target_folder in models_path.iterdir():
if not target_folder.is_dir():
continue
target_name = target_folder.name
print(f"\n处理目标参数: {target_name}")
try:
inferencer = WaterQualityInference(str(target_folder))
eval_result = inferencer.evaluate_with_split(
data_csv_path=training_csv_path,
split_method="spxy",
test_size=test_size,
random_state=random_state,
metric=metric
)
predictions = eval_result.get('predictions', {})
if predictions:
y_train_true = predictions.get('y_train_true')
y_train_pred = predictions.get('y_train_pred')
y_test_true = predictions.get('y_test_true')
y_test_pred = predictions.get('y_test_pred')
metrics = eval_result.get('test_metrics', {})
if y_train_true is not None and y_test_true is not None:
y_all_true = np.concatenate([y_train_true, y_test_true])
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
train_indices = np.arange(len(y_train_true))
test_indices = np.arange(len(y_train_true), len(y_all_true))
scatter_path = visualizer.plot_scatter_true_vs_pred(
y_true=y_all_true,
y_pred=y_all_pred,
target_name=target_name,
train_indices=train_indices,
test_indices=test_indices,
metrics={
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
'test_r2': metrics.get('r2', 0),
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
'test_rmse': metrics.get('rmse', 0)
}
)
scatter_paths[target_name] = scatter_path
except Exception as e:
print(f"处理目标参数 {target_name} 时出错: {e}")
continue
print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表")
return scatter_paths

View File

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 光谱曲线图生成
"""
import pandas as pd
from pathlib import Path
from typing import Optional, Dict, List, Union
from src.postprocessing.visualization_reports import WaterQualityVisualization
def generate_spectrum_comparison_plots(
csv_path: str,
parameter_columns: Optional[List[str]] = None,
wavelength_start_column: Union[str, int] = "UTM_Y",
output_dir: Optional[str] = None
) -> Dict[str, str]:
"""
生成光谱曲线对比图(不同参数值的光谱曲线对比)
Args:
csv_path: 包含光谱和参数值的CSV文件路径
parameter_columns: 参数列名列表如果为None自动检测
wavelength_start_column: 波长开始列名或索引
output_dir: 输出目录None则使用默认
Returns:
光谱曲线图文件路径字典(键为参数名)
"""
print("\n" + "="*80)
print("生成光谱曲线对比图")
print("="*80)
if csv_path is None:
raise ValueError("请提供 csv_path")
# 确定输出目录
if output_dir is None:
csv_dir = Path(csv_path).parent
output_dir = str(csv_dir / "visualization" / "spectrum_plots")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 实例化可视化器
visualizer = WaterQualityVisualization(output_dir)
# 读取数据以检测参数列
df = pd.read_csv(csv_path)
if parameter_columns is None:
if isinstance(wavelength_start_column, str):
try:
wavelength_start_idx = df.columns.get_loc(wavelength_start_column)
except:
wavelength_start_idx = 13
else:
wavelength_start_idx = wavelength_start_column
parameter_columns = list(df.columns[:wavelength_start_idx])
if len(parameter_columns) > 2:
parameter_columns = parameter_columns[2:]
spectrum_paths = {}
for param_col in parameter_columns:
if param_col not in df.columns:
continue
print(f"\n处理参数: {param_col}")
try:
spectrum_path = visualizer.plot_spectrum_by_parameter(
csv_path=csv_path,
parameter_column=param_col,
wavelength_start_column=wavelength_start_column,
n_groups=5
)
spectrum_paths[param_col] = spectrum_path
except Exception as e:
print(f"处理参数 {param_col} 时出错: {e}")
continue
print(f"\n光谱曲线图生成完成,共生成 {len(spectrum_paths)} 个图表")
return spectrum_paths

View File

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 统计图表生成
"""
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional, Dict, List
from src.postprocessing.visualization_reports import WaterQualityVisualization
def generate_statistical_charts(
csv_path: str,
parameter_columns: Optional[List[str]] = None,
output_dir: Optional[str] = None
) -> Dict[str, str]:
"""
生成统计图表(箱线图、直方图、相关性热力图)
Args:
csv_path: CSV文件路径
parameter_columns: 参数列名列表如果为None自动检测
output_dir: 输出目录None则使用默认
Returns:
统计图表文件路径字典
"""
print("\n" + "="*80)
print("生成统计图表")
print("="*80)
if csv_path is None:
raise ValueError("请提供 csv_path")
# 确定输出目录
if output_dir is None:
csv_dir = Path(csv_path).parent
output_dir = str(csv_dir / "visualization" / "statistical_charts")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 实例化可视化器
visualizer = WaterQualityVisualization(output_dir)
# 读取数据以检测参数列
df = pd.read_csv(csv_path)
if parameter_columns is None:
parameter_columns = list(df.columns[2:])
parameter_columns = [col for col in parameter_columns
if df[col].dtype in [np.float64, np.int64]]
chart_paths = visualizer.plot_statistical_charts(
csv_path=csv_path,
parameter_columns=parameter_columns
)
print(f"\n统计图表生成完成")
return chart_paths

File diff suppressed because it is too large Load Diff