""" 使用SAM3模型分割超大遥感影像中的水体(三层分割策略) 流程: 1. 将影像划分为若干有重叠的子区域(第一层)。 2. 对每个子区域进行中分辨率粗分割:以4096为块大小,滑动窗口推理,得到粗掩码(第二层)。 3. 基于粗掩码构建边缘带,对边缘带进行精细分块推理,得到精修掩码(第三层)。 4. 合并所有子区域的掩码(重叠区域取最大值)。 5. 后处理:填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。 6. 保存结果。 优点: - 粗分割采用分块推理,避免降采样丢失细节,提高召回率。 - 精修只处理边缘带,节省计算资源。 """ import torch import torch.nn.functional as F import matplotlib import matplotlib.pyplot as plt import numpy as np from PIL import Image import rasterio from rasterio.windows import Window from rasterio.io import MemoryFile from tqdm import tqdm from scipy import ndimage import math from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor # ---------- 工具函数(保持不变) ---------- def binary_dilate(mask, radius): if radius <= 0: return mask kernel = 2 * radius + 1 return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5 def binary_erode(mask, radius): if radius <= 0: return mask return ~binary_dilate(~mask, radius) def combine_masks_logits(masks_logits): if masks_logits.numel() == 0: return None probs = masks_logits.squeeze(1) if probs.dim() == 2: return probs return torch.amax(probs, dim=0) def upsample_prob(prob, size): return F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False).squeeze(0).squeeze(0) def build_band(mask, radius): """mask: 2D tensor (CPU or GPU)""" mask_4d = mask[None, None, ...] # (1,1,H,W) dil = binary_dilate(mask_4d, radius) ero = binary_erode(mask_4d, radius) band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0) return band def tile_slices(height, width, tile_size, overlap): stride = max(tile_size - overlap, 1) for top in range(0, height, stride): for left in range(0, width, stride): bottom = min(top + tile_size, height) right = min(left + tile_size, width) yield top, left, bottom, right def compute_stretch_params(bands, sample_max_pixels=2_000_000): c, h, w = bands.shape step = int(np.ceil(np.sqrt((h * w) / float(sample_max_pixels)))) step = max(step, 1) sample = bands[:, ::step, ::step].astype(np.float32) vmin = np.percentile(sample, 2.0, axis=(1, 2)) vmax = np.percentile(sample, 98.0, axis=(1, 2)) vmax = np.maximum(vmax, vmin + 1e-6) return vmin, vmax def make_overview_bands(bands, max_side): _, h, w = bands.shape step = int(math.ceil(max(h, w) / float(max_side))) step = max(step, 1) return bands[:, ::step, ::step], step def _to_uint8(arr, vmin=None, vmax=None): if arr.dtype == np.uint8: return arr arr = arr.astype(np.float32) if vmin is None or vmax is None: vmin = np.percentile(arr, 2.0) vmax = np.percentile(arr, 98.0) if vmax <= vmin: return np.zeros_like(arr, dtype=np.uint8) arr = (arr - vmin) / (vmax - vmin) arr = np.clip(arr, 0.0, 1.0) return (arr * 255.0).astype(np.uint8) def _bands_to_pil(bands, stretch_params=None): if bands.ndim != 3: raise ValueError("bands must be (C,H,W)") c, h, w = bands.shape if c == 1: rgb = np.repeat(bands, 3, axis=0) else: rgb = bands[:3] if stretch_params is not None: vmin, vmax = stretch_params vmin3 = vmin[:3] if vmin.shape[0] >= 3 else np.repeat(vmin[:1], 3) vmax3 = vmax[:3] if vmax.shape[0] >= 3 else np.repeat(vmax[:1], 3) rgb_u8 = np.empty_like(rgb, dtype=np.uint8) for i in range(3): rgb_u8[i] = _to_uint8(rgb[i], vmin=vmin3[i], vmax=vmax3[i]) rgb = rgb_u8 else: rgb = _to_uint8(rgb) rgb = np.transpose(rgb, (1, 2, 0)) return Image.fromarray(rgb, mode="RGB") def _read_pil_window(src, window, stretch_params=None): bands = src.read(window=window) return _bands_to_pil(bands, stretch_params=stretch_params) def infer_prompt_prob(processor, image, prompt): if processor.device != "cpu": with torch.autocast(device_type="cuda", dtype=torch.float16): state = processor.set_image(image) state = processor.set_text_prompt(prompt=prompt, state=state) else: state = processor.set_image(image) state = processor.set_text_prompt(prompt=prompt, state=state) prob = combine_masks_logits(state["masks_logits"]) return prob # ---------- 新增:中分辨率粗分割(分块推理) ---------- def _downsample_any(mask, factor): if factor <= 1: return mask h, w = mask.shape pad_h = (-h) % factor pad_w = (-w) % factor if pad_h or pad_w: mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=False) h, w = mask.shape h2 = h // factor w2 = w // factor return mask.reshape(h2, factor, w2, factor).any(axis=(1, 3)) def _build_band_ndimage(mask, radius): if radius <= 0: return np.zeros_like(mask, dtype=bool) struct = ndimage.generate_binary_structure(2, 2) dil = ndimage.binary_dilation(mask, structure=struct, iterations=radius) ero = ndimage.binary_erosion(mask, structure=struct, iterations=radius) return np.logical_xor(dil, ero) def coarse_tile_segmentation( processor, src, prompt, tile_size, overlap, downsample_factor=4, stretch_params=None, nodata_mask=None, use_tqdm=True, ): """ 对子区域进行分块粗分割,返回低分辨率概率图(float16) Args: processor: 粗处理器(输入分辨率 coarse_resolution) src: rasterio 数据集(子区域) prompt: 文本提示 tile_size: 粗分割块大小(原图像素) overlap: 块重叠像素 stretch_params: 拉伸参数 nodata_mask: 子区域的 NoData 掩码(bool, 与子区域同尺寸),用于过滤无效区域 Returns: coarse_probs_lr: 子区域粗概率图 (H_lr,W_lr) float16 """ height, width = src.height, src.width height_lr = int(math.ceil(height / float(downsample_factor))) width_lr = int(math.ceil(width / float(downsample_factor))) full_probs_lr = np.zeros((height_lr, width_lr), dtype=np.float16) stride = max(tile_size - overlap, 1) num_tiles_y = (height + stride - 1) // stride num_tiles_x = (width + stride - 1) // stride total_tiles = num_tiles_y * num_tiles_x if use_tqdm: iterator = tqdm( tile_slices(height, width, tile_size, overlap), total=total_tiles, desc="粗分割分块", unit="块", leave=True, ) else: iterator = tile_slices(height, width, tile_size, overlap) for top, left, bottom, right in iterator: # 如果该块完全在 NoData 区域内,则跳过(加速) if nodata_mask is not None: block_nodata = nodata_mask[top:bottom, left:right] if block_nodata.all(): continue window = Window(left, top, right - left, bottom - top) crop = _read_pil_window(src, window, stretch_params=stretch_params) tile_prob = infer_prompt_prob(processor, crop, prompt) if tile_prob is None: continue h_lr = int(math.ceil((bottom - top) / float(downsample_factor))) w_lr = int(math.ceil((right - left) / float(downsample_factor))) if tile_prob.shape[-2:] != (h_lr, w_lr): tile_prob = upsample_prob(tile_prob, (h_lr, w_lr)) tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy() top_lr = top // downsample_factor left_lr = left // downsample_factor bottom_lr = top_lr + h_lr right_lr = left_lr + w_lr full_probs_lr[top_lr:bottom_lr, left_lr:right_lr] = np.maximum( full_probs_lr[top_lr:bottom_lr, left_lr:right_lr], tile_prob_cpu ) if nodata_mask is not None: nodata_lr = _downsample_any(nodata_mask, downsample_factor) full_probs_lr[nodata_lr] = 0.0 return full_probs_lr # ---------- 精修分块(保持不变,但可复用之前的函数) ---------- def refine_tiles( processor, src, prompt, band_lr, downsample_factor, tile_size, overlap, stretch_params=None, use_tqdm=True, ): """ 对边缘带进行精修分割 band_lr: 2D numpy bool,指示需要精修的区域(低分辨率) """ height, width = src.height, src.width full_probs = np.zeros((height, width), dtype=np.float16) stride = max(tile_size - overlap, 1) num_tiles_y = (height + stride - 1) // stride num_tiles_x = (width + stride - 1) // stride total_tiles = num_tiles_y * num_tiles_x if use_tqdm: iterator = tqdm( tile_slices(height, width, tile_size, overlap), total=total_tiles, desc="精修分块", unit="块", leave=True, ) else: iterator = tile_slices(height, width, tile_size, overlap) for top, left, bottom, right in iterator: c_top = top // downsample_factor c_left = left // downsample_factor c_bottom = int(math.ceil(bottom / float(downsample_factor))) c_right = int(math.ceil(right / float(downsample_factor))) c_bottom = max(c_bottom, c_top + 1) c_right = max(c_right, c_left + 1) if not band_lr[c_top:c_bottom, c_left:c_right].any(): continue window = Window(left, top, right - left, bottom - top) crop = _read_pil_window(src, window, stretch_params=stretch_params) tile_prob = infer_prompt_prob(processor, crop, prompt) if tile_prob is None: continue if tile_prob.shape[-2:] != (bottom - top, right - left): tile_prob = upsample_prob(tile_prob, (bottom - top, right - left)) tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy() full_probs[top:bottom, left:right] = np.maximum( full_probs[top:bottom, left:right], tile_prob_cpu ) return full_probs # ---------- 后处理函数(面积过滤) ---------- def filter_by_area(mask, min_area=None, keep_largest_only=False): """ 对二值掩码进行连通域面积过滤。 mask: np.ndarray, dtype=bool or uint8, shape (H,W) min_area: int, 最小面积阈值(像素),小于此值的连通域移除。若为None则不过滤。 keep_largest_only: bool, 若True,只保留面积最大的连通域(同时受min_area约束)。 返回过滤后的掩码 (uint8, 0/1)。 """ labeled, num = ndimage.label(mask) if num == 0: return mask.astype(np.uint8) sizes = ndimage.sum(mask, labeled, range(1, num+1)) keep_labels = [] if keep_largest_only: max_idx = np.argmax(sizes) if min_area is None or sizes[max_idx] >= min_area: keep_labels = [max_idx + 1] else: for i, size in enumerate(sizes, start=1): if min_area is None or size >= min_area: keep_labels.append(i) if not keep_labels: return np.zeros_like(mask, dtype=np.uint8) keep_mask = np.isin(labeled, keep_labels) return keep_mask.astype(np.uint8) # ---------- 划分子区域函数 ---------- def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=256): """ 将图像划分为若干有重叠的子区域,返回每个子区域的 (left, top, right, bottom) 坐标(像素坐标,相对于原图)。 """ regions = [] tile_w = (width + num_splits_x - 1) // num_splits_x tile_h = (height + num_splits_y - 1) // num_splits_y for i in range(num_splits_y): for j in range(num_splits_x): left = max(j * tile_w - overlap, 0) top = max(i * tile_h - overlap, 0) right = min((j + 1) * tile_w + overlap, width) bottom = min((i + 1) * tile_h + overlap, height) regions.append((left, top, right, bottom)) return regions DEFAULT_CONFIG = { "image_path": r"E:\is2\guidingsahn\result.tif", "mask_output_path": r"E:\is2\guidingsahn\result_mask.tif", "prompt": "water body", "overview_max_side": 1400, "overview_threshold": 0.4, "coarse_resolution": 1008, "fine_resolution": 1008, "coarse_threshold": 0.5, "final_threshold": 0.5, "band_radius": 64, "fine_tile_size": 2048, "fine_overlap": 256, "coarse_tile_size": 4096, "coarse_tile_overlap": 256, "coarse_downsample_factor": 4, "num_splits_y": 2, "num_splits_x": 3, "region_overlap": 256, "min_area": 5000, "keep_largest_only": False, "use_tqdm": True, "device": None, } def run_segmentation(config=None, progress_callback=None, log_callback=None, stop_event=None): cfg = dict(DEFAULT_CONFIG) if config: cfg.update(config) log = log_callback if log_callback is not None else print device = cfg["device"] if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" log(f"使用设备: {device}") model = build_sam3_image_model().to(device).eval() coarse_processor = Sam3Processor( model, resolution=cfg["coarse_resolution"], device=device ) fine_processor = Sam3Processor(model, resolution=cfg["fine_resolution"], device=device) image_path = cfg["image_path"] mask_output_path = cfg["mask_output_path"] with rasterio.open(image_path) as src: nodata = src.nodata log(f"原始影像NoData值: {nodata}") height, width = src.height, src.width band1 = src.read(1) if np.issubdtype(band1.dtype, np.floating): nodata_mask_full = ( np.isclose(band1, float(nodata)) if nodata is not None else np.zeros_like(band1, dtype=bool) ) else: nodata_mask_full = ( (band1 == nodata) if nodata is not None else np.zeros_like(band1, dtype=bool) ) regions = split_into_regions( width, height, cfg["num_splits_y"], cfg["num_splits_x"], cfg["region_overlap"] ) log(f"共划分 {len(regions)} 个子区域") full_mask = np.zeros((height, width), dtype=np.uint8) use_tqdm = bool(cfg.get("use_tqdm", True)) if use_tqdm: region_iter = tqdm( list(enumerate(regions)), total=len(regions), desc="处理子区域", unit="子区域", ) else: region_iter = enumerate(regions) for idx, (left, top, right, bottom) in region_iter: if stop_event is not None and stop_event.is_set(): log("已停止") return if progress_callback is not None: progress_callback("region", idx + 1, len(regions)) log( f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})" ) region_w = right - left region_h = bottom - top window = Window(left, top, region_w, region_h) sub_bands = src.read(window=window) sub_profile = src.profile.copy() if not src.is_tiled: sub_profile.pop("blockxsize", None) sub_profile.pop("blockysize", None) sub_profile["tiled"] = False else: sub_profile["tiled"] = True sub_profile.update( { "height": region_h, "width": region_w, "transform": rasterio.windows.transform(window, src.transform), } ) sub_nodata = nodata_mask_full[top:bottom, left:right] with MemoryFile() as memfile: with memfile.open(**sub_profile) as sub_dst: sub_dst.write(sub_bands) with memfile.open() as sub_src: stretch_params = compute_stretch_params(sub_bands) overview_bands, overview_step = make_overview_bands( sub_bands, max_side=cfg["overview_max_side"] ) overview_img = _bands_to_pil( overview_bands, stretch_params=stretch_params ) overview_prob = infer_prompt_prob( coarse_processor, overview_img, cfg["prompt"] ) if overview_prob is None: overview_mask_small = np.zeros( (overview_img.height, overview_img.width), dtype=bool ) else: overview_mask_small = ( (overview_prob > cfg["overview_threshold"]) .detach() .cpu() .numpy() ) if sub_nodata.any(): nodata_overview = sub_nodata[::overview_step, ::overview_step] overview_mask_small[nodata_overview] = False coarse_probs = coarse_tile_segmentation( processor=coarse_processor, src=sub_src, prompt=cfg["prompt"], tile_size=cfg["coarse_tile_size"], overlap=cfg["coarse_tile_overlap"], downsample_factor=cfg["coarse_downsample_factor"], stretch_params=stretch_params, nodata_mask=sub_nodata if sub_nodata.any() else None, use_tqdm=use_tqdm, ) coarse_mask_lr = coarse_probs > cfg["coarse_threshold"] overview_mask_lr = np.array( Image.fromarray( overview_mask_small.astype(np.uint8), mode="L" ).resize( (coarse_mask_lr.shape[1], coarse_mask_lr.shape[0]), resample=Image.NEAREST, ) ).astype(bool) combined_mask_lr = coarse_mask_lr | overview_mask_lr band_radius_lr = max( int(round(cfg["band_radius"] / float(cfg["coarse_downsample_factor"]))), 1, ) band_lr = _build_band_ndimage(combined_mask_lr, band_radius_lr) fine_probs = refine_tiles( processor=fine_processor, src=sub_src, prompt=cfg["prompt"], band_lr=band_lr, downsample_factor=cfg["coarse_downsample_factor"], tile_size=cfg["fine_tile_size"], overlap=cfg["fine_overlap"], stretch_params=stretch_params, use_tqdm=use_tqdm, ) fine_mask = fine_probs > cfg["final_threshold"] coarse_mask_full = np.array( Image.fromarray(coarse_mask_lr.astype(np.uint8), mode="L").resize( (region_w, region_h), resample=Image.NEAREST ) ).astype(bool) overview_mask_full = np.array( Image.fromarray( overview_mask_small.astype(np.uint8), mode="L" ).resize((region_w, region_h), resample=Image.NEAREST) ).astype(bool) sub_mask_np = (fine_mask | coarse_mask_full | overview_mask_full).astype( np.uint8 ) full_mask[top:bottom, left:right] = np.maximum( full_mask[top:bottom, left:right], sub_mask_np ) log("\n后处理:填充内部NoData空洞...") if nodata is not None: if np.issubdtype(band1.dtype, np.floating): nodata_mask = np.isclose(band1, float(nodata)) else: nodata_mask = band1 == nodata labeled_mask, _ = ndimage.label(nodata_mask) boundary_mask = np.zeros_like(nodata_mask, dtype=bool) boundary_mask[0, :] = True boundary_mask[-1, :] = True boundary_mask[:, 0] = True boundary_mask[:, -1] = True boundary_labels = set(labeled_mask[boundary_mask]) internal_nodata_mask = np.isin( labeled_mask, list(boundary_labels), invert=True ) & nodata_mask if internal_nodata_mask.any(): struct = ndimage.generate_binary_structure(2, 2) internal_dilated = ndimage.binary_dilation( internal_nodata_mask, structure=struct, iterations=7 ) full_mask[internal_dilated] = 1 log( f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}" ) else: log(" 无内部NoData区域") log("后处理:面积过滤...") if cfg["min_area"] is not None or cfg["keep_largest_only"]: original_count = np.sum(full_mask) full_mask = filter_by_area( full_mask, min_area=cfg["min_area"], keep_largest_only=cfg["keep_largest_only"], ) filtered_count = np.sum(full_mask) log(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}") profile = src.profile.copy() profile.update(count=1, dtype="uint8", compress="lzw") with rasterio.open(mask_output_path, "w", **profile) as dst: dst.write(full_mask, 1) log(f"分割完成,结果已保存至:{mask_output_path}") def main(): matplotlib.use("TkAgg") run_segmentation() if __name__ == "__main__": main()