Files
water-body-segmentation/tif_caijain.py
2026-03-10 17:29:24 +08:00

155 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
使用二值掩膜 TIF 文件值为1的区域需要去除对数据 TIF 文件进行掩膜。
输入:
data_tif: 要掩膜的数据文件路径
mask_tif: 二值掩膜文件路径值为1表示需要去除的区域
输出:
掩膜后的数据 TIF 文件,仅将掩膜对应位置设为 NoData
要求:
两个 TIF 文件具有相同的投影、分辨率、范围和尺寸(精确对齐),
否则程序将报错或行为未定义。
"""
import argparse
import numpy as np
import rasterio
from rasterio.windows import Window
import logging
from pathlib import Path
from tqdm import tqdm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def mask_data_by_binary_mask(
data_path,
mask_path,
output_path=None,
remove_value=1,
nodata_value=None,
tile_size=4096,
):
"""使用二值掩膜 TIF 对数据 TIF 进行掩膜。
将数据 TIF 中对应掩膜值等于 remove_value 的像素设为 NoData其余保留。
性能建议:
- 若数据源是 tiled GeoTIFF可将 tile_size 设为 0 以按源文件块窗口遍历(通常更快)。
"""
data_path = Path(data_path)
mask_path = Path(mask_path)
if output_path is None:
output_path = data_path.parent / f"{data_path.stem}_masked{data_path.suffix}"
else:
output_path = Path(output_path)
logger.info(f"数据文件: {data_path.name}")
logger.info(f"掩膜文件: {mask_path.name}")
logger.info(f"去除掩膜值: {remove_value}")
with rasterio.Env(GDAL_NUM_THREADS="ALL_CPUS"):
with rasterio.open(data_path) as src_data, rasterio.open(mask_path) as src_mask:
if src_data.crs != src_mask.crs:
raise ValueError("数据与掩膜的 CRS 不一致,请先统一投影。")
if src_data.transform != src_mask.transform:
logger.warning(
"数据与掩膜的地理变换不一致,可能未精确对齐,继续处理可能存在风险。"
)
if (src_data.width, src_data.height) != (src_mask.width, src_mask.height):
raise ValueError("数据与掩膜的尺寸不一致,无法直接按像素对应掩膜。")
# 确定输出 NoData 值(并尽量匹配数据 dtype避免隐式类型转换带来的开销
if nodata_value is None:
nodata_value = src_data.nodata if src_data.nodata is not None else 0
try:
nodata_value_cast = np.array(
nodata_value, dtype=src_data.dtypes[0]
).item()
except Exception:
nodata_value_cast = nodata_value
# 创建输出元数据:基于数据源的元数据,更新 nodata 和压缩选项
out_meta = src_data.meta.copy()
out_meta.update(
{
"nodata": nodata_value,
"compress": (
src_data.compression.value if src_data.compression else "lzw"
),
"tiled": src_data.is_tiled,
}
)
if src_data.is_tiled:
out_meta.update(
{
"blockxsize": src_data.block_shapes[0][0],
"blockysize": src_data.block_shapes[0][1],
}
)
# 创建输出文件
with rasterio.open(output_path, "w", **out_meta) as dst:
width, height = src_data.width, src_data.height
if tile_size is None or tile_size <= 0:
windows = [w for _, w in src_data.block_windows(1)]
else:
stride = int(tile_size)
windows = [
Window(i, j, min(stride, width - i), min(stride, height - j))
for i in range(0, width, stride)
for j in range(0, height, stride)
]
with tqdm(total=len(windows), desc="处理瓦片", unit="") as pbar:
for window in windows:
# 读取相同位置的掩膜瓦片(假设完全对齐)
mask = src_mask.read(1, window=window)
remove_mask = mask == remove_value
# 读取数据瓦片
data = src_data.read(window=window) # shape: (bands, h, w)
if remove_mask.any():
for band_idx in range(data.shape[0]):
np.putmask(
data[band_idx], remove_mask, nodata_value_cast
)
dst.write(data, window=window)
pbar.update(1)
logger.info(f"处理完成,输出文件:{output_path}")
return str(output_path)
def main():
parser = argparse.ArgumentParser(
description="使用二值掩膜 TIF值为1的区域对数据 TIF 进行掩膜,将对应位置设为 NoData。"
)
parser.add_argument("data_tif", help="要掩膜的数据 TIF 文件路径")
parser.add_argument("mask_tif", help="二值掩膜 TIF 文件路径值为1表示需要去除的区域")
parser.add_argument("-o", "--output", help="输出文件路径 (可选)")
parser.add_argument("-r", "--remove_value", type=int, default=1,
help="掩膜中要去除的值默认为1")
parser.add_argument("-n", "--nodata", type=float,
help="输出 NoData 值 (可选,默认使用数据 TIF 的 NoData 或 0)")
parser.add_argument(
"-t",
"--tile_size",
type=int,
default=4096,
help="分块大小像素默认4096设为0则按源文件块窗口遍历tiled 文件通常更快)",
)
args = parser.parse_args()
mask_data_by_binary_mask(
args.data_tif, args.mask_tif, args.output,
args.remove_value, args.nodata, args.tile_size
)
if __name__ == "__main__":
exit(main())