Compare commits
1 Commits
605ec86108
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 9ce17df28a |
@ -1,336 +0,0 @@
|
||||
# Step1Panel UI 联动逻辑优化说明
|
||||
|
||||
## 📋 修改概览
|
||||
|
||||
本次优化针对 Step1Panel(水域掩膜生成步骤)的 UI 联动逻辑进行了深度重构,主要解决了:
|
||||
1. ✅ 输出掩膜组件与单选按钮的深度绑定
|
||||
2. ✅ 路径显示的斜杠混用问题
|
||||
3. ✅ 底层运行逻辑的兼容性
|
||||
|
||||
---
|
||||
|
||||
## 🎯 核心改进
|
||||
|
||||
### 1. 输出掩膜的显示/隐藏与单选按钮深度绑定
|
||||
|
||||
#### 修改位置:`Step1Panel.update_ui_state()`
|
||||
|
||||
**原逻辑问题**:
|
||||
- 输出掩膜在两种模式下都显示,不符合业务逻辑
|
||||
- "使用现有掩膜文件"时不需要指定输出路径
|
||||
|
||||
**新逻辑**:
|
||||
```python
|
||||
def update_ui_state(self):
|
||||
"""根据选择的掩膜生成方式更新UI状态(使用显示/隐藏控制)"""
|
||||
use_ndwi = self.use_ndwi_radio.isChecked()
|
||||
|
||||
# 动态显示/隐藏组件
|
||||
if use_ndwi:
|
||||
# 使用NDWI模式:隐藏掩膜文件,显示NDWI参数和输出掩膜
|
||||
self.mask_file.setVisible(False)
|
||||
self.ndwi_group.setVisible(True)
|
||||
self.output_file.setVisible(True) # 显示输出掩膜路径
|
||||
|
||||
# 当切换到NDWI模式时,如果工作目录已设置,自动填充输出路径
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
self._auto_fill_output_path()
|
||||
else:
|
||||
# 使用现有掩膜模式:显示掩膜文件,隐藏NDWI参数和输出掩膜
|
||||
self.mask_file.setVisible(True)
|
||||
self.ndwi_group.setVisible(False)
|
||||
self.output_file.setVisible(False) # 隐藏输出掩膜路径
|
||||
|
||||
# 参考影像在两种模式下都显示
|
||||
self.img_file.setVisible(True)
|
||||
```
|
||||
|
||||
**行为说明**:
|
||||
| 模式 | 掩膜文件 | NDWI参数组 | 输出掩膜 | 参考影像 |
|
||||
|------|---------|-----------|---------|---------|
|
||||
| 使用现有掩膜文件 | ✅ 显示 | ❌ 隐藏 | ❌ 隐藏 | ✅ 显示 |
|
||||
| 使用NDWI自动生成 | ❌ 隐藏 | ✅ 显示 | ✅ 显示 | ✅ 显示 |
|
||||
|
||||
---
|
||||
|
||||
### 2. 修复路径斜杠混用问题
|
||||
|
||||
#### 修改位置 1:`Step1Panel._auto_fill_output_path()` (新增方法)
|
||||
|
||||
**核心改进**:统一使用正斜杠 `/`,避免 Windows 下的 `\` 和 `/` 混用
|
||||
|
||||
```python
|
||||
def _auto_fill_output_path(self):
|
||||
"""
|
||||
自动填充输出掩膜路径(仅在NDWI模式下)
|
||||
确保路径使用正斜杠,避免斜杠混用
|
||||
"""
|
||||
if not hasattr(self, 'work_dir') or not self.work_dir:
|
||||
return
|
||||
|
||||
# 生成输出掩膜的完整路径
|
||||
output_dir = os.path.join(self.work_dir, "1_water_mask")
|
||||
os.makedirs(output_dir, exist_ok=True) # 确保目录存在
|
||||
|
||||
# 统一使用正斜杠,避免 \ 和 / 混用
|
||||
default_output_path = os.path.join(output_dir, "water_mask_out.dat").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
```
|
||||
|
||||
**关键技术点**:
|
||||
- 使用 `os.path.join()` 构建路径(适配不同操作系统)
|
||||
- 最终通过 `.replace('\\', '/')` 统一转换为正斜杠
|
||||
- 在界面显示前完成转换,确保用户看到的路径一致
|
||||
|
||||
#### 修改位置 2:`Step1Panel.update_work_directory()`
|
||||
|
||||
**原逻辑问题**:
|
||||
- 开机时直接填充路径,不考虑当前选择的模式
|
||||
- 没有斜杠格式化
|
||||
|
||||
**新逻辑**:
|
||||
```python
|
||||
def update_work_directory(self, work_dir):
|
||||
"""
|
||||
保存工作目录引用,用于后续自动填充路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
"""
|
||||
if not work_dir:
|
||||
return
|
||||
|
||||
# 保存工作目录引用
|
||||
self.work_dir = work_dir
|
||||
|
||||
# 如果当前选中的是NDWI模式,立即填充输出路径
|
||||
if self.use_ndwi_radio.isChecked():
|
||||
self._auto_fill_output_path()
|
||||
```
|
||||
|
||||
**改进说明**:
|
||||
- 只保存工作目录引用,不立即填充
|
||||
- 仅在 NDWI 模式下才调用 `_auto_fill_output_path()`
|
||||
- 配合 `update_ui_state()` 中的切换触发逻辑
|
||||
|
||||
---
|
||||
|
||||
### 3. 底层运行逻辑的兼容性保障
|
||||
|
||||
#### 修改位置:`Step1Panel.get_config()`
|
||||
|
||||
**原逻辑问题**:
|
||||
- 无论哪种模式,都传递 `output_path` 给底层 Pipeline
|
||||
- "使用现有掩膜"模式下传递空路径可能导致底层错误
|
||||
|
||||
**新逻辑**:
|
||||
```python
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
use_ndwi = self.use_ndwi_radio.isChecked()
|
||||
|
||||
config = {
|
||||
'mask_path': None if use_ndwi else self.mask_file.get_path(),
|
||||
'use_ndwi': use_ndwi,
|
||||
'ndwi_threshold': self.ndwi_threshold.value()
|
||||
}
|
||||
|
||||
# 参考影像路径(两种模式都可能需要)
|
||||
img_path = self.img_file.get_path()
|
||||
if img_path:
|
||||
config['img_path'] = img_path
|
||||
|
||||
# 输出路径:仅在NDWI模式下有效
|
||||
if use_ndwi:
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
else:
|
||||
# 使用现有掩膜时,不传递output_path,避免底层错误尝试保存文件
|
||||
config['output_path'] = None
|
||||
|
||||
return config
|
||||
```
|
||||
|
||||
**关键改进**:
|
||||
- 根据 `use_ndwi` 模式动态决定是否传递 `output_path`
|
||||
- "使用现有掩膜"模式:强制 `output_path = None`
|
||||
- "NDWI自动生成"模式:传递用户选择的路径
|
||||
|
||||
**底层兼容性**:
|
||||
```python
|
||||
# Pipeline 中的处理逻辑(已在之前的提交中实现)
|
||||
def step1_generate_water_mask(..., output_path: Optional[str] = None):
|
||||
if use_ndwi:
|
||||
if output_path:
|
||||
ndwi_output_path = output_path # 使用用户指定路径
|
||||
else:
|
||||
ndwi_output_path = str(self.water_mask_dir / "water_mask_from_ndwi.dat")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. 主窗口初始化逻辑优化
|
||||
|
||||
#### 修改位置:`WaterQualityGUI._auto_fill_output_paths()`
|
||||
|
||||
**原逻辑问题**:
|
||||
- 开机时直接调用 `set_path()` 填充输出掩膜路径
|
||||
- 不考虑当前的单选按钮状态
|
||||
|
||||
**新逻辑**:
|
||||
```python
|
||||
def _auto_fill_output_paths(self):
|
||||
"""
|
||||
根据工作目录自动填充各步骤的输出路径
|
||||
注意:Step1 的输出路径由 update_work_directory() 根据模式自动控制
|
||||
"""
|
||||
if not self.work_dir:
|
||||
return
|
||||
|
||||
# Step1: 只传递工作目录引用,不直接填充路径
|
||||
# 路径填充由 Step1Panel 根据单选按钮状态自动控制
|
||||
if hasattr(self, 'step1_panel'):
|
||||
self.step1_panel.update_work_directory(self.work_dir)
|
||||
```
|
||||
|
||||
**改进说明**:
|
||||
- 主窗口只传递工作目录引用
|
||||
- Step1Panel 内部根据模式自主决定是否填充路径
|
||||
- 解耦主窗口和子面板的逻辑依赖
|
||||
|
||||
---
|
||||
|
||||
## 🔄 完整的交互流程
|
||||
|
||||
### 场景 1:开机启动(默认选中"使用现有掩膜文件")
|
||||
|
||||
```
|
||||
1. 主窗口启动 → QTimer.singleShot(100) 延迟弹窗
|
||||
2. 用户选择工作目录 D:\work
|
||||
3. _auto_fill_output_paths() 调用 step1_panel.update_work_directory(work_dir)
|
||||
4. Step1Panel.update_work_directory() 保存 self.work_dir = "D:\work"
|
||||
5. 检查当前模式:use_existing_radio.isChecked() = True
|
||||
6. 不调用 _auto_fill_output_path(),输出掩膜保持隐藏
|
||||
7. 用户看到的界面:
|
||||
✅ 掩膜文件输入框(显示)
|
||||
✅ 参考影像输入框(显示)
|
||||
❌ NDWI参数组(隐藏)
|
||||
❌ 输出掩膜输入框(隐藏)
|
||||
```
|
||||
|
||||
### 场景 2:用户切换到"使用NDWI自动生成"
|
||||
|
||||
```
|
||||
1. 用户点击"使用NDWI自动生成"单选按钮
|
||||
2. 触发 toggled 信号 → update_ui_state()
|
||||
3. use_ndwi = True
|
||||
4. 执行显示/隐藏逻辑:
|
||||
- self.mask_file.setVisible(False) # 隐藏掩膜文件
|
||||
- self.ndwi_group.setVisible(True) # 显示NDWI参数组
|
||||
- self.output_file.setVisible(True) # 显示输出掩膜
|
||||
5. 检查工作目录:hasattr(self, 'work_dir') = True
|
||||
6. 调用 self._auto_fill_output_path()
|
||||
7. 生成路径:
|
||||
output_dir = os.path.join("D:\work", "1_water_mask")
|
||||
path = os.path.join(output_dir, "water_mask_out.dat")
|
||||
formatted_path = path.replace('\\', '/')
|
||||
# 结果:D:/work/1_water_mask/water_mask_out.dat
|
||||
8. 自动填充到输出掩膜输入框
|
||||
9. 用户看到的界面:
|
||||
❌ 掩膜文件输入框(隐藏)
|
||||
✅ 参考影像输入框(显示)
|
||||
✅ NDWI参数组(显示)
|
||||
✅ 输出掩膜输入框(显示,已填充:D:/work/1_water_mask/water_mask_out.dat)
|
||||
```
|
||||
|
||||
### 场景 3:用户点击"独立运行此步骤"
|
||||
|
||||
#### 当前选择:"使用现有掩膜文件"
|
||||
```python
|
||||
config = {
|
||||
'mask_path': "D:/data/existing_mask.dat", # 用户选择的现有掩膜
|
||||
'use_ndwi': False,
|
||||
'ndwi_threshold': 0.4,
|
||||
'img_path': "D:/data/image.dat",
|
||||
'output_path': None # ✅ 强制为 None,不尝试保存
|
||||
}
|
||||
```
|
||||
|
||||
#### 当前选择:"使用NDWI自动生成"
|
||||
```python
|
||||
config = {
|
||||
'mask_path': None, # NDWI模式不需要现有掩膜
|
||||
'use_ndwi': True,
|
||||
'ndwi_threshold': 0.4,
|
||||
'img_path': "D:/data/image.dat",
|
||||
'output_path': "D:/work/1_water_mask/water_mask_out.dat" # ✅ 传递输出路径
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ✅ 测试检查点
|
||||
|
||||
### UI 显示测试
|
||||
- [ ] 开机启动后,默认选中"使用现有掩膜文件",输出掩膜输入框应隐藏
|
||||
- [ ] 切换到"使用NDWI自动生成",输出掩膜输入框应显示,并自动填充路径
|
||||
- [ ] 切换回"使用现有掩膜文件",输出掩膜输入框应再次隐藏
|
||||
- [ ] 所有自动填充的路径应使用正斜杠 `/`,无 `\` 混用
|
||||
|
||||
### 路径格式测试
|
||||
- [ ] 工作目录:`D:\work` → 输出路径应显示为:`D:/work/1_water_mask/water_mask_out.dat`
|
||||
- [ ] 工作目录:`C:\Users\Test\Documents` → 输出路径应显示为:`C:/Users/Test/Documents/1_water_mask/water_mask_out.dat`
|
||||
|
||||
### 运行逻辑测试
|
||||
- [ ] "使用现有掩膜"模式运行:验证 `config['output_path'] == None`
|
||||
- [ ] "NDWI自动生成"模式运行:验证 `config['output_path']` 为有效路径字符串
|
||||
- [ ] 底层 Pipeline 接收 `output_path=None` 时不报错
|
||||
|
||||
---
|
||||
|
||||
## 📝 代码修改总结
|
||||
|
||||
| 文件 | 修改内容 | 行数变化 |
|
||||
|------|---------|---------|
|
||||
| `src/gui/water_quality_gui.py` | Step1Panel.update_ui_state() | +6 / -3 |
|
||||
| `src/gui/water_quality_gui.py` | Step1Panel.update_work_directory() | +10 / -8 |
|
||||
| `src/gui/water_quality_gui.py` | Step1Panel._auto_fill_output_path() (新增) | +15 / 0 |
|
||||
| `src/gui/water_quality_gui.py` | Step1Panel.get_config() | +12 / -6 |
|
||||
| `src/gui/water_quality_gui.py` | WaterQualityGUI._auto_fill_output_paths() | +3 / -4 |
|
||||
|
||||
**总计**:约 **+46 / -21** 行
|
||||
|
||||
---
|
||||
|
||||
## 🎯 优化效果
|
||||
|
||||
### 用户体验提升
|
||||
1. ✅ UI 更简洁:不需要的组件自动隐藏
|
||||
2. ✅ 路径一致性:所有路径显示统一使用正斜杠
|
||||
3. ✅ 自动化程度提高:切换模式时自动填充/清空路径
|
||||
|
||||
### 代码质量提升
|
||||
1. ✅ 职责分离:主窗口不直接操作子面板的路径填充
|
||||
2. ✅ 逻辑内聚:Step1Panel 内部自主管理显示和路径
|
||||
3. ✅ 兼容性保障:底层 Pipeline 不会收到无效的 output_path
|
||||
|
||||
### 维护性提升
|
||||
1. ✅ 新增 `_auto_fill_output_path()` 方法,单一职责
|
||||
2. ✅ 路径格式化逻辑集中在一处,便于修改
|
||||
3. ✅ 注释清晰,说明了各模式下的行为
|
||||
|
||||
---
|
||||
|
||||
## 🔧 后续可能的扩展
|
||||
|
||||
如果其他步骤也有类似的路径填充需求,可以考虑:
|
||||
1. 提取公共方法 `format_path_separator(path)` 到工具类
|
||||
2. 在 `FileSelectWidget` 类中增加路径格式化的内置支持
|
||||
3. 为所有路径输入框添加统一的验证和格式化逻辑
|
||||
|
||||
---
|
||||
|
||||
**文档生成时间**: 2026-05-06
|
||||
**修改人员**: DXC
|
||||
**关联提交**: (待提交)
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 30 KiB |
@ -1,26 +1,58 @@
|
||||
# 水质参数反演分析系统 - Python 依赖
|
||||
# 安装: pip install -r requirements.txt
|
||||
#
|
||||
# 说明:
|
||||
# - Windows 下 GDAL 若 pip 安装失败,建议用 conda-forge: conda install -c conda-forge gdal
|
||||
# 或使用已编译的 GDAL wheel / OSGeo4W,并保证与 rasterio 版本匹配。
|
||||
# - Word 报告(report_word)与 GUI「报告生成」页依赖 python-docx;AI 解读走 Ollama HTTP API,
|
||||
# 无需额外 pip 包(本地或远程部署 Ollama 即可)。
|
||||
|
||||
# ---------- GUI ----------
|
||||
PyQt5>=5.15.0
|
||||
|
||||
# ---------- 科学计算 ----------
|
||||
# 注:当前工程打包/运行日志显示使用 Python 3.12,因此下限按 Py3.12 兼容版本设置
|
||||
numpy>=1.26.0
|
||||
scipy>=1.11.0
|
||||
pandas>=2.0.0
|
||||
|
||||
# ---------- 机器学习 ----------
|
||||
scikit-learn>=1.4.0
|
||||
# xgboost>=2.0.0 # 可选;仅在环境已安装时 spec 会自动打入
|
||||
# lightgbm>=4.0.0 # 可选;当前流水线默认未启用
|
||||
|
||||
# ---------- 地理空间 ----------
|
||||
rasterio>=1.3.9
|
||||
fiona>=1.9.5
|
||||
shapely>=2.0.0
|
||||
geopandas>=0.14.0
|
||||
pyproj>=3.6.0
|
||||
spectral>=0.22.0
|
||||
|
||||
# ---------- 图像 ----------
|
||||
opencv-python>=4.5.0
|
||||
Pillow>=8.0.0
|
||||
scikit-image>=0.22.0
|
||||
|
||||
# ---------- 可视化 ----------
|
||||
matplotlib>=3.8.0
|
||||
seaborn>=0.11.0
|
||||
matplotlib-scalebar>=0.8.0
|
||||
|
||||
# ---------- 信号处理 ----------
|
||||
PyWavelets>=1.1.0
|
||||
|
||||
# ---------- 通用工具 ----------
|
||||
joblib>=1.1.0
|
||||
tqdm>=4.62.0
|
||||
PyYAML>=6.0
|
||||
|
||||
# ---------- 表格导出(.xlsx)----------
|
||||
openpyxl>=3.0.0
|
||||
|
||||
# ---------- Word 报告生成 ----------
|
||||
python-docx>=1.1.0
|
||||
lxml>=4.9.0
|
||||
|
||||
# ---------- 打包(可选,仅构建 exe 时需要)----------
|
||||
pyinstaller>=6.0.0
|
||||
pykrige>=1.7.3
|
||||
@ -1,36 +0,0 @@
|
||||
"""
|
||||
算法层模块
|
||||
包含插值算法和耀斑检测算法等核心数学计算
|
||||
"""
|
||||
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
|
||||
from src.core.algorithms.glint_detection.detectors import (
|
||||
otsu_threshold,
|
||||
zscore_threshold,
|
||||
percentile_threshold,
|
||||
iqr_outlier_detection,
|
||||
adaptive_threshold,
|
||||
multi_band_glint_detection,
|
||||
percentile_stretch,
|
||||
filter_large_components,
|
||||
create_shoreline_buffer,
|
||||
remove_shoreline_buffer,
|
||||
calculate_glint_mask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 插值
|
||||
'interpolate_pixels',
|
||||
'interpolate_zero_pixels_batch',
|
||||
# 耀斑检测
|
||||
'otsu_threshold',
|
||||
'zscore_threshold',
|
||||
'percentile_threshold',
|
||||
'iqr_outlier_detection',
|
||||
'adaptive_threshold',
|
||||
'multi_band_glint_detection',
|
||||
'percentile_stretch',
|
||||
'filter_large_components',
|
||||
'create_shoreline_buffer',
|
||||
'remove_shoreline_buffer',
|
||||
'calculate_glint_mask',
|
||||
]
|
||||
@ -1,31 +0,0 @@
|
||||
"""
|
||||
耀斑检测算法模块
|
||||
包含各种耀斑检测的核心数学计算函数
|
||||
"""
|
||||
from src.core.algorithms.glint_detection.detectors import (
|
||||
otsu_threshold,
|
||||
zscore_threshold,
|
||||
percentile_threshold,
|
||||
iqr_outlier_detection,
|
||||
adaptive_threshold,
|
||||
multi_band_glint_detection,
|
||||
percentile_stretch,
|
||||
filter_large_components,
|
||||
create_shoreline_buffer,
|
||||
remove_shoreline_buffer,
|
||||
calculate_glint_mask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'otsu_threshold',
|
||||
'zscore_threshold',
|
||||
'percentile_threshold',
|
||||
'iqr_outlier_detection',
|
||||
'adaptive_threshold',
|
||||
'multi_band_glint_detection',
|
||||
'percentile_stretch',
|
||||
'filter_large_components',
|
||||
'create_shoreline_buffer',
|
||||
'remove_shoreline_buffer',
|
||||
'calculate_glint_mask',
|
||||
]
|
||||
@ -1,595 +0,0 @@
|
||||
"""
|
||||
耀斑检测算法模块
|
||||
|
||||
包含各种耀斑检测的核心数学计算函数,纯数学逻辑,不涉及文件I/O。
|
||||
支持的方法:otsu, zscore, percentile, iqr, adaptive, multi_band
|
||||
|
||||
本模块是从 src/utils/find_severe_glint_area.py 抽取出来的核心算法部分。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, List, Tuple
|
||||
from functools import wraps
|
||||
|
||||
try:
|
||||
import cv2
|
||||
CV2_AVAILABLE = True
|
||||
except ImportError:
|
||||
CV2_AVAILABLE = False
|
||||
|
||||
|
||||
def timeit(func):
|
||||
"""装饰器:测量函数执行时间"""
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
import time
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
print(f"[{func.__name__}] 耗时: {end - start:.2f}s")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 百分位数拉伸
|
||||
# =============================================================================
|
||||
|
||||
def percentile_stretch(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
lower_percentile: float = 2,
|
||||
upper_percentile: float = 98,
|
||||
output_range: Tuple[int, int] = (0, 255)
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用百分位数裁剪进行归一化,适用于低反射率数据
|
||||
通过排除极值,更好地利用数据的动态范围
|
||||
|
||||
Args:
|
||||
img: 输入图像数组(反射率值,通常在0-1之间)
|
||||
data_water_mask: 水域掩膜
|
||||
lower_percentile: 下百分位数,用于裁剪最小值(默认2)
|
||||
upper_percentile: 上百分位数,用于裁剪最大值(默认98)
|
||||
output_range: 输出范围,默认(0, 255)
|
||||
|
||||
Returns:
|
||||
归一化后的图像数组(整数类型)
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return img.astype(np.int32)
|
||||
|
||||
p_lower = np.percentile(valid_pixels, lower_percentile)
|
||||
p_upper = np.percentile(valid_pixels, upper_percentile)
|
||||
|
||||
if p_lower >= p_upper:
|
||||
p_lower = np.percentile(valid_pixels, 1)
|
||||
p_upper = np.percentile(valid_pixels, 99)
|
||||
if p_lower >= p_upper:
|
||||
p_upper = valid_pixels.max()
|
||||
p_lower = valid_pixels.min()
|
||||
|
||||
img_clipped = np.clip(img, p_lower, p_upper)
|
||||
|
||||
if p_upper > p_lower:
|
||||
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (
|
||||
output_range[1] - output_range[0]
|
||||
) + output_range[0]
|
||||
else:
|
||||
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
|
||||
|
||||
return img_stretched.astype(np.int32)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Otsu阈值分割
|
||||
# =============================================================================
|
||||
|
||||
def otsu_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
ignore_value: int = 0,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于Otsu方法的自动阈值分割
|
||||
通过最大化类间方差找到最佳分割阈值
|
||||
|
||||
Args:
|
||||
img: 输入图像数组(整数值)
|
||||
data_water_mask: 水域掩膜
|
||||
ignore_value: 忽略的值(默认为0)
|
||||
foreground: 耀斑区域值(默认1)
|
||||
background: 背景值(默认0)
|
||||
|
||||
Returns:
|
||||
二值化检测结果数组
|
||||
"""
|
||||
height, width = img.shape
|
||||
|
||||
max_value = int(np.max(img[img > ignore_value])) + 1
|
||||
if max_value < 2:
|
||||
max_value = 256
|
||||
|
||||
hist = np.zeros([max_value], np.float32)
|
||||
|
||||
invalid_counter = 0
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
|
||||
invalid_counter += 1
|
||||
continue
|
||||
hist[img[i, j]] += 1
|
||||
|
||||
total_valid = height * width - invalid_counter
|
||||
if total_valid <= 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
hist /= total_valid
|
||||
|
||||
threshold = 0
|
||||
deltaMax = 0
|
||||
|
||||
for i in range(max_value):
|
||||
wA = sum(hist[:i + 1])
|
||||
wB = sum(hist[i + 1:])
|
||||
if wA == 0:
|
||||
wA = 1e-10
|
||||
if wB == 0:
|
||||
wB = 1e-10
|
||||
|
||||
uAtmp = sum(j * hist[j] for j in range(i + 1))
|
||||
uBtmp = sum(j * hist[j] for j in range(i + 1, max_value))
|
||||
uA = uAtmp / wA
|
||||
uB = uBtmp / wB
|
||||
u = uAtmp + uBtmp
|
||||
|
||||
deltaTmp = wA * ((uA - u) ** 2) + wB * ((uB - u) ** 2)
|
||||
if deltaTmp > deltaMax:
|
||||
deltaMax = deltaTmp
|
||||
threshold = i
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > threshold] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Z-score阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def zscore_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
z_threshold: float = 2.5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于Z-score(标准化分数)的耀斑检测方法
|
||||
使用统计方法识别异常高亮的像素,对数据分布不敏感
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
z_threshold: Z-score阈值,默认2.5(即超过均值2.5个标准差)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
mean_val = np.mean(valid_pixels)
|
||||
std_val = np.std(valid_pixels)
|
||||
|
||||
if std_val == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
z_scores = np.zeros_like(img, dtype=np.float32)
|
||||
valid_mask = (data_water_mask > 0) & np.isfinite(img)
|
||||
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[z_scores > z_threshold] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 百分位数阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def percentile_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
percentile: float = 95,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于百分位数的耀斑检测方法
|
||||
使用百分位数作为阈值,对异常值更稳健
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
percentile: 百分位数阈值,默认95(即超过95%的像素值)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
threshold_val = np.percentile(valid_pixels, percentile)
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > threshold_val] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# IQR异常值检测
|
||||
# =============================================================================
|
||||
|
||||
def iqr_outlier_detection(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
iqr_multiplier: float = 1.5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于IQR(四分位距)的异常值检测方法
|
||||
使用四分位距识别异常高亮的像素,对数据分布不敏感
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
iqr_multiplier: IQR倍数,默认1.5(标准异常值检测)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
q1 = np.percentile(valid_pixels, 25)
|
||||
q3 = np.percentile(valid_pixels, 75)
|
||||
iqr = q3 - q1
|
||||
|
||||
upper_bound = q3 + iqr_multiplier * iqr
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > upper_bound] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 自适应阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def adaptive_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
window_size: int = 15,
|
||||
percentile: float = 90,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
自适应阈值方法
|
||||
基于局部统计特性进行阈值分割,对光照变化更稳健
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
window_size: 局部窗口大小(奇数)
|
||||
percentile: 局部百分位数阈值
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
height, width = img.shape
|
||||
|
||||
if window_size % 2 == 0:
|
||||
window_size += 1
|
||||
|
||||
half_window = window_size // 2
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
for i in range(half_window, height - half_window):
|
||||
for j in range(half_window, width - half_window):
|
||||
if data_water_mask[i, j] == 0:
|
||||
continue
|
||||
|
||||
local_window = img[i - half_window:i + half_window + 1,
|
||||
j - half_window:j + half_window + 1]
|
||||
local_mask = data_water_mask[i - half_window:i + half_window + 1,
|
||||
j - half_window:j + half_window + 1]
|
||||
|
||||
valid_pixels = local_window[local_mask > 0]
|
||||
|
||||
if len(valid_pixels) > 0:
|
||||
local_th = np.percentile(valid_pixels, percentile)
|
||||
if img[i, j] > local_th:
|
||||
det_img[i, j] = foreground
|
||||
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 多波段融合耀斑检测
|
||||
# =============================================================================
|
||||
|
||||
def multi_band_glint_detection(
|
||||
nir_band: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
glint_waves: List[float],
|
||||
weights: Optional[List[float]] = None,
|
||||
method: str = 'zscore',
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95,
|
||||
sub_band_arrays: Optional[List[np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
多波段融合的耀斑检测方法
|
||||
结合多个波段的耀斑特征,提高检测的稳健性
|
||||
|
||||
Args:
|
||||
nir_band: 近红外波段数组(主波段,用于兼容性)
|
||||
water_mask: 水域掩膜数组
|
||||
glint_waves: 用于检测的波长列表,如[750, 800, 850]
|
||||
weights: 各波段的权重,如果为None则使用等权重
|
||||
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
|
||||
z_threshold: Z-score阈值(当method='zscore'时使用)
|
||||
percentile: 百分位数阈值(当method='percentile'时使用)
|
||||
sub_band_arrays: 子波段数组列表(如果提供,与 glint_waves 一一对应)
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
if weights is None:
|
||||
weights = [1.0 / len(glint_waves)] * len(glint_waves)
|
||||
|
||||
if len(weights) != len(glint_waves):
|
||||
raise ValueError("权重数量必须与波长数量相同")
|
||||
|
||||
fused_band = None
|
||||
|
||||
if sub_band_arrays is not None and len(sub_band_arrays) == len(glint_waves):
|
||||
for i, band_array in enumerate(sub_band_arrays):
|
||||
if fused_band is None:
|
||||
fused_band = (band_array * weights[i]).astype(np.float32)
|
||||
else:
|
||||
fused_band = (fused_band + band_array * weights[i]).astype(np.float32)
|
||||
else:
|
||||
fused_band = nir_band.astype(np.float32)
|
||||
|
||||
if method == 'otsu':
|
||||
stretched = percentile_stretch(fused_band, water_mask, 2, 98)
|
||||
return otsu_threshold(stretched, water_mask)
|
||||
elif method == 'zscore':
|
||||
return zscore_threshold(fused_band, water_mask, z_threshold)
|
||||
elif method == 'percentile':
|
||||
return percentile_threshold(fused_band, water_mask, percentile)
|
||||
else:
|
||||
raise ValueError(f"不支持的方法: {method}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 连通域过滤
|
||||
# =============================================================================
|
||||
|
||||
def filter_large_components(
|
||||
binary_img: np.ndarray,
|
||||
max_area: Optional[int] = None,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
过滤掉面积超过阈值的连通域
|
||||
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
|
||||
|
||||
Args:
|
||||
binary_img: 二值化图像
|
||||
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
过滤后的二值化图像
|
||||
"""
|
||||
if max_area is None or max_area <= 0:
|
||||
return binary_img
|
||||
|
||||
if CV2_AVAILABLE:
|
||||
binary_for_label = (binary_img == foreground).astype(np.uint8)
|
||||
num_features, labeled_array, stats, _ = cv2.connectedComponentsWithStats(
|
||||
binary_for_label, connectivity=8
|
||||
)
|
||||
|
||||
if num_features == 0:
|
||||
return binary_img
|
||||
|
||||
component_sizes = stats[1:, cv2.CC_STAT_AREA]
|
||||
keep_labels = np.where(component_sizes <= max_area)[0] + 1
|
||||
|
||||
keep_mask = np.isin(labeled_array, keep_labels)
|
||||
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||||
filtered[keep_mask] = foreground
|
||||
|
||||
return filtered
|
||||
else:
|
||||
from scipy import ndimage
|
||||
labeled_array, num_features = ndimage.label(
|
||||
(binary_img == foreground).astype(np.int32)
|
||||
)
|
||||
|
||||
if num_features == 0:
|
||||
return binary_img
|
||||
|
||||
component_sizes = ndimage.sum(
|
||||
(labeled_array == i).astype(np.int32),
|
||||
labeled_array,
|
||||
range(1, num_features + 1)
|
||||
)
|
||||
|
||||
keep_mask = np.isin(labeled_array, [i + 1 for i, s in enumerate(component_sizes) if s <= max_area])
|
||||
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||||
filtered[keep_mask] = foreground
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 岸边缓冲区处理
|
||||
# =============================================================================
|
||||
|
||||
def create_shoreline_buffer(
|
||||
water_mask: np.ndarray,
|
||||
buffer_size: int = 5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
创建岸边缓冲区掩膜(向内缓冲)
|
||||
用于去除岸边附近的错误耀斑检测区域
|
||||
|
||||
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
|
||||
|
||||
Args:
|
||||
water_mask: 水域掩膜数组(水域=1,非水域=0)
|
||||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
岸边缓冲区掩膜(缓冲区区域=1,其他=0)
|
||||
"""
|
||||
if buffer_size <= 0:
|
||||
return np.zeros_like(water_mask, dtype=np.int32)
|
||||
|
||||
water_binary = (water_mask > 0).astype(np.uint8)
|
||||
structure_size = buffer_size * 2 + 1
|
||||
structure = np.ones((structure_size, structure_size), dtype=np.uint8)
|
||||
|
||||
if CV2_AVAILABLE:
|
||||
eroded_water = cv2.erode(water_binary, structure).astype(np.int32)
|
||||
else:
|
||||
from scipy import ndimage
|
||||
eroded_water = ndimage.binary_erosion(water_binary, structure).astype(np.int32)
|
||||
|
||||
buffer_mask = (water_binary - eroded_water).astype(np.int32)
|
||||
|
||||
return buffer_mask
|
||||
|
||||
|
||||
def remove_shoreline_buffer(
|
||||
glint_mask: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
buffer_size: int = 5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
从耀斑掩膜中去除岸边缓冲区内的区域
|
||||
|
||||
Args:
|
||||
glint_mask: 耀斑掩膜数组
|
||||
water_mask: 水域掩膜数组
|
||||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
去除岸边缓冲区后的耀斑掩膜
|
||||
"""
|
||||
if buffer_size <= 0:
|
||||
return glint_mask
|
||||
|
||||
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
|
||||
|
||||
cleaned = glint_mask.copy()
|
||||
cleaned[buffer_mask > 0] = background
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 高级组合函数
|
||||
# =============================================================================
|
||||
|
||||
def calculate_glint_mask(
|
||||
nir_band: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
method: str = 'otsu',
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95,
|
||||
iqr_multiplier: float = 1.5,
|
||||
window_size: int = 15,
|
||||
apply_percentile_stretch: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
计算耀斑掩膜的统一入口函数
|
||||
|
||||
Args:
|
||||
nir_band: 近红外波段数组
|
||||
water_mask: 水域掩膜
|
||||
method: 检测方法 ('otsu', 'zscore', 'percentile', 'iqr', 'adaptive')
|
||||
z_threshold: Z-score阈值
|
||||
percentile: 百分位数阈值
|
||||
iqr_multiplier: IQR倍数
|
||||
window_size: 自适应阈值窗口大小
|
||||
apply_percentile_stretch: 是否对otsu和adaptive方法应用百分位数拉伸
|
||||
|
||||
Returns:
|
||||
二值化耀斑掩膜
|
||||
"""
|
||||
if method == 'otsu':
|
||||
if apply_percentile_stretch:
|
||||
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
|
||||
return otsu_threshold(stretched, water_mask)
|
||||
else:
|
||||
return otsu_threshold(nir_band.astype(np.int32), water_mask)
|
||||
elif method == 'zscore':
|
||||
return zscore_threshold(nir_band, water_mask, z_threshold)
|
||||
elif method == 'percentile':
|
||||
return percentile_threshold(nir_band, water_mask, percentile)
|
||||
elif method == 'iqr':
|
||||
return iqr_outlier_detection(nir_band, water_mask, iqr_multiplier)
|
||||
elif method == 'adaptive':
|
||||
if apply_percentile_stretch:
|
||||
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
|
||||
return adaptive_threshold(stretched, water_mask, window_size, percentile)
|
||||
else:
|
||||
return adaptive_threshold(nir_band.astype(np.int32), water_mask, window_size, percentile)
|
||||
else:
|
||||
raise ValueError(f"不支持的方法: {method}")
|
||||
@ -1,7 +0,0 @@
|
||||
"""
|
||||
插值算法模块
|
||||
包含0值像素插值的核心数学逻辑
|
||||
"""
|
||||
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
|
||||
|
||||
__all__ = ['interpolate_pixels', 'interpolate_zero_pixels_batch']
|
||||
@ -1,320 +0,0 @@
|
||||
"""
|
||||
像素插值算法模块
|
||||
|
||||
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
|
||||
支持多种插值方法:nearest, bilinear, spline (RBF), kriging。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import griddata, RBFInterpolator
|
||||
from scipy.spatial import cKDTree
|
||||
SCIPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
SCIPY_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
def interpolate_pixels(
|
||||
image_stack: np.ndarray,
|
||||
zero_coords: np.ndarray,
|
||||
valid_coords: np.ndarray,
|
||||
valid_values: np.ndarray,
|
||||
interpolation_method: str = 'nearest',
|
||||
water_mask: Optional[np.ndarray] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对指定坐标的像素进行插值(核心数学函数,不涉及文件I/O)
|
||||
|
||||
Args:
|
||||
image_stack: 影像数据堆叠,形状为 (height, width, n_bands) 的 float32 数组
|
||||
zero_coords: 需要插值的像素坐标,形状为 (n_zero, 2),每行是 [x, y]
|
||||
valid_coords: 有效像素坐标,形状为 (n_valid, 2)
|
||||
valid_values: 有效像素对应的值,形状为 (n_valid,) 或 (n_valid, n_bands)
|
||||
interpolation_method: 插值方法,可选 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
water_mask: 可选的水域掩膜数组
|
||||
|
||||
Returns:
|
||||
插值后的影像副本,形状与 image_stack 相同
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
|
||||
height, width, n_bands = image_stack.shape
|
||||
result = image_stack.copy()
|
||||
|
||||
# 兼容中文和各种格式的method参数
|
||||
raw_method = str(interpolation_method).lower()
|
||||
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
|
||||
method = 'nearest'
|
||||
elif 'bilinear' in raw_method or '线性' in raw_method or '双线性' in raw_method:
|
||||
method = 'bilinear'
|
||||
elif 'spline' in raw_method or '样条' in raw_method or 'rbf' in raw_method:
|
||||
method = 'spline'
|
||||
elif 'kriging' in raw_method or '克里金' in raw_method:
|
||||
method = 'kriging'
|
||||
else:
|
||||
method = 'nearest'
|
||||
|
||||
if len(valid_values) == 0:
|
||||
return result
|
||||
|
||||
is_multiband = len(valid_values.shape) > 1 and valid_values.shape[1] > 1
|
||||
|
||||
if is_multiband:
|
||||
for band_idx in range(n_bands):
|
||||
band_valid_values = valid_values[:, band_idx]
|
||||
interpolated_values = _interpolate_single_band(
|
||||
zero_coords, valid_coords, band_valid_values, method
|
||||
)
|
||||
y_coords = zero_coords[:, 1].astype(int)
|
||||
x_coords = zero_coords[:, 0].astype(int)
|
||||
result[y_coords, x_coords, band_idx] = interpolated_values
|
||||
else:
|
||||
interpolated_values = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values, method
|
||||
)
|
||||
y_coords = zero_coords[:, 1].astype(int)
|
||||
x_coords = zero_coords[:, 0].astype(int)
|
||||
result[y_coords, x_coords] = interpolated_values
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _interpolate_single_band(
|
||||
zero_coords: np.ndarray,
|
||||
valid_coords: np.ndarray,
|
||||
valid_values: np.ndarray,
|
||||
method: str
|
||||
) -> np.ndarray:
|
||||
"""对单个波段执行插值计算"""
|
||||
if method == 'nearest':
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords)
|
||||
return valid_values[indices]
|
||||
|
||||
elif method == 'bilinear':
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
elif method == 'spline':
|
||||
try:
|
||||
max_points = 10000
|
||||
if len(valid_values) > max_points:
|
||||
indices = np.random.choice(len(valid_values), max_points, replace=False)
|
||||
sample_coords = valid_coords[indices]
|
||||
sample_values = valid_values[indices]
|
||||
else:
|
||||
sample_coords = valid_coords
|
||||
sample_values = valid_values
|
||||
rbf = RBFInterpolator(sample_coords, sample_values, kernel='thin_plate_spline')
|
||||
interpolated = rbf(zero_coords)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
except Exception:
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
elif method == 'kriging':
|
||||
try:
|
||||
from src.utils.kriging import KrigingInterpolator
|
||||
interpolator = KrigingInterpolator()
|
||||
max_points = 5000
|
||||
if len(valid_values) > max_points:
|
||||
indices = np.random.choice(len(valid_values), max_points, replace=False)
|
||||
sample_coords = valid_coords[indices]
|
||||
sample_values = valid_values[indices]
|
||||
else:
|
||||
sample_coords = valid_coords
|
||||
sample_values = valid_values
|
||||
interpolated = griddata(
|
||||
sample_coords, sample_values, zero_coords,
|
||||
method='cubic', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
except Exception:
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
return np.zeros(len(zero_coords))
|
||||
|
||||
|
||||
def interpolate_zero_pixels_batch(
|
||||
img_path: str,
|
||||
interpolation_method: str = 'nearest',
|
||||
output_path: Optional[str] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
deglint_dir: Optional[str] = None,
|
||||
callback_progress: Optional[callable] = None
|
||||
) -> Tuple[str, Optional[np.ndarray]]:
|
||||
"""
|
||||
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
output_path: 输出文件路径(如果为None,自动生成)
|
||||
water_mask: 水域掩膜(文件路径或数组)
|
||||
deglint_dir: 去耀斑目录(用于生成默认输出路径)
|
||||
callback_progress: 进度回调函数
|
||||
|
||||
Returns:
|
||||
(output_path, interpolated_image_stack) 元组
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
# 确定输出路径
|
||||
if output_path is None and deglint_dir is not None:
|
||||
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
|
||||
|
||||
# 检查文件是否已存在
|
||||
if output_path and Path(output_path).exists():
|
||||
return output_path, None
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
# 读取所有波段数据
|
||||
all_bands = []
|
||||
for band_idx in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(band_idx)
|
||||
band_data = band.ReadAsArray().astype(np.float32)
|
||||
all_bands.append(band_data)
|
||||
|
||||
image_stack = np.dstack(all_bands)
|
||||
|
||||
# 读取水域掩膜
|
||||
mask_array = None
|
||||
if water_mask is not None:
|
||||
if isinstance(water_mask, str):
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
elif isinstance(water_mask, np.ndarray):
|
||||
mask_array = water_mask
|
||||
|
||||
# 找出所有波段都为0的像素点
|
||||
all_bands_zero = np.all(image_stack == 0, axis=2)
|
||||
|
||||
if mask_array is not None:
|
||||
all_bands_zero = all_bands_zero & (mask_array > 0)
|
||||
|
||||
zero_pixel_count = np.sum(all_bands_zero)
|
||||
if zero_pixel_count == 0:
|
||||
# 无需插值,直接保存
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(all_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
out_dataset = None
|
||||
return output_path, image_stack
|
||||
|
||||
# 获取坐标
|
||||
zero_y, zero_x = np.where(all_bands_zero)
|
||||
zero_coords = np.column_stack([zero_x, zero_y])
|
||||
|
||||
valid_mask = ~all_bands_zero
|
||||
valid_y, valid_x = np.where(valid_mask)
|
||||
valid_coords = np.column_stack([valid_x, valid_y])
|
||||
|
||||
if len(valid_coords) == 0:
|
||||
raise ValueError("没有有效像素可用于插值")
|
||||
|
||||
# 逐波段插值
|
||||
interpolated_bands = []
|
||||
for band_idx in range(n_bands):
|
||||
if callback_progress:
|
||||
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
|
||||
band_data = all_bands[band_idx].copy()
|
||||
valid_values_band = band_data[valid_mask]
|
||||
|
||||
if len(valid_values_band) == 0:
|
||||
interpolated_bands.append(band_data)
|
||||
continue
|
||||
|
||||
band_result = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values_band, interpolation_method
|
||||
)
|
||||
band_data[all_bands_zero] = band_result
|
||||
interpolated_bands.append(band_data)
|
||||
|
||||
# 保存结果
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(interpolated_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
out_dataset = None
|
||||
|
||||
result_stack = np.dstack(interpolated_bands)
|
||||
return output_path, result_stack
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
# import preprocessing
|
||||
import os
|
||||
|
||||
try:
|
||||
@ -7,301 +8,283 @@ try:
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
class Hedley:
|
||||
def __init__(self, img_path, shp_path=None, NIR_band=47, water_mask=None,
|
||||
output_path=None, block_size=1000):
|
||||
def __init__(self, im_aligned, shp_path=None, NIR_band = 47, water_mask=None, output_path=None):
|
||||
"""
|
||||
Hedley 耀斑去除算法 - 分块逐波段处理版本
|
||||
|
||||
:param img_path (str): 输入影像文件路径(GDAL可读取的格式)
|
||||
:param shp_path (str, optional): 深水区域shapefile,已废弃,请使用water_mask
|
||||
:param NIR_band (int): NIR波段索引(默认47,对应842.36nm)
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜
|
||||
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
|
||||
:param block_size (int): 分块大小(默认1000)
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
|
||||
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
|
||||
:param NIR_band (int): band index for NIR band which corresponds to 842.36nm, which corresponds closely to the NIR band in Micasense
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||||
如果为None,则不保存
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
self.img_path = img_path
|
||||
self.NIR_band = int(float(NIR_band))
|
||||
self.water_mask = None
|
||||
self.water_mask_path = water_mask
|
||||
self.im_aligned = im_aligned
|
||||
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
|
||||
self.NIR_band = NIR_band
|
||||
self.n_bands = im_aligned.shape[-1]
|
||||
self.height = im_aligned.shape[0]
|
||||
self.width = im_aligned.shape[1]
|
||||
self.output_path = output_path
|
||||
self.block_size = block_size
|
||||
self.R_min = None
|
||||
self.corr_list = None # 全局协方差系数列表
|
||||
|
||||
# 打开影像
|
||||
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if self.dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.n_bands = self.dataset.RasterCount
|
||||
|
||||
def _load_water_mask(self):
|
||||
"""延迟加载水域掩膜"""
|
||||
if self.water_mask_path is None:
|
||||
|
||||
# 加载水域掩膜
|
||||
self.water_mask = self._load_water_mask(water_mask)
|
||||
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 如果存在水域掩膜,只在掩膜内计算R_min
|
||||
if self.water_mask is not None:
|
||||
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
|
||||
self.R_min = np.percentile(nir_band_masked, 5, interpolation='nearest') if nir_band_masked.size > 0 else 0
|
||||
else:
|
||||
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, interpolation='nearest')
|
||||
|
||||
def _read_shp_to_bbox(self, shp_path):
|
||||
"""
|
||||
读取shapefile并提取边界框
|
||||
|
||||
:param shp_path (str): shapefile文件路径
|
||||
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
|
||||
"""
|
||||
if not os.path.exists(shp_path):
|
||||
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
|
||||
|
||||
try:
|
||||
try:
|
||||
import geopandas as gpd
|
||||
gdf = gpd.read_file(shp_path)
|
||||
# 获取所有几何体的总边界框
|
||||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||||
min_x, min_y, max_x, max_y = bounds
|
||||
except ImportError:
|
||||
# 如果geopandas不可用,尝试使用fiona
|
||||
import fiona
|
||||
from shapely.geometry import shape
|
||||
|
||||
min_x = float('inf')
|
||||
min_y = float('inf')
|
||||
max_x = float('-inf')
|
||||
max_y = float('-inf')
|
||||
|
||||
with fiona.open(shp_path) as shp:
|
||||
for feature in shp:
|
||||
geom = shape(feature['geometry'])
|
||||
if geom:
|
||||
bounds = geom.bounds
|
||||
min_x = min(min_x, bounds[0])
|
||||
min_y = min(min_y, bounds[1])
|
||||
max_x = max(max_x, bounds[2])
|
||||
max_y = max(max_y, bounds[3])
|
||||
|
||||
# 转换为整数像素坐标
|
||||
x1 = max(0, int(min_x))
|
||||
y1 = max(0, int(min_y))
|
||||
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
|
||||
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
|
||||
|
||||
return ((x1, y1), (x2, y2))
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
|
||||
|
||||
def _load_water_mask(self, water_mask):
|
||||
"""
|
||||
加载水域掩膜
|
||||
|
||||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
:return: numpy数组或None,1表示水域,0表示非水域
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
if isinstance(self.water_mask_path, np.ndarray):
|
||||
if self.water_mask_path.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(
|
||||
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||||
)
|
||||
return (self.water_mask_path > 0).astype(np.uint8)
|
||||
|
||||
if isinstance(self.water_mask_path, str):
|
||||
if self.water_mask_path.lower().endswith('.shp'):
|
||||
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
|
||||
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(
|
||||
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||||
)
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
return None
|
||||
|
||||
def covariance_NIR(self, NIR, b):
|
||||
"""计算 NIR 与波段 b 之间的协方差系数 b_i = Cov(NIR,b) / Var(NIR)"""
|
||||
|
||||
# 如果已经是numpy数组
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||||
|
||||
# 如果是文件路径
|
||||
if isinstance(water_mask, str):
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
except ImportError:
|
||||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||||
|
||||
# 检查是否为shapefile
|
||||
if water_mask.lower().endswith('.shp'):
|
||||
# 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸)
|
||||
# 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考
|
||||
raise ValueError("Hedley类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
|
||||
else:
|
||||
# 栅格文件
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||||
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
def covariance_NIR(self,NIR,b):
|
||||
"""
|
||||
NIR & b are vectors
|
||||
reflectance for band i
|
||||
"""
|
||||
n = len(NIR)
|
||||
# 优化:减少重复计算,使用更高效的numpy操作
|
||||
nir_mean = np.mean(NIR)
|
||||
b_mean = np.mean(b)
|
||||
# 使用更高效的协方差计算
|
||||
pij = np.mean((NIR - nir_mean) * (b - b_mean))
|
||||
pjj = np.mean((NIR - nir_mean) ** 2)
|
||||
# 避免除零错误
|
||||
return pij / pjj if pjj != 0 else 0.0
|
||||
|
||||
def _scan_global_stats(self, sample_step=20):
|
||||
|
||||
def correlation_bands_reflectance(self):
|
||||
"""
|
||||
扫描全图获取全局 R_min
|
||||
|
||||
使用重采样方式扫描,大幅降低内存占用。
|
||||
calculate correlation between NIR and other bands for reflectance
|
||||
NIR_band is 750 nm
|
||||
"""
|
||||
print(f"[Hedley] 扫描全局统计量(采样步长={sample_step})...")
|
||||
water_mask = self._load_water_mask()
|
||||
|
||||
nir_samples = []
|
||||
sample_count = 0
|
||||
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
block_height = y_end - y_off
|
||||
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
|
||||
nir_band = None
|
||||
|
||||
if water_mask is not None:
|
||||
mask_block = water_mask[y_off:y_end, :]
|
||||
mask_bool = mask_block.astype(bool)
|
||||
else:
|
||||
mask_bool = np.ones((block_height, self.width), dtype=bool)
|
||||
|
||||
# If bbox is None, use the entire image
|
||||
if self.bbox is None:
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 直接使用视图,只在需要时创建扁平数组
|
||||
im_region = self.im_aligned
|
||||
mask_region = self.water_mask
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
im_region = self.im_aligned[y1:y2,x1:x2,:]
|
||||
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内计算相关性
|
||||
if mask_region is not None:
|
||||
mask_bool = mask_region.astype(bool)
|
||||
if mask_bool.any():
|
||||
nir_sampled = nir_block[mask_bool][::sample_step]
|
||||
nir_samples.append(nir_sampled)
|
||||
sample_count += nir_sampled.size
|
||||
|
||||
del nir_block, mask_block
|
||||
|
||||
if sample_count == 0:
|
||||
self.R_min = 0.0
|
||||
else:
|
||||
all_nir = np.concatenate(nir_samples)
|
||||
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
|
||||
del all_nir
|
||||
|
||||
print(f"[Hedley] 全局 R_min={self.R_min:.4f}")
|
||||
|
||||
def _compute_corr_list(self, sample_step=5):
|
||||
"""
|
||||
计算每个波段与NIR的协方差系数 corr_list[b] = Cov(NIR, band_b) / Var(NIR)
|
||||
|
||||
全分辨率扫描,逐波段读取,每波段内存 ≈ block_size²
|
||||
由于需要相关性计算,需要足够多的样本,取sample_step=5
|
||||
"""
|
||||
print(f"[Hedley] 计算全局协方差系数列表(采样步长={sample_step})...")
|
||||
water_mask = self._load_water_mask()
|
||||
|
||||
# 预收集NIR和每个波段的样本数据
|
||||
nir_samples = []
|
||||
band_samples = [[] for _ in range(self.n_bands)]
|
||||
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
block_height = y_end - y_off
|
||||
|
||||
# 读取NIR波段(每块只读一次)
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
|
||||
nir_band = None
|
||||
|
||||
# 取 NIR 样本(每块只取一次,放在波段循环外)
|
||||
if water_mask is not None:
|
||||
mask_block = water_mask[y_off:y_end, :]
|
||||
mask_bool = mask_block.astype(bool)
|
||||
# 只在掩膜内提取数据
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band][mask_bool]
|
||||
else:
|
||||
mask_bool = np.ones((block_height, self.width), dtype=bool)
|
||||
|
||||
if mask_bool.any():
|
||||
nir_sampled = nir_block[mask_bool][::sample_step]
|
||||
nir_samples.append(nir_sampled)
|
||||
|
||||
# 逐波段读取并采样(all_band 严格使用单波段切片)
|
||||
for b in range(self.n_bands):
|
||||
band = self.dataset.GetRasterBand(b + 1)
|
||||
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
|
||||
band = None
|
||||
|
||||
if mask_bool.any():
|
||||
band_sampled = block[mask_bool][::sample_step]
|
||||
band_samples[b].append(band_sampled)
|
||||
|
||||
del block
|
||||
|
||||
del nir_block
|
||||
|
||||
# 汇总并计算相关系数
|
||||
if len(nir_samples) == 0 or sum(len(s) for s in nir_samples) == 0:
|
||||
self.corr_list = [0.0] * self.n_bands
|
||||
# 如果掩膜内没有有效像素,使用全区域
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
|
||||
mask_bool = None
|
||||
else:
|
||||
all_nir = np.concatenate(nir_samples)
|
||||
self.corr_list = []
|
||||
for b in range(self.n_bands):
|
||||
all_band = np.concatenate(band_samples[b])
|
||||
corr = self.covariance_NIR(all_nir, all_band)
|
||||
self.corr_list.append(float(corr))
|
||||
|
||||
del all_nir
|
||||
for b in range(self.n_bands):
|
||||
band_samples[b] = None
|
||||
|
||||
print(f"[Hedley] 协方差系数: min={min(self.corr_list):.4f}, max={max(self.corr_list):.4f}")
|
||||
|
||||
def _process_block(self, x_off, y_off, x_size, y_size):
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
|
||||
mask_bool = None
|
||||
|
||||
# 优化:一次性计算所有波段的相关性,减少循环开销
|
||||
corr_list = []
|
||||
for v in range(self.n_bands):
|
||||
if mask_bool is not None and mask_bool.any():
|
||||
band_reflectance = im_region[:,:,v][mask_bool]
|
||||
else:
|
||||
band_reflectance = im_region[:,:,v].ravel()
|
||||
corr = self.covariance_NIR(NIR_reflectance, band_reflectance)
|
||||
corr_list.append(corr)
|
||||
|
||||
return corr_list
|
||||
|
||||
def _save_corrected_bands(self, corrected_bands):
|
||||
"""
|
||||
处理单个分块
|
||||
|
||||
Returns:
|
||||
list of np.ndarray: 校正后的波段列表
|
||||
"""
|
||||
# 读取NIR波段
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
NIR = nir_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
nir_band = None
|
||||
|
||||
# 预计算 NIR - R_min
|
||||
NIR_diff = NIR - self.R_min
|
||||
|
||||
# 获取掩膜
|
||||
water_mask = self._load_water_mask()
|
||||
if water_mask is not None:
|
||||
y_end = y_off + y_size
|
||||
x_end = x_off + x_size
|
||||
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
|
||||
else:
|
||||
mask_block = None
|
||||
|
||||
# 逐波段处理
|
||||
corrected_bands = []
|
||||
for b in range(self.n_bands):
|
||||
band = self.dataset.GetRasterBand(b + 1)
|
||||
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
band = None
|
||||
|
||||
corr = self.corr_list[b]
|
||||
# Hedley 校正公式:R_corrected = R - corr * (NIR - R_min)
|
||||
corrected = R - corr * NIR_diff
|
||||
|
||||
if mask_block is not None:
|
||||
corrected = np.where(mask_block, corrected, R)
|
||||
|
||||
corrected_bands.append(corrected)
|
||||
del R
|
||||
|
||||
del NIR, NIR_diff
|
||||
|
||||
return corrected_bands
|
||||
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
执行分块处理,返回校正后的波段列表
|
||||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_bands: 校正后的波段列表
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if self.output_path is None:
|
||||
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
|
||||
|
||||
# Step 1: 扫描全局 R_min
|
||||
self._scan_global_stats(sample_step=20)
|
||||
|
||||
# Step 2: 计算协方差系数列表
|
||||
self._compute_corr_list(sample_step=5)
|
||||
|
||||
# Step 3: 创建输出文件
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
# 将波段列表转换为数组
|
||||
corrected_array = np.stack(corrected_bands, axis=2)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
|
||||
|
||||
geotransform = self.dataset.GetGeoTransform()
|
||||
projection = self.dataset.GetProjection()
|
||||
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
else:
|
||||
bsq_path = self.output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
out_dataset = driver.Create(bsq_path, self.width, self.height,
|
||||
self.n_bands, gdal.GDT_Float32)
|
||||
if out_dataset is None:
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_array.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
|
||||
# Step 4: 分块处理
|
||||
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
|
||||
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
|
||||
total_blocks = n_blocks_x * n_blocks_y
|
||||
|
||||
print(f"[Hedley] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
|
||||
|
||||
block_idx = 0
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
y_size = y_end - y_off
|
||||
|
||||
for x_off in range(0, self.width, self.block_size):
|
||||
x_end = min(x_off + self.block_size, self.width)
|
||||
x_size = x_end - x_off
|
||||
block_idx += 1
|
||||
|
||||
print(f"[Hedley] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
|
||||
|
||||
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
|
||||
|
||||
for b in range(self.n_bands):
|
||||
out_band = out_dataset.GetRasterBand(b + 1)
|
||||
out_band.WriteArray(corrected_bands[b], x_off, y_off)
|
||||
out_band.FlushCache()
|
||||
|
||||
del corrected_bands
|
||||
|
||||
out_dataset = None
|
||||
self.dataset = None
|
||||
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"[Hedley] 校正完成,已保存至: {bsq_path}")
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
else:
|
||||
print(f"[Hedley] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)")
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
|
||||
return []
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
correction is done in reflectance
|
||||
|
||||
:return: 校正后的波段列表
|
||||
"""
|
||||
corr = self.correlation_bands_reflectance()
|
||||
NIR_reflectance = self.im_aligned[:,:,self.NIR_band]
|
||||
# 预计算NIR-R_min,避免在循环中重复计算
|
||||
NIR_diff = NIR_reflectance - self.R_min
|
||||
|
||||
# 获取水域掩膜(如果存在)
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
|
||||
def __del__(self):
|
||||
if self.dataset is not None:
|
||||
self.dataset = None
|
||||
corrected_bands = []
|
||||
for band_number in range(self.n_bands): #iterate across bands
|
||||
b = corr[band_number]
|
||||
R = self.im_aligned[:,:,band_number]
|
||||
# 优化:减少中间数组创建
|
||||
corrected_band = R - b * NIR_diff
|
||||
|
||||
# 如果存在水域掩膜,只对水域区域应用校正
|
||||
if water_mask_bool is not None:
|
||||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||||
|
||||
corrected_bands.append(corrected_band)
|
||||
|
||||
# 如果提供了输出路径,保存结果
|
||||
if self.output_path is not None:
|
||||
self._save_corrected_bands(corrected_bands)
|
||||
|
||||
return corrected_bands
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
# import preprocessing
|
||||
import os
|
||||
|
||||
try:
|
||||
@ -7,333 +8,306 @@ try:
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
class Kutser:
|
||||
def __init__(self, img_path, shp_path=None, oxy_band=38, lower_oxy=36,
|
||||
upper_oxy=49, NIR_band=47, water_mask=None, output_path=None,
|
||||
block_size=1000):
|
||||
def __init__(self, im_aligned, shp_path=None, oxy_band = 38,lower_oxy = 36, upper_oxy = 49, NIR_band = 47, water_mask=None, output_path=None):
|
||||
"""
|
||||
Kutser 耀斑去除算法 - 分块逐波段处理版本
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
|
||||
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
|
||||
:param oxy_band (int): band index for oxygen absorption band, which corresponds to 760.6nm
|
||||
:param lower_oxy (int): band index for outside oxygen absorption band, which corresponds to 742.39nm
|
||||
:param upper_oxy (int): band index for outside oxygen absorption band, which corresponds to 860.48nm
|
||||
see Kutser, Vahtmäe and Praks
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||||
如果为None,则不保存
|
||||
"""
|
||||
self.im_aligned = im_aligned
|
||||
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
|
||||
self.oxy_band = oxy_band
|
||||
self.lower_oxy = lower_oxy
|
||||
self.upper_oxy = upper_oxy
|
||||
self.NIR_band = NIR_band
|
||||
self.n_bands = im_aligned.shape[-1]
|
||||
self.height = im_aligned.shape[0]
|
||||
self.width = im_aligned.shape[1]
|
||||
self.output_path = output_path
|
||||
|
||||
# 加载水域掩膜
|
||||
self.water_mask = self._load_water_mask(water_mask)
|
||||
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 如果存在水域掩膜,只在掩膜内计算R_min
|
||||
if self.water_mask is not None:
|
||||
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
|
||||
self.R_min = np.percentile(nir_band_masked, 5, interpolation='nearest') if nir_band_masked.size > 0 else 0
|
||||
else:
|
||||
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, interpolation='nearest')
|
||||
|
||||
def _read_shp_to_bbox(self, shp_path):
|
||||
"""
|
||||
读取shapefile并提取边界框
|
||||
|
||||
:param shp_path (str): shapefile文件路径
|
||||
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
|
||||
"""
|
||||
if not os.path.exists(shp_path):
|
||||
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
|
||||
|
||||
try:
|
||||
try:
|
||||
import geopandas as gpd
|
||||
gdf = gpd.read_file(shp_path)
|
||||
# 获取所有几何体的总边界框
|
||||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||||
min_x, min_y, max_x, max_y = bounds
|
||||
except ImportError:
|
||||
# 如果geopandas不可用,尝试使用fiona
|
||||
import fiona
|
||||
from shapely.geometry import shape
|
||||
|
||||
min_x = float('inf')
|
||||
min_y = float('inf')
|
||||
max_x = float('-inf')
|
||||
max_y = float('-inf')
|
||||
|
||||
with fiona.open(shp_path) as shp:
|
||||
for feature in shp:
|
||||
geom = shape(feature['geometry'])
|
||||
if geom:
|
||||
bounds = geom.bounds
|
||||
min_x = min(min_x, bounds[0])
|
||||
min_y = min(min_y, bounds[1])
|
||||
max_x = max(max_x, bounds[2])
|
||||
max_y = max(max_y, bounds[3])
|
||||
|
||||
# 转换为整数像素坐标
|
||||
x1 = max(0, int(min_x))
|
||||
y1 = max(0, int(min_y))
|
||||
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
|
||||
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
|
||||
|
||||
return ((x1, y1), (x2, y2))
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
|
||||
|
||||
def _load_water_mask(self, water_mask):
|
||||
"""
|
||||
加载水域掩膜
|
||||
|
||||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
:return: numpy数组或None,1表示水域,0表示非水域
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# 如果已经是numpy数组
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||||
|
||||
# 如果是文件路径
|
||||
if isinstance(water_mask, str):
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
except ImportError:
|
||||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||||
|
||||
# 检查是否为shapefile
|
||||
if water_mask.lower().endswith('.shp'):
|
||||
# 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸)
|
||||
# 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考
|
||||
raise ValueError("Kutser类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
|
||||
else:
|
||||
# 栅格文件
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||||
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
def get_depth_D(self):
|
||||
"""
|
||||
Assume the amount of glint is proportional to the depth of the oxygen absorption feature, D
|
||||
returns the normalised D by dividing it by the maximum D found in a deep water region
|
||||
"""
|
||||
# 优化:减少中间数组创建,使用更高效的计算
|
||||
lower_oxy_band = self.im_aligned[:,:,self.lower_oxy]
|
||||
upper_oxy_band = self.im_aligned[:,:,self.upper_oxy]
|
||||
oxy_band = self.im_aligned[:,:,self.oxy_band]
|
||||
D = (lower_oxy_band + upper_oxy_band) * 0.5 - oxy_band
|
||||
|
||||
# 确定用于计算D_max的区域
|
||||
if self.bbox is None:
|
||||
search_region = D
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
search_region = D[y1:y2,x1:x2]
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内搜索最大值
|
||||
if self.water_mask is not None:
|
||||
if self.bbox is None:
|
||||
mask_region = self.water_mask.astype(bool)
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
mask_region = self.water_mask[y1:y2,x1:x2].astype(bool)
|
||||
|
||||
if mask_region.any():
|
||||
D_max = search_region[mask_region].max()
|
||||
else:
|
||||
D_max = search_region.max()
|
||||
else:
|
||||
D_max = search_region.max() # assumed to be the maximum glint value
|
||||
|
||||
# 避免除零错误
|
||||
if D_max == 0:
|
||||
return np.zeros_like(D)
|
||||
return D / D_max
|
||||
|
||||
def get_glint_G(self):
|
||||
"""
|
||||
The spectral variation of glint G is found by subtracting the spectrum at the darkest (ie. lowest D) NIR deep-water pixel from the brightest
|
||||
returns G as a function of wavelength
|
||||
"""
|
||||
# If bbox is None, use the entire image
|
||||
if self.bbox is None:
|
||||
im_region = self.im_aligned
|
||||
mask_region = self.water_mask
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
im_region = self.im_aligned[y1:y2,x1:x2,:]
|
||||
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
|
||||
|
||||
:param img_path (str): 输入影像文件路径(GDAL可读取的格式)
|
||||
:param shp_path (str, optional): 深水区域shapefile,已废弃,请使用water_mask
|
||||
:param oxy_band (int): 氧吸收波段索引(默认38,对应760.6nm)
|
||||
:param lower_oxy (int): 氧吸收下方波段索引(默认36,对应742.39nm)
|
||||
:param upper_oxy (int): 氧吸收上方波段索引(默认49,对应860.48nm)
|
||||
:param NIR_band (int): NIR波段索引(默认47,对应842.36nm)
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜
|
||||
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
|
||||
:param block_size (int): 分块大小(默认1000)
|
||||
# 如果存在水域掩膜,只在掩膜内计算最大最小值
|
||||
if mask_region is not None:
|
||||
mask_bool = mask_region.astype(bool)
|
||||
if mask_bool.any():
|
||||
# 对每个波段,只在掩膜内计算最大最小值
|
||||
G_list = []
|
||||
for i in range(self.n_bands):
|
||||
band_data = im_region[:,:,i]
|
||||
G_max = band_data[mask_bool].max()
|
||||
G_min = band_data[mask_bool].min()
|
||||
G_list.append(G_max - G_min)
|
||||
else:
|
||||
# 如果掩膜内没有有效像素,使用全区域
|
||||
G_max = np.amax(im_region, axis=(0, 1))
|
||||
G_min = np.amin(im_region, axis=(0, 1))
|
||||
G_list = (G_max - G_min).tolist()
|
||||
else:
|
||||
# 优化:一次性计算所有波段的最大最小值,减少循环开销
|
||||
# 使用numpy的amax和amin沿最后一个轴计算
|
||||
G_max = np.amax(im_region, axis=(0, 1)) # 沿空间维度计算最大值
|
||||
G_min = np.amin(im_region, axis=(0, 1)) # 沿空间维度计算最小值
|
||||
G_list = (G_max - G_min).tolist()
|
||||
return G_list
|
||||
|
||||
def _save_corrected_bands(self, corrected_bands):
|
||||
"""
|
||||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_bands: 校正后的波段列表
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
self.img_path = img_path
|
||||
self.oxy_band = int(float(oxy_band))
|
||||
self.lower_oxy = int(float(lower_oxy))
|
||||
self.upper_oxy = int(float(upper_oxy))
|
||||
self.NIR_band = int(float(NIR_band))
|
||||
self.water_mask = None # 延迟加载,在处理前初始化
|
||||
self.water_mask_path = water_mask
|
||||
self.output_path = output_path
|
||||
self.block_size = block_size
|
||||
self.R_min = None # 全局R_min(来自重采样扫描)
|
||||
self.D_max = None # 全局D_max(来自重采样扫描)
|
||||
self.G_list = None # 全局G值列表
|
||||
|
||||
# 打开影像获取基本信息
|
||||
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if self.dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.n_bands = self.dataset.RasterCount
|
||||
|
||||
def _load_water_mask(self):
|
||||
"""延迟加载水域掩膜"""
|
||||
if self.water_mask_path is None:
|
||||
return None
|
||||
|
||||
if isinstance(self.water_mask_path, np.ndarray):
|
||||
if self.water_mask_path.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(
|
||||
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||||
)
|
||||
return (self.water_mask_path > 0).astype(np.uint8)
|
||||
|
||||
if isinstance(self.water_mask_path, str):
|
||||
if self.water_mask_path.lower().endswith('.shp'):
|
||||
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
|
||||
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(
|
||||
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||||
)
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
return None
|
||||
|
||||
def _scan_global_stats(self, sample_step=20):
|
||||
"""
|
||||
通过重采样扫描获取全图全局统计量:R_min 和 D_max
|
||||
|
||||
分块读取,按sample_step跳行/跳列采样,大幅降低内存占用。
|
||||
内存峰值 ≈ 单波段块大小 + 几个掩膜数组 ≈ block_size² × 4~8MB
|
||||
"""
|
||||
print(f"[Kutser] 扫描全局统计量(采样步长={sample_step})...")
|
||||
water_mask = self._load_water_mask()
|
||||
|
||||
# 预分配采样数组(NIR波段和D值)
|
||||
nir_samples = []
|
||||
d_samples = []
|
||||
sample_count = 0
|
||||
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
block_height = y_end - y_off
|
||||
|
||||
# 读取NIR波段(用于R_min)
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
|
||||
nir_band = None
|
||||
|
||||
# 读取氧吸收相关波段(用于D_max)
|
||||
lower_band = self.dataset.GetRasterBand(self.lower_oxy + 1)
|
||||
lower_block = lower_band.ReadAsArray(0, y_off, self.width, block_height)
|
||||
lower_band = None
|
||||
|
||||
upper_band = self.dataset.GetRasterBand(self.upper_oxy + 1)
|
||||
upper_block = upper_band.ReadAsArray(0, y_off, self.width, block_height)
|
||||
upper_band = None
|
||||
|
||||
oxy_band_obj = self.dataset.GetRasterBand(self.oxy_band + 1)
|
||||
oxy_block = oxy_band_obj.ReadAsArray(0, y_off, self.width, block_height)
|
||||
oxy_band_obj = None
|
||||
|
||||
# 计算D = (lower + upper) * 0.5 - oxy
|
||||
d_block = (lower_block.astype(np.float32) + upper_block.astype(np.float32)) * 0.5 - oxy_block.astype(np.float32)
|
||||
|
||||
# 获取掩膜(整块)
|
||||
if water_mask is not None:
|
||||
mask_block = water_mask[y_off:y_end, :]
|
||||
else:
|
||||
mask_block = np.ones((block_height, self.width), dtype=np.uint8)
|
||||
|
||||
# 对掩膜区域进行采样
|
||||
mask_bool = mask_block.astype(bool)
|
||||
|
||||
if mask_bool.any():
|
||||
# 按步长采样
|
||||
nir_sampled = nir_block[mask_bool][::sample_step]
|
||||
d_sampled = d_block[mask_bool][::sample_step]
|
||||
nir_samples.append(nir_sampled)
|
||||
d_samples.append(d_sampled)
|
||||
sample_count += nir_sampled.size
|
||||
|
||||
# 显式释放块内存
|
||||
del nir_block, lower_block, upper_block, oxy_block, d_block, mask_block
|
||||
|
||||
# 汇总
|
||||
if sample_count == 0:
|
||||
self.R_min = 0.0
|
||||
self.D_max = 1.0
|
||||
else:
|
||||
all_nir = np.concatenate(nir_samples)
|
||||
all_d = np.concatenate(d_samples)
|
||||
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
|
||||
self.D_max = float(all_d.max())
|
||||
del all_nir, all_d
|
||||
|
||||
print(f"[Kutser] 全局 R_min={self.R_min:.4f}, D_max={self.D_max:.4f}")
|
||||
|
||||
def _compute_G_list(self):
|
||||
"""
|
||||
计算全局G值列表(每个波段)
|
||||
|
||||
G = G_max - G_min(所有水域像素的极值差异)
|
||||
使用全分辨率扫描,但逐波段读取,每波段内存 ≈ block_size²
|
||||
"""
|
||||
print(f"[Kutser] 计算全局G值列表(n_bands={self.n_bands})...")
|
||||
water_mask = self._load_water_mask()
|
||||
|
||||
# 初始化G_max和G_min为极值
|
||||
g_max = np.full(self.n_bands, -np.inf, dtype=np.float32)
|
||||
g_min = np.full(self.n_bands, np.inf, dtype=np.float32)
|
||||
|
||||
# 逐块扫描
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
block_height = y_end - y_off
|
||||
|
||||
# 读取所有波段的当前块
|
||||
for b in range(self.n_bands):
|
||||
band = self.dataset.GetRasterBand(b + 1)
|
||||
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
|
||||
band = None
|
||||
|
||||
if water_mask is not None:
|
||||
mask_block = water_mask[y_off:y_end, :]
|
||||
mask_bool = mask_block.astype(bool)
|
||||
if mask_bool.any():
|
||||
band_masked = block[mask_bool]
|
||||
g_max[b] = max(g_max[b], band_masked.max())
|
||||
g_min[b] = min(g_min[b], band_masked.min())
|
||||
else:
|
||||
g_max[b] = max(g_max[b], block.max())
|
||||
g_min[b] = min(g_min[b], block.min())
|
||||
|
||||
del block
|
||||
|
||||
self.G_list = (g_max - g_min).tolist()
|
||||
print(f"[Kutser] G值范围: min={min(self.G_list):.4f}, max={max(self.G_list):.4f}")
|
||||
|
||||
def _process_block(self, x_off, y_off, x_size, y_size):
|
||||
"""
|
||||
处理单个分块:读取数据 -> 计算D -> 逐波段校正 -> 返回块结果
|
||||
|
||||
Returns:
|
||||
list of np.ndarray: 校正后的波段列表(每波段形状为 y_size x x_size)
|
||||
"""
|
||||
# 读取氧吸收相关波段
|
||||
lower_band = self.dataset.GetRasterBand(self.lower_oxy + 1)
|
||||
lower_block = lower_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
lower_band = None
|
||||
|
||||
upper_band = self.dataset.GetRasterBand(self.upper_oxy + 1)
|
||||
upper_block = upper_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
upper_band = None
|
||||
|
||||
oxy_band_obj = self.dataset.GetRasterBand(self.oxy_band + 1)
|
||||
oxy_block = oxy_band_obj.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
oxy_band_obj = None
|
||||
|
||||
# 计算D
|
||||
D = (lower_block + upper_block) * 0.5 - oxy_block
|
||||
|
||||
# 避免除零
|
||||
if self.D_max == 0:
|
||||
D_normalized = np.zeros_like(D)
|
||||
else:
|
||||
D_normalized = D / self.D_max
|
||||
|
||||
# 释放临时块
|
||||
del lower_block, upper_block, oxy_block, D
|
||||
|
||||
# 获取当前块的水域掩膜
|
||||
water_mask = self._load_water_mask()
|
||||
if water_mask is not None:
|
||||
y_end = y_off + y_size
|
||||
x_end = x_off + x_size
|
||||
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
|
||||
else:
|
||||
mask_block = None
|
||||
|
||||
# 逐波段处理
|
||||
corrected_bands = []
|
||||
for b in range(self.n_bands):
|
||||
band = self.dataset.GetRasterBand(b + 1)
|
||||
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
band = None
|
||||
|
||||
G = self.G_list[b]
|
||||
# 校正公式:R_corrected = R - G * D_normalized
|
||||
corrected = R - G * D_normalized
|
||||
|
||||
# 只对水域区域应用校正
|
||||
if mask_block is not None:
|
||||
corrected = np.where(mask_block, corrected, R)
|
||||
|
||||
corrected_bands.append(corrected)
|
||||
del R
|
||||
|
||||
# 释放D块
|
||||
del D_normalized
|
||||
|
||||
return corrected_bands
|
||||
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
执行分块处理,返回校正后的波段列表
|
||||
|
||||
内存峰值 ≈ 单波段块大小 + 几个辅助数组 ≈ 1000×1000×4B × 3 ≈ 12MB
|
||||
"""
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if self.output_path is None:
|
||||
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
|
||||
|
||||
# Step 1: 扫描全局统计量(R_min, D_max)
|
||||
self._scan_global_stats(sample_step=20)
|
||||
|
||||
# Step 2: 计算全局G列表
|
||||
self._compute_G_list()
|
||||
|
||||
# Step 3: 创建输出文件
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
||||
# 将波段列表转换为数组
|
||||
corrected_array = np.stack(corrected_bands, axis=2)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
|
||||
|
||||
# 获取地理信息
|
||||
geotransform = self.dataset.GetGeoTransform()
|
||||
projection = self.dataset.GetProjection()
|
||||
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
else:
|
||||
bsq_path = self.output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
out_dataset = driver.Create(bsq_path, self.width, self.height,
|
||||
self.n_bands, gdal.GDT_Float32)
|
||||
if out_dataset is None:
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_array.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
|
||||
# Step 4: 分块处理
|
||||
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
|
||||
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
|
||||
total_blocks = n_blocks_x * n_blocks_y
|
||||
|
||||
print(f"[Kutser] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
|
||||
|
||||
block_idx = 0
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
y_size = y_end - y_off
|
||||
|
||||
for x_off in range(0, self.width, self.block_size):
|
||||
x_end = min(x_off + self.block_size, self.width)
|
||||
x_size = x_end - x_off
|
||||
block_idx += 1
|
||||
|
||||
print(f"[Kutser] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
|
||||
|
||||
# 处理当前块
|
||||
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
|
||||
|
||||
# 写入输出文件
|
||||
for b in range(self.n_bands):
|
||||
out_band = out_dataset.GetRasterBand(b + 1)
|
||||
out_band.WriteArray(corrected_bands[b], x_off, y_off)
|
||||
out_band.FlushCache()
|
||||
|
||||
del corrected_bands
|
||||
|
||||
out_dataset = None
|
||||
self.dataset = None
|
||||
|
||||
# 检查.hdr文件
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"[Kutser] 校正完成,已保存至: {bsq_path}")
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
else:
|
||||
print(f"[Kutser] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)")
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
|
||||
# 返回空列表(结果已直接写入文件)
|
||||
return []
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
correction is done in reflectance
|
||||
|
||||
:return: 校正后的波段列表
|
||||
"""
|
||||
g_list = self.get_glint_G()
|
||||
D = self.get_depth_D()
|
||||
|
||||
# 获取水域掩膜(如果存在)
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
|
||||
def __del__(self):
|
||||
if self.dataset is not None:
|
||||
self.dataset = None
|
||||
corrected_bands = []
|
||||
for band_number in range(self.n_bands): #iterate across bands
|
||||
G = g_list[band_number]
|
||||
R = self.im_aligned[:,:,band_number]
|
||||
# 优化:减少中间数组创建,直接计算
|
||||
corrected_band = R - G * D
|
||||
|
||||
# 如果存在水域掩膜,只对水域区域应用校正
|
||||
if water_mask_bool is not None:
|
||||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||||
|
||||
corrected_bands.append(corrected_band)
|
||||
|
||||
# 如果提供了输出路径,保存结果
|
||||
if self.output_path is not None:
|
||||
self._save_corrected_bands(corrected_bands)
|
||||
|
||||
return corrected_bands
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,6 @@ from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
|
||||
from sklearn.tree import DecisionTreeRegressor
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
from joblib import parallel_backend
|
||||
# 第三方模型导入
|
||||
# try:
|
||||
# import lightgbm as lgb
|
||||
@ -43,6 +42,11 @@ import os
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
|
||||
|
||||
def _sklearn_parallel_n_jobs() -> int:
|
||||
"""PyInstaller 等打包环境下,joblib loky 会再次启动当前 exe,出现多个同名进程。"""
|
||||
return 1 if getattr(sys, "frozen", False) else -1
|
||||
|
||||
|
||||
class WaterQualityModelingBatch:
|
||||
"""水质参数反演批量建模类"""
|
||||
|
||||
@ -638,25 +642,26 @@ class WaterQualityModelingBatch:
|
||||
# 网格搜索 - 使用KFold代替StratifiedKFold
|
||||
cv_strategy = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
||||
|
||||
_n_jobs = _sklearn_parallel_n_jobs()
|
||||
grid_search = GridSearchCV(
|
||||
base_model,
|
||||
config['params'],
|
||||
cv=cv_strategy,
|
||||
scoring=scoring,
|
||||
n_jobs=-1,
|
||||
n_jobs=_n_jobs,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
# 在训练集上训练模型
|
||||
# with parallel_backend("threading", n_jobs=-1):
|
||||
# grid_search.fit(X_train, y_train)
|
||||
grid_search.fit(X_train, y_train)
|
||||
|
||||
# 获取最佳模型
|
||||
best_model = grid_search.best_estimator_
|
||||
|
||||
# 交叉验证评估(在训练集上)
|
||||
cv_scores = cross_val_score(best_model, X_train, y_train, cv=cv_strategy, scoring=scoring)
|
||||
cv_scores = cross_val_score(
|
||||
best_model, X_train, y_train, cv=cv_strategy, scoring=scoring, n_jobs=_n_jobs
|
||||
)
|
||||
|
||||
# 计算训练集上的回归指标
|
||||
y_train_pred = best_model.predict(X_train)
|
||||
|
||||
@ -555,13 +555,7 @@ class WaterQualityInference:
|
||||
print(f"输入数据形状: {spectra_processed.shape}")
|
||||
|
||||
try:
|
||||
# 清洗 NaN / Inf,防止 SVR 等模型报错
|
||||
spectra_clean = np.nan_to_num(spectra_processed, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
if np.any(np.isnan(spectra_clean)) or np.any(np.isinf(spectra_clean)):
|
||||
print("警告: 清洗后数据中仍存在 NaN/Inf,已重置为 0")
|
||||
spectra_clean = np.nan_to_num(spectra_clean, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
predictions = model.predict(spectra_clean)
|
||||
predictions = model.predict(spectra_processed)
|
||||
print(f"预测完成,结果形状: {predictions.shape}")
|
||||
print(f"预测值范围: [{np.min(predictions):.4f}, {np.max(predictions):.4f}]")
|
||||
print(f"预测值统计: 均值={np.mean(predictions):.4f}, 标准差={np.std(predictions):.4f}")
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
工具模块 - 统一导出接口
|
||||
"""
|
||||
from src.core.utils.gdal_helper import (
|
||||
get_image_geo_info,
|
||||
load_image_as_array,
|
||||
save_array_as_image,
|
||||
save_bands_as_image,
|
||||
copy_hdr_info,
|
||||
read_band_as_array,
|
||||
read_multiple_bands,
|
||||
)
|
||||
from src.core.utils.mask_converter import (
|
||||
prepare_water_mask_for_algorithm,
|
||||
ensure_water_mask_dat,
|
||||
)
|
||||
from src.core.utils.preview_generator import (
|
||||
generate_image_preview,
|
||||
generate_water_mask_overlay,
|
||||
select_rgb_bands_by_wavelength,
|
||||
get_wavelength_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# GDAL IO
|
||||
'get_image_geo_info',
|
||||
'load_image_as_array',
|
||||
'save_array_as_image',
|
||||
'save_bands_as_image',
|
||||
'copy_hdr_info',
|
||||
'read_band_as_array',
|
||||
'read_multiple_bands',
|
||||
# 掩膜转换
|
||||
'prepare_water_mask_for_algorithm',
|
||||
'ensure_water_mask_dat',
|
||||
# 预览图生成
|
||||
'generate_image_preview',
|
||||
'generate_water_mask_overlay',
|
||||
'select_rgb_bands_by_wavelength',
|
||||
'get_wavelength_info',
|
||||
]
|
||||
@ -1,309 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
GDAL 底层 IO 工具模块
|
||||
|
||||
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
|
||||
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
# GDAL 导入(可选)
|
||||
try:
|
||||
from osgeo import gdal, ogr, gdal_array
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# hdr 文件工具
|
||||
try:
|
||||
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
|
||||
UTIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
UTIL_AVAILABLE = False
|
||||
write_fields_to_hdrfile = None
|
||||
get_hdr_file_path = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 影像信息读取
|
||||
# ============================================================
|
||||
|
||||
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
|
||||
"""
|
||||
获取影像的地理信息(不加载图像数据,节省内存)
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (geotransform, projection, width, height, n_bands)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
return geotransform, projection, width, height, n_bands
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
|
||||
"""
|
||||
加载影像文件为numpy数组
|
||||
|
||||
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
|
||||
建议直接传递文件路径给算法类,让算法类使用GDAL逐波段处理。
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (image_array, geotransform, projection)
|
||||
image_array: numpy数组,形状为(height, width, bands)
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
image_bands = []
|
||||
for i in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(i)
|
||||
band_data = band.ReadAsArray()
|
||||
image_bands.append(band_data)
|
||||
|
||||
image_array = np.dstack(image_bands)
|
||||
return image_array, geotransform, projection
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
|
||||
"""
|
||||
读取单个波段为 numpy 数组
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
band_index: 波段索引(从 0 开始)
|
||||
|
||||
Returns:
|
||||
numpy 数组,形状为 (height, width)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
band = dataset.GetRasterBand(band_index + 1)
|
||||
return band.ReadAsArray()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
|
||||
"""
|
||||
读取多个指定波段为列表
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
band_indices: 波段索引列表
|
||||
|
||||
Returns:
|
||||
tuple: (band_list, geotransform, projection)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
bands = []
|
||||
for idx in band_indices:
|
||||
band = dataset.GetRasterBand(idx + 1)
|
||||
bands.append(band.ReadAsArray())
|
||||
return bands, geotransform, projection
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 影像写入
|
||||
# ============================================================
|
||||
|
||||
def save_array_as_image(image_array: np.ndarray, output_path: str,
|
||||
geotransform: tuple, projection: str,
|
||||
dtype=None) -> str:
|
||||
"""
|
||||
将numpy数组保存为影像文件
|
||||
|
||||
Args:
|
||||
image_array: numpy数组,形状为(height, width, bands) 或 (height, width)
|
||||
output_path: 输出文件路径
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
dtype: GDAL数据类型(默认 gdal.GDT_Float32)
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if dtype is None:
|
||||
dtype = gdal.GDT_Float32
|
||||
|
||||
if image_array.ndim == 2:
|
||||
height, width = image_array.shape
|
||||
n_bands = 1
|
||||
else:
|
||||
height, width, n_bands = image_array.shape
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
if driver is None:
|
||||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||||
|
||||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||||
|
||||
try:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
if n_bands == 1:
|
||||
band = dataset.GetRasterBand(1)
|
||||
band.WriteArray(image_array)
|
||||
band.FlushCache()
|
||||
else:
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(image_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def save_bands_as_image(corrected_bands: list, output_path: str,
|
||||
geotransform: tuple, projection: str,
|
||||
dtype=None) -> str:
|
||||
"""
|
||||
直接从波段列表保存影像文件(避免堆叠,节省内存)
|
||||
|
||||
Args:
|
||||
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
|
||||
output_path: 输出文件路径
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
dtype: GDAL数据类型
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if not corrected_bands:
|
||||
raise ValueError("波段列表为空")
|
||||
|
||||
if dtype is None:
|
||||
dtype = gdal.GDT_Float32
|
||||
|
||||
n_bands = len(corrected_bands)
|
||||
height, width = corrected_bands[0].shape
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
if driver is None:
|
||||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||||
|
||||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||||
|
||||
try:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
for i, band_array in enumerate(corrected_bands):
|
||||
if band_array.shape != (height, width):
|
||||
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(band_array)
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
|
||||
"""
|
||||
复制原始影像的hdr文件信息(如波长等)到目标影像的hdr文件
|
||||
|
||||
Args:
|
||||
source_img_path: 源影像文件路径(原始bsq文件)
|
||||
dest_img_path: 目标影像文件路径(去耀斑后的bsq文件)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not UTIL_AVAILABLE:
|
||||
print("警告: util模块未导入,无法复制hdr文件信息")
|
||||
return False
|
||||
|
||||
try:
|
||||
source_hdr_path = get_hdr_file_path(source_img_path)
|
||||
dest_hdr_path = get_hdr_file_path(dest_img_path)
|
||||
|
||||
if not Path(source_hdr_path).exists():
|
||||
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
|
||||
return False
|
||||
|
||||
if not Path(dest_hdr_path).exists():
|
||||
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
|
||||
return False
|
||||
|
||||
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
|
||||
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"警告: 复制hdr文件信息时出错: {e}")
|
||||
return False
|
||||
@ -1,210 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
掩膜转换工具模块
|
||||
|
||||
提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换,
|
||||
以及水体掩膜的预处理逻辑。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
def prepare_water_mask_for_algorithm(
|
||||
water_mask: Optional[Union[str, np.ndarray]],
|
||||
image_shape: Union[tuple, np.ndarray],
|
||||
geotransform: tuple,
|
||||
projection: str,
|
||||
img_path: str,
|
||||
water_mask_dir: Optional[str] = None,
|
||||
callback=None
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
准备水域掩膜供算法使用
|
||||
|
||||
支持格式:
|
||||
- None:自动使用预先生成的 dat 格式掩膜
|
||||
- numpy.ndarray:直接返回(确保是 0/1 格式)
|
||||
- .dat / .tif 等栅格文件:读取并返回
|
||||
- .shp 文件:先栅格化,再读取返回
|
||||
|
||||
Args:
|
||||
water_mask: 掩膜来源
|
||||
image_shape: 影像形状 (height, width) 或 (height, width, channels)
|
||||
geotransform: GDAL 地理变换参数
|
||||
projection: 投影信息
|
||||
img_path: 影像路径(用于 shp 栅格化)
|
||||
water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果)
|
||||
callback: 进度回调函数(可选)
|
||||
|
||||
Returns:
|
||||
numpy数组(dtype=uint8,0=非水域,1=水域)或 None
|
||||
"""
|
||||
img_height, img_width = image_shape[0], image_shape[1]
|
||||
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# numpy 数组直接返回
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (img_height, img_width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8)
|
||||
|
||||
# 字符串路径
|
||||
if isinstance(water_mask, str):
|
||||
ext = Path(water_mask).suffix.lower()
|
||||
|
||||
# shapefile 格式
|
||||
if ext == '.shp':
|
||||
return _convert_shp_to_mask(
|
||||
shp_path=water_mask,
|
||||
img_path=img_path,
|
||||
image_shape=image_shape,
|
||||
geotransform=geotransform,
|
||||
projection=projection,
|
||||
water_mask_dir=water_mask_dir,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
# 栅格文件格式
|
||||
return _load_raster_mask(water_mask, img_height, img_width)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
|
||||
def _convert_shp_to_mask(shp_path: str, img_path: str,
|
||||
image_shape: tuple,
|
||||
geotransform: tuple,
|
||||
projection: str,
|
||||
water_mask_dir: Optional[str] = None,
|
||||
callback=None) -> np.ndarray:
|
||||
"""将 shapefile 栅格化为掩膜数组"""
|
||||
from src.utils.extract_water_area import rasterize_shp
|
||||
|
||||
safe_shp_path = os.path.abspath(shp_path).replace('\\', '/')
|
||||
shp_name = Path(safe_shp_path).stem
|
||||
|
||||
if water_mask_dir:
|
||||
temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat")
|
||||
else:
|
||||
temp_mask_path = f"/tmp/water_mask_{shp_name}.dat"
|
||||
|
||||
# 缓存:已栅格化则直接读取
|
||||
if Path(temp_mask_path).exists():
|
||||
print(f"使用已存在的栅格化掩膜: {temp_mask_path}")
|
||||
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
|
||||
|
||||
# 需要栅格化
|
||||
if img_path is None:
|
||||
raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化")
|
||||
|
||||
print(f"正在将 SHP 栅格化: {safe_shp_path}")
|
||||
rasterize_shp(safe_shp_path, temp_mask_path, img_path)
|
||||
|
||||
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
|
||||
|
||||
|
||||
def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray:
|
||||
"""从栅格文件加载掩膜"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取掩膜文件")
|
||||
|
||||
mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {mask_path}")
|
||||
|
||||
try:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
finally:
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (img_height, img_width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
|
||||
def ensure_water_mask_dat(img_path: str,
|
||||
existing_dat_path: Optional[str] = None,
|
||||
output_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
确保存在 dat 格式的水体掩膜文件(用于步骤3/4中的算法)
|
||||
|
||||
如果 existing_dat_path 存在且是 .dat 文件,直接返回。
|
||||
如果存在同名 .dat 文件,直接返回。
|
||||
否则从 img_path 生成并保存到 output_dir。
|
||||
|
||||
Args:
|
||||
img_path: 用于生成掩膜的遥感影像路径
|
||||
existing_dat_path: 已有的 dat 格式掩膜路径(可选)
|
||||
output_dir: 输出目录(可选)
|
||||
|
||||
Returns:
|
||||
dat 格式掩膜文件路径
|
||||
"""
|
||||
if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat':
|
||||
if Path(existing_dat_path).exists():
|
||||
return existing_dat_path
|
||||
|
||||
img_name = Path(img_path).stem
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(img_path).parent)
|
||||
|
||||
dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat")
|
||||
|
||||
if Path(dat_path).exists():
|
||||
return dat_path
|
||||
|
||||
# 如果已有其他格式的掩膜,转换为 dat
|
||||
for ext in ['.tif', '.img', '.tiff']:
|
||||
alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}")
|
||||
if Path(alt_path).exists():
|
||||
return _convert_to_dat(alt_path, dat_path)
|
||||
|
||||
return dat_path # 返回目标路径,让调用方决定是否需要生成
|
||||
|
||||
|
||||
def _convert_to_dat(src_path: str, dest_path: str) -> str:
|
||||
"""将其他栅格格式转换为 ENVI dat 格式"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法转换格式")
|
||||
|
||||
src_ds = gdal.Open(src_path, gdal.GA_ReadOnly)
|
||||
if src_ds is None:
|
||||
raise ValueError(f"无法打开源掩膜文件: {src_path}")
|
||||
|
||||
try:
|
||||
geotransform = src_ds.GetGeoTransform()
|
||||
projection = src_ds.GetProjection()
|
||||
band = src_ds.GetRasterBand(1)
|
||||
array = band.ReadAsArray()
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte)
|
||||
if dest_ds is None:
|
||||
raise ValueError(f"无法创建输出文件: {dest_path}")
|
||||
|
||||
try:
|
||||
dest_ds.SetGeoTransform(geotransform)
|
||||
dest_ds.SetProjection(projection)
|
||||
dest_band = dest_ds.GetRasterBand(1)
|
||||
dest_band.WriteArray((array > 0).astype(np.uint8))
|
||||
dest_band.FlushCache()
|
||||
finally:
|
||||
dest_ds = None
|
||||
|
||||
return dest_path
|
||||
finally:
|
||||
src_ds = None
|
||||
@ -1,339 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
遥感影像预览图生成工具模块
|
||||
|
||||
提供高光谱影像的 RGB 预览图、水域掩膜叠加图等可视化功能。
|
||||
"""
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# matplotlib 仅在实际使用时导入(preview_generator 是可视化工具)
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 辅助函数:波段选择
|
||||
# ============================================================
|
||||
|
||||
def select_rgb_bands_by_wavelength(band_count: int,
|
||||
wavelength_info: Optional[List[float]] = None,
|
||||
fallback_bands: Optional[List[int]] = None) -> List[int]:
|
||||
"""
|
||||
根据波长自动选择 RGB 波段
|
||||
|
||||
Args:
|
||||
band_count: 总波段数
|
||||
wavelength_info: 各波段波长列表(nm),长度为 band_count
|
||||
fallback_bands: 当无法通过波长选择时的回退波段索引 [R, G, B]
|
||||
|
||||
Returns:
|
||||
波段索引列表 [R_index, G_index, B_index](0-based)
|
||||
"""
|
||||
if fallback_bands is None:
|
||||
fallback_bands = [band_count - 3, band_count - 2, band_count - 1]
|
||||
|
||||
if wavelength_info is None:
|
||||
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
|
||||
|
||||
# 目标波长(nm)
|
||||
TARGET_R = 650
|
||||
TARGET_G = 550
|
||||
TARGET_B = 460
|
||||
|
||||
def find_closest(target: float) -> int:
|
||||
min_dist = float('inf')
|
||||
best_idx = 0
|
||||
for i, wl in enumerate(wavelength_info):
|
||||
dist = abs(wl - target)
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
best_idx = i
|
||||
return best_idx
|
||||
|
||||
try:
|
||||
r_idx = find_closest(TARGET_R)
|
||||
g_idx = find_closest(TARGET_G)
|
||||
b_idx = find_closest(TARGET_B)
|
||||
return [r_idx, g_idx, b_idx]
|
||||
except Exception:
|
||||
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
|
||||
|
||||
|
||||
def get_wavelength_info(img_path: str) -> Optional[List[float]]:
|
||||
"""从 hdr 文件读取波长信息"""
|
||||
try:
|
||||
hdr_path = Path(img_path).with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
return None
|
||||
|
||||
wavelengths = []
|
||||
in_wl = False
|
||||
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith('wavelength ='):
|
||||
in_wl = True
|
||||
line = line.split('=', 1)[1].strip()
|
||||
elif in_wl:
|
||||
if line.startswith('{'):
|
||||
line = line[1:]
|
||||
if line.endswith('}'):
|
||||
line = line[:-1]
|
||||
in_wl = False
|
||||
# 解析逗号分隔的数值
|
||||
for token in line.replace(',', ' ').split():
|
||||
try:
|
||||
wavelengths.append(float(token))
|
||||
except ValueError:
|
||||
pass
|
||||
return wavelengths if wavelengths else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 核心预览图生成函数
|
||||
# ============================================================
|
||||
|
||||
def generate_image_preview(img_path: str,
|
||||
output_path: str,
|
||||
bands: Optional[List[int]] = None,
|
||||
title: str = "影像预览") -> str:
|
||||
"""
|
||||
生成高光谱影像的 PNG 预览图
|
||||
|
||||
Args:
|
||||
img_path: 输入影像路径
|
||||
output_path: 输出 PNG 文件路径
|
||||
bands: RGB 波段索引 [R, G, B],None 则自动选择
|
||||
title: 图片标题
|
||||
|
||||
Returns:
|
||||
生成的 PNG 文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法生成影像预览图")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预览图,跳过生成: {output_path}")
|
||||
return output_path
|
||||
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
|
||||
# 自动选择波段
|
||||
if bands is None:
|
||||
if band_count >= 3:
|
||||
wl_info = get_wavelength_info(img_path)
|
||||
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
|
||||
else:
|
||||
bands = [0, 0, 0]
|
||||
|
||||
# 读取波段
|
||||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||||
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
|
||||
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
|
||||
|
||||
r_data[r_data <= 0] = np.nan
|
||||
if band_count > 1:
|
||||
g_data[g_data <= 0] = np.nan
|
||||
if band_count > 2:
|
||||
b_data[b_data <= 0] = np.nan
|
||||
|
||||
# 线性拉伸
|
||||
def linear_stretch(data, low=2, high=98):
|
||||
valid = data[~np.isnan(data)]
|
||||
if len(valid) == 0:
|
||||
return np.zeros_like(data)
|
||||
lo = np.percentile(valid, low)
|
||||
hi = np.percentile(valid, high)
|
||||
if hi - lo < 1e-10:
|
||||
return np.zeros_like(data)
|
||||
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
|
||||
return np.nan_to_num(stretched, nan=0.0)
|
||||
|
||||
r_s = linear_stretch(r_data)
|
||||
g_s = linear_stretch(g_data) if band_count > 1 else r_s
|
||||
b_s = linear_stretch(b_data) if band_count > 2 else r_s
|
||||
|
||||
rgb_image = np.stack([r_s, g_s, b_s], axis=2)
|
||||
|
||||
# 绘图
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
ax.imshow(rgb_image)
|
||||
ax.set_title(title, fontsize=12, fontweight='bold')
|
||||
ax.axis('off')
|
||||
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
if geotransform and geotransform[1] != 0:
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
scale_text = f"分辨率: {pixel_size_x:.2f} m/px | 尺寸: {width} x {height} px"
|
||||
fig.text(0.5, 0.02, scale_text, ha='center', fontsize=9,
|
||||
color='white',
|
||||
bbox=dict(facecolor='black', alpha=0.6,
|
||||
boxstyle='round,pad=0.3'))
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.close(fig)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def generate_water_mask_overlay(img_path: str,
|
||||
mask_path: str,
|
||||
output_path: str,
|
||||
bands: Optional[List[int]] = None,
|
||||
mask_color: tuple = (0, 100, 255),
|
||||
mask_alpha: float = 0.5) -> str:
|
||||
"""
|
||||
生成水域掩膜叠加到原图的 PNG 图像
|
||||
|
||||
Args:
|
||||
img_path: 输入影像路径
|
||||
mask_path: 水域掩膜文件路径
|
||||
output_path: 输出 PNG 路径
|
||||
bands: RGB 波段索引,None 则自动选择
|
||||
mask_color: 掩膜叠加颜色 (R, G, B)
|
||||
mask_alpha: 掩膜透明度(0=完全透明,1=完全不透明)
|
||||
|
||||
Returns:
|
||||
生成的 PNG 文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法生成叠加图")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的叠加图,跳过生成: {output_path}")
|
||||
return output_path
|
||||
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
|
||||
# 自动选择波段
|
||||
if bands is None:
|
||||
if band_count >= 3:
|
||||
wl_info = get_wavelength_info(img_path)
|
||||
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
|
||||
else:
|
||||
bands = [0, 0, 0]
|
||||
|
||||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||||
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
|
||||
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
|
||||
|
||||
r_data[r_data <= 0] = np.nan
|
||||
if band_count > 1:
|
||||
g_data[g_data <= 0] = np.nan
|
||||
if band_count > 2:
|
||||
b_data[b_data <= 0] = np.nan
|
||||
|
||||
def linear_stretch(data, low=2, high=98):
|
||||
valid = data[~np.isnan(data)]
|
||||
if len(valid) == 0:
|
||||
return np.zeros_like(data)
|
||||
lo = np.percentile(valid, low)
|
||||
hi = np.percentile(valid, high)
|
||||
if hi - lo < 1e-10:
|
||||
return np.zeros_like(data)
|
||||
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
|
||||
return np.nan_to_num(stretched, nan=0.0)
|
||||
|
||||
r_s = linear_stretch(r_data)
|
||||
g_s = linear_stretch(g_data) if band_count > 1 else r_s
|
||||
b_s = linear_stretch(b_data) if band_count > 2 else r_s
|
||||
|
||||
rgb_image = np.nan_to_num(np.stack([r_s, g_s, b_s], axis=2)) * 255
|
||||
rgb_image = rgb_image.astype(np.uint8)
|
||||
|
||||
# 读取掩膜
|
||||
mask_dataset = gdal.Open(mask_path)
|
||||
if mask_dataset is not None:
|
||||
mask_data = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
else:
|
||||
print(f"警告: 无法打开掩膜文件: {mask_path}")
|
||||
mask_data = None
|
||||
|
||||
# Alpha 混合
|
||||
overlay = np.zeros((height, width, 4), dtype=np.uint8)
|
||||
overlay[:, :, 0:3] = mask_color
|
||||
overlay[:, :, 3] = 255 # 全不透明
|
||||
|
||||
blended = rgb_image.astype(np.float32)
|
||||
if mask_data is not None:
|
||||
alpha = mask_data.astype(np.float32) / 255.0 * mask_alpha
|
||||
for c in range(3):
|
||||
blended[:, :, c] = rgb_image[:, :, c].astype(np.float32) * (1 - alpha) + mask_color[c] * alpha
|
||||
blended = blended.astype(np.uint8)
|
||||
|
||||
# 绘图
|
||||
fig, ax = plt.subplots(figsize=(14, 10))
|
||||
ax.imshow(blended)
|
||||
ax.axis('off')
|
||||
|
||||
legend_elements = [
|
||||
Patch(facecolor=f'#{mask_color[0]:02x}{mask_color[1]:02x}{mask_color[2]:02x}',
|
||||
edgecolor='black', alpha=mask_alpha, label='水域范围')
|
||||
]
|
||||
ax.legend(handles=legend_elements, loc='upper right', framealpha=0.9)
|
||||
|
||||
# 面积计算
|
||||
if geotransform and geotransform[1] != 0:
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
pixel_size_y = abs(geotransform[5])
|
||||
pixel_area = pixel_size_x * pixel_size_y
|
||||
|
||||
if mask_data is not None:
|
||||
water_pixels = np.sum(mask_data > 0)
|
||||
valid_pixels = np.sum(mask_data >= 0)
|
||||
water_km2 = water_pixels * pixel_area / 1_000_000
|
||||
valid_km2 = valid_pixels * pixel_area / 1_000_000
|
||||
pct = (water_pixels / valid_pixels * 100) if valid_pixels > 0 else 0
|
||||
|
||||
area_text = (f'水域面积: {water_km2:.2f} km² | '
|
||||
f'影像总面积: {valid_km2:.2f} km² | '
|
||||
f'占比: {pct:.1f}%')
|
||||
ax.text(0.02, 0.98, area_text,
|
||||
transform=ax.transAxes, fontsize=11,
|
||||
color='white', fontweight='bold',
|
||||
bbox=dict(facecolor='#0064FF', alpha=0.8,
|
||||
edgecolor='black', boxstyle='round,pad=0.5'),
|
||||
verticalalignment='top')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.close(fig)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
@ -1,21 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 统一导出接口
|
||||
|
||||
本模块从各子模块导入可视化函数,提供统一的导出接口。
|
||||
"""
|
||||
from src.core.visualization.scatter_plot import generate_model_scatter_plots
|
||||
from src.core.visualization.spectrum_plot import generate_spectrum_comparison_plots
|
||||
from src.core.visualization.boxplot import generate_boxplots
|
||||
from src.core.visualization.statistics import generate_statistical_charts
|
||||
from src.core.visualization.preview import generate_glint_deglint_previews
|
||||
from src.core.visualization.report import generate_pipeline_report
|
||||
|
||||
__all__ = [
|
||||
'generate_model_scatter_plots',
|
||||
'generate_spectrum_comparison_plots',
|
||||
'generate_boxplots',
|
||||
'generate_statistical_charts',
|
||||
'generate_glint_deglint_previews',
|
||||
'generate_pipeline_report',
|
||||
]
|
||||
@ -1,183 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 箱型图生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
sns.set_style("whitegrid")
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
def generate_boxplots(
|
||||
csv_path: str,
|
||||
parameter_columns: Optional[List[str]] = None,
|
||||
data_start_column: int = 4,
|
||||
save_individual: bool = True,
|
||||
use_seaborn: bool = True,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成水质参数的箱型图
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径
|
||||
parameter_columns: 参数列名列表(如果为None,自动检测)
|
||||
data_start_column: 数据开始列索引(从第几列开始,默认第5列,索引为4)
|
||||
save_individual: 是否为每个参数单独保存箱型图
|
||||
use_seaborn: 是否使用seaborn绘制(更美观)
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
箱型图文件路径字典
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成水质参数箱型图")
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供 csv_path")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
csv_dir = Path(csv_path).parent
|
||||
output_dir = str(csv_dir / "visualization" / "boxplots")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
# 确定参数列
|
||||
if parameter_columns is None:
|
||||
data_columns = df.iloc[:, data_start_column:]
|
||||
parameter_columns = list(data_columns.columns)
|
||||
else:
|
||||
parameter_columns = [col for col in parameter_columns if col in df.columns]
|
||||
|
||||
if not parameter_columns:
|
||||
print("警告: 未找到有效的参数列")
|
||||
return {}
|
||||
|
||||
boxplot_dir = Path(output_dir)
|
||||
boxplot_paths = {}
|
||||
|
||||
if save_individual:
|
||||
print(f"为每个参数单独绘制箱型图(共 {len(parameter_columns)} 个参数)")
|
||||
|
||||
for column in parameter_columns:
|
||||
if column not in df.columns:
|
||||
continue
|
||||
|
||||
clean_data = df[column].dropna()
|
||||
|
||||
if len(clean_data) == 0:
|
||||
print(f"跳过列 '{column}': 没有有效数据")
|
||||
continue
|
||||
|
||||
try:
|
||||
plt.figure(figsize=(8, 6))
|
||||
|
||||
if use_seaborn:
|
||||
plot_data = pd.DataFrame({
|
||||
'参数': [column] * len(clean_data),
|
||||
'数值': clean_data
|
||||
})
|
||||
sns.boxplot(data=plot_data, x='参数', y='数值', palette='Set2')
|
||||
sns.stripplot(data=plot_data, x='参数', y='数值',
|
||||
color='red', alpha=0.6, size=5, jitter=True)
|
||||
else:
|
||||
box_plot = plt.boxplot([clean_data], labels=[column],
|
||||
patch_artist=True, showfliers=False)
|
||||
box_plot['boxes'][0].set_facecolor('lightblue')
|
||||
box_plot['boxes'][0].set_alpha(0.7)
|
||||
|
||||
x_pos = np.random.normal(1, 0.04, size=len(clean_data))
|
||||
plt.scatter(x_pos, clean_data, alpha=0.6, s=30, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
|
||||
plt.title(f'{column} - 箱型图', fontsize=14, fontweight='bold')
|
||||
plt.xlabel('参数', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
|
||||
stats_text = (f'数据点数: {len(clean_data)}\n'
|
||||
f'均值: {clean_data.mean():.2f}\n'
|
||||
f'中位数: {clean_data.median():.2f}\n'
|
||||
f'标准差: {clean_data.std():.2f}')
|
||||
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round',
|
||||
facecolor='wheat' if not use_seaborn else 'lightgreen',
|
||||
alpha=0.8))
|
||||
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
plt.tight_layout()
|
||||
|
||||
safe_column_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
|
||||
save_path = boxplot_dir / f'{safe_column_name}_boxplot.png'
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
boxplot_paths[column] = str(save_path)
|
||||
print(f" 已保存: {save_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" 处理参数 {column} 时出错: {e}")
|
||||
continue
|
||||
|
||||
# 综合箱型图
|
||||
try:
|
||||
print("\n生成综合箱型图(所有参数在一张图上)")
|
||||
plt.figure(figsize=(max(12, len(parameter_columns) * 0.8), 8))
|
||||
|
||||
box_data = []
|
||||
labels = []
|
||||
for column in parameter_columns:
|
||||
if column in df.columns:
|
||||
clean_data = df[column].dropna()
|
||||
if len(clean_data) > 0:
|
||||
box_data.append(clean_data)
|
||||
labels.append(column)
|
||||
|
||||
if box_data:
|
||||
if use_seaborn:
|
||||
melted_data = pd.melt(df[labels], var_name='参数', value_name='数值')
|
||||
melted_data = melted_data.dropna()
|
||||
sns.boxplot(data=melted_data, x='参数', y='数值', palette='Set3')
|
||||
sns.stripplot(data=melted_data, x='参数', y='数值',
|
||||
color='red', alpha=0.6, size=4, jitter=True)
|
||||
else:
|
||||
box_plot = plt.boxplot(box_data, labels=labels, patch_artist=True, showfliers=False)
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(box_data)))
|
||||
for patch, color in zip(box_plot['boxes'], colors):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.7)
|
||||
|
||||
for i, data in enumerate(box_data):
|
||||
x_pos = np.random.normal(i + 1, 0.04, size=len(data))
|
||||
plt.scatter(x_pos, data, alpha=0.6, s=20, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
|
||||
plt.title('水质参数箱型图(综合)', fontsize=16, fontweight='bold')
|
||||
plt.xlabel('参数', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
plt.tight_layout()
|
||||
|
||||
combined_path = boxplot_dir / 'all_parameters_boxplot.png'
|
||||
plt.savefig(combined_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
boxplot_paths['all_parameters'] = str(combined_path)
|
||||
print(f" 已保存综合箱型图: {combined_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"生成综合箱型图时出错: {e}")
|
||||
|
||||
print(f"\n箱型图生成完成,共生成 {len(boxplot_paths)} 个图表")
|
||||
return boxplot_paths
|
||||
@ -1,59 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 耀斑影像预览图生成
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_glint_deglint_previews(
|
||||
work_dir: str,
|
||||
output_subdir: str = "glint_deglint_previews",
|
||||
generate_glint: bool = True,
|
||||
generate_deglint: bool = True,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成2_glint和3_deglint文件夹中影像文件的PNG预览图
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录
|
||||
output_subdir: 输出子目录名称
|
||||
generate_glint: 是否处理2_glint文件夹
|
||||
generate_deglint: 是否处理3_deglint文件夹
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
生成的预览图路径字典
|
||||
"""
|
||||
print(f"\n{'='*70}")
|
||||
print("步骤: 生成耀斑分析影像预览图")
|
||||
print(f"{'='*70}")
|
||||
|
||||
if work_dir is None:
|
||||
raise ValueError("请提供 work_dir")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(work_dir) / "visualization" / output_subdir)
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
try:
|
||||
preview_paths = visualizer.generate_glint_deglint_previews(
|
||||
work_dir=work_dir,
|
||||
output_subdir=output_subdir,
|
||||
generate_glint=generate_glint,
|
||||
generate_deglint=generate_deglint
|
||||
)
|
||||
|
||||
print(f"耀斑分析影像预览图生成完成,共生成 {len(preview_paths)} 个预览图")
|
||||
return preview_paths
|
||||
|
||||
except Exception as e:
|
||||
print(f"生成耀斑分析影像预览图时出错: {e}")
|
||||
return {}
|
||||
@ -1,147 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 流程执行报告生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def generate_pipeline_report(
|
||||
step_timings: Dict,
|
||||
pipeline_start_time: Optional[float] = None,
|
||||
pipeline_end_time: Optional[float] = None,
|
||||
output_path: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
生成流程执行报告,包含每步的耗时统计
|
||||
|
||||
Args:
|
||||
step_timings: 步骤耗时字典(格式:{step_name: {start_time, end_time, elapsed_seconds, elapsed_formatted, status, error}})
|
||||
pipeline_start_time: 流程开始时间戳
|
||||
pipeline_end_time: 流程结束时间戳
|
||||
output_path: 输出文件路径(如果为None,自动生成)
|
||||
|
||||
Returns:
|
||||
报告文件路径
|
||||
"""
|
||||
if output_path is None:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
output_path = str(Path.cwd() / "reports" / f"pipeline_report_{timestamp}.csv")
|
||||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _format_time(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.2f}秒"
|
||||
elif seconds < 3600:
|
||||
minutes = int(seconds // 60)
|
||||
secs = seconds % 60
|
||||
return f"{minutes}分{secs:.2f}秒"
|
||||
else:
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = seconds % 60
|
||||
return f"{hours}小时{minutes}分{secs:.2f}秒"
|
||||
|
||||
# 准备报告数据
|
||||
report_data = []
|
||||
total_time = 0.0
|
||||
|
||||
step_order = [
|
||||
"步骤1: 生成水域mask",
|
||||
"步骤2: 找到耀斑区域",
|
||||
"步骤3: 去除耀斑",
|
||||
"步骤4: 处理CSV文件",
|
||||
"步骤5: 提取训练样本点光谱",
|
||||
"步骤5.5: 计算水质光谱指数",
|
||||
"步骤6: 训练机器学习模型",
|
||||
"步骤6.5: 非经验模型训练",
|
||||
"步骤6.75: 自定义回归",
|
||||
"步骤7: 生成预测采样点",
|
||||
"步骤8: 预测水质参数",
|
||||
"步骤9: 生成分布图"
|
||||
]
|
||||
|
||||
for step_name in step_order:
|
||||
if step_name in step_timings:
|
||||
timing_info = step_timings[step_name]
|
||||
report_data.append({
|
||||
'步骤': step_name,
|
||||
'开始时间': timing_info['start_time'],
|
||||
'结束时间': timing_info['end_time'],
|
||||
'耗时(秒)': f"{timing_info['elapsed_seconds']:.2f}",
|
||||
'耗时(格式化)': timing_info['elapsed_formatted'],
|
||||
'状态': timing_info['status'],
|
||||
'错误信息': timing_info.get('error', '')
|
||||
})
|
||||
if timing_info['status'] == 'completed':
|
||||
total_time += timing_info['elapsed_seconds']
|
||||
|
||||
if pipeline_start_time and pipeline_end_time:
|
||||
pipeline_total = pipeline_end_time - pipeline_start_time
|
||||
report_data.append({
|
||||
'步骤': '总计',
|
||||
'开始时间': datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'结束时间': datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'耗时(秒)': f"{pipeline_total:.2f}",
|
||||
'耗时(格式化)': _format_time(pipeline_total),
|
||||
'状态': 'completed',
|
||||
'错误信息': ''
|
||||
})
|
||||
|
||||
df_report = pd.DataFrame(report_data)
|
||||
df_report.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
# 文本格式报告
|
||||
txt_output_path = str(Path(output_path).with_suffix('.txt'))
|
||||
with open(txt_output_path, 'w', encoding='utf-8') as f:
|
||||
f.write("="*80 + "\n")
|
||||
f.write("水质参数反演流程执行报告\n")
|
||||
f.write("="*80 + "\n\n")
|
||||
|
||||
if pipeline_start_time and pipeline_end_time:
|
||||
f.write(f"流程开始时间: {datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"流程结束时间: {datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"总耗时: {_format_time(pipeline_end_time - pipeline_start_time)}\n\n")
|
||||
|
||||
f.write("-"*80 + "\n")
|
||||
f.write("各步骤执行详情:\n")
|
||||
f.write("-"*80 + "\n\n")
|
||||
|
||||
for step_name in step_order:
|
||||
if step_name in step_timings:
|
||||
timing_info = step_timings[step_name]
|
||||
f.write(f"{step_name}\n")
|
||||
f.write(f" 开始时间: {timing_info['start_time']}\n")
|
||||
f.write(f" 结束时间: {timing_info['end_time']}\n")
|
||||
f.write(f" 耗时: {timing_info['elapsed_formatted']} ({timing_info['elapsed_seconds']:.2f}秒)\n")
|
||||
f.write(f" 状态: {timing_info['status']}\n")
|
||||
if timing_info.get('error'):
|
||||
f.write(f" 错误: {timing_info['error']}\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write("-"*80 + "\n")
|
||||
f.write("统计摘要:\n")
|
||||
f.write("-"*80 + "\n")
|
||||
completed_steps = [s for s in step_timings.values() if s['status'] == 'completed']
|
||||
failed_steps = [s for s in step_timings.values() if s['status'] == 'failed']
|
||||
skipped_steps = [s for s in step_timings.values() if s['status'] == 'skipped']
|
||||
|
||||
f.write(f"成功完成的步骤: {len(completed_steps)}\n")
|
||||
f.write(f"失败的步骤: {len(failed_steps)}\n")
|
||||
f.write(f"跳过的步骤: {len(skipped_steps)}\n")
|
||||
|
||||
if completed_steps:
|
||||
completed_times = [s['elapsed_seconds'] for s in completed_steps]
|
||||
f.write(f"平均耗时: {_format_time(np.mean(completed_times))}\n")
|
||||
f.write(f"最长耗时: {_format_time(np.max(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.max(completed_times)][0]})\n")
|
||||
f.write(f"最短耗时: {_format_time(np.min(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.min(completed_times)][0]})\n")
|
||||
|
||||
print(f"\n流程报告已生成:")
|
||||
print(f" CSV格式: {output_path}")
|
||||
print(f" 文本格式: {txt_output_path}")
|
||||
|
||||
return output_path
|
||||
@ -1,147 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 散点图生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Union
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_model_scatter_plots(
|
||||
models_dir: str,
|
||||
training_csv_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
metric: str = 'test_r2',
|
||||
use_enhanced: bool = True,
|
||||
feature_start_column: Union[str, int] = 13,
|
||||
test_size: float = 0.2,
|
||||
random_state: int = 42,
|
||||
scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成模型评估散点图(真实值vs预测值)
|
||||
|
||||
Args:
|
||||
models_dir: 模型保存目录
|
||||
training_csv_path: 训练数据CSV路径
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
metric: 选择最佳模型的指标
|
||||
use_enhanced: 是否使用增强版散点图(带置信区间,使用sctter_batch)
|
||||
feature_start_column: 特征开始列名或索引
|
||||
test_size: 测试集比例
|
||||
random_state: 随机种子
|
||||
scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象
|
||||
|
||||
Returns:
|
||||
散点图文件路径字典(键为目标参数名)
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成模型评估散点图")
|
||||
print("="*80)
|
||||
|
||||
if training_csv_path is None:
|
||||
raise ValueError("请提供 training_csv_path")
|
||||
|
||||
models_path = Path(models_dir)
|
||||
if not models_path.exists():
|
||||
raise ValueError(f"模型目录不存在: {models_dir}")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
scatter_paths = {}
|
||||
|
||||
# 增强版散点图
|
||||
if use_enhanced:
|
||||
print("使用增强版散点图(带置信区间)")
|
||||
try:
|
||||
from src.core.prediction.sctter_batch import WaterQualityScatterBatch
|
||||
if scatter_batch is None:
|
||||
scatter_batch = WaterQualityScatterBatch()
|
||||
|
||||
results = scatter_batch.batch_plot_scatter(
|
||||
models_root_dir=models_dir,
|
||||
csv_path=training_csv_path,
|
||||
output_dir=output_dir,
|
||||
metric=metric,
|
||||
target_column=None,
|
||||
feature_start_column=feature_start_column,
|
||||
test_size=test_size,
|
||||
random_state=random_state
|
||||
)
|
||||
|
||||
for target_name, result in results.items():
|
||||
if result.get('status') == 'success':
|
||||
scatter_paths[target_name] = result.get('save_path', '')
|
||||
print(f" ✓ {target_name}: {result.get('save_path', '')}")
|
||||
else:
|
||||
print(f" ✗ {target_name}: 失败 - {result.get('error', '未知错误')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用增强版散点图时出错: {e}")
|
||||
print("回退到基础版散点图")
|
||||
use_enhanced = False
|
||||
|
||||
# 基础版散点图
|
||||
if not use_enhanced or not scatter_paths:
|
||||
print("使用基础版散点图")
|
||||
for target_folder in models_path.iterdir():
|
||||
if not target_folder.is_dir():
|
||||
continue
|
||||
|
||||
target_name = target_folder.name
|
||||
print(f"\n处理目标参数: {target_name}")
|
||||
|
||||
try:
|
||||
inferencer = WaterQualityInference(str(target_folder))
|
||||
eval_result = inferencer.evaluate_with_split(
|
||||
data_csv_path=training_csv_path,
|
||||
split_method="spxy",
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
metric=metric
|
||||
)
|
||||
|
||||
predictions = eval_result.get('predictions', {})
|
||||
if predictions:
|
||||
y_train_true = predictions.get('y_train_true')
|
||||
y_train_pred = predictions.get('y_train_pred')
|
||||
y_test_true = predictions.get('y_test_true')
|
||||
y_test_pred = predictions.get('y_test_pred')
|
||||
metrics = eval_result.get('test_metrics', {})
|
||||
|
||||
if y_train_true is not None and y_test_true is not None:
|
||||
y_all_true = np.concatenate([y_train_true, y_test_true])
|
||||
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
|
||||
|
||||
train_indices = np.arange(len(y_train_true))
|
||||
test_indices = np.arange(len(y_train_true), len(y_all_true))
|
||||
|
||||
scatter_path = visualizer.plot_scatter_true_vs_pred(
|
||||
y_true=y_all_true,
|
||||
y_pred=y_all_pred,
|
||||
target_name=target_name,
|
||||
train_indices=train_indices,
|
||||
test_indices=test_indices,
|
||||
metrics={
|
||||
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
|
||||
'test_r2': metrics.get('r2', 0),
|
||||
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
|
||||
'test_rmse': metrics.get('rmse', 0)
|
||||
}
|
||||
)
|
||||
scatter_paths[target_name] = scatter_path
|
||||
except Exception as e:
|
||||
print(f"处理目标参数 {target_name} 时出错: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表")
|
||||
return scatter_paths
|
||||
@ -1,80 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 光谱曲线图生成
|
||||
"""
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Union
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_spectrum_comparison_plots(
|
||||
csv_path: str,
|
||||
parameter_columns: Optional[List[str]] = None,
|
||||
wavelength_start_column: Union[str, int] = "UTM_Y",
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成光谱曲线对比图(不同参数值的光谱曲线对比)
|
||||
|
||||
Args:
|
||||
csv_path: 包含光谱和参数值的CSV文件路径
|
||||
parameter_columns: 参数列名列表(如果为None,自动检测)
|
||||
wavelength_start_column: 波长开始列名或索引
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
光谱曲线图文件路径字典(键为参数名)
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成光谱曲线对比图")
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供 csv_path")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
csv_dir = Path(csv_path).parent
|
||||
output_dir = str(csv_dir / "visualization" / "spectrum_plots")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
# 读取数据以检测参数列
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
if parameter_columns is None:
|
||||
if isinstance(wavelength_start_column, str):
|
||||
try:
|
||||
wavelength_start_idx = df.columns.get_loc(wavelength_start_column)
|
||||
except:
|
||||
wavelength_start_idx = 13
|
||||
else:
|
||||
wavelength_start_idx = wavelength_start_column
|
||||
|
||||
parameter_columns = list(df.columns[:wavelength_start_idx])
|
||||
if len(parameter_columns) > 2:
|
||||
parameter_columns = parameter_columns[2:]
|
||||
|
||||
spectrum_paths = {}
|
||||
for param_col in parameter_columns:
|
||||
if param_col not in df.columns:
|
||||
continue
|
||||
|
||||
print(f"\n处理参数: {param_col}")
|
||||
try:
|
||||
spectrum_path = visualizer.plot_spectrum_by_parameter(
|
||||
csv_path=csv_path,
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wavelength_start_column,
|
||||
n_groups=5
|
||||
)
|
||||
spectrum_paths[param_col] = spectrum_path
|
||||
except Exception as e:
|
||||
print(f"处理参数 {param_col} 时出错: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n光谱曲线图生成完成,共生成 {len(spectrum_paths)} 个图表")
|
||||
return spectrum_paths
|
||||
@ -1,59 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 统计图表生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_statistical_charts(
|
||||
csv_path: str,
|
||||
parameter_columns: Optional[List[str]] = None,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成统计图表(箱线图、直方图、相关性热力图)
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径
|
||||
parameter_columns: 参数列名列表(如果为None,自动检测)
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
统计图表文件路径字典
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成统计图表")
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供 csv_path")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
csv_dir = Path(csv_path).parent
|
||||
output_dir = str(csv_dir / "visualization" / "statistical_charts")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
# 读取数据以检测参数列
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
if parameter_columns is None:
|
||||
parameter_columns = list(df.columns[2:])
|
||||
parameter_columns = [col for col in parameter_columns
|
||||
if df[col].dtype in [np.float64, np.int64]]
|
||||
|
||||
chart_paths = visualizer.plot_statistical_charts(
|
||||
csv_path=csv_path,
|
||||
parameter_columns=parameter_columns
|
||||
)
|
||||
|
||||
print(f"\n统计图表生成完成")
|
||||
return chart_paths
|
||||
File diff suppressed because it is too large
Load Diff
@ -1 +0,0 @@
|
||||
# src.gui.components package
|
||||
@ -1,143 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自定义组件 - 文件选择控件等公共组件
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QHBoxLayout, QLabel, QLineEdit, QPushButton, QFileDialog,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
|
||||
class DirSelectWidget(QWidget):
|
||||
"""目录选择组件"""
|
||||
def __init__(self, label_text, parent=None):
|
||||
"""
|
||||
初始化目录选择组件
|
||||
|
||||
Args:
|
||||
label_text: 标签文本
|
||||
parent: 父控件
|
||||
"""
|
||||
super().__init__(parent)
|
||||
self.init_ui(label_text)
|
||||
|
||||
def init_ui(self, label_text):
|
||||
layout = QHBoxLayout()
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
self.label = QLabel(label_text)
|
||||
self.label.setMinimumWidth(120)
|
||||
self.line_edit = QLineEdit()
|
||||
self.line_edit.setPlaceholderText("请选择目录...")
|
||||
self.browse_btn = QPushButton("浏览...")
|
||||
self.browse_btn.setMaximumWidth(80)
|
||||
self.browse_btn.clicked.connect(self.browse_dir)
|
||||
|
||||
layout.addWidget(self.label)
|
||||
layout.addWidget(self.line_edit, 1)
|
||||
layout.addWidget(self.browse_btn)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def browse_dir(self):
|
||||
"""浏览目录 - 智能记忆上次选择位置"""
|
||||
current_text = self.line_edit.text().strip()
|
||||
initial_dir = ""
|
||||
|
||||
# 最高优先级:输入框已有路径存在
|
||||
if current_text:
|
||||
if os.path.isdir(current_text):
|
||||
initial_dir = current_text
|
||||
else:
|
||||
dir_path = os.path.dirname(current_text)
|
||||
if dir_path and os.path.exists(dir_path):
|
||||
initial_dir = dir_path
|
||||
|
||||
# 调用目录选择对话框
|
||||
dir_path = QFileDialog.getExistingDirectory(
|
||||
self, "选择目录", initial_dir
|
||||
)
|
||||
if dir_path:
|
||||
self.line_edit.setText(dir_path)
|
||||
|
||||
def get_path(self):
|
||||
"""获取路径"""
|
||||
return self.line_edit.text()
|
||||
|
||||
def set_path(self, path):
|
||||
"""设置路径"""
|
||||
self.line_edit.setText(str(path))
|
||||
|
||||
|
||||
class FileSelectWidget(QWidget):
|
||||
"""文件选择组件"""
|
||||
def __init__(self, label_text, file_filter="All Files (*.*)", mode="open", parent=None):
|
||||
"""
|
||||
初始化文件选择组件
|
||||
|
||||
Args:
|
||||
label_text: 标签文本
|
||||
file_filter: 文件过滤器
|
||||
mode: 选择模式 - "open"(打开文件) 或 "save"(保存文件)
|
||||
parent: 父控件
|
||||
"""
|
||||
super().__init__(parent)
|
||||
self.file_filter = file_filter
|
||||
self.mode = mode # "open" 或 "save"
|
||||
self.init_ui(label_text)
|
||||
|
||||
def init_ui(self, label_text):
|
||||
layout = QHBoxLayout()
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
self.label = QLabel(label_text)
|
||||
self.label.setMinimumWidth(120)
|
||||
self.line_edit = QLineEdit()
|
||||
placeholder = "请选择保存路径..." if self.mode == "save" else "请选择文件..."
|
||||
self.line_edit.setPlaceholderText(placeholder)
|
||||
self.browse_btn = QPushButton("浏览...")
|
||||
self.browse_btn.setMaximumWidth(80)
|
||||
self.browse_btn.clicked.connect(self.browse_file)
|
||||
|
||||
layout.addWidget(self.label)
|
||||
layout.addWidget(self.line_edit, 1)
|
||||
layout.addWidget(self.browse_btn)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def browse_file(self):
|
||||
"""浏览文件 - 智能记忆上次选择位置"""
|
||||
current_text = self.line_edit.text().strip()
|
||||
initial_dir = ""
|
||||
|
||||
# 最高优先级:输入框已有路径存在
|
||||
if current_text:
|
||||
if os.path.isdir(current_text):
|
||||
initial_dir = current_text
|
||||
else:
|
||||
dir_path = os.path.dirname(current_text)
|
||||
if dir_path and os.path.exists(dir_path):
|
||||
initial_dir = dir_path
|
||||
|
||||
if self.mode == "save":
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存文件", initial_dir, self.file_filter
|
||||
)
|
||||
else:
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self, "选择文件", initial_dir, self.file_filter
|
||||
)
|
||||
if file_path:
|
||||
self.line_edit.setText(file_path)
|
||||
|
||||
def get_path(self):
|
||||
"""获取路径"""
|
||||
return self.line_edit.text()
|
||||
|
||||
def set_path(self, path):
|
||||
"""设置路径"""
|
||||
self.line_edit.setText(str(path))
|
||||
@ -1 +0,0 @@
|
||||
# src.gui.core
|
||||
@ -1,332 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
后台线程模块:Pipeline 执行线程与诊断逻辑。
|
||||
"""
|
||||
import traceback
|
||||
from PyQt5.QtCore import QThread, pyqtSignal
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 依赖诊断
|
||||
# =============================================================================
|
||||
|
||||
def check_pipeline_dependencies():
|
||||
"""检查pipeline模块的依赖项"""
|
||||
missing_deps = []
|
||||
dep_errors = {}
|
||||
|
||||
required_packages = [
|
||||
'numpy', 'pandas', 'scipy', 'matplotlib', 'sklearn',
|
||||
'joblib', 'PIL', 'cv2', 'rasterio', 'geopandas'
|
||||
]
|
||||
|
||||
for package in required_packages:
|
||||
try:
|
||||
if package == 'PIL':
|
||||
import PIL
|
||||
elif package == 'cv2':
|
||||
import cv2
|
||||
else:
|
||||
__import__(package)
|
||||
except Exception as e:
|
||||
missing_deps.append(package)
|
||||
dep_errors[package] = repr(e)
|
||||
|
||||
return missing_deps, dep_errors
|
||||
|
||||
|
||||
def diagnose_pipeline_import_error():
|
||||
"""诊断pipeline导入错误"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
error_info = []
|
||||
|
||||
is_frozen = getattr(sys, "frozen", False) or bool(getattr(sys, "_MEIPASS", None))
|
||||
|
||||
if is_frozen:
|
||||
error_info.append(
|
||||
"[INFO] PyInstaller 环境:Pipeline 从程序内置包加载,跳过对仓库路径 src/core/*.py 的磁盘检查"
|
||||
)
|
||||
else:
|
||||
pipeline_file = os.path.normpath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "core", "water_quality_inversion_pipeline_GUI.py")
|
||||
)
|
||||
if not os.path.exists(pipeline_file):
|
||||
error_info.append(f"[ERROR] Pipeline文件不存在: {pipeline_file}")
|
||||
error_info.append(
|
||||
" 解决方案: 请确保项目结构完整,检查 src/core/ 下是否有 water_quality_inversion_pipeline_GUI.py"
|
||||
)
|
||||
else:
|
||||
error_info.append(f"[OK] Pipeline文件存在: {pipeline_file}")
|
||||
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
if current_dir not in sys.path:
|
||||
sys.path.insert(0, current_dir)
|
||||
error_info.append(f"[INFO] 已添加路径到sys.path: {current_dir}")
|
||||
|
||||
missing_deps, dep_errors = check_pipeline_dependencies()
|
||||
if missing_deps:
|
||||
error_info.append(f"[ERROR] 缺少必需的依赖包: {', '.join(missing_deps)}")
|
||||
for pkg in missing_deps:
|
||||
if pkg in dep_errors:
|
||||
error_info.append(f" - {pkg} 导入失败原因: {dep_errors[pkg]}")
|
||||
error_info.append(" 解决方案: 请运行以下命令安装依赖:")
|
||||
error_info.append(" pip install -r requirements.txt")
|
||||
error_info.append(" 或使用conda:")
|
||||
error_info.append(" conda install numpy pandas scipy matplotlib scikit-learn joblib pillow opencv-python rasterio geopandas")
|
||||
else:
|
||||
error_info.append("[OK] 主要依赖包均已安装")
|
||||
|
||||
try:
|
||||
from osgeo import gdal # noqa: F401
|
||||
error_info.append("[OK] GDAL (osgeo) 可用")
|
||||
except ImportError:
|
||||
try:
|
||||
from osgeo import gdal # noqa: F401
|
||||
error_info.append("[OK] GDAL 可用")
|
||||
except ImportError:
|
||||
error_info.append("[WARNING] GDAL/osgeo 不可用,将影响栅格与地理数据处理")
|
||||
error_info.append(" 开发环境: conda install gdal")
|
||||
error_info.append(" 打包环境: 请在构建所用 Conda 环境中打包,并确保 spec 已收集 Library/bin 中依赖 DLL")
|
||||
|
||||
try:
|
||||
import unittest
|
||||
error_info.append("[OK] unittest模块可用")
|
||||
except ImportError:
|
||||
error_info.append("[WARNING] unittest模块不可用,这可能是PyInstaller打包环境导致的")
|
||||
error_info.append(" 这不会影响主要功能,但可能影响某些测试相关特性")
|
||||
|
||||
return error_info
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pipeline 可用性标志(模块级状态)
|
||||
# =============================================================================
|
||||
|
||||
PIPELINE_AVAILABLE = False
|
||||
PIPELINE_ERROR_INFO = []
|
||||
|
||||
try:
|
||||
error_info = diagnose_pipeline_import_error()
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
PIPELINE_AVAILABLE = True
|
||||
print("[OK] 成功导入pipeline模块")
|
||||
PIPELINE_ERROR_INFO = error_info
|
||||
|
||||
except ImportError as e:
|
||||
PIPELINE_AVAILABLE = False
|
||||
error_info = diagnose_pipeline_import_error()
|
||||
|
||||
print("="*60)
|
||||
print("[ERROR] PIPELINE导入失败 - 详细诊断信息:")
|
||||
print("="*60)
|
||||
|
||||
for info in error_info:
|
||||
print(info)
|
||||
|
||||
print("-"*60)
|
||||
print(f"原始ImportError: {str(e)}")
|
||||
print("-"*60)
|
||||
|
||||
if "unittest" in str(e):
|
||||
print("[INFO] unittest模块缺失 - 这通常在PyInstaller打包环境中发生")
|
||||
print("解决方案:")
|
||||
print(" 1. 这不会影响主要功能,程序仍可正常运行")
|
||||
print(" 2. 如果需要修复,可以在.spec文件中添加unittest模块:")
|
||||
print(" a = Analysis(..., hiddenimports=['unittest', 'unittest.mock'])")
|
||||
print(" 3. 或在PyInstaller命令中添加: --hidden-import unittest")
|
||||
elif "water_quality_inversion_pipeline_GUI" in str(e):
|
||||
print("[INFO] 可能的解决方案:")
|
||||
print(" 1. 检查src/core/water_quality_inversion_pipeline_GUI.py文件是否存在")
|
||||
print(" 2. 确保Python路径设置正确")
|
||||
print(" 3. 尝试重新安装依赖: pip install -r requirements.txt")
|
||||
print(" 4. 检查Python版本是否兼容(推荐Python 3.8-3.11)")
|
||||
|
||||
import traceback
|
||||
print("\n完整错误追踪:")
|
||||
traceback.print_exc()
|
||||
print("="*60)
|
||||
|
||||
PIPELINE_ERROR_INFO = error_info
|
||||
|
||||
except Exception as e:
|
||||
PIPELINE_AVAILABLE = False
|
||||
error_info = diagnose_pipeline_import_error()
|
||||
|
||||
print("="*60)
|
||||
print("[ERROR] PIPELINE导入失败 - 其他错误:")
|
||||
print("="*60)
|
||||
|
||||
for info in error_info:
|
||||
print(info)
|
||||
|
||||
print("-"*60)
|
||||
print(f"原始错误: {str(e)}")
|
||||
print("-"*60)
|
||||
|
||||
print("[INFO] 可能的解决方案:")
|
||||
print(" 1. 检查Python环境和依赖包版本")
|
||||
print(" 2. 尝试重新安装所有依赖")
|
||||
print(" 3. 检查是否有语法错误或其他模块导入问题")
|
||||
|
||||
import traceback
|
||||
print("\n完整错误追踪:")
|
||||
traceback.print_exc()
|
||||
print("="*60)
|
||||
|
||||
PIPELINE_ERROR_INFO = error_info
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WorkerThread
|
||||
# =============================================================================
|
||||
|
||||
class WorkerThread(QThread):
|
||||
"""后台工作线程,用于执行耗时任务(在工作线程内创建 Pipeline,避免阻塞 UI)。"""
|
||||
progress_update = pyqtSignal(int, str) # 进度更新信号 (percentage, message)
|
||||
log_message = pyqtSignal(str, str) # 日志消息信号 (message, level: 'info'/'warning'/'error')
|
||||
step_completed = pyqtSignal(str, bool, str) # 步骤完成信号 (step_name, success, message)
|
||||
finished = pyqtSignal(bool, str) # 完成信号 (success, message)
|
||||
|
||||
def __init__(self, work_dir: str, config, mode='full', step_name=None):
|
||||
super().__init__()
|
||||
self.work_dir = str(work_dir)
|
||||
self.config = config
|
||||
self.mode = mode # 'full' 或 'single_step'
|
||||
self.step_name = step_name # 单步执行时的步骤名称
|
||||
self.pipeline = None
|
||||
self.is_running = True
|
||||
self.current_step = None
|
||||
self.step_count = 0
|
||||
self.total_steps = 9
|
||||
|
||||
def pipeline_callback(self, step_name, status, message=""):
|
||||
"""Pipeline回调函数,用于接收步骤状态"""
|
||||
if status == "start":
|
||||
self.log_message.emit(f"[START] 开始执行: {step_name}", "info")
|
||||
progress = int((self.step_count / self.total_steps) * 100)
|
||||
self.progress_update.emit(progress, f"正在执行: {step_name}")
|
||||
elif status == "completed":
|
||||
self.step_count += 1
|
||||
self.log_message.emit(f"[DONE] 完成: {step_name} {message}", "info")
|
||||
self.step_completed.emit(step_name, True, message)
|
||||
progress = int((self.step_count / self.total_steps) * 100)
|
||||
self.progress_update.emit(progress, f"已完成: {step_name}")
|
||||
elif status == "skipped":
|
||||
self.step_count += 1
|
||||
self.log_message.emit(f"[SKIP] 跳过: {step_name} {message}", "warning")
|
||||
self.step_completed.emit(step_name, True, f"跳过: {message}")
|
||||
progress = int((self.step_count / self.total_steps) * 100)
|
||||
self.progress_update.emit(progress, f"已跳过: {step_name}")
|
||||
elif status == "error":
|
||||
self.log_message.emit(f"[ERROR] 错误: {step_name} - {message}", "error")
|
||||
self.step_completed.emit(step_name, False, message)
|
||||
elif status == "info":
|
||||
self.log_message.emit(f" {message}", "info")
|
||||
elif status == "warning":
|
||||
self.log_message.emit(f" [WARNING] {message}", "warning")
|
||||
|
||||
def run(self):
|
||||
"""运行 pipeline:子线程内切换 Matplotlib 为 Agg,避免 Qt5Agg 在后台线程绘图导致界面卡死。"""
|
||||
import os
|
||||
# GDAL 环境变量保护(放在最前面,防止路径/编码问题)
|
||||
os.environ['GDAL_FILENAME_IS_UTF8'] = 'YES'
|
||||
os.environ['SHAPE_ENCODING'] = 'UTF-8'
|
||||
|
||||
mpl_prev = None
|
||||
try:
|
||||
import matplotlib
|
||||
mpl_prev = matplotlib.get_backend()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("Agg")
|
||||
except Exception:
|
||||
mpl_prev = None
|
||||
try:
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
self.pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
|
||||
|
||||
if self.mode == 'full':
|
||||
self.log_message.emit("开始运行完整流程...", "info")
|
||||
self.step_count = 0
|
||||
|
||||
if hasattr(self.pipeline, 'set_callback'):
|
||||
self.pipeline.set_callback(self.pipeline_callback)
|
||||
|
||||
self.pipeline.run_full_pipeline(self.config)
|
||||
|
||||
self.progress_update.emit(100, "流程执行完成")
|
||||
self.finished.emit(True, "完整流程执行成功!")
|
||||
else:
|
||||
self.log_message.emit(f"开始独立运行步骤: {self.step_name}", "info")
|
||||
self.progress_update.emit(0, f"正在执行: {self.step_name}")
|
||||
|
||||
if hasattr(self.pipeline, 'set_callback'):
|
||||
self.pipeline.set_callback(self.pipeline_callback)
|
||||
|
||||
self.run_single_step(self.step_name, self.config)
|
||||
|
||||
self.progress_update.emit(100, f"步骤 {self.step_name} 执行完成")
|
||||
self.finished.emit(True, f"步骤 {self.step_name} 独立运行成功!")
|
||||
except Exception as e:
|
||||
error_msg = f"执行失败: {str(e)}\n{traceback.format_exc()}"
|
||||
self.log_message.emit(error_msg, "error")
|
||||
self.finished.emit(False, error_msg)
|
||||
finally:
|
||||
if mpl_prev:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend(mpl_prev)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def run_single_step(self, step_name, config):
|
||||
"""运行单个步骤"""
|
||||
step_method_map = {
|
||||
'step1': 'step1_generate_water_mask',
|
||||
'step2': 'step2_find_glint_area',
|
||||
'step3': 'step3_remove_glint',
|
||||
'step4': 'step4_process_csv',
|
||||
'step5': 'step5_extract_training_spectra',
|
||||
'step5_5': 'step5_5_calculate_water_quality_indices',
|
||||
'step6': 'step6_train_models',
|
||||
'step6_5': 'step6_5_non_empirical_modeling',
|
||||
'step6_75': 'step6_75_custom_regression',
|
||||
'step7': 'step7_generate_sampling_points',
|
||||
'step8': 'step8_predict_water_quality',
|
||||
'step8_5': 'step8_5_predict_with_non_empirical_models',
|
||||
'step8_75': 'step8_75_predict_with_custom_regression',
|
||||
'step9': 'step9_generate_distribution_map'
|
||||
}
|
||||
|
||||
if step_name not in step_method_map:
|
||||
raise ValueError(f"未知的步骤名称: {step_name}")
|
||||
|
||||
method_name = step_method_map[step_name]
|
||||
step_config = dict(config.get(step_name, {}))
|
||||
|
||||
step_config['skip_dependency_check'] = True
|
||||
|
||||
if step_name == 'step9':
|
||||
step_config.pop('step9_batch_mode', None)
|
||||
step_config.pop('prediction_csv_dir', None)
|
||||
step_config.pop('recursive_csv_scan', None)
|
||||
|
||||
if step_name in ['step2', 'step3', 'step4', 'step5', 'step6', 'step7', 'step8', 'step8_5', 'step8_75']:
|
||||
step_config.pop('output_path', None)
|
||||
|
||||
if step_name == 'step8_5' and 'models_dir' in step_config:
|
||||
step_config['non_empirical_models_dir'] = step_config.pop('models_dir')
|
||||
|
||||
method = getattr(self.pipeline, method_name)
|
||||
result = method(**step_config)
|
||||
|
||||
return result
|
||||
|
||||
def stop(self):
|
||||
"""停止执行"""
|
||||
self.is_running = False
|
||||
self.terminate()
|
||||
@ -1,46 +0,0 @@
|
||||
Formula_Name,Category,Formula,Reference
|
||||
BGA_Am09KBBI,Phycocyanin (BGA_PC),(w686 - w658) / (w686 + w658),"Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S.; Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery, Optics Express, 2009, 17, 11, 1-13."
|
||||
BGA_Be162B643sub629,Phycocyanin (BGA_PC),w644 - w629,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be162B700sub601,Phycocyanin (BGA_PC),w700 - w601,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be162BsubPhy,Phycocyanin (BGA_PC),w715 - w615,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540."
|
||||
BGA_Be16FLHBlueRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w458 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be16FLHGreenRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w558 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be16FLHVioletRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w444 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be16MPI,Phycocyanin (BGA_PC),(w615 - w601) - (w644 - w601),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be16NDPhyI,Phycocyanin (BGA_PC),(w700 - w622) / (w700 + w622),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540."
|
||||
BGA_Be16NDPhyI644over615,Phycocyanin (BGA_PC),(w644 - w615) / (w644 + w615),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 541."
|
||||
BGA_Be16NDPhyI644over629,Phycocyanin (BGA_PC),(w644 - w629) / (w644 + w629),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 542."
|
||||
BGA_Be16Phy2BDA644over629,Phycocyanin (BGA_PC),w644 / w629,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 545."
|
||||
BGA_Da052BDA,Phycocyanin (BGA_PC),w714 / w672,"Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672."
|
||||
BGA_Go04MCI,Phycocyanin (BGA_PC),w709 - w681 - (w753 - w681),"Gower, J.F.R.; Brown,L.; Borstad, G.A.; Observation of chlorophyll fluorescence in west coast waters of Canada using the MODIS satellite sensor. Can. J. Remote Sens., 2004, 30 (1), 17闁?5."
|
||||
BGA_HU103BDA,Phycocyanin (BGA_PC),(((1 / w615) - (1 / w600)) - w725),"Hunter, P.D.; Tyler, A.N.; Willby, N.J.; Gilvear, D.J.; The spatial dynamics of vertical migration by Microcystis aeruginosa in a eutrophic shallow lake: A case study using high spatial resolution time-series airborne remote sensing. Limn. Oceanogr. 2008, 53, 2391-2406"
|
||||
BGA_Ku15PhyCI,Phycocyanin (BGA_PC),(-1 * (W681 - W665 - (W709 - W665))),"Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-10."
|
||||
BGA_Ku15SLH,Phycocyanin (BGA_PC),(w715 - w658) + (w715 - w658),"Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-11."
|
||||
BGA_MI092BDA,Phycocyanin (BGA_PC),w700 / w600,"Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758闁?75."
|
||||
BGA_MM092BDA,Phycocyanin (BGA_PC),w724 / w600,"Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758闁?76."
|
||||
BGA_MM12NDCIalt,Phycocyanin (BGA_PC),(w700 - w658) / (w700 + w658),"Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114003"
|
||||
BGA_MM143BDAopt,Phycocyanin (BGA_PC),((1 / w629) - (1 / w659)) * w724,"Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114004"
|
||||
BGA_SI052BDA,Phycocyanin (BGA_PC),w709 / w620,"Simis, S. G. H.; Peters, S.W. M.; Gons, H. J.; Remote sensing of the cyanobacteria pigment phycocyanin in turbid inland water. Limn. Oceanogr., 2005, 50, 237闁?45"
|
||||
BGA_SM122BDA,Phycocyanin (BGA_PC),w709 / w600,"Mishra, S. Remote sensing of cyanobacteria in turbid productive waters, PhD Dissertation. Mississippi State University, USA. 2012."
|
||||
BGA_SY002BDA,Phycocyanin (BGA_PC),w650 / w625,"Schalles, J.; Yacobi, Y. Remote detection and seasonal patterns of phycocyanin, carotenoid and chlorophyll-a pigments in eutrophic waters. Archiv fur Hydrobiologie, Special Issues Advances in Limnology, 2000, 55,153闁?68"
|
||||
BGA_Wy08CI,Phycocyanin (BGA_PC),(-1 * (W686 - W672 - (W715 - W672))),"Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672."
|
||||
Chl_Al10SABI,chlorophyll_a,(w857 - w644) / (w458 + w529),"Alawadi, F. Detection of surface algal blooms using the newly developed algorithm surface algal bloom index (SABI). Proc. SPIE 2010, 7825."
|
||||
Chl_Am092Bsub,chlorophyll_a,w681 - w665,"Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S. Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery. Opt. Express 2009, 17, 9126闁?144."
|
||||
Chl_Be16FLHblue,chlorophyll_a,w529 - (w644 + (w458 - w644)),"Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30."
|
||||
Chl_Be16FLHviolet,chlorophyll_a,w529 - (w644 + (w429 - w644)),"Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30."
|
||||
Chl_Be16NDTIblue,chlorophyll_a,(w658 - w458) / (w658 + w458),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 543."
|
||||
Chl_Be16NDTIviolet,chlorophyll_a,(w658 - w444) / (w658 + w444),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 544."
|
||||
Chl_De933BDA,chlorophyll_a,w600 - w648 - w625,"Dekker, A.; Detection of the optical water quality parameters for eutrophic waters by high resolution remote sensing, Ph.D. thesis, 1993, Free University, Amsterdam."
|
||||
Chl_Gi033BDA,chlorophyll_a,((1 / w672) - (1 / w715)) * w757,"Gitelson, A.A.; U. Gritz, and M. N. Merzlyak.; Relationships between leaf chlorophyll content and spectral reflectance and algorithms for non-destructive chlorophyll assessment in higher plant leaves. J. Plant Phys. 2003, 160, 271-282."
|
||||
Chl_Kn07KIVU,chlorophyll_a,(w458 - w644) / w529,"Kneubuhler, M.; Frank T.; Kellenberger, T.W; Pasche N.; Schmid M.; Mapping chlorophyll-a in Lake Kivu with remote sensing methods. 2007, Proceedings of the Envisat Symposium 2007, Montreux, Switzerland 23闁?7 April 2007 (ESA SP-636, July 2007)."
|
||||
Chl_MM12NDCI,chlorophyll_a,(w715 - w686) / (w715 + w686),"Mishra, S.; and Mishra, D.R. Normalized difference chlorophyll index: A novel model for remote estimation of chlorophyll-a concentration in turbid productive waters, Remote Sens. Environ., 2012, 117, 394-406"
|
||||
Chl_Zh10FLH,chlorophyll_a,w686 - (w715 + (w672 - w751)),"Zhao, D.Z.; Xing, X.G.; Liu, Y.G.; Yang, J.H.; Wang, L. The relation of chlorophyll-a concentration with the reflectance peak near 700 nm in algae-dominated waters and sensitivity of fluorescence algorithms for detecting algal bloom. Int. J. Remote Sens. 2010, 31, 39-48"
|
||||
Turb_Be16GreenPlusRedBothOverViolet,Turbidity,(w558 + w658) / w444,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538"
|
||||
Turb_Be16RedOverViolet,Turbidity,w658 / w444,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539"
|
||||
Turb_Bow06RedOverGreen,Turbidity,w658 / w558,"Bowers, D. G., and C. E. Binding. 2006. 闁炽儲缈籬e Optical Properties of Mineral Suspended Particles: A Review and Synthesis.闁?Estuarine Coastal and Shelf Science 67 (1闁?): 219闁?30. doi:10.1016/j.ecss.2005.11.010"
|
||||
Turb_Chip09NIROverGreen,Turbidity,w857 / w558,"Chipman, J. W.; Olmanson, L.G.; Gitelson, A.A.; Remote sensing methods for lake management: A guide for resource managers and decision-makers. 2009."
|
||||
Turb_Dox02NIRoverRed,Turbidity,w857 / w658,"Doxaran, D., Froidefond, J.-M.; Castaing, P. ; A reflectance band ratio used to estimate suspended matter concentrations in sediment-dominated coastal waters, Remote Sens., 2002, 23, 5079-5085"
|
||||
Turb_Frohn09GreenPlusRedBothOverBlue,Turbidity,(w558 + w658) / w458,"Frohn, R. C., & Autrey, B. C. (2009). Water quality assessment in the Ohio River using new indices for turbidity and chlorophyll-a with Landsat-7 Imagery. Draft Internal Report, US Environmental Protection Agency."
|
||||
Turb_Harr92NIR,Turbidity,w857,"Schiebe F.R., Harrington J.A., Ritchie J.C. Remote-Sensing of Suspended Sediments闁炽儲鏁刪e Lake Chicot, Arkansas Project. Int. J. Remote Sens. 1992;13:1487闁?509"
|
||||
Turb_Lath91RedOverBlue,Turbidity,w658 / w458,"Lathrop, R. G., Jr., T. M. Lillesand, and B. S. Yandell, 1991. Testing the utility of simple multi-date Thematic Mapper calibration algorithms for monitoring turbid inland waters. International Journal of Remote Sensing"
|
||||
Turb_Moore80Red,Turbidity,w658,"Moore, G.K., Satellite remote sensing of water turbidity, Hydrological Sciences, 1980, 25, 4, 407-422"
|
||||
|
@ -1,315 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
ReportGenerationPanel - Word 分析报告生成面板
|
||||
"""
|
||||
|
||||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout,
|
||||
QLabel, QCheckBox, QPushButton, QLineEdit, QSpinBox,
|
||||
QMessageBox, QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class ReportGenerateThread(QThread):
|
||||
"""后台生成 Word 报告(避免阻塞 UI)。"""
|
||||
finished_ok = pyqtSignal(str)
|
||||
failed = pyqtSignal(str)
|
||||
log_message = pyqtSignal(str, str)
|
||||
|
||||
def __init__(self, work_dir: str, output_dir: Optional[str], report_title: str, options: dict):
|
||||
super().__init__()
|
||||
self.work_dir = work_dir
|
||||
self.output_dir = output_dir
|
||||
self.report_title = report_title
|
||||
self.options = options
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
from src.postprocessing.report_word import WaterQualityReportGenerator, ReportGenerationConfig
|
||||
|
||||
url = (self.options.get("ollama_url") or "").strip() or None
|
||||
vision = (self.options.get("ollama_vision_model") or "").strip() or None
|
||||
text = (self.options.get("ollama_text_model") or "").strip() or None
|
||||
if self.options.get("text_same_as_vision"):
|
||||
text = vision
|
||||
timeout = self.options.get("ollama_timeout_s")
|
||||
enable_ai = self.options.get("enable_ai_analysis")
|
||||
|
||||
ai_cfg = ReportGenerationConfig(
|
||||
ollama_base_url=url,
|
||||
ollama_vision_model=vision,
|
||||
ollama_text_model=text,
|
||||
ollama_timeout_s=int(timeout) if timeout is not None else None,
|
||||
enable_ai_analysis=bool(enable_ai),
|
||||
)
|
||||
self.log_message.emit(
|
||||
f"报告生成:工作目录={self.work_dir},AI={'开' if enable_ai else '关'},"
|
||||
f"模型URL={url or '(环境变量 OLLAMA_URL)'}",
|
||||
"info",
|
||||
)
|
||||
gen = WaterQualityReportGenerator(
|
||||
work_dir=self.work_dir,
|
||||
output_dir=self.output_dir,
|
||||
ai_config=ai_cfg,
|
||||
)
|
||||
out_path = gen.generate_report(
|
||||
work_dir=self.work_dir,
|
||||
report_title=self.report_title or "水质参数反演分析报告",
|
||||
)
|
||||
self.finished_ok.emit(str(out_path))
|
||||
except Exception as e:
|
||||
self.failed.emit(f"{e}\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
class ReportGenerationPanel(QWidget):
|
||||
"""Word 报告生成:工作目录、输出目录、Ollama URL/模型、是否启用 AI 等。"""
|
||||
|
||||
def __init__(self, main_window=None, parent=None):
|
||||
super().__init__(parent)
|
||||
self.main_window = main_window
|
||||
self._report_thread = None
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
layout.setContentsMargins(10, 10, 10, 10)
|
||||
layout.setSpacing(10)
|
||||
|
||||
intro = QLabel(
|
||||
"根据工作目录下的可视化结果(14_visualization 等)生成 Word 分析报告。"
|
||||
"需已存在可视化图表;AI 分析通过 Ollama /api/chat 调用本地或远程服务。"
|
||||
)
|
||||
intro.setWordWrap(True)
|
||||
intro.setStyleSheet(
|
||||
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
|
||||
)
|
||||
layout.addWidget(intro)
|
||||
|
||||
path_group = QGroupBox("路径")
|
||||
path_form = QFormLayout()
|
||||
|
||||
wd_row = QHBoxLayout()
|
||||
self.work_dir_edit = QLineEdit()
|
||||
self.work_dir_edit.setPlaceholderText("选择流程工作目录(含 14_visualization)…")
|
||||
wd_browse = QPushButton("浏览…")
|
||||
wd_browse.clicked.connect(self.browse_work_dir)
|
||||
sync_btn = QPushButton("同步主窗口工作目录")
|
||||
sync_btn.clicked.connect(self.sync_work_dir_from_main)
|
||||
wd_row.addWidget(self.work_dir_edit, 1)
|
||||
wd_row.addWidget(wd_browse)
|
||||
wd_row.addWidget(sync_btn)
|
||||
path_form.addRow("工作目录:", wd_row)
|
||||
|
||||
out_row = QHBoxLayout()
|
||||
self.output_dir_edit = QLineEdit()
|
||||
self.output_dir_edit.setPlaceholderText("留空则保存到 工作目录/14_visualization")
|
||||
out_browse = QPushButton("浏览…")
|
||||
out_browse.clicked.connect(self.browse_output_dir)
|
||||
out_row.addWidget(self.output_dir_edit, 1)
|
||||
out_row.addWidget(out_browse)
|
||||
path_form.addRow("报告输出目录:", out_row)
|
||||
|
||||
self.report_title_edit = QLineEdit()
|
||||
self.report_title_edit.setText("水质参数反演分析报告")
|
||||
path_form.addRow("报告标题:", self.report_title_edit)
|
||||
|
||||
path_group.setLayout(path_form)
|
||||
layout.addWidget(path_group)
|
||||
|
||||
ai_group = QGroupBox("AI 分析(Ollama)")
|
||||
ai_form = QFormLayout()
|
||||
|
||||
self.enable_ai_cb = QCheckBox("启用 AI 图表解读与综合总结")
|
||||
self.enable_ai_cb.setChecked(
|
||||
os.environ.get("ENABLE_AI_ANALYSIS", "1") not in {"0", "false", "False"}
|
||||
)
|
||||
ai_form.addRow(self.enable_ai_cb)
|
||||
|
||||
self.ollama_url_edit = QLineEdit()
|
||||
self.ollama_url_edit.setText(
|
||||
os.environ.get("OLLAMA_URL", "http://localhost:11434").rstrip("/")
|
||||
)
|
||||
ai_form.addRow("服务 URL:", self.ollama_url_edit)
|
||||
|
||||
self.vision_model_edit = QLineEdit()
|
||||
self.vision_model_edit.setText(
|
||||
os.environ.get("OLLAMA_VISION_MODEL", "qwen3-vl:8b")
|
||||
)
|
||||
ai_form.addRow("视觉模型:", self.vision_model_edit)
|
||||
|
||||
self.same_text_model_cb = QCheckBox("文本总结与视觉使用同一模型")
|
||||
self.same_text_model_cb.setChecked(True)
|
||||
ai_form.addRow(self.same_text_model_cb)
|
||||
|
||||
self.text_model_edit = QLineEdit()
|
||||
self.text_model_edit.setText(
|
||||
os.environ.get(
|
||||
"OLLAMA_TEXT_MODEL",
|
||||
self.vision_model_edit.text() or "qwen3-vl:8b"
|
||||
)
|
||||
)
|
||||
self.text_model_edit.setEnabled(False)
|
||||
self.same_text_model_cb.toggled.connect(self._on_same_text_toggled)
|
||||
self.vision_model_edit.textChanged.connect(self._sync_text_model_if_linked)
|
||||
ai_form.addRow("文本模型:", self.text_model_edit)
|
||||
|
||||
self.timeout_spin = QSpinBox()
|
||||
self.timeout_spin.setRange(30, 3600)
|
||||
self.timeout_spin.setSingleStep(30)
|
||||
self.timeout_spin.setValue(int(os.environ.get("OLLAMA_TIMEOUT_S", "120")))
|
||||
ai_form.addRow("请求超时(秒):", self.timeout_spin)
|
||||
|
||||
ai_group.setLayout(ai_form)
|
||||
layout.addWidget(ai_group)
|
||||
|
||||
btn_row = QHBoxLayout()
|
||||
self.generate_btn = QPushButton("生成 Word 报告")
|
||||
self.generate_btn.setStyleSheet(
|
||||
ModernStylesheet.get_button_stylesheet("success")
|
||||
)
|
||||
self.generate_btn.clicked.connect(self.on_generate_clicked)
|
||||
btn_row.addWidget(self.generate_btn)
|
||||
btn_row.addStretch()
|
||||
layout.addLayout(btn_row)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def _on_same_text_toggled(self, checked: bool):
|
||||
self.text_model_edit.setEnabled(not checked)
|
||||
if checked:
|
||||
self.text_model_edit.setText(self.vision_model_edit.text())
|
||||
|
||||
def _sync_text_model_if_linked(self, _t=None):
|
||||
if self.same_text_model_cb.isChecked():
|
||||
self.text_model_edit.blockSignals(True)
|
||||
self.text_model_edit.setText(self.vision_model_edit.text())
|
||||
self.text_model_edit.blockSignals(False)
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用主窗口缓存的 work_dir"""
|
||||
if self.main_window and hasattr(self.main_window, 'work_dir') and self.main_window.work_dir:
|
||||
return str(self.main_window.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_work_dir(self):
|
||||
default = self._get_default_work_dir()
|
||||
d = QFileDialog.getExistingDirectory(self, "选择工作目录", default)
|
||||
if d:
|
||||
self.work_dir_edit.setText(d)
|
||||
|
||||
def browse_output_dir(self):
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "14_visualization")
|
||||
d = QFileDialog.getExistingDirectory(self, "选择报告输出目录", default)
|
||||
if d:
|
||||
self.output_dir_edit.setText(d)
|
||||
|
||||
def sync_work_dir_from_main(self):
|
||||
mw = self.main_window
|
||||
if mw is not None and getattr(mw, "work_dir", None):
|
||||
self.work_dir_edit.setText(str(mw.work_dir))
|
||||
else:
|
||||
QMessageBox.information(self, "提示", "主窗口尚未设置工作目录。")
|
||||
|
||||
def set_work_dir(self, work_dir):
|
||||
if work_dir:
|
||||
self.work_dir_edit.setText(str(work_dir))
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"work_dir": self.work_dir_edit.text().strip() or None,
|
||||
"output_dir": self.output_dir_edit.text().strip() or None,
|
||||
"report_title": self.report_title_edit.text().strip() or "水质参数反演分析报告",
|
||||
"ollama_url": self.ollama_url_edit.text().strip(),
|
||||
"ollama_vision_model": self.vision_model_edit.text().strip(),
|
||||
"ollama_text_model": self.text_model_edit.text().strip(),
|
||||
"text_same_as_vision": self.same_text_model_cb.isChecked(),
|
||||
"ollama_timeout_s": self.timeout_spin.value(),
|
||||
"enable_ai_analysis": self.enable_ai_cb.isChecked(),
|
||||
}
|
||||
|
||||
def set_config(self, config):
|
||||
if not config:
|
||||
return
|
||||
if config.get("work_dir"):
|
||||
self.work_dir_edit.setText(str(config["work_dir"]))
|
||||
if "output_dir" in config:
|
||||
self.output_dir_edit.setText(str(config["output_dir"] or ""))
|
||||
if config.get("report_title"):
|
||||
self.report_title_edit.setText(str(config["report_title"]))
|
||||
if config.get("ollama_url"):
|
||||
self.ollama_url_edit.setText(str(config["ollama_url"]))
|
||||
if config.get("ollama_vision_model"):
|
||||
self.vision_model_edit.setText(str(config["ollama_vision_model"]))
|
||||
if "text_same_as_vision" in config:
|
||||
self.same_text_model_cb.setChecked(bool(config["text_same_as_vision"]))
|
||||
if config.get("ollama_text_model"):
|
||||
self.text_model_edit.setText(str(config["ollama_text_model"]))
|
||||
if config.get("ollama_timeout_s") is not None:
|
||||
self.timeout_spin.setValue(int(config["ollama_timeout_s"]))
|
||||
if "enable_ai_analysis" in config:
|
||||
self.enable_ai_cb.setChecked(bool(config["enable_ai_analysis"]))
|
||||
|
||||
def on_generate_clicked(self):
|
||||
wd = self.work_dir_edit.text().strip()
|
||||
if not wd or not os.path.isdir(wd):
|
||||
QMessageBox.warning(self, "提示", "请选择有效的工作目录。")
|
||||
return
|
||||
viz = Path(wd) / "14_visualization"
|
||||
if not viz.is_dir():
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"提示",
|
||||
f"未找到可视化目录:\n{viz}\n请先完成流程或生成可视化。",
|
||||
)
|
||||
return
|
||||
if self._report_thread and self._report_thread.isRunning():
|
||||
QMessageBox.information(self, "提示", "报告正在生成中,请稍候。")
|
||||
return
|
||||
|
||||
out = self.output_dir_edit.text().strip() or None
|
||||
title = self.report_title_edit.text().strip() or "水质参数反演分析报告"
|
||||
opts = {
|
||||
"ollama_url": self.ollama_url_edit.text().strip(),
|
||||
"ollama_vision_model": self.vision_model_edit.text().strip(),
|
||||
"ollama_text_model": self.text_model_edit.text().strip(),
|
||||
"text_same_as_vision": self.same_text_model_cb.isChecked(),
|
||||
"ollama_timeout_s": self.timeout_spin.value(),
|
||||
"enable_ai_analysis": self.enable_ai_cb.isChecked(),
|
||||
}
|
||||
self.generate_btn.setEnabled(False)
|
||||
self._report_thread = ReportGenerateThread(wd, out, title, opts)
|
||||
self._report_thread.log_message.connect(self._forward_log, Qt.QueuedConnection)
|
||||
self._report_thread.finished_ok.connect(self._on_report_ok, Qt.QueuedConnection)
|
||||
self._report_thread.failed.connect(self._on_report_fail, Qt.QueuedConnection)
|
||||
self._report_thread.finished.connect(
|
||||
lambda: self.generate_btn.setEnabled(True), Qt.QueuedConnection
|
||||
)
|
||||
self._report_thread.start()
|
||||
self._forward_log("已开始生成 Word 报告…", "info")
|
||||
|
||||
def _forward_log(self, msg: str, level: str):
|
||||
mw = self.main_window
|
||||
if mw is not None and hasattr(mw, "log_message"):
|
||||
mw.log_message(msg, level)
|
||||
else:
|
||||
print(f"[{level}] {msg}")
|
||||
|
||||
def _on_report_ok(self, path: str):
|
||||
QMessageBox.information(self, "完成", f"报告已生成:\n{path}")
|
||||
self._forward_log(f"Word 报告已保存: {path}", "info")
|
||||
|
||||
def _on_report_fail(self, err: str):
|
||||
QMessageBox.critical(self, "失败", f"报告生成失败:\n{err[:800]}")
|
||||
self._forward_log(err, "error")
|
||||
@ -1,282 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step1 面板 - 水域掩膜生成
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QHBoxLayout, QGroupBox, QLabel,
|
||||
QDoubleSpinBox, QCheckBox, QPushButton, QFormLayout, QRadioButton,
|
||||
QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
# 从公共组件库导入
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step1Panel(QWidget):
|
||||
"""1. 水域掩膜生成"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 掩膜生成方式选择
|
||||
method_group = QGroupBox("掩膜生成方式")
|
||||
method_layout = QVBoxLayout()
|
||||
|
||||
# 使用现有掩膜文件
|
||||
self.use_existing_radio = QRadioButton("使用现有掩膜文件")
|
||||
self.use_existing_radio.setChecked(True)
|
||||
method_layout.addWidget(self.use_existing_radio)
|
||||
|
||||
# 使用NDWI自动生成
|
||||
self.use_ndwi_radio = QRadioButton("使用NDWI自动生成")
|
||||
method_layout.addWidget(self.use_ndwi_radio)
|
||||
|
||||
# 应用QRadioButton样式(实心选中点)
|
||||
radio_style = """
|
||||
QRadioButton::indicator {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border-radius: 8px;
|
||||
border: 2px solid #999;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
background-color: #0078D7;
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
QRadioButton::indicator:unchecked {
|
||||
background-color: white;
|
||||
border: 2px solid #999;
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
"""
|
||||
self.use_existing_radio.setStyleSheet(radio_style)
|
||||
self.use_ndwi_radio.setStyleSheet(radio_style)
|
||||
|
||||
method_group.setLayout(method_layout)
|
||||
layout.addWidget(method_group)
|
||||
|
||||
# 掩膜文件选择
|
||||
self.mask_file = FileSelectWidget(
|
||||
"掩膜文件:",
|
||||
"Shapefiles (*.shp);;Raster Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.mask_file)
|
||||
|
||||
# 影像文件选择(用于shp栅格化或NDWI生成)
|
||||
self.img_file = FileSelectWidget(
|
||||
"参考影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.img_file)
|
||||
|
||||
# NDWI参数设置
|
||||
self.ndwi_group = QGroupBox("NDWI参数设置")
|
||||
ndwi_layout = QVBoxLayout()
|
||||
|
||||
# NDWI阈值
|
||||
threshold_layout = QHBoxLayout()
|
||||
threshold_layout.addWidget(QLabel("NDWI阈值:"))
|
||||
self.ndwi_threshold = QDoubleSpinBox()
|
||||
self.ndwi_threshold.setRange(0.0, 1.0)
|
||||
self.ndwi_threshold.setSingleStep(0.05)
|
||||
self.ndwi_threshold.setValue(0.4)
|
||||
self.ndwi_threshold.setDecimals(2)
|
||||
threshold_layout.addWidget(self.ndwi_threshold)
|
||||
threshold_layout.addStretch()
|
||||
ndwi_layout.addLayout(threshold_layout)
|
||||
|
||||
self.ndwi_group.setLayout(ndwi_layout)
|
||||
layout.addWidget(self.ndwi_group)
|
||||
|
||||
# 输出文件路径(使用save模式)
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)",
|
||||
mode="save"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("water_mask.dat")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 提示信息 - 专业的 Info Alert 样式
|
||||
hint = QLabel("💡 提示: 如果掩膜文件是Shapefile(.shp),需要提供参考影像用于栅格化;如果使用NDWI自动生成,只需要提供参考影像")
|
||||
hint.setWordWrap(True) # 允许自动换行
|
||||
hint.setStyleSheet("""
|
||||
QLabel {
|
||||
color: #0055D4;
|
||||
font-size: 13px;
|
||||
font-weight: bold;
|
||||
background-color: #E8F4FF;
|
||||
border: 2px solid #0055D4;
|
||||
border-radius: 8px;
|
||||
padding: 12px 16px;
|
||||
margin: 8px 0px;
|
||||
}
|
||||
""")
|
||||
layout.addWidget(hint)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
# 连接信号
|
||||
self.use_existing_radio.toggled.connect(self.update_ui_state)
|
||||
self.use_ndwi_radio.toggled.connect(self.update_ui_state)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
# 初始UI状态
|
||||
self.update_ui_state()
|
||||
|
||||
def update_ui_state(self):
|
||||
"""根据选择的掩膜生成方式更新UI状态(使用显示/隐藏控制)"""
|
||||
use_ndwi = self.use_ndwi_radio.isChecked()
|
||||
|
||||
# 动态显示/隐藏组件
|
||||
if use_ndwi:
|
||||
# 使用NDWI模式:隐藏掩膜文件,显示NDWI参数和输出掩膜
|
||||
self.mask_file.setVisible(False)
|
||||
self.ndwi_group.setVisible(True)
|
||||
self.output_file.setVisible(True) # 显示输出掩膜路径
|
||||
|
||||
# 当切换到NDWI模式时,如果工作目录已设置,自动填充输出路径
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
self._auto_fill_output_path()
|
||||
else:
|
||||
# 使用现有掩膜模式:显示掩膜文件,隐藏NDWI参数和输出掩膜
|
||||
self.mask_file.setVisible(True)
|
||||
self.ndwi_group.setVisible(False)
|
||||
self.output_file.setVisible(False) # 隐藏输出掩膜路径
|
||||
|
||||
# 参考影像在两种模式下都显示
|
||||
self.img_file.setVisible(True)
|
||||
|
||||
def update_work_directory(self, work_dir):
|
||||
"""
|
||||
保存工作目录引用,用于后续自动填充路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
"""
|
||||
if not work_dir:
|
||||
return
|
||||
|
||||
# 保存工作目录引用
|
||||
self.work_dir = work_dir
|
||||
|
||||
# 如果当前选中的是NDWI模式,立即填充输出路径
|
||||
if self.use_ndwi_radio.isChecked():
|
||||
self._auto_fill_output_path()
|
||||
|
||||
def _auto_fill_output_path(self):
|
||||
"""
|
||||
自动填充输出掩膜路径(仅在NDWI模式下)
|
||||
确保路径使用正斜杠,避免斜杠混用
|
||||
"""
|
||||
if not hasattr(self, 'work_dir') or not self.work_dir:
|
||||
return
|
||||
|
||||
# 生成输出掩膜的完整路径
|
||||
output_dir = os.path.join(self.work_dir, "1_water_mask")
|
||||
os.makedirs(output_dir, exist_ok=True) # 确保目录存在
|
||||
|
||||
# 统一使用正斜杠,避免 \ 和 / 混用
|
||||
default_output_path = os.path.join(output_dir, "water_mask_out.dat").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
use_ndwi = self.use_ndwi_radio.isChecked()
|
||||
|
||||
config = {
|
||||
'mask_path': None if use_ndwi else self.mask_file.get_path(),
|
||||
'use_ndwi': use_ndwi,
|
||||
'ndwi_threshold': self.ndwi_threshold.value()
|
||||
}
|
||||
|
||||
# 参考影像路径(两种模式都可能需要)
|
||||
img_path = self.img_file.get_path()
|
||||
if img_path:
|
||||
config['img_path'] = img_path
|
||||
|
||||
# 输出路径:仅在NDWI模式下有效
|
||||
if use_ndwi:
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
else:
|
||||
# 使用现有掩膜时,不传递output_path,避免底层错误尝试保存文件
|
||||
config['output_path'] = None
|
||||
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'mask_path' in config:
|
||||
self.mask_file.set_path(config['mask_path'])
|
||||
if 'img_path' in config:
|
||||
self.img_file.set_path(config['img_path'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
if 'use_ndwi' in config:
|
||||
if config['use_ndwi']:
|
||||
self.use_ndwi_radio.setChecked(True)
|
||||
else:
|
||||
self.use_existing_radio.setChecked(True)
|
||||
if 'ndwi_threshold' in config:
|
||||
self.ndwi_threshold.setValue(config['ndwi_threshold'])
|
||||
|
||||
self.update_ui_state()
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤1"""
|
||||
# 验证输入
|
||||
if self.use_ndwi_radio.isChecked():
|
||||
# NDWI模式:需要影像文件
|
||||
img_path = self.img_file.get_path()
|
||||
if not img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择参考影像文件!")
|
||||
return
|
||||
else:
|
||||
# 现有掩膜模式:需要掩膜文件
|
||||
mask_path = self.mask_file.get_path()
|
||||
if not mask_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择掩膜文件!")
|
||||
return
|
||||
|
||||
# 如果是shp文件,还需要影像文件
|
||||
if mask_path.lower().endswith('.shp'):
|
||||
img_path = self.img_file.get_path()
|
||||
if not img_path:
|
||||
QMessageBox.warning(self, "输入错误", "当使用shp文件时,需要提供参考影像用于栅格化!")
|
||||
return
|
||||
|
||||
# 获取父窗口并运行步骤
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
config = {'step1': self.get_config()}
|
||||
parent.run_single_step("step1", config)
|
||||
@ -1,210 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step2 面板 - 耀斑区域识别
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QDoubleSpinBox, QSpinBox, QComboBox, QCheckBox, QPushButton,
|
||||
QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
# 从公共组件库导入
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step2Panel(QWidget):
|
||||
"""2. 耀斑区域识别"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.work_dir = None
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 影像文件
|
||||
self.img_file = FileSelectWidget(
|
||||
"影像文件:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.img_file)
|
||||
|
||||
# 水域掩膜文件(可选,用于独立运行)
|
||||
self.water_mask_file = FileSelectWidget(
|
||||
"水域掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.water_mask_file.label.setText("水域掩膜:")
|
||||
layout.addWidget(self.water_mask_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("检测参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
# 耀斑波长
|
||||
self.glint_wave = QDoubleSpinBox()
|
||||
self.glint_wave.setRange(300, 1000)
|
||||
self.glint_wave.setValue(750.0)
|
||||
self.glint_wave.setSuffix(" nm")
|
||||
params_layout.addRow("耀斑检测波长:", self.glint_wave)
|
||||
|
||||
# 检测方法
|
||||
self.method = QComboBox()
|
||||
self.method.addItem("Otsu 阈值法", "otsu")
|
||||
self.method.addItem("Z-Score 方法", "zscore")
|
||||
self.method.addItem("百分位数法", "percentile")
|
||||
self.method.addItem("IQR 四分位距法", "iqr")
|
||||
self.method.addItem("自适应阈值法", "adaptive")
|
||||
self.method.addItem("多波段综合法", "multi_band")
|
||||
params_layout.addRow("检测方法:", self.method)
|
||||
|
||||
# 最大连通域面积
|
||||
self.max_area = QSpinBox()
|
||||
self.max_area.setRange(0, 100000)
|
||||
self.max_area.setValue(50)
|
||||
self.max_area.setSpecialValueText("不过滤")
|
||||
params_layout.addRow("最大连通域面积:", self.max_area)
|
||||
|
||||
# 岸边缓冲区
|
||||
self.buffer_size = QSpinBox()
|
||||
self.buffer_size.setRange(0, 200)
|
||||
self.buffer_size.setValue(10)
|
||||
self.buffer_size.setSpecialValueText("不设置")
|
||||
params_layout.addRow("岸边缓冲区大小:", self.buffer_size)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出耀斑掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
# 信号连接:影像文件路径变化时动态更新波段范围
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'img_path': self.img_file.get_path(),
|
||||
'glint_wave': self.glint_wave.value(),
|
||||
'method': self.method.currentData(), # 使用 currentData() 获取英文ID
|
||||
}
|
||||
if self.max_area.value() > 0:
|
||||
config['max_area'] = self.max_area.value()
|
||||
if self.buffer_size.value() > 0:
|
||||
config['buffer_size'] = self.buffer_size.value()
|
||||
# 添加水域掩膜路径(用于独立运行)
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['water_mask_path'] = water_mask_path
|
||||
# 添加输出路径
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'img_path' in config:
|
||||
self.img_file.set_path(config['img_path'])
|
||||
if 'glint_wave' in config:
|
||||
self.glint_wave.setValue(config['glint_wave'])
|
||||
if 'method' in config:
|
||||
idx = self.method.findData(config['method']) # 使用 findData()
|
||||
if idx >= 0:
|
||||
self.method.setCurrentIndex(idx)
|
||||
if 'max_area' in config:
|
||||
self.max_area.setValue(config['max_area'])
|
||||
if 'buffer_size' in config:
|
||||
self.buffer_size.setValue(config['buffer_size'])
|
||||
if 'water_mask_path' in config:
|
||||
self.water_mask_file.set_path(config['water_mask_path'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""
|
||||
从全局配置/Pipeline 或 Step1Panel 自动填充路径,实现上下游数据流转
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例,用于获取步骤1生成的水域掩膜路径
|
||||
"""
|
||||
# 保存工作目录引用
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass # 保持现有工作目录
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Pipeline 获取
|
||||
mask_path = None
|
||||
if pipeline and hasattr(pipeline, 'water_mask_path') and pipeline.water_mask_path:
|
||||
mask_path = pipeline.water_mask_path
|
||||
|
||||
# 2. 如果 Pipeline 中没有,则尝试直接从 Step1 界面读取(关键修复)
|
||||
main_window = self.window()
|
||||
if not mask_path and hasattr(main_window, 'step1_panel'):
|
||||
if main_window.step1_panel.use_ndwi_radio.isChecked():
|
||||
# NDWI模式,读取输出框的路径
|
||||
mask_path = main_window.step1_panel.output_file.get_path()
|
||||
else:
|
||||
# 导入现有模式,读取输入框的路径
|
||||
mask_path = main_window.step1_panel.mask_file.get_path()
|
||||
|
||||
# 填充获取到的路径
|
||||
if mask_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(mask_path):
|
||||
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
|
||||
self.water_mask_file.set_path(mask_path)
|
||||
|
||||
# 3. 自动填充输出路径(基于工作目录)
|
||||
if self.work_dir:
|
||||
# 生成输出耀斑掩膜的标准路径:workspace/2_glint_mask/glint_mask_out.dat
|
||||
output_dir = os.path.join(self.work_dir, "2_glint_mask")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
default_output_path = os.path.join(output_dir, "glint_mask_out.dat").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
else:
|
||||
# 没有工作目录时,清空输出路径
|
||||
self.output_file.set_path("")
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤2"""
|
||||
# 验证输入
|
||||
img_path = self.img_file.get_path()
|
||||
if not img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择影像文件!")
|
||||
return
|
||||
|
||||
# 获取主窗口并运行步骤
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step2': self.get_config()}
|
||||
main_window.run_single_step('step2', config)
|
||||
@ -1,451 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step3 面板 - 耀斑去除
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QDoubleSpinBox, QSpinBox, QComboBox, QCheckBox, QPushButton,
|
||||
QLabel, QLineEdit, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
# 从公共组件库导入
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step3Panel(QWidget):
|
||||
"""步骤3:耀斑去除"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 影像文件
|
||||
self.img_file = FileSelectWidget(
|
||||
"影像文件:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.img_file)
|
||||
|
||||
# 水域掩膜/边界:完整流程可由步骤1自动生成;独立单步运行时须手动指定
|
||||
self.water_mask_file = FileSelectWidget(
|
||||
"水域掩膜/边界:",
|
||||
"Mask/Boundary (*.dat *.tif *.shp);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.water_mask_file)
|
||||
step3_mask_hint = QLabel(
|
||||
"提示:独立运行本步骤时必须选择水域掩膜或边界(与影像同区域的 .dat/.tif 掩膜,或 .shp 矢量)。"
|
||||
)
|
||||
step3_mask_hint.setWordWrap(True)
|
||||
step3_mask_hint.setStyleSheet("color: #666; font-size: 10px;")
|
||||
layout.addWidget(step3_mask_hint)
|
||||
|
||||
# 方法选择
|
||||
method_group = QGroupBox("去耀斑方法")
|
||||
method_layout = QVBoxLayout()
|
||||
|
||||
self.method = QComboBox()
|
||||
for text, data in [('Goodman方法', 'goodman'), ('Kutser方法', 'kutser'),
|
||||
('Hedley方法', 'hedley'), ('SUGAR算法', 'sugar')]:
|
||||
self.method.addItem(text, data)
|
||||
self.method.currentIndexChanged.connect(self._on_method_changed)
|
||||
method_layout.addWidget(self.method)
|
||||
|
||||
method_group.setLayout(method_layout)
|
||||
layout.addWidget(method_group)
|
||||
|
||||
# Goodman参数组
|
||||
self.goodman_group = QGroupBox("Goodman方法参数")
|
||||
goodman_layout = QFormLayout()
|
||||
|
||||
self.nir_lower = QSpinBox()
|
||||
self.nir_lower.setRange(0, 200)
|
||||
self.nir_lower.setValue(65)
|
||||
goodman_layout.addRow("NIR下波段索引:", self.nir_lower)
|
||||
|
||||
self.nir_upper = QSpinBox()
|
||||
self.nir_upper.setRange(0, 200)
|
||||
self.nir_upper.setValue(91)
|
||||
goodman_layout.addRow("NIR上波段索引:", self.nir_upper)
|
||||
|
||||
self.goodman_a = QDoubleSpinBox()
|
||||
self.goodman_a.setDecimals(6)
|
||||
self.goodman_a.setRange(0, 1)
|
||||
self.goodman_a.setValue(0.000019)
|
||||
goodman_layout.addRow("参数A:", self.goodman_a)
|
||||
|
||||
self.goodman_b = QDoubleSpinBox()
|
||||
self.goodman_b.setDecimals(2)
|
||||
self.goodman_b.setRange(0, 1)
|
||||
self.goodman_b.setValue(0.1)
|
||||
goodman_layout.addRow("参数B:", self.goodman_b)
|
||||
|
||||
self.goodman_group.setLayout(goodman_layout)
|
||||
layout.addWidget(self.goodman_group)
|
||||
|
||||
# Kutser参数组
|
||||
self.kutser_group = QGroupBox("Kutser方法参数")
|
||||
kutser_layout = QFormLayout()
|
||||
|
||||
self.oxy_band = QSpinBox()
|
||||
self.oxy_band.setRange(0, 200)
|
||||
self.oxy_band.setValue(38)
|
||||
kutser_layout.addRow("氧吸收波段索引:", self.oxy_band)
|
||||
|
||||
self.lower_oxy = QSpinBox()
|
||||
self.lower_oxy.setRange(0, 200)
|
||||
self.lower_oxy.setValue(36)
|
||||
kutser_layout.addRow("下氧吸收波段索引:", self.lower_oxy)
|
||||
|
||||
self.upper_oxy = QSpinBox()
|
||||
self.upper_oxy.setRange(0, 200)
|
||||
self.upper_oxy.setValue(49)
|
||||
kutser_layout.addRow("上氧吸收波段索引:", self.upper_oxy)
|
||||
|
||||
self.nir_band = QSpinBox()
|
||||
self.nir_band.setRange(0, 200)
|
||||
self.nir_band.setValue(47)
|
||||
kutser_layout.addRow("NIR波段索引:", self.nir_band)
|
||||
|
||||
self.kutser_group.setLayout(kutser_layout)
|
||||
self.kutser_group.setVisible(False)
|
||||
layout.addWidget(self.kutser_group)
|
||||
|
||||
# Hedley参数组
|
||||
self.hedley_group = QGroupBox("Hedley方法参数")
|
||||
hedley_layout = QFormLayout()
|
||||
|
||||
self.hedley_nir_band = QSpinBox()
|
||||
self.hedley_nir_band.setRange(0, 200)
|
||||
self.hedley_nir_band.setValue(47)
|
||||
hedley_layout.addRow("NIR波段索引:", self.hedley_nir_band)
|
||||
|
||||
self.hedley_group.setLayout(hedley_layout)
|
||||
self.hedley_group.setVisible(False)
|
||||
layout.addWidget(self.hedley_group)
|
||||
|
||||
# SUGAR参数组
|
||||
self.sugar_group = QGroupBox("SUGAR方法参数")
|
||||
sugar_layout = QFormLayout()
|
||||
|
||||
self.sugar_iter = QSpinBox()
|
||||
self.sugar_iter.setRange(1, 20)
|
||||
self.sugar_iter.setValue(3)
|
||||
self.sugar_iter.setSpecialValueText("自动")
|
||||
sugar_layout.addRow("迭代次数:", self.sugar_iter)
|
||||
|
||||
self.sugar_sigma = QDoubleSpinBox()
|
||||
self.sugar_sigma.setDecimals(2)
|
||||
self.sugar_sigma.setRange(0.1, 10)
|
||||
self.sugar_sigma.setValue(1.0)
|
||||
sugar_layout.addRow("LoG平滑σ:", self.sugar_sigma)
|
||||
|
||||
self.sugar_estimate_background = QCheckBox()
|
||||
self.sugar_estimate_background.setChecked(True)
|
||||
sugar_layout.addRow("估计背景光谱:", self.sugar_estimate_background)
|
||||
|
||||
self.sugar_glint_mask_method = QComboBox()
|
||||
self.sugar_glint_mask_method.addItems(['cdf', 'otsu'])
|
||||
self.sugar_glint_mask_method.setCurrentText('cdf')
|
||||
sugar_layout.addRow("耀斑掩膜方法:", self.sugar_glint_mask_method)
|
||||
|
||||
self.sugar_termination_thresh = QDoubleSpinBox()
|
||||
self.sugar_termination_thresh.setDecimals(2)
|
||||
self.sugar_termination_thresh.setRange(1, 100)
|
||||
self.sugar_termination_thresh.setValue(20.0)
|
||||
sugar_layout.addRow("终止阈值:", self.sugar_termination_thresh)
|
||||
|
||||
self.sugar_bounds = QLineEdit()
|
||||
self.sugar_bounds.setText("[(1, 2)]")
|
||||
sugar_layout.addRow("优化边界:", self.sugar_bounds)
|
||||
|
||||
self.sugar_group.setLayout(sugar_layout)
|
||||
self.sugar_group.setVisible(False)
|
||||
layout.addWidget(self.sugar_group)
|
||||
|
||||
# 插值选项
|
||||
interp_group = QGroupBox("0值像素插值")
|
||||
interp_layout = QFormLayout()
|
||||
|
||||
self.interpolate_zeros = QCheckBox("启用插值")
|
||||
interp_layout.addRow("", self.interpolate_zeros)
|
||||
|
||||
self.interp_method = QComboBox()
|
||||
for text, data in [('最近邻插值', 'nearest'), ('双线性插值', 'bilinear'),
|
||||
('样条插值', 'spline'), ('克里金插值', 'kriging')]:
|
||||
self.interp_method.addItem(text, data)
|
||||
self.interp_method.setCurrentIndex(1) # 默认双线性插值
|
||||
interp_layout.addRow("插值方法:", self.interp_method)
|
||||
|
||||
interp_group.setLayout(interp_layout)
|
||||
layout.addWidget(interp_group)
|
||||
|
||||
# # 实测经纬度参考点
|
||||
# self.ref_csv_file = FileSelectWidget(
|
||||
# "实测经纬度CSV:",
|
||||
# "CSV Files (*.csv);;All Files (*.*)"
|
||||
# )
|
||||
# self.ref_csv_file.line_edit.setPlaceholderText("可选:包含 Lon/Lat 列的 CSV 文件")
|
||||
# layout.addWidget(self.ref_csv_file)
|
||||
|
||||
# 交互式预览按钮
|
||||
# self.preview_btn = QPushButton("👁️ 打开交互式影像预览")
|
||||
# self.preview_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('info'))
|
||||
# self.preview_btn.clicked.connect(self.open_interactive_viewer)
|
||||
# layout.addWidget(self.preview_btn)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("deglint_image.dat")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
# 信号连接:影像文件路径变化时动态更新波段范围
|
||||
self.img_file.line_edit.textChanged.connect(self._update_band_ranges)
|
||||
|
||||
def open_interactive_viewer(self):
|
||||
"""打开交互式影像预览"""
|
||||
from src.gui.water_quality_gui import InteractiveViewerDialog
|
||||
img_path = self.img_file.get_path()
|
||||
if not img_path or not os.path.isfile(img_path):
|
||||
QMessageBox.warning(self, "警告", "请先选择影像文件!")
|
||||
return
|
||||
|
||||
water_mask = self.water_mask_file.get_path()
|
||||
|
||||
dialog = InteractiveViewerDialog(img_path, self)
|
||||
if water_mask and os.path.isfile(water_mask):
|
||||
dialog.load_water_mask(water_mask)
|
||||
dialog.exec_()
|
||||
|
||||
def _update_band_ranges(self, file_path):
|
||||
"""根据选择的影像动态限制波段索引的输入范围"""
|
||||
from osgeo import gdal
|
||||
|
||||
if not file_path or not os.path.isfile(file_path):
|
||||
return
|
||||
|
||||
try:
|
||||
dataset = gdal.Open(file_path)
|
||||
if dataset is None:
|
||||
return
|
||||
raster_count = dataset.RasterCount
|
||||
max_band = max(0, raster_count - 1)
|
||||
self.nir_lower.setMaximum(max_band)
|
||||
self.nir_upper.setMaximum(max_band)
|
||||
self.oxy_band.setMaximum(max_band)
|
||||
self.nir_band.setMaximum(max_band)
|
||||
self.hedley_nir_band.setMaximum(max_band)
|
||||
dataset = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""
|
||||
从 Step1Panel 自动填充水域掩膜路径,实现上下游数据流转
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
# 保存工作目录引用
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass # 保持现有工作目录
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 从 Step1 界面读取水域掩膜路径
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'step1_panel'):
|
||||
if main_window.step1_panel.use_ndwi_radio.isChecked():
|
||||
# NDWI模式,读取输出框的路径
|
||||
mask_path = main_window.step1_panel.output_file.get_path()
|
||||
else:
|
||||
# 导入现有模式,读取输入框的路径
|
||||
mask_path = main_window.step1_panel.mask_file.get_path()
|
||||
|
||||
if mask_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(mask_path):
|
||||
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
|
||||
self.water_mask_file.set_path(mask_path)
|
||||
|
||||
# 自动填充输出路径(基于工作目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "3_deglint")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
default_output_path = os.path.join(output_dir, "deglint_image.dat").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
else:
|
||||
self.output_file.set_path("")
|
||||
|
||||
def _on_method_changed(self, index):
|
||||
"""方法改变时更新参数显示"""
|
||||
method_id = self.method.currentData()
|
||||
self.goodman_group.setVisible(method_id == 'goodman')
|
||||
self.kutser_group.setVisible(method_id == 'kutser')
|
||||
self.hedley_group.setVisible(method_id == 'hedley')
|
||||
self.sugar_group.setVisible(method_id == 'sugar')
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'img_path': self.img_file.get_path(),
|
||||
'method': self.method.currentData(), # 使用 currentData() 获取英文ID
|
||||
'enabled': self.enable_checkbox.isChecked(),
|
||||
'interpolate_zeros': self.interpolate_zeros.isChecked(),
|
||||
'interpolation_method': self.interp_method.currentData(), # 使用 currentData()
|
||||
}
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['water_mask'] = water_mask_path
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
|
||||
method = self.method.currentData() # 使用 currentData()
|
||||
|
||||
if method == 'goodman':
|
||||
config['nir_lower'] = self.nir_lower.value()
|
||||
config['nir_upper'] = self.nir_upper.value()
|
||||
config['goodman_A'] = self.goodman_a.value()
|
||||
config['goodman_B'] = self.goodman_b.value()
|
||||
|
||||
elif method == 'kutser':
|
||||
config['oxy_band'] = self.oxy_band.value()
|
||||
config['lower_oxy'] = self.lower_oxy.value()
|
||||
config['upper_oxy'] = self.upper_oxy.value()
|
||||
config['nir_band'] = self.nir_band.value()
|
||||
|
||||
elif method == 'hedley':
|
||||
config['hedley_nir_band'] = self.hedley_nir_band.value()
|
||||
|
||||
elif method == 'sugar':
|
||||
config['sugar_iter'] = self.sugar_iter.value() if self.sugar_iter.value() > 0 else None
|
||||
config['sugar_sigma'] = self.sugar_sigma.value()
|
||||
config['sugar_estimate_background'] = self.sugar_estimate_background.isChecked()
|
||||
config['sugar_glint_mask_method'] = self.sugar_glint_mask_method.currentData()
|
||||
config['sugar_termination_thresh'] = self.sugar_termination_thresh.value()
|
||||
# 解析bounds字符串
|
||||
try:
|
||||
import ast
|
||||
config['sugar_bounds'] = ast.literal_eval(self.sugar_bounds.text())
|
||||
except:
|
||||
config['sugar_bounds'] = [(1, 2)] # 默认值
|
||||
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'img_path' in config:
|
||||
self.img_file.set_path(config['img_path'])
|
||||
if 'water_mask' in config:
|
||||
self.water_mask_file.set_path(config['water_mask'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
if 'reference_csv' in config:
|
||||
self.ref_csv_file.set_path(config['reference_csv'])
|
||||
if 'method' in config:
|
||||
idx = self.method.findData(config['method']) # 使用 findData()
|
||||
if idx >= 0:
|
||||
self.method.setCurrentIndex(idx)
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
if 'interpolate_zeros' in config:
|
||||
self.interpolate_zeros.setChecked(config['interpolate_zeros'])
|
||||
if 'interpolation_method' in config:
|
||||
idx = self.interp_method.findData(config['interpolation_method']) # 使用 findData()
|
||||
if idx >= 0:
|
||||
self.interp_method.setCurrentIndex(idx)
|
||||
|
||||
# Goodman参数
|
||||
if 'nir_lower' in config:
|
||||
self.nir_lower.setValue(config['nir_lower'])
|
||||
if 'nir_upper' in config:
|
||||
self.nir_upper.setValue(config['nir_upper'])
|
||||
if 'goodman_A' in config:
|
||||
self.goodman_a.setValue(config['goodman_A'])
|
||||
if 'goodman_B' in config:
|
||||
self.goodman_b.setValue(config['goodman_B'])
|
||||
|
||||
# Kutser参数
|
||||
if 'oxy_band' in config:
|
||||
self.oxy_band.setValue(config['oxy_band'])
|
||||
if 'lower_oxy' in config:
|
||||
self.lower_oxy.setValue(config['lower_oxy'])
|
||||
if 'upper_oxy' in config:
|
||||
self.upper_oxy.setValue(config['upper_oxy'])
|
||||
if 'nir_band' in config:
|
||||
self.nir_band.setValue(config['nir_band'])
|
||||
|
||||
# Hedley参数
|
||||
if 'hedley_nir_band' in config:
|
||||
self.hedley_nir_band.setValue(config['hedley_nir_band'])
|
||||
|
||||
# SUGAR参数
|
||||
if 'sugar_iter' in config:
|
||||
self.sugar_iter.setValue(config['sugar_iter'] if config['sugar_iter'] is not None else 0)
|
||||
if 'sugar_sigma' in config:
|
||||
self.sugar_sigma.setValue(config['sugar_sigma'])
|
||||
if 'sugar_estimate_background' in config:
|
||||
self.sugar_estimate_background.setChecked(config['sugar_estimate_background'])
|
||||
if 'sugar_glint_mask_method' in config:
|
||||
idx = self.sugar_glint_mask_method.findData(config['sugar_glint_mask_method']) # 使用 findData()
|
||||
if idx >= 0:
|
||||
self.sugar_glint_mask_method.setCurrentIndex(idx)
|
||||
if 'sugar_termination_thresh' in config:
|
||||
self.sugar_termination_thresh.setValue(config['sugar_termination_thresh'])
|
||||
if 'sugar_bounds' in config:
|
||||
self.sugar_bounds.setText(str(config['sugar_bounds']))
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤3"""
|
||||
# 验证输入
|
||||
img_path = self.img_file.get_path()
|
||||
if not img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择影像文件!")
|
||||
return
|
||||
if self.enable_checkbox.isChecked():
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if not water_mask_path:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入错误",
|
||||
"独立运行耀斑去除时,必须选择水域掩膜或边界文件。\n\n"
|
||||
"请提供与当前影像空间一致的水域栅格掩膜(.dat/.tif),或水域矢量边界(.shp)。\n"
|
||||
"若刚跑过完整流程,可使用步骤1生成的水域掩膜文件。",
|
||||
)
|
||||
return
|
||||
|
||||
# 获取主窗口并运行步骤
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step3': self.get_config()}
|
||||
main_window.run_single_step('step3', config)
|
||||
@ -1,185 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step4 面板 - 数据预处理
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QHBoxLayout, QLabel,
|
||||
QSpinBox, QPushButton, QCheckBox, QTableView,
|
||||
QAbstractItemView, QHeaderView, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step4Panel(QWidget):
|
||||
"""步骤4:数据预处理"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
# CSV文件
|
||||
self.csv_file = FileSelectWidget(
|
||||
"水质参数文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.csv_file)
|
||||
|
||||
hint = QLabel("提示: 处理CSV文件,筛选剔除异常值")
|
||||
hint.setStyleSheet("color: #666; font-size: 10px;")
|
||||
layout.addWidget(hint)
|
||||
|
||||
preview_group = QGroupBox("CSV数据预览")
|
||||
preview_layout = QVBoxLayout()
|
||||
|
||||
controls_layout = QHBoxLayout()
|
||||
controls_layout.addWidget(QLabel("预览行数:"))
|
||||
self.preview_rows_spin = QSpinBox()
|
||||
self.preview_rows_spin.setRange(1, 200)
|
||||
self.preview_rows_spin.setValue(10)
|
||||
controls_layout.addWidget(self.preview_rows_spin)
|
||||
self.preview_btn = QPushButton("刷新预览")
|
||||
self.preview_btn.clicked.connect(self.load_csv_preview)
|
||||
controls_layout.addWidget(self.preview_btn)
|
||||
controls_layout.addStretch()
|
||||
|
||||
self.preview_table = QTableView()
|
||||
self.preview_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||
self.preview_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||
self.preview_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||
self.preview_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
|
||||
self.preview_table.verticalHeader().setVisible(False)
|
||||
self.preview_table.setMinimumHeight(200)
|
||||
|
||||
self.preview_status_label = QLabel("请选择CSV文件并点击刷新预览")
|
||||
self.preview_status_label.setStyleSheet("color: #666; font-size: 11px;")
|
||||
|
||||
preview_layout.addLayout(controls_layout)
|
||||
preview_layout.addWidget(self.preview_table)
|
||||
preview_layout.addWidget(self.preview_status_label)
|
||||
preview_group.setLayout(preview_layout)
|
||||
layout.addWidget(preview_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出处理后CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("processed_data.csv")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
self.reset_preview()
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'csv_path': self.csv_file.get_path(),
|
||||
}
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'csv_path' in config:
|
||||
self.csv_file.set_path(config['csv_path'])
|
||||
self.load_csv_preview()
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "4_processed_data")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
default_output_path = os.path.join(output_dir, "processed_data.csv").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
else:
|
||||
self.output_file.set_path("")
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤4"""
|
||||
# 验证输入
|
||||
csv_path = self.csv_file.get_path()
|
||||
if not csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择水质参数文件!")
|
||||
return
|
||||
|
||||
# 获取主窗口并运行步骤
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step4': self.get_config()}
|
||||
main_window.run_single_step('step4', config)
|
||||
|
||||
def reset_preview(self, message="请选择CSV文件并点击刷新预览"):
|
||||
"""重置预览表格"""
|
||||
from src.gui.water_quality_gui import PandasTableModel
|
||||
empty_model = PandasTableModel(pd.DataFrame())
|
||||
self.preview_table.setModel(empty_model)
|
||||
self.preview_status_label.setText(message)
|
||||
|
||||
def load_csv_preview(self):
|
||||
"""加载CSV预览数据"""
|
||||
from src.gui.water_quality_gui import PandasTableModel
|
||||
csv_path = self.csv_file.get_path()
|
||||
if not csv_path:
|
||||
self.reset_preview("请先选择CSV文件")
|
||||
return
|
||||
if not os.path.exists(csv_path):
|
||||
self.reset_preview("文件不存在,请检查路径")
|
||||
return
|
||||
|
||||
try:
|
||||
rows_to_preview = max(1, self.preview_rows_spin.value())
|
||||
# dtype=object 确保所有列以字符串读取,避免空值/混合类型导致 dtype 报错
|
||||
df = pd.read_csv(csv_path, nrows=rows_to_preview, dtype=object)
|
||||
# fillna 在 PandasTableModel.__init__ 中已执行,此处再次防御性处理
|
||||
df = df.fillna('')
|
||||
if df.empty:
|
||||
self.reset_preview("CSV文件为空")
|
||||
return
|
||||
|
||||
model = PandasTableModel(df)
|
||||
self.preview_table.setModel(model)
|
||||
self.preview_status_label.setText(
|
||||
f"预览 {len(df)} 行,{len(df.columns)} 列(总行数可能更多)"
|
||||
)
|
||||
except Exception as exc:
|
||||
self.reset_preview(f"加载失败: {exc}")
|
||||
@ -1,399 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step5_5 面板 - 水质指数计算
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QLineEdit, QComboBox, QCheckBox,
|
||||
QPushButton, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step5_5Panel(QWidget):
|
||||
"""步骤5.5:水质指数计算"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.index_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.csv_columns = [] # 存储CSV文件列名
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
main_layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 数据文件选择
|
||||
data_group = QGroupBox("数据文件")
|
||||
data_layout = QVBoxLayout()
|
||||
|
||||
# 训练数据CSV文件选择
|
||||
self.training_data_widget = FileSelectWidget("训练数据CSV文件:", "CSV Files (*.csv)")
|
||||
data_layout.addWidget(self.training_data_widget)
|
||||
|
||||
# 公式CSV文件选择
|
||||
self.formula_csv_widget = FileSelectWidget("公式CSV文件:", "CSV Files (*.csv)")
|
||||
data_layout.addWidget(self.formula_csv_widget)
|
||||
|
||||
# 刷新公式按钮
|
||||
refresh_layout = QHBoxLayout()
|
||||
self.refresh_button = QPushButton("刷新公式列表")
|
||||
self.refresh_button.clicked.connect(self.refresh_formulas)
|
||||
refresh_layout.addWidget(self.refresh_button)
|
||||
refresh_layout.addStretch()
|
||||
data_layout.addLayout(refresh_layout)
|
||||
|
||||
data_group.setLayout(data_layout)
|
||||
main_layout.addWidget(data_group)
|
||||
|
||||
# 公式选择区域
|
||||
self.formula_group = QGroupBox("选择要计算的公式")
|
||||
formula_outer_layout = QVBoxLayout()
|
||||
|
||||
# 按钮控制区域
|
||||
button_layout = QHBoxLayout()
|
||||
self.select_all_btn = QPushButton("全选")
|
||||
self.select_all_btn.clicked.connect(self.select_all_formulas)
|
||||
self.deselect_all_btn = QPushButton("清空")
|
||||
self.deselect_all_btn.clicked.connect(self.deselect_all_formulas)
|
||||
button_layout.addWidget(self.select_all_btn)
|
||||
button_layout.addWidget(self.deselect_all_btn)
|
||||
button_layout.addStretch()
|
||||
|
||||
formula_outer_layout.addLayout(button_layout)
|
||||
|
||||
# 公式勾选框网格布局
|
||||
self.formula_layout = QGridLayout()
|
||||
formula_outer_layout.addLayout(self.formula_layout)
|
||||
|
||||
self.formula_group.setLayout(formula_outer_layout)
|
||||
main_layout.addWidget(self.formula_group)
|
||||
|
||||
# 输出文件设置
|
||||
output_group = QGroupBox("输出设置")
|
||||
output_layout = QVBoxLayout()
|
||||
|
||||
self.output_file_widget = FileSelectWidget(
|
||||
"输出文件:", "CSV Files (*.csv)", mode="save"
|
||||
)
|
||||
output_layout.addWidget(self.output_file_widget)
|
||||
|
||||
output_group.setLayout(output_layout)
|
||||
main_layout.addWidget(output_group)
|
||||
|
||||
# 启用选项
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
main_layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
main_layout.addWidget(self.run_button)
|
||||
|
||||
# 公式编辑区域
|
||||
formula_edit_group = QGroupBox("添加自定义公式")
|
||||
formula_edit_layout = QFormLayout()
|
||||
|
||||
self.formula_name_edit = QLineEdit()
|
||||
|
||||
# 公式类别下拉选择框
|
||||
self.formula_category_combo = QComboBox()
|
||||
self.formula_category_combo.addItems([
|
||||
"chlorophyll_a",
|
||||
"Phycocyanin (BGA_PC)",
|
||||
"Total Nitrogen (TN)",
|
||||
"Total Phosphorus (TP)",
|
||||
"Orthophosphate",
|
||||
"COD",
|
||||
"BOD",
|
||||
"TOC",
|
||||
"Dissolved Oxygen (DO)",
|
||||
"E. coli",
|
||||
"Total Coliforms",
|
||||
"Turbidity",
|
||||
"Total Suspended Solids (TSS)",
|
||||
"Color",
|
||||
"pH",
|
||||
"Temperature",
|
||||
"Conductivity",
|
||||
"Total Dissolved Solids (TDS)"
|
||||
])
|
||||
self.formula_category_combo.setEditable(True) # 允许用户输入自定义类别
|
||||
|
||||
self.formula_expression_edit = QLineEdit()
|
||||
self.formula_reference_edit = QLineEdit()
|
||||
|
||||
formula_edit_layout.addRow("公式名称:", self.formula_name_edit)
|
||||
formula_edit_layout.addRow("公式类别:", self.formula_category_combo)
|
||||
formula_edit_layout.addRow("公式表达式:", self.formula_expression_edit)
|
||||
formula_edit_layout.addRow("参考文献:", self.formula_reference_edit)
|
||||
|
||||
add_button = QPushButton("添加公式")
|
||||
add_button.clicked.connect(self.add_custom_formula)
|
||||
formula_edit_layout.addRow(add_button)
|
||||
|
||||
formula_edit_group.setLayout(formula_edit_layout)
|
||||
main_layout.addWidget(formula_edit_group)
|
||||
|
||||
main_layout.addStretch()
|
||||
self.setLayout(main_layout)
|
||||
|
||||
# 自动加载内置公式文件
|
||||
formula_csv_path = (
|
||||
Path(__file__).resolve().parent.parent / "model" / "waterindex.csv"
|
||||
)
|
||||
if formula_csv_path.is_file():
|
||||
self.formula_csv_widget.set_path(str(formula_csv_path))
|
||||
self.refresh_formulas()
|
||||
|
||||
def refresh_formulas(self):
|
||||
"""刷新公式列表"""
|
||||
formula_csv_path = self.formula_csv_widget.get_path()
|
||||
if not formula_csv_path or not os.path.exists(formula_csv_path):
|
||||
QMessageBox.warning(self, "警告", "请先选择有效的公式CSV文件")
|
||||
return
|
||||
|
||||
try:
|
||||
# 清除现有的勾选框
|
||||
for checkbox in self.index_checkboxes.values():
|
||||
self.formula_layout.removeWidget(checkbox)
|
||||
checkbox.deleteLater()
|
||||
self.index_checkboxes.clear()
|
||||
|
||||
# 读取公式CSV文件
|
||||
df = pd.read_csv(formula_csv_path)
|
||||
if df.empty or 'Formula_Name' not in df.columns:
|
||||
QMessageBox.warning(self, "警告", "公式CSV文件格式不正确")
|
||||
return
|
||||
|
||||
# 获取所有公式名称(跳过第一行)
|
||||
formula_names = df['Formula_Name'].tolist()[1:]
|
||||
|
||||
# 创建3列布局的勾选框
|
||||
row, col = 0, 0
|
||||
for formula_name in formula_names:
|
||||
if pd.isna(formula_name) or not formula_name.strip():
|
||||
continue
|
||||
|
||||
checkbox = QCheckBox(formula_name.strip())
|
||||
checkbox.setChecked(True)
|
||||
self.index_checkboxes[formula_name.strip()] = checkbox
|
||||
self.formula_layout.addWidget(checkbox, row, col)
|
||||
|
||||
col += 1
|
||||
if col >= 3: # 每行3列
|
||||
col = 0
|
||||
row += 1
|
||||
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "错误", f"读取公式文件失败: {str(e)}")
|
||||
|
||||
def add_custom_formula(self):
|
||||
"""添加自定义公式到公式CSV文件"""
|
||||
formula_csv_path = self.formula_csv_widget.get_path()
|
||||
if not formula_csv_path:
|
||||
QMessageBox.warning(self, "警告", "请先选择公式CSV文件")
|
||||
return
|
||||
|
||||
formula_name = self.formula_name_edit.text().strip()
|
||||
formula_category = self.formula_category_combo.currentText().strip()
|
||||
formula_expression = self.formula_expression_edit.text().strip()
|
||||
formula_reference = self.formula_reference_edit.text().strip()
|
||||
|
||||
if not all([formula_name, formula_category, formula_expression]):
|
||||
QMessageBox.warning(self, "警告", "请填写公式名称、类别和表达式")
|
||||
return
|
||||
|
||||
try:
|
||||
# 读取现有公式文件或创建新文件
|
||||
if os.path.exists(formula_csv_path):
|
||||
df = pd.read_csv(formula_csv_path)
|
||||
else:
|
||||
df = pd.DataFrame(columns=['Formula_Name', 'Category', 'Formula', 'Reference'])
|
||||
|
||||
# 添加新公式
|
||||
new_row = pd.DataFrame({
|
||||
'Formula_Name': [formula_name],
|
||||
'Category': [formula_category],
|
||||
'Formula': [formula_expression],
|
||||
'Reference': [formula_reference]
|
||||
})
|
||||
df = pd.concat([df, new_row], ignore_index=True)
|
||||
|
||||
# 保存文件
|
||||
df.to_csv(formula_csv_path, index=False, encoding='utf-8')
|
||||
|
||||
# 清空输入框
|
||||
self.formula_name_edit.clear()
|
||||
self.formula_category_combo.setCurrentIndex(0) # 重置到第一个选项
|
||||
self.formula_expression_edit.clear()
|
||||
self.formula_reference_edit.clear()
|
||||
|
||||
# 刷新公式列表
|
||||
self.refresh_formulas()
|
||||
|
||||
QMessageBox.information(self, "成功", "公式添加成功")
|
||||
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "错误", f"添加公式失败: {str(e)}")
|
||||
|
||||
def get_config(self) -> Dict[str, Union[List[str], str, bool]]:
|
||||
"""获取配置"""
|
||||
selected = [
|
||||
name for name, checkbox in self.index_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
output_path = self.output_file_widget.get_path()
|
||||
return {
|
||||
'training_spectra_path': self.training_data_widget.get_path() or None,
|
||||
'formula_csv_file': self.formula_csv_widget.get_path() or None,
|
||||
'formula_names': selected,
|
||||
'output_file': output_path or None,
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'training_spectra_path' in config:
|
||||
self.training_data_widget.set_path(config['training_spectra_path'])
|
||||
|
||||
if 'formula_csv_file' in config:
|
||||
self.formula_csv_widget.set_path(config['formula_csv_file'])
|
||||
self.refresh_formulas()
|
||||
|
||||
if 'formula_names' in config:
|
||||
selected_formulas = set(config['formula_names'])
|
||||
for name, checkbox in self.index_checkboxes.items():
|
||||
checkbox.setChecked(name in selected_formulas)
|
||||
|
||||
if 'output_file' in config and config['output_file']:
|
||||
self.output_file_widget.set_path(config['output_file'])
|
||||
elif 'output_filename' in config and config['output_filename']:
|
||||
self.output_file_widget.set_path(config['output_filename'])
|
||||
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 自动填入训练数据路径(从 Step5 的输出中获取)
|
||||
# 优先级:直接 widget > pipeline.step_outputs 回退
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'step5_panel'):
|
||||
# 优先直接从 Step5 的输出 widget 读取(已运行的最新输出)
|
||||
step5_output = main_window.step5_panel.output_file.get_path()
|
||||
if step5_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_output):
|
||||
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
|
||||
self.training_data_widget.set_path(step5_output)
|
||||
else:
|
||||
# 退而求其次,使用 Step5 的输入 CSV
|
||||
step5_csv = main_window.step5_panel.csv_file.get_path()
|
||||
if step5_csv:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_csv):
|
||||
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
|
||||
self.training_data_widget.set_path(step5_csv)
|
||||
|
||||
# 如果上述都没找到,尝试从 pipeline.step_outputs 回退
|
||||
if not self.training_data_widget.get_path() and pipeline and hasattr(pipeline, 'step_outputs'):
|
||||
step5_outputs = getattr(pipeline, 'step_outputs', {}).get('step5', {})
|
||||
training_path = step5_outputs.get('training_spectra')
|
||||
if training_path:
|
||||
self.training_data_widget.set_path(training_path)
|
||||
|
||||
# 2. 自动填入输出文件的绝对路径
|
||||
if self.work_dir:
|
||||
output_abs = os.path.join(self.work_dir, "6_water_quality_indices",
|
||||
"training_spectra_indices.csv").replace('\\', '/')
|
||||
self.output_file_widget.set_path(output_abs)
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
return self.enable_checkbox.isChecked()
|
||||
|
||||
def select_all_formulas(self):
|
||||
"""全选所有公式"""
|
||||
for checkbox in self.index_checkboxes.values():
|
||||
checkbox.setChecked(True)
|
||||
|
||||
def deselect_all_formulas(self):
|
||||
"""清空所有公式"""
|
||||
for checkbox in self.index_checkboxes.values():
|
||||
checkbox.setChecked(False)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤5.5:计算水质指数。
|
||||
|
||||
动态根据输入 CSV 文件名生成输出文件名,自动填入 output_file_widget。
|
||||
例如:training_spectra.csv → training_spectra_indices.csv
|
||||
sampling_spectra.csv → sampling_spectra_indices.csv
|
||||
"""
|
||||
# 验证输入
|
||||
training_csv_path = self.training_data_widget.get_path()
|
||||
formula_csv_path = self.formula_csv_widget.get_path()
|
||||
|
||||
if not training_csv_path:
|
||||
QMessageBox.warning(self, "输入验证失败", "请选择训练数据CSV文件")
|
||||
return
|
||||
if not formula_csv_path:
|
||||
QMessageBox.warning(self, "输入验证失败", "请选择公式CSV文件")
|
||||
return
|
||||
if not os.path.exists(training_csv_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "训练数据CSV文件不存在")
|
||||
return
|
||||
if not os.path.exists(formula_csv_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "公式CSV文件不存在")
|
||||
return
|
||||
|
||||
# 动态生成输出文件:自动拼接 _indices 后缀
|
||||
input_name = Path(training_csv_path).stem
|
||||
dynamic_output = f"{input_name}_indices.csv"
|
||||
|
||||
# 合成完整绝对路径(优先使用 work_dir,其次从 training_csv_path 推导)
|
||||
work_dir = getattr(self, 'work_dir', None)
|
||||
if work_dir:
|
||||
dynamic_output = os.path.join(
|
||||
work_dir, "6_water_quality_indices", dynamic_output
|
||||
).replace('\\', '/')
|
||||
|
||||
self.output_file_widget.set_path(dynamic_output)
|
||||
|
||||
# 获取配置
|
||||
config = self.get_config()
|
||||
|
||||
# 调用GUI的run_single_step方法
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step5_5', {'step5_5': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
@ -1,239 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step5 面板 - 光谱提取
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QLabel,
|
||||
QSpinBox, QPushButton, QCheckBox, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtGui import QFont
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step5Panel(QWidget):
|
||||
"""步骤5:光谱提取"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
title = QLabel("步骤5:训练样本光谱提取")
|
||||
title.setFont(QFont("Arial", 12, QFont.Bold))
|
||||
layout.addWidget(title)
|
||||
|
||||
# 去耀斑影像文件(用于独立运行)
|
||||
self.deglint_img_file = FileSelectWidget(
|
||||
"去耀斑影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.deglint_img_file)
|
||||
|
||||
# 处理后的CSV文件(用于独立运行)
|
||||
self.csv_file = FileSelectWidget(
|
||||
"处理后CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.csv_file)
|
||||
|
||||
# 水体掩膜文件(可选,用于独立运行)
|
||||
self.water_mask_file = FileSelectWidget(
|
||||
"水体掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.water_mask_file.line_edit.setPlaceholderText("可选,如不选择则自动生成")
|
||||
layout.addWidget(self.water_mask_file)
|
||||
|
||||
self.glint_mask_file = FileSelectWidget(
|
||||
"耀斑掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.glint_mask_file)
|
||||
step5_glint_hint = QLabel(
|
||||
"提示:独立运行本步骤时必须选择耀斑掩膜(通常为步骤2输出的 severe_glint_area.dat),用于在采样时避开耀斑像元。"
|
||||
)
|
||||
step5_glint_hint.setWordWrap(True)
|
||||
step5_glint_hint.setStyleSheet("color: #666; font-size: 10px;")
|
||||
layout.addWidget(step5_glint_hint)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("提取参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.radius = QSpinBox()
|
||||
self.radius.setRange(1, 50)
|
||||
self.radius.setValue(5)
|
||||
params_layout.addRow("采样半径(像素):", self.radius)
|
||||
|
||||
self.source_epsg = QSpinBox()
|
||||
self.source_epsg.setRange(1000, 99999)
|
||||
self.source_epsg.setValue(4326)
|
||||
params_layout.addRow("源坐标系EPSG:", self.source_epsg)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出训练数据:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("training_spectra.csv")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
# 信号连接:影像文件路径变化时动态更新波段范围
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'radius': self.radius.value(),
|
||||
'source_epsg': self.source_epsg.value(),
|
||||
}
|
||||
# 添加独立运行所需的文件路径
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
if deglint_img_path:
|
||||
config['deglint_img_path'] = deglint_img_path
|
||||
csv_path = self.csv_file.get_path()
|
||||
if csv_path:
|
||||
config['csv_path'] = csv_path
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['boundary_path'] = water_mask_path
|
||||
glint_mask_path = self.glint_mask_file.get_path()
|
||||
if glint_mask_path:
|
||||
config['glint_mask_path'] = glint_mask_path
|
||||
# 注意:step5_extract_training_spectra 不接受 output_path / training_spectra_path
|
||||
# 参数,输出路径由 pipeline 内部根据 training_spectra_dir 自动生成。
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'radius' in config:
|
||||
self.radius.setValue(config['radius'])
|
||||
if 'source_epsg' in config:
|
||||
self.source_epsg.setValue(config['source_epsg'])
|
||||
if 'deglint_img_path' in config:
|
||||
self.deglint_img_file.set_path(config['deglint_img_path'])
|
||||
if 'csv_path' in config:
|
||||
self.csv_file.set_path(config['csv_path'])
|
||||
if 'boundary_path' in config:
|
||||
self.water_mask_file.set_path(config['boundary_path'])
|
||||
if 'glint_mask_path' in config:
|
||||
self.glint_mask_file.set_path(config['glint_mask_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置/Pipeline 或 Step1Panel 自动填充路径,实现上下游数据流转
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例,用于获取步骤1生成的水域掩膜路径
|
||||
"""
|
||||
# 保存工作目录引用
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Pipeline 获取水体掩膜路径
|
||||
mask_path = None
|
||||
if pipeline and hasattr(pipeline, 'water_mask_path') and pipeline.water_mask_path:
|
||||
mask_path = pipeline.water_mask_path
|
||||
|
||||
# 2. 如果 Pipeline 中没有,则尝试直接从 Step1 界面读取
|
||||
main_window = self.window()
|
||||
if not mask_path and hasattr(main_window, 'step1_panel'):
|
||||
if main_window.step1_panel.use_ndwi_radio.isChecked():
|
||||
mask_path = main_window.step1_panel.output_file.get_path()
|
||||
else:
|
||||
mask_path = main_window.step1_panel.mask_file.get_path()
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if mask_path and not os.path.isabs(mask_path):
|
||||
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
|
||||
|
||||
# 填充水体掩膜路径
|
||||
if mask_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(mask_path):
|
||||
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
|
||||
self.water_mask_file.set_path(mask_path)
|
||||
|
||||
# 3. 尝试从 Step2 界面读取耀斑掩膜路径
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'step2_panel'):
|
||||
glint_path = main_window.step2_panel.output_file.get_path()
|
||||
if glint_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(glint_path):
|
||||
glint_path = os.path.join(self.work_dir or '', glint_path).replace('\\', '/')
|
||||
self.glint_mask_file.set_path(glint_path)
|
||||
|
||||
# 4. 自动填充输出路径(基于工作目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "5_training_spectra")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
default_output_path = os.path.join(output_dir, "training_spectra.csv").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
else:
|
||||
self.output_file.set_path("")
|
||||
|
||||
# 5. 尝试从 Step4 界面读取已处理的水质参数 CSV 路径,自动填入本面板
|
||||
main_window = self.window()
|
||||
if main_window and hasattr(main_window, 'step4_panel'):
|
||||
step4_output_path = main_window.step4_panel.output_file.get_path()
|
||||
if step4_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step4_output_path):
|
||||
step4_output_path = os.path.join(self.work_dir or '', step4_output_path).replace('\\', '/')
|
||||
existing_csv = self.csv_file.get_path()
|
||||
if not existing_csv or not existing_csv.strip():
|
||||
self.csv_file.set_path(step4_output_path)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤5"""
|
||||
# 验证输入
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
csv_path = self.csv_file.get_path()
|
||||
if not deglint_img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
|
||||
return
|
||||
if not csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择处理后的CSV文件!")
|
||||
return
|
||||
if not self.glint_mask_file.get_path():
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入错误",
|
||||
"独立运行光谱特征提取时,必须选择耀斑掩膜文件。\n\n"
|
||||
"请提供与去耀斑影像对应的耀斑二值掩膜(一般为步骤2输出的 severe_glint_area.dat)。",
|
||||
)
|
||||
return
|
||||
|
||||
# 获取主窗口并运行步骤
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step5': self.get_config()}
|
||||
main_window.run_single_step('step5', config)
|
||||
@ -1,307 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step6_5 面板 - 非经验统计回归建模
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QCheckBox, QSpinBox, QPushButton,
|
||||
QFileDialog, QMessageBox,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step6_5Panel(QWidget):
|
||||
"""步骤6.5:非经验统计回归建模"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 训练数据文件(用于独立运行)
|
||||
self.training_csv_file = FileSelectWidget(
|
||||
"训练数据CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.training_csv_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("模型参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
# 预处理方法
|
||||
self.preproc_checkboxes = {}
|
||||
preproc_group = QGroupBox("预处理方法 (可多选)")
|
||||
preproc_layout = QVBoxLayout()
|
||||
preproc_grid = QGridLayout()
|
||||
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
|
||||
|
||||
for i, method in enumerate(preproc_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
checkbox.setChecked(True)
|
||||
self.preproc_checkboxes[method] = checkbox
|
||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||
|
||||
button_layout = QHBoxLayout()
|
||||
select_all_btn = QPushButton("全选")
|
||||
deselect_all_btn = QPushButton("全不选")
|
||||
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
|
||||
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
|
||||
button_layout.addWidget(select_all_btn)
|
||||
button_layout.addWidget(deselect_all_btn)
|
||||
button_layout.addStretch()
|
||||
|
||||
preproc_layout.addLayout(preproc_grid)
|
||||
preproc_layout.addLayout(button_layout)
|
||||
preproc_group.setLayout(preproc_layout)
|
||||
params_layout.addRow(preproc_group)
|
||||
|
||||
# 算法选择(可多选)
|
||||
self.algorithm_inputs = {}
|
||||
algorithms_widget = QWidget()
|
||||
algorithms_layout = QVBoxLayout()
|
||||
algorithms_layout.setContentsMargins(0, 0, 0, 0)
|
||||
algorithms_layout.setSpacing(4)
|
||||
|
||||
algorithm_list = ['chl_a', 'nh3', 'mno4', 'tn', 'tp', 'tss']
|
||||
for algorithm in algorithm_list:
|
||||
row_widget = QWidget()
|
||||
row_layout = QHBoxLayout()
|
||||
row_layout.setContentsMargins(0, 0, 0, 0)
|
||||
checkbox = QCheckBox(algorithm)
|
||||
checkbox.setChecked(True)
|
||||
spinbox = QSpinBox()
|
||||
spinbox.setRange(0, 500)
|
||||
spinbox.setValue(0)
|
||||
spinbox.setMaximumWidth(90)
|
||||
row_layout.addWidget(checkbox)
|
||||
row_layout.addWidget(QLabel("对应值列索引:"))
|
||||
row_layout.addWidget(spinbox)
|
||||
row_layout.addStretch()
|
||||
row_widget.setLayout(row_layout)
|
||||
algorithms_layout.addWidget(row_widget)
|
||||
self.algorithm_inputs[algorithm] = (checkbox, spinbox)
|
||||
|
||||
algorithms_widget.setLayout(algorithms_layout)
|
||||
params_layout.addRow("非经验算法选择:", algorithms_widget)
|
||||
|
||||
# 光谱起始列
|
||||
self.spectral_start_col = QSpinBox()
|
||||
self.spectral_start_col.setRange(0, 100)
|
||||
self.spectral_start_col.setValue(1)
|
||||
params_layout.addRow("光谱起始列索引:", self.spectral_start_col)
|
||||
|
||||
# 窗口大小 (变量名已修正,避免覆盖 QWidget.window)
|
||||
self.window_size_spinbox = QSpinBox()
|
||||
self.window_size_spinbox.setRange(1, 20)
|
||||
self.window_size_spinbox.setValue(5)
|
||||
params_layout.addRow("窗口大小:", self.window_size_spinbox)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_dir = FileSelectWidget(
|
||||
"输出模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_dir.line_edit.setPlaceholderText("8_Regression_Modeling")
|
||||
self.output_dir.browse_btn.clicked.disconnect()
|
||||
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
layout.addWidget(self.output_dir)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
selected_algorithms = [
|
||||
name for name, (checkbox, _) in self.algorithm_inputs.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_algorithms:
|
||||
selected_algorithms = list(self.algorithm_inputs.keys())
|
||||
|
||||
value_cols = {
|
||||
name: spinbox.value()
|
||||
for name, (_, spinbox) in self.algorithm_inputs.items()
|
||||
if name in selected_algorithms
|
||||
}
|
||||
|
||||
preprocessing_methods = [
|
||||
method for method, checkbox in self.preproc_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
] or ['None']
|
||||
|
||||
config = {
|
||||
'preprocessing_methods': preprocessing_methods,
|
||||
'algorithms': selected_algorithms,
|
||||
'value_cols': value_cols,
|
||||
'spectral_start_col': self.spectral_start_col.value(),
|
||||
'window': self.window_size_spinbox.value(),
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
|
||||
output_dir = self.output_dir.get_path()
|
||||
if not output_dir:
|
||||
main_window = self.parent().window()
|
||||
if hasattr(main_window, 'work_dir') and main_window.work_dir:
|
||||
output_dir = str(Path(main_window.work_dir) / "8_Regression_Modeling")
|
||||
else:
|
||||
output_dir = str(Path.cwd() / "8_Regression_Modeling")
|
||||
config['output_dir'] = output_dir
|
||||
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if training_csv_path:
|
||||
config['csv_path'] = training_csv_path
|
||||
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'preprocessing_methods' in config:
|
||||
methods = config['preprocessing_methods']
|
||||
for method, checkbox in self.preproc_checkboxes.items():
|
||||
checkbox.setChecked(method in methods)
|
||||
|
||||
if 'algorithms' in config:
|
||||
algorithm_values = config['algorithms']
|
||||
for algorithm, (checkbox, spinbox) in self.algorithm_inputs.items():
|
||||
checkbox.setChecked(algorithm in algorithm_values)
|
||||
|
||||
if 'value_cols' in config:
|
||||
value_cols = config['value_cols']
|
||||
if isinstance(value_cols, dict):
|
||||
for algorithm, (_, spinbox) in self.algorithm_inputs.items():
|
||||
if algorithm in value_cols:
|
||||
spinbox.setValue(value_cols[algorithm])
|
||||
else:
|
||||
for _, spinbox in self.algorithm_inputs.values():
|
||||
spinbox.setValue(value_cols)
|
||||
|
||||
if 'spectral_start_col' in config:
|
||||
self.spectral_start_col.setValue(config['spectral_start_col'])
|
||||
|
||||
if 'window' in config:
|
||||
self.window_size_spinbox.setValue(config['window'])
|
||||
if 'output_dir' in config:
|
||||
self.output_dir.set_path(config['output_dir'])
|
||||
if 'csv_path' in config:
|
||||
self.training_csv_file.set_path(config['csv_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 借用父组件的 window() 方法,安全绕过当前类的命名冲突
|
||||
parent_widget = self.parentWidget()
|
||||
main_window = parent_widget.window() if parent_widget else None
|
||||
if main_window and hasattr(main_window, 'step5_panel'):
|
||||
step5_widget = getattr(main_window.step5_panel, 'output_file', None)
|
||||
step5_output_path = ""
|
||||
if hasattr(step5_widget, 'get_path'):
|
||||
step5_output_path = step5_widget.get_path() or ""
|
||||
elif hasattr(step5_widget, 'text'):
|
||||
step5_output_path = step5_widget.text() or ""
|
||||
|
||||
if step5_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_output_path):
|
||||
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
|
||||
existing = self.training_csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.training_csv_file.set_path(step5_output_path)
|
||||
|
||||
# 2. 自动填充输出目录(8_Regression_Modeling)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "8_Regression_Modeling")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_dir.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
# 借用父组件的 window() 方法,安全绕过当前类的命名冲突
|
||||
parent_widget = self.parentWidget()
|
||||
mw = parent_widget.window() if parent_widget else None
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "8_Regression_Modeling")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出模型目录", default)
|
||||
if dir_path:
|
||||
self.output_dir.set_path(dir_path)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤6.5"""
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if not training_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
|
||||
return
|
||||
|
||||
if not os.path.exists(training_csv_path):
|
||||
QMessageBox.warning(self, "输入错误", "训练数据CSV文件不存在!")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step6_5', {'step6_5': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
|
||||
def _toggle_checkboxes(self, checkboxes_dict, checked):
|
||||
"""统一设置预处理checkbox状态"""
|
||||
for checkbox in checkboxes_dict.values():
|
||||
checkbox.setChecked(checked)
|
||||
@ -1,374 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step6_75 面板 - 自定义回归分析
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pandas as pd
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton,
|
||||
QScrollArea, QMessageBox,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step6_75Panel(QWidget):
|
||||
"""步骤6.75:自定义回归分析"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.x_column_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.y_column_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.method_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.csv_columns = []
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法")
|
||||
hint.setStyleSheet("color: #666; font-size: 11px;")
|
||||
layout.addWidget(hint)
|
||||
|
||||
# CSV文件选择
|
||||
csv_group = QGroupBox("数据文件")
|
||||
csv_layout = QVBoxLayout()
|
||||
|
||||
self.csv_file = FileSelectWidget(
|
||||
"输入CSV文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed)
|
||||
csv_layout.addWidget(self.csv_file)
|
||||
|
||||
self.refresh_btn = QPushButton("刷新列信息")
|
||||
self.refresh_btn.clicked.connect(self.refresh_csv_columns)
|
||||
csv_layout.addWidget(self.refresh_btn)
|
||||
|
||||
csv_group.setLayout(csv_layout)
|
||||
layout.addWidget(csv_group)
|
||||
|
||||
# 自变量选择
|
||||
x_group = QGroupBox("自变量列选择 (可多选)")
|
||||
x_layout = QVBoxLayout()
|
||||
|
||||
x_scroll = QScrollArea()
|
||||
x_scroll.setWidgetResizable(True)
|
||||
x_scroll.setMinimumHeight(250)
|
||||
x_scroll.setMaximumHeight(350)
|
||||
|
||||
x_widget = QWidget()
|
||||
self.x_columns_layout = QGridLayout()
|
||||
x_widget.setLayout(self.x_columns_layout)
|
||||
|
||||
x_scroll.setWidget(x_widget)
|
||||
x_layout.addWidget(x_scroll)
|
||||
|
||||
x_btn_layout = QHBoxLayout()
|
||||
self.x_select_all = QPushButton("全选")
|
||||
self.x_deselect_all = QPushButton("全不选")
|
||||
self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True))
|
||||
self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False))
|
||||
x_btn_layout.addWidget(self.x_select_all)
|
||||
x_btn_layout.addWidget(self.x_deselect_all)
|
||||
x_btn_layout.addStretch()
|
||||
x_layout.addLayout(x_btn_layout)
|
||||
|
||||
x_group.setLayout(x_layout)
|
||||
layout.addWidget(x_group)
|
||||
|
||||
# 因变量选择
|
||||
y_group = QGroupBox("因变量列选择 (可多选)")
|
||||
y_layout = QVBoxLayout()
|
||||
|
||||
y_scroll = QScrollArea()
|
||||
y_scroll.setWidgetResizable(True)
|
||||
y_scroll.setMinimumHeight(200)
|
||||
y_scroll.setMaximumHeight(300)
|
||||
|
||||
y_widget = QWidget()
|
||||
self.y_columns_layout = QGridLayout()
|
||||
y_widget.setLayout(self.y_columns_layout)
|
||||
|
||||
y_scroll.setWidget(y_widget)
|
||||
y_layout.addWidget(y_scroll)
|
||||
|
||||
y_btn_layout = QHBoxLayout()
|
||||
self.y_select_all = QPushButton("全选")
|
||||
self.y_deselect_all = QPushButton("全不选")
|
||||
self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True))
|
||||
self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False))
|
||||
y_btn_layout.addWidget(self.y_select_all)
|
||||
y_btn_layout.addWidget(self.y_deselect_all)
|
||||
y_btn_layout.addStretch()
|
||||
y_layout.addLayout(y_btn_layout)
|
||||
|
||||
y_group.setLayout(y_layout)
|
||||
layout.addWidget(y_group)
|
||||
|
||||
# 回归方法选择
|
||||
method_group = QGroupBox("回归方法选择 (可多选)")
|
||||
method_layout = QVBoxLayout()
|
||||
|
||||
method_grid = QGridLayout()
|
||||
regression_methods = [
|
||||
'linear', 'exponential', 'power', 'logarithmic',
|
||||
'polynomial', 'hyperbolic', 'sigmoidal'
|
||||
]
|
||||
|
||||
for i, method in enumerate(regression_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
if method in ['linear', 'exponential', 'power', 'logarithmic']:
|
||||
checkbox.setChecked(True)
|
||||
self.method_checkboxes[method] = checkbox
|
||||
method_grid.addWidget(checkbox, i // 3, i % 3)
|
||||
|
||||
method_layout.addLayout(method_grid)
|
||||
|
||||
method_btn_layout = QHBoxLayout()
|
||||
self.method_select_all = QPushButton("全选")
|
||||
self.method_deselect_all = QPushButton("全不选")
|
||||
self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True))
|
||||
self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False))
|
||||
method_btn_layout.addWidget(self.method_select_all)
|
||||
method_btn_layout.addWidget(self.method_deselect_all)
|
||||
method_btn_layout.addStretch()
|
||||
method_layout.addLayout(method_btn_layout)
|
||||
|
||||
method_group.setLayout(method_layout)
|
||||
layout.addWidget(method_group)
|
||||
|
||||
# 输出目录
|
||||
output_group = QGroupBox("输出设置")
|
||||
output_layout = QFormLayout()
|
||||
|
||||
self.output_dir = QLineEdit()
|
||||
self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充
|
||||
output_layout.addRow("输出目录名:", self.output_dir)
|
||||
|
||||
output_group.setLayout(output_layout)
|
||||
layout.addWidget(output_group)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def toggle_checkboxes(self, checkboxes_dict, checked):
|
||||
"""统一设置checkbox状态"""
|
||||
for checkbox in checkboxes_dict.values():
|
||||
checkbox.setChecked(checked)
|
||||
|
||||
def on_csv_file_changed(self):
|
||||
"""CSV文件改变时自动刷新列信息"""
|
||||
self.refresh_csv_columns()
|
||||
|
||||
def refresh_csv_columns(self):
|
||||
"""刷新CSV文件的列信息"""
|
||||
csv_path = self.csv_file.get_path()
|
||||
if not csv_path or not os.path.exists(csv_path):
|
||||
self.csv_columns = []
|
||||
self.update_column_widgets()
|
||||
return
|
||||
|
||||
try:
|
||||
df = pd.read_csv(csv_path, nrows=0)
|
||||
self.csv_columns = list(df.columns)
|
||||
self.update_column_widgets()
|
||||
except Exception as e:
|
||||
self.csv_columns = []
|
||||
self.update_column_widgets()
|
||||
print(f"读取CSV列信息失败: {e}")
|
||||
|
||||
def update_column_widgets(self):
|
||||
"""更新列选择组件"""
|
||||
for checkbox in self.x_column_checkboxes.values():
|
||||
checkbox.setParent(None)
|
||||
self.x_column_checkboxes.clear()
|
||||
|
||||
for checkbox in self.y_column_checkboxes.values():
|
||||
checkbox.setParent(None)
|
||||
self.y_column_checkboxes.clear()
|
||||
|
||||
if not self.csv_columns:
|
||||
return
|
||||
|
||||
for i, col in enumerate(self.csv_columns):
|
||||
checkbox = QCheckBox(col)
|
||||
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
|
||||
checkbox.setChecked(True)
|
||||
self.x_column_checkboxes[col] = checkbox
|
||||
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
|
||||
|
||||
for i, col in enumerate(self.csv_columns):
|
||||
checkbox = QCheckBox(col)
|
||||
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
|
||||
checkbox.setChecked(True)
|
||||
self.y_column_checkboxes[col] = checkbox
|
||||
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
|
||||
|
||||
self.x_columns_layout.update()
|
||||
self.y_columns_layout.update()
|
||||
|
||||
def get_config(self):
|
||||
selected_x_columns = [
|
||||
col for col, checkbox in self.x_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
selected_y_columns = [
|
||||
col for col, checkbox in self.y_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
selected_methods = [
|
||||
method for method, checkbox in self.method_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_methods:
|
||||
selected_methods = 'all'
|
||||
|
||||
return {
|
||||
'csv_path': self.csv_file.get_path() or None,
|
||||
'x_columns': selected_x_columns,
|
||||
'y_columns': selected_y_columns,
|
||||
'methods': selected_methods,
|
||||
'output_dir': self.output_dir.text().strip() or None,
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
|
||||
def set_config(self, config):
|
||||
if 'csv_path' in config:
|
||||
self.csv_file.set_path(config['csv_path'])
|
||||
self.refresh_csv_columns()
|
||||
|
||||
if 'x_columns' in config:
|
||||
selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set()
|
||||
for col, checkbox in self.x_column_checkboxes.items():
|
||||
checkbox.setChecked(col in selected_x)
|
||||
|
||||
if 'y_columns' in config:
|
||||
selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set()
|
||||
for col, checkbox in self.y_column_checkboxes.items():
|
||||
checkbox.setChecked(col in selected_y)
|
||||
|
||||
if 'methods' in config:
|
||||
methods = config['methods']
|
||||
if isinstance(methods, list):
|
||||
selected_methods = set(methods)
|
||||
elif methods == 'all':
|
||||
selected_methods = set(self.method_checkboxes.keys())
|
||||
else:
|
||||
selected_methods = set()
|
||||
for method, checkbox in self.method_checkboxes.items():
|
||||
checkbox.setChecked(method in selected_methods)
|
||||
|
||||
if 'output_dir' in config:
|
||||
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Step5 界面读取训练光谱 CSV 路径
|
||||
main_window = self.window()
|
||||
if main_window and hasattr(main_window, 'step5_panel'):
|
||||
step5_widget = getattr(main_window.step5_panel, 'output_file', None)
|
||||
step5_output_path = ""
|
||||
if hasattr(step5_widget, 'get_path'):
|
||||
step5_output_path = step5_widget.get_path() or ""
|
||||
elif hasattr(step5_widget, 'text'):
|
||||
step5_output_path = step5_widget.text() or ""
|
||||
|
||||
if step5_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_output_path):
|
||||
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
|
||||
existing = self.csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.csv_file.set_path(step5_output_path)
|
||||
|
||||
# 2. 自动填充输出目录(9_Custom_Regression_Modeling)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir.text().strip()
|
||||
if not existing_out:
|
||||
self.output_dir.setText(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤6.75"""
|
||||
csv_path = self.csv_file.get_path()
|
||||
|
||||
if not csv_path:
|
||||
QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件")
|
||||
return
|
||||
if not os.path.exists(csv_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在")
|
||||
return
|
||||
|
||||
selected_x_columns = [
|
||||
col for col, checkbox in self.x_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_x_columns:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列")
|
||||
return
|
||||
|
||||
selected_y_columns = [
|
||||
col for col, checkbox in self.y_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_y_columns:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列")
|
||||
return
|
||||
|
||||
selected_methods = [
|
||||
method for method, checkbox in self.method_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_methods:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step6_75', {'step6_75': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
@ -1,364 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step6 面板 - 机器学习建模
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
|
||||
QPushButton, QFileDialog, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step6Panel(QWidget):
|
||||
"""步骤6:机器学习建模"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 训练数据文件(用于独立运行)
|
||||
self.training_csv_file = FileSelectWidget(
|
||||
"训练数据:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.training_csv_file)
|
||||
|
||||
# 机器学习模型页面
|
||||
self.ml_page = QWidget()
|
||||
self.create_ml_page()
|
||||
layout.addWidget(self.ml_page)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_path = FileSelectWidget(
|
||||
"输出文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)",
|
||||
mode="save"
|
||||
)
|
||||
self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
|
||||
self.output_path.browse_btn.clicked.disconnect()
|
||||
self.output_path.browse_btn.clicked.connect(self.browse_output_path)
|
||||
layout.addWidget(self.output_path)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def create_ml_page(self):
|
||||
"""创建机器学习模型页面"""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("训练参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.feature_start = QLineEdit()
|
||||
self.feature_start.setText("374.285004")
|
||||
params_layout.addRow("特征起始列:", self.feature_start)
|
||||
|
||||
self.cv_folds = QSpinBox()
|
||||
self.cv_folds.setRange(2, 10)
|
||||
self.cv_folds.setValue(3)
|
||||
params_layout.addRow("交叉验证折数:", self.cv_folds)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 预处理方法 - 多选
|
||||
preproc_group = QGroupBox("预处理方法 (可多选)")
|
||||
preproc_layout = QVBoxLayout()
|
||||
|
||||
preproc_grid = QGridLayout()
|
||||
self.preproc_checkboxes = {}
|
||||
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
|
||||
|
||||
for i, method in enumerate(preproc_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
checkbox.setChecked(True)
|
||||
self.preproc_checkboxes[method] = checkbox
|
||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||
|
||||
button_layout = QHBoxLayout()
|
||||
select_all_btn = QPushButton("全选")
|
||||
deselect_all_btn = QPushButton("全不选")
|
||||
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
|
||||
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
|
||||
button_layout.addWidget(select_all_btn)
|
||||
button_layout.addWidget(deselect_all_btn)
|
||||
button_layout.addStretch()
|
||||
|
||||
preproc_layout.addLayout(preproc_grid)
|
||||
preproc_layout.addLayout(button_layout)
|
||||
preproc_group.setLayout(preproc_layout)
|
||||
layout.addWidget(preproc_group)
|
||||
|
||||
# 模型选择 - 多选
|
||||
model_group = QGroupBox("模型类型 (可多选)")
|
||||
model_layout = QVBoxLayout()
|
||||
|
||||
model_grid = QGridLayout()
|
||||
self.model_checkboxes = {}
|
||||
|
||||
model_groups = [
|
||||
("线性模型", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
|
||||
("树模型", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
|
||||
("集成学习", ['GradientBoosting', 'AdaBoost']),
|
||||
("其他模型", ['SVR', 'KNN', 'MLP'])
|
||||
]
|
||||
|
||||
row = 0
|
||||
for group_name, models in model_groups:
|
||||
group_label = QLabel(f"<b>{group_name}</b>")
|
||||
group_label.setStyleSheet(
|
||||
f"background-color: {ModernStylesheet.COLORS['hover']}; "
|
||||
f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
|
||||
f"border-radius: 3px;"
|
||||
)
|
||||
model_grid.addWidget(group_label, row, 0, 1, 4)
|
||||
row += 1
|
||||
|
||||
for i, model in enumerate(models):
|
||||
checkbox = QCheckBox(model)
|
||||
checkbox.setChecked(model in ['SVR', 'RF', 'Ridge', 'Lasso'])
|
||||
self.model_checkboxes[model] = checkbox
|
||||
model_grid.addWidget(checkbox, row, i % 4)
|
||||
if (i + 1) % 4 == 0:
|
||||
row += 1
|
||||
|
||||
row += 1
|
||||
|
||||
model_button_layout = QHBoxLayout()
|
||||
model_select_all = QPushButton("全选")
|
||||
model_deselect_all = QPushButton("全不选")
|
||||
model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
|
||||
model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
|
||||
model_button_layout.addWidget(model_select_all)
|
||||
model_button_layout.addWidget(model_deselect_all)
|
||||
model_button_layout.addStretch()
|
||||
|
||||
model_layout.addLayout(model_grid)
|
||||
model_layout.addLayout(model_button_layout)
|
||||
model_group.setLayout(model_layout)
|
||||
layout.addWidget(model_group)
|
||||
|
||||
# 数据划分方法 - 多选
|
||||
split_group = QGroupBox("数据划分方法 (可多选)")
|
||||
split_layout = QVBoxLayout()
|
||||
|
||||
split_grid = QGridLayout()
|
||||
self.split_checkboxes = {}
|
||||
split_methods = ['spxy', 'ks', 'random']
|
||||
|
||||
for i, method in enumerate(split_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
checkbox.setChecked(True)
|
||||
self.split_checkboxes[method] = checkbox
|
||||
split_grid.addWidget(checkbox, 0, i)
|
||||
|
||||
split_button_layout = QHBoxLayout()
|
||||
split_select_all = QPushButton("全选")
|
||||
split_deselect_all = QPushButton("全不选")
|
||||
split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
|
||||
split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
|
||||
split_button_layout.addWidget(split_select_all)
|
||||
split_button_layout.addWidget(split_deselect_all)
|
||||
split_button_layout.addStretch()
|
||||
|
||||
split_layout.addLayout(split_grid)
|
||||
split_layout.addLayout(split_button_layout)
|
||||
split_group.setLayout(split_layout)
|
||||
layout.addWidget(split_group)
|
||||
|
||||
self.ml_page.setLayout(layout)
|
||||
|
||||
def _toggle_checkboxes(self, checkboxes_dict, checked):
|
||||
"""统一设置checkbox状态"""
|
||||
for checkbox in checkboxes_dict.values():
|
||||
checkbox.setChecked(checked)
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_output_path(self):
|
||||
"""浏览输出文件路径(保存对话框)"""
|
||||
current = self.output_path.get_path().strip()
|
||||
if current:
|
||||
initial_dir = os.path.dirname(current)
|
||||
initial_file = os.path.basename(current)
|
||||
else:
|
||||
initial_dir = ""
|
||||
initial_file = ""
|
||||
|
||||
if not initial_dir or not os.path.isdir(initial_dir):
|
||||
# 默认定位到 indices 目录
|
||||
work_dir = self._get_default_work_dir()
|
||||
initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
|
||||
if initial_dir and not os.path.isdir(initial_dir):
|
||||
os.makedirs(initial_dir, exist_ok=True)
|
||||
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
if file_path:
|
||||
self.output_path.set_path(file_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
preprocessing_methods = [
|
||||
method for method, checkbox in self.preproc_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
model_names = [
|
||||
model for model, checkbox in self.model_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
split_methods = [
|
||||
method for method, checkbox in self.split_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
|
||||
config = {
|
||||
'feature_start_column': self.feature_start.text(),
|
||||
'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
|
||||
'model_names': model_names if model_names else ['SVR'],
|
||||
'split_methods': split_methods if split_methods else ['random'],
|
||||
'cv_folds': self.cv_folds.value()
|
||||
}
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if training_csv_path:
|
||||
config['training_csv_path'] = training_csv_path
|
||||
output_path = self.output_path.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'feature_start_column' in config:
|
||||
self.feature_start.setText(str(config['feature_start_column']))
|
||||
if 'cv_folds' in config:
|
||||
self.cv_folds.setValue(config['cv_folds'])
|
||||
if 'preprocessing_methods' in config:
|
||||
methods = config['preprocessing_methods']
|
||||
for method, checkbox in self.preproc_checkboxes.items():
|
||||
checkbox.setChecked(method in methods)
|
||||
if 'model_names' in config:
|
||||
models = config['model_names']
|
||||
for model, checkbox in self.model_checkboxes.items():
|
||||
checkbox.setChecked(model in models)
|
||||
if 'split_methods' in config:
|
||||
methods = config['split_methods']
|
||||
for method, checkbox in self.split_checkboxes.items():
|
||||
checkbox.setChecked(method in methods)
|
||||
if 'training_csv_path' in config:
|
||||
self.training_csv_file.set_path(config['training_csv_path'])
|
||||
if 'output_path' in config:
|
||||
self.output_path.set_path(config['output_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'step5_panel'):
|
||||
# 优先直接从 Step5 的输出 widget 读取
|
||||
step5_output = main_window.step5_panel.output_file.get_path()
|
||||
if step5_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_output):
|
||||
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
|
||||
self.training_csv_file.set_path(step5_output)
|
||||
elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
|
||||
# 回退:从 Step5 的 config 字典中查找可能的键名
|
||||
step5_cfg = main_window.step5_panel.get_config()
|
||||
step5_csv = (
|
||||
step5_cfg.get('training_spectra_path')
|
||||
or step5_cfg.get('output_file')
|
||||
or step5_cfg.get('csv_path')
|
||||
or step5_cfg.get('output_csv')
|
||||
)
|
||||
if step5_csv:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_csv):
|
||||
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
|
||||
self.training_csv_file.set_path(step5_csv)
|
||||
|
||||
# 2. 自动填充输出文件路径(基于工作目录和输入文件名)
|
||||
# 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
|
||||
# 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
|
||||
if self.work_dir:
|
||||
indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
|
||||
os.makedirs(indices_dir, exist_ok=True)
|
||||
training_csv = self.training_csv_file.get_path()
|
||||
if training_csv:
|
||||
basename = os.path.splitext(os.path.basename(training_csv))[0]
|
||||
output_file = f"{basename}_indices.csv"
|
||||
else:
|
||||
output_file = "water_quality_indices.csv"
|
||||
output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
|
||||
self.output_path.set_path(output_path)
|
||||
else:
|
||||
self.output_path.set_path("")
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤6"""
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if not training_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step6': self.get_config()}
|
||||
main_window.run_single_step('step6', config)
|
||||
|
||||
def get_training_params(self):
|
||||
"""获取模型训练参数"""
|
||||
return {
|
||||
'pipeline_type': 'machine_learning',
|
||||
'feature_start': float(self.feature_start.text()),
|
||||
'cv_folds': self.cv_folds.value(),
|
||||
'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
|
||||
'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
|
||||
'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
|
||||
}
|
||||
@ -1,208 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step7 面板 - 采样点生成
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QSpinBox, QMessageBox,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step7Panel(QWidget):
|
||||
"""步骤7:采样点生成"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 去耀斑影像文件(用于独立运行)
|
||||
self.deglint_img_file = FileSelectWidget(
|
||||
"去耀斑影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.deglint_img_file)
|
||||
|
||||
# 水域掩膜文件(可选,用于独立运行)
|
||||
self.water_mask_file = FileSelectWidget(
|
||||
"水域掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.water_mask_file.label.setText("水域掩膜:")
|
||||
layout.addWidget(self.water_mask_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("采样参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.interval = QSpinBox()
|
||||
self.interval.setRange(10, 500)
|
||||
self.interval.setValue(50)
|
||||
params_layout.addRow("采样点间隔(像素):", self.interval)
|
||||
|
||||
self.sample_radius = QSpinBox()
|
||||
self.sample_radius.setRange(1, 50)
|
||||
self.sample_radius.setValue(5)
|
||||
params_layout.addRow("采样半径(像素):", self.sample_radius)
|
||||
|
||||
self.chunk_size = QSpinBox()
|
||||
self.chunk_size.setRange(100, 10000)
|
||||
self.chunk_size.setValue(1000)
|
||||
params_layout.addRow("处理块大小:", self.chunk_size)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出采样点:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("sampling_points.csv")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'interval': self.interval.value(),
|
||||
'sample_radius': self.sample_radius.value(),
|
||||
'chunk_size': self.chunk_size.value(),
|
||||
}
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
if deglint_img_path:
|
||||
config['deglint_img_path'] = deglint_img_path
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['water_mask_path'] = water_mask_path
|
||||
# 注意:step7_generate_sampling_points 不接受 output_path 参数,输出路径由 pipeline 内部自动生成
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'interval' in config:
|
||||
self.interval.setValue(config['interval'])
|
||||
if 'sample_radius' in config:
|
||||
self.sample_radius.setValue(config['sample_radius'])
|
||||
if 'chunk_size' in config:
|
||||
self.chunk_size.setValue(config['chunk_size'])
|
||||
if 'deglint_img_path' in config:
|
||||
self.deglint_img_file.set_path(config['deglint_img_path'])
|
||||
if 'water_mask_path' in config:
|
||||
self.water_mask_file.set_path(config['water_mask_path'])
|
||||
if 'glint_mask_path' in config:
|
||||
self.glint_mask_file.set_path(config['glint_mask_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充去耀斑影像和掩膜路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径)
|
||||
"""
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径)
|
||||
deglint_path = None
|
||||
if pipeline and hasattr(pipeline, 'step_outputs'):
|
||||
step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {})
|
||||
deglint_path = (
|
||||
step3_outputs.get('deglint_image')
|
||||
or step3_outputs.get('output_path')
|
||||
or step3_outputs.get('output_file')
|
||||
or step3_outputs.get('deglint_img_path')
|
||||
)
|
||||
# 回退:从 step3 面板 widget 直接读取(可能是相对路径)
|
||||
if not deglint_path and hasattr(main_window, 'step3_panel'):
|
||||
deglint_path = main_window.step3_panel.output_file.get_path()
|
||||
|
||||
if deglint_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(deglint_path):
|
||||
deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/')
|
||||
self.deglint_img_file.set_path(deglint_path)
|
||||
|
||||
# 2. 填充水域掩膜路径(优先级:pipeline.step_outputs > step1_panel > 1_water_mask > input-test)
|
||||
water_mask_path = None
|
||||
if pipeline and hasattr(pipeline, 'step_outputs'):
|
||||
step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {})
|
||||
water_mask_path = (
|
||||
step1_outputs.get('water_mask')
|
||||
or step1_outputs.get('output_path')
|
||||
or step1_outputs.get('output_file')
|
||||
)
|
||||
# 回退:从 step1 面板 widget 直接读取
|
||||
if not water_mask_path and hasattr(main_window, 'step1_panel'):
|
||||
water_mask_path = main_window.step1_panel.output_file.get_path()
|
||||
# 备选:扫描 1_water_mask 目录下的 .dat 文件
|
||||
if not water_mask_path and self.work_dir:
|
||||
mask_dir = os.path.join(self.work_dir, "1_water_mask")
|
||||
if os.path.isdir(mask_dir):
|
||||
dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')]
|
||||
if dat_files:
|
||||
water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/')
|
||||
# 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat)
|
||||
if not water_mask_path and self.work_dir:
|
||||
input_test_dir = os.path.join(self.work_dir, "input-test")
|
||||
if os.path.isdir(input_test_dir):
|
||||
dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')]
|
||||
# 优先匹配 water_mask_from_shp.dat
|
||||
for f in dat_files:
|
||||
if 'water_mask_from_shp' in f.lower():
|
||||
water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/')
|
||||
break
|
||||
# 否则取第一个 .dat 文件
|
||||
if not water_mask_path and dat_files:
|
||||
water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/')
|
||||
|
||||
if water_mask_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(water_mask_path):
|
||||
water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/')
|
||||
self.water_mask_file.set_path(water_mask_path)
|
||||
|
||||
# 3. 自动填充输出路径(绝对路径)
|
||||
if self.work_dir:
|
||||
output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv")
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
self.output_file.set_path(output_path.replace('\\', '/'))
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤7"""
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
if not deglint_img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step7': self.get_config()}
|
||||
main_window.run_single_step('step7', config)
|
||||
@ -1,226 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8_5 面板 - 非经验模型预测
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
|
||||
QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step8_5Panel(QWidget):
|
||||
"""步骤8.5:非经验模型预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 采样光谱CSV文件选择
|
||||
self.sampling_csv_file = FileSelectWidget(
|
||||
"采样光谱CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.sampling_csv_file)
|
||||
|
||||
# 模型目录选择
|
||||
self.models_dir_file = FileSelectWidget(
|
||||
"模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.models_dir_file.label.setText("模型目录:")
|
||||
self.models_dir_file.browse_btn.clicked.disconnect()
|
||||
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
|
||||
layout.addWidget(self.models_dir_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("预测参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.metric = QComboBox()
|
||||
self.metric.addItems(['Average Accuracy(%)', 'Min Accuracy(%)', 'Max Accuracy(%)'])
|
||||
params_layout.addRow("模型选择指标:", self.metric)
|
||||
|
||||
self.prediction_column = QLineEdit()
|
||||
self.prediction_column.setText("prediction")
|
||||
params_layout.addRow("预测列名:", self.prediction_column)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出文件夹:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_file.label.setText("输出文件夹:")
|
||||
self.output_file.browse_btn.clicked.disconnect()
|
||||
self.output_file.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充采样光谱和回归模型目录
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
|
||||
if main_window and hasattr(main_window, 'step7_panel'):
|
||||
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
|
||||
step7_output_path = ""
|
||||
if hasattr(step7_widget, 'get_path'):
|
||||
step7_output_path = step7_widget.get_path() or ""
|
||||
elif hasattr(step7_widget, 'text'):
|
||||
step7_output_path = step7_widget.text() or ""
|
||||
|
||||
if step7_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step7_output_path):
|
||||
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
|
||||
existing = self.sampling_csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.sampling_csv_file.set_path(step7_output_path)
|
||||
|
||||
# 2. 尝试从 Step6.5 界面读取回归模型目录
|
||||
if main_window and hasattr(main_window, 'step6_5_panel'):
|
||||
step6_5_widget = getattr(main_window.step6_5_panel, 'output_dir', None)
|
||||
step6_5_models_dir = ""
|
||||
if hasattr(step6_5_widget, 'get_path'):
|
||||
step6_5_models_dir = step6_5_widget.get_path() or ""
|
||||
elif hasattr(step6_5_widget, 'text'):
|
||||
step6_5_models_dir = step6_5_widget.text() or ""
|
||||
|
||||
if step6_5_models_dir:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step6_5_models_dir):
|
||||
step6_5_models_dir = os.path.join(self.work_dir or '', step6_5_models_dir).replace('\\', '/')
|
||||
existing_models = self.models_dir_file.get_path()
|
||||
if not existing_models or not existing_models.strip():
|
||||
self.models_dir_file.set_path(step6_5_models_dir)
|
||||
|
||||
# 3. 自动填充输出路径(非经验模型预测目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Non_Empirical_Prediction")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_file.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_file.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_models_dir(self):
|
||||
"""浏览模型目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "8_Regression_Modeling")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
|
||||
if dir_path:
|
||||
self.models_dir_file.set_path(dir_path)
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "11_12_13_predictions/Non_Empirical_Prediction")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出文件夹", default)
|
||||
if dir_path:
|
||||
self.output_file.set_path(dir_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'metric': self.metric.currentText(),
|
||||
'prediction_column': self.prediction_column.text(),
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if sampling_csv_path:
|
||||
config['sampling_csv_path'] = sampling_csv_path
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if models_dir:
|
||||
config['models_dir'] = models_dir
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'metric' in config:
|
||||
idx = self.metric.findText(config['metric'])
|
||||
if idx >= 0:
|
||||
self.metric.setCurrentIndex(idx)
|
||||
if 'prediction_column' in config:
|
||||
self.prediction_column.setText(config['prediction_column'])
|
||||
if 'sampling_csv_path' in config:
|
||||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||||
if 'models_dir' in config:
|
||||
self.models_dir_file.set_path(config['models_dir'])
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤8.5"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step8_5', {'step8_5': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
@ -1,230 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8_75 面板 - 自定义回归预测
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox,
|
||||
QPushButton, QCheckBox, QMessageBox, QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step8_75Panel(QWidget):
|
||||
"""步骤8.75:自定义回归预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 采样光谱CSV文件选择
|
||||
self.sampling_csv_file = FileSelectWidget(
|
||||
"采样光谱CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.sampling_csv_file)
|
||||
|
||||
# 自定义回归模型目录选择(9_Custom_Regression_Modeling)
|
||||
self.regression_models_dir = FileSelectWidget(
|
||||
"回归模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.regression_models_dir.label.setText("回归模型目录:")
|
||||
self.regression_models_dir.browse_btn.clicked.disconnect()
|
||||
self.regression_models_dir.browse_btn.clicked.connect(self.browse_regression_models_dir)
|
||||
self.regression_models_dir.set_path("") # 路径由 update_from_config 根据 work_dir 自动填充
|
||||
layout.addWidget(self.regression_models_dir)
|
||||
|
||||
# 公式CSV文件选择(用于查找index_formula)
|
||||
self.formula_csv_file = FileSelectWidget(
|
||||
"公式CSV文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.formula_csv_file.label.setText("公式CSV文件:")
|
||||
layout.addWidget(self.formula_csv_file)
|
||||
|
||||
# 输出目录选择
|
||||
self.output_dir_widget = FileSelectWidget(
|
||||
"输出目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_dir_widget.label.setText("输出目录:")
|
||||
self.output_dir_widget.browse_btn.clicked.disconnect()
|
||||
self.output_dir_widget.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
self.output_dir_widget.line_edit.setPlaceholderText("留空使用默认prediction目录")
|
||||
layout.addWidget(self.output_dir_widget)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充采样光谱和自定义回归模型目录
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
|
||||
if main_window and hasattr(main_window, 'step7_panel'):
|
||||
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
|
||||
step7_output_path = ""
|
||||
if hasattr(step7_widget, 'get_path'):
|
||||
step7_output_path = step7_widget.get_path() or ""
|
||||
elif hasattr(step7_widget, 'text'):
|
||||
step7_output_path = step7_widget.text() or ""
|
||||
|
||||
if step7_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step7_output_path):
|
||||
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
|
||||
existing = self.sampling_csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.sampling_csv_file.set_path(step7_output_path)
|
||||
|
||||
# 2. 尝试从 Step6.75 界面读取自定义回归模型目录
|
||||
if main_window and hasattr(main_window, 'step6_75_panel'):
|
||||
step6_75_widget = getattr(main_window.step6_75_panel, 'output_dir', None)
|
||||
step6_75_models_dir = ""
|
||||
if hasattr(step6_75_widget, 'get_path'):
|
||||
step6_75_models_dir = step6_75_widget.get_path() or ""
|
||||
elif hasattr(step6_75_widget, 'text'):
|
||||
step6_75_models_dir = step6_75_widget.text() or ""
|
||||
step6_75_models_dir = step6_75_models_dir.strip()
|
||||
|
||||
if step6_75_models_dir:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step6_75_models_dir):
|
||||
step6_75_models_dir = os.path.join(self.work_dir or '', step6_75_models_dir).replace('\\', '/')
|
||||
existing_models = self.regression_models_dir.get_path()
|
||||
if not existing_models or not existing_models.strip():
|
||||
self.regression_models_dir.set_path(step6_75_models_dir)
|
||||
|
||||
# 3. 自动填充回归模型目录(如果 step6_75 未提供)
|
||||
if self.work_dir:
|
||||
models_dir = self.regression_models_dir.get_path().strip()
|
||||
if not models_dir:
|
||||
default_models_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling").replace('\\', '/')
|
||||
self.regression_models_dir.set_path(default_models_dir)
|
||||
|
||||
# 4. 自动填充输出目录(自定义回归预测目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Custom_Regression_Prediction")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir_widget.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_dir_widget.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_regression_models_dir(self):
|
||||
"""浏览回归模型目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "9_Custom_Regression_Modeling")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择回归模型目录", default)
|
||||
if dir_path:
|
||||
self.regression_models_dir.set_path(dir_path)
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "11_12_13_predictions/Custom_Regression_Prediction")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出目录", default)
|
||||
if dir_path:
|
||||
self.output_dir_widget.set_path(dir_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if sampling_csv_path:
|
||||
config['sampling_csv_path'] = sampling_csv_path
|
||||
regression_models_dir = self.regression_models_dir.get_path()
|
||||
if regression_models_dir:
|
||||
config['custom_regression_dir'] = regression_models_dir
|
||||
formula_csv_path = self.formula_csv_file.get_path()
|
||||
if formula_csv_path:
|
||||
config['formula_csv_path'] = formula_csv_path
|
||||
output_dir = self.output_dir_widget.get_path()
|
||||
if output_dir:
|
||||
config['output_dir'] = output_dir
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'sampling_csv_path' in config:
|
||||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||||
if 'custom_regression_dir' in config:
|
||||
self.regression_models_dir.set_path(config['custom_regression_dir'])
|
||||
if 'formula_csv_path' in config:
|
||||
self.formula_csv_file.set_path(config['formula_csv_path'])
|
||||
if 'output_dir' in config:
|
||||
self.output_dir_widget.set_path(config['output_dir'])
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤8.75"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
regression_models_dir = self.regression_models_dir.get_path()
|
||||
if not regression_models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择回归模型目录!")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step8_75', {'step8_75': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8 面板 - 机器学习预测
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
|
||||
QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step8Panel(QWidget):
|
||||
"""步骤8:机器学习预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 采样光谱CSV文件(用于独立运行)
|
||||
self.sampling_csv_file = FileSelectWidget(
|
||||
"采样光谱CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.sampling_csv_file)
|
||||
|
||||
# 模型目录(用于独立运行)
|
||||
self.models_dir_file = FileSelectWidget(
|
||||
"模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.models_dir_file.label.setText("模型目录:")
|
||||
self.models_dir_file.browse_btn.clicked.disconnect()
|
||||
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
|
||||
layout.addWidget(self.models_dir_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("预测参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.metric = QComboBox()
|
||||
self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
|
||||
params_layout.addRow("模型选择指标:", self.metric)
|
||||
|
||||
self.prediction_column = QLineEdit()
|
||||
self.prediction_column.setText("prediction")
|
||||
params_layout.addRow("预测列名:", self.prediction_column)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出路径:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充采样光谱和模型目录
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
|
||||
if main_window and hasattr(main_window, 'step7_panel'):
|
||||
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
|
||||
step7_output_path = ""
|
||||
if hasattr(step7_widget, 'get_path'):
|
||||
step7_output_path = step7_widget.get_path() or ""
|
||||
elif hasattr(step7_widget, 'text'):
|
||||
step7_output_path = step7_widget.text() or ""
|
||||
|
||||
if step7_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step7_output_path):
|
||||
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
|
||||
existing = self.sampling_csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.sampling_csv_file.set_path(step7_output_path)
|
||||
|
||||
# 2. 尝试从 Step6 界面读取监督模型目录
|
||||
if main_window and hasattr(main_window, 'step6_panel'):
|
||||
step6_widget = getattr(main_window.step6_panel, 'output_dir', None)
|
||||
step6_models_dir = ""
|
||||
if hasattr(step6_widget, 'get_path'):
|
||||
step6_models_dir = step6_widget.get_path() or ""
|
||||
elif hasattr(step6_widget, 'text'):
|
||||
step6_models_dir = step6_widget.text() or ""
|
||||
|
||||
if step6_models_dir:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step6_models_dir):
|
||||
step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/')
|
||||
existing_models = self.models_dir_file.get_path()
|
||||
if not existing_models or not existing_models.strip():
|
||||
self.models_dir_file.set_path(step6_models_dir)
|
||||
|
||||
# 3. 自动填充输出路径(机器学习预测目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_file.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_file.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_models_dir(self):
|
||||
"""浏览模型目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "7_Supervised_Model_Training")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
|
||||
if dir_path:
|
||||
self.models_dir_file.set_path(dir_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'metric': self.metric.currentText(),
|
||||
'prediction_column': self.prediction_column.text(),
|
||||
}
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if sampling_csv_path:
|
||||
config['sampling_csv_path'] = sampling_csv_path
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if models_dir:
|
||||
config['models_dir'] = models_dir
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'metric' in config:
|
||||
idx = self.metric.findText(config['metric'])
|
||||
if idx >= 0:
|
||||
self.metric.setCurrentIndex(idx)
|
||||
if 'prediction_column' in config:
|
||||
self.prediction_column.setText(config['prediction_column'])
|
||||
if 'sampling_csv_path' in config:
|
||||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||||
if 'models_dir' in config:
|
||||
self.models_dir_file.set_path(config['models_dir'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤8"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
if not models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step8': self.get_config()}
|
||||
main_window.run_single_step('step8', config)
|
||||
@ -1,513 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step9 面板 - 分布图生成
|
||||
"""
|
||||
|
||||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
|
||||
QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
|
||||
QRadioButton, QButtonGroup, QMessageBox, QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
|
||||
try:
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
PIPELINE_AVAILABLE = True
|
||||
except ImportError:
|
||||
PIPELINE_AVAILABLE = False
|
||||
|
||||
|
||||
class Step9BatchThread(QThread):
|
||||
"""专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
|
||||
|
||||
finished_ok = pyqtSignal(int)
|
||||
failed = pyqtSignal(str)
|
||||
log_message = pyqtSignal(str, str)
|
||||
|
||||
def __init__(self, work_dir: str, csv_paths: List[str], step9_kwargs: dict, output_dir_optional: Optional[str]):
|
||||
super().__init__()
|
||||
self.work_dir = work_dir
|
||||
self.csv_paths = csv_paths
|
||||
self.step9_kwargs = step9_kwargs
|
||||
self.output_dir_optional = (output_dir_optional or "").strip() or None
|
||||
|
||||
def run(self):
|
||||
mpl_prev = None
|
||||
try:
|
||||
import matplotlib
|
||||
mpl_prev = matplotlib.get_backend()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("Agg")
|
||||
except Exception:
|
||||
mpl_prev = None
|
||||
try:
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
|
||||
n = len(self.csv_paths)
|
||||
for i, csv_p in enumerate(self.csv_paths):
|
||||
self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
|
||||
kw = {**self.step9_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
|
||||
if self.output_dir_optional:
|
||||
stem = Path(csv_p).stem
|
||||
kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
|
||||
else:
|
||||
kw["output_image_path"] = None
|
||||
pipeline.step9_generate_distribution_map(**kw)
|
||||
self.finished_ok.emit(n)
|
||||
except Exception as e:
|
||||
self.failed.emit(f"{e}\n{traceback.format_exc()}")
|
||||
finally:
|
||||
if mpl_prev:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend(mpl_prev)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class Step9Panel(QWidget):
|
||||
"""步骤9:分布图生成"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._batch_thread = None
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
hint = QLabel(
|
||||
"独立运行:可选「单个 CSV」或「文件夹批量」(扫描目录下所有 .csv)。"
|
||||
"完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。"
|
||||
)
|
||||
hint.setWordWrap(True)
|
||||
hint.setStyleSheet(
|
||||
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
|
||||
)
|
||||
layout.addWidget(hint)
|
||||
|
||||
mode_row = QHBoxLayout()
|
||||
self.mode_single_rb = QRadioButton("单个 CSV 文件")
|
||||
self.mode_folder_rb = QRadioButton("文件夹批量")
|
||||
self._mode_group = QButtonGroup(self)
|
||||
self._mode_group.addButton(self.mode_single_rb, 0)
|
||||
self._mode_group.addButton(self.mode_folder_rb, 1)
|
||||
mode_row.addWidget(self.mode_single_rb)
|
||||
mode_row.addWidget(self.mode_folder_rb)
|
||||
mode_row.addStretch()
|
||||
layout.addLayout(mode_row)
|
||||
|
||||
# ---------- RadioButton 美化样式(选中状态更醒目) ----------
|
||||
radio_style = """
|
||||
QRadioButton {
|
||||
font-size: 14px;
|
||||
spacing: 8px;
|
||||
color: #333333;
|
||||
}
|
||||
QRadioButton::indicator {
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
border: 2px solid #999999;
|
||||
border-radius: 9px;
|
||||
background-color: white;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
border: 2px solid #0078d4;
|
||||
background-color: qradialgradient(
|
||||
cx:0.5, cy:0.5, radius:0.5,
|
||||
fx:0.5, fy:0.5,
|
||||
stop:0 #0078d4,
|
||||
stop:0.6 white,
|
||||
stop:1.0 white
|
||||
);
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #005a9e;
|
||||
}
|
||||
"""
|
||||
self.mode_single_rb.setStyleSheet(radio_style)
|
||||
self.mode_folder_rb.setStyleSheet(radio_style)
|
||||
|
||||
self.prediction_csv_file = FileSelectWidget(
|
||||
"预测结果CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.prediction_csv_file)
|
||||
|
||||
folder_row = QHBoxLayout()
|
||||
self.prediction_csv_dir_label = QLabel("预测CSV目录:")
|
||||
self.prediction_csv_dir_label.setMinimumWidth(120)
|
||||
self.prediction_csv_dir_edit = QLineEdit()
|
||||
self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
|
||||
pred_dir_btn = QPushButton("浏览…")
|
||||
pred_dir_btn.setMaximumWidth(80)
|
||||
pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
|
||||
folder_row.addWidget(self.prediction_csv_dir_label)
|
||||
folder_row.addWidget(self.prediction_csv_dir_edit, 1)
|
||||
folder_row.addWidget(pred_dir_btn)
|
||||
self._folder_row_widget = QWidget()
|
||||
self._folder_row_widget.setLayout(folder_row)
|
||||
layout.addWidget(self._folder_row_widget)
|
||||
|
||||
self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv)")
|
||||
layout.addWidget(self.recursive_csv_cb)
|
||||
|
||||
self.boundary_file = FileSelectWidget(
|
||||
"边界文件:",
|
||||
"Shapefiles (*.shp);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.boundary_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("生成参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.resolution = QDoubleSpinBox()
|
||||
self.resolution.setRange(1, 1000)
|
||||
self.resolution.setValue(30)
|
||||
params_layout.addRow("分辨率(米):", self.resolution)
|
||||
|
||||
self.input_crs = QLineEdit()
|
||||
self.input_crs.setText("EPSG:32651")
|
||||
params_layout.addRow("输入坐标系:", self.input_crs)
|
||||
|
||||
self.output_crs = QLineEdit()
|
||||
self.output_crs.setText("EPSG:4326")
|
||||
params_layout.addRow("输出坐标系:", self.output_crs)
|
||||
|
||||
self.show_points = QCheckBox("显示采样点")
|
||||
params_layout.addRow("", self.show_points)
|
||||
|
||||
self.use_diffusion = QCheckBox("启用距离扩散")
|
||||
self.use_diffusion.setChecked(True)
|
||||
params_layout.addRow("", self.use_diffusion)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出目录
|
||||
self.output_dir = FileSelectWidget(
|
||||
"输出分布图目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
|
||||
self.output_dir.browse_btn.clicked.disconnect()
|
||||
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
layout.addWidget(self.output_dir)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
# 信号绑定与初始状态
|
||||
self.mode_single_rb.toggled.connect(self._toggle_input_mode)
|
||||
self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
|
||||
self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
|
||||
self._toggle_input_mode() # 根据默认值设置初始显示状态
|
||||
|
||||
def _toggle_input_mode(self):
|
||||
"""槽函数:根据单选框状态动态显示/隐藏对应的输入组件。"""
|
||||
folder_mode = self.mode_folder_rb.isChecked()
|
||||
# 单个 CSV 模式:显示单文件选择,隐藏文件夹选择
|
||||
self.prediction_csv_file.setVisible(not folder_mode)
|
||||
# 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择
|
||||
self._folder_row_widget.setVisible(folder_mode)
|
||||
self.recursive_csv_cb.setVisible(folder_mode)
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_prediction_csv_dir(self):
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "11_12_13_predictions")
|
||||
d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
|
||||
if d:
|
||||
self.prediction_csv_dir_edit.setText(d)
|
||||
|
||||
def _collect_csv_paths_from_folder(self) -> List[str]:
|
||||
folder = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
if not folder or not os.path.isdir(folder):
|
||||
return []
|
||||
root = Path(folder)
|
||||
if self.recursive_csv_cb.isChecked():
|
||||
files = sorted(root.rglob("*.csv"))
|
||||
else:
|
||||
files = sorted(root.glob("*.csv"))
|
||||
return [str(p) for p in files if p.is_file()]
|
||||
|
||||
def _step9_base_pipeline_kwargs(self) -> dict:
|
||||
return {
|
||||
'boundary_shp_path': self.boundary_file.get_path(),
|
||||
'resolution': self.resolution.value(),
|
||||
'input_crs': self.input_crs.text(),
|
||||
'output_crs': self.output_crs.text(),
|
||||
'show_sample_points': self.show_points.isChecked(),
|
||||
'use_distance_diffusion': self.use_diffusion.isChecked(),
|
||||
}
|
||||
|
||||
def get_config(self):
|
||||
pred_csv = (self.prediction_csv_file.get_path() or "").strip()
|
||||
folder_mode = self.mode_folder_rb.isChecked()
|
||||
pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
config = {
|
||||
'step9_batch_mode': 'folder' if folder_mode else 'single',
|
||||
'prediction_csv_dir': pred_dir if pred_dir else None,
|
||||
'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
|
||||
'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
|
||||
'boundary_shp_path': self.boundary_file.get_path(),
|
||||
'resolution': self.resolution.value(),
|
||||
'input_crs': self.input_crs.text(),
|
||||
'output_crs': self.output_crs.text(),
|
||||
'show_sample_points': self.show_points.isChecked(),
|
||||
'use_distance_diffusion': self.use_diffusion.isChecked(),
|
||||
}
|
||||
out_dir = (self.output_dir.get_path() or "").strip()
|
||||
if not folder_mode and pred_csv and out_dir:
|
||||
stem = Path(pred_csv).stem
|
||||
config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
|
||||
else:
|
||||
config['output_image_path'] = None
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
mode = config.get('step9_batch_mode', 'single')
|
||||
if mode == 'folder':
|
||||
self.mode_folder_rb.setChecked(True)
|
||||
else:
|
||||
self.mode_single_rb.setChecked(True)
|
||||
if config.get('prediction_csv_dir'):
|
||||
self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
|
||||
if 'recursive_csv_scan' in config:
|
||||
self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
|
||||
if 'prediction_csv_path' in config and config['prediction_csv_path']:
|
||||
self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
|
||||
if 'boundary_shp_path' in config:
|
||||
self.boundary_file.set_path(config['boundary_shp_path'])
|
||||
if 'resolution' in config:
|
||||
self.resolution.setValue(config['resolution'])
|
||||
if 'input_crs' in config:
|
||||
self.input_crs.setText(config['input_crs'])
|
||||
if 'output_crs' in config:
|
||||
self.output_crs.setText(config['output_crs'])
|
||||
if 'show_sample_points' in config:
|
||||
self.show_points.setChecked(config['show_sample_points'])
|
||||
if 'use_distance_diffusion' in config:
|
||||
self.use_diffusion.setChecked(config['use_distance_diffusion'])
|
||||
if 'output_dir' in config and config['output_dir']:
|
||||
self.output_dir.set_path(str(config['output_dir']))
|
||||
elif config.get('output_image_path'):
|
||||
p = Path(str(config['output_image_path']))
|
||||
if p.parent and str(p.parent) != '.':
|
||||
self.output_dir.set_path(str(p.parent))
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充预测结果目录
|
||||
|
||||
优先使用 Step8(机器学习预测)的输出目录作为待预测 CSV 目录;
|
||||
其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
if not main_window:
|
||||
return
|
||||
|
||||
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(优先)
|
||||
pred_dir = None
|
||||
if hasattr(main_window, 'step8_panel'):
|
||||
step8_widget = getattr(main_window.step8_panel, 'output_file', None)
|
||||
step8_output = ""
|
||||
if hasattr(step8_widget, 'get_path'):
|
||||
step8_output = step8_widget.get_path() or ""
|
||||
elif hasattr(step8_widget, 'text'):
|
||||
step8_output = step8_widget.text() or ""
|
||||
|
||||
if step8_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step8_output):
|
||||
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
|
||||
pred_dir = str(Path(step8_output).parent)
|
||||
|
||||
# 2. 备选:从 Step8.5 界面读取非经验预测输出目录
|
||||
if not pred_dir and hasattr(main_window, 'step8_5_panel'):
|
||||
step8_5_widget = getattr(main_window.step8_5_panel, 'output_file', None)
|
||||
step8_5_output = ""
|
||||
if hasattr(step8_5_widget, 'get_path'):
|
||||
step8_5_output = step8_5_widget.get_path() or ""
|
||||
elif hasattr(step8_5_widget, 'text'):
|
||||
step8_5_output = step8_5_widget.text() or ""
|
||||
|
||||
if step8_5_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step8_5_output):
|
||||
step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
|
||||
pred_dir = str(Path(step8_5_output).parent)
|
||||
|
||||
# 3. 备选:从 Step8.75 界面读取自定义回归预测输出目录
|
||||
if not pred_dir and hasattr(main_window, 'step8_75_panel'):
|
||||
step8_75_widget = getattr(main_window.step8_75_panel, 'output_dir_widget', None)
|
||||
step8_75_output = ""
|
||||
if hasattr(step8_75_widget, 'get_path'):
|
||||
step8_75_output = step8_75_widget.get_path() or ""
|
||||
elif hasattr(step8_75_widget, 'text'):
|
||||
step8_75_output = step8_75_widget.text() or ""
|
||||
|
||||
if step8_75_output:
|
||||
pred_dir = step8_75_output
|
||||
|
||||
# 自动填入"预测CSV目录"(文件夹批量模式)
|
||||
if pred_dir:
|
||||
existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
if not existing_dir:
|
||||
self.prediction_csv_dir_edit.setText(pred_dir)
|
||||
# 切换到文件夹批量模式
|
||||
self.mode_folder_rb.setChecked(True)
|
||||
|
||||
# 4. 自动填充输出目录(14_visualization)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "14_visualization")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_dir.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "14_visualization")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
|
||||
if dir_path:
|
||||
self.output_dir.set_path(dir_path)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤9"""
|
||||
if self._batch_thread and self._batch_thread.isRunning():
|
||||
QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
|
||||
return
|
||||
|
||||
boundary_shp_path = self.boundary_file.get_path()
|
||||
if not boundary_shp_path:
|
||||
QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
|
||||
return
|
||||
if not os.path.exists(boundary_shp_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
|
||||
return
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if not parent or not hasattr(parent, 'run_single_step'):
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
return
|
||||
|
||||
if self.mode_folder_rb.isChecked():
|
||||
csv_list = self._collect_csv_paths_from_folder()
|
||||
if not csv_list:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入验证失败",
|
||||
"所选文件夹中未找到 .csv 文件,或目录无效。\n"
|
||||
"可勾选「包含子文件夹」以递归扫描。",
|
||||
)
|
||||
return
|
||||
if not PIPELINE_AVAILABLE:
|
||||
QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
|
||||
return
|
||||
work_dir = getattr(parent, "work_dir", None) or "./work_dir"
|
||||
work_dir = str(work_dir)
|
||||
base_kw = self._step9_base_pipeline_kwargs()
|
||||
out_dir_opt = (self.output_dir.get_path() or "").strip() or None
|
||||
self.run_button.setEnabled(False)
|
||||
self._batch_thread = Step9BatchThread(work_dir, csv_list, base_kw, out_dir_opt)
|
||||
main_win = parent
|
||||
|
||||
def _batch_log(msg, lvl):
|
||||
if hasattr(main_win, "log_message"):
|
||||
main_win.log_message(msg, lvl)
|
||||
|
||||
self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
|
||||
self._batch_thread.finished_ok.connect(self._on_step9_batch_ok, Qt.QueuedConnection)
|
||||
self._batch_thread.failed.connect(self._on_step9_batch_fail, Qt.QueuedConnection)
|
||||
self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection)
|
||||
self._batch_thread.start()
|
||||
if hasattr(parent, "log_message"):
|
||||
parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV,工作目录 {work_dir}", "info")
|
||||
return
|
||||
|
||||
prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
|
||||
if not prediction_csv_path:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入验证失败",
|
||||
"请选择「预测结果 CSV」文件,或切换到「文件夹批量」。",
|
||||
)
|
||||
return
|
||||
if not os.path.isfile(prediction_csv_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
parent.run_single_step('step9', {'step9': config})
|
||||
|
||||
def _on_step9_batch_ok(self, n: int):
|
||||
QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, "log_message"):
|
||||
parent = parent.parent()
|
||||
if parent and hasattr(parent, "log_message"):
|
||||
parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
|
||||
|
||||
def _on_step9_batch_fail(self, err: str):
|
||||
QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, "log_message"):
|
||||
parent = parent.parent()
|
||||
if parent and hasattr(parent, "log_message"):
|
||||
parent.log_message(err, "error")
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -14,67 +14,17 @@ def rasterize_envi_xml(shp_filepath):
|
||||
|
||||
@timeit
|
||||
def rasterize_shp(shp_filepath, raster_fn_out, img_path, NoData_value=None):
|
||||
# ---------- 防御性处理:路径标准化 ----------
|
||||
shp_filepath = os.path.abspath(shp_filepath).replace('\\', '/')
|
||||
print(f"[DEBUG rasterize_shp] 标准化后的 SHP 路径: {shp_filepath}")
|
||||
|
||||
# 检查伴随文件完整性
|
||||
shp_base = os.path.splitext(shp_filepath)[0]
|
||||
for ext in ['.dbf', '.shx', '.prj']:
|
||||
companion = shp_base + ext
|
||||
if os.path.exists(companion):
|
||||
print(f"[DEBUG rasterize_shp] 伴随文件存在: {companion}")
|
||||
else:
|
||||
print(f"[WARNING rasterize_shp] 伴随文件缺失: {companion}")
|
||||
|
||||
# 确保 GDAL/OGR 驱动已注册
|
||||
gdal.AllRegister()
|
||||
ogr.RegisterAll()
|
||||
|
||||
# 检查 ESRI Shapefile 驱动
|
||||
driver = ogr.GetDriverByName("ESRI Shapefile")
|
||||
if driver is None:
|
||||
raise RuntimeError(
|
||||
"系统中未找到 ESRI Shapefile 驱动!请检查 GDAL 是否正确安装及是否包含 Shapefile 支持。"
|
||||
)
|
||||
print(f"[DEBUG rasterize_shp] ESRI Shapefile 驱动: {driver.GetName()}")
|
||||
|
||||
# 打开参考影像获取尺寸信息
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开参考影像文件: {img_path}")
|
||||
im_width = dataset.RasterXSize
|
||||
im_height = dataset.RasterYSize
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
imgdata_in = dataset.GetRasterBand(1).ReadAsArray()
|
||||
del dataset
|
||||
|
||||
# ---------- 打开 SHP 文件(双重尝试获取详细错误) ----------
|
||||
# Open the data source and read in the extent
|
||||
source_ds = gdal.OpenEx(shp_filepath, gdal.OF_VECTOR)
|
||||
if source_ds is None:
|
||||
# gdal.OpenEx 失败,尝试 ogr.Open 获取更详细的错误信息
|
||||
try:
|
||||
ogr_ds = ogr.Open(shp_filepath)
|
||||
except Exception as ogr_err:
|
||||
raise RuntimeError(
|
||||
f"GDAL/OGR 无法打开 SHP 文件(详细原因):\n"
|
||||
f" ogr.Open 抛出异常: {str(ogr_err)}\n"
|
||||
f" 文件路径: {shp_filepath}\n"
|
||||
f"常见原因:\n"
|
||||
f" 1. 路径包含中文/空格/特殊字符(建议复制到纯英文路径下重试)\n"
|
||||
f" 2. .dbf 或 .shx 伴随文件缺失或损坏\n"
|
||||
f" 3. GDAL 未注册 ESRI Shapefile 驱动\n"
|
||||
f" 4. 文件被其他程序锁定"
|
||||
)
|
||||
if ogr_ds is None:
|
||||
raise RuntimeError(
|
||||
f"ogr.Open 和 gdal.OpenEx 均返回 None,无法打开 SHP 文件。\n"
|
||||
f"文件路径: {shp_filepath}\n"
|
||||
f"请检查:\n"
|
||||
f" 1. 所有伴随文件(.dbf/.shx/.prg)是否齐全\n"
|
||||
f" 2. 文件是否被其他程序占用\n"
|
||||
f" 3. 路径中是否存在不支持的字符"
|
||||
)
|
||||
raise ValueError(f"无法打开shapefile: {shp_filepath}")
|
||||
|
||||
# 检查图层数量,如果有多层,指定使用第一层
|
||||
layer_count = source_ds.GetLayerCount()
|
||||
|
||||
@ -1,21 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
采样点生成模块 - 提供分块采样和光谱数据提取功能
|
||||
"""
|
||||
|
||||
import os
|
||||
from src.utils.util import *
|
||||
import math
|
||||
|
||||
# GDAL 环境变量保护(放在最前面,防止路径/编码问题)
|
||||
os.environ['GDAL_FILENAME_IS_UTF8'] = 'YES'
|
||||
os.environ['SHAPE_ENCODING'] = 'UTF-8'
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from osgeo import gdal, ogr
|
||||
import spectral
|
||||
from scipy import ndimage
|
||||
from src.utils.util import write_bands
|
||||
|
||||
try:
|
||||
from skimage import morphology
|
||||
from skimage.morphology import skeletonize, medial_axis
|
||||
@ -98,12 +87,6 @@ def get_spectral_sampling_points_chunked(bil_file, water_mask_shp, severe_glint=
|
||||
ogr.UseExceptions()
|
||||
|
||||
try:
|
||||
# ---------- 路径归一化 + 存在性检查 ----------
|
||||
bil_file = os.path.abspath(bil_file).replace('\\', '/')
|
||||
print(f"[路径检查] 去耀斑影像: {bil_file}")
|
||||
if not os.path.exists(bil_file):
|
||||
raise FileNotFoundError(f"【后端错误】无法在磁盘上找到指定的去耀斑影像: {bil_file}")
|
||||
|
||||
# 打开bil文件
|
||||
dataset_bil = gdal.Open(bil_file)
|
||||
if dataset_bil is None:
|
||||
|
||||
Reference in New Issue
Block a user