346 lines
17 KiB
Python
346 lines
17 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
可视化后台线程模块
|
||
|
||
包含 VisualizationWorkerThread 后台线程类和辅助函数。
|
||
"""
|
||
|
||
from pathlib import Path
|
||
from typing import Optional, List, Union
|
||
|
||
from PyQt5.QtCore import QThread, pyqtSignal
|
||
import numpy as np
|
||
|
||
|
||
def _viz_infer_wavelength_start_column(df) -> Union[str, int]:
|
||
"""推断光谱起始列(training_spectra 通常以波长数值为列名,未必含 UTM_Y)。"""
|
||
import pandas as pd
|
||
for i, col in enumerate(df.columns):
|
||
name = str(col).strip().lstrip("\ufeff")
|
||
try:
|
||
v = float(name)
|
||
except ValueError:
|
||
continue
|
||
if 200.0 <= v <= 3000.0:
|
||
return i
|
||
if "UTM_Y" in df.columns:
|
||
return "UTM_Y"
|
||
return 0
|
||
|
||
|
||
class VisualizationWorkerThread(QThread):
|
||
"""可视化耗时计算放入后台线程,并临时使用 Agg 后端,避免主界面未响应。"""
|
||
|
||
finished_ok = pyqtSignal(object)
|
||
failed = pyqtSignal(str)
|
||
|
||
def __init__(self, task: str, work_dir: str, extra: Optional[dict] = None):
|
||
super().__init__()
|
||
self.task = task
|
||
self.work_dir = str(work_dir)
|
||
self.extra = extra or {}
|
||
|
||
def run(self):
|
||
mpl_prev = None
|
||
try:
|
||
import matplotlib
|
||
mpl_prev = matplotlib.get_backend()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
import matplotlib.pyplot as plt
|
||
plt.switch_backend("Agg")
|
||
except Exception:
|
||
mpl_prev = None
|
||
try:
|
||
wp = Path(self.work_dir)
|
||
if self.task == "mask_glint":
|
||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||
preview_paths = viz.generate_glint_deglint_previews(
|
||
work_dir=str(wp),
|
||
output_subdir="glint_deglint_previews",
|
||
)
|
||
cnt = len(preview_paths) if preview_paths else 0
|
||
self.finished_ok.emit({"task": "mask_glint", "count": cnt, "preview_paths": preview_paths})
|
||
elif self.task == "sampling_map":
|
||
hyperspectral_files = []
|
||
deglint_dir = wp / "3_deglint"
|
||
if deglint_dir.exists():
|
||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
|
||
if not hyperspectral_files:
|
||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
|
||
if not hyperspectral_files:
|
||
self.failed.emit("未找到高光谱影像文件(.dat/.bsq/.tif)。")
|
||
return
|
||
hyperspectral_path = str(hyperspectral_files[0])
|
||
csv_files = []
|
||
processed_dir = wp / "4_processed_data"
|
||
if processed_dir.exists():
|
||
csv_files = list(processed_dir.glob("*.csv"))
|
||
if not csv_files:
|
||
csv_files = (
|
||
list(wp.glob("**/*sampling*.csv"))
|
||
+ list(wp.glob("**/*point*.csv"))
|
||
+ list(wp.glob("**/*.csv"))
|
||
)
|
||
if not csv_files:
|
||
self.failed.emit("未找到采样点 CSV 文件。")
|
||
return
|
||
csv_path = str(csv_files[0])
|
||
from src.postprocessing.point_map import SamplingPointMap
|
||
map_generator = SamplingPointMap(
|
||
output_dir=str(wp / "14_visualization" / "sampling_maps"),
|
||
fast_mode=True,
|
||
)
|
||
map_path = map_generator.create_sampling_point_map(
|
||
hyperspectral_path=hyperspectral_path,
|
||
csv_path=csv_path,
|
||
point_color="red",
|
||
point_size=100,
|
||
point_alpha=0.9,
|
||
show_north_arrow=True,
|
||
show_scale_bar=True,
|
||
show_legend=True,
|
||
downsample=True,
|
||
dpi=180,
|
||
)
|
||
self.finished_ok.emit(
|
||
{
|
||
"task": "sampling_map",
|
||
"map_path": map_path,
|
||
"hyperspectral_path": hyperspectral_path,
|
||
"csv_path": csv_path,
|
||
}
|
||
)
|
||
elif self.task == "spectrum":
|
||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||
csv_file = self.extra.get("csv_path")
|
||
wl = self.extra.get("wavelength_start_column", "UTM_Y")
|
||
n_groups = int(self.extra.get("n_groups", 5))
|
||
param_cols = self.extra.get("param_cols") or []
|
||
if param_cols:
|
||
output_paths: List[str] = []
|
||
err_lines: List[str] = []
|
||
for param_col in param_cols:
|
||
try:
|
||
out = viz.plot_spectrum_by_parameter(
|
||
csv_path=str(csv_file),
|
||
parameter_column=param_col,
|
||
wavelength_start_column=wl,
|
||
n_groups=n_groups,
|
||
)
|
||
output_paths.append(out)
|
||
except Exception as _ex:
|
||
err_lines.append(f"{param_col}: {_ex}")
|
||
if not output_paths:
|
||
self.failed.emit(
|
||
"所有参数列的光谱图均生成失败:\n" + "\n".join(err_lines[:20])
|
||
)
|
||
return
|
||
self.finished_ok.emit(
|
||
{
|
||
"task": "spectrum",
|
||
"output_paths": output_paths,
|
||
"errors": err_lines,
|
||
}
|
||
)
|
||
else:
|
||
param_col = self.extra.get("param_col")
|
||
out = viz.plot_spectrum_by_parameter(
|
||
csv_path=str(csv_file),
|
||
parameter_column=param_col,
|
||
wavelength_start_column=wl,
|
||
n_groups=n_groups,
|
||
)
|
||
self.finished_ok.emit(
|
||
{"task": "spectrum", "output_path": out, "param_col": param_col}
|
||
)
|
||
elif self.task == "statistics":
|
||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||
csv_file = self.extra.get("csv_path")
|
||
param_cols = self.extra.get("param_cols") or []
|
||
output_paths = viz.plot_statistical_charts(
|
||
csv_path=str(csv_file),
|
||
parameter_columns=param_cols,
|
||
)
|
||
self.finished_ok.emit(
|
||
{"task": "statistics", "output_paths": output_paths}
|
||
)
|
||
elif self.task == "scatter":
|
||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||
|
||
training_csv_path = (self.extra.get("training_csv_path") or "").strip()
|
||
models_dir = (self.extra.get("models_dir") or "").strip()
|
||
if not training_csv_path or not Path(training_csv_path).is_file():
|
||
self.failed.emit("训练光谱 CSV 无效或不存在,请确认已选择步骤5输出的文件。")
|
||
return
|
||
if not models_dir or not Path(models_dir).is_dir():
|
||
self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
|
||
return
|
||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||
training_csv_path=training_csv_path,
|
||
models_dir=models_dir,
|
||
)
|
||
self.finished_ok.emit({"task": "scatter", "scatter_paths": scatter_paths or {}})
|
||
elif self.task == "generate_all_selected":
|
||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||
parts = []
|
||
|
||
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
|
||
|
||
if self.extra.get("gen_scatter"):
|
||
if training_csv.is_file():
|
||
models_dir = wp / "7_Supervised_Model_Training"
|
||
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
|
||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||
training_csv_path=str(training_csv),
|
||
models_dir=str(models_dir),
|
||
)
|
||
count = len(scatter_paths) if scatter_paths else 0
|
||
parts.append(f"散点图: {count} 个")
|
||
else:
|
||
parts.append("散点图: 跳过(无模型目录)")
|
||
else:
|
||
parts.append("散点图: 跳过(无训练数据)")
|
||
|
||
if self.extra.get("gen_spectrum"):
|
||
if training_csv.is_file():
|
||
import pandas as pd
|
||
df = pd.read_csv(training_csv)
|
||
wl_col = _viz_infer_wavelength_start_column(df)
|
||
if isinstance(wl_col, str):
|
||
idx = int(df.columns.get_loc(wl_col)) + 1
|
||
else:
|
||
idx = int(wl_col)
|
||
param_cols = []
|
||
if idx > 0 and idx < len(df.columns):
|
||
param_cols = [
|
||
c for c in df.columns[:idx]
|
||
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
|
||
]
|
||
if param_cols:
|
||
spectrum_paths = []
|
||
for param_col in param_cols:
|
||
try:
|
||
path = viz.plot_spectrum_by_parameter(
|
||
csv_path=str(training_csv),
|
||
parameter_column=param_col,
|
||
wavelength_start_column=wl_col,
|
||
n_groups=5,
|
||
)
|
||
if path:
|
||
spectrum_paths.append(path)
|
||
except Exception as e:
|
||
print(f"生成光谱图失败 ({param_col}): {e}")
|
||
count = len(spectrum_paths)
|
||
parts.append(f"光谱图: {count} 个")
|
||
else:
|
||
parts.append("光谱图: 跳过(无可用参数列)")
|
||
else:
|
||
parts.append("光谱图: 跳过(无训练数据)")
|
||
|
||
if self.extra.get("gen_boxplots"):
|
||
if training_csv.is_file():
|
||
import pandas as pd
|
||
df = pd.read_csv(training_csv)
|
||
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
|
||
param_cols = [
|
||
c for c in df.select_dtypes(include=[np.number]).columns
|
||
if not any(exc in c.lower() for exc in exclude_cols)
|
||
]
|
||
wl = _viz_infer_wavelength_start_column(df)
|
||
if isinstance(wl, str):
|
||
idx = int(df.columns.get_loc(wl)) + 1
|
||
else:
|
||
idx = int(wl)
|
||
if 0 < idx < len(df.columns):
|
||
meta_set = set(df.columns[:idx])
|
||
param_cols = [c for c in param_cols if c in meta_set]
|
||
|
||
if param_cols:
|
||
output_dict = viz.plot_statistical_charts(
|
||
csv_path=str(training_csv),
|
||
parameter_columns=param_cols,
|
||
)
|
||
count = len([v for v in output_dict.values() if v]) if output_dict else 0
|
||
parts.append(f"统计图: {count} 个")
|
||
else:
|
||
parts.append("统计图: 跳过(无可用水质参数列)")
|
||
else:
|
||
parts.append("统计图: 跳过(无训练数据)")
|
||
|
||
if self.extra.get("gen_mask_glint"):
|
||
preview_paths = viz.generate_glint_deglint_previews(
|
||
work_dir=str(wp),
|
||
output_subdir="glint_deglint_previews",
|
||
)
|
||
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0} 个")
|
||
|
||
if self.extra.get("gen_sampling_map"):
|
||
hyperspectral_files = []
|
||
deglint_dir = wp / "3_deglint"
|
||
if deglint_dir.exists():
|
||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||
hyperspectral_files.extend(list(deglint_dir.glob(ext)))
|
||
if not hyperspectral_files:
|
||
for ext in ("*.dat", "*.bsq", "*.tif", "*.tiff"):
|
||
hyperspectral_files.extend(list(wp.glob(f"**/{ext}")))
|
||
if hyperspectral_files:
|
||
hyperspectral_path = str(hyperspectral_files[0])
|
||
csv_files = []
|
||
processed_dir = wp / "4_processed_data"
|
||
if processed_dir.exists():
|
||
csv_files = list(processed_dir.glob("*.csv"))
|
||
if not csv_files:
|
||
csv_files = (
|
||
list(wp.glob("**/*sampling*.csv"))
|
||
+ list(wp.glob("**/*point*.csv"))
|
||
+ list(wp.glob("**/*.csv"))
|
||
)
|
||
if csv_files:
|
||
csv_path = str(csv_files[0])
|
||
from src.postprocessing.point_map import SamplingPointMap
|
||
map_generator = SamplingPointMap(
|
||
output_dir=str(wp / "14_visualization" / "sampling_maps"),
|
||
fast_mode=True,
|
||
)
|
||
map_path = map_generator.create_sampling_point_map(
|
||
hyperspectral_path=hyperspectral_path,
|
||
csv_path=csv_path,
|
||
point_color="red",
|
||
point_size=100,
|
||
point_alpha=0.9,
|
||
show_north_arrow=True,
|
||
show_scale_bar=True,
|
||
show_legend=True,
|
||
downsample=True,
|
||
dpi=180,
|
||
)
|
||
parts.append(f"采样点图: {Path(map_path).name}")
|
||
else:
|
||
parts.append("采样点图: 跳过(无CSV)")
|
||
else:
|
||
parts.append("采样点图: 跳过(无影像)")
|
||
self.finished_ok.emit({"task": "generate_all_selected", "parts": parts})
|
||
else:
|
||
self.failed.emit(f"未知可视化任务: {self.task}")
|
||
except Exception as e:
|
||
import traceback
|
||
self.failed.emit(f"{e}\n{traceback.format_exc()}")
|
||
finally:
|
||
if mpl_prev:
|
||
try:
|
||
import matplotlib.pyplot as plt
|
||
plt.switch_backend(mpl_prev)
|
||
except Exception:
|
||
pass |