Files
WQ_GUI/src/core/utils/mask_converter.py

210 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
掩膜转换工具模块
提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换,
以及水体掩膜的预处理逻辑。
"""
import os
from pathlib import Path
from typing import Optional, Union
import numpy as np
try:
from osgeo import gdal, ogr
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
def prepare_water_mask_for_algorithm(
water_mask: Optional[Union[str, np.ndarray]],
image_shape: Union[tuple, np.ndarray],
geotransform: tuple,
projection: str,
img_path: str,
water_mask_dir: Optional[str] = None,
callback=None
) -> Optional[np.ndarray]:
"""
准备水域掩膜供算法使用
支持格式:
- None自动使用预先生成的 dat 格式掩膜
- numpy.ndarray直接返回确保是 0/1 格式)
- .dat / .tif 等栅格文件:读取并返回
- .shp 文件:先栅格化,再读取返回
Args:
water_mask: 掩膜来源
image_shape: 影像形状 (height, width) 或 (height, width, channels)
geotransform: GDAL 地理变换参数
projection: 投影信息
img_path: 影像路径(用于 shp 栅格化)
water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果)
callback: 进度回调函数(可选)
Returns:
numpy数组dtype=uint80=非水域1=水域)或 None
"""
img_height, img_width = image_shape[0], image_shape[1]
if water_mask is None:
return None
# numpy 数组直接返回
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (img_height, img_width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配")
return (water_mask > 0).astype(np.uint8)
# 字符串路径
if isinstance(water_mask, str):
ext = Path(water_mask).suffix.lower()
# shapefile 格式
if ext == '.shp':
return _convert_shp_to_mask(
shp_path=water_mask,
img_path=img_path,
image_shape=image_shape,
geotransform=geotransform,
projection=projection,
water_mask_dir=water_mask_dir,
callback=callback
)
# 栅格文件格式
return _load_raster_mask(water_mask, img_height, img_width)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def _convert_shp_to_mask(shp_path: str, img_path: str,
image_shape: tuple,
geotransform: tuple,
projection: str,
water_mask_dir: Optional[str] = None,
callback=None) -> np.ndarray:
"""将 shapefile 栅格化为掩膜数组"""
from src.utils.extract_water_area import rasterize_shp
safe_shp_path = os.path.abspath(shp_path).replace('\\', '/')
shp_name = Path(safe_shp_path).stem
if water_mask_dir:
temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat")
else:
temp_mask_path = f"/tmp/water_mask_{shp_name}.dat"
# 缓存:已栅格化则直接读取
if Path(temp_mask_path).exists():
print(f"使用已存在的栅格化掩膜: {temp_mask_path}")
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
# 需要栅格化
if img_path is None:
raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化")
print(f"正在将 SHP 栅格化: {safe_shp_path}")
rasterize_shp(safe_shp_path, temp_mask_path, img_path)
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray:
"""从栅格文件加载掩膜"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取掩膜文件")
mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {mask_path}")
try:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
finally:
mask_dataset = None
if mask_array.shape != (img_height, img_width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
def ensure_water_mask_dat(img_path: str,
existing_dat_path: Optional[str] = None,
output_dir: Optional[str] = None) -> str:
"""
确保存在 dat 格式的水体掩膜文件用于步骤3/4中的算法
如果 existing_dat_path 存在且是 .dat 文件,直接返回。
如果存在同名 .dat 文件,直接返回。
否则从 img_path 生成并保存到 output_dir。
Args:
img_path: 用于生成掩膜的遥感影像路径
existing_dat_path: 已有的 dat 格式掩膜路径(可选)
output_dir: 输出目录(可选)
Returns:
dat 格式掩膜文件路径
"""
if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat':
if Path(existing_dat_path).exists():
return existing_dat_path
img_name = Path(img_path).stem
if output_dir is None:
output_dir = str(Path(img_path).parent)
dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat")
if Path(dat_path).exists():
return dat_path
# 如果已有其他格式的掩膜,转换为 dat
for ext in ['.tif', '.img', '.tiff']:
alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}")
if Path(alt_path).exists():
return _convert_to_dat(alt_path, dat_path)
return dat_path # 返回目标路径,让调用方决定是否需要生成
def _convert_to_dat(src_path: str, dest_path: str) -> str:
"""将其他栅格格式转换为 ENVI dat 格式"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法转换格式")
src_ds = gdal.Open(src_path, gdal.GA_ReadOnly)
if src_ds is None:
raise ValueError(f"无法打开源掩膜文件: {src_path}")
try:
geotransform = src_ds.GetGeoTransform()
projection = src_ds.GetProjection()
band = src_ds.GetRasterBand(1)
array = band.ReadAsArray()
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte)
if dest_ds is None:
raise ValueError(f"无法创建输出文件: {dest_path}")
try:
dest_ds.SetGeoTransform(geotransform)
dest_ds.SetProjection(projection)
dest_band = dest_ds.GetRasterBand(1)
dest_band.WriteArray((array > 0).astype(np.uint8))
dest_band.FlushCache()
finally:
dest_ds = None
return dest_path
finally:
src_ds = None