371 lines
16 KiB
Python
371 lines
16 KiB
Python
import numpy as np
|
||
import sys
|
||
# import preprocessing
|
||
|
||
try:
|
||
from osgeo import gdal
|
||
GDAL_AVAILABLE = True
|
||
except ImportError:
|
||
GDAL_AVAILABLE = False
|
||
print("警告: GDAL未安装,将使用numpy处理模式")
|
||
|
||
try:
|
||
from tqdm import tqdm
|
||
TQDM_AVAILABLE = True
|
||
except ImportError:
|
||
TQDM_AVAILABLE = False
|
||
# 如果tqdm不可用,定义一个简单的包装器
|
||
def tqdm(iterable, desc=None, total=None, disable=None):
|
||
return iterable
|
||
|
||
# 检测是否在 PyInstaller 打包环境(无控制台)
|
||
_is_frozen_gui = getattr(sys, "frozen", False) and (not hasattr(sys, 'stdout') or sys.stdout is None)
|
||
|
||
class Goodman:
|
||
def __init__(self, im_aligned, NIR_lower = 25, NIR_upper = 37, A = 0.000019, B = 0.1,
|
||
use_gdal=True, chunk_size=None, water_mask=None, output_path=None):
|
||
"""
|
||
:param im_aligned (np.ndarray or str): band aligned and calibrated & corrected reflectance image
|
||
可以是numpy数组或GDAL可读取的文件路径
|
||
:param NIR_lower (int): band index which corresponds to 641.93nm, closest band to 640nm
|
||
:param NIR_upper (int): band index which corresponds to 751.49nm, closest band to 750nm
|
||
:param A (float): the values in Goodman et al's paper, using AVIRIS reflectance (rather than radiance) data
|
||
:param B (float): the values in Goodman et al's paper, using AVIRIS reflectance (rather than radiance) data
|
||
see Goodman et al, which corrects each pixel independently. The NIR radiance is subtracted from the radiance at each wavelength,
|
||
but a wavelength-independent offset is also added.
|
||
it is not clear how A and B were chosen, but an optimization for a case where in situ data is
|
||
available would enable values to be found
|
||
:param use_gdal (bool): 是否使用GDAL加速处理(需要GDAL可用且输入为文件路径或大数组)
|
||
:param chunk_size (int): 已废弃,不再使用分块处理,改为逐波段处理
|
||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||
如果为None,则处理全图
|
||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||
如果为None,则不保存
|
||
"""
|
||
self.im_aligned = im_aligned
|
||
self.NIR_lower = NIR_lower
|
||
self.NIR_upper = NIR_upper
|
||
self.A = A
|
||
self.B = B
|
||
self.use_gdal = use_gdal and GDAL_AVAILABLE
|
||
self.chunk_size = chunk_size
|
||
self.is_file_path = isinstance(im_aligned, str)
|
||
self.output_path = output_path
|
||
|
||
# 获取图像信息(需要在加载掩膜之前获取尺寸)
|
||
if self.is_file_path:
|
||
if not self.use_gdal:
|
||
raise ValueError("输入为文件路径时,必须安装GDAL")
|
||
self.dataset = gdal.Open(im_aligned, gdal.GA_ReadOnly)
|
||
if self.dataset is None:
|
||
raise ValueError(f"无法打开影像文件: {im_aligned}")
|
||
self.height = self.dataset.RasterYSize
|
||
self.width = self.dataset.RasterXSize
|
||
self.n_bands = self.dataset.RasterCount
|
||
else:
|
||
self.dataset = None
|
||
self.height = im_aligned.shape[0]
|
||
self.width = im_aligned.shape[1]
|
||
self.n_bands = im_aligned.shape[-1]
|
||
|
||
# 加载水域掩膜(在获取图像尺寸之后)
|
||
self.water_mask = self._load_water_mask(water_mask)
|
||
|
||
def _load_water_mask(self, water_mask):
|
||
"""
|
||
加载水域掩膜
|
||
|
||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||
:return: numpy数组或None,1表示水域,0表示非水域
|
||
"""
|
||
if water_mask is None:
|
||
return None
|
||
|
||
# 如果已经是numpy数组
|
||
if isinstance(water_mask, np.ndarray):
|
||
if water_mask.shape[:2] != (self.height, self.width):
|
||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||
|
||
# 如果是文件路径
|
||
if isinstance(water_mask, str):
|
||
if not GDAL_AVAILABLE:
|
||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||
|
||
# 检查是否为shapefile
|
||
if water_mask.lower().endswith('.shp'):
|
||
# 从shp文件创建掩膜
|
||
if self.is_file_path:
|
||
ref_path = self.im_aligned
|
||
else:
|
||
raise ValueError("输入为numpy数组时,无法从shp文件创建掩膜(需要参考栅格)")
|
||
|
||
try:
|
||
from osgeo import ogr
|
||
ref_dataset = gdal.Open(ref_path, gdal.GA_ReadOnly)
|
||
if ref_dataset is None:
|
||
raise ValueError(f"无法打开参考栅格文件: {ref_path}")
|
||
|
||
geotransform = ref_dataset.GetGeoTransform()
|
||
projection = ref_dataset.GetProjection()
|
||
width = ref_dataset.RasterXSize
|
||
height = ref_dataset.RasterYSize
|
||
|
||
# 创建内存中的栅格数据集
|
||
mem_driver = gdal.GetDriverByName('MEM')
|
||
mask_dataset = mem_driver.Create('', width, height, 1, gdal.GDT_Byte)
|
||
mask_dataset.SetGeoTransform(geotransform)
|
||
mask_dataset.SetProjection(projection)
|
||
|
||
mask_band = mask_dataset.GetRasterBand(1)
|
||
mask_band.Fill(0)
|
||
|
||
# 打开shp文件
|
||
shp_dataset = ogr.Open(water_mask)
|
||
if shp_dataset is None:
|
||
raise ValueError(f"无法打开shp文件: {water_mask}")
|
||
|
||
layer = shp_dataset.GetLayer()
|
||
gdal.RasterizeLayer(mask_dataset, [1], layer, burn_values=[1])
|
||
|
||
water_mask_array = mask_band.ReadAsArray()
|
||
|
||
ref_dataset = None
|
||
mask_dataset = None
|
||
shp_dataset = None
|
||
|
||
return (water_mask_array > 0).astype(np.uint8)
|
||
except Exception as e:
|
||
raise ValueError(f"从shp文件创建掩膜时出错: {e}")
|
||
else:
|
||
# 栅格文件
|
||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||
if mask_dataset is None:
|
||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||
|
||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||
mask_dataset = None
|
||
|
||
if mask_array.shape != (self.height, self.width):
|
||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||
|
||
return (mask_array > 0).astype(np.uint8)
|
||
|
||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||
|
||
def _get_corrected_bands_numpy(self):
|
||
"""
|
||
使用numpy处理(用于小图像或GDAL不可用时)
|
||
|
||
注意:由于输入已经是numpy数组,数据已在内存中。
|
||
此方法通过逐波段处理,避免同时创建多个校正后的波段数组。
|
||
内存峰值 = 原始数组 + NIR波段(2个) + 当前处理的波段(1个)
|
||
"""
|
||
# 预提取重复使用的NIR波段,避免在循环中重复访问
|
||
# 这些波段会一直保存在内存中,因为它们需要用于所有波段的校正
|
||
R_640 = self.im_aligned[:,:,self.NIR_lower]
|
||
R_750 = self.im_aligned[:,:,self.NIR_upper]
|
||
# 预计算常量部分
|
||
diff_640_750 = R_640 - R_750
|
||
corrected_bands = []
|
||
|
||
# 获取水域掩膜(如果存在)
|
||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||
|
||
# 逐波段处理:每次只处理一个波段,处理完后立即添加到结果列表
|
||
for i in tqdm(range(self.n_bands), desc="处理波段 (numpy)", total=self.n_bands, disable=_is_frozen_gui):
|
||
# 获取当前波段(这是数组视图,不是复制)
|
||
R = self.im_aligned[:,:,i]
|
||
# 优化计算:减少中间数组创建
|
||
corrected_band = R - R_750 + self.A + self.B * diff_640_750
|
||
# 使用np.maximum原地操作,将负值设为0
|
||
np.maximum(corrected_band, 0, out=corrected_band)
|
||
|
||
# 如果存在水域掩膜,只对水域区域应用校正
|
||
if water_mask_bool is not None:
|
||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||
|
||
# 立即添加到结果列表(corrected_band会保留在列表中)
|
||
corrected_bands.append(corrected_band)
|
||
return corrected_bands
|
||
|
||
def _get_corrected_bands_gdal(self):
|
||
"""
|
||
使用GDAL逐波段处理,直接处理整个波段(不分块)
|
||
|
||
内存峰值 = NIR波段(2个) + 当前处理的波段(1个) + 已处理的波段(累积在列表中)
|
||
"""
|
||
corrected_bands = []
|
||
|
||
# 获取NIR波段对象(用于所有波段的校正)
|
||
band_640 = self.dataset.GetRasterBand(self.NIR_lower + 1) # GDAL波段从1开始
|
||
band_750 = self.dataset.GetRasterBand(self.NIR_upper + 1)
|
||
|
||
# 先读取NIR波段(用于所有波段的校正,会一直保存在内存中)
|
||
R_640 = band_640.ReadAsArray().astype(np.float32)
|
||
R_750 = band_750.ReadAsArray().astype(np.float32)
|
||
diff_640_750 = R_640 - R_750
|
||
|
||
# 获取水域掩膜
|
||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||
|
||
# 逐波段处理:每次只读取和处理一个波段
|
||
for i in tqdm(range(self.n_bands), desc="处理波段 (GDAL)", total=self.n_bands, disable=_is_frozen_gui):
|
||
# 读取当前波段(只加载一个波段到内存)
|
||
current_band = self.dataset.GetRasterBand(i + 1)
|
||
R = current_band.ReadAsArray().astype(np.float32)
|
||
|
||
# 校正计算
|
||
corrected_band = R - R_750 + self.A + self.B * diff_640_750
|
||
np.maximum(corrected_band, 0, out=corrected_band)
|
||
|
||
# 如果存在水域掩膜,只对水域区域应用校正
|
||
if water_mask_bool is not None:
|
||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||
|
||
# 添加到结果列表(corrected_band会保留在列表中)
|
||
corrected_bands.append(corrected_band)
|
||
|
||
# 释放当前波段数据(显式删除有助于及时释放内存)
|
||
del R
|
||
|
||
return corrected_bands
|
||
|
||
def _get_corrected_bands_gdal_mem(self):
|
||
"""使用GDAL内存驱动处理numpy数组,逐波段处理"""
|
||
# 创建内存数据集
|
||
driver = gdal.GetDriverByName('MEM')
|
||
mem_dataset = driver.Create('', self.width, self.height, self.n_bands, gdal.GDT_Float32)
|
||
|
||
# 将numpy数组写入内存数据集(显示进度)
|
||
for i in tqdm(range(self.n_bands), desc="加载波段到内存", total=self.n_bands, disable=_is_frozen_gui):
|
||
band = mem_dataset.GetRasterBand(i + 1)
|
||
band.WriteArray(self.im_aligned[:,:,i])
|
||
band.FlushCache()
|
||
|
||
# 临时保存原始dataset引用
|
||
original_dataset = self.dataset
|
||
self.dataset = mem_dataset
|
||
|
||
try:
|
||
# 使用逐波段处理方法
|
||
result = self._get_corrected_bands_gdal()
|
||
finally:
|
||
# 恢复原始dataset
|
||
self.dataset = original_dataset
|
||
mem_dataset = None
|
||
|
||
return result
|
||
|
||
def _save_corrected_bands(self, corrected_bands):
|
||
"""
|
||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||
|
||
注意:为了节省内存,直接逐波段写入,不先堆叠成完整数组
|
||
|
||
:param corrected_bands: 校正后的波段列表
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||
|
||
if self.output_path is None:
|
||
return
|
||
|
||
import os
|
||
# 确保输出目录存在
|
||
output_dir = os.path.dirname(self.output_path)
|
||
if output_dir and not os.path.exists(output_dir):
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 从第一个波段获取尺寸信息(避免堆叠所有波段)
|
||
if not corrected_bands:
|
||
raise ValueError("校正后的波段列表为空")
|
||
first_band = corrected_bands[0]
|
||
height, width = first_band.shape
|
||
n_bands = len(corrected_bands)
|
||
|
||
# 获取地理变换和投影信息
|
||
if self.is_file_path and self.dataset is not None:
|
||
geotransform = self.dataset.GetGeoTransform()
|
||
projection = self.dataset.GetProjection()
|
||
else:
|
||
# 如果没有地理信息,使用默认值
|
||
geotransform = (0, 1, 0, 0, 0, -1)
|
||
projection = ""
|
||
|
||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||
base_path, ext = os.path.splitext(self.output_path)
|
||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||
if ext.lower() != '.bsq':
|
||
bsq_path = base_path + '.bsq'
|
||
else:
|
||
bsq_path = self.output_path
|
||
|
||
# 使用ENVI驱动(默认就是BSQ格式)
|
||
driver = gdal.GetDriverByName('ENVI')
|
||
if driver is None:
|
||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||
|
||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||
if dataset is None:
|
||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||
|
||
try:
|
||
# 设置地理变换和投影
|
||
if geotransform:
|
||
dataset.SetGeoTransform(geotransform)
|
||
if projection:
|
||
dataset.SetProjection(projection)
|
||
|
||
# 直接逐波段写入(不先堆叠,节省内存)
|
||
for i in tqdm(range(n_bands), desc="保存波段", total=n_bands, disable=_is_frozen_gui):
|
||
band = dataset.GetRasterBand(i + 1)
|
||
# 直接从列表中获取波段并写入,避免创建完整数组
|
||
band.WriteArray(corrected_bands[i])
|
||
band.FlushCache()
|
||
finally:
|
||
dataset = None
|
||
|
||
# 检查.hdr文件是否已创建
|
||
hdr_path = bsq_path + '.hdr'
|
||
if os.path.exists(hdr_path):
|
||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||
print(f"头文件已保存至: {hdr_path}")
|
||
else:
|
||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||
|
||
def get_corrected_bands(self):
|
||
"""
|
||
获取校正后的波段
|
||
根据输入类型和大小自动选择最优处理方法
|
||
|
||
:return: 校正后的波段列表
|
||
"""
|
||
# 如果输入是文件路径,使用GDAL直接读取
|
||
if self.is_file_path:
|
||
if self.use_gdal:
|
||
corrected_bands = self._get_corrected_bands_gdal()
|
||
else:
|
||
raise ValueError("输入为文件路径时,必须安装GDAL")
|
||
else:
|
||
# 如果输入是numpy数组
|
||
if self.use_gdal and self.height * self.width * self.n_bands > 100000000:
|
||
# 大图像使用GDAL内存驱动逐波段处理
|
||
corrected_bands = self._get_corrected_bands_gdal_mem()
|
||
else:
|
||
# 小图像使用numpy直接处理
|
||
corrected_bands = self._get_corrected_bands_numpy()
|
||
|
||
# 如果提供了输出路径,保存结果
|
||
if self.output_path is not None:
|
||
self._save_corrected_bands(corrected_bands)
|
||
|
||
return corrected_bands
|
||
|
||
def __del__(self):
|
||
"""清理资源"""
|
||
if self.dataset is not None and self.is_file_path:
|
||
self.dataset = None |