Files
StripStitch/test V2.py
2026-03-06 17:24:55 +08:00

360 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批量配准 .bip 文件到参考 .tif 文件
仿射变换修改
"""
from pathlib import Path
import numpy as np
import cv2
import rasterio
from rasterio.windows import from_bounds
from rasterio.warp import transform_bounds, reproject, Resampling
from affine import Affine
from vismatch import get_matcher
import logging
from osgeo import gdal
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ---------- 配置 ----------
# 请根据实际情况修改这些路径
REF_TIF = r"E:\is2\jiashixian\result.tif" # 参考 tif 文件路径
BIP_DIR = Path(r"E:\is2\jiashixian\Geoout\1") # .bip 文件所在文件夹
OUT_DIR = Path(r"E:\is2\jiashixian\matchanything-roma") # 输出文件夹
# 匹配算法选择
MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
DEVICE = "cuda" # 或 "cpu"
# 匹配参数
MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素)
ROI_PAD_PX = 500 # 粗定位窗口的padding参考tif像素
# 质量控制阈值
MIN_INLIERS = 30 # 最少内点数
MIN_INLIER_RATIO = 0.15 # 最少内点比例
# 创建输出目录
OUT_DIR.mkdir(parents=True, exist_ok=True)
# ---------- 工具函数 ----------
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
arr = arr_chw.astype(np.float32)
if arr.shape[0] == 1:
# 单波段复制为3通道
arr = np.repeat(arr, 3, axis=0)
elif arr.shape[0] >= 3:
# 取前3波段
arr = arr[:3]
else:
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
# 百分位数拉伸,增强跨传感器匹配稳定性
p2 = np.percentile(arr, 2)
p98 = np.percentile(arr, 98)
arr = (arr - p2) / (p98 - p2 + 1e-6)
arr = np.clip(arr, 0.0, 1.0)
return arr
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
c, h, w = arr_chw.shape
s = min(1.0, max_side / max(h, w))
if s >= 1.0:
return arr_chw
new_w = int(round(w * s))
new_h = int(round(h * s))
# 用opencv缩放(逐通道)
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
return out
def _expand_window(win, pad, max_w, max_h):
"""扩展窗口并确保边界有效"""
col_off = int(max(0, win.col_off - pad))
row_off = int(max(0, win.row_off - pad))
col_end = int(min(max_w, win.col_off + win.width + pad))
row_end = int(min(max_h, win.row_off + win.height + pad))
return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off)
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path):
"""处理单个 .bip 文件到参考 .tif 的配准"""
try:
with rasterio.open(bip_path) as src:
logger.info(f"处理文件: {bip_path.name}")
# 检查CRS
if src.crs is None:
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息尝试使用参考文件的CRS")
src_crs = ref_dataset.crs
else:
src_crs = src.crs
ref_crs = ref_dataset.crs
if ref_crs is None:
raise RuntimeError(f"参考文件缺少CRS信息")
# 1) 用地理信息把 src.bounds 转到 ref CRS再裁 ref ROI
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
win0 = from_bounds(*b, transform=ref_dataset.transform)
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
if win.width <= 0 or win.height <= 0:
logger.warning(f"无重叠区域: {bip_path.name}")
return False
# 2) 读取数据
# 读取所有波段,如果是多波段的话
src_arr = src.read() # (bands, H, W)
if src_arr.ndim == 2: # 单波段
src_arr = src_arr[None, ...] # 增加波段维度
# 读取参考文件的ROI
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
if ref_arr.ndim == 2: # 单波段
ref_arr = ref_arr[None, ...] # 增加波段维度
# 转换为匹配所需的格式
src_img = _to_3ch_float01(src_arr)
ref_img = _to_3ch_float01(ref_arr)
# 3) 匹配用降采样版本,提速 + 增稳
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
# 4) 精配准img0=src, img1=ref_roi
result = matcher(src_small, ref_small)
num_inl = int(result["num_inliers"])
num_m = len(result["matched_kpts0"])
ratio = (num_inl / num_m) if num_m else 0.0
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
logger.warning(f"匹配质量不足: {bip_path.name}")
return False
# 5) 用内点估计仿射变换(相似/全仿射二选一)
k0 = result["inlier_kpts0"].astype(np.float32)
k1 = result["inlier_kpts1"].astype(np.float32)
# 相似(旋转+等比缩放+平移4DOF
A_partial, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
# 全仿射(旋转+非等比缩放+剪切+平移6DOF
A_full, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
def _reproj_err(A2x3: np.ndarray, pts0: np.ndarray, pts1: np.ndarray):
if A2x3 is None:
return np.inf, np.inf
ones = np.ones((pts0.shape[0], 1), dtype=np.float32)
src_h = np.hstack([pts0, ones]) # (N,3)
pred = (A2x3 @ src_h.T).T # (N,2)
e = np.sqrt(((pred - pts1) ** 2).sum(axis=1))
return float(np.median(e)), float(np.percentile(e, 95))
med_p, p95_p = _reproj_err(A_partial, k0, k1)
med_f, p95_f = _reproj_err(A_full, k0, k1)
if (p95_f < p95_p) or (abs(p95_f - p95_p) < 0.5 and med_f < med_p):
A = A_full
model_type = "affine_full"
else:
A = A_partial
model_type = "affine_partial"
if A is None:
logger.warning(f"仿射估计失败: {bip_path.name}")
return False
logger.info(f"选用模型: {model_type}, partial(p50={med_p:.2f},p95={p95_p:.2f}), full(p50={med_f:.2f},p95={p95_f:.2f})")
# 6) 基于内点构建 GCP并使用 GDAL TPS薄板样条实现非刚性B样条近似效果配准
# 注GDAL 提供 TPS薄板样条在遥感配准中与B样条同属光滑非刚性方法工程上更常用
# 这里我们直接用 TPS 来实现你所需的“B样条变换”效果
# 将 small 尺寸的内点坐标映射回 full 分辨率
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0_inv = np.array([[s0x, 0, 0], [0, s0y, 0], [0, 0, 1]], dtype=np.float64) # small -> full (src)
S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) # small -> full (ref ROI)
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
k0_full = (S0_inv @ np.hstack([k0, ones]).T).T[:, :2] # 源full 源像素
k1_roi_full = (S1_inv @ np.hstack([k1, ones]).T).T[:, :2]
# 加上 ROI 偏移,得到参考影像全局像素坐标
k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32)
if k0_full.shape[0] < 8:
logger.warning(f"TPS 需要更多控制点(>=8当前 {k0_full.shape[0]},回退到仿射输出。")
else:
# 生成 GCP目标为参考地图坐标由全局像素 + ref_transform 转换)
ref_transform = ref_dataset.transform
def px_to_map(xp, yp):
X = ref_transform.a * xp + ref_transform.b * yp + ref_transform.c
Y = ref_transform.d * xp + ref_transform.e * yp + ref_transform.f
return X, Y
gcps = []
for (sx, sy), (rx, ry) in zip(k0_full, k1_global):
mx, my = px_to_map(rx, ry)
gcps.append(gdal.GCP(mx, my, 0.0, float(sx), float(sy)))
# 计算参考影像的输出边界与分辨率
minx = ref_transform.c
maxy = ref_transform.f
maxx = minx + ref_transform.a * ref_dataset.width + ref_transform.b * ref_dataset.height
miny = maxy + ref_transform.d * ref_dataset.width + ref_transform.e * ref_dataset.height
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
# 输出路径ENVI+BIP
out_path = out_dir / f"{bip_path.stem}_registered.bip"
# 使用 GDAL Warp 执行 TPS 变换到参考网格
# 注意:需要系统已安装 GDAL Python 绑定
try:
warp_opts = gdal.WarpOptions(
format="ENVI",
outputBounds=(minx, miny, maxx, maxy),
xRes=abs(ref_transform.a),
yRes=abs(ref_transform.e),
dstSRS=ref_crs.to_wkt() if hasattr(ref_crs, "to_wkt") else str(ref_crs),
tps=True, # 启用薄板样条(非刚性)
resampleAlg="bilinear",
srcNodata=src_nodata,
dstNodata=dst_nodata,
multithread=True,
creationOptions=["INTERLEAVE=BIP"],
gcps=gcps,
)
# 直接从源文件路径进行Warp
warp_ds = gdal.Warp(
destNameOrDestDS=str(out_path),
srcDSOrSrcDSTab=str(bip_path),
options=warp_opts,
)
if warp_ds is None:
logger.warning("GDAL Warp(TPS) 失败,回退到仿射输出。")
else:
warp_ds = None # 关闭文件
logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.warning(f"GDAL Warp(TPS) 异常,回退到仿射输出: {e}")
# ---- 回退:使用仿射(你前面的流程),保证最小可用结果 ----
# 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标
# 仍沿用上面已估计的 Apartial/full 二选一)
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64)
S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64)
A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A
M_full = S1_inv @ A3 @ S0
T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64)
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
src_pixel_to_map_corrected = Rt @ T_off @ M_full
corrected_affine = Affine(
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=ref_dataset.height,
width=ref_dataset.width,
count=src.count,
transform=ref_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((ref_dataset.height, ref_dataset.width), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=ref_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.bilinear,
)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
return False
# ---------- 主逻辑 ----------
def main():
logger.info("开始批量配准处理...")
# 检查输入文件是否存在
if not Path(REF_TIF).exists():
logger.error(f"参考文件不存在: {REF_TIF}")
return
if not BIP_DIR.exists():
logger.error(f"BIP文件夹不存在: {BIP_DIR}")
return
# 初始化匹配器
logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}")
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
# 打开参考文件
with rasterio.open(REF_TIF) as ref:
logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}")
# 查找所有 .bip 文件
bip_files = list(BIP_DIR.glob("*.bip"))
logger.info(f"找到 {len(bip_files)} 个 .bip 文件")
success_count = 0
for bip_path in bip_files:
if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR):
success_count += 1
logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准")
if __name__ == "__main__":
main()