diff --git a/StripStitch.py b/StripStitch.py new file mode 100644 index 0000000..c68281d --- /dev/null +++ b/StripStitch.py @@ -0,0 +1,2217 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +问题:当图像中大部分是水体时,匹配过多出现在掩膜边缘,同时过滤时将本来就少的陆地匹配点也过滤掉了 +""" + +import sys +import os +# Fix for PyInstaller GUI apps: ensure stdout/stderr are never None +# This prevents 'NoneType' object has no attribute 'write' errors +# when libraries like PyTorch try to print download progress +if sys.stdout is None: + sys.stdout = open(os.devnull, 'w') +if sys.stderr is None: + sys.stderr = open(os.devnull, 'w') + +from pathlib import Path + + +def _early_pyinstaller_hf_env(): + """必须在 import vismatch 之前执行:vismatch/__init__.py 会立即 import huggingface_hub。""" + if not hasattr(sys, "_MEIPASS"): + return + base = Path(sys._MEIPASS) + exe_dir = Path(sys.executable).resolve().parent + hf_candidates = [ + base / "hub", + base / "_internal" / "hub", + exe_dir / "_internal" / "hub", + exe_dir / "hub", + ] + for hf_candidate in hf_candidates: + try: + if not hf_candidate.exists(): + continue + if not any("vismatch" in d.name.lower() for d in hf_candidate.iterdir() if d.is_dir()): + continue + except OSError: + continue + os.environ.setdefault("HF_HOME", str(hf_candidate.parent)) + os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hf_candidate)) + os.environ["HF_HUB_OFFLINE"] = "1" + os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") + break + + +_early_pyinstaller_hf_env() + +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +from vismatch.viz import plot_matches, plot_keypoints +import logging +import threading +import queue +import sys +import traceback +import types +from dataclasses import dataclass +import tkinter as tk +from tkinter import ttk, filedialog, messagebox + +try: + from tif_caijain import mask_data_by_binary_mask + TIF_MASK_AVAILABLE = True +except Exception: + TIF_MASK_AVAILABLE = False + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def _ensure_pyinstaller_third_party_paths(): + if not hasattr(sys, "_MEIPASS"): + return + base = Path(sys._MEIPASS) + exe_dir = Path(sys.executable).resolve().parent + + # More comprehensive candidate paths for third_party + candidates = [ + base / "vismatch" / "third_party", + base / "_internal" / "vismatch" / "third_party", + exe_dir / "_internal" / "vismatch" / "third_party", + exe_dir / "vismatch" / "third_party", + base / "third_party", # In case vismatch is directly included + ] + + third_party_base = None + for c in candidates: + if c.exists(): + third_party_base = c + logger.info(f"找到 third_party 目录: {third_party_base}") + break + + if third_party_base is None: + logger.warning(f"未找到 third_party 目录,MEIPASS={base}, exe_dir={exe_dir}") + # List what's available for debugging + try: + if base.exists(): + logger.info(f"MEIPASS 内容: {list(base.iterdir())[:10]}") + if exe_dir.exists(): + logger.info(f"exe_dir 内容: {list(exe_dir.iterdir())[:10]}") + except Exception as e: + logger.warning(f"无法列出目录内容: {e}") + return + + # Try multiple possible structures for MatchAnything + matchanything_candidates = [ + # Original expected structure + third_party_base / "MatchAnything" / "imcui" / "third_party" / "MatchAnything", + # Alternative: direct MatchAnything without the nested imcui structure + third_party_base / "MatchAnything", + # Alternative: MatchAnything with imcui but different nesting + third_party_base / "MatchAnything" / "MatchAnything", + # One more level up possibility + third_party_base.parent / "MatchAnything" / "imcui" / "third_party" / "MatchAnything", + ] + + matchanything_root = None + for candidate in matchanything_candidates: + # Handle case where candidate already ends with 'src' or needs src subdirectory check + has_src = (candidate / "src").exists() if not str(candidate).endswith("src") else candidate.exists() + if candidate.exists() and has_src: + # If candidate ends with src, use its parent as root + matchanything_root = candidate.parent if str(candidate).endswith("src") else candidate + logger.info(f"找到 MatchAnything 根目录: {matchanything_root}") + break + + if matchanything_root is None: + logger.warning(f"未找到 MatchAnything 目录,尝试的路径:") + for c in matchanything_candidates: + logger.warning(f" - {c} (exists={c.exists()})") + + # Last resort: search recursively for any directory containing 'src' and 'matchanything' in path + try: + for root, dirs, files in os.walk(third_party_base): + root_path = Path(root) + if "matchanything" in root.lower() and (root_path / "src").exists(): + matchanything_root = root_path + logger.info(f"通过递归搜索找到 MatchAnything: {matchanything_root}") + break + # Also check if this directory has a 'src' subdirectory + if (root_path / "src").exists(): + # Check if it looks like MatchAnything (has specific files) + src_files = list((root_path / "src").glob("*.py"))[:5] + if src_files: + matchanything_root = root_path + logger.info(f"通过递归搜索找到潜在 MatchAnything: {matchanything_root}") + break + except Exception as e: + logger.warning(f"递归搜索失败: {e}") + + if matchanything_root is None: + return + + # Add MatchAnything root to path (contains 'src' module) + p = str(matchanything_root) + if p not in sys.path: + sys.path.insert(0, p) + logger.info(f"已添加 MatchAnything 到 sys.path: {p}") + + # Try multiple possible ROMA paths + roma_candidates = [ + matchanything_root / "third_party" / "ROMA", + third_party_base / "ROMA", + third_party_base / "MatchAnything" / "third_party" / "ROMA", + matchanything_root.parent / "ROMA", + ] + + roma_root = None + for candidate in roma_candidates: + if candidate.exists(): + roma_root = candidate + logger.info(f"找到 ROMA 目录: {roma_root}") + break + + if roma_root: + p2 = str(roma_root) + if p2 not in sys.path: + sys.path.insert(0, p2) + logger.info(f"已添加 ROMA 到 sys.path: {p2}") + else: + logger.warning(f"未找到 ROMA 目录") + + # HuggingFace 缓存:优先已在 _early_pyinstaller_hf_env() 中设置(须在 import vismatch 前) + if hasattr(sys, "_MEIPASS"): + hf_candidates = [ + base / "hub", + base / "_internal" / "hub", + exe_dir / "_internal" / "hub", + exe_dir / "hub", + ] + for hf_candidate in hf_candidates: + try: + if not hf_candidate.exists(): + continue + if not any("vismatch" in d.name.lower() for d in hf_candidate.iterdir() if d.is_dir()): + continue + except OSError: + continue + os.environ.setdefault("HF_HOME", str(hf_candidate.parent)) + os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hf_candidate)) + os.environ.setdefault("HF_HUB_OFFLINE", "1") + os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") + logger.info( + f"HuggingFace 缓存: {os.environ.get('HUGGINGFACE_HUB_CACHE')} " + f"(HF_HUB_OFFLINE={os.environ.get('HF_HUB_OFFLINE')})" + ) + break + +def _install_loguru_stub_if_missing(): + try: + import loguru # noqa: F401 + return + except Exception: + pass + + py_logger = logging.getLogger("loguru") + + class _StubLogger: + def debug(self, msg, *args, **kwargs): + py_logger.debug(msg, *args) + + def info(self, msg, *args, **kwargs): + py_logger.info(msg, *args) + + def warning(self, msg, *args, **kwargs): + py_logger.warning(msg, *args) + + def error(self, msg, *args, **kwargs): + py_logger.error(msg, *args) + + def exception(self, msg, *args, **kwargs): + py_logger.exception(msg, *args) + + def add(self, *args, **kwargs): + return 0 + + def remove(self, *args, **kwargs): + return None + + m = types.ModuleType("loguru") + m.logger = _StubLogger() + sys.modules["loguru"] = m + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\dingshanhu\mask_water.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\dingshanhu") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\dingshanhu\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["similarity", "affine", "homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) +MASK_PAD_PX = 100 # 匹配掩膜扩张像素(仅用于匹配阶段) + +# 质量控制阈值 +MIN_INLIERS = 10 +MIN_INLIER_RATIO = 0.01 + +# 掩膜边缘羽化与过滤 +FEATHER_PX = 20 # 掩膜羽化宽度(像素,先在全分辨率/ROI分辨率上做) +EDGE_BAND_PX = 30 # 剔除距离掩膜边界小于此像素的匹配点(在小图上按比例缩放) + +# 纹理过滤 +MIN_GRAD_QUANTILE = 0.20 # 梯度幅值的分位阈值(0~1),低于该阈值的点视为低纹理,剔除 + +STATS_DIR = None +STATS_CSV = None + + +@dataclass +class RegistrationConfig: + ref_tif: str + bip_dir: str + out_dir: str + enable_ref_mask: bool + ref_mask_tif: str + ref_mask_remove_value: int + matcher_name: str + device: str + transform_methods: list + match_max_side: int + roi_pad_px: int + mask_pad_px: int + min_inliers: int + min_inlier_ratio: float + feather_px: int + edge_band_px: int + min_grad_quantile: float + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + + +def _pixel_size_xy(transform: Affine): + rx = float(np.hypot(transform.a, transform.d)) + ry = float(np.hypot(transform.b, transform.e)) + if not np.isfinite(rx) or rx <= 0: + rx = float(abs(transform.a)) if transform.a != 0 else 1.0 + if not np.isfinite(ry) or ry <= 0: + ry = float(abs(transform.e)) if transform.e != 0 else 1.0 + return rx, ry + + +def _grid_from_bounds(bounds, res_x: float, res_y: float): + left, bottom, right, top = [float(v) for v in bounds] + res_x = float(abs(res_x)) + res_y = float(abs(res_y)) + w = int(np.ceil((right - left) / max(1e-12, res_x))) + h = int(np.ceil((top - bottom) / max(1e-12, res_y))) + w = max(1, w) + h = max(1, h) + out_transform = Affine(res_x, 0.0, left, 0.0, -res_y, top) + return out_transform, w, h + + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def _downscale_mask_hw(mask_hw: np.ndarray, target_h: int, target_w: int) -> np.ndarray: + """将(H,W)二值掩膜缩放到目标尺寸,保持最近邻""" + m = cv2.resize(mask_hw.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST) + return m > 0 + +def _soft_alpha_from_mask(mask_hw: np.ndarray, feather_px: int) -> np.ndarray: + """ + 二值掩膜 -> 软掩膜 alpha∈[0,1],边缘处按距离线性上升,避免硬边缘。 + mask_hw: bool/uint8 (H,W) True/1表示有效 + """ + if mask_hw is None: + return None + m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255 + # 距离变换仅对前景内部有效,计算到边界的距离 + dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3) + if feather_px <= 0: + alpha = (dist > 0).astype(np.float32) + else: + alpha = np.clip(dist / float(feather_px), 0.0, 1.0).astype(np.float32) + return alpha # (H,W) float32 + +def _distance_keep_mask(mask_hw: np.ndarray, min_dist_px: int) -> np.ndarray: + """ + 生成"远离边界"的保留掩膜:仅保留距离边界>=min_dist_px的像素。 + """ + if mask_hw is None: + return None + m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255 + dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3) + keep = dist >= float(max(1, min_dist_px)) + return keep + +def _grad_mask_from_chw(img_chw: np.ndarray, quantile: float) -> np.ndarray: + """ + 根据梯度幅值生成纹理掩膜(H,W)True=纹理足够。 + 使用与匹配同尺寸的CHW图像。 + """ + # 转灰度 + g = img_chw.mean(axis=0).astype(np.float32) # (H,W) + gx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3) + gy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3) + mag = np.sqrt(gx*gx + gy*gy) + thr = float(np.quantile(mag, quantile)) if mag.size > 0 else 0.0 + return mag >= thr # (H,W) bool + +def _filter_matches_by_masks(result: dict, src_mask_small: np.ndarray, ref_mask_small: np.ndarray) -> dict: + """将匹配与内点严格限制在掩膜内""" + if src_mask_small is None or ref_mask_small is None: + return result + + def keep_in_mask(kpts: np.ndarray, mask_hw: np.ndarray) -> np.ndarray: + if kpts is None or len(kpts) == 0: + return np.zeros((0,), dtype=bool) + kpts = np.asarray(kpts) + xs = np.clip(np.rint(kpts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1) + ys = np.clip(np.rint(kpts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1) + return mask_hw[ys, xs] + + # 过滤 matched_kpts + if "matched_kpts0" in result and "matched_kpts1" in result: + mk0 = np.asarray(result["matched_kpts0"]) + mk1 = np.asarray(result["matched_kpts1"]) + if len(mk0) == len(mk1) and len(mk0) > 0: + keep_m = keep_in_mask(mk0, src_mask_small) & keep_in_mask(mk1, ref_mask_small) + result["matched_kpts0"] = mk0[keep_m] + result["matched_kpts1"] = mk1[keep_m] + + # 过滤 inlier_kpts + if "inlier_kpts0" in result and "inlier_kpts1" in result and result["inlier_kpts0"] is not None: + ik0 = np.asarray(result["inlier_kpts0"]) + ik1 = np.asarray(result["inlier_kpts1"]) + if len(ik0) == len(ik1) and len(ik0) > 0: + keep_i = keep_in_mask(ik0, src_mask_small) & keep_in_mask(ik1, ref_mask_small) + result["inlier_kpts0"] = ik0[keep_i] + result["inlier_kpts1"] = ik1[keep_i] + result["num_inliers"] = int(len(result["inlier_kpts0"])) + + return result + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用"源图有效掩膜"的包围盒推参考ROI(比整图bounds更贴近有效重叠) + try: + src_mask = (src.read_masks(1) > 0) # True=有效 + rows_any = np.any(src_mask, axis=1) + cols_any = np.any(src_mask, axis=0) + if rows_any.any() and cols_any.any(): + rmin = int(rows_any.argmax()) + rmax = int(src.height - 1 - rows_any[::-1].argmax()) + cmin = int(cols_any.argmax()) + cmax = int(src.width - 1 - cols_any[::-1].argmax()) + valid_win_src = rasterio.windows.Window(cmin, rmin, cmax - cmin + 1, rmax - rmin + 1) + valid_bounds_src = rasterio.windows.bounds(valid_win_src, transform=src.transform) + b = transform_bounds(src_crs, ref_crs, *valid_bounds_src, densify_pts=21) + else: + # 掩膜无效时回退到整图bounds + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + except Exception: + src_mask = None # 后续可选源图掩膜时用到 + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 将源图有效掩膜重投影到参考ROI,并适度膨胀后作为匹配掩膜 + try: + if src_mask is None: + src_mask = (src.read_masks(1) > 0) + ref_roi_transform = ref_dataset.window_transform(win) + roi_h, roi_w = int(win.height), int(win.width) + dst_mask = np.zeros((roi_h, roi_w), dtype=np.uint8) + + reproject( + source=src_mask.astype(np.uint8), + destination=dst_mask, + src_transform=src.transform, + src_crs=src_crs, + dst_transform=ref_roi_transform, + dst_crs=ref_crs, + resampling=Resampling.nearest + ) + + if MASK_PAD_PX > 0: + k = max(1, MASK_PAD_PX * 2 + 1) # odd kernel size + k = min(k, 99) # 防止核过大导致性能问题,可按需调整/删除 + kernel = np.ones((k, k), np.uint8) + dst_mask = cv2.dilate(dst_mask, kernel, iterations=1) + except Exception: + # 掩膜获取/重投影失败则不使用掩膜 + dst_mask = None + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 软掩膜:避免在边界产生硬高对比边 + try: + alpha_src = _soft_alpha_from_mask(src_mask, FEATHER_PX) if src_mask is not None else None + except Exception: + alpha_src = None + try: + alpha_ref = _soft_alpha_from_mask(dst_mask, FEATHER_PX) if dst_mask is not None else None + except Exception: + alpha_ref = None + + if alpha_src is not None: + alpha_src3 = np.repeat(alpha_src[None, ...], 3, axis=0).astype(src_img.dtype) + src_img = src_img * alpha_src3 + + if alpha_ref is not None: + alpha_ref3 = np.repeat(alpha_ref[None, ...], 3, axis=0).astype(ref_img.dtype) + ref_img = ref_img * alpha_ref3 + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + # 与小图同尺寸的掩膜 + src_mask_small = _downscale_mask_hw(src_mask, src_small.shape[1], src_small.shape[2]) if 'src_mask' in locals() and src_mask is not None else None + ref_mask_small = _downscale_mask_hw(dst_mask, ref_small.shape[1], ref_small.shape[2]) if 'dst_mask' in locals() and dst_mask is not None else None + + # 剔除掩膜边缘带(小图尺度的最小距离) + def _scale_px(px_full: int, full_wh, small_wh) -> int: + # 用平均缩放;也可以分别对H/W计算后取最小 + sy = small_wh[0] / max(1, full_wh[0]) + sx = small_wh[1] / max(1, full_wh[1]) + s = 0.5 * (sx + sy) + return max(1, int(round(px_full * s))) + + edge_band_src_small = _scale_px(EDGE_BAND_PX, (src_img.shape[1], src_img.shape[2]), (src_small.shape[1], src_small.shape[2])) + edge_band_ref_small = _scale_px(EDGE_BAND_PX, (ref_img.shape[1], ref_img.shape[2]), (ref_small.shape[1], ref_small.shape[2])) + + keep_src_edge = _distance_keep_mask(src_mask_small, edge_band_src_small) if src_mask_small is not None else None + keep_ref_edge = _distance_keep_mask(ref_mask_small, edge_band_ref_small) if ref_mask_small is not None else None + + # 纹理掩膜 + keep_src_tex = _grad_mask_from_chw(src_small, MIN_GRAD_QUANTILE) + keep_ref_tex = _grad_mask_from_chw(ref_small, MIN_GRAD_QUANTILE) + + # 组合最终保留掩膜(边缘+纹理),二者都要满足 + def _combine_keep(m_edge, m_tex): + if m_edge is None: + return m_tex + return (m_edge & m_tex) + + keep_src_final = _combine_keep(keep_src_edge, keep_src_tex) + keep_ref_final = _combine_keep(keep_ref_edge, keep_ref_tex) + + # 将匹配与内点严格限制在最终掩膜内 + def _filter_by_bool_masks(res, m_src, m_ref): + if m_src is None or m_ref is None: + return res + + def keep_in(mask_hw, pts): + if pts is None or len(pts) == 0: + return np.zeros((0,), dtype=bool) + xs = np.clip(np.rint(pts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1) + ys = np.clip(np.rint(pts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1) + return mask_hw[ys, xs] + + # matched + if "matched_kpts0" in res and "matched_kpts1" in res: + mk0 = np.asarray(res["matched_kpts0"]); mk1 = np.asarray(res["matched_kpts1"]) + if len(mk0) == len(mk1) and len(mk0) > 0: + keep_m = keep_in(m_src, mk0) & keep_in(m_ref, mk1) + res["matched_kpts0"] = mk0[keep_m] + res["matched_kpts1"] = mk1[keep_m] + + # inliers + if "inlier_kpts0" in res and "inlier_kpts1" in res and res["inlier_kpts0"] is not None: + ik0 = np.asarray(res["inlier_kpts0"]); ik1 = np.asarray(res["inlier_kpts1"]) + if len(ik0) == len(ik1) and len(ik0) > 0: + keep_i = keep_in(m_src, ik0) & keep_in(m_ref, ik1) + res["inlier_kpts0"] = ik0[keep_i] + res["inlier_kpts1"] = ik1[keep_i] + res["num_inliers"] = int(len(res["inlier_kpts0"])) + return res + + result = _filter_by_bool_masks(result, keep_src_final, keep_ref_final) + + # 统计(以过滤后的结果为准) + num_inl = int(result.get("num_inliers", len(result.get("inlier_kpts0", [])))) + num_m = len(result.get("matched_kpts0", [])) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + # 保存匹配可视化图像(使用与匹配同尺寸的图像,保持CHW格式) + viz_dir = out_dir / "visualizations" + viz_dir.mkdir(exist_ok=True) + + matches_path = viz_dir / f"{bip_path.stem}_matches.png" + plot_matches(src_small, ref_small, result, save_path=str(matches_path)) + logger.info(f"匹配可视化已保存: {matches_path}") + + # 关键点可视化(源图像) + kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png" + plot_keypoints( + src_small, + {"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]}, + save_path=str(kpts_src_path) + ) + logger.info(f"源图像关键点可视化已保存: {kpts_src_path}") + + # 关键点可视化(参考图像) + kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png" + plot_keypoints( + ref_small, + {"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]}, + save_path=str(kpts_ref_path) + ) + logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform) + + res_x, res_y = _pixel_size_xy(src.transform) + out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = src.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=out_h, + width=out_w, + count=src.count, + transform=out_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((out_h, out_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=out_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.nearest, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform) + + res_x, res_y = _pixel_size_xy(src.transform) + out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = src.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=out_h, + width=out_w, + count=src.count, + transform=out_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + ref_transform = ref_dataset.transform + Rt = np.array( + [[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0]], + dtype=np.float64, + ) + Out = np.array( + [[out_transform.a, out_transform.b, out_transform.c], + [out_transform.d, out_transform.e, out_transform.f], + [0.0, 0.0, 1.0]], + dtype=np.float64, + ) + M = np.linalg.inv(Out) @ Rt @ H_full.astype(np.float64) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = cv2.warpPerspective( + src_band, + M, + (out_w, out_h), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=float(dst_nodata) + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True, + order=0 + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True, + order=0 + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform) + + res_x, res_y = _pixel_size_xy(src.transform) + out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = src.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=out_h, + width=out_w, + count=src.count, + transform=out_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((out_h, out_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=out_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.nearest, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +def _apply_config(cfg: RegistrationConfig): + global REF_TIF, BIP_DIR, OUT_DIR + global MATCHER_NAME, DEVICE, TRANSFORM_METHODS + global MATCH_MAX_SIDE, ROI_PAD_PX, MASK_PAD_PX + global MIN_INLIERS, MIN_INLIER_RATIO + global FEATHER_PX, EDGE_BAND_PX, MIN_GRAD_QUANTILE + + REF_TIF = cfg.ref_tif + BIP_DIR = Path(cfg.bip_dir) + OUT_DIR = Path(cfg.out_dir) + MATCHER_NAME = cfg.matcher_name + DEVICE = cfg.device + TRANSFORM_METHODS = list(cfg.transform_methods) + MATCH_MAX_SIDE = int(cfg.match_max_side) + ROI_PAD_PX = int(cfg.roi_pad_px) + MASK_PAD_PX = int(cfg.mask_pad_px) + MIN_INLIERS = int(cfg.min_inliers) + MIN_INLIER_RATIO = float(cfg.min_inlier_ratio) + FEATHER_PX = int(cfg.feather_px) + EDGE_BAND_PX = int(cfg.edge_band_px) + MIN_GRAD_QUANTILE = float(cfg.min_grad_quantile) + + +def _run_batch(cfg: RegistrationConfig, stop_event: threading.Event, progress_cb=None): + _apply_config(cfg) + + out_dir = OUT_DIR + out_dir.mkdir(parents=True, exist_ok=True) + stats_dir = out_dir / "stats" + stats_dir.mkdir(parents=True, exist_ok=True) + + ts = datetime.now().strftime('%Y%m%d_%H%M%S') + stats_csv = stats_dir / f"registration_stats_{ts}.csv" + logger.info(f"统计信息将保存到: {stats_csv}") + + init_stats_csv(stats_csv) + + _ensure_pyinstaller_third_party_paths() + _install_loguru_stub_if_missing() + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + ref_path_to_use = REF_TIF + if bool(cfg.enable_ref_mask): + if not TIF_MASK_AVAILABLE: + raise RuntimeError("未能导入 tif_caijain.py,无法进行底图掩膜。") + if not cfg.ref_mask_tif or not Path(cfg.ref_mask_tif).exists(): + raise RuntimeError("已启用底图掩膜,但掩膜 TIF 文件不存在。") + + masked_dir = out_dir / "masked_refs" + masked_dir.mkdir(parents=True, exist_ok=True) + masked_ref_path = masked_dir / f"{Path(REF_TIF).stem}_masked_{ts}.tif" + + logger.info(f"开始对底图进行掩膜: {REF_TIF}") + logger.info(f"掩膜文件: {cfg.ref_mask_tif}") + mask_data_by_binary_mask( + data_path=REF_TIF, + mask_path=cfg.ref_mask_tif, + output_path=str(masked_ref_path), + remove_value=int(cfg.ref_mask_remove_value), + ) + ref_path_to_use = str(masked_ref_path) + logger.info(f"掩膜后的底图: {ref_path_to_use}") + + with rasterio.open(ref_path_to_use) as ref: + bip_files = list(Path(BIP_DIR).glob("*.bip")) + total = len(bip_files) + success_count = 0 + + if progress_cb is not None: + progress_cb(0, total, "") + + for idx, bip_path in enumerate(bip_files, start=1): + if stop_event.is_set(): + break + if process_bip_to_tif(bip_path, ref, matcher, out_dir, stats_csv): + success_count += 1 + if progress_cb is not None: + progress_cb(idx, total, bip_path.name) + + return success_count + + +class QueueHandler(logging.Handler): + def __init__(self, log_queue): + super().__init__() + self.log_queue = log_queue + + def emit(self, record): + self.log_queue.put(self.format(record)) + +class ToolTip: + def __init__(self, widget, text: str, delay_ms: int = 400): + self.widget = widget + self.text = text + self.delay_ms = int(delay_ms) + self._after_id = None + self._tip = None + + self.widget.bind("", self._on_enter, add=True) + self.widget.bind("", self._on_leave, add=True) + self.widget.bind("", self._on_leave, add=True) + + def _on_enter(self, _event=None): + self._schedule() + + def _on_leave(self, _event=None): + self._cancel() + self._hide() + + def _schedule(self): + self._cancel() + try: + self._after_id = self.widget.after(self.delay_ms, self._show) + except Exception: + self._after_id = None + + def _cancel(self): + if self._after_id is not None: + try: + self.widget.after_cancel(self._after_id) + except Exception: + pass + self._after_id = None + + def _show(self): + if self._tip is not None: + return + if not self.text: + return + try: + x = self.widget.winfo_rootx() + 10 + y = self.widget.winfo_rooty() + self.widget.winfo_height() + 6 + except Exception: + return + + self._tip = tk.Toplevel(self.widget) + self._tip.wm_overrideredirect(True) + self._tip.wm_geometry(f"+{x}+{y}") + + label = tk.Label( + self._tip, + text=self.text, + justify=tk.LEFT, + background="#ffffe0", + relief=tk.SOLID, + borderwidth=1, + wraplength=520, + ) + label.pack(ipadx=6, ipady=4) + + def _hide(self): + if self._tip is not None: + try: + self._tip.destroy() + except Exception: + pass + self._tip = None + + +_MATCHER_VALUES = [ + "liftfeat", "loftr", "eloftr", "se2loftr", "xoftr", "aspanformer", + "matchanything-eloftr", "matchanything-roma", "matchformer", + "sift-lightglue", "superpoint-lightglue", "disk-lightglue", + "aliked-lightglue", "doghardnet-lightglue", "roma", "romav2", + "tiny-roma", "dedode", "steerers", "affine-steerers", + "dedode-kornia", "sift-nn", "orb-nn", "patch2pix", "superglue", + "r2d2", "d2net", "duster", "master", "doghardnet-nn", "xfeat", + "xfeat-star", "xfeat-lightglue", "dedode-lightglue", "gim-dkm", + "gim-lightglue", "omniglue", "xfeat-subpx", "xfeat-lightglue-subpx", + "dedode-subpx", "superpoint-lightglue-subpx", "aliked-lightglue-subpx", + "sift-sphereglue", "superpoint-sphereglue", "minima", "minima-roma", + "minima-roma-tiny", "minima-superpoint-lightglue", "minima-loftr", + "minima-xoftr", "edm", "lisrd-aliked", "lisrd-superpoint", "lisrd", + "lisrd-sift", "ripe", "topicfm", "topicfm-plus", "silk", "zippypoint", + "xfeat-steerers-perm", "xfeat-steerers-learned", "xfeat-star-steerers-perm", + "xfeat-star-steerers-learned", +] + + +class RegistrationGUI: + def __init__(self, root): + self.root = root + self.root.title("遥感影像批量配准工具") + self.root.geometry("1000x800") + self._tooltips = [] + + self.log_queue = queue.Queue() + self.stop_event = threading.Event() + self.processing_thread = None + + queue_handler = QueueHandler(self.log_queue) + queue_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) + logger.addHandler(queue_handler) + logger.setLevel(logging.INFO) + + self.create_widgets() + self.check_log_queue() + + def add_tooltip(self, widget, text: str): + self._tooltips.append(ToolTip(widget, text)) + + def show_error_dialog(self, title: str, summary: str, details: str): + win = tk.Toplevel(self.root) + win.title(title) + win.geometry("900x600") + + top = ttk.Frame(win, padding=10) + top.pack(fill=tk.BOTH, expand=True) + + summary_label = tk.Label(top, text=summary, fg="#b00020", justify=tk.LEFT, wraplength=860) + summary_label.pack(anchor=tk.W, fill=tk.X) + + text_frame = ttk.Frame(top) + text_frame.pack(fill=tk.BOTH, expand=True, pady=(10, 0)) + + scrollbar = ttk.Scrollbar(text_frame, orient=tk.VERTICAL) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + text = tk.Text(text_frame, wrap=tk.NONE, yscrollcommand=scrollbar.set) + text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.config(command=text.yview) + + if details: + text.insert(tk.END, details) + text.config(state=tk.DISABLED) + + btns = ttk.Frame(top) + btns.pack(fill=tk.X, pady=(10, 0)) + + def copy_details(): + try: + self.root.clipboard_clear() + self.root.clipboard_append(details or summary) + self.root.update() + except Exception: + pass + + ttk.Button(btns, text="复制详情", command=copy_details).pack(side=tk.LEFT) + ttk.Button(btns, text="关闭", command=win.destroy).pack(side=tk.RIGHT) + + try: + win.transient(self.root) + win.grab_set() + win.focus_force() + except Exception: + pass + + def show_exception_dialog(self, title: str, exc: BaseException): + self.show_error_dialog(title=title, summary=str(exc), details=traceback.format_exc()) + + def create_widgets(self): + main_frame = ttk.Frame(self.root, padding="10") + main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) + + config_frame = ttk.LabelFrame(main_frame, text="配置参数", padding="5") + config_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(0, 10)) + + ref_label = ttk.Label(config_frame, text="参考TIF文件:") + ref_label.grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) + self.ref_tif_var = tk.StringVar(value=str(REF_TIF)) + ref_entry = ttk.Entry(config_frame, textvariable=self.ref_tif_var, width=50) + ref_entry.grid(row=0, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ref_btn = ttk.Button(config_frame, text="选择文件", command=self.select_ref_tif) + ref_btn.grid(row=0, column=2) + + self.enable_ref_mask_var = tk.BooleanVar(value=False) + ref_mask_chk = ttk.Checkbutton( + config_frame, + text="启用底图掩膜", + variable=self.enable_ref_mask_var, + command=self._on_toggle_ref_mask, + ) + ref_mask_chk.grid(row=1, column=0, sticky=tk.W, padx=(0, 5)) + + self.ref_mask_tif_var = tk.StringVar(value="") + self.ref_mask_entry = ttk.Entry(config_frame, textvariable=self.ref_mask_tif_var, width=50, state=tk.DISABLED) + self.ref_mask_entry.grid(row=1, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + self.ref_mask_btn = ttk.Button(config_frame, text="选择文件", command=self.select_ref_mask_tif, state=tk.DISABLED) + self.ref_mask_btn.grid(row=1, column=2) + + bip_label = ttk.Label(config_frame, text="BIP文件夹:") + bip_label.grid(row=2, column=0, sticky=tk.W, padx=(0, 5)) + self.bip_dir_var = tk.StringVar(value=str(BIP_DIR)) + bip_entry = ttk.Entry(config_frame, textvariable=self.bip_dir_var, width=50) + bip_entry.grid(row=2, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + bip_btn = ttk.Button(config_frame, text="选择文件夹", command=self.select_bip_dir) + bip_btn.grid(row=2, column=2) + + out_label = ttk.Label(config_frame, text="输出文件夹:") + out_label.grid(row=3, column=0, sticky=tk.W, padx=(0, 5)) + self.out_dir_var = tk.StringVar(value=str(OUT_DIR)) + out_entry = ttk.Entry(config_frame, textvariable=self.out_dir_var, width=50) + out_entry.grid(row=3, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + out_btn = ttk.Button(config_frame, text="选择文件夹", command=self.select_out_dir) + out_btn.grid(row=3, column=2) + + matcher_label = ttk.Label(config_frame, text="匹配算法:") + matcher_label.grid(row=4, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0)) + self.matcher_var = tk.StringVar(value=str(MATCHER_NAME)) + matcher_combo = ttk.Combobox(config_frame, textvariable=self.matcher_var, width=47) + matcher_combo['values'] = _MATCHER_VALUES + matcher_combo.grid(row=4, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + device_label = ttk.Label(config_frame, text="设备:") + device_label.grid(row=5, column=0, sticky=tk.W, padx=(0, 5)) + self.device_var = tk.StringVar(value=str(DEVICE)) + device_frame = ttk.Frame(config_frame) + device_frame.grid(row=5, column=1, columnspan=2, sticky=(tk.W, tk.E)) + cuda_rb = ttk.Radiobutton(device_frame, text="CUDA", variable=self.device_var, value="cuda") + cpu_rb = ttk.Radiobutton(device_frame, text="CPU", variable=self.device_var, value="cpu") + cuda_rb.pack(side=tk.LEFT) + cpu_rb.pack(side=tk.LEFT) + + transform_label = ttk.Label(config_frame, text="变换方法 (按优先级):") + transform_label.grid(row=6, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0)) + transform_frame = ttk.Frame(config_frame) + transform_frame.grid(row=6, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + self.transform_listbox = tk.Listbox(transform_frame, selectmode=tk.MULTIPLE, height=5, exportselection=False) + transform_methods = ["similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps"] + for method in transform_methods: + self.transform_listbox.insert(tk.END, method) + if method in TRANSFORM_METHODS: + self.transform_listbox.selection_set(transform_methods.index(method)) + + scrollbar = ttk.Scrollbar(transform_frame, orient=tk.VERTICAL, command=self.transform_listbox.yview) + self.transform_listbox.configure(yscrollcommand=scrollbar.set) + self.transform_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + button_frame = ttk.Frame(transform_frame) + button_frame.pack(side=tk.RIGHT, padx=(5, 0)) + ttk.Button(button_frame, text="↑ 上移", command=self.move_up).pack(fill=tk.X, pady=(0, 2)) + ttk.Button(button_frame, text="↓ 下移", command=self.move_down).pack(fill=tk.X) + + param_frame = ttk.LabelFrame(config_frame, text="参数设置", padding="5") + param_frame.grid(row=7, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(10, 0)) + + match_max_side_label = ttk.Label(param_frame, text="匹配最大边长:") + match_max_side_label.grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) + self.match_max_side_var = tk.IntVar(value=int(MATCH_MAX_SIDE)) + match_max_side_entry = ttk.Entry(param_frame, textvariable=self.match_max_side_var, width=10) + match_max_side_entry.grid(row=0, column=1, sticky=tk.W) + + roi_pad_label = ttk.Label(param_frame, text="ROI填充像素:") + roi_pad_label.grid(row=0, column=2, sticky=tk.W, padx=(10, 5)) + self.roi_pad_px_var = tk.IntVar(value=int(ROI_PAD_PX)) + roi_pad_entry = ttk.Entry(param_frame, textvariable=self.roi_pad_px_var, width=10) + roi_pad_entry.grid(row=0, column=3, sticky=tk.W) + + mask_pad_label = ttk.Label(param_frame, text="掩膜膨胀像素:") + mask_pad_label.grid(row=0, column=4, sticky=tk.W, padx=(10, 5)) + self.mask_pad_px_var = tk.IntVar(value=int(MASK_PAD_PX)) + mask_pad_entry = ttk.Entry(param_frame, textvariable=self.mask_pad_px_var, width=10) + mask_pad_entry.grid(row=0, column=5, sticky=tk.W) + + min_inliers_label = ttk.Label(param_frame, text="最少内点数:") + min_inliers_label.grid(row=1, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0)) + self.min_inliers_var = tk.IntVar(value=int(MIN_INLIERS)) + min_inliers_entry = ttk.Entry(param_frame, textvariable=self.min_inliers_var, width=10) + min_inliers_entry.grid(row=1, column=1, sticky=tk.W, pady=(5, 0)) + + min_ratio_label = ttk.Label(param_frame, text="最少内点比例:") + min_ratio_label.grid(row=1, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0)) + self.min_inlier_ratio_var = tk.DoubleVar(value=float(MIN_INLIER_RATIO)) + min_ratio_entry = ttk.Entry(param_frame, textvariable=self.min_inlier_ratio_var, width=10) + min_ratio_entry.grid(row=1, column=3, sticky=tk.W, pady=(5, 0)) + + feather_label = ttk.Label(param_frame, text="羽化像素:") + feather_label.grid(row=2, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0)) + self.feather_px_var = tk.IntVar(value=int(FEATHER_PX)) + feather_entry = ttk.Entry(param_frame, textvariable=self.feather_px_var, width=10) + feather_entry.grid(row=2, column=1, sticky=tk.W, pady=(5, 0)) + + edge_band_label = ttk.Label(param_frame, text="边界剔除像素:") + edge_band_label.grid(row=2, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0)) + self.edge_band_px_var = tk.IntVar(value=int(EDGE_BAND_PX)) + edge_band_entry = ttk.Entry(param_frame, textvariable=self.edge_band_px_var, width=10) + edge_band_entry.grid(row=2, column=3, sticky=tk.W, pady=(5, 0)) + + grad_q_label = ttk.Label(param_frame, text="梯度分位阈值:") + grad_q_label.grid(row=2, column=4, sticky=tk.W, padx=(10, 5), pady=(5, 0)) + self.min_grad_quantile_var = tk.DoubleVar(value=float(MIN_GRAD_QUANTILE)) + grad_q_entry = ttk.Entry(param_frame, textvariable=self.min_grad_quantile_var, width=10) + grad_q_entry.grid(row=2, column=5, sticky=tk.W, pady=(5, 0)) + + self.add_tooltip(ref_label, "参考底图 GeoTIFF,用于批量配准的目标坐标系与位置基准。建议确保 CRS、transform 正确。") + self.add_tooltip(ref_entry, "参考底图 GeoTIFF 路径。配准时会读取该底图的 ROI 进行匹配。") + self.add_tooltip(ref_btn, "选择参考底图 GeoTIFF 文件。") + + self.add_tooltip(ref_mask_chk, "勾选后先用掩膜 TIF 对底图进行掩膜(掩膜值=1 的区域设置为 NoData),并保存为新的底图;后续配准使用掩膜后的底图。") + self.add_tooltip(self.ref_mask_entry, "掩膜 GeoTIFF 路径。要求与底图严格对齐(相同 CRS、分辨率、范围、尺寸),否则会报错或效果不可控。") + self.add_tooltip(self.ref_mask_btn, "选择掩膜 GeoTIFF 文件。") + + self.add_tooltip(bip_label, "包含待配准航带 .bip 文件的文件夹。程序会批量遍历 *.bip。") + self.add_tooltip(bip_entry, "BIP 文件夹路径。") + self.add_tooltip(bip_btn, "选择 BIP 文件夹。") + + self.add_tooltip(out_label, "输出目录:配准后的航带、可视化图片、统计 CSV 等都会写到这里。") + self.add_tooltip(out_entry, "输出文件夹路径。") + self.add_tooltip(out_btn, "选择输出文件夹。") + + self.add_tooltip(matcher_label, "特征匹配算法名称。不同 matcher 在精度、速度、鲁棒性上差异较大。") + self.add_tooltip(matcher_combo, "选择/输入 matcher 名称。若使用 cuda,需要环境支持 GPU。") + + self.add_tooltip(device_label, "运行设备:cuda(GPU)更快,cpu 更通用。") + self.add_tooltip(cuda_rb, "使用 GPU(CUDA)运行匹配器与部分计算。") + self.add_tooltip(cpu_rb, "使用 CPU 运行。速度可能较慢。") + + self.add_tooltip(transform_label, "变换模型选择(可多选)。配准会按优先级尝试,并自动选择误差较小的模型。") + self.add_tooltip(self.transform_listbox, "按住 Ctrl/Shift 多选。右侧可上移/下移调整优先级。一般 homography 更灵活但更易发散,affine 更稳定。") + + self.add_tooltip(match_max_side_label, "匹配阶段会把图像等比缩小到最大边长不超过该值。值越大越慢,但细节更多。") + self.add_tooltip(match_max_side_entry, "匹配用降采样尺寸上限(像素)。") + + self.add_tooltip(roi_pad_label, "参考底图 ROI 的额外扩展像素。增大可覆盖更大不确定区域,但会增加内存与耗时。") + self.add_tooltip(roi_pad_entry, "ROI padding(像素,参考底图坐标系)。") + + self.add_tooltip(mask_pad_label, "仅用于匹配阶段:对源图有效掩膜/重投影后的掩膜做膨胀,增加可匹配区域。") + self.add_tooltip(mask_pad_entry, "掩膜膨胀像素(只影响匹配,不直接改变输出)。") + + self.add_tooltip(min_inliers_label, "RANSAC 内点数量阈值。低于该值认为匹配质量不足,判定失败。") + self.add_tooltip(min_inliers_entry, "最少内点数。") + + self.add_tooltip(min_ratio_label, "内点比例阈值(内点数/匹配点数)。过低通常意味着匹配不可靠。") + self.add_tooltip(min_ratio_entry, "最少内点比例。") + + self.add_tooltip(feather_label, "对掩膜边缘做羽化,降低硬边缘带来的高对比假匹配。数值越大边缘过渡越宽。") + self.add_tooltip(feather_entry, "掩膜羽化宽度(像素)。") + + self.add_tooltip(edge_band_label, "剔除距离掩膜边界过近的匹配点,减少边缘假匹配。数值越大剔除越多。") + self.add_tooltip(edge_band_entry, "边缘带剔除宽度(像素)。") + + self.add_tooltip(grad_q_label, "纹理过滤分位阈值:梯度幅值低于该分位的区域视为低纹理,匹配点会被剔除。") + self.add_tooltip(grad_q_entry, "梯度分位阈值(0~1)。") + + control_frame = ttk.Frame(main_frame) + control_frame.grid(row=1, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + self.start_btn = ttk.Button(control_frame, text="开始处理", command=self.start_processing) + self.start_btn.pack(side=tk.LEFT, padx=(0, 10)) + self.stop_btn = ttk.Button(control_frame, text="停止处理", command=self.stop_processing, state=tk.DISABLED) + self.stop_btn.pack(side=tk.LEFT) + + progress_frame = ttk.LabelFrame(main_frame, text="处理进度", padding="5") + progress_frame.grid(row=2, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + self.progress_var = tk.DoubleVar() + self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100) + self.progress_bar.pack(fill=tk.X, pady=(0, 5)) + self.progress_label = ttk.Label(progress_frame, text="准备就绪") + self.progress_label.pack(anchor=tk.W) + + log_frame = ttk.LabelFrame(main_frame, text="处理日志", padding="5") + log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(10, 0)) + log_text_frame = ttk.Frame(log_frame) + log_text_frame.pack(fill=tk.BOTH, expand=True) + self.log_text = tk.Text(log_text_frame, height=15, wrap=tk.WORD) + scrollbar = ttk.Scrollbar(log_text_frame, orient=tk.VERTICAL, command=self.log_text.yview) + self.log_text.configure(yscrollcommand=scrollbar.set) + self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + log_btn_frame = ttk.Frame(log_frame) + log_btn_frame.pack(fill=tk.X, pady=(5, 0)) + ttk.Button(log_btn_frame, text="清空日志", command=self.clear_log).pack(side=tk.LEFT, padx=(0, 5)) + ttk.Button(log_btn_frame, text="保存日志", command=self.save_log).pack(side=tk.LEFT) + + self.root.columnconfigure(0, weight=1) + self.root.rowconfigure(0, weight=1) + main_frame.columnconfigure(1, weight=1) + main_frame.rowconfigure(3, weight=1) + + def select_ref_tif(self): + filename = filedialog.askopenfilename( + title="选择参考TIF文件", + filetypes=[("TIF files", "*.tif;*.tiff"), ("All files", "*.*")] + ) + if filename: + self.ref_tif_var.set(filename) + + def select_ref_mask_tif(self): + filename = filedialog.askopenfilename( + title="选择掩膜TIF文件", + filetypes=[("TIF files", "*.tif;*.tiff"), ("All files", "*.*")] + ) + if filename: + self.ref_mask_tif_var.set(filename) + + def select_bip_dir(self): + dirname = filedialog.askdirectory(title="选择BIP文件夹") + if dirname: + self.bip_dir_var.set(dirname) + + def select_out_dir(self): + dirname = filedialog.askdirectory(title="选择输出文件夹") + if dirname: + self.out_dir_var.set(dirname) + + def move_up(self): + selection = self.transform_listbox.curselection() + if selection and selection[0] > 0: + idx = selection[0] + text = self.transform_listbox.get(idx) + self.transform_listbox.delete(idx) + self.transform_listbox.insert(idx - 1, text) + self.transform_listbox.selection_set(idx - 1) + + def move_down(self): + selection = self.transform_listbox.curselection() + if selection and selection[0] < self.transform_listbox.size() - 1: + idx = selection[0] + text = self.transform_listbox.get(idx) + self.transform_listbox.delete(idx) + self.transform_listbox.insert(idx + 1, text) + self.transform_listbox.selection_set(idx + 1) + + def start_processing(self): + if self.processing_thread and self.processing_thread.is_alive(): + messagebox.showwarning("警告", "处理正在进行中") + return + + selected_indices = self.transform_listbox.curselection() + if not selected_indices: + messagebox.showwarning("警告", "请至少选择一种变换方法") + return + + transform_methods = [self.transform_listbox.get(i) for i in selected_indices] + + cfg = RegistrationConfig( + ref_tif=self.ref_tif_var.get().strip(), + bip_dir=self.bip_dir_var.get().strip(), + out_dir=self.out_dir_var.get().strip(), + enable_ref_mask=bool(self.enable_ref_mask_var.get()), + ref_mask_tif=self.ref_mask_tif_var.get().strip(), + ref_mask_remove_value=1, + matcher_name=self.matcher_var.get().strip(), + device=self.device_var.get().strip(), + transform_methods=transform_methods, + match_max_side=int(self.match_max_side_var.get()), + roi_pad_px=int(self.roi_pad_px_var.get()), + mask_pad_px=int(self.mask_pad_px_var.get()), + min_inliers=int(self.min_inliers_var.get()), + min_inlier_ratio=float(self.min_inlier_ratio_var.get()), + feather_px=int(self.feather_px_var.get()), + edge_band_px=int(self.edge_band_px_var.get()), + min_grad_quantile=float(self.min_grad_quantile_var.get()), + ) + + if not Path(cfg.ref_tif).exists(): + messagebox.showerror("错误", "参考 TIF 不存在") + return + if not Path(cfg.bip_dir).exists(): + messagebox.showerror("错误", "BIP 文件夹不存在") + return + if not cfg.out_dir: + messagebox.showerror("错误", "输出文件夹不能为空") + return + if cfg.enable_ref_mask: + if not TIF_MASK_AVAILABLE: + messagebox.showerror("错误", "tif_caijain.py 不可用,无法进行底图掩膜") + return + if not cfg.ref_mask_tif or not Path(cfg.ref_mask_tif).exists(): + messagebox.showerror("错误", "已启用底图掩膜,但掩膜 TIF 文件不存在") + return + + self.stop_event.clear() + self.start_btn.config(state=tk.DISABLED) + self.stop_btn.config(state=tk.NORMAL) + self.progress_var.set(0) + self.progress_label.config(text="正在初始化...") + + self.processing_thread = threading.Thread( + target=self.run_processing, + args=(cfg,), + daemon=True + ) + self.processing_thread.start() + + def _on_toggle_ref_mask(self): + enabled = bool(self.enable_ref_mask_var.get()) + state = tk.NORMAL if enabled else tk.DISABLED + try: + self.ref_mask_entry.configure(state=state) + self.ref_mask_btn.configure(state=state) + except Exception: + pass + + def stop_processing(self): + if self.processing_thread and self.processing_thread.is_alive(): + self.stop_event.set() + self.progress_label.config(text="正在停止...") + + def run_processing(self, cfg: RegistrationConfig): + try: + def progress_cb(current, total, filename): + self.on_progress(current, total, filename) + _run_batch(cfg, self.stop_event, progress_cb=progress_cb) + except Exception as e: + tb = traceback.format_exc() + self.log_queue.put(f"处理过程中发生错误: {e}\n{tb}") + try: + self.root.after(0, lambda: self.show_error_dialog("处理失败", str(e), tb)) + except Exception: + pass + finally: + self.root.after(0, lambda: self.start_btn.config(state=tk.NORMAL)) + self.root.after(0, lambda: self.stop_btn.config(state=tk.DISABLED)) + self.root.after(0, lambda: self.progress_label.config(text="处理完成")) + + def on_progress(self, current, total, filename): + if total > 0: + progress = (current / total) * 100 + self.root.after(0, lambda: self.progress_var.set(progress)) + if filename: + self.root.after(0, lambda: self.progress_label.config(text=f"处理中: {filename} ({current}/{total})")) + else: + self.root.after(0, lambda: self.progress_label.config(text=f"处理中: ({current}/{total})")) + + def check_log_queue(self): + try: + while True: + message = self.log_queue.get_nowait() + self.log_text.insert(tk.END, message + '\n') + self.log_text.see(tk.END) + except queue.Empty: + pass + self.root.after(100, self.check_log_queue) + + def clear_log(self): + self.log_text.delete(1.0, tk.END) + + def save_log(self): + filename = filedialog.asksaveasfilename( + title="保存日志", + defaultextension=".txt", + filetypes=[("Text files", "*.txt"), ("All files", "*.*")] + ) + if filename: + with open(filename, 'w', encoding='utf-8') as f: + f.write(self.log_text.get(1.0, tk.END)) + + +def create_gui(): + root = tk.Tk() + RegistrationGUI(root) + root.mainloop() + + +# ---------- 主逻辑 ---------- +def main(): + cfg = RegistrationConfig( + ref_tif=str(REF_TIF), + bip_dir=str(BIP_DIR), + out_dir=str(OUT_DIR), + enable_ref_mask=False, + ref_mask_tif="", + ref_mask_remove_value=1, + matcher_name=str(MATCHER_NAME), + device=str(DEVICE), + transform_methods=list(TRANSFORM_METHODS), + match_max_side=int(MATCH_MAX_SIDE), + roi_pad_px=int(ROI_PAD_PX), + mask_pad_px=int(MASK_PAD_PX), + min_inliers=int(MIN_INLIERS), + min_inlier_ratio=float(MIN_INLIER_RATIO), + feather_px=int(FEATHER_PX), + edge_band_px=int(EDGE_BAND_PX), + min_grad_quantile=float(MIN_GRAD_QUANTILE), + ) + stop_event = threading.Event() + _run_batch(cfg, stop_event) + +if __name__ == "__main__": + if "--cli" in sys.argv: + main() + else: + create_gui() diff --git a/StripStitch.spec b/StripStitch.spec new file mode 100644 index 0000000..f2e8767 --- /dev/null +++ b/StripStitch.spec @@ -0,0 +1,153 @@ +# -*- mode: python ; coding: utf-8 -*- + +import os +from PyInstaller.utils.hooks import ( + collect_data_files, + collect_dynamic_libs, + collect_submodules, +) + + +def _safe_collect_submodules(name: str): + try: + return collect_submodules(name) + except Exception: + return [] + + +def _safe_collect_data_files(name: str): + try: + return collect_data_files(name) + except Exception: + return [] + + +def _safe_collect_dynamic_libs(name: str): + try: + return collect_dynamic_libs(name) + except Exception: + return [] + + +project_root = r"e:\code\vismatch-main\vismatch-main" +test_dir = os.path.join(project_root, "test") +script_path = os.path.join(test_dir, "StripStitch.py") + +# Find the actual vismatch installation location (usually in site-packages) +import vismatch as _vismatch_pkg +vismatch_sitepkg_root = os.path.dirname(_vismatch_pkg.__file__) + + +hiddenimports = [] +hiddenimports += _safe_collect_submodules("vismatch") +hiddenimports += _safe_collect_submodules("rasterio") +hiddenimports += _safe_collect_submodules("rasterio._base") +hiddenimports += _safe_collect_submodules("rasterio._io") +hiddenimports += _safe_collect_submodules("affine") +hiddenimports += _safe_collect_submodules("cv2") +hiddenimports += ["tif_caijain"] + +hiddenimports += _safe_collect_submodules("pyproj") +hiddenimports += _safe_collect_submodules("scipy") +hiddenimports += _safe_collect_submodules("skimage") +hiddenimports += _safe_collect_submodules("SimpleITK") +hiddenimports += _safe_collect_submodules("pirt") +hiddenimports += _safe_collect_submodules("loguru") + +# MatchAnything's src module - need to collect from third_party directory +# Add the src module path to pathex so PyInstaller can find it during analysis +matchanything_src_dir = os.path.join(vismatch_sitepkg_root, "third_party", "MatchAnything", "imcui", "third_party", "MatchAnything") +if os.path.isdir(matchanything_src_dir): + # Add to pathex for analysis + pass # Will be added to pathex in Analysis + # Also try to collect src submodules if they exist + try: + hiddenimports += _safe_collect_submodules("src") + except Exception: + pass + + +datas = [] +datas += _safe_collect_data_files("vismatch") +datas += _safe_collect_data_files("rasterio") +datas += _safe_collect_data_files("pyproj") + +# vismatch 的 third_party 下包含大量运行时动态 add_to_path 的源码(例如 matchanything 依赖的 src.*)。 +# 这些 .py 文件不会被 collect_data_files 收集,需作为 datas 复制到 dist 里,并保持相对路径结构, +# 使得运行时 THIRD_PARTY_DIR = Path(vismatch.__file__).parent/"third_party" 可找到它们. +# Use the actual vismatch installation location (site-packages) instead of project root +# NOTE: Analysis() expects hook-style 2-tuples (source_dir, dest_dir), not Tree() TOC 3-tuples. +third_party_dir = os.path.join(vismatch_sitepkg_root, "third_party") +if os.path.isdir(third_party_dir): + datas.append((third_party_dir, os.path.join("vismatch", "third_party"))) + +# Include HuggingFace model weights for offline use (~/.cache/huggingface/hub/) +hf_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub") +if os.path.isdir(hf_cache_dir): + for model_dir in os.listdir(hf_cache_dir): + if "vismatch" in model_dir.lower(): + full_model_path = os.path.join(hf_cache_dir, model_dir) + if os.path.isdir(full_model_path): + datas.append((full_model_path, os.path.join("hub", model_dir))) + + +binaries = [] +binaries += _safe_collect_dynamic_libs("rasterio") +binaries += _safe_collect_dynamic_libs("pyproj") +binaries += _safe_collect_dynamic_libs("cv2") + +# Build pathex - include MatchAnything src directory for proper import analysis +pathex = [project_root, test_dir] +matchanything_src_dir = os.path.join(vismatch_sitepkg_root, "third_party", "MatchAnything", "imcui", "third_party", "MatchAnything") +if os.path.isdir(matchanything_src_dir): + pathex.append(matchanything_src_dir) + # Also add ROMA if it exists + roma_dir = os.path.join(vismatch_sitepkg_root, "third_party", "MatchAnything", "imcui", "third_party", "MatchAnything", "third_party", "ROMA") + if os.path.isdir(roma_dir): + pathex.append(roma_dir) + +a = Analysis( + [script_path], + pathex=pathex, + binaries=binaries, + datas=datas, + hiddenimports=hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + noarchive=False, +) + +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name="StripStitch", + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=False, + upx_exclude=[], + runtime_tmpdir=None, + console=False, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) + +coll = COLLECT( + exe, + a.binaries, + a.zipfiles, + a.datas, + strip=False, + upx=False, + upx_exclude=[], + name="StripStitch", +) diff --git a/StripStitch_说明.md b/StripStitch_说明.md new file mode 100644 index 0000000..04a75cf --- /dev/null +++ b/StripStitch_说明.md @@ -0,0 +1,287 @@ +## `StripStitch.py` 说明文档(遥感航带批量配准) + +`StripStitch.py` 用于 **批量将文件夹内的 `.bip` 航带影像配准到一张参考 GeoTIFF(`.tif`)底图**。脚本默认启动 **GUI(Tkinter)**,也支持通过 `--cli` 走“无界面批处理”(但当前版本 **没有 argparse 参数**,CLI 模式仍使用脚本内的默认配置/或你手动改常量)。 + +--- + +## 1. 功能概览 + +- **批量处理**:遍历 `BIP_DIR` 下的 `*.bip`,逐个配准到参考底图。 +- **自动裁剪参考 ROI**:根据源图有效像素掩膜推算与底图的重叠范围,并在底图上取 ROI(可额外 padding)。 +- **特征匹配**:通过 `vismatch.get_matcher()` 调用多种 matcher(可选 GPU/CPU)。 +- **质量控制**: + - 掩膜羽化(降低硬边界导致的假匹配) + - 掩膜边缘带剔除(减少“掩膜边界”上的伪匹配点) + - 纹理过滤(低梯度/低纹理区域的匹配点剔除) + - 最少内点数/最少内点比例阈值 +- **多模型变换尝试并自动选最优**:按你设置的优先级尝试 `similarity/affine/homography/piecewise_affine/polynomial/tps...`,以重投影误差(p95)选择最优。 +- **输出**: + - 配准后的 `.bip`(ENVI driver) + - 匹配可视化图(匹配线、关键点) + - 统计 CSV(每个航带一行:内点数、内点比例、选用模型、误差等) +- **可选底图掩膜**:先用另一个 mask GeoTIFF 对参考底图做掩膜生成新底图,再用新底图进行配准(依赖 `tif_caijain.py`)。 + +--- + +## 2. 运行方式 + +### 2.1 GUI(默认) + +在脚本所在目录执行: + +```bash +python StripStitch.py +``` + +GUI 会打开“遥感影像批量配准工具”,在界面里选择: + +- 参考底图(GeoTIFF) +- BIP 文件夹 +- 输出文件夹 +- matcher、设备(CUDA/CPU) +- 变换模型优先级 +- 一系列阈值/过滤参数 + +点击“开始”后批处理运行,日志会实时显示;可点击“停止”中断。 + +### 2.2 CLI(无界面批处理) + +```bash +python StripStitch.py --cli +``` + +注意: + +- 当前脚本仅通过 `if "--cli" in sys.argv` 切换模式,**没有命令行参数解析**。 +- CLI 模式使用脚本顶部的默认常量(如 `REF_TIF / BIP_DIR / OUT_DIR / MATCHER_NAME ...`)构造配置。 +- 如果你希望 CLI 可配置参数,需要额外改造(例如用 `argparse`)。 + +--- + +## 3. 输入数据要求 + +- **参考底图**:GeoTIFF(`.tif`),必须有 **有效 CRS/transform**(脚本会报错或无法正确投影/裁剪)。 +- **待配准航带**:ENVI BIP(扩展名 `.bip`,内部 profile 由 `rasterio` 读取)。 +- **坐标系**: + - 源 `.bip` 若缺少 CRS,脚本会尝试使用参考底图 CRS 继续(可能导致错误配准,建议为源数据补齐 CRS)。 + - ROI 的推算使用 `rasterio.warp.transform_bounds` 从源 CRS 转到参考 CRS。 + +--- + +## 4. 输出结构与文件命名 + +假设输出目录为 `OUT_DIR`,运行一次批处理会生成: + +- **配准结果**:`OUT_DIR/_registered.bip` + - 写入使用 `rasterio.open(..., driver="ENVI", interleave="bip")` + - 旁边通常会伴随 ENVI 的头文件(如 `.hdr` 等,取决于 `rasterio/GDAL` 行为) +- **可视化**:`OUT_DIR/visualizations/` + - `_matches.png`:匹配线可视化 + - `_keypoints_src.png`:源图关键点 + - `_keypoints_ref.png`:参考 ROI 关键点 +- **统计**:`OUT_DIR/stats/registration_stats_.csv` + - 列:`timestamp, filename, num_inliers, num_matches, inlier_ratio, selected_method, median_error, p95_error, success` +- **(可选)掩膜后的参考底图**:`OUT_DIR/masked_refs/_masked_.tif` + - 仅当 GUI 勾选“启用底图掩膜”且 `tif_caijain.py` 可用时生成 + +--- + +## 5. 参数说明(GUI 与脚本配置一致) + +### 5.1 基础路径 + +- **参考 TIF 文件**:配准目标底图(匹配在底图 ROI 上进行)。 +- **BIP 文件夹**:批量遍历 `*.bip`。 +- **输出文件夹**:结果、可视化、统计都写入这里。 + +### 5.2 底图掩膜(可选) + +GUI 勾选“启用底图掩膜”后: + +- 需要提供 **掩膜 TIF**,并要求与底图 **严格对齐**(同 CRS、分辨率、范围、尺寸)。 +- 掩膜值为 `remove_value` 的区域会被设置为 NoData(由 `tif_caijain.mask_data_by_binary_mask` 实现)。 +- 脚本会先生成“掩膜后的底图”,后续配准基于该底图进行。 + +### 5.3 matcher 与设备 + +- **匹配算法(matcher_name)**:由 `vismatch.get_matcher(name, device=...)` 创建。 + - GUI 下拉框内内置了一长串候选(如 `matchanything-roma`, `loftr`, `sift-lightglue`, `roma`, `xfeat-star` 等)。 + - 不同 matcher 在速度/鲁棒性/显存占用方面差异很大。 +- **设备(device)**: + - `cuda`:更快,但需要 GPU + CUDA 环境 + - `cpu`:更通用,但会慢 + +### 5.4 变换模型(按优先级) + +可多选并排序;脚本会按优先级逐一估计并评估误差,最终选 p95 重投影误差最小者。 + +常用含义(对应 GUI 列表): + +- **`similarity`**:相似变换(平移 + 旋转 + 等比缩放) +- **`affine`**:仿射变换(含非等比缩放/切变) +- **`homography`**:单应(透视)变换,最灵活但更易受离群点影响 +- **`piecewise_affine` / `polynomial` / `polynomial_order3` / `tps`**:非刚性/高阶模型(依赖可选库,见“依赖”) + +### 5.5 匹配与 ROI 参数 + +- **匹配最大边长(match_max_side)**:匹配前会等比缩小到最大边不超过该值。越大越慢、细节更多。 +- **ROI 填充像素(roi_pad_px)**:从源有效区域推算出的底图 ROI,会再向外扩展该像素数(底图像素尺度)。 +- **掩膜膨胀像素(mask_pad_px)**:仅用于匹配阶段,对源有效掩膜重投影到参考 ROI 后进行膨胀,扩大可匹配区域。 + +### 5.6 质量控制(建议从默认值起调) + +- **最少内点数(min_inliers)**:过滤后 RANSAC 内点数低于该值直接判失败。 +- **最少内点比例(min_inlier_ratio)**:\(\text{inliers} / \text{matches}\) 低于阈值判失败。 +- **掩膜羽化宽度(feather_px)**:在掩膜边界做平滑过渡,减少硬边界假匹配。 +- **边缘带剔除宽度(edge_band_px)**:剔除距离掩膜边界过近的匹配点(小图尺度会按缩放比例换算)。 +- **纹理过滤分位阈值(min_grad_quantile)**:在匹配尺寸上计算梯度幅值分位数,低纹理区域的点会被剔除。值越大,保留的区域越“高纹理”。 + +--- + +## 6. 内部处理流程(高层) + +单个 `.bip` 的核心流程(简化版): + +- 读取源图与源有效掩膜(`read_masks`),在源 CRS 下取有效像素包围盒 +- 包围盒 bounds 投影到参考 CRS,在参考底图上构建 ROI window,并额外 `roi_pad_px` +- 读取源图全图、参考底图 ROI +- 将源有效掩膜重投影到参考 ROI,并按 `mask_pad_px` 膨胀 +- 对源/参考做掩膜羽化(`feather_px`)后进入匹配 +- 下采样到 `match_max_side`,运行 matcher 得到匹配点/内点 +- 用“边缘带剔除 + 纹理过滤”对匹配点/内点二次过滤,得到最终内点与质量指标 +- 在你选定的多个变换模型中估计并评估,选 p95 误差最小者 +- 根据模型类型执行重采样并写出 ENVI BIP +- 写统计 CSV,保存可视化图片 + +--- + +## 7. 依赖与可选依赖(缺失时的行为) + +### 7.1 必需依赖(脚本直接 import) + +- `numpy` +- `opencv-python`(`cv2`) +- `rasterio` +- `affine` +- `tkinter`(Windows 自带 Python 通常包含) +- `vismatch`(脚本依赖它来创建 matcher,并用 `vismatch.viz` 保存可视化图) + +### 7.2 可选依赖(缺失会降级/跳过) + +- **`tif_caijain.py`**:用于“底图掩膜”功能;缺失则 GUI 勾选会报错。 +- **`scikit-image`**:用于 `piecewise_affine` / `polynomial` 等变换;缺失会跳过这些方法或走回退逻辑。 +- **`matplotlib` + `scipy`(ConvexHull)**:用于点集凸包的“内点判定”,缺失时会退化为矩形内判断(更可能外推导致异常区域)。 +- **`SimpleITK` / `pirt` / `scipy.interpolate.Rbf`**:用于 TPS 等更复杂的非线性变换与回退路径;缺失时可能回退到更简单的仿射。 + +--- + +## 8. 常见问题排查 + +- **报错:参考文件缺少 CRS 信息** + - 参考 GeoTIFF 必须有 CRS。用 GIS 软件或 `gdalinfo` 检查并修复投影信息。 + +- **源 `.bip` 缺少 CRS** + - 脚本会“尝试用参考 CRS”,但这通常不可靠;建议为源数据补齐正确 CRS。 + +- **内点很少 / 内点比例很低** + - 优先检查: + - 参考底图与航带是否确实有重叠区域 + - ROI padding 是否过小(`roi_pad_px`) + - 边缘带剔除是否过强(`edge_band_px`) + - 纹理过滤是否过强(`min_grad_quantile`) + - 其次尝试更鲁棒的 matcher(或改用 GPU)。 + +- **输出范围不对 / 被裁得太小** + - 输出范围由匹配点映射到参考像素后的外接矩形决定;匹配点集中在局部会导致 bbox 偏小。 + - 可尝试增大 `roi_pad_px`、降低过滤强度、或换更稳定的 matcher。 + +- **非线性方法(piecewise/polynomial/tps)经常失败** + - 先确保 `scikit-image`、`scipy` 等可用。 + - 非线性方法更依赖匹配点覆盖范围与质量;当点分布很局部时更容易外推/数值不稳。 + - 生产环境通常建议把 `affine` 作为兜底并放在较高优先级。 + +--- + +## 9. 推荐使用习惯(实操) + +- **先用少量样本跑通**:把 `BIP_DIR` 里先放 1–3 条航带验证匹配与输出,再批量跑全量。 +- **先看 `visualizations/`**:匹配线与关键点能最快判断“是没重叠、还是过滤过强、还是 matcher 不适配”。 +- **保留 `stats/`**:后续筛选失败样本、做参数回归/对比很有用。 + +flowchart TD + A[开始处理单个BIP文件] --> B[读取源文件和参考文件] + B --> C{检查CRS坐标系统} + C -->|源文件无CRS| D[使用参考文件CRS] + C -->|源文件有CRS| E[使用源文件CRS] + D --> F[计算有效区域ROI] + E --> F + + F --> F1[基于有效掩膜计算包围盒] + F1 --> F2[将包围盒转换到参考坐标系] + F2 --> F3[扩展窗口ROI_PAD_PX] + + F3 --> G[读取图像数据] + G --> H[将源图掩膜重投影到参考空间] + H --> H1[可选: 膨胀掩膜MASK_PAD_PX] + + H1 --> I[图像预处理] + I --> I1[转3通道_float01格式] + I1 --> I2[百分位数拉伸归一化] + + I2 --> J[生成软掩膜] + J --> J1[距离变换生成羽化边缘] + J1 --> J2[应用掩膜到图像] + + J2 --> K[降采样用于匹配] + K --> K1[等比缩放到MATCH_MAX_SIDE] + + K1 --> L[特征匹配] + L --> L1[调用Matcher获取匹配点] + + L1 --> M[匹配点过滤] + M --> M1[边缘带剔除 EDGE_BAND_PX] + M --> M2[纹理过滤 MIN_GRAD_QUANTILE] + M1 --> M3[组合掩膜过滤] + M2 --> M3 + + M3 --> N{质量控制检查} + N -->|内点数|内点比例|通过检查| P[计算全分辨率坐标] + + P --> Q[估计多种变换模型] + Q --> Q1[similarity] + Q --> Q2[affine] + Q --> Q3[homography] + Q --> Q4[piecewise_affine] + Q --> Q5[polynomial] + + Q1 --> R[评估变换质量] + Q2 --> R + Q3 --> R + Q4 --> R + Q5 --> R + + R --> S[选择最优变换模型] + S --> T{变换类型判断} + + T -->|Affine| U[仿射变换处理] + T -->|Homography| V[单应变换处理] + T -->|Piecewise/Polynomial| W[非线性变换处理] + + U --> X[计算最小外接矩形] + V --> X + W --> X + + X --> Y[创建输出文件] + Y --> Z[逐波段几何重采样] + Z --> AA[保存配准结果] + AA --> AB[记录统计信息] + AB --> AC[结束] + + O --> AD[记录失败信息] + AD --> AC + + T -->|所有变换失败| AE[仿射回退处理] + AE -->|回退成功| AA + AE -->|回退失败| O diff --git a/__pycache__/tif_caijain.cpython-310.pyc b/__pycache__/tif_caijain.cpython-310.pyc new file mode 100644 index 0000000..1fbd339 Binary files /dev/null and b/__pycache__/tif_caijain.cpython-310.pyc differ diff --git a/test V10.py b/test V10.py new file mode 100644 index 0000000..ae526b0 --- /dev/null +++ b/test V10.py @@ -0,0 +1,1336 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +问题:当图像中大部分是水体时,匹配过多出现在掩膜边缘,同时过滤时将本来就少的陆地匹配点也过滤掉了 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +from vismatch.viz import plot_matches, plot_keypoints +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"D:\BaiduNetdiskDownload\20250902\_3_52_52\316\agnle") # .bip 文件所在文件夹 +OUT_DIR = Path(r"D:\BaiduNetdiskDownload\20250902\_3_52_52\316\jiaozhen") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["similarity", "affine", "homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) +MASK_PAD_PX = 100 # 匹配掩膜扩张像素(仅用于匹配阶段) + +# 质量控制阈值 +MIN_INLIERS = 10 +MIN_INLIER_RATIO = 0.01 + +# 掩膜边缘羽化与过滤 +FEATHER_PX = 20 # 掩膜羽化宽度(像素,先在全分辨率/ROI分辨率上做) +EDGE_BAND_PX = 30 # 剔除距离掩膜边界小于此像素的匹配点(在小图上按比例缩放) + +# 纹理过滤 +MIN_GRAD_QUANTILE = 0.20 # 梯度幅值的分位阈值(0~1),低于该阈值的点视为低纹理,剔除 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, col_end - col_off,) + + +def _pixel_size_xy(transform: Affine): + rx = float(np.hypot(transform.a, transform.d)) + ry = float(np.hypot(transform.b, transform.e)) + if not np.isfinite(rx) or rx <= 0: + rx = float(abs(transform.a)) if transform.a != 0 else 1.0 + if not np.isfinite(ry) or ry <= 0: + ry = float(abs(transform.e)) if transform.e != 0 else 1.0 + return rx, ry + + +def _grid_from_bounds(bounds, res_x: float, res_y: float): + left, bottom, right, top = [float(v) for v in bounds] + res_x = float(abs(res_x)) + res_y = float(abs(res_y)) + w = int(np.ceil((right - left) / max(1e-12, res_x))) + h = int(np.ceil((top - bottom) / max(1e-12, res_y))) + w = max(1, w) + h = max(1, h) + out_transform = Affine(res_x, 0.0, left, 0.0, -res_y, top) + return out_transform, w, h + + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def _downscale_mask_hw(mask_hw: np.ndarray, target_h: int, target_w: int) -> np.ndarray: + """将(H,W)二值掩膜缩放到目标尺寸,保持最近邻""" + m = cv2.resize(mask_hw.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST) + return m > 0 + +def _soft_alpha_from_mask(mask_hw: np.ndarray, feather_px: int) -> np.ndarray: + """ + 二值掩膜 -> 软掩膜 alpha∈[0,1],边缘处按距离线性上升,避免硬边缘。 + mask_hw: bool/uint8 (H,W) True/1表示有效 + """ + if mask_hw is None: + return None + m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255 + # 距离变换仅对前景内部有效,计算到边界的距离 + dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3) + if feather_px <= 0: + alpha = (dist > 0).astype(np.float32) + else: + alpha = np.clip(dist / float(feather_px), 0.0, 1.0).astype(np.float32) + return alpha # (H,W) float32 + +def _distance_keep_mask(mask_hw: np.ndarray, min_dist_px: int) -> np.ndarray: + """ + 生成"远离边界"的保留掩膜:仅保留距离边界>=min_dist_px的像素。 + """ + if mask_hw is None: + return None + m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255 + dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3) + keep = dist >= float(max(1, min_dist_px)) + return keep + +def _grad_mask_from_chw(img_chw: np.ndarray, quantile: float) -> np.ndarray: + """ + 根据梯度幅值生成纹理掩膜(H,W)True=纹理足够。 + 使用与匹配同尺寸的CHW图像。 + """ + # 转灰度 + g = img_chw.mean(axis=0).astype(np.float32) # (H,W) + gx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3) + gy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3) + mag = np.sqrt(gx*gx + gy*gy) + thr = float(np.quantile(mag, quantile)) if mag.size > 0 else 0.0 + return mag >= thr # (H,W) bool + +def _filter_matches_by_masks(result: dict, src_mask_small: np.ndarray, ref_mask_small: np.ndarray) -> dict: + """将匹配与内点严格限制在掩膜内""" + if src_mask_small is None or ref_mask_small is None: + return result + + def keep_in_mask(kpts: np.ndarray, mask_hw: np.ndarray) -> np.ndarray: + if kpts is None or len(kpts) == 0: + return np.zeros((0,), dtype=bool) + kpts = np.asarray(kpts) + xs = np.clip(np.rint(kpts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1) + ys = np.clip(np.rint(kpts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1) + return mask_hw[ys, xs] + + # 过滤 matched_kpts + if "matched_kpts0" in result and "matched_kpts1" in result: + mk0 = np.asarray(result["matched_kpts0"]) + mk1 = np.asarray(result["matched_kpts1"]) + if len(mk0) == len(mk1) and len(mk0) > 0: + keep_m = keep_in_mask(mk0, src_mask_small) & keep_in_mask(mk1, ref_mask_small) + result["matched_kpts0"] = mk0[keep_m] + result["matched_kpts1"] = mk1[keep_m] + + # 过滤 inlier_kpts + if "inlier_kpts0" in result and "inlier_kpts1" in result and result["inlier_kpts0"] is not None: + ik0 = np.asarray(result["inlier_kpts0"]) + ik1 = np.asarray(result["inlier_kpts1"]) + if len(ik0) == len(ik1) and len(ik0) > 0: + keep_i = keep_in_mask(ik0, src_mask_small) & keep_in_mask(ik1, ref_mask_small) + result["inlier_kpts0"] = ik0[keep_i] + result["inlier_kpts1"] = ik1[keep_i] + result["num_inliers"] = int(len(result["inlier_kpts0"])) + + return result + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用"源图有效掩膜"的包围盒推参考ROI(比整图bounds更贴近有效重叠) + try: + src_mask = (src.read_masks(1) > 0) # True=有效 + rows_any = np.any(src_mask, axis=1) + cols_any = np.any(src_mask, axis=0) + if rows_any.any() and cols_any.any(): + rmin = int(rows_any.argmax()) + rmax = int(src.height - 1 - rows_any[::-1].argmax()) + cmin = int(cols_any.argmax()) + cmax = int(src.width - 1 - cols_any[::-1].argmax()) + valid_win_src = rasterio.windows.Window(cmin, rmin, cmax - cmin + 1, rmax - rmin + 1) + valid_bounds_src = rasterio.windows.bounds(valid_win_src, transform=src.transform) + b = transform_bounds(src_crs, ref_crs, *valid_bounds_src, densify_pts=21) + else: + # 掩膜无效时回退到整图bounds + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + except Exception: + src_mask = None # 后续可选源图掩膜时用到 + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 将源图有效掩膜重投影到参考ROI,并适度膨胀后作为匹配掩膜 + try: + if src_mask is None: + src_mask = (src.read_masks(1) > 0) + ref_roi_transform = ref_dataset.window_transform(win) + roi_h, roi_w = int(win.height), int(win.width) + dst_mask = np.zeros((roi_h, roi_w), dtype=np.uint8) + + reproject( + source=src_mask.astype(np.uint8), + destination=dst_mask, + src_transform=src.transform, + src_crs=src_crs, + dst_transform=ref_roi_transform, + dst_crs=ref_crs, + resampling=Resampling.nearest + ) + + if MASK_PAD_PX > 0: + k = max(1, MASK_PAD_PX * 2 + 1) # odd kernel size + k = min(k, 99) # 防止核过大导致性能问题,可按需调整/删除 + kernel = np.ones((k, k), np.uint8) + dst_mask = cv2.dilate(dst_mask, kernel, iterations=1) + except Exception: + # 掩膜获取/重投影失败则不使用掩膜 + dst_mask = None + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 软掩膜:避免在边界产生硬高对比边 + try: + alpha_src = _soft_alpha_from_mask(src_mask, FEATHER_PX) if src_mask is not None else None + except Exception: + alpha_src = None + try: + alpha_ref = _soft_alpha_from_mask(dst_mask, FEATHER_PX) if dst_mask is not None else None + except Exception: + alpha_ref = None + + if alpha_src is not None: + alpha_src3 = np.repeat(alpha_src[None, ...], 3, axis=0).astype(src_img.dtype) + src_img = src_img * alpha_src3 + + if alpha_ref is not None: + alpha_ref3 = np.repeat(alpha_ref[None, ...], 3, axis=0).astype(ref_img.dtype) + ref_img = ref_img * alpha_ref3 + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + # 与小图同尺寸的掩膜 + src_mask_small = _downscale_mask_hw(src_mask, src_small.shape[1], src_small.shape[2]) if 'src_mask' in locals() and src_mask is not None else None + ref_mask_small = _downscale_mask_hw(dst_mask, ref_small.shape[1], ref_small.shape[2]) if 'dst_mask' in locals() and dst_mask is not None else None + + # 剔除掩膜边缘带(小图尺度的最小距离) + def _scale_px(px_full: int, full_wh, small_wh) -> int: + # 用平均缩放;也可以分别对H/W计算后取最小 + sy = small_wh[0] / max(1, full_wh[0]) + sx = small_wh[1] / max(1, full_wh[1]) + s = 0.5 * (sx + sy) + return max(1, int(round(px_full * s))) + + edge_band_src_small = _scale_px(EDGE_BAND_PX, (src_img.shape[1], src_img.shape[2]), (src_small.shape[1], src_small.shape[2])) + edge_band_ref_small = _scale_px(EDGE_BAND_PX, (ref_img.shape[1], ref_img.shape[2]), (ref_small.shape[1], ref_small.shape[2])) + + keep_src_edge = _distance_keep_mask(src_mask_small, edge_band_src_small) if src_mask_small is not None else None + keep_ref_edge = _distance_keep_mask(ref_mask_small, edge_band_ref_small) if ref_mask_small is not None else None + + # 纹理掩膜 + keep_src_tex = _grad_mask_from_chw(src_small, MIN_GRAD_QUANTILE) + keep_ref_tex = _grad_mask_from_chw(ref_small, MIN_GRAD_QUANTILE) + + # 组合最终保留掩膜(边缘+纹理),二者都要满足 + def _combine_keep(m_edge, m_tex): + if m_edge is None: + return m_tex + return (m_edge & m_tex) + + keep_src_final = _combine_keep(keep_src_edge, keep_src_tex) + keep_ref_final = _combine_keep(keep_ref_edge, keep_ref_tex) + + # 将匹配与内点严格限制在最终掩膜内 + def _filter_by_bool_masks(res, m_src, m_ref): + if m_src is None or m_ref is None: + return res + + def keep_in(mask_hw, pts): + if pts is None or len(pts) == 0: + return np.zeros((0,), dtype=bool) + xs = np.clip(np.rint(pts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1) + ys = np.clip(np.rint(pts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1) + return mask_hw[ys, xs] + + # matched + if "matched_kpts0" in res and "matched_kpts1" in res: + mk0 = np.asarray(res["matched_kpts0"]); mk1 = np.asarray(res["matched_kpts1"]) + if len(mk0) == len(mk1) and len(mk0) > 0: + keep_m = keep_in(m_src, mk0) & keep_in(m_ref, mk1) + res["matched_kpts0"] = mk0[keep_m] + res["matched_kpts1"] = mk1[keep_m] + + # inliers + if "inlier_kpts0" in res and "inlier_kpts1" in res and res["inlier_kpts0"] is not None: + ik0 = np.asarray(res["inlier_kpts0"]); ik1 = np.asarray(res["inlier_kpts1"]) + if len(ik0) == len(ik1) and len(ik0) > 0: + keep_i = keep_in(m_src, ik0) & keep_in(m_ref, ik1) + res["inlier_kpts0"] = ik0[keep_i] + res["inlier_kpts1"] = ik1[keep_i] + res["num_inliers"] = int(len(res["inlier_kpts0"])) + return res + + result = _filter_by_bool_masks(result, keep_src_final, keep_ref_final) + + # 统计(以过滤后的结果为准) + num_inl = int(result.get("num_inliers", len(result.get("inlier_kpts0", [])))) + num_m = len(result.get("matched_kpts0", [])) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + # 保存匹配可视化图像(使用与匹配同尺寸的图像,保持CHW格式) + viz_dir = out_dir / "visualizations" + viz_dir.mkdir(exist_ok=True) + + matches_path = viz_dir / f"{bip_path.stem}_matches.png" + plot_matches(src_small, ref_small, result, save_path=str(matches_path)) + logger.info(f"匹配可视化已保存: {matches_path}") + + # 关键点可视化(源图像) + kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png" + plot_keypoints( + src_small, + {"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]}, + save_path=str(kpts_src_path) + ) + logger.info(f"源图像关键点可视化已保存: {kpts_src_path}") + + # 关键点可视化(参考图像) + kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png" + plot_keypoints( + ref_small, + {"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]}, + save_path=str(kpts_ref_path) + ) + logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform) + + res_x, res_y = _pixel_size_xy(src.transform) + out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = src.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=out_h, + width=out_w, + count=src.count, + transform=out_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((out_h, out_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=out_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.nearest, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform) + + res_x, res_y = _pixel_size_xy(src.transform) + out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = src.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=out_h, + width=out_w, + count=src.count, + transform=out_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + ref_transform = ref_dataset.transform + Rt = np.array( + [[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0]], + dtype=np.float64, + ) + Out = np.array( + [[out_transform.a, out_transform.b, out_transform.c], + [out_transform.d, out_transform.e, out_transform.f], + [0.0, 0.0, 1.0]], + dtype=np.float64, + ) + M = np.linalg.inv(Out) @ Rt @ H_full.astype(np.float64) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = cv2.warpPerspective( + src_band, + M, + (out_w, out_h), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=float(dst_nodata) + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True, + order=0 + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True, + order=0 + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bounds = rasterio.windows.bounds(bbox_window, transform=ref_dataset.transform) + + res_x, res_y = _pixel_size_xy(src.transform) + out_transform, out_w, out_h = _grid_from_bounds(bounds, res_x, res_y) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = src.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=out_h, + width=out_w, + count=src.count, + transform=out_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((out_h, out_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=out_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.nearest, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test V7.py b/test V7.py index e57772c..9007255 100644 --- a/test V7.py +++ b/test V7.py @@ -1528,7 +1528,7 @@ def create_gui(): if __name__ == "__main__": if len(sys.argv) > 1 and sys.argv[1] == "--cli": # 命令行模式 - main() + main() else: # 默认GUI模式 create_gui() diff --git a/test V9.py b/test V9.py index dae40b8..7b82b1f 100644 --- a/test V9.py +++ b/test V9.py @@ -62,7 +62,7 @@ logger = logging.getLogger(__name__) # ---------- 配置 ---------- # 请根据实际情况修改这些路径 REF_TIF = r"E:\is2\guidingsahn\mask_water.tif" # 参考 tif 文件路径 -BIP_DIR = Path(r"E:\is2\guidingsahn") # .bip 文件所在文件夹 +BIP_DIR = Path(r"D:\BaiduNetdiskDownload\20250902\_3_52_52\Geoout\Geoout") # .bip 文件所在文件夹 OUT_DIR = Path(r"E:\is2\guidingsahn\output") # 输出文件夹 # 匹配算法选择 diff --git a/test V9GUI.py b/test V9GUI copy.py similarity index 99% rename from test V9GUI.py rename to test V9GUI copy.py index 9e63639..e6d2959 100644 --- a/test V9GUI.py +++ b/test V9GUI copy.py @@ -1345,7 +1345,10 @@ def _run_batch(cfg: RegistrationConfig, stop_event: threading.Event, progress_cb out_dir.mkdir(parents=True, exist_ok=True) stats_dir = out_dir / "stats" stats_dir.mkdir(parents=True, exist_ok=True) - stats_csv = stats_dir / "registration_stats.csv" + + ts = datetime.now().strftime('%Y%m%d_%H%M%S') + stats_csv = stats_dir / f"registration_stats_{ts}.csv" + logger.info(f"统计信息将保存到: {stats_csv}") init_stats_csv(stats_csv) diff --git a/tif_caijain.py b/tif_caijain.py new file mode 100644 index 0000000..c20b20a --- /dev/null +++ b/tif_caijain.py @@ -0,0 +1,186 @@ +""" +使用二值掩膜 TIF 文件(值为1的区域需要去除)对数据 TIF 文件进行掩膜。 +输入: + data_tif: 要掩膜的数据文件路径 + mask_tif: 二值掩膜文件路径(值为1表示需要去除的区域) +输出: + 掩膜后的数据 TIF 文件,仅将掩膜对应位置设为 NoData +要求: + 两个 TIF 文件具有相同的投影、分辨率、范围和尺寸(精确对齐), + 否则程序将报错或行为未定义。 +""" + +import argparse +import numpy as np +import rasterio +from rasterio.windows import Window +import logging +import sys +from pathlib import Path +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class _NullWriter: + def write(self, _s): + return 0 + + def flush(self): + return None + + +def _tqdm_output(): + for fp in (getattr(sys, "stderr", None), getattr(sys, "stdout", None)): + if fp is not None and hasattr(fp, "write"): + return fp + return _NullWriter() + + +def _tqdm_disable(fp) -> bool: + if isinstance(fp, _NullWriter): + return True + try: + return not bool(fp.isatty()) + except Exception: + return True + + +def mask_data_by_binary_mask( + data_path, + mask_path, + output_path=None, + remove_value=1, + nodata_value=None, + tile_size=4096, +): + """使用二值掩膜 TIF 对数据 TIF 进行掩膜。 + + 将数据 TIF 中对应掩膜值等于 remove_value 的像素设为 NoData,其余保留。 + + 性能建议: + - 若数据源是 tiled GeoTIFF,可将 tile_size 设为 0 以按源文件块窗口遍历(通常更快)。 + """ + data_path = Path(data_path) + mask_path = Path(mask_path) + + if output_path is None: + output_path = data_path.parent / f"{data_path.stem}_masked{data_path.suffix}" + else: + output_path = Path(output_path) + + logger.info(f"数据文件: {data_path.name}") + logger.info(f"掩膜文件: {mask_path.name}") + logger.info(f"去除掩膜值: {remove_value}") + + with rasterio.Env(GDAL_NUM_THREADS="ALL_CPUS"): + with rasterio.open(data_path) as src_data, rasterio.open(mask_path) as src_mask: + if src_data.crs != src_mask.crs: + raise ValueError("数据与掩膜的 CRS 不一致,请先统一投影。") + if src_data.transform != src_mask.transform: + logger.warning( + "数据与掩膜的地理变换不一致,可能未精确对齐,继续处理可能存在风险。" + ) + if (src_data.width, src_data.height) != (src_mask.width, src_mask.height): + raise ValueError("数据与掩膜的尺寸不一致,无法直接按像素对应掩膜。") + + # 确定输出 NoData 值(并尽量匹配数据 dtype,避免隐式类型转换带来的开销) + if nodata_value is None: + nodata_value = src_data.nodata if src_data.nodata is not None else 0 + try: + nodata_value_cast = np.array( + nodata_value, dtype=src_data.dtypes[0] + ).item() + except Exception: + nodata_value_cast = nodata_value + + # 创建输出元数据:基于数据源的元数据,更新 nodata 和压缩选项 + out_meta = src_data.meta.copy() + out_meta.update( + { + "nodata": nodata_value, + "compress": ( + src_data.compression.value if src_data.compression else "lzw" + ), + "tiled": src_data.is_tiled, + } + ) + if src_data.is_tiled: + out_meta.update( + { + "blockxsize": src_data.block_shapes[0][0], + "blockysize": src_data.block_shapes[0][1], + } + ) + + # 创建输出文件 + with rasterio.open(output_path, "w", **out_meta) as dst: + width, height = src_data.width, src_data.height + + if tile_size is None or tile_size <= 0: + windows = [w for _, w in src_data.block_windows(1)] + else: + stride = int(tile_size) + windows = [ + Window(i, j, min(stride, width - i), min(stride, height - j)) + for i in range(0, width, stride) + for j in range(0, height, stride) + ] + + tqdm_fp = _tqdm_output() + with tqdm( + total=len(windows), + desc="处理瓦片", + unit="块", + file=tqdm_fp, + disable=_tqdm_disable(tqdm_fp), + ) as pbar: + for window in windows: + # 读取相同位置的掩膜瓦片(假设完全对齐) + mask = src_mask.read(1, window=window) + remove_mask = mask == remove_value + + # 读取数据瓦片 + data = src_data.read(window=window) # shape: (bands, h, w) + + if remove_mask.any(): + for band_idx in range(data.shape[0]): + np.putmask( + data[band_idx], remove_mask, nodata_value_cast + ) + + dst.write(data, window=window) + pbar.update(1) + + logger.info(f"处理完成,输出文件:{output_path}") + return str(output_path) + + +def main(): + parser = argparse.ArgumentParser( + description="使用二值掩膜 TIF(值为1的区域)对数据 TIF 进行掩膜,将对应位置设为 NoData。" + ) + parser.add_argument("data_tif", help="要掩膜的数据 TIF 文件路径") + parser.add_argument("mask_tif", help="二值掩膜 TIF 文件路径(值为1表示需要去除的区域)") + parser.add_argument("-o", "--output", help="输出文件路径 (可选)") + parser.add_argument("-r", "--remove_value", type=int, default=1, + help="掩膜中要去除的值,默认为1") + parser.add_argument("-n", "--nodata", type=float, + help="输出 NoData 值 (可选,默认使用数据 TIF 的 NoData 或 0)") + parser.add_argument( + "-t", + "--tile_size", + type=int, + default=4096, + help="分块大小(像素),默认4096;设为0则按源文件块窗口遍历(tiled 文件通常更快)", + ) + + args = parser.parse_args() + mask_data_by_binary_mask( + args.data_tif, args.mask_tif, args.output, + args.remove_value, args.nodata, args.tile_size + ) + + +if __name__ == "__main__": + exit(main()) diff --git a/思维导图.png b/思维导图.png new file mode 100644 index 0000000..474fb32 Binary files /dev/null and b/思维导图.png differ diff --git a/说明文档 b/说明文档 new file mode 100644 index 0000000..604df49 --- /dev/null +++ b/说明文档 @@ -0,0 +1,186 @@ +# StripStitch(航带批量配准工具)使用说明 + +## 1. 工具简介 + +`StripStitch.py` 用于将一个文件夹中的多幅 `.bip` 航带影像批量配准到一幅参考底图 `.tif`(GeoTIFF)上,输出配准后的 BIP 文件,并生成配准统计 CSV、日志与可视化结果(若启用/代码包含)。 + +支持的能力(按当前脚本实现为准): +- GUI 图形界面:选择输入/输出、配置 matcher、选择变换模型、设置质量阈值与过滤参数 +- 参考底图可选“掩膜”:先用 `tif_caijain.py` 对底图做 NoData 掩膜,再用掩膜后的底图执行配准 +- 统计 CSV:每次运行生成独立文件(文件名包含时间戳) +- 出错弹窗:处理线程异常会弹窗显示完整 traceback,便于定位问题 + +--- + +## 2. 目录结构(建议) + +建议将脚本与相关文件保持在同一目录(当前工程已是): + +- `e:\code\vismatch-main\vismatch-main\test\StripStitch.py` +- `e:\code\vismatch-main\vismatch-main\test\tif_caijain.py` +- `e:\code\vismatch-main\vismatch-main\vismatch\`(matcher 依赖) + +--- + +## 3. 环境依赖 + +### 3.1 必需依赖(运行 GUI/配准) +- Python 3.9+(建议与现有环境一致) +- `numpy` +- `opencv-python`(cv2) +- `rasterio` +- `affine` +- `vismatch`(本仓库内模块) + +### 3.2 可选依赖(对应功能可用/不可用) +- `scikit-image`:piecewise/polynomial 变换相关 +- `matplotlib` + `scipy`:凸包内点判定等 +- `SimpleITK` / `pirt` / `scipy.interpolate`:TPS 相关 +- `tqdm`:掩膜处理进度条(已做无控制台环境兼容) + +--- + +## 4. 输入数据要求 + +### 4.1 参考底图(GeoTIFF) +- 文件格式:`.tif` / `.tiff` +- 应包含正确 CRS 与 transform(地理参考信息) + +### 4.2 待配准航带(ENVI BIP) +- 文件格式:`.bip` +- 建议包含合理的波段与 NoData(若有) + +### 4.3 (可选)底图掩膜(GeoTIFF) +启用“底图掩膜”时: +- 掩膜必须是与底图 **严格对齐** 的 GeoTIFF(同 CRS、同 transform、同 width/height) +- `tif_caijain.py` 默认逻辑:掩膜中值等于 `1` 的像素会被置为 NoData(可在后续扩展为 GUI 可配) + +--- + +## 5. GUI 使用方法(推荐) + +### 5.1 启动 GUI +在 Windows PowerShell 或 CMD 中运行: + +```powershell +cd "e:\code\vismatch-main\vismatch-main\test" +python "StripStitch.py" +``` + +### 5.2 GUI 字段说明 +- 参考TIF文件:要配准到的底图 +- 启用底图掩膜:勾选后需要选择掩膜 TIF,并会先生成“掩膜后的底图” +- BIP文件夹:包含待配准 `.bip` 的文件夹(程序会遍历 `*.bip`) +- 输出文件夹:所有输出写入此目录 +- 匹配算法:选择 matcher(例如 `matchanything-roma`) +- 设备:`cuda`(更快)或 `cpu` +- 变换方法:可多选,按优先级尝试(可上移/下移顺序) +- 参数设置:匹配缩放、ROI 扩展、质量阈值、边缘/纹理过滤等 + +### 5.3 开始/停止 +- 开始处理:开始批量配准 +- 停止处理:设置停止标记,当前文件处理结束后停止 + +--- + +## 6. 命令行模式(CLI) + +脚本保留了 CLI 模式入口(若实现支持 `--cli`): + +```powershell +python "StripStitch.py" --cli +``` + +说明: +- CLI 模式下通常使用脚本顶部的默认配置(REF_TIF / BIP_DIR / OUT_DIR 等)。 +- 若需要完整参数化 CLI,可后续再扩展 argparse。 + +--- + +## 7. 输出说明 + +假设输出目录为 `OUT_DIR`: + +### 7.1 统计 CSV +每次运行会生成一个新 CSV: + +- `OUT_DIR\stats\registration_stats_YYYYMMDD_HHMMSS.csv` + +内容包含(示例字段): +- 时间戳、文件名、匹配点/内点数、内点比例、所选方法、误差统计、成功与否等 + +### 7.2 掩膜后的底图(如果启用底图掩膜) +- `OUT_DIR\masked_refs\<底图名>_masked_<时间戳>.tif` + +配准时会使用这个掩膜后的底图作为参考。 + +### 7.3 配准输出 +- 脚本会将每个 `.bip` 配准后输出到 `OUT_DIR`(具体命名以代码为准,例如 `_registered.bip`)。 + +--- + +## 8. 常见问题与排查 + +### 8.1 处理过程中发生错误: No module named 'src' +原因: +- `matchanything-*` 依赖第三方源码目录(`vismatch/third_party/MatchAnything/.../src`),打包或运行时没有被正确加入 `sys.path`。 + +解决: +- 确保 PyInstaller spec 中把 `vismatch/third_party` 打进包里(Tree 方式) +- 运行时脚本会尝试从 `_MEIPASS` 和 exe 目录自动寻找 third_party 并加入 `sys.path` +- 如果仍失败,检查 dist 目录下是否存在: + - `dist\StripStitch\_internal\vismatch\third_party\MatchAnything\imcui\third_party\MatchAnything\src\` + +### 8.2 No module named 'loguru' +原因: +- MatchAnything 的第三方代码中引用了 `loguru` + +解决: +- 打包时将 `loguru` 加入 hiddenimports(或安装 loguru) +- 脚本也提供了缺失时的兼容 stub(以避免直接崩溃) + +### 8.3 底图掩膜时报错:'NoneType' object has no attribute 'write' +原因: +- PyInstaller `console=False` 时 tqdm 可能没有可写的输出流 + +解决: +- `tif_caijain.py` 已处理:无控制台环境会自动禁用 tqdm 或使用安全输出,不应再崩溃 +- 重新打包并运行最新代码 + +--- + +## 9. 打包(PyInstaller,文件夹模式 onedir) + +### 9.1 使用 spec 打包 +在 `test` 目录执行: + +```powershell +pyinstaller --clean -y "e:\code\vismatch-main\vismatch-main\test\StripStitch.spec" +``` + +输出: +- `dist\StripStitch\StripStitch.exe`(以及同目录依赖文件) + +### 9.2 打包注意事项 +- 若你改动了 `.spec` 或 `.py`,务必 `--clean` 重新打包 +- 深度学习 matcher 依赖多,出现 “No module named xxx” 时通常需要: + - spec 增加 hiddenimports + - 或把第三方源码目录作为 datas 打进去 + - 或做运行时兼容 stub(仅当该依赖不影响核心逻辑/可替代) + +--- + +## 10. 建议的使用流程(从零到一) + +1) 准备底图 `result.tif` 与待配准 `.bip` 文件夹 +2) (可选)准备与底图对齐的掩膜 `result_mask.tif`(值=1 为去除区域) +3) 启动 GUI,选择底图、BIP 文件夹、输出目录 +4) 选择 matcher(建议从 `matchanything-roma` 或你已验证可用的 matcher 开始) +5) 选择变换方法(建议先 `affine + homography`) +6) 点“开始处理”,观察日志与进度 +7) 处理结束后查看: + - 输出 BIP + - `stats\registration_stats_*.csv` + - (可选)`masked_refs\*_masked_*.tif` + +--- \ No newline at end of file