Files
WQ_GUI/src/utils/sampling.py

1029 lines
44 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from src.utils.util import *
import math
import os
import numpy as np
from osgeo import gdal, ogr
import spectral
from scipy import ndimage
try:
from skimage import morphology
from skimage.morphology import skeletonize, medial_axis
SKIMAGE_AVAILABLE = True
except ImportError:
SKIMAGE_AVAILABLE = False
print("警告: skimage未安装将无法使用主水轴线检测功能")
def get_wavelengths_from_bil_header(bil_file):
"""
从BIL文件的头文件中读取波长信息使用spectral库
参数:
bil_file: str - BIL文件路径
返回:
list - 波长列表如果无法获取则返回None
"""
try:
# 获取头文件路径
header_file = os.path.splitext(bil_file)[0] + ".hdr"
if not os.path.exists(header_file):
print(f"警告: 找不到头文件 {header_file}")
return None
# 使用spectral库读取头文件
import spectral.io.envi as envi
header = envi.read_envi_header(header_file)
# 获取波长信息
wavelengths = header.get('wavelength', None)
if wavelengths is not None:
# 确保是列表形式
if isinstance(wavelengths, str):
# 如果是字符串,解析为列表
wavelengths = wavelengths.strip('{}').replace(',', ' ').split()
wavelengths = [float(w.strip()) for w in wavelengths if w.strip()]
# 过滤掉0值和无效值
wavelengths = [float(w) for w in wavelengths if float(w) > 0]
print(f"从头文件读取到 {len(wavelengths)} 个波长值")
print(f"波长范围: {min(wavelengths):.2f} ~ {max(wavelengths):.2f}")
return wavelengths
else:
print("警告: 头文件中未找到波长信息")
return None
except Exception as e:
print(f"读取头文件波长信息时发生错误: {str(e)}")
return None
def get_spectral_sampling_points_chunked(bil_file, water_mask_shp, severe_glint=None, output_csvpath=None,
interval=100, sample_radius=1, chunk_size=1000,
use_adaptive_sampling=True, min_interval=10, max_interval=200):
"""
基于bil文件、shp格式water_mask和severe_glint生成采样点并提取光谱数据分块处理版本
参数:
bil_file: str - bil格式的光谱数据文件路径
water_mask_shp: str - shp格式的水体掩膜文件路径
severe_glint: str - 耀斑掩膜文件路径(可选)
output_csvpath: str - 输出CSV文件路径可选
interval: int - 采样点间隔像元数当use_adaptive_sampling=False时使用
sample_radius: int - 采样点半径(像元数)
chunk_size: int - 每次处理的行数(控制内存使用)
use_adaptive_sampling: bool - 是否使用自适应采样根据水体宽度调整间隔默认True
min_interval: int - 自适应采样时的最小间隔像元数默认10
max_interval: int - 自适应采样时的最大间隔像元数默认200
返回:
tuple: (x_coords, y_coords, spectral_data) - 坐标列表和光谱数据数组
"""
# 初始化GDAL异常处理
gdal.UseExceptions()
ogr.UseExceptions()
try:
# 打开bil文件
dataset_bil = gdal.Open(bil_file)
if dataset_bil is None:
raise ValueError(f"无法打开bil文件: {bil_file}")
# 获取bil文件的基本信息
im_width = dataset_bil.RasterXSize
im_height = dataset_bil.RasterYSize
num_bands = dataset_bil.RasterCount
geotransform_input = dataset_bil.GetGeoTransform()
projection = dataset_bil.GetProjection()
print(f"bil文件信息: 宽度={im_width}, 高度={im_height}, 波段数={num_bands}")
print(f"分块处理,每次处理 {chunk_size}")
# 创建水体掩膜栅格
print("正在处理水体掩膜...")
water_mask_raster = create_water_mask_from_shp(water_mask_shp, bil_file)
# 处理耀斑掩膜(可选)
if severe_glint is not None:
dataset_severe_glint = gdal.Open(severe_glint)
if dataset_severe_glint is None:
raise ValueError(f"无法打开耀斑掩膜文件: {severe_glint}")
data_severe_glint = dataset_severe_glint.GetRasterBand(1).ReadAsArray()
print("已加载耀斑掩膜")
# 对glint边界进行外扩1-2像素作为缓冲
data_severe_glint = expand_glint_buffer(data_severe_glint, buffer_size=2)
else:
data_severe_glint = None
print("未使用耀斑掩膜")
# 创建有效区域掩膜
if data_severe_glint is not None:
valid_area = (water_mask_raster > 0) & (~(data_severe_glint > 0))
else:
valid_area = (water_mask_raster > 0)
# 计算水体宽度(用于自适应采样)
width_map = None
if use_adaptive_sampling:
print("正在计算水体宽度(用于自适应采样)...")
width_map = calculate_water_width(water_mask_raster)
if width_map is not None:
width_min = np.min(width_map[water_mask_raster > 0])
width_max = np.max(width_map[water_mask_raster > 0])
print(f"水体宽度范围: {width_min:.1f} - {width_max:.1f} 像元")
else:
print("警告: 无法计算水体宽度,将使用固定间隔采样")
use_adaptive_sampling = False
# 保存有效区域(可选)
if output_csvpath:
valid_area_path = os.path.splitext(output_csvpath)[0] + "_valid_area.bsq"
write_bands(bil_file, valid_area_path, valid_area.astype(np.uint8))
x_out = []
y_out = []
spectral_out = []
# 如果没有提供输出路径,则不保存文件
if output_csvpath:
f = open(output_csvpath, "w")
# 写入CSV头部
header = "x_coord,y_coord,pixel_x,pixel_y"
# 尝试从头文件读取波长名称
wavelengths = get_wavelengths_from_bil_header(bil_file)
if wavelengths is not None and len(wavelengths) == num_bands:
for i, wavelength in enumerate(wavelengths):
# 使用格式化字符串保留足够的小数位数通常波长需要4-6位小数
header += f",{wavelength:.6f}"
else:
# 如果无法获取波长信息,使用默认的波段编号
for i in range(num_bands):
header += f",band_{i + 1}"
f.write(header + "\n")
else:
f = None
try:
print("正在分块生成采样点...")
sample_count = 0
sampled_pixels = set() # 用于记录已采样的像素,避免重复
# 辅助函数:添加采样点(分块版本)
def add_sample_point_chunked(x, y, local_y, spectral_chunk, valid_chunk, sample_radius,
geotransform_input, num_bands, f, x_out, y_out,
spectral_out, sampled_pixels):
"""添加单个采样点(分块版本)"""
# 检查是否已采样
if (x, y) in sampled_pixels:
return False
# 检查边界
if (x < sample_radius or x >= im_width - sample_radius or
local_y < sample_radius or local_y >= valid_chunk.shape[0] - sample_radius):
return False
# 检查采样点周围区域是否全部有效
sample_area = valid_chunk[
local_y - sample_radius:local_y + sample_radius + 1,
x - sample_radius:x + sample_radius + 1
]
# 如果采样区域内所有像元都有效
if np.all(sample_area):
# 提取光谱数据(采样区域内的平均值)
spectral_sample = []
for band_idx in range(num_bands):
band_data = spectral_chunk[
band_idx,
local_y - sample_radius:local_y + sample_radius + 1,
x - sample_radius:x + sample_radius + 1
]
# 计算平均值,忽略无效值
valid_pixels = band_data[sample_area]
if len(valid_pixels) > 0:
mean_value = np.mean(valid_pixels)
else:
mean_value = np.nan
spectral_sample.append(mean_value)
# 转换为地理坐标
geo_x, geo_y = gdal.ApplyGeoTransform(
geotransform_input,
x + 0.5, # 像元中心
y + 0.5
)
# 保存结果
if f:
line_parts = [f"{geo_x:.6f}", f"{geo_y:.6f}", f"{x}", f"{y}"]
for spec_val in spectral_sample:
line_parts.append(f"{spec_val:.6f}")
f.write(",".join(line_parts) + "\n")
x_out.append(geo_x)
y_out.append(geo_y)
spectral_out.append(spectral_sample)
sampled_pixels.add((x, y))
return True
return False
# 计算需要处理的块数
total_chunks = math.ceil((im_height - 2 * sample_radius) / chunk_size)
for chunk_idx in range(total_chunks):
# 计算当前块的行范围
start_row = sample_radius + chunk_idx * chunk_size
end_row = min(sample_radius + (chunk_idx + 1) * chunk_size, im_height - sample_radius)
# 扩展范围以包含采样半径
read_start = max(0, start_row - sample_radius)
read_end = min(im_height, end_row + sample_radius)
print(f"处理块 {chunk_idx + 1}/{total_chunks}: 行 {start_row}-{end_row}")
# 读取当前块的光谱数据
spectral_chunk = dataset_bil.ReadAsArray(
0, read_start, im_width, read_end - read_start
) # shape: (bands, chunk_height, width)
# 获取对应的有效区域掩膜和宽度图
valid_chunk = valid_area[read_start:read_end, :]
water_chunk = water_mask_raster[read_start:read_end, :]
width_chunk = width_map[read_start:read_end, :] if width_map is not None else None
# 自适应采样:根据水体宽度调整采样间隔
if use_adaptive_sampling and width_chunk is not None:
print(f" 使用自适应采样(间隔范围: {min_interval}-{max_interval}...")
# 计算宽度范围用于归一化
width_chunk_valid = width_chunk[water_chunk > 0]
if len(width_chunk_valid) > 0:
width_min_chunk = np.min(width_chunk_valid)
width_max_chunk = np.max(width_chunk_valid)
# 使用基础间隔作为网格起点
base_interval = min(interval, max_interval)
# 使用网格化采样,但根据局部宽度调整间隔
y = start_row
while y < end_row:
local_y = y - read_start
if local_y < 0 or local_y >= valid_chunk.shape[0]:
y += base_interval
continue
x = sample_radius
while x < im_width - sample_radius:
# 检查当前位置是否在水体区域内
if (local_y >= 0 and local_y < valid_chunk.shape[0] and
x >= 0 and x < valid_chunk.shape[1] and
water_chunk[local_y, x] > 0):
# 获取当前位置的水体宽度
local_width = width_chunk[local_y, x]
# 根据宽度计算采样间隔
# 窄的区域使用小间隔,宽的区域使用大间隔
if width_max_chunk > width_min_chunk:
# 归一化宽度到[0, 1]
normalized_width = (local_width - width_min_chunk) / (width_max_chunk - width_min_chunk)
# 映射到[min_interval, max_interval]
adaptive_interval = max(min_interval, min(max_interval,
int(min_interval + normalized_width * (max_interval - min_interval))))
else:
adaptive_interval = base_interval
# 尝试添加采样点
if add_sample_point_chunked(x, y, local_y, spectral_chunk, valid_chunk, sample_radius,
geotransform_input, num_bands, f, x_out, y_out,
spectral_out, sampled_pixels):
sample_count += 1
# 根据自适应间隔前进
x += adaptive_interval
else:
# 不在水体区域内,使用基础间隔快速跳过
x += base_interval
# 行间隔也使用基础间隔
y += base_interval
else:
# 固定间隔采样
print(f" 使用固定间隔采样(间隔: {interval}...")
for y in range(start_row, end_row, interval):
for x in range(sample_radius, im_width - sample_radius, interval):
local_y = y - read_start
if add_sample_point_chunked(x, y, local_y, spectral_chunk, valid_chunk, sample_radius,
geotransform_input, num_bands, f, x_out, y_out,
spectral_out, sampled_pixels):
sample_count += 1
# 清理当前块的数据以释放内存
del spectral_chunk
del valid_chunk
del water_chunk
if width_chunk is not None:
del width_chunk
print(f"{chunk_idx + 1} 完成,当前采样点总数: {sample_count}")
print(f"所有块处理完成,成功生成 {sample_count} 个采样点")
finally:
if f:
f.close()
return x_out, y_out, np.array(spectral_out)
except Exception as e:
print(f"处理过程中发生错误: {str(e)}")
raise
def get_spectral_sampling_points(bil_file, water_mask_shp, severe_glint=None, output_csvpath=None,
interval=100, sample_radius=1,
use_adaptive_sampling=True, min_interval=10, max_interval=200):
"""
基于bil文件、shp格式water_mask和severe_glint生成采样点并提取光谱数据
参数:
bil_file: str - bil格式的光谱数据文件路径
water_mask_shp: str - shp格式的水体掩膜文件路径
severe_glint: str - 耀斑掩膜文件路径(可选)
output_csvpath: str - 输出CSV文件路径可选
interval: int - 采样点间隔像元数当use_adaptive_sampling=False时使用
sample_radius: int - 采样点半径(像元数)
use_adaptive_sampling: bool - 是否使用自适应采样根据水体宽度调整间隔默认True
min_interval: int - 自适应采样时的最小间隔像元数默认10
max_interval: int - 自适应采样时的最大间隔像元数默认200
返回:
tuple: (x_coords, y_coords, spectral_data) - 坐标列表和光谱数据数组
"""
# 初始化GDAL异常处理
gdal.UseExceptions()
ogr.UseExceptions()
try:
# 打开bil文件
dataset_bil = gdal.Open(bil_file)
if dataset_bil is None:
raise ValueError(f"无法打开bil文件: {bil_file}")
# 获取bil文件的基本信息
im_width = dataset_bil.RasterXSize
im_height = dataset_bil.RasterYSize
num_bands = dataset_bil.RasterCount
geotransform_input = dataset_bil.GetGeoTransform()
projection = dataset_bil.GetProjection()
print(f"bil文件信息: 宽度={im_width}, 高度={im_height}, 波段数={num_bands}")
# 读取光谱数据(所有波段)
print("正在读取光谱数据...")
spectral_data_full = dataset_bil.ReadAsArray() # shape: (bands, height, width)
# 创建水体掩膜栅格
print("正在处理水体掩膜...")
water_mask_raster = create_water_mask_from_shp(water_mask_shp, bil_file)
# 处理耀斑掩膜(可选)
if severe_glint is not None:
dataset_severe_glint = gdal.Open(severe_glint)
if dataset_severe_glint is None:
raise ValueError(f"无法打开耀斑掩膜文件: {severe_glint}")
data_severe_glint = dataset_severe_glint.GetRasterBand(1).ReadAsArray()
print("已加载耀斑掩膜")
# 对glint边界进行外扩1-2像素作为缓冲
data_severe_glint = expand_glint_buffer(data_severe_glint, buffer_size=2)
else:
data_severe_glint = None
print("未使用耀斑掩膜")
inv_geotransform_input = gdal.InvGeoTransform(geotransform_input)
if inv_geotransform_input is None:
raise ValueError("无法计算逆仿射变换")
# 计算范围
x_min = geotransform_input[0]
y_max = geotransform_input[3]
x_max = x_min + im_width * geotransform_input[1]
y_min = y_max + im_height * geotransform_input[5]
# 创建有效区域掩膜
if data_severe_glint is not None:
valid_area = (water_mask_raster > 0) & (~(data_severe_glint > 0))
else:
valid_area = (water_mask_raster > 0)
# 计算水体宽度(用于自适应采样)
width_map = None
if use_adaptive_sampling:
print("正在计算水体宽度(用于自适应采样)...")
width_map = calculate_water_width(water_mask_raster)
if width_map is not None:
width_min = np.min(width_map[water_mask_raster > 0])
width_max = np.max(width_map[water_mask_raster > 0])
print(f"水体宽度范围: {width_min:.1f} - {width_max:.1f} 像元")
else:
print("警告: 无法计算水体宽度,将使用固定间隔采样")
use_adaptive_sampling = False
# 保存有效区域(可选)
if output_csvpath:
valid_area_path = os.path.splitext(output_csvpath)[0] + "_valid_area.tif"
write_bands(bil_file, valid_area_path, valid_area.astype(np.uint8))
x_out = []
y_out = []
spectral_out = []
# 如果没有提供输出路径,则不保存文件
if output_csvpath:
f = open(output_csvpath, "w")
# 写入CSV头部
header = "x_coord,y_coord,pixel_x,pixel_y"
# 尝试从头文件读取波长名称
wavelengths = get_wavelengths_from_bil_header(bil_file)
if wavelengths is not None and len(wavelengths) == num_bands:
for i, wavelength in enumerate(wavelengths):
# 使用格式化字符串保留足够的小数位数通常波长需要4-6位小数
header += f",{wavelength:.6f}"
else:
# 如果无法获取波长信息,使用默认的波段编号
for i in range(num_bands):
header += f",band_{i + 1}"
f.write(header + "\n")
else:
f = None
try:
print("正在生成采样点...")
sample_count = 0
sampled_pixels = set() # 用于记录已采样的像素,避免重复
# 辅助函数:添加采样点
def add_sample_point(x, y, spectral_data_full, valid_area, sample_radius,
geotransform_input, num_bands, f, x_out, y_out,
spectral_out, sampled_pixels):
"""添加单个采样点"""
# 检查是否已采样
if (x, y) in sampled_pixels:
return False
# 检查边界
if (x < sample_radius or x >= im_width - sample_radius or
y < sample_radius or y >= im_height - sample_radius):
return False
# 检查采样点周围区域是否全部有效
sample_area = valid_area[y - sample_radius:y + sample_radius + 1,
x - sample_radius:x + sample_radius + 1]
# 如果采样区域内所有像元都有效
if np.all(sample_area):
# 提取光谱数据(采样区域内的平均值)
spectral_sample = []
for band_idx in range(num_bands):
band_data = spectral_data_full[band_idx,
y - sample_radius:y + sample_radius + 1,
x - sample_radius:x + sample_radius + 1]
# 计算平均值,忽略无效值
valid_pixels = band_data[sample_area]
if len(valid_pixels) > 0:
mean_value = np.mean(valid_pixels)
else:
mean_value = np.nan
spectral_sample.append(mean_value)
# 转换为地理坐标
geo_x, geo_y = gdal.ApplyGeoTransform(
geotransform_input,
x + 0.5, # 像元中心
y + 0.5
)
# 保存结果
if f:
line_parts = [f"{geo_x:.6f}", f"{geo_y:.6f}", f"{x}", f"{y}"]
for spec_val in spectral_sample:
line_parts.append(f"{spec_val:.6f}")
f.write(",".join(line_parts) + "\n")
x_out.append(geo_x)
y_out.append(geo_y)
spectral_out.append(spectral_sample)
sampled_pixels.add((x, y))
return True
return False
# 自适应采样:根据水体宽度调整采样间隔
if use_adaptive_sampling and width_map is not None:
print("使用自适应采样(根据水体宽度调整间隔)...")
# 计算宽度范围用于归一化
width_valid = width_map[water_mask_raster > 0]
if len(width_valid) > 0:
width_min = np.min(width_valid)
width_max = np.max(width_valid)
# 使用基础间隔作为网格起点
base_interval = min(interval, max_interval)
# 使用网格化采样,但根据局部宽度调整间隔
y = sample_radius
while y < im_height - sample_radius:
x = sample_radius
while x < im_width - sample_radius:
# 检查当前位置是否在水体区域内
if (water_mask_raster[y, x] > 0):
# 获取当前位置的水体宽度
local_width = width_map[y, x]
# 根据宽度计算采样间隔
# 窄的区域使用小间隔,宽的区域使用大间隔
if width_max > width_min:
# 归一化宽度到[0, 1]
normalized_width = (local_width - width_min) / (width_max - width_min)
# 映射到[min_interval, max_interval]
adaptive_interval = max(min_interval, min(max_interval,
int(min_interval + normalized_width * (max_interval - min_interval))))
else:
adaptive_interval = base_interval
# 尝试添加采样点
if add_sample_point(x, y, spectral_data_full, valid_area, sample_radius,
geotransform_input, num_bands, f, x_out, y_out,
spectral_out, sampled_pixels):
sample_count += 1
# 根据自适应间隔前进
x += adaptive_interval
else:
# 不在水体区域内,使用基础间隔快速跳过
x += base_interval
# 行间隔也使用基础间隔
y += base_interval
else:
# 如果无法计算宽度,使用固定间隔
print("无法获取有效宽度信息,使用固定间隔采样...")
use_adaptive_sampling = False
# 固定间隔采样(当不使用自适应采样时)
if not use_adaptive_sampling:
print(f"使用固定间隔采样(间隔: {interval}...")
for y in range(sample_radius, im_height - sample_radius, interval):
for x in range(sample_radius, im_width - sample_radius, interval):
if add_sample_point(x, y, spectral_data_full, valid_area, sample_radius,
geotransform_input, num_bands, f, x_out, y_out,
spectral_out, sampled_pixels):
sample_count += 1
print(f"成功生成 {sample_count} 个采样点")
finally:
if f:
f.close()
return x_out, y_out, np.array(spectral_out)
except Exception as e:
print(f"处理过程中发生错误: {str(e)}")
raise
def create_water_mask_from_shp(shp_file, reference_raster):
"""
从shp文件或栅格文件创建水体掩膜栅格
参数:
shp_file: str - shp文件路径或栅格文件路径.dat/.tif等
reference_raster: str - 参考栅格文件路径用于获取空间范围和分辨率当shp_file为shp格式时需要
返回:
numpy.ndarray - 水体掩膜数组
"""
try:
# 检查文件格式
file_ext = os.path.splitext(shp_file)[1].lower()
if file_ext == '.shp':
# shp格式需要栅格化
# 打开参考栅格获取空间信息
ref_dataset = gdal.Open(reference_raster)
if ref_dataset is None:
raise ValueError(f"无法打开参考栅格文件: {reference_raster}")
geotransform = ref_dataset.GetGeoTransform()
projection = ref_dataset.GetProjection()
width = ref_dataset.RasterXSize
height = ref_dataset.RasterYSize
# 创建内存中的栅格数据集
mem_driver = gdal.GetDriverByName('MEM')
mask_dataset = mem_driver.Create('', width, height, 1, gdal.GDT_Byte)
mask_dataset.SetGeoTransform(geotransform)
mask_dataset.SetProjection(projection)
# 初始化为0
mask_band = mask_dataset.GetRasterBand(1)
mask_band.Fill(0)
# 打开shp文件
shp_dataset = ogr.Open(shp_file)
if shp_dataset is None:
raise ValueError(f"无法打开shp文件: {shp_file}")
layer = shp_dataset.GetLayer()
# 栅格化shp文件
gdal.RasterizeLayer(mask_dataset, [1], layer, burn_values=[1])
# 读取栅格化结果
water_mask = mask_band.ReadAsArray()
# 清理
ref_dataset = None
mask_dataset = None
shp_dataset = None
return water_mask
else:
# 栅格格式(.dat/.tif等直接读取
mask_dataset = gdal.Open(shp_file, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开栅格掩膜文件: {shp_file}")
# 读取第一个波段
water_mask = mask_dataset.GetRasterBand(1).ReadAsArray()
# 清理
mask_dataset = None
return water_mask
except Exception as e:
print(f"创建水体掩膜时发生错误: {str(e)}")
raise
def expand_glint_buffer(glint_mask, buffer_size=2):
"""
对glint掩膜进行边界外扩作为缓冲区域
参数:
glint_mask: numpy.ndarray - glint掩膜数组glint区域>0
buffer_size: int - 外扩像素数默认2范围1-2
返回:
numpy.ndarray - 外扩后的glint掩膜
"""
if glint_mask is None:
return None
# 限制buffer_size在1-2之间
buffer_size = max(1, min(2, int(buffer_size)))
# 将glint掩膜转换为二值图像
glint_binary = (glint_mask > 0).astype(np.uint8)
# 创建结构元素(圆形或方形)
# 使用3x3或5x5的结构元素进行膨胀
if buffer_size == 1:
structure = np.ones((3, 3), dtype=np.uint8)
else: # buffer_size == 2
structure = np.ones((5, 5), dtype=np.uint8)
# 对glint区域进行膨胀操作
expanded_glint = ndimage.binary_dilation(glint_binary, structure=structure).astype(np.uint8)
expanded_pixels = np.sum(expanded_glint > 0) - np.sum(glint_binary > 0)
print(f"Glint边界外扩: 外扩 {buffer_size} 像素,新增 {expanded_pixels} 个像素")
return expanded_glint
def calculate_water_width(water_mask):
"""
计算水体宽度(使用距离变换)
参数:
water_mask: numpy.ndarray - 水体掩膜数组(水体区域>0
返回:
numpy.ndarray - 宽度数组,每个像素的值表示到最近边界的距离(即局部宽度的一半)
"""
try:
# 将水体掩膜转换为二值图像
water_binary = (water_mask > 0).astype(bool)
# 计算距离变换:每个像素到最近边界的距离
# 使用欧氏距离变换
from scipy.ndimage import distance_transform_edt
# 计算到边界的距离(内部像素)
distance = distance_transform_edt(water_binary)
# 宽度 = 2 * 距离(因为距离是到边界的距离,宽度是两倍)
# 但为了简化,我们直接使用距离作为宽度的近似值
# 在窄的区域,距离小;在宽的区域,距离大
width_map = distance.astype(np.float32)
return width_map
except Exception as e:
print(f"计算水体宽度时发生错误: {str(e)}")
return None
def detect_water_centerline(water_mask):
"""
检测水体的主轴线(中心线/骨架)
参数:
water_mask: numpy.ndarray - 水体掩膜数组
返回:
numpy.ndarray - 中心线掩膜(中心线位置=1其他=0
"""
if not SKIMAGE_AVAILABLE:
print("警告: skimage未安装无法检测主水轴线")
return None
try:
# 将水体掩膜转换为二值图像
water_binary = (water_mask > 0).astype(bool)
# 使用骨架化提取中心线
skeleton = skeletonize(water_binary)
skeleton_pixels = np.sum(skeleton > 0)
print(f"主水轴线检测: 检测到 {skeleton_pixels} 个中心线像素")
return skeleton.astype(np.uint8)
except Exception as e:
print(f"检测主水轴线时发生错误: {str(e)}")
return None
def detect_water_features(water_mask, centerline_mask=None):
"""
检测水体的地形特征点:分支口、汇入口、弯头
参数:
water_mask: numpy.ndarray - 水体掩膜数组
centerline_mask: numpy.ndarray - 中心线掩膜(可选,如果提供则只检测中心线上的特征点)
返回:
numpy.ndarray - 特征点掩膜(特征点位置=1其他=0
"""
try:
# 如果提供了中心线掩膜,则基于中心线检测特征点
if centerline_mask is not None:
skeleton = (centerline_mask > 0).astype(bool)
else:
# 如果没有提供中心线,先提取中心线
if SKIMAGE_AVAILABLE:
skeleton = skeletonize((water_mask > 0).astype(bool))
else:
print("警告: 无法检测地形特征点(需要中心线)")
return np.zeros_like(water_mask, dtype=np.uint8)
# 检测中心线上的交叉点(分支口/汇入口)
# 通过计算每个像素的邻域连接数来判断
features = np.zeros_like(water_mask, dtype=np.uint8)
# 使用3x3卷积核检测交叉点
kernel = np.array([[1, 1, 1],
[1, 0, 1],
[1, 1, 1]], dtype=np.uint8)
# 计算每个中心线像素的邻域连接数
neighbor_count = ndimage.convolve(skeleton.astype(np.uint8), kernel, mode='constant')
# 交叉点:邻域连接数 >= 3分支口或汇入口
branch_points = (skeleton) & (neighbor_count >= 3)
# 检测弯头:通过计算曲率变化
# 使用Sobel算子计算梯度方向变化
from scipy.ndimage import sobel, gaussian_filter
sobel_x = sobel(skeleton.astype(float), axis=1)
sobel_y = sobel(skeleton.astype(float), axis=0)
# 计算梯度方向
gradient_magnitude = np.sqrt(sobel_x**2 + sobel_y**2)
# 弯头:梯度方向变化大的点(在中心线上且梯度变化大)
# 使用梯度幅值的二阶导数来检测曲率变化
smoothed_gradient = gaussian_filter(gradient_magnitude, sigma=1.0)
gradient_laplacian = ndimage.laplace(smoothed_gradient)
# 弯头:中心线上的点且梯度变化大(阈值可调)
curvature_threshold = np.percentile(gradient_laplacian[skeleton], 75) # 使用75分位数作为阈值
bend_points = (skeleton) & (np.abs(gradient_laplacian) > curvature_threshold)
# 合并所有特征点
features = (branch_points | bend_points).astype(np.uint8)
feature_count = np.sum(features > 0)
branch_count = np.sum(branch_points > 0)
bend_count = np.sum(bend_points > 0)
print(f"地形特征点检测: 检测到 {feature_count} 个特征点(分支/汇入: {branch_count}, 弯头: {bend_count}")
return features
except Exception as e:
print(f"检测地形特征点时发生错误: {str(e)}")
return np.zeros_like(water_mask, dtype=np.uint8)
def get_coor_base_interval(water_mask, severe_glint=None, output_csvpath=None, interval=100):
# 原有函数保持不变
# 初始化GDAL异常处理
gdal.UseExceptions()
try:
dataset_water_mask = gdal.Open(water_mask)
if dataset_water_mask is None:
raise ValueError(f"无法打开水体掩膜文件: {water_mask}")
data_water_mask = dataset_water_mask.GetRasterBand(1).ReadAsArray()
# 处理耀斑掩膜(可选)
if severe_glint is not None:
dataset_severe_glint = gdal.Open(severe_glint)
if dataset_severe_glint is None:
raise ValueError(f"无法打开耀斑掩膜文件: {severe_glint}")
data_severe_glint = dataset_severe_glint.GetRasterBand(1).ReadAsArray()
# 使用耀斑掩膜的几何信息
im_width = dataset_severe_glint.RasterXSize
im_height = dataset_severe_glint.RasterYSize
geotransform_input = dataset_severe_glint.GetGeoTransform()
else:
data_severe_glint = None
# 使用水体掩膜的几何信息
im_width = dataset_water_mask.RasterXSize
im_height = dataset_water_mask.RasterYSize
geotransform_input = dataset_water_mask.GetGeoTransform()
inv_geotransform_input = gdal.InvGeoTransform(geotransform_input)
if inv_geotransform_input is None:
raise ValueError("无法计算逆仿射变换")
# 修正范围计算
x_min = geotransform_input[0]
y_max = geotransform_input[3]
x_max = x_min + im_width * geotransform_input[1]
y_min = y_max + im_height * geotransform_input[5]
x_range = [x_min, x_max]
y_range = [y_min, y_max]
# 计算网格大小 - 确保最小为1个像元
pixel_size = abs(geotransform_input[1]) # 原始像元大小
grid_size = max(pixel_size * interval, pixel_size) # 网格不能小于原始像元
# 计算网格数量
dx = max(1, math.ceil((x_range[1] - x_range[0]) / grid_size))
dy = max(1, math.ceil((y_range[1] - y_range[0]) / grid_size))
# 创建输出网格的仿射变换
geotransform_out = (x_range[0], grid_size, 0, y_range[1], 0, -grid_size)
# 创建有效区域掩膜
if data_severe_glint is not None:
valid_area = (data_water_mask > 0) & (~(data_severe_glint > 0))
else:
valid_area = (data_water_mask > 0)
# 保存有效区域(可选)
valid_area_path = os.path.splitext(output_csvpath)[0] + "_valid_area.tif"
write_bands(water_mask, valid_area_path, valid_area.astype(np.uint8))
x_out = []
y_out = []
# 如果没有提供输出路径,则不保存文件
if output_csvpath:
f = open(output_csvpath, "w")
else:
f = None
try:
for row in range(dy):
for column in range(dx):
# 计算当前网格的四个角点(地理坐标)
top_left = (
x_range[0] + column * grid_size,
y_range[1] + row * (-grid_size) # 注意Y方向
)
bottom_right = (
x_range[0] + (column + 1) * grid_size,
y_range[1] + (row + 1) * (-grid_size)
)
# 转换为像元坐标
top_left_px = gdal.ApplyGeoTransform(inv_geotransform_input, top_left[0], top_left[1])
bottom_right_px = gdal.ApplyGeoTransform(inv_geotransform_input, bottom_right[0], bottom_right[1])
# 确保坐标在图像范围内
x1 = max(0, int(top_left_px[0]))
y1 = max(0, int(top_left_px[1]))
x2 = min(im_width, int(bottom_right_px[0]) + 1) # +1 确保包含
y2 = min(im_height, int(bottom_right_px[1]) + 1)
# 检查网格是否有效
if x2 <= x1 or y2 <= y1:
continue
# 提取当前网格内的有效区域
valid_area_local = valid_area[y1:y2, x1:x2]
# 查找有效像元
valid_pixels = np.argwhere(valid_area_local)
if valid_pixels.size > 0:
# 取第一个有效像元
local_y, local_x = valid_pixels[0]
# 转换为全局像元坐标
global_x = x1 + local_x
global_y = y1 + local_y
# 转换为地理坐标
geo_x, geo_y = gdal.ApplyGeoTransform(
geotransform_input,
global_x + 0.5, # 像元中心
global_y + 0.5
)
# 写入结果
if f:
line_parts = [f"{geo_x:.6f}", f"{geo_y:.6f}", f"{global_x}", f"{global_y}"]
f.write(",".join(line_parts) + "\n")
x_out.append(geo_x)
y_out.append(geo_y)
finally:
if f:
f.close()
return x_out, y_out
except Exception as e:
print(f"处理过程中发生错误: {str(e)}")
raise
# 使用示例
if __name__ == "__main__":
# 新功能使用示例
bil_file = r"E:\wq_gui_test\3_deglint\deglint_goodman.bsq"
water_mask_shp = r"E:\wq_gui_test\1_water_mask\water_mask_from_shp.dat"
severe_glint = r"E:\wq_gui_test\2_glint\severe_glint_area.dat"
output_csvpath = r"E:\wq_gui_test\10_sampling\sampling_spectra.csv"
# 设置参数
interval = 50 # 基础采样点间隔像元数当use_adaptive_sampling=False时使用
sample_radius = 5 # 采样点半径(像元数)
chunk_size = 1000 # 每次处理的行数可根据内存大小调整建议500-2000
# 自适应采样参数
use_adaptive_sampling = True # 是否使用自适应采样(根据水体宽度调整间隔)
min_interval = 10 # 自适应采样时的最小间隔(像元数),用于窄的入库流区域
max_interval = 200 # 自适应采样时的最大间隔(像元数),用于宽的水库区域
# 调用分块处理函数适用于大型bil文件内存友好
try:
x_coords, y_coords, spectral_data = get_spectral_sampling_points_chunked(
bil_file, water_mask_shp, severe_glint, output_csvpath,
interval, sample_radius, chunk_size,
use_adaptive_sampling=use_adaptive_sampling,
min_interval=min_interval,
max_interval=max_interval
)
print(f"成功生成 {len(x_coords)} 个采样点")
print(f"每个采样点包含 {spectral_data.shape[1]} 个波段的光谱数据")
print(f"光谱数据形状: {spectral_data.shape}")
except Exception as e:
print(f"处理失败: {str(e)}")
# 原有功能使用示例(保持向后兼容)
# water_mask = r"D:\hsi\application\LICA_Work\laodaohe\preprocession\liuyang-guitang2\1.mask\mask.dat"
# severe_glint = r"D:\hsi\application\LICA_Work\laodaohe\preprocession\liuyang-guitang2\2.glint\ref_mosaic_1m_severe_glint"
# output_csvpath = r"D:\hsi\application\LICA_Work\laodaohe\preprocession\liuyang-guitang2\5.interval\coor_interval_3.csv"
# interval = 30
# x_coords, y_coords = get_coor_base_interval(water_mask, severe_glint, output_csvpath, interval)
# print(f"成功生成 {len(x_coords)} 个采样点")