feat(step8): 外部模型从单文件升级为母文件夹多模型字典扫描

This commit is contained in:
DXC
2026-06-08 09:56:02 +08:00
parent 4efe5b871e
commit 2b76d7908f
12 changed files with 935 additions and 29 deletions

View File

@ -326,6 +326,14 @@ class WorkerThread(QThread):
method_name = step_method_map[step_name]
step_config = dict(config.get(step_name, {}))
# 透传面板顶层传入的外部预训练模型GUI step8_panel 通过 config['_external_model'] 传入)
# 非空才覆盖(遵循 feedback_never_overwrite_with_empty 原则)
for key in ('_external_model', '_external_model_path',
'_external_models_dict', '_external_model_dir'):
val = config.get(key)
if val is not None and val != "":
step_config[key] = val
step_config['skip_dependency_check'] = True
if step_name == 'step9':

View File

@ -91,3 +91,13 @@ Traceback (most recent call last):
sys.exit(app.exec_())
^^^^^^^^^^^
KeyboardInterrupt
============================================================
[2026-06-04 09:54:07]
Traceback (most recent call last):
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3237, in <module>
main()
File "D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py", line 3214, in main
sys.exit(app.exec_())
^^^^^^^^^^^
KeyboardInterrupt

View File

@ -10,7 +10,7 @@ from pathlib import Path
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
QFileDialog,
QFileDialog, QRadioButton,
)
from src.gui.components.custom_widgets import FileSelectWidget
@ -21,12 +21,61 @@ class Step8Panel(QWidget):
"""步骤8机器学习预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.external_models_dict = {} # {subdir_name: model_obj, ...}
self.external_model_dir = "" # 母文件夹路径(隐藏)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 采样光谱CSV文件用于独立运行
# -------- 模型来源选择(单选按钮组) --------
source_group = QGroupBox("模型来源")
source_layout = QVBoxLayout()
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
self.use_trained_model.setChecked(True)
source_layout.addWidget(self.use_trained_model)
source_layout.addWidget(self.use_external_model)
self.use_trained_model.toggled.connect(self._on_model_source_changed)
self.use_external_model.toggled.connect(self._on_model_source_changed)
source_group.setStyleSheet("""
QRadioButton {
font-size: 13px;
spacing: 8px;
}
QRadioButton::indicator {
width: 16px;
height: 16px;
border-radius: 9px;
border: 2px solid #A0A0A0;
background-color: #FFFFFF;
}
QRadioButton::indicator:hover {
border: 2px solid #0078D7;
}
QRadioButton::indicator:checked {
background-color: #0078D7;
border: 2px solid #0078D7;
}
""")
source_group.setLayout(source_layout)
layout.addWidget(source_group)
# -------- 外部模型文件选择(条件显示) --------
self.external_model_widget = FileSelectWidget(
"模型母文件夹:",
"Directories"
)
self.external_model_widget.browse_btn.clicked.disconnect()
self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
self.external_model_widget.setVisible(False)
layout.addWidget(self.external_model_widget)
# -------- 采样光谱CSV文件用于独立运行--------
self.sampling_csv_file = FileSelectWidget(
"采样光谱CSV:",
"CSV Files (*.csv);;All Files (*.*)"
@ -79,6 +128,94 @@ class Step8Panel(QWidget):
layout.addStretch()
self.setLayout(layout)
def _on_model_source_changed(self, checked: bool):
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
if not checked:
return
is_external = self.use_external_model.isChecked()
self.external_model_widget.setVisible(is_external)
if not is_external:
self.external_models_dict = {}
self.external_model_dir = ""
def _scan_external_model_dir(self):
"""浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "7_Supervised_Model_Training")
dir_path = QFileDialog.getExistingDirectory(
self,
"选择模型母文件夹",
default,
)
if not dir_path:
return
self.external_model_dir = dir_path
models_found = {}
errors = []
try:
import joblib
for subentry in os.scandir(dir_path):
if not subentry.is_dir():
continue
subdir_name = subentry.name
joblib_files = [
f for f in os.scandir(subentry.path)
if f.is_file() and f.name.lower().endswith(".joblib")
]
if not joblib_files:
continue
# 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
joblib_path = joblib_files[0].path
try:
loaded = joblib.load(joblib_path)
if isinstance(loaded, dict) and "model" in loaded:
model_obj = loaded["model"]
elif hasattr(loaded, "predict"):
model_obj = loaded
else:
errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
continue
models_found[subdir_name] = model_obj
except Exception as e:
errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
except Exception as e:
QMessageBox.warning(
self,
"扫描失败",
f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
)
return
if not models_found:
QMessageBox.warning(
self,
"未找到模型",
f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
"请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
)
self.external_model_widget.set_path("")
self.external_models_dict = {}
return
self.external_models_dict = models_found
names = sorted(models_found.keys())
display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
self.external_model_widget.set_path(display)
self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
err_lines = "\n".join(errors) if errors else ""
QMessageBox.information(
self,
"模型扫描完成",
f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
f"加载失败 {len(errors)} 个:\n{err_lines}",
)
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充采样光谱和模型目录
@ -197,10 +334,31 @@ class Step8Panel(QWidget):
def run_step(self):
"""独立运行步骤8"""
sampling_csv_path = self.sampling_csv_file.get_path()
models_dir = self.models_dir_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
return
# 外部模型优先:用户选择了"导入本地预训练模型"
if self.use_external_model.isChecked():
if not self.external_models_dict:
QMessageBox.warning(
self,
"模型未加载",
"请先点击「浏览...」按钮选择模型母文件夹!",
)
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {
'step8': self.get_config(),
'_external_models_dict': self.external_models_dict,
'_external_model_dir': self.external_model_dir,
}
main_window.run_single_step('step8', config)
return
# 默认流程:使用模型目录
models_dir = self.models_dir_file.get_path()
if not models_dir:
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
return

Binary file not shown.