import numpy as np import os try: from osgeo import gdal GDAL_AVAILABLE = True except ImportError: GDAL_AVAILABLE = False class Hedley: def __init__(self, img_path, shp_path=None, NIR_band=47, water_mask=None, output_path=None, block_size=1000): """ Hedley 耀斑去除算法 - 分块逐波段处理版本 :param img_path (str): 输入影像文件路径(GDAL可读取的格式) :param shp_path (str, optional): 深水区域shapefile,已废弃,请使用water_mask :param NIR_band (int): NIR波段索引(默认47,对应842.36nm) :param water_mask (np.ndarray or str or None): 水域掩膜 :param output_path (str): 输出文件路径(必须提供,用于分块写入) :param block_size (int): 分块大小(默认1000) """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法读取影像文件") self.img_path = img_path self.NIR_band = int(float(NIR_band)) self.water_mask = None self.water_mask_path = water_mask self.output_path = output_path self.block_size = block_size self.R_min = None self.corr_list = None # 全局协方差系数列表 # 打开影像 self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly) if self.dataset is None: raise ValueError(f"无法打开影像文件: {img_path}") self.width = self.dataset.RasterXSize self.height = self.dataset.RasterYSize self.n_bands = self.dataset.RasterCount def _load_water_mask(self): """延迟加载水域掩膜""" if self.water_mask_path is None: return None if isinstance(self.water_mask_path, np.ndarray): if self.water_mask_path.shape[:2] != (self.height, self.width): raise ValueError( f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配" ) return (self.water_mask_path > 0).astype(np.uint8) if isinstance(self.water_mask_path, str): if self.water_mask_path.lower().endswith('.shp'): raise ValueError("请先栅格化shapefile为栅格掩膜文件") mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly) if mask_dataset is None: raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}") mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() mask_dataset = None if mask_array.shape != (self.height, self.width): raise ValueError( f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配" ) return (mask_array > 0).astype(np.uint8) return None def covariance_NIR(self, NIR, b): """计算 NIR 与波段 b 之间的协方差系数 b_i = Cov(NIR,b) / Var(NIR)""" n = len(NIR) nir_mean = np.mean(NIR) b_mean = np.mean(b) pij = np.mean((NIR - nir_mean) * (b - b_mean)) pjj = np.mean((NIR - nir_mean) ** 2) return pij / pjj if pjj != 0 else 0.0 def _scan_global_stats(self, sample_step=20): """ 扫描全图获取全局 R_min 使用重采样方式扫描,大幅降低内存占用。 """ print(f"[Hedley] 扫描全局统计量(采样步长={sample_step})...") water_mask = self._load_water_mask() nir_samples = [] sample_count = 0 for y_off in range(0, self.height, self.block_size): y_end = min(y_off + self.block_size, self.height) block_height = y_end - y_off nir_band = self.dataset.GetRasterBand(self.NIR_band + 1) nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height) nir_band = None if water_mask is not None: mask_block = water_mask[y_off:y_end, :] mask_bool = mask_block.astype(bool) else: mask_bool = np.ones((block_height, self.width), dtype=bool) if mask_bool.any(): nir_sampled = nir_block[mask_bool][::sample_step] nir_samples.append(nir_sampled) sample_count += nir_sampled.size del nir_block, mask_block if sample_count == 0: self.R_min = 0.0 else: all_nir = np.concatenate(nir_samples) self.R_min = float(np.percentile(all_nir, 5, method='nearest')) del all_nir print(f"[Hedley] 全局 R_min={self.R_min:.4f}") def _compute_corr_list(self, sample_step=5): """ 计算每个波段与NIR的协方差系数 corr_list[b] = Cov(NIR, band_b) / Var(NIR) 全分辨率扫描,逐波段读取,每波段内存 ≈ block_size² 由于需要相关性计算,需要足够多的样本,取sample_step=5 """ print(f"[Hedley] 计算全局协方差系数列表(采样步长={sample_step})...") water_mask = self._load_water_mask() # 预收集NIR和每个波段的样本数据 nir_samples = [] band_samples = [[] for _ in range(self.n_bands)] for y_off in range(0, self.height, self.block_size): y_end = min(y_off + self.block_size, self.height) block_height = y_end - y_off # 读取NIR波段(每块只读一次) nir_band = self.dataset.GetRasterBand(self.NIR_band + 1) nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32) nir_band = None # 取 NIR 样本(每块只取一次,放在波段循环外) if water_mask is not None: mask_block = water_mask[y_off:y_end, :] mask_bool = mask_block.astype(bool) else: mask_bool = np.ones((block_height, self.width), dtype=bool) if mask_bool.any(): nir_sampled = nir_block[mask_bool][::sample_step] nir_samples.append(nir_sampled) # 逐波段读取并采样(all_band 严格使用单波段切片) for b in range(self.n_bands): band = self.dataset.GetRasterBand(b + 1) block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32) band = None if mask_bool.any(): band_sampled = block[mask_bool][::sample_step] band_samples[b].append(band_sampled) del block del nir_block # 汇总并计算相关系数 if len(nir_samples) == 0 or sum(len(s) for s in nir_samples) == 0: self.corr_list = [0.0] * self.n_bands else: all_nir = np.concatenate(nir_samples) self.corr_list = [] for b in range(self.n_bands): all_band = np.concatenate(band_samples[b]) corr = self.covariance_NIR(all_nir, all_band) self.corr_list.append(float(corr)) del all_nir for b in range(self.n_bands): band_samples[b] = None print(f"[Hedley] 协方差系数: min={min(self.corr_list):.4f}, max={max(self.corr_list):.4f}") def _process_block(self, x_off, y_off, x_size, y_size): """ 处理单个分块 Returns: list of np.ndarray: 校正后的波段列表 """ # 读取NIR波段 nir_band = self.dataset.GetRasterBand(self.NIR_band + 1) NIR = nir_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32) nir_band = None # 预计算 NIR - R_min NIR_diff = NIR - self.R_min # 获取掩膜 water_mask = self._load_water_mask() if water_mask is not None: y_end = y_off + y_size x_end = x_off + x_size mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool) else: mask_block = None # 逐波段处理 corrected_bands = [] for b in range(self.n_bands): band = self.dataset.GetRasterBand(b + 1) R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32) band = None corr = self.corr_list[b] # Hedley 校正公式:R_corrected = R - corr * (NIR - R_min) corrected = R - corr * NIR_diff if mask_block is not None: corrected = np.where(mask_block, corrected, R) corrected_bands.append(corrected) del R del NIR, NIR_diff return corrected_bands def get_corrected_bands(self): """ 执行分块处理,返回校正后的波段列表 """ if self.output_path is None: raise ValueError("output_path 必须提供,分块处理需要直接写入文件") # Step 1: 扫描全局 R_min self._scan_global_stats(sample_step=20) # Step 2: 计算协方差系数列表 self._compute_corr_list(sample_step=5) # Step 3: 创建输出文件 output_dir = os.path.dirname(self.output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) base_path, ext = os.path.splitext(self.output_path) bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path geotransform = self.dataset.GetGeoTransform() projection = self.dataset.GetProjection() driver = gdal.GetDriverByName('ENVI') out_dataset = driver.Create(bsq_path, self.width, self.height, self.n_bands, gdal.GDT_Float32) if out_dataset is None: raise ValueError(f"无法创建输出文件: {bsq_path}") out_dataset.SetGeoTransform(geotransform) out_dataset.SetProjection(projection) # Step 4: 分块处理 n_blocks_x = (self.width + self.block_size - 1) // self.block_size n_blocks_y = (self.height + self.block_size - 1) // self.block_size total_blocks = n_blocks_x * n_blocks_y print(f"[Hedley] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}") block_idx = 0 for y_off in range(0, self.height, self.block_size): y_end = min(y_off + self.block_size, self.height) y_size = y_end - y_off for x_off in range(0, self.width, self.block_size): x_end = min(x_off + self.block_size, self.width) x_size = x_end - x_off block_idx += 1 print(f"[Hedley] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})") corrected_bands = self._process_block(x_off, y_off, x_size, y_size) for b in range(self.n_bands): out_band = out_dataset.GetRasterBand(b + 1) out_band.WriteArray(corrected_bands[b], x_off, y_off) out_band.FlushCache() del corrected_bands out_dataset = None self.dataset = None hdr_path = bsq_path + '.hdr' if os.path.exists(hdr_path): print(f"[Hedley] 校正完成,已保存至: {bsq_path}") else: print(f"[Hedley] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)") return [] def __del__(self): if self.dataset is not None: self.dataset = None