feat(step8): 外部模型从单文件升级为母文件夹多模型字典扫描

This commit is contained in:
DXC
2026-06-08 09:56:02 +08:00
parent 4efe5b871e
commit 2b76d7908f
12 changed files with 935 additions and 29 deletions

View File

@ -103,6 +103,10 @@ class PredictionStep:
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
callback: Optional[Callable] = None,
_report_generator=None,
_external_model=None,
_external_model_path=None,
_external_models_dict=None,
_external_model_dir=None,
) -> Dict[str, str]:
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
from src.core.prediction.inference_batch import WaterQualityInference
@ -149,19 +153,48 @@ class PredictionStep:
else:
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
inferencer = WaterQualityInference(models_dir)
all_results = inferencer.batch_inference_multi_models(
models_root_dir=models_dir,
sampling_csv_path=sampling_csv_path,
output_dir=str(ml_prediction_dir),
metric=metric,
prediction_column=prediction_column,
output_format="csv",
)
for target_name, result in all_results.items():
if result.get("status") == "success":
prediction_files[target_name] = result["output_file"]
if _external_models_dict:
# 外部模型字典优先:每个 {subdir_name: model_obj} 对应一个水质参数,
# 手动为每个模型创建 inference 实例并调用 inference_pipeline。
print(f"\n使用外部导入模型字典({len(_external_models_dict)} 个模型)...")
for target_name, model_obj in _external_models_dict.items():
try:
output_file = ml_prediction_dir / f"{target_name}.csv"
model_inferencer = WaterQualityInference(
models_dir or "./",
external_model=model_obj,
external_model_path=_external_model_dir or "",
)
predictions, result_df = model_inferencer.inference_pipeline(
sampling_csv_path=sampling_csv_path,
output_csv_path=str(output_file),
metric=metric,
prediction_column=prediction_column,
)
prediction_files[target_name] = str(output_file)
print(f"{target_name}: {len(predictions)} 个预测值")
except Exception as e:
print(f"{target_name}: 失败 — {type(e).__name__}: {e}")
prediction_files[target_name] = None
else:
inferencer = WaterQualityInference(
models_dir,
external_model=_external_model,
external_model_path=_external_model_path,
)
all_results = inferencer.batch_inference_multi_models(
models_root_dir=models_dir,
sampling_csv_path=sampling_csv_path,
output_dir=str(ml_prediction_dir),
metric=metric,
prediction_column=prediction_column,
output_format="csv",
external_model=_external_model,
external_model_path=_external_model_path,
)
for target_name, result in all_results.items():
if result.get("status") == "success":
prediction_files[target_name] = result["output_file"]
print(f"预测完成,结果保存在: {ml_prediction_dir}")