fix(step8): 修复外部模型字典透传断链 + 规范化 loaded_model_data 防 Ridge subscriptable 崩溃

This commit is contained in:
DXC
2026-06-08 11:36:36 +08:00
parent 2b76d7908f
commit e3debbcb15
5 changed files with 189 additions and 44 deletions

View File

@ -41,10 +41,17 @@ class WaterQualityInference:
print(f"警告: 模型目录不存在: {artifacts_dir},将在需要时创建")
self.best_model_info = None
self.loaded_model_data = None
self.external_model = external_model
self.external_model_path = external_model_path
# 规范化 loaded_model_data始终为 dict确保 ['model'] 访问不崩溃
if external_model is not None:
# 外部传入的是裸模型对象 → 包装为 dict统一后续 .get('model') 访问
self.loaded_model_data = {'model': external_model, 'preprocess_method': 'None'}
print(f" 外部模型已规范化: type={type(external_model).__name__}")
else:
self.loaded_model_data = None
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
加载sampling生成的CSV数据兼容 WQI 增强版 CSV
@ -751,8 +758,7 @@ class WaterQualityInference:
print("\n步骤1: 加载模型")
print("-" * 40)
if self.external_model is not None:
# 外部预训练模型已注入,直接使用,跳过磁盘加载
self.loaded_model_data = self.external_model
# 已在 __init__ 中规范化,无需重复赋值
print(f" 使用外部预训练模型: type={type(self.external_model).__name__}")
elif model_file_path:
self.load_specific_model(model_file_path)
@ -802,8 +808,8 @@ class WaterQualityInference:
info = {
"status": "model_loaded",
"preprocess_method": self.loaded_model_data['preprocess_method'],
"model_name": self.loaded_model_data['model_name'],
"preprocess_method": self.loaded_model_data.get('preprocess_method', 'Unknown'),
"model_name": self.loaded_model_data.get('model_name', type(self.external_model).__name__ if self.external_model else 'Unknown'),
"model_type": str(type(self.loaded_model_data['model'])),
"metadata": self.loaded_model_data.get('metadata', {})
}
@ -877,7 +883,8 @@ class WaterQualityInference:
prediction_column: str = 'prediction',
output_format: str = 'csv',
external_model=None,
external_model_path=None):
external_model_path=None,
external_models_dict=None):
"""
使用多个子文件夹中的模型进行批量推理
@ -893,45 +900,61 @@ class WaterQualityInference:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 优先级:外部预训练模型 > 从磁盘加载
if external_model is not None:
effective_model = external_model
model_desc = (
f"外部导入模型 ({external_model_path or 'unknown'}), "
f"type={type(external_model).__name__}"
)
print(f"\n使用外部预训练模型: {model_desc}")
else:
effective_model = None
# 查找所有子文件夹
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
if not subdirs:
print(f"在目录 {models_root_dir} 中未找到子文件夹")
return
print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理")
print(f"输出格式: {output_format.upper()}")
all_results = {}
for subdir in subdirs:
# 优先级 1_external_models_dict 非空 → 直接用字典的 keys 作为 targets不扫描磁盘
print(f"[BatchInference] 终于收到字典啦!包含模型: {list(external_models_dict.keys()) if external_models_dict else 'None'}")
if external_models_dict is not None and len(external_models_dict) > 0:
targets = list(external_models_dict.keys())
print(f"\n使用外部导入模型字典({len(targets)} 个模型)")
print(f"检测到外部导入模型,将预测以下参数: {targets}")
elif external_model is not None:
print(f"\n使用外部预训练模型: {external_model_path or 'unknown'}")
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
if not subdirs:
print(f"在目录 {models_root_dir} 中未找到子文件夹")
return {}
print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理")
targets = [d.name for d in subdirs]
else:
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
if not subdirs:
print(f"在目录 {models_root_dir} 中未找到子文件夹")
return {}
print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理")
targets = [d.name for d in subdirs]
print(f"输出格式: {output_format.upper()}")
for subdir_name in targets:
try:
subdir_name = subdir.name
print(f"\n{'='*60}")
print(f"处理模型文件夹: {subdir_name}")
print(f"处理模型: {subdir_name}")
print(f"{'='*60}")
# 创建推理实例:外部模型优先注入,跳过磁盘查找
# 优先级:字典中该 target 的模型 > 共享单模型 > 磁盘加载
effective_model = None
if external_models_dict and subdir_name in external_models_dict:
effective_model = external_models_dict[subdir_name]
print(f" → 使用字典中模型: {type(effective_model).__name__}")
elif external_model is not None:
effective_model = external_model
print(f" → 使用共享外部模型: {type(effective_model).__name__}")
# artifacts_dir字典模式优先用 placeholder "./",否则用真实子目录
artifacts_dir = (
str(models_root / subdir_name)
if (models_root / subdir_name).is_dir()
else str(models_root)
)
if effective_model is not None:
model_inferencer = WaterQualityInference(
str(subdir),
artifacts_dir,
external_model=effective_model,
external_model_path=external_model_path,
external_model_path=external_model_path or "",
)
else:
model_inferencer = WaterQualityInference(str(subdir))
model_inferencer = WaterQualityInference(artifacts_dir)
# 根据输出格式设置文件扩展名
file_ext = f".{output_format}"
@ -960,10 +983,10 @@ class WaterQualityInference:
}
}
print(f"子文件夹 {subdir_name} 处理完成")
print(f"模型 {subdir_name} 处理完成")
except Exception as e:
print(f"处理子文件夹 {subdir_name} 失败: {e}")
print(f"处理模型 {subdir_name} 失败: {e}")
all_results[subdir_name] = {
'status': 'error',
'error': str(e)

View File

@ -118,6 +118,8 @@ class PredictionStep:
print("\n" + "=" * 80)
print("步骤8: 预测水质参数")
print("=" * 80)
print(f"[PredictionStep] 准备执行预测,字典状态: {'Yes' if _external_models_dict else 'No'}"
f", 单模型状态: {'Yes' if _external_model else 'No'}")
step_start_time = time.time()
@ -153,8 +155,10 @@ class PredictionStep:
else:
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
all_results = {}
if _external_models_dict:
# 外部模型字典优先:每个 {subdir_name: model_obj} 对应一个水质参数
# 外部模型字典优先:直接用字典的 keys 作为 targets 列表
# 手动为每个模型创建 inference 实例并调用 inference_pipeline。
print(f"\n使用外部导入模型字典({len(_external_models_dict)} 个模型)...")
for target_name, model_obj in _external_models_dict.items():
@ -172,11 +176,18 @@ class PredictionStep:
prediction_column=prediction_column,
)
prediction_files[target_name] = str(output_file)
all_results[target_name] = {
"status": "success",
"output_file": str(output_file),
"sample_count": len(predictions),
}
print(f"{target_name}: {len(predictions)} 个预测值")
except Exception as e:
print(f"{target_name}: 失败 — {type(e).__name__}: {e}")
prediction_files[target_name] = None
all_results[target_name] = {"status": "error", "error": str(e)}
else:
# 字典为空或不存在:回退到扫描 models_dir 子目录的传统逻辑
inferencer = WaterQualityInference(
models_dir,
external_model=_external_model,
@ -191,10 +202,13 @@ class PredictionStep:
output_format="csv",
external_model=_external_model,
external_model_path=_external_model_path,
external_models_dict=_external_models_dict,
)
for target_name, result in all_results.items():
if result.get("status") == "success":
prediction_files[target_name] = result["output_file"]
# batch_inference_multi_models 已确保返回字典,永不返回 None
if all_results:
for target_name, result in all_results.items():
if result.get("status") == "success":
prediction_files[target_name] = result["output_file"]
print(f"预测完成,结果保存在: {ml_prediction_dir}")

View File

@ -808,6 +808,13 @@ class WaterQualityInversionPipeline:
Returns:
预测结果文件路径字典(键为目标列名)
"""
_external_models_dict = kwargs.get('_external_models_dict')
_external_model = kwargs.get('_external_model')
_external_model_path = kwargs.get('_external_model_path')
_external_model_dir = kwargs.get('_external_model_dir')
print(f"[Pipeline] 收到字典: {'Yes' if _external_models_dict else 'No'}"
f", 收到单模型: {'Yes' if _external_model else 'No'}")
self._notify("started", "步骤8: 预测水质参数")
result = PredictionStep.predict_water_quality(
sampling_csv_path=sampling_csv_path,
@ -816,11 +823,15 @@ class WaterQualityInversionPipeline:
prediction_column=prediction_column,
output_dir=str(self.prediction_dir / "Machine_Learning_Prediction"),
_report_generator=self.report_generator,
_external_model=_external_model,
_external_model_path=_external_model_path,
_external_models_dict=_external_models_dict,
_external_model_dir=_external_model_dir,
)
self._record_step_time("步骤8: 预测水质参数", 0, 0)
self._notify("completed", f"预测完成,结果保存在: {self.prediction_dir}")
return result
def step9_generate_distribution_map(self, prediction_csv_path: str,
boundary_shp_path: str,
output_image_path: Optional[str] = None,