#!/usr/bin/env python # -*- coding: utf-8 -*- """ Step8 面板 - 机器学习预测 """ import os from pathlib import Path from PyQt5.QtWidgets import ( QWidget, QVBoxLayout, QGroupBox, QFormLayout, QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox, QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout, QListWidgetItem, ) from PyQt5.QtCore import Qt from src.gui.components.custom_widgets import FileSelectWidget from src.gui.styles import ModernStylesheet class Step8Panel(QWidget): """步骤8:机器学习预测""" def __init__(self, parent=None): super().__init__(parent) self.external_models_dict = {} # {subdir_name: model_obj, ...} self.external_model_dir = "" # 母文件夹路径(隐藏) self.init_ui() def init_ui(self): layout = QVBoxLayout() # -------- 模型来源选择(单选按钮组) -------- source_group = QGroupBox("模型来源") source_layout = QVBoxLayout() self.use_trained_model = QRadioButton("使用当前训练流程的模型") self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)") self.use_trained_model.setChecked(True) source_layout.addWidget(self.use_trained_model) source_layout.addWidget(self.use_external_model) self.use_trained_model.toggled.connect(self._on_model_source_changed) self.use_external_model.toggled.connect(self._on_model_source_changed) source_group.setStyleSheet(""" QRadioButton { font-size: 13px; spacing: 8px; } QRadioButton::indicator { width: 16px; height: 16px; border-radius: 9px; border: 2px solid #A0A0A0; background-color: #FFFFFF; } QRadioButton::indicator:hover { border: 2px solid #0078D7; } QRadioButton::indicator:checked { background-color: #0078D7; border: 2px solid #0078D7; } """) source_group.setLayout(source_layout) layout.addWidget(source_group) # -------- 外部模型文件选择(条件显示) -------- self.external_model_widget = FileSelectWidget( "模型母文件夹:", "Directories" ) self.external_model_widget.browse_btn.clicked.disconnect() self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir) self.external_model_widget.setVisible(False) layout.addWidget(self.external_model_widget) # -------- 已扫描模型列表(条件显示) -------- self.model_list_group = QGroupBox("选择参与预测的模型") self.model_list_group.setVisible(False) model_list_layout = QVBoxLayout() self.model_list = QListWidget() self.model_list.setMaximumHeight(130) self.model_list.setSelectionMode(QAbstractItemView.NoSelection) self.model_list.setStyleSheet(""" QListWidget { border: 1px solid #C0C0C0; border-radius: 4px; background-color: #FFFFFF; font-size: 12px; } QListWidget::item { padding: 4px 6px; border-bottom: 1px solid #F0F0F0; } QListWidget::item:selected { background-color: transparent; } """) model_list_layout.addWidget(self.model_list) btn_row = QHBoxLayout() self.btn_select_all = QPushButton("全选") self.btn_select_all.setMaximumWidth(80) self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default')) self.btn_select_all.clicked.connect(self._select_all_models) self.btn_select_none = QPushButton("全不选") self.btn_select_none.setMaximumWidth(80) self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default')) self.btn_select_none.clicked.connect(self._select_none_models) btn_row.addWidget(self.btn_select_all) btn_row.addWidget(self.btn_select_none) btn_row.addStretch() model_list_layout.addLayout(btn_row) self.model_list_group.setLayout(model_list_layout) layout.addWidget(self.model_list_group) # -------- 采样光谱CSV文件(用于独立运行)-------- self.sampling_csv_file = FileSelectWidget( "采样光谱CSV:", "CSV Files (*.csv);;All Files (*.*)" ) layout.addWidget(self.sampling_csv_file) # 模型目录(用于独立运行) self.models_dir_file = FileSelectWidget( "模型目录:", "Directories;;All Files (*.*)" ) self.models_dir_file.label.setText("模型目录:") self.models_dir_file.browse_btn.clicked.disconnect() self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir) layout.addWidget(self.models_dir_file) # 参数设置 params_group = QGroupBox("预测参数") params_layout = QFormLayout() self.metric = QComboBox() self.metric.addItems(['test_r2', 'test_rmse', 'test_mae']) params_layout.addRow("模型选择指标:", self.metric) self.prediction_column = QLineEdit() self.prediction_column.setText("prediction") params_layout.addRow("预测列名:", self.prediction_column) params_group.setLayout(params_layout) layout.addWidget(params_group) # 输出路径 self.output_file = FileSelectWidget( "输出路径:", "CSV Files (*.csv);;All Files (*.*)" ) layout.addWidget(self.output_file) # 启用步骤 self.enable_checkbox = QCheckBox("启用此步骤") self.enable_checkbox.setChecked(True) layout.addWidget(self.enable_checkbox) # 独立运行按钮 self.run_btn = QPushButton("独立运行此步骤") self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) self.run_btn.clicked.connect(self.run_step) layout.addWidget(self.run_btn) layout.addStretch() self.setLayout(layout) def _on_model_source_changed(self, checked: bool): """单选按钮切换:控制外部模型文件选择控件的显示/隐藏""" if not checked: return is_external = self.use_external_model.isChecked() self.external_model_widget.setVisible(is_external) self.model_list_group.setVisible(is_external) if not is_external: self.external_models_dict = {} self.external_model_dir = "" self._clear_model_list() def _scan_external_model_dir(self): """浏览模型母文件夹,自动扫描子目录中的 .joblib 文件""" default = self._get_default_work_dir() if default: default = os.path.join(default, "7_Supervised_Model_Training") dir_path = QFileDialog.getExistingDirectory( self, "选择模型母文件夹", default, ) if not dir_path: return self.external_model_dir = dir_path models_found = {} errors = [] try: import joblib for subentry in os.scandir(dir_path): if not subentry.is_dir(): continue subdir_name = subentry.name joblib_files = [ f for f in os.scandir(subentry.path) if f.is_file() and f.name.lower().endswith(".joblib") ] if not joblib_files: continue # 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致) joblib_path = joblib_files[0].path try: loaded = joblib.load(joblib_path) if isinstance(loaded, dict) and "model" in loaded: model_obj = loaded["model"] elif hasattr(loaded, "predict"): model_obj = loaded else: errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}") continue models_found[subdir_name] = model_obj except Exception as e: errors.append(f"{subdir_name}: {type(e).__name__}: {e}") except Exception as e: QMessageBox.warning( self, "扫描失败", f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}", ) return if not models_found: QMessageBox.warning( self, "未找到模型", f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n" "请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。", ) self.external_model_widget.set_path("") self.external_models_dict = {} self._clear_model_list() return self.external_models_dict = models_found self._populate_model_list(models_found) names = sorted(models_found.keys()) display = f"已识别到 {len(names)} 个模型: {', '.join(names)}" self.external_model_widget.set_path(display) self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;") err_lines = "\n".join(errors) if errors else "无" QMessageBox.information( self, "模型扫描完成", f"成功加载 {len(models_found)} 个模型:\n{display}\n\n" f"加载失败 {len(errors)} 个:\n{err_lines}", ) def _populate_model_list(self, models_dict): """将扫描到的模型填充到 QListWidget,每个条目可勾选,默认全选""" self.model_list.clear() for name in sorted(models_dict.keys()): item = QListWidgetItem(name) item.setFlags(item.flags() | Qt.ItemIsUserCheckable) item.setCheckState(Qt.Checked) self.model_list.addItem(item) def _clear_model_list(self): """清空模型列表""" self.model_list.clear() def _select_all_models(self): """全选:设置所有条目为 Checked""" for i in range(self.model_list.count()): self.model_list.item(i).setCheckState(Qt.Checked) def _select_none_models(self): """全不选:设置所有条目为 Unchecked""" for i in range(self.model_list.count()): self.model_list.item(i).setCheckState(Qt.Unchecked) def _get_checked_models_dict(self): """从列表中提取用户勾选的模型,组装成字典返回""" result = {} for i in range(self.model_list.count()): item = self.model_list.item(i) if item.checkState() == Qt.Checked: name = item.text() if name in self.external_models_dict: result[name] = self.external_models_dict[name] return result def update_from_config(self, work_dir=None, pipeline=None): """从全局配置自动填充采样光谱和模型目录 Args: work_dir: 工作目录路径 pipeline: Pipeline 实例(未使用,保留接口兼容性) """ try: import traceback if work_dir: self.work_dir = work_dir elif hasattr(self, 'work_dir') and self.work_dir: pass else: self.work_dir = None main_window = self.window() # 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径 if main_window and hasattr(main_window, 'step7_panel'): step7_widget = getattr(main_window.step7_panel, 'output_file', None) step7_output_path = "" if hasattr(step7_widget, 'get_path'): step7_output_path = step7_widget.get_path() or "" elif hasattr(step7_widget, 'text'): step7_output_path = step7_widget.text() or "" if step7_output_path: # 若为相对路径,使用 work_dir 合成为绝对路径 if not os.path.isabs(step7_output_path): step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/') existing = self.sampling_csv_file.get_path() if not existing or not existing.strip(): self.sampling_csv_file.set_path(step7_output_path) # 2. 尝试从 Step6 界面读取监督模型目录 if main_window and hasattr(main_window, 'step6_panel'): step6_widget = getattr(main_window.step6_panel, 'output_dir', None) step6_models_dir = "" if hasattr(step6_widget, 'get_path'): step6_models_dir = step6_widget.get_path() or "" elif hasattr(step6_widget, 'text'): step6_models_dir = step6_widget.text() or "" if step6_models_dir: # 若为相对路径,使用 work_dir 合成为绝对路径 if not os.path.isabs(step6_models_dir): step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/') existing_models = self.models_dir_file.get_path() if not existing_models or not existing_models.strip(): self.models_dir_file.set_path(step6_models_dir) # 3. 自动填充输出路径(机器学习预测目录) if self.work_dir: output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction") os.makedirs(output_dir, exist_ok=True) existing_out = self.output_file.get_path() if not existing_out or not existing_out.strip(): self.output_file.set_path(output_dir) except Exception as e: import traceback print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}") traceback.print_exc() def _get_default_work_dir(self): """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" if hasattr(self, 'work_dir') and self.work_dir: return str(self.work_dir) mw = self.window() if mw and hasattr(mw, 'work_dir') and mw.work_dir: return str(mw.work_dir) return "" def browse_models_dir(self): """浏览模型目录""" default = self._get_default_work_dir() if default: default = os.path.join(default, "7_Supervised_Model_Training") dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default) if dir_path: self.models_dir_file.set_path(dir_path) def get_config(self): """获取配置""" config = { 'metric': self.metric.currentText(), 'prediction_column': self.prediction_column.text(), } sampling_csv_path = self.sampling_csv_file.get_path() if sampling_csv_path: config['sampling_csv_path'] = sampling_csv_path models_dir = self.models_dir_file.get_path() if models_dir: config['models_dir'] = models_dir output_path = self.output_file.get_path() if output_path: config['output_path'] = output_path return config def set_config(self, config): """设置配置""" if 'metric' in config: idx = self.metric.findText(config['metric']) if idx >= 0: self.metric.setCurrentIndex(idx) if 'prediction_column' in config: self.prediction_column.setText(config['prediction_column']) if 'sampling_csv_path' in config: self.sampling_csv_file.set_path(config['sampling_csv_path']) if 'models_dir' in config: self.models_dir_file.set_path(config['models_dir']) if 'output_path' in config: self.output_file.set_path(config['output_path']) def run_step(self): """独立运行步骤8""" sampling_csv_path = self.sampling_csv_file.get_path() if not sampling_csv_path: QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!") return # 外部模型优先:用户选择了"导入本地预训练模型" if self.use_external_model.isChecked(): if not self.external_models_dict: QMessageBox.warning( self, "模型未加载", "请先点击「浏览...」按钮选择模型母文件夹!", ) return # 只传递用户勾选的模型 checked_dict = self._get_checked_models_dict() if not checked_dict: QMessageBox.warning( self, "未选择模型", "请至少勾选一个模型参与预测!", ) return main_window = self.window() if hasattr(main_window, 'run_single_step'): config = { 'step8': self.get_config(), '_external_models_dict': checked_dict, '_external_model_dir': self.external_model_dir, } main_window.run_single_step('step8', config) return # 默认流程:使用模型目录 models_dir = self.models_dir_file.get_path() if not models_dir: QMessageBox.warning(self, "输入错误", "请选择模型目录!") return main_window = self.window() if hasattr(main_window, 'run_single_step'): config = {'step8': self.get_config()} main_window.run_single_step('step8', config)