refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块

This commit is contained in:
DXC
2026-05-09 17:18:34 +08:00
parent b2b90050dc
commit dcbcc043e4
17 changed files with 2673 additions and 948 deletions

View File

@ -0,0 +1,309 @@
# -*- coding: utf-8 -*-
"""
GDAL 底层 IO 工具模块
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
"""
import os
from pathlib import Path
from typing import Tuple, Optional
import numpy as np
# GDAL 导入(可选)
try:
from osgeo import gdal, ogr, gdal_array
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# hdr 文件工具
try:
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
UTIL_AVAILABLE = True
except ImportError:
UTIL_AVAILABLE = False
write_fields_to_hdrfile = None
get_hdr_file_path = None
# ============================================================
# 影像信息读取
# ============================================================
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
"""
获取影像的地理信息(不加载图像数据,节省内存)
Args:
img_path: 影像文件路径
Returns:
tuple: (geotransform, projection, width, height, n_bands)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
return geotransform, projection, width, height, n_bands
finally:
dataset = None
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
"""
加载影像文件为numpy数组
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
建议直接传递文件路径给算法类让算法类使用GDAL逐波段处理。
Args:
img_path: 影像文件路径
Returns:
tuple: (image_array, geotransform, projection)
image_array: numpy数组形状为(height, width, bands)
geotransform: 地理变换参数
projection: 投影信息
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
image_bands = []
for i in range(1, n_bands + 1):
band = dataset.GetRasterBand(i)
band_data = band.ReadAsArray()
image_bands.append(band_data)
image_array = np.dstack(image_bands)
return image_array, geotransform, projection
finally:
dataset = None
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
"""
读取单个波段为 numpy 数组
Args:
img_path: 影像文件路径
band_index: 波段索引(从 0 开始)
Returns:
numpy 数组,形状为 (height, width)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
band = dataset.GetRasterBand(band_index + 1)
return band.ReadAsArray()
finally:
dataset = None
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
"""
读取多个指定波段为列表
Args:
img_path: 影像文件路径
band_indices: 波段索引列表
Returns:
tuple: (band_list, geotransform, projection)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
bands = []
for idx in band_indices:
band = dataset.GetRasterBand(idx + 1)
bands.append(band.ReadAsArray())
return bands, geotransform, projection
finally:
dataset = None
# ============================================================
# 影像写入
# ============================================================
def save_array_as_image(image_array: np.ndarray, output_path: str,
geotransform: tuple, projection: str,
dtype=None) -> str:
"""
将numpy数组保存为影像文件
Args:
image_array: numpy数组形状为(height, width, bands) 或 (height, width)
output_path: 输出文件路径
geotransform: 地理变换参数
projection: 投影信息
dtype: GDAL数据类型默认 gdal.GDT_Float32
Returns:
输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if dtype is None:
dtype = gdal.GDT_Float32
if image_array.ndim == 2:
height, width = image_array.shape
n_bands = 1
else:
height, width, n_bands = image_array.shape
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise ValueError("无法创建影像文件,没有可用的驱动")
dataset = driver.Create(output_path, width, height, n_bands, dtype)
if dataset is None:
raise ValueError(f"无法创建输出文件: {output_path}")
try:
dataset.SetGeoTransform(geotransform)
dataset.SetProjection(projection)
if n_bands == 1:
band = dataset.GetRasterBand(1)
band.WriteArray(image_array)
band.FlushCache()
else:
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(image_array[:, :, i])
band.FlushCache()
finally:
dataset = None
return output_path
def save_bands_as_image(corrected_bands: list, output_path: str,
geotransform: tuple, projection: str,
dtype=None) -> str:
"""
直接从波段列表保存影像文件(避免堆叠,节省内存)
Args:
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
output_path: 输出文件路径
geotransform: 地理变换参数
projection: 投影信息
dtype: GDAL数据类型
Returns:
输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if not corrected_bands:
raise ValueError("波段列表为空")
if dtype is None:
dtype = gdal.GDT_Float32
n_bands = len(corrected_bands)
height, width = corrected_bands[0].shape
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise ValueError("无法创建影像文件,没有可用的驱动")
dataset = driver.Create(output_path, width, height, n_bands, dtype)
if dataset is None:
raise ValueError(f"无法创建输出文件: {output_path}")
try:
dataset.SetGeoTransform(geotransform)
dataset.SetProjection(projection)
for i, band_array in enumerate(corrected_bands):
if band_array.shape != (height, width):
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
band = dataset.GetRasterBand(i + 1)
band.WriteArray(band_array)
band.FlushCache()
finally:
dataset = None
return output_path
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
"""
复制原始影像的hdr文件信息如波长等到目标影像的hdr文件
Args:
source_img_path: 源影像文件路径原始bsq文件
dest_img_path: 目标影像文件路径去耀斑后的bsq文件
Returns:
bool: 是否成功
"""
if not UTIL_AVAILABLE:
print("警告: util模块未导入无法复制hdr文件信息")
return False
try:
source_hdr_path = get_hdr_file_path(source_img_path)
dest_hdr_path = get_hdr_file_path(dest_img_path)
if not Path(source_hdr_path).exists():
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
return False
if not Path(dest_hdr_path).exists():
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
return False
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
return True
except Exception as e:
print(f"警告: 复制hdr文件信息时出错: {e}")
return False