diff --git a/README_water_GUI.md b/README_water_GUI.md new file mode 100644 index 0000000..889cdc3 --- /dev/null +++ b/README_water_GUI.md @@ -0,0 +1,117 @@ +# water_GUI.py 使用说明(SAM3 水体分割 GUI) + +本说明文档对应脚本:[water_GUI.py](file:///e:/code/sam3-main/sam3-main/water_GUI.py)。 +该脚本是 **纯 GUI 启动器**:界面负责参数编辑、提示说明、日志与进度展示;真正的分割计算由 [water_V5.py](file:///e:/code/sam3-main/sam3-main/water_V5.py) 提供的 `run_segmentation()` 执行,并在点击“运行”时才会被懒加载导入。 + +## 1. 功能概览 + +- 选择输入 GeoTIFF(待分割遥感影像)与输出掩膜 GeoTIFF 路径 +- 编辑分割参数(Overview/Coarse/Fine/Region/Post) +- 鼠标悬停在参数输入框上显示详细提示(tooltip) +- 后台线程执行分割,GUI 不会卡死 +- 右侧日志窗口实时输出运行信息 +- 进度条显示子区域处理进度 +- 支持“停止”(会在子区域边界处响应) + +## 2. 运行方式 + +在项目目录下运行: + +```bash +python e:\code\sam3-main\sam3-main\water_GUI.py +``` + +运行后: +- 点击“浏览”选择输入影像(`.tif/.tiff`) +- 点击“浏览”选择输出掩膜路径(建议 `.tif`) +- 调整参数(可保持默认) +- 点击“运行” + +## 3. 依赖与环境 + +GUI 本身依赖: +- Python 标准库 `tkinter` + +执行分割时会懒加载并依赖(来自 water_V5.py): +- torch(建议 CUDA 可用) +- rasterio +- numpy +- scipy +- pillow +- tqdm +- 以及 sam3 模型代码与权重(由项目现有加载逻辑处理) + +## 4. 参数说明(界面分组) + +### 4.1 Overview 兜底 + +用于减少“中心水域漏检”。Overview 会对每个子区域做一次低分辨率全局预测,并在最终输出中与 coarse/fine 结果取并集。 + +- `overview_max_side`:兜底预测的最大边长。越大越准越慢。 +- `overview_threshold`:兜底阈值。偏低可减少漏检,但误检会增加。 + +### 4.2 Coarse 粗分割 + +粗分割用于提高召回、生成精修的边缘带(band),并与 overview 一起决定精修区域范围。 + +- `coarse_threshold`:粗分割阈值。偏低提高召回,偏高减少误检。 +- `coarse_tile_size`:粗分割分块大小(原图像素)。 +- `coarse_tile_overlap`:粗分割块重叠像素。 +- `coarse_downsample_factor`:粗分割输出降采样比例。2 更细更慢;4 折中;8 更快更粗。 +- `band_radius`:边缘带半径(原图像素)。越大精修范围越大,速度越慢。 + +### 4.3 Fine 精修 + +只对边缘带覆盖到的区域进行精细分割,补足边界细节。 + +- `final_threshold`:最终阈值。偏低更易连通,偏高更干净。 +- `fine_tile_size`:精修分块大小(原图像素)。一般 1536 更稳,2048 更快但更吃显存/更慢。 +- `fine_overlap`:精修块重叠像素。越大接缝越少但更慢。 + +### 4.4 Region 子区域 + +将整图划分成多个有重叠的子区域,降低内存峰值。 + +- `num_splits_y`:纵向切分数。 +- `num_splits_x`:横向切分数。 +- `region_overlap`:子区域重叠像素,减少子区域边界断裂。 + +### 4.5 Post 后处理 + +- `min_area`:移除小碎片的最小连通域面积(像素)。 +- `keep_largest_only`:只保留最大连通域(适合只要最大水体)。 + +## 5. 输出说明 + +输出文件为单通道 `uint8` 掩膜 GeoTIFF: +- 0:非水体 +- 1:水体 + +输出文件的空间参考(CRS/transform)来自输入影像的 profile。 + +## 6. 使用建议(16GB 显存) + +优先按以下顺序调参: +- 先把 `fine_tile_size` 从 2048 降到 1536(更稳) +- `fine_overlap` 从 256 降到 128(更快,可能稍有拼接缝) +- `coarse_downsample_factor` 从 4 改 2(更细更慢,但中心/边缘召回更稳) +- `overview_threshold` 从 0.4 降到 0.35(召回更强,误检更可能增加) + +## 7. 常见问题 + +### 7.1 点击“运行”后报 `ModuleNotFoundError: water_V5` + +确保 `water_GUI.py` 与 `water_V5.py` 在同一目录: +- `e:\code\sam3-main\sam3-main\water_GUI.py` +- `e:\code\sam3-main\sam3-main\water_V5.py` + +并从该目录执行运行命令,或将该目录加入 Python 路径。 + +### 7.2 界面正常但运行很慢 + +这通常是正常的:粗分割/精修会对大量瓦片做模型推理。优先调小 `fine_tile_size`、`fine_overlap`,并提高 `coarse_downsample_factor`。 + +### 7.3 “停止”不立即生效 + +停止信号会在子区域边界处检查,因此可能需要等待当前子区域处理完成后才退出。 + diff --git a/tif_caijain.py b/tif_caijain.py new file mode 100644 index 0000000..dc9218b --- /dev/null +++ b/tif_caijain.py @@ -0,0 +1,155 @@ +""" +使用二值掩膜 TIF 文件(值为1的区域需要去除)对数据 TIF 文件进行掩膜。 +输入: + data_tif: 要掩膜的数据文件路径 + mask_tif: 二值掩膜文件路径(值为1表示需要去除的区域) +输出: + 掩膜后的数据 TIF 文件,仅将掩膜对应位置设为 NoData +要求: + 两个 TIF 文件具有相同的投影、分辨率、范围和尺寸(精确对齐), + 否则程序将报错或行为未定义。 +""" + +import argparse +import numpy as np +import rasterio +from rasterio.windows import Window +import logging +from pathlib import Path +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + + +def mask_data_by_binary_mask( + data_path, + mask_path, + output_path=None, + remove_value=1, + nodata_value=None, + tile_size=4096, +): + """使用二值掩膜 TIF 对数据 TIF 进行掩膜。 + + 将数据 TIF 中对应掩膜值等于 remove_value 的像素设为 NoData,其余保留。 + + 性能建议: + - 若数据源是 tiled GeoTIFF,可将 tile_size 设为 0 以按源文件块窗口遍历(通常更快)。 + """ + data_path = Path(data_path) + mask_path = Path(mask_path) + + if output_path is None: + output_path = data_path.parent / f"{data_path.stem}_masked{data_path.suffix}" + else: + output_path = Path(output_path) + + logger.info(f"数据文件: {data_path.name}") + logger.info(f"掩膜文件: {mask_path.name}") + logger.info(f"去除掩膜值: {remove_value}") + + with rasterio.Env(GDAL_NUM_THREADS="ALL_CPUS"): + with rasterio.open(data_path) as src_data, rasterio.open(mask_path) as src_mask: + if src_data.crs != src_mask.crs: + raise ValueError("数据与掩膜的 CRS 不一致,请先统一投影。") + if src_data.transform != src_mask.transform: + logger.warning( + "数据与掩膜的地理变换不一致,可能未精确对齐,继续处理可能存在风险。" + ) + if (src_data.width, src_data.height) != (src_mask.width, src_mask.height): + raise ValueError("数据与掩膜的尺寸不一致,无法直接按像素对应掩膜。") + + # 确定输出 NoData 值(并尽量匹配数据 dtype,避免隐式类型转换带来的开销) + if nodata_value is None: + nodata_value = src_data.nodata if src_data.nodata is not None else 0 + try: + nodata_value_cast = np.array( + nodata_value, dtype=src_data.dtypes[0] + ).item() + except Exception: + nodata_value_cast = nodata_value + + # 创建输出元数据:基于数据源的元数据,更新 nodata 和压缩选项 + out_meta = src_data.meta.copy() + out_meta.update( + { + "nodata": nodata_value, + "compress": ( + src_data.compression.value if src_data.compression else "lzw" + ), + "tiled": src_data.is_tiled, + } + ) + if src_data.is_tiled: + out_meta.update( + { + "blockxsize": src_data.block_shapes[0][0], + "blockysize": src_data.block_shapes[0][1], + } + ) + + # 创建输出文件 + with rasterio.open(output_path, "w", **out_meta) as dst: + width, height = src_data.width, src_data.height + + if tile_size is None or tile_size <= 0: + windows = [w for _, w in src_data.block_windows(1)] + else: + stride = int(tile_size) + windows = [ + Window(i, j, min(stride, width - i), min(stride, height - j)) + for i in range(0, width, stride) + for j in range(0, height, stride) + ] + + with tqdm(total=len(windows), desc="处理瓦片", unit="块") as pbar: + for window in windows: + # 读取相同位置的掩膜瓦片(假设完全对齐) + mask = src_mask.read(1, window=window) + remove_mask = mask == remove_value + + # 读取数据瓦片 + data = src_data.read(window=window) # shape: (bands, h, w) + + if remove_mask.any(): + for band_idx in range(data.shape[0]): + np.putmask( + data[band_idx], remove_mask, nodata_value_cast + ) + + dst.write(data, window=window) + pbar.update(1) + + logger.info(f"处理完成,输出文件:{output_path}") + return str(output_path) + + +def main(): + parser = argparse.ArgumentParser( + description="使用二值掩膜 TIF(值为1的区域)对数据 TIF 进行掩膜,将对应位置设为 NoData。" + ) + parser.add_argument("data_tif", help="要掩膜的数据 TIF 文件路径") + parser.add_argument("mask_tif", help="二值掩膜 TIF 文件路径(值为1表示需要去除的区域)") + parser.add_argument("-o", "--output", help="输出文件路径 (可选)") + parser.add_argument("-r", "--remove_value", type=int, default=1, + help="掩膜中要去除的值,默认为1") + parser.add_argument("-n", "--nodata", type=float, + help="输出 NoData 值 (可选,默认使用数据 TIF 的 NoData 或 0)") + parser.add_argument( + "-t", + "--tile_size", + type=int, + default=4096, + help="分块大小(像素),默认4096;设为0则按源文件块窗口遍历(tiled 文件通常更快)", + ) + + args = parser.parse_args() + mask_data_by_binary_mask( + args.data_tif, args.mask_tif, args.output, + args.remove_value, args.nodata, args.tile_size + ) + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/water_GUI.py b/water_GUI.py new file mode 100644 index 0000000..a38a69a --- /dev/null +++ b/water_GUI.py @@ -0,0 +1,350 @@ +""" +water_V5 的图形界面启动器(纯 GUI 文件)。 + +特点: +- 参数可配置 +- 悬停显示提示 +- 后台线程运行分割,界面不假死 +""" + +import threading +import tkinter as tk +from tkinter import filedialog, messagebox, ttk + + +DEFAULTS = { + "image_path": "", + "mask_output_path": "", + "prompt": "water body", + "overview_max_side": 1400, + "overview_threshold": 0.4, + "coarse_threshold": 0.5, + "coarse_tile_size": 4096, + "coarse_tile_overlap": 256, + "coarse_downsample_factor": 4, + "band_radius": 64, + "final_threshold": 0.5, + "fine_tile_size": 2048, + "fine_overlap": 256, + "num_splits_y": 2, + "num_splits_x": 3, + "region_overlap": 256, + "min_area": 5000, + "keep_largest_only": False, +} + + +def _lazy_import_backend(): + import water_V5 + + return water_V5 + + +class Tooltip: + def __init__(self, widget, text): + self.widget = widget + self.text = text + self.tip = None + widget.bind("", self._show) + widget.bind("", self._hide) + + def _show(self, _event=None): + if self.tip is not None or not self.text: + return + x = self.widget.winfo_rootx() + 10 + y = self.widget.winfo_rooty() + self.widget.winfo_height() + 5 + self.tip = tk.Toplevel(self.widget) + self.tip.wm_overrideredirect(True) + self.tip.wm_geometry(f"+{x}+{y}") + label = tk.Label( + self.tip, + text=self.text, + justify="left", + background="#ffffe0", + relief="solid", + borderwidth=1, + font=("Segoe UI", 9), + ) + label.pack(ipadx=6, ipady=4) + + def _hide(self, _event=None): + if self.tip is not None: + self.tip.destroy() + self.tip = None + + +class WaterSegGUI(tk.Tk): + def __init__(self): + super().__init__() + self.title("SAM3 水体分割 GUI") + self.geometry("980x720") + + self.stop_event = threading.Event() + self.worker = None + + self.vars = {} + self._build_ui() + + def _build_ui(self): + container = ttk.Frame(self) + container.pack(fill="both", expand=True, padx=10, pady=10) + + top = ttk.Frame(container) + top.pack(fill="x") + + self._add_path_row( + top, + "image_path", + "输入影像 (TIF)", + DEFAULTS["image_path"], + "选择要分割的遥感影像 GeoTIFF。", + is_save=False, + ) + self._add_path_row( + top, + "mask_output_path", + "输出掩膜 (TIF)", + DEFAULTS["mask_output_path"], + "输出单通道 uint8 掩膜(0/1)。", + is_save=True, + ) + + row = ttk.Frame(top) + row.pack(fill="x", pady=(6, 0)) + ttk.Label(row, text="prompt").pack(side="left") + prompt_var = tk.StringVar(value=DEFAULTS["prompt"]) + self.vars["prompt"] = prompt_var + prompt_box = ttk.Combobox( + row, + textvariable=prompt_var, + values=["water body", "water", "river", "lake", "reservoir"], + ) + prompt_box.pack(side="left", fill="x", expand=True, padx=(10, 0)) + Tooltip(prompt_box, "文本提示。建议优先尝试:water / river / lake。") + + panes = ttk.Panedwindow(container, orient="horizontal") + panes.pack(fill="both", expand=True, pady=(10, 0)) + + left = ttk.Frame(panes) + right = ttk.Frame(panes) + panes.add(left, weight=3) + panes.add(right, weight=2) + + params = ttk.Frame(left) + params.pack(fill="both", expand=True) + + self._add_group( + params, + "Overview 兜底", + [ + ("overview_max_side", int, "子区域兜底预测最大边长。越大越准越慢。"), + ("overview_threshold", float, "兜底阈值。偏低减少中心漏检,但误检增加。"), + ], + ) + self._add_group( + params, + "Coarse 粗分割", + [ + ("coarse_threshold", float, "粗分割阈值。偏低提高召回,偏高减少误检。"), + ("coarse_tile_size", int, "粗分割分块大小(原图像素)。"), + ("coarse_tile_overlap", int, "粗分割块重叠像素。"), + ("coarse_downsample_factor", int, "粗分割输出降采样比例。2 更细更慢;4 折中;8 更快更粗。"), + ("band_radius", int, "边缘带半径(原图像素)。越大精修范围越大越慢。"), + ], + ) + self._add_group( + params, + "Fine 精修", + [ + ("final_threshold", float, "最终阈值。偏低更易连通,偏高更干净。"), + ("fine_tile_size", int, "精修分块大小(原图像素)。"), + ("fine_overlap", int, "精修重叠像素。越大接缝越少但更慢。"), + ], + ) + self._add_group( + params, + "Region 子区域", + [ + ("num_splits_y", int, "纵向切分数。"), + ("num_splits_x", int, "横向切分数。"), + ("region_overlap", int, "子区域之间重叠像素,减少边界断裂。"), + ], + ) + self._add_group( + params, + "Post 后处理", + [ + ("min_area", int, "移除小碎片的最小连通域面积(像素)。"), + ("keep_largest_only", bool, "只保留最大连通域(适合只要最大水体)。"), + ], + ) + + ctrl = ttk.Frame(right) + ctrl.pack(fill="x") + + self.progress = ttk.Progressbar(ctrl, maximum=100) + self.progress.pack(fill="x") + self.status = tk.StringVar(value="就绪") + ttk.Label(ctrl, textvariable=self.status).pack(anchor="w", pady=(6, 0)) + + btn_row = ttk.Frame(ctrl) + btn_row.pack(fill="x", pady=(8, 0)) + self.run_btn = ttk.Button(btn_row, text="运行", command=self._on_run) + self.run_btn.pack(side="left") + self.stop_btn = ttk.Button(btn_row, text="停止", command=self._on_stop, state="disabled") + self.stop_btn.pack(side="left", padx=(8, 0)) + + log_frame = ttk.Frame(right) + log_frame.pack(fill="both", expand=True, pady=(10, 0)) + ttk.Label(log_frame, text="日志").pack(anchor="w") + self.log_text = tk.Text(log_frame, height=20, wrap="word") + self.log_text.pack(side="left", fill="both", expand=True) + scroll = ttk.Scrollbar(log_frame, command=self.log_text.yview) + scroll.pack(side="right", fill="y") + self.log_text.configure(yscrollcommand=scroll.set) + + def _add_path_row(self, parent, key, label, default, tip, is_save): + row = ttk.Frame(parent) + row.pack(fill="x", pady=(0, 6)) + ttk.Label(row, text=label).pack(side="left") + var = tk.StringVar(value=default) + self.vars[key] = var + entry = ttk.Entry(row, textvariable=var) + entry.pack(side="left", fill="x", expand=True, padx=(10, 0)) + Tooltip(entry, tip) + btn = ttk.Button( + row, + text="浏览", + command=lambda: self._browse_path(var, is_save=is_save), + ) + btn.pack(side="left", padx=(8, 0)) + + def _add_group(self, parent, title, items): + frame = ttk.LabelFrame(parent, text=title) + frame.pack(fill="x", pady=(0, 10)) + for r, (key, typ, tip) in enumerate(items): + ttk.Label(frame, text=key).grid(row=r, column=0, sticky="w", padx=(6, 8), pady=4) + if typ is bool: + var = tk.BooleanVar(value=bool(DEFAULTS.get(key))) + self.vars[key] = var + w = ttk.Checkbutton(frame, variable=var) + w.grid(row=r, column=1, sticky="w", pady=4) + Tooltip(w, tip) + else: + var = tk.StringVar(value=str(DEFAULTS.get(key, ""))) + self.vars[key] = var + entry = ttk.Entry(frame, textvariable=var, width=14) + entry.grid(row=r, column=1, sticky="w", pady=4) + Tooltip(entry, tip) + frame.grid_columnconfigure(1, weight=1) + + def _browse_path(self, var, is_save): + if is_save: + p = filedialog.asksaveasfilename( + title="选择输出文件", + defaultextension=".tif", + filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")], + ) + else: + p = filedialog.askopenfilename( + title="选择输入文件", + filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")], + ) + if p: + var.set(p) + + def _append_log(self, msg): + self.log_text.insert("end", msg + "\n") + self.log_text.see("end") + + def _log(self, msg): + self.after(0, lambda: self._append_log(str(msg))) + + def _progress(self, stage, cur, total): + def _update(): + if total > 0: + self.progress["value"] = (cur / total) * 100.0 + self.status.set(f"{stage}: {cur}/{total}") + self.after(0, _update) + + def _get_config(self): + cfg = {} + for k, v in self.vars.items(): + if isinstance(v, tk.BooleanVar): + cfg[k] = bool(v.get()) + else: + cfg[k] = v.get() + + int_keys = [ + "overview_max_side", + "band_radius", + "fine_tile_size", + "fine_overlap", + "coarse_tile_size", + "coarse_tile_overlap", + "coarse_downsample_factor", + "num_splits_y", + "num_splits_x", + "region_overlap", + "min_area", + ] + float_keys = ["overview_threshold", "coarse_threshold", "final_threshold"] + for k in int_keys: + if k in cfg and cfg[k] != "": + cfg[k] = int(float(cfg[k])) + for k in float_keys: + if k in cfg and cfg[k] != "": + cfg[k] = float(cfg[k]) + + cfg["use_tqdm"] = False + cfg["device"] = None + return cfg + + def _on_run(self): + if self.worker is not None and self.worker.is_alive(): + return + cfg = self._get_config() + if not cfg.get("image_path") or not cfg.get("mask_output_path"): + messagebox.showerror("错误", "请先选择输入影像与输出路径。") + return + + self.stop_event.clear() + self.run_btn.configure(state="disabled") + self.stop_btn.configure(state="normal") + self.progress["value"] = 0 + self.status.set("启动中...") + self._append_log("开始运行...") + + def _work(): + try: + water_v5 = _lazy_import_backend() + water_v5.run_segmentation( + config=cfg, + progress_callback=self._progress, + log_callback=self._log, + stop_event=self.stop_event, + ) + except Exception as e: + self._log(f"错误: {e}") + finally: + def _done(): + self.run_btn.configure(state="normal") + self.stop_btn.configure(state="disabled") + self.status.set("完成") + self.after(0, _done) + + self.worker = threading.Thread(target=_work, daemon=True) + self.worker.start() + + def _on_stop(self): + self.stop_event.set() + self.status.set("停止中...") + + +def main(): + WaterSegGUI().mainloop() + + +if __name__ == "__main__": + main() + diff --git a/water_V5.py b/water_V5.py index 406c908..27d4e68 100644 --- a/water_V5.py +++ b/water_V5.py @@ -1,14 +1,17 @@ """ -使用SAM3模型分割超大遥感影像中的水体(分区域处理,优化内存) +使用SAM3模型分割超大遥感影像中的水体(三层分割策略) 流程: -1. 将影像划分为若干有重叠的子区域。 -2. 对每个子区域独立执行粗分割、边缘带构建、精修分割。 -3. 合并所有子区域的掩码(重叠区域取最大值)。 -4. 后处理:填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。 -5. 保存结果。 +1. 将影像划分为若干有重叠的子区域(第一层)。 +2. 对每个子区域进行中分辨率粗分割:以4096为块大小,滑动窗口推理,得到粗掩码(第二层)。 +3. 基于粗掩码构建边缘带,对边缘带进行精细分块推理,得到精修掩码(第三层)。 +4. 合并所有子区域的掩码(重叠区域取最大值)。 +5. 后处理:填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。 +6. 保存结果。 -优点:避免在全尺寸概率图上进行GPU操作,显著降低显存占用。 +优点: +- 粗分割采用分块推理,避免降采样丢失细节,提高召回率。 +- 精修只处理边缘带,节省计算资源。 """ import torch @@ -22,6 +25,7 @@ from rasterio.windows import Window from rasterio.io import MemoryFile from tqdm import tqdm from scipy import ndimage +import math from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor @@ -77,6 +81,13 @@ def compute_stretch_params(bands, sample_max_pixels=2_000_000): return vmin, vmax +def make_overview_bands(bands, max_side): + _, h, w = bands.shape + step = int(math.ceil(max(h, w) / float(max_side))) + step = max(step, 1) + return bands[:, ::step, ::step], step + + def _to_uint8(arr, vmin=None, vmax=None): if arr.dtype == np.uint8: return arr @@ -112,7 +123,7 @@ def _bands_to_pil(bands, stretch_params=None): return Image.fromarray(rgb, mode="RGB") def _read_pil_window(src, window, stretch_params=None): - bands = src.read(window=window, boundless=True, fill_value=0) + bands = src.read(window=window) return _bands_to_pil(bands, stretch_params=stretch_params) def infer_prompt_prob(processor, image, prompt): @@ -126,57 +137,169 @@ def infer_prompt_prob(processor, image, prompt): prob = combine_masks_logits(state["masks_logits"]) return prob -def make_overview_bands(sub_bands, max_side): - _, h, w = sub_bands.shape - step = int(np.ceil(max(h, w) / float(max_side))) - step = max(step, 1) - return sub_bands[:, ::step, ::step], step +# ---------- 新增:中分辨率粗分割(分块推理) ---------- +def _downsample_any(mask, factor): + if factor <= 1: + return mask + h, w = mask.shape + pad_h = (-h) % factor + pad_w = (-w) % factor + if pad_h or pad_w: + mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=False) + h, w = mask.shape + h2 = h // factor + w2 = w // factor + return mask.reshape(h2, factor, w2, factor).any(axis=(1, 3)) -# ---------- 精修分块 ---------- -def refine_tiles(processor, src, prompt, band_coarse_cpu, tile_size, overlap, stretch_params=None): + +def _build_band_ndimage(mask, radius): + if radius <= 0: + return np.zeros_like(mask, dtype=bool) + struct = ndimage.generate_binary_structure(2, 2) + dil = ndimage.binary_dilation(mask, structure=struct, iterations=radius) + ero = ndimage.binary_erosion(mask, structure=struct, iterations=radius) + return np.logical_xor(dil, ero) + + +def coarse_tile_segmentation( + processor, + src, + prompt, + tile_size, + overlap, + downsample_factor=4, + stretch_params=None, + nodata_mask=None, + use_tqdm=True, +): """ - 对边缘带进行精修分割 - band_coarse_cpu: 2D numpy bool,指示需要精修的区域(低分辨率) + 对子区域进行分块粗分割,返回低分辨率概率图(float16) + + Args: + processor: 粗处理器(输入分辨率 coarse_resolution) + src: rasterio 数据集(子区域) + prompt: 文本提示 + tile_size: 粗分割块大小(原图像素) + overlap: 块重叠像素 + stretch_params: 拉伸参数 + nodata_mask: 子区域的 NoData 掩码(bool, 与子区域同尺寸),用于过滤无效区域 + + Returns: + coarse_probs_lr: 子区域粗概率图 (H_lr,W_lr) float16 """ height, width = src.height, src.width - full_probs = np.zeros((height, width), dtype=np.float16) - - band_h, band_w = band_coarse_cpu.shape - scale_y = band_h / float(height) - scale_x = band_w / float(width) + height_lr = int(math.ceil(height / float(downsample_factor))) + width_lr = int(math.ceil(width / float(downsample_factor))) + full_probs_lr = np.zeros((height_lr, width_lr), dtype=np.float16) stride = max(tile_size - overlap, 1) num_tiles_y = (height + stride - 1) // stride num_tiles_x = (width + stride - 1) // stride total_tiles = num_tiles_y * num_tiles_x - # 内部进度条,leave=True 以便保留历史记录 - with tqdm(total=total_tiles, desc="精修分块", unit="块", leave=True) as pbar: - for top, left, bottom, right in tile_slices(height, width, tile_size, overlap): - c_top = int(top * scale_y) - c_left = int(left * scale_x) - c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1) - c_right = max(int(np.ceil(right * scale_x)), c_left + 1) + if use_tqdm: + iterator = tqdm( + tile_slices(height, width, tile_size, overlap), + total=total_tiles, + desc="粗分割分块", + unit="块", + leave=True, + ) + else: + iterator = tile_slices(height, width, tile_size, overlap) - if not band_coarse_cpu[c_top:c_bottom, c_left:c_right].any(): - pbar.update(1) + for top, left, bottom, right in iterator: + # 如果该块完全在 NoData 区域内,则跳过(加速) + if nodata_mask is not None: + block_nodata = nodata_mask[top:bottom, left:right] + if block_nodata.all(): + continue + + window = Window(left, top, right - left, bottom - top) + crop = _read_pil_window(src, window, stretch_params=stretch_params) + tile_prob = infer_prompt_prob(processor, crop, prompt) + if tile_prob is None: + continue + + h_lr = int(math.ceil((bottom - top) / float(downsample_factor))) + w_lr = int(math.ceil((right - left) / float(downsample_factor))) + if tile_prob.shape[-2:] != (h_lr, w_lr): + tile_prob = upsample_prob(tile_prob, (h_lr, w_lr)) + + tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy() + top_lr = top // downsample_factor + left_lr = left // downsample_factor + bottom_lr = top_lr + h_lr + right_lr = left_lr + w_lr + full_probs_lr[top_lr:bottom_lr, left_lr:right_lr] = np.maximum( + full_probs_lr[top_lr:bottom_lr, left_lr:right_lr], tile_prob_cpu + ) + + if nodata_mask is not None: + nodata_lr = _downsample_any(nodata_mask, downsample_factor) + full_probs_lr[nodata_lr] = 0.0 + + return full_probs_lr + +# ---------- 精修分块(保持不变,但可复用之前的函数) ---------- +def refine_tiles( + processor, + src, + prompt, + band_lr, + downsample_factor, + tile_size, + overlap, + stretch_params=None, + use_tqdm=True, +): + """ + 对边缘带进行精修分割 + band_lr: 2D numpy bool,指示需要精修的区域(低分辨率) + """ + height, width = src.height, src.width + full_probs = np.zeros((height, width), dtype=np.float16) + + stride = max(tile_size - overlap, 1) + num_tiles_y = (height + stride - 1) // stride + num_tiles_x = (width + stride - 1) // stride + total_tiles = num_tiles_y * num_tiles_x + + if use_tqdm: + iterator = tqdm( + tile_slices(height, width, tile_size, overlap), + total=total_tiles, + desc="精修分块", + unit="块", + leave=True, + ) + else: + iterator = tile_slices(height, width, tile_size, overlap) + + for top, left, bottom, right in iterator: + c_top = top // downsample_factor + c_left = left // downsample_factor + c_bottom = int(math.ceil(bottom / float(downsample_factor))) + c_right = int(math.ceil(right / float(downsample_factor))) + c_bottom = max(c_bottom, c_top + 1) + c_right = max(c_right, c_left + 1) + + if not band_lr[c_top:c_bottom, c_left:c_right].any(): continue window = Window(left, top, right - left, bottom - top) crop = _read_pil_window(src, window, stretch_params=stretch_params) tile_prob = infer_prompt_prob(processor, crop, prompt) if tile_prob is None: - pbar.update(1) continue if tile_prob.shape[-2:] != (bottom - top, right - left): tile_prob = upsample_prob(tile_prob, (bottom - top, right - left)) - tile_prob_cpu = tile_prob.detach().float().cpu().numpy().astype(np.float16) + tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy() full_probs[top:bottom, left:right] = np.maximum( full_probs[top:bottom, left:right], tile_prob_cpu ) - pbar.update(1) return full_probs @@ -230,80 +353,121 @@ def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=25 regions.append((left, top, right, bottom)) return regions -# ---------- 主程序 ---------- -matplotlib.use("TkAgg") +DEFAULT_CONFIG = { + "image_path": r"E:\is2\guidingsahn\result.tif", + "mask_output_path": r"E:\is2\guidingsahn\result_mask.tif", + "prompt": "water body", + "overview_max_side": 1400, + "overview_threshold": 0.4, + "coarse_resolution": 1008, + "fine_resolution": 1008, + "coarse_threshold": 0.5, + "final_threshold": 0.5, + "band_radius": 64, + "fine_tile_size": 2048, + "fine_overlap": 256, + "coarse_tile_size": 4096, + "coarse_tile_overlap": 256, + "coarse_downsample_factor": 4, + "num_splits_y": 2, + "num_splits_x": 3, + "region_overlap": 256, + "min_area": 5000, + "keep_largest_only": False, + "use_tqdm": True, + "device": None, +} -# 参数设置 -image_path = r"E:\is2\dingshanhu\result_caijian.tif" -mask_output_path = r"E:\is2\dingshanhu\result_maskV2.tif" -prompt = "water body" -coarse_read_max_side = 1200 -coarse_resolution = 1008 -fine_resolution = 1008 -coarse_threshold = 0.5 -final_threshold = 0.5 -band_radius = 64 -tile_size = 1536 -overlap = 128 -# ========== 分区域参数 ========== -num_splits_y = 2 # 纵向切分数 -num_splits_x = 3 # 横向切分数(共 2x3=6 份) -region_overlap = 256 # 子区域之间的重叠像素数 -# =============================== +def run_segmentation(config=None, progress_callback=None, log_callback=None, stop_event=None): + cfg = dict(DEFAULT_CONFIG) + if config: + cfg.update(config) -# ========== 后处理参数 ========== -min_area = 1000 # 最小面积阈值(像素),小于此值的连通域将被移除;设为0或None表示不进行面积过滤 -keep_largest_only = False # 是否只保留最大的连通域(True/False) -# =============================== + log = log_callback if log_callback is not None else print -device = "cuda" if torch.cuda.is_available() else "cpu" -print(f"使用设备: {device}") + device = cfg["device"] + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + log(f"使用设备: {device}") -# 加载模型(所有子区域共享同一个模型) -model = build_sam3_image_model().to(device).eval() -coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device) -fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device) + model = build_sam3_image_model().to(device).eval() + coarse_processor = Sam3Processor( + model, resolution=cfg["coarse_resolution"], device=device + ) + fine_processor = Sam3Processor(model, resolution=cfg["fine_resolution"], device=device) -# 打开原始影像 -with rasterio.open(image_path) as src: - nodata = src.nodata - print(f"原始影像NoData值: {nodata}") - height, width = src.height, src.width + image_path = cfg["image_path"] + mask_output_path = cfg["mask_output_path"] - # 划分区域 - regions = split_into_regions(width, height, num_splits_y, num_splits_x, region_overlap) - print(f"共划分 {len(regions)} 个子区域") + with rasterio.open(image_path) as src: + nodata = src.nodata + log(f"原始影像NoData值: {nodata}") + height, width = src.height, src.width - # 创建全尺寸掩码数组(CPU内存) - full_mask = np.zeros((height, width), dtype=np.uint8) + band1 = src.read(1) + if np.issubdtype(band1.dtype, np.floating): + nodata_mask_full = ( + np.isclose(band1, float(nodata)) + if nodata is not None + else np.zeros_like(band1, dtype=bool) + ) + else: + nodata_mask_full = ( + (band1 == nodata) if nodata is not None else np.zeros_like(band1, dtype=bool) + ) - # 子区域总进度条 - with tqdm(total=len(regions), desc="处理子区域", unit="子区域") as region_pbar: - for idx, (left, top, right, bottom) in enumerate(regions): - print(f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})") + regions = split_into_regions( + width, height, cfg["num_splits_y"], cfg["num_splits_x"], cfg["region_overlap"] + ) + log(f"共划分 {len(regions)} 个子区域") + + full_mask = np.zeros((height, width), dtype=np.uint8) + + use_tqdm = bool(cfg.get("use_tqdm", True)) + if use_tqdm: + region_iter = tqdm( + list(enumerate(regions)), + total=len(regions), + desc="处理子区域", + unit="子区域", + ) + else: + region_iter = enumerate(regions) + + for idx, (left, top, right, bottom) in region_iter: + if stop_event is not None and stop_event.is_set(): + log("已停止") + return + if progress_callback is not None: + progress_callback("region", idx + 1, len(regions)) + + log( + f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})" + ) region_w = right - left region_h = bottom - top - # 读取子区域数据 window = Window(left, top, region_w, region_h) - sub_bands = src.read(window=window) # shape: (bands, h, w) + sub_bands = src.read(window=window) - # 构建子区域元数据 sub_profile = src.profile.copy() if not src.is_tiled: - sub_profile.pop('blockxsize', None) - sub_profile.pop('blockysize', None) - sub_profile['tiled'] = False + sub_profile.pop("blockxsize", None) + sub_profile.pop("blockysize", None) + sub_profile["tiled"] = False else: - sub_profile['tiled'] = True - sub_profile.update({ - 'height': region_h, - 'width': region_w, - 'transform': rasterio.windows.transform(window, src.transform) - }) + sub_profile["tiled"] = True + sub_profile.update( + { + "height": region_h, + "width": region_w, + "transform": rasterio.windows.transform(window, src.transform), + } + ) + + sub_nodata = nodata_mask_full[top:bottom, left:right] - # 将子区域数据包装成内存中的 rasterio 数据集 with MemoryFile() as memfile: with memfile.open(**sub_profile) as sub_dst: sub_dst.write(sub_bands) @@ -311,84 +475,144 @@ with rasterio.open(image_path) as src: stretch_params = compute_stretch_params(sub_bands) overview_bands, overview_step = make_overview_bands( - sub_bands, max_side=coarse_read_max_side + sub_bands, max_side=cfg["overview_max_side"] + ) + overview_img = _bands_to_pil( + overview_bands, stretch_params=stretch_params + ) + overview_prob = infer_prompt_prob( + coarse_processor, overview_img, cfg["prompt"] ) - overview_img = _bands_to_pil(overview_bands, stretch_params=stretch_params) - overview_prob = infer_prompt_prob(coarse_processor, overview_img, prompt) if overview_prob is None: - overview_prob_cpu = np.zeros((overview_img.height, overview_img.width), dtype=np.float16) + overview_mask_small = np.zeros( + (overview_img.height, overview_img.width), dtype=bool + ) else: - overview_prob_cpu = overview_prob.detach().float().cpu().numpy().astype(np.float16) - overview_mask_cpu = overview_prob_cpu > coarse_threshold + overview_mask_small = ( + (overview_prob > cfg["overview_threshold"]) + .detach() + .cpu() + .numpy() + ) - scale = overview_img.height / float(region_h) - band_radius_small = max(int(round(band_radius * scale)), 1) - band_small = build_band(torch.from_numpy(overview_mask_cpu), band_radius_small).numpy() + if sub_nodata.any(): + nodata_overview = sub_nodata[::overview_step, ::overview_step] + overview_mask_small[nodata_overview] = False - fine_probs_cpu = refine_tiles( - fine_processor, - sub_src, - prompt, - band_small, - tile_size=tile_size, - overlap=overlap, + coarse_probs = coarse_tile_segmentation( + processor=coarse_processor, + src=sub_src, + prompt=cfg["prompt"], + tile_size=cfg["coarse_tile_size"], + overlap=cfg["coarse_tile_overlap"], + downsample_factor=cfg["coarse_downsample_factor"], stretch_params=stretch_params, + nodata_mask=sub_nodata if sub_nodata.any() else None, + use_tqdm=use_tqdm, ) - fine_mask = fine_probs_cpu > final_threshold + coarse_mask_lr = coarse_probs > cfg["coarse_threshold"] + + overview_mask_lr = np.array( + Image.fromarray( + overview_mask_small.astype(np.uint8), mode="L" + ).resize( + (coarse_mask_lr.shape[1], coarse_mask_lr.shape[0]), + resample=Image.NEAREST, + ) + ).astype(bool) + combined_mask_lr = coarse_mask_lr | overview_mask_lr + + band_radius_lr = max( + int(round(cfg["band_radius"] / float(cfg["coarse_downsample_factor"]))), + 1, + ) + band_lr = _build_band_ndimage(combined_mask_lr, band_radius_lr) + + fine_probs = refine_tiles( + processor=fine_processor, + src=sub_src, + prompt=cfg["prompt"], + band_lr=band_lr, + downsample_factor=cfg["coarse_downsample_factor"], + tile_size=cfg["fine_tile_size"], + overlap=cfg["fine_overlap"], + stretch_params=stretch_params, + use_tqdm=use_tqdm, + ) + fine_mask = fine_probs > cfg["final_threshold"] coarse_mask_full = np.array( - Image.fromarray(overview_mask_cpu.astype(np.uint8), mode="L").resize( + Image.fromarray(coarse_mask_lr.astype(np.uint8), mode="L").resize( (region_w, region_h), resample=Image.NEAREST ) ).astype(bool) - sub_mask_np = (fine_mask | coarse_mask_full).astype(np.uint8) + overview_mask_full = np.array( + Image.fromarray( + overview_mask_small.astype(np.uint8), mode="L" + ).resize((region_w, region_h), resample=Image.NEAREST) + ).astype(bool) + sub_mask_np = (fine_mask | coarse_mask_full | overview_mask_full).astype( + np.uint8 + ) - # 合并到全尺寸掩码(重叠区域取最大值) full_mask[top:bottom, left:right] = np.maximum( full_mask[top:bottom, left:right], sub_mask_np ) - # 更新子区域进度条 - region_pbar.update(1) + log("\n后处理:填充内部NoData空洞...") + if nodata is not None: + if np.issubdtype(band1.dtype, np.floating): + nodata_mask = np.isclose(band1, float(nodata)) + else: + nodata_mask = band1 == nodata - # ========== 后处理 ========== - print("\n后处理:填充内部NoData空洞...") - if nodata is not None: - band1 = src.read(1) - if np.issubdtype(band1.dtype, np.floating): - nodata_mask = np.isclose(band1, float(nodata)) - else: - nodata_mask = (band1 == nodata) + labeled_mask, _ = ndimage.label(nodata_mask) + boundary_mask = np.zeros_like(nodata_mask, dtype=bool) + boundary_mask[0, :] = True + boundary_mask[-1, :] = True + boundary_mask[:, 0] = True + boundary_mask[:, -1] = True - labeled_mask, num_features = ndimage.label(nodata_mask) - boundary_mask = np.zeros_like(nodata_mask, dtype=bool) - boundary_mask[0, :] = True - boundary_mask[-1, :] = True - boundary_mask[:, 0] = True - boundary_mask[:, -1] = True + boundary_labels = set(labeled_mask[boundary_mask]) + internal_nodata_mask = np.isin( + labeled_mask, list(boundary_labels), invert=True + ) & nodata_mask - boundary_labels = set(labeled_mask[boundary_mask]) - internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask + if internal_nodata_mask.any(): + struct = ndimage.generate_binary_structure(2, 2) + internal_dilated = ndimage.binary_dilation( + internal_nodata_mask, structure=struct, iterations=7 + ) + full_mask[internal_dilated] = 1 + log( + f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}" + ) + else: + log(" 无内部NoData区域") - if internal_nodata_mask.any(): - struct = ndimage.generate_binary_structure(2, 2) - internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7) - full_mask[internal_dilated] = 1 - print(f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}") - else: - print(" 无内部NoData区域") + log("后处理:面积过滤...") + if cfg["min_area"] is not None or cfg["keep_largest_only"]: + original_count = np.sum(full_mask) + full_mask = filter_by_area( + full_mask, + min_area=cfg["min_area"], + keep_largest_only=cfg["keep_largest_only"], + ) + filtered_count = np.sum(full_mask) + log(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}") - print("后处理:面积过滤...") - if min_area is not None or keep_largest_only: - original_count = np.sum(full_mask) - full_mask = filter_by_area(full_mask, min_area=min_area, keep_largest_only=keep_largest_only) - filtered_count = np.sum(full_mask) - print(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}") + profile = src.profile.copy() + profile.update(count=1, dtype="uint8", compress="lzw") + with rasterio.open(mask_output_path, "w", **profile) as dst: + dst.write(full_mask, 1) - # ========== 保存结果 ========== - profile = src.profile.copy() - profile.update(count=1, dtype="uint8", compress='lzw') - with rasterio.open(mask_output_path, "w", **profile) as dst: - dst.write(full_mask, 1) + log(f"分割完成,结果已保存至:{mask_output_path}") -print(f"分割完成,结果已保存至:{mask_output_path}") + +def main(): + matplotlib.use("TkAgg") + run_segmentation() + + +if __name__ == "__main__": + main()