Files
WQ_GUI/src/gui/dialogs.py

582 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
自定义确认对话框集合
"职责单一 + 不污染主 GUI 文件"原则拆分。
与 water_quality_gui.py 保持 1:1 风格(中文注释 / 顶部 encoding 声明)。
"""
import os
from typing import Optional
import numpy as np
import pandas as pd
from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtGui import QFont
from PyQt5.QtWidgets import (
QDialog,
QLabel,
QSpinBox,
QPushButton,
QVBoxLayout,
QHBoxLayout,
QDialogButtonBox,
QWidget,
QComboBox,
QLineEdit,
QTableWidget,
QTableWidgetItem,
)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
class BandConfirmDialog(QDialog):
"""波段越界智能确认对话框60 秒倒计时)
场景
----
用户在 step3 面板里设置了 nir_band=66但实际影像只有 50 波段。
Pipeline 一旦按 66 去取波段就会报 IndexError。
行为约定
--------
- 启动时 QTimer 开始 60s 倒计时,按钮文字同步显示 "确定 (Ns)"
- 用户手动调整 QSpinBox + 点"确定":立即 accept(),返回当前 spinbox 值。
- 用户 60s 未操作:定时器归零时自动 accept(),返回当前 spinbox 值
**默认值为影像最大波段数 = 用户拿不到想要波段时的兜底**)。
- 用户点"取消运行"reject(),调用方应中止 run_full_pipeline。
"""
DEFAULT_TIMEOUT = 60 # 秒
def __init__(
self,
parent: QWidget = None,
requested_band: int = 66,
max_band: int = 50,
recommended_band: int = 66,
method_label: str = "NIR",
timeout_seconds: int = DEFAULT_TIMEOUT,
):
super().__init__(parent)
self._requested_band = requested_band
self._max_band = max_band
self._recommended_band = recommended_band
self._method_label = method_label
self._timeout_seconds = timeout_seconds
self._remaining = timeout_seconds
self._selected_band = max_band # 默认 = 最大波段(兜底)
self.setWindowTitle("波段索引越界")
self.setModal(True)
self.setMinimumWidth(420)
self._init_ui()
self._start_timer()
def _init_ui(self):
"""搭建 UI警告文字 + 灰色推荐 + SpinBox + 倒计时按钮"""
layout = QVBoxLayout(self)
# 1) 主提示(带 HTML 强调)
self._msg_label = QLabel(
f"影像仅有 <b>{self._max_band}</b> 个波段,"
f"无法读取第 <b>{self._requested_band}</b> 波段({self._method_label})。"
)
self._msg_label.setWordWrap(True)
self._msg_label.setFont(QFont("Microsoft YaHei", 10))
layout.addWidget(self._msg_label)
# 2) 灰色小字推荐
hint_label = QLabel(
f"(推荐近红外波段序号:{self._recommended_band}"
)
hint_label.setStyleSheet("color: #888; font-size: 10px;")
layout.addWidget(hint_label)
# 3) 波段选择 SpinBox默认值 = 最大波段 = 超时兜底)
spin_row = QHBoxLayout()
spin_row.addWidget(QLabel(f"请选择要使用的{self._method_label}索引:"))
self._spin = QSpinBox()
self._spin.setRange(0, self._max_band)
self._spin.setValue(self._max_band)
self._spin.setSuffix(f" / 0~{self._max_band}")
spin_row.addWidget(self._spin)
spin_row.addStretch(1)
layout.addLayout(spin_row)
# 4) 倒计时说明
countdown_tip = QLabel(
f"{self._timeout_seconds} 秒内不操作,将自动使用最大波段 "
f"{self._max_band})继续运行。"
)
countdown_tip.setStyleSheet("color: #555; font-size: 9px;")
countdown_tip.setWordWrap(True)
layout.addWidget(countdown_tip)
# 5) 按钮组(手动"确定 (Ns)" + "取消运行"
btn_box = QDialogButtonBox()
self._ok_btn = QPushButton(f"确定 ({self._remaining}s)")
self._ok_btn.setDefault(True)
self._ok_btn.clicked.connect(self.accept)
self._cancel_btn = QPushButton("取消运行")
self._cancel_btn.clicked.connect(self.reject)
btn_box.addButton(self._ok_btn, QDialogButtonBox.AcceptRole)
btn_box.addButton(self._cancel_btn, QDialogButtonBox.RejectRole)
layout.addWidget(btn_box)
def _start_timer(self):
"""启动 1Hz 倒计时;归零时自动 accept()"""
self._timer = QTimer(self)
self._timer.setInterval(1000)
self._timer.timeout.connect(self._tick)
self._timer.start()
def _tick(self):
"""每秒刷新按钮文字;归零时停表 + accept()"""
self._remaining -= 1
if self._remaining <= 0:
self._timer.stop()
self.accept() # 超时:返回当前 spinbox 值(= max_band
else:
self._ok_btn.setText(f"确定 ({self._remaining}s)")
# ── 暴露给调用方的结果接口 ──────────────────────────────
def selected_band(self) -> int:
"""弹窗关闭后由调用方取回用户选定的波段索引"""
return self._selected_band
def accept(self):
""""确定"或倒计时归零触发:记录当前 spinbox 值后真正关闭"""
self._selected_band = self._spin.value()
self._timer.stop()
super().accept()
def reject(self):
""""取消运行"触发:停表 + 关闭,调用方需中止流程"""
self._timer.stop()
super().reject()
# ─────────────────────────────────────────────────────────────────────────────
# AI 引擎设置对话框
# ─────────────────────────────────────────────────────────────────────────────
from PyQt5.QtCore import QSettings
AI_SETTINGS_ORG = "IrisWaterQuality"
AI_SETTINGS_APP = "WQ_GUI"
AI_DEFAULTS = {
"ollama": {
"api_base_url": "http://localhost:11434",
"vision_model": "qwen3-vl:8b",
"text_model": "qwen3-vl:8b",
},
"minimax": {
"api_base_url": "https://api.minimaxi.com/v1/text/chatcompletion_v2",
"vision_model": "abab6.5s-chat",
"text_model": "abab6.5s-chat",
},
}
class AISettingsDialog(QDialog):
"""AI 引擎可视化配置弹窗,配置持久化到 QSettings。"""
def __init__(self, parent=None):
super().__init__(parent)
self.setWindowTitle("AI 引擎配置")
self.setModal(True)
self.setMinimumWidth(520)
self._load_settings()
self._init_ui()
def _load_settings(self):
"""从 QSettings 读取已有配置;无记录则回退到环境变量或默认值。"""
s = QSettings(AI_SETTINGS_ORG, AI_SETTINGS_APP)
self._provider = s.value("ai_provider", "minimax", type=str)
# API Key 不设默认值(敏感信息,首次必须由用户输入)
self._api_key = s.value("minimax_api_key", "", type=str)
# 已保存的 URL 和模型;若 QSettings 无记录则读环境变量
if self._provider == "ollama":
self._api_base_url = (
s.value("api_base_url", "")
or os.environ.get("OLLAMA_URL", AI_DEFAULTS["ollama"]["api_base_url"])
)
self._vision_model = (
s.value("vision_model", "")
or os.environ.get("OLLAMA_VISION_MODEL", AI_DEFAULTS["ollama"]["vision_model"])
)
self._text_model = (
s.value("text_model", "")
or os.environ.get("OLLAMA_TEXT_MODEL", AI_DEFAULTS["ollama"]["text_model"])
)
else:
self._api_base_url = (
s.value("api_base_url", "")
or os.environ.get("MINIMAX_BASE_URL", AI_DEFAULTS["minimax"]["api_base_url"])
)
self._vision_model = (
s.value("vision_model", "")
or os.environ.get("MINIMAX_VISION_MODEL", AI_DEFAULTS["minimax"]["vision_model"])
)
self._text_model = (
s.value("text_model", "")
or os.environ.get("MINIMAX_TEXT_MODEL", AI_DEFAULTS["minimax"]["text_model"])
)
self._timeout = s.value("timeout_s", 120, type=int)
def _init_ui(self):
layout = QVBoxLayout(self)
layout.setSpacing(12)
# ── Provider ──────────────────────────────────────────────────────────
provider_row = QHBoxLayout()
provider_row.addWidget(QLabel("AI 引擎提供商:"))
self._provider_combo = QComboBox()
self._provider_combo.addItems(["Ollama", "Minimax"])
self._provider_combo.setCurrentText("Ollama" if self._provider == "ollama" else "Minimax")
self._provider_combo.currentIndexChanged.connect(self._on_provider_changed)
provider_row.addWidget(self._provider_combo, 1)
provider_row.addStretch(1)
layout.addLayout(provider_row)
# ── API Base URL ───────────────────────────────────────────────────────
url_row = QHBoxLayout()
url_row.addWidget(QLabel("API Base URL:"))
self._url_edit = QLineEdit(self._api_base_url)
self._url_edit.setPlaceholderText("例如: http://localhost:11434")
url_row.addWidget(self._url_edit, 1)
layout.addLayout(url_row)
# ── API Key ───────────────────────────────────────────────────────────
key_row = QHBoxLayout()
key_row.addWidget(QLabel("API Key:"))
self._key_edit = QLineEdit(self._api_key)
self._key_edit.setPlaceholderText("输入 API Key敏感信息已加密存储")
self._key_edit.setEchoMode(QLineEdit.Password)
key_row.addWidget(self._key_edit, 1)
layout.addLayout(key_row)
# ── 模型名称 ───────────────────────────────────────────────────────────
model_row = QHBoxLayout()
model_row.addWidget(QLabel("视觉模型:"))
self._vision_edit = QLineEdit(self._vision_model)
model_row.addWidget(self._vision_edit, 1)
model_row.addSpacing(12)
model_row.addWidget(QLabel("文本模型:"))
self._text_edit = QLineEdit(self._text_model)
model_row.addWidget(self._text_edit, 1)
layout.addLayout(model_row)
# ── 超时 ──────────────────────────────────────────────────────────────
timeout_row = QHBoxLayout()
timeout_row.addWidget(QLabel("请求超时(秒):"))
self._timeout_spin = QSpinBox()
self._timeout_spin.setRange(30, 3600)
self._timeout_spin.setSingleStep(30)
self._timeout_spin.setValue(self._timeout)
timeout_row.addWidget(self._timeout_spin)
timeout_row.addStretch(1)
layout.addLayout(timeout_row)
# ── 说明 ──────────────────────────────────────────────────────────────
hint = QLabel(
"提示:切换引擎后将自动填充推荐默认值(可手动修改)。"
"API Key 仅本地加密存储,不会明文暴露。"
)
hint.setStyleSheet("color: #888; font-size: 10px;")
hint.setWordWrap(True)
layout.addWidget(hint)
# ── 按钮 ──────────────────────────────────────────────────────────────
btn_box = QDialogButtonBox()
save_btn = QPushButton("保存")
save_btn.setDefault(True)
save_btn.clicked.connect(self._save_and_close)
cancel_btn = QPushButton("取消")
cancel_btn.clicked.connect(self.reject)
btn_box.addButton(save_btn, QDialogButtonBox.AcceptRole)
btn_box.addButton(cancel_btn, QDialogButtonBox.RejectRole)
layout.addWidget(btn_box)
def _on_provider_changed(self):
"""切换 Provider 时自动填充推荐默认值。"""
provider = self._provider_combo.currentText().lower()
defaults = AI_DEFAULTS.get(provider, AI_DEFAULTS["minimax"])
self._url_edit.setText(defaults["api_base_url"])
self._vision_edit.setText(defaults["vision_model"])
self._text_edit.setText(defaults["text_model"])
def _save_and_close(self):
"""持久化到 QSettings 并关闭。"""
s = QSettings(AI_SETTINGS_ORG, AI_SETTINGS_APP)
provider = self._provider_combo.currentText().lower()
s.setValue("ai_provider", provider)
s.setValue("api_base_url", self._url_edit.text().strip())
s.setValue("api_key", self._key_edit.text().strip())
s.setValue("vision_model", self._vision_edit.text().strip())
s.setValue("text_model", self._text_edit.text().strip())
s.setValue("timeout_s", self._timeout_spin.value())
s.sync()
self.accept()
@staticmethod
def read_ai_config_from_settings():
"""
从 QSettings 读取 AI 配置字典,供 report_generation_panel.py 等处使用。
返回键ai_provider / api_base_url / api_key / vision_model / text_model / timeout_s
"""
s = QSettings(AI_SETTINGS_ORG, AI_SETTINGS_APP)
provider = s.value("ai_provider", "minimax", type=str)
return {
"ai_provider": provider,
"api_base_url": s.value("api_base_url", "", type=str),
"api_key": s.value("api_key", "", type=str),
"vision_model": s.value("vision_model", "", type=str),
"text_model": s.value("text_model", "", type=str),
"timeout_s": s.value("timeout_s", 120, type=int),
}
# ─────────────────────────────────────────────────────────────────────────────
# 交互式采样点与光谱查看器
# ─────────────────────────────────────────────────────────────────────────────
import matplotlib.pyplot as _plt
# 全局字体设置(防中文乱码 + 负号显示异常)
_plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'Arial Unicode MS']
_plt.rcParams['axes.unicode_minus'] = False
class SamplingViewerDialog(QDialog):
"""交互式采样点与光谱查看器
左右分栏:
- 左侧matplotlib 散点图pixel_x / pixel_y
- 右侧:上方信息面板(坐标),下方子图画光谱曲线
点击散点图任意点 → 定位最近采样点 → 高亮 + 显示坐标 + 绘制该点光谱
"""
def __init__(self, csv_path: str, parent: Optional[QWidget] = None):
super().__init__(parent)
self.csv_path = csv_path
self.df: Optional[pd.DataFrame] = None
self.selected_idx: Optional[int] = None
self._init_ui()
self._load_data()
self._draw_scatter()
def _init_ui(self):
"""搭建对话框 UI"""
self.setWindowTitle("采样点与光谱查看器")
self.setModal(False)
self.resize(1100, 600)
main_layout = QHBoxLayout(self)
# --- 左侧matplotlib 散点图画布 ---
self._fig = Figure(figsize=(6, 5))
self._canvas = FigureCanvasQTAgg(self._fig)
self._ax_scatter = self._fig.add_subplot(111)
self._ax_scatter.set_xlabel("像素 X")
self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
self._fig.tight_layout()
main_layout.addWidget(self._canvas, stretch=2)
# --- 右侧:信息面板 + 光谱子图 ---
right_widget = QWidget()
right_layout = QVBoxLayout(right_widget)
right_layout.setContentsMargins(0, 0, 0, 0)
# 坐标信息面板(多行中文清晰显示)
self._info_label = QLabel("点击左侧散点图选择采样点")
self._info_label.setStyleSheet(
"QLabel { background-color: #f0f0f0; padding: 8px; "
"border: 1px solid #ccc; border-radius: 4px; font-size: 13px; }"
)
right_layout.addWidget(self._info_label)
# 采样点列表迷你表格
self._point_table = QTableWidget()
self._point_table.setColumnCount(3)
self._point_table.setHorizontalHeaderLabels(["像素 X", "像素 Y", "序号"])
self._point_table.setMaximumHeight(120)
self._point_table.setEditTriggers(QTableWidget.NoEditTriggers)
self._point_table.setSelectionBehavior(QTableWidget.SelectRows)
right_layout.addWidget(QLabel("采样点列表(共 N 个):"))
right_layout.addWidget(self._point_table)
# 光谱曲线子图
self._fig_right = Figure(figsize=(5, 3))
self._ax_spectrum = self._fig_right.add_subplot(111)
self._ax_spectrum.set_xlabel("波段序号")
self._ax_spectrum.set_ylabel("反射率")
self._ax_spectrum.set_title("光谱曲线")
self._fig_right.tight_layout()
self._canvas_right = FigureCanvasQTAgg(self._fig_right)
right_layout.addWidget(self._canvas_right, stretch=1)
main_layout.addWidget(right_widget, stretch=1)
self.setLayout(main_layout)
# 绑定鼠标点击事件
self._cid = self._canvas.mpl_connect('button_press_event', self._on_click)
def _load_data(self):
"""加载 CSV 数据"""
if not os.path.exists(self.csv_path):
self.df = None
return
self.df = pd.read_csv(self.csv_path)
self._info_label.setText(
f"共加载 {len(self.df)} 个采样点,文件:{os.path.basename(self.csv_path)}"
)
def _draw_scatter(self):
"""绘制散点图"""
self._ax_scatter.clear()
self._ax_scatter.set_xlabel("像素 X")
self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
if self.df is None:
self._ax_scatter.text(
0.5, 0.5, "无采样数据",
ha='center', va='center', transform=self._ax_scatter.transAxes,
color='gray', fontsize=14
)
self._fig.canvas.draw_idle()
return
xs = self.df['pixel_x'].values
ys = self.df['pixel_y'].values
self._ax_scatter.scatter(xs, ys, s=15, alpha=0.7, color='#2196F3')
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
self._fig.canvas.draw_idle()
# 填充迷你表格
self._point_table.blockSignals(True)
self._point_table.setRowCount(min(len(self.df), 200))
for i, (_, row) in enumerate(self.df.iterrows()):
if i >= 200:
break
self._point_table.setItem(
i, 0, QTableWidgetItem(str(int(row.get('pixel_x', 0))))
)
self._point_table.setItem(
i, 1, QTableWidgetItem(str(int(row.get('pixel_y', 0))))
)
self._point_table.setItem(i, 2, QTableWidgetItem(str(i)))
self._point_table.blockSignals(False)
self._point_table.setRowCount(min(len(self.df), 200))
def _on_click(self, event):
"""鼠标点击散点图 → 找最近点 + 高亮 + 画光谱"""
if self.df is None or event.xdata is None or event.ydata is None:
return
click_x, click_y = event.xdata, event.ydata
distances = np.sqrt(
(self.df['pixel_x'].values - click_x) ** 2 +
(self.df['pixel_y'].values - click_y) ** 2
)
nearest_idx = int(np.argmin(distances))
self.selected_idx = nearest_idx
row = self.df.iloc[nearest_idx]
pixel_x = int(row.get('pixel_x', 0))
pixel_y = int(row.get('pixel_y', 0))
x_coord = row.get('x_coord', 'N/A')
y_coord = row.get('y_coord', 'N/A')
self._info_label.setText(
f"<b>选中的采样点 #{nearest_idx}</b><br>"
f"图像像素坐标: X = {pixel_x}, Y = {pixel_y}<br>"
f"地理真实坐标: 经度(X) = {x_coord}, 纬度(Y) = {y_coord}"
)
# 高亮散点图
self._ax_scatter.clear()
self._ax_scatter.set_xlabel("像素 X")
self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
self._ax_scatter.invert_yaxis()
self._ax_scatter.scatter(
self.df['pixel_x'].values, self.df['pixel_y'].values,
s=15, alpha=0.5, color='#90CAF9'
)
self._ax_scatter.scatter(
[pixel_x], [pixel_y], s=80, alpha=0.9,
color='red', zorder=5, label=f"#{nearest_idx}"
)
self._ax_scatter.legend()
self._fig.canvas.draw_idle()
self._draw_spectrum(row)
def _draw_spectrum(self, row: pd.Series):
"""从一行数据中提取纯波段数值并绘图"""
self._ax_spectrum.clear()
self._ax_spectrum.set_xlabel("波段序号")
self._ax_spectrum.set_ylabel("反射率")
self._ax_spectrum.set_title("光谱曲线")
exclude_patterns = (
'pixel_x', 'pixel_y', 'x_coord', 'y_coord',
'wqi', 'index', 'Unnamed', 'id', 'ID',
'longitude', 'latitude', 'lon', 'lat',
)
spectral_cols = [
col for col in self.df.columns
if not any(p in col.lower() for p in exclude_patterns)
and pd.api.types.is_numeric_dtype(self.df[col])
]
if not spectral_cols:
self._ax_spectrum.text(
0.5, 0.5, "无可用波段数据",
ha='center', va='center', transform=self._ax_spectrum.transAxes,
color='gray', fontsize=11
)
self._fig_right.canvas.draw_idle()
return
band_values = []
for col in spectral_cols:
val = row[col]
try:
band_values.append(float(val))
except (ValueError, TypeError):
band_values.append(np.nan)
band_indices = np.arange(len(band_values), dtype=float)
valid = ~np.isnan(band_values)
if not valid.any():
self._ax_spectrum.text(
0.5, 0.5, "波段数据全为无效值",
ha='center', va='center', transform=self._ax_spectrum.transAxes,
color='gray', fontsize=11
)
else:
self._ax_spectrum.plot(
band_indices[valid], np.array(band_values)[valid],
color='#1E88E5', linewidth=1.2
)
self._ax_spectrum.grid(True, alpha=0.3)
self._fig_right.canvas.draw_idle()