refactor: 提取 WorkspaceManager,将文件扫描与路径业务逻辑从主 GUI 解耦

This commit is contained in:
DXC
2026-06-17 15:35:02 +08:00
parent 191a4b681d
commit 1949711cda
2 changed files with 254 additions and 252 deletions

View File

@ -0,0 +1,231 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
工作空间管理器
负责工作目录文件扫描、步骤输出路径发现、配置裁剪等业务逻辑,
与 GUI 组件解耦,不直接引用任何 UI 类。
"""
import copy
from pathlib import Path
class WorkspaceManager:
"""管理步骤默认输出路径、文件扫描与配置裁剪"""
# 白名单:科学数据格式后缀
SCIENTIFIC_EXTENSIONS = {'.dat', '.tif', '.tiff', '.shp'}
# 临时文件关键词黑名单
TMP_KEYWORDS = ('__tmp', '_tmp')
# 掩膜类型集合
MASK_TYPES = {'water_mask', 'glint_mask', 'boundary_mask'}
def __init__(self):
self.step_default_outputs = {
'step1': "1_water_mask/water_mask_from_ndwi.dat",
'step2': "2_Glint_Detection/severe_glint_area.dat",
'step3': "3_deglint/deglint_goodman.bsq",
'step4_sampling': "4_sampling/sampling_spectra.csv",
'step5_clean': "5_Data_Cleaning/processed_data.csv",
'step6_feature': "6_Spectral_Feature_Extraction/training_spectra.csv",
'step7_index': "7_Water_Quality_Indices/training_spectra_indices.csv",
'step8_ml_train': "8_Supervised_Model_Training/",
'step9_ml_predict': "9_ML_Prediction/",
'step10_watercolor': "10_WaterIndex_Images/",
'step11_map': "14_visualization/"
}
self.step_outputs = {}
@staticmethod
def _is_scientific_mask(path_str):
"""白名单判断:只有 .dat .tif .tiff .shp 才算科学数据格式"""
p = Path(path_str)
name_lower = str(path_str).lower()
if any(kw in name_lower for kw in WorkspaceManager.TMP_KEYWORDS):
return False
return p.suffix.lower() in WorkspaceManager.SCIENTIFIC_EXTENSIONS
def find_step_output(self, work_path, step_id, output_type, ref_img_path=None):
"""查找指定步骤的输出文件
Args:
work_path: 工作目录 Path 对象
step_id: 步骤 ID
output_type: 输出类型(如 'water_mask', 'deglint_image' 等)
ref_img_path: 参考影像路径(仅 output_type='reference_img' 时需要)
Returns:
找到的文件路径字符串,或 None
"""
if step_id not in self.step_default_outputs:
return None
raw = self.step_default_outputs[step_id]
rel_path = None
if isinstance(raw, str):
rel_path = raw
elif isinstance(raw, dict):
rel_path = raw.get(output_type) or list(raw.values())[0]
if not rel_path:
return None
# 特殊处理:从 step_outputs 记录中查找实际输出路径
if step_id in self.step_outputs:
actual_outputs = self.step_outputs[step_id]
if output_type in actual_outputs:
candidate = actual_outputs[output_type]
if output_type in self.MASK_TYPES and not self._is_scientific_mask(candidate):
pass
else:
return candidate
if output_type == 'water_mask':
if rel_path:
mask_path = work_path / rel_path
if mask_path.exists():
return str(mask_path)
elif output_type == 'reference_img':
if ref_img_path and Path(ref_img_path).exists():
return ref_img_path
elif output_type == 'deglint_image':
if rel_path:
deglint_path = work_path / rel_path
if deglint_path.exists():
return str(deglint_path)
deglint_dir = work_path / "3_deglint"
if deglint_dir.exists():
for file_path in deglint_dir.glob("deglint_*.bsq"):
return str(file_path)
for file_path in deglint_dir.glob("interpolated_*.bsq"):
return str(file_path)
elif rel_path:
if rel_path.endswith('/'):
output_path = work_path / rel_path.rstrip('/')
if output_path.exists() and output_path.is_dir():
return str(output_path)
else:
output_path = work_path / rel_path
if output_path.exists():
return str(output_path)
return None
def scan_work_directory_for_files(self, work_path):
"""扫描工作目录,自动发现各步骤的输出文件
Returns:
discovered_outputs: dict, {step_id: {output_type: path_str}}
"""
discovered_outputs = {}
subdirs = {
'1_water_mask': 'step1',
'2_Glint_Detection': 'step2',
'3_deglint': 'step3',
'5_Data_Cleaning': 'step5_clean',
'6_Spectral_Feature_Extraction': 'step6_feature',
'7_Water_Quality_Indices': 'step7_index',
'8_Supervised_Model_Training': 'step8_ml_train',
'8_Regression_Modeling': 'step8_ml_train',
'13_Custom_Regression': 'step13',
'9_ML_Prediction': 'step9_ml_predict',
'11_12_13_predictions/Non_Empirical_Prediction': 'step11_map',
'13_Custom_Regression/Custom_Regression_Prediction': 'step13',
'14_visualization': 'step13_report',
'10_geotiff_batch_rendering': 'step11_map'
}
for subdir, step_ids in subdirs.items():
subdir_path = work_path / subdir
if not subdir_path.exists():
continue
if isinstance(step_ids, str):
step_ids = [step_ids]
for file_path in subdir_path.rglob('*'):
if file_path.is_file():
file_name = file_path.name.lower()
for step_id in step_ids:
if step_id not in discovered_outputs:
discovered_outputs[step_id] = {}
if 'water_mask' in file_name and step_id == 'step1':
if self._is_scientific_mask(file_path):
discovered_outputs[step_id]['water_mask'] = str(file_path)
elif 'glint' in file_name and 'mask' in file_name and step_id == 'step2':
if self._is_scientific_mask(file_path):
discovered_outputs[step_id]['glint_mask'] = str(file_path)
elif 'deglint' in file_name and step_id == 'step3':
discovered_outputs[step_id]['deglint_image'] = str(file_path)
elif 'processed_data' in file_name and step_id == 'step4_sampling':
discovered_outputs[step_id]['processed_data'] = str(file_path)
elif 'training_spectra' in file_name and step_id == 'step5_clean':
discovered_outputs[step_id]['training_spectra'] = str(file_path)
elif 'water_quality_indices' in file_name and step_id == 'step6_feature':
discovered_outputs[step_id]['water_indices'] = str(file_path)
elif 'sampling_spectra' in file_name and step_id == 'step4_sampling':
discovered_outputs[step_id]['sampling_points'] = str(file_path)
elif file_name.endswith('.csv') and step_id in ['step9_ml_predict', 'step11_map', 'step12_viz']:
discovered_outputs[step_id]['predictions'] = str(file_path)
for step_id, outputs in discovered_outputs.items():
if step_id not in self.step_outputs:
self.step_outputs[step_id] = {}
self.step_outputs[step_id].update(outputs)
return discovered_outputs
def update_step_outputs(self, step_name, work_path):
"""更新指定步骤的输出路径记录"""
if step_name not in self.step_default_outputs:
return
step_outputs = self.step_default_outputs[step_name]
for output_type, relative_path in step_outputs.items():
if '*' in relative_path:
pattern_path = work_path / relative_path.replace('*', '*')
matching_files = list(pattern_path.parent.glob(pattern_path.name))
if matching_files:
latest_file = max(matching_files, key=lambda p: p.stat().st_mtime)
self.step_outputs[step_name][output_type] = str(latest_file)
else:
output_path = work_path / relative_path
if output_path.exists():
self.step_outputs[step_name][output_type] = str(output_path)
@staticmethod
def prune_config_for_prediction_mode(config: dict) -> dict:
"""Prediction-only 模式:禁用训练相关步骤,保留预测和成图步骤。
被禁用的 step dict 中统一写入 'enabled': False
这些配置最终传给 PipelineRunnerRunner 会跳过它们。
同时,被跳过的步骤的 required_input_files 在 build_missing_items
中不会被检查,从而自然规避了"CSV 缺失"等训练模式下的误报。
Args:
config: 完整配置字典(来自 get_current_config
Returns:
裁剪后的 config深拷贝原 config 不被修改)
"""
cfg = copy.deepcopy(config)
training_steps = [
"step4",
"step5",
"step7",
"step6",
"step8_non_empirical_modeling",
"step9",
]
for step_id in training_steps:
step_cfg = cfg.setdefault(step_id, {})
step_cfg["enabled"] = False
return cfg