Initial commit
This commit is contained in:
391
water_V4.py
Normal file
391
water_V4.py
Normal file
@ -0,0 +1,391 @@
|
||||
"""
|
||||
使用SAM3模型分割超大遥感影像中的水体(分区域处理,优化内存)
|
||||
|
||||
流程:
|
||||
1. 将影像划分为若干有重叠的子区域。
|
||||
2. 对每个子区域独立执行粗分割、边缘带构建、精修分割。
|
||||
3. 合并所有子区域的掩码(重叠区域取最大值)。
|
||||
4. 后处理:填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。
|
||||
5. 保存结果。
|
||||
|
||||
优点:避免在全尺寸概率图上进行GPU操作,显著降低显存占用。
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import rasterio
|
||||
from rasterio.windows import Window
|
||||
from rasterio.io import MemoryFile
|
||||
from tqdm import tqdm
|
||||
from scipy import ndimage
|
||||
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
|
||||
|
||||
# ---------- 工具函数(保持不变) ----------
|
||||
def binary_dilate(mask, radius):
|
||||
if radius <= 0:
|
||||
return mask
|
||||
kernel = 2 * radius + 1
|
||||
return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5
|
||||
|
||||
def binary_erode(mask, radius):
|
||||
if radius <= 0:
|
||||
return mask
|
||||
return ~binary_dilate(~mask, radius)
|
||||
|
||||
def combine_masks_logits(masks_logits):
|
||||
if masks_logits.numel() == 0:
|
||||
return None
|
||||
probs = masks_logits.squeeze(1)
|
||||
if probs.dim() == 2:
|
||||
return probs
|
||||
return torch.amax(probs, dim=0)
|
||||
|
||||
def upsample_prob(prob, size):
|
||||
return F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False).squeeze(0).squeeze(0)
|
||||
|
||||
def build_band(mask, radius):
|
||||
"""mask: 2D tensor (CPU or GPU)"""
|
||||
mask_4d = mask[None, None, ...] # (1,1,H,W)
|
||||
dil = binary_dilate(mask_4d, radius)
|
||||
ero = binary_erode(mask_4d, radius)
|
||||
band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0)
|
||||
return band
|
||||
|
||||
def tile_slices(height, width, tile_size, overlap):
|
||||
stride = max(tile_size - overlap, 1)
|
||||
for top in range(0, height, stride):
|
||||
for left in range(0, width, stride):
|
||||
bottom = min(top + tile_size, height)
|
||||
right = min(left + tile_size, width)
|
||||
yield top, left, bottom, right
|
||||
|
||||
def _to_uint8(arr):
|
||||
if arr.dtype == np.uint8:
|
||||
return arr
|
||||
arr = arr.astype(np.float32)
|
||||
vmin = np.percentile(arr, 2.0)
|
||||
vmax = np.percentile(arr, 98.0)
|
||||
if vmax <= vmin:
|
||||
return np.zeros_like(arr, dtype=np.uint8)
|
||||
arr = (arr - vmin) / (vmax - vmin)
|
||||
arr = np.clip(arr, 0.0, 1.0)
|
||||
return (arr * 255.0).astype(np.uint8)
|
||||
|
||||
def _bands_to_pil(bands):
|
||||
if bands.ndim != 3:
|
||||
raise ValueError("bands must be (C,H,W)")
|
||||
c, h, w = bands.shape
|
||||
if c == 1:
|
||||
rgb = np.repeat(bands, 3, axis=0)
|
||||
else:
|
||||
rgb = bands[:3]
|
||||
rgb = _to_uint8(rgb)
|
||||
rgb = np.transpose(rgb, (1, 2, 0))
|
||||
return Image.fromarray(rgb, mode="RGB")
|
||||
|
||||
def _read_pil_window(src, window):
|
||||
bands = src.read(window=window, boundless=True, fill_value=0)
|
||||
return _bands_to_pil(bands)
|
||||
|
||||
def _coarse_shape(height, width, max_side):
|
||||
scale = max_side / float(max(height, width))
|
||||
h = max(int(round(height * scale)), 1)
|
||||
w = max(int(round(width * scale)), 1)
|
||||
return h, w
|
||||
|
||||
# ---------- 粗分割分块推理 ----------
|
||||
def coarse_tiles(processor, src, prompt, tile_size, overlap):
|
||||
"""
|
||||
对数据集src进行分块粗分割,返回全尺寸概率图(GPU张量)
|
||||
src: rasterio数据集(可以是内存文件或原始文件)
|
||||
"""
|
||||
height, width = src.height, src.width
|
||||
device = processor.device
|
||||
full_probs = torch.zeros((height, width), dtype=torch.float32, device=device)
|
||||
|
||||
stride = max(tile_size - overlap, 1)
|
||||
num_tiles_y = (height + stride - 1) // stride
|
||||
num_tiles_x = (width + stride - 1) // stride
|
||||
total_tiles = num_tiles_y * num_tiles_x
|
||||
|
||||
# 内部进度条,leave=True 以便保留历史记录
|
||||
with tqdm(total=total_tiles, desc="粗分割分块", unit="块", leave=True) as pbar:
|
||||
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
|
||||
window = Window(left, top, right - left, bottom - top)
|
||||
crop = _read_pil_window(src, window)
|
||||
state = processor.set_image(crop)
|
||||
state = processor.set_text_prompt(prompt=prompt, state=state)
|
||||
tile_prob = combine_masks_logits(state["masks_logits"])
|
||||
if tile_prob is None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
if tile_prob.shape[-2:] != (bottom - top, right - left):
|
||||
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
|
||||
|
||||
full_probs[top:bottom, left:right] = torch.maximum(
|
||||
full_probs[top:bottom, left:right], tile_prob
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
return full_probs
|
||||
|
||||
# ---------- 精修分块 ----------
|
||||
def refine_tiles(processor, src, prompt, band_coarse, tile_size, overlap):
|
||||
"""
|
||||
对边缘带进行精修分割
|
||||
band_coarse: 2D GPU张量,指示需要精修的区域
|
||||
"""
|
||||
height, width = src.height, src.width
|
||||
device = processor.device
|
||||
full_probs = torch.zeros((height, width), dtype=torch.float32, device=device)
|
||||
|
||||
band_h, band_w = band_coarse.shape
|
||||
scale_y = band_h / float(height)
|
||||
scale_x = band_w / float(width)
|
||||
|
||||
stride = max(tile_size - overlap, 1)
|
||||
num_tiles_y = (height + stride - 1) // stride
|
||||
num_tiles_x = (width + stride - 1) // stride
|
||||
total_tiles = num_tiles_y * num_tiles_x
|
||||
|
||||
# 内部进度条,leave=True 以便保留历史记录
|
||||
with tqdm(total=total_tiles, desc="精修分块", unit="块", leave=True) as pbar:
|
||||
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
|
||||
c_top = int(top * scale_y)
|
||||
c_left = int(left * scale_x)
|
||||
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1)
|
||||
c_right = max(int(np.ceil(right * scale_x)), c_left + 1)
|
||||
|
||||
if not band_coarse[c_top:c_bottom, c_left:c_right].any():
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
window = Window(left, top, right - left, bottom - top)
|
||||
crop = _read_pil_window(src, window)
|
||||
state = processor.set_image(crop)
|
||||
state = processor.set_text_prompt(prompt=prompt, state=state)
|
||||
tile_prob = combine_masks_logits(state["masks_logits"])
|
||||
if tile_prob is None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
if tile_prob.shape[-2:] != (bottom - top, right - left):
|
||||
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
|
||||
|
||||
full_probs[top:bottom, left:right] = torch.maximum(
|
||||
full_probs[top:bottom, left:right], tile_prob
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
return full_probs
|
||||
|
||||
# ---------- 后处理函数(面积过滤) ----------
|
||||
def filter_by_area(mask, min_area=None, keep_largest_only=False):
|
||||
"""
|
||||
对二值掩码进行连通域面积过滤。
|
||||
mask: np.ndarray, dtype=bool or uint8, shape (H,W)
|
||||
min_area: int, 最小面积阈值(像素),小于此值的连通域移除。若为None则不过滤。
|
||||
keep_largest_only: bool, 若True,只保留面积最大的连通域(同时受min_area约束)。
|
||||
返回过滤后的掩码 (uint8, 0/1)。
|
||||
"""
|
||||
labeled, num = ndimage.label(mask)
|
||||
if num == 0:
|
||||
return mask.astype(np.uint8)
|
||||
|
||||
sizes = ndimage.sum(mask, labeled, range(1, num+1))
|
||||
keep_labels = []
|
||||
|
||||
if keep_largest_only:
|
||||
max_idx = np.argmax(sizes)
|
||||
if min_area is None or sizes[max_idx] >= min_area:
|
||||
keep_labels = [max_idx + 1]
|
||||
else:
|
||||
for i, size in enumerate(sizes, start=1):
|
||||
if min_area is None or size >= min_area:
|
||||
keep_labels.append(i)
|
||||
|
||||
if not keep_labels:
|
||||
return np.zeros_like(mask, dtype=np.uint8)
|
||||
|
||||
keep_mask = np.isin(labeled, keep_labels)
|
||||
return keep_mask.astype(np.uint8)
|
||||
|
||||
# ---------- 划分子区域函数 ----------
|
||||
def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=256):
|
||||
"""
|
||||
将图像划分为若干有重叠的子区域,返回每个子区域的 (left, top, right, bottom)
|
||||
坐标(像素坐标,相对于原图)。
|
||||
"""
|
||||
regions = []
|
||||
tile_w = (width + num_splits_x - 1) // num_splits_x
|
||||
tile_h = (height + num_splits_y - 1) // num_splits_y
|
||||
|
||||
for i in range(num_splits_y):
|
||||
for j in range(num_splits_x):
|
||||
left = max(j * tile_w - overlap, 0)
|
||||
top = max(i * tile_h - overlap, 0)
|
||||
right = min((j + 1) * tile_w + overlap, width)
|
||||
bottom = min((i + 1) * tile_h + overlap, height)
|
||||
regions.append((left, top, right, bottom))
|
||||
return regions
|
||||
|
||||
# ---------- 主程序 ----------
|
||||
matplotlib.use("TkAgg")
|
||||
|
||||
# 参数设置
|
||||
image_path = r"E:\is2\dingshanhu\result_caijian.tif"
|
||||
mask_output_path = r"E:\is2\dingshanhu\result_maskV1.tif"
|
||||
prompt = "water body"
|
||||
coarse_read_max_side = 768 # 不再使用,保留以防万一
|
||||
coarse_resolution = 1008
|
||||
fine_resolution = 1008
|
||||
coarse_threshold = 0.5
|
||||
final_threshold = 0.5
|
||||
band_radius = 64
|
||||
tile_size = 2048 # 精修分块大小
|
||||
overlap = 256 # 精修重叠
|
||||
coarse_tile_size = 4096 # 粗分割分块大小
|
||||
coarse_overlap = 256 # 粗分割重叠
|
||||
|
||||
# ========== 分区域参数 ==========
|
||||
num_splits_y = 2 # 纵向切分数
|
||||
num_splits_x = 3 # 横向切分数(共 2x3=6 份)
|
||||
region_overlap = 256 # 子区域之间的重叠像素数
|
||||
# ===============================
|
||||
|
||||
# ========== 后处理参数 ==========
|
||||
min_area = 1000 # 最小面积阈值(像素),小于此值的连通域将被移除;设为0或None表示不进行面积过滤
|
||||
keep_largest_only = False # 是否只保留最大的连通域(True/False)
|
||||
# ===============================
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# 加载模型(所有子区域共享同一个模型)
|
||||
model = build_sam3_image_model().to(device).eval()
|
||||
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device)
|
||||
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device)
|
||||
|
||||
# 打开原始影像
|
||||
with rasterio.open(image_path) as src:
|
||||
nodata = src.nodata
|
||||
print(f"原始影像NoData值: {nodata}")
|
||||
height, width = src.height, src.width
|
||||
|
||||
# 划分区域
|
||||
regions = split_into_regions(width, height, num_splits_y, num_splits_x, region_overlap)
|
||||
print(f"共划分 {len(regions)} 个子区域")
|
||||
|
||||
# 创建全尺寸掩码数组(CPU内存)
|
||||
full_mask = np.zeros((height, width), dtype=np.uint8)
|
||||
|
||||
# 子区域总进度条
|
||||
with tqdm(total=len(regions), desc="处理子区域", unit="子区域") as region_pbar:
|
||||
for idx, (left, top, right, bottom) in enumerate(regions):
|
||||
print(f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})")
|
||||
region_w = right - left
|
||||
region_h = bottom - top
|
||||
|
||||
# 读取子区域数据
|
||||
window = Window(left, top, region_w, region_h)
|
||||
sub_bands = src.read(window=window) # shape: (bands, h, w)
|
||||
|
||||
# 构建子区域元数据
|
||||
sub_profile = src.profile.copy()
|
||||
if not src.is_tiled:
|
||||
sub_profile.pop('blockxsize', None)
|
||||
sub_profile.pop('blockysize', None)
|
||||
sub_profile['tiled'] = False
|
||||
else:
|
||||
sub_profile['tiled'] = True
|
||||
sub_profile.update({
|
||||
'height': region_h,
|
||||
'width': region_w,
|
||||
'transform': rasterio.windows.transform(window, src.transform)
|
||||
})
|
||||
|
||||
# 将子区域数据包装成内存中的 rasterio 数据集
|
||||
with MemoryFile() as memfile:
|
||||
with memfile.open(**sub_profile) as sub_dst:
|
||||
sub_dst.write(sub_bands)
|
||||
with memfile.open() as sub_src:
|
||||
# ---------- 粗分割 ----------
|
||||
coarse_prob = coarse_tiles(
|
||||
coarse_processor, sub_src, prompt,
|
||||
tile_size=coarse_tile_size, overlap=coarse_overlap
|
||||
)
|
||||
coarse_mask = coarse_prob > coarse_threshold
|
||||
|
||||
# ---------- 构建边缘带(直接在 GPU 上进行) ----------
|
||||
band = build_band(coarse_mask, band_radius)
|
||||
|
||||
# ---------- 精修分割 ----------
|
||||
fine_probs = refine_tiles(
|
||||
fine_processor, sub_src, prompt, band,
|
||||
tile_size=tile_size, overlap=overlap
|
||||
)
|
||||
|
||||
# ---------- 合并粗/细结果 ----------
|
||||
final_prob = torch.maximum(fine_probs, coarse_prob)
|
||||
final_mask = final_prob > final_threshold
|
||||
|
||||
# ---------- 获取子区域掩码(numpy) ----------
|
||||
sub_mask_np = final_mask.cpu().numpy().astype(np.uint8)
|
||||
|
||||
# 合并到全尺寸掩码(重叠区域取最大值)
|
||||
full_mask[top:bottom, left:right] = np.maximum(
|
||||
full_mask[top:bottom, left:right], sub_mask_np
|
||||
)
|
||||
|
||||
# 更新子区域进度条
|
||||
region_pbar.update(1)
|
||||
|
||||
# ========== 后处理 ==========
|
||||
print("\n后处理:填充内部NoData空洞...")
|
||||
if nodata is not None:
|
||||
band1 = src.read(1)
|
||||
if np.issubdtype(band1.dtype, np.floating):
|
||||
nodata_mask = np.isclose(band1, float(nodata))
|
||||
else:
|
||||
nodata_mask = (band1 == nodata)
|
||||
|
||||
labeled_mask, num_features = ndimage.label(nodata_mask)
|
||||
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
|
||||
boundary_mask[0, :] = True
|
||||
boundary_mask[-1, :] = True
|
||||
boundary_mask[:, 0] = True
|
||||
boundary_mask[:, -1] = True
|
||||
|
||||
boundary_labels = set(labeled_mask[boundary_mask])
|
||||
internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask
|
||||
|
||||
if internal_nodata_mask.any():
|
||||
struct = ndimage.generate_binary_structure(2, 2)
|
||||
internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7)
|
||||
full_mask[internal_dilated] = 1
|
||||
print(f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}")
|
||||
else:
|
||||
print(" 无内部NoData区域")
|
||||
|
||||
print("后处理:面积过滤...")
|
||||
if min_area is not None or keep_largest_only:
|
||||
original_count = np.sum(full_mask)
|
||||
full_mask = filter_by_area(full_mask, min_area=min_area, keep_largest_only=keep_largest_only)
|
||||
filtered_count = np.sum(full_mask)
|
||||
print(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
|
||||
|
||||
# ========== 保存结果 ==========
|
||||
profile = src.profile.copy()
|
||||
profile.update(count=1, dtype="uint8", compress='lzw')
|
||||
with rasterio.open(mask_output_path, "w", **profile) as dst:
|
||||
dst.write(full_mask, 1)
|
||||
|
||||
print(f"分割完成,结果已保存至:{mask_output_path}")
|
||||
Reference in New Issue
Block a user