Files
StripStitch/test.py
2026-03-06 17:24:55 +08:00

291 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批量配准 .bip 文件到参考 .tif 文件
使用地理信息进行粗定位,然后用图像匹配进行精配准
"""
from pathlib import Path
import numpy as np
import cv2
import rasterio
from rasterio.windows import from_bounds
from rasterio.warp import transform_bounds, reproject, Resampling
from affine import Affine
from vismatch import get_matcher
import logging
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ---------- 配置 ----------
# 请根据实际情况修改这些路径
REF_TIF = r"E:\is2\jiashixian\result.tif" # 参考 tif 文件路径
BIP_DIR = Path(r"E:\is2\jiashixian\Geoout\1") # .bip 文件所在文件夹
OUT_DIR = Path(r"E:\is2\jiashixian\matchanything-roma") # 输出文件夹
# 匹配算法选择
MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
DEVICE = "cuda" # 或 "cpu"
# 匹配参数
MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素)
ROI_PAD_PX = 300 # 粗定位窗口的padding参考tif像素
# 质量控制阈值
MIN_INLIERS = 30 # 最少内点数
MIN_INLIER_RATIO = 0.15 # 最少内点比例
# 创建输出目录
OUT_DIR.mkdir(parents=True, exist_ok=True)
# ---------- 工具函数 ----------
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
arr = arr_chw.astype(np.float32)
if arr.shape[0] == 1:
# 单波段复制为3通道
arr = np.repeat(arr, 3, axis=0)
elif arr.shape[0] >= 3:
# 取前3波段
arr = arr[:3]
else:
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
# 百分位数拉伸,增强跨传感器匹配稳定性
p2 = np.percentile(arr, 2)
p98 = np.percentile(arr, 98)
arr = (arr - p2) / (p98 - p2 + 1e-6)
arr = np.clip(arr, 0.0, 1.0)
return arr
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
c, h, w = arr_chw.shape
s = min(1.0, max_side / max(h, w))
if s >= 1.0:
return arr_chw
new_w = int(round(w * s))
new_h = int(round(h * s))
# 用opencv缩放(逐通道)
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
return out
def _expand_window(win, pad, max_w, max_h):
"""扩展窗口并确保边界有效"""
col_off = int(max(0, win.col_off - pad))
row_off = int(max(0, win.row_off - pad))
col_end = int(min(max_w, win.col_off + win.width + pad))
row_end = int(min(max_h, win.row_off + win.height + pad))
return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off)
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path):
"""处理单个 .bip 文件到参考 .tif 的配准"""
try:
with rasterio.open(bip_path) as src:
logger.info(f"处理文件: {bip_path.name}")
# 检查CRS
if src.crs is None:
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息尝试使用参考文件的CRS")
src_crs = ref_dataset.crs
else:
src_crs = src.crs
ref_crs = ref_dataset.crs
if ref_crs is None:
raise RuntimeError(f"参考文件缺少CRS信息")
# 1) 用地理信息把 src.bounds 转到 ref CRS再裁 ref ROI
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
win0 = from_bounds(*b, transform=ref_dataset.transform)
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
if win.width <= 0 or win.height <= 0:
logger.warning(f"无重叠区域: {bip_path.name}")
return False
# 2) 读取数据
# 读取所有波段,如果是多波段的话
src_arr = src.read() # (bands, H, W)
if src_arr.ndim == 2: # 单波段
src_arr = src_arr[None, ...] # 增加波段维度
# 读取参考文件的ROI
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
if ref_arr.ndim == 2: # 单波段
ref_arr = ref_arr[None, ...] # 增加波段维度
# 转换为匹配所需的格式
src_img = _to_3ch_float01(src_arr)
ref_img = _to_3ch_float01(ref_arr)
# 3) 匹配用降采样版本,提速 + 增稳
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
# 4) 精配准img0=src, img1=ref_roi
result = matcher(src_small, ref_small)
num_inl = int(result["num_inliers"])
num_m = len(result["matched_kpts0"])
ratio = (num_inl / num_m) if num_m else 0.0
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
logger.warning(f"匹配质量不足: {bip_path.name}")
return False
# 5) 用内点估计仿射变换
k0 = result["inlier_kpts0"].astype(np.float32)
k1 = result["inlier_kpts1"].astype(np.float32)
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
if A is None:
logger.warning(f"仿射估计失败: {bip_path.name}")
return False
# 6) 把"src_small->ref_small"的仿射映射回"src_full->ref_full_roi"
# 缩放系数full/small
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0 = np.array([[1/s0x, 0, 0],
[0, 1/s0y, 0],
[0, 0, 1]], dtype=np.float64) # full -> small
S1_inv = np.array([[s1x, 0, 0],
[0, s1y, 0],
[0, 0, 1]], dtype=np.float64) # small -> full(roi)
A3 = np.eye(3, dtype=np.float64)
A3[:2, :] = A # small src -> small ref_roi
# full_src -> full_ref_roi用 S0而不是 S0_inv
M_full = S1_inv @ A3 @ S0
# 7) 目标输出:重采样到 ref 全图网格
# 需要"src像素 -> 地图坐标"的修正 transform
T_off = np.array([[1, 0, win.col_off],
[0, 1, win.row_off],
[0, 0, 1]], dtype=np.float64)
# ref_transform 转 3x3
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
src_pixel_to_map_corrected = Rt @ T_off @ M_full
# 转回 Affine
corrected_affine = Affine(
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
)
# 8) 重采样输出到 ref 网格多波段BIP 输出)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
# 获取 NoData 值
src_nodata = src.nodata # 可能是 None、0、65535 等
dst_nodata = src_nodata if src_nodata is not None else 0
# 以参考影像的网格/坐标为目标,波段数和数据类型继承源影像
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0], # 尽量保持与源影像一致的数据类型
height=ref_dataset.height,
width=ref_dataset.width,
count=src.count, # 多波段
transform=ref_transform,
crs=ref_crs,
interleave="bip", # 指定 BIP
compress=None,
nodata=dst_nodata # 设置 NoData 值
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
# 逐波段重投影float32 计算更稳)
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((ref_dataset.height, ref_dataset.width), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine, # 由原逻辑推导的 src像素->ref像素 的仿射
src_crs=ref_crs,
dst_transform=ref_transform,
dst_crs=ref_crs,
src_nodata=src_nodata, # 新增 NoData 处理
dst_nodata=dst_nodata, # 新增 NoData 处理
resampling=Resampling.bilinear,
)
# 调试统计信息
logger.info(f"波段 {b}: min={float(np.nanmin(dst_band)):.2f}, max={float(np.nanmax(dst_band)):.2f}, "
f"mean={float(np.nanmean(dst_band)):.2f}, nodata占比={(dst_band==dst_nodata).mean():.2%}")
# 转回目标 dtype若为整型先裁剪再转型
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
if src_nodata is not None:
# 保持 nodata 不被拉伸
mask = (dst_band == dst_nodata)
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max)
dst_band = dst_band.astype(out_profile["dtype"])
if src_nodata is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准: {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
return False
# ---------- 主逻辑 ----------
def main():
logger.info("开始批量配准处理...")
# 检查输入文件是否存在
if not Path(REF_TIF).exists():
logger.error(f"参考文件不存在: {REF_TIF}")
return
if not BIP_DIR.exists():
logger.error(f"BIP文件夹不存在: {BIP_DIR}")
return
# 初始化匹配器
logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}")
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
# 打开参考文件
with rasterio.open(REF_TIF) as ref:
logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}")
# 查找所有 .bip 文件
bip_files = list(BIP_DIR.glob("*.bip"))
logger.info(f"找到 {len(bip_files)} 个 .bip 文件")
success_count = 0
for bip_path in bip_files:
if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR):
success_count += 1
logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准")
if __name__ == "__main__":
main()