Files
StripStitch/test V6.py
2026-03-06 17:24:55 +08:00

1509 lines
66 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批量配准 .bip 文件到参考 .tif 文件
使用 实现非刚性配准
"""
from pathlib import Path
import numpy as np
import cv2
import rasterio
import csv
from datetime import datetime
from rasterio.windows import from_bounds
from rasterio.warp import transform_bounds, reproject, Resampling
from affine import Affine
from vismatch import get_matcher
import logging
import threading
import queue
from dataclasses import dataclass
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import sys
import os
try:
from skimage.transform import PiecewiseAffineTransform, PolynomialTransform
SKIMAGE_AVAILABLE = True
except ImportError:
SKIMAGE_AVAILABLE = False
logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换")
try:
from matplotlib.path import Path as MplPath
from scipy.spatial import ConvexHull
MATPLOTLIB_SCIPY_AVAILABLE = True
except ImportError:
MATPLOTLIB_SCIPY_AVAILABLE = False
MplPath = None
logging.warning("matplotlib 或 scipy 不可用piecewise_affine 将退化为矩形内判断")
try:
import SimpleITK as sitk
SITK_AVAILABLE = True
except ImportError:
SITK_AVAILABLE = False
logging.warning("SimpleITK 不可用,将使用仿射变换作为替代")
try:
import pirt
PIRT_AVAILABLE = True
except ImportError:
PIRT_AVAILABLE = False
logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代")
try:
from scipy.interpolate import Rbf
SCIPY_AVAILABLE = True
except ImportError:
SCIPY_AVAILABLE = False
logging.warning("scipy 不可用,将跳过 TPS 变换")
@dataclass
class Config:
"""配置参数类"""
ref_tif: str
bip_dir: str
out_dir: str
matcher_name: str
device: str
transform_methods: list
match_max_side: int
roi_pad_px: int
min_inliers: int
min_inlier_ratio: float
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ---------- 配置 ----------
# 默认配置,请根据实际情况修改这些路径
DEFAULT_REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径
DEFAULT_BIP_DIR = r"E:\is2\yaopu" # .bip 文件所在文件夹
DEFAULT_OUT_DIR = r"E:\is2\yaopu\output" # 输出文件夹
# 默认匹配算法选择
DEFAULT_MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
DEFAULT_DEVICE = "cuda" # 或 "cpu"
# 默认变换方法选择(按优先级尝试)
DEFAULT_TRANSFORM_METHODS = ["homography", "affine", "piecewise_affine"]
# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial"
# 默认匹配参数
DEFAULT_MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素)
DEFAULT_ROI_PAD_PX = 500 # 粗定位窗口的padding参考tif像素
# 默认质量控制阈值
DEFAULT_MIN_INLIERS = 10 # 最少内点数
DEFAULT_MIN_INLIER_RATIO = 0.01 # 最少内点比例
# 向后兼容的全局变量(用于命令行模式)
REF_TIF = DEFAULT_REF_TIF
BIP_DIR = Path(DEFAULT_BIP_DIR)
OUT_DIR = Path(DEFAULT_OUT_DIR)
MATCHER_NAME = DEFAULT_MATCHER_NAME
DEVICE = DEFAULT_DEVICE
TRANSFORM_METHODS = DEFAULT_TRANSFORM_METHODS
MATCH_MAX_SIDE = DEFAULT_MATCH_MAX_SIDE
ROI_PAD_PX = DEFAULT_ROI_PAD_PX
MIN_INLIERS = DEFAULT_MIN_INLIERS
MIN_INLIER_RATIO = DEFAULT_MIN_INLIER_RATIO
# 创建输出目录
OUT_DIR.mkdir(parents=True, exist_ok=True)
# 创建统计输出目录和文件
STATS_DIR = OUT_DIR / "stats"
STATS_DIR.mkdir(parents=True, exist_ok=True)
STATS_CSV = STATS_DIR / "registration_stats.csv"
# ---------- 工具函数 ----------
def init_stats_csv(csv_path: Path):
"""初始化统计CSV文件"""
if not csv_path.exists():
with open(csv_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow([
'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio',
'selected_method', 'median_error', 'p95_error', 'success'
])
def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int,
inlier_ratio: float, selected_method: str, median_error: float,
p95_error: float, success: bool):
"""记录配准统计信息到CSV"""
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(csv_path, 'a', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow([
timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}",
selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success
])
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
arr = arr_chw.astype(np.float32)
if arr.shape[0] == 1:
# 单波段复制为3通道
arr = np.repeat(arr, 3, axis=0)
elif arr.shape[0] >= 3:
# 取前3波段
arr = arr[:3]
else:
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
# 百分位数拉伸,增强跨传感器匹配稳定性
p2 = np.percentile(arr, 2)
p98 = np.percentile(arr, 98)
arr = (arr - p2) / (p98 - p2 + 1e-6)
arr = np.clip(arr, 0.0, 1.0)
return arr
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
c, h, w = arr_chw.shape
s = min(1.0, max_side / max(h, w))
if s >= 1.0:
return arr_chw
new_w = int(round(w * s))
new_h = int(round(h * s))
# 用opencv缩放(逐通道)
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
return out
def _expand_window(win, pad, max_w, max_h):
"""扩展窗口并确保边界有效"""
col_off = int(max(0, win.col_off - pad))
row_off = int(max(0, win.row_off - pad))
col_end = int(min(max_w, win.col_off + win.width + pad))
row_end = int(min(max_h, win.row_off + win.height + pad))
return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off)
def estimate_transform(method, k0, k1):
"""统一的变换估计函数,支持多种变换类型"""
if method == "translation":
# 简单平移:用内点的平均位移
if len(k0) == 0:
return None, None
dx = np.mean(k1[:, 0] - k0[:, 0])
dy = np.mean(k1[:, 1] - k0[:, 1])
A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32)
return "A", A
elif method == "euclidean":
# 欧式变换(旋转+平移),约束等比缩放=1
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "similarity":
# 相似变换(旋转+等比缩放+平移)
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "affine":
# 全仿射变换(旋转+非等比缩放+剪切+平移)
A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "homography":
# 投影变换8DOF透视
H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0)
return "H", H
elif method == "piecewise_affine":
# 分片仿射变换
if not SKIMAGE_AVAILABLE:
return None, None
try:
tform = PiecewiseAffineTransform()
tform.estimate(k0, k1)
return "piecewise", tform
except Exception:
return None, None
elif method == "polynomial":
# 多项式变换2阶
if not SKIMAGE_AVAILABLE:
return None, None
try:
tform = PolynomialTransform()
tform.estimate(k0, k1, order=2)
return "polynomial", tform
except Exception:
return None, None
else:
raise ValueError(f"未知变换方法: {method}")
def evaluate_transform_quality(transform_type, transform, k0, k1):
"""评估变换质量(重投影误差)"""
if transform is None or len(k0) == 0:
return np.inf, np.inf
if transform_type == "A":
# 仿射变换重投影误差
A = transform
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
pred = (A @ np.hstack([k0, ones]).T).T
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type == "H":
# 单应变换重投影误差
H = transform
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
src_h = np.hstack([k0, ones]).T
warped = H @ src_h
warped /= (warped[2:3, :] + 1e-6)
pred = warped[:2, :].T
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type in ["piecewise", "polynomial"]:
# scikit-image 变换重投影误差
pred = transform(k0)
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
else:
return np.inf, np.inf
return float(np.median(e)), float(np.percentile(e, 95))
def _norm01_hw(x: np.ndarray) -> np.ndarray:
"""对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性"""
x = x.astype(np.float32, copy=False)
p2 = float(np.percentile(x, 2))
p98 = float(np.percentile(x, 98))
y = (x - p2) / (p98 - p2 + 1e-6)
return np.clip(y, 0.0, 1.0)
def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)):
"""
numpy(H,W)->SimpleITK Image。
物理坐标约定为“像素坐标系”spacing=1, direction=Iorigin=(x0,y0)。
"""
img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False))
img.SetSpacing((1.0, 1.0))
img.SetOrigin((float(origin_xy[0]), float(origin_xy[1])))
img.SetDirection((1.0, 0.0, 0.0, 1.0))
return img
def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10):
"""用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内"""
min_x = int(np.floor(k1_global[:, 0].min())) - pad
max_x = int(np.ceil (k1_global[:, 0].max())) + pad
min_y = int(np.floor(k1_global[:, 1].min())) - pad
max_y = int(np.ceil (k1_global[:, 1].max())) + pad
min_x = max(0, min_x)
min_y = max(0, min_y)
max_x = min(ref_w, max_x)
max_y = min(ref_h, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
return min_x, min_y, bbox_w, bbox_h
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path):
"""处理单个 .bip 文件到参考 .tif 的配准"""
try:
with rasterio.open(bip_path) as src:
logger.info(f"处理文件: {bip_path.name}")
# 初始化统计变量
num_inliers = 0
num_matches = 0
inlier_ratio = 0.0
selected_method = "none"
median_error = float('inf')
p95_error = float('inf')
success = False
# 检查CRS
if src.crs is None:
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息尝试使用参考文件的CRS")
src_crs = ref_dataset.crs
else:
src_crs = src.crs
ref_crs = ref_dataset.crs
if ref_crs is None:
raise RuntimeError(f"参考文件缺少CRS信息")
# 1) 用地理信息把 src.bounds 转到 ref CRS再裁 ref ROI
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
win0 = from_bounds(*b, transform=ref_dataset.transform)
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
if win.width <= 0 or win.height <= 0:
logger.warning(f"无重叠区域: {bip_path.name}")
return False
# 2) 读取数据
# 读取所有波段,如果是多波段的话
src_arr = src.read() # (bands, H, W)
if src_arr.ndim == 2: # 单波段
src_arr = src_arr[None, ...] # 增加波段维度
# 读取参考文件的ROI
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
if ref_arr.ndim == 2: # 单波段
ref_arr = ref_arr[None, ...] # 增加波段维度
# 转换为匹配所需的格式
src_img = _to_3ch_float01(src_arr)
ref_img = _to_3ch_float01(ref_arr)
# 3) 匹配用降采样版本,提速 + 增稳
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
# 4) 精配准img0=src, img1=ref_roi
result = matcher(src_small, ref_small)
num_inl = int(result["num_inliers"])
num_m = len(result["matched_kpts0"])
ratio = (num_inl / num_m) if num_m else 0.0
# 更新统计变量
num_inliers = num_inl
num_matches = num_m
inlier_ratio = ratio
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
logger.warning(f"匹配质量不足: {bip_path.name}")
# 记录失败的统计信息
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "failed_quality_check", median_error, p95_error, False)
return False
# 5) 用内点估计多种变换并自动选择最优
# 先计算全分辨率坐标
k0_small = result["inlier_kpts0"].astype(np.float32)
k1_small = result["inlier_kpts1"].astype(np.float32)
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src)
S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI)
ones = np.ones((k0_small.shape[0], 1), dtype=np.float32)
k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素
k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素
k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素
# 用全分辨率坐标进行所有模型的估计和评估
best_transform = None
best_transform_type = None
best_error = np.inf
best_median_error = np.inf
best_method = None
for method in TRANSFORM_METHODS:
transform_type, transform = estimate_transform(method, k0_full, k1_global)
if transform is None:
continue
med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global)
# 选择重投影误差最小的变换
if p95_err < best_error:
best_transform = transform
best_transform_type = transform_type
best_error = p95_err
best_median_error = med_err
best_method = method
logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}")
if best_transform is None:
logger.warning(f"所有变换方法都失败: {bip_path.name}")
# 记录失败的统计信息
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "failed_transform", median_error, p95_error, False)
return False
# 更新统计变量
selected_method = best_method
median_error = best_median_error
p95_error = best_error
logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}")
# 6) 根据变换类型进行相应的配准处理
if best_transform_type == "A":
# 仿射变换A 已是 src_full_pixel -> ref_full_pixel直接构造像素->地图仿射
A = best_transform # 2x3, src_full_pixel -> ref_full_pixel
A3 = np.eye(3, dtype=np.float64)
A3[:2, :] = A
# src_pixel -> map
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
M_map = Rt @ A3
corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2],
M_map[1,0], M_map[1,1], M_map[1,2])
# 用 M_map 求最小外接矩形(先到 map再到 ref 像素)
Rt_inv = np.linalg.inv(Rt)
src_h, src_w = src.height, src.width
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corn_h = np.hstack([corners, np.ones((4,1))]).T
map_corners = (M_map @ corn_h).T[:, :2]
pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2]
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil (pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil (pix_corners[:,1].max())) + 10
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 重采样到最小外接矩形
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=bbox_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.bilinear,
)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, selected_method, median_error, p95_error, success)
return True
# ---- 非仿射变换处理 ----
elif best_transform_type == "H":
# 单应变换H 已是 src_full_pixel -> ref_full_pixel
H_full = best_transform # 3x3
try:
# 用 H_full 映射源四角 -> 参考像素,求最小外接矩形
src_h, src_w = src.height, src.width
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32)
corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T
dst_h = (H_full @ corn_h)
dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T
min_x = int(np.floor(dst[:,0].min())) - 10
max_x = int(np.ceil (dst[:,0].max())) + 10
min_y = int(np.floor(dst[:,1].min())) - 10
max_y = int(np.ceil (dst[:,1].max())) + 10
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}")
return False
# 创建输出窗口
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
# 子窗口坐标的单应矩阵(输出坐标系是子窗口像素)
T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64)
H_sub = np.linalg.inv(T_off) @ H_full
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 使用 OpenCV 进行单应变换重采样
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32)
# 使用 OpenCV warpPerspective子窗口坐标
dst_band = cv2.warpPerspective(
src_band, H_sub,
(bbox_w, bbox_h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=dst_nodata
)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, selected_method, median_error, p95_error, success)
return True
except Exception as e:
logger.warning(f"单应变换异常: {e}")
# 继续到仿射回退
elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]:
# 分片仿射或多项式变换:使用 scikit-image
transform = best_transform # 已用 k0_full/k1_global 估计
try:
# 用目标侧匹配点k1_global决定外接矩形更稳
pad = 10
min_x = int(np.floor(k1_global[:, 0].min())) - pad
max_x = int(np.ceil (k1_global[:, 0].max())) + pad
min_y = int(np.floor(k1_global[:, 1].min())) - pad
max_y = int(np.ceil (k1_global[:, 1].max())) + pad
min_x = max(0, min_x)
min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}")
return False
# 创建输出窗口
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 定义带偏移的逆映射函数
off_x, off_y = min_x, min_y
if best_transform_type in ["polynomial", "polynomial_order3"]:
# 对于多项式,估计逆变换
order = 2 if best_transform_type == "polynomial" else 3
t_inv = PolynomialTransform()
t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源
# 目标侧点集的内点判定(用于限制外推)
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y],
[min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
def point_inside(xy):
return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) &
(xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h))
def inv_map_rc(coords):
# coords: (N,2) in (row, col)
rc = np.asarray(coords)
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref
inside = point_inside(xy)
xy_src = np.full_like(xy, fill_value=-1.0)
if np.any(inside):
xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src
# 确保坐标在源图像范围内
xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1)
xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1)
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
elif best_transform_type == "piecewise": # piecewise_affine
# 目标侧点集的内点判定
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
# 使用当前裁剪窗口的边界创建矩形
rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
# 退化为矩形内判断
def point_inside(xy):
return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \
(xy[:,1] >= min_y) & (xy[:,1] <= max_y)
def inv_map_rc(coords):
rc = np.asarray(coords)
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref
inside = point_inside(xy)
xy_src = np.full_like(xy, fill_value=-1.0)
if np.any(inside):
xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src)
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
# 使用 scikit-image 进行变换重采样
from skimage.transform import warp
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = warp(
src_band,
inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射
output_shape=(bbox_h, bbox_w),
mode='constant',
cval=dst_nodata,
preserve_range=True
).astype(np.float32)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, selected_method, median_error, p95_error, success)
return True
except Exception as e:
logger.warning(f"{best_transform_type}变换异常: {e}")
# 继续到仿射回退
# ---- 回退:使用仿射变换,保证最小可用结果 ----
transform = best_transform
try:
min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1(
k1_global, ref_dataset.width, ref_dataset.height, pad=10
)
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
rect = np.array(
[[min_x, min_y], [min_x + bbox_w, min_y],
[min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]],
dtype=float
)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
def point_inside(xy):
return (
(xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) &
(xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h)
)
off_x, off_y = min_x, min_y
tps_inv = transform["inv"] # ref -> src
def inv_map_rc(coords):
rc = np.asarray(coords, dtype=np.float64)
xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y)
inside = point_inside(xy_ref)
xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64)
if np.any(inside):
# 使用RBF插值计算逆映射
xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1])
xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1])
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 优先用 skimage.warp缺失时用 SimpleITK Resample 兜底
if SKIMAGE_AVAILABLE:
from skimage.transform import warp
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = warp(
src_band,
inverse_map=inv_map_rc,
output_shape=(bbox_h, bbox_w),
mode='constant',
cval=dst_nodata,
preserve_range=True
).astype(np.float32)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
else:
# OpenCV remap 版本(无需 skimage/SimpleITK
with rasterio.open(out_path, "w", **out_profile) as out_ds:
# 创建映射网格
y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w]
coords = np.column_stack([y_coords.ravel(), x_coords.ravel()])
# 计算逆映射
mapped_coords = inv_map_rc(coords)
map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32)
map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32)
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
# 使用OpenCV的remap进行重采样
dst_band = cv2.remap(
src_band, map_x, map_y,
interpolation=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=dst_nodata
)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.warning(f"tps变换异常: {e}")
# 继续到仿射回退
# ---- 回退:使用仿射变换,保证最小可用结果 ----
# 重新估计仿射变换作为fallback
A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0)
if A_fallback is None:
logger.warning(f"仿射回退也失败: {bip_path.name}")
return False
# 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64)
S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64)
A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback
M_full = S1_inv @ A3 @ S0
T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64)
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
src_pixel_to_map_corrected = Rt @ T_off @ M_full
corrected_affine = Affine(
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
)
# 计算源 BIP 四角经过仿射变换后的最小外接矩形
# 将 rasterio.Affine 转为 3x3 像素->地图矩阵
M_map = np.array([
[corrected_affine.a, corrected_affine.b, corrected_affine.c],
[corrected_affine.d, corrected_affine.e, corrected_affine.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
# 参考底图的 像素->地图 矩阵及其逆
ref_transform = ref_dataset.transform
Rt = np.array([
[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
Rt_inv = np.linalg.inv(Rt)
# 源影像四角(源像素坐标)
src_h, src_w = src.height, src.width
src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4)
# 源像素 -> 地图坐标
map_corners = (M_map @ corners_h).T[:, :2]
# 地图坐标 -> 参考像素坐标
pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3)
pix_corners = pix_corners_h[:, :2]
# 最小外接矩形(像素)
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil( pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil( pix_corners[:,1].max())) + 10
# 边界裁剪
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
# 如果外接矩形太小,跳过
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
# 创建裁剪窗口和变换
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
# 更新输出 profile 使用最小外接矩形
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform, # 使用最小外接矩形的变换
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 重采样到最小外接矩形
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=bbox_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.bilinear,
)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "affine_fallback", median_error, p95_error, success)
return True
except Exception as e:
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
# 记录失败的统计信息
try:
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "exception", median_error, p95_error, False)
except:
pass # 避免统计记录失败影响主要错误处理
return False
# ---------- 主逻辑 ----------
def run_batch(config: Config, on_progress=None, on_log=None, stop_event=None):
"""批量配准处理函数
Args:
config: 配置参数
on_progress: 进度回调函数 (current_idx, total, filename)
on_log: 日志回调函数 (message)
stop_event: 停止事件,用于取消处理
"""
def log(message):
if on_log:
on_log(message)
logger.info(message)
log("开始批量配准处理...")
# 检查输入文件是否存在
if not Path(config.ref_tif).exists():
log(f"参考文件不存在: {config.ref_tif}")
return
if not Path(config.bip_dir).exists():
log(f"BIP文件夹不存在: {config.bip_dir}")
return
# 创建输出目录
out_dir = Path(config.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# 创建统计输出目录和文件
stats_dir = out_dir / "stats"
stats_dir.mkdir(parents=True, exist_ok=True)
stats_csv = stats_dir / "registration_stats.csv"
# 初始化统计CSV文件
init_stats_csv(stats_csv)
log(f"统计信息将保存到: {stats_csv}")
# 初始化匹配器
log(f"初始化匹配器: {config.matcher_name} on {config.device}")
matcher = get_matcher(config.matcher_name, device=config.device)
# 打开参考文件
with rasterio.open(config.ref_tif) as ref:
log(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}")
# 查找所有 .bip 文件
bip_dir = Path(config.bip_dir)
bip_files = list(bip_dir.glob("*.bip"))
log(f"找到 {len(bip_files)} 个 .bip 文件")
success_count = 0
for i, bip_path in enumerate(bip_files):
if stop_event and stop_event.is_set():
log("处理被用户取消")
break
if on_progress:
on_progress(i, len(bip_files), bip_path.name)
if process_bip_to_tif(bip_path, ref, matcher, out_dir, stats_csv):
success_count += 1
if on_progress:
on_progress(len(bip_files), len(bip_files), "完成")
log(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准")
def main():
"""命令行入口"""
# 使用默认配置运行
config = Config(
ref_tif=REF_TIF,
bip_dir=BIP_DIR,
out_dir=OUT_DIR,
matcher_name=MATCHER_NAME,
device=DEVICE,
transform_methods=TRANSFORM_METHODS,
match_max_side=MATCH_MAX_SIDE,
roi_pad_px=ROI_PAD_PX,
min_inliers=MIN_INLIERS,
min_inlier_ratio=MIN_INLIER_RATIO
)
run_batch(config)
logger.info("开始批量配准处理...")
# 检查输入文件是否存在
if not Path(REF_TIF).exists():
logger.error(f"参考文件不存在: {REF_TIF}")
return
if not BIP_DIR.exists():
logger.error(f"BIP文件夹不存在: {BIP_DIR}")
return
# 初始化统计CSV文件
init_stats_csv(STATS_CSV)
logger.info(f"统计信息将保存到: {STATS_CSV}")
# 初始化匹配器
logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}")
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
# 打开参考文件
with rasterio.open(REF_TIF) as ref:
logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}")
# 查找所有 .bip 文件
bip_files = list(BIP_DIR.glob("*.bip"))
logger.info(f"找到 {len(bip_files)} 个 .bip 文件")
# ---------- GUI 相关 ----------
class QueueHandler(logging.Handler):
"""自定义日志处理器,将日志发送到队列"""
def __init__(self, log_queue):
super().__init__()
self.log_queue = log_queue
def emit(self, record):
self.log_queue.put(self.format(record))
class RegistrationGUI:
def __init__(self, root):
self.root = root
self.root.title("遥感影像批量配准工具")
self.root.geometry("1000x800")
# 日志队列和停止事件
self.log_queue = queue.Queue()
self.stop_event = threading.Event()
self.processing_thread = None
# 设置日志处理器
queue_handler = QueueHandler(self.log_queue)
queue_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(queue_handler)
logger.setLevel(logging.INFO)
# 创建GUI组件
self.create_widgets()
# 定期检查日志队列
self.check_log_queue()
def create_widgets(self):
"""创建GUI组件"""
# 主框架
main_frame = ttk.Frame(self.root, padding="10")
main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
# 配置展开面板
config_frame = ttk.LabelFrame(main_frame, text="配置参数", padding="5")
config_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(0, 10))
# 输入文件选择
ttk.Label(config_frame, text="参考TIF文件:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5))
self.ref_tif_var = tk.StringVar(value=DEFAULT_REF_TIF)
ttk.Entry(config_frame, textvariable=self.ref_tif_var, width=50).grid(row=0, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
ttk.Button(config_frame, text="选择文件", command=self.select_ref_tif).grid(row=0, column=2)
ttk.Label(config_frame, text="BIP文件夹:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5))
self.bip_dir_var = tk.StringVar(value=DEFAULT_BIP_DIR)
ttk.Entry(config_frame, textvariable=self.bip_dir_var, width=50).grid(row=1, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
ttk.Button(config_frame, text="选择文件夹", command=self.select_bip_dir).grid(row=1, column=2)
ttk.Label(config_frame, text="输出文件夹:").grid(row=2, column=0, sticky=tk.W, padx=(0, 5))
self.out_dir_var = tk.StringVar(value=DEFAULT_OUT_DIR)
ttk.Entry(config_frame, textvariable=self.out_dir_var, width=50).grid(row=2, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
ttk.Button(config_frame, text="选择文件夹", command=self.select_out_dir).grid(row=2, column=2)
# 匹配器选择
ttk.Label(config_frame, text="匹配算法:").grid(row=3, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0))
self.matcher_var = tk.StringVar(value=DEFAULT_MATCHER_NAME)
matcher_combo = ttk.Combobox(config_frame, textvariable=self.matcher_var, width=47)
matcher_combo['values'] = [
"liftfeat", "loftr", "eloftr", "se2loftr", "xoftr", "aspanformer",
"matchanything-eloftr", "matchanything-roma", "matchformer",
"sift-lightglue", "superpoint-lightglue", "disk-lightglue",
"aliked-lightglue", "doghardnet-lightglue", "roma", "romav2",
"tiny-roma", "dedode", "steerers", "affine-steerers",
"dedode-kornia", "sift-nn", "orb-nn", "patch2pix", "superglue",
"r2d2", "d2net", "duster", "master", "doghardnet-nn", "xfeat",
"xfeat-star", "xfeat-lightglue", "dedode-lightglue", "gim-dkm",
"gim-lightglue", "omniglue", "xfeat-subpx", "xfeat-lightglue-subpx",
"dedode-subpx", "superpoint-lightglue-subpx", "aliked-lightglue-subpx",
"sift-sphereglue", "superpoint-sphereglue", "minima", "minima-roma",
"minima-roma-tiny", "minima-superpoint-lightglue", "minima-loftr",
"minima-xoftr", "edm", "lisrd-aliked", "lisrd-superpoint", "lisrd",
"lisrd-sift", "ripe", "topicfm", "topicfm-plus", "silk", "zippypoint",
"xfeat-steerers-perm", "xfeat-steerers-learned", "xfeat-star-steerers-perm",
"xfeat-star-steerers-learned"
]
matcher_combo.grid(row=3, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
# 设备选择
ttk.Label(config_frame, text="设备:").grid(row=4, column=0, sticky=tk.W, padx=(0, 5))
self.device_var = tk.StringVar(value=DEFAULT_DEVICE)
device_frame = ttk.Frame(config_frame)
device_frame.grid(row=4, column=1, columnspan=2, sticky=(tk.W, tk.E))
ttk.Radiobutton(device_frame, text="CUDA", variable=self.device_var, value="cuda").pack(side=tk.LEFT)
ttk.Radiobutton(device_frame, text="CPU", variable=self.device_var, value="cpu").pack(side=tk.LEFT)
# 变换方法选择
ttk.Label(config_frame, text="变换方法 (按优先级):").grid(row=5, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0))
transform_frame = ttk.Frame(config_frame)
transform_frame.grid(row=5, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
# 变换方法列表
self.transform_listbox = tk.Listbox(transform_frame, selectmode=tk.MULTIPLE, height=5, exportselection=False)
transform_methods = ["similarity", "affine", "homography", "piecewise_affine", "polynomial"]
for method in transform_methods:
self.transform_listbox.insert(tk.END, method)
if method in DEFAULT_TRANSFORM_METHODS:
self.transform_listbox.selection_set(transform_methods.index(method))
scrollbar = ttk.Scrollbar(transform_frame, orient=tk.VERTICAL, command=self.transform_listbox.yview)
self.transform_listbox.configure(yscrollcommand=scrollbar.set)
self.transform_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
# 移动按钮
button_frame = ttk.Frame(transform_frame)
button_frame.pack(side=tk.RIGHT, padx=(5, 0))
ttk.Button(button_frame, text="↑ 上移", command=self.move_up).pack(fill=tk.X, pady=(0, 2))
ttk.Button(button_frame, text="↓ 下移", command=self.move_down).pack(fill=tk.X)
# 参数设置
param_frame = ttk.LabelFrame(config_frame, text="参数设置", padding="5")
param_frame.grid(row=6, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(10, 0))
ttk.Label(param_frame, text="匹配最大边长:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5))
self.match_max_side_var = tk.IntVar(value=DEFAULT_MATCH_MAX_SIDE)
ttk.Entry(param_frame, textvariable=self.match_max_side_var, width=10).grid(row=0, column=1, sticky=tk.W)
ttk.Label(param_frame, text="ROI填充像素:").grid(row=0, column=2, sticky=tk.W, padx=(10, 5))
self.roi_pad_px_var = tk.IntVar(value=DEFAULT_ROI_PAD_PX)
ttk.Entry(param_frame, textvariable=self.roi_pad_px_var, width=10).grid(row=0, column=3, sticky=tk.W)
ttk.Label(param_frame, text="最少内点数:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0))
self.min_inliers_var = tk.IntVar(value=DEFAULT_MIN_INLIERS)
ttk.Entry(param_frame, textvariable=self.min_inliers_var, width=10).grid(row=1, column=1, sticky=tk.W, pady=(5, 0))
ttk.Label(param_frame, text="最少内点比例:").grid(row=1, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0))
self.min_inlier_ratio_var = tk.DoubleVar(value=DEFAULT_MIN_INLIER_RATIO)
ttk.Entry(param_frame, textvariable=self.min_inlier_ratio_var, width=10).grid(row=1, column=3, sticky=tk.W, pady=(5, 0))
# 控制按钮
control_frame = ttk.Frame(main_frame)
control_frame.grid(row=1, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
self.start_btn = ttk.Button(control_frame, text="开始处理", command=self.start_processing)
self.start_btn.pack(side=tk.LEFT, padx=(0, 10))
self.stop_btn = ttk.Button(control_frame, text="停止处理", command=self.stop_processing, state=tk.DISABLED)
self.stop_btn.pack(side=tk.LEFT)
# 进度条
progress_frame = ttk.LabelFrame(main_frame, text="处理进度", padding="5")
progress_frame.grid(row=2, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
self.progress_var = tk.DoubleVar()
self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100)
self.progress_bar.pack(fill=tk.X, pady=(0, 5))
self.progress_label = ttk.Label(progress_frame, text="准备就绪")
self.progress_label.pack(anchor=tk.W)
# 日志窗口
log_frame = ttk.LabelFrame(main_frame, text="处理日志", padding="5")
log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(10, 0))
# 日志文本框和滚动条
log_text_frame = ttk.Frame(log_frame)
log_text_frame.pack(fill=tk.BOTH, expand=True)
self.log_text = tk.Text(log_text_frame, height=15, wrap=tk.WORD)
scrollbar = ttk.Scrollbar(log_text_frame, orient=tk.VERTICAL, command=self.log_text.yview)
self.log_text.configure(yscrollcommand=scrollbar.set)
self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
# 日志控制按钮
log_btn_frame = ttk.Frame(log_frame)
log_btn_frame.pack(fill=tk.X, pady=(5, 0))
ttk.Button(log_btn_frame, text="清空日志", command=self.clear_log).pack(side=tk.LEFT, padx=(0, 5))
ttk.Button(log_btn_frame, text="保存日志", command=self.save_log).pack(side=tk.LEFT)
# 配置网格权重
self.root.columnconfigure(0, weight=1)
self.root.rowconfigure(0, weight=1)
main_frame.columnconfigure(1, weight=1)
main_frame.rowconfigure(3, weight=1)
def select_ref_tif(self):
"""选择参考TIF文件"""
filename = filedialog.askopenfilename(
title="选择参考TIF文件",
filetypes=[("TIF files", "*.tif"), ("All files", "*.*")]
)
if filename:
self.ref_tif_var.set(filename)
def select_bip_dir(self):
"""选择BIP文件夹"""
dirname = filedialog.askdirectory(title="选择BIP文件夹")
if dirname:
self.bip_dir_var.set(dirname)
def select_out_dir(self):
"""选择输出文件夹"""
dirname = filedialog.askdirectory(title="选择输出文件夹")
if dirname:
self.out_dir_var.set(dirname)
def move_up(self):
"""上移选中的变换方法"""
selection = self.transform_listbox.curselection()
if selection and selection[0] > 0:
idx = selection[0]
text = self.transform_listbox.get(idx)
self.transform_listbox.delete(idx)
self.transform_listbox.insert(idx - 1, text)
self.transform_listbox.selection_set(idx - 1)
def move_down(self):
"""下移选中的变换方法"""
selection = self.transform_listbox.curselection()
if selection and selection[0] < self.transform_listbox.size() - 1:
idx = selection[0]
text = self.transform_listbox.get(idx)
self.transform_listbox.delete(idx)
self.transform_listbox.insert(idx + 1, text)
self.transform_listbox.selection_set(idx + 1)
def start_processing(self):
"""开始处理"""
if self.processing_thread and self.processing_thread.is_alive():
messagebox.showwarning("警告", "处理正在进行中")
return
# 获取选中的变换方法
selected_indices = self.transform_listbox.curselection()
if not selected_indices:
messagebox.showwarning("警告", "请至少选择一种变换方法")
return
transform_methods = []
for idx in selected_indices:
transform_methods.append(self.transform_listbox.get(idx))
# 创建配置
config = Config(
ref_tif=self.ref_tif_var.get(),
bip_dir=self.bip_dir_var.get(),
out_dir=self.out_dir_var.get(),
matcher_name=self.matcher_var.get(),
device=self.device_var.get(),
transform_methods=transform_methods,
match_max_side=self.match_max_side_var.get(),
roi_pad_px=self.roi_pad_px_var.get(),
min_inliers=self.min_inliers_var.get(),
min_inlier_ratio=self.min_inlier_ratio_var.get()
)
# 重置停止事件
self.stop_event.clear()
# 禁用开始按钮,启用停止按钮
self.start_btn.config(state=tk.DISABLED)
self.stop_btn.config(state=tk.NORMAL)
self.progress_var.set(0)
self.progress_label.config(text="正在初始化...")
# 在后台线程中运行处理
self.processing_thread = threading.Thread(
target=self.run_processing,
args=(config,),
daemon=True
)
self.processing_thread.start()
def stop_processing(self):
"""停止处理"""
if self.processing_thread and self.processing_thread.is_alive():
self.stop_event.set()
self.progress_label.config(text="正在停止...")
def run_processing(self, config):
"""在后台线程中运行处理"""
try:
run_batch(config, self.on_progress, self.on_log, self.stop_event)
except Exception as e:
self.log_queue.put(f"处理过程中发生错误: {e}")
finally:
# 恢复按钮状态
self.root.after(0, lambda: self.start_btn.config(state=tk.NORMAL))
self.root.after(0, lambda: self.stop_btn.config(state=tk.DISABLED))
self.root.after(0, lambda: self.progress_label.config(text="处理完成"))
def on_progress(self, current, total, filename):
"""进度回调"""
if total > 0:
progress = (current / total) * 100
self.root.after(0, lambda: self.progress_var.set(progress))
self.root.after(0, lambda: self.progress_label.config(text=f"处理中: {filename} ({current}/{total})"))
def on_log(self, message):
"""日志回调"""
self.log_queue.put(message)
def check_log_queue(self):
"""检查日志队列并更新GUI"""
try:
while True:
message = self.log_queue.get_nowait()
self.log_text.insert(tk.END, message + '\n')
self.log_text.see(tk.END)
except queue.Empty:
pass
# 每100ms检查一次
self.root.after(100, self.check_log_queue)
def clear_log(self):
"""清空日志"""
self.log_text.delete(1.0, tk.END)
def save_log(self):
"""保存日志"""
filename = filedialog.asksaveasfilename(
title="保存日志",
defaultextension=".txt",
filetypes=[("Text files", "*.txt"), ("All files", "*.*")]
)
if filename:
with open(filename, 'w', encoding='utf-8') as f:
f.write(self.log_text.get(1.0, tk.END))
def create_gui():
"""创建GUI"""
root = tk.Tk()
app = RegistrationGUI(root)
root.mainloop()
if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] == "--cli":
# 命令行模式
main()
else:
# 默认GUI模式
create_gui()