#!/usr/bin/env python # -*- coding: utf-8 -*- """ Step8 面板 - 机器学习建模 """ import os from PyQt5.QtWidgets import ( QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout, QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox, QPushButton, QFileDialog, QMessageBox, ) from PyQt5.QtCore import Qt from src.gui.components.custom_widgets import FileSelectWidget from src.gui.styles import ModernStylesheet # ============================================================ # 中文映射表(内部键名 -> 显示文本) # ============================================================ # 预处理方法:内部键 -> 显示文本 PREPROC_CHINESE = { 'None': '无 (None)', 'MMS': '最小-最大归一化 (MMS)', 'SS': '标度化 (SS)', 'SNV': '标准正态变换 (SNV)', 'MA': '移动平均 (MA)', 'SG': 'Savitzky-Golay (SG)', 'MSC': '多元散射校正 (MSC)', 'D1': '一阶导数 (D1)', 'D2': '二阶导数 (D2)', 'DT': '去趋势 (DT)', 'CT': '中心化 (CT)', } # 模型类型:内部键 -> 显示文本 MODEL_CHINESE = { # 线性模型 'LinearRegression': '多元线性回归 (MLR)', 'Ridge': '岭回归 (Ridge)', 'Lasso': '套索回归 (Lasso)', 'ElasticNet': '弹性网络 (ElasticNet)', 'PLS': '偏最小二乘 (PLSR)', # 树模型 'DecisionTree': '决策树 (CART)', 'RF': '随机森林 (RF)', 'ExtraTrees': '极端随机树 (ET)', 'XGBoost': '极值梯度提升 (XGBoost)', 'LightGBM': '轻量梯度提升 (LightGBM)', 'CatBoost': '类别梯度提升 (CatBoost)', # 集成学习 'GradientBoosting': '梯度提升树 (GBDT)', 'AdaBoost': '自适应提升 (AdaBoost)', # 其他模型 'SVR': '支持向量回归 (SVR)', 'KNN': 'K近邻回归 (KNN)', 'MLP': '多层感知机 (BP神经网络)', } # 数据划分方法:内部键 -> 显示文本 SPLIT_CHINESE = { 'spxy': 'SPXY 算法 (考量X-Y空间)', 'ks': 'KS 算法 (考量X空间)', 'random': '随机划分 (Random)', } class Step8MlTrainPanel(QWidget): """步骤8:机器学习建模""" def __init__(self, parent=None): super().__init__(parent) self.init_ui() def init_ui(self): layout = QVBoxLayout() # 标题 # 训练数据文件(用于独立运行) self.training_csv_file = FileSelectWidget( "训练数据:", "CSV Files (*.csv);;All Files (*.*)" ) layout.addWidget(self.training_csv_file) # 机器学习模型页面 self.ml_page = QWidget() self.create_ml_page() layout.addWidget(self.ml_page) # 输出文件路径 self.output_path = FileSelectWidget( "输出文件:", "CSV Files (*.csv);;All Files (*.*)", mode="save" ) self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...") self.output_path.browse_btn.clicked.disconnect() self.output_path.browse_btn.clicked.connect(self.browse_output_path) layout.addWidget(self.output_path) # 启用步骤 self.enable_checkbox = QCheckBox("启用此步骤") self.enable_checkbox.setChecked(False) layout.addWidget(self.enable_checkbox) # 独立运行按钮 self.run_btn = QPushButton("独立运行此步骤") self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) self.run_btn.clicked.connect(self.run_step) layout.addWidget(self.run_btn) layout.addStretch() self.setLayout(layout) def create_ml_page(self): """创建机器学习模型页面""" layout = QVBoxLayout() # 参数设置 params_group = QGroupBox("训练参数") params_layout = QFormLayout() self.feature_start = QLineEdit() self.feature_start.setText("374.285004") params_layout.addRow("特征起始列:", self.feature_start) self.cv_folds = QSpinBox() self.cv_folds.setRange(2, 10) self.cv_folds.setValue(3) params_layout.addRow("交叉验证折数:", self.cv_folds) params_group.setLayout(params_layout) layout.addWidget(params_group) # 预处理方法 - 多选 preproc_group = QGroupBox("预处理方法 (可多选)") preproc_layout = QVBoxLayout() preproc_grid = QGridLayout() self.preproc_checkboxes = {} preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT'] for i, method in enumerate(preproc_methods): checkbox = QCheckBox(PREPROC_CHINESE.get(method, method)) checkbox.setChecked(False) self.preproc_checkboxes[method] = checkbox preproc_grid.addWidget(checkbox, i // 4, i % 4) button_layout = QHBoxLayout() select_all_btn = QPushButton("全选") deselect_all_btn = QPushButton("全不选") select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True)) deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False)) button_layout.addWidget(select_all_btn) button_layout.addWidget(deselect_all_btn) button_layout.addStretch() preproc_layout.addLayout(preproc_grid) preproc_layout.addLayout(button_layout) preproc_group.setLayout(preproc_layout) layout.addWidget(preproc_group) # 模型选择 - 多选 model_group = QGroupBox("模型类型 (可多选)") model_layout = QVBoxLayout() model_grid = QGridLayout() self.model_checkboxes = {} model_groups = [ ("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']), ("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']), ("【集成学习】", ['GradientBoosting', 'AdaBoost']), ("【其他模型】", ['SVR', 'KNN', 'MLP']) ] row = 0 for group_name, models in model_groups: group_label = QLabel(f"{group_name}") group_label.setStyleSheet( f"background-color: {ModernStylesheet.COLORS['hover']}; " f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; " f"border-radius: 3px;" ) model_grid.addWidget(group_label, row, 0, 1, 4) row += 1 for i, model in enumerate(models): checkbox = QCheckBox(MODEL_CHINESE.get(model, model)) checkbox.setChecked(False) self.model_checkboxes[model] = checkbox model_grid.addWidget(checkbox, row, i % 4) if (i + 1) % 4 == 0: row += 1 row += 1 model_button_layout = QHBoxLayout() model_select_all = QPushButton("全选") model_deselect_all = QPushButton("全不选") model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True)) model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False)) model_button_layout.addWidget(model_select_all) model_button_layout.addWidget(model_deselect_all) model_button_layout.addStretch() model_layout.addLayout(model_grid) model_layout.addLayout(model_button_layout) model_group.setLayout(model_layout) layout.addWidget(model_group) # 数据划分方法 - 多选 split_group = QGroupBox("数据划分方法 (可多选)") split_layout = QVBoxLayout() split_grid = QGridLayout() self.split_checkboxes = {} split_methods = ['spxy', 'ks', 'random'] for i, method in enumerate(split_methods): checkbox = QCheckBox(SPLIT_CHINESE.get(method, method)) checkbox.setChecked(False) self.split_checkboxes[method] = checkbox split_grid.addWidget(checkbox, 0, i) split_button_layout = QHBoxLayout() split_select_all = QPushButton("全选") split_deselect_all = QPushButton("全不选") split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True)) split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False)) split_button_layout.addWidget(split_select_all) split_button_layout.addWidget(split_deselect_all) split_button_layout.addStretch() split_layout.addLayout(split_grid) split_layout.addLayout(split_button_layout) split_group.setLayout(split_layout) layout.addWidget(split_group) self.ml_page.setLayout(layout) def _toggle_checkboxes(self, checkboxes_dict, checked): """统一设置checkbox状态""" for checkbox in checkboxes_dict.values(): checkbox.setChecked(checked) def _get_default_work_dir(self): """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" if hasattr(self, 'work_dir') and self.work_dir: return str(self.work_dir) mw = self.window() if mw and hasattr(mw, 'work_dir') and mw.work_dir: return str(mw.work_dir) return "" def browse_output_path(self): """浏览输出文件路径(保存对话框)""" current = self.output_path.get_path().strip() if current: initial_dir = os.path.dirname(current) initial_file = os.path.basename(current) else: initial_dir = "" initial_file = "" if not initial_dir or not os.path.isdir(initial_dir): # 默认定位到 indices 目录 work_dir = self._get_default_work_dir() initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else "" if initial_dir and not os.path.isdir(initial_dir): os.makedirs(initial_dir, exist_ok=True) file_path, _ = QFileDialog.getSaveFileName( self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir, "CSV Files (*.csv);;All Files (*.*)" ) if file_path: self.output_path.set_path(file_path) def get_config(self): """获取配置""" preprocessing_methods = [ method for method, checkbox in self.preproc_checkboxes.items() if checkbox.isChecked() ] model_names = [ model for model, checkbox in self.model_checkboxes.items() if checkbox.isChecked() ] split_methods = [ method for method, checkbox in self.split_checkboxes.items() if checkbox.isChecked() ] config = { 'feature_start_column': self.feature_start.text(), 'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'], 'model_names': model_names if model_names else ['SVR'], 'split_methods': split_methods if split_methods else ['random'], 'cv_folds': self.cv_folds.value() } training_csv_path = self.training_csv_file.get_path() if training_csv_path: config['training_csv_path'] = training_csv_path output_path = self.output_path.get_path() if output_path: config['output_path'] = output_path return config def set_config(self, config): """设置配置""" if 'feature_start_column' in config: self.feature_start.setText(str(config['feature_start_column'])) if 'cv_folds' in config: self.cv_folds.setValue(config['cv_folds']) if 'preprocessing_methods' in config: methods = config['preprocessing_methods'] for method, checkbox in self.preproc_checkboxes.items(): checkbox.setChecked(method in methods) if 'model_names' in config: models = config['model_names'] for model, checkbox in self.model_checkboxes.items(): checkbox.setChecked(model in models) if 'split_methods' in config: methods = config['split_methods'] for method, checkbox in self.split_checkboxes.items(): checkbox.setChecked(method in methods) if 'training_csv_path' in config: self.training_csv_file.set_path(config['training_csv_path']) if 'output_path' in config: self.output_path.set_path(config['output_path']) def update_from_config(self, work_dir=None, pipeline=None): """从全局配置自动填充训练数据和输出路径 Args: work_dir: 工作目录路径 pipeline: Pipeline 实例(未使用,保留接口兼容性) """ if work_dir: self.work_dir = work_dir elif hasattr(self, 'work_dir') and self.work_dir: pass else: self.work_dir = None # 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径 main_window = self.window() if hasattr(main_window, 'step5_panel'): # 优先直接从 Step5 的输出 widget 读取 step5_output = main_window.step5_panel.output_file.get_path() if step5_output: # 若为相对路径,使用 work_dir 合成为绝对路径 if not os.path.isabs(step5_output): step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/') self.training_csv_file.set_path(step5_output) elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'): # 回退:从 Step5 的 config 字典中查找可能的键名 step5_cfg = main_window.step5_panel.get_config() step5_csv = ( step5_cfg.get('training_csv_path') or step5_cfg.get('output_file') or step5_cfg.get('csv_path') or step5_cfg.get('output_csv') ) if step5_csv: # 若为相对路径,使用 work_dir 合成为绝对路径 if not os.path.isabs(step5_csv): step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/') self.training_csv_file.set_path(step5_csv) # 2. 自动填充输出文件路径(基于工作目录和输入文件名) # 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv # 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv if self.work_dir: indices_dir = os.path.join(self.work_dir, "6_water_quality_indices") os.makedirs(indices_dir, exist_ok=True) training_csv = self.training_csv_file.get_path() if training_csv: basename = os.path.splitext(os.path.basename(training_csv))[0] output_file = f"{basename}_indices.csv" else: output_file = "water_quality_indices.csv" output_path = os.path.join(indices_dir, output_file).replace('\\', '/') self.output_path.set_path(output_path) else: self.output_path.set_path("") def run_step(self): """独立运行步骤8""" training_csv_path = self.training_csv_file.get_path() if not training_csv_path: QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!") return main_window = self.window() if hasattr(main_window, 'run_single_step'): config = {'step8_ml_train': self.get_config()} main_window.run_single_step('step8_ml_train', config) def get_training_params(self): """获取模型训练参数""" return { 'pipeline_type': 'machine_learning', 'feature_start': float(self.feature_start.text()), 'cv_folds': self.cv_folds.value(), 'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()], 'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()], 'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()] }