Files
water-body-segmentation/water_V5.py
2026-03-10 17:29:24 +08:00

619 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
使用SAM3模型分割超大遥感影像中的水体三层分割策略
流程:
1. 将影像划分为若干有重叠的子区域(第一层)。
2. 对每个子区域进行中分辨率粗分割以4096为块大小滑动窗口推理得到粗掩码第二层
3. 基于粗掩码构建边缘带,对边缘带进行精细分块推理,得到精修掩码(第三层)。
4. 合并所有子区域的掩码(重叠区域取最大值)。
5. 后处理填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。
6. 保存结果。
优点:
- 粗分割采用分块推理,避免降采样丢失细节,提高召回率。
- 精修只处理边缘带,节省计算资源。
"""
import torch
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import rasterio
from rasterio.windows import Window
from rasterio.io import MemoryFile
from tqdm import tqdm
from scipy import ndimage
import math
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# ---------- 工具函数(保持不变) ----------
def binary_dilate(mask, radius):
if radius <= 0:
return mask
kernel = 2 * radius + 1
return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5
def binary_erode(mask, radius):
if radius <= 0:
return mask
return ~binary_dilate(~mask, radius)
def combine_masks_logits(masks_logits):
if masks_logits.numel() == 0:
return None
probs = masks_logits.squeeze(1)
if probs.dim() == 2:
return probs
return torch.amax(probs, dim=0)
def upsample_prob(prob, size):
return F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False).squeeze(0).squeeze(0)
def build_band(mask, radius):
"""mask: 2D tensor (CPU or GPU)"""
mask_4d = mask[None, None, ...] # (1,1,H,W)
dil = binary_dilate(mask_4d, radius)
ero = binary_erode(mask_4d, radius)
band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0)
return band
def tile_slices(height, width, tile_size, overlap):
stride = max(tile_size - overlap, 1)
for top in range(0, height, stride):
for left in range(0, width, stride):
bottom = min(top + tile_size, height)
right = min(left + tile_size, width)
yield top, left, bottom, right
def compute_stretch_params(bands, sample_max_pixels=2_000_000):
c, h, w = bands.shape
step = int(np.ceil(np.sqrt((h * w) / float(sample_max_pixels))))
step = max(step, 1)
sample = bands[:, ::step, ::step].astype(np.float32)
vmin = np.percentile(sample, 2.0, axis=(1, 2))
vmax = np.percentile(sample, 98.0, axis=(1, 2))
vmax = np.maximum(vmax, vmin + 1e-6)
return vmin, vmax
def make_overview_bands(bands, max_side):
_, h, w = bands.shape
step = int(math.ceil(max(h, w) / float(max_side)))
step = max(step, 1)
return bands[:, ::step, ::step], step
def _to_uint8(arr, vmin=None, vmax=None):
if arr.dtype == np.uint8:
return arr
arr = arr.astype(np.float32)
if vmin is None or vmax is None:
vmin = np.percentile(arr, 2.0)
vmax = np.percentile(arr, 98.0)
if vmax <= vmin:
return np.zeros_like(arr, dtype=np.uint8)
arr = (arr - vmin) / (vmax - vmin)
arr = np.clip(arr, 0.0, 1.0)
return (arr * 255.0).astype(np.uint8)
def _bands_to_pil(bands, stretch_params=None):
if bands.ndim != 3:
raise ValueError("bands must be (C,H,W)")
c, h, w = bands.shape
if c == 1:
rgb = np.repeat(bands, 3, axis=0)
else:
rgb = bands[:3]
if stretch_params is not None:
vmin, vmax = stretch_params
vmin3 = vmin[:3] if vmin.shape[0] >= 3 else np.repeat(vmin[:1], 3)
vmax3 = vmax[:3] if vmax.shape[0] >= 3 else np.repeat(vmax[:1], 3)
rgb_u8 = np.empty_like(rgb, dtype=np.uint8)
for i in range(3):
rgb_u8[i] = _to_uint8(rgb[i], vmin=vmin3[i], vmax=vmax3[i])
rgb = rgb_u8
else:
rgb = _to_uint8(rgb)
rgb = np.transpose(rgb, (1, 2, 0))
return Image.fromarray(rgb, mode="RGB")
def _read_pil_window(src, window, stretch_params=None):
bands = src.read(window=window)
return _bands_to_pil(bands, stretch_params=stretch_params)
def infer_prompt_prob(processor, image, prompt):
if processor.device != "cpu":
with torch.autocast(device_type="cuda", dtype=torch.float16):
state = processor.set_image(image)
state = processor.set_text_prompt(prompt=prompt, state=state)
else:
state = processor.set_image(image)
state = processor.set_text_prompt(prompt=prompt, state=state)
prob = combine_masks_logits(state["masks_logits"])
return prob
# ---------- 新增:中分辨率粗分割(分块推理) ----------
def _downsample_any(mask, factor):
if factor <= 1:
return mask
h, w = mask.shape
pad_h = (-h) % factor
pad_w = (-w) % factor
if pad_h or pad_w:
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=False)
h, w = mask.shape
h2 = h // factor
w2 = w // factor
return mask.reshape(h2, factor, w2, factor).any(axis=(1, 3))
def _build_band_ndimage(mask, radius):
if radius <= 0:
return np.zeros_like(mask, dtype=bool)
struct = ndimage.generate_binary_structure(2, 2)
dil = ndimage.binary_dilation(mask, structure=struct, iterations=radius)
ero = ndimage.binary_erosion(mask, structure=struct, iterations=radius)
return np.logical_xor(dil, ero)
def coarse_tile_segmentation(
processor,
src,
prompt,
tile_size,
overlap,
downsample_factor=4,
stretch_params=None,
nodata_mask=None,
use_tqdm=True,
):
"""
对子区域进行分块粗分割返回低分辨率概率图float16
Args:
processor: 粗处理器(输入分辨率 coarse_resolution
src: rasterio 数据集(子区域)
prompt: 文本提示
tile_size: 粗分割块大小(原图像素)
overlap: 块重叠像素
stretch_params: 拉伸参数
nodata_mask: 子区域的 NoData 掩码bool, 与子区域同尺寸),用于过滤无效区域
Returns:
coarse_probs_lr: 子区域粗概率图 (H_lr,W_lr) float16
"""
height, width = src.height, src.width
height_lr = int(math.ceil(height / float(downsample_factor)))
width_lr = int(math.ceil(width / float(downsample_factor)))
full_probs_lr = np.zeros((height_lr, width_lr), dtype=np.float16)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
if use_tqdm:
iterator = tqdm(
tile_slices(height, width, tile_size, overlap),
total=total_tiles,
desc="粗分割分块",
unit="",
leave=True,
)
else:
iterator = tile_slices(height, width, tile_size, overlap)
for top, left, bottom, right in iterator:
# 如果该块完全在 NoData 区域内,则跳过(加速)
if nodata_mask is not None:
block_nodata = nodata_mask[top:bottom, left:right]
if block_nodata.all():
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window, stretch_params=stretch_params)
tile_prob = infer_prompt_prob(processor, crop, prompt)
if tile_prob is None:
continue
h_lr = int(math.ceil((bottom - top) / float(downsample_factor)))
w_lr = int(math.ceil((right - left) / float(downsample_factor)))
if tile_prob.shape[-2:] != (h_lr, w_lr):
tile_prob = upsample_prob(tile_prob, (h_lr, w_lr))
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
top_lr = top // downsample_factor
left_lr = left // downsample_factor
bottom_lr = top_lr + h_lr
right_lr = left_lr + w_lr
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr] = np.maximum(
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr], tile_prob_cpu
)
if nodata_mask is not None:
nodata_lr = _downsample_any(nodata_mask, downsample_factor)
full_probs_lr[nodata_lr] = 0.0
return full_probs_lr
# ---------- 精修分块(保持不变,但可复用之前的函数) ----------
def refine_tiles(
processor,
src,
prompt,
band_lr,
downsample_factor,
tile_size,
overlap,
stretch_params=None,
use_tqdm=True,
):
"""
对边缘带进行精修分割
band_lr: 2D numpy bool指示需要精修的区域低分辨率
"""
height, width = src.height, src.width
full_probs = np.zeros((height, width), dtype=np.float16)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
if use_tqdm:
iterator = tqdm(
tile_slices(height, width, tile_size, overlap),
total=total_tiles,
desc="精修分块",
unit="",
leave=True,
)
else:
iterator = tile_slices(height, width, tile_size, overlap)
for top, left, bottom, right in iterator:
c_top = top // downsample_factor
c_left = left // downsample_factor
c_bottom = int(math.ceil(bottom / float(downsample_factor)))
c_right = int(math.ceil(right / float(downsample_factor)))
c_bottom = max(c_bottom, c_top + 1)
c_right = max(c_right, c_left + 1)
if not band_lr[c_top:c_bottom, c_left:c_right].any():
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window, stretch_params=stretch_params)
tile_prob = infer_prompt_prob(processor, crop, prompt)
if tile_prob is None:
continue
if tile_prob.shape[-2:] != (bottom - top, right - left):
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
full_probs[top:bottom, left:right] = np.maximum(
full_probs[top:bottom, left:right], tile_prob_cpu
)
return full_probs
# ---------- 后处理函数(面积过滤) ----------
def filter_by_area(mask, min_area=None, keep_largest_only=False):
"""
对二值掩码进行连通域面积过滤。
mask: np.ndarray, dtype=bool or uint8, shape (H,W)
min_area: int, 最小面积阈值像素小于此值的连通域移除。若为None则不过滤。
keep_largest_only: bool, 若True只保留面积最大的连通域同时受min_area约束
返回过滤后的掩码 (uint8, 0/1)。
"""
labeled, num = ndimage.label(mask)
if num == 0:
return mask.astype(np.uint8)
sizes = ndimage.sum(mask, labeled, range(1, num+1))
keep_labels = []
if keep_largest_only:
max_idx = np.argmax(sizes)
if min_area is None or sizes[max_idx] >= min_area:
keep_labels = [max_idx + 1]
else:
for i, size in enumerate(sizes, start=1):
if min_area is None or size >= min_area:
keep_labels.append(i)
if not keep_labels:
return np.zeros_like(mask, dtype=np.uint8)
keep_mask = np.isin(labeled, keep_labels)
return keep_mask.astype(np.uint8)
# ---------- 划分子区域函数 ----------
def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=256):
"""
将图像划分为若干有重叠的子区域,返回每个子区域的 (left, top, right, bottom)
坐标(像素坐标,相对于原图)。
"""
regions = []
tile_w = (width + num_splits_x - 1) // num_splits_x
tile_h = (height + num_splits_y - 1) // num_splits_y
for i in range(num_splits_y):
for j in range(num_splits_x):
left = max(j * tile_w - overlap, 0)
top = max(i * tile_h - overlap, 0)
right = min((j + 1) * tile_w + overlap, width)
bottom = min((i + 1) * tile_h + overlap, height)
regions.append((left, top, right, bottom))
return regions
DEFAULT_CONFIG = {
"image_path": r"E:\is2\guidingsahn\result.tif",
"mask_output_path": r"E:\is2\guidingsahn\result_mask.tif",
"prompt": "water body",
"overview_max_side": 1400,
"overview_threshold": 0.4,
"coarse_resolution": 1008,
"fine_resolution": 1008,
"coarse_threshold": 0.5,
"final_threshold": 0.5,
"band_radius": 64,
"fine_tile_size": 2048,
"fine_overlap": 256,
"coarse_tile_size": 4096,
"coarse_tile_overlap": 256,
"coarse_downsample_factor": 4,
"num_splits_y": 2,
"num_splits_x": 3,
"region_overlap": 256,
"min_area": 5000,
"keep_largest_only": False,
"use_tqdm": True,
"device": None,
}
def run_segmentation(config=None, progress_callback=None, log_callback=None, stop_event=None):
cfg = dict(DEFAULT_CONFIG)
if config:
cfg.update(config)
log = log_callback if log_callback is not None else print
device = cfg["device"]
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
log(f"使用设备: {device}")
model = build_sam3_image_model().to(device).eval()
coarse_processor = Sam3Processor(
model, resolution=cfg["coarse_resolution"], device=device
)
fine_processor = Sam3Processor(model, resolution=cfg["fine_resolution"], device=device)
image_path = cfg["image_path"]
mask_output_path = cfg["mask_output_path"]
with rasterio.open(image_path) as src:
nodata = src.nodata
log(f"原始影像NoData值: {nodata}")
height, width = src.height, src.width
band1 = src.read(1)
if np.issubdtype(band1.dtype, np.floating):
nodata_mask_full = (
np.isclose(band1, float(nodata))
if nodata is not None
else np.zeros_like(band1, dtype=bool)
)
else:
nodata_mask_full = (
(band1 == nodata) if nodata is not None else np.zeros_like(band1, dtype=bool)
)
regions = split_into_regions(
width, height, cfg["num_splits_y"], cfg["num_splits_x"], cfg["region_overlap"]
)
log(f"共划分 {len(regions)} 个子区域")
full_mask = np.zeros((height, width), dtype=np.uint8)
use_tqdm = bool(cfg.get("use_tqdm", True))
if use_tqdm:
region_iter = tqdm(
list(enumerate(regions)),
total=len(regions),
desc="处理子区域",
unit="子区域",
)
else:
region_iter = enumerate(regions)
for idx, (left, top, right, bottom) in region_iter:
if stop_event is not None and stop_event.is_set():
log("已停止")
return
if progress_callback is not None:
progress_callback("region", idx + 1, len(regions))
log(
f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})"
)
region_w = right - left
region_h = bottom - top
window = Window(left, top, region_w, region_h)
sub_bands = src.read(window=window)
sub_profile = src.profile.copy()
if not src.is_tiled:
sub_profile.pop("blockxsize", None)
sub_profile.pop("blockysize", None)
sub_profile["tiled"] = False
else:
sub_profile["tiled"] = True
sub_profile.update(
{
"height": region_h,
"width": region_w,
"transform": rasterio.windows.transform(window, src.transform),
}
)
sub_nodata = nodata_mask_full[top:bottom, left:right]
with MemoryFile() as memfile:
with memfile.open(**sub_profile) as sub_dst:
sub_dst.write(sub_bands)
with memfile.open() as sub_src:
stretch_params = compute_stretch_params(sub_bands)
overview_bands, overview_step = make_overview_bands(
sub_bands, max_side=cfg["overview_max_side"]
)
overview_img = _bands_to_pil(
overview_bands, stretch_params=stretch_params
)
overview_prob = infer_prompt_prob(
coarse_processor, overview_img, cfg["prompt"]
)
if overview_prob is None:
overview_mask_small = np.zeros(
(overview_img.height, overview_img.width), dtype=bool
)
else:
overview_mask_small = (
(overview_prob > cfg["overview_threshold"])
.detach()
.cpu()
.numpy()
)
if sub_nodata.any():
nodata_overview = sub_nodata[::overview_step, ::overview_step]
overview_mask_small[nodata_overview] = False
coarse_probs = coarse_tile_segmentation(
processor=coarse_processor,
src=sub_src,
prompt=cfg["prompt"],
tile_size=cfg["coarse_tile_size"],
overlap=cfg["coarse_tile_overlap"],
downsample_factor=cfg["coarse_downsample_factor"],
stretch_params=stretch_params,
nodata_mask=sub_nodata if sub_nodata.any() else None,
use_tqdm=use_tqdm,
)
coarse_mask_lr = coarse_probs > cfg["coarse_threshold"]
overview_mask_lr = np.array(
Image.fromarray(
overview_mask_small.astype(np.uint8), mode="L"
).resize(
(coarse_mask_lr.shape[1], coarse_mask_lr.shape[0]),
resample=Image.NEAREST,
)
).astype(bool)
combined_mask_lr = coarse_mask_lr | overview_mask_lr
band_radius_lr = max(
int(round(cfg["band_radius"] / float(cfg["coarse_downsample_factor"]))),
1,
)
band_lr = _build_band_ndimage(combined_mask_lr, band_radius_lr)
fine_probs = refine_tiles(
processor=fine_processor,
src=sub_src,
prompt=cfg["prompt"],
band_lr=band_lr,
downsample_factor=cfg["coarse_downsample_factor"],
tile_size=cfg["fine_tile_size"],
overlap=cfg["fine_overlap"],
stretch_params=stretch_params,
use_tqdm=use_tqdm,
)
fine_mask = fine_probs > cfg["final_threshold"]
coarse_mask_full = np.array(
Image.fromarray(coarse_mask_lr.astype(np.uint8), mode="L").resize(
(region_w, region_h), resample=Image.NEAREST
)
).astype(bool)
overview_mask_full = np.array(
Image.fromarray(
overview_mask_small.astype(np.uint8), mode="L"
).resize((region_w, region_h), resample=Image.NEAREST)
).astype(bool)
sub_mask_np = (fine_mask | coarse_mask_full | overview_mask_full).astype(
np.uint8
)
full_mask[top:bottom, left:right] = np.maximum(
full_mask[top:bottom, left:right], sub_mask_np
)
log("\n后处理填充内部NoData空洞...")
if nodata is not None:
if np.issubdtype(band1.dtype, np.floating):
nodata_mask = np.isclose(band1, float(nodata))
else:
nodata_mask = band1 == nodata
labeled_mask, _ = ndimage.label(nodata_mask)
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
boundary_mask[0, :] = True
boundary_mask[-1, :] = True
boundary_mask[:, 0] = True
boundary_mask[:, -1] = True
boundary_labels = set(labeled_mask[boundary_mask])
internal_nodata_mask = np.isin(
labeled_mask, list(boundary_labels), invert=True
) & nodata_mask
if internal_nodata_mask.any():
struct = ndimage.generate_binary_structure(2, 2)
internal_dilated = ndimage.binary_dilation(
internal_nodata_mask, structure=struct, iterations=7
)
full_mask[internal_dilated] = 1
log(
f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}"
)
else:
log(" 无内部NoData区域")
log("后处理:面积过滤...")
if cfg["min_area"] is not None or cfg["keep_largest_only"]:
original_count = np.sum(full_mask)
full_mask = filter_by_area(
full_mask,
min_area=cfg["min_area"],
keep_largest_only=cfg["keep_largest_only"],
)
filtered_count = np.sum(full_mask)
log(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
profile = src.profile.copy()
profile.update(count=1, dtype="uint8", compress="lzw")
with rasterio.open(mask_output_path, "w", **profile) as dst:
dst.write(full_mask, 1)
log(f"分割完成,结果已保存至:{mask_output_path}")
def main():
matplotlib.use("TkAgg")
run_segmentation()
if __name__ == "__main__":
main()