Files
WQ_GUI/src/gui/panels/visualization_panel.py

1791 lines
74 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
VisualizationPanel - 可视化分析面板
左侧目录树 + 右侧图像查看器,支持多种图表生成。
"""
import os
import traceback
from pathlib import Path
from typing import Optional, List, Union
import numpy as np
import pandas as pd
from PyQt5.QtCore import Qt, QTimer, QThread, pyqtSignal, QAbstractTableModel
from PyQt5.QtGui import QPixmap
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout,
QLabel, QCheckBox, QPushButton, QLineEdit, QMessageBox,
QFileDialog, QFrame, QSizePolicy,
QDialog, QTreeWidget, QListWidget, QAbstractItemView, QHeaderView,QTreeWidgetItem,QScrollArea
)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
PIPELINE_AVAILABLE = True
except ImportError:
PIPELINE_AVAILABLE = False
def _viz_training_spectra_csv_path(work_path: Path) -> Path:
"""可视化光谱/统计及模型散点图使用的训练光谱表路径与步骤5输出一致
注意步骤5.5水质指数计算执行后会覆盖此文件为94维增强版本
因此下游步骤无需任何修改,直接读取此路径即可。
"""
return work_path / "5_training_spectra" / "training_spectra.csv"
def _viz_infer_wavelength_start_column(df: pd.DataFrame) -> Union[str, int]:
"""推断光谱起始列training_spectra 通常以波长数值为列名,未必含 UTM_Y"""
for i, col in enumerate(df.columns):
name = str(col).strip().lstrip("\ufeff")
try:
v = float(name)
except ValueError:
continue
if 200.0 <= v <= 3000.0:
return i
if "UTM_Y" in df.columns:
return "UTM_Y"
return 0
class VisualizationWorkerThread(QThread):
"""可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。"""
finished_ok = pyqtSignal(object)
failed = pyqtSignal(str)
def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None):
super().__init__()
self.task = task
self.work_dir = str(work_dir)
self.extra = extra or {}
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
wp = Path(self.work_dir)
if self.task == "mask_glint":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
cnt = len(preview_paths) if preview_paths else 0
self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths})
elif self.task == "sampling_map":
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if not hyperspectral_files:
self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif")
return
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if not csv_files:
self.failed.emit("未找到采样点 CSV 文件。")
return
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
self.finished_ok.emit(
{
"task": "sampling_map",
"map_path": map_path,
"hyperspectral_path": hyperspectral_path,
"csv_path": csv_path,
}
)
elif self.task == "spectrum":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
wl = self.extra.get("wavelength_start_column", "UTM_Y")
n_groups = int(self.extra.get("n_groups", 5))
param_cols = self.extra.get("param_cols") or []
if param_cols:
output_paths: List[str] = []
err_lines: List[str] = []
for param_col in param_cols:
try:
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
output_paths.append(out)
except Exception as _ex:
err_lines.append(f"{param_col}: {_ex}")
if not output_paths:
self.failed.emit(
"所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20])
)
return
self.finished_ok.emit(
{
"task": "spectrum",
"output_paths": output_paths,
"errors": err_lines,
}
)
else:
param_col = self.extra.get("param_col")
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
self.finished_ok.emit(
{"task": "spectrum", "output_path": out, "param_col": param_col}
)
elif self.task == "statistics":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
param_cols = self.extra.get("param_cols") or []
output_paths = viz.plot_statistical_charts(
csv_path=str(csv_file),
parameter_columns=param_cols,
)
self.finished_ok.emit(
{"task": "statistics", "output_paths": output_paths}
)
elif self.task == "scatter":
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
training_csv_path = (self.extra.get("training_csv_path") or "").strip()
models_dir = (self.extra.get("models_dir") or "").strip()
if not training_csv_path or not Path(training_csv_path).is_file():
self.failed.emit("训练光谱 CSV 无效或不存在请确认已选择步骤5输出的文件。")
return
if not models_dir or not Path(models_dir).is_dir():
self.failed.emit("模型目录无效或不存在请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
return
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=training_csv_path,
models_dir=models_dir,
)
self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}})
elif self.task == "generate_all_selected":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
parts = []
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
if self.extra.get("gen_scatter"):
if training_csv.is_file():
models_dir = wp / "7_Supervised_Model_Training"
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=str(training_csv),
models_dir=str(models_dir),
)
count = len(scatter_paths) if scatter_paths else 0
parts.append(f"散点图: {count}")
else:
parts.append("散点图: 跳过(无模型目录)")
else:
parts.append("散点图: 跳过(无训练数据)")
if self.extra.get("gen_spectrum"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
wl_col = _viz_infer_wavelength_start_column(df)
if isinstance(wl_col, str):
idx = int(df.columns.get_loc(wl_col)) + 1
else:
idx = int(wl_col)
param_cols = []
if idx > 0 and idx < len(df.columns):
param_cols = [
c for c in df.columns[:idx]
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
]
if param_cols:
spectrum_paths = []
for param_col in param_cols:
try:
path = viz.plot_spectrum_by_parameter(
csv_path=str(training_csv),
parameter_column=param_col,
wavelength_start_column=wl_col,
n_groups=5,
)
if path:
spectrum_paths.append(path)
except Exception as e:
print(f"生成光谱图失败 ({param_col}): {e}")
count = len(spectrum_paths)
parts.append(f"光谱图: {count}")
else:
parts.append("光谱图: 跳过(无可用参数列)")
else:
parts.append("光谱图: 跳过(无训练数据)")
if self.extra.get("gen_boxplots"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
param_cols = [
c for c in df.select_dtypes(include=[np.number]).columns
if not any(exc in c.lower() for exc in exclude_cols)
]
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
if 0 < idx < len(df.columns):
meta_set = set(df.columns[:idx])
param_cols = [c for c in param_cols if c in meta_set]
if param_cols:
output_dict = viz.plot_statistical_charts(
csv_path=str(training_csv),
parameter_columns=param_cols,
)
count = len([v for v in output_dict.values() if v]) if output_dict else 0
parts.append(f"统计图: {count}")
else:
parts.append("统计图: 跳过(无可用水质参数列)")
else:
parts.append("统计图: 跳过(无训练数据)")
if self.extra.get("gen_mask_glint"):
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0}")
if self.extra.get("gen_sampling_map"):
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if hyperspectral_files:
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if csv_files:
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
parts.append(f"采样点图: {Path(map_path).name}")
else:
parts.append("采样点图: 跳过(无CSV)")
else:
parts.append("采样点图: 跳过(无影像)")
self.finished_ok.emit({"task": "generate_all_selected", "parts": parts})
else:
self.failed.emit(f"未知可视化任务: {self.task}")
except Exception as e:
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass
class PandasTableModel(QAbstractTableModel):
"""支持DataFrame的表格模型"""
def __init__(self, data_frame: pd.DataFrame):
super().__init__()
self._data = data_frame.copy()
if self._data.empty:
self._data = pd.DataFrame()
self._data.fillna("", inplace=True)
self._columns = [str(col) for col in self._data.columns]
def rowCount(self, parent=None):
return len(self._data)
def columnCount(self, parent=None):
return len(self._columns)
def data(self, index, role=Qt.DisplayRole):
if not index.isValid() or role != Qt.DisplayRole:
return None
value = self._data.iat[index.row(), index.column()]
if pd.isna(value):
return ""
return str(value)
def headerData(self, section, orientation, role=Qt.DisplayRole):
if role != Qt.DisplayRole:
return None
if orientation == Qt.Horizontal:
if section < len(self._columns):
return self._columns[section]
return str(section)
return str(section + 1)
def flags(self, index):
if not index.isValid():
return Qt.NoItemFlags
return Qt.ItemIsEnabled | Qt.ItemIsSelectable
class ChartViewerDialog(QDialog):
"""图表查看器对话框"""
def __init__(self, title="图表查看器", parent=None):
super().__init__(parent)
self.setWindowTitle(title)
self.resize(1000, 700)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
self.figure = Figure(figsize=(10, 7))
self.canvas = FigureCanvas(self.figure)
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.toolbar = NavigationToolbar(self.canvas, self)
layout.addWidget(self.toolbar)
layout.addWidget(self.canvas)
btn_layout = QHBoxLayout()
self.save_btn = QPushButton("保存图表")
self.save_btn.clicked.connect(self.save_chart)
btn_layout.addWidget(self.save_btn)
btn_layout.addStretch()
self.close_btn = QPushButton("关闭")
self.close_btn.clicked.connect(self.close)
btn_layout.addWidget(self.close_btn)
layout.addLayout(btn_layout)
self.setLayout(layout)
def display_image(self, image_path):
"""显示图片"""
self.figure.clear()
ax = self.figure.add_subplot(111)
try:
import matplotlib.image as mpimg
img = mpimg.imread(image_path)
ax.imshow(img)
ax.axis('off')
self.figure.tight_layout()
self.canvas.draw()
self.current_image_path = image_path
except Exception as e:
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
def display_custom_plot(self, plot_func):
"""显示自定义绘图函数"""
self.figure.clear()
try:
plot_func(self.figure)
self.canvas.draw()
except Exception as e:
ax = self.figure.add_subplot(111)
ax.text(0.5, 0.5, f'绘图失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
def save_chart(self):
"""保存图表"""
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图表", "",
"PNG图片 (*.png);;JPG图片 (*.jpg);;PDF文件 (*.pdf);;所有文件 (*.*)"
)
if file_path:
try:
self.figure.savefig(file_path, dpi=300, bbox_inches='tight')
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
class ImageCategoryTree(QTreeWidget):
"""图像分类目录树 - 按真实物理文件夹结构组织图像文件"""
# 文件名中文翻译映射key: 文件名前缀 → 中文显示名)
NAME_MAPPING = {
"hsi_preview": "高光谱影像预览",
"hsi_original": "原始高光谱影像",
"hsi_deglint": "去耀斑高光谱影像",
"water_mask_overlay": "水域掩膜叠加图",
"water_mask": "水域掩膜图",
"glint_mask": "耀斑掩膜预览",
"glint_overlay": "耀斑叠加对比图",
"deglint_comparison": "去耀斑前后对比",
"training_spectra": "训练光谱特征",
"spectrum_by_param": "参数光谱图",
"model_evaluation": "模型评估散点图",
"model_scatter": "模型散点图",
"regression": "回归分析图",
"validation": "验证结果图",
"spatial_distribution": "参数空间分布图",
"distribution_map": "分布图",
"thematic_map": "水质专题图",
"water_quality_map": "水质分布图",
"prediction_map": "预测结果图",
"inversion_map": "反演结果图",
"correlation_matrix": "特征相关性矩阵",
"feature_correlation": "特征相关性",
"sampling_point_map": "采样点分布图",
"sampling_points": "采样点图",
"point_locations": "采样位置图",
"boxplot": "箱线图",
"histogram": "直方图",
"statistics": "统计图表",
"statistical_chart": "统计图",
"error_analysis": "误差分析图",
"rmse": "RMSE评估图",
"r2_score": "R²得分图",
"flight": "飞行轨迹图",
"path": "轨迹图",
"trajectory": "轨迹图",
"glint_deglint": "耀斑去耀斑影像",
"enhanced": "增强分布图",
"content": "含量分布图",
"distribution": "分布图",
"prediction": "预测图",
"inversion": "反演图",
"scatter_true_vs_pred": "真值-预测散点图",
"true_vs_pred": "真值-预测散点图",
"correlation_heatmap": "相关性热力图",
"parameter_boxplot": "水质参数箱线图",
"spectrum_comparison": "光谱曲线对比图",
"scatter": "散点图",
}
# 目录层级中文翻译
DIR_MAPPING = {
"14_visualization": "统计与分析报表",
"1_water_mask": "水域掩膜识别",
"2_glint": "耀斑区域检测",
"3_deglint": "去耀斑影像结果",
"5_training_spectra": "训练光谱特征",
"8_Regression_Modeling": "回归建模分析",
"9_water_quality_prediction": "水质预测结果",
"10_feature_construction": "特征构建散点",
"11_12_13_predictions": "空间分布专题图",
"glint_deglint_previews": "耀斑处理预览",
"sampling_maps": "采样点空间分布",
"flight_maps": "无人机飞行轨迹",
"Machine_Learning_Prediction": "机器学习预测",
"Non_Empirical_Prediction": "非经验模型预测",
"Custom_Regression_Prediction": "自定义回归预测",
"boxplot_dir": "水质参数箱线图",
"boxplot": "水质参数箱线图",
"output_dir": "输出目录",
"8_spatial_inversion": "空间反演",
"4_processed_data": "处理数据",
}
def __init__(self, parent=None):
super().__init__(parent)
self._dir_node_map: dict = {} # 目录路径字符串 → QTreeWidgetItem
self._work_path: Optional[Path] = None
self.setHeaderLabel("图像目录")
self.setMaximumWidth(300)
self.setMinimumWidth(250)
self.setStyleSheet("""
QTreeWidget {
border: 1px solid #ddd;
border-radius: 5px;
background-color: #f8f9fa;
}
QTreeWidget::item {
padding: 5px;
border-radius: 3px;
}
QTreeWidget::item:selected {
background-color: #0078D4;
color: white;
}
QTreeWidget::item:hover {
background-color: #e3f2fd;
}
""")
def clear_all_images(self):
"""清除所有图像项"""
try:
self.invisibleRootItem().takeChildren()
if hasattr(self, '_dir_node_map'):
self._dir_node_map.clear()
except Exception as e:
print(f"清空树状图出错: {e}")
import traceback
traceback.print_exc()
def _translate_dir_name(self, dir_name: str) -> str:
"""翻译目录名为中文"""
return self.DIR_MAPPING.get(dir_name, dir_name)
def _translate_filename(self, filename: str) -> str:
# 1. 后缀替换 (图表类型)
type_mapping = {
'_scatter_true_vs_pred': ' 真值预测散点图',
'_spectrum_comparison': ' 光谱曲线对比图',
'_spectrum': ' 光谱特征图',
'_histogram': ' 分布直方图',
'_boxplot_seaborn': ' Seaborn箱线图',
'_boxplot': ' 箱线图',
'_distribution_enhanced': ' 增强空间分布图',
'_distribution': ' 空间分布图',
'_sampling_map': ' 采样点地图',
'_flight_paths': ' 飞行轨迹图',
'_preview': ' 效果预览图',
'water_mask_overlay': '水域掩膜叠加图',
'hsi_preview': '原始影像预览',
'correlation_heatmap': '特征相关性热力图',
'parameter_boxplot': '水质参数汇总箱线图',
'all_parameters_boxplot': '全参数汇总箱线图',
'content_map': '含量分布专题图',
'_scatter_with_confidence': ' 置信区间散点图'
}
name = filename
for eng, chn in type_mapping.items():
if eng in name:
name = name.replace(eng, chn)
# 2. 常见水质参数前缀替换
param_mapping = {
'Chlorophyll': '叶绿素', 'Chl_a': '叶绿素a', 'Chla': '叶绿素a',
'Turbidity': '浊度', 'Temperature': '温度', 'spCond': '电导率',
'COD': '化学需氧量', 'DO': '溶解氧', 'PH': 'pH值', 'TDS': '总溶解固体',
'BGA': '蓝绿藻', 'TT': '透明度', 'NH3-N': '氨氮', 'NO3-N': '硝酸盐氮',
'severe_glint_area': '重度耀斑区域', 'deglint_goodman': 'Goodman算法去耀斑'
}
for eng, chn in param_mapping.items():
if name.startswith(eng + ' ') or name.startswith(eng + '_'):
name = name.replace(eng, chn, 1)
elif eng in name:
name = name.replace(eng, chn)
return name.strip('_')
def add_image_by_dir(self, file_path: Path, work_path: Path):
"""按真实物理目录层级挂载图片节点
Args:
file_path: 图片文件的完整路径
work_path: 工作目录根路径
"""
# 计算相对路径
try:
rel_path = file_path.relative_to(work_path)
except ValueError:
rel_path = Path(file_path.name)
# 分离父目录链和文件名
parts = rel_path.parts
if len(parts) <= 1:
parent_key = "__root__"
parent_display = "根目录"
else:
# 父目录路径相对于work_path
parent_key = str(Path(*parts[:-1]))
# 取最后一层目录名作为显示名
parent_display = self._translate_dir_name(parts[-2])
# 根目录节点特殊处理
root_display = self._translate_dir_name(parts[0]) if parts else "根目录"
# 获取或创建根目录节点
if root_display not in self._dir_node_map:
root_item = QTreeWidgetItem(self)
root_item.setText(0, f"📁 {root_display}")
root_item.setData(0, Qt.UserRole, {"type": "root_dir", "path": str(work_path / parts[0])})
root_item.setExpanded(True)
self._dir_node_map[root_display] = root_item
self._dir_node_map[f"__root__{root_display}"] = root_item
root_item = self._dir_node_map.get(f"__root__{root_display}")
if len(parts) > 1:
# 获取或创建子目录节点
if parent_key not in self._dir_node_map:
dir_item = QTreeWidgetItem(root_item)
dir_item.setText(0, f" 📂 {parent_display}")
dir_item.setData(0, Qt.UserRole, {"type": "sub_dir", "path": str(work_path / parent_key)})
dir_item.setExpanded(True)
self._dir_node_map[parent_key] = dir_item
parent_item = self._dir_node_map[parent_key]
else:
parent_item = root_item
# 创建图片节点(根据翻译后的名称分配图标)
display_name = self._translate_filename(file_path.stem) + file_path.suffix
icon = "🖼️" # 默认
if "散点" in display_name:
icon = "📊"
elif "光谱" in display_name or "曲线" in display_name:
icon = "📈"
elif "箱线" in display_name or "直方" in display_name:
icon = "📉"
elif "分布" in display_name or "地图" in display_name or "轨迹" in display_name:
icon = "🗺️"
image_item = QTreeWidgetItem(parent_item)
image_item.setText(0, f" {icon} {display_name}")
image_item.setData(0, Qt.UserRole, {"type": "image", "path": str(file_path), "display_name": display_name})
image_item.setToolTip(0, str(file_path))
return image_item
def scan_directory(self, work_dir: str):
"""扫描目录中的所有图像文件(深度递归扫描)—— 按真实物理目录结构挂载"""
try:
if not work_dir:
print("可视化面板:工作目录为空,跳过扫描")
return
self._work_path = Path(work_dir)
# 阻塞信号,防止在清空树状图时触发 selected 槽函数导致崩溃
# 因为当前类继承自 QTreeWidget所以 self 本身就是树
self.blockSignals(True)
self.clear_all_images()
self.blockSignals(False)
if not self._work_path.exists():
return
except Exception as e:
import traceback
print(f"可视化面板初始化扫描出错: {e}")
traceback.print_exc()
# 确保信号锁被解开
self.blockSignals(False)
return
try:
image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff', '*.bmp']
# 拓宽扫描根目录列表(新增多个遗漏目录)
scan_roots: List[Path] = [
self._work_path / "14_visualization",
self._work_path / "11_12_13_predictions",
self._work_path / "8_Regression_Modeling",
self._work_path / "10_feature_construction",
self._work_path / "5_training_spectra",
self._work_path / "2_glint",
self._work_path / "3_deglint",
self._work_path / "1_water_mask",
self._work_path / "9_water_quality_prediction",
]
# 只保留存在的目录,并补充工作根目录作为兜底
scan_roots = [p for p in scan_roots if p.is_dir()]
if not scan_roots:
scan_roots.append(self._work_path)
seen_norm: set = set()
image_files: List[Path] = []
for root in scan_roots:
for ext in image_extensions:
for p in root.rglob(ext):
key = os.path.normcase(os.path.normpath(str(p.resolve())))
if key in seen_norm:
continue
seen_norm.add(key)
image_files.append(p)
for img_file in sorted(image_files):
if img_file.name.startswith('.') or 'thumb' in img_file.name.lower():
continue
self.add_image_by_dir(img_file, self._work_path)
# 更新目录节点计数
for key, item in self._dir_node_map.items():
if key.startswith("__root__"):
continue
if item.data(0, Qt.UserRole).get("type") == "sub_dir":
count = item.childCount()
name = item.text(0)
if count > 0 and f"({count})" not in name:
# 从目录名中提取显示名并附加计数
display = name.strip()
item.setText(0, f" 📂 {display} ({count})")
except Exception as e:
import traceback
print(f"可视化面板图片挂载出错: {e}")
traceback.print_exc()
def get_selected_image_path(self) -> Optional[str]:
"""获取当前选中的图像路径"""
selected_item = self.currentItem()
if not selected_item:
return None
data = selected_item.data(0, Qt.UserRole)
if data and data.get("type") == "image":
return data.get("path")
return None
class ImageViewerWidget(QWidget):
"""图像查看器组件 - 支持缩放、平移"""
def __init__(self, parent=None):
super().__init__(parent)
self.current_image_path = None
self.scale_factor = 1.0
self._update_timer = QTimer()
self._update_timer.setSingleShot(True)
self._update_timer.timeout.connect(self._do_update_display)
self._pending_scale = None
self.setup_ui()
def setup_ui(self):
layout = QVBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
toolbar = QHBoxLayout()
self.refresh_btn = QPushButton("🔄 刷新目录")
self.refresh_btn.setToolTip("重新扫描工作目录中的图像文件")
toolbar.addWidget(self.refresh_btn)
separator = QFrame()
separator.setFrameShape(QFrame.VLine)
separator.setFrameShadow(QFrame.Sunken)
toolbar.addWidget(separator)
self.zoom_in_btn = QPushButton("🔍+")
self.zoom_in_btn.setToolTip("放大")
self.zoom_in_btn.setMaximumWidth(50)
toolbar.addWidget(self.zoom_in_btn)
self.zoom_out_btn = QPushButton("🔍-")
self.zoom_out_btn.setToolTip("缩小")
self.zoom_out_btn.setMaximumWidth(50)
toolbar.addWidget(self.zoom_out_btn)
self.fit_btn = QPushButton("⬜ 适应窗口")
self.fit_btn.setToolTip("适应窗口大小")
toolbar.addWidget(self.fit_btn)
self.original_btn = QPushButton("1:1 原始大小")
self.original_btn.setToolTip("原始大小")
toolbar.addWidget(self.original_btn)
toolbar.addStretch()
self.save_btn = QPushButton("💾 保存")
self.save_btn.setToolTip("保存当前图像")
toolbar.addWidget(self.save_btn)
layout.addLayout(toolbar)
self.scroll_area = QScrollArea()
self.scroll_area.setWidgetResizable(True)
self.scroll_area.setStyleSheet("background-color: white;")
self.image_label = QLabel()
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setStyleSheet("background-color: white;")
self.scroll_area.setWidget(self.image_label)
layout.addWidget(self.scroll_area, 1)
status_layout = QHBoxLayout()
self.status_label = QLabel("就绪")
self.status_label.setStyleSheet("color: #666; font-size: 11px;")
status_layout.addWidget(self.status_label)
status_layout.addStretch()
layout.addLayout(status_layout)
self.setLayout(layout)
self.zoom_in_btn.clicked.connect(self.zoom_in)
self.zoom_out_btn.clicked.connect(self.zoom_out)
self.fit_btn.clicked.connect(self.fit_to_window)
self.original_btn.clicked.connect(self.original_size)
self.save_btn.clicked.connect(self.save_image)
def load_image(self, image_path: str):
"""加载并显示图像"""
if not image_path or not Path(image_path).exists():
self.image_label.setText("图像不存在")
self.status_label.setText("图像加载失败")
return
self.current_image_path = image_path
self.scale_factor = 1.0
pixmap = QPixmap(image_path)
if pixmap.isNull():
self.image_label.setText("无法加载图像")
self.status_label.setText("图像格式不支持")
return
self.original_pixmap = pixmap
self.fit_to_window()
file_info = Path(image_path).stat()
size_mb = file_info.st_size / (1024 * 1024)
self.status_label.setText(f"{pixmap.width()}x{pixmap.height()} | {size_mb:.2f} MB | {Path(image_path).name} | 适应窗口")
def update_image_display(self):
"""更新图像显示 - 使用防抖避免频繁重绘卡顿"""
self._update_timer.stop()
self._pending_scale = self.scale_factor
self._update_timer.start(50)
def _do_update_display(self):
"""实际执行图像更新"""
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
return
if self._pending_scale is None:
return
if self._pending_scale > 2.0 or self._pending_scale < 0.5:
transform = Qt.FastTransformation
else:
transform = Qt.SmoothTransformation
scaled_pixmap = self.original_pixmap.scaled(
int(self.original_pixmap.width() * self._pending_scale),
int(self.original_pixmap.height() * self._pending_scale),
Qt.KeepAspectRatio,
transform
)
self.image_label.setPixmap(scaled_pixmap)
self._pending_scale = None
def wheelEvent(self, event):
"""鼠标滚轮缩放 - 实时响应"""
delta = event.angleDelta().y()
if delta > 0:
if self.scale_factor < 5.0:
self.scale_factor = min(self.scale_factor * 1.1, 5.0)
self.update_image_display()
else:
if self.scale_factor > 0.1:
self.scale_factor = max(self.scale_factor / 1.1, 0.1)
self.update_image_display()
event.accept()
def zoom_in(self):
"""放大"""
if self.scale_factor < 5.0:
self.scale_factor = min(self.scale_factor * 1.25, 5.0)
self.update_image_display()
def zoom_out(self):
"""缩小"""
if self.scale_factor > 0.1:
self.scale_factor = max(self.scale_factor / 1.25, 0.1)
self.update_image_display()
def fit_to_window(self):
"""适应窗口"""
if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull():
return
view_size = self.scroll_area.viewport().size()
img_size = self.original_pixmap.size()
scale_w = view_size.width() / img_size.width()
scale_h = view_size.height() / img_size.height()
self._fit_scale = min(scale_w, scale_h)
self.scale_factor = self._fit_scale
self.update_image_display()
self.status_label.setText(f"适应窗口 | 缩放: {self.scale_factor:.1%}")
def original_size(self):
"""原始大小"""
self.scale_factor = 1.0
self._fit_scale = None
self.update_image_display()
self.status_label.setText("原始大小 | 缩放: 100%")
def save_image(self):
"""保存图像"""
if not self.current_image_path:
return
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图像", Path(self.current_image_path).name,
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
)
if file_path:
try:
import shutil
shutil.copy(self.current_image_path, file_path)
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败: {e}")
class ChartBrowserDialog(QDialog):
"""图表浏览器对话框"""
def __init__(self, chart_files, parent=None):
super().__init__(parent)
self.chart_files = sorted(chart_files, key=lambda x: x.stat().st_mtime, reverse=True)
self.current_index = 0
self.setWindowTitle("图表浏览器")
self.resize(1200, 800)
self.init_ui()
self.show_chart(0)
def init_ui(self):
layout = QVBoxLayout()
list_group = QGroupBox(f"图表列表 (共 {len(self.chart_files)} 个)")
list_layout = QHBoxLayout()
self.chart_list = QListWidget()
self.chart_list.setMaximumHeight(150)
for chart_file in self.chart_files:
self.chart_list.addItem(chart_file.name)
self.chart_list.currentRowChanged.connect(self.show_chart)
list_layout.addWidget(self.chart_list)
list_group.setLayout(list_layout)
layout.addWidget(list_group)
self.figure = Figure(figsize=(12, 8))
self.canvas = FigureCanvas(self.figure)
self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.toolbar = NavigationToolbar(self.canvas, self)
layout.addWidget(self.toolbar)
layout.addWidget(self.canvas, 1)
btn_layout = QHBoxLayout()
self.prev_btn = QPushButton("◀ 上一个")
self.prev_btn.clicked.connect(self.prev_chart)
btn_layout.addWidget(self.prev_btn)
self.next_btn = QPushButton("下一个 >")
self.next_btn.clicked.connect(self.next_chart)
btn_layout.addWidget(self.next_btn)
btn_layout.addStretch()
self.save_btn = QPushButton("💾 保存当前图表")
self.save_btn.clicked.connect(self.save_current_chart)
btn_layout.addWidget(self.save_btn)
self.close_btn = QPushButton("关闭")
self.close_btn.clicked.connect(self.close)
btn_layout.addWidget(self.close_btn)
layout.addLayout(btn_layout)
self.setLayout(layout)
def show_chart(self, index):
"""显示指定索引的图表"""
if 0 <= index < len(self.chart_files):
self.current_index = index
self.chart_list.setCurrentRow(index)
chart_file = self.chart_files[index]
self.figure.clear()
ax = self.figure.add_subplot(111)
try:
import matplotlib.image as mpimg
img = mpimg.imread(str(chart_file))
ax.imshow(img)
ax.axis('off')
ax.set_title(chart_file.name, fontsize=12, pad=10)
self.figure.tight_layout()
self.canvas.draw()
except Exception as e:
ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}',
ha='center', va='center', transform=ax.transAxes)
self.canvas.draw()
self.prev_btn.setEnabled(index > 0)
self.next_btn.setEnabled(index < len(self.chart_files) - 1)
def prev_chart(self):
"""上一个图表"""
if self.current_index > 0:
self.show_chart(self.current_index - 1)
def next_chart(self):
"""下一个图表"""
if self.current_index < len(self.chart_files) - 1:
self.show_chart(self.current_index + 1)
def save_current_chart(self):
"""保存当前图表"""
if 0 <= self.current_index < len(self.chart_files):
current_file = self.chart_files[self.current_index]
file_path, _ = QFileDialog.getSaveFileName(
self, "保存图表", current_file.name,
"PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)"
)
if file_path:
try:
import shutil
shutil.copy(str(current_file), file_path)
QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}")
except Exception as e:
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
class VisualizationPanel(QWidget):
"""可视化分析面板 - 重构版:左侧目录树 + 右侧图像查看器"""
def __init__(self, parent=None):
super().__init__(parent)
self.work_dir = None
self.chart_viewer = None
self._viz_thread = None
self.init_ui()
def _viz_set_busy(self, busy: bool):
for w in (
getattr(self, "gen_all_btn", None),
getattr(self, "scan_btn", None),
):
if w is not None:
w.setEnabled(not busy)
def _start_visualization_thread(self, task: str, extra: Optional[dict] = None) -> bool:
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return False
work_path = Path(self.work_dir)
if not work_path.exists():
QMessageBox.warning(self, "警告", "工作目录不存在!")
return False
if self._viz_thread and self._viz_thread.isRunning():
QMessageBox.information(self, "提示", "可视化任务正在运行,请稍候。")
return False
self._viz_thread = VisualizationWorkerThread(task, str(work_path), extra or {})
self._viz_thread.finished_ok.connect(self._on_visualization_worker_ok, Qt.QueuedConnection)
self._viz_thread.failed.connect(self._on_visualization_worker_fail, Qt.QueuedConnection)
self._viz_thread.finished.connect(lambda: self._viz_set_busy(False), Qt.QueuedConnection)
self._viz_set_busy(True)
self._viz_thread.start()
return True
def _spectrum_meta_param_columns(self, df: pd.DataFrame) -> List[str]:
"""光谱图可选的水质参数列(光谱波段列之前、且为数值型)。"""
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
if idx <= 0 or idx >= len(df.columns):
numeric = df.select_dtypes(include=[np.number]).columns.tolist()
return [
c
for c in numeric
if not any(x in str(c).lower() for x in ("utm", "lat", "lon", "x", "y"))
]
meta = list(df.columns[:idx])
return [c for c in meta if pd.api.types.is_numeric_dtype(df[c])]
def _statistics_param_columns(self, df: pd.DataFrame) -> List[str]:
"""统计图用的参数列:只统计水质参数列(数值型),排除波长列。"""
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
coord_kw = ("utm", "lat", "lon")
if 0 < idx < len(df.columns):
meta_set = set(df.columns[:idx])
return [
col
for col in numeric_cols
if col in meta_set and not any(x in str(col).lower() for x in coord_kw)
]
return [
col
for col in numeric_cols
if not any(x in str(col).lower() for x in coord_kw + ("x", "y"))
]
def _on_visualization_worker_ok(self, payload):
if not isinstance(payload, dict):
self.scan_work_directory()
return
t = payload.get("task")
if t == "mask_glint":
cnt = int(payload.get("count") or 0)
if cnt > 0:
QMessageBox.information(
self,
"成功",
f"掩膜和耀斑缩略图生成完成,共 {cnt} 个预览图。\n"
f"保存位置: 14_visualization/glint_deglint_previews/",
)
else:
QMessageBox.warning(
self,
"警告",
"未找到可处理的影像文件2_glint/3_deglint 等)。",
)
elif t == "sampling_map":
map_path = payload.get("map_path")
QMessageBox.information(
self,
"成功",
"采样点地图生成完成。\n"
f"输出: {Path(map_path).name if map_path else ''}\n"
"路径: 14_visualization/sampling_maps/",
)
if map_path:
self.show_chart_viewer(map_path, "采样点分布图")
elif t == "spectrum":
multi = payload.get("output_paths")
if isinstance(multi, list) and multi:
ok_paths = [p for p in multi if p and Path(str(p)).is_file()]
errs = payload.get("errors") or []
msg = (
f"已为 {len(ok_paths)} 个水质参数生成光谱对比图。\n"
f"保存目录: 工作目录/14_visualization/"
)
if errs:
msg += f"\n\n以下列未生成或出错 ({len(errs)} 项,详见日志):\n"
msg += "\n".join(str(e) for e in errs[:8])
if len(errs) > 8:
msg += "\n..."
QMessageBox.information(self, "成功", msg)
if ok_paths:
self.show_chart_viewer(ok_paths[0], "光谱曲线对比(首张)")
else:
outp = payload.get("output_path")
param = payload.get("param_col", "")
QMessageBox.information(self, "成功", f"光谱图已生成:\n{outp}")
if outp:
self.show_chart_viewer(outp, f"{param} - 光谱曲线对比")
elif t == "statistics":
outp = payload.get("output_paths") or {}
QMessageBox.information(
self, "成功", f"统计图表已生成,共 {len(outp)} 项。"
)
if isinstance(outp, dict) and "boxplot" in outp:
self.show_chart_viewer(outp["boxplot"], "水质参数箱线图")
elif t == "scatter":
paths = payload.get("scatter_paths") or {}
ok_paths = [p for p in paths.values() if p and Path(str(p)).is_file()]
if ok_paths:
QMessageBox.information(
self,
"成功",
f"已生成 {len(ok_paths)} 个模型评估散点图。\n"
f"保存位置: 14_visualization/scatter_plots/",
)
self.show_chart_viewer(ok_paths[0], "模型评估散点图")
else:
QMessageBox.warning(
self,
"提示",
"未生成任何散点图。请确认 7_Supervised_Model_Training 下已有各参数子目录及模型文件,"
"且训练 CSV 与建模时一致。",
)
elif t == "generate_all_selected":
parts = payload.get("parts") or []
QMessageBox.information(
self,
"完成",
"批量可视化已执行:\n" + "\n".join(parts) if parts else "(无选中项或已跳过)",
)
self.scan_work_directory()
def _on_visualization_worker_fail(self, err: str):
QMessageBox.critical(self, "错误", f"可视化任务失败:\n{err[:1200]}")
def init_ui(self):
"""初始化UI - 使用左右分栏布局"""
main_layout = QHBoxLayout()
main_layout.setSpacing(10)
main_layout.setContentsMargins(10, 10, 10, 10)
# ===== 左侧面板 =====
left_panel = QWidget()
left_layout = QVBoxLayout()
left_layout.setContentsMargins(0, 0, 0, 0)
# 工作目录选择
dir_group = QGroupBox("工作目录")
dir_layout = QHBoxLayout()
self.work_dir_edit = QLineEdit()
self.work_dir_edit.setPlaceholderText("选择工作目录...")
self.work_dir_edit.setReadOnly(True)
dir_browse_btn = QPushButton("浏览")
dir_browse_btn.clicked.connect(self.browse_work_dir)
dir_layout.addWidget(self.work_dir_edit, 1)
dir_layout.addWidget(dir_browse_btn)
dir_group.setLayout(dir_layout)
left_layout.addWidget(dir_group)
# 图像目录选择(优先指向预测结果目录)
img_dir_group = QGroupBox("图像目录")
img_dir_layout = QHBoxLayout()
self.img_dir_edit = QLineEdit()
self.img_dir_edit.setPlaceholderText("预测结果目录(自动填充)…")
self.img_dir_edit.setReadOnly(True)
img_dir_browse_btn = QPushButton("浏览")
img_dir_browse_btn.clicked.connect(self.browse_img_dir)
img_dir_layout.addWidget(self.img_dir_edit, 1)
img_dir_layout.addWidget(img_dir_browse_btn)
img_dir_group.setLayout(img_dir_layout)
left_layout.addWidget(img_dir_group)
# 图像目录树
tree_group = QGroupBox("图像目录")
tree_layout = QVBoxLayout()
self.image_tree = ImageCategoryTree()
self.image_tree.itemClicked.connect(self.on_tree_item_clicked)
tree_layout.addWidget(self.image_tree)
tree_group.setLayout(tree_layout)
left_layout.addWidget(tree_group, 1)
# 可视化配置
config_group = QGroupBox("可视化配置")
config_layout = QVBoxLayout()
self.gen_scatter = QCheckBox("模型评估散点图")
self.gen_scatter.setChecked(True)
config_layout.addWidget(self.gen_scatter)
self.gen_spectrum = QCheckBox("光谱曲线图")
self.gen_spectrum.setChecked(True)
config_layout.addWidget(self.gen_spectrum)
self.gen_boxplots = QCheckBox("统计图表")
self.gen_boxplots.setChecked(True)
config_layout.addWidget(self.gen_boxplots)
self.gen_mask_glint = QCheckBox("掩膜和耀斑缩略图")
self.gen_mask_glint.setChecked(True)
config_layout.addWidget(self.gen_mask_glint)
self.gen_sampling_map = QCheckBox("采样点地图")
self.gen_sampling_map.setChecked(True)
config_layout.addWidget(self.gen_sampling_map)
config_layout.addSpacing(10)
line = QFrame()
line.setFrameShape(QFrame.HLine)
line.setStyleSheet("color: #ddd;")
config_layout.addWidget(line)
config_layout.addSpacing(10)
self.gen_all_btn = QPushButton("🚀 生成全部")
self.gen_all_btn.setToolTip("生成所有类型的可视化图表")
self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;")
self.gen_all_btn.clicked.connect(self.generate_all_visualizations)
config_layout.addWidget(self.gen_all_btn)
self.scan_btn = QPushButton("📁 扫描目录")
self.scan_btn.setToolTip("扫描工作目录中的图像文件")
self.scan_btn.clicked.connect(self.scan_work_directory)
config_layout.addWidget(self.scan_btn)
config_group.setLayout(config_layout)
left_layout.addWidget(config_group)
left_panel.setLayout(left_layout)
left_panel.setMaximumWidth(350)
main_layout.addWidget(left_panel, 0)
# ===== 右侧面板 =====
right_panel = QWidget()
right_layout = QVBoxLayout()
right_layout.setContentsMargins(0, 0, 0, 0)
self.image_viewer = ImageViewerWidget()
self.image_viewer.refresh_btn.clicked.connect(self.scan_work_directory)
right_layout.addWidget(self.image_viewer, 1)
right_panel.setLayout(right_layout)
main_layout.addWidget(right_panel, 1)
self.setLayout(main_layout)
def set_work_dir(self, work_dir):
"""设置工作目录"""
self.work_dir = work_dir
self.work_dir_edit.setText(str(work_dir))
if work_dir:
QTimer.singleShot(100, self.scan_work_directory)
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_work_dir(self):
"""浏览工作目录"""
default = self._get_default_work_dir()
dir_path = QFileDialog.getExistingDirectory(self, "选择工作目录", default)
if dir_path:
self.work_dir = dir_path
self.work_dir_edit.setText(dir_path)
self.scan_work_directory()
def browse_img_dir(self):
"""手动浏览图像目录"""
default = self._get_default_work_dir()
dir_path = QFileDialog.getExistingDirectory(self, "选择图像目录", default)
if dir_path:
self.img_dir_edit.setText(dir_path)
self.image_tree.scan_directory(dir_path)
self._load_first_image_from_tree()
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动推断并填入图像目录,然后自动加载目录内容。
推断优先级:
1. {work_dir}/11_12_13_predictions/Machine_Learning_Prediction机器学习预测
2. {work_dir}/11_12_13_predictions/Non_Empirical_Prediction普通回归预测
3. {work_dir}/11_12_13_predictions/Custom_Regression_Prediction自定义回归预测
4. {work_dir}/14_visualization可视化目录
5. {work_dir}(工作目录根)
"""
try:
if work_dir:
self.work_dir = work_dir
self.work_dir_edit.setText(str(work_dir))
elif not self.work_dir:
return
work_path = Path(self.work_dir)
pred_dir = work_path / "11_12_13_predictions"
# 按优先级寻找存在的目录
candidates = [
pred_dir / "Machine_Learning_Prediction",
pred_dir / "Non_Empirical_Prediction",
pred_dir / "Custom_Regression_Prediction",
work_path / "14_visualization",
work_path,
]
detected_dir = None
for candidate in candidates:
if candidate.exists() and candidate.is_dir():
detected_dir = candidate
break
if detected_dir:
detected_str = str(detected_dir)
self.img_dir_edit.setText(detected_str)
self.image_tree.scan_directory(detected_str)
else:
# 无预测目录时扫描整个工作目录
self.image_tree.scan_directory(self.work_dir)
# 自动触发加载第一张图像
self._load_first_image_from_tree()
except Exception as e:
import traceback
print(f"可视化面板 update_from_config 出错: {e}")
traceback.print_exc()
def _load_first_image_from_tree(self):
"""自动加载树状图中的第一张有效图片(兼容物理目录层级结构)"""
try:
tree = getattr(self, 'image_tree', None)
if not tree:
return
from PyQt5.QtCore import Qt
def find_first_image(item):
# 检查当前节点是否是图片节点
data = item.data(0, Qt.UserRole)
if isinstance(data, dict) and data.get("type") == "image":
return item
# 如果不是,递归检查所有子节点
for i in range(item.childCount()):
found = find_first_image(item.child(i))
if found:
return found
return None
# 遍历所有顶层节点
for i in range(tree.topLevelItemCount()):
first_img_item = find_first_image(tree.topLevelItem(i))
if first_img_item:
tree.setCurrentItem(first_img_item)
# 主动触发一次点击槽函数,以在右侧渲染图片
self.on_tree_item_clicked(first_img_item, 0)
return
except Exception as e:
import traceback
print(f"自动加载首张图片失败: {e}")
traceback.print_exc()
def scan_work_directory(self):
"""扫描工作目录中的图像文件"""
if not self.work_dir:
return
work_path = Path(self.work_dir)
if not work_path.exists():
return
print(f"扫描工作目录: {work_path}")
self.image_tree.scan_directory(str(work_path))
self._setup_prediction_output_dirs(work_path)
viz_dir = work_path / "14_visualization"
if viz_dir.exists():
image_files = list(viz_dir.glob("**/*.png")) + list(viz_dir.glob("**/*.jpg"))
if image_files:
self.image_viewer.load_image(str(image_files[0]))
def _setup_prediction_output_dirs(self, work_path: Path):
"""设置三个预测步骤的默认输出目录"""
try:
base_prediction_dir = work_path / "11_12_13_predictions"
ml_dir = base_prediction_dir / "Machine_Learning_Prediction"
reg_dir = base_prediction_dir / "Regression_Model_Prediction"
custom_dir = base_prediction_dir / "Custom_Regression_Prediction"
ml_dir.mkdir(parents=True, exist_ok=True)
reg_dir.mkdir(parents=True, exist_ok=True)
custom_dir.mkdir(parents=True, exist_ok=True)
if hasattr(self, 'step8_panel') and hasattr(self.step8_panel, 'output_file'):
self.step8_panel.output_file.set_path(str(ml_dir))
if hasattr(self, 'step8_5_panel') and hasattr(self.step8_5_panel, 'output_file'):
self.step8_5_panel.output_file.set_path(str(reg_dir))
if hasattr(self, 'step8_75_panel') and hasattr(self.step8_75_panel, 'output_dir_widget'):
self.step8_75_panel.output_dir_widget.set_path(str(custom_dir))
print(f"预测输出目录已设置:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}")
except Exception as e:
print(f"设置预测输出目录失败: {e}")
def on_tree_item_clicked(self, item, column):
"""目录树项点击事件"""
data = item.data(0, Qt.UserRole)
if not data:
return
if data.get("type") == "image":
image_path = data.get("path")
if image_path and Path(image_path).exists():
self.image_viewer.load_image(image_path)
def generate_all_visualizations(self):
"""生成所有可视化图表"""
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return
work_path = Path(self.work_dir)
if not work_path.exists():
QMessageBox.warning(self, "警告", "工作目录不存在!")
return
if not (self.gen_scatter.isChecked() or self.gen_spectrum.isChecked() or
self.gen_boxplots.isChecked() or self.gen_mask_glint.isChecked() or
self.gen_sampling_map.isChecked()):
QMessageBox.information(self, "提示", "请至少勾选一项可视化配置选项以生成图表。")
return
reply = QMessageBox.question(
self, "确认生成",
"将根据左侧勾选项在后台生成可视化图表,可能需要较长时间。\n是否继续?",
QMessageBox.Yes | QMessageBox.No
)
if reply != QMessageBox.Yes:
return
extra = {
"gen_scatter": self.gen_scatter.isChecked(),
"gen_spectrum": self.gen_spectrum.isChecked(),
"gen_boxplots": self.gen_boxplots.isChecked(),
"gen_mask_glint": self.gen_mask_glint.isChecked(),
"gen_sampling_map": self.gen_sampling_map.isChecked(),
}
self._start_visualization_thread("generate_all_selected", extra)
def generate_chart(self, chart_type):
"""生成图表"""
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return
work_path = Path(self.work_dir)
if not work_path.exists():
QMessageBox.warning(self, "警告", "工作目录不存在!")
return
try:
training_spectra_csv = _viz_training_spectra_csv_path(work_path)
if chart_type == 'scatter':
if not training_spectra_csv.is_file():
QMessageBox.warning(
self, "警告",
"未找到 5_training_spectra\\training_spectra.csv。\n"
"请先执行步骤5光谱特征提取生成该文件。",
)
return
training_csv = training_spectra_csv
models_dir = work_path / "7_Supervised_Model_Training"
if not models_dir.is_dir() or not any(d.is_dir() for d in models_dir.iterdir()):
mdir = QFileDialog.getExistingDirectory(
self, "选择模型根目录(内含各水质参数子文件夹)", str(work_path))
if not mdir:
return
models_dir = Path(mdir)
self._start_visualization_thread(
"scatter",
{"training_csv_path": str(training_csv), "models_dir": str(models_dir)},
)
return
if chart_type == 'spectrum':
if not training_spectra_csv.is_file():
QMessageBox.warning(
self, "警告",
"未找到 5_training_spectra\\training_spectra.csv。\n"
"光谱分析固定使用该文件请先执行步骤5光谱特征提取",
)
return
csv_file = training_spectra_csv
df = pd.read_csv(csv_file)
columns = self._spectrum_meta_param_columns(df)
if not columns:
QMessageBox.warning(
self, "警告",
"当前 CSV 中没有可用的数值型水质参数列,无法按参数分组绘制光谱图。",
)
return
wl_col = _viz_infer_wavelength_start_column(df)
self._start_visualization_thread(
"spectrum",
{"csv_path": str(csv_file), "param_cols": columns,
"wavelength_start_column": wl_col, "n_groups": 5},
)
return
if chart_type == 'statistics':
if not training_spectra_csv.is_file():
QMessageBox.warning(
self, "警告",
"未找到 5_training_spectra\\training_spectra.csv。\n"
"统计分析固定使用该文件请先执行步骤5光谱特征提取",
)
return
csv_file = training_spectra_csv
df = pd.read_csv(csv_file)
param_cols = self._statistics_param_columns(df)
if not param_cols:
QMessageBox.warning(self, "警告", "未找到可用的水质参数列!")
return
self._start_visualization_thread(
"statistics",
{"csv_path": str(csv_file), "param_cols": param_cols},
)
return
if chart_type == 'sampling_map':
self.generate_sampling_point_map()
return
except ImportError:
QMessageBox.critical(
self, "错误",
"无法导入可视化模块!\n请确保 visualization_reports.py 文件存在。",
)
except Exception as e:
QMessageBox.critical(
self, "错误",
f"生成图表时出错:\n{str(e)}\n\n{traceback.format_exc()}",
)
def generate_mask_glint_previews(self):
"""生成掩膜和耀斑缩略图"""
self._start_visualization_thread("mask_glint")
def generate_sampling_point_map(self):
"""生成采样点地图"""
self._start_visualization_thread("sampling_map")
def view_chart(self, chart_type):
"""查看图表"""
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return
work_path = Path(self.work_dir)
viz_dir = work_path / "14_visualization"
viz_dir2 = viz_dir / "boxplots"
viz_dir3 = viz_dir / "scatter_plots"
if not viz_dir.exists():
QMessageBox.warning(self, "警告", f"可视化目录不存在:\n{viz_dir}\n\n请先生成图表。")
return
chart_files = []
if chart_type == 'scatter':
chart_files = list(viz_dir3.glob("*scatter*.png"))
elif chart_type == 'spectrum':
chart_files = list(viz_dir.glob("*spectrum*.png"))
elif chart_type == 'statistics':
chart_files = list(viz_dir2.glob("*boxplot.png")) + \
list(viz_dir.glob("*histogram.png")) + \
list(viz_dir.glob("*heatmap.png"))
elif chart_type == 'distribution':
chart_files = list(viz_dir.glob("**/*distribution.png"))
elif chart_type == 'mask_glint':
glint_dir = viz_dir / "glint_deglint_previews"
chart_files = list(glint_dir.glob("*preview.png")) if glint_dir.exists() else \
list(viz_dir.glob("*preview.png")) + \
list(viz_dir.glob("*glint*.png")) + \
list(viz_dir.glob("*mask*.png"))
elif chart_type == 'sampling_map':
sampling_dir = viz_dir / "sampling_maps"
chart_files = list(sampling_dir.glob("*sampling_map.png")) if sampling_dir.exists() else \
list(viz_dir.glob("*sampling*.png"))
if not chart_files:
QMessageBox.warning(self, "警告", f"未找到{chart_type}类型的图表文件!\n\n请先生成图表。")
return
if len(chart_files) > 1:
from PyQt5.QtWidgets import QInputDialog
file_names = [f.name for f in chart_files]
file_name, ok = QInputDialog.getItem(
self, "选择图表", "请选择要查看的图表:", file_names, 0, False)
if ok:
selected_file = next(f for f in chart_files if f.name == file_name)
self.show_chart_viewer(str(selected_file), file_name)
else:
self.show_chart_viewer(str(chart_files[0]), chart_files[0].name)
def browse_all_charts(self):
"""浏览所有图表"""
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return
work_path = Path(self.work_dir)
chart_files = list(work_path.glob("**/*.png")) + list(work_path.glob("**/*.jpg"))
if not chart_files:
QMessageBox.warning(self, "警告", "未找到图表文件!")
return
dialog = ChartBrowserDialog(chart_files, self)
dialog.exec_()
def show_chart_viewer(self, image_path, title="图表查看器"):
"""显示图表查看器"""
viewer = ChartViewerDialog(title=title, parent=self)
viewer.display_image(image_path)
viewer.exec_()
def get_config(self):
"""获取配置"""
return {
'generate_scatter': self.gen_scatter.isChecked(),
'generate_boxplots': self.gen_boxplots.isChecked(),
'generate_spectrum': self.gen_spectrum.isChecked(),
'generate_glint_previews': self.gen_mask_glint.isChecked(),
'generate_sampling_maps': self.gen_sampling_map.isChecked(),
'scatter_config': {
'metric': 'test_r2', 'feature_start_column': 13,
'test_size': 0.2, 'random_state': 42
},
'boxplot_config': {
'data_start_column': 4, 'save_individual': True, 'use_seaborn': True
},
'glint_preview_config': {
'work_dir': None, 'output_subdir': 'glint_deglint_previews',
'generate_glint': True, 'generate_deglint': True
}
}
def set_config(self, config):
"""设置配置"""
if not config:
return
if 'generate_scatter' in config:
self.gen_scatter.setChecked(config['generate_scatter'])
if 'generate_boxplots' in config:
self.gen_boxplots.setChecked(config['generate_boxplots'])
if 'generate_spectrum' in config:
self.gen_spectrum.setChecked(config['generate_spectrum'])
if 'generate_glint_previews' in config:
self.gen_mask_glint.setChecked(config['generate_glint_previews'])
if 'generate_sampling_maps' in config:
self.gen_sampling_map.setChecked(config.get('generate_sampling_maps', True))