Files
WQ_GUI/src/core/glint_removal/SUGAR.py
2026-04-08 15:25:08 +08:00

573 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import os
import numpy as np
from scipy import ndimage
from scipy.optimize import minimize_scalar
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# SUn-Glint-Aware Restoration (SUGAR):A sweet and simple algorithm for correcting sunglint
class SUGAR:
def __init__(self, im_aligned,bounds=[(1,2)],sigma=1,estimate_background=True, glint_mask_method="cdf", water_mask=None, output_path=None):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param bounds (a list of tuple): lower and upper bound for optimisation of b for each band
:param sigma (float): smoothing sigma for LoG
:param estimate_background (bool): whether to estimate background spectra using median filtering
:param glint_mask_method (str): choose either "cdf" or "otsu", "cdf" is set as the default
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
如果为None则不保存
"""
self.im_aligned = im_aligned
self.sigma = sigma
self.estimate_background = estimate_background
self.n_bands = im_aligned.shape[-1]
self.bounds = bounds*self.n_bands
self.glint_mask_method = glint_mask_method
self.height = im_aligned.shape[0]
self.width = im_aligned.shape[1]
self.output_path = output_path
# 加载水域掩膜
self.water_mask = self._load_water_mask(water_mask)
def _load_water_mask(self, water_mask):
"""
加载水域掩膜
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
:return: numpy数组或None1表示水域0表示非水域
"""
if water_mask is None:
return None
# 如果已经是numpy数组
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
# 如果是文件路径
if isinstance(water_mask, str):
try:
from osgeo import gdal, ogr
except ImportError:
raise ValueError("使用文件路径作为掩膜时必须安装GDAL")
# 检查是否为shapefile
if water_mask.lower().endswith('.shp'):
# 从shp文件创建掩膜需要参考图像这里假设使用im_aligned的尺寸
# 注意如果输入是numpy数组无法从shp创建掩膜需要提供栅格参考
raise ValueError("SUGAR类输入为numpy数组时无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
else:
# 栅格文件
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {water_mask}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def otsu_thresholding(self,im):
"""
:param im (np.ndarray) of shape mxn. Note that it is the LoG of image
otsu thresholding with Brent's minimisation of a univariate function
returns the value of the threshold for input
"""
auto_bins = int(0.005*im.shape[0]*im.shape[1])
# 使用ravel()而不是flatten(),避免不必要的复制(如果可能)
# 如果存在无效值如NaN或极大值过滤掉它们
im_flat = im.ravel()
# 过滤掉NaN和无穷大值
valid_mask = np.isfinite(im_flat)
if not valid_mask.all():
im_flat = im_flat[valid_mask]
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers使用乘法替代除法
count_sum = count.sum()
hist_norm = count / count_sum # normalised histogram
Q = hist_norm.cumsum() # CDF function ranges from 0 to 1
N = count.shape[0]
N_negative = np.sum(bin < 0)
bins = np.arange(N, dtype=np.float32) # 使用float32减少内存
def otsu_thresh(x):
x = int(x)
# 使用切片而不是hsplit避免创建新数组
p1 = hist_norm[:x]
p2 = hist_norm[x:]
q1 = Q[x]
q2 = Q[N-1] - Q[x]
b1 = bins[:x]
b2 = bins[x:]
# finding means and variances
m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0
m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0
v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0
v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0
# calculates the minimization function
fn = v1 * q1 + v2 * q2
return fn
# brent method is used to minimise an univariate function
# bounded minimisation
# we can just limit the search to negative values since we know thresh should be negative as L<0 for glint pixels
if N_negative <= 1:
# 如果没有足够的负值,使用默认阈值
return bin[np.argmax(count)]
res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded')
thresh = bin[int(res.x)]
return thresh
# def cdf_thresholding(self,im, percentile=0.05):
# """
# :param im (np.ndarray) of shape mxn
# :param percentile (float): lower and upper percentile values are potential glint pixels
# """
# lower_perc = percentile
# upper_perc = 1-percentile
# im_flatten = im.flatten()
# H,X1 = np.histogram(im_flatten, bins = int(0.005*im.shape[0]*im.shape[1]), density=True )
# dx = X1[1] - X1[0]
# F1 = np.cumsum(H)*dx
# F_lower = X1[1:][F1<lower_perc]
# F_upper = X1[1:][F1>upper_perc]
# while((F_lower.size == 0) or (F_upper.size == 0)):
# if (F_lower.size == 0):
# lower_perc += 0.01
# F_lower = X1[1:][F1<lower_perc]
# if (F_upper.size == 0):
# upper_perc -= 0.01
# F_upper = X1[1:][F1>upper_perc]
# lower_thresh = F_lower[-1]
# upper_thresh = F_upper[0]
# return lower_thresh,upper_thresh
def cdf_thresholding(self,im,auto_bins=10):
"""
:param im (np.ndarray) of shape mxn. Note that it is the LoG of image
:param percentile (float): lower and upper percentile values are potential glint pixels
"""
# 使用ravel()而不是flatten(),避免不必要的复制
im_flat = im.ravel()
# 过滤掉NaN和无穷大值
valid_mask = np.isfinite(im_flat)
if not valid_mask.all():
im_flat = im_flat[valid_mask]
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers使用乘法替代除法
thresh = bin[np.argmax(count)]
return thresh
def glint_list(self):
"""
returns a list of np.ndarray, where each item is an extracted glint for each band based on get_glint_mask
"""
glint_mask = self.glint_mask_list()
extracted_glint_list = []
for i in range(self.im_aligned.shape[-1]):
gm = glint_mask[i]
extracted_glint = gm*self.im_aligned[:,:,i]
extracted_glint_list.append(extracted_glint)
return extracted_glint_list
def glint_mask_list(self):
"""
get glint mask using laplacian of gaussian image.
returns a list of np.ndarray
"""
glint_mask_list = []
for i in range(self.im_aligned.shape[-1]):
glint_mask = self.get_glint_mask(self.im_aligned[:,:,i])
glint_mask_list.append(glint_mask)
return glint_mask_list
def log_image_list(self):
"""
get Laplacian of Gaussian (LoG) images for all bands.
returns a list of np.ndarray
"""
log_image_list = []
for i in range(self.im_aligned.shape[-1]):
log_im = self.get_log_image(self.im_aligned[:,:,i])
log_image_list.append(log_im)
return log_image_list
def get_log_image(self, im):
"""
get Laplacian of Gaussian (LoG) image for a single band.
returns a np.ndarray
"""
LoG_im = ndimage.gaussian_laplace(im, sigma=self.sigma)
return LoG_im
def get_glint_mask(self,im):
"""
get glint mask using laplacian of gaussian image.
We assume that water constituents and features follow a smooth continuum,
but glint pixels vary a lot spatially and in intensities
Note that for very extensive glint, this method may not work as well <--:TODO use U-net to identify glint mask
returns a np.ndarray
"""
LoG_im = ndimage.gaussian_laplace(im,sigma=self.sigma)
# 如果存在水域掩膜,只在掩膜内计算阈值
if self.water_mask is not None:
mask_bool = self.water_mask.astype(bool)
if mask_bool.any():
# 只在掩膜内提取LoG值用于阈值计算
LoG_masked = LoG_im[mask_bool]
# 将非掩膜区域设为极大值,确保不影响阈值计算
LoG_for_thresh = LoG_im.copy()
LoG_for_thresh[~mask_bool] = LoG_masked.max() + 1
else:
LoG_for_thresh = LoG_im
else:
LoG_for_thresh = LoG_im
#threshold mask
if (self.glint_mask_method == "otsu"):
thresh = self.otsu_thresholding(LoG_for_thresh)
elif (self.glint_mask_method == "cdf"):
thresh = self.cdf_thresholding(LoG_for_thresh)
else:
raise ValueError('Enter only cdf or otsu as glint_mask_method')
# 使用更高效的方式创建mask避免np.where的开销
glint_mask = (LoG_im < thresh).astype(np.uint8)
# 如果存在水域掩膜将非水域区域设为0
if self.water_mask is not None:
glint_mask = glint_mask * self.water_mask
return glint_mask
def get_est_background(self, im,k_size=5):
"""
:param im (np.ndarray): image of a band
estimate background spectra
returns a np.ndarray
"""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k_size,k_size))
dst = cv2.erode(im, kernel)
return dst
def optimise_correction_by_band(self,im,glint_mask,R_BG,bounds):
"""
:param im (np.ndarray): image of a band
:param glint_mask (np.ndarray): glint mask, where glint area is 1 and non-glint area is 0
use brent method to get the optimimum b which minimises the variation (i.e. variance) in the entire image
returns regression slope b
"""
# 预计算常量,避免在优化函数中重复计算
glint_mask_bool = glint_mask.astype(bool)
R_BG_flat = R_BG if isinstance(R_BG, (int, float)) else R_BG[glint_mask_bool]
def optimise_b(b):
# 优化计算只在glint区域计算校正
if isinstance(R_BG, (int, float)):
im_corrected = im.copy()
im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG)
else:
im_corrected = im.copy()
im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG[glint_mask_bool])
return np.var(im_corrected)
res = minimize_scalar(optimise_b, bounds=bounds, method='bounded')
return res.x
def divide_and_conquer(self):
"""
instead of computing b_list for each window, use the previous b_list to narrow the bounds,
because of the strong spatial autocorrelation, we know that the b (correction magnitude) cannot diff too much
this can optimise the run time
"""
def optimise_correction(self):
"""
returns a list of slope in band order i.e. 0,1,2,3,4,5,6,7,8,9 through optimisation
"""
b_list = []
glint_mask_list = []
est_background_list = []
for i in range(self.n_bands):
glint_mask = self.get_glint_mask(self.im_aligned[:,:,i])
glint_mask_list.append(glint_mask)
if self.estimate_background is True:
est_background = self.get_est_background(self.im_aligned[:,:,i])
est_background_list.append(est_background)
else:
est_background = np.percentile(self.im_aligned[:,:,i], 5, interpolation='nearest')
est_background_list.append(est_background)
bounds = self.bounds[i]
b = self.optimise_correction_by_band(self.im_aligned[:,:,i],glint_mask,est_background,bounds)
b_list.append(b)
# add attributes
self.b_list = b_list
self.glint_mask = glint_mask_list
self.est_background = est_background_list
return b_list, glint_mask_list, est_background_list
def _save_corrected_bands(self, corrected_bands):
"""
保存校正后的波段到文件BSQ格式ENVI格式
:param corrected_bands: 校正后的波段列表
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if self.output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 将波段列表转换为数组
corrected_array = np.stack(corrected_bands, axis=2)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(self.output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
else:
bsq_path = self.output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_array.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_array[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")
def get_corrected_bands(self):
"""
获取校正后的波段
:return: 校正后的波段列表
"""
corrected_bands = []
# 获取水域掩膜(如果存在)
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
for i in range(self.n_bands):
im_band = self.im_aligned[:,:,i]
# 一次性计算mask和background避免重复计算
glint_mask = self.get_glint_mask(im_band)
background = self.get_est_background(im_band, k_size=5)
# 使用视图和原地操作减少内存
im_corrected = im_band.copy()
glint_mask_bool = glint_mask.astype(bool)
im_corrected[glint_mask_bool] = background[glint_mask_bool]
# 如果存在水域掩膜,确保只在水域内应用校正
if water_mask_bool is not None:
# 只在水域掩膜内应用校正
correction_mask = glint_mask_bool & water_mask_bool
im_corrected = np.where(correction_mask, background, im_band)
# 非水域区域保持原值
im_corrected = np.where(water_mask_bool, im_corrected, im_band)
corrected_bands.append(im_corrected)
# 如果提供了输出路径,保存结果
if self.output_path is not None:
self._save_corrected_bands(corrected_bands)
return corrected_bands
def correction_iterative(im_aligned,iter=3,bounds = [(1,2)],estimate_background=True,glint_mask_method="cdf",get_glint_mask=False,termination_thresh = 20, water_mask=None, output_path=None):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param iter (int or None): number of iterations to run the sugar algorithm. If None, termination conditions are automatically applied
:param bounds (list of tuples): to limit correction magnitude
:param get_glint_mask (np.ndarray):
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存最后一次迭代的校正结果
如果为None则不保存
conducts iterative correction using SUGAR
"""
glint_image = im_aligned.copy()
corrected_images = []
if iter is None:
# termination conditions
relative_difference = lambda sd0,sd1: sd1/sd0*100
marginal_difference = lambda sd1,sd2: (sd1-sd2)/sd1*100
relative_diff_thresh = marginal_difference_thresh = termination_thresh
sd_og = np.var(im_aligned)
iter_count = 0
sd_next = sd_og # 不需要copy直接使用值
max_iter = 100 # 添加最大迭代次数限制,防止无限循环
while ((relative_difference(sd_og,sd_next) > relative_diff_thresh) and iter_count < max_iter):
# do all the processing here
HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask)
corrected_bands = HM.get_corrected_bands()
glint_image = np.stack(corrected_bands,axis=2)
sd_temp = np.var(glint_image)
# 只在需要时保存中间结果,减少内存占用
if get_glint_mask or iter_count == 0:
corrected_images.append(glint_image.copy())
else:
corrected_images.append(glint_image) # 最后一次迭代的结果
# save glint_mask
# if iter_count == 0 and get_glint_mask is True:
# glint_mask = np.stack(HM.glint_mask,axis=2)
if (marginal_difference(sd_next,sd_temp)<marginal_difference_thresh):
break
else:
sd_next = sd_temp
#increase count
iter_count += 1
# 如果提供了输出路径,保存最后一次迭代的结果
if output_path is not None and len(corrected_images) > 0:
_save_corrected_image(corrected_images[-1], output_path)
else:
for i in range(iter):
HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask)
corrected_bands = HM.get_corrected_bands()
glint_image = np.stack(corrected_bands,axis=2)
# 只在最后一次迭代或需要时保存所有结果
if i == iter - 1 or get_glint_mask:
corrected_images.append(glint_image.copy())
else:
# 对于中间迭代,可以只保存引用(但要注意内存管理)
corrected_images.append(glint_image)
# save glint_mask
# if i == 0 and get_glint_mask is True:
# glint_mask = np.stack(HM.glint_mask,axis=2)
# 如果提供了输出路径,保存最后一次迭代的结果
if output_path is not None and len(corrected_images) > 0:
_save_corrected_image(corrected_images[-1], output_path)
return corrected_images
def _save_corrected_image(corrected_image, output_path):
"""
保存校正后的图像到文件用于correction_iterative函数BSQ格式ENVI格式
:param corrected_image: 校正后的图像数组,形状为(height, width, bands)
:param output_path: 输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
else:
bsq_path = output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_image.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_image[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")