Files
water-body-segmentation/water_V3.py
2026-03-09 17:23:53 +08:00

251 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import rasterio
from rasterio.windows import Window
from tqdm import tqdm
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# ---------- 工具函数(保持不变) ----------
def binary_dilate(mask, radius):
if radius <= 0:
return mask
kernel = 2 * radius + 1
return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5
def binary_erode(mask, radius):
if radius <= 0:
return mask
return ~binary_dilate(~mask, radius)
def combine_masks_logits(masks_logits):
if masks_logits.numel() == 0:
return None
probs = masks_logits.squeeze(1)
if probs.dim() == 2:
return probs
return torch.amax(probs, dim=0)
def upsample_prob(prob, size):
return F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False).squeeze(0).squeeze(0)
def build_band(mask, radius):
mask_4d = mask[None, None, ...]
dil = binary_dilate(mask_4d, radius)
ero = binary_erode(mask_4d, radius)
band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0)
return band
def tile_slices(height, width, tile_size, overlap):
stride = max(tile_size - overlap, 1)
for top in range(0, height, stride):
for left in range(0, width, stride):
bottom = min(top + tile_size, height)
right = min(left + tile_size, width)
yield top, left, bottom, right
def _to_uint8(arr):
if arr.dtype == np.uint8:
return arr
arr = arr.astype(np.float32)
vmin = np.percentile(arr, 2.0)
vmax = np.percentile(arr, 98.0)
if vmax <= vmin:
return np.zeros_like(arr, dtype=np.uint8)
arr = (arr - vmin) / (vmax - vmin)
arr = np.clip(arr, 0.0, 1.0)
return (arr * 255.0).astype(np.uint8)
def _bands_to_pil(bands):
if bands.ndim != 3:
raise ValueError("bands must be (C,H,W)")
c, h, w = bands.shape
if c == 1:
rgb = np.repeat(bands, 3, axis=0)
else:
rgb = bands[:3]
rgb = _to_uint8(rgb)
rgb = np.transpose(rgb, (1, 2, 0))
return Image.fromarray(rgb, mode="RGB")
def _read_pil_window(src, window):
bands = src.read(window=window, boundless=True, fill_value=0)
return _bands_to_pil(bands)
def _coarse_shape(height, width, max_side):
scale = max_side / float(max(height, width))
h = max(int(round(height * scale)), 1)
w = max(int(round(width * scale)), 1)
return h, w
# ---------- 新增:粗分割分块推理 ----------
def coarse_tiles(processor, src, prompt, tile_size, overlap):
"""
对原图进行分块粗分割返回全尺寸概率图GPU张量
"""
height, width = src.height, src.width
device = processor.device
full_probs = torch.zeros((height, width), dtype=torch.float32, device=device)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
with tqdm(total=total_tiles, desc="粗分割分块", unit="") as pbar:
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window)
state = processor.set_image(crop)
state = processor.set_text_prompt(prompt=prompt, state=state)
tile_prob = combine_masks_logits(state["masks_logits"])
if tile_prob is None:
pbar.update(1)
continue
if tile_prob.shape[-2:] != (bottom - top, right - left):
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
full_probs[top:bottom, left:right] = torch.maximum(
full_probs[top:bottom, left:right], tile_prob
)
pbar.update(1)
return full_probs
# ---------- 精修分块(保持不变) ----------
def refine_tiles(processor, src, prompt, band_coarse, tile_size, overlap):
height, width = src.height, src.width
device = processor.device
full_probs = torch.zeros((height, width), dtype=torch.float32, device=device)
band_h, band_w = band_coarse.shape
scale_y = band_h / float(height)
scale_x = band_w / float(width)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
with tqdm(total=total_tiles, desc="精修分块", unit="") as pbar:
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
c_top = int(top * scale_y)
c_left = int(left * scale_x)
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1)
c_right = max(int(np.ceil(right * scale_x)), c_left + 1)
if not band_coarse[c_top:c_bottom, c_left:c_right].any():
pbar.update(1)
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window)
state = processor.set_image(crop)
state = processor.set_text_prompt(prompt=prompt, state=state)
tile_prob = combine_masks_logits(state["masks_logits"])
if tile_prob is None:
pbar.update(1)
continue
if tile_prob.shape[-2:] != (bottom - top, right - left):
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
full_probs[top:bottom, left:right] = torch.maximum(
full_probs[top:bottom, left:right], tile_prob
)
pbar.update(1)
return full_probs
# ---------- 主程序 ----------
matplotlib.use("TkAgg")
# 参数设置
image_path = r"E:\is2\dingshanhu\result_caijian.tif"
mask_output_path = r"E:\is2\dingshanhu\result_maskV1.tif"
prompt = "water body"
coarse_read_max_side = 768 # 不再使用,但保留以免影响其他部分
coarse_resolution = 1008
fine_resolution = 1008
coarse_threshold = 0.5
final_threshold = 0.5
band_radius = 64
tile_size = 2048 # 精修分块大小
overlap = 256 # 精修重叠
coarse_tile_size = 6144 # 粗分割分块大小(可调)
coarse_overlap = 256 # 粗分割重叠(可调)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
model = build_sam3_image_model().to(device).eval()
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device)
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device)
with rasterio.open(image_path) as src:
nodata = src.nodata
print(f"原始影像NoData值: {nodata}")
print("开始粗分割(分块推理)...")
coarse_prob = coarse_tiles(
coarse_processor, src, prompt,
tile_size=coarse_tile_size, overlap=coarse_overlap
)
coarse_mask = coarse_prob > coarse_threshold
print("构建边缘带...")
band = build_band(coarse_mask, band_radius)
print("开始精修分割...")
fine_probs = refine_tiles(
fine_processor, src, prompt, band, tile_size, overlap
)
print("合并粗/细结果...")
final_prob = torch.maximum(fine_probs, coarse_prob)
final_mask = final_prob > final_threshold
mask_np = final_mask.cpu().numpy().astype(np.uint8)
# ========== 填充内部NoData空洞带膨胀 ==========
if nodata is not None:
band1 = src.read(1)
if np.issubdtype(band1.dtype, np.floating):
nodata_mask = np.isclose(band1, float(nodata))
else:
nodata_mask = (band1 == nodata)
from scipy import ndimage
labeled_mask, num_features = ndimage.label(nodata_mask)
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
boundary_mask[0, :] = True
boundary_mask[-1, :] = True
boundary_mask[:, 0] = True
boundary_mask[:, -1] = True
boundary_labels = set(labeled_mask[boundary_mask])
internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask
if internal_nodata_mask.any():
struct = ndimage.generate_binary_structure(2, 2) # 8邻域
internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7)
mask_np[internal_dilated] = 1
print(f"内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}")
else:
print("无内部NoData区域")
# ==================================================
profile = src.profile.copy()
profile.update(count=1, dtype="uint8")
with rasterio.open(mask_output_path, "w", **profile) as dst:
dst.write(mask_np, 1)
print(f"分割完成,结果已保存至:{mask_output_path}")