Files
WQ_GUI/src/core/glint_removal/SUGAR.py

636 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import os
from scipy import ndimage
from scipy.optimize import minimize_scalar
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
def otsu_thresholding(im, auto_bins=None):
"""
Otsu阈值分割
"""
if auto_bins is None:
auto_bins = max(10, int(0.005 * im.shape[0] * im.shape[1]))
im_flat = im.ravel()
valid_mask = np.isfinite(im_flat)
if not valid_mask.all():
im_flat = im_flat[valid_mask]
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5
count_sum = count.sum()
hist_norm = count / count_sum
Q = hist_norm.cumsum()
N = count.shape[0]
N_negative = np.sum(bin < 0)
bins = np.arange(N, dtype=np.float32)
def otsu_thresh(x):
x = int(x)
p1 = hist_norm[:x]
p2 = hist_norm[x:]
q1 = Q[x]
q2 = Q[N - 1] - Q[x]
b1 = bins[:x]
b2 = bins[x:]
m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0
m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0
v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0
v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0
return v1 * q1 + v2 * q2
if N_negative <= 1:
return bin[np.argmax(count)]
res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded')
return bin[int(res.x)]
def cdf_thresholding(im, auto_bins=10):
"""CDF阈值分割"""
im_flat = im.ravel()
valid_mask = np.isfinite(im_flat)
if not valid_mask.all():
im_flat = im_flat[valid_mask]
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5
return bin[np.argmax(count)]
class SUGAR:
"""
SUGAR 耀斑去除算法 - 分块逐波段处理版本
策略:
1. 分块扫描全图,计算每个块的 glint_mask需要全局阈值
2. 收集所有 glint 像素值到列表(仅收集索引,不存储完整掩膜数组)
3. 全局优化每波段的 b 值(使用所有 glint 像素的方差最小化)
4. 分块处理:计算 background需全块 -> 应用校正 -> 写入输出
"""
def __init__(self, img_path, bounds=None, sigma=1.0, estimate_background=True,
glint_mask_method="cdf", water_mask=None, output_path=None,
block_size=1000):
"""
:param img_path (str): 输入影像文件路径
:param bounds: 每个波段的优化边界,默认 [(1,2)] * n_bands
:param sigma (float): LoG 平滑 sigma
:param estimate_background (bool): 是否用中值滤波估计背景
:param glint_mask_method (str): "cdf""otsu"
:param water_mask: 水域掩膜
:param output_path (str): 输出文件路径
:param block_size (int): 分块大小
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
if bounds is None:
bounds = [(1, 2)]
self.img_path = img_path
self.bounds = bounds
self.sigma = sigma
self.estimate_background = estimate_background
self.glint_mask_method = glint_mask_method
self.water_mask = None
self.water_mask_path = water_mask
self.output_path = output_path
self.block_size = block_size
# 打开影像
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if self.dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
self.width = self.dataset.RasterXSize
self.height = self.dataset.RasterYSize
self.n_bands = self.dataset.RasterCount
# 扩展 bounds 到所有波段
self.bounds_all = self.bounds * self.n_bands
# 优化结果(全局)
self.b_list = None
self.glint_pixel_indices = [] # list of (block_idx, row, col) 索引
self.thresholds = [] # 每波段的全局阈值
def _load_water_mask(self):
"""延迟加载水域掩膜"""
if self.water_mask_path is None:
return None
if isinstance(self.water_mask_path, np.ndarray):
if self.water_mask_path.shape[:2] != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (self.water_mask_path > 0).astype(np.uint8)
if isinstance(self.water_mask_path, str):
if self.water_mask_path.lower().endswith('.shp'):
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (mask_array > 0).astype(np.uint8)
return None
def _compute_threshold(self, im):
"""计算 glint 阈值"""
if self.glint_mask_method == "otsu":
return otsu_thresholding(im)
else:
return cdf_thresholding(im)
def _get_glint_mask_block(self, band_data):
"""
对单波段块计算 glint mask
阈值来自全局阈值 self.thresholds[band_idx]
"""
# LoG
log_im = ndimage.gaussian_laplace(band_data.astype(np.float32), sigma=self.sigma)
# 全局阈值
thresh = self.thresholds[self._current_band]
glint_mask = (log_im < thresh).astype(np.uint8)
# 应用水域掩膜
water_mask = self._load_water_mask()
if water_mask is not None:
y_off = self._current_y
y_end = y_off + band_data.shape[0]
x_off = self._current_x
x_end = x_off + band_data.shape[1]
mask_block = water_mask[y_off:y_end, x_off:x_end]
glint_mask = glint_mask * mask_block
return log_im, glint_mask
def _get_est_background(self, im, k_size=5):
"""估计背景光谱(中值滤波)"""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size))
return cv2.erode(im.astype(np.float32), kernel)
def _optimise_correction_band(self, R_glint, R_bg_glint, bounds):
"""
全局优化单波段 b 值
使用所有 glint 像素的方差最小化
R_corrected = R - mask * (R / b - R_bg)
最小化 Var(R_corrected)
:param R_glint: 所有 glint 像素值1D array
:param R_bg_glint: 对应背景值1D array
:param bounds: 优化边界
:return: 最优 b 值
"""
if len(R_glint) == 0:
return 1.0
R_glint = R_glint.astype(np.float32)
R_bg_glint = R_bg_glint.astype(np.float32)
def objective(b):
b = float(b)
R_corrected = R_glint - (R_glint / b - R_bg_glint)
return np.var(R_corrected)
res = minimize_scalar(objective, bounds=bounds, method='bounded')
return res.x
def _scan_and_collect_glint(self):
"""
Step 1: 分块扫描全图,收集每波段的全局阈值和 glint 像素索引
内存仅存储每波段的阈值float和 glint 像素位置索引
"""
print(f"[SUGAR] 步骤1: 扫描全图收集glint像素...")
water_mask = self._load_water_mask()
# 初始化阈值列表
self.thresholds = [None] * self.n_bands
log_collections = [[] for _ in range(self.n_bands)]
# 全图扫描:收集每波段的 LoG 值用于阈值计算
n_blocks = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
y_size = y_end - y_off
for x_off in range(0, self.width, self.block_size):
x_end = min(x_off + self.block_size, self.width)
x_size = x_end - x_off
n_blocks += 1
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
band = None
log_im = ndimage.gaussian_laplace(block, sigma=self.sigma)
# mask_block 在波段循环外初始化,每块只计算一次
if b == 0 and water_mask is not None:
_mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
if water_mask is not None:
if _mask_block.any():
log_collections[b].append(log_im[_mask_block])
else:
log_collections[b].append(log_im.ravel())
del block, log_im
if water_mask is not None:
del _mask_block
# 计算每波段的全局阈值需要所有LoG值
print(f"[SUGAR] 计算 {self.n_bands} 个波段的全局阈值...")
for b in range(self.n_bands):
if len(log_collections[b]) == 0:
self.thresholds[b] = 0.0
else:
all_log = np.concatenate(log_collections[b])
thresh = self._compute_threshold(
all_log.reshape(1, -1) # shape (1, N) 模拟二维输入
)
self.thresholds[b] = float(thresh)
del all_log
print(f" 波段{b}: thresh={self.thresholds[b]:.4f}")
log_collections[b] = None
def _collect_glint_pixel_values(self):
"""
Step 2: 再次分块扫描,收集每波段所有 glint 像素的 (R, R_bg) 值用于优化
内存:只存储 1D 数组(所有 glint 像素值)
"""
print(f"[SUGAR] 步骤2: 收集glint像素值用于全局优化...")
water_mask = self._load_water_mask()
R_glint_list = [[] for _ in range(self.n_bands)]
R_bg_glint_list = [[] for _ in range(self.n_bands)]
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
y_size = y_end - y_off
for x_off in range(0, self.width, self.block_size):
x_end = min(x_off + self.block_size, self.width)
x_size = x_end - x_off
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
R_block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
band = None
# LoG 和 mask
log_im = ndimage.gaussian_laplace(R_block, sigma=self.sigma)
thresh = self.thresholds[b]
glint_mask = (log_im < thresh).astype(np.uint8)
if water_mask is not None:
mask_block = water_mask[y_off:y_end, x_off:x_end]
glint_mask = glint_mask * mask_block
# 背景
if self.estimate_background:
R_bg = self._get_est_background(R_block)
else:
R_bg = np.percentile(R_block, 5, method='nearest')
# 收集 glint 像素
glint_idx = glint_mask.astype(bool)
if glint_idx.any():
R_glint_list[b].append(R_block[glint_idx])
R_bg_glint_list[b].append(R_bg[glint_idx])
del R_block, log_im, glint_mask, R_bg
# 汇总
self.R_glint_all = []
self.R_bg_glint_all = []
for b in range(self.n_bands):
if len(R_glint_list[b]) == 0:
self.R_glint_all.append(np.array([], dtype=np.float32))
self.R_bg_glint_all.append(np.array([], dtype=np.float32))
else:
self.R_glint_all.append(np.concatenate(R_glint_list[b]))
self.R_bg_glint_all.append(np.concatenate(R_bg_glint_list[b]))
n = len(self.R_glint_all[b])
print(f" 波段{b}: 收集到 {n} 个glint像素")
R_glint_list[b] = None
R_bg_glint_list[b] = None
def _optimize_b_list(self):
"""
Step 3: 全局优化每波段的 b 值
"""
print(f"[SUGAR] 步骤3: 全局优化 b 值...")
self.b_list = []
for b in range(self.n_bands):
bounds = self.bounds_all[b]
b_opt = self._optimise_correction_band(
self.R_glint_all[b], self.R_bg_glint_all[b], bounds
)
self.b_list.append(float(b_opt))
print(f" 波段{b}: b={b_opt:.4f}")
# 释放内存
self.R_glint_all = None
self.R_bg_glint_all = None
def _process_and_write_block(self, x_off, y_off, x_size, y_size, out_dataset):
"""
Step 4: 分块处理并写入输出文件
"""
water_mask = self._load_water_mask()
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
R_block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
band = None
# 计算 glint mask
log_im = ndimage.gaussian_laplace(R_block, sigma=self.sigma)
thresh = self.thresholds[b]
glint_mask = (log_im < thresh).astype(np.uint8)
if water_mask is not None:
mask_block = water_mask[y_off:y_off + y_size, x_off:x_off + x_size]
glint_mask = glint_mask * mask_block
glint_bool = glint_mask.astype(bool)
# 计算背景
if self.estimate_background:
R_bg = self._get_est_background(R_block)
else:
R_bg = np.percentile(R_block, 5, method='nearest')
# 校正
b_val = self.b_list[b]
R_corrected = R_block.copy()
if glint_bool.any():
R_corrected[glint_bool] = (
R_bg[glint_bool]
+ (R_block[glint_bool] - R_bg[glint_bool]) / b_val
)
# 写入
out_band = out_dataset.GetRasterBand(b + 1)
out_band.WriteArray(R_corrected, x_off, y_off)
out_band.FlushCache()
del R_block, log_im, glint_mask, R_bg, R_corrected
def get_corrected_bands(self):
"""
执行分块处理,返回校正后的波段列表
"""
if self.output_path is None:
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
base_path, ext = os.path.splitext(self.output_path)
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
geotransform = self.dataset.GetGeoTransform()
projection = self.dataset.GetProjection()
driver = gdal.GetDriverByName('ENVI')
out_dataset = driver.Create(bsq_path, self.width, self.height,
self.n_bands, gdal.GDT_Float32)
if out_dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
# Step 1: 扫描收集阈值
self._scan_and_collect_glint()
# Step 2: 收集 glint 像素值
self._collect_glint_pixel_values()
# Step 3: 全局优化 b
self._optimize_b_list()
# Step 4: 分块处理写入
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
total_blocks = n_blocks_x * n_blocks_y
print(f"[SUGAR] 步骤4: 分块处理写入,共 {total_blocks}")
block_idx = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
y_size = y_end - y_off
for x_off in range(0, self.width, self.block_size):
x_end = min(x_off + self.block_size, self.width)
x_size = x_end - x_off
block_idx += 1
self._current_x = x_off
self._current_y = y_off
print(f"[SUGAR] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
self._process_and_write_block(x_off, y_off, x_size, y_size, out_dataset)
out_dataset = None
self.dataset = None
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"[SUGAR] 校正完成,已保存至: {bsq_path}")
else:
print(f"[SUGAR] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件")
return []
def __del__(self):
if self.dataset is not None:
self.dataset = None
# ============================================================================
# 独立函数correction_iterative迭代版本支持大图
# ============================================================================
def correction_iterative(img_path, iter=3, bounds=None, estimate_background=True,
glint_mask_method="cdf", get_glint_mask=False,
termination_thresh=20.0, water_mask=None, output_path=None,
block_size=1000):
"""
SUGAR 迭代去耀斑 - 分块版本
:param img_path (str): 输入影像文件路径
:param iter (int or None): 迭代次数None 表示自动终止
:param bounds: 优化边界
:param estimate_background: 是否估计背景
:param glint_mask_method: "cdf""otsu"
:param get_glint_mask: 是否返回 glint mask已废弃保持接口兼容
:param termination_thresh: 自动终止阈值
:param water_mask: 水域掩膜
:param output_path: 输出文件路径
:param block_size: 分块大小
:return: 每次迭代的校正图像列表None分块模式下返回空列表
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装")
if bounds is None:
bounds = [(1, 2)]
# 打开影像获取基本信息
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
dataset = None
# 计算临时输出路径
temp_dir = os.path.dirname(output_path) if output_path else os.getcwd()
temp_base = os.path.join(temp_dir, "_sugar_iter")
# 确保输出目录存在
if output_path:
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
if iter is None:
relative_diff = lambda sd0, sd1: sd1 / sd0 * 100
marginal_diff = lambda sd1, sd2: (sd1 - sd2) / sd1 * 100
glint_img_path = img_path
iter_count = 0
max_iter = 100
while True:
iter_output = f"{temp_base}_{iter_count}.bsq"
sugar = SUGAR(
glint_img_path,
bounds=bounds,
estimate_background=estimate_background,
glint_mask_method=glint_mask_method,
water_mask=water_mask,
output_path=iter_output,
block_size=block_size
)
sugar.get_corrected_bands()
# 检查方差收敛
# 读取当前输出图像的方差(分块读取第一个波段估算)
ds = gdal.Open(iter_output, gdal.GA_ReadOnly)
if ds is not None:
# 采样估算方差
sample_data = []
for y_off in range(0, height, block_size):
y_end = min(y_off + block_size, height)
block = ds.GetRasterBand(1).ReadAsArray(0, y_off, width, y_end - y_off)
sample_data.append(block.ravel())
all_data = np.concatenate(sample_data)
sd_current = np.var(all_data)
ds = None
del all_data
else:
sd_current = 0
prev_img_path = glint_img_path
glint_img_path = iter_output
if iter_count == 0:
sd_prev = sd_current
# 检查终止条件
if (iter_count > 0 and
marginal_diff(sd_prev, sd_current) < termination_thresh):
break
if iter_count >= max_iter:
break
sd_prev = sd_current
iter_count += 1
# 将最终结果移动到 output_path
if output_path and glint_img_path != output_path:
import shutil
if os.path.exists(output_path):
os.remove(output_path)
# 找最后一个有效输出
last_iter = max(0, iter_count - 1)
final_path = f"{temp_base}_{last_iter}.bsq"
if os.path.exists(final_path):
shutil.move(final_path, output_path)
# 复制hdr
if os.path.exists(final_path + '.hdr'):
import shutil
shutil.copy(final_path + '.hdr', output_path + '.hdr')
else:
glint_img_path = img_path
for i in range(iter):
iter_output = f"{temp_base}_{i}.bsq"
sugar = SUGAR(
glint_img_path,
bounds=bounds,
estimate_background=estimate_background,
glint_mask_method=glint_mask_method,
water_mask=water_mask,
output_path=iter_output,
block_size=block_size
)
sugar.get_corrected_bands()
prev_img_path = glint_img_path
glint_img_path = iter_output
# 将最后一次结果移动到 output_path
if output_path:
last_iter = iter - 1
final_path = f"{temp_base}_{last_iter}.bsq"
if os.path.exists(final_path):
import shutil
# 删除旧文件
if os.path.exists(output_path):
os.remove(output_path)
# 移动
shutil.move(final_path, output_path)
# 复制hdr
if os.path.exists(final_path + '.hdr'):
shutil.copy(final_path + '.hdr', output_path + '.hdr')
# 清理临时文件
for i in range(max(0, iter - 1)):
f = f"{temp_base}_{i}.bsq"
if os.path.exists(f):
os.remove(f)
h = f + '.hdr'
if os.path.exists(h):
os.remove(h)
return []