Files
water-body-segmentation/water.py
2026-03-09 17:23:53 +08:00

183 lines
5.9 KiB
Python

import torch
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import rasterio
from rasterio.windows import Window
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
def binary_dilate(mask, radius):
if radius <= 0:
return mask
kernel = 2 * radius + 1
return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5
def binary_erode(mask, radius):
if radius <= 0:
return mask
return ~binary_dilate(~mask, radius)
def combine_masks_logits(masks_logits):
if masks_logits.numel() == 0:
return None
probs = masks_logits.squeeze(1)
if probs.dim() == 2:
return probs
return torch.amax(probs, dim=0)
def upsample_prob(prob, size):
return (
F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False)
.squeeze(0)
.squeeze(0)
)
def infer_coarse(processor, image, prompt, threshold, target_size=None):
state = processor.set_image(image)
state = processor.set_text_prompt(prompt=prompt, state=state)
prob = combine_masks_logits(state["masks_logits"])
if prob is None:
prob = torch.zeros((image.height, image.width), device=processor.device)
if target_size is not None:
prob = upsample_prob(prob, target_size)
else:
prob = upsample_prob(prob, (image.height, image.width))
prob_cpu = prob.detach().cpu()
mask_cpu = prob_cpu > threshold
return prob_cpu, mask_cpu
def build_band(mask_cpu, radius):
mask = mask_cpu[None, None, ...]
dil = binary_dilate(mask, radius)
ero = binary_erode(mask, radius)
band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0)
return band
def tile_slices(height, width, tile_size, overlap):
stride = max(tile_size - overlap, 1)
for top in range(0, height, stride):
for left in range(0, width, stride):
bottom = min(top + tile_size, height)
right = min(left + tile_size, width)
yield top, left, bottom, right
def _to_uint8(arr):
if arr.dtype == np.uint8:
return arr
arr = arr.astype(np.float32)
vmin = np.percentile(arr, 2.0)
vmax = np.percentile(arr, 98.0)
if vmax <= vmin:
return np.zeros_like(arr, dtype=np.uint8)
arr = (arr - vmin) / (vmax - vmin)
arr = np.clip(arr, 0.0, 1.0)
return (arr * 255.0).astype(np.uint8)
def _bands_to_pil(bands):
if bands.ndim != 3:
raise ValueError("bands must be (C,H,W)")
c, h, w = bands.shape
if c == 1:
rgb = np.repeat(bands, 3, axis=0)
else:
rgb = bands[:3]
rgb = _to_uint8(rgb)
rgb = np.transpose(rgb, (1, 2, 0))
return Image.fromarray(rgb, mode="RGB")
def _read_pil_window(src, window):
bands = src.read(window=window, boundless=True, fill_value=0)
return _bands_to_pil(bands)
def _coarse_shape(height, width, max_side):
scale = max_side / float(max(height, width))
h = max(int(round(height * scale)), 1)
w = max(int(round(width * scale)), 1)
return h, w
def refine_tiles(processor, src, prompt, band_coarse, tile_size, overlap):
height, width = src.height, src.width
full_probs = torch.zeros((height, width), dtype=torch.float16)
band_h, band_w = band_coarse.shape
scale_y = band_h / float(height)
scale_x = band_w / float(width)
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
c_top = int(top * scale_y)
c_left = int(left * scale_x)
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1)
c_right = max(int(np.ceil(right * scale_x)), c_left + 1)
if not band_coarse[c_top:c_bottom, c_left:c_right].any():
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window)
state = processor.set_image(crop)
state = processor.set_text_prompt(prompt=prompt, state=state)
tile_prob = combine_masks_logits(state["masks_logits"])
if tile_prob is None:
continue
if tile_prob.shape[-2:] != (bottom - top, right - left):
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
tile_prob = tile_prob.detach().cpu().to(full_probs.dtype)
region = full_probs[top:bottom, left:right]
full_probs[top:bottom, left:right] = torch.maximum(region, tile_prob)
return full_probs
matplotlib.use("TkAgg")
image_path = r"E:\is2\yaopu\result.tif"
mask_output_path = r"E:\is2\yaopu\result_mask.tif"
prompt = "water body"
coarse_read_max_side = 768
coarse_resolution = 1008
fine_resolution = 1008
coarse_threshold = 0.5
final_threshold = 0.5
band_radius = 64
tile_size = 2048
overlap = 256
device = "cuda" if torch.cuda.is_available() else "cpu"
model = build_sam3_image_model().to(device).eval()
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device)
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device)
with rasterio.open(image_path) as src:
coarse_h, coarse_w = _coarse_shape(src.height, src.width, coarse_read_max_side)
coarse_bands = src.read(out_shape=(src.count, coarse_h, coarse_w))
coarse_image = _bands_to_pil(coarse_bands)
coarse_prob, coarse_mask = infer_coarse(
coarse_processor,
coarse_image,
prompt,
coarse_threshold,
target_size=(src.height, src.width),
)
band = build_band(coarse_mask, band_radius)
fine_probs = refine_tiles(
fine_processor, src, prompt, band, tile_size, overlap
)
final_prob = torch.maximum(fine_probs.float(), coarse_prob.float())
final_mask = final_prob > final_threshold
mask_np = final_mask.numpy().astype(np.uint8)
profile = src.profile.copy()
profile.update(count=1, dtype="uint8")
with rasterio.open(mask_output_path, "w", **profile) as dst:
dst.write(mask_np, 1)