From e3debbcb15eef04d1fe052c8155531f087976e08 Mon Sep 17 00:00:00 2001 From: DXC Date: Mon, 8 Jun 2026 11:36:36 +0800 Subject: [PATCH] =?UTF-8?q?fix(step8):=20=E4=BF=AE=E5=A4=8D=E5=A4=96?= =?UTF-8?q?=E9=83=A8=E6=A8=A1=E5=9E=8B=E5=AD=97=E5=85=B8=E9=80=8F=E4=BC=A0?= =?UTF-8?q?=E6=96=AD=E9=93=BE=20+=20=E8=A7=84=E8=8C=83=E5=8C=96=20loaded?= =?UTF-8?q?=5Fmodel=5Fdata=20=E9=98=B2=20Ridge=20subscriptable=20=E5=B4=A9?= =?UTF-8?q?=E6=BA=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/prediction/inference_batch.py | 97 ++++++++++++------- src/core/steps/prediction_step.py | 22 ++++- .../water_quality_inversion_pipeline_GUI.py | 13 ++- src/gui/core/worker_thread.py | 4 + src/gui/panels/step8_panel.py | 97 ++++++++++++++++++- 5 files changed, 189 insertions(+), 44 deletions(-) diff --git a/src/core/prediction/inference_batch.py b/src/core/prediction/inference_batch.py index 39a8900..837b241 100644 --- a/src/core/prediction/inference_batch.py +++ b/src/core/prediction/inference_batch.py @@ -41,10 +41,17 @@ class WaterQualityInference: print(f"警告: 模型目录不存在: {artifacts_dir},将在需要时创建") self.best_model_info = None - self.loaded_model_data = None self.external_model = external_model self.external_model_path = external_model_path + # 规范化 loaded_model_data:始终为 dict,确保 ['model'] 访问不崩溃 + if external_model is not None: + # 外部传入的是裸模型对象 → 包装为 dict,统一后续 .get('model') 访问 + self.loaded_model_data = {'model': external_model, 'preprocess_method': 'None'} + print(f" 外部模型已规范化: type={type(external_model).__name__}") + else: + self.loaded_model_data = None + def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ 加载sampling生成的CSV数据(兼容 WQI 增强版 CSV) @@ -751,8 +758,7 @@ class WaterQualityInference: print("\n步骤1: 加载模型") print("-" * 40) if self.external_model is not None: - # 外部预训练模型已注入,直接使用,跳过磁盘加载 - self.loaded_model_data = self.external_model + # 已在 __init__ 中规范化,无需重复赋值 print(f" 使用外部预训练模型: type={type(self.external_model).__name__}") elif model_file_path: self.load_specific_model(model_file_path) @@ -802,8 +808,8 @@ class WaterQualityInference: info = { "status": "model_loaded", - "preprocess_method": self.loaded_model_data['preprocess_method'], - "model_name": self.loaded_model_data['model_name'], + "preprocess_method": self.loaded_model_data.get('preprocess_method', 'Unknown'), + "model_name": self.loaded_model_data.get('model_name', type(self.external_model).__name__ if self.external_model else 'Unknown'), "model_type": str(type(self.loaded_model_data['model'])), "metadata": self.loaded_model_data.get('metadata', {}) } @@ -877,7 +883,8 @@ class WaterQualityInference: prediction_column: str = 'prediction', output_format: str = 'csv', external_model=None, - external_model_path=None): + external_model_path=None, + external_models_dict=None): """ 使用多个子文件夹中的模型进行批量推理 @@ -893,45 +900,61 @@ class WaterQualityInference: output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) - # 优先级:外部预训练模型 > 从磁盘加载 - if external_model is not None: - effective_model = external_model - model_desc = ( - f"外部导入模型 ({external_model_path or 'unknown'}), " - f"type={type(external_model).__name__}" - ) - print(f"\n使用外部预训练模型: {model_desc}") - else: - effective_model = None - - # 查找所有子文件夹 - subdirs = [d for d in models_root.iterdir() if d.is_dir()] - - if not subdirs: - print(f"在目录 {models_root_dir} 中未找到子文件夹") - return - - print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理") - print(f"输出格式: {output_format.upper()}") - all_results = {} - - for subdir in subdirs: + + # 优先级 1:_external_models_dict 非空 → 直接用字典的 keys 作为 targets,不扫描磁盘 + print(f"[BatchInference] 终于收到字典啦!包含模型: {list(external_models_dict.keys()) if external_models_dict else 'None'}") + if external_models_dict is not None and len(external_models_dict) > 0: + targets = list(external_models_dict.keys()) + print(f"\n使用外部导入模型字典({len(targets)} 个模型)") + print(f"检测到外部导入模型,将预测以下参数: {targets}") + elif external_model is not None: + print(f"\n使用外部预训练模型: {external_model_path or 'unknown'}") + subdirs = [d for d in models_root.iterdir() if d.is_dir()] + if not subdirs: + print(f"在目录 {models_root_dir} 中未找到子文件夹") + return {} + print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理") + targets = [d.name for d in subdirs] + else: + subdirs = [d for d in models_root.iterdir() if d.is_dir()] + if not subdirs: + print(f"在目录 {models_root_dir} 中未找到子文件夹") + return {} + print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理") + targets = [d.name for d in subdirs] + + print(f"输出格式: {output_format.upper()}") + + for subdir_name in targets: try: - subdir_name = subdir.name print(f"\n{'='*60}") - print(f"处理模型文件夹: {subdir_name}") + print(f"处理模型: {subdir_name}") print(f"{'='*60}") - # 创建推理实例:外部模型优先注入,跳过磁盘查找 + # 优先级:字典中该 target 的模型 > 共享单模型 > 磁盘加载 + effective_model = None + if external_models_dict and subdir_name in external_models_dict: + effective_model = external_models_dict[subdir_name] + print(f" → 使用字典中模型: {type(effective_model).__name__}") + elif external_model is not None: + effective_model = external_model + print(f" → 使用共享外部模型: {type(effective_model).__name__}") + + # artifacts_dir:字典模式优先用 placeholder "./",否则用真实子目录 + artifacts_dir = ( + str(models_root / subdir_name) + if (models_root / subdir_name).is_dir() + else str(models_root) + ) if effective_model is not None: model_inferencer = WaterQualityInference( - str(subdir), + artifacts_dir, external_model=effective_model, - external_model_path=external_model_path, + external_model_path=external_model_path or "", ) else: - model_inferencer = WaterQualityInference(str(subdir)) + model_inferencer = WaterQualityInference(artifacts_dir) # 根据输出格式设置文件扩展名 file_ext = f".{output_format}" @@ -960,10 +983,10 @@ class WaterQualityInference: } } - print(f"子文件夹 {subdir_name} 处理完成") + print(f"模型 {subdir_name} 处理完成") except Exception as e: - print(f"处理子文件夹 {subdir_name} 失败: {e}") + print(f"处理模型 {subdir_name} 失败: {e}") all_results[subdir_name] = { 'status': 'error', 'error': str(e) diff --git a/src/core/steps/prediction_step.py b/src/core/steps/prediction_step.py index 948c477..12262e2 100644 --- a/src/core/steps/prediction_step.py +++ b/src/core/steps/prediction_step.py @@ -118,6 +118,8 @@ class PredictionStep: print("\n" + "=" * 80) print("步骤8: 预测水质参数") print("=" * 80) + print(f"[PredictionStep] 准备执行预测,字典状态: {'Yes' if _external_models_dict else 'No'}" + f", 单模型状态: {'Yes' if _external_model else 'No'}") step_start_time = time.time() @@ -153,8 +155,10 @@ class PredictionStep: else: print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...") + all_results = {} + if _external_models_dict: - # 外部模型字典优先:每个 {subdir_name: model_obj} 对应一个水质参数, + # 外部模型字典优先:直接用字典的 keys 作为 targets 列表, # 手动为每个模型创建 inference 实例并调用 inference_pipeline。 print(f"\n使用外部导入模型字典({len(_external_models_dict)} 个模型)...") for target_name, model_obj in _external_models_dict.items(): @@ -172,11 +176,18 @@ class PredictionStep: prediction_column=prediction_column, ) prediction_files[target_name] = str(output_file) + all_results[target_name] = { + "status": "success", + "output_file": str(output_file), + "sample_count": len(predictions), + } print(f" ✓ {target_name}: {len(predictions)} 个预测值") except Exception as e: print(f" ✗ {target_name}: 失败 — {type(e).__name__}: {e}") prediction_files[target_name] = None + all_results[target_name] = {"status": "error", "error": str(e)} else: + # 字典为空或不存在:回退到扫描 models_dir 子目录的传统逻辑 inferencer = WaterQualityInference( models_dir, external_model=_external_model, @@ -191,10 +202,13 @@ class PredictionStep: output_format="csv", external_model=_external_model, external_model_path=_external_model_path, + external_models_dict=_external_models_dict, ) - for target_name, result in all_results.items(): - if result.get("status") == "success": - prediction_files[target_name] = result["output_file"] + # batch_inference_multi_models 已确保返回字典,永不返回 None + if all_results: + for target_name, result in all_results.items(): + if result.get("status") == "success": + prediction_files[target_name] = result["output_file"] print(f"预测完成,结果保存在: {ml_prediction_dir}") diff --git a/src/core/water_quality_inversion_pipeline_GUI.py b/src/core/water_quality_inversion_pipeline_GUI.py index 307f6da..8e7a8ab 100644 --- a/src/core/water_quality_inversion_pipeline_GUI.py +++ b/src/core/water_quality_inversion_pipeline_GUI.py @@ -808,6 +808,13 @@ class WaterQualityInversionPipeline: Returns: 预测结果文件路径字典(键为目标列名) """ + _external_models_dict = kwargs.get('_external_models_dict') + _external_model = kwargs.get('_external_model') + _external_model_path = kwargs.get('_external_model_path') + _external_model_dir = kwargs.get('_external_model_dir') + print(f"[Pipeline] 收到字典: {'Yes' if _external_models_dict else 'No'}" + f", 收到单模型: {'Yes' if _external_model else 'No'}") + self._notify("started", "步骤8: 预测水质参数") result = PredictionStep.predict_water_quality( sampling_csv_path=sampling_csv_path, @@ -816,11 +823,15 @@ class WaterQualityInversionPipeline: prediction_column=prediction_column, output_dir=str(self.prediction_dir / "Machine_Learning_Prediction"), _report_generator=self.report_generator, + _external_model=_external_model, + _external_model_path=_external_model_path, + _external_models_dict=_external_models_dict, + _external_model_dir=_external_model_dir, ) self._record_step_time("步骤8: 预测水质参数", 0, 0) self._notify("completed", f"预测完成,结果保存在: {self.prediction_dir}") return result - + def step9_generate_distribution_map(self, prediction_csv_path: str, boundary_shp_path: str, output_image_path: Optional[str] = None, diff --git a/src/gui/core/worker_thread.py b/src/gui/core/worker_thread.py index 8ae50d7..efd90fa 100644 --- a/src/gui/core/worker_thread.py +++ b/src/gui/core/worker_thread.py @@ -333,6 +333,10 @@ class WorkerThread(QThread): val = config.get(key) if val is not None and val != "": step_config[key] = val + if key == '_external_models_dict': + print(f"[Worker] 提取到的外部字典 Keys: {list(val.keys())}") + else: + print(f"[Worker] 透传 {key}: {val}") step_config['skip_dependency_check'] = True diff --git a/src/gui/panels/step8_panel.py b/src/gui/panels/step8_panel.py index 87a75ad..727cbb1 100644 --- a/src/gui/panels/step8_panel.py +++ b/src/gui/panels/step8_panel.py @@ -10,8 +10,10 @@ from pathlib import Path from PyQt5.QtWidgets import ( QWidget, QVBoxLayout, QGroupBox, QFormLayout, QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox, - QFileDialog, QRadioButton, + QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout, + QListWidgetItem, ) +from PyQt5.QtCore import Qt from src.gui.components.custom_widgets import FileSelectWidget from src.gui.styles import ModernStylesheet @@ -75,6 +77,50 @@ class Step8Panel(QWidget): self.external_model_widget.setVisible(False) layout.addWidget(self.external_model_widget) + # -------- 已扫描模型列表(条件显示) -------- + self.model_list_group = QGroupBox("选择参与预测的模型") + self.model_list_group.setVisible(False) + model_list_layout = QVBoxLayout() + + self.model_list = QListWidget() + self.model_list.setMaximumHeight(130) + self.model_list.setSelectionMode(QAbstractItemView.NoSelection) + self.model_list.setStyleSheet(""" + QListWidget { + border: 1px solid #C0C0C0; + border-radius: 4px; + background-color: #FFFFFF; + font-size: 12px; + } + QListWidget::item { + padding: 4px 6px; + border-bottom: 1px solid #F0F0F0; + } + QListWidget::item:selected { + background-color: transparent; + } + """) + model_list_layout.addWidget(self.model_list) + + btn_row = QHBoxLayout() + self.btn_select_all = QPushButton("全选") + self.btn_select_all.setMaximumWidth(80) + self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default')) + self.btn_select_all.clicked.connect(self._select_all_models) + + self.btn_select_none = QPushButton("全不选") + self.btn_select_none.setMaximumWidth(80) + self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default')) + self.btn_select_none.clicked.connect(self._select_none_models) + + btn_row.addWidget(self.btn_select_all) + btn_row.addWidget(self.btn_select_none) + btn_row.addStretch() + model_list_layout.addLayout(btn_row) + + self.model_list_group.setLayout(model_list_layout) + layout.addWidget(self.model_list_group) + # -------- 采样光谱CSV文件(用于独立运行)-------- self.sampling_csv_file = FileSelectWidget( "采样光谱CSV:", @@ -134,9 +180,11 @@ class Step8Panel(QWidget): return is_external = self.use_external_model.isChecked() self.external_model_widget.setVisible(is_external) + self.model_list_group.setVisible(is_external) if not is_external: self.external_models_dict = {} self.external_model_dir = "" + self._clear_model_list() def _scan_external_model_dir(self): """浏览模型母文件夹,自动扫描子目录中的 .joblib 文件""" @@ -200,9 +248,11 @@ class Step8Panel(QWidget): ) self.external_model_widget.set_path("") self.external_models_dict = {} + self._clear_model_list() return self.external_models_dict = models_found + self._populate_model_list(models_found) names = sorted(models_found.keys()) display = f"已识别到 {len(names)} 个模型: {', '.join(names)}" self.external_model_widget.set_path(display) @@ -216,6 +266,40 @@ class Step8Panel(QWidget): f"加载失败 {len(errors)} 个:\n{err_lines}", ) + def _populate_model_list(self, models_dict): + """将扫描到的模型填充到 QListWidget,每个条目可勾选,默认全选""" + self.model_list.clear() + for name in sorted(models_dict.keys()): + item = QListWidgetItem(name) + item.setFlags(item.flags() | Qt.ItemIsUserCheckable) + item.setCheckState(Qt.Checked) + self.model_list.addItem(item) + + def _clear_model_list(self): + """清空模型列表""" + self.model_list.clear() + + def _select_all_models(self): + """全选:设置所有条目为 Checked""" + for i in range(self.model_list.count()): + self.model_list.item(i).setCheckState(Qt.Checked) + + def _select_none_models(self): + """全不选:设置所有条目为 Unchecked""" + for i in range(self.model_list.count()): + self.model_list.item(i).setCheckState(Qt.Unchecked) + + def _get_checked_models_dict(self): + """从列表中提取用户勾选的模型,组装成字典返回""" + result = {} + for i in range(self.model_list.count()): + item = self.model_list.item(i) + if item.checkState() == Qt.Checked: + name = item.text() + if name in self.external_models_dict: + result[name] = self.external_models_dict[name] + return result + def update_from_config(self, work_dir=None, pipeline=None): """从全局配置自动填充采样光谱和模型目录 @@ -347,11 +431,20 @@ class Step8Panel(QWidget): "请先点击「浏览...」按钮选择模型母文件夹!", ) return + # 只传递用户勾选的模型 + checked_dict = self._get_checked_models_dict() + if not checked_dict: + QMessageBox.warning( + self, + "未选择模型", + "请至少勾选一个模型参与预测!", + ) + return main_window = self.window() if hasattr(main_window, 'run_single_step'): config = { 'step8': self.get_config(), - '_external_models_dict': self.external_models_dict, + '_external_models_dict': checked_dict, '_external_model_dir': self.external_model_dir, } main_window.run_single_step('step8', config)