Files
WQ_GUI/src/core/glint_removal/Hedley.py

307 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import os
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
class Hedley:
def __init__(self, img_path, shp_path=None, NIR_band=47, water_mask=None,
output_path=None, block_size=1000):
"""
Hedley 耀斑去除算法 - 分块逐波段处理版本
:param img_path (str): 输入影像文件路径GDAL可读取的格式
:param shp_path (str, optional): 深水区域shapefile已废弃请使用water_mask
:param NIR_band (int): NIR波段索引默认47对应842.36nm
:param water_mask (np.ndarray or str or None): 水域掩膜
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
:param block_size (int): 分块大小默认1000
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
self.img_path = img_path
self.NIR_band = int(float(NIR_band))
self.water_mask = None
self.water_mask_path = water_mask
self.output_path = output_path
self.block_size = block_size
self.R_min = None
self.corr_list = None # 全局协方差系数列表
# 打开影像
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if self.dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
self.width = self.dataset.RasterXSize
self.height = self.dataset.RasterYSize
self.n_bands = self.dataset.RasterCount
def _load_water_mask(self):
"""延迟加载水域掩膜"""
if self.water_mask_path is None:
return None
if isinstance(self.water_mask_path, np.ndarray):
if self.water_mask_path.shape[:2] != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (self.water_mask_path > 0).astype(np.uint8)
if isinstance(self.water_mask_path, str):
if self.water_mask_path.lower().endswith('.shp'):
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (mask_array > 0).astype(np.uint8)
return None
def covariance_NIR(self, NIR, b):
"""计算 NIR 与波段 b 之间的协方差系数 b_i = Cov(NIR,b) / Var(NIR)"""
n = len(NIR)
nir_mean = np.mean(NIR)
b_mean = np.mean(b)
pij = np.mean((NIR - nir_mean) * (b - b_mean))
pjj = np.mean((NIR - nir_mean) ** 2)
return pij / pjj if pjj != 0 else 0.0
def _scan_global_stats(self, sample_step=20):
"""
扫描全图获取全局 R_min
使用重采样方式扫描,大幅降低内存占用。
"""
print(f"[Hedley] 扫描全局统计量(采样步长={sample_step}...")
water_mask = self._load_water_mask()
nir_samples = []
sample_count = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
block_height = y_end - y_off
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
nir_band = None
if water_mask is not None:
mask_block = water_mask[y_off:y_end, :]
mask_bool = mask_block.astype(bool)
else:
mask_bool = np.ones((block_height, self.width), dtype=bool)
if mask_bool.any():
nir_sampled = nir_block[mask_bool][::sample_step]
nir_samples.append(nir_sampled)
sample_count += nir_sampled.size
del nir_block, mask_block
if sample_count == 0:
self.R_min = 0.0
else:
all_nir = np.concatenate(nir_samples)
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
del all_nir
print(f"[Hedley] 全局 R_min={self.R_min:.4f}")
def _compute_corr_list(self, sample_step=5):
"""
计算每个波段与NIR的协方差系数 corr_list[b] = Cov(NIR, band_b) / Var(NIR)
全分辨率扫描,逐波段读取,每波段内存 ≈ block_size²
由于需要相关性计算需要足够多的样本取sample_step=5
"""
print(f"[Hedley] 计算全局协方差系数列表(采样步长={sample_step}...")
water_mask = self._load_water_mask()
# 预收集NIR和每个波段的样本数据
nir_samples = []
band_samples = [[] for _ in range(self.n_bands)]
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
block_height = y_end - y_off
# 读取NIR波段每块只读一次
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
nir_band = None
# 取 NIR 样本(每块只取一次,放在波段循环外)
if water_mask is not None:
mask_block = water_mask[y_off:y_end, :]
mask_bool = mask_block.astype(bool)
else:
mask_bool = np.ones((block_height, self.width), dtype=bool)
if mask_bool.any():
nir_sampled = nir_block[mask_bool][::sample_step]
nir_samples.append(nir_sampled)
# 逐波段读取并采样all_band 严格使用单波段切片)
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
band = None
if mask_bool.any():
band_sampled = block[mask_bool][::sample_step]
band_samples[b].append(band_sampled)
del block
del nir_block
# 汇总并计算相关系数
if len(nir_samples) == 0 or sum(len(s) for s in nir_samples) == 0:
self.corr_list = [0.0] * self.n_bands
else:
all_nir = np.concatenate(nir_samples)
self.corr_list = []
for b in range(self.n_bands):
all_band = np.concatenate(band_samples[b])
corr = self.covariance_NIR(all_nir, all_band)
self.corr_list.append(float(corr))
del all_nir
for b in range(self.n_bands):
band_samples[b] = None
print(f"[Hedley] 协方差系数: min={min(self.corr_list):.4f}, max={max(self.corr_list):.4f}")
def _process_block(self, x_off, y_off, x_size, y_size):
"""
处理单个分块
Returns:
list of np.ndarray: 校正后的波段列表
"""
# 读取NIR波段
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
NIR = nir_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
nir_band = None
# 预计算 NIR - R_min
NIR_diff = NIR - self.R_min
# 获取掩膜
water_mask = self._load_water_mask()
if water_mask is not None:
y_end = y_off + y_size
x_end = x_off + x_size
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
else:
mask_block = None
# 逐波段处理
corrected_bands = []
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
band = None
corr = self.corr_list[b]
# Hedley 校正公式R_corrected = R - corr * (NIR - R_min)
corrected = R - corr * NIR_diff
if mask_block is not None:
corrected = np.where(mask_block, corrected, R)
corrected_bands.append(corrected)
del R
del NIR, NIR_diff
return corrected_bands
def get_corrected_bands(self):
"""
执行分块处理,返回校正后的波段列表
"""
if self.output_path is None:
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
# Step 1: 扫描全局 R_min
self._scan_global_stats(sample_step=20)
# Step 2: 计算协方差系数列表
self._compute_corr_list(sample_step=5)
# Step 3: 创建输出文件
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
base_path, ext = os.path.splitext(self.output_path)
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
geotransform = self.dataset.GetGeoTransform()
projection = self.dataset.GetProjection()
driver = gdal.GetDriverByName('ENVI')
out_dataset = driver.Create(bsq_path, self.width, self.height,
self.n_bands, gdal.GDT_Float32)
if out_dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
# Step 4: 分块处理
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
total_blocks = n_blocks_x * n_blocks_y
print(f"[Hedley] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
block_idx = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
y_size = y_end - y_off
for x_off in range(0, self.width, self.block_size):
x_end = min(x_off + self.block_size, self.width)
x_size = x_end - x_off
block_idx += 1
print(f"[Hedley] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
for b in range(self.n_bands):
out_band = out_dataset.GetRasterBand(b + 1)
out_band.WriteArray(corrected_bands[b], x_off, y_off)
out_band.FlushCache()
del corrected_bands
out_dataset = None
self.dataset = None
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"[Hedley] 校正完成,已保存至: {bsq_path}")
else:
print(f"[Hedley] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件")
return []
def __del__(self):
if self.dataset is not None:
self.dataset = None