2218 lines
98 KiB
Python
2218 lines
98 KiB
Python
"""
|
||
批量配准 .bip 文件到参考 .tif 文件
|
||
问题:当图像中大部分是水体时,匹配过多出现在掩膜边缘,同时过滤时将本来就少的陆地匹配点也过滤掉了
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
# Fix for PyInstaller GUI apps: ensure stdout/stderr are never None
|
||
# This prevents 'NoneType' object has no attribute 'write' errors
|
||
# when libraries like PyTorch try to print download progress
|
||
if sys.stdout is None:
|
||
sys.stdout = open(os.devnull, 'w')
|
||
if sys.stderr is None:
|
||
sys.stderr = open(os.devnull, 'w')
|
||
|
||
from pathlib import Path
|
||
|
||
|
||
def _early_pyinstaller_hf_env():
|
||
"""必须在 import vismatch 之前执行:vismatch/__init__.py 会立即 import huggingface_hub。"""
|
||
if not hasattr(sys, "_MEIPASS"):
|
||
return
|
||
base = Path(sys._MEIPASS)
|
||
exe_dir = Path(sys.executable).resolve().parent
|
||
hf_candidates = [
|
||
base / "hub",
|
||
base / "_internal" / "hub",
|
||
exe_dir / "_internal" / "hub",
|
||
exe_dir / "hub",
|
||
]
|
||
for hf_candidate in hf_candidates:
|
||
try:
|
||
if not hf_candidate.exists():
|
||
continue
|
||
if not any("vismatch" in d.name.lower() for d in hf_candidate.iterdir() if d.is_dir()):
|
||
continue
|
||
except OSError:
|
||
continue
|
||
os.environ.setdefault("HF_HOME", str(hf_candidate.parent))
|
||
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hf_candidate))
|
||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
|
||
break
|
||
|
||
|
||
_early_pyinstaller_hf_env()
|
||
|
||
import numpy as np
|
||
import cv2
|
||
import rasterio
|
||
import csv
|
||
from datetime import datetime
|
||
from rasterio.windows import from_bounds
|
||
from rasterio.warp import transform_bounds, reproject, Resampling
|
||
from affine import Affine
|
||
from vismatch import get_matcher
|
||
from vismatch.viz import plot_matches, plot_keypoints
|
||
import logging
|
||
import threading
|
||
import queue
|
||
import sys
|
||
import traceback
|
||
import types
|
||
from dataclasses import dataclass
|
||
import tkinter as tk
|
||
from tkinter import ttk, filedialog, messagebox
|
||
|
||
try:
|
||
from tif_caijain import mask_data_by_binary_mask
|
||
TIF_MASK_AVAILABLE = True
|
||
except Exception:
|
||
TIF_MASK_AVAILABLE = False
|
||
|
||
try:
|
||
from skimage.transform import PiecewiseAffineTransform, PolynomialTransform
|
||
SKIMAGE_AVAILABLE = True
|
||
except ImportError:
|
||
SKIMAGE_AVAILABLE = False
|
||
logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换")
|
||
|
||
try:
|
||
from matplotlib.path import Path as MplPath
|
||
from scipy.spatial import ConvexHull
|
||
MATPLOTLIB_SCIPY_AVAILABLE = True
|
||
except ImportError:
|
||
MATPLOTLIB_SCIPY_AVAILABLE = False
|
||
MplPath = None
|
||
logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断")
|
||
|
||
try:
|
||
import SimpleITK as sitk
|
||
SITK_AVAILABLE = True
|
||
except ImportError:
|
||
SITK_AVAILABLE = False
|
||
logging.warning("SimpleITK 不可用,将使用仿射变换作为替代")
|
||
|
||
|
||
try:
|
||
import pirt
|
||
PIRT_AVAILABLE = True
|
||
except ImportError:
|
||
PIRT_AVAILABLE = False
|
||
logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代")
|
||
|
||
try:
|
||
from scipy.interpolate import Rbf
|
||
SCIPY_AVAILABLE = True
|
||
except ImportError:
|
||
SCIPY_AVAILABLE = False
|
||
logging.warning("scipy 不可用,将跳过 TPS 变换")
|
||
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def _ensure_pyinstaller_third_party_paths():
|
||
if not hasattr(sys, "_MEIPASS"):
|
||
return
|
||
base = Path(sys._MEIPASS)
|
||
exe_dir = Path(sys.executable).resolve().parent
|
||
|
||
# More comprehensive candidate paths for third_party
|
||
candidates = [
|
||
base / "vismatch" / "third_party",
|
||
base / "_internal" / "vismatch" / "third_party",
|
||
exe_dir / "_internal" / "vismatch" / "third_party",
|
||
exe_dir / "vismatch" / "third_party",
|
||
base / "third_party", # In case vismatch is directly included
|
||
]
|
||
|
||
third_party_base = None
|
||
for c in candidates:
|
||
if c.exists():
|
||
third_party_base = c
|
||
logger.info(f"找到 third_party 目录: {third_party_base}")
|
||
break
|
||
|
||
if third_party_base is None:
|
||
logger.warning(f"未找到 third_party 目录,MEIPASS={base}, exe_dir={exe_dir}")
|
||
# List what's available for debugging
|
||
try:
|
||
if base.exists():
|
||
logger.info(f"MEIPASS 内容: {list(base.iterdir())[:10]}")
|
||
if exe_dir.exists():
|
||
logger.info(f"exe_dir 内容: {list(exe_dir.iterdir())[:10]}")
|
||
except Exception as e:
|
||
logger.warning(f"无法列出目录内容: {e}")
|
||
return
|
||
|
||
# Try multiple possible structures for MatchAnything
|
||
matchanything_candidates = [
|
||
# Original expected structure
|
||
third_party_base / "MatchAnything" / "imcui" / "third_party" / "MatchAnything",
|
||
# Alternative: direct MatchAnything without the nested imcui structure
|
||
third_party_base / "MatchAnything",
|
||
# Alternative: MatchAnything with imcui but different nesting
|
||
third_party_base / "MatchAnything" / "MatchAnything",
|
||
# One more level up possibility
|
||
third_party_base.parent / "MatchAnything" / "imcui" / "third_party" / "MatchAnything",
|
||
]
|
||
|
||
matchanything_root = None
|
||
for candidate in matchanything_candidates:
|
||
# Handle case where candidate already ends with 'src' or needs src subdirectory check
|
||
has_src = (candidate / "src").exists() if not str(candidate).endswith("src") else candidate.exists()
|
||
if candidate.exists() and has_src:
|
||
# If candidate ends with src, use its parent as root
|
||
matchanything_root = candidate.parent if str(candidate).endswith("src") else candidate
|
||
logger.info(f"找到 MatchAnything 根目录: {matchanything_root}")
|
||
break
|
||
|
||
if matchanything_root is None:
|
||
logger.warning(f"未找到 MatchAnything 目录,尝试的路径:")
|
||
for c in matchanything_candidates:
|
||
logger.warning(f" - {c} (exists={c.exists()})")
|
||
|
||
# Last resort: search recursively for any directory containing 'src' and 'matchanything' in path
|
||
try:
|
||
for root, dirs, files in os.walk(third_party_base):
|
||
root_path = Path(root)
|
||
if "matchanything" in root.lower() and (root_path / "src").exists():
|
||
matchanything_root = root_path
|
||
logger.info(f"通过递归搜索找到 MatchAnything: {matchanything_root}")
|
||
break
|
||
# Also check if this directory has a 'src' subdirectory
|
||
if (root_path / "src").exists():
|
||
# Check if it looks like MatchAnything (has specific files)
|
||
src_files = list((root_path / "src").glob("*.py"))[:5]
|
||
if src_files:
|
||
matchanything_root = root_path
|
||
logger.info(f"通过递归搜索找到潜在 MatchAnything: {matchanything_root}")
|
||
break
|
||
except Exception as e:
|
||
logger.warning(f"递归搜索失败: {e}")
|
||
|
||
if matchanything_root is None:
|
||
return
|
||
|
||
# Add MatchAnything root to path (contains 'src' module)
|
||
p = str(matchanything_root)
|
||
if p not in sys.path:
|
||
sys.path.insert(0, p)
|
||
logger.info(f"已添加 MatchAnything 到 sys.path: {p}")
|
||
|
||
# Try multiple possible ROMA paths
|
||
roma_candidates = [
|
||
matchanything_root / "third_party" / "ROMA",
|
||
third_party_base / "ROMA",
|
||
third_party_base / "MatchAnything" / "third_party" / "ROMA",
|
||
matchanything_root.parent / "ROMA",
|
||
]
|
||
|
||
roma_root = None
|
||
for candidate in roma_candidates:
|
||
if candidate.exists():
|
||
roma_root = candidate
|
||
logger.info(f"找到 ROMA 目录: {roma_root}")
|
||
break
|
||
|
||
if roma_root:
|
||
p2 = str(roma_root)
|
||
if p2 not in sys.path:
|
||
sys.path.insert(0, p2)
|
||
logger.info(f"已添加 ROMA 到 sys.path: {p2}")
|
||
else:
|
||
logger.warning(f"未找到 ROMA 目录")
|
||
|
||
# HuggingFace 缓存:优先已在 _early_pyinstaller_hf_env() 中设置(须在 import vismatch 前)
|
||
if hasattr(sys, "_MEIPASS"):
|
||
hf_candidates = [
|
||
base / "hub",
|
||
base / "_internal" / "hub",
|
||
exe_dir / "_internal" / "hub",
|
||
exe_dir / "hub",
|
||
]
|
||
for hf_candidate in hf_candidates:
|
||
try:
|
||
if not hf_candidate.exists():
|
||
continue
|
||
if not any("vismatch" in d.name.lower() for d in hf_candidate.iterdir() if d.is_dir()):
|
||
continue
|
||
except OSError:
|
||
continue
|
||
os.environ.setdefault("HF_HOME", str(hf_candidate.parent))
|
||
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hf_candidate))
|
||
os.environ.setdefault("HF_HUB_OFFLINE", "1")
|
||
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
|
||
logger.info(
|
||
f"HuggingFace 缓存: {os.environ.get('HUGGINGFACE_HUB_CACHE')} "
|
||
f"(HF_HUB_OFFLINE={os.environ.get('HF_HUB_OFFLINE')})"
|
||
)
|
||
break
|
||
|
||
def _install_loguru_stub_if_missing():
|
||
try:
|
||
import loguru # noqa: F401
|
||
return
|
||
except Exception:
|
||
pass
|
||
|
||
py_logger = logging.getLogger("loguru")
|
||
|
||
class _StubLogger:
|
||
def debug(self, msg, *args, **kwargs):
|
||
py_logger.debug(msg, *args)
|
||
|
||
def info(self, msg, *args, **kwargs):
|
||
py_logger.info(msg, *args)
|
||
|
||
def warning(self, msg, *args, **kwargs):
|
||
py_logger.warning(msg, *args)
|
||
|
||
def error(self, msg, *args, **kwargs):
|
||
py_logger.error(msg, *args)
|
||
|
||
def exception(self, msg, *args, **kwargs):
|
||
py_logger.exception(msg, *args)
|
||
|
||
def add(self, *args, **kwargs):
|
||
return 0
|
||
|
||
def remove(self, *args, **kwargs):
|
||
return None
|
||
|
||
m = types.ModuleType("loguru")
|
||
m.logger = _StubLogger()
|
||
sys.modules["loguru"] = m
|
||
|
||
# ---------- 配置 ----------
|
||
# 请根据实际情况修改这些路径
|
||
REF_TIF = r"E:\is2\dingshanhu\mask_water.tif" # 参考 tif 文件路径
|
||
BIP_DIR = Path(r"E:\is2\dingshanhu") # .bip 文件所在文件夹
|
||
OUT_DIR = Path(r"E:\is2\dingshanhu\output") # 输出文件夹
|
||
|
||
# 匹配算法选择
|
||
MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
|
||
DEVICE = "cuda" # 或 "cpu"
|
||
|
||
# 变换方法选择(按优先级尝试)
|
||
TRANSFORM_METHODS = ["similarity", "affine", "homography"]
|
||
# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps"
|
||
|
||
# 匹配参数
|
||
MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素)
|
||
ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素)
|
||
MASK_PAD_PX = 100 # 匹配掩膜扩张像素(仅用于匹配阶段)
|
||
|
||
# 质量控制阈值
|
||
MIN_INLIERS = 10
|
||
MIN_INLIER_RATIO = 0.01
|
||
|
||
# 掩膜边缘羽化与过滤
|
||
FEATHER_PX = 20 # 掩膜羽化宽度(像素,先在全分辨率/ROI分辨率上做)
|
||
EDGE_BAND_PX = 30 # 剔除距离掩膜边界小于此像素的匹配点(在小图上按比例缩放)
|
||
|
||
# 纹理过滤
|
||
MIN_GRAD_QUANTILE = 0.20 # 梯度幅值的分位阈值(0~1),低于该阈值的点视为低纹理,剔除
|
||
|
||
STATS_DIR = None
|
||
STATS_CSV = None
|
||
|
||
|
||
@dataclass
|
||
class RegistrationConfig:
|
||
ref_tif: str
|
||
bip_dir: str
|
||
out_dir: str
|
||
enable_ref_mask: bool
|
||
ref_mask_tif: str
|
||
ref_mask_remove_value: int
|
||
matcher_name: str
|
||
device: str
|
||
transform_methods: list
|
||
match_max_side: int
|
||
roi_pad_px: int
|
||
mask_pad_px: int
|
||
min_inliers: int
|
||
min_inlier_ratio: float
|
||
feather_px: int
|
||
edge_band_px: int
|
||
min_grad_quantile: float
|
||
|
||
# ---------- 工具函数 ----------
|
||
def init_stats_csv(csv_path: Path):
|
||
"""初始化统计CSV文件"""
|
||
if not csv_path.exists():
|
||
with open(csv_path, 'w', newline='', encoding='utf-8') as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow([
|
||
'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio',
|
||
'selected_method', 'median_error', 'p95_error', 'success'
|
||
])
|
||
|
||
def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int,
|
||
inlier_ratio: float, selected_method: str, median_error: float,
|
||
p95_error: float, success: bool):
|
||
"""记录配准统计信息到CSV"""
|
||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
with open(csv_path, 'a', newline='', encoding='utf-8') as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow([
|
||
timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}",
|
||
selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success
|
||
])
|
||
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
|
||
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
|
||
arr = arr_chw.astype(np.float32)
|
||
|
||
if arr.shape[0] == 1:
|
||
# 单波段复制为3通道
|
||
arr = np.repeat(arr, 3, axis=0)
|
||
elif arr.shape[0] >= 3:
|
||
# 取前3波段
|
||
arr = arr[:3]
|
||
else:
|
||
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
|
||
|
||
# 百分位数拉伸,增强跨传感器匹配稳定性
|
||
p2 = np.percentile(arr, 2)
|
||
p98 = np.percentile(arr, 98)
|
||
arr = (arr - p2) / (p98 - p2 + 1e-6)
|
||
arr = np.clip(arr, 0.0, 1.0)
|
||
return arr
|
||
|
||
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
|
||
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
|
||
c, h, w = arr_chw.shape
|
||
s = min(1.0, max_side / max(h, w))
|
||
if s >= 1.0:
|
||
return arr_chw
|
||
new_w = int(round(w * s))
|
||
new_h = int(round(h * s))
|
||
# 用opencv缩放(逐通道)
|
||
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
|
||
return out
|
||
|
||
def _expand_window(win, pad, max_w, max_h):
|
||
"""扩展窗口并确保边界有效"""
|
||
col_off = int(max(0, win.col_off - pad))
|
||
row_off = int(max(0, win.row_off - pad))
|
||
col_end = int(min(max_w, win.col_off + win.width + pad))
|
||
row_end = int(min(max_h, win.row_off + win.height + pad))
|
||
return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off)
|
||
|
||
|
||
def _pixel_size_xy(transform: Affine):
|
||
rx = float(np.hypot(transform.a, transform.d))
|
||
ry = float(np.hypot(transform.b, transform.e))
|
||
if not np.isfinite(rx) or rx <= 0:
|
||
rx = float(abs(transform.a)) if transform.a != 0 else 1.0
|
||
if not np.isfinite(ry) or ry <= 0:
|
||
ry = float(abs(transform.e)) if transform.e != 0 else 1.0
|
||
return rx, ry
|
||
|
||
|
||
def _grid_from_bounds(bounds, res_x: float, res_y: float):
|
||
left, bottom, right, top = [float(v) for v in bounds]
|
||
res_x = float(abs(res_x))
|
||
res_y = float(abs(res_y))
|
||
w = int(np.ceil((right - left) / max(1e-12, res_x)))
|
||
h = int(np.ceil((top - bottom) / max(1e-12, res_y)))
|
||
w = max(1, w)
|
||
h = max(1, h)
|
||
out_transform = Affine(res_x, 0.0, left, 0.0, -res_y, top)
|
||
return out_transform, w, h
|
||
|
||
|
||
def estimate_transform(method, k0, k1):
|
||
"""统一的变换估计函数,支持多种变换类型"""
|
||
if method == "translation":
|
||
# 简单平移:用内点的平均位移
|
||
if len(k0) == 0:
|
||
return None, None
|
||
dx = np.mean(k1[:, 0] - k0[:, 0])
|
||
dy = np.mean(k1[:, 1] - k0[:, 1])
|
||
A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32)
|
||
return "A", A
|
||
|
||
elif method == "euclidean":
|
||
# 欧式变换(旋转+平移),约束等比缩放=1
|
||
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
|
||
return "A", A
|
||
|
||
elif method == "similarity":
|
||
# 相似变换(旋转+等比缩放+平移)
|
||
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
|
||
return "A", A
|
||
|
||
elif method == "affine":
|
||
# 全仿射变换(旋转+非等比缩放+剪切+平移)
|
||
A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
|
||
return "A", A
|
||
|
||
elif method == "homography":
|
||
# 投影变换(8DOF,透视)
|
||
H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0)
|
||
return "H", H
|
||
|
||
elif method == "piecewise_affine":
|
||
# 分片仿射变换
|
||
if not SKIMAGE_AVAILABLE:
|
||
return None, None
|
||
try:
|
||
tform = PiecewiseAffineTransform()
|
||
tform.estimate(k0, k1)
|
||
return "piecewise", tform
|
||
except Exception:
|
||
return None, None
|
||
|
||
elif method == "polynomial":
|
||
# 多项式变换(2阶)
|
||
if not SKIMAGE_AVAILABLE:
|
||
return None, None
|
||
try:
|
||
tform = PolynomialTransform()
|
||
tform.estimate(k0, k1, order=2)
|
||
return "polynomial", tform
|
||
except Exception:
|
||
return None, None
|
||
|
||
else:
|
||
raise ValueError(f"未知变换方法: {method}")
|
||
|
||
def evaluate_transform_quality(transform_type, transform, k0, k1):
|
||
"""评估变换质量(重投影误差)"""
|
||
if transform is None or len(k0) == 0:
|
||
return np.inf, np.inf
|
||
|
||
if transform_type == "A":
|
||
# 仿射变换重投影误差
|
||
A = transform
|
||
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
|
||
pred = (A @ np.hstack([k0, ones]).T).T
|
||
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
|
||
|
||
elif transform_type == "H":
|
||
# 单应变换重投影误差
|
||
H = transform
|
||
ones = np.ones((k0.shape[0], 1), dtype=np.float32)
|
||
src_h = np.hstack([k0, ones]).T
|
||
warped = H @ src_h
|
||
warped /= (warped[2:3, :] + 1e-6)
|
||
pred = warped[:2, :].T
|
||
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
|
||
|
||
elif transform_type in ["piecewise", "polynomial"]:
|
||
# scikit-image 变换重投影误差
|
||
pred = transform(k0)
|
||
e = np.sqrt(((pred - k1) ** 2).sum(axis=1))
|
||
|
||
else:
|
||
return np.inf, np.inf
|
||
|
||
return float(np.median(e)), float(np.percentile(e, 95))
|
||
|
||
def _norm01_hw(x: np.ndarray) -> np.ndarray:
|
||
"""对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性"""
|
||
x = x.astype(np.float32, copy=False)
|
||
p2 = float(np.percentile(x, 2))
|
||
p98 = float(np.percentile(x, 98))
|
||
y = (x - p2) / (p98 - p2 + 1e-6)
|
||
return np.clip(y, 0.0, 1.0)
|
||
|
||
def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)):
|
||
"""
|
||
numpy(H,W)->SimpleITK Image。
|
||
物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。
|
||
"""
|
||
img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False))
|
||
img.SetSpacing((1.0, 1.0))
|
||
img.SetOrigin((float(origin_xy[0]), float(origin_xy[1])))
|
||
img.SetDirection((1.0, 0.0, 0.0, 1.0))
|
||
return img
|
||
|
||
def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10):
|
||
"""用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内"""
|
||
min_x = int(np.floor(k1_global[:, 0].min())) - pad
|
||
max_x = int(np.ceil (k1_global[:, 0].max())) + pad
|
||
min_y = int(np.floor(k1_global[:, 1].min())) - pad
|
||
max_y = int(np.ceil (k1_global[:, 1].max())) + pad
|
||
|
||
min_x = max(0, min_x)
|
||
min_y = max(0, min_y)
|
||
max_x = min(ref_w, max_x)
|
||
max_y = min(ref_h, max_y)
|
||
|
||
bbox_w = max_x - min_x
|
||
bbox_h = max_y - min_y
|
||
return min_x, min_y, bbox_w, bbox_h
|
||
|
||
def _downscale_mask_hw(mask_hw: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
|
||
"""将(H,W)二值掩膜缩放到目标尺寸,保持最近邻"""
|
||
m = cv2.resize(mask_hw.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST)
|
||
return m > 0
|
||
|
||
def _soft_alpha_from_mask(mask_hw: np.ndarray, feather_px: int) -> np.ndarray:
|
||
"""
|
||
二值掩膜 -> 软掩膜 alpha∈[0,1],边缘处按距离线性上升,避免硬边缘。
|
||
mask_hw: bool/uint8 (H,W) True/1表示有效
|
||
"""
|
||
if mask_hw is None:
|
||
return None
|
||
m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255
|
||
# 距离变换仅对前景内部有效,计算到边界的距离
|
||
dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3)
|
||
if feather_px <= 0:
|
||
alpha = (dist > 0).astype(np.float32)
|
||
else:
|
||
alpha = np.clip(dist / float(feather_px), 0.0, 1.0).astype(np.float32)
|
||
return alpha # (H,W) float32
|
||
|
||
def _distance_keep_mask(mask_hw: np.ndarray, min_dist_px: int) -> np.ndarray:
|
||
"""
|
||
生成"远离边界"的保留掩膜:仅保留距离边界>=min_dist_px的像素。
|
||
"""
|
||
if mask_hw is None:
|
||
return None
|
||
m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255
|
||
dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3)
|
||
keep = dist >= float(max(1, min_dist_px))
|
||
return keep
|
||
|
||
def _grad_mask_from_chw(img_chw: np.ndarray, quantile: float) -> np.ndarray:
|
||
"""
|
||
根据梯度幅值生成纹理掩膜(H,W)True=纹理足够。
|
||
使用与匹配同尺寸的CHW图像。
|
||
"""
|
||
# 转灰度
|
||
g = img_chw.mean(axis=0).astype(np.float32) # (H,W)
|
||
gx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3)
|
||
gy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3)
|
||
mag = np.sqrt(gx*gx + gy*gy)
|
||
thr = float(np.quantile(mag, quantile)) if mag.size > 0 else 0.0
|
||
return mag >= thr # (H,W) bool
|
||
|
||
def _filter_matches_by_masks(result: dict, src_mask_small: np.ndarray, ref_mask_small: np.ndarray) -> dict:
|
||
"""将匹配与内点严格限制在掩膜内"""
|
||
if src_mask_small is None or ref_mask_small is None:
|
||
return result
|
||
|
||
def keep_in_mask(kpts: np.ndarray, mask_hw: np.ndarray) -> np.ndarray:
|
||
if kpts is None or len(kpts) == 0:
|
||
return np.zeros((0,), dtype=bool)
|
||
kpts = np.asarray(kpts)
|
||
xs = np.clip(np.rint(kpts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1)
|
||
ys = np.clip(np.rint(kpts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1)
|
||
return mask_hw[ys, xs]
|
||
|
||
# 过滤 matched_kpts
|
||
if "matched_kpts0" in result and "matched_kpts1" in result:
|
||
mk0 = np.asarray(result["matched_kpts0"])
|
||
mk1 = np.asarray(result["matched_kpts1"])
|
||
if len(mk0) == len(mk1) and len(mk0) > 0:
|
||
keep_m = keep_in_mask(mk0, src_mask_small) & keep_in_mask(mk1, ref_mask_small)
|
||
result["matched_kpts0"] = mk0[keep_m]
|
||
result["matched_kpts1"] = mk1[keep_m]
|
||
|
||
# 过滤 inlier_kpts
|
||
if "inlier_kpts0" in result and "inlier_kpts1" in result and result["inlier_kpts0"] is not None:
|
||
ik0 = np.asarray(result["inlier_kpts0"])
|
||
ik1 = np.asarray(result["inlier_kpts1"])
|
||
if len(ik0) == len(ik1) and len(ik0) > 0:
|
||
keep_i = keep_in_mask(ik0, src_mask_small) & keep_in_mask(ik1, ref_mask_small)
|
||
result["inlier_kpts0"] = ik0[keep_i]
|
||
result["inlier_kpts1"] = ik1[keep_i]
|
||
result["num_inliers"] = int(len(result["inlier_kpts0"]))
|
||
|
||
return result
|
||
|
||
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path):
|
||
"""处理单个 .bip 文件到参考 .tif 的配准"""
|
||
try:
|
||
with rasterio.open(bip_path) as src:
|
||
logger.info(f"处理文件: {bip_path.name}")
|
||
|
||
# 初始化统计变量
|
||
num_inliers = 0
|
||
num_matches = 0
|
||
inlier_ratio = 0.0
|
||
selected_method = "none"
|
||
median_error = float('inf')
|
||
p95_error = float('inf')
|
||
success = False
|
||
|
||
# 检查CRS
|
||
if src.crs is None:
|
||
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS")
|
||
src_crs = ref_dataset.crs
|
||
else:
|
||
src_crs = src.crs
|
||
|
||
ref_crs = ref_dataset.crs
|
||
if ref_crs is None:
|
||
raise RuntimeError(f"参考文件缺少CRS信息")
|
||
|
||
# 1) 用"源图有效掩膜"的包围盒推参考ROI(比整图bounds更贴近有效重叠)
|
||
try:
|
||
src_mask = (src.read_masks(1) > 0) # True=有效
|
||
rows_any = np.any(src_mask, axis=1)
|
||
cols_any = np.any(src_mask, axis=0)
|
||
if rows_any.any() and cols_any.any():
|
||
rmin = int(rows_any.argmax())
|
||
rmax = int(src.height - 1 - rows_any[::-1].argmax())
|
||
cmin = int(cols_any.argmax())
|
||
cmax = int(src.width - 1 - cols_any[::-1].argmax())
|
||
valid_win_src = rasterio.windows.Window(cmin, rmin, cmax - cmin + 1, rmax - rmin + 1)
|
||
valid_bounds_src = rasterio.windows.bounds(valid_win_src, transform=src.transform)
|
||
b = transform_bounds(src_crs, ref_crs, *valid_bounds_src, densify_pts=21)
|
||
else:
|
||
# 掩膜无效时回退到整图bounds
|
||
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
|
||
except Exception:
|
||
src_mask = None # 后续可选源图掩膜时用到
|
||
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
|
||
|
||
win0 = from_bounds(*b, transform=ref_dataset.transform)
|
||
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
|
||
|
||
if win.width <= 0 or win.height <= 0:
|
||
logger.warning(f"无重叠区域: {bip_path.name}")
|
||
return False
|
||
|
||
# 2) 读取数据
|
||
# 读取所有波段,如果是多波段的话
|
||
src_arr = src.read() # (bands, H, W)
|
||
if src_arr.ndim == 2: # 单波段
|
||
src_arr = src_arr[None, ...] # 增加波段维度
|
||
|
||
# 读取参考文件的ROI
|
||
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
|
||
if ref_arr.ndim == 2: # 单波段
|
||
ref_arr = ref_arr[None, ...] # 增加波段维度
|
||
|
||
# 将源图有效掩膜重投影到参考ROI,并适度膨胀后作为匹配掩膜
|
||
try:
|
||
if src_mask is None:
|
||
src_mask = (src.read_masks(1) > 0)
|
||
ref_roi_transform = ref_dataset.window_transform(win)
|
||
roi_h, roi_w = int(win.height), int(win.width)
|
||
dst_mask = np.zeros((roi_h, roi_w), dtype=np.uint8)
|
||
|
||
reproject(
|
||
source=src_mask.astype(np.uint8),
|
||
destination=dst_mask,
|
||
src_transform=src.transform,
|
||
src_crs=src_crs,
|
||
dst_transform=ref_roi_transform,
|
||
dst_crs=ref_crs,
|
||
resampling=Resampling.nearest
|
||
)
|
||
|
||
if MASK_PAD_PX > 0:
|
||
k = max(1, MASK_PAD_PX * 2 + 1) # odd kernel size
|
||
k = min(k, 99) # 防止核过大导致性能问题,可按需调整/删除
|
||
kernel = np.ones((k, k), np.uint8)
|
||
dst_mask = cv2.dilate(dst_mask, kernel, iterations=1)
|
||
except Exception:
|
||
# 掩膜获取/重投影失败则不使用掩膜
|
||
dst_mask = None
|
||
|
||
# 转换为匹配所需的格式
|
||
src_img = _to_3ch_float01(src_arr)
|
||
ref_img = _to_3ch_float01(ref_arr)
|
||
|
||
# 软掩膜:避免在边界产生硬高对比边
|
||
try:
|
||
alpha_src = _soft_alpha_from_mask(src_mask, FEATHER_PX) if src_mask is not None else None
|
||
except Exception:
|
||
alpha_src = None
|
||
try:
|
||
alpha_ref = _soft_alpha_from_mask(dst_mask, FEATHER_PX) if dst_mask is not None else None
|
||
except Exception:
|
||
alpha_ref = None
|
||
|
||
if alpha_src is not None:
|
||
alpha_src3 = np.repeat(alpha_src[None, ...], 3, axis=0).astype(src_img.dtype)
|
||
src_img = src_img * alpha_src3
|
||
|
||
if alpha_ref is not None:
|
||
alpha_ref3 = np.repeat(alpha_ref[None, ...], 3, axis=0).astype(ref_img.dtype)
|
||
ref_img = ref_img * alpha_ref3
|
||
|
||
# 3) 匹配用降采样版本,提速 + 增稳
|
||
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
|
||
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
|
||
|
||
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
|
||
|
||
# 4) 精配准(img0=src, img1=ref_roi)
|
||
result = matcher(src_small, ref_small)
|
||
|
||
# 与小图同尺寸的掩膜
|
||
src_mask_small = _downscale_mask_hw(src_mask, src_small.shape[1], src_small.shape[2]) if 'src_mask' in locals() and src_mask is not None else None
|
||
ref_mask_small = _downscale_mask_hw(dst_mask, ref_small.shape[1], ref_small.shape[2]) if 'dst_mask' in locals() and dst_mask is not None else None
|
||
|
||
# 剔除掩膜边缘带(小图尺度的最小距离)
|
||
def _scale_px(px_full: int, full_wh, small_wh) -> int:
|
||
# 用平均缩放;也可以分别对H/W计算后取最小
|
||
sy = small_wh[0] / max(1, full_wh[0])
|
||
sx = small_wh[1] / max(1, full_wh[1])
|
||
s = 0.5 * (sx + sy)
|
||
return max(1, int(round(px_full * s)))
|
||
|
||
edge_band_src_small = _scale_px(EDGE_BAND_PX, (src_img.shape[1], src_img.shape[2]), (src_small.shape[1], src_small.shape[2]))
|
||
edge_band_ref_small = _scale_px(EDGE_BAND_PX, (ref_img.shape[1], ref_img.shape[2]), (ref_small.shape[1], ref_small.shape[2]))
|
||
|
||
keep_src_edge = _distance_keep_mask(src_mask_small, edge_band_src_small) if src_mask_small is not None else None
|
||
keep_ref_edge = _distance_keep_mask(ref_mask_small, edge_band_ref_small) if ref_mask_small is not None else None
|
||
|
||
# 纹理掩膜
|
||
keep_src_tex = _grad_mask_from_chw(src_small, MIN_GRAD_QUANTILE)
|
||
keep_ref_tex = _grad_mask_from_chw(ref_small, MIN_GRAD_QUANTILE)
|
||
|
||
# 组合最终保留掩膜(边缘+纹理),二者都要满足
|
||
def _combine_keep(m_edge, m_tex):
|
||
if m_edge is None:
|
||
return m_tex
|
||
return (m_edge & m_tex)
|
||
|
||
keep_src_final = _combine_keep(keep_src_edge, keep_src_tex)
|
||
keep_ref_final = _combine_keep(keep_ref_edge, keep_ref_tex)
|
||
|
||
# 将匹配与内点严格限制在最终掩膜内
|
||
def _filter_by_bool_masks(res, m_src, m_ref):
|
||
if m_src is None or m_ref is None:
|
||
return res
|
||
|
||
def keep_in(mask_hw, pts):
|
||
if pts is None or len(pts) == 0:
|
||
return np.zeros((0,), dtype=bool)
|
||
xs = np.clip(np.rint(pts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1)
|
||
ys = np.clip(np.rint(pts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1)
|
||
return mask_hw[ys, xs]
|
||
|
||
# matched
|
||
if "matched_kpts0" in res and "matched_kpts1" in res:
|
||
mk0 = np.asarray(res["matched_kpts0"]); mk1 = np.asarray(res["matched_kpts1"])
|
||
if len(mk0) == len(mk1) and len(mk0) > 0:
|
||
keep_m = keep_in(m_src, mk0) & keep_in(m_ref, mk1)
|
||
res["matched_kpts0"] = mk0[keep_m]
|
||
res["matched_kpts1"] = mk1[keep_m]
|
||
|
||
# inliers
|
||
if "inlier_kpts0" in res and "inlier_kpts1" in res and res["inlier_kpts0"] is not None:
|
||
ik0 = np.asarray(res["inlier_kpts0"]); ik1 = np.asarray(res["inlier_kpts1"])
|
||
if len(ik0) == len(ik1) and len(ik0) > 0:
|
||
keep_i = keep_in(m_src, ik0) & keep_in(m_ref, ik1)
|
||
res["inlier_kpts0"] = ik0[keep_i]
|
||
res["inlier_kpts1"] = ik1[keep_i]
|
||
res["num_inliers"] = int(len(res["inlier_kpts0"]))
|
||
return res
|
||
|
||
result = _filter_by_bool_masks(result, keep_src_final, keep_ref_final)
|
||
|
||
# 统计(以过滤后的结果为准)
|
||
num_inl = int(result.get("num_inliers", len(result.get("inlier_kpts0", []))))
|
||
num_m = len(result.get("matched_kpts0", []))
|
||
ratio = (num_inl / num_m) if num_m else 0.0
|
||
|
||
# 更新统计变量
|
||
num_inliers = num_inl
|
||
num_matches = num_m
|
||
inlier_ratio = ratio
|
||
|
||
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
|
||
|
||
# 保存匹配可视化图像(使用与匹配同尺寸的图像,保持CHW格式)
|
||
viz_dir = out_dir / "visualizations"
|
||
viz_dir.mkdir(exist_ok=True)
|
||
|
||
matches_path = viz_dir / f"{bip_path.stem}_matches.png"
|
||
plot_matches(src_small, ref_small, result, save_path=str(matches_path))
|
||
logger.info(f"匹配可视化已保存: {matches_path}")
|
||
|
||
# 关键点可视化(源图像)
|
||
kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png"
|
||
plot_keypoints(
|
||
src_small,
|
||
{"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]},
|
||
save_path=str(kpts_src_path)
|
||
)
|
||
logger.info(f"源图像关键点可视化已保存: {kpts_src_path}")
|
||
|
||
# 关键点可视化(参考图像)
|
||
kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png"
|
||
plot_keypoints(
|
||
ref_small,
|
||
{"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]},
|
||
save_path=str(kpts_ref_path)
|
||
)
|
||
logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}")
|
||
|
||
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
|
||
logger.warning(f"匹配质量不足: {bip_path.name}")
|
||
# 记录失败的统计信息
|
||
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
|
||
inlier_ratio, "failed_quality_check", median_error, p95_error, False)
|
||
return False
|
||
|
||
# 5) 用内点估计多种变换并自动选择最优
|
||
# 先计算全分辨率坐标
|
||
k0_small = result["inlier_kpts0"].astype(np.float32)
|
||
k1_small = result["inlier_kpts1"].astype(np.float32)
|
||
|
||
s0x = src_img.shape[2] / src_small.shape[2]
|
||
s0y = src_img.shape[1] / src_small.shape[1]
|
||
s1x = ref_img.shape[2] / ref_small.shape[2]
|
||
s1y = ref_img.shape[1] / ref_small.shape[1]
|
||
|
||
S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src)
|
||
S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI)
|
||
|
||
ones = np.ones((k0_small.shape[0], 1), dtype=np.float32)
|
||
k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素
|
||
k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素
|
||
k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素
|
||
|
||
|
||
# 用全分辨率坐标进行所有模型的估计和评估
|
||
best_transform = None
|
||
best_transform_type = None
|
||
best_error = np.inf
|
||
best_median_error = np.inf
|
||
best_method = None
|
||
|
||
for method in TRANSFORM_METHODS:
|
||
transform_type, transform = estimate_transform(method, k0_full, k1_global)
|
||
if transform is None:
|
||
continue
|
||
|
||
med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global)
|
||
|
||
# 选择重投影误差最小的变换
|
||
if p95_err < best_error:
|
||
best_transform = transform
|
||
best_transform_type = transform_type
|
||
best_error = p95_err
|
||
best_median_error = med_err
|
||
best_method = method
|
||
|
||
logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}")
|
||
|
||
if best_transform is None:
|
||
logger.warning(f"所有变换方法都失败: {bip_path.name}")
|
||
# 记录失败的统计信息
|
||
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
|
||
inlier_ratio, "failed_transform", median_error, p95_error, False)
|
||
return False
|
||
|
||
# 更新统计变量
|
||
selected_method = best_method
|
||
median_error = best_median_error
|
||
p95_error = best_error
|
||
|
||
logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}")
|
||
|
||
# 6) 根据变换类型进行相应的配准处理
|
||
if best_transform_type == "A":
|
||
# 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射
|
||
A = best_transform # 2x3, src_full_pixel -> ref_full_pixel
|
||
A3 = np.eye(3, dtype=np.float64)
|
||
A3[:2, :] = A
|
||
|
||
# src_pixel -> map
|
||
ref_transform = ref_dataset.transform
|
||
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
|
||
[ref_transform.d, ref_transform.e, ref_transform.f],
|
||
[0, 0, 1]], dtype=np.float64)
|
||
M_map = Rt @ A3
|
||
corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2],
|
||
M_map[1,0], M_map[1,1], M_map[1,2])
|
||
|
||
# 用 M_map 求最小外接矩形(先到 map,再到 ref 像素)
|
||
Rt_inv = np.linalg.inv(Rt)
|
||
src_h, src_w = src.height, src.width
|
||
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
|
||
corn_h = np.hstack([corners, np.ones((4,1))]).T
|
||
map_corners = (M_map @ corn_h).T[:, :2]
|
||
pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2]
|
||
|
||
min_x = int(np.floor(pix_corners[:,0].min())) - 10
|
||
max_x = int(np.ceil (pix_corners[:,0].max())) + 10
|
||
min_y = int(np.floor(pix_corners[:,1].min())) - 10
|
||
max_y = int(np.ceil (pix_corners[:,1].max())) + 10
|
||
|
||
min_x = max(0, min_x); min_y = max(0, min_y)
|
||
max_x = min(ref_dataset.width, max_x)
|
||
max_y = min(ref_dataset.height, max_y)
|
||
|
||
bbox_w = max_x - min_x
|
||
bbox_h = max_y - min_y
|
||
|
||
if bbox_w <= 0 or bbox_h <= 0:
|
||
logger.warning(f"最小外接矩形无效: {bip_path.name}")
|
||
return False
|
||
|
||
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
|
||
bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform)
|
||
|
||
res_x, res_y = _pixel_size_xy(src.transform)
|
||
out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y)
|
||
|
||
out_path = out_dir / f"{bip_path.stem}_registered.bip"
|
||
src_nodata = src.nodata
|
||
dst_nodata = src_nodata if src_nodata is not None else 0
|
||
|
||
out_profile = src.profile.copy()
|
||
out_profile.update(
|
||
driver="ENVI",
|
||
dtype=src.dtypes[0],
|
||
height=out_h,
|
||
width=out_w,
|
||
count=src.count,
|
||
transform=out_transform,
|
||
crs=ref_crs,
|
||
interleave="bip",
|
||
compress=None,
|
||
nodata=dst_nodata
|
||
)
|
||
|
||
with rasterio.open(out_path, "w", **out_profile) as out_ds:
|
||
for b in range(1, src.count + 1):
|
||
src_band = src.read(b).astype(np.float32)
|
||
dst_band = np.zeros((out_h, out_w), dtype=np.float32)
|
||
reproject(
|
||
source=src_band,
|
||
destination=dst_band,
|
||
src_transform=corrected_affine,
|
||
src_crs=ref_crs,
|
||
dst_transform=out_transform,
|
||
dst_crs=ref_crs,
|
||
src_nodata=src_nodata,
|
||
dst_nodata=dst_nodata,
|
||
resampling=Resampling.nearest,
|
||
)
|
||
|
||
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
|
||
mask = (dst_band == dst_nodata) if src_nodata is not None else None
|
||
info = np.iinfo(out_profile["dtype"])
|
||
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
|
||
if mask is not None:
|
||
dst_band[mask] = dst_nodata
|
||
else:
|
||
dst_band = dst_band.astype(out_profile["dtype"])
|
||
|
||
out_ds.write(dst_band, b)
|
||
|
||
logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}")
|
||
success = True
|
||
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
|
||
inlier_ratio, selected_method, median_error, p95_error, success)
|
||
return True
|
||
|
||
# ---- 非仿射变换处理 ----
|
||
elif best_transform_type == "H":
|
||
# 单应变换:H 已是 src_full_pixel -> ref_full_pixel
|
||
H_full = best_transform # 3x3
|
||
|
||
try:
|
||
# 用 H_full 映射源四角 -> 参考像素,求最小外接矩形
|
||
src_h, src_w = src.height, src.width
|
||
corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32)
|
||
corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T
|
||
dst_h = (H_full @ corn_h)
|
||
dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T
|
||
|
||
min_x = int(np.floor(dst[:,0].min())) - 10
|
||
max_x = int(np.ceil (dst[:,0].max())) + 10
|
||
min_y = int(np.floor(dst[:,1].min())) - 10
|
||
max_y = int(np.ceil (dst[:,1].max())) + 10
|
||
|
||
min_x = max(0, min_x); min_y = max(0, min_y)
|
||
max_x = min(ref_dataset.width, max_x)
|
||
max_y = min(ref_dataset.height, max_y)
|
||
|
||
bbox_w = max_x - min_x
|
||
bbox_h = max_y - min_y
|
||
|
||
if bbox_w <= 0 or bbox_h <= 0:
|
||
logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}")
|
||
return False
|
||
|
||
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
|
||
bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform)
|
||
|
||
res_x, res_y = _pixel_size_xy(src.transform)
|
||
out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y)
|
||
|
||
out_path = out_dir / f"{bip_path.stem}_registered.bip"
|
||
src_nodata = src.nodata
|
||
dst_nodata = src_nodata if src_nodata is not None else 0
|
||
|
||
out_profile = src.profile.copy()
|
||
out_profile.update(
|
||
driver="ENVI",
|
||
dtype=src.dtypes[0],
|
||
height=out_h,
|
||
width=out_w,
|
||
count=src.count,
|
||
transform=out_transform,
|
||
crs=ref_crs,
|
||
interleave="bip",
|
||
compress=None,
|
||
nodata=dst_nodata
|
||
)
|
||
|
||
ref_transform = ref_dataset.transform
|
||
Rt = np.array(
|
||
[[ref_transform.a, ref_transform.b, ref_transform.c],
|
||
[ref_transform.d, ref_transform.e, ref_transform.f],
|
||
[0.0, 0.0, 1.0]],
|
||
dtype=np.float64,
|
||
)
|
||
Out = np.array(
|
||
[[out_transform.a, out_transform.b, out_transform.c],
|
||
[out_transform.d, out_transform.e, out_transform.f],
|
||
[0.0, 0.0, 1.0]],
|
||
dtype=np.float64,
|
||
)
|
||
M = np.linalg.inv(Out) @ Rt @ H_full.astype(np.float64)
|
||
|
||
with rasterio.open(out_path, "w", **out_profile) as out_ds:
|
||
for b in range(1, src.count + 1):
|
||
src_band = src.read(b).astype(np.float32)
|
||
dst_band = cv2.warpPerspective(
|
||
src_band,
|
||
M,
|
||
(out_w, out_h),
|
||
flags=cv2.INTER_NEAREST,
|
||
borderMode=cv2.BORDER_CONSTANT,
|
||
borderValue=float(dst_nodata)
|
||
).astype(np.float32)
|
||
|
||
# 转回目标 dtype
|
||
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
|
||
mask = (dst_band == dst_nodata) if src_nodata is not None else None
|
||
info = np.iinfo(out_profile["dtype"])
|
||
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
|
||
if mask is not None:
|
||
dst_band[mask] = dst_nodata
|
||
else:
|
||
dst_band = dst_band.astype(out_profile["dtype"])
|
||
|
||
out_ds.write(dst_band, b)
|
||
|
||
logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}")
|
||
success = True
|
||
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
|
||
inlier_ratio, selected_method, median_error, p95_error, success)
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.warning(f"单应变换异常: {e}")
|
||
# 继续到仿射回退
|
||
|
||
elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]:
|
||
# 分片仿射或多项式变换:使用 scikit-image
|
||
transform = best_transform # 已用 k0_full/k1_global 估计
|
||
try:
|
||
# 用目标侧匹配点(k1_global)决定外接矩形(更稳)
|
||
pad = 10
|
||
min_x = int(np.floor(k1_global[:, 0].min())) - pad
|
||
max_x = int(np.ceil (k1_global[:, 0].max())) + pad
|
||
min_y = int(np.floor(k1_global[:, 1].min())) - pad
|
||
max_y = int(np.ceil (k1_global[:, 1].max())) + pad
|
||
|
||
min_x = max(0, min_x)
|
||
min_y = max(0, min_y)
|
||
max_x = min(ref_dataset.width, max_x)
|
||
max_y = min(ref_dataset.height, max_y)
|
||
|
||
bbox_w = max_x - min_x
|
||
bbox_h = max_y - min_y
|
||
|
||
if bbox_w <= 0 or bbox_h <= 0:
|
||
logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}")
|
||
return False
|
||
|
||
# 创建输出窗口
|
||
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
|
||
bbox_transform = ref_dataset.window_transform(bbox_window)
|
||
|
||
out_path = out_dir / f"{bip_path.stem}_registered.bip"
|
||
src_nodata = src.nodata
|
||
dst_nodata = src_nodata if src_nodata is not None else 0
|
||
|
||
out_profile = ref_dataset.profile.copy()
|
||
out_profile.update(
|
||
driver="ENVI",
|
||
dtype=src.dtypes[0],
|
||
height=bbox_h,
|
||
width=bbox_w,
|
||
count=src.count,
|
||
transform=bbox_transform,
|
||
crs=ref_crs,
|
||
interleave="bip",
|
||
compress=None,
|
||
nodata=dst_nodata
|
||
)
|
||
|
||
# 定义带偏移的逆映射函数
|
||
off_x, off_y = min_x, min_y
|
||
|
||
if best_transform_type in ["polynomial", "polynomial_order3"]:
|
||
# 对于多项式,估计逆变换
|
||
order = 2 if best_transform_type == "polynomial" else 3
|
||
t_inv = PolynomialTransform()
|
||
t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源
|
||
|
||
# 目标侧点集的内点判定(用于限制外推)
|
||
if MATPLOTLIB_SCIPY_AVAILABLE:
|
||
try:
|
||
hull = ConvexHull(k1_global)
|
||
hull_path = MplPath(k1_global[hull.vertices])
|
||
except Exception:
|
||
rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y],
|
||
[min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float)
|
||
hull_path = MplPath(rect)
|
||
|
||
def point_inside(xy):
|
||
return hull_path.contains_points(xy)
|
||
else:
|
||
def point_inside(xy):
|
||
return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) &
|
||
(xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h))
|
||
|
||
def inv_map_rc(coords):
|
||
# coords: (N,2) in (row, col)
|
||
rc = np.asarray(coords)
|
||
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref
|
||
inside = point_inside(xy)
|
||
xy_src = np.full_like(xy, fill_value=-1.0)
|
||
if np.any(inside):
|
||
xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src
|
||
# 确保坐标在源图像范围内
|
||
xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1)
|
||
xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1)
|
||
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
|
||
elif best_transform_type == "piecewise": # piecewise_affine
|
||
# 目标侧点集的内点判定
|
||
if MATPLOTLIB_SCIPY_AVAILABLE:
|
||
try:
|
||
hull = ConvexHull(k1_global)
|
||
hull_path = MplPath(k1_global[hull.vertices])
|
||
except Exception:
|
||
# 使用当前裁剪窗口的边界创建矩形
|
||
rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float)
|
||
hull_path = MplPath(rect)
|
||
|
||
def point_inside(xy):
|
||
return hull_path.contains_points(xy)
|
||
else:
|
||
# 退化为矩形内判断
|
||
def point_inside(xy):
|
||
return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \
|
||
(xy[:,1] >= min_y) & (xy[:,1] <= max_y)
|
||
|
||
def inv_map_rc(coords):
|
||
rc = np.asarray(coords)
|
||
xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref
|
||
inside = point_inside(xy)
|
||
xy_src = np.full_like(xy, fill_value=-1.0)
|
||
if np.any(inside):
|
||
xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src)
|
||
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src)
|
||
|
||
# 使用 scikit-image 进行变换重采样
|
||
from skimage.transform import warp
|
||
with rasterio.open(out_path, "w", **out_profile) as out_ds:
|
||
for b in range(1, src.count + 1):
|
||
src_band = src.read(b).astype(np.float32)
|
||
dst_band = warp(
|
||
src_band,
|
||
inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射
|
||
output_shape=(bbox_h, bbox_w),
|
||
mode='constant',
|
||
cval=dst_nodata,
|
||
preserve_range=True,
|
||
order=0
|
||
).astype(np.float32)
|
||
|
||
# 转回目标 dtype
|
||
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
|
||
mask = (dst_band == dst_nodata) if src_nodata is not None else None
|
||
info = np.iinfo(out_profile["dtype"])
|
||
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
|
||
if mask is not None:
|
||
dst_band[mask] = dst_nodata
|
||
else:
|
||
dst_band = dst_band.astype(out_profile["dtype"])
|
||
|
||
out_ds.write(dst_band, b)
|
||
|
||
logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}")
|
||
success = True
|
||
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
|
||
inlier_ratio, selected_method, median_error, p95_error, success)
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.warning(f"{best_transform_type}变换异常: {e}")
|
||
# 继续到仿射回退
|
||
|
||
# ---- 回退:使用仿射变换,保证最小可用结果 ----
|
||
transform = best_transform
|
||
try:
|
||
min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1(
|
||
k1_global, ref_dataset.width, ref_dataset.height, pad=10
|
||
)
|
||
if bbox_w <= 0 or bbox_h <= 0:
|
||
logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}")
|
||
return False
|
||
|
||
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
|
||
bbox_transform = ref_dataset.window_transform(bbox_window)
|
||
|
||
if MATPLOTLIB_SCIPY_AVAILABLE:
|
||
try:
|
||
hull = ConvexHull(k1_global)
|
||
hull_path = MplPath(k1_global[hull.vertices])
|
||
except Exception:
|
||
rect = np.array(
|
||
[[min_x, min_y], [min_x + bbox_w, min_y],
|
||
[min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]],
|
||
dtype=float
|
||
)
|
||
hull_path = MplPath(rect)
|
||
|
||
def point_inside(xy):
|
||
return hull_path.contains_points(xy)
|
||
else:
|
||
def point_inside(xy):
|
||
return (
|
||
(xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) &
|
||
(xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h)
|
||
)
|
||
|
||
off_x, off_y = min_x, min_y
|
||
tps_inv = transform["inv"] # ref -> src
|
||
|
||
def inv_map_rc(coords):
|
||
rc = np.asarray(coords, dtype=np.float64)
|
||
xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y)
|
||
inside = point_inside(xy_ref)
|
||
xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64)
|
||
if np.any(inside):
|
||
# 使用RBF插值计算逆映射
|
||
xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1])
|
||
xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1])
|
||
return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src)
|
||
|
||
out_path = out_dir / f"{bip_path.stem}_registered.bip"
|
||
src_nodata = src.nodata
|
||
dst_nodata = src_nodata if src_nodata is not None else 0
|
||
|
||
out_profile = ref_dataset.profile.copy()
|
||
out_profile.update(
|
||
driver="ENVI",
|
||
dtype=src.dtypes[0],
|
||
height=bbox_h,
|
||
width=bbox_w,
|
||
count=src.count,
|
||
transform=bbox_transform,
|
||
crs=ref_crs,
|
||
interleave="bip",
|
||
compress=None,
|
||
nodata=dst_nodata
|
||
)
|
||
|
||
# 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底
|
||
if SKIMAGE_AVAILABLE:
|
||
from skimage.transform import warp
|
||
with rasterio.open(out_path, "w", **out_profile) as out_ds:
|
||
for b in range(1, src.count + 1):
|
||
src_band = src.read(b).astype(np.float32)
|
||
dst_band = warp(
|
||
src_band,
|
||
inverse_map=inv_map_rc,
|
||
output_shape=(bbox_h, bbox_w),
|
||
mode='constant',
|
||
cval=dst_nodata,
|
||
preserve_range=True,
|
||
order=0
|
||
).astype(np.float32)
|
||
|
||
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
|
||
mask = (dst_band == dst_nodata) if src_nodata is not None else None
|
||
info = np.iinfo(out_profile["dtype"])
|
||
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
|
||
if mask is not None:
|
||
dst_band[mask] = dst_nodata
|
||
else:
|
||
dst_band = dst_band.astype(out_profile["dtype"])
|
||
|
||
out_ds.write(dst_band, b)
|
||
else:
|
||
# OpenCV remap 版本(无需 skimage/SimpleITK)
|
||
with rasterio.open(out_path, "w", **out_profile) as out_ds:
|
||
# 创建映射网格
|
||
y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w]
|
||
coords = np.column_stack([y_coords.ravel(), x_coords.ravel()])
|
||
|
||
# 计算逆映射
|
||
mapped_coords = inv_map_rc(coords)
|
||
map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32)
|
||
map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32)
|
||
|
||
for b in range(1, src.count + 1):
|
||
src_band = src.read(b).astype(np.float32)
|
||
|
||
# 使用OpenCV的remap进行重采样
|
||
dst_band = cv2.remap(
|
||
src_band, map_x, map_y,
|
||
interpolation=cv2.INTER_NEAREST,
|
||
borderMode=cv2.BORDER_CONSTANT,
|
||
borderValue=dst_nodata
|
||
)
|
||
|
||
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
|
||
mask = (dst_band == dst_nodata) if src_nodata is not None else None
|
||
info = np.iinfo(out_profile["dtype"])
|
||
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
|
||
if mask is not None:
|
||
dst_band[mask] = dst_nodata
|
||
else:
|
||
dst_band = dst_band.astype(out_profile["dtype"])
|
||
|
||
out_ds.write(dst_band, b)
|
||
|
||
logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.warning(f"tps变换异常: {e}")
|
||
# 继续到仿射回退
|
||
|
||
|
||
|
||
# ---- 回退:使用仿射变换,保证最小可用结果 ----
|
||
# 重新估计仿射变换作为fallback
|
||
A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0)
|
||
if A_fallback is None:
|
||
logger.warning(f"仿射回退也失败: {bip_path.name}")
|
||
return False
|
||
|
||
# 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标
|
||
s0x = src_img.shape[2] / src_small.shape[2]
|
||
s0y = src_img.shape[1] / src_small.shape[1]
|
||
s1x = ref_img.shape[2] / ref_small.shape[2]
|
||
s1y = ref_img.shape[1] / ref_small.shape[1]
|
||
S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64)
|
||
S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64)
|
||
A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback
|
||
M_full = S1_inv @ A3 @ S0
|
||
|
||
T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64)
|
||
ref_transform = ref_dataset.transform
|
||
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
|
||
[ref_transform.d, ref_transform.e, ref_transform.f],
|
||
[0, 0, 1]], dtype=np.float64)
|
||
src_pixel_to_map_corrected = Rt @ T_off @ M_full
|
||
corrected_affine = Affine(
|
||
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
|
||
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
|
||
)
|
||
|
||
# 计算源 BIP 四角经过仿射变换后的最小外接矩形
|
||
# 将 rasterio.Affine 转为 3x3 像素->地图矩阵
|
||
M_map = np.array([
|
||
[corrected_affine.a, corrected_affine.b, corrected_affine.c],
|
||
[corrected_affine.d, corrected_affine.e, corrected_affine.f],
|
||
[0.0, 0.0, 1.0]
|
||
], dtype=np.float64)
|
||
|
||
# 参考底图的 像素->地图 矩阵及其逆
|
||
ref_transform = ref_dataset.transform
|
||
Rt = np.array([
|
||
[ref_transform.a, ref_transform.b, ref_transform.c],
|
||
[ref_transform.d, ref_transform.e, ref_transform.f],
|
||
[0.0, 0.0, 1.0]
|
||
], dtype=np.float64)
|
||
Rt_inv = np.linalg.inv(Rt)
|
||
|
||
# 源影像四角(源像素坐标)
|
||
src_h, src_w = src.height, src.width
|
||
src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
|
||
corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4)
|
||
|
||
# 源像素 -> 地图坐标
|
||
map_corners = (M_map @ corners_h).T[:, :2]
|
||
|
||
# 地图坐标 -> 参考像素坐标
|
||
pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3)
|
||
pix_corners = pix_corners_h[:, :2]
|
||
|
||
# 最小外接矩形(像素)
|
||
min_x = int(np.floor(pix_corners[:,0].min())) - 10
|
||
max_x = int(np.ceil( pix_corners[:,0].max())) + 10
|
||
min_y = int(np.floor(pix_corners[:,1].min())) - 10
|
||
max_y = int(np.ceil( pix_corners[:,1].max())) + 10
|
||
|
||
# 边界裁剪
|
||
min_x = max(0, min_x); min_y = max(0, min_y)
|
||
max_x = min(ref_dataset.width, max_x)
|
||
max_y = min(ref_dataset.height, max_y)
|
||
|
||
bbox_w = max_x - min_x
|
||
bbox_h = max_y - min_y
|
||
|
||
# 如果外接矩形太小,跳过
|
||
if bbox_w <= 0 or bbox_h <= 0:
|
||
logger.warning(f"最小外接矩形无效: {bip_path.name}")
|
||
return False
|
||
|
||
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
|
||
bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform)
|
||
|
||
res_x, res_y = _pixel_size_xy(src.transform)
|
||
out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y)
|
||
|
||
out_path = out_dir / f"{bip_path.stem}_registered.bip"
|
||
src_nodata = src.nodata
|
||
dst_nodata = src_nodata if src_nodata is not None else 0
|
||
|
||
out_profile = src.profile.copy()
|
||
out_profile.update(
|
||
driver="ENVI",
|
||
dtype=src.dtypes[0],
|
||
height=out_h,
|
||
width=out_w,
|
||
count=src.count,
|
||
transform=out_transform,
|
||
crs=ref_crs,
|
||
interleave="bip",
|
||
compress=None,
|
||
nodata=dst_nodata
|
||
)
|
||
|
||
with rasterio.open(out_path, "w", **out_profile) as out_ds:
|
||
for b in range(1, src.count + 1):
|
||
src_band = src.read(b).astype(np.float32)
|
||
dst_band = np.zeros((out_h, out_w), dtype=np.float32)
|
||
reproject(
|
||
source=src_band,
|
||
destination=dst_band,
|
||
src_transform=corrected_affine,
|
||
src_crs=ref_crs,
|
||
dst_transform=out_transform,
|
||
dst_crs=ref_crs,
|
||
src_nodata=src_nodata,
|
||
dst_nodata=dst_nodata,
|
||
resampling=Resampling.nearest,
|
||
)
|
||
# 转回目标 dtype
|
||
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
|
||
mask = (dst_band == dst_nodata) if src_nodata is not None else None
|
||
info = np.iinfo(out_profile["dtype"])
|
||
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
|
||
if mask is not None:
|
||
dst_band[mask] = dst_nodata
|
||
else:
|
||
dst_band = dst_band.astype(out_profile["dtype"])
|
||
|
||
out_ds.write(dst_band, b)
|
||
|
||
logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}")
|
||
success = True
|
||
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
|
||
inlier_ratio, "affine_fallback", median_error, p95_error, success)
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
|
||
# 记录失败的统计信息
|
||
try:
|
||
log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches,
|
||
inlier_ratio, "exception", median_error, p95_error, False)
|
||
except:
|
||
pass # 避免统计记录失败影响主要错误处理
|
||
return False
|
||
|
||
def _apply_config(cfg: RegistrationConfig):
|
||
global REF_TIF, BIP_DIR, OUT_DIR
|
||
global MATCHER_NAME, DEVICE, TRANSFORM_METHODS
|
||
global MATCH_MAX_SIDE, ROI_PAD_PX, MASK_PAD_PX
|
||
global MIN_INLIERS, MIN_INLIER_RATIO
|
||
global FEATHER_PX, EDGE_BAND_PX, MIN_GRAD_QUANTILE
|
||
|
||
REF_TIF = cfg.ref_tif
|
||
BIP_DIR = Path(cfg.bip_dir)
|
||
OUT_DIR = Path(cfg.out_dir)
|
||
MATCHER_NAME = cfg.matcher_name
|
||
DEVICE = cfg.device
|
||
TRANSFORM_METHODS = list(cfg.transform_methods)
|
||
MATCH_MAX_SIDE = int(cfg.match_max_side)
|
||
ROI_PAD_PX = int(cfg.roi_pad_px)
|
||
MASK_PAD_PX = int(cfg.mask_pad_px)
|
||
MIN_INLIERS = int(cfg.min_inliers)
|
||
MIN_INLIER_RATIO = float(cfg.min_inlier_ratio)
|
||
FEATHER_PX = int(cfg.feather_px)
|
||
EDGE_BAND_PX = int(cfg.edge_band_px)
|
||
MIN_GRAD_QUANTILE = float(cfg.min_grad_quantile)
|
||
|
||
|
||
def _run_batch(cfg: RegistrationConfig, stop_event: threading.Event, progress_cb=None):
|
||
_apply_config(cfg)
|
||
|
||
out_dir = OUT_DIR
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
stats_dir = out_dir / "stats"
|
||
stats_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
ts = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
stats_csv = stats_dir / f"registration_stats_{ts}.csv"
|
||
logger.info(f"统计信息将保存到: {stats_csv}")
|
||
|
||
init_stats_csv(stats_csv)
|
||
|
||
_ensure_pyinstaller_third_party_paths()
|
||
_install_loguru_stub_if_missing()
|
||
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
|
||
ref_path_to_use = REF_TIF
|
||
if bool(cfg.enable_ref_mask):
|
||
if not TIF_MASK_AVAILABLE:
|
||
raise RuntimeError("未能导入 tif_caijain.py,无法进行底图掩膜。")
|
||
if not cfg.ref_mask_tif or not Path(cfg.ref_mask_tif).exists():
|
||
raise RuntimeError("已启用底图掩膜,但掩膜 TIF 文件不存在。")
|
||
|
||
masked_dir = out_dir / "masked_refs"
|
||
masked_dir.mkdir(parents=True, exist_ok=True)
|
||
masked_ref_path = masked_dir / f"{Path(REF_TIF).stem}_masked_{ts}.tif"
|
||
|
||
logger.info(f"开始对底图进行掩膜: {REF_TIF}")
|
||
logger.info(f"掩膜文件: {cfg.ref_mask_tif}")
|
||
mask_data_by_binary_mask(
|
||
data_path=REF_TIF,
|
||
mask_path=cfg.ref_mask_tif,
|
||
output_path=str(masked_ref_path),
|
||
remove_value=int(cfg.ref_mask_remove_value),
|
||
)
|
||
ref_path_to_use = str(masked_ref_path)
|
||
logger.info(f"掩膜后的底图: {ref_path_to_use}")
|
||
|
||
with rasterio.open(ref_path_to_use) as ref:
|
||
bip_files = list(Path(BIP_DIR).glob("*.bip"))
|
||
total = len(bip_files)
|
||
success_count = 0
|
||
|
||
if progress_cb is not None:
|
||
progress_cb(0, total, "")
|
||
|
||
for idx, bip_path in enumerate(bip_files, start=1):
|
||
if stop_event.is_set():
|
||
break
|
||
if process_bip_to_tif(bip_path, ref, matcher, out_dir, stats_csv):
|
||
success_count += 1
|
||
if progress_cb is not None:
|
||
progress_cb(idx, total, bip_path.name)
|
||
|
||
return success_count
|
||
|
||
|
||
class QueueHandler(logging.Handler):
|
||
def __init__(self, log_queue):
|
||
super().__init__()
|
||
self.log_queue = log_queue
|
||
|
||
def emit(self, record):
|
||
self.log_queue.put(self.format(record))
|
||
|
||
class ToolTip:
|
||
def __init__(self, widget, text: str, delay_ms: int = 400):
|
||
self.widget = widget
|
||
self.text = text
|
||
self.delay_ms = int(delay_ms)
|
||
self._after_id = None
|
||
self._tip = None
|
||
|
||
self.widget.bind("<Enter>", self._on_enter, add=True)
|
||
self.widget.bind("<Leave>", self._on_leave, add=True)
|
||
self.widget.bind("<ButtonPress>", self._on_leave, add=True)
|
||
|
||
def _on_enter(self, _event=None):
|
||
self._schedule()
|
||
|
||
def _on_leave(self, _event=None):
|
||
self._cancel()
|
||
self._hide()
|
||
|
||
def _schedule(self):
|
||
self._cancel()
|
||
try:
|
||
self._after_id = self.widget.after(self.delay_ms, self._show)
|
||
except Exception:
|
||
self._after_id = None
|
||
|
||
def _cancel(self):
|
||
if self._after_id is not None:
|
||
try:
|
||
self.widget.after_cancel(self._after_id)
|
||
except Exception:
|
||
pass
|
||
self._after_id = None
|
||
|
||
def _show(self):
|
||
if self._tip is not None:
|
||
return
|
||
if not self.text:
|
||
return
|
||
try:
|
||
x = self.widget.winfo_rootx() + 10
|
||
y = self.widget.winfo_rooty() + self.widget.winfo_height() + 6
|
||
except Exception:
|
||
return
|
||
|
||
self._tip = tk.Toplevel(self.widget)
|
||
self._tip.wm_overrideredirect(True)
|
||
self._tip.wm_geometry(f"+{x}+{y}")
|
||
|
||
label = tk.Label(
|
||
self._tip,
|
||
text=self.text,
|
||
justify=tk.LEFT,
|
||
background="#ffffe0",
|
||
relief=tk.SOLID,
|
||
borderwidth=1,
|
||
wraplength=520,
|
||
)
|
||
label.pack(ipadx=6, ipady=4)
|
||
|
||
def _hide(self):
|
||
if self._tip is not None:
|
||
try:
|
||
self._tip.destroy()
|
||
except Exception:
|
||
pass
|
||
self._tip = None
|
||
|
||
|
||
_MATCHER_VALUES = [
|
||
"liftfeat", "loftr", "eloftr", "se2loftr", "xoftr", "aspanformer",
|
||
"matchanything-eloftr", "matchanything-roma", "matchformer",
|
||
"sift-lightglue", "superpoint-lightglue", "disk-lightglue",
|
||
"aliked-lightglue", "doghardnet-lightglue", "roma", "romav2",
|
||
"tiny-roma", "dedode", "steerers", "affine-steerers",
|
||
"dedode-kornia", "sift-nn", "orb-nn", "patch2pix", "superglue",
|
||
"r2d2", "d2net", "duster", "master", "doghardnet-nn", "xfeat",
|
||
"xfeat-star", "xfeat-lightglue", "dedode-lightglue", "gim-dkm",
|
||
"gim-lightglue", "omniglue", "xfeat-subpx", "xfeat-lightglue-subpx",
|
||
"dedode-subpx", "superpoint-lightglue-subpx", "aliked-lightglue-subpx",
|
||
"sift-sphereglue", "superpoint-sphereglue", "minima", "minima-roma",
|
||
"minima-roma-tiny", "minima-superpoint-lightglue", "minima-loftr",
|
||
"minima-xoftr", "edm", "lisrd-aliked", "lisrd-superpoint", "lisrd",
|
||
"lisrd-sift", "ripe", "topicfm", "topicfm-plus", "silk", "zippypoint",
|
||
"xfeat-steerers-perm", "xfeat-steerers-learned", "xfeat-star-steerers-perm",
|
||
"xfeat-star-steerers-learned",
|
||
]
|
||
|
||
|
||
class RegistrationGUI:
|
||
def __init__(self, root):
|
||
self.root = root
|
||
self.root.title("遥感影像批量配准工具")
|
||
self.root.geometry("1000x800")
|
||
self._tooltips = []
|
||
|
||
self.log_queue = queue.Queue()
|
||
self.stop_event = threading.Event()
|
||
self.processing_thread = None
|
||
|
||
queue_handler = QueueHandler(self.log_queue)
|
||
queue_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
||
logger.addHandler(queue_handler)
|
||
logger.setLevel(logging.INFO)
|
||
|
||
self.create_widgets()
|
||
self.check_log_queue()
|
||
|
||
def add_tooltip(self, widget, text: str):
|
||
self._tooltips.append(ToolTip(widget, text))
|
||
|
||
def show_error_dialog(self, title: str, summary: str, details: str):
|
||
win = tk.Toplevel(self.root)
|
||
win.title(title)
|
||
win.geometry("900x600")
|
||
|
||
top = ttk.Frame(win, padding=10)
|
||
top.pack(fill=tk.BOTH, expand=True)
|
||
|
||
summary_label = tk.Label(top, text=summary, fg="#b00020", justify=tk.LEFT, wraplength=860)
|
||
summary_label.pack(anchor=tk.W, fill=tk.X)
|
||
|
||
text_frame = ttk.Frame(top)
|
||
text_frame.pack(fill=tk.BOTH, expand=True, pady=(10, 0))
|
||
|
||
scrollbar = ttk.Scrollbar(text_frame, orient=tk.VERTICAL)
|
||
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||
|
||
text = tk.Text(text_frame, wrap=tk.NONE, yscrollcommand=scrollbar.set)
|
||
text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||
scrollbar.config(command=text.yview)
|
||
|
||
if details:
|
||
text.insert(tk.END, details)
|
||
text.config(state=tk.DISABLED)
|
||
|
||
btns = ttk.Frame(top)
|
||
btns.pack(fill=tk.X, pady=(10, 0))
|
||
|
||
def copy_details():
|
||
try:
|
||
self.root.clipboard_clear()
|
||
self.root.clipboard_append(details or summary)
|
||
self.root.update()
|
||
except Exception:
|
||
pass
|
||
|
||
ttk.Button(btns, text="复制详情", command=copy_details).pack(side=tk.LEFT)
|
||
ttk.Button(btns, text="关闭", command=win.destroy).pack(side=tk.RIGHT)
|
||
|
||
try:
|
||
win.transient(self.root)
|
||
win.grab_set()
|
||
win.focus_force()
|
||
except Exception:
|
||
pass
|
||
|
||
def show_exception_dialog(self, title: str, exc: BaseException):
|
||
self.show_error_dialog(title=title, summary=str(exc), details=traceback.format_exc())
|
||
|
||
def create_widgets(self):
|
||
main_frame = ttk.Frame(self.root, padding="10")
|
||
main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
|
||
|
||
config_frame = ttk.LabelFrame(main_frame, text="配置参数", padding="5")
|
||
config_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(0, 10))
|
||
|
||
ref_label = ttk.Label(config_frame, text="参考TIF文件:")
|
||
ref_label.grid(row=0, column=0, sticky=tk.W, padx=(0, 5))
|
||
self.ref_tif_var = tk.StringVar(value=str(REF_TIF))
|
||
ref_entry = ttk.Entry(config_frame, textvariable=self.ref_tif_var, width=50)
|
||
ref_entry.grid(row=0, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
|
||
ref_btn = ttk.Button(config_frame, text="选择文件", command=self.select_ref_tif)
|
||
ref_btn.grid(row=0, column=2)
|
||
|
||
self.enable_ref_mask_var = tk.BooleanVar(value=False)
|
||
ref_mask_chk = ttk.Checkbutton(
|
||
config_frame,
|
||
text="启用底图掩膜",
|
||
variable=self.enable_ref_mask_var,
|
||
command=self._on_toggle_ref_mask,
|
||
)
|
||
ref_mask_chk.grid(row=1, column=0, sticky=tk.W, padx=(0, 5))
|
||
|
||
self.ref_mask_tif_var = tk.StringVar(value="")
|
||
self.ref_mask_entry = ttk.Entry(config_frame, textvariable=self.ref_mask_tif_var, width=50, state=tk.DISABLED)
|
||
self.ref_mask_entry.grid(row=1, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
|
||
self.ref_mask_btn = ttk.Button(config_frame, text="选择文件", command=self.select_ref_mask_tif, state=tk.DISABLED)
|
||
self.ref_mask_btn.grid(row=1, column=2)
|
||
|
||
bip_label = ttk.Label(config_frame, text="BIP文件夹:")
|
||
bip_label.grid(row=2, column=0, sticky=tk.W, padx=(0, 5))
|
||
self.bip_dir_var = tk.StringVar(value=str(BIP_DIR))
|
||
bip_entry = ttk.Entry(config_frame, textvariable=self.bip_dir_var, width=50)
|
||
bip_entry.grid(row=2, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
|
||
bip_btn = ttk.Button(config_frame, text="选择文件夹", command=self.select_bip_dir)
|
||
bip_btn.grid(row=2, column=2)
|
||
|
||
out_label = ttk.Label(config_frame, text="输出文件夹:")
|
||
out_label.grid(row=3, column=0, sticky=tk.W, padx=(0, 5))
|
||
self.out_dir_var = tk.StringVar(value=str(OUT_DIR))
|
||
out_entry = ttk.Entry(config_frame, textvariable=self.out_dir_var, width=50)
|
||
out_entry.grid(row=3, column=1, sticky=(tk.W, tk.E), padx=(0, 5))
|
||
out_btn = ttk.Button(config_frame, text="选择文件夹", command=self.select_out_dir)
|
||
out_btn.grid(row=3, column=2)
|
||
|
||
matcher_label = ttk.Label(config_frame, text="匹配算法:")
|
||
matcher_label.grid(row=4, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0))
|
||
self.matcher_var = tk.StringVar(value=str(MATCHER_NAME))
|
||
matcher_combo = ttk.Combobox(config_frame, textvariable=self.matcher_var, width=47)
|
||
matcher_combo['values'] = _MATCHER_VALUES
|
||
matcher_combo.grid(row=4, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
|
||
|
||
device_label = ttk.Label(config_frame, text="设备:")
|
||
device_label.grid(row=5, column=0, sticky=tk.W, padx=(0, 5))
|
||
self.device_var = tk.StringVar(value=str(DEVICE))
|
||
device_frame = ttk.Frame(config_frame)
|
||
device_frame.grid(row=5, column=1, columnspan=2, sticky=(tk.W, tk.E))
|
||
cuda_rb = ttk.Radiobutton(device_frame, text="CUDA", variable=self.device_var, value="cuda")
|
||
cpu_rb = ttk.Radiobutton(device_frame, text="CPU", variable=self.device_var, value="cpu")
|
||
cuda_rb.pack(side=tk.LEFT)
|
||
cpu_rb.pack(side=tk.LEFT)
|
||
|
||
transform_label = ttk.Label(config_frame, text="变换方法 (按优先级):")
|
||
transform_label.grid(row=6, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0))
|
||
transform_frame = ttk.Frame(config_frame)
|
||
transform_frame.grid(row=6, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
|
||
|
||
self.transform_listbox = tk.Listbox(transform_frame, selectmode=tk.MULTIPLE, height=5, exportselection=False)
|
||
transform_methods = ["similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps"]
|
||
for method in transform_methods:
|
||
self.transform_listbox.insert(tk.END, method)
|
||
if method in TRANSFORM_METHODS:
|
||
self.transform_listbox.selection_set(transform_methods.index(method))
|
||
|
||
scrollbar = ttk.Scrollbar(transform_frame, orient=tk.VERTICAL, command=self.transform_listbox.yview)
|
||
self.transform_listbox.configure(yscrollcommand=scrollbar.set)
|
||
self.transform_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||
|
||
button_frame = ttk.Frame(transform_frame)
|
||
button_frame.pack(side=tk.RIGHT, padx=(5, 0))
|
||
ttk.Button(button_frame, text="↑ 上移", command=self.move_up).pack(fill=tk.X, pady=(0, 2))
|
||
ttk.Button(button_frame, text="↓ 下移", command=self.move_down).pack(fill=tk.X)
|
||
|
||
param_frame = ttk.LabelFrame(config_frame, text="参数设置", padding="5")
|
||
param_frame.grid(row=7, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(10, 0))
|
||
|
||
match_max_side_label = ttk.Label(param_frame, text="匹配最大边长:")
|
||
match_max_side_label.grid(row=0, column=0, sticky=tk.W, padx=(0, 5))
|
||
self.match_max_side_var = tk.IntVar(value=int(MATCH_MAX_SIDE))
|
||
match_max_side_entry = ttk.Entry(param_frame, textvariable=self.match_max_side_var, width=10)
|
||
match_max_side_entry.grid(row=0, column=1, sticky=tk.W)
|
||
|
||
roi_pad_label = ttk.Label(param_frame, text="ROI填充像素:")
|
||
roi_pad_label.grid(row=0, column=2, sticky=tk.W, padx=(10, 5))
|
||
self.roi_pad_px_var = tk.IntVar(value=int(ROI_PAD_PX))
|
||
roi_pad_entry = ttk.Entry(param_frame, textvariable=self.roi_pad_px_var, width=10)
|
||
roi_pad_entry.grid(row=0, column=3, sticky=tk.W)
|
||
|
||
mask_pad_label = ttk.Label(param_frame, text="掩膜膨胀像素:")
|
||
mask_pad_label.grid(row=0, column=4, sticky=tk.W, padx=(10, 5))
|
||
self.mask_pad_px_var = tk.IntVar(value=int(MASK_PAD_PX))
|
||
mask_pad_entry = ttk.Entry(param_frame, textvariable=self.mask_pad_px_var, width=10)
|
||
mask_pad_entry.grid(row=0, column=5, sticky=tk.W)
|
||
|
||
min_inliers_label = ttk.Label(param_frame, text="最少内点数:")
|
||
min_inliers_label.grid(row=1, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0))
|
||
self.min_inliers_var = tk.IntVar(value=int(MIN_INLIERS))
|
||
min_inliers_entry = ttk.Entry(param_frame, textvariable=self.min_inliers_var, width=10)
|
||
min_inliers_entry.grid(row=1, column=1, sticky=tk.W, pady=(5, 0))
|
||
|
||
min_ratio_label = ttk.Label(param_frame, text="最少内点比例:")
|
||
min_ratio_label.grid(row=1, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0))
|
||
self.min_inlier_ratio_var = tk.DoubleVar(value=float(MIN_INLIER_RATIO))
|
||
min_ratio_entry = ttk.Entry(param_frame, textvariable=self.min_inlier_ratio_var, width=10)
|
||
min_ratio_entry.grid(row=1, column=3, sticky=tk.W, pady=(5, 0))
|
||
|
||
feather_label = ttk.Label(param_frame, text="羽化像素:")
|
||
feather_label.grid(row=2, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0))
|
||
self.feather_px_var = tk.IntVar(value=int(FEATHER_PX))
|
||
feather_entry = ttk.Entry(param_frame, textvariable=self.feather_px_var, width=10)
|
||
feather_entry.grid(row=2, column=1, sticky=tk.W, pady=(5, 0))
|
||
|
||
edge_band_label = ttk.Label(param_frame, text="边界剔除像素:")
|
||
edge_band_label.grid(row=2, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0))
|
||
self.edge_band_px_var = tk.IntVar(value=int(EDGE_BAND_PX))
|
||
edge_band_entry = ttk.Entry(param_frame, textvariable=self.edge_band_px_var, width=10)
|
||
edge_band_entry.grid(row=2, column=3, sticky=tk.W, pady=(5, 0))
|
||
|
||
grad_q_label = ttk.Label(param_frame, text="梯度分位阈值:")
|
||
grad_q_label.grid(row=2, column=4, sticky=tk.W, padx=(10, 5), pady=(5, 0))
|
||
self.min_grad_quantile_var = tk.DoubleVar(value=float(MIN_GRAD_QUANTILE))
|
||
grad_q_entry = ttk.Entry(param_frame, textvariable=self.min_grad_quantile_var, width=10)
|
||
grad_q_entry.grid(row=2, column=5, sticky=tk.W, pady=(5, 0))
|
||
|
||
self.add_tooltip(ref_label, "参考底图 GeoTIFF,用于批量配准的目标坐标系与位置基准。建议确保 CRS、transform 正确。")
|
||
self.add_tooltip(ref_entry, "参考底图 GeoTIFF 路径。配准时会读取该底图的 ROI 进行匹配。")
|
||
self.add_tooltip(ref_btn, "选择参考底图 GeoTIFF 文件。")
|
||
|
||
self.add_tooltip(ref_mask_chk, "勾选后先用掩膜 TIF 对底图进行掩膜(掩膜值=1 的区域设置为 NoData),并保存为新的底图;后续配准使用掩膜后的底图。")
|
||
self.add_tooltip(self.ref_mask_entry, "掩膜 GeoTIFF 路径。要求与底图严格对齐(相同 CRS、分辨率、范围、尺寸),否则会报错或效果不可控。")
|
||
self.add_tooltip(self.ref_mask_btn, "选择掩膜 GeoTIFF 文件。")
|
||
|
||
self.add_tooltip(bip_label, "包含待配准航带 .bip 文件的文件夹。程序会批量遍历 *.bip。")
|
||
self.add_tooltip(bip_entry, "BIP 文件夹路径。")
|
||
self.add_tooltip(bip_btn, "选择 BIP 文件夹。")
|
||
|
||
self.add_tooltip(out_label, "输出目录:配准后的航带、可视化图片、统计 CSV 等都会写到这里。")
|
||
self.add_tooltip(out_entry, "输出文件夹路径。")
|
||
self.add_tooltip(out_btn, "选择输出文件夹。")
|
||
|
||
self.add_tooltip(matcher_label, "特征匹配算法名称。不同 matcher 在精度、速度、鲁棒性上差异较大。")
|
||
self.add_tooltip(matcher_combo, "选择/输入 matcher 名称。若使用 cuda,需要环境支持 GPU。")
|
||
|
||
self.add_tooltip(device_label, "运行设备:cuda(GPU)更快,cpu 更通用。")
|
||
self.add_tooltip(cuda_rb, "使用 GPU(CUDA)运行匹配器与部分计算。")
|
||
self.add_tooltip(cpu_rb, "使用 CPU 运行。速度可能较慢。")
|
||
|
||
self.add_tooltip(transform_label, "变换模型选择(可多选)。配准会按优先级尝试,并自动选择误差较小的模型。")
|
||
self.add_tooltip(self.transform_listbox, "按住 Ctrl/Shift 多选。右侧可上移/下移调整优先级。一般 homography 更灵活但更易发散,affine 更稳定。")
|
||
|
||
self.add_tooltip(match_max_side_label, "匹配阶段会把图像等比缩小到最大边长不超过该值。值越大越慢,但细节更多。")
|
||
self.add_tooltip(match_max_side_entry, "匹配用降采样尺寸上限(像素)。")
|
||
|
||
self.add_tooltip(roi_pad_label, "参考底图 ROI 的额外扩展像素。增大可覆盖更大不确定区域,但会增加内存与耗时。")
|
||
self.add_tooltip(roi_pad_entry, "ROI padding(像素,参考底图坐标系)。")
|
||
|
||
self.add_tooltip(mask_pad_label, "仅用于匹配阶段:对源图有效掩膜/重投影后的掩膜做膨胀,增加可匹配区域。")
|
||
self.add_tooltip(mask_pad_entry, "掩膜膨胀像素(只影响匹配,不直接改变输出)。")
|
||
|
||
self.add_tooltip(min_inliers_label, "RANSAC 内点数量阈值。低于该值认为匹配质量不足,判定失败。")
|
||
self.add_tooltip(min_inliers_entry, "最少内点数。")
|
||
|
||
self.add_tooltip(min_ratio_label, "内点比例阈值(内点数/匹配点数)。过低通常意味着匹配不可靠。")
|
||
self.add_tooltip(min_ratio_entry, "最少内点比例。")
|
||
|
||
self.add_tooltip(feather_label, "对掩膜边缘做羽化,降低硬边缘带来的高对比假匹配。数值越大边缘过渡越宽。")
|
||
self.add_tooltip(feather_entry, "掩膜羽化宽度(像素)。")
|
||
|
||
self.add_tooltip(edge_band_label, "剔除距离掩膜边界过近的匹配点,减少边缘假匹配。数值越大剔除越多。")
|
||
self.add_tooltip(edge_band_entry, "边缘带剔除宽度(像素)。")
|
||
|
||
self.add_tooltip(grad_q_label, "纹理过滤分位阈值:梯度幅值低于该分位的区域视为低纹理,匹配点会被剔除。")
|
||
self.add_tooltip(grad_q_entry, "梯度分位阈值(0~1)。")
|
||
|
||
control_frame = ttk.Frame(main_frame)
|
||
control_frame.grid(row=1, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
|
||
self.start_btn = ttk.Button(control_frame, text="开始处理", command=self.start_processing)
|
||
self.start_btn.pack(side=tk.LEFT, padx=(0, 10))
|
||
self.stop_btn = ttk.Button(control_frame, text="停止处理", command=self.stop_processing, state=tk.DISABLED)
|
||
self.stop_btn.pack(side=tk.LEFT)
|
||
|
||
progress_frame = ttk.LabelFrame(main_frame, text="处理进度", padding="5")
|
||
progress_frame.grid(row=2, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0))
|
||
self.progress_var = tk.DoubleVar()
|
||
self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100)
|
||
self.progress_bar.pack(fill=tk.X, pady=(0, 5))
|
||
self.progress_label = ttk.Label(progress_frame, text="准备就绪")
|
||
self.progress_label.pack(anchor=tk.W)
|
||
|
||
log_frame = ttk.LabelFrame(main_frame, text="处理日志", padding="5")
|
||
log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(10, 0))
|
||
log_text_frame = ttk.Frame(log_frame)
|
||
log_text_frame.pack(fill=tk.BOTH, expand=True)
|
||
self.log_text = tk.Text(log_text_frame, height=15, wrap=tk.WORD)
|
||
scrollbar = ttk.Scrollbar(log_text_frame, orient=tk.VERTICAL, command=self.log_text.yview)
|
||
self.log_text.configure(yscrollcommand=scrollbar.set)
|
||
self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||
|
||
log_btn_frame = ttk.Frame(log_frame)
|
||
log_btn_frame.pack(fill=tk.X, pady=(5, 0))
|
||
ttk.Button(log_btn_frame, text="清空日志", command=self.clear_log).pack(side=tk.LEFT, padx=(0, 5))
|
||
ttk.Button(log_btn_frame, text="保存日志", command=self.save_log).pack(side=tk.LEFT)
|
||
|
||
self.root.columnconfigure(0, weight=1)
|
||
self.root.rowconfigure(0, weight=1)
|
||
main_frame.columnconfigure(1, weight=1)
|
||
main_frame.rowconfigure(3, weight=1)
|
||
|
||
def select_ref_tif(self):
|
||
filename = filedialog.askopenfilename(
|
||
title="选择参考TIF文件",
|
||
filetypes=[("TIF files", "*.tif;*.tiff"), ("All files", "*.*")]
|
||
)
|
||
if filename:
|
||
self.ref_tif_var.set(filename)
|
||
|
||
def select_ref_mask_tif(self):
|
||
filename = filedialog.askopenfilename(
|
||
title="选择掩膜TIF文件",
|
||
filetypes=[("TIF files", "*.tif;*.tiff"), ("All files", "*.*")]
|
||
)
|
||
if filename:
|
||
self.ref_mask_tif_var.set(filename)
|
||
|
||
def select_bip_dir(self):
|
||
dirname = filedialog.askdirectory(title="选择BIP文件夹")
|
||
if dirname:
|
||
self.bip_dir_var.set(dirname)
|
||
|
||
def select_out_dir(self):
|
||
dirname = filedialog.askdirectory(title="选择输出文件夹")
|
||
if dirname:
|
||
self.out_dir_var.set(dirname)
|
||
|
||
def move_up(self):
|
||
selection = self.transform_listbox.curselection()
|
||
if selection and selection[0] > 0:
|
||
idx = selection[0]
|
||
text = self.transform_listbox.get(idx)
|
||
self.transform_listbox.delete(idx)
|
||
self.transform_listbox.insert(idx - 1, text)
|
||
self.transform_listbox.selection_set(idx - 1)
|
||
|
||
def move_down(self):
|
||
selection = self.transform_listbox.curselection()
|
||
if selection and selection[0] < self.transform_listbox.size() - 1:
|
||
idx = selection[0]
|
||
text = self.transform_listbox.get(idx)
|
||
self.transform_listbox.delete(idx)
|
||
self.transform_listbox.insert(idx + 1, text)
|
||
self.transform_listbox.selection_set(idx + 1)
|
||
|
||
def start_processing(self):
|
||
if self.processing_thread and self.processing_thread.is_alive():
|
||
messagebox.showwarning("警告", "处理正在进行中")
|
||
return
|
||
|
||
selected_indices = self.transform_listbox.curselection()
|
||
if not selected_indices:
|
||
messagebox.showwarning("警告", "请至少选择一种变换方法")
|
||
return
|
||
|
||
transform_methods = [self.transform_listbox.get(i) for i in selected_indices]
|
||
|
||
cfg = RegistrationConfig(
|
||
ref_tif=self.ref_tif_var.get().strip(),
|
||
bip_dir=self.bip_dir_var.get().strip(),
|
||
out_dir=self.out_dir_var.get().strip(),
|
||
enable_ref_mask=bool(self.enable_ref_mask_var.get()),
|
||
ref_mask_tif=self.ref_mask_tif_var.get().strip(),
|
||
ref_mask_remove_value=1,
|
||
matcher_name=self.matcher_var.get().strip(),
|
||
device=self.device_var.get().strip(),
|
||
transform_methods=transform_methods,
|
||
match_max_side=int(self.match_max_side_var.get()),
|
||
roi_pad_px=int(self.roi_pad_px_var.get()),
|
||
mask_pad_px=int(self.mask_pad_px_var.get()),
|
||
min_inliers=int(self.min_inliers_var.get()),
|
||
min_inlier_ratio=float(self.min_inlier_ratio_var.get()),
|
||
feather_px=int(self.feather_px_var.get()),
|
||
edge_band_px=int(self.edge_band_px_var.get()),
|
||
min_grad_quantile=float(self.min_grad_quantile_var.get()),
|
||
)
|
||
|
||
if not Path(cfg.ref_tif).exists():
|
||
messagebox.showerror("错误", "参考 TIF 不存在")
|
||
return
|
||
if not Path(cfg.bip_dir).exists():
|
||
messagebox.showerror("错误", "BIP 文件夹不存在")
|
||
return
|
||
if not cfg.out_dir:
|
||
messagebox.showerror("错误", "输出文件夹不能为空")
|
||
return
|
||
if cfg.enable_ref_mask:
|
||
if not TIF_MASK_AVAILABLE:
|
||
messagebox.showerror("错误", "tif_caijain.py 不可用,无法进行底图掩膜")
|
||
return
|
||
if not cfg.ref_mask_tif or not Path(cfg.ref_mask_tif).exists():
|
||
messagebox.showerror("错误", "已启用底图掩膜,但掩膜 TIF 文件不存在")
|
||
return
|
||
|
||
self.stop_event.clear()
|
||
self.start_btn.config(state=tk.DISABLED)
|
||
self.stop_btn.config(state=tk.NORMAL)
|
||
self.progress_var.set(0)
|
||
self.progress_label.config(text="正在初始化...")
|
||
|
||
self.processing_thread = threading.Thread(
|
||
target=self.run_processing,
|
||
args=(cfg,),
|
||
daemon=True
|
||
)
|
||
self.processing_thread.start()
|
||
|
||
def _on_toggle_ref_mask(self):
|
||
enabled = bool(self.enable_ref_mask_var.get())
|
||
state = tk.NORMAL if enabled else tk.DISABLED
|
||
try:
|
||
self.ref_mask_entry.configure(state=state)
|
||
self.ref_mask_btn.configure(state=state)
|
||
except Exception:
|
||
pass
|
||
|
||
def stop_processing(self):
|
||
if self.processing_thread and self.processing_thread.is_alive():
|
||
self.stop_event.set()
|
||
self.progress_label.config(text="正在停止...")
|
||
|
||
def run_processing(self, cfg: RegistrationConfig):
|
||
try:
|
||
def progress_cb(current, total, filename):
|
||
self.on_progress(current, total, filename)
|
||
_run_batch(cfg, self.stop_event, progress_cb=progress_cb)
|
||
except Exception as e:
|
||
tb = traceback.format_exc()
|
||
self.log_queue.put(f"处理过程中发生错误: {e}\n{tb}")
|
||
try:
|
||
self.root.after(0, lambda: self.show_error_dialog("处理失败", str(e), tb))
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
self.root.after(0, lambda: self.start_btn.config(state=tk.NORMAL))
|
||
self.root.after(0, lambda: self.stop_btn.config(state=tk.DISABLED))
|
||
self.root.after(0, lambda: self.progress_label.config(text="处理完成"))
|
||
|
||
def on_progress(self, current, total, filename):
|
||
if total > 0:
|
||
progress = (current / total) * 100
|
||
self.root.after(0, lambda: self.progress_var.set(progress))
|
||
if filename:
|
||
self.root.after(0, lambda: self.progress_label.config(text=f"处理中: {filename} ({current}/{total})"))
|
||
else:
|
||
self.root.after(0, lambda: self.progress_label.config(text=f"处理中: ({current}/{total})"))
|
||
|
||
def check_log_queue(self):
|
||
try:
|
||
while True:
|
||
message = self.log_queue.get_nowait()
|
||
self.log_text.insert(tk.END, message + '\n')
|
||
self.log_text.see(tk.END)
|
||
except queue.Empty:
|
||
pass
|
||
self.root.after(100, self.check_log_queue)
|
||
|
||
def clear_log(self):
|
||
self.log_text.delete(1.0, tk.END)
|
||
|
||
def save_log(self):
|
||
filename = filedialog.asksaveasfilename(
|
||
title="保存日志",
|
||
defaultextension=".txt",
|
||
filetypes=[("Text files", "*.txt"), ("All files", "*.*")]
|
||
)
|
||
if filename:
|
||
with open(filename, 'w', encoding='utf-8') as f:
|
||
f.write(self.log_text.get(1.0, tk.END))
|
||
|
||
|
||
def create_gui():
|
||
root = tk.Tk()
|
||
RegistrationGUI(root)
|
||
root.mainloop()
|
||
|
||
|
||
# ---------- 主逻辑 ----------
|
||
def main():
|
||
cfg = RegistrationConfig(
|
||
ref_tif=str(REF_TIF),
|
||
bip_dir=str(BIP_DIR),
|
||
out_dir=str(OUT_DIR),
|
||
enable_ref_mask=False,
|
||
ref_mask_tif="",
|
||
ref_mask_remove_value=1,
|
||
matcher_name=str(MATCHER_NAME),
|
||
device=str(DEVICE),
|
||
transform_methods=list(TRANSFORM_METHODS),
|
||
match_max_side=int(MATCH_MAX_SIDE),
|
||
roi_pad_px=int(ROI_PAD_PX),
|
||
mask_pad_px=int(MASK_PAD_PX),
|
||
min_inliers=int(MIN_INLIERS),
|
||
min_inlier_ratio=float(MIN_INLIER_RATIO),
|
||
feather_px=int(FEATHER_PX),
|
||
edge_band_px=int(EDGE_BAND_PX),
|
||
min_grad_quantile=float(MIN_GRAD_QUANTILE),
|
||
)
|
||
stop_event = threading.Event()
|
||
_run_batch(cfg, stop_event)
|
||
|
||
if __name__ == "__main__":
|
||
if "--cli" in sys.argv:
|
||
main()
|
||
else:
|
||
create_gui()
|