import torch import torch.nn.functional as F import matplotlib import matplotlib.pyplot as plt import numpy as np from PIL import Image import rasterio from rasterio.windows import Window from tqdm import tqdm from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor # ---------- 工具函数(保持不变) ---------- def binary_dilate(mask, radius): if radius <= 0: return mask kernel = 2 * radius + 1 return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5 def binary_erode(mask, radius): if radius <= 0: return mask return ~binary_dilate(~mask, radius) def combine_masks_logits(masks_logits): if masks_logits.numel() == 0: return None probs = masks_logits.squeeze(1) if probs.dim() == 2: return probs return torch.amax(probs, dim=0) def upsample_prob(prob, size): return F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False).squeeze(0).squeeze(0) def build_band(mask, radius): mask_4d = mask[None, None, ...] dil = binary_dilate(mask_4d, radius) ero = binary_erode(mask_4d, radius) band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0) return band def tile_slices(height, width, tile_size, overlap): stride = max(tile_size - overlap, 1) for top in range(0, height, stride): for left in range(0, width, stride): bottom = min(top + tile_size, height) right = min(left + tile_size, width) yield top, left, bottom, right def _to_uint8(arr): if arr.dtype == np.uint8: return arr arr = arr.astype(np.float32) vmin = np.percentile(arr, 2.0) vmax = np.percentile(arr, 98.0) if vmax <= vmin: return np.zeros_like(arr, dtype=np.uint8) arr = (arr - vmin) / (vmax - vmin) arr = np.clip(arr, 0.0, 1.0) return (arr * 255.0).astype(np.uint8) def _bands_to_pil(bands): if bands.ndim != 3: raise ValueError("bands must be (C,H,W)") c, h, w = bands.shape if c == 1: rgb = np.repeat(bands, 3, axis=0) else: rgb = bands[:3] rgb = _to_uint8(rgb) rgb = np.transpose(rgb, (1, 2, 0)) return Image.fromarray(rgb, mode="RGB") def _read_pil_window(src, window): bands = src.read(window=window, boundless=True, fill_value=0) return _bands_to_pil(bands) def _coarse_shape(height, width, max_side): scale = max_side / float(max(height, width)) h = max(int(round(height * scale)), 1) w = max(int(round(width * scale)), 1) return h, w # ---------- 新增:粗分割分块推理 ---------- def coarse_tiles(processor, src, prompt, tile_size, overlap): """ 对原图进行分块粗分割,返回全尺寸概率图(GPU张量) """ height, width = src.height, src.width device = processor.device full_probs = torch.zeros((height, width), dtype=torch.float32, device=device) stride = max(tile_size - overlap, 1) num_tiles_y = (height + stride - 1) // stride num_tiles_x = (width + stride - 1) // stride total_tiles = num_tiles_y * num_tiles_x with tqdm(total=total_tiles, desc="粗分割分块", unit="块") as pbar: for top, left, bottom, right in tile_slices(height, width, tile_size, overlap): window = Window(left, top, right - left, bottom - top) crop = _read_pil_window(src, window) state = processor.set_image(crop) state = processor.set_text_prompt(prompt=prompt, state=state) tile_prob = combine_masks_logits(state["masks_logits"]) if tile_prob is None: pbar.update(1) continue if tile_prob.shape[-2:] != (bottom - top, right - left): tile_prob = upsample_prob(tile_prob, (bottom - top, right - left)) full_probs[top:bottom, left:right] = torch.maximum( full_probs[top:bottom, left:right], tile_prob ) pbar.update(1) return full_probs # ---------- 精修分块(保持不变) ---------- def refine_tiles(processor, src, prompt, band_coarse, tile_size, overlap): height, width = src.height, src.width device = processor.device full_probs = torch.zeros((height, width), dtype=torch.float32, device=device) band_h, band_w = band_coarse.shape scale_y = band_h / float(height) scale_x = band_w / float(width) stride = max(tile_size - overlap, 1) num_tiles_y = (height + stride - 1) // stride num_tiles_x = (width + stride - 1) // stride total_tiles = num_tiles_y * num_tiles_x with tqdm(total=total_tiles, desc="精修分块", unit="块") as pbar: for top, left, bottom, right in tile_slices(height, width, tile_size, overlap): c_top = int(top * scale_y) c_left = int(left * scale_x) c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1) c_right = max(int(np.ceil(right * scale_x)), c_left + 1) if not band_coarse[c_top:c_bottom, c_left:c_right].any(): pbar.update(1) continue window = Window(left, top, right - left, bottom - top) crop = _read_pil_window(src, window) state = processor.set_image(crop) state = processor.set_text_prompt(prompt=prompt, state=state) tile_prob = combine_masks_logits(state["masks_logits"]) if tile_prob is None: pbar.update(1) continue if tile_prob.shape[-2:] != (bottom - top, right - left): tile_prob = upsample_prob(tile_prob, (bottom - top, right - left)) full_probs[top:bottom, left:right] = torch.maximum( full_probs[top:bottom, left:right], tile_prob ) pbar.update(1) return full_probs # ---------- 主程序 ---------- matplotlib.use("TkAgg") # 参数设置 image_path = r"E:\is2\dingshanhu\result_caijian.tif" mask_output_path = r"E:\is2\dingshanhu\result_maskV1.tif" prompt = "water body" coarse_read_max_side = 768 # 不再使用,但保留以免影响其他部分 coarse_resolution = 1008 fine_resolution = 1008 coarse_threshold = 0.5 final_threshold = 0.5 band_radius = 64 tile_size = 2048 # 精修分块大小 overlap = 256 # 精修重叠 coarse_tile_size = 6144 # 粗分割分块大小(可调) coarse_overlap = 256 # 粗分割重叠(可调) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"使用设备: {device}") model = build_sam3_image_model().to(device).eval() coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device) fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device) with rasterio.open(image_path) as src: nodata = src.nodata print(f"原始影像NoData值: {nodata}") print("开始粗分割(分块推理)...") coarse_prob = coarse_tiles( coarse_processor, src, prompt, tile_size=coarse_tile_size, overlap=coarse_overlap ) coarse_mask = coarse_prob > coarse_threshold print("构建边缘带...") band = build_band(coarse_mask, band_radius) print("开始精修分割...") fine_probs = refine_tiles( fine_processor, src, prompt, band, tile_size, overlap ) print("合并粗/细结果...") final_prob = torch.maximum(fine_probs, coarse_prob) final_mask = final_prob > final_threshold mask_np = final_mask.cpu().numpy().astype(np.uint8) # ========== 填充内部NoData空洞(带膨胀) ========== if nodata is not None: band1 = src.read(1) if np.issubdtype(band1.dtype, np.floating): nodata_mask = np.isclose(band1, float(nodata)) else: nodata_mask = (band1 == nodata) from scipy import ndimage labeled_mask, num_features = ndimage.label(nodata_mask) boundary_mask = np.zeros_like(nodata_mask, dtype=bool) boundary_mask[0, :] = True boundary_mask[-1, :] = True boundary_mask[:, 0] = True boundary_mask[:, -1] = True boundary_labels = set(labeled_mask[boundary_mask]) internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask if internal_nodata_mask.any(): struct = ndimage.generate_binary_structure(2, 2) # 8邻域 internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7) mask_np[internal_dilated] = 1 print(f"内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}") else: print("无内部NoData区域") # ================================================== profile = src.profile.copy() profile.update(count=1, dtype="uint8") with rasterio.open(mask_output_path, "w", **profile) as dst: dst.write(mask_np, 1) print(f"分割完成,结果已保存至:{mask_output_path}")