fix(step8): 修复外部模型字典透传断链 + 规范化 loaded_model_data 防 Ridge subscriptable 崩溃
This commit is contained in:
@ -41,10 +41,17 @@ class WaterQualityInference:
|
||||
print(f"警告: 模型目录不存在: {artifacts_dir},将在需要时创建")
|
||||
|
||||
self.best_model_info = None
|
||||
self.loaded_model_data = None
|
||||
self.external_model = external_model
|
||||
self.external_model_path = external_model_path
|
||||
|
||||
# 规范化 loaded_model_data:始终为 dict,确保 ['model'] 访问不崩溃
|
||||
if external_model is not None:
|
||||
# 外部传入的是裸模型对象 → 包装为 dict,统一后续 .get('model') 访问
|
||||
self.loaded_model_data = {'model': external_model, 'preprocess_method': 'None'}
|
||||
print(f" 外部模型已规范化: type={type(external_model).__name__}")
|
||||
else:
|
||||
self.loaded_model_data = None
|
||||
|
||||
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
加载sampling生成的CSV数据(兼容 WQI 增强版 CSV)
|
||||
@ -751,8 +758,7 @@ class WaterQualityInference:
|
||||
print("\n步骤1: 加载模型")
|
||||
print("-" * 40)
|
||||
if self.external_model is not None:
|
||||
# 外部预训练模型已注入,直接使用,跳过磁盘加载
|
||||
self.loaded_model_data = self.external_model
|
||||
# 已在 __init__ 中规范化,无需重复赋值
|
||||
print(f" 使用外部预训练模型: type={type(self.external_model).__name__}")
|
||||
elif model_file_path:
|
||||
self.load_specific_model(model_file_path)
|
||||
@ -802,8 +808,8 @@ class WaterQualityInference:
|
||||
|
||||
info = {
|
||||
"status": "model_loaded",
|
||||
"preprocess_method": self.loaded_model_data['preprocess_method'],
|
||||
"model_name": self.loaded_model_data['model_name'],
|
||||
"preprocess_method": self.loaded_model_data.get('preprocess_method', 'Unknown'),
|
||||
"model_name": self.loaded_model_data.get('model_name', type(self.external_model).__name__ if self.external_model else 'Unknown'),
|
||||
"model_type": str(type(self.loaded_model_data['model'])),
|
||||
"metadata": self.loaded_model_data.get('metadata', {})
|
||||
}
|
||||
@ -877,7 +883,8 @@ class WaterQualityInference:
|
||||
prediction_column: str = 'prediction',
|
||||
output_format: str = 'csv',
|
||||
external_model=None,
|
||||
external_model_path=None):
|
||||
external_model_path=None,
|
||||
external_models_dict=None):
|
||||
"""
|
||||
使用多个子文件夹中的模型进行批量推理
|
||||
|
||||
@ -893,45 +900,61 @@ class WaterQualityInference:
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 优先级:外部预训练模型 > 从磁盘加载
|
||||
if external_model is not None:
|
||||
effective_model = external_model
|
||||
model_desc = (
|
||||
f"外部导入模型 ({external_model_path or 'unknown'}), "
|
||||
f"type={type(external_model).__name__}"
|
||||
)
|
||||
print(f"\n使用外部预训练模型: {model_desc}")
|
||||
else:
|
||||
effective_model = None
|
||||
|
||||
# 查找所有子文件夹
|
||||
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
|
||||
|
||||
if not subdirs:
|
||||
print(f"在目录 {models_root_dir} 中未找到子文件夹")
|
||||
return
|
||||
|
||||
print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理")
|
||||
print(f"输出格式: {output_format.upper()}")
|
||||
|
||||
all_results = {}
|
||||
|
||||
for subdir in subdirs:
|
||||
|
||||
# 优先级 1:_external_models_dict 非空 → 直接用字典的 keys 作为 targets,不扫描磁盘
|
||||
print(f"[BatchInference] 终于收到字典啦!包含模型: {list(external_models_dict.keys()) if external_models_dict else 'None'}")
|
||||
if external_models_dict is not None and len(external_models_dict) > 0:
|
||||
targets = list(external_models_dict.keys())
|
||||
print(f"\n使用外部导入模型字典({len(targets)} 个模型)")
|
||||
print(f"检测到外部导入模型,将预测以下参数: {targets}")
|
||||
elif external_model is not None:
|
||||
print(f"\n使用外部预训练模型: {external_model_path or 'unknown'}")
|
||||
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
|
||||
if not subdirs:
|
||||
print(f"在目录 {models_root_dir} 中未找到子文件夹")
|
||||
return {}
|
||||
print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理")
|
||||
targets = [d.name for d in subdirs]
|
||||
else:
|
||||
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
|
||||
if not subdirs:
|
||||
print(f"在目录 {models_root_dir} 中未找到子文件夹")
|
||||
return {}
|
||||
print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理")
|
||||
targets = [d.name for d in subdirs]
|
||||
|
||||
print(f"输出格式: {output_format.upper()}")
|
||||
|
||||
for subdir_name in targets:
|
||||
try:
|
||||
subdir_name = subdir.name
|
||||
print(f"\n{'='*60}")
|
||||
print(f"处理模型文件夹: {subdir_name}")
|
||||
print(f"处理模型: {subdir_name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 创建推理实例:外部模型优先注入,跳过磁盘查找
|
||||
# 优先级:字典中该 target 的模型 > 共享单模型 > 磁盘加载
|
||||
effective_model = None
|
||||
if external_models_dict and subdir_name in external_models_dict:
|
||||
effective_model = external_models_dict[subdir_name]
|
||||
print(f" → 使用字典中模型: {type(effective_model).__name__}")
|
||||
elif external_model is not None:
|
||||
effective_model = external_model
|
||||
print(f" → 使用共享外部模型: {type(effective_model).__name__}")
|
||||
|
||||
# artifacts_dir:字典模式优先用 placeholder "./",否则用真实子目录
|
||||
artifacts_dir = (
|
||||
str(models_root / subdir_name)
|
||||
if (models_root / subdir_name).is_dir()
|
||||
else str(models_root)
|
||||
)
|
||||
if effective_model is not None:
|
||||
model_inferencer = WaterQualityInference(
|
||||
str(subdir),
|
||||
artifacts_dir,
|
||||
external_model=effective_model,
|
||||
external_model_path=external_model_path,
|
||||
external_model_path=external_model_path or "",
|
||||
)
|
||||
else:
|
||||
model_inferencer = WaterQualityInference(str(subdir))
|
||||
model_inferencer = WaterQualityInference(artifacts_dir)
|
||||
|
||||
# 根据输出格式设置文件扩展名
|
||||
file_ext = f".{output_format}"
|
||||
@ -960,10 +983,10 @@ class WaterQualityInference:
|
||||
}
|
||||
}
|
||||
|
||||
print(f"子文件夹 {subdir_name} 处理完成")
|
||||
print(f"模型 {subdir_name} 处理完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理子文件夹 {subdir_name} 失败: {e}")
|
||||
print(f"处理模型 {subdir_name} 失败: {e}")
|
||||
all_results[subdir_name] = {
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
|
||||
Reference in New Issue
Block a user