GUI
This commit is contained in:
117
README_water_GUI.md
Normal file
117
README_water_GUI.md
Normal file
@ -0,0 +1,117 @@
|
||||
# water_GUI.py 使用说明(SAM3 水体分割 GUI)
|
||||
|
||||
本说明文档对应脚本:[water_GUI.py](file:///e:/code/sam3-main/sam3-main/water_GUI.py)。
|
||||
该脚本是 **纯 GUI 启动器**:界面负责参数编辑、提示说明、日志与进度展示;真正的分割计算由 [water_V5.py](file:///e:/code/sam3-main/sam3-main/water_V5.py) 提供的 `run_segmentation()` 执行,并在点击“运行”时才会被懒加载导入。
|
||||
|
||||
## 1. 功能概览
|
||||
|
||||
- 选择输入 GeoTIFF(待分割遥感影像)与输出掩膜 GeoTIFF 路径
|
||||
- 编辑分割参数(Overview/Coarse/Fine/Region/Post)
|
||||
- 鼠标悬停在参数输入框上显示详细提示(tooltip)
|
||||
- 后台线程执行分割,GUI 不会卡死
|
||||
- 右侧日志窗口实时输出运行信息
|
||||
- 进度条显示子区域处理进度
|
||||
- 支持“停止”(会在子区域边界处响应)
|
||||
|
||||
## 2. 运行方式
|
||||
|
||||
在项目目录下运行:
|
||||
|
||||
```bash
|
||||
python e:\code\sam3-main\sam3-main\water_GUI.py
|
||||
```
|
||||
|
||||
运行后:
|
||||
- 点击“浏览”选择输入影像(`.tif/.tiff`)
|
||||
- 点击“浏览”选择输出掩膜路径(建议 `.tif`)
|
||||
- 调整参数(可保持默认)
|
||||
- 点击“运行”
|
||||
|
||||
## 3. 依赖与环境
|
||||
|
||||
GUI 本身依赖:
|
||||
- Python 标准库 `tkinter`
|
||||
|
||||
执行分割时会懒加载并依赖(来自 water_V5.py):
|
||||
- torch(建议 CUDA 可用)
|
||||
- rasterio
|
||||
- numpy
|
||||
- scipy
|
||||
- pillow
|
||||
- tqdm
|
||||
- 以及 sam3 模型代码与权重(由项目现有加载逻辑处理)
|
||||
|
||||
## 4. 参数说明(界面分组)
|
||||
|
||||
### 4.1 Overview 兜底
|
||||
|
||||
用于减少“中心水域漏检”。Overview 会对每个子区域做一次低分辨率全局预测,并在最终输出中与 coarse/fine 结果取并集。
|
||||
|
||||
- `overview_max_side`:兜底预测的最大边长。越大越准越慢。
|
||||
- `overview_threshold`:兜底阈值。偏低可减少漏检,但误检会增加。
|
||||
|
||||
### 4.2 Coarse 粗分割
|
||||
|
||||
粗分割用于提高召回、生成精修的边缘带(band),并与 overview 一起决定精修区域范围。
|
||||
|
||||
- `coarse_threshold`:粗分割阈值。偏低提高召回,偏高减少误检。
|
||||
- `coarse_tile_size`:粗分割分块大小(原图像素)。
|
||||
- `coarse_tile_overlap`:粗分割块重叠像素。
|
||||
- `coarse_downsample_factor`:粗分割输出降采样比例。2 更细更慢;4 折中;8 更快更粗。
|
||||
- `band_radius`:边缘带半径(原图像素)。越大精修范围越大,速度越慢。
|
||||
|
||||
### 4.3 Fine 精修
|
||||
|
||||
只对边缘带覆盖到的区域进行精细分割,补足边界细节。
|
||||
|
||||
- `final_threshold`:最终阈值。偏低更易连通,偏高更干净。
|
||||
- `fine_tile_size`:精修分块大小(原图像素)。一般 1536 更稳,2048 更快但更吃显存/更慢。
|
||||
- `fine_overlap`:精修块重叠像素。越大接缝越少但更慢。
|
||||
|
||||
### 4.4 Region 子区域
|
||||
|
||||
将整图划分成多个有重叠的子区域,降低内存峰值。
|
||||
|
||||
- `num_splits_y`:纵向切分数。
|
||||
- `num_splits_x`:横向切分数。
|
||||
- `region_overlap`:子区域重叠像素,减少子区域边界断裂。
|
||||
|
||||
### 4.5 Post 后处理
|
||||
|
||||
- `min_area`:移除小碎片的最小连通域面积(像素)。
|
||||
- `keep_largest_only`:只保留最大连通域(适合只要最大水体)。
|
||||
|
||||
## 5. 输出说明
|
||||
|
||||
输出文件为单通道 `uint8` 掩膜 GeoTIFF:
|
||||
- 0:非水体
|
||||
- 1:水体
|
||||
|
||||
输出文件的空间参考(CRS/transform)来自输入影像的 profile。
|
||||
|
||||
## 6. 使用建议(16GB 显存)
|
||||
|
||||
优先按以下顺序调参:
|
||||
- 先把 `fine_tile_size` 从 2048 降到 1536(更稳)
|
||||
- `fine_overlap` 从 256 降到 128(更快,可能稍有拼接缝)
|
||||
- `coarse_downsample_factor` 从 4 改 2(更细更慢,但中心/边缘召回更稳)
|
||||
- `overview_threshold` 从 0.4 降到 0.35(召回更强,误检更可能增加)
|
||||
|
||||
## 7. 常见问题
|
||||
|
||||
### 7.1 点击“运行”后报 `ModuleNotFoundError: water_V5`
|
||||
|
||||
确保 `water_GUI.py` 与 `water_V5.py` 在同一目录:
|
||||
- `e:\code\sam3-main\sam3-main\water_GUI.py`
|
||||
- `e:\code\sam3-main\sam3-main\water_V5.py`
|
||||
|
||||
并从该目录执行运行命令,或将该目录加入 Python 路径。
|
||||
|
||||
### 7.2 界面正常但运行很慢
|
||||
|
||||
这通常是正常的:粗分割/精修会对大量瓦片做模型推理。优先调小 `fine_tile_size`、`fine_overlap`,并提高 `coarse_downsample_factor`。
|
||||
|
||||
### 7.3 “停止”不立即生效
|
||||
|
||||
停止信号会在子区域边界处检查,因此可能需要等待当前子区域处理完成后才退出。
|
||||
|
||||
155
tif_caijain.py
Normal file
155
tif_caijain.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""
|
||||
使用二值掩膜 TIF 文件(值为1的区域需要去除)对数据 TIF 文件进行掩膜。
|
||||
输入:
|
||||
data_tif: 要掩膜的数据文件路径
|
||||
mask_tif: 二值掩膜文件路径(值为1表示需要去除的区域)
|
||||
输出:
|
||||
掩膜后的数据 TIF 文件,仅将掩膜对应位置设为 NoData
|
||||
要求:
|
||||
两个 TIF 文件具有相同的投影、分辨率、范围和尺寸(精确对齐),
|
||||
否则程序将报错或行为未定义。
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import rasterio
|
||||
from rasterio.windows import Window
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def mask_data_by_binary_mask(
|
||||
data_path,
|
||||
mask_path,
|
||||
output_path=None,
|
||||
remove_value=1,
|
||||
nodata_value=None,
|
||||
tile_size=4096,
|
||||
):
|
||||
"""使用二值掩膜 TIF 对数据 TIF 进行掩膜。
|
||||
|
||||
将数据 TIF 中对应掩膜值等于 remove_value 的像素设为 NoData,其余保留。
|
||||
|
||||
性能建议:
|
||||
- 若数据源是 tiled GeoTIFF,可将 tile_size 设为 0 以按源文件块窗口遍历(通常更快)。
|
||||
"""
|
||||
data_path = Path(data_path)
|
||||
mask_path = Path(mask_path)
|
||||
|
||||
if output_path is None:
|
||||
output_path = data_path.parent / f"{data_path.stem}_masked{data_path.suffix}"
|
||||
else:
|
||||
output_path = Path(output_path)
|
||||
|
||||
logger.info(f"数据文件: {data_path.name}")
|
||||
logger.info(f"掩膜文件: {mask_path.name}")
|
||||
logger.info(f"去除掩膜值: {remove_value}")
|
||||
|
||||
with rasterio.Env(GDAL_NUM_THREADS="ALL_CPUS"):
|
||||
with rasterio.open(data_path) as src_data, rasterio.open(mask_path) as src_mask:
|
||||
if src_data.crs != src_mask.crs:
|
||||
raise ValueError("数据与掩膜的 CRS 不一致,请先统一投影。")
|
||||
if src_data.transform != src_mask.transform:
|
||||
logger.warning(
|
||||
"数据与掩膜的地理变换不一致,可能未精确对齐,继续处理可能存在风险。"
|
||||
)
|
||||
if (src_data.width, src_data.height) != (src_mask.width, src_mask.height):
|
||||
raise ValueError("数据与掩膜的尺寸不一致,无法直接按像素对应掩膜。")
|
||||
|
||||
# 确定输出 NoData 值(并尽量匹配数据 dtype,避免隐式类型转换带来的开销)
|
||||
if nodata_value is None:
|
||||
nodata_value = src_data.nodata if src_data.nodata is not None else 0
|
||||
try:
|
||||
nodata_value_cast = np.array(
|
||||
nodata_value, dtype=src_data.dtypes[0]
|
||||
).item()
|
||||
except Exception:
|
||||
nodata_value_cast = nodata_value
|
||||
|
||||
# 创建输出元数据:基于数据源的元数据,更新 nodata 和压缩选项
|
||||
out_meta = src_data.meta.copy()
|
||||
out_meta.update(
|
||||
{
|
||||
"nodata": nodata_value,
|
||||
"compress": (
|
||||
src_data.compression.value if src_data.compression else "lzw"
|
||||
),
|
||||
"tiled": src_data.is_tiled,
|
||||
}
|
||||
)
|
||||
if src_data.is_tiled:
|
||||
out_meta.update(
|
||||
{
|
||||
"blockxsize": src_data.block_shapes[0][0],
|
||||
"blockysize": src_data.block_shapes[0][1],
|
||||
}
|
||||
)
|
||||
|
||||
# 创建输出文件
|
||||
with rasterio.open(output_path, "w", **out_meta) as dst:
|
||||
width, height = src_data.width, src_data.height
|
||||
|
||||
if tile_size is None or tile_size <= 0:
|
||||
windows = [w for _, w in src_data.block_windows(1)]
|
||||
else:
|
||||
stride = int(tile_size)
|
||||
windows = [
|
||||
Window(i, j, min(stride, width - i), min(stride, height - j))
|
||||
for i in range(0, width, stride)
|
||||
for j in range(0, height, stride)
|
||||
]
|
||||
|
||||
with tqdm(total=len(windows), desc="处理瓦片", unit="块") as pbar:
|
||||
for window in windows:
|
||||
# 读取相同位置的掩膜瓦片(假设完全对齐)
|
||||
mask = src_mask.read(1, window=window)
|
||||
remove_mask = mask == remove_value
|
||||
|
||||
# 读取数据瓦片
|
||||
data = src_data.read(window=window) # shape: (bands, h, w)
|
||||
|
||||
if remove_mask.any():
|
||||
for band_idx in range(data.shape[0]):
|
||||
np.putmask(
|
||||
data[band_idx], remove_mask, nodata_value_cast
|
||||
)
|
||||
|
||||
dst.write(data, window=window)
|
||||
pbar.update(1)
|
||||
|
||||
logger.info(f"处理完成,输出文件:{output_path}")
|
||||
return str(output_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="使用二值掩膜 TIF(值为1的区域)对数据 TIF 进行掩膜,将对应位置设为 NoData。"
|
||||
)
|
||||
parser.add_argument("data_tif", help="要掩膜的数据 TIF 文件路径")
|
||||
parser.add_argument("mask_tif", help="二值掩膜 TIF 文件路径(值为1表示需要去除的区域)")
|
||||
parser.add_argument("-o", "--output", help="输出文件路径 (可选)")
|
||||
parser.add_argument("-r", "--remove_value", type=int, default=1,
|
||||
help="掩膜中要去除的值,默认为1")
|
||||
parser.add_argument("-n", "--nodata", type=float,
|
||||
help="输出 NoData 值 (可选,默认使用数据 TIF 的 NoData 或 0)")
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tile_size",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="分块大小(像素),默认4096;设为0则按源文件块窗口遍历(tiled 文件通常更快)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
mask_data_by_binary_mask(
|
||||
args.data_tif, args.mask_tif, args.output,
|
||||
args.remove_value, args.nodata, args.tile_size
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
350
water_GUI.py
Normal file
350
water_GUI.py
Normal file
@ -0,0 +1,350 @@
|
||||
"""
|
||||
water_V5 的图形界面启动器(纯 GUI 文件)。
|
||||
|
||||
特点:
|
||||
- 参数可配置
|
||||
- 悬停显示提示
|
||||
- 后台线程运行分割,界面不假死
|
||||
"""
|
||||
|
||||
import threading
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog, messagebox, ttk
|
||||
|
||||
|
||||
DEFAULTS = {
|
||||
"image_path": "",
|
||||
"mask_output_path": "",
|
||||
"prompt": "water body",
|
||||
"overview_max_side": 1400,
|
||||
"overview_threshold": 0.4,
|
||||
"coarse_threshold": 0.5,
|
||||
"coarse_tile_size": 4096,
|
||||
"coarse_tile_overlap": 256,
|
||||
"coarse_downsample_factor": 4,
|
||||
"band_radius": 64,
|
||||
"final_threshold": 0.5,
|
||||
"fine_tile_size": 2048,
|
||||
"fine_overlap": 256,
|
||||
"num_splits_y": 2,
|
||||
"num_splits_x": 3,
|
||||
"region_overlap": 256,
|
||||
"min_area": 5000,
|
||||
"keep_largest_only": False,
|
||||
}
|
||||
|
||||
|
||||
def _lazy_import_backend():
|
||||
import water_V5
|
||||
|
||||
return water_V5
|
||||
|
||||
|
||||
class Tooltip:
|
||||
def __init__(self, widget, text):
|
||||
self.widget = widget
|
||||
self.text = text
|
||||
self.tip = None
|
||||
widget.bind("<Enter>", self._show)
|
||||
widget.bind("<Leave>", self._hide)
|
||||
|
||||
def _show(self, _event=None):
|
||||
if self.tip is not None or not self.text:
|
||||
return
|
||||
x = self.widget.winfo_rootx() + 10
|
||||
y = self.widget.winfo_rooty() + self.widget.winfo_height() + 5
|
||||
self.tip = tk.Toplevel(self.widget)
|
||||
self.tip.wm_overrideredirect(True)
|
||||
self.tip.wm_geometry(f"+{x}+{y}")
|
||||
label = tk.Label(
|
||||
self.tip,
|
||||
text=self.text,
|
||||
justify="left",
|
||||
background="#ffffe0",
|
||||
relief="solid",
|
||||
borderwidth=1,
|
||||
font=("Segoe UI", 9),
|
||||
)
|
||||
label.pack(ipadx=6, ipady=4)
|
||||
|
||||
def _hide(self, _event=None):
|
||||
if self.tip is not None:
|
||||
self.tip.destroy()
|
||||
self.tip = None
|
||||
|
||||
|
||||
class WaterSegGUI(tk.Tk):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.title("SAM3 水体分割 GUI")
|
||||
self.geometry("980x720")
|
||||
|
||||
self.stop_event = threading.Event()
|
||||
self.worker = None
|
||||
|
||||
self.vars = {}
|
||||
self._build_ui()
|
||||
|
||||
def _build_ui(self):
|
||||
container = ttk.Frame(self)
|
||||
container.pack(fill="both", expand=True, padx=10, pady=10)
|
||||
|
||||
top = ttk.Frame(container)
|
||||
top.pack(fill="x")
|
||||
|
||||
self._add_path_row(
|
||||
top,
|
||||
"image_path",
|
||||
"输入影像 (TIF)",
|
||||
DEFAULTS["image_path"],
|
||||
"选择要分割的遥感影像 GeoTIFF。",
|
||||
is_save=False,
|
||||
)
|
||||
self._add_path_row(
|
||||
top,
|
||||
"mask_output_path",
|
||||
"输出掩膜 (TIF)",
|
||||
DEFAULTS["mask_output_path"],
|
||||
"输出单通道 uint8 掩膜(0/1)。",
|
||||
is_save=True,
|
||||
)
|
||||
|
||||
row = ttk.Frame(top)
|
||||
row.pack(fill="x", pady=(6, 0))
|
||||
ttk.Label(row, text="prompt").pack(side="left")
|
||||
prompt_var = tk.StringVar(value=DEFAULTS["prompt"])
|
||||
self.vars["prompt"] = prompt_var
|
||||
prompt_box = ttk.Combobox(
|
||||
row,
|
||||
textvariable=prompt_var,
|
||||
values=["water body", "water", "river", "lake", "reservoir"],
|
||||
)
|
||||
prompt_box.pack(side="left", fill="x", expand=True, padx=(10, 0))
|
||||
Tooltip(prompt_box, "文本提示。建议优先尝试:water / river / lake。")
|
||||
|
||||
panes = ttk.Panedwindow(container, orient="horizontal")
|
||||
panes.pack(fill="both", expand=True, pady=(10, 0))
|
||||
|
||||
left = ttk.Frame(panes)
|
||||
right = ttk.Frame(panes)
|
||||
panes.add(left, weight=3)
|
||||
panes.add(right, weight=2)
|
||||
|
||||
params = ttk.Frame(left)
|
||||
params.pack(fill="both", expand=True)
|
||||
|
||||
self._add_group(
|
||||
params,
|
||||
"Overview 兜底",
|
||||
[
|
||||
("overview_max_side", int, "子区域兜底预测最大边长。越大越准越慢。"),
|
||||
("overview_threshold", float, "兜底阈值。偏低减少中心漏检,但误检增加。"),
|
||||
],
|
||||
)
|
||||
self._add_group(
|
||||
params,
|
||||
"Coarse 粗分割",
|
||||
[
|
||||
("coarse_threshold", float, "粗分割阈值。偏低提高召回,偏高减少误检。"),
|
||||
("coarse_tile_size", int, "粗分割分块大小(原图像素)。"),
|
||||
("coarse_tile_overlap", int, "粗分割块重叠像素。"),
|
||||
("coarse_downsample_factor", int, "粗分割输出降采样比例。2 更细更慢;4 折中;8 更快更粗。"),
|
||||
("band_radius", int, "边缘带半径(原图像素)。越大精修范围越大越慢。"),
|
||||
],
|
||||
)
|
||||
self._add_group(
|
||||
params,
|
||||
"Fine 精修",
|
||||
[
|
||||
("final_threshold", float, "最终阈值。偏低更易连通,偏高更干净。"),
|
||||
("fine_tile_size", int, "精修分块大小(原图像素)。"),
|
||||
("fine_overlap", int, "精修重叠像素。越大接缝越少但更慢。"),
|
||||
],
|
||||
)
|
||||
self._add_group(
|
||||
params,
|
||||
"Region 子区域",
|
||||
[
|
||||
("num_splits_y", int, "纵向切分数。"),
|
||||
("num_splits_x", int, "横向切分数。"),
|
||||
("region_overlap", int, "子区域之间重叠像素,减少边界断裂。"),
|
||||
],
|
||||
)
|
||||
self._add_group(
|
||||
params,
|
||||
"Post 后处理",
|
||||
[
|
||||
("min_area", int, "移除小碎片的最小连通域面积(像素)。"),
|
||||
("keep_largest_only", bool, "只保留最大连通域(适合只要最大水体)。"),
|
||||
],
|
||||
)
|
||||
|
||||
ctrl = ttk.Frame(right)
|
||||
ctrl.pack(fill="x")
|
||||
|
||||
self.progress = ttk.Progressbar(ctrl, maximum=100)
|
||||
self.progress.pack(fill="x")
|
||||
self.status = tk.StringVar(value="就绪")
|
||||
ttk.Label(ctrl, textvariable=self.status).pack(anchor="w", pady=(6, 0))
|
||||
|
||||
btn_row = ttk.Frame(ctrl)
|
||||
btn_row.pack(fill="x", pady=(8, 0))
|
||||
self.run_btn = ttk.Button(btn_row, text="运行", command=self._on_run)
|
||||
self.run_btn.pack(side="left")
|
||||
self.stop_btn = ttk.Button(btn_row, text="停止", command=self._on_stop, state="disabled")
|
||||
self.stop_btn.pack(side="left", padx=(8, 0))
|
||||
|
||||
log_frame = ttk.Frame(right)
|
||||
log_frame.pack(fill="both", expand=True, pady=(10, 0))
|
||||
ttk.Label(log_frame, text="日志").pack(anchor="w")
|
||||
self.log_text = tk.Text(log_frame, height=20, wrap="word")
|
||||
self.log_text.pack(side="left", fill="both", expand=True)
|
||||
scroll = ttk.Scrollbar(log_frame, command=self.log_text.yview)
|
||||
scroll.pack(side="right", fill="y")
|
||||
self.log_text.configure(yscrollcommand=scroll.set)
|
||||
|
||||
def _add_path_row(self, parent, key, label, default, tip, is_save):
|
||||
row = ttk.Frame(parent)
|
||||
row.pack(fill="x", pady=(0, 6))
|
||||
ttk.Label(row, text=label).pack(side="left")
|
||||
var = tk.StringVar(value=default)
|
||||
self.vars[key] = var
|
||||
entry = ttk.Entry(row, textvariable=var)
|
||||
entry.pack(side="left", fill="x", expand=True, padx=(10, 0))
|
||||
Tooltip(entry, tip)
|
||||
btn = ttk.Button(
|
||||
row,
|
||||
text="浏览",
|
||||
command=lambda: self._browse_path(var, is_save=is_save),
|
||||
)
|
||||
btn.pack(side="left", padx=(8, 0))
|
||||
|
||||
def _add_group(self, parent, title, items):
|
||||
frame = ttk.LabelFrame(parent, text=title)
|
||||
frame.pack(fill="x", pady=(0, 10))
|
||||
for r, (key, typ, tip) in enumerate(items):
|
||||
ttk.Label(frame, text=key).grid(row=r, column=0, sticky="w", padx=(6, 8), pady=4)
|
||||
if typ is bool:
|
||||
var = tk.BooleanVar(value=bool(DEFAULTS.get(key)))
|
||||
self.vars[key] = var
|
||||
w = ttk.Checkbutton(frame, variable=var)
|
||||
w.grid(row=r, column=1, sticky="w", pady=4)
|
||||
Tooltip(w, tip)
|
||||
else:
|
||||
var = tk.StringVar(value=str(DEFAULTS.get(key, "")))
|
||||
self.vars[key] = var
|
||||
entry = ttk.Entry(frame, textvariable=var, width=14)
|
||||
entry.grid(row=r, column=1, sticky="w", pady=4)
|
||||
Tooltip(entry, tip)
|
||||
frame.grid_columnconfigure(1, weight=1)
|
||||
|
||||
def _browse_path(self, var, is_save):
|
||||
if is_save:
|
||||
p = filedialog.asksaveasfilename(
|
||||
title="选择输出文件",
|
||||
defaultextension=".tif",
|
||||
filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")],
|
||||
)
|
||||
else:
|
||||
p = filedialog.askopenfilename(
|
||||
title="选择输入文件",
|
||||
filetypes=[("GeoTIFF", "*.tif *.tiff"), ("All files", "*.*")],
|
||||
)
|
||||
if p:
|
||||
var.set(p)
|
||||
|
||||
def _append_log(self, msg):
|
||||
self.log_text.insert("end", msg + "\n")
|
||||
self.log_text.see("end")
|
||||
|
||||
def _log(self, msg):
|
||||
self.after(0, lambda: self._append_log(str(msg)))
|
||||
|
||||
def _progress(self, stage, cur, total):
|
||||
def _update():
|
||||
if total > 0:
|
||||
self.progress["value"] = (cur / total) * 100.0
|
||||
self.status.set(f"{stage}: {cur}/{total}")
|
||||
self.after(0, _update)
|
||||
|
||||
def _get_config(self):
|
||||
cfg = {}
|
||||
for k, v in self.vars.items():
|
||||
if isinstance(v, tk.BooleanVar):
|
||||
cfg[k] = bool(v.get())
|
||||
else:
|
||||
cfg[k] = v.get()
|
||||
|
||||
int_keys = [
|
||||
"overview_max_side",
|
||||
"band_radius",
|
||||
"fine_tile_size",
|
||||
"fine_overlap",
|
||||
"coarse_tile_size",
|
||||
"coarse_tile_overlap",
|
||||
"coarse_downsample_factor",
|
||||
"num_splits_y",
|
||||
"num_splits_x",
|
||||
"region_overlap",
|
||||
"min_area",
|
||||
]
|
||||
float_keys = ["overview_threshold", "coarse_threshold", "final_threshold"]
|
||||
for k in int_keys:
|
||||
if k in cfg and cfg[k] != "":
|
||||
cfg[k] = int(float(cfg[k]))
|
||||
for k in float_keys:
|
||||
if k in cfg and cfg[k] != "":
|
||||
cfg[k] = float(cfg[k])
|
||||
|
||||
cfg["use_tqdm"] = False
|
||||
cfg["device"] = None
|
||||
return cfg
|
||||
|
||||
def _on_run(self):
|
||||
if self.worker is not None and self.worker.is_alive():
|
||||
return
|
||||
cfg = self._get_config()
|
||||
if not cfg.get("image_path") or not cfg.get("mask_output_path"):
|
||||
messagebox.showerror("错误", "请先选择输入影像与输出路径。")
|
||||
return
|
||||
|
||||
self.stop_event.clear()
|
||||
self.run_btn.configure(state="disabled")
|
||||
self.stop_btn.configure(state="normal")
|
||||
self.progress["value"] = 0
|
||||
self.status.set("启动中...")
|
||||
self._append_log("开始运行...")
|
||||
|
||||
def _work():
|
||||
try:
|
||||
water_v5 = _lazy_import_backend()
|
||||
water_v5.run_segmentation(
|
||||
config=cfg,
|
||||
progress_callback=self._progress,
|
||||
log_callback=self._log,
|
||||
stop_event=self.stop_event,
|
||||
)
|
||||
except Exception as e:
|
||||
self._log(f"错误: {e}")
|
||||
finally:
|
||||
def _done():
|
||||
self.run_btn.configure(state="normal")
|
||||
self.stop_btn.configure(state="disabled")
|
||||
self.status.set("完成")
|
||||
self.after(0, _done)
|
||||
|
||||
self.worker = threading.Thread(target=_work, daemon=True)
|
||||
self.worker.start()
|
||||
|
||||
def _on_stop(self):
|
||||
self.stop_event.set()
|
||||
self.status.set("停止中...")
|
||||
|
||||
|
||||
def main():
|
||||
WaterSegGUI().mainloop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
484
water_V5.py
484
water_V5.py
@ -1,14 +1,17 @@
|
||||
"""
|
||||
使用SAM3模型分割超大遥感影像中的水体(分区域处理,优化内存)
|
||||
使用SAM3模型分割超大遥感影像中的水体(三层分割策略)
|
||||
|
||||
流程:
|
||||
1. 将影像划分为若干有重叠的子区域。
|
||||
2. 对每个子区域独立执行粗分割、边缘带构建、精修分割。
|
||||
3. 合并所有子区域的掩码(重叠区域取最大值)。
|
||||
4. 后处理:填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。
|
||||
5. 保存结果。
|
||||
1. 将影像划分为若干有重叠的子区域(第一层)。
|
||||
2. 对每个子区域进行中分辨率粗分割:以4096为块大小,滑动窗口推理,得到粗掩码(第二层)。
|
||||
3. 基于粗掩码构建边缘带,对边缘带进行精细分块推理,得到精修掩码(第三层)。
|
||||
4. 合并所有子区域的掩码(重叠区域取最大值)。
|
||||
5. 后处理:填充内部NoData空洞、移除小面积碎片、可选保留最大连通域。
|
||||
6. 保存结果。
|
||||
|
||||
优点:避免在全尺寸概率图上进行GPU操作,显著降低显存占用。
|
||||
优点:
|
||||
- 粗分割采用分块推理,避免降采样丢失细节,提高召回率。
|
||||
- 精修只处理边缘带,节省计算资源。
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -22,6 +25,7 @@ from rasterio.windows import Window
|
||||
from rasterio.io import MemoryFile
|
||||
from tqdm import tqdm
|
||||
from scipy import ndimage
|
||||
import math
|
||||
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
@ -77,6 +81,13 @@ def compute_stretch_params(bands, sample_max_pixels=2_000_000):
|
||||
return vmin, vmax
|
||||
|
||||
|
||||
def make_overview_bands(bands, max_side):
|
||||
_, h, w = bands.shape
|
||||
step = int(math.ceil(max(h, w) / float(max_side)))
|
||||
step = max(step, 1)
|
||||
return bands[:, ::step, ::step], step
|
||||
|
||||
|
||||
def _to_uint8(arr, vmin=None, vmax=None):
|
||||
if arr.dtype == np.uint8:
|
||||
return arr
|
||||
@ -112,7 +123,7 @@ def _bands_to_pil(bands, stretch_params=None):
|
||||
return Image.fromarray(rgb, mode="RGB")
|
||||
|
||||
def _read_pil_window(src, window, stretch_params=None):
|
||||
bands = src.read(window=window, boundless=True, fill_value=0)
|
||||
bands = src.read(window=window)
|
||||
return _bands_to_pil(bands, stretch_params=stretch_params)
|
||||
|
||||
def infer_prompt_prob(processor, image, prompt):
|
||||
@ -126,57 +137,169 @@ def infer_prompt_prob(processor, image, prompt):
|
||||
prob = combine_masks_logits(state["masks_logits"])
|
||||
return prob
|
||||
|
||||
def make_overview_bands(sub_bands, max_side):
|
||||
_, h, w = sub_bands.shape
|
||||
step = int(np.ceil(max(h, w) / float(max_side)))
|
||||
step = max(step, 1)
|
||||
return sub_bands[:, ::step, ::step], step
|
||||
# ---------- 新增:中分辨率粗分割(分块推理) ----------
|
||||
def _downsample_any(mask, factor):
|
||||
if factor <= 1:
|
||||
return mask
|
||||
h, w = mask.shape
|
||||
pad_h = (-h) % factor
|
||||
pad_w = (-w) % factor
|
||||
if pad_h or pad_w:
|
||||
mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=False)
|
||||
h, w = mask.shape
|
||||
h2 = h // factor
|
||||
w2 = w // factor
|
||||
return mask.reshape(h2, factor, w2, factor).any(axis=(1, 3))
|
||||
|
||||
# ---------- 精修分块 ----------
|
||||
def refine_tiles(processor, src, prompt, band_coarse_cpu, tile_size, overlap, stretch_params=None):
|
||||
|
||||
def _build_band_ndimage(mask, radius):
|
||||
if radius <= 0:
|
||||
return np.zeros_like(mask, dtype=bool)
|
||||
struct = ndimage.generate_binary_structure(2, 2)
|
||||
dil = ndimage.binary_dilation(mask, structure=struct, iterations=radius)
|
||||
ero = ndimage.binary_erosion(mask, structure=struct, iterations=radius)
|
||||
return np.logical_xor(dil, ero)
|
||||
|
||||
|
||||
def coarse_tile_segmentation(
|
||||
processor,
|
||||
src,
|
||||
prompt,
|
||||
tile_size,
|
||||
overlap,
|
||||
downsample_factor=4,
|
||||
stretch_params=None,
|
||||
nodata_mask=None,
|
||||
use_tqdm=True,
|
||||
):
|
||||
"""
|
||||
对边缘带进行精修分割
|
||||
band_coarse_cpu: 2D numpy bool,指示需要精修的区域(低分辨率)
|
||||
对子区域进行分块粗分割,返回低分辨率概率图(float16)
|
||||
|
||||
Args:
|
||||
processor: 粗处理器(输入分辨率 coarse_resolution)
|
||||
src: rasterio 数据集(子区域)
|
||||
prompt: 文本提示
|
||||
tile_size: 粗分割块大小(原图像素)
|
||||
overlap: 块重叠像素
|
||||
stretch_params: 拉伸参数
|
||||
nodata_mask: 子区域的 NoData 掩码(bool, 与子区域同尺寸),用于过滤无效区域
|
||||
|
||||
Returns:
|
||||
coarse_probs_lr: 子区域粗概率图 (H_lr,W_lr) float16
|
||||
"""
|
||||
height, width = src.height, src.width
|
||||
full_probs = np.zeros((height, width), dtype=np.float16)
|
||||
|
||||
band_h, band_w = band_coarse_cpu.shape
|
||||
scale_y = band_h / float(height)
|
||||
scale_x = band_w / float(width)
|
||||
height_lr = int(math.ceil(height / float(downsample_factor)))
|
||||
width_lr = int(math.ceil(width / float(downsample_factor)))
|
||||
full_probs_lr = np.zeros((height_lr, width_lr), dtype=np.float16)
|
||||
|
||||
stride = max(tile_size - overlap, 1)
|
||||
num_tiles_y = (height + stride - 1) // stride
|
||||
num_tiles_x = (width + stride - 1) // stride
|
||||
total_tiles = num_tiles_y * num_tiles_x
|
||||
|
||||
# 内部进度条,leave=True 以便保留历史记录
|
||||
with tqdm(total=total_tiles, desc="精修分块", unit="块", leave=True) as pbar:
|
||||
for top, left, bottom, right in tile_slices(height, width, tile_size, overlap):
|
||||
c_top = int(top * scale_y)
|
||||
c_left = int(left * scale_x)
|
||||
c_bottom = max(int(np.ceil(bottom * scale_y)), c_top + 1)
|
||||
c_right = max(int(np.ceil(right * scale_x)), c_left + 1)
|
||||
if use_tqdm:
|
||||
iterator = tqdm(
|
||||
tile_slices(height, width, tile_size, overlap),
|
||||
total=total_tiles,
|
||||
desc="粗分割分块",
|
||||
unit="块",
|
||||
leave=True,
|
||||
)
|
||||
else:
|
||||
iterator = tile_slices(height, width, tile_size, overlap)
|
||||
|
||||
if not band_coarse_cpu[c_top:c_bottom, c_left:c_right].any():
|
||||
pbar.update(1)
|
||||
for top, left, bottom, right in iterator:
|
||||
# 如果该块完全在 NoData 区域内,则跳过(加速)
|
||||
if nodata_mask is not None:
|
||||
block_nodata = nodata_mask[top:bottom, left:right]
|
||||
if block_nodata.all():
|
||||
continue
|
||||
|
||||
window = Window(left, top, right - left, bottom - top)
|
||||
crop = _read_pil_window(src, window, stretch_params=stretch_params)
|
||||
tile_prob = infer_prompt_prob(processor, crop, prompt)
|
||||
if tile_prob is None:
|
||||
continue
|
||||
|
||||
h_lr = int(math.ceil((bottom - top) / float(downsample_factor)))
|
||||
w_lr = int(math.ceil((right - left) / float(downsample_factor)))
|
||||
if tile_prob.shape[-2:] != (h_lr, w_lr):
|
||||
tile_prob = upsample_prob(tile_prob, (h_lr, w_lr))
|
||||
|
||||
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
|
||||
top_lr = top // downsample_factor
|
||||
left_lr = left // downsample_factor
|
||||
bottom_lr = top_lr + h_lr
|
||||
right_lr = left_lr + w_lr
|
||||
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr] = np.maximum(
|
||||
full_probs_lr[top_lr:bottom_lr, left_lr:right_lr], tile_prob_cpu
|
||||
)
|
||||
|
||||
if nodata_mask is not None:
|
||||
nodata_lr = _downsample_any(nodata_mask, downsample_factor)
|
||||
full_probs_lr[nodata_lr] = 0.0
|
||||
|
||||
return full_probs_lr
|
||||
|
||||
# ---------- 精修分块(保持不变,但可复用之前的函数) ----------
|
||||
def refine_tiles(
|
||||
processor,
|
||||
src,
|
||||
prompt,
|
||||
band_lr,
|
||||
downsample_factor,
|
||||
tile_size,
|
||||
overlap,
|
||||
stretch_params=None,
|
||||
use_tqdm=True,
|
||||
):
|
||||
"""
|
||||
对边缘带进行精修分割
|
||||
band_lr: 2D numpy bool,指示需要精修的区域(低分辨率)
|
||||
"""
|
||||
height, width = src.height, src.width
|
||||
full_probs = np.zeros((height, width), dtype=np.float16)
|
||||
|
||||
stride = max(tile_size - overlap, 1)
|
||||
num_tiles_y = (height + stride - 1) // stride
|
||||
num_tiles_x = (width + stride - 1) // stride
|
||||
total_tiles = num_tiles_y * num_tiles_x
|
||||
|
||||
if use_tqdm:
|
||||
iterator = tqdm(
|
||||
tile_slices(height, width, tile_size, overlap),
|
||||
total=total_tiles,
|
||||
desc="精修分块",
|
||||
unit="块",
|
||||
leave=True,
|
||||
)
|
||||
else:
|
||||
iterator = tile_slices(height, width, tile_size, overlap)
|
||||
|
||||
for top, left, bottom, right in iterator:
|
||||
c_top = top // downsample_factor
|
||||
c_left = left // downsample_factor
|
||||
c_bottom = int(math.ceil(bottom / float(downsample_factor)))
|
||||
c_right = int(math.ceil(right / float(downsample_factor)))
|
||||
c_bottom = max(c_bottom, c_top + 1)
|
||||
c_right = max(c_right, c_left + 1)
|
||||
|
||||
if not band_lr[c_top:c_bottom, c_left:c_right].any():
|
||||
continue
|
||||
|
||||
window = Window(left, top, right - left, bottom - top)
|
||||
crop = _read_pil_window(src, window, stretch_params=stretch_params)
|
||||
tile_prob = infer_prompt_prob(processor, crop, prompt)
|
||||
if tile_prob is None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
if tile_prob.shape[-2:] != (bottom - top, right - left):
|
||||
tile_prob = upsample_prob(tile_prob, (bottom - top, right - left))
|
||||
|
||||
tile_prob_cpu = tile_prob.detach().float().cpu().numpy().astype(np.float16)
|
||||
tile_prob_cpu = tile_prob.to(torch.float16).detach().cpu().numpy()
|
||||
full_probs[top:bottom, left:right] = np.maximum(
|
||||
full_probs[top:bottom, left:right], tile_prob_cpu
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
return full_probs
|
||||
|
||||
@ -230,80 +353,121 @@ def split_into_regions(width, height, num_splits_y=2, num_splits_x=3, overlap=25
|
||||
regions.append((left, top, right, bottom))
|
||||
return regions
|
||||
|
||||
# ---------- 主程序 ----------
|
||||
matplotlib.use("TkAgg")
|
||||
DEFAULT_CONFIG = {
|
||||
"image_path": r"E:\is2\guidingsahn\result.tif",
|
||||
"mask_output_path": r"E:\is2\guidingsahn\result_mask.tif",
|
||||
"prompt": "water body",
|
||||
"overview_max_side": 1400,
|
||||
"overview_threshold": 0.4,
|
||||
"coarse_resolution": 1008,
|
||||
"fine_resolution": 1008,
|
||||
"coarse_threshold": 0.5,
|
||||
"final_threshold": 0.5,
|
||||
"band_radius": 64,
|
||||
"fine_tile_size": 2048,
|
||||
"fine_overlap": 256,
|
||||
"coarse_tile_size": 4096,
|
||||
"coarse_tile_overlap": 256,
|
||||
"coarse_downsample_factor": 4,
|
||||
"num_splits_y": 2,
|
||||
"num_splits_x": 3,
|
||||
"region_overlap": 256,
|
||||
"min_area": 5000,
|
||||
"keep_largest_only": False,
|
||||
"use_tqdm": True,
|
||||
"device": None,
|
||||
}
|
||||
|
||||
# 参数设置
|
||||
image_path = r"E:\is2\dingshanhu\result_caijian.tif"
|
||||
mask_output_path = r"E:\is2\dingshanhu\result_maskV2.tif"
|
||||
prompt = "water body"
|
||||
coarse_read_max_side = 1200
|
||||
coarse_resolution = 1008
|
||||
fine_resolution = 1008
|
||||
coarse_threshold = 0.5
|
||||
final_threshold = 0.5
|
||||
band_radius = 64
|
||||
tile_size = 1536
|
||||
overlap = 128
|
||||
|
||||
# ========== 分区域参数 ==========
|
||||
num_splits_y = 2 # 纵向切分数
|
||||
num_splits_x = 3 # 横向切分数(共 2x3=6 份)
|
||||
region_overlap = 256 # 子区域之间的重叠像素数
|
||||
# ===============================
|
||||
def run_segmentation(config=None, progress_callback=None, log_callback=None, stop_event=None):
|
||||
cfg = dict(DEFAULT_CONFIG)
|
||||
if config:
|
||||
cfg.update(config)
|
||||
|
||||
# ========== 后处理参数 ==========
|
||||
min_area = 1000 # 最小面积阈值(像素),小于此值的连通域将被移除;设为0或None表示不进行面积过滤
|
||||
keep_largest_only = False # 是否只保留最大的连通域(True/False)
|
||||
# ===============================
|
||||
log = log_callback if log_callback is not None else print
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"使用设备: {device}")
|
||||
device = cfg["device"]
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
log(f"使用设备: {device}")
|
||||
|
||||
# 加载模型(所有子区域共享同一个模型)
|
||||
model = build_sam3_image_model().to(device).eval()
|
||||
coarse_processor = Sam3Processor(model, resolution=coarse_resolution, device=device)
|
||||
fine_processor = Sam3Processor(model, resolution=fine_resolution, device=device)
|
||||
model = build_sam3_image_model().to(device).eval()
|
||||
coarse_processor = Sam3Processor(
|
||||
model, resolution=cfg["coarse_resolution"], device=device
|
||||
)
|
||||
fine_processor = Sam3Processor(model, resolution=cfg["fine_resolution"], device=device)
|
||||
|
||||
# 打开原始影像
|
||||
with rasterio.open(image_path) as src:
|
||||
image_path = cfg["image_path"]
|
||||
mask_output_path = cfg["mask_output_path"]
|
||||
|
||||
with rasterio.open(image_path) as src:
|
||||
nodata = src.nodata
|
||||
print(f"原始影像NoData值: {nodata}")
|
||||
log(f"原始影像NoData值: {nodata}")
|
||||
height, width = src.height, src.width
|
||||
|
||||
# 划分区域
|
||||
regions = split_into_regions(width, height, num_splits_y, num_splits_x, region_overlap)
|
||||
print(f"共划分 {len(regions)} 个子区域")
|
||||
band1 = src.read(1)
|
||||
if np.issubdtype(band1.dtype, np.floating):
|
||||
nodata_mask_full = (
|
||||
np.isclose(band1, float(nodata))
|
||||
if nodata is not None
|
||||
else np.zeros_like(band1, dtype=bool)
|
||||
)
|
||||
else:
|
||||
nodata_mask_full = (
|
||||
(band1 == nodata) if nodata is not None else np.zeros_like(band1, dtype=bool)
|
||||
)
|
||||
|
||||
regions = split_into_regions(
|
||||
width, height, cfg["num_splits_y"], cfg["num_splits_x"], cfg["region_overlap"]
|
||||
)
|
||||
log(f"共划分 {len(regions)} 个子区域")
|
||||
|
||||
# 创建全尺寸掩码数组(CPU内存)
|
||||
full_mask = np.zeros((height, width), dtype=np.uint8)
|
||||
|
||||
# 子区域总进度条
|
||||
with tqdm(total=len(regions), desc="处理子区域", unit="子区域") as region_pbar:
|
||||
for idx, (left, top, right, bottom) in enumerate(regions):
|
||||
print(f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})")
|
||||
use_tqdm = bool(cfg.get("use_tqdm", True))
|
||||
if use_tqdm:
|
||||
region_iter = tqdm(
|
||||
list(enumerate(regions)),
|
||||
total=len(regions),
|
||||
desc="处理子区域",
|
||||
unit="子区域",
|
||||
)
|
||||
else:
|
||||
region_iter = enumerate(regions)
|
||||
|
||||
for idx, (left, top, right, bottom) in region_iter:
|
||||
if stop_event is not None and stop_event.is_set():
|
||||
log("已停止")
|
||||
return
|
||||
if progress_callback is not None:
|
||||
progress_callback("region", idx + 1, len(regions))
|
||||
|
||||
log(
|
||||
f"\n子区域 {idx+1}/{len(regions)}: 坐标范围 ({left},{top}) -> ({right},{bottom})"
|
||||
)
|
||||
region_w = right - left
|
||||
region_h = bottom - top
|
||||
|
||||
# 读取子区域数据
|
||||
window = Window(left, top, region_w, region_h)
|
||||
sub_bands = src.read(window=window) # shape: (bands, h, w)
|
||||
sub_bands = src.read(window=window)
|
||||
|
||||
# 构建子区域元数据
|
||||
sub_profile = src.profile.copy()
|
||||
if not src.is_tiled:
|
||||
sub_profile.pop('blockxsize', None)
|
||||
sub_profile.pop('blockysize', None)
|
||||
sub_profile['tiled'] = False
|
||||
sub_profile.pop("blockxsize", None)
|
||||
sub_profile.pop("blockysize", None)
|
||||
sub_profile["tiled"] = False
|
||||
else:
|
||||
sub_profile['tiled'] = True
|
||||
sub_profile.update({
|
||||
'height': region_h,
|
||||
'width': region_w,
|
||||
'transform': rasterio.windows.transform(window, src.transform)
|
||||
})
|
||||
sub_profile["tiled"] = True
|
||||
sub_profile.update(
|
||||
{
|
||||
"height": region_h,
|
||||
"width": region_w,
|
||||
"transform": rasterio.windows.transform(window, src.transform),
|
||||
}
|
||||
)
|
||||
|
||||
sub_nodata = nodata_mask_full[top:bottom, left:right]
|
||||
|
||||
# 将子区域数据包装成内存中的 rasterio 数据集
|
||||
with MemoryFile() as memfile:
|
||||
with memfile.open(**sub_profile) as sub_dst:
|
||||
sub_dst.write(sub_bands)
|
||||
@ -311,56 +475,98 @@ with rasterio.open(image_path) as src:
|
||||
stretch_params = compute_stretch_params(sub_bands)
|
||||
|
||||
overview_bands, overview_step = make_overview_bands(
|
||||
sub_bands, max_side=coarse_read_max_side
|
||||
sub_bands, max_side=cfg["overview_max_side"]
|
||||
)
|
||||
overview_img = _bands_to_pil(
|
||||
overview_bands, stretch_params=stretch_params
|
||||
)
|
||||
overview_prob = infer_prompt_prob(
|
||||
coarse_processor, overview_img, cfg["prompt"]
|
||||
)
|
||||
overview_img = _bands_to_pil(overview_bands, stretch_params=stretch_params)
|
||||
overview_prob = infer_prompt_prob(coarse_processor, overview_img, prompt)
|
||||
if overview_prob is None:
|
||||
overview_prob_cpu = np.zeros((overview_img.height, overview_img.width), dtype=np.float16)
|
||||
else:
|
||||
overview_prob_cpu = overview_prob.detach().float().cpu().numpy().astype(np.float16)
|
||||
overview_mask_cpu = overview_prob_cpu > coarse_threshold
|
||||
|
||||
scale = overview_img.height / float(region_h)
|
||||
band_radius_small = max(int(round(band_radius * scale)), 1)
|
||||
band_small = build_band(torch.from_numpy(overview_mask_cpu), band_radius_small).numpy()
|
||||
|
||||
fine_probs_cpu = refine_tiles(
|
||||
fine_processor,
|
||||
sub_src,
|
||||
prompt,
|
||||
band_small,
|
||||
tile_size=tile_size,
|
||||
overlap=overlap,
|
||||
stretch_params=stretch_params,
|
||||
overview_mask_small = np.zeros(
|
||||
(overview_img.height, overview_img.width), dtype=bool
|
||||
)
|
||||
fine_mask = fine_probs_cpu > final_threshold
|
||||
else:
|
||||
overview_mask_small = (
|
||||
(overview_prob > cfg["overview_threshold"])
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
if sub_nodata.any():
|
||||
nodata_overview = sub_nodata[::overview_step, ::overview_step]
|
||||
overview_mask_small[nodata_overview] = False
|
||||
|
||||
coarse_probs = coarse_tile_segmentation(
|
||||
processor=coarse_processor,
|
||||
src=sub_src,
|
||||
prompt=cfg["prompt"],
|
||||
tile_size=cfg["coarse_tile_size"],
|
||||
overlap=cfg["coarse_tile_overlap"],
|
||||
downsample_factor=cfg["coarse_downsample_factor"],
|
||||
stretch_params=stretch_params,
|
||||
nodata_mask=sub_nodata if sub_nodata.any() else None,
|
||||
use_tqdm=use_tqdm,
|
||||
)
|
||||
coarse_mask_lr = coarse_probs > cfg["coarse_threshold"]
|
||||
|
||||
overview_mask_lr = np.array(
|
||||
Image.fromarray(
|
||||
overview_mask_small.astype(np.uint8), mode="L"
|
||||
).resize(
|
||||
(coarse_mask_lr.shape[1], coarse_mask_lr.shape[0]),
|
||||
resample=Image.NEAREST,
|
||||
)
|
||||
).astype(bool)
|
||||
combined_mask_lr = coarse_mask_lr | overview_mask_lr
|
||||
|
||||
band_radius_lr = max(
|
||||
int(round(cfg["band_radius"] / float(cfg["coarse_downsample_factor"]))),
|
||||
1,
|
||||
)
|
||||
band_lr = _build_band_ndimage(combined_mask_lr, band_radius_lr)
|
||||
|
||||
fine_probs = refine_tiles(
|
||||
processor=fine_processor,
|
||||
src=sub_src,
|
||||
prompt=cfg["prompt"],
|
||||
band_lr=band_lr,
|
||||
downsample_factor=cfg["coarse_downsample_factor"],
|
||||
tile_size=cfg["fine_tile_size"],
|
||||
overlap=cfg["fine_overlap"],
|
||||
stretch_params=stretch_params,
|
||||
use_tqdm=use_tqdm,
|
||||
)
|
||||
fine_mask = fine_probs > cfg["final_threshold"]
|
||||
|
||||
coarse_mask_full = np.array(
|
||||
Image.fromarray(overview_mask_cpu.astype(np.uint8), mode="L").resize(
|
||||
Image.fromarray(coarse_mask_lr.astype(np.uint8), mode="L").resize(
|
||||
(region_w, region_h), resample=Image.NEAREST
|
||||
)
|
||||
).astype(bool)
|
||||
sub_mask_np = (fine_mask | coarse_mask_full).astype(np.uint8)
|
||||
overview_mask_full = np.array(
|
||||
Image.fromarray(
|
||||
overview_mask_small.astype(np.uint8), mode="L"
|
||||
).resize((region_w, region_h), resample=Image.NEAREST)
|
||||
).astype(bool)
|
||||
sub_mask_np = (fine_mask | coarse_mask_full | overview_mask_full).astype(
|
||||
np.uint8
|
||||
)
|
||||
|
||||
# 合并到全尺寸掩码(重叠区域取最大值)
|
||||
full_mask[top:bottom, left:right] = np.maximum(
|
||||
full_mask[top:bottom, left:right], sub_mask_np
|
||||
)
|
||||
|
||||
# 更新子区域进度条
|
||||
region_pbar.update(1)
|
||||
|
||||
# ========== 后处理 ==========
|
||||
print("\n后处理:填充内部NoData空洞...")
|
||||
log("\n后处理:填充内部NoData空洞...")
|
||||
if nodata is not None:
|
||||
band1 = src.read(1)
|
||||
if np.issubdtype(band1.dtype, np.floating):
|
||||
nodata_mask = np.isclose(band1, float(nodata))
|
||||
else:
|
||||
nodata_mask = (band1 == nodata)
|
||||
nodata_mask = band1 == nodata
|
||||
|
||||
labeled_mask, num_features = ndimage.label(nodata_mask)
|
||||
labeled_mask, _ = ndimage.label(nodata_mask)
|
||||
boundary_mask = np.zeros_like(nodata_mask, dtype=bool)
|
||||
boundary_mask[0, :] = True
|
||||
boundary_mask[-1, :] = True
|
||||
@ -368,27 +574,45 @@ with rasterio.open(image_path) as src:
|
||||
boundary_mask[:, -1] = True
|
||||
|
||||
boundary_labels = set(labeled_mask[boundary_mask])
|
||||
internal_nodata_mask = np.isin(labeled_mask, list(boundary_labels), invert=True) & nodata_mask
|
||||
internal_nodata_mask = np.isin(
|
||||
labeled_mask, list(boundary_labels), invert=True
|
||||
) & nodata_mask
|
||||
|
||||
if internal_nodata_mask.any():
|
||||
struct = ndimage.generate_binary_structure(2, 2)
|
||||
internal_dilated = ndimage.binary_dilation(internal_nodata_mask, structure=struct, iterations=7)
|
||||
internal_dilated = ndimage.binary_dilation(
|
||||
internal_nodata_mask, structure=struct, iterations=7
|
||||
)
|
||||
full_mask[internal_dilated] = 1
|
||||
print(f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}")
|
||||
log(
|
||||
f" 内部NoData原始像素数: {np.sum(internal_nodata_mask)},膨胀后像素数: {np.sum(internal_dilated)}"
|
||||
)
|
||||
else:
|
||||
print(" 无内部NoData区域")
|
||||
log(" 无内部NoData区域")
|
||||
|
||||
print("后处理:面积过滤...")
|
||||
if min_area is not None or keep_largest_only:
|
||||
log("后处理:面积过滤...")
|
||||
if cfg["min_area"] is not None or cfg["keep_largest_only"]:
|
||||
original_count = np.sum(full_mask)
|
||||
full_mask = filter_by_area(full_mask, min_area=min_area, keep_largest_only=keep_largest_only)
|
||||
full_mask = filter_by_area(
|
||||
full_mask,
|
||||
min_area=cfg["min_area"],
|
||||
keep_largest_only=cfg["keep_largest_only"],
|
||||
)
|
||||
filtered_count = np.sum(full_mask)
|
||||
print(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
|
||||
log(f" 后处理前水体像素数: {original_count},后处理后: {filtered_count}")
|
||||
|
||||
# ========== 保存结果 ==========
|
||||
profile = src.profile.copy()
|
||||
profile.update(count=1, dtype="uint8", compress='lzw')
|
||||
profile.update(count=1, dtype="uint8", compress="lzw")
|
||||
with rasterio.open(mask_output_path, "w", **profile) as dst:
|
||||
dst.write(full_mask, 1)
|
||||
|
||||
print(f"分割完成,结果已保存至:{mask_output_path}")
|
||||
log(f"分割完成,结果已保存至:{mask_output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
matplotlib.use("TkAgg")
|
||||
run_segmentation()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user