refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块
This commit is contained in:
147
src/core/visualization/scatter_plot.py
Normal file
147
src/core/visualization/scatter_plot.py
Normal file
@ -0,0 +1,147 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 散点图生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Union
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_model_scatter_plots(
|
||||
models_dir: str,
|
||||
training_csv_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
metric: str = 'test_r2',
|
||||
use_enhanced: bool = True,
|
||||
feature_start_column: Union[str, int] = 13,
|
||||
test_size: float = 0.2,
|
||||
random_state: int = 42,
|
||||
scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成模型评估散点图(真实值vs预测值)
|
||||
|
||||
Args:
|
||||
models_dir: 模型保存目录
|
||||
training_csv_path: 训练数据CSV路径
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
metric: 选择最佳模型的指标
|
||||
use_enhanced: 是否使用增强版散点图(带置信区间,使用sctter_batch)
|
||||
feature_start_column: 特征开始列名或索引
|
||||
test_size: 测试集比例
|
||||
random_state: 随机种子
|
||||
scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象
|
||||
|
||||
Returns:
|
||||
散点图文件路径字典(键为目标参数名)
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成模型评估散点图")
|
||||
print("="*80)
|
||||
|
||||
if training_csv_path is None:
|
||||
raise ValueError("请提供 training_csv_path")
|
||||
|
||||
models_path = Path(models_dir)
|
||||
if not models_path.exists():
|
||||
raise ValueError(f"模型目录不存在: {models_dir}")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
scatter_paths = {}
|
||||
|
||||
# 增强版散点图
|
||||
if use_enhanced:
|
||||
print("使用增强版散点图(带置信区间)")
|
||||
try:
|
||||
from src.core.prediction.sctter_batch import WaterQualityScatterBatch
|
||||
if scatter_batch is None:
|
||||
scatter_batch = WaterQualityScatterBatch()
|
||||
|
||||
results = scatter_batch.batch_plot_scatter(
|
||||
models_root_dir=models_dir,
|
||||
csv_path=training_csv_path,
|
||||
output_dir=output_dir,
|
||||
metric=metric,
|
||||
target_column=None,
|
||||
feature_start_column=feature_start_column,
|
||||
test_size=test_size,
|
||||
random_state=random_state
|
||||
)
|
||||
|
||||
for target_name, result in results.items():
|
||||
if result.get('status') == 'success':
|
||||
scatter_paths[target_name] = result.get('save_path', '')
|
||||
print(f" ✓ {target_name}: {result.get('save_path', '')}")
|
||||
else:
|
||||
print(f" ✗ {target_name}: 失败 - {result.get('error', '未知错误')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用增强版散点图时出错: {e}")
|
||||
print("回退到基础版散点图")
|
||||
use_enhanced = False
|
||||
|
||||
# 基础版散点图
|
||||
if not use_enhanced or not scatter_paths:
|
||||
print("使用基础版散点图")
|
||||
for target_folder in models_path.iterdir():
|
||||
if not target_folder.is_dir():
|
||||
continue
|
||||
|
||||
target_name = target_folder.name
|
||||
print(f"\n处理目标参数: {target_name}")
|
||||
|
||||
try:
|
||||
inferencer = WaterQualityInference(str(target_folder))
|
||||
eval_result = inferencer.evaluate_with_split(
|
||||
data_csv_path=training_csv_path,
|
||||
split_method="spxy",
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
metric=metric
|
||||
)
|
||||
|
||||
predictions = eval_result.get('predictions', {})
|
||||
if predictions:
|
||||
y_train_true = predictions.get('y_train_true')
|
||||
y_train_pred = predictions.get('y_train_pred')
|
||||
y_test_true = predictions.get('y_test_true')
|
||||
y_test_pred = predictions.get('y_test_pred')
|
||||
metrics = eval_result.get('test_metrics', {})
|
||||
|
||||
if y_train_true is not None and y_test_true is not None:
|
||||
y_all_true = np.concatenate([y_train_true, y_test_true])
|
||||
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
|
||||
|
||||
train_indices = np.arange(len(y_train_true))
|
||||
test_indices = np.arange(len(y_train_true), len(y_all_true))
|
||||
|
||||
scatter_path = visualizer.plot_scatter_true_vs_pred(
|
||||
y_true=y_all_true,
|
||||
y_pred=y_all_pred,
|
||||
target_name=target_name,
|
||||
train_indices=train_indices,
|
||||
test_indices=test_indices,
|
||||
metrics={
|
||||
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
|
||||
'test_r2': metrics.get('r2', 0),
|
||||
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
|
||||
'test_rmse': metrics.get('rmse', 0)
|
||||
}
|
||||
)
|
||||
scatter_paths[target_name] = scatter_path
|
||||
except Exception as e:
|
||||
print(f"处理目标参数 {target_name} 时出错: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表")
|
||||
return scatter_paths
|
||||
Reference in New Issue
Block a user