fix: 修复工作目录与步骤名不对应、回归预测虚数报错、模型加载及预处理名称转换问题,重构可视化并修正勾选联动

This commit is contained in:
2026-04-14 17:41:38 +08:00
parent b0a94ba1e7
commit 9b7bcfadd1
17 changed files with 12470 additions and 3113 deletions

View File

@ -660,7 +660,7 @@ class VisualizationWorkerThread(QThread):
self.failed.emit("训练光谱 CSV 无效或不存在请确认已选择步骤5输出的文件。")
return
if not models_dir or not Path(models_dir).is_dir():
self.failed.emit("模型目录无效或不存在请确认步骤6已生成 7_models 下的参数子文件夹。")
self.failed.emit("模型目录无效或不存在请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
return
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
@ -672,12 +672,111 @@ class VisualizationWorkerThread(QThread):
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
parts = []
# 获取训练数据CSV路径多个图表类型共用
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
# 生成散点图
if self.extra.get("gen_scatter"):
if training_csv.is_file():
models_dir = wp / "7_Supervised_Model_Training"
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=str(training_csv),
models_dir=str(models_dir),
)
count = len(scatter_paths) if scatter_paths else 0
parts.append(f"散点图: {count}")
else:
parts.append("散点图: 跳过(无模型目录)")
else:
parts.append("散点图: 跳过(无训练数据)")
# 生成光谱图
if self.extra.get("gen_spectrum"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
# 推断水质参数列(光谱波段列之前的数值型列)
wl_col = _viz_infer_wavelength_start_column(df)
if isinstance(wl_col, str):
idx = int(df.columns.get_loc(wl_col)) + 1
else:
idx = int(wl_col)
param_cols = []
if idx > 0 and idx < len(df.columns):
param_cols = [
c for c in df.columns[:idx]
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
]
if param_cols:
# plot_spectrum_by_parameter 接受单个参数列,逐个调用
spectrum_paths = []
for param_col in param_cols:
try:
path = viz.plot_spectrum_by_parameter(
csv_path=str(training_csv),
parameter_column=param_col,
wavelength_start_column=wl_col,
n_groups=5,
)
if path:
spectrum_paths.append(path)
except Exception as e:
print(f"生成光谱图失败 ({param_col}): {e}")
count = len(spectrum_paths)
parts.append(f"光谱图: {count}")
else:
parts.append("光谱图: 跳过(无可用参数列)")
else:
parts.append("光谱图: 跳过(无训练数据)")
# 生成统计图
if self.extra.get("gen_boxplots"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
# **只统计水质参数列(数值型),排除波长列**
# 获取水质参数列(数值型且不是波长、不是坐标列)
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
param_cols = [
c for c in df.select_dtypes(include=[np.number]).columns
if not any(exc in c.lower() for exc in exclude_cols)
]
# 排除光谱波长列:找到波长开始位置,只取之前的数值列
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
if 0 < idx < len(df.columns):
meta_set = set(df.columns[:idx])
param_cols = [c for c in param_cols if c in meta_set]
if param_cols:
output_dict = viz.plot_statistical_charts(
csv_path=str(training_csv),
parameter_columns=param_cols,
)
# plot_statistical_charts 返回字典,统计值非空
count = len([v for v in output_dict.values() if v]) if output_dict else 0
parts.append(f"统计图: {count}")
else:
parts.append("统计图: 跳过(无可用水质参数列)")
else:
parts.append("统计图: 跳过(无训练数据)")
# 生成掩膜/耀斑预览图
if self.extra.get("gen_mask_glint"):
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
parts.append(f"掩膜/耀斑预览 {len(preview_paths) if preview_paths else 0}")
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0}")
# 生成采样点地图
if self.extra.get("gen_sampling_map"):
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
@ -3103,35 +3202,37 @@ class Step8_75Panel(QWidget):
)
layout.addWidget(self.sampling_csv_file)
# 公式CSV文件选择
# 自定义回归模型目录选择9_Custom_Regression_Modeling
self.regression_models_dir = FileSelectWidget(
"回归模型目录:",
"Directories;;All Files (*.*)"
)
self.regression_models_dir.label.setText("回归模型目录:")
# 修改浏览按钮为选择目录
self.regression_models_dir.browse_btn.clicked.disconnect()
self.regression_models_dir.browse_btn.clicked.connect(self.browse_regression_models_dir)
self.regression_models_dir.set_path("9_Custom_Regression_Modeling") # 设置默认值
layout.addWidget(self.regression_models_dir)
# 公式CSV文件选择用于查找index_formula
self.formula_csv_file = FileSelectWidget(
"公式CSV文件:",
"CSV Files (*.csv);;All Files (*.*)"
)
self.formula_csv_file.label.setText("公式CSV文件:")
layout.addWidget(self.formula_csv_file)
# 模型目录选择
self.models_dir_file = FileSelectWidget(
"模型目录:",
# 输出目录选择
self.output_dir_widget = FileSelectWidget(
"输出目录:",
"Directories;;All Files (*.*)"
)
self.models_dir_file.label.setText("模型目录:")
self.output_dir_widget.label.setText("输出目录:")
# 修改浏览按钮为选择目录
self.models_dir_file.browse_btn.clicked.disconnect()
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
layout.addWidget(self.models_dir_file)
# 参数设置
params_group = QGroupBox("预测参数")
params_layout = QFormLayout()
# 预测列名
self.prediction_column = QLineEdit()
self.prediction_column.setText("prediction")
params_layout.addRow("预测列名:", self.prediction_column)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
self.output_dir_widget.browse_btn.clicked.disconnect()
self.output_dir_widget.browse_btn.clicked.connect(self.browse_output_dir)
self.output_dir_widget.line_edit.setPlaceholderText("留空使用默认prediction目录")
layout.addWidget(self.output_dir_widget)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
@ -3162,45 +3263,59 @@ class Step8_75Panel(QWidget):
layout.addStretch()
self.setLayout(layout)
def browse_models_dir(self):
"""浏览模型目录"""
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", "")
def browse_regression_models_dir(self):
"""浏览回归模型目录"""
dir_path = QFileDialog.getExistingDirectory(self, "选择回归模型目录", "")
if dir_path:
self.models_dir_file.set_path(dir_path)
self.regression_models_dir.set_path(dir_path)
def browse_output_dir(self):
"""浏览输出目录"""
dir_path = QFileDialog.getExistingDirectory(self, "选择输出目录", "")
if dir_path:
self.output_dir_widget.set_path(dir_path)
def get_config(self):
"""获取配置"""
config = {
'prediction_column': self.prediction_column.text(),
'enabled': self.enable_checkbox.isChecked()
}
# 添加采样光谱CSV路径
sampling_csv_path = self.sampling_csv_file.get_path()
if sampling_csv_path:
config['sampling_csv_path'] = sampling_csv_path
# 添加回归模型目录路径
regression_models_dir = self.regression_models_dir.get_path()
if regression_models_dir:
config['custom_regression_dir'] = regression_models_dir
# 添加公式CSV文件路径
formula_csv_path = self.formula_csv_file.get_path()
if formula_csv_path:
config['formula_csv_file'] = formula_csv_path
# 添加模型目录路径
models_dir = self.models_dir_file.get_path()
if models_dir:
config['custom_regression_dir'] = models_dir
config['formula_csv_path'] = formula_csv_path
# 添加输出目录路径
output_dir = self.output_dir_widget.get_path()
if output_dir:
config['output_dir'] = output_dir
return config
def set_config(self, config):
"""设置配置"""
if 'prediction_column' in config:
self.prediction_column.setText(config['prediction_column'])
if 'sampling_csv_path' in config:
self.sampling_csv_file.set_path(config['sampling_csv_path'])
if 'formula_csv_file' in config:
self.formula_csv_file.set_path(config['formula_csv_file'])
if 'custom_regression_dir' in config:
self.models_dir_file.set_path(config['custom_regression_dir'])
self.regression_models_dir.set_path(config['custom_regression_dir'])
if 'formula_csv_path' in config:
self.formula_csv_file.set_path(config['formula_csv_path'])
if 'output_dir' in config:
self.output_dir_widget.set_path(config['output_dir'])
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
@ -3213,9 +3328,9 @@ class Step8_75Panel(QWidget):
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
return
formula_csv_path = self.formula_csv_file.get_path()
if not formula_csv_path:
QMessageBox.warning(self, "输入错误", "请选择公式CSV文件")
regression_models_dir = self.regression_models_dir.get_path()
if not regression_models_dir:
QMessageBox.warning(self, "输入错误", "请选择回归模型目录")
return
# 获取配置
@ -3323,8 +3438,7 @@ class ImageCategoryTree(QTreeWidget):
("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"),
("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"),
("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"),
("采样分析", ["sampling", "flight_path", "point_map", "trajectory"], "📍"),
("其他图表", [], "📁"),
("含量分布图", [], "📁"),
]
def __init__(self, parent=None):
@ -3376,7 +3490,7 @@ class ImageCategoryTree(QTreeWidget):
# 根据文件名关键词确定类别
category = self._determine_category(file_path.name)
category_item = self.category_items.get(category, self.category_items["其他图表"])
category_item = self.category_items.get(category, self.category_items["含量分布图"])
# 创建图像项
image_item = QTreeWidgetItem(category_item)
@ -3394,7 +3508,7 @@ class ImageCategoryTree(QTreeWidget):
if any(keyword in filename_lower for keyword in keywords):
return category_name
return "其他图表"
return "含量分布图"
def scan_directory(self, work_dir: str):
"""扫描目录中的所有图像文件"""
@ -3682,11 +3796,7 @@ class VisualizationPanel(QWidget):
def _viz_set_busy(self, busy: bool):
for w in (
getattr(self, "gen_all_btn", None),
getattr(self, "gen_scatter_btn", None),
getattr(self, "gen_spectrum_btn", None),
getattr(self, "gen_stats_btn", None),
getattr(self, "gen_mask_glint_btn", None),
getattr(self, "gen_sampling_map_btn", None),
getattr(self, "scan_btn", None),
):
if w is not None:
w.setEnabled(not busy)
@ -3728,7 +3838,13 @@ class VisualizationPanel(QWidget):
return [c for c in meta if pd.api.types.is_numeric_dtype(df[c])]
def _statistics_param_columns(self, df: pd.DataFrame) -> List[str]:
"""统计图用的参数列;若存在光谱波段,则只统计波段前的字段。"""
"""统计图用的参数列**只统计水质参数列(数值型),排除波长列**。
- 包括:数值型的水质参数(浓度、含量等)
- 排除:光谱波长列(虽然也是数值型,但不是水质参数)
- 排除坐标列UTM_X, UTM_Y, lat, lon等
若存在光谱波段,则只统计波段前的数值字段。
"""
# 选择数值类型列
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
@ -3737,6 +3853,7 @@ class VisualizationPanel(QWidget):
idx = int(wl)
coord_kw = ("utm", "lat", "lon")
if 0 < idx < len(df.columns):
# 只取波长开始之前的列(水质参数区域)
meta_set = set(df.columns[:idx])
return [
col
@ -3744,6 +3861,7 @@ class VisualizationPanel(QWidget):
if col in meta_set and not any(x in str(col).lower() for x in coord_kw)
]
return [
# 如果没有找到波长列,排除坐标相关列
col
for col in numeric_cols
if not any(x in str(col).lower() for x in coord_kw + ("x", "y"))
@ -3825,7 +3943,7 @@ class VisualizationPanel(QWidget):
QMessageBox.warning(
self,
"提示",
"未生成任何散点图。请确认 7_models 下已有各参数子目录及模型文件,"
"未生成任何散点图。请确认 7_Supervised_Model_Training 下已有各参数子目录及模型文件,"
"且训练 CSV 与建模时一致。",
)
elif t == "generate_all_selected":
@ -3870,36 +3988,22 @@ class VisualizationPanel(QWidget):
self.image_tree = ImageCategoryTree()
self.image_tree.itemClicked.connect(self.on_tree_item_clicked)
tree_layout.addWidget(self.image_tree)
# 生成按钮组
gen_btn_layout = QHBoxLayout()
self.gen_all_btn = QPushButton("🚀 生成全部")
self.gen_all_btn.setToolTip("生成所有类型的可视化图表")
self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;")
self.gen_all_btn.clicked.connect(self.generate_all_visualizations)
gen_btn_layout.addWidget(self.gen_all_btn)
self.scan_btn = QPushButton("📁 扫描")
self.scan_btn.setToolTip("扫描工作目录中的图像文件")
self.scan_btn.clicked.connect(self.scan_work_directory)
gen_btn_layout.addWidget(self.scan_btn)
tree_layout.addLayout(gen_btn_layout)
tree_group.setLayout(tree_layout)
left_layout.addWidget(tree_group, 1)
# 可视化配置
config_group = QGroupBox("可视化配置")
config_layout = QVBoxLayout()
self.gen_scatter = QCheckBox("模型评估散点图")
self.gen_scatter.setChecked(True)
config_layout.addWidget(self.gen_scatter)
self.gen_spectrum = QCheckBox("光谱曲线图")
self.gen_spectrum.setChecked(True)
config_layout.addWidget(self.gen_spectrum)
self.gen_boxplots = QCheckBox("统计图表")
self.gen_boxplots.setChecked(True)
config_layout.addWidget(self.gen_boxplots)
@ -3912,6 +4016,27 @@ class VisualizationPanel(QWidget):
self.gen_sampling_map.setChecked(True)
config_layout.addWidget(self.gen_sampling_map)
# 添加分隔线
config_layout.addSpacing(10)
line = QFrame()
line.setFrameShape(QFrame.HLine)
line.setStyleSheet("color: #ddd;")
config_layout.addWidget(line)
config_layout.addSpacing(10)
# 生成全部按钮
self.gen_all_btn = QPushButton("🚀 生成全部")
self.gen_all_btn.setToolTip("生成所有类型的可视化图表")
self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;")
self.gen_all_btn.clicked.connect(self.generate_all_visualizations)
config_layout.addWidget(self.gen_all_btn)
# 扫描按钮
self.scan_btn = QPushButton("📁 扫描目录")
self.scan_btn.setToolTip("扫描工作目录中的图像文件")
self.scan_btn.clicked.connect(self.scan_work_directory)
config_layout.addWidget(self.scan_btn)
config_group.setLayout(config_layout)
left_layout.addWidget(config_group)
@ -3928,43 +4053,7 @@ class VisualizationPanel(QWidget):
self.image_viewer = ImageViewerWidget()
self.image_viewer.refresh_btn.clicked.connect(self.scan_work_directory)
right_layout.addWidget(self.image_viewer, 1)
# 生成特定图表按钮组
specific_group = QGroupBox("生成特定图表")
specific_layout = QHBoxLayout()
self.gen_scatter_btn = QPushButton("📊 散点图")
self.gen_scatter_btn.setToolTip(
"基于工作目录下 5_training_spectra/training_spectra.csv 与 7_models 生成模型评估散点图"
)
self.gen_scatter_btn.clicked.connect(lambda: self.generate_chart('scatter'))
specific_layout.addWidget(self.gen_scatter_btn)
self.gen_spectrum_btn = QPushButton("📈 光谱图")
self.gen_spectrum_btn.setToolTip(
"基于 5_training_spectra/training_spectra.csv为每个数值型水质参数各生成一张光谱对比图无需选择"
)
self.gen_spectrum_btn.clicked.connect(lambda: self.generate_chart('spectrum'))
specific_layout.addWidget(self.gen_spectrum_btn)
self.gen_stats_btn = QPushButton("📉 统计图")
self.gen_stats_btn.setToolTip(
"基于工作目录下 5_training_spectra/training_spectra.csv 生成箱线图、直方图与相关性热力图"
)
self.gen_stats_btn.clicked.connect(lambda: self.generate_chart('statistics'))
specific_layout.addWidget(self.gen_stats_btn)
self.gen_mask_glint_btn = QPushButton("🖼️ 掩膜图")
self.gen_mask_glint_btn.clicked.connect(lambda: self.generate_mask_glint_previews())
specific_layout.addWidget(self.gen_mask_glint_btn)
self.gen_sampling_map_btn = QPushButton("📍 采样点图")
self.gen_sampling_map_btn.clicked.connect(lambda: self.generate_sampling_point_map())
specific_layout.addWidget(self.gen_sampling_map_btn)
specific_group.setLayout(specific_layout)
right_layout.addWidget(specific_group)
right_panel.setLayout(right_layout)
main_layout.addWidget(right_panel, 1)
@ -3999,6 +4088,9 @@ class VisualizationPanel(QWidget):
print(f"扫描工作目录: {work_path}")
self.image_tree.scan_directory(str(work_path))
# 设置三个预测步骤的默认输出路径
self._setup_prediction_output_dirs(work_path)
# 如果有图像,自动选择第一个
viz_dir = work_path / "14_visualization"
if viz_dir.exists():
@ -4006,6 +4098,41 @@ class VisualizationPanel(QWidget):
if image_files:
self.image_viewer.load_image(str(image_files[0]))
def _setup_prediction_output_dirs(self, work_path: Path):
"""
设置三个预测步骤的默认输出目录
在11_12_13_predictions下创建三个子文件夹
"""
try:
# 基础预测目录
base_prediction_dir = work_path / "11_12_13_predictions"
# 三个子文件夹路径
ml_dir = base_prediction_dir / "Machine_Learning_Prediction"
reg_dir = base_prediction_dir / "Regression_Model_Prediction"
custom_dir = base_prediction_dir / "Custom_Regression_Prediction"
# 创建目录(如果不存在)
ml_dir.mkdir(parents=True, exist_ok=True)
reg_dir.mkdir(parents=True, exist_ok=True)
custom_dir.mkdir(parents=True, exist_ok=True)
# 设置Step8Panel机器学习预测的默认输出路径
if hasattr(self, 'step8_panel') and hasattr(self.step8_panel, 'output_file'):
self.step8_panel.output_file.set_path(str(ml_dir))
# 设置Step8_5Panel回归模型预测的默认输出路径
if hasattr(self, 'step8_5_panel') and hasattr(self.step8_5_panel, 'output_file'):
self.step8_5_panel.output_file.set_path(str(reg_dir))
# 设置Step8_75Panel自定义回归预测的默认输出路径
if hasattr(self, 'step8_75_panel') and hasattr(self.step8_75_panel, 'output_dir_widget'):
self.step8_75_panel.output_dir_widget.set_path(str(custom_dir))
print(f"预测输出目录已设置:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}")
except Exception as e:
print(f"设置预测输出目录失败: {e}")
def on_tree_item_clicked(self, item, column):
"""目录树项点击事件"""
data = item.data(0, Qt.UserRole)
@ -4022,37 +4149,37 @@ class VisualizationPanel(QWidget):
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return
work_path = Path(self.work_dir)
if not work_path.exists():
QMessageBox.warning(self, "警告", "工作目录不存在!")
return
reply = QMessageBox.question(
self, "确认生成",
"将按左侧勾选项在后台生成可视化(掩膜/耀斑预览、采样点图等),可能需要较长时间。\n是否继续?",
QMessageBox.Yes | QMessageBox.No
)
if reply != QMessageBox.Yes:
return
if self.gen_scatter.isChecked():
print("生成散点图...(占位,请用建模/可视化流程生成)")
if self.gen_spectrum.isChecked():
print("生成光谱图...(占位,请用下方「光谱图」按钮)")
if self.gen_boxplots.isChecked():
print("生成统计图...(占位,请用下方「统计图」按钮)")
if not self.gen_mask_glint.isChecked() and not self.gen_sampling_map.isChecked():
# 检查是否有任何选项被勾选
if not (self.gen_scatter.isChecked() or self.gen_spectrum.isChecked() or
self.gen_boxplots.isChecked() or self.gen_mask_glint.isChecked() or
self.gen_sampling_map.isChecked()):
QMessageBox.information(
self,
"提示",
"请至少勾选「掩膜和耀斑缩略图」或「采样点地图」以执行后台批量任务",
"请至少勾选一项可视化配置选项以生成图表",
)
return
reply = QMessageBox.question(
self, "确认生成",
"将根据左侧勾选项在后台生成可视化图表,可能需要较长时间。\n是否继续?",
QMessageBox.Yes | QMessageBox.No
)
if reply != QMessageBox.Yes:
return
# 收集所有选中的任务
extra = {
"gen_scatter": self.gen_scatter.isChecked(),
"gen_spectrum": self.gen_spectrum.isChecked(),
"gen_boxplots": self.gen_boxplots.isChecked(),
"gen_mask_glint": self.gen_mask_glint.isChecked(),
"gen_sampling_map": self.gen_sampling_map.isChecked(),
}
@ -4082,7 +4209,7 @@ class VisualizationPanel(QWidget):
)
return
training_csv = training_spectra_csv
models_dir = work_path / "7_models"
models_dir = work_path / "7_Supervised_Model_Training"
if not models_dir.is_dir() or not any(
d.is_dir() for d in models_dir.iterdir()
):
@ -4284,7 +4411,6 @@ class VisualizationPanel(QWidget):
'generate_scatter': self.gen_scatter.isChecked(),
'generate_boxplots': self.gen_boxplots.isChecked(),
'generate_spectrum': self.gen_spectrum.isChecked(),
'generate_statistics': self.gen_stats_btn.isChecked(),
'generate_glint_previews': self.gen_mask_glint.isChecked(),
'generate_sampling_maps': self.gen_sampling_map.isChecked(),
'scatter_config': {
@ -4314,8 +4440,6 @@ class VisualizationPanel(QWidget):
self.gen_boxplots.setChecked(config['generate_boxplots'])
if 'generate_spectrum' in config:
self.gen_spectrum.setChecked(config['generate_spectrum'])
if 'generate_statistics' in config:
self.gen_stats_btn.setChecked(config['generate_statistics'])
if 'generate_glint_previews' in config:
self.gen_mask_glint.setChecked(config['generate_glint_previews'])
if 'generate_sampling_maps' in config:
@ -4755,7 +4879,7 @@ class Step6_5Panel(QWidget):
"输出模型目录:",
"Directories;;All Files (*.*)"
)
self.output_dir.line_edit.setPlaceholderText("8_non_empirical_models")
self.output_dir.line_edit.setPlaceholderText("8_Regression_Modeling")
# 修改浏览按钮为选择目录
self.output_dir.browse_btn.clicked.disconnect()
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
@ -4825,9 +4949,9 @@ class Step6_5Panel(QWidget):
# 如果output_dir为空使用工作目录或当前目录
main_window = self.parent().window()
if hasattr(main_window, 'work_dir') and main_window.work_dir:
output_dir = str(Path(main_window.work_dir) / "8_non_empirical_models")
output_dir = str(Path(main_window.work_dir) / "8_Regression_Modeling")
else:
output_dir = str(Path.cwd() / "8_non_empirical_models")
output_dir = str(Path.cwd() / "8_Regression_Modeling")
config['output_dir'] = output_dir
# 添加训练数据路径(用于独立运行)
@ -4952,7 +5076,8 @@ class Step6_75Panel(QWidget):
# 创建滚动区域来容纳自变量选择
x_scroll = QScrollArea()
x_scroll.setWidgetResizable(True)
x_scroll.setMaximumHeight(200)
x_scroll.setMinimumHeight(250)
x_scroll.setMaximumHeight(350)
x_widget = QWidget()
self.x_columns_layout = QGridLayout()
@ -4982,7 +5107,8 @@ class Step6_75Panel(QWidget):
# 创建滚动区域来容纳因变量选择
y_scroll = QScrollArea()
y_scroll.setWidgetResizable(True)
y_scroll.setMaximumHeight(150)
y_scroll.setMinimumHeight(200)
y_scroll.setMaximumHeight(300)
y_widget = QWidget()
self.y_columns_layout = QGridLayout()
@ -5044,7 +5170,7 @@ class Step6_75Panel(QWidget):
output_layout = QFormLayout()
self.output_dir = QLineEdit()
self.output_dir.setText("9_custom_regression")
self.output_dir.setText("9_Custom_Regression_Modeling")
output_layout.addRow("输出目录名:", self.output_dir)
output_group.setLayout(output_layout)
@ -5202,7 +5328,7 @@ class Step6_75Panel(QWidget):
checkbox.setChecked(method in selected_methods)
if 'output_dir' in config:
self.output_dir.setText(config['output_dir'] or "9_custom_regression")
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
@ -5574,26 +5700,28 @@ class WaterQualityGUI(QMainWindow):
self.step_list = QListWidget()
self.step_list.setStyleSheet(ModernStylesheet.get_sidebar_stylesheet())
# 定义阶段结构
# 定义阶段结构
self.process_stages = {
"阶段一:数据预处理": [
"阶段一:影像预处理": [
("step1", "1. 水域掩膜生成"),
("step2", "2. 耀斑区域识别"),
("step3", "3. 耀斑去除与修复"),
("step4", "4. 数据标准化处理"),
],
"阶段二:特征提取与建模": [
"阶段二:样本数据准备 ": [
("step4", "4. 数据标准化处理"),
("step5", "5. 光谱特征提取"),
("step5_5", "6. 水质参数指数计算"),
("step6", "7. 监督学习模型训练"),
("step6_5", "8. 经验统计回归"),
("step6_75", "9. 自定义回归模型"),
],
"阶段三:应用与可视化": [
"阶段三:模型构建与训练": [
("step6", "7. 机器学习模型训练"),
("step6_5", "8. 回归模型训练"),
("step6_75", "9. 自定义回归模型训练"),
],
"阶段四:预测与成果输出 ": [
("step7", "10. 采样点布设"),
("step8", "11. 基于监督学习预测"),
("step8_5", "12. 基于统计回归预测"),
("step8_75", "13. 基于自定义回归预测"),
("step8", "11. 机器学习学习预测"),
("step8_5", "12. 回归预测"),
("step8_75", "13. 自定义回归预测"),
("step9", "14. 专题图生成"),
("step9_viz", "15. 可视化分析"),
("step_report", "16. 分析报告生成"),