feat(gui): 全流程面板合并 + 一键式运行 GUI 入口集成
This commit is contained in:
252
src/gui/panels/step10_panel.py
Normal file
252
src/gui/panels/step10_panel.py
Normal file
@ -0,0 +1,252 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step10 面板 - 采样点生成
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QSpinBox, QMessageBox,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.dialogs import SamplingViewerDialog
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step10Panel(QWidget):
|
||||
"""步骤10:采样点生成"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 去耀斑影像文件(用于独立运行)
|
||||
self.deglint_img_file = FileSelectWidget(
|
||||
"去耀斑影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.deglint_img_file)
|
||||
|
||||
# 水域掩膜文件(可选,用于独立运行)
|
||||
self.water_mask_file = FileSelectWidget(
|
||||
"水域掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.water_mask_file.label.setText("水域掩膜:")
|
||||
layout.addWidget(self.water_mask_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("采样参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.interval = QSpinBox()
|
||||
self.interval.setRange(10, 500)
|
||||
self.interval.setValue(50)
|
||||
params_layout.addRow("采样点间隔(像素):", self.interval)
|
||||
|
||||
self.sample_radius = QSpinBox()
|
||||
self.sample_radius.setRange(1, 50)
|
||||
self.sample_radius.setValue(5)
|
||||
params_layout.addRow("采样半径(像素):", self.sample_radius)
|
||||
|
||||
self.chunk_size = QSpinBox()
|
||||
self.chunk_size.setRange(100, 10000)
|
||||
self.chunk_size.setValue(1000)
|
||||
params_layout.addRow("处理块大小:", self.chunk_size)
|
||||
|
||||
self.use_adaptive_sampling = QCheckBox("启用自适应采样")
|
||||
self.use_adaptive_sampling.setChecked(True)
|
||||
params_layout.addRow("采样模式:", self.use_adaptive_sampling)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出采样点:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("sampling_points.csv")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
# 交互式预览按钮
|
||||
self.preview_btn = QPushButton("📊 交互式预览采样点与光谱")
|
||||
self.preview_btn.setEnabled(False)
|
||||
self.preview_btn.clicked.connect(self._open_sampling_viewer)
|
||||
layout.addWidget(self.preview_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
# 监听输出路径变化,实时更新预览按钮状态
|
||||
self.output_file.line_edit.textChanged.connect(self._on_output_changed)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'interval': self.interval.value(),
|
||||
'sample_radius': self.sample_radius.value(),
|
||||
'chunk_size': self.chunk_size.value(),
|
||||
'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(),
|
||||
}
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
if deglint_img_path:
|
||||
config['deglint_img_path'] = deglint_img_path
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['water_mask_path'] = water_mask_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'interval' in config:
|
||||
self.interval.setValue(config['interval'])
|
||||
if 'sample_radius' in config:
|
||||
self.sample_radius.setValue(config['sample_radius'])
|
||||
if 'chunk_size' in config:
|
||||
self.chunk_size.setValue(config['chunk_size'])
|
||||
if 'use_adaptive_sampling' in config:
|
||||
self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling'])
|
||||
if 'deglint_img_path' in config:
|
||||
self.deglint_img_file.set_path(config['deglint_img_path'])
|
||||
if 'water_mask_path' in config:
|
||||
self.water_mask_file.set_path(config['water_mask_path'])
|
||||
if 'glint_mask_path' in config:
|
||||
self.glint_mask_file.set_path(config['glint_mask_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充去耀斑影像和掩膜路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径)
|
||||
"""
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径)
|
||||
deglint_path = None
|
||||
if pipeline and hasattr(pipeline, 'step_outputs'):
|
||||
step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {})
|
||||
deglint_path = (
|
||||
step3_outputs.get('deglint_image')
|
||||
or step3_outputs.get('output_path')
|
||||
or step3_outputs.get('output_file')
|
||||
or step3_outputs.get('deglint_img_path')
|
||||
)
|
||||
# 回退:从 step3 面板 widget 直接读取(可能是相对路径)
|
||||
if not deglint_path and hasattr(main_window, 'step3_panel'):
|
||||
deglint_path = main_window.step3_panel.output_file.get_path()
|
||||
|
||||
if deglint_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(deglint_path):
|
||||
deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/')
|
||||
self.deglint_img_file.set_path(deglint_path)
|
||||
|
||||
# 2. 填充水域掩膜路径(优先级:pipeline.step_outputs > step1_panel > 1_water_mask > input-test)
|
||||
water_mask_path = None
|
||||
if pipeline and hasattr(pipeline, 'step_outputs'):
|
||||
step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {})
|
||||
water_mask_path = (
|
||||
step1_outputs.get('water_mask')
|
||||
or step1_outputs.get('output_path')
|
||||
or step1_outputs.get('output_file')
|
||||
)
|
||||
# 回退:从 step1 面板 widget 直接读取
|
||||
if not water_mask_path and hasattr(main_window, 'step1_panel'):
|
||||
water_mask_path = main_window.step1_panel.output_file.get_path()
|
||||
# 备选:扫描 1_water_mask 目录下的 .dat 文件
|
||||
if not water_mask_path and self.work_dir:
|
||||
mask_dir = os.path.join(self.work_dir, "1_water_mask")
|
||||
if os.path.isdir(mask_dir):
|
||||
dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')]
|
||||
if dat_files:
|
||||
water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/')
|
||||
# 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat)
|
||||
if not water_mask_path and self.work_dir:
|
||||
input_test_dir = os.path.join(self.work_dir, "input-test")
|
||||
if os.path.isdir(input_test_dir):
|
||||
dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')]
|
||||
# 优先匹配 water_mask_from_shp.dat
|
||||
for f in dat_files:
|
||||
if 'water_mask_from_shp' in f.lower():
|
||||
water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/')
|
||||
break
|
||||
# 否则取第一个 .dat 文件
|
||||
if not water_mask_path and dat_files:
|
||||
water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/')
|
||||
|
||||
if water_mask_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(water_mask_path):
|
||||
water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/')
|
||||
self.water_mask_file.set_path(water_mask_path)
|
||||
|
||||
# 3. 自动填充输出路径(绝对路径)
|
||||
if self.work_dir:
|
||||
output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv")
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
self.output_file.set_path(output_path.replace('\\', '/'))
|
||||
|
||||
# 4. 同步更新预览按钮状态(路径可能已自动填充)
|
||||
self._check_csv_exists()
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤10"""
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
if not deglint_img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step10': self.get_config()}
|
||||
main_window.run_single_step('step10', config)
|
||||
|
||||
def _check_csv_exists(self):
|
||||
"""检查 output csv 是否存在,驱动预览按钮启停"""
|
||||
csv_path = self.output_file.get_path()
|
||||
enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path))
|
||||
self.preview_btn.setEnabled(enabled)
|
||||
return enabled
|
||||
|
||||
def _on_output_changed(self, _text=None):
|
||||
"""输出路径输入框内容变化时调用(_text 为 line_edit.textChanged 信号参数)"""
|
||||
self._check_csv_exists()
|
||||
|
||||
def _open_sampling_viewer(self):
|
||||
"""打开交互式采样点查看器弹窗"""
|
||||
csv_path = self.output_file.get_path()
|
||||
if not csv_path or not os.path.exists(csv_path):
|
||||
QMessageBox.warning(
|
||||
self, "文件不存在",
|
||||
f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤10生成数据。"
|
||||
)
|
||||
return
|
||||
dialog = SamplingViewerDialog(csv_path, self)
|
||||
dialog.exec_()
|
||||
# 弹窗关闭后再次检查状态(可能文件被覆盖等)
|
||||
self._check_csv_exists()
|
||||
462
src/gui/panels/step11_ml_panel.py
Normal file
462
src/gui/panels/step11_ml_panel.py
Normal file
@ -0,0 +1,462 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8 面板 - 机器学习预测
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
|
||||
QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout,
|
||||
QListWidgetItem,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step11MlPanel(QWidget):
|
||||
"""步骤11:机器学习预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.external_models_dict = {} # {subdir_name: model_obj, ...}
|
||||
self.external_model_dir = "" # 母文件夹路径(隐藏)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# -------- 模型来源选择(单选按钮组) --------
|
||||
source_group = QGroupBox("模型来源")
|
||||
source_layout = QVBoxLayout()
|
||||
|
||||
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
|
||||
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
|
||||
self.use_trained_model.setChecked(True)
|
||||
source_layout.addWidget(self.use_trained_model)
|
||||
source_layout.addWidget(self.use_external_model)
|
||||
|
||||
self.use_trained_model.toggled.connect(self._on_model_source_changed)
|
||||
self.use_external_model.toggled.connect(self._on_model_source_changed)
|
||||
|
||||
source_group.setStyleSheet("""
|
||||
QRadioButton {
|
||||
font-size: 13px;
|
||||
spacing: 8px;
|
||||
}
|
||||
QRadioButton::indicator {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border-radius: 9px;
|
||||
border: 2px solid #A0A0A0;
|
||||
background-color: #FFFFFF;
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
background-color: #0078D7;
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
""")
|
||||
|
||||
source_group.setLayout(source_layout)
|
||||
layout.addWidget(source_group)
|
||||
|
||||
# -------- 外部模型文件选择(条件显示) --------
|
||||
self.external_model_widget = FileSelectWidget(
|
||||
"模型母文件夹:",
|
||||
"Directories"
|
||||
)
|
||||
self.external_model_widget.browse_btn.clicked.disconnect()
|
||||
self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
|
||||
self.external_model_widget.setVisible(False)
|
||||
layout.addWidget(self.external_model_widget)
|
||||
|
||||
# -------- 已扫描模型列表(条件显示) --------
|
||||
self.model_list_group = QGroupBox("选择参与预测的模型")
|
||||
self.model_list_group.setVisible(False)
|
||||
model_list_layout = QVBoxLayout()
|
||||
|
||||
self.model_list = QListWidget()
|
||||
self.model_list.setMaximumHeight(130)
|
||||
self.model_list.setSelectionMode(QAbstractItemView.NoSelection)
|
||||
self.model_list.setStyleSheet("""
|
||||
QListWidget {
|
||||
border: 1px solid #C0C0C0;
|
||||
border-radius: 4px;
|
||||
background-color: #FFFFFF;
|
||||
font-size: 12px;
|
||||
}
|
||||
QListWidget::item {
|
||||
padding: 4px 6px;
|
||||
border-bottom: 1px solid #F0F0F0;
|
||||
}
|
||||
QListWidget::item:selected {
|
||||
background-color: transparent;
|
||||
}
|
||||
""")
|
||||
model_list_layout.addWidget(self.model_list)
|
||||
|
||||
btn_row = QHBoxLayout()
|
||||
self.btn_select_all = QPushButton("全选")
|
||||
self.btn_select_all.setMaximumWidth(80)
|
||||
self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
|
||||
self.btn_select_all.clicked.connect(self._select_all_models)
|
||||
|
||||
self.btn_select_none = QPushButton("全不选")
|
||||
self.btn_select_none.setMaximumWidth(80)
|
||||
self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
|
||||
self.btn_select_none.clicked.connect(self._select_none_models)
|
||||
|
||||
btn_row.addWidget(self.btn_select_all)
|
||||
btn_row.addWidget(self.btn_select_none)
|
||||
btn_row.addStretch()
|
||||
model_list_layout.addLayout(btn_row)
|
||||
|
||||
self.model_list_group.setLayout(model_list_layout)
|
||||
layout.addWidget(self.model_list_group)
|
||||
|
||||
# -------- 采样光谱CSV文件(用于独立运行)--------
|
||||
self.sampling_csv_file = FileSelectWidget(
|
||||
"采样光谱CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.sampling_csv_file)
|
||||
|
||||
# 模型目录(用于独立运行)
|
||||
self.models_dir_file = FileSelectWidget(
|
||||
"模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.models_dir_file.label.setText("模型目录:")
|
||||
self.models_dir_file.browse_btn.clicked.disconnect()
|
||||
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
|
||||
layout.addWidget(self.models_dir_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("预测参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.metric = QComboBox()
|
||||
self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
|
||||
params_layout.addRow("模型选择指标:", self.metric)
|
||||
|
||||
self.prediction_column = QLineEdit()
|
||||
self.prediction_column.setText("prediction")
|
||||
params_layout.addRow("预测列名:", self.prediction_column)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出路径:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def _on_model_source_changed(self, checked: bool):
|
||||
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
|
||||
if not checked:
|
||||
return
|
||||
is_external = self.use_external_model.isChecked()
|
||||
self.external_model_widget.setVisible(is_external)
|
||||
self.model_list_group.setVisible(is_external)
|
||||
if not is_external:
|
||||
self.external_models_dict = {}
|
||||
self.external_model_dir = ""
|
||||
self._clear_model_list()
|
||||
|
||||
def _scan_external_model_dir(self):
|
||||
"""浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "7_Supervised_Model_Training")
|
||||
dir_path = QFileDialog.getExistingDirectory(
|
||||
self,
|
||||
"选择模型母文件夹",
|
||||
default,
|
||||
)
|
||||
if not dir_path:
|
||||
return
|
||||
|
||||
self.external_model_dir = dir_path
|
||||
models_found = {}
|
||||
errors = []
|
||||
|
||||
try:
|
||||
import joblib
|
||||
|
||||
for subentry in os.scandir(dir_path):
|
||||
if not subentry.is_dir():
|
||||
continue
|
||||
subdir_name = subentry.name
|
||||
joblib_files = [
|
||||
f for f in os.scandir(subentry.path)
|
||||
if f.is_file() and f.name.lower().endswith(".joblib")
|
||||
]
|
||||
if not joblib_files:
|
||||
continue
|
||||
# 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
|
||||
joblib_path = joblib_files[0].path
|
||||
try:
|
||||
loaded = joblib.load(joblib_path)
|
||||
if isinstance(loaded, dict) and "model" in loaded:
|
||||
model_obj = loaded["model"]
|
||||
elif hasattr(loaded, "predict"):
|
||||
model_obj = loaded
|
||||
else:
|
||||
errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
|
||||
continue
|
||||
models_found[subdir_name] = model_obj
|
||||
except Exception as e:
|
||||
errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"扫描失败",
|
||||
f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
|
||||
)
|
||||
return
|
||||
|
||||
if not models_found:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"未找到模型",
|
||||
f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
|
||||
"请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
|
||||
)
|
||||
self.external_model_widget.set_path("")
|
||||
self.external_models_dict = {}
|
||||
self._clear_model_list()
|
||||
return
|
||||
|
||||
self.external_models_dict = models_found
|
||||
self._populate_model_list(models_found)
|
||||
names = sorted(models_found.keys())
|
||||
display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
|
||||
self.external_model_widget.set_path(display)
|
||||
self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
|
||||
|
||||
err_lines = "\n".join(errors) if errors else "无"
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"模型扫描完成",
|
||||
f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
|
||||
f"加载失败 {len(errors)} 个:\n{err_lines}",
|
||||
)
|
||||
|
||||
def _populate_model_list(self, models_dict):
|
||||
"""将扫描到的模型填充到 QListWidget,每个条目可勾选,默认全选"""
|
||||
self.model_list.clear()
|
||||
for name in sorted(models_dict.keys()):
|
||||
item = QListWidgetItem(name)
|
||||
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
|
||||
item.setCheckState(Qt.Checked)
|
||||
self.model_list.addItem(item)
|
||||
|
||||
def _clear_model_list(self):
|
||||
"""清空模型列表"""
|
||||
self.model_list.clear()
|
||||
|
||||
def _select_all_models(self):
|
||||
"""全选:设置所有条目为 Checked"""
|
||||
for i in range(self.model_list.count()):
|
||||
self.model_list.item(i).setCheckState(Qt.Checked)
|
||||
|
||||
def _select_none_models(self):
|
||||
"""全不选:设置所有条目为 Unchecked"""
|
||||
for i in range(self.model_list.count()):
|
||||
self.model_list.item(i).setCheckState(Qt.Unchecked)
|
||||
|
||||
def _get_checked_models_dict(self):
|
||||
"""从列表中提取用户勾选的模型,组装成字典返回"""
|
||||
result = {}
|
||||
for i in range(self.model_list.count()):
|
||||
item = self.model_list.item(i)
|
||||
if item.checkState() == Qt.Checked:
|
||||
name = item.text()
|
||||
if name in self.external_models_dict:
|
||||
result[name] = self.external_models_dict[name]
|
||||
return result
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充采样光谱和模型目录
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
|
||||
if main_window and hasattr(main_window, 'step10_panel'):
|
||||
step7_widget = getattr(main_window.step10_panel, 'output_file', None)
|
||||
step7_output_path = ""
|
||||
if hasattr(step7_widget, 'get_path'):
|
||||
step7_output_path = step7_widget.get_path() or ""
|
||||
elif hasattr(step7_widget, 'text'):
|
||||
step7_output_path = step7_widget.text() or ""
|
||||
|
||||
if step7_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step7_output_path):
|
||||
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
|
||||
existing = self.sampling_csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.sampling_csv_file.set_path(step7_output_path)
|
||||
|
||||
# 2. 尝试从 Step6 界面读取监督模型目录
|
||||
if main_window and hasattr(main_window, 'step7_panel'):
|
||||
step6_widget = getattr(main_window.step7_panel, 'output_dir', None)
|
||||
step6_models_dir = ""
|
||||
if hasattr(step6_widget, 'get_path'):
|
||||
step6_models_dir = step6_widget.get_path() or ""
|
||||
elif hasattr(step6_widget, 'text'):
|
||||
step6_models_dir = step6_widget.text() or ""
|
||||
|
||||
if step6_models_dir:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step6_models_dir):
|
||||
step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/')
|
||||
existing_models = self.models_dir_file.get_path()
|
||||
if not existing_models or not existing_models.strip():
|
||||
self.models_dir_file.set_path(step6_models_dir)
|
||||
|
||||
# 3. 自动填充输出路径(机器学习预测目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_file.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_file.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_models_dir(self):
|
||||
"""浏览模型目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "7_Supervised_Model_Training")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
|
||||
if dir_path:
|
||||
self.models_dir_file.set_path(dir_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'metric': self.metric.currentText(),
|
||||
'prediction_column': self.prediction_column.text(),
|
||||
}
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if sampling_csv_path:
|
||||
config['sampling_csv_path'] = sampling_csv_path
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if models_dir:
|
||||
config['models_dir'] = models_dir
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'metric' in config:
|
||||
idx = self.metric.findText(config['metric'])
|
||||
if idx >= 0:
|
||||
self.metric.setCurrentIndex(idx)
|
||||
if 'prediction_column' in config:
|
||||
self.prediction_column.setText(config['prediction_column'])
|
||||
if 'sampling_csv_path' in config:
|
||||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||||
if 'models_dir' in config:
|
||||
self.models_dir_file.set_path(config['models_dir'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤8"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
|
||||
# 外部模型优先:用户选择了"导入本地预训练模型"
|
||||
if self.use_external_model.isChecked():
|
||||
if not self.external_models_dict:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"模型未加载",
|
||||
"请先点击「浏览...」按钮选择模型母文件夹!",
|
||||
)
|
||||
return
|
||||
# 只传递用户勾选的模型
|
||||
checked_dict = self._get_checked_models_dict()
|
||||
if not checked_dict:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"未选择模型",
|
||||
"请至少勾选一个模型参与预测!",
|
||||
)
|
||||
return
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {
|
||||
'step11_ml': self.get_config(),
|
||||
'_external_models_dict': checked_dict,
|
||||
'_external_model_dir': self.external_model_dir,
|
||||
}
|
||||
main_window.run_single_step('step11_ml', config)
|
||||
return
|
||||
|
||||
# 默认流程:使用模型目录
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if not models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step11_ml': self.get_config()}
|
||||
main_window.run_single_step('step11_ml', config)
|
||||
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8_5 面板 - 非经验模型预测
|
||||
Step11 面板 - 非经验模型预测
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -17,8 +17,8 @@ from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step8_5Panel(QWidget):
|
||||
"""步骤8.5:非经验模型预测"""
|
||||
class Step11Panel(QWidget):
|
||||
"""步骤11:非经验模型预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
@ -118,22 +118,22 @@ class Step8_5Panel(QWidget):
|
||||
if not existing or not existing.strip():
|
||||
self.sampling_csv_file.set_path(step7_output_path)
|
||||
|
||||
# 2. 尝试从 Step6.5 界面读取回归模型目录
|
||||
if main_window and hasattr(main_window, 'step6_5_panel'):
|
||||
step6_5_widget = getattr(main_window.step6_5_panel, 'output_dir', None)
|
||||
step6_5_models_dir = ""
|
||||
if hasattr(step6_5_widget, 'get_path'):
|
||||
step6_5_models_dir = step6_5_widget.get_path() or ""
|
||||
elif hasattr(step6_5_widget, 'text'):
|
||||
step6_5_models_dir = step6_5_widget.text() or ""
|
||||
# 2. 尝试从 Step8_Non_Empirical 界面读取回归模型目录
|
||||
if main_window and hasattr(main_window, 'step8_non_empirical_panel'):
|
||||
step8_non_empirical_widget = getattr(main_window.step8_non_empirical_panel, 'output_dir', None)
|
||||
step8_non_empirical_models_dir = ""
|
||||
if hasattr(step8_non_empirical_widget, 'get_path'):
|
||||
step8_non_empirical_models_dir = step8_non_empirical_widget.get_path() or ""
|
||||
elif hasattr(step8_non_empirical_widget, 'text'):
|
||||
step8_non_empirical_models_dir = step8_non_empirical_widget.text() or ""
|
||||
|
||||
if step6_5_models_dir:
|
||||
if step8_non_empirical_models_dir:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step6_5_models_dir):
|
||||
step6_5_models_dir = os.path.join(self.work_dir or '', step6_5_models_dir).replace('\\', '/')
|
||||
if not os.path.isabs(step8_non_empirical_models_dir):
|
||||
step8_non_empirical_models_dir = os.path.join(self.work_dir or '', step8_non_empirical_models_dir).replace('\\', '/')
|
||||
existing_models = self.models_dir_file.get_path()
|
||||
if not existing_models or not existing_models.strip():
|
||||
self.models_dir_file.set_path(step6_5_models_dir)
|
||||
self.models_dir_file.set_path(step8_non_empirical_models_dir)
|
||||
|
||||
# 3. 自动填充输出路径(非经验模型预测目录)
|
||||
if self.work_dir:
|
||||
@ -208,7 +208,7 @@ class Step8_5Panel(QWidget):
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤8.5"""
|
||||
"""独立运行步骤11"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
@ -221,6 +221,6 @@ class Step8_5Panel(QWidget):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step8_5', {'step8_5': config})
|
||||
parent.run_single_step('step11', {'step11': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8_75 面板 - 自定义回归预测
|
||||
Step12 面板 - 自定义回归预测
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -15,8 +15,8 @@ from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step8_75Panel(QWidget):
|
||||
"""步骤8.75:自定义回归预测"""
|
||||
class Step12Panel(QWidget):
|
||||
"""步骤12:自定义回归预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
@ -111,25 +111,25 @@ class Step8_75Panel(QWidget):
|
||||
if not existing or not existing.strip():
|
||||
self.sampling_csv_file.set_path(step7_output_path)
|
||||
|
||||
# 2. 尝试从 Step6.75 界面读取自定义回归模型目录
|
||||
if main_window and hasattr(main_window, 'step6_75_panel'):
|
||||
step6_75_widget = getattr(main_window.step6_75_panel, 'output_dir', None)
|
||||
step6_75_models_dir = ""
|
||||
if hasattr(step6_75_widget, 'get_path'):
|
||||
step6_75_models_dir = step6_75_widget.get_path() or ""
|
||||
elif hasattr(step6_75_widget, 'text'):
|
||||
step6_75_models_dir = step6_75_widget.text() or ""
|
||||
step6_75_models_dir = step6_75_models_dir.strip()
|
||||
# 2. 尝试从 Step9 界面读取自定义回归模型目录
|
||||
if main_window and hasattr(main_window, 'step12_panel'):
|
||||
step9_widget = getattr(main_window.step9_panel, 'output_dir', None)
|
||||
step9_models_dir = ""
|
||||
if hasattr(step9_widget, 'get_path'):
|
||||
step9_models_dir = step9_widget.get_path() or ""
|
||||
elif hasattr(step9_widget, 'text'):
|
||||
step9_models_dir = step9_widget.text() or ""
|
||||
step9_models_dir = step9_models_dir.strip()
|
||||
|
||||
if step6_75_models_dir:
|
||||
if step9_models_dir:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step6_75_models_dir):
|
||||
step6_75_models_dir = os.path.join(self.work_dir or '', step6_75_models_dir).replace('\\', '/')
|
||||
if not os.path.isabs(step9_models_dir):
|
||||
step9_models_dir = os.path.join(self.work_dir or '', step9_models_dir).replace('\\', '/')
|
||||
existing_models = self.regression_models_dir.get_path()
|
||||
if not existing_models or not existing_models.strip():
|
||||
self.regression_models_dir.set_path(step6_75_models_dir)
|
||||
self.regression_models_dir.set_path(step9_models_dir)
|
||||
|
||||
# 3. 自动填充回归模型目录(如果 step6_75 未提供)
|
||||
# 3. 自动填充回归模型目录(如果 step9 未提供)
|
||||
if self.work_dir:
|
||||
models_dir = self.regression_models_dir.get_path().strip()
|
||||
if not models_dir:
|
||||
@ -208,7 +208,7 @@ class Step8_75Panel(QWidget):
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤8.75"""
|
||||
"""独立运行步骤12"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
@ -225,6 +225,6 @@ class Step8_75Panel(QWidget):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step8_75', {'step8_75': config})
|
||||
parent.run_single_step('step12', {'step12': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
533
src/gui/panels/step14_panel.py
Normal file
533
src/gui/panels/step14_panel.py
Normal file
@ -0,0 +1,533 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step14 面板 - 分布图生成
|
||||
"""
|
||||
|
||||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
|
||||
QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
|
||||
QRadioButton, QButtonGroup, QMessageBox, QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
|
||||
try:
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
PIPELINE_AVAILABLE = True
|
||||
except ImportError:
|
||||
PIPELINE_AVAILABLE = False
|
||||
|
||||
|
||||
class Step14BatchThread(QThread):
|
||||
"""专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
|
||||
|
||||
finished_ok = pyqtSignal(int)
|
||||
failed = pyqtSignal(str)
|
||||
log_message = pyqtSignal(str, str)
|
||||
|
||||
def __init__(self, work_dir: str, csv_paths: List[str], step14_kwargs: dict, output_dir_optional: Optional[str]):
|
||||
super().__init__()
|
||||
self.work_dir = work_dir
|
||||
self.csv_paths = csv_paths
|
||||
self.step14_kwargs = step14_kwargs
|
||||
self.output_dir_optional = (output_dir_optional or "").strip() or None
|
||||
|
||||
def run(self):
|
||||
mpl_prev = None
|
||||
try:
|
||||
import matplotlib
|
||||
mpl_prev = matplotlib.get_backend()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("Agg")
|
||||
except Exception:
|
||||
mpl_prev = None
|
||||
try:
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
|
||||
n = len(self.csv_paths)
|
||||
for i, csv_p in enumerate(self.csv_paths):
|
||||
self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
|
||||
kw = {**self.step14_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
|
||||
if self.output_dir_optional:
|
||||
stem = Path(csv_p).stem
|
||||
kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
|
||||
else:
|
||||
kw["output_image_path"] = None
|
||||
pipeline.step9_generate_distribution_map(**kw)
|
||||
self.finished_ok.emit(n)
|
||||
except Exception as e:
|
||||
self.failed.emit(f"{e}\n{traceback.format_exc()}")
|
||||
finally:
|
||||
if mpl_prev:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend(mpl_prev)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class Step14Panel(QWidget):
|
||||
"""步骤14:分布图生成"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._batch_thread = None
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
hint = QLabel(
|
||||
"独立运行:可选「单个 CSV」或「文件夹批量」(扫描目录下所有 .csv)。"
|
||||
"完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。"
|
||||
)
|
||||
hint.setWordWrap(True)
|
||||
hint.setStyleSheet(
|
||||
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
|
||||
)
|
||||
layout.addWidget(hint)
|
||||
|
||||
mode_row = QHBoxLayout()
|
||||
self.mode_single_rb = QRadioButton("单个 CSV 文件")
|
||||
self.mode_folder_rb = QRadioButton("文件夹批量")
|
||||
self._mode_group = QButtonGroup(self)
|
||||
self._mode_group.addButton(self.mode_single_rb, 0)
|
||||
self._mode_group.addButton(self.mode_folder_rb, 1)
|
||||
mode_row.addWidget(self.mode_single_rb)
|
||||
mode_row.addWidget(self.mode_folder_rb)
|
||||
mode_row.addStretch()
|
||||
layout.addLayout(mode_row)
|
||||
|
||||
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
|
||||
radio_style = """
|
||||
QRadioButton {
|
||||
font-size: 14px;
|
||||
spacing: 8px;
|
||||
color: #333333;
|
||||
}
|
||||
QRadioButton::indicator {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border: 2px solid #999999;
|
||||
border-radius: 3px;
|
||||
background-color: white;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
border: 2px solid #0078d4;
|
||||
background-color: #0078d4;
|
||||
image: none;
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #005a9e;
|
||||
}
|
||||
"""
|
||||
self.mode_single_rb.setStyleSheet(radio_style)
|
||||
self.mode_folder_rb.setStyleSheet(radio_style)
|
||||
|
||||
self.prediction_csv_file = FileSelectWidget(
|
||||
"预测结果CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.prediction_csv_file)
|
||||
|
||||
folder_row = QHBoxLayout()
|
||||
self.prediction_csv_dir_label = QLabel("预测CSV目录:")
|
||||
self.prediction_csv_dir_label.setMinimumWidth(120)
|
||||
self.prediction_csv_dir_edit = QLineEdit()
|
||||
self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
|
||||
pred_dir_btn = QPushButton("浏览…")
|
||||
pred_dir_btn.setMaximumWidth(80)
|
||||
pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
|
||||
folder_row.addWidget(self.prediction_csv_dir_label)
|
||||
folder_row.addWidget(self.prediction_csv_dir_edit, 1)
|
||||
folder_row.addWidget(pred_dir_btn)
|
||||
self._folder_row_widget = QWidget()
|
||||
self._folder_row_widget.setLayout(folder_row)
|
||||
layout.addWidget(self._folder_row_widget)
|
||||
|
||||
self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv)")
|
||||
layout.addWidget(self.recursive_csv_cb)
|
||||
|
||||
self.boundary_file = FileSelectWidget(
|
||||
"边界文件:",
|
||||
"Shapefiles (*.shp);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.boundary_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("生成参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.resolution = QDoubleSpinBox()
|
||||
self.resolution.setRange(1, 1000)
|
||||
self.resolution.setValue(30)
|
||||
params_layout.addRow("分辨率(米):", self.resolution)
|
||||
|
||||
self.input_crs = QLineEdit()
|
||||
self.input_crs.setText("EPSG:32651")
|
||||
params_layout.addRow("输入坐标系:", self.input_crs)
|
||||
|
||||
self.output_crs = QLineEdit()
|
||||
self.output_crs.setText("EPSG:4326")
|
||||
params_layout.addRow("输出坐标系:", self.output_crs)
|
||||
|
||||
self.show_points = QCheckBox("显示采样点")
|
||||
params_layout.addRow("", self.show_points)
|
||||
|
||||
self.use_diffusion = QCheckBox("启用距离扩散")
|
||||
self.use_diffusion.setChecked(True)
|
||||
params_layout.addRow("", self.use_diffusion)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出目录
|
||||
self.output_dir = FileSelectWidget(
|
||||
"输出分布图目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
|
||||
self.output_dir.browse_btn.clicked.disconnect()
|
||||
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
layout.addWidget(self.output_dir)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
# 信号绑定与初始状态
|
||||
self.mode_single_rb.toggled.connect(self._toggle_input_mode)
|
||||
self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
|
||||
self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
|
||||
self._toggle_input_mode() # 根据默认值设置初始显示状态
|
||||
|
||||
def _toggle_input_mode(self):
|
||||
"""槽函数:根据单选框状态动态显示/隐藏对应的输入组件。"""
|
||||
folder_mode = self.mode_folder_rb.isChecked()
|
||||
# 单个 CSV 模式:显示单文件选择,隐藏文件夹选择
|
||||
self.prediction_csv_file.setVisible(not folder_mode)
|
||||
# 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择
|
||||
self._folder_row_widget.setVisible(folder_mode)
|
||||
self.recursive_csv_cb.setVisible(folder_mode)
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_prediction_csv_dir(self):
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "11_12_13_predictions")
|
||||
d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
|
||||
if d:
|
||||
self.prediction_csv_dir_edit.setText(d)
|
||||
|
||||
def _collect_csv_paths_from_folder(self) -> List[str]:
|
||||
folder = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
if not folder or not os.path.isdir(folder):
|
||||
return []
|
||||
root = Path(folder)
|
||||
if self.recursive_csv_cb.isChecked():
|
||||
files = sorted(root.rglob("*.csv"))
|
||||
else:
|
||||
files = sorted(root.glob("*.csv"))
|
||||
return [str(p) for p in files if p.is_file()]
|
||||
|
||||
def _step14_base_pipeline_kwargs(self) -> dict:
|
||||
return {
|
||||
'boundary_shp_path': self.boundary_file.get_path(),
|
||||
'resolution': self.resolution.value(),
|
||||
'input_crs': self.input_crs.text(),
|
||||
'output_crs': self.output_crs.text(),
|
||||
'show_sample_points': self.show_points.isChecked(),
|
||||
'use_distance_diffusion': self.use_diffusion.isChecked(),
|
||||
}
|
||||
|
||||
def get_config(self):
|
||||
pred_csv = (self.prediction_csv_file.get_path() or "").strip()
|
||||
folder_mode = self.mode_folder_rb.isChecked()
|
||||
pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
config = {
|
||||
'step14_batch_mode': 'folder' if folder_mode else 'single',
|
||||
'prediction_csv_dir': pred_dir if pred_dir else None,
|
||||
'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
|
||||
'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
|
||||
'boundary_shp_path': self.boundary_file.get_path(),
|
||||
'resolution': self.resolution.value(),
|
||||
'input_crs': self.input_crs.text(),
|
||||
'output_crs': self.output_crs.text(),
|
||||
'show_sample_points': self.show_points.isChecked(),
|
||||
'use_distance_diffusion': self.use_diffusion.isChecked(),
|
||||
}
|
||||
out_dir = (self.output_dir.get_path() or "").strip()
|
||||
if not folder_mode and pred_csv and out_dir:
|
||||
stem = Path(pred_csv).stem
|
||||
config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
|
||||
else:
|
||||
config['output_image_path'] = None
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
mode = config.get('step14_batch_mode', 'single')
|
||||
if mode == 'folder':
|
||||
self.mode_folder_rb.setChecked(True)
|
||||
else:
|
||||
self.mode_single_rb.setChecked(True)
|
||||
if config.get('prediction_csv_dir'):
|
||||
self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
|
||||
if 'recursive_csv_scan' in config:
|
||||
self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
|
||||
if 'prediction_csv_path' in config and config['prediction_csv_path']:
|
||||
self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
|
||||
if 'boundary_shp_path' in config:
|
||||
self.boundary_file.set_path(config['boundary_shp_path'])
|
||||
if 'resolution' in config:
|
||||
self.resolution.setValue(config['resolution'])
|
||||
if 'input_crs' in config:
|
||||
self.input_crs.setText(config['input_crs'])
|
||||
if 'output_crs' in config:
|
||||
self.output_crs.setText(config['output_crs'])
|
||||
if 'show_sample_points' in config:
|
||||
self.show_points.setChecked(config['show_sample_points'])
|
||||
if 'use_distance_diffusion' in config:
|
||||
self.use_diffusion.setChecked(config['use_distance_diffusion'])
|
||||
if 'output_dir' in config and config['output_dir']:
|
||||
self.output_dir.set_path(str(config['output_dir']))
|
||||
elif config.get('output_image_path'):
|
||||
p = Path(str(config['output_image_path']))
|
||||
if p.parent and str(p.parent) != '.':
|
||||
self.output_dir.set_path(str(p.parent))
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充预测结果目录
|
||||
|
||||
优先使用 Step8(机器学习预测)的输出目录作为待预测 CSV 目录;
|
||||
其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
if not main_window:
|
||||
return
|
||||
|
||||
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
|
||||
pred_dir = None
|
||||
if hasattr(main_window, 'step11_prediction_panel'):
|
||||
step8_widget = getattr(main_window.step11_prediction_panel, 'output_file', None)
|
||||
step8_output = ""
|
||||
if hasattr(step8_widget, 'get_path'):
|
||||
step8_output = step8_widget.get_path() or ""
|
||||
elif hasattr(step8_widget, 'text'):
|
||||
step8_output = step8_widget.text() or ""
|
||||
|
||||
if step8_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step8_output):
|
||||
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
|
||||
# 提取父目录后追加 Machine_Learning_Prediction(最底层真实子目录)
|
||||
base_pred_dir = str(Path(step8_output).parent)
|
||||
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
|
||||
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
|
||||
|
||||
# 2. 备选:从 Step11 界面读取非经验预测输出目录
|
||||
if not pred_dir and hasattr(main_window, 'step11_panel'):
|
||||
step8_5_widget = getattr(main_window.step11_panel, 'output_file', None)
|
||||
step8_5_output = ""
|
||||
if hasattr(step8_5_widget, 'get_path'):
|
||||
step8_5_output = step8_5_widget.get_path() or ""
|
||||
elif hasattr(step8_5_widget, 'text'):
|
||||
step8_5_output = step8_5_widget.text() or ""
|
||||
|
||||
if step8_5_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step8_5_output):
|
||||
step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
|
||||
pred_dir = str(Path(step8_5_output).parent)
|
||||
|
||||
# 3. 备选:从 Step12 界面读取自定义回归预测输出目录
|
||||
if not pred_dir and hasattr(main_window, 'step12_panel'):
|
||||
step8_75_widget = getattr(main_window.step12_panel, 'output_dir_widget', None)
|
||||
step8_75_output = ""
|
||||
if hasattr(step8_75_widget, 'get_path'):
|
||||
step8_75_output = step8_75_widget.get_path() or ""
|
||||
elif hasattr(step8_75_widget, 'text'):
|
||||
step8_75_output = step8_75_widget.text() or ""
|
||||
|
||||
if step8_75_output:
|
||||
pred_dir = step8_75_output
|
||||
|
||||
# 自动填入"预测CSV目录"(文件夹批量模式)
|
||||
if pred_dir:
|
||||
existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
if not existing_dir:
|
||||
self.prediction_csv_dir_edit.setText(pred_dir)
|
||||
# 切换到文件夹批量模式
|
||||
self.mode_folder_rb.setChecked(True)
|
||||
|
||||
# 4. 自动填充输出目录(14_visualization)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "14_visualization")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_dir.set_path(output_dir)
|
||||
|
||||
# 5. 自动探测原始矢量边界文件(.shp)作为专题图底图
|
||||
# 优先回溯 input-test/roi.shp,geopandas.read_file 仅支持矢量格式
|
||||
if self.work_dir:
|
||||
possible_shp = None
|
||||
candidates = [
|
||||
Path(self.work_dir).parent / "input-test" / "roi.shp",
|
||||
Path(self.work_dir) / "roi.shp",
|
||||
Path(self.work_dir).parent / "roi.shp",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.exists() and candidate.suffix.lower() == ".shp":
|
||||
possible_shp = candidate
|
||||
break
|
||||
|
||||
existing_boundary = (self.boundary_file.get_path() or "").strip()
|
||||
if not existing_boundary and possible_shp:
|
||||
self.boundary_file.set_path(str(possible_shp))
|
||||
elif not existing_boundary:
|
||||
# 未找到 .shp 时清空并提示用户手动选择矢量文件
|
||||
self.boundary_file.set_path("")
|
||||
print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "14_visualization")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
|
||||
if dir_path:
|
||||
self.output_dir.set_path(dir_path)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤14"""
|
||||
if self._batch_thread and self._batch_thread.isRunning():
|
||||
QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
|
||||
return
|
||||
|
||||
boundary_shp_path = self.boundary_file.get_path()
|
||||
if not boundary_shp_path:
|
||||
QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
|
||||
return
|
||||
if not os.path.exists(boundary_shp_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
|
||||
return
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if not parent or not hasattr(parent, 'run_single_step'):
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
return
|
||||
|
||||
if self.mode_folder_rb.isChecked():
|
||||
csv_list = self._collect_csv_paths_from_folder()
|
||||
if not csv_list:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入验证失败",
|
||||
"所选文件夹中未找到 .csv 文件,或目录无效。\n"
|
||||
"可勾选「包含子文件夹」以递归扫描。",
|
||||
)
|
||||
return
|
||||
if not PIPELINE_AVAILABLE:
|
||||
QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
|
||||
return
|
||||
work_dir = getattr(parent, "work_dir", None) or "./work_dir"
|
||||
work_dir = str(work_dir)
|
||||
base_kw = self._step14_base_pipeline_kwargs()
|
||||
out_dir_opt = (self.output_dir.get_path() or "").strip() or None
|
||||
self.run_button.setEnabled(False)
|
||||
self._batch_thread = Step14BatchThread(work_dir, csv_list, base_kw, out_dir_opt)
|
||||
main_win = parent
|
||||
|
||||
def _batch_log(msg, lvl):
|
||||
if hasattr(main_win, "log_message"):
|
||||
main_win.log_message(msg, lvl)
|
||||
|
||||
self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
|
||||
self._batch_thread.finished_ok.connect(self._on_step14_batch_ok, Qt.QueuedConnection)
|
||||
self._batch_thread.failed.connect(self._on_step14_batch_fail, Qt.QueuedConnection)
|
||||
self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection)
|
||||
self._batch_thread.start()
|
||||
if hasattr(parent, "log_message"):
|
||||
parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV,工作目录 {work_dir}", "info")
|
||||
return
|
||||
|
||||
prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
|
||||
if not prediction_csv_path:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入验证失败",
|
||||
"请选择「预测结果 CSV」文件,或切换到「文件夹批量」。",
|
||||
)
|
||||
return
|
||||
if not os.path.isfile(prediction_csv_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
parent.run_single_step('step14', {'step14': config})
|
||||
|
||||
def _on_step14_batch_ok(self, n: int):
|
||||
QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, "log_message"):
|
||||
parent = parent.parent()
|
||||
if parent and hasattr(parent, "log_message"):
|
||||
parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
|
||||
|
||||
def _on_step14_batch_fail(self, err: str):
|
||||
QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, "log_message"):
|
||||
parent = parent.parent()
|
||||
if parent and hasattr(parent, "log_message"):
|
||||
parent.log_message(err, "error")
|
||||
@ -1,225 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QGridLayout,
|
||||
QHBoxLayout, QLabel, QCheckBox, QPushButton, QMessageBox, QScrollArea
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
def get_resource_path(relative_path: str) -> str:
|
||||
"""适配开发与 PyInstaller 环境的路径获取逻辑。
|
||||
支持两种打包模式:
|
||||
1. --onedir 模式:文件在 exe_root/_internal/ 下 → 检查 _internal 目录
|
||||
2. --onefile 模式:文件在 sys._MEIPASS 平铺目录
|
||||
"""
|
||||
# 优先检查 PyInstaller onefile 模式(文件平铺在 _MEIPASS 下)
|
||||
if hasattr(sys, '_MEIPASS'):
|
||||
internal_path = os.path.join(sys._MEIPASS, '_internal', relative_path)
|
||||
if os.path.exists(internal_path):
|
||||
return internal_path
|
||||
return os.path.join(sys._MEIPASS, relative_path)
|
||||
|
||||
# 兼容 PyInstaller onedir 模式的 _internal 目录(exe 同级目录下)
|
||||
exe_dir = os.path.dirname(sys.executable)
|
||||
internal_path = os.path.join(exe_dir, '_internal', relative_path)
|
||||
if os.path.exists(internal_path):
|
||||
return internal_path
|
||||
|
||||
# 开发环境下:基于当前文件 (step5_5_panel.py) 的绝对路径进行回溯
|
||||
# 当前在 src/gui/panels/,目标在 src/gui/model/
|
||||
base_dir = Path(__file__).resolve().parent.parent / "model"
|
||||
target_path = base_dir / os.path.basename(relative_path)
|
||||
return str(target_path)
|
||||
|
||||
class Step5_5Panel(QWidget):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.index_checkboxes: Dict[str, QCheckBox] = {}
|
||||
# 标识为 waterindex.csv,目录跳转逻辑在 get_resource_path 中
|
||||
self.builtin_formula_path = get_resource_path("waterindex.csv")
|
||||
|
||||
self.init_ui()
|
||||
# 延迟一小会儿加载,确保UI框架已就绪
|
||||
self._auto_load_formulas()
|
||||
|
||||
def init_ui(self):
|
||||
main_layout = QVBoxLayout()
|
||||
main_layout.setContentsMargins(20, 20, 20, 20)
|
||||
main_layout.setSpacing(10)
|
||||
|
||||
# 1. 路径展示区 (半透明只读)
|
||||
path_group = QGroupBox("公式配置源 (内置)")
|
||||
path_layout = QVBoxLayout()
|
||||
self.formula_csv_widget = FileSelectWidget("内置CSV路径:", "CSV Files (*.csv)")
|
||||
self.formula_csv_widget.set_path(self.builtin_formula_path)
|
||||
self.formula_csv_widget.set_read_only(True)
|
||||
# 视觉微调:提示用户这是内置的
|
||||
self.formula_csv_widget.line_edit.setStyleSheet("background-color: #f0f0f0; color: #666;")
|
||||
path_layout.addWidget(self.formula_csv_widget)
|
||||
path_group.setLayout(path_layout)
|
||||
main_layout.addWidget(path_group)
|
||||
|
||||
# 2. 训练数据输入
|
||||
input_group = QGroupBox("输入样本数据")
|
||||
input_layout = QVBoxLayout()
|
||||
self.training_data_widget = FileSelectWidget("特征提取CSV:", "CSV Files (*.csv)")
|
||||
input_layout.addWidget(self.training_data_widget)
|
||||
input_group.setLayout(input_layout)
|
||||
main_layout.addWidget(input_group)
|
||||
|
||||
# 3. 公式选择区
|
||||
self.formula_group = QGroupBox("待计算水质指数勾选")
|
||||
formula_outer_layout = QVBoxLayout()
|
||||
|
||||
btn_layout = QHBoxLayout()
|
||||
self.select_all_btn = QPushButton("全选")
|
||||
self.deselect_all_btn = QPushButton("清空")
|
||||
self.select_all_btn.clicked.connect(self.select_all_formulas)
|
||||
self.deselect_all_btn.clicked.connect(self.deselect_all_formulas)
|
||||
btn_layout.addWidget(self.select_all_btn)
|
||||
btn_layout.addWidget(self.deselect_all_btn)
|
||||
btn_layout.addStretch()
|
||||
|
||||
self.refresh_button = QPushButton("手动重新加载公式")
|
||||
self.refresh_button.clicked.connect(lambda: self.refresh_formulas(silent=False))
|
||||
btn_layout.addWidget(self.refresh_button)
|
||||
|
||||
formula_outer_layout.addLayout(btn_layout)
|
||||
|
||||
# 核心滚动区
|
||||
scroll = QScrollArea()
|
||||
scroll.setWidgetResizable(True)
|
||||
scroll.setMinimumHeight(300) # 强制最小高度,防止塌陷
|
||||
self.scroll_content = QWidget()
|
||||
self.formula_layout = QGridLayout(self.scroll_content)
|
||||
self.formula_layout.setAlignment(Qt.AlignTop) # 靠顶对齐
|
||||
scroll.setWidget(self.scroll_content)
|
||||
formula_outer_layout.addWidget(scroll)
|
||||
|
||||
self.formula_group.setLayout(formula_outer_layout)
|
||||
main_layout.addWidget(self.formula_group)
|
||||
|
||||
# 4. 输出与运行
|
||||
output_group = QGroupBox("结果输出")
|
||||
output_layout = QVBoxLayout()
|
||||
self.output_file_widget = FileSelectWidget("保存路径:", "CSV Files (*.csv)", mode="save")
|
||||
output_layout.addWidget(self.output_file_widget)
|
||||
output_group.setLayout(output_layout)
|
||||
main_layout.addWidget(output_group)
|
||||
|
||||
self.enable_checkbox = QCheckBox("启用计算流程")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
main_layout.addWidget(self.enable_checkbox)
|
||||
|
||||
self.run_button = QPushButton("立即执行计算")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.setMinimumHeight(40)
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
main_layout.addWidget(self.run_button)
|
||||
|
||||
self.setLayout(main_layout)
|
||||
|
||||
def _auto_load_formulas(self):
|
||||
"""启动时自动加载逻辑"""
|
||||
if os.path.exists(self.builtin_formula_path):
|
||||
self.refresh_formulas(silent=True)
|
||||
else:
|
||||
print(f"DEBUG: 自动加载失败,路径不存在: {self.builtin_formula_path}")
|
||||
|
||||
def refresh_formulas(self, silent=False):
|
||||
path = self.builtin_formula_path
|
||||
if not os.path.exists(path):
|
||||
if not silent: QMessageBox.warning(self, "错误", f"找不到内置公式文件:\n{path}")
|
||||
return
|
||||
|
||||
try:
|
||||
# 清理旧列表
|
||||
for i in reversed(range(self.formula_layout.count())):
|
||||
widget = self.formula_layout.itemAt(i).widget()
|
||||
if widget: widget.deleteLater()
|
||||
self.index_checkboxes.clear()
|
||||
|
||||
# 鲁棒性读取:尝试不同编码
|
||||
for encoding in ['utf-8', 'gbk', 'utf-8-sig']:
|
||||
try:
|
||||
df = pd.read_csv(path, encoding=encoding)
|
||||
if 'Formula_Name' in df.columns: break
|
||||
except: continue
|
||||
|
||||
if 'Formula_Name' not in df.columns:
|
||||
if not silent: QMessageBox.critical(self, "错误", "CSV文件缺少 'Formula_Name' 列")
|
||||
return
|
||||
|
||||
names = df['Formula_Name'].dropna().unique().tolist()
|
||||
|
||||
row, col = 0, 0
|
||||
for name in names:
|
||||
name = str(name).strip()
|
||||
if not name: continue
|
||||
cb = QCheckBox(name)
|
||||
cb.setChecked(True)
|
||||
self.index_checkboxes[name] = cb
|
||||
self.formula_layout.addWidget(cb, row, col)
|
||||
col += 1
|
||||
if col >= 3:
|
||||
col = 0
|
||||
row += 1
|
||||
|
||||
# 强制UI更新
|
||||
self.scroll_content.adjustSize()
|
||||
print(f"✅ 成功加载 {len(self.index_checkboxes)} 个公式")
|
||||
|
||||
except Exception as e:
|
||||
if not silent: QMessageBox.critical(self, "加载失败", f"原因: {str(e)}")
|
||||
|
||||
def select_all_formulas(self):
|
||||
for cb in self.index_checkboxes.values(): cb.setChecked(True)
|
||||
|
||||
def deselect_all_formulas(self):
|
||||
for cb in self.index_checkboxes.values(): cb.setChecked(False)
|
||||
|
||||
def get_config(self):
|
||||
selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
|
||||
return {
|
||||
'training_csv_path': self.training_data_widget.get_path(),
|
||||
'formula_csv_file': self.builtin_formula_path,
|
||||
'formula_names': selected,
|
||||
'output_file': self.output_file_widget.get_path(),
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
|
||||
def set_config(self, config):
|
||||
if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path'])
|
||||
if 'formula_names' in config:
|
||||
sel = set(config['formula_names'])
|
||||
for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
|
||||
if 'output_file' in config: self.output_file_widget.set_path(config['output_file'])
|
||||
self.enable_checkbox.setChecked(config.get('enabled', True))
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
if work_dir: self.work_dir = work_dir
|
||||
main = self.window()
|
||||
if hasattr(main, 'step5_panel'):
|
||||
p5 = main.step5_panel.output_file.get_path() # 修正:变量名对齐
|
||||
if p5:
|
||||
if not os.path.isabs(p5): p5 = os.path.join(self.work_dir or '', p5).replace('\\', '/')
|
||||
self.training_data_widget.set_path(p5)
|
||||
|
||||
if self.work_dir:
|
||||
out = os.path.join(self.work_dir, "6_water_quality_indices", "training_spectra_indices.csv").replace('\\', '/')
|
||||
self.output_file_widget.set_path(out)
|
||||
|
||||
def run_step(self):
|
||||
config = self.get_config()
|
||||
if not config['training_csv_path']:
|
||||
QMessageBox.warning(self, "提示", "请先选择输入数据")
|
||||
return
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'): parent = parent.parent()
|
||||
if parent: parent.run_single_step('step5_5', {'step5_5': config})
|
||||
@ -1,374 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step6_75 面板 - 自定义回归分析
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pandas as pd
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton,
|
||||
QScrollArea, QMessageBox,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step6_75Panel(QWidget):
|
||||
"""步骤6.75:自定义回归分析"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.x_column_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.y_column_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.method_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.csv_columns = []
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法")
|
||||
hint.setStyleSheet("color: #666; font-size: 11px;")
|
||||
layout.addWidget(hint)
|
||||
|
||||
# CSV文件选择
|
||||
csv_group = QGroupBox("数据文件")
|
||||
csv_layout = QVBoxLayout()
|
||||
|
||||
self.csv_file = FileSelectWidget(
|
||||
"输入CSV文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed)
|
||||
csv_layout.addWidget(self.csv_file)
|
||||
|
||||
self.refresh_btn = QPushButton("刷新列信息")
|
||||
self.refresh_btn.clicked.connect(self.refresh_csv_columns)
|
||||
csv_layout.addWidget(self.refresh_btn)
|
||||
|
||||
csv_group.setLayout(csv_layout)
|
||||
layout.addWidget(csv_group)
|
||||
|
||||
# 自变量选择
|
||||
x_group = QGroupBox("自变量列选择 (可多选)")
|
||||
x_layout = QVBoxLayout()
|
||||
|
||||
x_scroll = QScrollArea()
|
||||
x_scroll.setWidgetResizable(True)
|
||||
x_scroll.setMinimumHeight(250)
|
||||
x_scroll.setMaximumHeight(350)
|
||||
|
||||
x_widget = QWidget()
|
||||
self.x_columns_layout = QGridLayout()
|
||||
x_widget.setLayout(self.x_columns_layout)
|
||||
|
||||
x_scroll.setWidget(x_widget)
|
||||
x_layout.addWidget(x_scroll)
|
||||
|
||||
x_btn_layout = QHBoxLayout()
|
||||
self.x_select_all = QPushButton("全选")
|
||||
self.x_deselect_all = QPushButton("全不选")
|
||||
self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True))
|
||||
self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False))
|
||||
x_btn_layout.addWidget(self.x_select_all)
|
||||
x_btn_layout.addWidget(self.x_deselect_all)
|
||||
x_btn_layout.addStretch()
|
||||
x_layout.addLayout(x_btn_layout)
|
||||
|
||||
x_group.setLayout(x_layout)
|
||||
layout.addWidget(x_group)
|
||||
|
||||
# 因变量选择
|
||||
y_group = QGroupBox("因变量列选择 (可多选)")
|
||||
y_layout = QVBoxLayout()
|
||||
|
||||
y_scroll = QScrollArea()
|
||||
y_scroll.setWidgetResizable(True)
|
||||
y_scroll.setMinimumHeight(200)
|
||||
y_scroll.setMaximumHeight(300)
|
||||
|
||||
y_widget = QWidget()
|
||||
self.y_columns_layout = QGridLayout()
|
||||
y_widget.setLayout(self.y_columns_layout)
|
||||
|
||||
y_scroll.setWidget(y_widget)
|
||||
y_layout.addWidget(y_scroll)
|
||||
|
||||
y_btn_layout = QHBoxLayout()
|
||||
self.y_select_all = QPushButton("全选")
|
||||
self.y_deselect_all = QPushButton("全不选")
|
||||
self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True))
|
||||
self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False))
|
||||
y_btn_layout.addWidget(self.y_select_all)
|
||||
y_btn_layout.addWidget(self.y_deselect_all)
|
||||
y_btn_layout.addStretch()
|
||||
y_layout.addLayout(y_btn_layout)
|
||||
|
||||
y_group.setLayout(y_layout)
|
||||
layout.addWidget(y_group)
|
||||
|
||||
# 回归方法选择
|
||||
method_group = QGroupBox("回归方法选择 (可多选)")
|
||||
method_layout = QVBoxLayout()
|
||||
|
||||
method_grid = QGridLayout()
|
||||
regression_methods = [
|
||||
'linear', 'exponential', 'power', 'logarithmic',
|
||||
'polynomial', 'hyperbolic', 'sigmoidal'
|
||||
]
|
||||
|
||||
for i, method in enumerate(regression_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
if method in ['linear', 'exponential', 'power', 'logarithmic']:
|
||||
checkbox.setChecked(True)
|
||||
self.method_checkboxes[method] = checkbox
|
||||
method_grid.addWidget(checkbox, i // 3, i % 3)
|
||||
|
||||
method_layout.addLayout(method_grid)
|
||||
|
||||
method_btn_layout = QHBoxLayout()
|
||||
self.method_select_all = QPushButton("全选")
|
||||
self.method_deselect_all = QPushButton("全不选")
|
||||
self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True))
|
||||
self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False))
|
||||
method_btn_layout.addWidget(self.method_select_all)
|
||||
method_btn_layout.addWidget(self.method_deselect_all)
|
||||
method_btn_layout.addStretch()
|
||||
method_layout.addLayout(method_btn_layout)
|
||||
|
||||
method_group.setLayout(method_layout)
|
||||
layout.addWidget(method_group)
|
||||
|
||||
# 输出目录
|
||||
output_group = QGroupBox("输出设置")
|
||||
output_layout = QFormLayout()
|
||||
|
||||
self.output_dir = QLineEdit()
|
||||
self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充
|
||||
output_layout.addRow("输出目录名:", self.output_dir)
|
||||
|
||||
output_group.setLayout(output_layout)
|
||||
layout.addWidget(output_group)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def toggle_checkboxes(self, checkboxes_dict, checked):
|
||||
"""统一设置checkbox状态"""
|
||||
for checkbox in checkboxes_dict.values():
|
||||
checkbox.setChecked(checked)
|
||||
|
||||
def on_csv_file_changed(self):
|
||||
"""CSV文件改变时自动刷新列信息"""
|
||||
self.refresh_csv_columns()
|
||||
|
||||
def refresh_csv_columns(self):
|
||||
"""刷新CSV文件的列信息"""
|
||||
csv_path = self.csv_file.get_path()
|
||||
if not csv_path or not os.path.exists(csv_path):
|
||||
self.csv_columns = []
|
||||
self.update_column_widgets()
|
||||
return
|
||||
|
||||
try:
|
||||
df = pd.read_csv(csv_path, nrows=0)
|
||||
self.csv_columns = list(df.columns)
|
||||
self.update_column_widgets()
|
||||
except Exception as e:
|
||||
self.csv_columns = []
|
||||
self.update_column_widgets()
|
||||
print(f"读取CSV列信息失败: {e}")
|
||||
|
||||
def update_column_widgets(self):
|
||||
"""更新列选择组件"""
|
||||
for checkbox in self.x_column_checkboxes.values():
|
||||
checkbox.setParent(None)
|
||||
self.x_column_checkboxes.clear()
|
||||
|
||||
for checkbox in self.y_column_checkboxes.values():
|
||||
checkbox.setParent(None)
|
||||
self.y_column_checkboxes.clear()
|
||||
|
||||
if not self.csv_columns:
|
||||
return
|
||||
|
||||
for i, col in enumerate(self.csv_columns):
|
||||
checkbox = QCheckBox(col)
|
||||
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
|
||||
checkbox.setChecked(True)
|
||||
self.x_column_checkboxes[col] = checkbox
|
||||
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
|
||||
|
||||
for i, col in enumerate(self.csv_columns):
|
||||
checkbox = QCheckBox(col)
|
||||
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
|
||||
checkbox.setChecked(True)
|
||||
self.y_column_checkboxes[col] = checkbox
|
||||
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
|
||||
|
||||
self.x_columns_layout.update()
|
||||
self.y_columns_layout.update()
|
||||
|
||||
def get_config(self):
|
||||
selected_x_columns = [
|
||||
col for col, checkbox in self.x_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
selected_y_columns = [
|
||||
col for col, checkbox in self.y_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
selected_methods = [
|
||||
method for method, checkbox in self.method_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_methods:
|
||||
selected_methods = 'all'
|
||||
|
||||
return {
|
||||
'csv_path': self.csv_file.get_path() or None,
|
||||
'x_columns': selected_x_columns,
|
||||
'y_columns': selected_y_columns,
|
||||
'methods': selected_methods,
|
||||
'output_dir': self.output_dir.text().strip() or None,
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
|
||||
def set_config(self, config):
|
||||
if 'csv_path' in config:
|
||||
self.csv_file.set_path(config['csv_path'])
|
||||
self.refresh_csv_columns()
|
||||
|
||||
if 'x_columns' in config:
|
||||
selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set()
|
||||
for col, checkbox in self.x_column_checkboxes.items():
|
||||
checkbox.setChecked(col in selected_x)
|
||||
|
||||
if 'y_columns' in config:
|
||||
selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set()
|
||||
for col, checkbox in self.y_column_checkboxes.items():
|
||||
checkbox.setChecked(col in selected_y)
|
||||
|
||||
if 'methods' in config:
|
||||
methods = config['methods']
|
||||
if isinstance(methods, list):
|
||||
selected_methods = set(methods)
|
||||
elif methods == 'all':
|
||||
selected_methods = set(self.method_checkboxes.keys())
|
||||
else:
|
||||
selected_methods = set()
|
||||
for method, checkbox in self.method_checkboxes.items():
|
||||
checkbox.setChecked(method in selected_methods)
|
||||
|
||||
if 'output_dir' in config:
|
||||
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Step5 界面读取训练光谱 CSV 路径
|
||||
main_window = self.window()
|
||||
if main_window and hasattr(main_window, 'step5_panel'):
|
||||
step5_widget = getattr(main_window.step5_panel, 'output_file', None)
|
||||
step5_output_path = ""
|
||||
if hasattr(step5_widget, 'get_path'):
|
||||
step5_output_path = step5_widget.get_path() or ""
|
||||
elif hasattr(step5_widget, 'text'):
|
||||
step5_output_path = step5_widget.text() or ""
|
||||
|
||||
if step5_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_output_path):
|
||||
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
|
||||
existing = self.csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.csv_file.set_path(step5_output_path)
|
||||
|
||||
# 2. 自动填充输出目录(9_Custom_Regression_Modeling)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir.text().strip()
|
||||
if not existing_out:
|
||||
self.output_dir.setText(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤6.75"""
|
||||
csv_path = self.csv_file.get_path()
|
||||
|
||||
if not csv_path:
|
||||
QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件")
|
||||
return
|
||||
if not os.path.exists(csv_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在")
|
||||
return
|
||||
|
||||
selected_x_columns = [
|
||||
col for col, checkbox in self.x_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_x_columns:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列")
|
||||
return
|
||||
|
||||
selected_y_columns = [
|
||||
col for col, checkbox in self.y_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_y_columns:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列")
|
||||
return
|
||||
|
||||
selected_methods = [
|
||||
method for method, checkbox in self.method_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_methods:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step6_75', {'step6_75': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
@ -1,415 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step6 面板 - 机器学习建模
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
|
||||
QPushButton, QFileDialog, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 中文映射表(内部键名 -> 显示文本)
|
||||
# ============================================================
|
||||
|
||||
# 预处理方法:内部键 -> 显示文本
|
||||
PREPROC_CHINESE = {
|
||||
'None': '无 (None)',
|
||||
'MMS': '最小-最大归一化 (MMS)',
|
||||
'SS': '标度化 (SS)',
|
||||
'SNV': '标准正态变换 (SNV)',
|
||||
'MA': '移动平均 (MA)',
|
||||
'SG': 'Savitzky-Golay (SG)',
|
||||
'MSC': '多元散射校正 (MSC)',
|
||||
'D1': '一阶导数 (D1)',
|
||||
'D2': '二阶导数 (D2)',
|
||||
'DT': '去趋势 (DT)',
|
||||
'CT': '中心化 (CT)',
|
||||
}
|
||||
|
||||
# 模型类型:内部键 -> 显示文本
|
||||
MODEL_CHINESE = {
|
||||
# 线性模型
|
||||
'LinearRegression': '多元线性回归 (MLR)',
|
||||
'Ridge': '岭回归 (Ridge)',
|
||||
'Lasso': '套索回归 (Lasso)',
|
||||
'ElasticNet': '弹性网络 (ElasticNet)',
|
||||
'PLS': '偏最小二乘 (PLSR)',
|
||||
# 树模型
|
||||
'DecisionTree': '决策树 (CART)',
|
||||
'RF': '随机森林 (RF)',
|
||||
'ExtraTrees': '极端随机树 (ET)',
|
||||
'XGBoost': '极值梯度提升 (XGBoost)',
|
||||
'LightGBM': '轻量梯度提升 (LightGBM)',
|
||||
'CatBoost': '类别梯度提升 (CatBoost)',
|
||||
# 集成学习
|
||||
'GradientBoosting': '梯度提升树 (GBDT)',
|
||||
'AdaBoost': '自适应提升 (AdaBoost)',
|
||||
# 其他模型
|
||||
'SVR': '支持向量回归 (SVR)',
|
||||
'KNN': 'K近邻回归 (KNN)',
|
||||
'MLP': '多层感知机 (BP神经网络)',
|
||||
}
|
||||
|
||||
# 数据划分方法:内部键 -> 显示文本
|
||||
SPLIT_CHINESE = {
|
||||
'spxy': 'SPXY 算法 (考量X-Y空间)',
|
||||
'ks': 'KS 算法 (考量X空间)',
|
||||
'random': '随机划分 (Random)',
|
||||
}
|
||||
|
||||
|
||||
class Step6Panel(QWidget):
|
||||
"""步骤6:机器学习建模"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 训练数据文件(用于独立运行)
|
||||
self.training_csv_file = FileSelectWidget(
|
||||
"训练数据:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.training_csv_file)
|
||||
|
||||
# 机器学习模型页面
|
||||
self.ml_page = QWidget()
|
||||
self.create_ml_page()
|
||||
layout.addWidget(self.ml_page)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_path = FileSelectWidget(
|
||||
"输出文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)",
|
||||
mode="save"
|
||||
)
|
||||
self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
|
||||
self.output_path.browse_btn.clicked.disconnect()
|
||||
self.output_path.browse_btn.clicked.connect(self.browse_output_path)
|
||||
layout.addWidget(self.output_path)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(False)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def create_ml_page(self):
|
||||
"""创建机器学习模型页面"""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("训练参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.feature_start = QLineEdit()
|
||||
self.feature_start.setText("374.285004")
|
||||
params_layout.addRow("特征起始列:", self.feature_start)
|
||||
|
||||
self.cv_folds = QSpinBox()
|
||||
self.cv_folds.setRange(2, 10)
|
||||
self.cv_folds.setValue(3)
|
||||
params_layout.addRow("交叉验证折数:", self.cv_folds)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 预处理方法 - 多选
|
||||
preproc_group = QGroupBox("预处理方法 (可多选)")
|
||||
preproc_layout = QVBoxLayout()
|
||||
|
||||
preproc_grid = QGridLayout()
|
||||
self.preproc_checkboxes = {}
|
||||
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
|
||||
|
||||
for i, method in enumerate(preproc_methods):
|
||||
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
|
||||
checkbox.setChecked(False)
|
||||
self.preproc_checkboxes[method] = checkbox
|
||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||
|
||||
button_layout = QHBoxLayout()
|
||||
select_all_btn = QPushButton("全选")
|
||||
deselect_all_btn = QPushButton("全不选")
|
||||
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
|
||||
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
|
||||
button_layout.addWidget(select_all_btn)
|
||||
button_layout.addWidget(deselect_all_btn)
|
||||
button_layout.addStretch()
|
||||
|
||||
preproc_layout.addLayout(preproc_grid)
|
||||
preproc_layout.addLayout(button_layout)
|
||||
preproc_group.setLayout(preproc_layout)
|
||||
layout.addWidget(preproc_group)
|
||||
|
||||
# 模型选择 - 多选
|
||||
model_group = QGroupBox("模型类型 (可多选)")
|
||||
model_layout = QVBoxLayout()
|
||||
|
||||
model_grid = QGridLayout()
|
||||
self.model_checkboxes = {}
|
||||
|
||||
model_groups = [
|
||||
("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
|
||||
("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
|
||||
("【集成学习】", ['GradientBoosting', 'AdaBoost']),
|
||||
("【其他模型】", ['SVR', 'KNN', 'MLP'])
|
||||
]
|
||||
|
||||
row = 0
|
||||
for group_name, models in model_groups:
|
||||
group_label = QLabel(f"<b>{group_name}</b>")
|
||||
group_label.setStyleSheet(
|
||||
f"background-color: {ModernStylesheet.COLORS['hover']}; "
|
||||
f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
|
||||
f"border-radius: 3px;"
|
||||
)
|
||||
model_grid.addWidget(group_label, row, 0, 1, 4)
|
||||
row += 1
|
||||
|
||||
for i, model in enumerate(models):
|
||||
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
|
||||
checkbox.setChecked(False)
|
||||
self.model_checkboxes[model] = checkbox
|
||||
model_grid.addWidget(checkbox, row, i % 4)
|
||||
if (i + 1) % 4 == 0:
|
||||
row += 1
|
||||
|
||||
row += 1
|
||||
|
||||
model_button_layout = QHBoxLayout()
|
||||
model_select_all = QPushButton("全选")
|
||||
model_deselect_all = QPushButton("全不选")
|
||||
model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
|
||||
model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
|
||||
model_button_layout.addWidget(model_select_all)
|
||||
model_button_layout.addWidget(model_deselect_all)
|
||||
model_button_layout.addStretch()
|
||||
|
||||
model_layout.addLayout(model_grid)
|
||||
model_layout.addLayout(model_button_layout)
|
||||
model_group.setLayout(model_layout)
|
||||
layout.addWidget(model_group)
|
||||
|
||||
# 数据划分方法 - 多选
|
||||
split_group = QGroupBox("数据划分方法 (可多选)")
|
||||
split_layout = QVBoxLayout()
|
||||
|
||||
split_grid = QGridLayout()
|
||||
self.split_checkboxes = {}
|
||||
split_methods = ['spxy', 'ks', 'random']
|
||||
|
||||
for i, method in enumerate(split_methods):
|
||||
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
|
||||
checkbox.setChecked(False)
|
||||
self.split_checkboxes[method] = checkbox
|
||||
split_grid.addWidget(checkbox, 0, i)
|
||||
|
||||
split_button_layout = QHBoxLayout()
|
||||
split_select_all = QPushButton("全选")
|
||||
split_deselect_all = QPushButton("全不选")
|
||||
split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
|
||||
split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
|
||||
split_button_layout.addWidget(split_select_all)
|
||||
split_button_layout.addWidget(split_deselect_all)
|
||||
split_button_layout.addStretch()
|
||||
|
||||
split_layout.addLayout(split_grid)
|
||||
split_layout.addLayout(split_button_layout)
|
||||
split_group.setLayout(split_layout)
|
||||
layout.addWidget(split_group)
|
||||
|
||||
self.ml_page.setLayout(layout)
|
||||
|
||||
def _toggle_checkboxes(self, checkboxes_dict, checked):
|
||||
"""统一设置checkbox状态"""
|
||||
for checkbox in checkboxes_dict.values():
|
||||
checkbox.setChecked(checked)
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_output_path(self):
|
||||
"""浏览输出文件路径(保存对话框)"""
|
||||
current = self.output_path.get_path().strip()
|
||||
if current:
|
||||
initial_dir = os.path.dirname(current)
|
||||
initial_file = os.path.basename(current)
|
||||
else:
|
||||
initial_dir = ""
|
||||
initial_file = ""
|
||||
|
||||
if not initial_dir or not os.path.isdir(initial_dir):
|
||||
# 默认定位到 indices 目录
|
||||
work_dir = self._get_default_work_dir()
|
||||
initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
|
||||
if initial_dir and not os.path.isdir(initial_dir):
|
||||
os.makedirs(initial_dir, exist_ok=True)
|
||||
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
if file_path:
|
||||
self.output_path.set_path(file_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
preprocessing_methods = [
|
||||
method for method, checkbox in self.preproc_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
model_names = [
|
||||
model for model, checkbox in self.model_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
split_methods = [
|
||||
method for method, checkbox in self.split_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
|
||||
config = {
|
||||
'feature_start_column': self.feature_start.text(),
|
||||
'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
|
||||
'model_names': model_names if model_names else ['SVR'],
|
||||
'split_methods': split_methods if split_methods else ['random'],
|
||||
'cv_folds': self.cv_folds.value()
|
||||
}
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if training_csv_path:
|
||||
config['training_csv_path'] = training_csv_path
|
||||
output_path = self.output_path.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'feature_start_column' in config:
|
||||
self.feature_start.setText(str(config['feature_start_column']))
|
||||
if 'cv_folds' in config:
|
||||
self.cv_folds.setValue(config['cv_folds'])
|
||||
if 'preprocessing_methods' in config:
|
||||
methods = config['preprocessing_methods']
|
||||
for method, checkbox in self.preproc_checkboxes.items():
|
||||
checkbox.setChecked(method in methods)
|
||||
if 'model_names' in config:
|
||||
models = config['model_names']
|
||||
for model, checkbox in self.model_checkboxes.items():
|
||||
checkbox.setChecked(model in models)
|
||||
if 'split_methods' in config:
|
||||
methods = config['split_methods']
|
||||
for method, checkbox in self.split_checkboxes.items():
|
||||
checkbox.setChecked(method in methods)
|
||||
if 'training_csv_path' in config:
|
||||
self.training_csv_file.set_path(config['training_csv_path'])
|
||||
if 'output_path' in config:
|
||||
self.output_path.set_path(config['output_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'step5_panel'):
|
||||
# 优先直接从 Step5 的输出 widget 读取
|
||||
step5_output = main_window.step5_panel.output_file.get_path()
|
||||
if step5_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_output):
|
||||
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
|
||||
self.training_csv_file.set_path(step5_output)
|
||||
elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
|
||||
# 回退:从 Step5 的 config 字典中查找可能的键名
|
||||
step5_cfg = main_window.step5_panel.get_config()
|
||||
step5_csv = (
|
||||
step5_cfg.get('training_csv_path')
|
||||
or step5_cfg.get('output_file')
|
||||
or step5_cfg.get('csv_path')
|
||||
or step5_cfg.get('output_csv')
|
||||
)
|
||||
if step5_csv:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_csv):
|
||||
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
|
||||
self.training_csv_file.set_path(step5_csv)
|
||||
|
||||
# 2. 自动填充输出文件路径(基于工作目录和输入文件名)
|
||||
# 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
|
||||
# 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
|
||||
if self.work_dir:
|
||||
indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
|
||||
os.makedirs(indices_dir, exist_ok=True)
|
||||
training_csv = self.training_csv_file.get_path()
|
||||
if training_csv:
|
||||
basename = os.path.splitext(os.path.basename(training_csv))[0]
|
||||
output_file = f"{basename}_indices.csv"
|
||||
else:
|
||||
output_file = "water_quality_indices.csv"
|
||||
output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
|
||||
self.output_path.set_path(output_path)
|
||||
else:
|
||||
self.output_path.set_path("")
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤6"""
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if not training_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step6': self.get_config()}
|
||||
main_window.run_single_step('step6', config)
|
||||
|
||||
def get_training_params(self):
|
||||
"""获取模型训练参数"""
|
||||
return {
|
||||
'pipeline_type': 'machine_learning',
|
||||
'feature_start': float(self.feature_start.text()),
|
||||
'cv_folds': self.cv_folds.value(),
|
||||
'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
|
||||
'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
|
||||
'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
|
||||
}
|
||||
@ -1,23 +1,75 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step7 面板 - 采样点生成
|
||||
Step7 面板 - 机器学习建模
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QSpinBox, QMessageBox,
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
|
||||
QPushButton, QFileDialog, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.dialogs import SamplingViewerDialog
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 中文映射表(内部键名 -> 显示文本)
|
||||
# ============================================================
|
||||
|
||||
# 预处理方法:内部键 -> 显示文本
|
||||
PREPROC_CHINESE = {
|
||||
'None': '无 (None)',
|
||||
'MMS': '最小-最大归一化 (MMS)',
|
||||
'SS': '标度化 (SS)',
|
||||
'SNV': '标准正态变换 (SNV)',
|
||||
'MA': '移动平均 (MA)',
|
||||
'SG': 'Savitzky-Golay (SG)',
|
||||
'MSC': '多元散射校正 (MSC)',
|
||||
'D1': '一阶导数 (D1)',
|
||||
'D2': '二阶导数 (D2)',
|
||||
'DT': '去趋势 (DT)',
|
||||
'CT': '中心化 (CT)',
|
||||
}
|
||||
|
||||
# 模型类型:内部键 -> 显示文本
|
||||
MODEL_CHINESE = {
|
||||
# 线性模型
|
||||
'LinearRegression': '多元线性回归 (MLR)',
|
||||
'Ridge': '岭回归 (Ridge)',
|
||||
'Lasso': '套索回归 (Lasso)',
|
||||
'ElasticNet': '弹性网络 (ElasticNet)',
|
||||
'PLS': '偏最小二乘 (PLSR)',
|
||||
# 树模型
|
||||
'DecisionTree': '决策树 (CART)',
|
||||
'RF': '随机森林 (RF)',
|
||||
'ExtraTrees': '极端随机树 (ET)',
|
||||
'XGBoost': '极值梯度提升 (XGBoost)',
|
||||
'LightGBM': '轻量梯度提升 (LightGBM)',
|
||||
'CatBoost': '类别梯度提升 (CatBoost)',
|
||||
# 集成学习
|
||||
'GradientBoosting': '梯度提升树 (GBDT)',
|
||||
'AdaBoost': '自适应提升 (AdaBoost)',
|
||||
# 其他模型
|
||||
'SVR': '支持向量回归 (SVR)',
|
||||
'KNN': 'K近邻回归 (KNN)',
|
||||
'MLP': '多层感知机 (BP神经网络)',
|
||||
}
|
||||
|
||||
# 数据划分方法:内部键 -> 显示文本
|
||||
SPLIT_CHINESE = {
|
||||
'spxy': 'SPXY 算法 (考量X-Y空间)',
|
||||
'ks': 'KS 算法 (考量X空间)',
|
||||
'random': '随机划分 (Random)',
|
||||
}
|
||||
|
||||
|
||||
class Step7Panel(QWidget):
|
||||
"""步骤7:采样点生成"""
|
||||
"""步骤7:机器学习建模"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
@ -25,58 +77,35 @@ class Step7Panel(QWidget):
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 去耀斑影像文件(用于独立运行)
|
||||
self.deglint_img_file = FileSelectWidget(
|
||||
"去耀斑影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.deglint_img_file)
|
||||
# 标题
|
||||
|
||||
# 水域掩膜文件(可选,用于独立运行)
|
||||
self.water_mask_file = FileSelectWidget(
|
||||
"水域掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.water_mask_file.label.setText("水域掩膜:")
|
||||
layout.addWidget(self.water_mask_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("采样参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.interval = QSpinBox()
|
||||
self.interval.setRange(10, 500)
|
||||
self.interval.setValue(50)
|
||||
params_layout.addRow("采样点间隔(像素):", self.interval)
|
||||
|
||||
self.sample_radius = QSpinBox()
|
||||
self.sample_radius.setRange(1, 50)
|
||||
self.sample_radius.setValue(5)
|
||||
params_layout.addRow("采样半径(像素):", self.sample_radius)
|
||||
|
||||
self.chunk_size = QSpinBox()
|
||||
self.chunk_size.setRange(100, 10000)
|
||||
self.chunk_size.setValue(1000)
|
||||
params_layout.addRow("处理块大小:", self.chunk_size)
|
||||
|
||||
self.use_adaptive_sampling = QCheckBox("启用自适应采样")
|
||||
self.use_adaptive_sampling.setChecked(True)
|
||||
params_layout.addRow("采样模式:", self.use_adaptive_sampling)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出采样点:",
|
||||
# 训练数据文件(用于独立运行)
|
||||
self.training_csv_file = FileSelectWidget(
|
||||
"训练数据:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("sampling_points.csv")
|
||||
layout.addWidget(self.output_file)
|
||||
layout.addWidget(self.training_csv_file)
|
||||
|
||||
# 机器学习模型页面
|
||||
self.ml_page = QWidget()
|
||||
self.create_ml_page()
|
||||
layout.addWidget(self.ml_page)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_path = FileSelectWidget(
|
||||
"输出文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)",
|
||||
mode="save"
|
||||
)
|
||||
self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
|
||||
self.output_path.browse_btn.clicked.disconnect()
|
||||
self.output_path.browse_btn.clicked.connect(self.browse_output_path)
|
||||
layout.addWidget(self.output_path)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
self.enable_checkbox.setChecked(False)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
@ -85,57 +114,233 @@ class Step7Panel(QWidget):
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
# 交互式预览按钮
|
||||
self.preview_btn = QPushButton("📊 交互式预览采样点与光谱")
|
||||
self.preview_btn.setEnabled(False)
|
||||
self.preview_btn.clicked.connect(self._open_sampling_viewer)
|
||||
layout.addWidget(self.preview_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
# 监听输出路径变化,实时更新预览按钮状态
|
||||
self.output_file.line_edit.textChanged.connect(self._on_output_changed)
|
||||
def create_ml_page(self):
|
||||
"""创建机器学习模型页面"""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("训练参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.feature_start = QLineEdit()
|
||||
self.feature_start.setText("374.285004")
|
||||
params_layout.addRow("特征起始列:", self.feature_start)
|
||||
|
||||
self.cv_folds = QSpinBox()
|
||||
self.cv_folds.setRange(2, 10)
|
||||
self.cv_folds.setValue(3)
|
||||
params_layout.addRow("交叉验证折数:", self.cv_folds)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 预处理方法 - 多选
|
||||
preproc_group = QGroupBox("预处理方法 (可多选)")
|
||||
preproc_layout = QVBoxLayout()
|
||||
|
||||
preproc_grid = QGridLayout()
|
||||
self.preproc_checkboxes = {}
|
||||
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
|
||||
|
||||
for i, method in enumerate(preproc_methods):
|
||||
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
|
||||
checkbox.setChecked(False)
|
||||
self.preproc_checkboxes[method] = checkbox
|
||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||
|
||||
button_layout = QHBoxLayout()
|
||||
select_all_btn = QPushButton("全选")
|
||||
deselect_all_btn = QPushButton("全不选")
|
||||
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
|
||||
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
|
||||
button_layout.addWidget(select_all_btn)
|
||||
button_layout.addWidget(deselect_all_btn)
|
||||
button_layout.addStretch()
|
||||
|
||||
preproc_layout.addLayout(preproc_grid)
|
||||
preproc_layout.addLayout(button_layout)
|
||||
preproc_group.setLayout(preproc_layout)
|
||||
layout.addWidget(preproc_group)
|
||||
|
||||
# 模型选择 - 多选
|
||||
model_group = QGroupBox("模型类型 (可多选)")
|
||||
model_layout = QVBoxLayout()
|
||||
|
||||
model_grid = QGridLayout()
|
||||
self.model_checkboxes = {}
|
||||
|
||||
model_groups = [
|
||||
("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
|
||||
("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
|
||||
("【集成学习】", ['GradientBoosting', 'AdaBoost']),
|
||||
("【其他模型】", ['SVR', 'KNN', 'MLP'])
|
||||
]
|
||||
|
||||
row = 0
|
||||
for group_name, models in model_groups:
|
||||
group_label = QLabel(f"<b>{group_name}</b>")
|
||||
group_label.setStyleSheet(
|
||||
f"background-color: {ModernStylesheet.COLORS['hover']}; "
|
||||
f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
|
||||
f"border-radius: 3px;"
|
||||
)
|
||||
model_grid.addWidget(group_label, row, 0, 1, 4)
|
||||
row += 1
|
||||
|
||||
for i, model in enumerate(models):
|
||||
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
|
||||
checkbox.setChecked(False)
|
||||
self.model_checkboxes[model] = checkbox
|
||||
model_grid.addWidget(checkbox, row, i % 4)
|
||||
if (i + 1) % 4 == 0:
|
||||
row += 1
|
||||
|
||||
row += 1
|
||||
|
||||
model_button_layout = QHBoxLayout()
|
||||
model_select_all = QPushButton("全选")
|
||||
model_deselect_all = QPushButton("全不选")
|
||||
model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
|
||||
model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
|
||||
model_button_layout.addWidget(model_select_all)
|
||||
model_button_layout.addWidget(model_deselect_all)
|
||||
model_button_layout.addStretch()
|
||||
|
||||
model_layout.addLayout(model_grid)
|
||||
model_layout.addLayout(model_button_layout)
|
||||
model_group.setLayout(model_layout)
|
||||
layout.addWidget(model_group)
|
||||
|
||||
# 数据划分方法 - 多选
|
||||
split_group = QGroupBox("数据划分方法 (可多选)")
|
||||
split_layout = QVBoxLayout()
|
||||
|
||||
split_grid = QGridLayout()
|
||||
self.split_checkboxes = {}
|
||||
split_methods = ['spxy', 'ks', 'random']
|
||||
|
||||
for i, method in enumerate(split_methods):
|
||||
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
|
||||
checkbox.setChecked(False)
|
||||
self.split_checkboxes[method] = checkbox
|
||||
split_grid.addWidget(checkbox, 0, i)
|
||||
|
||||
split_button_layout = QHBoxLayout()
|
||||
split_select_all = QPushButton("全选")
|
||||
split_deselect_all = QPushButton("全不选")
|
||||
split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
|
||||
split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
|
||||
split_button_layout.addWidget(split_select_all)
|
||||
split_button_layout.addWidget(split_deselect_all)
|
||||
split_button_layout.addStretch()
|
||||
|
||||
split_layout.addLayout(split_grid)
|
||||
split_layout.addLayout(split_button_layout)
|
||||
split_group.setLayout(split_layout)
|
||||
layout.addWidget(split_group)
|
||||
|
||||
self.ml_page.setLayout(layout)
|
||||
|
||||
def _toggle_checkboxes(self, checkboxes_dict, checked):
|
||||
"""统一设置checkbox状态"""
|
||||
for checkbox in checkboxes_dict.values():
|
||||
checkbox.setChecked(checked)
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_output_path(self):
|
||||
"""浏览输出文件路径(保存对话框)"""
|
||||
current = self.output_path.get_path().strip()
|
||||
if current:
|
||||
initial_dir = os.path.dirname(current)
|
||||
initial_file = os.path.basename(current)
|
||||
else:
|
||||
initial_dir = ""
|
||||
initial_file = ""
|
||||
|
||||
if not initial_dir or not os.path.isdir(initial_dir):
|
||||
# 默认定位到 indices 目录
|
||||
work_dir = self._get_default_work_dir()
|
||||
initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
|
||||
if initial_dir and not os.path.isdir(initial_dir):
|
||||
os.makedirs(initial_dir, exist_ok=True)
|
||||
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
if file_path:
|
||||
self.output_path.set_path(file_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
preprocessing_methods = [
|
||||
method for method, checkbox in self.preproc_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
model_names = [
|
||||
model for model, checkbox in self.model_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
split_methods = [
|
||||
method for method, checkbox in self.split_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
|
||||
config = {
|
||||
'interval': self.interval.value(),
|
||||
'sample_radius': self.sample_radius.value(),
|
||||
'chunk_size': self.chunk_size.value(),
|
||||
'use_adaptive_sampling': self.use_adaptive_sampling.isChecked(),
|
||||
'feature_start_column': self.feature_start.text(),
|
||||
'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
|
||||
'model_names': model_names if model_names else ['SVR'],
|
||||
'split_methods': split_methods if split_methods else ['random'],
|
||||
'cv_folds': self.cv_folds.value()
|
||||
}
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
if deglint_img_path:
|
||||
config['deglint_img_path'] = deglint_img_path
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['water_mask_path'] = water_mask_path
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if training_csv_path:
|
||||
config['training_csv_path'] = training_csv_path
|
||||
output_path = self.output_path.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'interval' in config:
|
||||
self.interval.setValue(config['interval'])
|
||||
if 'sample_radius' in config:
|
||||
self.sample_radius.setValue(config['sample_radius'])
|
||||
if 'chunk_size' in config:
|
||||
self.chunk_size.setValue(config['chunk_size'])
|
||||
if 'use_adaptive_sampling' in config:
|
||||
self.use_adaptive_sampling.setChecked(config['use_adaptive_sampling'])
|
||||
if 'deglint_img_path' in config:
|
||||
self.deglint_img_file.set_path(config['deglint_img_path'])
|
||||
if 'water_mask_path' in config:
|
||||
self.water_mask_file.set_path(config['water_mask_path'])
|
||||
if 'glint_mask_path' in config:
|
||||
self.glint_mask_file.set_path(config['glint_mask_path'])
|
||||
if 'feature_start_column' in config:
|
||||
self.feature_start.setText(str(config['feature_start_column']))
|
||||
if 'cv_folds' in config:
|
||||
self.cv_folds.setValue(config['cv_folds'])
|
||||
if 'preprocessing_methods' in config:
|
||||
methods = config['preprocessing_methods']
|
||||
for method, checkbox in self.preproc_checkboxes.items():
|
||||
checkbox.setChecked(method in methods)
|
||||
if 'model_names' in config:
|
||||
models = config['model_names']
|
||||
for model, checkbox in self.model_checkboxes.items():
|
||||
checkbox.setChecked(model in models)
|
||||
if 'split_methods' in config:
|
||||
methods = config['split_methods']
|
||||
for method, checkbox in self.split_checkboxes.items():
|
||||
checkbox.setChecked(method in methods)
|
||||
if 'training_csv_path' in config:
|
||||
self.training_csv_file.set_path(config['training_csv_path'])
|
||||
if 'output_path' in config:
|
||||
self.output_path.set_path(config['output_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充去耀斑影像和掩膜路径
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径)
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
@ -144,81 +349,53 @@ class Step7Panel(QWidget):
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'step5_panel'):
|
||||
# 优先直接从 Step5 的输出 widget 读取
|
||||
step5_output = main_window.step5_panel.output_file.get_path()
|
||||
if step5_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_output):
|
||||
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
|
||||
self.training_csv_file.set_path(step5_output)
|
||||
elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
|
||||
# 回退:从 Step5 的 config 字典中查找可能的键名
|
||||
step5_cfg = main_window.step5_panel.get_config()
|
||||
step5_csv = (
|
||||
step5_cfg.get('training_csv_path')
|
||||
or step5_cfg.get('output_file')
|
||||
or step5_cfg.get('csv_path')
|
||||
or step5_cfg.get('output_csv')
|
||||
)
|
||||
if step5_csv:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_csv):
|
||||
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
|
||||
self.training_csv_file.set_path(step5_csv)
|
||||
|
||||
# 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径)
|
||||
deglint_path = None
|
||||
if pipeline and hasattr(pipeline, 'step_outputs'):
|
||||
step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {})
|
||||
deglint_path = (
|
||||
step3_outputs.get('deglint_image')
|
||||
or step3_outputs.get('output_path')
|
||||
or step3_outputs.get('output_file')
|
||||
or step3_outputs.get('deglint_img_path')
|
||||
)
|
||||
# 回退:从 step3 面板 widget 直接读取(可能是相对路径)
|
||||
if not deglint_path and hasattr(main_window, 'step3_panel'):
|
||||
deglint_path = main_window.step3_panel.output_file.get_path()
|
||||
|
||||
if deglint_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(deglint_path):
|
||||
deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/')
|
||||
self.deglint_img_file.set_path(deglint_path)
|
||||
|
||||
# 2. 填充水域掩膜路径(优先级:pipeline.step_outputs > step1_panel > 1_water_mask > input-test)
|
||||
water_mask_path = None
|
||||
if pipeline and hasattr(pipeline, 'step_outputs'):
|
||||
step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {})
|
||||
water_mask_path = (
|
||||
step1_outputs.get('water_mask')
|
||||
or step1_outputs.get('output_path')
|
||||
or step1_outputs.get('output_file')
|
||||
)
|
||||
# 回退:从 step1 面板 widget 直接读取
|
||||
if not water_mask_path and hasattr(main_window, 'step1_panel'):
|
||||
water_mask_path = main_window.step1_panel.output_file.get_path()
|
||||
# 备选:扫描 1_water_mask 目录下的 .dat 文件
|
||||
if not water_mask_path and self.work_dir:
|
||||
mask_dir = os.path.join(self.work_dir, "1_water_mask")
|
||||
if os.path.isdir(mask_dir):
|
||||
dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')]
|
||||
if dat_files:
|
||||
water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/')
|
||||
# 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat)
|
||||
if not water_mask_path and self.work_dir:
|
||||
input_test_dir = os.path.join(self.work_dir, "input-test")
|
||||
if os.path.isdir(input_test_dir):
|
||||
dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')]
|
||||
# 优先匹配 water_mask_from_shp.dat
|
||||
for f in dat_files:
|
||||
if 'water_mask_from_shp' in f.lower():
|
||||
water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/')
|
||||
break
|
||||
# 否则取第一个 .dat 文件
|
||||
if not water_mask_path and dat_files:
|
||||
water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/')
|
||||
|
||||
if water_mask_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(water_mask_path):
|
||||
water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/')
|
||||
self.water_mask_file.set_path(water_mask_path)
|
||||
|
||||
# 3. 自动填充输出路径(绝对路径)
|
||||
# 2. 自动填充输出文件路径(基于工作目录和输入文件名)
|
||||
# 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
|
||||
# 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
|
||||
if self.work_dir:
|
||||
output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv")
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
self.output_file.set_path(output_path.replace('\\', '/'))
|
||||
|
||||
# 4. 同步更新预览按钮状态(路径可能已自动填充)
|
||||
self._check_csv_exists()
|
||||
indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
|
||||
os.makedirs(indices_dir, exist_ok=True)
|
||||
training_csv = self.training_csv_file.get_path()
|
||||
if training_csv:
|
||||
basename = os.path.splitext(os.path.basename(training_csv))[0]
|
||||
output_file = f"{basename}_indices.csv"
|
||||
else:
|
||||
output_file = "water_quality_indices.csv"
|
||||
output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
|
||||
self.output_path.set_path(output_path)
|
||||
else:
|
||||
self.output_path.set_path("")
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤7"""
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
if not deglint_img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if not training_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
@ -226,27 +403,13 @@ class Step7Panel(QWidget):
|
||||
config = {'step7': self.get_config()}
|
||||
main_window.run_single_step('step7', config)
|
||||
|
||||
def _check_csv_exists(self):
|
||||
"""检查 output csv 是否存在,驱动预览按钮启停"""
|
||||
csv_path = self.output_file.get_path()
|
||||
enabled = bool(csv_path and os.path.isabs(csv_path) and os.path.exists(csv_path))
|
||||
self.preview_btn.setEnabled(enabled)
|
||||
return enabled
|
||||
|
||||
def _on_output_changed(self, _text=None):
|
||||
"""输出路径输入框内容变化时调用(_text 为 line_edit.textChanged 信号参数)"""
|
||||
self._check_csv_exists()
|
||||
|
||||
def _open_sampling_viewer(self):
|
||||
"""打开交互式采样点查看器弹窗"""
|
||||
csv_path = self.output_file.get_path()
|
||||
if not csv_path or not os.path.exists(csv_path):
|
||||
QMessageBox.warning(
|
||||
self, "文件不存在",
|
||||
f"采样点 CSV 文件不存在:{csv_path}\n请先运行步骤7生成数据。"
|
||||
)
|
||||
return
|
||||
dialog = SamplingViewerDialog(csv_path, self)
|
||||
dialog.exec_()
|
||||
# 弹窗关闭后再次检查状态(可能文件被覆盖等)
|
||||
self._check_csv_exists()
|
||||
def get_training_params(self):
|
||||
"""获取模型训练参数"""
|
||||
return {
|
||||
'pipeline_type': 'machine_learning',
|
||||
'feature_start': float(self.feature_start.text()),
|
||||
'cv_folds': self.cv_folds.value(),
|
||||
'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
|
||||
'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
|
||||
'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step6_5 面板 - 非经验统计回归建模
|
||||
Step8 面板 - 非经验统计回归建模
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -17,8 +17,8 @@ from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step6_5Panel(QWidget):
|
||||
"""步骤6.5:非经验统计回归建模"""
|
||||
class Step8NonEmpiricalPanel(QWidget):
|
||||
"""步骤8:非经验统计回归建模"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
@ -280,7 +280,7 @@ class Step6_5Panel(QWidget):
|
||||
self.output_dir.set_path(dir_path)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤6.5"""
|
||||
"""独立运行步骤8"""
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if not training_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
|
||||
@ -297,7 +297,7 @@ class Step6_5Panel(QWidget):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step6_5', {'step6_5': config})
|
||||
parent.run_single_step('step8_non_empirical_modeling', {'step8_non_empirical_modeling': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
|
||||
@ -1,462 +1,225 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8 面板 - 机器学习预测
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
|
||||
QFileDialog, QRadioButton, QListWidget, QAbstractItemView, QHBoxLayout,
|
||||
QListWidgetItem,
|
||||
QWidget, QVBoxLayout, QGroupBox, QGridLayout,
|
||||
QHBoxLayout, QLabel, QCheckBox, QPushButton, QMessageBox, QScrollArea
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
def get_resource_path(relative_path: str) -> str:
|
||||
"""适配开发与 PyInstaller 环境的路径获取逻辑。
|
||||
支持两种打包模式:
|
||||
1. --onedir 模式:文件在 exe_root/_internal/ 下 → 检查 _internal 目录
|
||||
2. --onefile 模式:文件在 sys._MEIPASS 平铺目录
|
||||
"""
|
||||
# 优先检查 PyInstaller onefile 模式(文件平铺在 _MEIPASS 下)
|
||||
if hasattr(sys, '_MEIPASS'):
|
||||
internal_path = os.path.join(sys._MEIPASS, '_internal', relative_path)
|
||||
if os.path.exists(internal_path):
|
||||
return internal_path
|
||||
return os.path.join(sys._MEIPASS, relative_path)
|
||||
|
||||
# 兼容 PyInstaller onedir 模式的 _internal 目录(exe 同级目录下)
|
||||
exe_dir = os.path.dirname(sys.executable)
|
||||
internal_path = os.path.join(exe_dir, '_internal', relative_path)
|
||||
if os.path.exists(internal_path):
|
||||
return internal_path
|
||||
|
||||
# 开发环境下:基于当前文件 (step8_panel.py) 的绝对路径进行回溯
|
||||
# 当前在 src/gui/panels/,目标在 src/gui/model/
|
||||
base_dir = Path(__file__).resolve().parent.parent / "model"
|
||||
target_path = base_dir / os.path.basename(relative_path)
|
||||
return str(target_path)
|
||||
|
||||
class Step8Panel(QWidget):
|
||||
"""步骤8:机器学习预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.external_models_dict = {} # {subdir_name: model_obj, ...}
|
||||
self.external_model_dir = "" # 母文件夹路径(隐藏)
|
||||
self.index_checkboxes: Dict[str, QCheckBox] = {}
|
||||
# 标识为 waterindex.csv,目录跳转逻辑在 get_resource_path 中
|
||||
self.builtin_formula_path = get_resource_path("waterindex.csv")
|
||||
|
||||
self.init_ui()
|
||||
# 延迟一小会儿加载,确保UI框架已就绪
|
||||
self._auto_load_formulas()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
main_layout = QVBoxLayout()
|
||||
main_layout.setContentsMargins(20, 20, 20, 20)
|
||||
main_layout.setSpacing(10)
|
||||
|
||||
# -------- 模型来源选择(单选按钮组) --------
|
||||
source_group = QGroupBox("模型来源")
|
||||
source_layout = QVBoxLayout()
|
||||
# 1. 路径展示区 (半透明只读)
|
||||
path_group = QGroupBox("公式配置源 (内置)")
|
||||
path_layout = QVBoxLayout()
|
||||
self.formula_csv_widget = FileSelectWidget("内置CSV路径:", "CSV Files (*.csv)")
|
||||
self.formula_csv_widget.set_path(self.builtin_formula_path)
|
||||
self.formula_csv_widget.set_read_only(True)
|
||||
# 视觉微调:提示用户这是内置的
|
||||
self.formula_csv_widget.line_edit.setStyleSheet("background-color: #f0f0f0; color: #666;")
|
||||
path_layout.addWidget(self.formula_csv_widget)
|
||||
path_group.setLayout(path_layout)
|
||||
main_layout.addWidget(path_group)
|
||||
|
||||
self.use_trained_model = QRadioButton("使用当前训练流程的模型")
|
||||
self.use_external_model = QRadioButton("导入本地预训练模型 (.joblib)")
|
||||
self.use_trained_model.setChecked(True)
|
||||
source_layout.addWidget(self.use_trained_model)
|
||||
source_layout.addWidget(self.use_external_model)
|
||||
# 2. 训练数据输入
|
||||
input_group = QGroupBox("输入样本数据")
|
||||
input_layout = QVBoxLayout()
|
||||
self.training_data_widget = FileSelectWidget("特征提取CSV:", "CSV Files (*.csv)")
|
||||
input_layout.addWidget(self.training_data_widget)
|
||||
input_group.setLayout(input_layout)
|
||||
main_layout.addWidget(input_group)
|
||||
|
||||
self.use_trained_model.toggled.connect(self._on_model_source_changed)
|
||||
self.use_external_model.toggled.connect(self._on_model_source_changed)
|
||||
# 3. 公式选择区
|
||||
self.formula_group = QGroupBox("待计算水质指数勾选")
|
||||
formula_outer_layout = QVBoxLayout()
|
||||
|
||||
source_group.setStyleSheet("""
|
||||
QRadioButton {
|
||||
font-size: 13px;
|
||||
spacing: 8px;
|
||||
}
|
||||
QRadioButton::indicator {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border-radius: 9px;
|
||||
border: 2px solid #A0A0A0;
|
||||
background-color: #FFFFFF;
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
background-color: #0078D7;
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
""")
|
||||
btn_layout = QHBoxLayout()
|
||||
self.select_all_btn = QPushButton("全选")
|
||||
self.deselect_all_btn = QPushButton("清空")
|
||||
self.select_all_btn.clicked.connect(self.select_all_formulas)
|
||||
self.deselect_all_btn.clicked.connect(self.deselect_all_formulas)
|
||||
btn_layout.addWidget(self.select_all_btn)
|
||||
btn_layout.addWidget(self.deselect_all_btn)
|
||||
btn_layout.addStretch()
|
||||
|
||||
source_group.setLayout(source_layout)
|
||||
layout.addWidget(source_group)
|
||||
self.refresh_button = QPushButton("手动重新加载公式")
|
||||
self.refresh_button.clicked.connect(lambda: self.refresh_formulas(silent=False))
|
||||
btn_layout.addWidget(self.refresh_button)
|
||||
|
||||
# -------- 外部模型文件选择(条件显示) --------
|
||||
self.external_model_widget = FileSelectWidget(
|
||||
"模型母文件夹:",
|
||||
"Directories"
|
||||
)
|
||||
self.external_model_widget.browse_btn.clicked.disconnect()
|
||||
self.external_model_widget.browse_btn.clicked.connect(self._scan_external_model_dir)
|
||||
self.external_model_widget.setVisible(False)
|
||||
layout.addWidget(self.external_model_widget)
|
||||
formula_outer_layout.addLayout(btn_layout)
|
||||
|
||||
# -------- 已扫描模型列表(条件显示) --------
|
||||
self.model_list_group = QGroupBox("选择参与预测的模型")
|
||||
self.model_list_group.setVisible(False)
|
||||
model_list_layout = QVBoxLayout()
|
||||
# 核心滚动区
|
||||
scroll = QScrollArea()
|
||||
scroll.setWidgetResizable(True)
|
||||
scroll.setMinimumHeight(300) # 强制最小高度,防止塌陷
|
||||
self.scroll_content = QWidget()
|
||||
self.formula_layout = QGridLayout(self.scroll_content)
|
||||
self.formula_layout.setAlignment(Qt.AlignTop) # 靠顶对齐
|
||||
scroll.setWidget(self.scroll_content)
|
||||
formula_outer_layout.addWidget(scroll)
|
||||
|
||||
self.model_list = QListWidget()
|
||||
self.model_list.setMaximumHeight(130)
|
||||
self.model_list.setSelectionMode(QAbstractItemView.NoSelection)
|
||||
self.model_list.setStyleSheet("""
|
||||
QListWidget {
|
||||
border: 1px solid #C0C0C0;
|
||||
border-radius: 4px;
|
||||
background-color: #FFFFFF;
|
||||
font-size: 12px;
|
||||
}
|
||||
QListWidget::item {
|
||||
padding: 4px 6px;
|
||||
border-bottom: 1px solid #F0F0F0;
|
||||
}
|
||||
QListWidget::item:selected {
|
||||
background-color: transparent;
|
||||
}
|
||||
""")
|
||||
model_list_layout.addWidget(self.model_list)
|
||||
self.formula_group.setLayout(formula_outer_layout)
|
||||
main_layout.addWidget(self.formula_group)
|
||||
|
||||
btn_row = QHBoxLayout()
|
||||
self.btn_select_all = QPushButton("全选")
|
||||
self.btn_select_all.setMaximumWidth(80)
|
||||
self.btn_select_all.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
|
||||
self.btn_select_all.clicked.connect(self._select_all_models)
|
||||
# 4. 输出与运行
|
||||
output_group = QGroupBox("结果输出")
|
||||
output_layout = QVBoxLayout()
|
||||
self.output_file_widget = FileSelectWidget("保存路径:", "CSV Files (*.csv)", mode="save")
|
||||
output_layout.addWidget(self.output_file_widget)
|
||||
output_group.setLayout(output_layout)
|
||||
main_layout.addWidget(output_group)
|
||||
|
||||
self.btn_select_none = QPushButton("全不选")
|
||||
self.btn_select_none.setMaximumWidth(80)
|
||||
self.btn_select_none.setStyleSheet(ModernStylesheet.get_button_stylesheet('default'))
|
||||
self.btn_select_none.clicked.connect(self._select_none_models)
|
||||
|
||||
btn_row.addWidget(self.btn_select_all)
|
||||
btn_row.addWidget(self.btn_select_none)
|
||||
btn_row.addStretch()
|
||||
model_list_layout.addLayout(btn_row)
|
||||
|
||||
self.model_list_group.setLayout(model_list_layout)
|
||||
layout.addWidget(self.model_list_group)
|
||||
|
||||
# -------- 采样光谱CSV文件(用于独立运行)--------
|
||||
self.sampling_csv_file = FileSelectWidget(
|
||||
"采样光谱CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.sampling_csv_file)
|
||||
|
||||
# 模型目录(用于独立运行)
|
||||
self.models_dir_file = FileSelectWidget(
|
||||
"模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.models_dir_file.label.setText("模型目录:")
|
||||
self.models_dir_file.browse_btn.clicked.disconnect()
|
||||
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
|
||||
layout.addWidget(self.models_dir_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("预测参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.metric = QComboBox()
|
||||
self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
|
||||
params_layout.addRow("模型选择指标:", self.metric)
|
||||
|
||||
self.prediction_column = QLineEdit()
|
||||
self.prediction_column.setText("prediction")
|
||||
params_layout.addRow("预测列名:", self.prediction_column)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出路径:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox = QCheckBox("启用计算流程")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
main_layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
self.run_button = QPushButton("立即执行计算")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.setMinimumHeight(40)
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
main_layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
self.setLayout(main_layout)
|
||||
|
||||
def _on_model_source_changed(self, checked: bool):
|
||||
"""单选按钮切换:控制外部模型文件选择控件的显示/隐藏"""
|
||||
if not checked:
|
||||
def _auto_load_formulas(self):
|
||||
"""启动时自动加载逻辑"""
|
||||
if os.path.exists(self.builtin_formula_path):
|
||||
self.refresh_formulas(silent=True)
|
||||
else:
|
||||
print(f"DEBUG: 自动加载失败,路径不存在: {self.builtin_formula_path}")
|
||||
|
||||
def refresh_formulas(self, silent=False):
|
||||
path = self.builtin_formula_path
|
||||
if not os.path.exists(path):
|
||||
if not silent: QMessageBox.warning(self, "错误", f"找不到内置公式文件:\n{path}")
|
||||
return
|
||||
is_external = self.use_external_model.isChecked()
|
||||
self.external_model_widget.setVisible(is_external)
|
||||
self.model_list_group.setVisible(is_external)
|
||||
if not is_external:
|
||||
self.external_models_dict = {}
|
||||
self.external_model_dir = ""
|
||||
self._clear_model_list()
|
||||
|
||||
def _scan_external_model_dir(self):
|
||||
"""浏览模型母文件夹,自动扫描子目录中的 .joblib 文件"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "7_Supervised_Model_Training")
|
||||
dir_path = QFileDialog.getExistingDirectory(
|
||||
self,
|
||||
"选择模型母文件夹",
|
||||
default,
|
||||
)
|
||||
if not dir_path:
|
||||
return
|
||||
|
||||
self.external_model_dir = dir_path
|
||||
models_found = {}
|
||||
errors = []
|
||||
|
||||
try:
|
||||
import joblib
|
||||
# 清理旧列表
|
||||
for i in reversed(range(self.formula_layout.count())):
|
||||
widget = self.formula_layout.itemAt(i).widget()
|
||||
if widget: widget.deleteLater()
|
||||
self.index_checkboxes.clear()
|
||||
|
||||
for subentry in os.scandir(dir_path):
|
||||
if not subentry.is_dir():
|
||||
continue
|
||||
subdir_name = subentry.name
|
||||
joblib_files = [
|
||||
f for f in os.scandir(subentry.path)
|
||||
if f.is_file() and f.name.lower().endswith(".joblib")
|
||||
]
|
||||
if not joblib_files:
|
||||
continue
|
||||
# 每个子目录只取第一个 .joblib 文件(与 batch 逻辑一致)
|
||||
joblib_path = joblib_files[0].path
|
||||
# 鲁棒性读取:尝试不同编码
|
||||
for encoding in ['utf-8', 'gbk', 'utf-8-sig']:
|
||||
try:
|
||||
loaded = joblib.load(joblib_path)
|
||||
if isinstance(loaded, dict) and "model" in loaded:
|
||||
model_obj = loaded["model"]
|
||||
elif hasattr(loaded, "predict"):
|
||||
model_obj = loaded
|
||||
else:
|
||||
errors.append(f"{subdir_name}: 无法识别的格式 {type(loaded).__name__}")
|
||||
continue
|
||||
models_found[subdir_name] = model_obj
|
||||
except Exception as e:
|
||||
errors.append(f"{subdir_name}: {type(e).__name__}: {e}")
|
||||
df = pd.read_csv(path, encoding=encoding)
|
||||
if 'Formula_Name' in df.columns: break
|
||||
except: continue
|
||||
|
||||
if 'Formula_Name' not in df.columns:
|
||||
if not silent: QMessageBox.critical(self, "错误", "CSV文件缺少 'Formula_Name' 列")
|
||||
return
|
||||
|
||||
names = df['Formula_Name'].dropna().unique().tolist()
|
||||
|
||||
row, col = 0, 0
|
||||
for name in names:
|
||||
name = str(name).strip()
|
||||
if not name: continue
|
||||
cb = QCheckBox(name)
|
||||
cb.setChecked(True)
|
||||
self.index_checkboxes[name] = cb
|
||||
self.formula_layout.addWidget(cb, row, col)
|
||||
col += 1
|
||||
if col >= 3:
|
||||
col = 0
|
||||
row += 1
|
||||
|
||||
# 强制UI更新
|
||||
self.scroll_content.adjustSize()
|
||||
print(f"✅ 成功加载 {len(self.index_checkboxes)} 个公式")
|
||||
|
||||
except Exception as e:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"扫描失败",
|
||||
f"遍历模型目录时发生错误:\n{type(e).__name__}: {e}",
|
||||
)
|
||||
return
|
||||
if not silent: QMessageBox.critical(self, "加载失败", f"原因: {str(e)}")
|
||||
|
||||
if not models_found:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"未找到模型",
|
||||
f"在「{dir_path}」的子目录中未发现任何 .joblib 文件。\n"
|
||||
"请确认每个水质参数对应一个子文件夹,内含 .joblib 模型文件。",
|
||||
)
|
||||
self.external_model_widget.set_path("")
|
||||
self.external_models_dict = {}
|
||||
self._clear_model_list()
|
||||
return
|
||||
def select_all_formulas(self):
|
||||
for cb in self.index_checkboxes.values(): cb.setChecked(True)
|
||||
|
||||
self.external_models_dict = models_found
|
||||
self._populate_model_list(models_found)
|
||||
names = sorted(models_found.keys())
|
||||
display = f"已识别到 {len(names)} 个模型: {', '.join(names)}"
|
||||
self.external_model_widget.set_path(display)
|
||||
self.external_model_widget.line_edit.setStyleSheet("color: #0078D7; font-weight: bold;")
|
||||
|
||||
err_lines = "\n".join(errors) if errors else "无"
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"模型扫描完成",
|
||||
f"成功加载 {len(models_found)} 个模型:\n{display}\n\n"
|
||||
f"加载失败 {len(errors)} 个:\n{err_lines}",
|
||||
)
|
||||
|
||||
def _populate_model_list(self, models_dict):
|
||||
"""将扫描到的模型填充到 QListWidget,每个条目可勾选,默认全选"""
|
||||
self.model_list.clear()
|
||||
for name in sorted(models_dict.keys()):
|
||||
item = QListWidgetItem(name)
|
||||
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
|
||||
item.setCheckState(Qt.Checked)
|
||||
self.model_list.addItem(item)
|
||||
|
||||
def _clear_model_list(self):
|
||||
"""清空模型列表"""
|
||||
self.model_list.clear()
|
||||
|
||||
def _select_all_models(self):
|
||||
"""全选:设置所有条目为 Checked"""
|
||||
for i in range(self.model_list.count()):
|
||||
self.model_list.item(i).setCheckState(Qt.Checked)
|
||||
|
||||
def _select_none_models(self):
|
||||
"""全不选:设置所有条目为 Unchecked"""
|
||||
for i in range(self.model_list.count()):
|
||||
self.model_list.item(i).setCheckState(Qt.Unchecked)
|
||||
|
||||
def _get_checked_models_dict(self):
|
||||
"""从列表中提取用户勾选的模型,组装成字典返回"""
|
||||
result = {}
|
||||
for i in range(self.model_list.count()):
|
||||
item = self.model_list.item(i)
|
||||
if item.checkState() == Qt.Checked:
|
||||
name = item.text()
|
||||
if name in self.external_models_dict:
|
||||
result[name] = self.external_models_dict[name]
|
||||
return result
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充采样光谱和模型目录
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
|
||||
if main_window and hasattr(main_window, 'step7_panel'):
|
||||
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
|
||||
step7_output_path = ""
|
||||
if hasattr(step7_widget, 'get_path'):
|
||||
step7_output_path = step7_widget.get_path() or ""
|
||||
elif hasattr(step7_widget, 'text'):
|
||||
step7_output_path = step7_widget.text() or ""
|
||||
|
||||
if step7_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step7_output_path):
|
||||
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
|
||||
existing = self.sampling_csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.sampling_csv_file.set_path(step7_output_path)
|
||||
|
||||
# 2. 尝试从 Step6 界面读取监督模型目录
|
||||
if main_window and hasattr(main_window, 'step6_panel'):
|
||||
step6_widget = getattr(main_window.step6_panel, 'output_dir', None)
|
||||
step6_models_dir = ""
|
||||
if hasattr(step6_widget, 'get_path'):
|
||||
step6_models_dir = step6_widget.get_path() or ""
|
||||
elif hasattr(step6_widget, 'text'):
|
||||
step6_models_dir = step6_widget.text() or ""
|
||||
|
||||
if step6_models_dir:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step6_models_dir):
|
||||
step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/')
|
||||
existing_models = self.models_dir_file.get_path()
|
||||
if not existing_models or not existing_models.strip():
|
||||
self.models_dir_file.set_path(step6_models_dir)
|
||||
|
||||
# 3. 自动填充输出路径(机器学习预测目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_file.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_file.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_models_dir(self):
|
||||
"""浏览模型目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "7_Supervised_Model_Training")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
|
||||
if dir_path:
|
||||
self.models_dir_file.set_path(dir_path)
|
||||
def deselect_all_formulas(self):
|
||||
for cb in self.index_checkboxes.values(): cb.setChecked(False)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'metric': self.metric.currentText(),
|
||||
'prediction_column': self.prediction_column.text(),
|
||||
selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
|
||||
return {
|
||||
'training_csv_path': self.training_data_widget.get_path(),
|
||||
'formula_csv_file': self.builtin_formula_path,
|
||||
'formula_names': selected,
|
||||
'output_file': self.output_file_widget.get_path(),
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if sampling_csv_path:
|
||||
config['sampling_csv_path'] = sampling_csv_path
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if models_dir:
|
||||
config['models_dir'] = models_dir
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'metric' in config:
|
||||
idx = self.metric.findText(config['metric'])
|
||||
if idx >= 0:
|
||||
self.metric.setCurrentIndex(idx)
|
||||
if 'prediction_column' in config:
|
||||
self.prediction_column.setText(config['prediction_column'])
|
||||
if 'sampling_csv_path' in config:
|
||||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||||
if 'models_dir' in config:
|
||||
self.models_dir_file.set_path(config['models_dir'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
if 'training_csv_path' in config: self.training_data_widget.set_path(config['training_csv_path'])
|
||||
if 'formula_names' in config:
|
||||
sel = set(config['formula_names'])
|
||||
for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
|
||||
if 'output_file' in config: self.output_file_widget.set_path(config['output_file'])
|
||||
self.enable_checkbox.setChecked(config.get('enabled', True))
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
if work_dir: self.work_dir = work_dir
|
||||
main = self.window()
|
||||
if hasattr(main, 'step5_panel'):
|
||||
p5 = main.step5_panel.output_file.get_path() # 修正:变量名对齐
|
||||
if p5:
|
||||
if not os.path.isabs(p5): p5 = os.path.join(self.work_dir or '', p5).replace('\\', '/')
|
||||
self.training_data_widget.set_path(p5)
|
||||
|
||||
if self.work_dir:
|
||||
out = os.path.join(self.work_dir, "6_water_quality_indices", "training_spectra_indices.csv").replace('\\', '/')
|
||||
self.output_file_widget.set_path(out)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤8"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
config = self.get_config()
|
||||
if not config['training_csv_path']:
|
||||
QMessageBox.warning(self, "提示", "请先选择输入数据")
|
||||
return
|
||||
|
||||
# 外部模型优先:用户选择了"导入本地预训练模型"
|
||||
if self.use_external_model.isChecked():
|
||||
if not self.external_models_dict:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"模型未加载",
|
||||
"请先点击「浏览...」按钮选择模型母文件夹!",
|
||||
)
|
||||
return
|
||||
# 只传递用户勾选的模型
|
||||
checked_dict = self._get_checked_models_dict()
|
||||
if not checked_dict:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"未选择模型",
|
||||
"请至少勾选一个模型参与预测!",
|
||||
)
|
||||
return
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {
|
||||
'step8': self.get_config(),
|
||||
'_external_models_dict': checked_dict,
|
||||
'_external_model_dir': self.external_model_dir,
|
||||
}
|
||||
main_window.run_single_step('step8', config)
|
||||
return
|
||||
|
||||
# 默认流程:使用模型目录
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if not models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step8': self.get_config()}
|
||||
main_window.run_single_step('step8', config)
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'): parent = parent.parent()
|
||||
if parent: parent.run_single_step('step8', {'step8': config})
|
||||
@ -1,206 +1,158 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step9 面板 - 分布图生成
|
||||
Step9 面板 - 自定义回归分析
|
||||
"""
|
||||
|
||||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Dict
|
||||
|
||||
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||||
import pandas as pd
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
|
||||
QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
|
||||
QRadioButton, QButtonGroup, QMessageBox, QFileDialog,
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton,
|
||||
QScrollArea, QMessageBox,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
|
||||
try:
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
PIPELINE_AVAILABLE = True
|
||||
except ImportError:
|
||||
PIPELINE_AVAILABLE = False
|
||||
|
||||
|
||||
class Step9BatchThread(QThread):
|
||||
"""专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
|
||||
|
||||
finished_ok = pyqtSignal(int)
|
||||
failed = pyqtSignal(str)
|
||||
log_message = pyqtSignal(str, str)
|
||||
|
||||
def __init__(self, work_dir: str, csv_paths: List[str], step9_kwargs: dict, output_dir_optional: Optional[str]):
|
||||
super().__init__()
|
||||
self.work_dir = work_dir
|
||||
self.csv_paths = csv_paths
|
||||
self.step9_kwargs = step9_kwargs
|
||||
self.output_dir_optional = (output_dir_optional or "").strip() or None
|
||||
|
||||
def run(self):
|
||||
mpl_prev = None
|
||||
try:
|
||||
import matplotlib
|
||||
mpl_prev = matplotlib.get_backend()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("Agg")
|
||||
except Exception:
|
||||
mpl_prev = None
|
||||
try:
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
|
||||
n = len(self.csv_paths)
|
||||
for i, csv_p in enumerate(self.csv_paths):
|
||||
self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
|
||||
kw = {**self.step9_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
|
||||
if self.output_dir_optional:
|
||||
stem = Path(csv_p).stem
|
||||
kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
|
||||
else:
|
||||
kw["output_image_path"] = None
|
||||
pipeline.step9_generate_distribution_map(**kw)
|
||||
self.finished_ok.emit(n)
|
||||
except Exception as e:
|
||||
self.failed.emit(f"{e}\n{traceback.format_exc()}")
|
||||
finally:
|
||||
if mpl_prev:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend(mpl_prev)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class Step9Panel(QWidget):
|
||||
"""步骤9:分布图生成"""
|
||||
"""步骤9:自定义回归分析"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._batch_thread = None
|
||||
self.x_column_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.y_column_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.method_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.csv_columns = []
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
hint = QLabel(
|
||||
"独立运行:可选「单个 CSV」或「文件夹批量」(扫描目录下所有 .csv)。"
|
||||
"完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。"
|
||||
)
|
||||
hint.setWordWrap(True)
|
||||
hint.setStyleSheet(
|
||||
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
|
||||
)
|
||||
hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法")
|
||||
hint.setStyleSheet("color: #666; font-size: 11px;")
|
||||
layout.addWidget(hint)
|
||||
|
||||
mode_row = QHBoxLayout()
|
||||
self.mode_single_rb = QRadioButton("单个 CSV 文件")
|
||||
self.mode_folder_rb = QRadioButton("文件夹批量")
|
||||
self._mode_group = QButtonGroup(self)
|
||||
self._mode_group.addButton(self.mode_single_rb, 0)
|
||||
self._mode_group.addButton(self.mode_folder_rb, 1)
|
||||
mode_row.addWidget(self.mode_single_rb)
|
||||
mode_row.addWidget(self.mode_folder_rb)
|
||||
mode_row.addStretch()
|
||||
layout.addLayout(mode_row)
|
||||
# CSV文件选择
|
||||
csv_group = QGroupBox("数据文件")
|
||||
csv_layout = QVBoxLayout()
|
||||
|
||||
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
|
||||
radio_style = """
|
||||
QRadioButton {
|
||||
font-size: 14px;
|
||||
spacing: 8px;
|
||||
color: #333333;
|
||||
}
|
||||
QRadioButton::indicator {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border: 2px solid #999999;
|
||||
border-radius: 3px;
|
||||
background-color: white;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
border: 2px solid #0078d4;
|
||||
background-color: #0078d4;
|
||||
image: none;
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #005a9e;
|
||||
}
|
||||
"""
|
||||
self.mode_single_rb.setStyleSheet(radio_style)
|
||||
self.mode_folder_rb.setStyleSheet(radio_style)
|
||||
|
||||
self.prediction_csv_file = FileSelectWidget(
|
||||
"预测结果CSV:",
|
||||
self.csv_file = FileSelectWidget(
|
||||
"输入CSV文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.prediction_csv_file)
|
||||
self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed)
|
||||
csv_layout.addWidget(self.csv_file)
|
||||
|
||||
folder_row = QHBoxLayout()
|
||||
self.prediction_csv_dir_label = QLabel("预测CSV目录:")
|
||||
self.prediction_csv_dir_label.setMinimumWidth(120)
|
||||
self.prediction_csv_dir_edit = QLineEdit()
|
||||
self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
|
||||
pred_dir_btn = QPushButton("浏览…")
|
||||
pred_dir_btn.setMaximumWidth(80)
|
||||
pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
|
||||
folder_row.addWidget(self.prediction_csv_dir_label)
|
||||
folder_row.addWidget(self.prediction_csv_dir_edit, 1)
|
||||
folder_row.addWidget(pred_dir_btn)
|
||||
self._folder_row_widget = QWidget()
|
||||
self._folder_row_widget.setLayout(folder_row)
|
||||
layout.addWidget(self._folder_row_widget)
|
||||
self.refresh_btn = QPushButton("刷新列信息")
|
||||
self.refresh_btn.clicked.connect(self.refresh_csv_columns)
|
||||
csv_layout.addWidget(self.refresh_btn)
|
||||
|
||||
self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv)")
|
||||
layout.addWidget(self.recursive_csv_cb)
|
||||
csv_group.setLayout(csv_layout)
|
||||
layout.addWidget(csv_group)
|
||||
|
||||
self.boundary_file = FileSelectWidget(
|
||||
"边界文件:",
|
||||
"Shapefiles (*.shp);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.boundary_file)
|
||||
# 自变量选择
|
||||
x_group = QGroupBox("自变量列选择 (可多选)")
|
||||
x_layout = QVBoxLayout()
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("生成参数")
|
||||
params_layout = QFormLayout()
|
||||
x_scroll = QScrollArea()
|
||||
x_scroll.setWidgetResizable(True)
|
||||
x_scroll.setMinimumHeight(250)
|
||||
x_scroll.setMaximumHeight(350)
|
||||
|
||||
self.resolution = QDoubleSpinBox()
|
||||
self.resolution.setRange(1, 1000)
|
||||
self.resolution.setValue(30)
|
||||
params_layout.addRow("分辨率(米):", self.resolution)
|
||||
x_widget = QWidget()
|
||||
self.x_columns_layout = QGridLayout()
|
||||
x_widget.setLayout(self.x_columns_layout)
|
||||
|
||||
self.input_crs = QLineEdit()
|
||||
self.input_crs.setText("EPSG:32651")
|
||||
params_layout.addRow("输入坐标系:", self.input_crs)
|
||||
x_scroll.setWidget(x_widget)
|
||||
x_layout.addWidget(x_scroll)
|
||||
|
||||
self.output_crs = QLineEdit()
|
||||
self.output_crs.setText("EPSG:4326")
|
||||
params_layout.addRow("输出坐标系:", self.output_crs)
|
||||
x_btn_layout = QHBoxLayout()
|
||||
self.x_select_all = QPushButton("全选")
|
||||
self.x_deselect_all = QPushButton("全不选")
|
||||
self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True))
|
||||
self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False))
|
||||
x_btn_layout.addWidget(self.x_select_all)
|
||||
x_btn_layout.addWidget(self.x_deselect_all)
|
||||
x_btn_layout.addStretch()
|
||||
x_layout.addLayout(x_btn_layout)
|
||||
|
||||
self.show_points = QCheckBox("显示采样点")
|
||||
params_layout.addRow("", self.show_points)
|
||||
x_group.setLayout(x_layout)
|
||||
layout.addWidget(x_group)
|
||||
|
||||
self.use_diffusion = QCheckBox("启用距离扩散")
|
||||
self.use_diffusion.setChecked(True)
|
||||
params_layout.addRow("", self.use_diffusion)
|
||||
# 因变量选择
|
||||
y_group = QGroupBox("因变量列选择 (可多选)")
|
||||
y_layout = QVBoxLayout()
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
y_scroll = QScrollArea()
|
||||
y_scroll.setWidgetResizable(True)
|
||||
y_scroll.setMinimumHeight(200)
|
||||
y_scroll.setMaximumHeight(300)
|
||||
|
||||
y_widget = QWidget()
|
||||
self.y_columns_layout = QGridLayout()
|
||||
y_widget.setLayout(self.y_columns_layout)
|
||||
|
||||
y_scroll.setWidget(y_widget)
|
||||
y_layout.addWidget(y_scroll)
|
||||
|
||||
y_btn_layout = QHBoxLayout()
|
||||
self.y_select_all = QPushButton("全选")
|
||||
self.y_deselect_all = QPushButton("全不选")
|
||||
self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True))
|
||||
self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False))
|
||||
y_btn_layout.addWidget(self.y_select_all)
|
||||
y_btn_layout.addWidget(self.y_deselect_all)
|
||||
y_btn_layout.addStretch()
|
||||
y_layout.addLayout(y_btn_layout)
|
||||
|
||||
y_group.setLayout(y_layout)
|
||||
layout.addWidget(y_group)
|
||||
|
||||
# 回归方法选择
|
||||
method_group = QGroupBox("回归方法选择 (可多选)")
|
||||
method_layout = QVBoxLayout()
|
||||
|
||||
method_grid = QGridLayout()
|
||||
regression_methods = [
|
||||
'linear', 'exponential', 'power', 'logarithmic',
|
||||
'polynomial', 'hyperbolic', 'sigmoidal'
|
||||
]
|
||||
|
||||
for i, method in enumerate(regression_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
if method in ['linear', 'exponential', 'power', 'logarithmic']:
|
||||
checkbox.setChecked(True)
|
||||
self.method_checkboxes[method] = checkbox
|
||||
method_grid.addWidget(checkbox, i // 3, i % 3)
|
||||
|
||||
method_layout.addLayout(method_grid)
|
||||
|
||||
method_btn_layout = QHBoxLayout()
|
||||
self.method_select_all = QPushButton("全选")
|
||||
self.method_deselect_all = QPushButton("全不选")
|
||||
self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True))
|
||||
self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False))
|
||||
method_btn_layout.addWidget(self.method_select_all)
|
||||
method_btn_layout.addWidget(self.method_deselect_all)
|
||||
method_btn_layout.addStretch()
|
||||
method_layout.addLayout(method_btn_layout)
|
||||
|
||||
method_group.setLayout(method_layout)
|
||||
layout.addWidget(method_group)
|
||||
|
||||
# 输出目录
|
||||
self.output_dir = FileSelectWidget(
|
||||
"输出分布图目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
|
||||
self.output_dir.browse_btn.clicked.disconnect()
|
||||
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
layout.addWidget(self.output_dir)
|
||||
output_group = QGroupBox("输出设置")
|
||||
output_layout = QFormLayout()
|
||||
|
||||
self.output_dir = QLineEdit()
|
||||
self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充
|
||||
output_layout.addRow("输出目录名:", self.output_dir)
|
||||
|
||||
output_group.setLayout(output_layout)
|
||||
layout.addWidget(output_group)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
@ -216,119 +168,120 @@ class Step9Panel(QWidget):
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
# 信号绑定与初始状态
|
||||
self.mode_single_rb.toggled.connect(self._toggle_input_mode)
|
||||
self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
|
||||
self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
|
||||
self._toggle_input_mode() # 根据默认值设置初始显示状态
|
||||
def toggle_checkboxes(self, checkboxes_dict, checked):
|
||||
"""统一设置checkbox状态"""
|
||||
for checkbox in checkboxes_dict.values():
|
||||
checkbox.setChecked(checked)
|
||||
|
||||
def _toggle_input_mode(self):
|
||||
"""槽函数:根据单选框状态动态显示/隐藏对应的输入组件。"""
|
||||
folder_mode = self.mode_folder_rb.isChecked()
|
||||
# 单个 CSV 模式:显示单文件选择,隐藏文件夹选择
|
||||
self.prediction_csv_file.setVisible(not folder_mode)
|
||||
# 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择
|
||||
self._folder_row_widget.setVisible(folder_mode)
|
||||
self.recursive_csv_cb.setVisible(folder_mode)
|
||||
def on_csv_file_changed(self):
|
||||
"""CSV文件改变时自动刷新列信息"""
|
||||
self.refresh_csv_columns()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
def refresh_csv_columns(self):
|
||||
"""刷新CSV文件的列信息"""
|
||||
csv_path = self.csv_file.get_path()
|
||||
if not csv_path or not os.path.exists(csv_path):
|
||||
self.csv_columns = []
|
||||
self.update_column_widgets()
|
||||
return
|
||||
|
||||
def browse_prediction_csv_dir(self):
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "11_12_13_predictions")
|
||||
d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
|
||||
if d:
|
||||
self.prediction_csv_dir_edit.setText(d)
|
||||
try:
|
||||
df = pd.read_csv(csv_path, nrows=0)
|
||||
self.csv_columns = list(df.columns)
|
||||
self.update_column_widgets()
|
||||
except Exception as e:
|
||||
self.csv_columns = []
|
||||
self.update_column_widgets()
|
||||
print(f"读取CSV列信息失败: {e}")
|
||||
|
||||
def _collect_csv_paths_from_folder(self) -> List[str]:
|
||||
folder = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
if not folder or not os.path.isdir(folder):
|
||||
return []
|
||||
root = Path(folder)
|
||||
if self.recursive_csv_cb.isChecked():
|
||||
files = sorted(root.rglob("*.csv"))
|
||||
else:
|
||||
files = sorted(root.glob("*.csv"))
|
||||
return [str(p) for p in files if p.is_file()]
|
||||
def update_column_widgets(self):
|
||||
"""更新列选择组件"""
|
||||
for checkbox in self.x_column_checkboxes.values():
|
||||
checkbox.setParent(None)
|
||||
self.x_column_checkboxes.clear()
|
||||
|
||||
def _step9_base_pipeline_kwargs(self) -> dict:
|
||||
return {
|
||||
'boundary_shp_path': self.boundary_file.get_path(),
|
||||
'resolution': self.resolution.value(),
|
||||
'input_crs': self.input_crs.text(),
|
||||
'output_crs': self.output_crs.text(),
|
||||
'show_sample_points': self.show_points.isChecked(),
|
||||
'use_distance_diffusion': self.use_diffusion.isChecked(),
|
||||
}
|
||||
for checkbox in self.y_column_checkboxes.values():
|
||||
checkbox.setParent(None)
|
||||
self.y_column_checkboxes.clear()
|
||||
|
||||
if not self.csv_columns:
|
||||
return
|
||||
|
||||
for i, col in enumerate(self.csv_columns):
|
||||
checkbox = QCheckBox(col)
|
||||
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
|
||||
checkbox.setChecked(True)
|
||||
self.x_column_checkboxes[col] = checkbox
|
||||
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
|
||||
|
||||
for i, col in enumerate(self.csv_columns):
|
||||
checkbox = QCheckBox(col)
|
||||
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
|
||||
checkbox.setChecked(True)
|
||||
self.y_column_checkboxes[col] = checkbox
|
||||
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
|
||||
|
||||
self.x_columns_layout.update()
|
||||
self.y_columns_layout.update()
|
||||
|
||||
def get_config(self):
|
||||
pred_csv = (self.prediction_csv_file.get_path() or "").strip()
|
||||
folder_mode = self.mode_folder_rb.isChecked()
|
||||
pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
config = {
|
||||
'step9_batch_mode': 'folder' if folder_mode else 'single',
|
||||
'prediction_csv_dir': pred_dir if pred_dir else None,
|
||||
'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
|
||||
'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
|
||||
'boundary_shp_path': self.boundary_file.get_path(),
|
||||
'resolution': self.resolution.value(),
|
||||
'input_crs': self.input_crs.text(),
|
||||
'output_crs': self.output_crs.text(),
|
||||
'show_sample_points': self.show_points.isChecked(),
|
||||
'use_distance_diffusion': self.use_diffusion.isChecked(),
|
||||
selected_x_columns = [
|
||||
col for col, checkbox in self.x_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
selected_y_columns = [
|
||||
col for col, checkbox in self.y_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
selected_methods = [
|
||||
method for method, checkbox in self.method_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_methods:
|
||||
selected_methods = 'all'
|
||||
|
||||
return {
|
||||
'csv_path': self.csv_file.get_path() or None,
|
||||
'x_columns': selected_x_columns,
|
||||
'y_columns': selected_y_columns,
|
||||
'methods': selected_methods,
|
||||
'output_dir': self.output_dir.text().strip() or None,
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
out_dir = (self.output_dir.get_path() or "").strip()
|
||||
if not folder_mode and pred_csv and out_dir:
|
||||
stem = Path(pred_csv).stem
|
||||
config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
|
||||
else:
|
||||
config['output_image_path'] = None
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
mode = config.get('step9_batch_mode', 'single')
|
||||
if mode == 'folder':
|
||||
self.mode_folder_rb.setChecked(True)
|
||||
else:
|
||||
self.mode_single_rb.setChecked(True)
|
||||
if config.get('prediction_csv_dir'):
|
||||
self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
|
||||
if 'recursive_csv_scan' in config:
|
||||
self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
|
||||
if 'prediction_csv_path' in config and config['prediction_csv_path']:
|
||||
self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
|
||||
if 'boundary_shp_path' in config:
|
||||
self.boundary_file.set_path(config['boundary_shp_path'])
|
||||
if 'resolution' in config:
|
||||
self.resolution.setValue(config['resolution'])
|
||||
if 'input_crs' in config:
|
||||
self.input_crs.setText(config['input_crs'])
|
||||
if 'output_crs' in config:
|
||||
self.output_crs.setText(config['output_crs'])
|
||||
if 'show_sample_points' in config:
|
||||
self.show_points.setChecked(config['show_sample_points'])
|
||||
if 'use_distance_diffusion' in config:
|
||||
self.use_diffusion.setChecked(config['use_distance_diffusion'])
|
||||
if 'output_dir' in config and config['output_dir']:
|
||||
self.output_dir.set_path(str(config['output_dir']))
|
||||
elif config.get('output_image_path'):
|
||||
p = Path(str(config['output_image_path']))
|
||||
if p.parent and str(p.parent) != '.':
|
||||
self.output_dir.set_path(str(p.parent))
|
||||
if 'csv_path' in config:
|
||||
self.csv_file.set_path(config['csv_path'])
|
||||
self.refresh_csv_columns()
|
||||
|
||||
if 'x_columns' in config:
|
||||
selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set()
|
||||
for col, checkbox in self.x_column_checkboxes.items():
|
||||
checkbox.setChecked(col in selected_x)
|
||||
|
||||
if 'y_columns' in config:
|
||||
selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set()
|
||||
for col, checkbox in self.y_column_checkboxes.items():
|
||||
checkbox.setChecked(col in selected_y)
|
||||
|
||||
if 'methods' in config:
|
||||
methods = config['methods']
|
||||
if isinstance(methods, list):
|
||||
selected_methods = set(methods)
|
||||
elif methods == 'all':
|
||||
selected_methods = set(self.method_checkboxes.keys())
|
||||
else:
|
||||
selected_methods = set()
|
||||
for method, checkbox in self.method_checkboxes.items():
|
||||
checkbox.setChecked(method in selected_methods)
|
||||
|
||||
if 'output_dir' in config:
|
||||
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充预测结果目录
|
||||
|
||||
优先使用 Step8(机器学习预测)的输出目录作为待预测 CSV 目录;
|
||||
其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
@ -344,190 +297,78 @@ class Step9Panel(QWidget):
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Step5 界面读取训练光谱 CSV 路径
|
||||
main_window = self.window()
|
||||
if not main_window:
|
||||
return
|
||||
if main_window and hasattr(main_window, 'step5_panel'):
|
||||
step5_widget = getattr(main_window.step5_panel, 'output_file', None)
|
||||
step5_output_path = ""
|
||||
if hasattr(step5_widget, 'get_path'):
|
||||
step5_output_path = step5_widget.get_path() or ""
|
||||
elif hasattr(step5_widget, 'text'):
|
||||
step5_output_path = step5_widget.text() or ""
|
||||
|
||||
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
|
||||
pred_dir = None
|
||||
if hasattr(main_window, 'step8_panel'):
|
||||
step8_widget = getattr(main_window.step8_panel, 'output_file', None)
|
||||
step8_output = ""
|
||||
if hasattr(step8_widget, 'get_path'):
|
||||
step8_output = step8_widget.get_path() or ""
|
||||
elif hasattr(step8_widget, 'text'):
|
||||
step8_output = step8_widget.text() or ""
|
||||
|
||||
if step8_output:
|
||||
if step5_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step8_output):
|
||||
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
|
||||
# 提取父目录后追加 Machine_Learning_Prediction(最底层真实子目录)
|
||||
base_pred_dir = str(Path(step8_output).parent)
|
||||
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
|
||||
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
|
||||
if not os.path.isabs(step5_output_path):
|
||||
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
|
||||
existing = self.csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.csv_file.set_path(step5_output_path)
|
||||
|
||||
# 2. 备选:从 Step8.5 界面读取非经验预测输出目录
|
||||
if not pred_dir and hasattr(main_window, 'step8_5_panel'):
|
||||
step8_5_widget = getattr(main_window.step8_5_panel, 'output_file', None)
|
||||
step8_5_output = ""
|
||||
if hasattr(step8_5_widget, 'get_path'):
|
||||
step8_5_output = step8_5_widget.get_path() or ""
|
||||
elif hasattr(step8_5_widget, 'text'):
|
||||
step8_5_output = step8_5_widget.text() or ""
|
||||
|
||||
if step8_5_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step8_5_output):
|
||||
step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
|
||||
pred_dir = str(Path(step8_5_output).parent)
|
||||
|
||||
# 3. 备选:从 Step8.75 界面读取自定义回归预测输出目录
|
||||
if not pred_dir and hasattr(main_window, 'step8_75_panel'):
|
||||
step8_75_widget = getattr(main_window.step8_75_panel, 'output_dir_widget', None)
|
||||
step8_75_output = ""
|
||||
if hasattr(step8_75_widget, 'get_path'):
|
||||
step8_75_output = step8_75_widget.get_path() or ""
|
||||
elif hasattr(step8_75_widget, 'text'):
|
||||
step8_75_output = step8_75_widget.text() or ""
|
||||
|
||||
if step8_75_output:
|
||||
pred_dir = step8_75_output
|
||||
|
||||
# 自动填入"预测CSV目录"(文件夹批量模式)
|
||||
if pred_dir:
|
||||
existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
if not existing_dir:
|
||||
self.prediction_csv_dir_edit.setText(pred_dir)
|
||||
# 切换到文件夹批量模式
|
||||
self.mode_folder_rb.setChecked(True)
|
||||
|
||||
# 4. 自动填充输出目录(14_visualization)
|
||||
# 2. 自动填充输出目录(9_Custom_Regression_Modeling)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "14_visualization")
|
||||
output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_dir.set_path(output_dir)
|
||||
|
||||
# 5. 自动探测原始矢量边界文件(.shp)作为专题图底图
|
||||
# 优先回溯 input-test/roi.shp,geopandas.read_file 仅支持矢量格式
|
||||
if self.work_dir:
|
||||
possible_shp = None
|
||||
candidates = [
|
||||
Path(self.work_dir).parent / "input-test" / "roi.shp",
|
||||
Path(self.work_dir) / "roi.shp",
|
||||
Path(self.work_dir).parent / "roi.shp",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.exists() and candidate.suffix.lower() == ".shp":
|
||||
possible_shp = candidate
|
||||
break
|
||||
|
||||
existing_boundary = (self.boundary_file.get_path() or "").strip()
|
||||
if not existing_boundary and possible_shp:
|
||||
self.boundary_file.set_path(str(possible_shp))
|
||||
elif not existing_boundary:
|
||||
# 未找到 .shp 时清空并提示用户手动选择矢量文件
|
||||
self.boundary_file.set_path("")
|
||||
print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。")
|
||||
existing_out = self.output_dir.text().strip()
|
||||
if not existing_out:
|
||||
self.output_dir.setText(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "14_visualization")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
|
||||
if dir_path:
|
||||
self.output_dir.set_path(dir_path)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤9"""
|
||||
if self._batch_thread and self._batch_thread.isRunning():
|
||||
QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
|
||||
csv_path = self.csv_file.get_path()
|
||||
|
||||
if not csv_path:
|
||||
QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件")
|
||||
return
|
||||
if not os.path.exists(csv_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在")
|
||||
return
|
||||
|
||||
boundary_shp_path = self.boundary_file.get_path()
|
||||
if not boundary_shp_path:
|
||||
QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
|
||||
selected_x_columns = [
|
||||
col for col, checkbox in self.x_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_x_columns:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列")
|
||||
return
|
||||
if not os.path.exists(boundary_shp_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
|
||||
|
||||
selected_y_columns = [
|
||||
col for col, checkbox in self.y_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_y_columns:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列")
|
||||
return
|
||||
|
||||
selected_methods = [
|
||||
method for method, checkbox in self.method_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_methods:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if not parent or not hasattr(parent, 'run_single_step'):
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step9', {'step9': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
return
|
||||
|
||||
if self.mode_folder_rb.isChecked():
|
||||
csv_list = self._collect_csv_paths_from_folder()
|
||||
if not csv_list:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入验证失败",
|
||||
"所选文件夹中未找到 .csv 文件,或目录无效。\n"
|
||||
"可勾选「包含子文件夹」以递归扫描。",
|
||||
)
|
||||
return
|
||||
if not PIPELINE_AVAILABLE:
|
||||
QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
|
||||
return
|
||||
work_dir = getattr(parent, "work_dir", None) or "./work_dir"
|
||||
work_dir = str(work_dir)
|
||||
base_kw = self._step9_base_pipeline_kwargs()
|
||||
out_dir_opt = (self.output_dir.get_path() or "").strip() or None
|
||||
self.run_button.setEnabled(False)
|
||||
self._batch_thread = Step9BatchThread(work_dir, csv_list, base_kw, out_dir_opt)
|
||||
main_win = parent
|
||||
|
||||
def _batch_log(msg, lvl):
|
||||
if hasattr(main_win, "log_message"):
|
||||
main_win.log_message(msg, lvl)
|
||||
|
||||
self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
|
||||
self._batch_thread.finished_ok.connect(self._on_step9_batch_ok, Qt.QueuedConnection)
|
||||
self._batch_thread.failed.connect(self._on_step9_batch_fail, Qt.QueuedConnection)
|
||||
self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection)
|
||||
self._batch_thread.start()
|
||||
if hasattr(parent, "log_message"):
|
||||
parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV,工作目录 {work_dir}", "info")
|
||||
return
|
||||
|
||||
prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
|
||||
if not prediction_csv_path:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入验证失败",
|
||||
"请选择「预测结果 CSV」文件,或切换到「文件夹批量」。",
|
||||
)
|
||||
return
|
||||
if not os.path.isfile(prediction_csv_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
parent.run_single_step('step9', {'step9': config})
|
||||
|
||||
def _on_step9_batch_ok(self, n: int):
|
||||
QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, "log_message"):
|
||||
parent = parent.parent()
|
||||
if parent and hasattr(parent, "log_message"):
|
||||
parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
|
||||
|
||||
def _on_step9_batch_fail(self, err: str):
|
||||
QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, "log_message"):
|
||||
parent = parent.parent()
|
||||
if parent and hasattr(parent, "log_message"):
|
||||
parent.log_message(err, "error")
|
||||
|
||||
@ -1567,12 +1567,12 @@ class VisualizationPanel(QWidget):
|
||||
ml_dir.mkdir(parents=True, exist_ok=True)
|
||||
reg_dir.mkdir(parents=True, exist_ok=True)
|
||||
custom_dir.mkdir(parents=True, exist_ok=True)
|
||||
if hasattr(self, 'step8_panel') and hasattr(self.step8_panel, 'output_file'):
|
||||
self.step8_panel.output_file.set_path(str(ml_dir))
|
||||
if hasattr(self, 'step8_5_panel') and hasattr(self.step8_5_panel, 'output_file'):
|
||||
self.step8_5_panel.output_file.set_path(str(reg_dir))
|
||||
if hasattr(self, 'step8_75_panel') and hasattr(self.step8_75_panel, 'output_dir_widget'):
|
||||
self.step8_75_panel.output_dir_widget.set_path(str(custom_dir))
|
||||
if hasattr(self, 'step11_ml_panel') and hasattr(self.step11_ml_panel, 'output_file'):
|
||||
self.step11_ml_panel.output_file.set_path(str(ml_dir))
|
||||
if hasattr(self, 'step11_panel') and hasattr(self.step11_panel, 'output_file'):
|
||||
self.step11_panel.output_file.set_path(str(reg_dir))
|
||||
if hasattr(self, 'step12_panel') and hasattr(self.step12_panel, 'output_dir_widget'):
|
||||
self.step12_panel.output_dir_widget.set_path(str(custom_dir))
|
||||
print(f"预测输出目录已设置:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}")
|
||||
except Exception as e:
|
||||
print(f"设置预测输出目录失败: {e}")
|
||||
|
||||
Reference in New Issue
Block a user