推送
This commit is contained in:
186
tif_caijain.py
Normal file
186
tif_caijain.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""
|
||||
使用二值掩膜 TIF 文件(值为1的区域需要去除)对数据 TIF 文件进行掩膜。
|
||||
输入:
|
||||
data_tif: 要掩膜的数据文件路径
|
||||
mask_tif: 二值掩膜文件路径(值为1表示需要去除的区域)
|
||||
输出:
|
||||
掩膜后的数据 TIF 文件,仅将掩膜对应位置设为 NoData
|
||||
要求:
|
||||
两个 TIF 文件具有相同的投影、分辨率、范围和尺寸(精确对齐),
|
||||
否则程序将报错或行为未定义。
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import rasterio
|
||||
from rasterio.windows import Window
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class _NullWriter:
|
||||
def write(self, _s):
|
||||
return 0
|
||||
|
||||
def flush(self):
|
||||
return None
|
||||
|
||||
|
||||
def _tqdm_output():
|
||||
for fp in (getattr(sys, "stderr", None), getattr(sys, "stdout", None)):
|
||||
if fp is not None and hasattr(fp, "write"):
|
||||
return fp
|
||||
return _NullWriter()
|
||||
|
||||
|
||||
def _tqdm_disable(fp) -> bool:
|
||||
if isinstance(fp, _NullWriter):
|
||||
return True
|
||||
try:
|
||||
return not bool(fp.isatty())
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def mask_data_by_binary_mask(
|
||||
data_path,
|
||||
mask_path,
|
||||
output_path=None,
|
||||
remove_value=1,
|
||||
nodata_value=None,
|
||||
tile_size=4096,
|
||||
):
|
||||
"""使用二值掩膜 TIF 对数据 TIF 进行掩膜。
|
||||
|
||||
将数据 TIF 中对应掩膜值等于 remove_value 的像素设为 NoData,其余保留。
|
||||
|
||||
性能建议:
|
||||
- 若数据源是 tiled GeoTIFF,可将 tile_size 设为 0 以按源文件块窗口遍历(通常更快)。
|
||||
"""
|
||||
data_path = Path(data_path)
|
||||
mask_path = Path(mask_path)
|
||||
|
||||
if output_path is None:
|
||||
output_path = data_path.parent / f"{data_path.stem}_masked{data_path.suffix}"
|
||||
else:
|
||||
output_path = Path(output_path)
|
||||
|
||||
logger.info(f"数据文件: {data_path.name}")
|
||||
logger.info(f"掩膜文件: {mask_path.name}")
|
||||
logger.info(f"去除掩膜值: {remove_value}")
|
||||
|
||||
with rasterio.Env(GDAL_NUM_THREADS="ALL_CPUS"):
|
||||
with rasterio.open(data_path) as src_data, rasterio.open(mask_path) as src_mask:
|
||||
if src_data.crs != src_mask.crs:
|
||||
raise ValueError("数据与掩膜的 CRS 不一致,请先统一投影。")
|
||||
if src_data.transform != src_mask.transform:
|
||||
logger.warning(
|
||||
"数据与掩膜的地理变换不一致,可能未精确对齐,继续处理可能存在风险。"
|
||||
)
|
||||
if (src_data.width, src_data.height) != (src_mask.width, src_mask.height):
|
||||
raise ValueError("数据与掩膜的尺寸不一致,无法直接按像素对应掩膜。")
|
||||
|
||||
# 确定输出 NoData 值(并尽量匹配数据 dtype,避免隐式类型转换带来的开销)
|
||||
if nodata_value is None:
|
||||
nodata_value = src_data.nodata if src_data.nodata is not None else 0
|
||||
try:
|
||||
nodata_value_cast = np.array(
|
||||
nodata_value, dtype=src_data.dtypes[0]
|
||||
).item()
|
||||
except Exception:
|
||||
nodata_value_cast = nodata_value
|
||||
|
||||
# 创建输出元数据:基于数据源的元数据,更新 nodata 和压缩选项
|
||||
out_meta = src_data.meta.copy()
|
||||
out_meta.update(
|
||||
{
|
||||
"nodata": nodata_value,
|
||||
"compress": (
|
||||
src_data.compression.value if src_data.compression else "lzw"
|
||||
),
|
||||
"tiled": src_data.is_tiled,
|
||||
}
|
||||
)
|
||||
if src_data.is_tiled:
|
||||
out_meta.update(
|
||||
{
|
||||
"blockxsize": src_data.block_shapes[0][0],
|
||||
"blockysize": src_data.block_shapes[0][1],
|
||||
}
|
||||
)
|
||||
|
||||
# 创建输出文件
|
||||
with rasterio.open(output_path, "w", **out_meta) as dst:
|
||||
width, height = src_data.width, src_data.height
|
||||
|
||||
if tile_size is None or tile_size <= 0:
|
||||
windows = [w for _, w in src_data.block_windows(1)]
|
||||
else:
|
||||
stride = int(tile_size)
|
||||
windows = [
|
||||
Window(i, j, min(stride, width - i), min(stride, height - j))
|
||||
for i in range(0, width, stride)
|
||||
for j in range(0, height, stride)
|
||||
]
|
||||
|
||||
tqdm_fp = _tqdm_output()
|
||||
with tqdm(
|
||||
total=len(windows),
|
||||
desc="处理瓦片",
|
||||
unit="块",
|
||||
file=tqdm_fp,
|
||||
disable=_tqdm_disable(tqdm_fp),
|
||||
) as pbar:
|
||||
for window in windows:
|
||||
# 读取相同位置的掩膜瓦片(假设完全对齐)
|
||||
mask = src_mask.read(1, window=window)
|
||||
remove_mask = mask == remove_value
|
||||
|
||||
# 读取数据瓦片
|
||||
data = src_data.read(window=window) # shape: (bands, h, w)
|
||||
|
||||
if remove_mask.any():
|
||||
for band_idx in range(data.shape[0]):
|
||||
np.putmask(
|
||||
data[band_idx], remove_mask, nodata_value_cast
|
||||
)
|
||||
|
||||
dst.write(data, window=window)
|
||||
pbar.update(1)
|
||||
|
||||
logger.info(f"处理完成,输出文件:{output_path}")
|
||||
return str(output_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="使用二值掩膜 TIF(值为1的区域)对数据 TIF 进行掩膜,将对应位置设为 NoData。"
|
||||
)
|
||||
parser.add_argument("data_tif", help="要掩膜的数据 TIF 文件路径")
|
||||
parser.add_argument("mask_tif", help="二值掩膜 TIF 文件路径(值为1表示需要去除的区域)")
|
||||
parser.add_argument("-o", "--output", help="输出文件路径 (可选)")
|
||||
parser.add_argument("-r", "--remove_value", type=int, default=1,
|
||||
help="掩膜中要去除的值,默认为1")
|
||||
parser.add_argument("-n", "--nodata", type=float,
|
||||
help="输出 NoData 值 (可选,默认使用数据 TIF 的 NoData 或 0)")
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tile_size",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="分块大小(像素),默认4096;设为0则按源文件块窗口遍历(tiled 文件通常更快)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
mask_data_by_binary_mask(
|
||||
args.data_tif, args.mask_tif, args.output,
|
||||
args.remove_value, args.nodata, args.tile_size
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user