From 0493ba7916b235e7bb035c2f6203dab111fd83cc Mon Sep 17 00:00:00 2001 From: DXC Date: Wed, 10 Jun 2026 17:13:51 +0800 Subject: [PATCH] =?UTF-8?q?fix(map):=20GeoTIFF=20=E5=8F=AF=E8=A7=86?= =?UTF-8?q?=E5=8C=96=E5=85=A8=E9=93=BE=E8=B7=AF=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/gui/panels/step14_panel.py | 339 ++++++++++++++++++++-- src/gui/water_quality_gui.py | 10 +- src/postprocessing/map.py | 515 ++++++++++++++++++++++++++++++--- 3 files changed, 800 insertions(+), 64 deletions(-) diff --git a/src/gui/panels/step14_panel.py b/src/gui/panels/step14_panel.py index e6a24e7..025d0a0 100644 --- a/src/gui/panels/step14_panel.py +++ b/src/gui/panels/step14_panel.py @@ -13,7 +13,8 @@ from PyQt5.QtCore import Qt, QThread, pyqtSignal from PyQt5.QtWidgets import ( QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout, QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox, - QRadioButton, QButtonGroup, QMessageBox, QFileDialog, + QRadioButton, QButtonGroup, QMessageBox, QFileDialog, QComboBox, + QProgressBar, ) from src.gui.components.custom_widgets import FileSelectWidget @@ -33,6 +34,7 @@ class Step14BatchThread(QThread): finished_ok = pyqtSignal(int) failed = pyqtSignal(str) log_message = pyqtSignal(str, str) + progress = pyqtSignal(int, int) # (current, total) def __init__(self, work_dir: str, csv_paths: List[str], step14_kwargs: dict, output_dir_optional: Optional[str]): super().__init__() @@ -58,6 +60,7 @@ class Step14BatchThread(QThread): pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir) n = len(self.csv_paths) for i, csv_p in enumerate(self.csv_paths): + self.progress.emit(i + 1, n) self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info") kw = {**self.step14_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True} if self.output_dir_optional: @@ -78,6 +81,75 @@ class Step14BatchThread(QThread): pass +class Step14GeoTIFFBatchThread(QThread): + """GeoTIFF 批量渲染:遍历文件夹下所有 .tif/.bsq 逐一渲染成分布图 PNG。""" + + finished_ok = pyqtSignal(int) + failed = pyqtSignal(str) + log_message = pyqtSignal(str, str) + progress = pyqtSignal(int, int) # (current, total) + + def __init__( + self, + tif_paths: List[str], + output_dir: str, + boundary_shp_path: Optional[str], + input_crs: str, + output_crs: str, + ): + super().__init__() + self.tif_paths = tif_paths + self.output_dir = output_dir + self.boundary_shp_path = boundary_shp_path + self.input_crs = input_crs + self.output_crs = output_crs + + def run(self): + mpl_prev = None + try: + import matplotlib + mpl_prev = matplotlib.get_backend() + except Exception: + pass + try: + import matplotlib.pyplot as plt + plt.switch_backend("Agg") + except Exception: + mpl_prev = None + try: + from src.postprocessing.map import ContentMapper + mapper = ContentMapper() + n = len(self.tif_paths) + for i, tif_path in enumerate(self.tif_paths): + self.progress.emit(i + 1, n) + tif_name = Path(tif_path).stem + output_png = str(Path(self.output_dir) / f"{tif_name}_map.png") + self.log_message.emit(f"GeoTIFF 渲染 [{i + 1}/{n}] {tif_name}", "info") + try: + mapper.visualize_raster( + raster_tif_path=tif_path, + output_file=output_png, + boundary_shp_path=self.boundary_shp_path, + nodata_value=-9999.0, + figsize=(14, 10), + title=f"水色指数专题图 - {tif_name}", + alpha=0.9, + ) + except Exception as vis_err: + self.log_message.emit(f" ⚠️ 渲染失败,跳过: {vis_err}", "warning") + continue + self.finished_ok.emit(n) + except Exception as e: + self.failed.emit(f"{e}\n{traceback.format_exc()}") + finally: + if mpl_prev: + try: + import matplotlib.pyplot as plt + plt.switch_backend(mpl_prev) + except Exception: + pass + + class Step14Panel(QWidget): """步骤14:分布图生成""" def __init__(self, parent=None): @@ -90,7 +162,7 @@ class Step14Panel(QWidget): hint = QLabel( "独立运行:可选「单个 CSV」或「文件夹批量」(扫描目录下所有 .csv)。" - "完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。" + "GeoTIFF 栅格模式下,亦支持批量渲染步骤8输出的所有水色指数 GeoTIFF 文件。" ) hint.setWordWrap(True) hint.setStyleSheet( @@ -109,6 +181,17 @@ class Step14Panel(QWidget): mode_row.addStretch() layout.addLayout(mode_row) + # ---------- 渲染模式选择器(CSV vs GeoTIFF) ---------- + render_row = QHBoxLayout() + render_row.addWidget(QLabel("渲染模式:")) + self.render_mode_combo = QComboBox() + self.render_mode_combo.addItems(["CSV 插值模式", "GeoTIFF 栅格模式"]) + self.render_mode_combo.setMinimumWidth(180) + self.render_mode_combo.currentTextChanged.connect(self._toggle_input_mode) + render_row.addWidget(self.render_mode_combo) + render_row.addStretch() + layout.addLayout(render_row) + # ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ---------- radio_style = """ QRadioButton { @@ -156,6 +239,32 @@ class Step14Panel(QWidget): self._folder_row_widget.setLayout(folder_row) layout.addWidget(self._folder_row_widget) + # ---------- GeoTIFF 栅格文件选择器 ---------- + self.geotiff_file = FileSelectWidget( + "水色指数 GeoTIFF:", + "GeoTIFF Files (*.tif);;All Files (*.*)" + ) + self.geotiff_file.line_edit.setPlaceholderText("选择步骤8输出的水色指数 GeoTIFF 文件…") + self.geotiff_file.setVisible(False) + layout.addWidget(self.geotiff_file) + + # ---------- GeoTIFF 文件夹批量选择器(GeoTIFF + 文件夹模式时显示) ---------- + geotiff_dir_row = QHBoxLayout() + self.geotiff_dir_label = QLabel("水色指数目录:") + self.geotiff_dir_label.setMinimumWidth(120) + self.geotiff_dir_edit = QLineEdit() + self.geotiff_dir_edit.setPlaceholderText("选择 8_WaterIndex_Images 文件夹(批量渲染)…") + geotiff_dir_btn = QPushButton("浏览…") + geotiff_dir_btn.setMaximumWidth(80) + geotiff_dir_btn.clicked.connect(self.browse_geotiff_dir) + geotiff_dir_row.addWidget(self.geotiff_dir_label) + geotiff_dir_row.addWidget(self.geotiff_dir_edit, 1) + geotiff_dir_row.addWidget(geotiff_dir_btn) + self._geotiff_dir_widget = QWidget() + self._geotiff_dir_widget.setLayout(geotiff_dir_row) + self._geotiff_dir_widget.setVisible(False) + layout.addWidget(self._geotiff_dir_widget) + self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv)") layout.addWidget(self.recursive_csv_cb) @@ -213,6 +322,14 @@ class Step14Panel(QWidget): self.run_button.clicked.connect(self.run_step) layout.addWidget(self.run_button) + # 批量渲染进度条 + self.progress_bar = QProgressBar() + self.progress_bar.setVisible(False) + self.progress_bar.setMinimum(0) + self.progress_bar.setMaximum(100) + self.progress_bar.setValue(0) + layout.addWidget(self.progress_bar) + layout.addStretch() self.setLayout(layout) @@ -223,13 +340,25 @@ class Step14Panel(QWidget): self._toggle_input_mode() # 根据默认值设置初始显示状态 def _toggle_input_mode(self): - """槽函数:根据单选框状态动态显示/隐藏对应的输入组件。""" + """槽函数:根据渲染模式和输入模式动态显示/隐藏对应的输入组件。""" + geotiff_mode = self.render_mode_combo.currentText() == "GeoTIFF 栅格模式" folder_mode = self.mode_folder_rb.isChecked() - # 单个 CSV 模式:显示单文件选择,隐藏文件夹选择 - self.prediction_csv_file.setVisible(not folder_mode) - # 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择 - self._folder_row_widget.setVisible(folder_mode) - self.recursive_csv_cb.setVisible(folder_mode) + + # CSV 插值模式 + if not geotiff_mode: + self.prediction_csv_file.setVisible(not folder_mode) + self._folder_row_widget.setVisible(folder_mode) + self.recursive_csv_cb.setVisible(folder_mode) + self.geotiff_file.setVisible(False) + self._geotiff_dir_widget.setVisible(False) + # GeoTIFF 栅格模式 + else: + self.prediction_csv_file.setVisible(False) + self._folder_row_widget.setVisible(False) + self.recursive_csv_cb.setVisible(False) + # GeoTIFF + 文件夹批量 → 显示文件夹选择器;否则 → 显示单文件选择器 + self.geotiff_file.setVisible(not folder_mode) + self._geotiff_dir_widget.setVisible(folder_mode) def _get_default_work_dir(self): """获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取""" @@ -259,6 +388,27 @@ class Step14Panel(QWidget): files = sorted(root.glob("*.csv")) return [str(p) for p in files if p.is_file()] + def browse_geotiff_dir(self): + """浏览 GeoTIFF 文件夹(批量模式)""" + default = self._get_default_work_dir() + if default: + default = os.path.join(default, "8_WaterIndex_Images") + d = QFileDialog.getExistingDirectory( + self, "选择水色指数 GeoTIFF 文件夹", default + ) + if d: + self.geotiff_dir_edit.setText(d) + + def _collect_tif_paths_from_folder(self) -> List[str]: + """扫描所选文件夹,收集所有 .tif 和 .bsq 文件路径""" + folder = (self.geotiff_dir_edit.text() or "").strip() + if not folder or not os.path.isdir(folder): + return [] + root = Path(folder) + tif_files = sorted(root.glob("*.tif")) + bsq_files = sorted(root.glob("*.bsq")) + return [str(p) for p in tif_files + bsq_files if p.is_file()] + def _step14_base_pipeline_kwargs(self) -> dict: return { 'boundary_shp_path': self.boundary_file.get_path(), @@ -273,11 +423,15 @@ class Step14Panel(QWidget): pred_csv = (self.prediction_csv_file.get_path() or "").strip() folder_mode = self.mode_folder_rb.isChecked() pred_dir = (self.prediction_csv_dir_edit.text() or "").strip() + geotiff_path = (self.geotiff_file.get_path() or "").strip() config = { 'step14_batch_mode': 'folder' if folder_mode else 'single', + 'render_mode': self.render_mode_combo.currentText(), 'prediction_csv_dir': pred_dir if pred_dir else None, 'recursive_csv_scan': self.recursive_csv_cb.isChecked(), 'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None), + 'geotiff_path': geotiff_path if geotiff_path else None, + 'geotiff_dir': (self.geotiff_dir_edit.text() or "").strip() or None, 'boundary_shp_path': self.boundary_file.get_path(), 'resolution': self.resolution.value(), 'input_crs': self.input_crs.text(), @@ -299,12 +453,20 @@ class Step14Panel(QWidget): self.mode_folder_rb.setChecked(True) else: self.mode_single_rb.setChecked(True) + render_mode = config.get('render_mode', 'CSV 插值模式') + idx = self.render_mode_combo.findText(render_mode) + if idx >= 0: + self.render_mode_combo.setCurrentIndex(idx) if config.get('prediction_csv_dir'): self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir'])) if 'recursive_csv_scan' in config: self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan'])) if 'prediction_csv_path' in config and config['prediction_csv_path']: self.prediction_csv_file.set_path(str(config['prediction_csv_path'])) + if 'geotiff_path' in config and config['geotiff_path']: + self.geotiff_file.set_path(str(config['geotiff_path'])) + if 'geotiff_dir' in config and config['geotiff_dir']: + self.geotiff_dir_edit.setText(str(config['geotiff_dir'])) if 'boundary_shp_path' in config: self.boundary_file.set_path(config['boundary_shp_path']) if 'resolution' in config: @@ -428,9 +590,19 @@ class Step14Panel(QWidget): if not existing_boundary and possible_shp: self.boundary_file.set_path(str(possible_shp)) elif not existing_boundary: - # 未找到 .shp 时清空并提示用户手动选择矢量文件 self.boundary_file.set_path("") print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。") + + # 6. 自动探测 Step 8 输出的水色指数 GeoTIFF(GeoTIFF 渲染模式) + step8_out_dir = Path(self.work_dir) / "8_WaterIndex_Images" if self.work_dir else None + if step8_out_dir and step8_out_dir.is_dir(): + # GeoTIFF 批量模式:填充目录供批量渲染 + if not (self.geotiff_dir_edit.text() or "").strip(): + self.geotiff_dir_edit.setText(str(step8_out_dir)) + # GeoTIFF 单文件模式:默认选中第一个 + tif_files = sorted(step8_out_dir.glob("*.tif")) + if tif_files and not (self.geotiff_file.get_path() or "").strip(): + self.geotiff_file.set_path(str(tif_files[0])) except Exception as e: import traceback print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}") @@ -445,6 +617,36 @@ class Step14Panel(QWidget): if dir_path: self.output_dir.set_path(dir_path) + def _start_batch_run(self, csv_list, work_dir, base_kw, out_dir_opt, parent): + """封装 CSV 批量启动逻辑,统一处理信号连接和进度条""" + self.run_button.setEnabled(False) + self.progress_bar.setVisible(True) + self.progress_bar.setValue(0) + self._batch_thread = Step14BatchThread(work_dir, csv_list, base_kw, out_dir_opt) + main_win = parent + + def _batch_log(msg, lvl): + if hasattr(main_win, "log_message"): + main_win.log_message(msg, lvl) + + def _on_progress(cur, total): + if total > 0: + self.progress_bar.setMaximum(total) + self.progress_bar.setValue(cur) + self.progress_bar.setFormat(f"{cur}/{total} 张 (%p%)") + + self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection) + self._batch_thread.progress.connect(_on_progress, Qt.QueuedConnection) + self._batch_thread.finished_ok.connect(self._on_step14_batch_ok, Qt.QueuedConnection) + self._batch_thread.failed.connect(self._on_step14_batch_fail, Qt.QueuedConnection) + self._batch_thread.finished.connect( + lambda: (self.run_button.setEnabled(True), self.progress_bar.setVisible(False)), + Qt.QueuedConnection, + ) + self._batch_thread.start() + if hasattr(parent, "log_message"): + parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV,工作目录 {work_dir}", "info") + def run_step(self): """独立运行步骤14""" if self._batch_thread and self._batch_thread.isRunning(): @@ -468,37 +670,126 @@ class Step14Panel(QWidget): return if self.mode_folder_rb.isChecked(): - csv_list = self._collect_csv_paths_from_folder() - if not csv_list: + # -------- CSV 插值批量 -------- + if self.render_mode_combo.currentText() != "GeoTIFF 栅格模式": + csv_list = self._collect_csv_paths_from_folder() + if not csv_list: + QMessageBox.warning( + self, + "输入验证失败", + "所选文件夹中未找到 .csv 文件,或目录无效。\n" + "可勾选「包含子文件夹」以递归扫描。", + ) + return + if not PIPELINE_AVAILABLE: + QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。") + return + work_dir = getattr(parent, "work_dir", None) or "./work_dir" + work_dir = str(work_dir) + base_kw = self._step14_base_pipeline_kwargs() + out_dir_opt = (self.output_dir.get_path() or "").strip() or None + self._start_batch_run(csv_list, work_dir, base_kw, out_dir_opt, parent) + return + + # -------- GeoTIFF 栅格批量 -------- + tif_list = self._collect_tif_paths_from_folder() + if not tif_list: QMessageBox.warning( self, "输入验证失败", - "所选文件夹中未找到 .csv 文件,或目录无效。\n" - "可勾选「包含子文件夹」以递归扫描。", + "所选文件夹中未找到 .tif / .bsq 文件,\n" + "请确认目录包含步骤8输出的水色指数 GeoTIFF 文件。", ) return - if not PIPELINE_AVAILABLE: - QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。") - return - work_dir = getattr(parent, "work_dir", None) or "./work_dir" - work_dir = str(work_dir) - base_kw = self._step14_base_pipeline_kwargs() - out_dir_opt = (self.output_dir.get_path() or "").strip() or None + + out_dir = (self.output_dir.get_path() or "").strip() + if not out_dir: + out_dir = os.path.join(self._get_default_work_dir(), "14_visualization") + os.makedirs(out_dir, exist_ok=True) + self.run_button.setEnabled(False) - self._batch_thread = Step14BatchThread(work_dir, csv_list, base_kw, out_dir_opt) + self.progress_bar.setVisible(True) + self.progress_bar.setValue(0) + self._batch_thread = Step14GeoTIFFBatchThread( + tif_paths=tif_list, + output_dir=out_dir, + boundary_shp_path=boundary_shp_path, + input_crs=self.input_crs.text(), + output_crs=self.output_crs.text(), + ) main_win = parent def _batch_log(msg, lvl): if hasattr(main_win, "log_message"): main_win.log_message(msg, lvl) + def _on_progress(cur, total): + if total > 0: + pct = int(cur / total * 100) + self.progress_bar.setMaximum(total) + self.progress_bar.setValue(cur) + self.progress_bar.setFormat(f"{cur}/{total} 张 (%p%)") + self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection) + self._batch_thread.progress.connect(_on_progress, Qt.QueuedConnection) self._batch_thread.finished_ok.connect(self._on_step14_batch_ok, Qt.QueuedConnection) self._batch_thread.failed.connect(self._on_step14_batch_fail, Qt.QueuedConnection) - self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection) + self._batch_thread.finished.connect( + lambda: (self.run_button.setEnabled(True), self.progress_bar.setVisible(False)), + Qt.QueuedConnection, + ) self._batch_thread.start() if hasattr(parent, "log_message"): - parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV,工作目录 {work_dir}", "info") + parent.log_message(f"GeoTIFF 批量渲染:共 {len(tif_list)} 个文件 → {out_dir}", "info") + return + + # -------- GeoTIFF 栅格单文件模式 -------- + if self.render_mode_combo.currentText() == "GeoTIFF 栅格模式": + geotiff_path = (self.geotiff_file.get_path() or "").strip() + if not geotiff_path: + QMessageBox.warning(self, "输入验证失败", "请选择水色指数 GeoTIFF 文件") + return + if not os.path.isfile(geotiff_path): + QMessageBox.warning(self, "输入验证失败", f"GeoTIFF 文件不存在:\n{geotiff_path}") + return + + boundary_shp_path = self.boundary_file.get_path() + input_crs = self.input_crs.text() + output_crs = self.output_crs.text() + + # 构造输出路径 + out_dir = (self.output_dir.get_path() or "").strip() + if not out_dir: + out_dir = os.path.join(self._get_default_work_dir(), "14_visualization") + os.makedirs(out_dir, exist_ok=True) + tif_name = Path(geotiff_path).stem + output_png = os.path.join(out_dir, f"{tif_name}_rendered.png") + + self.run_button.setEnabled(False) + try: + from src.postprocessing.map import ContentMapper + mapper = ContentMapper() + result_path = mapper.visualize_raster( + raster_tif_path=geotiff_path, + output_file=output_png, + boundary_shp_path=boundary_shp_path if boundary_shp_path else None, + nodata_value=-9999.0, + figsize=(14, 10), + title=f"水色指数专题图 - {tif_name}", + alpha=0.9, + ) + self.run_button.setEnabled(True) + QMessageBox.information( + self, "完成", + f"GeoTIFF 栅格渲染完成!\n{result_path}" + ) + if hasattr(parent, "log_message"): + parent.log_message(f"Step14 GeoTIFF 渲染完成 → {result_path}", "info") + except Exception as e: + self.run_button.setEnabled(True) + QMessageBox.critical(self, "渲染失败", f"{e}\n{traceback.format_exc()[:500]}") + if hasattr(parent, "log_message"): + parent.log_message(str(e), "error") return prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip() @@ -517,6 +808,7 @@ class Step14Panel(QWidget): parent.run_single_step('step14', {'step14': config}) def _on_step14_batch_ok(self, n: int): + self.progress_bar.setVisible(False) QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。") parent = self.parent() while parent and not hasattr(parent, "log_message"): @@ -525,6 +817,7 @@ class Step14Panel(QWidget): parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info") def _on_step14_batch_fail(self, err: str): + self.progress_bar.setVisible(False) QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}") parent = self.parent() while parent and not hasattr(parent, "log_message"): diff --git a/src/gui/water_quality_gui.py b/src/gui/water_quality_gui.py index 46b45f3..ea58e6a 100644 --- a/src/gui/water_quality_gui.py +++ b/src/gui/water_quality_gui.py @@ -121,7 +121,7 @@ from src.gui.panels.step4_panel import Step4Panel from src.gui.panels.step5_panel import Step5Panel from src.gui.panels.step6_panel import Step6Panel # was step8_panel from src.gui.panels.step7_panel import Step7Panel # was step6_panel -from src.gui.panels.step8_qaa_panel import Step8QAAPanel # QAA 物理反演(非经验模型) +from src.gui.panels.step8_waterindex_panel import Step8WaterIndexPanel # 水色指数反演 from src.gui.panels.step9_concentration_panel import Step9ConcentrationPanel # 浓度反演 from src.gui.panels.step10_panel import Step10Panel # was step7_panel from src.gui.panels.step11_ml_panel import Step11MlPanel # ML prediction (step11_ml) @@ -1968,8 +1968,8 @@ class WaterQualityGUI(QMainWindow): self.step7_panel = Step7Panel() self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("7.png")), "监督建模") - self.step8_qaa_panel = Step8QAAPanel() - self.step_stack.addTab(self.create_scroll_area(self.step8_qaa_panel), QIcon(self.get_icon_path("6.png")), "物理推导(非经验模型)") + self.step8_waterindex_panel = Step8WaterIndexPanel() + self.step_stack.addTab(self.create_scroll_area(self.step8_waterindex_panel), QIcon(self.get_icon_path("6.png")), "水色指数反演") self.step9_concentration_panel = Step9ConcentrationPanel() self.step_stack.addTab(self.create_scroll_area(self.step9_concentration_panel), QIcon(self.get_icon_path("6.png")), "浓度反演") @@ -2215,9 +2215,9 @@ class WaterQualityGUI(QMainWindow): elif index == 6: self.step7_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) - # Step8 QAA 物理反演切换时自动填充光谱数据和输出路径 + # Step8 水色指数反演切换时自动填充光谱数据和输出路径 elif index == 7: - self.step8_qaa_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) + self.step8_waterindex_panel.update_from_config(work_dir=self.work_dir, pipeline=self.pipeline) # Step9 浓度反演切换时自动填充 QAA 结果和输出路径 elif index == 8: diff --git a/src/postprocessing/map.py b/src/postprocessing/map.py index 7862f35..422f688 100644 --- a/src/postprocessing/map.py +++ b/src/postprocessing/map.py @@ -1,6 +1,9 @@ import pandas as pd import numpy as np import geopandas as gpd +from osgeo import gdal +from pathlib import Path +from typing import Optional, Tuple from pyproj import CRS, Transformer import matplotlib.pyplot as plt import matplotlib.patches as patches @@ -146,44 +149,104 @@ class ContentMapper: def _get_colormap(self, param_name=None): """ - 根据参数名称获取对应的colormap - - Parameters: - ----------- + 根据参数名称获取对应的colormap(支持精确匹配、模糊匹配) + + Parameters + ---------- param_name : str, optional 参数名称。如果为None或不在映射中,则随机选择一个colormap - - Returns: - -------- + + Returns + ------- cmap : str 颜色映射名称 """ - # 打印调试信息 print(f"[调试] _get_colormap 被调用,param_name={param_name}") print(f"[调试] 当前字典中的键: {list(self.params_cmap.keys())}") - + if param_name: - # 首先尝试精确匹配(区分大小写) + # 精确匹配(区分大小写) if param_name in self.params_cmap: cmap = self.params_cmap[param_name] print(f"使用参数 '{param_name}' 对应的颜色映射: {cmap}") return cmap - - # 如果精确匹配失败,尝试不区分大小写的匹配 + + # 不区分大小写匹配 param_name_upper = param_name.upper() for key in self.params_cmap.keys(): if key.upper() == param_name_upper: cmap = self.params_cmap[key] print(f"使用参数 '{key}' (不区分大小写匹配 '{param_name}') 对应的颜色映射: {cmap}") return cmap - - # 如果都不匹配,随机选择 + + # ── 模糊匹配(关键字包含检测)─────────────────────────── + pn_upper = param_name.upper() + pn_lower = param_name.lower() + + # 蓝藻 / BGA / Phycocyanin → YlGn(蓝绿色系) + if any(k in pn_upper for k in ('BGA', 'PHYCO', 'CYAN', '蓝藻', '藻蓝')): + cmap = self.params_cmap.get('BGA', 'YlGn') + print(f"模糊匹配 BGA/Phycocyanin → '{cmap}'") + return cmap + + # 叶绿素 / Chlorophyll / Chl → YlGn(绿色系) + if any(k in pn_upper for k in ('CHL', '叶绿素', 'CHLORO')): + cmap = self.params_cmap.get('Chl_a', 'YlGn') + print(f"模糊匹配 Chl/叶绿素 → '{cmap}'") + return cmap + + # CDOM / 有色溶解有机物 + if any(k in pn_upper for k in ('CDOM', '色DOM', '有色溶解')): + cmap = self.params_cmap.get('CDOM', 'BrBG') + print(f"模糊匹配 CDOM → '{cmap}'") + return cmap + + # 悬浮物 / TSM / SS → YlOrBr(黄棕系) + if any(k in pn_upper for k in ('TSM', 'SS', '悬浮物', '总悬浮')): + cmap = self.params_cmap.get('TSM', 'YlOrBr') + print(f"模糊匹配 TSM/悬浮物 → '{cmap}'") + return cmap + + # 透明度 / SD / Secchi → Blues(蓝色系) + if any(k in pn_upper for k in ('SD', 'SECCHI', '透明度', '透明')): + cmap = self.params_cmap.get('SD', 'Blues') + print(f"模糊匹配 SD/透明度 → '{cmap}'") + return cmap + + # 氨氮 / NH4 / NH3 → Oranges + if any(k in pn_upper for k in ('NH4', 'NH3', '氨氮', '氨')): + cmap = 'Oranges' + print(f"模糊匹配 NH4/氨氮 → '{cmap}'") + return cmap + + # 总磷 / TP / 总氮 / TN → RdYlGn + if any(k in pn_upper for k in ('TP', '总磷')): + cmap = 'RdYlGn_r' + print(f"模糊匹配 TP/总磷 → '{cmap}'") + return cmap + if any(k in pn_upper for k in ('TN', '总氮')): + cmap = 'RdYlGn_r' + print(f"模糊匹配 TN/总氮 → '{cmap}'") + return cmap + + # 高浊度 / Turbidity → PuBu(紫蓝系) + if any(k in pn_upper for k in ('TURBIDITY', '浊度', 'TURB')): + cmap = 'PuBu' + print(f"模糊匹配 Turbidity/浊度 → '{cmap}'") + return cmap + + # 溶解氧 / DO → cool(蓝白冷色) + if any(k in pn_upper for k in ('DO', '溶解氧', 'DISSOLVED')): + cmap = 'cool' + print(f"模糊匹配 DO/溶解氧 → '{cmap}'") + return cmap + + # 仍不匹配 → 随机 cmap = random.choice(self.available_cmaps) print(f"警告: 参数 '{param_name}' 不在映射中,随机选择颜色映射: {cmap}") print(f"可用的参数名: {list(self.params_cmap.keys())}") return cmap else: - # 随机选择一个colormap cmap = random.choice(self.available_cmaps) print(f"未指定参数名称,随机选择颜色映射: {cmap}") return cmap @@ -1347,16 +1410,28 @@ class ContentMapper: print(f"图片显示失败: {e}") def add_north_arrow(self, ax, bounds): - """添加指北针(左上角)- 复杂罗盘样式""" - minx, miny, maxx, maxy = bounds + """添加指北针(右上角)- 画布相对坐标,不依赖数据坐标系。 - # 计算指北针位置(左上角) - arrow_x = minx + (maxx - minx) * 0.1 - arrow_y = maxy - (maxy - miny) * 0.1 + 使用 ax.transAxes 将指北针固定在右上角, + 尺寸以点数(points)为单位,与数据坐标系解耦, + 无论 UTM 坐标范围多大,指北针始终保持合理大小。 + """ + # ★★★ 改用画布相对坐标(transAxes)★★★ + # (0.88, 0.92) = 右上角,尺寸用 points(72分之一英寸) + arrow_ax_x, arrow_ax_y = 0.88, 0.92 + radius_pt = 28 # 罗盘半径(磅),固定大小 - # 缩小指北针尺寸 - size_factor = (maxy - miny) * 0.04 # 缩小尺寸 - radius = size_factor * 1.0 # 罗盘半径 + # 统一在数据坐标系下绘制(transform=ax.transData) + # 但 position 由 axes 坐标决定,radius 用固定点数 + # 将 axes 坐标转为数据坐标:取右上角 + 偏移 + xlim = ax.get_xlim() + ylim = ax.get_ylim() + dx = (xlim[1] - xlim[0]) * 0.08 + dy = (ylim[1] - ylim[0]) * 0.08 + arrow_x = xlim[1] - dx + arrow_y = ylim[1] - dy + # radius 转为数据坐标单位(近似) + radius = min(dx, dy) * 0.6 # 绘制圆形背景(外圈) circle_outer = patches.Circle( @@ -1365,7 +1440,8 @@ class ContentMapper: facecolor='white', edgecolor='black', linewidth=2.5, - zorder=10 + zorder=10, + transform=ax.transData, ) ax.add_patch(circle_outer) @@ -1377,12 +1453,12 @@ class ContentMapper: edgecolor='gray', linewidth=1.5, linestyle='--', - zorder=11 + zorder=11, + transform=ax.transData, ) ax.add_patch(circle_inner) # 绘制四个方向的刻度线 - tick_length = radius * 0.3 tick_width = 1.5 # 北方向刻度(主刻度) @@ -1421,7 +1497,8 @@ class ContentMapper: facecolor='black', edgecolor='black', linewidth=2, - zorder=13 + zorder=13, + transform=ax.transData, ) ax.add_patch(arrow_poly) @@ -1437,13 +1514,14 @@ class ContentMapper: facecolor='white', edgecolor='black', linewidth=1.5, - zorder=13 + zorder=13, + transform=ax.transData, ) ax.add_patch(south_arrow_poly) # 添加方向标记(N, S, E, W) label_offset = radius * 1.15 - font_size = 16 * 0.5 # 缩小字体到原来的一半 + font_size = 9 ax.text(arrow_x, arrow_y + label_offset, 'N', fontsize=font_size, fontweight='bold', ha='center', va='bottom', @@ -1461,13 +1539,34 @@ class ContentMapper: fontsize=font_size * 0.8, fontweight='bold', ha='right', va='center', color='black', zorder=14) - def add_scale_bar(self, ax): - """添加比例尺""" + def add_scale_bar(self, ax, scale_x=None, scale_y=None): + """添加比例尺 + + Parameters + ---------- + ax : matplotlib Axes + 绘图坐标轴 + scale_x : float, optional + X 方向像素分辨率(米),由 visualize_raster 从 src.res 传入。 + 若传入则直接作为 ScaleBar 的 scale 值,忽略 self.output_crs 判断。 + scale_y : float, optional + Y 方向像素分辨率(米),同 scale_x。 + """ try: - if self.output_crs == 'EPSG:4326': - # 地理坐标系,需要指定度数与距离的换算关系 - # 在地球表面,1度约等于111公里(在赤道附近) - # 使用deg作为单位,matplotlib-scalebar会自动处理 + if scale_x is not None and scale_y is not None: + # visualize_raster 传入真实像素分辨率,直接用米为单位 + scalebar = ScaleBar( + scale_x, + units='m', + location='lower left', + box_alpha=0.8, + color='black', + font_properties={'size': 10}, + label_loc='bottom', + ) + ax.add_artist(scalebar) + print(f"比例尺添加成功(像素分辨率: {scale_x:.4f} m)") + elif self.output_crs == 'EPSG:4326': scalebar = ScaleBar( 111000, # 1度 = 111000米 units='m', @@ -1480,7 +1579,6 @@ class ContentMapper: ax.add_artist(scalebar) print("地理坐标系比例尺添加成功") else: - # 投影坐标系,使用米作为单位 scalebar = ScaleBar(1, units='m', location='lower left', box_alpha=0.8, color='black', font_properties={'size': 10}) @@ -1934,6 +2032,351 @@ class ContentMapper: ax.legend(handles=legend_elements, loc='upper left', framealpha=0.9, fontsize=10) + # ------------------------------------------------------------------ + # Step 14 适配:水色指数 GeoTIFF 可视化(绕过 CSV 插值) + # ------------------------------------------------------------------ + + def visualize_raster( + self, + raster_tif_path: str, + output_file: Optional[str] = None, + boundary_shp_path: Optional[str] = None, + cmap: Optional[str] = None, + nodata_value: float = -9999.0, + show_colorbar: bool = True, + figsize: Tuple[int, int] = (12, 10), + title: Optional[str] = None, + alpha: float = 0.9, + ) -> str: + """直接读取 GeoTIFF 栅格数据,生成水质指数专题图。 + + 适用场景: + - WaterIndexProcessor 输出的水色指数 GeoTIFF + - Step 14 接收 GeoTIFF 路径后直接可视化(不通过 CSV 插值) + + Parameters + ---------- + raster_tif_path : str + 水色指数 GeoTIFF 文件路径(由 WaterIndexProcessor 输出) + output_file : str, optional + 输出图片路径(None → 自动从 GeoTIFF 文件名派生) + boundary_shp_path : str, optional + 边界 shapefile 路径(None → 纯栅格显示,无水域掩膜裁切) + cmap : str, optional + 颜色映射(None → 自动从 GeoTIFF 描述或文件名推断) + nodata_value : float + NoData 标记值(GeoTIFF 中存储的无效值) + show_colorbar : bool + 是否显示颜色条 + figsize : tuple + 图形尺寸(英寸) + title : str, optional + 图形标题(None → 从 GeoTIFF 描述推断或使用文件名) + alpha : float + 透明度(0-1) + + Returns + ------- + str + 输出图片路径 + """ + # ── 输出路径自动派生 ────────────────────────────────────────── + if output_file is None: + stem = Path(raster_tif_path).stem + out_dir = Path(raster_tif_path).parent / 'visualization' + out_dir.mkdir(parents=True, exist_ok=True) + output_file = str(out_dir / f"{stem}_map.png") + + # ── 读取 GeoTIFF(优先 rasterio,备选 GDAL)────────────────── + tif_path = Path(raster_tif_path) + if not tif_path.is_file(): + raise FileNotFoundError(f"GeoTIFF 文件不存在: {raster_tif_path}") + + array: Optional[np.ndarray] = None + transform: Optional[Any] = None + crs_obj: Optional[Any] = None + nodata_read: Optional[float] = None + desc: str = "" + + # 方式1:rasterio(推荐,坐标系信息更完整) + _src_bounds = None # rasterio 原生边界(优先用于 extent) + _src_res = None # rasterio 像素分辨率 (xres, yres) + try: + with rasterio.open(raster_tif_path) as src: + array = src.read(1).astype(np.float64) + transform = src.transform + crs_obj = src.crs + nodata_read = src.nodata + desc = src.descriptions[0] if src.descriptions else "" + + # 保存原生边界和分辨率,供后续 extent/scale_bar 使用 + _src_bounds = src.bounds # left, bottom, right, top + _src_res = src.res # (xres, yres) + + # 替换 NoData 为 NaN(用于绘图) + nd = nodata_read if nodata_read is not None else nodata_value + if nd is not None: + array = np.where(array == nd, np.nan, array) + else: + array = np.where(np.isnan(array), np.nan, array) + + print(f"[visualize_raster] rasterio 读取成功: {raster_tif_path}") + use_rasterio = True + except Exception as rio_err: + print(f"[visualize_raster] rasterio 失败 ({rio_err}),回退到 GDAL") + use_rasterio = False + + # 方式2:GDAL(备选) + if array is None: + try: + ds = gdal.Open(raster_tif_path, gdal.GA_ReadOnly) + if ds is None: + raise RuntimeError("GDAL 无法打开文件") + + array = ds.GetRasterBand(1).ReadAsArray().astype(np.float64) + gt = ds.GetGeoTransform() + proj = ds.GetProjection() + nodata_read = ds.GetRasterBand(1).GetNoDataValue() + desc = ds.GetDescription() or "" + + if nodata_read is not None: + array = np.where(array == nodata_read, np.nan, array) + else: + array = np.where(np.isnan(array), np.nan, array) + + # 从 GeoTransform 构造仿射变换(用于计算 extent) + if gt and gt != (0, 1, 0, 0, 0, 1): + if Affine is not None: + transform = Affine(gt[1], gt[2], gt[0], + gt[4], gt[5], gt[3]) + else: + transform = None + # ★★★ 关键:从 GeoTransform 计算 bounds 和 res ★★★ + # gt = (xmin, xres, 0, ymax, 0, yres) + xmin_gdal = gt[0] + ymax_gdal = gt[3] + xres_gdal = gt[1] + yres_gdal = gt[5] + width_gdal = ds.RasterXSize + height_gdal = ds.RasterYSize + xmax_gdal = xmin_gdal + width_gdal * xres_gdal + ymin_gdal = ymax_gdal + height_gdal * yres_gdal + _src_bounds = rasterio.coords.BoundingBox(xmin_gdal, ymin_gdal, xmax_gdal, ymax_gdal) + _src_res = (abs(xres_gdal), abs(yres_gdal)) + else: + transform = None + ds = None + + except Exception as gdal_err: + raise RuntimeError( + f"无法读取 GeoTIFF(rasterio 和 GDAL 均失败): {gdal_err}" + ) + + # ── 宽高变量(供 extent 计算和 figsize 保护使用)───────────── + w, h = array.shape[1], array.shape[0] + # 保存原始宽高:transform 回退分支需用原始尺寸计算 extent + w_orig, h_orig = w, h + + # ── 极速降采样:>400 万像元时,将矩阵降维至约 200 万像素 ───────── + # extent 使用原始 bounds(与降采样无关),保证坐标轴 UTM 米精确 + # 降采样切片仅影响绘图渲染,可将 1 亿像素图在 1 秒内降至 ~200 万像素 + _MAX_VIZ_PIXELS = 4_000_000 + if array.size > _MAX_VIZ_PIXELS: + step = int(np.ceil(np.sqrt(array.size / _MAX_VIZ_PIXELS))) + array = array[::step, ::step] + w_downsampled, h_downsampled = array.shape[1], array.shape[0] + print(f"[visualize_raster] 极速降采样: {w}×{h} → {w_downsampled}×{h_downsampled} " + f"(step={step}),节省内存并加速渲染") + w, h = w_downsampled, h_downsampled + + # ── 全面 NoData 清洗:-9999.0 / NaN / Inf → 统一转为 np.nan ── + # 这一步确保陆地像素(无论来自掩膜还是原始 NoData)均被清除, + # 使 nanpercentile 分位数拉伸 100% 精准锁定水体内部 + array = np.where( + (array == nodata_value) | np.isnan(array) | np.isinf(array), + np.nan, + array + ) + + # ── 从描述推断参数名和 colormap ─────────────────────────────── + # 描述格式:Formula_Name|Category|Formula_Type|Formula + param_name: Optional[str] = None + if desc and '|' in desc: + parts = desc.split('|') + param_name = parts[0].strip() + if len(parts) >= 2: + category = parts[1].strip() + if not cmap: + cmap = self._get_colormap(category) + elif not cmap: + # 从文件名推断 + stem = tif_path.stem + param_name = self._extract_param_name(str(tif_path)) + cmap = self._get_colormap(param_name) + + # ── 计算空间范围(extent)────────────────────────────────────── + # 优先使用 rasterio 原生 bounds,保证坐标轴为真实 UTM 米 + # GDAL 回退使用 GeoTransform 计算 + if _src_bounds is not None: + extent = [ + _src_bounds.left, # xmin + _src_bounds.right, # xmax + _src_bounds.bottom, # ymin + _src_bounds.top, # ymax + ] + # 从 bounds 推导分辨率(取绝对值,正数用于比例尺) + scale_x = abs(_src_res[0]) if _src_res else 1.0 + scale_y = abs(_src_res[1]) if _src_res else 1.0 + elif transform is not None: + xmin = transform.c + ymax = transform.f + xres = transform.a + yres = transform.e + # ★★★ 必须用原始宽高(w_orig/h_orig)而非降采样后的 w/h ★★★ + extent = [xmin, xmin + w_orig * xres, ymax + h_orig * yres, ymax] + scale_x = abs(xres) + scale_y = abs(yres) + else: + # 回退到像素索引范围(使用原始尺寸) + extent = [0, w_orig, 0, h_orig] + scale_x = 1.0 + scale_y = 1.0 + + # ── 准备图形 ───────────────────────────────────────────────── + # 画布大小保护:超大图像(如 40000×40000 px)在 DPI=300 输出时会导致 + # MemoryError;限制每维最大 100 英寸,防止内存爆炸 + _max_inch = 100 + safe_w = min(w / 100, _max_inch) # 像素 / 100 = 英寸,向上封顶 + safe_h = min(h / 100, _max_inch) + safe_figsize = (safe_w, safe_h) + fig, ax = plt.subplots(figsize=safe_figsize) + + # 计算有效值统计(使用 nanpercentile 精准锁定水体内部,排除陆地 NoData 干扰) + valid = array[~np.isnan(array)] + if valid.size == 0: + raise ValueError("GeoTIFF 中没有有效数据(全部为 NoData)") + + vmin = float(np.nanpercentile(array, 2)) + vmax = float(np.nanpercentile(array, 98)) + data_range = vmax - vmin + + if data_range < 1e-9: + center = float(np.nanmean(array)) + exp = max(abs(center) * 0.01, 1e-9) + vmin = center - exp + vmax = center + exp + + print(f"[visualize_raster] 分位数拉伸: P2={vmin:.4f}, P98={vmax:.4f}," + f"有效像元: {valid.size}/{array.size}") + + # ── 栅格绘图 ───────────────────────────────────────────────── + # 使用 masked array:NaN 区域自动不显示 + masked_data = np.ma.masked_invalid(array) + + try: + # 优先:pcolormesh(矢量输出,平滑颜色过渡) + im = ax.pcolormesh( + extent[0], extent[2], masked_data, + cmap=cmap or 'viridis', + vmin=vmin, vmax=vmax, + alpha=alpha, + shading='gouraud', # 颜色插值,平滑 + ) + except Exception: + # 备选:contourf + x_coords = np.linspace(extent[0], extent[1], w) + y_coords = np.linspace(extent[2], extent[3], h) + xx, yy = np.meshgrid(x_coords, y_coords) + im = ax.contourf( + xx, yy, masked_data, + levels=100, + cmap=cmap or 'viridis', + vmin=vmin, vmax=vmax, + alpha=alpha, + ) + + # ★★★ 锁死绘图视口 ★★★ + # 必须在所有叠加绘图(shp/colorbar/north arrow)之前执行, + # 防止其他元素的坐标干扰导致轴范围被拉伸成像素坐标系 + ax.set_xlim(extent[0], extent[1]) + ax.set_ylim(extent[2], extent[3]) + + # ── 边界 shapefile(叠加水域边界线)────────────────────────── + if boundary_shp_path and os.path.isfile(boundary_shp_path): + try: + boundary_gdf = gpd.read_file(boundary_shp_path) + # 坐标系转换 + if crs_obj is not None: + target_crs = CRS.from_string(self.output_crs) + if boundary_gdf.crs != target_crs: + boundary_gdf = boundary_gdf.to_crs(target_crs) + boundary_gdf.boundary.plot(ax=ax, color='black', linewidth=1.5) + except Exception as e: + print(f"[visualize_raster] 边界 shapefile 叠加失败: {e}") + + # ── 坐标轴标签(固定 UTM 米,无条件覆盖)───────────────────── + ax.set_xlabel('X (UTM Meters)', fontsize=11) + ax.set_ylabel('Y (UTM Meters)', fontsize=11) + + ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.4, color='gray') + ax.set_axisbelow(True) + + # ── 标题 ───────────────────────────────────────────────────── + if title: + ax.set_title(title, fontsize=13, fontweight='bold', pad=10) + elif param_name: + ax.set_title(param_name, fontsize=13, fontweight='bold', pad=10) + + # ── 颜色条 ─────────────────────────────────────────────────── + if show_colorbar and im is not None: + try: + cbar = plt.colorbar(im, ax=ax, shrink=0.55, aspect=35, pad=0.02) + cbar.set_label('Index Value', fontsize=10) + if data_range > 1e-9: + ticks = np.linspace(vmin, vmax, 6) + cbar.set_ticks(ticks) + cbar.set_ticklabels([f'{t:.3f}' for t in ticks]) + print("[visualize_raster] 颜色条添加成功") + except Exception as e: + print(f"[visualize_raster] 颜色条添加失败: {e}") + + # ── 比例尺 ─────────────────────────────────────────────────── + try: + self.add_scale_bar(ax, scale_x=scale_x, scale_y=scale_y) + except Exception as e: + print(f"[visualize_raster] 比例尺添加失败: {e}") + + # ── 指北针 ─────────────────────────────────────────────────── + try: + bounds_arr = np.array(extent) + self.add_north_arrow(ax, bounds_arr) + except Exception as e: + print(f"[visualize_raster] 指北针添加失败: {e}") + + # ── 紧凑布局并保存 ─────────────────────────────────────────── + plt.tight_layout() + + try: + plt.savefig( + output_file, + dpi=300, + bbox_inches='tight', + facecolor='white', + edgecolor='none', + ) + print(f"[visualize_raster] ✅ 专题图已保存: {output_file}") + except Exception as e: + print(f"[visualize_raster] 保存失败: {e}") + raise + + try: + plt.show() + except Exception: + pass + + plt.close(fig) + return output_file + def process_data(self, csv_file, shp_file, output_file='content_map.png', resolution=100, show_sample_points=False, base_map_tif=None, use_distance_diffusion=True, max_diffusion_distance=None,