fix: 分块读写改造——修复Hedley协方差形状广播错误和SUGAR列表越界错误

This commit is contained in:
DXC
2026-05-09 11:58:40 +08:00
parent 56de4b6fc4
commit a14d40f28d
2 changed files with 891 additions and 811 deletions

View File

@ -1,5 +1,4 @@
import numpy as np
# import preprocessing
import os
try:
@ -8,283 +7,301 @@ try:
except ImportError:
GDAL_AVAILABLE = False
class Hedley:
def __init__(self, im_aligned, shp_path=None, NIR_band = 47, water_mask=None, output_path=None):
def __init__(self, img_path, shp_path=None, NIR_band=47, water_mask=None,
output_path=None, block_size=1000):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
:param NIR_band (int): band index for NIR band which corresponds to 842.36nm, which corresponds closely to the NIR band in Micasense
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
如果为None则不保存
"""
self.im_aligned = im_aligned
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
self.NIR_band = int(float(NIR_band))
self.n_bands = im_aligned.shape[-1]
self.height = im_aligned.shape[0]
self.width = im_aligned.shape[1]
self.output_path = output_path
# 加载水域掩膜
self.water_mask = self._load_water_mask(water_mask)
# 使用ravel()而不是flatten(),避免不必要的复制
# 如果存在水域掩膜只在掩膜内计算R_min
if self.water_mask is not None:
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
self.R_min = np.percentile(nir_band_masked, 5, method='nearest') if nir_band_masked.size > 0 else 0
else:
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, method='nearest')
def _read_shp_to_bbox(self, shp_path):
"""
读取shapefile并提取边界框
:param shp_path (str): shapefile文件路径
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
"""
if not os.path.exists(shp_path):
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
try:
try:
import geopandas as gpd
gdf = gpd.read_file(shp_path)
# 获取所有几何体的总边界框
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
min_x, min_y, max_x, max_y = bounds
except ImportError:
# 如果geopandas不可用尝试使用fiona
import fiona
from shapely.geometry import shape
min_x = float('inf')
min_y = float('inf')
max_x = float('-inf')
max_y = float('-inf')
with fiona.open(shp_path) as shp:
for feature in shp:
geom = shape(feature['geometry'])
if geom:
bounds = geom.bounds
min_x = min(min_x, bounds[0])
min_y = min(min_y, bounds[1])
max_x = max(max_x, bounds[2])
max_y = max(max_y, bounds[3])
# 转换为整数像素坐标
x1 = max(0, int(min_x))
y1 = max(0, int(min_y))
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
return ((x1, y1), (x2, y2))
except Exception as e:
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
def _load_water_mask(self, water_mask):
"""
加载水域掩膜
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
:return: numpy数组或None1表示水域0表示非水域
"""
if water_mask is None:
return None
# 如果已经是numpy数组
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
# 如果是文件路径
if isinstance(water_mask, str):
try:
from osgeo import gdal, ogr
except ImportError:
raise ValueError("使用文件路径作为掩膜时必须安装GDAL")
# 检查是否为shapefile
if water_mask.lower().endswith('.shp'):
# 从shp文件创建掩膜需要参考图像这里假设使用im_aligned的尺寸
# 注意如果输入是numpy数组无法从shp创建掩膜需要提供栅格参考
raise ValueError("Hedley类输入为numpy数组时无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
else:
# 栅格文件
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {water_mask}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def covariance_NIR(self,NIR,b):
"""
NIR & b are vectors
reflectance for band i
"""
n = len(NIR)
# 优化减少重复计算使用更高效的numpy操作
nir_mean = np.mean(NIR)
b_mean = np.mean(b)
# 使用更高效的协方差计算
pij = np.mean((NIR - nir_mean) * (b - b_mean))
pjj = np.mean((NIR - nir_mean) ** 2)
# 避免除零错误
return pij / pjj if pjj != 0 else 0.0
def correlation_bands_reflectance(self):
"""
calculate correlation between NIR and other bands for reflectance
NIR_band is 750 nm
"""
# If bbox is None, use the entire image
if self.bbox is None:
# 使用ravel()而不是flatten(),避免不必要的复制
# 直接使用视图,只在需要时创建扁平数组
im_region = self.im_aligned
mask_region = self.water_mask
else:
((x1,y1),(x2,y2)) = self.bbox
im_region = self.im_aligned[y1:y2,x1:x2,:]
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
# 如果存在水域掩膜,只在掩膜内计算相关性
if mask_region is not None:
mask_bool = mask_region.astype(bool)
if mask_bool.any():
# 只在掩膜内提取数据
NIR_reflectance = im_region[:,:,self.NIR_band][mask_bool]
else:
# 如果掩膜内没有有效像素,使用全区域
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
mask_bool = None
else:
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
mask_bool = None
# 优化:一次性计算所有波段的相关性,减少循环开销
corr_list = []
for v in range(self.n_bands):
if mask_bool is not None and mask_bool.any():
band_reflectance = im_region[:,:,v][mask_bool]
else:
band_reflectance = im_region[:,:,v].ravel()
corr = self.covariance_NIR(NIR_reflectance, band_reflectance)
corr_list.append(corr)
return corr_list
def _save_corrected_bands(self, corrected_bands):
"""
保存校正后的波段到文件BSQ格式ENVI格式
:param corrected_bands: 校正后的波段列表
Hedley 耀斑去除算法 - 分块逐波段处理版本
:param img_path (str): 输入影像文件路径GDAL可读取的格式
:param shp_path (str, optional): 深水区域shapefile已废弃请使用water_mask
:param NIR_band (int): NIR波段索引默认47对应842.36nm
:param water_mask (np.ndarray or str or None): 水域掩膜
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
:param block_size (int): 分块大小默认1000
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if self.output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 将波段列表转换为数组
corrected_array = np.stack(corrected_bands, axis=2)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(self.output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
raise ImportError("GDAL未安装无法读取影像文件")
self.img_path = img_path
self.NIR_band = int(float(NIR_band))
self.water_mask = None
self.water_mask_path = water_mask
self.output_path = output_path
self.block_size = block_size
self.R_min = None
self.corr_list = None # 全局协方差系数列表
# 打开影像
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if self.dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
self.width = self.dataset.RasterXSize
self.height = self.dataset.RasterYSize
self.n_bands = self.dataset.RasterCount
def _load_water_mask(self):
"""延迟加载水域掩膜"""
if self.water_mask_path is None:
return None
if isinstance(self.water_mask_path, np.ndarray):
if self.water_mask_path.shape[:2] != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (self.water_mask_path > 0).astype(np.uint8)
if isinstance(self.water_mask_path, str):
if self.water_mask_path.lower().endswith('.shp'):
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (mask_array > 0).astype(np.uint8)
return None
def covariance_NIR(self, NIR, b):
"""计算 NIR 与波段 b 之间的协方差系数 b_i = Cov(NIR,b) / Var(NIR)"""
n = len(NIR)
nir_mean = np.mean(NIR)
b_mean = np.mean(b)
pij = np.mean((NIR - nir_mean) * (b - b_mean))
pjj = np.mean((NIR - nir_mean) ** 2)
return pij / pjj if pjj != 0 else 0.0
def _scan_global_stats(self, sample_step=20):
"""
扫描全图获取全局 R_min
使用重采样方式扫描,大幅降低内存占用。
"""
print(f"[Hedley] 扫描全局统计量(采样步长={sample_step}...")
water_mask = self._load_water_mask()
nir_samples = []
sample_count = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
block_height = y_end - y_off
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
nir_band = None
if water_mask is not None:
mask_block = water_mask[y_off:y_end, :]
mask_bool = mask_block.astype(bool)
else:
mask_bool = np.ones((block_height, self.width), dtype=bool)
if mask_bool.any():
nir_sampled = nir_block[mask_bool][::sample_step]
nir_samples.append(nir_sampled)
sample_count += nir_sampled.size
del nir_block, mask_block
if sample_count == 0:
self.R_min = 0.0
else:
bsq_path = self.output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_array.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_array[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
all_nir = np.concatenate(nir_samples)
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
del all_nir
print(f"[Hedley] 全局 R_min={self.R_min:.4f}")
def _compute_corr_list(self, sample_step=5):
"""
计算每个波段与NIR的协方差系数 corr_list[b] = Cov(NIR, band_b) / Var(NIR)
全分辨率扫描,逐波段读取,每波段内存 ≈ block_size²
由于需要相关性计算需要足够多的样本取sample_step=5
"""
print(f"[Hedley] 计算全局协方差系数列表(采样步长={sample_step}...")
water_mask = self._load_water_mask()
# 预收集NIR和每个波段的样本数据
nir_samples = []
band_samples = [[] for _ in range(self.n_bands)]
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
block_height = y_end - y_off
# 读取NIR波段每块只读一次
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
nir_band = None
# 取 NIR 样本(每块只取一次,放在波段循环外)
if water_mask is not None:
mask_block = water_mask[y_off:y_end, :]
mask_bool = mask_block.astype(bool)
else:
mask_bool = np.ones((block_height, self.width), dtype=bool)
if mask_bool.any():
nir_sampled = nir_block[mask_bool][::sample_step]
nir_samples.append(nir_sampled)
# 逐波段读取并采样all_band 严格使用单波段切片)
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
band = None
if mask_bool.any():
band_sampled = block[mask_bool][::sample_step]
band_samples[b].append(band_sampled)
del block
del nir_block
# 汇总并计算相关系数
if len(nir_samples) == 0 or sum(len(s) for s in nir_samples) == 0:
self.corr_list = [0.0] * self.n_bands
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")
all_nir = np.concatenate(nir_samples)
self.corr_list = []
for b in range(self.n_bands):
all_band = np.concatenate(band_samples[b])
corr = self.covariance_NIR(all_nir, all_band)
self.corr_list.append(float(corr))
del all_nir
for b in range(self.n_bands):
band_samples[b] = None
print(f"[Hedley] 协方差系数: min={min(self.corr_list):.4f}, max={max(self.corr_list):.4f}")
def _process_block(self, x_off, y_off, x_size, y_size):
"""
处理单个分块
Returns:
list of np.ndarray: 校正后的波段列表
"""
# 读取NIR波段
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
NIR = nir_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
nir_band = None
# 预计算 NIR - R_min
NIR_diff = NIR - self.R_min
# 获取掩膜
water_mask = self._load_water_mask()
if water_mask is not None:
y_end = y_off + y_size
x_end = x_off + x_size
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
else:
mask_block = None
# 逐波段处理
corrected_bands = []
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
band = None
corr = self.corr_list[b]
# Hedley 校正公式R_corrected = R - corr * (NIR - R_min)
corrected = R - corr * NIR_diff
if mask_block is not None:
corrected = np.where(mask_block, corrected, R)
corrected_bands.append(corrected)
del R
del NIR, NIR_diff
return corrected_bands
def get_corrected_bands(self):
"""
correction is done in reflectance
:return: 校正后的波段列表
执行分块处理,返回校正后的波段列表
"""
corr = self.correlation_bands_reflectance()
NIR_reflectance = self.im_aligned[:,:,self.NIR_band]
# 预计算NIR-R_min避免在循环中重复计算
NIR_diff = NIR_reflectance - self.R_min
# 获取水域掩膜(如果存在)
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
if self.output_path is None:
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
corrected_bands = []
for band_number in range(self.n_bands): #iterate across bands
b = corr[band_number]
R = self.im_aligned[:,:,band_number]
# 优化:减少中间数组创建
corrected_band = R - b * NIR_diff
# 如果存在水域掩膜,只对水域区域应用校正
if water_mask_bool is not None:
corrected_band = np.where(water_mask_bool, corrected_band, R)
corrected_bands.append(corrected_band)
# 如果提供了输出路径,保存结果
if self.output_path is not None:
self._save_corrected_bands(corrected_bands)
return corrected_bands
# Step 1: 扫描全局 R_min
self._scan_global_stats(sample_step=20)
# Step 2: 计算协方差系数列表
self._compute_corr_list(sample_step=5)
# Step 3: 创建输出文件
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
base_path, ext = os.path.splitext(self.output_path)
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
geotransform = self.dataset.GetGeoTransform()
projection = self.dataset.GetProjection()
driver = gdal.GetDriverByName('ENVI')
out_dataset = driver.Create(bsq_path, self.width, self.height,
self.n_bands, gdal.GDT_Float32)
if out_dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
# Step 4: 分块处理
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
total_blocks = n_blocks_x * n_blocks_y
print(f"[Hedley] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
block_idx = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
y_size = y_end - y_off
for x_off in range(0, self.width, self.block_size):
x_end = min(x_off + self.block_size, self.width)
x_size = x_end - x_off
block_idx += 1
print(f"[Hedley] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
for b in range(self.n_bands):
out_band = out_dataset.GetRasterBand(b + 1)
out_band.WriteArray(corrected_bands[b], x_off, y_off)
out_band.FlushCache()
del corrected_bands
out_dataset = None
self.dataset = None
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"[Hedley] 校正完成,已保存至: {bsq_path}")
else:
print(f"[Hedley] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件")
return []
def __del__(self):
if self.dataset is not None:
self.dataset = None

File diff suppressed because it is too large Load Diff