feat(sampling): add adaptive sampling toggle + interactive sampling point viewer
This commit is contained in:
@ -7,6 +7,11 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from PyQt5.QtCore import Qt, QTimer
|
||||
from PyQt5.QtGui import QFont
|
||||
from PyQt5.QtWidgets import (
|
||||
@ -20,7 +25,11 @@ from PyQt5.QtWidgets import (
|
||||
QWidget,
|
||||
QComboBox,
|
||||
QLineEdit,
|
||||
QTableWidget,
|
||||
QTableWidgetItem,
|
||||
)
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
|
||||
class BandConfirmDialog(QDialog):
|
||||
@ -334,3 +343,231 @@ class AISettingsDialog(QDialog):
|
||||
"text_model": s.value("text_model", "", type=str),
|
||||
"timeout_s": s.value("timeout_s", 120, type=int),
|
||||
}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# 交互式采样点与光谱查看器
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class SamplingViewerDialog(QDialog):
|
||||
"""交互式采样点与光谱查看器
|
||||
|
||||
左右分栏:
|
||||
- 左侧:matplotlib 散点图(pixel_x / pixel_y)
|
||||
- 右侧:上方信息面板(坐标),下方子图画光谱曲线
|
||||
|
||||
点击散点图任意点 → 定位最近采样点 → 高亮 + 显示坐标 + 绘制该点光谱
|
||||
"""
|
||||
|
||||
def __init__(self, csv_path: str, parent: Optional[QWidget] = None):
|
||||
super().__init__(parent)
|
||||
self.csv_path = csv_path
|
||||
self.df: Optional[pd.DataFrame] = None
|
||||
self.selected_idx: Optional[int] = None
|
||||
self._init_ui()
|
||||
self._load_data()
|
||||
self._draw_scatter()
|
||||
|
||||
def _init_ui(self):
|
||||
"""搭建对话框 UI"""
|
||||
self.setWindowTitle("采样点与光谱查看器")
|
||||
self.setModal(False)
|
||||
self.resize(1100, 600)
|
||||
|
||||
main_layout = QHBoxLayout(self)
|
||||
|
||||
# --- 左侧:matplotlib 散点图画布 ---
|
||||
self._fig = Figure(figsize=(6, 5))
|
||||
self._canvas = FigureCanvasQTAgg(self._fig)
|
||||
self._ax_scatter = self._fig.add_subplot(111)
|
||||
self._ax_scatter.set_xlabel("pixel_x")
|
||||
self._ax_scatter.set_ylabel("pixel_y")
|
||||
self._ax_scatter.set_title("采样点分布(点击查看详情)")
|
||||
self._ax_scatter.invert_yaxis()
|
||||
self._fig.tight_layout()
|
||||
|
||||
main_layout.addWidget(self._canvas, stretch=2)
|
||||
|
||||
# --- 右侧:信息面板 + 光谱子图 ---
|
||||
right_widget = QWidget()
|
||||
right_layout = QVBoxLayout(right_widget)
|
||||
|
||||
self._info_label = QLabel("点击左侧散点图选择采样点")
|
||||
self._info_label.setStyleSheet(
|
||||
"QLabel { background-color: #f0f0f0; padding: 6px; "
|
||||
"border: 1px solid #ccc; font-size: 13px; }"
|
||||
)
|
||||
right_layout.addWidget(self._info_label)
|
||||
|
||||
# 采样点列表迷你表格
|
||||
self._point_table = QTableWidget()
|
||||
self._point_table.setColumnCount(3)
|
||||
self._point_table.setHorizontalHeaderLabels(["pixel_x", "pixel_y", "index"])
|
||||
self._point_table.setMaximumHeight(120)
|
||||
self._point_table.setEditTriggers(QTableWidget.NoEditTriggers)
|
||||
self._point_table.setSelectionBehavior(QTableWidget.SelectRows)
|
||||
right_layout.addWidget(QLabel("采样点列表(共 N 个):"))
|
||||
right_layout.addWidget(self._point_table)
|
||||
|
||||
# 光谱曲线子图
|
||||
self._fig_right = Figure(figsize=(5, 3))
|
||||
self._ax_spectrum = self._fig_right.add_subplot(111)
|
||||
self._ax_spectrum.set_xlabel("Band Index")
|
||||
self._ax_spectrum.set_ylabel("Reflectance")
|
||||
self._ax_spectrum.set_title("光谱曲线")
|
||||
self._fig_right.tight_layout()
|
||||
|
||||
self._canvas_right = FigureCanvasQTAgg(self._fig_right)
|
||||
right_layout.addWidget(self._canvas_right, stretch=1)
|
||||
|
||||
main_layout.addWidget(right_widget, stretch=1)
|
||||
self.setLayout(main_layout)
|
||||
|
||||
# 绑定鼠标点击事件
|
||||
self._cid = self._canvas.mpl_connect('button_press_event', self._on_click)
|
||||
|
||||
def _load_data(self):
|
||||
"""加载 CSV 数据"""
|
||||
if not os.path.exists(self.csv_path):
|
||||
self.df = None
|
||||
return
|
||||
self.df = pd.read_csv(self.csv_path)
|
||||
self._info_label.setText(
|
||||
f"共加载 {len(self.df)} 个采样点,文件:{os.path.basename(self.csv_path)}"
|
||||
)
|
||||
|
||||
def _draw_scatter(self):
|
||||
"""绘制散点图"""
|
||||
self._ax_scatter.clear()
|
||||
self._ax_scatter.set_xlabel("pixel_x")
|
||||
self._ax_scatter.set_ylabel("pixel_y")
|
||||
self._ax_scatter.set_title("采样点分布(点击查看详情)")
|
||||
self._ax_scatter.invert_yaxis()
|
||||
|
||||
if self.df is None:
|
||||
self._ax_scatter.text(
|
||||
0.5, 0.5, "无采样数据",
|
||||
ha='center', va='center', transform=self._ax_scatter.transAxes,
|
||||
color='gray', fontsize=14
|
||||
)
|
||||
self._fig.canvas.draw_idle()
|
||||
return
|
||||
|
||||
xs = self.df['pixel_x'].values
|
||||
ys = self.df['pixel_y'].values
|
||||
self._ax_scatter.scatter(xs, ys, s=15, alpha=0.7, color='#2196F3')
|
||||
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
|
||||
self._fig.canvas.draw_idle()
|
||||
|
||||
# 填充迷你表格
|
||||
self._point_table.blockSignals(True)
|
||||
self._point_table.setRowCount(min(len(self.df), 200))
|
||||
for i, (_, row) in enumerate(self.df.iterrows()):
|
||||
if i >= 200:
|
||||
break
|
||||
self._point_table.setItem(
|
||||
i, 0, QTableWidgetItem(str(int(row.get('pixel_x', 0))))
|
||||
)
|
||||
self._point_table.setItem(
|
||||
i, 1, QTableWidgetItem(str(int(row.get('pixel_y', 0))))
|
||||
)
|
||||
self._point_table.setItem(i, 2, QTableWidgetItem(str(i)))
|
||||
self._point_table.blockSignals(False)
|
||||
self._point_table.setRowCount(min(len(self.df), 200))
|
||||
|
||||
def _on_click(self, event):
|
||||
"""鼠标点击散点图 → 找最近点 + 高亮 + 画光谱"""
|
||||
if self.df is None or event.xdata is None or event.ydata is None:
|
||||
return
|
||||
|
||||
click_x, click_y = event.xdata, event.ydata
|
||||
distances = np.sqrt(
|
||||
(self.df['pixel_x'].values - click_x) ** 2 +
|
||||
(self.df['pixel_y'].values - click_y) ** 2
|
||||
)
|
||||
nearest_idx = int(np.argmin(distances))
|
||||
self.selected_idx = nearest_idx
|
||||
|
||||
row = self.df.iloc[nearest_idx]
|
||||
pixel_x = int(row.get('pixel_x', 0))
|
||||
pixel_y = int(row.get('pixel_y', 0))
|
||||
x_coord = row.get('x_coord', 'N/A')
|
||||
y_coord = row.get('y_coord', 'N/A')
|
||||
|
||||
self._info_label.setText(
|
||||
f"<b>选中的采样点 #{nearest_idx}</b><br>"
|
||||
f"pixel_x = {pixel_x} pixel_y = {pixel_y}<br>"
|
||||
f"x_coord = {x_coord} y_coord = {y_coord}"
|
||||
)
|
||||
|
||||
# 高亮散点图
|
||||
self._ax_scatter.clear()
|
||||
self._ax_scatter.set_xlabel("pixel_x")
|
||||
self._ax_scatter.set_ylabel("pixel_y")
|
||||
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
|
||||
self._ax_scatter.invert_yaxis()
|
||||
self._ax_scatter.scatter(
|
||||
self.df['pixel_x'].values, self.df['pixel_y'].values,
|
||||
s=15, alpha=0.5, color='#90CAF9'
|
||||
)
|
||||
self._ax_scatter.scatter(
|
||||
[pixel_x], [pixel_y], s=80, alpha=0.9,
|
||||
color='red', zorder=5, label=f"#{nearest_idx}"
|
||||
)
|
||||
self._ax_scatter.legend()
|
||||
self._fig.canvas.draw_idle()
|
||||
|
||||
self._draw_spectrum(row)
|
||||
|
||||
def _draw_spectrum(self, row: pd.Series):
|
||||
"""从一行数据中提取纯波段数值并绘图"""
|
||||
self._ax_spectrum.clear()
|
||||
self._ax_spectrum.set_xlabel("Band Index")
|
||||
self._ax_spectrum.set_ylabel("Reflectance")
|
||||
self._ax_spectrum.set_title("光谱曲线")
|
||||
|
||||
exclude_patterns = (
|
||||
'pixel_x', 'pixel_y', 'x_coord', 'y_coord',
|
||||
'wqi', 'index', 'Unnamed', 'id', 'ID',
|
||||
'longitude', 'latitude', 'lon', 'lat',
|
||||
)
|
||||
spectral_cols = [
|
||||
col for col in self.df.columns
|
||||
if not any(p in col.lower() for p in exclude_patterns)
|
||||
and pd.api.types.is_numeric_dtype(self.df[col])
|
||||
]
|
||||
|
||||
if not spectral_cols:
|
||||
self._ax_spectrum.text(
|
||||
0.5, 0.5, "无可用波段数据",
|
||||
ha='center', va='center', transform=self._ax_spectrum.transAxes,
|
||||
color='gray', fontsize=11
|
||||
)
|
||||
self._fig_right.canvas.draw_idle()
|
||||
return
|
||||
|
||||
band_values = []
|
||||
for col in spectral_cols:
|
||||
val = row[col]
|
||||
try:
|
||||
band_values.append(float(val))
|
||||
except (ValueError, TypeError):
|
||||
band_values.append(np.nan)
|
||||
|
||||
band_indices = np.arange(len(band_values), dtype=float)
|
||||
valid = ~np.isnan(band_values)
|
||||
if not valid.any():
|
||||
self._ax_spectrum.text(
|
||||
0.5, 0.5, "波段数据全为无效值",
|
||||
ha='center', va='center', transform=self._ax_spectrum.transAxes,
|
||||
color='gray', fontsize=11
|
||||
)
|
||||
else:
|
||||
self._ax_spectrum.plot(
|
||||
band_indices[valid], np.array(band_values)[valid],
|
||||
color='#1E88E5', linewidth=1.2
|
||||
)
|
||||
self._ax_spectrum.grid(True, alpha=0.3)
|
||||
|
||||
self._fig_right.canvas.draw_idle()
|
||||
|
||||
Reference in New Issue
Block a user