From 5e0984bf9c2ff26ec3351035fb5fcefb3b38f191 Mon Sep 17 00:00:00 2001 From: zhanghuilai Date: Fri, 6 Mar 2026 17:24:55 +0800 Subject: [PATCH] first commit --- .idea/.gitignore | 8 + .idea/inspectionProfiles/Project_Default.xml | 44 + .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 7 + .idea/test.iml | 7 + README.md | 176 ++ demo.py | 21 + mask_water.py | 131 ++ test V2.py | 359 ++++ test V3.py | 970 +++++++++++ test V4.py | 467 +++++ test V5.1.py | 1085 ++++++++++++ test V5.py | 1058 ++++++++++++ test V6.py | 1509 ++++++++++++++++ test V7.py | 1534 +++++++++++++++++ test V8.py | 1207 +++++++++++++ test V9.py | 1299 ++++++++++++++ test.py | 290 ++++ 18 files changed, 10178 insertions(+) create mode 100644 .idea/.gitignore create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/test.iml create mode 100644 README.md create mode 100644 demo.py create mode 100644 mask_water.py create mode 100644 test V2.py create mode 100644 test V3.py create mode 100644 test V4.py create mode 100644 test V5.1.py create mode 100644 test V5.py create mode 100644 test V6.py create mode 100644 test V7.py create mode 100644 test V8.py create mode 100644 test V9.py create mode 100644 test.py diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..f649f0f --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..419e41a --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,44 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..ce9d49a --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/test.iml b/.idea/test.iml new file mode 100644 index 0000000..ec63674 --- /dev/null +++ b/.idea/test.iml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..353a2a1 --- /dev/null +++ b/README.md @@ -0,0 +1,176 @@ +# 图像配准工具集 (Image Registration Toolkit) + +这是一个基于 Python 的遥感图像配准工具集,主要用于批量配准 .bip 格式的影像文件到参考 .tif 文件。支持多种变换模型和质量控制机制,特别适用于森林、水体等复杂地物场景。 + +## 功能特性 + +### 核心功能 +- **批量配准**: 支持批量处理多个 .bip 文件到参考 .tif 文件 +- **多种变换模型**: 支持相似变换、仿射变换、单应变换等 +- **智能掩膜**: 使用有效区域掩膜和水体分割提高配准质量 +- **质量控制**: 基于内点数、内点比例和重投影误差的质量评估 +- **可视化输出**: 自动生成匹配结果和关键点可视化图像 + +### 高级特性 +- **RANSAC 参数控制**: 可配置的重投影误差阈值、置信度和最大迭代次数 +- **纹理过滤**: 基于梯度幅值的纹理质量筛选 +- **几何约束**: 距离边界限制和空间分布均匀性检查 +- **多尺度处理**: 支持不同分辨率的匹配和变换 + +## 文件说明 + +### 核心脚本 +- `test V8.py`: 最新版本的主配准脚本,支持有效区域掩膜 +- `test V9.py`: 增强版本,包含软掩膜和纹理过滤 +- `demo.py`: 基础演示脚本,展示匹配和可视化功能 + +### 工具脚本 +- `mask_water.py`: 基于 ROI 文件掩膜 TIF 图像的工具 +- `test V5.py` - `test V7.py`: 早期版本的配准脚本 + +## 环境要求 + +### 依赖包 +``` +numpy +opencv-python +rasterio +geopandas +shapely +scikit-image +matplotlib +scipy +torch (用于 vismatch) +vismatch +``` + +### 安装 +```bash +pip install numpy opencv-python rasterio geopandas shapely scikit-image matplotlib scipy torch +# 安装 vismatch (假设已配置) +``` + +## 使用方法 + +### 基本配准 +```bash +python test_V8.py # 运行主配准脚本 +``` + +### 参数配置 +在脚本开头修改以下关键参数: + +```python +# 文件路径 +REF_TIF = r"E:\path\to\reference.tif" # 参考影像路径 +BIP_DIR = Path(r"E:\path\to\bip\directory") # BIP文件目录 +OUT_DIR = Path(r"E:\path\to\output") # 输出目录 + +# 匹配参数 +MATCHER_NAME = "matchanything-roma" # 匹配器选择 +MATCH_MAX_SIDE = 1200 # 匹配最大边长 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法优先级 +TRANSFORM_METHODS = ["similarity", "affine", "homography"] + +# RANSAC参数 +RANSAC_REPROJ_THRESHOLD = 3.0 # 重投影误差阈值 +RANSAC_CONFIDENCE = 0.99 # 置信度 +RANSAC_MAX_ITERS = 1000 # 最大迭代次数 + +# 质量控制 +MIN_INLIERS = 10 # 最少内点数 +MIN_INLIER_RATIO = 0.01 # 最少内点比例 +``` + +### ROI掩膜工具 +```bash +python mask_water.py input.tif roi.shp -o output_masked.tif +``` + +## 算法流程 + +1. **粗定位**: 使用源影像有效区域确定参考影像的感兴趣区域 (ROI) +2. **数据读取**: 读取源影像和参考影像的对应区域 +3. **掩膜处理**: 应用有效区域掩膜和水体分割 +4. **特征匹配**: 使用深度学习匹配器提取和匹配特征点 +5. **几何变换**: 尝试多种变换模型,选择最优的几何变换 +6. **质量评估**: 基于重投影误差评估变换质量 +7. **输出生成**: 生成配准后的影像和可视化结果 + +## 关键参数说明 + +### 匹配参数 +- `MATCH_MAX_SIDE`: 控制匹配时的图像尺寸,越小速度越快但细节损失越多 +- `ROI_PAD_PX`: ROI扩展像素,影响匹配区域范围 + +### RANSAC参数 +- `RANSAC_REPROJ_THRESHOLD`: 内点判断的误差阈值,越小内点质量越高但数量越少 +- `RANSAC_CONFIDENCE`: 算法置信度,影响迭代次数 +- `RANSAC_MAX_ITERS`: 最大迭代次数上限 + +### 质量控制 +- `MIN_INLIERS`: 最少内点数量要求 +- `MIN_INLIER_RATIO`: 内点比例要求 + +## 场景优化建议 + +### 森林场景 +```python +MATCH_MAX_SIDE = 1600 +MIN_INLIERS = 25 +MIN_INLIER_RATIO = 0.025 +# 启用纹理过滤和边界距离限制 +``` + +### 水体场景 +```python +# 使用水体掩膜预处理 +MASK_PAD_PX = 100 +# 调大RANSAC阈值容忍岸线变化 +RANSAC_REPROJ_THRESHOLD = 4.0 +``` + +### 城市/结构化场景 +```python +MATCH_MAX_SIDE = 1000 # 结构特征明显,不需要很高分辨率 +RANSAC_REPROJ_THRESHOLD = 2.0 # 严格质量控制 +``` + +## 输出结果 + +### 文件输出 +- `*_registered.bip`: 配准后的影像文件 +- `visualizations/*_matches.png`: 匹配结果可视化 +- `visualizations/*_keypoints_src.png`: 源影像关键点 +- `visualizations/*_keypoints_ref.png`: 参考影像关键点 + +### 统计输出 +- `stats/registration_stats.csv`: 包含每对影像的配准统计信息 + +## 故障排除 + +### 常见问题 +1. **内点数过少**: 检查影像质量,调整 RANSAC 参数,或扩大匹配区域 +2. **内存不足**: 降低 `MATCH_MAX_SIDE` 或使用 CPU 模式 +3. **配准失败**: 检查坐标系统一致性,调整质量阈值 + +### 调试建议 +- 查看可视化输出了解匹配情况 +- 检查统计 CSV 文件的各项指标 +- 调整参数时从小范围测试开始 + +## 版本历史 + +- **test V8.py**: 支持有效区域掩膜和 RANSAC 参数控制 +- **test V9.py**: 增加软掩膜和纹理过滤 +- **早期版本**: 基础配准功能实现 + +## 许可证 + +本项目遵循 MIT 许可证。 + +## 贡献 + +欢迎提交 Issue 和 Pull Request 来改进这个工具集。 \ No newline at end of file diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..ea58bad --- /dev/null +++ b/demo.py @@ -0,0 +1,21 @@ + +from vismatch import get_matcher +from vismatch.viz import plot_matches,plot_keypoints + +# Choose any of the 50+ matchers listed below +matcher = get_matcher("matchanything-roma", device="cuda") +img_size = 512 # optional + +img0 = matcher.load_image(r"D:\汇报\水库项目资料\gif\屏幕截图 2026-02-28 140816.png", resize=img_size) +img1 = matcher.load_image(r"D:\汇报\水库项目资料\gif\屏幕截图 2026-02-28 140525.png", resize=img_size) + +result = matcher(img0, img1) +# result.keys() = ["num_inliers", "H", "all_kpts0", "all_kpts1", "all_desc0", "all_desc1", "matched_kpts0", "matched_kpts1", "inlier_kpts0", "inlier_kpts1"] + +# This will plot visualizations for matches as shown in the figures above +plot_matches(img0, img1, result, save_path="plot_matches.png") + +# Or you can extract and visualize keypoints as easily as +result = matcher.extract(img0) +# result.keys() = ["all_kpts0", "all_desc0"] +plot_keypoints(img0, result, save_path="plot_kpts.png") \ No newline at end of file diff --git a/mask_water.py b/mask_water.py new file mode 100644 index 0000000..ce6e52b --- /dev/null +++ b/mask_water.py @@ -0,0 +1,131 @@ +""" +掩膜tif文件的ROI区域并保存 + +输入: tif文件路径, roi文件路径(shp/geojson等) +输出: 掩膜后的tif文件 +""" + +import argparse +import numpy as np +import rasterio +from rasterio.mask import mask +from rasterio.features import shapes +import geopandas as gpd +# from shapely.geometry import shape +import logging +from pathlib import Path + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +def mask_tif_by_roi(tif_path: str, roi_path: str, output_path: str = None, + nodata_value: float = None): + """ + 使用ROI文件掩膜TIF文件 + + Args: + tif_path: 输入TIF文件路径 + roi_path: ROI文件路径(shp/geojson等) + output_path: 输出文件路径,如果为None则自动生成 + nodata_value: nodata值,如果为None则使用原始文件的nodata值 + + Returns: + str: 输出文件路径 + """ + + tif_path = Path(tif_path) + roi_path = Path(roi_path) + + # 检查输入文件存在性 + if not tif_path.exists(): + raise FileNotFoundError(f"TIF文件不存在: {tif_path}") + if not roi_path.exists(): + raise FileNotFoundError(f"ROI文件不存在: {roi_path}") + + # 自动生成输出路径 + if output_path is None: + output_path = tif_path.parent / f"{tif_path.stem}_masked{tif_path.suffix}" + else: + output_path = Path(output_path) + + logger.info(f"开始处理: {tif_path.name}") + logger.info(f"ROI文件: {roi_path.name}") + + try: + # 读取ROI文件 + logger.info("读取ROI文件...") + gdf = gpd.read_file(roi_path) + + if gdf.empty: + logger.warning("ROI文件为空,直接复制原文件") + # 直接复制原文件 + import shutil + shutil.copy2(tif_path, output_path) + return str(output_path) + + # 转换为GeoJSON格式的几何对象 + geometries = gdf.geometry.tolist() + logger.info(f"找到 {len(geometries)} 个ROI几何对象") + + # 读取并掩膜TIF + logger.info("读取并掩膜TIF文件...") + with rasterio.open(tif_path) as src: + # 使用ROI掩膜 - 保留ROI以外的区域(ROI区域设为nodata) + masked_data, masked_transform = mask( + src, + geometries, + crop=False, # 不裁剪,保持原始尺寸 + invert=True, # 反转:ROI区域设为nodata,ROI以外保留 + nodata=nodata_value if nodata_value is not None else src.nodata + ) + + # 更新元数据 + out_meta = src.meta.copy() + out_meta.update({ + "driver": "GTiff", + "height": masked_data.shape[1], + "width": masked_data.shape[2], + "transform": masked_transform, + "nodata": nodata_value if nodata_value is not None else src.nodata + }) + + # 保存结果 + logger.info(f"保存掩膜结果到: {output_path.name}") + with rasterio.open(output_path, "w", **out_meta) as dest: + dest.write(masked_data) + + logger.info("处理完成!") + return str(output_path) + + except Exception as e: + logger.error(f"处理失败: {str(e)}") + raise + + +def main(): + parser = argparse.ArgumentParser(description="使用ROI文件掩膜TIF文件") + parser.add_argument("tif_path", help="输入TIF文件路径") + parser.add_argument("roi_path", help="ROI文件路径 (shp/geojson等)") + parser.add_argument("-o", "--output", help="输出文件路径 (可选,默认自动生成)") + parser.add_argument("-n", "--nodata", type=float, help="nodata值 (可选,默认使用原文件nodata)") + + args = parser.parse_args() + + try: + output_path = mask_tif_by_roi( + args.tif_path, + args.roi_path, + args.output, + args.nodata + ) + print(f"成功完成!输出文件: {output_path}") + except Exception as e: + print(f"错误: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/test V2.py b/test V2.py new file mode 100644 index 0000000..4cf1858 --- /dev/null +++ b/test V2.py @@ -0,0 +1,359 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +仿射变换修改 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging +from osgeo import gdal + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\jiashixian\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\jiashixian\Geoout\1") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\jiashixian\matchanything-roma") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 30 # 最少内点数 +MIN_INLIER_RATIO = 0.15 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# ---------- 工具函数 ---------- +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + return False + + # 5) 用内点估计仿射变换(相似/全仿射二选一) + k0 = result["inlier_kpts0"].astype(np.float32) + k1 = result["inlier_kpts1"].astype(np.float32) + + # 相似(旋转+等比缩放+平移,4DOF) + A_partial, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + # 全仿射(旋转+非等比缩放+剪切+平移,6DOF) + A_full, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + + def _reproj_err(A2x3: np.ndarray, pts0: np.ndarray, pts1: np.ndarray): + if A2x3 is None: + return np.inf, np.inf + ones = np.ones((pts0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([pts0, ones]) # (N,3) + pred = (A2x3 @ src_h.T).T # (N,2) + e = np.sqrt(((pred - pts1) ** 2).sum(axis=1)) + return float(np.median(e)), float(np.percentile(e, 95)) + + med_p, p95_p = _reproj_err(A_partial, k0, k1) + med_f, p95_f = _reproj_err(A_full, k0, k1) + + if (p95_f < p95_p) or (abs(p95_f - p95_p) < 0.5 and med_f < med_p): + A = A_full + model_type = "affine_full" + else: + A = A_partial + model_type = "affine_partial" + + if A is None: + logger.warning(f"仿射估计失败: {bip_path.name}") + return False + + logger.info(f"选用模型: {model_type}, partial(p50={med_p:.2f},p95={p95_p:.2f}), full(p50={med_f:.2f},p95={p95_f:.2f})") + + # 6) 基于内点构建 GCP,并使用 GDAL TPS(薄板样条)实现非刚性(B样条近似效果)配准 + # 注:GDAL 提供 TPS(薄板样条),在遥感配准中与B样条同属光滑非刚性方法,工程上更常用 + # 这里我们直接用 TPS 来实现你所需的“B样条变换”效果 + # 将 small 尺寸的内点坐标映射回 full 分辨率 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0_inv = np.array([[s0x, 0, 0], [0, s0y, 0], [0, 0, 1]], dtype=np.float64) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) # small -> full (ref ROI) + + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0, ones]).T).T[:, :2] # 源:full 源像素 + k1_roi_full = (S1_inv @ np.hstack([k1, ones]).T).T[:, :2] + # 加上 ROI 偏移,得到参考影像全局像素坐标 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) + + if k0_full.shape[0] < 8: + logger.warning(f"TPS 需要更多控制点(>=8),当前 {k0_full.shape[0]},回退到仿射输出。") + else: + # 生成 GCP:目标为参考地图坐标(由全局像素 + ref_transform 转换) + ref_transform = ref_dataset.transform + def px_to_map(xp, yp): + X = ref_transform.a * xp + ref_transform.b * yp + ref_transform.c + Y = ref_transform.d * xp + ref_transform.e * yp + ref_transform.f + return X, Y + + gcps = [] + for (sx, sy), (rx, ry) in zip(k0_full, k1_global): + mx, my = px_to_map(rx, ry) + gcps.append(gdal.GCP(mx, my, 0.0, float(sx), float(sy))) + + # 计算参考影像的输出边界与分辨率 + minx = ref_transform.c + maxy = ref_transform.f + maxx = minx + ref_transform.a * ref_dataset.width + ref_transform.b * ref_dataset.height + miny = maxy + ref_transform.d * ref_dataset.width + ref_transform.e * ref_dataset.height + + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 输出路径(ENVI+BIP) + out_path = out_dir / f"{bip_path.stem}_registered.bip" + + # 使用 GDAL Warp 执行 TPS 变换到参考网格 + # 注意:需要系统已安装 GDAL Python 绑定 + try: + warp_opts = gdal.WarpOptions( + format="ENVI", + outputBounds=(minx, miny, maxx, maxy), + xRes=abs(ref_transform.a), + yRes=abs(ref_transform.e), + dstSRS=ref_crs.to_wkt() if hasattr(ref_crs, "to_wkt") else str(ref_crs), + tps=True, # 启用薄板样条(非刚性) + resampleAlg="bilinear", + srcNodata=src_nodata, + dstNodata=dst_nodata, + multithread=True, + creationOptions=["INTERLEAVE=BIP"], + gcps=gcps, + ) + + # 直接从源文件路径进行Warp + warp_ds = gdal.Warp( + destNameOrDestDS=str(out_path), + srcDSOrSrcDSTab=str(bip_path), + options=warp_opts, + ) + if warp_ds is None: + logger.warning("GDAL Warp(TPS) 失败,回退到仿射输出。") + else: + warp_ds = None # 关闭文件 + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + except Exception as e: + logger.warning(f"GDAL Warp(TPS) 异常,回退到仿射输出: {e}") + + # ---- 回退:使用仿射(你前面的流程),保证最小可用结果 ---- + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + # 仍沿用上面已估计的 A(partial/full 二选一) + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=ref_dataset.height, + width=ref_dataset.width, + count=src.count, + transform=ref_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((ref_dataset.height, ref_dataset.width), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=ref_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V3.py b/test V3.py new file mode 100644 index 0000000..2bf8f69 --- /dev/null +++ b/test V3.py @@ -0,0 +1,970 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用 SimpleITK 实现 B 样条变换 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\guidingsahn\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\guidingsahn") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\guidingsahn\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1500 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 10 # 最少内点数 +MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# ---------- 工具函数 ---------- +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + elif method == "tps": + # 薄板样条变换(如果SimpleITK可用) + SITK_TPS = SITK_AVAILABLE and hasattr(sitk, "ThinPlateSplineKernelTransform") + if not SITK_TPS: + return None, None + try: + tps_transform = sitk.ThinPlateSplineKernelTransform() + tps_transform.SetKernelTypeToThinPlateSpline() + + fixed_landmarks = sitk.vectorDPoint() + moving_landmarks = sitk.vectorDPoint() + + for (rx, ry), (sx, sy) in zip(k1, k0): + fixed_landmarks.push_back([float(rx), float(ry)]) + moving_landmarks.push_back([float(sx), float(sy)]) + + tps_transform.SetFixedLandmarks(fixed_landmarks) + tps_transform.SetMovingLandmarks(moving_landmarks) + return "tps", tps_transform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "tps": + # TPS 变换重投影误差(SimpleITK) + pred = [] + for pt in k0: + transformed_pt = transform.TransformPoint([float(pt[0]), float(pt[1])]) + pred.append([transformed_pt[0], transformed_pt[1]]) + pred = np.array(pred) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + return False + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type == "polynomial": + # 对于多项式,估计逆变换 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=2) # 顺序:目标->源 + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + xy_src = t_inv(xy) # -> (x_src, y_src) in full-src + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + else: # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type == "tps": + # B样条变换:优先使用 PIRT,如果不可用则使用 SimpleITK TPS + try: + if PIRT_AVAILABLE: + # 使用 PIRT 实现 B样条弹性变换 + logger.info("使用 PIRT B样条变换") + + # 读取用于配准的单波段并归一化 + ref_roi_data = ref_dataset.read(1, window=win).astype(np.float32) + src_band_data = src.read(1).astype(np.float32) + + from skimage.exposure import rescale_intensity + ref_roi_reg = rescale_intensity(ref_roi_data, in_range='image', out_range=(0.0, 1.0)) + src_reg = rescale_intensity(src_band_data, in_range='image', out_range=(0.0, 1.0)) + + # 构建 PIRT 注册器 + reg = pirt.Registration(fixed=ref_roi_reg, moving=src_reg) + + # 设置 B样条变换 + bspline = pirt.transform.BSplineTransform(grid_spacing=(96, 96)) # 可调节控制点间距 + reg.set_transformation(bspline) + + # 设置相似度度量(NCC 或 MI) + reg.set_similarity(pirt.metrics.NCC()) + + # 多分辨率金字塔 + reg.set_pyramid([4, 2, 1]) + + # 优化器设置 + reg.set_optimizer(pirt.optimizers.LBFGS(max_iter=200), smooth=1.0) + + # 执行注册 + reg.run() + + # 获取前向映射(参考网格到源的位移场) + phi = reg.get_forward_mapping() # (H, W, 2) 位移场 + + # 应用到所有波段 + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=ref_roi_data.shape[0], + width=ref_roi_data.shape[1], + count=src.count, + transform=ref_dataset.window_transform(win), + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + # 使用 PIRT 的 warp 函数应用位移场 + warped = pirt.warp(src_band, phi, mode='constant', cval=float(dst_nodata)) + band_data = warped.astype(out_profile["dtype"]) + out_ds.write(band_data, b) + + logger.info(f"成功配准(B样条-PIRT): {bip_path.name} -> {out_path.name}") + return True + + elif SITK_AVAILABLE and hasattr(sitk, "LandmarkBasedTransformInitializer"): + # 回退到 SimpleITK TPS + logger.info("PIRT 不可用,使用 SimpleITK TPS") + + # 1) 统一坐标系:用"全图像素"作为物理坐标(spacing=1, origin=(0,0), direction=I) + fixed_pts = [(float(x1), float(y1)) for (x1,y1) in k1_global] # 参考(输出)侧 + moving_pts = [(float(x0), float(y0)) for (x0,y0) in k0_full] # 源(输入)侧 + + # 2) 构造 TPS(参考→源) 用于 Resample(输出点 -> 输入点) + tps_ref2src = sitk.LandmarkBasedTransformInitializer( + sitk.Transform(2, sitk.sitkThinPlateSplineKernelTransform), + fixed_pts, # fixed = 参考 + moving_pts # moving = 源 + ) + + # 3) 构造 TPS(源→参考) 仅用于外接矩形估计(源顶点投到参考) + tps_src2ref = sitk.LandmarkBasedTransformInitializer( + sitk.Transform(2, sitk.sitkThinPlateSplineKernelTransform), + moving_pts, # fixed = 源 + fixed_pts # moving = 参考 + ) + + # 4) 用 tps_src2ref 变换源四角,求参考全图上的外接矩形,并与参考范围求交 + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + dst_corners = np.array([tps_src2ref.TransformPoint((float(x),float(y))) for x,y in src_corners], dtype=np.float32) + + min_x = max(0, int(np.floor(dst_corners[:,0].min())) - 10) + max_x = min(ref_dataset.width, int(np.ceil (dst_corners[:,0].max())) + 10) + min_y = max(0, int(np.floor(dst_corners[:,1].min())) - 10) + max_y = min(ref_dataset.height, int(np.ceil (dst_corners[:,1].max())) + 10) + bbox_w, bbox_h = max_x-min_x, max_y-min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"TPS最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 5) 参考(输出)影像定义:spacing=1,origin=(min_x,min_y),direction=I + ref_img = sitk.Image([bbox_w, bbox_h], sitk.sitkFloat32) + ref_img.SetSpacing((1.0, 1.0)) + ref_img.SetOrigin((float(min_x), float(min_y))) + ref_img.SetDirection((1.0,0.0,0.0,1.0)) + + # 6) 源影像:用 rasterio 读为 numpy,再转 SITK(spacing=1, origin=0, direction=I) + src_band = src.read(1).astype(np.float32) + src_img = sitk.GetImageFromArray(src_band) + + # 7) 重采样:设置参考图像 + 变换=参考→源(tps_ref2src) + res = sitk.ResampleImageFilter() + res.SetReferenceImage(ref_img) + res.SetTransform(tps_ref2src) + res.SetInterpolator(sitk.sitkLinear) + if src.nodata is not None: + res.SetDefaultPixelValue(float(src.nodata)) + + # 8) 写 ENVI/BIP:对所有波段逐一 TPS 重采样 + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + src_img_band = sitk.GetImageFromArray(src_band) + warped = res.Execute(src_img_band) + band_data = sitk.GetArrayFromImage(warped).astype(out_profile["dtype"]) + out_ds.write(band_data, b) + + logger.info(f"成功配准(TPS-SimpleITK): {bip_path.name} -> {out_path.name}") + return True + + else: + logger.warning("PIRT 和 SimpleITK TPS 都不可用,回退到仿射") + # 继续到仿射回退 + + except Exception as e: + logger.warning(f"B样条变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V4.py b/test V4.py new file mode 100644 index 0000000..b5ac605 --- /dev/null +++ b/test V4.py @@ -0,0 +1,467 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +直接进行配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\yaopu") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\yaopu\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 使用密集匹配模型的稠密流直接进行配准 + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 10 # 最少内点数 +MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# ---------- 工具函数 ---------- +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + return False + + # ==== 稠密流直接重采样(无需后续显式变换估计) ==== + + # 1) 取稠密流(优先 ref->src)。不同模型的键名可能不同,这里做兼容 + flow_small = None + for k in ["flow_ref2src", "flow21", "flow_1_0", "flow10", "flow"]: + if k in result: + flow_small = result[k] + break + + if flow_small is None: + # 回退:优先 DIS 光流(更快/稳),若不可用再用 Farneback (ref -> src) + ref_small_rgb = np.transpose(ref_small, (1, 2, 0)) # (H,W,3) + src_small_rgb = np.transpose(src_small, (1, 2, 0)) + ref_small_gray = cv2.cvtColor((ref_small_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) + src_small_gray = cv2.cvtColor((src_small_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) + + flow_small = None + try: + dis = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_MEDIUM) + flow_small = dis.calc(ref_small_gray, src_small_gray, None).astype(np.float32) + except Exception: + pass + + if flow_small is None: + # 典型参数:可按影像特性微调(窗口、迭代次数等) + flow_small = cv2.calcOpticalFlowFarneback( + ref_small_gray, src_small_gray, + None, 0.5, 3, 25, 3, 5, 1.2, 0 + ).astype(np.float32) + + # flow_small 期望形状 (h_s, w_s, 2),分量为 (dx, dy): 参考像素到源像素的位移 + flow_small = np.asarray(flow_small, dtype=np.float32) + if flow_small.ndim != 3 or flow_small.shape[2] != 2: + logger.warning(f"稠密流形状异常: {flow_small.shape}") + return False + + # 2) 将小图的流放大到 ROI 全分辨率,并按比例放大位移 + roi_h, roi_w = ref_img.shape[1], ref_img.shape[2] # 注意 ref_img 是 ROI 子图 + scale_x = roi_w / flow_small.shape[1] + scale_y = roi_h / flow_small.shape[0] + flow_full = cv2.resize(flow_small, (roi_w, roi_h), interpolation=cv2.INTER_LINEAR) + flow_full[..., 0] *= scale_x # dx + flow_full[..., 1] *= scale_y # dy + + # 3) 生成 remap 所需的源坐标图(map_x, map_y),在"参考ROI坐标系"内工作 + yy, xx = np.meshgrid(np.arange(roi_h, dtype=np.float32), + np.arange(roi_w, dtype=np.float32), indexing="ij") + map_x = xx + flow_full[..., 0] # 到源图的 x(列) + map_y = yy + flow_full[..., 1] # 到源图的 y(行) + + # 4) 根据有效映射范围求最小外接矩形(仅统计落在源图范围内的像素) + valid = (map_x >= 0) & (map_x <= (src.width - 1)) & (map_y >= 0) & (map_y <= (src.height - 1)) + if not np.any(valid): + logger.warning(f"稠密流无有效映射: {bip_path.name}") + return False + + ys, xs = np.where(valid) + pad = 0 + min_y = max(int(ys.min()) - pad, 0) + max_y = min(int(ys.max()) + 1 + pad, roi_h) + min_x = max(int(xs.min()) - pad, 0) + max_x = min(int(xs.max()) + 1 + pad, roi_w) + + crop_h = max_y - min_y + crop_w = max_x - min_x + if crop_h <= 0 or crop_w <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 只对外接矩形区域做重采样,减少内存 + map_x_crop = map_x[min_y:max_y, min_x:max_x].astype(np.float32) + map_y_crop = map_y[min_y:max_y, min_x:max_x].astype(np.float32) + + # 5) 计算输出的地理变换:参考ROI窗口 + 外接矩形子窗口 + # 先得到 ROI 的 transform,再叠加子窗口偏移 + roi_transform = ref_dataset.window_transform(win) + crop_window_global = rasterio.windows.Window( + win.col_off + min_x, win.row_off + min_y, crop_w, crop_h + ) + out_transform = ref_dataset.window_transform(crop_window_global) + + # 6) 写出 ENVI/BIP(按最小外接矩形) + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=crop_h, + width=crop_w, + count=src.count, + transform=out_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + # 反向映射采样:输出像素在参考ROI坐标,去源图(map_y,map_x)取值 + warped = cv2.remap( + src_band, map_x_crop, map_y_crop, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=float(dst_nodata) + ).astype(np.float32) + + # 转回目标 dtype,保持 nodata + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (warped == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + warped = np.clip(warped, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + warped[mask] = dst_nodata + else: + warped = warped.astype(out_profile["dtype"]) + + out_ds.write(warped, b) + + logger.info(f"成功配准(DenseFlow): {bip_path.name} -> {out_path.name}") + return True + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V5.1.py b/test V5.1.py new file mode 100644 index 0000000..480ee18 --- /dev/null +++ b/test V5.1.py @@ -0,0 +1,1085 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用 实现非刚性配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +from vismatch.viz import plot_matches, plot_keypoints +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\yaopu") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\yaopu\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 10 # 最少内点数 +MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + # 保存匹配可视化图像(使用与匹配同尺寸的图像,保持CHW格式) + viz_dir = out_dir / "visualizations" + viz_dir.mkdir(exist_ok=True) + + matches_path = viz_dir / f"{bip_path.stem}_matches.png" + plot_matches(src_small, ref_small, result, save_path=str(matches_path)) + logger.info(f"匹配可视化已保存: {matches_path}") + + # 关键点可视化(源图像) + kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png" + plot_keypoints( + src_small, + {"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]}, + save_path=str(kpts_src_path) + ) + logger.info(f"源图像关键点可视化已保存: {kpts_src_path}") + + # 关键点可视化(参考图像) + kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png" + plot_keypoints( + ref_small, + {"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]}, + save_path=str(kpts_ref_path) + ) + logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V5.py b/test V5.py new file mode 100644 index 0000000..e0448ed --- /dev/null +++ b/test V5.py @@ -0,0 +1,1058 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用 实现非刚性配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\yaopu") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\yaopu\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 10 # 最少内点数 +MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V6.py b/test V6.py new file mode 100644 index 0000000..fc94794 --- /dev/null +++ b/test V6.py @@ -0,0 +1,1509 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用 实现非刚性配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging +import threading +import queue +from dataclasses import dataclass +import tkinter as tk +from tkinter import ttk, filedialog, messagebox +import sys +import os + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + +@dataclass +class Config: + """配置参数类""" + ref_tif: str + bip_dir: str + out_dir: str + matcher_name: str + device: str + transform_methods: list + match_max_side: int + roi_pad_px: int + min_inliers: int + min_inlier_ratio: float + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 默认配置,请根据实际情况修改这些路径 +DEFAULT_REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +DEFAULT_BIP_DIR = r"E:\is2\yaopu" # .bip 文件所在文件夹 +DEFAULT_OUT_DIR = r"E:\is2\yaopu\output" # 输出文件夹 + +# 默认匹配算法选择 +DEFAULT_MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEFAULT_DEVICE = "cuda" # 或 "cpu" + +# 默认变换方法选择(按优先级尝试) +DEFAULT_TRANSFORM_METHODS = ["homography", "affine", "piecewise_affine"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial" + +# 默认匹配参数 +DEFAULT_MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +DEFAULT_ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 默认质量控制阈值 +DEFAULT_MIN_INLIERS = 10 # 最少内点数 +DEFAULT_MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 向后兼容的全局变量(用于命令行模式) +REF_TIF = DEFAULT_REF_TIF +BIP_DIR = Path(DEFAULT_BIP_DIR) +OUT_DIR = Path(DEFAULT_OUT_DIR) +MATCHER_NAME = DEFAULT_MATCHER_NAME +DEVICE = DEFAULT_DEVICE +TRANSFORM_METHODS = DEFAULT_TRANSFORM_METHODS +MATCH_MAX_SIDE = DEFAULT_MATCH_MAX_SIDE +ROI_PAD_PX = DEFAULT_ROI_PAD_PX +MIN_INLIERS = DEFAULT_MIN_INLIERS +MIN_INLIER_RATIO = DEFAULT_MIN_INLIER_RATIO + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def run_batch(config: Config, on_progress=None, on_log=None, stop_event=None): + """批量配准处理函数 + + Args: + config: 配置参数 + on_progress: 进度回调函数 (current_idx, total, filename) + on_log: 日志回调函数 (message) + stop_event: 停止事件,用于取消处理 + """ + def log(message): + if on_log: + on_log(message) + logger.info(message) + + log("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(config.ref_tif).exists(): + log(f"参考文件不存在: {config.ref_tif}") + return + + if not Path(config.bip_dir).exists(): + log(f"BIP文件夹不存在: {config.bip_dir}") + return + + # 创建输出目录 + out_dir = Path(config.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # 创建统计输出目录和文件 + stats_dir = out_dir / "stats" + stats_dir.mkdir(parents=True, exist_ok=True) + stats_csv = stats_dir / "registration_stats.csv" + + # 初始化统计CSV文件 + init_stats_csv(stats_csv) + log(f"统计信息将保存到: {stats_csv}") + + # 初始化匹配器 + log(f"初始化匹配器: {config.matcher_name} on {config.device}") + matcher = get_matcher(config.matcher_name, device=config.device) + + # 打开参考文件 + with rasterio.open(config.ref_tif) as ref: + log(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_dir = Path(config.bip_dir) + bip_files = list(bip_dir.glob("*.bip")) + log(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for i, bip_path in enumerate(bip_files): + if stop_event and stop_event.is_set(): + log("处理被用户取消") + break + + if on_progress: + on_progress(i, len(bip_files), bip_path.name) + + if process_bip_to_tif(bip_path, ref, matcher, out_dir, stats_csv): + success_count += 1 + + if on_progress: + on_progress(len(bip_files), len(bip_files), "完成") + + log(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +def main(): + """命令行入口""" + # 使用默认配置运行 + config = Config( + ref_tif=REF_TIF, + bip_dir=BIP_DIR, + out_dir=OUT_DIR, + matcher_name=MATCHER_NAME, + device=DEVICE, + transform_methods=TRANSFORM_METHODS, + match_max_side=MATCH_MAX_SIDE, + roi_pad_px=ROI_PAD_PX, + min_inliers=MIN_INLIERS, + min_inlier_ratio=MIN_INLIER_RATIO + ) + run_batch(config) + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + +# ---------- GUI 相关 ---------- +class QueueHandler(logging.Handler): + """自定义日志处理器,将日志发送到队列""" + def __init__(self, log_queue): + super().__init__() + self.log_queue = log_queue + + def emit(self, record): + self.log_queue.put(self.format(record)) + +class RegistrationGUI: + def __init__(self, root): + self.root = root + self.root.title("遥感影像批量配准工具") + self.root.geometry("1000x800") + + # 日志队列和停止事件 + self.log_queue = queue.Queue() + self.stop_event = threading.Event() + self.processing_thread = None + + # 设置日志处理器 + queue_handler = QueueHandler(self.log_queue) + queue_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) + logger.addHandler(queue_handler) + logger.setLevel(logging.INFO) + + # 创建GUI组件 + self.create_widgets() + + # 定期检查日志队列 + self.check_log_queue() + + def create_widgets(self): + """创建GUI组件""" + # 主框架 + main_frame = ttk.Frame(self.root, padding="10") + main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) + + # 配置展开面板 + config_frame = ttk.LabelFrame(main_frame, text="配置参数", padding="5") + config_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(0, 10)) + + # 输入文件选择 + ttk.Label(config_frame, text="参考TIF文件:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) + self.ref_tif_var = tk.StringVar(value=DEFAULT_REF_TIF) + ttk.Entry(config_frame, textvariable=self.ref_tif_var, width=50).grid(row=0, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件", command=self.select_ref_tif).grid(row=0, column=2) + + ttk.Label(config_frame, text="BIP文件夹:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5)) + self.bip_dir_var = tk.StringVar(value=DEFAULT_BIP_DIR) + ttk.Entry(config_frame, textvariable=self.bip_dir_var, width=50).grid(row=1, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件夹", command=self.select_bip_dir).grid(row=1, column=2) + + ttk.Label(config_frame, text="输出文件夹:").grid(row=2, column=0, sticky=tk.W, padx=(0, 5)) + self.out_dir_var = tk.StringVar(value=DEFAULT_OUT_DIR) + ttk.Entry(config_frame, textvariable=self.out_dir_var, width=50).grid(row=2, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件夹", command=self.select_out_dir).grid(row=2, column=2) + + # 匹配器选择 + ttk.Label(config_frame, text="匹配算法:").grid(row=3, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0)) + self.matcher_var = tk.StringVar(value=DEFAULT_MATCHER_NAME) + matcher_combo = ttk.Combobox(config_frame, textvariable=self.matcher_var, width=47) + matcher_combo['values'] = [ + "liftfeat", "loftr", "eloftr", "se2loftr", "xoftr", "aspanformer", + "matchanything-eloftr", "matchanything-roma", "matchformer", + "sift-lightglue", "superpoint-lightglue", "disk-lightglue", + "aliked-lightglue", "doghardnet-lightglue", "roma", "romav2", + "tiny-roma", "dedode", "steerers", "affine-steerers", + "dedode-kornia", "sift-nn", "orb-nn", "patch2pix", "superglue", + "r2d2", "d2net", "duster", "master", "doghardnet-nn", "xfeat", + "xfeat-star", "xfeat-lightglue", "dedode-lightglue", "gim-dkm", + "gim-lightglue", "omniglue", "xfeat-subpx", "xfeat-lightglue-subpx", + "dedode-subpx", "superpoint-lightglue-subpx", "aliked-lightglue-subpx", + "sift-sphereglue", "superpoint-sphereglue", "minima", "minima-roma", + "minima-roma-tiny", "minima-superpoint-lightglue", "minima-loftr", + "minima-xoftr", "edm", "lisrd-aliked", "lisrd-superpoint", "lisrd", + "lisrd-sift", "ripe", "topicfm", "topicfm-plus", "silk", "zippypoint", + "xfeat-steerers-perm", "xfeat-steerers-learned", "xfeat-star-steerers-perm", + "xfeat-star-steerers-learned" + ] + matcher_combo.grid(row=3, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + # 设备选择 + ttk.Label(config_frame, text="设备:").grid(row=4, column=0, sticky=tk.W, padx=(0, 5)) + self.device_var = tk.StringVar(value=DEFAULT_DEVICE) + device_frame = ttk.Frame(config_frame) + device_frame.grid(row=4, column=1, columnspan=2, sticky=(tk.W, tk.E)) + ttk.Radiobutton(device_frame, text="CUDA", variable=self.device_var, value="cuda").pack(side=tk.LEFT) + ttk.Radiobutton(device_frame, text="CPU", variable=self.device_var, value="cpu").pack(side=tk.LEFT) + + # 变换方法选择 + ttk.Label(config_frame, text="变换方法 (按优先级):").grid(row=5, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0)) + + transform_frame = ttk.Frame(config_frame) + transform_frame.grid(row=5, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + # 变换方法列表 + self.transform_listbox = tk.Listbox(transform_frame, selectmode=tk.MULTIPLE, height=5, exportselection=False) + transform_methods = ["similarity", "affine", "homography", "piecewise_affine", "polynomial"] + for method in transform_methods: + self.transform_listbox.insert(tk.END, method) + if method in DEFAULT_TRANSFORM_METHODS: + self.transform_listbox.selection_set(transform_methods.index(method)) + + scrollbar = ttk.Scrollbar(transform_frame, orient=tk.VERTICAL, command=self.transform_listbox.yview) + self.transform_listbox.configure(yscrollcommand=scrollbar.set) + + self.transform_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 移动按钮 + button_frame = ttk.Frame(transform_frame) + button_frame.pack(side=tk.RIGHT, padx=(5, 0)) + ttk.Button(button_frame, text="↑ 上移", command=self.move_up).pack(fill=tk.X, pady=(0, 2)) + ttk.Button(button_frame, text="↓ 下移", command=self.move_down).pack(fill=tk.X) + + # 参数设置 + param_frame = ttk.LabelFrame(config_frame, text="参数设置", padding="5") + param_frame.grid(row=6, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(10, 0)) + + ttk.Label(param_frame, text="匹配最大边长:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) + self.match_max_side_var = tk.IntVar(value=DEFAULT_MATCH_MAX_SIDE) + ttk.Entry(param_frame, textvariable=self.match_max_side_var, width=10).grid(row=0, column=1, sticky=tk.W) + + ttk.Label(param_frame, text="ROI填充像素:").grid(row=0, column=2, sticky=tk.W, padx=(10, 5)) + self.roi_pad_px_var = tk.IntVar(value=DEFAULT_ROI_PAD_PX) + ttk.Entry(param_frame, textvariable=self.roi_pad_px_var, width=10).grid(row=0, column=3, sticky=tk.W) + + ttk.Label(param_frame, text="最少内点数:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0)) + self.min_inliers_var = tk.IntVar(value=DEFAULT_MIN_INLIERS) + ttk.Entry(param_frame, textvariable=self.min_inliers_var, width=10).grid(row=1, column=1, sticky=tk.W, pady=(5, 0)) + + ttk.Label(param_frame, text="最少内点比例:").grid(row=1, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0)) + self.min_inlier_ratio_var = tk.DoubleVar(value=DEFAULT_MIN_INLIER_RATIO) + ttk.Entry(param_frame, textvariable=self.min_inlier_ratio_var, width=10).grid(row=1, column=3, sticky=tk.W, pady=(5, 0)) + + # 控制按钮 + control_frame = ttk.Frame(main_frame) + control_frame.grid(row=1, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + self.start_btn = ttk.Button(control_frame, text="开始处理", command=self.start_processing) + self.start_btn.pack(side=tk.LEFT, padx=(0, 10)) + + self.stop_btn = ttk.Button(control_frame, text="停止处理", command=self.stop_processing, state=tk.DISABLED) + self.stop_btn.pack(side=tk.LEFT) + + # 进度条 + progress_frame = ttk.LabelFrame(main_frame, text="处理进度", padding="5") + progress_frame.grid(row=2, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + self.progress_var = tk.DoubleVar() + self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100) + self.progress_bar.pack(fill=tk.X, pady=(0, 5)) + + self.progress_label = ttk.Label(progress_frame, text="准备就绪") + self.progress_label.pack(anchor=tk.W) + + # 日志窗口 + log_frame = ttk.LabelFrame(main_frame, text="处理日志", padding="5") + log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(10, 0)) + + # 日志文本框和滚动条 + log_text_frame = ttk.Frame(log_frame) + log_text_frame.pack(fill=tk.BOTH, expand=True) + + self.log_text = tk.Text(log_text_frame, height=15, wrap=tk.WORD) + scrollbar = ttk.Scrollbar(log_text_frame, orient=tk.VERTICAL, command=self.log_text.yview) + self.log_text.configure(yscrollcommand=scrollbar.set) + + self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 日志控制按钮 + log_btn_frame = ttk.Frame(log_frame) + log_btn_frame.pack(fill=tk.X, pady=(5, 0)) + + ttk.Button(log_btn_frame, text="清空日志", command=self.clear_log).pack(side=tk.LEFT, padx=(0, 5)) + ttk.Button(log_btn_frame, text="保存日志", command=self.save_log).pack(side=tk.LEFT) + + # 配置网格权重 + self.root.columnconfigure(0, weight=1) + self.root.rowconfigure(0, weight=1) + main_frame.columnconfigure(1, weight=1) + main_frame.rowconfigure(3, weight=1) + + def select_ref_tif(self): + """选择参考TIF文件""" + filename = filedialog.askopenfilename( + title="选择参考TIF文件", + filetypes=[("TIF files", "*.tif"), ("All files", "*.*")] + ) + if filename: + self.ref_tif_var.set(filename) + + def select_bip_dir(self): + """选择BIP文件夹""" + dirname = filedialog.askdirectory(title="选择BIP文件夹") + if dirname: + self.bip_dir_var.set(dirname) + + def select_out_dir(self): + """选择输出文件夹""" + dirname = filedialog.askdirectory(title="选择输出文件夹") + if dirname: + self.out_dir_var.set(dirname) + + def move_up(self): + """上移选中的变换方法""" + selection = self.transform_listbox.curselection() + if selection and selection[0] > 0: + idx = selection[0] + text = self.transform_listbox.get(idx) + self.transform_listbox.delete(idx) + self.transform_listbox.insert(idx - 1, text) + self.transform_listbox.selection_set(idx - 1) + + def move_down(self): + """下移选中的变换方法""" + selection = self.transform_listbox.curselection() + if selection and selection[0] < self.transform_listbox.size() - 1: + idx = selection[0] + text = self.transform_listbox.get(idx) + self.transform_listbox.delete(idx) + self.transform_listbox.insert(idx + 1, text) + self.transform_listbox.selection_set(idx + 1) + + def start_processing(self): + """开始处理""" + if self.processing_thread and self.processing_thread.is_alive(): + messagebox.showwarning("警告", "处理正在进行中") + return + + # 获取选中的变换方法 + selected_indices = self.transform_listbox.curselection() + if not selected_indices: + messagebox.showwarning("警告", "请至少选择一种变换方法") + return + + transform_methods = [] + for idx in selected_indices: + transform_methods.append(self.transform_listbox.get(idx)) + + # 创建配置 + config = Config( + ref_tif=self.ref_tif_var.get(), + bip_dir=self.bip_dir_var.get(), + out_dir=self.out_dir_var.get(), + matcher_name=self.matcher_var.get(), + device=self.device_var.get(), + transform_methods=transform_methods, + match_max_side=self.match_max_side_var.get(), + roi_pad_px=self.roi_pad_px_var.get(), + min_inliers=self.min_inliers_var.get(), + min_inlier_ratio=self.min_inlier_ratio_var.get() + ) + + # 重置停止事件 + self.stop_event.clear() + + # 禁用开始按钮,启用停止按钮 + self.start_btn.config(state=tk.DISABLED) + self.stop_btn.config(state=tk.NORMAL) + self.progress_var.set(0) + self.progress_label.config(text="正在初始化...") + + # 在后台线程中运行处理 + self.processing_thread = threading.Thread( + target=self.run_processing, + args=(config,), + daemon=True + ) + self.processing_thread.start() + + def stop_processing(self): + """停止处理""" + if self.processing_thread and self.processing_thread.is_alive(): + self.stop_event.set() + self.progress_label.config(text="正在停止...") + + def run_processing(self, config): + """在后台线程中运行处理""" + try: + run_batch(config, self.on_progress, self.on_log, self.stop_event) + except Exception as e: + self.log_queue.put(f"处理过程中发生错误: {e}") + finally: + # 恢复按钮状态 + self.root.after(0, lambda: self.start_btn.config(state=tk.NORMAL)) + self.root.after(0, lambda: self.stop_btn.config(state=tk.DISABLED)) + self.root.after(0, lambda: self.progress_label.config(text="处理完成")) + + def on_progress(self, current, total, filename): + """进度回调""" + if total > 0: + progress = (current / total) * 100 + self.root.after(0, lambda: self.progress_var.set(progress)) + self.root.after(0, lambda: self.progress_label.config(text=f"处理中: {filename} ({current}/{total})")) + + def on_log(self, message): + """日志回调""" + self.log_queue.put(message) + + def check_log_queue(self): + """检查日志队列并更新GUI""" + try: + while True: + message = self.log_queue.get_nowait() + self.log_text.insert(tk.END, message + '\n') + self.log_text.see(tk.END) + except queue.Empty: + pass + + # 每100ms检查一次 + self.root.after(100, self.check_log_queue) + + def clear_log(self): + """清空日志""" + self.log_text.delete(1.0, tk.END) + + def save_log(self): + """保存日志""" + filename = filedialog.asksaveasfilename( + title="保存日志", + defaultextension=".txt", + filetypes=[("Text files", "*.txt"), ("All files", "*.*")] + ) + if filename: + with open(filename, 'w', encoding='utf-8') as f: + f.write(self.log_text.get(1.0, tk.END)) + +def create_gui(): + """创建GUI""" + root = tk.Tk() + app = RegistrationGUI(root) + root.mainloop() + +if __name__ == "__main__": + if len(sys.argv) > 1 and sys.argv[1] == "--cli": + # 命令行模式 + main() + else: + # 默认GUI模式 + create_gui() \ No newline at end of file diff --git a/test V7.py b/test V7.py new file mode 100644 index 0000000..e57772c --- /dev/null +++ b/test V7.py @@ -0,0 +1,1534 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用 实现非刚性配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging +import threading +import queue +from dataclasses import dataclass +import tkinter as tk +from tkinter import ttk, filedialog, messagebox +import sys +import os + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + +@dataclass +class Config: + """配置参数类""" + ref_tif: str + bip_dir: str + out_dir: str + matcher_name: str + device: str + transform_methods: list + match_max_side: int + roi_pad_px: int + min_inliers: int + min_inlier_ratio: float + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 默认配置,请根据实际情况修改这些路径 +DEFAULT_REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +DEFAULT_BIP_DIR = r"E:\is2\yaopu" # .bip 文件所在文件夹 +DEFAULT_OUT_DIR = r"E:\is2\yaopu\output" # 输出文件夹 + +# 默认匹配算法选择 +DEFAULT_MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEFAULT_DEVICE = "cuda" # 或 "cpu" + +# 默认变换方法选择(按优先级尝试) +DEFAULT_TRANSFORM_METHODS = ["homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial" + +# 默认匹配参数 +DEFAULT_MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +DEFAULT_ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 默认质量控制阈值 +DEFAULT_MIN_INLIERS = 10 # 最少内点数 +DEFAULT_MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 向后兼容的全局变量(用于命令行模式) +REF_TIF = DEFAULT_REF_TIF +BIP_DIR = Path(DEFAULT_BIP_DIR) +OUT_DIR = Path(DEFAULT_OUT_DIR) +MATCHER_NAME = DEFAULT_MATCHER_NAME +DEVICE = DEFAULT_DEVICE +TRANSFORM_METHODS = DEFAULT_TRANSFORM_METHODS +MATCH_MAX_SIDE = DEFAULT_MATCH_MAX_SIDE +ROI_PAD_PX = DEFAULT_ROI_PAD_PX +MIN_INLIERS = DEFAULT_MIN_INLIERS +MIN_INLIER_RATIO = DEFAULT_MIN_INLIER_RATIO + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def process_bip_to_tif( + bip_path: Path, + ref_dataset, + matcher, + out_dir: Path, + stats_csv: Path, + transform_methods: list, + match_max_side: int, + roi_pad_px: int, + min_inliers: int, + min_inlier_ratio: float, +): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, roi_pad_px, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, match_max_side) + ref_small = _downscale_chw(ref_img, match_max_side) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < min_inliers or ratio < min_inlier_ratio: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in transform_methods: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def run_batch(config: Config, on_progress=None, on_log=None, stop_event=None): + """批量配准处理函数 + + Args: + config: 配置参数 + on_progress: 进度回调函数 (current_idx, total, filename) + on_log: 日志回调函数 (message) + stop_event: 停止事件,用于取消处理 + """ + def log(message): + if on_log: + on_log(message) + logger.info(message) + + log("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(config.ref_tif).exists(): + log(f"参考文件不存在: {config.ref_tif}") + return + + if not Path(config.bip_dir).exists(): + log(f"BIP文件夹不存在: {config.bip_dir}") + return + + # 创建输出目录 + out_dir = Path(config.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # 创建统计输出目录和文件 + stats_dir = out_dir / "stats" + stats_dir.mkdir(parents=True, exist_ok=True) + stats_csv = stats_dir / "registration_stats.csv" + + # 初始化统计CSV文件 + init_stats_csv(stats_csv) + log(f"统计信息将保存到: {stats_csv}") + + # 初始化匹配器 + log(f"初始化匹配器: {config.matcher_name} on {config.device}") + matcher = get_matcher(config.matcher_name, device=config.device) + + # 打开参考文件 + with rasterio.open(config.ref_tif) as ref: + log(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_dir = Path(config.bip_dir) + bip_files = list(bip_dir.glob("*.bip")) + log(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for i, bip_path in enumerate(bip_files): + if stop_event and stop_event.is_set(): + log("处理被用户取消") + break + + if on_progress: + on_progress(i, len(bip_files), bip_path.name) + + if process_bip_to_tif( + bip_path, ref, matcher, out_dir, stats_csv, + config.transform_methods, + config.match_max_side, + config.roi_pad_px, + config.min_inliers, + config.min_inlier_ratio, + ): + success_count += 1 + + if on_progress: + on_progress(len(bip_files), len(bip_files), "完成") + + log(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +def main(): + """命令行入口""" + # 使用默认配置运行 + config = Config( + ref_tif=REF_TIF, + bip_dir=BIP_DIR, + out_dir=OUT_DIR, + matcher_name=MATCHER_NAME, + device=DEVICE, + transform_methods=TRANSFORM_METHODS, + match_max_side=MATCH_MAX_SIDE, + roi_pad_px=ROI_PAD_PX, + min_inliers=MIN_INLIERS, + min_inlier_ratio=MIN_INLIER_RATIO + ) + run_batch(config) + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +# ---------- GUI 相关 ---------- +class QueueHandler(logging.Handler): + """自定义日志处理器,将日志发送到队列""" + def __init__(self, log_queue): + super().__init__() + self.log_queue = log_queue + + def emit(self, record): + self.log_queue.put(self.format(record)) + +class RegistrationGUI: + def __init__(self, root): + self.root = root + self.root.title("遥感影像批量配准工具") + self.root.geometry("1000x800") + + # 日志队列和停止事件 + self.log_queue = queue.Queue() + self.stop_event = threading.Event() + self.processing_thread = None + + # 设置日志处理器 + queue_handler = QueueHandler(self.log_queue) + queue_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) + logger.addHandler(queue_handler) + logger.setLevel(logging.INFO) + + # 创建GUI组件 + self.create_widgets() + + # 定期检查日志队列 + self.check_log_queue() + + def create_widgets(self): + """创建GUI组件""" + # 主框架 + main_frame = ttk.Frame(self.root, padding="10") + main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) + + # 配置展开面板 + config_frame = ttk.LabelFrame(main_frame, text="配置参数", padding="5") + config_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(0, 10)) + + # 输入文件选择 + ttk.Label(config_frame, text="参考TIF文件:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) + self.ref_tif_var = tk.StringVar(value=DEFAULT_REF_TIF) + ttk.Entry(config_frame, textvariable=self.ref_tif_var, width=50).grid(row=0, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件", command=self.select_ref_tif).grid(row=0, column=2) + + ttk.Label(config_frame, text="BIP文件夹:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5)) + self.bip_dir_var = tk.StringVar(value=DEFAULT_BIP_DIR) + ttk.Entry(config_frame, textvariable=self.bip_dir_var, width=50).grid(row=1, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件夹", command=self.select_bip_dir).grid(row=1, column=2) + + ttk.Label(config_frame, text="输出文件夹:").grid(row=2, column=0, sticky=tk.W, padx=(0, 5)) + self.out_dir_var = tk.StringVar(value=DEFAULT_OUT_DIR) + ttk.Entry(config_frame, textvariable=self.out_dir_var, width=50).grid(row=2, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件夹", command=self.select_out_dir).grid(row=2, column=2) + + # 匹配器选择 + ttk.Label(config_frame, text="匹配算法:").grid(row=3, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0)) + self.matcher_var = tk.StringVar(value=DEFAULT_MATCHER_NAME) + matcher_combo = ttk.Combobox(config_frame, textvariable=self.matcher_var, width=47) + matcher_combo['values'] = [ + "liftfeat", "loftr", "eloftr", "se2loftr", "xoftr", "aspanformer", + "matchanything-eloftr", "matchanything-roma", "matchformer", + "sift-lightglue", "superpoint-lightglue", "disk-lightglue", + "aliked-lightglue", "doghardnet-lightglue", "roma", "romav2", + "tiny-roma", "dedode", "steerers", "affine-steerers", + "dedode-kornia", "sift-nn", "orb-nn", "patch2pix", "superglue", + "r2d2", "d2net", "duster", "master", "doghardnet-nn", "xfeat", + "xfeat-star", "xfeat-lightglue", "dedode-lightglue", "gim-dkm", + "gim-lightglue", "omniglue", "xfeat-subpx", "xfeat-lightglue-subpx", + "dedode-subpx", "superpoint-lightglue-subpx", "aliked-lightglue-subpx", + "sift-sphereglue", "superpoint-sphereglue", "minima", "minima-roma", + "minima-roma-tiny", "minima-superpoint-lightglue", "minima-loftr", + "minima-xoftr", "edm", "lisrd-aliked", "lisrd-superpoint", "lisrd", + "lisrd-sift", "ripe", "topicfm", "topicfm-plus", "silk", "zippypoint", + "xfeat-steerers-perm", "xfeat-steerers-learned", "xfeat-star-steerers-perm", + "xfeat-star-steerers-learned" + ] + matcher_combo.grid(row=3, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + # 设备选择 + ttk.Label(config_frame, text="设备:").grid(row=4, column=0, sticky=tk.W, padx=(0, 5)) + self.device_var = tk.StringVar(value=DEFAULT_DEVICE) + device_frame = ttk.Frame(config_frame) + device_frame.grid(row=4, column=1, columnspan=2, sticky=(tk.W, tk.E)) + ttk.Radiobutton(device_frame, text="CUDA", variable=self.device_var, value="cuda").pack(side=tk.LEFT) + ttk.Radiobutton(device_frame, text="CPU", variable=self.device_var, value="cpu").pack(side=tk.LEFT) + + # 变换方法选择 + ttk.Label(config_frame, text="变换方法 (按优先级):").grid(row=5, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0)) + + transform_frame = ttk.Frame(config_frame) + transform_frame.grid(row=5, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + # 变换方法列表 + self.transform_listbox = tk.Listbox(transform_frame, selectmode=tk.MULTIPLE, height=5, exportselection=False) + transform_methods = ["similarity", "affine", "homography", "piecewise_affine", "polynomial"] + for method in transform_methods: + self.transform_listbox.insert(tk.END, method) + if method in DEFAULT_TRANSFORM_METHODS: + self.transform_listbox.selection_set(transform_methods.index(method)) + + scrollbar = ttk.Scrollbar(transform_frame, orient=tk.VERTICAL, command=self.transform_listbox.yview) + self.transform_listbox.configure(yscrollcommand=scrollbar.set) + + self.transform_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 移动按钮 + button_frame = ttk.Frame(transform_frame) + button_frame.pack(side=tk.RIGHT, padx=(5, 0)) + ttk.Button(button_frame, text="↑ 上移", command=self.move_up).pack(fill=tk.X, pady=(0, 2)) + ttk.Button(button_frame, text="↓ 下移", command=self.move_down).pack(fill=tk.X) + + # 参数设置 + param_frame = ttk.LabelFrame(config_frame, text="参数设置", padding="5") + param_frame.grid(row=6, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(10, 0)) + + ttk.Label(param_frame, text="匹配最大边长:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) + self.match_max_side_var = tk.IntVar(value=DEFAULT_MATCH_MAX_SIDE) + ttk.Entry(param_frame, textvariable=self.match_max_side_var, width=10).grid(row=0, column=1, sticky=tk.W) + + ttk.Label(param_frame, text="ROI填充像素:").grid(row=0, column=2, sticky=tk.W, padx=(10, 5)) + self.roi_pad_px_var = tk.IntVar(value=DEFAULT_ROI_PAD_PX) + ttk.Entry(param_frame, textvariable=self.roi_pad_px_var, width=10).grid(row=0, column=3, sticky=tk.W) + + ttk.Label(param_frame, text="最少内点数:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0)) + self.min_inliers_var = tk.IntVar(value=DEFAULT_MIN_INLIERS) + ttk.Entry(param_frame, textvariable=self.min_inliers_var, width=10).grid(row=1, column=1, sticky=tk.W, pady=(5, 0)) + + ttk.Label(param_frame, text="最少内点比例:").grid(row=1, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0)) + self.min_inlier_ratio_var = tk.DoubleVar(value=DEFAULT_MIN_INLIER_RATIO) + ttk.Entry(param_frame, textvariable=self.min_inlier_ratio_var, width=10).grid(row=1, column=3, sticky=tk.W, pady=(5, 0)) + + # 控制按钮 + control_frame = ttk.Frame(main_frame) + control_frame.grid(row=1, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + self.start_btn = ttk.Button(control_frame, text="开始处理", command=self.start_processing) + self.start_btn.pack(side=tk.LEFT, padx=(0, 10)) + + self.stop_btn = ttk.Button(control_frame, text="停止处理", command=self.stop_processing, state=tk.DISABLED) + self.stop_btn.pack(side=tk.LEFT) + + # 进度条 + progress_frame = ttk.LabelFrame(main_frame, text="处理进度", padding="5") + progress_frame.grid(row=2, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + self.progress_var = tk.DoubleVar() + self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100) + self.progress_bar.pack(fill=tk.X, pady=(0, 5)) + + self.progress_label = ttk.Label(progress_frame, text="准备就绪") + self.progress_label.pack(anchor=tk.W) + + # 日志窗口 + log_frame = ttk.LabelFrame(main_frame, text="处理日志", padding="5") + log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(10, 0)) + + # 日志文本框和滚动条 + log_text_frame = ttk.Frame(log_frame) + log_text_frame.pack(fill=tk.BOTH, expand=True) + + self.log_text = tk.Text(log_text_frame, height=15, wrap=tk.WORD) + scrollbar = ttk.Scrollbar(log_text_frame, orient=tk.VERTICAL, command=self.log_text.yview) + self.log_text.configure(yscrollcommand=scrollbar.set) + + self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 日志控制按钮 + log_btn_frame = ttk.Frame(log_frame) + log_btn_frame.pack(fill=tk.X, pady=(5, 0)) + + ttk.Button(log_btn_frame, text="清空日志", command=self.clear_log).pack(side=tk.LEFT, padx=(0, 5)) + ttk.Button(log_btn_frame, text="保存日志", command=self.save_log).pack(side=tk.LEFT) + + # 配置网格权重 + self.root.columnconfigure(0, weight=1) + self.root.rowconfigure(0, weight=1) + main_frame.columnconfigure(1, weight=1) + main_frame.rowconfigure(3, weight=1) + + def select_ref_tif(self): + """选择参考TIF文件""" + filename = filedialog.askopenfilename( + title="选择参考TIF文件", + filetypes=[("TIF files", "*.tif"), ("All files", "*.*")] + ) + if filename: + self.ref_tif_var.set(filename) + + def select_bip_dir(self): + """选择BIP文件夹""" + dirname = filedialog.askdirectory(title="选择BIP文件夹") + if dirname: + self.bip_dir_var.set(dirname) + + def select_out_dir(self): + """选择输出文件夹""" + dirname = filedialog.askdirectory(title="选择输出文件夹") + if dirname: + self.out_dir_var.set(dirname) + + def move_up(self): + """上移选中的变换方法""" + selection = self.transform_listbox.curselection() + if selection and selection[0] > 0: + idx = selection[0] + text = self.transform_listbox.get(idx) + self.transform_listbox.delete(idx) + self.transform_listbox.insert(idx - 1, text) + self.transform_listbox.selection_set(idx - 1) + + def move_down(self): + """下移选中的变换方法""" + selection = self.transform_listbox.curselection() + if selection and selection[0] < self.transform_listbox.size() - 1: + idx = selection[0] + text = self.transform_listbox.get(idx) + self.transform_listbox.delete(idx) + self.transform_listbox.insert(idx + 1, text) + self.transform_listbox.selection_set(idx + 1) + + def start_processing(self): + """开始处理""" + if self.processing_thread and self.processing_thread.is_alive(): + messagebox.showwarning("警告", "处理正在进行中") + return + + # 获取选中的变换方法 + selected_indices = self.transform_listbox.curselection() + if not selected_indices: + messagebox.showwarning("警告", "请至少选择一种变换方法") + return + + transform_methods = [] + for idx in selected_indices: + transform_methods.append(self.transform_listbox.get(idx)) + + # 创建配置 + config = Config( + ref_tif=self.ref_tif_var.get(), + bip_dir=self.bip_dir_var.get(), + out_dir=self.out_dir_var.get(), + matcher_name=self.matcher_var.get(), + device=self.device_var.get(), + transform_methods=transform_methods, + match_max_side=self.match_max_side_var.get(), + roi_pad_px=self.roi_pad_px_var.get(), + min_inliers=self.min_inliers_var.get(), + min_inlier_ratio=self.min_inlier_ratio_var.get() + ) + + # 重置停止事件 + self.stop_event.clear() + + # 禁用开始按钮,启用停止按钮 + self.start_btn.config(state=tk.DISABLED) + self.stop_btn.config(state=tk.NORMAL) + self.progress_var.set(0) + self.progress_label.config(text="正在初始化...") + + # 在后台线程中运行处理 + self.processing_thread = threading.Thread( + target=self.run_processing, + args=(config,), + daemon=True + ) + self.processing_thread.start() + + def stop_processing(self): + """停止处理""" + if self.processing_thread and self.processing_thread.is_alive(): + self.stop_event.set() + self.progress_label.config(text="正在停止...") + + def run_processing(self, config): + """在后台线程中运行处理""" + try: + run_batch(config, self.on_progress, self.on_log, self.stop_event) + except Exception as e: + self.log_queue.put(f"处理过程中发生错误: {e}") + finally: + # 恢复按钮状态 + self.root.after(0, lambda: self.start_btn.config(state=tk.NORMAL)) + self.root.after(0, lambda: self.stop_btn.config(state=tk.DISABLED)) + self.root.after(0, lambda: self.progress_label.config(text="处理完成")) + + def on_progress(self, current, total, filename): + """进度回调""" + if total > 0: + progress = (current / total) * 100 + self.root.after(0, lambda: self.progress_var.set(progress)) + self.root.after(0, lambda: self.progress_label.config(text=f"处理中: {filename} ({current}/{total})")) + + def on_log(self, message): + """日志回调""" + self.log_queue.put(message) + + def check_log_queue(self): + """检查日志队列并更新GUI""" + try: + while True: + message = self.log_queue.get_nowait() + self.log_text.insert(tk.END, message + '\n') + self.log_text.see(tk.END) + except queue.Empty: + pass + + # 每100ms检查一次 + self.root.after(100, self.check_log_queue) + + def clear_log(self): + """清空日志""" + self.log_text.delete(1.0, tk.END) + + def save_log(self): + """保存日志""" + filename = filedialog.asksaveasfilename( + title="保存日志", + defaultextension=".txt", + filetypes=[("Text files", "*.txt"), ("All files", "*.*")] + ) + if filename: + with open(filename, 'w', encoding='utf-8') as f: + f.write(self.log_text.get(1.0, tk.END)) + +def create_gui(): + """创建GUI""" + root = tk.Tk() + app = RegistrationGUI(root) + root.mainloop() + +if __name__ == "__main__": + if len(sys.argv) > 1 and sys.argv[1] == "--cli": + # 命令行模式 + main() + else: + # 默认GUI模式 + create_gui() diff --git a/test V8.py b/test V8.py new file mode 100644 index 0000000..16e74cf --- /dev/null +++ b/test V8.py @@ -0,0 +1,1207 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用有效区域掩膜参考影像 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +from vismatch.viz import plot_matches, plot_keypoints +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\guidingsahn\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\guidingsahn") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\guidingsahn\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["similarity", "affine", "homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) +MASK_PAD_PX = 100 # 匹配掩膜扩张像素(仅用于匹配阶段) + +# RANSAC 参数 +RANSAC_REPROJ_THRESHOLD = 5 # RANSAC重投影误差阈值(像素) +RANSAC_CONFIDENCE = 0.99 # RANSAC置信度 +RANSAC_MAX_ITERS = 1000 # RANSAC最大迭代次数 + +# 质量控制阈值 +MIN_INLIERS = 10 +MIN_INLIER_RATIO = 0.01 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, + ransacReprojThreshold=RANSAC_REPROJ_THRESHOLD, + confidence=RANSAC_CONFIDENCE, + maxIters=RANSAC_MAX_ITERS) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, + ransacReprojThreshold=RANSAC_REPROJ_THRESHOLD, + confidence=RANSAC_CONFIDENCE, + maxIters=RANSAC_MAX_ITERS) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, + ransacReprojThreshold=RANSAC_REPROJ_THRESHOLD, + confidence=RANSAC_CONFIDENCE, + maxIters=RANSAC_MAX_ITERS) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, + ransacReprojThreshold=RANSAC_REPROJ_THRESHOLD, + confidence=RANSAC_CONFIDENCE, + maxIters=RANSAC_MAX_ITERS) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def _downscale_mask_hw(mask_hw: np.ndarray, target_h: int, target_w: int) -> np.ndarray: + """将(H,W)二值掩膜缩放到目标尺寸,保持最近邻""" + m = cv2.resize(mask_hw.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST) + return m > 0 + +def _filter_matches_by_masks(result: dict, src_mask_small: np.ndarray, ref_mask_small: np.ndarray) -> dict: + """将匹配与内点严格限制在掩膜内""" + if src_mask_small is None or ref_mask_small is None: + return result + + def keep_in_mask(kpts: np.ndarray, mask_hw: np.ndarray) -> np.ndarray: + if kpts is None or len(kpts) == 0: + return np.zeros((0,), dtype=bool) + kpts = np.asarray(kpts) + xs = np.clip(np.rint(kpts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1) + ys = np.clip(np.rint(kpts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1) + return mask_hw[ys, xs] + + # 过滤 matched_kpts + if "matched_kpts0" in result and "matched_kpts1" in result: + mk0 = np.asarray(result["matched_kpts0"]) + mk1 = np.asarray(result["matched_kpts1"]) + if len(mk0) == len(mk1) and len(mk0) > 0: + keep_m = keep_in_mask(mk0, src_mask_small) & keep_in_mask(mk1, ref_mask_small) + result["matched_kpts0"] = mk0[keep_m] + result["matched_kpts1"] = mk1[keep_m] + + # 过滤 inlier_kpts + if "inlier_kpts0" in result and "inlier_kpts1" in result and result["inlier_kpts0"] is not None: + ik0 = np.asarray(result["inlier_kpts0"]) + ik1 = np.asarray(result["inlier_kpts1"]) + if len(ik0) == len(ik1) and len(ik0) > 0: + keep_i = keep_in_mask(ik0, src_mask_small) & keep_in_mask(ik1, ref_mask_small) + result["inlier_kpts0"] = ik0[keep_i] + result["inlier_kpts1"] = ik1[keep_i] + result["num_inliers"] = int(len(result["inlier_kpts0"])) + + return result + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用"源图有效掩膜"的包围盒推参考ROI(比整图bounds更贴近有效重叠) + try: + src_mask = (src.read_masks(1) > 0) # True=有效 + rows_any = np.any(src_mask, axis=1) + cols_any = np.any(src_mask, axis=0) + if rows_any.any() and cols_any.any(): + rmin = int(rows_any.argmax()) + rmax = int(src.height - 1 - rows_any[::-1].argmax()) + cmin = int(cols_any.argmax()) + cmax = int(src.width - 1 - cols_any[::-1].argmax()) + valid_win_src = rasterio.windows.Window(cmin, rmin, cmax - cmin + 1, rmax - rmin + 1) + valid_bounds_src = rasterio.windows.bounds(valid_win_src, transform=src.transform) + b = transform_bounds(src_crs, ref_crs, *valid_bounds_src, densify_pts=21) + else: + # 掩膜无效时回退到整图bounds + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + except Exception: + src_mask = None # 后续可选源图掩膜时用到 + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 将源图有效掩膜重投影到参考ROI,并适度膨胀后作为匹配掩膜 + try: + if src_mask is None: + src_mask = (src.read_masks(1) > 0) + ref_roi_transform = ref_dataset.window_transform(win) + roi_h, roi_w = int(win.height), int(win.width) + dst_mask = np.zeros((roi_h, roi_w), dtype=np.uint8) + + reproject( + source=src_mask.astype(np.uint8), + destination=dst_mask, + src_transform=src.transform, + src_crs=src_crs, + dst_transform=ref_roi_transform, + dst_crs=ref_crs, + resampling=Resampling.nearest + ) + + if MASK_PAD_PX > 0: + k = max(1, MASK_PAD_PX * 2 + 1) # odd kernel size + k = min(k, 99) # 防止核过大导致性能问题,可按需调整/删除 + kernel = np.ones((k, k), np.uint8) + dst_mask = cv2.dilate(dst_mask, kernel, iterations=1) + except Exception: + # 掩膜获取/重投影失败则不使用掩膜 + dst_mask = None + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 可选:将源图的有效掩膜也应用到源图,抑制无效区特征 + try: + src_mask3 = np.repeat((src_mask > 0)[None, ...], 3, axis=0).astype(src_img.dtype) + src_img = src_img * src_mask3 + except Exception: + pass + + # 将重投影到参考ROI的掩膜应用到参考图像(匹配仅在掩膜内进行) + if dst_mask is not None: + ref_mask3 = np.repeat((dst_mask > 0)[None, ...], 3, axis=0).astype(ref_img.dtype) + ref_img = ref_img * ref_mask3 + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + # 生成与小图同尺寸的掩膜 + src_mask_small = _downscale_mask_hw(src_mask, src_small.shape[1], src_small.shape[2]) if 'src_mask' in locals() and src_mask is not None else None + ref_mask_small = _downscale_mask_hw(dst_mask, ref_small.shape[1], ref_small.shape[2]) if 'dst_mask' in locals() and dst_mask is not None else None + + # 基于掩膜严格过滤匹配与内点 + result = _filter_matches_by_masks(result, src_mask_small, ref_mask_small) + + # 统计(以过滤后的结果为准) + num_inl = int(result.get("num_inliers", len(result.get("inlier_kpts0", [])))) + num_m = len(result.get("matched_kpts0", [])) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + # 保存匹配可视化图像(使用与匹配同尺寸的图像,保持CHW格式) + viz_dir = out_dir / "visualizations" + viz_dir.mkdir(exist_ok=True) + + matches_path = viz_dir / f"{bip_path.stem}_matches.png" + plot_matches(src_small, ref_small, result, save_path=str(matches_path)) + logger.info(f"匹配可视化已保存: {matches_path}") + + # 关键点可视化(源图像) + kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png" + plot_keypoints( + src_small, + {"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]}, + save_path=str(kpts_src_path) + ) + logger.info(f"源图像关键点可视化已保存: {kpts_src_path}") + + # 关键点可视化(参考图像) + kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png" + plot_keypoints( + ref_small, + {"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]}, + save_path=str(kpts_ref_path) + ) + logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V9.py b/test V9.py new file mode 100644 index 0000000..14ecc8e --- /dev/null +++ b/test V9.py @@ -0,0 +1,1299 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +问题:当图像中大部分是水体时,匹配过多出现在掩膜边缘,同时过滤时将本来就少的陆地匹配点也过滤掉了 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +from vismatch.viz import plot_matches, plot_keypoints +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\guidingsahn\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\guidingsahn") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\guidingsahn\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["similarity", "affine", "homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) +MASK_PAD_PX = 100 # 匹配掩膜扩张像素(仅用于匹配阶段) + +# 质量控制阈值 +MIN_INLIERS = 10 +MIN_INLIER_RATIO = 0.01 + +# 掩膜边缘羽化与过滤 +FEATHER_PX = 20 # 掩膜羽化宽度(像素,先在全分辨率/ROI分辨率上做) +EDGE_BAND_PX = 30 # 剔除距离掩膜边界小于此像素的匹配点(在小图上按比例缩放) + +# 纹理过滤 +MIN_GRAD_QUANTILE = 0.20 # 梯度幅值的分位阈值(0~1),低于该阈值的点视为低纹理,剔除 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def _downscale_mask_hw(mask_hw: np.ndarray, target_h: int, target_w: int) -> np.ndarray: + """将(H,W)二值掩膜缩放到目标尺寸,保持最近邻""" + m = cv2.resize(mask_hw.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST) + return m > 0 + +def _soft_alpha_from_mask(mask_hw: np.ndarray, feather_px: int) -> np.ndarray: + """ + 二值掩膜 -> 软掩膜 alpha∈[0,1],边缘处按距离线性上升,避免硬边缘。 + mask_hw: bool/uint8 (H,W) True/1表示有效 + """ + if mask_hw is None: + return None + m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255 + # 距离变换仅对前景内部有效,计算到边界的距离 + dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3) + if feather_px <= 0: + alpha = (dist > 0).astype(np.float32) + else: + alpha = np.clip(dist / float(feather_px), 0.0, 1.0).astype(np.float32) + return alpha # (H,W) float32 + +def _distance_keep_mask(mask_hw: np.ndarray, min_dist_px: int) -> np.ndarray: + """ + 生成"远离边界"的保留掩膜:仅保留距离边界>=min_dist_px的像素。 + """ + if mask_hw is None: + return None + m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255 + dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3) + keep = dist >= float(max(1, min_dist_px)) + return keep + +def _grad_mask_from_chw(img_chw: np.ndarray, quantile: float) -> np.ndarray: + """ + 根据梯度幅值生成纹理掩膜(H,W)True=纹理足够。 + 使用与匹配同尺寸的CHW图像。 + """ + # 转灰度 + g = img_chw.mean(axis=0).astype(np.float32) # (H,W) + gx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3) + gy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3) + mag = np.sqrt(gx*gx + gy*gy) + thr = float(np.quantile(mag, quantile)) if mag.size > 0 else 0.0 + return mag >= thr # (H,W) bool + +def _filter_matches_by_masks(result: dict, src_mask_small: np.ndarray, ref_mask_small: np.ndarray) -> dict: + """将匹配与内点严格限制在掩膜内""" + if src_mask_small is None or ref_mask_small is None: + return result + + def keep_in_mask(kpts: np.ndarray, mask_hw: np.ndarray) -> np.ndarray: + if kpts is None or len(kpts) == 0: + return np.zeros((0,), dtype=bool) + kpts = np.asarray(kpts) + xs = np.clip(np.rint(kpts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1) + ys = np.clip(np.rint(kpts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1) + return mask_hw[ys, xs] + + # 过滤 matched_kpts + if "matched_kpts0" in result and "matched_kpts1" in result: + mk0 = np.asarray(result["matched_kpts0"]) + mk1 = np.asarray(result["matched_kpts1"]) + if len(mk0) == len(mk1) and len(mk0) > 0: + keep_m = keep_in_mask(mk0, src_mask_small) & keep_in_mask(mk1, ref_mask_small) + result["matched_kpts0"] = mk0[keep_m] + result["matched_kpts1"] = mk1[keep_m] + + # 过滤 inlier_kpts + if "inlier_kpts0" in result and "inlier_kpts1" in result and result["inlier_kpts0"] is not None: + ik0 = np.asarray(result["inlier_kpts0"]) + ik1 = np.asarray(result["inlier_kpts1"]) + if len(ik0) == len(ik1) and len(ik0) > 0: + keep_i = keep_in_mask(ik0, src_mask_small) & keep_in_mask(ik1, ref_mask_small) + result["inlier_kpts0"] = ik0[keep_i] + result["inlier_kpts1"] = ik1[keep_i] + result["num_inliers"] = int(len(result["inlier_kpts0"])) + + return result + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用"源图有效掩膜"的包围盒推参考ROI(比整图bounds更贴近有效重叠) + try: + src_mask = (src.read_masks(1) > 0) # True=有效 + rows_any = np.any(src_mask, axis=1) + cols_any = np.any(src_mask, axis=0) + if rows_any.any() and cols_any.any(): + rmin = int(rows_any.argmax()) + rmax = int(src.height - 1 - rows_any[::-1].argmax()) + cmin = int(cols_any.argmax()) + cmax = int(src.width - 1 - cols_any[::-1].argmax()) + valid_win_src = rasterio.windows.Window(cmin, rmin, cmax - cmin + 1, rmax - rmin + 1) + valid_bounds_src = rasterio.windows.bounds(valid_win_src, transform=src.transform) + b = transform_bounds(src_crs, ref_crs, *valid_bounds_src, densify_pts=21) + else: + # 掩膜无效时回退到整图bounds + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + except Exception: + src_mask = None # 后续可选源图掩膜时用到 + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 将源图有效掩膜重投影到参考ROI,并适度膨胀后作为匹配掩膜 + try: + if src_mask is None: + src_mask = (src.read_masks(1) > 0) + ref_roi_transform = ref_dataset.window_transform(win) + roi_h, roi_w = int(win.height), int(win.width) + dst_mask = np.zeros((roi_h, roi_w), dtype=np.uint8) + + reproject( + source=src_mask.astype(np.uint8), + destination=dst_mask, + src_transform=src.transform, + src_crs=src_crs, + dst_transform=ref_roi_transform, + dst_crs=ref_crs, + resampling=Resampling.nearest + ) + + if MASK_PAD_PX > 0: + k = max(1, MASK_PAD_PX * 2 + 1) # odd kernel size + k = min(k, 99) # 防止核过大导致性能问题,可按需调整/删除 + kernel = np.ones((k, k), np.uint8) + dst_mask = cv2.dilate(dst_mask, kernel, iterations=1) + except Exception: + # 掩膜获取/重投影失败则不使用掩膜 + dst_mask = None + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 软掩膜:避免在边界产生硬高对比边 + try: + alpha_src = _soft_alpha_from_mask(src_mask, FEATHER_PX) if src_mask is not None else None + except Exception: + alpha_src = None + try: + alpha_ref = _soft_alpha_from_mask(dst_mask, FEATHER_PX) if dst_mask is not None else None + except Exception: + alpha_ref = None + + if alpha_src is not None: + alpha_src3 = np.repeat(alpha_src[None, ...], 3, axis=0).astype(src_img.dtype) + src_img = src_img * alpha_src3 + + if alpha_ref is not None: + alpha_ref3 = np.repeat(alpha_ref[None, ...], 3, axis=0).astype(ref_img.dtype) + ref_img = ref_img * alpha_ref3 + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + # 与小图同尺寸的掩膜 + src_mask_small = _downscale_mask_hw(src_mask, src_small.shape[1], src_small.shape[2]) if 'src_mask' in locals() and src_mask is not None else None + ref_mask_small = _downscale_mask_hw(dst_mask, ref_small.shape[1], ref_small.shape[2]) if 'dst_mask' in locals() and dst_mask is not None else None + + # 剔除掩膜边缘带(小图尺度的最小距离) + def _scale_px(px_full: int, full_wh, small_wh) -> int: + # 用平均缩放;也可以分别对H/W计算后取最小 + sy = small_wh[0] / max(1, full_wh[0]) + sx = small_wh[1] / max(1, full_wh[1]) + s = 0.5 * (sx + sy) + return max(1, int(round(px_full * s))) + + edge_band_src_small = _scale_px(EDGE_BAND_PX, (src_img.shape[1], src_img.shape[2]), (src_small.shape[1], src_small.shape[2])) + edge_band_ref_small = _scale_px(EDGE_BAND_PX, (ref_img.shape[1], ref_img.shape[2]), (ref_small.shape[1], ref_small.shape[2])) + + keep_src_edge = _distance_keep_mask(src_mask_small, edge_band_src_small) if src_mask_small is not None else None + keep_ref_edge = _distance_keep_mask(ref_mask_small, edge_band_ref_small) if ref_mask_small is not None else None + + # 纹理掩膜 + keep_src_tex = _grad_mask_from_chw(src_small, MIN_GRAD_QUANTILE) + keep_ref_tex = _grad_mask_from_chw(ref_small, MIN_GRAD_QUANTILE) + + # 组合最终保留掩膜(边缘+纹理),二者都要满足 + def _combine_keep(m_edge, m_tex): + if m_edge is None: + return m_tex + return (m_edge & m_tex) + + keep_src_final = _combine_keep(keep_src_edge, keep_src_tex) + keep_ref_final = _combine_keep(keep_ref_edge, keep_ref_tex) + + # 将匹配与内点严格限制在最终掩膜内 + def _filter_by_bool_masks(res, m_src, m_ref): + if m_src is None or m_ref is None: + return res + + def keep_in(mask_hw, pts): + if pts is None or len(pts) == 0: + return np.zeros((0,), dtype=bool) + xs = np.clip(np.rint(pts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1) + ys = np.clip(np.rint(pts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1) + return mask_hw[ys, xs] + + # matched + if "matched_kpts0" in res and "matched_kpts1" in res: + mk0 = np.asarray(res["matched_kpts0"]); mk1 = np.asarray(res["matched_kpts1"]) + if len(mk0) == len(mk1) and len(mk0) > 0: + keep_m = keep_in(m_src, mk0) & keep_in(m_ref, mk1) + res["matched_kpts0"] = mk0[keep_m] + res["matched_kpts1"] = mk1[keep_m] + + # inliers + if "inlier_kpts0" in res and "inlier_kpts1" in res and res["inlier_kpts0"] is not None: + ik0 = np.asarray(res["inlier_kpts0"]); ik1 = np.asarray(res["inlier_kpts1"]) + if len(ik0) == len(ik1) and len(ik0) > 0: + keep_i = keep_in(m_src, ik0) & keep_in(m_ref, ik1) + res["inlier_kpts0"] = ik0[keep_i] + res["inlier_kpts1"] = ik1[keep_i] + res["num_inliers"] = int(len(res["inlier_kpts0"])) + return res + + result = _filter_by_bool_masks(result, keep_src_final, keep_ref_final) + + # 统计(以过滤后的结果为准) + num_inl = int(result.get("num_inliers", len(result.get("inlier_kpts0", [])))) + num_m = len(result.get("matched_kpts0", [])) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + # 保存匹配可视化图像(使用与匹配同尺寸的图像,保持CHW格式) + viz_dir = out_dir / "visualizations" + viz_dir.mkdir(exist_ok=True) + + matches_path = viz_dir / f"{bip_path.stem}_matches.png" + plot_matches(src_small, ref_small, result, save_path=str(matches_path)) + logger.info(f"匹配可视化已保存: {matches_path}") + + # 关键点可视化(源图像) + kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png" + plot_keypoints( + src_small, + {"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]}, + save_path=str(kpts_src_path) + ) + logger.info(f"源图像关键点可视化已保存: {kpts_src_path}") + + # 关键点可视化(参考图像) + kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png" + plot_keypoints( + ref_small, + {"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]}, + save_path=str(kpts_ref_path) + ) + logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test.py b/test.py new file mode 100644 index 0000000..737ae9a --- /dev/null +++ b/test.py @@ -0,0 +1,290 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用地理信息进行粗定位,然后用图像匹配进行精配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\jiashixian\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\jiashixian\Geoout\1") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\jiashixian\matchanything-roma") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 300 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 30 # 最少内点数 +MIN_INLIER_RATIO = 0.15 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# ---------- 工具函数 ---------- +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + return False + + # 5) 用内点估计仿射变换 + k0 = result["inlier_kpts0"].astype(np.float32) + k1 = result["inlier_kpts1"].astype(np.float32) + + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A is None: + logger.warning(f"仿射估计失败: {bip_path.name}") + return False + + # 6) 把"src_small->ref_small"的仿射映射回"src_full->ref_full_roi" + # 缩放系数:full/small + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0 = np.array([[1/s0x, 0, 0], + [0, 1/s0y, 0], + [0, 0, 1]], dtype=np.float64) # full -> small + S1_inv = np.array([[s1x, 0, 0], + [0, s1y, 0], + [0, 0, 1]], dtype=np.float64) # small -> full(roi) + + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A # small src -> small ref_roi + + # full_src -> full_ref_roi(用 S0,而不是 S0_inv) + M_full = S1_inv @ A3 @ S0 + + # 7) 目标输出:重采样到 ref 全图网格 + # 需要"src像素 -> 地图坐标"的修正 transform + T_off = np.array([[1, 0, win.col_off], + [0, 1, win.row_off], + [0, 0, 1]], dtype=np.float64) + + # ref_transform 转 3x3 + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + + src_pixel_to_map_corrected = Rt @ T_off @ M_full + + # 转回 Affine + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 8) 重采样输出到 ref 网格(多波段,BIP 输出) + out_path = out_dir / f"{bip_path.stem}_registered.bip" + + # 获取 NoData 值 + src_nodata = src.nodata # 可能是 None、0、65535 等 + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 以参考影像的网格/坐标为目标,波段数和数据类型继承源影像 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], # 尽量保持与源影像一致的数据类型 + height=ref_dataset.height, + width=ref_dataset.width, + count=src.count, # 多波段 + transform=ref_transform, + crs=ref_crs, + interleave="bip", # 指定 BIP + compress=None, + nodata=dst_nodata # 设置 NoData 值 + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + # 逐波段重投影(float32 计算更稳) + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((ref_dataset.height, ref_dataset.width), dtype=np.float32) + + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, # 由原逻辑推导的 src像素->ref像素 的仿射 + src_crs=ref_crs, + dst_transform=ref_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, # 新增 NoData 处理 + dst_nodata=dst_nodata, # 新增 NoData 处理 + resampling=Resampling.bilinear, + ) + + # 调试统计信息 + logger.info(f"波段 {b}: min={float(np.nanmin(dst_band)):.2f}, max={float(np.nanmax(dst_band)):.2f}, " + f"mean={float(np.nanmean(dst_band)):.2f}, nodata占比={(dst_band==dst_nodata).mean():.2%}") + + # 转回目标 dtype(若为整型,先裁剪再转型) + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + if src_nodata is not None: + # 保持 nodata 不被拉伸 + mask = (dst_band == dst_nodata) + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max) + dst_band = dst_band.astype(out_profile["dtype"]) + if src_nodata is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准: {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main()