Files
WQ_GUI/src/core/visualization/scatter_plot.py

147 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
可视化模块 - 散点图生成
"""
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional, Dict, List, Union
from src.core.prediction.inference_batch import WaterQualityInference
from src.postprocessing.visualization_reports import WaterQualityVisualization
def generate_model_scatter_plots(
models_dir: str,
training_csv_path: str,
output_dir: Optional[str] = None,
metric: str = 'test_r2',
use_enhanced: bool = True,
feature_start_column: Union[str, int] = 13,
test_size: float = 0.2,
random_state: int = 42,
scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象
) -> Dict[str, str]:
"""
生成模型评估散点图真实值vs预测值
Args:
models_dir: 模型保存目录
training_csv_path: 训练数据CSV路径
output_dir: 输出目录None则使用默认
metric: 选择最佳模型的指标
use_enhanced: 是否使用增强版散点图带置信区间使用sctter_batch
feature_start_column: 特征开始列名或索引
test_size: 测试集比例
random_state: 随机种子
scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象
Returns:
散点图文件路径字典(键为目标参数名)
"""
print("\n" + "="*80)
print("生成模型评估散点图")
print("="*80)
if training_csv_path is None:
raise ValueError("请提供 training_csv_path")
models_path = Path(models_dir)
if not models_path.exists():
raise ValueError(f"模型目录不存在: {models_dir}")
# 确定输出目录
if output_dir is None:
output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 实例化可视化器
visualizer = WaterQualityVisualization(output_dir)
scatter_paths = {}
# 增强版散点图
if use_enhanced:
print("使用增强版散点图(带置信区间)")
try:
from src.core.prediction.sctter_batch import WaterQualityScatterBatch
if scatter_batch is None:
scatter_batch = WaterQualityScatterBatch()
results = scatter_batch.batch_plot_scatter(
models_root_dir=models_dir,
csv_path=training_csv_path,
output_dir=output_dir,
metric=metric,
target_column=None,
feature_start_column=feature_start_column,
test_size=test_size,
random_state=random_state
)
for target_name, result in results.items():
if result.get('status') == 'success':
scatter_paths[target_name] = result.get('save_path', '')
print(f"{target_name}: {result.get('save_path', '')}")
else:
print(f"{target_name}: 失败 - {result.get('error', '未知错误')}")
except Exception as e:
print(f"使用增强版散点图时出错: {e}")
print("回退到基础版散点图")
use_enhanced = False
# 基础版散点图
if not use_enhanced or not scatter_paths:
print("使用基础版散点图")
for target_folder in models_path.iterdir():
if not target_folder.is_dir():
continue
target_name = target_folder.name
print(f"\n处理目标参数: {target_name}")
try:
inferencer = WaterQualityInference(str(target_folder))
eval_result = inferencer.evaluate_with_split(
data_csv_path=training_csv_path,
split_method="spxy",
test_size=test_size,
random_state=random_state,
metric=metric
)
predictions = eval_result.get('predictions', {})
if predictions:
y_train_true = predictions.get('y_train_true')
y_train_pred = predictions.get('y_train_pred')
y_test_true = predictions.get('y_test_true')
y_test_pred = predictions.get('y_test_pred')
metrics = eval_result.get('test_metrics', {})
if y_train_true is not None and y_test_true is not None:
y_all_true = np.concatenate([y_train_true, y_test_true])
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
train_indices = np.arange(len(y_train_true))
test_indices = np.arange(len(y_train_true), len(y_all_true))
scatter_path = visualizer.plot_scatter_true_vs_pred(
y_true=y_all_true,
y_pred=y_all_pred,
target_name=target_name,
train_indices=train_indices,
test_indices=test_indices,
metrics={
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
'test_r2': metrics.get('r2', 0),
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
'test_rmse': metrics.get('rmse', 0)
}
)
scatter_paths[target_name] = scatter_path
except Exception as e:
print(f"处理目标参数 {target_name} 时出错: {e}")
continue
print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表")
return scatter_paths