refactor(pipeline): 路径直接传输 — 统一 ctx 字段名/panel key/step 形参名

This commit is contained in:
DXC
2026-06-03 17:29:41 +08:00
parent 517bb28611
commit 343e316799
99 changed files with 9127 additions and 91 deletions

346
src/gui/core/viz_thread.py Normal file
View File

@ -0,0 +1,346 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
可视化后台线程模块
包含 VisualizationWorkerThread 后台线程类和辅助函数。
"""
from pathlib import Path
from typing import Optional, List, Union
from PyQt5.QtCore import QThread, pyqtSignal
import numpy as np
def _viz_infer_wavelength_start_column(df) -> Union[str, int]:
"""推断光谱起始列training_spectra 通常以波长数值为列名,未必含 UTM_Y"""
import pandas as pd
for i, col in enumerate(df.columns):
name = str(col).strip().lstrip("\ufeff")
try:
v = float(name)
except ValueError:
continue
if 200.0 <= v <= 3000.0:
return i
if "UTM_Y" in df.columns:
return "UTM_Y"
return 0
class VisualizationWorkerThread(QThread):
"""可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。"""
finished_ok = pyqtSignal(object)
failed = pyqtSignal(str)
def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None):
super().__init__()
self.task = task
self.work_dir = str(work_dir)
self.extra = extra or {}
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
wp = Path(self.work_dir)
if self.task == "mask_glint":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
cnt = len(preview_paths) if preview_paths else 0
self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths})
elif self.task == "sampling_map":
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if not hyperspectral_files:
self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif")
return
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if not csv_files:
self.failed.emit("未找到采样点 CSV 文件。")
return
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
self.finished_ok.emit(
{
"task": "sampling_map",
"map_path": map_path,
"hyperspectral_path": hyperspectral_path,
"csv_path": csv_path,
}
)
elif self.task == "spectrum":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
wl = self.extra.get("wavelength_start_column", "UTM_Y")
n_groups = int(self.extra.get("n_groups", 5))
param_cols = self.extra.get("param_cols") or []
if param_cols:
output_paths: List[str] = []
err_lines: List[str] = []
for param_col in param_cols:
try:
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
output_paths.append(out)
except Exception as _ex:
err_lines.append(f"{param_col}: {_ex}")
if not output_paths:
self.failed.emit(
"所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20])
)
return
self.finished_ok.emit(
{
"task": "spectrum",
"output_paths": output_paths,
"errors": err_lines,
}
)
else:
param_col = self.extra.get("param_col")
out = viz.plot_spectrum_by_parameter(
csv_path=str(csv_file),
parameter_column=param_col,
wavelength_start_column=wl,
n_groups=n_groups,
)
self.finished_ok.emit(
{"task": "spectrum", "output_path": out, "param_col": param_col}
)
elif self.task == "statistics":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
csv_file = self.extra.get("csv_path")
param_cols = self.extra.get("param_cols") or []
output_paths = viz.plot_statistical_charts(
csv_path=str(csv_file),
parameter_columns=param_cols,
)
self.finished_ok.emit(
{"task": "statistics", "output_paths": output_paths}
)
elif self.task == "scatter":
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
training_csv_path = (self.extra.get("training_csv_path") or "").strip()
models_dir = (self.extra.get("models_dir") or "").strip()
if not training_csv_path or not Path(training_csv_path).is_file():
self.failed.emit("训练光谱 CSV 无效或不存在请确认已选择步骤5输出的文件。")
return
if not models_dir or not Path(models_dir).is_dir():
self.failed.emit("模型目录无效或不存在请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
return
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=training_csv_path,
models_dir=models_dir,
)
self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}})
elif self.task == "generate_all_selected":
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
parts = []
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
if self.extra.get("gen_scatter"):
if training_csv.is_file():
models_dir = wp / "7_Supervised_Model_Training"
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=str(training_csv),
models_dir=str(models_dir),
)
count = len(scatter_paths) if scatter_paths else 0
parts.append(f"散点图: {count}")
else:
parts.append("散点图: 跳过(无模型目录)")
else:
parts.append("散点图: 跳过(无训练数据)")
if self.extra.get("gen_spectrum"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
wl_col = _viz_infer_wavelength_start_column(df)
if isinstance(wl_col, str):
idx = int(df.columns.get_loc(wl_col)) + 1
else:
idx = int(wl_col)
param_cols = []
if idx > 0 and idx < len(df.columns):
param_cols = [
c for c in df.columns[:idx]
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
]
if param_cols:
spectrum_paths = []
for param_col in param_cols:
try:
path = viz.plot_spectrum_by_parameter(
csv_path=str(training_csv),
parameter_column=param_col,
wavelength_start_column=wl_col,
n_groups=5,
)
if path:
spectrum_paths.append(path)
except Exception as e:
print(f"生成光谱图失败 ({param_col}): {e}")
count = len(spectrum_paths)
parts.append(f"光谱图: {count}")
else:
parts.append("光谱图: 跳过(无可用参数列)")
else:
parts.append("光谱图: 跳过(无训练数据)")
if self.extra.get("gen_boxplots"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
param_cols = [
c for c in df.select_dtypes(include=[np.number]).columns
if not any(exc in c.lower() for exc in exclude_cols)
]
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
if 0 < idx < len(df.columns):
meta_set = set(df.columns[:idx])
param_cols = [c for c in param_cols if c in meta_set]
if param_cols:
output_dict = viz.plot_statistical_charts(
csv_path=str(training_csv),
parameter_columns=param_cols,
)
count = len([v for v in output_dict.values() if v]) if output_dict else 0
parts.append(f"统计图: {count}")
else:
parts.append("统计图: 跳过(无可用水质参数列)")
else:
parts.append("统计图: 跳过(无训练数据)")
if self.extra.get("gen_mask_glint"):
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0}")
if self.extra.get("gen_sampling_map"):
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
if deglint_dir.exists():
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
if not hyperspectral_files:
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
if hyperspectral_files:
hyperspectral_path = str(hyperspectral_files[0])
csv_files = []
processed_dir = wp / "4_processed_data"
if processed_dir.exists():
csv_files = list(processed_dir.glob("*.csv"))
if not csv_files:
csv_files = (
list(wp.glob("**/*sampling*.csv"))
+ list(wp.glob("**/*point*.csv"))
+ list(wp.glob("**/*.csv"))
)
if csv_files:
csv_path = str(csv_files[0])
from src.postprocessing.point_map import SamplingPointMap
map_generator = SamplingPointMap(
output_dir=str(wp / "14_visualization" / "sampling_maps"),
fast_mode=True,
)
map_path = map_generator.create_sampling_point_map(
hyperspectral_path=hyperspectral_path,
csv_path=csv_path,
point_color="red",
point_size=100,
point_alpha=0.9,
show_north_arrow=True,
show_scale_bar=True,
show_legend=True,
downsample=True,
dpi=180,
)
parts.append(f"采样点图: {Path(map_path).name}")
else:
parts.append("采样点图: 跳过(无CSV)")
else:
parts.append("采样点图: 跳过(无影像)")
self.finished_ok.emit({"task": "generate_all_selected", "parts": parts})
else:
self.failed.emit(f"未知可视化任务: {self.task}")
except Exception as e:
import traceback
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass