This commit is contained in:
2026-03-10 17:29:24 +08:00
parent 64692d8382
commit b8b6c6227d
4 changed files with 993 additions and 147 deletions

View File

@ -1,14 +1,17 @@
"""
使用SAM3模型分割超大遥感影像中的水体分区域处理,优化内存
使用SAM3模型分割超大遥感影像中的水体三层分割策略
流程:
1. 将影像划分为若干有重叠的子区域。
2. 对每个子区域独立执行粗分割、边缘带构建、精修分割
3. 合并所有子区域的掩码(重叠区域取最大值)。
4. 后处理填充内部NoData空洞、移除小面积碎片、可选保留最大连通域
5. 保存结果
1. 将影像划分为若干有重叠的子区域(第一层)
2. 对每个子区域进行中分辨率粗分割以4096为块大小滑动窗口推理得到粗掩码第二层
3. 基于粗掩码构建边缘带,对边缘带进行精细分块推理,得到精修掩码(第三层)。
4. 合并所有子区域的掩码(重叠区域取最大值)
5. 后处理填充内部NoData空洞、移除小面积碎片、可选保留最大连通域
6. 保存结果。
优点:避免在全尺寸概率图上进行GPU操作显著降低显存占用。
优点:
- 粗分割采用分块推理,避免降采样丢失细节,提高召回率。
- 精修只处理边缘带,节省计算资源。
"""
import torch
@ -22,6 +25,7 @@ from rasterio.windows import Window
from rasterio.io import MemoryFile
from tqdm import tqdm
from scipy import ndimage
import math
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
@ -77,6 +81,13 @@ def compute_stretch_params(bands, sample_max_pixels=2_000_000):
return vmin, vmax
def make_overview_bands(bands, max_side):
_, h, w = bands.shape
step = int(math.ceil(max(h, w) / float(max_side)))
step = max(step, 1)
return bands[:, ::step, ::step], step
def _to_uint8(arr, vmin=None, vmax=None):
if arr.dtype == np.uint8:
return arr
@ -112,7 +123,7 @@ def _bands_to_pil(bands, stretch_params=None):
return Image.fromarray(rgb, mode="RGB")
def _read_pil_window(src, window, stretch_params=None):
bands = src.read(window=window, boundless=True, fill_value=0)
bands = src.read(window=window)
return _bands_to_pil(bands, stretch_params=stretch_params)
def infer_prompt_prob(processor, image, prompt):
@ -126,57 +137,169 @@ def infer_prompt_prob(processor, image, prompt):
prob = combine_masks_logits(state["masks_logits"])
return prob
def make_overview_bands(sub_bands, max_side):
_, h, w = sub_bands.shape
step = int(np.ceil(max(h, w) / float(max_side)))
step = max(step, 1)
return sub_bands[:, ::step, ::step], step
# ---------- 新增:中分辨率粗分割(分块推理) ----------
def _downsample_any(mask, factor):
if factor <= 1:
return mask
h, w = mask.shape
pad_h = (-h) % factor
pad_w = (-w) % factor
if pad_h or pad_w:
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=False)
h, w = mask.shape
h2 = h // factor
w2 = w // factor
return mask.reshape(h2, factor, w2, factor).any(axis=(1, 3))
# ---------- 精修分块 ----------
def refine_tiles(processor, src, prompt, band_coarse_cpu, tile_size, overlap, stretch_params=None):
def _build_band_ndimage(mask, radius):
if radius <= 0:
return np.zeros_like(mask, dtype=bool)
struct = ndimage.generate_binary_structure(2, 2)
dil = ndimage.binary_dilation(mask, structure=struct, iterations=radius)
ero = ndimage.binary_erosion(mask, structure=struct, iterations=radius)
return np.logical_xor(dil, ero)
def coarse_tile_segmentation(
processor,
src,
prompt,
tile_size,
overlap,
downsample_factor=4,
stretch_params=None,
nodata_mask=None,
use_tqdm=True,
):
"""
边缘带进行精修分割
band_coarse_cpu: 2D numpy bool指示需要精修的区域低分辨率
子区域进行分块粗分割返回低分辨率概率图float16
Args:
processor: 粗处理器(输入分辨率 coarse_resolution
src: rasterio 数据集(子区域)
prompt: 文本提示
tile_size: 粗分割块大小(原图像素)
overlap: 块重叠像素
stretch_params: 拉伸参数
nodata_mask: 子区域的 NoData 掩码bool, 与子区域同尺寸),用于过滤无效区域
Returns:
coarse_probs_lr: 子区域粗概率图 (H_lr,W_lr) float16
"""
height, width = src.height, src.width
full_probs = np.zeros((height, width), dtype=np.float16)
band_h, band_w = band_coarse_cpu.shape
scale_y = band_h / float(height)
scale_x = band_w / float(width)
height_lr = int(math.ceil(height / float(downsample_factor)))
width_lr = int(math.ceil(width / float(downsample_factor)))
full_probs_lr = np.zeros((height_lr, width_lr), dtype=np.float16)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
# 内部进度条leave=True 以便保留历史记录
with tqdm(total=total_tiles, desc="精修分块", unit="", leave=True) as pbar:
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
c_top = int(top * scale_y)
c_left = int(left * scale_x)
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1)
c_right = max(int(np.ceil(right * scale_x)), c_left + 1)
if use_tqdm:
iterator = tqdm(
tile_slices(height, width, tile_size, overlap),
total=total_tiles,
desc="粗分割分块",
unit="",
leave=True,
)
else:
iterator = tile_slices(height, width, tile_size, overlap)
if not band_coarse_cpu[c_top:c_bottom, c_left:c_right].any():
pbar.update(1)
for top, left, bottom, right in iterator:
# 如果该块完全在 NoData 区域内,则跳过(加速)
if nodata_mask is not None:
block_nodata = nodata_mask[top:bottom, left:right]
if block_nodata.all():
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window, stretch_params=stretch_params)
tile_prob = infer_prompt_prob(processor, crop, prompt)
if tile_prob is None:
continue
h_lr = int(math.ceil((bottom - top) / float(downsample_factor)))
w_lr = int(math.ceil((right - left) / float(downsample_factor)))
if tile_prob.shape[-2:] != (h_lr, w_lr):
tile_prob = upsample_prob(tile_prob, (h_lr, w_lr))
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
top_lr = top // downsample_factor
left_lr = left // downsample_factor
bottom_lr = top_lr + h_lr
right_lr = left_lr + w_lr
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr] = np.maximum(
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr], tile_prob_cpu
)
if nodata_mask is not None:
nodata_lr = _downsample_any(nodata_mask, downsample_factor)
full_probs_lr[nodata_lr] = 0.0
return full_probs_lr
# ---------- 精修分块(保持不变,但可复用之前的函数) ----------
def refine_tiles(
processor,
src,
prompt,
band_lr,
downsample_factor,
tile_size,
overlap,
stretch_params=None,
use_tqdm=True,
):
"""
对边缘带进行精修分割
band_lr: 2D numpy bool指示需要精修的区域低分辨率
"""
height, width = src.height, src.width
full_probs = np.zeros((height, width), dtype=np.float16)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
if use_tqdm:
iterator = tqdm(
tile_slices(height, width, tile_size, overlap),
total=total_tiles,
desc="精修分块",
unit="",
leave=True,
)
else:
iterator = tile_slices(height, width, tile_size, overlap)
for top, left, bottom, right in iterator:
c_top = top // downsample_factor
c_left = left // downsample_factor
c_bottom = int(math.ceil(bottom / float(downsample_factor)))
c_right = int(math.ceil(right / float(downsample_factor)))
c_bottom = max(c_bottom, c_top + 1)
c_right = max(c_right, c_left + 1)
if not band_lr[c_top:c_bottom, c_left:c_right].any():
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window, stretch_params=stretch_params)
tile_prob = infer_prompt_prob(processor, crop, prompt)
if tile_prob is None:
pbar.update(1)
continue
if tile_prob.shape[-2:] != (bottom - top, right - left):
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
tile_prob_cpu = tile_prob.detach().float().cpu().numpy().astype(np.float16)
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
full_probs[top:bottom, left:right] = np.maximum(
full_probs[top:bottom, left:right], tile_prob_cpu
)
pbar.update(1)
return full_probs
@ -230,80 +353,121 @@ def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=25
regions.append((left, top, right, bottom))
return regions
# ---------- 主程序 ----------
matplotlib.use("TkAgg")
DEFAULT_CONFIG = {
"image_path": r"E:\is2\guidingsahn\result.tif",
"mask_output_path": r"E:\is2\guidingsahn\result_mask.tif",
"prompt": "water body",
"overview_max_side": 1400,
"overview_threshold": 0.4,
"coarse_resolution": 1008,
"fine_resolution": 1008,
"coarse_threshold": 0.5,
"final_threshold": 0.5,
"band_radius": 64,
"fine_tile_size": 2048,
"fine_overlap": 256,
"coarse_tile_size": 4096,
"coarse_tile_overlap": 256,
"coarse_downsample_factor": 4,
"num_splits_y": 2,
"num_splits_x": 3,
"region_overlap": 256,
"min_area": 5000,
"keep_largest_only": False,
"use_tqdm": True,
"device": None,
}
# 参数设置
image_path = r"E:\is2\dingshanhu\result_caijian.tif"
mask_output_path = r"E:\is2\dingshanhu\result_maskV2.tif"
prompt = "water body"
coarse_read_max_side = 1200
coarse_resolution = 1008
fine_resolution = 1008
coarse_threshold = 0.5
final_threshold = 0.5
band_radius = 64
tile_size = 1536
overlap = 128
# ========== 分区域参数 ==========
num_splits_y = 2 # 纵向切分数
num_splits_x = 3 # 横向切分数(共 2x3=6 份)
region_overlap = 256 # 子区域之间的重叠像素数
# ===============================
def run_segmentation(config=None, progress_callback=None, log_callback=None, stop_event=None):
cfg = dict(DEFAULT_CONFIG)
if config:
cfg.update(config)
# ========== 后处理参数 ==========
min_area = 1000 # 最小面积阈值像素小于此值的连通域将被移除设为0或None表示不进行面积过滤
keep_largest_only = False # 是否只保留最大的连通域True/False
# ===============================
log = log_callback if log_callback is not None else print
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
device = cfg["device"]
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
log(f"使用设备: {device}")
# 加载模型(所有子区域共享同一个模型)
model = build_sam3_image_model().to(device).eval()
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device)
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device)
model = build_sam3_image_model().to(device).eval()
coarse_processor = Sam3Processor(
model, resolution=cfg["coarse_resolution"], device=device
)
fine_processor = Sam3Processor(model, resolution=cfg["fine_resolution"], device=device)
# 打开原始影像
with rasterio.open(image_path) as src:
nodata = src.nodata
print(f"原始影像NoData值: {nodata}")
height, width = src.height, src.width
image_path = cfg["image_path"]
mask_output_path = cfg["mask_output_path"]
# 划分区域
regions = split_into_regions(width, height, num_splits_y, num_splits_x, region_overlap)
print(f"共划分 {len(regions)} 个子区域")
with rasterio.open(image_path) as src:
nodata = src.nodata
log(f"原始影像NoData值: {nodata}")
height, width = src.height, src.width
# 创建全尺寸掩码数组CPU内存
full_mask = np.zeros((height, width), dtype=np.uint8)
band1 = src.read(1)
if np.issubdtype(band1.dtype, np.floating):
nodata_mask_full = (
np.isclose(band1, float(nodata))
if nodata is not None
else np.zeros_like(band1, dtype=bool)
)
else:
nodata_mask_full = (
(band1 == nodata) if nodata is not None else np.zeros_like(band1, dtype=bool)
)
# 子区域总进度条
with tqdm(total=len(regions), desc="处理子区域", unit="子区域") as region_pbar:
for idx, (left, top, right, bottom) in enumerate(regions):
print(f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})")
regions = split_into_regions(
width, height, cfg["num_splits_y"], cfg["num_splits_x"], cfg["region_overlap"]
)
log(f"共划分 {len(regions)} 个子区域")
full_mask = np.zeros((height, width), dtype=np.uint8)
use_tqdm = bool(cfg.get("use_tqdm", True))
if use_tqdm:
region_iter = tqdm(
list(enumerate(regions)),
total=len(regions),
desc="处理子区域",
unit="子区域",
)
else:
region_iter = enumerate(regions)
for idx, (left, top, right, bottom) in region_iter:
if stop_event is not None and stop_event.is_set():
log("已停止")
return
if progress_callback is not None:
progress_callback("region", idx + 1, len(regions))
log(
f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})"
)
region_w = right - left
region_h = bottom - top
# 读取子区域数据
window = Window(left, top, region_w, region_h)
sub_bands = src.read(window=window) # shape: (bands, h, w)
sub_bands = src.read(window=window)
# 构建子区域元数据
sub_profile = src.profile.copy()
if not src.is_tiled:
sub_profile.pop('blockxsize', None)
sub_profile.pop('blockysize', None)
sub_profile['tiled'] = False
sub_profile.pop("blockxsize", None)
sub_profile.pop("blockysize", None)
sub_profile["tiled"] = False
else:
sub_profile['tiled'] = True
sub_profile.update({
'height': region_h,
'width': region_w,
'transform': rasterio.windows.transform(window, src.transform)
})
sub_profile["tiled"] = True
sub_profile.update(
{
"height": region_h,
"width": region_w,
"transform": rasterio.windows.transform(window, src.transform),
}
)
sub_nodata = nodata_mask_full[top:bottom, left:right]
# 将子区域数据包装成内存中的 rasterio 数据集
with MemoryFile() as memfile:
with memfile.open(**sub_profile) as sub_dst:
sub_dst.write(sub_bands)
@ -311,84 +475,144 @@ with rasterio.open(image_path) as src:
stretch_params = compute_stretch_params(sub_bands)
overview_bands, overview_step = make_overview_bands(
sub_bands, max_side=coarse_read_max_side
sub_bands, max_side=cfg["overview_max_side"]
)
overview_img = _bands_to_pil(
overview_bands, stretch_params=stretch_params
)
overview_prob = infer_prompt_prob(
coarse_processor, overview_img, cfg["prompt"]
)
overview_img = _bands_to_pil(overview_bands, stretch_params=stretch_params)
overview_prob = infer_prompt_prob(coarse_processor, overview_img, prompt)
if overview_prob is None:
overview_prob_cpu = np.zeros((overview_img.height, overview_img.width), dtype=np.float16)
overview_mask_small = np.zeros(
(overview_img.height, overview_img.width), dtype=bool
)
else:
overview_prob_cpu = overview_prob.detach().float().cpu().numpy().astype(np.float16)
overview_mask_cpu = overview_prob_cpu > coarse_threshold
overview_mask_small = (
(overview_prob > cfg["overview_threshold"])
.detach()
.cpu()
.numpy()
)
scale = overview_img.height / float(region_h)
band_radius_small = max(int(round(band_radius * scale)), 1)
band_small = build_band(torch.from_numpy(overview_mask_cpu), band_radius_small).numpy()
if sub_nodata.any():
nodata_overview = sub_nodata[::overview_step, ::overview_step]
overview_mask_small[nodata_overview] = False
fine_probs_cpu = refine_tiles(
fine_processor,
sub_src,
prompt,
band_small,
tile_size=tile_size,
overlap=overlap,
coarse_probs = coarse_tile_segmentation(
processor=coarse_processor,
src=sub_src,
prompt=cfg["prompt"],
tile_size=cfg["coarse_tile_size"],
overlap=cfg["coarse_tile_overlap"],
downsample_factor=cfg["coarse_downsample_factor"],
stretch_params=stretch_params,
nodata_mask=sub_nodata if sub_nodata.any() else None,
use_tqdm=use_tqdm,
)
fine_mask = fine_probs_cpu > final_threshold
coarse_mask_lr = coarse_probs > cfg["coarse_threshold"]
overview_mask_lr = np.array(
Image.fromarray(
overview_mask_small.astype(np.uint8), mode="L"
).resize(
(coarse_mask_lr.shape[1], coarse_mask_lr.shape[0]),
resample=Image.NEAREST,
)
).astype(bool)
combined_mask_lr = coarse_mask_lr | overview_mask_lr
band_radius_lr = max(
int(round(cfg["band_radius"] / float(cfg["coarse_downsample_factor"]))),
1,
)
band_lr = _build_band_ndimage(combined_mask_lr, band_radius_lr)
fine_probs = refine_tiles(
processor=fine_processor,
src=sub_src,
prompt=cfg["prompt"],
band_lr=band_lr,
downsample_factor=cfg["coarse_downsample_factor"],
tile_size=cfg["fine_tile_size"],
overlap=cfg["fine_overlap"],
stretch_params=stretch_params,
use_tqdm=use_tqdm,
)
fine_mask = fine_probs > cfg["final_threshold"]
coarse_mask_full = np.array(
Image.fromarray(overview_mask_cpu.astype(np.uint8), mode="L").resize(
Image.fromarray(coarse_mask_lr.astype(np.uint8), mode="L").resize(
(region_w, region_h), resample=Image.NEAREST
)
).astype(bool)
sub_mask_np = (fine_mask | coarse_mask_full).astype(np.uint8)
overview_mask_full = np.array(
Image.fromarray(
overview_mask_small.astype(np.uint8), mode="L"
).resize((region_w, region_h), resample=Image.NEAREST)
).astype(bool)
sub_mask_np = (fine_mask | coarse_mask_full | overview_mask_full).astype(
np.uint8
)
# 合并到全尺寸掩码(重叠区域取最大值)
full_mask[top:bottom, left:right] = np.maximum(
full_mask[top:bottom, left:right], sub_mask_np
)
# 更新子区域进度条
region_pbar.update(1)
log("\n后处理填充内部NoData空洞...")
if nodata is not None:
if np.issubdtype(band1.dtype, np.floating):
nodata_mask = np.isclose(band1, float(nodata))
else:
nodata_mask = band1 == nodata
# ========== 后处理 ==========
print("\n后处理填充内部NoData空洞...")
if nodata is not None:
band1 = src.read(1)
if np.issubdtype(band1.dtype, np.floating):
nodata_mask = np.isclose(band1, float(nodata))
else:
nodata_mask = (band1 == nodata)
labeled_mask, _ = ndimage.label(nodata_mask)
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
boundary_mask[0, :] = True
boundary_mask[-1, :] = True
boundary_mask[:, 0] = True
boundary_mask[:, -1] = True
labeled_mask, num_features = ndimage.label(nodata_mask)
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
boundary_mask[0, :] = True
boundary_mask[-1, :] = True
boundary_mask[:, 0] = True
boundary_mask[:, -1] = True
boundary_labels = set(labeled_mask[boundary_mask])
internal_nodata_mask = np.isin(
labeled_mask, list(boundary_labels), invert=True
) & nodata_mask
boundary_labels = set(labeled_mask[boundary_mask])
internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask
if internal_nodata_mask.any():
struct = ndimage.generate_binary_structure(2, 2)
internal_dilated = ndimage.binary_dilation(
internal_nodata_mask, structure=struct, iterations=7
)
full_mask[internal_dilated] = 1
log(
f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}"
)
else:
log(" 无内部NoData区域")
if internal_nodata_mask.any():
struct = ndimage.generate_binary_structure(2, 2)
internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7)
full_mask[internal_dilated] = 1
print(f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}")
else:
print(" 无内部NoData区域")
log("后处理:面积过滤...")
if cfg["min_area"] is not None or cfg["keep_largest_only"]:
original_count = np.sum(full_mask)
full_mask = filter_by_area(
full_mask,
min_area=cfg["min_area"],
keep_largest_only=cfg["keep_largest_only"],
)
filtered_count = np.sum(full_mask)
log(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
print("后处理:面积过滤...")
if min_area is not None or keep_largest_only:
original_count = np.sum(full_mask)
full_mask = filter_by_area(full_mask, min_area=min_area, keep_largest_only=keep_largest_only)
filtered_count = np.sum(full_mask)
print(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
profile = src.profile.copy()
profile.update(count=1, dtype="uint8", compress="lzw")
with rasterio.open(mask_output_path, "w", **profile) as dst:
dst.write(full_mask, 1)
# ========== 保存结果 ==========
profile = src.profile.copy()
profile.update(count=1, dtype="uint8", compress='lzw')
with rasterio.open(mask_output_path, "w", **profile) as dst:
dst.write(full_mask, 1)
log(f"分割完成,结果已保存至:{mask_output_path}")
print(f"分割完成,结果已保存至:{mask_output_path}")
def main():
matplotlib.use("TkAgg")
run_segmentation()
if __name__ == "__main__":
main()