#!/usr/bin/env python # -*- coding: utf-8 -*- """ Step8 面板 - 机器学习预测 """ import os from pathlib import Path from PyQt5.QtWidgets import ( QWidget, QVBoxLayout, QGroupBox, QFormLayout, QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox, QFileDialog, ) from src.gui.components.custom_widgets import FileSelectWidget from src.gui.styles import ModernStylesheet class Step8Panel(QWidget): """步骤8:机器学习预测""" def __init__(self, parent=None): super().__init__(parent) self.init_ui() def init_ui(self): layout = QVBoxLayout() # 采样光谱CSV文件(用于独立运行) self.sampling_csv_file = FileSelectWidget( "采样光谱CSV:", "CSV Files (*.csv);;All Files (*.*)" ) layout.addWidget(self.sampling_csv_file) # 模型目录(用于独立运行) self.models_dir_file = FileSelectWidget( "模型目录:", "Directories;;All Files (*.*)" ) self.models_dir_file.label.setText("模型目录:") self.models_dir_file.browse_btn.clicked.disconnect() self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir) layout.addWidget(self.models_dir_file) # 参数设置 params_group = QGroupBox("预测参数") params_layout = QFormLayout() self.metric = QComboBox() self.metric.addItems(['test_r2', 'test_rmse', 'test_mae']) params_layout.addRow("模型选择指标:", self.metric) self.prediction_column = QLineEdit() self.prediction_column.setText("prediction") params_layout.addRow("预测列名:", self.prediction_column) params_group.setLayout(params_layout) layout.addWidget(params_group) # 输出路径 self.output_file = FileSelectWidget( "输出路径:", "CSV Files (*.csv);;All Files (*.*)" ) layout.addWidget(self.output_file) # 启用步骤 self.enable_checkbox = QCheckBox("启用此步骤") self.enable_checkbox.setChecked(True) layout.addWidget(self.enable_checkbox) # 独立运行按钮 self.run_btn = QPushButton("独立运行此步骤") self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) self.run_btn.clicked.connect(self.run_step) layout.addWidget(self.run_btn) layout.addStretch() self.setLayout(layout) def update_from_config(self, work_dir=None, pipeline=None): """从全局配置自动填充采样光谱和模型目录 Args: work_dir: 工作目录路径 pipeline: Pipeline 实例(未使用,保留接口兼容性) """ try: import traceback if work_dir: self.work_dir = work_dir elif hasattr(self, 'work_dir') and self.work_dir: pass else: self.work_dir = None main_window = self.window() # 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径 if main_window and hasattr(main_window, 'step7_panel'): step7_widget = getattr(main_window.step7_panel, 'output_file', None) step7_output_path = "" if hasattr(step7_widget, 'get_path'): step7_output_path = step7_widget.get_path() or "" elif hasattr(step7_widget, 'text'): step7_output_path = step7_widget.text() or "" if step7_output_path: # 若为相对路径,使用 work_dir 合成为绝对路径 if not os.path.isabs(step7_output_path): step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/') existing = self.sampling_csv_file.get_path() if not existing or not existing.strip(): self.sampling_csv_file.set_path(step7_output_path) # 2. 尝试从 Step6 界面读取监督模型目录 if main_window and hasattr(main_window, 'step6_panel'): step6_widget = getattr(main_window.step6_panel, 'output_dir', None) step6_models_dir = "" if hasattr(step6_widget, 'get_path'): step6_models_dir = step6_widget.get_path() or "" elif hasattr(step6_widget, 'text'): step6_models_dir = step6_widget.text() or "" if step6_models_dir: # 若为相对路径,使用 work_dir 合成为绝对路径 if not os.path.isabs(step6_models_dir): step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/') existing_models = self.models_dir_file.get_path() if not existing_models or not existing_models.strip(): self.models_dir_file.set_path(step6_models_dir) # 3. 自动填充输出路径(机器学习预测目录) if self.work_dir: output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction") os.makedirs(output_dir, exist_ok=True) existing_out = self.output_file.get_path() if not existing_out or not existing_out.strip(): self.output_file.set_path(output_dir) except Exception as e: import traceback print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}") traceback.print_exc() def _get_default_work_dir(self): """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" if hasattr(self, 'work_dir') and self.work_dir: return str(self.work_dir) mw = self.window() if mw and hasattr(mw, 'work_dir') and mw.work_dir: return str(mw.work_dir) return "" def browse_models_dir(self): """浏览模型目录""" default = self._get_default_work_dir() if default: default = os.path.join(default, "7_Supervised_Model_Training") dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default) if dir_path: self.models_dir_file.set_path(dir_path) def get_config(self): """获取配置""" config = { 'metric': self.metric.currentText(), 'prediction_column': self.prediction_column.text(), } sampling_csv_path = self.sampling_csv_file.get_path() if sampling_csv_path: config['sampling_csv_path'] = sampling_csv_path models_dir = self.models_dir_file.get_path() if models_dir: config['models_dir'] = models_dir output_path = self.output_file.get_path() if output_path: config['output_path'] = output_path return config def set_config(self, config): """设置配置""" if 'metric' in config: idx = self.metric.findText(config['metric']) if idx >= 0: self.metric.setCurrentIndex(idx) if 'prediction_column' in config: self.prediction_column.setText(config['prediction_column']) if 'sampling_csv_path' in config: self.sampling_csv_file.set_path(config['sampling_csv_path']) if 'models_dir' in config: self.models_dir_file.set_path(config['models_dir']) if 'output_path' in config: self.output_file.set_path(config['output_path']) def run_step(self): """独立运行步骤8""" sampling_csv_path = self.sampling_csv_file.get_path() models_dir = self.models_dir_file.get_path() if not sampling_csv_path: QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!") return if not models_dir: QMessageBox.warning(self, "输入错误", "请选择模型目录!") return main_window = self.window() if hasattr(main_window, 'run_single_step'): config = {'step8': self.get_config()} main_window.run_single_step('step8', config)