fix(step8): 修复外部模型字典透传断链 + 规范化 loaded_model_data 防 Ridge subscriptable 崩溃

This commit is contained in:
DXC
2026-06-08 11:36:36 +08:00
parent 2b76d7908f
commit e3debbcb15
5 changed files with 189 additions and 44 deletions

View File

@ -118,6 +118,8 @@ class PredictionStep:
print("\n" + "=" * 80)
print("步骤8: 预测水质参数")
print("=" * 80)
print(f"[PredictionStep] 准备执行预测,字典状态: {'Yes' if _external_models_dict else 'No'}"
f", 单模型状态: {'Yes' if _external_model else 'No'}")
step_start_time = time.time()
@ -153,8 +155,10 @@ class PredictionStep:
else:
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
all_results = {}
if _external_models_dict:
# 外部模型字典优先:每个 {subdir_name: model_obj} 对应一个水质参数
# 外部模型字典优先:直接用字典的 keys 作为 targets 列表
# 手动为每个模型创建 inference 实例并调用 inference_pipeline。
print(f"\n使用外部导入模型字典({len(_external_models_dict)} 个模型)...")
for target_name, model_obj in _external_models_dict.items():
@ -172,11 +176,18 @@ class PredictionStep:
prediction_column=prediction_column,
)
prediction_files[target_name] = str(output_file)
all_results[target_name] = {
"status": "success",
"output_file": str(output_file),
"sample_count": len(predictions),
}
print(f"{target_name}: {len(predictions)} 个预测值")
except Exception as e:
print(f"{target_name}: 失败 — {type(e).__name__}: {e}")
prediction_files[target_name] = None
all_results[target_name] = {"status": "error", "error": str(e)}
else:
# 字典为空或不存在:回退到扫描 models_dir 子目录的传统逻辑
inferencer = WaterQualityInference(
models_dir,
external_model=_external_model,
@ -191,10 +202,13 @@ class PredictionStep:
output_format="csv",
external_model=_external_model,
external_model_path=_external_model_path,
external_models_dict=_external_models_dict,
)
for target_name, result in all_results.items():
if result.get("status") == "success":
prediction_files[target_name] = result["output_file"]
# batch_inference_multi_models 已确保返回字典,永不返回 None
if all_results:
for target_name, result in all_results.items():
if result.get("status") == "success":
prediction_files[target_name] = result["output_file"]
print(f"预测完成,结果保存在: {ml_prediction_dir}")