Files
water-body-segmentation/water_V2.py
2026-03-09 17:23:53 +08:00

246 lines
8.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import rasterio
from rasterio.windows import Window
from tqdm import tqdm
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
def binary_dilate(mask, radius):
if radius <= 0:
return mask
kernel = 2 * radius + 1
return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5
def binary_erode(mask, radius):
if radius <= 0:
return mask
return ~binary_dilate(~mask, radius)
def combine_masks_logits(masks_logits):
if masks_logits.numel() == 0:
return None
probs = masks_logits.squeeze(1)
if probs.dim() == 2:
return probs
return torch.amax(probs, dim=0)
def upsample_prob(prob, size):
return F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False).squeeze(0).squeeze(0)
def infer_coarse(processor, image, prompt, threshold, target_size=None):
state = processor.set_image(image)
state = processor.set_text_prompt(prompt=prompt, state=state)
prob = combine_masks_logits(state["masks_logits"])
if prob is None:
prob = torch.zeros((image.height, image.width), device=processor.device)
if target_size is not None:
prob = upsample_prob(prob, target_size)
else:
prob = upsample_prob(prob, (image.height, image.width))
mask = prob > threshold
return prob, mask
def build_band(mask, radius):
mask_4d = mask[None, None, ...]
dil = binary_dilate(mask_4d, radius)
ero = binary_erode(mask_4d, radius)
band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0)
return band
def tile_slices(height, width, tile_size, overlap):
stride = max(tile_size - overlap, 1)
for top in range(0, height, stride):
for left in range(0, width, stride):
bottom = min(top + tile_size, height)
right = min(left + tile_size, width)
yield top, left, bottom, right
def _to_uint8(arr):
if arr.dtype == np.uint8:
return arr
arr = arr.astype(np.float32)
vmin = np.percentile(arr, 2.0)
vmax = np.percentile(arr, 98.0)
if vmax <= vmin:
return np.zeros_like(arr, dtype=np.uint8)
arr = (arr - vmin) / (vmax - vmin)
arr = np.clip(arr, 0.0, 1.0)
return (arr * 255.0).astype(np.uint8)
def _bands_to_pil(bands):
if bands.ndim != 3:
raise ValueError("bands must be (C,H,W)")
c, h, w = bands.shape
if c == 1:
rgb = np.repeat(bands, 3, axis=0)
else:
rgb = bands[:3]
rgb = _to_uint8(rgb)
rgb = np.transpose(rgb, (1, 2, 0))
return Image.fromarray(rgb, mode="RGB")
def _read_pil_window(src, window):
bands = src.read(window=window, boundless=True, fill_value=0)
return _bands_to_pil(bands)
def _coarse_shape(height, width, max_side):
scale = max_side / float(max(height, width))
h = max(int(round(height * scale)), 1)
w = max(int(round(width * scale)), 1)
return h, w
def refine_tiles(processor, src, prompt, band_coarse, tile_size, overlap):
height, width = src.height, src.width
device = processor.device
full_probs = torch.zeros((height, width), dtype=torch.float32, device=device)
band_h, band_w = band_coarse.shape
scale_y = band_h / float(height)
scale_x = band_w / float(width)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
with tqdm(total=total_tiles, desc="精修分块", unit="") as pbar:
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
c_top = int(top * scale_y)
c_left = int(left * scale_x)
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1)
c_right = max(int(np.ceil(right * scale_x)), c_left + 1)
if not band_coarse[c_top:c_bottom, c_left:c_right].any():
pbar.update(1)
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window)
state = processor.set_image(crop)
state = processor.set_text_prompt(prompt=prompt, state=state)
tile_prob = combine_masks_logits(state["masks_logits"])
if tile_prob is None:
pbar.update(1)
continue
if tile_prob.shape[-2:] != (bottom - top, right - left):
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
full_probs[top:bottom, left:right] = torch.maximum(
full_probs[top:bottom, left:right], tile_prob
)
pbar.update(1)
return full_probs
# 设置后端
matplotlib.use("TkAgg")
# 参数设置
image_path = r"E:\is2\yaopu\result.tif"
mask_output_path = r"E:\is2\yaopu\result_mask.tif"
prompt = "water body"
coarse_read_max_side = 768
coarse_resolution = 1008
fine_resolution = 1008
coarse_threshold = 0.5
final_threshold = 0.5
band_radius = 64
tile_size = 2048
overlap = 256
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
model = build_sam3_image_model().to(device).eval()
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device)
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device)
with rasterio.open(image_path) as src:
nodata = src.nodata # 获取NoData值可能为None
print(f"原始影像NoData值: {nodata}")
print("开始粗分割...")
coarse_h, coarse_w = _coarse_shape(src.height, src.width, coarse_read_max_side)
coarse_bands = src.read(out_shape=(src.count, coarse_h, coarse_w))
coarse_image = _bands_to_pil(coarse_bands)
coarse_prob, coarse_mask = infer_coarse(
coarse_processor,
coarse_image,
prompt,
coarse_threshold,
target_size=(src.height, src.width),
)
print("构建边缘带...")
band = build_band(coarse_mask, band_radius)
print("开始精修分割...")
fine_probs = refine_tiles(
fine_processor, src, prompt, band, tile_size, overlap
)
print("合并粗/细结果...")
final_prob = torch.maximum(fine_probs, coarse_prob)
final_mask = final_prob > final_threshold
mask_np = final_mask.cpu().numpy().astype(np.uint8)
# ========== 新增根据NoData填充空洞 ==========
if nodata is not None:
band1 = src.read(1)
if np.issubdtype(band1.dtype, np.floating):
nodata_mask = np.isclose(band1, float(nodata))
else:
nodata_mask = (band1 == nodata)
from scipy import ndimage
labeled_mask, num_features = ndimage.label(nodata_mask)
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
boundary_mask[0, :] = True
boundary_mask[-1, :] = True
boundary_mask[:, 0] = True
boundary_mask[:, -1] = True
boundary_labels = set(labeled_mask[boundary_mask])
internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask
# 新增对内部NoData进行膨胀两圈以消除与SAM分割之间的间隙
if internal_nodata_mask.any():
# 使用8邻域结构元素迭代两次
struct = ndimage.generate_binary_structure(2, 2) # 2表示8邻域连接
internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7)
# 用膨胀后的掩码填充水体
mask_np[internal_dilated] = 1
print(f"内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}")
else:
print("无内部NoData区域")
# ============================================
profile = src.profile.copy()
profile.update(count=1, dtype="uint8")
with rasterio.open(mask_output_path, "w", **profile) as dst:
dst.write(mask_np, 1)
print(f"分割完成,结果已保存至:{mask_output_path}")