Files
water-body-segmentation/roi_裁剪.py
2026-03-09 17:23:53 +08:00

151 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
掩膜超大tif文件的ROI区域并保存分块处理保持原格式输出范围精确为ROI边界
输入: tif文件路径, roi文件路径(shp/geojson等)
输出: 掩膜后的tif文件仅包含ROI区域
"""
import argparse
import numpy as np
import rasterio
from rasterio.windows import Window
from rasterio.features import geometry_mask
import geopandas as gpd
from shapely.geometry import box
import logging
from pathlib import Path
from tqdm import tqdm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def mask_tif_by_roi_large(tif_path, roi_path, output_path=None, nodata_value=None, tile_size=4096):
"""
使用ROI文件分块掩膜超大TIF文件输出范围精确为ROI的边界裁剪到ROI最小外接矩形
并保持与原文件相同的格式压缩、tiling、数据类型等
"""
tif_path = Path(tif_path)
roi_path = Path(roi_path)
if output_path is None:
output_path = tif_path.parent / f"{tif_path.stem}_masked_roi{tif_path.suffix}"
else:
output_path = Path(output_path)
logger.info(f"开始处理: {tif_path.name}")
logger.info(f"ROI文件: {roi_path.name}")
# 读取ROI
gdf = gpd.read_file(roi_path)
if gdf.empty:
logger.warning("ROI文件为空直接复制原文件")
import shutil
shutil.copy2(tif_path, output_path)
return str(output_path)
with rasterio.open(tif_path) as src:
# 确保ROI与栅格CRS一致
if gdf.crs != src.crs:
gdf = gdf.to_crs(src.crs)
geometries = gdf.geometry.tolist()
# 计算ROI的整体边界minx, miny, maxx, maxy
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
roi_window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
# 将窗口限制在图像实际范围内防止ROI超出图像边界
roi_window = roi_window.intersection(Window(0, 0, src.width, src.height))
if roi_window.width <= 0 or roi_window.height <= 0:
raise ValueError("ROI与影像无重叠区域")
# 获取该窗口对应的地理变换
roi_transform = rasterio.windows.transform(roi_window, src.transform)
roi_width = int(roi_window.width)
roi_height = int(roi_window.height)
logger.info(f"ROI裁剪窗口: 起始列={roi_window.col_off}, 起始行={roi_window.row_off}, "
f"宽度={roi_width}, 高度={roi_height}")
# 获取NoData值
if nodata_value is None:
nodata_value = src.nodata if src.nodata is not None else 0
# 创建输出元数据基于源元数据更新尺寸、变换、nodata等
out_meta = src.meta.copy()
out_meta.update({
'height': roi_height,
'width': roi_width,
'transform': roi_transform,
'nodata': nodata_value,
'compress': src.compression.value if src.compression else 'lzw',
'tiled': src.is_tiled,
})
if src.is_tiled:
out_meta.update({
'blockxsize': src.block_shapes[0][0],
'blockysize': src.block_shapes[0][1],
})
# 创建输出文件
with rasterio.open(output_path, 'w', **out_meta) as dst:
# 计算在ROI窗口内的瓦片迭代范围
# 注意迭代范围是相对于ROI窗口的像素坐标所以从0开始到roi_width
stride = tile_size
total_tiles = ((roi_width + stride - 1) // stride) * ((roi_height + stride - 1) // stride)
with tqdm(total=total_tiles, desc="处理瓦片", unit="") as pbar:
for i in range(0, roi_width, stride):
for j in range(0, roi_height, stride):
w = min(stride, roi_width - i)
h = min(stride, roi_height - j)
# 当前瓦片在输出图像中的窗口
out_window = Window(i, j, w, h)
# 将该窗口转换回原始图像的像素窗口
# 原始图像窗口 = roi_window的左上角 + 当前窗口偏移
src_window = Window(
roi_window.col_off + i,
roi_window.row_off + j,
w, h
)
# 读取原始图像对应窗口的数据
data = src.read(window=src_window)
# 生成掩膜ROI内保留外部设为nodata
# 注意:掩膜需要基于当前瓦片的地理范围
win_transform = rasterio.windows.transform(src_window, src.transform)
mask = geometry_mask(
geometries,
transform=win_transform,
invert=True, # True表示ROI内为True外部为False
out_shape=(h, w)
)
# 将掩膜外区域设为nodata
for band in range(data.shape[0]):
data[band][~mask] = nodata_value
# 写入输出文件(使用输出窗口)
dst.write(data, window=out_window)
pbar.update(1)
logger.info(f"处理完成,输出文件:{output_path}")
return str(output_path)
def main():
parser = argparse.ArgumentParser(description="使用ROI文件掩膜超大TIF文件输出范围精确为ROI边界")
parser.add_argument("tif_path", help="输入TIF文件路径")
parser.add_argument("roi_path", help="ROI文件路径 (shp/geojson等)")
parser.add_argument("-o", "--output", help="输出文件路径 (可选)")
parser.add_argument("-n", "--nodata", type=float, help="nodata值 (可选默认使用原文件nodata)")
parser.add_argument("-t", "--tile_size", type=int, default=4096, help="分块大小像素默认4096")
args = parser.parse_args()
mask_tif_by_roi_large(args.tif_path, args.roi_path, args.output, args.nodata, args.tile_size)
if __name__ == "__main__":
exit(main())