Initial commit of WQ_GUI
This commit is contained in:
367
src/core/glint_removal/Goodman.py
Normal file
367
src/core/glint_removal/Goodman.py
Normal file
@ -0,0 +1,367 @@
|
||||
import numpy as np
|
||||
# import preprocessing
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
print("警告: GDAL未安装,将使用numpy处理模式")
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
TQDM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TQDM_AVAILABLE = False
|
||||
# 如果tqdm不可用,定义一个简单的包装器
|
||||
def tqdm(iterable, desc=None, total=None):
|
||||
return iterable
|
||||
|
||||
class Goodman:
|
||||
def __init__(self, im_aligned, NIR_lower = 25, NIR_upper = 37, A = 0.000019, B = 0.1,
|
||||
use_gdal=True, chunk_size=None, water_mask=None, output_path=None):
|
||||
"""
|
||||
:param im_aligned (np.ndarray or str): band aligned and calibrated & corrected reflectance image
|
||||
可以是numpy数组或GDAL可读取的文件路径
|
||||
:param NIR_lower (int): band index which corresponds to 641.93nm, closest band to 640nm
|
||||
:param NIR_upper (int): band index which corresponds to 751.49nm, closest band to 750nm
|
||||
:param A (float): the values in Goodman et al's paper, using AVIRIS reflectance (rather than radiance) data
|
||||
:param B (float): the values in Goodman et al's paper, using AVIRIS reflectance (rather than radiance) data
|
||||
see Goodman et al, which corrects each pixel independently. The NIR radiance is subtracted from the radiance at each wavelength,
|
||||
but a wavelength-independent offset is also added.
|
||||
it is not clear how A and B were chosen, but an optimization for a case where in situ data is
|
||||
available would enable values to be found
|
||||
:param use_gdal (bool): 是否使用GDAL加速处理(需要GDAL可用且输入为文件路径或大数组)
|
||||
:param chunk_size (int): 已废弃,不再使用分块处理,改为逐波段处理
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||||
如果为None,则不保存
|
||||
"""
|
||||
self.im_aligned = im_aligned
|
||||
self.NIR_lower = NIR_lower
|
||||
self.NIR_upper = NIR_upper
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.use_gdal = use_gdal and GDAL_AVAILABLE
|
||||
self.chunk_size = chunk_size
|
||||
self.is_file_path = isinstance(im_aligned, str)
|
||||
self.output_path = output_path
|
||||
|
||||
# 获取图像信息(需要在加载掩膜之前获取尺寸)
|
||||
if self.is_file_path:
|
||||
if not self.use_gdal:
|
||||
raise ValueError("输入为文件路径时,必须安装GDAL")
|
||||
self.dataset = gdal.Open(im_aligned, gdal.GA_ReadOnly)
|
||||
if self.dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {im_aligned}")
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.n_bands = self.dataset.RasterCount
|
||||
else:
|
||||
self.dataset = None
|
||||
self.height = im_aligned.shape[0]
|
||||
self.width = im_aligned.shape[1]
|
||||
self.n_bands = im_aligned.shape[-1]
|
||||
|
||||
# 加载水域掩膜(在获取图像尺寸之后)
|
||||
self.water_mask = self._load_water_mask(water_mask)
|
||||
|
||||
def _load_water_mask(self, water_mask):
|
||||
"""
|
||||
加载水域掩膜
|
||||
|
||||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
:return: numpy数组或None,1表示水域,0表示非水域
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# 如果已经是numpy数组
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||||
|
||||
# 如果是文件路径
|
||||
if isinstance(water_mask, str):
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||||
|
||||
# 检查是否为shapefile
|
||||
if water_mask.lower().endswith('.shp'):
|
||||
# 从shp文件创建掩膜
|
||||
if self.is_file_path:
|
||||
ref_path = self.im_aligned
|
||||
else:
|
||||
raise ValueError("输入为numpy数组时,无法从shp文件创建掩膜(需要参考栅格)")
|
||||
|
||||
try:
|
||||
from osgeo import ogr
|
||||
ref_dataset = gdal.Open(ref_path, gdal.GA_ReadOnly)
|
||||
if ref_dataset is None:
|
||||
raise ValueError(f"无法打开参考栅格文件: {ref_path}")
|
||||
|
||||
geotransform = ref_dataset.GetGeoTransform()
|
||||
projection = ref_dataset.GetProjection()
|
||||
width = ref_dataset.RasterXSize
|
||||
height = ref_dataset.RasterYSize
|
||||
|
||||
# 创建内存中的栅格数据集
|
||||
mem_driver = gdal.GetDriverByName('MEM')
|
||||
mask_dataset = mem_driver.Create('', width, height, 1, gdal.GDT_Byte)
|
||||
mask_dataset.SetGeoTransform(geotransform)
|
||||
mask_dataset.SetProjection(projection)
|
||||
|
||||
mask_band = mask_dataset.GetRasterBand(1)
|
||||
mask_band.Fill(0)
|
||||
|
||||
# 打开shp文件
|
||||
shp_dataset = ogr.Open(water_mask)
|
||||
if shp_dataset is None:
|
||||
raise ValueError(f"无法打开shp文件: {water_mask}")
|
||||
|
||||
layer = shp_dataset.GetLayer()
|
||||
gdal.RasterizeLayer(mask_dataset, [1], layer, burn_values=[1])
|
||||
|
||||
water_mask_array = mask_band.ReadAsArray()
|
||||
|
||||
ref_dataset = None
|
||||
mask_dataset = None
|
||||
shp_dataset = None
|
||||
|
||||
return (water_mask_array > 0).astype(np.uint8)
|
||||
except Exception as e:
|
||||
raise ValueError(f"从shp文件创建掩膜时出错: {e}")
|
||||
else:
|
||||
# 栅格文件
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||||
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
def _get_corrected_bands_numpy(self):
|
||||
"""
|
||||
使用numpy处理(用于小图像或GDAL不可用时)
|
||||
|
||||
注意:由于输入已经是numpy数组,数据已在内存中。
|
||||
此方法通过逐波段处理,避免同时创建多个校正后的波段数组。
|
||||
内存峰值 = 原始数组 + NIR波段(2个) + 当前处理的波段(1个)
|
||||
"""
|
||||
# 预提取重复使用的NIR波段,避免在循环中重复访问
|
||||
# 这些波段会一直保存在内存中,因为它们需要用于所有波段的校正
|
||||
R_640 = self.im_aligned[:,:,self.NIR_lower]
|
||||
R_750 = self.im_aligned[:,:,self.NIR_upper]
|
||||
# 预计算常量部分
|
||||
diff_640_750 = R_640 - R_750
|
||||
corrected_bands = []
|
||||
|
||||
# 获取水域掩膜(如果存在)
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
|
||||
# 逐波段处理:每次只处理一个波段,处理完后立即添加到结果列表
|
||||
for i in tqdm(range(self.n_bands), desc="处理波段 (numpy)", total=self.n_bands):
|
||||
# 获取当前波段(这是数组视图,不是复制)
|
||||
R = self.im_aligned[:,:,i]
|
||||
# 优化计算:减少中间数组创建
|
||||
corrected_band = R - R_750 + self.A + self.B * diff_640_750
|
||||
# 使用np.maximum原地操作,将负值设为0
|
||||
np.maximum(corrected_band, 0, out=corrected_band)
|
||||
|
||||
# 如果存在水域掩膜,只对水域区域应用校正
|
||||
if water_mask_bool is not None:
|
||||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||||
|
||||
# 立即添加到结果列表(corrected_band会保留在列表中)
|
||||
corrected_bands.append(corrected_band)
|
||||
return corrected_bands
|
||||
|
||||
def _get_corrected_bands_gdal(self):
|
||||
"""
|
||||
使用GDAL逐波段处理,直接处理整个波段(不分块)
|
||||
|
||||
内存峰值 = NIR波段(2个) + 当前处理的波段(1个) + 已处理的波段(累积在列表中)
|
||||
"""
|
||||
corrected_bands = []
|
||||
|
||||
# 获取NIR波段对象(用于所有波段的校正)
|
||||
band_640 = self.dataset.GetRasterBand(self.NIR_lower + 1) # GDAL波段从1开始
|
||||
band_750 = self.dataset.GetRasterBand(self.NIR_upper + 1)
|
||||
|
||||
# 先读取NIR波段(用于所有波段的校正,会一直保存在内存中)
|
||||
R_640 = band_640.ReadAsArray().astype(np.float32)
|
||||
R_750 = band_750.ReadAsArray().astype(np.float32)
|
||||
diff_640_750 = R_640 - R_750
|
||||
|
||||
# 获取水域掩膜
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
|
||||
# 逐波段处理:每次只读取和处理一个波段
|
||||
for i in tqdm(range(self.n_bands), desc="处理波段 (GDAL)", total=self.n_bands):
|
||||
# 读取当前波段(只加载一个波段到内存)
|
||||
current_band = self.dataset.GetRasterBand(i + 1)
|
||||
R = current_band.ReadAsArray().astype(np.float32)
|
||||
|
||||
# 校正计算
|
||||
corrected_band = R - R_750 + self.A + self.B * diff_640_750
|
||||
np.maximum(corrected_band, 0, out=corrected_band)
|
||||
|
||||
# 如果存在水域掩膜,只对水域区域应用校正
|
||||
if water_mask_bool is not None:
|
||||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||||
|
||||
# 添加到结果列表(corrected_band会保留在列表中)
|
||||
corrected_bands.append(corrected_band)
|
||||
|
||||
# 释放当前波段数据(显式删除有助于及时释放内存)
|
||||
del R
|
||||
|
||||
return corrected_bands
|
||||
|
||||
def _get_corrected_bands_gdal_mem(self):
|
||||
"""使用GDAL内存驱动处理numpy数组,逐波段处理"""
|
||||
# 创建内存数据集
|
||||
driver = gdal.GetDriverByName('MEM')
|
||||
mem_dataset = driver.Create('', self.width, self.height, self.n_bands, gdal.GDT_Float32)
|
||||
|
||||
# 将numpy数组写入内存数据集(显示进度)
|
||||
for i in tqdm(range(self.n_bands), desc="加载波段到内存", total=self.n_bands):
|
||||
band = mem_dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(self.im_aligned[:,:,i])
|
||||
band.FlushCache()
|
||||
|
||||
# 临时保存原始dataset引用
|
||||
original_dataset = self.dataset
|
||||
self.dataset = mem_dataset
|
||||
|
||||
try:
|
||||
# 使用逐波段处理方法
|
||||
result = self._get_corrected_bands_gdal()
|
||||
finally:
|
||||
# 恢复原始dataset
|
||||
self.dataset = original_dataset
|
||||
mem_dataset = None
|
||||
|
||||
return result
|
||||
|
||||
def _save_corrected_bands(self, corrected_bands):
|
||||
"""
|
||||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||||
|
||||
注意:为了节省内存,直接逐波段写入,不先堆叠成完整数组
|
||||
|
||||
:param corrected_bands: 校正后的波段列表
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if self.output_path is None:
|
||||
return
|
||||
|
||||
import os
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 从第一个波段获取尺寸信息(避免堆叠所有波段)
|
||||
if not corrected_bands:
|
||||
raise ValueError("校正后的波段列表为空")
|
||||
first_band = corrected_bands[0]
|
||||
height, width = first_band.shape
|
||||
n_bands = len(corrected_bands)
|
||||
|
||||
# 获取地理变换和投影信息
|
||||
if self.is_file_path and self.dataset is not None:
|
||||
geotransform = self.dataset.GetGeoTransform()
|
||||
projection = self.dataset.GetProjection()
|
||||
else:
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
else:
|
||||
bsq_path = self.output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 直接逐波段写入(不先堆叠,节省内存)
|
||||
for i in tqdm(range(n_bands), desc="保存波段", total=n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
# 直接从列表中获取波段并写入,避免创建完整数组
|
||||
band.WriteArray(corrected_bands[i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
else:
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
获取校正后的波段
|
||||
根据输入类型和大小自动选择最优处理方法
|
||||
|
||||
:return: 校正后的波段列表
|
||||
"""
|
||||
# 如果输入是文件路径,使用GDAL直接读取
|
||||
if self.is_file_path:
|
||||
if self.use_gdal:
|
||||
corrected_bands = self._get_corrected_bands_gdal()
|
||||
else:
|
||||
raise ValueError("输入为文件路径时,必须安装GDAL")
|
||||
else:
|
||||
# 如果输入是numpy数组
|
||||
if self.use_gdal and self.height * self.width * self.n_bands > 100000000:
|
||||
# 大图像使用GDAL内存驱动逐波段处理
|
||||
corrected_bands = self._get_corrected_bands_gdal_mem()
|
||||
else:
|
||||
# 小图像使用numpy直接处理
|
||||
corrected_bands = self._get_corrected_bands_numpy()
|
||||
|
||||
# 如果提供了输出路径,保存结果
|
||||
if self.output_path is not None:
|
||||
self._save_corrected_bands(corrected_bands)
|
||||
|
||||
return corrected_bands
|
||||
|
||||
def __del__(self):
|
||||
"""清理资源"""
|
||||
if self.dataset is not None and self.is_file_path:
|
||||
self.dataset = None
|
||||
290
src/core/glint_removal/Hedley.py
Normal file
290
src/core/glint_removal/Hedley.py
Normal file
@ -0,0 +1,290 @@
|
||||
import numpy as np
|
||||
# import preprocessing
|
||||
import os
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
class Hedley:
|
||||
def __init__(self, im_aligned, shp_path=None, NIR_band = 47, water_mask=None, output_path=None):
|
||||
"""
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
|
||||
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
|
||||
:param NIR_band (int): band index for NIR band which corresponds to 842.36nm, which corresponds closely to the NIR band in Micasense
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||||
如果为None,则不保存
|
||||
"""
|
||||
self.im_aligned = im_aligned
|
||||
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
|
||||
self.NIR_band = NIR_band
|
||||
self.n_bands = im_aligned.shape[-1]
|
||||
self.height = im_aligned.shape[0]
|
||||
self.width = im_aligned.shape[1]
|
||||
self.output_path = output_path
|
||||
|
||||
# 加载水域掩膜
|
||||
self.water_mask = self._load_water_mask(water_mask)
|
||||
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 如果存在水域掩膜,只在掩膜内计算R_min
|
||||
if self.water_mask is not None:
|
||||
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
|
||||
self.R_min = np.percentile(nir_band_masked, 5, interpolation='nearest') if nir_band_masked.size > 0 else 0
|
||||
else:
|
||||
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, interpolation='nearest')
|
||||
|
||||
def _read_shp_to_bbox(self, shp_path):
|
||||
"""
|
||||
读取shapefile并提取边界框
|
||||
|
||||
:param shp_path (str): shapefile文件路径
|
||||
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
|
||||
"""
|
||||
if not os.path.exists(shp_path):
|
||||
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
|
||||
|
||||
try:
|
||||
try:
|
||||
import geopandas as gpd
|
||||
gdf = gpd.read_file(shp_path)
|
||||
# 获取所有几何体的总边界框
|
||||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||||
min_x, min_y, max_x, max_y = bounds
|
||||
except ImportError:
|
||||
# 如果geopandas不可用,尝试使用fiona
|
||||
import fiona
|
||||
from shapely.geometry import shape
|
||||
|
||||
min_x = float('inf')
|
||||
min_y = float('inf')
|
||||
max_x = float('-inf')
|
||||
max_y = float('-inf')
|
||||
|
||||
with fiona.open(shp_path) as shp:
|
||||
for feature in shp:
|
||||
geom = shape(feature['geometry'])
|
||||
if geom:
|
||||
bounds = geom.bounds
|
||||
min_x = min(min_x, bounds[0])
|
||||
min_y = min(min_y, bounds[1])
|
||||
max_x = max(max_x, bounds[2])
|
||||
max_y = max(max_y, bounds[3])
|
||||
|
||||
# 转换为整数像素坐标
|
||||
x1 = max(0, int(min_x))
|
||||
y1 = max(0, int(min_y))
|
||||
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
|
||||
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
|
||||
|
||||
return ((x1, y1), (x2, y2))
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
|
||||
|
||||
def _load_water_mask(self, water_mask):
|
||||
"""
|
||||
加载水域掩膜
|
||||
|
||||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
:return: numpy数组或None,1表示水域,0表示非水域
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# 如果已经是numpy数组
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||||
|
||||
# 如果是文件路径
|
||||
if isinstance(water_mask, str):
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
except ImportError:
|
||||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||||
|
||||
# 检查是否为shapefile
|
||||
if water_mask.lower().endswith('.shp'):
|
||||
# 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸)
|
||||
# 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考
|
||||
raise ValueError("Hedley类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
|
||||
else:
|
||||
# 栅格文件
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||||
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
def covariance_NIR(self,NIR,b):
|
||||
"""
|
||||
NIR & b are vectors
|
||||
reflectance for band i
|
||||
"""
|
||||
n = len(NIR)
|
||||
# 优化:减少重复计算,使用更高效的numpy操作
|
||||
nir_mean = np.mean(NIR)
|
||||
b_mean = np.mean(b)
|
||||
# 使用更高效的协方差计算
|
||||
pij = np.mean((NIR - nir_mean) * (b - b_mean))
|
||||
pjj = np.mean((NIR - nir_mean) ** 2)
|
||||
# 避免除零错误
|
||||
return pij / pjj if pjj != 0 else 0.0
|
||||
|
||||
def correlation_bands_reflectance(self):
|
||||
"""
|
||||
calculate correlation between NIR and other bands for reflectance
|
||||
NIR_band is 750 nm
|
||||
"""
|
||||
# If bbox is None, use the entire image
|
||||
if self.bbox is None:
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 直接使用视图,只在需要时创建扁平数组
|
||||
im_region = self.im_aligned
|
||||
mask_region = self.water_mask
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
im_region = self.im_aligned[y1:y2,x1:x2,:]
|
||||
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内计算相关性
|
||||
if mask_region is not None:
|
||||
mask_bool = mask_region.astype(bool)
|
||||
if mask_bool.any():
|
||||
# 只在掩膜内提取数据
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band][mask_bool]
|
||||
else:
|
||||
# 如果掩膜内没有有效像素,使用全区域
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
|
||||
mask_bool = None
|
||||
else:
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
|
||||
mask_bool = None
|
||||
|
||||
# 优化:一次性计算所有波段的相关性,减少循环开销
|
||||
corr_list = []
|
||||
for v in range(self.n_bands):
|
||||
if mask_bool is not None and mask_bool.any():
|
||||
band_reflectance = im_region[:,:,v][mask_bool]
|
||||
else:
|
||||
band_reflectance = im_region[:,:,v].ravel()
|
||||
corr = self.covariance_NIR(NIR_reflectance, band_reflectance)
|
||||
corr_list.append(corr)
|
||||
|
||||
return corr_list
|
||||
|
||||
def _save_corrected_bands(self, corrected_bands):
|
||||
"""
|
||||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_bands: 校正后的波段列表
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if self.output_path is None:
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 将波段列表转换为数组
|
||||
corrected_array = np.stack(corrected_bands, axis=2)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
else:
|
||||
bsq_path = self.output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_array.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
else:
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
correction is done in reflectance
|
||||
|
||||
:return: 校正后的波段列表
|
||||
"""
|
||||
corr = self.correlation_bands_reflectance()
|
||||
NIR_reflectance = self.im_aligned[:,:,self.NIR_band]
|
||||
# 预计算NIR-R_min,避免在循环中重复计算
|
||||
NIR_diff = NIR_reflectance - self.R_min
|
||||
|
||||
# 获取水域掩膜(如果存在)
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
|
||||
corrected_bands = []
|
||||
for band_number in range(self.n_bands): #iterate across bands
|
||||
b = corr[band_number]
|
||||
R = self.im_aligned[:,:,band_number]
|
||||
# 优化:减少中间数组创建
|
||||
corrected_band = R - b * NIR_diff
|
||||
|
||||
# 如果存在水域掩膜,只对水域区域应用校正
|
||||
if water_mask_bool is not None:
|
||||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||||
|
||||
corrected_bands.append(corrected_band)
|
||||
|
||||
# 如果提供了输出路径,保存结果
|
||||
if self.output_path is not None:
|
||||
self._save_corrected_bands(corrected_bands)
|
||||
|
||||
return corrected_bands
|
||||
313
src/core/glint_removal/Kutser.py
Normal file
313
src/core/glint_removal/Kutser.py
Normal file
@ -0,0 +1,313 @@
|
||||
import numpy as np
|
||||
# import preprocessing
|
||||
import os
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
class Kutser:
|
||||
def __init__(self, im_aligned, shp_path=None, oxy_band = 38,lower_oxy = 36, upper_oxy = 49, NIR_band = 47, water_mask=None, output_path=None):
|
||||
"""
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
|
||||
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
|
||||
:param oxy_band (int): band index for oxygen absorption band, which corresponds to 760.6nm
|
||||
:param lower_oxy (int): band index for outside oxygen absorption band, which corresponds to 742.39nm
|
||||
:param upper_oxy (int): band index for outside oxygen absorption band, which corresponds to 860.48nm
|
||||
see Kutser, Vahtmäe and Praks
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||||
如果为None,则不保存
|
||||
"""
|
||||
self.im_aligned = im_aligned
|
||||
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
|
||||
self.oxy_band = oxy_band
|
||||
self.lower_oxy = lower_oxy
|
||||
self.upper_oxy = upper_oxy
|
||||
self.NIR_band = NIR_band
|
||||
self.n_bands = im_aligned.shape[-1]
|
||||
self.height = im_aligned.shape[0]
|
||||
self.width = im_aligned.shape[1]
|
||||
self.output_path = output_path
|
||||
|
||||
# 加载水域掩膜
|
||||
self.water_mask = self._load_water_mask(water_mask)
|
||||
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 如果存在水域掩膜,只在掩膜内计算R_min
|
||||
if self.water_mask is not None:
|
||||
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
|
||||
self.R_min = np.percentile(nir_band_masked, 5, interpolation='nearest') if nir_band_masked.size > 0 else 0
|
||||
else:
|
||||
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, interpolation='nearest')
|
||||
|
||||
def _read_shp_to_bbox(self, shp_path):
|
||||
"""
|
||||
读取shapefile并提取边界框
|
||||
|
||||
:param shp_path (str): shapefile文件路径
|
||||
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
|
||||
"""
|
||||
if not os.path.exists(shp_path):
|
||||
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
|
||||
|
||||
try:
|
||||
try:
|
||||
import geopandas as gpd
|
||||
gdf = gpd.read_file(shp_path)
|
||||
# 获取所有几何体的总边界框
|
||||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||||
min_x, min_y, max_x, max_y = bounds
|
||||
except ImportError:
|
||||
# 如果geopandas不可用,尝试使用fiona
|
||||
import fiona
|
||||
from shapely.geometry import shape
|
||||
|
||||
min_x = float('inf')
|
||||
min_y = float('inf')
|
||||
max_x = float('-inf')
|
||||
max_y = float('-inf')
|
||||
|
||||
with fiona.open(shp_path) as shp:
|
||||
for feature in shp:
|
||||
geom = shape(feature['geometry'])
|
||||
if geom:
|
||||
bounds = geom.bounds
|
||||
min_x = min(min_x, bounds[0])
|
||||
min_y = min(min_y, bounds[1])
|
||||
max_x = max(max_x, bounds[2])
|
||||
max_y = max(max_y, bounds[3])
|
||||
|
||||
# 转换为整数像素坐标
|
||||
x1 = max(0, int(min_x))
|
||||
y1 = max(0, int(min_y))
|
||||
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
|
||||
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
|
||||
|
||||
return ((x1, y1), (x2, y2))
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
|
||||
|
||||
def _load_water_mask(self, water_mask):
|
||||
"""
|
||||
加载水域掩膜
|
||||
|
||||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
:return: numpy数组或None,1表示水域,0表示非水域
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# 如果已经是numpy数组
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||||
|
||||
# 如果是文件路径
|
||||
if isinstance(water_mask, str):
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
except ImportError:
|
||||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||||
|
||||
# 检查是否为shapefile
|
||||
if water_mask.lower().endswith('.shp'):
|
||||
# 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸)
|
||||
# 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考
|
||||
raise ValueError("Kutser类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
|
||||
else:
|
||||
# 栅格文件
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||||
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
def get_depth_D(self):
|
||||
"""
|
||||
Assume the amount of glint is proportional to the depth of the oxygen absorption feature, D
|
||||
returns the normalised D by dividing it by the maximum D found in a deep water region
|
||||
"""
|
||||
# 优化:减少中间数组创建,使用更高效的计算
|
||||
lower_oxy_band = self.im_aligned[:,:,self.lower_oxy]
|
||||
upper_oxy_band = self.im_aligned[:,:,self.upper_oxy]
|
||||
oxy_band = self.im_aligned[:,:,self.oxy_band]
|
||||
D = (lower_oxy_band + upper_oxy_band) * 0.5 - oxy_band
|
||||
|
||||
# 确定用于计算D_max的区域
|
||||
if self.bbox is None:
|
||||
search_region = D
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
search_region = D[y1:y2,x1:x2]
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内搜索最大值
|
||||
if self.water_mask is not None:
|
||||
if self.bbox is None:
|
||||
mask_region = self.water_mask.astype(bool)
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
mask_region = self.water_mask[y1:y2,x1:x2].astype(bool)
|
||||
|
||||
if mask_region.any():
|
||||
D_max = search_region[mask_region].max()
|
||||
else:
|
||||
D_max = search_region.max()
|
||||
else:
|
||||
D_max = search_region.max() # assumed to be the maximum glint value
|
||||
|
||||
# 避免除零错误
|
||||
if D_max == 0:
|
||||
return np.zeros_like(D)
|
||||
return D / D_max
|
||||
|
||||
def get_glint_G(self):
|
||||
"""
|
||||
The spectral variation of glint G is found by subtracting the spectrum at the darkest (ie. lowest D) NIR deep-water pixel from the brightest
|
||||
returns G as a function of wavelength
|
||||
"""
|
||||
# If bbox is None, use the entire image
|
||||
if self.bbox is None:
|
||||
im_region = self.im_aligned
|
||||
mask_region = self.water_mask
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
im_region = self.im_aligned[y1:y2,x1:x2,:]
|
||||
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内计算最大最小值
|
||||
if mask_region is not None:
|
||||
mask_bool = mask_region.astype(bool)
|
||||
if mask_bool.any():
|
||||
# 对每个波段,只在掩膜内计算最大最小值
|
||||
G_list = []
|
||||
for i in range(self.n_bands):
|
||||
band_data = im_region[:,:,i]
|
||||
G_max = band_data[mask_bool].max()
|
||||
G_min = band_data[mask_bool].min()
|
||||
G_list.append(G_max - G_min)
|
||||
else:
|
||||
# 如果掩膜内没有有效像素,使用全区域
|
||||
G_max = np.amax(im_region, axis=(0, 1))
|
||||
G_min = np.amin(im_region, axis=(0, 1))
|
||||
G_list = (G_max - G_min).tolist()
|
||||
else:
|
||||
# 优化:一次性计算所有波段的最大最小值,减少循环开销
|
||||
# 使用numpy的amax和amin沿最后一个轴计算
|
||||
G_max = np.amax(im_region, axis=(0, 1)) # 沿空间维度计算最大值
|
||||
G_min = np.amin(im_region, axis=(0, 1)) # 沿空间维度计算最小值
|
||||
G_list = (G_max - G_min).tolist()
|
||||
return G_list
|
||||
|
||||
def _save_corrected_bands(self, corrected_bands):
|
||||
"""
|
||||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_bands: 校正后的波段列表
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if self.output_path is None:
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 将波段列表转换为数组
|
||||
corrected_array = np.stack(corrected_bands, axis=2)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
else:
|
||||
bsq_path = self.output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_array.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
else:
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
correction is done in reflectance
|
||||
|
||||
:return: 校正后的波段列表
|
||||
"""
|
||||
g_list = self.get_glint_G()
|
||||
D = self.get_depth_D()
|
||||
|
||||
# 获取水域掩膜(如果存在)
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
|
||||
corrected_bands = []
|
||||
for band_number in range(self.n_bands): #iterate across bands
|
||||
G = g_list[band_number]
|
||||
R = self.im_aligned[:,:,band_number]
|
||||
# 优化:减少中间数组创建,直接计算
|
||||
corrected_band = R - G * D
|
||||
|
||||
# 如果存在水域掩膜,只对水域区域应用校正
|
||||
if water_mask_bool is not None:
|
||||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||||
|
||||
corrected_bands.append(corrected_band)
|
||||
|
||||
# 如果提供了输出路径,保存结果
|
||||
if self.output_path is not None:
|
||||
self._save_corrected_bands(corrected_bands)
|
||||
|
||||
return corrected_bands
|
||||
572
src/core/glint_removal/SUGAR.py
Normal file
572
src/core/glint_removal/SUGAR.py
Normal file
@ -0,0 +1,572 @@
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from scipy import ndimage
|
||||
from scipy.optimize import minimize_scalar
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# SUn-Glint-Aware Restoration (SUGAR):A sweet and simple algorithm for correcting sunglint
|
||||
class SUGAR:
|
||||
def __init__(self, im_aligned,bounds=[(1,2)],sigma=1,estimate_background=True, glint_mask_method="cdf", water_mask=None, output_path=None):
|
||||
"""
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param bounds (a list of tuple): lower and upper bound for optimisation of b for each band
|
||||
:param sigma (float): smoothing sigma for LoG
|
||||
:param estimate_background (bool): whether to estimate background spectra using median filtering
|
||||
:param glint_mask_method (str): choose either "cdf" or "otsu", "cdf" is set as the default
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||||
如果为None,则不保存
|
||||
"""
|
||||
self.im_aligned = im_aligned
|
||||
self.sigma = sigma
|
||||
self.estimate_background = estimate_background
|
||||
self.n_bands = im_aligned.shape[-1]
|
||||
self.bounds = bounds*self.n_bands
|
||||
self.glint_mask_method = glint_mask_method
|
||||
self.height = im_aligned.shape[0]
|
||||
self.width = im_aligned.shape[1]
|
||||
self.output_path = output_path
|
||||
|
||||
# 加载水域掩膜
|
||||
self.water_mask = self._load_water_mask(water_mask)
|
||||
|
||||
def _load_water_mask(self, water_mask):
|
||||
"""
|
||||
加载水域掩膜
|
||||
|
||||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
:return: numpy数组或None,1表示水域,0表示非水域
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# 如果已经是numpy数组
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||||
|
||||
# 如果是文件路径
|
||||
if isinstance(water_mask, str):
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
except ImportError:
|
||||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||||
|
||||
# 检查是否为shapefile
|
||||
if water_mask.lower().endswith('.shp'):
|
||||
# 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸)
|
||||
# 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考
|
||||
raise ValueError("SUGAR类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
|
||||
else:
|
||||
# 栅格文件
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||||
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
def otsu_thresholding(self,im):
|
||||
"""
|
||||
:param im (np.ndarray) of shape mxn. Note that it is the LoG of image
|
||||
otsu thresholding with Brent's minimisation of a univariate function
|
||||
returns the value of the threshold for input
|
||||
"""
|
||||
auto_bins = int(0.005*im.shape[0]*im.shape[1])
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制(如果可能)
|
||||
# 如果存在无效值(如NaN或极大值),过滤掉它们
|
||||
im_flat = im.ravel()
|
||||
# 过滤掉NaN和无穷大值
|
||||
valid_mask = np.isfinite(im_flat)
|
||||
if not valid_mask.all():
|
||||
im_flat = im_flat[valid_mask]
|
||||
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
|
||||
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers,使用乘法替代除法
|
||||
|
||||
count_sum = count.sum()
|
||||
hist_norm = count / count_sum # normalised histogram
|
||||
Q = hist_norm.cumsum() # CDF function ranges from 0 to 1
|
||||
N = count.shape[0]
|
||||
N_negative = np.sum(bin < 0)
|
||||
bins = np.arange(N, dtype=np.float32) # 使用float32减少内存
|
||||
|
||||
def otsu_thresh(x):
|
||||
x = int(x)
|
||||
# 使用切片而不是hsplit,避免创建新数组
|
||||
p1 = hist_norm[:x]
|
||||
p2 = hist_norm[x:]
|
||||
q1 = Q[x]
|
||||
q2 = Q[N-1] - Q[x]
|
||||
b1 = bins[:x]
|
||||
b2 = bins[x:]
|
||||
# finding means and variances
|
||||
m1 = np.sum(p1 * b1) / q1 if q1 > 0 else 0
|
||||
m2 = np.sum(p2 * b2) / q2 if q2 > 0 else 0
|
||||
v1 = np.sum(((b1 - m1) ** 2) * p1) / q1 if q1 > 0 else 0
|
||||
v2 = np.sum(((b2 - m2) ** 2) * p2) / q2 if q2 > 0 else 0
|
||||
# calculates the minimization function
|
||||
fn = v1 * q1 + v2 * q2
|
||||
return fn
|
||||
|
||||
# brent method is used to minimise an univariate function
|
||||
# bounded minimisation
|
||||
# we can just limit the search to negative values since we know thresh should be negative as L<0 for glint pixels
|
||||
if N_negative <= 1:
|
||||
# 如果没有足够的负值,使用默认阈值
|
||||
return bin[np.argmax(count)]
|
||||
res = minimize_scalar(otsu_thresh, bounds=(1, N_negative), method='bounded')
|
||||
thresh = bin[int(res.x)]
|
||||
|
||||
return thresh
|
||||
|
||||
# def cdf_thresholding(self,im, percentile=0.05):
|
||||
# """
|
||||
# :param im (np.ndarray) of shape mxn
|
||||
# :param percentile (float): lower and upper percentile values are potential glint pixels
|
||||
# """
|
||||
# lower_perc = percentile
|
||||
# upper_perc = 1-percentile
|
||||
# im_flatten = im.flatten()
|
||||
# H,X1 = np.histogram(im_flatten, bins = int(0.005*im.shape[0]*im.shape[1]), density=True )
|
||||
# dx = X1[1] - X1[0]
|
||||
# F1 = np.cumsum(H)*dx
|
||||
# F_lower = X1[1:][F1<lower_perc]
|
||||
# F_upper = X1[1:][F1>upper_perc]
|
||||
# while((F_lower.size == 0) or (F_upper.size == 0)):
|
||||
# if (F_lower.size == 0):
|
||||
# lower_perc += 0.01
|
||||
# F_lower = X1[1:][F1<lower_perc]
|
||||
# if (F_upper.size == 0):
|
||||
# upper_perc -= 0.01
|
||||
# F_upper = X1[1:][F1>upper_perc]
|
||||
|
||||
# lower_thresh = F_lower[-1]
|
||||
# upper_thresh = F_upper[0]
|
||||
|
||||
# return lower_thresh,upper_thresh
|
||||
|
||||
def cdf_thresholding(self,im,auto_bins=10):
|
||||
"""
|
||||
:param im (np.ndarray) of shape mxn. Note that it is the LoG of image
|
||||
:param percentile (float): lower and upper percentile values are potential glint pixels
|
||||
"""
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
im_flat = im.ravel()
|
||||
# 过滤掉NaN和无穷大值
|
||||
valid_mask = np.isfinite(im_flat)
|
||||
if not valid_mask.all():
|
||||
im_flat = im_flat[valid_mask]
|
||||
count, bin_edges = np.histogram(im_flat, bins=auto_bins)
|
||||
bin = (bin_edges[:-1] + bin_edges[1:]) * 0.5 # bin centers,使用乘法替代除法
|
||||
thresh = bin[np.argmax(count)]
|
||||
return thresh
|
||||
|
||||
def glint_list(self):
|
||||
"""
|
||||
returns a list of np.ndarray, where each item is an extracted glint for each band based on get_glint_mask
|
||||
"""
|
||||
glint_mask = self.glint_mask_list()
|
||||
extracted_glint_list = []
|
||||
for i in range(self.im_aligned.shape[-1]):
|
||||
gm = glint_mask[i]
|
||||
extracted_glint = gm*self.im_aligned[:,:,i]
|
||||
extracted_glint_list.append(extracted_glint)
|
||||
|
||||
return extracted_glint_list
|
||||
|
||||
def glint_mask_list(self):
|
||||
"""
|
||||
get glint mask using laplacian of gaussian image.
|
||||
returns a list of np.ndarray
|
||||
"""
|
||||
glint_mask_list = []
|
||||
for i in range(self.im_aligned.shape[-1]):
|
||||
glint_mask = self.get_glint_mask(self.im_aligned[:,:,i])
|
||||
glint_mask_list.append(glint_mask)
|
||||
|
||||
return glint_mask_list
|
||||
|
||||
def log_image_list(self):
|
||||
"""
|
||||
get Laplacian of Gaussian (LoG) images for all bands.
|
||||
returns a list of np.ndarray
|
||||
"""
|
||||
log_image_list = []
|
||||
for i in range(self.im_aligned.shape[-1]):
|
||||
log_im = self.get_log_image(self.im_aligned[:,:,i])
|
||||
log_image_list.append(log_im)
|
||||
return log_image_list
|
||||
|
||||
def get_log_image(self, im):
|
||||
"""
|
||||
get Laplacian of Gaussian (LoG) image for a single band.
|
||||
returns a np.ndarray
|
||||
"""
|
||||
LoG_im = ndimage.gaussian_laplace(im, sigma=self.sigma)
|
||||
return LoG_im
|
||||
|
||||
def get_glint_mask(self,im):
|
||||
"""
|
||||
get glint mask using laplacian of gaussian image.
|
||||
We assume that water constituents and features follow a smooth continuum,
|
||||
but glint pixels vary a lot spatially and in intensities
|
||||
Note that for very extensive glint, this method may not work as well <--:TODO use U-net to identify glint mask
|
||||
returns a np.ndarray
|
||||
"""
|
||||
LoG_im = ndimage.gaussian_laplace(im,sigma=self.sigma)
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内计算阈值
|
||||
if self.water_mask is not None:
|
||||
mask_bool = self.water_mask.astype(bool)
|
||||
if mask_bool.any():
|
||||
# 只在掩膜内提取LoG值用于阈值计算
|
||||
LoG_masked = LoG_im[mask_bool]
|
||||
# 将非掩膜区域设为极大值,确保不影响阈值计算
|
||||
LoG_for_thresh = LoG_im.copy()
|
||||
LoG_for_thresh[~mask_bool] = LoG_masked.max() + 1
|
||||
else:
|
||||
LoG_for_thresh = LoG_im
|
||||
else:
|
||||
LoG_for_thresh = LoG_im
|
||||
|
||||
#threshold mask
|
||||
if (self.glint_mask_method == "otsu"):
|
||||
thresh = self.otsu_thresholding(LoG_for_thresh)
|
||||
elif (self.glint_mask_method == "cdf"):
|
||||
thresh = self.cdf_thresholding(LoG_for_thresh)
|
||||
else:
|
||||
raise ValueError('Enter only cdf or otsu as glint_mask_method')
|
||||
# 使用更高效的方式创建mask,避免np.where的开销
|
||||
glint_mask = (LoG_im < thresh).astype(np.uint8)
|
||||
|
||||
# 如果存在水域掩膜,将非水域区域设为0
|
||||
if self.water_mask is not None:
|
||||
glint_mask = glint_mask * self.water_mask
|
||||
|
||||
return glint_mask
|
||||
|
||||
def get_est_background(self, im,k_size=5):
|
||||
"""
|
||||
:param im (np.ndarray): image of a band
|
||||
estimate background spectra
|
||||
returns a np.ndarray
|
||||
"""
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k_size,k_size))
|
||||
dst = cv2.erode(im, kernel)
|
||||
|
||||
return dst
|
||||
|
||||
def optimise_correction_by_band(self,im,glint_mask,R_BG,bounds):
|
||||
"""
|
||||
:param im (np.ndarray): image of a band
|
||||
:param glint_mask (np.ndarray): glint mask, where glint area is 1 and non-glint area is 0
|
||||
use brent method to get the optimimum b which minimises the variation (i.e. variance) in the entire image
|
||||
returns regression slope b
|
||||
"""
|
||||
# 预计算常量,避免在优化函数中重复计算
|
||||
glint_mask_bool = glint_mask.astype(bool)
|
||||
R_BG_flat = R_BG if isinstance(R_BG, (int, float)) else R_BG[glint_mask_bool]
|
||||
|
||||
def optimise_b(b):
|
||||
# 优化计算:只在glint区域计算校正
|
||||
if isinstance(R_BG, (int, float)):
|
||||
im_corrected = im.copy()
|
||||
im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG)
|
||||
else:
|
||||
im_corrected = im.copy()
|
||||
im_corrected[glint_mask_bool] = im[glint_mask_bool] - glint_mask[glint_mask_bool] * (im[glint_mask_bool] / b - R_BG[glint_mask_bool])
|
||||
return np.var(im_corrected)
|
||||
|
||||
res = minimize_scalar(optimise_b, bounds=bounds, method='bounded')
|
||||
return res.x
|
||||
|
||||
def divide_and_conquer(self):
|
||||
"""
|
||||
instead of computing b_list for each window, use the previous b_list to narrow the bounds,
|
||||
because of the strong spatial autocorrelation, we know that the b (correction magnitude) cannot diff too much
|
||||
this can optimise the run time
|
||||
"""
|
||||
|
||||
|
||||
def optimise_correction(self):
|
||||
"""
|
||||
returns a list of slope in band order i.e. 0,1,2,3,4,5,6,7,8,9 through optimisation
|
||||
"""
|
||||
b_list = []
|
||||
glint_mask_list = []
|
||||
est_background_list = []
|
||||
for i in range(self.n_bands):
|
||||
glint_mask = self.get_glint_mask(self.im_aligned[:,:,i])
|
||||
glint_mask_list.append(glint_mask)
|
||||
if self.estimate_background is True:
|
||||
est_background = self.get_est_background(self.im_aligned[:,:,i])
|
||||
est_background_list.append(est_background)
|
||||
else:
|
||||
est_background = np.percentile(self.im_aligned[:,:,i], 5, interpolation='nearest')
|
||||
est_background_list.append(est_background)
|
||||
bounds = self.bounds[i]
|
||||
b = self.optimise_correction_by_band(self.im_aligned[:,:,i],glint_mask,est_background,bounds)
|
||||
b_list.append(b)
|
||||
|
||||
# add attributes
|
||||
self.b_list = b_list
|
||||
self.glint_mask = glint_mask_list
|
||||
self.est_background = est_background_list
|
||||
|
||||
return b_list, glint_mask_list, est_background_list
|
||||
|
||||
def _save_corrected_bands(self, corrected_bands):
|
||||
"""
|
||||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_bands: 校正后的波段列表
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if self.output_path is None:
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 将波段列表转换为数组
|
||||
corrected_array = np.stack(corrected_bands, axis=2)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
else:
|
||||
bsq_path = self.output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_array.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
else:
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
获取校正后的波段
|
||||
|
||||
:return: 校正后的波段列表
|
||||
"""
|
||||
corrected_bands = []
|
||||
# 获取水域掩膜(如果存在)
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
|
||||
for i in range(self.n_bands):
|
||||
im_band = self.im_aligned[:,:,i]
|
||||
# 一次性计算mask和background,避免重复计算
|
||||
glint_mask = self.get_glint_mask(im_band)
|
||||
background = self.get_est_background(im_band, k_size=5)
|
||||
# 使用视图和原地操作减少内存
|
||||
im_corrected = im_band.copy()
|
||||
glint_mask_bool = glint_mask.astype(bool)
|
||||
im_corrected[glint_mask_bool] = background[glint_mask_bool]
|
||||
|
||||
# 如果存在水域掩膜,确保只在水域内应用校正
|
||||
if water_mask_bool is not None:
|
||||
# 只在水域掩膜内应用校正
|
||||
correction_mask = glint_mask_bool & water_mask_bool
|
||||
im_corrected = np.where(correction_mask, background, im_band)
|
||||
# 非水域区域保持原值
|
||||
im_corrected = np.where(water_mask_bool, im_corrected, im_band)
|
||||
|
||||
corrected_bands.append(im_corrected)
|
||||
|
||||
# 如果提供了输出路径,保存结果
|
||||
if self.output_path is not None:
|
||||
self._save_corrected_bands(corrected_bands)
|
||||
|
||||
return corrected_bands
|
||||
|
||||
def correction_iterative(im_aligned,iter=3,bounds = [(1,2)],estimate_background=True,glint_mask_method="cdf",get_glint_mask=False,termination_thresh = 20, water_mask=None, output_path=None):
|
||||
"""
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param iter (int or None): number of iterations to run the sugar algorithm. If None, termination conditions are automatically applied
|
||||
:param bounds (list of tuples): to limit correction magnitude
|
||||
:param get_glint_mask (np.ndarray):
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存最后一次迭代的校正结果
|
||||
如果为None,则不保存
|
||||
conducts iterative correction using SUGAR
|
||||
"""
|
||||
glint_image = im_aligned.copy()
|
||||
corrected_images = []
|
||||
|
||||
if iter is None:
|
||||
# termination conditions
|
||||
relative_difference = lambda sd0,sd1: sd1/sd0*100
|
||||
marginal_difference = lambda sd1,sd2: (sd1-sd2)/sd1*100
|
||||
relative_diff_thresh = marginal_difference_thresh = termination_thresh
|
||||
sd_og = np.var(im_aligned)
|
||||
iter_count = 0
|
||||
sd_next = sd_og # 不需要copy,直接使用值
|
||||
max_iter = 100 # 添加最大迭代次数限制,防止无限循环
|
||||
|
||||
while ((relative_difference(sd_og,sd_next) > relative_diff_thresh) and iter_count < max_iter):
|
||||
# do all the processing here
|
||||
HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask)
|
||||
corrected_bands = HM.get_corrected_bands()
|
||||
glint_image = np.stack(corrected_bands,axis=2)
|
||||
sd_temp = np.var(glint_image)
|
||||
# 只在需要时保存中间结果,减少内存占用
|
||||
if get_glint_mask or iter_count == 0:
|
||||
corrected_images.append(glint_image.copy())
|
||||
else:
|
||||
corrected_images.append(glint_image) # 最后一次迭代的结果
|
||||
# save glint_mask
|
||||
# if iter_count == 0 and get_glint_mask is True:
|
||||
# glint_mask = np.stack(HM.glint_mask,axis=2)
|
||||
if (marginal_difference(sd_next,sd_temp)<marginal_difference_thresh):
|
||||
break
|
||||
else:
|
||||
sd_next = sd_temp
|
||||
#increase count
|
||||
iter_count += 1
|
||||
|
||||
# 如果提供了输出路径,保存最后一次迭代的结果
|
||||
if output_path is not None and len(corrected_images) > 0:
|
||||
_save_corrected_image(corrected_images[-1], output_path)
|
||||
|
||||
else:
|
||||
for i in range(iter):
|
||||
HM = SUGAR(glint_image,bounds,estimate_background=estimate_background, glint_mask_method=glint_mask_method, water_mask=water_mask)
|
||||
corrected_bands = HM.get_corrected_bands()
|
||||
glint_image = np.stack(corrected_bands,axis=2)
|
||||
# 只在最后一次迭代或需要时保存所有结果
|
||||
if i == iter - 1 or get_glint_mask:
|
||||
corrected_images.append(glint_image.copy())
|
||||
else:
|
||||
# 对于中间迭代,可以只保存引用(但要注意内存管理)
|
||||
corrected_images.append(glint_image)
|
||||
# save glint_mask
|
||||
# if i == 0 and get_glint_mask is True:
|
||||
# glint_mask = np.stack(HM.glint_mask,axis=2)
|
||||
|
||||
# 如果提供了输出路径,保存最后一次迭代的结果
|
||||
if output_path is not None and len(corrected_images) > 0:
|
||||
_save_corrected_image(corrected_images[-1], output_path)
|
||||
|
||||
return corrected_images
|
||||
|
||||
def _save_corrected_image(corrected_image, output_path):
|
||||
"""
|
||||
保存校正后的图像到文件(用于correction_iterative函数,BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_image: 校正后的图像数组,形状为(height, width, bands)
|
||||
:param output_path: 输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if output_path is None:
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(output_path)
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
else:
|
||||
bsq_path = output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_image.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_image[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
else:
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
1
src/core/glint_removal/__init__.py
Normal file
1
src/core/glint_removal/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
926
src/core/glint_removal/get_spectral-test.py
Normal file
926
src/core/glint_removal/get_spectral-test.py
Normal file
@ -0,0 +1,926 @@
|
||||
from osgeo import gdal, osr
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import spectral
|
||||
from math import sin, cos, tan, sqrt, radians
|
||||
|
||||
try:
|
||||
from scipy.ndimage import distance_transform_edt
|
||||
from scipy.spatial import cKDTree
|
||||
SCIPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
SCIPY_AVAILABLE = False
|
||||
|
||||
# 启用GDAL异常处理
|
||||
osr.UseExceptions()
|
||||
|
||||
# WGS84椭球参数
|
||||
WGS84_A = 6378137.0 # 长半轴(米)
|
||||
WGS84_F = 1 / 298.257223563 # 扁率
|
||||
WGS84_E2 = WGS84_F * (2 - WGS84_F) # 第一偏心率平方
|
||||
WGS84_EP2 = WGS84_E2 / (1 - WGS84_E2) # 第二偏心率平方
|
||||
UTM_K0 = 0.9996 # UTM比例因子
|
||||
def pixel_to_geo(pixel_x, pixel_y, geotransform):
|
||||
"""
|
||||
像素坐标转换为地图坐标
|
||||
"""
|
||||
geo_x = geotransform[0] + pixel_x * geotransform[1] + pixel_y * geotransform[2]
|
||||
geo_y = geotransform[3] + pixel_x * geotransform[4] + pixel_y * geotransform[5]
|
||||
return geo_x, geo_y
|
||||
|
||||
|
||||
def prepare_boundary_adjuster(boundary_mask):
|
||||
"""
|
||||
为边界掩膜构建辅助结构,用于根据半径调整采样中心
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
print("警告: 未安装SciPy,无法根据水体边界自动调整采样点位置。")
|
||||
return None
|
||||
|
||||
if boundary_mask is None:
|
||||
return None
|
||||
|
||||
boundary_bool = boundary_mask > 0
|
||||
if not np.any(boundary_bool):
|
||||
print("警告: 边界掩膜中未检测到有效水域,无法调整采样点。")
|
||||
return None
|
||||
|
||||
distance_map = distance_transform_edt(boundary_bool.astype(np.uint8))
|
||||
return {
|
||||
'mask': boundary_bool,
|
||||
'distance_map': distance_map,
|
||||
'trees': {}
|
||||
}
|
||||
|
||||
|
||||
def _get_boundary_tree(adjuster, radius):
|
||||
"""
|
||||
根据半径获取或构建适用的KDTree
|
||||
"""
|
||||
radius_key = float(radius)
|
||||
if radius_key in adjuster['trees']:
|
||||
return adjuster['trees'][radius_key]
|
||||
|
||||
distance_map = adjuster['distance_map']
|
||||
valid_positions = np.column_stack(np.where(distance_map >= radius_key))
|
||||
if valid_positions.size == 0:
|
||||
adjuster['trees'][radius_key] = None
|
||||
return None
|
||||
|
||||
tree = cKDTree(valid_positions)
|
||||
adjuster['trees'][radius_key] = (tree, valid_positions)
|
||||
return adjuster['trees'][radius_key]
|
||||
|
||||
|
||||
def adjust_sampling_center(pixel_x, pixel_y, radius, adjuster):
|
||||
"""
|
||||
如果采样半径范围超出水体边界,则将像素向内移动
|
||||
直至采样区域完全位于水体内部(与边界相切)
|
||||
"""
|
||||
if adjuster is None or radius <= 0:
|
||||
return pixel_x, pixel_y, False
|
||||
|
||||
distance_map = adjuster['distance_map']
|
||||
mask = adjuster['mask']
|
||||
|
||||
if pixel_y < 0 or pixel_y >= distance_map.shape[0] or pixel_x < 0 or pixel_x >= distance_map.shape[1]:
|
||||
return pixel_x, pixel_y, False
|
||||
|
||||
if not mask[pixel_y, pixel_x]:
|
||||
# 当前像素不在水域内,需要移动到最近的合法位置
|
||||
tree_info = _get_boundary_tree(adjuster, max(radius, 1))
|
||||
if tree_info is None:
|
||||
return pixel_x, pixel_y, False
|
||||
else:
|
||||
if distance_map[pixel_y, pixel_x] >= radius:
|
||||
return pixel_x, pixel_y, False
|
||||
tree_info = _get_boundary_tree(adjuster, radius)
|
||||
if tree_info is None:
|
||||
# 没有任何可以容纳该半径的像素,直接返回原位置
|
||||
return pixel_x, pixel_y, False
|
||||
|
||||
tree, valid_positions = tree_info
|
||||
if tree is None or valid_positions.size == 0:
|
||||
return pixel_x, pixel_y, False
|
||||
|
||||
# 查询附近潜在位置
|
||||
max_candidates = min(64, len(valid_positions))
|
||||
distances, indices = tree.query([pixel_y, pixel_x], k=max_candidates)
|
||||
|
||||
if np.isscalar(indices):
|
||||
indices = [int(indices)]
|
||||
else:
|
||||
indices = np.atleast_1d(indices).astype(int)
|
||||
|
||||
best_candidate = None
|
||||
best_delta = None
|
||||
|
||||
for idx in indices:
|
||||
cy, cx = valid_positions[idx]
|
||||
if distance_map[cy, cx] < radius:
|
||||
continue
|
||||
delta = distance_map[cy, cx] - radius
|
||||
center_shift = (cx - pixel_x) ** 2 + (cy - pixel_y) ** 2
|
||||
score = (abs(delta), center_shift)
|
||||
if best_candidate is None or score < best_delta:
|
||||
best_candidate = (cx, cy)
|
||||
best_delta = score
|
||||
|
||||
if best_candidate is None:
|
||||
# 没有找到满足条件的候选点
|
||||
return pixel_x, pixel_y, False
|
||||
|
||||
return int(best_candidate[0]), int(best_candidate[1]), True
|
||||
|
||||
|
||||
|
||||
def transform_coordinates(lon, lat, source_srs, target_srs):
|
||||
"""
|
||||
坐标系转换
|
||||
|
||||
Args:
|
||||
lon: 经度
|
||||
lat: 纬度
|
||||
source_srs: 源坐标系
|
||||
target_srs: 目标坐标系
|
||||
|
||||
Returns:
|
||||
transformed_lon, transformed_lat: 转换后的坐标
|
||||
"""
|
||||
# 创建坐标转换对象
|
||||
transform = osr.CoordinateTransformation(source_srs, target_srs)
|
||||
|
||||
# 执行坐标转换
|
||||
point = transform.TransformPoint(lon, lat)
|
||||
|
||||
return point[0], point[1]
|
||||
|
||||
|
||||
|
||||
def geo_to_pixel(lon, lat, geotransform, dataset_srs=None):
|
||||
"""
|
||||
地理坐标转换为像素坐标
|
||||
|
||||
Args:
|
||||
lon: 经度
|
||||
lat: 纬度
|
||||
geotransform: 仿射变换参数
|
||||
dataset_srs: 数据集的空间参考系统(可选)
|
||||
|
||||
Returns:
|
||||
pixel_x, pixel_y: 像素坐标
|
||||
"""
|
||||
# 使用仿射变换的逆变换将地理坐标转换为像素坐标
|
||||
x_origin = geotransform[0]
|
||||
y_origin = geotransform[3]
|
||||
pixel_width = geotransform[1]
|
||||
pixel_height = geotransform[5]
|
||||
|
||||
pixel_x = int((lon - x_origin) / pixel_width)
|
||||
pixel_y = int((lat - y_origin) / pixel_height)
|
||||
|
||||
return pixel_x, pixel_y
|
||||
|
||||
|
||||
def get_pixel_spectrum_batch(dataset, pixel_x_array, pixel_y_array):
|
||||
"""
|
||||
批量获取多个像素点的光谱数据(优化版本)
|
||||
|
||||
Args:
|
||||
dataset: GDAL数据集
|
||||
pixel_x_array: 像素X坐标数组
|
||||
pixel_y_array: 像素Y坐标数组
|
||||
|
||||
Returns:
|
||||
spectrum_array: 光谱数据数组 (n_points, n_bands)
|
||||
"""
|
||||
n_points = len(pixel_x_array)
|
||||
n_bands = dataset.RasterCount
|
||||
|
||||
# 初始化输出数组
|
||||
spectrum_array = np.zeros((n_points, n_bands), dtype=np.float32)
|
||||
|
||||
# 按波段批量读取(更高效)
|
||||
for band_idx in range(n_bands):
|
||||
band = dataset.GetRasterBand(band_idx + 1) # GDAL波段索引从1开始
|
||||
band_data = band.ReadAsArray() # 读取整个波段
|
||||
|
||||
# 批量提取像素值
|
||||
for i in range(n_points):
|
||||
px, py = int(pixel_x_array[i]), int(pixel_y_array[i])
|
||||
if 0 <= px < band_data.shape[1] and 0 <= py < band_data.shape[0]:
|
||||
spectrum_array[i, band_idx] = band_data[py, px]
|
||||
else:
|
||||
spectrum_array[i, band_idx] = np.nan
|
||||
|
||||
return spectrum_array
|
||||
|
||||
|
||||
def get_average_spectral_in_radius(dataset, center_x, center_y, radius, flare_mask=None, boundary_mask=None):
|
||||
"""
|
||||
获取指定半径内的平均光谱,避开耀斑和边界区域
|
||||
|
||||
Args:
|
||||
dataset: GDAL数据集
|
||||
center_x, center_y: 中心像素坐标
|
||||
radius: 半径(像素)
|
||||
flare_mask: 耀斑掩膜数组(可选)
|
||||
boundary_mask: 边界掩膜数组(可选)
|
||||
|
||||
Returns:
|
||||
平均光谱值数组
|
||||
"""
|
||||
num_bands = dataset.RasterCount
|
||||
|
||||
# 计算采样区域边界
|
||||
x_start = max(0, center_x - radius)
|
||||
x_end = min(dataset.RasterXSize, center_x + radius + 1)
|
||||
y_start = max(0, center_y - radius)
|
||||
y_end = min(dataset.RasterYSize, center_y + radius + 1)
|
||||
|
||||
# 读取区域数据
|
||||
width = x_end - x_start
|
||||
height = y_end - y_start
|
||||
|
||||
if width <= 0 or height <= 0:
|
||||
return np.zeros(num_bands)
|
||||
|
||||
# 读取所有波段数据
|
||||
spectral_data = dataset.ReadAsArray(x_start, y_start, width, height)
|
||||
if spectral_data is None:
|
||||
return np.zeros(num_bands)
|
||||
|
||||
# 确保数据是3维的 (bands, height, width)
|
||||
if len(spectral_data.shape) == 2:
|
||||
spectral_data = spectral_data.reshape(1, spectral_data.shape[0], spectral_data.shape[1])
|
||||
|
||||
# 创建圆形掩膜
|
||||
y_indices, x_indices = np.ogrid[:height, :width]
|
||||
center_x_local = center_x - x_start
|
||||
center_y_local = center_y - y_start
|
||||
|
||||
# 计算距离掩膜
|
||||
distance_mask = ((x_indices - center_x_local) ** 2 + (y_indices - center_y_local) ** 2) <= radius ** 2
|
||||
|
||||
# 应用耀斑掩膜(如果提供)
|
||||
if flare_mask is not None:
|
||||
flare_region = flare_mask[y_start:y_end, x_start:x_end]
|
||||
if flare_region.shape == distance_mask.shape:
|
||||
distance_mask = distance_mask & (flare_region == 0) # 假设0表示无耀斑
|
||||
|
||||
# 应用边界掩膜(如果提供)
|
||||
if boundary_mask is not None:
|
||||
boundary_region = boundary_mask[y_start:y_end, x_start:x_end]
|
||||
if boundary_region.shape == distance_mask.shape:
|
||||
distance_mask = distance_mask & (boundary_region == 1) # 假设0表示无边界
|
||||
|
||||
# 计算平均光谱
|
||||
average_spectrum = np.zeros(num_bands)
|
||||
valid_pixels = np.sum(distance_mask)
|
||||
|
||||
if valid_pixels > 0:
|
||||
for band in range(num_bands):
|
||||
band_data = spectral_data[band, :, :]
|
||||
# 排除无效值
|
||||
valid_data = band_data[distance_mask & (band_data != 0) & np.isfinite(band_data)]
|
||||
if len(valid_data) > 0:
|
||||
average_spectrum[band] = np.mean(valid_data)
|
||||
|
||||
return average_spectrum
|
||||
|
||||
|
||||
def load_mask_file(mask_path):
|
||||
"""
|
||||
加载掩膜文件
|
||||
|
||||
Args:
|
||||
mask_path: 掩膜文件路径(支持栅格文件如.dat/.tif等)
|
||||
|
||||
Returns:
|
||||
掩膜数组
|
||||
"""
|
||||
if mask_path is None or not os.path.exists(mask_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
# 使用gdal.OpenEx打开文件,明确指定为栅格文件
|
||||
# 如果文件是矢量格式,会返回None,避免"多图层"错误
|
||||
dataset = gdal.OpenEx(mask_path, gdal.OF_RASTER)
|
||||
if dataset is None:
|
||||
# 如果OpenEx失败,尝试使用Open(向后兼容)
|
||||
dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
print(f"警告: 无法打开掩膜文件 {mask_path},可能不是有效的栅格文件")
|
||||
return None
|
||||
|
||||
# 检查是否为栅格数据集(有RasterCount属性)
|
||||
if not hasattr(dataset, 'RasterCount') or dataset.RasterCount == 0:
|
||||
print(f"警告: {mask_path} 不是有效的栅格文件")
|
||||
del dataset
|
||||
return None
|
||||
|
||||
mask_data = dataset.GetRasterBand(1).ReadAsArray()
|
||||
del dataset
|
||||
return mask_data
|
||||
except Exception as e:
|
||||
print(f"警告: 加载掩膜文件 {mask_path} 时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_hdr_file_path(file_path):
|
||||
"""
|
||||
获取HDR文件路径
|
||||
|
||||
Args:
|
||||
file_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
HDR文件路径
|
||||
"""
|
||||
return os.path.splitext(file_path)[0] + ".hdr"
|
||||
|
||||
|
||||
def calculate_utm_zone(longitude):
|
||||
"""
|
||||
根据经度计算UTM分区号
|
||||
|
||||
Args:
|
||||
longitude: 经度
|
||||
|
||||
Returns:
|
||||
utm_zone: UTM分区号(1-60)
|
||||
"""
|
||||
# UTM分区从180度开始,每个分区6度
|
||||
utm_zone = int((longitude + 180) / 6) + 1
|
||||
# 确保分区号在有效范围内
|
||||
utm_zone = max(1, min(60, utm_zone))
|
||||
return utm_zone
|
||||
|
||||
|
||||
def latlon_to_utm_math(lat_deg, lon_deg, zone=None):
|
||||
"""
|
||||
使用数学公式将WGS84经纬度转换为UTM坐标
|
||||
|
||||
Args:
|
||||
lat_deg: 纬度(度)
|
||||
lon_deg: 经度(度)
|
||||
zone: UTM分区号(如果为None,则根据经度自动计算)
|
||||
|
||||
Returns:
|
||||
easting, northing: UTM坐标(米)
|
||||
"""
|
||||
# 如果未指定分区,根据经度计算
|
||||
if zone is None:
|
||||
zone = calculate_utm_zone(lon_deg)
|
||||
|
||||
# 计算中央经线(度)
|
||||
lon0 = (zone * 6 - 183)
|
||||
lam0 = radians(lon0)
|
||||
|
||||
# 转换为弧度
|
||||
phi = radians(lat_deg)
|
||||
lam = radians(lon_deg)
|
||||
|
||||
# 计算中间变量
|
||||
sinphi = sin(phi)
|
||||
cosphi = cos(phi)
|
||||
tanphi = tan(phi)
|
||||
|
||||
# 计算卯酉圈曲率半径
|
||||
N = WGS84_A / sqrt(1 - WGS84_E2 * sinphi * sinphi)
|
||||
|
||||
T = tanphi * tanphi
|
||||
C = WGS84_EP2 * cosphi * cosphi
|
||||
A = cosphi * (lam - lam0)
|
||||
|
||||
# 计算子午圈弧长(使用Snyder公式)
|
||||
M = (WGS84_A * ((1 - WGS84_E2/4 - 3*WGS84_E2**2/64 - 5*WGS84_E2**3/256) * phi
|
||||
- (3*WGS84_E2/8 + 3*WGS84_E2**2/32 + 45*WGS84_E2**3/1024) * sin(2*phi)
|
||||
+ (15*WGS84_E2**2/256 + 45*WGS84_E2**3/1024) * sin(4*phi)
|
||||
- (35*WGS84_E2**3/3072) * sin(6*phi)))
|
||||
|
||||
# 计算东坐标(Easting)
|
||||
E = (UTM_K0 * N * (A + (1 - T + C) * A**3 / 6
|
||||
+ (5 - 18*T + T*T + 72*C - 58*WGS84_EP2) * A**5 / 120)
|
||||
+ 500000.0)
|
||||
|
||||
# 计算北坐标(Northing)
|
||||
# 对于南半球,需要添加10000000米偏移
|
||||
if lat_deg < 0:
|
||||
Nn = (UTM_K0 * (M + N * tanphi * (A**2 / 2
|
||||
+ (5 - T + 9*C + 4*C*C) * A**4 / 24
|
||||
+ (61 - 58*T + T*T + 600*C - 330*WGS84_EP2) * A**6 / 720))
|
||||
+ 10000000.0)
|
||||
else:
|
||||
Nn = (UTM_K0 * (M + N * tanphi * (A**2 / 2
|
||||
+ (5 - T + 9*C + 4*C*C) * A**4 / 24
|
||||
+ (61 - 58*T + T*T + 600*C - 330*WGS84_EP2) * A**6 / 720)))
|
||||
|
||||
return E, Nn
|
||||
|
||||
|
||||
def convert_to_utm(lon, lat, source_epsg=4326, target_epsg=None):
|
||||
"""
|
||||
将坐标转换为UTM格式(使用数学公式,根据经度自动计算UTM分区)
|
||||
|
||||
Args:
|
||||
lon: 经度数组
|
||||
lat: 纬度数组
|
||||
source_epsg: 源坐标系EPSG代码,默认为4326 (WGS84地理坐标系)
|
||||
target_epsg: 目标坐标系EPSG代码(如果为None,则根据经度自动计算;如果指定,则从EPSG代码提取分区号)
|
||||
|
||||
Returns:
|
||||
utm_x, utm_y: 转换后的UTM坐标(米)
|
||||
"""
|
||||
try:
|
||||
# 检查源坐标系是否为WGS84
|
||||
if source_epsg != 4326:
|
||||
print(f"警告: 数学公式转换仅支持WGS84 (EPSG:4326),当前源坐标系为EPSG:{source_epsg}")
|
||||
print("将尝试使用数学公式进行转换,但可能不准确")
|
||||
|
||||
# 批量转换坐标
|
||||
utm_x = np.zeros_like(lon)
|
||||
utm_y = np.zeros_like(lat)
|
||||
|
||||
# 如果指定了目标EPSG,提取分区号
|
||||
fixed_zone = None
|
||||
if target_epsg is not None:
|
||||
# 从EPSG代码提取分区号
|
||||
# EPSG:32651 -> 51, EPSG:32751 -> 51
|
||||
if 32601 <= target_epsg <= 32660:
|
||||
fixed_zone = target_epsg - 32600
|
||||
elif 32701 <= target_epsg <= 32760:
|
||||
fixed_zone = target_epsg - 32700
|
||||
else:
|
||||
print(f"警告: 无法从EPSG代码 {target_epsg} 提取UTM分区号,将根据经度自动计算")
|
||||
|
||||
# 向量化处理:标记无效坐标
|
||||
invalid_mask = (np.isnan(lon) | np.isnan(lat) |
|
||||
(lon < -180) | (lon > 180) |
|
||||
(lat < -90) | (lat > 90))
|
||||
|
||||
# 统计无效坐标
|
||||
invalid_count = np.sum(invalid_mask)
|
||||
if invalid_count > 0:
|
||||
invalid_indices = np.where(invalid_mask)[0]
|
||||
print(f"警告: 发现 {invalid_count} 个无效坐标点,将跳过")
|
||||
for idx in invalid_indices[:10]: # 只打印前10个
|
||||
print(f" 坐标点 {idx + 1}: 经度={lon[idx]}, 纬度={lat[idx]}")
|
||||
if invalid_count > 10:
|
||||
print(f" ... 还有 {invalid_count - 10} 个无效坐标点")
|
||||
|
||||
# 对有效坐标进行转换
|
||||
valid_mask = ~invalid_mask
|
||||
if np.any(valid_mask):
|
||||
valid_lon = lon[valid_mask]
|
||||
valid_lat = lat[valid_mask]
|
||||
valid_indices = np.where(valid_mask)[0]
|
||||
|
||||
# 计算UTM分区(向量化)
|
||||
if fixed_zone is not None:
|
||||
zones = np.full(len(valid_lon), fixed_zone)
|
||||
else:
|
||||
zones = np.array([calculate_utm_zone(lon_val) for lon_val in valid_lon])
|
||||
|
||||
# 批量转换(仍需要循环,但减少了开销)
|
||||
for i, (lat_val, lon_val, zone) in enumerate(zip(valid_lat, valid_lon, zones)):
|
||||
try:
|
||||
E, Nn = latlon_to_utm_math(lat_val, lon_val, zone)
|
||||
if not (np.isnan(E) or np.isnan(Nn) or np.isinf(E) or np.isinf(Nn)):
|
||||
utm_x[valid_indices[i]] = E
|
||||
utm_y[valid_indices[i]] = Nn
|
||||
else:
|
||||
utm_x[valid_indices[i]] = np.nan
|
||||
utm_y[valid_indices[i]] = np.nan
|
||||
except Exception as e:
|
||||
utm_x[valid_indices[i]] = np.nan
|
||||
utm_y[valid_indices[i]] = np.nan
|
||||
|
||||
# 设置无效坐标为NaN
|
||||
utm_x[invalid_mask] = np.nan
|
||||
utm_y[invalid_mask] = np.nan
|
||||
|
||||
return utm_x, utm_y
|
||||
|
||||
except Exception as e:
|
||||
print(f"坐标转换初始化失败: {str(e)}")
|
||||
return np.full_like(lon, np.nan), np.full_like(lat, np.nan)
|
||||
|
||||
|
||||
def convert_to_utm51n(lon, lat, source_epsg=4326):
|
||||
"""
|
||||
将坐标转换为WGS84 UTM 51N格式(保留向后兼容性)
|
||||
|
||||
Args:
|
||||
lon: 经度数组
|
||||
lat: 纬度数组
|
||||
source_epsg: 源坐标系EPSG代码,默认为4326 (WGS84地理坐标系)
|
||||
|
||||
Returns:
|
||||
utm_x, utm_y: 转换后的UTM坐标(米)
|
||||
"""
|
||||
# 使用新的转换函数,但强制使用UTM 51N
|
||||
return convert_to_utm(lon, lat, source_epsg, target_epsg=32651)
|
||||
|
||||
|
||||
def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None, boundary_path=None, source_epsg=4326):
|
||||
"""
|
||||
获取给定坐标的光谱曲线,并将坐标转换为UTM格式(根据经度自动计算UTM分区)
|
||||
|
||||
Args:
|
||||
imgpath: 影像文件路径(BIL格式)
|
||||
coorpath: 坐标文件路径(CSV格式,第1、2列为纬度和经度)
|
||||
outpath: 输出文件路径(CSV格式)
|
||||
radius: 采样半径(像素)
|
||||
flare_path: 耀斑文件路径(可选)
|
||||
boundary_path: 边界文件路径(可选)
|
||||
source_epsg: 源坐标系EPSG代码,默认为4326 (WGS84地理坐标系)
|
||||
"""
|
||||
# 读取原始坐标文件(CSV格式)
|
||||
coor_df = None
|
||||
coor_data = None
|
||||
|
||||
# 尝试不同的编码方式读取CSV文件
|
||||
encodings = ['utf-8', 'gbk', 'gb2312', 'latin1', 'cp1252']
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
# 尝试读取CSV文件
|
||||
coor_df = pd.read_csv(coorpath, encoding=encoding)
|
||||
# 只提取数值数据,跳过表头
|
||||
coor_data = coor_df.select_dtypes(include=[np.number]).values
|
||||
|
||||
# 如果没有数值列,尝试转换所有列(跳过第一行表头)
|
||||
if coor_data.shape[1] == 0:
|
||||
# 尝试从第二行开始读取,第一行作为表头
|
||||
coor_df = pd.read_csv(coorpath, encoding=encoding, header=0)
|
||||
# 尝试将所有列转换为数值
|
||||
numeric_df = coor_df.apply(pd.to_numeric, errors='coerce')
|
||||
# 删除全为NaN的行(通常是表头转换失败的行)
|
||||
numeric_df = numeric_df.dropna(how='all')
|
||||
coor_data = numeric_df.values
|
||||
|
||||
print(f"成功使用 {encoding} 编码读取文件")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用 {encoding} 编码读取失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 如果所有编码都失败,尝试numpy读取
|
||||
if coor_data is None:
|
||||
try:
|
||||
print("尝试使用numpy读取数值数据...")
|
||||
# 跳过第一行(表头),只读取数值
|
||||
coor_data = np.loadtxt(coorpath, delimiter=",", skiprows=1)
|
||||
except:
|
||||
try:
|
||||
coor_data = np.loadtxt(coorpath, delimiter="\t", skiprows=1)
|
||||
except Exception as e:
|
||||
raise Exception(f"无法读取坐标文件,请检查文件格式: {str(e)}")
|
||||
|
||||
if len(coor_data.shape) == 1:
|
||||
coor_data = coor_data.reshape(1, -1)
|
||||
|
||||
# 检查数据有效性
|
||||
if coor_data is None or coor_data.shape[1] < 2:
|
||||
raise Exception("坐标文件格式错误:需要至少2列数据(第1列为纬度,第2列为经度)")
|
||||
|
||||
print(f"成功读取坐标文件,共 {coor_data.shape[0]} 行,{coor_data.shape[1]} 列")
|
||||
print(f"数据预览(前3行):")
|
||||
for i in range(min(3, coor_data.shape[0])):
|
||||
print(f" 行{i + 1}: {coor_data[i, :min(5, coor_data.shape[1])]}") # 只显示前5列
|
||||
|
||||
# 提取原始坐标
|
||||
lat_array = coor_data[:, 0] # 第1列是纬度
|
||||
lon_array = coor_data[:, 1] # 第2列是经度
|
||||
|
||||
print(f"\n=== 原始坐标信息 ===")
|
||||
print(f"原始坐标范围: 经度 {np.min(lon_array):.6f} ~ {np.max(lon_array):.6f}, 纬度 {np.min(lat_array):.6f} ~ {np.max(lat_array):.6f}")
|
||||
|
||||
# 坐标转换为UTM(根据经度自动计算UTM分区)
|
||||
print("正在进行坐标转换...")
|
||||
utm_x, utm_y = convert_to_utm(lon_array, lat_array, source_epsg, target_epsg=None)
|
||||
|
||||
# 检查转换结果
|
||||
valid_utm_mask = ~(np.isnan(utm_x) | np.isnan(utm_y) | np.isinf(utm_x) | np.isinf(utm_y))
|
||||
valid_count = np.sum(valid_utm_mask)
|
||||
|
||||
if valid_count > 0:
|
||||
print(f"转换后UTM坐标范围: X {np.nanmin(utm_x):.2f} ~ {np.nanmax(utm_x):.2f}, Y {np.nanmin(utm_y):.2f} ~ {np.nanmax(utm_y):.2f}")
|
||||
print(f"成功转换 {valid_count}/{len(utm_x)} 个坐标点")
|
||||
else:
|
||||
print("警告: 所有UTM坐标转换都失败了,将尝试使用原始经纬度坐标进行像素坐标转换")
|
||||
|
||||
# 打开影像数据集
|
||||
dataset = gdal.Open(imgpath)
|
||||
im_width = dataset.RasterXSize # 栅格矩阵的列数
|
||||
im_height = dataset.RasterYSize # 栅格矩阵的行数
|
||||
num_bands = dataset.RasterCount # 栅格矩阵的波段数
|
||||
geotransform = dataset.GetGeoTransform() # 仿射矩阵
|
||||
im_proj = dataset.GetProjection() # 地图投影信息
|
||||
|
||||
print(f"影像尺寸: {im_width} x {im_height}, 波段数: {num_bands}")
|
||||
print(f"仿射变换参数: {geotransform}")
|
||||
|
||||
print("\n=== 开始光谱提取 ===")
|
||||
|
||||
# 加载掩膜文件
|
||||
flare_mask = load_mask_file(flare_path)
|
||||
boundary_mask = load_mask_file(boundary_path)
|
||||
boundary_adjuster = None
|
||||
if boundary_mask is not None and radius > 0:
|
||||
boundary_adjuster = prepare_boundary_adjuster(boundary_mask)
|
||||
if boundary_adjuster is None:
|
||||
print("提示: 无法构建边界调整器,采样点将不会根据水体边界进行内移。")
|
||||
|
||||
# 获取数据集的空间参考系统
|
||||
dataset_srs = dataset.GetSpatialRef()
|
||||
|
||||
# 准备输出数组,在原有数据基础上添加UTM坐标和光谱列
|
||||
original_cols = coor_data.shape[1]
|
||||
# 添加UTM坐标列(2列)和光谱列(num_bands列)
|
||||
new_columns = np.zeros((coor_data.shape[0], 2 + num_bands))
|
||||
coor_spectral = np.hstack((coor_data, new_columns))
|
||||
|
||||
# 将UTM坐标添加到数据中(会在采样点调整后再更新为最终位置)
|
||||
coor_spectral[:, original_cols] = utm_x # 初始UTM X坐标
|
||||
coor_spectral[:, original_cols + 1] = utm_y # 初始UTM Y坐标
|
||||
|
||||
print(f"处理 {coor_data.shape[0]} 个坐标点...")
|
||||
|
||||
# 如果UTM转换失败,尝试使用影像坐标系进行转换
|
||||
use_utm_fallback = False
|
||||
if valid_count == 0 and dataset_srs is not None:
|
||||
print("尝试使用影像坐标系进行坐标转换...")
|
||||
try:
|
||||
source_srs = osr.SpatialReference()
|
||||
source_srs.ImportFromEPSG(source_epsg)
|
||||
transform_to_image = osr.CoordinateTransformation(source_srs, dataset_srs)
|
||||
use_utm_fallback = True
|
||||
except:
|
||||
use_utm_fallback = False
|
||||
|
||||
# 批量转换所有坐标点为像素坐标
|
||||
pixel_x_array = np.zeros(coor_data.shape[0], dtype=np.int32)
|
||||
pixel_y_array = np.zeros(coor_data.shape[0], dtype=np.int32)
|
||||
valid_pixel_mask = np.zeros(coor_data.shape[0], dtype=bool)
|
||||
|
||||
# 批量计算像素坐标
|
||||
for i in range(coor_data.shape[0]):
|
||||
# 优先使用UTM坐标,如果无效则使用备用方案
|
||||
utm_x_point = utm_x[i]
|
||||
utm_y_point = utm_y[i]
|
||||
|
||||
# 检查UTM坐标是否有效
|
||||
if np.isnan(utm_x_point) or np.isnan(utm_y_point) or np.isinf(utm_x_point) or np.isinf(utm_y_point):
|
||||
# 如果UTM转换失败,尝试使用影像坐标系
|
||||
if use_utm_fallback:
|
||||
try:
|
||||
lon_point = lon_array[i]
|
||||
lat_point = lat_array[i]
|
||||
if not (np.isnan(lon_point) or np.isnan(lat_point)):
|
||||
# 转换为影像坐标系
|
||||
img_coords = transform_to_image.TransformPoint(lon_point, lat_point)
|
||||
pixel_x, pixel_y = geo_to_pixel(img_coords[0], img_coords[1], geotransform, dataset_srs)
|
||||
# 更新UTM坐标列(使用影像坐标系坐标)
|
||||
coor_spectral[i, original_cols] = img_coords[0]
|
||||
coor_spectral[i, original_cols + 1] = img_coords[1]
|
||||
else:
|
||||
print(f"跳过坐标点 {i + 1}: 坐标无效")
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
continue
|
||||
except Exception as e:
|
||||
# 如果影像坐标系转换也失败,尝试直接使用经纬度
|
||||
try:
|
||||
lon_point = lon_array[i]
|
||||
lat_point = lat_array[i]
|
||||
if not (np.isnan(lon_point) or np.isnan(lat_point)):
|
||||
pixel_x, pixel_y = geo_to_pixel(lon_point, lat_point, geotransform, dataset_srs)
|
||||
# 保留原始经纬度作为坐标
|
||||
coor_spectral[i, original_cols] = lon_point
|
||||
coor_spectral[i, original_cols + 1] = lat_point
|
||||
else:
|
||||
print(f"跳过坐标点 {i + 1}: 坐标无效")
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
continue
|
||||
except:
|
||||
print(f"跳过坐标点 {i + 1}: 所有坐标转换方式都失败")
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
continue
|
||||
else:
|
||||
# 尝试直接使用经纬度坐标
|
||||
try:
|
||||
lon_point = lon_array[i]
|
||||
lat_point = lat_array[i]
|
||||
if not (np.isnan(lon_point) or np.isnan(lat_point)):
|
||||
pixel_x, pixel_y = geo_to_pixel(lon_point, lat_point, geotransform, dataset_srs)
|
||||
# 保留原始经纬度作为坐标
|
||||
coor_spectral[i, original_cols] = lon_point
|
||||
coor_spectral[i, original_cols + 1] = lat_point
|
||||
else:
|
||||
print(f"跳过坐标点 {i + 1}: 坐标无效")
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
continue
|
||||
except:
|
||||
print(f"跳过坐标点 {i + 1}: 坐标转换失败")
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
continue
|
||||
else:
|
||||
# UTM坐标转换为像素坐标
|
||||
pixel_x, pixel_y = geo_to_pixel(utm_x_point, utm_y_point, geotransform, dataset_srs)
|
||||
|
||||
# 存储像素坐标
|
||||
pixel_x_array[i] = pixel_x
|
||||
pixel_y_array[i] = pixel_y
|
||||
|
||||
# 根据水体边界调整采样中心
|
||||
moved = False
|
||||
original_pixel_x, original_pixel_y = pixel_x, pixel_y
|
||||
if boundary_adjuster is not None and radius > 0:
|
||||
new_pixel_x, new_pixel_y, moved = adjust_sampling_center(pixel_x, pixel_y, radius, boundary_adjuster)
|
||||
if moved:
|
||||
pixel_x, pixel_y = new_pixel_x, new_pixel_y
|
||||
if i < 10 or (i % 100 == 0):
|
||||
print(f" 采样点 {i + 1} 调整至水体内部: ({original_pixel_x}, {original_pixel_y}) -> ({pixel_x}, {pixel_y})")
|
||||
|
||||
pixel_x_array[i] = pixel_x
|
||||
pixel_y_array[i] = pixel_y
|
||||
|
||||
# 检查坐标是否在影像范围内(使用调整后的坐标)
|
||||
if 0 <= pixel_x < im_width and 0 <= pixel_y < im_height:
|
||||
valid_pixel_mask[i] = True
|
||||
# 更新UTM列为最终采样点的实际地图坐标
|
||||
geo_x, geo_y = pixel_to_geo(pixel_x, pixel_y, geotransform)
|
||||
coor_spectral[i, original_cols] = geo_x
|
||||
coor_spectral[i, original_cols + 1] = geo_y
|
||||
else:
|
||||
valid_pixel_mask[i] = False
|
||||
if i < 10 or (i % 100 == 0): # 只打印前10个或每100个打印一次
|
||||
print(f"警告: 坐标点 {i + 1} (UTM X:{utm_x_point:.2f}, Y:{utm_y_point:.2f}) 超出影像范围")
|
||||
|
||||
# 批量提取光谱数据(优化:减少I/O操作)
|
||||
print(f"批量提取光谱数据... (有效坐标点: {np.sum(valid_pixel_mask)})")
|
||||
|
||||
if radius > 0:
|
||||
# 半径采样模式:需要逐个处理
|
||||
for i in range(coor_data.shape[0]):
|
||||
if valid_pixel_mask[i]:
|
||||
spectrum = get_average_spectral_in_radius(
|
||||
dataset, pixel_x_array[i], pixel_y_array[i], radius, flare_mask, boundary_mask
|
||||
)
|
||||
coor_spectral[i, original_cols + 2:] = spectrum
|
||||
else:
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
else:
|
||||
# 单点采样模式:批量读取(优化)
|
||||
# 预读取所有波段数据(如果内存允许)
|
||||
try:
|
||||
# 尝试读取所有波段到内存(适用于内存充足的情况)
|
||||
print("正在预加载所有波段数据到内存(优化模式)...")
|
||||
all_bands_data = []
|
||||
for band_idx in range(num_bands):
|
||||
band = dataset.GetRasterBand(band_idx + 1)
|
||||
band_data = band.ReadAsArray()
|
||||
all_bands_data.append(band_data)
|
||||
all_bands_data = np.array(all_bands_data) # shape: (bands, height, width)
|
||||
print("预加载完成,开始批量提取像素值...")
|
||||
|
||||
# 批量提取像素值
|
||||
for i in range(coor_data.shape[0]):
|
||||
if valid_pixel_mask[i]:
|
||||
px, py = int(pixel_x_array[i]), int(pixel_y_array[i])
|
||||
# GDAL读取的数组形状是 (bands, height, width),像素坐标 (x,y) 对应数组索引 [:, y, x]
|
||||
# 注意:py是行(y坐标),px是列(x坐标)
|
||||
if 0 <= px < all_bands_data.shape[2] and 0 <= py < all_bands_data.shape[1]:
|
||||
spectrum = all_bands_data[:, py, px] # 直接索引,非常快
|
||||
coor_spectral[i, original_cols + 2:] = spectrum
|
||||
else:
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
else:
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
|
||||
# 释放内存
|
||||
del all_bands_data
|
||||
print("批量提取完成")
|
||||
|
||||
except MemoryError:
|
||||
# 如果内存不足,回退到逐个波段读取
|
||||
print("内存不足,使用逐个波段读取模式...")
|
||||
for i in range(coor_data.shape[0]):
|
||||
if valid_pixel_mask[i]:
|
||||
px, py = pixel_x_array[i], pixel_y_array[i]
|
||||
spectrum = np.zeros(num_bands)
|
||||
for band_idx in range(num_bands):
|
||||
band = dataset.GetRasterBand(band_idx + 1)
|
||||
spectrum[band_idx] = band.ReadAsArray(px, py, 1, 1)[0, 0]
|
||||
coor_spectral[i, original_cols + 2:] = spectrum
|
||||
else:
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
|
||||
del dataset
|
||||
|
||||
# 创建DataFrame用于CSV输出
|
||||
# 去除前两列坐标列(纬度和经度)和UTM列
|
||||
try:
|
||||
# 如果原始数据有列名,使用原始列名(跳过前两列)
|
||||
if coor_df is not None and hasattr(coor_df, 'columns'):
|
||||
# 跳过前两列(经纬度),从第3列开始
|
||||
if len(coor_df.columns) >= original_cols:
|
||||
# 保留第3列及之后的原始列(如果有的话)
|
||||
if original_cols > 2:
|
||||
original_columns = list(coor_df.columns[2:original_cols])
|
||||
else:
|
||||
original_columns = []
|
||||
else:
|
||||
# 如果原始列数不足,只保留存在的列(跳过前两列)
|
||||
if len(coor_df.columns) > 2:
|
||||
original_columns = list(coor_df.columns[2:])
|
||||
else:
|
||||
original_columns = []
|
||||
else:
|
||||
# 如果没有列名,只保留第3列及之后的列(如果有的话)
|
||||
if original_cols > 2:
|
||||
original_columns = ["col_" + str(j + 1) for j in range(2, original_cols)]
|
||||
else:
|
||||
original_columns = []
|
||||
except:
|
||||
# 异常处理:只保留第3列及之后的列(如果有的话)
|
||||
if original_cols > 2:
|
||||
original_columns = ["col_" + str(j + 1) for j in range(2, original_cols)]
|
||||
else:
|
||||
original_columns = []
|
||||
|
||||
# 读取波长信息,用作光谱列名
|
||||
wavelengths = None
|
||||
try:
|
||||
in_hdr_dict = spectral.envi.read_envi_header(get_hdr_file_path(imgpath))
|
||||
wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64')
|
||||
# 将波长值转换为字符串作为列名
|
||||
spectral_columns = [str(wl) for wl in wavelengths]
|
||||
print(f"成功读取波长信息,共 {len(spectral_columns)} 个波段")
|
||||
except Exception as e:
|
||||
print(f"警告: 无法读取波长信息 ({str(e)}),使用默认列名 band_1, band_2, ...")
|
||||
spectral_columns = ["band_" + str(j + 1) for j in range(num_bands)]
|
||||
|
||||
# 构建输出列名(不包含前两列坐标列和UTM列)
|
||||
all_columns = original_columns + spectral_columns
|
||||
|
||||
# 从coor_spectral中提取需要输出的列
|
||||
# 跳过前两列(经纬度)和UTM列,只保留:
|
||||
# - 第3列到第original_cols列(如果有的话)
|
||||
# - 光谱数据列(从original_cols+2开始)
|
||||
output_data = []
|
||||
if original_cols > 2:
|
||||
# 保留第3列到第original_cols列
|
||||
output_data.append(coor_spectral[:, 2:original_cols])
|
||||
# 保留光谱数据列(从original_cols+2开始)
|
||||
output_data.append(coor_spectral[:, original_cols + 2:])
|
||||
|
||||
# 合并数据
|
||||
if len(output_data) > 0:
|
||||
output_array = np.hstack(output_data) if len(output_data) > 1 else output_data[0]
|
||||
else:
|
||||
# 如果没有原始列,只输出光谱数据
|
||||
output_array = coor_spectral[:, original_cols + 2:]
|
||||
|
||||
# 创建结果DataFrame
|
||||
result_df = pd.DataFrame(output_array, columns=all_columns)
|
||||
|
||||
# 保存为CSV格式
|
||||
result_df.to_csv(outpath, index=False, float_format='%.6f')
|
||||
print(f"结果已保存到CSV文件: {outpath}")
|
||||
|
||||
return coor_spectral
|
||||
|
||||
|
||||
# 直接运行示例
|
||||
if __name__ == '__main__':
|
||||
# 在这里直接设置参数
|
||||
imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径
|
||||
|
||||
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
||||
flare_path = r"E:\code\WQ\封装\work_dir\2_glint\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用)
|
||||
boundary_path ="D:\BaiduNetdiskDownload\yaobao\water_mask.dat" # 边界掩膜文件路径(可选,None表示不使用)
|
||||
source_epsg = 4326 # 源坐标系EPSG代码,默认为4326 (WGS84地理坐标系)
|
||||
|
||||
verbose = True # 是否启用详细模式
|
||||
|
||||
if verbose:
|
||||
print(f"影像文件: {imgpath}")
|
||||
print(f"坐标文件: {coorpath}")
|
||||
print(f"输出文件: {output_path}")
|
||||
print(f"采样半径: {radius}")
|
||||
if flare_path:
|
||||
print(f"耀斑掩膜: {flare_path}")
|
||||
if boundary_path:
|
||||
print(f"边界掩膜: {boundary_path}")
|
||||
if source_epsg:
|
||||
print(f"指定坐标系: EPSG:{source_epsg}")
|
||||
|
||||
tmp = get_spectral_in_coor(imgpath, coorpath, output_path,
|
||||
radius, flare_path, boundary_path, source_epsg)
|
||||
|
||||
785
src/core/glint_removal/get_spectral.py
Normal file
785
src/core/glint_removal/get_spectral.py
Normal file
@ -0,0 +1,785 @@
|
||||
from osgeo import gdal, osr
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import spectral
|
||||
from math import sin, cos, tan, sqrt, radians
|
||||
|
||||
# 启用GDAL异常处理
|
||||
osr.UseExceptions()
|
||||
|
||||
# WGS84椭球参数
|
||||
WGS84_A = 6378137.0 # 长半轴(米)
|
||||
WGS84_F = 1 / 298.257223563 # 扁率
|
||||
WGS84_E2 = WGS84_F * (2 - WGS84_F) # 第一偏心率平方
|
||||
WGS84_EP2 = WGS84_E2 / (1 - WGS84_E2) # 第二偏心率平方
|
||||
UTM_K0 = 0.9996 # UTM比例因子
|
||||
|
||||
|
||||
def transform_coordinates(lon, lat, source_srs, target_srs):
|
||||
"""
|
||||
坐标系转换
|
||||
|
||||
Args:
|
||||
lon: 经度
|
||||
lat: 纬度
|
||||
source_srs: 源坐标系
|
||||
target_srs: 目标坐标系
|
||||
|
||||
Returns:
|
||||
transformed_lon, transformed_lat: 转换后的坐标
|
||||
"""
|
||||
# 创建坐标转换对象
|
||||
transform = osr.CoordinateTransformation(source_srs, target_srs)
|
||||
|
||||
# 执行坐标转换
|
||||
point = transform.TransformPoint(lon, lat)
|
||||
|
||||
return point[0], point[1]
|
||||
|
||||
|
||||
|
||||
def geo_to_pixel(lon, lat, geotransform, dataset_srs=None):
|
||||
"""
|
||||
地理坐标转换为像素坐标
|
||||
|
||||
Args:
|
||||
lon: 经度
|
||||
lat: 纬度
|
||||
geotransform: 仿射变换参数
|
||||
dataset_srs: 数据集的空间参考系统(可选)
|
||||
|
||||
Returns:
|
||||
pixel_x, pixel_y: 像素坐标
|
||||
"""
|
||||
# 使用仿射变换的逆变换将地理坐标转换为像素坐标
|
||||
x_origin = geotransform[0]
|
||||
y_origin = geotransform[3]
|
||||
pixel_width = geotransform[1]
|
||||
pixel_height = geotransform[5]
|
||||
|
||||
pixel_x = int((lon - x_origin) / pixel_width)
|
||||
pixel_y = int((lat - y_origin) / pixel_height)
|
||||
|
||||
return pixel_x, pixel_y
|
||||
|
||||
|
||||
def get_pixel_spectrum_batch(dataset, pixel_x_array, pixel_y_array):
|
||||
"""
|
||||
批量获取多个像素点的光谱数据(优化版本)
|
||||
|
||||
Args:
|
||||
dataset: GDAL数据集
|
||||
pixel_x_array: 像素X坐标数组
|
||||
pixel_y_array: 像素Y坐标数组
|
||||
|
||||
Returns:
|
||||
spectrum_array: 光谱数据数组 (n_points, n_bands)
|
||||
"""
|
||||
n_points = len(pixel_x_array)
|
||||
n_bands = dataset.RasterCount
|
||||
|
||||
# 初始化输出数组
|
||||
spectrum_array = np.zeros((n_points, n_bands), dtype=np.float32)
|
||||
|
||||
# 按波段批量读取(更高效)
|
||||
for band_idx in range(n_bands):
|
||||
band = dataset.GetRasterBand(band_idx + 1) # GDAL波段索引从1开始
|
||||
band_data = band.ReadAsArray() # 读取整个波段
|
||||
|
||||
# 批量提取像素值
|
||||
for i in range(n_points):
|
||||
px, py = int(pixel_x_array[i]), int(pixel_y_array[i])
|
||||
if 0 <= px < band_data.shape[1] and 0 <= py < band_data.shape[0]:
|
||||
spectrum_array[i, band_idx] = band_data[py, px]
|
||||
else:
|
||||
spectrum_array[i, band_idx] = np.nan
|
||||
|
||||
return spectrum_array
|
||||
|
||||
|
||||
def get_average_spectral_in_radius(dataset, center_x, center_y, radius, flare_mask=None, boundary_mask=None):
|
||||
"""
|
||||
获取指定半径内的平均光谱,避开耀斑和边界区域
|
||||
|
||||
Args:
|
||||
dataset: GDAL数据集
|
||||
center_x, center_y: 中心像素坐标
|
||||
radius: 半径(像素)
|
||||
flare_mask: 耀斑掩膜数组(可选)
|
||||
boundary_mask: 边界掩膜数组(可选)
|
||||
|
||||
Returns:
|
||||
平均光谱值数组
|
||||
"""
|
||||
num_bands = dataset.RasterCount
|
||||
|
||||
# 计算采样区域边界
|
||||
x_start = max(0, center_x - radius)
|
||||
x_end = min(dataset.RasterXSize, center_x + radius + 1)
|
||||
y_start = max(0, center_y - radius)
|
||||
y_end = min(dataset.RasterYSize, center_y + radius + 1)
|
||||
|
||||
# 读取区域数据
|
||||
width = x_end - x_start
|
||||
height = y_end - y_start
|
||||
|
||||
if width <= 0 or height <= 0:
|
||||
return np.zeros(num_bands)
|
||||
|
||||
# 读取所有波段数据
|
||||
spectral_data = dataset.ReadAsArray(x_start, y_start, width, height)
|
||||
if spectral_data is None:
|
||||
return np.zeros(num_bands)
|
||||
|
||||
# 确保数据是3维的 (bands, height, width)
|
||||
if len(spectral_data.shape) == 2:
|
||||
spectral_data = spectral_data.reshape(1, spectral_data.shape[0], spectral_data.shape[1])
|
||||
|
||||
# 创建圆形掩膜
|
||||
y_indices, x_indices = np.ogrid[:height, :width]
|
||||
center_x_local = center_x - x_start
|
||||
center_y_local = center_y - y_start
|
||||
|
||||
# 计算距离掩膜
|
||||
distance_mask = ((x_indices - center_x_local) ** 2 + (y_indices - center_y_local) ** 2) <= radius ** 2
|
||||
|
||||
# 应用耀斑掩膜(如果提供)
|
||||
if flare_mask is not None:
|
||||
flare_region = flare_mask[y_start:y_end, x_start:x_end]
|
||||
if flare_region.shape == distance_mask.shape:
|
||||
distance_mask = distance_mask & (flare_region == 0) # 假设0表示无耀斑
|
||||
|
||||
# 应用边界掩膜(如果提供)
|
||||
if boundary_mask is not None:
|
||||
boundary_region = boundary_mask[y_start:y_end, x_start:x_end]
|
||||
if boundary_region.shape == distance_mask.shape:
|
||||
distance_mask = distance_mask & (boundary_region == 1) # 假设0表示无边界
|
||||
|
||||
# 计算平均光谱
|
||||
average_spectrum = np.zeros(num_bands)
|
||||
valid_pixels = np.sum(distance_mask)
|
||||
|
||||
if valid_pixels > 0:
|
||||
for band in range(num_bands):
|
||||
band_data = spectral_data[band, :, :]
|
||||
# 排除无效值
|
||||
valid_data = band_data[distance_mask & (band_data != 0) & np.isfinite(band_data)]
|
||||
if len(valid_data) > 0:
|
||||
average_spectrum[band] = np.mean(valid_data)
|
||||
|
||||
return average_spectrum
|
||||
|
||||
|
||||
def load_mask_file(mask_path):
|
||||
"""
|
||||
加载掩膜文件
|
||||
|
||||
Args:
|
||||
mask_path: 掩膜文件路径(支持栅格文件如.dat/.tif等)
|
||||
|
||||
Returns:
|
||||
掩膜数组
|
||||
"""
|
||||
if mask_path is None or not os.path.exists(mask_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
# 使用gdal.OpenEx打开文件,明确指定为栅格文件
|
||||
# 如果文件是矢量格式,会返回None,避免"多图层"错误
|
||||
dataset = gdal.OpenEx(mask_path, gdal.OF_RASTER)
|
||||
if dataset is None:
|
||||
# 如果OpenEx失败,尝试使用Open(向后兼容)
|
||||
dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
print(f"警告: 无法打开掩膜文件 {mask_path},可能不是有效的栅格文件")
|
||||
return None
|
||||
|
||||
# 检查是否为栅格数据集(有RasterCount属性)
|
||||
if not hasattr(dataset, 'RasterCount') or dataset.RasterCount == 0:
|
||||
print(f"警告: {mask_path} 不是有效的栅格文件")
|
||||
del dataset
|
||||
return None
|
||||
|
||||
mask_data = dataset.GetRasterBand(1).ReadAsArray()
|
||||
del dataset
|
||||
return mask_data
|
||||
except Exception as e:
|
||||
print(f"警告: 加载掩膜文件 {mask_path} 时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_hdr_file_path(file_path):
|
||||
"""
|
||||
获取HDR文件路径
|
||||
|
||||
Args:
|
||||
file_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
HDR文件路径
|
||||
"""
|
||||
return os.path.splitext(file_path)[0] + ".hdr"
|
||||
|
||||
|
||||
def calculate_utm_zone(longitude):
|
||||
"""
|
||||
根据经度计算UTM分区号
|
||||
|
||||
Args:
|
||||
longitude: 经度
|
||||
|
||||
Returns:
|
||||
utm_zone: UTM分区号(1-60)
|
||||
"""
|
||||
# UTM分区从180度开始,每个分区6度
|
||||
utm_zone = int((longitude + 180) / 6) + 1
|
||||
# 确保分区号在有效范围内
|
||||
utm_zone = max(1, min(60, utm_zone))
|
||||
return utm_zone
|
||||
|
||||
|
||||
def latlon_to_utm_math(lat_deg, lon_deg, zone=None):
|
||||
"""
|
||||
使用数学公式将WGS84经纬度转换为UTM坐标
|
||||
|
||||
Args:
|
||||
lat_deg: 纬度(度)
|
||||
lon_deg: 经度(度)
|
||||
zone: UTM分区号(如果为None,则根据经度自动计算)
|
||||
|
||||
Returns:
|
||||
easting, northing: UTM坐标(米)
|
||||
"""
|
||||
# 如果未指定分区,根据经度计算
|
||||
if zone is None:
|
||||
zone = calculate_utm_zone(lon_deg)
|
||||
|
||||
# 计算中央经线(度)
|
||||
lon0 = (zone * 6 - 183)
|
||||
lam0 = radians(lon0)
|
||||
|
||||
# 转换为弧度
|
||||
phi = radians(lat_deg)
|
||||
lam = radians(lon_deg)
|
||||
|
||||
# 计算中间变量
|
||||
sinphi = sin(phi)
|
||||
cosphi = cos(phi)
|
||||
tanphi = tan(phi)
|
||||
|
||||
# 计算卯酉圈曲率半径
|
||||
N = WGS84_A / sqrt(1 - WGS84_E2 * sinphi * sinphi)
|
||||
|
||||
T = tanphi * tanphi
|
||||
C = WGS84_EP2 * cosphi * cosphi
|
||||
A = cosphi * (lam - lam0)
|
||||
|
||||
# 计算子午圈弧长(使用Snyder公式)
|
||||
M = (WGS84_A * ((1 - WGS84_E2/4 - 3*WGS84_E2**2/64 - 5*WGS84_E2**3/256) * phi
|
||||
- (3*WGS84_E2/8 + 3*WGS84_E2**2/32 + 45*WGS84_E2**3/1024) * sin(2*phi)
|
||||
+ (15*WGS84_E2**2/256 + 45*WGS84_E2**3/1024) * sin(4*phi)
|
||||
- (35*WGS84_E2**3/3072) * sin(6*phi)))
|
||||
|
||||
# 计算东坐标(Easting)
|
||||
E = (UTM_K0 * N * (A + (1 - T + C) * A**3 / 6
|
||||
+ (5 - 18*T + T*T + 72*C - 58*WGS84_EP2) * A**5 / 120)
|
||||
+ 500000.0)
|
||||
|
||||
# 计算北坐标(Northing)
|
||||
# 对于南半球,需要添加10000000米偏移
|
||||
if lat_deg < 0:
|
||||
Nn = (UTM_K0 * (M + N * tanphi * (A**2 / 2
|
||||
+ (5 - T + 9*C + 4*C*C) * A**4 / 24
|
||||
+ (61 - 58*T + T*T + 600*C - 330*WGS84_EP2) * A**6 / 720))
|
||||
+ 10000000.0)
|
||||
else:
|
||||
Nn = (UTM_K0 * (M + N * tanphi * (A**2 / 2
|
||||
+ (5 - T + 9*C + 4*C*C) * A**4 / 24
|
||||
+ (61 - 58*T + T*T + 600*C - 330*WGS84_EP2) * A**6 / 720)))
|
||||
|
||||
return E, Nn
|
||||
|
||||
|
||||
def convert_to_utm(lon, lat, source_epsg=4326, target_epsg=None):
|
||||
"""
|
||||
将坐标转换为UTM格式(使用数学公式,根据经度自动计算UTM分区)
|
||||
|
||||
Args:
|
||||
lon: 经度数组
|
||||
lat: 纬度数组
|
||||
source_epsg: 源坐标系EPSG代码,默认为4326 (WGS84地理坐标系)
|
||||
target_epsg: 目标坐标系EPSG代码(如果为None,则根据经度自动计算;如果指定,则从EPSG代码提取分区号)
|
||||
|
||||
Returns:
|
||||
utm_x, utm_y: 转换后的UTM坐标(米)
|
||||
"""
|
||||
try:
|
||||
# 检查源坐标系是否为WGS84
|
||||
if source_epsg != 4326:
|
||||
print(f"警告: 数学公式转换仅支持WGS84 (EPSG:4326),当前源坐标系为EPSG:{source_epsg}")
|
||||
print("将尝试使用数学公式进行转换,但可能不准确")
|
||||
|
||||
# 批量转换坐标
|
||||
utm_x = np.zeros_like(lon)
|
||||
utm_y = np.zeros_like(lat)
|
||||
|
||||
# 如果指定了目标EPSG,提取分区号
|
||||
fixed_zone = None
|
||||
if target_epsg is not None:
|
||||
# 从EPSG代码提取分区号
|
||||
# EPSG:32651 -> 51, EPSG:32751 -> 51
|
||||
if 32601 <= target_epsg <= 32660:
|
||||
fixed_zone = target_epsg - 32600
|
||||
elif 32701 <= target_epsg <= 32760:
|
||||
fixed_zone = target_epsg - 32700
|
||||
else:
|
||||
print(f"警告: 无法从EPSG代码 {target_epsg} 提取UTM分区号,将根据经度自动计算")
|
||||
|
||||
# 向量化处理:标记无效坐标
|
||||
invalid_mask = (np.isnan(lon) | np.isnan(lat) |
|
||||
(lon < -180) | (lon > 180) |
|
||||
(lat < -90) | (lat > 90))
|
||||
|
||||
# 统计无效坐标
|
||||
invalid_count = np.sum(invalid_mask)
|
||||
if invalid_count > 0:
|
||||
invalid_indices = np.where(invalid_mask)[0]
|
||||
print(f"警告: 发现 {invalid_count} 个无效坐标点,将跳过")
|
||||
for idx in invalid_indices[:10]: # 只打印前10个
|
||||
print(f" 坐标点 {idx + 1}: 经度={lon[idx]}, 纬度={lat[idx]}")
|
||||
if invalid_count > 10:
|
||||
print(f" ... 还有 {invalid_count - 10} 个无效坐标点")
|
||||
|
||||
# 对有效坐标进行转换
|
||||
valid_mask = ~invalid_mask
|
||||
if np.any(valid_mask):
|
||||
valid_lon = lon[valid_mask]
|
||||
valid_lat = lat[valid_mask]
|
||||
valid_indices = np.where(valid_mask)[0]
|
||||
|
||||
# 计算UTM分区(向量化)
|
||||
if fixed_zone is not None:
|
||||
zones = np.full(len(valid_lon), fixed_zone)
|
||||
else:
|
||||
zones = np.array([calculate_utm_zone(lon_val) for lon_val in valid_lon])
|
||||
|
||||
# 批量转换(仍需要循环,但减少了开销)
|
||||
for i, (lat_val, lon_val, zone) in enumerate(zip(valid_lat, valid_lon, zones)):
|
||||
try:
|
||||
E, Nn = latlon_to_utm_math(lat_val, lon_val, zone)
|
||||
if not (np.isnan(E) or np.isnan(Nn) or np.isinf(E) or np.isinf(Nn)):
|
||||
utm_x[valid_indices[i]] = E
|
||||
utm_y[valid_indices[i]] = Nn
|
||||
else:
|
||||
utm_x[valid_indices[i]] = np.nan
|
||||
utm_y[valid_indices[i]] = np.nan
|
||||
except Exception as e:
|
||||
utm_x[valid_indices[i]] = np.nan
|
||||
utm_y[valid_indices[i]] = np.nan
|
||||
|
||||
# 设置无效坐标为NaN
|
||||
utm_x[invalid_mask] = np.nan
|
||||
utm_y[invalid_mask] = np.nan
|
||||
|
||||
return utm_x, utm_y
|
||||
|
||||
except Exception as e:
|
||||
print(f"坐标转换初始化失败: {str(e)}")
|
||||
return np.full_like(lon, np.nan), np.full_like(lat, np.nan)
|
||||
|
||||
|
||||
def convert_to_utm51n(lon, lat, source_epsg=4326):
|
||||
"""
|
||||
将坐标转换为WGS84 UTM 51N格式(保留向后兼容性)
|
||||
|
||||
Args:
|
||||
lon: 经度数组
|
||||
lat: 纬度数组
|
||||
source_epsg: 源坐标系EPSG代码,默认为4326 (WGS84地理坐标系)
|
||||
|
||||
Returns:
|
||||
utm_x, utm_y: 转换后的UTM坐标(米)
|
||||
"""
|
||||
# 使用新的转换函数,但强制使用UTM 51N
|
||||
return convert_to_utm(lon, lat, source_epsg, target_epsg=32651)
|
||||
|
||||
|
||||
def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None, boundary_path=None, source_epsg=4326):
|
||||
"""
|
||||
获取给定坐标的光谱曲线,并将坐标转换为UTM格式(根据经度自动计算UTM分区)
|
||||
|
||||
Args:
|
||||
imgpath: 影像文件路径(BIL格式)
|
||||
coorpath: 坐标文件路径(CSV格式,第1、2列为纬度和经度)
|
||||
outpath: 输出文件路径(CSV格式)
|
||||
radius: 采样半径(像素)
|
||||
flare_path: 耀斑文件路径(可选)
|
||||
boundary_path: 边界文件路径(可选)
|
||||
source_epsg: 源坐标系EPSG代码,默认为4326 (WGS84地理坐标系)
|
||||
"""
|
||||
# 读取原始坐标文件(CSV格式)
|
||||
coor_df = None
|
||||
coor_data = None
|
||||
|
||||
# 尝试不同的编码方式读取CSV文件
|
||||
encodings = ['utf-8', 'gbk', 'gb2312', 'latin1', 'cp1252']
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
# 尝试读取CSV文件
|
||||
coor_df = pd.read_csv(coorpath, encoding=encoding)
|
||||
# 只提取数值数据,跳过表头
|
||||
coor_data = coor_df.select_dtypes(include=[np.number]).values
|
||||
|
||||
# 如果没有数值列,尝试转换所有列(跳过第一行表头)
|
||||
if coor_data.shape[1] == 0:
|
||||
# 尝试从第二行开始读取,第一行作为表头
|
||||
coor_df = pd.read_csv(coorpath, encoding=encoding, header=0)
|
||||
# 尝试将所有列转换为数值
|
||||
numeric_df = coor_df.apply(pd.to_numeric, errors='coerce')
|
||||
# 删除全为NaN的行(通常是表头转换失败的行)
|
||||
numeric_df = numeric_df.dropna(how='all')
|
||||
coor_data = numeric_df.values
|
||||
|
||||
print(f"成功使用 {encoding} 编码读取文件")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用 {encoding} 编码读取失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 如果所有编码都失败,尝试numpy读取
|
||||
if coor_data is None:
|
||||
try:
|
||||
print("尝试使用numpy读取数值数据...")
|
||||
# 跳过第一行(表头),只读取数值
|
||||
coor_data = np.loadtxt(coorpath, delimiter=",", skiprows=1)
|
||||
except:
|
||||
try:
|
||||
coor_data = np.loadtxt(coorpath, delimiter="\t", skiprows=1)
|
||||
except Exception as e:
|
||||
raise Exception(f"无法读取坐标文件,请检查文件格式: {str(e)}")
|
||||
|
||||
if len(coor_data.shape) == 1:
|
||||
coor_data = coor_data.reshape(1, -1)
|
||||
|
||||
# 检查数据有效性
|
||||
if coor_data is None or coor_data.shape[1] < 2:
|
||||
raise Exception("坐标文件格式错误:需要至少2列数据(第1列为纬度,第2列为经度)")
|
||||
|
||||
print(f"成功读取坐标文件,共 {coor_data.shape[0]} 行,{coor_data.shape[1]} 列")
|
||||
print(f"数据预览(前3行):")
|
||||
for i in range(min(3, coor_data.shape[0])):
|
||||
print(f" 行{i + 1}: {coor_data[i, :min(5, coor_data.shape[1])]}") # 只显示前5列
|
||||
|
||||
# 提取原始坐标
|
||||
lat_array = coor_data[:, 0] # 第1列是纬度
|
||||
lon_array = coor_data[:, 1] # 第2列是经度
|
||||
|
||||
print(f"\n=== 原始坐标信息 ===")
|
||||
print(f"原始坐标范围: 经度 {np.min(lon_array):.6f} ~ {np.max(lon_array):.6f}, 纬度 {np.min(lat_array):.6f} ~ {np.max(lat_array):.6f}")
|
||||
|
||||
# 坐标转换为UTM(根据经度自动计算UTM分区)
|
||||
print("正在进行坐标转换...")
|
||||
utm_x, utm_y = convert_to_utm(lon_array, lat_array, source_epsg, target_epsg=None)
|
||||
|
||||
# 检查转换结果
|
||||
valid_utm_mask = ~(np.isnan(utm_x) | np.isnan(utm_y) | np.isinf(utm_x) | np.isinf(utm_y))
|
||||
valid_count = np.sum(valid_utm_mask)
|
||||
|
||||
if valid_count > 0:
|
||||
print(f"转换后UTM坐标范围: X {np.nanmin(utm_x):.2f} ~ {np.nanmax(utm_x):.2f}, Y {np.nanmin(utm_y):.2f} ~ {np.nanmax(utm_y):.2f}")
|
||||
print(f"成功转换 {valid_count}/{len(utm_x)} 个坐标点")
|
||||
else:
|
||||
print("警告: 所有UTM坐标转换都失败了,将尝试使用原始经纬度坐标进行像素坐标转换")
|
||||
|
||||
# 打开影像数据集
|
||||
dataset = gdal.Open(imgpath)
|
||||
im_width = dataset.RasterXSize # 栅格矩阵的列数
|
||||
im_height = dataset.RasterYSize # 栅格矩阵的行数
|
||||
num_bands = dataset.RasterCount # 栅格矩阵的波段数
|
||||
geotransform = dataset.GetGeoTransform() # 仿射矩阵
|
||||
im_proj = dataset.GetProjection() # 地图投影信息
|
||||
|
||||
print(f"影像尺寸: {im_width} x {im_height}, 波段数: {num_bands}")
|
||||
print(f"仿射变换参数: {geotransform}")
|
||||
|
||||
print("\n=== 开始光谱提取 ===")
|
||||
|
||||
# 加载掩膜文件
|
||||
flare_mask = load_mask_file(flare_path)
|
||||
boundary_mask = load_mask_file(boundary_path)
|
||||
|
||||
# 获取数据集的空间参考系统
|
||||
dataset_srs = dataset.GetSpatialRef()
|
||||
|
||||
# 准备输出数组,在原有数据基础上添加UTM坐标和光谱列
|
||||
original_cols = coor_data.shape[1]
|
||||
# 添加UTM坐标列(2列)和光谱列(num_bands列)
|
||||
new_columns = np.zeros((coor_data.shape[0], 2 + num_bands))
|
||||
coor_spectral = np.hstack((coor_data, new_columns))
|
||||
|
||||
# 将UTM坐标添加到数据中
|
||||
coor_spectral[:, original_cols] = utm_x # UTM X坐标
|
||||
coor_spectral[:, original_cols + 1] = utm_y # UTM Y坐标
|
||||
|
||||
print(f"处理 {coor_data.shape[0]} 个坐标点...")
|
||||
|
||||
# 如果UTM转换失败,尝试使用影像坐标系进行转换
|
||||
use_utm_fallback = False
|
||||
if valid_count == 0 and dataset_srs is not None:
|
||||
print("尝试使用影像坐标系进行坐标转换...")
|
||||
try:
|
||||
source_srs = osr.SpatialReference()
|
||||
source_srs.ImportFromEPSG(source_epsg)
|
||||
transform_to_image = osr.CoordinateTransformation(source_srs, dataset_srs)
|
||||
use_utm_fallback = True
|
||||
except:
|
||||
use_utm_fallback = False
|
||||
|
||||
# 批量转换所有坐标点为像素坐标
|
||||
pixel_x_array = np.zeros(coor_data.shape[0], dtype=np.int32)
|
||||
pixel_y_array = np.zeros(coor_data.shape[0], dtype=np.int32)
|
||||
valid_pixel_mask = np.zeros(coor_data.shape[0], dtype=bool)
|
||||
|
||||
# 批量计算像素坐标
|
||||
for i in range(coor_data.shape[0]):
|
||||
# 优先使用UTM坐标,如果无效则使用备用方案
|
||||
utm_x_point = utm_x[i]
|
||||
utm_y_point = utm_y[i]
|
||||
|
||||
# 检查UTM坐标是否有效
|
||||
if np.isnan(utm_x_point) or np.isnan(utm_y_point) or np.isinf(utm_x_point) or np.isinf(utm_y_point):
|
||||
# 如果UTM转换失败,尝试使用影像坐标系
|
||||
if use_utm_fallback:
|
||||
try:
|
||||
lon_point = lon_array[i]
|
||||
lat_point = lat_array[i]
|
||||
if not (np.isnan(lon_point) or np.isnan(lat_point)):
|
||||
# 转换为影像坐标系
|
||||
img_coords = transform_to_image.TransformPoint(lon_point, lat_point)
|
||||
pixel_x, pixel_y = geo_to_pixel(img_coords[0], img_coords[1], geotransform, dataset_srs)
|
||||
# 更新UTM坐标列(使用影像坐标系坐标)
|
||||
coor_spectral[i, original_cols] = img_coords[0]
|
||||
coor_spectral[i, original_cols + 1] = img_coords[1]
|
||||
else:
|
||||
print(f"跳过坐标点 {i + 1}: 坐标无效")
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
continue
|
||||
except Exception as e:
|
||||
# 如果影像坐标系转换也失败,尝试直接使用经纬度
|
||||
try:
|
||||
lon_point = lon_array[i]
|
||||
lat_point = lat_array[i]
|
||||
if not (np.isnan(lon_point) or np.isnan(lat_point)):
|
||||
pixel_x, pixel_y = geo_to_pixel(lon_point, lat_point, geotransform, dataset_srs)
|
||||
# 保留原始经纬度作为坐标
|
||||
coor_spectral[i, original_cols] = lon_point
|
||||
coor_spectral[i, original_cols + 1] = lat_point
|
||||
else:
|
||||
print(f"跳过坐标点 {i + 1}: 坐标无效")
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
continue
|
||||
except:
|
||||
print(f"跳过坐标点 {i + 1}: 所有坐标转换方式都失败")
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
continue
|
||||
else:
|
||||
# 尝试直接使用经纬度坐标
|
||||
try:
|
||||
lon_point = lon_array[i]
|
||||
lat_point = lat_array[i]
|
||||
if not (np.isnan(lon_point) or np.isnan(lat_point)):
|
||||
pixel_x, pixel_y = geo_to_pixel(lon_point, lat_point, geotransform, dataset_srs)
|
||||
# 保留原始经纬度作为坐标
|
||||
coor_spectral[i, original_cols] = lon_point
|
||||
coor_spectral[i, original_cols + 1] = lat_point
|
||||
else:
|
||||
print(f"跳过坐标点 {i + 1}: 坐标无效")
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
continue
|
||||
except:
|
||||
print(f"跳过坐标点 {i + 1}: 坐标转换失败")
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
continue
|
||||
else:
|
||||
# UTM坐标转换为像素坐标
|
||||
pixel_x, pixel_y = geo_to_pixel(utm_x_point, utm_y_point, geotransform, dataset_srs)
|
||||
|
||||
# 存储像素坐标
|
||||
pixel_x_array[i] = pixel_x
|
||||
pixel_y_array[i] = pixel_y
|
||||
|
||||
# 检查坐标是否在影像范围内
|
||||
if 0 <= pixel_x < im_width and 0 <= pixel_y < im_height:
|
||||
valid_pixel_mask[i] = True
|
||||
else:
|
||||
valid_pixel_mask[i] = False
|
||||
if i < 10 or (i % 100 == 0): # 只打印前10个或每100个打印一次
|
||||
print(f"警告: 坐标点 {i + 1} (UTM X:{utm_x_point:.2f}, Y:{utm_y_point:.2f}) 超出影像范围")
|
||||
|
||||
# 批量提取光谱数据(优化:减少I/O操作)
|
||||
print(f"批量提取光谱数据... (有效坐标点: {np.sum(valid_pixel_mask)})")
|
||||
|
||||
if radius > 0:
|
||||
# 半径采样模式:需要逐个处理
|
||||
for i in range(coor_data.shape[0]):
|
||||
if valid_pixel_mask[i]:
|
||||
spectrum = get_average_spectral_in_radius(
|
||||
dataset, pixel_x_array[i], pixel_y_array[i], radius, flare_mask, boundary_mask
|
||||
)
|
||||
coor_spectral[i, original_cols + 2:] = spectrum
|
||||
else:
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
else:
|
||||
# 单点采样模式:批量读取(优化)
|
||||
# 预读取所有波段数据(如果内存允许)
|
||||
try:
|
||||
# 尝试读取所有波段到内存(适用于内存充足的情况)
|
||||
print("正在预加载所有波段数据到内存(优化模式)...")
|
||||
all_bands_data = []
|
||||
for band_idx in range(num_bands):
|
||||
band = dataset.GetRasterBand(band_idx + 1)
|
||||
band_data = band.ReadAsArray()
|
||||
all_bands_data.append(band_data)
|
||||
all_bands_data = np.array(all_bands_data) # shape: (bands, height, width)
|
||||
print("预加载完成,开始批量提取像素值...")
|
||||
|
||||
# 批量提取像素值
|
||||
for i in range(coor_data.shape[0]):
|
||||
if valid_pixel_mask[i]:
|
||||
px, py = int(pixel_x_array[i]), int(pixel_y_array[i])
|
||||
# GDAL读取的数组形状是 (bands, height, width),像素坐标 (x,y) 对应数组索引 [:, y, x]
|
||||
# 注意:py是行(y坐标),px是列(x坐标)
|
||||
if 0 <= px < all_bands_data.shape[2] and 0 <= py < all_bands_data.shape[1]:
|
||||
spectrum = all_bands_data[:, py, px] # 直接索引,非常快
|
||||
coor_spectral[i, original_cols + 2:] = spectrum
|
||||
else:
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
else:
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
|
||||
# 释放内存
|
||||
del all_bands_data
|
||||
print("批量提取完成")
|
||||
|
||||
except MemoryError:
|
||||
# 如果内存不足,回退到逐个波段读取
|
||||
print("内存不足,使用逐个波段读取模式...")
|
||||
for i in range(coor_data.shape[0]):
|
||||
if valid_pixel_mask[i]:
|
||||
px, py = pixel_x_array[i], pixel_y_array[i]
|
||||
spectrum = np.zeros(num_bands)
|
||||
for band_idx in range(num_bands):
|
||||
band = dataset.GetRasterBand(band_idx + 1)
|
||||
spectrum[band_idx] = band.ReadAsArray(px, py, 1, 1)[0, 0]
|
||||
coor_spectral[i, original_cols + 2:] = spectrum
|
||||
else:
|
||||
coor_spectral[i, original_cols + 2:] = np.zeros(num_bands)
|
||||
|
||||
del dataset
|
||||
|
||||
# 创建DataFrame用于CSV输出
|
||||
# 去除前两列坐标列(纬度和经度)和UTM列
|
||||
try:
|
||||
# 如果原始数据有列名,使用原始列名(跳过前两列)
|
||||
if coor_df is not None and hasattr(coor_df, 'columns'):
|
||||
# 跳过前两列(经纬度),从第3列开始
|
||||
if len(coor_df.columns) >= original_cols:
|
||||
# 保留第3列及之后的原始列(如果有的话)
|
||||
if original_cols > 2:
|
||||
original_columns = list(coor_df.columns[2:original_cols])
|
||||
else:
|
||||
original_columns = []
|
||||
else:
|
||||
# 如果原始列数不足,只保留存在的列(跳过前两列)
|
||||
if len(coor_df.columns) > 2:
|
||||
original_columns = list(coor_df.columns[2:])
|
||||
else:
|
||||
original_columns = []
|
||||
else:
|
||||
# 如果没有列名,只保留第3列及之后的列(如果有的话)
|
||||
if original_cols > 2:
|
||||
original_columns = ["col_" + str(j + 1) for j in range(2, original_cols)]
|
||||
else:
|
||||
original_columns = []
|
||||
except:
|
||||
# 异常处理:只保留第3列及之后的列(如果有的话)
|
||||
if original_cols > 2:
|
||||
original_columns = ["col_" + str(j + 1) for j in range(2, original_cols)]
|
||||
else:
|
||||
original_columns = []
|
||||
|
||||
# 读取波长信息,用作光谱列名
|
||||
wavelengths = None
|
||||
try:
|
||||
in_hdr_dict = spectral.envi.read_envi_header(get_hdr_file_path(imgpath))
|
||||
wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64')
|
||||
# 将波长值转换为字符串作为列名
|
||||
spectral_columns = [str(wl) for wl in wavelengths]
|
||||
print(f"成功读取波长信息,共 {len(spectral_columns)} 个波段")
|
||||
except Exception as e:
|
||||
print(f"警告: 无法读取波长信息 ({str(e)}),使用默认列名 band_1, band_2, ...")
|
||||
spectral_columns = ["band_" + str(j + 1) for j in range(num_bands)]
|
||||
|
||||
# 构建输出列名(不包含前两列坐标列和UTM列)
|
||||
all_columns = original_columns + spectral_columns
|
||||
|
||||
# 从coor_spectral中提取需要输出的列
|
||||
# 跳过前两列(经纬度)和UTM列,只保留:
|
||||
# - 第3列到第original_cols列(如果有的话)
|
||||
# - 光谱数据列(从original_cols+2开始)
|
||||
output_data = []
|
||||
if original_cols > 2:
|
||||
# 保留第3列到第original_cols列
|
||||
output_data.append(coor_spectral[:, 2:original_cols])
|
||||
# 保留光谱数据列(从original_cols+2开始)
|
||||
output_data.append(coor_spectral[:, original_cols + 2:])
|
||||
|
||||
# 合并数据
|
||||
if len(output_data) > 0:
|
||||
output_array = np.hstack(output_data) if len(output_data) > 1 else output_data[0]
|
||||
else:
|
||||
# 如果没有原始列,只输出光谱数据
|
||||
output_array = coor_spectral[:, original_cols + 2:]
|
||||
|
||||
# 创建结果DataFrame
|
||||
result_df = pd.DataFrame(output_array, columns=all_columns)
|
||||
|
||||
# 保存为CSV格式
|
||||
result_df.to_csv(outpath, index=False, float_format='%.6f')
|
||||
print(f"结果已保存到CSV文件: {outpath}")
|
||||
|
||||
return coor_spectral
|
||||
|
||||
|
||||
# 直接运行示例
|
||||
if __name__ == '__main__':
|
||||
# 在这里直接设置参数
|
||||
imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\work_dir\5_training_spectra/yangdian_output.csv" # CSV格式输出文件路径
|
||||
|
||||
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
||||
flare_path = r"E:\code\WQ\封装\work_dir\2_glint\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用)
|
||||
boundary_path = r"D:\BaiduNetdiskDownload\yaobao\water_mask.dat" # 边界掩膜文件路径(可选,None表示不使用)
|
||||
source_epsg = 4326 # 源坐标系EPSG代码,默认为4326 (WGS84地理坐标系)
|
||||
|
||||
verbose = True # 是否启用详细模式
|
||||
|
||||
if verbose:
|
||||
print(f"影像文件: {imgpath}")
|
||||
print(f"坐标文件: {coorpath}")
|
||||
print(f"输出文件: {output_path}")
|
||||
print(f"采样半径: {radius}")
|
||||
if flare_path:
|
||||
print(f"耀斑掩膜: {flare_path}")
|
||||
if boundary_path:
|
||||
print(f"边界掩膜: {boundary_path}")
|
||||
if source_epsg:
|
||||
print(f"指定坐标系: EPSG:{source_epsg}")
|
||||
|
||||
tmp = get_spectral_in_coor(imgpath, coorpath, output_path,
|
||||
radius, flare_path, boundary_path, source_epsg)
|
||||
|
||||
Reference in New Issue
Block a user