766 lines
30 KiB
Python
766 lines
30 KiB
Python
from src.utils.util import *
|
||
from osgeo import gdal, ogr
|
||
import argparse
|
||
import cv2
|
||
|
||
|
||
|
||
def percentile_stretch(img, data_water_mask, lower_percentile=2, upper_percentile=98, output_range=(0, 255)):
|
||
"""
|
||
使用百分位数裁剪进行归一化,适用于低反射率数据
|
||
通过排除极值,更好地利用数据的动态范围
|
||
|
||
Args:
|
||
img: 输入图像数组(反射率值,通常在0-1之间)
|
||
data_water_mask: 水域掩膜
|
||
lower_percentile: 下百分位数,用于裁剪最小值(默认2)
|
||
upper_percentile: 上百分位数,用于裁剪最大值(默认98)
|
||
output_range: 输出范围,默认(0, 255)
|
||
|
||
Returns:
|
||
归一化后的图像数组(整数类型)
|
||
"""
|
||
# 只在水域掩膜区域计算百分位数
|
||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||
|
||
if len(valid_pixels) == 0:
|
||
print("警告: 没有有效像素用于百分位数计算,使用原始值")
|
||
return img.astype(np.int32)
|
||
|
||
# 计算百分位数
|
||
p_lower = np.percentile(valid_pixels, lower_percentile)
|
||
p_upper = np.percentile(valid_pixels, upper_percentile)
|
||
|
||
# 如果上下界相同,使用最大值作为上界
|
||
if p_lower >= p_upper:
|
||
p_lower = np.percentile(valid_pixels, 1)
|
||
p_upper = np.percentile(valid_pixels, 99)
|
||
if p_lower >= p_upper:
|
||
p_upper = valid_pixels.max()
|
||
p_lower = valid_pixels.min()
|
||
|
||
print(f"百分位数拉伸: {lower_percentile}%={p_lower:.6f}, {upper_percentile}%={p_upper:.6f}, "
|
||
f"数据范围=[{img.min():.6f}, {img.max():.6f}]")
|
||
|
||
# 裁剪到百分位数范围
|
||
img_clipped = np.clip(img, p_lower, p_upper)
|
||
|
||
# 线性拉伸到输出范围
|
||
if p_upper > p_lower:
|
||
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (output_range[1] - output_range[0]) + output_range[0]
|
||
else:
|
||
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
|
||
|
||
return img_stretched.astype(np.int32)
|
||
|
||
|
||
@timeit
|
||
def otsu(img, max_value, data_water_mask, ignore_value=0, foreground=1, background=0):
|
||
height = img.shape[0]
|
||
width = img.shape[1]
|
||
|
||
hist = np.zeros([max_value], np.float32)
|
||
|
||
# 计算直方图
|
||
invalid_counter = 0
|
||
for i in range(height):
|
||
for j in range(width):
|
||
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
|
||
invalid_counter = invalid_counter + 1
|
||
continue
|
||
|
||
hist[img[i, j]] += 1
|
||
|
||
hist /= (height * width - invalid_counter)
|
||
|
||
threshold = 0
|
||
deltaMax = 0
|
||
# 遍历像素值,计算最大类间方差
|
||
for i in range(max_value):
|
||
wA = 0
|
||
wB = 0
|
||
uAtmp = 0
|
||
uBtmp = 0
|
||
uA = 0
|
||
uB = 0
|
||
u = 0
|
||
for j in range(max_value):
|
||
if j <= i:
|
||
wA += hist[j]
|
||
uAtmp += j * hist[j]
|
||
else:
|
||
wB += hist[j]
|
||
uBtmp += j * hist[j]
|
||
if wA == 0:
|
||
wA = 1e-10
|
||
if wB == 0:
|
||
wB = 1e-10
|
||
uA = uAtmp / wA
|
||
uB = uBtmp / wB
|
||
u = uAtmp + uBtmp
|
||
|
||
# 计算类间方差
|
||
deltaTmp = wA * ((uA - u)**2) + wB * ((uB - u)**2)
|
||
# 找出最大类间方差以及阈值
|
||
if deltaTmp > deltaMax:
|
||
deltaMax = deltaTmp
|
||
threshold = i
|
||
|
||
# 二值化
|
||
det_img = img.copy()
|
||
det_img[img > threshold] = foreground
|
||
det_img[img <= threshold] = background
|
||
det_img[np.where(data_water_mask == 0)] = background
|
||
return det_img
|
||
|
||
|
||
@timeit
|
||
def zscore_threshold(img, data_water_mask, z_threshold=2.5, foreground=1, background=0):
|
||
"""
|
||
基于Z-score(标准化分数)的耀斑检测方法
|
||
使用统计方法识别异常高亮的像素,对数据分布不敏感
|
||
|
||
Args:
|
||
img: 输入图像数组
|
||
data_water_mask: 水域掩膜
|
||
z_threshold: Z-score阈值,默认2.5(即超过均值2.5个标准差)
|
||
foreground: 前景值
|
||
background: 背景值
|
||
|
||
Returns:
|
||
二值化检测结果
|
||
"""
|
||
# 只在水域掩膜区域计算统计量,排除无效值
|
||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||
|
||
if len(valid_pixels) == 0:
|
||
print("警告: 没有有效像素用于统计计算")
|
||
return np.zeros_like(img, dtype=np.int32)
|
||
|
||
mean_val = np.mean(valid_pixels)
|
||
std_val = np.std(valid_pixels)
|
||
|
||
if std_val == 0:
|
||
print("警告: 标准差为0,无法使用Z-score方法")
|
||
return np.zeros_like(img, dtype=np.int32)
|
||
|
||
# 计算Z-score(对无效值进行保护)
|
||
z_scores = np.zeros_like(img, dtype=np.float32)
|
||
valid_mask = (data_water_mask > 0) & np.isfinite(img)
|
||
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
|
||
|
||
# 二值化
|
||
det_img = np.zeros_like(img, dtype=np.int32)
|
||
det_img[z_scores > z_threshold] = foreground
|
||
det_img[np.where(data_water_mask == 0)] = background
|
||
|
||
print(f"Z-score方法: 均值={mean_val:.2f}, 标准差={std_val:.2f}, 阈值={mean_val + z_threshold * std_val:.2f}")
|
||
|
||
return det_img
|
||
|
||
|
||
@timeit
|
||
def percentile_threshold(img, data_water_mask, percentile=95, foreground=1, background=0):
|
||
"""
|
||
基于百分位数的耀斑检测方法
|
||
使用百分位数作为阈值,对异常值更稳健
|
||
|
||
Args:
|
||
img: 输入图像数组
|
||
data_water_mask: 水域掩膜
|
||
percentile: 百分位数阈值,默认95(即超过95%的像素值)
|
||
foreground: 前景值
|
||
background: 背景值
|
||
|
||
Returns:
|
||
二值化检测结果
|
||
"""
|
||
# 只在水域掩膜区域计算百分位数,排除无效值
|
||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||
|
||
if len(valid_pixels) == 0:
|
||
print("警告: 没有有效像素用于统计计算")
|
||
return np.zeros_like(img, dtype=np.int32)
|
||
|
||
threshold = np.percentile(valid_pixels, percentile)
|
||
|
||
# 二值化
|
||
det_img = np.zeros_like(img, dtype=np.int32)
|
||
det_img[img > threshold] = foreground
|
||
det_img[np.where(data_water_mask == 0)] = background
|
||
|
||
print(f"百分位数方法: {percentile}%分位数为 {threshold:.2f}")
|
||
|
||
return det_img
|
||
|
||
|
||
@timeit
|
||
def multi_band_glint_detection(dataset, img_path, water_mask, glint_waves, weights=None, method='zscore',
|
||
z_threshold=2.5, percentile=95, foreground=1, background=0):
|
||
"""
|
||
多波段融合的耀斑检测方法
|
||
结合多个波段的耀斑特征,提高检测的稳健性
|
||
|
||
Args:
|
||
dataset: GDAL数据集
|
||
img_path: 影像文件路径(用于获取波长信息)
|
||
water_mask: 水域掩膜数组
|
||
glint_waves: 用于检测的波长列表,如[750, 800, 850]
|
||
weights: 各波段的权重,如果为None则使用等权重
|
||
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
|
||
z_threshold: Z-score阈值(当method='zscore'时使用)
|
||
percentile: 百分位数阈值(当method='percentile'时使用)
|
||
foreground: 前景值
|
||
background: 背景值
|
||
|
||
Returns:
|
||
二值化检测结果
|
||
"""
|
||
num_bands = dataset.RasterCount
|
||
|
||
if weights is None:
|
||
weights = [1.0 / len(glint_waves)] * len(glint_waves)
|
||
|
||
if len(weights) != len(glint_waves):
|
||
raise ValueError("权重数量必须与波长数量相同")
|
||
|
||
# 读取多个波段并加权融合(使用float32保持精度)
|
||
fused_band = None
|
||
for i, wave in enumerate(glint_waves):
|
||
band_num = find_band_number(wave, img_path)
|
||
if band_num >= num_bands:
|
||
print(f"警告: 波段号 {band_num} 超出范围,跳过波长 {wave}")
|
||
continue
|
||
|
||
tmp = dataset.GetRasterBand(band_num + 1).ReadAsArray().astype(np.float32)
|
||
|
||
if fused_band is None:
|
||
fused_band = (tmp * weights[i]).astype(np.float32)
|
||
else:
|
||
fused_band = (fused_band + tmp * weights[i]).astype(np.float32)
|
||
|
||
if fused_band is None:
|
||
raise ValueError("没有有效的波段可以融合")
|
||
|
||
# 根据方法选择是否需要归一化
|
||
# 对于统计方法(zscore, percentile),直接使用原始反射率值
|
||
# 对于Otsu方法,需要归一化到整数范围
|
||
if method == 'otsu':
|
||
# Otsu方法需要整数范围,使用百分位数拉伸
|
||
fused_band_stretch = percentile_stretch(fused_band, water_mask,
|
||
lower_percentile=2, upper_percentile=98)
|
||
return otsu(fused_band_stretch, fused_band_stretch.max() + 1, water_mask,
|
||
foreground=foreground, background=background)
|
||
elif method == 'zscore':
|
||
# Z-score方法直接使用原始反射率值
|
||
return zscore_threshold(fused_band, water_mask, z_threshold, foreground, background)
|
||
elif method == 'percentile':
|
||
# 百分位数方法直接使用原始反射率值
|
||
return percentile_threshold(fused_band, water_mask, percentile, foreground, background)
|
||
else:
|
||
raise ValueError(f"不支持的方法: {method}")
|
||
|
||
|
||
@timeit
|
||
def adaptive_threshold(img, data_water_mask, window_size=15, percentile=90, foreground=1, background=0):
|
||
"""
|
||
自适应阈值方法
|
||
基于局部统计特性进行阈值分割,对光照变化更稳健
|
||
|
||
Args:
|
||
img: 输入图像数组
|
||
data_water_mask: 水域掩膜
|
||
window_size: 局部窗口大小(奇数)
|
||
percentile: 局部百分位数阈值
|
||
foreground: 前景值
|
||
background: 背景值
|
||
|
||
Returns:
|
||
二值化检测结果
|
||
"""
|
||
height, width = img.shape
|
||
|
||
# 确保窗口大小为奇数
|
||
if window_size % 2 == 0:
|
||
window_size += 1
|
||
|
||
half_window = window_size // 2
|
||
|
||
# 创建输出图像
|
||
det_img = np.zeros_like(img, dtype=np.int32)
|
||
|
||
# 对每个像素计算局部阈值
|
||
for i in range(half_window, height - half_window):
|
||
for j in range(half_window, width - half_window):
|
||
# 只在水域掩膜内处理
|
||
if data_water_mask[i, j] == 0:
|
||
continue
|
||
|
||
# 提取局部窗口
|
||
local_window = img[i - half_window:i + half_window + 1,
|
||
j - half_window:j + half_window + 1]
|
||
local_mask = data_water_mask[i - half_window:i + half_window + 1,
|
||
j - half_window:j + half_window + 1]
|
||
|
||
# 只考虑有效像素
|
||
valid_pixels = local_window[local_mask > 0]
|
||
|
||
if len(valid_pixels) > 0:
|
||
local_threshold = np.percentile(valid_pixels, percentile)
|
||
if img[i, j] > local_threshold:
|
||
det_img[i, j] = foreground
|
||
|
||
det_img[np.where(data_water_mask == 0)] = background
|
||
|
||
print(f"自适应阈值方法: 窗口大小={window_size}, 局部百分位数={percentile}%")
|
||
|
||
return det_img
|
||
|
||
|
||
@timeit
|
||
def iqr_outlier_detection(img, data_water_mask, iqr_multiplier=1.5, foreground=1, background=0):
|
||
"""
|
||
基于IQR(四分位距)的异常值检测方法
|
||
使用四分位距识别异常高亮的像素,对数据分布不敏感
|
||
|
||
Args:
|
||
img: 输入图像数组
|
||
data_water_mask: 水域掩膜
|
||
iqr_multiplier: IQR倍数,默认1.5(标准异常值检测)
|
||
foreground: 前景值
|
||
background: 背景值
|
||
|
||
Returns:
|
||
二值化检测结果
|
||
"""
|
||
# 只在水域掩膜区域计算统计量,排除无效值
|
||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||
|
||
if len(valid_pixels) == 0:
|
||
print("警告: 没有有效像素用于统计计算")
|
||
return np.zeros_like(img, dtype=np.int32)
|
||
|
||
q1 = np.percentile(valid_pixels, 25)
|
||
q3 = np.percentile(valid_pixels, 75)
|
||
iqr = q3 - q1
|
||
|
||
# 上界 = Q3 + 1.5 * IQR
|
||
upper_bound = q3 + iqr_multiplier * iqr
|
||
|
||
# 二值化
|
||
det_img = np.zeros_like(img, dtype=np.int32)
|
||
det_img[img > upper_bound] = foreground
|
||
det_img[np.where(data_water_mask == 0)] = background
|
||
|
||
print(f"IQR方法: Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}, 上界={upper_bound:.2f}")
|
||
|
||
return det_img
|
||
|
||
|
||
@timeit
|
||
def create_shoreline_buffer(water_mask, buffer_size=5, foreground=1, background=0):
|
||
"""
|
||
创建岸边缓冲区掩膜(向内缓冲)
|
||
用于去除岸边附近的错误耀斑检测区域
|
||
|
||
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
|
||
|
||
Args:
|
||
water_mask: 水域掩膜数组(水域=1,非水域=0)
|
||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||
foreground: 前景值
|
||
background: 背景值
|
||
|
||
Returns:
|
||
岸边缓冲区掩膜(缓冲区区域=1,其他=0)
|
||
"""
|
||
if buffer_size <= 0:
|
||
print("缓冲区大小为0或负数,不创建岸边缓冲区")
|
||
return np.zeros_like(water_mask, dtype=np.int32)
|
||
|
||
# 将水域掩膜转换为二值图像
|
||
water_binary = (water_mask > 0).astype(np.int32)
|
||
|
||
# 创建结构元素(方形结构元素)
|
||
# 结构元素大小由buffer_size决定,确保是奇数
|
||
structure_size = buffer_size * 2 + 1
|
||
structure = np.ones((structure_size, structure_size), dtype=np.int32)
|
||
|
||
# 对水域进行腐蚀,得到缩小后的水域
|
||
# 使用OpenCV替代scipy.ndimage.binary_erosion
|
||
eroded_water = cv2.erode(water_binary.astype(np.uint8), structure.astype(np.uint8)).astype(np.int32)
|
||
|
||
# 岸边缓冲区 = 原始水域 - 腐蚀后的水域
|
||
# 这给出了水域边缘向内buffer_size像素宽的缓冲区区域
|
||
buffer_mask = (water_binary - eroded_water).astype(np.int32)
|
||
|
||
buffer_pixels = np.sum(buffer_mask > 0)
|
||
print(f"岸边缓冲区: 创建了 {buffer_size} 像素宽的内向缓冲区,共 {buffer_pixels} 个像素")
|
||
|
||
return buffer_mask
|
||
|
||
|
||
@timeit
|
||
def remove_shoreline_buffer(glint_mask, water_mask, buffer_size=5, foreground=1, background=0):
|
||
"""
|
||
从耀斑掩膜中去除岸边缓冲区内的区域
|
||
|
||
Args:
|
||
glint_mask: 耀斑掩膜数组
|
||
water_mask: 水域掩膜数组
|
||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||
foreground: 前景值
|
||
background: 背景值
|
||
|
||
Returns:
|
||
去除岸边缓冲区后的耀斑掩膜
|
||
"""
|
||
if buffer_size <= 0:
|
||
print("缓冲区大小为0,不进行岸边缓冲区去除")
|
||
return glint_mask
|
||
|
||
# 创建岸边缓冲区掩膜
|
||
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
|
||
|
||
# 从耀斑掩膜中去除缓冲区内的区域
|
||
cleaned_glint_mask = glint_mask.copy()
|
||
cleaned_glint_mask[buffer_mask > 0] = background
|
||
|
||
removed_pixels = np.sum((glint_mask > 0) & (buffer_mask > 0))
|
||
remaining_pixels = np.sum(cleaned_glint_mask > 0)
|
||
|
||
if removed_pixels > 0:
|
||
print(f"岸边缓冲区去除: 从耀斑掩膜中移除了 {removed_pixels} 个岸边向内缓冲区域的像素,"
|
||
f"剩余 {remaining_pixels} 个像素")
|
||
else:
|
||
print(f"岸边缓冲区去除: 缓冲区区域没有耀斑掩膜,无需移除")
|
||
|
||
return cleaned_glint_mask
|
||
|
||
|
||
@timeit
|
||
def filter_large_components(binary_img, max_area=None, foreground=1, background=0):
|
||
"""
|
||
过滤掉面积超过阈值的连通域
|
||
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
|
||
|
||
Args:
|
||
binary_img: 二值化图像
|
||
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
|
||
如果为None,则不进行过滤
|
||
foreground: 前景值
|
||
background: 背景值
|
||
|
||
Returns:
|
||
过滤后的二值化图像
|
||
"""
|
||
if max_area is None or max_area <= 0:
|
||
return binary_img
|
||
|
||
# 连通域标记
|
||
# 使用OpenCV替代scipy.ndimage.label
|
||
binary_for_label = (binary_img == foreground).astype(np.uint8)
|
||
num_features, labeled_array, stats, centroids = cv2.connectedComponentsWithStats(binary_for_label, connectivity=8)
|
||
|
||
if num_features == 0:
|
||
print("没有检测到连通域")
|
||
return binary_img
|
||
|
||
# 使用OpenCV返回的stats信息直接获取连通域面积
|
||
# stats[:, cv2.CC_STAT_AREA] 包含每个连通域的面积(包括背景)
|
||
# 跳过索引0(背景)的面积,从索引1开始获取连通域面积
|
||
component_sizes = stats[1:, cv2.CC_STAT_AREA]
|
||
|
||
# 找出需要保留的连通域(面积 <= max_area)
|
||
keep_labels = np.where(component_sizes <= max_area)[0] + 1 # +1 因为标签从1开始
|
||
|
||
# 使用布尔索引一次性过滤(高效方法)
|
||
# 创建一个mask,标记所有需要保留的连通域
|
||
keep_mask = np.isin(labeled_array, keep_labels)
|
||
|
||
# 创建输出图像
|
||
filtered_img = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||
filtered_img[keep_mask] = foreground
|
||
|
||
# 统计信息
|
||
removed_count = num_features - len(keep_labels)
|
||
kept_count = len(keep_labels)
|
||
total_removed_pixels = np.sum(component_sizes[component_sizes > max_area])
|
||
|
||
if removed_count > 0:
|
||
print(f"连通域面积过滤: 移除了 {removed_count} 个大面积连通域(面积 > {max_area} 像素),"
|
||
f"共移除 {total_removed_pixels} 个像素;保留了 {kept_count} 个小面积连通域")
|
||
else:
|
||
print(f"连通域面积过滤: 所有 {kept_count} 个连通域面积均小于阈值 {max_area},全部保留")
|
||
|
||
return filtered_img
|
||
|
||
|
||
def find_overexposure_area(img_path, threhold=4095):
|
||
# 第一步通过某个像素的光谱找到信号最强的波段
|
||
|
||
# 根据上步所得的波段号检测过曝区域
|
||
pass
|
||
|
||
|
||
def create_water_mask_from_shp(shp_file, reference_raster):
|
||
"""
|
||
从shp文件创建水体掩膜栅格数组(内存中,不保存到磁盘)
|
||
|
||
参数:
|
||
shp_file: str - shp文件路径
|
||
reference_raster: str - 参考栅格文件路径(用于获取空间范围和分辨率)
|
||
|
||
返回:
|
||
numpy.ndarray - 水体掩膜数组
|
||
"""
|
||
try:
|
||
# 打开参考栅格获取空间信息
|
||
ref_dataset = gdal.Open(reference_raster)
|
||
if ref_dataset is None:
|
||
raise ValueError(f"无法打开参考栅格文件: {reference_raster}")
|
||
|
||
geotransform = ref_dataset.GetGeoTransform()
|
||
projection = ref_dataset.GetProjection()
|
||
width = ref_dataset.RasterXSize
|
||
height = ref_dataset.RasterYSize
|
||
|
||
# 创建内存中的栅格数据集
|
||
mem_driver = gdal.GetDriverByName('MEM')
|
||
mask_dataset = mem_driver.Create('', width, height, 1, gdal.GDT_Byte)
|
||
mask_dataset.SetGeoTransform(geotransform)
|
||
mask_dataset.SetProjection(projection)
|
||
|
||
# 初始化为0
|
||
mask_band = mask_dataset.GetRasterBand(1)
|
||
mask_band.Fill(0)
|
||
|
||
# 打开shp文件
|
||
shp_dataset = ogr.Open(shp_file)
|
||
if shp_dataset is None:
|
||
raise ValueError(f"无法打开shp文件: {shp_file}")
|
||
|
||
layer = shp_dataset.GetLayer()
|
||
|
||
# 栅格化shp文件
|
||
gdal.RasterizeLayer(mask_dataset, [1], layer, burn_values=[1])
|
||
|
||
# 读取栅格化结果
|
||
water_mask = mask_band.ReadAsArray()
|
||
|
||
# 清理
|
||
ref_dataset = None
|
||
mask_dataset = None
|
||
shp_dataset = None
|
||
|
||
return water_mask
|
||
|
||
except Exception as e:
|
||
print(f"创建水体掩膜时发生错误: {str(e)}")
|
||
raise
|
||
|
||
|
||
@timeit
|
||
def find_severe_glint_area(img_path, water_mask, glint_wave=750, output_path=None,
|
||
method='otsu', multi_band_waves=None, **kwargs):
|
||
"""
|
||
找到严重耀斑区域的主函数
|
||
|
||
注意:对于低反射率数据(如水面反射率约0.02),本函数采用了改进的归一化策略:
|
||
- 统计方法(zscore, percentile, iqr):直接使用原始反射率值,无需归一化
|
||
- Otsu和adaptive方法:使用百分位数裁剪拉伸(2%-98%分位数),避免极值影响
|
||
|
||
Args:
|
||
img_path: 输入影像路径
|
||
water_mask: 水域掩膜路径(支持栅格文件如.dat/.tif,或SHP文件如.shp;如果为None或空字符串,则使用全图进行检测)
|
||
glint_wave: 用于检测的波长(单个波段方法使用)
|
||
output_path: 输出路径
|
||
method: 检测方法,可选:
|
||
- 'otsu': Otsu阈值分割(默认,使用百分位数拉伸)
|
||
- 'zscore': Z-score统计方法(直接使用原始反射率)
|
||
- 'percentile': 百分位数阈值方法(直接使用原始反射率)
|
||
- 'iqr': IQR异常值检测(直接使用原始反射率)
|
||
- 'adaptive': 自适应阈值方法(使用百分位数拉伸)
|
||
- 'multi_band': 多波段融合方法
|
||
multi_band_waves: 多波段方法的波长列表,如[750, 800, 850]
|
||
**kwargs: 其他方法特定参数
|
||
- z_threshold: Z-score阈值(默认2.5)
|
||
- percentile: 百分位数(默认95)
|
||
- iqr_multiplier: IQR倍数(默认1.5)
|
||
- window_size: 自适应阈值窗口大小(默认15)
|
||
- weights: 多波段方法的权重列表
|
||
- sub_method: 多波段方法的子方法('otsu', 'zscore', 'percentile')
|
||
- max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被过滤掉
|
||
用于去除岸边、浅水、水华等大面积区域(默认None,表示不过滤)
|
||
- buffer_size: 岸边缓冲区大小(像素数),用于去除岸边附近的错误耀斑掩膜
|
||
默认None,表示不进行岸边缓冲区去除;设置为正整数时启用
|
||
|
||
Returns:
|
||
输出文件路径
|
||
"""
|
||
if output_path is None:
|
||
output_path = append2filename(img_path, "_severe_glint_area")
|
||
|
||
dataset = gdal.Open(img_path)
|
||
num_bands = dataset.RasterCount
|
||
im_width = dataset.RasterXSize
|
||
im_height = dataset.RasterYSize
|
||
|
||
# 读取水域掩膜,如果water_mask为None或空字符串,则创建全图掩膜
|
||
if water_mask is None or water_mask == "":
|
||
print("注意: water_mask为空,使用全图进行检测")
|
||
data_water_mask = np.ones((im_height, im_width), dtype=np.int32)
|
||
else:
|
||
# 检查是否为SHP文件
|
||
water_mask_lower = water_mask.lower()
|
||
if water_mask_lower.endswith('.shp'):
|
||
# 直接使用SHP文件,在内存中栅格化
|
||
print(f"检测到SHP文件,正在从 {water_mask} 创建水体掩膜...")
|
||
data_water_mask = create_water_mask_from_shp(water_mask, img_path)
|
||
else:
|
||
# 使用栅格文件
|
||
dataset_water_mask = gdal.Open(water_mask)
|
||
if dataset_water_mask is None:
|
||
raise ValueError(f"无法打开水域掩膜文件: {water_mask}")
|
||
data_water_mask = dataset_water_mask.GetRasterBand(1).ReadAsArray()
|
||
del dataset_water_mask
|
||
|
||
print(f"使用检测方法: {method}")
|
||
|
||
# 根据方法选择检测算法
|
||
if method == 'multi_band':
|
||
if multi_band_waves is None:
|
||
# 默认使用几个常见NIR波段
|
||
multi_band_waves = [glint_wave, glint_wave + 50, glint_wave + 100]
|
||
print(f"多波段方法: 使用默认波长 {multi_band_waves}")
|
||
else:
|
||
print(f"多波段方法: 使用波长 {multi_band_waves}")
|
||
|
||
sub_method = kwargs.get('sub_method', 'zscore')
|
||
weights = kwargs.get('weights', None)
|
||
z_threshold = kwargs.get('z_threshold', 2.5)
|
||
percentile = kwargs.get('percentile', 95)
|
||
|
||
flare_binary = multi_band_glint_detection(
|
||
dataset, img_path, data_water_mask, multi_band_waves, weights,
|
||
method=sub_method, z_threshold=z_threshold, percentile=percentile
|
||
)
|
||
else:
|
||
# 单波段方法
|
||
glint_band_number = find_band_number(glint_wave, img_path)
|
||
tmp = dataset.GetRasterBand(glint_band_number + 1)
|
||
band_flare = tmp.ReadAsArray().astype(np.float32)
|
||
|
||
# 根据方法选择是否需要归一化
|
||
# 对于统计方法(zscore, percentile, iqr),直接使用原始反射率值
|
||
# 对于Otsu和adaptive方法,需要归一化到整数范围
|
||
if method == 'otsu':
|
||
# Otsu方法需要整数范围,使用百分位数拉伸
|
||
band_flare_stretch = percentile_stretch(band_flare, data_water_mask,
|
||
lower_percentile=2, upper_percentile=98)
|
||
flare_binary = otsu(band_flare_stretch, band_flare_stretch.max() + 1, data_water_mask)
|
||
elif method == 'zscore':
|
||
# Z-score方法直接使用原始反射率值
|
||
z_threshold = kwargs.get('z_threshold', 2.5)
|
||
flare_binary = zscore_threshold(band_flare, data_water_mask, z_threshold)
|
||
elif method == 'percentile':
|
||
# 百分位数方法直接使用原始反射率值
|
||
percentile = kwargs.get('percentile', 95)
|
||
flare_binary = percentile_threshold(band_flare, data_water_mask, percentile)
|
||
elif method == 'iqr':
|
||
# IQR方法直接使用原始反射率值
|
||
iqr_multiplier = kwargs.get('iqr_multiplier', 1.5)
|
||
flare_binary = iqr_outlier_detection(band_flare, data_water_mask, iqr_multiplier)
|
||
elif method == 'adaptive':
|
||
# 自适应阈值方法需要归一化
|
||
band_flare_stretch = percentile_stretch(band_flare, data_water_mask,
|
||
lower_percentile=2, upper_percentile=98)
|
||
window_size = kwargs.get('window_size', 15)
|
||
percentile = kwargs.get('percentile', 90)
|
||
flare_binary = adaptive_threshold(band_flare_stretch, data_water_mask, window_size, percentile)
|
||
else:
|
||
raise ValueError(f"不支持的方法: {method}。可选方法: otsu, zscore, percentile, iqr, adaptive, multi_band")
|
||
|
||
# 过滤掉面积超过阈值的连通域(用于去除岸边、浅水、水华等大面积区域)
|
||
max_area = kwargs.get('max_area', None)
|
||
if max_area is not None and max_area > 0:
|
||
print(f"应用连通域面积过滤,最大面积阈值: {max_area} 像素")
|
||
flare_binary = filter_large_components(flare_binary, max_area=max_area)
|
||
|
||
# 去除岸边缓冲区内的耀斑掩膜(用于去除岸边的错误检测)
|
||
buffer_size = kwargs.get('buffer_size', None)
|
||
if buffer_size is not None and buffer_size > 0:
|
||
print(f"应用岸边缓冲区去除,缓冲区大小: {buffer_size} 像素")
|
||
flare_binary = remove_shoreline_buffer(flare_binary, data_water_mask, buffer_size=buffer_size)
|
||
|
||
write_bands(img_path, output_path, flare_binary)
|
||
|
||
del dataset
|
||
|
||
return output_path
|
||
|
||
|
||
# Press the green button in the gutter to run the script.
|
||
if __name__ == '__main__':
|
||
img_path = r"D:\PycharmProjects\0water_rlx\test_data\ref_mosaic_1m_bsq"
|
||
|
||
parser = argparse.ArgumentParser(
|
||
description="此程序通过多种算法分割图像,提取耀斑最严重的区域。"
|
||
"支持的算法: otsu, zscore, percentile, iqr, adaptive, multi_band"
|
||
)
|
||
|
||
parser.add_argument('-i1', '--input', type=str, required=True, help='输入影像文件的路径')
|
||
parser.add_argument('-i2', '--input_water_mask', type=str, required=True, help='输入水域掩膜文件的路径')
|
||
parser.add_argument('-gw', '--glint_wave', type=float, default=750.0,
|
||
help='用于提取耀斑严重区域的波段波长(单波段方法使用)')
|
||
parser.add_argument('-m', '--method', type=str, default='otsu',
|
||
choices=['otsu', 'zscore', 'percentile', 'iqr', 'adaptive', 'multi_band'],
|
||
help='检测方法: otsu(默认), zscore, percentile, iqr, adaptive, multi_band')
|
||
parser.add_argument('-o', '--output', type=str, help='输出文件的路径')
|
||
|
||
# 方法特定参数
|
||
parser.add_argument('-zt', '--z_threshold', type=float, default=2.5,
|
||
help='Z-score方法的阈值(默认2.5)')
|
||
parser.add_argument('-p', '--percentile', type=float, default=95.0,
|
||
help='百分位数阈值(默认95)')
|
||
parser.add_argument('-iqr', '--iqr_multiplier', type=float, default=1.5,
|
||
help='IQR方法的倍数(默认1.5)')
|
||
parser.add_argument('-ws', '--window_size', type=int, default=15,
|
||
help='自适应阈值方法的窗口大小(默认15)')
|
||
parser.add_argument('-mbw', '--multi_band_waves', type=str, default=None,
|
||
help='多波段方法的波长列表,用逗号分隔,如: 750,800,850')
|
||
parser.add_argument('-sm', '--sub_method', type=str, default='zscore',
|
||
choices=['otsu', 'zscore', 'percentile'],
|
||
help='多波段方法的子方法(默认zscore)')
|
||
parser.add_argument('-ma', '--max_area', type=int, default=None,
|
||
help='最大连通域面积阈值(像素数),超过此面积的连通域将被过滤掉,'
|
||
'用于去除岸边、浅水、水华等大面积区域(默认None,表示不过滤)')
|
||
parser.add_argument('-bs', '--buffer_size', type=int, default=None,
|
||
help='岸边缓冲区大小(像素数),用于去除岸边附近的错误耀斑掩膜'
|
||
'(默认None,表示不进行岸边缓冲区去除;设置为正整数时启用)')
|
||
|
||
parser.add_argument('-v', '--verbose', action='store_true', help='启用详细模式')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 解析多波段波长列表
|
||
multi_band_waves = None
|
||
if args.multi_band_waves:
|
||
multi_band_waves = [float(x.strip()) for x in args.multi_band_waves.split(',')]
|
||
|
||
# 构建kwargs
|
||
kwargs = {
|
||
'z_threshold': args.z_threshold,
|
||
'percentile': args.percentile,
|
||
'iqr_multiplier': args.iqr_multiplier,
|
||
'window_size': args.window_size,
|
||
'sub_method': args.sub_method,
|
||
'max_area': args.max_area,
|
||
'buffer_size': args.buffer_size
|
||
}
|
||
|
||
find_severe_glint_area(
|
||
args.input, args.input_water_mask, args.glint_wave, args.output,
|
||
method=args.method, multi_band_waves=multi_band_waves, **kwargs
|
||
)
|