362 lines
13 KiB
Python
362 lines
13 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
Step6_75 面板 - 自定义回归分析
|
||
"""
|
||
|
||
import os
|
||
from typing import Dict
|
||
|
||
import pandas as pd
|
||
from PyQt5.QtWidgets import (
|
||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||
QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton,
|
||
QScrollArea, QMessageBox,
|
||
)
|
||
|
||
from src.gui.components.custom_widgets import FileSelectWidget
|
||
from src.gui.styles import ModernStylesheet
|
||
|
||
|
||
class Step6_75Panel(QWidget):
|
||
"""步骤6.75:自定义回归分析"""
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent)
|
||
self.x_column_checkboxes: Dict[str, QCheckBox] = {}
|
||
self.y_column_checkboxes: Dict[str, QCheckBox] = {}
|
||
self.method_checkboxes: Dict[str, QCheckBox] = {}
|
||
self.csv_columns = []
|
||
self.init_ui()
|
||
|
||
def init_ui(self):
|
||
layout = QVBoxLayout()
|
||
|
||
hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法")
|
||
hint.setStyleSheet("color: #666; font-size: 11px;")
|
||
layout.addWidget(hint)
|
||
|
||
# CSV文件选择
|
||
csv_group = QGroupBox("数据文件")
|
||
csv_layout = QVBoxLayout()
|
||
|
||
self.csv_file = FileSelectWidget(
|
||
"输入CSV文件:",
|
||
"CSV Files (*.csv);;All Files (*.*)"
|
||
)
|
||
self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed)
|
||
csv_layout.addWidget(self.csv_file)
|
||
|
||
self.refresh_btn = QPushButton("刷新列信息")
|
||
self.refresh_btn.clicked.connect(self.refresh_csv_columns)
|
||
csv_layout.addWidget(self.refresh_btn)
|
||
|
||
csv_group.setLayout(csv_layout)
|
||
layout.addWidget(csv_group)
|
||
|
||
# 自变量选择
|
||
x_group = QGroupBox("自变量列选择 (可多选)")
|
||
x_layout = QVBoxLayout()
|
||
|
||
x_scroll = QScrollArea()
|
||
x_scroll.setWidgetResizable(True)
|
||
x_scroll.setMinimumHeight(250)
|
||
x_scroll.setMaximumHeight(350)
|
||
|
||
x_widget = QWidget()
|
||
self.x_columns_layout = QGridLayout()
|
||
x_widget.setLayout(self.x_columns_layout)
|
||
|
||
x_scroll.setWidget(x_widget)
|
||
x_layout.addWidget(x_scroll)
|
||
|
||
x_btn_layout = QHBoxLayout()
|
||
self.x_select_all = QPushButton("全选")
|
||
self.x_deselect_all = QPushButton("全不选")
|
||
self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True))
|
||
self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False))
|
||
x_btn_layout.addWidget(self.x_select_all)
|
||
x_btn_layout.addWidget(self.x_deselect_all)
|
||
x_btn_layout.addStretch()
|
||
x_layout.addLayout(x_btn_layout)
|
||
|
||
x_group.setLayout(x_layout)
|
||
layout.addWidget(x_group)
|
||
|
||
# 因变量选择
|
||
y_group = QGroupBox("因变量列选择 (可多选)")
|
||
y_layout = QVBoxLayout()
|
||
|
||
y_scroll = QScrollArea()
|
||
y_scroll.setWidgetResizable(True)
|
||
y_scroll.setMinimumHeight(200)
|
||
y_scroll.setMaximumHeight(300)
|
||
|
||
y_widget = QWidget()
|
||
self.y_columns_layout = QGridLayout()
|
||
y_widget.setLayout(self.y_columns_layout)
|
||
|
||
y_scroll.setWidget(y_widget)
|
||
y_layout.addWidget(y_scroll)
|
||
|
||
y_btn_layout = QHBoxLayout()
|
||
self.y_select_all = QPushButton("全选")
|
||
self.y_deselect_all = QPushButton("全不选")
|
||
self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True))
|
||
self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False))
|
||
y_btn_layout.addWidget(self.y_select_all)
|
||
y_btn_layout.addWidget(self.y_deselect_all)
|
||
y_btn_layout.addStretch()
|
||
y_layout.addLayout(y_btn_layout)
|
||
|
||
y_group.setLayout(y_layout)
|
||
layout.addWidget(y_group)
|
||
|
||
# 回归方法选择
|
||
method_group = QGroupBox("回归方法选择 (可多选)")
|
||
method_layout = QVBoxLayout()
|
||
|
||
method_grid = QGridLayout()
|
||
regression_methods = [
|
||
'linear', 'exponential', 'power', 'logarithmic',
|
||
'polynomial', 'hyperbolic', 'sigmoidal'
|
||
]
|
||
|
||
for i, method in enumerate(regression_methods):
|
||
checkbox = QCheckBox(method)
|
||
if method in ['linear', 'exponential', 'power', 'logarithmic']:
|
||
checkbox.setChecked(True)
|
||
self.method_checkboxes[method] = checkbox
|
||
method_grid.addWidget(checkbox, i // 3, i % 3)
|
||
|
||
method_layout.addLayout(method_grid)
|
||
|
||
method_btn_layout = QHBoxLayout()
|
||
self.method_select_all = QPushButton("全选")
|
||
self.method_deselect_all = QPushButton("全不选")
|
||
self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True))
|
||
self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False))
|
||
method_btn_layout.addWidget(self.method_select_all)
|
||
method_btn_layout.addWidget(self.method_deselect_all)
|
||
method_btn_layout.addStretch()
|
||
method_layout.addLayout(method_btn_layout)
|
||
|
||
method_group.setLayout(method_layout)
|
||
layout.addWidget(method_group)
|
||
|
||
# 输出目录
|
||
output_group = QGroupBox("输出设置")
|
||
output_layout = QFormLayout()
|
||
|
||
self.output_dir = QLineEdit()
|
||
self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充
|
||
output_layout.addRow("输出目录名:", self.output_dir)
|
||
|
||
output_group.setLayout(output_layout)
|
||
layout.addWidget(output_group)
|
||
|
||
# 启用步骤
|
||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||
self.enable_checkbox.setChecked(True)
|
||
layout.addWidget(self.enable_checkbox)
|
||
|
||
# 独立运行按钮
|
||
self.run_button = QPushButton("独立运行此步骤")
|
||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||
self.run_button.clicked.connect(self.run_step)
|
||
layout.addWidget(self.run_button)
|
||
|
||
layout.addStretch()
|
||
self.setLayout(layout)
|
||
|
||
def toggle_checkboxes(self, checkboxes_dict, checked):
|
||
"""统一设置checkbox状态"""
|
||
for checkbox in checkboxes_dict.values():
|
||
checkbox.setChecked(checked)
|
||
|
||
def on_csv_file_changed(self):
|
||
"""CSV文件改变时自动刷新列信息"""
|
||
self.refresh_csv_columns()
|
||
|
||
def refresh_csv_columns(self):
|
||
"""刷新CSV文件的列信息"""
|
||
csv_path = self.csv_file.get_path()
|
||
if not csv_path or not os.path.exists(csv_path):
|
||
self.csv_columns = []
|
||
self.update_column_widgets()
|
||
return
|
||
|
||
try:
|
||
df = pd.read_csv(csv_path, nrows=0)
|
||
self.csv_columns = list(df.columns)
|
||
self.update_column_widgets()
|
||
except Exception as e:
|
||
self.csv_columns = []
|
||
self.update_column_widgets()
|
||
print(f"读取CSV列信息失败: {e}")
|
||
|
||
def update_column_widgets(self):
|
||
"""更新列选择组件"""
|
||
for checkbox in self.x_column_checkboxes.values():
|
||
checkbox.setParent(None)
|
||
self.x_column_checkboxes.clear()
|
||
|
||
for checkbox in self.y_column_checkboxes.values():
|
||
checkbox.setParent(None)
|
||
self.y_column_checkboxes.clear()
|
||
|
||
if not self.csv_columns:
|
||
return
|
||
|
||
for i, col in enumerate(self.csv_columns):
|
||
checkbox = QCheckBox(col)
|
||
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
|
||
checkbox.setChecked(True)
|
||
self.x_column_checkboxes[col] = checkbox
|
||
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
|
||
|
||
for i, col in enumerate(self.csv_columns):
|
||
checkbox = QCheckBox(col)
|
||
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
|
||
checkbox.setChecked(True)
|
||
self.y_column_checkboxes[col] = checkbox
|
||
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
|
||
|
||
self.x_columns_layout.update()
|
||
self.y_columns_layout.update()
|
||
|
||
def get_config(self):
|
||
selected_x_columns = [
|
||
col for col, checkbox in self.x_column_checkboxes.items()
|
||
if checkbox.isChecked()
|
||
]
|
||
selected_y_columns = [
|
||
col for col, checkbox in self.y_column_checkboxes.items()
|
||
if checkbox.isChecked()
|
||
]
|
||
selected_methods = [
|
||
method for method, checkbox in self.method_checkboxes.items()
|
||
if checkbox.isChecked()
|
||
]
|
||
if not selected_methods:
|
||
selected_methods = 'all'
|
||
|
||
return {
|
||
'csv_path': self.csv_file.get_path() or None,
|
||
'x_columns': selected_x_columns,
|
||
'y_columns': selected_y_columns,
|
||
'methods': selected_methods,
|
||
'output_dir': self.output_dir.text().strip() or None,
|
||
'enabled': self.enable_checkbox.isChecked()
|
||
}
|
||
|
||
def set_config(self, config):
|
||
if 'csv_path' in config:
|
||
self.csv_file.set_path(config['csv_path'])
|
||
self.refresh_csv_columns()
|
||
|
||
if 'x_columns' in config:
|
||
selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set()
|
||
for col, checkbox in self.x_column_checkboxes.items():
|
||
checkbox.setChecked(col in selected_x)
|
||
|
||
if 'y_columns' in config:
|
||
selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set()
|
||
for col, checkbox in self.y_column_checkboxes.items():
|
||
checkbox.setChecked(col in selected_y)
|
||
|
||
if 'methods' in config:
|
||
methods = config['methods']
|
||
if isinstance(methods, list):
|
||
selected_methods = set(methods)
|
||
elif methods == 'all':
|
||
selected_methods = set(self.method_checkboxes.keys())
|
||
else:
|
||
selected_methods = set()
|
||
for method, checkbox in self.method_checkboxes.items():
|
||
checkbox.setChecked(method in selected_methods)
|
||
|
||
if 'output_dir' in config:
|
||
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
|
||
if 'enabled' in config:
|
||
self.enable_checkbox.setChecked(config['enabled'])
|
||
|
||
def update_from_config(self, work_dir=None, pipeline=None):
|
||
"""从全局配置自动填充训练数据和输出路径
|
||
|
||
Args:
|
||
work_dir: 工作目录路径
|
||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||
"""
|
||
if work_dir:
|
||
self.work_dir = work_dir
|
||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||
pass
|
||
else:
|
||
self.work_dir = None
|
||
|
||
# 1. 尝试从 Step5 界面读取训练光谱 CSV 路径
|
||
main_window = self.window()
|
||
if main_window and hasattr(main_window, 'step5_panel'):
|
||
step5_output_path = main_window.step5_panel.output_file.get_path()
|
||
if step5_output_path:
|
||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||
if not os.path.isabs(step5_output_path):
|
||
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
|
||
existing = self.csv_file.get_path()
|
||
if not existing or not existing.strip():
|
||
self.csv_file.set_path(step5_output_path)
|
||
|
||
# 2. 自动填充输出目录(9_Custom_Regression_Modeling)
|
||
if self.work_dir:
|
||
output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
existing_out = self.output_dir.text().strip()
|
||
if not existing_out:
|
||
self.output_dir.setText(output_dir)
|
||
|
||
def run_step(self):
|
||
"""独立运行步骤6.75"""
|
||
csv_path = self.csv_file.get_path()
|
||
|
||
if not csv_path:
|
||
QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件")
|
||
return
|
||
if not os.path.exists(csv_path):
|
||
QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在")
|
||
return
|
||
|
||
selected_x_columns = [
|
||
col for col, checkbox in self.x_column_checkboxes.items()
|
||
if checkbox.isChecked()
|
||
]
|
||
if not selected_x_columns:
|
||
QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列")
|
||
return
|
||
|
||
selected_y_columns = [
|
||
col for col, checkbox in self.y_column_checkboxes.items()
|
||
if checkbox.isChecked()
|
||
]
|
||
if not selected_y_columns:
|
||
QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列")
|
||
return
|
||
|
||
selected_methods = [
|
||
method for method, checkbox in self.method_checkboxes.items()
|
||
if checkbox.isChecked()
|
||
]
|
||
if not selected_methods:
|
||
QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法")
|
||
return
|
||
|
||
config = self.get_config()
|
||
|
||
parent = self.parent()
|
||
while parent and not hasattr(parent, 'run_single_step'):
|
||
parent = parent.parent()
|
||
|
||
if parent and hasattr(parent, 'run_single_step'):
|
||
parent.run_single_step('step6_75', {'step6_75': config})
|
||
else:
|
||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|