Initial commit of WQ_GUI

This commit is contained in:
2026-04-08 15:25:08 +08:00
commit 91e36407ae
302 changed files with 40872 additions and 0 deletions

View File

@ -0,0 +1,572 @@
import cv2
import os
import numpy as np
from scipy import ndimage
from scipy.optimize import minimize_scalar
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# SUn-Glint-Aware Restoration (SUGAR):A sweet and simple algorithm for correcting sunglint
class SUGAR:
def __init__(self, im_aligned,bounds=[(1,2)],sigma=1,estimate_background=True, glint_mask_method="cdf", water_mask=None, output_path=None):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param bounds (a list of tuple): lower and upper bound for optimisation of b for each band
:param sigma (float): smoothing sigma for LoG
:param estimate_background (bool): whether to estimate background spectra using median filtering
:param glint_mask_method (str): choose either "cdf" or "otsu", "cdf" is set as the default
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
如果为None则不保存
"""
self.im_aligned = im_aligned
self.sigma = sigma
self.estimate_background = estimate_background
self.n_bands = im_aligned.shape[-1]
self.bounds = bounds*self.n_bands
self.glint_mask_method = glint_mask_method
self.height = im_aligned.shape[0]
self.width = im_aligned.shape[1]
self.output_path = output_path
# 加载水域掩膜
self.water_mask = self._load_water_mask(water_mask)
def _load_water_mask(self, water_mask):
"""
加载水域掩膜
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
:return: numpy数组或None1表示水域0表示非水域
"""
if water_mask is None:
return None
# 如果已经是numpy数组
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
# 如果是文件路径
if isinstance(water_mask, str):
try:
from osgeo import gdal, ogr
except ImportError:
raise ValueError("使用文件路径作为掩膜时必须安装GDAL")
# 检查是否为shapefile
if water_mask.lower().endswith('.shp'):
# 从shp文件创建掩膜需要参考图像这里假设使用im_aligned的尺寸
# 注意如果输入是numpy数组无法从shp创建掩膜需要提供栅格参考
raise ValueError("SUGAR类输入为numpy数组时无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
else:
# 栅格文件
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {water_mask}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def otsu_thresholding(self,im):
"""
:param im (np.ndarray) of shape mxn. Note that it is the LoG of image
otsu thresholding with Brent's minimisation of a univariate function
returns the value of the threshold for input
"""
auto_bins = int(0.005*im.shape[0]*im.shape[1])
# 使用ravel()而不是flatten(),避免不必要的复制(如果可能)
# 如果存在无效值如NaN或极大值过滤掉它们
im_flat = im.ravel()
# 过滤掉NaN和无穷大值
valid_mask = np.isfinite(im_flat)
if not valid_mask.all():
im_flat = im_flat[valid_mask]
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers使用乘法替代除法
count_sum = count.sum()
hist_norm = count / count_sum # normalised histogram
Q = hist_norm.cumsum() # CDF function ranges from 0 to 1
N = count.shape[0]
N_negative = np.sum(bin < 0)
bins = np.arange(N, dtype=np.float32) # 使用float32减少内存
def otsu_thresh(x):
x = int(x)
# 使用切片而不是hsplit避免创建新数组
p1 = hist_norm[:x]
p2 = hist_norm[x:]
q1 = Q[x]
q2 = Q[N-1] - Q[x]
b1 = bins[:x]
b2 = bins[x:]
# finding means and variances
m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0
m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0
v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0
v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0
# calculates the minimization function
fn = v1 * q1 + v2 * q2
return fn
# brent method is used to minimise an univariate function
# bounded minimisation
# we can just limit the search to negative values since we know thresh should be negative as L<0 for glint pixels
if N_negative <= 1:
# 如果没有足够的负值,使用默认阈值
return bin[np.argmax(count)]
res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded')
thresh = bin[int(res.x)]
return thresh
# def cdf_thresholding(self,im, percentile=0.05):
# """
# :param im (np.ndarray) of shape mxn
# :param percentile (float): lower and upper percentile values are potential glint pixels
# """
# lower_perc = percentile
# upper_perc = 1-percentile
# im_flatten = im.flatten()
# H,X1 = np.histogram(im_flatten, bins = int(0.005*im.shape[0]*im.shape[1]), density=True )
# dx = X1[1] - X1[0]
# F1 = np.cumsum(H)*dx
# F_lower = X1[1:][F1<lower_perc]
# F_upper = X1[1:][F1>upper_perc]
# while((F_lower.size == 0) or (F_upper.size == 0)):
# if (F_lower.size == 0):
# lower_perc += 0.01
# F_lower = X1[1:][F1<lower_perc]
# if (F_upper.size == 0):
# upper_perc -= 0.01
# F_upper = X1[1:][F1>upper_perc]
# lower_thresh = F_lower[-1]
# upper_thresh = F_upper[0]
# return lower_thresh,upper_thresh
def cdf_thresholding(self,im,auto_bins=10):
"""
:param im (np.ndarray) of shape mxn. Note that it is the LoG of image
:param percentile (float): lower and upper percentile values are potential glint pixels
"""
# 使用ravel()而不是flatten(),避免不必要的复制
im_flat = im.ravel()
# 过滤掉NaN和无穷大值
valid_mask = np.isfinite(im_flat)
if not valid_mask.all():
im_flat = im_flat[valid_mask]
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers使用乘法替代除法
thresh = bin[np.argmax(count)]
return thresh
def glint_list(self):
"""
returns a list of np.ndarray, where each item is an extracted glint for each band based on get_glint_mask
"""
glint_mask = self.glint_mask_list()
extracted_glint_list = []
for i in range(self.im_aligned.shape[-1]):
gm = glint_mask[i]
extracted_glint = gm*self.im_aligned[:,:,i]
extracted_glint_list.append(extracted_glint)
return extracted_glint_list
def glint_mask_list(self):
"""
get glint mask using laplacian of gaussian image.
returns a list of np.ndarray
"""
glint_mask_list = []
for i in range(self.im_aligned.shape[-1]):
glint_mask = self.get_glint_mask(self.im_aligned[:,:,i])
glint_mask_list.append(glint_mask)
return glint_mask_list
def log_image_list(self):
"""
get Laplacian of Gaussian (LoG) images for all bands.
returns a list of np.ndarray
"""
log_image_list = []
for i in range(self.im_aligned.shape[-1]):
log_im = self.get_log_image(self.im_aligned[:,:,i])
log_image_list.append(log_im)
return log_image_list
def get_log_image(self, im):
"""
get Laplacian of Gaussian (LoG) image for a single band.
returns a np.ndarray
"""
LoG_im = ndimage.gaussian_laplace(im, sigma=self.sigma)
return LoG_im
def get_glint_mask(self,im):
"""
get glint mask using laplacian of gaussian image.
We assume that water constituents and features follow a smooth continuum,
but glint pixels vary a lot spatially and in intensities
Note that for very extensive glint, this method may not work as well <--:TODO use U-net to identify glint mask
returns a np.ndarray
"""
LoG_im = ndimage.gaussian_laplace(im,sigma=self.sigma)
# 如果存在水域掩膜,只在掩膜内计算阈值
if self.water_mask is not None:
mask_bool = self.water_mask.astype(bool)
if mask_bool.any():
# 只在掩膜内提取LoG值用于阈值计算
LoG_masked = LoG_im[mask_bool]
# 将非掩膜区域设为极大值,确保不影响阈值计算
LoG_for_thresh = LoG_im.copy()
LoG_for_thresh[~mask_bool] = LoG_masked.max() + 1
else:
LoG_for_thresh = LoG_im
else:
LoG_for_thresh = LoG_im
#threshold mask
if (self.glint_mask_method == "otsu"):
thresh = self.otsu_thresholding(LoG_for_thresh)
elif (self.glint_mask_method == "cdf"):
thresh = self.cdf_thresholding(LoG_for_thresh)
else:
raise ValueError('Enter only cdf or otsu as glint_mask_method')
# 使用更高效的方式创建mask避免np.where的开销
glint_mask = (LoG_im < thresh).astype(np.uint8)
# 如果存在水域掩膜将非水域区域设为0
if self.water_mask is not None:
glint_mask = glint_mask * self.water_mask
return glint_mask
def get_est_background(self, im,k_size=5):
"""
:param im (np.ndarray): image of a band
estimate background spectra
returns a np.ndarray
"""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k_size,k_size))
dst = cv2.erode(im, kernel)
return dst
def optimise_correction_by_band(self,im,glint_mask,R_BG,bounds):
"""
:param im (np.ndarray): image of a band
:param glint_mask (np.ndarray): glint mask, where glint area is 1 and non-glint area is 0
use brent method to get the optimimum b which minimises the variation (i.e. variance) in the entire image
returns regression slope b
"""
# 预计算常量,避免在优化函数中重复计算
glint_mask_bool = glint_mask.astype(bool)
R_BG_flat = R_BG if isinstance(R_BG, (int, float)) else R_BG[glint_mask_bool]
def optimise_b(b):
# 优化计算只在glint区域计算校正
if isinstance(R_BG, (int, float)):
im_corrected = im.copy()
im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG)
else:
im_corrected = im.copy()
im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG[glint_mask_bool])
return np.var(im_corrected)
res = minimize_scalar(optimise_b, bounds=bounds, method='bounded')
return res.x
def divide_and_conquer(self):
"""
instead of computing b_list for each window, use the previous b_list to narrow the bounds,
because of the strong spatial autocorrelation, we know that the b (correction magnitude) cannot diff too much
this can optimise the run time
"""
def optimise_correction(self):
"""
returns a list of slope in band order i.e. 0,1,2,3,4,5,6,7,8,9 through optimisation
"""
b_list = []
glint_mask_list = []
est_background_list = []
for i in range(self.n_bands):
glint_mask = self.get_glint_mask(self.im_aligned[:,:,i])
glint_mask_list.append(glint_mask)
if self.estimate_background is True:
est_background = self.get_est_background(self.im_aligned[:,:,i])
est_background_list.append(est_background)
else:
est_background = np.percentile(self.im_aligned[:,:,i], 5, interpolation='nearest')
est_background_list.append(est_background)
bounds = self.bounds[i]
b = self.optimise_correction_by_band(self.im_aligned[:,:,i],glint_mask,est_background,bounds)
b_list.append(b)
# add attributes
self.b_list = b_list
self.glint_mask = glint_mask_list
self.est_background = est_background_list
return b_list, glint_mask_list, est_background_list
def _save_corrected_bands(self, corrected_bands):
"""
保存校正后的波段到文件BSQ格式ENVI格式
:param corrected_bands: 校正后的波段列表
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if self.output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 将波段列表转换为数组
corrected_array = np.stack(corrected_bands, axis=2)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(self.output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
else:
bsq_path = self.output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_array.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_array[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")
def get_corrected_bands(self):
"""
获取校正后的波段
:return: 校正后的波段列表
"""
corrected_bands = []
# 获取水域掩膜(如果存在)
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
for i in range(self.n_bands):
im_band = self.im_aligned[:,:,i]
# 一次性计算mask和background避免重复计算
glint_mask = self.get_glint_mask(im_band)
background = self.get_est_background(im_band, k_size=5)
# 使用视图和原地操作减少内存
im_corrected = im_band.copy()
glint_mask_bool = glint_mask.astype(bool)
im_corrected[glint_mask_bool] = background[glint_mask_bool]
# 如果存在水域掩膜,确保只在水域内应用校正
if water_mask_bool is not None:
# 只在水域掩膜内应用校正
correction_mask = glint_mask_bool & water_mask_bool
im_corrected = np.where(correction_mask, background, im_band)
# 非水域区域保持原值
im_corrected = np.where(water_mask_bool, im_corrected, im_band)
corrected_bands.append(im_corrected)
# 如果提供了输出路径,保存结果
if self.output_path is not None:
self._save_corrected_bands(corrected_bands)
return corrected_bands
def correction_iterative(im_aligned,iter=3,bounds = [(1,2)],estimate_background=True,glint_mask_method="cdf",get_glint_mask=False,termination_thresh = 20, water_mask=None, output_path=None):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param iter (int or None): number of iterations to run the sugar algorithm. If None, termination conditions are automatically applied
:param bounds (list of tuples): to limit correction magnitude
:param get_glint_mask (np.ndarray):
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存最后一次迭代的校正结果
如果为None则不保存
conducts iterative correction using SUGAR
"""
glint_image = im_aligned.copy()
corrected_images = []
if iter is None:
# termination conditions
relative_difference = lambda sd0,sd1: sd1/sd0*100
marginal_difference = lambda sd1,sd2: (sd1-sd2)/sd1*100
relative_diff_thresh = marginal_difference_thresh = termination_thresh
sd_og = np.var(im_aligned)
iter_count = 0
sd_next = sd_og # 不需要copy直接使用值
max_iter = 100 # 添加最大迭代次数限制,防止无限循环
while ((relative_difference(sd_og,sd_next) > relative_diff_thresh) and iter_count < max_iter):
# do all the processing here
HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask)
corrected_bands = HM.get_corrected_bands()
glint_image = np.stack(corrected_bands,axis=2)
sd_temp = np.var(glint_image)
# 只在需要时保存中间结果,减少内存占用
if get_glint_mask or iter_count == 0:
corrected_images.append(glint_image.copy())
else:
corrected_images.append(glint_image) # 最后一次迭代的结果
# save glint_mask
# if iter_count == 0 and get_glint_mask is True:
# glint_mask = np.stack(HM.glint_mask,axis=2)
if (marginal_difference(sd_next,sd_temp)<marginal_difference_thresh):
break
else:
sd_next = sd_temp
#increase count
iter_count += 1
# 如果提供了输出路径,保存最后一次迭代的结果
if output_path is not None and len(corrected_images) > 0:
_save_corrected_image(corrected_images[-1], output_path)
else:
for i in range(iter):
HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask)
corrected_bands = HM.get_corrected_bands()
glint_image = np.stack(corrected_bands,axis=2)
# 只在最后一次迭代或需要时保存所有结果
if i == iter - 1 or get_glint_mask:
corrected_images.append(glint_image.copy())
else:
# 对于中间迭代,可以只保存引用(但要注意内存管理)
corrected_images.append(glint_image)
# save glint_mask
# if i == 0 and get_glint_mask is True:
# glint_mask = np.stack(HM.glint_mask,axis=2)
# 如果提供了输出路径,保存最后一次迭代的结果
if output_path is not None and len(corrected_images) > 0:
_save_corrected_image(corrected_images[-1], output_path)
return corrected_images
def _save_corrected_image(corrected_image, output_path):
"""
保存校正后的图像到文件用于correction_iterative函数BSQ格式ENVI格式
:param corrected_image: 校正后的图像数组,形状为(height, width, bands)
:param output_path: 输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
else:
bsq_path = output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_image.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_image[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")