Files
StripStitch/mask_water.py
2026-03-06 17:24:55 +08:00

131 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
掩膜tif文件的ROI区域并保存
输入: tif文件路径, roi文件路径(shp/geojson等)
输出: 掩膜后的tif文件
"""
import argparse
import numpy as np
import rasterio
from rasterio.mask import mask
from rasterio.features import shapes
import geopandas as gpd
# from shapely.geometry import shape
import logging
from pathlib import Path
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def mask_tif_by_roi(tif_path: str, roi_path: str, output_path: str = None,
nodata_value: float = None):
"""
使用ROI文件掩膜TIF文件
Args:
tif_path: 输入TIF文件路径
roi_path: ROI文件路径shp/geojson等
output_path: 输出文件路径如果为None则自动生成
nodata_value: nodata值如果为None则使用原始文件的nodata值
Returns:
str: 输出文件路径
"""
tif_path = Path(tif_path)
roi_path = Path(roi_path)
# 检查输入文件存在性
if not tif_path.exists():
raise FileNotFoundError(f"TIF文件不存在: {tif_path}")
if not roi_path.exists():
raise FileNotFoundError(f"ROI文件不存在: {roi_path}")
# 自动生成输出路径
if output_path is None:
output_path = tif_path.parent / f"{tif_path.stem}_masked{tif_path.suffix}"
else:
output_path = Path(output_path)
logger.info(f"开始处理: {tif_path.name}")
logger.info(f"ROI文件: {roi_path.name}")
try:
# 读取ROI文件
logger.info("读取ROI文件...")
gdf = gpd.read_file(roi_path)
if gdf.empty:
logger.warning("ROI文件为空直接复制原文件")
# 直接复制原文件
import shutil
shutil.copy2(tif_path, output_path)
return str(output_path)
# 转换为GeoJSON格式的几何对象
geometries = gdf.geometry.tolist()
logger.info(f"找到 {len(geometries)} 个ROI几何对象")
# 读取并掩膜TIF
logger.info("读取并掩膜TIF文件...")
with rasterio.open(tif_path) as src:
# 使用ROI掩膜 - 保留ROI以外的区域ROI区域设为nodata
masked_data, masked_transform = mask(
src,
geometries,
crop=False, # 不裁剪,保持原始尺寸
invert=True, # 反转ROI区域设为nodataROI以外保留
nodata=nodata_value if nodata_value is not None else src.nodata
)
# 更新元数据
out_meta = src.meta.copy()
out_meta.update({
"driver": "GTiff",
"height": masked_data.shape[1],
"width": masked_data.shape[2],
"transform": masked_transform,
"nodata": nodata_value if nodata_value is not None else src.nodata
})
# 保存结果
logger.info(f"保存掩膜结果到: {output_path.name}")
with rasterio.open(output_path, "w", **out_meta) as dest:
dest.write(masked_data)
logger.info("处理完成!")
return str(output_path)
except Exception as e:
logger.error(f"处理失败: {str(e)}")
raise
def main():
parser = argparse.ArgumentParser(description="使用ROI文件掩膜TIF文件")
parser.add_argument("tif_path", help="输入TIF文件路径")
parser.add_argument("roi_path", help="ROI文件路径 (shp/geojson等)")
parser.add_argument("-o", "--output", help="输出文件路径 (可选,默认自动生成)")
parser.add_argument("-n", "--nodata", type=float, help="nodata值 (可选默认使用原文件nodata)")
args = parser.parse_args()
try:
output_path = mask_tif_by_roi(
args.tif_path,
args.roi_path,
args.output,
args.nodata
)
print(f"成功完成!输出文件: {output_path}")
except Exception as e:
print(f"错误: {e}")
return 1
return 0
if __name__ == "__main__":
exit(main())