Initial commit

This commit is contained in:
2026-03-09 17:23:53 +08:00
parent 17e0db880e
commit 1422fb5026
515 changed files with 124791 additions and 0 deletions

394
water_V5.py Normal file
View File

@ -0,0 +1,394 @@
"""
使用SAM3模型分割超大遥感影像中的水体分区域处理优化内存
流程:
1. 将影像划分为若干有重叠的子区域。
2. 对每个子区域独立执行粗分割、边缘带构建、精修分割。
3. 合并所有子区域的掩码(重叠区域取最大值)。
4. 后处理填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。
5. 保存结果。
优点避免在全尺寸概率图上进行GPU操作显著降低显存占用。
"""
import torch
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import rasterio
from rasterio.windows import Window
from rasterio.io import MemoryFile
from tqdm import tqdm
from scipy import ndimage
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# ---------- 工具函数(保持不变) ----------
def binary_dilate(mask, radius):
if radius <= 0:
return mask
kernel = 2 * radius + 1
return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5
def binary_erode(mask, radius):
if radius <= 0:
return mask
return ~binary_dilate(~mask, radius)
def combine_masks_logits(masks_logits):
if masks_logits.numel() == 0:
return None
probs = masks_logits.squeeze(1)
if probs.dim() == 2:
return probs
return torch.amax(probs, dim=0)
def upsample_prob(prob, size):
return F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False).squeeze(0).squeeze(0)
def build_band(mask, radius):
"""mask: 2D tensor (CPU or GPU)"""
mask_4d = mask[None, None, ...] # (1,1,H,W)
dil = binary_dilate(mask_4d, radius)
ero = binary_erode(mask_4d, radius)
band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0)
return band
def tile_slices(height, width, tile_size, overlap):
stride = max(tile_size - overlap, 1)
for top in range(0, height, stride):
for left in range(0, width, stride):
bottom = min(top + tile_size, height)
right = min(left + tile_size, width)
yield top, left, bottom, right
def compute_stretch_params(bands, sample_max_pixels=2_000_000):
c, h, w = bands.shape
step = int(np.ceil(np.sqrt((h * w) / float(sample_max_pixels))))
step = max(step, 1)
sample = bands[:, ::step, ::step].astype(np.float32)
vmin = np.percentile(sample, 2.0, axis=(1, 2))
vmax = np.percentile(sample, 98.0, axis=(1, 2))
vmax = np.maximum(vmax, vmin + 1e-6)
return vmin, vmax
def _to_uint8(arr, vmin=None, vmax=None):
if arr.dtype == np.uint8:
return arr
arr = arr.astype(np.float32)
if vmin is None or vmax is None:
vmin = np.percentile(arr, 2.0)
vmax = np.percentile(arr, 98.0)
if vmax <= vmin:
return np.zeros_like(arr, dtype=np.uint8)
arr = (arr - vmin) / (vmax - vmin)
arr = np.clip(arr, 0.0, 1.0)
return (arr * 255.0).astype(np.uint8)
def _bands_to_pil(bands, stretch_params=None):
if bands.ndim != 3:
raise ValueError("bands must be (C,H,W)")
c, h, w = bands.shape
if c == 1:
rgb = np.repeat(bands, 3, axis=0)
else:
rgb = bands[:3]
if stretch_params is not None:
vmin, vmax = stretch_params
vmin3 = vmin[:3] if vmin.shape[0] >= 3 else np.repeat(vmin[:1], 3)
vmax3 = vmax[:3] if vmax.shape[0] >= 3 else np.repeat(vmax[:1], 3)
rgb_u8 = np.empty_like(rgb, dtype=np.uint8)
for i in range(3):
rgb_u8[i] = _to_uint8(rgb[i], vmin=vmin3[i], vmax=vmax3[i])
rgb = rgb_u8
else:
rgb = _to_uint8(rgb)
rgb = np.transpose(rgb, (1, 2, 0))
return Image.fromarray(rgb, mode="RGB")
def _read_pil_window(src, window, stretch_params=None):
bands = src.read(window=window, boundless=True, fill_value=0)
return _bands_to_pil(bands, stretch_params=stretch_params)
def infer_prompt_prob(processor, image, prompt):
if processor.device != "cpu":
with torch.autocast(device_type="cuda", dtype=torch.float16):
state = processor.set_image(image)
state = processor.set_text_prompt(prompt=prompt, state=state)
else:
state = processor.set_image(image)
state = processor.set_text_prompt(prompt=prompt, state=state)
prob = combine_masks_logits(state["masks_logits"])
return prob
def make_overview_bands(sub_bands, max_side):
_, h, w = sub_bands.shape
step = int(np.ceil(max(h, w) / float(max_side)))
step = max(step, 1)
return sub_bands[:, ::step, ::step], step
# ---------- 精修分块 ----------
def refine_tiles(processor, src, prompt, band_coarse_cpu, tile_size, overlap, stretch_params=None):
"""
对边缘带进行精修分割
band_coarse_cpu: 2D numpy bool指示需要精修的区域低分辨率
"""
height, width = src.height, src.width
full_probs = np.zeros((height, width), dtype=np.float16)
band_h, band_w = band_coarse_cpu.shape
scale_y = band_h / float(height)
scale_x = band_w / float(width)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
# 内部进度条leave=True 以便保留历史记录
with tqdm(total=total_tiles, desc="精修分块", unit="", leave=True) as pbar:
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
c_top = int(top * scale_y)
c_left = int(left * scale_x)
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1)
c_right = max(int(np.ceil(right * scale_x)), c_left + 1)
if not band_coarse_cpu[c_top:c_bottom, c_left:c_right].any():
pbar.update(1)
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window, stretch_params=stretch_params)
tile_prob = infer_prompt_prob(processor, crop, prompt)
if tile_prob is None:
pbar.update(1)
continue
if tile_prob.shape[-2:] != (bottom - top, right - left):
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
tile_prob_cpu = tile_prob.detach().float().cpu().numpy().astype(np.float16)
full_probs[top:bottom, left:right] = np.maximum(
full_probs[top:bottom, left:right], tile_prob_cpu
)
pbar.update(1)
return full_probs
# ---------- 后处理函数(面积过滤) ----------
def filter_by_area(mask, min_area=None, keep_largest_only=False):
"""
对二值掩码进行连通域面积过滤。
mask: np.ndarray, dtype=bool or uint8, shape (H,W)
min_area: int, 最小面积阈值像素小于此值的连通域移除。若为None则不过滤。
keep_largest_only: bool, 若True只保留面积最大的连通域同时受min_area约束
返回过滤后的掩码 (uint8, 0/1)。
"""
labeled, num = ndimage.label(mask)
if num == 0:
return mask.astype(np.uint8)
sizes = ndimage.sum(mask, labeled, range(1, num+1))
keep_labels = []
if keep_largest_only:
max_idx = np.argmax(sizes)
if min_area is None or sizes[max_idx] >= min_area:
keep_labels = [max_idx + 1]
else:
for i, size in enumerate(sizes, start=1):
if min_area is None or size >= min_area:
keep_labels.append(i)
if not keep_labels:
return np.zeros_like(mask, dtype=np.uint8)
keep_mask = np.isin(labeled, keep_labels)
return keep_mask.astype(np.uint8)
# ---------- 划分子区域函数 ----------
def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=256):
"""
将图像划分为若干有重叠的子区域,返回每个子区域的 (left, top, right, bottom)
坐标(像素坐标,相对于原图)。
"""
regions = []
tile_w = (width + num_splits_x - 1) // num_splits_x
tile_h = (height + num_splits_y - 1) // num_splits_y
for i in range(num_splits_y):
for j in range(num_splits_x):
left = max(j * tile_w - overlap, 0)
top = max(i * tile_h - overlap, 0)
right = min((j + 1) * tile_w + overlap, width)
bottom = min((i + 1) * tile_h + overlap, height)
regions.append((left, top, right, bottom))
return regions
# ---------- 主程序 ----------
matplotlib.use("TkAgg")
# 参数设置
image_path = r"E:\is2\dingshanhu\result_caijian.tif"
mask_output_path = r"E:\is2\dingshanhu\result_maskV2.tif"
prompt = "water body"
coarse_read_max_side = 1200
coarse_resolution = 1008
fine_resolution = 1008
coarse_threshold = 0.5
final_threshold = 0.5
band_radius = 64
tile_size = 1536
overlap = 128
# ========== 分区域参数 ==========
num_splits_y = 2 # 纵向切分数
num_splits_x = 3 # 横向切分数(共 2x3=6 份)
region_overlap = 256 # 子区域之间的重叠像素数
# ===============================
# ========== 后处理参数 ==========
min_area = 1000 # 最小面积阈值像素小于此值的连通域将被移除设为0或None表示不进行面积过滤
keep_largest_only = False # 是否只保留最大的连通域True/False
# ===============================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
# 加载模型(所有子区域共享同一个模型)
model = build_sam3_image_model().to(device).eval()
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device)
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device)
# 打开原始影像
with rasterio.open(image_path) as src:
nodata = src.nodata
print(f"原始影像NoData值: {nodata}")
height, width = src.height, src.width
# 划分区域
regions = split_into_regions(width, height, num_splits_y, num_splits_x, region_overlap)
print(f"共划分 {len(regions)} 个子区域")
# 创建全尺寸掩码数组CPU内存
full_mask = np.zeros((height, width), dtype=np.uint8)
# 子区域总进度条
with tqdm(total=len(regions), desc="处理子区域", unit="子区域") as region_pbar:
for idx, (left, top, right, bottom) in enumerate(regions):
print(f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})")
region_w = right - left
region_h = bottom - top
# 读取子区域数据
window = Window(left, top, region_w, region_h)
sub_bands = src.read(window=window) # shape: (bands, h, w)
# 构建子区域元数据
sub_profile = src.profile.copy()
if not src.is_tiled:
sub_profile.pop('blockxsize', None)
sub_profile.pop('blockysize', None)
sub_profile['tiled'] = False
else:
sub_profile['tiled'] = True
sub_profile.update({
'height': region_h,
'width': region_w,
'transform': rasterio.windows.transform(window, src.transform)
})
# 将子区域数据包装成内存中的 rasterio 数据集
with MemoryFile() as memfile:
with memfile.open(**sub_profile) as sub_dst:
sub_dst.write(sub_bands)
with memfile.open() as sub_src:
stretch_params = compute_stretch_params(sub_bands)
overview_bands, overview_step = make_overview_bands(
sub_bands, max_side=coarse_read_max_side
)
overview_img = _bands_to_pil(overview_bands, stretch_params=stretch_params)
overview_prob = infer_prompt_prob(coarse_processor, overview_img, prompt)
if overview_prob is None:
overview_prob_cpu = np.zeros((overview_img.height, overview_img.width), dtype=np.float16)
else:
overview_prob_cpu = overview_prob.detach().float().cpu().numpy().astype(np.float16)
overview_mask_cpu = overview_prob_cpu > coarse_threshold
scale = overview_img.height / float(region_h)
band_radius_small = max(int(round(band_radius * scale)), 1)
band_small = build_band(torch.from_numpy(overview_mask_cpu), band_radius_small).numpy()
fine_probs_cpu = refine_tiles(
fine_processor,
sub_src,
prompt,
band_small,
tile_size=tile_size,
overlap=overlap,
stretch_params=stretch_params,
)
fine_mask = fine_probs_cpu > final_threshold
coarse_mask_full = np.array(
Image.fromarray(overview_mask_cpu.astype(np.uint8), mode="L").resize(
(region_w, region_h), resample=Image.NEAREST
)
).astype(bool)
sub_mask_np = (fine_mask | coarse_mask_full).astype(np.uint8)
# 合并到全尺寸掩码(重叠区域取最大值)
full_mask[top:bottom, left:right] = np.maximum(
full_mask[top:bottom, left:right], sub_mask_np
)
# 更新子区域进度条
region_pbar.update(1)
# ========== 后处理 ==========
print("\n后处理填充内部NoData空洞...")
if nodata is not None:
band1 = src.read(1)
if np.issubdtype(band1.dtype, np.floating):
nodata_mask = np.isclose(band1, float(nodata))
else:
nodata_mask = (band1 == nodata)
labeled_mask, num_features = ndimage.label(nodata_mask)
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
boundary_mask[0, :] = True
boundary_mask[-1, :] = True
boundary_mask[:, 0] = True
boundary_mask[:, -1] = True
boundary_labels = set(labeled_mask[boundary_mask])
internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask
if internal_nodata_mask.any():
struct = ndimage.generate_binary_structure(2, 2)
internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7)
full_mask[internal_dilated] = 1
print(f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}")
else:
print(" 无内部NoData区域")
print("后处理:面积过滤...")
if min_area is not None or keep_largest_only:
original_count = np.sum(full_mask)
full_mask = filter_by_area(full_mask, min_area=min_area, keep_largest_only=keep_largest_only)
filtered_count = np.sum(full_mask)
print(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
# ========== 保存结果 ==========
profile = src.profile.copy()
profile.update(count=1, dtype="uint8", compress='lzw')
with rasterio.open(mask_output_path, "w", **profile) as dst:
dst.write(full_mask, 1)
print(f"分割完成,结果已保存至:{mask_output_path}")