151 lines
6.2 KiB
Python
151 lines
6.2 KiB
Python
"""
|
||
掩膜超大tif文件的ROI区域并保存(分块处理,保持原格式,输出范围精确为ROI边界)
|
||
|
||
输入: tif文件路径, roi文件路径(shp/geojson等)
|
||
输出: 掩膜后的tif文件(仅包含ROI区域)
|
||
"""
|
||
|
||
import argparse
|
||
import numpy as np
|
||
import rasterio
|
||
from rasterio.windows import Window
|
||
from rasterio.features import geometry_mask
|
||
import geopandas as gpd
|
||
from shapely.geometry import box
|
||
import logging
|
||
from pathlib import Path
|
||
from tqdm import tqdm
|
||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def mask_tif_by_roi_large(tif_path, roi_path, output_path=None, nodata_value=None, tile_size=4096):
|
||
"""
|
||
使用ROI文件分块掩膜超大TIF文件,输出范围精确为ROI的边界(裁剪到ROI最小外接矩形),
|
||
并保持与原文件相同的格式(压缩、tiling、数据类型等)
|
||
"""
|
||
tif_path = Path(tif_path)
|
||
roi_path = Path(roi_path)
|
||
|
||
if output_path is None:
|
||
output_path = tif_path.parent / f"{tif_path.stem}_masked_roi{tif_path.suffix}"
|
||
else:
|
||
output_path = Path(output_path)
|
||
|
||
logger.info(f"开始处理: {tif_path.name}")
|
||
logger.info(f"ROI文件: {roi_path.name}")
|
||
|
||
# 读取ROI
|
||
gdf = gpd.read_file(roi_path)
|
||
if gdf.empty:
|
||
logger.warning("ROI文件为空,直接复制原文件")
|
||
import shutil
|
||
shutil.copy2(tif_path, output_path)
|
||
return str(output_path)
|
||
|
||
with rasterio.open(tif_path) as src:
|
||
# 确保ROI与栅格CRS一致
|
||
if gdf.crs != src.crs:
|
||
gdf = gdf.to_crs(src.crs)
|
||
geometries = gdf.geometry.tolist()
|
||
|
||
# 计算ROI的整体边界(minx, miny, maxx, maxy)
|
||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||
roi_window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
|
||
# 将窗口限制在图像实际范围内(防止ROI超出图像边界)
|
||
roi_window = roi_window.intersection(Window(0, 0, src.width, src.height))
|
||
if roi_window.width <= 0 or roi_window.height <= 0:
|
||
raise ValueError("ROI与影像无重叠区域")
|
||
|
||
# 获取该窗口对应的地理变换
|
||
roi_transform = rasterio.windows.transform(roi_window, src.transform)
|
||
roi_width = int(roi_window.width)
|
||
roi_height = int(roi_window.height)
|
||
|
||
logger.info(f"ROI裁剪窗口: 起始列={roi_window.col_off}, 起始行={roi_window.row_off}, "
|
||
f"宽度={roi_width}, 高度={roi_height}")
|
||
|
||
# 获取NoData值
|
||
if nodata_value is None:
|
||
nodata_value = src.nodata if src.nodata is not None else 0
|
||
|
||
# 创建输出元数据:基于源元数据,更新尺寸、变换、nodata等
|
||
out_meta = src.meta.copy()
|
||
out_meta.update({
|
||
'height': roi_height,
|
||
'width': roi_width,
|
||
'transform': roi_transform,
|
||
'nodata': nodata_value,
|
||
'compress': src.compression.value if src.compression else 'lzw',
|
||
'tiled': src.is_tiled,
|
||
})
|
||
if src.is_tiled:
|
||
out_meta.update({
|
||
'blockxsize': src.block_shapes[0][0],
|
||
'blockysize': src.block_shapes[0][1],
|
||
})
|
||
|
||
# 创建输出文件
|
||
with rasterio.open(output_path, 'w', **out_meta) as dst:
|
||
# 计算在ROI窗口内的瓦片迭代范围
|
||
# 注意:迭代范围是相对于ROI窗口的像素坐标,所以从0开始到roi_width
|
||
stride = tile_size
|
||
total_tiles = ((roi_width + stride - 1) // stride) * ((roi_height + stride - 1) // stride)
|
||
|
||
with tqdm(total=total_tiles, desc="处理瓦片", unit="块") as pbar:
|
||
for i in range(0, roi_width, stride):
|
||
for j in range(0, roi_height, stride):
|
||
w = min(stride, roi_width - i)
|
||
h = min(stride, roi_height - j)
|
||
|
||
# 当前瓦片在输出图像中的窗口
|
||
out_window = Window(i, j, w, h)
|
||
|
||
# 将该窗口转换回原始图像的像素窗口
|
||
# 原始图像窗口 = roi_window的左上角 + 当前窗口偏移
|
||
src_window = Window(
|
||
roi_window.col_off + i,
|
||
roi_window.row_off + j,
|
||
w, h
|
||
)
|
||
|
||
# 读取原始图像对应窗口的数据
|
||
data = src.read(window=src_window)
|
||
|
||
# 生成掩膜(ROI内保留,外部设为nodata)
|
||
# 注意:掩膜需要基于当前瓦片的地理范围
|
||
win_transform = rasterio.windows.transform(src_window, src.transform)
|
||
mask = geometry_mask(
|
||
geometries,
|
||
transform=win_transform,
|
||
invert=True, # True表示ROI内为True,外部为False
|
||
out_shape=(h, w)
|
||
)
|
||
|
||
# 将掩膜外区域设为nodata
|
||
for band in range(data.shape[0]):
|
||
data[band][~mask] = nodata_value
|
||
|
||
# 写入输出文件(使用输出窗口)
|
||
dst.write(data, window=out_window)
|
||
pbar.update(1)
|
||
|
||
logger.info(f"处理完成,输出文件:{output_path}")
|
||
return str(output_path)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="使用ROI文件掩膜超大TIF文件(输出范围精确为ROI边界)")
|
||
parser.add_argument("tif_path", help="输入TIF文件路径")
|
||
parser.add_argument("roi_path", help="ROI文件路径 (shp/geojson等)")
|
||
parser.add_argument("-o", "--output", help="输出文件路径 (可选)")
|
||
parser.add_argument("-n", "--nodata", type=float, help="nodata值 (可选,默认使用原文件nodata)")
|
||
parser.add_argument("-t", "--tile_size", type=int, default=4096, help="分块大小(像素),默认4096")
|
||
|
||
args = parser.parse_args()
|
||
mask_tif_by_roi_large(args.tif_path, args.roi_path, args.output, args.nodata, args.tile_size)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit(main()) |