fix: 修复工作目录与步骤名不对应、回归预测虚数报错、模型加载及预处理名称转换问题,重构可视化并修正勾选联动
This commit is contained in:
@ -660,7 +660,7 @@ class VisualizationWorkerThread(QThread):
|
||||
self.failed.emit("训练光谱 CSV 无效或不存在,请确认已选择步骤5输出的文件。")
|
||||
return
|
||||
if not models_dir or not Path(models_dir).is_dir():
|
||||
self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_models 下的参数子文件夹。")
|
||||
self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
|
||||
return
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||||
@ -672,12 +672,111 @@ class VisualizationWorkerThread(QThread):
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
parts = []
|
||||
|
||||
# 获取训练数据CSV路径(多个图表类型共用)
|
||||
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
|
||||
|
||||
# 生成散点图
|
||||
if self.extra.get("gen_scatter"):
|
||||
if training_csv.is_file():
|
||||
models_dir = wp / "7_Supervised_Model_Training"
|
||||
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||||
training_csv_path=str(training_csv),
|
||||
models_dir=str(models_dir),
|
||||
)
|
||||
count = len(scatter_paths) if scatter_paths else 0
|
||||
parts.append(f"散点图: {count} 个")
|
||||
else:
|
||||
parts.append("散点图: 跳过(无模型目录)")
|
||||
else:
|
||||
parts.append("散点图: 跳过(无训练数据)")
|
||||
|
||||
# 生成光谱图
|
||||
if self.extra.get("gen_spectrum"):
|
||||
if training_csv.is_file():
|
||||
import pandas as pd
|
||||
df = pd.read_csv(training_csv)
|
||||
# 推断水质参数列(光谱波段列之前的数值型列)
|
||||
wl_col = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl_col, str):
|
||||
idx = int(df.columns.get_loc(wl_col)) + 1
|
||||
else:
|
||||
idx = int(wl_col)
|
||||
param_cols = []
|
||||
if idx > 0 and idx < len(df.columns):
|
||||
param_cols = [
|
||||
c for c in df.columns[:idx]
|
||||
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
|
||||
]
|
||||
if param_cols:
|
||||
# plot_spectrum_by_parameter 接受单个参数列,逐个调用
|
||||
spectrum_paths = []
|
||||
for param_col in param_cols:
|
||||
try:
|
||||
path = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(training_csv),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl_col,
|
||||
n_groups=5,
|
||||
)
|
||||
if path:
|
||||
spectrum_paths.append(path)
|
||||
except Exception as e:
|
||||
print(f"生成光谱图失败 ({param_col}): {e}")
|
||||
count = len(spectrum_paths)
|
||||
parts.append(f"光谱图: {count} 个")
|
||||
else:
|
||||
parts.append("光谱图: 跳过(无可用参数列)")
|
||||
else:
|
||||
parts.append("光谱图: 跳过(无训练数据)")
|
||||
|
||||
# 生成统计图
|
||||
if self.extra.get("gen_boxplots"):
|
||||
if training_csv.is_file():
|
||||
import pandas as pd
|
||||
df = pd.read_csv(training_csv)
|
||||
# **只统计水质参数列(数值型),排除波长列**
|
||||
# 获取水质参数列(数值型且不是波长、不是坐标列)
|
||||
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
|
||||
param_cols = [
|
||||
c for c in df.select_dtypes(include=[np.number]).columns
|
||||
if not any(exc in c.lower() for exc in exclude_cols)
|
||||
]
|
||||
# 排除光谱波长列:找到波长开始位置,只取之前的数值列
|
||||
wl = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl, str):
|
||||
idx = int(df.columns.get_loc(wl)) + 1
|
||||
else:
|
||||
idx = int(wl)
|
||||
if 0 < idx < len(df.columns):
|
||||
meta_set = set(df.columns[:idx])
|
||||
param_cols = [c for c in param_cols if c in meta_set]
|
||||
|
||||
if param_cols:
|
||||
output_dict = viz.plot_statistical_charts(
|
||||
csv_path=str(training_csv),
|
||||
parameter_columns=param_cols,
|
||||
)
|
||||
# plot_statistical_charts 返回字典,统计值非空
|
||||
count = len([v for v in output_dict.values() if v]) if output_dict else 0
|
||||
parts.append(f"统计图: {count} 个")
|
||||
else:
|
||||
parts.append("统计图: 跳过(无可用水质参数列)")
|
||||
else:
|
||||
parts.append("统计图: 跳过(无训练数据)")
|
||||
|
||||
# 生成掩膜/耀斑预览图
|
||||
if self.extra.get("gen_mask_glint"):
|
||||
preview_paths = viz.generate_glint_deglint_previews(
|
||||
work_dir=str(wp),
|
||||
output_subdir="glint_deglint_previews",
|
||||
)
|
||||
parts.append(f"掩膜/耀斑预览 {len(preview_paths) if preview_paths else 0} 个")
|
||||
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0} 个")
|
||||
|
||||
# 生成采样点地图
|
||||
if self.extra.get("gen_sampling_map"):
|
||||
hyperspectral_files = []
|
||||
deglint_dir = wp / "3_deglint"
|
||||
@ -3103,35 +3202,37 @@ class Step8_75Panel(QWidget):
|
||||
)
|
||||
layout.addWidget(self.sampling_csv_file)
|
||||
|
||||
# 公式CSV文件选择
|
||||
# 自定义回归模型目录选择(9_Custom_Regression_Modeling)
|
||||
self.regression_models_dir = FileSelectWidget(
|
||||
"回归模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.regression_models_dir.label.setText("回归模型目录:")
|
||||
# 修改浏览按钮为选择目录
|
||||
self.regression_models_dir.browse_btn.clicked.disconnect()
|
||||
self.regression_models_dir.browse_btn.clicked.connect(self.browse_regression_models_dir)
|
||||
self.regression_models_dir.set_path("9_Custom_Regression_Modeling") # 设置默认值
|
||||
layout.addWidget(self.regression_models_dir)
|
||||
|
||||
# 公式CSV文件选择(用于查找index_formula)
|
||||
self.formula_csv_file = FileSelectWidget(
|
||||
"公式CSV文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.formula_csv_file.label.setText("公式CSV文件:")
|
||||
layout.addWidget(self.formula_csv_file)
|
||||
|
||||
# 模型目录选择
|
||||
self.models_dir_file = FileSelectWidget(
|
||||
"模型目录:",
|
||||
# 输出目录选择
|
||||
self.output_dir_widget = FileSelectWidget(
|
||||
"输出目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.models_dir_file.label.setText("模型目录:")
|
||||
self.output_dir_widget.label.setText("输出目录:")
|
||||
# 修改浏览按钮为选择目录
|
||||
self.models_dir_file.browse_btn.clicked.disconnect()
|
||||
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
|
||||
layout.addWidget(self.models_dir_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("预测参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
# 预测列名
|
||||
self.prediction_column = QLineEdit()
|
||||
self.prediction_column.setText("prediction")
|
||||
params_layout.addRow("预测列名:", self.prediction_column)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
self.output_dir_widget.browse_btn.clicked.disconnect()
|
||||
self.output_dir_widget.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
self.output_dir_widget.line_edit.setPlaceholderText("留空使用默认prediction目录")
|
||||
layout.addWidget(self.output_dir_widget)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
@ -3162,45 +3263,59 @@ class Step8_75Panel(QWidget):
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def browse_models_dir(self):
|
||||
"""浏览模型目录"""
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", "")
|
||||
def browse_regression_models_dir(self):
|
||||
"""浏览回归模型目录"""
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择回归模型目录", "")
|
||||
if dir_path:
|
||||
self.models_dir_file.set_path(dir_path)
|
||||
self.regression_models_dir.set_path(dir_path)
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出目录", "")
|
||||
if dir_path:
|
||||
self.output_dir_widget.set_path(dir_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'prediction_column': self.prediction_column.text(),
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
|
||||
# 添加采样光谱CSV路径
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if sampling_csv_path:
|
||||
config['sampling_csv_path'] = sampling_csv_path
|
||||
|
||||
# 添加回归模型目录路径
|
||||
regression_models_dir = self.regression_models_dir.get_path()
|
||||
if regression_models_dir:
|
||||
config['custom_regression_dir'] = regression_models_dir
|
||||
|
||||
# 添加公式CSV文件路径
|
||||
formula_csv_path = self.formula_csv_file.get_path()
|
||||
if formula_csv_path:
|
||||
config['formula_csv_file'] = formula_csv_path
|
||||
# 添加模型目录路径
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if models_dir:
|
||||
config['custom_regression_dir'] = models_dir
|
||||
config['formula_csv_path'] = formula_csv_path
|
||||
|
||||
# 添加输出目录路径
|
||||
output_dir = self.output_dir_widget.get_path()
|
||||
if output_dir:
|
||||
config['output_dir'] = output_dir
|
||||
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'prediction_column' in config:
|
||||
self.prediction_column.setText(config['prediction_column'])
|
||||
|
||||
if 'sampling_csv_path' in config:
|
||||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||||
|
||||
if 'formula_csv_file' in config:
|
||||
self.formula_csv_file.set_path(config['formula_csv_file'])
|
||||
|
||||
if 'custom_regression_dir' in config:
|
||||
self.models_dir_file.set_path(config['custom_regression_dir'])
|
||||
self.regression_models_dir.set_path(config['custom_regression_dir'])
|
||||
|
||||
if 'formula_csv_path' in config:
|
||||
self.formula_csv_file.set_path(config['formula_csv_path'])
|
||||
|
||||
if 'output_dir' in config:
|
||||
self.output_dir_widget.set_path(config['output_dir'])
|
||||
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
@ -3213,9 +3328,9 @@ class Step8_75Panel(QWidget):
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
|
||||
formula_csv_path = self.formula_csv_file.get_path()
|
||||
if not formula_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择公式CSV文件!")
|
||||
regression_models_dir = self.regression_models_dir.get_path()
|
||||
if not regression_models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择回归模型目录!")
|
||||
return
|
||||
|
||||
# 获取配置
|
||||
@ -3323,8 +3438,7 @@ class ImageCategoryTree(QTreeWidget):
|
||||
("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"),
|
||||
("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"),
|
||||
("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"),
|
||||
("采样分析", ["sampling", "flight_path", "point_map", "trajectory"], "📍"),
|
||||
("其他图表", [], "📁"),
|
||||
("含量分布图", [], "📁"),
|
||||
]
|
||||
|
||||
def __init__(self, parent=None):
|
||||
@ -3376,7 +3490,7 @@ class ImageCategoryTree(QTreeWidget):
|
||||
|
||||
# 根据文件名关键词确定类别
|
||||
category = self._determine_category(file_path.name)
|
||||
category_item = self.category_items.get(category, self.category_items["其他图表"])
|
||||
category_item = self.category_items.get(category, self.category_items["含量分布图"])
|
||||
|
||||
# 创建图像项
|
||||
image_item = QTreeWidgetItem(category_item)
|
||||
@ -3394,7 +3508,7 @@ class ImageCategoryTree(QTreeWidget):
|
||||
if any(keyword in filename_lower for keyword in keywords):
|
||||
return category_name
|
||||
|
||||
return "其他图表"
|
||||
return "含量分布图"
|
||||
|
||||
def scan_directory(self, work_dir: str):
|
||||
"""扫描目录中的所有图像文件"""
|
||||
@ -3682,11 +3796,7 @@ class VisualizationPanel(QWidget):
|
||||
def _viz_set_busy(self, busy: bool):
|
||||
for w in (
|
||||
getattr(self, "gen_all_btn", None),
|
||||
getattr(self, "gen_scatter_btn", None),
|
||||
getattr(self, "gen_spectrum_btn", None),
|
||||
getattr(self, "gen_stats_btn", None),
|
||||
getattr(self, "gen_mask_glint_btn", None),
|
||||
getattr(self, "gen_sampling_map_btn", None),
|
||||
getattr(self, "scan_btn", None),
|
||||
):
|
||||
if w is not None:
|
||||
w.setEnabled(not busy)
|
||||
@ -3728,7 +3838,13 @@ class VisualizationPanel(QWidget):
|
||||
return [c for c in meta if pd.api.types.is_numeric_dtype(df[c])]
|
||||
|
||||
def _statistics_param_columns(self, df: pd.DataFrame) -> List[str]:
|
||||
"""统计图用的参数列;若存在光谱波段,则只统计波段前的字段。"""
|
||||
"""统计图用的参数列:**只统计水质参数列(数值型),排除波长列**。
|
||||
- 包括:数值型的水质参数(浓度、含量等)
|
||||
- 排除:光谱波长列(虽然也是数值型,但不是水质参数)
|
||||
- 排除:坐标列(UTM_X, UTM_Y, lat, lon等)
|
||||
若存在光谱波段,则只统计波段前的数值字段。
|
||||
"""
|
||||
# 选择数值类型列
|
||||
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
||||
wl = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl, str):
|
||||
@ -3737,6 +3853,7 @@ class VisualizationPanel(QWidget):
|
||||
idx = int(wl)
|
||||
coord_kw = ("utm", "lat", "lon")
|
||||
if 0 < idx < len(df.columns):
|
||||
# 只取波长开始之前的列(水质参数区域)
|
||||
meta_set = set(df.columns[:idx])
|
||||
return [
|
||||
col
|
||||
@ -3744,6 +3861,7 @@ class VisualizationPanel(QWidget):
|
||||
if col in meta_set and not any(x in str(col).lower() for x in coord_kw)
|
||||
]
|
||||
return [
|
||||
# 如果没有找到波长列,排除坐标相关列
|
||||
col
|
||||
for col in numeric_cols
|
||||
if not any(x in str(col).lower() for x in coord_kw + ("x", "y"))
|
||||
@ -3825,7 +3943,7 @@ class VisualizationPanel(QWidget):
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"提示",
|
||||
"未生成任何散点图。请确认 7_models 下已有各参数子目录及模型文件,"
|
||||
"未生成任何散点图。请确认 7_Supervised_Model_Training 下已有各参数子目录及模型文件,"
|
||||
"且训练 CSV 与建模时一致。",
|
||||
)
|
||||
elif t == "generate_all_selected":
|
||||
@ -3870,36 +3988,22 @@ class VisualizationPanel(QWidget):
|
||||
self.image_tree = ImageCategoryTree()
|
||||
self.image_tree.itemClicked.connect(self.on_tree_item_clicked)
|
||||
tree_layout.addWidget(self.image_tree)
|
||||
|
||||
# 生成按钮组
|
||||
gen_btn_layout = QHBoxLayout()
|
||||
self.gen_all_btn = QPushButton("🚀 生成全部")
|
||||
self.gen_all_btn.setToolTip("生成所有类型的可视化图表")
|
||||
self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;")
|
||||
self.gen_all_btn.clicked.connect(self.generate_all_visualizations)
|
||||
gen_btn_layout.addWidget(self.gen_all_btn)
|
||||
|
||||
self.scan_btn = QPushButton("📁 扫描")
|
||||
self.scan_btn.setToolTip("扫描工作目录中的图像文件")
|
||||
self.scan_btn.clicked.connect(self.scan_work_directory)
|
||||
gen_btn_layout.addWidget(self.scan_btn)
|
||||
|
||||
tree_layout.addLayout(gen_btn_layout)
|
||||
|
||||
tree_group.setLayout(tree_layout)
|
||||
left_layout.addWidget(tree_group, 1)
|
||||
|
||||
|
||||
# 可视化配置
|
||||
config_group = QGroupBox("可视化配置")
|
||||
config_layout = QVBoxLayout()
|
||||
|
||||
|
||||
self.gen_scatter = QCheckBox("模型评估散点图")
|
||||
self.gen_scatter.setChecked(True)
|
||||
config_layout.addWidget(self.gen_scatter)
|
||||
|
||||
|
||||
self.gen_spectrum = QCheckBox("光谱曲线图")
|
||||
self.gen_spectrum.setChecked(True)
|
||||
config_layout.addWidget(self.gen_spectrum)
|
||||
|
||||
|
||||
self.gen_boxplots = QCheckBox("统计图表")
|
||||
self.gen_boxplots.setChecked(True)
|
||||
config_layout.addWidget(self.gen_boxplots)
|
||||
@ -3912,6 +4016,27 @@ class VisualizationPanel(QWidget):
|
||||
self.gen_sampling_map.setChecked(True)
|
||||
config_layout.addWidget(self.gen_sampling_map)
|
||||
|
||||
# 添加分隔线
|
||||
config_layout.addSpacing(10)
|
||||
line = QFrame()
|
||||
line.setFrameShape(QFrame.HLine)
|
||||
line.setStyleSheet("color: #ddd;")
|
||||
config_layout.addWidget(line)
|
||||
config_layout.addSpacing(10)
|
||||
|
||||
# 生成全部按钮
|
||||
self.gen_all_btn = QPushButton("🚀 生成全部")
|
||||
self.gen_all_btn.setToolTip("生成所有类型的可视化图表")
|
||||
self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;")
|
||||
self.gen_all_btn.clicked.connect(self.generate_all_visualizations)
|
||||
config_layout.addWidget(self.gen_all_btn)
|
||||
|
||||
# 扫描按钮
|
||||
self.scan_btn = QPushButton("📁 扫描目录")
|
||||
self.scan_btn.setToolTip("扫描工作目录中的图像文件")
|
||||
self.scan_btn.clicked.connect(self.scan_work_directory)
|
||||
config_layout.addWidget(self.scan_btn)
|
||||
|
||||
config_group.setLayout(config_layout)
|
||||
left_layout.addWidget(config_group)
|
||||
|
||||
@ -3928,43 +4053,7 @@ class VisualizationPanel(QWidget):
|
||||
self.image_viewer = ImageViewerWidget()
|
||||
self.image_viewer.refresh_btn.clicked.connect(self.scan_work_directory)
|
||||
right_layout.addWidget(self.image_viewer, 1)
|
||||
|
||||
# 生成特定图表按钮组
|
||||
specific_group = QGroupBox("生成特定图表")
|
||||
specific_layout = QHBoxLayout()
|
||||
|
||||
self.gen_scatter_btn = QPushButton("📊 散点图")
|
||||
self.gen_scatter_btn.setToolTip(
|
||||
"基于工作目录下 5_training_spectra/training_spectra.csv 与 7_models 生成模型评估散点图"
|
||||
)
|
||||
self.gen_scatter_btn.clicked.connect(lambda: self.generate_chart('scatter'))
|
||||
specific_layout.addWidget(self.gen_scatter_btn)
|
||||
|
||||
self.gen_spectrum_btn = QPushButton("📈 光谱图")
|
||||
self.gen_spectrum_btn.setToolTip(
|
||||
"基于 5_training_spectra/training_spectra.csv,为每个数值型水质参数各生成一张光谱对比图(无需选择)"
|
||||
)
|
||||
self.gen_spectrum_btn.clicked.connect(lambda: self.generate_chart('spectrum'))
|
||||
specific_layout.addWidget(self.gen_spectrum_btn)
|
||||
|
||||
self.gen_stats_btn = QPushButton("📉 统计图")
|
||||
self.gen_stats_btn.setToolTip(
|
||||
"基于工作目录下 5_training_spectra/training_spectra.csv 生成箱线图、直方图与相关性热力图"
|
||||
)
|
||||
self.gen_stats_btn.clicked.connect(lambda: self.generate_chart('statistics'))
|
||||
specific_layout.addWidget(self.gen_stats_btn)
|
||||
|
||||
self.gen_mask_glint_btn = QPushButton("🖼️ 掩膜图")
|
||||
self.gen_mask_glint_btn.clicked.connect(lambda: self.generate_mask_glint_previews())
|
||||
specific_layout.addWidget(self.gen_mask_glint_btn)
|
||||
|
||||
self.gen_sampling_map_btn = QPushButton("📍 采样点图")
|
||||
self.gen_sampling_map_btn.clicked.connect(lambda: self.generate_sampling_point_map())
|
||||
specific_layout.addWidget(self.gen_sampling_map_btn)
|
||||
|
||||
specific_group.setLayout(specific_layout)
|
||||
right_layout.addWidget(specific_group)
|
||||
|
||||
right_panel.setLayout(right_layout)
|
||||
main_layout.addWidget(right_panel, 1)
|
||||
|
||||
@ -3999,6 +4088,9 @@ class VisualizationPanel(QWidget):
|
||||
print(f"扫描工作目录: {work_path}")
|
||||
self.image_tree.scan_directory(str(work_path))
|
||||
|
||||
# 设置三个预测步骤的默认输出路径
|
||||
self._setup_prediction_output_dirs(work_path)
|
||||
|
||||
# 如果有图像,自动选择第一个
|
||||
viz_dir = work_path / "14_visualization"
|
||||
if viz_dir.exists():
|
||||
@ -4006,6 +4098,41 @@ class VisualizationPanel(QWidget):
|
||||
if image_files:
|
||||
self.image_viewer.load_image(str(image_files[0]))
|
||||
|
||||
def _setup_prediction_output_dirs(self, work_path: Path):
|
||||
"""
|
||||
设置三个预测步骤的默认输出目录
|
||||
在11_12_13_predictions下创建三个子文件夹
|
||||
"""
|
||||
try:
|
||||
# 基础预测目录
|
||||
base_prediction_dir = work_path / "11_12_13_predictions"
|
||||
|
||||
# 三个子文件夹路径
|
||||
ml_dir = base_prediction_dir / "Machine_Learning_Prediction"
|
||||
reg_dir = base_prediction_dir / "Regression_Model_Prediction"
|
||||
custom_dir = base_prediction_dir / "Custom_Regression_Prediction"
|
||||
|
||||
# 创建目录(如果不存在)
|
||||
ml_dir.mkdir(parents=True, exist_ok=True)
|
||||
reg_dir.mkdir(parents=True, exist_ok=True)
|
||||
custom_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 设置Step8Panel(机器学习预测)的默认输出路径
|
||||
if hasattr(self, 'step8_panel') and hasattr(self.step8_panel, 'output_file'):
|
||||
self.step8_panel.output_file.set_path(str(ml_dir))
|
||||
|
||||
# 设置Step8_5Panel(回归模型预测)的默认输出路径
|
||||
if hasattr(self, 'step8_5_panel') and hasattr(self.step8_5_panel, 'output_file'):
|
||||
self.step8_5_panel.output_file.set_path(str(reg_dir))
|
||||
|
||||
# 设置Step8_75Panel(自定义回归预测)的默认输出路径
|
||||
if hasattr(self, 'step8_75_panel') and hasattr(self.step8_75_panel, 'output_dir_widget'):
|
||||
self.step8_75_panel.output_dir_widget.set_path(str(custom_dir))
|
||||
|
||||
print(f"预测输出目录已设置:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}")
|
||||
except Exception as e:
|
||||
print(f"设置预测输出目录失败: {e}")
|
||||
|
||||
def on_tree_item_clicked(self, item, column):
|
||||
"""目录树项点击事件"""
|
||||
data = item.data(0, Qt.UserRole)
|
||||
@ -4022,37 +4149,37 @@ class VisualizationPanel(QWidget):
|
||||
if not self.work_dir:
|
||||
QMessageBox.warning(self, "警告", "请先选择工作目录!")
|
||||
return
|
||||
|
||||
|
||||
work_path = Path(self.work_dir)
|
||||
if not work_path.exists():
|
||||
QMessageBox.warning(self, "警告", "工作目录不存在!")
|
||||
return
|
||||
|
||||
reply = QMessageBox.question(
|
||||
self, "确认生成",
|
||||
"将按左侧勾选项在后台生成可视化(掩膜/耀斑预览、采样点图等),可能需要较长时间。\n是否继续?",
|
||||
QMessageBox.Yes | QMessageBox.No
|
||||
)
|
||||
|
||||
if reply != QMessageBox.Yes:
|
||||
return
|
||||
|
||||
if self.gen_scatter.isChecked():
|
||||
print("生成散点图...(占位,请用建模/可视化流程生成)")
|
||||
if self.gen_spectrum.isChecked():
|
||||
print("生成光谱图...(占位,请用下方「光谱图」按钮)")
|
||||
if self.gen_boxplots.isChecked():
|
||||
print("生成统计图...(占位,请用下方「统计图」按钮)")
|
||||
|
||||
if not self.gen_mask_glint.isChecked() and not self.gen_sampling_map.isChecked():
|
||||
# 检查是否有任何选项被勾选
|
||||
if not (self.gen_scatter.isChecked() or self.gen_spectrum.isChecked() or
|
||||
self.gen_boxplots.isChecked() or self.gen_mask_glint.isChecked() or
|
||||
self.gen_sampling_map.isChecked()):
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"提示",
|
||||
"请至少勾选「掩膜和耀斑缩略图」或「采样点地图」以执行后台批量任务。",
|
||||
"请至少勾选一项可视化配置选项以生成图表。",
|
||||
)
|
||||
return
|
||||
|
||||
reply = QMessageBox.question(
|
||||
self, "确认生成",
|
||||
"将根据左侧勾选项在后台生成可视化图表,可能需要较长时间。\n是否继续?",
|
||||
QMessageBox.Yes | QMessageBox.No
|
||||
)
|
||||
|
||||
if reply != QMessageBox.Yes:
|
||||
return
|
||||
|
||||
# 收集所有选中的任务
|
||||
extra = {
|
||||
"gen_scatter": self.gen_scatter.isChecked(),
|
||||
"gen_spectrum": self.gen_spectrum.isChecked(),
|
||||
"gen_boxplots": self.gen_boxplots.isChecked(),
|
||||
"gen_mask_glint": self.gen_mask_glint.isChecked(),
|
||||
"gen_sampling_map": self.gen_sampling_map.isChecked(),
|
||||
}
|
||||
@ -4082,7 +4209,7 @@ class VisualizationPanel(QWidget):
|
||||
)
|
||||
return
|
||||
training_csv = training_spectra_csv
|
||||
models_dir = work_path / "7_models"
|
||||
models_dir = work_path / "7_Supervised_Model_Training"
|
||||
if not models_dir.is_dir() or not any(
|
||||
d.is_dir() for d in models_dir.iterdir()
|
||||
):
|
||||
@ -4284,7 +4411,6 @@ class VisualizationPanel(QWidget):
|
||||
'generate_scatter': self.gen_scatter.isChecked(),
|
||||
'generate_boxplots': self.gen_boxplots.isChecked(),
|
||||
'generate_spectrum': self.gen_spectrum.isChecked(),
|
||||
'generate_statistics': self.gen_stats_btn.isChecked(),
|
||||
'generate_glint_previews': self.gen_mask_glint.isChecked(),
|
||||
'generate_sampling_maps': self.gen_sampling_map.isChecked(),
|
||||
'scatter_config': {
|
||||
@ -4314,8 +4440,6 @@ class VisualizationPanel(QWidget):
|
||||
self.gen_boxplots.setChecked(config['generate_boxplots'])
|
||||
if 'generate_spectrum' in config:
|
||||
self.gen_spectrum.setChecked(config['generate_spectrum'])
|
||||
if 'generate_statistics' in config:
|
||||
self.gen_stats_btn.setChecked(config['generate_statistics'])
|
||||
if 'generate_glint_previews' in config:
|
||||
self.gen_mask_glint.setChecked(config['generate_glint_previews'])
|
||||
if 'generate_sampling_maps' in config:
|
||||
@ -4755,7 +4879,7 @@ class Step6_5Panel(QWidget):
|
||||
"输出模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_dir.line_edit.setPlaceholderText("8_non_empirical_models")
|
||||
self.output_dir.line_edit.setPlaceholderText("8_Regression_Modeling")
|
||||
# 修改浏览按钮为选择目录
|
||||
self.output_dir.browse_btn.clicked.disconnect()
|
||||
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
@ -4825,9 +4949,9 @@ class Step6_5Panel(QWidget):
|
||||
# 如果output_dir为空,使用工作目录或当前目录
|
||||
main_window = self.parent().window()
|
||||
if hasattr(main_window, 'work_dir') and main_window.work_dir:
|
||||
output_dir = str(Path(main_window.work_dir) / "8_non_empirical_models")
|
||||
output_dir = str(Path(main_window.work_dir) / "8_Regression_Modeling")
|
||||
else:
|
||||
output_dir = str(Path.cwd() / "8_non_empirical_models")
|
||||
output_dir = str(Path.cwd() / "8_Regression_Modeling")
|
||||
config['output_dir'] = output_dir
|
||||
|
||||
# 添加训练数据路径(用于独立运行)
|
||||
@ -4952,7 +5076,8 @@ class Step6_75Panel(QWidget):
|
||||
# 创建滚动区域来容纳自变量选择
|
||||
x_scroll = QScrollArea()
|
||||
x_scroll.setWidgetResizable(True)
|
||||
x_scroll.setMaximumHeight(200)
|
||||
x_scroll.setMinimumHeight(250)
|
||||
x_scroll.setMaximumHeight(350)
|
||||
|
||||
x_widget = QWidget()
|
||||
self.x_columns_layout = QGridLayout()
|
||||
@ -4982,7 +5107,8 @@ class Step6_75Panel(QWidget):
|
||||
# 创建滚动区域来容纳因变量选择
|
||||
y_scroll = QScrollArea()
|
||||
y_scroll.setWidgetResizable(True)
|
||||
y_scroll.setMaximumHeight(150)
|
||||
y_scroll.setMinimumHeight(200)
|
||||
y_scroll.setMaximumHeight(300)
|
||||
|
||||
y_widget = QWidget()
|
||||
self.y_columns_layout = QGridLayout()
|
||||
@ -5044,7 +5170,7 @@ class Step6_75Panel(QWidget):
|
||||
output_layout = QFormLayout()
|
||||
|
||||
self.output_dir = QLineEdit()
|
||||
self.output_dir.setText("9_custom_regression")
|
||||
self.output_dir.setText("9_Custom_Regression_Modeling")
|
||||
output_layout.addRow("输出目录名:", self.output_dir)
|
||||
|
||||
output_group.setLayout(output_layout)
|
||||
@ -5202,7 +5328,7 @@ class Step6_75Panel(QWidget):
|
||||
checkbox.setChecked(method in selected_methods)
|
||||
|
||||
if 'output_dir' in config:
|
||||
self.output_dir.setText(config['output_dir'] or "9_custom_regression")
|
||||
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
@ -5574,26 +5700,28 @@ class WaterQualityGUI(QMainWindow):
|
||||
self.step_list = QListWidget()
|
||||
self.step_list.setStyleSheet(ModernStylesheet.get_sidebar_stylesheet())
|
||||
|
||||
# 定义三阶段结构
|
||||
# 定义四阶段结构
|
||||
self.process_stages = {
|
||||
"阶段一:数据预处理": [
|
||||
"阶段一:影像预处理": [
|
||||
("step1", "1. 水域掩膜生成"),
|
||||
("step2", "2. 耀斑区域识别"),
|
||||
("step3", "3. 耀斑去除与修复"),
|
||||
("step4", "4. 数据标准化处理"),
|
||||
],
|
||||
"阶段二:特征提取与建模": [
|
||||
"阶段二:样本数据准备 ": [
|
||||
("step4", "4. 数据标准化处理"),
|
||||
("step5", "5. 光谱特征提取"),
|
||||
("step5_5", "6. 水质参数指数计算"),
|
||||
("step6", "7. 监督学习模型训练"),
|
||||
("step6_5", "8. 经验统计回归"),
|
||||
("step6_75", "9. 自定义回归模型"),
|
||||
],
|
||||
"阶段三:应用与可视化": [
|
||||
"阶段三:模型构建与训练": [
|
||||
("step6", "7. 机器学习模型训练"),
|
||||
("step6_5", "8. 回归模型训练"),
|
||||
("step6_75", "9. 自定义回归模型训练"),
|
||||
],
|
||||
"阶段四:预测与成果输出 ": [
|
||||
("step7", "10. 采样点布设"),
|
||||
("step8", "11. 基于监督学习预测"),
|
||||
("step8_5", "12. 基于统计回归预测"),
|
||||
("step8_75", "13. 基于自定义回归预测"),
|
||||
("step8", "11. 机器学习学习预测"),
|
||||
("step8_5", "12. 回归预测"),
|
||||
("step8_75", "13. 自定义回归预测"),
|
||||
("step9", "14. 专题图生成"),
|
||||
("step9_viz", "15. 可视化分析"),
|
||||
("step_report", "16. 分析报告生成"),
|
||||
|
||||
Reference in New Issue
Block a user