Files
StripStitch/test V3.py
2026-03-06 17:24:55 +08:00

971 lines
43 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批量配准 .bip 文件到参考 .tif 文件
使用 SimpleITK 实现 B 样条变换
"""
from pathlib import Path
import numpy as np
import cv2
import rasterio
from rasterio.windows import from_bounds
from rasterio.warp import transform_bounds, reproject, Resampling
from affine import Affine
from vismatch import get_matcher
import logging
try:
from skimage.transform import PiecewiseAffineTransform, PolynomialTransform
SKIMAGE_AVAILABLE = True
except ImportError:
SKIMAGE_AVAILABLE = False
logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换")
try:
from matplotlib.path import Path as MplPath
from scipy.spatial import ConvexHull
MATPLOTLIB_SCIPY_AVAILABLE = True
except ImportError:
MATPLOTLIB_SCIPY_AVAILABLE = False
MplPath = None
logging.warning("matplotlib 或 scipy 不可用piecewise_affine 将退化为矩形内判断")
try:
import SimpleITK as sitk
SITK_AVAILABLE = True
except ImportError:
SITK_AVAILABLE = False
logging.warning("SimpleITK 不可用,将使用仿射变换作为替代")
try:
import pirt
PIRT_AVAILABLE = True
except ImportError:
PIRT_AVAILABLE = False
logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代")
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ---------- 配置 ----------
# 请根据实际情况修改这些路径
REF_TIF = r"E:\is2\guidingsahn\result.tif" # 参考 tif 文件路径
BIP_DIR = Path(r"E:\is2\guidingsahn") # .bip 文件所在文件夹
OUT_DIR = Path(r"E:\is2\guidingsahn\output") # 输出文件夹
# 匹配算法选择
MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
DEVICE = "cuda" # 或 "cpu"
# 变换方法选择(按优先级尝试)
TRANSFORM_METHODS = ["homography"]
# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "tps"
# 匹配参数
MATCH_MAX_SIDE = 1500 # 匹配时最大边长(像素)
ROI_PAD_PX = 500 # 粗定位窗口的padding参考tif像素
# 质量控制阈值
MIN_INLIERS = 10 # 最少内点数
MIN_INLIER_RATIO = 0.01 # 最少内点比例
# 创建输出目录
OUT_DIR.mkdir(parents=True, exist_ok=True)
# ---------- 工具函数 ----------
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
arr = arr_chw.astype(np.float32)
if arr.shape[0] == 1:
# 单波段复制为3通道
arr = np.repeat(arr, 3, axis=0)
elif arr.shape[0] >= 3:
# 取前3波段
arr = arr[:3]
else:
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
# 百分位数拉伸,增强跨传感器匹配稳定性
p2 = np.percentile(arr, 2)
p98 = np.percentile(arr, 98)
arr = (arr - p2) / (p98 - p2 + 1e-6)
arr = np.clip(arr, 0.0, 1.0)
return arr
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
c, h, w = arr_chw.shape
s = min(1.0, max_side / max(h, w))
if s >= 1.0:
return arr_chw
new_w = int(round(w * s))
new_h = int(round(h * s))
# 用opencv缩放(逐通道)
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
return out
def _expand_window(win, pad, max_w, max_h):
"""扩展窗口并确保边界有效"""
col_off = int(max(0, win.col_off - pad))
row_off = int(max(0, win.row_off - pad))
col_end = int(min(max_w, win.col_off + win.width + pad))
row_end = int(min(max_h, win.row_off + win.height + pad))
return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off)
def estimate_transform(method, k0, k1):
"""统一的变换估计函数,支持多种变换类型"""
if method == "translation":
# 简单平移:用内点的平均位移
if len(k0) == 0:
return None, None
dx = np.mean(k1[:, 0] - k0[:, 0])
dy = np.mean(k1[:, 1] - k0[:, 1])
A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32)
return "A", A
elif method == "euclidean":
# 欧式变换(旋转+平移),约束等比缩放=1
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "similarity":
# 相似变换(旋转+等比缩放+平移)
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "affine":
# 全仿射变换(旋转+非等比缩放+剪切+平移)
A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "homography":
# 投影变换8DOF透视
H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0)
return "H", H
elif method == "piecewise_affine":
# 分片仿射变换
if not SKIMAGE_AVAILABLE:
return None, None
try:
tform = PiecewiseAffineTransform()
tform.estimate(k0, k1)
return "piecewise", tform
except Exception:
return None, None
elif method == "polynomial":
# 多项式变换2阶
if not SKIMAGE_AVAILABLE:
return None, None
try:
tform = PolynomialTransform()
tform.estimate(k0, k1, order=2)
return "polynomial", tform
except Exception:
return None, None
elif method == "tps":
# 薄板样条变换如果SimpleITK可用
SITK_TPS = SITK_AVAILABLE and hasattr(sitk, "ThinPlateSplineKernelTransform")
if not SITK_TPS:
return None, None
try:
tps_transform = sitk.ThinPlateSplineKernelTransform()
tps_transform.SetKernelTypeToThinPlateSpline()
fixed_landmarks = sitk.vectorDPoint()
moving_landmarks = sitk.vectorDPoint()
for (rx, ry), (sx, sy) in zip(k1, k0):
fixed_landmarks.push_back([float(rx), float(ry)])
moving_landmarks.push_back([float(sx), float(sy)])
tps_transform.SetFixedLandmarks(fixed_landmarks)
tps_transform.SetMovingLandmarks(moving_landmarks)
return "tps", tps_transform
except Exception:
return None, None
else:
raise ValueError(f"未知变换方法: {method}")
def evaluate_transform_quality(transform_type, transform, k0, k1):
"""评估变换质量(重投影误差)"""
if transform is None or len(k0) == 0:
return np.inf, np.inf
if transform_type == "A":
# 仿射变换重投影误差
A = transform
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
pred = (A @ np.hstack([k0, ones]).T).T
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type == "H":
# 单应变换重投影误差
H = transform
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
src_h = np.hstack([k0, ones]).T
warped = H @ src_h
warped /= (warped[2:3, :] + 1e-6)
pred = warped[:2, :].T
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type in ["piecewise", "polynomial"]:
# scikit-image 变换重投影误差
pred = transform(k0)
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type == "tps":
# TPS 变换重投影误差SimpleITK
pred = []
for pt in k0:
transformed_pt = transform.TransformPoint([float(pt[0]), float(pt[1])])
pred.append([transformed_pt[0], transformed_pt[1]])
pred = np.array(pred)
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
else:
return np.inf, np.inf
return float(np.median(e)), float(np.percentile(e, 95))
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path):
"""处理单个 .bip 文件到参考 .tif 的配准"""
try:
with rasterio.open(bip_path) as src:
logger.info(f"处理文件: {bip_path.name}")
# 检查CRS
if src.crs is None:
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息尝试使用参考文件的CRS")
src_crs = ref_dataset.crs
else:
src_crs = src.crs
ref_crs = ref_dataset.crs
if ref_crs is None:
raise RuntimeError(f"参考文件缺少CRS信息")
# 1) 用地理信息把 src.bounds 转到 ref CRS再裁 ref ROI
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
win0 = from_bounds(*b, transform=ref_dataset.transform)
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
if win.width <= 0 or win.height <= 0:
logger.warning(f"无重叠区域: {bip_path.name}")
return False
# 2) 读取数据
# 读取所有波段,如果是多波段的话
src_arr = src.read() # (bands, H, W)
if src_arr.ndim == 2: # 单波段
src_arr = src_arr[None, ...] # 增加波段维度
# 读取参考文件的ROI
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
if ref_arr.ndim == 2: # 单波段
ref_arr = ref_arr[None, ...] # 增加波段维度
# 转换为匹配所需的格式
src_img = _to_3ch_float01(src_arr)
ref_img = _to_3ch_float01(ref_arr)
# 3) 匹配用降采样版本,提速 + 增稳
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
# 4) 精配准img0=src, img1=ref_roi
result = matcher(src_small, ref_small)
num_inl = int(result["num_inliers"])
num_m = len(result["matched_kpts0"])
ratio = (num_inl / num_m) if num_m else 0.0
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
logger.warning(f"匹配质量不足: {bip_path.name}")
return False
# 5) 用内点估计多种变换并自动选择最优
# 先计算全分辨率坐标
k0_small = result["inlier_kpts0"].astype(np.float32)
k1_small = result["inlier_kpts1"].astype(np.float32)
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src)
S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI)
ones = np.ones((k0_small.shape[0], 1), dtype=np.float32)
k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素
k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素
k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素
# 用全分辨率坐标进行所有模型的估计和评估
best_transform = None
best_transform_type = None
best_error = np.inf
best_method = None
for method in TRANSFORM_METHODS:
transform_type, transform = estimate_transform(method, k0_full, k1_global)
if transform is None:
continue
med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global)
# 选择重投影误差最小的变换
if p95_err < best_error:
best_transform = transform
best_transform_type = transform_type
best_error = p95_err
best_method = method
logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}")
if best_transform is None:
logger.warning(f"所有变换方法都失败: {bip_path.name}")
return False
logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}")
# 6) 根据变换类型进行相应的配准处理
if best_transform_type == "A":
# 仿射变换A 已是 src_full_pixel -> ref_full_pixel直接构造像素->地图仿射
A = best_transform # 2x3, src_full_pixel -> ref_full_pixel
A3 = np.eye(3, dtype=np.float64)
A3[:2, :] = A
# src_pixel -> map
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
M_map = Rt @ A3
corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2],
M_map[1,0], M_map[1,1], M_map[1,2])
# 用 M_map 求最小外接矩形(先到 map再到 ref 像素)
Rt_inv = np.linalg.inv(Rt)
src_h, src_w = src.height, src.width
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corn_h = np.hstack([corners, np.ones((4,1))]).T
map_corners = (M_map @ corn_h).T[:, :2]
pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2]
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil (pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil (pix_corners[:,1].max())) + 10
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 重采样到最小外接矩形
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=bbox_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.bilinear,
)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}")
return True
# ---- 非仿射变换处理 ----
elif best_transform_type == "H":
# 单应变换H 已是 src_full_pixel -> ref_full_pixel
H_full = best_transform # 3x3
try:
# 用 H_full 映射源四角 -> 参考像素,求最小外接矩形
src_h, src_w = src.height, src.width
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32)
corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T
dst_h = (H_full @ corn_h)
dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T
min_x = int(np.floor(dst[:,0].min())) - 10
max_x = int(np.ceil (dst[:,0].max())) + 10
min_y = int(np.floor(dst[:,1].min())) - 10
max_y = int(np.ceil (dst[:,1].max())) + 10
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}")
return False
# 创建输出窗口
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
# 子窗口坐标的单应矩阵(输出坐标系是子窗口像素)
T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64)
H_sub = np.linalg.inv(T_off) @ H_full
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 使用 OpenCV 进行单应变换重采样
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32)
# 使用 OpenCV warpPerspective子窗口坐标
dst_band = cv2.warpPerspective(
src_band, H_sub,
(bbox_w, bbox_h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=dst_nodata
)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.warning(f"单应变换异常: {e}")
# 继续到仿射回退
elif best_transform_type in ["piecewise", "polynomial"]:
# 分片仿射或多项式变换:使用 scikit-image
transform = best_transform # 已用 k0_full/k1_global 估计
try:
# 用目标侧匹配点k1_global决定外接矩形更稳
pad = 10
min_x = int(np.floor(k1_global[:, 0].min())) - pad
max_x = int(np.ceil (k1_global[:, 0].max())) + pad
min_y = int(np.floor(k1_global[:, 1].min())) - pad
max_y = int(np.ceil (k1_global[:, 1].max())) + pad
min_x = max(0, min_x)
min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}")
return False
# 创建输出窗口
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 定义带偏移的逆映射函数
off_x, off_y = min_x, min_y
if best_transform_type == "polynomial":
# 对于多项式,估计逆变换
t_inv = PolynomialTransform()
t_inv.estimate(k1_global, k0_full, order=2) # 顺序:目标->源
def inv_map_rc(coords):
# coords: (N,2) in (row, col)
rc = np.asarray(coords)
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref
xy_src = t_inv(xy) # -> (x_src, y_src) in full-src
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
else: # piecewise_affine
# 目标侧点集的内点判定
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
# 退化为矩形内判断
def point_inside(xy):
return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & (xy[:,1] >= min_y) & (xy[:,1] <= max_y)
def inv_map_rc(coords):
rc = np.asarray(coords)
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref
inside = point_inside(xy)
xy_src = np.full_like(xy, fill_value=-1.0)
if np.any(inside):
xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src)
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
# 使用 scikit-image 进行变换重采样
from skimage.transform import warp
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = warp(
src_band,
inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射
output_shape=(bbox_h, bbox_w),
mode='constant',
cval=dst_nodata,
preserve_range=True
).astype(np.float32)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.warning(f"{best_transform_type}变换异常: {e}")
# 继续到仿射回退
elif best_transform_type == "tps":
# B样条变换优先使用 PIRT如果不可用则使用 SimpleITK TPS
try:
if PIRT_AVAILABLE:
# 使用 PIRT 实现 B样条弹性变换
logger.info("使用 PIRT B样条变换")
# 读取用于配准的单波段并归一化
ref_roi_data = ref_dataset.read(1, window=win).astype(np.float32)
src_band_data = src.read(1).astype(np.float32)
from skimage.exposure import rescale_intensity
ref_roi_reg = rescale_intensity(ref_roi_data, in_range='image', out_range=(0.0, 1.0))
src_reg = rescale_intensity(src_band_data, in_range='image', out_range=(0.0, 1.0))
# 构建 PIRT 注册器
reg = pirt.Registration(fixed=ref_roi_reg, moving=src_reg)
# 设置 B样条变换
bspline = pirt.transform.BSplineTransform(grid_spacing=(96, 96)) # 可调节控制点间距
reg.set_transformation(bspline)
# 设置相似度度量NCC 或 MI
reg.set_similarity(pirt.metrics.NCC())
# 多分辨率金字塔
reg.set_pyramid([4, 2, 1])
# 优化器设置
reg.set_optimizer(pirt.optimizers.LBFGS(max_iter=200), smooth=1.0)
# 执行注册
reg.run()
# 获取前向映射(参考网格到源的位移场)
phi = reg.get_forward_mapping() # (H, W, 2) 位移场
# 应用到所有波段
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=ref_roi_data.shape[0],
width=ref_roi_data.shape[1],
count=src.count,
transform=ref_dataset.window_transform(win),
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
# 使用 PIRT 的 warp 函数应用位移场
warped = pirt.warp(src_band, phi, mode='constant', cval=float(dst_nodata))
band_data = warped.astype(out_profile["dtype"])
out_ds.write(band_data, b)
logger.info(f"成功配准(B样条-PIRT): {bip_path.name} -> {out_path.name}")
return True
elif SITK_AVAILABLE and hasattr(sitk, "LandmarkBasedTransformInitializer"):
# 回退到 SimpleITK TPS
logger.info("PIRT 不可用,使用 SimpleITK TPS")
# 1) 统一坐标系:用"全图像素"作为物理坐标spacing=1, origin=(0,0), direction=I
fixed_pts = [(float(x1), float(y1)) for (x1,y1) in k1_global] # 参考(输出)侧
moving_pts = [(float(x0), float(y0)) for (x0,y0) in k0_full] # 源(输入)侧
# 2) 构造 TPS(参考→源) 用于 Resample输出点 -> 输入点)
tps_ref2src = sitk.LandmarkBasedTransformInitializer(
sitk.Transform(2, sitk.sitkThinPlateSplineKernelTransform),
fixed_pts, # fixed = 参考
moving_pts # moving = 源
)
# 3) 构造 TPS(源→参考) 仅用于外接矩形估计(源顶点投到参考)
tps_src2ref = sitk.LandmarkBasedTransformInitializer(
sitk.Transform(2, sitk.sitkThinPlateSplineKernelTransform),
moving_pts, # fixed = 源
fixed_pts # moving = 参考
)
# 4) 用 tps_src2ref 变换源四角,求参考全图上的外接矩形,并与参考范围求交
src_h, src_w = src.height, src.width
src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32)
dst_corners = np.array([tps_src2ref.TransformPoint((float(x),float(y))) for x,y in src_corners], dtype=np.float32)
min_x = max(0, int(np.floor(dst_corners[:,0].min())) - 10)
max_x = min(ref_dataset.width, int(np.ceil (dst_corners[:,0].max())) + 10)
min_y = max(0, int(np.floor(dst_corners[:,1].min())) - 10)
max_y = min(ref_dataset.height, int(np.ceil (dst_corners[:,1].max())) + 10)
bbox_w, bbox_h = max_x-min_x, max_y-min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"TPS最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
# 5) 参考(输出)影像定义spacing=1origin=(min_x,min_y)direction=I
ref_img = sitk.Image([bbox_w, bbox_h], sitk.sitkFloat32)
ref_img.SetSpacing((1.0, 1.0))
ref_img.SetOrigin((float(min_x), float(min_y)))
ref_img.SetDirection((1.0,0.0,0.0,1.0))
# 6) 源影像:用 rasterio 读为 numpy再转 SITKspacing=1, origin=0, direction=I
src_band = src.read(1).astype(np.float32)
src_img = sitk.GetImageFromArray(src_band)
# 7) 重采样:设置参考图像 + 变换=参考→源tps_ref2src
res = sitk.ResampleImageFilter()
res.SetReferenceImage(ref_img)
res.SetTransform(tps_ref2src)
res.SetInterpolator(sitk.sitkLinear)
if src.nodata is not None:
res.SetDefaultPixelValue(float(src.nodata))
# 8) 写 ENVI/BIP对所有波段逐一 TPS 重采样
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
src_img_band = sitk.GetImageFromArray(src_band)
warped = res.Execute(src_img_band)
band_data = sitk.GetArrayFromImage(warped).astype(out_profile["dtype"])
out_ds.write(band_data, b)
logger.info(f"成功配准(TPS-SimpleITK): {bip_path.name} -> {out_path.name}")
return True
else:
logger.warning("PIRT 和 SimpleITK TPS 都不可用,回退到仿射")
# 继续到仿射回退
except Exception as e:
logger.warning(f"B样条变换异常: {e}")
# 继续到仿射回退
# ---- 回退:使用仿射变换,保证最小可用结果 ----
# 重新估计仿射变换作为fallback
A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0)
if A_fallback is None:
logger.warning(f"仿射回退也失败: {bip_path.name}")
return False
# 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64)
S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64)
A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback
M_full = S1_inv @ A3 @ S0
T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64)
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
src_pixel_to_map_corrected = Rt @ T_off @ M_full
corrected_affine = Affine(
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
)
# 计算源 BIP 四角经过仿射变换后的最小外接矩形
# 将 rasterio.Affine 转为 3x3 像素->地图矩阵
M_map = np.array([
[corrected_affine.a, corrected_affine.b, corrected_affine.c],
[corrected_affine.d, corrected_affine.e, corrected_affine.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
# 参考底图的 像素->地图 矩阵及其逆
ref_transform = ref_dataset.transform
Rt = np.array([
[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
Rt_inv = np.linalg.inv(Rt)
# 源影像四角(源像素坐标)
src_h, src_w = src.height, src.width
src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4)
# 源像素 -> 地图坐标
map_corners = (M_map @ corners_h).T[:, :2]
# 地图坐标 -> 参考像素坐标
pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3)
pix_corners = pix_corners_h[:, :2]
# 最小外接矩形(像素)
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil( pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil( pix_corners[:,1].max())) + 10
# 边界裁剪
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
# 如果外接矩形太小,跳过
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
# 创建裁剪窗口和变换
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
# 更新输出 profile 使用最小外接矩形
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform, # 使用最小外接矩形的变换
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 重采样到最小外接矩形
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=bbox_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.bilinear,
)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
return False
# ---------- 主逻辑 ----------
def main():
logger.info("开始批量配准处理...")
# 检查输入文件是否存在
if not Path(REF_TIF).exists():
logger.error(f"参考文件不存在: {REF_TIF}")
return
if not BIP_DIR.exists():
logger.error(f"BIP文件夹不存在: {BIP_DIR}")
return
# 初始化匹配器
logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}")
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
# 打开参考文件
with rasterio.open(REF_TIF) as ref:
logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}")
# 查找所有 .bip 文件
bip_files = list(BIP_DIR.glob("*.bip"))
logger.info(f"找到 {len(bip_files)} 个 .bip 文件")
success_count = 0
for bip_path in bip_files:
if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR):
success_count += 1
logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准")
if __name__ == "__main__":
main()