251 lines
8.9 KiB
Python
251 lines
8.9 KiB
Python
import torch
|
||
import torch.nn.functional as F
|
||
import matplotlib
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
from PIL import Image
|
||
import rasterio
|
||
from rasterio.windows import Window
|
||
from tqdm import tqdm
|
||
|
||
from sam3.model_builder import build_sam3_image_model
|
||
from sam3.model.sam3_image_processor import Sam3Processor
|
||
|
||
|
||
# ---------- 工具函数(保持不变) ----------
|
||
def binary_dilate(mask, radius):
|
||
if radius <= 0:
|
||
return mask
|
||
kernel = 2 * radius + 1
|
||
return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5
|
||
|
||
def binary_erode(mask, radius):
|
||
if radius <= 0:
|
||
return mask
|
||
return ~binary_dilate(~mask, radius)
|
||
|
||
def combine_masks_logits(masks_logits):
|
||
if masks_logits.numel() == 0:
|
||
return None
|
||
probs = masks_logits.squeeze(1)
|
||
if probs.dim() == 2:
|
||
return probs
|
||
return torch.amax(probs, dim=0)
|
||
|
||
def upsample_prob(prob, size):
|
||
return F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False).squeeze(0).squeeze(0)
|
||
|
||
def build_band(mask, radius):
|
||
mask_4d = mask[None, None, ...]
|
||
dil = binary_dilate(mask_4d, radius)
|
||
ero = binary_erode(mask_4d, radius)
|
||
band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0)
|
||
return band
|
||
|
||
def tile_slices(height, width, tile_size, overlap):
|
||
stride = max(tile_size - overlap, 1)
|
||
for top in range(0, height, stride):
|
||
for left in range(0, width, stride):
|
||
bottom = min(top + tile_size, height)
|
||
right = min(left + tile_size, width)
|
||
yield top, left, bottom, right
|
||
|
||
def _to_uint8(arr):
|
||
if arr.dtype == np.uint8:
|
||
return arr
|
||
arr = arr.astype(np.float32)
|
||
vmin = np.percentile(arr, 2.0)
|
||
vmax = np.percentile(arr, 98.0)
|
||
if vmax <= vmin:
|
||
return np.zeros_like(arr, dtype=np.uint8)
|
||
arr = (arr - vmin) / (vmax - vmin)
|
||
arr = np.clip(arr, 0.0, 1.0)
|
||
return (arr * 255.0).astype(np.uint8)
|
||
|
||
def _bands_to_pil(bands):
|
||
if bands.ndim != 3:
|
||
raise ValueError("bands must be (C,H,W)")
|
||
c, h, w = bands.shape
|
||
if c == 1:
|
||
rgb = np.repeat(bands, 3, axis=0)
|
||
else:
|
||
rgb = bands[:3]
|
||
rgb = _to_uint8(rgb)
|
||
rgb = np.transpose(rgb, (1, 2, 0))
|
||
return Image.fromarray(rgb, mode="RGB")
|
||
|
||
def _read_pil_window(src, window):
|
||
bands = src.read(window=window, boundless=True, fill_value=0)
|
||
return _bands_to_pil(bands)
|
||
|
||
def _coarse_shape(height, width, max_side):
|
||
scale = max_side / float(max(height, width))
|
||
h = max(int(round(height * scale)), 1)
|
||
w = max(int(round(width * scale)), 1)
|
||
return h, w
|
||
|
||
# ---------- 新增:粗分割分块推理 ----------
|
||
def coarse_tiles(processor, src, prompt, tile_size, overlap):
|
||
"""
|
||
对原图进行分块粗分割,返回全尺寸概率图(GPU张量)
|
||
"""
|
||
height, width = src.height, src.width
|
||
device = processor.device
|
||
full_probs = torch.zeros((height, width), dtype=torch.float32, device=device)
|
||
|
||
stride = max(tile_size - overlap, 1)
|
||
num_tiles_y = (height + stride - 1) // stride
|
||
num_tiles_x = (width + stride - 1) // stride
|
||
total_tiles = num_tiles_y * num_tiles_x
|
||
|
||
with tqdm(total=total_tiles, desc="粗分割分块", unit="块") as pbar:
|
||
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
|
||
window = Window(left, top, right - left, bottom - top)
|
||
crop = _read_pil_window(src, window)
|
||
state = processor.set_image(crop)
|
||
state = processor.set_text_prompt(prompt=prompt, state=state)
|
||
tile_prob = combine_masks_logits(state["masks_logits"])
|
||
if tile_prob is None:
|
||
pbar.update(1)
|
||
continue
|
||
|
||
if tile_prob.shape[-2:] != (bottom - top, right - left):
|
||
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
|
||
|
||
full_probs[top:bottom, left:right] = torch.maximum(
|
||
full_probs[top:bottom, left:right], tile_prob
|
||
)
|
||
pbar.update(1)
|
||
|
||
return full_probs
|
||
|
||
# ---------- 精修分块(保持不变) ----------
|
||
def refine_tiles(processor, src, prompt, band_coarse, tile_size, overlap):
|
||
height, width = src.height, src.width
|
||
device = processor.device
|
||
full_probs = torch.zeros((height, width), dtype=torch.float32, device=device)
|
||
|
||
band_h, band_w = band_coarse.shape
|
||
scale_y = band_h / float(height)
|
||
scale_x = band_w / float(width)
|
||
|
||
stride = max(tile_size - overlap, 1)
|
||
num_tiles_y = (height + stride - 1) // stride
|
||
num_tiles_x = (width + stride - 1) // stride
|
||
total_tiles = num_tiles_y * num_tiles_x
|
||
|
||
with tqdm(total=total_tiles, desc="精修分块", unit="块") as pbar:
|
||
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
|
||
c_top = int(top * scale_y)
|
||
c_left = int(left * scale_x)
|
||
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1)
|
||
c_right = max(int(np.ceil(right * scale_x)), c_left + 1)
|
||
|
||
if not band_coarse[c_top:c_bottom, c_left:c_right].any():
|
||
pbar.update(1)
|
||
continue
|
||
|
||
window = Window(left, top, right - left, bottom - top)
|
||
crop = _read_pil_window(src, window)
|
||
state = processor.set_image(crop)
|
||
state = processor.set_text_prompt(prompt=prompt, state=state)
|
||
tile_prob = combine_masks_logits(state["masks_logits"])
|
||
if tile_prob is None:
|
||
pbar.update(1)
|
||
continue
|
||
|
||
if tile_prob.shape[-2:] != (bottom - top, right - left):
|
||
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
|
||
|
||
full_probs[top:bottom, left:right] = torch.maximum(
|
||
full_probs[top:bottom, left:right], tile_prob
|
||
)
|
||
pbar.update(1)
|
||
|
||
return full_probs
|
||
|
||
# ---------- 主程序 ----------
|
||
matplotlib.use("TkAgg")
|
||
|
||
# 参数设置
|
||
image_path = r"E:\is2\dingshanhu\result_caijian.tif"
|
||
mask_output_path = r"E:\is2\dingshanhu\result_maskV1.tif"
|
||
prompt = "water body"
|
||
coarse_read_max_side = 768 # 不再使用,但保留以免影响其他部分
|
||
coarse_resolution = 1008
|
||
fine_resolution = 1008
|
||
coarse_threshold = 0.5
|
||
final_threshold = 0.5
|
||
band_radius = 64
|
||
tile_size = 2048 # 精修分块大小
|
||
overlap = 256 # 精修重叠
|
||
coarse_tile_size = 6144 # 粗分割分块大小(可调)
|
||
coarse_overlap = 256 # 粗分割重叠(可调)
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
print(f"使用设备: {device}")
|
||
|
||
model = build_sam3_image_model().to(device).eval()
|
||
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device)
|
||
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device)
|
||
|
||
with rasterio.open(image_path) as src:
|
||
nodata = src.nodata
|
||
print(f"原始影像NoData值: {nodata}")
|
||
|
||
print("开始粗分割(分块推理)...")
|
||
coarse_prob = coarse_tiles(
|
||
coarse_processor, src, prompt,
|
||
tile_size=coarse_tile_size, overlap=coarse_overlap
|
||
)
|
||
coarse_mask = coarse_prob > coarse_threshold
|
||
|
||
print("构建边缘带...")
|
||
band = build_band(coarse_mask, band_radius)
|
||
|
||
print("开始精修分割...")
|
||
fine_probs = refine_tiles(
|
||
fine_processor, src, prompt, band, tile_size, overlap
|
||
)
|
||
|
||
print("合并粗/细结果...")
|
||
final_prob = torch.maximum(fine_probs, coarse_prob)
|
||
final_mask = final_prob > final_threshold
|
||
|
||
mask_np = final_mask.cpu().numpy().astype(np.uint8)
|
||
|
||
# ========== 填充内部NoData空洞(带膨胀) ==========
|
||
if nodata is not None:
|
||
band1 = src.read(1)
|
||
if np.issubdtype(band1.dtype, np.floating):
|
||
nodata_mask = np.isclose(band1, float(nodata))
|
||
else:
|
||
nodata_mask = (band1 == nodata)
|
||
|
||
from scipy import ndimage
|
||
|
||
labeled_mask, num_features = ndimage.label(nodata_mask)
|
||
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
|
||
boundary_mask[0, :] = True
|
||
boundary_mask[-1, :] = True
|
||
boundary_mask[:, 0] = True
|
||
boundary_mask[:, -1] = True
|
||
|
||
boundary_labels = set(labeled_mask[boundary_mask])
|
||
internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask
|
||
|
||
if internal_nodata_mask.any():
|
||
struct = ndimage.generate_binary_structure(2, 2) # 8邻域
|
||
internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7)
|
||
mask_np[internal_dilated] = 1
|
||
print(f"内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}")
|
||
else:
|
||
print("无内部NoData区域")
|
||
# ==================================================
|
||
|
||
profile = src.profile.copy()
|
||
profile.update(count=1, dtype="uint8")
|
||
with rasterio.open(mask_output_path, "w", **profile) as dst:
|
||
dst.write(mask_np, 1)
|
||
|
||
print(f"分割完成,结果已保存至:{mask_output_path}") |