187 lines
6.8 KiB
Python
187 lines
6.8 KiB
Python
"""
|
||
使用二值掩膜 TIF 文件(值为1的区域需要去除)对数据 TIF 文件进行掩膜。
|
||
输入:
|
||
data_tif: 要掩膜的数据文件路径
|
||
mask_tif: 二值掩膜文件路径(值为1表示需要去除的区域)
|
||
输出:
|
||
掩膜后的数据 TIF 文件,仅将掩膜对应位置设为 NoData
|
||
要求:
|
||
两个 TIF 文件具有相同的投影、分辨率、范围和尺寸(精确对齐),
|
||
否则程序将报错或行为未定义。
|
||
"""
|
||
|
||
import argparse
|
||
import numpy as np
|
||
import rasterio
|
||
from rasterio.windows import Window
|
||
import logging
|
||
import sys
|
||
from pathlib import Path
|
||
from tqdm import tqdm
|
||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class _NullWriter:
|
||
def write(self, _s):
|
||
return 0
|
||
|
||
def flush(self):
|
||
return None
|
||
|
||
|
||
def _tqdm_output():
|
||
for fp in (getattr(sys, "stderr", None), getattr(sys, "stdout", None)):
|
||
if fp is not None and hasattr(fp, "write"):
|
||
return fp
|
||
return _NullWriter()
|
||
|
||
|
||
def _tqdm_disable(fp) -> bool:
|
||
if isinstance(fp, _NullWriter):
|
||
return True
|
||
try:
|
||
return not bool(fp.isatty())
|
||
except Exception:
|
||
return True
|
||
|
||
|
||
def mask_data_by_binary_mask(
|
||
data_path,
|
||
mask_path,
|
||
output_path=None,
|
||
remove_value=1,
|
||
nodata_value=None,
|
||
tile_size=4096,
|
||
):
|
||
"""使用二值掩膜 TIF 对数据 TIF 进行掩膜。
|
||
|
||
将数据 TIF 中对应掩膜值等于 remove_value 的像素设为 NoData,其余保留。
|
||
|
||
性能建议:
|
||
- 若数据源是 tiled GeoTIFF,可将 tile_size 设为 0 以按源文件块窗口遍历(通常更快)。
|
||
"""
|
||
data_path = Path(data_path)
|
||
mask_path = Path(mask_path)
|
||
|
||
if output_path is None:
|
||
output_path = data_path.parent / f"{data_path.stem}_masked{data_path.suffix}"
|
||
else:
|
||
output_path = Path(output_path)
|
||
|
||
logger.info(f"数据文件: {data_path.name}")
|
||
logger.info(f"掩膜文件: {mask_path.name}")
|
||
logger.info(f"去除掩膜值: {remove_value}")
|
||
|
||
with rasterio.Env(GDAL_NUM_THREADS="ALL_CPUS"):
|
||
with rasterio.open(data_path) as src_data, rasterio.open(mask_path) as src_mask:
|
||
if src_data.crs != src_mask.crs:
|
||
raise ValueError("数据与掩膜的 CRS 不一致,请先统一投影。")
|
||
if src_data.transform != src_mask.transform:
|
||
logger.warning(
|
||
"数据与掩膜的地理变换不一致,可能未精确对齐,继续处理可能存在风险。"
|
||
)
|
||
if (src_data.width, src_data.height) != (src_mask.width, src_mask.height):
|
||
raise ValueError("数据与掩膜的尺寸不一致,无法直接按像素对应掩膜。")
|
||
|
||
# 确定输出 NoData 值(并尽量匹配数据 dtype,避免隐式类型转换带来的开销)
|
||
if nodata_value is None:
|
||
nodata_value = src_data.nodata if src_data.nodata is not None else 0
|
||
try:
|
||
nodata_value_cast = np.array(
|
||
nodata_value, dtype=src_data.dtypes[0]
|
||
).item()
|
||
except Exception:
|
||
nodata_value_cast = nodata_value
|
||
|
||
# 创建输出元数据:基于数据源的元数据,更新 nodata 和压缩选项
|
||
out_meta = src_data.meta.copy()
|
||
out_meta.update(
|
||
{
|
||
"nodata": nodata_value,
|
||
"compress": (
|
||
src_data.compression.value if src_data.compression else "lzw"
|
||
),
|
||
"tiled": src_data.is_tiled,
|
||
}
|
||
)
|
||
if src_data.is_tiled:
|
||
out_meta.update(
|
||
{
|
||
"blockxsize": src_data.block_shapes[0][0],
|
||
"blockysize": src_data.block_shapes[0][1],
|
||
}
|
||
)
|
||
|
||
# 创建输出文件
|
||
with rasterio.open(output_path, "w", **out_meta) as dst:
|
||
width, height = src_data.width, src_data.height
|
||
|
||
if tile_size is None or tile_size <= 0:
|
||
windows = [w for _, w in src_data.block_windows(1)]
|
||
else:
|
||
stride = int(tile_size)
|
||
windows = [
|
||
Window(i, j, min(stride, width - i), min(stride, height - j))
|
||
for i in range(0, width, stride)
|
||
for j in range(0, height, stride)
|
||
]
|
||
|
||
tqdm_fp = _tqdm_output()
|
||
with tqdm(
|
||
total=len(windows),
|
||
desc="处理瓦片",
|
||
unit="块",
|
||
file=tqdm_fp,
|
||
disable=_tqdm_disable(tqdm_fp),
|
||
) as pbar:
|
||
for window in windows:
|
||
# 读取相同位置的掩膜瓦片(假设完全对齐)
|
||
mask = src_mask.read(1, window=window)
|
||
remove_mask = mask == remove_value
|
||
|
||
# 读取数据瓦片
|
||
data = src_data.read(window=window) # shape: (bands, h, w)
|
||
|
||
if remove_mask.any():
|
||
for band_idx in range(data.shape[0]):
|
||
np.putmask(
|
||
data[band_idx], remove_mask, nodata_value_cast
|
||
)
|
||
|
||
dst.write(data, window=window)
|
||
pbar.update(1)
|
||
|
||
logger.info(f"处理完成,输出文件:{output_path}")
|
||
return str(output_path)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="使用二值掩膜 TIF(值为1的区域)对数据 TIF 进行掩膜,将对应位置设为 NoData。"
|
||
)
|
||
parser.add_argument("data_tif", help="要掩膜的数据 TIF 文件路径")
|
||
parser.add_argument("mask_tif", help="二值掩膜 TIF 文件路径(值为1表示需要去除的区域)")
|
||
parser.add_argument("-o", "--output", help="输出文件路径 (可选)")
|
||
parser.add_argument("-r", "--remove_value", type=int, default=1,
|
||
help="掩膜中要去除的值,默认为1")
|
||
parser.add_argument("-n", "--nodata", type=float,
|
||
help="输出 NoData 值 (可选,默认使用数据 TIF 的 NoData 或 0)")
|
||
parser.add_argument(
|
||
"-t",
|
||
"--tile_size",
|
||
type=int,
|
||
default=4096,
|
||
help="分块大小(像素),默认4096;设为0则按源文件块窗口遍历(tiled 文件通常更快)",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
mask_data_by_binary_mask(
|
||
args.data_tif, args.mask_tif, args.output,
|
||
args.remove_value, args.nodata, args.tile_size
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit(main())
|