291 lines
12 KiB
Python
291 lines
12 KiB
Python
import numpy as np
|
||
# import preprocessing
|
||
import os
|
||
|
||
try:
|
||
from osgeo import gdal
|
||
GDAL_AVAILABLE = True
|
||
except ImportError:
|
||
GDAL_AVAILABLE = False
|
||
|
||
class Hedley:
|
||
def __init__(self, im_aligned, shp_path=None, NIR_band = 47, water_mask=None, output_path=None):
|
||
"""
|
||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
|
||
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
|
||
:param NIR_band (int): band index for NIR band which corresponds to 842.36nm, which corresponds closely to the NIR band in Micasense
|
||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||
如果为None,则处理全图
|
||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||
如果为None,则不保存
|
||
"""
|
||
self.im_aligned = im_aligned
|
||
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
|
||
self.NIR_band = int(float(NIR_band))
|
||
self.n_bands = im_aligned.shape[-1]
|
||
self.height = im_aligned.shape[0]
|
||
self.width = im_aligned.shape[1]
|
||
self.output_path = output_path
|
||
|
||
# 加载水域掩膜
|
||
self.water_mask = self._load_water_mask(water_mask)
|
||
|
||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||
# 如果存在水域掩膜,只在掩膜内计算R_min
|
||
if self.water_mask is not None:
|
||
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
|
||
self.R_min = np.percentile(nir_band_masked, 5, method='nearest') if nir_band_masked.size > 0 else 0
|
||
else:
|
||
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, method='nearest')
|
||
|
||
def _read_shp_to_bbox(self, shp_path):
|
||
"""
|
||
读取shapefile并提取边界框
|
||
|
||
:param shp_path (str): shapefile文件路径
|
||
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
|
||
"""
|
||
if not os.path.exists(shp_path):
|
||
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
|
||
|
||
try:
|
||
try:
|
||
import geopandas as gpd
|
||
gdf = gpd.read_file(shp_path)
|
||
# 获取所有几何体的总边界框
|
||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||
min_x, min_y, max_x, max_y = bounds
|
||
except ImportError:
|
||
# 如果geopandas不可用,尝试使用fiona
|
||
import fiona
|
||
from shapely.geometry import shape
|
||
|
||
min_x = float('inf')
|
||
min_y = float('inf')
|
||
max_x = float('-inf')
|
||
max_y = float('-inf')
|
||
|
||
with fiona.open(shp_path) as shp:
|
||
for feature in shp:
|
||
geom = shape(feature['geometry'])
|
||
if geom:
|
||
bounds = geom.bounds
|
||
min_x = min(min_x, bounds[0])
|
||
min_y = min(min_y, bounds[1])
|
||
max_x = max(max_x, bounds[2])
|
||
max_y = max(max_y, bounds[3])
|
||
|
||
# 转换为整数像素坐标
|
||
x1 = max(0, int(min_x))
|
||
y1 = max(0, int(min_y))
|
||
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
|
||
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
|
||
|
||
return ((x1, y1), (x2, y2))
|
||
|
||
except Exception as e:
|
||
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
|
||
|
||
def _load_water_mask(self, water_mask):
|
||
"""
|
||
加载水域掩膜
|
||
|
||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||
:return: numpy数组或None,1表示水域,0表示非水域
|
||
"""
|
||
if water_mask is None:
|
||
return None
|
||
|
||
# 如果已经是numpy数组
|
||
if isinstance(water_mask, np.ndarray):
|
||
if water_mask.shape[:2] != (self.height, self.width):
|
||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||
|
||
# 如果是文件路径
|
||
if isinstance(water_mask, str):
|
||
try:
|
||
from osgeo import gdal, ogr
|
||
except ImportError:
|
||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||
|
||
# 检查是否为shapefile
|
||
if water_mask.lower().endswith('.shp'):
|
||
# 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸)
|
||
# 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考
|
||
raise ValueError("Hedley类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
|
||
else:
|
||
# 栅格文件
|
||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||
if mask_dataset is None:
|
||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||
|
||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||
mask_dataset = None
|
||
|
||
if mask_array.shape != (self.height, self.width):
|
||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||
|
||
return (mask_array > 0).astype(np.uint8)
|
||
|
||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||
|
||
def covariance_NIR(self,NIR,b):
|
||
"""
|
||
NIR & b are vectors
|
||
reflectance for band i
|
||
"""
|
||
n = len(NIR)
|
||
# 优化:减少重复计算,使用更高效的numpy操作
|
||
nir_mean = np.mean(NIR)
|
||
b_mean = np.mean(b)
|
||
# 使用更高效的协方差计算
|
||
pij = np.mean((NIR - nir_mean) * (b - b_mean))
|
||
pjj = np.mean((NIR - nir_mean) ** 2)
|
||
# 避免除零错误
|
||
return pij / pjj if pjj != 0 else 0.0
|
||
|
||
def correlation_bands_reflectance(self):
|
||
"""
|
||
calculate correlation between NIR and other bands for reflectance
|
||
NIR_band is 750 nm
|
||
"""
|
||
# If bbox is None, use the entire image
|
||
if self.bbox is None:
|
||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||
# 直接使用视图,只在需要时创建扁平数组
|
||
im_region = self.im_aligned
|
||
mask_region = self.water_mask
|
||
else:
|
||
((x1,y1),(x2,y2)) = self.bbox
|
||
im_region = self.im_aligned[y1:y2,x1:x2,:]
|
||
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
|
||
|
||
# 如果存在水域掩膜,只在掩膜内计算相关性
|
||
if mask_region is not None:
|
||
mask_bool = mask_region.astype(bool)
|
||
if mask_bool.any():
|
||
# 只在掩膜内提取数据
|
||
NIR_reflectance = im_region[:,:,self.NIR_band][mask_bool]
|
||
else:
|
||
# 如果掩膜内没有有效像素,使用全区域
|
||
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
|
||
mask_bool = None
|
||
else:
|
||
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
|
||
mask_bool = None
|
||
|
||
# 优化:一次性计算所有波段的相关性,减少循环开销
|
||
corr_list = []
|
||
for v in range(self.n_bands):
|
||
if mask_bool is not None and mask_bool.any():
|
||
band_reflectance = im_region[:,:,v][mask_bool]
|
||
else:
|
||
band_reflectance = im_region[:,:,v].ravel()
|
||
corr = self.covariance_NIR(NIR_reflectance, band_reflectance)
|
||
corr_list.append(corr)
|
||
|
||
return corr_list
|
||
|
||
def _save_corrected_bands(self, corrected_bands):
|
||
"""
|
||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||
|
||
:param corrected_bands: 校正后的波段列表
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||
|
||
if self.output_path is None:
|
||
return
|
||
|
||
# 确保输出目录存在
|
||
output_dir = os.path.dirname(self.output_path)
|
||
if output_dir and not os.path.exists(output_dir):
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 将波段列表转换为数组
|
||
corrected_array = np.stack(corrected_bands, axis=2)
|
||
|
||
# 如果没有地理信息,使用默认值
|
||
geotransform = (0, 1, 0, 0, 0, -1)
|
||
projection = ""
|
||
|
||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||
base_path, ext = os.path.splitext(self.output_path)
|
||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||
if ext.lower() != '.bsq':
|
||
bsq_path = base_path + '.bsq'
|
||
else:
|
||
bsq_path = self.output_path
|
||
|
||
# 使用ENVI驱动(默认就是BSQ格式)
|
||
driver = gdal.GetDriverByName('ENVI')
|
||
if driver is None:
|
||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||
|
||
height, width, n_bands = corrected_array.shape
|
||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||
if dataset is None:
|
||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||
|
||
try:
|
||
# 设置地理变换和投影
|
||
if geotransform:
|
||
dataset.SetGeoTransform(geotransform)
|
||
if projection:
|
||
dataset.SetProjection(projection)
|
||
|
||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||
for i in range(n_bands):
|
||
band = dataset.GetRasterBand(i + 1)
|
||
band.WriteArray(corrected_array[:, :, i])
|
||
band.FlushCache()
|
||
finally:
|
||
dataset = None
|
||
|
||
# 检查.hdr文件是否已创建
|
||
hdr_path = bsq_path + '.hdr'
|
||
if os.path.exists(hdr_path):
|
||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||
print(f"头文件已保存至: {hdr_path}")
|
||
else:
|
||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||
|
||
def get_corrected_bands(self):
|
||
"""
|
||
correction is done in reflectance
|
||
|
||
:return: 校正后的波段列表
|
||
"""
|
||
corr = self.correlation_bands_reflectance()
|
||
NIR_reflectance = self.im_aligned[:,:,self.NIR_band]
|
||
# 预计算NIR-R_min,避免在循环中重复计算
|
||
NIR_diff = NIR_reflectance - self.R_min
|
||
|
||
# 获取水域掩膜(如果存在)
|
||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||
|
||
corrected_bands = []
|
||
for band_number in range(self.n_bands): #iterate across bands
|
||
b = corr[band_number]
|
||
R = self.im_aligned[:,:,band_number]
|
||
# 优化:减少中间数组创建
|
||
corrected_band = R - b * NIR_diff
|
||
|
||
# 如果存在水域掩膜,只对水域区域应用校正
|
||
if water_mask_bool is not None:
|
||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||
|
||
corrected_bands.append(corrected_band)
|
||
|
||
# 如果提供了输出路径,保存结果
|
||
if self.output_path is not None:
|
||
self._save_corrected_bands(corrected_bands)
|
||
|
||
return corrected_bands
|