- 9 个面板(step1~step6/step8_ml_train/step8_qaa/step9_ml_predict/step10)单步执行按钮从 parent 链上溯改为 global_event_bus.publish('RequestRunSingleStep')
- PipelineExecutor 新增 _on_request_run_single_step 订阅
- 新增 Handler: step8_ml_train / step9_ml_predict / step10_qaa_inversion / step11_concentration / step12_kriging / step13_visualization / step14_report
- 删除旧 water_quality_inversion_pipeline_GUI.py(上帝类已肢解完毕)
509 lines
20 KiB
Python
509 lines
20 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
Step11 面板 - 机器学习预测
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# 路径归一化 helper(与 pipeline.get_step_output_dir 互为表里)
|
||
_HERE = os.path.dirname(os.path.abspath(__file__))
|
||
if _HERE not in sys.path:
|
||
sys.path.insert(0, _HERE)
|
||
from _step_path_resolver import get_step_output_path, resolve_step_widget, resolve_subdir
|
||
|
||
from PyQt5.QtWidgets import (
|
||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
|
||
QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout,
|
||
QListWidgetItem,
|
||
)
|
||
from PyQt5.QtCore import Qt
|
||
|
||
from src.gui.components.custom_widgets import FileSelectWidget
|
||
from src.gui.styles import ModernStylesheet
|
||
|
||
|
||
class Step9MlPredictPanel(QWidget):
|
||
"""步骤9:机器学习预测"""
|
||
def __init__(self, parent=None):
|
||
super().__init__(parent)
|
||
self.external_models_dict = {} # {subdir_name: model_obj, ...}
|
||
self.external_model_dir = "" # 母文件夹路径(隐藏)
|
||
self.init_ui()
|
||
|
||
def init_ui(self):
|
||
layout = QVBoxLayout()
|
||
|
||
# -------- 模型来源选择(单选按钮组) --------
|
||
source_group = QGroupBox("模型来源")
|
||
source_layout = QVBoxLayout()
|
||
|
||
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
|
||
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
|
||
self.use_trained_model.setChecked(True)
|
||
source_layout.addWidget(self.use_trained_model)
|
||
source_layout.addWidget(self.use_external_model)
|
||
|
||
self.use_trained_model.toggled.connect(self._on_model_source_changed)
|
||
self.use_external_model.toggled.connect(self._on_model_source_changed)
|
||
|
||
source_group.setStyleSheet("""
|
||
QRadioButton {
|
||
font-size: 13px;
|
||
spacing: 8px;
|
||
}
|
||
QRadioButton::indicator {
|
||
width: 16px;
|
||
height: 16px;
|
||
border-radius: 9px;
|
||
border: 2px solid #A0A0A0;
|
||
background-color: #FFFFFF;
|
||
}
|
||
QRadioButton::indicator:hover {
|
||
border: 2px solid #0078D7;
|
||
}
|
||
QRadioButton::indicator:checked {
|
||
background-color: #0078D7;
|
||
border: 2px solid #0078D7;
|
||
}
|
||
""")
|
||
|
||
source_group.setLayout(source_layout)
|
||
layout.addWidget(source_group)
|
||
|
||
# -------- 外部模型文件选择(条件显示) --------
|
||
self.external_model_widget = FileSelectWidget(
|
||
"模型母文件夹:",
|
||
"Directories"
|
||
)
|
||
self.external_model_widget.browse_btn.clicked.disconnect()
|
||
self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
|
||
self.external_model_widget.setVisible(False)
|
||
layout.addWidget(self.external_model_widget)
|
||
|
||
# -------- 已扫描模型列表(条件显示) --------
|
||
self.model_list_group = QGroupBox("选择参与预测的模型")
|
||
self.model_list_group.setVisible(False)
|
||
model_list_layout = QVBoxLayout()
|
||
|
||
self.model_list = QListWidget()
|
||
self.model_list.setMaximumHeight(130)
|
||
self.model_list.setSelectionMode(QAbstractItemView.NoSelection)
|
||
self.model_list.setStyleSheet("""
|
||
QListWidget {
|
||
border: 1px solid #C0C0C0;
|
||
border-radius: 4px;
|
||
background-color: #FFFFFF;
|
||
font-size: 12px;
|
||
}
|
||
QListWidget::item {
|
||
padding: 4px 6px;
|
||
border-bottom: 1px solid #F0F0F0;
|
||
}
|
||
QListWidget::item:selected {
|
||
background-color: transparent;
|
||
}
|
||
""")
|
||
model_list_layout.addWidget(self.model_list)
|
||
|
||
btn_row = QHBoxLayout()
|
||
self.btn_select_all = QPushButton("全选")
|
||
self.btn_select_all.setMaximumWidth(80)
|
||
self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
|
||
self.btn_select_all.clicked.connect(self._select_all_models)
|
||
|
||
self.btn_select_none = QPushButton("全不选")
|
||
self.btn_select_none.setMaximumWidth(80)
|
||
self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
|
||
self.btn_select_none.clicked.connect(self._select_none_models)
|
||
|
||
btn_row.addWidget(self.btn_select_all)
|
||
btn_row.addWidget(self.btn_select_none)
|
||
btn_row.addStretch()
|
||
model_list_layout.addLayout(btn_row)
|
||
|
||
self.model_list_group.setLayout(model_list_layout)
|
||
layout.addWidget(self.model_list_group)
|
||
|
||
# -------- 采样光谱CSV文件(用于独立运行)--------
|
||
self.sampling_csv_file = FileSelectWidget(
|
||
"采样光谱CSV:",
|
||
"CSV Files (*.csv);;All Files (*.*)"
|
||
)
|
||
layout.addWidget(self.sampling_csv_file)
|
||
|
||
# 模型目录(用于独立运行)
|
||
self.models_dir_file = FileSelectWidget(
|
||
"模型目录:",
|
||
"Directories;;All Files (*.*)"
|
||
)
|
||
self.models_dir_file.label.setText("模型目录:")
|
||
self.models_dir_file.browse_btn.clicked.disconnect()
|
||
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
|
||
layout.addWidget(self.models_dir_file)
|
||
|
||
# 参数设置
|
||
params_group = QGroupBox("预测参数")
|
||
params_layout = QFormLayout()
|
||
|
||
self.metric = QComboBox()
|
||
self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
|
||
params_layout.addRow("模型选择指标:", self.metric)
|
||
|
||
self.prediction_column = QLineEdit()
|
||
self.prediction_column.setText("prediction")
|
||
params_layout.addRow("预测列名:", self.prediction_column)
|
||
|
||
params_group.setLayout(params_layout)
|
||
layout.addWidget(params_group)
|
||
|
||
# 输出路径
|
||
self.output_file = FileSelectWidget(
|
||
"输出路径:",
|
||
"CSV Files (*.csv);;All Files (*.*)"
|
||
)
|
||
layout.addWidget(self.output_file)
|
||
|
||
# 启用步骤
|
||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||
self.enable_checkbox.setChecked(True)
|
||
layout.addWidget(self.enable_checkbox)
|
||
|
||
# 独立运行按钮
|
||
self.run_btn = QPushButton("独立运行此步骤")
|
||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||
self.run_btn.clicked.connect(self._on_run_single_clicked)
|
||
layout.addWidget(self.run_btn)
|
||
|
||
layout.addStretch()
|
||
self.setLayout(layout)
|
||
|
||
def _on_model_source_changed(self, checked: bool):
|
||
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
|
||
if not checked:
|
||
return
|
||
is_external = self.use_external_model.isChecked()
|
||
self.external_model_widget.setVisible(is_external)
|
||
self.model_list_group.setVisible(is_external)
|
||
if not is_external:
|
||
self.external_models_dict = {}
|
||
self.external_model_dir = ""
|
||
self._clear_model_list()
|
||
|
||
def _scan_external_model_dir(self):
|
||
"""浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
|
||
default = self._get_default_work_dir()
|
||
if default:
|
||
default = resolve_subdir(default, 'ml_prediction')
|
||
dir_path = QFileDialog.getExistingDirectory(
|
||
self,
|
||
"选择模型母文件夹",
|
||
default,
|
||
)
|
||
if not dir_path:
|
||
return
|
||
|
||
self.external_model_dir = dir_path
|
||
models_found = {}
|
||
errors = []
|
||
|
||
try:
|
||
import joblib
|
||
|
||
for subentry in os.scandir(dir_path):
|
||
if not subentry.is_dir():
|
||
continue
|
||
subdir_name = subentry.name
|
||
joblib_files = [
|
||
f for f in os.scandir(subentry.path)
|
||
if f.is_file() and f.name.lower().endswith(".joblib")
|
||
]
|
||
if not joblib_files:
|
||
continue
|
||
joblib_path = joblib_files[0].path
|
||
try:
|
||
loaded = joblib.load(joblib_path)
|
||
if isinstance(loaded, dict) and "model" in loaded:
|
||
model_obj = loaded["model"]
|
||
elif hasattr(loaded, "predict"):
|
||
model_obj = loaded
|
||
else:
|
||
errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
|
||
continue
|
||
models_found[subdir_name] = model_obj
|
||
except Exception as e:
|
||
errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
|
||
|
||
except Exception as e:
|
||
QMessageBox.warning(
|
||
self,
|
||
"扫描失败",
|
||
f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
|
||
)
|
||
return
|
||
|
||
if not models_found:
|
||
QMessageBox.warning(
|
||
self,
|
||
"未找到模型",
|
||
f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
|
||
"请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
|
||
)
|
||
self.external_model_widget.set_path("")
|
||
self.external_models_dict = {}
|
||
self._clear_model_list()
|
||
return
|
||
|
||
self.external_models_dict = models_found
|
||
self._populate_model_list(models_found)
|
||
names = sorted(models_found.keys())
|
||
display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
|
||
self.external_model_widget.set_path(display)
|
||
self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
|
||
|
||
err_lines = "\n".join(errors) if errors else "无"
|
||
QMessageBox.information(
|
||
self,
|
||
"模型扫描完成",
|
||
f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
|
||
f"加载失败 {len(errors)} 个:\n{err_lines}",
|
||
)
|
||
|
||
def _populate_model_list(self, models_dict):
|
||
"""将扫描到的模型填充到 QListWidget,每个条目可勾选,默认全选"""
|
||
self.model_list.clear()
|
||
for name in sorted(models_dict.keys()):
|
||
item = QListWidgetItem(name)
|
||
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
|
||
item.setCheckState(Qt.Checked)
|
||
self.model_list.addItem(item)
|
||
|
||
def _clear_model_list(self):
|
||
"""清空模型列表"""
|
||
self.model_list.clear()
|
||
|
||
def _select_all_models(self):
|
||
"""全选:设置所有条目为 Checked"""
|
||
for i in range(self.model_list.count()):
|
||
self.model_list.item(i).setCheckState(Qt.Checked)
|
||
|
||
def _select_none_models(self):
|
||
"""全不选:设置所有条目为 Unchecked"""
|
||
for i in range(self.model_list.count()):
|
||
self.model_list.item(i).setCheckState(Qt.Unchecked)
|
||
|
||
def _get_checked_models_dict(self):
|
||
"""从列表中提取用户勾选的模型,组装成字典返回"""
|
||
result = {}
|
||
for i in range(self.model_list.count()):
|
||
item = self.model_list.item(i)
|
||
if item.checkState() == Qt.Checked:
|
||
name = item.text()
|
||
if name in self.external_models_dict:
|
||
result[name] = self.external_models_dict[name]
|
||
return result
|
||
|
||
def update_from_config(self, work_dir=None, pipeline=None):
|
||
"""从全局配置自动填充采样光谱和模型目录
|
||
|
||
Args:
|
||
work_dir: 工作目录路径
|
||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||
"""
|
||
try:
|
||
import traceback
|
||
|
||
if work_dir:
|
||
self.work_dir = work_dir
|
||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||
pass
|
||
else:
|
||
self.work_dir = None
|
||
|
||
main_window = self.window()
|
||
|
||
# 1. 尝试从 Step4(采样点布设)读取全湖采样点 CSV 路径
|
||
if main_window and hasattr(main_window, 'step4_sampling_panel'):
|
||
step4_widget = getattr(main_window.step4_sampling_panel, 'output_file', None)
|
||
step4_output_path = ""
|
||
if hasattr(step4_widget, 'get_path'):
|
||
step4_output_path = step4_widget.get_path() or ""
|
||
elif hasattr(step4_widget, 'text'):
|
||
step4_output_path = step4_widget.text() or ""
|
||
|
||
if step4_output_path:
|
||
if not os.path.isabs(step4_output_path):
|
||
step4_output_path = os.path.join(self.work_dir or '', step4_output_path).replace('\\', '/')
|
||
existing = self.sampling_csv_file.get_path()
|
||
if not existing or not existing.strip():
|
||
self.sampling_csv_file.set_path(step4_output_path)
|
||
|
||
# 2. 尝试从 Step8(监督建模)读取模型目录(修复张冠李戴:原代码 main_window.step9_panel 不存在)
|
||
step8_models_dir = get_step_output_path(
|
||
main_window, 'models_dir', work_dir=self.work_dir,
|
||
widget_attr='output_dir', fallback_key='step8_ml_train',
|
||
)
|
||
if step8_models_dir:
|
||
existing_models = self.models_dir_file.get_path()
|
||
if not existing_models or not existing_models.strip():
|
||
self.models_dir_file.set_path(step8_models_dir)
|
||
|
||
# 3. 自动填充输出路径(机器学习预测目录,归属 step9 → 9_ML_Prediction)
|
||
# 注:9_ML_Prediction 是 prediction_dir 的子目录,用本地约定
|
||
if self.work_dir:
|
||
output_dir = resolve_subdir(self.work_dir, 'ml_prediction')
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
existing_out = self.output_file.get_path()
|
||
if not existing_out or not existing_out.strip():
|
||
self.output_file.set_path(output_dir)
|
||
except Exception as e:
|
||
import traceback
|
||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||
traceback.print_exc()
|
||
|
||
def _get_default_work_dir(self):
|
||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||
if hasattr(self, 'work_dir') and self.work_dir:
|
||
return str(self.work_dir)
|
||
mw = self.window()
|
||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||
return str(mw.work_dir)
|
||
return ""
|
||
|
||
def browse_models_dir(self):
|
||
"""浏览模型目录"""
|
||
default = self._get_default_work_dir()
|
||
if default:
|
||
default = resolve_subdir(default, 'ml_prediction')
|
||
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
|
||
if dir_path:
|
||
self.models_dir_file.set_path(dir_path)
|
||
|
||
def get_config(self):
|
||
"""获取配置"""
|
||
config = {
|
||
'metric': self.metric.currentText(),
|
||
'prediction_column': self.prediction_column.text(),
|
||
}
|
||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||
if sampling_csv_path:
|
||
config['sampling_csv_path'] = sampling_csv_path
|
||
models_dir = self.models_dir_file.get_path()
|
||
if models_dir:
|
||
config['models_dir'] = models_dir
|
||
output_path = self.output_file.get_path()
|
||
if output_path:
|
||
config['output_path'] = output_path
|
||
return config
|
||
|
||
def set_config(self, config):
|
||
"""设置配置"""
|
||
if 'metric' in config:
|
||
idx = self.metric.findText(config['metric'])
|
||
if idx >= 0:
|
||
self.metric.setCurrentIndex(idx)
|
||
if 'prediction_column' in config:
|
||
self.prediction_column.setText(config['prediction_column'])
|
||
if 'sampling_csv_path' in config:
|
||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||
if 'models_dir' in config:
|
||
self.models_dir_file.set_path(config['models_dir'])
|
||
if 'output_path' in config:
|
||
self.output_file.set_path(config['output_path'])
|
||
|
||
def _on_run_single_clicked(self):
|
||
"""通过 EventBus 发布单步执行请求(解耦面板与 PipelineExecutor)。"""
|
||
from src.gui.core.event_bus import global_event_bus
|
||
|
||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||
if not sampling_csv_path:
|
||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||
return
|
||
|
||
# 外部模型优先:用户选择了"导入本地预训练模型"
|
||
if self.use_external_model.isChecked():
|
||
if not self.external_models_dict:
|
||
QMessageBox.warning(
|
||
self,
|
||
"模型未加载",
|
||
"请先点击「浏览...」按钮选择模型母文件夹!",
|
||
)
|
||
return
|
||
checked_dict = self._get_checked_models_dict()
|
||
if not checked_dict:
|
||
QMessageBox.warning(
|
||
self,
|
||
"未选择模型",
|
||
"请至少勾选一个模型参与预测!",
|
||
)
|
||
return
|
||
config = {
|
||
'step9_ml_predict': self.get_config(),
|
||
'_external_models_dict': checked_dict,
|
||
'_external_model_dir': self.external_model_dir,
|
||
}
|
||
global_event_bus.publish('RequestRunSingleStep', {
|
||
'step_name': 'step9_ml_predict',
|
||
'config': config,
|
||
})
|
||
return
|
||
|
||
# 默认流程:使用模型目录
|
||
models_dir = self.models_dir_file.get_path()
|
||
if not models_dir:
|
||
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
|
||
return
|
||
|
||
config = {'step9_ml_predict': self.get_config()}
|
||
global_event_bus.publish('RequestRunSingleStep', {
|
||
'step_name': 'step9_ml_predict',
|
||
'config': config,
|
||
})
|
||
|
||
def run_step(self):
|
||
"""独立运行步骤11(旧版 parent 链上溯方式,保留兼容)。"""
|
||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||
if not sampling_csv_path:
|
||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||
return
|
||
|
||
# 外部模型优先:用户选择了"导入本地预训练模型"
|
||
if self.use_external_model.isChecked():
|
||
if not self.external_models_dict:
|
||
QMessageBox.warning(
|
||
self,
|
||
"模型未加载",
|
||
"请先点击「浏览...」按钮选择模型母文件夹!",
|
||
)
|
||
return
|
||
checked_dict = self._get_checked_models_dict()
|
||
if not checked_dict:
|
||
QMessageBox.warning(
|
||
self,
|
||
"未选择模型",
|
||
"请至少勾选一个模型参与预测!",
|
||
)
|
||
return
|
||
main_window = self.window()
|
||
if hasattr(main_window, 'run_single_step'):
|
||
config = {
|
||
'step9_ml_predict': self.get_config(),
|
||
'_external_models_dict': checked_dict,
|
||
'_external_model_dir': self.external_model_dir,
|
||
}
|
||
main_window.run_single_step('step9_ml_predict', config)
|
||
return
|
||
|
||
# 默认流程:使用模型目录
|
||
models_dir = self.models_dir_file.get_path()
|
||
if not models_dir:
|
||
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
|
||
return
|
||
|
||
main_window = self.window()
|
||
if hasattr(main_window, 'run_single_step'):
|
||
config = {'step9_ml_predict': self.get_config()}
|
||
main_window.run_single_step('step9_ml_predict', config) |