636 lines
23 KiB
Python
636 lines
23 KiB
Python
import cv2
|
||
import numpy as np
|
||
import os
|
||
from scipy import ndimage
|
||
from scipy.optimize import minimize_scalar
|
||
|
||
try:
|
||
from osgeo import gdal
|
||
GDAL_AVAILABLE = True
|
||
except ImportError:
|
||
GDAL_AVAILABLE = False
|
||
|
||
|
||
def otsu_thresholding(im, auto_bins=None):
|
||
"""
|
||
Otsu阈值分割
|
||
"""
|
||
if auto_bins is None:
|
||
auto_bins = max(10, int(0.005 * im.shape[0] * im.shape[1]))
|
||
im_flat = im.ravel()
|
||
valid_mask = np.isfinite(im_flat)
|
||
if not valid_mask.all():
|
||
im_flat = im_flat[valid_mask]
|
||
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
|
||
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5
|
||
count_sum = count.sum()
|
||
hist_norm = count / count_sum
|
||
Q = hist_norm.cumsum()
|
||
N = count.shape[0]
|
||
N_negative = np.sum(bin < 0)
|
||
bins = np.arange(N, dtype=np.float32)
|
||
|
||
def otsu_thresh(x):
|
||
x = int(x)
|
||
p1 = hist_norm[:x]
|
||
p2 = hist_norm[x:]
|
||
q1 = Q[x]
|
||
q2 = Q[N - 1] - Q[x]
|
||
b1 = bins[:x]
|
||
b2 = bins[x:]
|
||
m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0
|
||
m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0
|
||
v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0
|
||
v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0
|
||
return v1 * q1 + v2 * q2
|
||
|
||
if N_negative <= 1:
|
||
return bin[np.argmax(count)]
|
||
res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded')
|
||
return bin[int(res.x)]
|
||
|
||
|
||
def cdf_thresholding(im, auto_bins=10):
|
||
"""CDF阈值分割"""
|
||
im_flat = im.ravel()
|
||
valid_mask = np.isfinite(im_flat)
|
||
if not valid_mask.all():
|
||
im_flat = im_flat[valid_mask]
|
||
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
|
||
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5
|
||
return bin[np.argmax(count)]
|
||
|
||
|
||
class SUGAR:
|
||
"""
|
||
SUGAR 耀斑去除算法 - 分块逐波段处理版本
|
||
|
||
策略:
|
||
1. 分块扫描全图,计算每个块的 glint_mask(需要全局阈值)
|
||
2. 收集所有 glint 像素值到列表(仅收集索引,不存储完整掩膜数组)
|
||
3. 全局优化每波段的 b 值(使用所有 glint 像素的方差最小化)
|
||
4. 分块处理:计算 background(需全块) -> 应用校正 -> 写入输出
|
||
"""
|
||
|
||
def __init__(self, img_path, bounds=None, sigma=1.0, estimate_background=True,
|
||
glint_mask_method="cdf", water_mask=None, output_path=None,
|
||
block_size=1000):
|
||
"""
|
||
:param img_path (str): 输入影像文件路径
|
||
:param bounds: 每个波段的优化边界,默认 [(1,2)] * n_bands
|
||
:param sigma (float): LoG 平滑 sigma
|
||
:param estimate_background (bool): 是否用中值滤波估计背景
|
||
:param glint_mask_method (str): "cdf" 或 "otsu"
|
||
:param water_mask: 水域掩膜
|
||
:param output_path (str): 输出文件路径
|
||
:param block_size (int): 分块大小
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||
|
||
if bounds is None:
|
||
bounds = [(1, 2)]
|
||
|
||
self.img_path = img_path
|
||
self.bounds = bounds
|
||
self.sigma = sigma
|
||
self.estimate_background = estimate_background
|
||
self.glint_mask_method = glint_mask_method
|
||
self.water_mask = None
|
||
self.water_mask_path = water_mask
|
||
self.output_path = output_path
|
||
self.block_size = block_size
|
||
|
||
# 打开影像
|
||
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||
if self.dataset is None:
|
||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||
self.width = self.dataset.RasterXSize
|
||
self.height = self.dataset.RasterYSize
|
||
self.n_bands = self.dataset.RasterCount
|
||
|
||
# 扩展 bounds 到所有波段
|
||
self.bounds_all = self.bounds * self.n_bands
|
||
|
||
# 优化结果(全局)
|
||
self.b_list = None
|
||
self.glint_pixel_indices = [] # list of (block_idx, row, col) 索引
|
||
self.thresholds = [] # 每波段的全局阈值
|
||
|
||
def _load_water_mask(self):
|
||
"""延迟加载水域掩膜"""
|
||
if self.water_mask_path is None:
|
||
return None
|
||
|
||
if isinstance(self.water_mask_path, np.ndarray):
|
||
if self.water_mask_path.shape[:2] != (self.height, self.width):
|
||
raise ValueError(
|
||
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||
)
|
||
return (self.water_mask_path > 0).astype(np.uint8)
|
||
|
||
if isinstance(self.water_mask_path, str):
|
||
if self.water_mask_path.lower().endswith('.shp'):
|
||
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
|
||
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
|
||
if mask_dataset is None:
|
||
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
|
||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||
mask_dataset = None
|
||
if mask_array.shape != (self.height, self.width):
|
||
raise ValueError(
|
||
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||
)
|
||
return (mask_array > 0).astype(np.uint8)
|
||
|
||
return None
|
||
|
||
def _compute_threshold(self, im):
|
||
"""计算 glint 阈值"""
|
||
if self.glint_mask_method == "otsu":
|
||
return otsu_thresholding(im)
|
||
else:
|
||
return cdf_thresholding(im)
|
||
|
||
def _get_glint_mask_block(self, band_data):
|
||
"""
|
||
对单波段块计算 glint mask
|
||
阈值来自全局阈值 self.thresholds[band_idx]
|
||
"""
|
||
# LoG
|
||
log_im = ndimage.gaussian_laplace(band_data.astype(np.float32), sigma=self.sigma)
|
||
# 全局阈值
|
||
thresh = self.thresholds[self._current_band]
|
||
glint_mask = (log_im < thresh).astype(np.uint8)
|
||
|
||
# 应用水域掩膜
|
||
water_mask = self._load_water_mask()
|
||
if water_mask is not None:
|
||
y_off = self._current_y
|
||
y_end = y_off + band_data.shape[0]
|
||
x_off = self._current_x
|
||
x_end = x_off + band_data.shape[1]
|
||
mask_block = water_mask[y_off:y_end, x_off:x_end]
|
||
glint_mask = glint_mask * mask_block
|
||
|
||
return log_im, glint_mask
|
||
|
||
def _get_est_background(self, im, k_size=5):
|
||
"""估计背景光谱(中值滤波)"""
|
||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size))
|
||
return cv2.erode(im.astype(np.float32), kernel)
|
||
|
||
def _optimise_correction_band(self, R_glint, R_bg_glint, bounds):
|
||
"""
|
||
全局优化单波段 b 值
|
||
|
||
使用所有 glint 像素的方差最小化
|
||
R_corrected = R - mask * (R / b - R_bg)
|
||
最小化 Var(R_corrected)
|
||
|
||
:param R_glint: 所有 glint 像素值(1D array)
|
||
:param R_bg_glint: 对应背景值(1D array)
|
||
:param bounds: 优化边界
|
||
:return: 最优 b 值
|
||
"""
|
||
if len(R_glint) == 0:
|
||
return 1.0
|
||
|
||
R_glint = R_glint.astype(np.float32)
|
||
R_bg_glint = R_bg_glint.astype(np.float32)
|
||
|
||
def objective(b):
|
||
b = float(b)
|
||
R_corrected = R_glint - (R_glint / b - R_bg_glint)
|
||
return np.var(R_corrected)
|
||
|
||
res = minimize_scalar(objective, bounds=bounds, method='bounded')
|
||
return res.x
|
||
|
||
def _scan_and_collect_glint(self):
|
||
"""
|
||
Step 1: 分块扫描全图,收集每波段的全局阈值和 glint 像素索引
|
||
|
||
内存:仅存储每波段的阈值(float)和 glint 像素位置索引
|
||
"""
|
||
print(f"[SUGAR] 步骤1: 扫描全图收集glint像素...")
|
||
water_mask = self._load_water_mask()
|
||
|
||
# 初始化阈值列表
|
||
self.thresholds = [None] * self.n_bands
|
||
log_collections = [[] for _ in range(self.n_bands)]
|
||
|
||
# 全图扫描:收集每波段的 LoG 值用于阈值计算
|
||
n_blocks = 0
|
||
for y_off in range(0, self.height, self.block_size):
|
||
y_end = min(y_off + self.block_size, self.height)
|
||
y_size = y_end - y_off
|
||
|
||
for x_off in range(0, self.width, self.block_size):
|
||
x_end = min(x_off + self.block_size, self.width)
|
||
x_size = x_end - x_off
|
||
n_blocks += 1
|
||
|
||
for b in range(self.n_bands):
|
||
band = self.dataset.GetRasterBand(b + 1)
|
||
block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||
band = None
|
||
|
||
log_im = ndimage.gaussian_laplace(block, sigma=self.sigma)
|
||
|
||
# mask_block 在波段循环外初始化,每块只计算一次
|
||
if b == 0 and water_mask is not None:
|
||
_mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
|
||
|
||
if water_mask is not None:
|
||
if _mask_block.any():
|
||
log_collections[b].append(log_im[_mask_block])
|
||
else:
|
||
log_collections[b].append(log_im.ravel())
|
||
|
||
del block, log_im
|
||
|
||
if water_mask is not None:
|
||
del _mask_block
|
||
|
||
# 计算每波段的全局阈值(需要所有LoG值)
|
||
print(f"[SUGAR] 计算 {self.n_bands} 个波段的全局阈值...")
|
||
for b in range(self.n_bands):
|
||
if len(log_collections[b]) == 0:
|
||
self.thresholds[b] = 0.0
|
||
else:
|
||
all_log = np.concatenate(log_collections[b])
|
||
thresh = self._compute_threshold(
|
||
all_log.reshape(1, -1) # shape (1, N) 模拟二维输入
|
||
)
|
||
self.thresholds[b] = float(thresh)
|
||
del all_log
|
||
print(f" 波段{b}: thresh={self.thresholds[b]:.4f}")
|
||
log_collections[b] = None
|
||
|
||
def _collect_glint_pixel_values(self):
|
||
"""
|
||
Step 2: 再次分块扫描,收集每波段所有 glint 像素的 (R, R_bg) 值用于优化
|
||
|
||
内存:只存储 1D 数组(所有 glint 像素值)
|
||
"""
|
||
print(f"[SUGAR] 步骤2: 收集glint像素值用于全局优化...")
|
||
water_mask = self._load_water_mask()
|
||
|
||
R_glint_list = [[] for _ in range(self.n_bands)]
|
||
R_bg_glint_list = [[] for _ in range(self.n_bands)]
|
||
|
||
for y_off in range(0, self.height, self.block_size):
|
||
y_end = min(y_off + self.block_size, self.height)
|
||
y_size = y_end - y_off
|
||
|
||
for x_off in range(0, self.width, self.block_size):
|
||
x_end = min(x_off + self.block_size, self.width)
|
||
x_size = x_end - x_off
|
||
|
||
for b in range(self.n_bands):
|
||
band = self.dataset.GetRasterBand(b + 1)
|
||
R_block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||
band = None
|
||
|
||
# LoG 和 mask
|
||
log_im = ndimage.gaussian_laplace(R_block, sigma=self.sigma)
|
||
thresh = self.thresholds[b]
|
||
glint_mask = (log_im < thresh).astype(np.uint8)
|
||
|
||
if water_mask is not None:
|
||
mask_block = water_mask[y_off:y_end, x_off:x_end]
|
||
glint_mask = glint_mask * mask_block
|
||
|
||
# 背景
|
||
if self.estimate_background:
|
||
R_bg = self._get_est_background(R_block)
|
||
else:
|
||
R_bg = np.percentile(R_block, 5, method='nearest')
|
||
|
||
# 收集 glint 像素
|
||
glint_idx = glint_mask.astype(bool)
|
||
if glint_idx.any():
|
||
R_glint_list[b].append(R_block[glint_idx])
|
||
R_bg_glint_list[b].append(R_bg[glint_idx])
|
||
|
||
del R_block, log_im, glint_mask, R_bg
|
||
|
||
# 汇总
|
||
self.R_glint_all = []
|
||
self.R_bg_glint_all = []
|
||
for b in range(self.n_bands):
|
||
if len(R_glint_list[b]) == 0:
|
||
self.R_glint_all.append(np.array([], dtype=np.float32))
|
||
self.R_bg_glint_all.append(np.array([], dtype=np.float32))
|
||
else:
|
||
self.R_glint_all.append(np.concatenate(R_glint_list[b]))
|
||
self.R_bg_glint_all.append(np.concatenate(R_bg_glint_list[b]))
|
||
n = len(self.R_glint_all[b])
|
||
print(f" 波段{b}: 收集到 {n} 个glint像素")
|
||
R_glint_list[b] = None
|
||
R_bg_glint_list[b] = None
|
||
|
||
def _optimize_b_list(self):
|
||
"""
|
||
Step 3: 全局优化每波段的 b 值
|
||
"""
|
||
print(f"[SUGAR] 步骤3: 全局优化 b 值...")
|
||
self.b_list = []
|
||
for b in range(self.n_bands):
|
||
bounds = self.bounds_all[b]
|
||
b_opt = self._optimise_correction_band(
|
||
self.R_glint_all[b], self.R_bg_glint_all[b], bounds
|
||
)
|
||
self.b_list.append(float(b_opt))
|
||
print(f" 波段{b}: b={b_opt:.4f}")
|
||
|
||
# 释放内存
|
||
self.R_glint_all = None
|
||
self.R_bg_glint_all = None
|
||
|
||
def _process_and_write_block(self, x_off, y_off, x_size, y_size, out_dataset):
|
||
"""
|
||
Step 4: 分块处理并写入输出文件
|
||
"""
|
||
water_mask = self._load_water_mask()
|
||
|
||
for b in range(self.n_bands):
|
||
band = self.dataset.GetRasterBand(b + 1)
|
||
R_block = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||
band = None
|
||
|
||
# 计算 glint mask
|
||
log_im = ndimage.gaussian_laplace(R_block, sigma=self.sigma)
|
||
thresh = self.thresholds[b]
|
||
glint_mask = (log_im < thresh).astype(np.uint8)
|
||
|
||
if water_mask is not None:
|
||
mask_block = water_mask[y_off:y_off + y_size, x_off:x_off + x_size]
|
||
glint_mask = glint_mask * mask_block
|
||
|
||
glint_bool = glint_mask.astype(bool)
|
||
|
||
# 计算背景
|
||
if self.estimate_background:
|
||
R_bg = self._get_est_background(R_block)
|
||
else:
|
||
R_bg = np.percentile(R_block, 5, method='nearest')
|
||
|
||
# 校正
|
||
b_val = self.b_list[b]
|
||
R_corrected = R_block.copy()
|
||
|
||
if glint_bool.any():
|
||
R_corrected[glint_bool] = (
|
||
R_bg[glint_bool]
|
||
+ (R_block[glint_bool] - R_bg[glint_bool]) / b_val
|
||
)
|
||
|
||
# 写入
|
||
out_band = out_dataset.GetRasterBand(b + 1)
|
||
out_band.WriteArray(R_corrected, x_off, y_off)
|
||
out_band.FlushCache()
|
||
|
||
del R_block, log_im, glint_mask, R_bg, R_corrected
|
||
|
||
def get_corrected_bands(self):
|
||
"""
|
||
执行分块处理,返回校正后的波段列表
|
||
"""
|
||
if self.output_path is None:
|
||
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
|
||
|
||
output_dir = os.path.dirname(self.output_path)
|
||
if output_dir and not os.path.exists(output_dir):
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
base_path, ext = os.path.splitext(self.output_path)
|
||
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
|
||
|
||
geotransform = self.dataset.GetGeoTransform()
|
||
projection = self.dataset.GetProjection()
|
||
|
||
driver = gdal.GetDriverByName('ENVI')
|
||
out_dataset = driver.Create(bsq_path, self.width, self.height,
|
||
self.n_bands, gdal.GDT_Float32)
|
||
if out_dataset is None:
|
||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||
out_dataset.SetGeoTransform(geotransform)
|
||
out_dataset.SetProjection(projection)
|
||
|
||
# Step 1: 扫描收集阈值
|
||
self._scan_and_collect_glint()
|
||
|
||
# Step 2: 收集 glint 像素值
|
||
self._collect_glint_pixel_values()
|
||
|
||
# Step 3: 全局优化 b
|
||
self._optimize_b_list()
|
||
|
||
# Step 4: 分块处理写入
|
||
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
|
||
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
|
||
total_blocks = n_blocks_x * n_blocks_y
|
||
|
||
print(f"[SUGAR] 步骤4: 分块处理写入,共 {total_blocks} 块")
|
||
|
||
block_idx = 0
|
||
for y_off in range(0, self.height, self.block_size):
|
||
y_end = min(y_off + self.block_size, self.height)
|
||
y_size = y_end - y_off
|
||
|
||
for x_off in range(0, self.width, self.block_size):
|
||
x_end = min(x_off + self.block_size, self.width)
|
||
x_size = x_end - x_off
|
||
block_idx += 1
|
||
|
||
self._current_x = x_off
|
||
self._current_y = y_off
|
||
|
||
print(f"[SUGAR] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
|
||
|
||
self._process_and_write_block(x_off, y_off, x_size, y_size, out_dataset)
|
||
|
||
out_dataset = None
|
||
self.dataset = None
|
||
|
||
hdr_path = bsq_path + '.hdr'
|
||
if os.path.exists(hdr_path):
|
||
print(f"[SUGAR] 校正完成,已保存至: {bsq_path}")
|
||
else:
|
||
print(f"[SUGAR] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)")
|
||
|
||
return []
|
||
|
||
def __del__(self):
|
||
if self.dataset is not None:
|
||
self.dataset = None
|
||
|
||
|
||
# ============================================================================
|
||
# 独立函数:correction_iterative(迭代版本,支持大图)
|
||
# ============================================================================
|
||
def correction_iterative(img_path, iter=3, bounds=None, estimate_background=True,
|
||
glint_mask_method="cdf", get_glint_mask=False,
|
||
termination_thresh=20.0, water_mask=None, output_path=None,
|
||
block_size=1000):
|
||
"""
|
||
SUGAR 迭代去耀斑 - 分块版本
|
||
|
||
:param img_path (str): 输入影像文件路径
|
||
:param iter (int or None): 迭代次数,None 表示自动终止
|
||
:param bounds: 优化边界
|
||
:param estimate_background: 是否估计背景
|
||
:param glint_mask_method: "cdf" 或 "otsu"
|
||
:param get_glint_mask: 是否返回 glint mask(已废弃,保持接口兼容)
|
||
:param termination_thresh: 自动终止阈值
|
||
:param water_mask: 水域掩膜
|
||
:param output_path: 输出文件路径
|
||
:param block_size: 分块大小
|
||
:return: 每次迭代的校正图像列表(None,分块模式下返回空列表)
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装")
|
||
|
||
if bounds is None:
|
||
bounds = [(1, 2)]
|
||
|
||
# 打开影像获取基本信息
|
||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||
if dataset is None:
|
||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||
|
||
width = dataset.RasterXSize
|
||
height = dataset.RasterYSize
|
||
n_bands = dataset.RasterCount
|
||
|
||
geotransform = dataset.GetGeoTransform()
|
||
projection = dataset.GetProjection()
|
||
dataset = None
|
||
|
||
# 计算临时输出路径
|
||
temp_dir = os.path.dirname(output_path) if output_path else os.getcwd()
|
||
temp_base = os.path.join(temp_dir, "_sugar_iter")
|
||
|
||
# 确保输出目录存在
|
||
if output_path:
|
||
output_dir = os.path.dirname(output_path)
|
||
if output_dir and not os.path.exists(output_dir):
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
if iter is None:
|
||
relative_diff = lambda sd0, sd1: sd1 / sd0 * 100
|
||
marginal_diff = lambda sd1, sd2: (sd1 - sd2) / sd1 * 100
|
||
|
||
glint_img_path = img_path
|
||
iter_count = 0
|
||
max_iter = 100
|
||
|
||
while True:
|
||
iter_output = f"{temp_base}_{iter_count}.bsq"
|
||
|
||
sugar = SUGAR(
|
||
glint_img_path,
|
||
bounds=bounds,
|
||
estimate_background=estimate_background,
|
||
glint_mask_method=glint_mask_method,
|
||
water_mask=water_mask,
|
||
output_path=iter_output,
|
||
block_size=block_size
|
||
)
|
||
sugar.get_corrected_bands()
|
||
|
||
# 检查方差收敛
|
||
# 读取当前输出图像的方差(分块读取第一个波段估算)
|
||
ds = gdal.Open(iter_output, gdal.GA_ReadOnly)
|
||
if ds is not None:
|
||
# 采样估算方差
|
||
sample_data = []
|
||
for y_off in range(0, height, block_size):
|
||
y_end = min(y_off + block_size, height)
|
||
block = ds.GetRasterBand(1).ReadAsArray(0, y_off, width, y_end - y_off)
|
||
sample_data.append(block.ravel())
|
||
all_data = np.concatenate(sample_data)
|
||
sd_current = np.var(all_data)
|
||
ds = None
|
||
del all_data
|
||
else:
|
||
sd_current = 0
|
||
|
||
prev_img_path = glint_img_path
|
||
glint_img_path = iter_output
|
||
|
||
if iter_count == 0:
|
||
sd_prev = sd_current
|
||
|
||
# 检查终止条件
|
||
if (iter_count > 0 and
|
||
marginal_diff(sd_prev, sd_current) < termination_thresh):
|
||
break
|
||
if iter_count >= max_iter:
|
||
break
|
||
|
||
sd_prev = sd_current
|
||
iter_count += 1
|
||
|
||
# 将最终结果移动到 output_path
|
||
if output_path and glint_img_path != output_path:
|
||
import shutil
|
||
if os.path.exists(output_path):
|
||
os.remove(output_path)
|
||
# 找最后一个有效输出
|
||
last_iter = max(0, iter_count - 1)
|
||
final_path = f"{temp_base}_{last_iter}.bsq"
|
||
if os.path.exists(final_path):
|
||
shutil.move(final_path, output_path)
|
||
# 复制hdr
|
||
if os.path.exists(final_path + '.hdr'):
|
||
import shutil
|
||
shutil.copy(final_path + '.hdr', output_path + '.hdr')
|
||
|
||
else:
|
||
glint_img_path = img_path
|
||
|
||
for i in range(iter):
|
||
iter_output = f"{temp_base}_{i}.bsq"
|
||
|
||
sugar = SUGAR(
|
||
glint_img_path,
|
||
bounds=bounds,
|
||
estimate_background=estimate_background,
|
||
glint_mask_method=glint_mask_method,
|
||
water_mask=water_mask,
|
||
output_path=iter_output,
|
||
block_size=block_size
|
||
)
|
||
sugar.get_corrected_bands()
|
||
|
||
prev_img_path = glint_img_path
|
||
glint_img_path = iter_output
|
||
|
||
# 将最后一次结果移动到 output_path
|
||
if output_path:
|
||
last_iter = iter - 1
|
||
final_path = f"{temp_base}_{last_iter}.bsq"
|
||
if os.path.exists(final_path):
|
||
import shutil
|
||
# 删除旧文件
|
||
if os.path.exists(output_path):
|
||
os.remove(output_path)
|
||
# 移动
|
||
shutil.move(final_path, output_path)
|
||
# 复制hdr
|
||
if os.path.exists(final_path + '.hdr'):
|
||
shutil.copy(final_path + '.hdr', output_path + '.hdr')
|
||
|
||
# 清理临时文件
|
||
for i in range(max(0, iter - 1)):
|
||
f = f"{temp_base}_{i}.bsq"
|
||
if os.path.exists(f):
|
||
os.remove(f)
|
||
h = f + '.hdr'
|
||
if os.path.exists(h):
|
||
os.remove(h)
|
||
|
||
return [] |