# -*- coding: utf-8 -*- """ GDAL 底层 IO 工具模块 提供遥感影像读写、格式转换等底层 GDAL 操作功能。 这些函数不依赖任何业务逻辑,可在其他项目中独立复用。 """ import os from pathlib import Path from typing import Tuple, Optional import numpy as np # GDAL 导入(可选) try: from osgeo import gdal, ogr, gdal_array GDAL_AVAILABLE = True except ImportError: GDAL_AVAILABLE = False # hdr 文件工具 try: from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path UTIL_AVAILABLE = True except ImportError: UTIL_AVAILABLE = False write_fields_to_hdrfile = None get_hdr_file_path = None # ============================================================ # 影像信息读取 # ============================================================ def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]: """ 获取影像的地理信息(不加载图像数据,节省内存) Args: img_path: 影像文件路径 Returns: tuple: (geotransform, projection, width, height, n_bands) """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法读取影像文件") dataset = gdal.Open(img_path, gdal.GA_ReadOnly) if dataset is None: raise ValueError(f"无法打开影像文件: {img_path}") try: width = dataset.RasterXSize height = dataset.RasterYSize n_bands = dataset.RasterCount geotransform = dataset.GetGeoTransform() projection = dataset.GetProjection() return geotransform, projection, width, height, n_bands finally: dataset = None def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]: """ 加载影像文件为numpy数组 注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。 建议直接传递文件路径给算法类,让算法类使用GDAL逐波段处理。 Args: img_path: 影像文件路径 Returns: tuple: (image_array, geotransform, projection) image_array: numpy数组,形状为(height, width, bands) geotransform: 地理变换参数 projection: 投影信息 """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法读取影像文件") dataset = gdal.Open(img_path, gdal.GA_ReadOnly) if dataset is None: raise ValueError(f"无法打开影像文件: {img_path}") try: width = dataset.RasterXSize height = dataset.RasterYSize n_bands = dataset.RasterCount geotransform = dataset.GetGeoTransform() projection = dataset.GetProjection() image_bands = [] for i in range(1, n_bands + 1): band = dataset.GetRasterBand(i) band_data = band.ReadAsArray() image_bands.append(band_data) image_array = np.dstack(image_bands) return image_array, geotransform, projection finally: dataset = None def read_band_as_array(img_path: str, band_index: int) -> np.ndarray: """ 读取单个波段为 numpy 数组 Args: img_path: 影像文件路径 band_index: 波段索引(从 0 开始) Returns: numpy 数组,形状为 (height, width) """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法读取影像文件") dataset = gdal.Open(img_path, gdal.GA_ReadOnly) if dataset is None: raise ValueError(f"无法打开影像文件: {img_path}") try: band = dataset.GetRasterBand(band_index + 1) return band.ReadAsArray() finally: dataset = None def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]: """ 读取多个指定波段为列表 Args: img_path: 影像文件路径 band_indices: 波段索引列表 Returns: tuple: (band_list, geotransform, projection) """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法读取影像文件") dataset = gdal.Open(img_path, gdal.GA_ReadOnly) if dataset is None: raise ValueError(f"无法打开影像文件: {img_path}") try: geotransform = dataset.GetGeoTransform() projection = dataset.GetProjection() bands = [] for idx in band_indices: band = dataset.GetRasterBand(idx + 1) bands.append(band.ReadAsArray()) return bands, geotransform, projection finally: dataset = None # ============================================================ # 影像写入 # ============================================================ def save_array_as_image(image_array: np.ndarray, output_path: str, geotransform: tuple, projection: str, dtype=None) -> str: """ 将numpy数组保存为影像文件 Args: image_array: numpy数组,形状为(height, width, bands) 或 (height, width) output_path: 输出文件路径 geotransform: 地理变换参数 projection: 投影信息 dtype: GDAL数据类型(默认 gdal.GDT_Float32) Returns: 输出文件路径 """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法保存影像文件") if dtype is None: dtype = gdal.GDT_Float32 if image_array.ndim == 2: height, width = image_array.shape n_bands = 1 else: height, width, n_bands = image_array.shape driver = gdal.GetDriverByName('ENVI') if driver is None: driver = gdal.GetDriverByName('GTiff') if driver is None: raise ValueError("无法创建影像文件,没有可用的驱动") dataset = driver.Create(output_path, width, height, n_bands, dtype) if dataset is None: raise ValueError(f"无法创建输出文件: {output_path}") try: dataset.SetGeoTransform(geotransform) dataset.SetProjection(projection) if n_bands == 1: band = dataset.GetRasterBand(1) band.WriteArray(image_array) band.FlushCache() else: for i in range(n_bands): band = dataset.GetRasterBand(i + 1) band.WriteArray(image_array[:, :, i]) band.FlushCache() finally: dataset = None return output_path def save_bands_as_image(corrected_bands: list, output_path: str, geotransform: tuple, projection: str, dtype=None) -> str: """ 直接从波段列表保存影像文件(避免堆叠,节省内存) Args: corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组 output_path: 输出文件路径 geotransform: 地理变换参数 projection: 投影信息 dtype: GDAL数据类型 Returns: 输出文件路径 """ if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法保存影像文件") if not corrected_bands: raise ValueError("波段列表为空") if dtype is None: dtype = gdal.GDT_Float32 n_bands = len(corrected_bands) height, width = corrected_bands[0].shape driver = gdal.GetDriverByName('ENVI') if driver is None: driver = gdal.GetDriverByName('GTiff') if driver is None: raise ValueError("无法创建影像文件,没有可用的驱动") dataset = driver.Create(output_path, width, height, n_bands, dtype) if dataset is None: raise ValueError(f"无法创建输出文件: {output_path}") try: dataset.SetGeoTransform(geotransform) dataset.SetProjection(projection) for i, band_array in enumerate(corrected_bands): if band_array.shape != (height, width): raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配") band = dataset.GetRasterBand(i + 1) band.WriteArray(band_array) band.FlushCache() finally: dataset = None return output_path def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool: """ 复制原始影像的hdr文件信息(如波长等)到目标影像的hdr文件 Args: source_img_path: 源影像文件路径(原始bsq文件) dest_img_path: 目标影像文件路径(去耀斑后的bsq文件) Returns: bool: 是否成功 """ if not UTIL_AVAILABLE: print("警告: util模块未导入,无法复制hdr文件信息") return False try: source_hdr_path = get_hdr_file_path(source_img_path) dest_hdr_path = get_hdr_file_path(dest_img_path) if not Path(source_hdr_path).exists(): print(f"警告: 源hdr文件不存在: {source_hdr_path}") return False if not Path(dest_hdr_path).exists(): print(f"警告: 目标hdr文件不存在: {dest_hdr_path}") return False write_fields_to_hdrfile(source_hdr_path, dest_hdr_path) print(f"已复制原始hdr文件信息到: {dest_hdr_path}") return True except Exception as e: print(f"警告: 复制hdr文件信息时出错: {e}") return False