refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块
This commit is contained in:
42
src/core/utils/__init__.py
Normal file
42
src/core/utils/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
工具模块 - 统一导出接口
|
||||
"""
|
||||
from src.core.utils.gdal_helper import (
|
||||
get_image_geo_info,
|
||||
load_image_as_array,
|
||||
save_array_as_image,
|
||||
save_bands_as_image,
|
||||
copy_hdr_info,
|
||||
read_band_as_array,
|
||||
read_multiple_bands,
|
||||
)
|
||||
from src.core.utils.mask_converter import (
|
||||
prepare_water_mask_for_algorithm,
|
||||
ensure_water_mask_dat,
|
||||
)
|
||||
from src.core.utils.preview_generator import (
|
||||
generate_image_preview,
|
||||
generate_water_mask_overlay,
|
||||
select_rgb_bands_by_wavelength,
|
||||
get_wavelength_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# GDAL IO
|
||||
'get_image_geo_info',
|
||||
'load_image_as_array',
|
||||
'save_array_as_image',
|
||||
'save_bands_as_image',
|
||||
'copy_hdr_info',
|
||||
'read_band_as_array',
|
||||
'read_multiple_bands',
|
||||
# 掩膜转换
|
||||
'prepare_water_mask_for_algorithm',
|
||||
'ensure_water_mask_dat',
|
||||
# 预览图生成
|
||||
'generate_image_preview',
|
||||
'generate_water_mask_overlay',
|
||||
'select_rgb_bands_by_wavelength',
|
||||
'get_wavelength_info',
|
||||
]
|
||||
309
src/core/utils/gdal_helper.py
Normal file
309
src/core/utils/gdal_helper.py
Normal file
@ -0,0 +1,309 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
GDAL 底层 IO 工具模块
|
||||
|
||||
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
|
||||
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
# GDAL 导入(可选)
|
||||
try:
|
||||
from osgeo import gdal, ogr, gdal_array
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# hdr 文件工具
|
||||
try:
|
||||
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
|
||||
UTIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
UTIL_AVAILABLE = False
|
||||
write_fields_to_hdrfile = None
|
||||
get_hdr_file_path = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 影像信息读取
|
||||
# ============================================================
|
||||
|
||||
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
|
||||
"""
|
||||
获取影像的地理信息(不加载图像数据,节省内存)
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (geotransform, projection, width, height, n_bands)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
return geotransform, projection, width, height, n_bands
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
|
||||
"""
|
||||
加载影像文件为numpy数组
|
||||
|
||||
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
|
||||
建议直接传递文件路径给算法类,让算法类使用GDAL逐波段处理。
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (image_array, geotransform, projection)
|
||||
image_array: numpy数组,形状为(height, width, bands)
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
image_bands = []
|
||||
for i in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(i)
|
||||
band_data = band.ReadAsArray()
|
||||
image_bands.append(band_data)
|
||||
|
||||
image_array = np.dstack(image_bands)
|
||||
return image_array, geotransform, projection
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
|
||||
"""
|
||||
读取单个波段为 numpy 数组
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
band_index: 波段索引(从 0 开始)
|
||||
|
||||
Returns:
|
||||
numpy 数组,形状为 (height, width)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
band = dataset.GetRasterBand(band_index + 1)
|
||||
return band.ReadAsArray()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
|
||||
"""
|
||||
读取多个指定波段为列表
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
band_indices: 波段索引列表
|
||||
|
||||
Returns:
|
||||
tuple: (band_list, geotransform, projection)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
bands = []
|
||||
for idx in band_indices:
|
||||
band = dataset.GetRasterBand(idx + 1)
|
||||
bands.append(band.ReadAsArray())
|
||||
return bands, geotransform, projection
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 影像写入
|
||||
# ============================================================
|
||||
|
||||
def save_array_as_image(image_array: np.ndarray, output_path: str,
|
||||
geotransform: tuple, projection: str,
|
||||
dtype=None) -> str:
|
||||
"""
|
||||
将numpy数组保存为影像文件
|
||||
|
||||
Args:
|
||||
image_array: numpy数组,形状为(height, width, bands) 或 (height, width)
|
||||
output_path: 输出文件路径
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
dtype: GDAL数据类型(默认 gdal.GDT_Float32)
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if dtype is None:
|
||||
dtype = gdal.GDT_Float32
|
||||
|
||||
if image_array.ndim == 2:
|
||||
height, width = image_array.shape
|
||||
n_bands = 1
|
||||
else:
|
||||
height, width, n_bands = image_array.shape
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
if driver is None:
|
||||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||||
|
||||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||||
|
||||
try:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
if n_bands == 1:
|
||||
band = dataset.GetRasterBand(1)
|
||||
band.WriteArray(image_array)
|
||||
band.FlushCache()
|
||||
else:
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(image_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def save_bands_as_image(corrected_bands: list, output_path: str,
|
||||
geotransform: tuple, projection: str,
|
||||
dtype=None) -> str:
|
||||
"""
|
||||
直接从波段列表保存影像文件(避免堆叠,节省内存)
|
||||
|
||||
Args:
|
||||
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
|
||||
output_path: 输出文件路径
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
dtype: GDAL数据类型
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if not corrected_bands:
|
||||
raise ValueError("波段列表为空")
|
||||
|
||||
if dtype is None:
|
||||
dtype = gdal.GDT_Float32
|
||||
|
||||
n_bands = len(corrected_bands)
|
||||
height, width = corrected_bands[0].shape
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
if driver is None:
|
||||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||||
|
||||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||||
|
||||
try:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
for i, band_array in enumerate(corrected_bands):
|
||||
if band_array.shape != (height, width):
|
||||
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(band_array)
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
|
||||
"""
|
||||
复制原始影像的hdr文件信息(如波长等)到目标影像的hdr文件
|
||||
|
||||
Args:
|
||||
source_img_path: 源影像文件路径(原始bsq文件)
|
||||
dest_img_path: 目标影像文件路径(去耀斑后的bsq文件)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not UTIL_AVAILABLE:
|
||||
print("警告: util模块未导入,无法复制hdr文件信息")
|
||||
return False
|
||||
|
||||
try:
|
||||
source_hdr_path = get_hdr_file_path(source_img_path)
|
||||
dest_hdr_path = get_hdr_file_path(dest_img_path)
|
||||
|
||||
if not Path(source_hdr_path).exists():
|
||||
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
|
||||
return False
|
||||
|
||||
if not Path(dest_hdr_path).exists():
|
||||
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
|
||||
return False
|
||||
|
||||
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
|
||||
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"警告: 复制hdr文件信息时出错: {e}")
|
||||
return False
|
||||
210
src/core/utils/mask_converter.py
Normal file
210
src/core/utils/mask_converter.py
Normal file
@ -0,0 +1,210 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
掩膜转换工具模块
|
||||
|
||||
提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换,
|
||||
以及水体掩膜的预处理逻辑。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
def prepare_water_mask_for_algorithm(
|
||||
water_mask: Optional[Union[str, np.ndarray]],
|
||||
image_shape: Union[tuple, np.ndarray],
|
||||
geotransform: tuple,
|
||||
projection: str,
|
||||
img_path: str,
|
||||
water_mask_dir: Optional[str] = None,
|
||||
callback=None
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
准备水域掩膜供算法使用
|
||||
|
||||
支持格式:
|
||||
- None:自动使用预先生成的 dat 格式掩膜
|
||||
- numpy.ndarray:直接返回(确保是 0/1 格式)
|
||||
- .dat / .tif 等栅格文件:读取并返回
|
||||
- .shp 文件:先栅格化,再读取返回
|
||||
|
||||
Args:
|
||||
water_mask: 掩膜来源
|
||||
image_shape: 影像形状 (height, width) 或 (height, width, channels)
|
||||
geotransform: GDAL 地理变换参数
|
||||
projection: 投影信息
|
||||
img_path: 影像路径(用于 shp 栅格化)
|
||||
water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果)
|
||||
callback: 进度回调函数(可选)
|
||||
|
||||
Returns:
|
||||
numpy数组(dtype=uint8,0=非水域,1=水域)或 None
|
||||
"""
|
||||
img_height, img_width = image_shape[0], image_shape[1]
|
||||
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# numpy 数组直接返回
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (img_height, img_width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8)
|
||||
|
||||
# 字符串路径
|
||||
if isinstance(water_mask, str):
|
||||
ext = Path(water_mask).suffix.lower()
|
||||
|
||||
# shapefile 格式
|
||||
if ext == '.shp':
|
||||
return _convert_shp_to_mask(
|
||||
shp_path=water_mask,
|
||||
img_path=img_path,
|
||||
image_shape=image_shape,
|
||||
geotransform=geotransform,
|
||||
projection=projection,
|
||||
water_mask_dir=water_mask_dir,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
# 栅格文件格式
|
||||
return _load_raster_mask(water_mask, img_height, img_width)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
|
||||
def _convert_shp_to_mask(shp_path: str, img_path: str,
|
||||
image_shape: tuple,
|
||||
geotransform: tuple,
|
||||
projection: str,
|
||||
water_mask_dir: Optional[str] = None,
|
||||
callback=None) -> np.ndarray:
|
||||
"""将 shapefile 栅格化为掩膜数组"""
|
||||
from src.utils.extract_water_area import rasterize_shp
|
||||
|
||||
safe_shp_path = os.path.abspath(shp_path).replace('\\', '/')
|
||||
shp_name = Path(safe_shp_path).stem
|
||||
|
||||
if water_mask_dir:
|
||||
temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat")
|
||||
else:
|
||||
temp_mask_path = f"/tmp/water_mask_{shp_name}.dat"
|
||||
|
||||
# 缓存:已栅格化则直接读取
|
||||
if Path(temp_mask_path).exists():
|
||||
print(f"使用已存在的栅格化掩膜: {temp_mask_path}")
|
||||
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
|
||||
|
||||
# 需要栅格化
|
||||
if img_path is None:
|
||||
raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化")
|
||||
|
||||
print(f"正在将 SHP 栅格化: {safe_shp_path}")
|
||||
rasterize_shp(safe_shp_path, temp_mask_path, img_path)
|
||||
|
||||
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
|
||||
|
||||
|
||||
def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray:
|
||||
"""从栅格文件加载掩膜"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取掩膜文件")
|
||||
|
||||
mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {mask_path}")
|
||||
|
||||
try:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
finally:
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (img_height, img_width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
|
||||
def ensure_water_mask_dat(img_path: str,
|
||||
existing_dat_path: Optional[str] = None,
|
||||
output_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
确保存在 dat 格式的水体掩膜文件(用于步骤3/4中的算法)
|
||||
|
||||
如果 existing_dat_path 存在且是 .dat 文件,直接返回。
|
||||
如果存在同名 .dat 文件,直接返回。
|
||||
否则从 img_path 生成并保存到 output_dir。
|
||||
|
||||
Args:
|
||||
img_path: 用于生成掩膜的遥感影像路径
|
||||
existing_dat_path: 已有的 dat 格式掩膜路径(可选)
|
||||
output_dir: 输出目录(可选)
|
||||
|
||||
Returns:
|
||||
dat 格式掩膜文件路径
|
||||
"""
|
||||
if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat':
|
||||
if Path(existing_dat_path).exists():
|
||||
return existing_dat_path
|
||||
|
||||
img_name = Path(img_path).stem
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(img_path).parent)
|
||||
|
||||
dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat")
|
||||
|
||||
if Path(dat_path).exists():
|
||||
return dat_path
|
||||
|
||||
# 如果已有其他格式的掩膜,转换为 dat
|
||||
for ext in ['.tif', '.img', '.tiff']:
|
||||
alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}")
|
||||
if Path(alt_path).exists():
|
||||
return _convert_to_dat(alt_path, dat_path)
|
||||
|
||||
return dat_path # 返回目标路径,让调用方决定是否需要生成
|
||||
|
||||
|
||||
def _convert_to_dat(src_path: str, dest_path: str) -> str:
|
||||
"""将其他栅格格式转换为 ENVI dat 格式"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法转换格式")
|
||||
|
||||
src_ds = gdal.Open(src_path, gdal.GA_ReadOnly)
|
||||
if src_ds is None:
|
||||
raise ValueError(f"无法打开源掩膜文件: {src_path}")
|
||||
|
||||
try:
|
||||
geotransform = src_ds.GetGeoTransform()
|
||||
projection = src_ds.GetProjection()
|
||||
band = src_ds.GetRasterBand(1)
|
||||
array = band.ReadAsArray()
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte)
|
||||
if dest_ds is None:
|
||||
raise ValueError(f"无法创建输出文件: {dest_path}")
|
||||
|
||||
try:
|
||||
dest_ds.SetGeoTransform(geotransform)
|
||||
dest_ds.SetProjection(projection)
|
||||
dest_band = dest_ds.GetRasterBand(1)
|
||||
dest_band.WriteArray((array > 0).astype(np.uint8))
|
||||
dest_band.FlushCache()
|
||||
finally:
|
||||
dest_ds = None
|
||||
|
||||
return dest_path
|
||||
finally:
|
||||
src_ds = None
|
||||
339
src/core/utils/preview_generator.py
Normal file
339
src/core/utils/preview_generator.py
Normal file
@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
遥感影像预览图生成工具模块
|
||||
|
||||
提供高光谱影像的 RGB 预览图、水域掩膜叠加图等可视化功能。
|
||||
"""
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# matplotlib 仅在实际使用时导入(preview_generator 是可视化工具)
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 辅助函数:波段选择
|
||||
# ============================================================
|
||||
|
||||
def select_rgb_bands_by_wavelength(band_count: int,
|
||||
wavelength_info: Optional[List[float]] = None,
|
||||
fallback_bands: Optional[List[int]] = None) -> List[int]:
|
||||
"""
|
||||
根据波长自动选择 RGB 波段
|
||||
|
||||
Args:
|
||||
band_count: 总波段数
|
||||
wavelength_info: 各波段波长列表(nm),长度为 band_count
|
||||
fallback_bands: 当无法通过波长选择时的回退波段索引 [R, G, B]
|
||||
|
||||
Returns:
|
||||
波段索引列表 [R_index, G_index, B_index](0-based)
|
||||
"""
|
||||
if fallback_bands is None:
|
||||
fallback_bands = [band_count - 3, band_count - 2, band_count - 1]
|
||||
|
||||
if wavelength_info is None:
|
||||
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
|
||||
|
||||
# 目标波长(nm)
|
||||
TARGET_R = 650
|
||||
TARGET_G = 550
|
||||
TARGET_B = 460
|
||||
|
||||
def find_closest(target: float) -> int:
|
||||
min_dist = float('inf')
|
||||
best_idx = 0
|
||||
for i, wl in enumerate(wavelength_info):
|
||||
dist = abs(wl - target)
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
best_idx = i
|
||||
return best_idx
|
||||
|
||||
try:
|
||||
r_idx = find_closest(TARGET_R)
|
||||
g_idx = find_closest(TARGET_G)
|
||||
b_idx = find_closest(TARGET_B)
|
||||
return [r_idx, g_idx, b_idx]
|
||||
except Exception:
|
||||
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
|
||||
|
||||
|
||||
def get_wavelength_info(img_path: str) -> Optional[List[float]]:
|
||||
"""从 hdr 文件读取波长信息"""
|
||||
try:
|
||||
hdr_path = Path(img_path).with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
return None
|
||||
|
||||
wavelengths = []
|
||||
in_wl = False
|
||||
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith('wavelength ='):
|
||||
in_wl = True
|
||||
line = line.split('=', 1)[1].strip()
|
||||
elif in_wl:
|
||||
if line.startswith('{'):
|
||||
line = line[1:]
|
||||
if line.endswith('}'):
|
||||
line = line[:-1]
|
||||
in_wl = False
|
||||
# 解析逗号分隔的数值
|
||||
for token in line.replace(',', ' ').split():
|
||||
try:
|
||||
wavelengths.append(float(token))
|
||||
except ValueError:
|
||||
pass
|
||||
return wavelengths if wavelengths else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 核心预览图生成函数
|
||||
# ============================================================
|
||||
|
||||
def generate_image_preview(img_path: str,
|
||||
output_path: str,
|
||||
bands: Optional[List[int]] = None,
|
||||
title: str = "影像预览") -> str:
|
||||
"""
|
||||
生成高光谱影像的 PNG 预览图
|
||||
|
||||
Args:
|
||||
img_path: 输入影像路径
|
||||
output_path: 输出 PNG 文件路径
|
||||
bands: RGB 波段索引 [R, G, B],None 则自动选择
|
||||
title: 图片标题
|
||||
|
||||
Returns:
|
||||
生成的 PNG 文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法生成影像预览图")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预览图,跳过生成: {output_path}")
|
||||
return output_path
|
||||
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
|
||||
# 自动选择波段
|
||||
if bands is None:
|
||||
if band_count >= 3:
|
||||
wl_info = get_wavelength_info(img_path)
|
||||
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
|
||||
else:
|
||||
bands = [0, 0, 0]
|
||||
|
||||
# 读取波段
|
||||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||||
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
|
||||
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
|
||||
|
||||
r_data[r_data <= 0] = np.nan
|
||||
if band_count > 1:
|
||||
g_data[g_data <= 0] = np.nan
|
||||
if band_count > 2:
|
||||
b_data[b_data <= 0] = np.nan
|
||||
|
||||
# 线性拉伸
|
||||
def linear_stretch(data, low=2, high=98):
|
||||
valid = data[~np.isnan(data)]
|
||||
if len(valid) == 0:
|
||||
return np.zeros_like(data)
|
||||
lo = np.percentile(valid, low)
|
||||
hi = np.percentile(valid, high)
|
||||
if hi - lo < 1e-10:
|
||||
return np.zeros_like(data)
|
||||
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
|
||||
return np.nan_to_num(stretched, nan=0.0)
|
||||
|
||||
r_s = linear_stretch(r_data)
|
||||
g_s = linear_stretch(g_data) if band_count > 1 else r_s
|
||||
b_s = linear_stretch(b_data) if band_count > 2 else r_s
|
||||
|
||||
rgb_image = np.stack([r_s, g_s, b_s], axis=2)
|
||||
|
||||
# 绘图
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
ax.imshow(rgb_image)
|
||||
ax.set_title(title, fontsize=12, fontweight='bold')
|
||||
ax.axis('off')
|
||||
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
if geotransform and geotransform[1] != 0:
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
scale_text = f"分辨率: {pixel_size_x:.2f} m/px | 尺寸: {width} x {height} px"
|
||||
fig.text(0.5, 0.02, scale_text, ha='center', fontsize=9,
|
||||
color='white',
|
||||
bbox=dict(facecolor='black', alpha=0.6,
|
||||
boxstyle='round,pad=0.3'))
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.close(fig)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def generate_water_mask_overlay(img_path: str,
|
||||
mask_path: str,
|
||||
output_path: str,
|
||||
bands: Optional[List[int]] = None,
|
||||
mask_color: tuple = (0, 100, 255),
|
||||
mask_alpha: float = 0.5) -> str:
|
||||
"""
|
||||
生成水域掩膜叠加到原图的 PNG 图像
|
||||
|
||||
Args:
|
||||
img_path: 输入影像路径
|
||||
mask_path: 水域掩膜文件路径
|
||||
output_path: 输出 PNG 路径
|
||||
bands: RGB 波段索引,None 则自动选择
|
||||
mask_color: 掩膜叠加颜色 (R, G, B)
|
||||
mask_alpha: 掩膜透明度(0=完全透明,1=完全不透明)
|
||||
|
||||
Returns:
|
||||
生成的 PNG 文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法生成叠加图")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的叠加图,跳过生成: {output_path}")
|
||||
return output_path
|
||||
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
|
||||
# 自动选择波段
|
||||
if bands is None:
|
||||
if band_count >= 3:
|
||||
wl_info = get_wavelength_info(img_path)
|
||||
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
|
||||
else:
|
||||
bands = [0, 0, 0]
|
||||
|
||||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||||
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
|
||||
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
|
||||
|
||||
r_data[r_data <= 0] = np.nan
|
||||
if band_count > 1:
|
||||
g_data[g_data <= 0] = np.nan
|
||||
if band_count > 2:
|
||||
b_data[b_data <= 0] = np.nan
|
||||
|
||||
def linear_stretch(data, low=2, high=98):
|
||||
valid = data[~np.isnan(data)]
|
||||
if len(valid) == 0:
|
||||
return np.zeros_like(data)
|
||||
lo = np.percentile(valid, low)
|
||||
hi = np.percentile(valid, high)
|
||||
if hi - lo < 1e-10:
|
||||
return np.zeros_like(data)
|
||||
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
|
||||
return np.nan_to_num(stretched, nan=0.0)
|
||||
|
||||
r_s = linear_stretch(r_data)
|
||||
g_s = linear_stretch(g_data) if band_count > 1 else r_s
|
||||
b_s = linear_stretch(b_data) if band_count > 2 else r_s
|
||||
|
||||
rgb_image = np.nan_to_num(np.stack([r_s, g_s, b_s], axis=2)) * 255
|
||||
rgb_image = rgb_image.astype(np.uint8)
|
||||
|
||||
# 读取掩膜
|
||||
mask_dataset = gdal.Open(mask_path)
|
||||
if mask_dataset is not None:
|
||||
mask_data = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
else:
|
||||
print(f"警告: 无法打开掩膜文件: {mask_path}")
|
||||
mask_data = None
|
||||
|
||||
# Alpha 混合
|
||||
overlay = np.zeros((height, width, 4), dtype=np.uint8)
|
||||
overlay[:, :, 0:3] = mask_color
|
||||
overlay[:, :, 3] = 255 # 全不透明
|
||||
|
||||
blended = rgb_image.astype(np.float32)
|
||||
if mask_data is not None:
|
||||
alpha = mask_data.astype(np.float32) / 255.0 * mask_alpha
|
||||
for c in range(3):
|
||||
blended[:, :, c] = rgb_image[:, :, c].astype(np.float32) * (1 - alpha) + mask_color[c] * alpha
|
||||
blended = blended.astype(np.uint8)
|
||||
|
||||
# 绘图
|
||||
fig, ax = plt.subplots(figsize=(14, 10))
|
||||
ax.imshow(blended)
|
||||
ax.axis('off')
|
||||
|
||||
legend_elements = [
|
||||
Patch(facecolor=f'#{mask_color[0]:02x}{mask_color[1]:02x}{mask_color[2]:02x}',
|
||||
edgecolor='black', alpha=mask_alpha, label='水域范围')
|
||||
]
|
||||
ax.legend(handles=legend_elements, loc='upper right', framealpha=0.9)
|
||||
|
||||
# 面积计算
|
||||
if geotransform and geotransform[1] != 0:
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
pixel_size_y = abs(geotransform[5])
|
||||
pixel_area = pixel_size_x * pixel_size_y
|
||||
|
||||
if mask_data is not None:
|
||||
water_pixels = np.sum(mask_data > 0)
|
||||
valid_pixels = np.sum(mask_data >= 0)
|
||||
water_km2 = water_pixels * pixel_area / 1_000_000
|
||||
valid_km2 = valid_pixels * pixel_area / 1_000_000
|
||||
pct = (water_pixels / valid_pixels * 100) if valid_pixels > 0 else 0
|
||||
|
||||
area_text = (f'水域面积: {water_km2:.2f} km² | '
|
||||
f'影像总面积: {valid_km2:.2f} km² | '
|
||||
f'占比: {pct:.1f}%')
|
||||
ax.text(0.02, 0.98, area_text,
|
||||
transform=ax.transAxes, fontsize=11,
|
||||
color='white', fontweight='bold',
|
||||
bbox=dict(facecolor='#0064FF', alpha=0.8,
|
||||
edgecolor='black', boxstyle='round,pad=0.5'),
|
||||
verticalalignment='top')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.close(fig)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
Reference in New Issue
Block a user