feat(step8): 外部模型从单文件升级为母文件夹多模型字典扫描
This commit is contained in:
@ -26,19 +26,24 @@ from sklearn.model_selection import train_test_split
|
||||
class WaterQualityInference:
|
||||
"""水质参数反演推理类"""
|
||||
|
||||
def __init__(self, artifacts_dir: str = "models/artifacts"):
|
||||
def __init__(self, artifacts_dir: str = "models/artifacts",
|
||||
external_model=None, external_model_path=None):
|
||||
"""
|
||||
初始化推理类
|
||||
|
||||
Args:
|
||||
artifacts_dir: 模型保存目录
|
||||
external_model: 外部预训练模型对象(来自 GUI 导入,跳过磁盘加载)
|
||||
external_model_path: 外部模型文件路径(仅用于日志)
|
||||
"""
|
||||
self.artifacts_dir = Path(artifacts_dir)
|
||||
if not self.artifacts_dir.exists():
|
||||
print(f"警告: 模型目录不存在: {artifacts_dir},将在需要时创建")
|
||||
|
||||
|
||||
self.best_model_info = None
|
||||
self.loaded_model_data = None
|
||||
self.external_model = external_model
|
||||
self.external_model_path = external_model_path
|
||||
|
||||
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
@ -745,7 +750,11 @@ class WaterQualityInference:
|
||||
# 1. 加载模型
|
||||
print("\n步骤1: 加载模型")
|
||||
print("-" * 40)
|
||||
if model_file_path:
|
||||
if self.external_model is not None:
|
||||
# 外部预训练模型已注入,直接使用,跳过磁盘加载
|
||||
self.loaded_model_data = self.external_model
|
||||
print(f" 使用外部预训练模型: type={type(self.external_model).__name__}")
|
||||
elif model_file_path:
|
||||
self.load_specific_model(model_file_path)
|
||||
else:
|
||||
self.load_best_model(metric=metric)
|
||||
@ -863,10 +872,12 @@ class WaterQualityInference:
|
||||
print(f"\n批量推理完成,共处理 {len(csv_files)} 个文件")
|
||||
return results
|
||||
|
||||
def batch_inference_multi_models(self, models_root_dir: str, sampling_csv_path: str,
|
||||
output_dir: str, metric: str = 'test_r2',
|
||||
def batch_inference_multi_models(self, models_root_dir: str, sampling_csv_path: str,
|
||||
output_dir: str, metric: str = 'test_r2',
|
||||
prediction_column: str = 'prediction',
|
||||
output_format: str = 'csv'):
|
||||
output_format: str = 'csv',
|
||||
external_model=None,
|
||||
external_model_path=None):
|
||||
"""
|
||||
使用多个子文件夹中的模型进行批量推理
|
||||
|
||||
@ -881,7 +892,18 @@ class WaterQualityInference:
|
||||
models_root = Path(models_root_dir)
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 优先级:外部预训练模型 > 从磁盘加载
|
||||
if external_model is not None:
|
||||
effective_model = external_model
|
||||
model_desc = (
|
||||
f"外部导入模型 ({external_model_path or 'unknown'}), "
|
||||
f"type={type(external_model).__name__}"
|
||||
)
|
||||
print(f"\n使用外部预训练模型: {model_desc}")
|
||||
else:
|
||||
effective_model = None
|
||||
|
||||
# 查找所有子文件夹
|
||||
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
|
||||
|
||||
@ -900,9 +922,16 @@ class WaterQualityInference:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"处理模型文件夹: {subdir_name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 创建新的推理实例,使用当前子文件夹作为artifacts_dir
|
||||
model_inferencer = WaterQualityInference(str(subdir))
|
||||
|
||||
# 创建推理实例:外部模型优先注入,跳过磁盘查找
|
||||
if effective_model is not None:
|
||||
model_inferencer = WaterQualityInference(
|
||||
str(subdir),
|
||||
external_model=effective_model,
|
||||
external_model_path=external_model_path,
|
||||
)
|
||||
else:
|
||||
model_inferencer = WaterQualityInference(str(subdir))
|
||||
|
||||
# 根据输出格式设置文件扩展名
|
||||
file_ext = f".{output_format}"
|
||||
|
||||
@ -103,6 +103,10 @@ class PredictionStep:
|
||||
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
_external_model=None,
|
||||
_external_model_path=None,
|
||||
_external_models_dict=None,
|
||||
_external_model_dir=None,
|
||||
) -> Dict[str, str]:
|
||||
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
@ -149,19 +153,48 @@ class PredictionStep:
|
||||
else:
|
||||
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
|
||||
|
||||
inferencer = WaterQualityInference(models_dir)
|
||||
all_results = inferencer.batch_inference_multi_models(
|
||||
models_root_dir=models_dir,
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_dir=str(ml_prediction_dir),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
output_format="csv",
|
||||
)
|
||||
|
||||
for target_name, result in all_results.items():
|
||||
if result.get("status") == "success":
|
||||
prediction_files[target_name] = result["output_file"]
|
||||
if _external_models_dict:
|
||||
# 外部模型字典优先:每个 {subdir_name: model_obj} 对应一个水质参数,
|
||||
# 手动为每个模型创建 inference 实例并调用 inference_pipeline。
|
||||
print(f"\n使用外部导入模型字典({len(_external_models_dict)} 个模型)...")
|
||||
for target_name, model_obj in _external_models_dict.items():
|
||||
try:
|
||||
output_file = ml_prediction_dir / f"{target_name}.csv"
|
||||
model_inferencer = WaterQualityInference(
|
||||
models_dir or "./",
|
||||
external_model=model_obj,
|
||||
external_model_path=_external_model_dir or "",
|
||||
)
|
||||
predictions, result_df = model_inferencer.inference_pipeline(
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_csv_path=str(output_file),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
)
|
||||
prediction_files[target_name] = str(output_file)
|
||||
print(f" ✓ {target_name}: {len(predictions)} 个预测值")
|
||||
except Exception as e:
|
||||
print(f" ✗ {target_name}: 失败 — {type(e).__name__}: {e}")
|
||||
prediction_files[target_name] = None
|
||||
else:
|
||||
inferencer = WaterQualityInference(
|
||||
models_dir,
|
||||
external_model=_external_model,
|
||||
external_model_path=_external_model_path,
|
||||
)
|
||||
all_results = inferencer.batch_inference_multi_models(
|
||||
models_root_dir=models_dir,
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_dir=str(ml_prediction_dir),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
output_format="csv",
|
||||
external_model=_external_model,
|
||||
external_model_path=_external_model_path,
|
||||
)
|
||||
for target_name, result in all_results.items():
|
||||
if result.get("status") == "success":
|
||||
prediction_files[target_name] = result["output_file"]
|
||||
|
||||
print(f"预测完成,结果保存在: {ml_prediction_dir}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user