feat(sampling): add adaptive sampling toggle + interactive sampling point viewer

This commit is contained in:
DXC
2026-06-08 15:39:43 +08:00
parent e57fdb4f75
commit d22414bf7d
4 changed files with 294 additions and 5 deletions

View File

@ -26,6 +26,7 @@ class PredictionStep:
glint_mask_path: Optional[str] = None,
output_dir: Union[str, Path] = "./10_sampling",
callback: Optional[Callable] = None,
use_adaptive_sampling: bool = True,
) -> str:
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
from pathlib import Path
@ -83,10 +84,14 @@ class PredictionStep:
if glint_mask_to_use is None:
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
# 传递极度安全的 deglint_img_str 进底层
# 传递极度安全的 deglint_img_str 进底层(关键字传参,避免 positional 参数顺序陷阱)
get_spectral_sampling_points_chunked(
deglint_img_str, water_mask_path, glint_mask_to_use,
output_path, interval, sample_radius, chunk_size
output_path,
interval=interval,
sample_radius=sample_radius,
chunk_size=chunk_size,
use_adaptive_sampling=use_adaptive_sampling,
)
notify("completed", f"采样点光谱数据已保存: {output_path}")

View File

@ -753,6 +753,7 @@ class WaterQualityInversionPipeline:
chunk_size: int = 1000,
water_mask_path: Optional[str] = None,
glint_mask_path: Optional[str] = None,
use_adaptive_sampling: bool = True,
skip_dependency_check: bool = False, **kwargs) -> str:
"""
步骤7: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱
@ -763,6 +764,7 @@ class WaterQualityInversionPipeline:
sample_radius: 采样点半径(像元数)
chunk_size: 每次处理的行数(控制内存使用)
water_mask_path: dat格式的水域掩膜文件路径如果为None将使用步骤1生成的dat格式掩膜
use_adaptive_sampling: 是否启用自适应采样(根据水体宽度动态调整间隔)
Returns:
采样点光谱数据CSV文件路径
@ -786,6 +788,7 @@ class WaterQualityInversionPipeline:
water_mask_path=water_mask_path,
glint_mask_path=glint_mask_path,
output_dir=str(self.sampling_dir),
use_adaptive_sampling=use_adaptive_sampling,
)
self._record_step_time("步骤7: 生成预测采样点", 0, 0)
self._notify("completed", f"采样点光谱数据已保存: {result}")

View File

@ -7,6 +7,11 @@
"""
import os
from typing import Optional
import numpy as np
import pandas as pd
from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtGui import QFont
from PyQt5.QtWidgets import (
@ -20,7 +25,11 @@ from PyQt5.QtWidgets import (
QWidget,
QComboBox,
QLineEdit,
QTableWidget,
QTableWidgetItem,
)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
class BandConfirmDialog(QDialog):
@ -334,3 +343,231 @@ class AISettingsDialog(QDialog):
"text_model": s.value("text_model", "", type=str),
"timeout_s": s.value("timeout_s", 120, type=int),
}
# ─────────────────────────────────────────────────────────────────────────────
# 交互式采样点与光谱查看器
# ─────────────────────────────────────────────────────────────────────────────
class SamplingViewerDialog(QDialog):
"""交互式采样点与光谱查看器
左右分栏:
- 左侧matplotlib 散点图pixel_x / pixel_y
- 右侧:上方信息面板(坐标),下方子图画光谱曲线
点击散点图任意点 → 定位最近采样点 → 高亮 + 显示坐标 + 绘制该点光谱
"""
def __init__(self, csv_path: str, parent: Optional[QWidget] = None):
super().__init__(parent)
self.csv_path = csv_path
self.df: Optional[pd.DataFrame] = None
self.selected_idx: Optional[int] = None
self._init_ui()
self._load_data()
self._draw_scatter()
def _init_ui(self):
"""搭建对话框 UI"""
self.setWindowTitle("采样点与光谱查看器")
self.setModal(False)
self.resize(1100, 600)
main_layout = QHBoxLayout(self)
# --- 左侧matplotlib 散点图画布 ---
self._fig = Figure(figsize=(6, 5))
self._canvas = FigureCanvasQTAgg(self._fig)
self._ax_scatter = self._fig.add_subplot(111)
self._ax_scatter.set_xlabel("pixel_x")
self._ax_scatter.set_ylabel("pixel_y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
self._fig.tight_layout()
main_layout.addWidget(self._canvas, stretch=2)
# --- 右侧:信息面板 + 光谱子图 ---
right_widget = QWidget()
right_layout = QVBoxLayout(right_widget)
self._info_label = QLabel("点击左侧散点图选择采样点")
self._info_label.setStyleSheet(
"QLabel { background-color: #f0f0f0; padding: 6px; "
"border: 1px solid #ccc; font-size: 13px; }"
)
right_layout.addWidget(self._info_label)
# 采样点列表迷你表格
self._point_table = QTableWidget()
self._point_table.setColumnCount(3)
self._point_table.setHorizontalHeaderLabels(["pixel_x", "pixel_y", "index"])
self._point_table.setMaximumHeight(120)
self._point_table.setEditTriggers(QTableWidget.NoEditTriggers)
self._point_table.setSelectionBehavior(QTableWidget.SelectRows)
right_layout.addWidget(QLabel("采样点列表(共 N 个):"))
right_layout.addWidget(self._point_table)
# 光谱曲线子图
self._fig_right = Figure(figsize=(5, 3))
self._ax_spectrum = self._fig_right.add_subplot(111)
self._ax_spectrum.set_xlabel("Band Index")
self._ax_spectrum.set_ylabel("Reflectance")
self._ax_spectrum.set_title("光谱曲线")
self._fig_right.tight_layout()
self._canvas_right = FigureCanvasQTAgg(self._fig_right)
right_layout.addWidget(self._canvas_right, stretch=1)
main_layout.addWidget(right_widget, stretch=1)
self.setLayout(main_layout)
# 绑定鼠标点击事件
self._cid = self._canvas.mpl_connect('button_press_event', self._on_click)
def _load_data(self):
"""加载 CSV 数据"""
if not os.path.exists(self.csv_path):
self.df = None
return
self.df = pd.read_csv(self.csv_path)
self._info_label.setText(
f"共加载 {len(self.df)} 个采样点,文件:{os.path.basename(self.csv_path)}"
)
def _draw_scatter(self):
"""绘制散点图"""
self._ax_scatter.clear()
self._ax_scatter.set_xlabel("pixel_x")
self._ax_scatter.set_ylabel("pixel_y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
if self.df is None:
self._ax_scatter.text(
0.5, 0.5, "无采样数据",
ha='center', va='center', transform=self._ax_scatter.transAxes,
color='gray', fontsize=14
)
self._fig.canvas.draw_idle()
return
xs = self.df['pixel_x'].values
ys = self.df['pixel_y'].values
self._ax_scatter.scatter(xs, ys, s=15, alpha=0.7, color='#2196F3')
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
self._fig.canvas.draw_idle()
# 填充迷你表格
self._point_table.blockSignals(True)
self._point_table.setRowCount(min(len(self.df), 200))
for i, (_, row) in enumerate(self.df.iterrows()):
if i >= 200:
break
self._point_table.setItem(
i, 0, QTableWidgetItem(str(int(row.get('pixel_x', 0))))
)
self._point_table.setItem(
i, 1, QTableWidgetItem(str(int(row.get('pixel_y', 0))))
)
self._point_table.setItem(i, 2, QTableWidgetItem(str(i)))
self._point_table.blockSignals(False)
self._point_table.setRowCount(min(len(self.df), 200))
def _on_click(self, event):
"""鼠标点击散点图 → 找最近点 + 高亮 + 画光谱"""
if self.df is None or event.xdata is None or event.ydata is None:
return
click_x, click_y = event.xdata, event.ydata
distances = np.sqrt(
(self.df['pixel_x'].values - click_x) ** 2 +
(self.df['pixel_y'].values - click_y) ** 2
)
nearest_idx = int(np.argmin(distances))
self.selected_idx = nearest_idx
row = self.df.iloc[nearest_idx]
pixel_x = int(row.get('pixel_x', 0))
pixel_y = int(row.get('pixel_y', 0))
x_coord = row.get('x_coord', 'N/A')
y_coord = row.get('y_coord', 'N/A')
self._info_label.setText(
f"<b>选中的采样点 #{nearest_idx}</b><br>"
f"pixel_x = {pixel_x} &nbsp; pixel_y = {pixel_y}<br>"
f"x_coord = {x_coord} &nbsp; y_coord = {y_coord}"
)
# 高亮散点图
self._ax_scatter.clear()
self._ax_scatter.set_xlabel("pixel_x")
self._ax_scatter.set_ylabel("pixel_y")
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
self._ax_scatter.invert_yaxis()
self._ax_scatter.scatter(
self.df['pixel_x'].values, self.df['pixel_y'].values,
s=15, alpha=0.5, color='#90CAF9'
)
self._ax_scatter.scatter(
[pixel_x], [pixel_y], s=80, alpha=0.9,
color='red', zorder=5, label=f"#{nearest_idx}"
)
self._ax_scatter.legend()
self._fig.canvas.draw_idle()
self._draw_spectrum(row)
def _draw_spectrum(self, row: pd.Series):
"""从一行数据中提取纯波段数值并绘图"""
self._ax_spectrum.clear()
self._ax_spectrum.set_xlabel("Band Index")
self._ax_spectrum.set_ylabel("Reflectance")
self._ax_spectrum.set_title("光谱曲线")
exclude_patterns = (
'pixel_x', 'pixel_y', 'x_coord', 'y_coord',
'wqi', 'index', 'Unnamed', 'id', 'ID',
'longitude', 'latitude', 'lon', 'lat',
)
spectral_cols = [
col for col in self.df.columns
if not any(p in col.lower() for p in exclude_patterns)
and pd.api.types.is_numeric_dtype(self.df[col])
]
if not spectral_cols:
self._ax_spectrum.text(
0.5, 0.5, "无可用波段数据",
ha='center', va='center', transform=self._ax_spectrum.transAxes,
color='gray', fontsize=11
)
self._fig_right.canvas.draw_idle()
return
band_values = []
for col in spectral_cols:
val = row[col]
try:
band_values.append(float(val))
except (ValueError, TypeError):
band_values.append(np.nan)
band_indices = np.arange(len(band_values), dtype=float)
valid = ~np.isnan(band_values)
if not valid.any():
self._ax_spectrum.text(
0.5, 0.5, "波段数据全为无效值",
ha='center', va='center', transform=self._ax_spectrum.transAxes,
color='gray', fontsize=11
)
else:
self._ax_spectrum.plot(
band_indices[valid], np.array(band_values)[valid],
color='#1E88E5', linewidth=1.2
)
self._ax_spectrum.grid(True, alpha=0.3)
self._fig_right.canvas.draw_idle()

View File

@ -12,6 +12,7 @@ from PyQt5.QtWidgets import (
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.dialogs import SamplingViewerDialog
from src.gui.styles import ModernStylesheet
@ -58,6 +59,10 @@ class Step7Panel(QWidget):
self.chunk_size.setValue(1000)
params_layout.addRow("处理块大小:", self.chunk_size)
self.use_adaptive_sampling = QCheckBox("启用自适应采样")
self.use_adaptive_sampling.setChecked(True)
params_layout.addRow("采样模式:", self.use_adaptive_sampling)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
@ -80,15 +85,25 @@ class Step7Panel(QWidget):
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
# 交互式预览按钮
self.preview_btn = QPushButton("📊 交互式预览采样点与光谱")
self.preview_btn.setEnabled(False)
self.preview_btn.clicked.connect(self._open_sampling_viewer)
layout.addWidget(self.preview_btn)
layout.addStretch()
self.setLayout(layout)
# 监听输出路径变化,实时更新预览按钮状态
self.output_file.line_edit.textChanged.connect(self._on_output_changed)
def get_config(self):
"""获取配置"""
config = {
'interval': self.interval.value(),
'sample_radius': self.sample_radius.value(),
'chunk_size': self.chunk_size.value(),
'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(),
}
deglint_img_path = self.deglint_img_file.get_path()
if deglint_img_path:
@ -96,7 +111,6 @@ class Step7Panel(QWidget):
water_mask_path = self.water_mask_file.get_path()
if water_mask_path:
config['water_mask_path'] = water_mask_path
# 注意step7_generate_sampling_points 不接受 output_path 参数,输出路径由 pipeline 内部自动生成
return config
def set_config(self, config):
@ -107,6 +121,8 @@ class Step7Panel(QWidget):
self.sample_radius.setValue(config['sample_radius'])
if 'chunk_size' in config:
self.chunk_size.setValue(config['chunk_size'])
if 'use_adaptive_sampling' in config:
self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling'])
if 'deglint_img_path' in config:
self.deglint_img_file.set_path(config['deglint_img_path'])
if 'water_mask_path' in config:
@ -195,6 +211,9 @@ class Step7Panel(QWidget):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
self.output_file.set_path(output_path.replace('\\', '/'))
# 4. 同步更新预览按钮状态(路径可能已自动填充)
self._check_csv_exists()
def run_step(self):
"""独立运行步骤7"""
deglint_img_path = self.deglint_img_file.get_path()
@ -206,3 +225,28 @@ class Step7Panel(QWidget):
if hasattr(main_window, 'run_single_step'):
config = {'step7': self.get_config()}
main_window.run_single_step('step7', config)
def _check_csv_exists(self):
"""检查 output csv 是否存在,驱动预览按钮启停"""
csv_path = self.output_file.get_path()
enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path))
self.preview_btn.setEnabled(enabled)
return enabled
def _on_output_changed(self, _text=None):
"""输出路径输入框内容变化时调用_text 为 line_edit.textChanged 信号参数)"""
self._check_csv_exists()
def _open_sampling_viewer(self):
"""打开交互式采样点查看器弹窗"""
csv_path = self.output_file.get_path()
if not csv_path or not os.path.exists(csv_path):
QMessageBox.warning(
self, "文件不存在",
f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤7生成数据。"
)
return
dialog = SamplingViewerDialog(csv_path, self)
dialog.exec_()
# 弹窗关闭后再次检查状态(可能文件被覆盖等)
self._check_csv_exists()