147 lines
5.7 KiB
Python
147 lines
5.7 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
可视化模块 - 散点图生成
|
||
"""
|
||
import numpy as np
|
||
import pandas as pd
|
||
from pathlib import Path
|
||
from typing import Optional, Dict, List, Union
|
||
from src.core.prediction.inference_batch import WaterQualityInference
|
||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||
|
||
|
||
def generate_model_scatter_plots(
|
||
models_dir: str,
|
||
training_csv_path: str,
|
||
output_dir: Optional[str] = None,
|
||
metric: str = 'test_r2',
|
||
use_enhanced: bool = True,
|
||
feature_start_column: Union[str, int] = 13,
|
||
test_size: float = 0.2,
|
||
random_state: int = 42,
|
||
scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象
|
||
) -> Dict[str, str]:
|
||
"""
|
||
生成模型评估散点图(真实值vs预测值)
|
||
|
||
Args:
|
||
models_dir: 模型保存目录
|
||
training_csv_path: 训练数据CSV路径
|
||
output_dir: 输出目录(None则使用默认)
|
||
metric: 选择最佳模型的指标
|
||
use_enhanced: 是否使用增强版散点图(带置信区间,使用sctter_batch)
|
||
feature_start_column: 特征开始列名或索引
|
||
test_size: 测试集比例
|
||
random_state: 随机种子
|
||
scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象
|
||
|
||
Returns:
|
||
散点图文件路径字典(键为目标参数名)
|
||
"""
|
||
print("\n" + "="*80)
|
||
print("生成模型评估散点图")
|
||
print("="*80)
|
||
|
||
if training_csv_path is None:
|
||
raise ValueError("请提供 training_csv_path")
|
||
|
||
models_path = Path(models_dir)
|
||
if not models_path.exists():
|
||
raise ValueError(f"模型目录不存在: {models_dir}")
|
||
|
||
# 确定输出目录
|
||
if output_dir is None:
|
||
output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots")
|
||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||
|
||
# 实例化可视化器
|
||
visualizer = WaterQualityVisualization(output_dir)
|
||
|
||
scatter_paths = {}
|
||
|
||
# 增强版散点图
|
||
if use_enhanced:
|
||
print("使用增强版散点图(带置信区间)")
|
||
try:
|
||
from src.core.prediction.sctter_batch import WaterQualityScatterBatch
|
||
if scatter_batch is None:
|
||
scatter_batch = WaterQualityScatterBatch()
|
||
|
||
results = scatter_batch.batch_plot_scatter(
|
||
models_root_dir=models_dir,
|
||
csv_path=training_csv_path,
|
||
output_dir=output_dir,
|
||
metric=metric,
|
||
target_column=None,
|
||
feature_start_column=feature_start_column,
|
||
test_size=test_size,
|
||
random_state=random_state
|
||
)
|
||
|
||
for target_name, result in results.items():
|
||
if result.get('status') == 'success':
|
||
scatter_paths[target_name] = result.get('save_path', '')
|
||
print(f" ✓ {target_name}: {result.get('save_path', '')}")
|
||
else:
|
||
print(f" ✗ {target_name}: 失败 - {result.get('error', '未知错误')}")
|
||
|
||
except Exception as e:
|
||
print(f"使用增强版散点图时出错: {e}")
|
||
print("回退到基础版散点图")
|
||
use_enhanced = False
|
||
|
||
# 基础版散点图
|
||
if not use_enhanced or not scatter_paths:
|
||
print("使用基础版散点图")
|
||
for target_folder in models_path.iterdir():
|
||
if not target_folder.is_dir():
|
||
continue
|
||
|
||
target_name = target_folder.name
|
||
print(f"\n处理目标参数: {target_name}")
|
||
|
||
try:
|
||
inferencer = WaterQualityInference(str(target_folder))
|
||
eval_result = inferencer.evaluate_with_split(
|
||
data_csv_path=training_csv_path,
|
||
split_method="spxy",
|
||
test_size=test_size,
|
||
random_state=random_state,
|
||
metric=metric
|
||
)
|
||
|
||
predictions = eval_result.get('predictions', {})
|
||
if predictions:
|
||
y_train_true = predictions.get('y_train_true')
|
||
y_train_pred = predictions.get('y_train_pred')
|
||
y_test_true = predictions.get('y_test_true')
|
||
y_test_pred = predictions.get('y_test_pred')
|
||
metrics = eval_result.get('test_metrics', {})
|
||
|
||
if y_train_true is not None and y_test_true is not None:
|
||
y_all_true = np.concatenate([y_train_true, y_test_true])
|
||
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
|
||
|
||
train_indices = np.arange(len(y_train_true))
|
||
test_indices = np.arange(len(y_train_true), len(y_all_true))
|
||
|
||
scatter_path = visualizer.plot_scatter_true_vs_pred(
|
||
y_true=y_all_true,
|
||
y_pred=y_all_pred,
|
||
target_name=target_name,
|
||
train_indices=train_indices,
|
||
test_indices=test_indices,
|
||
metrics={
|
||
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
|
||
'test_r2': metrics.get('r2', 0),
|
||
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
|
||
'test_rmse': metrics.get('rmse', 0)
|
||
}
|
||
)
|
||
scatter_paths[target_name] = scatter_path
|
||
except Exception as e:
|
||
print(f"处理目标参数 {target_name} 时出错: {e}")
|
||
continue
|
||
|
||
print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表")
|
||
return scatter_paths |