# -*- coding: utf-8 -*- """ 掩膜转换工具模块 提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换, 以及水体掩膜的预处理逻辑。 """ import os from pathlib import Path from typing import Optional, Union import numpy as np try: from osgeo import gdal, ogr GDAL_AVAILABLE = True except ImportError: GDAL_AVAILABLE = False def prepare_water_mask_for_algorithm( water_mask: Optional[Union[str, np.ndarray]], image_shape: Union[tuple, np.ndarray], geotransform: tuple, projection: str, img_path: str, water_mask_dir: Optional[str] = None, callback=None ) -> Optional[np.ndarray]: """ 准备水域掩膜供算法使用 支持格式: - None:自动使用预先生成的 dat 格式掩膜 - numpy.ndarray:直接返回(确保是 0/1 格式) - .dat / .tif 等栅格文件:读取并返回 - .shp 文件:先栅格化,再读取返回 Args: water_mask: 掩膜来源 image_shape: 影像形状 (height, width) 或 (height, width, channels) geotransform: GDAL 地理变换参数 projection: 投影信息 img_path: 影像路径(用于 shp 栅格化) water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果) callback: 进度回调函数(可选) Returns: numpy数组(dtype=uint8,0=非水域,1=水域)或 None """ img_height, img_width = image_shape[0], image_shape[1] if water_mask is None: return None # numpy 数组直接返回 if isinstance(water_mask, np.ndarray): if water_mask.shape[:2] != (img_height, img_width): raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配") return (water_mask > 0).astype(np.uint8) # 字符串路径 if isinstance(water_mask, str): ext = Path(water_mask).suffix.lower() # shapefile 格式 if ext == '.shp': return _convert_shp_to_mask( shp_path=water_mask, img_path=img_path, image_shape=image_shape, geotransform=geotransform, projection=projection, water_mask_dir=water_mask_dir, callback=callback ) # 栅格文件格式 return _load_raster_mask(water_mask, img_height, img_width) raise ValueError(f"不支持的掩膜类型: {type(water_mask)}") def _convert_shp_to_mask(shp_path: str, img_path: str, image_shape: tuple, geotransform: tuple, projection: str, water_mask_dir: Optional[str] = None, callback=None) -> np.ndarray: """将 shapefile 栅格化为掩膜数组""" from src.utils.extract_water_area import rasterize_shp safe_shp_path = os.path.abspath(shp_path).replace('\\', '/') shp_name = Path(safe_shp_path).stem if water_mask_dir: temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat") else: temp_mask_path = f"/tmp/water_mask_{shp_name}.dat" # 缓存:已栅格化则直接读取 if Path(temp_mask_path).exists(): print(f"使用已存在的栅格化掩膜: {temp_mask_path}") return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1]) # 需要栅格化 if img_path is None: raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化") print(f"正在将 SHP 栅格化: {safe_shp_path}") rasterize_shp(safe_shp_path, temp_mask_path, img_path) return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1]) def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray: """从栅格文件加载掩膜""" if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法读取掩膜文件") mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly) if mask_dataset is None: raise ValueError(f"无法打开掩膜文件: {mask_path}") try: mask_array = mask_dataset.GetRasterBand(1).ReadAsArray() finally: mask_dataset = None if mask_array.shape != (img_height, img_width): raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配") return (mask_array > 0).astype(np.uint8) def ensure_water_mask_dat(img_path: str, existing_dat_path: Optional[str] = None, output_dir: Optional[str] = None) -> str: """ 确保存在 dat 格式的水体掩膜文件(用于步骤3/4中的算法) 如果 existing_dat_path 存在且是 .dat 文件,直接返回。 如果存在同名 .dat 文件,直接返回。 否则从 img_path 生成并保存到 output_dir。 Args: img_path: 用于生成掩膜的遥感影像路径 existing_dat_path: 已有的 dat 格式掩膜路径(可选) output_dir: 输出目录(可选) Returns: dat 格式掩膜文件路径 """ if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat': if Path(existing_dat_path).exists(): return existing_dat_path img_name = Path(img_path).stem if output_dir is None: output_dir = str(Path(img_path).parent) dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat") if Path(dat_path).exists(): return dat_path # 如果已有其他格式的掩膜,转换为 dat for ext in ['.tif', '.img', '.tiff']: alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}") if Path(alt_path).exists(): return _convert_to_dat(alt_path, dat_path) return dat_path # 返回目标路径,让调用方决定是否需要生成 def _convert_to_dat(src_path: str, dest_path: str) -> str: """将其他栅格格式转换为 ENVI dat 格式""" if not GDAL_AVAILABLE: raise ImportError("GDAL未安装,无法转换格式") src_ds = gdal.Open(src_path, gdal.GA_ReadOnly) if src_ds is None: raise ValueError(f"无法打开源掩膜文件: {src_path}") try: geotransform = src_ds.GetGeoTransform() projection = src_ds.GetProjection() band = src_ds.GetRasterBand(1) array = band.ReadAsArray() driver = gdal.GetDriverByName('ENVI') if driver is None: driver = gdal.GetDriverByName('GTiff') dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte) if dest_ds is None: raise ValueError(f"无法创建输出文件: {dest_path}") try: dest_ds.SetGeoTransform(geotransform) dest_ds.SetProjection(projection) dest_band = dest_ds.GetRasterBand(1) dest_band.WriteArray((array > 0).astype(np.uint8)) dest_band.FlushCache() finally: dest_ds = None return dest_path finally: src_ds = None