416 lines
16 KiB
Python
416 lines
16 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
Step8 面板 - 机器学习建模
|
||
"""
|
||
|
||
import os
|
||
|
||
from PyQt5.QtWidgets import (
|
||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||
QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
|
||
QPushButton, QFileDialog, QMessageBox,
|
||
)
|
||
from PyQt5.QtCore import Qt
|
||
|
||
from src.gui.components.custom_widgets import FileSelectWidget
|
||
from src.gui.styles import ModernStylesheet
|
||
|
||
|
||
# ============================================================
|
||
# 中文映射表(内部键名 -> 显示文本)
|
||
# ============================================================
|
||
|
||
# 预处理方法:内部键 -> 显示文本
|
||
PREPROC_CHINESE = {
|
||
'None': '无 (None)',
|
||
'MMS': '最小-最大归一化 (MMS)',
|
||
'SS': '标度化 (SS)',
|
||
'SNV': '标准正态变换 (SNV)',
|
||
'MA': '移动平均 (MA)',
|
||
'SG': 'Savitzky-Golay (SG)',
|
||
'MSC': '多元散射校正 (MSC)',
|
||
'D1': '一阶导数 (D1)',
|
||
'D2': '二阶导数 (D2)',
|
||
'DT': '去趋势 (DT)',
|
||
'CT': '中心化 (CT)',
|
||
}
|
||
|
||
# 模型类型:内部键 -> 显示文本
|
||
MODEL_CHINESE = {
|
||
# 线性模型
|
||
'LinearRegression': '多元线性回归 (MLR)',
|
||
'Ridge': '岭回归 (Ridge)',
|
||
'Lasso': '套索回归 (Lasso)',
|
||
'ElasticNet': '弹性网络 (ElasticNet)',
|
||
'PLS': '偏最小二乘 (PLSR)',
|
||
# 树模型
|
||
'DecisionTree': '决策树 (CART)',
|
||
'RF': '随机森林 (RF)',
|
||
'ExtraTrees': '极端随机树 (ET)',
|
||
'XGBoost': '极值梯度提升 (XGBoost)',
|
||
'LightGBM': '轻量梯度提升 (LightGBM)',
|
||
'CatBoost': '类别梯度提升 (CatBoost)',
|
||
# 集成学习
|
||
'GradientBoosting': '梯度提升树 (GBDT)',
|
||
'AdaBoost': '自适应提升 (AdaBoost)',
|
||
# 其他模型
|
||
'SVR': '支持向量回归 (SVR)',
|
||
'KNN': 'K近邻回归 (KNN)',
|
||
'MLP': '多层感知机 (BP神经网络)',
|
||
}
|
||
|
||
# 数据划分方法:内部键 -> 显示文本
|
||
SPLIT_CHINESE = {
|
||
'spxy': 'SPXY 算法 (考量X-Y空间)',
|
||
'ks': 'KS 算法 (考量X空间)',
|
||
'random': '随机划分 (Random)',
|
||
}
|
||
|
||
|
||
class Step8MlTrainPanel(QWidget):
|
||
"""步骤8:机器学习建模"""
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent)
|
||
self.init_ui()
|
||
|
||
def init_ui(self):
|
||
layout = QVBoxLayout()
|
||
|
||
# 标题
|
||
|
||
|
||
# 训练数据文件(用于独立运行)
|
||
self.training_csv_file = FileSelectWidget(
|
||
"训练数据:",
|
||
"CSV Files (*.csv);;All Files (*.*)"
|
||
)
|
||
layout.addWidget(self.training_csv_file)
|
||
|
||
# 机器学习模型页面
|
||
self.ml_page = QWidget()
|
||
self.create_ml_page()
|
||
layout.addWidget(self.ml_page)
|
||
|
||
# 输出文件路径
|
||
self.output_path = FileSelectWidget(
|
||
"输出文件:",
|
||
"CSV Files (*.csv);;All Files (*.*)",
|
||
mode="save"
|
||
)
|
||
self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
|
||
self.output_path.browse_btn.clicked.disconnect()
|
||
self.output_path.browse_btn.clicked.connect(self.browse_output_path)
|
||
layout.addWidget(self.output_path)
|
||
|
||
# 启用步骤
|
||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||
self.enable_checkbox.setChecked(False)
|
||
layout.addWidget(self.enable_checkbox)
|
||
|
||
# 独立运行按钮
|
||
self.run_btn = QPushButton("独立运行此步骤")
|
||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||
self.run_btn.clicked.connect(self.run_step)
|
||
layout.addWidget(self.run_btn)
|
||
|
||
layout.addStretch()
|
||
self.setLayout(layout)
|
||
|
||
def create_ml_page(self):
|
||
"""创建机器学习模型页面"""
|
||
layout = QVBoxLayout()
|
||
|
||
# 参数设置
|
||
params_group = QGroupBox("训练参数")
|
||
params_layout = QFormLayout()
|
||
|
||
self.feature_start = QLineEdit()
|
||
self.feature_start.setText("374.285004")
|
||
params_layout.addRow("特征起始列:", self.feature_start)
|
||
|
||
self.cv_folds = QSpinBox()
|
||
self.cv_folds.setRange(2, 10)
|
||
self.cv_folds.setValue(3)
|
||
params_layout.addRow("交叉验证折数:", self.cv_folds)
|
||
|
||
params_group.setLayout(params_layout)
|
||
layout.addWidget(params_group)
|
||
|
||
# 预处理方法 - 多选
|
||
preproc_group = QGroupBox("预处理方法 (可多选)")
|
||
preproc_layout = QVBoxLayout()
|
||
|
||
preproc_grid = QGridLayout()
|
||
self.preproc_checkboxes = {}
|
||
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
|
||
|
||
for i, method in enumerate(preproc_methods):
|
||
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
|
||
checkbox.setChecked(False)
|
||
self.preproc_checkboxes[method] = checkbox
|
||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||
|
||
button_layout = QHBoxLayout()
|
||
select_all_btn = QPushButton("全选")
|
||
deselect_all_btn = QPushButton("全不选")
|
||
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
|
||
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
|
||
button_layout.addWidget(select_all_btn)
|
||
button_layout.addWidget(deselect_all_btn)
|
||
button_layout.addStretch()
|
||
|
||
preproc_layout.addLayout(preproc_grid)
|
||
preproc_layout.addLayout(button_layout)
|
||
preproc_group.setLayout(preproc_layout)
|
||
layout.addWidget(preproc_group)
|
||
|
||
# 模型选择 - 多选
|
||
model_group = QGroupBox("模型类型 (可多选)")
|
||
model_layout = QVBoxLayout()
|
||
|
||
model_grid = QGridLayout()
|
||
self.model_checkboxes = {}
|
||
|
||
model_groups = [
|
||
("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
|
||
("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
|
||
("【集成学习】", ['GradientBoosting', 'AdaBoost']),
|
||
("【其他模型】", ['SVR', 'KNN', 'MLP'])
|
||
]
|
||
|
||
row = 0
|
||
for group_name, models in model_groups:
|
||
group_label = QLabel(f"<b>{group_name}</b>")
|
||
group_label.setStyleSheet(
|
||
f"background-color: {ModernStylesheet.COLORS['hover']}; "
|
||
f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
|
||
f"border-radius: 3px;"
|
||
)
|
||
model_grid.addWidget(group_label, row, 0, 1, 4)
|
||
row += 1
|
||
|
||
for i, model in enumerate(models):
|
||
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
|
||
checkbox.setChecked(False)
|
||
self.model_checkboxes[model] = checkbox
|
||
model_grid.addWidget(checkbox, row, i % 4)
|
||
if (i + 1) % 4 == 0:
|
||
row += 1
|
||
|
||
row += 1
|
||
|
||
model_button_layout = QHBoxLayout()
|
||
model_select_all = QPushButton("全选")
|
||
model_deselect_all = QPushButton("全不选")
|
||
model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
|
||
model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
|
||
model_button_layout.addWidget(model_select_all)
|
||
model_button_layout.addWidget(model_deselect_all)
|
||
model_button_layout.addStretch()
|
||
|
||
model_layout.addLayout(model_grid)
|
||
model_layout.addLayout(model_button_layout)
|
||
model_group.setLayout(model_layout)
|
||
layout.addWidget(model_group)
|
||
|
||
# 数据划分方法 - 多选
|
||
split_group = QGroupBox("数据划分方法 (可多选)")
|
||
split_layout = QVBoxLayout()
|
||
|
||
split_grid = QGridLayout()
|
||
self.split_checkboxes = {}
|
||
split_methods = ['spxy', 'ks', 'random']
|
||
|
||
for i, method in enumerate(split_methods):
|
||
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
|
||
checkbox.setChecked(False)
|
||
self.split_checkboxes[method] = checkbox
|
||
split_grid.addWidget(checkbox, 0, i)
|
||
|
||
split_button_layout = QHBoxLayout()
|
||
split_select_all = QPushButton("全选")
|
||
split_deselect_all = QPushButton("全不选")
|
||
split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
|
||
split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
|
||
split_button_layout.addWidget(split_select_all)
|
||
split_button_layout.addWidget(split_deselect_all)
|
||
split_button_layout.addStretch()
|
||
|
||
split_layout.addLayout(split_grid)
|
||
split_layout.addLayout(split_button_layout)
|
||
split_group.setLayout(split_layout)
|
||
layout.addWidget(split_group)
|
||
|
||
self.ml_page.setLayout(layout)
|
||
|
||
def _toggle_checkboxes(self, checkboxes_dict, checked):
|
||
"""统一设置checkbox状态"""
|
||
for checkbox in checkboxes_dict.values():
|
||
checkbox.setChecked(checked)
|
||
|
||
def _get_default_work_dir(self):
|
||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||
if hasattr(self, 'work_dir') and self.work_dir:
|
||
return str(self.work_dir)
|
||
mw = self.window()
|
||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||
return str(mw.work_dir)
|
||
return ""
|
||
|
||
def browse_output_path(self):
|
||
"""浏览输出文件路径(保存对话框)"""
|
||
current = self.output_path.get_path().strip()
|
||
if current:
|
||
initial_dir = os.path.dirname(current)
|
||
initial_file = os.path.basename(current)
|
||
else:
|
||
initial_dir = ""
|
||
initial_file = ""
|
||
|
||
if not initial_dir or not os.path.isdir(initial_dir):
|
||
# 默认定位到 indices 目录
|
||
work_dir = self._get_default_work_dir()
|
||
initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
|
||
if initial_dir and not os.path.isdir(initial_dir):
|
||
os.makedirs(initial_dir, exist_ok=True)
|
||
|
||
file_path, _ = QFileDialog.getSaveFileName(
|
||
self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
|
||
"CSV Files (*.csv);;All Files (*.*)"
|
||
)
|
||
if file_path:
|
||
self.output_path.set_path(file_path)
|
||
|
||
def get_config(self):
|
||
"""获取配置"""
|
||
preprocessing_methods = [
|
||
method for method, checkbox in self.preproc_checkboxes.items()
|
||
if checkbox.isChecked()
|
||
]
|
||
model_names = [
|
||
model for model, checkbox in self.model_checkboxes.items()
|
||
if checkbox.isChecked()
|
||
]
|
||
split_methods = [
|
||
method for method, checkbox in self.split_checkboxes.items()
|
||
if checkbox.isChecked()
|
||
]
|
||
|
||
config = {
|
||
'feature_start_column': self.feature_start.text(),
|
||
'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
|
||
'model_names': model_names if model_names else ['SVR'],
|
||
'split_methods': split_methods if split_methods else ['random'],
|
||
'cv_folds': self.cv_folds.value()
|
||
}
|
||
training_csv_path = self.training_csv_file.get_path()
|
||
if training_csv_path:
|
||
config['training_csv_path'] = training_csv_path
|
||
output_path = self.output_path.get_path()
|
||
if output_path:
|
||
config['output_path'] = output_path
|
||
return config
|
||
|
||
def set_config(self, config):
|
||
"""设置配置"""
|
||
if 'feature_start_column' in config:
|
||
self.feature_start.setText(str(config['feature_start_column']))
|
||
if 'cv_folds' in config:
|
||
self.cv_folds.setValue(config['cv_folds'])
|
||
if 'preprocessing_methods' in config:
|
||
methods = config['preprocessing_methods']
|
||
for method, checkbox in self.preproc_checkboxes.items():
|
||
checkbox.setChecked(method in methods)
|
||
if 'model_names' in config:
|
||
models = config['model_names']
|
||
for model, checkbox in self.model_checkboxes.items():
|
||
checkbox.setChecked(model in models)
|
||
if 'split_methods' in config:
|
||
methods = config['split_methods']
|
||
for method, checkbox in self.split_checkboxes.items():
|
||
checkbox.setChecked(method in methods)
|
||
if 'training_csv_path' in config:
|
||
self.training_csv_file.set_path(config['training_csv_path'])
|
||
if 'output_path' in config:
|
||
self.output_path.set_path(config['output_path'])
|
||
|
||
def update_from_config(self, work_dir=None, pipeline=None):
|
||
"""从全局配置自动填充训练数据和输出路径
|
||
|
||
Args:
|
||
work_dir: 工作目录路径
|
||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||
"""
|
||
if work_dir:
|
||
self.work_dir = work_dir
|
||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||
pass
|
||
else:
|
||
self.work_dir = None
|
||
|
||
# 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
|
||
main_window = self.window()
|
||
if hasattr(main_window, 'step5_panel'):
|
||
# 优先直接从 Step5 的输出 widget 读取
|
||
step5_output = main_window.step5_panel.output_file.get_path()
|
||
if step5_output:
|
||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||
if not os.path.isabs(step5_output):
|
||
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
|
||
self.training_csv_file.set_path(step5_output)
|
||
elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
|
||
# 回退:从 Step5 的 config 字典中查找可能的键名
|
||
step5_cfg = main_window.step5_panel.get_config()
|
||
step5_csv = (
|
||
step5_cfg.get('training_csv_path')
|
||
or step5_cfg.get('output_file')
|
||
or step5_cfg.get('csv_path')
|
||
or step5_cfg.get('output_csv')
|
||
)
|
||
if step5_csv:
|
||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||
if not os.path.isabs(step5_csv):
|
||
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
|
||
self.training_csv_file.set_path(step5_csv)
|
||
|
||
# 2. 自动填充输出文件路径(基于工作目录和输入文件名)
|
||
# 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
|
||
# 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
|
||
if self.work_dir:
|
||
indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
|
||
os.makedirs(indices_dir, exist_ok=True)
|
||
training_csv = self.training_csv_file.get_path()
|
||
if training_csv:
|
||
basename = os.path.splitext(os.path.basename(training_csv))[0]
|
||
output_file = f"{basename}_indices.csv"
|
||
else:
|
||
output_file = "water_quality_indices.csv"
|
||
output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
|
||
self.output_path.set_path(output_path)
|
||
else:
|
||
self.output_path.set_path("")
|
||
|
||
def run_step(self):
|
||
"""独立运行步骤8"""
|
||
training_csv_path = self.training_csv_file.get_path()
|
||
if not training_csv_path:
|
||
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
|
||
return
|
||
|
||
main_window = self.window()
|
||
if hasattr(main_window, 'run_single_step'):
|
||
config = {'step8_ml_train': self.get_config()}
|
||
main_window.run_single_step('step8_ml_train', config)
|
||
|
||
def get_training_params(self):
|
||
"""获取模型训练参数"""
|
||
return {
|
||
'pipeline_type': 'machine_learning',
|
||
'feature_start': float(self.feature_start.text()),
|
||
'cv_folds': self.cv_folds.value(),
|
||
'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
|
||
'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
|
||
'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
|
||
}
|