Initial commit of WQ_GUI

This commit is contained in:
2026-04-08 15:25:08 +08:00
commit 91e36407ae
302 changed files with 40872 additions and 0 deletions

View File

@ -0,0 +1,765 @@
from src.utils.util import *
from osgeo import gdal, ogr
import argparse
import cv2
def percentile_stretch(img, data_water_mask, lower_percentile=2, upper_percentile=98, output_range=(0, 255)):
"""
使用百分位数裁剪进行归一化,适用于低反射率数据
通过排除极值,更好地利用数据的动态范围
Args:
img: 输入图像数组反射率值通常在0-1之间
data_water_mask: 水域掩膜
lower_percentile: 下百分位数用于裁剪最小值默认2
upper_percentile: 上百分位数用于裁剪最大值默认98
output_range: 输出范围,默认(0, 255)
Returns:
归一化后的图像数组(整数类型)
"""
# 只在水域掩膜区域计算百分位数
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于百分位数计算,使用原始值")
return img.astype(np.int32)
# 计算百分位数
p_lower = np.percentile(valid_pixels, lower_percentile)
p_upper = np.percentile(valid_pixels, upper_percentile)
# 如果上下界相同,使用最大值作为上界
if p_lower >= p_upper:
p_lower = np.percentile(valid_pixels, 1)
p_upper = np.percentile(valid_pixels, 99)
if p_lower >= p_upper:
p_upper = valid_pixels.max()
p_lower = valid_pixels.min()
print(f"百分位数拉伸: {lower_percentile}%={p_lower:.6f}, {upper_percentile}%={p_upper:.6f}, "
f"数据范围=[{img.min():.6f}, {img.max():.6f}]")
# 裁剪到百分位数范围
img_clipped = np.clip(img, p_lower, p_upper)
# 线性拉伸到输出范围
if p_upper > p_lower:
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (output_range[1] - output_range[0]) + output_range[0]
else:
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
return img_stretched.astype(np.int32)
@timeit
def otsu(img, max_value, data_water_mask, ignore_value=0, foreground=1, background=0):
height = img.shape[0]
width = img.shape[1]
hist = np.zeros([max_value], np.float32)
# 计算直方图
invalid_counter = 0
for i in range(height):
for j in range(width):
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
invalid_counter = invalid_counter + 1
continue
hist[img[i, j]] += 1
hist /= (height * width - invalid_counter)
threshold = 0
deltaMax = 0
# 遍历像素值,计算最大类间方差
for i in range(max_value):
wA = 0
wB = 0
uAtmp = 0
uBtmp = 0
uA = 0
uB = 0
u = 0
for j in range(max_value):
if j <= i:
wA += hist[j]
uAtmp += j * hist[j]
else:
wB += hist[j]
uBtmp += j * hist[j]
if wA == 0:
wA = 1e-10
if wB == 0:
wB = 1e-10
uA = uAtmp / wA
uB = uBtmp / wB
u = uAtmp + uBtmp
# 计算类间方差
deltaTmp = wA * ((uA - u)**2) + wB * ((uB - u)**2)
# 找出最大类间方差以及阈值
if deltaTmp > deltaMax:
deltaMax = deltaTmp
threshold = i
# 二值化
det_img = img.copy()
det_img[img > threshold] = foreground
det_img[img <= threshold] = background
det_img[np.where(data_water_mask == 0)] = background
return det_img
@timeit
def zscore_threshold(img, data_water_mask, z_threshold=2.5, foreground=1, background=0):
"""
基于Z-score标准化分数的耀斑检测方法
使用统计方法识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
z_threshold: Z-score阈值默认2.5即超过均值2.5个标准差)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
# 只在水域掩膜区域计算统计量,排除无效值
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于统计计算")
return np.zeros_like(img, dtype=np.int32)
mean_val = np.mean(valid_pixels)
std_val = np.std(valid_pixels)
if std_val == 0:
print("警告: 标准差为0无法使用Z-score方法")
return np.zeros_like(img, dtype=np.int32)
# 计算Z-score对无效值进行保护
z_scores = np.zeros_like(img, dtype=np.float32)
valid_mask = (data_water_mask > 0) & np.isfinite(img)
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
# 二值化
det_img = np.zeros_like(img, dtype=np.int32)
det_img[z_scores > z_threshold] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"Z-score方法: 均值={mean_val:.2f}, 标准差={std_val:.2f}, 阈值={mean_val + z_threshold * std_val:.2f}")
return det_img
@timeit
def percentile_threshold(img, data_water_mask, percentile=95, foreground=1, background=0):
"""
基于百分位数的耀斑检测方法
使用百分位数作为阈值,对异常值更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
percentile: 百分位数阈值默认95即超过95%的像素值)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
# 只在水域掩膜区域计算百分位数,排除无效值
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于统计计算")
return np.zeros_like(img, dtype=np.int32)
threshold = np.percentile(valid_pixels, percentile)
# 二值化
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > threshold] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"百分位数方法: {percentile}%分位数为 {threshold:.2f}")
return det_img
@timeit
def multi_band_glint_detection(dataset, img_path, water_mask, glint_waves, weights=None, method='zscore',
z_threshold=2.5, percentile=95, foreground=1, background=0):
"""
多波段融合的耀斑检测方法
结合多个波段的耀斑特征,提高检测的稳健性
Args:
dataset: GDAL数据集
img_path: 影像文件路径(用于获取波长信息)
water_mask: 水域掩膜数组
glint_waves: 用于检测的波长列表,如[750, 800, 850]
weights: 各波段的权重如果为None则使用等权重
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
z_threshold: Z-score阈值当method='zscore'时使用)
percentile: 百分位数阈值当method='percentile'时使用)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
num_bands = dataset.RasterCount
if weights is None:
weights = [1.0 / len(glint_waves)] * len(glint_waves)
if len(weights) != len(glint_waves):
raise ValueError("权重数量必须与波长数量相同")
# 读取多个波段并加权融合使用float32保持精度
fused_band = None
for i, wave in enumerate(glint_waves):
band_num = find_band_number(wave, img_path)
if band_num >= num_bands:
print(f"警告: 波段号 {band_num} 超出范围,跳过波长 {wave}")
continue
tmp = dataset.GetRasterBand(band_num + 1).ReadAsArray().astype(np.float32)
if fused_band is None:
fused_band = (tmp * weights[i]).astype(np.float32)
else:
fused_band = (fused_band + tmp * weights[i]).astype(np.float32)
if fused_band is None:
raise ValueError("没有有效的波段可以融合")
# 根据方法选择是否需要归一化
# 对于统计方法zscore, percentile直接使用原始反射率值
# 对于Otsu方法需要归一化到整数范围
if method == 'otsu':
# Otsu方法需要整数范围使用百分位数拉伸
fused_band_stretch = percentile_stretch(fused_band, water_mask,
lower_percentile=2, upper_percentile=98)
return otsu(fused_band_stretch, fused_band_stretch.max() + 1, water_mask,
foreground=foreground, background=background)
elif method == 'zscore':
# Z-score方法直接使用原始反射率值
return zscore_threshold(fused_band, water_mask, z_threshold, foreground, background)
elif method == 'percentile':
# 百分位数方法直接使用原始反射率值
return percentile_threshold(fused_band, water_mask, percentile, foreground, background)
else:
raise ValueError(f"不支持的方法: {method}")
@timeit
def adaptive_threshold(img, data_water_mask, window_size=15, percentile=90, foreground=1, background=0):
"""
自适应阈值方法
基于局部统计特性进行阈值分割,对光照变化更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
window_size: 局部窗口大小(奇数)
percentile: 局部百分位数阈值
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
height, width = img.shape
# 确保窗口大小为奇数
if window_size % 2 == 0:
window_size += 1
half_window = window_size // 2
# 创建输出图像
det_img = np.zeros_like(img, dtype=np.int32)
# 对每个像素计算局部阈值
for i in range(half_window, height - half_window):
for j in range(half_window, width - half_window):
# 只在水域掩膜内处理
if data_water_mask[i, j] == 0:
continue
# 提取局部窗口
local_window = img[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
local_mask = data_water_mask[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
# 只考虑有效像素
valid_pixels = local_window[local_mask > 0]
if len(valid_pixels) > 0:
local_threshold = np.percentile(valid_pixels, percentile)
if img[i, j] > local_threshold:
det_img[i, j] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"自适应阈值方法: 窗口大小={window_size}, 局部百分位数={percentile}%")
return det_img
@timeit
def iqr_outlier_detection(img, data_water_mask, iqr_multiplier=1.5, foreground=1, background=0):
"""
基于IQR四分位距的异常值检测方法
使用四分位距识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
iqr_multiplier: IQR倍数默认1.5(标准异常值检测)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
# 只在水域掩膜区域计算统计量,排除无效值
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于统计计算")
return np.zeros_like(img, dtype=np.int32)
q1 = np.percentile(valid_pixels, 25)
q3 = np.percentile(valid_pixels, 75)
iqr = q3 - q1
# 上界 = Q3 + 1.5 * IQR
upper_bound = q3 + iqr_multiplier * iqr
# 二值化
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > upper_bound] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"IQR方法: Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}, 上界={upper_bound:.2f}")
return det_img
@timeit
def create_shoreline_buffer(water_mask, buffer_size=5, foreground=1, background=0):
"""
创建岸边缓冲区掩膜(向内缓冲)
用于去除岸边附近的错误耀斑检测区域
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
Args:
water_mask: 水域掩膜数组(水域=1非水域=0
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
岸边缓冲区掩膜(缓冲区区域=1其他=0
"""
if buffer_size <= 0:
print("缓冲区大小为0或负数不创建岸边缓冲区")
return np.zeros_like(water_mask, dtype=np.int32)
# 将水域掩膜转换为二值图像
water_binary = (water_mask > 0).astype(np.int32)
# 创建结构元素(方形结构元素)
# 结构元素大小由buffer_size决定确保是奇数
structure_size = buffer_size * 2 + 1
structure = np.ones((structure_size, structure_size), dtype=np.int32)
# 对水域进行腐蚀,得到缩小后的水域
# 使用OpenCV替代scipy.ndimage.binary_erosion
eroded_water = cv2.erode(water_binary.astype(np.uint8), structure.astype(np.uint8)).astype(np.int32)
# 岸边缓冲区 = 原始水域 - 腐蚀后的水域
# 这给出了水域边缘向内buffer_size像素宽的缓冲区区域
buffer_mask = (water_binary - eroded_water).astype(np.int32)
buffer_pixels = np.sum(buffer_mask > 0)
print(f"岸边缓冲区: 创建了 {buffer_size} 像素宽的内向缓冲区,共 {buffer_pixels} 个像素")
return buffer_mask
@timeit
def remove_shoreline_buffer(glint_mask, water_mask, buffer_size=5, foreground=1, background=0):
"""
从耀斑掩膜中去除岸边缓冲区内的区域
Args:
glint_mask: 耀斑掩膜数组
water_mask: 水域掩膜数组
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
去除岸边缓冲区后的耀斑掩膜
"""
if buffer_size <= 0:
print("缓冲区大小为0不进行岸边缓冲区去除")
return glint_mask
# 创建岸边缓冲区掩膜
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
# 从耀斑掩膜中去除缓冲区内的区域
cleaned_glint_mask = glint_mask.copy()
cleaned_glint_mask[buffer_mask > 0] = background
removed_pixels = np.sum((glint_mask > 0) & (buffer_mask > 0))
remaining_pixels = np.sum(cleaned_glint_mask > 0)
if removed_pixels > 0:
print(f"岸边缓冲区去除: 从耀斑掩膜中移除了 {removed_pixels} 个岸边向内缓冲区域的像素,"
f"剩余 {remaining_pixels} 个像素")
else:
print(f"岸边缓冲区去除: 缓冲区区域没有耀斑掩膜,无需移除")
return cleaned_glint_mask
@timeit
def filter_large_components(binary_img, max_area=None, foreground=1, background=0):
"""
过滤掉面积超过阈值的连通域
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
Args:
binary_img: 二值化图像
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
如果为None则不进行过滤
foreground: 前景值
background: 背景值
Returns:
过滤后的二值化图像
"""
if max_area is None or max_area <= 0:
return binary_img
# 连通域标记
# 使用OpenCV替代scipy.ndimage.label
binary_for_label = (binary_img == foreground).astype(np.uint8)
num_features, labeled_array, stats, centroids = cv2.connectedComponentsWithStats(binary_for_label, connectivity=8)
if num_features == 0:
print("没有检测到连通域")
return binary_img
# 使用OpenCV返回的stats信息直接获取连通域面积
# stats[:, cv2.CC_STAT_AREA] 包含每个连通域的面积(包括背景)
# 跳过索引0背景的面积从索引1开始获取连通域面积
component_sizes = stats[1:, cv2.CC_STAT_AREA]
# 找出需要保留的连通域(面积 <= max_area
keep_labels = np.where(component_sizes <= max_area)[0] + 1 # +1 因为标签从1开始
# 使用布尔索引一次性过滤(高效方法)
# 创建一个mask标记所有需要保留的连通域
keep_mask = np.isin(labeled_array, keep_labels)
# 创建输出图像
filtered_img = np.zeros_like(binary_img, dtype=binary_img.dtype)
filtered_img[keep_mask] = foreground
# 统计信息
removed_count = num_features - len(keep_labels)
kept_count = len(keep_labels)
total_removed_pixels = np.sum(component_sizes[component_sizes > max_area])
if removed_count > 0:
print(f"连通域面积过滤: 移除了 {removed_count} 个大面积连通域(面积 > {max_area} 像素),"
f"共移除 {total_removed_pixels} 个像素;保留了 {kept_count} 个小面积连通域")
else:
print(f"连通域面积过滤: 所有 {kept_count} 个连通域面积均小于阈值 {max_area},全部保留")
return filtered_img
def find_overexposure_area(img_path, threhold=4095):
# 第一步通过某个像素的光谱找到信号最强的波段
# 根据上步所得的波段号检测过曝区域
pass
def create_water_mask_from_shp(shp_file, reference_raster):
"""
从shp文件创建水体掩膜栅格数组内存中不保存到磁盘
参数:
shp_file: str - shp文件路径
reference_raster: str - 参考栅格文件路径(用于获取空间范围和分辨率)
返回:
numpy.ndarray - 水体掩膜数组
"""
try:
# 打开参考栅格获取空间信息
ref_dataset = gdal.Open(reference_raster)
if ref_dataset is None:
raise ValueError(f"无法打开参考栅格文件: {reference_raster}")
geotransform = ref_dataset.GetGeoTransform()
projection = ref_dataset.GetProjection()
width = ref_dataset.RasterXSize
height = ref_dataset.RasterYSize
# 创建内存中的栅格数据集
mem_driver = gdal.GetDriverByName('MEM')
mask_dataset = mem_driver.Create('', width, height, 1, gdal.GDT_Byte)
mask_dataset.SetGeoTransform(geotransform)
mask_dataset.SetProjection(projection)
# 初始化为0
mask_band = mask_dataset.GetRasterBand(1)
mask_band.Fill(0)
# 打开shp文件
shp_dataset = ogr.Open(shp_file)
if shp_dataset is None:
raise ValueError(f"无法打开shp文件: {shp_file}")
layer = shp_dataset.GetLayer()
# 栅格化shp文件
gdal.RasterizeLayer(mask_dataset, [1], layer, burn_values=[1])
# 读取栅格化结果
water_mask = mask_band.ReadAsArray()
# 清理
ref_dataset = None
mask_dataset = None
shp_dataset = None
return water_mask
except Exception as e:
print(f"创建水体掩膜时发生错误: {str(e)}")
raise
@timeit
def find_severe_glint_area(img_path, water_mask, glint_wave=750, output_path=None,
method='otsu', multi_band_waves=None, **kwargs):
"""
找到严重耀斑区域的主函数
注意对于低反射率数据如水面反射率约0.02),本函数采用了改进的归一化策略:
- 统计方法zscore, percentile, iqr直接使用原始反射率值无需归一化
- Otsu和adaptive方法使用百分位数裁剪拉伸2%-98%分位数),避免极值影响
Args:
img_path: 输入影像路径
water_mask: 水域掩膜路径(支持栅格文件如.dat/.tif或SHP文件如.shp如果为None或空字符串则使用全图进行检测
glint_wave: 用于检测的波长(单个波段方法使用)
output_path: 输出路径
method: 检测方法,可选:
- 'otsu': Otsu阈值分割默认使用百分位数拉伸
- 'zscore': Z-score统计方法直接使用原始反射率
- 'percentile': 百分位数阈值方法(直接使用原始反射率)
- 'iqr': IQR异常值检测直接使用原始反射率
- 'adaptive': 自适应阈值方法(使用百分位数拉伸)
- 'multi_band': 多波段融合方法
multi_band_waves: 多波段方法的波长列表,如[750, 800, 850]
**kwargs: 其他方法特定参数
- z_threshold: Z-score阈值默认2.5
- percentile: 百分位数默认95
- iqr_multiplier: IQR倍数默认1.5
- window_size: 自适应阈值窗口大小默认15
- weights: 多波段方法的权重列表
- sub_method: 多波段方法的子方法('otsu', 'zscore', 'percentile'
- max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被过滤掉
用于去除岸边、浅水、水华等大面积区域默认None表示不过滤
- buffer_size: 岸边缓冲区大小(像素数),用于去除岸边附近的错误耀斑掩膜
默认None表示不进行岸边缓冲区去除设置为正整数时启用
Returns:
输出文件路径
"""
if output_path is None:
output_path = append2filename(img_path, "_severe_glint_area")
dataset = gdal.Open(img_path)
num_bands = dataset.RasterCount
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
# 读取水域掩膜如果water_mask为None或空字符串则创建全图掩膜
if water_mask is None or water_mask == "":
print("注意: water_mask为空使用全图进行检测")
data_water_mask = np.ones((im_height, im_width), dtype=np.int32)
else:
# 检查是否为SHP文件
water_mask_lower = water_mask.lower()
if water_mask_lower.endswith('.shp'):
# 直接使用SHP文件在内存中栅格化
print(f"检测到SHP文件正在从 {water_mask} 创建水体掩膜...")
data_water_mask = create_water_mask_from_shp(water_mask, img_path)
else:
# 使用栅格文件
dataset_water_mask = gdal.Open(water_mask)
if dataset_water_mask is None:
raise ValueError(f"无法打开水域掩膜文件: {water_mask}")
data_water_mask = dataset_water_mask.GetRasterBand(1).ReadAsArray()
del dataset_water_mask
print(f"使用检测方法: {method}")
# 根据方法选择检测算法
if method == 'multi_band':
if multi_band_waves is None:
# 默认使用几个常见NIR波段
multi_band_waves = [glint_wave, glint_wave + 50, glint_wave + 100]
print(f"多波段方法: 使用默认波长 {multi_band_waves}")
else:
print(f"多波段方法: 使用波长 {multi_band_waves}")
sub_method = kwargs.get('sub_method', 'zscore')
weights = kwargs.get('weights', None)
z_threshold = kwargs.get('z_threshold', 2.5)
percentile = kwargs.get('percentile', 95)
flare_binary = multi_band_glint_detection(
dataset, img_path, data_water_mask, multi_band_waves, weights,
method=sub_method, z_threshold=z_threshold, percentile=percentile
)
else:
# 单波段方法
glint_band_number = find_band_number(glint_wave, img_path)
tmp = dataset.GetRasterBand(glint_band_number + 1)
band_flare = tmp.ReadAsArray().astype(np.float32)
# 根据方法选择是否需要归一化
# 对于统计方法zscore, percentile, iqr直接使用原始反射率值
# 对于Otsu和adaptive方法需要归一化到整数范围
if method == 'otsu':
# Otsu方法需要整数范围使用百分位数拉伸
band_flare_stretch = percentile_stretch(band_flare, data_water_mask,
lower_percentile=2, upper_percentile=98)
flare_binary = otsu(band_flare_stretch, band_flare_stretch.max() + 1, data_water_mask)
elif method == 'zscore':
# Z-score方法直接使用原始反射率值
z_threshold = kwargs.get('z_threshold', 2.5)
flare_binary = zscore_threshold(band_flare, data_water_mask, z_threshold)
elif method == 'percentile':
# 百分位数方法直接使用原始反射率值
percentile = kwargs.get('percentile', 95)
flare_binary = percentile_threshold(band_flare, data_water_mask, percentile)
elif method == 'iqr':
# IQR方法直接使用原始反射率值
iqr_multiplier = kwargs.get('iqr_multiplier', 1.5)
flare_binary = iqr_outlier_detection(band_flare, data_water_mask, iqr_multiplier)
elif method == 'adaptive':
# 自适应阈值方法需要归一化
band_flare_stretch = percentile_stretch(band_flare, data_water_mask,
lower_percentile=2, upper_percentile=98)
window_size = kwargs.get('window_size', 15)
percentile = kwargs.get('percentile', 90)
flare_binary = adaptive_threshold(band_flare_stretch, data_water_mask, window_size, percentile)
else:
raise ValueError(f"不支持的方法: {method}。可选方法: otsu, zscore, percentile, iqr, adaptive, multi_band")
# 过滤掉面积超过阈值的连通域(用于去除岸边、浅水、水华等大面积区域)
max_area = kwargs.get('max_area', None)
if max_area is not None and max_area > 0:
print(f"应用连通域面积过滤,最大面积阈值: {max_area} 像素")
flare_binary = filter_large_components(flare_binary, max_area=max_area)
# 去除岸边缓冲区内的耀斑掩膜(用于去除岸边的错误检测)
buffer_size = kwargs.get('buffer_size', None)
if buffer_size is not None and buffer_size > 0:
print(f"应用岸边缓冲区去除,缓冲区大小: {buffer_size} 像素")
flare_binary = remove_shoreline_buffer(flare_binary, data_water_mask, buffer_size=buffer_size)
write_bands(img_path, output_path, flare_binary)
del dataset
return output_path
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
img_path = r"D:\PycharmProjects\0water_rlx\test_data\ref_mosaic_1m_bsq"
parser = argparse.ArgumentParser(
description="此程序通过多种算法分割图像,提取耀斑最严重的区域。"
"支持的算法: otsu, zscore, percentile, iqr, adaptive, multi_band"
)
parser.add_argument('-i1', '--input', type=str, required=True, help='输入影像文件的路径')
parser.add_argument('-i2', '--input_water_mask', type=str, required=True, help='输入水域掩膜文件的路径')
parser.add_argument('-gw', '--glint_wave', type=float, default=750.0,
help='用于提取耀斑严重区域的波段波长(单波段方法使用)')
parser.add_argument('-m', '--method', type=str, default='otsu',
choices=['otsu', 'zscore', 'percentile', 'iqr', 'adaptive', 'multi_band'],
help='检测方法: otsu(默认), zscore, percentile, iqr, adaptive, multi_band')
parser.add_argument('-o', '--output', type=str, help='输出文件的路径')
# 方法特定参数
parser.add_argument('-zt', '--z_threshold', type=float, default=2.5,
help='Z-score方法的阈值默认2.5')
parser.add_argument('-p', '--percentile', type=float, default=95.0,
help='百分位数阈值默认95')
parser.add_argument('-iqr', '--iqr_multiplier', type=float, default=1.5,
help='IQR方法的倍数默认1.5')
parser.add_argument('-ws', '--window_size', type=int, default=15,
help='自适应阈值方法的窗口大小默认15')
parser.add_argument('-mbw', '--multi_band_waves', type=str, default=None,
help='多波段方法的波长列表,用逗号分隔,如: 750,800,850')
parser.add_argument('-sm', '--sub_method', type=str, default='zscore',
choices=['otsu', 'zscore', 'percentile'],
help='多波段方法的子方法默认zscore')
parser.add_argument('-ma', '--max_area', type=int, default=None,
help='最大连通域面积阈值(像素数),超过此面积的连通域将被过滤掉,'
'用于去除岸边、浅水、水华等大面积区域默认None表示不过滤')
parser.add_argument('-bs', '--buffer_size', type=int, default=None,
help='岸边缓冲区大小(像素数),用于去除岸边附近的错误耀斑掩膜'
'默认None表示不进行岸边缓冲区去除设置为正整数时启用')
parser.add_argument('-v', '--verbose', action='store_true', help='启用详细模式')
args = parser.parse_args()
# 解析多波段波长列表
multi_band_waves = None
if args.multi_band_waves:
multi_band_waves = [float(x.strip()) for x in args.multi_band_waves.split(',')]
# 构建kwargs
kwargs = {
'z_threshold': args.z_threshold,
'percentile': args.percentile,
'iqr_multiplier': args.iqr_multiplier,
'window_size': args.window_size,
'sub_method': args.sub_method,
'max_area': args.max_area,
'buffer_size': args.buffer_size
}
find_severe_glint_area(
args.input, args.input_water_mask, args.glint_wave, args.output,
method=args.method, multi_band_waves=multi_band_waves, **kwargs
)