""" 使用SAM3模型分割超大遥感影像中的水体(分区域处理,优化内存) 流程: 1. 将影像划分为若干有重叠的子区域。 2. 对每个子区域独立执行粗分割、边缘带构建、精修分割。 3. 合并所有子区域的掩码(重叠区域取最大值)。 4. 后处理:填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。 5. 保存结果。 优点:避免在全尺寸概率图上进行GPU操作,显著降低显存占用。 """ import torch import torch.nn.functional as F import matplotlib import matplotlib.pyplot as plt import numpy as np from PIL import Image import rasterio from rasterio.windows import Window from rasterio.io import MemoryFile from tqdm import tqdm from scipy import ndimage from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor # ---------- 工具函数(保持不变) ---------- def binary_dilate(mask, radius): if radius <= 0: return mask kernel = 2 * radius + 1 return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5 def binary_erode(mask, radius): if radius <= 0: return mask return ~binary_dilate(~mask, radius) def combine_masks_logits(masks_logits): if masks_logits.numel() == 0: return None probs = masks_logits.squeeze(1) if probs.dim() == 2: return probs return torch.amax(probs, dim=0) def upsample_prob(prob, size): return F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False).squeeze(0).squeeze(0) def build_band(mask, radius): """mask: 2D tensor (CPU or GPU)""" mask_4d = mask[None, None, ...] # (1,1,H,W) dil = binary_dilate(mask_4d, radius) ero = binary_erode(mask_4d, radius) band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0) return band def tile_slices(height, width, tile_size, overlap): stride = max(tile_size - overlap, 1) for top in range(0, height, stride): for left in range(0, width, stride): bottom = min(top + tile_size, height) right = min(left + tile_size, width) yield top, left, bottom, right def _to_uint8(arr): if arr.dtype == np.uint8: return arr arr = arr.astype(np.float32) vmin = np.percentile(arr, 2.0) vmax = np.percentile(arr, 98.0) if vmax <= vmin: return np.zeros_like(arr, dtype=np.uint8) arr = (arr - vmin) / (vmax - vmin) arr = np.clip(arr, 0.0, 1.0) return (arr * 255.0).astype(np.uint8) def _bands_to_pil(bands): if bands.ndim != 3: raise ValueError("bands must be (C,H,W)") c, h, w = bands.shape if c == 1: rgb = np.repeat(bands, 3, axis=0) else: rgb = bands[:3] rgb = _to_uint8(rgb) rgb = np.transpose(rgb, (1, 2, 0)) return Image.fromarray(rgb, mode="RGB") def _read_pil_window(src, window): bands = src.read(window=window, boundless=True, fill_value=0) return _bands_to_pil(bands) def _coarse_shape(height, width, max_side): scale = max_side / float(max(height, width)) h = max(int(round(height * scale)), 1) w = max(int(round(width * scale)), 1) return h, w # ---------- 粗分割分块推理 ---------- def coarse_tiles(processor, src, prompt, tile_size, overlap): """ 对数据集src进行分块粗分割,返回全尺寸概率图(GPU张量) src: rasterio数据集(可以是内存文件或原始文件) """ height, width = src.height, src.width device = processor.device full_probs = torch.zeros((height, width), dtype=torch.float32, device=device) stride = max(tile_size - overlap, 1) num_tiles_y = (height + stride - 1) // stride num_tiles_x = (width + stride - 1) // stride total_tiles = num_tiles_y * num_tiles_x # 内部进度条,leave=True 以便保留历史记录 with tqdm(total=total_tiles, desc="粗分割分块", unit="块", leave=True) as pbar: for top, left, bottom, right in tile_slices(height, width, tile_size, overlap): window = Window(left, top, right - left, bottom - top) crop = _read_pil_window(src, window) state = processor.set_image(crop) state = processor.set_text_prompt(prompt=prompt, state=state) tile_prob = combine_masks_logits(state["masks_logits"]) if tile_prob is None: pbar.update(1) continue if tile_prob.shape[-2:] != (bottom - top, right - left): tile_prob = upsample_prob(tile_prob, (bottom - top, right - left)) full_probs[top:bottom, left:right] = torch.maximum( full_probs[top:bottom, left:right], tile_prob ) pbar.update(1) return full_probs # ---------- 精修分块 ---------- def refine_tiles(processor, src, prompt, band_coarse, tile_size, overlap): """ 对边缘带进行精修分割 band_coarse: 2D GPU张量,指示需要精修的区域 """ height, width = src.height, src.width device = processor.device full_probs = torch.zeros((height, width), dtype=torch.float32, device=device) band_h, band_w = band_coarse.shape scale_y = band_h / float(height) scale_x = band_w / float(width) stride = max(tile_size - overlap, 1) num_tiles_y = (height + stride - 1) // stride num_tiles_x = (width + stride - 1) // stride total_tiles = num_tiles_y * num_tiles_x # 内部进度条,leave=True 以便保留历史记录 with tqdm(total=total_tiles, desc="精修分块", unit="块", leave=True) as pbar: for top, left, bottom, right in tile_slices(height, width, tile_size, overlap): c_top = int(top * scale_y) c_left = int(left * scale_x) c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1) c_right = max(int(np.ceil(right * scale_x)), c_left + 1) if not band_coarse[c_top:c_bottom, c_left:c_right].any(): pbar.update(1) continue window = Window(left, top, right - left, bottom - top) crop = _read_pil_window(src, window) state = processor.set_image(crop) state = processor.set_text_prompt(prompt=prompt, state=state) tile_prob = combine_masks_logits(state["masks_logits"]) if tile_prob is None: pbar.update(1) continue if tile_prob.shape[-2:] != (bottom - top, right - left): tile_prob = upsample_prob(tile_prob, (bottom - top, right - left)) full_probs[top:bottom, left:right] = torch.maximum( full_probs[top:bottom, left:right], tile_prob ) pbar.update(1) return full_probs # ---------- 后处理函数(面积过滤) ---------- def filter_by_area(mask, min_area=None, keep_largest_only=False): """ 对二值掩码进行连通域面积过滤。 mask: np.ndarray, dtype=bool or uint8, shape (H,W) min_area: int, 最小面积阈值(像素),小于此值的连通域移除。若为None则不过滤。 keep_largest_only: bool, 若True,只保留面积最大的连通域(同时受min_area约束)。 返回过滤后的掩码 (uint8, 0/1)。 """ labeled, num = ndimage.label(mask) if num == 0: return mask.astype(np.uint8) sizes = ndimage.sum(mask, labeled, range(1, num+1)) keep_labels = [] if keep_largest_only: max_idx = np.argmax(sizes) if min_area is None or sizes[max_idx] >= min_area: keep_labels = [max_idx + 1] else: for i, size in enumerate(sizes, start=1): if min_area is None or size >= min_area: keep_labels.append(i) if not keep_labels: return np.zeros_like(mask, dtype=np.uint8) keep_mask = np.isin(labeled, keep_labels) return keep_mask.astype(np.uint8) # ---------- 划分子区域函数 ---------- def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=256): """ 将图像划分为若干有重叠的子区域,返回每个子区域的 (left, top, right, bottom) 坐标(像素坐标,相对于原图)。 """ regions = [] tile_w = (width + num_splits_x - 1) // num_splits_x tile_h = (height + num_splits_y - 1) // num_splits_y for i in range(num_splits_y): for j in range(num_splits_x): left = max(j * tile_w - overlap, 0) top = max(i * tile_h - overlap, 0) right = min((j + 1) * tile_w + overlap, width) bottom = min((i + 1) * tile_h + overlap, height) regions.append((left, top, right, bottom)) return regions # ---------- 主程序 ---------- matplotlib.use("TkAgg") # 参数设置 image_path = r"E:\is2\dingshanhu\result_caijian.tif" mask_output_path = r"E:\is2\dingshanhu\result_maskV1.tif" prompt = "water body" coarse_read_max_side = 768 # 不再使用,保留以防万一 coarse_resolution = 1008 fine_resolution = 1008 coarse_threshold = 0.5 final_threshold = 0.5 band_radius = 64 tile_size = 2048 # 精修分块大小 overlap = 256 # 精修重叠 coarse_tile_size = 4096 # 粗分割分块大小 coarse_overlap = 256 # 粗分割重叠 # ========== 分区域参数 ========== num_splits_y = 2 # 纵向切分数 num_splits_x = 3 # 横向切分数(共 2x3=6 份) region_overlap = 256 # 子区域之间的重叠像素数 # =============================== # ========== 后处理参数 ========== min_area = 1000 # 最小面积阈值(像素),小于此值的连通域将被移除;设为0或None表示不进行面积过滤 keep_largest_only = False # 是否只保留最大的连通域(True/False) # =============================== device = "cuda" if torch.cuda.is_available() else "cpu" print(f"使用设备: {device}") # 加载模型(所有子区域共享同一个模型) model = build_sam3_image_model().to(device).eval() coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device) fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device) # 打开原始影像 with rasterio.open(image_path) as src: nodata = src.nodata print(f"原始影像NoData值: {nodata}") height, width = src.height, src.width # 划分区域 regions = split_into_regions(width, height, num_splits_y, num_splits_x, region_overlap) print(f"共划分 {len(regions)} 个子区域") # 创建全尺寸掩码数组(CPU内存) full_mask = np.zeros((height, width), dtype=np.uint8) # 子区域总进度条 with tqdm(total=len(regions), desc="处理子区域", unit="子区域") as region_pbar: for idx, (left, top, right, bottom) in enumerate(regions): print(f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})") region_w = right - left region_h = bottom - top # 读取子区域数据 window = Window(left, top, region_w, region_h) sub_bands = src.read(window=window) # shape: (bands, h, w) # 构建子区域元数据 sub_profile = src.profile.copy() if not src.is_tiled: sub_profile.pop('blockxsize', None) sub_profile.pop('blockysize', None) sub_profile['tiled'] = False else: sub_profile['tiled'] = True sub_profile.update({ 'height': region_h, 'width': region_w, 'transform': rasterio.windows.transform(window, src.transform) }) # 将子区域数据包装成内存中的 rasterio 数据集 with MemoryFile() as memfile: with memfile.open(**sub_profile) as sub_dst: sub_dst.write(sub_bands) with memfile.open() as sub_src: # ---------- 粗分割 ---------- coarse_prob = coarse_tiles( coarse_processor, sub_src, prompt, tile_size=coarse_tile_size, overlap=coarse_overlap ) coarse_mask = coarse_prob > coarse_threshold # ---------- 构建边缘带(直接在 GPU 上进行) ---------- band = build_band(coarse_mask, band_radius) # ---------- 精修分割 ---------- fine_probs = refine_tiles( fine_processor, sub_src, prompt, band, tile_size=tile_size, overlap=overlap ) # ---------- 合并粗/细结果 ---------- final_prob = torch.maximum(fine_probs, coarse_prob) final_mask = final_prob > final_threshold # ---------- 获取子区域掩码(numpy) ---------- sub_mask_np = final_mask.cpu().numpy().astype(np.uint8) # 合并到全尺寸掩码(重叠区域取最大值) full_mask[top:bottom, left:right] = np.maximum( full_mask[top:bottom, left:right], sub_mask_np ) # 更新子区域进度条 region_pbar.update(1) # ========== 后处理 ========== print("\n后处理:填充内部NoData空洞...") if nodata is not None: band1 = src.read(1) if np.issubdtype(band1.dtype, np.floating): nodata_mask = np.isclose(band1, float(nodata)) else: nodata_mask = (band1 == nodata) labeled_mask, num_features = ndimage.label(nodata_mask) boundary_mask = np.zeros_like(nodata_mask, dtype=bool) boundary_mask[0, :] = True boundary_mask[-1, :] = True boundary_mask[:, 0] = True boundary_mask[:, -1] = True boundary_labels = set(labeled_mask[boundary_mask]) internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask if internal_nodata_mask.any(): struct = ndimage.generate_binary_structure(2, 2) internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7) full_mask[internal_dilated] = 1 print(f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}") else: print(" 无内部NoData区域") print("后处理:面积过滤...") if min_area is not None or keep_largest_only: original_count = np.sum(full_mask) full_mask = filter_by_area(full_mask, min_area=min_area, keep_largest_only=keep_largest_only) filtered_count = np.sum(full_mask) print(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}") # ========== 保存结果 ========== profile = src.profile.copy() profile.update(count=1, dtype="uint8", compress='lzw') with rasterio.open(mask_output_path, "w", **profile) as dst: dst.write(full_mask, 1) print(f"分割完成,结果已保存至:{mask_output_path}")