Initial commit
This commit is contained in:
151
roi_裁剪.py
Normal file
151
roi_裁剪.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""
|
||||
掩膜超大tif文件的ROI区域并保存(分块处理,保持原格式,输出范围精确为ROI边界)
|
||||
|
||||
输入: tif文件路径, roi文件路径(shp/geojson等)
|
||||
输出: 掩膜后的tif文件(仅包含ROI区域)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import rasterio
|
||||
from rasterio.windows import Window
|
||||
from rasterio.features import geometry_mask
|
||||
import geopandas as gpd
|
||||
from shapely.geometry import box
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def mask_tif_by_roi_large(tif_path, roi_path, output_path=None, nodata_value=None, tile_size=4096):
|
||||
"""
|
||||
使用ROI文件分块掩膜超大TIF文件,输出范围精确为ROI的边界(裁剪到ROI最小外接矩形),
|
||||
并保持与原文件相同的格式(压缩、tiling、数据类型等)
|
||||
"""
|
||||
tif_path = Path(tif_path)
|
||||
roi_path = Path(roi_path)
|
||||
|
||||
if output_path is None:
|
||||
output_path = tif_path.parent / f"{tif_path.stem}_masked_roi{tif_path.suffix}"
|
||||
else:
|
||||
output_path = Path(output_path)
|
||||
|
||||
logger.info(f"开始处理: {tif_path.name}")
|
||||
logger.info(f"ROI文件: {roi_path.name}")
|
||||
|
||||
# 读取ROI
|
||||
gdf = gpd.read_file(roi_path)
|
||||
if gdf.empty:
|
||||
logger.warning("ROI文件为空,直接复制原文件")
|
||||
import shutil
|
||||
shutil.copy2(tif_path, output_path)
|
||||
return str(output_path)
|
||||
|
||||
with rasterio.open(tif_path) as src:
|
||||
# 确保ROI与栅格CRS一致
|
||||
if gdf.crs != src.crs:
|
||||
gdf = gdf.to_crs(src.crs)
|
||||
geometries = gdf.geometry.tolist()
|
||||
|
||||
# 计算ROI的整体边界(minx, miny, maxx, maxy)
|
||||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||||
roi_window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
|
||||
# 将窗口限制在图像实际范围内(防止ROI超出图像边界)
|
||||
roi_window = roi_window.intersection(Window(0, 0, src.width, src.height))
|
||||
if roi_window.width <= 0 or roi_window.height <= 0:
|
||||
raise ValueError("ROI与影像无重叠区域")
|
||||
|
||||
# 获取该窗口对应的地理变换
|
||||
roi_transform = rasterio.windows.transform(roi_window, src.transform)
|
||||
roi_width = int(roi_window.width)
|
||||
roi_height = int(roi_window.height)
|
||||
|
||||
logger.info(f"ROI裁剪窗口: 起始列={roi_window.col_off}, 起始行={roi_window.row_off}, "
|
||||
f"宽度={roi_width}, 高度={roi_height}")
|
||||
|
||||
# 获取NoData值
|
||||
if nodata_value is None:
|
||||
nodata_value = src.nodata if src.nodata is not None else 0
|
||||
|
||||
# 创建输出元数据:基于源元数据,更新尺寸、变换、nodata等
|
||||
out_meta = src.meta.copy()
|
||||
out_meta.update({
|
||||
'height': roi_height,
|
||||
'width': roi_width,
|
||||
'transform': roi_transform,
|
||||
'nodata': nodata_value,
|
||||
'compress': src.compression.value if src.compression else 'lzw',
|
||||
'tiled': src.is_tiled,
|
||||
})
|
||||
if src.is_tiled:
|
||||
out_meta.update({
|
||||
'blockxsize': src.block_shapes[0][0],
|
||||
'blockysize': src.block_shapes[0][1],
|
||||
})
|
||||
|
||||
# 创建输出文件
|
||||
with rasterio.open(output_path, 'w', **out_meta) as dst:
|
||||
# 计算在ROI窗口内的瓦片迭代范围
|
||||
# 注意:迭代范围是相对于ROI窗口的像素坐标,所以从0开始到roi_width
|
||||
stride = tile_size
|
||||
total_tiles = ((roi_width + stride - 1) // stride) * ((roi_height + stride - 1) // stride)
|
||||
|
||||
with tqdm(total=total_tiles, desc="处理瓦片", unit="块") as pbar:
|
||||
for i in range(0, roi_width, stride):
|
||||
for j in range(0, roi_height, stride):
|
||||
w = min(stride, roi_width - i)
|
||||
h = min(stride, roi_height - j)
|
||||
|
||||
# 当前瓦片在输出图像中的窗口
|
||||
out_window = Window(i, j, w, h)
|
||||
|
||||
# 将该窗口转换回原始图像的像素窗口
|
||||
# 原始图像窗口 = roi_window的左上角 + 当前窗口偏移
|
||||
src_window = Window(
|
||||
roi_window.col_off + i,
|
||||
roi_window.row_off + j,
|
||||
w, h
|
||||
)
|
||||
|
||||
# 读取原始图像对应窗口的数据
|
||||
data = src.read(window=src_window)
|
||||
|
||||
# 生成掩膜(ROI内保留,外部设为nodata)
|
||||
# 注意:掩膜需要基于当前瓦片的地理范围
|
||||
win_transform = rasterio.windows.transform(src_window, src.transform)
|
||||
mask = geometry_mask(
|
||||
geometries,
|
||||
transform=win_transform,
|
||||
invert=True, # True表示ROI内为True,外部为False
|
||||
out_shape=(h, w)
|
||||
)
|
||||
|
||||
# 将掩膜外区域设为nodata
|
||||
for band in range(data.shape[0]):
|
||||
data[band][~mask] = nodata_value
|
||||
|
||||
# 写入输出文件(使用输出窗口)
|
||||
dst.write(data, window=out_window)
|
||||
pbar.update(1)
|
||||
|
||||
logger.info(f"处理完成,输出文件:{output_path}")
|
||||
return str(output_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="使用ROI文件掩膜超大TIF文件(输出范围精确为ROI边界)")
|
||||
parser.add_argument("tif_path", help="输入TIF文件路径")
|
||||
parser.add_argument("roi_path", help="ROI文件路径 (shp/geojson等)")
|
||||
parser.add_argument("-o", "--output", help="输出文件路径 (可选)")
|
||||
parser.add_argument("-n", "--nodata", type=float, help="nodata值 (可选,默认使用原文件nodata)")
|
||||
parser.add_argument("-t", "--tile_size", type=int, default=4096, help="分块大小(像素),默认4096")
|
||||
|
||||
args = parser.parse_args()
|
||||
mask_tif_by_roi_large(args.tif_path, args.roi_path, args.output, args.nodata, args.tile_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user