#!/usr/bin/env python # -*- coding: utf-8 -*- """ VisualizationPanel - 可视化分析面板 左侧目录树 + 右侧图像查看器,支持多种图表生成。 """ import os import sys import traceback from pathlib import Path from typing import Optional, List, Union import numpy as np import pandas as pd # 路径归一化 helper(与 pipeline.get_step_output_dir 互为表里) _HERE = os.path.dirname(os.path.abspath(__file__)) if _HERE not in sys.path: sys.path.insert(0, _HERE) from _step_path_resolver import get_step_output_path, resolve_step_widget, resolve_subdir from PyQt5.QtCore import Qt, QTimer, QThread, pyqtSignal, QAbstractTableModel from PyQt5.QtGui import QPixmap from PyQt5.QtWidgets import ( QWidget, QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout, QLabel, QCheckBox, QPushButton, QLineEdit, QMessageBox, QFileDialog, QFrame, QSizePolicy, QDialog, QTreeWidget, QListWidget, QAbstractItemView, QHeaderView,QTreeWidgetItem,QScrollArea ) from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar from matplotlib.figure import Figure PIPELINE_AVAILABLE = True def _viz_training_spectra_csv_path(work_path: Path) -> Path: """可视化光谱/统计及模型散点图使用的训练光谱表路径(与步骤5输出一致)。 注意:步骤5.5(水质指数计算)执行后会覆盖此文件为94维增强版本, 因此下游步骤无需任何修改,直接读取此路径即可。 """ return work_path / "5_training_spectra" / "training_spectra.csv" def _viz_infer_wavelength_start_column(df: pd.DataFrame) -> Union[str, int]: """推断光谱起始列(training_spectra 通常以波长数值为列名,未必含 UTM_Y)。""" for i, col in enumerate(df.columns): name = str(col).strip().lstrip("\ufeff") try: v = float(name) except ValueError: continue if 200.0 <= v <= 3000.0: return i if "UTM_Y" in df.columns: return "UTM_Y" return 0 class VisualizationWorkerThread(QThread): """可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。""" finished_ok = pyqtSignal(object) failed = pyqtSignal(str) def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None): super().__init__() self.task = task self.work_dir = str(work_dir) self.extra = extra or {} def run(self): mpl_prev = None try: import matplotlib mpl_prev = matplotlib.get_backend() except Exception: pass try: import matplotlib.pyplot as plt plt.switch_backend("Agg") except Exception: mpl_prev = None try: wp = Path(self.work_dir) if self.task == "mask_glint": from src.postprocessing.visualization_reports import WaterQualityVisualization viz = WaterQualityVisualization(output_dir=str(resolve_subdir(self.work_dir, 'visualization'))) preview_paths = viz.generate_glint_deglint_previews( work_dir=str(wp), output_subdir="glint_deglint_previews", ) cnt = len(preview_paths) if preview_paths else 0 self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths}) elif self.task == "sampling_map": hyperspectral_files = [] deglint_dir = Path(resolve_subdir(self.work_dir, 'deglint')) if deglint_dir.exists(): for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): hyperspectral_files.extend(list(deglint_dir.glob(ext))) if not hyperspectral_files: for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): hyperspectral_files.extend(list(wp.glob(f"**/{ext}"))) if not hyperspectral_files: self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif)。") return hyperspectral_path = str(hyperspectral_files[0]) csv_files = [] processed_dir = wp / "4_processed_data" if processed_dir.exists(): csv_files = list(processed_dir.glob("*.csv")) if not csv_files: csv_files = ( list(wp.glob("**/*sampling*.csv")) + list(wp.glob("**/*point*.csv")) + list(wp.glob("**/*.csv")) ) if not csv_files: self.failed.emit("未找到采样点 CSV 文件。") return csv_path = str(csv_files[0]) from src.postprocessing.point_map import SamplingPointMap map_generator = SamplingPointMap( output_dir=str(Path(resolve_subdir(self.work_dir, 'visualization')) / "sampling_maps"), fast_mode=True, ) map_path = map_generator.create_sampling_point_map( hyperspectral_path=hyperspectral_path, csv_path=csv_path, point_color="red", point_size=100, point_alpha=0.9, show_north_arrow=True, show_scale_bar=True, show_legend=True, downsample=True, dpi=180, ) self.finished_ok.emit( { "task": "sampling_map", "map_path": map_path, "hyperspectral_path": hyperspectral_path, "csv_path": csv_path, } ) elif self.task == "spectrum": from src.postprocessing.visualization_reports import WaterQualityVisualization viz = WaterQualityVisualization(output_dir=str(resolve_subdir(self.work_dir, 'visualization'))) csv_file = self.extra.get("csv_path") wl = self.extra.get("wavelength_start_column", "UTM_Y") n_groups = int(self.extra.get("n_groups", 5)) param_cols = self.extra.get("param_cols") or [] if param_cols: output_paths: List[str] = [] err_lines: List[str] = [] for param_col in param_cols: try: out = viz.plot_spectrum_by_parameter( csv_path=str(csv_file), parameter_column=param_col, wavelength_start_column=wl, n_groups=n_groups, ) output_paths.append(out) except Exception as _ex: err_lines.append(f"{param_col}: {_ex}") if not output_paths: self.failed.emit( "所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20]) ) return self.finished_ok.emit( { "task": "spectrum", "output_paths": output_paths, "errors": err_lines, } ) else: param_col = self.extra.get("param_col") out = viz.plot_spectrum_by_parameter( csv_path=str(csv_file), parameter_column=param_col, wavelength_start_column=wl, n_groups=n_groups, ) self.finished_ok.emit( {"task": "spectrum", "output_path": out, "param_col": param_col} ) elif self.task == "statistics": from src.postprocessing.visualization_reports import WaterQualityVisualization viz = WaterQualityVisualization(output_dir=str(resolve_subdir(self.work_dir, 'visualization'))) csv_file = self.extra.get("csv_path") param_cols = self.extra.get("param_cols") or [] output_paths = viz.plot_statistical_charts( csv_path=str(csv_file), parameter_columns=param_cols, ) self.finished_ok.emit( {"task": "statistics", "output_paths": output_paths} ) elif self.task == "scatter": from src.core.visualization.scatter_plot import generate_model_scatter_plots training_csv_path = (self.extra.get("training_csv_path") or "").strip() models_dir = (self.extra.get("models_dir") or "").strip() if not training_csv_path or not Path(training_csv_path).is_file(): self.failed.emit("训练光谱 CSV 无效或不存在,请确认已选择步骤5输出的文件。") return if not models_dir or not Path(models_dir).is_dir(): self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。") return scatter_paths = generate_model_scatter_plots( models_dir=models_dir, training_csv_path=training_csv_path, ) self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}}) elif self.task == "generate_all_selected": from src.postprocessing.visualization_reports import WaterQualityVisualization viz = WaterQualityVisualization(output_dir=str(resolve_subdir(self.work_dir, 'visualization'))) parts = [] training_csv_path = (self.extra.get("training_csv_path") or "").strip() if training_csv_path: training_csv = Path(training_csv_path) else: training_csv = wp / "5_training_spectra" / "training_spectra.csv" if self.extra.get("gen_scatter"): if training_csv.is_file(): models_dir_str = (self.extra.get("models_dir") or "").strip() if models_dir_str: models_dir = Path(models_dir_str) else: models_dir = wp / "8_Supervised_Model_Training" if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()): from src.core.visualization.scatter_plot import generate_model_scatter_plots scatter_paths = generate_model_scatter_plots( models_dir=str(models_dir), training_csv_path=str(training_csv), ) count = len(scatter_paths) if scatter_paths else 0 parts.append(f"散点图: {count} 个") else: parts.append("散点图: 跳过(无模型目录)") else: parts.append("散点图: 跳过(无训练数据)") if self.extra.get("gen_spectrum"): if training_csv.is_file(): import pandas as pd df = pd.read_csv(training_csv) wl_col = _viz_infer_wavelength_start_column(df) if isinstance(wl_col, str): idx = int(df.columns.get_loc(wl_col)) + 1 else: idx = int(wl_col) param_cols = [] if idx > 0 and idx < len(df.columns): param_cols = [ c for c in df.columns[:idx] if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0 ] if param_cols: spectrum_paths = [] for param_col in param_cols: try: path = viz.plot_spectrum_by_parameter( csv_path=str(training_csv), parameter_column=param_col, wavelength_start_column=wl_col, n_groups=5, ) if path: spectrum_paths.append(path) except Exception as e: print(f"生成光谱图失败 ({param_col}): {e}") count = len(spectrum_paths) parts.append(f"光谱图: {count} 个") else: parts.append("光谱图: 跳过(无可用参数列)") else: parts.append("光谱图: 跳过(无训练数据)") if self.extra.get("gen_boxplots"): if training_csv.is_file(): import pandas as pd df = pd.read_csv(training_csv) exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate'] param_cols = [ c for c in df.select_dtypes(include=[np.number]).columns if not any(exc in c.lower() for exc in exclude_cols) ] wl = _viz_infer_wavelength_start_column(df) if isinstance(wl, str): idx = int(df.columns.get_loc(wl)) + 1 else: idx = int(wl) if 0 < idx < len(df.columns): meta_set = set(df.columns[:idx]) param_cols = [c for c in param_cols if c in meta_set] if param_cols: output_dict = viz.plot_statistical_charts( csv_path=str(training_csv), parameter_columns=param_cols, ) count = len([v for v in output_dict.values() if v]) if output_dict else 0 parts.append(f"统计图: {count} 个") else: parts.append("统计图: 跳过(无可用水质参数列)") else: parts.append("统计图: 跳过(无训练数据)") if self.extra.get("gen_mask_glint"): preview_paths = viz.generate_glint_deglint_previews( work_dir=str(wp), output_subdir="glint_deglint_previews", ) parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0} 个") if self.extra.get("gen_sampling_map"): hyperspectral_files = [] deglint_dir = Path(resolve_subdir(self.work_dir, 'deglint')) if deglint_dir.exists(): for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): hyperspectral_files.extend(list(deglint_dir.glob(ext))) if not hyperspectral_files: for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): hyperspectral_files.extend(list(wp.glob(f"**/{ext}"))) if hyperspectral_files: hyperspectral_path = str(hyperspectral_files[0]) csv_files = [] processed_dir = wp / "4_processed_data" if processed_dir.exists(): csv_files = list(processed_dir.glob("*.csv")) if not csv_files: csv_files = ( list(wp.glob("**/*sampling*.csv")) + list(wp.glob("**/*point*.csv")) + list(wp.glob("**/*.csv")) ) if csv_files: csv_path = str(csv_files[0]) from src.postprocessing.point_map import SamplingPointMap map_generator = SamplingPointMap( output_dir=str(Path(resolve_subdir(self.work_dir, 'visualization')) / "sampling_maps"), fast_mode=True, ) map_path = map_generator.create_sampling_point_map( hyperspectral_path=hyperspectral_path, csv_path=csv_path, point_color="red", point_size=100, point_alpha=0.9, show_north_arrow=True, show_scale_bar=True, show_legend=True, downsample=True, dpi=180, ) parts.append(f"采样点图: {Path(map_path).name}") else: parts.append("采样点图: 跳过(无CSV)") else: parts.append("采样点图: 跳过(无影像)") if self.extra.get("gen_concentration"): conc_dir = wp / "9_Concentration" conc_csv = conc_dir / "final_concentrations.csv" if conc_csv.is_file(): charts_dir = conc_dir / "charts" charts_dir.mkdir(parents=True, exist_ok=True) try: import pandas as pd df = pd.read_csv(conc_csv) exclude_kw = ( "wavelength", "lon", "lat", "utm_x", "utm_y", "x", "y", "coord", "longitude", "latitude", "sample_id", "id", "index", "name", "pixel", ) conc_cols = [ c for c in df.select_dtypes(include=[np.number]).columns if not any(k in str(c).lower() for k in exclude_kw) ] if conc_cols: orig_out = viz.output_dir viz.output_dir = str(charts_dir) output_dict = viz.plot_statistical_charts( csv_path=str(conc_csv), parameter_columns=conc_cols, ) viz.output_dir = orig_out count = len([v for v in output_dict.values() if v]) parts.append(f"浓度统计图: {count} 个") stats_rows = [] for col in conc_cols: s = df[col].dropna() if len(s) == 0: continue stats_rows.append({ "参数": col, "点位数": len(s), "最小值": round(float(s.min()), 4), "最大值": round(float(s.max()), 4), "均值": round(float(s.mean()), 4), "中位数": round(float(s.median()), 4), "标准差": round(float(s.std()), 4) if len(s) > 1 else 0.0, }) if stats_rows: pd.DataFrame(stats_rows).to_csv( conc_dir / "statistics_summary.csv", index=False, encoding="utf-8-sig", ) parts.append("浓度统计表: 已生成") else: parts.append("浓度统计表: 跳过(无可用列)") else: parts.append("浓度统计图: 跳过(无可用浓度列)") except Exception as e: parts.append(f"浓度统计图: 失败({e})") else: parts.append("浓度统计图: 跳过(无浓度CSV)") self.finished_ok.emit({"task": "generate_all_selected", "parts": parts}) else: self.failed.emit(f"未知可视化任务: {self.task}") except Exception as e: self.failed.emit(f"{e}\n{traceback.format_exc()}") finally: if mpl_prev: try: import matplotlib.pyplot as plt plt.switch_backend(mpl_prev) except Exception: pass class PandasTableModel(QAbstractTableModel): """支持DataFrame的表格模型""" def __init__(self, data_frame: pd.DataFrame): super().__init__() self._data = data_frame.copy() if self._data.empty: self._data = pd.DataFrame() self._data.fillna("", inplace=True) self._columns = [str(col) for col in self._data.columns] def rowCount(self, parent=None): return len(self._data) def columnCount(self, parent=None): return len(self._columns) def data(self, index, role=Qt.DisplayRole): if not index.isValid() or role != Qt.DisplayRole: return None value = self._data.iat[index.row(), index.column()] if pd.isna(value): return "" return str(value) def headerData(self, section, orientation, role=Qt.DisplayRole): if role != Qt.DisplayRole: return None if orientation == Qt.Horizontal: if section < len(self._columns): return self._columns[section] return str(section) return str(section + 1) def flags(self, index): if not index.isValid(): return Qt.NoItemFlags return Qt.ItemIsEnabled | Qt.ItemIsSelectable class ChartViewerDialog(QDialog): """图表查看器对话框""" def __init__(self, title="图表查看器", parent=None): super().__init__(parent) self.setWindowTitle(title) self.resize(1000, 700) self.init_ui() def init_ui(self): layout = QVBoxLayout() self.figure = Figure(figsize=(10, 7)) self.canvas = FigureCanvas(self.figure) self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) self.toolbar = NavigationToolbar(self.canvas, self) layout.addWidget(self.toolbar) layout.addWidget(self.canvas) btn_layout = QHBoxLayout() self.save_btn = QPushButton("保存图表") self.save_btn.clicked.connect(self.save_chart) btn_layout.addWidget(self.save_btn) btn_layout.addStretch() self.close_btn = QPushButton("关闭") self.close_btn.clicked.connect(self.close) btn_layout.addWidget(self.close_btn) layout.addLayout(btn_layout) self.setLayout(layout) def display_image(self, image_path): """显示图片""" self.figure.clear() ax = self.figure.add_subplot(111) try: import matplotlib.image as mpimg img = mpimg.imread(image_path) ax.imshow(img) ax.axis('off') self.figure.tight_layout() self.canvas.draw() self.current_image_path = image_path except Exception as e: ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}', ha='center', va='center', transform=ax.transAxes) self.canvas.draw() def display_custom_plot(self, plot_func): """显示自定义绘图函数""" self.figure.clear() try: plot_func(self.figure) self.canvas.draw() except Exception as e: ax = self.figure.add_subplot(111) ax.text(0.5, 0.5, f'绘图失败:\n{str(e)}', ha='center', va='center', transform=ax.transAxes) self.canvas.draw() def save_chart(self): """保存图表""" file_path, _ = QFileDialog.getSaveFileName( self, "保存图表", "", "PNG图片 (*.png);;JPG图片 (*.jpg);;PDF文件 (*.pdf);;所有文件 (*.*)" ) if file_path: try: self.figure.savefig(file_path, dpi=300, bbox_inches='tight') QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}") except Exception as e: QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}") class ImageCategoryTree(QTreeWidget): """图像分类目录树 - 按真实物理文件夹结构组织图像文件""" # 文件名中文翻译映射(key: 文件名前缀 → 中文显示名) NAME_MAPPING = { "hsi_preview": "高光谱影像预览", "hsi_original": "原始高光谱影像", "hsi_deglint": "去耀斑高光谱影像", "water_mask_overlay": "水域掩膜叠加图", "water_mask": "水域掩膜图", "glint_mask": "耀斑掩膜预览", "glint_overlay": "耀斑叠加对比图", "deglint_comparison": "去耀斑前后对比", "training_spectra": "训练光谱特征", "spectrum_by_param": "参数光谱图", "model_evaluation": "模型评估散点图", "model_scatter": "模型散点图", "regression": "回归分析图", "validation": "验证结果图", "spatial_distribution": "参数空间分布图", "distribution_map": "分布图", "thematic_map": "水质专题图", "water_quality_map": "水质分布图", "prediction_map": "预测结果图", "inversion_map": "反演结果图", "correlation_matrix": "特征相关性矩阵", "feature_correlation": "特征相关性", "sampling_point_map": "采样点分布图", "sampling_points": "采样点图", "point_locations": "采样位置图", "boxplot": "箱线图", "histogram": "直方图", "statistics": "统计图表", "statistical_chart": "统计图", "error_analysis": "误差分析图", "rmse": "RMSE评估图", "r2_score": "R²得分图", "flight": "飞行轨迹图", "path": "轨迹图", "trajectory": "轨迹图", "glint_deglint": "耀斑去耀斑影像", "enhanced": "增强分布图", "content": "含量分布图", "distribution": "分布图", "prediction": "预测图", "inversion": "反演图", "scatter_true_vs_pred": "真值-预测散点图", "true_vs_pred": "真值-预测散点图", "correlation_heatmap": "相关性热力图", "parameter_boxplot": "水质参数箱线图", "spectrum_comparison": "光谱曲线对比图", "scatter": "散点图", } # 目录层级中文翻译 DIR_MAPPING = { "14_visualization": "统计与分析报表", "1_water_mask": "水域掩膜识别", "2_Glint_Detection": "耀斑区域检测", "3_deglint": "去耀斑影像结果", "5_training_spectra": "训练光谱特征", "8_Regression_Modeling": "回归建模分析", "9_water_quality_prediction": "水质预测结果", "10_feature_construction": "特征构建散点", "11_12_13_predictions": "空间分布专题图", "glint_deglint_previews": "耀斑处理预览", "sampling_maps": "采样点空间分布", "flight_maps": "无人机飞行轨迹", "9_ML_Prediction": "机器学习预测", "Non_Empirical_Prediction": "非经验模型预测", "Custom_Regression_Prediction": "自定义回归预测", "boxplot_dir": "水质参数箱线图", "boxplot": "水质参数箱线图", "output_dir": "输出目录", "8_spatial_inversion": "空间反演", "4_processed_data": "处理数据", "9_Concentration": "物理反演浓度分布", } def __init__(self, parent=None): super().__init__(parent) self._dir_node_map: dict = {} # 目录路径字符串 → QTreeWidgetItem self._work_path: Optional[Path] = None self.setHeaderLabel("图像目录") self.setMaximumWidth(300) self.setMinimumWidth(250) self.setStyleSheet(""" QTreeWidget { border: 1px solid #ddd; border-radius: 5px; background-color: #f8f9fa; } QTreeWidget::item { padding: 5px; border-radius: 3px; } QTreeWidget::item:selected { background-color: #0078D4; color: white; } QTreeWidget::item:hover { background-color: #e3f2fd; } """) def clear_all_images(self): """清除所有图像项""" try: self.invisibleRootItem().takeChildren() if hasattr(self, '_dir_node_map'): self._dir_node_map.clear() except Exception as e: print(f"清空树状图出错: {e}") import traceback traceback.print_exc() def _translate_dir_name(self, dir_name: str) -> str: """翻译目录名为中文""" return self.DIR_MAPPING.get(dir_name, dir_name) def _translate_filename(self, filename: str) -> str: # 1. 后缀替换 (图表类型) type_mapping = { '_scatter_true_vs_pred': ' 真值预测散点图', '_spectrum_comparison': ' 光谱曲线对比图', '_spectrum': ' 光谱特征图', '_histogram': ' 分布直方图', '_boxplot_seaborn': ' Seaborn箱线图', '_boxplot': ' 箱线图', '_distribution_enhanced': ' 增强空间分布图', '_distribution': ' 空间分布图', '_sampling_map': ' 采样点地图', '_flight_paths': ' 飞行轨迹图', '_preview': ' 效果预览图', 'water_mask_overlay': '水域掩膜叠加图', 'hsi_preview': '原始影像预览', 'correlation_heatmap': '特征相关性热力图', 'parameter_boxplot': '水质参数汇总箱线图', 'all_parameters_boxplot': '全参数汇总箱线图', 'content_map': '含量分布专题图', '_scatter_with_confidence': ' 置信区间散点图' } name = filename for eng, chn in type_mapping.items(): if eng in name: name = name.replace(eng, chn) # 2. 常见水质参数前缀替换 param_mapping = { 'Chlorophyll': '叶绿素', 'Chl_a': '叶绿素a', 'Chla': '叶绿素a', 'Turbidity': '浊度', 'Temperature': '温度', 'spCond': '电导率', 'COD': '化学需氧量', 'DO': '溶解氧', 'PH': 'pH值', 'TDS': '总溶解固体', 'BGA': '蓝绿藻', 'TT': '透明度', 'NH3-N': '氨氮', 'NO3-N': '硝酸盐氮', 'glint_severe_glint_area': '重度耀斑区域', 'severe_glint_area': '重度耀斑区域', 'deglint_goodman': 'Goodman算法去耀斑', 'deglint_Goodman': 'Goodman算法去耀斑', 'glint_': '耀斑检测_', 'deglint_': '耀斑去除_', } for eng, chn in param_mapping.items(): if name.startswith(eng + ' ') or name.startswith(eng + '_'): name = name.replace(eng, chn, 1) elif eng in name: name = name.replace(eng, chn) return name.strip('_') def add_image_by_dir(self, file_path: Path, work_path: Path): """按真实物理目录层级挂载图片节点 Args: file_path: 图片文件的完整路径 work_path: 工作目录根路径 """ # 计算相对路径 try: rel_path = file_path.relative_to(work_path) except ValueError: rel_path = Path(file_path.name) # 分离父目录链和文件名 parts = rel_path.parts if len(parts) <= 1: parent_key = "__root__" parent_display = "根目录" else: # 父目录路径(相对于work_path) parent_key = str(Path(*parts[:-1])) # 取最后一层目录名作为显示名 parent_display = self._translate_dir_name(parts[-2]) # 根目录节点特殊处理 root_display = self._translate_dir_name(parts[0]) if parts else "根目录" # 获取或创建根目录节点 if root_display not in self._dir_node_map: root_item = QTreeWidgetItem(self) root_item.setText(0, f"📁 {root_display}") root_item.setData(0, Qt.UserRole, {"type": "root_dir", "path": str(work_path / parts[0])}) root_item.setExpanded(True) self._dir_node_map[root_display] = root_item self._dir_node_map[f"__root__{root_display}"] = root_item root_item = self._dir_node_map.get(f"__root__{root_display}") if len(parts) > 1: # 获取或创建子目录节点 if parent_key not in self._dir_node_map: dir_item = QTreeWidgetItem(root_item) dir_item.setText(0, f" 📂 {parent_display}") dir_item.setData(0, Qt.UserRole, {"type": "sub_dir", "path": str(work_path / parent_key)}) dir_item.setExpanded(True) self._dir_node_map[parent_key] = dir_item parent_item = self._dir_node_map[parent_key] else: parent_item = root_item # 创建图片节点(根据翻译后的名称分配图标) display_name = self._translate_filename(file_path.stem) + file_path.suffix icon = "🖼️" # 默认 if "散点" in display_name: icon = "📊" elif "光谱" in display_name or "曲线" in display_name: icon = "📈" elif "箱线" in display_name or "直方" in display_name: icon = "📉" elif "分布" in display_name or "地图" in display_name or "轨迹" in display_name: icon = "🗺️" image_item = QTreeWidgetItem(parent_item) image_item.setText(0, f" {icon} {display_name}") image_item.setData(0, Qt.UserRole, {"type": "image", "path": str(file_path), "display_name": display_name}) image_item.setToolTip(0, str(file_path)) return image_item def scan_directory(self, work_dir: str): """扫描目录中的所有图像文件(深度递归扫描)—— 按真实物理目录结构挂载""" try: if not work_dir: print("可视化面板:工作目录为空,跳过扫描") return self._work_path = Path(work_dir) # 阻塞信号,防止在清空树状图时触发 selected 槽函数导致崩溃 # 因为当前类继承自 QTreeWidget,所以 self 本身就是树 self.blockSignals(True) self.clear_all_images() self.blockSignals(False) if not self._work_path.exists(): return except Exception as e: import traceback print(f"可视化面板初始化扫描出错: {e}") traceback.print_exc() # 确保信号锁被解开 self.blockSignals(False) return try: image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.tif', '*.tiff', '*.bmp'] # 拓宽扫描根目录列表(新增多个遗漏目录) scan_roots: List[Path] = [ Path(resolve_subdir(str(self._work_path), 'visualization')), Path(resolve_subdir(str(self._work_path), 'prediction_dir')), Path(resolve_subdir(str(self._work_path), 'regression_modeling')), self._work_path / "10_feature_construction", self._work_path / "5_training_spectra", Path(resolve_subdir(str(self._work_path), 'glint_detection')), Path(resolve_subdir(str(self._work_path), 'deglint')), Path(resolve_subdir(str(self._work_path), 'water_mask')), self._work_path / "9_water_quality_prediction", self._work_path / "9_Concentration", ] # 只保留存在的目录,并补充工作根目录作为兜底 scan_roots = [p for p in scan_roots if p.is_dir()] if not scan_roots: scan_roots.append(self._work_path) seen_norm: set = set() image_files: List[Path] = [] for root in scan_roots: for ext in image_extensions: for p in root.rglob(ext): key = os.path.normcase(os.path.normpath(str(p.resolve()))) if key in seen_norm: continue seen_norm.add(key) image_files.append(p) for img_file in sorted(image_files): if img_file.name.startswith('.') or 'thumb' in img_file.name.lower(): continue self.add_image_by_dir(img_file, self._work_path) # 更新目录节点计数 for key, item in self._dir_node_map.items(): if key.startswith("__root__"): continue if item.data(0, Qt.UserRole).get("type") == "sub_dir": count = item.childCount() name = item.text(0) if count > 0 and f"({count})" not in name: # 从目录名中提取显示名并附加计数 display = name.strip() item.setText(0, f" 📂 {display} ({count})") except Exception as e: import traceback print(f"可视化面板图片挂载出错: {e}") traceback.print_exc() def get_selected_image_path(self) -> Optional[str]: """获取当前选中的图像路径""" selected_item = self.currentItem() if not selected_item: return None data = selected_item.data(0, Qt.UserRole) if data and data.get("type") == "image": return data.get("path") return None class ImageViewerWidget(QWidget): """图像查看器组件 - 支持缩放、平移""" def __init__(self, parent=None): super().__init__(parent) self.current_image_path = None self.scale_factor = 1.0 self._update_timer = QTimer() self._update_timer.setSingleShot(True) self._update_timer.timeout.connect(self._do_update_display) self._pending_scale = None self.setup_ui() def setup_ui(self): layout = QVBoxLayout() layout.setContentsMargins(0, 0, 0, 0) toolbar = QHBoxLayout() self.refresh_btn = QPushButton("🔄 刷新目录") self.refresh_btn.setToolTip("重新扫描工作目录中的图像文件") toolbar.addWidget(self.refresh_btn) separator = QFrame() separator.setFrameShape(QFrame.VLine) separator.setFrameShadow(QFrame.Sunken) toolbar.addWidget(separator) self.zoom_in_btn = QPushButton("🔍+") self.zoom_in_btn.setToolTip("放大") self.zoom_in_btn.setMaximumWidth(50) toolbar.addWidget(self.zoom_in_btn) self.zoom_out_btn = QPushButton("🔍-") self.zoom_out_btn.setToolTip("缩小") self.zoom_out_btn.setMaximumWidth(50) toolbar.addWidget(self.zoom_out_btn) self.fit_btn = QPushButton("⬜ 适应窗口") self.fit_btn.setToolTip("适应窗口大小") toolbar.addWidget(self.fit_btn) self.original_btn = QPushButton("1:1 原始大小") self.original_btn.setToolTip("原始大小") toolbar.addWidget(self.original_btn) self.hint_label = QLabel("💡 提示: 按住 Ctrl+滚轮 可以实现放大缩小") self.hint_label.setStyleSheet(""" QLabel { color: #444444; font-size: 14px; font-weight: bold; padding-left: 15px; } """) toolbar.addWidget(self.hint_label) toolbar.addStretch() self.save_btn = QPushButton("💾 保存") self.save_btn.setToolTip("保存当前图像") toolbar.addWidget(self.save_btn) layout.addLayout(toolbar) self.scroll_area = QScrollArea() self.scroll_area.setWidgetResizable(True) self.scroll_area.setStyleSheet("background-color: white;") self.image_label = QLabel() self.image_label.setAlignment(Qt.AlignCenter) self.image_label.setStyleSheet("background-color: white;") self.scroll_area.setWidget(self.image_label) # 全方位事件拦截:给所有可能触发滚轮的子组件全部挂载过滤器 self.image_label.installEventFilter(self) self.scroll_area.viewport().installEventFilter(self) self.scroll_area.installEventFilter(self) self.scroll_area.verticalScrollBar().installEventFilter(self) self.scroll_area.horizontalScrollBar().installEventFilter(self) layout.addWidget(self.scroll_area, 1) status_layout = QHBoxLayout() self.status_label = QLabel("就绪") self.status_label.setStyleSheet("color: #666; font-size: 11px;") status_layout.addWidget(self.status_label) status_layout.addStretch() layout.addLayout(status_layout) self.setLayout(layout) self.zoom_in_btn.clicked.connect(self.zoom_in) self.zoom_out_btn.clicked.connect(self.zoom_out) self.fit_btn.clicked.connect(self.fit_to_window) self.original_btn.clicked.connect(self.original_size) self.save_btn.clicked.connect(self.save_image) def load_image(self, image_path: str): """加载并显示图像""" if not image_path or not Path(image_path).exists(): self.image_label.setText("图像不存在") self.status_label.setText("图像加载失败") return self.current_image_path = image_path self.scale_factor = 1.0 pixmap = QPixmap(image_path) if pixmap.isNull(): self.image_label.setText("无法加载图像") self.status_label.setText("图像格式不支持") return self.original_pixmap = pixmap self.fit_to_window() file_info = Path(image_path).stat() size_mb = file_info.st_size / (1024 * 1024) self.status_label.setText(f"{pixmap.width()}x{pixmap.height()} | {size_mb:.2f} MB | {Path(image_path).name} | 适应窗口") def update_image_display(self): """更新图像显示 - 使用防抖避免频繁重绘卡顿""" self._update_timer.stop() self._pending_scale = self.scale_factor self._update_timer.start(50) def _do_update_display(self): """实际执行图像更新""" if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull(): return if self._pending_scale is None: return if self._pending_scale > 2.0 or self._pending_scale < 0.5: transform = Qt.FastTransformation else: transform = Qt.SmoothTransformation scaled_pixmap = self.original_pixmap.scaled( int(self.original_pixmap.width() * self._pending_scale), int(self.original_pixmap.height() * self._pending_scale), Qt.KeepAspectRatio, transform ) self.image_label.setPixmap(scaled_pixmap) self._pending_scale = None def eventFilter(self, obj, event): from PyQt5.QtCore import QEvent, Qt if event.type() == QEvent.Wheel: if event.modifiers() == Qt.ControlModifier: if obj is self.scroll_area.viewport() or obj is self.image_label: delta = event.angleDelta().y() if delta > 0: if getattr(self, 'scale_factor', 1.0) < 5.0: self.scale_factor = min(self.scale_factor * 1.1, 5.0) self.update_image_display() else: if getattr(self, 'scale_factor', 1.0) > 0.1: self.scale_factor = max(self.scale_factor / 1.1, 0.1) self.update_image_display() return True return super().eventFilter(obj, event) def zoom_in(self): """放大""" if self.scale_factor < 5.0: self.scale_factor = min(self.scale_factor * 1.25, 5.0) self.update_image_display() def zoom_out(self): """缩小""" if self.scale_factor > 0.1: self.scale_factor = max(self.scale_factor / 1.25, 0.1) self.update_image_display() def fit_to_window(self): """适应窗口""" if not hasattr(self, 'original_pixmap') or self.original_pixmap.isNull(): return view_size = self.scroll_area.viewport().size() img_size = self.original_pixmap.size() scale_w = view_size.width() / img_size.width() scale_h = view_size.height() / img_size.height() self._fit_scale = min(scale_w, scale_h) self.scale_factor = self._fit_scale self.update_image_display() self.status_label.setText(f"适应窗口 | 缩放: {self.scale_factor:.1%}") def original_size(self): """原始大小""" self.scale_factor = 1.0 self._fit_scale = None self.update_image_display() self.status_label.setText("原始大小 | 缩放: 100%") def save_image(self): """保存图像""" if not self.current_image_path: return file_path, _ = QFileDialog.getSaveFileName( self, "保存图像", Path(self.current_image_path).name, "PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)" ) if file_path: try: import shutil shutil.copy(self.current_image_path, file_path) except Exception as e: QMessageBox.critical(self, "错误", f"保存失败: {e}") class ChartBrowserDialog(QDialog): """图表浏览器对话框""" def __init__(self, chart_files, parent=None): super().__init__(parent) self.chart_files = sorted(chart_files, key=lambda x: x.stat().st_mtime, reverse=True) self.current_index = 0 self.setWindowTitle("图表浏览器") self.resize(1200, 800) self.init_ui() self.show_chart(0) def init_ui(self): layout = QVBoxLayout() list_group = QGroupBox(f"图表列表 (共 {len(self.chart_files)} 个)") list_layout = QHBoxLayout() self.chart_list = QListWidget() self.chart_list.setMaximumHeight(150) for chart_file in self.chart_files: self.chart_list.addItem(chart_file.name) self.chart_list.currentRowChanged.connect(self.show_chart) list_layout.addWidget(self.chart_list) list_group.setLayout(list_layout) layout.addWidget(list_group) self.figure = Figure(figsize=(12, 8)) self.canvas = FigureCanvas(self.figure) self.canvas.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) self.toolbar = NavigationToolbar(self.canvas, self) layout.addWidget(self.toolbar) layout.addWidget(self.canvas, 1) btn_layout = QHBoxLayout() self.prev_btn = QPushButton("◀ 上一个") self.prev_btn.clicked.connect(self.prev_chart) btn_layout.addWidget(self.prev_btn) self.next_btn = QPushButton("下一个 >") self.next_btn.clicked.connect(self.next_chart) btn_layout.addWidget(self.next_btn) btn_layout.addStretch() self.save_btn = QPushButton("💾 保存当前图表") self.save_btn.clicked.connect(self.save_current_chart) btn_layout.addWidget(self.save_btn) self.close_btn = QPushButton("关闭") self.close_btn.clicked.connect(self.close) btn_layout.addWidget(self.close_btn) layout.addLayout(btn_layout) self.setLayout(layout) def show_chart(self, index): """显示指定索引的图表""" if 0 <= index < len(self.chart_files): self.current_index = index self.chart_list.setCurrentRow(index) chart_file = self.chart_files[index] self.figure.clear() ax = self.figure.add_subplot(111) try: import matplotlib.image as mpimg img = mpimg.imread(str(chart_file)) ax.imshow(img) ax.axis('off') ax.set_title(chart_file.name, fontsize=12, pad=10) self.figure.tight_layout() self.canvas.draw() except Exception as e: ax.text(0.5, 0.5, f'加载图片失败:\n{str(e)}', ha='center', va='center', transform=ax.transAxes) self.canvas.draw() self.prev_btn.setEnabled(index > 0) self.next_btn.setEnabled(index < len(self.chart_files) - 1) def prev_chart(self): """上一个图表""" if self.current_index > 0: self.show_chart(self.current_index - 1) def next_chart(self): """下一个图表""" if self.current_index < len(self.chart_files) - 1: self.show_chart(self.current_index + 1) def save_current_chart(self): """保存当前图表""" if 0 <= self.current_index < len(self.chart_files): current_file = self.chart_files[self.current_index] file_path, _ = QFileDialog.getSaveFileName( self, "保存图表", current_file.name, "PNG图片 (*.png);;JPG图片 (*.jpg);;所有文件 (*.*)" ) if file_path: try: import shutil shutil.copy(str(current_file), file_path) QMessageBox.information(self, "成功", f"图表已保存到:\n{file_path}") except Exception as e: QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}") class Step12VizPanel(QWidget): """步骤12:可视化展示""" def __init__(self, parent=None): super().__init__(parent) self.work_dir = None self.chart_viewer = None self._viz_thread = None self.init_ui() def _viz_set_busy(self, busy: bool): for w in ( getattr(self, "gen_all_btn", None), getattr(self, "scan_btn", None), ): if w is not None: w.setEnabled(not busy) def _start_visualization_thread(self, task: str, extra: Optional[dict] = None) -> bool: if not self.work_dir: QMessageBox.warning(self, "警告", "请先选择工作目录!") return False work_path = Path(self.work_dir) if not work_path.exists(): QMessageBox.warning(self, "警告", "工作目录不存在!") return False if self._viz_thread and self._viz_thread.isRunning(): QMessageBox.information(self, "提示", "可视化任务正在运行,请稍候。") return False self._viz_thread = VisualizationWorkerThread(task, str(work_path), extra or {}) self._viz_thread.finished_ok.connect(self._on_visualization_worker_ok, Qt.QueuedConnection) self._viz_thread.failed.connect(self._on_visualization_worker_fail, Qt.QueuedConnection) self._viz_thread.finished.connect(lambda: self._viz_set_busy(False), Qt.QueuedConnection) self._viz_set_busy(True) self._viz_thread.start() return True def _spectrum_meta_param_columns(self, df: pd.DataFrame) -> List[str]: """光谱图可选的水质参数列(光谱波段列之前、且为数值型)。""" wl = _viz_infer_wavelength_start_column(df) if isinstance(wl, str): idx = int(df.columns.get_loc(wl)) + 1 else: idx = int(wl) if idx <= 0 or idx >= len(df.columns): numeric = df.select_dtypes(include=[np.number]).columns.tolist() return [ c for c in numeric if not any(x in str(c).lower() for x in ("utm", "lat", "lon", "x", "y")) ] meta = list(df.columns[:idx]) return [c for c in meta if pd.api.types.is_numeric_dtype(df[c])] def _statistics_param_columns(self, df: pd.DataFrame) -> List[str]: """统计图用的参数列:只统计水质参数列(数值型),排除波长列。""" numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() wl = _viz_infer_wavelength_start_column(df) if isinstance(wl, str): idx = int(df.columns.get_loc(wl)) + 1 else: idx = int(wl) coord_kw = ("utm", "lat", "lon") if 0 < idx < len(df.columns): meta_set = set(df.columns[:idx]) return [ col for col in numeric_cols if col in meta_set and not any(x in str(col).lower() for x in coord_kw) ] return [ col for col in numeric_cols if not any(x in str(col).lower() for x in coord_kw + ("x", "y")) ] def _on_visualization_worker_ok(self, payload): if not isinstance(payload, dict): self.scan_work_directory() return t = payload.get("task") if t == "mask_glint": cnt = int(payload.get("count") or 0) if cnt > 0: QMessageBox.information( self, "成功", f"掩膜和耀斑缩略图生成完成,共 {cnt} 个预览图。\n" f"保存位置: 14_visualization/glint_deglint_previews/", ) else: QMessageBox.warning( self, "警告", "未找到可处理的影像文件(2_Glint_Detection/3_deglint 等)。", ) elif t == "sampling_map": map_path = payload.get("map_path") QMessageBox.information( self, "成功", "采样点地图生成完成。\n" f"输出: {Path(map_path).name if map_path else ''}\n" "路径: 14_visualization/sampling_maps/", ) if map_path: self.show_chart_viewer(map_path, "采样点分布图") elif t == "spectrum": multi = payload.get("output_paths") if isinstance(multi, list) and multi: ok_paths = [p for p in multi if p and Path(str(p)).is_file()] errs = payload.get("errors") or [] msg = ( f"已为 {len(ok_paths)} 个水质参数生成光谱对比图。\n" f"保存目录: 工作目录/14_visualization/" ) if errs: msg += f"\n\n以下列未生成或出错 ({len(errs)} 项,详见日志):\n" msg += "\n".join(str(e) for e in errs[:8]) if len(errs) > 8: msg += "\n..." QMessageBox.information(self, "成功", msg) if ok_paths: self.show_chart_viewer(ok_paths[0], "光谱曲线对比(首张)") else: outp = payload.get("output_path") param = payload.get("param_col", "") QMessageBox.information(self, "成功", f"光谱图已生成:\n{outp}") if outp: self.show_chart_viewer(outp, f"{param} - 光谱曲线对比") elif t == "statistics": outp = payload.get("output_paths") or {} QMessageBox.information( self, "成功", f"统计图表已生成,共 {len(outp)} 项。" ) if isinstance(outp, dict) and "boxplot" in outp: self.show_chart_viewer(outp["boxplot"], "水质参数箱线图") elif t == "scatter": paths = payload.get("scatter_paths") or {} ok_paths = [p for p in paths.values() if p and Path(str(p)).is_file()] if ok_paths: QMessageBox.information( self, "成功", f"已生成 {len(ok_paths)} 个模型评估散点图。\n" f"保存位置: 14_visualization/scatter_plots/", ) self.show_chart_viewer(ok_paths[0], "模型评估散点图") else: QMessageBox.warning( self, "提示", "未生成任何散点图。请确认 7_Supervised_Model_Training 下已有各参数子目录及模型文件," "且训练 CSV 与建模时一致。", ) elif t == "generate_all_selected": parts = payload.get("parts") or [] QMessageBox.information( self, "完成", "批量可视化已执行:\n" + "\n".join(parts) if parts else "(无选中项或已跳过)", ) self.scan_work_directory() def _on_visualization_worker_fail(self, err: str): QMessageBox.critical(self, "错误", f"可视化任务失败:\n{err[:1200]}") def init_ui(self): """初始化UI - 使用左右分栏布局""" main_layout = QHBoxLayout() main_layout.setSpacing(10) main_layout.setContentsMargins(10, 10, 10, 10) # ===== 左侧面板 ===== left_panel = QWidget() left_layout = QVBoxLayout() left_layout.setContentsMargins(0, 0, 0, 0) # 工作目录选择 dir_group = QGroupBox("工作目录") dir_layout = QHBoxLayout() self.work_dir_edit = QLineEdit() self.work_dir_edit.setPlaceholderText("选择工作目录...") self.work_dir_edit.setReadOnly(True) dir_browse_btn = QPushButton("浏览") dir_browse_btn.clicked.connect(self.browse_work_dir) dir_layout.addWidget(self.work_dir_edit, 1) dir_layout.addWidget(dir_browse_btn) dir_group.setLayout(dir_layout) left_layout.addWidget(dir_group) # 图像目录选择(优先指向预测结果目录) img_dir_group = QGroupBox("图像目录") img_dir_layout = QHBoxLayout() self.img_dir_edit = QLineEdit() self.img_dir_edit.setPlaceholderText("预测结果目录(自动填充)…") self.img_dir_edit.setReadOnly(True) img_dir_browse_btn = QPushButton("浏览") img_dir_browse_btn.clicked.connect(self.browse_img_dir) img_dir_layout.addWidget(self.img_dir_edit, 1) img_dir_layout.addWidget(img_dir_browse_btn) img_dir_group.setLayout(img_dir_layout) left_layout.addWidget(img_dir_group) # 图像目录树 tree_group = QGroupBox("图像目录") tree_layout = QVBoxLayout() self.image_tree = ImageCategoryTree() self.image_tree.itemClicked.connect(self.on_tree_item_clicked) tree_layout.addWidget(self.image_tree) tree_group.setLayout(tree_layout) left_layout.addWidget(tree_group, 1) # 可视化配置 config_group = QGroupBox("可视化配置") config_layout = QVBoxLayout() self.gen_scatter = QCheckBox("模型评估散点图") self.gen_scatter.setChecked(True) config_layout.addWidget(self.gen_scatter) self.gen_spectrum = QCheckBox("光谱曲线图") self.gen_spectrum.setChecked(True) config_layout.addWidget(self.gen_spectrum) self.gen_boxplots = QCheckBox("统计图表") self.gen_boxplots.setChecked(True) config_layout.addWidget(self.gen_boxplots) self.gen_mask_glint = QCheckBox("掩膜和耀斑缩略图") self.gen_mask_glint.setChecked(True) config_layout.addWidget(self.gen_mask_glint) self.gen_sampling_map = QCheckBox("采样点地图") self.gen_sampling_map.setChecked(True) config_layout.addWidget(self.gen_sampling_map) config_layout.addSpacing(10) line = QFrame() line.setFrameShape(QFrame.HLine) line.setStyleSheet("color: #ddd;") config_layout.addWidget(line) config_layout.addSpacing(10) self.gen_all_btn = QPushButton("🚀 生成全部") self.gen_all_btn.setToolTip("生成所有类型的可视化图表") self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;") self.gen_all_btn.clicked.connect(self.generate_all_visualizations) config_layout.addWidget(self.gen_all_btn) self.scan_btn = QPushButton("📁 扫描目录") self.scan_btn.setToolTip("扫描工作目录中的图像文件") self.scan_btn.clicked.connect(self.scan_work_directory) config_layout.addWidget(self.scan_btn) config_group.setLayout(config_layout) left_layout.addWidget(config_group) left_panel.setLayout(left_layout) left_panel.setMaximumWidth(350) main_layout.addWidget(left_panel, 0) # ===== 右侧面板 ===== right_panel = QWidget() right_layout = QVBoxLayout() right_layout.setContentsMargins(0, 0, 0, 0) self.image_viewer = ImageViewerWidget() self.image_viewer.refresh_btn.clicked.connect(self.scan_work_directory) right_layout.addWidget(self.image_viewer, 1) right_panel.setLayout(right_layout) main_layout.addWidget(right_panel, 1) self.setLayout(main_layout) def set_work_dir(self, work_dir): """设置工作目录""" self.work_dir = work_dir self.work_dir_edit.setText(str(work_dir)) if work_dir: QTimer.singleShot(100, self.scan_work_directory) def _get_default_work_dir(self): """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" if hasattr(self, 'work_dir') and self.work_dir: return str(self.work_dir) mw = self.window() if mw and hasattr(mw, 'work_dir') and mw.work_dir: return str(mw.work_dir) return "" def browse_work_dir(self): """浏览工作目录""" default = self._get_default_work_dir() dir_path = QFileDialog.getExistingDirectory(self, "选择工作目录", default) if dir_path: self.work_dir = dir_path self.work_dir_edit.setText(dir_path) self.scan_work_directory() def browse_img_dir(self): """手动浏览图像目录""" default = self._get_default_work_dir() dir_path = QFileDialog.getExistingDirectory(self, "选择图像目录", default) if dir_path: self.img_dir_edit.setText(dir_path) self.image_tree.scan_directory(dir_path) self._load_first_image_from_tree() def update_from_config(self, work_dir=None, pipeline=None): """从全局配置自动推断并填入图像目录,然后自动加载目录内容。 推断优先级: 1. {work_dir}/9_ML_Prediction(机器学习预测) 2. {work_dir}/11_12_13_predictions/Non_Empirical_Prediction(普通回归预测) 3. {work_dir}/13_Custom_Regression/Custom_Regression_Prediction(自定义回归预测) 4. {work_dir}/14_visualization(可视化目录) 5. {work_dir}(工作目录根) """ try: if work_dir: self.work_dir = work_dir self.work_dir_edit.setText(str(work_dir)) elif not self.work_dir: return work_path = Path(self.work_dir) pred_dir = Path(resolve_subdir(self.work_dir, 'prediction_dir')) # 按优先级寻找存在的目录 candidates = [ Path(resolve_subdir(self.work_dir, 'ml_prediction')), pred_dir / "Non_Empirical_Prediction", Path(resolve_subdir(self.work_dir, 'custom_regression')) / "Custom_Regression_Prediction", Path(resolve_subdir(self.work_dir, 'visualization')), work_path, ] detected_dir = None for candidate in candidates: if candidate.exists() and candidate.is_dir(): detected_dir = candidate break if detected_dir: detected_str = str(detected_dir) self.img_dir_edit.setText(detected_str) self.image_tree.scan_directory(detected_str) else: # 无预测目录时扫描整个工作目录 self.image_tree.scan_directory(self.work_dir) # 自动触发加载第一张图像 self._load_first_image_from_tree() except Exception as e: import traceback print(f"可视化面板 update_from_config 出错: {e}") traceback.print_exc() def _load_first_image_from_tree(self): """自动加载树状图中的第一张有效图片(兼容物理目录层级结构)""" try: tree = getattr(self, 'image_tree', None) if not tree: return from PyQt5.QtCore import Qt def find_first_image(item): # 检查当前节点是否是图片节点 data = item.data(0, Qt.UserRole) if isinstance(data, dict) and data.get("type") == "image": return item # 如果不是,递归检查所有子节点 for i in range(item.childCount()): found = find_first_image(item.child(i)) if found: return found return None # 遍历所有顶层节点 for i in range(tree.topLevelItemCount()): first_img_item = find_first_image(tree.topLevelItem(i)) if first_img_item: tree.setCurrentItem(first_img_item) # 主动触发一次点击槽函数,以在右侧渲染图片 self.on_tree_item_clicked(first_img_item, 0) return except Exception as e: import traceback print(f"自动加载首张图片失败: {e}") traceback.print_exc() def scan_work_directory(self): """扫描工作目录中的图像文件""" if not self.work_dir: return work_path = Path(self.work_dir) if not work_path.exists(): return print(f"扫描工作目录: {work_path}") self.image_tree.scan_directory(str(work_path)) self._setup_prediction_output_dirs(work_path) viz_dir = Path(resolve_subdir(str(work_path), 'visualization')) if viz_dir.exists(): image_files = list(viz_dir.glob("**/*.png")) + list(viz_dir.glob("**/*.jpg")) if image_files: self.image_viewer.load_image(str(image_files[0])) def _setup_prediction_output_dirs(self, work_path: Path): """设置三个预测步骤的默认输出目录""" try: base_prediction_dir = Path(resolve_subdir(str(work_path), 'prediction_dir')) ml_dir = Path(resolve_subdir(str(work_path), 'ml_prediction')) reg_dir = base_prediction_dir / "Regression_Model_Prediction" custom_dir = Path(resolve_subdir(str(work_path), 'custom_regression')) / "Custom_Regression_Prediction" ml_dir.mkdir(parents=True, exist_ok=True) reg_dir.mkdir(parents=True, exist_ok=True) custom_dir.mkdir(parents=True, exist_ok=True) # 旧的 self.step11_ml_panel/step11_panel/step12_panel 在 Step12VizPanel 上不存在,是死代码。 # 三个目录的真实默认值在用户首次浏览 / 自动填充时由各 panel 自己的 _get_default_work_dir 路径产出。 # 这里仅做目录创建 + 提示输出,便于用户在工作目录树中能看到预测输出位置。 print(f"预测输出目录已创建:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}") except Exception as e: print(f"设置预测输出目录失败: {e}") def on_tree_item_clicked(self, item, column): """目录树项点击事件""" data = item.data(0, Qt.UserRole) if not data: return if data.get("type") == "image": image_path = data.get("path") if image_path and Path(image_path).exists(): self.image_viewer.load_image(image_path) def generate_all_visualizations(self): """生成所有可视化图表""" if not self.work_dir: QMessageBox.warning(self, "警告", "请先选择工作目录!") return work_path = Path(self.work_dir) if not work_path.exists(): QMessageBox.warning(self, "警告", "工作目录不存在!") return if not (self.gen_scatter.isChecked() or self.gen_spectrum.isChecked() or self.gen_boxplots.isChecked() or self.gen_mask_glint.isChecked() or self.gen_sampling_map.isChecked()): QMessageBox.information(self, "提示", "请至少勾选一项可视化配置选项以生成图表。") return reply = QMessageBox.question( self, "确认生成", "将根据左侧勾选项在后台生成可视化图表,可能需要较长时间。\n是否继续?", QMessageBox.Yes | QMessageBox.No ) if reply != QMessageBox.Yes: return extra = { "gen_scatter": self.gen_scatter.isChecked(), "gen_spectrum": self.gen_spectrum.isChecked(), "gen_boxplots": self.gen_boxplots.isChecked(), "gen_mask_glint": self.gen_mask_glint.isChecked(), "gen_sampling_map": self.gen_sampling_map.isChecked(), } main_window = self.window() factory = getattr(main_window, '_panel_factory', None) if main_window else None step5_panel = factory.get_panel('step5_clean') if factory else None if step5_panel and getattr(step5_panel, 'output_file', None): _resolved_csv = step5_panel.output_file.get_path() if _resolved_csv: extra["training_csv_path"] = _resolved_csv step8_panel = factory.get_panel('step8_ml_train') if factory else None if step8_panel and getattr(step8_panel, 'output_path', None): _resolved_models_dir = step8_panel.output_path.get_path() if _resolved_models_dir: extra["models_dir"] = _resolved_models_dir self._start_visualization_thread("generate_all_selected", extra) def generate_chart(self, chart_type): """生成图表""" if not self.work_dir: QMessageBox.warning(self, "警告", "请先选择工作目录!") return work_path = Path(self.work_dir) if not work_path.exists(): QMessageBox.warning(self, "警告", "工作目录不存在!") return try: main_window = self.window() factory = getattr(main_window, '_panel_factory', None) if main_window else None step5_panel = factory.get_panel('step5_clean') if factory else None if step5_panel and getattr(step5_panel, 'output_file', None) and step5_panel.output_file.get_path(): training_spectra_csv = Path(step5_panel.output_file.get_path()) else: training_spectra_csv = _viz_training_spectra_csv_path(work_path) if chart_type == 'scatter': if not training_spectra_csv.is_file(): QMessageBox.warning( self, "警告", "未找到 5_training_spectra\\training_spectra.csv。\n" "请先执行步骤5(光谱特征提取)生成该文件。", ) return training_csv = training_spectra_csv models_dir = work_path / "7_Supervised_Model_Training" if not models_dir.is_dir() or not any(d.is_dir() for d in models_dir.iterdir()): mdir = QFileDialog.getExistingDirectory( self, "选择模型根目录(内含各水质参数子文件夹)", str(work_path)) if not mdir: return models_dir = Path(mdir) self._start_visualization_thread( "scatter", {"training_csv_path": str(training_csv), "models_dir": str(models_dir)}, ) return if chart_type == 'spectrum': if not training_spectra_csv.is_file(): QMessageBox.warning( self, "警告", "未找到 5_training_spectra\\training_spectra.csv。\n" "光谱分析固定使用该文件,请先执行步骤5(光谱特征提取)。", ) return csv_file = training_spectra_csv df = pd.read_csv(csv_file) columns = self._spectrum_meta_param_columns(df) if not columns: QMessageBox.warning( self, "警告", "当前 CSV 中没有可用的数值型水质参数列,无法按参数分组绘制光谱图。", ) return wl_col = _viz_infer_wavelength_start_column(df) self._start_visualization_thread( "spectrum", {"csv_path": str(csv_file), "param_cols": columns, "wavelength_start_column": wl_col, "n_groups": 5}, ) return if chart_type == 'statistics': if not training_spectra_csv.is_file(): QMessageBox.warning( self, "警告", "未找到 5_training_spectra\\training_spectra.csv。\n" "统计分析固定使用该文件,请先执行步骤5(光谱特征提取)。", ) return csv_file = training_spectra_csv df = pd.read_csv(csv_file) param_cols = self._statistics_param_columns(df) if not param_cols: QMessageBox.warning(self, "警告", "未找到可用的水质参数列!") return self._start_visualization_thread( "statistics", {"csv_path": str(csv_file), "param_cols": param_cols}, ) return if chart_type == 'sampling_map': self.generate_sampling_point_map() return except ImportError: QMessageBox.critical( self, "错误", "无法导入可视化模块!\n请确保 visualization_reports.py 文件存在。", ) except Exception as e: QMessageBox.critical( self, "错误", f"生成图表时出错:\n{str(e)}\n\n{traceback.format_exc()}", ) def generate_mask_glint_previews(self): """生成掩膜和耀斑缩略图""" self._start_visualization_thread("mask_glint") def generate_sampling_point_map(self): """生成采样点地图""" self._start_visualization_thread("sampling_map") def view_chart(self, chart_type): """查看图表""" if not self.work_dir: QMessageBox.warning(self, "警告", "请先选择工作目录!") return work_path = Path(self.work_dir) viz_dir = Path(resolve_subdir(self.work_dir, 'visualization')) viz_dir2 = viz_dir / "boxplots" viz_dir3 = viz_dir / "scatter_plots" if not viz_dir.exists(): QMessageBox.warning(self, "警告", f"可视化目录不存在:\n{viz_dir}\n\n请先生成图表。") return chart_files = [] if chart_type == 'scatter': chart_files = list(viz_dir3.glob("*scatter*.png")) elif chart_type == 'spectrum': chart_files = list(viz_dir.glob("*spectrum*.png")) elif chart_type == 'statistics': chart_files = list(viz_dir2.glob("*boxplot.png")) + \ list(viz_dir.glob("*histogram.png")) + \ list(viz_dir.glob("*heatmap.png")) elif chart_type == 'distribution': chart_files = list(viz_dir.glob("**/*distribution.png")) elif chart_type == 'mask_glint': glint_dir = viz_dir / "glint_deglint_previews" chart_files = list(glint_dir.glob("*preview.png")) if glint_dir.exists() else \ list(viz_dir.glob("*preview.png")) + \ list(viz_dir.glob("*glint*.png")) + \ list(viz_dir.glob("*mask*.png")) elif chart_type == 'sampling_map': sampling_dir = viz_dir / "sampling_maps" chart_files = list(sampling_dir.glob("*sampling_map.png")) if sampling_dir.exists() else \ list(viz_dir.glob("*sampling*.png")) if not chart_files: QMessageBox.warning(self, "警告", f"未找到{chart_type}类型的图表文件!\n\n请先生成图表。") return if len(chart_files) > 1: from PyQt5.QtWidgets import QInputDialog file_names = [f.name for f in chart_files] file_name, ok = QInputDialog.getItem( self, "选择图表", "请选择要查看的图表:", file_names, 0, False) if ok: selected_file = next(f for f in chart_files if f.name == file_name) self.show_chart_viewer(str(selected_file), file_name) else: self.show_chart_viewer(str(chart_files[0]), chart_files[0].name) def browse_all_charts(self): """浏览所有图表""" if not self.work_dir: QMessageBox.warning(self, "警告", "请先选择工作目录!") return work_path = Path(self.work_dir) chart_files = list(work_path.glob("**/*.png")) + list(work_path.glob("**/*.jpg")) if not chart_files: QMessageBox.warning(self, "警告", "未找到图表文件!") return dialog = ChartBrowserDialog(chart_files, self) dialog.exec_() def show_chart_viewer(self, image_path, title="图表查看器"): """显示图表查看器""" viewer = ChartViewerDialog(title=title, parent=self) viewer.display_image(image_path) viewer.exec_() def get_config(self): """获取配置""" return { 'generate_scatter': self.gen_scatter.isChecked(), 'generate_boxplots': self.gen_boxplots.isChecked(), 'generate_spectrum': self.gen_spectrum.isChecked(), 'generate_glint_previews': self.gen_mask_glint.isChecked(), 'generate_sampling_maps': self.gen_sampling_map.isChecked(), 'scatter_config': { 'metric': 'test_r2', 'feature_start_column': 13, 'test_size': 0.2, 'random_state': 42 }, 'boxplot_config': { 'data_start_column': 4, 'save_individual': True, 'use_seaborn': True }, 'glint_preview_config': { 'work_dir': None, 'output_subdir': 'glint_deglint_previews', 'generate_glint': True, 'generate_deglint': True } } def set_config(self, config): """设置配置""" if not config: return if 'generate_scatter' in config: self.gen_scatter.setChecked(config['generate_scatter']) if 'generate_boxplots' in config: self.gen_boxplots.setChecked(config['generate_boxplots']) if 'generate_spectrum' in config: self.gen_spectrum.setChecked(config['generate_spectrum']) if 'generate_glint_previews' in config: self.gen_mask_glint.setChecked(config['generate_glint_previews']) if 'generate_sampling_maps' in config: self.gen_sampling_map.setChecked(config.get('generate_sampling_maps', True))