commit 5e0984bf9c2ff26ec3351035fb5fcefb3b38f191 Author: zhanghuilai Date: Fri Mar 6 17:24:55 2026 +0800 first commit diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..f649f0f --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..419e41a --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,44 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..ce9d49a --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/test.iml b/.idea/test.iml new file mode 100644 index 0000000..ec63674 --- /dev/null +++ b/.idea/test.iml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..353a2a1 --- /dev/null +++ b/README.md @@ -0,0 +1,176 @@ +# 图像配准工具集 (Image Registration Toolkit) + +这是一个基于 Python 的遥感图像配准工具集,主要用于批量配准 .bip 格式的影像文件到参考 .tif 文件。支持多种变换模型和质量控制机制,特别适用于森林、水体等复杂地物场景。 + +## 功能特性 + +### 核心功能 +- **批量配准**: 支持批量处理多个 .bip 文件到参考 .tif 文件 +- **多种变换模型**: 支持相似变换、仿射变换、单应变换等 +- **智能掩膜**: 使用有效区域掩膜和水体分割提高配准质量 +- **质量控制**: 基于内点数、内点比例和重投影误差的质量评估 +- **可视化输出**: 自动生成匹配结果和关键点可视化图像 + +### 高级特性 +- **RANSAC 参数控制**: 可配置的重投影误差阈值、置信度和最大迭代次数 +- **纹理过滤**: 基于梯度幅值的纹理质量筛选 +- **几何约束**: 距离边界限制和空间分布均匀性检查 +- **多尺度处理**: 支持不同分辨率的匹配和变换 + +## 文件说明 + +### 核心脚本 +- `test V8.py`: 最新版本的主配准脚本,支持有效区域掩膜 +- `test V9.py`: 增强版本,包含软掩膜和纹理过滤 +- `demo.py`: 基础演示脚本,展示匹配和可视化功能 + +### 工具脚本 +- `mask_water.py`: 基于 ROI 文件掩膜 TIF 图像的工具 +- `test V5.py` - `test V7.py`: 早期版本的配准脚本 + +## 环境要求 + +### 依赖包 +``` +numpy +opencv-python +rasterio +geopandas +shapely +scikit-image +matplotlib +scipy +torch (用于 vismatch) +vismatch +``` + +### 安装 +```bash +pip install numpy opencv-python rasterio geopandas shapely scikit-image matplotlib scipy torch +# 安装 vismatch (假设已配置) +``` + +## 使用方法 + +### 基本配准 +```bash +python test_V8.py # 运行主配准脚本 +``` + +### 参数配置 +在脚本开头修改以下关键参数: + +```python +# 文件路径 +REF_TIF = r"E:\path\to\reference.tif" # 参考影像路径 +BIP_DIR = Path(r"E:\path\to\bip\directory") # BIP文件目录 +OUT_DIR = Path(r"E:\path\to\output") # 输出目录 + +# 匹配参数 +MATCHER_NAME = "matchanything-roma" # 匹配器选择 +MATCH_MAX_SIDE = 1200 # 匹配最大边长 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法优先级 +TRANSFORM_METHODS = ["similarity", "affine", "homography"] + +# RANSAC参数 +RANSAC_REPROJ_THRESHOLD = 3.0 # 重投影误差阈值 +RANSAC_CONFIDENCE = 0.99 # 置信度 +RANSAC_MAX_ITERS = 1000 # 最大迭代次数 + +# 质量控制 +MIN_INLIERS = 10 # 最少内点数 +MIN_INLIER_RATIO = 0.01 # 最少内点比例 +``` + +### ROI掩膜工具 +```bash +python mask_water.py input.tif roi.shp -o output_masked.tif +``` + +## 算法流程 + +1. **粗定位**: 使用源影像有效区域确定参考影像的感兴趣区域 (ROI) +2. **数据读取**: 读取源影像和参考影像的对应区域 +3. **掩膜处理**: 应用有效区域掩膜和水体分割 +4. **特征匹配**: 使用深度学习匹配器提取和匹配特征点 +5. **几何变换**: 尝试多种变换模型,选择最优的几何变换 +6. **质量评估**: 基于重投影误差评估变换质量 +7. **输出生成**: 生成配准后的影像和可视化结果 + +## 关键参数说明 + +### 匹配参数 +- `MATCH_MAX_SIDE`: 控制匹配时的图像尺寸,越小速度越快但细节损失越多 +- `ROI_PAD_PX`: ROI扩展像素,影响匹配区域范围 + +### RANSAC参数 +- `RANSAC_REPROJ_THRESHOLD`: 内点判断的误差阈值,越小内点质量越高但数量越少 +- `RANSAC_CONFIDENCE`: 算法置信度,影响迭代次数 +- `RANSAC_MAX_ITERS`: 最大迭代次数上限 + +### 质量控制 +- `MIN_INLIERS`: 最少内点数量要求 +- `MIN_INLIER_RATIO`: 内点比例要求 + +## 场景优化建议 + +### 森林场景 +```python +MATCH_MAX_SIDE = 1600 +MIN_INLIERS = 25 +MIN_INLIER_RATIO = 0.025 +# 启用纹理过滤和边界距离限制 +``` + +### 水体场景 +```python +# 使用水体掩膜预处理 +MASK_PAD_PX = 100 +# 调大RANSAC阈值容忍岸线变化 +RANSAC_REPROJ_THRESHOLD = 4.0 +``` + +### 城市/结构化场景 +```python +MATCH_MAX_SIDE = 1000 # 结构特征明显,不需要很高分辨率 +RANSAC_REPROJ_THRESHOLD = 2.0 # 严格质量控制 +``` + +## 输出结果 + +### 文件输出 +- `*_registered.bip`: 配准后的影像文件 +- `visualizations/*_matches.png`: 匹配结果可视化 +- `visualizations/*_keypoints_src.png`: 源影像关键点 +- `visualizations/*_keypoints_ref.png`: 参考影像关键点 + +### 统计输出 +- `stats/registration_stats.csv`: 包含每对影像的配准统计信息 + +## 故障排除 + +### 常见问题 +1. **内点数过少**: 检查影像质量,调整 RANSAC 参数,或扩大匹配区域 +2. **内存不足**: 降低 `MATCH_MAX_SIDE` 或使用 CPU 模式 +3. **配准失败**: 检查坐标系统一致性,调整质量阈值 + +### 调试建议 +- 查看可视化输出了解匹配情况 +- 检查统计 CSV 文件的各项指标 +- 调整参数时从小范围测试开始 + +## 版本历史 + +- **test V8.py**: 支持有效区域掩膜和 RANSAC 参数控制 +- **test V9.py**: 增加软掩膜和纹理过滤 +- **早期版本**: 基础配准功能实现 + +## 许可证 + +本项目遵循 MIT 许可证。 + +## 贡献 + +欢迎提交 Issue 和 Pull Request 来改进这个工具集。 \ No newline at end of file diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..ea58bad --- /dev/null +++ b/demo.py @@ -0,0 +1,21 @@ + +from vismatch import get_matcher +from vismatch.viz import plot_matches,plot_keypoints + +# Choose any of the 50+ matchers listed below +matcher = get_matcher("matchanything-roma", device="cuda") +img_size = 512 # optional + +img0 = matcher.load_image(r"D:\汇报\水库项目资料\gif\屏幕截图 2026-02-28 140816.png", resize=img_size) +img1 = matcher.load_image(r"D:\汇报\水库项目资料\gif\屏幕截图 2026-02-28 140525.png", resize=img_size) + +result = matcher(img0, img1) +# result.keys() = ["num_inliers", "H", "all_kpts0", "all_kpts1", "all_desc0", "all_desc1", "matched_kpts0", "matched_kpts1", "inlier_kpts0", "inlier_kpts1"] + +# This will plot visualizations for matches as shown in the figures above +plot_matches(img0, img1, result, save_path="plot_matches.png") + +# Or you can extract and visualize keypoints as easily as +result = matcher.extract(img0) +# result.keys() = ["all_kpts0", "all_desc0"] +plot_keypoints(img0, result, save_path="plot_kpts.png") \ No newline at end of file diff --git a/mask_water.py b/mask_water.py new file mode 100644 index 0000000..ce6e52b --- /dev/null +++ b/mask_water.py @@ -0,0 +1,131 @@ +""" +掩膜tif文件的ROI区域并保存 + +输入: tif文件路径, roi文件路径(shp/geojson等) +输出: 掩膜后的tif文件 +""" + +import argparse +import numpy as np +import rasterio +from rasterio.mask import mask +from rasterio.features import shapes +import geopandas as gpd +# from shapely.geometry import shape +import logging +from pathlib import Path + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +def mask_tif_by_roi(tif_path: str, roi_path: str, output_path: str = None, + nodata_value: float = None): + """ + 使用ROI文件掩膜TIF文件 + + Args: + tif_path: 输入TIF文件路径 + roi_path: ROI文件路径(shp/geojson等) + output_path: 输出文件路径,如果为None则自动生成 + nodata_value: nodata值,如果为None则使用原始文件的nodata值 + + Returns: + str: 输出文件路径 + """ + + tif_path = Path(tif_path) + roi_path = Path(roi_path) + + # 检查输入文件存在性 + if not tif_path.exists(): + raise FileNotFoundError(f"TIF文件不存在: {tif_path}") + if not roi_path.exists(): + raise FileNotFoundError(f"ROI文件不存在: {roi_path}") + + # 自动生成输出路径 + if output_path is None: + output_path = tif_path.parent / f"{tif_path.stem}_masked{tif_path.suffix}" + else: + output_path = Path(output_path) + + logger.info(f"开始处理: {tif_path.name}") + logger.info(f"ROI文件: {roi_path.name}") + + try: + # 读取ROI文件 + logger.info("读取ROI文件...") + gdf = gpd.read_file(roi_path) + + if gdf.empty: + logger.warning("ROI文件为空,直接复制原文件") + # 直接复制原文件 + import shutil + shutil.copy2(tif_path, output_path) + return str(output_path) + + # 转换为GeoJSON格式的几何对象 + geometries = gdf.geometry.tolist() + logger.info(f"找到 {len(geometries)} 个ROI几何对象") + + # 读取并掩膜TIF + logger.info("读取并掩膜TIF文件...") + with rasterio.open(tif_path) as src: + # 使用ROI掩膜 - 保留ROI以外的区域(ROI区域设为nodata) + masked_data, masked_transform = mask( + src, + geometries, + crop=False, # 不裁剪,保持原始尺寸 + invert=True, # 反转:ROI区域设为nodata,ROI以外保留 + nodata=nodata_value if nodata_value is not None else src.nodata + ) + + # 更新元数据 + out_meta = src.meta.copy() + out_meta.update({ + "driver": "GTiff", + "height": masked_data.shape[1], + "width": masked_data.shape[2], + "transform": masked_transform, + "nodata": nodata_value if nodata_value is not None else src.nodata + }) + + # 保存结果 + logger.info(f"保存掩膜结果到: {output_path.name}") + with rasterio.open(output_path, "w", **out_meta) as dest: + dest.write(masked_data) + + logger.info("处理完成!") + return str(output_path) + + except Exception as e: + logger.error(f"处理失败: {str(e)}") + raise + + +def main(): + parser = argparse.ArgumentParser(description="使用ROI文件掩膜TIF文件") + parser.add_argument("tif_path", help="输入TIF文件路径") + parser.add_argument("roi_path", help="ROI文件路径 (shp/geojson等)") + parser.add_argument("-o", "--output", help="输出文件路径 (可选,默认自动生成)") + parser.add_argument("-n", "--nodata", type=float, help="nodata值 (可选,默认使用原文件nodata)") + + args = parser.parse_args() + + try: + output_path = mask_tif_by_roi( + args.tif_path, + args.roi_path, + args.output, + args.nodata + ) + print(f"成功完成!输出文件: {output_path}") + except Exception as e: + print(f"错误: {e}") + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/test V2.py b/test V2.py new file mode 100644 index 0000000..4cf1858 --- /dev/null +++ b/test V2.py @@ -0,0 +1,359 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +仿射变换修改 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging +from osgeo import gdal + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\jiashixian\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\jiashixian\Geoout\1") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\jiashixian\matchanything-roma") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 30 # 最少内点数 +MIN_INLIER_RATIO = 0.15 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# ---------- 工具函数 ---------- +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + return False + + # 5) 用内点估计仿射变换(相似/全仿射二选一) + k0 = result["inlier_kpts0"].astype(np.float32) + k1 = result["inlier_kpts1"].astype(np.float32) + + # 相似(旋转+等比缩放+平移,4DOF) + A_partial, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + # 全仿射(旋转+非等比缩放+剪切+平移,6DOF) + A_full, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + + def _reproj_err(A2x3: np.ndarray, pts0: np.ndarray, pts1: np.ndarray): + if A2x3 is None: + return np.inf, np.inf + ones = np.ones((pts0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([pts0, ones]) # (N,3) + pred = (A2x3 @ src_h.T).T # (N,2) + e = np.sqrt(((pred - pts1) ** 2).sum(axis=1)) + return float(np.median(e)), float(np.percentile(e, 95)) + + med_p, p95_p = _reproj_err(A_partial, k0, k1) + med_f, p95_f = _reproj_err(A_full, k0, k1) + + if (p95_f < p95_p) or (abs(p95_f - p95_p) < 0.5 and med_f < med_p): + A = A_full + model_type = "affine_full" + else: + A = A_partial + model_type = "affine_partial" + + if A is None: + logger.warning(f"仿射估计失败: {bip_path.name}") + return False + + logger.info(f"选用模型: {model_type}, partial(p50={med_p:.2f},p95={p95_p:.2f}), full(p50={med_f:.2f},p95={p95_f:.2f})") + + # 6) 基于内点构建 GCP,并使用 GDAL TPS(薄板样条)实现非刚性(B样条近似效果)配准 + # 注:GDAL 提供 TPS(薄板样条),在遥感配准中与B样条同属光滑非刚性方法,工程上更常用 + # 这里我们直接用 TPS 来实现你所需的“B样条变换”效果 + # 将 small 尺寸的内点坐标映射回 full 分辨率 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0_inv = np.array([[s0x, 0, 0], [0, s0y, 0], [0, 0, 1]], dtype=np.float64) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) # small -> full (ref ROI) + + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0, ones]).T).T[:, :2] # 源:full 源像素 + k1_roi_full = (S1_inv @ np.hstack([k1, ones]).T).T[:, :2] + # 加上 ROI 偏移,得到参考影像全局像素坐标 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) + + if k0_full.shape[0] < 8: + logger.warning(f"TPS 需要更多控制点(>=8),当前 {k0_full.shape[0]},回退到仿射输出。") + else: + # 生成 GCP:目标为参考地图坐标(由全局像素 + ref_transform 转换) + ref_transform = ref_dataset.transform + def px_to_map(xp, yp): + X = ref_transform.a * xp + ref_transform.b * yp + ref_transform.c + Y = ref_transform.d * xp + ref_transform.e * yp + ref_transform.f + return X, Y + + gcps = [] + for (sx, sy), (rx, ry) in zip(k0_full, k1_global): + mx, my = px_to_map(rx, ry) + gcps.append(gdal.GCP(mx, my, 0.0, float(sx), float(sy))) + + # 计算参考影像的输出边界与分辨率 + minx = ref_transform.c + maxy = ref_transform.f + maxx = minx + ref_transform.a * ref_dataset.width + ref_transform.b * ref_dataset.height + miny = maxy + ref_transform.d * ref_dataset.width + ref_transform.e * ref_dataset.height + + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 输出路径(ENVI+BIP) + out_path = out_dir / f"{bip_path.stem}_registered.bip" + + # 使用 GDAL Warp 执行 TPS 变换到参考网格 + # 注意:需要系统已安装 GDAL Python 绑定 + try: + warp_opts = gdal.WarpOptions( + format="ENVI", + outputBounds=(minx, miny, maxx, maxy), + xRes=abs(ref_transform.a), + yRes=abs(ref_transform.e), + dstSRS=ref_crs.to_wkt() if hasattr(ref_crs, "to_wkt") else str(ref_crs), + tps=True, # 启用薄板样条(非刚性) + resampleAlg="bilinear", + srcNodata=src_nodata, + dstNodata=dst_nodata, + multithread=True, + creationOptions=["INTERLEAVE=BIP"], + gcps=gcps, + ) + + # 直接从源文件路径进行Warp + warp_ds = gdal.Warp( + destNameOrDestDS=str(out_path), + srcDSOrSrcDSTab=str(bip_path), + options=warp_opts, + ) + if warp_ds is None: + logger.warning("GDAL Warp(TPS) 失败,回退到仿射输出。") + else: + warp_ds = None # 关闭文件 + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + except Exception as e: + logger.warning(f"GDAL Warp(TPS) 异常,回退到仿射输出: {e}") + + # ---- 回退:使用仿射(你前面的流程),保证最小可用结果 ---- + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + # 仍沿用上面已估计的 A(partial/full 二选一) + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=ref_dataset.height, + width=ref_dataset.width, + count=src.count, + transform=ref_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((ref_dataset.height, ref_dataset.width), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=ref_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V3.py b/test V3.py new file mode 100644 index 0000000..2bf8f69 --- /dev/null +++ b/test V3.py @@ -0,0 +1,970 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用 SimpleITK 实现 B 样条变换 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\guidingsahn\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\guidingsahn") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\guidingsahn\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1500 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 10 # 最少内点数 +MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# ---------- 工具函数 ---------- +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + elif method == "tps": + # 薄板样条变换(如果SimpleITK可用) + SITK_TPS = SITK_AVAILABLE and hasattr(sitk, "ThinPlateSplineKernelTransform") + if not SITK_TPS: + return None, None + try: + tps_transform = sitk.ThinPlateSplineKernelTransform() + tps_transform.SetKernelTypeToThinPlateSpline() + + fixed_landmarks = sitk.vectorDPoint() + moving_landmarks = sitk.vectorDPoint() + + for (rx, ry), (sx, sy) in zip(k1, k0): + fixed_landmarks.push_back([float(rx), float(ry)]) + moving_landmarks.push_back([float(sx), float(sy)]) + + tps_transform.SetFixedLandmarks(fixed_landmarks) + tps_transform.SetMovingLandmarks(moving_landmarks) + return "tps", tps_transform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "tps": + # TPS 变换重投影误差(SimpleITK) + pred = [] + for pt in k0: + transformed_pt = transform.TransformPoint([float(pt[0]), float(pt[1])]) + pred.append([transformed_pt[0], transformed_pt[1]]) + pred = np.array(pred) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + return False + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type == "polynomial": + # 对于多项式,估计逆变换 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=2) # 顺序:目标->源 + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + xy_src = t_inv(xy) # -> (x_src, y_src) in full-src + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + else: # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type == "tps": + # B样条变换:优先使用 PIRT,如果不可用则使用 SimpleITK TPS + try: + if PIRT_AVAILABLE: + # 使用 PIRT 实现 B样条弹性变换 + logger.info("使用 PIRT B样条变换") + + # 读取用于配准的单波段并归一化 + ref_roi_data = ref_dataset.read(1, window=win).astype(np.float32) + src_band_data = src.read(1).astype(np.float32) + + from skimage.exposure import rescale_intensity + ref_roi_reg = rescale_intensity(ref_roi_data, in_range='image', out_range=(0.0, 1.0)) + src_reg = rescale_intensity(src_band_data, in_range='image', out_range=(0.0, 1.0)) + + # 构建 PIRT 注册器 + reg = pirt.Registration(fixed=ref_roi_reg, moving=src_reg) + + # 设置 B样条变换 + bspline = pirt.transform.BSplineTransform(grid_spacing=(96, 96)) # 可调节控制点间距 + reg.set_transformation(bspline) + + # 设置相似度度量(NCC 或 MI) + reg.set_similarity(pirt.metrics.NCC()) + + # 多分辨率金字塔 + reg.set_pyramid([4, 2, 1]) + + # 优化器设置 + reg.set_optimizer(pirt.optimizers.LBFGS(max_iter=200), smooth=1.0) + + # 执行注册 + reg.run() + + # 获取前向映射(参考网格到源的位移场) + phi = reg.get_forward_mapping() # (H, W, 2) 位移场 + + # 应用到所有波段 + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=ref_roi_data.shape[0], + width=ref_roi_data.shape[1], + count=src.count, + transform=ref_dataset.window_transform(win), + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + # 使用 PIRT 的 warp 函数应用位移场 + warped = pirt.warp(src_band, phi, mode='constant', cval=float(dst_nodata)) + band_data = warped.astype(out_profile["dtype"]) + out_ds.write(band_data, b) + + logger.info(f"成功配准(B样条-PIRT): {bip_path.name} -> {out_path.name}") + return True + + elif SITK_AVAILABLE and hasattr(sitk, "LandmarkBasedTransformInitializer"): + # 回退到 SimpleITK TPS + logger.info("PIRT 不可用,使用 SimpleITK TPS") + + # 1) 统一坐标系:用"全图像素"作为物理坐标(spacing=1, origin=(0,0), direction=I) + fixed_pts = [(float(x1), float(y1)) for (x1,y1) in k1_global] # 参考(输出)侧 + moving_pts = [(float(x0), float(y0)) for (x0,y0) in k0_full] # 源(输入)侧 + + # 2) 构造 TPS(参考→源) 用于 Resample(输出点 -> 输入点) + tps_ref2src = sitk.LandmarkBasedTransformInitializer( + sitk.Transform(2, sitk.sitkThinPlateSplineKernelTransform), + fixed_pts, # fixed = 参考 + moving_pts # moving = 源 + ) + + # 3) 构造 TPS(源→参考) 仅用于外接矩形估计(源顶点投到参考) + tps_src2ref = sitk.LandmarkBasedTransformInitializer( + sitk.Transform(2, sitk.sitkThinPlateSplineKernelTransform), + moving_pts, # fixed = 源 + fixed_pts # moving = 参考 + ) + + # 4) 用 tps_src2ref 变换源四角,求参考全图上的外接矩形,并与参考范围求交 + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + dst_corners = np.array([tps_src2ref.TransformPoint((float(x),float(y))) for x,y in src_corners], dtype=np.float32) + + min_x = max(0, int(np.floor(dst_corners[:,0].min())) - 10) + max_x = min(ref_dataset.width, int(np.ceil (dst_corners[:,0].max())) + 10) + min_y = max(0, int(np.floor(dst_corners[:,1].min())) - 10) + max_y = min(ref_dataset.height, int(np.ceil (dst_corners[:,1].max())) + 10) + bbox_w, bbox_h = max_x-min_x, max_y-min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"TPS最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 5) 参考(输出)影像定义:spacing=1,origin=(min_x,min_y),direction=I + ref_img = sitk.Image([bbox_w, bbox_h], sitk.sitkFloat32) + ref_img.SetSpacing((1.0, 1.0)) + ref_img.SetOrigin((float(min_x), float(min_y))) + ref_img.SetDirection((1.0,0.0,0.0,1.0)) + + # 6) 源影像:用 rasterio 读为 numpy,再转 SITK(spacing=1, origin=0, direction=I) + src_band = src.read(1).astype(np.float32) + src_img = sitk.GetImageFromArray(src_band) + + # 7) 重采样:设置参考图像 + 变换=参考→源(tps_ref2src) + res = sitk.ResampleImageFilter() + res.SetReferenceImage(ref_img) + res.SetTransform(tps_ref2src) + res.SetInterpolator(sitk.sitkLinear) + if src.nodata is not None: + res.SetDefaultPixelValue(float(src.nodata)) + + # 8) 写 ENVI/BIP:对所有波段逐一 TPS 重采样 + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + src_img_band = sitk.GetImageFromArray(src_band) + warped = res.Execute(src_img_band) + band_data = sitk.GetArrayFromImage(warped).astype(out_profile["dtype"]) + out_ds.write(band_data, b) + + logger.info(f"成功配准(TPS-SimpleITK): {bip_path.name} -> {out_path.name}") + return True + + else: + logger.warning("PIRT 和 SimpleITK TPS 都不可用,回退到仿射") + # 继续到仿射回退 + + except Exception as e: + logger.warning(f"B样条变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V4.py b/test V4.py new file mode 100644 index 0000000..b5ac605 --- /dev/null +++ b/test V4.py @@ -0,0 +1,467 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +直接进行配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\yaopu") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\yaopu\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 使用密集匹配模型的稠密流直接进行配准 + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 10 # 最少内点数 +MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# ---------- 工具函数 ---------- +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + return False + + # ==== 稠密流直接重采样(无需后续显式变换估计) ==== + + # 1) 取稠密流(优先 ref->src)。不同模型的键名可能不同,这里做兼容 + flow_small = None + for k in ["flow_ref2src", "flow21", "flow_1_0", "flow10", "flow"]: + if k in result: + flow_small = result[k] + break + + if flow_small is None: + # 回退:优先 DIS 光流(更快/稳),若不可用再用 Farneback (ref -> src) + ref_small_rgb = np.transpose(ref_small, (1, 2, 0)) # (H,W,3) + src_small_rgb = np.transpose(src_small, (1, 2, 0)) + ref_small_gray = cv2.cvtColor((ref_small_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) + src_small_gray = cv2.cvtColor((src_small_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY) + + flow_small = None + try: + dis = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_MEDIUM) + flow_small = dis.calc(ref_small_gray, src_small_gray, None).astype(np.float32) + except Exception: + pass + + if flow_small is None: + # 典型参数:可按影像特性微调(窗口、迭代次数等) + flow_small = cv2.calcOpticalFlowFarneback( + ref_small_gray, src_small_gray, + None, 0.5, 3, 25, 3, 5, 1.2, 0 + ).astype(np.float32) + + # flow_small 期望形状 (h_s, w_s, 2),分量为 (dx, dy): 参考像素到源像素的位移 + flow_small = np.asarray(flow_small, dtype=np.float32) + if flow_small.ndim != 3 or flow_small.shape[2] != 2: + logger.warning(f"稠密流形状异常: {flow_small.shape}") + return False + + # 2) 将小图的流放大到 ROI 全分辨率,并按比例放大位移 + roi_h, roi_w = ref_img.shape[1], ref_img.shape[2] # 注意 ref_img 是 ROI 子图 + scale_x = roi_w / flow_small.shape[1] + scale_y = roi_h / flow_small.shape[0] + flow_full = cv2.resize(flow_small, (roi_w, roi_h), interpolation=cv2.INTER_LINEAR) + flow_full[..., 0] *= scale_x # dx + flow_full[..., 1] *= scale_y # dy + + # 3) 生成 remap 所需的源坐标图(map_x, map_y),在"参考ROI坐标系"内工作 + yy, xx = np.meshgrid(np.arange(roi_h, dtype=np.float32), + np.arange(roi_w, dtype=np.float32), indexing="ij") + map_x = xx + flow_full[..., 0] # 到源图的 x(列) + map_y = yy + flow_full[..., 1] # 到源图的 y(行) + + # 4) 根据有效映射范围求最小外接矩形(仅统计落在源图范围内的像素) + valid = (map_x >= 0) & (map_x <= (src.width - 1)) & (map_y >= 0) & (map_y <= (src.height - 1)) + if not np.any(valid): + logger.warning(f"稠密流无有效映射: {bip_path.name}") + return False + + ys, xs = np.where(valid) + pad = 0 + min_y = max(int(ys.min()) - pad, 0) + max_y = min(int(ys.max()) + 1 + pad, roi_h) + min_x = max(int(xs.min()) - pad, 0) + max_x = min(int(xs.max()) + 1 + pad, roi_w) + + crop_h = max_y - min_y + crop_w = max_x - min_x + if crop_h <= 0 or crop_w <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 只对外接矩形区域做重采样,减少内存 + map_x_crop = map_x[min_y:max_y, min_x:max_x].astype(np.float32) + map_y_crop = map_y[min_y:max_y, min_x:max_x].astype(np.float32) + + # 5) 计算输出的地理变换:参考ROI窗口 + 外接矩形子窗口 + # 先得到 ROI 的 transform,再叠加子窗口偏移 + roi_transform = ref_dataset.window_transform(win) + crop_window_global = rasterio.windows.Window( + win.col_off + min_x, win.row_off + min_y, crop_w, crop_h + ) + out_transform = ref_dataset.window_transform(crop_window_global) + + # 6) 写出 ENVI/BIP(按最小外接矩形) + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=crop_h, + width=crop_w, + count=src.count, + transform=out_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + # 反向映射采样:输出像素在参考ROI坐标,去源图(map_y,map_x)取值 + warped = cv2.remap( + src_band, map_x_crop, map_y_crop, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=float(dst_nodata) + ).astype(np.float32) + + # 转回目标 dtype,保持 nodata + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (warped == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + warped = np.clip(warped, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + warped[mask] = dst_nodata + else: + warped = warped.astype(out_profile["dtype"]) + + out_ds.write(warped, b) + + logger.info(f"成功配准(DenseFlow): {bip_path.name} -> {out_path.name}") + return True + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V5.1.py b/test V5.1.py new file mode 100644 index 0000000..480ee18 --- /dev/null +++ b/test V5.1.py @@ -0,0 +1,1085 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用 实现非刚性配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +from vismatch.viz import plot_matches, plot_keypoints +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\yaopu") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\yaopu\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 10 # 最少内点数 +MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + # 保存匹配可视化图像(使用与匹配同尺寸的图像,保持CHW格式) + viz_dir = out_dir / "visualizations" + viz_dir.mkdir(exist_ok=True) + + matches_path = viz_dir / f"{bip_path.stem}_matches.png" + plot_matches(src_small, ref_small, result, save_path=str(matches_path)) + logger.info(f"匹配可视化已保存: {matches_path}") + + # 关键点可视化(源图像) + kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png" + plot_keypoints( + src_small, + {"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]}, + save_path=str(kpts_src_path) + ) + logger.info(f"源图像关键点可视化已保存: {kpts_src_path}") + + # 关键点可视化(参考图像) + kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png" + plot_keypoints( + ref_small, + {"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]}, + save_path=str(kpts_ref_path) + ) + logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V5.py b/test V5.py new file mode 100644 index 0000000..e0448ed --- /dev/null +++ b/test V5.py @@ -0,0 +1,1058 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用 实现非刚性配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\yaopu") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\yaopu\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 10 # 最少内点数 +MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V6.py b/test V6.py new file mode 100644 index 0000000..fc94794 --- /dev/null +++ b/test V6.py @@ -0,0 +1,1509 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用 实现非刚性配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging +import threading +import queue +from dataclasses import dataclass +import tkinter as tk +from tkinter import ttk, filedialog, messagebox +import sys +import os + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + +@dataclass +class Config: + """配置参数类""" + ref_tif: str + bip_dir: str + out_dir: str + matcher_name: str + device: str + transform_methods: list + match_max_side: int + roi_pad_px: int + min_inliers: int + min_inlier_ratio: float + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 默认配置,请根据实际情况修改这些路径 +DEFAULT_REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +DEFAULT_BIP_DIR = r"E:\is2\yaopu" # .bip 文件所在文件夹 +DEFAULT_OUT_DIR = r"E:\is2\yaopu\output" # 输出文件夹 + +# 默认匹配算法选择 +DEFAULT_MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEFAULT_DEVICE = "cuda" # 或 "cpu" + +# 默认变换方法选择(按优先级尝试) +DEFAULT_TRANSFORM_METHODS = ["homography", "affine", "piecewise_affine"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial" + +# 默认匹配参数 +DEFAULT_MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +DEFAULT_ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 默认质量控制阈值 +DEFAULT_MIN_INLIERS = 10 # 最少内点数 +DEFAULT_MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 向后兼容的全局变量(用于命令行模式) +REF_TIF = DEFAULT_REF_TIF +BIP_DIR = Path(DEFAULT_BIP_DIR) +OUT_DIR = Path(DEFAULT_OUT_DIR) +MATCHER_NAME = DEFAULT_MATCHER_NAME +DEVICE = DEFAULT_DEVICE +TRANSFORM_METHODS = DEFAULT_TRANSFORM_METHODS +MATCH_MAX_SIDE = DEFAULT_MATCH_MAX_SIDE +ROI_PAD_PX = DEFAULT_ROI_PAD_PX +MIN_INLIERS = DEFAULT_MIN_INLIERS +MIN_INLIER_RATIO = DEFAULT_MIN_INLIER_RATIO + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def run_batch(config: Config, on_progress=None, on_log=None, stop_event=None): + """批量配准处理函数 + + Args: + config: 配置参数 + on_progress: 进度回调函数 (current_idx, total, filename) + on_log: 日志回调函数 (message) + stop_event: 停止事件,用于取消处理 + """ + def log(message): + if on_log: + on_log(message) + logger.info(message) + + log("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(config.ref_tif).exists(): + log(f"参考文件不存在: {config.ref_tif}") + return + + if not Path(config.bip_dir).exists(): + log(f"BIP文件夹不存在: {config.bip_dir}") + return + + # 创建输出目录 + out_dir = Path(config.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # 创建统计输出目录和文件 + stats_dir = out_dir / "stats" + stats_dir.mkdir(parents=True, exist_ok=True) + stats_csv = stats_dir / "registration_stats.csv" + + # 初始化统计CSV文件 + init_stats_csv(stats_csv) + log(f"统计信息将保存到: {stats_csv}") + + # 初始化匹配器 + log(f"初始化匹配器: {config.matcher_name} on {config.device}") + matcher = get_matcher(config.matcher_name, device=config.device) + + # 打开参考文件 + with rasterio.open(config.ref_tif) as ref: + log(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_dir = Path(config.bip_dir) + bip_files = list(bip_dir.glob("*.bip")) + log(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for i, bip_path in enumerate(bip_files): + if stop_event and stop_event.is_set(): + log("处理被用户取消") + break + + if on_progress: + on_progress(i, len(bip_files), bip_path.name) + + if process_bip_to_tif(bip_path, ref, matcher, out_dir, stats_csv): + success_count += 1 + + if on_progress: + on_progress(len(bip_files), len(bip_files), "完成") + + log(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +def main(): + """命令行入口""" + # 使用默认配置运行 + config = Config( + ref_tif=REF_TIF, + bip_dir=BIP_DIR, + out_dir=OUT_DIR, + matcher_name=MATCHER_NAME, + device=DEVICE, + transform_methods=TRANSFORM_METHODS, + match_max_side=MATCH_MAX_SIDE, + roi_pad_px=ROI_PAD_PX, + min_inliers=MIN_INLIERS, + min_inlier_ratio=MIN_INLIER_RATIO + ) + run_batch(config) + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + +# ---------- GUI 相关 ---------- +class QueueHandler(logging.Handler): + """自定义日志处理器,将日志发送到队列""" + def __init__(self, log_queue): + super().__init__() + self.log_queue = log_queue + + def emit(self, record): + self.log_queue.put(self.format(record)) + +class RegistrationGUI: + def __init__(self, root): + self.root = root + self.root.title("遥感影像批量配准工具") + self.root.geometry("1000x800") + + # 日志队列和停止事件 + self.log_queue = queue.Queue() + self.stop_event = threading.Event() + self.processing_thread = None + + # 设置日志处理器 + queue_handler = QueueHandler(self.log_queue) + queue_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) + logger.addHandler(queue_handler) + logger.setLevel(logging.INFO) + + # 创建GUI组件 + self.create_widgets() + + # 定期检查日志队列 + self.check_log_queue() + + def create_widgets(self): + """创建GUI组件""" + # 主框架 + main_frame = ttk.Frame(self.root, padding="10") + main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) + + # 配置展开面板 + config_frame = ttk.LabelFrame(main_frame, text="配置参数", padding="5") + config_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(0, 10)) + + # 输入文件选择 + ttk.Label(config_frame, text="参考TIF文件:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) + self.ref_tif_var = tk.StringVar(value=DEFAULT_REF_TIF) + ttk.Entry(config_frame, textvariable=self.ref_tif_var, width=50).grid(row=0, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件", command=self.select_ref_tif).grid(row=0, column=2) + + ttk.Label(config_frame, text="BIP文件夹:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5)) + self.bip_dir_var = tk.StringVar(value=DEFAULT_BIP_DIR) + ttk.Entry(config_frame, textvariable=self.bip_dir_var, width=50).grid(row=1, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件夹", command=self.select_bip_dir).grid(row=1, column=2) + + ttk.Label(config_frame, text="输出文件夹:").grid(row=2, column=0, sticky=tk.W, padx=(0, 5)) + self.out_dir_var = tk.StringVar(value=DEFAULT_OUT_DIR) + ttk.Entry(config_frame, textvariable=self.out_dir_var, width=50).grid(row=2, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件夹", command=self.select_out_dir).grid(row=2, column=2) + + # 匹配器选择 + ttk.Label(config_frame, text="匹配算法:").grid(row=3, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0)) + self.matcher_var = tk.StringVar(value=DEFAULT_MATCHER_NAME) + matcher_combo = ttk.Combobox(config_frame, textvariable=self.matcher_var, width=47) + matcher_combo['values'] = [ + "liftfeat", "loftr", "eloftr", "se2loftr", "xoftr", "aspanformer", + "matchanything-eloftr", "matchanything-roma", "matchformer", + "sift-lightglue", "superpoint-lightglue", "disk-lightglue", + "aliked-lightglue", "doghardnet-lightglue", "roma", "romav2", + "tiny-roma", "dedode", "steerers", "affine-steerers", + "dedode-kornia", "sift-nn", "orb-nn", "patch2pix", "superglue", + "r2d2", "d2net", "duster", "master", "doghardnet-nn", "xfeat", + "xfeat-star", "xfeat-lightglue", "dedode-lightglue", "gim-dkm", + "gim-lightglue", "omniglue", "xfeat-subpx", "xfeat-lightglue-subpx", + "dedode-subpx", "superpoint-lightglue-subpx", "aliked-lightglue-subpx", + "sift-sphereglue", "superpoint-sphereglue", "minima", "minima-roma", + "minima-roma-tiny", "minima-superpoint-lightglue", "minima-loftr", + "minima-xoftr", "edm", "lisrd-aliked", "lisrd-superpoint", "lisrd", + "lisrd-sift", "ripe", "topicfm", "topicfm-plus", "silk", "zippypoint", + "xfeat-steerers-perm", "xfeat-steerers-learned", "xfeat-star-steerers-perm", + "xfeat-star-steerers-learned" + ] + matcher_combo.grid(row=3, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + # 设备选择 + ttk.Label(config_frame, text="设备:").grid(row=4, column=0, sticky=tk.W, padx=(0, 5)) + self.device_var = tk.StringVar(value=DEFAULT_DEVICE) + device_frame = ttk.Frame(config_frame) + device_frame.grid(row=4, column=1, columnspan=2, sticky=(tk.W, tk.E)) + ttk.Radiobutton(device_frame, text="CUDA", variable=self.device_var, value="cuda").pack(side=tk.LEFT) + ttk.Radiobutton(device_frame, text="CPU", variable=self.device_var, value="cpu").pack(side=tk.LEFT) + + # 变换方法选择 + ttk.Label(config_frame, text="变换方法 (按优先级):").grid(row=5, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0)) + + transform_frame = ttk.Frame(config_frame) + transform_frame.grid(row=5, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + # 变换方法列表 + self.transform_listbox = tk.Listbox(transform_frame, selectmode=tk.MULTIPLE, height=5, exportselection=False) + transform_methods = ["similarity", "affine", "homography", "piecewise_affine", "polynomial"] + for method in transform_methods: + self.transform_listbox.insert(tk.END, method) + if method in DEFAULT_TRANSFORM_METHODS: + self.transform_listbox.selection_set(transform_methods.index(method)) + + scrollbar = ttk.Scrollbar(transform_frame, orient=tk.VERTICAL, command=self.transform_listbox.yview) + self.transform_listbox.configure(yscrollcommand=scrollbar.set) + + self.transform_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 移动按钮 + button_frame = ttk.Frame(transform_frame) + button_frame.pack(side=tk.RIGHT, padx=(5, 0)) + ttk.Button(button_frame, text="↑ 上移", command=self.move_up).pack(fill=tk.X, pady=(0, 2)) + ttk.Button(button_frame, text="↓ 下移", command=self.move_down).pack(fill=tk.X) + + # 参数设置 + param_frame = ttk.LabelFrame(config_frame, text="参数设置", padding="5") + param_frame.grid(row=6, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(10, 0)) + + ttk.Label(param_frame, text="匹配最大边长:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) + self.match_max_side_var = tk.IntVar(value=DEFAULT_MATCH_MAX_SIDE) + ttk.Entry(param_frame, textvariable=self.match_max_side_var, width=10).grid(row=0, column=1, sticky=tk.W) + + ttk.Label(param_frame, text="ROI填充像素:").grid(row=0, column=2, sticky=tk.W, padx=(10, 5)) + self.roi_pad_px_var = tk.IntVar(value=DEFAULT_ROI_PAD_PX) + ttk.Entry(param_frame, textvariable=self.roi_pad_px_var, width=10).grid(row=0, column=3, sticky=tk.W) + + ttk.Label(param_frame, text="最少内点数:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0)) + self.min_inliers_var = tk.IntVar(value=DEFAULT_MIN_INLIERS) + ttk.Entry(param_frame, textvariable=self.min_inliers_var, width=10).grid(row=1, column=1, sticky=tk.W, pady=(5, 0)) + + ttk.Label(param_frame, text="最少内点比例:").grid(row=1, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0)) + self.min_inlier_ratio_var = tk.DoubleVar(value=DEFAULT_MIN_INLIER_RATIO) + ttk.Entry(param_frame, textvariable=self.min_inlier_ratio_var, width=10).grid(row=1, column=3, sticky=tk.W, pady=(5, 0)) + + # 控制按钮 + control_frame = ttk.Frame(main_frame) + control_frame.grid(row=1, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + self.start_btn = ttk.Button(control_frame, text="开始处理", command=self.start_processing) + self.start_btn.pack(side=tk.LEFT, padx=(0, 10)) + + self.stop_btn = ttk.Button(control_frame, text="停止处理", command=self.stop_processing, state=tk.DISABLED) + self.stop_btn.pack(side=tk.LEFT) + + # 进度条 + progress_frame = ttk.LabelFrame(main_frame, text="处理进度", padding="5") + progress_frame.grid(row=2, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + self.progress_var = tk.DoubleVar() + self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100) + self.progress_bar.pack(fill=tk.X, pady=(0, 5)) + + self.progress_label = ttk.Label(progress_frame, text="准备就绪") + self.progress_label.pack(anchor=tk.W) + + # 日志窗口 + log_frame = ttk.LabelFrame(main_frame, text="处理日志", padding="5") + log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(10, 0)) + + # 日志文本框和滚动条 + log_text_frame = ttk.Frame(log_frame) + log_text_frame.pack(fill=tk.BOTH, expand=True) + + self.log_text = tk.Text(log_text_frame, height=15, wrap=tk.WORD) + scrollbar = ttk.Scrollbar(log_text_frame, orient=tk.VERTICAL, command=self.log_text.yview) + self.log_text.configure(yscrollcommand=scrollbar.set) + + self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 日志控制按钮 + log_btn_frame = ttk.Frame(log_frame) + log_btn_frame.pack(fill=tk.X, pady=(5, 0)) + + ttk.Button(log_btn_frame, text="清空日志", command=self.clear_log).pack(side=tk.LEFT, padx=(0, 5)) + ttk.Button(log_btn_frame, text="保存日志", command=self.save_log).pack(side=tk.LEFT) + + # 配置网格权重 + self.root.columnconfigure(0, weight=1) + self.root.rowconfigure(0, weight=1) + main_frame.columnconfigure(1, weight=1) + main_frame.rowconfigure(3, weight=1) + + def select_ref_tif(self): + """选择参考TIF文件""" + filename = filedialog.askopenfilename( + title="选择参考TIF文件", + filetypes=[("TIF files", "*.tif"), ("All files", "*.*")] + ) + if filename: + self.ref_tif_var.set(filename) + + def select_bip_dir(self): + """选择BIP文件夹""" + dirname = filedialog.askdirectory(title="选择BIP文件夹") + if dirname: + self.bip_dir_var.set(dirname) + + def select_out_dir(self): + """选择输出文件夹""" + dirname = filedialog.askdirectory(title="选择输出文件夹") + if dirname: + self.out_dir_var.set(dirname) + + def move_up(self): + """上移选中的变换方法""" + selection = self.transform_listbox.curselection() + if selection and selection[0] > 0: + idx = selection[0] + text = self.transform_listbox.get(idx) + self.transform_listbox.delete(idx) + self.transform_listbox.insert(idx - 1, text) + self.transform_listbox.selection_set(idx - 1) + + def move_down(self): + """下移选中的变换方法""" + selection = self.transform_listbox.curselection() + if selection and selection[0] < self.transform_listbox.size() - 1: + idx = selection[0] + text = self.transform_listbox.get(idx) + self.transform_listbox.delete(idx) + self.transform_listbox.insert(idx + 1, text) + self.transform_listbox.selection_set(idx + 1) + + def start_processing(self): + """开始处理""" + if self.processing_thread and self.processing_thread.is_alive(): + messagebox.showwarning("警告", "处理正在进行中") + return + + # 获取选中的变换方法 + selected_indices = self.transform_listbox.curselection() + if not selected_indices: + messagebox.showwarning("警告", "请至少选择一种变换方法") + return + + transform_methods = [] + for idx in selected_indices: + transform_methods.append(self.transform_listbox.get(idx)) + + # 创建配置 + config = Config( + ref_tif=self.ref_tif_var.get(), + bip_dir=self.bip_dir_var.get(), + out_dir=self.out_dir_var.get(), + matcher_name=self.matcher_var.get(), + device=self.device_var.get(), + transform_methods=transform_methods, + match_max_side=self.match_max_side_var.get(), + roi_pad_px=self.roi_pad_px_var.get(), + min_inliers=self.min_inliers_var.get(), + min_inlier_ratio=self.min_inlier_ratio_var.get() + ) + + # 重置停止事件 + self.stop_event.clear() + + # 禁用开始按钮,启用停止按钮 + self.start_btn.config(state=tk.DISABLED) + self.stop_btn.config(state=tk.NORMAL) + self.progress_var.set(0) + self.progress_label.config(text="正在初始化...") + + # 在后台线程中运行处理 + self.processing_thread = threading.Thread( + target=self.run_processing, + args=(config,), + daemon=True + ) + self.processing_thread.start() + + def stop_processing(self): + """停止处理""" + if self.processing_thread and self.processing_thread.is_alive(): + self.stop_event.set() + self.progress_label.config(text="正在停止...") + + def run_processing(self, config): + """在后台线程中运行处理""" + try: + run_batch(config, self.on_progress, self.on_log, self.stop_event) + except Exception as e: + self.log_queue.put(f"处理过程中发生错误: {e}") + finally: + # 恢复按钮状态 + self.root.after(0, lambda: self.start_btn.config(state=tk.NORMAL)) + self.root.after(0, lambda: self.stop_btn.config(state=tk.DISABLED)) + self.root.after(0, lambda: self.progress_label.config(text="处理完成")) + + def on_progress(self, current, total, filename): + """进度回调""" + if total > 0: + progress = (current / total) * 100 + self.root.after(0, lambda: self.progress_var.set(progress)) + self.root.after(0, lambda: self.progress_label.config(text=f"处理中: {filename} ({current}/{total})")) + + def on_log(self, message): + """日志回调""" + self.log_queue.put(message) + + def check_log_queue(self): + """检查日志队列并更新GUI""" + try: + while True: + message = self.log_queue.get_nowait() + self.log_text.insert(tk.END, message + '\n') + self.log_text.see(tk.END) + except queue.Empty: + pass + + # 每100ms检查一次 + self.root.after(100, self.check_log_queue) + + def clear_log(self): + """清空日志""" + self.log_text.delete(1.0, tk.END) + + def save_log(self): + """保存日志""" + filename = filedialog.asksaveasfilename( + title="保存日志", + defaultextension=".txt", + filetypes=[("Text files", "*.txt"), ("All files", "*.*")] + ) + if filename: + with open(filename, 'w', encoding='utf-8') as f: + f.write(self.log_text.get(1.0, tk.END)) + +def create_gui(): + """创建GUI""" + root = tk.Tk() + app = RegistrationGUI(root) + root.mainloop() + +if __name__ == "__main__": + if len(sys.argv) > 1 and sys.argv[1] == "--cli": + # 命令行模式 + main() + else: + # 默认GUI模式 + create_gui() \ No newline at end of file diff --git a/test V7.py b/test V7.py new file mode 100644 index 0000000..e57772c --- /dev/null +++ b/test V7.py @@ -0,0 +1,1534 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用 实现非刚性配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging +import threading +import queue +from dataclasses import dataclass +import tkinter as tk +from tkinter import ttk, filedialog, messagebox +import sys +import os + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + +@dataclass +class Config: + """配置参数类""" + ref_tif: str + bip_dir: str + out_dir: str + matcher_name: str + device: str + transform_methods: list + match_max_side: int + roi_pad_px: int + min_inliers: int + min_inlier_ratio: float + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 默认配置,请根据实际情况修改这些路径 +DEFAULT_REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径 +DEFAULT_BIP_DIR = r"E:\is2\yaopu" # .bip 文件所在文件夹 +DEFAULT_OUT_DIR = r"E:\is2\yaopu\output" # 输出文件夹 + +# 默认匹配算法选择 +DEFAULT_MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEFAULT_DEVICE = "cuda" # 或 "cpu" + +# 默认变换方法选择(按优先级尝试) +DEFAULT_TRANSFORM_METHODS = ["homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial" + +# 默认匹配参数 +DEFAULT_MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +DEFAULT_ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) + +# 默认质量控制阈值 +DEFAULT_MIN_INLIERS = 10 # 最少内点数 +DEFAULT_MIN_INLIER_RATIO = 0.01 # 最少内点比例 + +# 向后兼容的全局变量(用于命令行模式) +REF_TIF = DEFAULT_REF_TIF +BIP_DIR = Path(DEFAULT_BIP_DIR) +OUT_DIR = Path(DEFAULT_OUT_DIR) +MATCHER_NAME = DEFAULT_MATCHER_NAME +DEVICE = DEFAULT_DEVICE +TRANSFORM_METHODS = DEFAULT_TRANSFORM_METHODS +MATCH_MAX_SIDE = DEFAULT_MATCH_MAX_SIDE +ROI_PAD_PX = DEFAULT_ROI_PAD_PX +MIN_INLIERS = DEFAULT_MIN_INLIERS +MIN_INLIER_RATIO = DEFAULT_MIN_INLIER_RATIO + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def process_bip_to_tif( + bip_path: Path, + ref_dataset, + matcher, + out_dir: Path, + stats_csv: Path, + transform_methods: list, + match_max_side: int, + roi_pad_px: int, + min_inliers: int, + min_inlier_ratio: float, +): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, roi_pad_px, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, match_max_side) + ref_small = _downscale_chw(ref_img, match_max_side) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < min_inliers or ratio < min_inlier_ratio: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in transform_methods: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def run_batch(config: Config, on_progress=None, on_log=None, stop_event=None): + """批量配准处理函数 + + Args: + config: 配置参数 + on_progress: 进度回调函数 (current_idx, total, filename) + on_log: 日志回调函数 (message) + stop_event: 停止事件,用于取消处理 + """ + def log(message): + if on_log: + on_log(message) + logger.info(message) + + log("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(config.ref_tif).exists(): + log(f"参考文件不存在: {config.ref_tif}") + return + + if not Path(config.bip_dir).exists(): + log(f"BIP文件夹不存在: {config.bip_dir}") + return + + # 创建输出目录 + out_dir = Path(config.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + # 创建统计输出目录和文件 + stats_dir = out_dir / "stats" + stats_dir.mkdir(parents=True, exist_ok=True) + stats_csv = stats_dir / "registration_stats.csv" + + # 初始化统计CSV文件 + init_stats_csv(stats_csv) + log(f"统计信息将保存到: {stats_csv}") + + # 初始化匹配器 + log(f"初始化匹配器: {config.matcher_name} on {config.device}") + matcher = get_matcher(config.matcher_name, device=config.device) + + # 打开参考文件 + with rasterio.open(config.ref_tif) as ref: + log(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_dir = Path(config.bip_dir) + bip_files = list(bip_dir.glob("*.bip")) + log(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for i, bip_path in enumerate(bip_files): + if stop_event and stop_event.is_set(): + log("处理被用户取消") + break + + if on_progress: + on_progress(i, len(bip_files), bip_path.name) + + if process_bip_to_tif( + bip_path, ref, matcher, out_dir, stats_csv, + config.transform_methods, + config.match_max_side, + config.roi_pad_px, + config.min_inliers, + config.min_inlier_ratio, + ): + success_count += 1 + + if on_progress: + on_progress(len(bip_files), len(bip_files), "完成") + + log(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +def main(): + """命令行入口""" + # 使用默认配置运行 + config = Config( + ref_tif=REF_TIF, + bip_dir=BIP_DIR, + out_dir=OUT_DIR, + matcher_name=MATCHER_NAME, + device=DEVICE, + transform_methods=TRANSFORM_METHODS, + match_max_side=MATCH_MAX_SIDE, + roi_pad_px=ROI_PAD_PX, + min_inliers=MIN_INLIERS, + min_inlier_ratio=MIN_INLIER_RATIO + ) + run_batch(config) + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +# ---------- GUI 相关 ---------- +class QueueHandler(logging.Handler): + """自定义日志处理器,将日志发送到队列""" + def __init__(self, log_queue): + super().__init__() + self.log_queue = log_queue + + def emit(self, record): + self.log_queue.put(self.format(record)) + +class RegistrationGUI: + def __init__(self, root): + self.root = root + self.root.title("遥感影像批量配准工具") + self.root.geometry("1000x800") + + # 日志队列和停止事件 + self.log_queue = queue.Queue() + self.stop_event = threading.Event() + self.processing_thread = None + + # 设置日志处理器 + queue_handler = QueueHandler(self.log_queue) + queue_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) + logger.addHandler(queue_handler) + logger.setLevel(logging.INFO) + + # 创建GUI组件 + self.create_widgets() + + # 定期检查日志队列 + self.check_log_queue() + + def create_widgets(self): + """创建GUI组件""" + # 主框架 + main_frame = ttk.Frame(self.root, padding="10") + main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) + + # 配置展开面板 + config_frame = ttk.LabelFrame(main_frame, text="配置参数", padding="5") + config_frame.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(0, 10)) + + # 输入文件选择 + ttk.Label(config_frame, text="参考TIF文件:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) + self.ref_tif_var = tk.StringVar(value=DEFAULT_REF_TIF) + ttk.Entry(config_frame, textvariable=self.ref_tif_var, width=50).grid(row=0, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件", command=self.select_ref_tif).grid(row=0, column=2) + + ttk.Label(config_frame, text="BIP文件夹:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5)) + self.bip_dir_var = tk.StringVar(value=DEFAULT_BIP_DIR) + ttk.Entry(config_frame, textvariable=self.bip_dir_var, width=50).grid(row=1, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件夹", command=self.select_bip_dir).grid(row=1, column=2) + + ttk.Label(config_frame, text="输出文件夹:").grid(row=2, column=0, sticky=tk.W, padx=(0, 5)) + self.out_dir_var = tk.StringVar(value=DEFAULT_OUT_DIR) + ttk.Entry(config_frame, textvariable=self.out_dir_var, width=50).grid(row=2, column=1, sticky=(tk.W, tk.E), padx=(0, 5)) + ttk.Button(config_frame, text="选择文件夹", command=self.select_out_dir).grid(row=2, column=2) + + # 匹配器选择 + ttk.Label(config_frame, text="匹配算法:").grid(row=3, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0)) + self.matcher_var = tk.StringVar(value=DEFAULT_MATCHER_NAME) + matcher_combo = ttk.Combobox(config_frame, textvariable=self.matcher_var, width=47) + matcher_combo['values'] = [ + "liftfeat", "loftr", "eloftr", "se2loftr", "xoftr", "aspanformer", + "matchanything-eloftr", "matchanything-roma", "matchformer", + "sift-lightglue", "superpoint-lightglue", "disk-lightglue", + "aliked-lightglue", "doghardnet-lightglue", "roma", "romav2", + "tiny-roma", "dedode", "steerers", "affine-steerers", + "dedode-kornia", "sift-nn", "orb-nn", "patch2pix", "superglue", + "r2d2", "d2net", "duster", "master", "doghardnet-nn", "xfeat", + "xfeat-star", "xfeat-lightglue", "dedode-lightglue", "gim-dkm", + "gim-lightglue", "omniglue", "xfeat-subpx", "xfeat-lightglue-subpx", + "dedode-subpx", "superpoint-lightglue-subpx", "aliked-lightglue-subpx", + "sift-sphereglue", "superpoint-sphereglue", "minima", "minima-roma", + "minima-roma-tiny", "minima-superpoint-lightglue", "minima-loftr", + "minima-xoftr", "edm", "lisrd-aliked", "lisrd-superpoint", "lisrd", + "lisrd-sift", "ripe", "topicfm", "topicfm-plus", "silk", "zippypoint", + "xfeat-steerers-perm", "xfeat-steerers-learned", "xfeat-star-steerers-perm", + "xfeat-star-steerers-learned" + ] + matcher_combo.grid(row=3, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + # 设备选择 + ttk.Label(config_frame, text="设备:").grid(row=4, column=0, sticky=tk.W, padx=(0, 5)) + self.device_var = tk.StringVar(value=DEFAULT_DEVICE) + device_frame = ttk.Frame(config_frame) + device_frame.grid(row=4, column=1, columnspan=2, sticky=(tk.W, tk.E)) + ttk.Radiobutton(device_frame, text="CUDA", variable=self.device_var, value="cuda").pack(side=tk.LEFT) + ttk.Radiobutton(device_frame, text="CPU", variable=self.device_var, value="cpu").pack(side=tk.LEFT) + + # 变换方法选择 + ttk.Label(config_frame, text="变换方法 (按优先级):").grid(row=5, column=0, sticky=tk.W, padx=(0, 5), pady=(10, 0)) + + transform_frame = ttk.Frame(config_frame) + transform_frame.grid(row=5, column=1, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + # 变换方法列表 + self.transform_listbox = tk.Listbox(transform_frame, selectmode=tk.MULTIPLE, height=5, exportselection=False) + transform_methods = ["similarity", "affine", "homography", "piecewise_affine", "polynomial"] + for method in transform_methods: + self.transform_listbox.insert(tk.END, method) + if method in DEFAULT_TRANSFORM_METHODS: + self.transform_listbox.selection_set(transform_methods.index(method)) + + scrollbar = ttk.Scrollbar(transform_frame, orient=tk.VERTICAL, command=self.transform_listbox.yview) + self.transform_listbox.configure(yscrollcommand=scrollbar.set) + + self.transform_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 移动按钮 + button_frame = ttk.Frame(transform_frame) + button_frame.pack(side=tk.RIGHT, padx=(5, 0)) + ttk.Button(button_frame, text="↑ 上移", command=self.move_up).pack(fill=tk.X, pady=(0, 2)) + ttk.Button(button_frame, text="↓ 下移", command=self.move_down).pack(fill=tk.X) + + # 参数设置 + param_frame = ttk.LabelFrame(config_frame, text="参数设置", padding="5") + param_frame.grid(row=6, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=(10, 0)) + + ttk.Label(param_frame, text="匹配最大边长:").grid(row=0, column=0, sticky=tk.W, padx=(0, 5)) + self.match_max_side_var = tk.IntVar(value=DEFAULT_MATCH_MAX_SIDE) + ttk.Entry(param_frame, textvariable=self.match_max_side_var, width=10).grid(row=0, column=1, sticky=tk.W) + + ttk.Label(param_frame, text="ROI填充像素:").grid(row=0, column=2, sticky=tk.W, padx=(10, 5)) + self.roi_pad_px_var = tk.IntVar(value=DEFAULT_ROI_PAD_PX) + ttk.Entry(param_frame, textvariable=self.roi_pad_px_var, width=10).grid(row=0, column=3, sticky=tk.W) + + ttk.Label(param_frame, text="最少内点数:").grid(row=1, column=0, sticky=tk.W, padx=(0, 5), pady=(5, 0)) + self.min_inliers_var = tk.IntVar(value=DEFAULT_MIN_INLIERS) + ttk.Entry(param_frame, textvariable=self.min_inliers_var, width=10).grid(row=1, column=1, sticky=tk.W, pady=(5, 0)) + + ttk.Label(param_frame, text="最少内点比例:").grid(row=1, column=2, sticky=tk.W, padx=(10, 5), pady=(5, 0)) + self.min_inlier_ratio_var = tk.DoubleVar(value=DEFAULT_MIN_INLIER_RATIO) + ttk.Entry(param_frame, textvariable=self.min_inlier_ratio_var, width=10).grid(row=1, column=3, sticky=tk.W, pady=(5, 0)) + + # 控制按钮 + control_frame = ttk.Frame(main_frame) + control_frame.grid(row=1, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + self.start_btn = ttk.Button(control_frame, text="开始处理", command=self.start_processing) + self.start_btn.pack(side=tk.LEFT, padx=(0, 10)) + + self.stop_btn = ttk.Button(control_frame, text="停止处理", command=self.stop_processing, state=tk.DISABLED) + self.stop_btn.pack(side=tk.LEFT) + + # 进度条 + progress_frame = ttk.LabelFrame(main_frame, text="处理进度", padding="5") + progress_frame.grid(row=2, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=(10, 0)) + + self.progress_var = tk.DoubleVar() + self.progress_bar = ttk.Progressbar(progress_frame, variable=self.progress_var, maximum=100) + self.progress_bar.pack(fill=tk.X, pady=(0, 5)) + + self.progress_label = ttk.Label(progress_frame, text="准备就绪") + self.progress_label.pack(anchor=tk.W) + + # 日志窗口 + log_frame = ttk.LabelFrame(main_frame, text="处理日志", padding="5") + log_frame.grid(row=3, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S), pady=(10, 0)) + + # 日志文本框和滚动条 + log_text_frame = ttk.Frame(log_frame) + log_text_frame.pack(fill=tk.BOTH, expand=True) + + self.log_text = tk.Text(log_text_frame, height=15, wrap=tk.WORD) + scrollbar = ttk.Scrollbar(log_text_frame, orient=tk.VERTICAL, command=self.log_text.yview) + self.log_text.configure(yscrollcommand=scrollbar.set) + + self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 日志控制按钮 + log_btn_frame = ttk.Frame(log_frame) + log_btn_frame.pack(fill=tk.X, pady=(5, 0)) + + ttk.Button(log_btn_frame, text="清空日志", command=self.clear_log).pack(side=tk.LEFT, padx=(0, 5)) + ttk.Button(log_btn_frame, text="保存日志", command=self.save_log).pack(side=tk.LEFT) + + # 配置网格权重 + self.root.columnconfigure(0, weight=1) + self.root.rowconfigure(0, weight=1) + main_frame.columnconfigure(1, weight=1) + main_frame.rowconfigure(3, weight=1) + + def select_ref_tif(self): + """选择参考TIF文件""" + filename = filedialog.askopenfilename( + title="选择参考TIF文件", + filetypes=[("TIF files", "*.tif"), ("All files", "*.*")] + ) + if filename: + self.ref_tif_var.set(filename) + + def select_bip_dir(self): + """选择BIP文件夹""" + dirname = filedialog.askdirectory(title="选择BIP文件夹") + if dirname: + self.bip_dir_var.set(dirname) + + def select_out_dir(self): + """选择输出文件夹""" + dirname = filedialog.askdirectory(title="选择输出文件夹") + if dirname: + self.out_dir_var.set(dirname) + + def move_up(self): + """上移选中的变换方法""" + selection = self.transform_listbox.curselection() + if selection and selection[0] > 0: + idx = selection[0] + text = self.transform_listbox.get(idx) + self.transform_listbox.delete(idx) + self.transform_listbox.insert(idx - 1, text) + self.transform_listbox.selection_set(idx - 1) + + def move_down(self): + """下移选中的变换方法""" + selection = self.transform_listbox.curselection() + if selection and selection[0] < self.transform_listbox.size() - 1: + idx = selection[0] + text = self.transform_listbox.get(idx) + self.transform_listbox.delete(idx) + self.transform_listbox.insert(idx + 1, text) + self.transform_listbox.selection_set(idx + 1) + + def start_processing(self): + """开始处理""" + if self.processing_thread and self.processing_thread.is_alive(): + messagebox.showwarning("警告", "处理正在进行中") + return + + # 获取选中的变换方法 + selected_indices = self.transform_listbox.curselection() + if not selected_indices: + messagebox.showwarning("警告", "请至少选择一种变换方法") + return + + transform_methods = [] + for idx in selected_indices: + transform_methods.append(self.transform_listbox.get(idx)) + + # 创建配置 + config = Config( + ref_tif=self.ref_tif_var.get(), + bip_dir=self.bip_dir_var.get(), + out_dir=self.out_dir_var.get(), + matcher_name=self.matcher_var.get(), + device=self.device_var.get(), + transform_methods=transform_methods, + match_max_side=self.match_max_side_var.get(), + roi_pad_px=self.roi_pad_px_var.get(), + min_inliers=self.min_inliers_var.get(), + min_inlier_ratio=self.min_inlier_ratio_var.get() + ) + + # 重置停止事件 + self.stop_event.clear() + + # 禁用开始按钮,启用停止按钮 + self.start_btn.config(state=tk.DISABLED) + self.stop_btn.config(state=tk.NORMAL) + self.progress_var.set(0) + self.progress_label.config(text="正在初始化...") + + # 在后台线程中运行处理 + self.processing_thread = threading.Thread( + target=self.run_processing, + args=(config,), + daemon=True + ) + self.processing_thread.start() + + def stop_processing(self): + """停止处理""" + if self.processing_thread and self.processing_thread.is_alive(): + self.stop_event.set() + self.progress_label.config(text="正在停止...") + + def run_processing(self, config): + """在后台线程中运行处理""" + try: + run_batch(config, self.on_progress, self.on_log, self.stop_event) + except Exception as e: + self.log_queue.put(f"处理过程中发生错误: {e}") + finally: + # 恢复按钮状态 + self.root.after(0, lambda: self.start_btn.config(state=tk.NORMAL)) + self.root.after(0, lambda: self.stop_btn.config(state=tk.DISABLED)) + self.root.after(0, lambda: self.progress_label.config(text="处理完成")) + + def on_progress(self, current, total, filename): + """进度回调""" + if total > 0: + progress = (current / total) * 100 + self.root.after(0, lambda: self.progress_var.set(progress)) + self.root.after(0, lambda: self.progress_label.config(text=f"处理中: {filename} ({current}/{total})")) + + def on_log(self, message): + """日志回调""" + self.log_queue.put(message) + + def check_log_queue(self): + """检查日志队列并更新GUI""" + try: + while True: + message = self.log_queue.get_nowait() + self.log_text.insert(tk.END, message + '\n') + self.log_text.see(tk.END) + except queue.Empty: + pass + + # 每100ms检查一次 + self.root.after(100, self.check_log_queue) + + def clear_log(self): + """清空日志""" + self.log_text.delete(1.0, tk.END) + + def save_log(self): + """保存日志""" + filename = filedialog.asksaveasfilename( + title="保存日志", + defaultextension=".txt", + filetypes=[("Text files", "*.txt"), ("All files", "*.*")] + ) + if filename: + with open(filename, 'w', encoding='utf-8') as f: + f.write(self.log_text.get(1.0, tk.END)) + +def create_gui(): + """创建GUI""" + root = tk.Tk() + app = RegistrationGUI(root) + root.mainloop() + +if __name__ == "__main__": + if len(sys.argv) > 1 and sys.argv[1] == "--cli": + # 命令行模式 + main() + else: + # 默认GUI模式 + create_gui() diff --git a/test V8.py b/test V8.py new file mode 100644 index 0000000..16e74cf --- /dev/null +++ b/test V8.py @@ -0,0 +1,1207 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用有效区域掩膜参考影像 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +from vismatch.viz import plot_matches, plot_keypoints +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\guidingsahn\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\guidingsahn") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\guidingsahn\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["similarity", "affine", "homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) +MASK_PAD_PX = 100 # 匹配掩膜扩张像素(仅用于匹配阶段) + +# RANSAC 参数 +RANSAC_REPROJ_THRESHOLD = 5 # RANSAC重投影误差阈值(像素) +RANSAC_CONFIDENCE = 0.99 # RANSAC置信度 +RANSAC_MAX_ITERS = 1000 # RANSAC最大迭代次数 + +# 质量控制阈值 +MIN_INLIERS = 10 +MIN_INLIER_RATIO = 0.01 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, + ransacReprojThreshold=RANSAC_REPROJ_THRESHOLD, + confidence=RANSAC_CONFIDENCE, + maxIters=RANSAC_MAX_ITERS) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, + ransacReprojThreshold=RANSAC_REPROJ_THRESHOLD, + confidence=RANSAC_CONFIDENCE, + maxIters=RANSAC_MAX_ITERS) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, + ransacReprojThreshold=RANSAC_REPROJ_THRESHOLD, + confidence=RANSAC_CONFIDENCE, + maxIters=RANSAC_MAX_ITERS) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, + ransacReprojThreshold=RANSAC_REPROJ_THRESHOLD, + confidence=RANSAC_CONFIDENCE, + maxIters=RANSAC_MAX_ITERS) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def _downscale_mask_hw(mask_hw: np.ndarray, target_h: int, target_w: int) -> np.ndarray: + """将(H,W)二值掩膜缩放到目标尺寸,保持最近邻""" + m = cv2.resize(mask_hw.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST) + return m > 0 + +def _filter_matches_by_masks(result: dict, src_mask_small: np.ndarray, ref_mask_small: np.ndarray) -> dict: + """将匹配与内点严格限制在掩膜内""" + if src_mask_small is None or ref_mask_small is None: + return result + + def keep_in_mask(kpts: np.ndarray, mask_hw: np.ndarray) -> np.ndarray: + if kpts is None or len(kpts) == 0: + return np.zeros((0,), dtype=bool) + kpts = np.asarray(kpts) + xs = np.clip(np.rint(kpts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1) + ys = np.clip(np.rint(kpts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1) + return mask_hw[ys, xs] + + # 过滤 matched_kpts + if "matched_kpts0" in result and "matched_kpts1" in result: + mk0 = np.asarray(result["matched_kpts0"]) + mk1 = np.asarray(result["matched_kpts1"]) + if len(mk0) == len(mk1) and len(mk0) > 0: + keep_m = keep_in_mask(mk0, src_mask_small) & keep_in_mask(mk1, ref_mask_small) + result["matched_kpts0"] = mk0[keep_m] + result["matched_kpts1"] = mk1[keep_m] + + # 过滤 inlier_kpts + if "inlier_kpts0" in result and "inlier_kpts1" in result and result["inlier_kpts0"] is not None: + ik0 = np.asarray(result["inlier_kpts0"]) + ik1 = np.asarray(result["inlier_kpts1"]) + if len(ik0) == len(ik1) and len(ik0) > 0: + keep_i = keep_in_mask(ik0, src_mask_small) & keep_in_mask(ik1, ref_mask_small) + result["inlier_kpts0"] = ik0[keep_i] + result["inlier_kpts1"] = ik1[keep_i] + result["num_inliers"] = int(len(result["inlier_kpts0"])) + + return result + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用"源图有效掩膜"的包围盒推参考ROI(比整图bounds更贴近有效重叠) + try: + src_mask = (src.read_masks(1) > 0) # True=有效 + rows_any = np.any(src_mask, axis=1) + cols_any = np.any(src_mask, axis=0) + if rows_any.any() and cols_any.any(): + rmin = int(rows_any.argmax()) + rmax = int(src.height - 1 - rows_any[::-1].argmax()) + cmin = int(cols_any.argmax()) + cmax = int(src.width - 1 - cols_any[::-1].argmax()) + valid_win_src = rasterio.windows.Window(cmin, rmin, cmax - cmin + 1, rmax - rmin + 1) + valid_bounds_src = rasterio.windows.bounds(valid_win_src, transform=src.transform) + b = transform_bounds(src_crs, ref_crs, *valid_bounds_src, densify_pts=21) + else: + # 掩膜无效时回退到整图bounds + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + except Exception: + src_mask = None # 后续可选源图掩膜时用到 + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 将源图有效掩膜重投影到参考ROI,并适度膨胀后作为匹配掩膜 + try: + if src_mask is None: + src_mask = (src.read_masks(1) > 0) + ref_roi_transform = ref_dataset.window_transform(win) + roi_h, roi_w = int(win.height), int(win.width) + dst_mask = np.zeros((roi_h, roi_w), dtype=np.uint8) + + reproject( + source=src_mask.astype(np.uint8), + destination=dst_mask, + src_transform=src.transform, + src_crs=src_crs, + dst_transform=ref_roi_transform, + dst_crs=ref_crs, + resampling=Resampling.nearest + ) + + if MASK_PAD_PX > 0: + k = max(1, MASK_PAD_PX * 2 + 1) # odd kernel size + k = min(k, 99) # 防止核过大导致性能问题,可按需调整/删除 + kernel = np.ones((k, k), np.uint8) + dst_mask = cv2.dilate(dst_mask, kernel, iterations=1) + except Exception: + # 掩膜获取/重投影失败则不使用掩膜 + dst_mask = None + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 可选:将源图的有效掩膜也应用到源图,抑制无效区特征 + try: + src_mask3 = np.repeat((src_mask > 0)[None, ...], 3, axis=0).astype(src_img.dtype) + src_img = src_img * src_mask3 + except Exception: + pass + + # 将重投影到参考ROI的掩膜应用到参考图像(匹配仅在掩膜内进行) + if dst_mask is not None: + ref_mask3 = np.repeat((dst_mask > 0)[None, ...], 3, axis=0).astype(ref_img.dtype) + ref_img = ref_img * ref_mask3 + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + # 生成与小图同尺寸的掩膜 + src_mask_small = _downscale_mask_hw(src_mask, src_small.shape[1], src_small.shape[2]) if 'src_mask' in locals() and src_mask is not None else None + ref_mask_small = _downscale_mask_hw(dst_mask, ref_small.shape[1], ref_small.shape[2]) if 'dst_mask' in locals() and dst_mask is not None else None + + # 基于掩膜严格过滤匹配与内点 + result = _filter_matches_by_masks(result, src_mask_small, ref_mask_small) + + # 统计(以过滤后的结果为准) + num_inl = int(result.get("num_inliers", len(result.get("inlier_kpts0", [])))) + num_m = len(result.get("matched_kpts0", [])) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + # 保存匹配可视化图像(使用与匹配同尺寸的图像,保持CHW格式) + viz_dir = out_dir / "visualizations" + viz_dir.mkdir(exist_ok=True) + + matches_path = viz_dir / f"{bip_path.stem}_matches.png" + plot_matches(src_small, ref_small, result, save_path=str(matches_path)) + logger.info(f"匹配可视化已保存: {matches_path}") + + # 关键点可视化(源图像) + kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png" + plot_keypoints( + src_small, + {"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]}, + save_path=str(kpts_src_path) + ) + logger.info(f"源图像关键点可视化已保存: {kpts_src_path}") + + # 关键点可视化(参考图像) + kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png" + plot_keypoints( + ref_small, + {"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]}, + save_path=str(kpts_ref_path) + ) + logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test V9.py b/test V9.py new file mode 100644 index 0000000..14ecc8e --- /dev/null +++ b/test V9.py @@ -0,0 +1,1299 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +问题:当图像中大部分是水体时,匹配过多出现在掩膜边缘,同时过滤时将本来就少的陆地匹配点也过滤掉了 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +import csv +from datetime import datetime +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +from vismatch.viz import plot_matches, plot_keypoints +import logging + +try: + from skimage.transform import PiecewiseAffineTransform, PolynomialTransform + SKIMAGE_AVAILABLE = True +except ImportError: + SKIMAGE_AVAILABLE = False + logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换") + +try: + from matplotlib.path import Path as MplPath + from scipy.spatial import ConvexHull + MATPLOTLIB_SCIPY_AVAILABLE = True +except ImportError: + MATPLOTLIB_SCIPY_AVAILABLE = False + MplPath = None + logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断") + +try: + import SimpleITK as sitk + SITK_AVAILABLE = True +except ImportError: + SITK_AVAILABLE = False + logging.warning("SimpleITK 不可用,将使用仿射变换作为替代") + + +try: + import pirt + PIRT_AVAILABLE = True +except ImportError: + PIRT_AVAILABLE = False + logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代") + +try: + from scipy.interpolate import Rbf + SCIPY_AVAILABLE = True +except ImportError: + SCIPY_AVAILABLE = False + logging.warning("scipy 不可用,将跳过 TPS 变换") + + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\guidingsahn\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\guidingsahn") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\guidingsahn\output") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 变换方法选择(按优先级尝试) +TRANSFORM_METHODS = ["similarity", "affine", "homography"] +# 可选: "similarity", "affine", "homography", "piecewise_affine", "polynomial", "polynomial_order3", "tps" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素) +MASK_PAD_PX = 100 # 匹配掩膜扩张像素(仅用于匹配阶段) + +# 质量控制阈值 +MIN_INLIERS = 10 +MIN_INLIER_RATIO = 0.01 + +# 掩膜边缘羽化与过滤 +FEATHER_PX = 20 # 掩膜羽化宽度(像素,先在全分辨率/ROI分辨率上做) +EDGE_BAND_PX = 30 # 剔除距离掩膜边界小于此像素的匹配点(在小图上按比例缩放) + +# 纹理过滤 +MIN_GRAD_QUANTILE = 0.20 # 梯度幅值的分位阈值(0~1),低于该阈值的点视为低纹理,剔除 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# 创建统计输出目录和文件 +STATS_DIR = OUT_DIR / "stats" +STATS_DIR.mkdir(parents=True, exist_ok=True) +STATS_CSV = STATS_DIR / "registration_stats.csv" + +# ---------- 工具函数 ---------- +def init_stats_csv(csv_path: Path): + """初始化统计CSV文件""" + if not csv_path.exists(): + with open(csv_path, 'w', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + 'timestamp', 'filename', 'num_inliers', 'num_matches', 'inlier_ratio', + 'selected_method', 'median_error', 'p95_error', 'success' + ]) + +def log_registration_stats(csv_path: Path, filename: str, num_inliers: int, num_matches: int, + inlier_ratio: float, selected_method: str, median_error: float, + p95_error: float, success: bool): + """记录配准统计信息到CSV""" + timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + with open(csv_path, 'a', newline='', encoding='utf-8') as f: + writer = csv.writer(f) + writer.writerow([ + timestamp, filename, num_inliers, num_matches, f"{inlier_ratio:.4f}", + selected_method, f"{median_error:.4f}", f"{p95_error:.4f}", success + ]) +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def estimate_transform(method, k0, k1): + """统一的变换估计函数,支持多种变换类型""" + if method == "translation": + # 简单平移:用内点的平均位移 + if len(k0) == 0: + return None, None + dx = np.mean(k1[:, 0] - k0[:, 0]) + dy = np.mean(k1[:, 1] - k0[:, 1]) + A = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) + return "A", A + + elif method == "euclidean": + # 欧式变换(旋转+平移),约束等比缩放=1 + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "similarity": + # 相似变换(旋转+等比缩放+平移) + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "affine": + # 全仿射变换(旋转+非等比缩放+剪切+平移) + A, _ = cv2.estimateAffine2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + return "A", A + + elif method == "homography": + # 投影变换(8DOF,透视) + H, _ = cv2.findHomography(k0, k1, method=cv2.USAC_MAGSAC, ransacReprojThreshold=3.0) + return "H", H + + elif method == "piecewise_affine": + # 分片仿射变换 + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PiecewiseAffineTransform() + tform.estimate(k0, k1) + return "piecewise", tform + except Exception: + return None, None + + elif method == "polynomial": + # 多项式变换(2阶) + if not SKIMAGE_AVAILABLE: + return None, None + try: + tform = PolynomialTransform() + tform.estimate(k0, k1, order=2) + return "polynomial", tform + except Exception: + return None, None + + else: + raise ValueError(f"未知变换方法: {method}") + +def evaluate_transform_quality(transform_type, transform, k0, k1): + """评估变换质量(重投影误差)""" + if transform is None or len(k0) == 0: + return np.inf, np.inf + + if transform_type == "A": + # 仿射变换重投影误差 + A = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + pred = (A @ np.hstack([k0, ones]).T).T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type == "H": + # 单应变换重投影误差 + H = transform + ones = np.ones((k0.shape[0], 1), dtype=np.float32) + src_h = np.hstack([k0, ones]).T + warped = H @ src_h + warped /= (warped[2:3, :] + 1e-6) + pred = warped[:2, :].T + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + elif transform_type in ["piecewise", "polynomial"]: + # scikit-image 变换重投影误差 + pred = transform(k0) + e = np.sqrt(((pred - k1) ** 2).sum(axis=1)) + + else: + return np.inf, np.inf + + return float(np.median(e)), float(np.percentile(e, 95)) + +def _norm01_hw(x: np.ndarray) -> np.ndarray: + """对单波段(H,W)做简单百分位归一化到[0,1],增强跨传感器强度配准稳定性""" + x = x.astype(np.float32, copy=False) + p2 = float(np.percentile(x, 2)) + p98 = float(np.percentile(x, 98)) + y = (x - p2) / (p98 - p2 + 1e-6) + return np.clip(y, 0.0, 1.0) + +def _np_to_sitk_float_image(arr_hw: np.ndarray, origin_xy=(0.0, 0.0)): + """ + numpy(H,W)->SimpleITK Image。 + 物理坐标约定为“像素坐标系”:spacing=1, direction=I,origin=(x0,y0)。 + """ + img = sitk.GetImageFromArray(arr_hw.astype(np.float32, copy=False)) + img.SetSpacing((1.0, 1.0)) + img.SetOrigin((float(origin_xy[0]), float(origin_xy[1]))) + img.SetDirection((1.0, 0.0, 0.0, 1.0)) + return img + +def _compute_bbox_from_k1(k1_global: np.ndarray, ref_w: int, ref_h: int, pad: int = 10): + """用目标侧匹配点(k1_global)计算裁剪窗口(min_x,min_y,w,h),并裁到参考影像范围内""" + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_w, max_x) + max_y = min(ref_h, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + return min_x, min_y, bbox_w, bbox_h + +def _downscale_mask_hw(mask_hw: np.ndarray, target_h: int, target_w: int) -> np.ndarray: + """将(H,W)二值掩膜缩放到目标尺寸,保持最近邻""" + m = cv2.resize(mask_hw.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST) + return m > 0 + +def _soft_alpha_from_mask(mask_hw: np.ndarray, feather_px: int) -> np.ndarray: + """ + 二值掩膜 -> 软掩膜 alpha∈[0,1],边缘处按距离线性上升,避免硬边缘。 + mask_hw: bool/uint8 (H,W) True/1表示有效 + """ + if mask_hw is None: + return None + m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255 + # 距离变换仅对前景内部有效,计算到边界的距离 + dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3) + if feather_px <= 0: + alpha = (dist > 0).astype(np.float32) + else: + alpha = np.clip(dist / float(feather_px), 0.0, 1.0).astype(np.float32) + return alpha # (H,W) float32 + +def _distance_keep_mask(mask_hw: np.ndarray, min_dist_px: int) -> np.ndarray: + """ + 生成"远离边界"的保留掩膜:仅保留距离边界>=min_dist_px的像素。 + """ + if mask_hw is None: + return None + m = (mask_hw.astype(np.uint8) > 0).astype(np.uint8) * 255 + dist = cv2.distanceTransform(m, distanceType=cv2.DIST_L2, maskSize=3) + keep = dist >= float(max(1, min_dist_px)) + return keep + +def _grad_mask_from_chw(img_chw: np.ndarray, quantile: float) -> np.ndarray: + """ + 根据梯度幅值生成纹理掩膜(H,W)True=纹理足够。 + 使用与匹配同尺寸的CHW图像。 + """ + # 转灰度 + g = img_chw.mean(axis=0).astype(np.float32) # (H,W) + gx = cv2.Sobel(g, cv2.CV_32F, 1, 0, ksize=3) + gy = cv2.Sobel(g, cv2.CV_32F, 0, 1, ksize=3) + mag = np.sqrt(gx*gx + gy*gy) + thr = float(np.quantile(mag, quantile)) if mag.size > 0 else 0.0 + return mag >= thr # (H,W) bool + +def _filter_matches_by_masks(result: dict, src_mask_small: np.ndarray, ref_mask_small: np.ndarray) -> dict: + """将匹配与内点严格限制在掩膜内""" + if src_mask_small is None or ref_mask_small is None: + return result + + def keep_in_mask(kpts: np.ndarray, mask_hw: np.ndarray) -> np.ndarray: + if kpts is None or len(kpts) == 0: + return np.zeros((0,), dtype=bool) + kpts = np.asarray(kpts) + xs = np.clip(np.rint(kpts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1) + ys = np.clip(np.rint(kpts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1) + return mask_hw[ys, xs] + + # 过滤 matched_kpts + if "matched_kpts0" in result and "matched_kpts1" in result: + mk0 = np.asarray(result["matched_kpts0"]) + mk1 = np.asarray(result["matched_kpts1"]) + if len(mk0) == len(mk1) and len(mk0) > 0: + keep_m = keep_in_mask(mk0, src_mask_small) & keep_in_mask(mk1, ref_mask_small) + result["matched_kpts0"] = mk0[keep_m] + result["matched_kpts1"] = mk1[keep_m] + + # 过滤 inlier_kpts + if "inlier_kpts0" in result and "inlier_kpts1" in result and result["inlier_kpts0"] is not None: + ik0 = np.asarray(result["inlier_kpts0"]) + ik1 = np.asarray(result["inlier_kpts1"]) + if len(ik0) == len(ik1) and len(ik0) > 0: + keep_i = keep_in_mask(ik0, src_mask_small) & keep_in_mask(ik1, ref_mask_small) + result["inlier_kpts0"] = ik0[keep_i] + result["inlier_kpts1"] = ik1[keep_i] + result["num_inliers"] = int(len(result["inlier_kpts0"])) + + return result + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path, stats_csv: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 初始化统计变量 + num_inliers = 0 + num_matches = 0 + inlier_ratio = 0.0 + selected_method = "none" + median_error = float('inf') + p95_error = float('inf') + success = False + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用"源图有效掩膜"的包围盒推参考ROI(比整图bounds更贴近有效重叠) + try: + src_mask = (src.read_masks(1) > 0) # True=有效 + rows_any = np.any(src_mask, axis=1) + cols_any = np.any(src_mask, axis=0) + if rows_any.any() and cols_any.any(): + rmin = int(rows_any.argmax()) + rmax = int(src.height - 1 - rows_any[::-1].argmax()) + cmin = int(cols_any.argmax()) + cmax = int(src.width - 1 - cols_any[::-1].argmax()) + valid_win_src = rasterio.windows.Window(cmin, rmin, cmax - cmin + 1, rmax - rmin + 1) + valid_bounds_src = rasterio.windows.bounds(valid_win_src, transform=src.transform) + b = transform_bounds(src_crs, ref_crs, *valid_bounds_src, densify_pts=21) + else: + # 掩膜无效时回退到整图bounds + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + except Exception: + src_mask = None # 后续可选源图掩膜时用到 + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 将源图有效掩膜重投影到参考ROI,并适度膨胀后作为匹配掩膜 + try: + if src_mask is None: + src_mask = (src.read_masks(1) > 0) + ref_roi_transform = ref_dataset.window_transform(win) + roi_h, roi_w = int(win.height), int(win.width) + dst_mask = np.zeros((roi_h, roi_w), dtype=np.uint8) + + reproject( + source=src_mask.astype(np.uint8), + destination=dst_mask, + src_transform=src.transform, + src_crs=src_crs, + dst_transform=ref_roi_transform, + dst_crs=ref_crs, + resampling=Resampling.nearest + ) + + if MASK_PAD_PX > 0: + k = max(1, MASK_PAD_PX * 2 + 1) # odd kernel size + k = min(k, 99) # 防止核过大导致性能问题,可按需调整/删除 + kernel = np.ones((k, k), np.uint8) + dst_mask = cv2.dilate(dst_mask, kernel, iterations=1) + except Exception: + # 掩膜获取/重投影失败则不使用掩膜 + dst_mask = None + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 软掩膜:避免在边界产生硬高对比边 + try: + alpha_src = _soft_alpha_from_mask(src_mask, FEATHER_PX) if src_mask is not None else None + except Exception: + alpha_src = None + try: + alpha_ref = _soft_alpha_from_mask(dst_mask, FEATHER_PX) if dst_mask is not None else None + except Exception: + alpha_ref = None + + if alpha_src is not None: + alpha_src3 = np.repeat(alpha_src[None, ...], 3, axis=0).astype(src_img.dtype) + src_img = src_img * alpha_src3 + + if alpha_ref is not None: + alpha_ref3 = np.repeat(alpha_ref[None, ...], 3, axis=0).astype(ref_img.dtype) + ref_img = ref_img * alpha_ref3 + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + # 与小图同尺寸的掩膜 + src_mask_small = _downscale_mask_hw(src_mask, src_small.shape[1], src_small.shape[2]) if 'src_mask' in locals() and src_mask is not None else None + ref_mask_small = _downscale_mask_hw(dst_mask, ref_small.shape[1], ref_small.shape[2]) if 'dst_mask' in locals() and dst_mask is not None else None + + # 剔除掩膜边缘带(小图尺度的最小距离) + def _scale_px(px_full: int, full_wh, small_wh) -> int: + # 用平均缩放;也可以分别对H/W计算后取最小 + sy = small_wh[0] / max(1, full_wh[0]) + sx = small_wh[1] / max(1, full_wh[1]) + s = 0.5 * (sx + sy) + return max(1, int(round(px_full * s))) + + edge_band_src_small = _scale_px(EDGE_BAND_PX, (src_img.shape[1], src_img.shape[2]), (src_small.shape[1], src_small.shape[2])) + edge_band_ref_small = _scale_px(EDGE_BAND_PX, (ref_img.shape[1], ref_img.shape[2]), (ref_small.shape[1], ref_small.shape[2])) + + keep_src_edge = _distance_keep_mask(src_mask_small, edge_band_src_small) if src_mask_small is not None else None + keep_ref_edge = _distance_keep_mask(ref_mask_small, edge_band_ref_small) if ref_mask_small is not None else None + + # 纹理掩膜 + keep_src_tex = _grad_mask_from_chw(src_small, MIN_GRAD_QUANTILE) + keep_ref_tex = _grad_mask_from_chw(ref_small, MIN_GRAD_QUANTILE) + + # 组合最终保留掩膜(边缘+纹理),二者都要满足 + def _combine_keep(m_edge, m_tex): + if m_edge is None: + return m_tex + return (m_edge & m_tex) + + keep_src_final = _combine_keep(keep_src_edge, keep_src_tex) + keep_ref_final = _combine_keep(keep_ref_edge, keep_ref_tex) + + # 将匹配与内点严格限制在最终掩膜内 + def _filter_by_bool_masks(res, m_src, m_ref): + if m_src is None or m_ref is None: + return res + + def keep_in(mask_hw, pts): + if pts is None or len(pts) == 0: + return np.zeros((0,), dtype=bool) + xs = np.clip(np.rint(pts[:, 0]).astype(int), 0, mask_hw.shape[1] - 1) + ys = np.clip(np.rint(pts[:, 1]).astype(int), 0, mask_hw.shape[0] - 1) + return mask_hw[ys, xs] + + # matched + if "matched_kpts0" in res and "matched_kpts1" in res: + mk0 = np.asarray(res["matched_kpts0"]); mk1 = np.asarray(res["matched_kpts1"]) + if len(mk0) == len(mk1) and len(mk0) > 0: + keep_m = keep_in(m_src, mk0) & keep_in(m_ref, mk1) + res["matched_kpts0"] = mk0[keep_m] + res["matched_kpts1"] = mk1[keep_m] + + # inliers + if "inlier_kpts0" in res and "inlier_kpts1" in res and res["inlier_kpts0"] is not None: + ik0 = np.asarray(res["inlier_kpts0"]); ik1 = np.asarray(res["inlier_kpts1"]) + if len(ik0) == len(ik1) and len(ik0) > 0: + keep_i = keep_in(m_src, ik0) & keep_in(m_ref, ik1) + res["inlier_kpts0"] = ik0[keep_i] + res["inlier_kpts1"] = ik1[keep_i] + res["num_inliers"] = int(len(res["inlier_kpts0"])) + return res + + result = _filter_by_bool_masks(result, keep_src_final, keep_ref_final) + + # 统计(以过滤后的结果为准) + num_inl = int(result.get("num_inliers", len(result.get("inlier_kpts0", [])))) + num_m = len(result.get("matched_kpts0", [])) + ratio = (num_inl / num_m) if num_m else 0.0 + + # 更新统计变量 + num_inliers = num_inl + num_matches = num_m + inlier_ratio = ratio + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + # 保存匹配可视化图像(使用与匹配同尺寸的图像,保持CHW格式) + viz_dir = out_dir / "visualizations" + viz_dir.mkdir(exist_ok=True) + + matches_path = viz_dir / f"{bip_path.stem}_matches.png" + plot_matches(src_small, ref_small, result, save_path=str(matches_path)) + logger.info(f"匹配可视化已保存: {matches_path}") + + # 关键点可视化(源图像) + kpts_src_path = viz_dir / f"{bip_path.stem}_keypoints_src.png" + plot_keypoints( + src_small, + {"all_kpts0": result["all_kpts0"], "all_desc0": result["all_desc0"]}, + save_path=str(kpts_src_path) + ) + logger.info(f"源图像关键点可视化已保存: {kpts_src_path}") + + # 关键点可视化(参考图像) + kpts_ref_path = viz_dir / f"{bip_path.stem}_keypoints_ref.png" + plot_keypoints( + ref_small, + {"all_kpts0": result["all_kpts1"], "all_desc0": result["all_desc1"]}, + save_path=str(kpts_ref_path) + ) + logger.info(f"参考图像关键点可视化已保存: {kpts_ref_path}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_quality_check", median_error, p95_error, False) + return False + + # 5) 用内点估计多种变换并自动选择最优 + # 先计算全分辨率坐标 + k0_small = result["inlier_kpts0"].astype(np.float32) + k1_small = result["inlier_kpts1"].astype(np.float32) + + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0_inv = np.array([[s0x, 0, 0],[0, s0y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (src) + S1_inv = np.array([[s1x, 0, 0],[0, s1y, 0],[0, 0, 1]], dtype=np.float32) # small -> full (ref ROI) + + ones = np.ones((k0_small.shape[0], 1), dtype=np.float32) + k0_full = (S0_inv @ np.hstack([k0_small, ones]).T).T[:, :2] # 全分辨率源像素 + k1_roi_full = (S1_inv @ np.hstack([k1_small, ones]).T).T[:, :2] # ROI内参考像素 + k1_global = k1_roi_full + np.array([win.col_off, win.row_off], dtype=np.float32) # 全局参考像素 + + + # 用全分辨率坐标进行所有模型的估计和评估 + best_transform = None + best_transform_type = None + best_error = np.inf + best_median_error = np.inf + best_method = None + + for method in TRANSFORM_METHODS: + transform_type, transform = estimate_transform(method, k0_full, k1_global) + if transform is None: + continue + + med_err, p95_err = evaluate_transform_quality(transform_type, transform, k0_full, k1_global) + + # 选择重投影误差最小的变换 + if p95_err < best_error: + best_transform = transform + best_transform_type = transform_type + best_error = p95_err + best_median_error = med_err + best_method = method + + logger.debug(f"方法 {method}: p50={med_err:.2f}, p95={p95_err:.2f}") + + if best_transform is None: + logger.warning(f"所有变换方法都失败: {bip_path.name}") + # 记录失败的统计信息 + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "failed_transform", median_error, p95_error, False) + return False + + # 更新统计变量 + selected_method = best_method + median_error = best_median_error + p95_error = best_error + + logger.info(f"选用变换: {best_method} ({best_transform_type}), 误差 p95={best_error:.2f}") + + # 6) 根据变换类型进行相应的配准处理 + if best_transform_type == "A": + # 仿射变换:A 已是 src_full_pixel -> ref_full_pixel,直接构造像素->地图仿射 + A = best_transform # 2x3, src_full_pixel -> ref_full_pixel + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A + + # src_pixel -> map + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + M_map = Rt @ A3 + corrected_affine = Affine(M_map[0,0], M_map[0,1], M_map[0,2], + M_map[1,0], M_map[1,1], M_map[1,2]) + + # 用 M_map 求最小外接矩形(先到 map,再到 ref 像素) + Rt_inv = np.linalg.inv(Rt) + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corn_h = np.hstack([corners, np.ones((4,1))]).T + map_corners = (M_map @ corn_h).T[:, :2] + pix_corners = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T[:, :2] + + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil (pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil (pix_corners[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Affine): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + # ---- 非仿射变换处理 ---- + elif best_transform_type == "H": + # 单应变换:H 已是 src_full_pixel -> ref_full_pixel + H_full = best_transform # 3x3 + + try: + # 用 H_full 映射源四角 -> 参考像素,求最小外接矩形 + src_h, src_w = src.height, src.width + corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float32) + corn_h = np.hstack([corners, np.ones((4,1), dtype=np.float32)]).T + dst_h = (H_full @ corn_h) + dst = (dst_h[:2] / (dst_h[2:]+1e-6)).T + + min_x = int(np.floor(dst[:,0].min())) - 10 + max_x = int(np.ceil (dst[:,0].max())) + 10 + min_y = int(np.floor(dst[:,1].min())) - 10 + max_y = int(np.ceil (dst[:,1].max())) + 10 + + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"单应变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + # 子窗口坐标的单应矩阵(输出坐标系是子窗口像素) + T_off = np.array([[1,0,min_x],[0,1,min_y],[0,0,1]], dtype=np.float64) + H_sub = np.linalg.inv(T_off) @ H_full + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 使用 OpenCV 进行单应变换重采样 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.full((bbox_h, bbox_w), dst_nodata, dtype=np.float32) + + # 使用 OpenCV warpPerspective(子窗口坐标) + dst_band = cv2.warpPerspective( + src_band, H_sub, + (bbox_w, bbox_h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(Homography): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"单应变换异常: {e}") + # 继续到仿射回退 + + elif best_transform_type in ["piecewise", "polynomial", "polynomial_order3"]: + # 分片仿射或多项式变换:使用 scikit-image + transform = best_transform # 已用 k0_full/k1_global 估计 + try: + # 用目标侧匹配点(k1_global)决定外接矩形(更稳) + pad = 10 + min_x = int(np.floor(k1_global[:, 0].min())) - pad + max_x = int(np.ceil (k1_global[:, 0].max())) + pad + min_y = int(np.floor(k1_global[:, 1].min())) - pad + max_y = int(np.ceil (k1_global[:, 1].max())) + pad + + min_x = max(0, min_x) + min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"{best_transform_type}变换最小外接矩形无效: {bip_path.name}") + return False + + # 创建输出窗口 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 定义带偏移的逆映射函数 + off_x, off_y = min_x, min_y + + if best_transform_type in ["polynomial", "polynomial_order3"]: + # 对于多项式,估计逆变换 + order = 2 if best_transform_type == "polynomial" else 3 + t_inv = PolynomialTransform() + t_inv.estimate(k1_global, k0_full, order=order) # 顺序:目标->源 + + # 目标侧点集的内点判定(用于限制外推) + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array([[min_x, min_y],[min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h],[min_x, min_y + bbox_h]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ((xy[:,0] >= min_x) & (xy[:,0] <= min_x + bbox_w) & + (xy[:,1] >= min_y) & (xy[:,1] <= min_y + bbox_h)) + + def inv_map_rc(coords): + # coords: (N,2) in (row, col) + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # -> (x, y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = t_inv(xy[inside]) # -> (x_src, y_src) in full-src + # 确保坐标在源图像范围内 + xy_src[:, 0] = np.clip(xy_src[:, 0], 0, src.height - 1) + xy_src[:, 1] = np.clip(xy_src[:, 1], 0, src.width - 1) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + elif best_transform_type == "piecewise": # piecewise_affine + # 目标侧点集的内点判定 + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + # 使用当前裁剪窗口的边界创建矩形 + rect = np.array([[min_x, min_y],[max_x, min_y],[max_x, max_y],[min_x, max_y]], dtype=float) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + # 退化为矩形内判断 + def point_inside(xy): + return (xy[:,0] >= min_x) & (xy[:,0] <= max_x) & \ + (xy[:,1] >= min_y) & (xy[:,1] <= max_y) + + def inv_map_rc(coords): + rc = np.asarray(coords) + xy = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # (x,y) in full-ref + inside = point_inside(xy) + xy_src = np.full_like(xy, fill_value=-1.0) + if np.any(inside): + xy_src[inside] = transform.inverse(xy[inside]) # -> full-src (x_src, y_src) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # -> (row_src, col_src) + + # 使用 scikit-image 进行变换重采样 + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, # 带偏移和轴序修正的逆映射 + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准({best_transform_type}): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, selected_method, median_error, p95_error, success) + return True + + except Exception as e: + logger.warning(f"{best_transform_type}变换异常: {e}") + # 继续到仿射回退 + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + transform = best_transform + try: + min_x, min_y, bbox_w, bbox_h = _compute_bbox_from_k1( + k1_global, ref_dataset.width, ref_dataset.height, pad=10 + ) + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"tps变换最小外接矩形无效: {bip_path.name}") + return False + + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + if MATPLOTLIB_SCIPY_AVAILABLE: + try: + hull = ConvexHull(k1_global) + hull_path = MplPath(k1_global[hull.vertices]) + except Exception: + rect = np.array( + [[min_x, min_y], [min_x + bbox_w, min_y], + [min_x + bbox_w, min_y + bbox_h], [min_x, min_y + bbox_h]], + dtype=float + ) + hull_path = MplPath(rect) + + def point_inside(xy): + return hull_path.contains_points(xy) + else: + def point_inside(xy): + return ( + (xy[:, 0] >= min_x) & (xy[:, 0] <= min_x + bbox_w) & + (xy[:, 1] >= min_y) & (xy[:, 1] <= min_y + bbox_h) + ) + + off_x, off_y = min_x, min_y + tps_inv = transform["inv"] # ref -> src + + def inv_map_rc(coords): + rc = np.asarray(coords, dtype=np.float64) + xy_ref = np.column_stack([rc[:, 1] + off_x, rc[:, 0] + off_y]) # full-ref (x, y) + inside = point_inside(xy_ref) + xy_src = np.full_like(xy_ref, fill_value=-1.0, dtype=np.float64) + if np.any(inside): + # 使用RBF插值计算逆映射 + xy_src[inside, 0] = tps_inv["rbf_x"](xy_ref[inside, 0], xy_ref[inside, 1]) + xy_src[inside, 1] = tps_inv["rbf_y"](xy_ref[inside, 0], xy_ref[inside, 1]) + return np.column_stack([xy_src[:, 1], xy_src[:, 0]]) # (row_src, col_src) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 优先用 skimage.warp;缺失时用 SimpleITK Resample 兜底 + if SKIMAGE_AVAILABLE: + from skimage.transform import warp + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = warp( + src_band, + inverse_map=inv_map_rc, + output_shape=(bbox_h, bbox_w), + mode='constant', + cval=dst_nodata, + preserve_range=True + ).astype(np.float32) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + else: + # OpenCV remap 版本(无需 skimage/SimpleITK) + with rasterio.open(out_path, "w", **out_profile) as out_ds: + # 创建映射网格 + y_coords, x_coords = np.mgrid[0:bbox_h, 0:bbox_w] + coords = np.column_stack([y_coords.ravel(), x_coords.ravel()]) + + # 计算逆映射 + mapped_coords = inv_map_rc(coords) + map_y = mapped_coords[:, 0].reshape(bbox_h, bbox_w).astype(np.float32) + map_x = mapped_coords[:, 1].reshape(bbox_h, bbox_w).astype(np.float32) + + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + + # 使用OpenCV的remap进行重采样 + dst_band = cv2.remap( + src_band, map_x, map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=dst_nodata + ) + + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(TPS): {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.warning(f"tps变换异常: {e}") + # 继续到仿射回退 + + + + # ---- 回退:使用仿射变换,保证最小可用结果 ---- + # 重新估计仿射变换作为fallback + A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A_fallback is None: + logger.warning(f"仿射回退也失败: {bip_path.name}") + return False + + # 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标 + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64) + S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64) + A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback + M_full = S1_inv @ A3 @ S0 + + T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64) + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + src_pixel_to_map_corrected = Rt @ T_off @ M_full + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 计算源 BIP 四角经过仿射变换后的最小外接矩形 + # 将 rasterio.Affine 转为 3x3 像素->地图矩阵 + M_map = np.array([ + [corrected_affine.a, corrected_affine.b, corrected_affine.c], + [corrected_affine.d, corrected_affine.e, corrected_affine.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + + # 参考底图的 像素->地图 矩阵及其逆 + ref_transform = ref_dataset.transform + Rt = np.array([ + [ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0.0, 0.0, 1.0] + ], dtype=np.float64) + Rt_inv = np.linalg.inv(Rt) + + # 源影像四角(源像素坐标) + src_h, src_w = src.height, src.width + src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64) + corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4) + + # 源像素 -> 地图坐标 + map_corners = (M_map @ corners_h).T[:, :2] + + # 地图坐标 -> 参考像素坐标 + pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3) + pix_corners = pix_corners_h[:, :2] + + # 最小外接矩形(像素) + min_x = int(np.floor(pix_corners[:,0].min())) - 10 + max_x = int(np.ceil( pix_corners[:,0].max())) + 10 + min_y = int(np.floor(pix_corners[:,1].min())) - 10 + max_y = int(np.ceil( pix_corners[:,1].max())) + 10 + + # 边界裁剪 + min_x = max(0, min_x); min_y = max(0, min_y) + max_x = min(ref_dataset.width, max_x) + max_y = min(ref_dataset.height, max_y) + + bbox_w = max_x - min_x + bbox_h = max_y - min_y + + # 如果外接矩形太小,跳过 + if bbox_w <= 0 or bbox_h <= 0: + logger.warning(f"最小外接矩形无效: {bip_path.name}") + return False + + # 创建裁剪窗口和变换 + bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h) + bbox_transform = ref_dataset.window_transform(bbox_window) + + out_path = out_dir / f"{bip_path.stem}_registered.bip" + src_nodata = src.nodata + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 更新输出 profile 使用最小外接矩形 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], + height=bbox_h, + width=bbox_w, + count=src.count, + transform=bbox_transform, # 使用最小外接矩形的变换 + crs=ref_crs, + interleave="bip", + compress=None, + nodata=dst_nodata + ) + + # 重采样到最小外接矩形 + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32) + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, + src_crs=ref_crs, + dst_transform=bbox_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=Resampling.bilinear, + ) + # 转回目标 dtype + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + mask = (dst_band == dst_nodata) if src_nodata is not None else None + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"]) + if mask is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}") + success = True + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "affine_fallback", median_error, p95_error, success) + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + # 记录失败的统计信息 + try: + log_registration_stats(stats_csv, bip_path.name, num_inliers, num_matches, + inlier_ratio, "exception", median_error, p95_error, False) + except: + pass # 避免统计记录失败影响主要错误处理 + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化统计CSV文件 + init_stats_csv(STATS_CSV) + logger.info(f"统计信息将保存到: {STATS_CSV}") + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR, STATS_CSV): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main() diff --git a/test.py b/test.py new file mode 100644 index 0000000..737ae9a --- /dev/null +++ b/test.py @@ -0,0 +1,290 @@ +""" +批量配准 .bip 文件到参考 .tif 文件 +使用地理信息进行粗定位,然后用图像匹配进行精配准 +""" + +from pathlib import Path +import numpy as np +import cv2 +import rasterio +from rasterio.windows import from_bounds +from rasterio.warp import transform_bounds, reproject, Resampling +from affine import Affine +from vismatch import get_matcher +import logging + +# 设置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ---------- 配置 ---------- +# 请根据实际情况修改这些路径 +REF_TIF = r"E:\is2\jiashixian\result.tif" # 参考 tif 文件路径 +BIP_DIR = Path(r"E:\is2\jiashixian\Geoout\1") # .bip 文件所在文件夹 +OUT_DIR = Path(r"E:\is2\jiashixian\matchanything-roma") # 输出文件夹 + +# 匹配算法选择 +MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等 +DEVICE = "cuda" # 或 "cpu" + +# 匹配参数 +MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素) +ROI_PAD_PX = 300 # 粗定位窗口的padding(参考tif像素) + +# 质量控制阈值 +MIN_INLIERS = 30 # 最少内点数 +MIN_INLIER_RATIO = 0.15 # 最少内点比例 + +# 创建输出目录 +OUT_DIR.mkdir(parents=True, exist_ok=True) + +# ---------- 工具函数 ---------- +def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray: + """将任意通道数的数组转换为 (3,H,W) float32 in [0,1]""" + arr = arr_chw.astype(np.float32) + + if arr.shape[0] == 1: + # 单波段复制为3通道 + arr = np.repeat(arr, 3, axis=0) + elif arr.shape[0] >= 3: + # 取前3波段 + arr = arr[:3] + else: + raise ValueError(f"不支持的通道数: {arr.shape[0]}") + + # 百分位数拉伸,增强跨传感器匹配稳定性 + p2 = np.percentile(arr, 2) + p98 = np.percentile(arr, 98) + arr = (arr - p2) / (p98 - p2 + 1e-6) + arr = np.clip(arr, 0.0, 1.0) + return arr + +def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray: + """等比缩放 (C,H,W) 到 max(H,W) <= max_side""" + c, h, w = arr_chw.shape + s = min(1.0, max_side / max(h, w)) + if s >= 1.0: + return arr_chw + new_w = int(round(w * s)) + new_h = int(round(h * s)) + # 用opencv缩放(逐通道) + out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0) + return out + +def _expand_window(win, pad, max_w, max_h): + """扩展窗口并确保边界有效""" + col_off = int(max(0, win.col_off - pad)) + row_off = int(max(0, win.row_off - pad)) + col_end = int(min(max_w, win.col_off + win.width + pad)) + row_end = int(min(max_h, win.row_off + win.height + pad)) + return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off) + +def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path): + """处理单个 .bip 文件到参考 .tif 的配准""" + try: + with rasterio.open(bip_path) as src: + logger.info(f"处理文件: {bip_path.name}") + + # 检查CRS + if src.crs is None: + logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS") + src_crs = ref_dataset.crs + else: + src_crs = src.crs + + ref_crs = ref_dataset.crs + if ref_crs is None: + raise RuntimeError(f"参考文件缺少CRS信息") + + # 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI + b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21) + win0 = from_bounds(*b, transform=ref_dataset.transform) + win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height) + + if win.width <= 0 or win.height <= 0: + logger.warning(f"无重叠区域: {bip_path.name}") + return False + + # 2) 读取数据 + # 读取所有波段,如果是多波段的话 + src_arr = src.read() # (bands, H, W) + if src_arr.ndim == 2: # 单波段 + src_arr = src_arr[None, ...] # 增加波段维度 + + # 读取参考文件的ROI + ref_arr = ref_dataset.read(window=win) # (bands, h, w) + if ref_arr.ndim == 2: # 单波段 + ref_arr = ref_arr[None, ...] # 增加波段维度 + + # 转换为匹配所需的格式 + src_img = _to_3ch_float01(src_arr) + ref_img = _to_3ch_float01(ref_arr) + + # 3) 匹配用降采样版本,提速 + 增稳 + src_small = _downscale_chw(src_img, MATCH_MAX_SIDE) + ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE) + + logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}") + + # 4) 精配准(img0=src, img1=ref_roi) + result = matcher(src_small, ref_small) + + num_inl = int(result["num_inliers"]) + num_m = len(result["matched_kpts0"]) + ratio = (num_inl / num_m) if num_m else 0.0 + + logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}") + + if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO: + logger.warning(f"匹配质量不足: {bip_path.name}") + return False + + # 5) 用内点估计仿射变换 + k0 = result["inlier_kpts0"].astype(np.float32) + k1 = result["inlier_kpts1"].astype(np.float32) + + A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0) + if A is None: + logger.warning(f"仿射估计失败: {bip_path.name}") + return False + + # 6) 把"src_small->ref_small"的仿射映射回"src_full->ref_full_roi" + # 缩放系数:full/small + s0x = src_img.shape[2] / src_small.shape[2] + s0y = src_img.shape[1] / src_small.shape[1] + s1x = ref_img.shape[2] / ref_small.shape[2] + s1y = ref_img.shape[1] / ref_small.shape[1] + + S0 = np.array([[1/s0x, 0, 0], + [0, 1/s0y, 0], + [0, 0, 1]], dtype=np.float64) # full -> small + S1_inv = np.array([[s1x, 0, 0], + [0, s1y, 0], + [0, 0, 1]], dtype=np.float64) # small -> full(roi) + + A3 = np.eye(3, dtype=np.float64) + A3[:2, :] = A # small src -> small ref_roi + + # full_src -> full_ref_roi(用 S0,而不是 S0_inv) + M_full = S1_inv @ A3 @ S0 + + # 7) 目标输出:重采样到 ref 全图网格 + # 需要"src像素 -> 地图坐标"的修正 transform + T_off = np.array([[1, 0, win.col_off], + [0, 1, win.row_off], + [0, 0, 1]], dtype=np.float64) + + # ref_transform 转 3x3 + ref_transform = ref_dataset.transform + Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c], + [ref_transform.d, ref_transform.e, ref_transform.f], + [0, 0, 1]], dtype=np.float64) + + src_pixel_to_map_corrected = Rt @ T_off @ M_full + + # 转回 Affine + corrected_affine = Affine( + src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2], + src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2], + ) + + # 8) 重采样输出到 ref 网格(多波段,BIP 输出) + out_path = out_dir / f"{bip_path.stem}_registered.bip" + + # 获取 NoData 值 + src_nodata = src.nodata # 可能是 None、0、65535 等 + dst_nodata = src_nodata if src_nodata is not None else 0 + + # 以参考影像的网格/坐标为目标,波段数和数据类型继承源影像 + out_profile = ref_dataset.profile.copy() + out_profile.update( + driver="ENVI", + dtype=src.dtypes[0], # 尽量保持与源影像一致的数据类型 + height=ref_dataset.height, + width=ref_dataset.width, + count=src.count, # 多波段 + transform=ref_transform, + crs=ref_crs, + interleave="bip", # 指定 BIP + compress=None, + nodata=dst_nodata # 设置 NoData 值 + ) + + with rasterio.open(out_path, "w", **out_profile) as out_ds: + for b in range(1, src.count + 1): + # 逐波段重投影(float32 计算更稳) + src_band = src.read(b).astype(np.float32) + dst_band = np.zeros((ref_dataset.height, ref_dataset.width), dtype=np.float32) + + reproject( + source=src_band, + destination=dst_band, + src_transform=corrected_affine, # 由原逻辑推导的 src像素->ref像素 的仿射 + src_crs=ref_crs, + dst_transform=ref_transform, + dst_crs=ref_crs, + src_nodata=src_nodata, # 新增 NoData 处理 + dst_nodata=dst_nodata, # 新增 NoData 处理 + resampling=Resampling.bilinear, + ) + + # 调试统计信息 + logger.info(f"波段 {b}: min={float(np.nanmin(dst_band)):.2f}, max={float(np.nanmax(dst_band)):.2f}, " + f"mean={float(np.nanmean(dst_band)):.2f}, nodata占比={(dst_band==dst_nodata).mean():.2%}") + + # 转回目标 dtype(若为整型,先裁剪再转型) + if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer): + if src_nodata is not None: + # 保持 nodata 不被拉伸 + mask = (dst_band == dst_nodata) + info = np.iinfo(out_profile["dtype"]) + dst_band = np.clip(dst_band, info.min, info.max) + dst_band = dst_band.astype(out_profile["dtype"]) + if src_nodata is not None: + dst_band[mask] = dst_nodata + else: + dst_band = dst_band.astype(out_profile["dtype"]) + + out_ds.write(dst_band, b) + + logger.info(f"成功配准: {bip_path.name} -> {out_path.name}") + return True + + except Exception as e: + logger.error(f"处理失败 {bip_path.name}: {str(e)}") + return False + +# ---------- 主逻辑 ---------- +def main(): + logger.info("开始批量配准处理...") + + # 检查输入文件是否存在 + if not Path(REF_TIF).exists(): + logger.error(f"参考文件不存在: {REF_TIF}") + return + + if not BIP_DIR.exists(): + logger.error(f"BIP文件夹不存在: {BIP_DIR}") + return + + # 初始化匹配器 + logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}") + matcher = get_matcher(MATCHER_NAME, device=DEVICE) + + # 打开参考文件 + with rasterio.open(REF_TIF) as ref: + logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}") + + # 查找所有 .bip 文件 + bip_files = list(BIP_DIR.glob("*.bip")) + logger.info(f"找到 {len(bip_files)} 个 .bip 文件") + + success_count = 0 + for bip_path in bip_files: + if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR): + success_count += 1 + + logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准") + +if __name__ == "__main__": + main()