From 191a4b681df0f30993fdf291489ba5d096314e86 Mon Sep 17 00:00:00 2001 From: DXC Date: Wed, 17 Jun 2026 15:16:19 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4=E4=B8=BB?= =?UTF-8?q?=E7=95=8C=E9=9D=A2=E9=87=8D=E5=A4=8D=E4=BB=A3=E7=A0=81=EF=BC=8C?= =?UTF-8?q?=E5=A4=8D=E7=94=A8=E7=8E=B0=E6=9C=89=E7=BB=84=E4=BB=B6=E5=B9=B6?= =?UTF-8?q?=E5=BD=BB=E5=BA=95=E6=8A=BD=E7=A6=BB=E5=9B=BE=E5=83=8F=E6=8E=A7?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/gui/components/chart_dialogs.py | 45 +- src/gui/components/image_viewer_components.py | 374 ++++++ src/gui/core/viz_thread.py | 5 + src/gui/water_quality_gui.py | 1170 +---------------- 4 files changed, 425 insertions(+), 1169 deletions(-) create mode 100644 src/gui/components/image_viewer_components.py diff --git a/src/gui/components/chart_dialogs.py b/src/gui/components/chart_dialogs.py index 61d94d0..2d8a249 100644 --- a/src/gui/components/chart_dialogs.py +++ b/src/gui/components/chart_dialogs.py @@ -7,12 +7,13 @@ """ import numpy as np +import pandas as pd from PyQt5.QtWidgets import ( QDialog, QVBoxLayout, QHBoxLayout, QPushButton, QSizePolicy, QFileDialog, QMessageBox, QGroupBox, QListWidget, QLabel, QComboBox, QCheckBox, ) -from PyQt5.QtCore import Qt +from PyQt5.QtCore import Qt, QAbstractTableModel from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar from matplotlib.figure import Figure @@ -427,4 +428,44 @@ class InteractiveViewerDialog(QDialog): f"像素: (行={py}, 列={px}) | {geo_str} | " f"波段值: {' | '.join(vals[:5])}" + (f" ... ({n_bands}波段的更多信息)" if n_bands > 5 else "") - ) \ No newline at end of file + ) + + +class PandasTableModel(QAbstractTableModel): + """支持DataFrame的表格模型""" + def __init__(self, data_frame: pd.DataFrame): + super().__init__() + self._data = data_frame.copy() + if self._data.empty: + self._data = pd.DataFrame() + self._data.fillna("", inplace=True) + self._columns = [str(col) for col in self._data.columns] + + def rowCount(self, parent=None): + return len(self._data) + + def columnCount(self, parent=None): + return len(self._columns) + + def data(self, index, role=Qt.DisplayRole): + if not index.isValid() or role != Qt.DisplayRole: + return None + + value = self._data.iat[index.row(), index.column()] + if pd.isna(value): + return "" + return str(value) + + def headerData(self, section, orientation, role=Qt.DisplayRole): + if role != Qt.DisplayRole: + return None + if orientation == Qt.Horizontal: + if section < len(self._columns): + return self._columns[section] + return str(section) + return str(section + 1) + + def flags(self, index): + if not index.isValid(): + return Qt.NoItemFlags + return Qt.ItemIsEnabled | Qt.ItemIsSelectable \ No newline at end of file diff --git a/src/gui/components/image_viewer_components.py b/src/gui/components/image_viewer_components.py new file mode 100644 index 0000000..22808ac --- /dev/null +++ b/src/gui/components/image_viewer_components.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +图像查看器组件模块 + +包含 ImageCategoryTree 和 ImageViewerWidget 类。 +""" + +import os +from pathlib import Path +from typing import Optional, List + +from PyQt5.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QPushButton, + QFrame, QScrollArea, QLabel, QFileDialog, QMessageBox, + QTreeWidget, QTreeWidgetItem, +) +from PyQt5.QtCore import Qt, QTimer +from PyQt5.QtGui import QPixmap + + +class ImageCategoryTree(QTreeWidget): + """图像分类目录树 - 按类别组织图像文件""" + + # 图像类别定义:(类别名称, 关键词列表, 图标) + CATEGORIES = [ + ("模型评估", ["scatter", "regression", "validation", "r2", "rmse"], "📊"), + ("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"), + ("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"), + ("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"), + ("含量分布图", [], "📁"), + ] + + def __init__(self, parent=None): + super().__init__(parent) + self.setHeaderLabel("图像目录") + self.setMaximumWidth(300) + self.setMinimumWidth(250) + self.setup_categories() + self.setStyleSheet(""" + QTreeWidget { + border: 1px solid #ddd; + border-radius: 5px; + background-color: #f8f9fa; + } + QTreeWidget::item { + padding: 5px; + border-radius: 3px; + } + QTreeWidget::item:selected { + background-color: #0078D4; + color: white; + } + QTreeWidget::item:hover { + background-color: #e3f2fd; + } + """) + + def setup_categories(self): + """初始化类别节点""" + self.category_items = {} + for category_name, keywords, icon in self.CATEGORIES: + item = QTreeWidgetItem(self) + item.setText(0, f"{icon} {category_name}") + item.setData(0, Qt.UserRole, {"type": "category", "keywords": keywords, "name": category_name}) + item.setExpanded(True) + self.category_items[category_name] = item + + def clear_all_images(self): + """清除所有图像项""" + for category_item in self.category_items.values(): + # 删除所有子项 + while category_item.childCount() > 0: + category_item.removeChild(category_item.child(0)) + + def add_image(self, file_path: Path, display_name: str = None): + """添加图像到对应的类别""" + if display_name is None: + display_name = file_path.stem + + # 根据文件名关键词确定类别 + category = self._determine_category(file_path.name) + category_item = self.category_items.get(category, self.category_items["含量分布图"]) + + # 创建图像项 + image_item = QTreeWidgetItem(category_item) + image_item.setText(0, f" └─ {display_name}") + image_item.setData(0, Qt.UserRole, {"type": "image", "path": str(file_path)}) + image_item.setToolTip(0, str(file_path)) + + return image_item + + def _determine_category(self, filename: str) -> str: + """根据文件名确定类别""" + filename_lower = filename.lower() + + for category_name, keywords, _ in self.CATEGORIES: + if any(keyword in filename_lower for keyword in keywords): + return category_name + + return "含量分布图" + + def scan_directory(self, work_dir: str): + """扫描目录中的所有图像文件""" + self.clear_all_images() + + work_path = Path(work_dir) + if not work_path.exists(): + return + + # 查找所有图像文件:14_visualization 为主,同时扫描步骤产出目录(如 1_water_mask 下的预览/叠置图) + image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff', '*.bmp'] + scan_roots: List[Path] = [] + _viz = work_path / "14_visualization" + if _viz.is_dir(): + scan_roots.append(_viz) + _wm = work_path / "1_water_mask" + if _wm.is_dir(): + scan_roots.append(_wm) + if not scan_roots: + scan_roots.append(work_path) + + seen_norm: set = set() + image_files: List[Path] = [] + for root in scan_roots: + for ext in image_extensions: + for p in root.glob(f"**/{ext}"): + key = os.path.normcase(os.path.normpath(str(p.resolve()))) + if key in seen_norm: + continue + seen_norm.add(key) + image_files.append(p) + + # 添加图像到树 + for img_file in sorted(image_files): + # 跳过缩略图和临时文件 + if img_file.name.startswith('.') or 'thumb' in img_file.name.lower(): + continue + self.add_image(img_file) + + # 更新类别项文本显示数量 + for category_name, item in self.category_items.items(): + count = item.childCount() + if count > 0: + for cat_name, _, icon in self.CATEGORIES: + if cat_name == category_name: + item.setText(0, f"{icon} {category_name} ({count})") + break + + def get_selected_image_path(self) -> Optional[str]: + """获取当前选中的图像路径""" + selected_item = self.currentItem() + if not selected_item: + return None + + data = selected_item.data(0, Qt.UserRole) + if data and data.get("type") == "image": + return data.get("path") + return None + + +class ImageViewerWidget(QWidget): + """图像查看器组件 - 支持缩放、平移""" + + def __init__(self, parent=None): + super().__init__(parent) + self.current_image_path = None + self.scale_factor = 1.0 + self._update_timer = QTimer() # 防抖定时器 + self._update_timer.setSingleShot(True) + self._update_timer.timeout.connect(self._do_update_display) + self._pending_scale = None # 待更新的缩放比例 + self.setup_ui() + + def setup_ui(self): + layout = QVBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + + # 工具栏 + toolbar = QHBoxLayout() + + self.refresh_btn = QPushButton("🔄 刷新目录") + self.refresh_btn.setToolTip("重新扫描工作目录中的图像文件") + toolbar.addWidget(self.refresh_btn) + + # 添加分隔线 + separator = QFrame() + separator.setFrameShape(QFrame.VLine) + separator.setFrameShadow(QFrame.Sunken) + toolbar.addWidget(separator) + + self.zoom_in_btn = QPushButton("🔍+") + self.zoom_in_btn.setToolTip("放大") + self.zoom_in_btn.setMaximumWidth(50) + toolbar.addWidget(self.zoom_in_btn) + + self.zoom_out_btn = QPushButton("🔍-") + self.zoom_out_btn.setToolTip("缩小") + self.zoom_out_btn.setMaximumWidth(50) + toolbar.addWidget(self.zoom_out_btn) + + self.fit_btn = QPushButton("⬜ 适应窗口") + self.fit_btn.setToolTip("适应窗口大小") + toolbar.addWidget(self.fit_btn) + + self.original_btn = QPushButton("1:1 原始大小") + self.original_btn.setToolTip("原始大小") + toolbar.addWidget(self.original_btn) + + toolbar.addStretch() + + self.save_btn = QPushButton("💾 保存") + self.save_btn.setToolTip("保存当前图像") + toolbar.addWidget(self.save_btn) + + layout.addLayout(toolbar) + + # 图像显示区域 - 使用 QLabel + QScrollArea + self.scroll_area = QScrollArea() + self.scroll_area.setWidgetResizable(True) + self.scroll_area.setStyleSheet("background-color: white;") + + self.image_label = QLabel() + self.image_label.setAlignment(Qt.AlignCenter) + self.image_label.setStyleSheet("background-color: white;") + + self.scroll_area.setWidget(self.image_label) + layout.addWidget(self.scroll_area, 1) + + # 状态栏 + status_layout = QHBoxLayout() + self.status_label = QLabel("就绪") + self.status_label.setStyleSheet("color: #666; font-size: 11px;") + status_layout.addWidget(self.status_label) + status_layout.addStretch() + layout.addLayout(status_layout) + + self.setLayout(layout) + + # 连接信号 + self.zoom_in_btn.clicked.connect(self.zoom_in) + self.zoom_out_btn.clicked.connect(self.zoom_out) + self.fit_btn.clicked.connect(self.fit_to_window) + self.original_btn.clicked.connect(self.original_size) + self.save_btn.clicked.connect(self.save_image) + + def load_image(self, image_path: str): + """加载并显示图像""" + if not image_path or not Path(image_path).exists(): + self.image_label.setText("图像不存在") + self.status_label.setText("图像加载失败") + return + + self.current_image_path = image_path + self.scale_factor = 1.0 + + # 加载图像 + pixmap = QPixmap(image_path) + if pixmap.isNull(): + self.image_label.setText("无法加载图像") + self.status_label.setText("图像格式不支持") + return + + self.original_pixmap = pixmap + + # 默认适应窗口显示 + self.fit_to_window() + + # 更新状态 + file_info = Path(image_path).stat() + size_mb = file_info.st_size / (1024 * 1024) + self.status_label.setText(f"{pixmap.width()}x{pixmap.height()} | {size_mb:.2f} MB | {Path(image_path).name} | 适应窗口") + + def update_image_display(self): + """更新图像显示 - 使用防抖避免频繁重绘卡顿""" + # 取消之前的待执行更新,重新计时 + self._update_timer.stop() + self._pending_scale = self.scale_factor + self._update_timer.start(50) # 50ms后执行实际更新 + + def _do_update_display(self): + """实际执行图像更新""" + if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull(): + return + + if self._pending_scale is None: + return + + # 根据缩放比例选择变换模式:大幅度缩放用Fast模式提升性能 + if self._pending_scale > 2.0 or self._pending_scale < 0.5: + transform = Qt.FastTransformation + else: + transform = Qt.SmoothTransformation + + scaled_pixmap = self.original_pixmap.scaled( + int(self.original_pixmap.width() * self._pending_scale), + int(self.original_pixmap.height() * self._pending_scale), + Qt.KeepAspectRatio, + transform + ) + self.image_label.setPixmap(scaled_pixmap) + self._pending_scale = None + + def wheelEvent(self, event): + """鼠标滚轮缩放 - 实时响应""" + delta = event.angleDelta().y() + + if delta > 0: + # 向上滚动 - 放大 + if self.scale_factor < 5.0: + self.scale_factor = min(self.scale_factor * 1.1, 5.0) + self.update_image_display() + else: + # 向下滚动 - 缩小 + if self.scale_factor > 0.1: + self.scale_factor = max(self.scale_factor / 1.1, 0.1) + self.update_image_display() + + event.accept() + + def zoom_in(self): + """放大""" + if self.scale_factor < 5.0: + self.scale_factor = min(self.scale_factor * 1.25, 5.0) + self.update_image_display() + + def zoom_out(self): + """缩小""" + if self.scale_factor > 0.1: + self.scale_factor = max(self.scale_factor / 1.25, 0.1) + self.update_image_display() + + def fit_to_window(self): + """适应窗口""" + if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull(): + return + + # 计算适应窗口的缩放比例 + view_size = self.scroll_area.viewport().size() + img_size = self.original_pixmap.size() + + scale_w = view_size.width() / img_size.width() + scale_h = view_size.height() / img_size.height() + + # 记录适应前的比例(用于后续恢复参考) + self._fit_scale = min(scale_w, scale_h) + self.scale_factor = self._fit_scale + + self.update_image_display() + self.status_label.setText(f"适应窗口 | 缩放: {self.scale_factor:.1%}") + + def original_size(self): + """原始大小""" + self.scale_factor = 1.0 + self._fit_scale = None # 清除适应记录 + self.update_image_display() + self.status_label.setText("原始大小 | 缩放: 100%") + + def save_image(self): + """保存图像""" + if not self.current_image_path: + return + + file_path, _ = QFileDialog.getSaveFileName( + self, "保存图像", Path(self.current_image_path).name, + "PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)" + ) + + if file_path: + try: + import shutil + shutil.copy(self.current_image_path, file_path) + except Exception as e: + QMessageBox.critical(self, "错误", f"保存失败: {e}") diff --git a/src/gui/core/viz_thread.py b/src/gui/core/viz_thread.py index 566955d..9dfe9b3 100644 --- a/src/gui/core/viz_thread.py +++ b/src/gui/core/viz_thread.py @@ -13,6 +13,11 @@ from PyQt5.QtCore import QThread, pyqtSignal import numpy as np +def _viz_training_spectra_csv_path(work_path: Path) -> Path: + """可视化光谱/统计及模型散点图使用的训练光谱表路径(与步骤5输出一致)。""" + return work_path / "5_training_spectra" / "training_spectra.csv" + + def _viz_infer_wavelength_start_column(df) -> Union[str, int]: """推断光谱起始列(training_spectra 通常以波长数值为列名,未必含 UTM_Y)。""" import pandas as pd diff --git a/src/gui/water_quality_gui.py b/src/gui/water_quality_gui.py index 955ac89..7fecefd 100644 --- a/src/gui/water_quality_gui.py +++ b/src/gui/water_quality_gui.py @@ -112,6 +112,8 @@ except ImportError: # 导入自定义组件 from src.gui.components.custom_widgets import FileSelectWidget +from src.gui.components.chart_dialogs import ChartViewerDialog, ChartBrowserDialog, InteractiveViewerDialog, PandasTableModel +from src.gui.components.image_viewer_components import ImageCategoryTree, ImageViewerWidget # 导入面板组件 from src.gui.panels.step1_panel import Step1Panel @@ -155,1173 +157,7 @@ from src.gui.core.worker_thread import ( # 预检交互对话框 from src.gui.core.preflight_dialog import PreflightDialog from src.gui.core.pipeline_mode_dialog import PipelineModeDialog - - -def _viz_training_spectra_csv_path(work_path: Path) -> Path: - """可视化光谱/统计及模型散点图使用的训练光谱表路径(与步骤5输出一致)。""" - return work_path / "5_training_spectra" / "training_spectra.csv" - - -class VisualizationWorkerThread(QThread): - """可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。""" - - finished_ok = pyqtSignal(object) - failed = pyqtSignal(str) - - def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None): - super().__init__() - self.task = task - self.work_dir = str(work_dir) - self.extra = extra or {} - - def run(self): - mpl_prev = None - try: - import matplotlib - mpl_prev = matplotlib.get_backend() - except Exception: - pass - try: - import matplotlib.pyplot as plt - plt.switch_backend("Agg") - except Exception: - mpl_prev = None - try: - wp = Path(self.work_dir) - if self.task == "mask_glint": - from src.postprocessing.visualization_reports import WaterQualityVisualization - viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization")) - preview_paths = viz.generate_glint_deglint_previews( - work_dir=str(wp), - output_subdir="glint_deglint_previews", - ) - cnt = len(preview_paths) if preview_paths else 0 - self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths}) - elif self.task == "sampling_map": - hyperspectral_files = [] - deglint_dir = wp / "3_deglint" - if deglint_dir.exists(): - for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): - hyperspectral_files.extend(list(deglint_dir.glob(ext))) - if not hyperspectral_files: - for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): - hyperspectral_files.extend(list(wp.glob(f"**/{ext}"))) - if not hyperspectral_files: - self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif)。") - return - hyperspectral_path = str(hyperspectral_files[0]) - csv_files = [] - processed_dir = wp / "4_processed_data" - if processed_dir.exists(): - csv_files = list(processed_dir.glob("*.csv")) - if not csv_files: - csv_files = ( - list(wp.glob("**/*sampling*.csv")) - + list(wp.glob("**/*point*.csv")) - + list(wp.glob("**/*.csv")) - ) - if not csv_files: - self.failed.emit("未找到采样点 CSV 文件。") - return - csv_path = str(csv_files[0]) - from src.postprocessing.point_map import SamplingPointMap - map_generator = SamplingPointMap( - output_dir=str(wp / "14_visualization" / "sampling_maps"), - fast_mode=True, - ) - map_path = map_generator.create_sampling_point_map( - hyperspectral_path=hyperspectral_path, - csv_path=csv_path, - point_color="red", - point_size=100, - point_alpha=0.9, - show_north_arrow=True, - show_scale_bar=True, - show_legend=True, - downsample=True, - dpi=180, - ) - self.finished_ok.emit( - { - "task": "sampling_map", - "map_path": map_path, - "hyperspectral_path": hyperspectral_path, - "csv_path": csv_path, - } - ) - elif self.task == "spectrum": - from src.postprocessing.visualization_reports import WaterQualityVisualization - viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization")) - csv_file = self.extra.get("csv_path") - wl = self.extra.get("wavelength_start_column", "UTM_Y") - n_groups = int(self.extra.get("n_groups", 5)) - param_cols = self.extra.get("param_cols") or [] - if param_cols: - output_paths: List[str] = [] - err_lines: List[str] = [] - for param_col in param_cols: - try: - out = viz.plot_spectrum_by_parameter( - csv_path=str(csv_file), - parameter_column=param_col, - wavelength_start_column=wl, - n_groups=n_groups, - ) - output_paths.append(out) - except Exception as _ex: - err_lines.append(f"{param_col}: {_ex}") - if not output_paths: - self.failed.emit( - "所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20]) - ) - return - self.finished_ok.emit( - { - "task": "spectrum", - "output_paths": output_paths, - "errors": err_lines, - } - ) - else: - param_col = self.extra.get("param_col") - out = viz.plot_spectrum_by_parameter( - csv_path=str(csv_file), - parameter_column=param_col, - wavelength_start_column=wl, - n_groups=n_groups, - ) - self.finished_ok.emit( - {"task": "spectrum", "output_path": out, "param_col": param_col} - ) - elif self.task == "statistics": - from src.postprocessing.visualization_reports import WaterQualityVisualization - viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization")) - csv_file = self.extra.get("csv_path") - param_cols = self.extra.get("param_cols") or [] - output_paths = viz.plot_statistical_charts( - csv_path=str(csv_file), - parameter_columns=param_cols, - ) - self.finished_ok.emit( - {"task": "statistics", "output_paths": output_paths} - ) - elif self.task == "scatter": - from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline - - training_csv_path = (self.extra.get("training_csv_path") or "").strip() - models_dir = (self.extra.get("models_dir") or "").strip() - if not training_csv_path or not Path(training_csv_path).is_file(): - self.failed.emit("训练光谱 CSV 无效或不存在,请确认已选择步骤5输出的文件。") - return - if not models_dir or not Path(models_dir).is_dir(): - self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。") - return - pipeline = WaterQualityInversionPipeline(work_dir=str(wp)) - scatter_paths = pipeline.generate_model_scatter_plots( - training_csv_path=training_csv_path, - models_dir=models_dir, - ) - self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}}) - elif self.task == "generate_all_selected": - from src.postprocessing.visualization_reports import WaterQualityVisualization - viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization")) - parts = [] - - # 获取训练数据CSV路径(多个图表类型共用) - training_csv = wp / "5_training_spectra" / "training_spectra.csv" - - # 生成散点图 - if self.extra.get("gen_scatter"): - if training_csv.is_file(): - models_dir = wp / "7_Supervised_Model_Training" - if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()): - from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline - pipeline = WaterQualityInversionPipeline(work_dir=str(wp)) - scatter_paths = pipeline.generate_model_scatter_plots( - training_csv_path=str(training_csv), - models_dir=str(models_dir), - ) - count = len(scatter_paths) if scatter_paths else 0 - parts.append(f"散点图: {count} 个") - else: - parts.append("散点图: 跳过(无模型目录)") - else: - parts.append("散点图: 跳过(无训练数据)") - - # 生成光谱图 - if self.extra.get("gen_spectrum"): - if training_csv.is_file(): - import pandas as pd - df = pd.read_csv(training_csv) - # 推断水质参数列(光谱波段列之前的数值型列) - wl_col = _viz_infer_wavelength_start_column(df) - if isinstance(wl_col, str): - idx = int(df.columns.get_loc(wl_col)) + 1 - else: - idx = int(wl_col) - param_cols = [] - if idx > 0 and idx < len(df.columns): - param_cols = [ - c for c in df.columns[:idx] - if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0 - ] - if param_cols: - # plot_spectrum_by_parameter 接受单个参数列,逐个调用 - spectrum_paths = [] - for param_col in param_cols: - try: - path = viz.plot_spectrum_by_parameter( - csv_path=str(training_csv), - parameter_column=param_col, - wavelength_start_column=wl_col, - n_groups=5, - ) - if path: - spectrum_paths.append(path) - except Exception as e: - print(f"生成光谱图失败 ({param_col}): {e}") - count = len(spectrum_paths) - parts.append(f"光谱图: {count} 个") - else: - parts.append("光谱图: 跳过(无可用参数列)") - else: - parts.append("光谱图: 跳过(无训练数据)") - - # 生成统计图 - if self.extra.get("gen_boxplots"): - if training_csv.is_file(): - import pandas as pd - df = pd.read_csv(training_csv) - # **只统计水质参数列(数值型),排除波长列** - # 获取水质参数列(数值型且不是波长、不是坐标列) - exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate'] - param_cols = [ - c for c in df.select_dtypes(include=[np.number]).columns - if not any(exc in c.lower() for exc in exclude_cols) - ] - # 排除光谱波长列:找到波长开始位置,只取之前的数值列 - wl = _viz_infer_wavelength_start_column(df) - if isinstance(wl, str): - idx = int(df.columns.get_loc(wl)) + 1 - else: - idx = int(wl) - if 0 < idx < len(df.columns): - meta_set = set(df.columns[:idx]) - param_cols = [c for c in param_cols if c in meta_set] - - if param_cols: - output_dict = viz.plot_statistical_charts( - csv_path=str(training_csv), - parameter_columns=param_cols, - ) - # plot_statistical_charts 返回字典,统计值非空 - count = len([v for v in output_dict.values() if v]) if output_dict else 0 - parts.append(f"统计图: {count} 个") - else: - parts.append("统计图: 跳过(无可用水质参数列)") - else: - parts.append("统计图: 跳过(无训练数据)") - - # 生成掩膜/耀斑预览图 - if self.extra.get("gen_mask_glint"): - preview_paths = viz.generate_glint_deglint_previews( - work_dir=str(wp), - output_subdir="glint_deglint_previews", - ) - parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0} 个") - - # 生成采样点地图 - if self.extra.get("gen_sampling_map"): - hyperspectral_files = [] - deglint_dir = wp / "3_deglint" - if deglint_dir.exists(): - for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): - hyperspectral_files.extend(list(deglint_dir.glob(ext))) - if not hyperspectral_files: - for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): - hyperspectral_files.extend(list(wp.glob(f"**/{ext}"))) - if hyperspectral_files: - hyperspectral_path = str(hyperspectral_files[0]) - csv_files = [] - processed_dir = wp / "4_processed_data" - if processed_dir.exists(): - csv_files = list(processed_dir.glob("*.csv")) - if not csv_files: - csv_files = ( - list(wp.glob("**/*sampling*.csv")) - + list(wp.glob("**/*point*.csv")) - + list(wp.glob("**/*.csv")) - ) - if csv_files: - csv_path = str(csv_files[0]) - from src.postprocessing.point_map import SamplingPointMap - map_generator = SamplingPointMap( - output_dir=str(wp / "14_visualization" / "sampling_maps"), - fast_mode=True, - ) - map_path = map_generator.create_sampling_point_map( - hyperspectral_path=hyperspectral_path, - csv_path=csv_path, - point_color="red", - point_size=100, - point_alpha=0.9, - show_north_arrow=True, - show_scale_bar=True, - show_legend=True, - downsample=True, - dpi=180, - ) - parts.append(f"采样点图: {Path(map_path).name}") - else: - parts.append("采样点图: 跳过(无CSV)") - else: - parts.append("采样点图: 跳过(无影像)") - self.finished_ok.emit({"task": "generate_all_selected", "parts": parts}) - else: - self.failed.emit(f"未知可视化任务: {self.task}") - except Exception as e: - self.failed.emit(f"{e}\n{traceback.format_exc()}") - finally: - if mpl_prev: - try: - import matplotlib.pyplot as plt - plt.switch_backend(mpl_prev) - except Exception: - pass - - -class PandasTableModel(QAbstractTableModel): - """支持DataFrame的表格模型""" - def __init__(self, data_frame: pd.DataFrame): - super().__init__() - self._data = data_frame.copy() - if self._data.empty: - self._data = pd.DataFrame() - self._data.fillna("", inplace=True) - self._columns = [str(col) for col in self._data.columns] - - def rowCount(self, parent=None): - return len(self._data) - - def columnCount(self, parent=None): - return len(self._columns) - - def data(self, index, role=Qt.DisplayRole): - if not index.isValid() or role != Qt.DisplayRole: - return None - - value = self._data.iat[index.row(), index.column()] - if pd.isna(value): - return "" - return str(value) - - def headerData(self, section, orientation, role=Qt.DisplayRole): - if role != Qt.DisplayRole: - return None - if orientation == Qt.Horizontal: - if section < len(self._columns): - return self._columns[section] - return str(section) - return str(section + 1) - - def flags(self, index): - if not index.isValid(): - return Qt.NoItemFlags - return Qt.ItemIsEnabled | Qt.ItemIsSelectable - - -class ChartViewerDialog(QDialog): - """图表查看器对话框""" - def __init__(self, title="图表查看器", parent=None): - super().__init__(parent) - self.setWindowTitle(title) - self.resize(1000, 700) - self.init_ui() - - def init_ui(self): - layout = QVBoxLayout() - - # 创建matplotlib图形 - self.figure = Figure(figsize=(10, 7)) - self.canvas = FigureCanvas(self.figure) - self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - - # 添加工具栏 - self.toolbar = NavigationToolbar(self.canvas, self) - - layout.addWidget(self.toolbar) - layout.addWidget(self.canvas) - - # 按钮组 - btn_layout = QHBoxLayout() - - self.save_btn = QPushButton("保存图表") - self.save_btn.clicked.connect(self.save_chart) - btn_layout.addWidget(self.save_btn) - - btn_layout.addStretch() - - self.close_btn = QPushButton("关闭") - self.close_btn.clicked.connect(self.close) - btn_layout.addWidget(self.close_btn) - - layout.addLayout(btn_layout) - self.setLayout(layout) - # 信号连接:影像文件路径变化时动态更新波段范围 - def display_image(self, image_path): - """显示图片""" - self.figure.clear() - ax = self.figure.add_subplot(111) - - try: - import matplotlib.image as mpimg - img = mpimg.imread(image_path) - ax.imshow(img) - ax.axis('off') - self.figure.tight_layout() - self.canvas.draw() - self.current_image_path = image_path - except Exception as e: - ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}', - ha='center', va='center', transform=ax.transAxes) - self.canvas.draw() - - def display_custom_plot(self, plot_func): - """显示自定义绘图函数""" - self.figure.clear() - try: - plot_func(self.figure) - self.canvas.draw() - except Exception as e: - ax = self.figure.add_subplot(111) - ax.text(0.5, 0.5, f'绘图失败:\n{str(e)}', - ha='center', va='center', transform=ax.transAxes) - self.canvas.draw() - - def save_chart(self): - """保存图表""" - file_path, _ = QFileDialog.getSaveFileName( - self, "保存图表", "", - "PNG图片 (*.png);;JPG图片 (*.jpg);;PDF文件 (*.pdf);;所有文件 (*.*)" - ) - if file_path: - try: - self.figure.savefig(file_path, dpi=300, bbox_inches='tight') - QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}") - except Exception as e: - QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}") - - -class ImageCategoryTree(QTreeWidget): - """图像分类目录树 - 按类别组织图像文件""" - - # 图像类别定义:(类别名称, 关键词列表, 图标) - CATEGORIES = [ - ("模型评估", ["scatter", "regression", "validation", "r2", "rmse"], "📊"), - ("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"), - ("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"), - ("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"), - ("含量分布图", [], "📁"), - ] - - def __init__(self, parent=None): - super().__init__(parent) - self.setHeaderLabel("图像目录") - self.setMaximumWidth(300) - self.setMinimumWidth(250) - self.setup_categories() - self.setStyleSheet(""" - QTreeWidget { - border: 1px solid #ddd; - border-radius: 5px; - background-color: #f8f9fa; - } - QTreeWidget::item { - padding: 5px; - border-radius: 3px; - } - QTreeWidget::item:selected { - background-color: #0078D4; - color: white; - } - QTreeWidget::item:hover { - background-color: #e3f2fd; - } - """) - - def setup_categories(self): - """初始化类别节点""" - self.category_items = {} - for category_name, keywords, icon in self.CATEGORIES: - item = QTreeWidgetItem(self) - item.setText(0, f"{icon} {category_name}") - item.setData(0, Qt.UserRole, {"type": "category", "keywords": keywords, "name": category_name}) - item.setExpanded(True) - self.category_items[category_name] = item - - def clear_all_images(self): - """清除所有图像项""" - for category_item in self.category_items.values(): - # 删除所有子项 - while category_item.childCount() > 0: - category_item.removeChild(category_item.child(0)) - - def add_image(self, file_path: Path, display_name: str = None): - """添加图像到对应的类别""" - if display_name is None: - display_name = file_path.stem - - # 根据文件名关键词确定类别 - category = self._determine_category(file_path.name) - category_item = self.category_items.get(category, self.category_items["含量分布图"]) - - # 创建图像项 - image_item = QTreeWidgetItem(category_item) - image_item.setText(0, f" └─ {display_name}") - image_item.setData(0, Qt.UserRole, {"type": "image", "path": str(file_path)}) - image_item.setToolTip(0, str(file_path)) - - return image_item - - def _determine_category(self, filename: str) -> str: - """根据文件名确定类别""" - filename_lower = filename.lower() - - for category_name, keywords, _ in self.CATEGORIES: - if any(keyword in filename_lower for keyword in keywords): - return category_name - - return "含量分布图" - - def scan_directory(self, work_dir: str): - """扫描目录中的所有图像文件""" - self.clear_all_images() - - work_path = Path(work_dir) - if not work_path.exists(): - return - - # 查找所有图像文件:14_visualization 为主,同时扫描步骤产出目录(如 1_water_mask 下的预览/叠置图) - image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff', '*.bmp'] - scan_roots: List[Path] = [] - _viz = work_path / "14_visualization" - if _viz.is_dir(): - scan_roots.append(_viz) - _wm = work_path / "1_water_mask" - if _wm.is_dir(): - scan_roots.append(_wm) - if not scan_roots: - scan_roots.append(work_path) - - seen_norm: set = set() - image_files: List[Path] = [] - for root in scan_roots: - for ext in image_extensions: - for p in root.glob(f"**/{ext}"): - key = os.path.normcase(os.path.normpath(str(p.resolve()))) - if key in seen_norm: - continue - seen_norm.add(key) - image_files.append(p) - - # 添加图像到树 - for img_file in sorted(image_files): - # 跳过缩略图和临时文件 - if img_file.name.startswith('.') or 'thumb' in img_file.name.lower(): - continue - self.add_image(img_file) - - # 更新类别项文本显示数量 - for category_name, item in self.category_items.items(): - count = item.childCount() - if count > 0: - for cat_name, _, icon in self.CATEGORIES: - if cat_name == category_name: - item.setText(0, f"{icon} {category_name} ({count})") - break - - def get_selected_image_path(self) -> Optional[str]: - """获取当前选中的图像路径""" - selected_item = self.currentItem() - if not selected_item: - return None - - data = selected_item.data(0, Qt.UserRole) - if data and data.get("type") == "image": - return data.get("path") - return None - - -class ImageViewerWidget(QWidget): - """图像查看器组件 - 支持缩放、平移""" - - def __init__(self, parent=None): - super().__init__(parent) - self.current_image_path = None - self.scale_factor = 1.0 - self._update_timer = QTimer() # 防抖定时器 - self._update_timer.setSingleShot(True) - self._update_timer.timeout.connect(self._do_update_display) - self._pending_scale = None # 待更新的缩放比例 - self.setup_ui() - - def setup_ui(self): - layout = QVBoxLayout() - layout.setContentsMargins(0, 0, 0, 0) - - # 工具栏 - toolbar = QHBoxLayout() - - self.refresh_btn = QPushButton("🔄 刷新目录") - self.refresh_btn.setToolTip("重新扫描工作目录中的图像文件") - toolbar.addWidget(self.refresh_btn) - - # 添加分隔线 - separator = QFrame() - separator.setFrameShape(QFrame.VLine) - separator.setFrameShadow(QFrame.Sunken) - toolbar.addWidget(separator) - - self.zoom_in_btn = QPushButton("🔍+") - self.zoom_in_btn.setToolTip("放大") - self.zoom_in_btn.setMaximumWidth(50) - toolbar.addWidget(self.zoom_in_btn) - - self.zoom_out_btn = QPushButton("🔍-") - self.zoom_out_btn.setToolTip("缩小") - self.zoom_out_btn.setMaximumWidth(50) - toolbar.addWidget(self.zoom_out_btn) - - self.fit_btn = QPushButton("⬜ 适应窗口") - self.fit_btn.setToolTip("适应窗口大小") - toolbar.addWidget(self.fit_btn) - - self.original_btn = QPushButton("1:1 原始大小") - self.original_btn.setToolTip("原始大小") - toolbar.addWidget(self.original_btn) - - toolbar.addStretch() - - self.save_btn = QPushButton("💾 保存") - self.save_btn.setToolTip("保存当前图像") - toolbar.addWidget(self.save_btn) - - layout.addLayout(toolbar) - - # 图像显示区域 - 使用 QLabel + QScrollArea - self.scroll_area = QScrollArea() - self.scroll_area.setWidgetResizable(True) - self.scroll_area.setStyleSheet("background-color: white;") - - self.image_label = QLabel() - self.image_label.setAlignment(Qt.AlignCenter) - self.image_label.setStyleSheet("background-color: white;") - - self.scroll_area.setWidget(self.image_label) - layout.addWidget(self.scroll_area, 1) - - # 状态栏 - status_layout = QHBoxLayout() - self.status_label = QLabel("就绪") - self.status_label.setStyleSheet("color: #666; font-size: 11px;") - status_layout.addWidget(self.status_label) - status_layout.addStretch() - layout.addLayout(status_layout) - - self.setLayout(layout) - - # 连接信号 - self.zoom_in_btn.clicked.connect(self.zoom_in) - self.zoom_out_btn.clicked.connect(self.zoom_out) - self.fit_btn.clicked.connect(self.fit_to_window) - self.original_btn.clicked.connect(self.original_size) - self.save_btn.clicked.connect(self.save_image) - - def load_image(self, image_path: str): - """加载并显示图像""" - if not image_path or not Path(image_path).exists(): - self.image_label.setText("图像不存在") - self.status_label.setText("图像加载失败") - return - - self.current_image_path = image_path - self.scale_factor = 1.0 - - # 加载图像 - pixmap = QPixmap(image_path) - if pixmap.isNull(): - self.image_label.setText("无法加载图像") - self.status_label.setText("图像格式不支持") - return - - self.original_pixmap = pixmap - - # 默认适应窗口显示 - self.fit_to_window() - - # 更新状态 - file_info = Path(image_path).stat() - size_mb = file_info.st_size / (1024 * 1024) - self.status_label.setText(f"{pixmap.width()}x{pixmap.height()} | {size_mb:.2f} MB | {Path(image_path).name} | 适应窗口") - - def update_image_display(self): - """更新图像显示 - 使用防抖避免频繁重绘卡顿""" - # 取消之前的待执行更新,重新计时 - self._update_timer.stop() - self._pending_scale = self.scale_factor - self._update_timer.start(50) # 50ms后执行实际更新 - - def _do_update_display(self): - """实际执行图像更新""" - if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull(): - return - - if self._pending_scale is None: - return - - # 根据缩放比例选择变换模式:大幅度缩放用Fast模式提升性能 - if self._pending_scale > 2.0 or self._pending_scale < 0.5: - transform = Qt.FastTransformation - else: - transform = Qt.SmoothTransformation - - scaled_pixmap = self.original_pixmap.scaled( - int(self.original_pixmap.width() * self._pending_scale), - int(self.original_pixmap.height() * self._pending_scale), - Qt.KeepAspectRatio, - transform - ) - self.image_label.setPixmap(scaled_pixmap) - self._pending_scale = None - - def wheelEvent(self, event): - """鼠标滚轮缩放 - 实时响应""" - delta = event.angleDelta().y() - - if delta > 0: - # 向上滚动 - 放大 - if self.scale_factor < 5.0: - self.scale_factor = min(self.scale_factor * 1.1, 5.0) - self.update_image_display() - else: - # 向下滚动 - 缩小 - if self.scale_factor > 0.1: - self.scale_factor = max(self.scale_factor / 1.1, 0.1) - self.update_image_display() - - event.accept() - - def zoom_in(self): - """放大""" - if self.scale_factor < 5.0: - self.scale_factor = min(self.scale_factor * 1.25, 5.0) - self.update_image_display() - - def zoom_out(self): - """缩小""" - if self.scale_factor > 0.1: - self.scale_factor = max(self.scale_factor / 1.25, 0.1) - self.update_image_display() - - def fit_to_window(self): - """适应窗口""" - if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull(): - return - - # 计算适应窗口的缩放比例 - view_size = self.scroll_area.viewport().size() - img_size = self.original_pixmap.size() - - scale_w = view_size.width() / img_size.width() - scale_h = view_size.height() / img_size.height() - - # 记录适应前的比例(用于后续恢复参考) - self._fit_scale = min(scale_w, scale_h) - self.scale_factor = self._fit_scale - - self.update_image_display() - self.status_label.setText(f"适应窗口 | 缩放: {self.scale_factor:.1%}") - - def original_size(self): - """原始大小""" - self.scale_factor = 1.0 - self._fit_scale = None # 清除适应记录 - self.update_image_display() - self.status_label.setText("原始大小 | 缩放: 100%") - - def save_image(self): - """保存图像""" - if not self.current_image_path: - return - - file_path, _ = QFileDialog.getSaveFileName( - self, "保存图像", Path(self.current_image_path).name, - "PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)" - ) - - if file_path: - try: - import shutil - shutil.copy(self.current_image_path, file_path) - except Exception as e: - QMessageBox.critical(self, "错误", f"保存失败: {e}") - - -class ChartBrowserDialog(QDialog): - """图表浏览器对话框""" - def __init__(self, chart_files, parent=None): - super().__init__(parent) - self.chart_files = sorted(chart_files, key=lambda x: x.stat().st_mtime, reverse=True) - self.current_index = 0 - self.setWindowTitle("图表浏览器") - self.resize(1200, 800) - self.init_ui() - self.show_chart(0) - - def init_ui(self): - layout = QVBoxLayout() - - # 顶部:图表列表 - list_group = QGroupBox(f"图表列表 (共 {len(self.chart_files)} 个)") - list_layout = QHBoxLayout() - - self.chart_list = QListWidget() - self.chart_list.setMaximumHeight(150) - for chart_file in self.chart_files: - self.chart_list.addItem(chart_file.name) - self.chart_list.currentRowChanged.connect(self.show_chart) - - list_layout.addWidget(self.chart_list) - list_group.setLayout(list_layout) - layout.addWidget(list_group) - - # 中间:图表显示 - self.figure = Figure(figsize=(12, 8)) - self.canvas = FigureCanvas(self.figure) - self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - - self.toolbar = NavigationToolbar(self.canvas, self) - layout.addWidget(self.toolbar) - layout.addWidget(self.canvas, 1) - - # 底部:控制按钮 - btn_layout = QHBoxLayout() - - self.prev_btn = QPushButton("◀ 上一个") - self.prev_btn.clicked.connect(self.prev_chart) - btn_layout.addWidget(self.prev_btn) - - self.next_btn = QPushButton("下一个 >") - self.next_btn.clicked.connect(self.next_chart) - btn_layout.addWidget(self.next_btn) - - btn_layout.addStretch() - - self.save_btn = QPushButton("💾 保存当前图表") - self.save_btn.clicked.connect(self.save_current_chart) - btn_layout.addWidget(self.save_btn) - - self.close_btn = QPushButton("关闭") - self.close_btn.clicked.connect(self.close) - btn_layout.addWidget(self.close_btn) - - layout.addLayout(btn_layout) - self.setLayout(layout) - # 信号连接:影像文件路径变化时动态更新波段范围 - def show_chart(self, index): - """显示指定索引的图表""" - if 0 <= index < len(self.chart_files): - self.current_index = index - self.chart_list.setCurrentRow(index) - - chart_file = self.chart_files[index] - self.figure.clear() - ax = self.figure.add_subplot(111) - - try: - import matplotlib.image as mpimg - img = mpimg.imread(str(chart_file)) - ax.imshow(img) - ax.axis('off') - ax.set_title(chart_file.name, fontsize=12, pad=10) - self.figure.tight_layout() - self.canvas.draw() - except Exception as e: - ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}', - ha='center', va='center', transform=ax.transAxes) - self.canvas.draw() - - # 更新按钮状态 - self.prev_btn.setEnabled(index > 0) - self.next_btn.setEnabled(index < len(self.chart_files) - 1) - - def prev_chart(self): - """上一个图表""" - if self.current_index > 0: - self.show_chart(self.current_index - 1) - - def next_chart(self): - """下一个图表""" - if self.current_index < len(self.chart_files) - 1: - self.show_chart(self.current_index + 1) - - def save_current_chart(self): - """保存当前图表""" - if 0 <= self.current_index < len(self.chart_files): - current_file = self.chart_files[self.current_index] - file_path, _ = QFileDialog.getSaveFileName( - self, "保存图表", current_file.name, - "PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)" - ) - if file_path: - try: - import shutil - shutil.copy(str(current_file), file_path) - QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}") - except Exception as e: - QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}") - - -class InteractiveViewerDialog(QDialog): - """交互式影像预览对话框:显示影像、参考点散点图、点击查询坐标/值""" - - def __init__(self, parent, img_path, ref_csv=None): - super().__init__(parent) - self.img_path = img_path - self.ref_csv = ref_csv - self.geotransform = None - self.fig = None - self.canvas = None - self.ax = None - self.status_label = None - self.init_ui() - - def init_ui(self): - self.setWindowTitle("👁️ 交互式影像预览") - self.setMinimumSize(900, 700) - - layout = QVBoxLayout() - - # 工具栏 - toolbar = QHBoxLayout() - self.band_combo = QComboBox() - self.band_combo.currentIndexChanged.connect(self.on_band_changed) - toolbar.addWidget(QLabel("显示波段:")) - toolbar.addWidget(self.band_combo) - - self.gray_check = QCheckBox("灰度显示") - self.gray_check.stateChanged.connect(self.on_band_changed) - toolbar.addWidget(self.gray_check) - toolbar.addStretch() - layout.addLayout(toolbar) - - # Matplotlib 画布 - try: - from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas - from matplotlib.figure import Figure - import matplotlib - matplotlib.use('Qt5Agg') - - self.fig = Figure(figsize=(10, 8)) - self.canvas = FigureCanvas(self.fig) - self.ax = self.fig.add_subplot(111) - self.fig.tight_layout() - layout.addWidget(self.canvas) - - # 读取影像并初始化显示 - self.load_and_display() - - except ImportError as e: - layout.addWidget(QLabel(f"Matplotlib 未安装: {e}")) - - # 状态栏 - self.status_label = QLabel("点击影像查看像素坐标和经纬度") - self.status_label.setStyleSheet("background:#f0f0f0;padding:4px;font-size:12px;") - self.status_label.setWordWrap(True) - layout.addWidget(self.status_label) - - # 关闭按钮 - close_btn = QPushButton("关闭") - close_btn.clicked.connect(self.close) - layout.addWidget(close_btn) - - self.setLayout(layout) - - def load_and_display(self): - """加载影像并显示""" - from osgeo import gdal - import numpy as np - - dataset = gdal.Open(self.img_path) - if dataset is None: - self.status_label.setText(f"无法打开影像: {self.img_path}") - return - - self.geotransform = dataset.GetGeoTransform() - self.projection = dataset.GetProjection() - n_bands = dataset.RasterCount - self.height = dataset.RasterYSize - self.width = dataset.RasterXSize - - # 填充波段选择下拉框 - self.band_combo.clear() - if n_bands >= 3: - for i in range(1, n_bands + 1): - self.band_combo.addItem(f"RGB (B{i-0}, G{i-1}, R{i-2})" if i >= 3 else f"波段 {i}", i) - self.band_combo.addItem(f"单波段 (B1)", 0) - else: - for i in range(1, n_bands + 1): - self.band_combo.addItem(f"波段 {i}", i - 1) - self.band_combo.setCurrentIndex(0) - - self.dataset = dataset - self.display_band(0, is_gray=False) - self.load_ref_points() - - def display_band(self, band_idx, is_gray=False): - """显示指定波段组合""" - from osgeo import gdal - import numpy as np - from matplotlib.pyplot import Normalize - from matplotlib.cm import ScalarMappable - - dataset = self.dataset - self.ax.clear() - - if is_gray or (self.band_combo.currentData() == 0 and dataset.RasterCount == 1): - # 灰度显示 - band = dataset.GetRasterBand(1 if band_idx == 0 else band_idx + 1) - data = band.ReadAsArray() - data = np.nan_to_num(data, nan=0.0) - self.ax.imshow(data, cmap='gray') - self.ax.set_title(f"波段 {band_idx + 1} (灰度)") - else: - # 彩色显示(取前3个波段) - n = min(3, dataset.RasterCount) - bands_data = [] - for i in range(n): - b = dataset.GetRasterBand(i + 1) - bd = b.ReadAsArray() - bd = np.nan_to_num(bd, nan=0.0) - bands_data.append(bd) - rgb = np.dstack(bands_data) - - # 归一化到 [0, 1] - for i in range(rgb.shape[2]): - p2, p98 = np.percentile(rgb[:, :, i], [2, 98]) - if p98 > p2: - rgb[:, :, i] = np.clip((rgb[:, :, i] - p2) / (p98 - p2), 0, 1) - else: - rgb[:, :, i] = np.clip(rgb[:, :, i] / (p98 + 1e-6), 0, 1) - - self.ax.imshow(rgb) - self.ax.set_title(f"RGB 显示") - - self.ax.set_xlabel("列 (Column)") - self.ax.set_ylabel("行 (Row)") - self.fig.tight_layout() - self.canvas.draw() - - # 绑定点击事件 - self.cid = self.canvas.mpl_connect('button_press_event', self.on_click) - - def on_band_changed(self): - """波段选择变化时更新显示""" - if not hasattr(self, 'dataset'): - return - is_gray = self.gray_check.isChecked() - band_data = self.band_combo.currentData() - self.display_band(band_data if band_data != 0 else 0, is_gray=is_gray) - - def load_ref_points(self): - """加载并显示参考点""" - import os - if not self.ref_csv or not os.path.isfile(self.ref_csv): - return - - try: - import csv - lon_list, lat_list = [], [] - with open(self.ref_csv, 'r', encoding='utf-8-sig') as f: - reader = csv.DictReader(f) - for row in reader: - try: - lon = float(row.get('Lon', row.get('lon', row.get('LON', 0)))) - lat = float(row.get('Lat', row.get('lat', row.get('LAT', 0)))) - if lon and lat: - lon_list.append(lon) - lat_list.append(lat) - except (ValueError, TypeError): - continue - - if not lon_list: - return - - # 逆变换:经纬度 -> 像素坐标 - px_list, py_list = [], [] - gt = self.geotransform - if gt and (gt[1] != 0 or gt[5] != 0): - # GeoTransform: (originX, pixSizeX, rotX, originY, rotY, pixSizeY) - # pixel_x = (lon - gt[0]) / gt[1] - # line_y = (lat - gt[3]) / gt[5] - for lon, lat in zip(lon_list, lat_list): - px = (lon - gt[0]) / gt[1] - py = (lat - gt[3]) / gt[5] - if 0 <= px < self.width and 0 <= py < self.height: - px_list.append(px) - py_list.append(py) - - if px_list: - self.ax.scatter(px_list, py_list, c='red', s=40, marker='o', - edgecolors='white', linewidths=0.8, zorder=5, alpha=0.9, - label=f'参考点 ({len(px_list)}个)') - self.ax.legend(loc='upper right', fontsize=9) - self.fig.tight_layout() - self.canvas.draw() - self.status_label.setText( - f"已加载 {len(px_list)} 个参考点(仅显示在影像范围内的点)" - ) - except Exception as e: - self.status_label.setText(f"加载参考点失败: {e}") - - def pixel_to_geo(self, px, py): - """像素坐标转经纬度""" - gt = self.geotransform - if gt is None: - return None, None - lon = gt[0] + px * gt[1] + py * gt[2] - lat = gt[3] + px * gt[4] + py * gt[5] - return lon, lat - - def on_click(self, event): - """鼠标点击事件""" - if event.inaxes != self.ax or event.xdata is None or event.ydata is None: - return - - px, py = int(round(event.xdata)), int(round(event.ydata)) - if not (0 <= px < self.width and 0 <= py < self.height): - return - - # 获取该像素在各波段的值 - from osgeo import gdal - import numpy as np - dataset = self.dataset - n_bands = dataset.RasterCount - vals = [] - for b in range(1, n_bands + 1): - val = dataset.GetRasterBand(b).ReadAsArray()[py, px] - vals.append(f"{val:.4f}" if isinstance(val, float) else str(val)) - - # 经纬度转换 - lon, lat = self.pixel_to_geo(px, py) - geo_str = f"Lon={lon:.6f}, Lat={lat:.6f}" if lon is not None else "无地理参考" - - self.status_label.setText( - f"像素: (行={py}, 列={px}) | {geo_str} | " - f"波段值: {' | '.join(vals[:5])}" + - (f" ... ({n_bands}波段的更多信息)" if n_bands > 5 else "") - ) - +from src.gui.core.viz_thread import VisualizationWorkerThread, _viz_training_spectra_csv_path class WaterQualityGUI(QMainWindow):