468 lines
20 KiB
Python
468 lines
20 KiB
Python
"""
|
||
批量配准 .bip 文件到参考 .tif 文件
|
||
直接进行配准
|
||
"""
|
||
|
||
from pathlib import Path
|
||
import numpy as np
|
||
import cv2
|
||
import rasterio
|
||
from rasterio.windows import from_bounds
|
||
from rasterio.warp import transform_bounds, reproject, Resampling
|
||
from affine import Affine
|
||
from vismatch import get_matcher
|
||
import logging
|
||
|
||
try:
|
||
from skimage.transform import PiecewiseAffineTransform, PolynomialTransform
|
||
SKIMAGE_AVAILABLE = True
|
||
except ImportError:
|
||
SKIMAGE_AVAILABLE = False
|
||
logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换")
|
||
|
||
try:
|
||
from matplotlib.path import Path as MplPath
|
||
from scipy.spatial import ConvexHull
|
||
MATPLOTLIB_SCIPY_AVAILABLE = True
|
||
except ImportError:
|
||
MATPLOTLIB_SCIPY_AVAILABLE = False
|
||
MplPath = None
|
||
logging.warning("matplotlib 或 scipy 不可用,piecewise_affine 将退化为矩形内判断")
|
||
|
||
try:
|
||
import SimpleITK as sitk
|
||
SITK_AVAILABLE = True
|
||
except ImportError:
|
||
SITK_AVAILABLE = False
|
||
logging.warning("SimpleITK 不可用,将使用仿射变换作为替代")
|
||
|
||
try:
|
||
import pirt
|
||
PIRT_AVAILABLE = True
|
||
except ImportError:
|
||
PIRT_AVAILABLE = False
|
||
logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代")
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ---------- 配置 ----------
|
||
# 请根据实际情况修改这些路径
|
||
REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径
|
||
BIP_DIR = Path(r"E:\is2\yaopu") # .bip 文件所在文件夹
|
||
OUT_DIR = Path(r"E:\is2\yaopu\output") # 输出文件夹
|
||
|
||
# 匹配算法选择
|
||
MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
|
||
DEVICE = "cuda" # 或 "cpu"
|
||
|
||
# 使用密集匹配模型的稠密流直接进行配准
|
||
|
||
# 匹配参数
|
||
MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素)
|
||
ROI_PAD_PX = 500 # 粗定位窗口的padding(参考tif像素)
|
||
|
||
# 质量控制阈值
|
||
MIN_INLIERS = 10 # 最少内点数
|
||
MIN_INLIER_RATIO = 0.01 # 最少内点比例
|
||
|
||
# 创建输出目录
|
||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# ---------- 工具函数 ----------
|
||
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
|
||
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
|
||
arr = arr_chw.astype(np.float32)
|
||
|
||
if arr.shape[0] == 1:
|
||
# 单波段复制为3通道
|
||
arr = np.repeat(arr, 3, axis=0)
|
||
elif arr.shape[0] >= 3:
|
||
# 取前3波段
|
||
arr = arr[:3]
|
||
else:
|
||
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
|
||
|
||
# 百分位数拉伸,增强跨传感器匹配稳定性
|
||
p2 = np.percentile(arr, 2)
|
||
p98 = np.percentile(arr, 98)
|
||
arr = (arr - p2) / (p98 - p2 + 1e-6)
|
||
arr = np.clip(arr, 0.0, 1.0)
|
||
return arr
|
||
|
||
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
|
||
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
|
||
c, h, w = arr_chw.shape
|
||
s = min(1.0, max_side / max(h, w))
|
||
if s >= 1.0:
|
||
return arr_chw
|
||
new_w = int(round(w * s))
|
||
new_h = int(round(h * s))
|
||
# 用opencv缩放(逐通道)
|
||
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
|
||
return out
|
||
|
||
def _expand_window(win, pad, max_w, max_h):
|
||
"""扩展窗口并确保边界有效"""
|
||
col_off = int(max(0, win.col_off - pad))
|
||
row_off = int(max(0, win.row_off - pad))
|
||
col_end = int(min(max_w, win.col_off + win.width + pad))
|
||
row_end = int(min(max_h, win.row_off + win.height + pad))
|
||
return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off)
|
||
|
||
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path):
|
||
"""处理单个 .bip 文件到参考 .tif 的配准"""
|
||
try:
|
||
with rasterio.open(bip_path) as src:
|
||
logger.info(f"处理文件: {bip_path.name}")
|
||
|
||
# 检查CRS
|
||
if src.crs is None:
|
||
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS")
|
||
src_crs = ref_dataset.crs
|
||
else:
|
||
src_crs = src.crs
|
||
|
||
ref_crs = ref_dataset.crs
|
||
if ref_crs is None:
|
||
raise RuntimeError(f"参考文件缺少CRS信息")
|
||
|
||
# 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI
|
||
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
|
||
win0 = from_bounds(*b, transform=ref_dataset.transform)
|
||
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
|
||
|
||
if win.width <= 0 or win.height <= 0:
|
||
logger.warning(f"无重叠区域: {bip_path.name}")
|
||
return False
|
||
|
||
# 2) 读取数据
|
||
# 读取所有波段,如果是多波段的话
|
||
src_arr = src.read() # (bands, H, W)
|
||
if src_arr.ndim == 2: # 单波段
|
||
src_arr = src_arr[None, ...] # 增加波段维度
|
||
|
||
# 读取参考文件的ROI
|
||
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
|
||
if ref_arr.ndim == 2: # 单波段
|
||
ref_arr = ref_arr[None, ...] # 增加波段维度
|
||
|
||
# 转换为匹配所需的格式
|
||
src_img = _to_3ch_float01(src_arr)
|
||
ref_img = _to_3ch_float01(ref_arr)
|
||
|
||
# 3) 匹配用降采样版本,提速 + 增稳
|
||
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
|
||
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
|
||
|
||
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
|
||
|
||
# 4) 精配准(img0=src, img1=ref_roi)
|
||
result = matcher(src_small, ref_small)
|
||
|
||
num_inl = int(result["num_inliers"])
|
||
num_m = len(result["matched_kpts0"])
|
||
ratio = (num_inl / num_m) if num_m else 0.0
|
||
|
||
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
|
||
|
||
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
|
||
logger.warning(f"匹配质量不足: {bip_path.name}")
|
||
return False
|
||
|
||
# ==== 稠密流直接重采样(无需后续显式变换估计) ====
|
||
|
||
# 1) 取稠密流(优先 ref->src)。不同模型的键名可能不同,这里做兼容
|
||
flow_small = None
|
||
for k in ["flow_ref2src", "flow21", "flow_1_0", "flow10", "flow"]:
|
||
if k in result:
|
||
flow_small = result[k]
|
||
break
|
||
|
||
if flow_small is None:
|
||
# 回退:优先 DIS 光流(更快/稳),若不可用再用 Farneback (ref -> src)
|
||
ref_small_rgb = np.transpose(ref_small, (1, 2, 0)) # (H,W,3)
|
||
src_small_rgb = np.transpose(src_small, (1, 2, 0))
|
||
ref_small_gray = cv2.cvtColor((ref_small_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
|
||
src_small_gray = cv2.cvtColor((src_small_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
|
||
|
||
flow_small = None
|
||
try:
|
||
dis = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_MEDIUM)
|
||
flow_small = dis.calc(ref_small_gray, src_small_gray, None).astype(np.float32)
|
||
except Exception:
|
||
pass
|
||
|
||
if flow_small is None:
|
||
# 典型参数:可按影像特性微调(窗口、迭代次数等)
|
||
flow_small = cv2.calcOpticalFlowFarneback(
|
||
ref_small_gray, src_small_gray,
|
||
None, 0.5, 3, 25, 3, 5, 1.2, 0
|
||
).astype(np.float32)
|
||
|
||
# flow_small 期望形状 (h_s, w_s, 2),分量为 (dx, dy): 参考像素到源像素的位移
|
||
flow_small = np.asarray(flow_small, dtype=np.float32)
|
||
if flow_small.ndim != 3 or flow_small.shape[2] != 2:
|
||
logger.warning(f"稠密流形状异常: {flow_small.shape}")
|
||
return False
|
||
|
||
# 2) 将小图的流放大到 ROI 全分辨率,并按比例放大位移
|
||
roi_h, roi_w = ref_img.shape[1], ref_img.shape[2] # 注意 ref_img 是 ROI 子图
|
||
scale_x = roi_w / flow_small.shape[1]
|
||
scale_y = roi_h / flow_small.shape[0]
|
||
flow_full = cv2.resize(flow_small, (roi_w, roi_h), interpolation=cv2.INTER_LINEAR)
|
||
flow_full[..., 0] *= scale_x # dx
|
||
flow_full[..., 1] *= scale_y # dy
|
||
|
||
# 3) 生成 remap 所需的源坐标图(map_x, map_y),在"参考ROI坐标系"内工作
|
||
yy, xx = np.meshgrid(np.arange(roi_h, dtype=np.float32),
|
||
np.arange(roi_w, dtype=np.float32), indexing="ij")
|
||
map_x = xx + flow_full[..., 0] # 到源图的 x(列)
|
||
map_y = yy + flow_full[..., 1] # 到源图的 y(行)
|
||
|
||
# 4) 根据有效映射范围求最小外接矩形(仅统计落在源图范围内的像素)
|
||
valid = (map_x >= 0) & (map_x <= (src.width - 1)) & (map_y >= 0) & (map_y <= (src.height - 1))
|
||
if not np.any(valid):
|
||
logger.warning(f"稠密流无有效映射: {bip_path.name}")
|
||
return False
|
||
|
||
ys, xs = np.where(valid)
|
||
pad = 0
|
||
min_y = max(int(ys.min()) - pad, 0)
|
||
max_y = min(int(ys.max()) + 1 + pad, roi_h)
|
||
min_x = max(int(xs.min()) - pad, 0)
|
||
max_x = min(int(xs.max()) + 1 + pad, roi_w)
|
||
|
||
crop_h = max_y - min_y
|
||
crop_w = max_x - min_x
|
||
if crop_h <= 0 or crop_w <= 0:
|
||
logger.warning(f"最小外接矩形无效: {bip_path.name}")
|
||
return False
|
||
|
||
# 只对外接矩形区域做重采样,减少内存
|
||
map_x_crop = map_x[min_y:max_y, min_x:max_x].astype(np.float32)
|
||
map_y_crop = map_y[min_y:max_y, min_x:max_x].astype(np.float32)
|
||
|
||
# 5) 计算输出的地理变换:参考ROI窗口 + 外接矩形子窗口
|
||
# 先得到 ROI 的 transform,再叠加子窗口偏移
|
||
roi_transform = ref_dataset.window_transform(win)
|
||
crop_window_global = rasterio.windows.Window(
|
||
win.col_off + min_x, win.row_off + min_y, crop_w, crop_h
|
||
)
|
||
out_transform = ref_dataset.window_transform(crop_window_global)
|
||
|
||
# 6) 写出 ENVI/BIP(按最小外接矩形)
|
||
out_path = out_dir / f"{bip_path.stem}_registered.bip"
|
||
src_nodata = src.nodata
|
||
dst_nodata = src_nodata if src_nodata is not None else 0
|
||
|
||
out_profile = ref_dataset.profile.copy()
|
||
out_profile.update(
|
||
driver="ENVI",
|
||
dtype=src.dtypes[0],
|
||
height=crop_h,
|
||
width=crop_w,
|
||
count=src.count,
|
||
transform=out_transform,
|
||
crs=ref_crs,
|
||
interleave="bip",
|
||
compress=None,
|
||
nodata=dst_nodata
|
||
)
|
||
|
||
with rasterio.open(out_path, "w", **out_profile) as out_ds:
|
||
for b in range(1, src.count + 1):
|
||
src_band = src.read(b).astype(np.float32)
|
||
# 反向映射采样:输出像素在参考ROI坐标,去源图(map_y,map_x)取值
|
||
warped = cv2.remap(
|
||
src_band, map_x_crop, map_y_crop,
|
||
interpolation=cv2.INTER_LINEAR,
|
||
borderMode=cv2.BORDER_CONSTANT,
|
||
borderValue=float(dst_nodata)
|
||
).astype(np.float32)
|
||
|
||
# 转回目标 dtype,保持 nodata
|
||
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
|
||
mask = (warped == dst_nodata) if src_nodata is not None else None
|
||
info = np.iinfo(out_profile["dtype"])
|
||
warped = np.clip(warped, info.min, info.max).astype(out_profile["dtype"])
|
||
if mask is not None:
|
||
warped[mask] = dst_nodata
|
||
else:
|
||
warped = warped.astype(out_profile["dtype"])
|
||
|
||
out_ds.write(warped, b)
|
||
|
||
logger.info(f"成功配准(DenseFlow): {bip_path.name} -> {out_path.name}")
|
||
return True
|
||
|
||
# ---- 回退:使用仿射变换,保证最小可用结果 ----
|
||
# 重新估计仿射变换作为fallback
|
||
A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0)
|
||
if A_fallback is None:
|
||
logger.warning(f"仿射回退也失败: {bip_path.name}")
|
||
return False
|
||
|
||
# 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标
|
||
s0x = src_img.shape[2] / src_small.shape[2]
|
||
s0y = src_img.shape[1] / src_small.shape[1]
|
||
s1x = ref_img.shape[2] / ref_small.shape[2]
|
||
s1y = ref_img.shape[1] / ref_small.shape[1]
|
||
S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64)
|
||
S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64)
|
||
A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback
|
||
M_full = S1_inv @ A3 @ S0
|
||
|
||
T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64)
|
||
ref_transform = ref_dataset.transform
|
||
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
|
||
[ref_transform.d, ref_transform.e, ref_transform.f],
|
||
[0, 0, 1]], dtype=np.float64)
|
||
src_pixel_to_map_corrected = Rt @ T_off @ M_full
|
||
corrected_affine = Affine(
|
||
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
|
||
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
|
||
)
|
||
|
||
# 计算源 BIP 四角经过仿射变换后的最小外接矩形
|
||
# 将 rasterio.Affine 转为 3x3 像素->地图矩阵
|
||
M_map = np.array([
|
||
[corrected_affine.a, corrected_affine.b, corrected_affine.c],
|
||
[corrected_affine.d, corrected_affine.e, corrected_affine.f],
|
||
[0.0, 0.0, 1.0]
|
||
], dtype=np.float64)
|
||
|
||
# 参考底图的 像素->地图 矩阵及其逆
|
||
ref_transform = ref_dataset.transform
|
||
Rt = np.array([
|
||
[ref_transform.a, ref_transform.b, ref_transform.c],
|
||
[ref_transform.d, ref_transform.e, ref_transform.f],
|
||
[0.0, 0.0, 1.0]
|
||
], dtype=np.float64)
|
||
Rt_inv = np.linalg.inv(Rt)
|
||
|
||
# 源影像四角(源像素坐标)
|
||
src_h, src_w = src.height, src.width
|
||
src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
|
||
corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4)
|
||
|
||
# 源像素 -> 地图坐标
|
||
map_corners = (M_map @ corners_h).T[:, :2]
|
||
|
||
# 地图坐标 -> 参考像素坐标
|
||
pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3)
|
||
pix_corners = pix_corners_h[:, :2]
|
||
|
||
# 最小外接矩形(像素)
|
||
min_x = int(np.floor(pix_corners[:,0].min())) - 10
|
||
max_x = int(np.ceil( pix_corners[:,0].max())) + 10
|
||
min_y = int(np.floor(pix_corners[:,1].min())) - 10
|
||
max_y = int(np.ceil( pix_corners[:,1].max())) + 10
|
||
|
||
# 边界裁剪
|
||
min_x = max(0, min_x); min_y = max(0, min_y)
|
||
max_x = min(ref_dataset.width, max_x)
|
||
max_y = min(ref_dataset.height, max_y)
|
||
|
||
bbox_w = max_x - min_x
|
||
bbox_h = max_y - min_y
|
||
|
||
# 如果外接矩形太小,跳过
|
||
if bbox_w <= 0 or bbox_h <= 0:
|
||
logger.warning(f"最小外接矩形无效: {bip_path.name}")
|
||
return False
|
||
|
||
# 创建裁剪窗口和变换
|
||
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
|
||
bbox_transform = ref_dataset.window_transform(bbox_window)
|
||
|
||
out_path = out_dir / f"{bip_path.stem}_registered.bip"
|
||
src_nodata = src.nodata
|
||
dst_nodata = src_nodata if src_nodata is not None else 0
|
||
|
||
# 更新输出 profile 使用最小外接矩形
|
||
out_profile = ref_dataset.profile.copy()
|
||
out_profile.update(
|
||
driver="ENVI",
|
||
dtype=src.dtypes[0],
|
||
height=bbox_h,
|
||
width=bbox_w,
|
||
count=src.count,
|
||
transform=bbox_transform, # 使用最小外接矩形的变换
|
||
crs=ref_crs,
|
||
interleave="bip",
|
||
compress=None,
|
||
nodata=dst_nodata
|
||
)
|
||
|
||
# 重采样到最小外接矩形
|
||
with rasterio.open(out_path, "w", **out_profile) as out_ds:
|
||
for b in range(1, src.count + 1):
|
||
src_band = src.read(b).astype(np.float32)
|
||
dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32)
|
||
reproject(
|
||
source=src_band,
|
||
destination=dst_band,
|
||
src_transform=corrected_affine,
|
||
src_crs=ref_crs,
|
||
dst_transform=bbox_transform,
|
||
dst_crs=ref_crs,
|
||
src_nodata=src_nodata,
|
||
dst_nodata=dst_nodata,
|
||
resampling=Resampling.bilinear,
|
||
)
|
||
# 转回目标 dtype
|
||
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
|
||
mask = (dst_band == dst_nodata) if src_nodata is not None else None
|
||
info = np.iinfo(out_profile["dtype"])
|
||
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
|
||
if mask is not None:
|
||
dst_band[mask] = dst_nodata
|
||
else:
|
||
dst_band = dst_band.astype(out_profile["dtype"])
|
||
|
||
out_ds.write(dst_band, b)
|
||
|
||
logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
|
||
return False
|
||
|
||
# ---------- 主逻辑 ----------
|
||
def main():
|
||
logger.info("开始批量配准处理...")
|
||
|
||
# 检查输入文件是否存在
|
||
if not Path(REF_TIF).exists():
|
||
logger.error(f"参考文件不存在: {REF_TIF}")
|
||
return
|
||
|
||
if not BIP_DIR.exists():
|
||
logger.error(f"BIP文件夹不存在: {BIP_DIR}")
|
||
return
|
||
|
||
# 初始化匹配器
|
||
logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}")
|
||
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
|
||
|
||
# 打开参考文件
|
||
with rasterio.open(REF_TIF) as ref:
|
||
logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}")
|
||
|
||
# 查找所有 .bip 文件
|
||
bip_files = list(BIP_DIR.glob("*.bip"))
|
||
logger.info(f"找到 {len(bip_files)} 个 .bip 文件")
|
||
|
||
success_count = 0
|
||
for bip_path in bip_files:
|
||
if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR):
|
||
success_count += 1
|
||
|
||
logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|