feat: Kutser算法分块读写改造 + GUI标题更名为Mega Water

This commit is contained in:
DXC
2026-05-09 13:30:33 +08:00
parent 820986d975
commit 82af2d75d3
3 changed files with 384 additions and 451 deletions

View File

@ -1,5 +1,4 @@
import numpy as np
# import preprocessing
import os
try:
@ -8,306 +7,333 @@ try:
except ImportError:
GDAL_AVAILABLE = False
class Kutser:
def __init__(self, im_aligned, shp_path=None, oxy_band = 38,lower_oxy = 36, upper_oxy = 49, NIR_band = 47, water_mask=None, output_path=None):
def __init__(self, img_path, shp_path=None, oxy_band=38, lower_oxy=36,
upper_oxy=49, NIR_band=47, water_mask=None, output_path=None,
block_size=1000):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
:param oxy_band (int): band index for oxygen absorption band, which corresponds to 760.6nm
:param lower_oxy (int): band index for outside oxygen absorption band, which corresponds to 742.39nm
:param upper_oxy (int): band index for outside oxygen absorption band, which corresponds to 860.48nm
see Kutser, Vahtmäe and Praks
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
如果为None则不保存
Kutser 耀斑去除算法 - 分块逐波段处理版本
:param img_path (str): 输入影像文件路径GDAL可读取的格式
:param shp_path (str, optional): 深水区域shapefile已废弃请使用water_mask
:param oxy_band (int): 氧吸收波段索引默认38对应760.6nm
:param lower_oxy (int): 氧吸收下方波段索引默认36对应742.39nm
:param upper_oxy (int): 氧吸收上方波段索引默认49对应860.48nm
:param NIR_band (int): NIR波段索引默认47对应842.36nm
:param water_mask (np.ndarray or str or None): 水域掩膜
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
:param block_size (int): 分块大小默认1000
"""
self.im_aligned = im_aligned
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
self.img_path = img_path
self.oxy_band = int(float(oxy_band))
self.lower_oxy = int(float(lower_oxy))
self.upper_oxy = int(float(upper_oxy))
self.NIR_band = int(float(NIR_band))
self.n_bands = im_aligned.shape[-1]
self.height = im_aligned.shape[0]
self.width = im_aligned.shape[1]
self.water_mask = None # 延迟加载,在处理前初始化
self.water_mask_path = water_mask
self.output_path = output_path
# 加载水域掩膜
self.water_mask = self._load_water_mask(water_mask)
# 使用ravel()而不是flatten(),避免不必要的复制
# 如果存在水域掩膜只在掩膜内计算R_min
if self.water_mask is not None:
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
self.R_min = np.percentile(nir_band_masked, 5, method='nearest') if nir_band_masked.size > 0 else 0
else:
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, method='nearest')
def _read_shp_to_bbox(self, shp_path):
"""
读取shapefile并提取边界框
:param shp_path (str): shapefile文件路径
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
"""
if not os.path.exists(shp_path):
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
try:
try:
import geopandas as gpd
gdf = gpd.read_file(shp_path)
# 获取所有几何体的总边界框
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
min_x, min_y, max_x, max_y = bounds
except ImportError:
# 如果geopandas不可用尝试使用fiona
import fiona
from shapely.geometry import shape
min_x = float('inf')
min_y = float('inf')
max_x = float('-inf')
max_y = float('-inf')
with fiona.open(shp_path) as shp:
for feature in shp:
geom = shape(feature['geometry'])
if geom:
bounds = geom.bounds
min_x = min(min_x, bounds[0])
min_y = min(min_y, bounds[1])
max_x = max(max_x, bounds[2])
max_y = max(max_y, bounds[3])
# 转换为整数像素坐标
x1 = max(0, int(min_x))
y1 = max(0, int(min_y))
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
return ((x1, y1), (x2, y2))
except Exception as e:
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
def _load_water_mask(self, water_mask):
"""
加载水域掩膜
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
:return: numpy数组或None1表示水域0表示非水域
"""
if water_mask is None:
return None
# 如果已经是numpy数组
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
# 如果是文件路径
if isinstance(water_mask, str):
try:
from osgeo import gdal, ogr
except ImportError:
raise ValueError("使用文件路径作为掩膜时必须安装GDAL")
# 检查是否为shapefile
if water_mask.lower().endswith('.shp'):
# 从shp文件创建掩膜需要参考图像这里假设使用im_aligned的尺寸
# 注意如果输入是numpy数组无法从shp创建掩膜需要提供栅格参考
raise ValueError("Kutser类输入为numpy数组时无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
else:
# 栅格文件
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {water_mask}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def get_depth_D(self):
"""
Assume the amount of glint is proportional to the depth of the oxygen absorption feature, D
returns the normalised D by dividing it by the maximum D found in a deep water region
"""
# 优化:减少中间数组创建,使用更高效的计算
lower_oxy_band = self.im_aligned[:,:,self.lower_oxy]
upper_oxy_band = self.im_aligned[:,:,self.upper_oxy]
oxy_band = self.im_aligned[:,:,self.oxy_band]
D = (lower_oxy_band + upper_oxy_band) * 0.5 - oxy_band
# 确定用于计算D_max的区域
if self.bbox is None:
search_region = D
else:
((x1,y1),(x2,y2)) = self.bbox
search_region = D[y1:y2,x1:x2]
# 如果存在水域掩膜,只在掩膜内搜索最大值
if self.water_mask is not None:
if self.bbox is None:
mask_region = self.water_mask.astype(bool)
else:
((x1,y1),(x2,y2)) = self.bbox
mask_region = self.water_mask[y1:y2,x1:x2].astype(bool)
if mask_region.any():
D_max = search_region[mask_region].max()
else:
D_max = search_region.max()
else:
D_max = search_region.max() # assumed to be the maximum glint value
# 避免除零错误
if D_max == 0:
return np.zeros_like(D)
return D / D_max
def get_glint_G(self):
"""
The spectral variation of glint G is found by subtracting the spectrum at the darkest (ie. lowest D) NIR deep-water pixel from the brightest
returns G as a function of wavelength
"""
# If bbox is None, use the entire image
if self.bbox is None:
im_region = self.im_aligned
mask_region = self.water_mask
else:
((x1,y1),(x2,y2)) = self.bbox
im_region = self.im_aligned[y1:y2,x1:x2,:]
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
self.block_size = block_size
self.R_min = None # 全局R_min来自重采样扫描
self.D_max = None # 全局D_max来自重采样扫描
self.G_list = None # 全局G值列表
# 如果存在水域掩膜,只在掩膜内计算最大最小值
if mask_region is not None:
mask_bool = mask_region.astype(bool)
if mask_bool.any():
# 对每个波段,只在掩膜内计算最大最小值
G_list = []
for i in range(self.n_bands):
band_data = im_region[:,:,i]
G_max = band_data[mask_bool].max()
G_min = band_data[mask_bool].min()
G_list.append(G_max - G_min)
# 打开影像获取基本信息
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if self.dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
self.width = self.dataset.RasterXSize
self.height = self.dataset.RasterYSize
self.n_bands = self.dataset.RasterCount
def _load_water_mask(self):
"""延迟加载水域掩膜"""
if self.water_mask_path is None:
return None
if isinstance(self.water_mask_path, np.ndarray):
if self.water_mask_path.shape[:2] != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (self.water_mask_path > 0).astype(np.uint8)
if isinstance(self.water_mask_path, str):
if self.water_mask_path.lower().endswith('.shp'):
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (mask_array > 0).astype(np.uint8)
return None
def _scan_global_stats(self, sample_step=20):
"""
通过重采样扫描获取全图全局统计量R_min 和 D_max
分块读取按sample_step跳行/跳列采样,大幅降低内存占用。
内存峰值 ≈ 单波段块大小 + 几个掩膜数组 ≈ block_size² × 4~8MB
"""
print(f"[Kutser] 扫描全局统计量(采样步长={sample_step}...")
water_mask = self._load_water_mask()
# 预分配采样数组NIR波段和D值
nir_samples = []
d_samples = []
sample_count = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
block_height = y_end - y_off
# 读取NIR波段用于R_min
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
nir_band = None
# 读取氧吸收相关波段用于D_max
lower_band = self.dataset.GetRasterBand(self.lower_oxy + 1)
lower_block = lower_band.ReadAsArray(0, y_off, self.width, block_height)
lower_band = None
upper_band = self.dataset.GetRasterBand(self.upper_oxy + 1)
upper_block = upper_band.ReadAsArray(0, y_off, self.width, block_height)
upper_band = None
oxy_band_obj = self.dataset.GetRasterBand(self.oxy_band + 1)
oxy_block = oxy_band_obj.ReadAsArray(0, y_off, self.width, block_height)
oxy_band_obj = None
# 计算D = (lower + upper) * 0.5 - oxy
d_block = (lower_block.astype(np.float32) + upper_block.astype(np.float32)) * 0.5 - oxy_block.astype(np.float32)
# 获取掩膜(整块)
if water_mask is not None:
mask_block = water_mask[y_off:y_end, :]
else:
# 如果掩膜内没有有效像素,使用全区域
G_max = np.amax(im_region, axis=(0, 1))
G_min = np.amin(im_region, axis=(0, 1))
G_list = (G_max - G_min).tolist()
mask_block = np.ones((block_height, self.width), dtype=np.uint8)
# 对掩膜区域进行采样
mask_bool = mask_block.astype(bool)
if mask_bool.any():
# 按步长采样
nir_sampled = nir_block[mask_bool][::sample_step]
d_sampled = d_block[mask_bool][::sample_step]
nir_samples.append(nir_sampled)
d_samples.append(d_sampled)
sample_count += nir_sampled.size
# 显式释放块内存
del nir_block, lower_block, upper_block, oxy_block, d_block, mask_block
# 汇总
if sample_count == 0:
self.R_min = 0.0
self.D_max = 1.0
else:
# 优化:一次性计算所有波段的最大最小值,减少循环开销
# 使用numpy的amax和amin沿最后一个轴计算
G_max = np.amax(im_region, axis=(0, 1)) # 沿空间维度计算最大值
G_min = np.amin(im_region, axis=(0, 1)) # 沿空间维度计算最小值
G_list = (G_max - G_min).tolist()
return G_list
def _save_corrected_bands(self, corrected_bands):
all_nir = np.concatenate(nir_samples)
all_d = np.concatenate(d_samples)
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
self.D_max = float(all_d.max())
del all_nir, all_d
print(f"[Kutser] 全局 R_min={self.R_min:.4f}, D_max={self.D_max:.4f}")
def _compute_G_list(self):
"""
保存校正后的波段到文件BSQ格式ENVI格式
:param corrected_bands: 校正后的波段列表
计算全局G值列表每个波段
G = G_max - G_min所有水域像素的极值差异
使用全分辨率扫描,但逐波段读取,每波段内存 ≈ block_size²
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if self.output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 将波段列表转换为数组
corrected_array = np.stack(corrected_bands, axis=2)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(self.output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
print(f"[Kutser] 计算全局G值列表n_bands={self.n_bands}...")
water_mask = self._load_water_mask()
# 初始化G_max和G_min为极值
g_max = np.full(self.n_bands, -np.inf, dtype=np.float32)
g_min = np.full(self.n_bands, np.inf, dtype=np.float32)
# 逐块扫描
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
block_height = y_end - y_off
# 读取所有波段的当前块
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
band = None
if water_mask is not None:
mask_block = water_mask[y_off:y_end, :]
mask_bool = mask_block.astype(bool)
if mask_bool.any():
band_masked = block[mask_bool]
g_max[b] = max(g_max[b], band_masked.max())
g_min[b] = min(g_min[b], band_masked.min())
else:
g_max[b] = max(g_max[b], block.max())
g_min[b] = min(g_min[b], block.min())
del block
self.G_list = (g_max - g_min).tolist()
print(f"[Kutser] G值范围: min={min(self.G_list):.4f}, max={max(self.G_list):.4f}")
def _process_block(self, x_off, y_off, x_size, y_size):
"""
处理单个分块:读取数据 -> 计算D -> 逐波段校正 -> 返回块结果
Returns:
list of np.ndarray: 校正后的波段列表(每波段形状为 y_size x x_size
"""
# 读取氧吸收相关波段
lower_band = self.dataset.GetRasterBand(self.lower_oxy + 1)
lower_block = lower_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
lower_band = None
upper_band = self.dataset.GetRasterBand(self.upper_oxy + 1)
upper_block = upper_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
upper_band = None
oxy_band_obj = self.dataset.GetRasterBand(self.oxy_band + 1)
oxy_block = oxy_band_obj.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
oxy_band_obj = None
# 计算D
D = (lower_block + upper_block) * 0.5 - oxy_block
# 避免除零
if self.D_max == 0:
D_normalized = np.zeros_like(D)
else:
bsq_path = self.output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_array.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_array[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
D_normalized = D / self.D_max
# 释放临时块
del lower_block, upper_block, oxy_block, D
# 获取当前块的水域掩膜
water_mask = self._load_water_mask()
if water_mask is not None:
y_end = y_off + y_size
x_end = x_off + x_size
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")
mask_block = None
# 逐波段处理
corrected_bands = []
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
band = None
G = self.G_list[b]
# 校正公式R_corrected = R - G * D_normalized
corrected = R - G * D_normalized
# 只对水域区域应用校正
if mask_block is not None:
corrected = np.where(mask_block, corrected, R)
corrected_bands.append(corrected)
del R
# 释放D块
del D_normalized
return corrected_bands
def get_corrected_bands(self):
"""
correction is done in reflectance
:return: 校正后的波段列表
"""
g_list = self.get_glint_G()
D = self.get_depth_D()
# 获取水域掩膜(如果存在)
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
执行分块处理,返回校正后的波段列表
corrected_bands = []
for band_number in range(self.n_bands): #iterate across bands
G = g_list[band_number]
R = self.im_aligned[:,:,band_number]
# 优化:减少中间数组创建,直接计算
corrected_band = R - G * D
# 如果存在水域掩膜,只对水域区域应用校正
if water_mask_bool is not None:
corrected_band = np.where(water_mask_bool, corrected_band, R)
corrected_bands.append(corrected_band)
# 如果提供了输出路径,保存结果
if self.output_path is not None:
self._save_corrected_bands(corrected_bands)
return corrected_bands
内存峰值 ≈ 单波段块大小 + 几个辅助数组 ≈ 1000×1000×4B × 3 ≈ 12MB
"""
if self.output_path is None:
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
# Step 1: 扫描全局统计量R_min, D_max
self._scan_global_stats(sample_step=20)
# Step 2: 计算全局G列表
self._compute_G_list()
# Step 3: 创建输出文件
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
base_path, ext = os.path.splitext(self.output_path)
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
# 获取地理信息
geotransform = self.dataset.GetGeoTransform()
projection = self.dataset.GetProjection()
driver = gdal.GetDriverByName('ENVI')
out_dataset = driver.Create(bsq_path, self.width, self.height,
self.n_bands, gdal.GDT_Float32)
if out_dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
# Step 4: 分块处理
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
total_blocks = n_blocks_x * n_blocks_y
print(f"[Kutser] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
block_idx = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
y_size = y_end - y_off
for x_off in range(0, self.width, self.block_size):
x_end = min(x_off + self.block_size, self.width)
x_size = x_end - x_off
block_idx += 1
print(f"[Kutser] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
# 处理当前块
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
# 写入输出文件
for b in range(self.n_bands):
out_band = out_dataset.GetRasterBand(b + 1)
out_band.WriteArray(corrected_bands[b], x_off, y_off)
out_band.FlushCache()
del corrected_bands
out_dataset = None
self.dataset = None
# 检查.hdr文件
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"[Kutser] 校正完成,已保存至: {bsq_path}")
else:
print(f"[Kutser] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件")
# 返回空列表(结果已直接写入文件)
return []
def __del__(self):
if self.dataset is not None:
self.dataset = None