Files
WQ_GUI/src/gui/panels/step8_panel.py

463 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step8 面板 - 机器学习预测
"""
import os
from pathlib import Path
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout,
QListWidgetItem,
)
from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step8Panel(QWidget):
"""步骤8机器学习预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.external_models_dict = {} # {subdir_name: model_obj, ...}
self.external_model_dir = "" # 母文件夹路径(隐藏)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# -------- 模型来源选择(单选按钮组) --------
source_group = QGroupBox("模型来源")
source_layout = QVBoxLayout()
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
self.use_trained_model.setChecked(True)
source_layout.addWidget(self.use_trained_model)
source_layout.addWidget(self.use_external_model)
self.use_trained_model.toggled.connect(self._on_model_source_changed)
self.use_external_model.toggled.connect(self._on_model_source_changed)
source_group.setStyleSheet("""
QRadioButton {
font-size: 13px;
spacing: 8px;
}
QRadioButton::indicator {
width: 16px;
height: 16px;
border-radius: 9px;
border: 2px solid #A0A0A0;
background-color: #FFFFFF;
}
QRadioButton::indicator:hover {
border: 2px solid #0078D7;
}
QRadioButton::indicator:checked {
background-color: #0078D7;
border: 2px solid #0078D7;
}
""")
source_group.setLayout(source_layout)
layout.addWidget(source_group)
# -------- 外部模型文件选择(条件显示) --------
self.external_model_widget = FileSelectWidget(
"模型母文件夹:",
"Directories"
)
self.external_model_widget.browse_btn.clicked.disconnect()
self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
self.external_model_widget.setVisible(False)
layout.addWidget(self.external_model_widget)
# -------- 已扫描模型列表(条件显示) --------
self.model_list_group = QGroupBox("选择参与预测的模型")
self.model_list_group.setVisible(False)
model_list_layout = QVBoxLayout()
self.model_list = QListWidget()
self.model_list.setMaximumHeight(130)
self.model_list.setSelectionMode(QAbstractItemView.NoSelection)
self.model_list.setStyleSheet("""
QListWidget {
border: 1px solid #C0C0C0;
border-radius: 4px;
background-color: #FFFFFF;
font-size: 12px;
}
QListWidget::item {
padding: 4px 6px;
border-bottom: 1px solid #F0F0F0;
}
QListWidget::item:selected {
background-color: transparent;
}
""")
model_list_layout.addWidget(self.model_list)
btn_row = QHBoxLayout()
self.btn_select_all = QPushButton("全选")
self.btn_select_all.setMaximumWidth(80)
self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
self.btn_select_all.clicked.connect(self._select_all_models)
self.btn_select_none = QPushButton("全不选")
self.btn_select_none.setMaximumWidth(80)
self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
self.btn_select_none.clicked.connect(self._select_none_models)
btn_row.addWidget(self.btn_select_all)
btn_row.addWidget(self.btn_select_none)
btn_row.addStretch()
model_list_layout.addLayout(btn_row)
self.model_list_group.setLayout(model_list_layout)
layout.addWidget(self.model_list_group)
# -------- 采样光谱CSV文件用于独立运行--------
self.sampling_csv_file = FileSelectWidget(
"采样光谱CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.sampling_csv_file)
# 模型目录(用于独立运行)
self.models_dir_file = FileSelectWidget(
"模型目录:",
"Directories;;All Files (*.*)"
)
self.models_dir_file.label.setText("模型目录:")
self.models_dir_file.browse_btn.clicked.disconnect()
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
layout.addWidget(self.models_dir_file)
# 参数设置
params_group = QGroupBox("预测参数")
params_layout = QFormLayout()
self.metric = QComboBox()
self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
params_layout.addRow("模型选择指标:", self.metric)
self.prediction_column = QLineEdit()
self.prediction_column.setText("prediction")
params_layout.addRow("预测列名:", self.prediction_column)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出路径
self.output_file = FileSelectWidget(
"输出路径:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
layout.addStretch()
self.setLayout(layout)
def _on_model_source_changed(self, checked: bool):
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
if not checked:
return
is_external = self.use_external_model.isChecked()
self.external_model_widget.setVisible(is_external)
self.model_list_group.setVisible(is_external)
if not is_external:
self.external_models_dict = {}
self.external_model_dir = ""
self._clear_model_list()
def _scan_external_model_dir(self):
"""浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "7_Supervised_Model_Training")
dir_path = QFileDialog.getExistingDirectory(
self,
"选择模型母文件夹",
default,
)
if not dir_path:
return
self.external_model_dir = dir_path
models_found = {}
errors = []
try:
import joblib
for subentry in os.scandir(dir_path):
if not subentry.is_dir():
continue
subdir_name = subentry.name
joblib_files = [
f for f in os.scandir(subentry.path)
if f.is_file() and f.name.lower().endswith(".joblib")
]
if not joblib_files:
continue
# 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
joblib_path = joblib_files[0].path
try:
loaded = joblib.load(joblib_path)
if isinstance(loaded, dict) and "model" in loaded:
model_obj = loaded["model"]
elif hasattr(loaded, "predict"):
model_obj = loaded
else:
errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
continue
models_found[subdir_name] = model_obj
except Exception as e:
errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
except Exception as e:
QMessageBox.warning(
self,
"扫描失败",
f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
)
return
if not models_found:
QMessageBox.warning(
self,
"未找到模型",
f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
"请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
)
self.external_model_widget.set_path("")
self.external_models_dict = {}
self._clear_model_list()
return
self.external_models_dict = models_found
self._populate_model_list(models_found)
names = sorted(models_found.keys())
display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
self.external_model_widget.set_path(display)
self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
err_lines = "\n".join(errors) if errors else ""
QMessageBox.information(
self,
"模型扫描完成",
f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
f"加载失败 {len(errors)} 个:\n{err_lines}",
)
def _populate_model_list(self, models_dict):
"""将扫描到的模型填充到 QListWidget每个条目可勾选默认全选"""
self.model_list.clear()
for name in sorted(models_dict.keys()):
item = QListWidgetItem(name)
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
item.setCheckState(Qt.Checked)
self.model_list.addItem(item)
def _clear_model_list(self):
"""清空模型列表"""
self.model_list.clear()
def _select_all_models(self):
"""全选:设置所有条目为 Checked"""
for i in range(self.model_list.count()):
self.model_list.item(i).setCheckState(Qt.Checked)
def _select_none_models(self):
"""全不选:设置所有条目为 Unchecked"""
for i in range(self.model_list.count()):
self.model_list.item(i).setCheckState(Qt.Unchecked)
def _get_checked_models_dict(self):
"""从列表中提取用户勾选的模型,组装成字典返回"""
result = {}
for i in range(self.model_list.count()):
item = self.model_list.item(i)
if item.checkState() == Qt.Checked:
name = item.text()
if name in self.external_models_dict:
result[name] = self.external_models_dict[name]
return result
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充采样光谱和模型目录
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
if main_window and hasattr(main_window, 'step7_panel'):
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
step7_output_path = ""
if hasattr(step7_widget, 'get_path'):
step7_output_path = step7_widget.get_path() or ""
elif hasattr(step7_widget, 'text'):
step7_output_path = step7_widget.text() or ""
if step7_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step7_output_path):
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
existing = self.sampling_csv_file.get_path()
if not existing or not existing.strip():
self.sampling_csv_file.set_path(step7_output_path)
# 2. 尝试从 Step6 界面读取监督模型目录
if main_window and hasattr(main_window, 'step6_panel'):
step6_widget = getattr(main_window.step6_panel, 'output_dir', None)
step6_models_dir = ""
if hasattr(step6_widget, 'get_path'):
step6_models_dir = step6_widget.get_path() or ""
elif hasattr(step6_widget, 'text'):
step6_models_dir = step6_widget.text() or ""
if step6_models_dir:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step6_models_dir):
step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/')
existing_models = self.models_dir_file.get_path()
if not existing_models or not existing_models.strip():
self.models_dir_file.set_path(step6_models_dir)
# 3. 自动填充输出路径(机器学习预测目录)
if self.work_dir:
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_file.get_path()
if not existing_out or not existing_out.strip():
self.output_file.set_path(output_dir)
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_models_dir(self):
"""浏览模型目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "7_Supervised_Model_Training")
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
if dir_path:
self.models_dir_file.set_path(dir_path)
def get_config(self):
"""获取配置"""
config = {
'metric': self.metric.currentText(),
'prediction_column': self.prediction_column.text(),
}
sampling_csv_path = self.sampling_csv_file.get_path()
if sampling_csv_path:
config['sampling_csv_path'] = sampling_csv_path
models_dir = self.models_dir_file.get_path()
if models_dir:
config['models_dir'] = models_dir
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'metric' in config:
idx = self.metric.findText(config['metric'])
if idx >= 0:
self.metric.setCurrentIndex(idx)
if 'prediction_column' in config:
self.prediction_column.setText(config['prediction_column'])
if 'sampling_csv_path' in config:
self.sampling_csv_file.set_path(config['sampling_csv_path'])
if 'models_dir' in config:
self.models_dir_file.set_path(config['models_dir'])
if 'output_path' in config:
self.output_file.set_path(config['output_path'])
def run_step(self):
"""独立运行步骤8"""
sampling_csv_path = self.sampling_csv_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
return
# 外部模型优先:用户选择了"导入本地预训练模型"
if self.use_external_model.isChecked():
if not self.external_models_dict:
QMessageBox.warning(
self,
"模型未加载",
"请先点击「浏览...」按钮选择模型母文件夹!",
)
return
# 只传递用户勾选的模型
checked_dict = self._get_checked_models_dict()
if not checked_dict:
QMessageBox.warning(
self,
"未选择模型",
"请至少勾选一个模型参与预测!",
)
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {
'step8': self.get_config(),
'_external_models_dict': checked_dict,
'_external_model_dir': self.external_model_dir,
}
main_window.run_single_step('step8', config)
return
# 默认流程:使用模型目录
models_dir = self.models_dir_file.get_path()
if not models_dir:
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step8': self.get_config()}
main_window.run_single_step('step8', config)