Files
WQ_GUI/src/utils/find_severe_glint_area.py
2026-04-08 15:25:08 +08:00

766 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from src.utils.util import *
from osgeo import gdal, ogr
import argparse
import cv2
def percentile_stretch(img, data_water_mask, lower_percentile=2, upper_percentile=98, output_range=(0, 255)):
"""
使用百分位数裁剪进行归一化,适用于低反射率数据
通过排除极值,更好地利用数据的动态范围
Args:
img: 输入图像数组反射率值通常在0-1之间
data_water_mask: 水域掩膜
lower_percentile: 下百分位数用于裁剪最小值默认2
upper_percentile: 上百分位数用于裁剪最大值默认98
output_range: 输出范围,默认(0, 255)
Returns:
归一化后的图像数组(整数类型)
"""
# 只在水域掩膜区域计算百分位数
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于百分位数计算,使用原始值")
return img.astype(np.int32)
# 计算百分位数
p_lower = np.percentile(valid_pixels, lower_percentile)
p_upper = np.percentile(valid_pixels, upper_percentile)
# 如果上下界相同,使用最大值作为上界
if p_lower >= p_upper:
p_lower = np.percentile(valid_pixels, 1)
p_upper = np.percentile(valid_pixels, 99)
if p_lower >= p_upper:
p_upper = valid_pixels.max()
p_lower = valid_pixels.min()
print(f"百分位数拉伸: {lower_percentile}%={p_lower:.6f}, {upper_percentile}%={p_upper:.6f}, "
f"数据范围=[{img.min():.6f}, {img.max():.6f}]")
# 裁剪到百分位数范围
img_clipped = np.clip(img, p_lower, p_upper)
# 线性拉伸到输出范围
if p_upper > p_lower:
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (output_range[1] - output_range[0]) + output_range[0]
else:
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
return img_stretched.astype(np.int32)
@timeit
def otsu(img, max_value, data_water_mask, ignore_value=0, foreground=1, background=0):
height = img.shape[0]
width = img.shape[1]
hist = np.zeros([max_value], np.float32)
# 计算直方图
invalid_counter = 0
for i in range(height):
for j in range(width):
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
invalid_counter = invalid_counter + 1
continue
hist[img[i, j]] += 1
hist /= (height * width - invalid_counter)
threshold = 0
deltaMax = 0
# 遍历像素值,计算最大类间方差
for i in range(max_value):
wA = 0
wB = 0
uAtmp = 0
uBtmp = 0
uA = 0
uB = 0
u = 0
for j in range(max_value):
if j <= i:
wA += hist[j]
uAtmp += j * hist[j]
else:
wB += hist[j]
uBtmp += j * hist[j]
if wA == 0:
wA = 1e-10
if wB == 0:
wB = 1e-10
uA = uAtmp / wA
uB = uBtmp / wB
u = uAtmp + uBtmp
# 计算类间方差
deltaTmp = wA * ((uA - u)**2) + wB * ((uB - u)**2)
# 找出最大类间方差以及阈值
if deltaTmp > deltaMax:
deltaMax = deltaTmp
threshold = i
# 二值化
det_img = img.copy()
det_img[img > threshold] = foreground
det_img[img <= threshold] = background
det_img[np.where(data_water_mask == 0)] = background
return det_img
@timeit
def zscore_threshold(img, data_water_mask, z_threshold=2.5, foreground=1, background=0):
"""
基于Z-score标准化分数的耀斑检测方法
使用统计方法识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
z_threshold: Z-score阈值默认2.5即超过均值2.5个标准差)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
# 只在水域掩膜区域计算统计量,排除无效值
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于统计计算")
return np.zeros_like(img, dtype=np.int32)
mean_val = np.mean(valid_pixels)
std_val = np.std(valid_pixels)
if std_val == 0:
print("警告: 标准差为0无法使用Z-score方法")
return np.zeros_like(img, dtype=np.int32)
# 计算Z-score对无效值进行保护
z_scores = np.zeros_like(img, dtype=np.float32)
valid_mask = (data_water_mask > 0) & np.isfinite(img)
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
# 二值化
det_img = np.zeros_like(img, dtype=np.int32)
det_img[z_scores > z_threshold] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"Z-score方法: 均值={mean_val:.2f}, 标准差={std_val:.2f}, 阈值={mean_val + z_threshold * std_val:.2f}")
return det_img
@timeit
def percentile_threshold(img, data_water_mask, percentile=95, foreground=1, background=0):
"""
基于百分位数的耀斑检测方法
使用百分位数作为阈值,对异常值更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
percentile: 百分位数阈值默认95即超过95%的像素值)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
# 只在水域掩膜区域计算百分位数,排除无效值
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于统计计算")
return np.zeros_like(img, dtype=np.int32)
threshold = np.percentile(valid_pixels, percentile)
# 二值化
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > threshold] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"百分位数方法: {percentile}%分位数为 {threshold:.2f}")
return det_img
@timeit
def multi_band_glint_detection(dataset, img_path, water_mask, glint_waves, weights=None, method='zscore',
z_threshold=2.5, percentile=95, foreground=1, background=0):
"""
多波段融合的耀斑检测方法
结合多个波段的耀斑特征,提高检测的稳健性
Args:
dataset: GDAL数据集
img_path: 影像文件路径(用于获取波长信息)
water_mask: 水域掩膜数组
glint_waves: 用于检测的波长列表,如[750, 800, 850]
weights: 各波段的权重如果为None则使用等权重
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
z_threshold: Z-score阈值当method='zscore'时使用)
percentile: 百分位数阈值当method='percentile'时使用)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
num_bands = dataset.RasterCount
if weights is None:
weights = [1.0 / len(glint_waves)] * len(glint_waves)
if len(weights) != len(glint_waves):
raise ValueError("权重数量必须与波长数量相同")
# 读取多个波段并加权融合使用float32保持精度
fused_band = None
for i, wave in enumerate(glint_waves):
band_num = find_band_number(wave, img_path)
if band_num >= num_bands:
print(f"警告: 波段号 {band_num} 超出范围,跳过波长 {wave}")
continue
tmp = dataset.GetRasterBand(band_num + 1).ReadAsArray().astype(np.float32)
if fused_band is None:
fused_band = (tmp * weights[i]).astype(np.float32)
else:
fused_band = (fused_band + tmp * weights[i]).astype(np.float32)
if fused_band is None:
raise ValueError("没有有效的波段可以融合")
# 根据方法选择是否需要归一化
# 对于统计方法zscore, percentile直接使用原始反射率值
# 对于Otsu方法需要归一化到整数范围
if method == 'otsu':
# Otsu方法需要整数范围使用百分位数拉伸
fused_band_stretch = percentile_stretch(fused_band, water_mask,
lower_percentile=2, upper_percentile=98)
return otsu(fused_band_stretch, fused_band_stretch.max() + 1, water_mask,
foreground=foreground, background=background)
elif method == 'zscore':
# Z-score方法直接使用原始反射率值
return zscore_threshold(fused_band, water_mask, z_threshold, foreground, background)
elif method == 'percentile':
# 百分位数方法直接使用原始反射率值
return percentile_threshold(fused_band, water_mask, percentile, foreground, background)
else:
raise ValueError(f"不支持的方法: {method}")
@timeit
def adaptive_threshold(img, data_water_mask, window_size=15, percentile=90, foreground=1, background=0):
"""
自适应阈值方法
基于局部统计特性进行阈值分割,对光照变化更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
window_size: 局部窗口大小(奇数)
percentile: 局部百分位数阈值
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
height, width = img.shape
# 确保窗口大小为奇数
if window_size % 2 == 0:
window_size += 1
half_window = window_size // 2
# 创建输出图像
det_img = np.zeros_like(img, dtype=np.int32)
# 对每个像素计算局部阈值
for i in range(half_window, height - half_window):
for j in range(half_window, width - half_window):
# 只在水域掩膜内处理
if data_water_mask[i, j] == 0:
continue
# 提取局部窗口
local_window = img[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
local_mask = data_water_mask[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
# 只考虑有效像素
valid_pixels = local_window[local_mask > 0]
if len(valid_pixels) > 0:
local_threshold = np.percentile(valid_pixels, percentile)
if img[i, j] > local_threshold:
det_img[i, j] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"自适应阈值方法: 窗口大小={window_size}, 局部百分位数={percentile}%")
return det_img
@timeit
def iqr_outlier_detection(img, data_water_mask, iqr_multiplier=1.5, foreground=1, background=0):
"""
基于IQR四分位距的异常值检测方法
使用四分位距识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
iqr_multiplier: IQR倍数默认1.5(标准异常值检测)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
# 只在水域掩膜区域计算统计量,排除无效值
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于统计计算")
return np.zeros_like(img, dtype=np.int32)
q1 = np.percentile(valid_pixels, 25)
q3 = np.percentile(valid_pixels, 75)
iqr = q3 - q1
# 上界 = Q3 + 1.5 * IQR
upper_bound = q3 + iqr_multiplier * iqr
# 二值化
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > upper_bound] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"IQR方法: Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}, 上界={upper_bound:.2f}")
return det_img
@timeit
def create_shoreline_buffer(water_mask, buffer_size=5, foreground=1, background=0):
"""
创建岸边缓冲区掩膜(向内缓冲)
用于去除岸边附近的错误耀斑检测区域
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
Args:
water_mask: 水域掩膜数组(水域=1非水域=0
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
岸边缓冲区掩膜(缓冲区区域=1其他=0
"""
if buffer_size <= 0:
print("缓冲区大小为0或负数不创建岸边缓冲区")
return np.zeros_like(water_mask, dtype=np.int32)
# 将水域掩膜转换为二值图像
water_binary = (water_mask > 0).astype(np.int32)
# 创建结构元素(方形结构元素)
# 结构元素大小由buffer_size决定确保是奇数
structure_size = buffer_size * 2 + 1
structure = np.ones((structure_size, structure_size), dtype=np.int32)
# 对水域进行腐蚀,得到缩小后的水域
# 使用OpenCV替代scipy.ndimage.binary_erosion
eroded_water = cv2.erode(water_binary.astype(np.uint8), structure.astype(np.uint8)).astype(np.int32)
# 岸边缓冲区 = 原始水域 - 腐蚀后的水域
# 这给出了水域边缘向内buffer_size像素宽的缓冲区区域
buffer_mask = (water_binary - eroded_water).astype(np.int32)
buffer_pixels = np.sum(buffer_mask > 0)
print(f"岸边缓冲区: 创建了 {buffer_size} 像素宽的内向缓冲区,共 {buffer_pixels} 个像素")
return buffer_mask
@timeit
def remove_shoreline_buffer(glint_mask, water_mask, buffer_size=5, foreground=1, background=0):
"""
从耀斑掩膜中去除岸边缓冲区内的区域
Args:
glint_mask: 耀斑掩膜数组
water_mask: 水域掩膜数组
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
去除岸边缓冲区后的耀斑掩膜
"""
if buffer_size <= 0:
print("缓冲区大小为0不进行岸边缓冲区去除")
return glint_mask
# 创建岸边缓冲区掩膜
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
# 从耀斑掩膜中去除缓冲区内的区域
cleaned_glint_mask = glint_mask.copy()
cleaned_glint_mask[buffer_mask > 0] = background
removed_pixels = np.sum((glint_mask > 0) & (buffer_mask > 0))
remaining_pixels = np.sum(cleaned_glint_mask > 0)
if removed_pixels > 0:
print(f"岸边缓冲区去除: 从耀斑掩膜中移除了 {removed_pixels} 个岸边向内缓冲区域的像素,"
f"剩余 {remaining_pixels} 个像素")
else:
print(f"岸边缓冲区去除: 缓冲区区域没有耀斑掩膜,无需移除")
return cleaned_glint_mask
@timeit
def filter_large_components(binary_img, max_area=None, foreground=1, background=0):
"""
过滤掉面积超过阈值的连通域
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
Args:
binary_img: 二值化图像
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
如果为None则不进行过滤
foreground: 前景值
background: 背景值
Returns:
过滤后的二值化图像
"""
if max_area is None or max_area <= 0:
return binary_img
# 连通域标记
# 使用OpenCV替代scipy.ndimage.label
binary_for_label = (binary_img == foreground).astype(np.uint8)
num_features, labeled_array, stats, centroids = cv2.connectedComponentsWithStats(binary_for_label, connectivity=8)
if num_features == 0:
print("没有检测到连通域")
return binary_img
# 使用OpenCV返回的stats信息直接获取连通域面积
# stats[:, cv2.CC_STAT_AREA] 包含每个连通域的面积(包括背景)
# 跳过索引0背景的面积从索引1开始获取连通域面积
component_sizes = stats[1:, cv2.CC_STAT_AREA]
# 找出需要保留的连通域(面积 <= max_area
keep_labels = np.where(component_sizes <= max_area)[0] + 1 # +1 因为标签从1开始
# 使用布尔索引一次性过滤(高效方法)
# 创建一个mask标记所有需要保留的连通域
keep_mask = np.isin(labeled_array, keep_labels)
# 创建输出图像
filtered_img = np.zeros_like(binary_img, dtype=binary_img.dtype)
filtered_img[keep_mask] = foreground
# 统计信息
removed_count = num_features - len(keep_labels)
kept_count = len(keep_labels)
total_removed_pixels = np.sum(component_sizes[component_sizes > max_area])
if removed_count > 0:
print(f"连通域面积过滤: 移除了 {removed_count} 个大面积连通域(面积 > {max_area} 像素),"
f"共移除 {total_removed_pixels} 个像素;保留了 {kept_count} 个小面积连通域")
else:
print(f"连通域面积过滤: 所有 {kept_count} 个连通域面积均小于阈值 {max_area},全部保留")
return filtered_img
def find_overexposure_area(img_path, threhold=4095):
# 第一步通过某个像素的光谱找到信号最强的波段
# 根据上步所得的波段号检测过曝区域
pass
def create_water_mask_from_shp(shp_file, reference_raster):
"""
从shp文件创建水体掩膜栅格数组内存中不保存到磁盘
参数:
shp_file: str - shp文件路径
reference_raster: str - 参考栅格文件路径(用于获取空间范围和分辨率)
返回:
numpy.ndarray - 水体掩膜数组
"""
try:
# 打开参考栅格获取空间信息
ref_dataset = gdal.Open(reference_raster)
if ref_dataset is None:
raise ValueError(f"无法打开参考栅格文件: {reference_raster}")
geotransform = ref_dataset.GetGeoTransform()
projection = ref_dataset.GetProjection()
width = ref_dataset.RasterXSize
height = ref_dataset.RasterYSize
# 创建内存中的栅格数据集
mem_driver = gdal.GetDriverByName('MEM')
mask_dataset = mem_driver.Create('', width, height, 1, gdal.GDT_Byte)
mask_dataset.SetGeoTransform(geotransform)
mask_dataset.SetProjection(projection)
# 初始化为0
mask_band = mask_dataset.GetRasterBand(1)
mask_band.Fill(0)
# 打开shp文件
shp_dataset = ogr.Open(shp_file)
if shp_dataset is None:
raise ValueError(f"无法打开shp文件: {shp_file}")
layer = shp_dataset.GetLayer()
# 栅格化shp文件
gdal.RasterizeLayer(mask_dataset, [1], layer, burn_values=[1])
# 读取栅格化结果
water_mask = mask_band.ReadAsArray()
# 清理
ref_dataset = None
mask_dataset = None
shp_dataset = None
return water_mask
except Exception as e:
print(f"创建水体掩膜时发生错误: {str(e)}")
raise
@timeit
def find_severe_glint_area(img_path, water_mask, glint_wave=750, output_path=None,
method='otsu', multi_band_waves=None, **kwargs):
"""
找到严重耀斑区域的主函数
注意对于低反射率数据如水面反射率约0.02),本函数采用了改进的归一化策略:
- 统计方法zscore, percentile, iqr直接使用原始反射率值无需归一化
- Otsu和adaptive方法使用百分位数裁剪拉伸2%-98%分位数),避免极值影响
Args:
img_path: 输入影像路径
water_mask: 水域掩膜路径(支持栅格文件如.dat/.tif或SHP文件如.shp如果为None或空字符串则使用全图进行检测
glint_wave: 用于检测的波长(单个波段方法使用)
output_path: 输出路径
method: 检测方法,可选:
- 'otsu': Otsu阈值分割默认使用百分位数拉伸
- 'zscore': Z-score统计方法直接使用原始反射率
- 'percentile': 百分位数阈值方法(直接使用原始反射率)
- 'iqr': IQR异常值检测直接使用原始反射率
- 'adaptive': 自适应阈值方法(使用百分位数拉伸)
- 'multi_band': 多波段融合方法
multi_band_waves: 多波段方法的波长列表,如[750, 800, 850]
**kwargs: 其他方法特定参数
- z_threshold: Z-score阈值默认2.5
- percentile: 百分位数默认95
- iqr_multiplier: IQR倍数默认1.5
- window_size: 自适应阈值窗口大小默认15
- weights: 多波段方法的权重列表
- sub_method: 多波段方法的子方法('otsu', 'zscore', 'percentile'
- max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被过滤掉
用于去除岸边、浅水、水华等大面积区域默认None表示不过滤
- buffer_size: 岸边缓冲区大小(像素数),用于去除岸边附近的错误耀斑掩膜
默认None表示不进行岸边缓冲区去除设置为正整数时启用
Returns:
输出文件路径
"""
if output_path is None:
output_path = append2filename(img_path, "_severe_glint_area")
dataset = gdal.Open(img_path)
num_bands = dataset.RasterCount
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
# 读取水域掩膜如果water_mask为None或空字符串则创建全图掩膜
if water_mask is None or water_mask == "":
print("注意: water_mask为空使用全图进行检测")
data_water_mask = np.ones((im_height, im_width), dtype=np.int32)
else:
# 检查是否为SHP文件
water_mask_lower = water_mask.lower()
if water_mask_lower.endswith('.shp'):
# 直接使用SHP文件在内存中栅格化
print(f"检测到SHP文件正在从 {water_mask} 创建水体掩膜...")
data_water_mask = create_water_mask_from_shp(water_mask, img_path)
else:
# 使用栅格文件
dataset_water_mask = gdal.Open(water_mask)
if dataset_water_mask is None:
raise ValueError(f"无法打开水域掩膜文件: {water_mask}")
data_water_mask = dataset_water_mask.GetRasterBand(1).ReadAsArray()
del dataset_water_mask
print(f"使用检测方法: {method}")
# 根据方法选择检测算法
if method == 'multi_band':
if multi_band_waves is None:
# 默认使用几个常见NIR波段
multi_band_waves = [glint_wave, glint_wave + 50, glint_wave + 100]
print(f"多波段方法: 使用默认波长 {multi_band_waves}")
else:
print(f"多波段方法: 使用波长 {multi_band_waves}")
sub_method = kwargs.get('sub_method', 'zscore')
weights = kwargs.get('weights', None)
z_threshold = kwargs.get('z_threshold', 2.5)
percentile = kwargs.get('percentile', 95)
flare_binary = multi_band_glint_detection(
dataset, img_path, data_water_mask, multi_band_waves, weights,
method=sub_method, z_threshold=z_threshold, percentile=percentile
)
else:
# 单波段方法
glint_band_number = find_band_number(glint_wave, img_path)
tmp = dataset.GetRasterBand(glint_band_number + 1)
band_flare = tmp.ReadAsArray().astype(np.float32)
# 根据方法选择是否需要归一化
# 对于统计方法zscore, percentile, iqr直接使用原始反射率值
# 对于Otsu和adaptive方法需要归一化到整数范围
if method == 'otsu':
# Otsu方法需要整数范围使用百分位数拉伸
band_flare_stretch = percentile_stretch(band_flare, data_water_mask,
lower_percentile=2, upper_percentile=98)
flare_binary = otsu(band_flare_stretch, band_flare_stretch.max() + 1, data_water_mask)
elif method == 'zscore':
# Z-score方法直接使用原始反射率值
z_threshold = kwargs.get('z_threshold', 2.5)
flare_binary = zscore_threshold(band_flare, data_water_mask, z_threshold)
elif method == 'percentile':
# 百分位数方法直接使用原始反射率值
percentile = kwargs.get('percentile', 95)
flare_binary = percentile_threshold(band_flare, data_water_mask, percentile)
elif method == 'iqr':
# IQR方法直接使用原始反射率值
iqr_multiplier = kwargs.get('iqr_multiplier', 1.5)
flare_binary = iqr_outlier_detection(band_flare, data_water_mask, iqr_multiplier)
elif method == 'adaptive':
# 自适应阈值方法需要归一化
band_flare_stretch = percentile_stretch(band_flare, data_water_mask,
lower_percentile=2, upper_percentile=98)
window_size = kwargs.get('window_size', 15)
percentile = kwargs.get('percentile', 90)
flare_binary = adaptive_threshold(band_flare_stretch, data_water_mask, window_size, percentile)
else:
raise ValueError(f"不支持的方法: {method}。可选方法: otsu, zscore, percentile, iqr, adaptive, multi_band")
# 过滤掉面积超过阈值的连通域(用于去除岸边、浅水、水华等大面积区域)
max_area = kwargs.get('max_area', None)
if max_area is not None and max_area > 0:
print(f"应用连通域面积过滤,最大面积阈值: {max_area} 像素")
flare_binary = filter_large_components(flare_binary, max_area=max_area)
# 去除岸边缓冲区内的耀斑掩膜(用于去除岸边的错误检测)
buffer_size = kwargs.get('buffer_size', None)
if buffer_size is not None and buffer_size > 0:
print(f"应用岸边缓冲区去除,缓冲区大小: {buffer_size} 像素")
flare_binary = remove_shoreline_buffer(flare_binary, data_water_mask, buffer_size=buffer_size)
write_bands(img_path, output_path, flare_binary)
del dataset
return output_path
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
img_path = r"D:\PycharmProjects\0water_rlx\test_data\ref_mosaic_1m_bsq"
parser = argparse.ArgumentParser(
description="此程序通过多种算法分割图像,提取耀斑最严重的区域。"
"支持的算法: otsu, zscore, percentile, iqr, adaptive, multi_band"
)
parser.add_argument('-i1', '--input', type=str, required=True, help='输入影像文件的路径')
parser.add_argument('-i2', '--input_water_mask', type=str, required=True, help='输入水域掩膜文件的路径')
parser.add_argument('-gw', '--glint_wave', type=float, default=750.0,
help='用于提取耀斑严重区域的波段波长(单波段方法使用)')
parser.add_argument('-m', '--method', type=str, default='otsu',
choices=['otsu', 'zscore', 'percentile', 'iqr', 'adaptive', 'multi_band'],
help='检测方法: otsu(默认), zscore, percentile, iqr, adaptive, multi_band')
parser.add_argument('-o', '--output', type=str, help='输出文件的路径')
# 方法特定参数
parser.add_argument('-zt', '--z_threshold', type=float, default=2.5,
help='Z-score方法的阈值默认2.5')
parser.add_argument('-p', '--percentile', type=float, default=95.0,
help='百分位数阈值默认95')
parser.add_argument('-iqr', '--iqr_multiplier', type=float, default=1.5,
help='IQR方法的倍数默认1.5')
parser.add_argument('-ws', '--window_size', type=int, default=15,
help='自适应阈值方法的窗口大小默认15')
parser.add_argument('-mbw', '--multi_band_waves', type=str, default=None,
help='多波段方法的波长列表,用逗号分隔,如: 750,800,850')
parser.add_argument('-sm', '--sub_method', type=str, default='zscore',
choices=['otsu', 'zscore', 'percentile'],
help='多波段方法的子方法默认zscore')
parser.add_argument('-ma', '--max_area', type=int, default=None,
help='最大连通域面积阈值(像素数),超过此面积的连通域将被过滤掉,'
'用于去除岸边、浅水、水华等大面积区域默认None表示不过滤')
parser.add_argument('-bs', '--buffer_size', type=int, default=None,
help='岸边缓冲区大小(像素数),用于去除岸边附近的错误耀斑掩膜'
'默认None表示不进行岸边缓冲区去除设置为正整数时启用')
parser.add_argument('-v', '--verbose', action='store_true', help='启用详细模式')
args = parser.parse_args()
# 解析多波段波长列表
multi_band_waves = None
if args.multi_band_waves:
multi_band_waves = [float(x.strip()) for x in args.multi_band_waves.split(',')]
# 构建kwargs
kwargs = {
'z_threshold': args.z_threshold,
'percentile': args.percentile,
'iqr_multiplier': args.iqr_multiplier,
'window_size': args.window_size,
'sub_method': args.sub_method,
'max_area': args.max_area,
'buffer_size': args.buffer_size
}
find_severe_glint_area(
args.input, args.input_water_mask, args.glint_wave, args.output,
method=args.method, multi_band_waves=multi_band_waves, **kwargs
)