This commit is contained in:
2026-03-10 17:29:24 +08:00
parent 64692d8382
commit b8b6c6227d
4 changed files with 993 additions and 147 deletions

117
README_water_GUI.md Normal file
View File

@ -0,0 +1,117 @@
# water_GUI.py 使用说明SAM3 水体分割 GUI
本说明文档对应脚本:[water_GUI.py](file:///e:/code/sam3-main/sam3-main/water_GUI.py)。
该脚本是 **纯 GUI 启动器**:界面负责参数编辑、提示说明、日志与进度展示;真正的分割计算由 [water_V5.py](file:///e:/code/sam3-main/sam3-main/water_V5.py) 提供的 `run_segmentation()` 执行,并在点击“运行”时才会被懒加载导入。
## 1. 功能概览
- 选择输入 GeoTIFF待分割遥感影像与输出掩膜 GeoTIFF 路径
- 编辑分割参数Overview/Coarse/Fine/Region/Post
- 鼠标悬停在参数输入框上显示详细提示tooltip
- 后台线程执行分割GUI 不会卡死
- 右侧日志窗口实时输出运行信息
- 进度条显示子区域处理进度
- 支持“停止”(会在子区域边界处响应)
## 2. 运行方式
在项目目录下运行:
```bash
python e:\code\sam3-main\sam3-main\water_GUI.py
```
运行后:
- 点击“浏览”选择输入影像(`.tif/.tiff`
- 点击“浏览”选择输出掩膜路径(建议 `.tif`
- 调整参数(可保持默认)
- 点击“运行”
## 3. 依赖与环境
GUI 本身依赖:
- Python 标准库 `tkinter`
执行分割时会懒加载并依赖(来自 water_V5.py
- torch建议 CUDA 可用)
- rasterio
- numpy
- scipy
- pillow
- tqdm
- 以及 sam3 模型代码与权重(由项目现有加载逻辑处理)
## 4. 参数说明(界面分组)
### 4.1 Overview 兜底
用于减少“中心水域漏检”。Overview 会对每个子区域做一次低分辨率全局预测,并在最终输出中与 coarse/fine 结果取并集。
- `overview_max_side`:兜底预测的最大边长。越大越准越慢。
- `overview_threshold`:兜底阈值。偏低可减少漏检,但误检会增加。
### 4.2 Coarse 粗分割
粗分割用于提高召回、生成精修的边缘带band并与 overview 一起决定精修区域范围。
- `coarse_threshold`:粗分割阈值。偏低提高召回,偏高减少误检。
- `coarse_tile_size`:粗分割分块大小(原图像素)。
- `coarse_tile_overlap`:粗分割块重叠像素。
- `coarse_downsample_factor`粗分割输出降采样比例。2 更细更慢4 折中8 更快更粗。
- `band_radius`:边缘带半径(原图像素)。越大精修范围越大,速度越慢。
### 4.3 Fine 精修
只对边缘带覆盖到的区域进行精细分割,补足边界细节。
- `final_threshold`:最终阈值。偏低更易连通,偏高更干净。
- `fine_tile_size`:精修分块大小(原图像素)。一般 1536 更稳2048 更快但更吃显存/更慢。
- `fine_overlap`:精修块重叠像素。越大接缝越少但更慢。
### 4.4 Region 子区域
将整图划分成多个有重叠的子区域,降低内存峰值。
- `num_splits_y`:纵向切分数。
- `num_splits_x`:横向切分数。
- `region_overlap`:子区域重叠像素,减少子区域边界断裂。
### 4.5 Post 后处理
- `min_area`:移除小碎片的最小连通域面积(像素)。
- `keep_largest_only`:只保留最大连通域(适合只要最大水体)。
## 5. 输出说明
输出文件为单通道 `uint8` 掩膜 GeoTIFF
- 0非水体
- 1水体
输出文件的空间参考CRS/transform来自输入影像的 profile。
## 6. 使用建议16GB 显存)
优先按以下顺序调参:
- 先把 `fine_tile_size` 从 2048 降到 1536更稳
- `fine_overlap` 从 256 降到 128更快可能稍有拼接缝
- `coarse_downsample_factor` 从 4 改 2更细更慢但中心/边缘召回更稳)
- `overview_threshold` 从 0.4 降到 0.35(召回更强,误检更可能增加)
## 7. 常见问题
### 7.1 点击“运行”后报 `ModuleNotFoundError: water_V5`
确保 `water_GUI.py``water_V5.py` 在同一目录:
- `e:\code\sam3-main\sam3-main\water_GUI.py`
- `e:\code\sam3-main\sam3-main\water_V5.py`
并从该目录执行运行命令,或将该目录加入 Python 路径。
### 7.2 界面正常但运行很慢
这通常是正常的:粗分割/精修会对大量瓦片做模型推理。优先调小 `fine_tile_size``fine_overlap`,并提高 `coarse_downsample_factor`
### 7.3 “停止”不立即生效
停止信号会在子区域边界处检查,因此可能需要等待当前子区域处理完成后才退出。

155
tif_caijain.py Normal file
View File

@ -0,0 +1,155 @@
"""
使用二值掩膜 TIF 文件值为1的区域需要去除对数据 TIF 文件进行掩膜。
输入:
data_tif: 要掩膜的数据文件路径
mask_tif: 二值掩膜文件路径值为1表示需要去除的区域
输出:
掩膜后的数据 TIF 文件,仅将掩膜对应位置设为 NoData
要求:
两个 TIF 文件具有相同的投影、分辨率、范围和尺寸(精确对齐),
否则程序将报错或行为未定义。
"""
import argparse
import numpy as np
import rasterio
from rasterio.windows import Window
import logging
from pathlib import Path
from tqdm import tqdm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def mask_data_by_binary_mask(
data_path,
mask_path,
output_path=None,
remove_value=1,
nodata_value=None,
tile_size=4096,
):
"""使用二值掩膜 TIF 对数据 TIF 进行掩膜。
将数据 TIF 中对应掩膜值等于 remove_value 的像素设为 NoData其余保留。
性能建议:
- 若数据源是 tiled GeoTIFF可将 tile_size 设为 0 以按源文件块窗口遍历(通常更快)。
"""
data_path = Path(data_path)
mask_path = Path(mask_path)
if output_path is None:
output_path = data_path.parent / f"{data_path.stem}_masked{data_path.suffix}"
else:
output_path = Path(output_path)
logger.info(f"数据文件: {data_path.name}")
logger.info(f"掩膜文件: {mask_path.name}")
logger.info(f"去除掩膜值: {remove_value}")
with rasterio.Env(GDAL_NUM_THREADS="ALL_CPUS"):
with rasterio.open(data_path) as src_data, rasterio.open(mask_path) as src_mask:
if src_data.crs != src_mask.crs:
raise ValueError("数据与掩膜的 CRS 不一致,请先统一投影。")
if src_data.transform != src_mask.transform:
logger.warning(
"数据与掩膜的地理变换不一致,可能未精确对齐,继续处理可能存在风险。"
)
if (src_data.width, src_data.height) != (src_mask.width, src_mask.height):
raise ValueError("数据与掩膜的尺寸不一致,无法直接按像素对应掩膜。")
# 确定输出 NoData 值(并尽量匹配数据 dtype避免隐式类型转换带来的开销
if nodata_value is None:
nodata_value = src_data.nodata if src_data.nodata is not None else 0
try:
nodata_value_cast = np.array(
nodata_value, dtype=src_data.dtypes[0]
).item()
except Exception:
nodata_value_cast = nodata_value
# 创建输出元数据:基于数据源的元数据,更新 nodata 和压缩选项
out_meta = src_data.meta.copy()
out_meta.update(
{
"nodata": nodata_value,
"compress": (
src_data.compression.value if src_data.compression else "lzw"
),
"tiled": src_data.is_tiled,
}
)
if src_data.is_tiled:
out_meta.update(
{
"blockxsize": src_data.block_shapes[0][0],
"blockysize": src_data.block_shapes[0][1],
}
)
# 创建输出文件
with rasterio.open(output_path, "w", **out_meta) as dst:
width, height = src_data.width, src_data.height
if tile_size is None or tile_size <= 0:
windows = [w for _, w in src_data.block_windows(1)]
else:
stride = int(tile_size)
windows = [
Window(i, j, min(stride, width - i), min(stride, height - j))
for i in range(0, width, stride)
for j in range(0, height, stride)
]
with tqdm(total=len(windows), desc="处理瓦片", unit="") as pbar:
for window in windows:
# 读取相同位置的掩膜瓦片(假设完全对齐)
mask = src_mask.read(1, window=window)
remove_mask = mask == remove_value
# 读取数据瓦片
data = src_data.read(window=window) # shape: (bands, h, w)
if remove_mask.any():
for band_idx in range(data.shape[0]):
np.putmask(
data[band_idx], remove_mask, nodata_value_cast
)
dst.write(data, window=window)
pbar.update(1)
logger.info(f"处理完成,输出文件:{output_path}")
return str(output_path)
def main():
parser = argparse.ArgumentParser(
description="使用二值掩膜 TIF值为1的区域对数据 TIF 进行掩膜,将对应位置设为 NoData。"
)
parser.add_argument("data_tif", help="要掩膜的数据 TIF 文件路径")
parser.add_argument("mask_tif", help="二值掩膜 TIF 文件路径值为1表示需要去除的区域")
parser.add_argument("-o", "--output", help="输出文件路径 (可选)")
parser.add_argument("-r", "--remove_value", type=int, default=1,
help="掩膜中要去除的值默认为1")
parser.add_argument("-n", "--nodata", type=float,
help="输出 NoData 值 (可选,默认使用数据 TIF 的 NoData 或 0)")
parser.add_argument(
"-t",
"--tile_size",
type=int,
default=4096,
help="分块大小像素默认4096设为0则按源文件块窗口遍历tiled 文件通常更快)",
)
args = parser.parse_args()
mask_data_by_binary_mask(
args.data_tif, args.mask_tif, args.output,
args.remove_value, args.nodata, args.tile_size
)
if __name__ == "__main__":
exit(main())

350
water_GUI.py Normal file
View File

@ -0,0 +1,350 @@
"""
water_V5 的图形界面启动器(纯 GUI 文件)。
特点:
- 参数可配置
- 悬停显示提示
- 后台线程运行分割,界面不假死
"""
import threading
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
DEFAULTS = {
"image_path": "",
"mask_output_path": "",
"prompt": "water body",
"overview_max_side": 1400,
"overview_threshold": 0.4,
"coarse_threshold": 0.5,
"coarse_tile_size": 4096,
"coarse_tile_overlap": 256,
"coarse_downsample_factor": 4,
"band_radius": 64,
"final_threshold": 0.5,
"fine_tile_size": 2048,
"fine_overlap": 256,
"num_splits_y": 2,
"num_splits_x": 3,
"region_overlap": 256,
"min_area": 5000,
"keep_largest_only": False,
}
def _lazy_import_backend():
import water_V5
return water_V5
class Tooltip:
def __init__(self, widget, text):
self.widget = widget
self.text = text
self.tip = None
widget.bind("<Enter>", self._show)
widget.bind("<Leave>", self._hide)
def _show(self, _event=None):
if self.tip is not None or not self.text:
return
x = self.widget.winfo_rootx() + 10
y = self.widget.winfo_rooty() + self.widget.winfo_height() + 5
self.tip = tk.Toplevel(self.widget)
self.tip.wm_overrideredirect(True)
self.tip.wm_geometry(f"+{x}+{y}")
label = tk.Label(
self.tip,
text=self.text,
justify="left",
background="#ffffe0",
relief="solid",
borderwidth=1,
font=("Segoe UI", 9),
)
label.pack(ipadx=6, ipady=4)
def _hide(self, _event=None):
if self.tip is not None:
self.tip.destroy()
self.tip = None
class WaterSegGUI(tk.Tk):
def __init__(self):
super().__init__()
self.title("SAM3 水体分割 GUI")
self.geometry("980x720")
self.stop_event = threading.Event()
self.worker = None
self.vars = {}
self._build_ui()
def _build_ui(self):
container = ttk.Frame(self)
container.pack(fill="both", expand=True, padx=10, pady=10)
top = ttk.Frame(container)
top.pack(fill="x")
self._add_path_row(
top,
"image_path",
"输入影像 (TIF)",
DEFAULTS["image_path"],
"选择要分割的遥感影像 GeoTIFF。",
is_save=False,
)
self._add_path_row(
top,
"mask_output_path",
"输出掩膜 (TIF)",
DEFAULTS["mask_output_path"],
"输出单通道 uint8 掩膜0/1",
is_save=True,
)
row = ttk.Frame(top)
row.pack(fill="x", pady=(6, 0))
ttk.Label(row, text="prompt").pack(side="left")
prompt_var = tk.StringVar(value=DEFAULTS["prompt"])
self.vars["prompt"] = prompt_var
prompt_box = ttk.Combobox(
row,
textvariable=prompt_var,
values=["water body", "water", "river", "lake", "reservoir"],
)
prompt_box.pack(side="left", fill="x", expand=True, padx=(10, 0))
Tooltip(prompt_box, "文本提示。建议优先尝试water / river / lake。")
panes = ttk.Panedwindow(container, orient="horizontal")
panes.pack(fill="both", expand=True, pady=(10, 0))
left = ttk.Frame(panes)
right = ttk.Frame(panes)
panes.add(left, weight=3)
panes.add(right, weight=2)
params = ttk.Frame(left)
params.pack(fill="both", expand=True)
self._add_group(
params,
"Overview 兜底",
[
("overview_max_side", int, "子区域兜底预测最大边长。越大越准越慢。"),
("overview_threshold", float, "兜底阈值。偏低减少中心漏检,但误检增加。"),
],
)
self._add_group(
params,
"Coarse 粗分割",
[
("coarse_threshold", float, "粗分割阈值。偏低提高召回,偏高减少误检。"),
("coarse_tile_size", int, "粗分割分块大小(原图像素)。"),
("coarse_tile_overlap", int, "粗分割块重叠像素。"),
("coarse_downsample_factor", int, "粗分割输出降采样比例。2 更细更慢4 折中8 更快更粗。"),
("band_radius", int, "边缘带半径(原图像素)。越大精修范围越大越慢。"),
],
)
self._add_group(
params,
"Fine 精修",
[
("final_threshold", float, "最终阈值。偏低更易连通,偏高更干净。"),
("fine_tile_size", int, "精修分块大小(原图像素)。"),
("fine_overlap", int, "精修重叠像素。越大接缝越少但更慢。"),
],
)
self._add_group(
params,
"Region 子区域",
[
("num_splits_y", int, "纵向切分数。"),
("num_splits_x", int, "横向切分数。"),
("region_overlap", int, "子区域之间重叠像素,减少边界断裂。"),
],
)
self._add_group(
params,
"Post 后处理",
[
("min_area", int, "移除小碎片的最小连通域面积(像素)。"),
("keep_largest_only", bool, "只保留最大连通域(适合只要最大水体)。"),
],
)
ctrl = ttk.Frame(right)
ctrl.pack(fill="x")
self.progress = ttk.Progressbar(ctrl, maximum=100)
self.progress.pack(fill="x")
self.status = tk.StringVar(value="就绪")
ttk.Label(ctrl, textvariable=self.status).pack(anchor="w", pady=(6, 0))
btn_row = ttk.Frame(ctrl)
btn_row.pack(fill="x", pady=(8, 0))
self.run_btn = ttk.Button(btn_row, text="运行", command=self._on_run)
self.run_btn.pack(side="left")
self.stop_btn = ttk.Button(btn_row, text="停止", command=self._on_stop, state="disabled")
self.stop_btn.pack(side="left", padx=(8, 0))
log_frame = ttk.Frame(right)
log_frame.pack(fill="both", expand=True, pady=(10, 0))
ttk.Label(log_frame, text="日志").pack(anchor="w")
self.log_text = tk.Text(log_frame, height=20, wrap="word")
self.log_text.pack(side="left", fill="both", expand=True)
scroll = ttk.Scrollbar(log_frame, command=self.log_text.yview)
scroll.pack(side="right", fill="y")
self.log_text.configure(yscrollcommand=scroll.set)
def _add_path_row(self, parent, key, label, default, tip, is_save):
row = ttk.Frame(parent)
row.pack(fill="x", pady=(0, 6))
ttk.Label(row, text=label).pack(side="left")
var = tk.StringVar(value=default)
self.vars[key] = var
entry = ttk.Entry(row, textvariable=var)
entry.pack(side="left", fill="x", expand=True, padx=(10, 0))
Tooltip(entry, tip)
btn = ttk.Button(
row,
text="浏览",
command=lambda: self._browse_path(var, is_save=is_save),
)
btn.pack(side="left", padx=(8, 0))
def _add_group(self, parent, title, items):
frame = ttk.LabelFrame(parent, text=title)
frame.pack(fill="x", pady=(0, 10))
for r, (key, typ, tip) in enumerate(items):
ttk.Label(frame, text=key).grid(row=r, column=0, sticky="w", padx=(6, 8), pady=4)
if typ is bool:
var = tk.BooleanVar(value=bool(DEFAULTS.get(key)))
self.vars[key] = var
w = ttk.Checkbutton(frame, variable=var)
w.grid(row=r, column=1, sticky="w", pady=4)
Tooltip(w, tip)
else:
var = tk.StringVar(value=str(DEFAULTS.get(key, "")))
self.vars[key] = var
entry = ttk.Entry(frame, textvariable=var, width=14)
entry.grid(row=r, column=1, sticky="w", pady=4)
Tooltip(entry, tip)
frame.grid_columnconfigure(1, weight=1)
def _browse_path(self, var, is_save):
if is_save:
p = filedialog.asksaveasfilename(
title="选择输出文件",
defaultextension=".tif",
filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")],
)
else:
p = filedialog.askopenfilename(
title="选择输入文件",
filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")],
)
if p:
var.set(p)
def _append_log(self, msg):
self.log_text.insert("end", msg + "\n")
self.log_text.see("end")
def _log(self, msg):
self.after(0, lambda: self._append_log(str(msg)))
def _progress(self, stage, cur, total):
def _update():
if total > 0:
self.progress["value"] = (cur / total) * 100.0
self.status.set(f"{stage}: {cur}/{total}")
self.after(0, _update)
def _get_config(self):
cfg = {}
for k, v in self.vars.items():
if isinstance(v, tk.BooleanVar):
cfg[k] = bool(v.get())
else:
cfg[k] = v.get()
int_keys = [
"overview_max_side",
"band_radius",
"fine_tile_size",
"fine_overlap",
"coarse_tile_size",
"coarse_tile_overlap",
"coarse_downsample_factor",
"num_splits_y",
"num_splits_x",
"region_overlap",
"min_area",
]
float_keys = ["overview_threshold", "coarse_threshold", "final_threshold"]
for k in int_keys:
if k in cfg and cfg[k] != "":
cfg[k] = int(float(cfg[k]))
for k in float_keys:
if k in cfg and cfg[k] != "":
cfg[k] = float(cfg[k])
cfg["use_tqdm"] = False
cfg["device"] = None
return cfg
def _on_run(self):
if self.worker is not None and self.worker.is_alive():
return
cfg = self._get_config()
if not cfg.get("image_path") or not cfg.get("mask_output_path"):
messagebox.showerror("错误", "请先选择输入影像与输出路径。")
return
self.stop_event.clear()
self.run_btn.configure(state="disabled")
self.stop_btn.configure(state="normal")
self.progress["value"] = 0
self.status.set("启动中...")
self._append_log("开始运行...")
def _work():
try:
water_v5 = _lazy_import_backend()
water_v5.run_segmentation(
config=cfg,
progress_callback=self._progress,
log_callback=self._log,
stop_event=self.stop_event,
)
except Exception as e:
self._log(f"错误: {e}")
finally:
def _done():
self.run_btn.configure(state="normal")
self.stop_btn.configure(state="disabled")
self.status.set("完成")
self.after(0, _done)
self.worker = threading.Thread(target=_work, daemon=True)
self.worker.start()
def _on_stop(self):
self.stop_event.set()
self.status.set("停止中...")
def main():
WaterSegGUI().mainloop()
if __name__ == "__main__":
main()

View File

@ -1,14 +1,17 @@
"""
使用SAM3模型分割超大遥感影像中的水体分区域处理,优化内存
使用SAM3模型分割超大遥感影像中的水体三层分割策略
流程:
1. 将影像划分为若干有重叠的子区域。
2. 对每个子区域独立执行粗分割、边缘带构建、精修分割
3. 合并所有子区域的掩码(重叠区域取最大值)。
4. 后处理填充内部NoData空洞、移除小面积碎片、可选保留最大连通域
5. 保存结果
1. 将影像划分为若干有重叠的子区域(第一层)
2. 对每个子区域进行中分辨率粗分割以4096为块大小滑动窗口推理得到粗掩码第二层
3. 基于粗掩码构建边缘带,对边缘带进行精细分块推理,得到精修掩码(第三层)。
4. 合并所有子区域的掩码(重叠区域取最大值)
5. 后处理填充内部NoData空洞、移除小面积碎片、可选保留最大连通域
6. 保存结果。
优点:避免在全尺寸概率图上进行GPU操作显著降低显存占用。
优点:
- 粗分割采用分块推理,避免降采样丢失细节,提高召回率。
- 精修只处理边缘带,节省计算资源。
"""
import torch
@ -22,6 +25,7 @@ from rasterio.windows import Window
from rasterio.io import MemoryFile
from tqdm import tqdm
from scipy import ndimage
import math
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
@ -77,6 +81,13 @@ def compute_stretch_params(bands, sample_max_pixels=2_000_000):
return vmin, vmax
def make_overview_bands(bands, max_side):
_, h, w = bands.shape
step = int(math.ceil(max(h, w) / float(max_side)))
step = max(step, 1)
return bands[:, ::step, ::step], step
def _to_uint8(arr, vmin=None, vmax=None):
if arr.dtype == np.uint8:
return arr
@ -112,7 +123,7 @@ def _bands_to_pil(bands, stretch_params=None):
return Image.fromarray(rgb, mode="RGB")
def _read_pil_window(src, window, stretch_params=None):
bands = src.read(window=window, boundless=True, fill_value=0)
bands = src.read(window=window)
return _bands_to_pil(bands, stretch_params=stretch_params)
def infer_prompt_prob(processor, image, prompt):
@ -126,57 +137,169 @@ def infer_prompt_prob(processor, image, prompt):
prob = combine_masks_logits(state["masks_logits"])
return prob
def make_overview_bands(sub_bands, max_side):
_, h, w = sub_bands.shape
step = int(np.ceil(max(h, w) / float(max_side)))
step = max(step, 1)
return sub_bands[:, ::step, ::step], step
# ---------- 新增:中分辨率粗分割(分块推理) ----------
def _downsample_any(mask, factor):
if factor <= 1:
return mask
h, w = mask.shape
pad_h = (-h) % factor
pad_w = (-w) % factor
if pad_h or pad_w:
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=False)
h, w = mask.shape
h2 = h // factor
w2 = w // factor
return mask.reshape(h2, factor, w2, factor).any(axis=(1, 3))
# ---------- 精修分块 ----------
def refine_tiles(processor, src, prompt, band_coarse_cpu, tile_size, overlap, stretch_params=None):
def _build_band_ndimage(mask, radius):
if radius <= 0:
return np.zeros_like(mask, dtype=bool)
struct = ndimage.generate_binary_structure(2, 2)
dil = ndimage.binary_dilation(mask, structure=struct, iterations=radius)
ero = ndimage.binary_erosion(mask, structure=struct, iterations=radius)
return np.logical_xor(dil, ero)
def coarse_tile_segmentation(
processor,
src,
prompt,
tile_size,
overlap,
downsample_factor=4,
stretch_params=None,
nodata_mask=None,
use_tqdm=True,
):
"""
边缘带进行精修分割
band_coarse_cpu: 2D numpy bool指示需要精修的区域低分辨率
子区域进行分块粗分割返回低分辨率概率图float16
Args:
processor: 粗处理器(输入分辨率 coarse_resolution
src: rasterio 数据集(子区域)
prompt: 文本提示
tile_size: 粗分割块大小(原图像素)
overlap: 块重叠像素
stretch_params: 拉伸参数
nodata_mask: 子区域的 NoData 掩码bool, 与子区域同尺寸),用于过滤无效区域
Returns:
coarse_probs_lr: 子区域粗概率图 (H_lr,W_lr) float16
"""
height, width = src.height, src.width
full_probs = np.zeros((height, width), dtype=np.float16)
band_h, band_w = band_coarse_cpu.shape
scale_y = band_h / float(height)
scale_x = band_w / float(width)
height_lr = int(math.ceil(height / float(downsample_factor)))
width_lr = int(math.ceil(width / float(downsample_factor)))
full_probs_lr = np.zeros((height_lr, width_lr), dtype=np.float16)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
# 内部进度条leave=True 以便保留历史记录
with tqdm(total=total_tiles, desc="精修分块", unit="", leave=True) as pbar:
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
c_top = int(top * scale_y)
c_left = int(left * scale_x)
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1)
c_right = max(int(np.ceil(right * scale_x)), c_left + 1)
if use_tqdm:
iterator = tqdm(
tile_slices(height, width, tile_size, overlap),
total=total_tiles,
desc="粗分割分块",
unit="",
leave=True,
)
else:
iterator = tile_slices(height, width, tile_size, overlap)
if not band_coarse_cpu[c_top:c_bottom, c_left:c_right].any():
pbar.update(1)
for top, left, bottom, right in iterator:
# 如果该块完全在 NoData 区域内,则跳过(加速)
if nodata_mask is not None:
block_nodata = nodata_mask[top:bottom, left:right]
if block_nodata.all():
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window, stretch_params=stretch_params)
tile_prob = infer_prompt_prob(processor, crop, prompt)
if tile_prob is None:
continue
h_lr = int(math.ceil((bottom - top) / float(downsample_factor)))
w_lr = int(math.ceil((right - left) / float(downsample_factor)))
if tile_prob.shape[-2:] != (h_lr, w_lr):
tile_prob = upsample_prob(tile_prob, (h_lr, w_lr))
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
top_lr = top // downsample_factor
left_lr = left // downsample_factor
bottom_lr = top_lr + h_lr
right_lr = left_lr + w_lr
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr] = np.maximum(
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr], tile_prob_cpu
)
if nodata_mask is not None:
nodata_lr = _downsample_any(nodata_mask, downsample_factor)
full_probs_lr[nodata_lr] = 0.0
return full_probs_lr
# ---------- 精修分块(保持不变,但可复用之前的函数) ----------
def refine_tiles(
processor,
src,
prompt,
band_lr,
downsample_factor,
tile_size,
overlap,
stretch_params=None,
use_tqdm=True,
):
"""
对边缘带进行精修分割
band_lr: 2D numpy bool指示需要精修的区域低分辨率
"""
height, width = src.height, src.width
full_probs = np.zeros((height, width), dtype=np.float16)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
if use_tqdm:
iterator = tqdm(
tile_slices(height, width, tile_size, overlap),
total=total_tiles,
desc="精修分块",
unit="",
leave=True,
)
else:
iterator = tile_slices(height, width, tile_size, overlap)
for top, left, bottom, right in iterator:
c_top = top // downsample_factor
c_left = left // downsample_factor
c_bottom = int(math.ceil(bottom / float(downsample_factor)))
c_right = int(math.ceil(right / float(downsample_factor)))
c_bottom = max(c_bottom, c_top + 1)
c_right = max(c_right, c_left + 1)
if not band_lr[c_top:c_bottom, c_left:c_right].any():
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window, stretch_params=stretch_params)
tile_prob = infer_prompt_prob(processor, crop, prompt)
if tile_prob is None:
pbar.update(1)
continue
if tile_prob.shape[-2:] != (bottom - top, right - left):
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
tile_prob_cpu = tile_prob.detach().float().cpu().numpy().astype(np.float16)
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
full_probs[top:bottom, left:right] = np.maximum(
full_probs[top:bottom, left:right], tile_prob_cpu
)
pbar.update(1)
return full_probs
@ -230,80 +353,121 @@ def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=25
regions.append((left, top, right, bottom))
return regions
# ---------- 主程序 ----------
matplotlib.use("TkAgg")
DEFAULT_CONFIG = {
"image_path": r"E:\is2\guidingsahn\result.tif",
"mask_output_path": r"E:\is2\guidingsahn\result_mask.tif",
"prompt": "water body",
"overview_max_side": 1400,
"overview_threshold": 0.4,
"coarse_resolution": 1008,
"fine_resolution": 1008,
"coarse_threshold": 0.5,
"final_threshold": 0.5,
"band_radius": 64,
"fine_tile_size": 2048,
"fine_overlap": 256,
"coarse_tile_size": 4096,
"coarse_tile_overlap": 256,
"coarse_downsample_factor": 4,
"num_splits_y": 2,
"num_splits_x": 3,
"region_overlap": 256,
"min_area": 5000,
"keep_largest_only": False,
"use_tqdm": True,
"device": None,
}
# 参数设置
image_path = r"E:\is2\dingshanhu\result_caijian.tif"
mask_output_path = r"E:\is2\dingshanhu\result_maskV2.tif"
prompt = "water body"
coarse_read_max_side = 1200
coarse_resolution = 1008
fine_resolution = 1008
coarse_threshold = 0.5
final_threshold = 0.5
band_radius = 64
tile_size = 1536
overlap = 128
# ========== 分区域参数 ==========
num_splits_y = 2 # 纵向切分数
num_splits_x = 3 # 横向切分数(共 2x3=6 份)
region_overlap = 256 # 子区域之间的重叠像素数
# ===============================
def run_segmentation(config=None, progress_callback=None, log_callback=None, stop_event=None):
cfg = dict(DEFAULT_CONFIG)
if config:
cfg.update(config)
# ========== 后处理参数 ==========
min_area = 1000 # 最小面积阈值像素小于此值的连通域将被移除设为0或None表示不进行面积过滤
keep_largest_only = False # 是否只保留最大的连通域True/False
# ===============================
log = log_callback if log_callback is not None else print
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
device = cfg["device"]
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
log(f"使用设备: {device}")
# 加载模型(所有子区域共享同一个模型)
model = build_sam3_image_model().to(device).eval()
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device)
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device)
model = build_sam3_image_model().to(device).eval()
coarse_processor = Sam3Processor(
model, resolution=cfg["coarse_resolution"], device=device
)
fine_processor = Sam3Processor(model, resolution=cfg["fine_resolution"], device=device)
# 打开原始影像
with rasterio.open(image_path) as src:
image_path = cfg["image_path"]
mask_output_path = cfg["mask_output_path"]
with rasterio.open(image_path) as src:
nodata = src.nodata
print(f"原始影像NoData值: {nodata}")
log(f"原始影像NoData值: {nodata}")
height, width = src.height, src.width
# 划分区域
regions = split_into_regions(width, height, num_splits_y, num_splits_x, region_overlap)
print(f"共划分 {len(regions)} 个子区域")
band1 = src.read(1)
if np.issubdtype(band1.dtype, np.floating):
nodata_mask_full = (
np.isclose(band1, float(nodata))
if nodata is not None
else np.zeros_like(band1, dtype=bool)
)
else:
nodata_mask_full = (
(band1 == nodata) if nodata is not None else np.zeros_like(band1, dtype=bool)
)
regions = split_into_regions(
width, height, cfg["num_splits_y"], cfg["num_splits_x"], cfg["region_overlap"]
)
log(f"共划分 {len(regions)} 个子区域")
# 创建全尺寸掩码数组CPU内存
full_mask = np.zeros((height, width), dtype=np.uint8)
# 子区域总进度条
with tqdm(total=len(regions), desc="处理子区域", unit="子区域") as region_pbar:
for idx, (left, top, right, bottom) in enumerate(regions):
print(f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})")
use_tqdm = bool(cfg.get("use_tqdm", True))
if use_tqdm:
region_iter = tqdm(
list(enumerate(regions)),
total=len(regions),
desc="处理子区域",
unit="子区域",
)
else:
region_iter = enumerate(regions)
for idx, (left, top, right, bottom) in region_iter:
if stop_event is not None and stop_event.is_set():
log("已停止")
return
if progress_callback is not None:
progress_callback("region", idx + 1, len(regions))
log(
f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})"
)
region_w = right - left
region_h = bottom - top
# 读取子区域数据
window = Window(left, top, region_w, region_h)
sub_bands = src.read(window=window) # shape: (bands, h, w)
sub_bands = src.read(window=window)
# 构建子区域元数据
sub_profile = src.profile.copy()
if not src.is_tiled:
sub_profile.pop('blockxsize', None)
sub_profile.pop('blockysize', None)
sub_profile['tiled'] = False
sub_profile.pop("blockxsize", None)
sub_profile.pop("blockysize", None)
sub_profile["tiled"] = False
else:
sub_profile['tiled'] = True
sub_profile.update({
'height': region_h,
'width': region_w,
'transform': rasterio.windows.transform(window, src.transform)
})
sub_profile["tiled"] = True
sub_profile.update(
{
"height": region_h,
"width": region_w,
"transform": rasterio.windows.transform(window, src.transform),
}
)
sub_nodata = nodata_mask_full[top:bottom, left:right]
# 将子区域数据包装成内存中的 rasterio 数据集
with MemoryFile() as memfile:
with memfile.open(**sub_profile) as sub_dst:
sub_dst.write(sub_bands)
@ -311,56 +475,98 @@ with rasterio.open(image_path) as src:
stretch_params = compute_stretch_params(sub_bands)
overview_bands, overview_step = make_overview_bands(
sub_bands, max_side=coarse_read_max_side
sub_bands, max_side=cfg["overview_max_side"]
)
overview_img = _bands_to_pil(
overview_bands, stretch_params=stretch_params
)
overview_prob = infer_prompt_prob(
coarse_processor, overview_img, cfg["prompt"]
)
overview_img = _bands_to_pil(overview_bands, stretch_params=stretch_params)
overview_prob = infer_prompt_prob(coarse_processor, overview_img, prompt)
if overview_prob is None:
overview_prob_cpu = np.zeros((overview_img.height, overview_img.width), dtype=np.float16)
else:
overview_prob_cpu = overview_prob.detach().float().cpu().numpy().astype(np.float16)
overview_mask_cpu = overview_prob_cpu > coarse_threshold
scale = overview_img.height / float(region_h)
band_radius_small = max(int(round(band_radius * scale)), 1)
band_small = build_band(torch.from_numpy(overview_mask_cpu), band_radius_small).numpy()
fine_probs_cpu = refine_tiles(
fine_processor,
sub_src,
prompt,
band_small,
tile_size=tile_size,
overlap=overlap,
stretch_params=stretch_params,
overview_mask_small = np.zeros(
(overview_img.height, overview_img.width), dtype=bool
)
fine_mask = fine_probs_cpu > final_threshold
else:
overview_mask_small = (
(overview_prob > cfg["overview_threshold"])
.detach()
.cpu()
.numpy()
)
if sub_nodata.any():
nodata_overview = sub_nodata[::overview_step, ::overview_step]
overview_mask_small[nodata_overview] = False
coarse_probs = coarse_tile_segmentation(
processor=coarse_processor,
src=sub_src,
prompt=cfg["prompt"],
tile_size=cfg["coarse_tile_size"],
overlap=cfg["coarse_tile_overlap"],
downsample_factor=cfg["coarse_downsample_factor"],
stretch_params=stretch_params,
nodata_mask=sub_nodata if sub_nodata.any() else None,
use_tqdm=use_tqdm,
)
coarse_mask_lr = coarse_probs > cfg["coarse_threshold"]
overview_mask_lr = np.array(
Image.fromarray(
overview_mask_small.astype(np.uint8), mode="L"
).resize(
(coarse_mask_lr.shape[1], coarse_mask_lr.shape[0]),
resample=Image.NEAREST,
)
).astype(bool)
combined_mask_lr = coarse_mask_lr | overview_mask_lr
band_radius_lr = max(
int(round(cfg["band_radius"] / float(cfg["coarse_downsample_factor"]))),
1,
)
band_lr = _build_band_ndimage(combined_mask_lr, band_radius_lr)
fine_probs = refine_tiles(
processor=fine_processor,
src=sub_src,
prompt=cfg["prompt"],
band_lr=band_lr,
downsample_factor=cfg["coarse_downsample_factor"],
tile_size=cfg["fine_tile_size"],
overlap=cfg["fine_overlap"],
stretch_params=stretch_params,
use_tqdm=use_tqdm,
)
fine_mask = fine_probs > cfg["final_threshold"]
coarse_mask_full = np.array(
Image.fromarray(overview_mask_cpu.astype(np.uint8), mode="L").resize(
Image.fromarray(coarse_mask_lr.astype(np.uint8), mode="L").resize(
(region_w, region_h), resample=Image.NEAREST
)
).astype(bool)
sub_mask_np = (fine_mask | coarse_mask_full).astype(np.uint8)
overview_mask_full = np.array(
Image.fromarray(
overview_mask_small.astype(np.uint8), mode="L"
).resize((region_w, region_h), resample=Image.NEAREST)
).astype(bool)
sub_mask_np = (fine_mask | coarse_mask_full | overview_mask_full).astype(
np.uint8
)
# 合并到全尺寸掩码(重叠区域取最大值)
full_mask[top:bottom, left:right] = np.maximum(
full_mask[top:bottom, left:right], sub_mask_np
)
# 更新子区域进度条
region_pbar.update(1)
# ========== 后处理 ==========
print("\n后处理填充内部NoData空洞...")
log("\n后处理填充内部NoData空洞...")
if nodata is not None:
band1 = src.read(1)
if np.issubdtype(band1.dtype, np.floating):
nodata_mask = np.isclose(band1, float(nodata))
else:
nodata_mask = (band1 == nodata)
nodata_mask = band1 == nodata
labeled_mask, num_features = ndimage.label(nodata_mask)
labeled_mask, _ = ndimage.label(nodata_mask)
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
boundary_mask[0, :] = True
boundary_mask[-1, :] = True
@ -368,27 +574,45 @@ with rasterio.open(image_path) as src:
boundary_mask[:, -1] = True
boundary_labels = set(labeled_mask[boundary_mask])
internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask
internal_nodata_mask = np.isin(
labeled_mask, list(boundary_labels), invert=True
) & nodata_mask
if internal_nodata_mask.any():
struct = ndimage.generate_binary_structure(2, 2)
internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7)
internal_dilated = ndimage.binary_dilation(
internal_nodata_mask, structure=struct, iterations=7
)
full_mask[internal_dilated] = 1
print(f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}")
log(
f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}"
)
else:
print(" 无内部NoData区域")
log(" 无内部NoData区域")
print("后处理:面积过滤...")
if min_area is not None or keep_largest_only:
log("后处理:面积过滤...")
if cfg["min_area"] is not None or cfg["keep_largest_only"]:
original_count = np.sum(full_mask)
full_mask = filter_by_area(full_mask, min_area=min_area, keep_largest_only=keep_largest_only)
full_mask = filter_by_area(
full_mask,
min_area=cfg["min_area"],
keep_largest_only=cfg["keep_largest_only"],
)
filtered_count = np.sum(full_mask)
print(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
log(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
# ========== 保存结果 ==========
profile = src.profile.copy()
profile.update(count=1, dtype="uint8", compress='lzw')
profile.update(count=1, dtype="uint8", compress="lzw")
with rasterio.open(mask_output_path, "w", **profile) as dst:
dst.write(full_mask, 1)
print(f"分割完成,结果已保存至:{mask_output_path}")
log(f"分割完成,结果已保存至:{mask_output_path}")
def main():
matplotlib.use("TkAgg")
run_segmentation()
if __name__ == "__main__":
main()