修复训练摘要报告无法识别 .joblib 模型的 Bug
This commit is contained in:
@ -1004,66 +1004,83 @@ class ReportGenerator:
|
|||||||
保存的文件路径
|
保存的文件路径
|
||||||
"""
|
"""
|
||||||
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||||
|
import joblib
|
||||||
|
|
||||||
modeler = WaterQualityModelingBatch(models_dir)
|
modeler = WaterQualityModelingBatch(models_dir)
|
||||||
|
|
||||||
# 需要先加载训练结果
|
|
||||||
# 这里假设results已经存储在modeler中,或者需要从保存的文件中读取
|
|
||||||
# 由于modeling_batch.py的结构,我们需要另一种方式来获取所有结果
|
|
||||||
|
|
||||||
# 尝试遍历模型目录,查找所有保存的结果
|
|
||||||
models_path = Path(models_dir)
|
models_path = Path(models_dir)
|
||||||
all_results = []
|
all_results = []
|
||||||
|
|
||||||
# 遍历所有目标参数文件夹
|
# 递归扫描 *.joblib 和 *.pkl,兼容 artifacts_dir/target_name/ 的所有子目录层级
|
||||||
for target_folder in models_path.iterdir():
|
model_files = list(models_path.rglob("*.joblib")) + list(models_path.rglob("*.pkl"))
|
||||||
if not target_folder.is_dir():
|
|
||||||
continue
|
for model_file in model_files:
|
||||||
|
# 目标参数取直系父目录名(符合 artifacts_dir/target_name/ 结构)
|
||||||
target_name = target_folder.name
|
target_name = model_file.parent.name
|
||||||
|
stem = model_file.stem
|
||||||
# 查找所有模型文件
|
|
||||||
for model_file in target_folder.rglob("*.pkl"):
|
# 文件名格式:{safe_target}_{preprocess}_{model_name}.joblib
|
||||||
# 从文件名提取信息(假设格式为:{preprocess}_{model}_{split}.pkl)
|
# 使用 split('_', 2) 最多切 3 段,目标 1 段、预处理 1 段、模型 1 段
|
||||||
model_info = {
|
parts = stem.split('_', 2)
|
||||||
'target': target_name,
|
preprocess = parts[1] if len(parts) > 1 else 'Unknown'
|
||||||
'model_file': str(model_file),
|
model_name_str = parts[2] if len(parts) > 2 else stem
|
||||||
'preprocess': 'Unknown',
|
|
||||||
'model': 'Unknown',
|
# 尝试从 joblib/pkl 读取元数据,提取性能指标
|
||||||
'split_method': 'Unknown'
|
metrics = {}
|
||||||
|
try:
|
||||||
|
data = joblib.load(model_file)
|
||||||
|
metadata = data.get('metadata', {})
|
||||||
|
metrics = {
|
||||||
|
'train_r2': metadata.get('train_r2', 'N/A'),
|
||||||
|
'test_r2': metadata.get('test_r2', 'N/A'),
|
||||||
|
'test_rmse': metadata.get('test_rmse', 'N/A'),
|
||||||
|
'train_rmse': metadata.get('train_rmse', 'N/A'),
|
||||||
|
'train_mae': metadata.get('train_mae', 'N/A'),
|
||||||
|
'test_mae': metadata.get('test_mae', 'N/A'),
|
||||||
|
'cv_mean': metadata.get('cv_mean', 'N/A'),
|
||||||
}
|
}
|
||||||
|
except Exception:
|
||||||
# 尝试从文件名解析
|
pass # 加载失败时 metrics 保持为空字典,摘要中该列为 N/A
|
||||||
parts = model_file.stem.split('_')
|
|
||||||
if len(parts) >= 3:
|
all_results.append({
|
||||||
model_info['preprocess'] = parts[0]
|
'target': target_name,
|
||||||
model_info['model'] = parts[1]
|
'model_file': str(model_file),
|
||||||
model_info['split_method'] = parts[2]
|
'preprocess': preprocess,
|
||||||
|
'model': model_name_str,
|
||||||
all_results.append(model_info)
|
**metrics,
|
||||||
|
})
|
||||||
# 如果有训练结果数据,使用实际指标
|
|
||||||
# 否则创建一个基本的摘要
|
|
||||||
summary_data = []
|
summary_data = []
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
summary_data.append({
|
summary_data.append({
|
||||||
'目标参数': result['target'],
|
'目标参数': result['target'],
|
||||||
'预处理方法': result['preprocess'],
|
'预处理方法': result['preprocess'],
|
||||||
'模型名称': result['model'],
|
'模型名称': result['model'],
|
||||||
'划分方法': result['split_method'],
|
'模型文件': result['model_file'],
|
||||||
'模型文件': result['model_file']
|
'训练集R²': result.get('train_r2', 'N/A'),
|
||||||
|
'测试集R²': result.get('test_r2', 'N/A'),
|
||||||
|
'测试集RMSE': result.get('test_rmse', 'N/A'),
|
||||||
|
'训练集RMSE': result.get('train_rmse', 'N/A'),
|
||||||
|
'训练集MAE': result.get('train_mae', 'N/A'),
|
||||||
|
'测试集MAE': result.get('test_mae', 'N/A'),
|
||||||
|
'CV均值': result.get('cv_mean', 'N/A'),
|
||||||
})
|
})
|
||||||
|
|
||||||
if not summary_data:
|
if not summary_data:
|
||||||
print("警告:未找到模型文件,生成空摘要")
|
print("警告:未找到模型文件,生成空摘要")
|
||||||
summary_data = [{
|
summary_data = [{
|
||||||
'目标参数': 'No Data',
|
'目标参数': 'No Data',
|
||||||
'预处理方法': 'N/A',
|
'预处理方法': 'N/A',
|
||||||
'模型名称': 'N/A',
|
'模型名称': 'N/A',
|
||||||
'划分方法': 'N/A',
|
'模型文件': 'N/A',
|
||||||
'模型文件': 'N/A'
|
'训练集R²': 'N/A',
|
||||||
|
'测试集R²': 'N/A',
|
||||||
|
'测试集RMSE': 'N/A',
|
||||||
|
'训练集RMSE': 'N/A',
|
||||||
|
'训练集MAE': 'N/A',
|
||||||
|
'测试集MAE': 'N/A',
|
||||||
|
'CV均值': 'N/A',
|
||||||
}]
|
}]
|
||||||
|
|
||||||
df_summary = pd.DataFrame(summary_data)
|
df_summary = pd.DataFrame(summary_data)
|
||||||
|
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user