refactor: 提取 WorkspaceManager,将文件扫描与路径业务逻辑从主 GUI 解耦

This commit is contained in:
DXC
2026-06-17 15:35:02 +08:00
parent 191a4b681d
commit 1949711cda
2 changed files with 254 additions and 252 deletions

View File

@ -0,0 +1,231 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
工作空间管理器
负责工作目录文件扫描、步骤输出路径发现、配置裁剪等业务逻辑,
与 GUI 组件解耦,不直接引用任何 UI 类。
"""
import copy
from pathlib import Path
class WorkspaceManager:
"""管理步骤默认输出路径、文件扫描与配置裁剪"""
# 白名单:科学数据格式后缀
SCIENTIFIC_EXTENSIONS = {'.dat', '.tif', '.tiff', '.shp'}
# 临时文件关键词黑名单
TMP_KEYWORDS = ('__tmp', '_tmp')
# 掩膜类型集合
MASK_TYPES = {'water_mask', 'glint_mask', 'boundary_mask'}
def __init__(self):
self.step_default_outputs = {
'step1': "1_water_mask/water_mask_from_ndwi.dat",
'step2': "2_Glint_Detection/severe_glint_area.dat",
'step3': "3_deglint/deglint_goodman.bsq",
'step4_sampling': "4_sampling/sampling_spectra.csv",
'step5_clean': "5_Data_Cleaning/processed_data.csv",
'step6_feature': "6_Spectral_Feature_Extraction/training_spectra.csv",
'step7_index': "7_Water_Quality_Indices/training_spectra_indices.csv",
'step8_ml_train': "8_Supervised_Model_Training/",
'step9_ml_predict': "9_ML_Prediction/",
'step10_watercolor': "10_WaterIndex_Images/",
'step11_map': "14_visualization/"
}
self.step_outputs = {}
@staticmethod
def _is_scientific_mask(path_str):
"""白名单判断:只有 .dat .tif .tiff .shp 才算科学数据格式"""
p = Path(path_str)
name_lower = str(path_str).lower()
if any(kw in name_lower for kw in WorkspaceManager.TMP_KEYWORDS):
return False
return p.suffix.lower() in WorkspaceManager.SCIENTIFIC_EXTENSIONS
def find_step_output(self, work_path, step_id, output_type, ref_img_path=None):
"""查找指定步骤的输出文件
Args:
work_path: 工作目录 Path 对象
step_id: 步骤 ID
output_type: 输出类型(如 'water_mask', 'deglint_image' 等)
ref_img_path: 参考影像路径(仅 output_type='reference_img' 时需要)
Returns:
找到的文件路径字符串,或 None
"""
if step_id not in self.step_default_outputs:
return None
raw = self.step_default_outputs[step_id]
rel_path = None
if isinstance(raw, str):
rel_path = raw
elif isinstance(raw, dict):
rel_path = raw.get(output_type) or list(raw.values())[0]
if not rel_path:
return None
# 特殊处理:从 step_outputs 记录中查找实际输出路径
if step_id in self.step_outputs:
actual_outputs = self.step_outputs[step_id]
if output_type in actual_outputs:
candidate = actual_outputs[output_type]
if output_type in self.MASK_TYPES and not self._is_scientific_mask(candidate):
pass
else:
return candidate
if output_type == 'water_mask':
if rel_path:
mask_path = work_path / rel_path
if mask_path.exists():
return str(mask_path)
elif output_type == 'reference_img':
if ref_img_path and Path(ref_img_path).exists():
return ref_img_path
elif output_type == 'deglint_image':
if rel_path:
deglint_path = work_path / rel_path
if deglint_path.exists():
return str(deglint_path)
deglint_dir = work_path / "3_deglint"
if deglint_dir.exists():
for file_path in deglint_dir.glob("deglint_*.bsq"):
return str(file_path)
for file_path in deglint_dir.glob("interpolated_*.bsq"):
return str(file_path)
elif rel_path:
if rel_path.endswith('/'):
output_path = work_path / rel_path.rstrip('/')
if output_path.exists() and output_path.is_dir():
return str(output_path)
else:
output_path = work_path / rel_path
if output_path.exists():
return str(output_path)
return None
def scan_work_directory_for_files(self, work_path):
"""扫描工作目录,自动发现各步骤的输出文件
Returns:
discovered_outputs: dict, {step_id: {output_type: path_str}}
"""
discovered_outputs = {}
subdirs = {
'1_water_mask': 'step1',
'2_Glint_Detection': 'step2',
'3_deglint': 'step3',
'5_Data_Cleaning': 'step5_clean',
'6_Spectral_Feature_Extraction': 'step6_feature',
'7_Water_Quality_Indices': 'step7_index',
'8_Supervised_Model_Training': 'step8_ml_train',
'8_Regression_Modeling': 'step8_ml_train',
'13_Custom_Regression': 'step13',
'9_ML_Prediction': 'step9_ml_predict',
'11_12_13_predictions/Non_Empirical_Prediction': 'step11_map',
'13_Custom_Regression/Custom_Regression_Prediction': 'step13',
'14_visualization': 'step13_report',
'10_geotiff_batch_rendering': 'step11_map'
}
for subdir, step_ids in subdirs.items():
subdir_path = work_path / subdir
if not subdir_path.exists():
continue
if isinstance(step_ids, str):
step_ids = [step_ids]
for file_path in subdir_path.rglob('*'):
if file_path.is_file():
file_name = file_path.name.lower()
for step_id in step_ids:
if step_id not in discovered_outputs:
discovered_outputs[step_id] = {}
if 'water_mask' in file_name and step_id == 'step1':
if self._is_scientific_mask(file_path):
discovered_outputs[step_id]['water_mask'] = str(file_path)
elif 'glint' in file_name and 'mask' in file_name and step_id == 'step2':
if self._is_scientific_mask(file_path):
discovered_outputs[step_id]['glint_mask'] = str(file_path)
elif 'deglint' in file_name and step_id == 'step3':
discovered_outputs[step_id]['deglint_image'] = str(file_path)
elif 'processed_data' in file_name and step_id == 'step4_sampling':
discovered_outputs[step_id]['processed_data'] = str(file_path)
elif 'training_spectra' in file_name and step_id == 'step5_clean':
discovered_outputs[step_id]['training_spectra'] = str(file_path)
elif 'water_quality_indices' in file_name and step_id == 'step6_feature':
discovered_outputs[step_id]['water_indices'] = str(file_path)
elif 'sampling_spectra' in file_name and step_id == 'step4_sampling':
discovered_outputs[step_id]['sampling_points'] = str(file_path)
elif file_name.endswith('.csv') and step_id in ['step9_ml_predict', 'step11_map', 'step12_viz']:
discovered_outputs[step_id]['predictions'] = str(file_path)
for step_id, outputs in discovered_outputs.items():
if step_id not in self.step_outputs:
self.step_outputs[step_id] = {}
self.step_outputs[step_id].update(outputs)
return discovered_outputs
def update_step_outputs(self, step_name, work_path):
"""更新指定步骤的输出路径记录"""
if step_name not in self.step_default_outputs:
return
step_outputs = self.step_default_outputs[step_name]
for output_type, relative_path in step_outputs.items():
if '*' in relative_path:
pattern_path = work_path / relative_path.replace('*', '*')
matching_files = list(pattern_path.parent.glob(pattern_path.name))
if matching_files:
latest_file = max(matching_files, key=lambda p: p.stat().st_mtime)
self.step_outputs[step_name][output_type] = str(latest_file)
else:
output_path = work_path / relative_path
if output_path.exists():
self.step_outputs[step_name][output_type] = str(output_path)
@staticmethod
def prune_config_for_prediction_mode(config: dict) -> dict:
"""Prediction-only 模式:禁用训练相关步骤,保留预测和成图步骤。
被禁用的 step dict 中统一写入 'enabled': False
这些配置最终传给 PipelineRunnerRunner 会跳过它们。
同时,被跳过的步骤的 required_input_files 在 build_missing_items
中不会被检查,从而自然规避了"CSV 缺失"等训练模式下的误报。
Args:
config: 完整配置字典(来自 get_current_config
Returns:
裁剪后的 config深拷贝原 config 不被修改)
"""
cfg = copy.deepcopy(config)
training_steps = [
"step4",
"step5",
"step7",
"step6",
"step8_non_empirical_modeling",
"step9",
]
for step_id in training_steps:
step_cfg = cfg.setdefault(step_id, {})
step_cfg["enabled"] = False
return cfg

View File

@ -158,6 +158,7 @@ from src.gui.core.worker_thread import (
from src.gui.core.preflight_dialog import PreflightDialog
from src.gui.core.pipeline_mode_dialog import PipelineModeDialog
from src.gui.core.viz_thread import VisualizationWorkerThread, _viz_training_spectra_csv_path
from src.core.workspace_manager import WorkspaceManager
class WaterQualityGUI(QMainWindow):
@ -183,10 +184,10 @@ class WaterQualityGUI(QMainWindow):
# 训练数据模式状态
self.has_training_data = True # 默认有训练数据
# 步骤输出路径记录
self.step_outputs = {} # 记录每个步骤的输出路径
# 工作空间管理器(文件扫描、路径发现、配置裁剪)
self.workspace_manager = WorkspaceManager()
# 定义步骤依赖关系和标准输出路径
# 定义步骤依赖关系
self._init_step_dependencies()
self.init_ui()
@ -198,22 +199,7 @@ class WaterQualityGUI(QMainWindow):
QTimer.singleShot(100, self.init_workspace)
def _init_step_dependencies(self):
"""初始化步骤依赖关系和标准输出路径"""
# 定义每个步骤的标准输出路径模式(相对于工作目录)
self.step_default_outputs = {
'step1': "1_water_mask/water_mask_from_ndwi.dat",
'step2': "2_Glint_Detection/severe_glint_area.dat",
'step3': "3_deglint/deglint_goodman.bsq",
'step4_sampling': "4_sampling/sampling_spectra.csv",
'step5_clean': "5_Data_Cleaning/processed_data.csv",
'step6_feature': "6_Spectral_Feature_Extraction/training_spectra.csv",
'step7_index': "7_Water_Quality_Indices/training_spectra_indices.csv",
'step8_ml_train': "8_Supervised_Model_Training/",
'step9_ml_predict': "9_ML_Prediction/",
'step10_watercolor': "10_WaterIndex_Images/",
'step11_map': "14_visualization/"
}
"""初始化步骤依赖关系"""
# 依赖关系字典结构:
# '当前步骤ID': { '依赖参数名': ('上游步骤ID', '上游输出类型/Key', '当前步骤接收该路径的组件属性名') }
self.step_dependencies = {
@ -1081,7 +1067,11 @@ class WaterQualityGUI(QMainWindow):
dependencies = self.step_dependencies[step_id]
filled_count = 0
ref_img_path = None
if hasattr(self, 'step1_panel'):
ref_img_path = self.step1_panel.img_file.get_path()
for input_field, (dep_step, output_type, panel_attr) in dependencies.items():
# 检查面板是否有对应的属性
if not hasattr(panel, panel_attr):
@ -1101,7 +1091,7 @@ class WaterQualityGUI(QMainWindow):
continue
# 查找依赖步骤的输出文件
output_path = self.find_step_output(work_path, dep_step, output_type)
output_path = self.workspace_manager.find_step_output(work_path, dep_step, output_type, ref_img_path=ref_img_path)
if output_path and Path(output_path).exists():
# ★ 兼容 FileSelectWidget 与原生 QLineEdit
@ -1132,173 +1122,6 @@ class WaterQualityGUI(QMainWindow):
}
return panel_map.get(step_id)
def find_step_output(self, work_path, step_id, output_type):
"""查找指定步骤的输出文件"""
if step_id not in self.step_default_outputs:
return None
raw = self.step_default_outputs[step_id]
# ★ 兼容扁平化后的纯字符串路径格式
rel_path = None
if isinstance(raw, str):
rel_path = raw
elif isinstance(raw, dict):
rel_path = raw.get(output_type) or list(raw.values())[0]
if not rel_path:
return None
# ★ 掩膜类型列表:这些类型只接受科学数据格式
mask_types = {'water_mask', 'glint_mask', 'boundary_mask'}
# ★ 白名单机制:只允许 .dat .tif .tiff .shp拒绝其他一切格式
scientific_extensions = {'.dat', '.tif', '.tiff', '.shp'}
# ★ 临时文件关键词黑名单
tmp_keywords = ('__tmp', '_tmp')
def _is_scientific_mask(path_str):
"""白名单判断:只有 .dat .tif .tiff .shp 才算科学数据格式"""
p = Path(path_str)
name_lower = str(path_str).lower()
# 拒绝临时文件
if any(kw in name_lower for kw in tmp_keywords):
return False
# 白名单校验
return p.suffix.lower() in scientific_extensions
# 特殊处理从step_outputs记录中查找实际输出路径
if step_id in self.step_outputs:
actual_outputs = self.step_outputs[step_id]
if output_type in actual_outputs:
candidate = actual_outputs[output_type]
# ★ 掩膜类型白名单二次校验:不在白名单内的一律拒绝
if output_type in mask_types and not _is_scientific_mask(candidate):
# 非科学格式被拒绝,不使用 step_outputs 中的值
pass
else:
return candidate
# 根据输出类型查找对应的文件
if output_type == 'water_mask':
# 水域掩膜:直接用统一路径
if rel_path:
mask_path = work_path / rel_path
if mask_path.exists():
return str(mask_path)
elif output_type == 'reference_img':
# 参考影像从step1的配置中获取用户输入的影像路径
if hasattr(self, 'step1_panel'):
img_path = self.step1_panel.img_file.get_path()
if img_path and Path(img_path).exists():
return img_path
elif output_type == 'deglint_image':
# 去耀斑影像:直接用统一路径
if rel_path:
deglint_path = work_path / rel_path
if deglint_path.exists():
return str(deglint_path)
# 还要检查 Kutser 算法输出与插值方法生成的文件
deglint_dir = work_path / "3_deglint"
if deglint_dir.exists():
for file_path in deglint_dir.glob("deglint_*.bsq"):
return str(file_path)
for file_path in deglint_dir.glob("interpolated_*.bsq"):
return str(file_path)
elif rel_path:
# 直接匹配的输出类型(统一使用 rel_path
if rel_path.endswith('/'):
# 是目录
output_path = work_path / rel_path.rstrip('/')
if output_path.exists() and output_path.is_dir():
return str(output_path)
else:
# 是文件
output_path = work_path / rel_path
if output_path.exists():
return str(output_path)
return None
def scan_work_directory_for_files(self, work_path):
"""扫描工作目录,自动发现各步骤的输出文件"""
discovered_outputs = {}
# 扫描各个子目录
subdirs = {
'1_water_mask': 'step1',
'2_Glint_Detection': 'step2',
'3_deglint': 'step3',
'5_Data_Cleaning': 'step5_clean',
'6_Spectral_Feature_Extraction': 'step6_feature',
'7_Water_Quality_Indices': 'step7_index',
'8_Supervised_Model_Training': 'step8_ml_train',
'8_Regression_Modeling': 'step8_ml_train',
'13_Custom_Regression': 'step13',
'9_ML_Prediction': 'step9_ml_predict',
'11_12_13_predictions/Non_Empirical_Prediction': 'step11_map',
'13_Custom_Regression/Custom_Regression_Prediction': 'step13',
'14_visualization': 'step13_report',
'10_geotiff_batch_rendering': 'step11_map'
}
for subdir, step_ids in subdirs.items():
subdir_path = work_path / subdir
if not subdir_path.exists():
continue
if isinstance(step_ids, str):
step_ids = [step_ids]
# 扫描该目录下的文件
for file_path in subdir_path.rglob('*'):
if file_path.is_file():
file_name = file_path.name.lower()
# 根据文件名模式判断输出类型
for step_id in step_ids:
if step_id not in discovered_outputs:
discovered_outputs[step_id] = {}
# ★ 掩膜文件白名单过滤:只有 .dat .tif .tiff .shp 才通过,拒绝 .hdr .xml .png 等
scientific_extensions = {'.dat', '.tif', '.tiff', '.shp'}
tmp_keywords = ('__tmp', '_tmp')
def _is_scientific_mask(path_str):
"""白名单判断:拒绝 .hdr .xml 临时文件等,只接受科学数据格式"""
p = Path(path_str)
name_lower = str(path_str).lower()
if any(kw in name_lower for kw in tmp_keywords):
return False
return p.suffix.lower() in scientific_extensions
# 匹配不同的文件类型
if 'water_mask' in file_name and step_id == 'step1':
if _is_scientific_mask(file_path):
discovered_outputs[step_id]['water_mask'] = str(file_path)
elif 'glint' in file_name and 'mask' in file_name and step_id == 'step2':
if _is_scientific_mask(file_path):
discovered_outputs[step_id]['glint_mask'] = str(file_path)
elif 'deglint' in file_name and step_id == 'step3':
discovered_outputs[step_id]['deglint_image'] = str(file_path)
elif 'processed_data' in file_name and step_id == 'step4_sampling':
discovered_outputs[step_id]['processed_data'] = str(file_path)
elif 'training_spectra' in file_name and step_id == 'step5_clean':
discovered_outputs[step_id]['training_spectra'] = str(file_path)
elif 'water_quality_indices' in file_name and step_id == 'step6_feature':
discovered_outputs[step_id]['water_indices'] = str(file_path)
elif 'sampling_spectra' in file_name and step_id == 'step4_sampling':
discovered_outputs[step_id]['sampling_points'] = str(file_path)
elif file_name.endswith('.csv') and step_id in ['step9_ml_predict', 'step11_map', 'step12_viz']:
discovered_outputs[step_id]['predictions'] = str(file_path)
# 更新内部记录
for step_id, outputs in discovered_outputs.items():
if step_id not in self.step_outputs:
self.step_outputs[step_id] = {}
self.step_outputs[step_id].update(outputs)
return discovered_outputs
def auto_populate_all_steps(self):
"""自动填充所有步骤的输入路径"""
work_dir = getattr(self, 'work_dir', './work_dir')
@ -1309,7 +1132,7 @@ class WaterQualityGUI(QMainWindow):
return
# 首先扫描工作目录发现已有的输出文件
self.scan_work_directory_for_files(work_path)
self.workspace_manager.scan_work_directory_for_files(work_path)
step_order = ['step2', 'step3', 'step4_sampling', 'step5_clean', 'step6_feature', 'step7_index',
'step8_ml_train', 'step9_ml_predict', 'step11_map', 'step12_viz', 'step13_report']
@ -1612,41 +1435,6 @@ class WaterQualityGUI(QMainWindow):
return True
# ------------------------------------------------------------------
# ★ 全流程模式动态裁剪
# ------------------------------------------------------------------
def _prune_config_for_prediction_mode(self, config: dict) -> dict:
"""Prediction-only 模式:禁用训练相关步骤,保留预测和成图步骤。
被禁用的 step dict 中统一写入 'enabled': False
这些配置最终传给 PipelineRunnerRunner 会跳过它们。
同时,被跳过的步骤的 required_input_files 在 build_missing_items
中不会被检查,从而自然规避了"CSV 缺失"等训练模式下的误报。
Args:
config: 完整配置字典(来自 get_current_config
Returns:
裁剪后的 config深拷贝原 config 不被修改)
"""
cfg = copy.deepcopy(config)
# 在每个训练相关步骤的 dict 中写入 enabled=False
training_steps = [
"step4", # CSV 实测数据清洗
"step5", # 实测点光谱提取(→ training_csv_path
"step7", # ML 监督建模
"step6", # 水质指数计算(辅助训练)
"step8_non_empirical_modeling", # 非经验回归建模
"step9", # 自定义回归建模
]
for step_id in training_steps:
step_cfg = cfg.setdefault(step_id, {})
step_cfg["enabled"] = False
return cfg
def run_full_pipeline(self):
"""运行完整流程"""
if not PIPELINE_AVAILABLE:
@ -1665,7 +1453,7 @@ class WaterQualityGUI(QMainWindow):
# ── 1) 运行前智能预检与自动回填(硬盘已有产物自动跳过) ──
work_path = Path(work_dir)
self.log_message("正在进行运行前环境预检与自动扫描...", "info")
self.scan_work_directory_for_files(work_path)
self.workspace_manager.scan_work_directory_for_files(work_path)
self.auto_populate_all_steps()
self.log_message("✓ 预检完成:已扫描工作目录并自动回填已落盘的产物", "info")
@ -1685,7 +1473,7 @@ class WaterQualityGUI(QMainWindow):
# ── 2.1) ★ 根据模式动态裁剪配置 ──
if selected_mode == "prediction_only":
config = self._prune_config_for_prediction_mode(config)
config = self.workspace_manager.prune_config_for_prediction_mode(config)
self.log_message("[模式选择] 已裁剪训练相关步骤step4/5/7/8进入仅预测模式", "info")
# ── 3) ★ 一次性全预检 + 用户交互式决策 ──
@ -1798,38 +1586,21 @@ class WaterQualityGUI(QMainWindow):
work_path = Path(work_dir)
# 根据步骤名称和约定路径,记录实际输出
if step_name not in self.step_outputs:
self.step_outputs[step_name] = {}
if step_name not in self.workspace_manager.step_outputs:
self.workspace_manager.step_outputs[step_name] = {}
# 扫描工作目录,更新该步骤的输出路径
self.update_step_outputs(step_name, work_path)
self.workspace_manager.update_step_outputs(step_name, work_path)
# 自动填充依赖该步骤输出的后续步骤
self.auto_populate_dependent_steps(step_name)
def update_step_outputs(self, step_name, work_path):
"""更新指定步骤的输出路径记录"""
if step_name not in self.step_default_outputs:
return
step_outputs = self.step_default_outputs[step_name]
for output_type, relative_path in step_outputs.items():
if '*' in relative_path:
# 处理通配符路径
pattern_path = work_path / relative_path.replace('*', '*')
matching_files = list(pattern_path.parent.glob(pattern_path.name))
if matching_files:
# 选择最新的文件
latest_file = max(matching_files, key=lambda p: p.stat().st_mtime)
self.step_outputs[step_name][output_type] = str(latest_file)
else:
output_path = work_path / relative_path
if output_path.exists():
self.step_outputs[step_name][output_type] = str(output_path)
def auto_populate_dependent_steps(self, completed_step):
"""自动填充依赖于已完成步骤的后续步骤"""
ref_img_path = None
if hasattr(self, 'step1_panel'):
ref_img_path = self.step1_panel.img_file.get_path()
for step_id, dependencies in self.step_dependencies.items():
for input_field, (dep_step, output_type, panel_attr) in dependencies.items():
if dep_step == completed_step:
@ -1841,7 +1612,7 @@ class WaterQualityGUI(QMainWindow):
if not file_widget.get_path().strip():
work_dir = getattr(self, 'work_dir', './work_dir')
work_path = Path(work_dir)
output_path = self.find_step_output(work_path, dep_step, output_type)
output_path = self.workspace_manager.find_step_output(work_path, dep_step, output_type, ref_img_path=ref_img_path)
if output_path and Path(output_path).exists():
file_widget.set_path(output_path)
self.log_message(f"步骤完成后自动填充 {step_id}.{input_field}: {output_path}", "info")