309 lines
9.2 KiB
Python
309 lines
9.2 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
GDAL 底层 IO 工具模块
|
||
|
||
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
|
||
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
|
||
"""
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Tuple, Optional
|
||
|
||
import numpy as np
|
||
|
||
# GDAL 导入(可选)
|
||
try:
|
||
from osgeo import gdal, ogr, gdal_array
|
||
GDAL_AVAILABLE = True
|
||
except ImportError:
|
||
GDAL_AVAILABLE = False
|
||
|
||
# hdr 文件工具
|
||
try:
|
||
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
|
||
UTIL_AVAILABLE = True
|
||
except ImportError:
|
||
UTIL_AVAILABLE = False
|
||
write_fields_to_hdrfile = None
|
||
get_hdr_file_path = None
|
||
|
||
|
||
# ============================================================
|
||
# 影像信息读取
|
||
# ============================================================
|
||
|
||
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
|
||
"""
|
||
获取影像的地理信息(不加载图像数据,节省内存)
|
||
|
||
Args:
|
||
img_path: 影像文件路径
|
||
|
||
Returns:
|
||
tuple: (geotransform, projection, width, height, n_bands)
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||
|
||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||
if dataset is None:
|
||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||
|
||
try:
|
||
width = dataset.RasterXSize
|
||
height = dataset.RasterYSize
|
||
n_bands = dataset.RasterCount
|
||
geotransform = dataset.GetGeoTransform()
|
||
projection = dataset.GetProjection()
|
||
return geotransform, projection, width, height, n_bands
|
||
finally:
|
||
dataset = None
|
||
|
||
|
||
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
|
||
"""
|
||
加载影像文件为numpy数组
|
||
|
||
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
|
||
建议直接传递文件路径给算法类,让算法类使用GDAL逐波段处理。
|
||
|
||
Args:
|
||
img_path: 影像文件路径
|
||
|
||
Returns:
|
||
tuple: (image_array, geotransform, projection)
|
||
image_array: numpy数组,形状为(height, width, bands)
|
||
geotransform: 地理变换参数
|
||
projection: 投影信息
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||
|
||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||
if dataset is None:
|
||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||
|
||
try:
|
||
width = dataset.RasterXSize
|
||
height = dataset.RasterYSize
|
||
n_bands = dataset.RasterCount
|
||
geotransform = dataset.GetGeoTransform()
|
||
projection = dataset.GetProjection()
|
||
|
||
image_bands = []
|
||
for i in range(1, n_bands + 1):
|
||
band = dataset.GetRasterBand(i)
|
||
band_data = band.ReadAsArray()
|
||
image_bands.append(band_data)
|
||
|
||
image_array = np.dstack(image_bands)
|
||
return image_array, geotransform, projection
|
||
finally:
|
||
dataset = None
|
||
|
||
|
||
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
|
||
"""
|
||
读取单个波段为 numpy 数组
|
||
|
||
Args:
|
||
img_path: 影像文件路径
|
||
band_index: 波段索引(从 0 开始)
|
||
|
||
Returns:
|
||
numpy 数组,形状为 (height, width)
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||
|
||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||
if dataset is None:
|
||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||
|
||
try:
|
||
band = dataset.GetRasterBand(band_index + 1)
|
||
return band.ReadAsArray()
|
||
finally:
|
||
dataset = None
|
||
|
||
|
||
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
|
||
"""
|
||
读取多个指定波段为列表
|
||
|
||
Args:
|
||
img_path: 影像文件路径
|
||
band_indices: 波段索引列表
|
||
|
||
Returns:
|
||
tuple: (band_list, geotransform, projection)
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||
|
||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||
if dataset is None:
|
||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||
|
||
try:
|
||
geotransform = dataset.GetGeoTransform()
|
||
projection = dataset.GetProjection()
|
||
bands = []
|
||
for idx in band_indices:
|
||
band = dataset.GetRasterBand(idx + 1)
|
||
bands.append(band.ReadAsArray())
|
||
return bands, geotransform, projection
|
||
finally:
|
||
dataset = None
|
||
|
||
|
||
# ============================================================
|
||
# 影像写入
|
||
# ============================================================
|
||
|
||
def save_array_as_image(image_array: np.ndarray, output_path: str,
|
||
geotransform: tuple, projection: str,
|
||
dtype=None) -> str:
|
||
"""
|
||
将numpy数组保存为影像文件
|
||
|
||
Args:
|
||
image_array: numpy数组,形状为(height, width, bands) 或 (height, width)
|
||
output_path: 输出文件路径
|
||
geotransform: 地理变换参数
|
||
projection: 投影信息
|
||
dtype: GDAL数据类型(默认 gdal.GDT_Float32)
|
||
|
||
Returns:
|
||
输出文件路径
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||
|
||
if dtype is None:
|
||
dtype = gdal.GDT_Float32
|
||
|
||
if image_array.ndim == 2:
|
||
height, width = image_array.shape
|
||
n_bands = 1
|
||
else:
|
||
height, width, n_bands = image_array.shape
|
||
|
||
driver = gdal.GetDriverByName('ENVI')
|
||
if driver is None:
|
||
driver = gdal.GetDriverByName('GTiff')
|
||
|
||
if driver is None:
|
||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||
|
||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||
if dataset is None:
|
||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||
|
||
try:
|
||
dataset.SetGeoTransform(geotransform)
|
||
dataset.SetProjection(projection)
|
||
|
||
if n_bands == 1:
|
||
band = dataset.GetRasterBand(1)
|
||
band.WriteArray(image_array)
|
||
band.FlushCache()
|
||
else:
|
||
for i in range(n_bands):
|
||
band = dataset.GetRasterBand(i + 1)
|
||
band.WriteArray(image_array[:, :, i])
|
||
band.FlushCache()
|
||
finally:
|
||
dataset = None
|
||
|
||
return output_path
|
||
|
||
|
||
def save_bands_as_image(corrected_bands: list, output_path: str,
|
||
geotransform: tuple, projection: str,
|
||
dtype=None) -> str:
|
||
"""
|
||
直接从波段列表保存影像文件(避免堆叠,节省内存)
|
||
|
||
Args:
|
||
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
|
||
output_path: 输出文件路径
|
||
geotransform: 地理变换参数
|
||
projection: 投影信息
|
||
dtype: GDAL数据类型
|
||
|
||
Returns:
|
||
输出文件路径
|
||
"""
|
||
if not GDAL_AVAILABLE:
|
||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||
|
||
if not corrected_bands:
|
||
raise ValueError("波段列表为空")
|
||
|
||
if dtype is None:
|
||
dtype = gdal.GDT_Float32
|
||
|
||
n_bands = len(corrected_bands)
|
||
height, width = corrected_bands[0].shape
|
||
|
||
driver = gdal.GetDriverByName('ENVI')
|
||
if driver is None:
|
||
driver = gdal.GetDriverByName('GTiff')
|
||
|
||
if driver is None:
|
||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||
|
||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||
if dataset is None:
|
||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||
|
||
try:
|
||
dataset.SetGeoTransform(geotransform)
|
||
dataset.SetProjection(projection)
|
||
|
||
for i, band_array in enumerate(corrected_bands):
|
||
if band_array.shape != (height, width):
|
||
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
|
||
band = dataset.GetRasterBand(i + 1)
|
||
band.WriteArray(band_array)
|
||
band.FlushCache()
|
||
finally:
|
||
dataset = None
|
||
|
||
return output_path
|
||
|
||
|
||
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
|
||
"""
|
||
复制原始影像的hdr文件信息(如波长等)到目标影像的hdr文件
|
||
|
||
Args:
|
||
source_img_path: 源影像文件路径(原始bsq文件)
|
||
dest_img_path: 目标影像文件路径(去耀斑后的bsq文件)
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
if not UTIL_AVAILABLE:
|
||
print("警告: util模块未导入,无法复制hdr文件信息")
|
||
return False
|
||
|
||
try:
|
||
source_hdr_path = get_hdr_file_path(source_img_path)
|
||
dest_hdr_path = get_hdr_file_path(dest_img_path)
|
||
|
||
if not Path(source_hdr_path).exists():
|
||
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
|
||
return False
|
||
|
||
if not Path(dest_hdr_path).exists():
|
||
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
|
||
return False
|
||
|
||
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
|
||
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
|
||
return True
|
||
except Exception as e:
|
||
print(f"警告: 复制hdr文件信息时出错: {e}")
|
||
return False |