diff --git a/src/core/steps/prediction_step.py b/src/core/steps/prediction_step.py
index 12262e2..a0fa10e 100644
--- a/src/core/steps/prediction_step.py
+++ b/src/core/steps/prediction_step.py
@@ -26,6 +26,7 @@ class PredictionStep:
glint_mask_path: Optional[str] = None,
output_dir: Union[str, Path] = "./10_sampling",
callback: Optional[Callable] = None,
+ use_adaptive_sampling: bool = True,
) -> str:
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
from pathlib import Path
@@ -83,10 +84,14 @@ class PredictionStep:
if glint_mask_to_use is None:
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
- # 传递极度安全的 deglint_img_str 进底层
+ # 传递极度安全的 deglint_img_str 进底层(关键字传参,避免 positional 参数顺序陷阱)
get_spectral_sampling_points_chunked(
deglint_img_str, water_mask_path, glint_mask_to_use,
- output_path, interval, sample_radius, chunk_size
+ output_path,
+ interval=interval,
+ sample_radius=sample_radius,
+ chunk_size=chunk_size,
+ use_adaptive_sampling=use_adaptive_sampling,
)
notify("completed", f"采样点光谱数据已保存: {output_path}")
diff --git a/src/core/water_quality_inversion_pipeline_GUI.py b/src/core/water_quality_inversion_pipeline_GUI.py
index 8e7a8ab..ac5d639 100644
--- a/src/core/water_quality_inversion_pipeline_GUI.py
+++ b/src/core/water_quality_inversion_pipeline_GUI.py
@@ -753,17 +753,19 @@ class WaterQualityInversionPipeline:
chunk_size: int = 1000,
water_mask_path: Optional[str] = None,
glint_mask_path: Optional[str] = None,
+ use_adaptive_sampling: bool = True,
skip_dependency_check: bool = False, **kwargs) -> str:
"""
步骤7: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱
-
+
Args:
deglint_img_path: 去除耀斑后的影像文件路径(如果为None,使用步骤3的结果)
interval: 采样点间隔(像元数)
sample_radius: 采样点半径(像元数)
chunk_size: 每次处理的行数(控制内存使用)
water_mask_path: dat格式的水域掩膜文件路径(如果为None,将使用步骤1生成的dat格式掩膜)
-
+ use_adaptive_sampling: 是否启用自适应采样(根据水体宽度动态调整间隔)
+
Returns:
采样点光谱数据CSV文件路径
"""
@@ -786,6 +788,7 @@ class WaterQualityInversionPipeline:
water_mask_path=water_mask_path,
glint_mask_path=glint_mask_path,
output_dir=str(self.sampling_dir),
+ use_adaptive_sampling=use_adaptive_sampling,
)
self._record_step_time("步骤7: 生成预测采样点", 0, 0)
self._notify("completed", f"采样点光谱数据已保存: {result}")
diff --git a/src/gui/dialogs.py b/src/gui/dialogs.py
index eb88b2e..04ecc8a 100644
--- a/src/gui/dialogs.py
+++ b/src/gui/dialogs.py
@@ -7,6 +7,11 @@
"""
import os
+from typing import Optional
+
+import numpy as np
+import pandas as pd
+
from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtGui import QFont
from PyQt5.QtWidgets import (
@@ -20,7 +25,11 @@ from PyQt5.QtWidgets import (
QWidget,
QComboBox,
QLineEdit,
+ QTableWidget,
+ QTableWidgetItem,
)
+from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
+from matplotlib.figure import Figure
class BandConfirmDialog(QDialog):
@@ -334,3 +343,231 @@ class AISettingsDialog(QDialog):
"text_model": s.value("text_model", "", type=str),
"timeout_s": s.value("timeout_s", 120, type=int),
}
+
+
+# ─────────────────────────────────────────────────────────────────────────────
+# 交互式采样点与光谱查看器
+# ─────────────────────────────────────────────────────────────────────────────
+
+
+class SamplingViewerDialog(QDialog):
+ """交互式采样点与光谱查看器
+
+ 左右分栏:
+ - 左侧:matplotlib 散点图(pixel_x / pixel_y)
+ - 右侧:上方信息面板(坐标),下方子图画光谱曲线
+
+ 点击散点图任意点 → 定位最近采样点 → 高亮 + 显示坐标 + 绘制该点光谱
+ """
+
+ def __init__(self, csv_path: str, parent: Optional[QWidget] = None):
+ super().__init__(parent)
+ self.csv_path = csv_path
+ self.df: Optional[pd.DataFrame] = None
+ self.selected_idx: Optional[int] = None
+ self._init_ui()
+ self._load_data()
+ self._draw_scatter()
+
+ def _init_ui(self):
+ """搭建对话框 UI"""
+ self.setWindowTitle("采样点与光谱查看器")
+ self.setModal(False)
+ self.resize(1100, 600)
+
+ main_layout = QHBoxLayout(self)
+
+ # --- 左侧:matplotlib 散点图画布 ---
+ self._fig = Figure(figsize=(6, 5))
+ self._canvas = FigureCanvasQTAgg(self._fig)
+ self._ax_scatter = self._fig.add_subplot(111)
+ self._ax_scatter.set_xlabel("pixel_x")
+ self._ax_scatter.set_ylabel("pixel_y")
+ self._ax_scatter.set_title("采样点分布(点击查看详情)")
+ self._ax_scatter.invert_yaxis()
+ self._fig.tight_layout()
+
+ main_layout.addWidget(self._canvas, stretch=2)
+
+ # --- 右侧:信息面板 + 光谱子图 ---
+ right_widget = QWidget()
+ right_layout = QVBoxLayout(right_widget)
+
+ self._info_label = QLabel("点击左侧散点图选择采样点")
+ self._info_label.setStyleSheet(
+ "QLabel { background-color: #f0f0f0; padding: 6px; "
+ "border: 1px solid #ccc; font-size: 13px; }"
+ )
+ right_layout.addWidget(self._info_label)
+
+ # 采样点列表迷你表格
+ self._point_table = QTableWidget()
+ self._point_table.setColumnCount(3)
+ self._point_table.setHorizontalHeaderLabels(["pixel_x", "pixel_y", "index"])
+ self._point_table.setMaximumHeight(120)
+ self._point_table.setEditTriggers(QTableWidget.NoEditTriggers)
+ self._point_table.setSelectionBehavior(QTableWidget.SelectRows)
+ right_layout.addWidget(QLabel("采样点列表(共 N 个):"))
+ right_layout.addWidget(self._point_table)
+
+ # 光谱曲线子图
+ self._fig_right = Figure(figsize=(5, 3))
+ self._ax_spectrum = self._fig_right.add_subplot(111)
+ self._ax_spectrum.set_xlabel("Band Index")
+ self._ax_spectrum.set_ylabel("Reflectance")
+ self._ax_spectrum.set_title("光谱曲线")
+ self._fig_right.tight_layout()
+
+ self._canvas_right = FigureCanvasQTAgg(self._fig_right)
+ right_layout.addWidget(self._canvas_right, stretch=1)
+
+ main_layout.addWidget(right_widget, stretch=1)
+ self.setLayout(main_layout)
+
+ # 绑定鼠标点击事件
+ self._cid = self._canvas.mpl_connect('button_press_event', self._on_click)
+
+ def _load_data(self):
+ """加载 CSV 数据"""
+ if not os.path.exists(self.csv_path):
+ self.df = None
+ return
+ self.df = pd.read_csv(self.csv_path)
+ self._info_label.setText(
+ f"共加载 {len(self.df)} 个采样点,文件:{os.path.basename(self.csv_path)}"
+ )
+
+ def _draw_scatter(self):
+ """绘制散点图"""
+ self._ax_scatter.clear()
+ self._ax_scatter.set_xlabel("pixel_x")
+ self._ax_scatter.set_ylabel("pixel_y")
+ self._ax_scatter.set_title("采样点分布(点击查看详情)")
+ self._ax_scatter.invert_yaxis()
+
+ if self.df is None:
+ self._ax_scatter.text(
+ 0.5, 0.5, "无采样数据",
+ ha='center', va='center', transform=self._ax_scatter.transAxes,
+ color='gray', fontsize=14
+ )
+ self._fig.canvas.draw_idle()
+ return
+
+ xs = self.df['pixel_x'].values
+ ys = self.df['pixel_y'].values
+ self._ax_scatter.scatter(xs, ys, s=15, alpha=0.7, color='#2196F3')
+ self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
+ self._fig.canvas.draw_idle()
+
+ # 填充迷你表格
+ self._point_table.blockSignals(True)
+ self._point_table.setRowCount(min(len(self.df), 200))
+ for i, (_, row) in enumerate(self.df.iterrows()):
+ if i >= 200:
+ break
+ self._point_table.setItem(
+ i, 0, QTableWidgetItem(str(int(row.get('pixel_x', 0))))
+ )
+ self._point_table.setItem(
+ i, 1, QTableWidgetItem(str(int(row.get('pixel_y', 0))))
+ )
+ self._point_table.setItem(i, 2, QTableWidgetItem(str(i)))
+ self._point_table.blockSignals(False)
+ self._point_table.setRowCount(min(len(self.df), 200))
+
+ def _on_click(self, event):
+ """鼠标点击散点图 → 找最近点 + 高亮 + 画光谱"""
+ if self.df is None or event.xdata is None or event.ydata is None:
+ return
+
+ click_x, click_y = event.xdata, event.ydata
+ distances = np.sqrt(
+ (self.df['pixel_x'].values - click_x) ** 2 +
+ (self.df['pixel_y'].values - click_y) ** 2
+ )
+ nearest_idx = int(np.argmin(distances))
+ self.selected_idx = nearest_idx
+
+ row = self.df.iloc[nearest_idx]
+ pixel_x = int(row.get('pixel_x', 0))
+ pixel_y = int(row.get('pixel_y', 0))
+ x_coord = row.get('x_coord', 'N/A')
+ y_coord = row.get('y_coord', 'N/A')
+
+ self._info_label.setText(
+ f"选中的采样点 #{nearest_idx}
"
+ f"pixel_x = {pixel_x} pixel_y = {pixel_y}
"
+ f"x_coord = {x_coord} y_coord = {y_coord}"
+ )
+
+ # 高亮散点图
+ self._ax_scatter.clear()
+ self._ax_scatter.set_xlabel("pixel_x")
+ self._ax_scatter.set_ylabel("pixel_y")
+ self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
+ self._ax_scatter.invert_yaxis()
+ self._ax_scatter.scatter(
+ self.df['pixel_x'].values, self.df['pixel_y'].values,
+ s=15, alpha=0.5, color='#90CAF9'
+ )
+ self._ax_scatter.scatter(
+ [pixel_x], [pixel_y], s=80, alpha=0.9,
+ color='red', zorder=5, label=f"#{nearest_idx}"
+ )
+ self._ax_scatter.legend()
+ self._fig.canvas.draw_idle()
+
+ self._draw_spectrum(row)
+
+ def _draw_spectrum(self, row: pd.Series):
+ """从一行数据中提取纯波段数值并绘图"""
+ self._ax_spectrum.clear()
+ self._ax_spectrum.set_xlabel("Band Index")
+ self._ax_spectrum.set_ylabel("Reflectance")
+ self._ax_spectrum.set_title("光谱曲线")
+
+ exclude_patterns = (
+ 'pixel_x', 'pixel_y', 'x_coord', 'y_coord',
+ 'wqi', 'index', 'Unnamed', 'id', 'ID',
+ 'longitude', 'latitude', 'lon', 'lat',
+ )
+ spectral_cols = [
+ col for col in self.df.columns
+ if not any(p in col.lower() for p in exclude_patterns)
+ and pd.api.types.is_numeric_dtype(self.df[col])
+ ]
+
+ if not spectral_cols:
+ self._ax_spectrum.text(
+ 0.5, 0.5, "无可用波段数据",
+ ha='center', va='center', transform=self._ax_spectrum.transAxes,
+ color='gray', fontsize=11
+ )
+ self._fig_right.canvas.draw_idle()
+ return
+
+ band_values = []
+ for col in spectral_cols:
+ val = row[col]
+ try:
+ band_values.append(float(val))
+ except (ValueError, TypeError):
+ band_values.append(np.nan)
+
+ band_indices = np.arange(len(band_values), dtype=float)
+ valid = ~np.isnan(band_values)
+ if not valid.any():
+ self._ax_spectrum.text(
+ 0.5, 0.5, "波段数据全为无效值",
+ ha='center', va='center', transform=self._ax_spectrum.transAxes,
+ color='gray', fontsize=11
+ )
+ else:
+ self._ax_spectrum.plot(
+ band_indices[valid], np.array(band_values)[valid],
+ color='#1E88E5', linewidth=1.2
+ )
+ self._ax_spectrum.grid(True, alpha=0.3)
+
+ self._fig_right.canvas.draw_idle()
diff --git a/src/gui/panels/step7_panel.py b/src/gui/panels/step7_panel.py
index a819201..7c9ed67 100644
--- a/src/gui/panels/step7_panel.py
+++ b/src/gui/panels/step7_panel.py
@@ -12,6 +12,7 @@ from PyQt5.QtWidgets import (
)
from src.gui.components.custom_widgets import FileSelectWidget
+from src.gui.dialogs import SamplingViewerDialog
from src.gui.styles import ModernStylesheet
@@ -58,6 +59,10 @@ class Step7Panel(QWidget):
self.chunk_size.setValue(1000)
params_layout.addRow("处理块大小:", self.chunk_size)
+ self.use_adaptive_sampling = QCheckBox("启用自适应采样")
+ self.use_adaptive_sampling.setChecked(True)
+ params_layout.addRow("采样模式:", self.use_adaptive_sampling)
+
params_group.setLayout(params_layout)
layout.addWidget(params_group)
@@ -80,15 +85,25 @@ class Step7Panel(QWidget):
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
+ # 交互式预览按钮
+ self.preview_btn = QPushButton("📊 交互式预览采样点与光谱")
+ self.preview_btn.setEnabled(False)
+ self.preview_btn.clicked.connect(self._open_sampling_viewer)
+ layout.addWidget(self.preview_btn)
+
layout.addStretch()
self.setLayout(layout)
+ # 监听输出路径变化,实时更新预览按钮状态
+ self.output_file.line_edit.textChanged.connect(self._on_output_changed)
+
def get_config(self):
"""获取配置"""
config = {
'interval': self.interval.value(),
'sample_radius': self.sample_radius.value(),
'chunk_size': self.chunk_size.value(),
+ 'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(),
}
deglint_img_path = self.deglint_img_file.get_path()
if deglint_img_path:
@@ -96,7 +111,6 @@ class Step7Panel(QWidget):
water_mask_path = self.water_mask_file.get_path()
if water_mask_path:
config['water_mask_path'] = water_mask_path
- # 注意:step7_generate_sampling_points 不接受 output_path 参数,输出路径由 pipeline 内部自动生成
return config
def set_config(self, config):
@@ -107,6 +121,8 @@ class Step7Panel(QWidget):
self.sample_radius.setValue(config['sample_radius'])
if 'chunk_size' in config:
self.chunk_size.setValue(config['chunk_size'])
+ if 'use_adaptive_sampling' in config:
+ self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling'])
if 'deglint_img_path' in config:
self.deglint_img_file.set_path(config['deglint_img_path'])
if 'water_mask_path' in config:
@@ -195,6 +211,9 @@ class Step7Panel(QWidget):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
self.output_file.set_path(output_path.replace('\\', '/'))
+ # 4. 同步更新预览按钮状态(路径可能已自动填充)
+ self._check_csv_exists()
+
def run_step(self):
"""独立运行步骤7"""
deglint_img_path = self.deglint_img_file.get_path()
@@ -206,3 +225,28 @@ class Step7Panel(QWidget):
if hasattr(main_window, 'run_single_step'):
config = {'step7': self.get_config()}
main_window.run_single_step('step7', config)
+
+ def _check_csv_exists(self):
+ """检查 output csv 是否存在,驱动预览按钮启停"""
+ csv_path = self.output_file.get_path()
+ enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path))
+ self.preview_btn.setEnabled(enabled)
+ return enabled
+
+ def _on_output_changed(self, _text=None):
+ """输出路径输入框内容变化时调用(_text 为 line_edit.textChanged 信号参数)"""
+ self._check_csv_exists()
+
+ def _open_sampling_viewer(self):
+ """打开交互式采样点查看器弹窗"""
+ csv_path = self.output_file.get_path()
+ if not csv_path or not os.path.exists(csv_path):
+ QMessageBox.warning(
+ self, "文件不存在",
+ f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤7生成数据。"
+ )
+ return
+ dialog = SamplingViewerDialog(csv_path, self)
+ dialog.exec_()
+ # 弹窗关闭后再次检查状态(可能文件被覆盖等)
+ self._check_csv_exists()