import cv2 import numpy as np import os from scipy import ndimage from scipy.optimize import minimize_scalar try: from osgeo import gdal GDAL_AVAILABLE = True except ImportError: GDAL_AVAILABLE = False def otsu_thresholding(im, auto_bins=None): """ Otsu阈值分割 """ if auto_bins is None: auto_bins = max(10, int(0.005 * im.shape[0] * im.shape[1])) im_flat = im.ravel() valid_mask = np.isfinite(im_flat) if not valid_mask.all(): im_flat = im_flat[valid_mask] count, bin_edges = np.histogram(im_flat, bins=auto_bins) bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 count_sum = count.sum() hist_norm = count / count_sum Q = hist_norm.cumsum() N = count.shape[0] N_negative = np.sum(bin < 0) bins = np.arange(N, dtype=np.float32) def otsu_thresh(x): x = int(x) p1 = hist_norm[:x] p2 = hist_norm[x:] q1 = Q[x] q2 = Q[N - 1] - Q[x] b1 = bins[:x] b2 = bins[x:] m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0 m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0 v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0 v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0 return v1 * q1 + v2 * q2 if N_negative <= 1: return bin[np.argmax(count)] res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded') return bin[int(res.x)] def cdf_thresholding(im, auto_bins=10): """CDF阈值分割""" im_flat = im.ravel() valid_mask = np.isfinite(im_flat) if not valid_mask.all(): im_flat = im_flat[valid_mask] count, bin_edges = np.histogram(im_flat, bins=auto_bins) bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 return bin[np.argmax(count)] class SUGAR: """ SUGAR 耀斑去除算法 - 分块逐波段处理版本 策略: 1. 分块扫描全图,计算每个块的 glint_mask(需要全局阈值) 2. 收集所有 glint 像素值到列表(仅收集索引,不存储完整掩膜数组) 3. 全局优化每波段的 b 值(使用所有 glint 像素的方差最小化) 4. 分块处理:计算 background(需全块) -> 应用校正 -> 写入输出 """ def __init__(self, img_path, bounds=None, sigma=1.0, estimate_background=True, glint_mask_method="cdf", water_mask=None, output_path=None, block_size=1000): """ :param img_path (str): 输入影像文件路径 :param bounds: 每个波段的优化边界,默认 [(1,2)] * n_bands :param sigma (float): LoG 平滑 sigma :param estimate_background (bool): 是否用中值滤波估计背景 :param glint_mask_method (str): "cdf" 或 "otsu" :param water_mask: 水域掩膜 :param output_path (str): 输出文件路径 :param block_size (int): 分块大小 """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法读取影像文件") if bounds is None: bounds = [(1, 2)] self.img_path = img_path self.bounds = bounds self.sigma = sigma self.estimate_background = estimate_background self.glint_mask_method = glint_mask_method self.water_mask = None self.water_mask_path = water_mask self.output_path = output_path self.block_size = block_size # 打开影像 self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly) if self.dataset is None: raise ValueError(f"无法打开影像文件: {img_path}") self.width = self.dataset.RasterXSize self.height = self.dataset.RasterYSize self.n_bands = self.dataset.RasterCount # 扩展 bounds 到所有波段 self.bounds_all = self.bounds * self.n_bands # 优化结果(全局) self.b_list = None self.glint_pixel_indices = [] # list of (block_idx, row, col) 索引 self.thresholds = [] # 每波段的全局阈值 def _load_water_mask(self): """延迟加载水域掩膜""" if self.water_mask_path is None: return None if isinstance(self.water_mask_path, np.ndarray): if self.water_mask_path.shape[:2] != (self.height, self.width): raise ValueError( f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配" ) return (self.water_mask_path > 0).astype(np.uint8) if isinstance(self.water_mask_path, str): if self.water_mask_path.lower().endswith('.shp'): raise ValueError("请先栅格化shapefile为栅格掩膜文件") mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly) if mask_dataset is None: raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}") mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() mask_dataset = None if mask_array.shape != (self.height, self.width): raise ValueError( f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配" ) return (mask_array > 0).astype(np.uint8) return None def _compute_threshold(self, im): """计算 glint 阈值""" if self.glint_mask_method == "otsu": return otsu_thresholding(im) else: return cdf_thresholding(im) def _get_glint_mask_block(self, band_data): """ 对单波段块计算 glint mask 阈值来自全局阈值 self.thresholds[band_idx] """ # LoG log_im = ndimage.gaussian_laplace(band_data.astype(np.float32), sigma=self.sigma) # 全局阈值 thresh = self.thresholds[self._current_band] glint_mask = (log_im < thresh).astype(np.uint8) # 应用水域掩膜 water_mask = self._load_water_mask() if water_mask is not None: y_off = self._current_y y_end = y_off + band_data.shape[0] x_off = self._current_x x_end = x_off + band_data.shape[1] mask_block = water_mask[y_off:y_end, x_off:x_end] glint_mask = glint_mask * mask_block return log_im, glint_mask def _get_est_background(self, im, k_size=5): """估计背景光谱(中值滤波)""" kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size)) return cv2.erode(im.astype(np.float32), kernel) def _optimise_correction_band(self, R_glint, R_bg_glint, bounds): """ 全局优化单波段 b 值 使用所有 glint 像素的方差最小化 R_corrected = R - mask * (R / b - R_bg) 最小化 Var(R_corrected) :param R_glint: 所有 glint 像素值(1D array) :param R_bg_glint: 对应背景值(1D array) :param bounds: 优化边界 :return: 最优 b 值 """ if len(R_glint) == 0: return 1.0 R_glint = R_glint.astype(np.float32) R_bg_glint = R_bg_glint.astype(np.float32) def objective(b): b = float(b) R_corrected = R_glint - (R_glint / b - R_bg_glint) return np.var(R_corrected) res = minimize_scalar(objective, bounds=bounds, method='bounded') return res.x def _scan_and_collect_glint(self): """ Step 1: 分块扫描全图,收集每波段的全局阈值和 glint 像素索引 内存:仅存储每波段的阈值(float)和 glint 像素位置索引 """ print(f"[SUGAR] 步骤1: 扫描全图收集glint像素...") water_mask = self._load_water_mask() # 初始化阈值列表 self.thresholds = [None] * self.n_bands log_collections = [[] for _ in range(self.n_bands)] # 全图扫描:收集每波段的 LoG 值用于阈值计算 n_blocks = 0 for y_off in range(0, self.height, self.block_size): y_end = min(y_off + self.block_size, self.height) y_size = y_end - y_off for x_off in range(0, self.width, self.block_size): x_end = min(x_off + self.block_size, self.width) x_size = x_end - x_off n_blocks += 1 for b in range(self.n_bands): band = self.dataset.GetRasterBand(b + 1) block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32) band = None log_im = ndimage.gaussian_laplace(block, sigma=self.sigma) # mask_block 在波段循环外初始化,每块只计算一次 if b == 0 and water_mask is not None: _mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool) if water_mask is not None: if _mask_block.any(): log_collections[b].append(log_im[_mask_block]) else: log_collections[b].append(log_im.ravel()) del block, log_im if water_mask is not None: del _mask_block # 计算每波段的全局阈值(需要所有LoG值) print(f"[SUGAR] 计算 {self.n_bands} 个波段的全局阈值...") for b in range(self.n_bands): if len(log_collections[b]) == 0: self.thresholds[b] = 0.0 else: all_log = np.concatenate(log_collections[b]) thresh = self._compute_threshold( all_log.reshape(1, -1) # shape (1, N) 模拟二维输入 ) self.thresholds[b] = float(thresh) del all_log print(f" 波段{b}: thresh={self.thresholds[b]:.4f}") log_collections[b] = None def _collect_glint_pixel_values(self): """ Step 2: 再次分块扫描,收集每波段所有 glint 像素的 (R, R_bg) 值用于优化 内存:只存储 1D 数组(所有 glint 像素值) """ print(f"[SUGAR] 步骤2: 收集glint像素值用于全局优化...") water_mask = self._load_water_mask() R_glint_list = [[] for _ in range(self.n_bands)] R_bg_glint_list = [[] for _ in range(self.n_bands)] for y_off in range(0, self.height, self.block_size): y_end = min(y_off + self.block_size, self.height) y_size = y_end - y_off for x_off in range(0, self.width, self.block_size): x_end = min(x_off + self.block_size, self.width) x_size = x_end - x_off for b in range(self.n_bands): band = self.dataset.GetRasterBand(b + 1) R_block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32) band = None # LoG 和 mask log_im = ndimage.gaussian_laplace(R_block, sigma=self.sigma) thresh = self.thresholds[b] glint_mask = (log_im < thresh).astype(np.uint8) if water_mask is not None: mask_block = water_mask[y_off:y_end, x_off:x_end] glint_mask = glint_mask * mask_block # 背景 if self.estimate_background: R_bg = self._get_est_background(R_block) else: R_bg = np.percentile(R_block, 5, method='nearest') # 收集 glint 像素 glint_idx = glint_mask.astype(bool) if glint_idx.any(): R_glint_list[b].append(R_block[glint_idx]) R_bg_glint_list[b].append(R_bg[glint_idx]) del R_block, log_im, glint_mask, R_bg # 汇总 self.R_glint_all = [] self.R_bg_glint_all = [] for b in range(self.n_bands): if len(R_glint_list[b]) == 0: self.R_glint_all.append(np.array([], dtype=np.float32)) self.R_bg_glint_all.append(np.array([], dtype=np.float32)) else: self.R_glint_all.append(np.concatenate(R_glint_list[b])) self.R_bg_glint_all.append(np.concatenate(R_bg_glint_list[b])) n = len(self.R_glint_all[b]) print(f" 波段{b}: 收集到 {n} 个glint像素") R_glint_list[b] = None R_bg_glint_list[b] = None def _optimize_b_list(self): """ Step 3: 全局优化每波段的 b 值 """ print(f"[SUGAR] 步骤3: 全局优化 b 值...") self.b_list = [] for b in range(self.n_bands): bounds = self.bounds_all[b] b_opt = self._optimise_correction_band( self.R_glint_all[b], self.R_bg_glint_all[b], bounds ) self.b_list.append(float(b_opt)) print(f" 波段{b}: b={b_opt:.4f}") # 释放内存 self.R_glint_all = None self.R_bg_glint_all = None def _process_and_write_block(self, x_off, y_off, x_size, y_size, out_dataset): """ Step 4: 分块处理并写入输出文件 """ water_mask = self._load_water_mask() for b in range(self.n_bands): band = self.dataset.GetRasterBand(b + 1) R_block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32) band = None # 计算 glint mask log_im = ndimage.gaussian_laplace(R_block, sigma=self.sigma) thresh = self.thresholds[b] glint_mask = (log_im < thresh).astype(np.uint8) if water_mask is not None: mask_block = water_mask[y_off:y_off + y_size, x_off:x_off + x_size] glint_mask = glint_mask * mask_block glint_bool = glint_mask.astype(bool) # 计算背景 if self.estimate_background: R_bg = self._get_est_background(R_block) else: R_bg = np.percentile(R_block, 5, method='nearest') # 校正 b_val = self.b_list[b] R_corrected = R_block.copy() if glint_bool.any(): R_corrected[glint_bool] = ( R_bg[glint_bool] + (R_block[glint_bool] - R_bg[glint_bool]) / b_val ) # 写入 out_band = out_dataset.GetRasterBand(b + 1) out_band.WriteArray(R_corrected, x_off, y_off) out_band.FlushCache() del R_block, log_im, glint_mask, R_bg, R_corrected def get_corrected_bands(self): """ 执行分块处理,返回校正后的波段列表 """ if self.output_path is None: raise ValueError("output_path 必须提供,分块处理需要直接写入文件") output_dir = os.path.dirname(self.output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) base_path, ext = os.path.splitext(self.output_path) bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path geotransform = self.dataset.GetGeoTransform() projection = self.dataset.GetProjection() driver = gdal.GetDriverByName('ENVI') out_dataset = driver.Create(bsq_path, self.width, self.height, self.n_bands, gdal.GDT_Float32) if out_dataset is None: raise ValueError(f"无法创建输出文件: {bsq_path}") out_dataset.SetGeoTransform(geotransform) out_dataset.SetProjection(projection) # Step 1: 扫描收集阈值 self._scan_and_collect_glint() # Step 2: 收集 glint 像素值 self._collect_glint_pixel_values() # Step 3: 全局优化 b self._optimize_b_list() # Step 4: 分块处理写入 n_blocks_x = (self.width + self.block_size - 1) // self.block_size n_blocks_y = (self.height + self.block_size - 1) // self.block_size total_blocks = n_blocks_x * n_blocks_y print(f"[SUGAR] 步骤4: 分块处理写入,共 {total_blocks} 块") block_idx = 0 for y_off in range(0, self.height, self.block_size): y_end = min(y_off + self.block_size, self.height) y_size = y_end - y_off for x_off in range(0, self.width, self.block_size): x_end = min(x_off + self.block_size, self.width) x_size = x_end - x_off block_idx += 1 self._current_x = x_off self._current_y = y_off print(f"[SUGAR] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})") self._process_and_write_block(x_off, y_off, x_size, y_size, out_dataset) out_dataset = None self.dataset = None hdr_path = bsq_path + '.hdr' if os.path.exists(hdr_path): print(f"[SUGAR] 校正完成,已保存至: {bsq_path}") else: print(f"[SUGAR] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)") return [] def __del__(self): if self.dataset is not None: self.dataset = None # ============================================================================ # 独立函数:correction_iterative(迭代版本,支持大图) # ============================================================================ def correction_iterative(img_path, iter=3, bounds=None, estimate_background=True, glint_mask_method="cdf", get_glint_mask=False, termination_thresh=20.0, water_mask=None, output_path=None, block_size=1000): """ SUGAR 迭代去耀斑 - 分块版本 :param img_path (str): 输入影像文件路径 :param iter (int or None): 迭代次数,None 表示自动终止 :param bounds: 优化边界 :param estimate_background: 是否估计背景 :param glint_mask_method: "cdf" 或 "otsu" :param get_glint_mask: 是否返回 glint mask(已废弃,保持接口兼容) :param termination_thresh: 自动终止阈值 :param water_mask: 水域掩膜 :param output_path: 输出文件路径 :param block_size: 分块大小 :return: 每次迭代的校正图像列表(None,分块模式下返回空列表) """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装") if bounds is None: bounds = [(1, 2)] # 打开影像获取基本信息 dataset = gdal.Open(img_path, gdal.GA_ReadOnly) if dataset is None: raise ValueError(f"无法打开影像文件: {img_path}") width = dataset.RasterXSize height = dataset.RasterYSize n_bands = dataset.RasterCount geotransform = dataset.GetGeoTransform() projection = dataset.GetProjection() dataset = None # 计算临时输出路径 temp_dir = os.path.dirname(output_path) if output_path else os.getcwd() temp_base = os.path.join(temp_dir, "_sugar_iter") # 确保输出目录存在 if output_path: output_dir = os.path.dirname(output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) if iter is None: relative_diff = lambda sd0, sd1: sd1 / sd0 * 100 marginal_diff = lambda sd1, sd2: (sd1 - sd2) / sd1 * 100 glint_img_path = img_path iter_count = 0 max_iter = 100 while True: iter_output = f"{temp_base}_{iter_count}.bsq" sugar = SUGAR( glint_img_path, bounds=bounds, estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask, output_path=iter_output, block_size=block_size ) sugar.get_corrected_bands() # 检查方差收敛 # 读取当前输出图像的方差(分块读取第一个波段估算) ds = gdal.Open(iter_output, gdal.GA_ReadOnly) if ds is not None: # 采样估算方差 sample_data = [] for y_off in range(0, height, block_size): y_end = min(y_off + block_size, height) block = ds.GetRasterBand(1).ReadAsArray(0, y_off, width, y_end - y_off) sample_data.append(block.ravel()) all_data = np.concatenate(sample_data) sd_current = np.var(all_data) ds = None del all_data else: sd_current = 0 prev_img_path = glint_img_path glint_img_path = iter_output if iter_count == 0: sd_prev = sd_current # 检查终止条件 if (iter_count > 0 and marginal_diff(sd_prev, sd_current) < termination_thresh): break if iter_count >= max_iter: break sd_prev = sd_current iter_count += 1 # 将最终结果移动到 output_path if output_path and glint_img_path != output_path: import shutil if os.path.exists(output_path): os.remove(output_path) # 找最后一个有效输出 last_iter = max(0, iter_count - 1) final_path = f"{temp_base}_{last_iter}.bsq" if os.path.exists(final_path): shutil.move(final_path, output_path) # 复制hdr if os.path.exists(final_path + '.hdr'): import shutil shutil.copy(final_path + '.hdr', output_path + '.hdr') else: glint_img_path = img_path for i in range(iter): iter_output = f"{temp_base}_{i}.bsq" sugar = SUGAR( glint_img_path, bounds=bounds, estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask, output_path=iter_output, block_size=block_size ) sugar.get_corrected_bands() prev_img_path = glint_img_path glint_img_path = iter_output # 将最后一次结果移动到 output_path if output_path: last_iter = iter - 1 final_path = f"{temp_base}_{last_iter}.bsq" if os.path.exists(final_path): import shutil # 删除旧文件 if os.path.exists(output_path): os.remove(output_path) # 移动 shutil.move(final_path, output_path) # 复制hdr if os.path.exists(final_path + '.hdr'): shutil.copy(final_path + '.hdr', output_path + '.hdr') # 清理临时文件 for i in range(max(0, iter - 1)): f = f"{temp_base}_{i}.bsq" if os.path.exists(f): os.remove(f) h = f + '.hdr' if os.path.exists(h): os.remove(h) return []