first commit

This commit is contained in:
2026-03-06 17:24:55 +08:00
commit 5e0984bf9c
18 changed files with 10178 additions and 0 deletions

970
test V3.py Normal file
View File

@ -0,0 +1,970 @@
"""
批量配准 .bip 文件到参考 .tif 文件
使用 SimpleITK 实现 B 样条变换
"""
from pathlib import Path
import numpy as np
import cv2
import rasterio
from rasterio.windows import from_bounds
from rasterio.warp import transform_bounds, reproject, Resampling
from affine import Affine
from vismatch import get_matcher
import logging
try:
from skimage.transform import PiecewiseAffineTransform, PolynomialTransform
SKIMAGE_AVAILABLE = True
except ImportError:
SKIMAGE_AVAILABLE = False
logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换")
try:
from matplotlib.path import Path as MplPath
from scipy.spatial import ConvexHull
MATPLOTLIB_SCIPY_AVAILABLE = True
except ImportError:
MATPLOTLIB_SCIPY_AVAILABLE = False
MplPath = None
logging.warning("matplotlib 或 scipy 不可用piecewise_affine 将退化为矩形内判断")
try:
import SimpleITK as sitk
SITK_AVAILABLE = True
except ImportError:
SITK_AVAILABLE = False
logging.warning("SimpleITK 不可用,将使用仿射变换作为替代")
try:
import pirt
PIRT_AVAILABLE = True
except ImportError:
PIRT_AVAILABLE = False
logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代")
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ---------- 配置 ----------
# 请根据实际情况修改这些路径
REF_TIF = r"E:\is2\guidingsahn\result.tif" # 参考 tif 文件路径
BIP_DIR = Path(r"E:\is2\guidingsahn") # .bip 文件所在文件夹
OUT_DIR = Path(r"E:\is2\guidingsahn\output") # 输出文件夹
# 匹配算法选择
MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
DEVICE = "cuda" # 或 "cpu"
# 变换方法选择(按优先级尝试)
TRANSFORM_METHODS = ["homography"]
# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "tps"
# 匹配参数
MATCH_MAX_SIDE = 1500 # 匹配时最大边长(像素)
ROI_PAD_PX = 500 # 粗定位窗口的padding参考tif像素
# 质量控制阈值
MIN_INLIERS = 10 # 最少内点数
MIN_INLIER_RATIO = 0.01 # 最少内点比例
# 创建输出目录
OUT_DIR.mkdir(parents=True, exist_ok=True)
# ---------- 工具函数 ----------
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
arr = arr_chw.astype(np.float32)
if arr.shape[0] == 1:
# 单波段复制为3通道
arr = np.repeat(arr, 3, axis=0)
elif arr.shape[0] >= 3:
# 取前3波段
arr = arr[:3]
else:
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
# 百分位数拉伸,增强跨传感器匹配稳定性
p2 = np.percentile(arr, 2)
p98 = np.percentile(arr, 98)
arr = (arr - p2) / (p98 - p2 + 1e-6)
arr = np.clip(arr, 0.0, 1.0)
return arr
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
c, h, w = arr_chw.shape
s = min(1.0, max_side / max(h, w))
if s >= 1.0:
return arr_chw
new_w = int(round(w * s))
new_h = int(round(h * s))
# 用opencv缩放(逐通道)
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
return out
def _expand_window(win, pad, max_w, max_h):
"""扩展窗口并确保边界有效"""
col_off = int(max(0, win.col_off - pad))
row_off = int(max(0, win.row_off - pad))
col_end = int(min(max_w, win.col_off + win.width + pad))
row_end = int(min(max_h, win.row_off + win.height + pad))
return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off)
def estimate_transform(method, k0, k1):
"""统一的变换估计函数,支持多种变换类型"""
if method == "translation":
# 简单平移:用内点的平均位移
if len(k0) == 0:
return None, None
dx = np.mean(k1[:, 0] - k0[:, 0])
dy = np.mean(k1[:, 1] - k0[:, 1])
A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32)
return "A", A
elif method == "euclidean":
# 欧式变换(旋转+平移),约束等比缩放=1
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "similarity":
# 相似变换(旋转+等比缩放+平移)
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "affine":
# 全仿射变换(旋转+非等比缩放+剪切+平移)
A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "homography":
# 投影变换8DOF透视
H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0)
return "H", H
elif method == "piecewise_affine":
# 分片仿射变换
if not SKIMAGE_AVAILABLE:
return None, None
try:
tform = PiecewiseAffineTransform()
tform.estimate(k0, k1)
return "piecewise", tform
except Exception:
return None, None
elif method == "polynomial":
# 多项式变换2阶
if not SKIMAGE_AVAILABLE:
return None, None
try:
tform = PolynomialTransform()
tform.estimate(k0, k1, order=2)
return "polynomial", tform
except Exception:
return None, None
elif method == "tps":
# 薄板样条变换如果SimpleITK可用
SITK_TPS = SITK_AVAILABLE and hasattr(sitk, "ThinPlateSplineKernelTransform")
if not SITK_TPS:
return None, None
try:
tps_transform = sitk.ThinPlateSplineKernelTransform()
tps_transform.SetKernelTypeToThinPlateSpline()
fixed_landmarks = sitk.vectorDPoint()
moving_landmarks = sitk.vectorDPoint()
for (rx, ry), (sx, sy) in zip(k1, k0):
fixed_landmarks.push_back([float(rx), float(ry)])
moving_landmarks.push_back([float(sx), float(sy)])
tps_transform.SetFixedLandmarks(fixed_landmarks)
tps_transform.SetMovingLandmarks(moving_landmarks)
return "tps", tps_transform
except Exception:
return None, None
else:
raise ValueError(f"未知变换方法: {method}")
def evaluate_transform_quality(transform_type, transform, k0, k1):
"""评估变换质量(重投影误差)"""
if transform is None or len(k0) == 0:
return np.inf, np.inf
if transform_type == "A":
# 仿射变换重投影误差
A = transform
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
pred = (A @ np.hstack([k0, ones]).T).T
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type == "H":
# 单应变换重投影误差
H = transform
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
src_h = np.hstack([k0, ones]).T
warped = H @ src_h
warped /= (warped[2:3, :] + 1e-6)
pred = warped[:2, :].T
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type in ["piecewise", "polynomial"]:
# scikit-image 变换重投影误差
pred = transform(k0)
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type == "tps":
# TPS 变换重投影误差SimpleITK
pred = []
for pt in k0:
transformed_pt = transform.TransformPoint([float(pt[0]), float(pt[1])])
pred.append([transformed_pt[0], transformed_pt[1]])
pred = np.array(pred)
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
else:
return np.inf, np.inf
return float(np.median(e)), float(np.percentile(e, 95))
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path):
"""处理单个 .bip 文件到参考 .tif 的配准"""
try:
with rasterio.open(bip_path) as src:
logger.info(f"处理文件: {bip_path.name}")
# 检查CRS
if src.crs is None:
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息尝试使用参考文件的CRS")
src_crs = ref_dataset.crs
else:
src_crs = src.crs
ref_crs = ref_dataset.crs
if ref_crs is None:
raise RuntimeError(f"参考文件缺少CRS信息")
# 1) 用地理信息把 src.bounds 转到 ref CRS再裁 ref ROI
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
win0 = from_bounds(*b, transform=ref_dataset.transform)
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
if win.width <= 0 or win.height <= 0:
logger.warning(f"无重叠区域: {bip_path.name}")
return False
# 2) 读取数据
# 读取所有波段,如果是多波段的话
src_arr = src.read() # (bands, H, W)
if src_arr.ndim == 2: # 单波段
src_arr = src_arr[None, ...] # 增加波段维度
# 读取参考文件的ROI
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
if ref_arr.ndim == 2: # 单波段
ref_arr = ref_arr[None, ...] # 增加波段维度
# 转换为匹配所需的格式
src_img = _to_3ch_float01(src_arr)
ref_img = _to_3ch_float01(ref_arr)
# 3) 匹配用降采样版本,提速 + 增稳
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
# 4) 精配准img0=src, img1=ref_roi
result = matcher(src_small, ref_small)
num_inl = int(result["num_inliers"])
num_m = len(result["matched_kpts0"])
ratio = (num_inl / num_m) if num_m else 0.0
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
logger.warning(f"匹配质量不足: {bip_path.name}")
return False
# 5) 用内点估计多种变换并自动选择最优
# 先计算全分辨率坐标
k0_small = result["inlier_kpts0"].astype(np.float32)
k1_small = result["inlier_kpts1"].astype(np.float32)
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src)
S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI)
ones = np.ones((k0_small.shape[0], 1), dtype=np.float32)
k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素
k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素
k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素
# 用全分辨率坐标进行所有模型的估计和评估
best_transform = None
best_transform_type = None
best_error = np.inf
best_method = None
for method in TRANSFORM_METHODS:
transform_type, transform = estimate_transform(method, k0_full, k1_global)
if transform is None:
continue
med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global)
# 选择重投影误差最小的变换
if p95_err < best_error:
best_transform = transform
best_transform_type = transform_type
best_error = p95_err
best_method = method
logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}")
if best_transform is None:
logger.warning(f"所有变换方法都失败: {bip_path.name}")
return False
logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}")
# 6) 根据变换类型进行相应的配准处理
if best_transform_type == "A":
# 仿射变换A 已是 src_full_pixel -> ref_full_pixel直接构造像素->地图仿射
A = best_transform # 2x3, src_full_pixel -> ref_full_pixel
A3 = np.eye(3, dtype=np.float64)
A3[:2, :] = A
# src_pixel -> map
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
M_map = Rt @ A3
corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2],
M_map[1,0], M_map[1,1], M_map[1,2])
# 用 M_map 求最小外接矩形(先到 map再到 ref 像素)
Rt_inv = np.linalg.inv(Rt)
src_h, src_w = src.height, src.width
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corn_h = np.hstack([corners, np.ones((4,1))]).T
map_corners = (M_map @ corn_h).T[:, :2]
pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2]
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil (pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil (pix_corners[:,1].max())) + 10
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 重采样到最小外接矩形
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=bbox_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.bilinear,
)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}")
return True
# ---- 非仿射变换处理 ----
elif best_transform_type == "H":
# 单应变换H 已是 src_full_pixel -> ref_full_pixel
H_full = best_transform # 3x3
try:
# 用 H_full 映射源四角 -> 参考像素,求最小外接矩形
src_h, src_w = src.height, src.width
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32)
corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T
dst_h = (H_full @ corn_h)
dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T
min_x = int(np.floor(dst[:,0].min())) - 10
max_x = int(np.ceil (dst[:,0].max())) + 10
min_y = int(np.floor(dst[:,1].min())) - 10
max_y = int(np.ceil (dst[:,1].max())) + 10
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}")
return False
# 创建输出窗口
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
# 子窗口坐标的单应矩阵(输出坐标系是子窗口像素)
T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64)
H_sub = np.linalg.inv(T_off) @ H_full
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 使用 OpenCV 进行单应变换重采样
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32)
# 使用 OpenCV warpPerspective子窗口坐标
dst_band = cv2.warpPerspective(
src_band, H_sub,
(bbox_w, bbox_h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=dst_nodata
)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.warning(f"单应变换异常: {e}")
# 继续到仿射回退
elif best_transform_type in ["piecewise", "polynomial"]:
# 分片仿射或多项式变换:使用 scikit-image
transform = best_transform # 已用 k0_full/k1_global 估计
try:
# 用目标侧匹配点k1_global决定外接矩形更稳
pad = 10
min_x = int(np.floor(k1_global[:, 0].min())) - pad
max_x = int(np.ceil (k1_global[:, 0].max())) + pad
min_y = int(np.floor(k1_global[:, 1].min())) - pad
max_y = int(np.ceil (k1_global[:, 1].max())) + pad
min_x = max(0, min_x)
min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}")
return False
# 创建输出窗口
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 定义带偏移的逆映射函数
off_x, off_y = min_x, min_y
if best_transform_type == "polynomial":
# 对于多项式,估计逆变换
t_inv = PolynomialTransform()
t_inv.estimate(k1_global, k0_full, order=2) # 顺序:目标->源
def inv_map_rc(coords):
# coords: (N,2) in (row, col)
rc = np.asarray(coords)
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref
xy_src = t_inv(xy) # -> (x_src, y_src) in full-src
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
else: # piecewise_affine
# 目标侧点集的内点判定
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
# 退化为矩形内判断
def point_inside(xy):
return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & (xy[:,1] >= min_y) & (xy[:,1] <= max_y)
def inv_map_rc(coords):
rc = np.asarray(coords)
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref
inside = point_inside(xy)
xy_src = np.full_like(xy, fill_value=-1.0)
if np.any(inside):
xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src)
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
# 使用 scikit-image 进行变换重采样
from skimage.transform import warp
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = warp(
src_band,
inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射
output_shape=(bbox_h, bbox_w),
mode='constant',
cval=dst_nodata,
preserve_range=True
).astype(np.float32)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.warning(f"{best_transform_type}变换异常: {e}")
# 继续到仿射回退
elif best_transform_type == "tps":
# B样条变换优先使用 PIRT如果不可用则使用 SimpleITK TPS
try:
if PIRT_AVAILABLE:
# 使用 PIRT 实现 B样条弹性变换
logger.info("使用 PIRT B样条变换")
# 读取用于配准的单波段并归一化
ref_roi_data = ref_dataset.read(1, window=win).astype(np.float32)
src_band_data = src.read(1).astype(np.float32)
from skimage.exposure import rescale_intensity
ref_roi_reg = rescale_intensity(ref_roi_data, in_range='image', out_range=(0.0, 1.0))
src_reg = rescale_intensity(src_band_data, in_range='image', out_range=(0.0, 1.0))
# 构建 PIRT 注册器
reg = pirt.Registration(fixed=ref_roi_reg, moving=src_reg)
# 设置 B样条变换
bspline = pirt.transform.BSplineTransform(grid_spacing=(96, 96)) # 可调节控制点间距
reg.set_transformation(bspline)
# 设置相似度度量NCC 或 MI
reg.set_similarity(pirt.metrics.NCC())
# 多分辨率金字塔
reg.set_pyramid([4, 2, 1])
# 优化器设置
reg.set_optimizer(pirt.optimizers.LBFGS(max_iter=200), smooth=1.0)
# 执行注册
reg.run()
# 获取前向映射(参考网格到源的位移场)
phi = reg.get_forward_mapping() # (H, W, 2) 位移场
# 应用到所有波段
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=ref_roi_data.shape[0],
width=ref_roi_data.shape[1],
count=src.count,
transform=ref_dataset.window_transform(win),
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
# 使用 PIRT 的 warp 函数应用位移场
warped = pirt.warp(src_band, phi, mode='constant', cval=float(dst_nodata))
band_data = warped.astype(out_profile["dtype"])
out_ds.write(band_data, b)
logger.info(f"成功配准(B样条-PIRT): {bip_path.name} -> {out_path.name}")
return True
elif SITK_AVAILABLE and hasattr(sitk, "LandmarkBasedTransformInitializer"):
# 回退到 SimpleITK TPS
logger.info("PIRT 不可用,使用 SimpleITK TPS")
# 1) 统一坐标系:用"全图像素"作为物理坐标spacing=1, origin=(0,0), direction=I
fixed_pts = [(float(x1), float(y1)) for (x1,y1) in k1_global] # 参考(输出)侧
moving_pts = [(float(x0), float(y0)) for (x0,y0) in k0_full] # 源(输入)侧
# 2) 构造 TPS(参考→源) 用于 Resample输出点 -> 输入点)
tps_ref2src = sitk.LandmarkBasedTransformInitializer(
sitk.Transform(2, sitk.sitkThinPlateSplineKernelTransform),
fixed_pts, # fixed = 参考
moving_pts # moving = 源
)
# 3) 构造 TPS(源→参考) 仅用于外接矩形估计(源顶点投到参考)
tps_src2ref = sitk.LandmarkBasedTransformInitializer(
sitk.Transform(2, sitk.sitkThinPlateSplineKernelTransform),
moving_pts, # fixed = 源
fixed_pts # moving = 参考
)
# 4) 用 tps_src2ref 变换源四角,求参考全图上的外接矩形,并与参考范围求交
src_h, src_w = src.height, src.width
src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32)
dst_corners = np.array([tps_src2ref.TransformPoint((float(x),float(y))) for x,y in src_corners], dtype=np.float32)
min_x = max(0, int(np.floor(dst_corners[:,0].min())) - 10)
max_x = min(ref_dataset.width, int(np.ceil (dst_corners[:,0].max())) + 10)
min_y = max(0, int(np.floor(dst_corners[:,1].min())) - 10)
max_y = min(ref_dataset.height, int(np.ceil (dst_corners[:,1].max())) + 10)
bbox_w, bbox_h = max_x-min_x, max_y-min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"TPS最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
# 5) 参考(输出)影像定义spacing=1origin=(min_x,min_y)direction=I
ref_img = sitk.Image([bbox_w, bbox_h], sitk.sitkFloat32)
ref_img.SetSpacing((1.0, 1.0))
ref_img.SetOrigin((float(min_x), float(min_y)))
ref_img.SetDirection((1.0,0.0,0.0,1.0))
# 6) 源影像:用 rasterio 读为 numpy再转 SITKspacing=1, origin=0, direction=I
src_band = src.read(1).astype(np.float32)
src_img = sitk.GetImageFromArray(src_band)
# 7) 重采样:设置参考图像 + 变换=参考→源tps_ref2src
res = sitk.ResampleImageFilter()
res.SetReferenceImage(ref_img)
res.SetTransform(tps_ref2src)
res.SetInterpolator(sitk.sitkLinear)
if src.nodata is not None:
res.SetDefaultPixelValue(float(src.nodata))
# 8) 写 ENVI/BIP对所有波段逐一 TPS 重采样
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
src_img_band = sitk.GetImageFromArray(src_band)
warped = res.Execute(src_img_band)
band_data = sitk.GetArrayFromImage(warped).astype(out_profile["dtype"])
out_ds.write(band_data, b)
logger.info(f"成功配准(TPS-SimpleITK): {bip_path.name} -> {out_path.name}")
return True
else:
logger.warning("PIRT 和 SimpleITK TPS 都不可用,回退到仿射")
# 继续到仿射回退
except Exception as e:
logger.warning(f"B样条变换异常: {e}")
# 继续到仿射回退
# ---- 回退:使用仿射变换,保证最小可用结果 ----
# 重新估计仿射变换作为fallback
A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0)
if A_fallback is None:
logger.warning(f"仿射回退也失败: {bip_path.name}")
return False
# 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64)
S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64)
A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback
M_full = S1_inv @ A3 @ S0
T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64)
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
src_pixel_to_map_corrected = Rt @ T_off @ M_full
corrected_affine = Affine(
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
)
# 计算源 BIP 四角经过仿射变换后的最小外接矩形
# 将 rasterio.Affine 转为 3x3 像素->地图矩阵
M_map = np.array([
[corrected_affine.a, corrected_affine.b, corrected_affine.c],
[corrected_affine.d, corrected_affine.e, corrected_affine.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
# 参考底图的 像素->地图 矩阵及其逆
ref_transform = ref_dataset.transform
Rt = np.array([
[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
Rt_inv = np.linalg.inv(Rt)
# 源影像四角(源像素坐标)
src_h, src_w = src.height, src.width
src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4)
# 源像素 -> 地图坐标
map_corners = (M_map @ corners_h).T[:, :2]
# 地图坐标 -> 参考像素坐标
pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3)
pix_corners = pix_corners_h[:, :2]
# 最小外接矩形(像素)
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil( pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil( pix_corners[:,1].max())) + 10
# 边界裁剪
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
# 如果外接矩形太小,跳过
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
# 创建裁剪窗口和变换
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
# 更新输出 profile 使用最小外接矩形
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform, # 使用最小外接矩形的变换
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 重采样到最小外接矩形
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=bbox_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.bilinear,
)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
return False
# ---------- 主逻辑 ----------
def main():
logger.info("开始批量配准处理...")
# 检查输入文件是否存在
if not Path(REF_TIF).exists():
logger.error(f"参考文件不存在: {REF_TIF}")
return
if not BIP_DIR.exists():
logger.error(f"BIP文件夹不存在: {BIP_DIR}")
return
# 初始化匹配器
logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}")
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
# 打开参考文件
with rasterio.open(REF_TIF) as ref:
logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}")
# 查找所有 .bip 文件
bip_files = list(BIP_DIR.glob("*.bip"))
logger.info(f"找到 {len(bip_files)} 个 .bip 文件")
success_count = 0
for bip_path in bip_files:
if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR):
success_count += 1
logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准")
if __name__ == "__main__":
main()