#!/usr/bin/env python # -*- coding: utf-8 -*- """ Step11 面板 - 机器学习预测 """ import os import sys from pathlib import Path # 路径归一化 helper(与 pipeline.get_step_output_dir 互为表里) _HERE = os.path.dirname(os.path.abspath(__file__)) if _HERE not in sys.path: sys.path.insert(0, _HERE) from _step_path_resolver import get_step_output_path, resolve_step_widget, resolve_subdir from PyQt5.QtWidgets import ( QWidget, QVBoxLayout, QGroupBox, QFormLayout, QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox, QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout, QListWidgetItem, ) from PyQt5.QtCore import Qt from src.gui.components.custom_widgets import FileSelectWidget from src.gui.styles import ModernStylesheet class Step9MlPredictPanel(QWidget): """步骤9:机器学习预测""" def __init__(self, parent=None): super().__init__(parent) self.external_models_dict = {} # {subdir_name: model_obj, ...} self.external_model_dir = "" # 母文件夹路径(隐藏) self.init_ui() def init_ui(self): layout = QVBoxLayout() # -------- 模型来源选择(单选按钮组) -------- source_group = QGroupBox("模型来源") source_layout = QVBoxLayout() self.use_trained_model = QRadioButton("使用当前训练流程的模型") self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)") self.use_trained_model.setChecked(True) source_layout.addWidget(self.use_trained_model) source_layout.addWidget(self.use_external_model) self.use_trained_model.toggled.connect(self._on_model_source_changed) self.use_external_model.toggled.connect(self._on_model_source_changed) source_group.setStyleSheet(""" QRadioButton { font-size: 13px; spacing: 8px; } QRadioButton::indicator { width: 16px; height: 16px; border-radius: 9px; border: 2px solid #A0A0A0; background-color: #FFFFFF; } QRadioButton::indicator:hover { border: 2px solid #0078D7; } QRadioButton::indicator:checked { background-color: #0078D7; border: 2px solid #0078D7; } """) source_group.setLayout(source_layout) layout.addWidget(source_group) # -------- 外部模型文件选择(条件显示) -------- self.external_model_widget = FileSelectWidget( "模型母文件夹:", "Directories" ) self.external_model_widget.browse_btn.clicked.disconnect() self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir) self.external_model_widget.setVisible(False) layout.addWidget(self.external_model_widget) # -------- 已扫描模型列表(条件显示) -------- self.model_list_group = QGroupBox("选择参与预测的模型") self.model_list_group.setVisible(False) model_list_layout = QVBoxLayout() self.model_list = QListWidget() self.model_list.setMaximumHeight(130) self.model_list.setSelectionMode(QAbstractItemView.NoSelection) self.model_list.setStyleSheet(""" QListWidget { border: 1px solid #C0C0C0; border-radius: 4px; background-color: #FFFFFF; font-size: 12px; } QListWidget::item { padding: 4px 6px; border-bottom: 1px solid #F0F0F0; } QListWidget::item:selected { background-color: transparent; } """) model_list_layout.addWidget(self.model_list) btn_row = QHBoxLayout() self.btn_select_all = QPushButton("全选") self.btn_select_all.setMaximumWidth(80) self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default')) self.btn_select_all.clicked.connect(self._select_all_models) self.btn_select_none = QPushButton("全不选") self.btn_select_none.setMaximumWidth(80) self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default')) self.btn_select_none.clicked.connect(self._select_none_models) btn_row.addWidget(self.btn_select_all) btn_row.addWidget(self.btn_select_none) btn_row.addStretch() model_list_layout.addLayout(btn_row) self.model_list_group.setLayout(model_list_layout) layout.addWidget(self.model_list_group) # -------- 采样光谱CSV文件(用于独立运行)-------- self.sampling_csv_file = FileSelectWidget( "采样光谱CSV:", "CSV Files (*.csv);;All Files (*.*)" ) layout.addWidget(self.sampling_csv_file) # 模型目录(用于独立运行) self.models_dir_file = FileSelectWidget( "模型目录:", "Directories;;All Files (*.*)" ) self.models_dir_file.label.setText("模型目录:") self.models_dir_file.browse_btn.clicked.disconnect() self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir) layout.addWidget(self.models_dir_file) # 参数设置 params_group = QGroupBox("预测参数") params_layout = QFormLayout() self.metric = QComboBox() self.metric.addItems(['test_r2', 'test_rmse', 'test_mae']) params_layout.addRow("模型选择指标:", self.metric) self.prediction_column = QLineEdit() self.prediction_column.setText("prediction") params_layout.addRow("预测列名:", self.prediction_column) params_group.setLayout(params_layout) layout.addWidget(params_group) # 输出路径 self.output_file = FileSelectWidget( "输出路径:", "CSV Files (*.csv);;All Files (*.*)" ) layout.addWidget(self.output_file) # 启用步骤 self.enable_checkbox = QCheckBox("启用此步骤") self.enable_checkbox.setChecked(True) layout.addWidget(self.enable_checkbox) # 独立运行按钮 self.run_btn = QPushButton("独立运行此步骤") self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) self.run_btn.clicked.connect(self._on_run_single_clicked) layout.addWidget(self.run_btn) layout.addStretch() self.setLayout(layout) def _on_model_source_changed(self, checked: bool): """单选按钮切换:控制外部模型文件选择控件的显示/隐藏""" if not checked: return is_external = self.use_external_model.isChecked() self.external_model_widget.setVisible(is_external) self.model_list_group.setVisible(is_external) if not is_external: self.external_models_dict = {} self.external_model_dir = "" self._clear_model_list() def _scan_external_model_dir(self): """浏览模型母文件夹,自动扫描子目录中的 .joblib 文件""" default = self._get_default_work_dir() if default: default = resolve_subdir(default, 'ml_prediction') dir_path = QFileDialog.getExistingDirectory( self, "选择模型母文件夹", default, ) if not dir_path: return self.external_model_dir = dir_path models_found = {} errors = [] try: import joblib for subentry in os.scandir(dir_path): if not subentry.is_dir(): continue subdir_name = subentry.name joblib_files = [ f for f in os.scandir(subentry.path) if f.is_file() and f.name.lower().endswith(".joblib") ] if not joblib_files: continue joblib_path = joblib_files[0].path try: loaded = joblib.load(joblib_path) if isinstance(loaded, dict) and "model" in loaded: model_obj = loaded["model"] elif hasattr(loaded, "predict"): model_obj = loaded else: errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}") continue models_found[subdir_name] = model_obj except Exception as e: errors.append(f"{subdir_name}: {type(e).__name__}: {e}") except Exception as e: QMessageBox.warning( self, "扫描失败", f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}", ) return if not models_found: QMessageBox.warning( self, "未找到模型", f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n" "请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。", ) self.external_model_widget.set_path("") self.external_models_dict = {} self._clear_model_list() return self.external_models_dict = models_found self._populate_model_list(models_found) names = sorted(models_found.keys()) display = f"已识别到 {len(names)} 个模型: {', '.join(names)}" self.external_model_widget.set_path(display) self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;") err_lines = "\n".join(errors) if errors else "无" QMessageBox.information( self, "模型扫描完成", f"成功加载 {len(models_found)} 个模型:\n{display}\n\n" f"加载失败 {len(errors)} 个:\n{err_lines}", ) def _populate_model_list(self, models_dict): """将扫描到的模型填充到 QListWidget,每个条目可勾选,默认全选""" self.model_list.clear() for name in sorted(models_dict.keys()): item = QListWidgetItem(name) item.setFlags(item.flags() | Qt.ItemIsUserCheckable) item.setCheckState(Qt.Checked) self.model_list.addItem(item) def _clear_model_list(self): """清空模型列表""" self.model_list.clear() def _select_all_models(self): """全选:设置所有条目为 Checked""" for i in range(self.model_list.count()): self.model_list.item(i).setCheckState(Qt.Checked) def _select_none_models(self): """全不选:设置所有条目为 Unchecked""" for i in range(self.model_list.count()): self.model_list.item(i).setCheckState(Qt.Unchecked) def _get_checked_models_dict(self): """从列表中提取用户勾选的模型,组装成字典返回""" result = {} for i in range(self.model_list.count()): item = self.model_list.item(i) if item.checkState() == Qt.Checked: name = item.text() if name in self.external_models_dict: result[name] = self.external_models_dict[name] return result def update_from_config(self, work_dir=None, pipeline=None): """从全局配置自动填充采样光谱和模型目录 Args: work_dir: 工作目录路径 pipeline: Pipeline 实例(未使用,保留接口兼容性) """ try: import traceback if work_dir: self.work_dir = work_dir elif hasattr(self, 'work_dir') and self.work_dir: pass else: self.work_dir = None main_window = self.window() # 1. 尝试从 Step4(采样点布设)读取全湖采样点 CSV 路径 if main_window and hasattr(main_window, 'step4_sampling_panel'): step4_widget = getattr(main_window.step4_sampling_panel, 'output_file', None) step4_output_path = "" if hasattr(step4_widget, 'get_path'): step4_output_path = step4_widget.get_path() or "" elif hasattr(step4_widget, 'text'): step4_output_path = step4_widget.text() or "" if step4_output_path: if not os.path.isabs(step4_output_path): step4_output_path = os.path.join(self.work_dir or '', step4_output_path).replace('\\', '/') existing = self.sampling_csv_file.get_path() if not existing or not existing.strip(): self.sampling_csv_file.set_path(step4_output_path) # 2. 尝试从 Step8(监督建模)读取模型目录(修复张冠李戴:原代码 main_window.step9_panel 不存在) step8_models_dir = get_step_output_path( main_window, 'models_dir', work_dir=self.work_dir, widget_attr='output_dir', fallback_key='step8_ml_train', ) if step8_models_dir: existing_models = self.models_dir_file.get_path() if not existing_models or not existing_models.strip(): self.models_dir_file.set_path(step8_models_dir) # 3. 自动填充输出路径(机器学习预测目录,归属 step9 → 9_ML_Prediction) # 注:9_ML_Prediction 是 prediction_dir 的子目录,用本地约定 if self.work_dir: output_dir = resolve_subdir(self.work_dir, 'ml_prediction') os.makedirs(output_dir, exist_ok=True) existing_out = self.output_file.get_path() if not existing_out or not existing_out.strip(): self.output_file.set_path(output_dir) except Exception as e: import traceback print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}") traceback.print_exc() def _get_default_work_dir(self): """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" if hasattr(self, 'work_dir') and self.work_dir: return str(self.work_dir) mw = self.window() if mw and hasattr(mw, 'work_dir') and mw.work_dir: return str(mw.work_dir) return "" def browse_models_dir(self): """浏览模型目录""" default = self._get_default_work_dir() if default: default = resolve_subdir(default, 'ml_prediction') dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default) if dir_path: self.models_dir_file.set_path(dir_path) def get_config(self): """获取配置""" config = { 'metric': self.metric.currentText(), 'prediction_column': self.prediction_column.text(), } sampling_csv_path = self.sampling_csv_file.get_path() if sampling_csv_path: config['sampling_csv_path'] = sampling_csv_path models_dir = self.models_dir_file.get_path() if models_dir: config['models_dir'] = models_dir output_path = self.output_file.get_path() if output_path: config['output_path'] = output_path return config def set_config(self, config): """设置配置""" if 'metric' in config: idx = self.metric.findText(config['metric']) if idx >= 0: self.metric.setCurrentIndex(idx) if 'prediction_column' in config: self.prediction_column.setText(config['prediction_column']) if 'sampling_csv_path' in config: self.sampling_csv_file.set_path(config['sampling_csv_path']) if 'models_dir' in config: self.models_dir_file.set_path(config['models_dir']) if 'output_path' in config: self.output_file.set_path(config['output_path']) def _on_run_single_clicked(self): """通过 EventBus 发布单步执行请求(解耦面板与 PipelineExecutor)。""" from src.gui.core.event_bus import global_event_bus sampling_csv_path = self.sampling_csv_file.get_path() if not sampling_csv_path: QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!") return # 外部模型优先:用户选择了"导入本地预训练模型" if self.use_external_model.isChecked(): if not self.external_models_dict: QMessageBox.warning( self, "模型未加载", "请先点击「浏览...」按钮选择模型母文件夹!", ) return checked_dict = self._get_checked_models_dict() if not checked_dict: QMessageBox.warning( self, "未选择模型", "请至少勾选一个模型参与预测!", ) return config = { 'step9_ml_predict': self.get_config(), '_external_models_dict': checked_dict, '_external_model_dir': self.external_model_dir, } global_event_bus.publish('RequestRunSingleStep', { 'step_name': 'step9_ml_predict', 'config': config, }) return # 默认流程:使用模型目录 models_dir = self.models_dir_file.get_path() if not models_dir: QMessageBox.warning(self, "输入错误", "请选择模型目录!") return config = {'step9_ml_predict': self.get_config()} global_event_bus.publish('RequestRunSingleStep', { 'step_name': 'step9_ml_predict', 'config': config, }) def run_step(self): """独立运行步骤11(旧版 parent 链上溯方式,保留兼容)。""" sampling_csv_path = self.sampling_csv_file.get_path() if not sampling_csv_path: QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!") return # 外部模型优先:用户选择了"导入本地预训练模型" if self.use_external_model.isChecked(): if not self.external_models_dict: QMessageBox.warning( self, "模型未加载", "请先点击「浏览...」按钮选择模型母文件夹!", ) return checked_dict = self._get_checked_models_dict() if not checked_dict: QMessageBox.warning( self, "未选择模型", "请至少勾选一个模型参与预测!", ) return main_window = self.window() if hasattr(main_window, 'run_single_step'): config = { 'step9_ml_predict': self.get_config(), '_external_models_dict': checked_dict, '_external_model_dir': self.external_model_dir, } main_window.run_single_step('step9_ml_predict', config) return # 默认流程:使用模型目录 models_dir = self.models_dir_file.get_path() if not models_dir: QMessageBox.warning(self, "输入错误", "请选择模型目录!") return main_window = self.window() if hasattr(main_window, 'run_single_step'): config = {'step9_ml_predict': self.get_config()} main_window.run_single_step('step9_ml_predict', config)