refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块

This commit is contained in:
DXC
2026-05-09 17:18:34 +08:00
parent b2b90050dc
commit dcbcc043e4
17 changed files with 2673 additions and 948 deletions

View File

@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 散点图生成
"""
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional, Dict, List, Union
from src.core.prediction.inference_batch import WaterQualityInference
from src.postprocessing.visualization_reports import WaterQualityVisualization
def generate_model_scatter_plots(
models_dir: str,
training_csv_path: str,
output_dir: Optional[str] = None,
metric: str = 'test_r2',
use_enhanced: bool = True,
feature_start_column: Union[str, int] = 13,
test_size: float = 0.2,
random_state: int = 42,
scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象
) -> Dict[str, str]:
"""
生成模型评估散点图真实值vs预测值
Args:
models_dir: 模型保存目录
training_csv_path: 训练数据CSV路径
output_dir: 输出目录None则使用默认
metric: 选择最佳模型的指标
use_enhanced: 是否使用增强版散点图带置信区间使用sctter_batch
feature_start_column: 特征开始列名或索引
test_size: 测试集比例
random_state: 随机种子
scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象
Returns:
散点图文件路径字典(键为目标参数名)
"""
print("\n" + "="*80)
print("生成模型评估散点图")
print("="*80)
if training_csv_path is None:
raise ValueError("请提供 training_csv_path")
models_path = Path(models_dir)
if not models_path.exists():
raise ValueError(f"模型目录不存在: {models_dir}")
# 确定输出目录
if output_dir is None:
output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 实例化可视化器
visualizer = WaterQualityVisualization(output_dir)
scatter_paths = {}
# 增强版散点图
if use_enhanced:
print("使用增强版散点图(带置信区间)")
try:
from src.core.prediction.sctter_batch import WaterQualityScatterBatch
if scatter_batch is None:
scatter_batch = WaterQualityScatterBatch()
results = scatter_batch.batch_plot_scatter(
models_root_dir=models_dir,
csv_path=training_csv_path,
output_dir=output_dir,
metric=metric,
target_column=None,
feature_start_column=feature_start_column,
test_size=test_size,
random_state=random_state
)
for target_name, result in results.items():
if result.get('status') == 'success':
scatter_paths[target_name] = result.get('save_path', '')
print(f"{target_name}: {result.get('save_path', '')}")
else:
print(f"{target_name}: 失败 - {result.get('error', '未知错误')}")
except Exception as e:
print(f"使用增强版散点图时出错: {e}")
print("回退到基础版散点图")
use_enhanced = False
# 基础版散点图
if not use_enhanced or not scatter_paths:
print("使用基础版散点图")
for target_folder in models_path.iterdir():
if not target_folder.is_dir():
continue
target_name = target_folder.name
print(f"\n处理目标参数: {target_name}")
try:
inferencer = WaterQualityInference(str(target_folder))
eval_result = inferencer.evaluate_with_split(
data_csv_path=training_csv_path,
split_method="spxy",
test_size=test_size,
random_state=random_state,
metric=metric
)
predictions = eval_result.get('predictions', {})
if predictions:
y_train_true = predictions.get('y_train_true')
y_train_pred = predictions.get('y_train_pred')
y_test_true = predictions.get('y_test_true')
y_test_pred = predictions.get('y_test_pred')
metrics = eval_result.get('test_metrics', {})
if y_train_true is not None and y_test_true is not None:
y_all_true = np.concatenate([y_train_true, y_test_true])
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
train_indices = np.arange(len(y_train_true))
test_indices = np.arange(len(y_train_true), len(y_all_true))
scatter_path = visualizer.plot_scatter_true_vs_pred(
y_true=y_all_true,
y_pred=y_all_pred,
target_name=target_name,
train_indices=train_indices,
test_indices=test_indices,
metrics={
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
'test_r2': metrics.get('r2', 0),
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
'test_rmse': metrics.get('rmse', 0)
}
)
scatter_paths[target_name] = scatter_path
except Exception as e:
print(f"处理目标参数 {target_name} 时出错: {e}")
continue
print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表")
return scatter_paths