1061 lines
46 KiB
Python
1061 lines
46 KiB
Python
from src.utils.util import *
|
||
import math
|
||
import os
|
||
import numpy as np
|
||
from osgeo import gdal, ogr
|
||
import spectral
|
||
from scipy import ndimage
|
||
try:
|
||
from skimage import morphology
|
||
from skimage.morphology import skeletonize, medial_axis
|
||
SKIMAGE_AVAILABLE = True
|
||
except ImportError:
|
||
SKIMAGE_AVAILABLE = False
|
||
print("警告: skimage未安装,将无法使用主水轴线检测功能")
|
||
|
||
|
||
def get_wavelengths_from_bil_header(bil_file):
|
||
"""
|
||
从BIL文件的头文件中读取波长信息
|
||
|
||
参数:
|
||
bil_file: str - BIL文件路径
|
||
|
||
返回:
|
||
list - 波长列表,如果无法获取则返回None
|
||
"""
|
||
try:
|
||
# 获取头文件路径(通常与BIL文件同目录,后缀为.hdr)
|
||
header_file = os.path.splitext(bil_file)[0] + ".hdr"
|
||
|
||
if not os.path.exists(header_file):
|
||
print(f"警告: 找不到头文件 {header_file}")
|
||
return None
|
||
|
||
wavelengths = []
|
||
|
||
with open(header_file, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
|
||
# 查找包含波长信息的行
|
||
wavelength_lines = []
|
||
in_wavelength_block = False
|
||
|
||
for line in lines:
|
||
stripped_line = line.strip()
|
||
|
||
# 检测波长块的开始(精确匹配 wavelength = )
|
||
if stripped_line.startswith('wavelength ='):
|
||
in_wavelength_block = True
|
||
# 提取第一行的波长信息
|
||
wavelength_str = stripped_line.replace('wavelength =', '').strip()
|
||
if wavelength_str.startswith('{'):
|
||
wavelength_str = wavelength_str[1:].strip()
|
||
wavelength_lines.append(wavelength_str)
|
||
# 检测波长块的中间行
|
||
elif in_wavelength_block:
|
||
if '}' in stripped_line:
|
||
# 波长块结束
|
||
end_str = stripped_line.replace('}', '').strip()
|
||
if end_str:
|
||
wavelength_lines.append(end_str)
|
||
in_wavelength_block = False
|
||
else:
|
||
wavelength_lines.append(stripped_line)
|
||
|
||
if wavelength_lines:
|
||
# 合并所有波长行
|
||
combined_wavelengths = ' '.join(wavelength_lines)
|
||
# 移除所有花括号和逗号
|
||
combined_wavelengths = combined_wavelengths.replace('{', '').replace('}', '').strip()
|
||
|
||
# 分割波长值(支持逗号和空格分隔)
|
||
wavelength_values = []
|
||
for part in combined_wavelengths.split(','):
|
||
part = part.strip()
|
||
if part:
|
||
# 处理可能的多值情况(空格分隔)
|
||
for value in part.split():
|
||
if value.strip():
|
||
try:
|
||
wavelength_values.append(float(value.strip()))
|
||
except ValueError:
|
||
continue
|
||
|
||
print(f"从头文件读取到 {len(wavelength_values)} 个波长值")
|
||
return wavelength_values
|
||
else:
|
||
print("警告: 头文件中未找到波长信息")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"读取头文件波长信息时发生错误: {str(e)}")
|
||
return None
|
||
|
||
|
||
def get_spectral_sampling_points_chunked(bil_file, water_mask_shp, severe_glint=None, output_csvpath=None,
|
||
interval=100, sample_radius=1, chunk_size=1000,
|
||
use_adaptive_sampling=True, min_interval=10, max_interval=200):
|
||
"""
|
||
基于bil文件、shp格式water_mask和severe_glint生成采样点并提取光谱数据(分块处理版本)
|
||
|
||
参数:
|
||
bil_file: str - bil格式的光谱数据文件路径
|
||
water_mask_shp: str - shp格式的水体掩膜文件路径
|
||
severe_glint: str - 耀斑掩膜文件路径(可选)
|
||
output_csvpath: str - 输出CSV文件路径(可选)
|
||
interval: int - 采样点间隔(像元数),当use_adaptive_sampling=False时使用
|
||
sample_radius: int - 采样点半径(像元数)
|
||
chunk_size: int - 每次处理的行数(控制内存使用)
|
||
use_adaptive_sampling: bool - 是否使用自适应采样(根据水体宽度调整间隔,默认True)
|
||
min_interval: int - 自适应采样时的最小间隔(像元数,默认10)
|
||
max_interval: int - 自适应采样时的最大间隔(像元数,默认200)
|
||
|
||
返回:
|
||
tuple: (x_coords, y_coords, spectral_data) - 坐标列表和光谱数据数组
|
||
"""
|
||
# 初始化GDAL异常处理
|
||
gdal.UseExceptions()
|
||
ogr.UseExceptions()
|
||
|
||
try:
|
||
# 打开bil文件
|
||
dataset_bil = gdal.Open(bil_file)
|
||
if dataset_bil is None:
|
||
raise ValueError(f"无法打开bil文件: {bil_file}")
|
||
|
||
# 获取bil文件的基本信息
|
||
im_width = dataset_bil.RasterXSize
|
||
im_height = dataset_bil.RasterYSize
|
||
num_bands = dataset_bil.RasterCount
|
||
geotransform_input = dataset_bil.GetGeoTransform()
|
||
projection = dataset_bil.GetProjection()
|
||
|
||
print(f"bil文件信息: 宽度={im_width}, 高度={im_height}, 波段数={num_bands}")
|
||
print(f"分块处理,每次处理 {chunk_size} 行")
|
||
|
||
# 创建水体掩膜栅格
|
||
print("正在处理水体掩膜...")
|
||
water_mask_raster = create_water_mask_from_shp(water_mask_shp, bil_file)
|
||
|
||
# 处理耀斑掩膜(可选)
|
||
if severe_glint is not None:
|
||
dataset_severe_glint = gdal.Open(severe_glint)
|
||
if dataset_severe_glint is None:
|
||
raise ValueError(f"无法打开耀斑掩膜文件: {severe_glint}")
|
||
data_severe_glint = dataset_severe_glint.GetRasterBand(1).ReadAsArray()
|
||
print("已加载耀斑掩膜")
|
||
# 对glint边界进行外扩1-2像素作为缓冲
|
||
data_severe_glint = expand_glint_buffer(data_severe_glint, buffer_size=2)
|
||
else:
|
||
data_severe_glint = None
|
||
print("未使用耀斑掩膜")
|
||
|
||
# 创建有效区域掩膜
|
||
if data_severe_glint is not None:
|
||
valid_area = (water_mask_raster > 0) & (~(data_severe_glint > 0))
|
||
else:
|
||
valid_area = (water_mask_raster > 0)
|
||
|
||
# 计算水体宽度(用于自适应采样)
|
||
width_map = None
|
||
if use_adaptive_sampling:
|
||
print("正在计算水体宽度(用于自适应采样)...")
|
||
width_map = calculate_water_width(water_mask_raster)
|
||
if width_map is not None:
|
||
width_min = np.min(width_map[water_mask_raster > 0])
|
||
width_max = np.max(width_map[water_mask_raster > 0])
|
||
print(f"水体宽度范围: {width_min:.1f} - {width_max:.1f} 像元")
|
||
else:
|
||
print("警告: 无法计算水体宽度,将使用固定间隔采样")
|
||
use_adaptive_sampling = False
|
||
|
||
# 保存有效区域(可选)
|
||
if output_csvpath:
|
||
valid_area_path = os.path.splitext(output_csvpath)[0] + "_valid_area.bsq"
|
||
write_bands(bil_file, valid_area_path, valid_area.astype(np.uint8))
|
||
|
||
x_out = []
|
||
y_out = []
|
||
spectral_out = []
|
||
|
||
# 如果没有提供输出路径,则不保存文件
|
||
if output_csvpath:
|
||
f = open(output_csvpath, "w")
|
||
# 写入CSV头部
|
||
header = "x_coord,y_coord,pixel_x,pixel_y"
|
||
|
||
# 尝试从头文件读取波长名称
|
||
wavelengths = get_wavelengths_from_bil_header(bil_file)
|
||
if wavelengths is not None and len(wavelengths) == num_bands:
|
||
for i, wavelength in enumerate(wavelengths):
|
||
# 使用格式化字符串保留足够的小数位数(通常波长需要4-6位小数)
|
||
header += f",{wavelength:.6f}"
|
||
else:
|
||
# 如果无法获取波长信息,使用默认的波段编号
|
||
for i in range(num_bands):
|
||
header += f",band_{i + 1}"
|
||
|
||
f.write(header + "\n")
|
||
else:
|
||
f = None
|
||
|
||
try:
|
||
print("正在分块生成采样点...")
|
||
sample_count = 0
|
||
sampled_pixels = set() # 用于记录已采样的像素,避免重复
|
||
|
||
# 辅助函数:添加采样点(分块版本)
|
||
def add_sample_point_chunked(x, y, local_y, spectral_chunk, valid_chunk, sample_radius,
|
||
geotransform_input, num_bands, f, x_out, y_out,
|
||
spectral_out, sampled_pixels):
|
||
"""添加单个采样点(分块版本)"""
|
||
# 检查是否已采样
|
||
if (x, y) in sampled_pixels:
|
||
return False
|
||
|
||
# 检查边界
|
||
if (x < sample_radius or x >= im_width - sample_radius or
|
||
local_y < sample_radius or local_y >= valid_chunk.shape[0] - sample_radius):
|
||
return False
|
||
|
||
# 检查采样点周围区域是否全部有效
|
||
sample_area = valid_chunk[
|
||
local_y - sample_radius:local_y + sample_radius + 1,
|
||
x - sample_radius:x + sample_radius + 1
|
||
]
|
||
|
||
# 如果采样区域内所有像元都有效
|
||
if np.all(sample_area):
|
||
# 提取光谱数据(采样区域内的平均值)
|
||
spectral_sample = []
|
||
for band_idx in range(num_bands):
|
||
band_data = spectral_chunk[
|
||
band_idx,
|
||
local_y - sample_radius:local_y + sample_radius + 1,
|
||
x - sample_radius:x + sample_radius + 1
|
||
]
|
||
# 计算平均值,忽略无效值
|
||
valid_pixels = band_data[sample_area]
|
||
if len(valid_pixels) > 0:
|
||
mean_value = np.mean(valid_pixels)
|
||
else:
|
||
mean_value = np.nan
|
||
spectral_sample.append(mean_value)
|
||
|
||
# 转换为地理坐标
|
||
geo_x, geo_y = gdal.ApplyGeoTransform(
|
||
geotransform_input,
|
||
x + 0.5, # 像元中心
|
||
y + 0.5
|
||
)
|
||
|
||
# 保存结果
|
||
if f:
|
||
line_parts = [f"{geo_x:.6f}", f"{geo_y:.6f}", f"{x}", f"{y}"]
|
||
for spec_val in spectral_sample:
|
||
line_parts.append(f"{spec_val:.6f}")
|
||
f.write(",".join(line_parts) + "\n")
|
||
|
||
x_out.append(geo_x)
|
||
y_out.append(geo_y)
|
||
spectral_out.append(spectral_sample)
|
||
sampled_pixels.add((x, y))
|
||
return True
|
||
return False
|
||
|
||
# 计算需要处理的块数
|
||
total_chunks = math.ceil((im_height - 2 * sample_radius) / chunk_size)
|
||
|
||
for chunk_idx in range(total_chunks):
|
||
# 计算当前块的行范围
|
||
start_row = sample_radius + chunk_idx * chunk_size
|
||
end_row = min(sample_radius + (chunk_idx + 1) * chunk_size, im_height - sample_radius)
|
||
|
||
# 扩展范围以包含采样半径
|
||
read_start = max(0, start_row - sample_radius)
|
||
read_end = min(im_height, end_row + sample_radius)
|
||
|
||
print(f"处理块 {chunk_idx + 1}/{total_chunks}: 行 {start_row}-{end_row}")
|
||
|
||
# 读取当前块的光谱数据
|
||
spectral_chunk = dataset_bil.ReadAsArray(
|
||
0, read_start, im_width, read_end - read_start
|
||
) # shape: (bands, chunk_height, width)
|
||
|
||
# 获取对应的有效区域掩膜和宽度图
|
||
valid_chunk = valid_area[read_start:read_end, :]
|
||
water_chunk = water_mask_raster[read_start:read_end, :]
|
||
width_chunk = width_map[read_start:read_end, :] if width_map is not None else None
|
||
|
||
# 自适应采样:根据水体宽度调整采样间隔
|
||
if use_adaptive_sampling and width_chunk is not None:
|
||
print(f" 使用自适应采样(间隔范围: {min_interval}-{max_interval})...")
|
||
# 计算宽度范围用于归一化
|
||
width_chunk_valid = width_chunk[water_chunk > 0]
|
||
if len(width_chunk_valid) > 0:
|
||
width_min_chunk = np.min(width_chunk_valid)
|
||
width_max_chunk = np.max(width_chunk_valid)
|
||
|
||
# 使用基础间隔作为网格起点
|
||
base_interval = min(interval, max_interval)
|
||
|
||
# 使用网格化采样,但根据局部宽度调整间隔
|
||
y = start_row
|
||
while y < end_row:
|
||
local_y = y - read_start
|
||
if local_y < 0 or local_y >= valid_chunk.shape[0]:
|
||
y += base_interval
|
||
continue
|
||
|
||
x = sample_radius
|
||
while x < im_width - sample_radius:
|
||
# 检查当前位置是否在水体区域内
|
||
if (local_y >= 0 and local_y < valid_chunk.shape[0] and
|
||
x >= 0 and x < valid_chunk.shape[1] and
|
||
water_chunk[local_y, x] > 0):
|
||
|
||
# 获取当前位置的水体宽度
|
||
local_width = width_chunk[local_y, x]
|
||
|
||
# 根据宽度计算采样间隔
|
||
# 窄的区域使用小间隔,宽的区域使用大间隔
|
||
if width_max_chunk > width_min_chunk:
|
||
# 归一化宽度到[0, 1]
|
||
normalized_width = (local_width - width_min_chunk) / (width_max_chunk - width_min_chunk)
|
||
# 映射到[min_interval, max_interval]
|
||
adaptive_interval = max(min_interval, min(max_interval,
|
||
int(min_interval + normalized_width * (max_interval - min_interval))))
|
||
else:
|
||
adaptive_interval = base_interval
|
||
|
||
# 尝试添加采样点
|
||
if add_sample_point_chunked(x, y, local_y, spectral_chunk, valid_chunk, sample_radius,
|
||
geotransform_input, num_bands, f, x_out, y_out,
|
||
spectral_out, sampled_pixels):
|
||
sample_count += 1
|
||
|
||
# 根据自适应间隔前进
|
||
x += adaptive_interval
|
||
else:
|
||
# 不在水体区域内,使用基础间隔快速跳过
|
||
x += base_interval
|
||
|
||
# 行间隔也使用基础间隔
|
||
y += base_interval
|
||
else:
|
||
# 固定间隔采样
|
||
print(f" 使用固定间隔采样(间隔: {interval})...")
|
||
for y in range(start_row, end_row, interval):
|
||
for x in range(sample_radius, im_width - sample_radius, interval):
|
||
local_y = y - read_start
|
||
if add_sample_point_chunked(x, y, local_y, spectral_chunk, valid_chunk, sample_radius,
|
||
geotransform_input, num_bands, f, x_out, y_out,
|
||
spectral_out, sampled_pixels):
|
||
sample_count += 1
|
||
|
||
# 清理当前块的数据以释放内存
|
||
del spectral_chunk
|
||
del valid_chunk
|
||
del water_chunk
|
||
if width_chunk is not None:
|
||
del width_chunk
|
||
|
||
print(f"块 {chunk_idx + 1} 完成,当前采样点总数: {sample_count}")
|
||
|
||
print(f"所有块处理完成,成功生成 {sample_count} 个采样点")
|
||
|
||
finally:
|
||
if f:
|
||
f.close()
|
||
|
||
return x_out, y_out, np.array(spectral_out)
|
||
|
||
except Exception as e:
|
||
print(f"处理过程中发生错误: {str(e)}")
|
||
raise
|
||
|
||
|
||
def get_spectral_sampling_points(bil_file, water_mask_shp, severe_glint=None, output_csvpath=None,
|
||
interval=100, sample_radius=1,
|
||
use_adaptive_sampling=True, min_interval=10, max_interval=200):
|
||
"""
|
||
基于bil文件、shp格式water_mask和severe_glint生成采样点并提取光谱数据
|
||
|
||
参数:
|
||
bil_file: str - bil格式的光谱数据文件路径
|
||
water_mask_shp: str - shp格式的水体掩膜文件路径
|
||
severe_glint: str - 耀斑掩膜文件路径(可选)
|
||
output_csvpath: str - 输出CSV文件路径(可选)
|
||
interval: int - 采样点间隔(像元数),当use_adaptive_sampling=False时使用
|
||
sample_radius: int - 采样点半径(像元数)
|
||
use_adaptive_sampling: bool - 是否使用自适应采样(根据水体宽度调整间隔,默认True)
|
||
min_interval: int - 自适应采样时的最小间隔(像元数,默认10)
|
||
max_interval: int - 自适应采样时的最大间隔(像元数,默认200)
|
||
|
||
返回:
|
||
tuple: (x_coords, y_coords, spectral_data) - 坐标列表和光谱数据数组
|
||
"""
|
||
# 初始化GDAL异常处理
|
||
gdal.UseExceptions()
|
||
ogr.UseExceptions()
|
||
|
||
try:
|
||
# 打开bil文件
|
||
dataset_bil = gdal.Open(bil_file)
|
||
if dataset_bil is None:
|
||
raise ValueError(f"无法打开bil文件: {bil_file}")
|
||
|
||
# 获取bil文件的基本信息
|
||
im_width = dataset_bil.RasterXSize
|
||
im_height = dataset_bil.RasterYSize
|
||
num_bands = dataset_bil.RasterCount
|
||
geotransform_input = dataset_bil.GetGeoTransform()
|
||
projection = dataset_bil.GetProjection()
|
||
|
||
print(f"bil文件信息: 宽度={im_width}, 高度={im_height}, 波段数={num_bands}")
|
||
|
||
# 读取光谱数据(所有波段)
|
||
print("正在读取光谱数据...")
|
||
spectral_data_full = dataset_bil.ReadAsArray() # shape: (bands, height, width)
|
||
|
||
# 创建水体掩膜栅格
|
||
print("正在处理水体掩膜...")
|
||
water_mask_raster = create_water_mask_from_shp(water_mask_shp, bil_file)
|
||
|
||
# 处理耀斑掩膜(可选)
|
||
if severe_glint is not None:
|
||
dataset_severe_glint = gdal.Open(severe_glint)
|
||
if dataset_severe_glint is None:
|
||
raise ValueError(f"无法打开耀斑掩膜文件: {severe_glint}")
|
||
data_severe_glint = dataset_severe_glint.GetRasterBand(1).ReadAsArray()
|
||
print("已加载耀斑掩膜")
|
||
# 对glint边界进行外扩1-2像素作为缓冲
|
||
data_severe_glint = expand_glint_buffer(data_severe_glint, buffer_size=2)
|
||
else:
|
||
data_severe_glint = None
|
||
print("未使用耀斑掩膜")
|
||
|
||
inv_geotransform_input = gdal.InvGeoTransform(geotransform_input)
|
||
if inv_geotransform_input is None:
|
||
raise ValueError("无法计算逆仿射变换")
|
||
|
||
# 计算范围
|
||
x_min = geotransform_input[0]
|
||
y_max = geotransform_input[3]
|
||
x_max = x_min + im_width * geotransform_input[1]
|
||
y_min = y_max + im_height * geotransform_input[5]
|
||
|
||
# 创建有效区域掩膜
|
||
if data_severe_glint is not None:
|
||
valid_area = (water_mask_raster > 0) & (~(data_severe_glint > 0))
|
||
else:
|
||
valid_area = (water_mask_raster > 0)
|
||
|
||
# 计算水体宽度(用于自适应采样)
|
||
width_map = None
|
||
if use_adaptive_sampling:
|
||
print("正在计算水体宽度(用于自适应采样)...")
|
||
width_map = calculate_water_width(water_mask_raster)
|
||
if width_map is not None:
|
||
width_min = np.min(width_map[water_mask_raster > 0])
|
||
width_max = np.max(width_map[water_mask_raster > 0])
|
||
print(f"水体宽度范围: {width_min:.1f} - {width_max:.1f} 像元")
|
||
else:
|
||
print("警告: 无法计算水体宽度,将使用固定间隔采样")
|
||
use_adaptive_sampling = False
|
||
|
||
# 保存有效区域(可选)
|
||
if output_csvpath:
|
||
valid_area_path = os.path.splitext(output_csvpath)[0] + "_valid_area.tif"
|
||
write_bands(bil_file, valid_area_path, valid_area.astype(np.uint8))
|
||
|
||
x_out = []
|
||
y_out = []
|
||
spectral_out = []
|
||
|
||
# 如果没有提供输出路径,则不保存文件
|
||
if output_csvpath:
|
||
f = open(output_csvpath, "w")
|
||
# 写入CSV头部
|
||
header = "x_coord,y_coord,pixel_x,pixel_y"
|
||
|
||
# 尝试从头文件读取波长名称
|
||
wavelengths = get_wavelengths_from_bil_header(bil_file)
|
||
if wavelengths is not None and len(wavelengths) == num_bands:
|
||
for i, wavelength in enumerate(wavelengths):
|
||
# 使用格式化字符串保留足够的小数位数(通常波长需要4-6位小数)
|
||
header += f",{wavelength:.6f}"
|
||
else:
|
||
# 如果无法获取波长信息,使用默认的波段编号
|
||
for i in range(num_bands):
|
||
header += f",band_{i + 1}"
|
||
|
||
f.write(header + "\n")
|
||
else:
|
||
f = None
|
||
|
||
try:
|
||
print("正在生成采样点...")
|
||
sample_count = 0
|
||
sampled_pixels = set() # 用于记录已采样的像素,避免重复
|
||
|
||
# 辅助函数:添加采样点
|
||
def add_sample_point(x, y, spectral_data_full, valid_area, sample_radius,
|
||
geotransform_input, num_bands, f, x_out, y_out,
|
||
spectral_out, sampled_pixels):
|
||
"""添加单个采样点"""
|
||
# 检查是否已采样
|
||
if (x, y) in sampled_pixels:
|
||
return False
|
||
|
||
# 检查边界
|
||
if (x < sample_radius or x >= im_width - sample_radius or
|
||
y < sample_radius or y >= im_height - sample_radius):
|
||
return False
|
||
|
||
# 检查采样点周围区域是否全部有效
|
||
sample_area = valid_area[y - sample_radius:y + sample_radius + 1,
|
||
x - sample_radius:x + sample_radius + 1]
|
||
|
||
# 如果采样区域内所有像元都有效
|
||
if np.all(sample_area):
|
||
# 提取光谱数据(采样区域内的平均值)
|
||
spectral_sample = []
|
||
for band_idx in range(num_bands):
|
||
band_data = spectral_data_full[band_idx,
|
||
y - sample_radius:y + sample_radius + 1,
|
||
x - sample_radius:x + sample_radius + 1]
|
||
# 计算平均值,忽略无效值
|
||
valid_pixels = band_data[sample_area]
|
||
if len(valid_pixels) > 0:
|
||
mean_value = np.mean(valid_pixels)
|
||
else:
|
||
mean_value = np.nan
|
||
spectral_sample.append(mean_value)
|
||
|
||
# 转换为地理坐标
|
||
geo_x, geo_y = gdal.ApplyGeoTransform(
|
||
geotransform_input,
|
||
x + 0.5, # 像元中心
|
||
y + 0.5
|
||
)
|
||
|
||
# 保存结果
|
||
if f:
|
||
line_parts = [f"{geo_x:.6f}", f"{geo_y:.6f}", f"{x}", f"{y}"]
|
||
for spec_val in spectral_sample:
|
||
line_parts.append(f"{spec_val:.6f}")
|
||
f.write(",".join(line_parts) + "\n")
|
||
|
||
x_out.append(geo_x)
|
||
y_out.append(geo_y)
|
||
spectral_out.append(spectral_sample)
|
||
sampled_pixels.add((x, y))
|
||
return True
|
||
return False
|
||
|
||
# 自适应采样:根据水体宽度调整采样间隔
|
||
if use_adaptive_sampling and width_map is not None:
|
||
print("使用自适应采样(根据水体宽度调整间隔)...")
|
||
# 计算宽度范围用于归一化
|
||
width_valid = width_map[water_mask_raster > 0]
|
||
if len(width_valid) > 0:
|
||
width_min = np.min(width_valid)
|
||
width_max = np.max(width_valid)
|
||
|
||
# 使用基础间隔作为网格起点
|
||
base_interval = min(interval, max_interval)
|
||
|
||
# 使用网格化采样,但根据局部宽度调整间隔
|
||
y = sample_radius
|
||
while y < im_height - sample_radius:
|
||
x = sample_radius
|
||
while x < im_width - sample_radius:
|
||
# 检查当前位置是否在水体区域内
|
||
if (water_mask_raster[y, x] > 0):
|
||
# 获取当前位置的水体宽度
|
||
local_width = width_map[y, x]
|
||
|
||
# 根据宽度计算采样间隔
|
||
# 窄的区域使用小间隔,宽的区域使用大间隔
|
||
if width_max > width_min:
|
||
# 归一化宽度到[0, 1]
|
||
normalized_width = (local_width - width_min) / (width_max - width_min)
|
||
# 映射到[min_interval, max_interval]
|
||
adaptive_interval = max(min_interval, min(max_interval,
|
||
int(min_interval + normalized_width * (max_interval - min_interval))))
|
||
else:
|
||
adaptive_interval = base_interval
|
||
|
||
# 尝试添加采样点
|
||
if add_sample_point(x, y, spectral_data_full, valid_area, sample_radius,
|
||
geotransform_input, num_bands, f, x_out, y_out,
|
||
spectral_out, sampled_pixels):
|
||
sample_count += 1
|
||
|
||
# 根据自适应间隔前进
|
||
x += adaptive_interval
|
||
else:
|
||
# 不在水体区域内,使用基础间隔快速跳过
|
||
x += base_interval
|
||
|
||
# 行间隔也使用基础间隔
|
||
y += base_interval
|
||
else:
|
||
# 如果无法计算宽度,使用固定间隔
|
||
print("无法获取有效宽度信息,使用固定间隔采样...")
|
||
use_adaptive_sampling = False
|
||
|
||
# 固定间隔采样(当不使用自适应采样时)
|
||
if not use_adaptive_sampling:
|
||
print(f"使用固定间隔采样(间隔: {interval})...")
|
||
for y in range(sample_radius, im_height - sample_radius, interval):
|
||
for x in range(sample_radius, im_width - sample_radius, interval):
|
||
if add_sample_point(x, y, spectral_data_full, valid_area, sample_radius,
|
||
geotransform_input, num_bands, f, x_out, y_out,
|
||
spectral_out, sampled_pixels):
|
||
sample_count += 1
|
||
|
||
print(f"成功生成 {sample_count} 个采样点")
|
||
|
||
finally:
|
||
if f:
|
||
f.close()
|
||
|
||
return x_out, y_out, np.array(spectral_out)
|
||
|
||
except Exception as e:
|
||
print(f"处理过程中发生错误: {str(e)}")
|
||
raise
|
||
|
||
|
||
def create_water_mask_from_shp(shp_file, reference_raster):
|
||
"""
|
||
从shp文件或栅格文件创建水体掩膜栅格
|
||
|
||
参数:
|
||
shp_file: str - shp文件路径或栅格文件路径(.dat/.tif等)
|
||
reference_raster: str - 参考栅格文件路径(用于获取空间范围和分辨率,当shp_file为shp格式时需要)
|
||
|
||
返回:
|
||
numpy.ndarray - 水体掩膜数组
|
||
"""
|
||
try:
|
||
# 检查文件格式
|
||
file_ext = os.path.splitext(shp_file)[1].lower()
|
||
|
||
if file_ext == '.shp':
|
||
# shp格式,需要栅格化
|
||
# 打开参考栅格获取空间信息
|
||
ref_dataset = gdal.Open(reference_raster)
|
||
if ref_dataset is None:
|
||
raise ValueError(f"无法打开参考栅格文件: {reference_raster}")
|
||
|
||
geotransform = ref_dataset.GetGeoTransform()
|
||
projection = ref_dataset.GetProjection()
|
||
width = ref_dataset.RasterXSize
|
||
height = ref_dataset.RasterYSize
|
||
|
||
# 创建内存中的栅格数据集
|
||
mem_driver = gdal.GetDriverByName('MEM')
|
||
mask_dataset = mem_driver.Create('', width, height, 1, gdal.GDT_Byte)
|
||
mask_dataset.SetGeoTransform(geotransform)
|
||
mask_dataset.SetProjection(projection)
|
||
|
||
# 初始化为0
|
||
mask_band = mask_dataset.GetRasterBand(1)
|
||
mask_band.Fill(0)
|
||
|
||
# 打开shp文件
|
||
shp_dataset = ogr.Open(shp_file)
|
||
if shp_dataset is None:
|
||
raise ValueError(f"无法打开shp文件: {shp_file}")
|
||
|
||
layer = shp_dataset.GetLayer()
|
||
|
||
# 栅格化shp文件
|
||
gdal.RasterizeLayer(mask_dataset, [1], layer, burn_values=[1])
|
||
|
||
# 读取栅格化结果
|
||
water_mask = mask_band.ReadAsArray()
|
||
|
||
# 清理
|
||
ref_dataset = None
|
||
mask_dataset = None
|
||
shp_dataset = None
|
||
|
||
return water_mask
|
||
else:
|
||
# 栅格格式(.dat/.tif等),直接读取
|
||
mask_dataset = gdal.Open(shp_file, gdal.GA_ReadOnly)
|
||
if mask_dataset is None:
|
||
raise ValueError(f"无法打开栅格掩膜文件: {shp_file}")
|
||
|
||
# 读取第一个波段
|
||
water_mask = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||
|
||
# 清理
|
||
mask_dataset = None
|
||
|
||
return water_mask
|
||
|
||
except Exception as e:
|
||
print(f"创建水体掩膜时发生错误: {str(e)}")
|
||
raise
|
||
|
||
|
||
def expand_glint_buffer(glint_mask, buffer_size=2):
|
||
"""
|
||
对glint掩膜进行边界外扩,作为缓冲区域
|
||
|
||
参数:
|
||
glint_mask: numpy.ndarray - glint掩膜数组(glint区域>0)
|
||
buffer_size: int - 外扩像素数(默认2,范围1-2)
|
||
|
||
返回:
|
||
numpy.ndarray - 外扩后的glint掩膜
|
||
"""
|
||
if glint_mask is None:
|
||
return None
|
||
|
||
# 限制buffer_size在1-2之间
|
||
buffer_size = max(1, min(2, int(buffer_size)))
|
||
|
||
# 将glint掩膜转换为二值图像
|
||
glint_binary = (glint_mask > 0).astype(np.uint8)
|
||
|
||
# 创建结构元素(圆形或方形)
|
||
# 使用3x3或5x5的结构元素进行膨胀
|
||
if buffer_size == 1:
|
||
structure = np.ones((3, 3), dtype=np.uint8)
|
||
else: # buffer_size == 2
|
||
structure = np.ones((5, 5), dtype=np.uint8)
|
||
|
||
# 对glint区域进行膨胀操作
|
||
expanded_glint = ndimage.binary_dilation(glint_binary, structure=structure).astype(np.uint8)
|
||
|
||
expanded_pixels = np.sum(expanded_glint > 0) - np.sum(glint_binary > 0)
|
||
print(f"Glint边界外扩: 外扩 {buffer_size} 像素,新增 {expanded_pixels} 个像素")
|
||
|
||
return expanded_glint
|
||
|
||
|
||
def calculate_water_width(water_mask):
|
||
"""
|
||
计算水体宽度(使用距离变换)
|
||
|
||
参数:
|
||
water_mask: numpy.ndarray - 水体掩膜数组(水体区域>0)
|
||
|
||
返回:
|
||
numpy.ndarray - 宽度数组,每个像素的值表示到最近边界的距离(即局部宽度的一半)
|
||
"""
|
||
try:
|
||
# 将水体掩膜转换为二值图像
|
||
water_binary = (water_mask > 0).astype(bool)
|
||
|
||
# 计算距离变换:每个像素到最近边界的距离
|
||
# 使用欧氏距离变换
|
||
from scipy.ndimage import distance_transform_edt
|
||
|
||
# 计算到边界的距离(内部像素)
|
||
distance = distance_transform_edt(water_binary)
|
||
|
||
# 宽度 = 2 * 距离(因为距离是到边界的距离,宽度是两倍)
|
||
# 但为了简化,我们直接使用距离作为宽度的近似值
|
||
# 在窄的区域,距离小;在宽的区域,距离大
|
||
width_map = distance.astype(np.float32)
|
||
|
||
return width_map
|
||
|
||
except Exception as e:
|
||
print(f"计算水体宽度时发生错误: {str(e)}")
|
||
return None
|
||
|
||
|
||
def detect_water_centerline(water_mask):
|
||
"""
|
||
检测水体的主轴线(中心线/骨架)
|
||
|
||
参数:
|
||
water_mask: numpy.ndarray - 水体掩膜数组
|
||
|
||
返回:
|
||
numpy.ndarray - 中心线掩膜(中心线位置=1,其他=0)
|
||
"""
|
||
if not SKIMAGE_AVAILABLE:
|
||
print("警告: skimage未安装,无法检测主水轴线")
|
||
return None
|
||
|
||
try:
|
||
# 将水体掩膜转换为二值图像
|
||
water_binary = (water_mask > 0).astype(bool)
|
||
|
||
# 使用骨架化提取中心线
|
||
skeleton = skeletonize(water_binary)
|
||
|
||
skeleton_pixels = np.sum(skeleton > 0)
|
||
print(f"主水轴线检测: 检测到 {skeleton_pixels} 个中心线像素")
|
||
|
||
return skeleton.astype(np.uint8)
|
||
|
||
except Exception as e:
|
||
print(f"检测主水轴线时发生错误: {str(e)}")
|
||
return None
|
||
|
||
|
||
def detect_water_features(water_mask, centerline_mask=None):
|
||
"""
|
||
检测水体的地形特征点:分支口、汇入口、弯头
|
||
|
||
参数:
|
||
water_mask: numpy.ndarray - 水体掩膜数组
|
||
centerline_mask: numpy.ndarray - 中心线掩膜(可选,如果提供则只检测中心线上的特征点)
|
||
|
||
返回:
|
||
numpy.ndarray - 特征点掩膜(特征点位置=1,其他=0)
|
||
"""
|
||
try:
|
||
# 如果提供了中心线掩膜,则基于中心线检测特征点
|
||
if centerline_mask is not None:
|
||
skeleton = (centerline_mask > 0).astype(bool)
|
||
else:
|
||
# 如果没有提供中心线,先提取中心线
|
||
if SKIMAGE_AVAILABLE:
|
||
skeleton = skeletonize((water_mask > 0).astype(bool))
|
||
else:
|
||
print("警告: 无法检测地形特征点(需要中心线)")
|
||
return np.zeros_like(water_mask, dtype=np.uint8)
|
||
|
||
# 检测中心线上的交叉点(分支口/汇入口)
|
||
# 通过计算每个像素的邻域连接数来判断
|
||
features = np.zeros_like(water_mask, dtype=np.uint8)
|
||
|
||
# 使用3x3卷积核检测交叉点
|
||
kernel = np.array([[1, 1, 1],
|
||
[1, 0, 1],
|
||
[1, 1, 1]], dtype=np.uint8)
|
||
|
||
# 计算每个中心线像素的邻域连接数
|
||
neighbor_count = ndimage.convolve(skeleton.astype(np.uint8), kernel, mode='constant')
|
||
|
||
# 交叉点:邻域连接数 >= 3(分支口或汇入口)
|
||
branch_points = (skeleton) & (neighbor_count >= 3)
|
||
|
||
# 检测弯头:通过计算曲率变化
|
||
# 使用Sobel算子计算梯度方向变化
|
||
from scipy.ndimage import sobel, gaussian_filter
|
||
sobel_x = sobel(skeleton.astype(float), axis=1)
|
||
sobel_y = sobel(skeleton.astype(float), axis=0)
|
||
|
||
# 计算梯度方向
|
||
gradient_magnitude = np.sqrt(sobel_x**2 + sobel_y**2)
|
||
|
||
# 弯头:梯度方向变化大的点(在中心线上且梯度变化大)
|
||
# 使用梯度幅值的二阶导数来检测曲率变化
|
||
smoothed_gradient = gaussian_filter(gradient_magnitude, sigma=1.0)
|
||
gradient_laplacian = ndimage.laplace(smoothed_gradient)
|
||
|
||
# 弯头:中心线上的点且梯度变化大(阈值可调)
|
||
curvature_threshold = np.percentile(gradient_laplacian[skeleton], 75) # 使用75分位数作为阈值
|
||
bend_points = (skeleton) & (np.abs(gradient_laplacian) > curvature_threshold)
|
||
|
||
# 合并所有特征点
|
||
features = (branch_points | bend_points).astype(np.uint8)
|
||
|
||
feature_count = np.sum(features > 0)
|
||
branch_count = np.sum(branch_points > 0)
|
||
bend_count = np.sum(bend_points > 0)
|
||
|
||
print(f"地形特征点检测: 检测到 {feature_count} 个特征点(分支/汇入: {branch_count}, 弯头: {bend_count})")
|
||
|
||
return features
|
||
|
||
except Exception as e:
|
||
print(f"检测地形特征点时发生错误: {str(e)}")
|
||
return np.zeros_like(water_mask, dtype=np.uint8)
|
||
|
||
|
||
def get_coor_base_interval(water_mask, severe_glint=None, output_csvpath=None, interval=100):
|
||
# 原有函数保持不变
|
||
# 初始化GDAL异常处理
|
||
gdal.UseExceptions()
|
||
|
||
try:
|
||
dataset_water_mask = gdal.Open(water_mask)
|
||
if dataset_water_mask is None:
|
||
raise ValueError(f"无法打开水体掩膜文件: {water_mask}")
|
||
data_water_mask = dataset_water_mask.GetRasterBand(1).ReadAsArray()
|
||
|
||
# 处理耀斑掩膜(可选)
|
||
if severe_glint is not None:
|
||
dataset_severe_glint = gdal.Open(severe_glint)
|
||
if dataset_severe_glint is None:
|
||
raise ValueError(f"无法打开耀斑掩膜文件: {severe_glint}")
|
||
data_severe_glint = dataset_severe_glint.GetRasterBand(1).ReadAsArray()
|
||
|
||
# 使用耀斑掩膜的几何信息
|
||
im_width = dataset_severe_glint.RasterXSize
|
||
im_height = dataset_severe_glint.RasterYSize
|
||
geotransform_input = dataset_severe_glint.GetGeoTransform()
|
||
else:
|
||
data_severe_glint = None
|
||
|
||
# 使用水体掩膜的几何信息
|
||
im_width = dataset_water_mask.RasterXSize
|
||
im_height = dataset_water_mask.RasterYSize
|
||
geotransform_input = dataset_water_mask.GetGeoTransform()
|
||
|
||
inv_geotransform_input = gdal.InvGeoTransform(geotransform_input)
|
||
|
||
if inv_geotransform_input is None:
|
||
raise ValueError("无法计算逆仿射变换")
|
||
|
||
# 修正范围计算
|
||
x_min = geotransform_input[0]
|
||
y_max = geotransform_input[3]
|
||
x_max = x_min + im_width * geotransform_input[1]
|
||
y_min = y_max + im_height * geotransform_input[5]
|
||
|
||
x_range = [x_min, x_max]
|
||
y_range = [y_min, y_max]
|
||
|
||
# 计算网格大小 - 确保最小为1个像元
|
||
pixel_size = abs(geotransform_input[1]) # 原始像元大小
|
||
grid_size = max(pixel_size * interval, pixel_size) # 网格不能小于原始像元
|
||
|
||
# 计算网格数量
|
||
dx = max(1, math.ceil((x_range[1] - x_range[0]) / grid_size))
|
||
dy = max(1, math.ceil((y_range[1] - y_range[0]) / grid_size))
|
||
|
||
# 创建输出网格的仿射变换
|
||
geotransform_out = (x_range[0], grid_size, 0, y_range[1], 0, -grid_size)
|
||
|
||
# 创建有效区域掩膜
|
||
if data_severe_glint is not None:
|
||
valid_area = (data_water_mask > 0) & (~(data_severe_glint > 0))
|
||
else:
|
||
valid_area = (data_water_mask > 0)
|
||
|
||
# 保存有效区域(可选)
|
||
valid_area_path = os.path.splitext(output_csvpath)[0] + "_valid_area.tif"
|
||
write_bands(water_mask, valid_area_path, valid_area.astype(np.uint8))
|
||
|
||
x_out = []
|
||
y_out = []
|
||
|
||
# 如果没有提供输出路径,则不保存文件
|
||
if output_csvpath:
|
||
f = open(output_csvpath, "w")
|
||
else:
|
||
f = None
|
||
|
||
try:
|
||
for row in range(dy):
|
||
for column in range(dx):
|
||
# 计算当前网格的四个角点(地理坐标)
|
||
top_left = (
|
||
x_range[0] + column * grid_size,
|
||
y_range[1] + row * (-grid_size) # 注意Y方向
|
||
)
|
||
|
||
bottom_right = (
|
||
x_range[0] + (column + 1) * grid_size,
|
||
y_range[1] + (row + 1) * (-grid_size)
|
||
)
|
||
|
||
# 转换为像元坐标
|
||
top_left_px = gdal.ApplyGeoTransform(inv_geotransform_input, top_left[0], top_left[1])
|
||
bottom_right_px = gdal.ApplyGeoTransform(inv_geotransform_input, bottom_right[0], bottom_right[1])
|
||
|
||
# 确保坐标在图像范围内
|
||
x1 = max(0, int(top_left_px[0]))
|
||
y1 = max(0, int(top_left_px[1]))
|
||
x2 = min(im_width, int(bottom_right_px[0]) + 1) # +1 确保包含
|
||
y2 = min(im_height, int(bottom_right_px[1]) + 1)
|
||
|
||
# 检查网格是否有效
|
||
if x2 <= x1 or y2 <= y1:
|
||
continue
|
||
|
||
# 提取当前网格内的有效区域
|
||
valid_area_local = valid_area[y1:y2, x1:x2]
|
||
|
||
# 查找有效像元
|
||
valid_pixels = np.argwhere(valid_area_local)
|
||
|
||
if valid_pixels.size > 0:
|
||
# 取第一个有效像元
|
||
local_y, local_x = valid_pixels[0]
|
||
|
||
# 转换为全局像元坐标
|
||
global_x = x1 + local_x
|
||
global_y = y1 + local_y
|
||
|
||
# 转换为地理坐标
|
||
geo_x, geo_y = gdal.ApplyGeoTransform(
|
||
geotransform_input,
|
||
global_x + 0.5, # 像元中心
|
||
global_y + 0.5
|
||
)
|
||
|
||
# 写入结果
|
||
if f:
|
||
line_parts = [f"{geo_x:.6f}", f"{geo_y:.6f}", f"{global_x}", f"{global_y}"]
|
||
f.write(",".join(line_parts) + "\n")
|
||
x_out.append(geo_x)
|
||
y_out.append(geo_y)
|
||
|
||
finally:
|
||
if f:
|
||
f.close()
|
||
|
||
return x_out, y_out
|
||
|
||
except Exception as e:
|
||
print(f"处理过程中发生错误: {str(e)}")
|
||
raise
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 新功能使用示例
|
||
bil_file = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"
|
||
water_mask_shp = r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp"
|
||
severe_glint = r"D:\BaiduNetdiskDownload\yaobao\find_glint\result3_glint_otsu"
|
||
output_csvpath = r"D:\BaiduNetdiskDownload\yaobao\csv\spectral_sampling_results.csv"
|
||
|
||
# 设置参数
|
||
interval = 50 # 基础采样点间隔(像元数),当use_adaptive_sampling=False时使用
|
||
sample_radius = 5 # 采样点半径(像元数)
|
||
chunk_size = 1000 # 每次处理的行数,可根据内存大小调整(建议500-2000)
|
||
|
||
# 自适应采样参数
|
||
use_adaptive_sampling = True # 是否使用自适应采样(根据水体宽度调整间隔)
|
||
min_interval = 10 # 自适应采样时的最小间隔(像元数),用于窄的入库流区域
|
||
max_interval = 200 # 自适应采样时的最大间隔(像元数),用于宽的水库区域
|
||
|
||
# 调用分块处理函数(适用于大型bil文件,内存友好)
|
||
try:
|
||
x_coords, y_coords, spectral_data = get_spectral_sampling_points_chunked(
|
||
bil_file, water_mask_shp, severe_glint, output_csvpath,
|
||
interval, sample_radius, chunk_size,
|
||
use_adaptive_sampling=use_adaptive_sampling,
|
||
min_interval=min_interval,
|
||
max_interval=max_interval
|
||
)
|
||
print(f"成功生成 {len(x_coords)} 个采样点")
|
||
print(f"每个采样点包含 {spectral_data.shape[1]} 个波段的光谱数据")
|
||
print(f"光谱数据形状: {spectral_data.shape}")
|
||
|
||
except Exception as e:
|
||
print(f"处理失败: {str(e)}")
|
||
|
||
# 原有功能使用示例(保持向后兼容)
|
||
# water_mask = r"D:\hsi\application\LICA_Work\laodaohe\preprocession\liuyang-guitang2\1.mask\mask.dat"
|
||
# severe_glint = r"D:\hsi\application\LICA_Work\laodaohe\preprocession\liuyang-guitang2\2.glint\ref_mosaic_1m_severe_glint"
|
||
# output_csvpath = r"D:\hsi\application\LICA_Work\laodaohe\preprocession\liuyang-guitang2\5.interval\coor_interval_3.csv"
|
||
# interval = 30
|
||
# x_coords, y_coords = get_coor_base_interval(water_mask, severe_glint, output_csvpath, interval)
|
||
# print(f"成功生成 {len(x_coords)} 个采样点") |