refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名
This commit is contained in:
346
src/gui/core/viz_thread.py
Normal file
346
src/gui/core/viz_thread.py
Normal file
@ -0,0 +1,346 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化后台线程模块
|
||||
|
||||
包含 VisualizationWorkerThread 后台线程类和辅助函数。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from PyQt5.QtCore import QThread, pyqtSignal
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _viz_infer_wavelength_start_column(df) -> Union[str, int]:
|
||||
"""推断光谱起始列(training_spectra 通常以波长数值为列名,未必含 UTM_Y)。"""
|
||||
import pandas as pd
|
||||
for i, col in enumerate(df.columns):
|
||||
name = str(col).strip().lstrip("\ufeff")
|
||||
try:
|
||||
v = float(name)
|
||||
except ValueError:
|
||||
continue
|
||||
if 200.0 <= v <= 3000.0:
|
||||
return i
|
||||
if "UTM_Y" in df.columns:
|
||||
return "UTM_Y"
|
||||
return 0
|
||||
|
||||
|
||||
class VisualizationWorkerThread(QThread):
|
||||
"""可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。"""
|
||||
|
||||
finished_ok = pyqtSignal(object)
|
||||
failed = pyqtSignal(str)
|
||||
|
||||
def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None):
|
||||
super().__init__()
|
||||
self.task = task
|
||||
self.work_dir = str(work_dir)
|
||||
self.extra = extra or {}
|
||||
|
||||
def run(self):
|
||||
mpl_prev = None
|
||||
try:
|
||||
import matplotlib
|
||||
mpl_prev = matplotlib.get_backend()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("Agg")
|
||||
except Exception:
|
||||
mpl_prev = None
|
||||
try:
|
||||
wp = Path(self.work_dir)
|
||||
if self.task == "mask_glint":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
preview_paths = viz.generate_glint_deglint_previews(
|
||||
work_dir=str(wp),
|
||||
output_subdir="glint_deglint_previews",
|
||||
)
|
||||
cnt = len(preview_paths) if preview_paths else 0
|
||||
self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths})
|
||||
elif self.task == "sampling_map":
|
||||
hyperspectral_files = []
|
||||
deglint_dir = wp / "3_deglint"
|
||||
if deglint_dir.exists():
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
|
||||
if not hyperspectral_files:
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
|
||||
if not hyperspectral_files:
|
||||
self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif)。")
|
||||
return
|
||||
hyperspectral_path = str(hyperspectral_files[0])
|
||||
csv_files = []
|
||||
processed_dir = wp / "4_processed_data"
|
||||
if processed_dir.exists():
|
||||
csv_files = list(processed_dir.glob("*.csv"))
|
||||
if not csv_files:
|
||||
csv_files = (
|
||||
list(wp.glob("**/*sampling*.csv"))
|
||||
+ list(wp.glob("**/*point*.csv"))
|
||||
+ list(wp.glob("**/*.csv"))
|
||||
)
|
||||
if not csv_files:
|
||||
self.failed.emit("未找到采样点 CSV 文件。")
|
||||
return
|
||||
csv_path = str(csv_files[0])
|
||||
from src.postprocessing.point_map import SamplingPointMap
|
||||
map_generator = SamplingPointMap(
|
||||
output_dir=str(wp / "14_visualization" / "sampling_maps"),
|
||||
fast_mode=True,
|
||||
)
|
||||
map_path = map_generator.create_sampling_point_map(
|
||||
hyperspectral_path=hyperspectral_path,
|
||||
csv_path=csv_path,
|
||||
point_color="red",
|
||||
point_size=100,
|
||||
point_alpha=0.9,
|
||||
show_north_arrow=True,
|
||||
show_scale_bar=True,
|
||||
show_legend=True,
|
||||
downsample=True,
|
||||
dpi=180,
|
||||
)
|
||||
self.finished_ok.emit(
|
||||
{
|
||||
"task": "sampling_map",
|
||||
"map_path": map_path,
|
||||
"hyperspectral_path": hyperspectral_path,
|
||||
"csv_path": csv_path,
|
||||
}
|
||||
)
|
||||
elif self.task == "spectrum":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
csv_file = self.extra.get("csv_path")
|
||||
wl = self.extra.get("wavelength_start_column", "UTM_Y")
|
||||
n_groups = int(self.extra.get("n_groups", 5))
|
||||
param_cols = self.extra.get("param_cols") or []
|
||||
if param_cols:
|
||||
output_paths: List[str] = []
|
||||
err_lines: List[str] = []
|
||||
for param_col in param_cols:
|
||||
try:
|
||||
out = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(csv_file),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl,
|
||||
n_groups=n_groups,
|
||||
)
|
||||
output_paths.append(out)
|
||||
except Exception as _ex:
|
||||
err_lines.append(f"{param_col}: {_ex}")
|
||||
if not output_paths:
|
||||
self.failed.emit(
|
||||
"所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20])
|
||||
)
|
||||
return
|
||||
self.finished_ok.emit(
|
||||
{
|
||||
"task": "spectrum",
|
||||
"output_paths": output_paths,
|
||||
"errors": err_lines,
|
||||
}
|
||||
)
|
||||
else:
|
||||
param_col = self.extra.get("param_col")
|
||||
out = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(csv_file),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl,
|
||||
n_groups=n_groups,
|
||||
)
|
||||
self.finished_ok.emit(
|
||||
{"task": "spectrum", "output_path": out, "param_col": param_col}
|
||||
)
|
||||
elif self.task == "statistics":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
csv_file = self.extra.get("csv_path")
|
||||
param_cols = self.extra.get("param_cols") or []
|
||||
output_paths = viz.plot_statistical_charts(
|
||||
csv_path=str(csv_file),
|
||||
parameter_columns=param_cols,
|
||||
)
|
||||
self.finished_ok.emit(
|
||||
{"task": "statistics", "output_paths": output_paths}
|
||||
)
|
||||
elif self.task == "scatter":
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
|
||||
training_csv_path = (self.extra.get("training_csv_path") or "").strip()
|
||||
models_dir = (self.extra.get("models_dir") or "").strip()
|
||||
if not training_csv_path or not Path(training_csv_path).is_file():
|
||||
self.failed.emit("训练光谱 CSV 无效或不存在,请确认已选择步骤5输出的文件。")
|
||||
return
|
||||
if not models_dir or not Path(models_dir).is_dir():
|
||||
self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
|
||||
return
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||||
training_csv_path=training_csv_path,
|
||||
models_dir=models_dir,
|
||||
)
|
||||
self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}})
|
||||
elif self.task == "generate_all_selected":
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
parts = []
|
||||
|
||||
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
|
||||
|
||||
if self.extra.get("gen_scatter"):
|
||||
if training_csv.is_file():
|
||||
models_dir = wp / "7_Supervised_Model_Training"
|
||||
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||||
training_csv_path=str(training_csv),
|
||||
models_dir=str(models_dir),
|
||||
)
|
||||
count = len(scatter_paths) if scatter_paths else 0
|
||||
parts.append(f"散点图: {count} 个")
|
||||
else:
|
||||
parts.append("散点图: 跳过(无模型目录)")
|
||||
else:
|
||||
parts.append("散点图: 跳过(无训练数据)")
|
||||
|
||||
if self.extra.get("gen_spectrum"):
|
||||
if training_csv.is_file():
|
||||
import pandas as pd
|
||||
df = pd.read_csv(training_csv)
|
||||
wl_col = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl_col, str):
|
||||
idx = int(df.columns.get_loc(wl_col)) + 1
|
||||
else:
|
||||
idx = int(wl_col)
|
||||
param_cols = []
|
||||
if idx > 0 and idx < len(df.columns):
|
||||
param_cols = [
|
||||
c for c in df.columns[:idx]
|
||||
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
|
||||
]
|
||||
if param_cols:
|
||||
spectrum_paths = []
|
||||
for param_col in param_cols:
|
||||
try:
|
||||
path = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(training_csv),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl_col,
|
||||
n_groups=5,
|
||||
)
|
||||
if path:
|
||||
spectrum_paths.append(path)
|
||||
except Exception as e:
|
||||
print(f"生成光谱图失败 ({param_col}): {e}")
|
||||
count = len(spectrum_paths)
|
||||
parts.append(f"光谱图: {count} 个")
|
||||
else:
|
||||
parts.append("光谱图: 跳过(无可用参数列)")
|
||||
else:
|
||||
parts.append("光谱图: 跳过(无训练数据)")
|
||||
|
||||
if self.extra.get("gen_boxplots"):
|
||||
if training_csv.is_file():
|
||||
import pandas as pd
|
||||
df = pd.read_csv(training_csv)
|
||||
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
|
||||
param_cols = [
|
||||
c for c in df.select_dtypes(include=[np.number]).columns
|
||||
if not any(exc in c.lower() for exc in exclude_cols)
|
||||
]
|
||||
wl = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl, str):
|
||||
idx = int(df.columns.get_loc(wl)) + 1
|
||||
else:
|
||||
idx = int(wl)
|
||||
if 0 < idx < len(df.columns):
|
||||
meta_set = set(df.columns[:idx])
|
||||
param_cols = [c for c in param_cols if c in meta_set]
|
||||
|
||||
if param_cols:
|
||||
output_dict = viz.plot_statistical_charts(
|
||||
csv_path=str(training_csv),
|
||||
parameter_columns=param_cols,
|
||||
)
|
||||
count = len([v for v in output_dict.values() if v]) if output_dict else 0
|
||||
parts.append(f"统计图: {count} 个")
|
||||
else:
|
||||
parts.append("统计图: 跳过(无可用水质参数列)")
|
||||
else:
|
||||
parts.append("统计图: 跳过(无训练数据)")
|
||||
|
||||
if self.extra.get("gen_mask_glint"):
|
||||
preview_paths = viz.generate_glint_deglint_previews(
|
||||
work_dir=str(wp),
|
||||
output_subdir="glint_deglint_previews",
|
||||
)
|
||||
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0} 个")
|
||||
|
||||
if self.extra.get("gen_sampling_map"):
|
||||
hyperspectral_files = []
|
||||
deglint_dir = wp / "3_deglint"
|
||||
if deglint_dir.exists():
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
|
||||
if not hyperspectral_files:
|
||||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||||
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
|
||||
if hyperspectral_files:
|
||||
hyperspectral_path = str(hyperspectral_files[0])
|
||||
csv_files = []
|
||||
processed_dir = wp / "4_processed_data"
|
||||
if processed_dir.exists():
|
||||
csv_files = list(processed_dir.glob("*.csv"))
|
||||
if not csv_files:
|
||||
csv_files = (
|
||||
list(wp.glob("**/*sampling*.csv"))
|
||||
+ list(wp.glob("**/*point*.csv"))
|
||||
+ list(wp.glob("**/*.csv"))
|
||||
)
|
||||
if csv_files:
|
||||
csv_path = str(csv_files[0])
|
||||
from src.postprocessing.point_map import SamplingPointMap
|
||||
map_generator = SamplingPointMap(
|
||||
output_dir=str(wp / "14_visualization" / "sampling_maps"),
|
||||
fast_mode=True,
|
||||
)
|
||||
map_path = map_generator.create_sampling_point_map(
|
||||
hyperspectral_path=hyperspectral_path,
|
||||
csv_path=csv_path,
|
||||
point_color="red",
|
||||
point_size=100,
|
||||
point_alpha=0.9,
|
||||
show_north_arrow=True,
|
||||
show_scale_bar=True,
|
||||
show_legend=True,
|
||||
downsample=True,
|
||||
dpi=180,
|
||||
)
|
||||
parts.append(f"采样点图: {Path(map_path).name}")
|
||||
else:
|
||||
parts.append("采样点图: 跳过(无CSV)")
|
||||
else:
|
||||
parts.append("采样点图: 跳过(无影像)")
|
||||
self.finished_ok.emit({"task": "generate_all_selected", "parts": parts})
|
||||
else:
|
||||
self.failed.emit(f"未知可视化任务: {self.task}")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
self.failed.emit(f"{e}\n{traceback.format_exc()}")
|
||||
finally:
|
||||
if mpl_prev:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend(mpl_prev)
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user