结构修改,后端文件跟前端内容进行适配

This commit is contained in:
DXC
2026-06-11 17:44:24 +08:00
parent 3584c07b67
commit e59703f163
12 changed files with 1311 additions and 495 deletions

View File

@ -272,10 +272,10 @@ class WorkerThread(QThread):
ctx = PipelineContext( ctx = PipelineContext(
img_path=self.config.get('step1', {}).get('img_path'), img_path=self.config.get('step1', {}).get('img_path'),
water_mask_path=self.config.get('step1', {}).get('mask_path'), water_mask_path=self.config.get('step1', {}).get('mask_path'),
csv_path=self.config.get('step4', {}).get('csv_path'), csv_path=self.config.get('step4_sampling', {}).get('csv_path'),
boundary_path=self.config.get('step5', {}).get('boundary_path'), boundary_path=self.config.get('step5_clean', {}).get('boundary_path'),
boundary_shp_path=self.config.get('step14', {}).get('boundary_shp_path'), boundary_shp_path=self.config.get('step11_map', {}).get('boundary_shp_path'),
formula_csv_path=self.config.get('step12', {}).get('formula_csv_path'), formula_csv_path=self.config.get('step8_non_empirical_modeling', {}).get('formula_csv_path'),
work_dir=self.work_dir, work_dir=self.work_dir,
user_config=self.config user_config=self.config
) )
@ -323,21 +323,16 @@ class WorkerThread(QThread):
'step1': 'step1_generate_water_mask', 'step1': 'step1_generate_water_mask',
'step2': 'step2_find_glint_area', 'step2': 'step2_find_glint_area',
'step3': 'step3_remove_glint', 'step3': 'step3_remove_glint',
'step4': 'step5_process_csv', 'step4_sampling': 'step4_sampling',
'step5': 'step6_extract_spectra', 'step5_clean': 'step5_process_csv',
'step7': 'step7_calc_indices', 'step6_feature': 'step6_extract_spectra',
'step7_index': 'step7_calc_indices', 'step7_index': 'step7_calc_indices',
'step8': 'step8_train_ml',
'step8_ml_train': 'step8_train_ml', 'step8_ml_train': 'step8_train_ml',
'step8_non_empirical_modeling': 'step8_non_empirical_modeling', 'step8_non_empirical_modeling': 'step8_non_empirical_modeling',
'step8_qaa': 'step8_qaa_inversion', 'step8_qaa': 'step8_qaa_inversion',
'step9': 'step9_watercolor_inversion',
'step9_ml_predict': 'step9_predict_ml', 'step9_ml_predict': 'step9_predict_ml',
'step10': 'step4_sampling', 'step10_watercolor': 'step9_watercolor_inversion',
'step10_map': 'step10_map', 'step11_map': 'step10_map',
'step11_ml': 'step9_predict_ml',
'step11': 'step11_non_empirical_prediction',
'step14': 'step10_map'
} }
if step_name not in step_method_map: if step_name not in step_method_map:
@ -367,17 +362,9 @@ class WorkerThread(QThread):
step_config['skip_dependency_check'] = True step_config['skip_dependency_check'] = True
if step_name == 'step14': if step_name in ['step2', 'step3', 'step4_sampling', 'step5_clean', 'step7_index', 'step9_ml_predict']:
step_config.pop('step9_batch_mode', None)
step_config.pop('prediction_csv_dir', None)
step_config.pop('recursive_csv_scan', None)
if step_name in ['step2', 'step3', 'step4', 'step5', 'step7', 'step10', 'step11_ml', 'step11', 'step12']:
step_config.pop('output_path', None) step_config.pop('output_path', None)
if step_name == 'step11' and 'models_dir' in step_config:
step_config['non_empirical_models_dir'] = step_config.pop('models_dir')
method = getattr(self.pipeline, method_name) method = getattr(self.pipeline, method_name)
result = method(**step_config) result = method(**step_config)
@ -436,100 +423,76 @@ class WorkerThread(QThread):
" → 请确认「耀斑去除」已成功运行,或重新配置路径。" " → 请确认「耀斑去除」已成功运行,或重新配置路径。"
) )
# ── 步骤4实测水质数据 CSV ── # ── 步骤4_sampling:实测水质数据 CSV ──
step4_cfg = config.get('step4', {}) step4_cfg = config.get('step4_sampling', {})
csv_path = step4_cfg.get('csv_path') csv_path = step4_cfg.get('csv_path')
if csv_path and not os.path.isfile(csv_path): if csv_path and not os.path.isfile(csv_path):
errors.append( errors.append(
f"步骤 4实测水质数据文件不存在\n {csv_path}\n" f"步骤 4_sampling:实测水质数据文件不存在:\n {csv_path}\n"
" → 请检查 CSV 路径是否正确,或重新上传数据文件。" " → 请检查 CSV 路径是否正确,或重新上传数据文件。"
) )
# ── 步骤5采样点平均光谱提取 ── # ── 步骤5_clean:采样点平均光谱提取 ──
step5_cfg = config.get('step5', {}) step5_cfg = config.get('step5_clean', {})
step5_csv = step5_cfg.get('csv_path') step5_csv = step5_cfg.get('csv_path')
boundary_path = step5_cfg.get('boundary_path') boundary_path = step5_cfg.get('boundary_path')
if step5_csv and not os.path.isfile(step5_csv): if step5_csv and not os.path.isfile(step5_csv):
errors.append( errors.append(
f"步骤 5实测水质数据文件不存在\n {step5_csv}\n" f"步骤 5_clean:实测水质数据文件不存在:\n {step5_csv}\n"
" → 请检查「流程步骤-阶段五」中的 CSV 路径。" " → 请检查「流程步骤-阶段五」中的 CSV 路径。"
) )
if boundary_path and not os.path.isfile(boundary_path): if boundary_path and not os.path.isfile(boundary_path):
errors.append( errors.append(
f"步骤 5边界矢量文件不存在\n {boundary_path}\n" f"步骤 5_clean:边界矢量文件不存在:\n {boundary_path}\n"
" → 请确认「流程步骤-阶段五」中已填写有效的边界 shp 路径。" " → 请确认「流程步骤-阶段五」中已填写有效的边界 shp 路径。"
) )
# ── 步骤6水质光谱指数训练光谱 CSV ── # ── 步骤6_feature(水质光谱指数):训练光谱 CSV ──
step6_cfg = config.get('step6', {}) step6_cfg = config.get('step6_feature', {})
training_csv = step6_cfg.get('training_csv_path') training_csv = step6_cfg.get('training_csv_path')
if training_csv and not os.path.isfile(training_csv): if training_csv and not os.path.isfile(training_csv):
errors.append( errors.append(
f"步骤 6水质光谱指数训练光谱文件不存在\n {training_csv}\n" f"步骤 6_feature(水质光谱指数):训练光谱文件不存在:\n {training_csv}\n"
" → 请确认步骤 5 已成功运行并生成了训练光谱。" " → 请确认步骤 5_clean 已成功运行并生成了训练光谱。"
) )
# ── 步骤7ML 建模) ── # ── 步骤8_ml_trainML 建模) ──
step7_cfg = config.get('step7', {}) step7_cfg = config.get('step8_ml_train', {})
step7_csv = step7_cfg.get('training_csv_path') step7_csv = step7_cfg.get('training_csv_path')
if step7_csv and not os.path.isfile(step7_csv): if step7_csv and not os.path.isfile(step7_csv):
errors.append( errors.append(
f"步骤 7ML 建模):训练光谱文件不存在:\n {step7_csv}\n" f"步骤 8_ml_trainML 建模):训练光谱文件不存在:\n {step7_csv}\n"
" → 请确认步骤 5 已成功运行并生成了训练光谱。" " → 请确认步骤 5_clean 已成功运行并生成了训练光谱。"
) )
# ── 步骤11 ML 预测:密集采样点 CSV + 模型目录 ── # ── 步骤9_ml_predict:密集采样点 CSV + 模型目录 ──
step11_ml_cfg = config.get('step11_ml', {}) step9_cfg = config.get('step9_ml_predict', {})
ml_csv = step11_ml_cfg.get('sampling_csv_path') ml_csv = step9_cfg.get('sampling_csv_path')
models_dir = step11_ml_cfg.get('models_dir') models_dir = step9_cfg.get('models_dir')
if ml_csv and not os.path.isfile(ml_csv): if ml_csv and not os.path.isfile(ml_csv):
errors.append( errors.append(
f"步骤 11 ML 预测:采样点 CSV 不存在:\n {ml_csv}\n" f"步骤 9_ml_predict:采样点 CSV 不存在:\n {ml_csv}\n"
" → 请确认「流程步骤-阶段七(采样点布设)」已成功运行。" " → 请确认「流程步骤-阶段七(采样点布设)」已成功运行。"
) )
if models_dir and not os.path.isdir(models_dir): if models_dir and not os.path.isdir(models_dir):
errors.append( errors.append(
f"步骤 11 ML 预测:模型目录不存在:\n {models_dir}\n" f"步骤 9_ml_predict:模型目录不存在:\n {models_dir}\n"
" → 请确认「流程步骤-阶段六(机器学习建模)」已成功运行。" " → 请确认「流程步骤-阶段六(机器学习建模)」已成功运行。"
) )
# ── 步骤11 回归预测:模型目录 ── # ── 步骤11_map 专题图:预测结果 CSV + 边界 shp ──
step11_cfg = config.get('step11', {}) step11_cfg = config.get('step11_map', {})
step11_csv = step11_cfg.get('sampling_csv_path') pred_csv = step11_cfg.get('prediction_csv_path')
step11_dir = step11_cfg.get('models_dir') boundary_shp = step11_cfg.get('boundary_shp_path')
if step11_csv and not os.path.isfile(step11_csv):
errors.append(
f"步骤 11 回归预测:采样点 CSV 不存在:\n {step11_csv}\n"
" → 请确认「流程步骤-阶段七(采样点布设)」已成功运行。"
)
if step11_dir and not os.path.isdir(step11_dir):
errors.append(
f"步骤 11 回归预测:模型目录不存在:\n {step11_dir}\n"
" → 请确认「流程步骤-阶段八(非经验建模)」已成功运行。"
)
# ── 步骤12 自定义回归预测:公式 CSV ──
step12_cfg = config.get('step12', {})
formula_csv = step12_cfg.get('formula_csv_path')
if formula_csv and not os.path.isfile(formula_csv):
errors.append(
f"步骤 12自定义回归预测公式 CSV 文件不存在:\n {formula_csv}\n"
" → 请确认「流程步骤-阶段十二」中已填写有效的公式文件路径。"
)
# ── 步骤14 专题图:预测结果 CSV + 边界 shp ──
step14_cfg = config.get('step14', {})
pred_csv = step14_cfg.get('prediction_csv_path')
boundary_shp = step14_cfg.get('boundary_shp_path')
if pred_csv and not os.path.isfile(pred_csv): if pred_csv and not os.path.isfile(pred_csv):
errors.append( errors.append(
f"步骤 14(专题图):预测结果 CSV 不存在:\n {pred_csv}\n" f"步骤 11_map(专题图):预测结果 CSV 不存在:\n {pred_csv}\n"
" → 请确认机器学习或回归预测步骤已成功运行。" " → 请确认机器学习或回归预测步骤已成功运行。"
) )
if boundary_shp and not os.path.isfile(boundary_shp): if boundary_shp and not os.path.isfile(boundary_shp):
errors.append( errors.append(
f"步骤 14(专题图):边界 shp 文件不存在:\n {boundary_shp}\n" f"步骤 11_map(专题图):边界 shp 文件不存在:\n {boundary_shp}\n"
" → 请确认「流程步骤-阶段十」中已填写有效的边界矢量文件路径。" " → 请确认「流程步骤-阶段十」中已填写有效的边界矢量文件路径。"
) )
# ── 汇总报错:任一缺失立即抛出 PipelineHalt ── # ── 汇总报错:任一缺失立即抛出 PipelineHalt ──

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
Step9 面板 - 水色指数反演直接处理去耀斑 BSQ 影像 Step10 面板 - 水色指数反演直接处理去耀斑 BSQ 影像
waterindex.csv 中的公式直接应用于去耀斑高光谱影像 waterindex.csv 中的公式直接应用于去耀斑高光谱影像
输出各水质参数指数的 GeoTIFF 栅格图像 输出各水质参数指数的 GeoTIFF 栅格图像
@ -98,8 +98,8 @@ class WaterIndexWorker(QThread):
self.progress.emit(msg, pct) self.progress.emit(msg, pct)
class Step10MapPanel(QWidget): class Step10WatercolorPanel(QWidget):
"""步骤10专题图生成""" """步骤10水色指数反演(直接处理 BSQ 影像)"""
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
@ -115,7 +115,7 @@ class Step10MapPanel(QWidget):
layout = QVBoxLayout() layout = QVBoxLayout()
# ---- 标题 ---- # ---- 标题 ----
title = QLabel("步骤9:水色指数反演(高光谱影像直接处理)") title = QLabel("步骤10:水色指数反演(高光谱影像直接处理)")
title.setFont(QFont("Arial", 12, QFont.Bold)) title.setFont(QFont("Arial", 12, QFont.Bold))
layout.addWidget(title) layout.addWidget(title)

View File

@ -0,0 +1,826 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step10 面板 - 专题图生成
"""
import os
import traceback
from pathlib import Path
from typing import List, Optional
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
QRadioButton, QButtonGroup, QMessageBox, QFileDialog, QComboBox,
QProgressBar,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
PIPELINE_AVAILABLE = True
except ImportError:
PIPELINE_AVAILABLE = False
class Step11MapBatchThread(QThread):
"""专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
finished_ok = pyqtSignal(int)
failed = pyqtSignal(str)
log_message = pyqtSignal(str, str)
progress = pyqtSignal(int, int) # (current, total)
def __init__(self, work_dir: str, csv_paths: List[str], step10_kwargs: dict, output_dir_optional: Optional[str]):
super().__init__()
self.work_dir = work_dir
self.csv_paths = csv_paths
self.step10_kwargs = step10_kwargs
self.output_dir_optional = (output_dir_optional or "").strip() or None
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
n = len(self.csv_paths)
for i, csv_p in enumerate(self.csv_paths):
self.progress.emit(i + 1, n)
self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
kw = {**self.step10_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
if self.output_dir_optional:
stem = Path(csv_p).stem
kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
else:
kw["output_image_path"] = None
pipeline.step10_map(**kw)
self.finished_ok.emit(n)
except Exception as e:
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass
class Step11GeoTIFFBatchThread(QThread):
"""GeoTIFF 批量渲染:遍历文件夹下所有 .tif/.bsq 逐一渲染成分布图 PNG。"""
finished_ok = pyqtSignal(int)
failed = pyqtSignal(str)
log_message = pyqtSignal(str, str)
progress = pyqtSignal(int, int) # (current, total)
def __init__(
self,
tif_paths: List[str],
output_dir: str,
boundary_shp_path: Optional[str],
input_crs: str,
output_crs: str,
):
super().__init__()
self.tif_paths = tif_paths
self.output_dir = output_dir
self.boundary_shp_path = boundary_shp_path
self.input_crs = input_crs
self.output_crs = output_crs
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
from src.postprocessing.map import ContentMapper
mapper = ContentMapper()
n = len(self.tif_paths)
for i, tif_path in enumerate(self.tif_paths):
self.progress.emit(i + 1, n)
tif_stem = Path(tif_path).stem
chinese_name = mapper._get_chinese_title(tif_stem)
output_png = str(Path(self.output_dir) / f"{chinese_name}_专题图.png")
self.log_message.emit(f"GeoTIFF 渲染 [{i + 1}/{n}] {tif_stem}", "info")
try:
mapper.visualize_raster(
raster_tif_path=tif_path,
output_file=output_png,
boundary_shp_path=self.boundary_shp_path,
nodata_value=-9999.0,
figsize=(14, 10),
alpha=0.9,
)
except Exception as vis_err:
self.log_message.emit(f" ⚠️ 渲染失败,跳过: {vis_err}", "warning")
continue
self.finished_ok.emit(n)
except Exception as e:
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass
class Step11MapPanel(QWidget):
"""步骤11专题图生成"""
def __init__(self, parent=None):
super().__init__(parent)
self._batch_thread = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
hint = QLabel(
"独立运行:可选「单个 CSV」或「文件夹批量」扫描目录下所有 .csv"
"GeoTIFF 栅格模式下亦支持批量渲染步骤8输出的所有水色指数 GeoTIFF 文件。"
)
hint.setWordWrap(True)
hint.setStyleSheet(
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
)
layout.addWidget(hint)
mode_row = QHBoxLayout()
self.mode_single_rb = QRadioButton("单个 CSV 文件")
self.mode_folder_rb = QRadioButton("文件夹批量")
self._mode_group = QButtonGroup(self)
self._mode_group.addButton(self.mode_single_rb, 0)
self._mode_group.addButton(self.mode_folder_rb, 1)
mode_row.addWidget(self.mode_single_rb)
mode_row.addWidget(self.mode_folder_rb)
mode_row.addStretch()
layout.addLayout(mode_row)
# ---------- 渲染模式选择器CSV vs GeoTIFF ----------
render_row = QHBoxLayout()
render_row.addWidget(QLabel("渲染模式:"))
self.render_mode_combo = QComboBox()
self.render_mode_combo.addItems(["CSV 插值模式", "GeoTIFF 栅格模式"])
self.render_mode_combo.setMinimumWidth(180)
self.render_mode_combo.currentTextChanged.connect(self._toggle_input_mode)
render_row.addWidget(self.render_mode_combo)
render_row.addStretch()
layout.addLayout(render_row)
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
radio_style = """
QRadioButton {
font-size: 14px;
spacing: 8px;
color: #333333;
}
QRadioButton::indicator {
width: 16px;
height: 16px;
border: 2px solid #999999;
border-radius: 3px;
background-color: white;
}
QRadioButton::indicator:checked {
border: 2px solid #0078d4;
background-color: #0078d4;
image: none;
}
QRadioButton::indicator:hover {
border: 2px solid #005a9e;
}
"""
self.mode_single_rb.setStyleSheet(radio_style)
self.mode_folder_rb.setStyleSheet(radio_style)
self.prediction_csv_file = FileSelectWidget(
"预测结果CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.prediction_csv_file)
folder_row = QHBoxLayout()
self.prediction_csv_dir_label = QLabel("预测CSV目录:")
self.prediction_csv_dir_label.setMinimumWidth(120)
self.prediction_csv_dir_edit = QLineEdit()
self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
pred_dir_btn = QPushButton("浏览…")
pred_dir_btn.setMaximumWidth(80)
pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
folder_row.addWidget(self.prediction_csv_dir_label)
folder_row.addWidget(self.prediction_csv_dir_edit, 1)
folder_row.addWidget(pred_dir_btn)
self._folder_row_widget = QWidget()
self._folder_row_widget.setLayout(folder_row)
layout.addWidget(self._folder_row_widget)
# ---------- GeoTIFF 栅格文件选择器 ----------
self.geotiff_file = FileSelectWidget(
"水色指数 GeoTIFF:",
"GeoTIFF Files (*.tif);;All Files (*.*)"
)
self.geotiff_file.line_edit.setPlaceholderText("选择步骤8输出的水色指数 GeoTIFF 文件…")
self.geotiff_file.setVisible(False)
layout.addWidget(self.geotiff_file)
# ---------- GeoTIFF 文件夹批量选择器GeoTIFF + 文件夹模式时显示) ----------
geotiff_dir_row = QHBoxLayout()
self.geotiff_dir_label = QLabel("水色指数目录:")
self.geotiff_dir_label.setMinimumWidth(120)
self.geotiff_dir_edit = QLineEdit()
self.geotiff_dir_edit.setPlaceholderText("选择 8_WaterIndex_Images 文件夹(批量渲染)…")
geotiff_dir_btn = QPushButton("浏览…")
geotiff_dir_btn.setMaximumWidth(80)
geotiff_dir_btn.clicked.connect(self.browse_geotiff_dir)
geotiff_dir_row.addWidget(self.geotiff_dir_label)
geotiff_dir_row.addWidget(self.geotiff_dir_edit, 1)
geotiff_dir_row.addWidget(geotiff_dir_btn)
self._geotiff_dir_widget = QWidget()
self._geotiff_dir_widget.setLayout(geotiff_dir_row)
self._geotiff_dir_widget.setVisible(False)
layout.addWidget(self._geotiff_dir_widget)
self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv")
layout.addWidget(self.recursive_csv_cb)
self.boundary_file = FileSelectWidget(
"边界文件:",
"Shapefiles (*.shp);;All Files (*.*)"
)
layout.addWidget(self.boundary_file)
# 参数设置
params_group = QGroupBox("生成参数")
params_layout = QFormLayout()
self.resolution = QDoubleSpinBox()
self.resolution.setRange(1, 1000)
self.resolution.setValue(30)
params_layout.addRow("分辨率(米):", self.resolution)
self.input_crs = QLineEdit()
self.input_crs.setText("EPSG:32651")
params_layout.addRow("输入坐标系:", self.input_crs)
self.output_crs = QLineEdit()
self.output_crs.setText("EPSG:4326")
params_layout.addRow("输出坐标系:", self.output_crs)
self.show_points = QCheckBox("显示采样点")
params_layout.addRow("", self.show_points)
self.use_diffusion = QCheckBox("启用距离扩散")
self.use_diffusion.setChecked(True)
params_layout.addRow("", self.use_diffusion)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出目录
self.output_dir = FileSelectWidget(
"输出分布图目录:",
"Directories;;All Files (*.*)"
)
self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
self.output_dir.browse_btn.clicked.disconnect()
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
layout.addWidget(self.output_dir)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_button = QPushButton("独立运行此步骤")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.clicked.connect(self.run_step)
layout.addWidget(self.run_button)
# 批量渲染进度条
self.progress_bar = QProgressBar()
self.progress_bar.setVisible(False)
self.progress_bar.setMinimum(0)
self.progress_bar.setMaximum(100)
self.progress_bar.setValue(0)
layout.addWidget(self.progress_bar)
layout.addStretch()
self.setLayout(layout)
# 信号绑定与初始状态
self.mode_single_rb.toggled.connect(self._toggle_input_mode)
self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
self._toggle_input_mode() # 根据默认值设置初始显示状态
def _toggle_input_mode(self):
"""槽函数:根据渲染模式和输入模式动态显示/隐藏对应的输入组件。"""
geotiff_mode = self.render_mode_combo.currentText() == "GeoTIFF 栅格模式"
folder_mode = self.mode_folder_rb.isChecked()
# CSV 插值模式
if not geotiff_mode:
self.prediction_csv_file.setVisible(not folder_mode)
self._folder_row_widget.setVisible(folder_mode)
self.recursive_csv_cb.setVisible(folder_mode)
self.geotiff_file.setVisible(False)
self._geotiff_dir_widget.setVisible(False)
# GeoTIFF 栅格模式
else:
self.prediction_csv_file.setVisible(False)
self._folder_row_widget.setVisible(False)
self.recursive_csv_cb.setVisible(False)
# GeoTIFF + 文件夹批量 → 显示文件夹选择器;否则 → 显示单文件选择器
self.geotiff_file.setVisible(not folder_mode)
self._geotiff_dir_widget.setVisible(folder_mode)
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_prediction_csv_dir(self):
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "11_12_13_predictions")
d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
if d:
self.prediction_csv_dir_edit.setText(d)
def _collect_csv_paths_from_folder(self) -> List[str]:
folder = (self.prediction_csv_dir_edit.text() or "").strip()
if not folder or not os.path.isdir(folder):
return []
root = Path(folder)
if self.recursive_csv_cb.isChecked():
files = sorted(root.rglob("*.csv"))
else:
files = sorted(root.glob("*.csv"))
return [str(p) for p in files if p.is_file()]
def browse_geotiff_dir(self):
"""浏览 GeoTIFF 文件夹(批量模式)"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "8_WaterIndex_Images")
d = QFileDialog.getExistingDirectory(
self, "选择水色指数 GeoTIFF 文件夹", default
)
if d:
self.geotiff_dir_edit.setText(d)
def _collect_tif_paths_from_folder(self) -> List[str]:
"""扫描所选文件夹,收集所有 .tif 和 .bsq 文件路径"""
folder = (self.geotiff_dir_edit.text() or "").strip()
if not folder or not os.path.isdir(folder):
return []
root = Path(folder)
tif_files = sorted(root.glob("*.tif"))
bsq_files = sorted(root.glob("*.bsq"))
return [str(p) for p in tif_files + bsq_files if p.is_file()]
def _step10_base_pipeline_kwargs(self) -> dict:
return {
'boundary_shp_path': self.boundary_file.get_path(),
'resolution': self.resolution.value(),
'input_crs': self.input_crs.text(),
'output_crs': self.output_crs.text(),
'show_sample_points': self.show_points.isChecked(),
'use_distance_diffusion': self.use_diffusion.isChecked(),
}
def get_config(self):
pred_csv = (self.prediction_csv_file.get_path() or "").strip()
folder_mode = self.mode_folder_rb.isChecked()
pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
geotiff_path = (self.geotiff_file.get_path() or "").strip()
config = {
'step10_batch_mode': 'folder' if folder_mode else 'single',
'render_mode': self.render_mode_combo.currentText(),
'prediction_csv_dir': pred_dir if pred_dir else None,
'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
'geotiff_path': geotiff_path if geotiff_path else None,
'geotiff_dir': (self.geotiff_dir_edit.text() or "").strip() or None,
'boundary_shp_path': self.boundary_file.get_path(),
'resolution': self.resolution.value(),
'input_crs': self.input_crs.text(),
'output_crs': self.output_crs.text(),
'show_sample_points': self.show_points.isChecked(),
'use_distance_diffusion': self.use_diffusion.isChecked(),
}
out_dir = (self.output_dir.get_path() or "").strip()
if not folder_mode and pred_csv and out_dir:
stem = Path(pred_csv).stem
config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
else:
config['output_image_path'] = None
return config
def set_config(self, config):
mode = config.get('step10_batch_mode', 'single')
if mode == 'folder':
self.mode_folder_rb.setChecked(True)
else:
self.mode_single_rb.setChecked(True)
render_mode = config.get('render_mode', 'CSV 插值模式')
idx = self.render_mode_combo.findText(render_mode)
if idx >= 0:
self.render_mode_combo.setCurrentIndex(idx)
if config.get('prediction_csv_dir'):
self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
if 'recursive_csv_scan' in config:
self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
if 'prediction_csv_path' in config and config['prediction_csv_path']:
self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
if 'geotiff_path' in config and config['geotiff_path']:
self.geotiff_file.set_path(str(config['geotiff_path']))
if 'geotiff_dir' in config and config['geotiff_dir']:
self.geotiff_dir_edit.setText(str(config['geotiff_dir']))
if 'boundary_shp_path' in config:
self.boundary_file.set_path(config['boundary_shp_path'])
if 'resolution' in config:
self.resolution.setValue(config['resolution'])
if 'input_crs' in config:
self.input_crs.setText(config['input_crs'])
if 'output_crs' in config:
self.output_crs.setText(config['output_crs'])
if 'show_sample_points' in config:
self.show_points.setChecked(config['show_sample_points'])
if 'use_distance_diffusion' in config:
self.use_diffusion.setChecked(config['use_distance_diffusion'])
if 'output_dir' in config and config['output_dir']:
self.output_dir.set_path(str(config['output_dir']))
elif config.get('output_image_path'):
p = Path(str(config['output_image_path']))
if p.parent and str(p.parent) != '.':
self.output_dir.set_path(str(p.parent))
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充预测结果目录
优先使用 Step8机器学习预测的输出目录作为待预测 CSV 目录;
其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
if not main_window:
return
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
pred_dir = None
if hasattr(main_window, 'step11_prediction_panel'):
step8_widget = getattr(main_window.step11_prediction_panel, 'output_file', None)
step8_output = ""
if hasattr(step8_widget, 'get_path'):
step8_output = step8_widget.get_path() or ""
elif hasattr(step8_widget, 'text'):
step8_output = step8_widget.text() or ""
if step8_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_output):
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
# 提取父目录后追加 Machine_Learning_Prediction最底层真实子目录
base_pred_dir = str(Path(step8_output).parent)
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
# 2. 备选:从 Step11 界面读取非经验预测输出目录
if not pred_dir and hasattr(main_window, 'step11_panel'):
step8_5_widget = getattr(main_window.step11_panel, 'output_file', None)
step8_5_output = ""
if hasattr(step8_5_widget, 'get_path'):
step8_5_output = step8_5_widget.get_path() or ""
elif hasattr(step8_5_widget, 'text'):
step8_5_output = step8_5_widget.text() or ""
if step8_5_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_5_output):
step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
pred_dir = str(Path(step8_5_output).parent)
# 3. 备选:从 Step12 界面读取自定义回归预测输出目录
if not pred_dir and hasattr(main_window, 'step12_panel'):
step8_75_widget = getattr(main_window.step12_panel, 'output_dir_widget', None)
step8_75_output = ""
if hasattr(step8_75_widget, 'get_path'):
step8_75_output = step8_75_widget.get_path() or ""
elif hasattr(step8_75_widget, 'text'):
step8_75_output = step8_75_widget.text() or ""
if step8_75_output:
pred_dir = step8_75_output
# 自动填入"预测CSV目录"(文件夹批量模式)
if pred_dir:
existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
if not existing_dir:
self.prediction_csv_dir_edit.setText(pred_dir)
# 切换到文件夹批量模式
self.mode_folder_rb.setChecked(True)
# 4. 自动填充输出目录14_visualization
if self.work_dir:
output_dir = os.path.join(self.work_dir, "14_visualization")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_dir.get_path()
if not existing_out or not existing_out.strip():
self.output_dir.set_path(output_dir)
# 5. 自动探测原始矢量边界文件(.shp作为专题图底图
# 优先回溯 input-test/roi.shpgeopandas.read_file 仅支持矢量格式
if self.work_dir:
possible_shp = None
candidates = [
Path(self.work_dir).parent / "input-test" / "roi.shp",
Path(self.work_dir) / "roi.shp",
Path(self.work_dir).parent / "roi.shp",
]
for candidate in candidates:
if candidate.exists() and candidate.suffix.lower() == ".shp":
possible_shp = candidate
break
existing_boundary = (self.boundary_file.get_path() or "").strip()
if not existing_boundary and possible_shp:
self.boundary_file.set_path(str(possible_shp))
elif not existing_boundary:
self.boundary_file.set_path("")
print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。")
# 6. 自动探测 Step 8 输出的水色指数 GeoTIFFGeoTIFF 渲染模式)
step8_out_dir = Path(self.work_dir) / "8_WaterIndex_Images" if self.work_dir else None
if step8_out_dir and step8_out_dir.is_dir():
# GeoTIFF 批量模式:填充目录供批量渲染
if not (self.geotiff_dir_edit.text() or "").strip():
self.geotiff_dir_edit.setText(str(step8_out_dir))
# GeoTIFF 单文件模式:默认选中第一个
tif_files = sorted(step8_out_dir.glob("*.tif"))
if tif_files and not (self.geotiff_file.get_path() or "").strip():
self.geotiff_file.set_path(str(tif_files[0]))
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def browse_output_dir(self):
"""浏览输出目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "14_visualization")
dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
if dir_path:
self.output_dir.set_path(dir_path)
def _start_batch_run(self, csv_list, work_dir, base_kw, out_dir_opt, parent):
"""封装 CSV 批量启动逻辑,统一处理信号连接和进度条"""
self.run_button.setEnabled(False)
self.progress_bar.setVisible(True)
self.progress_bar.setValue(0)
self._batch_thread = Step11MapBatchThread(work_dir, csv_list, base_kw, out_dir_opt)
main_win = parent
def _batch_log(msg, lvl):
if hasattr(main_win, "log_message"):
main_win.log_message(msg, lvl)
def _on_progress(cur, total):
if total > 0:
self.progress_bar.setMaximum(total)
self.progress_bar.setValue(cur)
self.progress_bar.setFormat(f"{cur}/{total} 张 (%p%)")
self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
self._batch_thread.progress.connect(_on_progress, Qt.QueuedConnection)
self._batch_thread.finished_ok.connect(self._on_step10_batch_ok, Qt.QueuedConnection)
self._batch_thread.failed.connect(self._on_step10_batch_fail, Qt.QueuedConnection)
self._batch_thread.finished.connect(
lambda: (self.run_button.setEnabled(True), self.progress_bar.setVisible(False)),
Qt.QueuedConnection,
)
self._batch_thread.start()
if hasattr(parent, "log_message"):
parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV工作目录 {work_dir}", "info")
def run_step(self):
"""独立运行步骤11"""
if self._batch_thread and self._batch_thread.isRunning():
QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
return
boundary_shp_path = self.boundary_file.get_path()
if not boundary_shp_path:
QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
return
if not os.path.exists(boundary_shp_path):
QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
return
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if not parent or not hasattr(parent, 'run_single_step'):
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
return
if self.mode_folder_rb.isChecked():
# -------- CSV 插值批量 --------
if self.render_mode_combo.currentText() != "GeoTIFF 栅格模式":
csv_list = self._collect_csv_paths_from_folder()
if not csv_list:
QMessageBox.warning(
self,
"输入验证失败",
"所选文件夹中未找到 .csv 文件,或目录无效。\n"
"可勾选「包含子文件夹」以递归扫描。",
)
return
if not PIPELINE_AVAILABLE:
QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
return
work_dir = getattr(parent, "work_dir", None) or "./work_dir"
work_dir = str(work_dir)
base_kw = self._step10_base_pipeline_kwargs()
out_dir_opt = (self.output_dir.get_path() or "").strip() or None
self._start_batch_run(csv_list, work_dir, base_kw, out_dir_opt, parent)
return
# -------- GeoTIFF 栅格批量 --------
tif_list = self._collect_tif_paths_from_folder()
if not tif_list:
QMessageBox.warning(
self,
"输入验证失败",
"所选文件夹中未找到 .tif / .bsq 文件,\n"
"请确认目录包含步骤8输出的水色指数 GeoTIFF 文件。",
)
return
out_dir = (self.output_dir.get_path() or "").strip()
if not out_dir:
out_dir = os.path.join(self._get_default_work_dir(), "14_visualization")
os.makedirs(out_dir, exist_ok=True)
self.run_button.setEnabled(False)
self.progress_bar.setVisible(True)
self.progress_bar.setValue(0)
self._batch_thread = Step11GeoTIFFBatchThread(
tif_paths=tif_list,
output_dir=out_dir,
boundary_shp_path=boundary_shp_path,
input_crs=self.input_crs.text(),
output_crs=self.output_crs.text(),
)
main_win = parent
def _batch_log(msg, lvl):
if hasattr(main_win, "log_message"):
main_win.log_message(msg, lvl)
def _on_progress(cur, total):
if total > 0:
pct = int(cur / total * 100)
self.progress_bar.setMaximum(total)
self.progress_bar.setValue(cur)
self.progress_bar.setFormat(f"{cur}/{total} 张 (%p%)")
self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
self._batch_thread.progress.connect(_on_progress, Qt.QueuedConnection)
self._batch_thread.finished_ok.connect(self._on_step10_batch_ok, Qt.QueuedConnection)
self._batch_thread.failed.connect(self._on_step10_batch_fail, Qt.QueuedConnection)
self._batch_thread.finished.connect(
lambda: (self.run_button.setEnabled(True), self.progress_bar.setVisible(False)),
Qt.QueuedConnection,
)
self._batch_thread.start()
if hasattr(parent, "log_message"):
parent.log_message(f"GeoTIFF 批量渲染:共 {len(tif_list)} 个文件 → {out_dir}", "info")
return
# -------- GeoTIFF 栅格单文件模式 --------
if self.render_mode_combo.currentText() == "GeoTIFF 栅格模式":
geotiff_path = (self.geotiff_file.get_path() or "").strip()
if not geotiff_path:
QMessageBox.warning(self, "输入验证失败", "请选择水色指数 GeoTIFF 文件")
return
if not os.path.isfile(geotiff_path):
QMessageBox.warning(self, "输入验证失败", f"GeoTIFF 文件不存在:\n{geotiff_path}")
return
boundary_shp_path = self.boundary_file.get_path()
input_crs = self.input_crs.text()
output_crs = self.output_crs.text()
# 构造输出路径
out_dir = (self.output_dir.get_path() or "").strip()
if not out_dir:
out_dir = os.path.join(self._get_default_work_dir(), "14_visualization")
os.makedirs(out_dir, exist_ok=True)
tif_stem = Path(geotiff_path).stem
chinese_name = mapper._get_chinese_title(tif_stem)
output_png = os.path.join(out_dir, f"{chinese_name}_专题图.png")
self.run_button.setEnabled(False)
try:
from src.postprocessing.map import ContentMapper
mapper = ContentMapper()
result_path = mapper.visualize_raster(
raster_tif_path=geotiff_path,
output_file=output_png,
boundary_shp_path=boundary_shp_path if boundary_shp_path else None,
nodata_value=-9999.0,
figsize=(14, 10),
alpha=0.9,
)
self.run_button.setEnabled(True)
QMessageBox.information(
self, "完成",
f"GeoTIFF 栅格渲染完成!\n{result_path}"
)
if hasattr(parent, "log_message"):
parent.log_message(f"Step10 GeoTIFF 渲染完成 → {result_path}", "info")
except Exception as e:
self.run_button.setEnabled(True)
QMessageBox.critical(self, "渲染失败", f"{e}\n{traceback.format_exc()[:500]}")
if hasattr(parent, "log_message"):
parent.log_message(str(e), "error")
return
prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
if not prediction_csv_path:
QMessageBox.warning(
self,
"输入验证失败",
"请选择「预测结果 CSV」文件或切换到「文件夹批量」。",
)
return
if not os.path.isfile(prediction_csv_path):
QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
return
config = self.get_config()
parent.run_single_step('step11_map', {'step11_map': config})
def _on_step10_batch_ok(self, n: int):
self.progress_bar.setVisible(False)
QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
parent = self.parent()
while parent and not hasattr(parent, "log_message"):
parent = parent.parent()
if parent and hasattr(parent, "log_message"):
parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
def _on_step10_batch_fail(self, err: str):
self.progress_bar.setVisible(False)
QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
parent = self.parent()
while parent and not hasattr(parent, "log_message"):
parent = parent.parent()
if parent and hasattr(parent, "log_message"):
parent.log_message(err, "error")

View File

@ -1211,8 +1211,8 @@ class ChartBrowserDialog(QDialog):
QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}") QMessageBox.critical(self, "错误", f"保存失败:\n{str(e)}")
class Step11VizPanel(QWidget): class Step12VizPanel(QWidget):
"""步骤11:可视化展示""" """步骤12:可视化展示"""
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
self.work_dir = None self.work_dir = None

View File

@ -225,6 +225,6 @@ class Step12Panel(QWidget):
parent = parent.parent() parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'): if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step12', {'step12': config}) parent.run_single_step('step13_report', {'step13_report': config})
else: else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象") QMessageBox.critical(self, "错误", "无法找到父级GUI对象")

View File

@ -79,8 +79,8 @@ class ReportGenerateThread(QThread):
self.failed.emit(f"{e}\n{traceback.format_exc()}") self.failed.emit(f"{e}\n{traceback.format_exc()}")
class Step12ReportPanel(QWidget): class Step13ReportPanel(QWidget):
"""步骤12分析报告生成。AI 配置统一由 AISettingsDialog 管理,本面板不持有配置状态。""" """步骤13分析报告生成。AI 配置统一由 AISettingsDialog 管理,本面板不持有配置状态。"""
def __init__(self, main_window=None, parent=None): def __init__(self, main_window=None, parent=None):
super().__init__(parent) super().__init__(parent)

View File

@ -223,8 +223,8 @@ class Step4SamplingPanel(QWidget):
main_window = self.window() main_window = self.window()
if hasattr(main_window, 'run_single_step'): if hasattr(main_window, 'run_single_step'):
config = {'step4': self.get_config()} config = {'step4_sampling': self.get_config()}
main_window.run_single_step('step4', config) main_window.run_single_step('step4_sampling', config)
def _check_csv_exists(self): def _check_csv_exists(self):
"""检查 output csv 是否存在,驱动预览按钮启停""" """检查 output csv 是否存在,驱动预览按钮启停"""

View File

@ -143,8 +143,8 @@ class Step5CleanPanel(QWidget):
main_window = self.window() main_window = self.window()
if hasattr(main_window, 'run_single_step'): if hasattr(main_window, 'run_single_step'):
config = {'step5': self.get_config()} config = {'step5_clean': self.get_config()}
main_window.run_single_step('step5', config) main_window.run_single_step('step5_clean', config)
def reset_preview(self, message="请选择CSV文件并点击刷新预览"): def reset_preview(self, message="请选择CSV文件并点击刷新预览"):
"""重置预览表格""" """重置预览表格"""

View File

@ -235,5 +235,5 @@ class Step6FeaturePanel(QWidget):
# 获取主窗口并运行步骤 # 获取主窗口并运行步骤
main_window = self.window() main_window = self.window()
if hasattr(main_window, 'run_single_step'): if hasattr(main_window, 'run_single_step'):
config = {'step5': self.get_config()} config = {'step6_feature': self.get_config()}
main_window.run_single_step('step5', config) main_window.run_single_step('step6_feature', config)

View File

@ -5,345 +5,411 @@ Step8 面板 - 机器学习建模
""" """
import os import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from PyQt5.QtWidgets import ( from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QGridLayout, QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
QHBoxLayout, QLabel, QCheckBox, QPushButton, QMessageBox, QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
QScrollArea, QListWidget, QListWidgetItem, QAbstractItemView, QPushButton, QFileDialog, QMessageBox,
QRadioButton, QButtonGroup
) )
from PyQt5.QtCore import Qt from PyQt5.QtCore import Qt
from PyQt5.QtGui import QColor, QBrush, QFont
from src.gui.components.custom_widgets import FileSelectWidget from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet from src.gui.styles import ModernStylesheet
def get_resource_path(relative_path: str) -> str: # ============================================================
"""适配开发与 PyInstaller 环境的路径获取逻辑。""" # 中文映射表(内部键名 -> 显示文本)
if hasattr(sys, '_MEIPASS'): # ============================================================
internal = os.path.join(sys._MEIPASS, '_internal', relative_path)
if os.path.exists(internal):
return internal
return os.path.join(sys._MEIPASS, relative_path)
exe_dir = os.path.dirname(sys.executable) # 预处理方法:内部键 -> 显示文本
internal = os.path.join(exe_dir, '_internal', relative_path) PREPROC_CHINESE = {
if os.path.exists(internal): 'None': '无 (None)',
return internal 'MMS': '最小-最大归一化 (MMS)',
'SS': '标度化 (SS)',
'SNV': '标准正态变换 (SNV)',
'MA': '移动平均 (MA)',
'SG': 'Savitzky-Golay (SG)',
'MSC': '多元散射校正 (MSC)',
'D1': '一阶导数 (D1)',
'D2': '二阶导数 (D2)',
'DT': '去趋势 (DT)',
'CT': '中心化 (CT)',
}
base_dir = Path(__file__).resolve().parent.parent / "model" # 模型类型:内部键 -> 显示文本
return str(base_dir / os.path.basename(relative_path)) MODEL_CHINESE = {
# 线性模型
'LinearRegression': '多元线性回归 (MLR)',
'Ridge': '岭回归 (Ridge)',
'Lasso': '套索回归 (Lasso)',
'ElasticNet': '弹性网络 (ElasticNet)',
'PLS': '偏最小二乘 (PLSR)',
# 树模型
'DecisionTree': '决策树 (CART)',
'RF': '随机森林 (RF)',
'ExtraTrees': '极端随机树 (ET)',
'XGBoost': '极值梯度提升 (XGBoost)',
'LightGBM': '轻量梯度提升 (LightGBM)',
'CatBoost': '类别梯度提升 (CatBoost)',
# 集成学习
'GradientBoosting': '梯度提升树 (GBDT)',
'AdaBoost': '自适应提升 (AdaBoost)',
# 其他模型
'SVR': '支持向量回归 (SVR)',
'KNN': 'K近邻回归 (KNN)',
'MLP': '多层感知机 (BP神经网络)',
}
# 数据划分方法:内部键 -> 显示文本
SPLIT_CHINESE = {
'spxy': 'SPXY 算法 (考量X-Y空间)',
'ks': 'KS 算法 (考量X空间)',
'random': '随机划分 (Random)',
}
class Step8MlTrainPanel(QWidget): class Step8MlTrainPanel(QWidget):
"""步骤8机器学习建模""" """步骤8机器学习建模"""
COLOR_RATIO = QColor(255, 255, 255)
COLOR_CONCENTRATION = QColor(220, 240, 255)
COLOR_HEADER = QColor(245, 245, 245)
def __init__(self, parent=None): def __init__(self, parent=None):
super().__init__(parent) super().__init__(parent)
self.index_checkboxes: Dict[str, QListWidgetItem] = {}
self.work_dir: Optional[str] = None
self.builtin_formula_path = get_resource_path("waterindex.csv")
self._formula_type_map: Dict[str, str] = {}
self._formula_color_map: Dict[str, QColor] = {}
self._formula_coef_map: Dict[str, List[float]] = {}
self.init_ui() self.init_ui()
self._auto_load_formulas()
def init_ui(self): def init_ui(self):
main_layout = QVBoxLayout() layout = QVBoxLayout()
main_layout.setContentsMargins(20, 20, 20, 20)
main_layout.setSpacing(10)
# 1. 公式配置源 (只读) # 标题
path_group = QGroupBox("公式配置源 (内置)")
path_layout = QVBoxLayout()
self.formula_csv_widget = FileSelectWidget("内置CSV路径:", "CSV Files (*.csv)")
self.formula_csv_widget.set_path(self.builtin_formula_path)
self.formula_csv_widget.set_read_only(True)
self.formula_csv_widget.line_edit.setStyleSheet("background-color: #f0f0f0; color: #666;")
path_layout.addWidget(self.formula_csv_widget)
path_group.setLayout(path_layout)
main_layout.addWidget(path_group)
# 2. 训练数据输入
input_group = QGroupBox("输入样本数据")
input_layout = QVBoxLayout()
self.training_data_widget = FileSelectWidget("特征提取CSV:", "CSV Files (*.csv)")
input_layout.addWidget(self.training_data_widget)
input_group.setLayout(input_layout)
main_layout.addWidget(input_group)
# 3. 公式选择区 (分组 ListWidget) # 训练数据文件(用于独立运行)
self.formula_group = QGroupBox("待计算水质指数勾选") self.training_csv_file = FileSelectWidget(
formula_outer_layout = QVBoxLayout() "训练数据:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.training_csv_file)
btn_layout = QHBoxLayout() # 机器学习模型页面
self.select_all_btn = QPushButton("全选") self.ml_page = QWidget()
self.deselect_all_btn = QPushButton("清空") self.create_ml_page()
self.select_ratio_btn = QPushButton("仅选比值型") layout.addWidget(self.ml_page)
self.select_conc_btn = QPushButton("仅选浓度型")
self.select_all_btn.clicked.connect(self.select_all_formulas)
self.deselect_all_btn.clicked.connect(self.deselect_all_formulas)
self.select_ratio_btn.clicked.connect(self._select_ratio_only)
self.select_conc_btn.clicked.connect(self._select_conc_only)
btn_layout.addWidget(self.select_all_btn)
btn_layout.addWidget(self.deselect_all_btn)
btn_layout.addWidget(self.select_ratio_btn)
btn_layout.addWidget(self.select_conc_btn)
btn_layout.addStretch()
self.refresh_button = QPushButton("重新加载") # 输出文件路径
self.refresh_button.clicked.connect(lambda: self.refresh_formulas(silent=False)) self.output_path = FileSelectWidget(
btn_layout.addWidget(self.refresh_button) "输出文件:",
"CSV Files (*.csv);;All Files (*.*)",
mode="save"
)
self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
self.output_path.browse_btn.clicked.disconnect()
self.output_path.browse_btn.clicked.connect(self.browse_output_path)
layout.addWidget(self.output_path)
formula_outer_layout.addLayout(btn_layout) # 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(False)
layout.addWidget(self.enable_checkbox)
scroll = QScrollArea() # 独立运行按钮
scroll.setWidgetResizable(True) self.run_btn = QPushButton("独立运行此步骤")
scroll.setMinimumHeight(280) self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.scroll_content = QWidget() self.run_btn.clicked.connect(self.run_step)
self.formula_layout = QVBoxLayout(self.scroll_content) layout.addWidget(self.run_btn)
self.formula_layout.setContentsMargins(4, 4, 4, 4)
self.formula_layout.setSpacing(2)
self.formula_layout.setAlignment(Qt.AlignTop)
self.formula_list = QListWidget() layout.addStretch()
self.formula_list.setSelectionMode(QAbstractItemView.MultiSelection) self.setLayout(layout)
self.formula_list.itemChanged.connect(self._on_item_changed)
self.formula_layout.addWidget(self.formula_list)
scroll.setWidget(self.scroll_content) def create_ml_page(self):
formula_outer_layout.addWidget(scroll) """创建机器学习模型页面"""
layout = QVBoxLayout()
self.formula_group.setLayout(formula_outer_layout) # 参数设置
main_layout.addWidget(self.formula_group) params_group = QGroupBox("训练参数")
params_layout = QFormLayout()
# 4. 输出选项 self.feature_start = QLineEdit()
output_group = QGroupBox("输出模式") self.feature_start.setText("374.285004")
output_layout = QVBoxLayout() params_layout.addRow("特征起始列:", self.feature_start)
mode_layout = QHBoxLayout() self.cv_folds = QSpinBox()
self.mode_group = QButtonGroup() self.cv_folds.setRange(2, 10)
self.radio_both = QRadioButton("两者皆出") self.cv_folds.setValue(3)
self.radio_wide = QRadioButton("仅宽表") params_layout.addRow("交叉验证折数:", self.cv_folds)
self.radio_single = QRadioButton("仅单文件")
self.mode_group.addButton(self.radio_both, 0)
self.mode_group.addButton(self.radio_wide, 1)
self.mode_group.addButton(self.radio_single, 2)
self.radio_both.setChecked(True)
mode_layout.addWidget(self.radio_both)
mode_layout.addWidget(self.radio_wide)
mode_layout.addWidget(self.radio_single)
mode_layout.addStretch()
output_layout.addLayout(mode_layout)
self.enable_checkbox = QCheckBox("启用计算流程") params_group.setLayout(params_layout)
self.enable_checkbox.setChecked(True) layout.addWidget(params_group)
output_layout.addWidget(self.enable_checkbox)
output_group.setLayout(output_layout) # 预处理方法 - 多选
main_layout.addWidget(output_group) preproc_group = QGroupBox("预处理方法 (可多选)")
preproc_layout = QVBoxLayout()
# 5. 运行按钮 preproc_grid = QGridLayout()
self.run_button = QPushButton("立即执行计算") self.preproc_checkboxes = {}
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success')) preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
self.run_button.setMinimumHeight(40)
self.run_button.clicked.connect(self.run_step)
main_layout.addWidget(self.run_button)
self.setLayout(main_layout) for i, method in enumerate(preproc_methods):
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
checkbox.setChecked(False)
self.preproc_checkboxes[method] = checkbox
preproc_grid.addWidget(checkbox, i // 4, i % 4)
def _on_item_changed(self, item: QListWidgetItem): button_layout = QHBoxLayout()
if item.checkState() == Qt.Checked: select_all_btn = QPushButton("全选")
bg_color = self.COLOR_RATIO deselect_all_btn = QPushButton("全不选")
for name, ref_item in self.index_checkboxes.items(): select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
if ref_item is item: deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
bg_color = self._formula_color_map.get(name, self.COLOR_RATIO) button_layout.addWidget(select_all_btn)
break button_layout.addWidget(deselect_all_btn)
item.setBackground(QBrush(bg_color)) button_layout.addStretch()
else:
item.setBackground(QBrush(self.COLOR_RATIO))
def _auto_load_formulas(self): preproc_layout.addLayout(preproc_grid)
if os.path.exists(self.builtin_formula_path): preproc_layout.addLayout(button_layout)
self.refresh_formulas(silent=True) preproc_group.setLayout(preproc_layout)
else: layout.addWidget(preproc_group)
print(f"DEBUG: 自动加载失败,路径不存在: {self.builtin_formula_path}")
def refresh_formulas(self, silent=False): # 模型选择 - 多选
path = self.builtin_formula_path model_group = QGroupBox("模型类型 (可多选)")
if not os.path.exists(path): model_layout = QVBoxLayout()
if not silent:
QMessageBox.warning(self, "错误", f"找不到内置公式文件:\n{path}")
return
try: model_grid = QGridLayout()
df = None self.model_checkboxes = {}
for enc in ('utf-8', 'gbk', 'utf-8-sig'):
try:
df = pd.read_csv(path, encoding=enc)
if 'Formula_Name' in df.columns:
break
except Exception:
continue
if df is None or 'Formula_Name' not in df.columns: model_groups = [
if not silent: ("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
QMessageBox.critical(self, "错误", "CSV缺少 'Formula_Name'") ("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
return ("【集成学习】", ['GradientBoosting', 'AdaBoost']),
("【其他模型】", ['SVR', 'KNN', 'MLP'])
self._formula_type_map.clear()
self._formula_coef_map.clear()
for _, row in df.iterrows():
name = str(row['Formula_Name']).strip()
if not name:
continue
ftype = str(row.get('Formula_Type', 'ratio')).strip().lower()
self._formula_type_map[name] = ftype
coef_str = str(row.get('Coefficient', '')).strip()
if coef_str:
try:
coeffs = [float(c.strip()) for c in coef_str.split(',') if c.strip()]
self._formula_coef_map[name] = coeffs
except Exception:
self._formula_coef_map[name] = []
else:
self._formula_coef_map[name] = []
self.formula_list.clear()
self.index_checkboxes.clear()
self._formula_color_map.clear()
for name, ftype in self._formula_type_map.items():
item = QListWidgetItem(name, self.formula_list)
item.setCheckState(Qt.Checked)
if ftype == 'concentration':
bg_color = QColor(220, 240, 255)
else:
bg_color = self.COLOR_RATIO
self._formula_color_map[name] = bg_color
item.setBackground(QBrush(bg_color))
self.index_checkboxes[name] = item
self.formula_list.adjustSize()
print(f"✅ 加载 {len(self.index_checkboxes)} 个公式")
except Exception as e:
if not silent:
QMessageBox.critical(self, "加载失败", f"原因: {str(e)}")
def _select_ratio_only(self):
for name, item in self.index_checkboxes.items():
ftype = self._formula_type_map.get(name, 'ratio')
item.setCheckState(Qt.Checked if ftype == 'ratio' else Qt.Unchecked)
def _select_conc_only(self):
for name, item in self.index_checkboxes.items():
ftype = self._formula_type_map.get(name, 'ratio')
item.setCheckState(Qt.Checked if ftype == 'concentration' else Qt.Unchecked)
def select_all_formulas(self):
for item in self.index_checkboxes.values():
item.setCheckState(Qt.Checked)
def deselect_all_formulas(self):
for item in self.index_checkboxes.values():
item.setCheckState(Qt.Unchecked)
def get_config(self) -> Dict:
selected = [
name for name, item in self.index_checkboxes.items()
if item.checkState() == Qt.Checked
] ]
formula_coefficients = {
name: self._formula_coef_map.get(name, [])
for name in selected
}
return {
'training_csv_path': self.training_data_widget.get_path(),
'formula_csv_file': self.builtin_formula_path,
'formula_names': selected,
'formula_coefficients': formula_coefficients,
'enabled': self.enable_checkbox.isChecked(),
'output_mode': self.mode_group.checkedId(),
}
def set_config(self, config: Dict): row = 0
for group_name, models in model_groups:
group_label = QLabel(f"<b>{group_name}</b>")
group_label.setStyleSheet(
f"background-color: {ModernStylesheet.COLORS['hover']}; "
f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
f"border-radius: 3px;"
)
model_grid.addWidget(group_label, row, 0, 1, 4)
row += 1
for i, model in enumerate(models):
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
checkbox.setChecked(False)
self.model_checkboxes[model] = checkbox
model_grid.addWidget(checkbox, row, i % 4)
if (i + 1) % 4 == 0:
row += 1
row += 1
model_button_layout = QHBoxLayout()
model_select_all = QPushButton("全选")
model_deselect_all = QPushButton("全不选")
model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
model_button_layout.addWidget(model_select_all)
model_button_layout.addWidget(model_deselect_all)
model_button_layout.addStretch()
model_layout.addLayout(model_grid)
model_layout.addLayout(model_button_layout)
model_group.setLayout(model_layout)
layout.addWidget(model_group)
# 数据划分方法 - 多选
split_group = QGroupBox("数据划分方法 (可多选)")
split_layout = QVBoxLayout()
split_grid = QGridLayout()
self.split_checkboxes = {}
split_methods = ['spxy', 'ks', 'random']
for i, method in enumerate(split_methods):
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
checkbox.setChecked(False)
self.split_checkboxes[method] = checkbox
split_grid.addWidget(checkbox, 0, i)
split_button_layout = QHBoxLayout()
split_select_all = QPushButton("全选")
split_deselect_all = QPushButton("全不选")
split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
split_button_layout.addWidget(split_select_all)
split_button_layout.addWidget(split_deselect_all)
split_button_layout.addStretch()
split_layout.addLayout(split_grid)
split_layout.addLayout(split_button_layout)
split_group.setLayout(split_layout)
layout.addWidget(split_group)
self.ml_page.setLayout(layout)
def _toggle_checkboxes(self, checkboxes_dict, checked):
"""统一设置checkbox状态"""
for checkbox in checkboxes_dict.values():
checkbox.setChecked(checked)
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_output_path(self):
"""浏览输出文件路径(保存对话框)"""
current = self.output_path.get_path().strip()
if current:
initial_dir = os.path.dirname(current)
initial_file = os.path.basename(current)
else:
initial_dir = ""
initial_file = ""
if not initial_dir or not os.path.isdir(initial_dir):
# 默认定位到 indices 目录
work_dir = self._get_default_work_dir()
initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
if initial_dir and not os.path.isdir(initial_dir):
os.makedirs(initial_dir, exist_ok=True)
file_path, _ = QFileDialog.getSaveFileName(
self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
"CSV Files (*.csv);;All Files (*.*)"
)
if file_path:
self.output_path.set_path(file_path)
def get_config(self):
"""获取配置"""
preprocessing_methods = [
method for method, checkbox in self.preproc_checkboxes.items()
if checkbox.isChecked()
]
model_names = [
model for model, checkbox in self.model_checkboxes.items()
if checkbox.isChecked()
]
split_methods = [
method for method, checkbox in self.split_checkboxes.items()
if checkbox.isChecked()
]
config = {
'feature_start_column': self.feature_start.text(),
'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
'model_names': model_names if model_names else ['SVR'],
'split_methods': split_methods if split_methods else ['random'],
'cv_folds': self.cv_folds.value()
}
training_csv_path = self.training_csv_file.get_path()
if training_csv_path:
config['training_csv_path'] = training_csv_path
output_path = self.output_path.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'feature_start_column' in config:
self.feature_start.setText(str(config['feature_start_column']))
if 'cv_folds' in config:
self.cv_folds.setValue(config['cv_folds'])
if 'preprocessing_methods' in config:
methods = config['preprocessing_methods']
for method, checkbox in self.preproc_checkboxes.items():
checkbox.setChecked(method in methods)
if 'model_names' in config:
models = config['model_names']
for model, checkbox in self.model_checkboxes.items():
checkbox.setChecked(model in models)
if 'split_methods' in config:
methods = config['split_methods']
for method, checkbox in self.split_checkboxes.items():
checkbox.setChecked(method in methods)
if 'training_csv_path' in config: if 'training_csv_path' in config:
self.training_data_widget.set_path(config['training_csv_path']) self.training_csv_file.set_path(config['training_csv_path'])
if 'formula_names' in config: if 'output_path' in config:
sel = set(config['formula_names']) self.output_path.set_path(config['output_path'])
for name, item in self.index_checkboxes.items():
item.setCheckState(Qt.Checked if name in sel else Qt.Unchecked)
self.enable_checkbox.setChecked(config.get('enabled', True))
if 'output_mode' in config:
btn = self.mode_group.button(config['output_mode'])
if btn:
btn.setChecked(True)
def update_from_config(self, work_dir=None, pipeline=None): def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
if work_dir: if work_dir:
self.work_dir = work_dir self.work_dir = work_dir
main = self.window() elif hasattr(self, 'work_dir') and self.work_dir:
if hasattr(main, 'step5_panel'): pass
p5 = main.step5_panel.output_file.get_path() else:
if p5: self.work_dir = None
if not os.path.isabs(p5):
p5 = os.path.join(self.work_dir or '', p5)
p5 = p5.replace('\\', '/')
self.training_data_widget.set_path(p5)
def _get_work_dir(self) -> Optional[str]: # 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
main_window = self.window()
if hasattr(main_window, 'step5_panel'):
# 优先直接从 Step5 的输出 widget 读取
step5_output = main_window.step5_panel.output_file.get_path()
if step5_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_output):
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
self.training_csv_file.set_path(step5_output)
elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
# 回退:从 Step5 的 config 字典中查找可能的键名
step5_cfg = main_window.step5_panel.get_config()
step5_csv = (
step5_cfg.get('training_csv_path')
or step5_cfg.get('output_file')
or step5_cfg.get('csv_path')
or step5_cfg.get('output_csv')
)
if step5_csv:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_csv):
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
self.training_csv_file.set_path(step5_csv)
# 2. 自动填充输出文件路径(基于工作目录和输入文件名)
# 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
# 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
if self.work_dir: if self.work_dir:
return self.work_dir indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
main = self.window() os.makedirs(indices_dir, exist_ok=True)
if hasattr(main, 'work_dir') and main.work_dir: training_csv = self.training_csv_file.get_path()
return main.work_dir if training_csv:
return None basename = os.path.splitext(os.path.basename(training_csv))[0]
output_file = f"{basename}_indices.csv"
def _get_coord_cols(self, df: pd.DataFrame) -> Tuple[str, str]: else:
coord_candidates = ['lon', 'lng', 'longitude', '经度', 'x', 'lon_utm', 'utm_x', 'pixel_x'] output_file = "water_quality_indices.csv"
lat_candidates = ['lat', 'latitude', '纬度', 'y', 'lat_utc', 'utm_y', 'pixel_y'] output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
self.output_path.set_path(output_path)
x_col, y_col = None, None else:
for col in df.columns: self.output_path.set_path("")
cl = col.lower()
if x_col is None and any(c in cl for c in coord_candidates):
x_col = col
if y_col is None and any(c in cl for c in lat_candidates):
y_col = col
if x_col is None and len(df.columns) >= 2:
x_col = df.columns[0]
if y_col is None and len(df.columns) >= 2:
y_col = df.columns[1]
return x_col or 'x_coord', y_col or 'y_coord'
def run_step(self): def run_step(self):
config = self.get_config() """独立运行步骤8"""
training_csv_path = self.training_csv_file.get_path()
if not config['enabled']: if not training_csv_path:
QMessageBox.information(self, "提示", "已禁用计算流程(启用计算流程未勾选)") QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件")
return
training_path = config['training_csv_path']
if not training_path or not os.path.exists(training_path):
QMessageBox.warning(self, "提示", "请先选择输入特征提取CSV文件")
return return
main_window = self.window() main_window = self.window()
if hasattr(main_window, 'run_single_step'): if hasattr(main_window, 'run_single_step'):
pipeline_config = {'step8_ml_train': config} config = {'step8_ml_train': self.get_config()}
main_window.run_single_step('step8_ml_train', pipeline_config) main_window.run_single_step('step8_ml_train', config)
def get_training_params(self):
"""获取模型训练参数"""
return {
'pipeline_type': 'machine_learning',
'feature_start': float(self.feature_start.text()),
'cv_folds': self.cv_folds.value(),
'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
}

View File

@ -439,11 +439,11 @@ class Step9MlPredictPanel(QWidget):
main_window = self.window() main_window = self.window()
if hasattr(main_window, 'run_single_step'): if hasattr(main_window, 'run_single_step'):
config = { config = {
'step11_ml': self.get_config(), 'step9_ml_predict': self.get_config(),
'_external_models_dict': checked_dict, '_external_models_dict': checked_dict,
'_external_model_dir': self.external_model_dir, '_external_model_dir': self.external_model_dir,
} }
main_window.run_single_step('step11_ml', config) main_window.run_single_step('step9_ml_predict', config)
return return
# 默认流程:使用模型目录 # 默认流程:使用模型目录
@ -454,5 +454,5 @@ class Step9MlPredictPanel(QWidget):
main_window = self.window() main_window = self.window()
if hasattr(main_window, 'run_single_step'): if hasattr(main_window, 'run_single_step'):
config = {'step11_ml': self.get_config()} config = {'step9_ml_predict': self.get_config()}
main_window.run_single_step('step11_ml', config) main_window.run_single_step('step9_ml_predict', config)

View File

@ -117,17 +117,17 @@ from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.panels.step1_panel import Step1Panel from src.gui.panels.step1_panel import Step1Panel
from src.gui.panels.step2_panel import Step2Panel from src.gui.panels.step2_panel import Step2Panel
from src.gui.panels.step3_panel import Step3Panel from src.gui.panels.step3_panel import Step3Panel
from src.gui.panels.step4_sampling_panel import Step4SamplingPanel # 采样点布设原step10→新step4 from src.gui.panels.step4_sampling_panel import Step4SamplingPanel # 采样点布设
from src.gui.panels.step5_clean_panel import Step5CleanPanel # 数据清洗 from src.gui.panels.step5_clean_panel import Step5CleanPanel # 数据清洗
from src.gui.panels.step6_feature_panel import Step6FeaturePanel # 光谱特征 from src.gui.panels.step6_feature_panel import Step6FeaturePanel # 光谱特征
from src.gui.panels.step7_index_panel import Step7IndexPanel # 水质光谱指数 from src.gui.panels.step7_index_panel import Step7IndexPanel # 水质光谱指数
from src.gui.panels.step10_map_panel import Step10MapPanel # 水色指数反演 from src.gui.panels.step10_watercolor_panel import Step10WatercolorPanel # 水色指数反演
from src.gui.panels.step8_ml_train_panel import Step8MlTrainPanel # 机器学习建模 from src.gui.panels.step8_ml_train_panel import Step8MlTrainPanel # 机器学习建模
from src.gui.panels.step9_ml_predict_panel import Step9MlPredictPanel # 机器学习预测 from src.gui.panels.step9_ml_predict_panel import Step9MlPredictPanel # 机器学习预测
from src.gui.panels.step14_panel import Step14Panel
from src.gui.dialogs import BandConfirmDialog, AISettingsDialog from src.gui.dialogs import BandConfirmDialog, AISettingsDialog
from src.gui.panels.step11_viz_panel import Step11VizPanel # 可视化(覆盖旧 step11_viz_panel.py from src.gui.panels.step11_map_panel import Step11MapPanel # 专题图生成
from src.gui.panels.step12_report_panel import Step12ReportPanel # 报告生成 from src.gui.panels.step12_viz_panel import Step12VizPanel # 可视化
from src.gui.panels.step13_report_panel import Step13ReportPanel # 报告生成
# Pipeline 核心异常(用于预检弹窗) # Pipeline 核心异常(用于预检弹窗)
from src.core.pipeline.runner import PipelineHalt from src.core.pipeline.runner import PipelineHalt
@ -1399,7 +1399,7 @@ class WaterQualityGUI(QMainWindow):
'step9_ml_predict': { 'step9_ml_predict': {
'predictions': '11_12_13_predictions/Machine_Learning_Prediction/' 'predictions': '11_12_13_predictions/Machine_Learning_Prediction/'
}, },
'step10_map': { 'step11_map': {
'distribution_maps': '14_visualization/' 'distribution_maps': '14_visualization/'
} }
} }
@ -1435,7 +1435,7 @@ class WaterQualityGUI(QMainWindow):
'sampling_csv_path': ('step4_sampling', 'sampling_points', 'sampling_csv_file'), 'sampling_csv_path': ('step4_sampling', 'sampling_points', 'sampling_csv_file'),
'models_dir': ('step8_ml_train', 'models', 'models_dir_file') 'models_dir': ('step8_ml_train', 'models', 'models_dir_file')
}, },
'step10_map': { 'step11_map': {
'prediction_csv_path': ('step9_ml_predict', 'predictions', 'prediction_csv_file') 'prediction_csv_path': ('step9_ml_predict', 'predictions', 'prediction_csv_file')
} }
} }
@ -1826,9 +1826,10 @@ class WaterQualityGUI(QMainWindow):
], ],
"阶段四:预测与成果输出": [ "阶段四:预测与成果输出": [
("step9_ml_predict", "9. 机器学习预测"), ("step9_ml_predict", "9. 机器学习预测"),
("step10_map", "10. 专题图生成"), ("step10_watercolor", "10. 水色指数反演"),
("step11_viz", "11. 可视化展示"), ("step11_map", "11. 专题图生成"),
("step12_report", "12. 分析报告生成") ("step12_viz", "12. 可视化展示"),
("step13_report", "13. 分析报告生成")
] ]
} }
@ -1936,17 +1937,17 @@ class WaterQualityGUI(QMainWindow):
self.step9_ml_predict_panel = Step9MlPredictPanel() self.step9_ml_predict_panel = Step9MlPredictPanel()
self.step_stack.addTab(self.create_scroll_area(self.step9_ml_predict_panel), QIcon(self.get_icon_path("10.png")), "机器学习预测") self.step_stack.addTab(self.create_scroll_area(self.step9_ml_predict_panel), QIcon(self.get_icon_path("10.png")), "机器学习预测")
self.step10_map_panel = Step10MapPanel() self.step10_watercolor_panel = Step10WatercolorPanel()
self.step_stack.addTab(self.create_scroll_area(self.step10_map_panel), QIcon(self.get_icon_path("10.png")), "专题图生成") self.step_stack.addTab(self.create_scroll_area(self.step10_watercolor_panel), QIcon(self.get_icon_path("10.png")), "水色指数反演")
self.step14_panel = Step14Panel() self.step11_map_panel = Step11MapPanel()
self.step_stack.addTab(self.create_scroll_area(self.step14_panel), QIcon(self.get_icon_path("11.png")), "专题图生成") self.step_stack.addTab(self.create_scroll_area(self.step11_map_panel), QIcon(self.get_icon_path("10.png")), "专题图生成")
self.step11_viz_panel = Step11VizPanel() self.step12_viz_panel = Step12VizPanel()
self.step_stack.addTab(self.create_scroll_area(self.step11_viz_panel), QIcon(self.get_icon_path("9.png")), "可视化") self.step_stack.addTab(self.create_scroll_area(self.step12_viz_panel), QIcon(self.get_icon_path("9.png")), "可视化")
self.step12_report_panel = Step12ReportPanel(main_window=self) self.step13_report_panel = Step13ReportPanel(main_window=self)
self.step_stack.addTab(self.create_scroll_area(self.step12_report_panel), QIcon(self.get_icon_path("10.png")), "报告生成") self.step_stack.addTab(self.create_scroll_area(self.step13_report_panel), QIcon(self.get_icon_path("10.png")), "报告生成")
# 连接Tab切换信号实现双向同步必须在step_stack创建后 # 连接Tab切换信号实现双向同步必须在step_stack创建后
self.step_stack.currentChanged.connect(self.on_tab_changed) self.step_stack.currentChanged.connect(self.on_tab_changed)
@ -2085,18 +2086,11 @@ class WaterQualityGUI(QMainWindow):
# 根据步骤ID查找对应的tab索引 # 根据步骤ID查找对应的tab索引
step_id_to_tab = { step_id_to_tab = {
'step1': 0, 'step1': 0, 'step2': 1, 'step3': 2, 'step4_sampling': 3,
'step2': 1, 'step5_clean': 4, 'step6_feature': 5, 'step7_index': 6,
'step3': 2, 'step8_ml_train': 7, 'step9_ml_predict': 8,
'step4_sampling': 3, 'step10_watercolor': 9, 'step11_map': 10,
'step5_clean': 4, 'step12_viz': 11, 'step13_report': 12,
'step6_feature': 5,
'step7_index': 6,
'step8_ml_train': 7,
'step9_ml_predict': 8,
'step10_map': 9,
'step11_viz': 11,
'step12_report': 12,
} }
if item_data in step_id_to_tab: if item_data in step_id_to_tab:
@ -2110,21 +2104,13 @@ class WaterQualityGUI(QMainWindow):
if index < 0: if index < 0:
return return
# Tab索引到步骤ID的反向映射 # Tab索引到步骤ID的反向映射13个Tabindex 0-12
tab_to_step_id = { tab_to_step_id = {
0: 'step1', 0: 'step1', 1: 'step2', 2: 'step3', 3: 'step4_sampling',
1: 'step2', 4: 'step5_clean', 5: 'step6_feature', 6: 'step7_index',
2: 'step3', 7: 'step8_ml_train', 8: 'step9_ml_predict',
3: 'step4_sampling', 9: 'step10_watercolor', 10: 'step11_map',
4: 'step5_clean', 11: 'step12_viz', 12: 'step13_report',
5: 'step6_feature',
6: 'step7_index',
7: 'step8_ml_train',
8: 'step9_ml_predict',
9: 'step10_map',
10: None, # 遗留 step14_panel保留 tab 但不加入 process_stages
11: 'step11_viz',
12: 'step12_report',
} }
if index not in tab_to_step_id: if index not in tab_to_step_id:
@ -2143,49 +2129,27 @@ class WaterQualityGUI(QMainWindow):
self.step_list.setCurrentRow(row) self.step_list.setCurrentRow(row)
break break
# Step2 切换时自动填充数据流转路径 # 面板自动填充:统一 mapping 覆盖 index 0-12
if index == 1: mapping = {
self.step2_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) 0: (self.step1_panel, "Step1"),
1: (self.step2_panel, "Step2"),
2: (self.step3_panel, "Step3"),
3: (self.step4_sampling_panel, "Step4"),
4: (self.step5_clean_panel, "Step5"),
5: (self.step6_feature_panel, "Step6"),
6: (self.step7_index_panel, "Step7"),
7: (self.step8_ml_train_panel, "Step8"),
8: (self.step9_ml_predict_panel, "Step9"),
9: (self.step10_watercolor_panel, "Step10"), # 水色指数反演
10: (self.step11_map_panel, "Step11"), # 专题图生成
11: (self.step12_viz_panel, "Step12"),
12: (self.step13_report_panel, "Step13")
}
# Step3 切换时自动填充数据流转路径 if index in mapping:
elif index == 2: panel, _ = mapping[index]
self.step3_panel.update_from_config(work_dir=self.work_dir) if hasattr(panel, 'update_from_config'):
panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step4采样点布设切换时自动填充输出路径
elif index == 3:
self.step4_sampling_panel.update_from_config(work_dir=self.work_dir)
# Step5数据清洗切换时自动填充数据流转路径
elif index == 4:
self.step5_clean_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step6光谱特征切换时自动填充输出路径
elif index == 5:
self.step6_feature_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step7水质光谱指数计算切换时自动填充水质参数 CSV
elif index == 6:
self.step7_index_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step8机器学习建模切换时自动填充训练数据和输出路径
elif index == 7:
self.step8_ml_train_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step9机器学习预测切换时自动填充采样光谱和模型目录
elif index == 8:
self.step9_ml_predict_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step10水色指数反演切换时自动填充光谱数据和输出路径
elif index == 9:
self.step10_map_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step12专题图生成切换时自动填充预测结果目录
elif index == 10:
self.step14_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
# Step13可视化分析切换时自动推断图像目录并加载目录树
elif index == 11:
self.step11_viz_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline)
def apply_stylesheet(self): def apply_stylesheet(self):
"""应用样式表 - 应用现代化设计风格""" """应用样式表 - 应用现代化设计风格"""
@ -2234,12 +2198,12 @@ class WaterQualityGUI(QMainWindow):
self.step7_index_panel.set_config(config['step7_index']) self.step7_index_panel.set_config(config['step7_index'])
if 'step9_ml_predict' in config: if 'step9_ml_predict' in config:
self.step9_ml_predict_panel.set_config(config['step9_ml_predict']) self.step9_ml_predict_panel.set_config(config['step9_ml_predict'])
if 'step14' in config: if 'step11_map' in config:
self.step14_panel.set_config(config['step14']) self.step11_map_panel.set_config(config['step11_map'])
if 'step11_viz' in config: if 'step12_viz' in config:
self.step11_viz_panel.set_config(config['step11_viz']) self.step12_viz_panel.set_config(config['step12_viz'])
if 'step12_report' in config: if 'step13_report' in config:
self.step12_report_panel.set_config(config['step12_report']) self.step13_report_panel.set_config(config['step13_report'])
self.config_file = file_path self.config_file = file_path
self.log_message(f"已加载配置: {file_path}", "info") self.log_message(f"已加载配置: {file_path}", "info")
@ -2282,10 +2246,9 @@ class WaterQualityGUI(QMainWindow):
'step7_index': self.step7_index_panel.get_config(), 'step7_index': self.step7_index_panel.get_config(),
'step8_ml_train': self.step8_ml_train_panel.get_config(), 'step8_ml_train': self.step8_ml_train_panel.get_config(),
'step9_ml_predict': self.step9_ml_predict_panel.get_config(), 'step9_ml_predict': self.step9_ml_predict_panel.get_config(),
'step10_map': self.step10_map_panel.get_config(), 'step11_map': self.step11_map_panel.get_config(),
'step11_viz': self.step11_viz_panel.get_config(), 'step12_viz': self.step12_viz_panel.get_config(),
'step12_report': self.step12_report_panel.get_config(), 'step13_report': self.step13_report_panel.get_config(),
'step14': self.step14_panel.get_config(),
} }
return config return config
@ -2338,10 +2301,9 @@ class WaterQualityGUI(QMainWindow):
'step7_index': self.step7_index_panel, 'step7_index': self.step7_index_panel,
'step8_ml_train': self.step8_ml_train_panel, 'step8_ml_train': self.step8_ml_train_panel,
'step9_ml_predict': self.step9_ml_predict_panel, 'step9_ml_predict': self.step9_ml_predict_panel,
'step10_map': self.step10_map_panel, 'step11_map': self.step11_map_panel,
'step11_viz': self.step11_viz_panel, 'step12_viz': self.step12_viz_panel,
'step12_report': self.step12_report_panel, 'step13_report': self.step13_report_panel,
'step14': self.step14_panel,
} }
return panel_map.get(step_id) return panel_map.get(step_id)
@ -2440,10 +2402,10 @@ class WaterQualityGUI(QMainWindow):
'8_Regression_Modeling': 'step8_ml_train', '8_Regression_Modeling': 'step8_ml_train',
'9_Custom_Regression_Modeling': 'step9_ml_predict', '9_Custom_Regression_Modeling': 'step9_ml_predict',
'11_12_13_predictions/Machine_Learning_Prediction': 'step9_ml_predict', '11_12_13_predictions/Machine_Learning_Prediction': 'step9_ml_predict',
'11_12_13_predictions/Non_Empirical_Prediction': 'step10_map', '11_12_13_predictions/Non_Empirical_Prediction': 'step11_map',
'11_12_13_predictions/Custom_Regression_Prediction': 'step11_viz', '11_12_13_predictions/Custom_Regression_Prediction': 'step12_viz',
'14_visualization': 'step12_report', '14_visualization': 'step13_report',
'14_geotiff_batch_rendering': 'step14' '10_geotiff_batch_rendering': 'step11_map'
} }
for subdir, step_ids in subdirs.items(): for subdir, step_ids in subdirs.items():
@ -2493,7 +2455,7 @@ class WaterQualityGUI(QMainWindow):
discovered_outputs[step_id]['water_indices'] = str(file_path) discovered_outputs[step_id]['water_indices'] = str(file_path)
elif 'sampling_spectra' in file_name and step_id == 'step4_sampling': elif 'sampling_spectra' in file_name and step_id == 'step4_sampling':
discovered_outputs[step_id]['sampling_points'] = str(file_path) discovered_outputs[step_id]['sampling_points'] = str(file_path)
elif file_name.endswith('.csv') and step_id in ['step9_ml_predict', 'step10_map', 'step11_viz']: elif file_name.endswith('.csv') and step_id in ['step9_ml_predict', 'step11_map', 'step12_viz']:
discovered_outputs[step_id]['predictions'] = str(file_path) discovered_outputs[step_id]['predictions'] = str(file_path)
# 更新内部记录 # 更新内部记录
@ -2517,7 +2479,7 @@ class WaterQualityGUI(QMainWindow):
self.scan_work_directory_for_files(work_path) self.scan_work_directory_for_files(work_path)
step_order = ['step2', 'step3', 'step4_sampling', 'step5_clean', 'step6_feature', 'step7_index', step_order = ['step2', 'step3', 'step4_sampling', 'step5_clean', 'step6_feature', 'step7_index',
'step8_ml_train', 'step9_ml_predict', 'step10_map', 'step11_viz', 'step12_report', 'step14'] 'step8_ml_train', 'step9_ml_predict', 'step11_map', 'step12_viz', 'step13_report']
filled_count = 0 filled_count = 0
for step_id in step_order: for step_id in step_order:
@ -2544,10 +2506,9 @@ class WaterQualityGUI(QMainWindow):
('step7_index', self.step7_index_panel), ('step7_index', self.step7_index_panel),
('step8_ml_train', self.step8_ml_train_panel), ('step8_ml_train', self.step8_ml_train_panel),
('step9_ml_predict', self.step9_ml_predict_panel), ('step9_ml_predict', self.step9_ml_predict_panel),
('step10_map', self.step10_map_panel), ('step11_map', self.step11_map_panel),
('step11_viz', self.step11_viz_panel), ('step12_viz', self.step12_viz_panel),
('step12_report', self.step12_report_panel), ('step13_report', self.step13_report_panel),
('step14', self.step14_panel)
] ]
for step_id, panel in panels_with_dependencies: for step_id, panel in panels_with_dependencies:
@ -2617,10 +2578,10 @@ class WaterQualityGUI(QMainWindow):
self.statusBar().showMessage(f"工作目录: {dir_path}") self.statusBar().showMessage(f"工作目录: {dir_path}")
# 同步到可视化面板 # 同步到可视化面板
if hasattr(self, 'step11_viz_panel'): if hasattr(self, 'step12_viz_panel'):
self.step11_viz_panel.set_work_dir(dir_path) self.step12_viz_panel.set_work_dir(dir_path)
if hasattr(self, 'step12_report_panel'): if hasattr(self, 'step13_report_panel'):
self.step12_report_panel.set_work_dir(dir_path) self.step13_report_panel.set_work_dir(dir_path)
def open_work_directory(self): def open_work_directory(self):
"""打开工作目录""" """打开工作目录"""
@ -3179,7 +3140,7 @@ class WaterQualityGUI(QMainWindow):
step_id_to_tab_training = { step_id_to_tab_training = {
'step1': 0, 'step2': 1, 'step3': 2, 'step4_sampling': 3, 'step1': 0, 'step2': 1, 'step3': 2, 'step4_sampling': 3,
'step5_clean': 4, 'step6_feature': 5, 'step7_index': 6, 'step9_ml_predict': 7, 'step5_clean': 4, 'step6_feature': 5, 'step7_index': 6, 'step9_ml_predict': 7,
'step10_map': 8, 'step14': 10, 'step11_viz': 11, 'step12_report': 12 'step10_watercolor': 9, 'step11_map': 10, 'step12_viz': 11, 'step13_report': 12
} }
for step_id in disabled_step_ids: for step_id in disabled_step_ids: