Files
StripStitch/test V4.py
2026-03-06 17:24:55 +08:00

468 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批量配准 .bip 文件到参考 .tif 文件
直接进行配准
"""
from pathlib import Path
import numpy as np
import cv2
import rasterio
from rasterio.windows import from_bounds
from rasterio.warp import transform_bounds, reproject, Resampling
from affine import Affine
from vismatch import get_matcher
import logging
try:
from skimage.transform import PiecewiseAffineTransform, PolynomialTransform
SKIMAGE_AVAILABLE = True
except ImportError:
SKIMAGE_AVAILABLE = False
logging.warning("scikit-image 不可用,将跳过 piecewise_affine 和 polynomial 变换")
try:
from matplotlib.path import Path as MplPath
from scipy.spatial import ConvexHull
MATPLOTLIB_SCIPY_AVAILABLE = True
except ImportError:
MATPLOTLIB_SCIPY_AVAILABLE = False
MplPath = None
logging.warning("matplotlib 或 scipy 不可用piecewise_affine 将退化为矩形内判断")
try:
import SimpleITK as sitk
SITK_AVAILABLE = True
except ImportError:
SITK_AVAILABLE = False
logging.warning("SimpleITK 不可用,将使用仿射变换作为替代")
try:
import pirt
PIRT_AVAILABLE = True
except ImportError:
PIRT_AVAILABLE = False
logging.warning("PIRT 不可用,将使用 SimpleITK TPS 作为替代")
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ---------- 配置 ----------
# 请根据实际情况修改这些路径
REF_TIF = r"E:\is2\yaopu\result.tif" # 参考 tif 文件路径
BIP_DIR = Path(r"E:\is2\yaopu") # .bip 文件所在文件夹
OUT_DIR = Path(r"E:\is2\yaopu\output") # 输出文件夹
# 匹配算法选择
MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
DEVICE = "cuda" # 或 "cpu"
# 使用密集匹配模型的稠密流直接进行配准
# 匹配参数
MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素)
ROI_PAD_PX = 500 # 粗定位窗口的padding参考tif像素
# 质量控制阈值
MIN_INLIERS = 10 # 最少内点数
MIN_INLIER_RATIO = 0.01 # 最少内点比例
# 创建输出目录
OUT_DIR.mkdir(parents=True, exist_ok=True)
# ---------- 工具函数 ----------
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
arr = arr_chw.astype(np.float32)
if arr.shape[0] == 1:
# 单波段复制为3通道
arr = np.repeat(arr, 3, axis=0)
elif arr.shape[0] >= 3:
# 取前3波段
arr = arr[:3]
else:
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
# 百分位数拉伸,增强跨传感器匹配稳定性
p2 = np.percentile(arr, 2)
p98 = np.percentile(arr, 98)
arr = (arr - p2) / (p98 - p2 + 1e-6)
arr = np.clip(arr, 0.0, 1.0)
return arr
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
c, h, w = arr_chw.shape
s = min(1.0, max_side / max(h, w))
if s >= 1.0:
return arr_chw
new_w = int(round(w * s))
new_h = int(round(h * s))
# 用opencv缩放(逐通道)
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
return out
def _expand_window(win, pad, max_w, max_h):
"""扩展窗口并确保边界有效"""
col_off = int(max(0, win.col_off - pad))
row_off = int(max(0, win.row_off - pad))
col_end = int(min(max_w, win.col_off + win.width + pad))
row_end = int(min(max_h, win.row_off + win.height + pad))
return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off)
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path):
"""处理单个 .bip 文件到参考 .tif 的配准"""
try:
with rasterio.open(bip_path) as src:
logger.info(f"处理文件: {bip_path.name}")
# 检查CRS
if src.crs is None:
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息尝试使用参考文件的CRS")
src_crs = ref_dataset.crs
else:
src_crs = src.crs
ref_crs = ref_dataset.crs
if ref_crs is None:
raise RuntimeError(f"参考文件缺少CRS信息")
# 1) 用地理信息把 src.bounds 转到 ref CRS再裁 ref ROI
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
win0 = from_bounds(*b, transform=ref_dataset.transform)
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
if win.width <= 0 or win.height <= 0:
logger.warning(f"无重叠区域: {bip_path.name}")
return False
# 2) 读取数据
# 读取所有波段,如果是多波段的话
src_arr = src.read() # (bands, H, W)
if src_arr.ndim == 2: # 单波段
src_arr = src_arr[None, ...] # 增加波段维度
# 读取参考文件的ROI
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
if ref_arr.ndim == 2: # 单波段
ref_arr = ref_arr[None, ...] # 增加波段维度
# 转换为匹配所需的格式
src_img = _to_3ch_float01(src_arr)
ref_img = _to_3ch_float01(ref_arr)
# 3) 匹配用降采样版本,提速 + 增稳
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
# 4) 精配准img0=src, img1=ref_roi
result = matcher(src_small, ref_small)
num_inl = int(result["num_inliers"])
num_m = len(result["matched_kpts0"])
ratio = (num_inl / num_m) if num_m else 0.0
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
logger.warning(f"匹配质量不足: {bip_path.name}")
return False
# ==== 稠密流直接重采样(无需后续显式变换估计) ====
# 1) 取稠密流(优先 ref->src。不同模型的键名可能不同这里做兼容
flow_small = None
for k in ["flow_ref2src", "flow21", "flow_1_0", "flow10", "flow"]:
if k in result:
flow_small = result[k]
break
if flow_small is None:
# 回退:优先 DIS 光流(更快/稳),若不可用再用 Farneback (ref -> src)
ref_small_rgb = np.transpose(ref_small, (1, 2, 0)) # (H,W,3)
src_small_rgb = np.transpose(src_small, (1, 2, 0))
ref_small_gray = cv2.cvtColor((ref_small_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
src_small_gray = cv2.cvtColor((src_small_rgb * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
flow_small = None
try:
dis = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_MEDIUM)
flow_small = dis.calc(ref_small_gray, src_small_gray, None).astype(np.float32)
except Exception:
pass
if flow_small is None:
# 典型参数:可按影像特性微调(窗口、迭代次数等)
flow_small = cv2.calcOpticalFlowFarneback(
ref_small_gray, src_small_gray,
None, 0.5, 3, 25, 3, 5, 1.2, 0
).astype(np.float32)
# flow_small 期望形状 (h_s, w_s, 2),分量为 (dx, dy): 参考像素到源像素的位移
flow_small = np.asarray(flow_small, dtype=np.float32)
if flow_small.ndim != 3 or flow_small.shape[2] != 2:
logger.warning(f"稠密流形状异常: {flow_small.shape}")
return False
# 2) 将小图的流放大到 ROI 全分辨率,并按比例放大位移
roi_h, roi_w = ref_img.shape[1], ref_img.shape[2] # 注意 ref_img 是 ROI 子图
scale_x = roi_w / flow_small.shape[1]
scale_y = roi_h / flow_small.shape[0]
flow_full = cv2.resize(flow_small, (roi_w, roi_h), interpolation=cv2.INTER_LINEAR)
flow_full[..., 0] *= scale_x # dx
flow_full[..., 1] *= scale_y # dy
# 3) 生成 remap 所需的源坐标图map_x, map_y在"参考ROI坐标系"内工作
yy, xx = np.meshgrid(np.arange(roi_h, dtype=np.float32),
np.arange(roi_w, dtype=np.float32), indexing="ij")
map_x = xx + flow_full[..., 0] # 到源图的 x
map_y = yy + flow_full[..., 1] # 到源图的 y
# 4) 根据有效映射范围求最小外接矩形(仅统计落在源图范围内的像素)
valid = (map_x >= 0) & (map_x <= (src.width - 1)) & (map_y >= 0) & (map_y <= (src.height - 1))
if not np.any(valid):
logger.warning(f"稠密流无有效映射: {bip_path.name}")
return False
ys, xs = np.where(valid)
pad = 0
min_y = max(int(ys.min()) - pad, 0)
max_y = min(int(ys.max()) + 1 + pad, roi_h)
min_x = max(int(xs.min()) - pad, 0)
max_x = min(int(xs.max()) + 1 + pad, roi_w)
crop_h = max_y - min_y
crop_w = max_x - min_x
if crop_h <= 0 or crop_w <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
# 只对外接矩形区域做重采样,减少内存
map_x_crop = map_x[min_y:max_y, min_x:max_x].astype(np.float32)
map_y_crop = map_y[min_y:max_y, min_x:max_x].astype(np.float32)
# 5) 计算输出的地理变换参考ROI窗口 + 外接矩形子窗口
# 先得到 ROI 的 transform再叠加子窗口偏移
roi_transform = ref_dataset.window_transform(win)
crop_window_global = rasterio.windows.Window(
win.col_off + min_x, win.row_off + min_y, crop_w, crop_h
)
out_transform = ref_dataset.window_transform(crop_window_global)
# 6) 写出 ENVI/BIP按最小外接矩形
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=crop_h,
width=crop_w,
count=src.count,
transform=out_transform,
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
# 反向映射采样输出像素在参考ROI坐标去源图(map_y,map_x)取值
warped = cv2.remap(
src_band, map_x_crop, map_y_crop,
interpolation=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=float(dst_nodata)
).astype(np.float32)
# 转回目标 dtype保持 nodata
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (warped == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
warped = np.clip(warped, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
warped[mask] = dst_nodata
else:
warped = warped.astype(out_profile["dtype"])
out_ds.write(warped, b)
logger.info(f"成功配准(DenseFlow): {bip_path.name} -> {out_path.name}")
return True
# ---- 回退:使用仿射变换,保证最小可用结果 ----
# 重新估计仿射变换作为fallback
A_fallback, _ = cv2.estimateAffine2D(k0_full, k1_global, method=cv2.RANSAC, ransacReprojThreshold=3.0)
if A_fallback is None:
logger.warning(f"仿射回退也失败: {bip_path.name}")
return False
# 构造 full_src -> full_ref_roi 的仿射并回写到地图坐标
s0x = src_img.shape[2] / src_small.shape[2]
s0y = src_img.shape[1] / src_small.shape[1]
s1x = ref_img.shape[2] / ref_small.shape[2]
s1y = ref_img.shape[1] / ref_small.shape[1]
S0 = np.array([[1/s0x, 0, 0], [0, 1/s0y, 0], [0, 0, 1]], dtype=np.float64)
S1_inv = np.array([[s1x, 0, 0], [0, s1y, 0], [0, 0, 1]], dtype=np.float64)
A3 = np.eye(3, dtype=np.float64); A3[:2, :] = A_fallback
M_full = S1_inv @ A3 @ S0
T_off = np.array([[1, 0, win.col_off], [0, 1, win.row_off], [0, 0, 1]], dtype=np.float64)
ref_transform = ref_dataset.transform
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0, 0, 1]], dtype=np.float64)
src_pixel_to_map_corrected = Rt @ T_off @ M_full
corrected_affine = Affine(
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
)
# 计算源 BIP 四角经过仿射变换后的最小外接矩形
# 将 rasterio.Affine 转为 3x3 像素->地图矩阵
M_map = np.array([
[corrected_affine.a, corrected_affine.b, corrected_affine.c],
[corrected_affine.d, corrected_affine.e, corrected_affine.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
# 参考底图的 像素->地图 矩阵及其逆
ref_transform = ref_dataset.transform
Rt = np.array([
[ref_transform.a, ref_transform.b, ref_transform.c],
[ref_transform.d, ref_transform.e, ref_transform.f],
[0.0, 0.0, 1.0]
], dtype=np.float64)
Rt_inv = np.linalg.inv(Rt)
# 源影像四角(源像素坐标)
src_h, src_w = src.height, src.width
src_corners = np.array([[0,0],[src_w,0],[src_w,src_h],[0,src_h]], dtype=np.float64)
corners_h = np.hstack([src_corners, np.ones((4,1))]).T # (3,4)
# 源像素 -> 地图坐标
map_corners = (M_map @ corners_h).T[:, :2]
# 地图坐标 -> 参考像素坐标
pix_corners_h = (Rt_inv @ np.hstack([map_corners, np.ones((4,1))]).T).T # (4,3)
pix_corners = pix_corners_h[:, :2]
# 最小外接矩形(像素)
min_x = int(np.floor(pix_corners[:,0].min())) - 10
max_x = int(np.ceil( pix_corners[:,0].max())) + 10
min_y = int(np.floor(pix_corners[:,1].min())) - 10
max_y = int(np.ceil( pix_corners[:,1].max())) + 10
# 边界裁剪
min_x = max(0, min_x); min_y = max(0, min_y)
max_x = min(ref_dataset.width, max_x)
max_y = min(ref_dataset.height, max_y)
bbox_w = max_x - min_x
bbox_h = max_y - min_y
# 如果外接矩形太小,跳过
if bbox_w <= 0 or bbox_h <= 0:
logger.warning(f"最小外接矩形无效: {bip_path.name}")
return False
# 创建裁剪窗口和变换
bbox_window = rasterio.windows.Window(min_x, min_y, bbox_w, bbox_h)
bbox_transform = ref_dataset.window_transform(bbox_window)
out_path = out_dir / f"{bip_path.stem}_registered.bip"
src_nodata = src.nodata
dst_nodata = src_nodata if src_nodata is not None else 0
# 更新输出 profile 使用最小外接矩形
out_profile = ref_dataset.profile.copy()
out_profile.update(
driver="ENVI",
dtype=src.dtypes[0],
height=bbox_h,
width=bbox_w,
count=src.count,
transform=bbox_transform, # 使用最小外接矩形的变换
crs=ref_crs,
interleave="bip",
compress=None,
nodata=dst_nodata
)
# 重采样到最小外接矩形
with rasterio.open(out_path, "w", **out_profile) as out_ds:
for b in range(1, src.count + 1):
src_band = src.read(b).astype(np.float32)
dst_band = np.zeros((bbox_h, bbox_w), dtype=np.float32)
reproject(
source=src_band,
destination=dst_band,
src_transform=corrected_affine,
src_crs=ref_crs,
dst_transform=bbox_transform,
dst_crs=ref_crs,
src_nodata=src_nodata,
dst_nodata=dst_nodata,
resampling=Resampling.bilinear,
)
# 转回目标 dtype
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
mask = (dst_band == dst_nodata) if src_nodata is not None else None
info = np.iinfo(out_profile["dtype"])
dst_band = np.clip(dst_band, info.min, info.max).astype(out_profile["dtype"])
if mask is not None:
dst_band[mask] = dst_nodata
else:
dst_band = dst_band.astype(out_profile["dtype"])
out_ds.write(dst_band, b)
logger.info(f"成功配准(仿射回退): {bip_path.name} -> {out_path.name}")
return True
except Exception as e:
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
return False
# ---------- 主逻辑 ----------
def main():
logger.info("开始批量配准处理...")
# 检查输入文件是否存在
if not Path(REF_TIF).exists():
logger.error(f"参考文件不存在: {REF_TIF}")
return
if not BIP_DIR.exists():
logger.error(f"BIP文件夹不存在: {BIP_DIR}")
return
# 初始化匹配器
logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}")
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
# 打开参考文件
with rasterio.open(REF_TIF) as ref:
logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}")
# 查找所有 .bip 文件
bip_files = list(BIP_DIR.glob("*.bip"))
logger.info(f"找到 {len(bip_files)} 个 .bip 文件")
success_count = 0
for bip_path in bip_files:
if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR):
success_count += 1
logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准")
if __name__ == "__main__":
main()