refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块

This commit is contained in:
DXC
2026-05-09 17:18:34 +08:00
parent b2b90050dc
commit dcbcc043e4
17 changed files with 2673 additions and 948 deletions

View File

@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
"""
掩膜转换工具模块
提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换,
以及水体掩膜的预处理逻辑。
"""
import os
from pathlib import Path
from typing import Optional, Union
import numpy as np
try:
from osgeo import gdal, ogr
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
def prepare_water_mask_for_algorithm(
water_mask: Optional[Union[str, np.ndarray]],
image_shape: Union[tuple, np.ndarray],
geotransform: tuple,
projection: str,
img_path: str,
water_mask_dir: Optional[str] = None,
callback=None
) -> Optional[np.ndarray]:
"""
准备水域掩膜供算法使用
支持格式:
- None自动使用预先生成的 dat 格式掩膜
- numpy.ndarray直接返回确保是 0/1 格式)
- .dat / .tif 等栅格文件:读取并返回
- .shp 文件:先栅格化,再读取返回
Args:
water_mask: 掩膜来源
image_shape: 影像形状 (height, width) 或 (height, width, channels)
geotransform: GDAL 地理变换参数
projection: 投影信息
img_path: 影像路径(用于 shp 栅格化)
water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果)
callback: 进度回调函数(可选)
Returns:
numpy数组dtype=uint80=非水域1=水域)或 None
"""
img_height, img_width = image_shape[0], image_shape[1]
if water_mask is None:
return None
# numpy 数组直接返回
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (img_height, img_width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配")
return (water_mask > 0).astype(np.uint8)
# 字符串路径
if isinstance(water_mask, str):
ext = Path(water_mask).suffix.lower()
# shapefile 格式
if ext == '.shp':
return _convert_shp_to_mask(
shp_path=water_mask,
img_path=img_path,
image_shape=image_shape,
geotransform=geotransform,
projection=projection,
water_mask_dir=water_mask_dir,
callback=callback
)
# 栅格文件格式
return _load_raster_mask(water_mask, img_height, img_width)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def _convert_shp_to_mask(shp_path: str, img_path: str,
image_shape: tuple,
geotransform: tuple,
projection: str,
water_mask_dir: Optional[str] = None,
callback=None) -> np.ndarray:
"""将 shapefile 栅格化为掩膜数组"""
from src.utils.extract_water_area import rasterize_shp
safe_shp_path = os.path.abspath(shp_path).replace('\\', '/')
shp_name = Path(safe_shp_path).stem
if water_mask_dir:
temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat")
else:
temp_mask_path = f"/tmp/water_mask_{shp_name}.dat"
# 缓存:已栅格化则直接读取
if Path(temp_mask_path).exists():
print(f"使用已存在的栅格化掩膜: {temp_mask_path}")
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
# 需要栅格化
if img_path is None:
raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化")
print(f"正在将 SHP 栅格化: {safe_shp_path}")
rasterize_shp(safe_shp_path, temp_mask_path, img_path)
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray:
"""从栅格文件加载掩膜"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取掩膜文件")
mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {mask_path}")
try:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
finally:
mask_dataset = None
if mask_array.shape != (img_height, img_width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
def ensure_water_mask_dat(img_path: str,
existing_dat_path: Optional[str] = None,
output_dir: Optional[str] = None) -> str:
"""
确保存在 dat 格式的水体掩膜文件用于步骤3/4中的算法
如果 existing_dat_path 存在且是 .dat 文件,直接返回。
如果存在同名 .dat 文件,直接返回。
否则从 img_path 生成并保存到 output_dir。
Args:
img_path: 用于生成掩膜的遥感影像路径
existing_dat_path: 已有的 dat 格式掩膜路径(可选)
output_dir: 输出目录(可选)
Returns:
dat 格式掩膜文件路径
"""
if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat':
if Path(existing_dat_path).exists():
return existing_dat_path
img_name = Path(img_path).stem
if output_dir is None:
output_dir = str(Path(img_path).parent)
dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat")
if Path(dat_path).exists():
return dat_path
# 如果已有其他格式的掩膜,转换为 dat
for ext in ['.tif', '.img', '.tiff']:
alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}")
if Path(alt_path).exists():
return _convert_to_dat(alt_path, dat_path)
return dat_path # 返回目标路径,让调用方决定是否需要生成
def _convert_to_dat(src_path: str, dest_path: str) -> str:
"""将其他栅格格式转换为 ENVI dat 格式"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法转换格式")
src_ds = gdal.Open(src_path, gdal.GA_ReadOnly)
if src_ds is None:
raise ValueError(f"无法打开源掩膜文件: {src_path}")
try:
geotransform = src_ds.GetGeoTransform()
projection = src_ds.GetProjection()
band = src_ds.GetRasterBand(1)
array = band.ReadAsArray()
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte)
if dest_ds is None:
raise ValueError(f"无法创建输出文件: {dest_path}")
try:
dest_ds.SetGeoTransform(geotransform)
dest_ds.SetProjection(projection)
dest_band = dest_ds.GetRasterBand(1)
dest_band.WriteArray((array > 0).astype(np.uint8))
dest_band.FlushCache()
finally:
dest_ds = None
return dest_path
finally:
src_ds = None