import numpy as np import sys # import preprocessing try: from osgeo import gdal GDAL_AVAILABLE = True except ImportError: GDAL_AVAILABLE = False print("警告: GDAL未安装,将使用numpy处理模式") try: from tqdm import tqdm TQDM_AVAILABLE = True except ImportError: TQDM_AVAILABLE = False # 如果tqdm不可用,定义一个简单的包装器 def tqdm(iterable, desc=None, total=None, disable=None): return iterable # 检测是否在 PyInstaller 打包环境(无控制台) _is_frozen_gui = getattr(sys, "frozen", False) and (not hasattr(sys, 'stdout') or sys.stdout is None) class Goodman: def __init__(self, im_aligned, NIR_lower = 25, NIR_upper = 37, A = 0.000019, B = 0.1, use_gdal=True, chunk_size=None, water_mask=None, output_path=None): """ :param im_aligned (np.ndarray or str): band aligned and calibrated & corrected reflectance image 可以是numpy数组或GDAL可读取的文件路径 :param NIR_lower (int): band index which corresponds to 641.93nm, closest band to 640nm :param NIR_upper (int): band index which corresponds to 751.49nm, closest band to 750nm :param A (float): the values in Goodman et al's paper, using AVIRIS reflectance (rather than radiance) data :param B (float): the values in Goodman et al's paper, using AVIRIS reflectance (rather than radiance) data see Goodman et al, which corrects each pixel independently. The NIR radiance is subtracted from the radiance at each wavelength, but a wavelength-independent offset is also added. it is not clear how A and B were chosen, but an optimization for a case where in situ data is available would enable values to be found :param use_gdal (bool): 是否使用GDAL加速处理(需要GDAL可用且输入为文件路径或大数组) :param chunk_size (int): 已废弃,不再使用分块处理,改为逐波段处理 :param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域 可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp) 如果为None,则处理全图 :param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像 如果为None,则不保存 """ self.im_aligned = im_aligned self.NIR_lower = NIR_lower self.NIR_upper = NIR_upper self.A = A self.B = B self.use_gdal = use_gdal and GDAL_AVAILABLE self.chunk_size = chunk_size self.is_file_path = isinstance(im_aligned, str) self.output_path = output_path # 获取图像信息(需要在加载掩膜之前获取尺寸) if self.is_file_path: if not self.use_gdal: raise ValueError("输入为文件路径时,必须安装GDAL") self.dataset = gdal.Open(im_aligned, gdal.GA_ReadOnly) if self.dataset is None: raise ValueError(f"无法打开影像文件: {im_aligned}") self.height = self.dataset.RasterYSize self.width = self.dataset.RasterXSize self.n_bands = self.dataset.RasterCount else: self.dataset = None self.height = im_aligned.shape[0] self.width = im_aligned.shape[1] self.n_bands = im_aligned.shape[-1] # 加载水域掩膜(在获取图像尺寸之后) self.water_mask = self._load_water_mask(water_mask) def _load_water_mask(self, water_mask): """ 加载水域掩膜 :param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp) :return: numpy数组或None,1表示水域,0表示非水域 """ if water_mask is None: return None # 如果已经是numpy数组 if isinstance(water_mask, np.ndarray): if water_mask.shape[:2] != (self.height, self.width): raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配") return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜 # 如果是文件路径 if isinstance(water_mask, str): if not GDAL_AVAILABLE: raise ValueError("使用文件路径作为掩膜时,必须安装GDAL") # 检查是否为shapefile if water_mask.lower().endswith('.shp'): # 从shp文件创建掩膜 if self.is_file_path: ref_path = self.im_aligned else: raise ValueError("输入为numpy数组时,无法从shp文件创建掩膜(需要参考栅格)") try: from osgeo import ogr ref_dataset = gdal.Open(ref_path, gdal.GA_ReadOnly) if ref_dataset is None: raise ValueError(f"无法打开参考栅格文件: {ref_path}") geotransform = ref_dataset.GetGeoTransform() projection = ref_dataset.GetProjection() width = ref_dataset.RasterXSize height = ref_dataset.RasterYSize # 创建内存中的栅格数据集 mem_driver = gdal.GetDriverByName('MEM') mask_dataset = mem_driver.Create('', width, height, 1, gdal.GDT_Byte) mask_dataset.SetGeoTransform(geotransform) mask_dataset.SetProjection(projection) mask_band = mask_dataset.GetRasterBand(1) mask_band.Fill(0) # 打开shp文件 shp_dataset = ogr.Open(water_mask) if shp_dataset is None: raise ValueError(f"无法打开shp文件: {water_mask}") layer = shp_dataset.GetLayer() gdal.RasterizeLayer(mask_dataset, [1], layer, burn_values=[1]) water_mask_array = mask_band.ReadAsArray() ref_dataset = None mask_dataset = None shp_dataset = None return (water_mask_array > 0).astype(np.uint8) except Exception as e: raise ValueError(f"从shp文件创建掩膜时出错: {e}") else: # 栅格文件 mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly) if mask_dataset is None: raise ValueError(f"无法打开掩膜文件: {water_mask}") mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() mask_dataset = None if mask_array.shape != (self.height, self.width): raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配") return (mask_array > 0).astype(np.uint8) raise ValueError(f"不支持的掩膜类型: {type(water_mask)}") def _get_corrected_bands_numpy(self): """ 使用numpy处理(用于小图像或GDAL不可用时) 注意:由于输入已经是numpy数组,数据已在内存中。 此方法通过逐波段处理,避免同时创建多个校正后的波段数组。 内存峰值 = 原始数组 + NIR波段(2个) + 当前处理的波段(1个) """ # 预提取重复使用的NIR波段,避免在循环中重复访问 # 这些波段会一直保存在内存中,因为它们需要用于所有波段的校正 R_640 = self.im_aligned[:,:,self.NIR_lower] R_750 = self.im_aligned[:,:,self.NIR_upper] # 预计算常量部分 diff_640_750 = R_640 - R_750 corrected_bands = [] # 获取水域掩膜(如果存在) water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None # 逐波段处理:每次只处理一个波段,处理完后立即添加到结果列表 for i in tqdm(range(self.n_bands), desc="处理波段 (numpy)", total=self.n_bands, disable=_is_frozen_gui): # 获取当前波段(这是数组视图,不是复制) R = self.im_aligned[:,:,i] # 优化计算:减少中间数组创建 corrected_band = R - R_750 + self.A + self.B * diff_640_750 # 使用np.maximum原地操作,将负值设为0 np.maximum(corrected_band, 0, out=corrected_band) # 如果存在水域掩膜,只对水域区域应用校正 if water_mask_bool is not None: corrected_band = np.where(water_mask_bool, corrected_band, R) # 立即添加到结果列表(corrected_band会保留在列表中) corrected_bands.append(corrected_band) return corrected_bands def _get_corrected_bands_gdal(self): """ 使用GDAL逐波段处理,直接处理整个波段(不分块) 内存峰值 = NIR波段(2个) + 当前处理的波段(1个) + 已处理的波段(累积在列表中) """ corrected_bands = [] # 获取NIR波段对象(用于所有波段的校正) band_640 = self.dataset.GetRasterBand(self.NIR_lower + 1) # GDAL波段从1开始 band_750 = self.dataset.GetRasterBand(self.NIR_upper + 1) # 先读取NIR波段(用于所有波段的校正,会一直保存在内存中) R_640 = band_640.ReadAsArray().astype(np.float32) R_750 = band_750.ReadAsArray().astype(np.float32) diff_640_750 = R_640 - R_750 # 获取水域掩膜 water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None # 逐波段处理:每次只读取和处理一个波段 for i in tqdm(range(self.n_bands), desc="处理波段 (GDAL)", total=self.n_bands, disable=_is_frozen_gui): # 读取当前波段(只加载一个波段到内存) current_band = self.dataset.GetRasterBand(i + 1) R = current_band.ReadAsArray().astype(np.float32) # 校正计算 corrected_band = R - R_750 + self.A + self.B * diff_640_750 np.maximum(corrected_band, 0, out=corrected_band) # 如果存在水域掩膜,只对水域区域应用校正 if water_mask_bool is not None: corrected_band = np.where(water_mask_bool, corrected_band, R) # 添加到结果列表(corrected_band会保留在列表中) corrected_bands.append(corrected_band) # 释放当前波段数据(显式删除有助于及时释放内存) del R return corrected_bands def _get_corrected_bands_gdal_mem(self): """使用GDAL内存驱动处理numpy数组,逐波段处理""" # 创建内存数据集 driver = gdal.GetDriverByName('MEM') mem_dataset = driver.Create('', self.width, self.height, self.n_bands, gdal.GDT_Float32) # 将numpy数组写入内存数据集(显示进度) for i in tqdm(range(self.n_bands), desc="加载波段到内存", total=self.n_bands, disable=_is_frozen_gui): band = mem_dataset.GetRasterBand(i + 1) band.WriteArray(self.im_aligned[:,:,i]) band.FlushCache() # 临时保存原始dataset引用 original_dataset = self.dataset self.dataset = mem_dataset try: # 使用逐波段处理方法 result = self._get_corrected_bands_gdal() finally: # 恢复原始dataset self.dataset = original_dataset mem_dataset = None return result def _save_corrected_bands(self, corrected_bands): """ 保存校正后的波段到文件(BSQ格式,ENVI格式) 注意:为了节省内存,直接逐波段写入,不先堆叠成完整数组 :param corrected_bands: 校正后的波段列表 """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法保存影像文件") if self.output_path is None: return import os # 确保输出目录存在 output_dir = os.path.dirname(self.output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) # 从第一个波段获取尺寸信息(避免堆叠所有波段) if not corrected_bands: raise ValueError("校正后的波段列表为空") first_band = corrected_bands[0] height, width = first_band.shape n_bands = len(corrected_bands) # 获取地理变换和投影信息 if self.is_file_path and self.dataset is not None: geotransform = self.dataset.GetGeoTransform() projection = self.dataset.GetProjection() else: # 如果没有地理信息,使用默认值 geotransform = (0, 1, 0, 0, 0, -1) projection = "" # 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq base_path, ext = os.path.splitext(self.output_path) # 如果扩展名不是.bsq,使用基础路径添加.bsq if ext.lower() != '.bsq': bsq_path = base_path + '.bsq' else: bsq_path = self.output_path # 使用ENVI驱动(默认就是BSQ格式) driver = gdal.GetDriverByName('ENVI') if driver is None: raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用") # 创建ENVI格式数据集(会自动生成.hdr文件) dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32) if dataset is None: raise ValueError(f"无法创建输出文件: {bsq_path}") try: # 设置地理变换和投影 if geotransform: dataset.SetGeoTransform(geotransform) if projection: dataset.SetProjection(projection) # 直接逐波段写入(不先堆叠,节省内存) for i in tqdm(range(n_bands), desc="保存波段", total=n_bands, disable=_is_frozen_gui): band = dataset.GetRasterBand(i + 1) # 直接从列表中获取波段并写入,避免创建完整数组 band.WriteArray(corrected_bands[i]) band.FlushCache() finally: dataset = None # 检查.hdr文件是否已创建 hdr_path = bsq_path + '.hdr' if os.path.exists(hdr_path): print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") print(f"头文件已保存至: {hdr_path}") else: print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)") print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建") def get_corrected_bands(self): """ 获取校正后的波段 根据输入类型和大小自动选择最优处理方法 :return: 校正后的波段列表 """ # 如果输入是文件路径,使用GDAL直接读取 if self.is_file_path: if self.use_gdal: corrected_bands = self._get_corrected_bands_gdal() else: raise ValueError("输入为文件路径时,必须安装GDAL") else: # 如果输入是numpy数组 if self.use_gdal and self.height * self.width * self.n_bands > 100000000: # 大图像使用GDAL内存驱动逐波段处理 corrected_bands = self._get_corrected_bands_gdal_mem() else: # 小图像使用numpy直接处理 corrected_bands = self._get_corrected_bands_numpy() # 如果提供了输出路径,保存结果 if self.output_path is not None: self._save_corrected_bands(corrected_bands) return corrected_bands def __del__(self): """清理资源""" if self.dataset is not None and self.is_file_path: self.dataset = None