feat(step8): 外部模型从单文件升级为母文件夹多模型字典扫描
This commit is contained in:
@ -326,6 +326,14 @@ class WorkerThread(QThread):
|
||||
method_name = step_method_map[step_name]
|
||||
step_config = dict(config.get(step_name, {}))
|
||||
|
||||
# 透传面板顶层传入的外部预训练模型(GUI step8_panel 通过 config['_external_model'] 传入)
|
||||
# 非空才覆盖(遵循 feedback_never_overwrite_with_empty 原则)
|
||||
for key in ('_external_model', '_external_model_path',
|
||||
'_external_models_dict', '_external_model_dir'):
|
||||
val = config.get(key)
|
||||
if val is not None and val != "":
|
||||
step_config[key] = val
|
||||
|
||||
step_config['skip_dependency_check'] = True
|
||||
|
||||
if step_name == 'step9':
|
||||
|
||||
@ -91,3 +91,13 @@ Traceback (most recent call last):
|
||||
sys.exit(app.exec_())
|
||||
^^^^^^^^^^^
|
||||
KeyboardInterrupt
|
||||
|
||||
============================================================
|
||||
[2026-06-04 09:54:07]
|
||||
Traceback (most recent call last):
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3237, in <module>
|
||||
main()
|
||||
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3214, in main
|
||||
sys.exit(app.exec_())
|
||||
^^^^^^^^^^^
|
||||
KeyboardInterrupt
|
||||
|
||||
@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
|
||||
QFileDialog,
|
||||
QFileDialog, QRadioButton,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
@ -21,12 +21,61 @@ class Step8Panel(QWidget):
|
||||
"""步骤8:机器学习预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.external_models_dict = {} # {subdir_name: model_obj, ...}
|
||||
self.external_model_dir = "" # 母文件夹路径(隐藏)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 采样光谱CSV文件(用于独立运行)
|
||||
# -------- 模型来源选择(单选按钮组) --------
|
||||
source_group = QGroupBox("模型来源")
|
||||
source_layout = QVBoxLayout()
|
||||
|
||||
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
|
||||
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
|
||||
self.use_trained_model.setChecked(True)
|
||||
source_layout.addWidget(self.use_trained_model)
|
||||
source_layout.addWidget(self.use_external_model)
|
||||
|
||||
self.use_trained_model.toggled.connect(self._on_model_source_changed)
|
||||
self.use_external_model.toggled.connect(self._on_model_source_changed)
|
||||
|
||||
source_group.setStyleSheet("""
|
||||
QRadioButton {
|
||||
font-size: 13px;
|
||||
spacing: 8px;
|
||||
}
|
||||
QRadioButton::indicator {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border-radius: 9px;
|
||||
border: 2px solid #A0A0A0;
|
||||
background-color: #FFFFFF;
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
background-color: #0078D7;
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
""")
|
||||
|
||||
source_group.setLayout(source_layout)
|
||||
layout.addWidget(source_group)
|
||||
|
||||
# -------- 外部模型文件选择(条件显示) --------
|
||||
self.external_model_widget = FileSelectWidget(
|
||||
"模型母文件夹:",
|
||||
"Directories"
|
||||
)
|
||||
self.external_model_widget.browse_btn.clicked.disconnect()
|
||||
self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
|
||||
self.external_model_widget.setVisible(False)
|
||||
layout.addWidget(self.external_model_widget)
|
||||
|
||||
# -------- 采样光谱CSV文件(用于独立运行)--------
|
||||
self.sampling_csv_file = FileSelectWidget(
|
||||
"采样光谱CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
@ -79,6 +128,94 @@ class Step8Panel(QWidget):
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def _on_model_source_changed(self, checked: bool):
|
||||
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
|
||||
if not checked:
|
||||
return
|
||||
is_external = self.use_external_model.isChecked()
|
||||
self.external_model_widget.setVisible(is_external)
|
||||
if not is_external:
|
||||
self.external_models_dict = {}
|
||||
self.external_model_dir = ""
|
||||
|
||||
def _scan_external_model_dir(self):
|
||||
"""浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "7_Supervised_Model_Training")
|
||||
dir_path = QFileDialog.getExistingDirectory(
|
||||
self,
|
||||
"选择模型母文件夹",
|
||||
default,
|
||||
)
|
||||
if not dir_path:
|
||||
return
|
||||
|
||||
self.external_model_dir = dir_path
|
||||
models_found = {}
|
||||
errors = []
|
||||
|
||||
try:
|
||||
import joblib
|
||||
|
||||
for subentry in os.scandir(dir_path):
|
||||
if not subentry.is_dir():
|
||||
continue
|
||||
subdir_name = subentry.name
|
||||
joblib_files = [
|
||||
f for f in os.scandir(subentry.path)
|
||||
if f.is_file() and f.name.lower().endswith(".joblib")
|
||||
]
|
||||
if not joblib_files:
|
||||
continue
|
||||
# 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
|
||||
joblib_path = joblib_files[0].path
|
||||
try:
|
||||
loaded = joblib.load(joblib_path)
|
||||
if isinstance(loaded, dict) and "model" in loaded:
|
||||
model_obj = loaded["model"]
|
||||
elif hasattr(loaded, "predict"):
|
||||
model_obj = loaded
|
||||
else:
|
||||
errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
|
||||
continue
|
||||
models_found[subdir_name] = model_obj
|
||||
except Exception as e:
|
||||
errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"扫描失败",
|
||||
f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
|
||||
)
|
||||
return
|
||||
|
||||
if not models_found:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"未找到模型",
|
||||
f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
|
||||
"请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
|
||||
)
|
||||
self.external_model_widget.set_path("")
|
||||
self.external_models_dict = {}
|
||||
return
|
||||
|
||||
self.external_models_dict = models_found
|
||||
names = sorted(models_found.keys())
|
||||
display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
|
||||
self.external_model_widget.set_path(display)
|
||||
self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
|
||||
|
||||
err_lines = "\n".join(errors) if errors else "无"
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"模型扫描完成",
|
||||
f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
|
||||
f"加载失败 {len(errors)} 个:\n{err_lines}",
|
||||
)
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充采样光谱和模型目录
|
||||
|
||||
@ -197,10 +334,31 @@ class Step8Panel(QWidget):
|
||||
def run_step(self):
|
||||
"""独立运行步骤8"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
|
||||
# 外部模型优先:用户选择了"导入本地预训练模型"
|
||||
if self.use_external_model.isChecked():
|
||||
if not self.external_models_dict:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"模型未加载",
|
||||
"请先点击「浏览...」按钮选择模型母文件夹!",
|
||||
)
|
||||
return
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {
|
||||
'step8': self.get_config(),
|
||||
'_external_models_dict': self.external_models_dict,
|
||||
'_external_model_dir': self.external_model_dir,
|
||||
}
|
||||
main_window.run_single_step('step8', config)
|
||||
return
|
||||
|
||||
# 默认流程:使用模型目录
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if not models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
|
||||
return
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user