Files
WQ_GUI/src/gui/panels/step9_panel.py

534 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step9 面板 - 分布图生成
"""
import os
import traceback
from pathlib import Path
from typing import List, Optional
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
QRadioButton, QButtonGroup, QMessageBox, QFileDialog,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
PIPELINE_AVAILABLE = True
except ImportError:
PIPELINE_AVAILABLE = False
class Step9BatchThread(QThread):
"""专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
finished_ok = pyqtSignal(int)
failed = pyqtSignal(str)
log_message = pyqtSignal(str, str)
def __init__(self, work_dir: str, csv_paths: List[str], step9_kwargs: dict, output_dir_optional: Optional[str]):
super().__init__()
self.work_dir = work_dir
self.csv_paths = csv_paths
self.step9_kwargs = step9_kwargs
self.output_dir_optional = (output_dir_optional or "").strip() or None
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
n = len(self.csv_paths)
for i, csv_p in enumerate(self.csv_paths):
self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
kw = {**self.step9_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
if self.output_dir_optional:
stem = Path(csv_p).stem
kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
else:
kw["output_image_path"] = None
pipeline.step9_generate_distribution_map(**kw)
self.finished_ok.emit(n)
except Exception as e:
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass
class Step9Panel(QWidget):
"""步骤9分布图生成"""
def __init__(self, parent=None):
super().__init__(parent)
self._batch_thread = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
hint = QLabel(
"独立运行:可选「单个 CSV」或「文件夹批量」扫描目录下所有 .csv"
"完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。"
)
hint.setWordWrap(True)
hint.setStyleSheet(
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
)
layout.addWidget(hint)
mode_row = QHBoxLayout()
self.mode_single_rb = QRadioButton("单个 CSV 文件")
self.mode_folder_rb = QRadioButton("文件夹批量")
self._mode_group = QButtonGroup(self)
self._mode_group.addButton(self.mode_single_rb, 0)
self._mode_group.addButton(self.mode_folder_rb, 1)
mode_row.addWidget(self.mode_single_rb)
mode_row.addWidget(self.mode_folder_rb)
mode_row.addStretch()
layout.addLayout(mode_row)
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
radio_style = """
QRadioButton {
font-size: 14px;
spacing: 8px;
color: #333333;
}
QRadioButton::indicator {
width: 16px;
height: 16px;
border: 2px solid #999999;
border-radius: 3px;
background-color: white;
}
QRadioButton::indicator:checked {
border: 2px solid #0078d4;
background-color: #0078d4;
image: none;
}
QRadioButton::indicator:hover {
border: 2px solid #005a9e;
}
"""
self.mode_single_rb.setStyleSheet(radio_style)
self.mode_folder_rb.setStyleSheet(radio_style)
self.prediction_csv_file = FileSelectWidget(
"预测结果CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.prediction_csv_file)
folder_row = QHBoxLayout()
self.prediction_csv_dir_label = QLabel("预测CSV目录:")
self.prediction_csv_dir_label.setMinimumWidth(120)
self.prediction_csv_dir_edit = QLineEdit()
self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
pred_dir_btn = QPushButton("浏览…")
pred_dir_btn.setMaximumWidth(80)
pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
folder_row.addWidget(self.prediction_csv_dir_label)
folder_row.addWidget(self.prediction_csv_dir_edit, 1)
folder_row.addWidget(pred_dir_btn)
self._folder_row_widget = QWidget()
self._folder_row_widget.setLayout(folder_row)
layout.addWidget(self._folder_row_widget)
self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv")
layout.addWidget(self.recursive_csv_cb)
self.boundary_file = FileSelectWidget(
"边界文件:",
"Shapefiles (*.shp);;All Files (*.*)"
)
layout.addWidget(self.boundary_file)
# 参数设置
params_group = QGroupBox("生成参数")
params_layout = QFormLayout()
self.resolution = QDoubleSpinBox()
self.resolution.setRange(1, 1000)
self.resolution.setValue(30)
params_layout.addRow("分辨率(米):", self.resolution)
self.input_crs = QLineEdit()
self.input_crs.setText("EPSG:32651")
params_layout.addRow("输入坐标系:", self.input_crs)
self.output_crs = QLineEdit()
self.output_crs.setText("EPSG:4326")
params_layout.addRow("输出坐标系:", self.output_crs)
self.show_points = QCheckBox("显示采样点")
params_layout.addRow("", self.show_points)
self.use_diffusion = QCheckBox("启用距离扩散")
self.use_diffusion.setChecked(True)
params_layout.addRow("", self.use_diffusion)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出目录
self.output_dir = FileSelectWidget(
"输出分布图目录:",
"Directories;;All Files (*.*)"
)
self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
self.output_dir.browse_btn.clicked.disconnect()
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
layout.addWidget(self.output_dir)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_button = QPushButton("独立运行此步骤")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.clicked.connect(self.run_step)
layout.addWidget(self.run_button)
layout.addStretch()
self.setLayout(layout)
# 信号绑定与初始状态
self.mode_single_rb.toggled.connect(self._toggle_input_mode)
self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
self._toggle_input_mode() # 根据默认值设置初始显示状态
def _toggle_input_mode(self):
"""槽函数:根据单选框状态动态显示/隐藏对应的输入组件。"""
folder_mode = self.mode_folder_rb.isChecked()
# 单个 CSV 模式:显示单文件选择,隐藏文件夹选择
self.prediction_csv_file.setVisible(not folder_mode)
# 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择
self._folder_row_widget.setVisible(folder_mode)
self.recursive_csv_cb.setVisible(folder_mode)
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_prediction_csv_dir(self):
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "11_12_13_predictions")
d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
if d:
self.prediction_csv_dir_edit.setText(d)
def _collect_csv_paths_from_folder(self) -> List[str]:
folder = (self.prediction_csv_dir_edit.text() or "").strip()
if not folder or not os.path.isdir(folder):
return []
root = Path(folder)
if self.recursive_csv_cb.isChecked():
files = sorted(root.rglob("*.csv"))
else:
files = sorted(root.glob("*.csv"))
return [str(p) for p in files if p.is_file()]
def _step9_base_pipeline_kwargs(self) -> dict:
return {
'boundary_shp_path': self.boundary_file.get_path(),
'resolution': self.resolution.value(),
'input_crs': self.input_crs.text(),
'output_crs': self.output_crs.text(),
'show_sample_points': self.show_points.isChecked(),
'use_distance_diffusion': self.use_diffusion.isChecked(),
}
def get_config(self):
pred_csv = (self.prediction_csv_file.get_path() or "").strip()
folder_mode = self.mode_folder_rb.isChecked()
pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
config = {
'step9_batch_mode': 'folder' if folder_mode else 'single',
'prediction_csv_dir': pred_dir if pred_dir else None,
'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
'boundary_shp_path': self.boundary_file.get_path(),
'resolution': self.resolution.value(),
'input_crs': self.input_crs.text(),
'output_crs': self.output_crs.text(),
'show_sample_points': self.show_points.isChecked(),
'use_distance_diffusion': self.use_diffusion.isChecked(),
}
out_dir = (self.output_dir.get_path() or "").strip()
if not folder_mode and pred_csv and out_dir:
stem = Path(pred_csv).stem
config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
else:
config['output_image_path'] = None
return config
def set_config(self, config):
mode = config.get('step9_batch_mode', 'single')
if mode == 'folder':
self.mode_folder_rb.setChecked(True)
else:
self.mode_single_rb.setChecked(True)
if config.get('prediction_csv_dir'):
self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
if 'recursive_csv_scan' in config:
self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
if 'prediction_csv_path' in config and config['prediction_csv_path']:
self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
if 'boundary_shp_path' in config:
self.boundary_file.set_path(config['boundary_shp_path'])
if 'resolution' in config:
self.resolution.setValue(config['resolution'])
if 'input_crs' in config:
self.input_crs.setText(config['input_crs'])
if 'output_crs' in config:
self.output_crs.setText(config['output_crs'])
if 'show_sample_points' in config:
self.show_points.setChecked(config['show_sample_points'])
if 'use_distance_diffusion' in config:
self.use_diffusion.setChecked(config['use_distance_diffusion'])
if 'output_dir' in config and config['output_dir']:
self.output_dir.set_path(str(config['output_dir']))
elif config.get('output_image_path'):
p = Path(str(config['output_image_path']))
if p.parent and str(p.parent) != '.':
self.output_dir.set_path(str(p.parent))
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充预测结果目录
优先使用 Step8机器学习预测的输出目录作为待预测 CSV 目录;
其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
if not main_window:
return
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
pred_dir = None
if hasattr(main_window, 'step8_panel'):
step8_widget = getattr(main_window.step8_panel, 'output_file', None)
step8_output = ""
if hasattr(step8_widget, 'get_path'):
step8_output = step8_widget.get_path() or ""
elif hasattr(step8_widget, 'text'):
step8_output = step8_widget.text() or ""
if step8_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_output):
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
# 提取父目录后追加 Machine_Learning_Prediction最底层真实子目录
base_pred_dir = str(Path(step8_output).parent)
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
# 2. 备选:从 Step8.5 界面读取非经验预测输出目录
if not pred_dir and hasattr(main_window, 'step8_5_panel'):
step8_5_widget = getattr(main_window.step8_5_panel, 'output_file', None)
step8_5_output = ""
if hasattr(step8_5_widget, 'get_path'):
step8_5_output = step8_5_widget.get_path() or ""
elif hasattr(step8_5_widget, 'text'):
step8_5_output = step8_5_widget.text() or ""
if step8_5_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_5_output):
step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
pred_dir = str(Path(step8_5_output).parent)
# 3. 备选:从 Step8.75 界面读取自定义回归预测输出目录
if not pred_dir and hasattr(main_window, 'step8_75_panel'):
step8_75_widget = getattr(main_window.step8_75_panel, 'output_dir_widget', None)
step8_75_output = ""
if hasattr(step8_75_widget, 'get_path'):
step8_75_output = step8_75_widget.get_path() or ""
elif hasattr(step8_75_widget, 'text'):
step8_75_output = step8_75_widget.text() or ""
if step8_75_output:
pred_dir = step8_75_output
# 自动填入"预测CSV目录"(文件夹批量模式)
if pred_dir:
existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
if not existing_dir:
self.prediction_csv_dir_edit.setText(pred_dir)
# 切换到文件夹批量模式
self.mode_folder_rb.setChecked(True)
# 4. 自动填充输出目录14_visualization
if self.work_dir:
output_dir = os.path.join(self.work_dir, "14_visualization")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_dir.get_path()
if not existing_out or not existing_out.strip():
self.output_dir.set_path(output_dir)
# 5. 自动探测原始矢量边界文件(.shp作为专题图底图
# 优先回溯 input-test/roi.shpgeopandas.read_file 仅支持矢量格式
if self.work_dir:
possible_shp = None
candidates = [
Path(self.work_dir).parent / "input-test" / "roi.shp",
Path(self.work_dir) / "roi.shp",
Path(self.work_dir).parent / "roi.shp",
]
for candidate in candidates:
if candidate.exists() and candidate.suffix.lower() == ".shp":
possible_shp = candidate
break
existing_boundary = (self.boundary_file.get_path() or "").strip()
if not existing_boundary and possible_shp:
self.boundary_file.set_path(str(possible_shp))
elif not existing_boundary:
# 未找到 .shp 时清空并提示用户手动选择矢量文件
self.boundary_file.set_path("")
print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。")
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def browse_output_dir(self):
"""浏览输出目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "14_visualization")
dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
if dir_path:
self.output_dir.set_path(dir_path)
def run_step(self):
"""独立运行步骤9"""
if self._batch_thread and self._batch_thread.isRunning():
QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
return
boundary_shp_path = self.boundary_file.get_path()
if not boundary_shp_path:
QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
return
if not os.path.exists(boundary_shp_path):
QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
return
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if not parent or not hasattr(parent, 'run_single_step'):
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
return
if self.mode_folder_rb.isChecked():
csv_list = self._collect_csv_paths_from_folder()
if not csv_list:
QMessageBox.warning(
self,
"输入验证失败",
"所选文件夹中未找到 .csv 文件,或目录无效。\n"
"可勾选「包含子文件夹」以递归扫描。",
)
return
if not PIPELINE_AVAILABLE:
QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
return
work_dir = getattr(parent, "work_dir", None) or "./work_dir"
work_dir = str(work_dir)
base_kw = self._step9_base_pipeline_kwargs()
out_dir_opt = (self.output_dir.get_path() or "").strip() or None
self.run_button.setEnabled(False)
self._batch_thread = Step9BatchThread(work_dir, csv_list, base_kw, out_dir_opt)
main_win = parent
def _batch_log(msg, lvl):
if hasattr(main_win, "log_message"):
main_win.log_message(msg, lvl)
self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
self._batch_thread.finished_ok.connect(self._on_step9_batch_ok, Qt.QueuedConnection)
self._batch_thread.failed.connect(self._on_step9_batch_fail, Qt.QueuedConnection)
self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection)
self._batch_thread.start()
if hasattr(parent, "log_message"):
parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV工作目录 {work_dir}", "info")
return
prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
if not prediction_csv_path:
QMessageBox.warning(
self,
"输入验证失败",
"请选择「预测结果 CSV」文件或切换到「文件夹批量」。",
)
return
if not os.path.isfile(prediction_csv_path):
QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
return
config = self.get_config()
parent.run_single_step('step9', {'step9': config})
def _on_step9_batch_ok(self, n: int):
QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
parent = self.parent()
while parent and not hasattr(parent, "log_message"):
parent = parent.parent()
if parent and hasattr(parent, "log_message"):
parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
def _on_step9_batch_fail(self, err: str):
QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
parent = self.parent()
while parent and not hasattr(parent, "log_message"):
parent = parent.parent()
if parent and hasattr(parent, "log_message"):
parent.log_message(err, "error")