Files
WQ_GUI/src/gui/core/viz_thread.py

346 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
可视化后台线程模块
包含 VisualizationWorkerThread 后台线程类和辅助函数。
"""
from pathlib import Path
from typing import Optional, List, Union
from PyQt5.QtCore import QThread, pyqtSignal
import numpy as np
def _viz_infer_wavelength_start_column(df) -> Union[str, int]:
"""推断光谱起始列training_spectra 通常以波长数值为列名,未必含 UTM_Y"""
import pandas as pd
for i, col in enumerate(df.columns):
name = str(col).strip().lstrip("\ufeff")
try:
v = float(name)
except ValueError:
continue
if 200.0 <= v <= 3000.0:
return i
if "UTM_Y" in df.columns:
return "UTM_Y"
return 0
class VisualizationWorkerThread(QThread):
"""可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。"""
finished_ok = pyqtSignal(object)
failed = pyqtSignal(str)
def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None):
super().__init__()
self.task = task
self.work_dir = str(work_dir)
self.extra = extra or {}
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
wp = Path(self.work_dir)
if self.task == "mask_glint":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
cnt = len(preview_paths) if preview_paths else 0
self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths})
elif self.task == "sampling_map":
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if not hyperspectral_files:
self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif")
return
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if not csv_files:
self.failed.emit("未找到采样点 CSV 文件。")
return
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
self.finished_ok.emit(
{
"task": "sampling_map",
"map_path": map_path,
"hyperspectral_path": hyperspectral_path,
"csv_path": csv_path,
}
)
elif self.task == "spectrum":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
wl = self.extra.get("wavelength_start_column", "UTM_Y")
n_groups = int(self.extra.get("n_groups", 5))
param_cols = self.extra.get("param_cols") or []
if param_cols:
output_paths: List[str] = []
err_lines: List[str] = []
for param_col in param_cols:
try:
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
output_paths.append(out)
except Exception as _ex:
err_lines.append(f"{param_col}: {_ex}")
if not output_paths:
self.failed.emit(
"所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20])
)
return
self.finished_ok.emit(
{
"task": "spectrum",
"output_paths": output_paths,
"errors": err_lines,
}
)
else:
param_col = self.extra.get("param_col")
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
self.finished_ok.emit(
{"task": "spectrum", "output_path": out, "param_col": param_col}
)
elif self.task == "statistics":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
param_cols = self.extra.get("param_cols") or []
output_paths = viz.plot_statistical_charts(
csv_path=str(csv_file),
parameter_columns=param_cols,
)
self.finished_ok.emit(
{"task": "statistics", "output_paths": output_paths}
)
elif self.task == "scatter":
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
training_csv_path = (self.extra.get("training_csv_path") or "").strip()
models_dir = (self.extra.get("models_dir") or "").strip()
if not training_csv_path or not Path(training_csv_path).is_file():
self.failed.emit("训练光谱 CSV 无效或不存在请确认已选择步骤5输出的文件。")
return
if not models_dir or not Path(models_dir).is_dir():
self.failed.emit("模型目录无效或不存在请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
return
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=training_csv_path,
models_dir=models_dir,
)
self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}})
elif self.task == "generate_all_selected":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
parts = []
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
if self.extra.get("gen_scatter"):
if training_csv.is_file():
models_dir = wp / "7_Supervised_Model_Training"
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=str(training_csv),
models_dir=str(models_dir),
)
count = len(scatter_paths) if scatter_paths else 0
parts.append(f"散点图: {count}")
else:
parts.append("散点图: 跳过(无模型目录)")
else:
parts.append("散点图: 跳过(无训练数据)")
if self.extra.get("gen_spectrum"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
wl_col = _viz_infer_wavelength_start_column(df)
if isinstance(wl_col, str):
idx = int(df.columns.get_loc(wl_col)) + 1
else:
idx = int(wl_col)
param_cols = []
if idx > 0 and idx < len(df.columns):
param_cols = [
c for c in df.columns[:idx]
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
]
if param_cols:
spectrum_paths = []
for param_col in param_cols:
try:
path = viz.plot_spectrum_by_parameter(
csv_path=str(training_csv),
parameter_column=param_col,
wavelength_start_column=wl_col,
n_groups=5,
)
if path:
spectrum_paths.append(path)
except Exception as e:
print(f"生成光谱图失败 ({param_col}): {e}")
count = len(spectrum_paths)
parts.append(f"光谱图: {count}")
else:
parts.append("光谱图: 跳过(无可用参数列)")
else:
parts.append("光谱图: 跳过(无训练数据)")
if self.extra.get("gen_boxplots"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
param_cols = [
c for c in df.select_dtypes(include=[np.number]).columns
if not any(exc in c.lower() for exc in exclude_cols)
]
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
if 0 < idx < len(df.columns):
meta_set = set(df.columns[:idx])
param_cols = [c for c in param_cols if c in meta_set]
if param_cols:
output_dict = viz.plot_statistical_charts(
csv_path=str(training_csv),
parameter_columns=param_cols,
)
count = len([v for v in output_dict.values() if v]) if output_dict else 0
parts.append(f"统计图: {count}")
else:
parts.append("统计图: 跳过(无可用水质参数列)")
else:
parts.append("统计图: 跳过(无训练数据)")
if self.extra.get("gen_mask_glint"):
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0}")
if self.extra.get("gen_sampling_map"):
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if hyperspectral_files:
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if csv_files:
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
parts.append(f"采样点图: {Path(map_path).name}")
else:
parts.append("采样点图: 跳过(无CSV)")
else:
parts.append("采样点图: 跳过(无影像)")
self.finished_ok.emit({"task": "generate_all_selected", "parts": parts})
else:
self.failed.emit(f"未知可视化任务: {self.task}")
except Exception as e:
import traceback
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass