feat(step8): 外部模型从单文件升级为母文件夹多模型字典扫描

This commit is contained in:
DXC
2026-06-08 09:56:02 +08:00
parent 4efe5b871e
commit 2b76d7908f
12 changed files with 935 additions and 29 deletions

View File

@ -26,19 +26,24 @@ from sklearn.model_selection import train_test_split
class WaterQualityInference:
"""水质参数反演推理类"""
def __init__(self, artifacts_dir: str = "models/artifacts"):
def __init__(self, artifacts_dir: str = "models/artifacts",
external_model=None, external_model_path=None):
"""
初始化推理类
Args:
artifacts_dir: 模型保存目录
external_model: 外部预训练模型对象(来自 GUI 导入,跳过磁盘加载)
external_model_path: 外部模型文件路径(仅用于日志)
"""
self.artifacts_dir = Path(artifacts_dir)
if not self.artifacts_dir.exists():
print(f"警告: 模型目录不存在: {artifacts_dir},将在需要时创建")
self.best_model_info = None
self.loaded_model_data = None
self.external_model = external_model
self.external_model_path = external_model_path
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
@ -745,7 +750,11 @@ class WaterQualityInference:
# 1. 加载模型
print("\n步骤1: 加载模型")
print("-" * 40)
if model_file_path:
if self.external_model is not None:
# 外部预训练模型已注入,直接使用,跳过磁盘加载
self.loaded_model_data = self.external_model
print(f" 使用外部预训练模型: type={type(self.external_model).__name__}")
elif model_file_path:
self.load_specific_model(model_file_path)
else:
self.load_best_model(metric=metric)
@ -863,10 +872,12 @@ class WaterQualityInference:
print(f"\n批量推理完成,共处理 {len(csv_files)} 个文件")
return results
def batch_inference_multi_models(self, models_root_dir: str, sampling_csv_path: str,
output_dir: str, metric: str = 'test_r2',
def batch_inference_multi_models(self, models_root_dir: str, sampling_csv_path: str,
output_dir: str, metric: str = 'test_r2',
prediction_column: str = 'prediction',
output_format: str = 'csv'):
output_format: str = 'csv',
external_model=None,
external_model_path=None):
"""
使用多个子文件夹中的模型进行批量推理
@ -881,7 +892,18 @@ class WaterQualityInference:
models_root = Path(models_root_dir)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 优先级:外部预训练模型 > 从磁盘加载
if external_model is not None:
effective_model = external_model
model_desc = (
f"外部导入模型 ({external_model_path or 'unknown'}), "
f"type={type(external_model).__name__}"
)
print(f"\n使用外部预训练模型: {model_desc}")
else:
effective_model = None
# 查找所有子文件夹
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
@ -900,9 +922,16 @@ class WaterQualityInference:
print(f"\n{'='*60}")
print(f"处理模型文件夹: {subdir_name}")
print(f"{'='*60}")
# 创建新的推理实例使用当前子文件夹作为artifacts_dir
model_inferencer = WaterQualityInference(str(subdir))
# 创建推理实例:外部模型优先注入,跳过磁盘查找
if effective_model is not None:
model_inferencer = WaterQualityInference(
str(subdir),
external_model=effective_model,
external_model_path=external_model_path,
)
else:
model_inferencer = WaterQualityInference(str(subdir))
# 根据输出格式设置文件扩展名
file_ext = f".{output_format}"