351 lines
11 KiB
Python
351 lines
11 KiB
Python
"""
|
||
water_V5 的图形界面启动器(纯 GUI 文件)。
|
||
|
||
特点:
|
||
- 参数可配置
|
||
- 悬停显示提示
|
||
- 后台线程运行分割,界面不假死
|
||
"""
|
||
|
||
import threading
|
||
import tkinter as tk
|
||
from tkinter import filedialog, messagebox, ttk
|
||
|
||
|
||
DEFAULTS = {
|
||
"image_path": "",
|
||
"mask_output_path": "",
|
||
"prompt": "water body",
|
||
"overview_max_side": 1400,
|
||
"overview_threshold": 0.4,
|
||
"coarse_threshold": 0.5,
|
||
"coarse_tile_size": 4096,
|
||
"coarse_tile_overlap": 256,
|
||
"coarse_downsample_factor": 4,
|
||
"band_radius": 64,
|
||
"final_threshold": 0.5,
|
||
"fine_tile_size": 2048,
|
||
"fine_overlap": 256,
|
||
"num_splits_y": 2,
|
||
"num_splits_x": 3,
|
||
"region_overlap": 256,
|
||
"min_area": 5000,
|
||
"keep_largest_only": False,
|
||
}
|
||
|
||
|
||
def _lazy_import_backend():
|
||
import water_V5
|
||
|
||
return water_V5
|
||
|
||
|
||
class Tooltip:
|
||
def __init__(self, widget, text):
|
||
self.widget = widget
|
||
self.text = text
|
||
self.tip = None
|
||
widget.bind("<Enter>", self._show)
|
||
widget.bind("<Leave>", self._hide)
|
||
|
||
def _show(self, _event=None):
|
||
if self.tip is not None or not self.text:
|
||
return
|
||
x = self.widget.winfo_rootx() + 10
|
||
y = self.widget.winfo_rooty() + self.widget.winfo_height() + 5
|
||
self.tip = tk.Toplevel(self.widget)
|
||
self.tip.wm_overrideredirect(True)
|
||
self.tip.wm_geometry(f"+{x}+{y}")
|
||
label = tk.Label(
|
||
self.tip,
|
||
text=self.text,
|
||
justify="left",
|
||
background="#ffffe0",
|
||
relief="solid",
|
||
borderwidth=1,
|
||
font=("Segoe UI", 9),
|
||
)
|
||
label.pack(ipadx=6, ipady=4)
|
||
|
||
def _hide(self, _event=None):
|
||
if self.tip is not None:
|
||
self.tip.destroy()
|
||
self.tip = None
|
||
|
||
|
||
class WaterSegGUI(tk.Tk):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.title("SAM3 水体分割 GUI")
|
||
self.geometry("980x720")
|
||
|
||
self.stop_event = threading.Event()
|
||
self.worker = None
|
||
|
||
self.vars = {}
|
||
self._build_ui()
|
||
|
||
def _build_ui(self):
|
||
container = ttk.Frame(self)
|
||
container.pack(fill="both", expand=True, padx=10, pady=10)
|
||
|
||
top = ttk.Frame(container)
|
||
top.pack(fill="x")
|
||
|
||
self._add_path_row(
|
||
top,
|
||
"image_path",
|
||
"输入影像 (TIF)",
|
||
DEFAULTS["image_path"],
|
||
"选择要分割的遥感影像 GeoTIFF。",
|
||
is_save=False,
|
||
)
|
||
self._add_path_row(
|
||
top,
|
||
"mask_output_path",
|
||
"输出掩膜 (TIF)",
|
||
DEFAULTS["mask_output_path"],
|
||
"输出单通道 uint8 掩膜(0/1)。",
|
||
is_save=True,
|
||
)
|
||
|
||
row = ttk.Frame(top)
|
||
row.pack(fill="x", pady=(6, 0))
|
||
ttk.Label(row, text="prompt").pack(side="left")
|
||
prompt_var = tk.StringVar(value=DEFAULTS["prompt"])
|
||
self.vars["prompt"] = prompt_var
|
||
prompt_box = ttk.Combobox(
|
||
row,
|
||
textvariable=prompt_var,
|
||
values=["water body", "water", "river", "lake", "reservoir"],
|
||
)
|
||
prompt_box.pack(side="left", fill="x", expand=True, padx=(10, 0))
|
||
Tooltip(prompt_box, "文本提示。建议优先尝试:water / river / lake。")
|
||
|
||
panes = ttk.Panedwindow(container, orient="horizontal")
|
||
panes.pack(fill="both", expand=True, pady=(10, 0))
|
||
|
||
left = ttk.Frame(panes)
|
||
right = ttk.Frame(panes)
|
||
panes.add(left, weight=3)
|
||
panes.add(right, weight=2)
|
||
|
||
params = ttk.Frame(left)
|
||
params.pack(fill="both", expand=True)
|
||
|
||
self._add_group(
|
||
params,
|
||
"Overview 兜底",
|
||
[
|
||
("overview_max_side", int, "子区域兜底预测最大边长。越大越准越慢。"),
|
||
("overview_threshold", float, "兜底阈值。偏低减少中心漏检,但误检增加。"),
|
||
],
|
||
)
|
||
self._add_group(
|
||
params,
|
||
"Coarse 粗分割",
|
||
[
|
||
("coarse_threshold", float, "粗分割阈值。偏低提高召回,偏高减少误检。"),
|
||
("coarse_tile_size", int, "粗分割分块大小(原图像素)。"),
|
||
("coarse_tile_overlap", int, "粗分割块重叠像素。"),
|
||
("coarse_downsample_factor", int, "粗分割输出降采样比例。2 更细更慢;4 折中;8 更快更粗。"),
|
||
("band_radius", int, "边缘带半径(原图像素)。越大精修范围越大越慢。"),
|
||
],
|
||
)
|
||
self._add_group(
|
||
params,
|
||
"Fine 精修",
|
||
[
|
||
("final_threshold", float, "最终阈值。偏低更易连通,偏高更干净。"),
|
||
("fine_tile_size", int, "精修分块大小(原图像素)。"),
|
||
("fine_overlap", int, "精修重叠像素。越大接缝越少但更慢。"),
|
||
],
|
||
)
|
||
self._add_group(
|
||
params,
|
||
"Region 子区域",
|
||
[
|
||
("num_splits_y", int, "纵向切分数。"),
|
||
("num_splits_x", int, "横向切分数。"),
|
||
("region_overlap", int, "子区域之间重叠像素,减少边界断裂。"),
|
||
],
|
||
)
|
||
self._add_group(
|
||
params,
|
||
"Post 后处理",
|
||
[
|
||
("min_area", int, "移除小碎片的最小连通域面积(像素)。"),
|
||
("keep_largest_only", bool, "只保留最大连通域(适合只要最大水体)。"),
|
||
],
|
||
)
|
||
|
||
ctrl = ttk.Frame(right)
|
||
ctrl.pack(fill="x")
|
||
|
||
self.progress = ttk.Progressbar(ctrl, maximum=100)
|
||
self.progress.pack(fill="x")
|
||
self.status = tk.StringVar(value="就绪")
|
||
ttk.Label(ctrl, textvariable=self.status).pack(anchor="w", pady=(6, 0))
|
||
|
||
btn_row = ttk.Frame(ctrl)
|
||
btn_row.pack(fill="x", pady=(8, 0))
|
||
self.run_btn = ttk.Button(btn_row, text="运行", command=self._on_run)
|
||
self.run_btn.pack(side="left")
|
||
self.stop_btn = ttk.Button(btn_row, text="停止", command=self._on_stop, state="disabled")
|
||
self.stop_btn.pack(side="left", padx=(8, 0))
|
||
|
||
log_frame = ttk.Frame(right)
|
||
log_frame.pack(fill="both", expand=True, pady=(10, 0))
|
||
ttk.Label(log_frame, text="日志").pack(anchor="w")
|
||
self.log_text = tk.Text(log_frame, height=20, wrap="word")
|
||
self.log_text.pack(side="left", fill="both", expand=True)
|
||
scroll = ttk.Scrollbar(log_frame, command=self.log_text.yview)
|
||
scroll.pack(side="right", fill="y")
|
||
self.log_text.configure(yscrollcommand=scroll.set)
|
||
|
||
def _add_path_row(self, parent, key, label, default, tip, is_save):
|
||
row = ttk.Frame(parent)
|
||
row.pack(fill="x", pady=(0, 6))
|
||
ttk.Label(row, text=label).pack(side="left")
|
||
var = tk.StringVar(value=default)
|
||
self.vars[key] = var
|
||
entry = ttk.Entry(row, textvariable=var)
|
||
entry.pack(side="left", fill="x", expand=True, padx=(10, 0))
|
||
Tooltip(entry, tip)
|
||
btn = ttk.Button(
|
||
row,
|
||
text="浏览",
|
||
command=lambda: self._browse_path(var, is_save=is_save),
|
||
)
|
||
btn.pack(side="left", padx=(8, 0))
|
||
|
||
def _add_group(self, parent, title, items):
|
||
frame = ttk.LabelFrame(parent, text=title)
|
||
frame.pack(fill="x", pady=(0, 10))
|
||
for r, (key, typ, tip) in enumerate(items):
|
||
ttk.Label(frame, text=key).grid(row=r, column=0, sticky="w", padx=(6, 8), pady=4)
|
||
if typ is bool:
|
||
var = tk.BooleanVar(value=bool(DEFAULTS.get(key)))
|
||
self.vars[key] = var
|
||
w = ttk.Checkbutton(frame, variable=var)
|
||
w.grid(row=r, column=1, sticky="w", pady=4)
|
||
Tooltip(w, tip)
|
||
else:
|
||
var = tk.StringVar(value=str(DEFAULTS.get(key, "")))
|
||
self.vars[key] = var
|
||
entry = ttk.Entry(frame, textvariable=var, width=14)
|
||
entry.grid(row=r, column=1, sticky="w", pady=4)
|
||
Tooltip(entry, tip)
|
||
frame.grid_columnconfigure(1, weight=1)
|
||
|
||
def _browse_path(self, var, is_save):
|
||
if is_save:
|
||
p = filedialog.asksaveasfilename(
|
||
title="选择输出文件",
|
||
defaultextension=".tif",
|
||
filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")],
|
||
)
|
||
else:
|
||
p = filedialog.askopenfilename(
|
||
title="选择输入文件",
|
||
filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")],
|
||
)
|
||
if p:
|
||
var.set(p)
|
||
|
||
def _append_log(self, msg):
|
||
self.log_text.insert("end", msg + "\n")
|
||
self.log_text.see("end")
|
||
|
||
def _log(self, msg):
|
||
self.after(0, lambda: self._append_log(str(msg)))
|
||
|
||
def _progress(self, stage, cur, total):
|
||
def _update():
|
||
if total > 0:
|
||
self.progress["value"] = (cur / total) * 100.0
|
||
self.status.set(f"{stage}: {cur}/{total}")
|
||
self.after(0, _update)
|
||
|
||
def _get_config(self):
|
||
cfg = {}
|
||
for k, v in self.vars.items():
|
||
if isinstance(v, tk.BooleanVar):
|
||
cfg[k] = bool(v.get())
|
||
else:
|
||
cfg[k] = v.get()
|
||
|
||
int_keys = [
|
||
"overview_max_side",
|
||
"band_radius",
|
||
"fine_tile_size",
|
||
"fine_overlap",
|
||
"coarse_tile_size",
|
||
"coarse_tile_overlap",
|
||
"coarse_downsample_factor",
|
||
"num_splits_y",
|
||
"num_splits_x",
|
||
"region_overlap",
|
||
"min_area",
|
||
]
|
||
float_keys = ["overview_threshold", "coarse_threshold", "final_threshold"]
|
||
for k in int_keys:
|
||
if k in cfg and cfg[k] != "":
|
||
cfg[k] = int(float(cfg[k]))
|
||
for k in float_keys:
|
||
if k in cfg and cfg[k] != "":
|
||
cfg[k] = float(cfg[k])
|
||
|
||
cfg["use_tqdm"] = False
|
||
cfg["device"] = None
|
||
return cfg
|
||
|
||
def _on_run(self):
|
||
if self.worker is not None and self.worker.is_alive():
|
||
return
|
||
cfg = self._get_config()
|
||
if not cfg.get("image_path") or not cfg.get("mask_output_path"):
|
||
messagebox.showerror("错误", "请先选择输入影像与输出路径。")
|
||
return
|
||
|
||
self.stop_event.clear()
|
||
self.run_btn.configure(state="disabled")
|
||
self.stop_btn.configure(state="normal")
|
||
self.progress["value"] = 0
|
||
self.status.set("启动中...")
|
||
self._append_log("开始运行...")
|
||
|
||
def _work():
|
||
try:
|
||
water_v5 = _lazy_import_backend()
|
||
water_v5.run_segmentation(
|
||
config=cfg,
|
||
progress_callback=self._progress,
|
||
log_callback=self._log,
|
||
stop_event=self.stop_event,
|
||
)
|
||
except Exception as e:
|
||
self._log(f"错误: {e}")
|
||
finally:
|
||
def _done():
|
||
self.run_btn.configure(state="normal")
|
||
self.stop_btn.configure(state="disabled")
|
||
self.status.set("完成")
|
||
self.after(0, _done)
|
||
|
||
self.worker = threading.Thread(target=_work, daemon=True)
|
||
self.worker.start()
|
||
|
||
def _on_stop(self):
|
||
self.stop_event.set()
|
||
self.status.set("停止中...")
|
||
|
||
|
||
def main():
|
||
WaterSegGUI().mainloop()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|