# -*- coding: utf-8 -*-
"""
自定义确认对话框集合
按"职责单一 + 不污染主 GUI 文件"原则拆分。
与 water_quality_gui.py 保持 1:1 风格(中文注释 / 顶部 encoding 声明)。
"""
import os
from typing import Optional
import numpy as np
import pandas as pd
from PyQt5.QtCore import Qt, QTimer
from PyQt5.QtGui import QFont
from PyQt5.QtWidgets import (
QDialog,
QLabel,
QSpinBox,
QPushButton,
QVBoxLayout,
QHBoxLayout,
QDialogButtonBox,
QWidget,
QComboBox,
QLineEdit,
QTableWidget,
QTableWidgetItem,
)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
class BandConfirmDialog(QDialog):
"""波段越界智能确认对话框(60 秒倒计时)
场景
----
用户在 step3 面板里设置了 nir_band=66,但实际影像只有 50 波段。
Pipeline 一旦按 66 去取波段就会报 IndexError。
行为约定
--------
- 启动时 QTimer 开始 60s 倒计时,按钮文字同步显示 "确定 (Ns)"。
- 用户手动调整 QSpinBox + 点"确定":立即 accept(),返回当前 spinbox 值。
- 用户 60s 未操作:定时器归零时自动 accept(),返回当前 spinbox 值
(**默认值为影像最大波段数 = 用户拿不到想要波段时的兜底**)。
- 用户点"取消运行":reject(),调用方应中止 run_full_pipeline。
"""
DEFAULT_TIMEOUT = 60 # 秒
def __init__(
self,
parent: QWidget = None,
requested_band: int = 66,
max_band: int = 50,
recommended_band: int = 66,
method_label: str = "NIR",
timeout_seconds: int = DEFAULT_TIMEOUT,
):
super().__init__(parent)
self._requested_band = requested_band
self._max_band = max_band
self._recommended_band = recommended_band
self._method_label = method_label
self._timeout_seconds = timeout_seconds
self._remaining = timeout_seconds
self._selected_band = max_band # 默认 = 最大波段(兜底)
self.setWindowTitle("波段索引越界")
self.setModal(True)
self.setMinimumWidth(420)
self._init_ui()
self._start_timer()
def _init_ui(self):
"""搭建 UI:警告文字 + 灰色推荐 + SpinBox + 倒计时按钮"""
layout = QVBoxLayout(self)
# 1) 主提示(带 HTML 强调)
self._msg_label = QLabel(
f"影像仅有 {self._max_band} 个波段,"
f"无法读取第 {self._requested_band} 波段({self._method_label})。"
)
self._msg_label.setWordWrap(True)
self._msg_label.setFont(QFont("Microsoft YaHei", 10))
layout.addWidget(self._msg_label)
# 2) 灰色小字推荐
hint_label = QLabel(
f"(推荐近红外波段序号:{self._recommended_band})"
)
hint_label.setStyleSheet("color: #888; font-size: 10px;")
layout.addWidget(hint_label)
# 3) 波段选择 SpinBox(默认值 = 最大波段 = 超时兜底)
spin_row = QHBoxLayout()
spin_row.addWidget(QLabel(f"请选择要使用的{self._method_label}索引:"))
self._spin = QSpinBox()
self._spin.setRange(0, self._max_band)
self._spin.setValue(self._max_band)
self._spin.setSuffix(f" / 0~{self._max_band}")
spin_row.addWidget(self._spin)
spin_row.addStretch(1)
layout.addLayout(spin_row)
# 4) 倒计时说明
countdown_tip = QLabel(
f"⏱ {self._timeout_seconds} 秒内不操作,将自动使用最大波段 "
f"({self._max_band})继续运行。"
)
countdown_tip.setStyleSheet("color: #555; font-size: 9px;")
countdown_tip.setWordWrap(True)
layout.addWidget(countdown_tip)
# 5) 按钮组(手动"确定 (Ns)" + "取消运行")
btn_box = QDialogButtonBox()
self._ok_btn = QPushButton(f"确定 ({self._remaining}s)")
self._ok_btn.setDefault(True)
self._ok_btn.clicked.connect(self.accept)
self._cancel_btn = QPushButton("取消运行")
self._cancel_btn.clicked.connect(self.reject)
btn_box.addButton(self._ok_btn, QDialogButtonBox.AcceptRole)
btn_box.addButton(self._cancel_btn, QDialogButtonBox.RejectRole)
layout.addWidget(btn_box)
def _start_timer(self):
"""启动 1Hz 倒计时;归零时自动 accept()"""
self._timer = QTimer(self)
self._timer.setInterval(1000)
self._timer.timeout.connect(self._tick)
self._timer.start()
def _tick(self):
"""每秒刷新按钮文字;归零时停表 + accept()"""
self._remaining -= 1
if self._remaining <= 0:
self._timer.stop()
self.accept() # 超时:返回当前 spinbox 值(= max_band)
else:
self._ok_btn.setText(f"确定 ({self._remaining}s)")
# ── 暴露给调用方的结果接口 ──────────────────────────────
def selected_band(self) -> int:
"""弹窗关闭后由调用方取回用户选定的波段索引"""
return self._selected_band
def accept(self):
"""点"确定"或倒计时归零触发:记录当前 spinbox 值后真正关闭"""
self._selected_band = self._spin.value()
self._timer.stop()
super().accept()
def reject(self):
"""点"取消运行"触发:停表 + 关闭,调用方需中止流程"""
self._timer.stop()
super().reject()
# ─────────────────────────────────────────────────────────────────────────────
# AI 引擎设置对话框
# ─────────────────────────────────────────────────────────────────────────────
from PyQt5.QtCore import QSettings
AI_SETTINGS_ORG = "IrisWaterQuality"
AI_SETTINGS_APP = "WQ_GUI"
AI_DEFAULTS = {
"ollama": {
"api_base_url": "http://localhost:11434",
"vision_model": "qwen3-vl:8b",
"text_model": "qwen3-vl:8b",
},
"minimax": {
"api_base_url": "https://api.minimaxi.com/v1/text/chatcompletion_v2",
"vision_model": "abab6.5s-chat",
"text_model": "abab6.5s-chat",
},
}
class AISettingsDialog(QDialog):
"""AI 引擎可视化配置弹窗,配置持久化到 QSettings。"""
def __init__(self, parent=None):
super().__init__(parent)
self.setWindowTitle("AI 引擎配置")
self.setModal(True)
self.setMinimumWidth(520)
self._load_settings()
self._init_ui()
def _load_settings(self):
"""从 QSettings 读取已有配置;无记录则回退到环境变量或默认值。"""
s = QSettings(AI_SETTINGS_ORG, AI_SETTINGS_APP)
self._provider = s.value("ai_provider", "minimax", type=str)
# API Key 不设默认值(敏感信息,首次必须由用户输入)
self._api_key = s.value("minimax_api_key", "", type=str)
# 已保存的 URL 和模型;若 QSettings 无记录则读环境变量
if self._provider == "ollama":
self._api_base_url = (
s.value("api_base_url", "")
or os.environ.get("OLLAMA_URL", AI_DEFAULTS["ollama"]["api_base_url"])
)
self._vision_model = (
s.value("vision_model", "")
or os.environ.get("OLLAMA_VISION_MODEL", AI_DEFAULTS["ollama"]["vision_model"])
)
self._text_model = (
s.value("text_model", "")
or os.environ.get("OLLAMA_TEXT_MODEL", AI_DEFAULTS["ollama"]["text_model"])
)
else:
self._api_base_url = (
s.value("api_base_url", "")
or os.environ.get("MINIMAX_BASE_URL", AI_DEFAULTS["minimax"]["api_base_url"])
)
self._vision_model = (
s.value("vision_model", "")
or os.environ.get("MINIMAX_VISION_MODEL", AI_DEFAULTS["minimax"]["vision_model"])
)
self._text_model = (
s.value("text_model", "")
or os.environ.get("MINIMAX_TEXT_MODEL", AI_DEFAULTS["minimax"]["text_model"])
)
self._timeout = s.value("timeout_s", 120, type=int)
def _init_ui(self):
layout = QVBoxLayout(self)
layout.setSpacing(12)
# ── Provider ──────────────────────────────────────────────────────────
provider_row = QHBoxLayout()
provider_row.addWidget(QLabel("AI 引擎提供商:"))
self._provider_combo = QComboBox()
self._provider_combo.addItems(["Ollama", "Minimax"])
self._provider_combo.setCurrentText("Ollama" if self._provider == "ollama" else "Minimax")
self._provider_combo.currentIndexChanged.connect(self._on_provider_changed)
provider_row.addWidget(self._provider_combo, 1)
provider_row.addStretch(1)
layout.addLayout(provider_row)
# ── API Base URL ───────────────────────────────────────────────────────
url_row = QHBoxLayout()
url_row.addWidget(QLabel("API Base URL:"))
self._url_edit = QLineEdit(self._api_base_url)
self._url_edit.setPlaceholderText("例如: http://localhost:11434")
url_row.addWidget(self._url_edit, 1)
layout.addLayout(url_row)
# ── API Key ───────────────────────────────────────────────────────────
key_row = QHBoxLayout()
key_row.addWidget(QLabel("API Key:"))
self._key_edit = QLineEdit(self._api_key)
self._key_edit.setPlaceholderText("输入 API Key(敏感信息,已加密存储)")
self._key_edit.setEchoMode(QLineEdit.Password)
key_row.addWidget(self._key_edit, 1)
layout.addLayout(key_row)
# ── 模型名称 ───────────────────────────────────────────────────────────
model_row = QHBoxLayout()
model_row.addWidget(QLabel("视觉模型:"))
self._vision_edit = QLineEdit(self._vision_model)
model_row.addWidget(self._vision_edit, 1)
model_row.addSpacing(12)
model_row.addWidget(QLabel("文本模型:"))
self._text_edit = QLineEdit(self._text_model)
model_row.addWidget(self._text_edit, 1)
layout.addLayout(model_row)
# ── 超时 ──────────────────────────────────────────────────────────────
timeout_row = QHBoxLayout()
timeout_row.addWidget(QLabel("请求超时(秒):"))
self._timeout_spin = QSpinBox()
self._timeout_spin.setRange(30, 3600)
self._timeout_spin.setSingleStep(30)
self._timeout_spin.setValue(self._timeout)
timeout_row.addWidget(self._timeout_spin)
timeout_row.addStretch(1)
layout.addLayout(timeout_row)
# ── 说明 ──────────────────────────────────────────────────────────────
hint = QLabel(
"提示:切换引擎后将自动填充推荐默认值(可手动修改)。"
"API Key 仅本地加密存储,不会明文暴露。"
)
hint.setStyleSheet("color: #888; font-size: 10px;")
hint.setWordWrap(True)
layout.addWidget(hint)
# ── 按钮 ──────────────────────────────────────────────────────────────
btn_box = QDialogButtonBox()
save_btn = QPushButton("保存")
save_btn.setDefault(True)
save_btn.clicked.connect(self._save_and_close)
cancel_btn = QPushButton("取消")
cancel_btn.clicked.connect(self.reject)
btn_box.addButton(save_btn, QDialogButtonBox.AcceptRole)
btn_box.addButton(cancel_btn, QDialogButtonBox.RejectRole)
layout.addWidget(btn_box)
def _on_provider_changed(self):
"""切换 Provider 时自动填充推荐默认值。"""
provider = self._provider_combo.currentText().lower()
defaults = AI_DEFAULTS.get(provider, AI_DEFAULTS["minimax"])
self._url_edit.setText(defaults["api_base_url"])
self._vision_edit.setText(defaults["vision_model"])
self._text_edit.setText(defaults["text_model"])
def _save_and_close(self):
"""持久化到 QSettings 并关闭。"""
s = QSettings(AI_SETTINGS_ORG, AI_SETTINGS_APP)
provider = self._provider_combo.currentText().lower()
s.setValue("ai_provider", provider)
s.setValue("api_base_url", self._url_edit.text().strip())
s.setValue("api_key", self._key_edit.text().strip())
s.setValue("vision_model", self._vision_edit.text().strip())
s.setValue("text_model", self._text_edit.text().strip())
s.setValue("timeout_s", self._timeout_spin.value())
s.sync()
self.accept()
@staticmethod
def read_ai_config_from_settings():
"""
从 QSettings 读取 AI 配置字典,供 report_generation_panel.py 等处使用。
返回键:ai_provider / api_base_url / api_key / vision_model / text_model / timeout_s
"""
s = QSettings(AI_SETTINGS_ORG, AI_SETTINGS_APP)
provider = s.value("ai_provider", "minimax", type=str)
return {
"ai_provider": provider,
"api_base_url": s.value("api_base_url", "", type=str),
"api_key": s.value("api_key", "", type=str),
"vision_model": s.value("vision_model", "", type=str),
"text_model": s.value("text_model", "", type=str),
"timeout_s": s.value("timeout_s", 120, type=int),
}
# ─────────────────────────────────────────────────────────────────────────────
# 交互式采样点与光谱查看器
# ─────────────────────────────────────────────────────────────────────────────
import matplotlib.pyplot as _plt
# 全局字体设置(防中文乱码 + 负号显示异常)
_plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'Arial Unicode MS']
_plt.rcParams['axes.unicode_minus'] = False
class SamplingViewerDialog(QDialog):
"""交互式采样点与光谱查看器
左右分栏:
- 左侧:matplotlib 散点图(pixel_x / pixel_y)
- 右侧:上方信息面板(坐标),下方子图画光谱曲线
点击散点图任意点 → 定位最近采样点 → 高亮 + 显示坐标 + 绘制该点光谱
"""
def __init__(self, csv_path: str, parent: Optional[QWidget] = None):
super().__init__(parent)
self.csv_path = csv_path
self.df: Optional[pd.DataFrame] = None
self.selected_idx: Optional[int] = None
self._init_ui()
self._load_data()
self._draw_scatter()
def _init_ui(self):
"""搭建对话框 UI"""
self.setWindowTitle("采样点与光谱查看器")
self.setModal(False)
self.resize(1100, 600)
main_layout = QHBoxLayout(self)
# --- 左侧:matplotlib 散点图画布 ---
self._fig = Figure(figsize=(6, 5))
self._canvas = FigureCanvasQTAgg(self._fig)
self._ax_scatter = self._fig.add_subplot(111)
self._ax_scatter.set_xlabel("像素 X")
self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
self._fig.tight_layout()
main_layout.addWidget(self._canvas, stretch=2)
# --- 右侧:信息面板 + 光谱子图 ---
right_widget = QWidget()
right_layout = QVBoxLayout(right_widget)
right_layout.setContentsMargins(0, 0, 0, 0)
# 坐标信息面板(多行中文清晰显示)
self._info_label = QLabel("点击左侧散点图选择采样点")
self._info_label.setStyleSheet(
"QLabel { background-color: #f0f0f0; padding: 8px; "
"border: 1px solid #ccc; border-radius: 4px; font-size: 13px; }"
)
right_layout.addWidget(self._info_label)
# 采样点列表迷你表格
self._point_table = QTableWidget()
self._point_table.setColumnCount(3)
self._point_table.setHorizontalHeaderLabels(["像素 X", "像素 Y", "序号"])
self._point_table.setMaximumHeight(120)
self._point_table.setEditTriggers(QTableWidget.NoEditTriggers)
self._point_table.setSelectionBehavior(QTableWidget.SelectRows)
right_layout.addWidget(QLabel("采样点列表(共 N 个):"))
right_layout.addWidget(self._point_table)
# 光谱曲线子图
self._fig_right = Figure(figsize=(5, 3))
self._ax_spectrum = self._fig_right.add_subplot(111)
self._ax_spectrum.set_xlabel("波段序号")
self._ax_spectrum.set_ylabel("反射率")
self._ax_spectrum.set_title("光谱曲线")
self._fig_right.tight_layout()
self._canvas_right = FigureCanvasQTAgg(self._fig_right)
right_layout.addWidget(self._canvas_right, stretch=1)
main_layout.addWidget(right_widget, stretch=1)
self.setLayout(main_layout)
# 绑定鼠标点击事件
self._cid = self._canvas.mpl_connect('button_press_event', self._on_click)
def _load_data(self):
"""加载 CSV 数据"""
if not os.path.exists(self.csv_path):
self.df = None
return
self.df = pd.read_csv(self.csv_path)
self._info_label.setText(
f"共加载 {len(self.df)} 个采样点,文件:{os.path.basename(self.csv_path)}"
)
def _draw_scatter(self):
"""绘制散点图"""
self._ax_scatter.clear()
self._ax_scatter.set_xlabel("像素 X")
self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title("采样点分布(点击查看详情)")
self._ax_scatter.invert_yaxis()
if self.df is None:
self._ax_scatter.text(
0.5, 0.5, "无采样数据",
ha='center', va='center', transform=self._ax_scatter.transAxes,
color='gray', fontsize=14
)
self._fig.canvas.draw_idle()
return
xs = self.df['pixel_x'].values
ys = self.df['pixel_y'].values
self._ax_scatter.scatter(xs, ys, s=15, alpha=0.7, color='#2196F3')
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
self._fig.canvas.draw_idle()
# 填充迷你表格
self._point_table.blockSignals(True)
self._point_table.setRowCount(min(len(self.df), 200))
for i, (_, row) in enumerate(self.df.iterrows()):
if i >= 200:
break
self._point_table.setItem(
i, 0, QTableWidgetItem(str(int(row.get('pixel_x', 0))))
)
self._point_table.setItem(
i, 1, QTableWidgetItem(str(int(row.get('pixel_y', 0))))
)
self._point_table.setItem(i, 2, QTableWidgetItem(str(i)))
self._point_table.blockSignals(False)
self._point_table.setRowCount(min(len(self.df), 200))
def _on_click(self, event):
"""鼠标点击散点图 → 找最近点 + 高亮 + 画光谱"""
if self.df is None or event.xdata is None or event.ydata is None:
return
click_x, click_y = event.xdata, event.ydata
distances = np.sqrt(
(self.df['pixel_x'].values - click_x) ** 2 +
(self.df['pixel_y'].values - click_y) ** 2
)
nearest_idx = int(np.argmin(distances))
self.selected_idx = nearest_idx
row = self.df.iloc[nearest_idx]
pixel_x = int(row.get('pixel_x', 0))
pixel_y = int(row.get('pixel_y', 0))
x_coord = row.get('x_coord', 'N/A')
y_coord = row.get('y_coord', 'N/A')
self._info_label.setText(
f"选中的采样点 #{nearest_idx}
"
f"图像像素坐标: X = {pixel_x}, Y = {pixel_y}
"
f"地理真实坐标: 经度(X) = {x_coord}, 纬度(Y) = {y_coord}"
)
# 高亮散点图
self._ax_scatter.clear()
self._ax_scatter.set_xlabel("像素 X")
self._ax_scatter.set_ylabel("像素 Y")
self._ax_scatter.set_title(f"采样点分布(共 {len(self.df)} 个)")
self._ax_scatter.invert_yaxis()
self._ax_scatter.scatter(
self.df['pixel_x'].values, self.df['pixel_y'].values,
s=15, alpha=0.5, color='#90CAF9'
)
self._ax_scatter.scatter(
[pixel_x], [pixel_y], s=80, alpha=0.9,
color='red', zorder=5, label=f"#{nearest_idx}"
)
self._ax_scatter.legend()
self._fig.canvas.draw_idle()
self._draw_spectrum(row)
def _draw_spectrum(self, row: pd.Series):
"""从一行数据中提取纯波段数值并绘图"""
self._ax_spectrum.clear()
self._ax_spectrum.set_xlabel("波段序号")
self._ax_spectrum.set_ylabel("反射率")
self._ax_spectrum.set_title("光谱曲线")
exclude_patterns = (
'pixel_x', 'pixel_y', 'x_coord', 'y_coord',
'wqi', 'index', 'Unnamed', 'id', 'ID',
'longitude', 'latitude', 'lon', 'lat',
)
spectral_cols = [
col for col in self.df.columns
if not any(p in col.lower() for p in exclude_patterns)
and pd.api.types.is_numeric_dtype(self.df[col])
]
if not spectral_cols:
self._ax_spectrum.text(
0.5, 0.5, "无可用波段数据",
ha='center', va='center', transform=self._ax_spectrum.transAxes,
color='gray', fontsize=11
)
self._fig_right.canvas.draw_idle()
return
band_values = []
for col in spectral_cols:
val = row[col]
try:
band_values.append(float(val))
except (ValueError, TypeError):
band_values.append(np.nan)
band_indices = np.arange(len(band_values), dtype=float)
valid = ~np.isnan(band_values)
if not valid.any():
self._ax_spectrum.text(
0.5, 0.5, "波段数据全为无效值",
ha='center', va='center', transform=self._ax_spectrum.transAxes,
color='gray', fontsize=11
)
else:
self._ax_spectrum.plot(
band_indices[valid], np.array(band_values)[valid],
color='#1E88E5', linewidth=1.2
)
self._ax_spectrum.grid(True, alpha=0.3)
self._fig_right.canvas.draw_idle()