# -*- coding: utf-8 -*- """ 可视化模块 - 散点图生成 """ import numpy as np import pandas as pd from pathlib import Path from typing import Optional, Dict, List, Union from src.core.prediction.inference_batch import WaterQualityInference from src.postprocessing.visualization_reports import WaterQualityVisualization def generate_model_scatter_plots( models_dir: str, training_csv_path: str, output_dir: Optional[str] = None, metric: str = 'test_r2', use_enhanced: bool = True, feature_start_column: Union[str, int] = 13, test_size: float = 0.2, random_state: int = 42, scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象 ) -> Dict[str, str]: """ 生成模型评估散点图(真实值vs预测值) Args: models_dir: 模型保存目录 training_csv_path: 训练数据CSV路径 output_dir: 输出目录(None则使用默认) metric: 选择最佳模型的指标 use_enhanced: 是否使用增强版散点图(带置信区间,使用sctter_batch) feature_start_column: 特征开始列名或索引 test_size: 测试集比例 random_state: 随机种子 scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象 Returns: 散点图文件路径字典(键为目标参数名) """ print("\n" + "="*80) print("生成模型评估散点图") print("="*80) if training_csv_path is None: raise ValueError("请提供 training_csv_path") models_path = Path(models_dir) if not models_path.exists(): raise ValueError(f"模型目录不存在: {models_dir}") # 确定输出目录 if output_dir is None: output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots") Path(output_dir).mkdir(parents=True, exist_ok=True) # 实例化可视化器 visualizer = WaterQualityVisualization(output_dir) scatter_paths = {} # 增强版散点图 if use_enhanced: print("使用增强版散点图(带置信区间)") try: from src.core.prediction.sctter_batch import WaterQualityScatterBatch if scatter_batch is None: scatter_batch = WaterQualityScatterBatch() results = scatter_batch.batch_plot_scatter( models_root_dir=models_dir, csv_path=training_csv_path, output_dir=output_dir, metric=metric, target_column=None, feature_start_column=feature_start_column, test_size=test_size, random_state=random_state ) for target_name, result in results.items(): if result.get('status') == 'success': scatter_paths[target_name] = result.get('save_path', '') print(f" ✓ {target_name}: {result.get('save_path', '')}") else: print(f" ✗ {target_name}: 失败 - {result.get('error', '未知错误')}") except Exception as e: print(f"使用增强版散点图时出错: {e}") print("回退到基础版散点图") use_enhanced = False # 基础版散点图 if not use_enhanced or not scatter_paths: print("使用基础版散点图") for target_folder in models_path.iterdir(): if not target_folder.is_dir(): continue target_name = target_folder.name print(f"\n处理目标参数: {target_name}") try: inferencer = WaterQualityInference(str(target_folder)) eval_result = inferencer.evaluate_with_split( data_csv_path=training_csv_path, split_method="spxy", test_size=test_size, random_state=random_state, metric=metric ) predictions = eval_result.get('predictions', {}) if predictions: y_train_true = predictions.get('y_train_true') y_train_pred = predictions.get('y_train_pred') y_test_true = predictions.get('y_test_true') y_test_pred = predictions.get('y_test_pred') metrics = eval_result.get('test_metrics', {}) if y_train_true is not None and y_test_true is not None: y_all_true = np.concatenate([y_train_true, y_test_true]) y_all_pred = np.concatenate([y_train_pred, y_test_pred]) train_indices = np.arange(len(y_train_true)) test_indices = np.arange(len(y_train_true), len(y_all_true)) scatter_path = visualizer.plot_scatter_true_vs_pred( y_true=y_all_true, y_pred=y_all_pred, target_name=target_name, train_indices=train_indices, test_indices=test_indices, metrics={ 'train_r2': eval_result.get('train_metrics', {}).get('r2', 0), 'test_r2': metrics.get('r2', 0), 'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0), 'test_rmse': metrics.get('rmse', 0) } ) scatter_paths[target_name] = scatter_path except Exception as e: print(f"处理目标参数 {target_name} 时出错: {e}") continue print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表") return scatter_paths