Files
water-body-segmentation/water_GUI.py
2026-03-10 17:29:24 +08:00

351 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
water_V5 的图形界面启动器(纯 GUI 文件)。
特点:
- 参数可配置
- 悬停显示提示
- 后台线程运行分割,界面不假死
"""
import threading
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
DEFAULTS = {
"image_path": "",
"mask_output_path": "",
"prompt": "water body",
"overview_max_side": 1400,
"overview_threshold": 0.4,
"coarse_threshold": 0.5,
"coarse_tile_size": 4096,
"coarse_tile_overlap": 256,
"coarse_downsample_factor": 4,
"band_radius": 64,
"final_threshold": 0.5,
"fine_tile_size": 2048,
"fine_overlap": 256,
"num_splits_y": 2,
"num_splits_x": 3,
"region_overlap": 256,
"min_area": 5000,
"keep_largest_only": False,
}
def _lazy_import_backend():
import water_V5
return water_V5
class Tooltip:
def __init__(self, widget, text):
self.widget = widget
self.text = text
self.tip = None
widget.bind("<Enter>", self._show)
widget.bind("<Leave>", self._hide)
def _show(self, _event=None):
if self.tip is not None or not self.text:
return
x = self.widget.winfo_rootx() + 10
y = self.widget.winfo_rooty() + self.widget.winfo_height() + 5
self.tip = tk.Toplevel(self.widget)
self.tip.wm_overrideredirect(True)
self.tip.wm_geometry(f"+{x}+{y}")
label = tk.Label(
self.tip,
text=self.text,
justify="left",
background="#ffffe0",
relief="solid",
borderwidth=1,
font=("Segoe UI", 9),
)
label.pack(ipadx=6, ipady=4)
def _hide(self, _event=None):
if self.tip is not None:
self.tip.destroy()
self.tip = None
class WaterSegGUI(tk.Tk):
def __init__(self):
super().__init__()
self.title("SAM3 水体分割 GUI")
self.geometry("980x720")
self.stop_event = threading.Event()
self.worker = None
self.vars = {}
self._build_ui()
def _build_ui(self):
container = ttk.Frame(self)
container.pack(fill="both", expand=True, padx=10, pady=10)
top = ttk.Frame(container)
top.pack(fill="x")
self._add_path_row(
top,
"image_path",
"输入影像 (TIF)",
DEFAULTS["image_path"],
"选择要分割的遥感影像 GeoTIFF。",
is_save=False,
)
self._add_path_row(
top,
"mask_output_path",
"输出掩膜 (TIF)",
DEFAULTS["mask_output_path"],
"输出单通道 uint8 掩膜0/1",
is_save=True,
)
row = ttk.Frame(top)
row.pack(fill="x", pady=(6, 0))
ttk.Label(row, text="prompt").pack(side="left")
prompt_var = tk.StringVar(value=DEFAULTS["prompt"])
self.vars["prompt"] = prompt_var
prompt_box = ttk.Combobox(
row,
textvariable=prompt_var,
values=["water body", "water", "river", "lake", "reservoir"],
)
prompt_box.pack(side="left", fill="x", expand=True, padx=(10, 0))
Tooltip(prompt_box, "文本提示。建议优先尝试water / river / lake。")
panes = ttk.Panedwindow(container, orient="horizontal")
panes.pack(fill="both", expand=True, pady=(10, 0))
left = ttk.Frame(panes)
right = ttk.Frame(panes)
panes.add(left, weight=3)
panes.add(right, weight=2)
params = ttk.Frame(left)
params.pack(fill="both", expand=True)
self._add_group(
params,
"Overview 兜底",
[
("overview_max_side", int, "子区域兜底预测最大边长。越大越准越慢。"),
("overview_threshold", float, "兜底阈值。偏低减少中心漏检,但误检增加。"),
],
)
self._add_group(
params,
"Coarse 粗分割",
[
("coarse_threshold", float, "粗分割阈值。偏低提高召回,偏高减少误检。"),
("coarse_tile_size", int, "粗分割分块大小(原图像素)。"),
("coarse_tile_overlap", int, "粗分割块重叠像素。"),
("coarse_downsample_factor", int, "粗分割输出降采样比例。2 更细更慢4 折中8 更快更粗。"),
("band_radius", int, "边缘带半径(原图像素)。越大精修范围越大越慢。"),
],
)
self._add_group(
params,
"Fine 精修",
[
("final_threshold", float, "最终阈值。偏低更易连通,偏高更干净。"),
("fine_tile_size", int, "精修分块大小(原图像素)。"),
("fine_overlap", int, "精修重叠像素。越大接缝越少但更慢。"),
],
)
self._add_group(
params,
"Region 子区域",
[
("num_splits_y", int, "纵向切分数。"),
("num_splits_x", int, "横向切分数。"),
("region_overlap", int, "子区域之间重叠像素,减少边界断裂。"),
],
)
self._add_group(
params,
"Post 后处理",
[
("min_area", int, "移除小碎片的最小连通域面积(像素)。"),
("keep_largest_only", bool, "只保留最大连通域(适合只要最大水体)。"),
],
)
ctrl = ttk.Frame(right)
ctrl.pack(fill="x")
self.progress = ttk.Progressbar(ctrl, maximum=100)
self.progress.pack(fill="x")
self.status = tk.StringVar(value="就绪")
ttk.Label(ctrl, textvariable=self.status).pack(anchor="w", pady=(6, 0))
btn_row = ttk.Frame(ctrl)
btn_row.pack(fill="x", pady=(8, 0))
self.run_btn = ttk.Button(btn_row, text="运行", command=self._on_run)
self.run_btn.pack(side="left")
self.stop_btn = ttk.Button(btn_row, text="停止", command=self._on_stop, state="disabled")
self.stop_btn.pack(side="left", padx=(8, 0))
log_frame = ttk.Frame(right)
log_frame.pack(fill="both", expand=True, pady=(10, 0))
ttk.Label(log_frame, text="日志").pack(anchor="w")
self.log_text = tk.Text(log_frame, height=20, wrap="word")
self.log_text.pack(side="left", fill="both", expand=True)
scroll = ttk.Scrollbar(log_frame, command=self.log_text.yview)
scroll.pack(side="right", fill="y")
self.log_text.configure(yscrollcommand=scroll.set)
def _add_path_row(self, parent, key, label, default, tip, is_save):
row = ttk.Frame(parent)
row.pack(fill="x", pady=(0, 6))
ttk.Label(row, text=label).pack(side="left")
var = tk.StringVar(value=default)
self.vars[key] = var
entry = ttk.Entry(row, textvariable=var)
entry.pack(side="left", fill="x", expand=True, padx=(10, 0))
Tooltip(entry, tip)
btn = ttk.Button(
row,
text="浏览",
command=lambda: self._browse_path(var, is_save=is_save),
)
btn.pack(side="left", padx=(8, 0))
def _add_group(self, parent, title, items):
frame = ttk.LabelFrame(parent, text=title)
frame.pack(fill="x", pady=(0, 10))
for r, (key, typ, tip) in enumerate(items):
ttk.Label(frame, text=key).grid(row=r, column=0, sticky="w", padx=(6, 8), pady=4)
if typ is bool:
var = tk.BooleanVar(value=bool(DEFAULTS.get(key)))
self.vars[key] = var
w = ttk.Checkbutton(frame, variable=var)
w.grid(row=r, column=1, sticky="w", pady=4)
Tooltip(w, tip)
else:
var = tk.StringVar(value=str(DEFAULTS.get(key, "")))
self.vars[key] = var
entry = ttk.Entry(frame, textvariable=var, width=14)
entry.grid(row=r, column=1, sticky="w", pady=4)
Tooltip(entry, tip)
frame.grid_columnconfigure(1, weight=1)
def _browse_path(self, var, is_save):
if is_save:
p = filedialog.asksaveasfilename(
title="选择输出文件",
defaultextension=".tif",
filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")],
)
else:
p = filedialog.askopenfilename(
title="选择输入文件",
filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")],
)
if p:
var.set(p)
def _append_log(self, msg):
self.log_text.insert("end", msg + "\n")
self.log_text.see("end")
def _log(self, msg):
self.after(0, lambda: self._append_log(str(msg)))
def _progress(self, stage, cur, total):
def _update():
if total > 0:
self.progress["value"] = (cur / total) * 100.0
self.status.set(f"{stage}: {cur}/{total}")
self.after(0, _update)
def _get_config(self):
cfg = {}
for k, v in self.vars.items():
if isinstance(v, tk.BooleanVar):
cfg[k] = bool(v.get())
else:
cfg[k] = v.get()
int_keys = [
"overview_max_side",
"band_radius",
"fine_tile_size",
"fine_overlap",
"coarse_tile_size",
"coarse_tile_overlap",
"coarse_downsample_factor",
"num_splits_y",
"num_splits_x",
"region_overlap",
"min_area",
]
float_keys = ["overview_threshold", "coarse_threshold", "final_threshold"]
for k in int_keys:
if k in cfg and cfg[k] != "":
cfg[k] = int(float(cfg[k]))
for k in float_keys:
if k in cfg and cfg[k] != "":
cfg[k] = float(cfg[k])
cfg["use_tqdm"] = False
cfg["device"] = None
return cfg
def _on_run(self):
if self.worker is not None and self.worker.is_alive():
return
cfg = self._get_config()
if not cfg.get("image_path") or not cfg.get("mask_output_path"):
messagebox.showerror("错误", "请先选择输入影像与输出路径。")
return
self.stop_event.clear()
self.run_btn.configure(state="disabled")
self.stop_btn.configure(state="normal")
self.progress["value"] = 0
self.status.set("启动中...")
self._append_log("开始运行...")
def _work():
try:
water_v5 = _lazy_import_backend()
water_v5.run_segmentation(
config=cfg,
progress_callback=self._progress,
log_callback=self._log,
stop_event=self.stop_event,
)
except Exception as e:
self._log(f"错误: {e}")
finally:
def _done():
self.run_btn.configure(state="normal")
self.stop_btn.configure(state="disabled")
self.status.set("完成")
self.after(0, _done)
self.worker = threading.Thread(target=_work, daemon=True)
self.worker.start()
def _on_stop(self):
self.stop_event.set()
self.status.set("停止中...")
def main():
WaterSegGUI().mainloop()
if __name__ == "__main__":
main()