first commit
This commit is contained in:
290
test.py
Normal file
290
test.py
Normal file
@ -0,0 +1,290 @@
|
||||
"""
|
||||
批量配准 .bip 文件到参考 .tif 文件
|
||||
使用地理信息进行粗定位,然后用图像匹配进行精配准
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import cv2
|
||||
import rasterio
|
||||
from rasterio.windows import from_bounds
|
||||
from rasterio.warp import transform_bounds, reproject, Resampling
|
||||
from affine import Affine
|
||||
from vismatch import get_matcher
|
||||
import logging
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------- 配置 ----------
|
||||
# 请根据实际情况修改这些路径
|
||||
REF_TIF = r"E:\is2\jiashixian\result.tif" # 参考 tif 文件路径
|
||||
BIP_DIR = Path(r"E:\is2\jiashixian\Geoout\1") # .bip 文件所在文件夹
|
||||
OUT_DIR = Path(r"E:\is2\jiashixian\matchanything-roma") # 输出文件夹
|
||||
|
||||
# 匹配算法选择
|
||||
MATCHER_NAME = "matchanything-roma" # 可选: xfeat-star, loftr, roma, superpoint-lightglue, sift-lightglue 等
|
||||
DEVICE = "cuda" # 或 "cpu"
|
||||
|
||||
# 匹配参数
|
||||
MATCH_MAX_SIDE = 1200 # 匹配时最大边长(像素)
|
||||
ROI_PAD_PX = 300 # 粗定位窗口的padding(参考tif像素)
|
||||
|
||||
# 质量控制阈值
|
||||
MIN_INLIERS = 30 # 最少内点数
|
||||
MIN_INLIER_RATIO = 0.15 # 最少内点比例
|
||||
|
||||
# 创建输出目录
|
||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---------- 工具函数 ----------
|
||||
def _to_3ch_float01(arr_chw: np.ndarray) -> np.ndarray:
|
||||
"""将任意通道数的数组转换为 (3,H,W) float32 in [0,1]"""
|
||||
arr = arr_chw.astype(np.float32)
|
||||
|
||||
if arr.shape[0] == 1:
|
||||
# 单波段复制为3通道
|
||||
arr = np.repeat(arr, 3, axis=0)
|
||||
elif arr.shape[0] >= 3:
|
||||
# 取前3波段
|
||||
arr = arr[:3]
|
||||
else:
|
||||
raise ValueError(f"不支持的通道数: {arr.shape[0]}")
|
||||
|
||||
# 百分位数拉伸,增强跨传感器匹配稳定性
|
||||
p2 = np.percentile(arr, 2)
|
||||
p98 = np.percentile(arr, 98)
|
||||
arr = (arr - p2) / (p98 - p2 + 1e-6)
|
||||
arr = np.clip(arr, 0.0, 1.0)
|
||||
return arr
|
||||
|
||||
def _downscale_chw(arr_chw: np.ndarray, max_side: int) -> np.ndarray:
|
||||
"""等比缩放 (C,H,W) 到 max(H,W) <= max_side"""
|
||||
c, h, w = arr_chw.shape
|
||||
s = min(1.0, max_side / max(h, w))
|
||||
if s >= 1.0:
|
||||
return arr_chw
|
||||
new_w = int(round(w * s))
|
||||
new_h = int(round(h * s))
|
||||
# 用opencv缩放(逐通道)
|
||||
out = np.stack([cv2.resize(arr_chw[i], (new_w, new_h), interpolation=cv2.INTER_AREA) for i in range(c)], axis=0)
|
||||
return out
|
||||
|
||||
def _expand_window(win, pad, max_w, max_h):
|
||||
"""扩展窗口并确保边界有效"""
|
||||
col_off = int(max(0, win.col_off - pad))
|
||||
row_off = int(max(0, win.row_off - pad))
|
||||
col_end = int(min(max_w, win.col_off + win.width + pad))
|
||||
row_end = int(min(max_h, win.row_off + win.height + pad))
|
||||
return rasterio.windows.Window(col_off, row_off, col_end - col_off, row_end - row_off)
|
||||
|
||||
def process_bip_to_tif(bip_path: Path, ref_dataset, matcher, out_dir: Path):
|
||||
"""处理单个 .bip 文件到参考 .tif 的配准"""
|
||||
try:
|
||||
with rasterio.open(bip_path) as src:
|
||||
logger.info(f"处理文件: {bip_path.name}")
|
||||
|
||||
# 检查CRS
|
||||
if src.crs is None:
|
||||
logger.warning(f"源文件 {bip_path.name} 缺少CRS信息,尝试使用参考文件的CRS")
|
||||
src_crs = ref_dataset.crs
|
||||
else:
|
||||
src_crs = src.crs
|
||||
|
||||
ref_crs = ref_dataset.crs
|
||||
if ref_crs is None:
|
||||
raise RuntimeError(f"参考文件缺少CRS信息")
|
||||
|
||||
# 1) 用地理信息把 src.bounds 转到 ref CRS,再裁 ref ROI
|
||||
b = transform_bounds(src_crs, ref_crs, *src.bounds, densify_pts=21)
|
||||
win0 = from_bounds(*b, transform=ref_dataset.transform)
|
||||
win = _expand_window(win0, ROI_PAD_PX, ref_dataset.width, ref_dataset.height)
|
||||
|
||||
if win.width <= 0 or win.height <= 0:
|
||||
logger.warning(f"无重叠区域: {bip_path.name}")
|
||||
return False
|
||||
|
||||
# 2) 读取数据
|
||||
# 读取所有波段,如果是多波段的话
|
||||
src_arr = src.read() # (bands, H, W)
|
||||
if src_arr.ndim == 2: # 单波段
|
||||
src_arr = src_arr[None, ...] # 增加波段维度
|
||||
|
||||
# 读取参考文件的ROI
|
||||
ref_arr = ref_dataset.read(window=win) # (bands, h, w)
|
||||
if ref_arr.ndim == 2: # 单波段
|
||||
ref_arr = ref_arr[None, ...] # 增加波段维度
|
||||
|
||||
# 转换为匹配所需的格式
|
||||
src_img = _to_3ch_float01(src_arr)
|
||||
ref_img = _to_3ch_float01(ref_arr)
|
||||
|
||||
# 3) 匹配用降采样版本,提速 + 增稳
|
||||
src_small = _downscale_chw(src_img, MATCH_MAX_SIDE)
|
||||
ref_small = _downscale_chw(ref_img, MATCH_MAX_SIDE)
|
||||
|
||||
logger.info(f"匹配尺寸: src {src_small.shape[1:]} -> ref {ref_small.shape[1:]}")
|
||||
|
||||
# 4) 精配准(img0=src, img1=ref_roi)
|
||||
result = matcher(src_small, ref_small)
|
||||
|
||||
num_inl = int(result["num_inliers"])
|
||||
num_m = len(result["matched_kpts0"])
|
||||
ratio = (num_inl / num_m) if num_m else 0.0
|
||||
|
||||
logger.info(f"匹配结果: 内点={num_inl}, 匹配点={num_m}, 内点比例={ratio:.2f}")
|
||||
|
||||
if num_inl < MIN_INLIERS or ratio < MIN_INLIER_RATIO:
|
||||
logger.warning(f"匹配质量不足: {bip_path.name}")
|
||||
return False
|
||||
|
||||
# 5) 用内点估计仿射变换
|
||||
k0 = result["inlier_kpts0"].astype(np.float32)
|
||||
k1 = result["inlier_kpts1"].astype(np.float32)
|
||||
|
||||
A, _ = cv2.estimateAffinePartial2D(k0, k1, method=cv2.RANSAC, ransacReprojThreshold=3.0)
|
||||
if A is None:
|
||||
logger.warning(f"仿射估计失败: {bip_path.name}")
|
||||
return False
|
||||
|
||||
# 6) 把"src_small->ref_small"的仿射映射回"src_full->ref_full_roi"
|
||||
# 缩放系数:full/small
|
||||
s0x = src_img.shape[2] / src_small.shape[2]
|
||||
s0y = src_img.shape[1] / src_small.shape[1]
|
||||
s1x = ref_img.shape[2] / ref_small.shape[2]
|
||||
s1y = ref_img.shape[1] / ref_small.shape[1]
|
||||
|
||||
S0 = np.array([[1/s0x, 0, 0],
|
||||
[0, 1/s0y, 0],
|
||||
[0, 0, 1]], dtype=np.float64) # full -> small
|
||||
S1_inv = np.array([[s1x, 0, 0],
|
||||
[0, s1y, 0],
|
||||
[0, 0, 1]], dtype=np.float64) # small -> full(roi)
|
||||
|
||||
A3 = np.eye(3, dtype=np.float64)
|
||||
A3[:2, :] = A # small src -> small ref_roi
|
||||
|
||||
# full_src -> full_ref_roi(用 S0,而不是 S0_inv)
|
||||
M_full = S1_inv @ A3 @ S0
|
||||
|
||||
# 7) 目标输出:重采样到 ref 全图网格
|
||||
# 需要"src像素 -> 地图坐标"的修正 transform
|
||||
T_off = np.array([[1, 0, win.col_off],
|
||||
[0, 1, win.row_off],
|
||||
[0, 0, 1]], dtype=np.float64)
|
||||
|
||||
# ref_transform 转 3x3
|
||||
ref_transform = ref_dataset.transform
|
||||
Rt = np.array([[ref_transform.a, ref_transform.b, ref_transform.c],
|
||||
[ref_transform.d, ref_transform.e, ref_transform.f],
|
||||
[0, 0, 1]], dtype=np.float64)
|
||||
|
||||
src_pixel_to_map_corrected = Rt @ T_off @ M_full
|
||||
|
||||
# 转回 Affine
|
||||
corrected_affine = Affine(
|
||||
src_pixel_to_map_corrected[0, 0], src_pixel_to_map_corrected[0, 1], src_pixel_to_map_corrected[0, 2],
|
||||
src_pixel_to_map_corrected[1, 0], src_pixel_to_map_corrected[1, 1], src_pixel_to_map_corrected[1, 2],
|
||||
)
|
||||
|
||||
# 8) 重采样输出到 ref 网格(多波段,BIP 输出)
|
||||
out_path = out_dir / f"{bip_path.stem}_registered.bip"
|
||||
|
||||
# 获取 NoData 值
|
||||
src_nodata = src.nodata # 可能是 None、0、65535 等
|
||||
dst_nodata = src_nodata if src_nodata is not None else 0
|
||||
|
||||
# 以参考影像的网格/坐标为目标,波段数和数据类型继承源影像
|
||||
out_profile = ref_dataset.profile.copy()
|
||||
out_profile.update(
|
||||
driver="ENVI",
|
||||
dtype=src.dtypes[0], # 尽量保持与源影像一致的数据类型
|
||||
height=ref_dataset.height,
|
||||
width=ref_dataset.width,
|
||||
count=src.count, # 多波段
|
||||
transform=ref_transform,
|
||||
crs=ref_crs,
|
||||
interleave="bip", # 指定 BIP
|
||||
compress=None,
|
||||
nodata=dst_nodata # 设置 NoData 值
|
||||
)
|
||||
|
||||
with rasterio.open(out_path, "w", **out_profile) as out_ds:
|
||||
for b in range(1, src.count + 1):
|
||||
# 逐波段重投影(float32 计算更稳)
|
||||
src_band = src.read(b).astype(np.float32)
|
||||
dst_band = np.zeros((ref_dataset.height, ref_dataset.width), dtype=np.float32)
|
||||
|
||||
reproject(
|
||||
source=src_band,
|
||||
destination=dst_band,
|
||||
src_transform=corrected_affine, # 由原逻辑推导的 src像素->ref像素 的仿射
|
||||
src_crs=ref_crs,
|
||||
dst_transform=ref_transform,
|
||||
dst_crs=ref_crs,
|
||||
src_nodata=src_nodata, # 新增 NoData 处理
|
||||
dst_nodata=dst_nodata, # 新增 NoData 处理
|
||||
resampling=Resampling.bilinear,
|
||||
)
|
||||
|
||||
# 调试统计信息
|
||||
logger.info(f"波段 {b}: min={float(np.nanmin(dst_band)):.2f}, max={float(np.nanmax(dst_band)):.2f}, "
|
||||
f"mean={float(np.nanmean(dst_band)):.2f}, nodata占比={(dst_band==dst_nodata).mean():.2%}")
|
||||
|
||||
# 转回目标 dtype(若为整型,先裁剪再转型)
|
||||
if np.issubdtype(np.dtype(out_profile["dtype"]), np.integer):
|
||||
if src_nodata is not None:
|
||||
# 保持 nodata 不被拉伸
|
||||
mask = (dst_band == dst_nodata)
|
||||
info = np.iinfo(out_profile["dtype"])
|
||||
dst_band = np.clip(dst_band, info.min, info.max)
|
||||
dst_band = dst_band.astype(out_profile["dtype"])
|
||||
if src_nodata is not None:
|
||||
dst_band[mask] = dst_nodata
|
||||
else:
|
||||
dst_band = dst_band.astype(out_profile["dtype"])
|
||||
|
||||
out_ds.write(dst_band, b)
|
||||
|
||||
logger.info(f"成功配准: {bip_path.name} -> {out_path.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理失败 {bip_path.name}: {str(e)}")
|
||||
return False
|
||||
|
||||
# ---------- 主逻辑 ----------
|
||||
def main():
|
||||
logger.info("开始批量配准处理...")
|
||||
|
||||
# 检查输入文件是否存在
|
||||
if not Path(REF_TIF).exists():
|
||||
logger.error(f"参考文件不存在: {REF_TIF}")
|
||||
return
|
||||
|
||||
if not BIP_DIR.exists():
|
||||
logger.error(f"BIP文件夹不存在: {BIP_DIR}")
|
||||
return
|
||||
|
||||
# 初始化匹配器
|
||||
logger.info(f"初始化匹配器: {MATCHER_NAME} on {DEVICE}")
|
||||
matcher = get_matcher(MATCHER_NAME, device=DEVICE)
|
||||
|
||||
# 打开参考文件
|
||||
with rasterio.open(REF_TIF) as ref:
|
||||
logger.info(f"参考文件信息: {ref.width}x{ref.height}, CRS: {ref.crs}")
|
||||
|
||||
# 查找所有 .bip 文件
|
||||
bip_files = list(BIP_DIR.glob("*.bip"))
|
||||
logger.info(f"找到 {len(bip_files)} 个 .bip 文件")
|
||||
|
||||
success_count = 0
|
||||
for bip_path in bip_files:
|
||||
if process_bip_to_tif(bip_path, ref, matcher, OUT_DIR):
|
||||
success_count += 1
|
||||
|
||||
logger.info(f"处理完成: {success_count}/{len(bip_files)} 个文件成功配准")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user