Initial commit of WQ_GUI

This commit is contained in:
2026-04-08 15:25:08 +08:00
commit 91e36407ae
302 changed files with 40872 additions and 0 deletions

View File

@ -0,0 +1,367 @@
import numpy as np
# import preprocessing
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
print("警告: GDAL未安装将使用numpy处理模式")
try:
from tqdm import tqdm
TQDM_AVAILABLE = True
except ImportError:
TQDM_AVAILABLE = False
# 如果tqdm不可用定义一个简单的包装器
def tqdm(iterable, desc=None, total=None):
return iterable
class Goodman:
def __init__(self, im_aligned, NIR_lower = 25, NIR_upper = 37, A = 0.000019, B = 0.1,
use_gdal=True, chunk_size=None, water_mask=None, output_path=None):
"""
:param im_aligned (np.ndarray or str): band aligned and calibrated & corrected reflectance image
可以是numpy数组或GDAL可读取的文件路径
:param NIR_lower (int): band index which corresponds to 641.93nm, closest band to 640nm
:param NIR_upper (int): band index which corresponds to 751.49nm, closest band to 750nm
:param A (float): the values in Goodman et al's paper, using AVIRIS reflectance (rather than radiance) data
:param B (float): the values in Goodman et al's paper, using AVIRIS reflectance (rather than radiance) data
see Goodman et al, which corrects each pixel independently. The NIR radiance is subtracted from the radiance at each wavelength,
but a wavelength-independent offset is also added.
it is not clear how A and B were chosen, but an optimization for a case where in situ data is
available would enable values to be found
:param use_gdal (bool): 是否使用GDAL加速处理需要GDAL可用且输入为文件路径或大数组
:param chunk_size (int): 已废弃,不再使用分块处理,改为逐波段处理
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
如果为None则不保存
"""
self.im_aligned = im_aligned
self.NIR_lower = NIR_lower
self.NIR_upper = NIR_upper
self.A = A
self.B = B
self.use_gdal = use_gdal and GDAL_AVAILABLE
self.chunk_size = chunk_size
self.is_file_path = isinstance(im_aligned, str)
self.output_path = output_path
# 获取图像信息(需要在加载掩膜之前获取尺寸)
if self.is_file_path:
if not self.use_gdal:
raise ValueError("输入为文件路径时必须安装GDAL")
self.dataset = gdal.Open(im_aligned, gdal.GA_ReadOnly)
if self.dataset is None:
raise ValueError(f"无法打开影像文件: {im_aligned}")
self.height = self.dataset.RasterYSize
self.width = self.dataset.RasterXSize
self.n_bands = self.dataset.RasterCount
else:
self.dataset = None
self.height = im_aligned.shape[0]
self.width = im_aligned.shape[1]
self.n_bands = im_aligned.shape[-1]
# 加载水域掩膜(在获取图像尺寸之后)
self.water_mask = self._load_water_mask(water_mask)
def _load_water_mask(self, water_mask):
"""
加载水域掩膜
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
:return: numpy数组或None1表示水域0表示非水域
"""
if water_mask is None:
return None
# 如果已经是numpy数组
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
# 如果是文件路径
if isinstance(water_mask, str):
if not GDAL_AVAILABLE:
raise ValueError("使用文件路径作为掩膜时必须安装GDAL")
# 检查是否为shapefile
if water_mask.lower().endswith('.shp'):
# 从shp文件创建掩膜
if self.is_file_path:
ref_path = self.im_aligned
else:
raise ValueError("输入为numpy数组时无法从shp文件创建掩膜需要参考栅格")
try:
from osgeo import ogr
ref_dataset = gdal.Open(ref_path, gdal.GA_ReadOnly)
if ref_dataset is None:
raise ValueError(f"无法打开参考栅格文件: {ref_path}")
geotransform = ref_dataset.GetGeoTransform()
projection = ref_dataset.GetProjection()
width = ref_dataset.RasterXSize
height = ref_dataset.RasterYSize
# 创建内存中的栅格数据集
mem_driver = gdal.GetDriverByName('MEM')
mask_dataset = mem_driver.Create('', width, height, 1, gdal.GDT_Byte)
mask_dataset.SetGeoTransform(geotransform)
mask_dataset.SetProjection(projection)
mask_band = mask_dataset.GetRasterBand(1)
mask_band.Fill(0)
# 打开shp文件
shp_dataset = ogr.Open(water_mask)
if shp_dataset is None:
raise ValueError(f"无法打开shp文件: {water_mask}")
layer = shp_dataset.GetLayer()
gdal.RasterizeLayer(mask_dataset, [1], layer, burn_values=[1])
water_mask_array = mask_band.ReadAsArray()
ref_dataset = None
mask_dataset = None
shp_dataset = None
return (water_mask_array > 0).astype(np.uint8)
except Exception as e:
raise ValueError(f"从shp文件创建掩膜时出错: {e}")
else:
# 栅格文件
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {water_mask}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def _get_corrected_bands_numpy(self):
"""
使用numpy处理用于小图像或GDAL不可用时
注意由于输入已经是numpy数组数据已在内存中。
此方法通过逐波段处理,避免同时创建多个校正后的波段数组。
内存峰值 = 原始数组 + NIR波段(2个) + 当前处理的波段(1个)
"""
# 预提取重复使用的NIR波段避免在循环中重复访问
# 这些波段会一直保存在内存中,因为它们需要用于所有波段的校正
R_640 = self.im_aligned[:,:,self.NIR_lower]
R_750 = self.im_aligned[:,:,self.NIR_upper]
# 预计算常量部分
diff_640_750 = R_640 - R_750
corrected_bands = []
# 获取水域掩膜(如果存在)
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
# 逐波段处理:每次只处理一个波段,处理完后立即添加到结果列表
for i in tqdm(range(self.n_bands), desc="处理波段 (numpy)", total=self.n_bands):
# 获取当前波段(这是数组视图,不是复制)
R = self.im_aligned[:,:,i]
# 优化计算:减少中间数组创建
corrected_band = R - R_750 + self.A + self.B * diff_640_750
# 使用np.maximum原地操作将负值设为0
np.maximum(corrected_band, 0, out=corrected_band)
# 如果存在水域掩膜,只对水域区域应用校正
if water_mask_bool is not None:
corrected_band = np.where(water_mask_bool, corrected_band, R)
# 立即添加到结果列表corrected_band会保留在列表中
corrected_bands.append(corrected_band)
return corrected_bands
def _get_corrected_bands_gdal(self):
"""
使用GDAL逐波段处理直接处理整个波段不分块
内存峰值 = NIR波段(2个) + 当前处理的波段(1个) + 已处理的波段(累积在列表中)
"""
corrected_bands = []
# 获取NIR波段对象用于所有波段的校正
band_640 = self.dataset.GetRasterBand(self.NIR_lower + 1) # GDAL波段从1开始
band_750 = self.dataset.GetRasterBand(self.NIR_upper + 1)
# 先读取NIR波段用于所有波段的校正会一直保存在内存中
R_640 = band_640.ReadAsArray().astype(np.float32)
R_750 = band_750.ReadAsArray().astype(np.float32)
diff_640_750 = R_640 - R_750
# 获取水域掩膜
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
# 逐波段处理:每次只读取和处理一个波段
for i in tqdm(range(self.n_bands), desc="处理波段 (GDAL)", total=self.n_bands):
# 读取当前波段(只加载一个波段到内存)
current_band = self.dataset.GetRasterBand(i + 1)
R = current_band.ReadAsArray().astype(np.float32)
# 校正计算
corrected_band = R - R_750 + self.A + self.B * diff_640_750
np.maximum(corrected_band, 0, out=corrected_band)
# 如果存在水域掩膜,只对水域区域应用校正
if water_mask_bool is not None:
corrected_band = np.where(water_mask_bool, corrected_band, R)
# 添加到结果列表corrected_band会保留在列表中
corrected_bands.append(corrected_band)
# 释放当前波段数据(显式删除有助于及时释放内存)
del R
return corrected_bands
def _get_corrected_bands_gdal_mem(self):
"""使用GDAL内存驱动处理numpy数组逐波段处理"""
# 创建内存数据集
driver = gdal.GetDriverByName('MEM')
mem_dataset = driver.Create('', self.width, self.height, self.n_bands, gdal.GDT_Float32)
# 将numpy数组写入内存数据集显示进度
for i in tqdm(range(self.n_bands), desc="加载波段到内存", total=self.n_bands):
band = mem_dataset.GetRasterBand(i + 1)
band.WriteArray(self.im_aligned[:,:,i])
band.FlushCache()
# 临时保存原始dataset引用
original_dataset = self.dataset
self.dataset = mem_dataset
try:
# 使用逐波段处理方法
result = self._get_corrected_bands_gdal()
finally:
# 恢复原始dataset
self.dataset = original_dataset
mem_dataset = None
return result
def _save_corrected_bands(self, corrected_bands):
"""
保存校正后的波段到文件BSQ格式ENVI格式
注意:为了节省内存,直接逐波段写入,不先堆叠成完整数组
:param corrected_bands: 校正后的波段列表
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if self.output_path is None:
return
import os
# 确保输出目录存在
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 从第一个波段获取尺寸信息(避免堆叠所有波段)
if not corrected_bands:
raise ValueError("校正后的波段列表为空")
first_band = corrected_bands[0]
height, width = first_band.shape
n_bands = len(corrected_bands)
# 获取地理变换和投影信息
if self.is_file_path and self.dataset is not None:
geotransform = self.dataset.GetGeoTransform()
projection = self.dataset.GetProjection()
else:
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(self.output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
else:
bsq_path = self.output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 直接逐波段写入(不先堆叠,节省内存)
for i in tqdm(range(n_bands), desc="保存波段", total=n_bands):
band = dataset.GetRasterBand(i + 1)
# 直接从列表中获取波段并写入,避免创建完整数组
band.WriteArray(corrected_bands[i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")
def get_corrected_bands(self):
"""
获取校正后的波段
根据输入类型和大小自动选择最优处理方法
:return: 校正后的波段列表
"""
# 如果输入是文件路径使用GDAL直接读取
if self.is_file_path:
if self.use_gdal:
corrected_bands = self._get_corrected_bands_gdal()
else:
raise ValueError("输入为文件路径时必须安装GDAL")
else:
# 如果输入是numpy数组
if self.use_gdal and self.height * self.width * self.n_bands > 100000000:
# 大图像使用GDAL内存驱动逐波段处理
corrected_bands = self._get_corrected_bands_gdal_mem()
else:
# 小图像使用numpy直接处理
corrected_bands = self._get_corrected_bands_numpy()
# 如果提供了输出路径,保存结果
if self.output_path is not None:
self._save_corrected_bands(corrected_bands)
return corrected_bands
def __del__(self):
"""清理资源"""
if self.dataset is not None and self.is_file_path:
self.dataset = None

View File

@ -0,0 +1,290 @@
import numpy as np
# import preprocessing
import os
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
class Hedley:
def __init__(self, im_aligned, shp_path=None, NIR_band = 47, water_mask=None, output_path=None):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
:param NIR_band (int): band index for NIR band which corresponds to 842.36nm, which corresponds closely to the NIR band in Micasense
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
如果为None则不保存
"""
self.im_aligned = im_aligned
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
self.NIR_band = NIR_band
self.n_bands = im_aligned.shape[-1]
self.height = im_aligned.shape[0]
self.width = im_aligned.shape[1]
self.output_path = output_path
# 加载水域掩膜
self.water_mask = self._load_water_mask(water_mask)
# 使用ravel()而不是flatten(),避免不必要的复制
# 如果存在水域掩膜只在掩膜内计算R_min
if self.water_mask is not None:
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
self.R_min = np.percentile(nir_band_masked, 5, interpolation='nearest') if nir_band_masked.size > 0 else 0
else:
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, interpolation='nearest')
def _read_shp_to_bbox(self, shp_path):
"""
读取shapefile并提取边界框
:param shp_path (str): shapefile文件路径
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
"""
if not os.path.exists(shp_path):
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
try:
try:
import geopandas as gpd
gdf = gpd.read_file(shp_path)
# 获取所有几何体的总边界框
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
min_x, min_y, max_x, max_y = bounds
except ImportError:
# 如果geopandas不可用尝试使用fiona
import fiona
from shapely.geometry import shape
min_x = float('inf')
min_y = float('inf')
max_x = float('-inf')
max_y = float('-inf')
with fiona.open(shp_path) as shp:
for feature in shp:
geom = shape(feature['geometry'])
if geom:
bounds = geom.bounds
min_x = min(min_x, bounds[0])
min_y = min(min_y, bounds[1])
max_x = max(max_x, bounds[2])
max_y = max(max_y, bounds[3])
# 转换为整数像素坐标
x1 = max(0, int(min_x))
y1 = max(0, int(min_y))
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
return ((x1, y1), (x2, y2))
except Exception as e:
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
def _load_water_mask(self, water_mask):
"""
加载水域掩膜
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
:return: numpy数组或None1表示水域0表示非水域
"""
if water_mask is None:
return None
# 如果已经是numpy数组
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
# 如果是文件路径
if isinstance(water_mask, str):
try:
from osgeo import gdal, ogr
except ImportError:
raise ValueError("使用文件路径作为掩膜时必须安装GDAL")
# 检查是否为shapefile
if water_mask.lower().endswith('.shp'):
# 从shp文件创建掩膜需要参考图像这里假设使用im_aligned的尺寸
# 注意如果输入是numpy数组无法从shp创建掩膜需要提供栅格参考
raise ValueError("Hedley类输入为numpy数组时无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
else:
# 栅格文件
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {water_mask}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def covariance_NIR(self,NIR,b):
"""
NIR & b are vectors
reflectance for band i
"""
n = len(NIR)
# 优化减少重复计算使用更高效的numpy操作
nir_mean = np.mean(NIR)
b_mean = np.mean(b)
# 使用更高效的协方差计算
pij = np.mean((NIR - nir_mean) * (b - b_mean))
pjj = np.mean((NIR - nir_mean) ** 2)
# 避免除零错误
return pij / pjj if pjj != 0 else 0.0
def correlation_bands_reflectance(self):
"""
calculate correlation between NIR and other bands for reflectance
NIR_band is 750 nm
"""
# If bbox is None, use the entire image
if self.bbox is None:
# 使用ravel()而不是flatten(),避免不必要的复制
# 直接使用视图,只在需要时创建扁平数组
im_region = self.im_aligned
mask_region = self.water_mask
else:
((x1,y1),(x2,y2)) = self.bbox
im_region = self.im_aligned[y1:y2,x1:x2,:]
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
# 如果存在水域掩膜,只在掩膜内计算相关性
if mask_region is not None:
mask_bool = mask_region.astype(bool)
if mask_bool.any():
# 只在掩膜内提取数据
NIR_reflectance = im_region[:,:,self.NIR_band][mask_bool]
else:
# 如果掩膜内没有有效像素,使用全区域
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
mask_bool = None
else:
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
mask_bool = None
# 优化:一次性计算所有波段的相关性,减少循环开销
corr_list = []
for v in range(self.n_bands):
if mask_bool is not None and mask_bool.any():
band_reflectance = im_region[:,:,v][mask_bool]
else:
band_reflectance = im_region[:,:,v].ravel()
corr = self.covariance_NIR(NIR_reflectance, band_reflectance)
corr_list.append(corr)
return corr_list
def _save_corrected_bands(self, corrected_bands):
"""
保存校正后的波段到文件BSQ格式ENVI格式
:param corrected_bands: 校正后的波段列表
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if self.output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 将波段列表转换为数组
corrected_array = np.stack(corrected_bands, axis=2)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(self.output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
else:
bsq_path = self.output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_array.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_array[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")
def get_corrected_bands(self):
"""
correction is done in reflectance
:return: 校正后的波段列表
"""
corr = self.correlation_bands_reflectance()
NIR_reflectance = self.im_aligned[:,:,self.NIR_band]
# 预计算NIR-R_min避免在循环中重复计算
NIR_diff = NIR_reflectance - self.R_min
# 获取水域掩膜(如果存在)
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
corrected_bands = []
for band_number in range(self.n_bands): #iterate across bands
b = corr[band_number]
R = self.im_aligned[:,:,band_number]
# 优化:减少中间数组创建
corrected_band = R - b * NIR_diff
# 如果存在水域掩膜,只对水域区域应用校正
if water_mask_bool is not None:
corrected_band = np.where(water_mask_bool, corrected_band, R)
corrected_bands.append(corrected_band)
# 如果提供了输出路径,保存结果
if self.output_path is not None:
self._save_corrected_bands(corrected_bands)
return corrected_bands

View File

@ -0,0 +1,313 @@
import numpy as np
# import preprocessing
import os
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
class Kutser:
def __init__(self, im_aligned, shp_path=None, oxy_band = 38,lower_oxy = 36, upper_oxy = 49, NIR_band = 47, water_mask=None, output_path=None):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
:param oxy_band (int): band index for oxygen absorption band, which corresponds to 760.6nm
:param lower_oxy (int): band index for outside oxygen absorption band, which corresponds to 742.39nm
:param upper_oxy (int): band index for outside oxygen absorption band, which corresponds to 860.48nm
see Kutser, Vahtmäe and Praks
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
如果为None则不保存
"""
self.im_aligned = im_aligned
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
self.oxy_band = oxy_band
self.lower_oxy = lower_oxy
self.upper_oxy = upper_oxy
self.NIR_band = NIR_band
self.n_bands = im_aligned.shape[-1]
self.height = im_aligned.shape[0]
self.width = im_aligned.shape[1]
self.output_path = output_path
# 加载水域掩膜
self.water_mask = self._load_water_mask(water_mask)
# 使用ravel()而不是flatten(),避免不必要的复制
# 如果存在水域掩膜只在掩膜内计算R_min
if self.water_mask is not None:
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
self.R_min = np.percentile(nir_band_masked, 5, interpolation='nearest') if nir_band_masked.size > 0 else 0
else:
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, interpolation='nearest')
def _read_shp_to_bbox(self, shp_path):
"""
读取shapefile并提取边界框
:param shp_path (str): shapefile文件路径
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
"""
if not os.path.exists(shp_path):
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
try:
try:
import geopandas as gpd
gdf = gpd.read_file(shp_path)
# 获取所有几何体的总边界框
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
min_x, min_y, max_x, max_y = bounds
except ImportError:
# 如果geopandas不可用尝试使用fiona
import fiona
from shapely.geometry import shape
min_x = float('inf')
min_y = float('inf')
max_x = float('-inf')
max_y = float('-inf')
with fiona.open(shp_path) as shp:
for feature in shp:
geom = shape(feature['geometry'])
if geom:
bounds = geom.bounds
min_x = min(min_x, bounds[0])
min_y = min(min_y, bounds[1])
max_x = max(max_x, bounds[2])
max_y = max(max_y, bounds[3])
# 转换为整数像素坐标
x1 = max(0, int(min_x))
y1 = max(0, int(min_y))
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
return ((x1, y1), (x2, y2))
except Exception as e:
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
def _load_water_mask(self, water_mask):
"""
加载水域掩膜
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
:return: numpy数组或None1表示水域0表示非水域
"""
if water_mask is None:
return None
# 如果已经是numpy数组
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
# 如果是文件路径
if isinstance(water_mask, str):
try:
from osgeo import gdal, ogr
except ImportError:
raise ValueError("使用文件路径作为掩膜时必须安装GDAL")
# 检查是否为shapefile
if water_mask.lower().endswith('.shp'):
# 从shp文件创建掩膜需要参考图像这里假设使用im_aligned的尺寸
# 注意如果输入是numpy数组无法从shp创建掩膜需要提供栅格参考
raise ValueError("Kutser类输入为numpy数组时无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
else:
# 栅格文件
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {water_mask}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def get_depth_D(self):
"""
Assume the amount of glint is proportional to the depth of the oxygen absorption feature, D
returns the normalised D by dividing it by the maximum D found in a deep water region
"""
# 优化:减少中间数组创建,使用更高效的计算
lower_oxy_band = self.im_aligned[:,:,self.lower_oxy]
upper_oxy_band = self.im_aligned[:,:,self.upper_oxy]
oxy_band = self.im_aligned[:,:,self.oxy_band]
D = (lower_oxy_band + upper_oxy_band) * 0.5 - oxy_band
# 确定用于计算D_max的区域
if self.bbox is None:
search_region = D
else:
((x1,y1),(x2,y2)) = self.bbox
search_region = D[y1:y2,x1:x2]
# 如果存在水域掩膜,只在掩膜内搜索最大值
if self.water_mask is not None:
if self.bbox is None:
mask_region = self.water_mask.astype(bool)
else:
((x1,y1),(x2,y2)) = self.bbox
mask_region = self.water_mask[y1:y2,x1:x2].astype(bool)
if mask_region.any():
D_max = search_region[mask_region].max()
else:
D_max = search_region.max()
else:
D_max = search_region.max() # assumed to be the maximum glint value
# 避免除零错误
if D_max == 0:
return np.zeros_like(D)
return D / D_max
def get_glint_G(self):
"""
The spectral variation of glint G is found by subtracting the spectrum at the darkest (ie. lowest D) NIR deep-water pixel from the brightest
returns G as a function of wavelength
"""
# If bbox is None, use the entire image
if self.bbox is None:
im_region = self.im_aligned
mask_region = self.water_mask
else:
((x1,y1),(x2,y2)) = self.bbox
im_region = self.im_aligned[y1:y2,x1:x2,:]
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
# 如果存在水域掩膜,只在掩膜内计算最大最小值
if mask_region is not None:
mask_bool = mask_region.astype(bool)
if mask_bool.any():
# 对每个波段,只在掩膜内计算最大最小值
G_list = []
for i in range(self.n_bands):
band_data = im_region[:,:,i]
G_max = band_data[mask_bool].max()
G_min = band_data[mask_bool].min()
G_list.append(G_max - G_min)
else:
# 如果掩膜内没有有效像素,使用全区域
G_max = np.amax(im_region, axis=(0, 1))
G_min = np.amin(im_region, axis=(0, 1))
G_list = (G_max - G_min).tolist()
else:
# 优化:一次性计算所有波段的最大最小值,减少循环开销
# 使用numpy的amax和amin沿最后一个轴计算
G_max = np.amax(im_region, axis=(0, 1)) # 沿空间维度计算最大值
G_min = np.amin(im_region, axis=(0, 1)) # 沿空间维度计算最小值
G_list = (G_max - G_min).tolist()
return G_list
def _save_corrected_bands(self, corrected_bands):
"""
保存校正后的波段到文件BSQ格式ENVI格式
:param corrected_bands: 校正后的波段列表
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if self.output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 将波段列表转换为数组
corrected_array = np.stack(corrected_bands, axis=2)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(self.output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
else:
bsq_path = self.output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_array.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_array[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")
def get_corrected_bands(self):
"""
correction is done in reflectance
:return: 校正后的波段列表
"""
g_list = self.get_glint_G()
D = self.get_depth_D()
# 获取水域掩膜(如果存在)
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
corrected_bands = []
for band_number in range(self.n_bands): #iterate across bands
G = g_list[band_number]
R = self.im_aligned[:,:,band_number]
# 优化:减少中间数组创建,直接计算
corrected_band = R - G * D
# 如果存在水域掩膜,只对水域区域应用校正
if water_mask_bool is not None:
corrected_band = np.where(water_mask_bool, corrected_band, R)
corrected_bands.append(corrected_band)
# 如果提供了输出路径,保存结果
if self.output_path is not None:
self._save_corrected_bands(corrected_bands)
return corrected_bands

View File

@ -0,0 +1,572 @@
import cv2
import os
import numpy as np
from scipy import ndimage
from scipy.optimize import minimize_scalar
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# SUn-Glint-Aware Restoration (SUGAR):A sweet and simple algorithm for correcting sunglint
class SUGAR:
def __init__(self, im_aligned,bounds=[(1,2)],sigma=1,estimate_background=True, glint_mask_method="cdf", water_mask=None, output_path=None):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param bounds (a list of tuple): lower and upper bound for optimisation of b for each band
:param sigma (float): smoothing sigma for LoG
:param estimate_background (bool): whether to estimate background spectra using median filtering
:param glint_mask_method (str): choose either "cdf" or "otsu", "cdf" is set as the default
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
如果为None则不保存
"""
self.im_aligned = im_aligned
self.sigma = sigma
self.estimate_background = estimate_background
self.n_bands = im_aligned.shape[-1]
self.bounds = bounds*self.n_bands
self.glint_mask_method = glint_mask_method
self.height = im_aligned.shape[0]
self.width = im_aligned.shape[1]
self.output_path = output_path
# 加载水域掩膜
self.water_mask = self._load_water_mask(water_mask)
def _load_water_mask(self, water_mask):
"""
加载水域掩膜
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
:return: numpy数组或None1表示水域0表示非水域
"""
if water_mask is None:
return None
# 如果已经是numpy数组
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
# 如果是文件路径
if isinstance(water_mask, str):
try:
from osgeo import gdal, ogr
except ImportError:
raise ValueError("使用文件路径作为掩膜时必须安装GDAL")
# 检查是否为shapefile
if water_mask.lower().endswith('.shp'):
# 从shp文件创建掩膜需要参考图像这里假设使用im_aligned的尺寸
# 注意如果输入是numpy数组无法从shp创建掩膜需要提供栅格参考
raise ValueError("SUGAR类输入为numpy数组时无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
else:
# 栅格文件
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {water_mask}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def otsu_thresholding(self,im):
"""
:param im (np.ndarray) of shape mxn. Note that it is the LoG of image
otsu thresholding with Brent's minimisation of a univariate function
returns the value of the threshold for input
"""
auto_bins = int(0.005*im.shape[0]*im.shape[1])
# 使用ravel()而不是flatten(),避免不必要的复制(如果可能)
# 如果存在无效值如NaN或极大值过滤掉它们
im_flat = im.ravel()
# 过滤掉NaN和无穷大值
valid_mask = np.isfinite(im_flat)
if not valid_mask.all():
im_flat = im_flat[valid_mask]
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers使用乘法替代除法
count_sum = count.sum()
hist_norm = count / count_sum # normalised histogram
Q = hist_norm.cumsum() # CDF function ranges from 0 to 1
N = count.shape[0]
N_negative = np.sum(bin < 0)
bins = np.arange(N, dtype=np.float32) # 使用float32减少内存
def otsu_thresh(x):
x = int(x)
# 使用切片而不是hsplit避免创建新数组
p1 = hist_norm[:x]
p2 = hist_norm[x:]
q1 = Q[x]
q2 = Q[N-1] - Q[x]
b1 = bins[:x]
b2 = bins[x:]
# finding means and variances
m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0
m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0
v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0
v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0
# calculates the minimization function
fn = v1 * q1 + v2 * q2
return fn
# brent method is used to minimise an univariate function
# bounded minimisation
# we can just limit the search to negative values since we know thresh should be negative as L<0 for glint pixels
if N_negative <= 1:
# 如果没有足够的负值,使用默认阈值
return bin[np.argmax(count)]
res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded')
thresh = bin[int(res.x)]
return thresh
# def cdf_thresholding(self,im, percentile=0.05):
# """
# :param im (np.ndarray) of shape mxn
# :param percentile (float): lower and upper percentile values are potential glint pixels
# """
# lower_perc = percentile
# upper_perc = 1-percentile
# im_flatten = im.flatten()
# H,X1 = np.histogram(im_flatten, bins = int(0.005*im.shape[0]*im.shape[1]), density=True )
# dx = X1[1] - X1[0]
# F1 = np.cumsum(H)*dx
# F_lower = X1[1:][F1<lower_perc]
# F_upper = X1[1:][F1>upper_perc]
# while((F_lower.size == 0) or (F_upper.size == 0)):
# if (F_lower.size == 0):
# lower_perc += 0.01
# F_lower = X1[1:][F1<lower_perc]
# if (F_upper.size == 0):
# upper_perc -= 0.01
# F_upper = X1[1:][F1>upper_perc]
# lower_thresh = F_lower[-1]
# upper_thresh = F_upper[0]
# return lower_thresh,upper_thresh
def cdf_thresholding(self,im,auto_bins=10):
"""
:param im (np.ndarray) of shape mxn. Note that it is the LoG of image
:param percentile (float): lower and upper percentile values are potential glint pixels
"""
# 使用ravel()而不是flatten(),避免不必要的复制
im_flat = im.ravel()
# 过滤掉NaN和无穷大值
valid_mask = np.isfinite(im_flat)
if not valid_mask.all():
im_flat = im_flat[valid_mask]
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers使用乘法替代除法
thresh = bin[np.argmax(count)]
return thresh
def glint_list(self):
"""
returns a list of np.ndarray, where each item is an extracted glint for each band based on get_glint_mask
"""
glint_mask = self.glint_mask_list()
extracted_glint_list = []
for i in range(self.im_aligned.shape[-1]):
gm = glint_mask[i]
extracted_glint = gm*self.im_aligned[:,:,i]
extracted_glint_list.append(extracted_glint)
return extracted_glint_list
def glint_mask_list(self):
"""
get glint mask using laplacian of gaussian image.
returns a list of np.ndarray
"""
glint_mask_list = []
for i in range(self.im_aligned.shape[-1]):
glint_mask = self.get_glint_mask(self.im_aligned[:,:,i])
glint_mask_list.append(glint_mask)
return glint_mask_list
def log_image_list(self):
"""
get Laplacian of Gaussian (LoG) images for all bands.
returns a list of np.ndarray
"""
log_image_list = []
for i in range(self.im_aligned.shape[-1]):
log_im = self.get_log_image(self.im_aligned[:,:,i])
log_image_list.append(log_im)
return log_image_list
def get_log_image(self, im):
"""
get Laplacian of Gaussian (LoG) image for a single band.
returns a np.ndarray
"""
LoG_im = ndimage.gaussian_laplace(im, sigma=self.sigma)
return LoG_im
def get_glint_mask(self,im):
"""
get glint mask using laplacian of gaussian image.
We assume that water constituents and features follow a smooth continuum,
but glint pixels vary a lot spatially and in intensities
Note that for very extensive glint, this method may not work as well <--:TODO use U-net to identify glint mask
returns a np.ndarray
"""
LoG_im = ndimage.gaussian_laplace(im,sigma=self.sigma)
# 如果存在水域掩膜,只在掩膜内计算阈值
if self.water_mask is not None:
mask_bool = self.water_mask.astype(bool)
if mask_bool.any():
# 只在掩膜内提取LoG值用于阈值计算
LoG_masked = LoG_im[mask_bool]
# 将非掩膜区域设为极大值,确保不影响阈值计算
LoG_for_thresh = LoG_im.copy()
LoG_for_thresh[~mask_bool] = LoG_masked.max() + 1
else:
LoG_for_thresh = LoG_im
else:
LoG_for_thresh = LoG_im
#threshold mask
if (self.glint_mask_method == "otsu"):
thresh = self.otsu_thresholding(LoG_for_thresh)
elif (self.glint_mask_method == "cdf"):
thresh = self.cdf_thresholding(LoG_for_thresh)
else:
raise ValueError('Enter only cdf or otsu as glint_mask_method')
# 使用更高效的方式创建mask避免np.where的开销
glint_mask = (LoG_im < thresh).astype(np.uint8)
# 如果存在水域掩膜将非水域区域设为0
if self.water_mask is not None:
glint_mask = glint_mask * self.water_mask
return glint_mask
def get_est_background(self, im,k_size=5):
"""
:param im (np.ndarray): image of a band
estimate background spectra
returns a np.ndarray
"""
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k_size,k_size))
dst = cv2.erode(im, kernel)
return dst
def optimise_correction_by_band(self,im,glint_mask,R_BG,bounds):
"""
:param im (np.ndarray): image of a band
:param glint_mask (np.ndarray): glint mask, where glint area is 1 and non-glint area is 0
use brent method to get the optimimum b which minimises the variation (i.e. variance) in the entire image
returns regression slope b
"""
# 预计算常量,避免在优化函数中重复计算
glint_mask_bool = glint_mask.astype(bool)
R_BG_flat = R_BG if isinstance(R_BG, (int, float)) else R_BG[glint_mask_bool]
def optimise_b(b):
# 优化计算只在glint区域计算校正
if isinstance(R_BG, (int, float)):
im_corrected = im.copy()
im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG)
else:
im_corrected = im.copy()
im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG[glint_mask_bool])
return np.var(im_corrected)
res = minimize_scalar(optimise_b, bounds=bounds, method='bounded')
return res.x
def divide_and_conquer(self):
"""
instead of computing b_list for each window, use the previous b_list to narrow the bounds,
because of the strong spatial autocorrelation, we know that the b (correction magnitude) cannot diff too much
this can optimise the run time
"""
def optimise_correction(self):
"""
returns a list of slope in band order i.e. 0,1,2,3,4,5,6,7,8,9 through optimisation
"""
b_list = []
glint_mask_list = []
est_background_list = []
for i in range(self.n_bands):
glint_mask = self.get_glint_mask(self.im_aligned[:,:,i])
glint_mask_list.append(glint_mask)
if self.estimate_background is True:
est_background = self.get_est_background(self.im_aligned[:,:,i])
est_background_list.append(est_background)
else:
est_background = np.percentile(self.im_aligned[:,:,i], 5, interpolation='nearest')
est_background_list.append(est_background)
bounds = self.bounds[i]
b = self.optimise_correction_by_band(self.im_aligned[:,:,i],glint_mask,est_background,bounds)
b_list.append(b)
# add attributes
self.b_list = b_list
self.glint_mask = glint_mask_list
self.est_background = est_background_list
return b_list, glint_mask_list, est_background_list
def _save_corrected_bands(self, corrected_bands):
"""
保存校正后的波段到文件BSQ格式ENVI格式
:param corrected_bands: 校正后的波段列表
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if self.output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 将波段列表转换为数组
corrected_array = np.stack(corrected_bands, axis=2)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(self.output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
else:
bsq_path = self.output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_array.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_array[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")
def get_corrected_bands(self):
"""
获取校正后的波段
:return: 校正后的波段列表
"""
corrected_bands = []
# 获取水域掩膜(如果存在)
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
for i in range(self.n_bands):
im_band = self.im_aligned[:,:,i]
# 一次性计算mask和background避免重复计算
glint_mask = self.get_glint_mask(im_band)
background = self.get_est_background(im_band, k_size=5)
# 使用视图和原地操作减少内存
im_corrected = im_band.copy()
glint_mask_bool = glint_mask.astype(bool)
im_corrected[glint_mask_bool] = background[glint_mask_bool]
# 如果存在水域掩膜,确保只在水域内应用校正
if water_mask_bool is not None:
# 只在水域掩膜内应用校正
correction_mask = glint_mask_bool & water_mask_bool
im_corrected = np.where(correction_mask, background, im_band)
# 非水域区域保持原值
im_corrected = np.where(water_mask_bool, im_corrected, im_band)
corrected_bands.append(im_corrected)
# 如果提供了输出路径,保存结果
if self.output_path is not None:
self._save_corrected_bands(corrected_bands)
return corrected_bands
def correction_iterative(im_aligned,iter=3,bounds = [(1,2)],estimate_background=True,glint_mask_method="cdf",get_glint_mask=False,termination_thresh = 20, water_mask=None, output_path=None):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param iter (int or None): number of iterations to run the sugar algorithm. If None, termination conditions are automatically applied
:param bounds (list of tuples): to limit correction magnitude
:param get_glint_mask (np.ndarray):
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存最后一次迭代的校正结果
如果为None则不保存
conducts iterative correction using SUGAR
"""
glint_image = im_aligned.copy()
corrected_images = []
if iter is None:
# termination conditions
relative_difference = lambda sd0,sd1: sd1/sd0*100
marginal_difference = lambda sd1,sd2: (sd1-sd2)/sd1*100
relative_diff_thresh = marginal_difference_thresh = termination_thresh
sd_og = np.var(im_aligned)
iter_count = 0
sd_next = sd_og # 不需要copy直接使用值
max_iter = 100 # 添加最大迭代次数限制,防止无限循环
while ((relative_difference(sd_og,sd_next) > relative_diff_thresh) and iter_count < max_iter):
# do all the processing here
HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask)
corrected_bands = HM.get_corrected_bands()
glint_image = np.stack(corrected_bands,axis=2)
sd_temp = np.var(glint_image)
# 只在需要时保存中间结果,减少内存占用
if get_glint_mask or iter_count == 0:
corrected_images.append(glint_image.copy())
else:
corrected_images.append(glint_image) # 最后一次迭代的结果
# save glint_mask
# if iter_count == 0 and get_glint_mask is True:
# glint_mask = np.stack(HM.glint_mask,axis=2)
if (marginal_difference(sd_next,sd_temp)<marginal_difference_thresh):
break
else:
sd_next = sd_temp
#increase count
iter_count += 1
# 如果提供了输出路径,保存最后一次迭代的结果
if output_path is not None and len(corrected_images) > 0:
_save_corrected_image(corrected_images[-1], output_path)
else:
for i in range(iter):
HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask)
corrected_bands = HM.get_corrected_bands()
glint_image = np.stack(corrected_bands,axis=2)
# 只在最后一次迭代或需要时保存所有结果
if i == iter - 1 or get_glint_mask:
corrected_images.append(glint_image.copy())
else:
# 对于中间迭代,可以只保存引用(但要注意内存管理)
corrected_images.append(glint_image)
# save glint_mask
# if i == 0 and get_glint_mask is True:
# glint_mask = np.stack(HM.glint_mask,axis=2)
# 如果提供了输出路径,保存最后一次迭代的结果
if output_path is not None and len(corrected_images) > 0:
_save_corrected_image(corrected_images[-1], output_path)
return corrected_images
def _save_corrected_image(corrected_image, output_path):
"""
保存校正后的图像到文件用于correction_iterative函数BSQ格式ENVI格式
:param corrected_image: 校正后的图像数组,形状为(height, width, bands)
:param output_path: 输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
else:
bsq_path = output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_image.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_image[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,926 @@
from osgeo import gdal, osr
import numpy as np
import pandas as pd
import os
import spectral
from math import sin, cos, tan, sqrt, radians
try:
from scipy.ndimage import distance_transform_edt
from scipy.spatial import cKDTree
SCIPY_AVAILABLE = True
except ImportError:
SCIPY_AVAILABLE = False
# 启用GDAL异常处理
osr.UseExceptions()
# WGS84椭球参数
WGS84_A = 6378137.0 # 长半轴(米)
WGS84_F = 1 / 298.257223563 # 扁率
WGS84_E2 = WGS84_F * (2 - WGS84_F) # 第一偏心率平方
WGS84_EP2 = WGS84_E2 / (1 - WGS84_E2) # 第二偏心率平方
UTM_K0 = 0.9996 # UTM比例因子
def pixel_to_geo(pixel_x, pixel_y, geotransform):
"""
像素坐标转换为地图坐标
"""
geo_x = geotransform[0] + pixel_x * geotransform[1] + pixel_y * geotransform[2]
geo_y = geotransform[3] + pixel_x * geotransform[4] + pixel_y * geotransform[5]
return geo_x, geo_y
def prepare_boundary_adjuster(boundary_mask):
"""
为边界掩膜构建辅助结构,用于根据半径调整采样中心
"""
if not SCIPY_AVAILABLE:
print("警告: 未安装SciPy无法根据水体边界自动调整采样点位置。")
return None
if boundary_mask is None:
return None
boundary_bool = boundary_mask > 0
if not np.any(boundary_bool):
print("警告: 边界掩膜中未检测到有效水域,无法调整采样点。")
return None
distance_map = distance_transform_edt(boundary_bool.astype(np.uint8))
return {
'mask': boundary_bool,
'distance_map': distance_map,
'trees': {}
}
def _get_boundary_tree(adjuster, radius):
"""
根据半径获取或构建适用的KDTree
"""
radius_key = float(radius)
if radius_key in adjuster['trees']:
return adjuster['trees'][radius_key]
distance_map = adjuster['distance_map']
valid_positions = np.column_stack(np.where(distance_map >= radius_key))
if valid_positions.size == 0:
adjuster['trees'][radius_key] = None
return None
tree = cKDTree(valid_positions)
adjuster['trees'][radius_key] = (tree, valid_positions)
return adjuster['trees'][radius_key]
def adjust_sampling_center(pixel_x, pixel_y, radius, adjuster):
"""
如果采样半径范围超出水体边界,则将像素向内移动
直至采样区域完全位于水体内部(与边界相切)
"""
if adjuster is None or radius <= 0:
return pixel_x, pixel_y, False
distance_map = adjuster['distance_map']
mask = adjuster['mask']
if pixel_y < 0 or pixel_y >= distance_map.shape[0] or pixel_x < 0 or pixel_x >= distance_map.shape[1]:
return pixel_x, pixel_y, False
if not mask[pixel_y, pixel_x]:
# 当前像素不在水域内,需要移动到最近的合法位置
tree_info = _get_boundary_tree(adjuster, max(radius, 1))
if tree_info is None:
return pixel_x, pixel_y, False
else:
if distance_map[pixel_y, pixel_x] >= radius:
return pixel_x, pixel_y, False
tree_info = _get_boundary_tree(adjuster, radius)
if tree_info is None:
# 没有任何可以容纳该半径的像素,直接返回原位置
return pixel_x, pixel_y, False
tree, valid_positions = tree_info
if tree is None or valid_positions.size == 0:
return pixel_x, pixel_y, False
# 查询附近潜在位置
max_candidates = min(64, len(valid_positions))
distances, indices = tree.query([pixel_y, pixel_x], k=max_candidates)
if np.isscalar(indices):
indices = [int(indices)]
else:
indices = np.atleast_1d(indices).astype(int)
best_candidate = None
best_delta = None
for idx in indices:
cy, cx = valid_positions[idx]
if distance_map[cy, cx] < radius:
continue
delta = distance_map[cy, cx] - radius
center_shift = (cx - pixel_x) ** 2 + (cy - pixel_y) ** 2
score = (abs(delta), center_shift)
if best_candidate is None or score < best_delta:
best_candidate = (cx, cy)
best_delta = score
if best_candidate is None:
# 没有找到满足条件的候选点
return pixel_x, pixel_y, False
return int(best_candidate[0]), int(best_candidate[1]), True
def transform_coordinates(lon, lat, source_srs, target_srs):
"""
坐标系转换
Args:
lon: 经度
lat: 纬度
source_srs: 源坐标系
target_srs: 目标坐标系
Returns:
transformed_lon, transformed_lat: 转换后的坐标
"""
# 创建坐标转换对象
transform = osr.CoordinateTransformation(source_srs, target_srs)
# 执行坐标转换
point = transform.TransformPoint(lon, lat)
return point[0], point[1]
def geo_to_pixel(lon, lat, geotransform, dataset_srs=None):
"""
地理坐标转换为像素坐标
Args:
lon: 经度
lat: 纬度
geotransform: 仿射变换参数
dataset_srs: 数据集的空间参考系统(可选)
Returns:
pixel_x, pixel_y: 像素坐标
"""
# 使用仿射变换的逆变换将地理坐标转换为像素坐标
x_origin = geotransform[0]
y_origin = geotransform[3]
pixel_width = geotransform[1]
pixel_height = geotransform[5]
pixel_x = int((lon - x_origin) / pixel_width)
pixel_y = int((lat - y_origin) / pixel_height)
return pixel_x, pixel_y
def get_pixel_spectrum_batch(dataset, pixel_x_array, pixel_y_array):
"""
批量获取多个像素点的光谱数据(优化版本)
Args:
dataset: GDAL数据集
pixel_x_array: 像素X坐标数组
pixel_y_array: 像素Y坐标数组
Returns:
spectrum_array: 光谱数据数组 (n_points, n_bands)
"""
n_points = len(pixel_x_array)
n_bands = dataset.RasterCount
# 初始化输出数组
spectrum_array = np.zeros((n_points, n_bands), dtype=np.float32)
# 按波段批量读取(更高效)
for band_idx in range(n_bands):
band = dataset.GetRasterBand(band_idx + 1) # GDAL波段索引从1开始
band_data = band.ReadAsArray() # 读取整个波段
# 批量提取像素值
for i in range(n_points):
px, py = int(pixel_x_array[i]), int(pixel_y_array[i])
if 0 <= px < band_data.shape[1] and 0 <= py < band_data.shape[0]:
spectrum_array[i, band_idx] = band_data[py, px]
else:
spectrum_array[i, band_idx] = np.nan
return spectrum_array
def get_average_spectral_in_radius(dataset, center_x, center_y, radius, flare_mask=None, boundary_mask=None):
"""
获取指定半径内的平均光谱,避开耀斑和边界区域
Args:
dataset: GDAL数据集
center_x, center_y: 中心像素坐标
radius: 半径(像素)
flare_mask: 耀斑掩膜数组(可选)
boundary_mask: 边界掩膜数组(可选)
Returns:
平均光谱值数组
"""
num_bands = dataset.RasterCount
# 计算采样区域边界
x_start = max(0, center_x - radius)
x_end = min(dataset.RasterXSize, center_x + radius + 1)
y_start = max(0, center_y - radius)
y_end = min(dataset.RasterYSize, center_y + radius + 1)
# 读取区域数据
width = x_end - x_start
height = y_end - y_start
if width <= 0 or height <= 0:
return np.zeros(num_bands)
# 读取所有波段数据
spectral_data = dataset.ReadAsArray(x_start, y_start, width, height)
if spectral_data is None:
return np.zeros(num_bands)
# 确保数据是3维的 (bands, height, width)
if len(spectral_data.shape) == 2:
spectral_data = spectral_data.reshape(1, spectral_data.shape[0], spectral_data.shape[1])
# 创建圆形掩膜
y_indices, x_indices = np.ogrid[:height, :width]
center_x_local = center_x - x_start
center_y_local = center_y - y_start
# 计算距离掩膜
distance_mask = ((x_indices - center_x_local) ** 2 + (y_indices - center_y_local) ** 2) <= radius ** 2
# 应用耀斑掩膜(如果提供)
if flare_mask is not None:
flare_region = flare_mask[y_start:y_end, x_start:x_end]
if flare_region.shape == distance_mask.shape:
distance_mask = distance_mask & (flare_region == 0) # 假设0表示无耀斑
# 应用边界掩膜(如果提供)
if boundary_mask is not None:
boundary_region = boundary_mask[y_start:y_end, x_start:x_end]
if boundary_region.shape == distance_mask.shape:
distance_mask = distance_mask & (boundary_region == 1) # 假设0表示无边界
# 计算平均光谱
average_spectrum = np.zeros(num_bands)
valid_pixels = np.sum(distance_mask)
if valid_pixels > 0:
for band in range(num_bands):
band_data = spectral_data[band, :, :]
# 排除无效值
valid_data = band_data[distance_mask & (band_data != 0) & np.isfinite(band_data)]
if len(valid_data) > 0:
average_spectrum[band] = np.mean(valid_data)
return average_spectrum
def load_mask_file(mask_path):
"""
加载掩膜文件
Args:
mask_path: 掩膜文件路径(支持栅格文件如.dat/.tif等
Returns:
掩膜数组
"""
if mask_path is None or not os.path.exists(mask_path):
return None
try:
# 使用gdal.OpenEx打开文件明确指定为栅格文件
# 如果文件是矢量格式会返回None避免"多图层"错误
dataset = gdal.OpenEx(mask_path, gdal.OF_RASTER)
if dataset is None:
# 如果OpenEx失败尝试使用Open向后兼容
dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
if dataset is None:
print(f"警告: 无法打开掩膜文件 {mask_path},可能不是有效的栅格文件")
return None
# 检查是否为栅格数据集有RasterCount属性
if not hasattr(dataset, 'RasterCount') or dataset.RasterCount == 0:
print(f"警告: {mask_path} 不是有效的栅格文件")
del dataset
return None
mask_data = dataset.GetRasterBand(1).ReadAsArray()
del dataset
return mask_data
except Exception as e:
print(f"警告: 加载掩膜文件 {mask_path} 时出错: {str(e)}")
return None
def get_hdr_file_path(file_path):
"""
获取HDR文件路径
Args:
file_path: 影像文件路径
Returns:
HDR文件路径
"""
return os.path.splitext(file_path)[0] + ".hdr"
def calculate_utm_zone(longitude):
"""
根据经度计算UTM分区号
Args:
longitude: 经度
Returns:
utm_zone: UTM分区号1-60
"""
# UTM分区从180度开始每个分区6度
utm_zone = int((longitude + 180) / 6) + 1
# 确保分区号在有效范围内
utm_zone = max(1, min(60, utm_zone))
return utm_zone
def latlon_to_utm_math(lat_deg, lon_deg, zone=None):
"""
使用数学公式将WGS84经纬度转换为UTM坐标
Args:
lat_deg: 纬度(度)
lon_deg: 经度(度)
zone: UTM分区号如果为None则根据经度自动计算
Returns:
easting, northing: UTM坐标
"""
# 如果未指定分区,根据经度计算
if zone is None:
zone = calculate_utm_zone(lon_deg)
# 计算中央经线(度)
lon0 = (zone * 6 - 183)
lam0 = radians(lon0)
# 转换为弧度
phi = radians(lat_deg)
lam = radians(lon_deg)
# 计算中间变量
sinphi = sin(phi)
cosphi = cos(phi)
tanphi = tan(phi)
# 计算卯酉圈曲率半径
N = WGS84_A / sqrt(1 - WGS84_E2 * sinphi * sinphi)
T = tanphi * tanphi
C = WGS84_EP2 * cosphi * cosphi
A = cosphi * (lam - lam0)
# 计算子午圈弧长使用Snyder公式
M = (WGS84_A * ((1 - WGS84_E2/4 - 3*WGS84_E2**2/64 - 5*WGS84_E2**3/256) * phi
- (3*WGS84_E2/8 + 3*WGS84_E2**2/32 + 45*WGS84_E2**3/1024) * sin(2*phi)
+ (15*WGS84_E2**2/256 + 45*WGS84_E2**3/1024) * sin(4*phi)
- (35*WGS84_E2**3/3072) * sin(6*phi)))
# 计算东坐标Easting
E = (UTM_K0 * N * (A + (1 - T + C) * A**3 / 6
+ (5 - 18*T + T*T + 72*C - 58*WGS84_EP2) * A**5 / 120)
+ 500000.0)
# 计算北坐标Northing
# 对于南半球需要添加10000000米偏移
if lat_deg < 0:
Nn = (UTM_K0 * (M + N * tanphi * (A**2 / 2
+ (5 - T + 9*C + 4*C*C) * A**4 / 24
+ (61 - 58*T + T*T + 600*C - 330*WGS84_EP2) * A**6 / 720))
+ 10000000.0)
else:
Nn = (UTM_K0 * (M + N * tanphi * (A**2 / 2
+ (5 - T + 9*C + 4*C*C) * A**4 / 24
+ (61 - 58*T + T*T + 600*C - 330*WGS84_EP2) * A**6 / 720)))
return E, Nn
def convert_to_utm(lon, lat, source_epsg=4326, target_epsg=None):
"""
将坐标转换为UTM格式使用数学公式根据经度自动计算UTM分区
Args:
lon: 经度数组
lat: 纬度数组
source_epsg: 源坐标系EPSG代码默认为4326 (WGS84地理坐标系)
target_epsg: 目标坐标系EPSG代码如果为None则根据经度自动计算如果指定则从EPSG代码提取分区号
Returns:
utm_x, utm_y: 转换后的UTM坐标
"""
try:
# 检查源坐标系是否为WGS84
if source_epsg != 4326:
print(f"警告: 数学公式转换仅支持WGS84 (EPSG:4326)当前源坐标系为EPSG:{source_epsg}")
print("将尝试使用数学公式进行转换,但可能不准确")
# 批量转换坐标
utm_x = np.zeros_like(lon)
utm_y = np.zeros_like(lat)
# 如果指定了目标EPSG提取分区号
fixed_zone = None
if target_epsg is not None:
# 从EPSG代码提取分区号
# EPSG:32651 -> 51, EPSG:32751 -> 51
if 32601 <= target_epsg <= 32660:
fixed_zone = target_epsg - 32600
elif 32701 <= target_epsg <= 32760:
fixed_zone = target_epsg - 32700
else:
print(f"警告: 无法从EPSG代码 {target_epsg} 提取UTM分区号将根据经度自动计算")
# 向量化处理:标记无效坐标
invalid_mask = (np.isnan(lon) | np.isnan(lat) |
(lon < -180) | (lon > 180) |
(lat < -90) | (lat > 90))
# 统计无效坐标
invalid_count = np.sum(invalid_mask)
if invalid_count > 0:
invalid_indices = np.where(invalid_mask)[0]
print(f"警告: 发现 {invalid_count} 个无效坐标点,将跳过")
for idx in invalid_indices[:10]: # 只打印前10个
print(f" 坐标点 {idx + 1}: 经度={lon[idx]}, 纬度={lat[idx]}")
if invalid_count > 10:
print(f" ... 还有 {invalid_count - 10} 个无效坐标点")
# 对有效坐标进行转换
valid_mask = ~invalid_mask
if np.any(valid_mask):
valid_lon = lon[valid_mask]
valid_lat = lat[valid_mask]
valid_indices = np.where(valid_mask)[0]
# 计算UTM分区向量化
if fixed_zone is not None:
zones = np.full(len(valid_lon), fixed_zone)
else:
zones = np.array([calculate_utm_zone(lon_val) for lon_val in valid_lon])
# 批量转换(仍需要循环,但减少了开销)
for i, (lat_val, lon_val, zone) in enumerate(zip(valid_lat, valid_lon, zones)):
try:
E, Nn = latlon_to_utm_math(lat_val, lon_val, zone)
if not (np.isnan(E) or np.isnan(Nn) or np.isinf(E) or np.isinf(Nn)):
utm_x[valid_indices[i]] = E
utm_y[valid_indices[i]] = Nn
else:
utm_x[valid_indices[i]] = np.nan
utm_y[valid_indices[i]] = np.nan
except Exception as e:
utm_x[valid_indices[i]] = np.nan
utm_y[valid_indices[i]] = np.nan
# 设置无效坐标为NaN
utm_x[invalid_mask] = np.nan
utm_y[invalid_mask] = np.nan
return utm_x, utm_y
except Exception as e:
print(f"坐标转换初始化失败: {str(e)}")
return np.full_like(lon, np.nan), np.full_like(lat, np.nan)
def convert_to_utm51n(lon, lat, source_epsg=4326):
"""
将坐标转换为WGS84 UTM 51N格式保留向后兼容性
Args:
lon: 经度数组
lat: 纬度数组
source_epsg: 源坐标系EPSG代码默认为4326 (WGS84地理坐标系)
Returns:
utm_x, utm_y: 转换后的UTM坐标
"""
# 使用新的转换函数但强制使用UTM 51N
return convert_to_utm(lon, lat, source_epsg, target_epsg=32651)
def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None, boundary_path=None, source_epsg=4326):
"""
获取给定坐标的光谱曲线并将坐标转换为UTM格式根据经度自动计算UTM分区
Args:
imgpath: 影像文件路径BIL格式
coorpath: 坐标文件路径CSV格式第1、2列为纬度和经度
outpath: 输出文件路径CSV格式
radius: 采样半径(像素)
flare_path: 耀斑文件路径(可选)
boundary_path: 边界文件路径(可选)
source_epsg: 源坐标系EPSG代码默认为4326 (WGS84地理坐标系)
"""
# 读取原始坐标文件CSV格式
coor_df = None
coor_data = None
# 尝试不同的编码方式读取CSV文件
encodings = ['utf-8', 'gbk', 'gb2312', 'latin1', 'cp1252']
for encoding in encodings:
try:
# 尝试读取CSV文件
coor_df = pd.read_csv(coorpath, encoding=encoding)
# 只提取数值数据,跳过表头
coor_data = coor_df.select_dtypes(include=[np.number]).values
# 如果没有数值列,尝试转换所有列(跳过第一行表头)
if coor_data.shape[1] == 0:
# 尝试从第二行开始读取,第一行作为表头
coor_df = pd.read_csv(coorpath, encoding=encoding, header=0)
# 尝试将所有列转换为数值
numeric_df = coor_df.apply(pd.to_numeric, errors='coerce')
# 删除全为NaN的行通常是表头转换失败的行
numeric_df = numeric_df.dropna(how='all')
coor_data = numeric_df.values
print(f"成功使用 {encoding} 编码读取文件")
break
except Exception as e:
print(f"使用 {encoding} 编码读取失败: {str(e)}")
continue
# 如果所有编码都失败尝试numpy读取
if coor_data is None:
try:
print("尝试使用numpy读取数值数据...")
# 跳过第一行(表头),只读取数值
coor_data = np.loadtxt(coorpath, delimiter=",", skiprows=1)
except:
try:
coor_data = np.loadtxt(coorpath, delimiter="\t", skiprows=1)
except Exception as e:
raise Exception(f"无法读取坐标文件,请检查文件格式: {str(e)}")
if len(coor_data.shape) == 1:
coor_data = coor_data.reshape(1, -1)
# 检查数据有效性
if coor_data is None or coor_data.shape[1] < 2:
raise Exception("坐标文件格式错误需要至少2列数据第1列为纬度第2列为经度")
print(f"成功读取坐标文件,共 {coor_data.shape[0]} 行,{coor_data.shape[1]}")
print(f"数据预览前3行")
for i in range(min(3, coor_data.shape[0])):
print(f"{i + 1}: {coor_data[i, :min(5, coor_data.shape[1])]}") # 只显示前5列
# 提取原始坐标
lat_array = coor_data[:, 0] # 第1列是纬度
lon_array = coor_data[:, 1] # 第2列是经度
print(f"\n=== 原始坐标信息 ===")
print(f"原始坐标范围: 经度 {np.min(lon_array):.6f} ~ {np.max(lon_array):.6f}, 纬度 {np.min(lat_array):.6f} ~ {np.max(lat_array):.6f}")
# 坐标转换为UTM根据经度自动计算UTM分区
print("正在进行坐标转换...")
utm_x, utm_y = convert_to_utm(lon_array, lat_array, source_epsg, target_epsg=None)
# 检查转换结果
valid_utm_mask = ~(np.isnan(utm_x) | np.isnan(utm_y) | np.isinf(utm_x) | np.isinf(utm_y))
valid_count = np.sum(valid_utm_mask)
if valid_count > 0:
print(f"转换后UTM坐标范围: X {np.nanmin(utm_x):.2f} ~ {np.nanmax(utm_x):.2f}, Y {np.nanmin(utm_y):.2f} ~ {np.nanmax(utm_y):.2f}")
print(f"成功转换 {valid_count}/{len(utm_x)} 个坐标点")
else:
print("警告: 所有UTM坐标转换都失败了将尝试使用原始经纬度坐标进行像素坐标转换")
# 打开影像数据集
dataset = gdal.Open(imgpath)
im_width = dataset.RasterXSize # 栅格矩阵的列数
im_height = dataset.RasterYSize # 栅格矩阵的行数
num_bands = dataset.RasterCount # 栅格矩阵的波段数
geotransform = dataset.GetGeoTransform() # 仿射矩阵
im_proj = dataset.GetProjection() # 地图投影信息
print(f"影像尺寸: {im_width} x {im_height}, 波段数: {num_bands}")
print(f"仿射变换参数: {geotransform}")
print("\n=== 开始光谱提取 ===")
# 加载掩膜文件
flare_mask = load_mask_file(flare_path)
boundary_mask = load_mask_file(boundary_path)
boundary_adjuster = None
if boundary_mask is not None and radius > 0:
boundary_adjuster = prepare_boundary_adjuster(boundary_mask)
if boundary_adjuster is None:
print("提示: 无法构建边界调整器,采样点将不会根据水体边界进行内移。")
# 获取数据集的空间参考系统
dataset_srs = dataset.GetSpatialRef()
# 准备输出数组在原有数据基础上添加UTM坐标和光谱列
original_cols = coor_data.shape[1]
# 添加UTM坐标列2列和光谱列num_bands列
new_columns = np.zeros((coor_data.shape[0], 2 + num_bands))
coor_spectral = np.hstack((coor_data, new_columns))
# 将UTM坐标添加到数据中会在采样点调整后再更新为最终位置
coor_spectral[:, original_cols] = utm_x # 初始UTM X坐标
coor_spectral[:, original_cols + 1] = utm_y # 初始UTM Y坐标
print(f"处理 {coor_data.shape[0]} 个坐标点...")
# 如果UTM转换失败尝试使用影像坐标系进行转换
use_utm_fallback = False
if valid_count == 0 and dataset_srs is not None:
print("尝试使用影像坐标系进行坐标转换...")
try:
source_srs = osr.SpatialReference()
source_srs.ImportFromEPSG(source_epsg)
transform_to_image = osr.CoordinateTransformation(source_srs, dataset_srs)
use_utm_fallback = True
except:
use_utm_fallback = False
# 批量转换所有坐标点为像素坐标
pixel_x_array = np.zeros(coor_data.shape[0], dtype=np.int32)
pixel_y_array = np.zeros(coor_data.shape[0], dtype=np.int32)
valid_pixel_mask = np.zeros(coor_data.shape[0], dtype=bool)
# 批量计算像素坐标
for i in range(coor_data.shape[0]):
# 优先使用UTM坐标如果无效则使用备用方案
utm_x_point = utm_x[i]
utm_y_point = utm_y[i]
# 检查UTM坐标是否有效
if np.isnan(utm_x_point) or np.isnan(utm_y_point) or np.isinf(utm_x_point) or np.isinf(utm_y_point):
# 如果UTM转换失败尝试使用影像坐标系
if use_utm_fallback:
try:
lon_point = lon_array[i]
lat_point = lat_array[i]
if not (np.isnan(lon_point) or np.isnan(lat_point)):
# 转换为影像坐标系
img_coords = transform_to_image.TransformPoint(lon_point, lat_point)
pixel_x, pixel_y = geo_to_pixel(img_coords[0], img_coords[1], geotransform, dataset_srs)
# 更新UTM坐标列使用影像坐标系坐标
coor_spectral[i, original_cols] = img_coords[0]
coor_spectral[i, original_cols + 1] = img_coords[1]
else:
print(f"跳过坐标点 {i + 1}: 坐标无效")
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
continue
except Exception as e:
# 如果影像坐标系转换也失败,尝试直接使用经纬度
try:
lon_point = lon_array[i]
lat_point = lat_array[i]
if not (np.isnan(lon_point) or np.isnan(lat_point)):
pixel_x, pixel_y = geo_to_pixel(lon_point, lat_point, geotransform, dataset_srs)
# 保留原始经纬度作为坐标
coor_spectral[i, original_cols] = lon_point
coor_spectral[i, original_cols + 1] = lat_point
else:
print(f"跳过坐标点 {i + 1}: 坐标无效")
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
continue
except:
print(f"跳过坐标点 {i + 1}: 所有坐标转换方式都失败")
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
continue
else:
# 尝试直接使用经纬度坐标
try:
lon_point = lon_array[i]
lat_point = lat_array[i]
if not (np.isnan(lon_point) or np.isnan(lat_point)):
pixel_x, pixel_y = geo_to_pixel(lon_point, lat_point, geotransform, dataset_srs)
# 保留原始经纬度作为坐标
coor_spectral[i, original_cols] = lon_point
coor_spectral[i, original_cols + 1] = lat_point
else:
print(f"跳过坐标点 {i + 1}: 坐标无效")
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
continue
except:
print(f"跳过坐标点 {i + 1}: 坐标转换失败")
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
continue
else:
# UTM坐标转换为像素坐标
pixel_x, pixel_y = geo_to_pixel(utm_x_point, utm_y_point, geotransform, dataset_srs)
# 存储像素坐标
pixel_x_array[i] = pixel_x
pixel_y_array[i] = pixel_y
# 根据水体边界调整采样中心
moved = False
original_pixel_x, original_pixel_y = pixel_x, pixel_y
if boundary_adjuster is not None and radius > 0:
new_pixel_x, new_pixel_y, moved = adjust_sampling_center(pixel_x, pixel_y, radius, boundary_adjuster)
if moved:
pixel_x, pixel_y = new_pixel_x, new_pixel_y
if i < 10 or (i % 100 == 0):
print(f" 采样点 {i + 1} 调整至水体内部: ({original_pixel_x}, {original_pixel_y}) -> ({pixel_x}, {pixel_y})")
pixel_x_array[i] = pixel_x
pixel_y_array[i] = pixel_y
# 检查坐标是否在影像范围内(使用调整后的坐标)
if 0 <= pixel_x < im_width and 0 <= pixel_y < im_height:
valid_pixel_mask[i] = True
# 更新UTM列为最终采样点的实际地图坐标
geo_x, geo_y = pixel_to_geo(pixel_x, pixel_y, geotransform)
coor_spectral[i, original_cols] = geo_x
coor_spectral[i, original_cols + 1] = geo_y
else:
valid_pixel_mask[i] = False
if i < 10 or (i % 100 == 0): # 只打印前10个或每100个打印一次
print(f"警告: 坐标点 {i + 1} (UTM X:{utm_x_point:.2f}, Y:{utm_y_point:.2f}) 超出影像范围")
# 批量提取光谱数据优化减少I/O操作
print(f"批量提取光谱数据... (有效坐标点: {np.sum(valid_pixel_mask)})")
if radius > 0:
# 半径采样模式:需要逐个处理
for i in range(coor_data.shape[0]):
if valid_pixel_mask[i]:
spectrum = get_average_spectral_in_radius(
dataset, pixel_x_array[i], pixel_y_array[i], radius, flare_mask, boundary_mask
)
coor_spectral[i, original_cols + 2:] = spectrum
else:
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
else:
# 单点采样模式:批量读取(优化)
# 预读取所有波段数据(如果内存允许)
try:
# 尝试读取所有波段到内存(适用于内存充足的情况)
print("正在预加载所有波段数据到内存(优化模式)...")
all_bands_data = []
for band_idx in range(num_bands):
band = dataset.GetRasterBand(band_idx + 1)
band_data = band.ReadAsArray()
all_bands_data.append(band_data)
all_bands_data = np.array(all_bands_data) # shape: (bands, height, width)
print("预加载完成,开始批量提取像素值...")
# 批量提取像素值
for i in range(coor_data.shape[0]):
if valid_pixel_mask[i]:
px, py = int(pixel_x_array[i]), int(pixel_y_array[i])
# GDAL读取的数组形状是 (bands, height, width),像素坐标 (x,y) 对应数组索引 [:, y, x]
# 注意py是行y坐标px是列x坐标
if 0 <= px < all_bands_data.shape[2] and 0 <= py < all_bands_data.shape[1]:
spectrum = all_bands_data[:, py, px] # 直接索引,非常快
coor_spectral[i, original_cols + 2:] = spectrum
else:
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
else:
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
# 释放内存
del all_bands_data
print("批量提取完成")
except MemoryError:
# 如果内存不足,回退到逐个波段读取
print("内存不足,使用逐个波段读取模式...")
for i in range(coor_data.shape[0]):
if valid_pixel_mask[i]:
px, py = pixel_x_array[i], pixel_y_array[i]
spectrum = np.zeros(num_bands)
for band_idx in range(num_bands):
band = dataset.GetRasterBand(band_idx + 1)
spectrum[band_idx] = band.ReadAsArray(px, py, 1, 1)[0, 0]
coor_spectral[i, original_cols + 2:] = spectrum
else:
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
del dataset
# 创建DataFrame用于CSV输出
# 去除前两列坐标列纬度和经度和UTM列
try:
# 如果原始数据有列名,使用原始列名(跳过前两列)
if coor_df is not None and hasattr(coor_df, 'columns'):
# 跳过前两列经纬度从第3列开始
if len(coor_df.columns) >= original_cols:
# 保留第3列及之后的原始列如果有的话
if original_cols > 2:
original_columns = list(coor_df.columns[2:original_cols])
else:
original_columns = []
else:
# 如果原始列数不足,只保留存在的列(跳过前两列)
if len(coor_df.columns) > 2:
original_columns = list(coor_df.columns[2:])
else:
original_columns = []
else:
# 如果没有列名只保留第3列及之后的列如果有的话
if original_cols > 2:
original_columns = ["col_" + str(j + 1) for j in range(2, original_cols)]
else:
original_columns = []
except:
# 异常处理只保留第3列及之后的列如果有的话
if original_cols > 2:
original_columns = ["col_" + str(j + 1) for j in range(2, original_cols)]
else:
original_columns = []
# 读取波长信息,用作光谱列名
wavelengths = None
try:
in_hdr_dict = spectral.envi.read_envi_header(get_hdr_file_path(imgpath))
wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64')
# 将波长值转换为字符串作为列名
spectral_columns = [str(wl) for wl in wavelengths]
print(f"成功读取波长信息,共 {len(spectral_columns)} 个波段")
except Exception as e:
print(f"警告: 无法读取波长信息 ({str(e)}),使用默认列名 band_1, band_2, ...")
spectral_columns = ["band_" + str(j + 1) for j in range(num_bands)]
# 构建输出列名不包含前两列坐标列和UTM列
all_columns = original_columns + spectral_columns
# 从coor_spectral中提取需要输出的列
# 跳过前两列经纬度和UTM列只保留
# - 第3列到第original_cols列如果有的话
# - 光谱数据列从original_cols+2开始
output_data = []
if original_cols > 2:
# 保留第3列到第original_cols列
output_data.append(coor_spectral[:, 2:original_cols])
# 保留光谱数据列从original_cols+2开始
output_data.append(coor_spectral[:, original_cols + 2:])
# 合并数据
if len(output_data) > 0:
output_array = np.hstack(output_data) if len(output_data) > 1 else output_data[0]
else:
# 如果没有原始列,只输出光谱数据
output_array = coor_spectral[:, original_cols + 2:]
# 创建结果DataFrame
result_df = pd.DataFrame(output_array, columns=all_columns)
# 保存为CSV格式
result_df.to_csv(outpath, index=False, float_format='%.6f')
print(f"结果已保存到CSV文件: {outpath}")
return coor_spectral
# 直接运行示例
if __name__ == '__main__':
# 在这里直接设置参数
imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径
radius = 5 # 采样半径像素0表示单点采样>0表示半径内平均
flare_path = r"E:\code\WQ\封装\work_dir\2_glint\severe_glint_area.dat" # 耀斑掩膜文件路径可选None表示不使用
boundary_path ="D:\BaiduNetdiskDownload\yaobao\water_mask.dat" # 边界掩膜文件路径可选None表示不使用
source_epsg = 4326 # 源坐标系EPSG代码默认为4326 (WGS84地理坐标系)
verbose = True # 是否启用详细模式
if verbose:
print(f"影像文件: {imgpath}")
print(f"坐标文件: {coorpath}")
print(f"输出文件: {output_path}")
print(f"采样半径: {radius}")
if flare_path:
print(f"耀斑掩膜: {flare_path}")
if boundary_path:
print(f"边界掩膜: {boundary_path}")
if source_epsg:
print(f"指定坐标系: EPSG:{source_epsg}")
tmp = get_spectral_in_coor(imgpath, coorpath, output_path,
radius, flare_path, boundary_path, source_epsg)

View File

@ -0,0 +1,785 @@
from osgeo import gdal, osr
import numpy as np
import pandas as pd
import os
import spectral
from math import sin, cos, tan, sqrt, radians
# 启用GDAL异常处理
osr.UseExceptions()
# WGS84椭球参数
WGS84_A = 6378137.0 # 长半轴(米)
WGS84_F = 1 / 298.257223563 # 扁率
WGS84_E2 = WGS84_F * (2 - WGS84_F) # 第一偏心率平方
WGS84_EP2 = WGS84_E2 / (1 - WGS84_E2) # 第二偏心率平方
UTM_K0 = 0.9996 # UTM比例因子
def transform_coordinates(lon, lat, source_srs, target_srs):
"""
坐标系转换
Args:
lon: 经度
lat: 纬度
source_srs: 源坐标系
target_srs: 目标坐标系
Returns:
transformed_lon, transformed_lat: 转换后的坐标
"""
# 创建坐标转换对象
transform = osr.CoordinateTransformation(source_srs, target_srs)
# 执行坐标转换
point = transform.TransformPoint(lon, lat)
return point[0], point[1]
def geo_to_pixel(lon, lat, geotransform, dataset_srs=None):
"""
地理坐标转换为像素坐标
Args:
lon: 经度
lat: 纬度
geotransform: 仿射变换参数
dataset_srs: 数据集的空间参考系统(可选)
Returns:
pixel_x, pixel_y: 像素坐标
"""
# 使用仿射变换的逆变换将地理坐标转换为像素坐标
x_origin = geotransform[0]
y_origin = geotransform[3]
pixel_width = geotransform[1]
pixel_height = geotransform[5]
pixel_x = int((lon - x_origin) / pixel_width)
pixel_y = int((lat - y_origin) / pixel_height)
return pixel_x, pixel_y
def get_pixel_spectrum_batch(dataset, pixel_x_array, pixel_y_array):
"""
批量获取多个像素点的光谱数据(优化版本)
Args:
dataset: GDAL数据集
pixel_x_array: 像素X坐标数组
pixel_y_array: 像素Y坐标数组
Returns:
spectrum_array: 光谱数据数组 (n_points, n_bands)
"""
n_points = len(pixel_x_array)
n_bands = dataset.RasterCount
# 初始化输出数组
spectrum_array = np.zeros((n_points, n_bands), dtype=np.float32)
# 按波段批量读取(更高效)
for band_idx in range(n_bands):
band = dataset.GetRasterBand(band_idx + 1) # GDAL波段索引从1开始
band_data = band.ReadAsArray() # 读取整个波段
# 批量提取像素值
for i in range(n_points):
px, py = int(pixel_x_array[i]), int(pixel_y_array[i])
if 0 <= px < band_data.shape[1] and 0 <= py < band_data.shape[0]:
spectrum_array[i, band_idx] = band_data[py, px]
else:
spectrum_array[i, band_idx] = np.nan
return spectrum_array
def get_average_spectral_in_radius(dataset, center_x, center_y, radius, flare_mask=None, boundary_mask=None):
"""
获取指定半径内的平均光谱,避开耀斑和边界区域
Args:
dataset: GDAL数据集
center_x, center_y: 中心像素坐标
radius: 半径(像素)
flare_mask: 耀斑掩膜数组(可选)
boundary_mask: 边界掩膜数组(可选)
Returns:
平均光谱值数组
"""
num_bands = dataset.RasterCount
# 计算采样区域边界
x_start = max(0, center_x - radius)
x_end = min(dataset.RasterXSize, center_x + radius + 1)
y_start = max(0, center_y - radius)
y_end = min(dataset.RasterYSize, center_y + radius + 1)
# 读取区域数据
width = x_end - x_start
height = y_end - y_start
if width <= 0 or height <= 0:
return np.zeros(num_bands)
# 读取所有波段数据
spectral_data = dataset.ReadAsArray(x_start, y_start, width, height)
if spectral_data is None:
return np.zeros(num_bands)
# 确保数据是3维的 (bands, height, width)
if len(spectral_data.shape) == 2:
spectral_data = spectral_data.reshape(1, spectral_data.shape[0], spectral_data.shape[1])
# 创建圆形掩膜
y_indices, x_indices = np.ogrid[:height, :width]
center_x_local = center_x - x_start
center_y_local = center_y - y_start
# 计算距离掩膜
distance_mask = ((x_indices - center_x_local) ** 2 + (y_indices - center_y_local) ** 2) <= radius ** 2
# 应用耀斑掩膜(如果提供)
if flare_mask is not None:
flare_region = flare_mask[y_start:y_end, x_start:x_end]
if flare_region.shape == distance_mask.shape:
distance_mask = distance_mask & (flare_region == 0) # 假设0表示无耀斑
# 应用边界掩膜(如果提供)
if boundary_mask is not None:
boundary_region = boundary_mask[y_start:y_end, x_start:x_end]
if boundary_region.shape == distance_mask.shape:
distance_mask = distance_mask & (boundary_region == 1) # 假设0表示无边界
# 计算平均光谱
average_spectrum = np.zeros(num_bands)
valid_pixels = np.sum(distance_mask)
if valid_pixels > 0:
for band in range(num_bands):
band_data = spectral_data[band, :, :]
# 排除无效值
valid_data = band_data[distance_mask & (band_data != 0) & np.isfinite(band_data)]
if len(valid_data) > 0:
average_spectrum[band] = np.mean(valid_data)
return average_spectrum
def load_mask_file(mask_path):
"""
加载掩膜文件
Args:
mask_path: 掩膜文件路径(支持栅格文件如.dat/.tif等
Returns:
掩膜数组
"""
if mask_path is None or not os.path.exists(mask_path):
return None
try:
# 使用gdal.OpenEx打开文件明确指定为栅格文件
# 如果文件是矢量格式会返回None避免"多图层"错误
dataset = gdal.OpenEx(mask_path, gdal.OF_RASTER)
if dataset is None:
# 如果OpenEx失败尝试使用Open向后兼容
dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
if dataset is None:
print(f"警告: 无法打开掩膜文件 {mask_path},可能不是有效的栅格文件")
return None
# 检查是否为栅格数据集有RasterCount属性
if not hasattr(dataset, 'RasterCount') or dataset.RasterCount == 0:
print(f"警告: {mask_path} 不是有效的栅格文件")
del dataset
return None
mask_data = dataset.GetRasterBand(1).ReadAsArray()
del dataset
return mask_data
except Exception as e:
print(f"警告: 加载掩膜文件 {mask_path} 时出错: {str(e)}")
return None
def get_hdr_file_path(file_path):
"""
获取HDR文件路径
Args:
file_path: 影像文件路径
Returns:
HDR文件路径
"""
return os.path.splitext(file_path)[0] + ".hdr"
def calculate_utm_zone(longitude):
"""
根据经度计算UTM分区号
Args:
longitude: 经度
Returns:
utm_zone: UTM分区号1-60
"""
# UTM分区从180度开始每个分区6度
utm_zone = int((longitude + 180) / 6) + 1
# 确保分区号在有效范围内
utm_zone = max(1, min(60, utm_zone))
return utm_zone
def latlon_to_utm_math(lat_deg, lon_deg, zone=None):
"""
使用数学公式将WGS84经纬度转换为UTM坐标
Args:
lat_deg: 纬度(度)
lon_deg: 经度(度)
zone: UTM分区号如果为None则根据经度自动计算
Returns:
easting, northing: UTM坐标
"""
# 如果未指定分区,根据经度计算
if zone is None:
zone = calculate_utm_zone(lon_deg)
# 计算中央经线(度)
lon0 = (zone * 6 - 183)
lam0 = radians(lon0)
# 转换为弧度
phi = radians(lat_deg)
lam = radians(lon_deg)
# 计算中间变量
sinphi = sin(phi)
cosphi = cos(phi)
tanphi = tan(phi)
# 计算卯酉圈曲率半径
N = WGS84_A / sqrt(1 - WGS84_E2 * sinphi * sinphi)
T = tanphi * tanphi
C = WGS84_EP2 * cosphi * cosphi
A = cosphi * (lam - lam0)
# 计算子午圈弧长使用Snyder公式
M = (WGS84_A * ((1 - WGS84_E2/4 - 3*WGS84_E2**2/64 - 5*WGS84_E2**3/256) * phi
- (3*WGS84_E2/8 + 3*WGS84_E2**2/32 + 45*WGS84_E2**3/1024) * sin(2*phi)
+ (15*WGS84_E2**2/256 + 45*WGS84_E2**3/1024) * sin(4*phi)
- (35*WGS84_E2**3/3072) * sin(6*phi)))
# 计算东坐标Easting
E = (UTM_K0 * N * (A + (1 - T + C) * A**3 / 6
+ (5 - 18*T + T*T + 72*C - 58*WGS84_EP2) * A**5 / 120)
+ 500000.0)
# 计算北坐标Northing
# 对于南半球需要添加10000000米偏移
if lat_deg < 0:
Nn = (UTM_K0 * (M + N * tanphi * (A**2 / 2
+ (5 - T + 9*C + 4*C*C) * A**4 / 24
+ (61 - 58*T + T*T + 600*C - 330*WGS84_EP2) * A**6 / 720))
+ 10000000.0)
else:
Nn = (UTM_K0 * (M + N * tanphi * (A**2 / 2
+ (5 - T + 9*C + 4*C*C) * A**4 / 24
+ (61 - 58*T + T*T + 600*C - 330*WGS84_EP2) * A**6 / 720)))
return E, Nn
def convert_to_utm(lon, lat, source_epsg=4326, target_epsg=None):
"""
将坐标转换为UTM格式使用数学公式根据经度自动计算UTM分区
Args:
lon: 经度数组
lat: 纬度数组
source_epsg: 源坐标系EPSG代码默认为4326 (WGS84地理坐标系)
target_epsg: 目标坐标系EPSG代码如果为None则根据经度自动计算如果指定则从EPSG代码提取分区号
Returns:
utm_x, utm_y: 转换后的UTM坐标
"""
try:
# 检查源坐标系是否为WGS84
if source_epsg != 4326:
print(f"警告: 数学公式转换仅支持WGS84 (EPSG:4326)当前源坐标系为EPSG:{source_epsg}")
print("将尝试使用数学公式进行转换,但可能不准确")
# 批量转换坐标
utm_x = np.zeros_like(lon)
utm_y = np.zeros_like(lat)
# 如果指定了目标EPSG提取分区号
fixed_zone = None
if target_epsg is not None:
# 从EPSG代码提取分区号
# EPSG:32651 -> 51, EPSG:32751 -> 51
if 32601 <= target_epsg <= 32660:
fixed_zone = target_epsg - 32600
elif 32701 <= target_epsg <= 32760:
fixed_zone = target_epsg - 32700
else:
print(f"警告: 无法从EPSG代码 {target_epsg} 提取UTM分区号将根据经度自动计算")
# 向量化处理:标记无效坐标
invalid_mask = (np.isnan(lon) | np.isnan(lat) |
(lon < -180) | (lon > 180) |
(lat < -90) | (lat > 90))
# 统计无效坐标
invalid_count = np.sum(invalid_mask)
if invalid_count > 0:
invalid_indices = np.where(invalid_mask)[0]
print(f"警告: 发现 {invalid_count} 个无效坐标点,将跳过")
for idx in invalid_indices[:10]: # 只打印前10个
print(f" 坐标点 {idx + 1}: 经度={lon[idx]}, 纬度={lat[idx]}")
if invalid_count > 10:
print(f" ... 还有 {invalid_count - 10} 个无效坐标点")
# 对有效坐标进行转换
valid_mask = ~invalid_mask
if np.any(valid_mask):
valid_lon = lon[valid_mask]
valid_lat = lat[valid_mask]
valid_indices = np.where(valid_mask)[0]
# 计算UTM分区向量化
if fixed_zone is not None:
zones = np.full(len(valid_lon), fixed_zone)
else:
zones = np.array([calculate_utm_zone(lon_val) for lon_val in valid_lon])
# 批量转换(仍需要循环,但减少了开销)
for i, (lat_val, lon_val, zone) in enumerate(zip(valid_lat, valid_lon, zones)):
try:
E, Nn = latlon_to_utm_math(lat_val, lon_val, zone)
if not (np.isnan(E) or np.isnan(Nn) or np.isinf(E) or np.isinf(Nn)):
utm_x[valid_indices[i]] = E
utm_y[valid_indices[i]] = Nn
else:
utm_x[valid_indices[i]] = np.nan
utm_y[valid_indices[i]] = np.nan
except Exception as e:
utm_x[valid_indices[i]] = np.nan
utm_y[valid_indices[i]] = np.nan
# 设置无效坐标为NaN
utm_x[invalid_mask] = np.nan
utm_y[invalid_mask] = np.nan
return utm_x, utm_y
except Exception as e:
print(f"坐标转换初始化失败: {str(e)}")
return np.full_like(lon, np.nan), np.full_like(lat, np.nan)
def convert_to_utm51n(lon, lat, source_epsg=4326):
"""
将坐标转换为WGS84 UTM 51N格式保留向后兼容性
Args:
lon: 经度数组
lat: 纬度数组
source_epsg: 源坐标系EPSG代码默认为4326 (WGS84地理坐标系)
Returns:
utm_x, utm_y: 转换后的UTM坐标
"""
# 使用新的转换函数但强制使用UTM 51N
return convert_to_utm(lon, lat, source_epsg, target_epsg=32651)
def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None, boundary_path=None, source_epsg=4326):
"""
获取给定坐标的光谱曲线并将坐标转换为UTM格式根据经度自动计算UTM分区
Args:
imgpath: 影像文件路径BIL格式
coorpath: 坐标文件路径CSV格式第1、2列为纬度和经度
outpath: 输出文件路径CSV格式
radius: 采样半径(像素)
flare_path: 耀斑文件路径(可选)
boundary_path: 边界文件路径(可选)
source_epsg: 源坐标系EPSG代码默认为4326 (WGS84地理坐标系)
"""
# 读取原始坐标文件CSV格式
coor_df = None
coor_data = None
# 尝试不同的编码方式读取CSV文件
encodings = ['utf-8', 'gbk', 'gb2312', 'latin1', 'cp1252']
for encoding in encodings:
try:
# 尝试读取CSV文件
coor_df = pd.read_csv(coorpath, encoding=encoding)
# 只提取数值数据,跳过表头
coor_data = coor_df.select_dtypes(include=[np.number]).values
# 如果没有数值列,尝试转换所有列(跳过第一行表头)
if coor_data.shape[1] == 0:
# 尝试从第二行开始读取,第一行作为表头
coor_df = pd.read_csv(coorpath, encoding=encoding, header=0)
# 尝试将所有列转换为数值
numeric_df = coor_df.apply(pd.to_numeric, errors='coerce')
# 删除全为NaN的行通常是表头转换失败的行
numeric_df = numeric_df.dropna(how='all')
coor_data = numeric_df.values
print(f"成功使用 {encoding} 编码读取文件")
break
except Exception as e:
print(f"使用 {encoding} 编码读取失败: {str(e)}")
continue
# 如果所有编码都失败尝试numpy读取
if coor_data is None:
try:
print("尝试使用numpy读取数值数据...")
# 跳过第一行(表头),只读取数值
coor_data = np.loadtxt(coorpath, delimiter=",", skiprows=1)
except:
try:
coor_data = np.loadtxt(coorpath, delimiter="\t", skiprows=1)
except Exception as e:
raise Exception(f"无法读取坐标文件,请检查文件格式: {str(e)}")
if len(coor_data.shape) == 1:
coor_data = coor_data.reshape(1, -1)
# 检查数据有效性
if coor_data is None or coor_data.shape[1] < 2:
raise Exception("坐标文件格式错误需要至少2列数据第1列为纬度第2列为经度")
print(f"成功读取坐标文件,共 {coor_data.shape[0]} 行,{coor_data.shape[1]}")
print(f"数据预览前3行")
for i in range(min(3, coor_data.shape[0])):
print(f"{i + 1}: {coor_data[i, :min(5, coor_data.shape[1])]}") # 只显示前5列
# 提取原始坐标
lat_array = coor_data[:, 0] # 第1列是纬度
lon_array = coor_data[:, 1] # 第2列是经度
print(f"\n=== 原始坐标信息 ===")
print(f"原始坐标范围: 经度 {np.min(lon_array):.6f} ~ {np.max(lon_array):.6f}, 纬度 {np.min(lat_array):.6f} ~ {np.max(lat_array):.6f}")
# 坐标转换为UTM根据经度自动计算UTM分区
print("正在进行坐标转换...")
utm_x, utm_y = convert_to_utm(lon_array, lat_array, source_epsg, target_epsg=None)
# 检查转换结果
valid_utm_mask = ~(np.isnan(utm_x) | np.isnan(utm_y) | np.isinf(utm_x) | np.isinf(utm_y))
valid_count = np.sum(valid_utm_mask)
if valid_count > 0:
print(f"转换后UTM坐标范围: X {np.nanmin(utm_x):.2f} ~ {np.nanmax(utm_x):.2f}, Y {np.nanmin(utm_y):.2f} ~ {np.nanmax(utm_y):.2f}")
print(f"成功转换 {valid_count}/{len(utm_x)} 个坐标点")
else:
print("警告: 所有UTM坐标转换都失败了将尝试使用原始经纬度坐标进行像素坐标转换")
# 打开影像数据集
dataset = gdal.Open(imgpath)
im_width = dataset.RasterXSize # 栅格矩阵的列数
im_height = dataset.RasterYSize # 栅格矩阵的行数
num_bands = dataset.RasterCount # 栅格矩阵的波段数
geotransform = dataset.GetGeoTransform() # 仿射矩阵
im_proj = dataset.GetProjection() # 地图投影信息
print(f"影像尺寸: {im_width} x {im_height}, 波段数: {num_bands}")
print(f"仿射变换参数: {geotransform}")
print("\n=== 开始光谱提取 ===")
# 加载掩膜文件
flare_mask = load_mask_file(flare_path)
boundary_mask = load_mask_file(boundary_path)
# 获取数据集的空间参考系统
dataset_srs = dataset.GetSpatialRef()
# 准备输出数组在原有数据基础上添加UTM坐标和光谱列
original_cols = coor_data.shape[1]
# 添加UTM坐标列2列和光谱列num_bands列
new_columns = np.zeros((coor_data.shape[0], 2 + num_bands))
coor_spectral = np.hstack((coor_data, new_columns))
# 将UTM坐标添加到数据中
coor_spectral[:, original_cols] = utm_x # UTM X坐标
coor_spectral[:, original_cols + 1] = utm_y # UTM Y坐标
print(f"处理 {coor_data.shape[0]} 个坐标点...")
# 如果UTM转换失败尝试使用影像坐标系进行转换
use_utm_fallback = False
if valid_count == 0 and dataset_srs is not None:
print("尝试使用影像坐标系进行坐标转换...")
try:
source_srs = osr.SpatialReference()
source_srs.ImportFromEPSG(source_epsg)
transform_to_image = osr.CoordinateTransformation(source_srs, dataset_srs)
use_utm_fallback = True
except:
use_utm_fallback = False
# 批量转换所有坐标点为像素坐标
pixel_x_array = np.zeros(coor_data.shape[0], dtype=np.int32)
pixel_y_array = np.zeros(coor_data.shape[0], dtype=np.int32)
valid_pixel_mask = np.zeros(coor_data.shape[0], dtype=bool)
# 批量计算像素坐标
for i in range(coor_data.shape[0]):
# 优先使用UTM坐标如果无效则使用备用方案
utm_x_point = utm_x[i]
utm_y_point = utm_y[i]
# 检查UTM坐标是否有效
if np.isnan(utm_x_point) or np.isnan(utm_y_point) or np.isinf(utm_x_point) or np.isinf(utm_y_point):
# 如果UTM转换失败尝试使用影像坐标系
if use_utm_fallback:
try:
lon_point = lon_array[i]
lat_point = lat_array[i]
if not (np.isnan(lon_point) or np.isnan(lat_point)):
# 转换为影像坐标系
img_coords = transform_to_image.TransformPoint(lon_point, lat_point)
pixel_x, pixel_y = geo_to_pixel(img_coords[0], img_coords[1], geotransform, dataset_srs)
# 更新UTM坐标列使用影像坐标系坐标
coor_spectral[i, original_cols] = img_coords[0]
coor_spectral[i, original_cols + 1] = img_coords[1]
else:
print(f"跳过坐标点 {i + 1}: 坐标无效")
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
continue
except Exception as e:
# 如果影像坐标系转换也失败,尝试直接使用经纬度
try:
lon_point = lon_array[i]
lat_point = lat_array[i]
if not (np.isnan(lon_point) or np.isnan(lat_point)):
pixel_x, pixel_y = geo_to_pixel(lon_point, lat_point, geotransform, dataset_srs)
# 保留原始经纬度作为坐标
coor_spectral[i, original_cols] = lon_point
coor_spectral[i, original_cols + 1] = lat_point
else:
print(f"跳过坐标点 {i + 1}: 坐标无效")
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
continue
except:
print(f"跳过坐标点 {i + 1}: 所有坐标转换方式都失败")
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
continue
else:
# 尝试直接使用经纬度坐标
try:
lon_point = lon_array[i]
lat_point = lat_array[i]
if not (np.isnan(lon_point) or np.isnan(lat_point)):
pixel_x, pixel_y = geo_to_pixel(lon_point, lat_point, geotransform, dataset_srs)
# 保留原始经纬度作为坐标
coor_spectral[i, original_cols] = lon_point
coor_spectral[i, original_cols + 1] = lat_point
else:
print(f"跳过坐标点 {i + 1}: 坐标无效")
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
continue
except:
print(f"跳过坐标点 {i + 1}: 坐标转换失败")
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
continue
else:
# UTM坐标转换为像素坐标
pixel_x, pixel_y = geo_to_pixel(utm_x_point, utm_y_point, geotransform, dataset_srs)
# 存储像素坐标
pixel_x_array[i] = pixel_x
pixel_y_array[i] = pixel_y
# 检查坐标是否在影像范围内
if 0 <= pixel_x < im_width and 0 <= pixel_y < im_height:
valid_pixel_mask[i] = True
else:
valid_pixel_mask[i] = False
if i < 10 or (i % 100 == 0): # 只打印前10个或每100个打印一次
print(f"警告: 坐标点 {i + 1} (UTM X:{utm_x_point:.2f}, Y:{utm_y_point:.2f}) 超出影像范围")
# 批量提取光谱数据优化减少I/O操作
print(f"批量提取光谱数据... (有效坐标点: {np.sum(valid_pixel_mask)})")
if radius > 0:
# 半径采样模式:需要逐个处理
for i in range(coor_data.shape[0]):
if valid_pixel_mask[i]:
spectrum = get_average_spectral_in_radius(
dataset, pixel_x_array[i], pixel_y_array[i], radius, flare_mask, boundary_mask
)
coor_spectral[i, original_cols + 2:] = spectrum
else:
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
else:
# 单点采样模式:批量读取(优化)
# 预读取所有波段数据(如果内存允许)
try:
# 尝试读取所有波段到内存(适用于内存充足的情况)
print("正在预加载所有波段数据到内存(优化模式)...")
all_bands_data = []
for band_idx in range(num_bands):
band = dataset.GetRasterBand(band_idx + 1)
band_data = band.ReadAsArray()
all_bands_data.append(band_data)
all_bands_data = np.array(all_bands_data) # shape: (bands, height, width)
print("预加载完成,开始批量提取像素值...")
# 批量提取像素值
for i in range(coor_data.shape[0]):
if valid_pixel_mask[i]:
px, py = int(pixel_x_array[i]), int(pixel_y_array[i])
# GDAL读取的数组形状是 (bands, height, width),像素坐标 (x,y) 对应数组索引 [:, y, x]
# 注意py是行y坐标px是列x坐标
if 0 <= px < all_bands_data.shape[2] and 0 <= py < all_bands_data.shape[1]:
spectrum = all_bands_data[:, py, px] # 直接索引,非常快
coor_spectral[i, original_cols + 2:] = spectrum
else:
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
else:
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
# 释放内存
del all_bands_data
print("批量提取完成")
except MemoryError:
# 如果内存不足,回退到逐个波段读取
print("内存不足,使用逐个波段读取模式...")
for i in range(coor_data.shape[0]):
if valid_pixel_mask[i]:
px, py = pixel_x_array[i], pixel_y_array[i]
spectrum = np.zeros(num_bands)
for band_idx in range(num_bands):
band = dataset.GetRasterBand(band_idx + 1)
spectrum[band_idx] = band.ReadAsArray(px, py, 1, 1)[0, 0]
coor_spectral[i, original_cols + 2:] = spectrum
else:
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
del dataset
# 创建DataFrame用于CSV输出
# 去除前两列坐标列纬度和经度和UTM列
try:
# 如果原始数据有列名,使用原始列名(跳过前两列)
if coor_df is not None and hasattr(coor_df, 'columns'):
# 跳过前两列经纬度从第3列开始
if len(coor_df.columns) >= original_cols:
# 保留第3列及之后的原始列如果有的话
if original_cols > 2:
original_columns = list(coor_df.columns[2:original_cols])
else:
original_columns = []
else:
# 如果原始列数不足,只保留存在的列(跳过前两列)
if len(coor_df.columns) > 2:
original_columns = list(coor_df.columns[2:])
else:
original_columns = []
else:
# 如果没有列名只保留第3列及之后的列如果有的话
if original_cols > 2:
original_columns = ["col_" + str(j + 1) for j in range(2, original_cols)]
else:
original_columns = []
except:
# 异常处理只保留第3列及之后的列如果有的话
if original_cols > 2:
original_columns = ["col_" + str(j + 1) for j in range(2, original_cols)]
else:
original_columns = []
# 读取波长信息,用作光谱列名
wavelengths = None
try:
in_hdr_dict = spectral.envi.read_envi_header(get_hdr_file_path(imgpath))
wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64')
# 将波长值转换为字符串作为列名
spectral_columns = [str(wl) for wl in wavelengths]
print(f"成功读取波长信息,共 {len(spectral_columns)} 个波段")
except Exception as e:
print(f"警告: 无法读取波长信息 ({str(e)}),使用默认列名 band_1, band_2, ...")
spectral_columns = ["band_" + str(j + 1) for j in range(num_bands)]
# 构建输出列名不包含前两列坐标列和UTM列
all_columns = original_columns + spectral_columns
# 从coor_spectral中提取需要输出的列
# 跳过前两列经纬度和UTM列只保留
# - 第3列到第original_cols列如果有的话
# - 光谱数据列从original_cols+2开始
output_data = []
if original_cols > 2:
# 保留第3列到第original_cols列
output_data.append(coor_spectral[:, 2:original_cols])
# 保留光谱数据列从original_cols+2开始
output_data.append(coor_spectral[:, original_cols + 2:])
# 合并数据
if len(output_data) > 0:
output_array = np.hstack(output_data) if len(output_data) > 1 else output_data[0]
else:
# 如果没有原始列,只输出光谱数据
output_array = coor_spectral[:, original_cols + 2:]
# 创建结果DataFrame
result_df = pd.DataFrame(output_array, columns=all_columns)
# 保存为CSV格式
result_df.to_csv(outpath, index=False, float_format='%.6f')
print(f"结果已保存到CSV文件: {outpath}")
return coor_spectral
# 直接运行示例
if __name__ == '__main__':
# 在这里直接设置参数
imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径第1、2列为纬度和经度
output_path = r"E:\code\WQ\封装\work_dir\5_training_spectra/yangdian_output.csv" # CSV格式输出文件路径
radius = 5 # 采样半径像素0表示单点采样>0表示半径内平均
flare_path = r"E:\code\WQ\封装\work_dir\2_glint\severe_glint_area.dat" # 耀斑掩膜文件路径可选None表示不使用
boundary_path = r"D:\BaiduNetdiskDownload\yaobao\water_mask.dat" # 边界掩膜文件路径可选None表示不使用
source_epsg = 4326 # 源坐标系EPSG代码默认为4326 (WGS84地理坐标系)
verbose = True # 是否启用详细模式
if verbose:
print(f"影像文件: {imgpath}")
print(f"坐标文件: {coorpath}")
print(f"输出文件: {output_path}")
print(f"采样半径: {radius}")
if flare_path:
print(f"耀斑掩膜: {flare_path}")
if boundary_path:
print(f"边界掩膜: {boundary_path}")
if source_epsg:
print(f"指定坐标系: EPSG:{source_epsg}")
tmp = get_spectral_in_coor(imgpath, coorpath, output_path,
radius, flare_path, boundary_path, source_epsg)