Files
WQ_GUI/src/gui/panels/visualization_panel.py
2026-05-07 16:49:24 +08:00

1578 lines
65 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
VisualizationPanel - 可视化分析面板
左侧目录树 + 右侧图像查看器,支持多种图表生成。
"""
import os
import traceback
from pathlib import Path
from typing import Optional, List, Union
import numpy as np
import pandas as pd
from PyQt5.QtCore import Qt, QTimer, QThread, pyqtSignal, QAbstractTableModel
from PyQt5.QtGui import QPixmap
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout,
QLabel, QCheckBox, QPushButton, QLineEdit, QMessageBox,
QFileDialog, QFrame, QSizePolicy,
QDialog, QTreeWidget, QListWidget, QAbstractItemView, QHeaderView,QTreeWidgetItem,QScrollArea
)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
PIPELINE_AVAILABLE = True
except ImportError:
PIPELINE_AVAILABLE = False
def _viz_training_spectra_csv_path(work_path: Path) -> Path:
"""可视化光谱/统计及模型散点图使用的训练光谱表路径与步骤5输出一致
注意步骤5.5水质指数计算执行后会覆盖此文件为94维增强版本
因此下游步骤无需任何修改,直接读取此路径即可。
"""
return work_path / "5_training_spectra" / "training_spectra.csv"
def _viz_infer_wavelength_start_column(df: pd.DataFrame) -> Union[str, int]:
"""推断光谱起始列training_spectra 通常以波长数值为列名,未必含 UTM_Y"""
for i, col in enumerate(df.columns):
name = str(col).strip().lstrip("\ufeff")
try:
v = float(name)
except ValueError:
continue
if 200.0 <= v <= 3000.0:
return i
if "UTM_Y" in df.columns:
return "UTM_Y"
return 0
class VisualizationWorkerThread(QThread):
"""可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。"""
finished_ok = pyqtSignal(object)
failed = pyqtSignal(str)
def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None):
super().__init__()
self.task = task
self.work_dir = str(work_dir)
self.extra = extra or {}
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
wp = Path(self.work_dir)
if self.task == "mask_glint":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
cnt = len(preview_paths) if preview_paths else 0
self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths})
elif self.task == "sampling_map":
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if not hyperspectral_files:
self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif")
return
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if not csv_files:
self.failed.emit("未找到采样点 CSV 文件。")
return
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
self.finished_ok.emit(
{
"task": "sampling_map",
"map_path": map_path,
"hyperspectral_path": hyperspectral_path,
"csv_path": csv_path,
}
)
elif self.task == "spectrum":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
wl = self.extra.get("wavelength_start_column", "UTM_Y")
n_groups = int(self.extra.get("n_groups", 5))
param_cols = self.extra.get("param_cols") or []
if param_cols:
output_paths: List[str] = []
err_lines: List[str] = []
for param_col in param_cols:
try:
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
output_paths.append(out)
except Exception as _ex:
err_lines.append(f"{param_col}: {_ex}")
if not output_paths:
self.failed.emit(
"所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20])
)
return
self.finished_ok.emit(
{
"task": "spectrum",
"output_paths": output_paths,
"errors": err_lines,
}
)
else:
param_col = self.extra.get("param_col")
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
self.finished_ok.emit(
{"task": "spectrum", "output_path": out, "param_col": param_col}
)
elif self.task == "statistics":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
param_cols = self.extra.get("param_cols") or []
output_paths = viz.plot_statistical_charts(
csv_path=str(csv_file),
parameter_columns=param_cols,
)
self.finished_ok.emit(
{"task": "statistics", "output_paths": output_paths}
)
elif self.task == "scatter":
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
training_csv_path = (self.extra.get("training_csv_path") or "").strip()
models_dir = (self.extra.get("models_dir") or "").strip()
if not training_csv_path or not Path(training_csv_path).is_file():
self.failed.emit("训练光谱 CSV 无效或不存在请确认已选择步骤5输出的文件。")
return
if not models_dir or not Path(models_dir).is_dir():
self.failed.emit("模型目录无效或不存在请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
return
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=training_csv_path,
models_dir=models_dir,
)
self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}})
elif self.task == "generate_all_selected":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
parts = []
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
if self.extra.get("gen_scatter"):
if training_csv.is_file():
models_dir = wp / "7_Supervised_Model_Training"
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=str(training_csv),
models_dir=str(models_dir),
)
count = len(scatter_paths) if scatter_paths else 0
parts.append(f"散点图: {count}")
else:
parts.append("散点图: 跳过(无模型目录)")
else:
parts.append("散点图: 跳过(无训练数据)")
if self.extra.get("gen_spectrum"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
wl_col = _viz_infer_wavelength_start_column(df)
if isinstance(wl_col, str):
idx = int(df.columns.get_loc(wl_col)) + 1
else:
idx = int(wl_col)
param_cols = []
if idx > 0 and idx < len(df.columns):
param_cols = [
c for c in df.columns[:idx]
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
]
if param_cols:
spectrum_paths = []
for param_col in param_cols:
try:
path = viz.plot_spectrum_by_parameter(
csv_path=str(training_csv),
parameter_column=param_col,
wavelength_start_column=wl_col,
n_groups=5,
)
if path:
spectrum_paths.append(path)
except Exception as e:
print(f"生成光谱图失败 ({param_col}): {e}")
count = len(spectrum_paths)
parts.append(f"光谱图: {count}")
else:
parts.append("光谱图: 跳过(无可用参数列)")
else:
parts.append("光谱图: 跳过(无训练数据)")
if self.extra.get("gen_boxplots"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
param_cols = [
c for c in df.select_dtypes(include=[np.number]).columns
if not any(exc in c.lower() for exc in exclude_cols)
]
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
if 0 < idx < len(df.columns):
meta_set = set(df.columns[:idx])
param_cols = [c for c in param_cols if c in meta_set]
if param_cols:
output_dict = viz.plot_statistical_charts(
csv_path=str(training_csv),
parameter_columns=param_cols,
)
count = len([v for v in output_dict.values() if v]) if output_dict else 0
parts.append(f"统计图: {count}")
else:
parts.append("统计图: 跳过(无可用水质参数列)")
else:
parts.append("统计图: 跳过(无训练数据)")
if self.extra.get("gen_mask_glint"):
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0}")
if self.extra.get("gen_sampling_map"):
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if hyperspectral_files:
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if csv_files:
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
parts.append(f"采样点图: {Path(map_path).name}")
else:
parts.append("采样点图: 跳过(无CSV)")
else:
parts.append("采样点图: 跳过(无影像)")
self.finished_ok.emit({"task": "generate_all_selected", "parts": parts})
else:
self.failed.emit(f"未知可视化任务: {self.task}")
except Exception as e:
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass
class PandasTableModel(QAbstractTableModel):
"""支持DataFrame的表格模型"""
def __init__(self, data_frame: pd.DataFrame):
super().__init__()
self._data = data_frame.copy()
if self._data.empty:
self._data = pd.DataFrame()
self._data.fillna("", inplace=True)
self._columns = [str(col) for col in self._data.columns]
def rowCount(self, parent=None):
return len(self._data)
def columnCount(self, parent=None):
return len(self._columns)
def data(self, index, role=Qt.DisplayRole):
if not index.isValid() or role != Qt.DisplayRole:
return None
value = self._data.iat[index.row(), index.column()]
if pd.isna(value):
return ""
return str(value)
def headerData(self, section, orientation, role=Qt.DisplayRole):
if role != Qt.DisplayRole:
return None
if orientation == Qt.Horizontal:
if section < len(self._columns):
return self._columns[section]
return str(section)
return str(section + 1)
def flags(self, index):
if not index.isValid():
return Qt.NoItemFlags
return Qt.ItemIsEnabled | Qt.ItemIsSelectable
class ChartViewerDialog(QDialog):
"""图表查看器对话框"""
def __init__(self, title="图表查看器", parent=None):
super().__init__(parent)
self.setWindowTitle(title)
self.resize(1000, 700)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
self.figure = Figure(figsize=(10, 7))
self.canvas = FigureCanvas(self.figure)
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.toolbar = NavigationToolbar(self.canvas, self)
layout.addWidget(self.toolbar)
layout.addWidget(self.canvas)
btn_layout = QHBoxLayout()
self.save_btn = QPushButton("保存图表")
self.save_btn.clicked.connect(self.save_chart)
btn_layout.addWidget(self.save_btn)
btn_layout.addStretch()
self.close_btn = QPushButton("关闭")
self.close_btn.clicked.connect(self.close)
btn_layout.addWidget(self.close_btn)
layout.addLayout(btn_layout)
self.setLayout(layout)
def display_image(self, image_path):
"""显示图片"""
self.figure.clear()
ax = self.figure.add_subplot(111)
try:
import matplotlib.image as mpimg
img = mpimg.imread(image_path)
ax.imshow(img)
ax.axis('off')
self.figure.tight_layout()
self.canvas.draw()
self.current_image_path = image_path
except Exception as e:
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
def display_custom_plot(self, plot_func):
"""显示自定义绘图函数"""
self.figure.clear()
try:
plot_func(self.figure)
self.canvas.draw()
except Exception as e:
ax = self.figure.add_subplot(111)
ax.text(0.5, 0.5, f'绘图失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
def save_chart(self):
"""保存图表"""
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图表", "",
"PNG图片 (*.png);;JPG图片 (*.jpg);;PDF文件 (*.pdf);;所有文件 (*.*)"
)
if file_path:
try:
self.figure.savefig(file_path, dpi=300, bbox_inches='tight')
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
class ImageCategoryTree(QTreeWidget):
"""图像分类目录树 - 按类别组织图像文件"""
CATEGORIES = [
("模型评估", ["scatter", "regression", "validation", "r2", "rmse"], "📊"),
("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"),
("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"),
("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"),
("含量分布图", [], "📁"),
]
def __init__(self, parent=None):
super().__init__(parent)
self.setHeaderLabel("图像目录")
self.setMaximumWidth(300)
self.setMinimumWidth(250)
self.setup_categories()
self.setStyleSheet("""
QTreeWidget {
border: 1px solid #ddd;
border-radius: 5px;
background-color: #f8f9fa;
}
QTreeWidget::item {
padding: 5px;
border-radius: 3px;
}
QTreeWidget::item:selected {
background-color: #0078D4;
color: white;
}
QTreeWidget::item:hover {
background-color: #e3f2fd;
}
""")
def setup_categories(self):
"""初始化类别节点"""
self.category_items = {}
for category_name, keywords, icon in self.CATEGORIES:
item = QTreeWidgetItem(self)
item.setText(0, f"{icon} {category_name}")
item.setData(0, Qt.UserRole, {"type": "category", "keywords": keywords, "name": category_name})
item.setExpanded(True)
self.category_items[category_name] = item
def clear_all_images(self):
"""清除所有图像项"""
for category_item in self.category_items.values():
while category_item.childCount() > 0:
category_item.removeChild(category_item.child(0))
def add_image(self, file_path: Path, display_name: str = None):
"""添加图像到对应的类别"""
if display_name is None:
display_name = file_path.stem
category = self._determine_category(file_path.name)
category_item = self.category_items.get(category, self.category_items["含量分布图"])
image_item = QTreeWidgetItem(category_item)
image_item.setText(0, f" └─ {display_name}")
image_item.setData(0, Qt.UserRole, {"type": "image", "path": str(file_path)})
image_item.setToolTip(0, str(file_path))
return image_item
def _determine_category(self, filename: str) -> str:
"""根据文件名确定类别"""
filename_lower = filename.lower()
for category_name, keywords, _ in self.CATEGORIES:
if any(keyword in filename_lower for keyword in keywords):
return category_name
return "含量分布图"
def scan_directory(self, work_dir: str):
"""扫描目录中的所有图像文件"""
self.clear_all_images()
work_path = Path(work_dir)
if not work_path.exists():
return
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff', '*.bmp']
scan_roots: List[Path] = []
_viz = work_path / "14_visualization"
if _viz.is_dir():
scan_roots.append(_viz)
_wm = work_path / "1_water_mask"
if _wm.is_dir():
scan_roots.append(_wm)
if not scan_roots:
scan_roots.append(work_path)
seen_norm: set = set()
image_files: List[Path] = []
for root in scan_roots:
for ext in image_extensions:
for p in root.glob(f"**/{ext}"):
key = os.path.normcase(os.path.normpath(str(p.resolve())))
if key in seen_norm:
continue
seen_norm.add(key)
image_files.append(p)
for img_file in sorted(image_files):
if img_file.name.startswith('.') or 'thumb' in img_file.name.lower():
continue
self.add_image(img_file)
for category_name, item in self.category_items.items():
count = item.childCount()
if count > 0:
for cat_name, _, icon in self.CATEGORIES:
if cat_name == category_name:
item.setText(0, f"{icon} {category_name} ({count})")
break
def get_selected_image_path(self) -> Optional[str]:
"""获取当前选中的图像路径"""
selected_item = self.currentItem()
if not selected_item:
return None
data = selected_item.data(0, Qt.UserRole)
if data and data.get("type") == "image":
return data.get("path")
return None
class ImageViewerWidget(QWidget):
"""图像查看器组件 - 支持缩放、平移"""
def __init__(self, parent=None):
super().__init__(parent)
self.current_image_path = None
self.scale_factor = 1.0
self._update_timer = QTimer()
self._update_timer.setSingleShot(True)
self._update_timer.timeout.connect(self._do_update_display)
self._pending_scale = None
self.setup_ui()
def setup_ui(self):
layout = QVBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
toolbar = QHBoxLayout()
self.refresh_btn = QPushButton("🔄 刷新目录")
self.refresh_btn.setToolTip("重新扫描工作目录中的图像文件")
toolbar.addWidget(self.refresh_btn)
separator = QFrame()
separator.setFrameShape(QFrame.VLine)
separator.setFrameShadow(QFrame.Sunken)
toolbar.addWidget(separator)
self.zoom_in_btn = QPushButton("🔍+")
self.zoom_in_btn.setToolTip("放大")
self.zoom_in_btn.setMaximumWidth(50)
toolbar.addWidget(self.zoom_in_btn)
self.zoom_out_btn = QPushButton("🔍-")
self.zoom_out_btn.setToolTip("缩小")
self.zoom_out_btn.setMaximumWidth(50)
toolbar.addWidget(self.zoom_out_btn)
self.fit_btn = QPushButton("⬜ 适应窗口")
self.fit_btn.setToolTip("适应窗口大小")
toolbar.addWidget(self.fit_btn)
self.original_btn = QPushButton("1:1 原始大小")
self.original_btn.setToolTip("原始大小")
toolbar.addWidget(self.original_btn)
toolbar.addStretch()
self.save_btn = QPushButton("💾 保存")
self.save_btn.setToolTip("保存当前图像")
toolbar.addWidget(self.save_btn)
layout.addLayout(toolbar)
self.scroll_area = QScrollArea()
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setStyleSheet("background-color: white;")
self.image_label = QLabel()
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setStyleSheet("background-color: white;")
self.scroll_area.setWidget(self.image_label)
layout.addWidget(self.scroll_area, 1)
status_layout = QHBoxLayout()
self.status_label = QLabel("就绪")
self.status_label.setStyleSheet("color: #666; font-size: 11px;")
status_layout.addWidget(self.status_label)
status_layout.addStretch()
layout.addLayout(status_layout)
self.setLayout(layout)
self.zoom_in_btn.clicked.connect(self.zoom_in)
self.zoom_out_btn.clicked.connect(self.zoom_out)
self.fit_btn.clicked.connect(self.fit_to_window)
self.original_btn.clicked.connect(self.original_size)
self.save_btn.clicked.connect(self.save_image)
def load_image(self, image_path: str):
"""加载并显示图像"""
if not image_path or not Path(image_path).exists():
self.image_label.setText("图像不存在")
self.status_label.setText("图像加载失败")
return
self.current_image_path = image_path
self.scale_factor = 1.0
pixmap = QPixmap(image_path)
if pixmap.isNull():
self.image_label.setText("无法加载图像")
self.status_label.setText("图像格式不支持")
return
self.original_pixmap = pixmap
self.fit_to_window()
file_info = Path(image_path).stat()
size_mb = file_info.st_size / (1024 * 1024)
self.status_label.setText(f"{pixmap.width()}x{pixmap.height()} | {size_mb:.2f} MB | {Path(image_path).name} | 适应窗口")
def update_image_display(self):
"""更新图像显示 - 使用防抖避免频繁重绘卡顿"""
self._update_timer.stop()
self._pending_scale = self.scale_factor
self._update_timer.start(50)
def _do_update_display(self):
"""实际执行图像更新"""
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
return
if self._pending_scale is None:
return
if self._pending_scale > 2.0 or self._pending_scale < 0.5:
transform = Qt.FastTransformation
else:
transform = Qt.SmoothTransformation
scaled_pixmap = self.original_pixmap.scaled(
int(self.original_pixmap.width() * self._pending_scale),
int(self.original_pixmap.height() * self._pending_scale),
Qt.KeepAspectRatio,
transform
)
self.image_label.setPixmap(scaled_pixmap)
self._pending_scale = None
def wheelEvent(self, event):
"""鼠标滚轮缩放 - 实时响应"""
delta = event.angleDelta().y()
if delta > 0:
if self.scale_factor < 5.0:
self.scale_factor = min(self.scale_factor * 1.1, 5.0)
self.update_image_display()
else:
if self.scale_factor > 0.1:
self.scale_factor = max(self.scale_factor / 1.1, 0.1)
self.update_image_display()
event.accept()
def zoom_in(self):
"""放大"""
if self.scale_factor < 5.0:
self.scale_factor = min(self.scale_factor * 1.25, 5.0)
self.update_image_display()
def zoom_out(self):
"""缩小"""
if self.scale_factor > 0.1:
self.scale_factor = max(self.scale_factor / 1.25, 0.1)
self.update_image_display()
def fit_to_window(self):
"""适应窗口"""
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
return
view_size = self.scroll_area.viewport().size()
img_size = self.original_pixmap.size()
scale_w = view_size.width() / img_size.width()
scale_h = view_size.height() / img_size.height()
self._fit_scale = min(scale_w, scale_h)
self.scale_factor = self._fit_scale
self.update_image_display()
self.status_label.setText(f"适应窗口 | 缩放: {self.scale_factor:.1%}")
def original_size(self):
"""原始大小"""
self.scale_factor = 1.0
self._fit_scale = None
self.update_image_display()
self.status_label.setText("原始大小 | 缩放: 100%")
def save_image(self):
"""保存图像"""
if not self.current_image_path:
return
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图像", Path(self.current_image_path).name,
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
)
if file_path:
try:
import shutil
shutil.copy(self.current_image_path, file_path)
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败: {e}")
class ChartBrowserDialog(QDialog):
"""图表浏览器对话框"""
def __init__(self, chart_files, parent=None):
super().__init__(parent)
self.chart_files = sorted(chart_files, key=lambda x: x.stat().st_mtime, reverse=True)
self.current_index = 0
self.setWindowTitle("图表浏览器")
self.resize(1200, 800)
self.init_ui()
self.show_chart(0)
def init_ui(self):
layout = QVBoxLayout()
list_group = QGroupBox(f"图表列表 (共 {len(self.chart_files)} 个)")
list_layout = QHBoxLayout()
self.chart_list = QListWidget()
self.chart_list.setMaximumHeight(150)
for chart_file in self.chart_files:
self.chart_list.addItem(chart_file.name)
self.chart_list.currentRowChanged.connect(self.show_chart)
list_layout.addWidget(self.chart_list)
list_group.setLayout(list_layout)
layout.addWidget(list_group)
self.figure = Figure(figsize=(12, 8))
self.canvas = FigureCanvas(self.figure)
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.toolbar = NavigationToolbar(self.canvas, self)
layout.addWidget(self.toolbar)
layout.addWidget(self.canvas, 1)
btn_layout = QHBoxLayout()
self.prev_btn = QPushButton("◀ 上一个")
self.prev_btn.clicked.connect(self.prev_chart)
btn_layout.addWidget(self.prev_btn)
self.next_btn = QPushButton("下一个 >")
self.next_btn.clicked.connect(self.next_chart)
btn_layout.addWidget(self.next_btn)
btn_layout.addStretch()
self.save_btn = QPushButton("💾 保存当前图表")
self.save_btn.clicked.connect(self.save_current_chart)
btn_layout.addWidget(self.save_btn)
self.close_btn = QPushButton("关闭")
self.close_btn.clicked.connect(self.close)
btn_layout.addWidget(self.close_btn)
layout.addLayout(btn_layout)
self.setLayout(layout)
def show_chart(self, index):
"""显示指定索引的图表"""
if 0 <= index < len(self.chart_files):
self.current_index = index
self.chart_list.setCurrentRow(index)
chart_file = self.chart_files[index]
self.figure.clear()
ax = self.figure.add_subplot(111)
try:
import matplotlib.image as mpimg
img = mpimg.imread(str(chart_file))
ax.imshow(img)
ax.axis('off')
ax.set_title(chart_file.name, fontsize=12, pad=10)
self.figure.tight_layout()
self.canvas.draw()
except Exception as e:
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
self.prev_btn.setEnabled(index > 0)
self.next_btn.setEnabled(index < len(self.chart_files) - 1)
def prev_chart(self):
"""上一个图表"""
if self.current_index > 0:
self.show_chart(self.current_index - 1)
def next_chart(self):
"""下一个图表"""
if self.current_index < len(self.chart_files) - 1:
self.show_chart(self.current_index + 1)
def save_current_chart(self):
"""保存当前图表"""
if 0 <= self.current_index < len(self.chart_files):
current_file = self.chart_files[self.current_index]
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图表", current_file.name,
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
)
if file_path:
try:
import shutil
shutil.copy(str(current_file), file_path)
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
class VisualizationPanel(QWidget):
"""可视化分析面板 - 重构版:左侧目录树 + 右侧图像查看器"""
def __init__(self, parent=None):
super().__init__(parent)
self.work_dir = None
self.chart_viewer = None
self._viz_thread = None
self.init_ui()
def _viz_set_busy(self, busy: bool):
for w in (
getattr(self, "gen_all_btn", None),
getattr(self, "scan_btn", None),
):
if w is not None:
w.setEnabled(not busy)
def _start_visualization_thread(self, task: str, extra: Optional[dict] = None) -> bool:
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return False
work_path = Path(self.work_dir)
if not work_path.exists():
QMessageBox.warning(self, "警告", "工作目录不存在!")
return False
if self._viz_thread and self._viz_thread.isRunning():
QMessageBox.information(self, "提示", "可视化任务正在运行,请稍候。")
return False
self._viz_thread = VisualizationWorkerThread(task, str(work_path), extra or {})
self._viz_thread.finished_ok.connect(self._on_visualization_worker_ok, Qt.QueuedConnection)
self._viz_thread.failed.connect(self._on_visualization_worker_fail, Qt.QueuedConnection)
self._viz_thread.finished.connect(lambda: self._viz_set_busy(False), Qt.QueuedConnection)
self._viz_set_busy(True)
self._viz_thread.start()
return True
def _spectrum_meta_param_columns(self, df: pd.DataFrame) -> List[str]:
"""光谱图可选的水质参数列(光谱波段列之前、且为数值型)。"""
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
if idx <= 0 or idx >= len(df.columns):
numeric = df.select_dtypes(include=[np.number]).columns.tolist()
return [
c
for c in numeric
if not any(x in str(c).lower() for x in ("utm", "lat", "lon", "x", "y"))
]
meta = list(df.columns[:idx])
return [c for c in meta if pd.api.types.is_numeric_dtype(df[c])]
def _statistics_param_columns(self, df: pd.DataFrame) -> List[str]:
"""统计图用的参数列:只统计水质参数列(数值型),排除波长列。"""
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
coord_kw = ("utm", "lat", "lon")
if 0 < idx < len(df.columns):
meta_set = set(df.columns[:idx])
return [
col
for col in numeric_cols
if col in meta_set and not any(x in str(col).lower() for x in coord_kw)
]
return [
col
for col in numeric_cols
if not any(x in str(col).lower() for x in coord_kw + ("x", "y"))
]
def _on_visualization_worker_ok(self, payload):
if not isinstance(payload, dict):
self.scan_work_directory()
return
t = payload.get("task")
if t == "mask_glint":
cnt = int(payload.get("count") or 0)
if cnt > 0:
QMessageBox.information(
self,
"成功",
f"掩膜和耀斑缩略图生成完成,共 {cnt} 个预览图。\n"
f"保存位置: 14_visualization/glint_deglint_previews/",
)
else:
QMessageBox.warning(
self,
"警告",
"未找到可处理的影像文件2_glint/3_deglint 等)。",
)
elif t == "sampling_map":
map_path = payload.get("map_path")
QMessageBox.information(
self,
"成功",
"采样点地图生成完成。\n"
f"输出: {Path(map_path).name if map_path else ''}\n"
"路径: 14_visualization/sampling_maps/",
)
if map_path:
self.show_chart_viewer(map_path, "采样点分布图")
elif t == "spectrum":
multi = payload.get("output_paths")
if isinstance(multi, list) and multi:
ok_paths = [p for p in multi if p and Path(str(p)).is_file()]
errs = payload.get("errors") or []
msg = (
f"已为 {len(ok_paths)} 个水质参数生成光谱对比图。\n"
f"保存目录: 工作目录/14_visualization/"
)
if errs:
msg += f"\n\n以下列未生成或出错 ({len(errs)} 项,详见日志):\n"
msg += "\n".join(str(e) for e in errs[:8])
if len(errs) > 8:
msg += "\n..."
QMessageBox.information(self, "成功", msg)
if ok_paths:
self.show_chart_viewer(ok_paths[0], "光谱曲线对比(首张)")
else:
outp = payload.get("output_path")
param = payload.get("param_col", "")
QMessageBox.information(self, "成功", f"光谱图已生成:\n{outp}")
if outp:
self.show_chart_viewer(outp, f"{param} - 光谱曲线对比")
elif t == "statistics":
outp = payload.get("output_paths") or {}
QMessageBox.information(
self, "成功", f"统计图表已生成,共 {len(outp)} 项。"
)
if isinstance(outp, dict) and "boxplot" in outp:
self.show_chart_viewer(outp["boxplot"], "水质参数箱线图")
elif t == "scatter":
paths = payload.get("scatter_paths") or {}
ok_paths = [p for p in paths.values() if p and Path(str(p)).is_file()]
if ok_paths:
QMessageBox.information(
self,
"成功",
f"已生成 {len(ok_paths)} 个模型评估散点图。\n"
f"保存位置: 14_visualization/scatter_plots/",
)
self.show_chart_viewer(ok_paths[0], "模型评估散点图")
else:
QMessageBox.warning(
self,
"提示",
"未生成任何散点图。请确认 7_Supervised_Model_Training 下已有各参数子目录及模型文件,"
"且训练 CSV 与建模时一致。",
)
elif t == "generate_all_selected":
parts = payload.get("parts") or []
QMessageBox.information(
self,
"完成",
"批量可视化已执行:\n" + "\n".join(parts) if parts else "(无选中项或已跳过)",
)
self.scan_work_directory()
def _on_visualization_worker_fail(self, err: str):
QMessageBox.critical(self, "错误", f"可视化任务失败:\n{err[:1200]}")
def init_ui(self):
"""初始化UI - 使用左右分栏布局"""
main_layout = QHBoxLayout()
main_layout.setSpacing(10)
main_layout.setContentsMargins(10, 10, 10, 10)
# ===== 左侧面板 =====
left_panel = QWidget()
left_layout = QVBoxLayout()
left_layout.setContentsMargins(0, 0, 0, 0)
# 工作目录选择
dir_group = QGroupBox("工作目录")
dir_layout = QHBoxLayout()
self.work_dir_edit = QLineEdit()
self.work_dir_edit.setPlaceholderText("选择工作目录...")
self.work_dir_edit.setReadOnly(True)
dir_browse_btn = QPushButton("浏览")
dir_browse_btn.clicked.connect(self.browse_work_dir)
dir_layout.addWidget(self.work_dir_edit, 1)
dir_layout.addWidget(dir_browse_btn)
dir_group.setLayout(dir_layout)
left_layout.addWidget(dir_group)
# 图像目录选择(优先指向预测结果目录)
img_dir_group = QGroupBox("图像目录")
img_dir_layout = QHBoxLayout()
self.img_dir_edit = QLineEdit()
self.img_dir_edit.setPlaceholderText("预测结果目录(自动填充)…")
self.img_dir_edit.setReadOnly(True)
img_dir_browse_btn = QPushButton("浏览")
img_dir_browse_btn.clicked.connect(self.browse_img_dir)
img_dir_layout.addWidget(self.img_dir_edit, 1)
img_dir_layout.addWidget(img_dir_browse_btn)
img_dir_group.setLayout(img_dir_layout)
left_layout.addWidget(img_dir_group)
# 图像目录树
tree_group = QGroupBox("图像目录")
tree_layout = QVBoxLayout()
self.image_tree = ImageCategoryTree()
self.image_tree.itemClicked.connect(self.on_tree_item_clicked)
tree_layout.addWidget(self.image_tree)
tree_group.setLayout(tree_layout)
left_layout.addWidget(tree_group, 1)
# 可视化配置
config_group = QGroupBox("可视化配置")
config_layout = QVBoxLayout()
self.gen_scatter = QCheckBox("模型评估散点图")
self.gen_scatter.setChecked(True)
config_layout.addWidget(self.gen_scatter)
self.gen_spectrum = QCheckBox("光谱曲线图")
self.gen_spectrum.setChecked(True)
config_layout.addWidget(self.gen_spectrum)
self.gen_boxplots = QCheckBox("统计图表")
self.gen_boxplots.setChecked(True)
config_layout.addWidget(self.gen_boxplots)
self.gen_mask_glint = QCheckBox("掩膜和耀斑缩略图")
self.gen_mask_glint.setChecked(True)
config_layout.addWidget(self.gen_mask_glint)
self.gen_sampling_map = QCheckBox("采样点地图")
self.gen_sampling_map.setChecked(True)
config_layout.addWidget(self.gen_sampling_map)
config_layout.addSpacing(10)
line = QFrame()
line.setFrameShape(QFrame.HLine)
line.setStyleSheet("color: #ddd;")
config_layout.addWidget(line)
config_layout.addSpacing(10)
self.gen_all_btn = QPushButton("🚀 生成全部")
self.gen_all_btn.setToolTip("生成所有类型的可视化图表")
self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;")
self.gen_all_btn.clicked.connect(self.generate_all_visualizations)
config_layout.addWidget(self.gen_all_btn)
self.scan_btn = QPushButton("📁 扫描目录")
self.scan_btn.setToolTip("扫描工作目录中的图像文件")
self.scan_btn.clicked.connect(self.scan_work_directory)
config_layout.addWidget(self.scan_btn)
config_group.setLayout(config_layout)
left_layout.addWidget(config_group)
left_panel.setLayout(left_layout)
left_panel.setMaximumWidth(350)
main_layout.addWidget(left_panel, 0)
# ===== 右侧面板 =====
right_panel = QWidget()
right_layout = QVBoxLayout()
right_layout.setContentsMargins(0, 0, 0, 0)
self.image_viewer = ImageViewerWidget()
self.image_viewer.refresh_btn.clicked.connect(self.scan_work_directory)
right_layout.addWidget(self.image_viewer, 1)
right_panel.setLayout(right_layout)
main_layout.addWidget(right_panel, 1)
self.setLayout(main_layout)
def set_work_dir(self, work_dir):
"""设置工作目录"""
self.work_dir = work_dir
self.work_dir_edit.setText(str(work_dir))
if work_dir:
QTimer.singleShot(100, self.scan_work_directory)
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_work_dir(self):
"""浏览工作目录"""
default = self._get_default_work_dir()
dir_path = QFileDialog.getExistingDirectory(self, "选择工作目录", default)
if dir_path:
self.work_dir = dir_path
self.work_dir_edit.setText(dir_path)
self.scan_work_directory()
def browse_img_dir(self):
"""手动浏览图像目录"""
default = self._get_default_work_dir()
dir_path = QFileDialog.getExistingDirectory(self, "选择图像目录", default)
if dir_path:
self.img_dir_edit.setText(dir_path)
self.image_tree.scan_directory(dir_path)
self._load_first_image_from_tree()
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动推断并填入图像目录,然后自动加载目录内容。
推断优先级:
1. {work_dir}/11_12_13_predictions/Machine_Learning_Prediction机器学习预测
2. {work_dir}/11_12_13_predictions/Non_Empirical_Prediction普通回归预测
3. {work_dir}/11_12_13_predictions/Custom_Regression_Prediction自定义回归预测
4. {work_dir}/14_visualization可视化目录
5. {work_dir}(工作目录根)
"""
if work_dir:
self.work_dir = work_dir
self.work_dir_edit.setText(str(work_dir))
elif not self.work_dir:
return
work_path = Path(self.work_dir)
pred_dir = work_path / "11_12_13_predictions"
# 按优先级寻找存在的目录
candidates = [
pred_dir / "Machine_Learning_Prediction",
pred_dir / "Non_Empirical_Prediction",
pred_dir / "Custom_Regression_Prediction",
work_path / "14_visualization",
work_path,
]
detected_dir = None
for candidate in candidates:
if candidate.exists() and candidate.is_dir():
detected_dir = candidate
break
if detected_dir:
detected_str = str(detected_dir)
self.img_dir_edit.setText(detected_str)
self.image_tree.scan_directory(detected_str)
else:
# 无预测目录时扫描整个工作目录
self.image_tree.scan_directory(self.work_dir)
# 自动触发加载第一张图像
self._load_first_image_from_tree()
def _load_first_image_from_tree(self):
"""从目录树中加载第一张图像到右侧查看器"""
tree = self.image_tree
if tree is None:
return
for category_item in tree.category_items.values():
for i in range(category_item.childCount()):
child = category_item.child(i)
data = child.data(0, Qt.UserRole)
if data and data.get("type") == "image":
img_path = data.get("path")
if img_path and Path(img_path).exists():
self.image_viewer.load_image(img_path)
return
def scan_work_directory(self):
"""扫描工作目录中的图像文件"""
if not self.work_dir:
return
work_path = Path(self.work_dir)
if not work_path.exists():
return
print(f"扫描工作目录: {work_path}")
self.image_tree.scan_directory(str(work_path))
self._setup_prediction_output_dirs(work_path)
viz_dir = work_path / "14_visualization"
if viz_dir.exists():
image_files = list(viz_dir.glob("**/*.png")) + list(viz_dir.glob("**/*.jpg"))
if image_files:
self.image_viewer.load_image(str(image_files[0]))
def _setup_prediction_output_dirs(self, work_path: Path):
"""设置三个预测步骤的默认输出目录"""
try:
base_prediction_dir = work_path / "11_12_13_predictions"
ml_dir = base_prediction_dir / "Machine_Learning_Prediction"
reg_dir = base_prediction_dir / "Regression_Model_Prediction"
custom_dir = base_prediction_dir / "Custom_Regression_Prediction"
ml_dir.mkdir(parents=True, exist_ok=True)
reg_dir.mkdir(parents=True, exist_ok=True)
custom_dir.mkdir(parents=True, exist_ok=True)
if hasattr(self, 'step8_panel') and hasattr(self.step8_panel, 'output_file'):
self.step8_panel.output_file.set_path(str(ml_dir))
if hasattr(self, 'step8_5_panel') and hasattr(self.step8_5_panel, 'output_file'):
self.step8_5_panel.output_file.set_path(str(reg_dir))
if hasattr(self, 'step8_75_panel') and hasattr(self.step8_75_panel, 'output_dir_widget'):
self.step8_75_panel.output_dir_widget.set_path(str(custom_dir))
print(f"预测输出目录已设置:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}")
except Exception as e:
print(f"设置预测输出目录失败: {e}")
def on_tree_item_clicked(self, item, column):
"""目录树项点击事件"""
data = item.data(0, Qt.UserRole)
if not data:
return
if data.get("type") == "image":
image_path = data.get("path")
if image_path and Path(image_path).exists():
self.image_viewer.load_image(image_path)
def generate_all_visualizations(self):
"""生成所有可视化图表"""
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return
work_path = Path(self.work_dir)
if not work_path.exists():
QMessageBox.warning(self, "警告", "工作目录不存在!")
return
if not (self.gen_scatter.isChecked() or self.gen_spectrum.isChecked() or
self.gen_boxplots.isChecked() or self.gen_mask_glint.isChecked() or
self.gen_sampling_map.isChecked()):
QMessageBox.information(self, "提示", "请至少勾选一项可视化配置选项以生成图表。")
return
reply = QMessageBox.question(
self, "确认生成",
"将根据左侧勾选项在后台生成可视化图表,可能需要较长时间。\n是否继续?",
QMessageBox.Yes | QMessageBox.No
)
if reply != QMessageBox.Yes:
return
extra = {
"gen_scatter": self.gen_scatter.isChecked(),
"gen_spectrum": self.gen_spectrum.isChecked(),
"gen_boxplots": self.gen_boxplots.isChecked(),
"gen_mask_glint": self.gen_mask_glint.isChecked(),
"gen_sampling_map": self.gen_sampling_map.isChecked(),
}
self._start_visualization_thread("generate_all_selected", extra)
def generate_chart(self, chart_type):
"""生成图表"""
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return
work_path = Path(self.work_dir)
if not work_path.exists():
QMessageBox.warning(self, "警告", "工作目录不存在!")
return
try:
training_spectra_csv = _viz_training_spectra_csv_path(work_path)
if chart_type == 'scatter':
if not training_spectra_csv.is_file():
QMessageBox.warning(
self, "警告",
"未找到 5_training_spectra\\training_spectra.csv。\n"
"请先执行步骤5光谱特征提取生成该文件。",
)
return
training_csv = training_spectra_csv
models_dir = work_path / "7_Supervised_Model_Training"
if not models_dir.is_dir() or not any(d.is_dir() for d in models_dir.iterdir()):
mdir = QFileDialog.getExistingDirectory(
self, "选择模型根目录(内含各水质参数子文件夹)", str(work_path))
if not mdir:
return
models_dir = Path(mdir)
self._start_visualization_thread(
"scatter",
{"training_csv_path": str(training_csv), "models_dir": str(models_dir)},
)
return
if chart_type == 'spectrum':
if not training_spectra_csv.is_file():
QMessageBox.warning(
self, "警告",
"未找到 5_training_spectra\\training_spectra.csv。\n"
"光谱分析固定使用该文件请先执行步骤5光谱特征提取",
)
return
csv_file = training_spectra_csv
df = pd.read_csv(csv_file)
columns = self._spectrum_meta_param_columns(df)
if not columns:
QMessageBox.warning(
self, "警告",
"当前 CSV 中没有可用的数值型水质参数列,无法按参数分组绘制光谱图。",
)
return
wl_col = _viz_infer_wavelength_start_column(df)
self._start_visualization_thread(
"spectrum",
{"csv_path": str(csv_file), "param_cols": columns,
"wavelength_start_column": wl_col, "n_groups": 5},
)
return
if chart_type == 'statistics':
if not training_spectra_csv.is_file():
QMessageBox.warning(
self, "警告",
"未找到 5_training_spectra\\training_spectra.csv。\n"
"统计分析固定使用该文件请先执行步骤5光谱特征提取",
)
return
csv_file = training_spectra_csv
df = pd.read_csv(csv_file)
param_cols = self._statistics_param_columns(df)
if not param_cols:
QMessageBox.warning(self, "警告", "未找到可用的水质参数列!")
return
self._start_visualization_thread(
"statistics",
{"csv_path": str(csv_file), "param_cols": param_cols},
)
return
if chart_type == 'sampling_map':
self.generate_sampling_point_map()
return
except ImportError:
QMessageBox.critical(
self, "错误",
"无法导入可视化模块!\n请确保 visualization_reports.py 文件存在。",
)
except Exception as e:
QMessageBox.critical(
self, "错误",
f"生成图表时出错:\n{str(e)}\n\n{traceback.format_exc()}",
)
def generate_mask_glint_previews(self):
"""生成掩膜和耀斑缩略图"""
self._start_visualization_thread("mask_glint")
def generate_sampling_point_map(self):
"""生成采样点地图"""
self._start_visualization_thread("sampling_map")
def view_chart(self, chart_type):
"""查看图表"""
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return
work_path = Path(self.work_dir)
viz_dir = work_path / "14_visualization"
viz_dir2 = viz_dir / "boxplots"
viz_dir3 = viz_dir / "scatter_plots"
if not viz_dir.exists():
QMessageBox.warning(self, "警告", f"可视化目录不存在:\n{viz_dir}\n\n请先生成图表。")
return
chart_files = []
if chart_type == 'scatter':
chart_files = list(viz_dir3.glob("*scatter*.png"))
elif chart_type == 'spectrum':
chart_files = list(viz_dir.glob("*spectrum*.png"))
elif chart_type == 'statistics':
chart_files = list(viz_dir2.glob("*boxplot.png")) + \
list(viz_dir.glob("*histogram.png")) + \
list(viz_dir.glob("*heatmap.png"))
elif chart_type == 'distribution':
chart_files = list(viz_dir.glob("**/*distribution.png"))
elif chart_type == 'mask_glint':
glint_dir = viz_dir / "glint_deglint_previews"
chart_files = list(glint_dir.glob("*preview.png")) if glint_dir.exists() else \
list(viz_dir.glob("*preview.png")) + \
list(viz_dir.glob("*glint*.png")) + \
list(viz_dir.glob("*mask*.png"))
elif chart_type == 'sampling_map':
sampling_dir = viz_dir / "sampling_maps"
chart_files = list(sampling_dir.glob("*sampling_map.png")) if sampling_dir.exists() else \
list(viz_dir.glob("*sampling*.png"))
if not chart_files:
QMessageBox.warning(self, "警告", f"未找到{chart_type}类型的图表文件!\n\n请先生成图表。")
return
if len(chart_files) > 1:
from PyQt5.QtWidgets import QInputDialog
file_names = [f.name for f in chart_files]
file_name, ok = QInputDialog.getItem(
self, "选择图表", "请选择要查看的图表:", file_names, 0, False)
if ok:
selected_file = next(f for f in chart_files if f.name == file_name)
self.show_chart_viewer(str(selected_file), file_name)
else:
self.show_chart_viewer(str(chart_files[0]), chart_files[0].name)
def browse_all_charts(self):
"""浏览所有图表"""
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return
work_path = Path(self.work_dir)
chart_files = list(work_path.glob("**/*.png")) + list(work_path.glob("**/*.jpg"))
if not chart_files:
QMessageBox.warning(self, "警告", "未找到图表文件!")
return
dialog = ChartBrowserDialog(chart_files, self)
dialog.exec_()
def show_chart_viewer(self, image_path, title="图表查看器"):
"""显示图表查看器"""
viewer = ChartViewerDialog(title=title, parent=self)
viewer.display_image(image_path)
viewer.exec_()
def get_config(self):
"""获取配置"""
return {
'generate_scatter': self.gen_scatter.isChecked(),
'generate_boxplots': self.gen_boxplots.isChecked(),
'generate_spectrum': self.gen_spectrum.isChecked(),
'generate_glint_previews': self.gen_mask_glint.isChecked(),
'generate_sampling_maps': self.gen_sampling_map.isChecked(),
'scatter_config': {
'metric': 'test_r2', 'feature_start_column': 13,
'test_size': 0.2, 'random_state': 42
},
'boxplot_config': {
'data_start_column': 4, 'save_individual': True, 'use_seaborn': True
},
'glint_preview_config': {
'work_dir': None, 'output_subdir': 'glint_deglint_previews',
'generate_glint': True, 'generate_deglint': True
}
}
def set_config(self, config):
"""设置配置"""
if not config:
return
if 'generate_scatter' in config:
self.gen_scatter.setChecked(config['generate_scatter'])
if 'generate_boxplots' in config:
self.gen_boxplots.setChecked(config['generate_boxplots'])
if 'generate_spectrum' in config:
self.gen_spectrum.setChecked(config['generate_spectrum'])
if 'generate_glint_previews' in config:
self.gen_mask_glint.setChecked(config['generate_glint_previews'])
if 'generate_sampling_maps' in config:
self.gen_sampling_map.setChecked(config.get('generate_sampling_maps', True))