Files
StripStitch/StripStitch.py
2026-04-22 09:26:39 +08:00

2218 lines
98 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批量配准 .bip 文件到参考 .tif 文件
问题:当图像中大部分是水体时,匹配过多出现在掩膜边缘,同时过滤时将本来就少的陆地匹配点也过滤掉了
"""
import sys
import os
# Fix for PyInstaller GUI apps: ensure stdout/stderr are never None
# This prevents 'NoneType' object has no attribute 'write' errors
# when libraries like PyTorch try to print download progress
if sys.stdout is None:
sys.stdout = open(os.devnull, 'w')
if sys.stderr is None:
sys.stderr = open(os.devnull, 'w')
from pathlib import Path
def _early_pyinstaller_hf_env():
"""必须在 import vismatch 之前执行vismatch/__init__.py 会立即 import huggingface_hub。"""
if not hasattr(sys, "_MEIPASS"):
return
base = Path(sys._MEIPASS)
exe_dir = Path(sys.executable).resolve().parent
hf_candidates = [
base / "hub",
base / "_internal" / "hub",
exe_dir / "_internal" / "hub",
exe_dir / "hub",
]
for hf_candidate in hf_candidates:
try:
if not hf_candidate.exists():
continue
if not any("vismatch" in d.name.lower() for d in hf_candidate.iterdir() if d.is_dir()):
continue
except OSError:
continue
os.environ.setdefault("HF_HOME", str(hf_candidate.parent))
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hf_candidate))
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
break
_early_pyinstaller_hf_env()
import numpy as np
import cv2
import rasterio
import csv
from datetime import datetime
from rasterio.windows import from_bounds
from rasterio.warp import transform_bounds, reproject, Resampling
from affine import Affine
from vismatch import get_matcher
from vismatch.viz import plot_matches, plot_keypoints
import logging
import threading
import queue
import sys
import traceback
import types
from dataclasses import dataclass
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
try:
from tif_caijain import mask_data_by_binary_mask
TIF_MASK_AVAILABLE = True
except Exception:
TIF_MASK_AVAILABLE = False
try:
from skimage.transform import PiecewiseAffineTransform, PolynomialTransform
SKIMAGE_AVAILABLE = True
except ImportError:
SKIMAGE_AVAILABLE = False
logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换")
try:
from matplotlib.path import Path as MplPath
from scipy.spatial import ConvexHull
MATPLOTLIB_SCIPY_AVAILABLE = True
except ImportError:
MATPLOTLIB_SCIPY_AVAILABLE = False
MplPath = None
logging.warning("matplotlib 或 scipy 不可用piecewise_affine 将退化为矩形内判断")
try:
import SimpleITK as sitk
SITK_AVAILABLE = True
except ImportError:
SITK_AVAILABLE = False
logging.warning("SimpleITK 不可用,将使用仿射变换作为替代")
try:
import pirt
PIRT_AVAILABLE = True
except ImportError:
PIRT_AVAILABLE = False
logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代")
try:
from scipy.interpolate import Rbf
SCIPY_AVAILABLE = True
except ImportError:
SCIPY_AVAILABLE = False
logging.warning("scipy 不可用,将跳过 TPS 变换")
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def _ensure_pyinstaller_third_party_paths():
if not hasattr(sys, "_MEIPASS"):
return
base = Path(sys._MEIPASS)
exe_dir = Path(sys.executable).resolve().parent
# More comprehensive candidate paths for third_party
candidates = [
base / "vismatch" / "third_party",
base / "_internal" / "vismatch" / "third_party",
exe_dir / "_internal" / "vismatch" / "third_party",
exe_dir / "vismatch" / "third_party",
base / "third_party", # In case vismatch is directly included
]
third_party_base = None
for c in candidates:
if c.exists():
third_party_base = c
logger.info(f"找到 third_party 目录: {third_party_base}")
break
if third_party_base is None:
logger.warning(f"未找到 third_party 目录MEIPASS={base}, exe_dir={exe_dir}")
# List what's available for debugging
try:
if base.exists():
logger.info(f"MEIPASS 内容: {list(base.iterdir())[:10]}")
if exe_dir.exists():
logger.info(f"exe_dir 内容: {list(exe_dir.iterdir())[:10]}")
except Exception as e:
logger.warning(f"无法列出目录内容: {e}")
return
# Try multiple possible structures for MatchAnything
matchanything_candidates = [
# Original expected structure
third_party_base / "MatchAnything" / "imcui" / "third_party" / "MatchAnything",
# Alternative: direct MatchAnything without the nested imcui structure
third_party_base / "MatchAnything",
# Alternative: MatchAnything with imcui but different nesting
third_party_base / "MatchAnything" / "MatchAnything",
# One more level up possibility
third_party_base.parent / "MatchAnything" / "imcui" / "third_party" / "MatchAnything",
]
matchanything_root = None
for candidate in matchanything_candidates:
# Handle case where candidate already ends with 'src' or needs src subdirectory check
has_src = (candidate / "src").exists() if not str(candidate).endswith("src") else candidate.exists()
if candidate.exists() and has_src:
# If candidate ends with src, use its parent as root
matchanything_root = candidate.parent if str(candidate).endswith("src") else candidate
logger.info(f"找到 MatchAnything 根目录: {matchanything_root}")
break
if matchanything_root is None:
logger.warning(f"未找到 MatchAnything 目录,尝试的路径:")
for c in matchanything_candidates:
logger.warning(f" - {c} (exists={c.exists()})")
# Last resort: search recursively for any directory containing 'src' and 'matchanything' in path
try:
for root, dirs, files in os.walk(third_party_base):
root_path = Path(root)
if "matchanything" in root.lower() and (root_path / "src").exists():
matchanything_root = root_path
logger.info(f"通过递归搜索找到 MatchAnything: {matchanything_root}")
break
# Also check if this directory has a 'src' subdirectory
if (root_path / "src").exists():
# Check if it looks like MatchAnything (has specific files)
src_files = list((root_path / "src").glob("*.py"))[:5]
if src_files:
matchanything_root = root_path
logger.info(f"通过递归搜索找到潜在 MatchAnything: {matchanything_root}")
break
except Exception as e:
logger.warning(f"递归搜索失败: {e}")
if matchanything_root is None:
return
# Add MatchAnything root to path (contains 'src' module)
p = str(matchanything_root)
if p not in sys.path:
sys.path.insert(0, p)
logger.info(f"已添加 MatchAnything 到 sys.path: {p}")
# Try multiple possible ROMA paths
roma_candidates = [
matchanything_root / "third_party" / "ROMA",
third_party_base / "ROMA",
third_party_base / "MatchAnything" / "third_party" / "ROMA",
matchanything_root.parent / "ROMA",
]
roma_root = None
for candidate in roma_candidates:
if candidate.exists():
roma_root = candidate
logger.info(f"找到 ROMA 目录: {roma_root}")
break
if roma_root:
p2 = str(roma_root)
if p2 not in sys.path:
sys.path.insert(0, p2)
logger.info(f"已添加 ROMA 到 sys.path: {p2}")
else:
logger.warning(f"未找到 ROMA 目录")
# HuggingFace 缓存:优先已在 _early_pyinstaller_hf_env() 中设置(须在 import vismatch 前)
if hasattr(sys, "_MEIPASS"):
hf_candidates = [
base / "hub",
base / "_internal" / "hub",
exe_dir / "_internal" / "hub",
exe_dir / "hub",
]
for hf_candidate in hf_candidates:
try:
if not hf_candidate.exists():
continue
if not any("vismatch" in d.name.lower() for d in hf_candidate.iterdir() if d.is_dir()):
continue
except OSError:
continue
os.environ.setdefault("HF_HOME", str(hf_candidate.parent))
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hf_candidate))
os.environ.setdefault("HF_HUB_OFFLINE", "1")
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
logger.info(
f"HuggingFace 缓存: {os.environ.get('HUGGINGFACE_HUB_CACHE')} "
f"(HF_HUB_OFFLINE={os.environ.get('HF_HUB_OFFLINE')})"
)
break
def _install_loguru_stub_if_missing():
try:
import loguru # noqa: F401
return
except Exception:
pass
py_logger = logging.getLogger("loguru")
class _StubLogger:
def debug(self, msg, *args, **kwargs):
py_logger.debug(msg, *args)
def info(self, msg, *args, **kwargs):
py_logger.info(msg, *args)
def warning(self, msg, *args, **kwargs):
py_logger.warning(msg, *args)
def error(self, msg, *args, **kwargs):
py_logger.error(msg, *args)
def exception(self, msg, *args, **kwargs):
py_logger.exception(msg, *args)
def add(self, *args, **kwargs):
return 0
def remove(self, *args, **kwargs):
return None
m = types.ModuleType("loguru")
m.logger = _StubLogger()
sys.modules["loguru"] = m
# ---------- 配置 ----------
# 请根据实际情况修改这些路径
REF_TIF = r"E:\is2\dingshanhu\mask_water.tif" # 参考 tif 文件路径
BIP_DIR = Path(r"E:\is2\dingshanhu") # .bip 文件所在文件夹
OUT_DIR = Path(r"E:\is2\dingshanhu\output") # 输出文件夹
# 匹配算法选择
MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
DEVICE = "cuda" # 或 "cpu"
# 变换方法选择(按优先级尝试)
TRANSFORM_METHODS = ["similarity", "affine", "homography"]
# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps"
# 匹配参数
MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素)
ROI_PAD_PX = 500 # 粗定位窗口的padding参考tif像素
MASK_PAD_PX = 100 # 匹配掩膜扩张像素(仅用于匹配阶段)
# 质量控制阈值
MIN_INLIERS = 10
MIN_INLIER_RATIO = 0.01
# 掩膜边缘羽化与过滤
FEATHER_PX = 20 # 掩膜羽化宽度(像素,先在全分辨率/ROI分辨率上做)
EDGE_BAND_PX = 30 # 剔除距离掩膜边界小于此像素的匹配点(在小图上按比例缩放)
# 纹理过滤
MIN_GRAD_QUANTILE = 0.20 # 梯度幅值的分位阈值(0~1),低于该阈值的点视为低纹理,剔除
STATS_DIR = None
STATS_CSV = None
@dataclass
class RegistrationConfig:
ref_tif: str
bip_dir: str
out_dir: str
enable_ref_mask: bool
ref_mask_tif: str
ref_mask_remove_value: int
matcher_name: str
device: str
transform_methods: list
match_max_side: int
roi_pad_px: int
mask_pad_px: int
min_inliers: int
min_inlier_ratio: float
feather_px: int
edge_band_px: int
min_grad_quantile: float
# ---------- 工具函数 ----------
def init_stats_csv(csv_path: Path):
"""初始化统计CSV文件"""
if not csv_path.exists():
with open(csv_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow([
'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio',
'selected_method', 'median_error', 'p95_error', 'success'
])
def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int,
inlier_ratio: float, selected_method: str, median_error: float,
p95_error: float, success: bool):
"""记录配准统计信息到CSV"""
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(csv_path, 'a', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow([
timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}",
selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success
])
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
arr = arr_chw.astype(np.float32)
if arr.shape[0] == 1:
# 单波段复制为3通道
arr = np.repeat(arr, 3, axis=0)
elif arr.shape[0] >= 3:
# 取前3波段
arr = arr[:3]
else:
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
# 百分位数拉伸,增强跨传感器匹配稳定性
p2 = np.percentile(arr, 2)
p98 = np.percentile(arr, 98)
arr = (arr - p2) / (p98 - p2 + 1e-6)
arr = np.clip(arr, 0.0, 1.0)
return arr
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
c, h, w = arr_chw.shape
s = min(1.0, max_side / max(h, w))
if s >= 1.0:
return arr_chw
new_w = int(round(w * s))
new_h = int(round(h * s))
# 用opencv缩放(逐通道)
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
return out
def _expand_window(win, pad, max_w, max_h):
"""扩展窗口并确保边界有效"""
col_off = int(max(0, win.col_off - pad))
row_off = int(max(0, win.row_off - pad))
col_end = int(min(max_w, win.col_off + win.width + pad))
row_end = int(min(max_h, win.row_off + win.height + pad))
return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off)
def _pixel_size_xy(transform: Affine):
rx = float(np.hypot(transform.a, transform.d))
ry = float(np.hypot(transform.b, transform.e))
if not np.isfinite(rx) or rx <= 0:
rx = float(abs(transform.a)) if transform.a != 0 else 1.0
if not np.isfinite(ry) or ry <= 0:
ry = float(abs(transform.e)) if transform.e != 0 else 1.0
return rx, ry
def _grid_from_bounds(bounds, res_x: float, res_y: float):
left, bottom, right, top = [float(v) for v in bounds]
res_x = float(abs(res_x))
res_y = float(abs(res_y))
w = int(np.ceil((right - left) / max(1e-12, res_x)))
h = int(np.ceil((top - bottom) / max(1e-12, res_y)))
w = max(1, w)
h = max(1, h)
out_transform = Affine(res_x, 0.0, left, 0.0, -res_y, top)
return out_transform, w, h
def estimate_transform(method, k0, k1):
"""统一的变换估计函数,支持多种变换类型"""
if method == "translation":
# 简单平移:用内点的平均位移
if len(k0) == 0:
return None, None
dx = np.mean(k1[:, 0] - k0[:, 0])
dy = np.mean(k1[:, 1] - k0[:, 1])
A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32)
return "A", A
elif method == "euclidean":
# 欧式变换(旋转+平移),约束等比缩放=1
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "similarity":
# 相似变换(旋转+等比缩放+平移)
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "affine":
# 全仿射变换(旋转+非等比缩放+剪切+平移)
A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
return "A", A
elif method == "homography":
# 投影变换8DOF透视
H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0)
return "H", H
elif method == "piecewise_affine":
# 分片仿射变换
if not SKIMAGE_AVAILABLE:
return None, None
try:
tform = PiecewiseAffineTransform()
tform.estimate(k0, k1)
return "piecewise", tform
except Exception:
return None, None
elif method == "polynomial":
# 多项式变换2阶
if not SKIMAGE_AVAILABLE:
return None, None
try:
tform = PolynomialTransform()
tform.estimate(k0, k1, order=2)
return "polynomial", tform
except Exception:
return None, None
else:
raise ValueError(f"未知变换方法: {method}")
def evaluate_transform_quality(transform_type, transform, k0, k1):
"""评估变换质量(重投影误差)"""
if transform is None or len(k0) == 0:
return np.inf, np.inf
if transform_type == "A":
# 仿射变换重投影误差
A = transform
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
pred = (A @ np.hstack([k0, ones]).T).T
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type == "H":
# 单应变换重投影误差
H = transform
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
src_h = np.hstack([k0, ones]).T
warped = H @ src_h
warped /= (warped[2:3, :] + 1e-6)
pred = warped[:2, :].T
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
elif transform_type in ["piecewise", "polynomial"]:
# scikit-image 变换重投影误差
pred = transform(k0)
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
else:
return np.inf, np.inf
return float(np.median(e)), float(np.percentile(e, 95))
def _norm01_hw(x: np.ndarray) -> np.ndarray:
"""对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性"""
x = x.astype(np.float32, copy=False)
p2 = float(np.percentile(x, 2))
p98 = float(np.percentile(x, 98))
y = (x - p2) / (p98 - p2 + 1e-6)
return np.clip(y, 0.0, 1.0)
def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)):
"""
numpy(H,W)->SimpleITK Image。
物理坐标约定为“像素坐标系”spacing=1, direction=Iorigin=(x0,y0)。
"""
img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False))
img.SetSpacing((1.0, 1.0))
img.SetOrigin((float(origin_xy[0]), float(origin_xy[1])))
img.SetDirection((1.0, 0.0, 0.0, 1.0))
return img
def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10):
"""用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内"""
min_x = int(np.floor(k1_global[:, 0].min())) - pad
max_x = int(np.ceil (k1_global[:, 0].max())) + pad
min_y = int(np.floor(k1_global[:, 1].min())) - pad
max_y = int(np.ceil (k1_global[:, 1].max())) + pad
min_x = max(0, min_x)
min_y = max(0, min_y)
max_x = min(ref_w, max_x)
max_y = min(ref_h, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
return min_x, min_y, bbox_w, bbox_h
def _downscale_mask_hw(mask_hw: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
"""将(H,W)二值掩膜缩放到目标尺寸,保持最近邻"""
m = cv2.resize(mask_hw.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST)
return m > 0
def _soft_alpha_from_mask(mask_hw: np.ndarray, feather_px: int) -> np.ndarray:
"""
二值掩膜 -> 软掩膜 alpha∈[0,1],边缘处按距离线性上升,避免硬边缘。
mask_hw: bool/uint8 (H,W) True/1表示有效
"""
if mask_hw is None:
return None
m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255
# 距离变换仅对前景内部有效,计算到边界的距离
dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3)
if feather_px <= 0:
alpha = (dist > 0).astype(np.float32)
else:
alpha = np.clip(dist / float(feather_px), 0.0, 1.0).astype(np.float32)
return alpha # (H,W) float32
def _distance_keep_mask(mask_hw: np.ndarray, min_dist_px: int) -> np.ndarray:
"""
生成"远离边界"的保留掩膜:仅保留距离边界>=min_dist_px的像素。
"""
if mask_hw is None:
return None
m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255
dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3)
keep = dist >= float(max(1, min_dist_px))
return keep
def _grad_mask_from_chw(img_chw: np.ndarray, quantile: float) -> np.ndarray:
"""
根据梯度幅值生成纹理掩膜H,WTrue=纹理足够。
使用与匹配同尺寸的CHW图像。
"""
# 转灰度
g = img_chw.mean(axis=0).astype(np.float32) # (H,W)
gx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3)
gy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3)
mag = np.sqrt(gx*gx + gy*gy)
thr = float(np.quantile(mag, quantile)) if mag.size > 0 else 0.0
return mag >= thr # (H,W) bool
def _filter_matches_by_masks(result: dict, src_mask_small: np.ndarray, ref_mask_small: np.ndarray) -> dict:
"""将匹配与内点严格限制在掩膜内"""
if src_mask_small is None or ref_mask_small is None:
return result
def keep_in_mask(kpts: np.ndarray, mask_hw: np.ndarray) -> np.ndarray:
if kpts is None or len(kpts) == 0:
return np.zeros((0,), dtype=bool)
kpts = np.asarray(kpts)
xs = np.clip(np.rint(kpts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1)
ys = np.clip(np.rint(kpts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1)
return mask_hw[ys, xs]
# 过滤 matched_kpts
if "matched_kpts0" in result and "matched_kpts1" in result:
mk0 = np.asarray(result["matched_kpts0"])
mk1 = np.asarray(result["matched_kpts1"])
if len(mk0) == len(mk1) and len(mk0) > 0:
keep_m = keep_in_mask(mk0, src_mask_small) & keep_in_mask(mk1, ref_mask_small)
result["matched_kpts0"] = mk0[keep_m]
result["matched_kpts1"] = mk1[keep_m]
# 过滤 inlier_kpts
if "inlier_kpts0" in result and "inlier_kpts1" in result and result["inlier_kpts0"] is not None:
ik0 = np.asarray(result["inlier_kpts0"])
ik1 = np.asarray(result["inlier_kpts1"])
if len(ik0) == len(ik1) and len(ik0) > 0:
keep_i = keep_in_mask(ik0, src_mask_small) & keep_in_mask(ik1, ref_mask_small)
result["inlier_kpts0"] = ik0[keep_i]
result["inlier_kpts1"] = ik1[keep_i]
result["num_inliers"] = int(len(result["inlier_kpts0"]))
return result
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path):
"""处理单个 .bip 文件到参考 .tif 的配准"""
try:
with rasterio.open(bip_path) as src:
logger.info(f"处理文件: {bip_path.name}")
# 初始化统计变量
num_inliers = 0
num_matches = 0
inlier_ratio = 0.0
selected_method = "none"
median_error = float('inf')
p95_error = float('inf')
success = False
# 检查CRS
if src.crs is None:
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息尝试使用参考文件的CRS")
src_crs = ref_dataset.crs
else:
src_crs = src.crs
ref_crs = ref_dataset.crs
if ref_crs is None:
raise RuntimeError(f"参考文件缺少CRS信息")
# 1) 用"源图有效掩膜"的包围盒推参考ROI比整图bounds更贴近有效重叠
try:
src_mask = (src.read_masks(1) > 0) # True=有效
rows_any = np.any(src_mask, axis=1)
cols_any = np.any(src_mask, axis=0)
if rows_any.any() and cols_any.any():
rmin = int(rows_any.argmax())
rmax = int(src.height - 1 - rows_any[::-1].argmax())
cmin = int(cols_any.argmax())
cmax = int(src.width - 1 - cols_any[::-1].argmax())
valid_win_src = rasterio.windows.Window(cmin, rmin, cmax - cmin + 1, rmax - rmin + 1)
valid_bounds_src = rasterio.windows.bounds(valid_win_src, transform=src.transform)
b = transform_bounds(src_crs, ref_crs, *valid_bounds_src, densify_pts=21)
else:
# 掩膜无效时回退到整图bounds
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
except Exception:
src_mask = None # 后续可选源图掩膜时用到
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
win0 = from_bounds(*b, transform=ref_dataset.transform)
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
if win.width <= 0 or win.height <= 0:
logger.warning(f"无重叠区域: {bip_path.name}")
return False
# 2) 读取数据
# 读取所有波段,如果是多波段的话
src_arr = src.read() # (bands, H, W)
if src_arr.ndim == 2: # 单波段
src_arr = src_arr[None, ...] # 增加波段维度
# 读取参考文件的ROI
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
if ref_arr.ndim == 2: # 单波段
ref_arr = ref_arr[None, ...] # 增加波段维度
# 将源图有效掩膜重投影到参考ROI并适度膨胀后作为匹配掩膜
try:
if src_mask is None:
src_mask = (src.read_masks(1) > 0)
ref_roi_transform = ref_dataset.window_transform(win)
roi_h, roi_w = int(win.height), int(win.width)
dst_mask = np.zeros((roi_h, roi_w), dtype=np.uint8)
reproject(
source=src_mask.astype(np.uint8),
destination=dst_mask,
src_transform=src.transform,
src_crs=src_crs,
dst_transform=ref_roi_transform,
dst_crs=ref_crs,
resampling=Resampling.nearest
)
if MASK_PAD_PX > 0:
k = max(1, MASK_PAD_PX * 2 + 1) # odd kernel size
k = min(k, 99) # 防止核过大导致性能问题,可按需调整/删除
kernel = np.ones((k, k), np.uint8)
dst_mask = cv2.dilate(dst_mask, kernel, iterations=1)
except Exception:
# 掩膜获取/重投影失败则不使用掩膜
dst_mask = None
# 转换为匹配所需的格式
src_img = _to_3ch_float01(src_arr)
ref_img = _to_3ch_float01(ref_arr)
# 软掩膜:避免在边界产生硬高对比边
try:
alpha_src = _soft_alpha_from_mask(src_mask, FEATHER_PX) if src_mask is not None else None
except Exception:
alpha_src = None
try:
alpha_ref = _soft_alpha_from_mask(dst_mask, FEATHER_PX) if dst_mask is not None else None
except Exception:
alpha_ref = None
if alpha_src is not None:
alpha_src3 = np.repeat(alpha_src[None, ...], 3, axis=0).astype(src_img.dtype)
src_img = src_img * alpha_src3
if alpha_ref is not None:
alpha_ref3 = np.repeat(alpha_ref[None, ...], 3, axis=0).astype(ref_img.dtype)
ref_img = ref_img * alpha_ref3
# 3) 匹配用降采样版本,提速 + 增稳
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
# 4) 精配准img0=src, img1=ref_roi
result = matcher(src_small, ref_small)
# 与小图同尺寸的掩膜
src_mask_small = _downscale_mask_hw(src_mask, src_small.shape[1], src_small.shape[2]) if 'src_mask' in locals() and src_mask is not None else None
ref_mask_small = _downscale_mask_hw(dst_mask, ref_small.shape[1], ref_small.shape[2]) if 'dst_mask' in locals() and dst_mask is not None else None
# 剔除掩膜边缘带(小图尺度的最小距离)
def _scale_px(px_full: int, full_wh, small_wh) -> int:
# 用平均缩放也可以分别对H/W计算后取最小
sy = small_wh[0] / max(1, full_wh[0])
sx = small_wh[1] / max(1, full_wh[1])
s = 0.5 * (sx + sy)
return max(1, int(round(px_full * s)))
edge_band_src_small = _scale_px(EDGE_BAND_PX, (src_img.shape[1], src_img.shape[2]), (src_small.shape[1], src_small.shape[2]))
edge_band_ref_small = _scale_px(EDGE_BAND_PX, (ref_img.shape[1], ref_img.shape[2]), (ref_small.shape[1], ref_small.shape[2]))
keep_src_edge = _distance_keep_mask(src_mask_small, edge_band_src_small) if src_mask_small is not None else None
keep_ref_edge = _distance_keep_mask(ref_mask_small, edge_band_ref_small) if ref_mask_small is not None else None
# 纹理掩膜
keep_src_tex = _grad_mask_from_chw(src_small, MIN_GRAD_QUANTILE)
keep_ref_tex = _grad_mask_from_chw(ref_small, MIN_GRAD_QUANTILE)
# 组合最终保留掩膜(边缘+纹理),二者都要满足
def _combine_keep(m_edge, m_tex):
if m_edge is None:
return m_tex
return (m_edge & m_tex)
keep_src_final = _combine_keep(keep_src_edge, keep_src_tex)
keep_ref_final = _combine_keep(keep_ref_edge, keep_ref_tex)
# 将匹配与内点严格限制在最终掩膜内
def _filter_by_bool_masks(res, m_src, m_ref):
if m_src is None or m_ref is None:
return res
def keep_in(mask_hw, pts):
if pts is None or len(pts) == 0:
return np.zeros((0,), dtype=bool)
xs = np.clip(np.rint(pts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1)
ys = np.clip(np.rint(pts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1)
return mask_hw[ys, xs]
# matched
if "matched_kpts0" in res and "matched_kpts1" in res:
mk0 = np.asarray(res["matched_kpts0"]); mk1 = np.asarray(res["matched_kpts1"])
if len(mk0) == len(mk1) and len(mk0) > 0:
keep_m = keep_in(m_src, mk0) & keep_in(m_ref, mk1)
res["matched_kpts0"] = mk0[keep_m]
res["matched_kpts1"] = mk1[keep_m]
# inliers
if "inlier_kpts0" in res and "inlier_kpts1" in res and res["inlier_kpts0"] is not None:
ik0 = np.asarray(res["inlier_kpts0"]); ik1 = np.asarray(res["inlier_kpts1"])
if len(ik0) == len(ik1) and len(ik0) > 0:
keep_i = keep_in(m_src, ik0) & keep_in(m_ref, ik1)
res["inlier_kpts0"] = ik0[keep_i]
res["inlier_kpts1"] = ik1[keep_i]
res["num_inliers"] = int(len(res["inlier_kpts0"]))
return res
result = _filter_by_bool_masks(result, keep_src_final, keep_ref_final)
# 统计(以过滤后的结果为准)
num_inl = int(result.get("num_inliers", len(result.get("inlier_kpts0", []))))
num_m = len(result.get("matched_kpts0", []))
ratio = (num_inl / num_m) if num_m else 0.0
# 更新统计变量
num_inliers = num_inl
num_matches = num_m
inlier_ratio = ratio
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
# 保存匹配可视化图像使用与匹配同尺寸的图像保持CHW格式
viz_dir = out_dir / "visualizations"
viz_dir.mkdir(exist_ok=True)
matches_path = viz_dir / f"{bip_path.stem}_matches.png"
plot_matches(src_small, ref_small, result, save_path=str(matches_path))
logger.info(f"匹配可视化已保存: {matches_path}")
# 关键点可视化(源图像)
kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png"
plot_keypoints(
src_small,
{"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]},
save_path=str(kpts_src_path)
)
logger.info(f"源图像关键点可视化已保存: {kpts_src_path}")
# 关键点可视化(参考图像)
kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png"
plot_keypoints(
ref_small,
{"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]},
save_path=str(kpts_ref_path)
)
logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}")
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
logger.warning(f"匹配质量不足: {bip_path.name}")
# 记录失败的统计信息
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "failed_quality_check", median_error, p95_error, False)
return False
# 5) 用内点估计多种变换并自动选择最优
# 先计算全分辨率坐标
k0_small = result["inlier_kpts0"].astype(np.float32)
k1_small = result["inlier_kpts1"].astype(np.float32)
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src)
S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI)
ones = np.ones((k0_small.shape[0], 1), dtype=np.float32)
k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素
k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素
k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素
# 用全分辨率坐标进行所有模型的估计和评估
best_transform = None
best_transform_type = None
best_error = np.inf
best_median_error = np.inf
best_method = None
for method in TRANSFORM_METHODS:
transform_type, transform = estimate_transform(method, k0_full, k1_global)
if transform is None:
continue
med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global)
# 选择重投影误差最小的变换
if p95_err < best_error:
best_transform = transform
best_transform_type = transform_type
best_error = p95_err
best_median_error = med_err
best_method = method
logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}")
if best_transform is None:
logger.warning(f"所有变换方法都失败: {bip_path.name}")
# 记录失败的统计信息
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "failed_transform", median_error, p95_error, False)
return False
# 更新统计变量
selected_method = best_method
median_error = best_median_error
p95_error = best_error
logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}")
# 6) 根据变换类型进行相应的配准处理
if best_transform_type == "A":
# 仿射变换A 已是 src_full_pixel -> ref_full_pixel直接构造像素->地图仿射
A = best_transform # 2x3, src_full_pixel -> ref_full_pixel
A3 = np.eye(3, dtype=np.float64)
A3[:2, :] = A
# src_pixel -> map
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
M_map = Rt @ A3
corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2],
M_map[1,0], M_map[1,1], M_map[1,2])
# 用 M_map 求最小外接矩形(先到 map再到 ref 像素)
Rt_inv = np.linalg.inv(Rt)
src_h, src_w = src.height, src.width
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corn_h = np.hstack([corners, np.ones((4,1))]).T
map_corners = (M_map @ corn_h).T[:, :2]
pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2]
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil (pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil (pix_corners[:,1].max())) + 10
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform)
res_x, res_y = _pixel_size_xy(src.transform)
out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = src.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=out_h,
width=out_w,
count=src.count,
transform=out_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((out_h, out_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=out_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.nearest,
)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, selected_method, median_error, p95_error, success)
return True
# ---- 非仿射变换处理 ----
elif best_transform_type == "H":
# 单应变换H 已是 src_full_pixel -> ref_full_pixel
H_full = best_transform # 3x3
try:
# 用 H_full 映射源四角 -> 参考像素,求最小外接矩形
src_h, src_w = src.height, src.width
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32)
corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T
dst_h = (H_full @ corn_h)
dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T
min_x = int(np.floor(dst[:,0].min())) - 10
max_x = int(np.ceil (dst[:,0].max())) + 10
min_y = int(np.floor(dst[:,1].min())) - 10
max_y = int(np.ceil (dst[:,1].max())) + 10
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform)
res_x, res_y = _pixel_size_xy(src.transform)
out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = src.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=out_h,
width=out_w,
count=src.count,
transform=out_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
ref_transform = ref_dataset.transform
Rt = np.array(
[[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0.0, 0.0, 1.0]],
dtype=np.float64,
)
Out = np.array(
[[out_transform.a, out_transform.b, out_transform.c],
[out_transform.d, out_transform.e, out_transform.f],
[0.0, 0.0, 1.0]],
dtype=np.float64,
)
M = np.linalg.inv(Out) @ Rt @ H_full.astype(np.float64)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = cv2.warpPerspective(
src_band,
M,
(out_w, out_h),
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=float(dst_nodata)
).astype(np.float32)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, selected_method, median_error, p95_error, success)
return True
except Exception as e:
logger.warning(f"单应变换异常: {e}")
# 继续到仿射回退
elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]:
# 分片仿射或多项式变换:使用 scikit-image
transform = best_transform # 已用 k0_full/k1_global 估计
try:
# 用目标侧匹配点k1_global决定外接矩形更稳
pad = 10
min_x = int(np.floor(k1_global[:, 0].min())) - pad
max_x = int(np.ceil (k1_global[:, 0].max())) + pad
min_y = int(np.floor(k1_global[:, 1].min())) - pad
max_y = int(np.ceil (k1_global[:, 1].max())) + pad
min_x = max(0, min_x)
min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}")
return False
# 创建输出窗口
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 定义带偏移的逆映射函数
off_x, off_y = min_x, min_y
if best_transform_type in ["polynomial", "polynomial_order3"]:
# 对于多项式,估计逆变换
order = 2 if best_transform_type == "polynomial" else 3
t_inv = PolynomialTransform()
t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源
# 目标侧点集的内点判定(用于限制外推)
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y],
[min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
def point_inside(xy):
return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) &
(xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h))
def inv_map_rc(coords):
# coords: (N,2) in (row, col)
rc = np.asarray(coords)
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref
inside = point_inside(xy)
xy_src = np.full_like(xy, fill_value=-1.0)
if np.any(inside):
xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src
# 确保坐标在源图像范围内
xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1)
xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1)
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
elif best_transform_type == "piecewise": # piecewise_affine
# 目标侧点集的内点判定
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
# 使用当前裁剪窗口的边界创建矩形
rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
# 退化为矩形内判断
def point_inside(xy):
return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \
(xy[:,1] >= min_y) & (xy[:,1] <= max_y)
def inv_map_rc(coords):
rc = np.asarray(coords)
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref
inside = point_inside(xy)
xy_src = np.full_like(xy, fill_value=-1.0)
if np.any(inside):
xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src)
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
# 使用 scikit-image 进行变换重采样
from skimage.transform import warp
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = warp(
src_band,
inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射
output_shape=(bbox_h, bbox_w),
mode='constant',
cval=dst_nodata,
preserve_range=True,
order=0
).astype(np.float32)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, selected_method, median_error, p95_error, success)
return True
except Exception as e:
logger.warning(f"{best_transform_type}变换异常: {e}")
# 继续到仿射回退
# ---- 回退:使用仿射变换,保证最小可用结果 ----
transform = best_transform
try:
min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1(
k1_global, ref_dataset.width, ref_dataset.height, pad=10
)
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
if MATPLOTLIB_SCIPY_AVAILABLE:
try:
hull = ConvexHull(k1_global)
hull_path = MplPath(k1_global[hull.vertices])
except Exception:
rect = np.array(
[[min_x, min_y], [min_x + bbox_w, min_y],
[min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]],
dtype=float
)
hull_path = MplPath(rect)
def point_inside(xy):
return hull_path.contains_points(xy)
else:
def point_inside(xy):
return (
(xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) &
(xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h)
)
off_x, off_y = min_x, min_y
tps_inv = transform["inv"] # ref -> src
def inv_map_rc(coords):
rc = np.asarray(coords, dtype=np.float64)
xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y)
inside = point_inside(xy_ref)
xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64)
if np.any(inside):
# 使用RBF插值计算逆映射
xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1])
xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1])
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 优先用 skimage.warp缺失时用 SimpleITK Resample 兜底
if SKIMAGE_AVAILABLE:
from skimage.transform import warp
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = warp(
src_band,
inverse_map=inv_map_rc,
output_shape=(bbox_h, bbox_w),
mode='constant',
cval=dst_nodata,
preserve_range=True,
order=0
).astype(np.float32)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
else:
# OpenCV remap 版本(无需 skimage/SimpleITK
with rasterio.open(out_path, "w", **out_profile) as out_ds:
# 创建映射网格
y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w]
coords = np.column_stack([y_coords.ravel(), x_coords.ravel()])
# 计算逆映射
mapped_coords = inv_map_rc(coords)
map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32)
map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32)
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
# 使用OpenCV的remap进行重采样
dst_band = cv2.remap(
src_band, map_x, map_y,
interpolation=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=dst_nodata
)
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.warning(f"tps变换异常: {e}")
# 继续到仿射回退
# ---- 回退:使用仿射变换,保证最小可用结果 ----
# 重新估计仿射变换作为fallback
A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0)
if A_fallback is None:
logger.warning(f"仿射回退也失败: {bip_path.name}")
return False
# 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64)
S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64)
A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback
M_full = S1_inv @ A3 @ S0
T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64)
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
src_pixel_to_map_corrected = Rt @ T_off @ M_full
corrected_affine = Affine(
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
)
# 计算源 BIP 四角经过仿射变换后的最小外接矩形
# 将 rasterio.Affine 转为 3x3 像素->地图矩阵
M_map = np.array([
[corrected_affine.a, corrected_affine.b, corrected_affine.c],
[corrected_affine.d, corrected_affine.e, corrected_affine.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
# 参考底图的 像素->地图 矩阵及其逆
ref_transform = ref_dataset.transform
Rt = np.array([
[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
Rt_inv = np.linalg.inv(Rt)
# 源影像四角(源像素坐标)
src_h, src_w = src.height, src.width
src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4)
# 源像素 -> 地图坐标
map_corners = (M_map @ corners_h).T[:, :2]
# 地图坐标 -> 参考像素坐标
pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3)
pix_corners = pix_corners_h[:, :2]
# 最小外接矩形(像素)
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil( pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil( pix_corners[:,1].max())) + 10
# 边界裁剪
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
# 如果外接矩形太小,跳过
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform)
res_x, res_y = _pixel_size_xy(src.transform)
out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = src.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=out_h,
width=out_w,
count=src.count,
transform=out_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((out_h, out_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=out_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.nearest,
)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}")
success = True
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "affine_fallback", median_error, p95_error, success)
return True
except Exception as e:
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
# 记录失败的统计信息
try:
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
inlier_ratio, "exception", median_error, p95_error, False)
except:
pass # 避免统计记录失败影响主要错误处理
return False
def _apply_config(cfg: RegistrationConfig):
global REF_TIF, BIP_DIR, OUT_DIR
global MATCHER_NAME, DEVICE, TRANSFORM_METHODS
global MATCH_MAX_SIDE, ROI_PAD_PX, MASK_PAD_PX
global MIN_INLIERS, MIN_INLIER_RATIO
global FEATHER_PX, EDGE_BAND_PX, MIN_GRAD_QUANTILE
REF_TIF = cfg.ref_tif
BIP_DIR = Path(cfg.bip_dir)
OUT_DIR = Path(cfg.out_dir)
MATCHER_NAME = cfg.matcher_name
DEVICE = cfg.device
TRANSFORM_METHODS = list(cfg.transform_methods)
MATCH_MAX_SIDE = int(cfg.match_max_side)
ROI_PAD_PX = int(cfg.roi_pad_px)
MASK_PAD_PX = int(cfg.mask_pad_px)
MIN_INLIERS = int(cfg.min_inliers)
MIN_INLIER_RATIO = float(cfg.min_inlier_ratio)
FEATHER_PX = int(cfg.feather_px)
EDGE_BAND_PX = int(cfg.edge_band_px)
MIN_GRAD_QUANTILE = float(cfg.min_grad_quantile)
def _run_batch(cfg: RegistrationConfig, stop_event: threading.Event, progress_cb=None):
_apply_config(cfg)
out_dir = OUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
stats_dir = out_dir / "stats"
stats_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime('%Y%m%d_%H%M%S')
stats_csv = stats_dir / f"registration_stats_{ts}.csv"
logger.info(f"统计信息将保存到: {stats_csv}")
init_stats_csv(stats_csv)
_ensure_pyinstaller_third_party_paths()
_install_loguru_stub_if_missing()
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
ref_path_to_use = REF_TIF
if bool(cfg.enable_ref_mask):
if not TIF_MASK_AVAILABLE:
raise RuntimeError("未能导入 tif_caijain.py无法进行底图掩膜。")
if not cfg.ref_mask_tif or not Path(cfg.ref_mask_tif).exists():
raise RuntimeError("已启用底图掩膜,但掩膜 TIF 文件不存在。")
masked_dir = out_dir / "masked_refs"
masked_dir.mkdir(parents=True, exist_ok=True)
masked_ref_path = masked_dir / f"{Path(REF_TIF).stem}_masked_{ts}.tif"
logger.info(f"开始对底图进行掩膜: {REF_TIF}")
logger.info(f"掩膜文件: {cfg.ref_mask_tif}")
mask_data_by_binary_mask(
data_path=REF_TIF,
mask_path=cfg.ref_mask_tif,
output_path=str(masked_ref_path),
remove_value=int(cfg.ref_mask_remove_value),
)
ref_path_to_use = str(masked_ref_path)
logger.info(f"掩膜后的底图: {ref_path_to_use}")
with rasterio.open(ref_path_to_use) as ref:
bip_files = list(Path(BIP_DIR).glob("*.bip"))
total = len(bip_files)
success_count = 0
if progress_cb is not None:
progress_cb(0, total, "")
for idx, bip_path in enumerate(bip_files, start=1):
if stop_event.is_set():
break
if process_bip_to_tif(bip_path, ref, matcher, out_dir, stats_csv):
success_count += 1
if progress_cb is not None:
progress_cb(idx, total, bip_path.name)
return success_count
class QueueHandler(logging.Handler):
def __init__(self, log_queue):
super().__init__()
self.log_queue = log_queue
def emit(self, record):
self.log_queue.put(self.format(record))
class ToolTip:
def __init__(self, widget, text: str, delay_ms: int = 400):
self.widget = widget
self.text = text
self.delay_ms = int(delay_ms)
self._after_id = None
self._tip = None
self.widget.bind("<Enter>", self._on_enter, add=True)
self.widget.bind("<Leave>", self._on_leave, add=True)
self.widget.bind("<ButtonPress>", self._on_leave, add=True)
def _on_enter(self, _event=None):
self._schedule()
def _on_leave(self, _event=None):
self._cancel()
self._hide()
def _schedule(self):
self._cancel()
try:
self._after_id = self.widget.after(self.delay_ms, self._show)
except Exception:
self._after_id = None
def _cancel(self):
if self._after_id is not None:
try:
self.widget.after_cancel(self._after_id)
except Exception:
pass
self._after_id = None
def _show(self):
if self._tip is not None:
return
if not self.text:
return
try:
x = self.widget.winfo_rootx() + 10
y = self.widget.winfo_rooty() + self.widget.winfo_height() + 6
except Exception:
return
self._tip = tk.Toplevel(self.widget)
self._tip.wm_overrideredirect(True)
self._tip.wm_geometry(f"+{x}+{y}")
label = tk.Label(
self._tip,
text=self.text,
justify=tk.LEFT,
background="#ffffe0",
relief=tk.SOLID,
borderwidth=1,
wraplength=520,
)
label.pack(ipadx=6, ipady=4)
def _hide(self):
if self._tip is not None:
try:
self._tip.destroy()
except Exception:
pass
self._tip = None
_MATCHER_VALUES = [
"liftfeat", "loftr", "eloftr", "se2loftr", "xoftr", "aspanformer",
"matchanything-eloftr", "matchanything-roma", "matchformer",
"sift-lightglue", "superpoint-lightglue", "disk-lightglue",
"aliked-lightglue", "doghardnet-lightglue", "roma", "romav2",
"tiny-roma", "dedode", "steerers", "affine-steerers",
"dedode-kornia", "sift-nn", "orb-nn", "patch2pix", "superglue",
"r2d2", "d2net", "duster", "master", "doghardnet-nn", "xfeat",
"xfeat-star", "xfeat-lightglue", "dedode-lightglue", "gim-dkm",
"gim-lightglue", "omniglue", "xfeat-subpx", "xfeat-lightglue-subpx",
"dedode-subpx", "superpoint-lightglue-subpx", "aliked-lightglue-subpx",
"sift-sphereglue", "superpoint-sphereglue", "minima", "minima-roma",
"minima-roma-tiny", "minima-superpoint-lightglue", "minima-loftr",
"minima-xoftr", "edm", "lisrd-aliked", "lisrd-superpoint", "lisrd",
"lisrd-sift", "ripe", "topicfm", "topicfm-plus", "silk", "zippypoint",
"xfeat-steerers-perm", "xfeat-steerers-learned", "xfeat-star-steerers-perm",
"xfeat-star-steerers-learned",
]
class RegistrationGUI:
def __init__(self, root):
self.root = root
self.root.title("遥感影像批量配准工具")
self.root.geometry("1000x800")
self._tooltips = []
self.log_queue = queue.Queue()
self.stop_event = threading.Event()
self.processing_thread = None
queue_handler = QueueHandler(self.log_queue)
queue_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(queue_handler)
logger.setLevel(logging.INFO)
self.create_widgets()
self.check_log_queue()
def add_tooltip(self, widget, text: str):
self._tooltips.append(ToolTip(widget, text))
def show_error_dialog(self, title: str, summary: str, details: str):
win = tk.Toplevel(self.root)
win.title(title)
win.geometry("900x600")
top = ttk.Frame(win, padding=10)
top.pack(fill=tk.BOTH, expand=True)
summary_label = tk.Label(top, text=summary, fg="#b00020", justify=tk.LEFT, wraplength=860)
summary_label.pack(anchor=tk.W, fill=tk.X)
text_frame = ttk.Frame(top)
text_frame.pack(fill=tk.BOTH, expand=True, pady=(10, 0))
scrollbar = ttk.Scrollbar(text_frame, orient=tk.VERTICAL)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
text = tk.Text(text_frame, wrap=tk.NONE, yscrollcommand=scrollbar.set)
text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
scrollbar.config(command=text.yview)
if details:
text.insert(tk.END, details)
text.config(state=tk.DISABLED)
btns = ttk.Frame(top)
btns.pack(fill=tk.X, pady=(10, 0))
def copy_details():
try:
self.root.clipboard_clear()
self.root.clipboard_append(details or summary)
self.root.update()
except Exception:
pass
ttk.Button(btns, text="复制详情", command=copy_details).pack(side=tk.LEFT)
ttk.Button(btns, text="关闭", command=win.destroy).pack(side=tk.RIGHT)
try:
win.transient(self.root)
win.grab_set()
win.focus_force()
except Exception:
pass
def show_exception_dialog(self, title: str, exc: BaseException):
self.show_error_dialog(title=title, summary=str(exc), details=traceback.format_exc())
def create_widgets(self):
main_frame = ttk.Frame(self.root, padding="10")
main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
config_frame = ttk.LabelFrame(main_frame, text="配置参数", padding="5")
config_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(0, 10))
ref_label = ttk.Label(config_frame, text="参考TIF文件:")
ref_label.grid(row=0, column=0, sticky=tk.W, padx=(0, 5))
self.ref_tif_var = tk.StringVar(value=str(REF_TIF))
ref_entry = ttk.Entry(config_frame, textvariable=self.ref_tif_var, width=50)
ref_entry.grid(row=0, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
ref_btn = ttk.Button(config_frame, text="选择文件", command=self.select_ref_tif)
ref_btn.grid(row=0, column=2)
self.enable_ref_mask_var = tk.BooleanVar(value=False)
ref_mask_chk = ttk.Checkbutton(
config_frame,
text="启用底图掩膜",
variable=self.enable_ref_mask_var,
command=self._on_toggle_ref_mask,
)
ref_mask_chk.grid(row=1, column=0, sticky=tk.W, padx=(0, 5))
self.ref_mask_tif_var = tk.StringVar(value="")
self.ref_mask_entry = ttk.Entry(config_frame, textvariable=self.ref_mask_tif_var, width=50, state=tk.DISABLED)
self.ref_mask_entry.grid(row=1, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
self.ref_mask_btn = ttk.Button(config_frame, text="选择文件", command=self.select_ref_mask_tif, state=tk.DISABLED)
self.ref_mask_btn.grid(row=1, column=2)
bip_label = ttk.Label(config_frame, text="BIP文件夹:")
bip_label.grid(row=2, column=0, sticky=tk.W, padx=(0, 5))
self.bip_dir_var = tk.StringVar(value=str(BIP_DIR))
bip_entry = ttk.Entry(config_frame, textvariable=self.bip_dir_var, width=50)
bip_entry.grid(row=2, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
bip_btn = ttk.Button(config_frame, text="选择文件夹", command=self.select_bip_dir)
bip_btn.grid(row=2, column=2)
out_label = ttk.Label(config_frame, text="输出文件夹:")
out_label.grid(row=3, column=0, sticky=tk.W, padx=(0, 5))
self.out_dir_var = tk.StringVar(value=str(OUT_DIR))
out_entry = ttk.Entry(config_frame, textvariable=self.out_dir_var, width=50)
out_entry.grid(row=3, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
out_btn = ttk.Button(config_frame, text="选择文件夹", command=self.select_out_dir)
out_btn.grid(row=3, column=2)
matcher_label = ttk.Label(config_frame, text="匹配算法:")
matcher_label.grid(row=4, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0))
self.matcher_var = tk.StringVar(value=str(MATCHER_NAME))
matcher_combo = ttk.Combobox(config_frame, textvariable=self.matcher_var, width=47)
matcher_combo['values'] = _MATCHER_VALUES
matcher_combo.grid(row=4, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
device_label = ttk.Label(config_frame, text="设备:")
device_label.grid(row=5, column=0, sticky=tk.W, padx=(0, 5))
self.device_var = tk.StringVar(value=str(DEVICE))
device_frame = ttk.Frame(config_frame)
device_frame.grid(row=5, column=1, columnspan=2, sticky=(tk.W, tk.E))
cuda_rb = ttk.Radiobutton(device_frame, text="CUDA", variable=self.device_var, value="cuda")
cpu_rb = ttk.Radiobutton(device_frame, text="CPU", variable=self.device_var, value="cpu")
cuda_rb.pack(side=tk.LEFT)
cpu_rb.pack(side=tk.LEFT)
transform_label = ttk.Label(config_frame, text="变换方法 (按优先级):")
transform_label.grid(row=6, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0))
transform_frame = ttk.Frame(config_frame)
transform_frame.grid(row=6, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
self.transform_listbox = tk.Listbox(transform_frame, selectmode=tk.MULTIPLE, height=5, exportselection=False)
transform_methods = ["similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps"]
for method in transform_methods:
self.transform_listbox.insert(tk.END, method)
if method in TRANSFORM_METHODS:
self.transform_listbox.selection_set(transform_methods.index(method))
scrollbar = ttk.Scrollbar(transform_frame, orient=tk.VERTICAL, command=self.transform_listbox.yview)
self.transform_listbox.configure(yscrollcommand=scrollbar.set)
self.transform_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
button_frame = ttk.Frame(transform_frame)
button_frame.pack(side=tk.RIGHT, padx=(5, 0))
ttk.Button(button_frame, text="↑ 上移", command=self.move_up).pack(fill=tk.X, pady=(0, 2))
ttk.Button(button_frame, text="↓ 下移", command=self.move_down).pack(fill=tk.X)
param_frame = ttk.LabelFrame(config_frame, text="参数设置", padding="5")
param_frame.grid(row=7, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(10, 0))
match_max_side_label = ttk.Label(param_frame, text="匹配最大边长:")
match_max_side_label.grid(row=0, column=0, sticky=tk.W, padx=(0, 5))
self.match_max_side_var = tk.IntVar(value=int(MATCH_MAX_SIDE))
match_max_side_entry = ttk.Entry(param_frame, textvariable=self.match_max_side_var, width=10)
match_max_side_entry.grid(row=0, column=1, sticky=tk.W)
roi_pad_label = ttk.Label(param_frame, text="ROI填充像素:")
roi_pad_label.grid(row=0, column=2, sticky=tk.W, padx=(10, 5))
self.roi_pad_px_var = tk.IntVar(value=int(ROI_PAD_PX))
roi_pad_entry = ttk.Entry(param_frame, textvariable=self.roi_pad_px_var, width=10)
roi_pad_entry.grid(row=0, column=3, sticky=tk.W)
mask_pad_label = ttk.Label(param_frame, text="掩膜膨胀像素:")
mask_pad_label.grid(row=0, column=4, sticky=tk.W, padx=(10, 5))
self.mask_pad_px_var = tk.IntVar(value=int(MASK_PAD_PX))
mask_pad_entry = ttk.Entry(param_frame, textvariable=self.mask_pad_px_var, width=10)
mask_pad_entry.grid(row=0, column=5, sticky=tk.W)
min_inliers_label = ttk.Label(param_frame, text="最少内点数:")
min_inliers_label.grid(row=1, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0))
self.min_inliers_var = tk.IntVar(value=int(MIN_INLIERS))
min_inliers_entry = ttk.Entry(param_frame, textvariable=self.min_inliers_var, width=10)
min_inliers_entry.grid(row=1, column=1, sticky=tk.W, pady=(5, 0))
min_ratio_label = ttk.Label(param_frame, text="最少内点比例:")
min_ratio_label.grid(row=1, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0))
self.min_inlier_ratio_var = tk.DoubleVar(value=float(MIN_INLIER_RATIO))
min_ratio_entry = ttk.Entry(param_frame, textvariable=self.min_inlier_ratio_var, width=10)
min_ratio_entry.grid(row=1, column=3, sticky=tk.W, pady=(5, 0))
feather_label = ttk.Label(param_frame, text="羽化像素:")
feather_label.grid(row=2, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0))
self.feather_px_var = tk.IntVar(value=int(FEATHER_PX))
feather_entry = ttk.Entry(param_frame, textvariable=self.feather_px_var, width=10)
feather_entry.grid(row=2, column=1, sticky=tk.W, pady=(5, 0))
edge_band_label = ttk.Label(param_frame, text="边界剔除像素:")
edge_band_label.grid(row=2, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0))
self.edge_band_px_var = tk.IntVar(value=int(EDGE_BAND_PX))
edge_band_entry = ttk.Entry(param_frame, textvariable=self.edge_band_px_var, width=10)
edge_band_entry.grid(row=2, column=3, sticky=tk.W, pady=(5, 0))
grad_q_label = ttk.Label(param_frame, text="梯度分位阈值:")
grad_q_label.grid(row=2, column=4, sticky=tk.W, padx=(10, 5), pady=(5, 0))
self.min_grad_quantile_var = tk.DoubleVar(value=float(MIN_GRAD_QUANTILE))
grad_q_entry = ttk.Entry(param_frame, textvariable=self.min_grad_quantile_var, width=10)
grad_q_entry.grid(row=2, column=5, sticky=tk.W, pady=(5, 0))
self.add_tooltip(ref_label, "参考底图 GeoTIFF用于批量配准的目标坐标系与位置基准。建议确保 CRS、transform 正确。")
self.add_tooltip(ref_entry, "参考底图 GeoTIFF 路径。配准时会读取该底图的 ROI 进行匹配。")
self.add_tooltip(ref_btn, "选择参考底图 GeoTIFF 文件。")
self.add_tooltip(ref_mask_chk, "勾选后先用掩膜 TIF 对底图进行掩膜(掩膜值=1 的区域设置为 NoData并保存为新的底图后续配准使用掩膜后的底图。")
self.add_tooltip(self.ref_mask_entry, "掩膜 GeoTIFF 路径。要求与底图严格对齐(相同 CRS、分辨率、范围、尺寸否则会报错或效果不可控。")
self.add_tooltip(self.ref_mask_btn, "选择掩膜 GeoTIFF 文件。")
self.add_tooltip(bip_label, "包含待配准航带 .bip 文件的文件夹。程序会批量遍历 *.bip。")
self.add_tooltip(bip_entry, "BIP 文件夹路径。")
self.add_tooltip(bip_btn, "选择 BIP 文件夹。")
self.add_tooltip(out_label, "输出目录:配准后的航带、可视化图片、统计 CSV 等都会写到这里。")
self.add_tooltip(out_entry, "输出文件夹路径。")
self.add_tooltip(out_btn, "选择输出文件夹。")
self.add_tooltip(matcher_label, "特征匹配算法名称。不同 matcher 在精度、速度、鲁棒性上差异较大。")
self.add_tooltip(matcher_combo, "选择/输入 matcher 名称。若使用 cuda需要环境支持 GPU。")
self.add_tooltip(device_label, "运行设备cudaGPU更快cpu 更通用。")
self.add_tooltip(cuda_rb, "使用 GPUCUDA运行匹配器与部分计算。")
self.add_tooltip(cpu_rb, "使用 CPU 运行。速度可能较慢。")
self.add_tooltip(transform_label, "变换模型选择(可多选)。配准会按优先级尝试,并自动选择误差较小的模型。")
self.add_tooltip(self.transform_listbox, "按住 Ctrl/Shift 多选。右侧可上移/下移调整优先级。一般 homography 更灵活但更易发散affine 更稳定。")
self.add_tooltip(match_max_side_label, "匹配阶段会把图像等比缩小到最大边长不超过该值。值越大越慢,但细节更多。")
self.add_tooltip(match_max_side_entry, "匹配用降采样尺寸上限(像素)。")
self.add_tooltip(roi_pad_label, "参考底图 ROI 的额外扩展像素。增大可覆盖更大不确定区域,但会增加内存与耗时。")
self.add_tooltip(roi_pad_entry, "ROI padding像素参考底图坐标系")
self.add_tooltip(mask_pad_label, "仅用于匹配阶段:对源图有效掩膜/重投影后的掩膜做膨胀,增加可匹配区域。")
self.add_tooltip(mask_pad_entry, "掩膜膨胀像素(只影响匹配,不直接改变输出)。")
self.add_tooltip(min_inliers_label, "RANSAC 内点数量阈值。低于该值认为匹配质量不足,判定失败。")
self.add_tooltip(min_inliers_entry, "最少内点数。")
self.add_tooltip(min_ratio_label, "内点比例阈值(内点数/匹配点数)。过低通常意味着匹配不可靠。")
self.add_tooltip(min_ratio_entry, "最少内点比例。")
self.add_tooltip(feather_label, "对掩膜边缘做羽化,降低硬边缘带来的高对比假匹配。数值越大边缘过渡越宽。")
self.add_tooltip(feather_entry, "掩膜羽化宽度(像素)。")
self.add_tooltip(edge_band_label, "剔除距离掩膜边界过近的匹配点,减少边缘假匹配。数值越大剔除越多。")
self.add_tooltip(edge_band_entry, "边缘带剔除宽度(像素)。")
self.add_tooltip(grad_q_label, "纹理过滤分位阈值:梯度幅值低于该分位的区域视为低纹理,匹配点会被剔除。")
self.add_tooltip(grad_q_entry, "梯度分位阈值0~1")
control_frame = ttk.Frame(main_frame)
control_frame.grid(row=1, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
self.start_btn = ttk.Button(control_frame, text="开始处理", command=self.start_processing)
self.start_btn.pack(side=tk.LEFT, padx=(0, 10))
self.stop_btn = ttk.Button(control_frame, text="停止处理", command=self.stop_processing, state=tk.DISABLED)
self.stop_btn.pack(side=tk.LEFT)
progress_frame = ttk.LabelFrame(main_frame, text="处理进度", padding="5")
progress_frame.grid(row=2, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
self.progress_var = tk.DoubleVar()
self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100)
self.progress_bar.pack(fill=tk.X, pady=(0, 5))
self.progress_label = ttk.Label(progress_frame, text="准备就绪")
self.progress_label.pack(anchor=tk.W)
log_frame = ttk.LabelFrame(main_frame, text="处理日志", padding="5")
log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(10, 0))
log_text_frame = ttk.Frame(log_frame)
log_text_frame.pack(fill=tk.BOTH, expand=True)
self.log_text = tk.Text(log_text_frame, height=15, wrap=tk.WORD)
scrollbar = ttk.Scrollbar(log_text_frame, orient=tk.VERTICAL, command=self.log_text.yview)
self.log_text.configure(yscrollcommand=scrollbar.set)
self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
log_btn_frame = ttk.Frame(log_frame)
log_btn_frame.pack(fill=tk.X, pady=(5, 0))
ttk.Button(log_btn_frame, text="清空日志", command=self.clear_log).pack(side=tk.LEFT, padx=(0, 5))
ttk.Button(log_btn_frame, text="保存日志", command=self.save_log).pack(side=tk.LEFT)
self.root.columnconfigure(0, weight=1)
self.root.rowconfigure(0, weight=1)
main_frame.columnconfigure(1, weight=1)
main_frame.rowconfigure(3, weight=1)
def select_ref_tif(self):
filename = filedialog.askopenfilename(
title="选择参考TIF文件",
filetypes=[("TIF files", "*.tif;*.tiff"), ("All files", "*.*")]
)
if filename:
self.ref_tif_var.set(filename)
def select_ref_mask_tif(self):
filename = filedialog.askopenfilename(
title="选择掩膜TIF文件",
filetypes=[("TIF files", "*.tif;*.tiff"), ("All files", "*.*")]
)
if filename:
self.ref_mask_tif_var.set(filename)
def select_bip_dir(self):
dirname = filedialog.askdirectory(title="选择BIP文件夹")
if dirname:
self.bip_dir_var.set(dirname)
def select_out_dir(self):
dirname = filedialog.askdirectory(title="选择输出文件夹")
if dirname:
self.out_dir_var.set(dirname)
def move_up(self):
selection = self.transform_listbox.curselection()
if selection and selection[0] > 0:
idx = selection[0]
text = self.transform_listbox.get(idx)
self.transform_listbox.delete(idx)
self.transform_listbox.insert(idx - 1, text)
self.transform_listbox.selection_set(idx - 1)
def move_down(self):
selection = self.transform_listbox.curselection()
if selection and selection[0] < self.transform_listbox.size() - 1:
idx = selection[0]
text = self.transform_listbox.get(idx)
self.transform_listbox.delete(idx)
self.transform_listbox.insert(idx + 1, text)
self.transform_listbox.selection_set(idx + 1)
def start_processing(self):
if self.processing_thread and self.processing_thread.is_alive():
messagebox.showwarning("警告", "处理正在进行中")
return
selected_indices = self.transform_listbox.curselection()
if not selected_indices:
messagebox.showwarning("警告", "请至少选择一种变换方法")
return
transform_methods = [self.transform_listbox.get(i) for i in selected_indices]
cfg = RegistrationConfig(
ref_tif=self.ref_tif_var.get().strip(),
bip_dir=self.bip_dir_var.get().strip(),
out_dir=self.out_dir_var.get().strip(),
enable_ref_mask=bool(self.enable_ref_mask_var.get()),
ref_mask_tif=self.ref_mask_tif_var.get().strip(),
ref_mask_remove_value=1,
matcher_name=self.matcher_var.get().strip(),
device=self.device_var.get().strip(),
transform_methods=transform_methods,
match_max_side=int(self.match_max_side_var.get()),
roi_pad_px=int(self.roi_pad_px_var.get()),
mask_pad_px=int(self.mask_pad_px_var.get()),
min_inliers=int(self.min_inliers_var.get()),
min_inlier_ratio=float(self.min_inlier_ratio_var.get()),
feather_px=int(self.feather_px_var.get()),
edge_band_px=int(self.edge_band_px_var.get()),
min_grad_quantile=float(self.min_grad_quantile_var.get()),
)
if not Path(cfg.ref_tif).exists():
messagebox.showerror("错误", "参考 TIF 不存在")
return
if not Path(cfg.bip_dir).exists():
messagebox.showerror("错误", "BIP 文件夹不存在")
return
if not cfg.out_dir:
messagebox.showerror("错误", "输出文件夹不能为空")
return
if cfg.enable_ref_mask:
if not TIF_MASK_AVAILABLE:
messagebox.showerror("错误", "tif_caijain.py 不可用,无法进行底图掩膜")
return
if not cfg.ref_mask_tif or not Path(cfg.ref_mask_tif).exists():
messagebox.showerror("错误", "已启用底图掩膜,但掩膜 TIF 文件不存在")
return
self.stop_event.clear()
self.start_btn.config(state=tk.DISABLED)
self.stop_btn.config(state=tk.NORMAL)
self.progress_var.set(0)
self.progress_label.config(text="正在初始化...")
self.processing_thread = threading.Thread(
target=self.run_processing,
args=(cfg,),
daemon=True
)
self.processing_thread.start()
def _on_toggle_ref_mask(self):
enabled = bool(self.enable_ref_mask_var.get())
state = tk.NORMAL if enabled else tk.DISABLED
try:
self.ref_mask_entry.configure(state=state)
self.ref_mask_btn.configure(state=state)
except Exception:
pass
def stop_processing(self):
if self.processing_thread and self.processing_thread.is_alive():
self.stop_event.set()
self.progress_label.config(text="正在停止...")
def run_processing(self, cfg: RegistrationConfig):
try:
def progress_cb(current, total, filename):
self.on_progress(current, total, filename)
_run_batch(cfg, self.stop_event, progress_cb=progress_cb)
except Exception as e:
tb = traceback.format_exc()
self.log_queue.put(f"处理过程中发生错误: {e}\n{tb}")
try:
self.root.after(0, lambda: self.show_error_dialog("处理失败", str(e), tb))
except Exception:
pass
finally:
self.root.after(0, lambda: self.start_btn.config(state=tk.NORMAL))
self.root.after(0, lambda: self.stop_btn.config(state=tk.DISABLED))
self.root.after(0, lambda: self.progress_label.config(text="处理完成"))
def on_progress(self, current, total, filename):
if total > 0:
progress = (current / total) * 100
self.root.after(0, lambda: self.progress_var.set(progress))
if filename:
self.root.after(0, lambda: self.progress_label.config(text=f"处理中: {filename} ({current}/{total})"))
else:
self.root.after(0, lambda: self.progress_label.config(text=f"处理中: ({current}/{total})"))
def check_log_queue(self):
try:
while True:
message = self.log_queue.get_nowait()
self.log_text.insert(tk.END, message + '\n')
self.log_text.see(tk.END)
except queue.Empty:
pass
self.root.after(100, self.check_log_queue)
def clear_log(self):
self.log_text.delete(1.0, tk.END)
def save_log(self):
filename = filedialog.asksaveasfilename(
title="保存日志",
defaultextension=".txt",
filetypes=[("Text files", "*.txt"), ("All files", "*.*")]
)
if filename:
with open(filename, 'w', encoding='utf-8') as f:
f.write(self.log_text.get(1.0, tk.END))
def create_gui():
root = tk.Tk()
RegistrationGUI(root)
root.mainloop()
# ---------- 主逻辑 ----------
def main():
cfg = RegistrationConfig(
ref_tif=str(REF_TIF),
bip_dir=str(BIP_DIR),
out_dir=str(OUT_DIR),
enable_ref_mask=False,
ref_mask_tif="",
ref_mask_remove_value=1,
matcher_name=str(MATCHER_NAME),
device=str(DEVICE),
transform_methods=list(TRANSFORM_METHODS),
match_max_side=int(MATCH_MAX_SIDE),
roi_pad_px=int(ROI_PAD_PX),
mask_pad_px=int(MASK_PAD_PX),
min_inliers=int(MIN_INLIERS),
min_inlier_ratio=float(MIN_INLIER_RATIO),
feather_px=int(FEATHER_PX),
edge_band_px=int(EDGE_BAND_PX),
min_grad_quantile=float(MIN_GRAD_QUANTILE),
)
stop_event = threading.Event()
_run_batch(cfg, stop_event)
if __name__ == "__main__":
if "--cli" in sys.argv:
main()
else:
create_gui()