refactor: 提取 WorkspaceManager,将文件扫描与路径业务逻辑从主 GUI 解耦
This commit is contained in:
231
src/core/workspace_manager.py
Normal file
231
src/core/workspace_manager.py
Normal file
@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
工作空间管理器
|
||||
|
||||
负责工作目录文件扫描、步骤输出路径发现、配置裁剪等业务逻辑,
|
||||
与 GUI 组件解耦,不直接引用任何 UI 类。
|
||||
"""
|
||||
|
||||
import copy
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
"""管理步骤默认输出路径、文件扫描与配置裁剪"""
|
||||
|
||||
# 白名单:科学数据格式后缀
|
||||
SCIENTIFIC_EXTENSIONS = {'.dat', '.tif', '.tiff', '.shp'}
|
||||
# 临时文件关键词黑名单
|
||||
TMP_KEYWORDS = ('__tmp', '_tmp')
|
||||
# 掩膜类型集合
|
||||
MASK_TYPES = {'water_mask', 'glint_mask', 'boundary_mask'}
|
||||
|
||||
def __init__(self):
|
||||
self.step_default_outputs = {
|
||||
'step1': "1_water_mask/water_mask_from_ndwi.dat",
|
||||
'step2': "2_Glint_Detection/severe_glint_area.dat",
|
||||
'step3': "3_deglint/deglint_goodman.bsq",
|
||||
'step4_sampling': "4_sampling/sampling_spectra.csv",
|
||||
'step5_clean': "5_Data_Cleaning/processed_data.csv",
|
||||
'step6_feature': "6_Spectral_Feature_Extraction/training_spectra.csv",
|
||||
'step7_index': "7_Water_Quality_Indices/training_spectra_indices.csv",
|
||||
'step8_ml_train': "8_Supervised_Model_Training/",
|
||||
'step9_ml_predict': "9_ML_Prediction/",
|
||||
'step10_watercolor': "10_WaterIndex_Images/",
|
||||
'step11_map': "14_visualization/"
|
||||
}
|
||||
self.step_outputs = {}
|
||||
|
||||
@staticmethod
|
||||
def _is_scientific_mask(path_str):
|
||||
"""白名单判断:只有 .dat .tif .tiff .shp 才算科学数据格式"""
|
||||
p = Path(path_str)
|
||||
name_lower = str(path_str).lower()
|
||||
if any(kw in name_lower for kw in WorkspaceManager.TMP_KEYWORDS):
|
||||
return False
|
||||
return p.suffix.lower() in WorkspaceManager.SCIENTIFIC_EXTENSIONS
|
||||
|
||||
def find_step_output(self, work_path, step_id, output_type, ref_img_path=None):
|
||||
"""查找指定步骤的输出文件
|
||||
|
||||
Args:
|
||||
work_path: 工作目录 Path 对象
|
||||
step_id: 步骤 ID
|
||||
output_type: 输出类型(如 'water_mask', 'deglint_image' 等)
|
||||
ref_img_path: 参考影像路径(仅 output_type='reference_img' 时需要)
|
||||
|
||||
Returns:
|
||||
找到的文件路径字符串,或 None
|
||||
"""
|
||||
if step_id not in self.step_default_outputs:
|
||||
return None
|
||||
|
||||
raw = self.step_default_outputs[step_id]
|
||||
|
||||
rel_path = None
|
||||
if isinstance(raw, str):
|
||||
rel_path = raw
|
||||
elif isinstance(raw, dict):
|
||||
rel_path = raw.get(output_type) or list(raw.values())[0]
|
||||
|
||||
if not rel_path:
|
||||
return None
|
||||
|
||||
# 特殊处理:从 step_outputs 记录中查找实际输出路径
|
||||
if step_id in self.step_outputs:
|
||||
actual_outputs = self.step_outputs[step_id]
|
||||
if output_type in actual_outputs:
|
||||
candidate = actual_outputs[output_type]
|
||||
if output_type in self.MASK_TYPES and not self._is_scientific_mask(candidate):
|
||||
pass
|
||||
else:
|
||||
return candidate
|
||||
|
||||
if output_type == 'water_mask':
|
||||
if rel_path:
|
||||
mask_path = work_path / rel_path
|
||||
if mask_path.exists():
|
||||
return str(mask_path)
|
||||
elif output_type == 'reference_img':
|
||||
if ref_img_path and Path(ref_img_path).exists():
|
||||
return ref_img_path
|
||||
elif output_type == 'deglint_image':
|
||||
if rel_path:
|
||||
deglint_path = work_path / rel_path
|
||||
if deglint_path.exists():
|
||||
return str(deglint_path)
|
||||
deglint_dir = work_path / "3_deglint"
|
||||
if deglint_dir.exists():
|
||||
for file_path in deglint_dir.glob("deglint_*.bsq"):
|
||||
return str(file_path)
|
||||
for file_path in deglint_dir.glob("interpolated_*.bsq"):
|
||||
return str(file_path)
|
||||
elif rel_path:
|
||||
if rel_path.endswith('/'):
|
||||
output_path = work_path / rel_path.rstrip('/')
|
||||
if output_path.exists() and output_path.is_dir():
|
||||
return str(output_path)
|
||||
else:
|
||||
output_path = work_path / rel_path
|
||||
if output_path.exists():
|
||||
return str(output_path)
|
||||
|
||||
return None
|
||||
|
||||
def scan_work_directory_for_files(self, work_path):
|
||||
"""扫描工作目录,自动发现各步骤的输出文件
|
||||
|
||||
Returns:
|
||||
discovered_outputs: dict, {step_id: {output_type: path_str}}
|
||||
"""
|
||||
discovered_outputs = {}
|
||||
|
||||
subdirs = {
|
||||
'1_water_mask': 'step1',
|
||||
'2_Glint_Detection': 'step2',
|
||||
'3_deglint': 'step3',
|
||||
'5_Data_Cleaning': 'step5_clean',
|
||||
'6_Spectral_Feature_Extraction': 'step6_feature',
|
||||
'7_Water_Quality_Indices': 'step7_index',
|
||||
'8_Supervised_Model_Training': 'step8_ml_train',
|
||||
'8_Regression_Modeling': 'step8_ml_train',
|
||||
'13_Custom_Regression': 'step13',
|
||||
'9_ML_Prediction': 'step9_ml_predict',
|
||||
'11_12_13_predictions/Non_Empirical_Prediction': 'step11_map',
|
||||
'13_Custom_Regression/Custom_Regression_Prediction': 'step13',
|
||||
'14_visualization': 'step13_report',
|
||||
'10_geotiff_batch_rendering': 'step11_map'
|
||||
}
|
||||
|
||||
for subdir, step_ids in subdirs.items():
|
||||
subdir_path = work_path / subdir
|
||||
if not subdir_path.exists():
|
||||
continue
|
||||
|
||||
if isinstance(step_ids, str):
|
||||
step_ids = [step_ids]
|
||||
|
||||
for file_path in subdir_path.rglob('*'):
|
||||
if file_path.is_file():
|
||||
file_name = file_path.name.lower()
|
||||
|
||||
for step_id in step_ids:
|
||||
if step_id not in discovered_outputs:
|
||||
discovered_outputs[step_id] = {}
|
||||
|
||||
if 'water_mask' in file_name and step_id == 'step1':
|
||||
if self._is_scientific_mask(file_path):
|
||||
discovered_outputs[step_id]['water_mask'] = str(file_path)
|
||||
elif 'glint' in file_name and 'mask' in file_name and step_id == 'step2':
|
||||
if self._is_scientific_mask(file_path):
|
||||
discovered_outputs[step_id]['glint_mask'] = str(file_path)
|
||||
elif 'deglint' in file_name and step_id == 'step3':
|
||||
discovered_outputs[step_id]['deglint_image'] = str(file_path)
|
||||
elif 'processed_data' in file_name and step_id == 'step4_sampling':
|
||||
discovered_outputs[step_id]['processed_data'] = str(file_path)
|
||||
elif 'training_spectra' in file_name and step_id == 'step5_clean':
|
||||
discovered_outputs[step_id]['training_spectra'] = str(file_path)
|
||||
elif 'water_quality_indices' in file_name and step_id == 'step6_feature':
|
||||
discovered_outputs[step_id]['water_indices'] = str(file_path)
|
||||
elif 'sampling_spectra' in file_name and step_id == 'step4_sampling':
|
||||
discovered_outputs[step_id]['sampling_points'] = str(file_path)
|
||||
elif file_name.endswith('.csv') and step_id in ['step9_ml_predict', 'step11_map', 'step12_viz']:
|
||||
discovered_outputs[step_id]['predictions'] = str(file_path)
|
||||
|
||||
for step_id, outputs in discovered_outputs.items():
|
||||
if step_id not in self.step_outputs:
|
||||
self.step_outputs[step_id] = {}
|
||||
self.step_outputs[step_id].update(outputs)
|
||||
|
||||
return discovered_outputs
|
||||
|
||||
def update_step_outputs(self, step_name, work_path):
|
||||
"""更新指定步骤的输出路径记录"""
|
||||
if step_name not in self.step_default_outputs:
|
||||
return
|
||||
|
||||
step_outputs = self.step_default_outputs[step_name]
|
||||
|
||||
for output_type, relative_path in step_outputs.items():
|
||||
if '*' in relative_path:
|
||||
pattern_path = work_path / relative_path.replace('*', '*')
|
||||
matching_files = list(pattern_path.parent.glob(pattern_path.name))
|
||||
if matching_files:
|
||||
latest_file = max(matching_files, key=lambda p: p.stat().st_mtime)
|
||||
self.step_outputs[step_name][output_type] = str(latest_file)
|
||||
else:
|
||||
output_path = work_path / relative_path
|
||||
if output_path.exists():
|
||||
self.step_outputs[step_name][output_type] = str(output_path)
|
||||
|
||||
@staticmethod
|
||||
def prune_config_for_prediction_mode(config: dict) -> dict:
|
||||
"""Prediction-only 模式:禁用训练相关步骤,保留预测和成图步骤。
|
||||
|
||||
被禁用的 step dict 中统一写入 'enabled': False,
|
||||
这些配置最终传给 PipelineRunner,Runner 会跳过它们。
|
||||
同时,被跳过的步骤的 required_input_files 在 build_missing_items
|
||||
中不会被检查,从而自然规避了"CSV 缺失"等训练模式下的误报。
|
||||
|
||||
Args:
|
||||
config: 完整配置字典(来自 get_current_config)
|
||||
|
||||
Returns:
|
||||
裁剪后的 config(深拷贝,原 config 不被修改)
|
||||
"""
|
||||
cfg = copy.deepcopy(config)
|
||||
|
||||
training_steps = [
|
||||
"step4",
|
||||
"step5",
|
||||
"step7",
|
||||
"step6",
|
||||
"step8_non_empirical_modeling",
|
||||
"step9",
|
||||
]
|
||||
for step_id in training_steps:
|
||||
step_cfg = cfg.setdefault(step_id, {})
|
||||
step_cfg["enabled"] = False
|
||||
|
||||
return cfg
|
||||
Reference in New Issue
Block a user