From 1949711cda06aad01efb5adbed00b125752464dc Mon Sep 17 00:00:00 2001 From: DXC Date: Wed, 17 Jun 2026 15:35:02 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=8F=90=E5=8F=96=20WorkspaceManag?= =?UTF-8?q?er=EF=BC=8C=E5=B0=86=E6=96=87=E4=BB=B6=E6=89=AB=E6=8F=8F?= =?UTF-8?q?=E4=B8=8E=E8=B7=AF=E5=BE=84=E4=B8=9A=E5=8A=A1=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E4=BB=8E=E4=B8=BB=20GUI=20=E8=A7=A3=E8=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/workspace_manager.py | 231 ++++++++++++++++++++++++++++ src/gui/water_quality_gui.py | 275 +++------------------------------- 2 files changed, 254 insertions(+), 252 deletions(-) create mode 100644 src/core/workspace_manager.py diff --git a/src/core/workspace_manager.py b/src/core/workspace_manager.py new file mode 100644 index 0000000..52f86e0 --- /dev/null +++ b/src/core/workspace_manager.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +工作空间管理器 + +负责工作目录文件扫描、步骤输出路径发现、配置裁剪等业务逻辑, +与 GUI 组件解耦,不直接引用任何 UI 类。 +""" + +import copy +from pathlib import Path + + +class WorkspaceManager: + """管理步骤默认输出路径、文件扫描与配置裁剪""" + + # 白名单:科学数据格式后缀 + SCIENTIFIC_EXTENSIONS = {'.dat', '.tif', '.tiff', '.shp'} + # 临时文件关键词黑名单 + TMP_KEYWORDS = ('__tmp', '_tmp') + # 掩膜类型集合 + MASK_TYPES = {'water_mask', 'glint_mask', 'boundary_mask'} + + def __init__(self): + self.step_default_outputs = { + 'step1': "1_water_mask/water_mask_from_ndwi.dat", + 'step2': "2_Glint_Detection/severe_glint_area.dat", + 'step3': "3_deglint/deglint_goodman.bsq", + 'step4_sampling': "4_sampling/sampling_spectra.csv", + 'step5_clean': "5_Data_Cleaning/processed_data.csv", + 'step6_feature': "6_Spectral_Feature_Extraction/training_spectra.csv", + 'step7_index': "7_Water_Quality_Indices/training_spectra_indices.csv", + 'step8_ml_train': "8_Supervised_Model_Training/", + 'step9_ml_predict': "9_ML_Prediction/", + 'step10_watercolor': "10_WaterIndex_Images/", + 'step11_map': "14_visualization/" + } + self.step_outputs = {} + + @staticmethod + def _is_scientific_mask(path_str): + """白名单判断:只有 .dat .tif .tiff .shp 才算科学数据格式""" + p = Path(path_str) + name_lower = str(path_str).lower() + if any(kw in name_lower for kw in WorkspaceManager.TMP_KEYWORDS): + return False + return p.suffix.lower() in WorkspaceManager.SCIENTIFIC_EXTENSIONS + + def find_step_output(self, work_path, step_id, output_type, ref_img_path=None): + """查找指定步骤的输出文件 + + Args: + work_path: 工作目录 Path 对象 + step_id: 步骤 ID + output_type: 输出类型(如 'water_mask', 'deglint_image' 等) + ref_img_path: 参考影像路径(仅 output_type='reference_img' 时需要) + + Returns: + 找到的文件路径字符串,或 None + """ + if step_id not in self.step_default_outputs: + return None + + raw = self.step_default_outputs[step_id] + + rel_path = None + if isinstance(raw, str): + rel_path = raw + elif isinstance(raw, dict): + rel_path = raw.get(output_type) or list(raw.values())[0] + + if not rel_path: + return None + + # 特殊处理:从 step_outputs 记录中查找实际输出路径 + if step_id in self.step_outputs: + actual_outputs = self.step_outputs[step_id] + if output_type in actual_outputs: + candidate = actual_outputs[output_type] + if output_type in self.MASK_TYPES and not self._is_scientific_mask(candidate): + pass + else: + return candidate + + if output_type == 'water_mask': + if rel_path: + mask_path = work_path / rel_path + if mask_path.exists(): + return str(mask_path) + elif output_type == 'reference_img': + if ref_img_path and Path(ref_img_path).exists(): + return ref_img_path + elif output_type == 'deglint_image': + if rel_path: + deglint_path = work_path / rel_path + if deglint_path.exists(): + return str(deglint_path) + deglint_dir = work_path / "3_deglint" + if deglint_dir.exists(): + for file_path in deglint_dir.glob("deglint_*.bsq"): + return str(file_path) + for file_path in deglint_dir.glob("interpolated_*.bsq"): + return str(file_path) + elif rel_path: + if rel_path.endswith('/'): + output_path = work_path / rel_path.rstrip('/') + if output_path.exists() and output_path.is_dir(): + return str(output_path) + else: + output_path = work_path / rel_path + if output_path.exists(): + return str(output_path) + + return None + + def scan_work_directory_for_files(self, work_path): + """扫描工作目录,自动发现各步骤的输出文件 + + Returns: + discovered_outputs: dict, {step_id: {output_type: path_str}} + """ + discovered_outputs = {} + + subdirs = { + '1_water_mask': 'step1', + '2_Glint_Detection': 'step2', + '3_deglint': 'step3', + '5_Data_Cleaning': 'step5_clean', + '6_Spectral_Feature_Extraction': 'step6_feature', + '7_Water_Quality_Indices': 'step7_index', + '8_Supervised_Model_Training': 'step8_ml_train', + '8_Regression_Modeling': 'step8_ml_train', + '13_Custom_Regression': 'step13', + '9_ML_Prediction': 'step9_ml_predict', + '11_12_13_predictions/Non_Empirical_Prediction': 'step11_map', + '13_Custom_Regression/Custom_Regression_Prediction': 'step13', + '14_visualization': 'step13_report', + '10_geotiff_batch_rendering': 'step11_map' + } + + for subdir, step_ids in subdirs.items(): + subdir_path = work_path / subdir + if not subdir_path.exists(): + continue + + if isinstance(step_ids, str): + step_ids = [step_ids] + + for file_path in subdir_path.rglob('*'): + if file_path.is_file(): + file_name = file_path.name.lower() + + for step_id in step_ids: + if step_id not in discovered_outputs: + discovered_outputs[step_id] = {} + + if 'water_mask' in file_name and step_id == 'step1': + if self._is_scientific_mask(file_path): + discovered_outputs[step_id]['water_mask'] = str(file_path) + elif 'glint' in file_name and 'mask' in file_name and step_id == 'step2': + if self._is_scientific_mask(file_path): + discovered_outputs[step_id]['glint_mask'] = str(file_path) + elif 'deglint' in file_name and step_id == 'step3': + discovered_outputs[step_id]['deglint_image'] = str(file_path) + elif 'processed_data' in file_name and step_id == 'step4_sampling': + discovered_outputs[step_id]['processed_data'] = str(file_path) + elif 'training_spectra' in file_name and step_id == 'step5_clean': + discovered_outputs[step_id]['training_spectra'] = str(file_path) + elif 'water_quality_indices' in file_name and step_id == 'step6_feature': + discovered_outputs[step_id]['water_indices'] = str(file_path) + elif 'sampling_spectra' in file_name and step_id == 'step4_sampling': + discovered_outputs[step_id]['sampling_points'] = str(file_path) + elif file_name.endswith('.csv') and step_id in ['step9_ml_predict', 'step11_map', 'step12_viz']: + discovered_outputs[step_id]['predictions'] = str(file_path) + + for step_id, outputs in discovered_outputs.items(): + if step_id not in self.step_outputs: + self.step_outputs[step_id] = {} + self.step_outputs[step_id].update(outputs) + + return discovered_outputs + + def update_step_outputs(self, step_name, work_path): + """更新指定步骤的输出路径记录""" + if step_name not in self.step_default_outputs: + return + + step_outputs = self.step_default_outputs[step_name] + + for output_type, relative_path in step_outputs.items(): + if '*' in relative_path: + pattern_path = work_path / relative_path.replace('*', '*') + matching_files = list(pattern_path.parent.glob(pattern_path.name)) + if matching_files: + latest_file = max(matching_files, key=lambda p: p.stat().st_mtime) + self.step_outputs[step_name][output_type] = str(latest_file) + else: + output_path = work_path / relative_path + if output_path.exists(): + self.step_outputs[step_name][output_type] = str(output_path) + + @staticmethod + def prune_config_for_prediction_mode(config: dict) -> dict: + """Prediction-only 模式:禁用训练相关步骤,保留预测和成图步骤。 + + 被禁用的 step dict 中统一写入 'enabled': False, + 这些配置最终传给 PipelineRunner,Runner 会跳过它们。 + 同时,被跳过的步骤的 required_input_files 在 build_missing_items + 中不会被检查,从而自然规避了"CSV 缺失"等训练模式下的误报。 + + Args: + config: 完整配置字典(来自 get_current_config) + + Returns: + 裁剪后的 config(深拷贝,原 config 不被修改) + """ + cfg = copy.deepcopy(config) + + training_steps = [ + "step4", + "step5", + "step7", + "step6", + "step8_non_empirical_modeling", + "step9", + ] + for step_id in training_steps: + step_cfg = cfg.setdefault(step_id, {}) + step_cfg["enabled"] = False + + return cfg diff --git a/src/gui/water_quality_gui.py b/src/gui/water_quality_gui.py index 7fecefd..88417a1 100644 --- a/src/gui/water_quality_gui.py +++ b/src/gui/water_quality_gui.py @@ -158,6 +158,7 @@ from src.gui.core.worker_thread import ( from src.gui.core.preflight_dialog import PreflightDialog from src.gui.core.pipeline_mode_dialog import PipelineModeDialog from src.gui.core.viz_thread import VisualizationWorkerThread, _viz_training_spectra_csv_path +from src.core.workspace_manager import WorkspaceManager class WaterQualityGUI(QMainWindow): @@ -183,10 +184,10 @@ class WaterQualityGUI(QMainWindow): # 训练数据模式状态 self.has_training_data = True # 默认有训练数据 - # 步骤输出路径记录 - self.step_outputs = {} # 记录每个步骤的输出路径 + # 工作空间管理器(文件扫描、路径发现、配置裁剪) + self.workspace_manager = WorkspaceManager() - # 定义步骤依赖关系和标准输出路径 + # 定义步骤依赖关系 self._init_step_dependencies() self.init_ui() @@ -198,22 +199,7 @@ class WaterQualityGUI(QMainWindow): QTimer.singleShot(100, self.init_workspace) def _init_step_dependencies(self): - """初始化步骤依赖关系和标准输出路径""" - # 定义每个步骤的标准输出路径模式(相对于工作目录) - self.step_default_outputs = { - 'step1': "1_water_mask/water_mask_from_ndwi.dat", - 'step2': "2_Glint_Detection/severe_glint_area.dat", - 'step3': "3_deglint/deglint_goodman.bsq", - 'step4_sampling': "4_sampling/sampling_spectra.csv", - 'step5_clean': "5_Data_Cleaning/processed_data.csv", - 'step6_feature': "6_Spectral_Feature_Extraction/training_spectra.csv", - 'step7_index': "7_Water_Quality_Indices/training_spectra_indices.csv", - 'step8_ml_train': "8_Supervised_Model_Training/", - 'step9_ml_predict': "9_ML_Prediction/", - 'step10_watercolor': "10_WaterIndex_Images/", - 'step11_map': "14_visualization/" - } - + """初始化步骤依赖关系""" # 依赖关系字典结构: # '当前步骤ID': { '依赖参数名': ('上游步骤ID', '上游输出类型/Key', '当前步骤接收该路径的组件属性名') } self.step_dependencies = { @@ -1081,7 +1067,11 @@ class WaterQualityGUI(QMainWindow): dependencies = self.step_dependencies[step_id] filled_count = 0 - + + ref_img_path = None + if hasattr(self, 'step1_panel'): + ref_img_path = self.step1_panel.img_file.get_path() + for input_field, (dep_step, output_type, panel_attr) in dependencies.items(): # 检查面板是否有对应的属性 if not hasattr(panel, panel_attr): @@ -1101,7 +1091,7 @@ class WaterQualityGUI(QMainWindow): continue # 查找依赖步骤的输出文件 - output_path = self.find_step_output(work_path, dep_step, output_type) + output_path = self.workspace_manager.find_step_output(work_path, dep_step, output_type, ref_img_path=ref_img_path) if output_path and Path(output_path).exists(): # ★ 兼容 FileSelectWidget 与原生 QLineEdit @@ -1132,173 +1122,6 @@ class WaterQualityGUI(QMainWindow): } return panel_map.get(step_id) - def find_step_output(self, work_path, step_id, output_type): - """查找指定步骤的输出文件""" - if step_id not in self.step_default_outputs: - return None - - raw = self.step_default_outputs[step_id] - - # ★ 兼容扁平化后的纯字符串路径格式 - rel_path = None - if isinstance(raw, str): - rel_path = raw - elif isinstance(raw, dict): - rel_path = raw.get(output_type) or list(raw.values())[0] - - if not rel_path: - return None - - # ★ 掩膜类型列表:这些类型只接受科学数据格式 - mask_types = {'water_mask', 'glint_mask', 'boundary_mask'} - # ★ 白名单机制:只允许 .dat .tif .tiff .shp,拒绝其他一切格式 - scientific_extensions = {'.dat', '.tif', '.tiff', '.shp'} - # ★ 临时文件关键词黑名单 - tmp_keywords = ('__tmp', '_tmp') - - def _is_scientific_mask(path_str): - """白名单判断:只有 .dat .tif .tiff .shp 才算科学数据格式""" - p = Path(path_str) - name_lower = str(path_str).lower() - # 拒绝临时文件 - if any(kw in name_lower for kw in tmp_keywords): - return False - # 白名单校验 - return p.suffix.lower() in scientific_extensions - - # 特殊处理:从step_outputs记录中查找实际输出路径 - if step_id in self.step_outputs: - actual_outputs = self.step_outputs[step_id] - if output_type in actual_outputs: - candidate = actual_outputs[output_type] - # ★ 掩膜类型白名单二次校验:不在白名单内的一律拒绝 - if output_type in mask_types and not _is_scientific_mask(candidate): - # 非科学格式被拒绝,不使用 step_outputs 中的值 - pass - else: - return candidate - - # 根据输出类型查找对应的文件 - if output_type == 'water_mask': - # 水域掩膜:直接用统一路径 - if rel_path: - mask_path = work_path / rel_path - if mask_path.exists(): - return str(mask_path) - elif output_type == 'reference_img': - # 参考影像:从step1的配置中获取用户输入的影像路径 - if hasattr(self, 'step1_panel'): - img_path = self.step1_panel.img_file.get_path() - if img_path and Path(img_path).exists(): - return img_path - elif output_type == 'deglint_image': - # 去耀斑影像:直接用统一路径 - if rel_path: - deglint_path = work_path / rel_path - if deglint_path.exists(): - return str(deglint_path) - # 还要检查 Kutser 算法输出与插值方法生成的文件 - deglint_dir = work_path / "3_deglint" - if deglint_dir.exists(): - for file_path in deglint_dir.glob("deglint_*.bsq"): - return str(file_path) - for file_path in deglint_dir.glob("interpolated_*.bsq"): - return str(file_path) - elif rel_path: - # 直接匹配的输出类型(统一使用 rel_path) - if rel_path.endswith('/'): - # 是目录 - output_path = work_path / rel_path.rstrip('/') - if output_path.exists() and output_path.is_dir(): - return str(output_path) - else: - # 是文件 - output_path = work_path / rel_path - if output_path.exists(): - return str(output_path) - - return None - - def scan_work_directory_for_files(self, work_path): - """扫描工作目录,自动发现各步骤的输出文件""" - discovered_outputs = {} - - # 扫描各个子目录 - subdirs = { - '1_water_mask': 'step1', - '2_Glint_Detection': 'step2', - '3_deglint': 'step3', - '5_Data_Cleaning': 'step5_clean', - '6_Spectral_Feature_Extraction': 'step6_feature', - '7_Water_Quality_Indices': 'step7_index', - '8_Supervised_Model_Training': 'step8_ml_train', - '8_Regression_Modeling': 'step8_ml_train', - '13_Custom_Regression': 'step13', - '9_ML_Prediction': 'step9_ml_predict', - '11_12_13_predictions/Non_Empirical_Prediction': 'step11_map', - '13_Custom_Regression/Custom_Regression_Prediction': 'step13', - '14_visualization': 'step13_report', - '10_geotiff_batch_rendering': 'step11_map' - } - - for subdir, step_ids in subdirs.items(): - subdir_path = work_path / subdir - if not subdir_path.exists(): - continue - - if isinstance(step_ids, str): - step_ids = [step_ids] - - # 扫描该目录下的文件 - for file_path in subdir_path.rglob('*'): - if file_path.is_file(): - file_name = file_path.name.lower() - - # 根据文件名模式判断输出类型 - for step_id in step_ids: - if step_id not in discovered_outputs: - discovered_outputs[step_id] = {} - - # ★ 掩膜文件白名单过滤:只有 .dat .tif .tiff .shp 才通过,拒绝 .hdr .xml .png 等 - scientific_extensions = {'.dat', '.tif', '.tiff', '.shp'} - tmp_keywords = ('__tmp', '_tmp') - - def _is_scientific_mask(path_str): - """白名单判断:拒绝 .hdr .xml 临时文件等,只接受科学数据格式""" - p = Path(path_str) - name_lower = str(path_str).lower() - if any(kw in name_lower for kw in tmp_keywords): - return False - return p.suffix.lower() in scientific_extensions - - # 匹配不同的文件类型 - if 'water_mask' in file_name and step_id == 'step1': - if _is_scientific_mask(file_path): - discovered_outputs[step_id]['water_mask'] = str(file_path) - elif 'glint' in file_name and 'mask' in file_name and step_id == 'step2': - if _is_scientific_mask(file_path): - discovered_outputs[step_id]['glint_mask'] = str(file_path) - elif 'deglint' in file_name and step_id == 'step3': - discovered_outputs[step_id]['deglint_image'] = str(file_path) - elif 'processed_data' in file_name and step_id == 'step4_sampling': - discovered_outputs[step_id]['processed_data'] = str(file_path) - elif 'training_spectra' in file_name and step_id == 'step5_clean': - discovered_outputs[step_id]['training_spectra'] = str(file_path) - elif 'water_quality_indices' in file_name and step_id == 'step6_feature': - discovered_outputs[step_id]['water_indices'] = str(file_path) - elif 'sampling_spectra' in file_name and step_id == 'step4_sampling': - discovered_outputs[step_id]['sampling_points'] = str(file_path) - elif file_name.endswith('.csv') and step_id in ['step9_ml_predict', 'step11_map', 'step12_viz']: - discovered_outputs[step_id]['predictions'] = str(file_path) - - # 更新内部记录 - for step_id, outputs in discovered_outputs.items(): - if step_id not in self.step_outputs: - self.step_outputs[step_id] = {} - self.step_outputs[step_id].update(outputs) - - return discovered_outputs - def auto_populate_all_steps(self): """自动填充所有步骤的输入路径""" work_dir = getattr(self, 'work_dir', './work_dir') @@ -1309,7 +1132,7 @@ class WaterQualityGUI(QMainWindow): return # 首先扫描工作目录发现已有的输出文件 - self.scan_work_directory_for_files(work_path) + self.workspace_manager.scan_work_directory_for_files(work_path) step_order = ['step2', 'step3', 'step4_sampling', 'step5_clean', 'step6_feature', 'step7_index', 'step8_ml_train', 'step9_ml_predict', 'step11_map', 'step12_viz', 'step13_report'] @@ -1612,41 +1435,6 @@ class WaterQualityGUI(QMainWindow): return True - # ------------------------------------------------------------------ - # ★ 全流程模式动态裁剪 - # ------------------------------------------------------------------ - - def _prune_config_for_prediction_mode(self, config: dict) -> dict: - """Prediction-only 模式:禁用训练相关步骤,保留预测和成图步骤。 - - 被禁用的 step dict 中统一写入 'enabled': False, - 这些配置最终传给 PipelineRunner,Runner 会跳过它们。 - 同时,被跳过的步骤的 required_input_files 在 build_missing_items - 中不会被检查,从而自然规避了"CSV 缺失"等训练模式下的误报。 - - Args: - config: 完整配置字典(来自 get_current_config) - - Returns: - 裁剪后的 config(深拷贝,原 config 不被修改) - """ - cfg = copy.deepcopy(config) - - # 在每个训练相关步骤的 dict 中写入 enabled=False - training_steps = [ - "step4", # CSV 实测数据清洗 - "step5", # 实测点光谱提取(→ training_csv_path) - "step7", # ML 监督建模 - "step6", # 水质指数计算(辅助训练) - "step8_non_empirical_modeling", # 非经验回归建模 - "step9", # 自定义回归建模 - ] - for step_id in training_steps: - step_cfg = cfg.setdefault(step_id, {}) - step_cfg["enabled"] = False - - return cfg - def run_full_pipeline(self): """运行完整流程""" if not PIPELINE_AVAILABLE: @@ -1665,7 +1453,7 @@ class WaterQualityGUI(QMainWindow): # ── 1) 运行前智能预检与自动回填(硬盘已有产物自动跳过) ── work_path = Path(work_dir) self.log_message("正在进行运行前环境预检与自动扫描...", "info") - self.scan_work_directory_for_files(work_path) + self.workspace_manager.scan_work_directory_for_files(work_path) self.auto_populate_all_steps() self.log_message("✓ 预检完成:已扫描工作目录并自动回填已落盘的产物", "info") @@ -1685,7 +1473,7 @@ class WaterQualityGUI(QMainWindow): # ── 2.1) ★ 根据模式动态裁剪配置 ── if selected_mode == "prediction_only": - config = self._prune_config_for_prediction_mode(config) + config = self.workspace_manager.prune_config_for_prediction_mode(config) self.log_message("[模式选择] 已裁剪训练相关步骤(step4/5/7/8),进入仅预测模式", "info") # ── 3) ★ 一次性全预检 + 用户交互式决策 ── @@ -1798,38 +1586,21 @@ class WaterQualityGUI(QMainWindow): work_path = Path(work_dir) # 根据步骤名称和约定路径,记录实际输出 - if step_name not in self.step_outputs: - self.step_outputs[step_name] = {} - + if step_name not in self.workspace_manager.step_outputs: + self.workspace_manager.step_outputs[step_name] = {} + # 扫描工作目录,更新该步骤的输出路径 - self.update_step_outputs(step_name, work_path) + self.workspace_manager.update_step_outputs(step_name, work_path) # 自动填充依赖该步骤输出的后续步骤 self.auto_populate_dependent_steps(step_name) - def update_step_outputs(self, step_name, work_path): - """更新指定步骤的输出路径记录""" - if step_name not in self.step_default_outputs: - return - - step_outputs = self.step_default_outputs[step_name] - - for output_type, relative_path in step_outputs.items(): - if '*' in relative_path: - # 处理通配符路径 - pattern_path = work_path / relative_path.replace('*', '*') - matching_files = list(pattern_path.parent.glob(pattern_path.name)) - if matching_files: - # 选择最新的文件 - latest_file = max(matching_files, key=lambda p: p.stat().st_mtime) - self.step_outputs[step_name][output_type] = str(latest_file) - else: - output_path = work_path / relative_path - if output_path.exists(): - self.step_outputs[step_name][output_type] = str(output_path) - def auto_populate_dependent_steps(self, completed_step): """自动填充依赖于已完成步骤的后续步骤""" + ref_img_path = None + if hasattr(self, 'step1_panel'): + ref_img_path = self.step1_panel.img_file.get_path() + for step_id, dependencies in self.step_dependencies.items(): for input_field, (dep_step, output_type, panel_attr) in dependencies.items(): if dep_step == completed_step: @@ -1841,7 +1612,7 @@ class WaterQualityGUI(QMainWindow): if not file_widget.get_path().strip(): work_dir = getattr(self, 'work_dir', './work_dir') work_path = Path(work_dir) - output_path = self.find_step_output(work_path, dep_step, output_type) + output_path = self.workspace_manager.find_step_output(work_path, dep_step, output_type, ref_img_path=ref_img_path) if output_path and Path(output_path).exists(): file_widget.set_path(output_path) self.log_message(f"步骤完成后自动填充 {step_id}.{input_field}: {output_path}", "info")