#!/usr/bin/env python # -*- coding: utf-8 -*- """ 可视化后台线程模块 包含 VisualizationWorkerThread 后台线程类和辅助函数。 """ from pathlib import Path from typing import Optional, List, Union from PyQt5.QtCore import QThread, pyqtSignal import numpy as np def _viz_infer_wavelength_start_column(df) -> Union[str, int]: """推断光谱起始列(training_spectra 通常以波长数值为列名,未必含 UTM_Y)。""" import pandas as pd for i, col in enumerate(df.columns): name = str(col).strip().lstrip("\ufeff") try: v = float(name) except ValueError: continue if 200.0 <= v <= 3000.0: return i if "UTM_Y" in df.columns: return "UTM_Y" return 0 class VisualizationWorkerThread(QThread): """可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。""" finished_ok = pyqtSignal(object) failed = pyqtSignal(str) def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None): super().__init__() self.task = task self.work_dir = str(work_dir) self.extra = extra or {} def run(self): mpl_prev = None try: import matplotlib mpl_prev = matplotlib.get_backend() except Exception: pass try: import matplotlib.pyplot as plt plt.switch_backend("Agg") except Exception: mpl_prev = None try: wp = Path(self.work_dir) if self.task == "mask_glint": from src.postprocessing.visualization_reports import WaterQualityVisualization viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization")) preview_paths = viz.generate_glint_deglint_previews( work_dir=str(wp), output_subdir="glint_deglint_previews", ) cnt = len(preview_paths) if preview_paths else 0 self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths}) elif self.task == "sampling_map": hyperspectral_files = [] deglint_dir = wp / "3_deglint" if deglint_dir.exists(): for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): hyperspectral_files.extend(list(deglint_dir.glob(ext))) if not hyperspectral_files: for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): hyperspectral_files.extend(list(wp.glob(f"**/{ext}"))) if not hyperspectral_files: self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif)。") return hyperspectral_path = str(hyperspectral_files[0]) csv_files = [] processed_dir = wp / "4_processed_data" if processed_dir.exists(): csv_files = list(processed_dir.glob("*.csv")) if not csv_files: csv_files = ( list(wp.glob("**/*sampling*.csv")) + list(wp.glob("**/*point*.csv")) + list(wp.glob("**/*.csv")) ) if not csv_files: self.failed.emit("未找到采样点 CSV 文件。") return csv_path = str(csv_files[0]) from src.postprocessing.point_map import SamplingPointMap map_generator = SamplingPointMap( output_dir=str(wp / "14_visualization" / "sampling_maps"), fast_mode=True, ) map_path = map_generator.create_sampling_point_map( hyperspectral_path=hyperspectral_path, csv_path=csv_path, point_color="red", point_size=100, point_alpha=0.9, show_north_arrow=True, show_scale_bar=True, show_legend=True, downsample=True, dpi=180, ) self.finished_ok.emit( { "task": "sampling_map", "map_path": map_path, "hyperspectral_path": hyperspectral_path, "csv_path": csv_path, } ) elif self.task == "spectrum": from src.postprocessing.visualization_reports import WaterQualityVisualization viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization")) csv_file = self.extra.get("csv_path") wl = self.extra.get("wavelength_start_column", "UTM_Y") n_groups = int(self.extra.get("n_groups", 5)) param_cols = self.extra.get("param_cols") or [] if param_cols: output_paths: List[str] = [] err_lines: List[str] = [] for param_col in param_cols: try: out = viz.plot_spectrum_by_parameter( csv_path=str(csv_file), parameter_column=param_col, wavelength_start_column=wl, n_groups=n_groups, ) output_paths.append(out) except Exception as _ex: err_lines.append(f"{param_col}: {_ex}") if not output_paths: self.failed.emit( "所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20]) ) return self.finished_ok.emit( { "task": "spectrum", "output_paths": output_paths, "errors": err_lines, } ) else: param_col = self.extra.get("param_col") out = viz.plot_spectrum_by_parameter( csv_path=str(csv_file), parameter_column=param_col, wavelength_start_column=wl, n_groups=n_groups, ) self.finished_ok.emit( {"task": "spectrum", "output_path": out, "param_col": param_col} ) elif self.task == "statistics": from src.postprocessing.visualization_reports import WaterQualityVisualization viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization")) csv_file = self.extra.get("csv_path") param_cols = self.extra.get("param_cols") or [] output_paths = viz.plot_statistical_charts( csv_path=str(csv_file), parameter_columns=param_cols, ) self.finished_ok.emit( {"task": "statistics", "output_paths": output_paths} ) elif self.task == "scatter": from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline training_csv_path = (self.extra.get("training_csv_path") or "").strip() models_dir = (self.extra.get("models_dir") or "").strip() if not training_csv_path or not Path(training_csv_path).is_file(): self.failed.emit("训练光谱 CSV 无效或不存在,请确认已选择步骤5输出的文件。") return if not models_dir or not Path(models_dir).is_dir(): self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。") return pipeline = WaterQualityInversionPipeline(work_dir=str(wp)) scatter_paths = pipeline.generate_model_scatter_plots( training_csv_path=training_csv_path, models_dir=models_dir, ) self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}}) elif self.task == "generate_all_selected": from src.postprocessing.visualization_reports import WaterQualityVisualization viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization")) parts = [] training_csv = wp / "5_training_spectra" / "training_spectra.csv" if self.extra.get("gen_scatter"): if training_csv.is_file(): models_dir = wp / "7_Supervised_Model_Training" if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()): from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline pipeline = WaterQualityInversionPipeline(work_dir=str(wp)) scatter_paths = pipeline.generate_model_scatter_plots( training_csv_path=str(training_csv), models_dir=str(models_dir), ) count = len(scatter_paths) if scatter_paths else 0 parts.append(f"散点图: {count} 个") else: parts.append("散点图: 跳过(无模型目录)") else: parts.append("散点图: 跳过(无训练数据)") if self.extra.get("gen_spectrum"): if training_csv.is_file(): import pandas as pd df = pd.read_csv(training_csv) wl_col = _viz_infer_wavelength_start_column(df) if isinstance(wl_col, str): idx = int(df.columns.get_loc(wl_col)) + 1 else: idx = int(wl_col) param_cols = [] if idx > 0 and idx < len(df.columns): param_cols = [ c for c in df.columns[:idx] if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0 ] if param_cols: spectrum_paths = [] for param_col in param_cols: try: path = viz.plot_spectrum_by_parameter( csv_path=str(training_csv), parameter_column=param_col, wavelength_start_column=wl_col, n_groups=5, ) if path: spectrum_paths.append(path) except Exception as e: print(f"生成光谱图失败 ({param_col}): {e}") count = len(spectrum_paths) parts.append(f"光谱图: {count} 个") else: parts.append("光谱图: 跳过(无可用参数列)") else: parts.append("光谱图: 跳过(无训练数据)") if self.extra.get("gen_boxplots"): if training_csv.is_file(): import pandas as pd df = pd.read_csv(training_csv) exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate'] param_cols = [ c for c in df.select_dtypes(include=[np.number]).columns if not any(exc in c.lower() for exc in exclude_cols) ] wl = _viz_infer_wavelength_start_column(df) if isinstance(wl, str): idx = int(df.columns.get_loc(wl)) + 1 else: idx = int(wl) if 0 < idx < len(df.columns): meta_set = set(df.columns[:idx]) param_cols = [c for c in param_cols if c in meta_set] if param_cols: output_dict = viz.plot_statistical_charts( csv_path=str(training_csv), parameter_columns=param_cols, ) count = len([v for v in output_dict.values() if v]) if output_dict else 0 parts.append(f"统计图: {count} 个") else: parts.append("统计图: 跳过(无可用水质参数列)") else: parts.append("统计图: 跳过(无训练数据)") if self.extra.get("gen_mask_glint"): preview_paths = viz.generate_glint_deglint_previews( work_dir=str(wp), output_subdir="glint_deglint_previews", ) parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0} 个") if self.extra.get("gen_sampling_map"): hyperspectral_files = [] deglint_dir = wp / "3_deglint" if deglint_dir.exists(): for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): hyperspectral_files.extend(list(deglint_dir.glob(ext))) if not hyperspectral_files: for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"): hyperspectral_files.extend(list(wp.glob(f"**/{ext}"))) if hyperspectral_files: hyperspectral_path = str(hyperspectral_files[0]) csv_files = [] processed_dir = wp / "4_processed_data" if processed_dir.exists(): csv_files = list(processed_dir.glob("*.csv")) if not csv_files: csv_files = ( list(wp.glob("**/*sampling*.csv")) + list(wp.glob("**/*point*.csv")) + list(wp.glob("**/*.csv")) ) if csv_files: csv_path = str(csv_files[0]) from src.postprocessing.point_map import SamplingPointMap map_generator = SamplingPointMap( output_dir=str(wp / "14_visualization" / "sampling_maps"), fast_mode=True, ) map_path = map_generator.create_sampling_point_map( hyperspectral_path=hyperspectral_path, csv_path=csv_path, point_color="red", point_size=100, point_alpha=0.9, show_north_arrow=True, show_scale_bar=True, show_legend=True, downsample=True, dpi=180, ) parts.append(f"采样点图: {Path(map_path).name}") else: parts.append("采样点图: 跳过(无CSV)") else: parts.append("采样点图: 跳过(无影像)") self.finished_ok.emit({"task": "generate_all_selected", "parts": parts}) else: self.failed.emit(f"未知可视化任务: {self.task}") except Exception as e: import traceback self.failed.emit(f"{e}\n{traceback.format_exc()}") finally: if mpl_prev: try: import matplotlib.pyplot as plt plt.switch_backend(mpl_prev) except Exception: pass