import cv2 import os import numpy as np from scipy import ndimage from scipy.optimize import minimize_scalar try: from osgeo import gdal GDAL_AVAILABLE = True except ImportError: GDAL_AVAILABLE = False # SUn-Glint-Aware Restoration (SUGAR):A sweet and simple algorithm for correcting sunglint class SUGAR: def __init__(self, im_aligned,bounds=[(1,2)],sigma=1,estimate_background=True, glint_mask_method="cdf", water_mask=None, output_path=None): """ :param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image :param bounds (a list of tuple): lower and upper bound for optimisation of b for each band :param sigma (float): smoothing sigma for LoG :param estimate_background (bool): whether to estimate background spectra using median filtering :param glint_mask_method (str): choose either "cdf" or "otsu", "cdf" is set as the default :param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域 可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp) 如果为None,则处理全图 :param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像 如果为None,则不保存 """ self.im_aligned = im_aligned self.sigma = sigma self.estimate_background = estimate_background self.n_bands = im_aligned.shape[-1] self.bounds = bounds*self.n_bands self.glint_mask_method = glint_mask_method self.height = im_aligned.shape[0] self.width = im_aligned.shape[1] self.output_path = output_path # 加载水域掩膜 self.water_mask = self._load_water_mask(water_mask) def _load_water_mask(self, water_mask): """ 加载水域掩膜 :param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp) :return: numpy数组或None,1表示水域,0表示非水域 """ if water_mask is None: return None # 如果已经是numpy数组 if isinstance(water_mask, np.ndarray): if water_mask.shape[:2] != (self.height, self.width): raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配") return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜 # 如果是文件路径 if isinstance(water_mask, str): try: from osgeo import gdal, ogr except ImportError: raise ValueError("使用文件路径作为掩膜时,必须安装GDAL") # 检查是否为shapefile if water_mask.lower().endswith('.shp'): # 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸) # 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考 raise ValueError("SUGAR类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜") else: # 栅格文件 mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly) if mask_dataset is None: raise ValueError(f"无法打开掩膜文件: {water_mask}") mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() mask_dataset = None if mask_array.shape != (self.height, self.width): raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配") return (mask_array > 0).astype(np.uint8) raise ValueError(f"不支持的掩膜类型: {type(water_mask)}") def otsu_thresholding(self,im): """ :param im (np.ndarray) of shape mxn. Note that it is the LoG of image otsu thresholding with Brent's minimisation of a univariate function returns the value of the threshold for input """ auto_bins = int(0.005*im.shape[0]*im.shape[1]) # 使用ravel()而不是flatten(),避免不必要的复制(如果可能) # 如果存在无效值(如NaN或极大值),过滤掉它们 im_flat = im.ravel() # 过滤掉NaN和无穷大值 valid_mask = np.isfinite(im_flat) if not valid_mask.all(): im_flat = im_flat[valid_mask] count, bin_edges = np.histogram(im_flat, bins=auto_bins) bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers,使用乘法替代除法 count_sum = count.sum() hist_norm = count / count_sum # normalised histogram Q = hist_norm.cumsum() # CDF function ranges from 0 to 1 N = count.shape[0] N_negative = np.sum(bin < 0) bins = np.arange(N, dtype=np.float32) # 使用float32减少内存 def otsu_thresh(x): x = int(x) # 使用切片而不是hsplit,避免创建新数组 p1 = hist_norm[:x] p2 = hist_norm[x:] q1 = Q[x] q2 = Q[N-1] - Q[x] b1 = bins[:x] b2 = bins[x:] # finding means and variances m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0 m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0 v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0 v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0 # calculates the minimization function fn = v1 * q1 + v2 * q2 return fn # brent method is used to minimise an univariate function # bounded minimisation # we can just limit the search to negative values since we know thresh should be negative as L<0 for glint pixels if N_negative <= 1: # 如果没有足够的负值,使用默认阈值 return bin[np.argmax(count)] res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded') thresh = bin[int(res.x)] return thresh # def cdf_thresholding(self,im, percentile=0.05): # """ # :param im (np.ndarray) of shape mxn # :param percentile (float): lower and upper percentile values are potential glint pixels # """ # lower_perc = percentile # upper_perc = 1-percentile # im_flatten = im.flatten() # H,X1 = np.histogram(im_flatten, bins = int(0.005*im.shape[0]*im.shape[1]), density=True ) # dx = X1[1] - X1[0] # F1 = np.cumsum(H)*dx # F_lower = X1[1:][F1upper_perc] # while((F_lower.size == 0) or (F_upper.size == 0)): # if (F_lower.size == 0): # lower_perc += 0.01 # F_lower = X1[1:][F1upper_perc] # lower_thresh = F_lower[-1] # upper_thresh = F_upper[0] # return lower_thresh,upper_thresh def cdf_thresholding(self,im,auto_bins=10): """ :param im (np.ndarray) of shape mxn. Note that it is the LoG of image :param percentile (float): lower and upper percentile values are potential glint pixels """ # 使用ravel()而不是flatten(),避免不必要的复制 im_flat = im.ravel() # 过滤掉NaN和无穷大值 valid_mask = np.isfinite(im_flat) if not valid_mask.all(): im_flat = im_flat[valid_mask] count, bin_edges = np.histogram(im_flat, bins=auto_bins) bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers,使用乘法替代除法 thresh = bin[np.argmax(count)] return thresh def glint_list(self): """ returns a list of np.ndarray, where each item is an extracted glint for each band based on get_glint_mask """ glint_mask = self.glint_mask_list() extracted_glint_list = [] for i in range(self.im_aligned.shape[-1]): gm = glint_mask[i] extracted_glint = gm*self.im_aligned[:,:,i] extracted_glint_list.append(extracted_glint) return extracted_glint_list def glint_mask_list(self): """ get glint mask using laplacian of gaussian image. returns a list of np.ndarray """ glint_mask_list = [] for i in range(self.im_aligned.shape[-1]): glint_mask = self.get_glint_mask(self.im_aligned[:,:,i]) glint_mask_list.append(glint_mask) return glint_mask_list def log_image_list(self): """ get Laplacian of Gaussian (LoG) images for all bands. returns a list of np.ndarray """ log_image_list = [] for i in range(self.im_aligned.shape[-1]): log_im = self.get_log_image(self.im_aligned[:,:,i]) log_image_list.append(log_im) return log_image_list def get_log_image(self, im): """ get Laplacian of Gaussian (LoG) image for a single band. returns a np.ndarray """ LoG_im = ndimage.gaussian_laplace(im, sigma=self.sigma) return LoG_im def get_glint_mask(self,im): """ get glint mask using laplacian of gaussian image. We assume that water constituents and features follow a smooth continuum, but glint pixels vary a lot spatially and in intensities Note that for very extensive glint, this method may not work as well <--:TODO use U-net to identify glint mask returns a np.ndarray """ LoG_im = ndimage.gaussian_laplace(im,sigma=self.sigma) # 如果存在水域掩膜,只在掩膜内计算阈值 if self.water_mask is not None: mask_bool = self.water_mask.astype(bool) if mask_bool.any(): # 只在掩膜内提取LoG值用于阈值计算 LoG_masked = LoG_im[mask_bool] # 将非掩膜区域设为极大值,确保不影响阈值计算 LoG_for_thresh = LoG_im.copy() LoG_for_thresh[~mask_bool] = LoG_masked.max() + 1 else: LoG_for_thresh = LoG_im else: LoG_for_thresh = LoG_im #threshold mask if (self.glint_mask_method == "otsu"): thresh = self.otsu_thresholding(LoG_for_thresh) elif (self.glint_mask_method == "cdf"): thresh = self.cdf_thresholding(LoG_for_thresh) else: raise ValueError('Enter only cdf or otsu as glint_mask_method') # 使用更高效的方式创建mask,避免np.where的开销 glint_mask = (LoG_im < thresh).astype(np.uint8) # 如果存在水域掩膜,将非水域区域设为0 if self.water_mask is not None: glint_mask = glint_mask * self.water_mask return glint_mask def get_est_background(self, im,k_size=5): """ :param im (np.ndarray): image of a band estimate background spectra returns a np.ndarray """ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k_size,k_size)) dst = cv2.erode(im, kernel) return dst def optimise_correction_by_band(self,im,glint_mask,R_BG,bounds): """ :param im (np.ndarray): image of a band :param glint_mask (np.ndarray): glint mask, where glint area is 1 and non-glint area is 0 use brent method to get the optimimum b which minimises the variation (i.e. variance) in the entire image returns regression slope b """ # 预计算常量,避免在优化函数中重复计算 glint_mask_bool = glint_mask.astype(bool) R_BG_flat = R_BG if isinstance(R_BG, (int, float)) else R_BG[glint_mask_bool] def optimise_b(b): # 优化计算:只在glint区域计算校正 if isinstance(R_BG, (int, float)): im_corrected = im.copy() im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG) else: im_corrected = im.copy() im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG[glint_mask_bool]) return np.var(im_corrected) res = minimize_scalar(optimise_b, bounds=bounds, method='bounded') return res.x def divide_and_conquer(self): """ instead of computing b_list for each window, use the previous b_list to narrow the bounds, because of the strong spatial autocorrelation, we know that the b (correction magnitude) cannot diff too much this can optimise the run time """ def optimise_correction(self): """ returns a list of slope in band order i.e. 0,1,2,3,4,5,6,7,8,9 through optimisation """ b_list = [] glint_mask_list = [] est_background_list = [] for i in range(self.n_bands): glint_mask = self.get_glint_mask(self.im_aligned[:,:,i]) glint_mask_list.append(glint_mask) if self.estimate_background is True: est_background = self.get_est_background(self.im_aligned[:,:,i]) est_background_list.append(est_background) else: est_background = np.percentile(self.im_aligned[:,:,i], 5, interpolation='nearest') est_background_list.append(est_background) bounds = self.bounds[i] b = self.optimise_correction_by_band(self.im_aligned[:,:,i],glint_mask,est_background,bounds) b_list.append(b) # add attributes self.b_list = b_list self.glint_mask = glint_mask_list self.est_background = est_background_list return b_list, glint_mask_list, est_background_list def _save_corrected_bands(self, corrected_bands): """ 保存校正后的波段到文件(BSQ格式,ENVI格式) :param corrected_bands: 校正后的波段列表 """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法保存影像文件") if self.output_path is None: return # 确保输出目录存在 output_dir = os.path.dirname(self.output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) # 将波段列表转换为数组 corrected_array = np.stack(corrected_bands, axis=2) # 如果没有地理信息,使用默认值 geotransform = (0, 1, 0, 0, 0, -1) projection = "" # 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq base_path, ext = os.path.splitext(self.output_path) # 如果扩展名不是.bsq,使用基础路径添加.bsq if ext.lower() != '.bsq': bsq_path = base_path + '.bsq' else: bsq_path = self.output_path # 使用ENVI驱动(默认就是BSQ格式) driver = gdal.GetDriverByName('ENVI') if driver is None: raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用") height, width, n_bands = corrected_array.shape # 创建ENVI格式数据集(会自动生成.hdr文件) dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32) if dataset is None: raise ValueError(f"无法创建输出文件: {bsq_path}") try: # 设置地理变换和投影 if geotransform: dataset.SetGeoTransform(geotransform) if projection: dataset.SetProjection(projection) # 写入每个波段(BSQ格式:按波段顺序存储) for i in range(n_bands): band = dataset.GetRasterBand(i + 1) band.WriteArray(corrected_array[:, :, i]) band.FlushCache() finally: dataset = None # 检查.hdr文件是否已创建 hdr_path = bsq_path + '.hdr' if os.path.exists(hdr_path): print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") print(f"头文件已保存至: {hdr_path}") else: print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建") def get_corrected_bands(self): """ 获取校正后的波段 :return: 校正后的波段列表 """ corrected_bands = [] # 获取水域掩膜(如果存在) water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None for i in range(self.n_bands): im_band = self.im_aligned[:,:,i] # 一次性计算mask和background,避免重复计算 glint_mask = self.get_glint_mask(im_band) background = self.get_est_background(im_band, k_size=5) # 使用视图和原地操作减少内存 im_corrected = im_band.copy() glint_mask_bool = glint_mask.astype(bool) im_corrected[glint_mask_bool] = background[glint_mask_bool] # 如果存在水域掩膜,确保只在水域内应用校正 if water_mask_bool is not None: # 只在水域掩膜内应用校正 correction_mask = glint_mask_bool & water_mask_bool im_corrected = np.where(correction_mask, background, im_band) # 非水域区域保持原值 im_corrected = np.where(water_mask_bool, im_corrected, im_band) corrected_bands.append(im_corrected) # 如果提供了输出路径,保存结果 if self.output_path is not None: self._save_corrected_bands(corrected_bands) return corrected_bands def correction_iterative(im_aligned,iter=3,bounds = [(1,2)],estimate_background=True,glint_mask_method="cdf",get_glint_mask=False,termination_thresh = 20, water_mask=None, output_path=None): """ :param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image :param iter (int or None): number of iterations to run the sugar algorithm. If None, termination conditions are automatically applied :param bounds (list of tuples): to limit correction magnitude :param get_glint_mask (np.ndarray): :param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域 可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp) 如果为None,则处理全图 :param output_path (str or None): 输出文件路径,如果提供则保存最后一次迭代的校正结果 如果为None,则不保存 conducts iterative correction using SUGAR """ glint_image = im_aligned.copy() corrected_images = [] if iter is None: # termination conditions relative_difference = lambda sd0,sd1: sd1/sd0*100 marginal_difference = lambda sd1,sd2: (sd1-sd2)/sd1*100 relative_diff_thresh = marginal_difference_thresh = termination_thresh sd_og = np.var(im_aligned) iter_count = 0 sd_next = sd_og # 不需要copy,直接使用值 max_iter = 100 # 添加最大迭代次数限制,防止无限循环 while ((relative_difference(sd_og,sd_next) > relative_diff_thresh) and iter_count < max_iter): # do all the processing here HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask) corrected_bands = HM.get_corrected_bands() glint_image = np.stack(corrected_bands,axis=2) sd_temp = np.var(glint_image) # 只在需要时保存中间结果,减少内存占用 if get_glint_mask or iter_count == 0: corrected_images.append(glint_image.copy()) else: corrected_images.append(glint_image) # 最后一次迭代的结果 # save glint_mask # if iter_count == 0 and get_glint_mask is True: # glint_mask = np.stack(HM.glint_mask,axis=2) if (marginal_difference(sd_next,sd_temp) 0: _save_corrected_image(corrected_images[-1], output_path) else: for i in range(iter): HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask) corrected_bands = HM.get_corrected_bands() glint_image = np.stack(corrected_bands,axis=2) # 只在最后一次迭代或需要时保存所有结果 if i == iter - 1 or get_glint_mask: corrected_images.append(glint_image.copy()) else: # 对于中间迭代,可以只保存引用(但要注意内存管理) corrected_images.append(glint_image) # save glint_mask # if i == 0 and get_glint_mask is True: # glint_mask = np.stack(HM.glint_mask,axis=2) # 如果提供了输出路径,保存最后一次迭代的结果 if output_path is not None and len(corrected_images) > 0: _save_corrected_image(corrected_images[-1], output_path) return corrected_images def _save_corrected_image(corrected_image, output_path): """ 保存校正后的图像到文件(用于correction_iterative函数,BSQ格式,ENVI格式) :param corrected_image: 校正后的图像数组,形状为(height, width, bands) :param output_path: 输出文件路径 """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法保存影像文件") if output_path is None: return # 确保输出目录存在 output_dir = os.path.dirname(output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) # 如果没有地理信息,使用默认值 geotransform = (0, 1, 0, 0, 0, -1) projection = "" # 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq base_path, ext = os.path.splitext(output_path) # 如果扩展名不是.bsq,使用基础路径添加.bsq if ext.lower() != '.bsq': bsq_path = base_path + '.bsq' else: bsq_path = output_path # 使用ENVI驱动(默认就是BSQ格式) driver = gdal.GetDriverByName('ENVI') if driver is None: raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用") height, width, n_bands = corrected_image.shape # 创建ENVI格式数据集(会自动生成.hdr文件) dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32) if dataset is None: raise ValueError(f"无法创建输出文件: {bsq_path}") try: # 设置地理变换和投影 if geotransform: dataset.SetGeoTransform(geotransform) if projection: dataset.SetProjection(projection) # 写入每个波段(BSQ格式:按波段顺序存储) for i in range(n_bands): band = dataset.GetRasterBand(i + 1) band.WriteArray(corrected_image[:, :, i]) band.FlushCache() finally: dataset = None # 检查.hdr文件是否已创建 hdr_path = bsq_path + '.hdr' if os.path.exists(hdr_path): print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") print(f"头文件已保存至: {hdr_path}") else: print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")