Initial commit of WQ_GUI
This commit is contained in:
572
src/core/glint_removal/SUGAR.py
Normal file
572
src/core/glint_removal/SUGAR.py
Normal file
@ -0,0 +1,572 @@
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from scipy import ndimage
|
||||
from scipy.optimize import minimize_scalar
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# SUn-Glint-Aware Restoration (SUGAR):A sweet and simple algorithm for correcting sunglint
|
||||
class SUGAR:
|
||||
def __init__(self, im_aligned,bounds=[(1,2)],sigma=1,estimate_background=True, glint_mask_method="cdf", water_mask=None, output_path=None):
|
||||
"""
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param bounds (a list of tuple): lower and upper bound for optimisation of b for each band
|
||||
:param sigma (float): smoothing sigma for LoG
|
||||
:param estimate_background (bool): whether to estimate background spectra using median filtering
|
||||
:param glint_mask_method (str): choose either "cdf" or "otsu", "cdf" is set as the default
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||||
如果为None,则不保存
|
||||
"""
|
||||
self.im_aligned = im_aligned
|
||||
self.sigma = sigma
|
||||
self.estimate_background = estimate_background
|
||||
self.n_bands = im_aligned.shape[-1]
|
||||
self.bounds = bounds*self.n_bands
|
||||
self.glint_mask_method = glint_mask_method
|
||||
self.height = im_aligned.shape[0]
|
||||
self.width = im_aligned.shape[1]
|
||||
self.output_path = output_path
|
||||
|
||||
# 加载水域掩膜
|
||||
self.water_mask = self._load_water_mask(water_mask)
|
||||
|
||||
def _load_water_mask(self, water_mask):
|
||||
"""
|
||||
加载水域掩膜
|
||||
|
||||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
:return: numpy数组或None,1表示水域,0表示非水域
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# 如果已经是numpy数组
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||||
|
||||
# 如果是文件路径
|
||||
if isinstance(water_mask, str):
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
except ImportError:
|
||||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||||
|
||||
# 检查是否为shapefile
|
||||
if water_mask.lower().endswith('.shp'):
|
||||
# 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸)
|
||||
# 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考
|
||||
raise ValueError("SUGAR类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
|
||||
else:
|
||||
# 栅格文件
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||||
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
def otsu_thresholding(self,im):
|
||||
"""
|
||||
:param im (np.ndarray) of shape mxn. Note that it is the LoG of image
|
||||
otsu thresholding with Brent's minimisation of a univariate function
|
||||
returns the value of the threshold for input
|
||||
"""
|
||||
auto_bins = int(0.005*im.shape[0]*im.shape[1])
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制(如果可能)
|
||||
# 如果存在无效值(如NaN或极大值),过滤掉它们
|
||||
im_flat = im.ravel()
|
||||
# 过滤掉NaN和无穷大值
|
||||
valid_mask = np.isfinite(im_flat)
|
||||
if not valid_mask.all():
|
||||
im_flat = im_flat[valid_mask]
|
||||
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
|
||||
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers,使用乘法替代除法
|
||||
|
||||
count_sum = count.sum()
|
||||
hist_norm = count / count_sum # normalised histogram
|
||||
Q = hist_norm.cumsum() # CDF function ranges from 0 to 1
|
||||
N = count.shape[0]
|
||||
N_negative = np.sum(bin < 0)
|
||||
bins = np.arange(N, dtype=np.float32) # 使用float32减少内存
|
||||
|
||||
def otsu_thresh(x):
|
||||
x = int(x)
|
||||
# 使用切片而不是hsplit,避免创建新数组
|
||||
p1 = hist_norm[:x]
|
||||
p2 = hist_norm[x:]
|
||||
q1 = Q[x]
|
||||
q2 = Q[N-1] - Q[x]
|
||||
b1 = bins[:x]
|
||||
b2 = bins[x:]
|
||||
# finding means and variances
|
||||
m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0
|
||||
m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0
|
||||
v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0
|
||||
v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0
|
||||
# calculates the minimization function
|
||||
fn = v1 * q1 + v2 * q2
|
||||
return fn
|
||||
|
||||
# brent method is used to minimise an univariate function
|
||||
# bounded minimisation
|
||||
# we can just limit the search to negative values since we know thresh should be negative as L<0 for glint pixels
|
||||
if N_negative <= 1:
|
||||
# 如果没有足够的负值,使用默认阈值
|
||||
return bin[np.argmax(count)]
|
||||
res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded')
|
||||
thresh = bin[int(res.x)]
|
||||
|
||||
return thresh
|
||||
|
||||
# def cdf_thresholding(self,im, percentile=0.05):
|
||||
# """
|
||||
# :param im (np.ndarray) of shape mxn
|
||||
# :param percentile (float): lower and upper percentile values are potential glint pixels
|
||||
# """
|
||||
# lower_perc = percentile
|
||||
# upper_perc = 1-percentile
|
||||
# im_flatten = im.flatten()
|
||||
# H,X1 = np.histogram(im_flatten, bins = int(0.005*im.shape[0]*im.shape[1]), density=True )
|
||||
# dx = X1[1] - X1[0]
|
||||
# F1 = np.cumsum(H)*dx
|
||||
# F_lower = X1[1:][F1<lower_perc]
|
||||
# F_upper = X1[1:][F1>upper_perc]
|
||||
# while((F_lower.size == 0) or (F_upper.size == 0)):
|
||||
# if (F_lower.size == 0):
|
||||
# lower_perc += 0.01
|
||||
# F_lower = X1[1:][F1<lower_perc]
|
||||
# if (F_upper.size == 0):
|
||||
# upper_perc -= 0.01
|
||||
# F_upper = X1[1:][F1>upper_perc]
|
||||
|
||||
# lower_thresh = F_lower[-1]
|
||||
# upper_thresh = F_upper[0]
|
||||
|
||||
# return lower_thresh,upper_thresh
|
||||
|
||||
def cdf_thresholding(self,im,auto_bins=10):
|
||||
"""
|
||||
:param im (np.ndarray) of shape mxn. Note that it is the LoG of image
|
||||
:param percentile (float): lower and upper percentile values are potential glint pixels
|
||||
"""
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
im_flat = im.ravel()
|
||||
# 过滤掉NaN和无穷大值
|
||||
valid_mask = np.isfinite(im_flat)
|
||||
if not valid_mask.all():
|
||||
im_flat = im_flat[valid_mask]
|
||||
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
|
||||
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers,使用乘法替代除法
|
||||
thresh = bin[np.argmax(count)]
|
||||
return thresh
|
||||
|
||||
def glint_list(self):
|
||||
"""
|
||||
returns a list of np.ndarray, where each item is an extracted glint for each band based on get_glint_mask
|
||||
"""
|
||||
glint_mask = self.glint_mask_list()
|
||||
extracted_glint_list = []
|
||||
for i in range(self.im_aligned.shape[-1]):
|
||||
gm = glint_mask[i]
|
||||
extracted_glint = gm*self.im_aligned[:,:,i]
|
||||
extracted_glint_list.append(extracted_glint)
|
||||
|
||||
return extracted_glint_list
|
||||
|
||||
def glint_mask_list(self):
|
||||
"""
|
||||
get glint mask using laplacian of gaussian image.
|
||||
returns a list of np.ndarray
|
||||
"""
|
||||
glint_mask_list = []
|
||||
for i in range(self.im_aligned.shape[-1]):
|
||||
glint_mask = self.get_glint_mask(self.im_aligned[:,:,i])
|
||||
glint_mask_list.append(glint_mask)
|
||||
|
||||
return glint_mask_list
|
||||
|
||||
def log_image_list(self):
|
||||
"""
|
||||
get Laplacian of Gaussian (LoG) images for all bands.
|
||||
returns a list of np.ndarray
|
||||
"""
|
||||
log_image_list = []
|
||||
for i in range(self.im_aligned.shape[-1]):
|
||||
log_im = self.get_log_image(self.im_aligned[:,:,i])
|
||||
log_image_list.append(log_im)
|
||||
return log_image_list
|
||||
|
||||
def get_log_image(self, im):
|
||||
"""
|
||||
get Laplacian of Gaussian (LoG) image for a single band.
|
||||
returns a np.ndarray
|
||||
"""
|
||||
LoG_im = ndimage.gaussian_laplace(im, sigma=self.sigma)
|
||||
return LoG_im
|
||||
|
||||
def get_glint_mask(self,im):
|
||||
"""
|
||||
get glint mask using laplacian of gaussian image.
|
||||
We assume that water constituents and features follow a smooth continuum,
|
||||
but glint pixels vary a lot spatially and in intensities
|
||||
Note that for very extensive glint, this method may not work as well <--:TODO use U-net to identify glint mask
|
||||
returns a np.ndarray
|
||||
"""
|
||||
LoG_im = ndimage.gaussian_laplace(im,sigma=self.sigma)
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内计算阈值
|
||||
if self.water_mask is not None:
|
||||
mask_bool = self.water_mask.astype(bool)
|
||||
if mask_bool.any():
|
||||
# 只在掩膜内提取LoG值用于阈值计算
|
||||
LoG_masked = LoG_im[mask_bool]
|
||||
# 将非掩膜区域设为极大值,确保不影响阈值计算
|
||||
LoG_for_thresh = LoG_im.copy()
|
||||
LoG_for_thresh[~mask_bool] = LoG_masked.max() + 1
|
||||
else:
|
||||
LoG_for_thresh = LoG_im
|
||||
else:
|
||||
LoG_for_thresh = LoG_im
|
||||
|
||||
#threshold mask
|
||||
if (self.glint_mask_method == "otsu"):
|
||||
thresh = self.otsu_thresholding(LoG_for_thresh)
|
||||
elif (self.glint_mask_method == "cdf"):
|
||||
thresh = self.cdf_thresholding(LoG_for_thresh)
|
||||
else:
|
||||
raise ValueError('Enter only cdf or otsu as glint_mask_method')
|
||||
# 使用更高效的方式创建mask,避免np.where的开销
|
||||
glint_mask = (LoG_im < thresh).astype(np.uint8)
|
||||
|
||||
# 如果存在水域掩膜,将非水域区域设为0
|
||||
if self.water_mask is not None:
|
||||
glint_mask = glint_mask * self.water_mask
|
||||
|
||||
return glint_mask
|
||||
|
||||
def get_est_background(self, im,k_size=5):
|
||||
"""
|
||||
:param im (np.ndarray): image of a band
|
||||
estimate background spectra
|
||||
returns a np.ndarray
|
||||
"""
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k_size,k_size))
|
||||
dst = cv2.erode(im, kernel)
|
||||
|
||||
return dst
|
||||
|
||||
def optimise_correction_by_band(self,im,glint_mask,R_BG,bounds):
|
||||
"""
|
||||
:param im (np.ndarray): image of a band
|
||||
:param glint_mask (np.ndarray): glint mask, where glint area is 1 and non-glint area is 0
|
||||
use brent method to get the optimimum b which minimises the variation (i.e. variance) in the entire image
|
||||
returns regression slope b
|
||||
"""
|
||||
# 预计算常量,避免在优化函数中重复计算
|
||||
glint_mask_bool = glint_mask.astype(bool)
|
||||
R_BG_flat = R_BG if isinstance(R_BG, (int, float)) else R_BG[glint_mask_bool]
|
||||
|
||||
def optimise_b(b):
|
||||
# 优化计算:只在glint区域计算校正
|
||||
if isinstance(R_BG, (int, float)):
|
||||
im_corrected = im.copy()
|
||||
im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG)
|
||||
else:
|
||||
im_corrected = im.copy()
|
||||
im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG[glint_mask_bool])
|
||||
return np.var(im_corrected)
|
||||
|
||||
res = minimize_scalar(optimise_b, bounds=bounds, method='bounded')
|
||||
return res.x
|
||||
|
||||
def divide_and_conquer(self):
|
||||
"""
|
||||
instead of computing b_list for each window, use the previous b_list to narrow the bounds,
|
||||
because of the strong spatial autocorrelation, we know that the b (correction magnitude) cannot diff too much
|
||||
this can optimise the run time
|
||||
"""
|
||||
|
||||
|
||||
def optimise_correction(self):
|
||||
"""
|
||||
returns a list of slope in band order i.e. 0,1,2,3,4,5,6,7,8,9 through optimisation
|
||||
"""
|
||||
b_list = []
|
||||
glint_mask_list = []
|
||||
est_background_list = []
|
||||
for i in range(self.n_bands):
|
||||
glint_mask = self.get_glint_mask(self.im_aligned[:,:,i])
|
||||
glint_mask_list.append(glint_mask)
|
||||
if self.estimate_background is True:
|
||||
est_background = self.get_est_background(self.im_aligned[:,:,i])
|
||||
est_background_list.append(est_background)
|
||||
else:
|
||||
est_background = np.percentile(self.im_aligned[:,:,i], 5, interpolation='nearest')
|
||||
est_background_list.append(est_background)
|
||||
bounds = self.bounds[i]
|
||||
b = self.optimise_correction_by_band(self.im_aligned[:,:,i],glint_mask,est_background,bounds)
|
||||
b_list.append(b)
|
||||
|
||||
# add attributes
|
||||
self.b_list = b_list
|
||||
self.glint_mask = glint_mask_list
|
||||
self.est_background = est_background_list
|
||||
|
||||
return b_list, glint_mask_list, est_background_list
|
||||
|
||||
def _save_corrected_bands(self, corrected_bands):
|
||||
"""
|
||||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_bands: 校正后的波段列表
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if self.output_path is None:
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 将波段列表转换为数组
|
||||
corrected_array = np.stack(corrected_bands, axis=2)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
else:
|
||||
bsq_path = self.output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_array.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
else:
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
获取校正后的波段
|
||||
|
||||
:return: 校正后的波段列表
|
||||
"""
|
||||
corrected_bands = []
|
||||
# 获取水域掩膜(如果存在)
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
|
||||
for i in range(self.n_bands):
|
||||
im_band = self.im_aligned[:,:,i]
|
||||
# 一次性计算mask和background,避免重复计算
|
||||
glint_mask = self.get_glint_mask(im_band)
|
||||
background = self.get_est_background(im_band, k_size=5)
|
||||
# 使用视图和原地操作减少内存
|
||||
im_corrected = im_band.copy()
|
||||
glint_mask_bool = glint_mask.astype(bool)
|
||||
im_corrected[glint_mask_bool] = background[glint_mask_bool]
|
||||
|
||||
# 如果存在水域掩膜,确保只在水域内应用校正
|
||||
if water_mask_bool is not None:
|
||||
# 只在水域掩膜内应用校正
|
||||
correction_mask = glint_mask_bool & water_mask_bool
|
||||
im_corrected = np.where(correction_mask, background, im_band)
|
||||
# 非水域区域保持原值
|
||||
im_corrected = np.where(water_mask_bool, im_corrected, im_band)
|
||||
|
||||
corrected_bands.append(im_corrected)
|
||||
|
||||
# 如果提供了输出路径,保存结果
|
||||
if self.output_path is not None:
|
||||
self._save_corrected_bands(corrected_bands)
|
||||
|
||||
return corrected_bands
|
||||
|
||||
def correction_iterative(im_aligned,iter=3,bounds = [(1,2)],estimate_background=True,glint_mask_method="cdf",get_glint_mask=False,termination_thresh = 20, water_mask=None, output_path=None):
|
||||
"""
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param iter (int or None): number of iterations to run the sugar algorithm. If None, termination conditions are automatically applied
|
||||
:param bounds (list of tuples): to limit correction magnitude
|
||||
:param get_glint_mask (np.ndarray):
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存最后一次迭代的校正结果
|
||||
如果为None,则不保存
|
||||
conducts iterative correction using SUGAR
|
||||
"""
|
||||
glint_image = im_aligned.copy()
|
||||
corrected_images = []
|
||||
|
||||
if iter is None:
|
||||
# termination conditions
|
||||
relative_difference = lambda sd0,sd1: sd1/sd0*100
|
||||
marginal_difference = lambda sd1,sd2: (sd1-sd2)/sd1*100
|
||||
relative_diff_thresh = marginal_difference_thresh = termination_thresh
|
||||
sd_og = np.var(im_aligned)
|
||||
iter_count = 0
|
||||
sd_next = sd_og # 不需要copy,直接使用值
|
||||
max_iter = 100 # 添加最大迭代次数限制,防止无限循环
|
||||
|
||||
while ((relative_difference(sd_og,sd_next) > relative_diff_thresh) and iter_count < max_iter):
|
||||
# do all the processing here
|
||||
HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask)
|
||||
corrected_bands = HM.get_corrected_bands()
|
||||
glint_image = np.stack(corrected_bands,axis=2)
|
||||
sd_temp = np.var(glint_image)
|
||||
# 只在需要时保存中间结果,减少内存占用
|
||||
if get_glint_mask or iter_count == 0:
|
||||
corrected_images.append(glint_image.copy())
|
||||
else:
|
||||
corrected_images.append(glint_image) # 最后一次迭代的结果
|
||||
# save glint_mask
|
||||
# if iter_count == 0 and get_glint_mask is True:
|
||||
# glint_mask = np.stack(HM.glint_mask,axis=2)
|
||||
if (marginal_difference(sd_next,sd_temp)<marginal_difference_thresh):
|
||||
break
|
||||
else:
|
||||
sd_next = sd_temp
|
||||
#increase count
|
||||
iter_count += 1
|
||||
|
||||
# 如果提供了输出路径,保存最后一次迭代的结果
|
||||
if output_path is not None and len(corrected_images) > 0:
|
||||
_save_corrected_image(corrected_images[-1], output_path)
|
||||
|
||||
else:
|
||||
for i in range(iter):
|
||||
HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask)
|
||||
corrected_bands = HM.get_corrected_bands()
|
||||
glint_image = np.stack(corrected_bands,axis=2)
|
||||
# 只在最后一次迭代或需要时保存所有结果
|
||||
if i == iter - 1 or get_glint_mask:
|
||||
corrected_images.append(glint_image.copy())
|
||||
else:
|
||||
# 对于中间迭代,可以只保存引用(但要注意内存管理)
|
||||
corrected_images.append(glint_image)
|
||||
# save glint_mask
|
||||
# if i == 0 and get_glint_mask is True:
|
||||
# glint_mask = np.stack(HM.glint_mask,axis=2)
|
||||
|
||||
# 如果提供了输出路径,保存最后一次迭代的结果
|
||||
if output_path is not None and len(corrected_images) > 0:
|
||||
_save_corrected_image(corrected_images[-1], output_path)
|
||||
|
||||
return corrected_images
|
||||
|
||||
def _save_corrected_image(corrected_image, output_path):
|
||||
"""
|
||||
保存校正后的图像到文件(用于correction_iterative函数,BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_image: 校正后的图像数组,形状为(height, width, bands)
|
||||
:param output_path: 输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if output_path is None:
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(output_path)
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
else:
|
||||
bsq_path = output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_image.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_image[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
else:
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
Reference in New Issue
Block a user