fix: 分块读写改造——修复Hedley协方差形状广播错误和SUGAR列表越界错误
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
import numpy as np
|
||||
# import preprocessing
|
||||
import os
|
||||
|
||||
try:
|
||||
@ -8,283 +7,301 @@ try:
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
class Hedley:
|
||||
def __init__(self, im_aligned, shp_path=None, NIR_band = 47, water_mask=None, output_path=None):
|
||||
def __init__(self, img_path, shp_path=None, NIR_band=47, water_mask=None,
|
||||
output_path=None, block_size=1000):
|
||||
"""
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
|
||||
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
|
||||
:param NIR_band (int): band index for NIR band which corresponds to 842.36nm, which corresponds closely to the NIR band in Micasense
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||||
如果为None,则不保存
|
||||
"""
|
||||
self.im_aligned = im_aligned
|
||||
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
|
||||
self.NIR_band = int(float(NIR_band))
|
||||
self.n_bands = im_aligned.shape[-1]
|
||||
self.height = im_aligned.shape[0]
|
||||
self.width = im_aligned.shape[1]
|
||||
self.output_path = output_path
|
||||
|
||||
# 加载水域掩膜
|
||||
self.water_mask = self._load_water_mask(water_mask)
|
||||
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 如果存在水域掩膜,只在掩膜内计算R_min
|
||||
if self.water_mask is not None:
|
||||
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
|
||||
self.R_min = np.percentile(nir_band_masked, 5, method='nearest') if nir_band_masked.size > 0 else 0
|
||||
else:
|
||||
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, method='nearest')
|
||||
|
||||
def _read_shp_to_bbox(self, shp_path):
|
||||
"""
|
||||
读取shapefile并提取边界框
|
||||
|
||||
:param shp_path (str): shapefile文件路径
|
||||
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
|
||||
"""
|
||||
if not os.path.exists(shp_path):
|
||||
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
|
||||
|
||||
try:
|
||||
try:
|
||||
import geopandas as gpd
|
||||
gdf = gpd.read_file(shp_path)
|
||||
# 获取所有几何体的总边界框
|
||||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||||
min_x, min_y, max_x, max_y = bounds
|
||||
except ImportError:
|
||||
# 如果geopandas不可用,尝试使用fiona
|
||||
import fiona
|
||||
from shapely.geometry import shape
|
||||
|
||||
min_x = float('inf')
|
||||
min_y = float('inf')
|
||||
max_x = float('-inf')
|
||||
max_y = float('-inf')
|
||||
|
||||
with fiona.open(shp_path) as shp:
|
||||
for feature in shp:
|
||||
geom = shape(feature['geometry'])
|
||||
if geom:
|
||||
bounds = geom.bounds
|
||||
min_x = min(min_x, bounds[0])
|
||||
min_y = min(min_y, bounds[1])
|
||||
max_x = max(max_x, bounds[2])
|
||||
max_y = max(max_y, bounds[3])
|
||||
|
||||
# 转换为整数像素坐标
|
||||
x1 = max(0, int(min_x))
|
||||
y1 = max(0, int(min_y))
|
||||
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
|
||||
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
|
||||
|
||||
return ((x1, y1), (x2, y2))
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
|
||||
|
||||
def _load_water_mask(self, water_mask):
|
||||
"""
|
||||
加载水域掩膜
|
||||
|
||||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
:return: numpy数组或None,1表示水域,0表示非水域
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# 如果已经是numpy数组
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||||
|
||||
# 如果是文件路径
|
||||
if isinstance(water_mask, str):
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
except ImportError:
|
||||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||||
|
||||
# 检查是否为shapefile
|
||||
if water_mask.lower().endswith('.shp'):
|
||||
# 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸)
|
||||
# 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考
|
||||
raise ValueError("Hedley类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
|
||||
else:
|
||||
# 栅格文件
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||||
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
def covariance_NIR(self,NIR,b):
|
||||
"""
|
||||
NIR & b are vectors
|
||||
reflectance for band i
|
||||
"""
|
||||
n = len(NIR)
|
||||
# 优化:减少重复计算,使用更高效的numpy操作
|
||||
nir_mean = np.mean(NIR)
|
||||
b_mean = np.mean(b)
|
||||
# 使用更高效的协方差计算
|
||||
pij = np.mean((NIR - nir_mean) * (b - b_mean))
|
||||
pjj = np.mean((NIR - nir_mean) ** 2)
|
||||
# 避免除零错误
|
||||
return pij / pjj if pjj != 0 else 0.0
|
||||
|
||||
def correlation_bands_reflectance(self):
|
||||
"""
|
||||
calculate correlation between NIR and other bands for reflectance
|
||||
NIR_band is 750 nm
|
||||
"""
|
||||
# If bbox is None, use the entire image
|
||||
if self.bbox is None:
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 直接使用视图,只在需要时创建扁平数组
|
||||
im_region = self.im_aligned
|
||||
mask_region = self.water_mask
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
im_region = self.im_aligned[y1:y2,x1:x2,:]
|
||||
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内计算相关性
|
||||
if mask_region is not None:
|
||||
mask_bool = mask_region.astype(bool)
|
||||
if mask_bool.any():
|
||||
# 只在掩膜内提取数据
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band][mask_bool]
|
||||
else:
|
||||
# 如果掩膜内没有有效像素,使用全区域
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
|
||||
mask_bool = None
|
||||
else:
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
|
||||
mask_bool = None
|
||||
|
||||
# 优化:一次性计算所有波段的相关性,减少循环开销
|
||||
corr_list = []
|
||||
for v in range(self.n_bands):
|
||||
if mask_bool is not None and mask_bool.any():
|
||||
band_reflectance = im_region[:,:,v][mask_bool]
|
||||
else:
|
||||
band_reflectance = im_region[:,:,v].ravel()
|
||||
corr = self.covariance_NIR(NIR_reflectance, band_reflectance)
|
||||
corr_list.append(corr)
|
||||
|
||||
return corr_list
|
||||
|
||||
def _save_corrected_bands(self, corrected_bands):
|
||||
"""
|
||||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_bands: 校正后的波段列表
|
||||
Hedley 耀斑去除算法 - 分块逐波段处理版本
|
||||
|
||||
:param img_path (str): 输入影像文件路径(GDAL可读取的格式)
|
||||
:param shp_path (str, optional): 深水区域shapefile,已废弃,请使用water_mask
|
||||
:param NIR_band (int): NIR波段索引(默认47,对应842.36nm)
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜
|
||||
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
|
||||
:param block_size (int): 分块大小(默认1000)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if self.output_path is None:
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 将波段列表转换为数组
|
||||
corrected_array = np.stack(corrected_bands, axis=2)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
self.img_path = img_path
|
||||
self.NIR_band = int(float(NIR_band))
|
||||
self.water_mask = None
|
||||
self.water_mask_path = water_mask
|
||||
self.output_path = output_path
|
||||
self.block_size = block_size
|
||||
self.R_min = None
|
||||
self.corr_list = None # 全局协方差系数列表
|
||||
|
||||
# 打开影像
|
||||
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if self.dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.n_bands = self.dataset.RasterCount
|
||||
|
||||
def _load_water_mask(self):
|
||||
"""延迟加载水域掩膜"""
|
||||
if self.water_mask_path is None:
|
||||
return None
|
||||
|
||||
if isinstance(self.water_mask_path, np.ndarray):
|
||||
if self.water_mask_path.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(
|
||||
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||||
)
|
||||
return (self.water_mask_path > 0).astype(np.uint8)
|
||||
|
||||
if isinstance(self.water_mask_path, str):
|
||||
if self.water_mask_path.lower().endswith('.shp'):
|
||||
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
|
||||
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(
|
||||
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||||
)
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
return None
|
||||
|
||||
def covariance_NIR(self, NIR, b):
|
||||
"""计算 NIR 与波段 b 之间的协方差系数 b_i = Cov(NIR,b) / Var(NIR)"""
|
||||
n = len(NIR)
|
||||
nir_mean = np.mean(NIR)
|
||||
b_mean = np.mean(b)
|
||||
pij = np.mean((NIR - nir_mean) * (b - b_mean))
|
||||
pjj = np.mean((NIR - nir_mean) ** 2)
|
||||
return pij / pjj if pjj != 0 else 0.0
|
||||
|
||||
def _scan_global_stats(self, sample_step=20):
|
||||
"""
|
||||
扫描全图获取全局 R_min
|
||||
|
||||
使用重采样方式扫描,大幅降低内存占用。
|
||||
"""
|
||||
print(f"[Hedley] 扫描全局统计量(采样步长={sample_step})...")
|
||||
water_mask = self._load_water_mask()
|
||||
|
||||
nir_samples = []
|
||||
sample_count = 0
|
||||
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
block_height = y_end - y_off
|
||||
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
|
||||
nir_band = None
|
||||
|
||||
if water_mask is not None:
|
||||
mask_block = water_mask[y_off:y_end, :]
|
||||
mask_bool = mask_block.astype(bool)
|
||||
else:
|
||||
mask_bool = np.ones((block_height, self.width), dtype=bool)
|
||||
|
||||
if mask_bool.any():
|
||||
nir_sampled = nir_block[mask_bool][::sample_step]
|
||||
nir_samples.append(nir_sampled)
|
||||
sample_count += nir_sampled.size
|
||||
|
||||
del nir_block, mask_block
|
||||
|
||||
if sample_count == 0:
|
||||
self.R_min = 0.0
|
||||
else:
|
||||
bsq_path = self.output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_array.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
all_nir = np.concatenate(nir_samples)
|
||||
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
|
||||
del all_nir
|
||||
|
||||
print(f"[Hedley] 全局 R_min={self.R_min:.4f}")
|
||||
|
||||
def _compute_corr_list(self, sample_step=5):
|
||||
"""
|
||||
计算每个波段与NIR的协方差系数 corr_list[b] = Cov(NIR, band_b) / Var(NIR)
|
||||
|
||||
全分辨率扫描,逐波段读取,每波段内存 ≈ block_size²
|
||||
由于需要相关性计算,需要足够多的样本,取sample_step=5
|
||||
"""
|
||||
print(f"[Hedley] 计算全局协方差系数列表(采样步长={sample_step})...")
|
||||
water_mask = self._load_water_mask()
|
||||
|
||||
# 预收集NIR和每个波段的样本数据
|
||||
nir_samples = []
|
||||
band_samples = [[] for _ in range(self.n_bands)]
|
||||
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
block_height = y_end - y_off
|
||||
|
||||
# 读取NIR波段(每块只读一次)
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
|
||||
nir_band = None
|
||||
|
||||
# 取 NIR 样本(每块只取一次,放在波段循环外)
|
||||
if water_mask is not None:
|
||||
mask_block = water_mask[y_off:y_end, :]
|
||||
mask_bool = mask_block.astype(bool)
|
||||
else:
|
||||
mask_bool = np.ones((block_height, self.width), dtype=bool)
|
||||
|
||||
if mask_bool.any():
|
||||
nir_sampled = nir_block[mask_bool][::sample_step]
|
||||
nir_samples.append(nir_sampled)
|
||||
|
||||
# 逐波段读取并采样(all_band 严格使用单波段切片)
|
||||
for b in range(self.n_bands):
|
||||
band = self.dataset.GetRasterBand(b + 1)
|
||||
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
|
||||
band = None
|
||||
|
||||
if mask_bool.any():
|
||||
band_sampled = block[mask_bool][::sample_step]
|
||||
band_samples[b].append(band_sampled)
|
||||
|
||||
del block
|
||||
|
||||
del nir_block
|
||||
|
||||
# 汇总并计算相关系数
|
||||
if len(nir_samples) == 0 or sum(len(s) for s in nir_samples) == 0:
|
||||
self.corr_list = [0.0] * self.n_bands
|
||||
else:
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
all_nir = np.concatenate(nir_samples)
|
||||
self.corr_list = []
|
||||
for b in range(self.n_bands):
|
||||
all_band = np.concatenate(band_samples[b])
|
||||
corr = self.covariance_NIR(all_nir, all_band)
|
||||
self.corr_list.append(float(corr))
|
||||
|
||||
del all_nir
|
||||
for b in range(self.n_bands):
|
||||
band_samples[b] = None
|
||||
|
||||
print(f"[Hedley] 协方差系数: min={min(self.corr_list):.4f}, max={max(self.corr_list):.4f}")
|
||||
|
||||
def _process_block(self, x_off, y_off, x_size, y_size):
|
||||
"""
|
||||
处理单个分块
|
||||
|
||||
Returns:
|
||||
list of np.ndarray: 校正后的波段列表
|
||||
"""
|
||||
# 读取NIR波段
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
NIR = nir_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
nir_band = None
|
||||
|
||||
# 预计算 NIR - R_min
|
||||
NIR_diff = NIR - self.R_min
|
||||
|
||||
# 获取掩膜
|
||||
water_mask = self._load_water_mask()
|
||||
if water_mask is not None:
|
||||
y_end = y_off + y_size
|
||||
x_end = x_off + x_size
|
||||
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
|
||||
else:
|
||||
mask_block = None
|
||||
|
||||
# 逐波段处理
|
||||
corrected_bands = []
|
||||
for b in range(self.n_bands):
|
||||
band = self.dataset.GetRasterBand(b + 1)
|
||||
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
band = None
|
||||
|
||||
corr = self.corr_list[b]
|
||||
# Hedley 校正公式:R_corrected = R - corr * (NIR - R_min)
|
||||
corrected = R - corr * NIR_diff
|
||||
|
||||
if mask_block is not None:
|
||||
corrected = np.where(mask_block, corrected, R)
|
||||
|
||||
corrected_bands.append(corrected)
|
||||
del R
|
||||
|
||||
del NIR, NIR_diff
|
||||
|
||||
return corrected_bands
|
||||
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
correction is done in reflectance
|
||||
|
||||
:return: 校正后的波段列表
|
||||
执行分块处理,返回校正后的波段列表
|
||||
"""
|
||||
corr = self.correlation_bands_reflectance()
|
||||
NIR_reflectance = self.im_aligned[:,:,self.NIR_band]
|
||||
# 预计算NIR-R_min,避免在循环中重复计算
|
||||
NIR_diff = NIR_reflectance - self.R_min
|
||||
|
||||
# 获取水域掩膜(如果存在)
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
if self.output_path is None:
|
||||
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
|
||||
|
||||
corrected_bands = []
|
||||
for band_number in range(self.n_bands): #iterate across bands
|
||||
b = corr[band_number]
|
||||
R = self.im_aligned[:,:,band_number]
|
||||
# 优化:减少中间数组创建
|
||||
corrected_band = R - b * NIR_diff
|
||||
|
||||
# 如果存在水域掩膜,只对水域区域应用校正
|
||||
if water_mask_bool is not None:
|
||||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||||
|
||||
corrected_bands.append(corrected_band)
|
||||
|
||||
# 如果提供了输出路径,保存结果
|
||||
if self.output_path is not None:
|
||||
self._save_corrected_bands(corrected_bands)
|
||||
|
||||
return corrected_bands
|
||||
# Step 1: 扫描全局 R_min
|
||||
self._scan_global_stats(sample_step=20)
|
||||
|
||||
# Step 2: 计算协方差系数列表
|
||||
self._compute_corr_list(sample_step=5)
|
||||
|
||||
# Step 3: 创建输出文件
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
|
||||
|
||||
geotransform = self.dataset.GetGeoTransform()
|
||||
projection = self.dataset.GetProjection()
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
out_dataset = driver.Create(bsq_path, self.width, self.height,
|
||||
self.n_bands, gdal.GDT_Float32)
|
||||
if out_dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
|
||||
# Step 4: 分块处理
|
||||
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
|
||||
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
|
||||
total_blocks = n_blocks_x * n_blocks_y
|
||||
|
||||
print(f"[Hedley] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
|
||||
|
||||
block_idx = 0
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
y_size = y_end - y_off
|
||||
|
||||
for x_off in range(0, self.width, self.block_size):
|
||||
x_end = min(x_off + self.block_size, self.width)
|
||||
x_size = x_end - x_off
|
||||
block_idx += 1
|
||||
|
||||
print(f"[Hedley] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
|
||||
|
||||
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
|
||||
|
||||
for b in range(self.n_bands):
|
||||
out_band = out_dataset.GetRasterBand(b + 1)
|
||||
out_band.WriteArray(corrected_bands[b], x_off, y_off)
|
||||
out_band.FlushCache()
|
||||
|
||||
del corrected_bands
|
||||
|
||||
out_dataset = None
|
||||
self.dataset = None
|
||||
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"[Hedley] 校正完成,已保存至: {bsq_path}")
|
||||
else:
|
||||
print(f"[Hedley] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)")
|
||||
|
||||
return []
|
||||
|
||||
def __del__(self):
|
||||
if self.dataset is not None:
|
||||
self.dataset = None
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user