This commit is contained in:
2026-03-10 17:29:24 +08:00
parent 64692d8382
commit b8b6c6227d
4 changed files with 993 additions and 147 deletions

117
README_water_GUI.md Normal file
View File

@ -0,0 +1,117 @@
# water_GUI.py 使用说明SAM3 水体分割 GUI
本说明文档对应脚本:[water_GUI.py](file:///e:/code/sam3-main/sam3-main/water_GUI.py)。
该脚本是 **纯 GUI 启动器**:界面负责参数编辑、提示说明、日志与进度展示;真正的分割计算由 [water_V5.py](file:///e:/code/sam3-main/sam3-main/water_V5.py) 提供的 `run_segmentation()` 执行,并在点击“运行”时才会被懒加载导入。
## 1. 功能概览
- 选择输入 GeoTIFF待分割遥感影像与输出掩膜 GeoTIFF 路径
- 编辑分割参数Overview/Coarse/Fine/Region/Post
- 鼠标悬停在参数输入框上显示详细提示tooltip
- 后台线程执行分割GUI 不会卡死
- 右侧日志窗口实时输出运行信息
- 进度条显示子区域处理进度
- 支持“停止”(会在子区域边界处响应)
## 2. 运行方式
在项目目录下运行:
```bash
python e:\code\sam3-main\sam3-main\water_GUI.py
```
运行后:
- 点击“浏览”选择输入影像(`.tif/.tiff`
- 点击“浏览”选择输出掩膜路径(建议 `.tif`
- 调整参数(可保持默认)
- 点击“运行”
## 3. 依赖与环境
GUI 本身依赖:
- Python 标准库 `tkinter`
执行分割时会懒加载并依赖(来自 water_V5.py
- torch建议 CUDA 可用)
- rasterio
- numpy
- scipy
- pillow
- tqdm
- 以及 sam3 模型代码与权重(由项目现有加载逻辑处理)
## 4. 参数说明(界面分组)
### 4.1 Overview 兜底
用于减少“中心水域漏检”。Overview 会对每个子区域做一次低分辨率全局预测,并在最终输出中与 coarse/fine 结果取并集。
- `overview_max_side`:兜底预测的最大边长。越大越准越慢。
- `overview_threshold`:兜底阈值。偏低可减少漏检,但误检会增加。
### 4.2 Coarse 粗分割
粗分割用于提高召回、生成精修的边缘带band并与 overview 一起决定精修区域范围。
- `coarse_threshold`:粗分割阈值。偏低提高召回,偏高减少误检。
- `coarse_tile_size`:粗分割分块大小(原图像素)。
- `coarse_tile_overlap`:粗分割块重叠像素。
- `coarse_downsample_factor`粗分割输出降采样比例。2 更细更慢4 折中8 更快更粗。
- `band_radius`:边缘带半径(原图像素)。越大精修范围越大,速度越慢。
### 4.3 Fine 精修
只对边缘带覆盖到的区域进行精细分割,补足边界细节。
- `final_threshold`:最终阈值。偏低更易连通,偏高更干净。
- `fine_tile_size`:精修分块大小(原图像素)。一般 1536 更稳2048 更快但更吃显存/更慢。
- `fine_overlap`:精修块重叠像素。越大接缝越少但更慢。
### 4.4 Region 子区域
将整图划分成多个有重叠的子区域,降低内存峰值。
- `num_splits_y`:纵向切分数。
- `num_splits_x`:横向切分数。
- `region_overlap`:子区域重叠像素,减少子区域边界断裂。
### 4.5 Post 后处理
- `min_area`:移除小碎片的最小连通域面积(像素)。
- `keep_largest_only`:只保留最大连通域(适合只要最大水体)。
## 5. 输出说明
输出文件为单通道 `uint8` 掩膜 GeoTIFF
- 0非水体
- 1水体
输出文件的空间参考CRS/transform来自输入影像的 profile。
## 6. 使用建议16GB 显存)
优先按以下顺序调参:
- 先把 `fine_tile_size` 从 2048 降到 1536更稳
- `fine_overlap` 从 256 降到 128更快可能稍有拼接缝
- `coarse_downsample_factor` 从 4 改 2更细更慢但中心/边缘召回更稳)
- `overview_threshold` 从 0.4 降到 0.35(召回更强,误检更可能增加)
## 7. 常见问题
### 7.1 点击“运行”后报 `ModuleNotFoundError: water_V5`
确保 `water_GUI.py``water_V5.py` 在同一目录:
- `e:\code\sam3-main\sam3-main\water_GUI.py`
- `e:\code\sam3-main\sam3-main\water_V5.py`
并从该目录执行运行命令,或将该目录加入 Python 路径。
### 7.2 界面正常但运行很慢
这通常是正常的:粗分割/精修会对大量瓦片做模型推理。优先调小 `fine_tile_size``fine_overlap`,并提高 `coarse_downsample_factor`
### 7.3 “停止”不立即生效
停止信号会在子区域边界处检查,因此可能需要等待当前子区域处理完成后才退出。

155
tif_caijain.py Normal file
View File

@ -0,0 +1,155 @@
"""
使用二值掩膜 TIF 文件值为1的区域需要去除对数据 TIF 文件进行掩膜。
输入:
data_tif: 要掩膜的数据文件路径
mask_tif: 二值掩膜文件路径值为1表示需要去除的区域
输出:
掩膜后的数据 TIF 文件,仅将掩膜对应位置设为 NoData
要求:
两个 TIF 文件具有相同的投影、分辨率、范围和尺寸(精确对齐),
否则程序将报错或行为未定义。
"""
import argparse
import numpy as np
import rasterio
from rasterio.windows import Window
import logging
from pathlib import Path
from tqdm import tqdm
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def mask_data_by_binary_mask(
data_path,
mask_path,
output_path=None,
remove_value=1,
nodata_value=None,
tile_size=4096,
):
"""使用二值掩膜 TIF 对数据 TIF 进行掩膜。
将数据 TIF 中对应掩膜值等于 remove_value 的像素设为 NoData其余保留。
性能建议:
- 若数据源是 tiled GeoTIFF可将 tile_size 设为 0 以按源文件块窗口遍历(通常更快)。
"""
data_path = Path(data_path)
mask_path = Path(mask_path)
if output_path is None:
output_path = data_path.parent / f"{data_path.stem}_masked{data_path.suffix}"
else:
output_path = Path(output_path)
logger.info(f"数据文件: {data_path.name}")
logger.info(f"掩膜文件: {mask_path.name}")
logger.info(f"去除掩膜值: {remove_value}")
with rasterio.Env(GDAL_NUM_THREADS="ALL_CPUS"):
with rasterio.open(data_path) as src_data, rasterio.open(mask_path) as src_mask:
if src_data.crs != src_mask.crs:
raise ValueError("数据与掩膜的 CRS 不一致,请先统一投影。")
if src_data.transform != src_mask.transform:
logger.warning(
"数据与掩膜的地理变换不一致,可能未精确对齐,继续处理可能存在风险。"
)
if (src_data.width, src_data.height) != (src_mask.width, src_mask.height):
raise ValueError("数据与掩膜的尺寸不一致,无法直接按像素对应掩膜。")
# 确定输出 NoData 值(并尽量匹配数据 dtype避免隐式类型转换带来的开销
if nodata_value is None:
nodata_value = src_data.nodata if src_data.nodata is not None else 0
try:
nodata_value_cast = np.array(
nodata_value, dtype=src_data.dtypes[0]
).item()
except Exception:
nodata_value_cast = nodata_value
# 创建输出元数据:基于数据源的元数据,更新 nodata 和压缩选项
out_meta = src_data.meta.copy()
out_meta.update(
{
"nodata": nodata_value,
"compress": (
src_data.compression.value if src_data.compression else "lzw"
),
"tiled": src_data.is_tiled,
}
)
if src_data.is_tiled:
out_meta.update(
{
"blockxsize": src_data.block_shapes[0][0],
"blockysize": src_data.block_shapes[0][1],
}
)
# 创建输出文件
with rasterio.open(output_path, "w", **out_meta) as dst:
width, height = src_data.width, src_data.height
if tile_size is None or tile_size <= 0:
windows = [w for _, w in src_data.block_windows(1)]
else:
stride = int(tile_size)
windows = [
Window(i, j, min(stride, width - i), min(stride, height - j))
for i in range(0, width, stride)
for j in range(0, height, stride)
]
with tqdm(total=len(windows), desc="处理瓦片", unit="") as pbar:
for window in windows:
# 读取相同位置的掩膜瓦片(假设完全对齐)
mask = src_mask.read(1, window=window)
remove_mask = mask == remove_value
# 读取数据瓦片
data = src_data.read(window=window) # shape: (bands, h, w)
if remove_mask.any():
for band_idx in range(data.shape[0]):
np.putmask(
data[band_idx], remove_mask, nodata_value_cast
)
dst.write(data, window=window)
pbar.update(1)
logger.info(f"处理完成,输出文件:{output_path}")
return str(output_path)
def main():
parser = argparse.ArgumentParser(
description="使用二值掩膜 TIF值为1的区域对数据 TIF 进行掩膜,将对应位置设为 NoData。"
)
parser.add_argument("data_tif", help="要掩膜的数据 TIF 文件路径")
parser.add_argument("mask_tif", help="二值掩膜 TIF 文件路径值为1表示需要去除的区域")
parser.add_argument("-o", "--output", help="输出文件路径 (可选)")
parser.add_argument("-r", "--remove_value", type=int, default=1,
help="掩膜中要去除的值默认为1")
parser.add_argument("-n", "--nodata", type=float,
help="输出 NoData 值 (可选,默认使用数据 TIF 的 NoData 或 0)")
parser.add_argument(
"-t",
"--tile_size",
type=int,
default=4096,
help="分块大小像素默认4096设为0则按源文件块窗口遍历tiled 文件通常更快)",
)
args = parser.parse_args()
mask_data_by_binary_mask(
args.data_tif, args.mask_tif, args.output,
args.remove_value, args.nodata, args.tile_size
)
if __name__ == "__main__":
exit(main())

350
water_GUI.py Normal file
View File

@ -0,0 +1,350 @@
"""
water_V5 的图形界面启动器(纯 GUI 文件)。
特点:
- 参数可配置
- 悬停显示提示
- 后台线程运行分割,界面不假死
"""
import threading
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
DEFAULTS = {
"image_path": "",
"mask_output_path": "",
"prompt": "water body",
"overview_max_side": 1400,
"overview_threshold": 0.4,
"coarse_threshold": 0.5,
"coarse_tile_size": 4096,
"coarse_tile_overlap": 256,
"coarse_downsample_factor": 4,
"band_radius": 64,
"final_threshold": 0.5,
"fine_tile_size": 2048,
"fine_overlap": 256,
"num_splits_y": 2,
"num_splits_x": 3,
"region_overlap": 256,
"min_area": 5000,
"keep_largest_only": False,
}
def _lazy_import_backend():
import water_V5
return water_V5
class Tooltip:
def __init__(self, widget, text):
self.widget = widget
self.text = text
self.tip = None
widget.bind("<Enter>", self._show)
widget.bind("<Leave>", self._hide)
def _show(self, _event=None):
if self.tip is not None or not self.text:
return
x = self.widget.winfo_rootx() + 10
y = self.widget.winfo_rooty() + self.widget.winfo_height() + 5
self.tip = tk.Toplevel(self.widget)
self.tip.wm_overrideredirect(True)
self.tip.wm_geometry(f"+{x}+{y}")
label = tk.Label(
self.tip,
text=self.text,
justify="left",
background="#ffffe0",
relief="solid",
borderwidth=1,
font=("Segoe UI", 9),
)
label.pack(ipadx=6, ipady=4)
def _hide(self, _event=None):
if self.tip is not None:
self.tip.destroy()
self.tip = None
class WaterSegGUI(tk.Tk):
def __init__(self):
super().__init__()
self.title("SAM3 水体分割 GUI")
self.geometry("980x720")
self.stop_event = threading.Event()
self.worker = None
self.vars = {}
self._build_ui()
def _build_ui(self):
container = ttk.Frame(self)
container.pack(fill="both", expand=True, padx=10, pady=10)
top = ttk.Frame(container)
top.pack(fill="x")
self._add_path_row(
top,
"image_path",
"输入影像 (TIF)",
DEFAULTS["image_path"],
"选择要分割的遥感影像 GeoTIFF。",
is_save=False,
)
self._add_path_row(
top,
"mask_output_path",
"输出掩膜 (TIF)",
DEFAULTS["mask_output_path"],
"输出单通道 uint8 掩膜0/1",
is_save=True,
)
row = ttk.Frame(top)
row.pack(fill="x", pady=(6, 0))
ttk.Label(row, text="prompt").pack(side="left")
prompt_var = tk.StringVar(value=DEFAULTS["prompt"])
self.vars["prompt"] = prompt_var
prompt_box = ttk.Combobox(
row,
textvariable=prompt_var,
values=["water body", "water", "river", "lake", "reservoir"],
)
prompt_box.pack(side="left", fill="x", expand=True, padx=(10, 0))
Tooltip(prompt_box, "文本提示。建议优先尝试water / river / lake。")
panes = ttk.Panedwindow(container, orient="horizontal")
panes.pack(fill="both", expand=True, pady=(10, 0))
left = ttk.Frame(panes)
right = ttk.Frame(panes)
panes.add(left, weight=3)
panes.add(right, weight=2)
params = ttk.Frame(left)
params.pack(fill="both", expand=True)
self._add_group(
params,
"Overview 兜底",
[
("overview_max_side", int, "子区域兜底预测最大边长。越大越准越慢。"),
("overview_threshold", float, "兜底阈值。偏低减少中心漏检,但误检增加。"),
],
)
self._add_group(
params,
"Coarse 粗分割",
[
("coarse_threshold", float, "粗分割阈值。偏低提高召回,偏高减少误检。"),
("coarse_tile_size", int, "粗分割分块大小(原图像素)。"),
("coarse_tile_overlap", int, "粗分割块重叠像素。"),
("coarse_downsample_factor", int, "粗分割输出降采样比例。2 更细更慢4 折中8 更快更粗。"),
("band_radius", int, "边缘带半径(原图像素)。越大精修范围越大越慢。"),
],
)
self._add_group(
params,
"Fine 精修",
[
("final_threshold", float, "最终阈值。偏低更易连通,偏高更干净。"),
("fine_tile_size", int, "精修分块大小(原图像素)。"),
("fine_overlap", int, "精修重叠像素。越大接缝越少但更慢。"),
],
)
self._add_group(
params,
"Region 子区域",
[
("num_splits_y", int, "纵向切分数。"),
("num_splits_x", int, "横向切分数。"),
("region_overlap", int, "子区域之间重叠像素,减少边界断裂。"),
],
)
self._add_group(
params,
"Post 后处理",
[
("min_area", int, "移除小碎片的最小连通域面积(像素)。"),
("keep_largest_only", bool, "只保留最大连通域(适合只要最大水体)。"),
],
)
ctrl = ttk.Frame(right)
ctrl.pack(fill="x")
self.progress = ttk.Progressbar(ctrl, maximum=100)
self.progress.pack(fill="x")
self.status = tk.StringVar(value="就绪")
ttk.Label(ctrl, textvariable=self.status).pack(anchor="w", pady=(6, 0))
btn_row = ttk.Frame(ctrl)
btn_row.pack(fill="x", pady=(8, 0))
self.run_btn = ttk.Button(btn_row, text="运行", command=self._on_run)
self.run_btn.pack(side="left")
self.stop_btn = ttk.Button(btn_row, text="停止", command=self._on_stop, state="disabled")
self.stop_btn.pack(side="left", padx=(8, 0))
log_frame = ttk.Frame(right)
log_frame.pack(fill="both", expand=True, pady=(10, 0))
ttk.Label(log_frame, text="日志").pack(anchor="w")
self.log_text = tk.Text(log_frame, height=20, wrap="word")
self.log_text.pack(side="left", fill="both", expand=True)
scroll = ttk.Scrollbar(log_frame, command=self.log_text.yview)
scroll.pack(side="right", fill="y")
self.log_text.configure(yscrollcommand=scroll.set)
def _add_path_row(self, parent, key, label, default, tip, is_save):
row = ttk.Frame(parent)
row.pack(fill="x", pady=(0, 6))
ttk.Label(row, text=label).pack(side="left")
var = tk.StringVar(value=default)
self.vars[key] = var
entry = ttk.Entry(row, textvariable=var)
entry.pack(side="left", fill="x", expand=True, padx=(10, 0))
Tooltip(entry, tip)
btn = ttk.Button(
row,
text="浏览",
command=lambda: self._browse_path(var, is_save=is_save),
)
btn.pack(side="left", padx=(8, 0))
def _add_group(self, parent, title, items):
frame = ttk.LabelFrame(parent, text=title)
frame.pack(fill="x", pady=(0, 10))
for r, (key, typ, tip) in enumerate(items):
ttk.Label(frame, text=key).grid(row=r, column=0, sticky="w", padx=(6, 8), pady=4)
if typ is bool:
var = tk.BooleanVar(value=bool(DEFAULTS.get(key)))
self.vars[key] = var
w = ttk.Checkbutton(frame, variable=var)
w.grid(row=r, column=1, sticky="w", pady=4)
Tooltip(w, tip)
else:
var = tk.StringVar(value=str(DEFAULTS.get(key, "")))
self.vars[key] = var
entry = ttk.Entry(frame, textvariable=var, width=14)
entry.grid(row=r, column=1, sticky="w", pady=4)
Tooltip(entry, tip)
frame.grid_columnconfigure(1, weight=1)
def _browse_path(self, var, is_save):
if is_save:
p = filedialog.asksaveasfilename(
title="选择输出文件",
defaultextension=".tif",
filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")],
)
else:
p = filedialog.askopenfilename(
title="选择输入文件",
filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")],
)
if p:
var.set(p)
def _append_log(self, msg):
self.log_text.insert("end", msg + "\n")
self.log_text.see("end")
def _log(self, msg):
self.after(0, lambda: self._append_log(str(msg)))
def _progress(self, stage, cur, total):
def _update():
if total > 0:
self.progress["value"] = (cur / total) * 100.0
self.status.set(f"{stage}: {cur}/{total}")
self.after(0, _update)
def _get_config(self):
cfg = {}
for k, v in self.vars.items():
if isinstance(v, tk.BooleanVar):
cfg[k] = bool(v.get())
else:
cfg[k] = v.get()
int_keys = [
"overview_max_side",
"band_radius",
"fine_tile_size",
"fine_overlap",
"coarse_tile_size",
"coarse_tile_overlap",
"coarse_downsample_factor",
"num_splits_y",
"num_splits_x",
"region_overlap",
"min_area",
]
float_keys = ["overview_threshold", "coarse_threshold", "final_threshold"]
for k in int_keys:
if k in cfg and cfg[k] != "":
cfg[k] = int(float(cfg[k]))
for k in float_keys:
if k in cfg and cfg[k] != "":
cfg[k] = float(cfg[k])
cfg["use_tqdm"] = False
cfg["device"] = None
return cfg
def _on_run(self):
if self.worker is not None and self.worker.is_alive():
return
cfg = self._get_config()
if not cfg.get("image_path") or not cfg.get("mask_output_path"):
messagebox.showerror("错误", "请先选择输入影像与输出路径。")
return
self.stop_event.clear()
self.run_btn.configure(state="disabled")
self.stop_btn.configure(state="normal")
self.progress["value"] = 0
self.status.set("启动中...")
self._append_log("开始运行...")
def _work():
try:
water_v5 = _lazy_import_backend()
water_v5.run_segmentation(
config=cfg,
progress_callback=self._progress,
log_callback=self._log,
stop_event=self.stop_event,
)
except Exception as e:
self._log(f"错误: {e}")
finally:
def _done():
self.run_btn.configure(state="normal")
self.stop_btn.configure(state="disabled")
self.status.set("完成")
self.after(0, _done)
self.worker = threading.Thread(target=_work, daemon=True)
self.worker.start()
def _on_stop(self):
self.stop_event.set()
self.status.set("停止中...")
def main():
WaterSegGUI().mainloop()
if __name__ == "__main__":
main()

View File

@ -1,14 +1,17 @@
""" """
使用SAM3模型分割超大遥感影像中的水体分区域处理,优化内存 使用SAM3模型分割超大遥感影像中的水体三层分割策略
流程: 流程:
1. 将影像划分为若干有重叠的子区域。 1. 将影像划分为若干有重叠的子区域(第一层)
2. 对每个子区域独立执行粗分割、边缘带构建、精修分割 2. 对每个子区域进行中分辨率粗分割以4096为块大小滑动窗口推理得到粗掩码第二层
3. 合并所有子区域的掩码(重叠区域取最大值)。 3. 基于粗掩码构建边缘带,对边缘带进行精细分块推理,得到精修掩码(第三层)。
4. 后处理填充内部NoData空洞、移除小面积碎片、可选保留最大连通域 4. 合并所有子区域的掩码(重叠区域取最大值)
5. 保存结果 5. 后处理填充内部NoData空洞、移除小面积碎片、可选保留最大连通域
6. 保存结果。
优点:避免在全尺寸概率图上进行GPU操作显著降低显存占用。 优点:
- 粗分割采用分块推理,避免降采样丢失细节,提高召回率。
- 精修只处理边缘带,节省计算资源。
""" """
import torch import torch
@ -22,6 +25,7 @@ from rasterio.windows import Window
from rasterio.io import MemoryFile from rasterio.io import MemoryFile
from tqdm import tqdm from tqdm import tqdm
from scipy import ndimage from scipy import ndimage
import math
from sam3.model_builder import build_sam3_image_model from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor from sam3.model.sam3_image_processor import Sam3Processor
@ -77,6 +81,13 @@ def compute_stretch_params(bands, sample_max_pixels=2_000_000):
return vmin, vmax return vmin, vmax
def make_overview_bands(bands, max_side):
_, h, w = bands.shape
step = int(math.ceil(max(h, w) / float(max_side)))
step = max(step, 1)
return bands[:, ::step, ::step], step
def _to_uint8(arr, vmin=None, vmax=None): def _to_uint8(arr, vmin=None, vmax=None):
if arr.dtype == np.uint8: if arr.dtype == np.uint8:
return arr return arr
@ -112,7 +123,7 @@ def _bands_to_pil(bands, stretch_params=None):
return Image.fromarray(rgb, mode="RGB") return Image.fromarray(rgb, mode="RGB")
def _read_pil_window(src, window, stretch_params=None): def _read_pil_window(src, window, stretch_params=None):
bands = src.read(window=window, boundless=True, fill_value=0) bands = src.read(window=window)
return _bands_to_pil(bands, stretch_params=stretch_params) return _bands_to_pil(bands, stretch_params=stretch_params)
def infer_prompt_prob(processor, image, prompt): def infer_prompt_prob(processor, image, prompt):
@ -126,57 +137,169 @@ def infer_prompt_prob(processor, image, prompt):
prob = combine_masks_logits(state["masks_logits"]) prob = combine_masks_logits(state["masks_logits"])
return prob return prob
def make_overview_bands(sub_bands, max_side): # ---------- 新增:中分辨率粗分割(分块推理) ----------
_, h, w = sub_bands.shape def _downsample_any(mask, factor):
step = int(np.ceil(max(h, w) / float(max_side))) if factor <= 1:
step = max(step, 1) return mask
return sub_bands[:, ::step, ::step], step h, w = mask.shape
pad_h = (-h) % factor
pad_w = (-w) % factor
if pad_h or pad_w:
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=False)
h, w = mask.shape
h2 = h // factor
w2 = w // factor
return mask.reshape(h2, factor, w2, factor).any(axis=(1, 3))
# ---------- 精修分块 ----------
def refine_tiles(processor, src, prompt, band_coarse_cpu, tile_size, overlap, stretch_params=None): def _build_band_ndimage(mask, radius):
if radius <= 0:
return np.zeros_like(mask, dtype=bool)
struct = ndimage.generate_binary_structure(2, 2)
dil = ndimage.binary_dilation(mask, structure=struct, iterations=radius)
ero = ndimage.binary_erosion(mask, structure=struct, iterations=radius)
return np.logical_xor(dil, ero)
def coarse_tile_segmentation(
processor,
src,
prompt,
tile_size,
overlap,
downsample_factor=4,
stretch_params=None,
nodata_mask=None,
use_tqdm=True,
):
""" """
边缘带进行精修分割 子区域进行分块粗分割返回低分辨率概率图float16
band_coarse_cpu: 2D numpy bool指示需要精修的区域低分辨率
Args:
processor: 粗处理器(输入分辨率 coarse_resolution
src: rasterio 数据集(子区域)
prompt: 文本提示
tile_size: 粗分割块大小(原图像素)
overlap: 块重叠像素
stretch_params: 拉伸参数
nodata_mask: 子区域的 NoData 掩码bool, 与子区域同尺寸),用于过滤无效区域
Returns:
coarse_probs_lr: 子区域粗概率图 (H_lr,W_lr) float16
""" """
height, width = src.height, src.width height, width = src.height, src.width
full_probs = np.zeros((height, width), dtype=np.float16) height_lr = int(math.ceil(height / float(downsample_factor)))
width_lr = int(math.ceil(width / float(downsample_factor)))
band_h, band_w = band_coarse_cpu.shape full_probs_lr = np.zeros((height_lr, width_lr), dtype=np.float16)
scale_y = band_h / float(height)
scale_x = band_w / float(width)
stride = max(tile_size - overlap, 1) stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x total_tiles = num_tiles_y * num_tiles_x
# 内部进度条leave=True 以便保留历史记录 if use_tqdm:
with tqdm(total=total_tiles, desc="精修分块", unit="", leave=True) as pbar: iterator = tqdm(
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap): tile_slices(height, width, tile_size, overlap),
c_top = int(top * scale_y) total=total_tiles,
c_left = int(left * scale_x) desc="粗分割分块",
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1) unit="",
c_right = max(int(np.ceil(right * scale_x)), c_left + 1) leave=True,
)
else:
iterator = tile_slices(height, width, tile_size, overlap)
if not band_coarse_cpu[c_top:c_bottom, c_left:c_right].any(): for top, left, bottom, right in iterator:
pbar.update(1) # 如果该块完全在 NoData 区域内,则跳过(加速)
if nodata_mask is not None:
block_nodata = nodata_mask[top:bottom, left:right]
if block_nodata.all():
continue
window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window, stretch_params=stretch_params)
tile_prob = infer_prompt_prob(processor, crop, prompt)
if tile_prob is None:
continue
h_lr = int(math.ceil((bottom - top) / float(downsample_factor)))
w_lr = int(math.ceil((right - left) / float(downsample_factor)))
if tile_prob.shape[-2:] != (h_lr, w_lr):
tile_prob = upsample_prob(tile_prob, (h_lr, w_lr))
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
top_lr = top // downsample_factor
left_lr = left // downsample_factor
bottom_lr = top_lr + h_lr
right_lr = left_lr + w_lr
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr] = np.maximum(
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr], tile_prob_cpu
)
if nodata_mask is not None:
nodata_lr = _downsample_any(nodata_mask, downsample_factor)
full_probs_lr[nodata_lr] = 0.0
return full_probs_lr
# ---------- 精修分块(保持不变,但可复用之前的函数) ----------
def refine_tiles(
processor,
src,
prompt,
band_lr,
downsample_factor,
tile_size,
overlap,
stretch_params=None,
use_tqdm=True,
):
"""
对边缘带进行精修分割
band_lr: 2D numpy bool指示需要精修的区域低分辨率
"""
height, width = src.height, src.width
full_probs = np.zeros((height, width), dtype=np.float16)
stride = max(tile_size - overlap, 1)
num_tiles_y = (height + stride - 1) // stride
num_tiles_x = (width + stride - 1) // stride
total_tiles = num_tiles_y * num_tiles_x
if use_tqdm:
iterator = tqdm(
tile_slices(height, width, tile_size, overlap),
total=total_tiles,
desc="精修分块",
unit="",
leave=True,
)
else:
iterator = tile_slices(height, width, tile_size, overlap)
for top, left, bottom, right in iterator:
c_top = top // downsample_factor
c_left = left // downsample_factor
c_bottom = int(math.ceil(bottom / float(downsample_factor)))
c_right = int(math.ceil(right / float(downsample_factor)))
c_bottom = max(c_bottom, c_top + 1)
c_right = max(c_right, c_left + 1)
if not band_lr[c_top:c_bottom, c_left:c_right].any():
continue continue
window = Window(left, top, right - left, bottom - top) window = Window(left, top, right - left, bottom - top)
crop = _read_pil_window(src, window, stretch_params=stretch_params) crop = _read_pil_window(src, window, stretch_params=stretch_params)
tile_prob = infer_prompt_prob(processor, crop, prompt) tile_prob = infer_prompt_prob(processor, crop, prompt)
if tile_prob is None: if tile_prob is None:
pbar.update(1)
continue continue
if tile_prob.shape[-2:] != (bottom - top, right - left): if tile_prob.shape[-2:] != (bottom - top, right - left):
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left)) tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
tile_prob_cpu = tile_prob.detach().float().cpu().numpy().astype(np.float16) tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
full_probs[top:bottom, left:right] = np.maximum( full_probs[top:bottom, left:right] = np.maximum(
full_probs[top:bottom, left:right], tile_prob_cpu full_probs[top:bottom, left:right], tile_prob_cpu
) )
pbar.update(1)
return full_probs return full_probs
@ -230,80 +353,121 @@ def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=25
regions.append((left, top, right, bottom)) regions.append((left, top, right, bottom))
return regions return regions
# ---------- 主程序 ---------- DEFAULT_CONFIG = {
matplotlib.use("TkAgg") "image_path": r"E:\is2\guidingsahn\result.tif",
"mask_output_path": r"E:\is2\guidingsahn\result_mask.tif",
"prompt": "water body",
"overview_max_side": 1400,
"overview_threshold": 0.4,
"coarse_resolution": 1008,
"fine_resolution": 1008,
"coarse_threshold": 0.5,
"final_threshold": 0.5,
"band_radius": 64,
"fine_tile_size": 2048,
"fine_overlap": 256,
"coarse_tile_size": 4096,
"coarse_tile_overlap": 256,
"coarse_downsample_factor": 4,
"num_splits_y": 2,
"num_splits_x": 3,
"region_overlap": 256,
"min_area": 5000,
"keep_largest_only": False,
"use_tqdm": True,
"device": None,
}
# 参数设置
image_path = r"E:\is2\dingshanhu\result_caijian.tif"
mask_output_path = r"E:\is2\dingshanhu\result_maskV2.tif"
prompt = "water body"
coarse_read_max_side = 1200
coarse_resolution = 1008
fine_resolution = 1008
coarse_threshold = 0.5
final_threshold = 0.5
band_radius = 64
tile_size = 1536
overlap = 128
# ========== 分区域参数 ========== def run_segmentation(config=None, progress_callback=None, log_callback=None, stop_event=None):
num_splits_y = 2 # 纵向切分数 cfg = dict(DEFAULT_CONFIG)
num_splits_x = 3 # 横向切分数(共 2x3=6 份) if config:
region_overlap = 256 # 子区域之间的重叠像素数 cfg.update(config)
# ===============================
# ========== 后处理参数 ========== log = log_callback if log_callback is not None else print
min_area = 1000 # 最小面积阈值像素小于此值的连通域将被移除设为0或None表示不进行面积过滤
keep_largest_only = False # 是否只保留最大的连通域True/False
# ===============================
device = "cuda" if torch.cuda.is_available() else "cpu" device = cfg["device"]
print(f"使用设备: {device}") if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
log(f"使用设备: {device}")
# 加载模型(所有子区域共享同一个模型) model = build_sam3_image_model().to(device).eval()
model = build_sam3_image_model().to(device).eval() coarse_processor = Sam3Processor(
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device) model, resolution=cfg["coarse_resolution"], device=device
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device) )
fine_processor = Sam3Processor(model, resolution=cfg["fine_resolution"], device=device)
# 打开原始影像 image_path = cfg["image_path"]
with rasterio.open(image_path) as src: mask_output_path = cfg["mask_output_path"]
with rasterio.open(image_path) as src:
nodata = src.nodata nodata = src.nodata
print(f"原始影像NoData值: {nodata}") log(f"原始影像NoData值: {nodata}")
height, width = src.height, src.width height, width = src.height, src.width
# 划分区域 band1 = src.read(1)
regions = split_into_regions(width, height, num_splits_y, num_splits_x, region_overlap) if np.issubdtype(band1.dtype, np.floating):
print(f"共划分 {len(regions)} 个子区域") nodata_mask_full = (
np.isclose(band1, float(nodata))
if nodata is not None
else np.zeros_like(band1, dtype=bool)
)
else:
nodata_mask_full = (
(band1 == nodata) if nodata is not None else np.zeros_like(band1, dtype=bool)
)
regions = split_into_regions(
width, height, cfg["num_splits_y"], cfg["num_splits_x"], cfg["region_overlap"]
)
log(f"共划分 {len(regions)} 个子区域")
# 创建全尺寸掩码数组CPU内存
full_mask = np.zeros((height, width), dtype=np.uint8) full_mask = np.zeros((height, width), dtype=np.uint8)
# 子区域总进度条 use_tqdm = bool(cfg.get("use_tqdm", True))
with tqdm(total=len(regions), desc="处理子区域", unit="子区域") as region_pbar: if use_tqdm:
for idx, (left, top, right, bottom) in enumerate(regions): region_iter = tqdm(
print(f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})") list(enumerate(regions)),
total=len(regions),
desc="处理子区域",
unit="子区域",
)
else:
region_iter = enumerate(regions)
for idx, (left, top, right, bottom) in region_iter:
if stop_event is not None and stop_event.is_set():
log("已停止")
return
if progress_callback is not None:
progress_callback("region", idx + 1, len(regions))
log(
f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})"
)
region_w = right - left region_w = right - left
region_h = bottom - top region_h = bottom - top
# 读取子区域数据
window = Window(left, top, region_w, region_h) window = Window(left, top, region_w, region_h)
sub_bands = src.read(window=window) # shape: (bands, h, w) sub_bands = src.read(window=window)
# 构建子区域元数据
sub_profile = src.profile.copy() sub_profile = src.profile.copy()
if not src.is_tiled: if not src.is_tiled:
sub_profile.pop('blockxsize', None) sub_profile.pop("blockxsize", None)
sub_profile.pop('blockysize', None) sub_profile.pop("blockysize", None)
sub_profile['tiled'] = False sub_profile["tiled"] = False
else: else:
sub_profile['tiled'] = True sub_profile["tiled"] = True
sub_profile.update({ sub_profile.update(
'height': region_h, {
'width': region_w, "height": region_h,
'transform': rasterio.windows.transform(window, src.transform) "width": region_w,
}) "transform": rasterio.windows.transform(window, src.transform),
}
)
sub_nodata = nodata_mask_full[top:bottom, left:right]
# 将子区域数据包装成内存中的 rasterio 数据集
with MemoryFile() as memfile: with MemoryFile() as memfile:
with memfile.open(**sub_profile) as sub_dst: with memfile.open(**sub_profile) as sub_dst:
sub_dst.write(sub_bands) sub_dst.write(sub_bands)
@ -311,56 +475,98 @@ with rasterio.open(image_path) as src:
stretch_params = compute_stretch_params(sub_bands) stretch_params = compute_stretch_params(sub_bands)
overview_bands, overview_step = make_overview_bands( overview_bands, overview_step = make_overview_bands(
sub_bands, max_side=coarse_read_max_side sub_bands, max_side=cfg["overview_max_side"]
)
overview_img = _bands_to_pil(
overview_bands, stretch_params=stretch_params
)
overview_prob = infer_prompt_prob(
coarse_processor, overview_img, cfg["prompt"]
) )
overview_img = _bands_to_pil(overview_bands, stretch_params=stretch_params)
overview_prob = infer_prompt_prob(coarse_processor, overview_img, prompt)
if overview_prob is None: if overview_prob is None:
overview_prob_cpu = np.zeros((overview_img.height, overview_img.width), dtype=np.float16) overview_mask_small = np.zeros(
else: (overview_img.height, overview_img.width), dtype=bool
overview_prob_cpu = overview_prob.detach().float().cpu().numpy().astype(np.float16)
overview_mask_cpu = overview_prob_cpu > coarse_threshold
scale = overview_img.height / float(region_h)
band_radius_small = max(int(round(band_radius * scale)), 1)
band_small = build_band(torch.from_numpy(overview_mask_cpu), band_radius_small).numpy()
fine_probs_cpu = refine_tiles(
fine_processor,
sub_src,
prompt,
band_small,
tile_size=tile_size,
overlap=overlap,
stretch_params=stretch_params,
) )
fine_mask = fine_probs_cpu > final_threshold else:
overview_mask_small = (
(overview_prob > cfg["overview_threshold"])
.detach()
.cpu()
.numpy()
)
if sub_nodata.any():
nodata_overview = sub_nodata[::overview_step, ::overview_step]
overview_mask_small[nodata_overview] = False
coarse_probs = coarse_tile_segmentation(
processor=coarse_processor,
src=sub_src,
prompt=cfg["prompt"],
tile_size=cfg["coarse_tile_size"],
overlap=cfg["coarse_tile_overlap"],
downsample_factor=cfg["coarse_downsample_factor"],
stretch_params=stretch_params,
nodata_mask=sub_nodata if sub_nodata.any() else None,
use_tqdm=use_tqdm,
)
coarse_mask_lr = coarse_probs > cfg["coarse_threshold"]
overview_mask_lr = np.array(
Image.fromarray(
overview_mask_small.astype(np.uint8), mode="L"
).resize(
(coarse_mask_lr.shape[1], coarse_mask_lr.shape[0]),
resample=Image.NEAREST,
)
).astype(bool)
combined_mask_lr = coarse_mask_lr | overview_mask_lr
band_radius_lr = max(
int(round(cfg["band_radius"] / float(cfg["coarse_downsample_factor"]))),
1,
)
band_lr = _build_band_ndimage(combined_mask_lr, band_radius_lr)
fine_probs = refine_tiles(
processor=fine_processor,
src=sub_src,
prompt=cfg["prompt"],
band_lr=band_lr,
downsample_factor=cfg["coarse_downsample_factor"],
tile_size=cfg["fine_tile_size"],
overlap=cfg["fine_overlap"],
stretch_params=stretch_params,
use_tqdm=use_tqdm,
)
fine_mask = fine_probs > cfg["final_threshold"]
coarse_mask_full = np.array( coarse_mask_full = np.array(
Image.fromarray(overview_mask_cpu.astype(np.uint8), mode="L").resize( Image.fromarray(coarse_mask_lr.astype(np.uint8), mode="L").resize(
(region_w, region_h), resample=Image.NEAREST (region_w, region_h), resample=Image.NEAREST
) )
).astype(bool) ).astype(bool)
sub_mask_np = (fine_mask | coarse_mask_full).astype(np.uint8) overview_mask_full = np.array(
Image.fromarray(
overview_mask_small.astype(np.uint8), mode="L"
).resize((region_w, region_h), resample=Image.NEAREST)
).astype(bool)
sub_mask_np = (fine_mask | coarse_mask_full | overview_mask_full).astype(
np.uint8
)
# 合并到全尺寸掩码(重叠区域取最大值)
full_mask[top:bottom, left:right] = np.maximum( full_mask[top:bottom, left:right] = np.maximum(
full_mask[top:bottom, left:right], sub_mask_np full_mask[top:bottom, left:right], sub_mask_np
) )
# 更新子区域进度条 log("\n后处理填充内部NoData空洞...")
region_pbar.update(1)
# ========== 后处理 ==========
print("\n后处理填充内部NoData空洞...")
if nodata is not None: if nodata is not None:
band1 = src.read(1)
if np.issubdtype(band1.dtype, np.floating): if np.issubdtype(band1.dtype, np.floating):
nodata_mask = np.isclose(band1, float(nodata)) nodata_mask = np.isclose(band1, float(nodata))
else: else:
nodata_mask = (band1 == nodata) nodata_mask = band1 == nodata
labeled_mask, num_features = ndimage.label(nodata_mask) labeled_mask, _ = ndimage.label(nodata_mask)
boundary_mask = np.zeros_like(nodata_mask, dtype=bool) boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
boundary_mask[0, :] = True boundary_mask[0, :] = True
boundary_mask[-1, :] = True boundary_mask[-1, :] = True
@ -368,27 +574,45 @@ with rasterio.open(image_path) as src:
boundary_mask[:, -1] = True boundary_mask[:, -1] = True
boundary_labels = set(labeled_mask[boundary_mask]) boundary_labels = set(labeled_mask[boundary_mask])
internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask internal_nodata_mask = np.isin(
labeled_mask, list(boundary_labels), invert=True
) & nodata_mask
if internal_nodata_mask.any(): if internal_nodata_mask.any():
struct = ndimage.generate_binary_structure(2, 2) struct = ndimage.generate_binary_structure(2, 2)
internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7) internal_dilated = ndimage.binary_dilation(
internal_nodata_mask, structure=struct, iterations=7
)
full_mask[internal_dilated] = 1 full_mask[internal_dilated] = 1
print(f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}") log(
f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}"
)
else: else:
print(" 无内部NoData区域") log(" 无内部NoData区域")
print("后处理:面积过滤...") log("后处理:面积过滤...")
if min_area is not None or keep_largest_only: if cfg["min_area"] is not None or cfg["keep_largest_only"]:
original_count = np.sum(full_mask) original_count = np.sum(full_mask)
full_mask = filter_by_area(full_mask, min_area=min_area, keep_largest_only=keep_largest_only) full_mask = filter_by_area(
full_mask,
min_area=cfg["min_area"],
keep_largest_only=cfg["keep_largest_only"],
)
filtered_count = np.sum(full_mask) filtered_count = np.sum(full_mask)
print(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}") log(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
# ========== 保存结果 ==========
profile = src.profile.copy() profile = src.profile.copy()
profile.update(count=1, dtype="uint8", compress='lzw') profile.update(count=1, dtype="uint8", compress="lzw")
with rasterio.open(mask_output_path, "w", **profile) as dst: with rasterio.open(mask_output_path, "w", **profile) as dst:
dst.write(full_mask, 1) dst.write(full_mask, 1)
print(f"分割完成,结果已保存至:{mask_output_path}") log(f"分割完成,结果已保存至:{mask_output_path}")
def main():
matplotlib.use("TkAgg")
run_segmentation()
if __name__ == "__main__":
main()