619 lines
22 KiB
Python
619 lines
22 KiB
Python
"""
|
||
使用SAM3模型分割超大遥感影像中的水体(三层分割策略)
|
||
|
||
流程:
|
||
1. 将影像划分为若干有重叠的子区域(第一层)。
|
||
2. 对每个子区域进行中分辨率粗分割:以4096为块大小,滑动窗口推理,得到粗掩码(第二层)。
|
||
3. 基于粗掩码构建边缘带,对边缘带进行精细分块推理,得到精修掩码(第三层)。
|
||
4. 合并所有子区域的掩码(重叠区域取最大值)。
|
||
5. 后处理:填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。
|
||
6. 保存结果。
|
||
|
||
优点:
|
||
- 粗分割采用分块推理,避免降采样丢失细节,提高召回率。
|
||
- 精修只处理边缘带,节省计算资源。
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn.functional as F
|
||
import matplotlib
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
from PIL import Image
|
||
import rasterio
|
||
from rasterio.windows import Window
|
||
from rasterio.io import MemoryFile
|
||
from tqdm import tqdm
|
||
from scipy import ndimage
|
||
import math
|
||
|
||
from sam3.model_builder import build_sam3_image_model
|
||
from sam3.model.sam3_image_processor import Sam3Processor
|
||
|
||
|
||
# ---------- 工具函数(保持不变) ----------
|
||
def binary_dilate(mask, radius):
|
||
if radius <= 0:
|
||
return mask
|
||
kernel = 2 * radius + 1
|
||
return F.max_pool2d(mask.float(), kernel_size=kernel, stride=1, padding=radius) > 0.5
|
||
|
||
def binary_erode(mask, radius):
|
||
if radius <= 0:
|
||
return mask
|
||
return ~binary_dilate(~mask, radius)
|
||
|
||
def combine_masks_logits(masks_logits):
|
||
if masks_logits.numel() == 0:
|
||
return None
|
||
probs = masks_logits.squeeze(1)
|
||
if probs.dim() == 2:
|
||
return probs
|
||
return torch.amax(probs, dim=0)
|
||
|
||
def upsample_prob(prob, size):
|
||
return F.interpolate(prob[None, None, ...], size=size, mode="bilinear", align_corners=False).squeeze(0).squeeze(0)
|
||
|
||
def build_band(mask, radius):
|
||
"""mask: 2D tensor (CPU or GPU)"""
|
||
mask_4d = mask[None, None, ...] # (1,1,H,W)
|
||
dil = binary_dilate(mask_4d, radius)
|
||
ero = binary_erode(mask_4d, radius)
|
||
band = torch.logical_xor(dil, ero).squeeze(0).squeeze(0)
|
||
return band
|
||
|
||
def tile_slices(height, width, tile_size, overlap):
|
||
stride = max(tile_size - overlap, 1)
|
||
for top in range(0, height, stride):
|
||
for left in range(0, width, stride):
|
||
bottom = min(top + tile_size, height)
|
||
right = min(left + tile_size, width)
|
||
yield top, left, bottom, right
|
||
|
||
def compute_stretch_params(bands, sample_max_pixels=2_000_000):
|
||
c, h, w = bands.shape
|
||
step = int(np.ceil(np.sqrt((h * w) / float(sample_max_pixels))))
|
||
step = max(step, 1)
|
||
sample = bands[:, ::step, ::step].astype(np.float32)
|
||
vmin = np.percentile(sample, 2.0, axis=(1, 2))
|
||
vmax = np.percentile(sample, 98.0, axis=(1, 2))
|
||
vmax = np.maximum(vmax, vmin + 1e-6)
|
||
return vmin, vmax
|
||
|
||
|
||
def make_overview_bands(bands, max_side):
|
||
_, h, w = bands.shape
|
||
step = int(math.ceil(max(h, w) / float(max_side)))
|
||
step = max(step, 1)
|
||
return bands[:, ::step, ::step], step
|
||
|
||
|
||
def _to_uint8(arr, vmin=None, vmax=None):
|
||
if arr.dtype == np.uint8:
|
||
return arr
|
||
arr = arr.astype(np.float32)
|
||
if vmin is None or vmax is None:
|
||
vmin = np.percentile(arr, 2.0)
|
||
vmax = np.percentile(arr, 98.0)
|
||
if vmax <= vmin:
|
||
return np.zeros_like(arr, dtype=np.uint8)
|
||
arr = (arr - vmin) / (vmax - vmin)
|
||
arr = np.clip(arr, 0.0, 1.0)
|
||
return (arr * 255.0).astype(np.uint8)
|
||
|
||
def _bands_to_pil(bands, stretch_params=None):
|
||
if bands.ndim != 3:
|
||
raise ValueError("bands must be (C,H,W)")
|
||
c, h, w = bands.shape
|
||
if c == 1:
|
||
rgb = np.repeat(bands, 3, axis=0)
|
||
else:
|
||
rgb = bands[:3]
|
||
if stretch_params is not None:
|
||
vmin, vmax = stretch_params
|
||
vmin3 = vmin[:3] if vmin.shape[0] >= 3 else np.repeat(vmin[:1], 3)
|
||
vmax3 = vmax[:3] if vmax.shape[0] >= 3 else np.repeat(vmax[:1], 3)
|
||
rgb_u8 = np.empty_like(rgb, dtype=np.uint8)
|
||
for i in range(3):
|
||
rgb_u8[i] = _to_uint8(rgb[i], vmin=vmin3[i], vmax=vmax3[i])
|
||
rgb = rgb_u8
|
||
else:
|
||
rgb = _to_uint8(rgb)
|
||
rgb = np.transpose(rgb, (1, 2, 0))
|
||
return Image.fromarray(rgb, mode="RGB")
|
||
|
||
def _read_pil_window(src, window, stretch_params=None):
|
||
bands = src.read(window=window)
|
||
return _bands_to_pil(bands, stretch_params=stretch_params)
|
||
|
||
def infer_prompt_prob(processor, image, prompt):
|
||
if processor.device != "cpu":
|
||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||
state = processor.set_image(image)
|
||
state = processor.set_text_prompt(prompt=prompt, state=state)
|
||
else:
|
||
state = processor.set_image(image)
|
||
state = processor.set_text_prompt(prompt=prompt, state=state)
|
||
prob = combine_masks_logits(state["masks_logits"])
|
||
return prob
|
||
|
||
# ---------- 新增:中分辨率粗分割(分块推理) ----------
|
||
def _downsample_any(mask, factor):
|
||
if factor <= 1:
|
||
return mask
|
||
h, w = mask.shape
|
||
pad_h = (-h) % factor
|
||
pad_w = (-w) % factor
|
||
if pad_h or pad_w:
|
||
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=False)
|
||
h, w = mask.shape
|
||
h2 = h // factor
|
||
w2 = w // factor
|
||
return mask.reshape(h2, factor, w2, factor).any(axis=(1, 3))
|
||
|
||
|
||
def _build_band_ndimage(mask, radius):
|
||
if radius <= 0:
|
||
return np.zeros_like(mask, dtype=bool)
|
||
struct = ndimage.generate_binary_structure(2, 2)
|
||
dil = ndimage.binary_dilation(mask, structure=struct, iterations=radius)
|
||
ero = ndimage.binary_erosion(mask, structure=struct, iterations=radius)
|
||
return np.logical_xor(dil, ero)
|
||
|
||
|
||
def coarse_tile_segmentation(
|
||
processor,
|
||
src,
|
||
prompt,
|
||
tile_size,
|
||
overlap,
|
||
downsample_factor=4,
|
||
stretch_params=None,
|
||
nodata_mask=None,
|
||
use_tqdm=True,
|
||
):
|
||
"""
|
||
对子区域进行分块粗分割,返回低分辨率概率图(float16)
|
||
|
||
Args:
|
||
processor: 粗处理器(输入分辨率 coarse_resolution)
|
||
src: rasterio 数据集(子区域)
|
||
prompt: 文本提示
|
||
tile_size: 粗分割块大小(原图像素)
|
||
overlap: 块重叠像素
|
||
stretch_params: 拉伸参数
|
||
nodata_mask: 子区域的 NoData 掩码(bool, 与子区域同尺寸),用于过滤无效区域
|
||
|
||
Returns:
|
||
coarse_probs_lr: 子区域粗概率图 (H_lr,W_lr) float16
|
||
"""
|
||
height, width = src.height, src.width
|
||
height_lr = int(math.ceil(height / float(downsample_factor)))
|
||
width_lr = int(math.ceil(width / float(downsample_factor)))
|
||
full_probs_lr = np.zeros((height_lr, width_lr), dtype=np.float16)
|
||
|
||
stride = max(tile_size - overlap, 1)
|
||
num_tiles_y = (height + stride - 1) // stride
|
||
num_tiles_x = (width + stride - 1) // stride
|
||
total_tiles = num_tiles_y * num_tiles_x
|
||
|
||
if use_tqdm:
|
||
iterator = tqdm(
|
||
tile_slices(height, width, tile_size, overlap),
|
||
total=total_tiles,
|
||
desc="粗分割分块",
|
||
unit="块",
|
||
leave=True,
|
||
)
|
||
else:
|
||
iterator = tile_slices(height, width, tile_size, overlap)
|
||
|
||
for top, left, bottom, right in iterator:
|
||
# 如果该块完全在 NoData 区域内,则跳过(加速)
|
||
if nodata_mask is not None:
|
||
block_nodata = nodata_mask[top:bottom, left:right]
|
||
if block_nodata.all():
|
||
continue
|
||
|
||
window = Window(left, top, right - left, bottom - top)
|
||
crop = _read_pil_window(src, window, stretch_params=stretch_params)
|
||
tile_prob = infer_prompt_prob(processor, crop, prompt)
|
||
if tile_prob is None:
|
||
continue
|
||
|
||
h_lr = int(math.ceil((bottom - top) / float(downsample_factor)))
|
||
w_lr = int(math.ceil((right - left) / float(downsample_factor)))
|
||
if tile_prob.shape[-2:] != (h_lr, w_lr):
|
||
tile_prob = upsample_prob(tile_prob, (h_lr, w_lr))
|
||
|
||
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
|
||
top_lr = top // downsample_factor
|
||
left_lr = left // downsample_factor
|
||
bottom_lr = top_lr + h_lr
|
||
right_lr = left_lr + w_lr
|
||
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr] = np.maximum(
|
||
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr], tile_prob_cpu
|
||
)
|
||
|
||
if nodata_mask is not None:
|
||
nodata_lr = _downsample_any(nodata_mask, downsample_factor)
|
||
full_probs_lr[nodata_lr] = 0.0
|
||
|
||
return full_probs_lr
|
||
|
||
# ---------- 精修分块(保持不变,但可复用之前的函数) ----------
|
||
def refine_tiles(
|
||
processor,
|
||
src,
|
||
prompt,
|
||
band_lr,
|
||
downsample_factor,
|
||
tile_size,
|
||
overlap,
|
||
stretch_params=None,
|
||
use_tqdm=True,
|
||
):
|
||
"""
|
||
对边缘带进行精修分割
|
||
band_lr: 2D numpy bool,指示需要精修的区域(低分辨率)
|
||
"""
|
||
height, width = src.height, src.width
|
||
full_probs = np.zeros((height, width), dtype=np.float16)
|
||
|
||
stride = max(tile_size - overlap, 1)
|
||
num_tiles_y = (height + stride - 1) // stride
|
||
num_tiles_x = (width + stride - 1) // stride
|
||
total_tiles = num_tiles_y * num_tiles_x
|
||
|
||
if use_tqdm:
|
||
iterator = tqdm(
|
||
tile_slices(height, width, tile_size, overlap),
|
||
total=total_tiles,
|
||
desc="精修分块",
|
||
unit="块",
|
||
leave=True,
|
||
)
|
||
else:
|
||
iterator = tile_slices(height, width, tile_size, overlap)
|
||
|
||
for top, left, bottom, right in iterator:
|
||
c_top = top // downsample_factor
|
||
c_left = left // downsample_factor
|
||
c_bottom = int(math.ceil(bottom / float(downsample_factor)))
|
||
c_right = int(math.ceil(right / float(downsample_factor)))
|
||
c_bottom = max(c_bottom, c_top + 1)
|
||
c_right = max(c_right, c_left + 1)
|
||
|
||
if not band_lr[c_top:c_bottom, c_left:c_right].any():
|
||
continue
|
||
|
||
window = Window(left, top, right - left, bottom - top)
|
||
crop = _read_pil_window(src, window, stretch_params=stretch_params)
|
||
tile_prob = infer_prompt_prob(processor, crop, prompt)
|
||
if tile_prob is None:
|
||
continue
|
||
|
||
if tile_prob.shape[-2:] != (bottom - top, right - left):
|
||
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
|
||
|
||
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
|
||
full_probs[top:bottom, left:right] = np.maximum(
|
||
full_probs[top:bottom, left:right], tile_prob_cpu
|
||
)
|
||
|
||
return full_probs
|
||
|
||
# ---------- 后处理函数(面积过滤) ----------
|
||
def filter_by_area(mask, min_area=None, keep_largest_only=False):
|
||
"""
|
||
对二值掩码进行连通域面积过滤。
|
||
mask: np.ndarray, dtype=bool or uint8, shape (H,W)
|
||
min_area: int, 最小面积阈值(像素),小于此值的连通域移除。若为None则不过滤。
|
||
keep_largest_only: bool, 若True,只保留面积最大的连通域(同时受min_area约束)。
|
||
返回过滤后的掩码 (uint8, 0/1)。
|
||
"""
|
||
labeled, num = ndimage.label(mask)
|
||
if num == 0:
|
||
return mask.astype(np.uint8)
|
||
|
||
sizes = ndimage.sum(mask, labeled, range(1, num+1))
|
||
keep_labels = []
|
||
|
||
if keep_largest_only:
|
||
max_idx = np.argmax(sizes)
|
||
if min_area is None or sizes[max_idx] >= min_area:
|
||
keep_labels = [max_idx + 1]
|
||
else:
|
||
for i, size in enumerate(sizes, start=1):
|
||
if min_area is None or size >= min_area:
|
||
keep_labels.append(i)
|
||
|
||
if not keep_labels:
|
||
return np.zeros_like(mask, dtype=np.uint8)
|
||
|
||
keep_mask = np.isin(labeled, keep_labels)
|
||
return keep_mask.astype(np.uint8)
|
||
|
||
# ---------- 划分子区域函数 ----------
|
||
def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=256):
|
||
"""
|
||
将图像划分为若干有重叠的子区域,返回每个子区域的 (left, top, right, bottom)
|
||
坐标(像素坐标,相对于原图)。
|
||
"""
|
||
regions = []
|
||
tile_w = (width + num_splits_x - 1) // num_splits_x
|
||
tile_h = (height + num_splits_y - 1) // num_splits_y
|
||
|
||
for i in range(num_splits_y):
|
||
for j in range(num_splits_x):
|
||
left = max(j * tile_w - overlap, 0)
|
||
top = max(i * tile_h - overlap, 0)
|
||
right = min((j + 1) * tile_w + overlap, width)
|
||
bottom = min((i + 1) * tile_h + overlap, height)
|
||
regions.append((left, top, right, bottom))
|
||
return regions
|
||
|
||
DEFAULT_CONFIG = {
|
||
"image_path": r"E:\is2\guidingsahn\result.tif",
|
||
"mask_output_path": r"E:\is2\guidingsahn\result_mask.tif",
|
||
"prompt": "water body",
|
||
"overview_max_side": 1400,
|
||
"overview_threshold": 0.4,
|
||
"coarse_resolution": 1008,
|
||
"fine_resolution": 1008,
|
||
"coarse_threshold": 0.5,
|
||
"final_threshold": 0.5,
|
||
"band_radius": 64,
|
||
"fine_tile_size": 2048,
|
||
"fine_overlap": 256,
|
||
"coarse_tile_size": 4096,
|
||
"coarse_tile_overlap": 256,
|
||
"coarse_downsample_factor": 4,
|
||
"num_splits_y": 2,
|
||
"num_splits_x": 3,
|
||
"region_overlap": 256,
|
||
"min_area": 5000,
|
||
"keep_largest_only": False,
|
||
"use_tqdm": True,
|
||
"device": None,
|
||
}
|
||
|
||
|
||
def run_segmentation(config=None, progress_callback=None, log_callback=None, stop_event=None):
|
||
cfg = dict(DEFAULT_CONFIG)
|
||
if config:
|
||
cfg.update(config)
|
||
|
||
log = log_callback if log_callback is not None else print
|
||
|
||
device = cfg["device"]
|
||
if device is None:
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
log(f"使用设备: {device}")
|
||
|
||
model = build_sam3_image_model().to(device).eval()
|
||
coarse_processor = Sam3Processor(
|
||
model, resolution=cfg["coarse_resolution"], device=device
|
||
)
|
||
fine_processor = Sam3Processor(model, resolution=cfg["fine_resolution"], device=device)
|
||
|
||
image_path = cfg["image_path"]
|
||
mask_output_path = cfg["mask_output_path"]
|
||
|
||
with rasterio.open(image_path) as src:
|
||
nodata = src.nodata
|
||
log(f"原始影像NoData值: {nodata}")
|
||
height, width = src.height, src.width
|
||
|
||
band1 = src.read(1)
|
||
if np.issubdtype(band1.dtype, np.floating):
|
||
nodata_mask_full = (
|
||
np.isclose(band1, float(nodata))
|
||
if nodata is not None
|
||
else np.zeros_like(band1, dtype=bool)
|
||
)
|
||
else:
|
||
nodata_mask_full = (
|
||
(band1 == nodata) if nodata is not None else np.zeros_like(band1, dtype=bool)
|
||
)
|
||
|
||
regions = split_into_regions(
|
||
width, height, cfg["num_splits_y"], cfg["num_splits_x"], cfg["region_overlap"]
|
||
)
|
||
log(f"共划分 {len(regions)} 个子区域")
|
||
|
||
full_mask = np.zeros((height, width), dtype=np.uint8)
|
||
|
||
use_tqdm = bool(cfg.get("use_tqdm", True))
|
||
if use_tqdm:
|
||
region_iter = tqdm(
|
||
list(enumerate(regions)),
|
||
total=len(regions),
|
||
desc="处理子区域",
|
||
unit="子区域",
|
||
)
|
||
else:
|
||
region_iter = enumerate(regions)
|
||
|
||
for idx, (left, top, right, bottom) in region_iter:
|
||
if stop_event is not None and stop_event.is_set():
|
||
log("已停止")
|
||
return
|
||
if progress_callback is not None:
|
||
progress_callback("region", idx + 1, len(regions))
|
||
|
||
log(
|
||
f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})"
|
||
)
|
||
region_w = right - left
|
||
region_h = bottom - top
|
||
|
||
window = Window(left, top, region_w, region_h)
|
||
sub_bands = src.read(window=window)
|
||
|
||
sub_profile = src.profile.copy()
|
||
if not src.is_tiled:
|
||
sub_profile.pop("blockxsize", None)
|
||
sub_profile.pop("blockysize", None)
|
||
sub_profile["tiled"] = False
|
||
else:
|
||
sub_profile["tiled"] = True
|
||
sub_profile.update(
|
||
{
|
||
"height": region_h,
|
||
"width": region_w,
|
||
"transform": rasterio.windows.transform(window, src.transform),
|
||
}
|
||
)
|
||
|
||
sub_nodata = nodata_mask_full[top:bottom, left:right]
|
||
|
||
with MemoryFile() as memfile:
|
||
with memfile.open(**sub_profile) as sub_dst:
|
||
sub_dst.write(sub_bands)
|
||
with memfile.open() as sub_src:
|
||
stretch_params = compute_stretch_params(sub_bands)
|
||
|
||
overview_bands, overview_step = make_overview_bands(
|
||
sub_bands, max_side=cfg["overview_max_side"]
|
||
)
|
||
overview_img = _bands_to_pil(
|
||
overview_bands, stretch_params=stretch_params
|
||
)
|
||
overview_prob = infer_prompt_prob(
|
||
coarse_processor, overview_img, cfg["prompt"]
|
||
)
|
||
if overview_prob is None:
|
||
overview_mask_small = np.zeros(
|
||
(overview_img.height, overview_img.width), dtype=bool
|
||
)
|
||
else:
|
||
overview_mask_small = (
|
||
(overview_prob > cfg["overview_threshold"])
|
||
.detach()
|
||
.cpu()
|
||
.numpy()
|
||
)
|
||
|
||
if sub_nodata.any():
|
||
nodata_overview = sub_nodata[::overview_step, ::overview_step]
|
||
overview_mask_small[nodata_overview] = False
|
||
|
||
coarse_probs = coarse_tile_segmentation(
|
||
processor=coarse_processor,
|
||
src=sub_src,
|
||
prompt=cfg["prompt"],
|
||
tile_size=cfg["coarse_tile_size"],
|
||
overlap=cfg["coarse_tile_overlap"],
|
||
downsample_factor=cfg["coarse_downsample_factor"],
|
||
stretch_params=stretch_params,
|
||
nodata_mask=sub_nodata if sub_nodata.any() else None,
|
||
use_tqdm=use_tqdm,
|
||
)
|
||
coarse_mask_lr = coarse_probs > cfg["coarse_threshold"]
|
||
|
||
overview_mask_lr = np.array(
|
||
Image.fromarray(
|
||
overview_mask_small.astype(np.uint8), mode="L"
|
||
).resize(
|
||
(coarse_mask_lr.shape[1], coarse_mask_lr.shape[0]),
|
||
resample=Image.NEAREST,
|
||
)
|
||
).astype(bool)
|
||
combined_mask_lr = coarse_mask_lr | overview_mask_lr
|
||
|
||
band_radius_lr = max(
|
||
int(round(cfg["band_radius"] / float(cfg["coarse_downsample_factor"]))),
|
||
1,
|
||
)
|
||
band_lr = _build_band_ndimage(combined_mask_lr, band_radius_lr)
|
||
|
||
fine_probs = refine_tiles(
|
||
processor=fine_processor,
|
||
src=sub_src,
|
||
prompt=cfg["prompt"],
|
||
band_lr=band_lr,
|
||
downsample_factor=cfg["coarse_downsample_factor"],
|
||
tile_size=cfg["fine_tile_size"],
|
||
overlap=cfg["fine_overlap"],
|
||
stretch_params=stretch_params,
|
||
use_tqdm=use_tqdm,
|
||
)
|
||
fine_mask = fine_probs > cfg["final_threshold"]
|
||
|
||
coarse_mask_full = np.array(
|
||
Image.fromarray(coarse_mask_lr.astype(np.uint8), mode="L").resize(
|
||
(region_w, region_h), resample=Image.NEAREST
|
||
)
|
||
).astype(bool)
|
||
overview_mask_full = np.array(
|
||
Image.fromarray(
|
||
overview_mask_small.astype(np.uint8), mode="L"
|
||
).resize((region_w, region_h), resample=Image.NEAREST)
|
||
).astype(bool)
|
||
sub_mask_np = (fine_mask | coarse_mask_full | overview_mask_full).astype(
|
||
np.uint8
|
||
)
|
||
|
||
full_mask[top:bottom, left:right] = np.maximum(
|
||
full_mask[top:bottom, left:right], sub_mask_np
|
||
)
|
||
|
||
log("\n后处理:填充内部NoData空洞...")
|
||
if nodata is not None:
|
||
if np.issubdtype(band1.dtype, np.floating):
|
||
nodata_mask = np.isclose(band1, float(nodata))
|
||
else:
|
||
nodata_mask = band1 == nodata
|
||
|
||
labeled_mask, _ = ndimage.label(nodata_mask)
|
||
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
|
||
boundary_mask[0, :] = True
|
||
boundary_mask[-1, :] = True
|
||
boundary_mask[:, 0] = True
|
||
boundary_mask[:, -1] = True
|
||
|
||
boundary_labels = set(labeled_mask[boundary_mask])
|
||
internal_nodata_mask = np.isin(
|
||
labeled_mask, list(boundary_labels), invert=True
|
||
) & nodata_mask
|
||
|
||
if internal_nodata_mask.any():
|
||
struct = ndimage.generate_binary_structure(2, 2)
|
||
internal_dilated = ndimage.binary_dilation(
|
||
internal_nodata_mask, structure=struct, iterations=7
|
||
)
|
||
full_mask[internal_dilated] = 1
|
||
log(
|
||
f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}"
|
||
)
|
||
else:
|
||
log(" 无内部NoData区域")
|
||
|
||
log("后处理:面积过滤...")
|
||
if cfg["min_area"] is not None or cfg["keep_largest_only"]:
|
||
original_count = np.sum(full_mask)
|
||
full_mask = filter_by_area(
|
||
full_mask,
|
||
min_area=cfg["min_area"],
|
||
keep_largest_only=cfg["keep_largest_only"],
|
||
)
|
||
filtered_count = np.sum(full_mask)
|
||
log(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
|
||
|
||
profile = src.profile.copy()
|
||
profile.update(count=1, dtype="uint8", compress="lzw")
|
||
with rasterio.open(mask_output_path, "w", **profile) as dst:
|
||
dst.write(full_mask, 1)
|
||
|
||
log(f"分割完成,结果已保存至:{mask_output_path}")
|
||
|
||
|
||
def main():
|
||
matplotlib.use("TkAgg")
|
||
run_segmentation()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|