From d22414bf7d7ed19ae58661372dd4948cc025f26e Mon Sep 17 00:00:00 2001 From: DXC Date: Mon, 8 Jun 2026 15:39:43 +0800 Subject: [PATCH] feat(sampling): add adaptive sampling toggle + interactive sampling point viewer --- src/core/steps/prediction_step.py | 9 +- .../water_quality_inversion_pipeline_GUI.py | 7 +- src/gui/dialogs.py | 237 ++++++++++++++++++ src/gui/panels/step7_panel.py | 46 +++- 4 files changed, 294 insertions(+), 5 deletions(-) diff --git a/src/core/steps/prediction_step.py b/src/core/steps/prediction_step.py index 12262e2..a0fa10e 100644 --- a/src/core/steps/prediction_step.py +++ b/src/core/steps/prediction_step.py @@ -26,6 +26,7 @@ class PredictionStep: glint_mask_path: Optional[str] = None, output_dir: Union[str, Path] = "./10_sampling", callback: Optional[Callable] = None, + use_adaptive_sampling: bool = True, ) -> str: """生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱""" from pathlib import Path @@ -83,10 +84,14 @@ class PredictionStep: if glint_mask_to_use is None: print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。") - # 传递极度安全的 deglint_img_str 进底层 + # 传递极度安全的 deglint_img_str 进底层(关键字传参,避免 positional 参数顺序陷阱) get_spectral_sampling_points_chunked( deglint_img_str, water_mask_path, glint_mask_to_use, - output_path, interval, sample_radius, chunk_size + output_path, + interval=interval, + sample_radius=sample_radius, + chunk_size=chunk_size, + use_adaptive_sampling=use_adaptive_sampling, ) notify("completed", f"采样点光谱数据已保存: {output_path}") diff --git a/src/core/water_quality_inversion_pipeline_GUI.py b/src/core/water_quality_inversion_pipeline_GUI.py index 8e7a8ab..ac5d639 100644 --- a/src/core/water_quality_inversion_pipeline_GUI.py +++ b/src/core/water_quality_inversion_pipeline_GUI.py @@ -753,17 +753,19 @@ class WaterQualityInversionPipeline: chunk_size: int = 1000, water_mask_path: Optional[str] = None, glint_mask_path: Optional[str] = None, + use_adaptive_sampling: bool = True, skip_dependency_check: bool = False, **kwargs) -> str: """ 步骤7: 生成根据水域掩膜内且耀斑掩膜外的采样点,统计采样点的平均光谱 - + Args: deglint_img_path: 去除耀斑后的影像文件路径(如果为None,使用步骤3的结果) interval: 采样点间隔(像元数) sample_radius: 采样点半径(像元数) chunk_size: 每次处理的行数(控制内存使用) water_mask_path: dat格式的水域掩膜文件路径(如果为None,将使用步骤1生成的dat格式掩膜) - + use_adaptive_sampling: 是否启用自适应采样(根据水体宽度动态调整间隔) + Returns: 采样点光谱数据CSV文件路径 """ @@ -786,6 +788,7 @@ class WaterQualityInversionPipeline: water_mask_path=water_mask_path, glint_mask_path=glint_mask_path, output_dir=str(self.sampling_dir), + use_adaptive_sampling=use_adaptive_sampling, ) self._record_step_time("步骤7: 生成预测采样点", 0, 0) self._notify("completed", f"采样点光谱数据已保存: {result}") diff --git a/src/gui/dialogs.py b/src/gui/dialogs.py index eb88b2e..04ecc8a 100644 --- a/src/gui/dialogs.py +++ b/src/gui/dialogs.py @@ -7,6 +7,11 @@ """ import os +from typing import Optional + +import numpy as np +import pandas as pd + from PyQt5.QtCore import Qt, QTimer from PyQt5.QtGui import QFont from PyQt5.QtWidgets import ( @@ -20,7 +25,11 @@ from PyQt5.QtWidgets import ( QWidget, QComboBox, QLineEdit, + QTableWidget, + QTableWidgetItem, ) +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.figure import Figure class BandConfirmDialog(QDialog): @@ -334,3 +343,231 @@ class AISettingsDialog(QDialog): "text_model": s.value("text_model", "", type=str), "timeout_s": s.value("timeout_s", 120, type=int), } + + +# ───────────────────────────────────────────────────────────────────────────── +# 交互式采样点与光谱查看器 +# ───────────────────────────────────────────────────────────────────────────── + + +class SamplingViewerDialog(QDialog): + """交互式采样点与光谱查看器 + + 左右分栏: + - 左侧:matplotlib 散点图(pixel_x / pixel_y) + - 右侧:上方信息面板(坐标),下方子图画光谱曲线 + + 点击散点图任意点 → 定位最近采样点 → 高亮 + 显示坐标 + 绘制该点光谱 + """ + + def __init__(self, csv_path: str, parent: Optional[QWidget] = None): + super().__init__(parent) + self.csv_path = csv_path + self.df: Optional[pd.DataFrame] = None + self.selected_idx: Optional[int] = None + self._init_ui() + self._load_data() + self._draw_scatter() + + def _init_ui(self): + """搭建对话框 UI""" + self.setWindowTitle("采样点与光谱查看器") + self.setModal(False) + self.resize(1100, 600) + + main_layout = QHBoxLayout(self) + + # --- 左侧:matplotlib 散点图画布 --- + self._fig = Figure(figsize=(6, 5)) + self._canvas = FigureCanvasQTAgg(self._fig) + self._ax_scatter = self._fig.add_subplot(111) + self._ax_scatter.set_xlabel("pixel_x") + self._ax_scatter.set_ylabel("pixel_y") + self._ax_scatter.set_title("采样点分布(点击查看详情)") + self._ax_scatter.invert_yaxis() + self._fig.tight_layout() + + main_layout.addWidget(self._canvas, stretch=2) + + # --- 右侧:信息面板 + 光谱子图 --- + right_widget = QWidget() + right_layout = QVBoxLayout(right_widget) + + self._info_label = QLabel("点击左侧散点图选择采样点") + self._info_label.setStyleSheet( + "QLabel { background-color: #f0f0f0; padding: 6px; " + "border: 1px solid #ccc; font-size: 13px; }" + ) + right_layout.addWidget(self._info_label) + + # 采样点列表迷你表格 + self._point_table = QTableWidget() + self._point_table.setColumnCount(3) + self._point_table.setHorizontalHeaderLabels(["pixel_x", "pixel_y", "index"]) + self._point_table.setMaximumHeight(120) + self._point_table.setEditTriggers(QTableWidget.NoEditTriggers) + self._point_table.setSelectionBehavior(QTableWidget.SelectRows) + right_layout.addWidget(QLabel("采样点列表(共 N 个):")) + right_layout.addWidget(self._point_table) + + # 光谱曲线子图 + self._fig_right = Figure(figsize=(5, 3)) + self._ax_spectrum = self._fig_right.add_subplot(111) + self._ax_spectrum.set_xlabel("Band Index") + self._ax_spectrum.set_ylabel("Reflectance") + self._ax_spectrum.set_title("光谱曲线") + self._fig_right.tight_layout() + + self._canvas_right = FigureCanvasQTAgg(self._fig_right) + right_layout.addWidget(self._canvas_right, stretch=1) + + main_layout.addWidget(right_widget, stretch=1) + self.setLayout(main_layout) + + # 绑定鼠标点击事件 + self._cid = self._canvas.mpl_connect('button_press_event', self._on_click) + + def _load_data(self): + """加载 CSV 数据""" + if not os.path.exists(self.csv_path): + self.df = None + return + self.df = pd.read_csv(self.csv_path) + self._info_label.setText( + f"共加载 {len(self.df)} 个采样点,文件:{os.path.basename(self.csv_path)}" + ) + + def _draw_scatter(self): + """绘制散点图""" + self._ax_scatter.clear() + self._ax_scatter.set_xlabel("pixel_x") + self._ax_scatter.set_ylabel("pixel_y") + self._ax_scatter.set_title("采样点分布(点击查看详情)") + self._ax_scatter.invert_yaxis() + + if self.df is None: + self._ax_scatter.text( + 0.5, 0.5, "无采样数据", + ha='center', va='center', transform=self._ax_scatter.transAxes, + color='gray', fontsize=14 + ) + self._fig.canvas.draw_idle() + return + + xs = self.df['pixel_x'].values + ys = self.df['pixel_y'].values + self._ax_scatter.scatter(xs, ys, s=15, alpha=0.7, color='#2196F3') + self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)") + self._fig.canvas.draw_idle() + + # 填充迷你表格 + self._point_table.blockSignals(True) + self._point_table.setRowCount(min(len(self.df), 200)) + for i, (_, row) in enumerate(self.df.iterrows()): + if i >= 200: + break + self._point_table.setItem( + i, 0, QTableWidgetItem(str(int(row.get('pixel_x', 0)))) + ) + self._point_table.setItem( + i, 1, QTableWidgetItem(str(int(row.get('pixel_y', 0)))) + ) + self._point_table.setItem(i, 2, QTableWidgetItem(str(i))) + self._point_table.blockSignals(False) + self._point_table.setRowCount(min(len(self.df), 200)) + + def _on_click(self, event): + """鼠标点击散点图 → 找最近点 + 高亮 + 画光谱""" + if self.df is None or event.xdata is None or event.ydata is None: + return + + click_x, click_y = event.xdata, event.ydata + distances = np.sqrt( + (self.df['pixel_x'].values - click_x) ** 2 + + (self.df['pixel_y'].values - click_y) ** 2 + ) + nearest_idx = int(np.argmin(distances)) + self.selected_idx = nearest_idx + + row = self.df.iloc[nearest_idx] + pixel_x = int(row.get('pixel_x', 0)) + pixel_y = int(row.get('pixel_y', 0)) + x_coord = row.get('x_coord', 'N/A') + y_coord = row.get('y_coord', 'N/A') + + self._info_label.setText( + f"选中的采样点 #{nearest_idx}
" + f"pixel_x = {pixel_x}   pixel_y = {pixel_y}
" + f"x_coord = {x_coord}   y_coord = {y_coord}" + ) + + # 高亮散点图 + self._ax_scatter.clear() + self._ax_scatter.set_xlabel("pixel_x") + self._ax_scatter.set_ylabel("pixel_y") + self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)") + self._ax_scatter.invert_yaxis() + self._ax_scatter.scatter( + self.df['pixel_x'].values, self.df['pixel_y'].values, + s=15, alpha=0.5, color='#90CAF9' + ) + self._ax_scatter.scatter( + [pixel_x], [pixel_y], s=80, alpha=0.9, + color='red', zorder=5, label=f"#{nearest_idx}" + ) + self._ax_scatter.legend() + self._fig.canvas.draw_idle() + + self._draw_spectrum(row) + + def _draw_spectrum(self, row: pd.Series): + """从一行数据中提取纯波段数值并绘图""" + self._ax_spectrum.clear() + self._ax_spectrum.set_xlabel("Band Index") + self._ax_spectrum.set_ylabel("Reflectance") + self._ax_spectrum.set_title("光谱曲线") + + exclude_patterns = ( + 'pixel_x', 'pixel_y', 'x_coord', 'y_coord', + 'wqi', 'index', 'Unnamed', 'id', 'ID', + 'longitude', 'latitude', 'lon', 'lat', + ) + spectral_cols = [ + col for col in self.df.columns + if not any(p in col.lower() for p in exclude_patterns) + and pd.api.types.is_numeric_dtype(self.df[col]) + ] + + if not spectral_cols: + self._ax_spectrum.text( + 0.5, 0.5, "无可用波段数据", + ha='center', va='center', transform=self._ax_spectrum.transAxes, + color='gray', fontsize=11 + ) + self._fig_right.canvas.draw_idle() + return + + band_values = [] + for col in spectral_cols: + val = row[col] + try: + band_values.append(float(val)) + except (ValueError, TypeError): + band_values.append(np.nan) + + band_indices = np.arange(len(band_values), dtype=float) + valid = ~np.isnan(band_values) + if not valid.any(): + self._ax_spectrum.text( + 0.5, 0.5, "波段数据全为无效值", + ha='center', va='center', transform=self._ax_spectrum.transAxes, + color='gray', fontsize=11 + ) + else: + self._ax_spectrum.plot( + band_indices[valid], np.array(band_values)[valid], + color='#1E88E5', linewidth=1.2 + ) + self._ax_spectrum.grid(True, alpha=0.3) + + self._fig_right.canvas.draw_idle() diff --git a/src/gui/panels/step7_panel.py b/src/gui/panels/step7_panel.py index a819201..7c9ed67 100644 --- a/src/gui/panels/step7_panel.py +++ b/src/gui/panels/step7_panel.py @@ -12,6 +12,7 @@ from PyQt5.QtWidgets import ( ) from src.gui.components.custom_widgets import FileSelectWidget +from src.gui.dialogs import SamplingViewerDialog from src.gui.styles import ModernStylesheet @@ -58,6 +59,10 @@ class Step7Panel(QWidget): self.chunk_size.setValue(1000) params_layout.addRow("处理块大小:", self.chunk_size) + self.use_adaptive_sampling = QCheckBox("启用自适应采样") + self.use_adaptive_sampling.setChecked(True) + params_layout.addRow("采样模式:", self.use_adaptive_sampling) + params_group.setLayout(params_layout) layout.addWidget(params_group) @@ -80,15 +85,25 @@ class Step7Panel(QWidget): self.run_btn.clicked.connect(self.run_step) layout.addWidget(self.run_btn) + # 交互式预览按钮 + self.preview_btn = QPushButton("📊 交互式预览采样点与光谱") + self.preview_btn.setEnabled(False) + self.preview_btn.clicked.connect(self._open_sampling_viewer) + layout.addWidget(self.preview_btn) + layout.addStretch() self.setLayout(layout) + # 监听输出路径变化,实时更新预览按钮状态 + self.output_file.line_edit.textChanged.connect(self._on_output_changed) + def get_config(self): """获取配置""" config = { 'interval': self.interval.value(), 'sample_radius': self.sample_radius.value(), 'chunk_size': self.chunk_size.value(), + 'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(), } deglint_img_path = self.deglint_img_file.get_path() if deglint_img_path: @@ -96,7 +111,6 @@ class Step7Panel(QWidget): water_mask_path = self.water_mask_file.get_path() if water_mask_path: config['water_mask_path'] = water_mask_path - # 注意:step7_generate_sampling_points 不接受 output_path 参数,输出路径由 pipeline 内部自动生成 return config def set_config(self, config): @@ -107,6 +121,8 @@ class Step7Panel(QWidget): self.sample_radius.setValue(config['sample_radius']) if 'chunk_size' in config: self.chunk_size.setValue(config['chunk_size']) + if 'use_adaptive_sampling' in config: + self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling']) if 'deglint_img_path' in config: self.deglint_img_file.set_path(config['deglint_img_path']) if 'water_mask_path' in config: @@ -195,6 +211,9 @@ class Step7Panel(QWidget): os.makedirs(os.path.dirname(output_path), exist_ok=True) self.output_file.set_path(output_path.replace('\\', '/')) + # 4. 同步更新预览按钮状态(路径可能已自动填充) + self._check_csv_exists() + def run_step(self): """独立运行步骤7""" deglint_img_path = self.deglint_img_file.get_path() @@ -206,3 +225,28 @@ class Step7Panel(QWidget): if hasattr(main_window, 'run_single_step'): config = {'step7': self.get_config()} main_window.run_single_step('step7', config) + + def _check_csv_exists(self): + """检查 output csv 是否存在,驱动预览按钮启停""" + csv_path = self.output_file.get_path() + enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path)) + self.preview_btn.setEnabled(enabled) + return enabled + + def _on_output_changed(self, _text=None): + """输出路径输入框内容变化时调用(_text 为 line_edit.textChanged 信号参数)""" + self._check_csv_exists() + + def _open_sampling_viewer(self): + """打开交互式采样点查看器弹窗""" + csv_path = self.output_file.get_path() + if not csv_path or not os.path.exists(csv_path): + QMessageBox.warning( + self, "文件不存在", + f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤7生成数据。" + ) + return + dialog = SamplingViewerDialog(csv_path, self) + dialog.exec_() + # 弹窗关闭后再次检查状态(可能文件被覆盖等) + self._check_csv_exists()