Files
WQ_GUI/src/core/utils/gdal_helper.py

309 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
GDAL 底层 IO 工具模块
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
"""
import os
from pathlib import Path
from typing import Tuple, Optional
import numpy as np
# GDAL 导入(可选)
try:
from osgeo import gdal, ogr, gdal_array
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# hdr 文件工具
try:
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
UTIL_AVAILABLE = True
except ImportError:
UTIL_AVAILABLE = False
write_fields_to_hdrfile = None
get_hdr_file_path = None
# ============================================================
# 影像信息读取
# ============================================================
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
"""
获取影像的地理信息(不加载图像数据,节省内存)
Args:
img_path: 影像文件路径
Returns:
tuple: (geotransform, projection, width, height, n_bands)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
return geotransform, projection, width, height, n_bands
finally:
dataset = None
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
"""
加载影像文件为numpy数组
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
建议直接传递文件路径给算法类让算法类使用GDAL逐波段处理。
Args:
img_path: 影像文件路径
Returns:
tuple: (image_array, geotransform, projection)
image_array: numpy数组形状为(height, width, bands)
geotransform: 地理变换参数
projection: 投影信息
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
image_bands = []
for i in range(1, n_bands + 1):
band = dataset.GetRasterBand(i)
band_data = band.ReadAsArray()
image_bands.append(band_data)
image_array = np.dstack(image_bands)
return image_array, geotransform, projection
finally:
dataset = None
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
"""
读取单个波段为 numpy 数组
Args:
img_path: 影像文件路径
band_index: 波段索引(从 0 开始)
Returns:
numpy 数组,形状为 (height, width)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
band = dataset.GetRasterBand(band_index + 1)
return band.ReadAsArray()
finally:
dataset = None
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
"""
读取多个指定波段为列表
Args:
img_path: 影像文件路径
band_indices: 波段索引列表
Returns:
tuple: (band_list, geotransform, projection)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
bands = []
for idx in band_indices:
band = dataset.GetRasterBand(idx + 1)
bands.append(band.ReadAsArray())
return bands, geotransform, projection
finally:
dataset = None
# ============================================================
# 影像写入
# ============================================================
def save_array_as_image(image_array: np.ndarray, output_path: str,
geotransform: tuple, projection: str,
dtype=None) -> str:
"""
将numpy数组保存为影像文件
Args:
image_array: numpy数组形状为(height, width, bands) 或 (height, width)
output_path: 输出文件路径
geotransform: 地理变换参数
projection: 投影信息
dtype: GDAL数据类型默认 gdal.GDT_Float32
Returns:
输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if dtype is None:
dtype = gdal.GDT_Float32
if image_array.ndim == 2:
height, width = image_array.shape
n_bands = 1
else:
height, width, n_bands = image_array.shape
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise ValueError("无法创建影像文件,没有可用的驱动")
dataset = driver.Create(output_path, width, height, n_bands, dtype)
if dataset is None:
raise ValueError(f"无法创建输出文件: {output_path}")
try:
dataset.SetGeoTransform(geotransform)
dataset.SetProjection(projection)
if n_bands == 1:
band = dataset.GetRasterBand(1)
band.WriteArray(image_array)
band.FlushCache()
else:
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(image_array[:, :, i])
band.FlushCache()
finally:
dataset = None
return output_path
def save_bands_as_image(corrected_bands: list, output_path: str,
geotransform: tuple, projection: str,
dtype=None) -> str:
"""
直接从波段列表保存影像文件(避免堆叠,节省内存)
Args:
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
output_path: 输出文件路径
geotransform: 地理变换参数
projection: 投影信息
dtype: GDAL数据类型
Returns:
输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if not corrected_bands:
raise ValueError("波段列表为空")
if dtype is None:
dtype = gdal.GDT_Float32
n_bands = len(corrected_bands)
height, width = corrected_bands[0].shape
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise ValueError("无法创建影像文件,没有可用的驱动")
dataset = driver.Create(output_path, width, height, n_bands, dtype)
if dataset is None:
raise ValueError(f"无法创建输出文件: {output_path}")
try:
dataset.SetGeoTransform(geotransform)
dataset.SetProjection(projection)
for i, band_array in enumerate(corrected_bands):
if band_array.shape != (height, width):
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
band = dataset.GetRasterBand(i + 1)
band.WriteArray(band_array)
band.FlushCache()
finally:
dataset = None
return output_path
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
"""
复制原始影像的hdr文件信息如波长等到目标影像的hdr文件
Args:
source_img_path: 源影像文件路径原始bsq文件
dest_img_path: 目标影像文件路径去耀斑后的bsq文件
Returns:
bool: 是否成功
"""
if not UTIL_AVAILABLE:
print("警告: util模块未导入无法复制hdr文件信息")
return False
try:
source_hdr_path = get_hdr_file_path(source_img_path)
dest_hdr_path = get_hdr_file_path(dest_img_path)
if not Path(source_hdr_path).exists():
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
return False
if not Path(dest_hdr_path).exists():
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
return False
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
return True
except Exception as e:
print(f"警告: 复制hdr文件信息时出错: {e}")
return False