feat(sampling): add adaptive sampling toggle + interactive sampling point viewer

This commit is contained in:
DXC
2026-06-08 15:39:43 +08:00
parent e57fdb4f75
commit d22414bf7d
4 changed files with 294 additions and 5 deletions

View File

@ -26,6 +26,7 @@ class PredictionStep:
glint_mask_path: Optional[str] = None, glint_mask_path: Optional[str] = None,
output_dir: Union[str, Path] = "./10_sampling", output_dir: Union[str, Path] = "./10_sampling",
callback: Optional[Callable] = None, callback: Optional[Callable] = None,
use_adaptive_sampling: bool = True,
) -> str: ) -> str:
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱""" """生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
from pathlib import Path from pathlib import Path
@ -83,10 +84,14 @@ class PredictionStep:
if glint_mask_to_use is None: if glint_mask_to_use is None:
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。") print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
# 传递极度安全的 deglint_img_str 进底层 # 传递极度安全的 deglint_img_str 进底层(关键字传参,避免 positional 参数顺序陷阱)
get_spectral_sampling_points_chunked( get_spectral_sampling_points_chunked(
deglint_img_str, water_mask_path, glint_mask_to_use, deglint_img_str, water_mask_path, glint_mask_to_use,
output_path, interval, sample_radius, chunk_size output_path,
interval=interval,
sample_radius=sample_radius,
chunk_size=chunk_size,
use_adaptive_sampling=use_adaptive_sampling,
) )
notify("completed", f"采样点光谱数据已保存: {output_path}") notify("completed", f"采样点光谱数据已保存: {output_path}")

View File

@ -753,17 +753,19 @@ class WaterQualityInversionPipeline:
chunk_size: int = 1000, chunk_size: int = 1000,
water_mask_path: Optional[str] = None, water_mask_path: Optional[str] = None,
glint_mask_path: Optional[str] = None, glint_mask_path: Optional[str] = None,
use_adaptive_sampling: bool = True,
skip_dependency_check: bool = False, **kwargs) -> str: skip_dependency_check: bool = False, **kwargs) -> str:
""" """
步骤7: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱 步骤7: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱
Args: Args:
deglint_img_path: 去除耀斑后的影像文件路径如果为None使用步骤3的结果 deglint_img_path: 去除耀斑后的影像文件路径如果为None使用步骤3的结果
interval: 采样点间隔(像元数) interval: 采样点间隔(像元数)
sample_radius: 采样点半径(像元数) sample_radius: 采样点半径(像元数)
chunk_size: 每次处理的行数(控制内存使用) chunk_size: 每次处理的行数(控制内存使用)
water_mask_path: dat格式的水域掩膜文件路径如果为None将使用步骤1生成的dat格式掩膜 water_mask_path: dat格式的水域掩膜文件路径如果为None将使用步骤1生成的dat格式掩膜
use_adaptive_sampling: 是否启用自适应采样(根据水体宽度动态调整间隔)
Returns: Returns:
采样点光谱数据CSV文件路径 采样点光谱数据CSV文件路径
""" """
@ -786,6 +788,7 @@ class WaterQualityInversionPipeline:
water_mask_path=water_mask_path, water_mask_path=water_mask_path,
glint_mask_path=glint_mask_path, glint_mask_path=glint_mask_path,
output_dir=str(self.sampling_dir), output_dir=str(self.sampling_dir),
use_adaptive_sampling=use_adaptive_sampling,
) )
self._record_step_time("步骤7: 生成预测采样点", 0, 0) self._record_step_time("步骤7: 生成预测采样点", 0, 0)
self._notify("completed", f"采样点光谱数据已保存: {result}") self._notify("completed", f"采样点光谱数据已保存: {result}")

View File

@ -7,6 +7,11 @@
""" """
import os import os
from typing import Optional
import numpy as np
import pandas as pd
from PyQt5.QtCore import Qt, QTimer from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtGui import QFont from PyQt5.QtGui import QFont
from PyQt5.QtWidgets import ( from PyQt5.QtWidgets import (
@ -20,7 +25,11 @@ from PyQt5.QtWidgets import (
QWidget, QWidget,
QComboBox, QComboBox,
QLineEdit, QLineEdit,
QTableWidget,
QTableWidgetItem,
) )
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
class BandConfirmDialog(QDialog): class BandConfirmDialog(QDialog):
@ -334,3 +343,231 @@ class AISettingsDialog(QDialog):
"text_model": s.value("text_model", "", type=str), "text_model": s.value("text_model", "", type=str),
"timeout_s": s.value("timeout_s", 120, type=int), "timeout_s": s.value("timeout_s", 120, type=int),
} }
# ─────────────────────────────────────────────────────────────────────────────
# 交互式采样点与光谱查看器
# ─────────────────────────────────────────────────────────────────────────────
class SamplingViewerDialog(QDialog):
"""交互式采样点与光谱查看器
左右分栏:
- 左侧matplotlib 散点图pixel_x / pixel_y
- 右侧:上方信息面板(坐标),下方子图画光谱曲线
点击散点图任意点 → 定位最近采样点 → 高亮 + 显示坐标 + 绘制该点光谱
"""
def __init__(self, csv_path: str, parent: Optional[QWidget] = None):
super().__init__(parent)
self.csv_path = csv_path
self.df: Optional[pd.DataFrame] = None
self.selected_idx: Optional[int] = None
self._init_ui()
self._load_data()
self._draw_scatter()
def _init_ui(self):
"""搭建对话框 UI"""
self.setWindowTitle("采样点与光谱查看器")
self.setModal(False)
self.resize(1100, 600)
main_layout = QHBoxLayout(self)
# --- 左侧matplotlib 散点图画布 ---
self._fig = Figure(figsize=(6, 5))
self._canvas = FigureCanvasQTAgg(self._fig)
self._ax_scatter = self._fig.add_subplot(111)
self._ax_scatter.set_xlabel("pixel_x")
self._ax_scatter.set_ylabel("pixel_y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
self._fig.tight_layout()
main_layout.addWidget(self._canvas, stretch=2)
# --- 右侧:信息面板 + 光谱子图 ---
right_widget = QWidget()
right_layout = QVBoxLayout(right_widget)
self._info_label = QLabel("点击左侧散点图选择采样点")
self._info_label.setStyleSheet(
"QLabel { background-color: #f0f0f0; padding: 6px; "
"border: 1px solid #ccc; font-size: 13px; }"
)
right_layout.addWidget(self._info_label)
# 采样点列表迷你表格
self._point_table = QTableWidget()
self._point_table.setColumnCount(3)
self._point_table.setHorizontalHeaderLabels(["pixel_x", "pixel_y", "index"])
self._point_table.setMaximumHeight(120)
self._point_table.setEditTriggers(QTableWidget.NoEditTriggers)
self._point_table.setSelectionBehavior(QTableWidget.SelectRows)
right_layout.addWidget(QLabel("采样点列表(共 N 个):"))
right_layout.addWidget(self._point_table)
# 光谱曲线子图
self._fig_right = Figure(figsize=(5, 3))
self._ax_spectrum = self._fig_right.add_subplot(111)
self._ax_spectrum.set_xlabel("Band Index")
self._ax_spectrum.set_ylabel("Reflectance")
self._ax_spectrum.set_title("光谱曲线")
self._fig_right.tight_layout()
self._canvas_right = FigureCanvasQTAgg(self._fig_right)
right_layout.addWidget(self._canvas_right, stretch=1)
main_layout.addWidget(right_widget, stretch=1)
self.setLayout(main_layout)
# 绑定鼠标点击事件
self._cid = self._canvas.mpl_connect('button_press_event', self._on_click)
def _load_data(self):
"""加载 CSV 数据"""
if not os.path.exists(self.csv_path):
self.df = None
return
self.df = pd.read_csv(self.csv_path)
self._info_label.setText(
f"共加载 {len(self.df)} 个采样点,文件:{os.path.basename(self.csv_path)}"
)
def _draw_scatter(self):
"""绘制散点图"""
self._ax_scatter.clear()
self._ax_scatter.set_xlabel("pixel_x")
self._ax_scatter.set_ylabel("pixel_y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
if self.df is None:
self._ax_scatter.text(
0.5, 0.5, "无采样数据",
ha='center', va='center', transform=self._ax_scatter.transAxes,
color='gray', fontsize=14
)
self._fig.canvas.draw_idle()
return
xs = self.df['pixel_x'].values
ys = self.df['pixel_y'].values
self._ax_scatter.scatter(xs, ys, s=15, alpha=0.7, color='#2196F3')
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
self._fig.canvas.draw_idle()
# 填充迷你表格
self._point_table.blockSignals(True)
self._point_table.setRowCount(min(len(self.df), 200))
for i, (_, row) in enumerate(self.df.iterrows()):
if i >= 200:
break
self._point_table.setItem(
i, 0, QTableWidgetItem(str(int(row.get('pixel_x', 0))))
)
self._point_table.setItem(
i, 1, QTableWidgetItem(str(int(row.get('pixel_y', 0))))
)
self._point_table.setItem(i, 2, QTableWidgetItem(str(i)))
self._point_table.blockSignals(False)
self._point_table.setRowCount(min(len(self.df), 200))
def _on_click(self, event):
"""鼠标点击散点图 → 找最近点 + 高亮 + 画光谱"""
if self.df is None or event.xdata is None or event.ydata is None:
return
click_x, click_y = event.xdata, event.ydata
distances = np.sqrt(
(self.df['pixel_x'].values - click_x) ** 2 +
(self.df['pixel_y'].values - click_y) ** 2
)
nearest_idx = int(np.argmin(distances))
self.selected_idx = nearest_idx
row = self.df.iloc[nearest_idx]
pixel_x = int(row.get('pixel_x', 0))
pixel_y = int(row.get('pixel_y', 0))
x_coord = row.get('x_coord', 'N/A')
y_coord = row.get('y_coord', 'N/A')
self._info_label.setText(
f"<b>选中的采样点 #{nearest_idx}</b><br>"
f"pixel_x = {pixel_x} &nbsp; pixel_y = {pixel_y}<br>"
f"x_coord = {x_coord} &nbsp; y_coord = {y_coord}"
)
# 高亮散点图
self._ax_scatter.clear()
self._ax_scatter.set_xlabel("pixel_x")
self._ax_scatter.set_ylabel("pixel_y")
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
self._ax_scatter.invert_yaxis()
self._ax_scatter.scatter(
self.df['pixel_x'].values, self.df['pixel_y'].values,
s=15, alpha=0.5, color='#90CAF9'
)
self._ax_scatter.scatter(
[pixel_x], [pixel_y], s=80, alpha=0.9,
color='red', zorder=5, label=f"#{nearest_idx}"
)
self._ax_scatter.legend()
self._fig.canvas.draw_idle()
self._draw_spectrum(row)
def _draw_spectrum(self, row: pd.Series):
"""从一行数据中提取纯波段数值并绘图"""
self._ax_spectrum.clear()
self._ax_spectrum.set_xlabel("Band Index")
self._ax_spectrum.set_ylabel("Reflectance")
self._ax_spectrum.set_title("光谱曲线")
exclude_patterns = (
'pixel_x', 'pixel_y', 'x_coord', 'y_coord',
'wqi', 'index', 'Unnamed', 'id', 'ID',
'longitude', 'latitude', 'lon', 'lat',
)
spectral_cols = [
col for col in self.df.columns
if not any(p in col.lower() for p in exclude_patterns)
and pd.api.types.is_numeric_dtype(self.df[col])
]
if not spectral_cols:
self._ax_spectrum.text(
0.5, 0.5, "无可用波段数据",
ha='center', va='center', transform=self._ax_spectrum.transAxes,
color='gray', fontsize=11
)
self._fig_right.canvas.draw_idle()
return
band_values = []
for col in spectral_cols:
val = row[col]
try:
band_values.append(float(val))
except (ValueError, TypeError):
band_values.append(np.nan)
band_indices = np.arange(len(band_values), dtype=float)
valid = ~np.isnan(band_values)
if not valid.any():
self._ax_spectrum.text(
0.5, 0.5, "波段数据全为无效值",
ha='center', va='center', transform=self._ax_spectrum.transAxes,
color='gray', fontsize=11
)
else:
self._ax_spectrum.plot(
band_indices[valid], np.array(band_values)[valid],
color='#1E88E5', linewidth=1.2
)
self._ax_spectrum.grid(True, alpha=0.3)
self._fig_right.canvas.draw_idle()

View File

@ -12,6 +12,7 @@ from PyQt5.QtWidgets import (
) )
from src.gui.components.custom_widgets import FileSelectWidget from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.dialogs import SamplingViewerDialog
from src.gui.styles import ModernStylesheet from src.gui.styles import ModernStylesheet
@ -58,6 +59,10 @@ class Step7Panel(QWidget):
self.chunk_size.setValue(1000) self.chunk_size.setValue(1000)
params_layout.addRow("处理块大小:", self.chunk_size) params_layout.addRow("处理块大小:", self.chunk_size)
self.use_adaptive_sampling = QCheckBox("启用自适应采样")
self.use_adaptive_sampling.setChecked(True)
params_layout.addRow("采样模式:", self.use_adaptive_sampling)
params_group.setLayout(params_layout) params_group.setLayout(params_layout)
layout.addWidget(params_group) layout.addWidget(params_group)
@ -80,15 +85,25 @@ class Step7Panel(QWidget):
self.run_btn.clicked.connect(self.run_step) self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn) layout.addWidget(self.run_btn)
# 交互式预览按钮
self.preview_btn = QPushButton("📊 交互式预览采样点与光谱")
self.preview_btn.setEnabled(False)
self.preview_btn.clicked.connect(self._open_sampling_viewer)
layout.addWidget(self.preview_btn)
layout.addStretch() layout.addStretch()
self.setLayout(layout) self.setLayout(layout)
# 监听输出路径变化,实时更新预览按钮状态
self.output_file.line_edit.textChanged.connect(self._on_output_changed)
def get_config(self): def get_config(self):
"""获取配置""" """获取配置"""
config = { config = {
'interval': self.interval.value(), 'interval': self.interval.value(),
'sample_radius': self.sample_radius.value(), 'sample_radius': self.sample_radius.value(),
'chunk_size': self.chunk_size.value(), 'chunk_size': self.chunk_size.value(),
'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(),
} }
deglint_img_path = self.deglint_img_file.get_path() deglint_img_path = self.deglint_img_file.get_path()
if deglint_img_path: if deglint_img_path:
@ -96,7 +111,6 @@ class Step7Panel(QWidget):
water_mask_path = self.water_mask_file.get_path() water_mask_path = self.water_mask_file.get_path()
if water_mask_path: if water_mask_path:
config['water_mask_path'] = water_mask_path config['water_mask_path'] = water_mask_path
# 注意step7_generate_sampling_points 不接受 output_path 参数,输出路径由 pipeline 内部自动生成
return config return config
def set_config(self, config): def set_config(self, config):
@ -107,6 +121,8 @@ class Step7Panel(QWidget):
self.sample_radius.setValue(config['sample_radius']) self.sample_radius.setValue(config['sample_radius'])
if 'chunk_size' in config: if 'chunk_size' in config:
self.chunk_size.setValue(config['chunk_size']) self.chunk_size.setValue(config['chunk_size'])
if 'use_adaptive_sampling' in config:
self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling'])
if 'deglint_img_path' in config: if 'deglint_img_path' in config:
self.deglint_img_file.set_path(config['deglint_img_path']) self.deglint_img_file.set_path(config['deglint_img_path'])
if 'water_mask_path' in config: if 'water_mask_path' in config:
@ -195,6 +211,9 @@ class Step7Panel(QWidget):
os.makedirs(os.path.dirname(output_path), exist_ok=True) os.makedirs(os.path.dirname(output_path), exist_ok=True)
self.output_file.set_path(output_path.replace('\\', '/')) self.output_file.set_path(output_path.replace('\\', '/'))
# 4. 同步更新预览按钮状态(路径可能已自动填充)
self._check_csv_exists()
def run_step(self): def run_step(self):
"""独立运行步骤7""" """独立运行步骤7"""
deglint_img_path = self.deglint_img_file.get_path() deglint_img_path = self.deglint_img_file.get_path()
@ -206,3 +225,28 @@ class Step7Panel(QWidget):
if hasattr(main_window, 'run_single_step'): if hasattr(main_window, 'run_single_step'):
config = {'step7': self.get_config()} config = {'step7': self.get_config()}
main_window.run_single_step('step7', config) main_window.run_single_step('step7', config)
def _check_csv_exists(self):
"""检查 output csv 是否存在,驱动预览按钮启停"""
csv_path = self.output_file.get_path()
enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path))
self.preview_btn.setEnabled(enabled)
return enabled
def _on_output_changed(self, _text=None):
"""输出路径输入框内容变化时调用_text 为 line_edit.textChanged 信号参数)"""
self._check_csv_exists()
def _open_sampling_viewer(self):
"""打开交互式采样点查看器弹窗"""
csv_path = self.output_file.get_path()
if not csv_path or not os.path.exists(csv_path):
QMessageBox.warning(
self, "文件不存在",
f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤7生成数据。"
)
return
dialog = SamplingViewerDialog(csv_path, self)
dialog.exec_()
# 弹窗关闭后再次检查状态(可能文件被覆盖等)
self._check_csv_exists()