GUI
This commit is contained in:
518
water_V5.py
518
water_V5.py
@ -1,14 +1,17 @@
|
||||
"""
|
||||
使用SAM3模型分割超大遥感影像中的水体(分区域处理,优化内存)
|
||||
使用SAM3模型分割超大遥感影像中的水体(三层分割策略)
|
||||
|
||||
流程:
|
||||
1. 将影像划分为若干有重叠的子区域。
|
||||
2. 对每个子区域独立执行粗分割、边缘带构建、精修分割。
|
||||
3. 合并所有子区域的掩码(重叠区域取最大值)。
|
||||
4. 后处理:填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。
|
||||
5. 保存结果。
|
||||
1. 将影像划分为若干有重叠的子区域(第一层)。
|
||||
2. 对每个子区域进行中分辨率粗分割:以4096为块大小,滑动窗口推理,得到粗掩码(第二层)。
|
||||
3. 基于粗掩码构建边缘带,对边缘带进行精细分块推理,得到精修掩码(第三层)。
|
||||
4. 合并所有子区域的掩码(重叠区域取最大值)。
|
||||
5. 后处理:填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。
|
||||
6. 保存结果。
|
||||
|
||||
优点:避免在全尺寸概率图上进行GPU操作,显著降低显存占用。
|
||||
优点:
|
||||
- 粗分割采用分块推理,避免降采样丢失细节,提高召回率。
|
||||
- 精修只处理边缘带,节省计算资源。
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -22,6 +25,7 @@ from rasterio.windows import Window
|
||||
from rasterio.io import MemoryFile
|
||||
from tqdm import tqdm
|
||||
from scipy import ndimage
|
||||
import math
|
||||
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
@ -77,6 +81,13 @@ def compute_stretch_params(bands, sample_max_pixels=2_000_000):
|
||||
return vmin, vmax
|
||||
|
||||
|
||||
def make_overview_bands(bands, max_side):
|
||||
_, h, w = bands.shape
|
||||
step = int(math.ceil(max(h, w) / float(max_side)))
|
||||
step = max(step, 1)
|
||||
return bands[:, ::step, ::step], step
|
||||
|
||||
|
||||
def _to_uint8(arr, vmin=None, vmax=None):
|
||||
if arr.dtype == np.uint8:
|
||||
return arr
|
||||
@ -112,7 +123,7 @@ def _bands_to_pil(bands, stretch_params=None):
|
||||
return Image.fromarray(rgb, mode="RGB")
|
||||
|
||||
def _read_pil_window(src, window, stretch_params=None):
|
||||
bands = src.read(window=window, boundless=True, fill_value=0)
|
||||
bands = src.read(window=window)
|
||||
return _bands_to_pil(bands, stretch_params=stretch_params)
|
||||
|
||||
def infer_prompt_prob(processor, image, prompt):
|
||||
@ -126,57 +137,169 @@ def infer_prompt_prob(processor, image, prompt):
|
||||
prob = combine_masks_logits(state["masks_logits"])
|
||||
return prob
|
||||
|
||||
def make_overview_bands(sub_bands, max_side):
|
||||
_, h, w = sub_bands.shape
|
||||
step = int(np.ceil(max(h, w) / float(max_side)))
|
||||
step = max(step, 1)
|
||||
return sub_bands[:, ::step, ::step], step
|
||||
# ---------- 新增:中分辨率粗分割(分块推理) ----------
|
||||
def _downsample_any(mask, factor):
|
||||
if factor <= 1:
|
||||
return mask
|
||||
h, w = mask.shape
|
||||
pad_h = (-h) % factor
|
||||
pad_w = (-w) % factor
|
||||
if pad_h or pad_w:
|
||||
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=False)
|
||||
h, w = mask.shape
|
||||
h2 = h // factor
|
||||
w2 = w // factor
|
||||
return mask.reshape(h2, factor, w2, factor).any(axis=(1, 3))
|
||||
|
||||
# ---------- 精修分块 ----------
|
||||
def refine_tiles(processor, src, prompt, band_coarse_cpu, tile_size, overlap, stretch_params=None):
|
||||
|
||||
def _build_band_ndimage(mask, radius):
|
||||
if radius <= 0:
|
||||
return np.zeros_like(mask, dtype=bool)
|
||||
struct = ndimage.generate_binary_structure(2, 2)
|
||||
dil = ndimage.binary_dilation(mask, structure=struct, iterations=radius)
|
||||
ero = ndimage.binary_erosion(mask, structure=struct, iterations=radius)
|
||||
return np.logical_xor(dil, ero)
|
||||
|
||||
|
||||
def coarse_tile_segmentation(
|
||||
processor,
|
||||
src,
|
||||
prompt,
|
||||
tile_size,
|
||||
overlap,
|
||||
downsample_factor=4,
|
||||
stretch_params=None,
|
||||
nodata_mask=None,
|
||||
use_tqdm=True,
|
||||
):
|
||||
"""
|
||||
对边缘带进行精修分割
|
||||
band_coarse_cpu: 2D numpy bool,指示需要精修的区域(低分辨率)
|
||||
对子区域进行分块粗分割,返回低分辨率概率图(float16)
|
||||
|
||||
Args:
|
||||
processor: 粗处理器(输入分辨率 coarse_resolution)
|
||||
src: rasterio 数据集(子区域)
|
||||
prompt: 文本提示
|
||||
tile_size: 粗分割块大小(原图像素)
|
||||
overlap: 块重叠像素
|
||||
stretch_params: 拉伸参数
|
||||
nodata_mask: 子区域的 NoData 掩码(bool, 与子区域同尺寸),用于过滤无效区域
|
||||
|
||||
Returns:
|
||||
coarse_probs_lr: 子区域粗概率图 (H_lr,W_lr) float16
|
||||
"""
|
||||
height, width = src.height, src.width
|
||||
full_probs = np.zeros((height, width), dtype=np.float16)
|
||||
|
||||
band_h, band_w = band_coarse_cpu.shape
|
||||
scale_y = band_h / float(height)
|
||||
scale_x = band_w / float(width)
|
||||
height_lr = int(math.ceil(height / float(downsample_factor)))
|
||||
width_lr = int(math.ceil(width / float(downsample_factor)))
|
||||
full_probs_lr = np.zeros((height_lr, width_lr), dtype=np.float16)
|
||||
|
||||
stride = max(tile_size - overlap, 1)
|
||||
num_tiles_y = (height + stride - 1) // stride
|
||||
num_tiles_x = (width + stride - 1) // stride
|
||||
total_tiles = num_tiles_y * num_tiles_x
|
||||
|
||||
# 内部进度条,leave=True 以便保留历史记录
|
||||
with tqdm(total=total_tiles, desc="精修分块", unit="块", leave=True) as pbar:
|
||||
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
|
||||
c_top = int(top * scale_y)
|
||||
c_left = int(left * scale_x)
|
||||
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1)
|
||||
c_right = max(int(np.ceil(right * scale_x)), c_left + 1)
|
||||
if use_tqdm:
|
||||
iterator = tqdm(
|
||||
tile_slices(height, width, tile_size, overlap),
|
||||
total=total_tiles,
|
||||
desc="粗分割分块",
|
||||
unit="块",
|
||||
leave=True,
|
||||
)
|
||||
else:
|
||||
iterator = tile_slices(height, width, tile_size, overlap)
|
||||
|
||||
if not band_coarse_cpu[c_top:c_bottom, c_left:c_right].any():
|
||||
pbar.update(1)
|
||||
for top, left, bottom, right in iterator:
|
||||
# 如果该块完全在 NoData 区域内,则跳过(加速)
|
||||
if nodata_mask is not None:
|
||||
block_nodata = nodata_mask[top:bottom, left:right]
|
||||
if block_nodata.all():
|
||||
continue
|
||||
|
||||
window = Window(left, top, right - left, bottom - top)
|
||||
crop = _read_pil_window(src, window, stretch_params=stretch_params)
|
||||
tile_prob = infer_prompt_prob(processor, crop, prompt)
|
||||
if tile_prob is None:
|
||||
continue
|
||||
|
||||
h_lr = int(math.ceil((bottom - top) / float(downsample_factor)))
|
||||
w_lr = int(math.ceil((right - left) / float(downsample_factor)))
|
||||
if tile_prob.shape[-2:] != (h_lr, w_lr):
|
||||
tile_prob = upsample_prob(tile_prob, (h_lr, w_lr))
|
||||
|
||||
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
|
||||
top_lr = top // downsample_factor
|
||||
left_lr = left // downsample_factor
|
||||
bottom_lr = top_lr + h_lr
|
||||
right_lr = left_lr + w_lr
|
||||
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr] = np.maximum(
|
||||
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr], tile_prob_cpu
|
||||
)
|
||||
|
||||
if nodata_mask is not None:
|
||||
nodata_lr = _downsample_any(nodata_mask, downsample_factor)
|
||||
full_probs_lr[nodata_lr] = 0.0
|
||||
|
||||
return full_probs_lr
|
||||
|
||||
# ---------- 精修分块(保持不变,但可复用之前的函数) ----------
|
||||
def refine_tiles(
|
||||
processor,
|
||||
src,
|
||||
prompt,
|
||||
band_lr,
|
||||
downsample_factor,
|
||||
tile_size,
|
||||
overlap,
|
||||
stretch_params=None,
|
||||
use_tqdm=True,
|
||||
):
|
||||
"""
|
||||
对边缘带进行精修分割
|
||||
band_lr: 2D numpy bool,指示需要精修的区域(低分辨率)
|
||||
"""
|
||||
height, width = src.height, src.width
|
||||
full_probs = np.zeros((height, width), dtype=np.float16)
|
||||
|
||||
stride = max(tile_size - overlap, 1)
|
||||
num_tiles_y = (height + stride - 1) // stride
|
||||
num_tiles_x = (width + stride - 1) // stride
|
||||
total_tiles = num_tiles_y * num_tiles_x
|
||||
|
||||
if use_tqdm:
|
||||
iterator = tqdm(
|
||||
tile_slices(height, width, tile_size, overlap),
|
||||
total=total_tiles,
|
||||
desc="精修分块",
|
||||
unit="块",
|
||||
leave=True,
|
||||
)
|
||||
else:
|
||||
iterator = tile_slices(height, width, tile_size, overlap)
|
||||
|
||||
for top, left, bottom, right in iterator:
|
||||
c_top = top // downsample_factor
|
||||
c_left = left // downsample_factor
|
||||
c_bottom = int(math.ceil(bottom / float(downsample_factor)))
|
||||
c_right = int(math.ceil(right / float(downsample_factor)))
|
||||
c_bottom = max(c_bottom, c_top + 1)
|
||||
c_right = max(c_right, c_left + 1)
|
||||
|
||||
if not band_lr[c_top:c_bottom, c_left:c_right].any():
|
||||
continue
|
||||
|
||||
window = Window(left, top, right - left, bottom - top)
|
||||
crop = _read_pil_window(src, window, stretch_params=stretch_params)
|
||||
tile_prob = infer_prompt_prob(processor, crop, prompt)
|
||||
if tile_prob is None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
if tile_prob.shape[-2:] != (bottom - top, right - left):
|
||||
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
|
||||
|
||||
tile_prob_cpu = tile_prob.detach().float().cpu().numpy().astype(np.float16)
|
||||
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
|
||||
full_probs[top:bottom, left:right] = np.maximum(
|
||||
full_probs[top:bottom, left:right], tile_prob_cpu
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
return full_probs
|
||||
|
||||
@ -230,80 +353,121 @@ def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=25
|
||||
regions.append((left, top, right, bottom))
|
||||
return regions
|
||||
|
||||
# ---------- 主程序 ----------
|
||||
matplotlib.use("TkAgg")
|
||||
DEFAULT_CONFIG = {
|
||||
"image_path": r"E:\is2\guidingsahn\result.tif",
|
||||
"mask_output_path": r"E:\is2\guidingsahn\result_mask.tif",
|
||||
"prompt": "water body",
|
||||
"overview_max_side": 1400,
|
||||
"overview_threshold": 0.4,
|
||||
"coarse_resolution": 1008,
|
||||
"fine_resolution": 1008,
|
||||
"coarse_threshold": 0.5,
|
||||
"final_threshold": 0.5,
|
||||
"band_radius": 64,
|
||||
"fine_tile_size": 2048,
|
||||
"fine_overlap": 256,
|
||||
"coarse_tile_size": 4096,
|
||||
"coarse_tile_overlap": 256,
|
||||
"coarse_downsample_factor": 4,
|
||||
"num_splits_y": 2,
|
||||
"num_splits_x": 3,
|
||||
"region_overlap": 256,
|
||||
"min_area": 5000,
|
||||
"keep_largest_only": False,
|
||||
"use_tqdm": True,
|
||||
"device": None,
|
||||
}
|
||||
|
||||
# 参数设置
|
||||
image_path = r"E:\is2\dingshanhu\result_caijian.tif"
|
||||
mask_output_path = r"E:\is2\dingshanhu\result_maskV2.tif"
|
||||
prompt = "water body"
|
||||
coarse_read_max_side = 1200
|
||||
coarse_resolution = 1008
|
||||
fine_resolution = 1008
|
||||
coarse_threshold = 0.5
|
||||
final_threshold = 0.5
|
||||
band_radius = 64
|
||||
tile_size = 1536
|
||||
overlap = 128
|
||||
|
||||
# ========== 分区域参数 ==========
|
||||
num_splits_y = 2 # 纵向切分数
|
||||
num_splits_x = 3 # 横向切分数(共 2x3=6 份)
|
||||
region_overlap = 256 # 子区域之间的重叠像素数
|
||||
# ===============================
|
||||
def run_segmentation(config=None, progress_callback=None, log_callback=None, stop_event=None):
|
||||
cfg = dict(DEFAULT_CONFIG)
|
||||
if config:
|
||||
cfg.update(config)
|
||||
|
||||
# ========== 后处理参数 ==========
|
||||
min_area = 1000 # 最小面积阈值(像素),小于此值的连通域将被移除;设为0或None表示不进行面积过滤
|
||||
keep_largest_only = False # 是否只保留最大的连通域(True/False)
|
||||
# ===============================
|
||||
log = log_callback if log_callback is not None else print
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"使用设备: {device}")
|
||||
device = cfg["device"]
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
log(f"使用设备: {device}")
|
||||
|
||||
# 加载模型(所有子区域共享同一个模型)
|
||||
model = build_sam3_image_model().to(device).eval()
|
||||
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device)
|
||||
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device)
|
||||
model = build_sam3_image_model().to(device).eval()
|
||||
coarse_processor = Sam3Processor(
|
||||
model, resolution=cfg["coarse_resolution"], device=device
|
||||
)
|
||||
fine_processor = Sam3Processor(model, resolution=cfg["fine_resolution"], device=device)
|
||||
|
||||
# 打开原始影像
|
||||
with rasterio.open(image_path) as src:
|
||||
nodata = src.nodata
|
||||
print(f"原始影像NoData值: {nodata}")
|
||||
height, width = src.height, src.width
|
||||
image_path = cfg["image_path"]
|
||||
mask_output_path = cfg["mask_output_path"]
|
||||
|
||||
# 划分区域
|
||||
regions = split_into_regions(width, height, num_splits_y, num_splits_x, region_overlap)
|
||||
print(f"共划分 {len(regions)} 个子区域")
|
||||
with rasterio.open(image_path) as src:
|
||||
nodata = src.nodata
|
||||
log(f"原始影像NoData值: {nodata}")
|
||||
height, width = src.height, src.width
|
||||
|
||||
# 创建全尺寸掩码数组(CPU内存)
|
||||
full_mask = np.zeros((height, width), dtype=np.uint8)
|
||||
band1 = src.read(1)
|
||||
if np.issubdtype(band1.dtype, np.floating):
|
||||
nodata_mask_full = (
|
||||
np.isclose(band1, float(nodata))
|
||||
if nodata is not None
|
||||
else np.zeros_like(band1, dtype=bool)
|
||||
)
|
||||
else:
|
||||
nodata_mask_full = (
|
||||
(band1 == nodata) if nodata is not None else np.zeros_like(band1, dtype=bool)
|
||||
)
|
||||
|
||||
# 子区域总进度条
|
||||
with tqdm(total=len(regions), desc="处理子区域", unit="子区域") as region_pbar:
|
||||
for idx, (left, top, right, bottom) in enumerate(regions):
|
||||
print(f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})")
|
||||
regions = split_into_regions(
|
||||
width, height, cfg["num_splits_y"], cfg["num_splits_x"], cfg["region_overlap"]
|
||||
)
|
||||
log(f"共划分 {len(regions)} 个子区域")
|
||||
|
||||
full_mask = np.zeros((height, width), dtype=np.uint8)
|
||||
|
||||
use_tqdm = bool(cfg.get("use_tqdm", True))
|
||||
if use_tqdm:
|
||||
region_iter = tqdm(
|
||||
list(enumerate(regions)),
|
||||
total=len(regions),
|
||||
desc="处理子区域",
|
||||
unit="子区域",
|
||||
)
|
||||
else:
|
||||
region_iter = enumerate(regions)
|
||||
|
||||
for idx, (left, top, right, bottom) in region_iter:
|
||||
if stop_event is not None and stop_event.is_set():
|
||||
log("已停止")
|
||||
return
|
||||
if progress_callback is not None:
|
||||
progress_callback("region", idx + 1, len(regions))
|
||||
|
||||
log(
|
||||
f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})"
|
||||
)
|
||||
region_w = right - left
|
||||
region_h = bottom - top
|
||||
|
||||
# 读取子区域数据
|
||||
window = Window(left, top, region_w, region_h)
|
||||
sub_bands = src.read(window=window) # shape: (bands, h, w)
|
||||
sub_bands = src.read(window=window)
|
||||
|
||||
# 构建子区域元数据
|
||||
sub_profile = src.profile.copy()
|
||||
if not src.is_tiled:
|
||||
sub_profile.pop('blockxsize', None)
|
||||
sub_profile.pop('blockysize', None)
|
||||
sub_profile['tiled'] = False
|
||||
sub_profile.pop("blockxsize", None)
|
||||
sub_profile.pop("blockysize", None)
|
||||
sub_profile["tiled"] = False
|
||||
else:
|
||||
sub_profile['tiled'] = True
|
||||
sub_profile.update({
|
||||
'height': region_h,
|
||||
'width': region_w,
|
||||
'transform': rasterio.windows.transform(window, src.transform)
|
||||
})
|
||||
sub_profile["tiled"] = True
|
||||
sub_profile.update(
|
||||
{
|
||||
"height": region_h,
|
||||
"width": region_w,
|
||||
"transform": rasterio.windows.transform(window, src.transform),
|
||||
}
|
||||
)
|
||||
|
||||
sub_nodata = nodata_mask_full[top:bottom, left:right]
|
||||
|
||||
# 将子区域数据包装成内存中的 rasterio 数据集
|
||||
with MemoryFile() as memfile:
|
||||
with memfile.open(**sub_profile) as sub_dst:
|
||||
sub_dst.write(sub_bands)
|
||||
@ -311,84 +475,144 @@ with rasterio.open(image_path) as src:
|
||||
stretch_params = compute_stretch_params(sub_bands)
|
||||
|
||||
overview_bands, overview_step = make_overview_bands(
|
||||
sub_bands, max_side=coarse_read_max_side
|
||||
sub_bands, max_side=cfg["overview_max_side"]
|
||||
)
|
||||
overview_img = _bands_to_pil(
|
||||
overview_bands, stretch_params=stretch_params
|
||||
)
|
||||
overview_prob = infer_prompt_prob(
|
||||
coarse_processor, overview_img, cfg["prompt"]
|
||||
)
|
||||
overview_img = _bands_to_pil(overview_bands, stretch_params=stretch_params)
|
||||
overview_prob = infer_prompt_prob(coarse_processor, overview_img, prompt)
|
||||
if overview_prob is None:
|
||||
overview_prob_cpu = np.zeros((overview_img.height, overview_img.width), dtype=np.float16)
|
||||
overview_mask_small = np.zeros(
|
||||
(overview_img.height, overview_img.width), dtype=bool
|
||||
)
|
||||
else:
|
||||
overview_prob_cpu = overview_prob.detach().float().cpu().numpy().astype(np.float16)
|
||||
overview_mask_cpu = overview_prob_cpu > coarse_threshold
|
||||
overview_mask_small = (
|
||||
(overview_prob > cfg["overview_threshold"])
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
scale = overview_img.height / float(region_h)
|
||||
band_radius_small = max(int(round(band_radius * scale)), 1)
|
||||
band_small = build_band(torch.from_numpy(overview_mask_cpu), band_radius_small).numpy()
|
||||
if sub_nodata.any():
|
||||
nodata_overview = sub_nodata[::overview_step, ::overview_step]
|
||||
overview_mask_small[nodata_overview] = False
|
||||
|
||||
fine_probs_cpu = refine_tiles(
|
||||
fine_processor,
|
||||
sub_src,
|
||||
prompt,
|
||||
band_small,
|
||||
tile_size=tile_size,
|
||||
overlap=overlap,
|
||||
coarse_probs = coarse_tile_segmentation(
|
||||
processor=coarse_processor,
|
||||
src=sub_src,
|
||||
prompt=cfg["prompt"],
|
||||
tile_size=cfg["coarse_tile_size"],
|
||||
overlap=cfg["coarse_tile_overlap"],
|
||||
downsample_factor=cfg["coarse_downsample_factor"],
|
||||
stretch_params=stretch_params,
|
||||
nodata_mask=sub_nodata if sub_nodata.any() else None,
|
||||
use_tqdm=use_tqdm,
|
||||
)
|
||||
fine_mask = fine_probs_cpu > final_threshold
|
||||
coarse_mask_lr = coarse_probs > cfg["coarse_threshold"]
|
||||
|
||||
overview_mask_lr = np.array(
|
||||
Image.fromarray(
|
||||
overview_mask_small.astype(np.uint8), mode="L"
|
||||
).resize(
|
||||
(coarse_mask_lr.shape[1], coarse_mask_lr.shape[0]),
|
||||
resample=Image.NEAREST,
|
||||
)
|
||||
).astype(bool)
|
||||
combined_mask_lr = coarse_mask_lr | overview_mask_lr
|
||||
|
||||
band_radius_lr = max(
|
||||
int(round(cfg["band_radius"] / float(cfg["coarse_downsample_factor"]))),
|
||||
1,
|
||||
)
|
||||
band_lr = _build_band_ndimage(combined_mask_lr, band_radius_lr)
|
||||
|
||||
fine_probs = refine_tiles(
|
||||
processor=fine_processor,
|
||||
src=sub_src,
|
||||
prompt=cfg["prompt"],
|
||||
band_lr=band_lr,
|
||||
downsample_factor=cfg["coarse_downsample_factor"],
|
||||
tile_size=cfg["fine_tile_size"],
|
||||
overlap=cfg["fine_overlap"],
|
||||
stretch_params=stretch_params,
|
||||
use_tqdm=use_tqdm,
|
||||
)
|
||||
fine_mask = fine_probs > cfg["final_threshold"]
|
||||
|
||||
coarse_mask_full = np.array(
|
||||
Image.fromarray(overview_mask_cpu.astype(np.uint8), mode="L").resize(
|
||||
Image.fromarray(coarse_mask_lr.astype(np.uint8), mode="L").resize(
|
||||
(region_w, region_h), resample=Image.NEAREST
|
||||
)
|
||||
).astype(bool)
|
||||
sub_mask_np = (fine_mask | coarse_mask_full).astype(np.uint8)
|
||||
overview_mask_full = np.array(
|
||||
Image.fromarray(
|
||||
overview_mask_small.astype(np.uint8), mode="L"
|
||||
).resize((region_w, region_h), resample=Image.NEAREST)
|
||||
).astype(bool)
|
||||
sub_mask_np = (fine_mask | coarse_mask_full | overview_mask_full).astype(
|
||||
np.uint8
|
||||
)
|
||||
|
||||
# 合并到全尺寸掩码(重叠区域取最大值)
|
||||
full_mask[top:bottom, left:right] = np.maximum(
|
||||
full_mask[top:bottom, left:right], sub_mask_np
|
||||
)
|
||||
|
||||
# 更新子区域进度条
|
||||
region_pbar.update(1)
|
||||
log("\n后处理:填充内部NoData空洞...")
|
||||
if nodata is not None:
|
||||
if np.issubdtype(band1.dtype, np.floating):
|
||||
nodata_mask = np.isclose(band1, float(nodata))
|
||||
else:
|
||||
nodata_mask = band1 == nodata
|
||||
|
||||
# ========== 后处理 ==========
|
||||
print("\n后处理:填充内部NoData空洞...")
|
||||
if nodata is not None:
|
||||
band1 = src.read(1)
|
||||
if np.issubdtype(band1.dtype, np.floating):
|
||||
nodata_mask = np.isclose(band1, float(nodata))
|
||||
else:
|
||||
nodata_mask = (band1 == nodata)
|
||||
labeled_mask, _ = ndimage.label(nodata_mask)
|
||||
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
|
||||
boundary_mask[0, :] = True
|
||||
boundary_mask[-1, :] = True
|
||||
boundary_mask[:, 0] = True
|
||||
boundary_mask[:, -1] = True
|
||||
|
||||
labeled_mask, num_features = ndimage.label(nodata_mask)
|
||||
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
|
||||
boundary_mask[0, :] = True
|
||||
boundary_mask[-1, :] = True
|
||||
boundary_mask[:, 0] = True
|
||||
boundary_mask[:, -1] = True
|
||||
boundary_labels = set(labeled_mask[boundary_mask])
|
||||
internal_nodata_mask = np.isin(
|
||||
labeled_mask, list(boundary_labels), invert=True
|
||||
) & nodata_mask
|
||||
|
||||
boundary_labels = set(labeled_mask[boundary_mask])
|
||||
internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask
|
||||
if internal_nodata_mask.any():
|
||||
struct = ndimage.generate_binary_structure(2, 2)
|
||||
internal_dilated = ndimage.binary_dilation(
|
||||
internal_nodata_mask, structure=struct, iterations=7
|
||||
)
|
||||
full_mask[internal_dilated] = 1
|
||||
log(
|
||||
f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}"
|
||||
)
|
||||
else:
|
||||
log(" 无内部NoData区域")
|
||||
|
||||
if internal_nodata_mask.any():
|
||||
struct = ndimage.generate_binary_structure(2, 2)
|
||||
internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7)
|
||||
full_mask[internal_dilated] = 1
|
||||
print(f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}")
|
||||
else:
|
||||
print(" 无内部NoData区域")
|
||||
log("后处理:面积过滤...")
|
||||
if cfg["min_area"] is not None or cfg["keep_largest_only"]:
|
||||
original_count = np.sum(full_mask)
|
||||
full_mask = filter_by_area(
|
||||
full_mask,
|
||||
min_area=cfg["min_area"],
|
||||
keep_largest_only=cfg["keep_largest_only"],
|
||||
)
|
||||
filtered_count = np.sum(full_mask)
|
||||
log(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
|
||||
|
||||
print("后处理:面积过滤...")
|
||||
if min_area is not None or keep_largest_only:
|
||||
original_count = np.sum(full_mask)
|
||||
full_mask = filter_by_area(full_mask, min_area=min_area, keep_largest_only=keep_largest_only)
|
||||
filtered_count = np.sum(full_mask)
|
||||
print(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
|
||||
profile = src.profile.copy()
|
||||
profile.update(count=1, dtype="uint8", compress="lzw")
|
||||
with rasterio.open(mask_output_path, "w", **profile) as dst:
|
||||
dst.write(full_mask, 1)
|
||||
|
||||
# ========== 保存结果 ==========
|
||||
profile = src.profile.copy()
|
||||
profile.update(count=1, dtype="uint8", compress='lzw')
|
||||
with rasterio.open(mask_output_path, "w", **profile) as dst:
|
||||
dst.write(full_mask, 1)
|
||||
log(f"分割完成,结果已保存至:{mask_output_path}")
|
||||
|
||||
print(f"分割完成,结果已保存至:{mask_output_path}")
|
||||
|
||||
def main():
|
||||
matplotlib.use("TkAgg")
|
||||
run_segmentation()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user