diff --git a/src/core/water_quality_inversion_pipeline_GUI.py b/src/core/water_quality_inversion_pipeline_GUI.py index 268508c..3395a6e 100644 --- a/src/core/water_quality_inversion_pipeline_GUI.py +++ b/src/core/water_quality_inversion_pipeline_GUI.py @@ -1661,19 +1661,22 @@ class WaterQualityInversionPipeline: geotransform, projection, width, height, n_bands = self._get_image_geo_info(img_path) print(f"影像尺寸: {width} x {height} x {n_bands}") - # 处理水域掩膜:如果是shp文件路径,需要栅格化 - # 创建一个临时数组用于获取尺寸信息(仅用于掩膜处理) - temp_shape = (height, width) + + # 加载影像数据(Kutser算法需要numpy数组) + image_array, geotransform, projection = self._load_image_as_array(img_path) + print(f"影像尺寸: {image_array.shape}") + + # 处理水域掩膜 mask_for_algorithm = self._prepare_water_mask_for_algorithm( - final_water_mask, temp_shape, geotransform, projection, img_path + final_water_mask, image_array.shape, geotransform, projection, img_path ) - - # 应用Kutser算法:直接传递文件路径,让算法类使用GDAL逐波段处理 + + # 应用Kutser算法:传递numpy数组 # 注意:kutser_shp_path参数已废弃,使用water_mask代替 - kutser = Kutser(img_path, shp_path=None, # 直接传递文件路径 + kutser = Kutser(image_array, shp_path=None, oxy_band=oxy_band, lower_oxy=lower_oxy, upper_oxy=upper_oxy, NIR_band=nir_band, - water_mask=mask_for_algorithm, output_path=output_path) # 传递output_path,算法类会保存 + water_mask=mask_for_algorithm, output_path=output_path) corrected_bands = kutser.get_corrected_bands() # 检查算法类是否已保存文件(可能保存为.bsq格式) @@ -1765,18 +1768,21 @@ class WaterQualityInversionPipeline: geotransform, projection, width, height, n_bands = self._get_image_geo_info(img_path) print(f"影像尺寸: {width} x {height} x {n_bands}") - # 处理水域掩膜:如果是shp文件路径,需要栅格化 - # 创建一个临时数组用于获取尺寸信息(仅用于掩膜处理) - temp_shape = (height, width) + + # 加载影像数据(Hedley算法需要numpy数组) + image_array, geotransform, projection = self._load_image_as_array(img_path) + print(f"影像尺寸: {image_array.shape}") + + # 处理水域掩膜 mask_for_algorithm = self._prepare_water_mask_for_algorithm( - final_water_mask, temp_shape, geotransform, projection, img_path + final_water_mask, image_array.shape, geotransform, projection, img_path ) - - # 应用Hedley算法:直接传递文件路径,让算法类使用GDAL逐波段处理 + + # 应用Hedley算法:传递numpy数组 # 注意:hedley_shp_path参数已废弃,使用water_mask代替 - hedley = Hedley(img_path, shp_path=None, # 直接传递文件路径 + hedley = Hedley(image_array, shp_path=None, NIR_band=hedley_nir_band, water_mask=mask_for_algorithm, - output_path=output_path) # 传递output_path,算法类会保存 + output_path=output_path) corrected_bands = hedley.get_corrected_bands() # 检查算法类是否已保存文件(可能保存为.bsq格式) diff --git a/src/gui/water_quality_gui.py b/src/gui/water_quality_gui.py index ffe87a0..31486f8 100644 --- a/src/gui/water_quality_gui.py +++ b/src/gui/water_quality_gui.py @@ -365,8 +365,8 @@ class WorkerThread(QThread): step_config.pop('prediction_csv_dir', None) step_config.pop('recursive_csv_scan', None) - # step5:输出路径由管线固定到工作目录,GUI 占位字段勿传入 - if step_name == 'step5': + # 拦截掉底层不需要的 GUI 专用输出路径字段,防止报错 + if step_name in ['step2', 'step3', 'step4', 'step5', 'step7', 'step8', 'step8_5', 'step8_75']: step_config.pop('output_path', None) # 参数名映射:将GUI中的参数名映射为pipeline方法期望的参数名 @@ -872,6 +872,44 @@ class FileSelectWidget(QWidget): self.setLayout(layout) + def update_from_config(self, work_dir=None, pipeline=None): + """ + 从 Step1Panel 自动填充水域掩膜路径,实现上下游数据流转 + + Args: + work_dir: 工作目录路径 + pipeline: Pipeline 实例(未使用,保留接口兼容性) + """ + # 保存工作目录引用 + if work_dir: + self.work_dir = work_dir + elif hasattr(self, 'work_dir') and self.work_dir: + pass # 保持现有工作目录 + else: + self.work_dir = None + + # 从 Step1 界面读取水域掩膜路径 + main_window = self.window() + if hasattr(main_window, 'step1_panel'): + if main_window.step1_panel.use_ndwi_radio.isChecked(): + # NDWI模式,读取输出框的路径 + mask_path = main_window.step1_panel.output_file.get_path() + else: + # 导入现有模式,读取输入框的路径 + mask_path = main_window.step1_panel.mask_file.get_path() + + if mask_path: + self.water_mask_file.set_path(mask_path) + + # 自动填充输出路径(基于工作目录) + if self.work_dir: + output_dir = os.path.join(self.work_dir, "3_deglint") + os.makedirs(output_dir, exist_ok=True) + default_output_path = os.path.join(output_dir, "deglint_image.dat").replace('\\', '/') + self.output_file.set_path(default_output_path) + else: + self.output_file.set_path("") + def browse_file(self): """浏览文件""" # 获取当前输入框中的文本,尝试从中提取初始目录 @@ -1249,7 +1287,12 @@ class Step2Panel(QWidget): # 检测方法 self.method = QComboBox() - self.method.addItems(['otsu', 'zscore', 'percentile', 'iqr', 'adaptive', 'multi_band']) + self.method.addItem("Otsu 阈值法", "otsu") + self.method.addItem("Z-Score 方法", "zscore") + self.method.addItem("百分位数法", "percentile") + self.method.addItem("IQR 四分位距法", "iqr") + self.method.addItem("自适应阈值法", "adaptive") + self.method.addItem("多波段综合法", "multi_band") params_layout.addRow("检测方法:", self.method) # 最大连通域面积 @@ -1290,13 +1333,13 @@ class Step2Panel(QWidget): layout.addStretch() self.setLayout(layout) - + # 信号连接:影像文件路径变化时动态更新波段范围 def get_config(self): """获取配置""" config = { 'img_path': self.img_file.get_path(), 'glint_wave': self.glint_wave.value(), - 'method': self.method.currentText(), + 'method': self.method.currentData(), # 使用 currentData() 获取英文ID } if self.max_area.value() > 0: config['max_area'] = self.max_area.value() @@ -1319,7 +1362,7 @@ class Step2Panel(QWidget): if 'glint_wave' in config: self.glint_wave.setValue(config['glint_wave']) if 'method' in config: - idx = self.method.findText(config['method']) + idx = self.method.findData(config['method']) # 使用 findData() if idx >= 0: self.method.setCurrentIndex(idx) if 'max_area' in config: @@ -1333,7 +1376,7 @@ class Step2Panel(QWidget): def update_from_config(self, work_dir=None, pipeline=None): """ - 从全局配置/Pipeline 自动填充路径,实现上下游数据流转 + 从全局配置/Pipeline 或 Step1Panel 自动填充路径,实现上下游数据流转 Args: work_dir: 工作目录路径 @@ -1347,11 +1390,26 @@ class Step2Panel(QWidget): else: self.work_dir = None - # 1. 自动填充水域掩膜路径(从步骤1的输出获取) + # 1. 尝试从 Pipeline 获取 + mask_path = None if pipeline and hasattr(pipeline, 'water_mask_path') and pipeline.water_mask_path: - self.water_mask_file.set_path(pipeline.water_mask_path) + mask_path = pipeline.water_mask_path + + # 2. 如果 Pipeline 中没有,则尝试直接从 Step1 界面读取(关键修复) + main_window = self.window() + if not mask_path and hasattr(main_window, 'step1_panel'): + if main_window.step1_panel.use_ndwi_radio.isChecked(): + # NDWI模式,读取输出框的路径 + mask_path = main_window.step1_panel.output_file.get_path() + else: + # 导入现有模式,读取输入框的路径 + mask_path = main_window.step1_panel.mask_file.get_path() - # 2. 自动填充输出路径(基于工作目录) + # 填充获取到的路径 + if mask_path: + self.water_mask_file.set_path(mask_path) + + # 3. 自动填充输出路径(基于工作目录) if self.work_dir: # 生成输出耀斑掩膜的标准路径:workspace/2_glint_mask/glint_mask_out.dat output_dir = os.path.join(self.work_dir, "2_glint_mask") @@ -1414,8 +1472,10 @@ class Step3Panel(QWidget): method_layout = QVBoxLayout() self.method = QComboBox() - self.method.addItems(['goodman', 'kutser', 'hedley', 'sugar']) - self.method.currentTextChanged.connect(self.on_method_changed) + for text, data in [('Goodman方法', 'goodman'), ('Kutser方法', 'kutser'), + ('Hedley方法', 'hedley'), ('SUGAR算法', 'sugar')]: + self.method.addItem(text, data) + self.method.currentIndexChanged.connect(self._on_method_changed) method_layout.addWidget(self.method) method_group.setLayout(method_layout) @@ -1540,8 +1600,10 @@ class Step3Panel(QWidget): interp_layout.addRow("", self.interpolate_zeros) self.interp_method = QComboBox() - self.interp_method.addItems(['nearest', 'bilinear', 'spline', 'kriging']) - self.interp_method.setCurrentText('bilinear') + for text, data in [('最近邻插值', 'nearest'), ('双线性插值', 'bilinear'), + ('样条插值', 'spline'), ('克里金插值', 'kriging')]: + self.interp_method.addItem(text, data) + self.interp_method.setCurrentIndex(1) # 默认双线性插值 interp_layout.addRow("插值方法:", self.interp_method) interp_group.setLayout(interp_layout) @@ -1568,22 +1630,87 @@ class Step3Panel(QWidget): layout.addStretch() self.setLayout(layout) + # 信号连接:影像文件路径变化时动态更新波段范围 + self.img_file.line_edit.textChanged.connect(self._update_band_ranges) - def on_method_changed(self, method): + + def _update_band_ranges(self, file_path): + """根据选择的影像动态限制波段索引的输入范围""" + import os + from osgeo import gdal + + if not file_path or not os.path.isfile(file_path): + return + + try: + dataset = gdal.Open(file_path) + if dataset is None: + return + raster_count = dataset.RasterCount + max_band = max(0, raster_count - 1) + self.nir_lower.setMaximum(max_band) + self.nir_upper.setMaximum(max_band) + self.oxy_band.setMaximum(max_band) + self.nir_band.setMaximum(max_band) + self.hedley_nir_band.setMaximum(max_band) + dataset = None + except Exception: + pass + + def update_from_config(self, work_dir=None, pipeline=None): + """ + 从 Step1Panel 自动填充水域掩膜路径,实现上下游数据流转 + + Args: + work_dir: 工作目录路径 + pipeline: Pipeline 实例(未使用,保留接口兼容性) + """ + # 保存工作目录引用 + if work_dir: + self.work_dir = work_dir + elif hasattr(self, 'work_dir') and self.work_dir: + pass # 保持现有工作目录 + else: + self.work_dir = None + + # 从 Step1 界面读取水域掩膜路径 + main_window = self.window() + if hasattr(main_window, 'step1_panel'): + if main_window.step1_panel.use_ndwi_radio.isChecked(): + # NDWI模式,读取输出框的路径 + mask_path = main_window.step1_panel.output_file.get_path() + else: + # 导入现有模式,读取输入框的路径 + mask_path = main_window.step1_panel.mask_file.get_path() + + if mask_path: + self.water_mask_file.set_path(mask_path) + + # 自动填充输出路径(基于工作目录) + if self.work_dir: + output_dir = os.path.join(self.work_dir, "3_deglint") + os.makedirs(output_dir, exist_ok=True) + default_output_path = os.path.join(output_dir, "deglint_image.dat").replace('\\', '/') + self.output_file.set_path(default_output_path) + else: + self.output_file.set_path("") + + def _on_method_changed(self, index): """方法改变时更新参数显示""" - self.goodman_group.setVisible(method == 'goodman') - self.kutser_group.setVisible(method == 'kutser') - self.hedley_group.setVisible(method == 'hedley') - self.sugar_group.setVisible(method == 'sugar') + method_id = self.method.currentData() + self.goodman_group.setVisible(method_id == 'goodman') + self.kutser_group.setVisible(method_id == 'kutser') + self.hedley_group.setVisible(method_id == 'hedley') + self.sugar_group.setVisible(method_id == 'sugar') def get_config(self): """获取配置""" config = { 'img_path': self.img_file.get_path(), - 'method': self.method.currentText(), + 'method': self.method.currentData(), # 使用 currentData() 获取英文ID 'enabled': self.enable_checkbox.isChecked(), 'interpolate_zeros': self.interpolate_zeros.isChecked(), - 'interpolation_method': self.interp_method.currentText(), + 'interpolation_method': self.interp_method.currentData(), # 使用 currentData() } water_mask_path = self.water_mask_file.get_path() if water_mask_path: @@ -1592,7 +1719,7 @@ class Step3Panel(QWidget): if output_path: config['output_path'] = output_path - method = self.method.currentText() + method = self.method.currentData() # 使用 currentData() if method == 'goodman': config['nir_lower'] = self.nir_lower.value() @@ -1613,7 +1740,7 @@ class Step3Panel(QWidget): config['sugar_iter'] = self.sugar_iter.value() if self.sugar_iter.value() > 0 else None config['sugar_sigma'] = self.sugar_sigma.value() config['sugar_estimate_background'] = self.sugar_estimate_background.isChecked() - config['sugar_glint_mask_method'] = self.sugar_glint_mask_method.currentText() + config['sugar_glint_mask_method'] = self.sugar_glint_mask_method.currentData() config['sugar_termination_thresh'] = self.sugar_termination_thresh.value() # 解析bounds字符串 try: @@ -1633,7 +1760,7 @@ class Step3Panel(QWidget): if 'output_path' in config: self.output_file.set_path(config['output_path']) if 'method' in config: - idx = self.method.findText(config['method']) + idx = self.method.findData(config['method']) # 使用 findData() if idx >= 0: self.method.setCurrentIndex(idx) if 'enabled' in config: @@ -1641,7 +1768,7 @@ class Step3Panel(QWidget): if 'interpolate_zeros' in config: self.interpolate_zeros.setChecked(config['interpolate_zeros']) if 'interpolation_method' in config: - idx = self.interp_method.findText(config['interpolation_method']) + idx = self.interp_method.findData(config['interpolation_method']) # 使用 findData() if idx >= 0: self.interp_method.setCurrentIndex(idx) @@ -1677,7 +1804,7 @@ class Step3Panel(QWidget): if 'sugar_estimate_background' in config: self.sugar_estimate_background.setChecked(config['sugar_estimate_background']) if 'sugar_glint_mask_method' in config: - idx = self.sugar_glint_mask_method.findText(config['sugar_glint_mask_method']) + idx = self.sugar_glint_mask_method.findData(config['sugar_glint_mask_method']) # 使用 findData() if idx >= 0: self.sugar_glint_mask_method.setCurrentIndex(idx) if 'sugar_termination_thresh' in config: @@ -1937,7 +2064,7 @@ class Step5Panel(QWidget): layout.addStretch() self.setLayout(layout) - + # 信号连接:影像文件路径变化时动态更新波段范围 def get_config(self): """获取配置""" config = { @@ -2375,7 +2502,7 @@ class Step6Panel(QWidget): layout.addStretch() self.setLayout(layout) - + # 信号连接:影像文件路径变化时动态更新波段范围 def create_ml_page(self): """创建机器学习模型页面""" layout = QVBoxLayout() @@ -2676,7 +2803,7 @@ class Step7Panel(QWidget): layout.addStretch() self.setLayout(layout) - + # 信号连接:影像文件路径变化时动态更新波段范围 def get_config(self): """获取配置""" config = { @@ -3528,7 +3655,7 @@ class ChartViewerDialog(QDialog): layout.addLayout(btn_layout) self.setLayout(layout) - + # 信号连接:影像文件路径变化时动态更新波段范围 def display_image(self, image_path): """显示图片""" self.figure.clear() @@ -4869,7 +4996,7 @@ class ChartBrowserDialog(QDialog): layout.addLayout(btn_layout) self.setLayout(layout) - + # 信号连接:影像文件路径变化时动态更新波段范围 def show_chart(self, index): """显示指定索引的图表""" if 0 <= index < len(self.chart_files): @@ -5057,7 +5184,7 @@ class Step6_5Panel(QWidget): layout.addStretch() self.setLayout(layout) - + # 信号连接:影像文件路径变化时动态更新波段范围 def get_config(self): """获取配置""" selected_algorithms = [ @@ -6381,6 +6508,10 @@ class WaterQualityGUI(QMainWindow): if index == 1: self.step2_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + # Step3 切换时自动填充数据流转路径 + elif index == 2: + self.step3_panel.update_from_config(work_dir=self.work_dir) + def apply_stylesheet(self): """应用样式表 - 应用现代化设计风格""" # 应用主样式表