Files
StripStitch/test V10.py
2026-04-22 09:26:39 +08:00

1336 lines
60 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批量配准 .bip 文件到参考 .tif 文件
问题:当图像中大部分是水体时,匹配过多出现在掩膜边缘,同时过滤时将本来就少的陆地匹配点也过滤掉了
"""
from pathlib import Path
import numpy as np
import cv2
import rasterio
import csv
from datetime import datetime
from rasterio.windows import from_bounds
from rasterio.warp import transform_bounds, reproject, Resampling
from affine import Affine
from vismatch import get_matcher
from vismatch.viz import plot_matches, plot_keypoints
import logging
try:
from skimage.transform import PiecewiseAffineTransform, PolynomialTransform
SKIMAGE_AVAILABLE = True
except ImportError:
SKIMAGE_AVAILABLE = False
logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换")
try:
from matplotlib.path import Path as MplPath
from scipy.spatial import ConvexHull
MATPLOTLIB_SCIPY_AVAILABLE = True
except ImportError:
MATPLOTLIB_SCIPY_AVAILABLE = False
MplPath = None
logging.warning("matplotlib 或 scipy 不可用piecewise_affine 将退化为矩形内判断")
try:
import SimpleITK as sitk
SITK_AVAILABLE = True
except ImportError:
SITK_AVAILABLE = False
logging.warning("SimpleITK 不可用,将使用仿射变换作为替代")
try:
import pirt
PIRT_AVAILABLE = True
except ImportError:
PIRT_AVAILABLE = False
logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代")
try:
from scipy.interpolate import Rbf
SCIPY_AVAILABLE = True
except ImportError:
SCIPY_AVAILABLE = False
logging.warning("scipy 不可用,将跳过 TPS 变换")
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ---------- 配置 ----------
# 请根据实际情况修改这些路径
REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径
BIP_DIR = Path(r"D:\BaiduNetdiskDownload\20250902\_3_52_52\316\agnle") # .bip 文件所在文件夹
OUT_DIR = Path(r"D:\BaiduNetdiskDownload\20250902\_3_52_52\316\jiaozhen") # 输出文件夹
# 匹配算法选择
MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
DEVICE = "cuda" # 或 "cpu"
# 变换方法选择(按优先级尝试)
TRANSFORM_METHODS = ["similarity", "affine", "homography"]
# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps"
# 匹配参数
MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素)
ROI_PAD_PX = 500 # 粗定位窗口的padding参考tif像素
MASK_PAD_PX = 100 # 匹配掩膜扩张像素(仅用于匹配阶段)
# 质量控制阈值
MIN_INLIERS = 10
MIN_INLIER_RATIO = 0.01
# 掩膜边缘羽化与过滤
FEATHER_PX = 20 # 掩膜羽化宽度(像素,先在全分辨率/ROI分辨率上做)
EDGE_BAND_PX = 30 # 剔除距离掩膜边界小于此像素的匹配点(在小图上按比例缩放)
# 纹理过滤
MIN_GRAD_QUANTILE = 0.20 # 梯度幅值的分位阈值(0~1),低于该阈值的点视为低纹理,剔除
# 创建输出目录
OUT_DIR.mkdir(parents=True, exist_ok=True)
# 创建统计输出目录和文件
STATS_DIR = OUT_DIR / "stats"
STATS_DIR.mkdir(parents=True, exist_ok=True)
STATS_CSV = STATS_DIR / "registration_stats.csv"
# ---------- 工具函数 ----------
def init_stats_csv(csv_path: Path):
"""初始化统计CSV文件"""
if not csv_path.exists():
with open(csv_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow([
'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio',
'selected_method', 'median_error', 'p95_error', 'success'
])
def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int,
inlier_ratio: float, selected_method: str, median_error: float,
p95_error: float, success: bool):
"""记录配准统计信息到CSV"""
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(csv_path, 'a', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow([
timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}",
selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success
])
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
arr = arr_chw.astype(np.float32)
if arr.shape[0] == 1:
# 单波段复制为3通道
arr = np.repeat(arr, 3, axis=0)
elif arr.shape[0] >= 3:
# 取前3波段
arr = arr[:3]
else:
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
# 百分位数拉伸,增强跨传感器匹配稳定性
p2 = np.percentile(arr, 2)
p98 = np.percentile(arr, 98)
arr = (arr - p2) / (p98 - p2 + 1e-6)
arr = np.clip(arr, 0.0, 1.0)
return arr
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
c, h, w = arr_chw.shape
s = min(1.0, max_side / max(h, w))
if s >= 1.0:
return arr_chw
new_w = int(round(w * s))
new_h = int(round(h * s))
# 用opencv缩放(逐通道)
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
return out
def _expand_window(win, pad, max_w, max_h):
"""扩展窗口并确保边界有效"""
col_off = int(max(0, win.col_off - pad))
row_off = int(max(0, win.row_off - pad))
col_end = int(min(max_w, win.col_off + win.width + pad))
row_end = int(min(max_h, win.row_off + win.height + pad))
return rasterio.windows.Window(col_off, row_off, col_end - col_off, col_end - col_off,)
def _pixel_size_xy(transform: Affine):
rx = float(np.hypot(transform.a, transform.d))
ry = float(np.hypot(transform.b, transform.e))
if not np.isfinite(rx) or rx <= 0:
rx = float(abs(transform.a)) if transform.a != 0 else 1.0
if not np.isfinite(ry) or ry <= 0:
ry = float(abs(transform.e)) if transform.e != 0 else 1.0
return rx, ry
def _grid_from_bounds(bounds, res_x: float, res_y: float):
left, bottom, right, top = [float(v) for v in bounds]
res_x = float(abs(res_x))
res_y = float(abs(res_y))
w = int(np.ceil((right - left) / max(1e-12, res_x)))
h = int(np.ceil((top - bottom) / max(1e-12, res_y)))
w = max(1, w)
h = max(1, h)
out_transform = Affine(res_x, 0.0, left, 0.0, -res_y, top)
return out_transform, w, h
def estimate_transform(method, k0, k1):
"""统一的变换估计函数,支持多种变换类型"""
if method == "translation":
# 简单平移:用内点的平均位移
if len(k0) == 0:
return None, None
dx = np.mean(k1[:, 0] - k0[:, 0])
dy = np.mean(k1[:, 1] - k0[:, 1])
A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32)
return "A", A
elif method == "euclidean":
# 欧式变换(旋转+平移),约束等比缩放=1
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "similarity":
# 相似变换(旋转+等比缩放+平移)
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "affine":
# 全仿射变换(旋转+非等比缩放+剪切+平移)
A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "homography":
# 投影变换8DOF透视
H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0)
return "H", H
elif method == "piecewise_affine":
# 分片仿射变换
if not SKIMAGE_AVAILABLE:
return None, None
try:
tform = PiecewiseAffineTransform()
tform.estimate(k0, k1)
return "piecewise", tform
except Exception:
return None, None
elif method == "polynomial":
# 多项式变换2阶
if not SKIMAGE_AVAILABLE:
return None, None
try:
tform = PolynomialTransform()
tform.estimate(k0, k1, order=2)
return "polynomial", tform
except Exception:
return None, None
else:
raise ValueError(f"未知变换方法: {method}")
def evaluate_transform_quality(transform_type, transform, k0, k1):
"""评估变换质量(重投影误差)"""
if transform is None or len(k0) == 0:
return np.inf, np.inf
if transform_type == "A":
# 仿射变换重投影误差
A = transform
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
pred = (A @ np.hstack([k0, ones]).T).T
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type == "H":
# 单应变换重投影误差
H = transform
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
src_h = np.hstack([k0, ones]).T
warped = H @ src_h
warped /= (warped[2:3, :] + 1e-6)
pred = warped[:2, :].T
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type in ["piecewise", "polynomial"]:
# scikit-image 变换重投影误差
pred = transform(k0)
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
else:
return np.inf, np.inf
return float(np.median(e)), float(np.percentile(e, 95))
def _norm01_hw(x: np.ndarray) -> np.ndarray:
"""对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性"""
x = x.astype(np.float32, copy=False)
p2 = float(np.percentile(x, 2))
p98 = float(np.percentile(x, 98))
y = (x - p2) / (p98 - p2 + 1e-6)
return np.clip(y, 0.0, 1.0)
def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)):
"""
numpy(H,W)->SimpleITK Image。
物理坐标约定为“像素坐标系”spacing=1, direction=Iorigin=(x0,y0)。
"""
img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False))
img.SetSpacing((1.0, 1.0))
img.SetOrigin((float(origin_xy[0]), float(origin_xy[1])))
img.SetDirection((1.0, 0.0, 0.0, 1.0))
return img
def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10):
"""用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内"""
min_x = int(np.floor(k1_global[:, 0].min())) - pad
max_x = int(np.ceil (k1_global[:, 0].max())) + pad
min_y = int(np.floor(k1_global[:, 1].min())) - pad
max_y = int(np.ceil (k1_global[:, 1].max())) + pad
min_x = max(0, min_x)
min_y = max(0, min_y)
max_x = min(ref_w, max_x)
max_y = min(ref_h, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
return min_x, min_y, bbox_w, bbox_h
def _downscale_mask_hw(mask_hw: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
"""将(H,W)二值掩膜缩放到目标尺寸,保持最近邻"""
m = cv2.resize(mask_hw.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST)
return m > 0
def _soft_alpha_from_mask(mask_hw: np.ndarray, feather_px: int) -> np.ndarray:
"""
二值掩膜 -> 软掩膜 alpha∈[0,1],边缘处按距离线性上升,避免硬边缘。
mask_hw: bool/uint8 (H,W) True/1表示有效
"""
if mask_hw is None:
return None
m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255
# 距离变换仅对前景内部有效,计算到边界的距离
dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3)
if feather_px <= 0:
alpha = (dist > 0).astype(np.float32)
else:
alpha = np.clip(dist / float(feather_px), 0.0, 1.0).astype(np.float32)
return alpha # (H,W) float32
def _distance_keep_mask(mask_hw: np.ndarray, min_dist_px: int) -> np.ndarray:
"""
生成"远离边界"的保留掩膜:仅保留距离边界>=min_dist_px的像素。
"""
if mask_hw is None:
return None
m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255
dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3)
keep = dist >= float(max(1, min_dist_px))
return keep
def _grad_mask_from_chw(img_chw: np.ndarray, quantile: float) -> np.ndarray:
"""
根据梯度幅值生成纹理掩膜H,WTrue=纹理足够。
使用与匹配同尺寸的CHW图像。
"""
# 转灰度
g = img_chw.mean(axis=0).astype(np.float32) # (H,W)
gx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3)
gy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3)
mag = np.sqrt(gx*gx + gy*gy)
thr = float(np.quantile(mag, quantile)) if mag.size > 0 else 0.0
return mag >= thr # (H,W) bool
def _filter_matches_by_masks(result: dict, src_mask_small: np.ndarray, ref_mask_small: np.ndarray) -> dict:
"""将匹配与内点严格限制在掩膜内"""
if src_mask_small is None or ref_mask_small is None:
return result
def keep_in_mask(kpts: np.ndarray, mask_hw: np.ndarray) -> np.ndarray:
if kpts is None or len(kpts) == 0:
return np.zeros((0,), dtype=bool)
kpts = np.asarray(kpts)
xs = np.clip(np.rint(kpts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1)
ys = np.clip(np.rint(kpts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1)
return mask_hw[ys, xs]
# 过滤 matched_kpts
if "matched_kpts0" in result and "matched_kpts1" in result:
mk0 = np.asarray(result["matched_kpts0"])
mk1 = np.asarray(result["matched_kpts1"])
if len(mk0) == len(mk1) and len(mk0) > 0:
keep_m = keep_in_mask(mk0, src_mask_small) & keep_in_mask(mk1, ref_mask_small)
result["matched_kpts0"] = mk0[keep_m]
result["matched_kpts1"] = mk1[keep_m]
# 过滤 inlier_kpts
if "inlier_kpts0" in result and "inlier_kpts1" in result and result["inlier_kpts0"] is not None:
ik0 = np.asarray(result["inlier_kpts0"])
ik1 = np.asarray(result["inlier_kpts1"])
if len(ik0) == len(ik1) and len(ik0) > 0:
keep_i = keep_in_mask(ik0, src_mask_small) & keep_in_mask(ik1, ref_mask_small)
result["inlier_kpts0"] = ik0[keep_i]
result["inlier_kpts1"] = ik1[keep_i]
result["num_inliers"] = int(len(result["inlier_kpts0"]))
return result
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path):
"""处理单个 .bip 文件到参考 .tif 的配准"""
try:
with rasterio.open(bip_path) as src:
logger.info(f"处理文件: {bip_path.name}")
# 初始化统计变量
num_inliers = 0
num_matches = 0
inlier_ratio = 0.0
selected_method = "none"
median_error = float('inf')
p95_error = float('inf')
success = False
# 检查CRS
if src.crs is None:
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息尝试使用参考文件的CRS")
src_crs = ref_dataset.crs
else:
src_crs = src.crs
ref_crs = ref_dataset.crs
if ref_crs is None:
raise RuntimeError(f"参考文件缺少CRS信息")
# 1) 用"源图有效掩膜"的包围盒推参考ROI比整图bounds更贴近有效重叠
try:
src_mask = (src.read_masks(1) > 0) # True=有效
rows_any = np.any(src_mask, axis=1)
cols_any = np.any(src_mask, axis=0)
if rows_any.any() and cols_any.any():
rmin = int(rows_any.argmax())
rmax = int(src.height - 1 - rows_any[::-1].argmax())
cmin = int(cols_any.argmax())
cmax = int(src.width - 1 - cols_any[::-1].argmax())
valid_win_src = rasterio.windows.Window(cmin, rmin, cmax - cmin + 1, rmax - rmin + 1)
valid_bounds_src = rasterio.windows.bounds(valid_win_src, transform=src.transform)
b = transform_bounds(src_crs, ref_crs, *valid_bounds_src, densify_pts=21)
else:
# 掩膜无效时回退到整图bounds
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
except Exception:
src_mask = None # 后续可选源图掩膜时用到
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
win0 = from_bounds(*b, transform=ref_dataset.transform)
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
if win.width <= 0 or win.height <= 0:
logger.warning(f"无重叠区域: {bip_path.name}")
return False
# 2) 读取数据
# 读取所有波段,如果是多波段的话
src_arr = src.read() # (bands, H, W)
if src_arr.ndim == 2: # 单波段
src_arr = src_arr[None, ...] # 增加波段维度
# 读取参考文件的ROI
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
if ref_arr.ndim == 2: # 单波段
ref_arr = ref_arr[None, ...] # 增加波段维度
# 将源图有效掩膜重投影到参考ROI并适度膨胀后作为匹配掩膜
try:
if src_mask is None:
src_mask = (src.read_masks(1) > 0)
ref_roi_transform = ref_dataset.window_transform(win)
roi_h, roi_w = int(win.height), int(win.width)
dst_mask = np.zeros((roi_h, roi_w), dtype=np.uint8)
reproject(
source=src_mask.astype(np.uint8),
destination=dst_mask,
src_transform=src.transform,
src_crs=src_crs,
dst_transform=ref_roi_transform,
dst_crs=ref_crs,
resampling=Resampling.nearest
)
if MASK_PAD_PX > 0:
k = max(1, MASK_PAD_PX * 2 + 1) # odd kernel size
k = min(k, 99) # 防止核过大导致性能问题,可按需调整/删除
kernel = np.ones((k, k), np.uint8)
dst_mask = cv2.dilate(dst_mask, kernel, iterations=1)
except Exception:
# 掩膜获取/重投影失败则不使用掩膜
dst_mask = None
# 转换为匹配所需的格式
src_img = _to_3ch_float01(src_arr)
ref_img = _to_3ch_float01(ref_arr)
# 软掩膜:避免在边界产生硬高对比边
try:
alpha_src = _soft_alpha_from_mask(src_mask, FEATHER_PX) if src_mask is not None else None
except Exception:
alpha_src = None
try:
alpha_ref = _soft_alpha_from_mask(dst_mask, FEATHER_PX) if dst_mask is not None else None
except Exception:
alpha_ref = None
if alpha_src is not None:
alpha_src3 = np.repeat(alpha_src[None, ...], 3, axis=0).astype(src_img.dtype)
src_img = src_img * alpha_src3
if alpha_ref is not None:
alpha_ref3 = np.repeat(alpha_ref[None, ...], 3, axis=0).astype(ref_img.dtype)
ref_img = ref_img * alpha_ref3
# 3) 匹配用降采样版本,提速 + 增稳
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
# 4) 精配准img0=src, img1=ref_roi
result = matcher(src_small, ref_small)
# 与小图同尺寸的掩膜
src_mask_small = _downscale_mask_hw(src_mask, src_small.shape[1], src_small.shape[2]) if 'src_mask' in locals() and src_mask is not None else None
ref_mask_small = _downscale_mask_hw(dst_mask, ref_small.shape[1], ref_small.shape[2]) if 'dst_mask' in locals() and dst_mask is not None else None
# 剔除掩膜边缘带(小图尺度的最小距离)
def _scale_px(px_full: int, full_wh, small_wh) -> int:
# 用平均缩放也可以分别对H/W计算后取最小
sy = small_wh[0] / max(1, full_wh[0])
sx = small_wh[1] / max(1, full_wh[1])
s = 0.5 * (sx + sy)
return max(1, int(round(px_full * s)))
edge_band_src_small = _scale_px(EDGE_BAND_PX, (src_img.shape[1], src_img.shape[2]), (src_small.shape[1], src_small.shape[2]))
edge_band_ref_small = _scale_px(EDGE_BAND_PX, (ref_img.shape[1], ref_img.shape[2]), (ref_small.shape[1], ref_small.shape[2]))
keep_src_edge = _distance_keep_mask(src_mask_small, edge_band_src_small) if src_mask_small is not None else None
keep_ref_edge = _distance_keep_mask(ref_mask_small, edge_band_ref_small) if ref_mask_small is not None else None
# 纹理掩膜
keep_src_tex = _grad_mask_from_chw(src_small, MIN_GRAD_QUANTILE)
keep_ref_tex = _grad_mask_from_chw(ref_small, MIN_GRAD_QUANTILE)
# 组合最终保留掩膜(边缘+纹理),二者都要满足
def _combine_keep(m_edge, m_tex):
if m_edge is None:
return m_tex
return (m_edge & m_tex)
keep_src_final = _combine_keep(keep_src_edge, keep_src_tex)
keep_ref_final = _combine_keep(keep_ref_edge, keep_ref_tex)
# 将匹配与内点严格限制在最终掩膜内
def _filter_by_bool_masks(res, m_src, m_ref):
if m_src is None or m_ref is None:
return res
def keep_in(mask_hw, pts):
if pts is None or len(pts) == 0:
return np.zeros((0,), dtype=bool)
xs = np.clip(np.rint(pts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1)
ys = np.clip(np.rint(pts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1)
return mask_hw[ys, xs]
# matched
if "matched_kpts0" in res and "matched_kpts1" in res:
mk0 = np.asarray(res["matched_kpts0"]); mk1 = np.asarray(res["matched_kpts1"])
if len(mk0) == len(mk1) and len(mk0) > 0:
keep_m = keep_in(m_src, mk0) & keep_in(m_ref, mk1)
res["matched_kpts0"] = mk0[keep_m]
res["matched_kpts1"] = mk1[keep_m]
# inliers
if "inlier_kpts0" in res and "inlier_kpts1" in res and res["inlier_kpts0"] is not None:
ik0 = np.asarray(res["inlier_kpts0"]); ik1 = np.asarray(res["inlier_kpts1"])
if len(ik0) == len(ik1) and len(ik0) > 0:
keep_i = keep_in(m_src, ik0) & keep_in(m_ref, ik1)
res["inlier_kpts0"] = ik0[keep_i]
res["inlier_kpts1"] = ik1[keep_i]
res["num_inliers"] = int(len(res["inlier_kpts0"]))
return res
result = _filter_by_bool_masks(result, keep_src_final, keep_ref_final)
# 统计(以过滤后的结果为准)
num_inl = int(result.get("num_inliers", len(result.get("inlier_kpts0", []))))
num_m = len(result.get("matched_kpts0", []))
ratio = (num_inl / num_m) if num_m else 0.0
# 更新统计变量
num_inliers = num_inl
num_matches = num_m
inlier_ratio = ratio
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
# 保存匹配可视化图像使用与匹配同尺寸的图像保持CHW格式
viz_dir = out_dir / "visualizations"
viz_dir.mkdir(exist_ok=True)
matches_path = viz_dir / f"{bip_path.stem}_matches.png"
plot_matches(src_small, ref_small, result, save_path=str(matches_path))
logger.info(f"匹配可视化已保存: {matches_path}")
# 关键点可视化(源图像)
kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png"
plot_keypoints(
src_small,
{"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]},
save_path=str(kpts_src_path)
)
logger.info(f"源图像关键点可视化已保存: {kpts_src_path}")
# 关键点可视化(参考图像)
kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png"
plot_keypoints(
ref_small,
{"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]},
save_path=str(kpts_ref_path)
)
logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}")
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
logger.warning(f"匹配质量不足: {bip_path.name}")
# 记录失败的统计信息
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "failed_quality_check", median_error, p95_error, False)
return False
# 5) 用内点估计多种变换并自动选择最优
# 先计算全分辨率坐标
k0_small = result["inlier_kpts0"].astype(np.float32)
k1_small = result["inlier_kpts1"].astype(np.float32)
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src)
S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI)
ones = np.ones((k0_small.shape[0], 1), dtype=np.float32)
k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素
k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素
k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素
# 用全分辨率坐标进行所有模型的估计和评估
best_transform = None
best_transform_type = None
best_error = np.inf
best_median_error = np.inf
best_method = None
for method in TRANSFORM_METHODS:
transform_type, transform = estimate_transform(method, k0_full, k1_global)
if transform is None:
continue
med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global)
# 选择重投影误差最小的变换
if p95_err < best_error:
best_transform = transform
best_transform_type = transform_type
best_error = p95_err
best_median_error = med_err
best_method = method
logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}")
if best_transform is None:
logger.warning(f"所有变换方法都失败: {bip_path.name}")
# 记录失败的统计信息
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "failed_transform", median_error, p95_error, False)
return False
# 更新统计变量
selected_method = best_method
median_error = best_median_error
p95_error = best_error
logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}")
# 6) 根据变换类型进行相应的配准处理
if best_transform_type == "A":
# 仿射变换A 已是 src_full_pixel -> ref_full_pixel直接构造像素->地图仿射
A = best_transform # 2x3, src_full_pixel -> ref_full_pixel
A3 = np.eye(3, dtype=np.float64)
A3[:2, :] = A
# src_pixel -> map
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
M_map = Rt @ A3
corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2],
M_map[1,0], M_map[1,1], M_map[1,2])
# 用 M_map 求最小外接矩形(先到 map再到 ref 像素)
Rt_inv = np.linalg.inv(Rt)
src_h, src_w = src.height, src.width
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corn_h = np.hstack([corners, np.ones((4,1))]).T
map_corners = (M_map @ corn_h).T[:, :2]
pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2]
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil (pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil (pix_corners[:,1].max())) + 10
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform)
res_x, res_y = _pixel_size_xy(src.transform)
out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = src.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=out_h,
width=out_w,
count=src.count,
transform=out_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((out_h, out_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=out_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.nearest,
)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, selected_method, median_error, p95_error, success)
return True
# ---- 非仿射变换处理 ----
elif best_transform_type == "H":
# 单应变换H 已是 src_full_pixel -> ref_full_pixel
H_full = best_transform # 3x3
try:
# 用 H_full 映射源四角 -> 参考像素,求最小外接矩形
src_h, src_w = src.height, src.width
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32)
corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T
dst_h = (H_full @ corn_h)
dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T
min_x = int(np.floor(dst[:,0].min())) - 10
max_x = int(np.ceil (dst[:,0].max())) + 10
min_y = int(np.floor(dst[:,1].min())) - 10
max_y = int(np.ceil (dst[:,1].max())) + 10
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform)
res_x, res_y = _pixel_size_xy(src.transform)
out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = src.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=out_h,
width=out_w,
count=src.count,
transform=out_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
ref_transform = ref_dataset.transform
Rt = np.array(
[[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0.0, 0.0, 1.0]],
dtype=np.float64,
)
Out = np.array(
[[out_transform.a, out_transform.b, out_transform.c],
[out_transform.d, out_transform.e, out_transform.f],
[0.0, 0.0, 1.0]],
dtype=np.float64,
)
M = np.linalg.inv(Out) @ Rt @ H_full.astype(np.float64)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = cv2.warpPerspective(
src_band,
M,
(out_w, out_h),
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=float(dst_nodata)
).astype(np.float32)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, selected_method, median_error, p95_error, success)
return True
except Exception as e:
logger.warning(f"单应变换异常: {e}")
# 继续到仿射回退
elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]:
# 分片仿射或多项式变换:使用 scikit-image
transform = best_transform # 已用 k0_full/k1_global 估计
try:
# 用目标侧匹配点k1_global决定外接矩形更稳
pad = 10
min_x = int(np.floor(k1_global[:, 0].min())) - pad
max_x = int(np.ceil (k1_global[:, 0].max())) + pad
min_y = int(np.floor(k1_global[:, 1].min())) - pad
max_y = int(np.ceil (k1_global[:, 1].max())) + pad
min_x = max(0, min_x)
min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}")
return False
# 创建输出窗口
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 定义带偏移的逆映射函数
off_x, off_y = min_x, min_y
if best_transform_type in ["polynomial", "polynomial_order3"]:
# 对于多项式,估计逆变换
order = 2 if best_transform_type == "polynomial" else 3
t_inv = PolynomialTransform()
t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源
# 目标侧点集的内点判定(用于限制外推)
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y],
[min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
def point_inside(xy):
return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) &
(xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h))
def inv_map_rc(coords):
# coords: (N,2) in (row, col)
rc = np.asarray(coords)
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref
inside = point_inside(xy)
xy_src = np.full_like(xy, fill_value=-1.0)
if np.any(inside):
xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src
# 确保坐标在源图像范围内
xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1)
xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1)
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
elif best_transform_type == "piecewise": # piecewise_affine
# 目标侧点集的内点判定
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
# 使用当前裁剪窗口的边界创建矩形
rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
# 退化为矩形内判断
def point_inside(xy):
return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \
(xy[:,1] >= min_y) & (xy[:,1] <= max_y)
def inv_map_rc(coords):
rc = np.asarray(coords)
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref
inside = point_inside(xy)
xy_src = np.full_like(xy, fill_value=-1.0)
if np.any(inside):
xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src)
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
# 使用 scikit-image 进行变换重采样
from skimage.transform import warp
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = warp(
src_band,
inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射
output_shape=(bbox_h, bbox_w),
mode='constant',
cval=dst_nodata,
preserve_range=True,
order=0
).astype(np.float32)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, selected_method, median_error, p95_error, success)
return True
except Exception as e:
logger.warning(f"{best_transform_type}变换异常: {e}")
# 继续到仿射回退
# ---- 回退:使用仿射变换,保证最小可用结果 ----
transform = best_transform
try:
min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1(
k1_global, ref_dataset.width, ref_dataset.height, pad=10
)
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
rect = np.array(
[[min_x, min_y], [min_x + bbox_w, min_y],
[min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]],
dtype=float
)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
def point_inside(xy):
return (
(xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) &
(xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h)
)
off_x, off_y = min_x, min_y
tps_inv = transform["inv"] # ref -> src
def inv_map_rc(coords):
rc = np.asarray(coords, dtype=np.float64)
xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y)
inside = point_inside(xy_ref)
xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64)
if np.any(inside):
# 使用RBF插值计算逆映射
xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1])
xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1])
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 优先用 skimage.warp缺失时用 SimpleITK Resample 兜底
if SKIMAGE_AVAILABLE:
from skimage.transform import warp
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = warp(
src_band,
inverse_map=inv_map_rc,
output_shape=(bbox_h, bbox_w),
mode='constant',
cval=dst_nodata,
preserve_range=True,
order=0
).astype(np.float32)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
else:
# OpenCV remap 版本(无需 skimage/SimpleITK
with rasterio.open(out_path, "w", **out_profile) as out_ds:
# 创建映射网格
y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w]
coords = np.column_stack([y_coords.ravel(), x_coords.ravel()])
# 计算逆映射
mapped_coords = inv_map_rc(coords)
map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32)
map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32)
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
# 使用OpenCV的remap进行重采样
dst_band = cv2.remap(
src_band, map_x, map_y,
interpolation=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=dst_nodata
)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.warning(f"tps变换异常: {e}")
# 继续到仿射回退
# ---- 回退:使用仿射变换,保证最小可用结果 ----
# 重新估计仿射变换作为fallback
A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0)
if A_fallback is None:
logger.warning(f"仿射回退也失败: {bip_path.name}")
return False
# 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64)
S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64)
A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback
M_full = S1_inv @ A3 @ S0
T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64)
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
src_pixel_to_map_corrected = Rt @ T_off @ M_full
corrected_affine = Affine(
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
)
# 计算源 BIP 四角经过仿射变换后的最小外接矩形
# 将 rasterio.Affine 转为 3x3 像素->地图矩阵
M_map = np.array([
[corrected_affine.a, corrected_affine.b, corrected_affine.c],
[corrected_affine.d, corrected_affine.e, corrected_affine.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
# 参考底图的 像素->地图 矩阵及其逆
ref_transform = ref_dataset.transform
Rt = np.array([
[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
Rt_inv = np.linalg.inv(Rt)
# 源影像四角(源像素坐标)
src_h, src_w = src.height, src.width
src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4)
# 源像素 -> 地图坐标
map_corners = (M_map @ corners_h).T[:, :2]
# 地图坐标 -> 参考像素坐标
pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3)
pix_corners = pix_corners_h[:, :2]
# 最小外接矩形(像素)
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil( pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil( pix_corners[:,1].max())) + 10
# 边界裁剪
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
# 如果外接矩形太小,跳过
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform)
res_x, res_y = _pixel_size_xy(src.transform)
out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = src.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=out_h,
width=out_w,
count=src.count,
transform=out_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((out_h, out_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=out_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.nearest,
)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "affine_fallback", median_error, p95_error, success)
return True
except Exception as e:
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
# 记录失败的统计信息
try:
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "exception", median_error, p95_error, False)
except:
pass # 避免统计记录失败影响主要错误处理
return False
# ---------- 主逻辑 ----------
def main():
logger.info("开始批量配准处理...")
# 检查输入文件是否存在
if not Path(REF_TIF).exists():
logger.error(f"参考文件不存在: {REF_TIF}")
return
if not BIP_DIR.exists():
logger.error(f"BIP文件夹不存在: {BIP_DIR}")
return
# 初始化统计CSV文件
init_stats_csv(STATS_CSV)
logger.info(f"统计信息将保存到: {STATS_CSV}")
# 初始化匹配器
logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}")
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
# 打开参考文件
with rasterio.open(REF_TIF) as ref:
logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}")
# 查找所有 .bip 文件
bip_files = list(BIP_DIR.glob("*.bip"))
logger.info(f"找到 {len(bip_files)} 个 .bip 文件")
success_count = 0
for bip_path in bip_files:
if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV):
success_count += 1
logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准")
if __name__ == "__main__":
main()