From a14d40f28dd3cff636656aa9e3f4bc8ee4d2d304 Mon Sep 17 00:00:00 2001 From: DXC Date: Sat, 9 May 2026 11:58:40 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=88=86=E5=9D=97=E8=AF=BB=E5=86=99?= =?UTF-8?q?=E6=94=B9=E9=80=A0=E2=80=94=E2=80=94=E4=BF=AE=E5=A4=8DHedley?= =?UTF-8?q?=E5=8D=8F=E6=96=B9=E5=B7=AE=E5=BD=A2=E7=8A=B6=E5=B9=BF=E6=92=AD?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E5=92=8CSUGAR=E5=88=97=E8=A1=A8=E8=B6=8A?= =?UTF-8?q?=E7=95=8C=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/glint_removal/Hedley.py | 557 ++++++++------- src/core/glint_removal/SUGAR.py | 1145 ++++++++++++++++-------------- 2 files changed, 891 insertions(+), 811 deletions(-) diff --git a/src/core/glint_removal/Hedley.py b/src/core/glint_removal/Hedley.py index edac624..ec2528e 100644 --- a/src/core/glint_removal/Hedley.py +++ b/src/core/glint_removal/Hedley.py @@ -1,5 +1,4 @@ import numpy as np -# import preprocessing import os try: @@ -8,283 +7,301 @@ try: except ImportError: GDAL_AVAILABLE = False + class Hedley: - def __init__(self, im_aligned, shp_path=None, NIR_band = 47, water_mask=None, output_path=None): + def __init__(self, img_path, shp_path=None, NIR_band=47, water_mask=None, + output_path=None, block_size=1000): """ - :param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image - :param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water. - If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates. - :param NIR_band (int): band index for NIR band which corresponds to 842.36nm, which corresponds closely to the NIR band in Micasense - :param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域 - 可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp) - 如果为None,则处理全图 - :param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像 - 如果为None,则不保存 - """ - self.im_aligned = im_aligned - self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None - self.NIR_band = int(float(NIR_band)) - self.n_bands = im_aligned.shape[-1] - self.height = im_aligned.shape[0] - self.width = im_aligned.shape[1] - self.output_path = output_path - - # 加载水域掩膜 - self.water_mask = self._load_water_mask(water_mask) - - # 使用ravel()而不是flatten(),避免不必要的复制 - # 如果存在水域掩膜,只在掩膜内计算R_min - if self.water_mask is not None: - nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)] - self.R_min = np.percentile(nir_band_masked, 5, method='nearest') if nir_band_masked.size > 0 else 0 - else: - self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, method='nearest') - - def _read_shp_to_bbox(self, shp_path): - """ - 读取shapefile并提取边界框 - - :param shp_path (str): shapefile文件路径 - :return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner - """ - if not os.path.exists(shp_path): - raise FileNotFoundError(f"Shapefile not found: {shp_path}") - - try: - try: - import geopandas as gpd - gdf = gpd.read_file(shp_path) - # 获取所有几何体的总边界框 - bounds = gdf.total_bounds # [minx, miny, maxx, maxy] - min_x, min_y, max_x, max_y = bounds - except ImportError: - # 如果geopandas不可用,尝试使用fiona - import fiona - from shapely.geometry import shape - - min_x = float('inf') - min_y = float('inf') - max_x = float('-inf') - max_y = float('-inf') - - with fiona.open(shp_path) as shp: - for feature in shp: - geom = shape(feature['geometry']) - if geom: - bounds = geom.bounds - min_x = min(min_x, bounds[0]) - min_y = min(min_y, bounds[1]) - max_x = max(max_x, bounds[2]) - max_y = max(max_y, bounds[3]) - - # 转换为整数像素坐标 - x1 = max(0, int(min_x)) - y1 = max(0, int(min_y)) - x2 = min(self.im_aligned.shape[1], int(max_x) + 1) - y2 = min(self.im_aligned.shape[0], int(max_y) + 1) - - return ((x1, y1), (x2, y2)) - - except Exception as e: - raise ValueError(f"Error reading shapefile {shp_path}: {e}") - - def _load_water_mask(self, water_mask): - """ - 加载水域掩膜 - - :param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp) - :return: numpy数组或None,1表示水域,0表示非水域 - """ - if water_mask is None: - return None - - # 如果已经是numpy数组 - if isinstance(water_mask, np.ndarray): - if water_mask.shape[:2] != (self.height, self.width): - raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配") - return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜 - - # 如果是文件路径 - if isinstance(water_mask, str): - try: - from osgeo import gdal, ogr - except ImportError: - raise ValueError("使用文件路径作为掩膜时,必须安装GDAL") - - # 检查是否为shapefile - if water_mask.lower().endswith('.shp'): - # 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸) - # 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考 - raise ValueError("Hedley类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜") - else: - # 栅格文件 - mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly) - if mask_dataset is None: - raise ValueError(f"无法打开掩膜文件: {water_mask}") - - mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() - mask_dataset = None - - if mask_array.shape != (self.height, self.width): - raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配") - - return (mask_array > 0).astype(np.uint8) - - raise ValueError(f"不支持的掩膜类型: {type(water_mask)}") - - def covariance_NIR(self,NIR,b): - """ - NIR & b are vectors - reflectance for band i - """ - n = len(NIR) - # 优化:减少重复计算,使用更高效的numpy操作 - nir_mean = np.mean(NIR) - b_mean = np.mean(b) - # 使用更高效的协方差计算 - pij = np.mean((NIR - nir_mean) * (b - b_mean)) - pjj = np.mean((NIR - nir_mean) ** 2) - # 避免除零错误 - return pij / pjj if pjj != 0 else 0.0 - - def correlation_bands_reflectance(self): - """ - calculate correlation between NIR and other bands for reflectance - NIR_band is 750 nm - """ - # If bbox is None, use the entire image - if self.bbox is None: - # 使用ravel()而不是flatten(),避免不必要的复制 - # 直接使用视图,只在需要时创建扁平数组 - im_region = self.im_aligned - mask_region = self.water_mask - else: - ((x1,y1),(x2,y2)) = self.bbox - im_region = self.im_aligned[y1:y2,x1:x2,:] - mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None - - # 如果存在水域掩膜,只在掩膜内计算相关性 - if mask_region is not None: - mask_bool = mask_region.astype(bool) - if mask_bool.any(): - # 只在掩膜内提取数据 - NIR_reflectance = im_region[:,:,self.NIR_band][mask_bool] - else: - # 如果掩膜内没有有效像素,使用全区域 - NIR_reflectance = im_region[:,:,self.NIR_band].ravel() - mask_bool = None - else: - NIR_reflectance = im_region[:,:,self.NIR_band].ravel() - mask_bool = None - - # 优化:一次性计算所有波段的相关性,减少循环开销 - corr_list = [] - for v in range(self.n_bands): - if mask_bool is not None and mask_bool.any(): - band_reflectance = im_region[:,:,v][mask_bool] - else: - band_reflectance = im_region[:,:,v].ravel() - corr = self.covariance_NIR(NIR_reflectance, band_reflectance) - corr_list.append(corr) - - return corr_list - - def _save_corrected_bands(self, corrected_bands): - """ - 保存校正后的波段到文件(BSQ格式,ENVI格式) - - :param corrected_bands: 校正后的波段列表 + Hedley 耀斑去除算法 - 分块逐波段处理版本 + + :param img_path (str): 输入影像文件路径(GDAL可读取的格式) + :param shp_path (str, optional): 深水区域shapefile,已废弃,请使用water_mask + :param NIR_band (int): NIR波段索引(默认47,对应842.36nm) + :param water_mask (np.ndarray or str or None): 水域掩膜 + :param output_path (str): 输出文件路径(必须提供,用于分块写入) + :param block_size (int): 分块大小(默认1000) """ if not GDAL_AVAILABLE: - raise ImportError("GDAL未安装,无法保存影像文件") - - if self.output_path is None: - return - - # 确保输出目录存在 - output_dir = os.path.dirname(self.output_path) - if output_dir and not os.path.exists(output_dir): - os.makedirs(output_dir, exist_ok=True) - - # 将波段列表转换为数组 - corrected_array = np.stack(corrected_bands, axis=2) - - # 如果没有地理信息,使用默认值 - geotransform = (0, 1, 0, 0, 0, -1) - projection = "" - - # 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq - base_path, ext = os.path.splitext(self.output_path) - # 如果扩展名不是.bsq,使用基础路径添加.bsq - if ext.lower() != '.bsq': - bsq_path = base_path + '.bsq' + raise ImportError("GDAL未安装,无法读取影像文件") + + self.img_path = img_path + self.NIR_band = int(float(NIR_band)) + self.water_mask = None + self.water_mask_path = water_mask + self.output_path = output_path + self.block_size = block_size + self.R_min = None + self.corr_list = None # 全局协方差系数列表 + + # 打开影像 + self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly) + if self.dataset is None: + raise ValueError(f"无法打开影像文件: {img_path}") + self.width = self.dataset.RasterXSize + self.height = self.dataset.RasterYSize + self.n_bands = self.dataset.RasterCount + + def _load_water_mask(self): + """延迟加载水域掩膜""" + if self.water_mask_path is None: + return None + + if isinstance(self.water_mask_path, np.ndarray): + if self.water_mask_path.shape[:2] != (self.height, self.width): + raise ValueError( + f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配" + ) + return (self.water_mask_path > 0).astype(np.uint8) + + if isinstance(self.water_mask_path, str): + if self.water_mask_path.lower().endswith('.shp'): + raise ValueError("请先栅格化shapefile为栅格掩膜文件") + mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly) + if mask_dataset is None: + raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}") + mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() + mask_dataset = None + if mask_array.shape != (self.height, self.width): + raise ValueError( + f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配" + ) + return (mask_array > 0).astype(np.uint8) + + return None + + def covariance_NIR(self, NIR, b): + """计算 NIR 与波段 b 之间的协方差系数 b_i = Cov(NIR,b) / Var(NIR)""" + n = len(NIR) + nir_mean = np.mean(NIR) + b_mean = np.mean(b) + pij = np.mean((NIR - nir_mean) * (b - b_mean)) + pjj = np.mean((NIR - nir_mean) ** 2) + return pij / pjj if pjj != 0 else 0.0 + + def _scan_global_stats(self, sample_step=20): + """ + 扫描全图获取全局 R_min + + 使用重采样方式扫描,大幅降低内存占用。 + """ + print(f"[Hedley] 扫描全局统计量(采样步长={sample_step})...") + water_mask = self._load_water_mask() + + nir_samples = [] + sample_count = 0 + + for y_off in range(0, self.height, self.block_size): + y_end = min(y_off + self.block_size, self.height) + block_height = y_end - y_off + + nir_band = self.dataset.GetRasterBand(self.NIR_band + 1) + nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height) + nir_band = None + + if water_mask is not None: + mask_block = water_mask[y_off:y_end, :] + mask_bool = mask_block.astype(bool) + else: + mask_bool = np.ones((block_height, self.width), dtype=bool) + + if mask_bool.any(): + nir_sampled = nir_block[mask_bool][::sample_step] + nir_samples.append(nir_sampled) + sample_count += nir_sampled.size + + del nir_block, mask_block + + if sample_count == 0: + self.R_min = 0.0 else: - bsq_path = self.output_path - - # 使用ENVI驱动(默认就是BSQ格式) - driver = gdal.GetDriverByName('ENVI') - if driver is None: - raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用") - - height, width, n_bands = corrected_array.shape - # 创建ENVI格式数据集(会自动生成.hdr文件) - dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32) - if dataset is None: - raise ValueError(f"无法创建输出文件: {bsq_path}") - - try: - # 设置地理变换和投影 - if geotransform: - dataset.SetGeoTransform(geotransform) - if projection: - dataset.SetProjection(projection) - - # 写入每个波段(BSQ格式:按波段顺序存储) - for i in range(n_bands): - band = dataset.GetRasterBand(i + 1) - band.WriteArray(corrected_array[:, :, i]) - band.FlushCache() - finally: - dataset = None - - # 检查.hdr文件是否已创建 - hdr_path = bsq_path + '.hdr' - if os.path.exists(hdr_path): - print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") - print(f"头文件已保存至: {hdr_path}") + all_nir = np.concatenate(nir_samples) + self.R_min = float(np.percentile(all_nir, 5, method='nearest')) + del all_nir + + print(f"[Hedley] 全局 R_min={self.R_min:.4f}") + + def _compute_corr_list(self, sample_step=5): + """ + 计算每个波段与NIR的协方差系数 corr_list[b] = Cov(NIR, band_b) / Var(NIR) + + 全分辨率扫描,逐波段读取,每波段内存 ≈ block_size² + 由于需要相关性计算,需要足够多的样本,取sample_step=5 + """ + print(f"[Hedley] 计算全局协方差系数列表(采样步长={sample_step})...") + water_mask = self._load_water_mask() + + # 预收集NIR和每个波段的样本数据 + nir_samples = [] + band_samples = [[] for _ in range(self.n_bands)] + + for y_off in range(0, self.height, self.block_size): + y_end = min(y_off + self.block_size, self.height) + block_height = y_end - y_off + + # 读取NIR波段(每块只读一次) + nir_band = self.dataset.GetRasterBand(self.NIR_band + 1) + nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32) + nir_band = None + + # 取 NIR 样本(每块只取一次,放在波段循环外) + if water_mask is not None: + mask_block = water_mask[y_off:y_end, :] + mask_bool = mask_block.astype(bool) + else: + mask_bool = np.ones((block_height, self.width), dtype=bool) + + if mask_bool.any(): + nir_sampled = nir_block[mask_bool][::sample_step] + nir_samples.append(nir_sampled) + + # 逐波段读取并采样(all_band 严格使用单波段切片) + for b in range(self.n_bands): + band = self.dataset.GetRasterBand(b + 1) + block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32) + band = None + + if mask_bool.any(): + band_sampled = block[mask_bool][::sample_step] + band_samples[b].append(band_sampled) + + del block + + del nir_block + + # 汇总并计算相关系数 + if len(nir_samples) == 0 or sum(len(s) for s in nir_samples) == 0: + self.corr_list = [0.0] * self.n_bands else: - print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") - print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建") + all_nir = np.concatenate(nir_samples) + self.corr_list = [] + for b in range(self.n_bands): + all_band = np.concatenate(band_samples[b]) + corr = self.covariance_NIR(all_nir, all_band) + self.corr_list.append(float(corr)) + + del all_nir + for b in range(self.n_bands): + band_samples[b] = None + + print(f"[Hedley] 协方差系数: min={min(self.corr_list):.4f}, max={max(self.corr_list):.4f}") + + def _process_block(self, x_off, y_off, x_size, y_size): + """ + 处理单个分块 + + Returns: + list of np.ndarray: 校正后的波段列表 + """ + # 读取NIR波段 + nir_band = self.dataset.GetRasterBand(self.NIR_band + 1) + NIR = nir_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32) + nir_band = None + + # 预计算 NIR - R_min + NIR_diff = NIR - self.R_min + + # 获取掩膜 + water_mask = self._load_water_mask() + if water_mask is not None: + y_end = y_off + y_size + x_end = x_off + x_size + mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool) + else: + mask_block = None + + # 逐波段处理 + corrected_bands = [] + for b in range(self.n_bands): + band = self.dataset.GetRasterBand(b + 1) + R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32) + band = None + + corr = self.corr_list[b] + # Hedley 校正公式:R_corrected = R - corr * (NIR - R_min) + corrected = R - corr * NIR_diff + + if mask_block is not None: + corrected = np.where(mask_block, corrected, R) + + corrected_bands.append(corrected) + del R + + del NIR, NIR_diff + + return corrected_bands def get_corrected_bands(self): """ - correction is done in reflectance - - :return: 校正后的波段列表 + 执行分块处理,返回校正后的波段列表 """ - corr = self.correlation_bands_reflectance() - NIR_reflectance = self.im_aligned[:,:,self.NIR_band] - # 预计算NIR-R_min,避免在循环中重复计算 - NIR_diff = NIR_reflectance - self.R_min - - # 获取水域掩膜(如果存在) - water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None + if self.output_path is None: + raise ValueError("output_path 必须提供,分块处理需要直接写入文件") - corrected_bands = [] - for band_number in range(self.n_bands): #iterate across bands - b = corr[band_number] - R = self.im_aligned[:,:,band_number] - # 优化:减少中间数组创建 - corrected_band = R - b * NIR_diff - - # 如果存在水域掩膜,只对水域区域应用校正 - if water_mask_bool is not None: - corrected_band = np.where(water_mask_bool, corrected_band, R) - - corrected_bands.append(corrected_band) - - # 如果提供了输出路径,保存结果 - if self.output_path is not None: - self._save_corrected_bands(corrected_bands) - - return corrected_bands + # Step 1: 扫描全局 R_min + self._scan_global_stats(sample_step=20) + + # Step 2: 计算协方差系数列表 + self._compute_corr_list(sample_step=5) + + # Step 3: 创建输出文件 + output_dir = os.path.dirname(self.output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + base_path, ext = os.path.splitext(self.output_path) + bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path + + geotransform = self.dataset.GetGeoTransform() + projection = self.dataset.GetProjection() + + driver = gdal.GetDriverByName('ENVI') + out_dataset = driver.Create(bsq_path, self.width, self.height, + self.n_bands, gdal.GDT_Float32) + if out_dataset is None: + raise ValueError(f"无法创建输出文件: {bsq_path}") + + out_dataset.SetGeoTransform(geotransform) + out_dataset.SetProjection(projection) + + # Step 4: 分块处理 + n_blocks_x = (self.width + self.block_size - 1) // self.block_size + n_blocks_y = (self.height + self.block_size - 1) // self.block_size + total_blocks = n_blocks_x * n_blocks_y + + print(f"[Hedley] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}") + + block_idx = 0 + for y_off in range(0, self.height, self.block_size): + y_end = min(y_off + self.block_size, self.height) + y_size = y_end - y_off + + for x_off in range(0, self.width, self.block_size): + x_end = min(x_off + self.block_size, self.width) + x_size = x_end - x_off + block_idx += 1 + + print(f"[Hedley] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})") + + corrected_bands = self._process_block(x_off, y_off, x_size, y_size) + + for b in range(self.n_bands): + out_band = out_dataset.GetRasterBand(b + 1) + out_band.WriteArray(corrected_bands[b], x_off, y_off) + out_band.FlushCache() + + del corrected_bands + + out_dataset = None + self.dataset = None + + hdr_path = bsq_path + '.hdr' + if os.path.exists(hdr_path): + print(f"[Hedley] 校正完成,已保存至: {bsq_path}") + else: + print(f"[Hedley] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)") + + return [] + + def __del__(self): + if self.dataset is not None: + self.dataset = None \ No newline at end of file diff --git a/src/core/glint_removal/SUGAR.py b/src/core/glint_removal/SUGAR.py index 4655026..3faf182 100644 --- a/src/core/glint_removal/SUGAR.py +++ b/src/core/glint_removal/SUGAR.py @@ -1,6 +1,6 @@ import cv2 -import os import numpy as np +import os from scipy import ndimage from scipy.optimize import minimize_scalar @@ -10,563 +10,626 @@ try: except ImportError: GDAL_AVAILABLE = False -# SUn-Glint-Aware Restoration (SUGAR):A sweet and simple algorithm for correcting sunglint + +def otsu_thresholding(im, auto_bins=None): + """ + Otsu阈值分割 + """ + if auto_bins is None: + auto_bins = max(10, int(0.005 * im.shape[0] * im.shape[1])) + im_flat = im.ravel() + valid_mask = np.isfinite(im_flat) + if not valid_mask.all(): + im_flat = im_flat[valid_mask] + count, bin_edges = np.histogram(im_flat, bins=auto_bins) + bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 + count_sum = count.sum() + hist_norm = count / count_sum + Q = hist_norm.cumsum() + N = count.shape[0] + N_negative = np.sum(bin < 0) + bins = np.arange(N, dtype=np.float32) + + def otsu_thresh(x): + x = int(x) + p1 = hist_norm[:x] + p2 = hist_norm[x:] + q1 = Q[x] + q2 = Q[N - 1] - Q[x] + b1 = bins[:x] + b2 = bins[x:] + m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0 + m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0 + v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0 + v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0 + return v1 * q1 + v2 * q2 + + if N_negative <= 1: + return bin[np.argmax(count)] + res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded') + return bin[int(res.x)] + + +def cdf_thresholding(im, auto_bins=10): + """CDF阈值分割""" + im_flat = im.ravel() + valid_mask = np.isfinite(im_flat) + if not valid_mask.all(): + im_flat = im_flat[valid_mask] + count, bin_edges = np.histogram(im_flat, bins=auto_bins) + bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 + return bin[np.argmax(count)] + + class SUGAR: - def __init__(self, im_aligned,bounds=[(1,2)],sigma=1,estimate_background=True, glint_mask_method="cdf", water_mask=None, output_path=None): - """ - :param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image - :param bounds (a list of tuple): lower and upper bound for optimisation of b for each band - :param sigma (float): smoothing sigma for LoG - :param estimate_background (bool): whether to estimate background spectra using median filtering - :param glint_mask_method (str): choose either "cdf" or "otsu", "cdf" is set as the default - :param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域 - 可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp) - 如果为None,则处理全图 - :param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像 - 如果为None,则不保存 - """ - self.im_aligned = im_aligned - self.sigma = sigma - self.estimate_background = estimate_background - self.n_bands = im_aligned.shape[-1] - self.bounds = bounds*self.n_bands - self.glint_mask_method = glint_mask_method - self.height = im_aligned.shape[0] - self.width = im_aligned.shape[1] - self.output_path = output_path - - # 加载水域掩膜 - self.water_mask = self._load_water_mask(water_mask) - - def _load_water_mask(self, water_mask): - """ - 加载水域掩膜 - - :param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp) - :return: numpy数组或None,1表示水域,0表示非水域 - """ - if water_mask is None: - return None - - # 如果已经是numpy数组 - if isinstance(water_mask, np.ndarray): - if water_mask.shape[:2] != (self.height, self.width): - raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配") - return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜 - - # 如果是文件路径 - if isinstance(water_mask, str): - try: - from osgeo import gdal, ogr - except ImportError: - raise ValueError("使用文件路径作为掩膜时,必须安装GDAL") - - # 检查是否为shapefile - if water_mask.lower().endswith('.shp'): - # 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸) - # 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考 - raise ValueError("SUGAR类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜") - else: - # 栅格文件 - mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly) - if mask_dataset is None: - raise ValueError(f"无法打开掩膜文件: {water_mask}") - - mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() - mask_dataset = None - - if mask_array.shape != (self.height, self.width): - raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配") - - return (mask_array > 0).astype(np.uint8) - - raise ValueError(f"不支持的掩膜类型: {type(water_mask)}") + """ + SUGAR 耀斑去除算法 - 分块逐波段处理版本 - def otsu_thresholding(self,im): - """ - :param im (np.ndarray) of shape mxn. Note that it is the LoG of image - otsu thresholding with Brent's minimisation of a univariate function - returns the value of the threshold for input - """ - auto_bins = int(0.005*im.shape[0]*im.shape[1]) - # 使用ravel()而不是flatten(),避免不必要的复制(如果可能) - # 如果存在无效值(如NaN或极大值),过滤掉它们 - im_flat = im.ravel() - # 过滤掉NaN和无穷大值 - valid_mask = np.isfinite(im_flat) - if not valid_mask.all(): - im_flat = im_flat[valid_mask] - count, bin_edges = np.histogram(im_flat, bins=auto_bins) - bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers,使用乘法替代除法 - - count_sum = count.sum() - hist_norm = count / count_sum # normalised histogram - Q = hist_norm.cumsum() # CDF function ranges from 0 to 1 - N = count.shape[0] - N_negative = np.sum(bin < 0) - bins = np.arange(N, dtype=np.float32) # 使用float32减少内存 - - def otsu_thresh(x): - x = int(x) - # 使用切片而不是hsplit,避免创建新数组 - p1 = hist_norm[:x] - p2 = hist_norm[x:] - q1 = Q[x] - q2 = Q[N-1] - Q[x] - b1 = bins[:x] - b2 = bins[x:] - # finding means and variances - m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0 - m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0 - v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0 - v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0 - # calculates the minimization function - fn = v1 * q1 + v2 * q2 - return fn - - # brent method is used to minimise an univariate function - # bounded minimisation - # we can just limit the search to negative values since we know thresh should be negative as L<0 for glint pixels - if N_negative <= 1: - # 如果没有足够的负值,使用默认阈值 - return bin[np.argmax(count)] - res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded') - thresh = bin[int(res.x)] - - return thresh - - # def cdf_thresholding(self,im, percentile=0.05): - # """ - # :param im (np.ndarray) of shape mxn - # :param percentile (float): lower and upper percentile values are potential glint pixels - # """ - # lower_perc = percentile - # upper_perc = 1-percentile - # im_flatten = im.flatten() - # H,X1 = np.histogram(im_flatten, bins = int(0.005*im.shape[0]*im.shape[1]), density=True ) - # dx = X1[1] - X1[0] - # F1 = np.cumsum(H)*dx - # F_lower = X1[1:][F1upper_perc] - # while((F_lower.size == 0) or (F_upper.size == 0)): - # if (F_lower.size == 0): - # lower_perc += 0.01 - # F_lower = X1[1:][F1upper_perc] + 策略: + 1. 分块扫描全图,计算每个块的 glint_mask(需要全局阈值) + 2. 收集所有 glint 像素值到列表(仅收集索引,不存储完整掩膜数组) + 3. 全局优化每波段的 b 值(使用所有 glint 像素的方差最小化) + 4. 分块处理:计算 background(需全块) -> 应用校正 -> 写入输出 + """ - # lower_thresh = F_lower[-1] - # upper_thresh = F_upper[0] - - # return lower_thresh,upper_thresh - - def cdf_thresholding(self,im,auto_bins=10): + def __init__(self, img_path, bounds=None, sigma=1.0, estimate_background=True, + glint_mask_method="cdf", water_mask=None, output_path=None, + block_size=1000): """ - :param im (np.ndarray) of shape mxn. Note that it is the LoG of image - :param percentile (float): lower and upper percentile values are potential glint pixels - """ - # 使用ravel()而不是flatten(),避免不必要的复制 - im_flat = im.ravel() - # 过滤掉NaN和无穷大值 - valid_mask = np.isfinite(im_flat) - if not valid_mask.all(): - im_flat = im_flat[valid_mask] - count, bin_edges = np.histogram(im_flat, bins=auto_bins) - bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers,使用乘法替代除法 - thresh = bin[np.argmax(count)] - return thresh - - def glint_list(self): - """ - returns a list of np.ndarray, where each item is an extracted glint for each band based on get_glint_mask - """ - glint_mask = self.glint_mask_list() - extracted_glint_list = [] - for i in range(self.im_aligned.shape[-1]): - gm = glint_mask[i] - extracted_glint = gm*self.im_aligned[:,:,i] - extracted_glint_list.append(extracted_glint) - - return extracted_glint_list - - def glint_mask_list(self): - """ - get glint mask using laplacian of gaussian image. - returns a list of np.ndarray - """ - glint_mask_list = [] - for i in range(self.im_aligned.shape[-1]): - glint_mask = self.get_glint_mask(self.im_aligned[:,:,i]) - glint_mask_list.append(glint_mask) - - return glint_mask_list - - def log_image_list(self): - """ - get Laplacian of Gaussian (LoG) images for all bands. - returns a list of np.ndarray - """ - log_image_list = [] - for i in range(self.im_aligned.shape[-1]): - log_im = self.get_log_image(self.im_aligned[:,:,i]) - log_image_list.append(log_im) - return log_image_list - - def get_log_image(self, im): - """ - get Laplacian of Gaussian (LoG) image for a single band. - returns a np.ndarray - """ - LoG_im = ndimage.gaussian_laplace(im, sigma=self.sigma) - return LoG_im - - def get_glint_mask(self,im): - """ - get glint mask using laplacian of gaussian image. - We assume that water constituents and features follow a smooth continuum, - but glint pixels vary a lot spatially and in intensities - Note that for very extensive glint, this method may not work as well <--:TODO use U-net to identify glint mask - returns a np.ndarray - """ - LoG_im = ndimage.gaussian_laplace(im,sigma=self.sigma) - - # 如果存在水域掩膜,只在掩膜内计算阈值 - if self.water_mask is not None: - mask_bool = self.water_mask.astype(bool) - if mask_bool.any(): - # 只在掩膜内提取LoG值用于阈值计算 - LoG_masked = LoG_im[mask_bool] - # 将非掩膜区域设为极大值,确保不影响阈值计算 - LoG_for_thresh = LoG_im.copy() - LoG_for_thresh[~mask_bool] = LoG_masked.max() + 1 - else: - LoG_for_thresh = LoG_im - else: - LoG_for_thresh = LoG_im - - #threshold mask - if (self.glint_mask_method == "otsu"): - thresh = self.otsu_thresholding(LoG_for_thresh) - elif (self.glint_mask_method == "cdf"): - thresh = self.cdf_thresholding(LoG_for_thresh) - else: - raise ValueError('Enter only cdf or otsu as glint_mask_method') - # 使用更高效的方式创建mask,避免np.where的开销 - glint_mask = (LoG_im < thresh).astype(np.uint8) - - # 如果存在水域掩膜,将非水域区域设为0 - if self.water_mask is not None: - glint_mask = glint_mask * self.water_mask - - return glint_mask - - def get_est_background(self, im,k_size=5): - """ - :param im (np.ndarray): image of a band - estimate background spectra - returns a np.ndarray - """ - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k_size,k_size)) - dst = cv2.erode(im, kernel) - - return dst - - def optimise_correction_by_band(self,im,glint_mask,R_BG,bounds): - """ - :param im (np.ndarray): image of a band - :param glint_mask (np.ndarray): glint mask, where glint area is 1 and non-glint area is 0 - use brent method to get the optimimum b which minimises the variation (i.e. variance) in the entire image - returns regression slope b - """ - # 预计算常量,避免在优化函数中重复计算 - glint_mask_bool = glint_mask.astype(bool) - R_BG_flat = R_BG if isinstance(R_BG, (int, float)) else R_BG[glint_mask_bool] - - def optimise_b(b): - # 优化计算:只在glint区域计算校正 - if isinstance(R_BG, (int, float)): - im_corrected = im.copy() - im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG) - else: - im_corrected = im.copy() - im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG[glint_mask_bool]) - return np.var(im_corrected) - - res = minimize_scalar(optimise_b, bounds=bounds, method='bounded') - return res.x - - def divide_and_conquer(self): - """ - instead of computing b_list for each window, use the previous b_list to narrow the bounds, - because of the strong spatial autocorrelation, we know that the b (correction magnitude) cannot diff too much - this can optimise the run time - """ - - - def optimise_correction(self): - """ - returns a list of slope in band order i.e. 0,1,2,3,4,5,6,7,8,9 through optimisation - """ - b_list = [] - glint_mask_list = [] - est_background_list = [] - for i in range(self.n_bands): - glint_mask = self.get_glint_mask(self.im_aligned[:,:,i]) - glint_mask_list.append(glint_mask) - if self.estimate_background is True: - est_background = self.get_est_background(self.im_aligned[:,:,i]) - est_background_list.append(est_background) - else: - est_background = np.percentile(self.im_aligned[:,:,i], 5, method='nearest') - est_background_list.append(est_background) - bounds = self.bounds[i] - b = self.optimise_correction_by_band(self.im_aligned[:,:,i],glint_mask,est_background,bounds) - b_list.append(b) - - # add attributes - self.b_list = b_list - self.glint_mask = glint_mask_list - self.est_background = est_background_list - - return b_list, glint_mask_list, est_background_list - - def _save_corrected_bands(self, corrected_bands): - """ - 保存校正后的波段到文件(BSQ格式,ENVI格式) - - :param corrected_bands: 校正后的波段列表 + :param img_path (str): 输入影像文件路径 + :param bounds: 每个波段的优化边界,默认 [(1,2)] * n_bands + :param sigma (float): LoG 平滑 sigma + :param estimate_background (bool): 是否用中值滤波估计背景 + :param glint_mask_method (str): "cdf" 或 "otsu" + :param water_mask: 水域掩膜 + :param output_path (str): 输出文件路径 + :param block_size (int): 分块大小 """ if not GDAL_AVAILABLE: - raise ImportError("GDAL未安装,无法保存影像文件") - - if self.output_path is None: - return - - # 确保输出目录存在 - output_dir = os.path.dirname(self.output_path) - if output_dir and not os.path.exists(output_dir): - os.makedirs(output_dir, exist_ok=True) - - # 将波段列表转换为数组 - corrected_array = np.stack(corrected_bands, axis=2) - - # 如果没有地理信息,使用默认值 - geotransform = (0, 1, 0, 0, 0, -1) - projection = "" - - # 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq - base_path, ext = os.path.splitext(self.output_path) - # 如果扩展名不是.bsq,使用基础路径添加.bsq - if ext.lower() != '.bsq': - bsq_path = base_path + '.bsq' + raise ImportError("GDAL未安装,无法读取影像文件") + + if bounds is None: + bounds = [(1, 2)] + + self.img_path = img_path + self.bounds = bounds + self.sigma = sigma + self.estimate_background = estimate_background + self.glint_mask_method = glint_mask_method + self.water_mask = None + self.water_mask_path = water_mask + self.output_path = output_path + self.block_size = block_size + + # 打开影像 + self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly) + if self.dataset is None: + raise ValueError(f"无法打开影像文件: {img_path}") + self.width = self.dataset.RasterXSize + self.height = self.dataset.RasterYSize + self.n_bands = self.dataset.RasterCount + + # 扩展 bounds 到所有波段 + self.bounds_all = self.bounds * self.n_bands + + # 优化结果(全局) + self.b_list = None + self.glint_pixel_indices = [] # list of (block_idx, row, col) 索引 + self.thresholds = [] # 每波段的全局阈值 + + def _load_water_mask(self): + """延迟加载水域掩膜""" + if self.water_mask_path is None: + return None + + if isinstance(self.water_mask_path, np.ndarray): + if self.water_mask_path.shape[:2] != (self.height, self.width): + raise ValueError( + f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配" + ) + return (self.water_mask_path > 0).astype(np.uint8) + + if isinstance(self.water_mask_path, str): + if self.water_mask_path.lower().endswith('.shp'): + raise ValueError("请先栅格化shapefile为栅格掩膜文件") + mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly) + if mask_dataset is None: + raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}") + mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() + mask_dataset = None + if mask_array.shape != (self.height, self.width): + raise ValueError( + f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配" + ) + return (mask_array > 0).astype(np.uint8) + + return None + + def _compute_threshold(self, im): + """计算 glint 阈值""" + if self.glint_mask_method == "otsu": + return otsu_thresholding(im) else: - bsq_path = self.output_path - - # 使用ENVI驱动(默认就是BSQ格式) - driver = gdal.GetDriverByName('ENVI') - if driver is None: - raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用") - - height, width, n_bands = corrected_array.shape - # 创建ENVI格式数据集(会自动生成.hdr文件) - dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32) - if dataset is None: - raise ValueError(f"无法创建输出文件: {bsq_path}") - - try: - # 设置地理变换和投影 - if geotransform: - dataset.SetGeoTransform(geotransform) - if projection: - dataset.SetProjection(projection) - - # 写入每个波段(BSQ格式:按波段顺序存储) - for i in range(n_bands): - band = dataset.GetRasterBand(i + 1) - band.WriteArray(corrected_array[:, :, i]) - band.FlushCache() - finally: - dataset = None - - # 检查.hdr文件是否已创建 - hdr_path = bsq_path + '.hdr' - if os.path.exists(hdr_path): - print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") - print(f"头文件已保存至: {hdr_path}") - else: - print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") - print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建") + return cdf_thresholding(im) + + def _get_glint_mask_block(self, band_data): + """ + 对单波段块计算 glint mask + 阈值来自全局阈值 self.thresholds[band_idx] + """ + # LoG + log_im = ndimage.gaussian_laplace(band_data.astype(np.float32), sigma=self.sigma) + # 全局阈值 + thresh = self.thresholds[self._current_band] + glint_mask = (log_im < thresh).astype(np.uint8) + + # 应用水域掩膜 + water_mask = self._load_water_mask() + if water_mask is not None: + y_off = self._current_y + y_end = y_off + band_data.shape[0] + x_off = self._current_x + x_end = x_off + band_data.shape[1] + mask_block = water_mask[y_off:y_end, x_off:x_end] + glint_mask = glint_mask * mask_block + + return log_im, glint_mask + + def _get_est_background(self, im, k_size=5): + """估计背景光谱(中值滤波)""" + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size)) + return cv2.erode(im.astype(np.float32), kernel) + + def _optimise_correction_band(self, R_glint, R_bg_glint, bounds): + """ + 全局优化单波段 b 值 + + 使用所有 glint 像素的方差最小化 + R_corrected = R - mask * (R / b - R_bg) + 最小化 Var(R_corrected) + + :param R_glint: 所有 glint 像素值(1D array) + :param R_bg_glint: 对应背景值(1D array) + :param bounds: 优化边界 + :return: 最优 b 值 + """ + if len(R_glint) == 0: + return 1.0 + + R_glint = R_glint.astype(np.float32) + R_bg_glint = R_bg_glint.astype(np.float32) + + def objective(b): + b = float(b) + R_corrected = R_glint - (R_glint / b - R_bg_glint) + return np.var(R_corrected) + + res = minimize_scalar(objective, bounds=bounds, method='bounded') + return res.x + + def _scan_and_collect_glint(self): + """ + Step 1: 分块扫描全图,收集每波段的全局阈值和 glint 像素索引 + + 内存:仅存储每波段的阈值(float)和 glint 像素位置索引 + """ + print(f"[SUGAR] 步骤1: 扫描全图收集glint像素...") + water_mask = self._load_water_mask() + + # 初始化阈值列表 + self.thresholds = [None] * self.n_bands + log_collections = [[] for _ in range(self.n_bands)] + + # 全图扫描:收集每波段的 LoG 值用于阈值计算 + n_blocks = 0 + for y_off in range(0, self.height, self.block_size): + y_end = min(y_off + self.block_size, self.height) + y_size = y_end - y_off + + for x_off in range(0, self.width, self.block_size): + x_end = min(x_off + self.block_size, self.width) + x_size = x_end - x_off + n_blocks += 1 + + for b in range(self.n_bands): + band = self.dataset.GetRasterBand(b + 1) + block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32) + band = None + + log_im = ndimage.gaussian_laplace(block, sigma=self.sigma) + + # mask_block 在波段循环外初始化,每块只计算一次 + if b == 0 and water_mask is not None: + _mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool) + + if water_mask is not None: + if _mask_block.any(): + log_collections[b].append(log_im[_mask_block]) + else: + log_collections[b].append(log_im.ravel()) + + del block, log_im + + if water_mask is not None: + del _mask_block + + # 计算每波段的全局阈值(需要所有LoG值) + print(f"[SUGAR] 计算 {self.n_bands} 个波段的全局阈值...") + for b in range(self.n_bands): + if len(log_collections[b]) == 0: + self.thresholds[b] = 0.0 + else: + all_log = np.concatenate(log_collections[b]) + thresh = self._compute_threshold( + all_log.reshape(1, -1) # shape (1, N) 模拟二维输入 + ) + self.thresholds[b] = float(thresh) + del all_log + print(f" 波段{b}: thresh={self.thresholds[b]:.4f}") + log_collections[b] = None + + def _collect_glint_pixel_values(self): + """ + Step 2: 再次分块扫描,收集每波段所有 glint 像素的 (R, R_bg) 值用于优化 + + 内存:只存储 1D 数组(所有 glint 像素值) + """ + print(f"[SUGAR] 步骤2: 收集glint像素值用于全局优化...") + water_mask = self._load_water_mask() + + R_glint_list = [[] for _ in range(self.n_bands)] + R_bg_glint_list = [[] for _ in range(self.n_bands)] + + for y_off in range(0, self.height, self.block_size): + y_end = min(y_off + self.block_size, self.height) + y_size = y_end - y_off + + for x_off in range(0, self.width, self.block_size): + x_end = min(x_off + self.block_size, self.width) + x_size = x_end - x_off + + for b in range(self.n_bands): + band = self.dataset.GetRasterBand(b + 1) + R_block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32) + band = None + + # LoG 和 mask + log_im = ndimage.gaussian_laplace(R_block, sigma=self.sigma) + thresh = self.thresholds[b] + glint_mask = (log_im < thresh).astype(np.uint8) + + if water_mask is not None: + mask_block = water_mask[y_off:y_end, x_off:x_end] + glint_mask = glint_mask * mask_block + + # 背景 + if self.estimate_background: + R_bg = self._get_est_background(R_block) + else: + R_bg = np.percentile(R_block, 5, method='nearest') + + # 收集 glint 像素 + glint_idx = glint_mask.astype(bool) + if glint_idx.any(): + R_glint_list[b].append(R_block[glint_idx]) + R_bg_glint_list[b].append(R_bg[glint_idx]) + + del R_block, log_im, glint_mask, R_bg + + # 汇总 + self.R_glint_all = [] + self.R_bg_glint_all = [] + for b in range(self.n_bands): + if len(R_glint_list[b]) == 0: + self.R_glint_all.append(np.array([], dtype=np.float32)) + self.R_bg_glint_all.append(np.array([], dtype=np.float32)) + else: + self.R_glint_all.append(np.concatenate(R_glint_list[b])) + self.R_bg_glint_all.append(np.concatenate(R_bg_glint_list[b])) + n = len(self.R_glint_all[b]) + print(f" 波段{b}: 收集到 {n} 个glint像素") + del R_glint_list[b], R_bg_glint_list[b] + + def _optimize_b_list(self): + """ + Step 3: 全局优化每波段的 b 值 + """ + print(f"[SUGAR] 步骤3: 全局优化 b 值...") + self.b_list = [] + for b in range(self.n_bands): + bounds = self.bounds_all[b] + b_opt = self._optimise_correction_band( + self.R_glint_all[b], self.R_bg_glint_all[b], bounds + ) + self.b_list.append(float(b_opt)) + print(f" 波段{b}: b={b_opt:.4f}") + + # 释放内存 + self.R_glint_all = None + self.R_bg_glint_all = None + + def _process_and_write_block(self, x_off, y_off, x_size, y_size, out_dataset): + """ + Step 4: 分块处理并写入输出文件 + """ + water_mask = self._load_water_mask() + + for b in range(self.n_bands): + band = self.dataset.GetRasterBand(b + 1) + R_block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32) + band = None + + # 计算 glint mask + log_im = ndimage.gaussian_laplace(R_block, sigma=self.sigma) + thresh = self.thresholds[b] + glint_mask = (log_im < thresh).astype(np.uint8) + + if water_mask is not None: + mask_block = water_mask[y_off:y_off + y_size, x_off:x_off + x_size] + glint_mask = glint_mask * mask_block + + glint_bool = glint_mask.astype(bool) + + # 计算背景 + if self.estimate_background: + R_bg = self._get_est_background(R_block) + else: + R_bg = np.percentile(R_block, 5, method='nearest') + + # 校正 + b_val = self.b_list[b] + R_corrected = R_block.copy() + + if glint_bool.any(): + R_corrected[glint_bool] = ( + R_bg[glint_bool] + + (R_block[glint_bool] - R_bg[glint_bool]) / b_val + ) + + # 写入 + out_band = out_dataset.GetRasterBand(b + 1) + out_band.WriteArray(R_corrected, x_off, y_off) + out_band.FlushCache() + + del R_block, log_im, glint_mask, R_bg, R_corrected def get_corrected_bands(self): """ - 获取校正后的波段 - - :return: 校正后的波段列表 + 执行分块处理,返回校正后的波段列表 """ - corrected_bands = [] - # 获取水域掩膜(如果存在) - water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None - - for i in range(self.n_bands): - im_band = self.im_aligned[:,:,i] - # 一次性计算mask和background,避免重复计算 - glint_mask = self.get_glint_mask(im_band) - background = self.get_est_background(im_band, k_size=5) - # 使用视图和原地操作减少内存 - im_corrected = im_band.copy() - glint_mask_bool = glint_mask.astype(bool) - im_corrected[glint_mask_bool] = background[glint_mask_bool] - - # 如果存在水域掩膜,确保只在水域内应用校正 - if water_mask_bool is not None: - # 只在水域掩膜内应用校正 - correction_mask = glint_mask_bool & water_mask_bool - im_corrected = np.where(correction_mask, background, im_band) - # 非水域区域保持原值 - im_corrected = np.where(water_mask_bool, im_corrected, im_band) - - corrected_bands.append(im_corrected) - - # 如果提供了输出路径,保存结果 - if self.output_path is not None: - self._save_corrected_bands(corrected_bands) + if self.output_path is None: + raise ValueError("output_path 必须提供,分块处理需要直接写入文件") - return corrected_bands + output_dir = os.path.dirname(self.output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) -def correction_iterative(im_aligned,iter=3,bounds = [(1,2)],estimate_background=True,glint_mask_method="cdf",get_glint_mask=False,termination_thresh = 20, water_mask=None, output_path=None): + base_path, ext = os.path.splitext(self.output_path) + bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path + + geotransform = self.dataset.GetGeoTransform() + projection = self.dataset.GetProjection() + + driver = gdal.GetDriverByName('ENVI') + out_dataset = driver.Create(bsq_path, self.width, self.height, + self.n_bands, gdal.GDT_Float32) + if out_dataset is None: + raise ValueError(f"无法创建输出文件: {bsq_path}") + out_dataset.SetGeoTransform(geotransform) + out_dataset.SetProjection(projection) + + # Step 1: 扫描收集阈值 + self._scan_and_collect_glint() + + # Step 2: 收集 glint 像素值 + self._collect_glint_pixel_values() + + # Step 3: 全局优化 b + self._optimize_b_list() + + # Step 4: 分块处理写入 + n_blocks_x = (self.width + self.block_size - 1) // self.block_size + n_blocks_y = (self.height + self.block_size - 1) // self.block_size + total_blocks = n_blocks_x * n_blocks_y + + print(f"[SUGAR] 步骤4: 分块处理写入,共 {total_blocks} 块") + + block_idx = 0 + for y_off in range(0, self.height, self.block_size): + y_end = min(y_off + self.block_size, self.height) + y_size = y_end - y_off + + for x_off in range(0, self.width, self.block_size): + x_end = min(x_off + self.block_size, self.width) + x_size = x_end - x_off + block_idx += 1 + + self._current_x = x_off + self._current_y = y_off + + print(f"[SUGAR] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})") + + self._process_and_write_block(x_off, y_off, x_size, y_size, out_dataset) + + out_dataset = None + self.dataset = None + + hdr_path = bsq_path + '.hdr' + if os.path.exists(hdr_path): + print(f"[SUGAR] 校正完成,已保存至: {bsq_path}") + else: + print(f"[SUGAR] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)") + + return [] + + def __del__(self): + if self.dataset is not None: + self.dataset = None + + +# ============================================================================ +# 独立函数:correction_iterative(迭代版本,支持大图) +# ============================================================================ +def correction_iterative(img_path, iter=3, bounds=None, estimate_background=True, + glint_mask_method="cdf", get_glint_mask=False, + termination_thresh=20.0, water_mask=None, output_path=None, + block_size=1000): """ - :param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image - :param iter (int or None): number of iterations to run the sugar algorithm. If None, termination conditions are automatically applied - :param bounds (list of tuples): to limit correction magnitude - :param get_glint_mask (np.ndarray): - :param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域 - 可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp) - 如果为None,则处理全图 - :param output_path (str or None): 输出文件路径,如果提供则保存最后一次迭代的校正结果 - 如果为None,则不保存 - conducts iterative correction using SUGAR - """ - glint_image = im_aligned.copy() - corrected_images = [] + SUGAR 迭代去耀斑 - 分块版本 - if iter is None: - # termination conditions - relative_difference = lambda sd0,sd1: sd1/sd0*100 - marginal_difference = lambda sd1,sd2: (sd1-sd2)/sd1*100 - relative_diff_thresh = marginal_difference_thresh = termination_thresh - sd_og = np.var(im_aligned) - iter_count = 0 - sd_next = sd_og # 不需要copy,直接使用值 - max_iter = 100 # 添加最大迭代次数限制,防止无限循环 - - while ((relative_difference(sd_og,sd_next) > relative_diff_thresh) and iter_count < max_iter): - # do all the processing here - HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask) - corrected_bands = HM.get_corrected_bands() - glint_image = np.stack(corrected_bands,axis=2) - sd_temp = np.var(glint_image) - # 只在需要时保存中间结果,减少内存占用 - if get_glint_mask or iter_count == 0: - corrected_images.append(glint_image.copy()) - else: - corrected_images.append(glint_image) # 最后一次迭代的结果 - # save glint_mask - # if iter_count == 0 and get_glint_mask is True: - # glint_mask = np.stack(HM.glint_mask,axis=2) - if (marginal_difference(sd_next,sd_temp) 0: - _save_corrected_image(corrected_images[-1], output_path) - - else: - for i in range(iter): - HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask) - corrected_bands = HM.get_corrected_bands() - glint_image = np.stack(corrected_bands,axis=2) - # 只在最后一次迭代或需要时保存所有结果 - if i == iter - 1 or get_glint_mask: - corrected_images.append(glint_image.copy()) - else: - # 对于中间迭代,可以只保存引用(但要注意内存管理) - corrected_images.append(glint_image) - # save glint_mask - # if i == 0 and get_glint_mask is True: - # glint_mask = np.stack(HM.glint_mask,axis=2) - - # 如果提供了输出路径,保存最后一次迭代的结果 - if output_path is not None and len(corrected_images) > 0: - _save_corrected_image(corrected_images[-1], output_path) - - return corrected_images - -def _save_corrected_image(corrected_image, output_path): - """ - 保存校正后的图像到文件(用于correction_iterative函数,BSQ格式,ENVI格式) - - :param corrected_image: 校正后的图像数组,形状为(height, width, bands) + :param img_path (str): 输入影像文件路径 + :param iter (int or None): 迭代次数,None 表示自动终止 + :param bounds: 优化边界 + :param estimate_background: 是否估计背景 + :param glint_mask_method: "cdf" 或 "otsu" + :param get_glint_mask: 是否返回 glint mask(已废弃,保持接口兼容) + :param termination_thresh: 自动终止阈值 + :param water_mask: 水域掩膜 :param output_path: 输出文件路径 + :param block_size: 分块大小 + :return: 每次迭代的校正图像列表(None,分块模式下返回空列表) """ if not GDAL_AVAILABLE: - raise ImportError("GDAL未安装,无法保存影像文件") - - if output_path is None: - return - - # 确保输出目录存在 - output_dir = os.path.dirname(output_path) - if output_dir and not os.path.exists(output_dir): - os.makedirs(output_dir, exist_ok=True) - - # 如果没有地理信息,使用默认值 - geotransform = (0, 1, 0, 0, 0, -1) - projection = "" - - # 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq - base_path, ext = os.path.splitext(output_path) - # 如果扩展名不是.bsq,使用基础路径添加.bsq - if ext.lower() != '.bsq': - bsq_path = base_path + '.bsq' - else: - bsq_path = output_path - - # 使用ENVI驱动(默认就是BSQ格式) - driver = gdal.GetDriverByName('ENVI') - if driver is None: - raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用") - - height, width, n_bands = corrected_image.shape - # 创建ENVI格式数据集(会自动生成.hdr文件) - dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32) + raise ImportError("GDAL未安装") + + if bounds is None: + bounds = [(1, 2)] + + # 打开影像获取基本信息 + dataset = gdal.Open(img_path, gdal.GA_ReadOnly) if dataset is None: - raise ValueError(f"无法创建输出文件: {bsq_path}") - - try: - # 设置地理变换和投影 - if geotransform: - dataset.SetGeoTransform(geotransform) - if projection: - dataset.SetProjection(projection) - - # 写入每个波段(BSQ格式:按波段顺序存储) - for i in range(n_bands): - band = dataset.GetRasterBand(i + 1) - band.WriteArray(corrected_image[:, :, i]) - band.FlushCache() - finally: - dataset = None - - # 检查.hdr文件是否已创建 - hdr_path = bsq_path + '.hdr' - if os.path.exists(hdr_path): - print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") - print(f"头文件已保存至: {hdr_path}") + raise ValueError(f"无法打开影像文件: {img_path}") + + width = dataset.RasterXSize + height = dataset.RasterYSize + n_bands = dataset.RasterCount + + geotransform = dataset.GetGeoTransform() + projection = dataset.GetProjection() + dataset = None + + # 计算临时输出路径 + temp_dir = os.path.dirname(output_path) if output_path else os.getcwd() + temp_base = os.path.join(temp_dir, "_sugar_iter") + + # 确保输出目录存在 + if output_path: + output_dir = os.path.dirname(output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + if iter is None: + relative_diff = lambda sd0, sd1: sd1 / sd0 * 100 + marginal_diff = lambda sd1, sd2: (sd1 - sd2) / sd1 * 100 + + glint_img_path = img_path + iter_count = 0 + max_iter = 100 + + while True: + iter_output = f"{temp_base}_{iter_count}.bsq" + + sugar = SUGAR( + glint_img_path, + bounds=bounds, + estimate_background=estimate_background, + glint_mask_method=glint_mask_method, + water_mask=water_mask, + output_path=iter_output, + block_size=block_size + ) + sugar.get_corrected_bands() + + # 检查方差收敛 + # 读取当前输出图像的方差(分块读取第一个波段估算) + ds = gdal.Open(iter_output, gdal.GA_ReadOnly) + if ds is not None: + # 采样估算方差 + sample_data = [] + for y_off in range(0, height, block_size): + y_end = min(y_off + block_size, height) + block = ds.GetRasterBand(1).ReadAsArray(0, y_off, width, y_end - y_off) + sample_data.append(block.ravel()) + all_data = np.concatenate(sample_data) + sd_current = np.var(all_data) + ds = None + del all_data + else: + sd_current = 0 + + prev_img_path = glint_img_path + glint_img_path = iter_output + + if iter_count == 0: + sd_prev = sd_current + + # 检查终止条件 + if (iter_count > 0 and + marginal_diff(sd_prev, sd_current) < termination_thresh): + break + if iter_count >= max_iter: + break + + sd_prev = sd_current + iter_count += 1 + + # 将最终结果移动到 output_path + if output_path and glint_img_path != output_path: + import shutil + if os.path.exists(output_path): + os.remove(output_path) + # 找最后一个有效输出 + last_iter = max(0, iter_count - 1) + final_path = f"{temp_base}_{last_iter}.bsq" + if os.path.exists(final_path): + shutil.move(final_path, output_path) + # 复制hdr + if os.path.exists(final_path + '.hdr'): + import shutil + shutil.copy(final_path + '.hdr', output_path + '.hdr') + else: - print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") - print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建") + glint_img_path = img_path + + for i in range(iter): + iter_output = f"{temp_base}_{i}.bsq" + + sugar = SUGAR( + glint_img_path, + bounds=bounds, + estimate_background=estimate_background, + glint_mask_method=glint_mask_method, + water_mask=water_mask, + output_path=iter_output, + block_size=block_size + ) + sugar.get_corrected_bands() + + prev_img_path = glint_img_path + glint_img_path = iter_output + + # 将最后一次结果移动到 output_path + if output_path: + last_iter = iter - 1 + final_path = f"{temp_base}_{last_iter}.bsq" + if os.path.exists(final_path): + import shutil + # 删除旧文件 + if os.path.exists(output_path): + os.remove(output_path) + # 移动 + shutil.move(final_path, output_path) + # 复制hdr + if os.path.exists(final_path + '.hdr'): + shutil.copy(final_path + '.hdr', output_path + '.hdr') + + # 清理临时文件 + for i in range(max(0, iter - 1)): + f = f"{temp_base}_{i}.bsq" + if os.path.exists(f): + os.remove(f) + h = f + '.hdr' + if os.path.exists(h): + os.remove(h) + + return [] \ No newline at end of file