307 lines
11 KiB
Python
307 lines
11 KiB
Python
import numpy as np
|
||
import os
|
||
|
||
try:
|
||
from osgeo import gdal
|
||
GDAL_AVAILABLE = True
|
||
except ImportError:
|
||
GDAL_AVAILABLE = False
|
||
|
||
|
||
class Hedley:
|
||
def __init__(self, img_path, shp_path=None, NIR_band=47, water_mask=None,
|
||
output_path=None, block_size=1000):
|
||
"""
|
||
Hedley 耀斑去除算法 - 分块逐波段处理版本
|
||
|
||
:param img_path (str): 输入影像文件路径(GDAL可读取的格式)
|
||
:param shp_path (str, optional): 深水区域shapefile,已废弃,请使用water_mask
|
||
:param NIR_band (int): NIR波段索引(默认47,对应842.36nm)
|
||
:param water_mask (np.ndarray or str or None): 水域掩膜
|
||
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
|
||
:param block_size (int): 分块大小(默认1000)
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||
|
||
self.img_path = img_path
|
||
self.NIR_band = int(float(NIR_band))
|
||
self.water_mask = None
|
||
self.water_mask_path = water_mask
|
||
self.output_path = output_path
|
||
self.block_size = block_size
|
||
self.R_min = None
|
||
self.corr_list = None # 全局协方差系数列表
|
||
|
||
# 打开影像
|
||
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||
if self.dataset is None:
|
||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||
self.width = self.dataset.RasterXSize
|
||
self.height = self.dataset.RasterYSize
|
||
self.n_bands = self.dataset.RasterCount
|
||
|
||
def _load_water_mask(self):
|
||
"""延迟加载水域掩膜"""
|
||
if self.water_mask_path is None:
|
||
return None
|
||
|
||
if isinstance(self.water_mask_path, np.ndarray):
|
||
if self.water_mask_path.shape[:2] != (self.height, self.width):
|
||
raise ValueError(
|
||
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||
)
|
||
return (self.water_mask_path > 0).astype(np.uint8)
|
||
|
||
if isinstance(self.water_mask_path, str):
|
||
if self.water_mask_path.lower().endswith('.shp'):
|
||
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
|
||
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
|
||
if mask_dataset is None:
|
||
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
|
||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||
mask_dataset = None
|
||
if mask_array.shape != (self.height, self.width):
|
||
raise ValueError(
|
||
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||
)
|
||
return (mask_array > 0).astype(np.uint8)
|
||
|
||
return None
|
||
|
||
def covariance_NIR(self, NIR, b):
|
||
"""计算 NIR 与波段 b 之间的协方差系数 b_i = Cov(NIR,b) / Var(NIR)"""
|
||
n = len(NIR)
|
||
nir_mean = np.mean(NIR)
|
||
b_mean = np.mean(b)
|
||
pij = np.mean((NIR - nir_mean) * (b - b_mean))
|
||
pjj = np.mean((NIR - nir_mean) ** 2)
|
||
return pij / pjj if pjj != 0 else 0.0
|
||
|
||
def _scan_global_stats(self, sample_step=20):
|
||
"""
|
||
扫描全图获取全局 R_min
|
||
|
||
使用重采样方式扫描,大幅降低内存占用。
|
||
"""
|
||
print(f"[Hedley] 扫描全局统计量(采样步长={sample_step})...")
|
||
water_mask = self._load_water_mask()
|
||
|
||
nir_samples = []
|
||
sample_count = 0
|
||
|
||
for y_off in range(0, self.height, self.block_size):
|
||
y_end = min(y_off + self.block_size, self.height)
|
||
block_height = y_end - y_off
|
||
|
||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
|
||
nir_band = None
|
||
|
||
if water_mask is not None:
|
||
mask_block = water_mask[y_off:y_end, :]
|
||
mask_bool = mask_block.astype(bool)
|
||
else:
|
||
mask_bool = np.ones((block_height, self.width), dtype=bool)
|
||
|
||
if mask_bool.any():
|
||
nir_sampled = nir_block[mask_bool][::sample_step]
|
||
nir_samples.append(nir_sampled)
|
||
sample_count += nir_sampled.size
|
||
|
||
del nir_block, mask_block
|
||
|
||
if sample_count == 0:
|
||
self.R_min = 0.0
|
||
else:
|
||
all_nir = np.concatenate(nir_samples)
|
||
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
|
||
del all_nir
|
||
|
||
print(f"[Hedley] 全局 R_min={self.R_min:.4f}")
|
||
|
||
def _compute_corr_list(self, sample_step=5):
|
||
"""
|
||
计算每个波段与NIR的协方差系数 corr_list[b] = Cov(NIR, band_b) / Var(NIR)
|
||
|
||
全分辨率扫描,逐波段读取,每波段内存 ≈ block_size²
|
||
由于需要相关性计算,需要足够多的样本,取sample_step=5
|
||
"""
|
||
print(f"[Hedley] 计算全局协方差系数列表(采样步长={sample_step})...")
|
||
water_mask = self._load_water_mask()
|
||
|
||
# 预收集NIR和每个波段的样本数据
|
||
nir_samples = []
|
||
band_samples = [[] for _ in range(self.n_bands)]
|
||
|
||
for y_off in range(0, self.height, self.block_size):
|
||
y_end = min(y_off + self.block_size, self.height)
|
||
block_height = y_end - y_off
|
||
|
||
# 读取NIR波段(每块只读一次)
|
||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
|
||
nir_band = None
|
||
|
||
# 取 NIR 样本(每块只取一次,放在波段循环外)
|
||
if water_mask is not None:
|
||
mask_block = water_mask[y_off:y_end, :]
|
||
mask_bool = mask_block.astype(bool)
|
||
else:
|
||
mask_bool = np.ones((block_height, self.width), dtype=bool)
|
||
|
||
if mask_bool.any():
|
||
nir_sampled = nir_block[mask_bool][::sample_step]
|
||
nir_samples.append(nir_sampled)
|
||
|
||
# 逐波段读取并采样(all_band 严格使用单波段切片)
|
||
for b in range(self.n_bands):
|
||
band = self.dataset.GetRasterBand(b + 1)
|
||
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
|
||
band = None
|
||
|
||
if mask_bool.any():
|
||
band_sampled = block[mask_bool][::sample_step]
|
||
band_samples[b].append(band_sampled)
|
||
|
||
del block
|
||
|
||
del nir_block
|
||
|
||
# 汇总并计算相关系数
|
||
if len(nir_samples) == 0 or sum(len(s) for s in nir_samples) == 0:
|
||
self.corr_list = [0.0] * self.n_bands
|
||
else:
|
||
all_nir = np.concatenate(nir_samples)
|
||
self.corr_list = []
|
||
for b in range(self.n_bands):
|
||
all_band = np.concatenate(band_samples[b])
|
||
corr = self.covariance_NIR(all_nir, all_band)
|
||
self.corr_list.append(float(corr))
|
||
|
||
del all_nir
|
||
for b in range(self.n_bands):
|
||
band_samples[b] = None
|
||
|
||
print(f"[Hedley] 协方差系数: min={min(self.corr_list):.4f}, max={max(self.corr_list):.4f}")
|
||
|
||
def _process_block(self, x_off, y_off, x_size, y_size):
|
||
"""
|
||
处理单个分块
|
||
|
||
Returns:
|
||
list of np.ndarray: 校正后的波段列表
|
||
"""
|
||
# 读取NIR波段
|
||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||
NIR = nir_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||
nir_band = None
|
||
|
||
# 预计算 NIR - R_min
|
||
NIR_diff = NIR - self.R_min
|
||
|
||
# 获取掩膜
|
||
water_mask = self._load_water_mask()
|
||
if water_mask is not None:
|
||
y_end = y_off + y_size
|
||
x_end = x_off + x_size
|
||
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
|
||
else:
|
||
mask_block = None
|
||
|
||
# 逐波段处理
|
||
corrected_bands = []
|
||
for b in range(self.n_bands):
|
||
band = self.dataset.GetRasterBand(b + 1)
|
||
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||
band = None
|
||
|
||
corr = self.corr_list[b]
|
||
# Hedley 校正公式:R_corrected = R - corr * (NIR - R_min)
|
||
corrected = R - corr * NIR_diff
|
||
|
||
if mask_block is not None:
|
||
corrected = np.where(mask_block, corrected, R)
|
||
|
||
corrected_bands.append(corrected)
|
||
del R
|
||
|
||
del NIR, NIR_diff
|
||
|
||
return corrected_bands
|
||
|
||
def get_corrected_bands(self):
|
||
"""
|
||
执行分块处理,返回校正后的波段列表
|
||
"""
|
||
if self.output_path is None:
|
||
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
|
||
|
||
# Step 1: 扫描全局 R_min
|
||
self._scan_global_stats(sample_step=20)
|
||
|
||
# Step 2: 计算协方差系数列表
|
||
self._compute_corr_list(sample_step=5)
|
||
|
||
# Step 3: 创建输出文件
|
||
output_dir = os.path.dirname(self.output_path)
|
||
if output_dir and not os.path.exists(output_dir):
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
base_path, ext = os.path.splitext(self.output_path)
|
||
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
|
||
|
||
geotransform = self.dataset.GetGeoTransform()
|
||
projection = self.dataset.GetProjection()
|
||
|
||
driver = gdal.GetDriverByName('ENVI')
|
||
out_dataset = driver.Create(bsq_path, self.width, self.height,
|
||
self.n_bands, gdal.GDT_Float32)
|
||
if out_dataset is None:
|
||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||
|
||
out_dataset.SetGeoTransform(geotransform)
|
||
out_dataset.SetProjection(projection)
|
||
|
||
# Step 4: 分块处理
|
||
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
|
||
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
|
||
total_blocks = n_blocks_x * n_blocks_y
|
||
|
||
print(f"[Hedley] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
|
||
|
||
block_idx = 0
|
||
for y_off in range(0, self.height, self.block_size):
|
||
y_end = min(y_off + self.block_size, self.height)
|
||
y_size = y_end - y_off
|
||
|
||
for x_off in range(0, self.width, self.block_size):
|
||
x_end = min(x_off + self.block_size, self.width)
|
||
x_size = x_end - x_off
|
||
block_idx += 1
|
||
|
||
print(f"[Hedley] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
|
||
|
||
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
|
||
|
||
for b in range(self.n_bands):
|
||
out_band = out_dataset.GetRasterBand(b + 1)
|
||
out_band.WriteArray(corrected_bands[b], x_off, y_off)
|
||
out_band.FlushCache()
|
||
|
||
del corrected_bands
|
||
|
||
out_dataset = None
|
||
self.dataset = None
|
||
|
||
hdr_path = bsq_path + '.hdr'
|
||
if os.path.exists(hdr_path):
|
||
print(f"[Hedley] 校正完成,已保存至: {bsq_path}")
|
||
else:
|
||
print(f"[Hedley] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)")
|
||
|
||
return []
|
||
|
||
def __del__(self):
|
||
if self.dataset is not None:
|
||
self.dataset = None |