refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块

This commit is contained in:
DXC
2026-05-09 17:18:34 +08:00
parent b2b90050dc
commit dcbcc043e4
17 changed files with 2673 additions and 948 deletions

View File

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
"""
工具模块 - 统一导出接口
"""
from src.core.utils.gdal_helper import (
get_image_geo_info,
load_image_as_array,
save_array_as_image,
save_bands_as_image,
copy_hdr_info,
read_band_as_array,
read_multiple_bands,
)
from src.core.utils.mask_converter import (
prepare_water_mask_for_algorithm,
ensure_water_mask_dat,
)
from src.core.utils.preview_generator import (
generate_image_preview,
generate_water_mask_overlay,
select_rgb_bands_by_wavelength,
get_wavelength_info,
)
__all__ = [
# GDAL IO
'get_image_geo_info',
'load_image_as_array',
'save_array_as_image',
'save_bands_as_image',
'copy_hdr_info',
'read_band_as_array',
'read_multiple_bands',
# 掩膜转换
'prepare_water_mask_for_algorithm',
'ensure_water_mask_dat',
# 预览图生成
'generate_image_preview',
'generate_water_mask_overlay',
'select_rgb_bands_by_wavelength',
'get_wavelength_info',
]

View File

@ -0,0 +1,309 @@
# -*- coding: utf-8 -*-
"""
GDAL 底层 IO 工具模块
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
"""
import os
from pathlib import Path
from typing import Tuple, Optional
import numpy as np
# GDAL 导入(可选)
try:
from osgeo import gdal, ogr, gdal_array
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# hdr 文件工具
try:
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
UTIL_AVAILABLE = True
except ImportError:
UTIL_AVAILABLE = False
write_fields_to_hdrfile = None
get_hdr_file_path = None
# ============================================================
# 影像信息读取
# ============================================================
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
"""
获取影像的地理信息(不加载图像数据,节省内存)
Args:
img_path: 影像文件路径
Returns:
tuple: (geotransform, projection, width, height, n_bands)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
return geotransform, projection, width, height, n_bands
finally:
dataset = None
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
"""
加载影像文件为numpy数组
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
建议直接传递文件路径给算法类让算法类使用GDAL逐波段处理。
Args:
img_path: 影像文件路径
Returns:
tuple: (image_array, geotransform, projection)
image_array: numpy数组形状为(height, width, bands)
geotransform: 地理变换参数
projection: 投影信息
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
image_bands = []
for i in range(1, n_bands + 1):
band = dataset.GetRasterBand(i)
band_data = band.ReadAsArray()
image_bands.append(band_data)
image_array = np.dstack(image_bands)
return image_array, geotransform, projection
finally:
dataset = None
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
"""
读取单个波段为 numpy 数组
Args:
img_path: 影像文件路径
band_index: 波段索引(从 0 开始)
Returns:
numpy 数组,形状为 (height, width)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
band = dataset.GetRasterBand(band_index + 1)
return band.ReadAsArray()
finally:
dataset = None
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
"""
读取多个指定波段为列表
Args:
img_path: 影像文件路径
band_indices: 波段索引列表
Returns:
tuple: (band_list, geotransform, projection)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
bands = []
for idx in band_indices:
band = dataset.GetRasterBand(idx + 1)
bands.append(band.ReadAsArray())
return bands, geotransform, projection
finally:
dataset = None
# ============================================================
# 影像写入
# ============================================================
def save_array_as_image(image_array: np.ndarray, output_path: str,
geotransform: tuple, projection: str,
dtype=None) -> str:
"""
将numpy数组保存为影像文件
Args:
image_array: numpy数组形状为(height, width, bands) 或 (height, width)
output_path: 输出文件路径
geotransform: 地理变换参数
projection: 投影信息
dtype: GDAL数据类型默认 gdal.GDT_Float32
Returns:
输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if dtype is None:
dtype = gdal.GDT_Float32
if image_array.ndim == 2:
height, width = image_array.shape
n_bands = 1
else:
height, width, n_bands = image_array.shape
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise ValueError("无法创建影像文件,没有可用的驱动")
dataset = driver.Create(output_path, width, height, n_bands, dtype)
if dataset is None:
raise ValueError(f"无法创建输出文件: {output_path}")
try:
dataset.SetGeoTransform(geotransform)
dataset.SetProjection(projection)
if n_bands == 1:
band = dataset.GetRasterBand(1)
band.WriteArray(image_array)
band.FlushCache()
else:
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(image_array[:, :, i])
band.FlushCache()
finally:
dataset = None
return output_path
def save_bands_as_image(corrected_bands: list, output_path: str,
geotransform: tuple, projection: str,
dtype=None) -> str:
"""
直接从波段列表保存影像文件(避免堆叠,节省内存)
Args:
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
output_path: 输出文件路径
geotransform: 地理变换参数
projection: 投影信息
dtype: GDAL数据类型
Returns:
输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if not corrected_bands:
raise ValueError("波段列表为空")
if dtype is None:
dtype = gdal.GDT_Float32
n_bands = len(corrected_bands)
height, width = corrected_bands[0].shape
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise ValueError("无法创建影像文件,没有可用的驱动")
dataset = driver.Create(output_path, width, height, n_bands, dtype)
if dataset is None:
raise ValueError(f"无法创建输出文件: {output_path}")
try:
dataset.SetGeoTransform(geotransform)
dataset.SetProjection(projection)
for i, band_array in enumerate(corrected_bands):
if band_array.shape != (height, width):
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
band = dataset.GetRasterBand(i + 1)
band.WriteArray(band_array)
band.FlushCache()
finally:
dataset = None
return output_path
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
"""
复制原始影像的hdr文件信息如波长等到目标影像的hdr文件
Args:
source_img_path: 源影像文件路径原始bsq文件
dest_img_path: 目标影像文件路径去耀斑后的bsq文件
Returns:
bool: 是否成功
"""
if not UTIL_AVAILABLE:
print("警告: util模块未导入无法复制hdr文件信息")
return False
try:
source_hdr_path = get_hdr_file_path(source_img_path)
dest_hdr_path = get_hdr_file_path(dest_img_path)
if not Path(source_hdr_path).exists():
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
return False
if not Path(dest_hdr_path).exists():
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
return False
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
return True
except Exception as e:
print(f"警告: 复制hdr文件信息时出错: {e}")
return False

View File

@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
"""
掩膜转换工具模块
提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换,
以及水体掩膜的预处理逻辑。
"""
import os
from pathlib import Path
from typing import Optional, Union
import numpy as np
try:
from osgeo import gdal, ogr
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
def prepare_water_mask_for_algorithm(
water_mask: Optional[Union[str, np.ndarray]],
image_shape: Union[tuple, np.ndarray],
geotransform: tuple,
projection: str,
img_path: str,
water_mask_dir: Optional[str] = None,
callback=None
) -> Optional[np.ndarray]:
"""
准备水域掩膜供算法使用
支持格式:
- None自动使用预先生成的 dat 格式掩膜
- numpy.ndarray直接返回确保是 0/1 格式)
- .dat / .tif 等栅格文件:读取并返回
- .shp 文件:先栅格化,再读取返回
Args:
water_mask: 掩膜来源
image_shape: 影像形状 (height, width) 或 (height, width, channels)
geotransform: GDAL 地理变换参数
projection: 投影信息
img_path: 影像路径(用于 shp 栅格化)
water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果)
callback: 进度回调函数(可选)
Returns:
numpy数组dtype=uint80=非水域1=水域)或 None
"""
img_height, img_width = image_shape[0], image_shape[1]
if water_mask is None:
return None
# numpy 数组直接返回
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (img_height, img_width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配")
return (water_mask > 0).astype(np.uint8)
# 字符串路径
if isinstance(water_mask, str):
ext = Path(water_mask).suffix.lower()
# shapefile 格式
if ext == '.shp':
return _convert_shp_to_mask(
shp_path=water_mask,
img_path=img_path,
image_shape=image_shape,
geotransform=geotransform,
projection=projection,
water_mask_dir=water_mask_dir,
callback=callback
)
# 栅格文件格式
return _load_raster_mask(water_mask, img_height, img_width)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def _convert_shp_to_mask(shp_path: str, img_path: str,
image_shape: tuple,
geotransform: tuple,
projection: str,
water_mask_dir: Optional[str] = None,
callback=None) -> np.ndarray:
"""将 shapefile 栅格化为掩膜数组"""
from src.utils.extract_water_area import rasterize_shp
safe_shp_path = os.path.abspath(shp_path).replace('\\', '/')
shp_name = Path(safe_shp_path).stem
if water_mask_dir:
temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat")
else:
temp_mask_path = f"/tmp/water_mask_{shp_name}.dat"
# 缓存:已栅格化则直接读取
if Path(temp_mask_path).exists():
print(f"使用已存在的栅格化掩膜: {temp_mask_path}")
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
# 需要栅格化
if img_path is None:
raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化")
print(f"正在将 SHP 栅格化: {safe_shp_path}")
rasterize_shp(safe_shp_path, temp_mask_path, img_path)
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray:
"""从栅格文件加载掩膜"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取掩膜文件")
mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {mask_path}")
try:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
finally:
mask_dataset = None
if mask_array.shape != (img_height, img_width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
def ensure_water_mask_dat(img_path: str,
existing_dat_path: Optional[str] = None,
output_dir: Optional[str] = None) -> str:
"""
确保存在 dat 格式的水体掩膜文件用于步骤3/4中的算法
如果 existing_dat_path 存在且是 .dat 文件,直接返回。
如果存在同名 .dat 文件,直接返回。
否则从 img_path 生成并保存到 output_dir。
Args:
img_path: 用于生成掩膜的遥感影像路径
existing_dat_path: 已有的 dat 格式掩膜路径(可选)
output_dir: 输出目录(可选)
Returns:
dat 格式掩膜文件路径
"""
if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat':
if Path(existing_dat_path).exists():
return existing_dat_path
img_name = Path(img_path).stem
if output_dir is None:
output_dir = str(Path(img_path).parent)
dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat")
if Path(dat_path).exists():
return dat_path
# 如果已有其他格式的掩膜,转换为 dat
for ext in ['.tif', '.img', '.tiff']:
alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}")
if Path(alt_path).exists():
return _convert_to_dat(alt_path, dat_path)
return dat_path # 返回目标路径,让调用方决定是否需要生成
def _convert_to_dat(src_path: str, dest_path: str) -> str:
"""将其他栅格格式转换为 ENVI dat 格式"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法转换格式")
src_ds = gdal.Open(src_path, gdal.GA_ReadOnly)
if src_ds is None:
raise ValueError(f"无法打开源掩膜文件: {src_path}")
try:
geotransform = src_ds.GetGeoTransform()
projection = src_ds.GetProjection()
band = src_ds.GetRasterBand(1)
array = band.ReadAsArray()
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte)
if dest_ds is None:
raise ValueError(f"无法创建输出文件: {dest_path}")
try:
dest_ds.SetGeoTransform(geotransform)
dest_ds.SetProjection(projection)
dest_band = dest_ds.GetRasterBand(1)
dest_band.WriteArray((array > 0).astype(np.uint8))
dest_band.FlushCache()
finally:
dest_ds = None
return dest_path
finally:
src_ds = None

View File

@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-
"""
遥感影像预览图生成工具模块
提供高光谱影像的 RGB 预览图、水域掩膜叠加图等可视化功能。
"""
import numpy as np
from pathlib import Path
from typing import Optional, List
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# matplotlib 仅在实际使用时导入preview_generator 是可视化工具)
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
# ============================================================
# 辅助函数:波段选择
# ============================================================
def select_rgb_bands_by_wavelength(band_count: int,
wavelength_info: Optional[List[float]] = None,
fallback_bands: Optional[List[int]] = None) -> List[int]:
"""
根据波长自动选择 RGB 波段
Args:
band_count: 总波段数
wavelength_info: 各波段波长列表nm长度为 band_count
fallback_bands: 当无法通过波长选择时的回退波段索引 [R, G, B]
Returns:
波段索引列表 [R_index, G_index, B_index]0-based
"""
if fallback_bands is None:
fallback_bands = [band_count - 3, band_count - 2, band_count - 1]
if wavelength_info is None:
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
# 目标波长nm
TARGET_R = 650
TARGET_G = 550
TARGET_B = 460
def find_closest(target: float) -> int:
min_dist = float('inf')
best_idx = 0
for i, wl in enumerate(wavelength_info):
dist = abs(wl - target)
if dist < min_dist:
min_dist = dist
best_idx = i
return best_idx
try:
r_idx = find_closest(TARGET_R)
g_idx = find_closest(TARGET_G)
b_idx = find_closest(TARGET_B)
return [r_idx, g_idx, b_idx]
except Exception:
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
def get_wavelength_info(img_path: str) -> Optional[List[float]]:
"""从 hdr 文件读取波长信息"""
try:
hdr_path = Path(img_path).with_suffix('.hdr')
if not hdr_path.exists():
return None
wavelengths = []
in_wl = False
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.strip()
if line.startswith('wavelength ='):
in_wl = True
line = line.split('=', 1)[1].strip()
elif in_wl:
if line.startswith('{'):
line = line[1:]
if line.endswith('}'):
line = line[:-1]
in_wl = False
# 解析逗号分隔的数值
for token in line.replace(',', ' ').split():
try:
wavelengths.append(float(token))
except ValueError:
pass
return wavelengths if wavelengths else None
except Exception:
return None
# ============================================================
# 核心预览图生成函数
# ============================================================
def generate_image_preview(img_path: str,
output_path: str,
bands: Optional[List[int]] = None,
title: str = "影像预览") -> str:
"""
生成高光谱影像的 PNG 预览图
Args:
img_path: 输入影像路径
output_path: 输出 PNG 文件路径
bands: RGB 波段索引 [R, G, B]None 则自动选择
title: 图片标题
Returns:
生成的 PNG 文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法生成影像预览图")
if Path(output_path).exists():
print(f"检测到已存在的预览图,跳过生成: {output_path}")
return output_path
dataset = gdal.Open(img_path)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
band_count = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
# 自动选择波段
if bands is None:
if band_count >= 3:
wl_info = get_wavelength_info(img_path)
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
else:
bands = [0, 0, 0]
# 读取波段
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
r_data[r_data <= 0] = np.nan
if band_count > 1:
g_data[g_data <= 0] = np.nan
if band_count > 2:
b_data[b_data <= 0] = np.nan
# 线性拉伸
def linear_stretch(data, low=2, high=98):
valid = data[~np.isnan(data)]
if len(valid) == 0:
return np.zeros_like(data)
lo = np.percentile(valid, low)
hi = np.percentile(valid, high)
if hi - lo < 1e-10:
return np.zeros_like(data)
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
return np.nan_to_num(stretched, nan=0.0)
r_s = linear_stretch(r_data)
g_s = linear_stretch(g_data) if band_count > 1 else r_s
b_s = linear_stretch(b_data) if band_count > 2 else r_s
rgb_image = np.stack([r_s, g_s, b_s], axis=2)
# 绘图
fig, ax = plt.subplots(figsize=(12, 10))
ax.imshow(rgb_image)
ax.set_title(title, fontsize=12, fontweight='bold')
ax.axis('off')
geotransform = dataset.GetGeoTransform()
if geotransform and geotransform[1] != 0:
pixel_size_x = abs(geotransform[1])
scale_text = f"分辨率: {pixel_size_x:.2f} m/px | 尺寸: {width} x {height} px"
fig.text(0.5, 0.02, scale_text, ha='center', fontsize=9,
color='white',
bbox=dict(facecolor='black', alpha=0.6,
boxstyle='round,pad=0.3'))
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
return output_path
finally:
dataset = None
def generate_water_mask_overlay(img_path: str,
mask_path: str,
output_path: str,
bands: Optional[List[int]] = None,
mask_color: tuple = (0, 100, 255),
mask_alpha: float = 0.5) -> str:
"""
生成水域掩膜叠加到原图的 PNG 图像
Args:
img_path: 输入影像路径
mask_path: 水域掩膜文件路径
output_path: 输出 PNG 路径
bands: RGB 波段索引None 则自动选择
mask_color: 掩膜叠加颜色 (R, G, B)
mask_alpha: 掩膜透明度0=完全透明1=完全不透明)
Returns:
生成的 PNG 文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法生成叠加图")
if Path(output_path).exists():
print(f"检测到已存在的叠加图,跳过生成: {output_path}")
return output_path
dataset = gdal.Open(img_path)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
band_count = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
# 自动选择波段
if bands is None:
if band_count >= 3:
wl_info = get_wavelength_info(img_path)
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
else:
bands = [0, 0, 0]
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
r_data[r_data <= 0] = np.nan
if band_count > 1:
g_data[g_data <= 0] = np.nan
if band_count > 2:
b_data[b_data <= 0] = np.nan
def linear_stretch(data, low=2, high=98):
valid = data[~np.isnan(data)]
if len(valid) == 0:
return np.zeros_like(data)
lo = np.percentile(valid, low)
hi = np.percentile(valid, high)
if hi - lo < 1e-10:
return np.zeros_like(data)
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
return np.nan_to_num(stretched, nan=0.0)
r_s = linear_stretch(r_data)
g_s = linear_stretch(g_data) if band_count > 1 else r_s
b_s = linear_stretch(b_data) if band_count > 2 else r_s
rgb_image = np.nan_to_num(np.stack([r_s, g_s, b_s], axis=2)) * 255
rgb_image = rgb_image.astype(np.uint8)
# 读取掩膜
mask_dataset = gdal.Open(mask_path)
if mask_dataset is not None:
mask_data = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
else:
print(f"警告: 无法打开掩膜文件: {mask_path}")
mask_data = None
# Alpha 混合
overlay = np.zeros((height, width, 4), dtype=np.uint8)
overlay[:, :, 0:3] = mask_color
overlay[:, :, 3] = 255 # 全不透明
blended = rgb_image.astype(np.float32)
if mask_data is not None:
alpha = mask_data.astype(np.float32) / 255.0 * mask_alpha
for c in range(3):
blended[:, :, c] = rgb_image[:, :, c].astype(np.float32) * (1 - alpha) + mask_color[c] * alpha
blended = blended.astype(np.uint8)
# 绘图
fig, ax = plt.subplots(figsize=(14, 10))
ax.imshow(blended)
ax.axis('off')
legend_elements = [
Patch(facecolor=f'#{mask_color[0]:02x}{mask_color[1]:02x}{mask_color[2]:02x}',
edgecolor='black', alpha=mask_alpha, label='水域范围')
]
ax.legend(handles=legend_elements, loc='upper right', framealpha=0.9)
# 面积计算
if geotransform and geotransform[1] != 0:
pixel_size_x = abs(geotransform[1])
pixel_size_y = abs(geotransform[5])
pixel_area = pixel_size_x * pixel_size_y
if mask_data is not None:
water_pixels = np.sum(mask_data > 0)
valid_pixels = np.sum(mask_data >= 0)
water_km2 = water_pixels * pixel_area / 1_000_000
valid_km2 = valid_pixels * pixel_area / 1_000_000
pct = (water_pixels / valid_pixels * 100) if valid_pixels > 0 else 0
area_text = (f'水域面积: {water_km2:.2f} km² | '
f'影像总面积: {valid_km2:.2f} km² | '
f'占比: {pct:.1f}%')
ax.text(0.02, 0.98, area_text,
transform=ax.transAxes, fontsize=11,
color='white', fontweight='bold',
bbox=dict(facecolor='#0064FF', alpha=0.8,
edgecolor='black', boxstyle='round,pad=0.5'),
verticalalignment='top')
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
return output_path
finally:
dataset = None