diff --git a/src/gui/core/viz_thread.py b/src/gui/core/viz_thread.py index ca657e9..fc1ac06 100644 --- a/src/gui/core/viz_thread.py +++ b/src/gui/core/viz_thread.py @@ -209,11 +209,19 @@ class VisualizationWorkerThread(QThread): viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization")) parts = [] - training_csv = wp / "5_training_spectra" / "training_spectra.csv" + training_csv_path = (self.extra.get("training_csv_path") or "").strip() + if training_csv_path: + training_csv = Path(training_csv_path) + else: + training_csv = wp / "5_training_spectra" / "training_spectra.csv" if self.extra.get("gen_scatter"): if training_csv.is_file(): - models_dir = wp / "7_Supervised_Model_Training" + models_dir_str = (self.extra.get("models_dir") or "").strip() + if models_dir_str: + models_dir = Path(models_dir_str) + else: + models_dir = wp / "7_Supervised_Model_Training" if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()): from src.core.visualization.scatter_plot import generate_model_scatter_plots scatter_paths = generate_model_scatter_plots(