Compare commits

43 Commits

Author SHA1 Message Date
DXC
2a4a7ec7be refactor(packaging): PyInstaller资源路径统一适配get_resource_path 2026-05-10 18:02:59 +08:00
DXC
5a55be286f refactor(gui): 专题图UI美化/目录寻路/掩膜继承/隐藏冗余回归步骤 2026-05-10 17:02:58 +08:00
DXC
9ba39a7bff fix(step7): 消除UnboundLocalError — 将Path/os导入上移至函数第一行,避免与后续局部赋值冲突 2026-05-10 16:45:07 +08:00
DXC
d15a7a1e2b fix(step7): 去除耀斑路径智能回溯 — .dat占位符自动拦截改为glob搜索3_deglint真实.bsq产物 2026-05-10 16:34:04 +08:00
DXC
6d4d802ffe fix(step5/step5.5): 掩膜.shp智能替身为.dat、band_math.eval注入np.nan/np.inf命名空间 2026-05-10 16:20:51 +08:00
DXC
abac272b31 fix(step3/step7): 路径断层闭环 — _safe_rename换用os.rename、generate_sampling_points入口强制.bsq后缀校验 2026-05-10 16:04:26 +08:00
DXC
95d30d8d81 修复训练摘要报告无法识别 .joblib 模型的 Bug 2026-05-10 15:45:56 +08:00
DXC
375fea77b9 修复后处理模块导包路径断层 2026-05-10 15:24:50 +08:00
DXC
8c7c995985 修复步骤3去耀斑路径断层 + UI默认路径标准化 2026-05-10 15:11:01 +08:00
DXC
f96c55f361 refactor(step6): 步骤6机器学习建模UI汉化 + 默认全不选 + 底层反向映射清洗 2026-05-10 14:58:57 +08:00
DXC
14278739bf refactor(step4): 剥离 Steps 层 - step4~step9 业务逻辑下沉到独立模块 2026-05-09 17:55:58 +08:00
DXC
d0eb458392 refactor(step4): 剥离 Steps 层 - step1~step3 业务逻辑下沉到独立模块 2026-05-09 17:30:49 +08:00
DXC
605ec86108 修改命名 2026-05-09 17:25:40 +08:00
DXC
dcbcc043e4 refactor: 渐进式模块化重构 — 剥离可视化层、工具层、算法层到独立模块 2026-05-09 17:18:34 +08:00
DXC
b2b90050dc 修改命名 2026-05-09 16:32:55 +08:00
DXC
9d39e61161 fix: 自动格式转换 .shp→.dat 水域掩膜,解决插值函数报错 2026-05-09 14:35:58 +08:00
DXC
82af2d75d3 feat: Kutser算法分块读写改造 + GUI标题更名为Mega Water 2026-05-09 13:30:33 +08:00
DXC
820986d975 fix: SUGAR _collect_glint_pixel_values 修复循环中 del 列表导致的越界错误 2026-05-09 12:02:49 +08:00
DXC
a14d40f28d fix: 分块读写改造——修复Hedley协方差形状广播错误和SUGAR列表越界错误 2026-05-09 11:58:40 +08:00
DXC
56de4b6fc4 修复Step3去耀斑模块三大问题:SUGAR分支变量名冲突、Kutser参数默认值错误、波段索引类型统一转为int 2026-05-09 09:58:50 +08:00
DXC
4d23a65a21 修复Hedley波段索引类型:__init__中将NIR_band转为int 2026-05-09 09:51:37 +08:00
DXC
27d6db3141 修复Kutser波段索引类型:__init__中统一将oxy/lower_oxy/upper_oxy/NIR_band转为int 2026-05-09 09:50:34 +08:00
DXC
6d6bb6e402 修复NumPy兼容性:np.percentile参数interpolation改为method(Kutser/Hedley/SUGAR) 2026-05-09 09:36:11 +08:00
DXC
d7b5c45dd4 修复Step7采样点布设路径读取问题:GDAL环境变量保护+路径归一化+FileNotFoundError检查+水域掩膜备选路径扫描 2026-05-08 18:05:11 +08:00
DXC
3c0bd29275 UI优化:Step9单选框样式美化(选中态蓝色实心圆点+悬停效果) 2026-05-08 18:04:02 +08:00
DXC
ca12517d41 可视化面板:全方位事件过滤器劫持,彻底阻断Ctrl+滚轮的滚动穿透 2026-05-08 16:42:04 +08:00
DXC
33b6a918aa 可视化面板:添加工具栏操作提示语(Ctrl+滚轮缩放/滚轮滚动/拖拽平移) 2026-05-08 16:28:05 +08:00
DXC
8c7458bbe4 Panel交互增强:Step6.75/8/8.5/8.75/9 面板UI联动优化与稳定性修复 2026-05-08 16:17:19 +08:00
DXC
9b9365d823 可视化面板:完善耀斑文件名汉化,修复滚轮缩放逻辑(Ctrl+滚轮缩放,纯滚轮滚动) 2026-05-08 16:14:22 +08:00
DXC
7cadd7e437 可视化面板重构:目录/文件名双引擎汉化,Emoji图标智能分配 2026-05-08 14:43:32 +08:00
DXC
f24aa4f555 修复 PyQt 0xC0000409 崩溃:修复 window 属性命名冲突、全局异常钩子、可视化面板健壮重构 2026-05-08 14:21:50 +08:00
DXC
5af466b2d3 体验升级:路径记忆、可视化深度扫描、文件名汉化 2026-05-08 13:33:19 +08:00
DXC
a4e6747b54 修复所有 Panel 的相对路径传递问题,防止 FileNotFoundError 2026-05-08 09:40:33 +08:00
DXC
0f36da742f 修复多步运行时参数传递及文件读取问题 2026-05-08 09:27:07 +08:00
DXC
742bc392a5 界面优化 2026-05-07 16:49:24 +08:00
DXC
a645c64987 界面优化 2026-05-07 14:46:59 +08:00
DXC
c12b9d8d8a 界面优化 2026-05-07 14:23:58 +08:00
DXC
dc33ee260d fix(Step3): 修复Step3波段范围信号误植多类;新增动态波段范围限制;优化去耀斑算法调用 2026-05-06 14:41:41 +08:00
DXC
6e51d1482c feat(Step2Panel): 优化耀斑检测步骤交互体验 2026-05-06 13:06:30 +08:00
DXC
9cc89bcd69 feat(FileSelectWidget): 优化文件对话框初始目录为当前输入路径所在目录 2026-05-06 11:54:00 +08:00
DXC
15cc14b8e1 chore: 简化 requirements.txt 并添加 pykrige 依赖 2026-05-06 11:41:51 +08:00
DXC
8d36c23524 refactor: Step1Panel UI 联动逻辑深度优化 2026-05-06 11:41:21 +08:00
DXC
71e3aaa8cd feat: 水质分析系统用户体验核心升级 2026-05-06 11:33:35 +08:00
60 changed files with 13917 additions and 7929 deletions

View File

@ -0,0 +1,336 @@
# Step1Panel UI 联动逻辑优化说明
## 📋 修改概览
本次优化针对 Step1Panel水域掩膜生成步骤的 UI 联动逻辑进行了深度重构,主要解决了:
1. ✅ 输出掩膜组件与单选按钮的深度绑定
2. ✅ 路径显示的斜杠混用问题
3. ✅ 底层运行逻辑的兼容性
---
## 🎯 核心改进
### 1. 输出掩膜的显示/隐藏与单选按钮深度绑定
#### 修改位置:`Step1Panel.update_ui_state()`
**原逻辑问题**
- 输出掩膜在两种模式下都显示,不符合业务逻辑
- "使用现有掩膜文件"时不需要指定输出路径
**新逻辑**
```python
def update_ui_state(self):
"""根据选择的掩膜生成方式更新UI状态使用显示/隐藏控制)"""
use_ndwi = self.use_ndwi_radio.isChecked()
# 动态显示/隐藏组件
if use_ndwi:
# 使用NDWI模式隐藏掩膜文件显示NDWI参数和输出掩膜
self.mask_file.setVisible(False)
self.ndwi_group.setVisible(True)
self.output_file.setVisible(True) # 显示输出掩膜路径
# 当切换到NDWI模式时如果工作目录已设置自动填充输出路径
if hasattr(self, 'work_dir') and self.work_dir:
self._auto_fill_output_path()
else:
# 使用现有掩膜模式显示掩膜文件隐藏NDWI参数和输出掩膜
self.mask_file.setVisible(True)
self.ndwi_group.setVisible(False)
self.output_file.setVisible(False) # 隐藏输出掩膜路径
# 参考影像在两种模式下都显示
self.img_file.setVisible(True)
```
**行为说明**
| 模式 | 掩膜文件 | NDWI参数组 | 输出掩膜 | 参考影像 |
|------|---------|-----------|---------|---------|
| 使用现有掩膜文件 | ✅ 显示 | ❌ 隐藏 | ❌ 隐藏 | ✅ 显示 |
| 使用NDWI自动生成 | ❌ 隐藏 | ✅ 显示 | ✅ 显示 | ✅ 显示 |
---
### 2. 修复路径斜杠混用问题
#### 修改位置 1`Step1Panel._auto_fill_output_path()` (新增方法)
**核心改进**:统一使用正斜杠 `/`,避免 Windows 下的 `\``/` 混用
```python
def _auto_fill_output_path(self):
"""
自动填充输出掩膜路径仅在NDWI模式下
确保路径使用正斜杠,避免斜杠混用
"""
if not hasattr(self, 'work_dir') or not self.work_dir:
return
# 生成输出掩膜的完整路径
output_dir = os.path.join(self.work_dir, "1_water_mask")
os.makedirs(output_dir, exist_ok=True) # 确保目录存在
# 统一使用正斜杠,避免 \ 和 / 混用
default_output_path = os.path.join(output_dir, "water_mask_out.dat").replace('\\', '/')
self.output_file.set_path(default_output_path)
```
**关键技术点**
- 使用 `os.path.join()` 构建路径(适配不同操作系统)
- 最终通过 `.replace('\\', '/')` 统一转换为正斜杠
- 在界面显示前完成转换,确保用户看到的路径一致
#### 修改位置 2`Step1Panel.update_work_directory()`
**原逻辑问题**
- 开机时直接填充路径,不考虑当前选择的模式
- 没有斜杠格式化
**新逻辑**
```python
def update_work_directory(self, work_dir):
"""
保存工作目录引用,用于后续自动填充路径
Args:
work_dir: 工作目录路径
"""
if not work_dir:
return
# 保存工作目录引用
self.work_dir = work_dir
# 如果当前选中的是NDWI模式立即填充输出路径
if self.use_ndwi_radio.isChecked():
self._auto_fill_output_path()
```
**改进说明**
- 只保存工作目录引用,不立即填充
- 仅在 NDWI 模式下才调用 `_auto_fill_output_path()`
- 配合 `update_ui_state()` 中的切换触发逻辑
---
### 3. 底层运行逻辑的兼容性保障
#### 修改位置:`Step1Panel.get_config()`
**原逻辑问题**
- 无论哪种模式,都传递 `output_path` 给底层 Pipeline
- "使用现有掩膜"模式下传递空路径可能导致底层错误
**新逻辑**
```python
def get_config(self):
"""获取配置"""
use_ndwi = self.use_ndwi_radio.isChecked()
config = {
'mask_path': None if use_ndwi else self.mask_file.get_path(),
'use_ndwi': use_ndwi,
'ndwi_threshold': self.ndwi_threshold.value()
}
# 参考影像路径(两种模式都可能需要)
img_path = self.img_file.get_path()
if img_path:
config['img_path'] = img_path
# 输出路径仅在NDWI模式下有效
if use_ndwi:
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
else:
# 使用现有掩膜时不传递output_path避免底层错误尝试保存文件
config['output_path'] = None
return config
```
**关键改进**
- 根据 `use_ndwi` 模式动态决定是否传递 `output_path`
- "使用现有掩膜"模式:强制 `output_path = None`
- "NDWI自动生成"模式:传递用户选择的路径
**底层兼容性**
```python
# Pipeline 中的处理逻辑(已在之前的提交中实现)
def step1_generate_water_mask(..., output_path: Optional[str] = None):
if use_ndwi:
if output_path:
ndwi_output_path = output_path # 使用用户指定路径
else:
ndwi_output_path = str(self.water_mask_dir / "water_mask_from_ndwi.dat")
```
---
### 4. 主窗口初始化逻辑优化
#### 修改位置:`WaterQualityGUI._auto_fill_output_paths()`
**原逻辑问题**
- 开机时直接调用 `set_path()` 填充输出掩膜路径
- 不考虑当前的单选按钮状态
**新逻辑**
```python
def _auto_fill_output_paths(self):
"""
根据工作目录自动填充各步骤的输出路径
注意Step1 的输出路径由 update_work_directory() 根据模式自动控制
"""
if not self.work_dir:
return
# Step1: 只传递工作目录引用,不直接填充路径
# 路径填充由 Step1Panel 根据单选按钮状态自动控制
if hasattr(self, 'step1_panel'):
self.step1_panel.update_work_directory(self.work_dir)
```
**改进说明**
- 主窗口只传递工作目录引用
- Step1Panel 内部根据模式自主决定是否填充路径
- 解耦主窗口和子面板的逻辑依赖
---
## 🔄 完整的交互流程
### 场景 1开机启动默认选中"使用现有掩膜文件"
```
1. 主窗口启动 → QTimer.singleShot(100) 延迟弹窗
2. 用户选择工作目录 D:\work
3. _auto_fill_output_paths() 调用 step1_panel.update_work_directory(work_dir)
4. Step1Panel.update_work_directory() 保存 self.work_dir = "D:\work"
5. 检查当前模式use_existing_radio.isChecked() = True
6. 不调用 _auto_fill_output_path(),输出掩膜保持隐藏
7. 用户看到的界面:
✅ 掩膜文件输入框(显示)
✅ 参考影像输入框(显示)
❌ NDWI参数组隐藏
❌ 输出掩膜输入框(隐藏)
```
### 场景 2用户切换到"使用NDWI自动生成"
```
1. 用户点击"使用NDWI自动生成"单选按钮
2. 触发 toggled 信号 → update_ui_state()
3. use_ndwi = True
4. 执行显示/隐藏逻辑:
- self.mask_file.setVisible(False) # 隐藏掩膜文件
- self.ndwi_group.setVisible(True) # 显示NDWI参数组
- self.output_file.setVisible(True) # 显示输出掩膜
5. 检查工作目录hasattr(self, 'work_dir') = True
6. 调用 self._auto_fill_output_path()
7. 生成路径:
output_dir = os.path.join("D:\work", "1_water_mask")
path = os.path.join(output_dir, "water_mask_out.dat")
formatted_path = path.replace('\\', '/')
# 结果D:/work/1_water_mask/water_mask_out.dat
8. 自动填充到输出掩膜输入框
9. 用户看到的界面:
❌ 掩膜文件输入框(隐藏)
✅ 参考影像输入框(显示)
✅ NDWI参数组显示
✅ 输出掩膜输入框显示已填充D:/work/1_water_mask/water_mask_out.dat
```
### 场景 3用户点击"独立运行此步骤"
#### 当前选择:"使用现有掩膜文件"
```python
config = {
'mask_path': "D:/data/existing_mask.dat", # 用户选择的现有掩膜
'use_ndwi': False,
'ndwi_threshold': 0.4,
'img_path': "D:/data/image.dat",
'output_path': None # ✅ 强制为 None不尝试保存
}
```
#### 当前选择:"使用NDWI自动生成"
```python
config = {
'mask_path': None, # NDWI模式不需要现有掩膜
'use_ndwi': True,
'ndwi_threshold': 0.4,
'img_path': "D:/data/image.dat",
'output_path': "D:/work/1_water_mask/water_mask_out.dat" # ✅ 传递输出路径
}
```
---
## ✅ 测试检查点
### UI 显示测试
- [ ] 开机启动后,默认选中"使用现有掩膜文件",输出掩膜输入框应隐藏
- [ ] 切换到"使用NDWI自动生成",输出掩膜输入框应显示,并自动填充路径
- [ ] 切换回"使用现有掩膜文件",输出掩膜输入框应再次隐藏
- [ ] 所有自动填充的路径应使用正斜杠 `/`,无 `\` 混用
### 路径格式测试
- [ ] 工作目录:`D:\work` → 输出路径应显示为:`D:/work/1_water_mask/water_mask_out.dat`
- [ ] 工作目录:`C:\Users\Test\Documents` → 输出路径应显示为:`C:/Users/Test/Documents/1_water_mask/water_mask_out.dat`
### 运行逻辑测试
- [ ] "使用现有掩膜"模式运行:验证 `config['output_path'] == None`
- [ ] "NDWI自动生成"模式运行:验证 `config['output_path']` 为有效路径字符串
- [ ] 底层 Pipeline 接收 `output_path=None` 时不报错
---
## 📝 代码修改总结
| 文件 | 修改内容 | 行数变化 |
|------|---------|---------|
| `src/gui/water_quality_gui.py` | Step1Panel.update_ui_state() | +6 / -3 |
| `src/gui/water_quality_gui.py` | Step1Panel.update_work_directory() | +10 / -8 |
| `src/gui/water_quality_gui.py` | Step1Panel._auto_fill_output_path() (新增) | +15 / 0 |
| `src/gui/water_quality_gui.py` | Step1Panel.get_config() | +12 / -6 |
| `src/gui/water_quality_gui.py` | WaterQualityGUI._auto_fill_output_paths() | +3 / -4 |
**总计**:约 **+46 / -21** 行
---
## 🎯 优化效果
### 用户体验提升
1. ✅ UI 更简洁:不需要的组件自动隐藏
2. ✅ 路径一致性:所有路径显示统一使用正斜杠
3. ✅ 自动化程度提高:切换模式时自动填充/清空路径
### 代码质量提升
1. ✅ 职责分离:主窗口不直接操作子面板的路径填充
2. ✅ 逻辑内聚Step1Panel 内部自主管理显示和路径
3. ✅ 兼容性保障:底层 Pipeline 不会收到无效的 output_path
### 维护性提升
1. ✅ 新增 `_auto_fill_output_path()` 方法,单一职责
2. ✅ 路径格式化逻辑集中在一处,便于修改
3. ✅ 注释清晰,说明了各模式下的行为
---
## 🔧 后续可能的扩展
如果其他步骤也有类似的路径填充需求,可以考虑:
1. 提取公共方法 `format_path_separator(path)` 到工具类
2.`FileSelectWidget` 类中增加路径格式化的内置支持
3. 为所有路径输入框添加统一的验证和格式化逻辑
---
**文档生成时间**: 2026-05-06
**修改人员**: DXC
**关联提交**: (待提交)

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

View File

@ -1,58 +1,26 @@
# 水质参数反演分析系统 - Python 依赖
# 安装: pip install -r requirements.txt
#
# 说明:
# - Windows 下 GDAL 若 pip 安装失败,建议用 conda-forge: conda install -c conda-forge gdal
# 或使用已编译的 GDAL wheel / OSGeo4W并保证与 rasterio 版本匹配。
# - Word 报告report_word与 GUI「报告生成」页依赖 python-docxAI 解读走 Ollama HTTP API
# 无需额外 pip 包(本地或远程部署 Ollama 即可)。
# ---------- GUI ----------
PyQt5>=5.15.0
# ---------- 科学计算 ----------
# 注:当前工程打包/运行日志显示使用 Python 3.12,因此下限按 Py3.12 兼容版本设置
numpy>=1.26.0
scipy>=1.11.0
pandas>=2.0.0
# ---------- 机器学习 ----------
scikit-learn>=1.4.0
# xgboost>=2.0.0 # 可选;仅在环境已安装时 spec 会自动打入
# lightgbm>=4.0.0 # 可选;当前流水线默认未启用
# ---------- 地理空间 ----------
rasterio>=1.3.9
fiona>=1.9.5
shapely>=2.0.0
geopandas>=0.14.0
pyproj>=3.6.0
spectral>=0.22.0
# ---------- 图像 ----------
opencv-python>=4.5.0
Pillow>=8.0.0
scikit-image>=0.22.0
# ---------- 可视化 ----------
matplotlib>=3.8.0
seaborn>=0.11.0
matplotlib-scalebar>=0.8.0
# ---------- 信号处理 ----------
PyWavelets>=1.1.0
# ---------- 通用工具 ----------
joblib>=1.1.0
tqdm>=4.62.0
PyYAML>=6.0
# ---------- 表格导出(.xlsx----------
openpyxl>=3.0.0
# ---------- Word 报告生成 ----------
python-docx>=1.1.0
lxml>=4.9.0
# ---------- 打包(可选,仅构建 exe 时需要)----------
pyinstaller>=6.0.0
pykrige>=1.7.3

View File

@ -0,0 +1,36 @@
"""
算法层模块
包含插值算法和耀斑检测算法等核心数学计算
"""
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
from src.core.algorithms.glint_detection.detectors import (
otsu_threshold,
zscore_threshold,
percentile_threshold,
iqr_outlier_detection,
adaptive_threshold,
multi_band_glint_detection,
percentile_stretch,
filter_large_components,
create_shoreline_buffer,
remove_shoreline_buffer,
calculate_glint_mask,
)
__all__ = [
# 插值
'interpolate_pixels',
'interpolate_zero_pixels_batch',
# 耀斑检测
'otsu_threshold',
'zscore_threshold',
'percentile_threshold',
'iqr_outlier_detection',
'adaptive_threshold',
'multi_band_glint_detection',
'percentile_stretch',
'filter_large_components',
'create_shoreline_buffer',
'remove_shoreline_buffer',
'calculate_glint_mask',
]

View File

@ -0,0 +1,31 @@
"""
耀斑检测算法模块
包含各种耀斑检测的核心数学计算函数
"""
from src.core.algorithms.glint_detection.detectors import (
otsu_threshold,
zscore_threshold,
percentile_threshold,
iqr_outlier_detection,
adaptive_threshold,
multi_band_glint_detection,
percentile_stretch,
filter_large_components,
create_shoreline_buffer,
remove_shoreline_buffer,
calculate_glint_mask,
)
__all__ = [
'otsu_threshold',
'zscore_threshold',
'percentile_threshold',
'iqr_outlier_detection',
'adaptive_threshold',
'multi_band_glint_detection',
'percentile_stretch',
'filter_large_components',
'create_shoreline_buffer',
'remove_shoreline_buffer',
'calculate_glint_mask',
]

View File

@ -0,0 +1,595 @@
"""
耀斑检测算法模块
包含各种耀斑检测的核心数学计算函数纯数学逻辑不涉及文件I/O。
支持的方法otsu, zscore, percentile, iqr, adaptive, multi_band
本模块是从 src/utils/find_severe_glint_area.py 抽取出来的核心算法部分。
"""
import numpy as np
from typing import Optional, List, Tuple
from functools import wraps
try:
import cv2
CV2_AVAILABLE = True
except ImportError:
CV2_AVAILABLE = False
def timeit(func):
"""装饰器:测量函数执行时间"""
@wraps(func)
def wrapper(*args, **kwargs):
import time
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print(f"[{func.__name__}] 耗时: {end - start:.2f}s")
return result
return wrapper
# =============================================================================
# 百分位数拉伸
# =============================================================================
def percentile_stretch(
img: np.ndarray,
data_water_mask: np.ndarray,
lower_percentile: float = 2,
upper_percentile: float = 98,
output_range: Tuple[int, int] = (0, 255)
) -> np.ndarray:
"""
使用百分位数裁剪进行归一化,适用于低反射率数据
通过排除极值,更好地利用数据的动态范围
Args:
img: 输入图像数组反射率值通常在0-1之间
data_water_mask: 水域掩膜
lower_percentile: 下百分位数用于裁剪最小值默认2
upper_percentile: 上百分位数用于裁剪最大值默认98
output_range: 输出范围,默认(0, 255)
Returns:
归一化后的图像数组(整数类型)
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return img.astype(np.int32)
p_lower = np.percentile(valid_pixels, lower_percentile)
p_upper = np.percentile(valid_pixels, upper_percentile)
if p_lower >= p_upper:
p_lower = np.percentile(valid_pixels, 1)
p_upper = np.percentile(valid_pixels, 99)
if p_lower >= p_upper:
p_upper = valid_pixels.max()
p_lower = valid_pixels.min()
img_clipped = np.clip(img, p_lower, p_upper)
if p_upper > p_lower:
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (
output_range[1] - output_range[0]
) + output_range[0]
else:
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
return img_stretched.astype(np.int32)
# =============================================================================
# Otsu阈值分割
# =============================================================================
def otsu_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
ignore_value: int = 0,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于Otsu方法的自动阈值分割
通过最大化类间方差找到最佳分割阈值
Args:
img: 输入图像数组(整数值)
data_water_mask: 水域掩膜
ignore_value: 忽略的值默认为0
foreground: 耀斑区域值默认1
background: 背景值默认0
Returns:
二值化检测结果数组
"""
height, width = img.shape
max_value = int(np.max(img[img > ignore_value])) + 1
if max_value < 2:
max_value = 256
hist = np.zeros([max_value], np.float32)
invalid_counter = 0
for i in range(height):
for j in range(width):
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
invalid_counter += 1
continue
hist[img[i, j]] += 1
total_valid = height * width - invalid_counter
if total_valid <= 0:
return np.zeros_like(img, dtype=np.int32)
hist /= total_valid
threshold = 0
deltaMax = 0
for i in range(max_value):
wA = sum(hist[:i + 1])
wB = sum(hist[i + 1:])
if wA == 0:
wA = 1e-10
if wB == 0:
wB = 1e-10
uAtmp = sum(j * hist[j] for j in range(i + 1))
uBtmp = sum(j * hist[j] for j in range(i + 1, max_value))
uA = uAtmp / wA
uB = uBtmp / wB
u = uAtmp + uBtmp
deltaTmp = wA * ((uA - u) ** 2) + wB * ((uB - u) ** 2)
if deltaTmp > deltaMax:
deltaMax = deltaTmp
threshold = i
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > threshold] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# Z-score阈值检测
# =============================================================================
def zscore_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
z_threshold: float = 2.5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于Z-score标准化分数的耀斑检测方法
使用统计方法识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
z_threshold: Z-score阈值默认2.5即超过均值2.5个标准差)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return np.zeros_like(img, dtype=np.int32)
mean_val = np.mean(valid_pixels)
std_val = np.std(valid_pixels)
if std_val == 0:
return np.zeros_like(img, dtype=np.int32)
z_scores = np.zeros_like(img, dtype=np.float32)
valid_mask = (data_water_mask > 0) & np.isfinite(img)
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
det_img = np.zeros_like(img, dtype=np.int32)
det_img[z_scores > z_threshold] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# 百分位数阈值检测
# =============================================================================
def percentile_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
percentile: float = 95,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于百分位数的耀斑检测方法
使用百分位数作为阈值,对异常值更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
percentile: 百分位数阈值默认95即超过95%的像素值)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return np.zeros_like(img, dtype=np.int32)
threshold_val = np.percentile(valid_pixels, percentile)
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > threshold_val] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# IQR异常值检测
# =============================================================================
def iqr_outlier_detection(
img: np.ndarray,
data_water_mask: np.ndarray,
iqr_multiplier: float = 1.5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
基于IQR四分位距的异常值检测方法
使用四分位距识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
iqr_multiplier: IQR倍数默认1.5(标准异常值检测)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
return np.zeros_like(img, dtype=np.int32)
q1 = np.percentile(valid_pixels, 25)
q3 = np.percentile(valid_pixels, 75)
iqr = q3 - q1
upper_bound = q3 + iqr_multiplier * iqr
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > upper_bound] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# 自适应阈值检测
# =============================================================================
def adaptive_threshold(
img: np.ndarray,
data_water_mask: np.ndarray,
window_size: int = 15,
percentile: float = 90,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
自适应阈值方法
基于局部统计特性进行阈值分割,对光照变化更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
window_size: 局部窗口大小(奇数)
percentile: 局部百分位数阈值
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
height, width = img.shape
if window_size % 2 == 0:
window_size += 1
half_window = window_size // 2
det_img = np.zeros_like(img, dtype=np.int32)
for i in range(half_window, height - half_window):
for j in range(half_window, width - half_window):
if data_water_mask[i, j] == 0:
continue
local_window = img[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
local_mask = data_water_mask[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
valid_pixels = local_window[local_mask > 0]
if len(valid_pixels) > 0:
local_th = np.percentile(valid_pixels, percentile)
if img[i, j] > local_th:
det_img[i, j] = foreground
det_img[data_water_mask == 0] = background
return det_img
# =============================================================================
# 多波段融合耀斑检测
# =============================================================================
def multi_band_glint_detection(
nir_band: np.ndarray,
water_mask: np.ndarray,
glint_waves: List[float],
weights: Optional[List[float]] = None,
method: str = 'zscore',
z_threshold: float = 2.5,
percentile: float = 95,
sub_band_arrays: Optional[List[np.ndarray]] = None
) -> np.ndarray:
"""
多波段融合的耀斑检测方法
结合多个波段的耀斑特征,提高检测的稳健性
Args:
nir_band: 近红外波段数组(主波段,用于兼容性)
water_mask: 水域掩膜数组
glint_waves: 用于检测的波长列表,如[750, 800, 850]
weights: 各波段的权重如果为None则使用等权重
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
z_threshold: Z-score阈值当method='zscore'时使用)
percentile: 百分位数阈值当method='percentile'时使用)
sub_band_arrays: 子波段数组列表(如果提供,与 glint_waves 一一对应)
Returns:
二值化检测结果
"""
if weights is None:
weights = [1.0 / len(glint_waves)] * len(glint_waves)
if len(weights) != len(glint_waves):
raise ValueError("权重数量必须与波长数量相同")
fused_band = None
if sub_band_arrays is not None and len(sub_band_arrays) == len(glint_waves):
for i, band_array in enumerate(sub_band_arrays):
if fused_band is None:
fused_band = (band_array * weights[i]).astype(np.float32)
else:
fused_band = (fused_band + band_array * weights[i]).astype(np.float32)
else:
fused_band = nir_band.astype(np.float32)
if method == 'otsu':
stretched = percentile_stretch(fused_band, water_mask, 2, 98)
return otsu_threshold(stretched, water_mask)
elif method == 'zscore':
return zscore_threshold(fused_band, water_mask, z_threshold)
elif method == 'percentile':
return percentile_threshold(fused_band, water_mask, percentile)
else:
raise ValueError(f"不支持的方法: {method}")
# =============================================================================
# 连通域过滤
# =============================================================================
def filter_large_components(
binary_img: np.ndarray,
max_area: Optional[int] = None,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
过滤掉面积超过阈值的连通域
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
Args:
binary_img: 二值化图像
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
foreground: 前景值
background: 背景值
Returns:
过滤后的二值化图像
"""
if max_area is None or max_area <= 0:
return binary_img
if CV2_AVAILABLE:
binary_for_label = (binary_img == foreground).astype(np.uint8)
num_features, labeled_array, stats, _ = cv2.connectedComponentsWithStats(
binary_for_label, connectivity=8
)
if num_features == 0:
return binary_img
component_sizes = stats[1:, cv2.CC_STAT_AREA]
keep_labels = np.where(component_sizes <= max_area)[0] + 1
keep_mask = np.isin(labeled_array, keep_labels)
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
filtered[keep_mask] = foreground
return filtered
else:
from scipy import ndimage
labeled_array, num_features = ndimage.label(
(binary_img == foreground).astype(np.int32)
)
if num_features == 0:
return binary_img
component_sizes = ndimage.sum(
(labeled_array == i).astype(np.int32),
labeled_array,
range(1, num_features + 1)
)
keep_mask = np.isin(labeled_array, [i + 1 for i, s in enumerate(component_sizes) if s <= max_area])
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
filtered[keep_mask] = foreground
return filtered
# =============================================================================
# 岸边缓冲区处理
# =============================================================================
def create_shoreline_buffer(
water_mask: np.ndarray,
buffer_size: int = 5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
创建岸边缓冲区掩膜(向内缓冲)
用于去除岸边附近的错误耀斑检测区域
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
Args:
water_mask: 水域掩膜数组(水域=1非水域=0
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
岸边缓冲区掩膜(缓冲区区域=1其他=0
"""
if buffer_size <= 0:
return np.zeros_like(water_mask, dtype=np.int32)
water_binary = (water_mask > 0).astype(np.uint8)
structure_size = buffer_size * 2 + 1
structure = np.ones((structure_size, structure_size), dtype=np.uint8)
if CV2_AVAILABLE:
eroded_water = cv2.erode(water_binary, structure).astype(np.int32)
else:
from scipy import ndimage
eroded_water = ndimage.binary_erosion(water_binary, structure).astype(np.int32)
buffer_mask = (water_binary - eroded_water).astype(np.int32)
return buffer_mask
def remove_shoreline_buffer(
glint_mask: np.ndarray,
water_mask: np.ndarray,
buffer_size: int = 5,
foreground: int = 1,
background: int = 0
) -> np.ndarray:
"""
从耀斑掩膜中去除岸边缓冲区内的区域
Args:
glint_mask: 耀斑掩膜数组
water_mask: 水域掩膜数组
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
去除岸边缓冲区后的耀斑掩膜
"""
if buffer_size <= 0:
return glint_mask
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
cleaned = glint_mask.copy()
cleaned[buffer_mask > 0] = background
return cleaned
# =============================================================================
# 高级组合函数
# =============================================================================
def calculate_glint_mask(
nir_band: np.ndarray,
water_mask: np.ndarray,
method: str = 'otsu',
z_threshold: float = 2.5,
percentile: float = 95,
iqr_multiplier: float = 1.5,
window_size: int = 15,
apply_percentile_stretch: bool = True
) -> np.ndarray:
"""
计算耀斑掩膜的统一入口函数
Args:
nir_band: 近红外波段数组
water_mask: 水域掩膜
method: 检测方法 ('otsu', 'zscore', 'percentile', 'iqr', 'adaptive')
z_threshold: Z-score阈值
percentile: 百分位数阈值
iqr_multiplier: IQR倍数
window_size: 自适应阈值窗口大小
apply_percentile_stretch: 是否对otsu和adaptive方法应用百分位数拉伸
Returns:
二值化耀斑掩膜
"""
if method == 'otsu':
if apply_percentile_stretch:
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
return otsu_threshold(stretched, water_mask)
else:
return otsu_threshold(nir_band.astype(np.int32), water_mask)
elif method == 'zscore':
return zscore_threshold(nir_band, water_mask, z_threshold)
elif method == 'percentile':
return percentile_threshold(nir_band, water_mask, percentile)
elif method == 'iqr':
return iqr_outlier_detection(nir_band, water_mask, iqr_multiplier)
elif method == 'adaptive':
if apply_percentile_stretch:
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
return adaptive_threshold(stretched, water_mask, window_size, percentile)
else:
return adaptive_threshold(nir_band.astype(np.int32), water_mask, window_size, percentile)
else:
raise ValueError(f"不支持的方法: {method}")

View File

@ -0,0 +1,7 @@
"""
插值算法模块
包含0值像素插值的核心数学逻辑
"""
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
__all__ = ['interpolate_pixels', 'interpolate_zero_pixels_batch']

View File

@ -0,0 +1,320 @@
"""
像素插值算法模块
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
支持多种插值方法nearest, bilinear, spline (RBF), kriging。
"""
import numpy as np
from typing import Optional, Union, Tuple, List
from pathlib import Path
try:
from scipy import ndimage
from scipy.interpolate import griddata, RBFInterpolator
from scipy.spatial import cKDTree
SCIPY_AVAILABLE = True
except ImportError:
SCIPY_AVAILABLE = False
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
def interpolate_pixels(
image_stack: np.ndarray,
zero_coords: np.ndarray,
valid_coords: np.ndarray,
valid_values: np.ndarray,
interpolation_method: str = 'nearest',
water_mask: Optional[np.ndarray] = None
) -> np.ndarray:
"""
对指定坐标的像素进行插值核心数学函数不涉及文件I/O
Args:
image_stack: 影像数据堆叠,形状为 (height, width, n_bands) 的 float32 数组
zero_coords: 需要插值的像素坐标,形状为 (n_zero, 2),每行是 [x, y]
valid_coords: 有效像素坐标,形状为 (n_valid, 2)
valid_values: 有效像素对应的值,形状为 (n_valid,) 或 (n_valid, n_bands)
interpolation_method: 插值方法,可选 'nearest', 'bilinear', 'spline', 'kriging'
water_mask: 可选的水域掩膜数组
Returns:
插值后的影像副本,形状与 image_stack 相同
"""
if not SCIPY_AVAILABLE:
raise ImportError("scipy未安装无法进行0值像素插值")
height, width, n_bands = image_stack.shape
result = image_stack.copy()
# 兼容中文和各种格式的method参数
raw_method = str(interpolation_method).lower()
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
method = 'nearest'
elif 'bilinear' in raw_method or '线性' in raw_method or '双线性' in raw_method:
method = 'bilinear'
elif 'spline' in raw_method or '样条' in raw_method or 'rbf' in raw_method:
method = 'spline'
elif 'kriging' in raw_method or '克里金' in raw_method:
method = 'kriging'
else:
method = 'nearest'
if len(valid_values) == 0:
return result
is_multiband = len(valid_values.shape) > 1 and valid_values.shape[1] > 1
if is_multiband:
for band_idx in range(n_bands):
band_valid_values = valid_values[:, band_idx]
interpolated_values = _interpolate_single_band(
zero_coords, valid_coords, band_valid_values, method
)
y_coords = zero_coords[:, 1].astype(int)
x_coords = zero_coords[:, 0].astype(int)
result[y_coords, x_coords, band_idx] = interpolated_values
else:
interpolated_values = _interpolate_single_band(
zero_coords, valid_coords, valid_values, method
)
y_coords = zero_coords[:, 1].astype(int)
x_coords = zero_coords[:, 0].astype(int)
result[y_coords, x_coords] = interpolated_values
return result
def _interpolate_single_band(
zero_coords: np.ndarray,
valid_coords: np.ndarray,
valid_values: np.ndarray,
method: str
) -> np.ndarray:
"""对单个波段执行插值计算"""
if method == 'nearest':
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords)
return valid_values[indices]
elif method == 'bilinear':
interpolated = griddata(
valid_coords, valid_values, zero_coords,
method='linear', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
elif method == 'spline':
try:
max_points = 10000
if len(valid_values) > max_points:
indices = np.random.choice(len(valid_values), max_points, replace=False)
sample_coords = valid_coords[indices]
sample_values = valid_values[indices]
else:
sample_coords = valid_coords
sample_values = valid_values
rbf = RBFInterpolator(sample_coords, sample_values, kernel='thin_plate_spline')
interpolated = rbf(zero_coords)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
except Exception:
interpolated = griddata(
valid_coords, valid_values, zero_coords,
method='linear', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
elif method == 'kriging':
try:
from src.utils.kriging import KrigingInterpolator
interpolator = KrigingInterpolator()
max_points = 5000
if len(valid_values) > max_points:
indices = np.random.choice(len(valid_values), max_points, replace=False)
sample_coords = valid_coords[indices]
sample_values = valid_values[indices]
else:
sample_coords = valid_coords
sample_values = valid_values
interpolated = griddata(
sample_coords, sample_values, zero_coords,
method='cubic', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
except Exception:
interpolated = griddata(
valid_coords, valid_values, zero_coords,
method='linear', fill_value=0.0
)
nan_mask = np.isnan(interpolated)
if np.any(nan_mask):
tree = cKDTree(valid_coords)
_, indices = tree.query(zero_coords[nan_mask])
interpolated[nan_mask] = valid_values[indices]
return interpolated
return np.zeros(len(zero_coords))
def interpolate_zero_pixels_batch(
img_path: str,
interpolation_method: str = 'nearest',
output_path: Optional[str] = None,
water_mask: Optional[Union[str, np.ndarray]] = None,
deglint_dir: Optional[str] = None,
callback_progress: Optional[callable] = None
) -> Tuple[str, Optional[np.ndarray]]:
"""
对影像中所有波段都为0的像素点进行插值完整流程含文件I/O
Args:
img_path: 输入影像文件路径
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
output_path: 输出文件路径如果为None自动生成
water_mask: 水域掩膜(文件路径或数组)
deglint_dir: 去耀斑目录(用于生成默认输出路径)
callback_progress: 进度回调函数
Returns:
(output_path, interpolated_image_stack) 元组
"""
if not SCIPY_AVAILABLE:
raise ImportError("scipy未安装无法进行0值像素插值")
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
# 确定输出路径
if output_path is None and deglint_dir is not None:
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
# 检查文件是否已存在
if output_path and Path(output_path).exists():
return output_path, None
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
# 读取所有波段数据
all_bands = []
for band_idx in range(1, n_bands + 1):
band = dataset.GetRasterBand(band_idx)
band_data = band.ReadAsArray().astype(np.float32)
all_bands.append(band_data)
image_stack = np.dstack(all_bands)
# 读取水域掩膜
mask_array = None
if water_mask is not None:
if isinstance(water_mask, str):
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
elif isinstance(water_mask, np.ndarray):
mask_array = water_mask
# 找出所有波段都为0的像素点
all_bands_zero = np.all(image_stack == 0, axis=2)
if mask_array is not None:
all_bands_zero = all_bands_zero & (mask_array > 0)
zero_pixel_count = np.sum(all_bands_zero)
if zero_pixel_count == 0:
# 无需插值,直接保存
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(all_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
out_dataset = None
return output_path, image_stack
# 获取坐标
zero_y, zero_x = np.where(all_bands_zero)
zero_coords = np.column_stack([zero_x, zero_y])
valid_mask = ~all_bands_zero
valid_y, valid_x = np.where(valid_mask)
valid_coords = np.column_stack([valid_x, valid_y])
if len(valid_coords) == 0:
raise ValueError("没有有效像素可用于插值")
# 逐波段插值
interpolated_bands = []
for band_idx in range(n_bands):
if callback_progress:
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
band_data = all_bands[band_idx].copy()
valid_values_band = band_data[valid_mask]
if len(valid_values_band) == 0:
interpolated_bands.append(band_data)
continue
band_result = _interpolate_single_band(
zero_coords, valid_coords, valid_values_band, interpolation_method
)
band_data[all_bands_zero] = band_result
interpolated_bands.append(band_data)
# 保存结果
if output_path:
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
for i, band_data in enumerate(interpolated_bands):
out_band = out_dataset.GetRasterBand(i + 1)
out_band.WriteArray(band_data)
out_band.FlushCache()
out_dataset = None
result_stack = np.dstack(interpolated_bands)
return output_path, result_stack
finally:
dataset = None

View File

@ -1,5 +1,4 @@
import numpy as np
# import preprocessing
import os
try:
@ -8,283 +7,301 @@ try:
except ImportError:
GDAL_AVAILABLE = False
class Hedley:
def __init__(self, im_aligned, shp_path=None, NIR_band = 47, water_mask=None, output_path=None):
def __init__(self, img_path, shp_path=None, NIR_band=47, water_mask=None,
output_path=None, block_size=1000):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
:param NIR_band (int): band index for NIR band which corresponds to 842.36nm, which corresponds closely to the NIR band in Micasense
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
如果为None则不保存
"""
self.im_aligned = im_aligned
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
self.NIR_band = NIR_band
self.n_bands = im_aligned.shape[-1]
self.height = im_aligned.shape[0]
self.width = im_aligned.shape[1]
self.output_path = output_path
# 加载水域掩膜
self.water_mask = self._load_water_mask(water_mask)
# 使用ravel()而不是flatten(),避免不必要的复制
# 如果存在水域掩膜只在掩膜内计算R_min
if self.water_mask is not None:
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
self.R_min = np.percentile(nir_band_masked, 5, interpolation='nearest') if nir_band_masked.size > 0 else 0
else:
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, interpolation='nearest')
def _read_shp_to_bbox(self, shp_path):
"""
读取shapefile并提取边界框
:param shp_path (str): shapefile文件路径
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
"""
if not os.path.exists(shp_path):
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
try:
try:
import geopandas as gpd
gdf = gpd.read_file(shp_path)
# 获取所有几何体的总边界框
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
min_x, min_y, max_x, max_y = bounds
except ImportError:
# 如果geopandas不可用尝试使用fiona
import fiona
from shapely.geometry import shape
min_x = float('inf')
min_y = float('inf')
max_x = float('-inf')
max_y = float('-inf')
with fiona.open(shp_path) as shp:
for feature in shp:
geom = shape(feature['geometry'])
if geom:
bounds = geom.bounds
min_x = min(min_x, bounds[0])
min_y = min(min_y, bounds[1])
max_x = max(max_x, bounds[2])
max_y = max(max_y, bounds[3])
# 转换为整数像素坐标
x1 = max(0, int(min_x))
y1 = max(0, int(min_y))
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
return ((x1, y1), (x2, y2))
except Exception as e:
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
def _load_water_mask(self, water_mask):
"""
加载水域掩膜
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
:return: numpy数组或None1表示水域0表示非水域
"""
if water_mask is None:
return None
# 如果已经是numpy数组
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
# 如果是文件路径
if isinstance(water_mask, str):
try:
from osgeo import gdal, ogr
except ImportError:
raise ValueError("使用文件路径作为掩膜时必须安装GDAL")
# 检查是否为shapefile
if water_mask.lower().endswith('.shp'):
# 从shp文件创建掩膜需要参考图像这里假设使用im_aligned的尺寸
# 注意如果输入是numpy数组无法从shp创建掩膜需要提供栅格参考
raise ValueError("Hedley类输入为numpy数组时无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
else:
# 栅格文件
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {water_mask}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def covariance_NIR(self,NIR,b):
"""
NIR & b are vectors
reflectance for band i
"""
n = len(NIR)
# 优化减少重复计算使用更高效的numpy操作
nir_mean = np.mean(NIR)
b_mean = np.mean(b)
# 使用更高效的协方差计算
pij = np.mean((NIR - nir_mean) * (b - b_mean))
pjj = np.mean((NIR - nir_mean) ** 2)
# 避免除零错误
return pij / pjj if pjj != 0 else 0.0
def correlation_bands_reflectance(self):
"""
calculate correlation between NIR and other bands for reflectance
NIR_band is 750 nm
"""
# If bbox is None, use the entire image
if self.bbox is None:
# 使用ravel()而不是flatten(),避免不必要的复制
# 直接使用视图,只在需要时创建扁平数组
im_region = self.im_aligned
mask_region = self.water_mask
else:
((x1,y1),(x2,y2)) = self.bbox
im_region = self.im_aligned[y1:y2,x1:x2,:]
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
# 如果存在水域掩膜,只在掩膜内计算相关性
if mask_region is not None:
mask_bool = mask_region.astype(bool)
if mask_bool.any():
# 只在掩膜内提取数据
NIR_reflectance = im_region[:,:,self.NIR_band][mask_bool]
else:
# 如果掩膜内没有有效像素,使用全区域
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
mask_bool = None
else:
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
mask_bool = None
# 优化:一次性计算所有波段的相关性,减少循环开销
corr_list = []
for v in range(self.n_bands):
if mask_bool is not None and mask_bool.any():
band_reflectance = im_region[:,:,v][mask_bool]
else:
band_reflectance = im_region[:,:,v].ravel()
corr = self.covariance_NIR(NIR_reflectance, band_reflectance)
corr_list.append(corr)
return corr_list
def _save_corrected_bands(self, corrected_bands):
"""
保存校正后的波段到文件BSQ格式ENVI格式
:param corrected_bands: 校正后的波段列表
Hedley 耀斑去除算法 - 分块逐波段处理版本
:param img_path (str): 输入影像文件路径GDAL可读取的格式
:param shp_path (str, optional): 深水区域shapefile已废弃请使用water_mask
:param NIR_band (int): NIR波段索引默认47对应842.36nm
:param water_mask (np.ndarray or str or None): 水域掩膜
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
:param block_size (int): 分块大小默认1000
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if self.output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 将波段列表转换为数组
corrected_array = np.stack(corrected_bands, axis=2)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(self.output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
raise ImportError("GDAL未安装无法读取影像文件")
self.img_path = img_path
self.NIR_band = int(float(NIR_band))
self.water_mask = None
self.water_mask_path = water_mask
self.output_path = output_path
self.block_size = block_size
self.R_min = None
self.corr_list = None # 全局协方差系数列表
# 打开影像
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if self.dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
self.width = self.dataset.RasterXSize
self.height = self.dataset.RasterYSize
self.n_bands = self.dataset.RasterCount
def _load_water_mask(self):
"""延迟加载水域掩膜"""
if self.water_mask_path is None:
return None
if isinstance(self.water_mask_path, np.ndarray):
if self.water_mask_path.shape[:2] != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (self.water_mask_path > 0).astype(np.uint8)
if isinstance(self.water_mask_path, str):
if self.water_mask_path.lower().endswith('.shp'):
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (mask_array > 0).astype(np.uint8)
return None
def covariance_NIR(self, NIR, b):
"""计算 NIR 与波段 b 之间的协方差系数 b_i = Cov(NIR,b) / Var(NIR)"""
n = len(NIR)
nir_mean = np.mean(NIR)
b_mean = np.mean(b)
pij = np.mean((NIR - nir_mean) * (b - b_mean))
pjj = np.mean((NIR - nir_mean) ** 2)
return pij / pjj if pjj != 0 else 0.0
def _scan_global_stats(self, sample_step=20):
"""
扫描全图获取全局 R_min
使用重采样方式扫描,大幅降低内存占用。
"""
print(f"[Hedley] 扫描全局统计量(采样步长={sample_step}...")
water_mask = self._load_water_mask()
nir_samples = []
sample_count = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
block_height = y_end - y_off
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
nir_band = None
if water_mask is not None:
mask_block = water_mask[y_off:y_end, :]
mask_bool = mask_block.astype(bool)
else:
mask_bool = np.ones((block_height, self.width), dtype=bool)
if mask_bool.any():
nir_sampled = nir_block[mask_bool][::sample_step]
nir_samples.append(nir_sampled)
sample_count += nir_sampled.size
del nir_block, mask_block
if sample_count == 0:
self.R_min = 0.0
else:
bsq_path = self.output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_array.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_array[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
all_nir = np.concatenate(nir_samples)
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
del all_nir
print(f"[Hedley] 全局 R_min={self.R_min:.4f}")
def _compute_corr_list(self, sample_step=5):
"""
计算每个波段与NIR的协方差系数 corr_list[b] = Cov(NIR, band_b) / Var(NIR)
全分辨率扫描,逐波段读取,每波段内存 ≈ block_size²
由于需要相关性计算需要足够多的样本取sample_step=5
"""
print(f"[Hedley] 计算全局协方差系数列表(采样步长={sample_step}...")
water_mask = self._load_water_mask()
# 预收集NIR和每个波段的样本数据
nir_samples = []
band_samples = [[] for _ in range(self.n_bands)]
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
block_height = y_end - y_off
# 读取NIR波段每块只读一次
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
nir_band = None
# 取 NIR 样本(每块只取一次,放在波段循环外)
if water_mask is not None:
mask_block = water_mask[y_off:y_end, :]
mask_bool = mask_block.astype(bool)
else:
mask_bool = np.ones((block_height, self.width), dtype=bool)
if mask_bool.any():
nir_sampled = nir_block[mask_bool][::sample_step]
nir_samples.append(nir_sampled)
# 逐波段读取并采样all_band 严格使用单波段切片)
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
band = None
if mask_bool.any():
band_sampled = block[mask_bool][::sample_step]
band_samples[b].append(band_sampled)
del block
del nir_block
# 汇总并计算相关系数
if len(nir_samples) == 0 or sum(len(s) for s in nir_samples) == 0:
self.corr_list = [0.0] * self.n_bands
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")
all_nir = np.concatenate(nir_samples)
self.corr_list = []
for b in range(self.n_bands):
all_band = np.concatenate(band_samples[b])
corr = self.covariance_NIR(all_nir, all_band)
self.corr_list.append(float(corr))
del all_nir
for b in range(self.n_bands):
band_samples[b] = None
print(f"[Hedley] 协方差系数: min={min(self.corr_list):.4f}, max={max(self.corr_list):.4f}")
def _process_block(self, x_off, y_off, x_size, y_size):
"""
处理单个分块
Returns:
list of np.ndarray: 校正后的波段列表
"""
# 读取NIR波段
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
NIR = nir_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
nir_band = None
# 预计算 NIR - R_min
NIR_diff = NIR - self.R_min
# 获取掩膜
water_mask = self._load_water_mask()
if water_mask is not None:
y_end = y_off + y_size
x_end = x_off + x_size
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
else:
mask_block = None
# 逐波段处理
corrected_bands = []
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
band = None
corr = self.corr_list[b]
# Hedley 校正公式R_corrected = R - corr * (NIR - R_min)
corrected = R - corr * NIR_diff
if mask_block is not None:
corrected = np.where(mask_block, corrected, R)
corrected_bands.append(corrected)
del R
del NIR, NIR_diff
return corrected_bands
def get_corrected_bands(self):
"""
correction is done in reflectance
:return: 校正后的波段列表
执行分块处理,返回校正后的波段列表
"""
corr = self.correlation_bands_reflectance()
NIR_reflectance = self.im_aligned[:,:,self.NIR_band]
# 预计算NIR-R_min避免在循环中重复计算
NIR_diff = NIR_reflectance - self.R_min
# 获取水域掩膜(如果存在)
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
if self.output_path is None:
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
corrected_bands = []
for band_number in range(self.n_bands): #iterate across bands
b = corr[band_number]
R = self.im_aligned[:,:,band_number]
# 优化:减少中间数组创建
corrected_band = R - b * NIR_diff
# 如果存在水域掩膜,只对水域区域应用校正
if water_mask_bool is not None:
corrected_band = np.where(water_mask_bool, corrected_band, R)
corrected_bands.append(corrected_band)
# 如果提供了输出路径,保存结果
if self.output_path is not None:
self._save_corrected_bands(corrected_bands)
return corrected_bands
# Step 1: 扫描全局 R_min
self._scan_global_stats(sample_step=20)
# Step 2: 计算协方差系数列表
self._compute_corr_list(sample_step=5)
# Step 3: 创建输出文件
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
base_path, ext = os.path.splitext(self.output_path)
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
geotransform = self.dataset.GetGeoTransform()
projection = self.dataset.GetProjection()
driver = gdal.GetDriverByName('ENVI')
out_dataset = driver.Create(bsq_path, self.width, self.height,
self.n_bands, gdal.GDT_Float32)
if out_dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
# Step 4: 分块处理
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
total_blocks = n_blocks_x * n_blocks_y
print(f"[Hedley] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
block_idx = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
y_size = y_end - y_off
for x_off in range(0, self.width, self.block_size):
x_end = min(x_off + self.block_size, self.width)
x_size = x_end - x_off
block_idx += 1
print(f"[Hedley] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
for b in range(self.n_bands):
out_band = out_dataset.GetRasterBand(b + 1)
out_band.WriteArray(corrected_bands[b], x_off, y_off)
out_band.FlushCache()
del corrected_bands
out_dataset = None
self.dataset = None
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"[Hedley] 校正完成,已保存至: {bsq_path}")
else:
print(f"[Hedley] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件")
return []
def __del__(self):
if self.dataset is not None:
self.dataset = None

View File

@ -1,5 +1,4 @@
import numpy as np
# import preprocessing
import os
try:
@ -8,306 +7,333 @@ try:
except ImportError:
GDAL_AVAILABLE = False
class Kutser:
def __init__(self, im_aligned, shp_path=None, oxy_band = 38,lower_oxy = 36, upper_oxy = 49, NIR_band = 47, water_mask=None, output_path=None):
"""
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
:param oxy_band (int): band index for oxygen absorption band, which corresponds to 760.6nm
:param lower_oxy (int): band index for outside oxygen absorption band, which corresponds to 742.39nm
:param upper_oxy (int): band index for outside oxygen absorption band, which corresponds to 860.48nm
see Kutser, Vahtmäe and Praks
:param water_mask (np.ndarray or str or None): 水域掩膜1表示水域0表示非水域
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
如果为None则处理全图
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
如果为None则不保存
"""
self.im_aligned = im_aligned
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
self.oxy_band = oxy_band
self.lower_oxy = lower_oxy
self.upper_oxy = upper_oxy
self.NIR_band = NIR_band
self.n_bands = im_aligned.shape[-1]
self.height = im_aligned.shape[0]
self.width = im_aligned.shape[1]
self.output_path = output_path
# 加载水域掩膜
self.water_mask = self._load_water_mask(water_mask)
# 使用ravel()而不是flatten(),避免不必要的复制
# 如果存在水域掩膜只在掩膜内计算R_min
if self.water_mask is not None:
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
self.R_min = np.percentile(nir_band_masked, 5, interpolation='nearest') if nir_band_masked.size > 0 else 0
else:
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, interpolation='nearest')
def _read_shp_to_bbox(self, shp_path):
"""
读取shapefile并提取边界框
:param shp_path (str): shapefile文件路径
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
"""
if not os.path.exists(shp_path):
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
try:
try:
import geopandas as gpd
gdf = gpd.read_file(shp_path)
# 获取所有几何体的总边界框
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
min_x, min_y, max_x, max_y = bounds
except ImportError:
# 如果geopandas不可用尝试使用fiona
import fiona
from shapely.geometry import shape
min_x = float('inf')
min_y = float('inf')
max_x = float('-inf')
max_y = float('-inf')
with fiona.open(shp_path) as shp:
for feature in shp:
geom = shape(feature['geometry'])
if geom:
bounds = geom.bounds
min_x = min(min_x, bounds[0])
min_y = min(min_y, bounds[1])
max_x = max(max_x, bounds[2])
max_y = max(max_y, bounds[3])
# 转换为整数像素坐标
x1 = max(0, int(min_x))
y1 = max(0, int(min_y))
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
return ((x1, y1), (x2, y2))
except Exception as e:
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
def _load_water_mask(self, water_mask):
"""
加载水域掩膜
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
:return: numpy数组或None1表示水域0表示非水域
"""
if water_mask is None:
return None
# 如果已经是numpy数组
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
# 如果是文件路径
if isinstance(water_mask, str):
try:
from osgeo import gdal, ogr
except ImportError:
raise ValueError("使用文件路径作为掩膜时必须安装GDAL")
# 检查是否为shapefile
if water_mask.lower().endswith('.shp'):
# 从shp文件创建掩膜需要参考图像这里假设使用im_aligned的尺寸
# 注意如果输入是numpy数组无法从shp创建掩膜需要提供栅格参考
raise ValueError("Kutser类输入为numpy数组时无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
else:
# 栅格文件
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {water_mask}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def get_depth_D(self):
"""
Assume the amount of glint is proportional to the depth of the oxygen absorption feature, D
returns the normalised D by dividing it by the maximum D found in a deep water region
"""
# 优化:减少中间数组创建,使用更高效的计算
lower_oxy_band = self.im_aligned[:,:,self.lower_oxy]
upper_oxy_band = self.im_aligned[:,:,self.upper_oxy]
oxy_band = self.im_aligned[:,:,self.oxy_band]
D = (lower_oxy_band + upper_oxy_band) * 0.5 - oxy_band
# 确定用于计算D_max的区域
if self.bbox is None:
search_region = D
else:
((x1,y1),(x2,y2)) = self.bbox
search_region = D[y1:y2,x1:x2]
# 如果存在水域掩膜,只在掩膜内搜索最大值
if self.water_mask is not None:
if self.bbox is None:
mask_region = self.water_mask.astype(bool)
else:
((x1,y1),(x2,y2)) = self.bbox
mask_region = self.water_mask[y1:y2,x1:x2].astype(bool)
if mask_region.any():
D_max = search_region[mask_region].max()
else:
D_max = search_region.max()
else:
D_max = search_region.max() # assumed to be the maximum glint value
# 避免除零错误
if D_max == 0:
return np.zeros_like(D)
return D / D_max
def get_glint_G(self):
"""
The spectral variation of glint G is found by subtracting the spectrum at the darkest (ie. lowest D) NIR deep-water pixel from the brightest
returns G as a function of wavelength
"""
# If bbox is None, use the entire image
if self.bbox is None:
im_region = self.im_aligned
mask_region = self.water_mask
else:
((x1,y1),(x2,y2)) = self.bbox
im_region = self.im_aligned[y1:y2,x1:x2,:]
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
# 如果存在水域掩膜,只在掩膜内计算最大最小值
if mask_region is not None:
mask_bool = mask_region.astype(bool)
if mask_bool.any():
# 对每个波段,只在掩膜内计算最大最小值
G_list = []
for i in range(self.n_bands):
band_data = im_region[:,:,i]
G_max = band_data[mask_bool].max()
G_min = band_data[mask_bool].min()
G_list.append(G_max - G_min)
else:
# 如果掩膜内没有有效像素,使用全区域
G_max = np.amax(im_region, axis=(0, 1))
G_min = np.amin(im_region, axis=(0, 1))
G_list = (G_max - G_min).tolist()
else:
# 优化:一次性计算所有波段的最大最小值,减少循环开销
# 使用numpy的amax和amin沿最后一个轴计算
G_max = np.amax(im_region, axis=(0, 1)) # 沿空间维度计算最大值
G_min = np.amin(im_region, axis=(0, 1)) # 沿空间维度计算最小值
G_list = (G_max - G_min).tolist()
return G_list
def _save_corrected_bands(self, corrected_bands):
class Kutser:
def __init__(self, img_path, shp_path=None, oxy_band=38, lower_oxy=36,
upper_oxy=49, NIR_band=47, water_mask=None, output_path=None,
block_size=1000):
"""
保存校正后的波段到文件BSQ格式ENVI格式
:param corrected_bands: 校正后的波段列表
Kutser 耀斑去除算法 - 分块逐波段处理版本
:param img_path (str): 输入影像文件路径GDAL可读取的格式
:param shp_path (str, optional): 深水区域shapefile已废弃请使用water_mask
:param oxy_band (int): 氧吸收波段索引默认38对应760.6nm
:param lower_oxy (int): 氧吸收下方波段索引默认36对应742.39nm
:param upper_oxy (int): 氧吸收上方波段索引默认49对应860.48nm
:param NIR_band (int): NIR波段索引默认47对应842.36nm
:param water_mask (np.ndarray or str or None): 水域掩膜
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
:param block_size (int): 分块大小默认1000
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if self.output_path is None:
return
# 确保输出目录存在
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# 将波段列表转换为数组
corrected_array = np.stack(corrected_bands, axis=2)
# 如果没有地理信息,使用默认值
geotransform = (0, 1, 0, 0, 0, -1)
projection = ""
# 强制使用ENVI格式BSQ格式确保文件扩展名为.bsq
base_path, ext = os.path.splitext(self.output_path)
# 如果扩展名不是.bsq使用基础路径添加.bsq
if ext.lower() != '.bsq':
bsq_path = base_path + '.bsq'
raise ImportError("GDAL未安装无法读取影像文件")
self.img_path = img_path
self.oxy_band = int(float(oxy_band))
self.lower_oxy = int(float(lower_oxy))
self.upper_oxy = int(float(upper_oxy))
self.NIR_band = int(float(NIR_band))
self.water_mask = None # 延迟加载,在处理前初始化
self.water_mask_path = water_mask
self.output_path = output_path
self.block_size = block_size
self.R_min = None # 全局R_min来自重采样扫描
self.D_max = None # 全局D_max来自重采样扫描
self.G_list = None # 全局G值列表
# 打开影像获取基本信息
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if self.dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
self.width = self.dataset.RasterXSize
self.height = self.dataset.RasterYSize
self.n_bands = self.dataset.RasterCount
def _load_water_mask(self):
"""延迟加载水域掩膜"""
if self.water_mask_path is None:
return None
if isinstance(self.water_mask_path, np.ndarray):
if self.water_mask_path.shape[:2] != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (self.water_mask_path > 0).astype(np.uint8)
if isinstance(self.water_mask_path, str):
if self.water_mask_path.lower().endswith('.shp'):
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
if mask_array.shape != (self.height, self.width):
raise ValueError(
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
)
return (mask_array > 0).astype(np.uint8)
return None
def _scan_global_stats(self, sample_step=20):
"""
通过重采样扫描获取全图全局统计量R_min 和 D_max
分块读取按sample_step跳行/跳列采样,大幅降低内存占用。
内存峰值 ≈ 单波段块大小 + 几个掩膜数组 ≈ block_size² × 4~8MB
"""
print(f"[Kutser] 扫描全局统计量(采样步长={sample_step}...")
water_mask = self._load_water_mask()
# 预分配采样数组NIR波段和D值
nir_samples = []
d_samples = []
sample_count = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
block_height = y_end - y_off
# 读取NIR波段用于R_min
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
nir_band = None
# 读取氧吸收相关波段用于D_max
lower_band = self.dataset.GetRasterBand(self.lower_oxy + 1)
lower_block = lower_band.ReadAsArray(0, y_off, self.width, block_height)
lower_band = None
upper_band = self.dataset.GetRasterBand(self.upper_oxy + 1)
upper_block = upper_band.ReadAsArray(0, y_off, self.width, block_height)
upper_band = None
oxy_band_obj = self.dataset.GetRasterBand(self.oxy_band + 1)
oxy_block = oxy_band_obj.ReadAsArray(0, y_off, self.width, block_height)
oxy_band_obj = None
# 计算D = (lower + upper) * 0.5 - oxy
d_block = (lower_block.astype(np.float32) + upper_block.astype(np.float32)) * 0.5 - oxy_block.astype(np.float32)
# 获取掩膜(整块)
if water_mask is not None:
mask_block = water_mask[y_off:y_end, :]
else:
mask_block = np.ones((block_height, self.width), dtype=np.uint8)
# 对掩膜区域进行采样
mask_bool = mask_block.astype(bool)
if mask_bool.any():
# 按步长采样
nir_sampled = nir_block[mask_bool][::sample_step]
d_sampled = d_block[mask_bool][::sample_step]
nir_samples.append(nir_sampled)
d_samples.append(d_sampled)
sample_count += nir_sampled.size
# 显式释放块内存
del nir_block, lower_block, upper_block, oxy_block, d_block, mask_block
# 汇总
if sample_count == 0:
self.R_min = 0.0
self.D_max = 1.0
else:
bsq_path = self.output_path
# 使用ENVI驱动默认就是BSQ格式
driver = gdal.GetDriverByName('ENVI')
if driver is None:
raise ValueError("无法创建ENVI格式文件ENVI驱动不可用")
height, width, n_bands = corrected_array.shape
# 创建ENVI格式数据集会自动生成.hdr文件
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
if dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
try:
# 设置地理变换和投影
if geotransform:
dataset.SetGeoTransform(geotransform)
if projection:
dataset.SetProjection(projection)
# 写入每个波段BSQ格式按波段顺序存储
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(corrected_array[:, :, i])
band.FlushCache()
finally:
dataset = None
# 检查.hdr文件是否已创建
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"头文件已保存至: {hdr_path}")
all_nir = np.concatenate(nir_samples)
all_d = np.concatenate(d_samples)
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
self.D_max = float(all_d.max())
del all_nir, all_d
print(f"[Kutser] 全局 R_min={self.R_min:.4f}, D_max={self.D_max:.4f}")
def _compute_G_list(self):
"""
计算全局G值列表每个波段
G = G_max - G_min所有水域像素的极值差异
使用全分辨率扫描,但逐波段读取,每波段内存 ≈ block_size²
"""
print(f"[Kutser] 计算全局G值列表n_bands={self.n_bands}...")
water_mask = self._load_water_mask()
# 初始化G_max和G_min为极值
g_max = np.full(self.n_bands, -np.inf, dtype=np.float32)
g_min = np.full(self.n_bands, np.inf, dtype=np.float32)
# 逐块扫描
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
block_height = y_end - y_off
# 读取所有波段的当前块
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
band = None
if water_mask is not None:
mask_block = water_mask[y_off:y_end, :]
mask_bool = mask_block.astype(bool)
if mask_bool.any():
band_masked = block[mask_bool]
g_max[b] = max(g_max[b], band_masked.max())
g_min[b] = min(g_min[b], band_masked.min())
else:
g_max[b] = max(g_max[b], block.max())
g_min[b] = min(g_min[b], block.min())
del block
self.G_list = (g_max - g_min).tolist()
print(f"[Kutser] G值范围: min={min(self.G_list):.4f}, max={max(self.G_list):.4f}")
def _process_block(self, x_off, y_off, x_size, y_size):
"""
处理单个分块:读取数据 -> 计算D -> 逐波段校正 -> 返回块结果
Returns:
list of np.ndarray: 校正后的波段列表(每波段形状为 y_size x x_size
"""
# 读取氧吸收相关波段
lower_band = self.dataset.GetRasterBand(self.lower_oxy + 1)
lower_block = lower_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
lower_band = None
upper_band = self.dataset.GetRasterBand(self.upper_oxy + 1)
upper_block = upper_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
upper_band = None
oxy_band_obj = self.dataset.GetRasterBand(self.oxy_band + 1)
oxy_block = oxy_band_obj.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
oxy_band_obj = None
# 计算D
D = (lower_block + upper_block) * 0.5 - oxy_block
# 避免除零
if self.D_max == 0:
D_normalized = np.zeros_like(D)
else:
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
print(f"警告: 未检测到.hdr文件但GDAL应该已自动创建")
D_normalized = D / self.D_max
# 释放临时块
del lower_block, upper_block, oxy_block, D
# 获取当前块的水域掩膜
water_mask = self._load_water_mask()
if water_mask is not None:
y_end = y_off + y_size
x_end = x_off + x_size
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
else:
mask_block = None
# 逐波段处理
corrected_bands = []
for b in range(self.n_bands):
band = self.dataset.GetRasterBand(b + 1)
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
band = None
G = self.G_list[b]
# 校正公式R_corrected = R - G * D_normalized
corrected = R - G * D_normalized
# 只对水域区域应用校正
if mask_block is not None:
corrected = np.where(mask_block, corrected, R)
corrected_bands.append(corrected)
del R
# 释放D块
del D_normalized
return corrected_bands
def get_corrected_bands(self):
"""
correction is done in reflectance
:return: 校正后的波段列表
"""
g_list = self.get_glint_G()
D = self.get_depth_D()
# 获取水域掩膜(如果存在)
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
执行分块处理,返回校正后的波段列表
corrected_bands = []
for band_number in range(self.n_bands): #iterate across bands
G = g_list[band_number]
R = self.im_aligned[:,:,band_number]
# 优化:减少中间数组创建,直接计算
corrected_band = R - G * D
# 如果存在水域掩膜,只对水域区域应用校正
if water_mask_bool is not None:
corrected_band = np.where(water_mask_bool, corrected_band, R)
corrected_bands.append(corrected_band)
# 如果提供了输出路径,保存结果
if self.output_path is not None:
self._save_corrected_bands(corrected_bands)
return corrected_bands
内存峰值 ≈ 单波段块大小 + 几个辅助数组 ≈ 1000×1000×4B × 3 ≈ 12MB
"""
if self.output_path is None:
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
# Step 1: 扫描全局统计量R_min, D_max
self._scan_global_stats(sample_step=20)
# Step 2: 计算全局G列表
self._compute_G_list()
# Step 3: 创建输出文件
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
base_path, ext = os.path.splitext(self.output_path)
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
# 获取地理信息
geotransform = self.dataset.GetGeoTransform()
projection = self.dataset.GetProjection()
driver = gdal.GetDriverByName('ENVI')
out_dataset = driver.Create(bsq_path, self.width, self.height,
self.n_bands, gdal.GDT_Float32)
if out_dataset is None:
raise ValueError(f"无法创建输出文件: {bsq_path}")
out_dataset.SetGeoTransform(geotransform)
out_dataset.SetProjection(projection)
# Step 4: 分块处理
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
total_blocks = n_blocks_x * n_blocks_y
print(f"[Kutser] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
block_idx = 0
for y_off in range(0, self.height, self.block_size):
y_end = min(y_off + self.block_size, self.height)
y_size = y_end - y_off
for x_off in range(0, self.width, self.block_size):
x_end = min(x_off + self.block_size, self.width)
x_size = x_end - x_off
block_idx += 1
print(f"[Kutser] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
# 处理当前块
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
# 写入输出文件
for b in range(self.n_bands):
out_band = out_dataset.GetRasterBand(b + 1)
out_band.WriteArray(corrected_bands[b], x_off, y_off)
out_band.FlushCache()
del corrected_bands
out_dataset = None
self.dataset = None
# 检查.hdr文件
hdr_path = bsq_path + '.hdr'
if os.path.exists(hdr_path):
print(f"[Kutser] 校正完成,已保存至: {bsq_path}")
else:
print(f"[Kutser] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件")
# 返回空列表(结果已直接写入文件)
return []
def __del__(self):
if self.dataset is not None:
self.dataset = None

File diff suppressed because it is too large Load Diff

View File

@ -19,6 +19,7 @@ from sklearn.cross_decomposition import PLSRegression
from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.neural_network import MLPRegressor
from joblib import parallel_backend
# 第三方模型导入
# try:
# import lightgbm as lgb
@ -42,11 +43,6 @@ import os
from src.preprocessing.spectral_Preprocessing import Preprocessing
def _sklearn_parallel_n_jobs() -> int:
"""PyInstaller 等打包环境下joblib loky 会再次启动当前 exe出现多个同名进程。"""
return 1 if getattr(sys, "frozen", False) else -1
class WaterQualityModelingBatch:
"""水质参数反演批量建模类"""
@ -642,26 +638,25 @@ class WaterQualityModelingBatch:
# 网格搜索 - 使用KFold代替StratifiedKFold
cv_strategy = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
_n_jobs = _sklearn_parallel_n_jobs()
grid_search = GridSearchCV(
base_model,
config['params'],
cv=cv_strategy,
scoring=scoring,
n_jobs=_n_jobs,
n_jobs=-1,
verbose=1
)
# 在训练集上训练模型
# with parallel_backend("threading", n_jobs=-1):
# grid_search.fit(X_train, y_train)
grid_search.fit(X_train, y_train)
# 获取最佳模型
best_model = grid_search.best_estimator_
# 交叉验证评估(在训练集上)
cv_scores = cross_val_score(
best_model, X_train, y_train, cv=cv_strategy, scoring=scoring, n_jobs=_n_jobs
)
cv_scores = cross_val_score(best_model, X_train, y_train, cv=cv_strategy, scoring=scoring)
# 计算训练集上的回归指标
y_train_pred = best_model.predict(X_train)

View File

@ -555,7 +555,13 @@ class WaterQualityInference:
print(f"输入数据形状: {spectra_processed.shape}")
try:
predictions = model.predict(spectra_processed)
# 清洗 NaN / Inf防止 SVR 等模型报错
spectra_clean = np.nan_to_num(spectra_processed, nan=0.0, posinf=0.0, neginf=0.0)
if np.any(np.isnan(spectra_clean)) or np.any(np.isinf(spectra_clean)):
print("警告: 清洗后数据中仍存在 NaN/Inf已重置为 0")
spectra_clean = np.nan_to_num(spectra_clean, nan=0.0, posinf=0.0, neginf=0.0)
predictions = model.predict(spectra_clean)
print(f"预测完成,结果形状: {predictions.shape}")
print(f"预测值范围: [{np.min(predictions):.4f}, {np.max(predictions):.4f}]")
print(f"预测值统计: 均值={np.mean(predictions):.4f}, 标准差={np.std(predictions):.4f}")

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""业务步骤层模块"""
from src.core.steps.water_mask_step import WaterMaskStep
from src.core.steps.glint_detection_step import GlintDetectionStep
from src.core.steps.glint_removal_step import GlintRemovalStep
from src.core.steps.data_preparation_step import DataPreparationStep
from src.core.steps.modeling_step import ModelingStep
from src.core.steps.prediction_step import PredictionStep
from src.core.steps.mapping_step import MappingStep
__all__ = [
"WaterMaskStep",
"GlintDetectionStep",
"GlintRemovalStep",
"DataPreparationStep",
"ModelingStep",
"PredictionStep",
"MappingStep",
]

View File

@ -0,0 +1,184 @@
# -*- coding: utf-8 -*-
"""
数据准备步骤
包含 step4_process_csv, step5_extract_training_spectra, step5_5_calculate_water_quality_indices
"""
import time
from pathlib import Path
from typing import Optional, List, Union, Callable, Dict
import pandas as pd
import numpy as np
class DataPreparationStep:
"""数据准备步骤"""
# ---- Step 4: 处理CSV文件 ----
@staticmethod
def process_csv(
csv_path: str,
output_dir: Union[str, Path] = "./4_processed_data",
callback: Optional[Callable] = None,
) -> str:
"""处理CSV文件筛选剔除异常值"""
from src.preprocessing.process_water_quality_data import process_water_quality_data
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = str(output_dir / "processed_data.csv")
def notify(status, msg=""):
if callback:
callback("步骤4", status, msg)
print("\n" + "=" * 80)
print("步骤4: 处理CSV文件筛选剔除异常值")
print("=" * 80)
step_start_time = time.time()
if Path(output_path).exists():
print(f"检测到已存在的处理后CSV文件直接使用: {output_path}")
notify("skipped", f"处理后的CSV文件已设置: {output_path}")
return output_path
process_water_quality_data(csv_path, output_path)
notify("completed", f"处理后的CSV文件已保存: {output_path}")
return output_path
# ---- Step 5: 提取训练样本点光谱 ----
@staticmethod
def extract_training_spectra(
deglint_img_path: Optional[str] = None,
radius: int = 5,
source_epsg: int = 4326,
csv_path: Optional[str] = None,
boundary_path: Optional[str] = None,
glint_mask_path: Optional[str] = None,
water_mask_path: Optional[str] = None,
output_dir: Union[str, Path] = "./5_training_spectra",
callback: Optional[Callable] = None,
) -> str:
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
from src.core.glint_removal.get_spectral import get_spectral_in_coor
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = str(output_dir / "training_spectra.csv")
def notify(status, msg=""):
if callback:
callback("步骤5", status, msg)
print("\n" + "=" * 80)
print("步骤5: 提取训练样本点的平均光谱")
print("=" * 80)
step_start_time = time.time()
if deglint_img_path is None:
raise ValueError("必须提供 deglint_img_path 参数")
if csv_path is None:
raise ValueError("必须提供 csv_path 参数")
if Path(output_path).exists():
print(f"检测到已存在的训练光谱数据文件,直接使用: {output_path}")
notify("skipped", f"训练光谱数据已设置: {output_path}")
return output_path
# 确保水体掩膜存在
final_boundary_path = boundary_path
if final_boundary_path is None and water_mask_path is not None:
final_boundary_path = water_mask_path
# 【新增安全防护】智能拦截矢量 .shp强制替换为步骤 1 生成的栅格 .dat
if final_boundary_path and str(final_boundary_path).lower().endswith('.shp'):
# 向上追溯查找 1_water_mask 目录下的 dat 替身
possible_dat = Path(deglint_img_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
if not possible_dat.exists() and output_path:
possible_dat = Path(output_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
if possible_dat.exists():
print(f"💡 智能拦截:检测到输入掩膜为矢量 .shp自动切换为已生成的栅格掩膜: {possible_dat}")
final_boundary_path = str(possible_dat)
else:
print(f"⚠️ 警告:检测到输入掩膜为 .shp 且未找到对应 .dat 替身,可能导致底层读取失败。")
flare_path = glint_mask_path
if flare_path:
print(f"光谱提取使用耀斑掩膜: {flare_path}")
get_spectral_in_coor(
deglint_img_path, csv_path, output_path,
radius=radius, flare_path=flare_path,
boundary_path=final_boundary_path, source_epsg=source_epsg
)
notify("completed", f"训练光谱数据已保存: {output_path}")
return output_path
# ---- Step 5.5: 计算水质光谱指数 ----
@staticmethod
def calculate_water_quality_indices(
training_spectra_path: Optional[str] = None,
formula_csv_file: Optional[str] = None,
formula_names: Optional[List[str]] = None,
output_file: Optional[str] = None,
enabled: bool = True,
output_dir: Union[str, Path] = "./6_water_quality_indices",
callback: Optional[Callable] = None,
) -> Optional[str]:
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤5.5", status, msg)
print("\n" + "=" * 80)
print("步骤5.5: 计算水质光谱指数使用band_math方法")
print("=" * 80)
step_start_time = time.time()
if not enabled:
print("已设置跳过水质指数计算enabled=False")
notify("skipped", "跳过水质指数计算")
return None
if training_spectra_path is None:
raise ValueError("必须提供 training_spectra_path 参数")
if formula_csv_file is None:
raise ValueError("必须提供 formula_csv_file 参数")
if output_file:
output_path = str(Path(output_file))
else:
output_path = str(output_dir / "water_quality_indices.csv")
if Path(output_path).exists():
print(f"检测到已存在的水质指数文件,直接使用: {output_path}")
notify("skipped", f"水质指数数据已设置: {output_path}")
return output_path
from src.utils.band_math import BandMathCalculator
calculator = BandMathCalculator(training_spectra_path)
result_df = calculator.process_formulas_from_csv(
formula_csv_file=formula_csv_file,
formula_names=formula_names,
output_file=output_path
)
if result_df is None:
raise ValueError("计算水质指数失败请检查公式CSV文件格式")
notify("completed", f"水质指数已保存: {output_path}")
return output_path

View File

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
"""
步骤2: 耀斑区域检测
支持多种检测方法: otsu, zscore, percentile, iqr, adaptive, multi_band
"""
import time
from pathlib import Path
from typing import Optional, List, Union
class GlintDetectionStep:
"""耀斑区域检测步骤"""
@staticmethod
def run(
img_path: str,
glint_wave: float = 750.0,
method: str = "otsu",
z_threshold: float = 2.5,
percentile: float = 95.0,
iqr_multiplier: float = 1.5,
window_size: int = 15,
multi_band_waves: Optional[List[float]] = None,
sub_method: str = "zscore",
weights: Optional[List[float]] = None,
max_area: Optional[int] = None,
buffer_size: Optional[int] = None,
water_mask_path: Optional[str] = None,
glint_dir: Union[str, Path] = "./2_glint",
callback: Optional[callable] = None,
) -> str:
"""
执行耀斑区域检测
Args:
img_path: 输入影像文件路径
glint_wave: 用于耀斑检测的波段波长nm
method: 检测方法 ('otsu' | 'zscore' | 'percentile' | 'iqr' | 'adaptive' | 'multi_band')
z_threshold: Z-score 方法阈值(默认 2.5
percentile: 百分位数阈值(默认 95.0
iqr_multiplier: IQR 倍数(默认 1.5
window_size: 自适应阈值窗口大小(默认 15
multi_band_waves: 多波段方法的波长列表,如 [750, 800, 850]
sub_method: 多波段方法的子方法(默认 'zscore'
weights: 多波段方法的权重列表None 表示等权重)
max_area: 最大连通域面积阈值(像素),超过则过滤
buffer_size: 岸边缓冲区大小(像素),用于去除岸边附近错误掩膜
water_mask_path: 水域掩膜文件路径dat 格式优先)
glint_dir: 工作目录
callback: 回调函数
Returns:
耀斑掩膜文件路径 (.dat)
"""
from src.utils.find_severe_glint_area import find_severe_glint_area
glint_dir = Path(glint_dir)
glint_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤2", status, msg)
print("\n" + "=" * 80)
print("步骤2: 找到耀斑区域")
print("=" * 80)
step_start_time = time.time()
# 确定水体掩膜路径
if water_mask_path is not None and Path(water_mask_path).exists():
final_water_mask_path = water_mask_path
else:
final_water_mask_path = None
output_path = str(glint_dir / "severe_glint_area.dat")
# 跳过已存在的文件
if Path(output_path).exists():
print(f"检测到已存在的耀斑掩膜文件,直接使用: {output_path}")
notify("skipped", f"耀斑掩膜已设置: {output_path}")
return output_path
# 构建检测参数字典
kwargs = {
"method": method,
"z_threshold": z_threshold,
"percentile": percentile,
"iqr_multiplier": iqr_multiplier,
"window_size": window_size,
}
if method == "multi_band":
if multi_band_waves is not None:
kwargs["multi_band_waves"] = multi_band_waves
if sub_method is not None:
kwargs["sub_method"] = sub_method
if weights is not None:
kwargs["weights"] = weights
if max_area is not None:
kwargs["max_area"] = max_area
if buffer_size is not None:
kwargs["buffer_size"] = buffer_size
glint_mask_path = find_severe_glint_area(
img_path, final_water_mask_path, glint_wave, output_path, **kwargs
)
print(f"耀斑掩膜已生成: {glint_mask_path}")
print(f"使用检测方法: {method}")
notify("completed", f"耀斑掩膜已生成: {glint_mask_path}")
return glint_mask_path

View File

@ -0,0 +1,375 @@
# -*- coding: utf-8 -*-
"""
步骤3: 去除耀斑
支持多种方法: subtract_nir, regression_slope, oxygen_absorption, kutser, goodman, hedley, sugar
每种方法都会:
1. 准备水域掩膜(支持 shp 自动转 dat
2. 调用对应的算法类执行处理
3. 复制 hdr 文件到输出影像
"""
import os
import time
from pathlib import Path
from typing import Optional, List, Union, Callable
import numpy as np
def _safe_rename(src_bsq: str, src_hdr: str, dest_bsq: str, dest_hdr: str) -> str:
"""将底层硬编码生成的 .bsq + .hdr 文件对重命名到用户指定的 output_path
使用 os.remove + os.rename 确保原子覆盖(不等 os.replace 的跨设备行为),
resolve() 断路防止同路径 self-rename 报错。
Returns:
dest_bsq 路径
"""
src_bsq_p = Path(src_bsq)
src_hdr_p = Path(src_hdr)
dest_bsq_p = Path(dest_bsq)
dest_hdr_p = Path(dest_hdr)
if str(src_bsq_p.resolve()) == str(dest_bsq_p.resolve()):
return dest_bsq
if dest_bsq_p.exists():
os.remove(dest_bsq_p)
if dest_hdr_p.exists():
os.remove(dest_hdr_p)
if src_bsq_p.exists():
os.rename(src_bsq_p, dest_bsq_p)
if src_hdr_p.exists():
os.rename(src_hdr_p, dest_hdr_p)
return dest_bsq
class GlintRemovalStep:
"""去除耀斑步骤"""
@staticmethod
def run(
img_path: str,
method: str = "subtract_nir",
start_wave: Optional[float] = None,
end_wave: Optional[float] = None,
json_path: Optional[str] = None,
left_shoulder_wave: Optional[float] = None,
valley_wave: Optional[float] = None,
right_shoulder_wave: Optional[float] = None,
water_mask: Optional[Union[str, np.ndarray]] = None,
interpolated_img_path: Optional[str] = None,
interpolate_zeros: bool = False,
interpolation_method: str = "nearest",
enabled: bool = True,
# Kutser 参数
kutser_shp_path: Optional[str] = None,
oxy_band: int = 38,
lower_oxy: int = 36,
upper_oxy: int = 49,
nir_band: int = 47,
# Goodman 参数
nir_lower: int = 25,
nir_upper: int = 37,
goodman_A: float = 0.000019,
goodman_B: float = 0.1,
# Hedley 参数
hedley_shp_path: Optional[str] = None,
hedley_nir_band: int = 47,
# SUGAR 参数
sugar_bounds: Optional[List[tuple]] = None,
sugar_sigma: float = 1.0,
sugar_estimate_background: bool = True,
sugar_glint_mask_method: str = "cdf",
sugar_iter: Optional[int] = 3,
sugar_termination_thresh: float = 20.0,
# 内部工具函数
_get_image_geo_info=None,
_load_image_as_array=None,
_save_bands_as_image=None,
_copy_hdr_info=None,
_prepare_water_mask_for_algorithm=None,
_interpolate_zero_pixels_batch=None,
deglint_dir: Union[str, Path] = "./3_deglint",
water_mask_dir: Union[str, Path] = "./1_water_mask",
callback: Optional[Callable] = None,
output_path: Optional[str] = None,
) -> str:
"""
执行去除耀斑处理
Args:
img_path: 输入影像文件路径
method: 去耀斑方法
...(其余参数同主类 step3_remove_glint
Returns:
去除耀斑后的影像文件路径
"""
from src.core.glint_removal.Kutser import Kutser
from src.core.glint_removal.Goodman import Goodman
from src.core.glint_removal.Hedley import Hedley
from src.core.glint_removal.SUGAR import SUGAR, correction_iterative
from src.core.utils.gdal_helper import (
get_image_geo_info as _default_get_geo,
load_image_as_array as _default_load,
save_bands_as_image as _default_save_bands,
copy_hdr_info as _default_copy_hdr,
)
from src.core.utils.mask_converter import (
prepare_water_mask_for_algorithm as _default_prepare,
)
# 使用提供的函数或默认函数
if _get_image_geo_info is None:
_get_image_geo_info = _default_get_geo
if _load_image_as_array is None:
_load_image_as_array = _default_load
if _save_bands_as_image is None:
_save_bands_as_image = _default_save_bands
if _copy_hdr_info is None:
_copy_hdr_info = _default_copy_hdr
if _prepare_water_mask_for_algorithm is None:
_prepare_water_mask_for_algorithm = _default_prepare
deglint_dir = Path(deglint_dir)
deglint_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤3", status, msg)
print("\n" + "=" * 80)
print("步骤3: 去除耀斑")
print("=" * 80)
step_start_time = time.time()
# 方法名标准化
raw_method = str(method).lower()
if "kutser" in raw_method:
method = "kutser"
elif "goodman" in raw_method:
method = "goodman"
elif "hedley" in raw_method:
method = "hedley"
elif "sugar" in raw_method:
method = "sugar"
# 如果未启用,直接返回原始影像
if not enabled:
print("已设置跳过去除耀斑enabled=False将直接使用原始影像。")
notify("skipped", "跳过去耀斑,使用原始影像")
return img_path
# ---- 确定水域掩膜 ----
final_water_mask = water_mask
if final_water_mask is not None and str(final_water_mask).lower().endswith(".shp"):
# shp 自动替换为 dat
dat_mask = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
if Path(dat_mask).exists():
print(f"检测到输入掩膜为 .shp自动替换为栅格掩膜: {dat_mask}")
final_water_mask = dat_mask
if final_water_mask is None:
dat_mask_default = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
if Path(dat_mask_default).exists():
final_water_mask = dat_mask_default
print(f"使用步骤1生成的水域掩膜: {final_water_mask}")
# ---- 步骤3.1: 0值像素插值 ----
if interpolate_zeros:
print("\n" + "-" * 80)
print("步骤3.1: 对0值像素进行插值")
print("-" * 80)
interp_start_time = time.time()
if _interpolate_zero_pixels_batch is None:
from src.core.algorithms.interpolation.interpolator import (
interpolate_zero_pixels_batch as _interp_batch,
)
_interpolate_zero_pixels_batch = _interp_batch
interp_result, _ = _interpolate_zero_pixels_batch(
img_path=img_path,
interpolation_method=interpolation_method,
output_path=None,
water_mask=final_water_mask,
deglint_dir=str(deglint_dir),
callback_progress=lambda msg: print(f" {msg}"),
)
img_path = interp_result
interp_end_time = time.time()
print(f"插值完成,使用插值后的影像: {img_path}")
# ---- 获取影像信息 ----
geotransform, projection, width, height, n_bands = _get_image_geo_info(img_path)
print(f"影像尺寸: {width} x {height} x {n_bands}")
mask_for_algorithm = _prepare_water_mask_for_algorithm(
final_water_mask, (height, width), geotransform, projection, img_path
)
# ==================== Kutser ====================
if method == "kutser":
print(f"使用方法: Kutser (氧吸收波段={oxy_band}, NIR波段={nir_band})")
hardcoded_bsq = str(deglint_dir / "deglint_kutser.bsq")
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
# 将用户指定的 output_path 标准化为 .bsq 路径
if output_path:
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
final_hdr = final_bsq.replace(".bsq", ".hdr")
else:
final_bsq = hardcoded_bsq
final_hdr = hardcoded_hdr
if Path(hardcoded_bsq).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
kutser = Kutser(
img_path,
shp_path=None,
oxy_band=oxy_band,
lower_oxy=lower_oxy,
upper_oxy=upper_oxy,
NIR_band=nir_band,
water_mask=mask_for_algorithm,
output_path=hardcoded_bsq,
)
kutser.get_corrected_bands()
if Path(hardcoded_bsq).exists():
_copy_hdr_info(img_path, hardcoded_bsq)
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
notify("completed", f"去耀斑影像已生成: {final}")
return final
raise RuntimeError(f"Kutser算法未生成输出文件: {hardcoded_bsq}")
# ==================== Goodman ====================
elif method == "goodman":
print(f"使用方法: Goodman (NIR波段范围: {nir_lower}-{nir_upper})")
hardcoded_bsq = str(deglint_dir / "deglint_goodman.bsq")
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
if output_path:
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
final_hdr = final_bsq.replace(".bsq", ".hdr")
else:
final_bsq = hardcoded_bsq
final_hdr = hardcoded_hdr
if Path(hardcoded_bsq).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
goodman = Goodman(
img_path,
NIR_lower=nir_lower,
NIR_upper=nir_upper,
A=goodman_A,
B=goodman_B,
water_mask=mask_for_algorithm,
output_path=hardcoded_bsq,
)
corrected_bands = goodman.get_corrected_bands()
if not Path(hardcoded_bsq).exists():
_save_bands_as_image(corrected_bands, hardcoded_bsq, geotransform, projection)
_copy_hdr_info(img_path, hardcoded_bsq)
del corrected_bands
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
notify("completed", f"去耀斑影像已生成: {final}")
return final
# ==================== Hedley ====================
elif method == "hedley":
print(f"使用方法: Hedley (NIR波段={hedley_nir_band})")
hardcoded_bsq = str(deglint_dir / "deglint_hedley.bsq")
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
if output_path:
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
final_hdr = final_bsq.replace(".bsq", ".hdr")
else:
final_bsq = hardcoded_bsq
final_hdr = hardcoded_hdr
if Path(hardcoded_bsq).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
hedley = Hedley(
img_path,
shp_path=None,
NIR_band=hedley_nir_band,
water_mask=mask_for_algorithm,
output_path=hardcoded_bsq,
)
hedley.get_corrected_bands()
if Path(hardcoded_bsq).exists():
_copy_hdr_info(img_path, hardcoded_bsq)
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
notify("completed", f"去耀斑影像已生成: {final}")
return final
raise RuntimeError(f"Hedley算法未生成输出文件: {hardcoded_bsq}")
# ==================== SUGAR ====================
elif method == "sugar":
glint_method_raw = str(sugar_glint_mask_method).lower()
if "cdf" in glint_method_raw or "累积" in glint_method_raw:
sugar_glint_mask_method_fixed = "cdf"
elif "otsu" in glint_method_raw or "大津" in glint_method_raw:
sugar_glint_mask_method_fixed = "otsu"
else:
sugar_glint_mask_method_fixed = "cdf"
print(
f"使用方法: SUGAR (迭代次数={sugar_iter}, 掩膜方法={sugar_glint_mask_method_fixed})"
)
hardcoded_bsq = str(deglint_dir / "deglint_sugar.bsq")
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
if output_path:
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
final_hdr = final_bsq.replace(".bsq", ".hdr")
else:
final_bsq = hardcoded_bsq
final_hdr = hardcoded_hdr
if Path(hardcoded_bsq).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
if sugar_bounds is None:
sugar_bounds = [(1, 2)]
correction_iterative(
img_path,
iter=sugar_iter,
bounds=sugar_bounds,
estimate_background=sugar_estimate_background,
glint_mask_method=sugar_glint_mask_method_fixed,
termination_thresh=sugar_termination_thresh,
water_mask=mask_for_algorithm,
output_path=hardcoded_bsq,
)
if Path(hardcoded_bsq).exists():
_copy_hdr_info(img_path, hardcoded_bsq)
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
notify("completed", f"去耀斑影像已生成: {final}")
return final
raise RuntimeError(f"SUGAR算法未生成输出文件: {hardcoded_bsq}")
else:
raise ValueError(
f"不支持的方法: {method}。支持的方法: kutser, goodman, hedley, sugar"
)

View File

@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
"""
成图步骤
包含 step9_generate_distribution_map
"""
import time
from pathlib import Path
from typing import Optional, Union, Callable
class MappingStep:
"""成图步骤"""
@staticmethod
def generate_distribution_map(
prediction_csv_path: str,
boundary_shp_path: str,
output_image_path: Optional[str] = None,
resolution: float = 30,
input_crs: str = "EPSG:32651",
output_crs: str = "EPSG:4326",
show_sample_points: bool = False,
base_map_tif: Optional[str] = None,
use_distance_diffusion: bool = True,
max_diffusion_distance: Optional[float] = None,
diffusion_power: float = 2,
diffusion_n_neighbors: int = 15,
cmap: Optional[str] = None,
expand_ratio: float = 0.05,
output_dir: Union[str, Path] = "./14_visualization",
callback: Optional[Callable] = None,
) -> str:
"""
根据采样点的坐标和反演的实测参数,通过插值方法得到水质参数可视化分布图
Args:
prediction_csv_path: 预测结果CSV文件路径前两列为经纬度第三列为预测值
boundary_shp_path: 边界shapefile文件路径
output_image_path: 输出图片路径如果为None自动生成
resolution: 插值网格分辨率(米)
input_crs: 输入坐标系
output_crs: 输出坐标系
show_sample_points: 是否在图上显示采样点
base_map_tif: 底图TIF路径
use_distance_diffusion: 是否启用距离扩散补全边界
max_diffusion_distance: 距离扩散最大距离(米)
diffusion_power: 距离扩散幂参数
diffusion_n_neighbors: 距离扩散最近邻数量
cmap: 颜色映射名称None表示自动识别
expand_ratio: 边界外扩比例0-1之间
output_dir: 输出目录
callback: 回调函数
Returns:
可视化分布图文件路径
"""
from src.postprocessing.map import ContentMapper
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤9", status, msg)
print("\n" + "=" * 80)
print("步骤9: 生成水质参数可视化分布图")
print("=" * 80)
step_start_time = time.time()
if output_image_path is None:
csv_name = Path(prediction_csv_path).stem
output_image_path = str(output_dir / f"{csv_name}_distribution.png")
if Path(output_image_path).exists():
print(f"检测到已存在的分布图文件,直接使用: {output_image_path}")
notify("skipped", f"可视化分布图已设置: {output_image_path}")
return output_image_path
mapper = ContentMapper(input_crs=input_crs, output_crs=output_crs)
mapper_kwargs = {
"resolution": resolution,
"show_sample_points": show_sample_points,
"use_distance_diffusion": use_distance_diffusion,
"diffusion_power": diffusion_power,
"diffusion_n_neighbors": diffusion_n_neighbors,
"expand_ratio": expand_ratio,
}
optional_kwargs = {
"base_map_tif": base_map_tif,
"max_diffusion_distance": max_diffusion_distance,
"cmap": cmap,
}
mapper_kwargs.update({k: v for k, v in optional_kwargs.items() if v is not None})
mapper.process_data(
csv_file=prediction_csv_path,
shp_file=boundary_shp_path,
output_file=output_image_path,
**mapper_kwargs,
)
notify("completed", f"可视化分布图已保存: {output_image_path}")
return output_image_path

View File

@ -0,0 +1,497 @@
# -*- coding: utf-8 -*-
"""
建模步骤
包含 step6_train_models, step6_5_non_empirical_modeling, step6_75_custom_regression
"""
import time
import json
from pathlib import Path
from typing import Optional, List, Union, Callable, Dict
import pandas as pd
import numpy as np
# ============================================================
# 汉化 -> 英文 反向映射字典UI 复选框显示文本 -> 底层算法键名)
# ============================================================
# 模型名称:中文 (缩写) -> 英文键名
MODEL_NAME_MAP = {
"多元线性回归 (MLR)": "LinearRegression",
"岭回归 (Ridge)": "Ridge",
"套索回归 (Lasso)": "Lasso",
"弹性网络 (ElasticNet)": "ElasticNet",
"偏最小二乘 (PLSR)": "PLS",
"决策树 (CART)": "DecisionTree",
"随机森林 (RF)": "RF",
"极端随机树 (ET)": "ExtraTrees",
"极值梯度提升 (XGBoost)": "XGBoost",
"轻量梯度提升 (LightGBM)": "LightGBM",
"类别梯度提升 (CatBoost)": "CatBoost",
"梯度提升树 (GBDT)": "GradientBoosting",
"自适应提升 (AdaBoost)": "AdaBoost",
"支持向量回归 (SVR)": "SVR",
"K近邻回归 (KNN)": "KNN",
"多层感知机 (BP神经网络)": "MLP",
}
# 预处理方法:各种可能的中文变体 -> 标准键名
PREPROC_NAME_MAP = {
# 无处理
"无 (None)": "None",
"None": "None",
# MMS
"最小-最大归一化 (MMS)": "MMS",
"MMS": "MMS",
# SS
"标度化 (SS)": "SS",
"SS": "SS",
# SNV
"标准正态变换 (SNV)": "SNV",
"SNV": "SNV",
# MA
"移动平均 (MA)": "MA",
"MA": "MA",
# SG
"Savitzky-Golay (SG)": "SG",
"SG": "SG",
# MSC
"多元散射校正 (MSC)": "MSC",
"MSC": "MSC",
# D1
"一阶导数 (D1)": "D1",
"D1": "D1",
# D2
"二阶导数 (D2)": "D2",
"D2": "D2",
# DT
"去趋势 (DT)": "DT",
"DT": "DT",
# CT
"中心化 (CT)": "CT",
"CT": "CT",
}
# 数据划分方法:各种可能的中文变体 -> 标准键名
SPLIT_NAME_MAP = {
"SPXY 算法 (考量X-Y空间)": "spxy",
"spxy": "spxy",
"KS 算法 (考量X空间)": "ks",
"ks": "ks",
"随机划分 (Random)": "random",
"random": "random",
}
def _normalize_model_names(model_names: List[str]) -> List[str]:
"""清洗模型名称列表:将汉化显示文本还原为英文键名"""
result = []
for name in model_names:
if name in MODEL_NAME_MAP:
result.append(MODEL_NAME_MAP[name])
else:
# 已经是英文键名,直接保留
result.append(name)
return result
def _normalize_preprocessing_methods(methods: List[str]) -> List[str]:
"""清洗预处理方法列表:将汉化显示文本还原为标准键名"""
result = []
for method in methods:
if method in PREPROC_NAME_MAP:
result.append(PREPROC_NAME_MAP[method])
else:
# 已经是标准键名,直接保留
result.append(method)
return result
def _normalize_split_methods(methods: List[str]) -> List[str]:
"""清洗数据划分方法列表:将汉化显示文本还原为标准键名"""
result = []
for method in methods:
if method in SPLIT_NAME_MAP:
result.append(SPLIT_NAME_MAP[method])
else:
# 已经是标准键名,直接保留
result.append(method)
return result
class ModelingStep:
"""建模步骤"""
# ---- Step 6: 训练机器学习模型 ----
@staticmethod
def train_models(
feature_start_column: str = "374.285004",
preprocessing_methods: Optional[List[str]] = None,
model_names: Optional[List[str]] = None,
split_methods: Optional[List[str]] = None,
cv_folds: int = 5,
training_csv_path: Optional[str] = None,
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
callback: Optional[Callable] = None,
_report_generator=None,
) -> str:
"""使用采样点光谱和实测值建立机器学习模型"""
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤6", status, msg)
print("\n" + "=" * 80)
print("步骤6: 训练机器学习模型")
print("=" * 80)
step_start_time = time.time()
if training_csv_path is None:
raise ValueError("必须提供 training_csv_path 参数")
# 检查模型目录是否已有模型
if output_dir.exists() and any(output_dir.iterdir()):
has_models = False
for item in output_dir.iterdir():
if item.is_dir():
model_files = (
list(item.glob("*.pkl"))
+ list(item.glob("*.joblib"))
+ list(item.glob("*.h5"))
)
if model_files:
has_models = True
break
if has_models:
print(f"检测到已存在的模型文件,直接使用: {output_dir}")
notify("skipped", f"模型目录已设置: {output_dir}")
return str(output_dir)
if preprocessing_methods is None:
preprocessing_methods = ["None", "MMS", "SS", "SNV", "MA", "SG", "MSC", "D1", "D2", "DT", "CT"]
if model_names is None:
model_names = ["SVR", "RF", "Ridge", "Lasso"]
if split_methods is None:
split_methods = ["spxy", "ks", "random"]
# ---- 汉化清洗:将 UI 传来的中文/混合名称转换为底层英文键名 ----
preprocessing_methods = _normalize_preprocessing_methods(preprocessing_methods)
model_names = _normalize_model_names(model_names)
split_methods = _normalize_split_methods(split_methods)
print(f"[参数清洗] 预处理方法: {preprocessing_methods}")
print(f"[参数清洗] 模型名称: {model_names}")
print(f"[参数清洗] 划分方法: {split_methods}")
modeler = WaterQualityModelingBatch(str(output_dir))
modeler.train_models_batch(
csv_path=training_csv_path,
feature_start_column=feature_start_column,
preprocessing_methods=preprocessing_methods,
model_names=model_names,
split_methods=split_methods,
cv_folds=cv_folds,
)
print(f"模型训练完成,结果保存在: {output_dir}")
if _report_generator is not None:
try:
summary_path = _report_generator.generate_training_summary(str(output_dir))
print(f"训练摘要报告已生成: {summary_path}")
except Exception as e:
print(f"生成训练摘要报告时出错: {e}")
notify("completed", f"模型训练完成: {output_dir}")
return str(output_dir)
# ---- Step 6.5: 非经验统计回归模型训练 ----
@staticmethod
def train_non_empirical_models(
csv_path: Optional[str] = None,
preprocessing_methods: Optional[List[str]] = None,
algorithms: Optional[List[str]] = None,
value_cols: Union[int, Dict[str, int]] = 0,
spectral_start_col: int = 1,
spectral_end_col: Optional[int] = None,
window: int = 5,
output_dir: Optional[str] = None,
enabled: bool = True,
callback: Optional[Callable] = None,
) -> Dict[str, str]:
"""非经验统计回归模型训练"""
def notify(status, msg=""):
if callback:
callback("步骤6.5", status, msg)
print("\n" + "=" * 80)
print("步骤6.5: 非经验统计回归模型训练")
print("=" * 80)
step_start_time = time.time()
if not enabled:
print("已设置跳过非经验模型训练enabled=False")
notify("skipped", "跳过的经验模型训练")
return {}
if csv_path is None:
raise ValueError("必须提供 csv_path 参数")
if output_dir is not None:
non_empirical_dir = Path(output_dir)
else:
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
non_empirical_dir.mkdir(parents=True, exist_ok=True)
if preprocessing_methods is None:
preprocessing_methods = ["None"]
if algorithms is None:
algorithms = ["chl_a", "nh3", "mno4", "tn", "tp", "tss"]
if isinstance(value_cols, int):
value_cols_dict = {algorithm: value_cols for algorithm in algorithms}
elif isinstance(value_cols, dict):
value_cols_dict = value_cols
else:
raise ValueError("value_cols 参数必须是整数或字典")
if spectral_end_col is None:
df = pd.read_csv(csv_path)
spectral_end_col = len(df.columns) - 1
all_model_results = {}
for preprocess in preprocessing_methods:
preprocess_dir = non_empirical_dir / preprocess
preprocess_dir.mkdir(parents=True, exist_ok=True)
processed_csv_path = _apply_preprocessing_internal(
csv_path, preprocess, preprocess_dir, spectral_start_col
)
for algorithm in algorithms:
algorithm_value_col = value_cols_dict[algorithm]
print(f"\n训练 {preprocess} + {algorithm} 模型 (实测值列: {algorithm_value_col})...")
model_outpath = str(preprocess_dir / f"{preprocess}_{algorithm}.json")
if Path(model_outpath).exists():
print(f"检测到已存在的模型文件,直接使用: {model_outpath}")
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
continue
try:
from src.core.non_empirical_model_correction import run_model_correction
run_model_correction(
algorithm=algorithm,
csv_file=processed_csv_path if Path(processed_csv_path).exists() else csv_path,
value_col=algorithm_value_col,
spectral_start=spectral_start_col,
spectral_end=spectral_end_col,
model_info_outpath=model_outpath,
window=window,
)
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
print(f"模型训练完成: {model_outpath}")
except Exception as e:
print(f"训练 {preprocess}_{algorithm} 模型时出错: {e}")
continue
summary_path = _generate_non_empirical_summary(all_model_results, non_empirical_dir)
notify("completed", f"非经验模型训练完成: {non_empirical_dir}")
return all_model_results
# ---- Step 6.75: 自定义回归分析 ----
@staticmethod
def custom_regression(
csv_path: Optional[str] = None,
x_columns: Optional[Union[str, List[str]]] = None,
y_columns: Optional[Union[str, List[str]]] = None,
methods: Union[str, List[str]] = "all",
output_dir: Optional[str] = None,
enabled: bool = True,
callback: Optional[Callable] = None,
work_dir: Union[str, Path] = "./work_dir",
) -> Optional[str]:
"""使用自定义回归方法分析指标与目标参数之间的关系"""
def notify(status, msg=""):
if callback:
callback("步骤6.75", status, msg)
print("\n" + "=" * 80)
print("步骤6.75: 自定义回归分析")
print("=" * 80)
step_start_time = time.time()
if not enabled:
print("已设置跳过自定义回归分析enabled=False")
notify("skipped", "跳过自定义回归分析")
return None
if csv_path is None:
raise ValueError("必须提供 csv_path 参数")
if y_columns is None:
raise ValueError("必须指定 y_columns")
if x_columns is None:
raise ValueError("必须指定 x_columns")
if isinstance(x_columns, str):
x_columns = [x_columns]
if isinstance(y_columns, str):
y_columns = [y_columns]
df = pd.read_csv(csv_path)
missing_x = [col for col in x_columns if col not in df.columns]
missing_y = [col for col in y_columns if col not in df.columns]
if missing_x:
raise ValueError(f"自变量列不存在: {missing_x}")
if missing_y:
raise ValueError(f"因变量列不存在: {missing_y}")
if output_dir is None:
custom_regression_dir = Path(work_dir) / "9_Custom_Regression_Modeling"
else:
custom_regression_dir = Path(work_dir) / output_dir
custom_regression_dir.mkdir(parents=True, exist_ok=True)
from src.core.modeling.regression import SingleVariableRegressionAnalysis
analyzer = SingleVariableRegressionAnalysis()
analyzer.batch_single_variable_regression(
data=df,
x_columns=x_columns,
y_columns=y_columns,
methods=methods,
output_dir=str(custom_regression_dir),
)
notify("completed", f"自定义回归结果已保存到目录: {custom_regression_dir}")
return str(custom_regression_dir)
# ============================================================
# 内部辅助函数(供 ModelingStep 内部使用)
# ============================================================
def _apply_preprocessing_internal(
csv_path: str,
preprocess_method: str,
output_dir: Path,
spectral_start_col: int = 4,
) -> str:
"""应用预处理到CSV数据内部函数"""
raw_p = str(preprocess_method).lower()
if raw_p == "none" or "" in raw_p or "跳过" in raw_p:
preprocess_method = "None"
elif raw_p == "mms" or "minmax" in raw_p or "最大最小" in raw_p:
preprocess_method = "MMS"
elif raw_p == "ss" or "标准" in raw_p or "标准化" in raw_p:
preprocess_method = "SS"
elif raw_p == "snv" or "标准正态" in raw_p:
preprocess_method = "SNV"
elif raw_p == "ma" or "移动" in raw_p:
preprocess_method = "MA"
elif raw_p == "sg" or "savitzky" in raw_p or "平滑" in raw_p:
preprocess_method = "SG"
elif raw_p == "msc" or "多元散射" in raw_p:
preprocess_method = "MSC"
elif raw_p in ("d1", "d2", "dt"):
preprocess_method = {"d1": "D1", "d2": "D2", "dt": "DT"}.get(raw_p, raw_p.upper())
elif raw_p == "ct" or "去趋势" in raw_p:
preprocess_method = "CT"
if preprocess_method == "None":
return csv_path
output_filename = f"preprocessed_{preprocess_method}.csv"
output_path = str(output_dir / output_filename)
if Path(output_path).exists():
print(f"检测到已存在的预处理文件,直接使用: {output_path}")
return output_path
df = pd.read_csv(csv_path)
non_spectral_cols = df.iloc[:, :spectral_start_col]
spectral_data = df.iloc[:, spectral_start_col:]
from src.preprocessing.spectral_Preprocessing import Preprocessing
save_path = None
if preprocess_method == "SS":
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
models_dir.mkdir(parents=True, exist_ok=True)
save_path = str(models_dir / "scaler_params.pkl")
print(f"SS预处理: scaler模型将保存到 {save_path}")
processed_spectral = Preprocessing(preprocess_method, spectral_data, save_path=save_path)
if isinstance(processed_spectral, pd.DataFrame):
processed_df = pd.concat([non_spectral_cols, processed_spectral], axis=1)
else:
processed_spectral_df = pd.DataFrame(
processed_spectral, columns=spectral_data.columns, index=spectral_data.index
)
processed_df = pd.concat([non_spectral_cols, processed_spectral_df], axis=1)
processed_df.to_csv(output_path, index=False)
print(f"预处理完成: {output_path}")
return output_path
def _generate_non_empirical_summary(model_results: Dict[str, str], output_dir: Path) -> str:
"""生成非经验模型训练结果汇总CSV"""
summary_path = str(output_dir / "non_empirical_models_summary.csv")
summary_data = []
for model_key, model_path in model_results.items():
try:
parts = model_key.split("_")
preprocess_method = parts[0]
algorithm_name = "_".join(parts[1:]) if len(parts) > 2 else parts[1]
with open(model_path, "r", encoding="utf-8") as f:
model_info = json.load(f)
accuracy_list = model_info.get("accuracy", [])
summary_row = {
"Preprocessing Method": preprocess_method,
"Algorithm Name": algorithm_name,
"Model Type": model_info.get("model_type", ""),
"Coefficient Count": len(model_info.get("model_info", [])),
"Average Accuracy(%)": np.mean(accuracy_list) if accuracy_list else 0,
"Min Accuracy(%)": np.min(accuracy_list) if accuracy_list else 0,
"Max Accuracy(%)": np.max(accuracy_list) if accuracy_list else 0,
"Sample Count": len(model_info.get("long", [])),
"Model File": model_path,
}
coefficients = model_info.get("model_info", [])
for i, coeff in enumerate(coefficients[:5]):
summary_row[f"系数_{i+1}"] = coeff
summary_data.append(summary_row)
except Exception as e:
print(f"读取模型文件 {model_path} 时出错: {e}")
continue
if summary_data:
df_summary = pd.DataFrame(summary_data)
df_summary.to_csv(summary_path, index=False, encoding="utf-8-sig")
print(f"汇总文件已生成: {summary_path}")
else:
print("警告: 没有有效的模型数据可汇总")
summary_path = ""
return summary_path

View File

@ -0,0 +1,350 @@
# -*- coding: utf-8 -*-
"""
预测步骤
包含 step7_generate_sampling_points, step8_predict_water_quality,
step8_5_predict_with_non_empirical_models, step8_75_predict_with_custom_regression
"""
import time
from pathlib import Path
from typing import Optional, List, Union, Callable, Dict
class PredictionStep:
"""预测步骤"""
# ---- Step 7: 生成采样点并提取光谱 ----
@staticmethod
def generate_sampling_points(
deglint_img_path: Optional[str] = None,
interval: int = 50,
sample_radius: int = 5,
chunk_size: int = 1000,
water_mask_path: Optional[str] = None,
glint_mask_path: Optional[str] = None,
output_dir: Union[str, Path] = "./10_sampling",
callback: Optional[Callable] = None,
) -> str:
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
from pathlib import Path
from src.utils.sampling import get_spectral_sampling_points_chunked
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = str(output_dir / "sampling_spectra.csv")
def notify(status, msg=""):
if callback:
callback("步骤7", status, msg)
print("\n" + "=" * 80)
print("步骤7: 生成预测采样点并提取光谱")
print("=" * 80)
step_start_time = time.time()
if deglint_img_path is None:
raise ValueError("必须提供 deglint_img_path 参数")
# 1. 初始归一化与安全转换
original_path = Path(deglint_img_path)
final_deglint_path = original_path
# 2. 智能回溯探测:如果当前路径不存在,或者后缀是前端死板的 .dat
if not final_deglint_path.exists() or final_deglint_path.suffix.lower() == '.dat':
print(f"🔍 智能探测:输入去耀斑路径不存在或为 .dat 占位符 ({final_deglint_path}),正在向上搜索真实产物...")
# 定位到预期的 3_deglint 根目录
possible_dir = original_path.parent
if possible_dir.name != '3_deglint' and Path(output_path).parent.parent.exists():
possible_dir = Path(output_path).parent.parent / "3_deglint"
if possible_dir.exists():
# 搜寻该目录下所有真实存在的 .bsq 文件(接管 goodman/sugar/kutser/hedley 的硬编码产物)
existing_bsqs = list(possible_dir.glob("*.bsq"))
if existing_bsqs:
final_deglint_path = existing_bsqs[0]
print(f"💡 智能拦截成功:自动寻回底层真实去耀斑影像: {final_deglint_path}")
else:
final_deglint_path = original_path.with_suffix('.bsq')
else:
final_deglint_path = original_path.with_suffix('.bsq')
deglint_img_str = str(final_deglint_path)
if Path(output_path).exists():
print(f"检测到已存在的采样点光谱数据文件,直接使用: {output_path}")
notify("skipped", f"采样点光谱数据已设置: {output_path}")
return output_path
glint_mask_to_use = glint_mask_path
if glint_mask_to_use is None:
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
# 传递极度安全的 deglint_img_str 进底层
get_spectral_sampling_points_chunked(
deglint_img_str, water_mask_path, glint_mask_to_use,
output_path, interval, sample_radius, chunk_size
)
notify("completed", f"采样点光谱数据已保存: {output_path}")
return output_path
# ---- Step 8: 机器学习模型预测水质参数 ----
@staticmethod
def predict_water_quality(
sampling_csv_path: str,
models_dir: Optional[str] = None,
metric: str = "test_r2",
prediction_column: str = "prediction",
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
callback: Optional[Callable] = None,
_report_generator=None,
) -> Dict[str, str]:
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
from src.core.prediction.inference_batch import WaterQualityInference
def notify(status, msg=""):
if callback:
callback("步骤8", status, msg)
print("\n" + "=" * 80)
print("步骤8: 预测水质参数")
print("=" * 80)
step_start_time = time.time()
if models_dir is None:
raise ValueError("必须提供 models_dir 参数")
ml_prediction_dir = Path(output_dir)
ml_prediction_dir.mkdir(parents=True, exist_ok=True)
prediction_files = {}
if ml_prediction_dir.exists():
csv_files = list(ml_prediction_dir.glob("*.csv"))
for csv_file in csv_files:
file_stem = csv_file.stem
if "_prediction" in file_stem:
target_name = file_stem.replace("_prediction", "")
elif "_pred" in file_stem:
target_name = file_stem.replace("_pred", "")
else:
target_name = file_stem
prediction_files[target_name] = str(csv_file)
# 检查是否所有目标参数都有预测文件
if prediction_files:
models_path_obj = Path(models_dir)
if models_path_obj.exists():
target_folders = [d.name for d in models_path_obj.iterdir() if d.is_dir()]
missing_targets = [t for t in target_folders if t not in prediction_files]
if not missing_targets:
print(f"检测到已存在的预测结果文件,直接使用: {ml_prediction_dir}")
notify("skipped", f"预测结果已设置: {ml_prediction_dir}")
return prediction_files
else:
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
inferencer = WaterQualityInference(models_dir)
all_results = inferencer.batch_inference_multi_models(
models_root_dir=models_dir,
sampling_csv_path=sampling_csv_path,
output_dir=str(ml_prediction_dir),
metric=metric,
prediction_column=prediction_column,
output_format="csv",
)
for target_name, result in all_results.items():
if result.get("status") == "success":
prediction_files[target_name] = result["output_file"]
print(f"预测完成,结果保存在: {ml_prediction_dir}")
if _report_generator is not None:
try:
report_path = _report_generator.generate_prediction_report(prediction_files)
print(f"预测结果报告已生成: {report_path}")
except Exception as e:
print(f"生成预测结果报告时出错: {e}")
notify("completed", f"预测完成: {ml_prediction_dir}")
return prediction_files
# ---- Step 8.5: 非经验模型预测 ----
@staticmethod
def predict_with_non_empirical_models(
sampling_csv_path: str,
non_empirical_models_dir: Optional[str] = None,
output_dir: Optional[str] = None,
metric: str = "Average Accuracy(%)",
prediction_column: str = "prediction",
enabled: bool = True,
callback: Optional[Callable] = None,
work_dir: Union[str, Path] = "./work_dir",
) -> Dict[str, str]:
"""使用非经验统计回归模型进行参数预测"""
def notify(status, msg=""):
if callback:
callback("步骤8.5", status, msg)
print("\n" + "=" * 80)
print("步骤8.5: 使用非经验模型进行参数预测")
print("=" * 80)
step_start_time = time.time()
if not enabled:
print("已设置跳过非经验模型预测enabled=False")
notify("skipped", "跳过非经验模型预测")
return {}
if non_empirical_models_dir is not None:
final_models_dir = non_empirical_models_dir
else:
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
if Path(default_models_dir).exists():
final_models_dir = default_models_dir
else:
raise ValueError("请先执行步骤6.5: 非经验模型训练,或提供 non_empirical_models_dir 参数")
if output_dir is not None:
non_empirical_prediction_dir = Path(output_dir)
else:
non_empirical_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Non_Empirical_Prediction"
non_empirical_prediction_dir.mkdir(parents=True, exist_ok=True)
prediction_files = {}
summary_path = Path(final_models_dir) / "non_empirical_models_summary.csv"
if not summary_path.exists():
raise ValueError(f"未找到非经验模型汇总文件: {summary_path}")
import pandas as pd
df_summary = pd.read_csv(summary_path)
best_models = {}
for algorithm in df_summary["Algorithm Name"].unique():
algorithm_df = df_summary[df_summary["Algorithm Name"] == algorithm]
if metric in algorithm_df.columns:
best_model_row = algorithm_df.nlargest(1, metric)
else:
best_model_row = algorithm_df.iloc[[0]]
best_model_path = best_model_row["Model File"].values[0]
best_preprocess = best_model_row["Preprocessing Method"].values[0]
best_accuracy = best_model_row[metric].values[0] if metric in best_model_row.columns else "N/A"
best_models[algorithm] = {
"model_path": best_model_path,
"preprocess_method": best_preprocess,
"accuracy": best_accuracy,
}
print(f"算法 {algorithm}: 选择 {best_preprocess} (准确率: {best_accuracy})")
pd.read_csv(sampling_csv_path) # just to validate
for algorithm, model_info in best_models.items():
print(f"\n使用 {algorithm} 算法进行预测...")
output_path = str(non_empirical_prediction_dir / f"non_empirical_{algorithm}_{prediction_column}.csv")
if Path(output_path).exists():
print(f"检测到已存在的预测结果文件,直接使用: {output_path}")
prediction_files[algorithm] = output_path
continue
try:
from src.core.non_empirical_retrieval import non_empirical_retrieval
non_empirical_retrieval(
algorithm=algorithm,
model_info_path=model_info["model_path"],
coor_spectral_path=sampling_csv_path,
output_path=output_path,
wave_radius=5,
)
prediction_files[algorithm] = output_path
print(f"预测完成: {output_path}")
except Exception as e:
print(f"使用 {algorithm} 算法预测时出错: {e}")
continue
notify("completed", f"非经验模型预测完成: {non_empirical_prediction_dir}")
return prediction_files
# ---- Step 8.75: 自定义回归模型预测 ----
@staticmethod
def predict_with_custom_regression(
sampling_csv_path: str,
custom_regression_dir: Optional[str] = None,
formula_csv_path: Optional[str] = None,
coordinate_columns: Optional[List[str]] = None,
output_dir: Optional[str] = None,
filename_prefix: str = "custom_regression_prediction",
enabled: bool = True,
callback: Optional[Callable] = None,
work_dir: Union[str, Path] = "./work_dir",
) -> Dict[str, str]:
"""使用自定义回归模型进行参数预测"""
def notify(status, msg=""):
if callback:
callback("步骤8.75", status, msg)
print("\n" + "=" * 80)
print("步骤8.75: 使用自定义回归模型进行参数预测")
print("=" * 80)
step_start_time = time.time()
if not enabled:
print("已设置跳过自定义回归模型预测enabled=False")
notify("skipped", "跳过自定义回归预测")
return {}
if not Path(sampling_csv_path).exists():
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
if custom_regression_dir is not None:
final_regression_dir = custom_regression_dir
else:
final_regression_dir = str(Path(work_dir) / "9_Custom_Regression_Modeling")
if not Path(final_regression_dir).exists():
raise ValueError(
"请先执行步骤6.75: 自定义回归分析,或提供 custom_regression_dir 参数"
)
if output_dir is None:
custom_regression_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Custom_Regression_Prediction"
custom_regression_prediction_dir.mkdir(parents=True, exist_ok=True)
prediction_output_dir = str(custom_regression_prediction_dir)
else:
prediction_output_dir = output_dir
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
predictor = CustomRegressionPredictor(
regression_csv_dir=final_regression_dir,
formula_csv_path=formula_csv_path,
)
print(f"开始使用自定义回归模块进行批量预测...")
print(f" 采样点数据: {sampling_csv_path}")
print(f" 回归模型目录: {final_regression_dir}")
print(f" 输出目录: {prediction_output_dir}")
saved_files = predictor.run_batch_prediction(
input_csv_path=sampling_csv_path,
coordinate_columns=coordinate_columns,
filename_prefix=filename_prefix,
)
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
for param_name, filepath in saved_files.items():
print(f" {param_name}: {filepath}")
notify("completed", f"自定义回归预测完成: {len(saved_files)} 个文件")
return saved_files

View File

@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
"""
步骤1: 水域掩膜生成
支持三种方式:
1. 基于 shp 文件栅格化
2. 使用现有栅格格式掩膜文件 (.dat/.tif)
3. 基于 NDWI 从影像自动生成水体掩膜
"""
import os
import time
from pathlib import Path
from typing import Optional, List, Callable, Union
import numpy as np
class WaterMaskStep:
"""水域掩膜生成步骤"""
@staticmethod
def run(
mask_path: Optional[str] = None,
img_path: Optional[str] = None,
ndwi_threshold: float = 0.4,
use_ndwi: bool = False,
generate_png: bool = True,
output_path: Optional[str] = None,
water_mask_dir: Union[str, Path] = "./1_water_mask",
callback: Optional[Callable] = None,
) -> str:
"""
执行水域掩膜生成
Args:
mask_path: 水体掩膜文件路径,支持 .shp需 img_path或 .dat/.tif直接使用
img_path: 输入影像文件路径(当 mask_path 为 shp 或 use_ndwi=True 时必须提供)
ndwi_threshold: NDWI 阈值use_ndwi=True 时使用)
use_ndwi: 是否使用 NDWI 方法从影像生成水体掩膜
generate_png: 是否生成 PNG 预览图(默认 True
output_path: 指定输出掩膜文件的保存路径(可选)
water_mask_dir: 工作目录
callback: 回调函数,签名为 callback(step, status, message)
Returns:
dat 格式的水域掩膜文件路径
"""
from src.utils.extract_water_area import rasterize_shp, ndwi
from src.core.utils.preview_generator import (
generate_image_preview,
generate_water_mask_overlay,
)
water_mask_dir = Path(water_mask_dir)
water_mask_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤1", status, msg)
print("\n" + "=" * 80)
print("步骤1: 生成或设置水域mask")
print("=" * 80)
step_start_time = time.time()
# 生成影像预览图
if generate_png and img_path is not None and Path(img_path).exists():
preview_path = str(water_mask_dir / "hsi_preview.png")
generate_image_preview(
img_path=img_path,
output_path=preview_path,
title="影像预览: RGB波段(基于波长)"
)
# ---- NDWI 方法 ----
if use_ndwi:
if img_path is None:
raise ValueError("当 use_ndwi=True 时,必须提供 img_path 参数")
if not Path(img_path).exists():
raise ValueError(f"影像文件不存在: {img_path}")
print(f"使用NDWI方法从影像生成水体掩膜阈值={ndwi_threshold}...")
ndwi_output_path = output_path or str(water_mask_dir / "water_mask_from_ndwi.dat")
os.makedirs(Path(ndwi_output_path).parent, exist_ok=True)
if Path(ndwi_output_path).exists():
print(f"检测到已存在的NDWI掩膜文件直接使用: {ndwi_output_path}")
notify("skipped", f"水域掩膜已设置: {ndwi_output_path}")
return ndwi_output_path
ndwi(img_path, ndwi_threshold, ndwi_output_path)
if generate_png:
overlay_path = water_mask_dir / "water_mask_overlay.png"
generate_water_mask_overlay(
img_path=img_path, mask_path=ndwi_output_path, output_path=str(overlay_path)
)
notify("completed", f"NDWI水体掩膜已生成: {ndwi_output_path}")
return ndwi_output_path
# ---- 必须提供 mask_path ----
if mask_path is None:
raise ValueError("必须提供 mask_path 参数或设置 use_ndwi=True")
if not Path(mask_path).exists():
raise ValueError(f"文件不存在: {mask_path}")
file_ext = Path(mask_path).suffix.lower()
# ---- SHP 栅格化 ----
if file_ext == ".shp":
if img_path is None:
raise ValueError("当 mask_path 为 shp 格式时,必须提供 img_path 参数")
print(f"检测到shp格式的水体掩膜正在转换为dat格式...")
shp_output_path = output_path or str(water_mask_dir / "water_mask_from_shp.dat")
os.makedirs(Path(shp_output_path).parent, exist_ok=True)
if Path(shp_output_path).exists():
print(f"检测到已存在的栅格化掩膜文件,直接使用: {shp_output_path}")
notify("skipped", f"水域掩膜已设置: {shp_output_path}")
if generate_png:
overlay_path = water_mask_dir / "water_mask_overlay.png"
if not overlay_path.exists():
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
return shp_output_path
safe_mask_path = os.path.abspath(mask_path).replace("\\", "/")
rasterize_shp(safe_mask_path, shp_output_path, img_path)
if generate_png:
overlay_path = water_mask_dir / "water_mask_overlay.png"
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
notify("completed", f"dat格式水域掩膜已生成: {shp_output_path}")
return shp_output_path
# ---- 栅格格式直接使用 ----
print(f"检测到栅格格式的水体掩膜,直接使用: {mask_path}")
if generate_png and img_path is not None and Path(img_path).exists():
overlay_path = water_mask_dir / "water_mask_overlay.png"
generate_water_mask_overlay(img_path, mask_path, str(overlay_path))
notify("completed", f"水域掩膜已设置: {mask_path}")
return mask_path

View File

@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
"""
工具模块 - 统一导出接口
"""
from src.core.utils.gdal_helper import (
get_image_geo_info,
load_image_as_array,
save_array_as_image,
save_bands_as_image,
copy_hdr_info,
read_band_as_array,
read_multiple_bands,
)
from src.core.utils.mask_converter import (
prepare_water_mask_for_algorithm,
ensure_water_mask_dat,
)
from src.core.utils.preview_generator import (
generate_image_preview,
generate_water_mask_overlay,
select_rgb_bands_by_wavelength,
get_wavelength_info,
)
__all__ = [
# GDAL IO
'get_image_geo_info',
'load_image_as_array',
'save_array_as_image',
'save_bands_as_image',
'copy_hdr_info',
'read_band_as_array',
'read_multiple_bands',
# 掩膜转换
'prepare_water_mask_for_algorithm',
'ensure_water_mask_dat',
# 预览图生成
'generate_image_preview',
'generate_water_mask_overlay',
'select_rgb_bands_by_wavelength',
'get_wavelength_info',
]

View File

@ -0,0 +1,309 @@
# -*- coding: utf-8 -*-
"""
GDAL 底层 IO 工具模块
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
"""
import os
from pathlib import Path
from typing import Tuple, Optional
import numpy as np
# GDAL 导入(可选)
try:
from osgeo import gdal, ogr, gdal_array
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# hdr 文件工具
try:
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
UTIL_AVAILABLE = True
except ImportError:
UTIL_AVAILABLE = False
write_fields_to_hdrfile = None
get_hdr_file_path = None
# ============================================================
# 影像信息读取
# ============================================================
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
"""
获取影像的地理信息(不加载图像数据,节省内存)
Args:
img_path: 影像文件路径
Returns:
tuple: (geotransform, projection, width, height, n_bands)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
return geotransform, projection, width, height, n_bands
finally:
dataset = None
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
"""
加载影像文件为numpy数组
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
建议直接传递文件路径给算法类让算法类使用GDAL逐波段处理。
Args:
img_path: 影像文件路径
Returns:
tuple: (image_array, geotransform, projection)
image_array: numpy数组形状为(height, width, bands)
geotransform: 地理变换参数
projection: 投影信息
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
n_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
image_bands = []
for i in range(1, n_bands + 1):
band = dataset.GetRasterBand(i)
band_data = band.ReadAsArray()
image_bands.append(band_data)
image_array = np.dstack(image_bands)
return image_array, geotransform, projection
finally:
dataset = None
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
"""
读取单个波段为 numpy 数组
Args:
img_path: 影像文件路径
band_index: 波段索引(从 0 开始)
Returns:
numpy 数组,形状为 (height, width)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
band = dataset.GetRasterBand(band_index + 1)
return band.ReadAsArray()
finally:
dataset = None
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
"""
读取多个指定波段为列表
Args:
img_path: 影像文件路径
band_indices: 波段索引列表
Returns:
tuple: (band_list, geotransform, projection)
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取影像文件")
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
geotransform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
bands = []
for idx in band_indices:
band = dataset.GetRasterBand(idx + 1)
bands.append(band.ReadAsArray())
return bands, geotransform, projection
finally:
dataset = None
# ============================================================
# 影像写入
# ============================================================
def save_array_as_image(image_array: np.ndarray, output_path: str,
geotransform: tuple, projection: str,
dtype=None) -> str:
"""
将numpy数组保存为影像文件
Args:
image_array: numpy数组形状为(height, width, bands) 或 (height, width)
output_path: 输出文件路径
geotransform: 地理变换参数
projection: 投影信息
dtype: GDAL数据类型默认 gdal.GDT_Float32
Returns:
输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if dtype is None:
dtype = gdal.GDT_Float32
if image_array.ndim == 2:
height, width = image_array.shape
n_bands = 1
else:
height, width, n_bands = image_array.shape
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise ValueError("无法创建影像文件,没有可用的驱动")
dataset = driver.Create(output_path, width, height, n_bands, dtype)
if dataset is None:
raise ValueError(f"无法创建输出文件: {output_path}")
try:
dataset.SetGeoTransform(geotransform)
dataset.SetProjection(projection)
if n_bands == 1:
band = dataset.GetRasterBand(1)
band.WriteArray(image_array)
band.FlushCache()
else:
for i in range(n_bands):
band = dataset.GetRasterBand(i + 1)
band.WriteArray(image_array[:, :, i])
band.FlushCache()
finally:
dataset = None
return output_path
def save_bands_as_image(corrected_bands: list, output_path: str,
geotransform: tuple, projection: str,
dtype=None) -> str:
"""
直接从波段列表保存影像文件(避免堆叠,节省内存)
Args:
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
output_path: 输出文件路径
geotransform: 地理变换参数
projection: 投影信息
dtype: GDAL数据类型
Returns:
输出文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法保存影像文件")
if not corrected_bands:
raise ValueError("波段列表为空")
if dtype is None:
dtype = gdal.GDT_Float32
n_bands = len(corrected_bands)
height, width = corrected_bands[0].shape
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise ValueError("无法创建影像文件,没有可用的驱动")
dataset = driver.Create(output_path, width, height, n_bands, dtype)
if dataset is None:
raise ValueError(f"无法创建输出文件: {output_path}")
try:
dataset.SetGeoTransform(geotransform)
dataset.SetProjection(projection)
for i, band_array in enumerate(corrected_bands):
if band_array.shape != (height, width):
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
band = dataset.GetRasterBand(i + 1)
band.WriteArray(band_array)
band.FlushCache()
finally:
dataset = None
return output_path
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
"""
复制原始影像的hdr文件信息如波长等到目标影像的hdr文件
Args:
source_img_path: 源影像文件路径原始bsq文件
dest_img_path: 目标影像文件路径去耀斑后的bsq文件
Returns:
bool: 是否成功
"""
if not UTIL_AVAILABLE:
print("警告: util模块未导入无法复制hdr文件信息")
return False
try:
source_hdr_path = get_hdr_file_path(source_img_path)
dest_hdr_path = get_hdr_file_path(dest_img_path)
if not Path(source_hdr_path).exists():
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
return False
if not Path(dest_hdr_path).exists():
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
return False
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
return True
except Exception as e:
print(f"警告: 复制hdr文件信息时出错: {e}")
return False

View File

@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
"""
掩膜转换工具模块
提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换,
以及水体掩膜的预处理逻辑。
"""
import os
from pathlib import Path
from typing import Optional, Union
import numpy as np
try:
from osgeo import gdal, ogr
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
def prepare_water_mask_for_algorithm(
water_mask: Optional[Union[str, np.ndarray]],
image_shape: Union[tuple, np.ndarray],
geotransform: tuple,
projection: str,
img_path: str,
water_mask_dir: Optional[str] = None,
callback=None
) -> Optional[np.ndarray]:
"""
准备水域掩膜供算法使用
支持格式:
- None自动使用预先生成的 dat 格式掩膜
- numpy.ndarray直接返回确保是 0/1 格式)
- .dat / .tif 等栅格文件:读取并返回
- .shp 文件:先栅格化,再读取返回
Args:
water_mask: 掩膜来源
image_shape: 影像形状 (height, width) 或 (height, width, channels)
geotransform: GDAL 地理变换参数
projection: 投影信息
img_path: 影像路径(用于 shp 栅格化)
water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果)
callback: 进度回调函数(可选)
Returns:
numpy数组dtype=uint80=非水域1=水域)或 None
"""
img_height, img_width = image_shape[0], image_shape[1]
if water_mask is None:
return None
# numpy 数组直接返回
if isinstance(water_mask, np.ndarray):
if water_mask.shape[:2] != (img_height, img_width):
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配")
return (water_mask > 0).astype(np.uint8)
# 字符串路径
if isinstance(water_mask, str):
ext = Path(water_mask).suffix.lower()
# shapefile 格式
if ext == '.shp':
return _convert_shp_to_mask(
shp_path=water_mask,
img_path=img_path,
image_shape=image_shape,
geotransform=geotransform,
projection=projection,
water_mask_dir=water_mask_dir,
callback=callback
)
# 栅格文件格式
return _load_raster_mask(water_mask, img_height, img_width)
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
def _convert_shp_to_mask(shp_path: str, img_path: str,
image_shape: tuple,
geotransform: tuple,
projection: str,
water_mask_dir: Optional[str] = None,
callback=None) -> np.ndarray:
"""将 shapefile 栅格化为掩膜数组"""
from src.utils.extract_water_area import rasterize_shp
safe_shp_path = os.path.abspath(shp_path).replace('\\', '/')
shp_name = Path(safe_shp_path).stem
if water_mask_dir:
temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat")
else:
temp_mask_path = f"/tmp/water_mask_{shp_name}.dat"
# 缓存:已栅格化则直接读取
if Path(temp_mask_path).exists():
print(f"使用已存在的栅格化掩膜: {temp_mask_path}")
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
# 需要栅格化
if img_path is None:
raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化")
print(f"正在将 SHP 栅格化: {safe_shp_path}")
rasterize_shp(safe_shp_path, temp_mask_path, img_path)
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray:
"""从栅格文件加载掩膜"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法读取掩膜文件")
mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
if mask_dataset is None:
raise ValueError(f"无法打开掩膜文件: {mask_path}")
try:
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
finally:
mask_dataset = None
if mask_array.shape != (img_height, img_width):
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配")
return (mask_array > 0).astype(np.uint8)
def ensure_water_mask_dat(img_path: str,
existing_dat_path: Optional[str] = None,
output_dir: Optional[str] = None) -> str:
"""
确保存在 dat 格式的水体掩膜文件用于步骤3/4中的算法
如果 existing_dat_path 存在且是 .dat 文件,直接返回。
如果存在同名 .dat 文件,直接返回。
否则从 img_path 生成并保存到 output_dir。
Args:
img_path: 用于生成掩膜的遥感影像路径
existing_dat_path: 已有的 dat 格式掩膜路径(可选)
output_dir: 输出目录(可选)
Returns:
dat 格式掩膜文件路径
"""
if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat':
if Path(existing_dat_path).exists():
return existing_dat_path
img_name = Path(img_path).stem
if output_dir is None:
output_dir = str(Path(img_path).parent)
dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat")
if Path(dat_path).exists():
return dat_path
# 如果已有其他格式的掩膜,转换为 dat
for ext in ['.tif', '.img', '.tiff']:
alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}")
if Path(alt_path).exists():
return _convert_to_dat(alt_path, dat_path)
return dat_path # 返回目标路径,让调用方决定是否需要生成
def _convert_to_dat(src_path: str, dest_path: str) -> str:
"""将其他栅格格式转换为 ENVI dat 格式"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法转换格式")
src_ds = gdal.Open(src_path, gdal.GA_ReadOnly)
if src_ds is None:
raise ValueError(f"无法打开源掩膜文件: {src_path}")
try:
geotransform = src_ds.GetGeoTransform()
projection = src_ds.GetProjection()
band = src_ds.GetRasterBand(1)
array = band.ReadAsArray()
driver = gdal.GetDriverByName('ENVI')
if driver is None:
driver = gdal.GetDriverByName('GTiff')
dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte)
if dest_ds is None:
raise ValueError(f"无法创建输出文件: {dest_path}")
try:
dest_ds.SetGeoTransform(geotransform)
dest_ds.SetProjection(projection)
dest_band = dest_ds.GetRasterBand(1)
dest_band.WriteArray((array > 0).astype(np.uint8))
dest_band.FlushCache()
finally:
dest_ds = None
return dest_path
finally:
src_ds = None

View File

@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-
"""
遥感影像预览图生成工具模块
提供高光谱影像的 RGB 预览图、水域掩膜叠加图等可视化功能。
"""
import numpy as np
from pathlib import Path
from typing import Optional, List
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
# matplotlib 仅在实际使用时导入preview_generator 是可视化工具)
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
# ============================================================
# 辅助函数:波段选择
# ============================================================
def select_rgb_bands_by_wavelength(band_count: int,
wavelength_info: Optional[List[float]] = None,
fallback_bands: Optional[List[int]] = None) -> List[int]:
"""
根据波长自动选择 RGB 波段
Args:
band_count: 总波段数
wavelength_info: 各波段波长列表nm长度为 band_count
fallback_bands: 当无法通过波长选择时的回退波段索引 [R, G, B]
Returns:
波段索引列表 [R_index, G_index, B_index]0-based
"""
if fallback_bands is None:
fallback_bands = [band_count - 3, band_count - 2, band_count - 1]
if wavelength_info is None:
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
# 目标波长nm
TARGET_R = 650
TARGET_G = 550
TARGET_B = 460
def find_closest(target: float) -> int:
min_dist = float('inf')
best_idx = 0
for i, wl in enumerate(wavelength_info):
dist = abs(wl - target)
if dist < min_dist:
min_dist = dist
best_idx = i
return best_idx
try:
r_idx = find_closest(TARGET_R)
g_idx = find_closest(TARGET_G)
b_idx = find_closest(TARGET_B)
return [r_idx, g_idx, b_idx]
except Exception:
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
def get_wavelength_info(img_path: str) -> Optional[List[float]]:
"""从 hdr 文件读取波长信息"""
try:
hdr_path = Path(img_path).with_suffix('.hdr')
if not hdr_path.exists():
return None
wavelengths = []
in_wl = False
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.strip()
if line.startswith('wavelength ='):
in_wl = True
line = line.split('=', 1)[1].strip()
elif in_wl:
if line.startswith('{'):
line = line[1:]
if line.endswith('}'):
line = line[:-1]
in_wl = False
# 解析逗号分隔的数值
for token in line.replace(',', ' ').split():
try:
wavelengths.append(float(token))
except ValueError:
pass
return wavelengths if wavelengths else None
except Exception:
return None
# ============================================================
# 核心预览图生成函数
# ============================================================
def generate_image_preview(img_path: str,
output_path: str,
bands: Optional[List[int]] = None,
title: str = "影像预览") -> str:
"""
生成高光谱影像的 PNG 预览图
Args:
img_path: 输入影像路径
output_path: 输出 PNG 文件路径
bands: RGB 波段索引 [R, G, B]None 则自动选择
title: 图片标题
Returns:
生成的 PNG 文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法生成影像预览图")
if Path(output_path).exists():
print(f"检测到已存在的预览图,跳过生成: {output_path}")
return output_path
dataset = gdal.Open(img_path)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
band_count = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
# 自动选择波段
if bands is None:
if band_count >= 3:
wl_info = get_wavelength_info(img_path)
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
else:
bands = [0, 0, 0]
# 读取波段
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
r_data[r_data <= 0] = np.nan
if band_count > 1:
g_data[g_data <= 0] = np.nan
if band_count > 2:
b_data[b_data <= 0] = np.nan
# 线性拉伸
def linear_stretch(data, low=2, high=98):
valid = data[~np.isnan(data)]
if len(valid) == 0:
return np.zeros_like(data)
lo = np.percentile(valid, low)
hi = np.percentile(valid, high)
if hi - lo < 1e-10:
return np.zeros_like(data)
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
return np.nan_to_num(stretched, nan=0.0)
r_s = linear_stretch(r_data)
g_s = linear_stretch(g_data) if band_count > 1 else r_s
b_s = linear_stretch(b_data) if band_count > 2 else r_s
rgb_image = np.stack([r_s, g_s, b_s], axis=2)
# 绘图
fig, ax = plt.subplots(figsize=(12, 10))
ax.imshow(rgb_image)
ax.set_title(title, fontsize=12, fontweight='bold')
ax.axis('off')
geotransform = dataset.GetGeoTransform()
if geotransform and geotransform[1] != 0:
pixel_size_x = abs(geotransform[1])
scale_text = f"分辨率: {pixel_size_x:.2f} m/px | 尺寸: {width} x {height} px"
fig.text(0.5, 0.02, scale_text, ha='center', fontsize=9,
color='white',
bbox=dict(facecolor='black', alpha=0.6,
boxstyle='round,pad=0.3'))
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
return output_path
finally:
dataset = None
def generate_water_mask_overlay(img_path: str,
mask_path: str,
output_path: str,
bands: Optional[List[int]] = None,
mask_color: tuple = (0, 100, 255),
mask_alpha: float = 0.5) -> str:
"""
生成水域掩膜叠加到原图的 PNG 图像
Args:
img_path: 输入影像路径
mask_path: 水域掩膜文件路径
output_path: 输出 PNG 路径
bands: RGB 波段索引None 则自动选择
mask_color: 掩膜叠加颜色 (R, G, B)
mask_alpha: 掩膜透明度0=完全透明1=完全不透明)
Returns:
生成的 PNG 文件路径
"""
if not GDAL_AVAILABLE:
raise ImportError("GDAL未安装无法生成叠加图")
if Path(output_path).exists():
print(f"检测到已存在的叠加图,跳过生成: {output_path}")
return output_path
dataset = gdal.Open(img_path)
if dataset is None:
raise ValueError(f"无法打开影像文件: {img_path}")
try:
width = dataset.RasterXSize
height = dataset.RasterYSize
band_count = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
# 自动选择波段
if bands is None:
if band_count >= 3:
wl_info = get_wavelength_info(img_path)
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
else:
bands = [0, 0, 0]
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
r_data[r_data <= 0] = np.nan
if band_count > 1:
g_data[g_data <= 0] = np.nan
if band_count > 2:
b_data[b_data <= 0] = np.nan
def linear_stretch(data, low=2, high=98):
valid = data[~np.isnan(data)]
if len(valid) == 0:
return np.zeros_like(data)
lo = np.percentile(valid, low)
hi = np.percentile(valid, high)
if hi - lo < 1e-10:
return np.zeros_like(data)
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
return np.nan_to_num(stretched, nan=0.0)
r_s = linear_stretch(r_data)
g_s = linear_stretch(g_data) if band_count > 1 else r_s
b_s = linear_stretch(b_data) if band_count > 2 else r_s
rgb_image = np.nan_to_num(np.stack([r_s, g_s, b_s], axis=2)) * 255
rgb_image = rgb_image.astype(np.uint8)
# 读取掩膜
mask_dataset = gdal.Open(mask_path)
if mask_dataset is not None:
mask_data = mask_dataset.GetRasterBand(1).ReadAsArray()
mask_dataset = None
else:
print(f"警告: 无法打开掩膜文件: {mask_path}")
mask_data = None
# Alpha 混合
overlay = np.zeros((height, width, 4), dtype=np.uint8)
overlay[:, :, 0:3] = mask_color
overlay[:, :, 3] = 255 # 全不透明
blended = rgb_image.astype(np.float32)
if mask_data is not None:
alpha = mask_data.astype(np.float32) / 255.0 * mask_alpha
for c in range(3):
blended[:, :, c] = rgb_image[:, :, c].astype(np.float32) * (1 - alpha) + mask_color[c] * alpha
blended = blended.astype(np.uint8)
# 绘图
fig, ax = plt.subplots(figsize=(14, 10))
ax.imshow(blended)
ax.axis('off')
legend_elements = [
Patch(facecolor=f'#{mask_color[0]:02x}{mask_color[1]:02x}{mask_color[2]:02x}',
edgecolor='black', alpha=mask_alpha, label='水域范围')
]
ax.legend(handles=legend_elements, loc='upper right', framealpha=0.9)
# 面积计算
if geotransform and geotransform[1] != 0:
pixel_size_x = abs(geotransform[1])
pixel_size_y = abs(geotransform[5])
pixel_area = pixel_size_x * pixel_size_y
if mask_data is not None:
water_pixels = np.sum(mask_data > 0)
valid_pixels = np.sum(mask_data >= 0)
water_km2 = water_pixels * pixel_area / 1_000_000
valid_km2 = valid_pixels * pixel_area / 1_000_000
pct = (water_pixels / valid_pixels * 100) if valid_pixels > 0 else 0
area_text = (f'水域面积: {water_km2:.2f} km² | '
f'影像总面积: {valid_km2:.2f} km² | '
f'占比: {pct:.1f}%')
ax.text(0.02, 0.98, area_text,
transform=ax.transAxes, fontsize=11,
color='white', fontweight='bold',
bbox=dict(facecolor='#0064FF', alpha=0.8,
edgecolor='black', boxstyle='round,pad=0.5'),
verticalalignment='top')
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
return output_path
finally:
dataset = None

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 统一导出接口
本模块从各子模块导入可视化函数,提供统一的导出接口。
"""
from src.core.visualization.scatter_plot import generate_model_scatter_plots
from src.core.visualization.spectrum_plot import generate_spectrum_comparison_plots
from src.core.visualization.boxplot import generate_boxplots
from src.core.visualization.statistics import generate_statistical_charts
from src.core.visualization.preview import generate_glint_deglint_previews
from src.core.visualization.report import generate_pipeline_report
__all__ = [
'generate_model_scatter_plots',
'generate_spectrum_comparison_plots',
'generate_boxplots',
'generate_statistical_charts',
'generate_glint_deglint_previews',
'generate_pipeline_report',
]

View File

@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 箱型图生成
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from typing import Optional, Dict, List
sns.set_style("whitegrid")
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False
def generate_boxplots(
csv_path: str,
parameter_columns: Optional[List[str]] = None,
data_start_column: int = 4,
save_individual: bool = True,
use_seaborn: bool = True,
output_dir: Optional[str] = None
) -> Dict[str, str]:
"""
生成水质参数的箱型图
Args:
csv_path: CSV文件路径
parameter_columns: 参数列名列表如果为None自动检测
data_start_column: 数据开始列索引从第几列开始默认第5列索引为4
save_individual: 是否为每个参数单独保存箱型图
use_seaborn: 是否使用seaborn绘制更美观
output_dir: 输出目录None则使用默认
Returns:
箱型图文件路径字典
"""
print("\n" + "="*80)
print("生成水质参数箱型图")
print("="*80)
if csv_path is None:
raise ValueError("请提供 csv_path")
# 确定输出目录
if output_dir is None:
csv_dir = Path(csv_path).parent
output_dir = str(csv_dir / "visualization" / "boxplots")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 读取数据
df = pd.read_csv(csv_path)
# 确定参数列
if parameter_columns is None:
data_columns = df.iloc[:, data_start_column:]
parameter_columns = list(data_columns.columns)
else:
parameter_columns = [col for col in parameter_columns if col in df.columns]
if not parameter_columns:
print("警告: 未找到有效的参数列")
return {}
boxplot_dir = Path(output_dir)
boxplot_paths = {}
if save_individual:
print(f"为每个参数单独绘制箱型图(共 {len(parameter_columns)} 个参数)")
for column in parameter_columns:
if column not in df.columns:
continue
clean_data = df[column].dropna()
if len(clean_data) == 0:
print(f"跳过列 '{column}': 没有有效数据")
continue
try:
plt.figure(figsize=(8, 6))
if use_seaborn:
plot_data = pd.DataFrame({
'参数': [column] * len(clean_data),
'数值': clean_data
})
sns.boxplot(data=plot_data, x='参数', y='数值', palette='Set2')
sns.stripplot(data=plot_data, x='参数', y='数值',
color='red', alpha=0.6, size=5, jitter=True)
else:
box_plot = plt.boxplot([clean_data], labels=[column],
patch_artist=True, showfliers=False)
box_plot['boxes'][0].set_facecolor('lightblue')
box_plot['boxes'][0].set_alpha(0.7)
x_pos = np.random.normal(1, 0.04, size=len(clean_data))
plt.scatter(x_pos, clean_data, alpha=0.6, s=30, color='red',
edgecolors='black', linewidth=0.5, zorder=3)
plt.title(f'{column} - 箱型图', fontsize=14, fontweight='bold')
plt.xlabel('参数', fontsize=12)
plt.ylabel('数值', fontsize=12)
stats_text = (f'数据点数: {len(clean_data)}\n'
f'均值: {clean_data.mean():.2f}\n'
f'中位数: {clean_data.median():.2f}\n'
f'标准差: {clean_data.std():.2f}')
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
verticalalignment='top',
bbox=dict(boxstyle='round',
facecolor='wheat' if not use_seaborn else 'lightgreen',
alpha=0.8))
plt.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
safe_column_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
save_path = boxplot_dir / f'{safe_column_name}_boxplot.png'
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
boxplot_paths[column] = str(save_path)
print(f" 已保存: {save_path.name}")
except Exception as e:
print(f" 处理参数 {column} 时出错: {e}")
continue
# 综合箱型图
try:
print("\n生成综合箱型图(所有参数在一张图上)")
plt.figure(figsize=(max(12, len(parameter_columns) * 0.8), 8))
box_data = []
labels = []
for column in parameter_columns:
if column in df.columns:
clean_data = df[column].dropna()
if len(clean_data) > 0:
box_data.append(clean_data)
labels.append(column)
if box_data:
if use_seaborn:
melted_data = pd.melt(df[labels], var_name='参数', value_name='数值')
melted_data = melted_data.dropna()
sns.boxplot(data=melted_data, x='参数', y='数值', palette='Set3')
sns.stripplot(data=melted_data, x='参数', y='数值',
color='red', alpha=0.6, size=4, jitter=True)
else:
box_plot = plt.boxplot(box_data, labels=labels, patch_artist=True, showfliers=False)
colors = plt.cm.Set3(np.linspace(0, 1, len(box_data)))
for patch, color in zip(box_plot['boxes'], colors):
patch.set_facecolor(color)
patch.set_alpha(0.7)
for i, data in enumerate(box_data):
x_pos = np.random.normal(i + 1, 0.04, size=len(data))
plt.scatter(x_pos, data, alpha=0.6, s=20, color='red',
edgecolors='black', linewidth=0.5, zorder=3)
plt.title('水质参数箱型图(综合)', fontsize=16, fontweight='bold')
plt.xlabel('参数', fontsize=12)
plt.ylabel('数值', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
combined_path = boxplot_dir / 'all_parameters_boxplot.png'
plt.savefig(combined_path, dpi=300, bbox_inches='tight')
plt.close()
boxplot_paths['all_parameters'] = str(combined_path)
print(f" 已保存综合箱型图: {combined_path.name}")
except Exception as e:
print(f"生成综合箱型图时出错: {e}")
print(f"\n箱型图生成完成,共生成 {len(boxplot_paths)} 个图表")
return boxplot_paths

View File

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 耀斑影像预览图生成
"""
from pathlib import Path
from typing import Optional, Dict
from src.postprocessing.visualization_reports import WaterQualityVisualization
def generate_glint_deglint_previews(
work_dir: str,
output_subdir: str = "glint_deglint_previews",
generate_glint: bool = True,
generate_deglint: bool = True,
output_dir: Optional[str] = None
) -> Dict[str, str]:
"""
生成2_glint和3_deglint文件夹中影像文件的PNG预览图
Args:
work_dir: 工作目录
output_subdir: 输出子目录名称
generate_glint: 是否处理2_glint文件夹
generate_deglint: 是否处理3_deglint文件夹
output_dir: 输出目录None则使用默认
Returns:
生成的预览图路径字典
"""
print(f"\n{'='*70}")
print("步骤: 生成耀斑分析影像预览图")
print(f"{'='*70}")
if work_dir is None:
raise ValueError("请提供 work_dir")
# 确定输出目录
if output_dir is None:
output_dir = str(Path(work_dir) / "visualization" / output_subdir)
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 实例化可视化器
visualizer = WaterQualityVisualization(output_dir)
try:
preview_paths = visualizer.generate_glint_deglint_previews(
work_dir=work_dir,
output_subdir=output_subdir,
generate_glint=generate_glint,
generate_deglint=generate_deglint
)
print(f"耀斑分析影像预览图生成完成,共生成 {len(preview_paths)} 个预览图")
return preview_paths
except Exception as e:
print(f"生成耀斑分析影像预览图时出错: {e}")
return {}

View File

@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 流程执行报告生成
"""
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional, Dict
from datetime import datetime
def generate_pipeline_report(
step_timings: Dict,
pipeline_start_time: Optional[float] = None,
pipeline_end_time: Optional[float] = None,
output_path: Optional[str] = None
) -> str:
"""
生成流程执行报告,包含每步的耗时统计
Args:
step_timings: 步骤耗时字典(格式:{step_name: {start_time, end_time, elapsed_seconds, elapsed_formatted, status, error}}
pipeline_start_time: 流程开始时间戳
pipeline_end_time: 流程结束时间戳
output_path: 输出文件路径如果为None自动生成
Returns:
报告文件路径
"""
if output_path is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_path = str(Path.cwd() / "reports" / f"pipeline_report_{timestamp}.csv")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
def _format_time(seconds: float) -> str:
if seconds < 60:
return f"{seconds:.2f}"
elif seconds < 3600:
minutes = int(seconds // 60)
secs = seconds % 60
return f"{minutes}{secs:.2f}"
else:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = seconds % 60
return f"{hours}小时{minutes}{secs:.2f}"
# 准备报告数据
report_data = []
total_time = 0.0
step_order = [
"步骤1: 生成水域mask",
"步骤2: 找到耀斑区域",
"步骤3: 去除耀斑",
"步骤4: 处理CSV文件",
"步骤5: 提取训练样本点光谱",
"步骤5.5: 计算水质光谱指数",
"步骤6: 训练机器学习模型",
"步骤6.5: 非经验模型训练",
"步骤6.75: 自定义回归",
"步骤7: 生成预测采样点",
"步骤8: 预测水质参数",
"步骤9: 生成分布图"
]
for step_name in step_order:
if step_name in step_timings:
timing_info = step_timings[step_name]
report_data.append({
'步骤': step_name,
'开始时间': timing_info['start_time'],
'结束时间': timing_info['end_time'],
'耗时(秒)': f"{timing_info['elapsed_seconds']:.2f}",
'耗时(格式化)': timing_info['elapsed_formatted'],
'状态': timing_info['status'],
'错误信息': timing_info.get('error', '')
})
if timing_info['status'] == 'completed':
total_time += timing_info['elapsed_seconds']
if pipeline_start_time and pipeline_end_time:
pipeline_total = pipeline_end_time - pipeline_start_time
report_data.append({
'步骤': '总计',
'开始时间': datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S'),
'结束时间': datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S'),
'耗时(秒)': f"{pipeline_total:.2f}",
'耗时(格式化)': _format_time(pipeline_total),
'状态': 'completed',
'错误信息': ''
})
df_report = pd.DataFrame(report_data)
df_report.to_csv(output_path, index=False, encoding='utf-8-sig')
# 文本格式报告
txt_output_path = str(Path(output_path).with_suffix('.txt'))
with open(txt_output_path, 'w', encoding='utf-8') as f:
f.write("="*80 + "\n")
f.write("水质参数反演流程执行报告\n")
f.write("="*80 + "\n\n")
if pipeline_start_time and pipeline_end_time:
f.write(f"流程开始时间: {datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"流程结束时间: {datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"总耗时: {_format_time(pipeline_end_time - pipeline_start_time)}\n\n")
f.write("-"*80 + "\n")
f.write("各步骤执行详情:\n")
f.write("-"*80 + "\n\n")
for step_name in step_order:
if step_name in step_timings:
timing_info = step_timings[step_name]
f.write(f"{step_name}\n")
f.write(f" 开始时间: {timing_info['start_time']}\n")
f.write(f" 结束时间: {timing_info['end_time']}\n")
f.write(f" 耗时: {timing_info['elapsed_formatted']} ({timing_info['elapsed_seconds']:.2f}秒)\n")
f.write(f" 状态: {timing_info['status']}\n")
if timing_info.get('error'):
f.write(f" 错误: {timing_info['error']}\n")
f.write("\n")
f.write("-"*80 + "\n")
f.write("统计摘要:\n")
f.write("-"*80 + "\n")
completed_steps = [s for s in step_timings.values() if s['status'] == 'completed']
failed_steps = [s for s in step_timings.values() if s['status'] == 'failed']
skipped_steps = [s for s in step_timings.values() if s['status'] == 'skipped']
f.write(f"成功完成的步骤: {len(completed_steps)}\n")
f.write(f"失败的步骤: {len(failed_steps)}\n")
f.write(f"跳过的步骤: {len(skipped_steps)}\n")
if completed_steps:
completed_times = [s['elapsed_seconds'] for s in completed_steps]
f.write(f"平均耗时: {_format_time(np.mean(completed_times))}\n")
f.write(f"最长耗时: {_format_time(np.max(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.max(completed_times)][0]})\n")
f.write(f"最短耗时: {_format_time(np.min(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.min(completed_times)][0]})\n")
print(f"\n流程报告已生成:")
print(f" CSV格式: {output_path}")
print(f" 文本格式: {txt_output_path}")
return output_path

View File

@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 散点图生成
"""
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional, Dict, List, Union
from src.core.prediction.inference_batch import WaterQualityInference
from src.postprocessing.visualization_reports import WaterQualityVisualization
def generate_model_scatter_plots(
models_dir: str,
training_csv_path: str,
output_dir: Optional[str] = None,
metric: str = 'test_r2',
use_enhanced: bool = True,
feature_start_column: Union[str, int] = 13,
test_size: float = 0.2,
random_state: int = 42,
scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象
) -> Dict[str, str]:
"""
生成模型评估散点图真实值vs预测值
Args:
models_dir: 模型保存目录
training_csv_path: 训练数据CSV路径
output_dir: 输出目录None则使用默认
metric: 选择最佳模型的指标
use_enhanced: 是否使用增强版散点图带置信区间使用sctter_batch
feature_start_column: 特征开始列名或索引
test_size: 测试集比例
random_state: 随机种子
scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象
Returns:
散点图文件路径字典(键为目标参数名)
"""
print("\n" + "="*80)
print("生成模型评估散点图")
print("="*80)
if training_csv_path is None:
raise ValueError("请提供 training_csv_path")
models_path = Path(models_dir)
if not models_path.exists():
raise ValueError(f"模型目录不存在: {models_dir}")
# 确定输出目录
if output_dir is None:
output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 实例化可视化器
visualizer = WaterQualityVisualization(output_dir)
scatter_paths = {}
# 增强版散点图
if use_enhanced:
print("使用增强版散点图(带置信区间)")
try:
from src.core.prediction.sctter_batch import WaterQualityScatterBatch
if scatter_batch is None:
scatter_batch = WaterQualityScatterBatch()
results = scatter_batch.batch_plot_scatter(
models_root_dir=models_dir,
csv_path=training_csv_path,
output_dir=output_dir,
metric=metric,
target_column=None,
feature_start_column=feature_start_column,
test_size=test_size,
random_state=random_state
)
for target_name, result in results.items():
if result.get('status') == 'success':
scatter_paths[target_name] = result.get('save_path', '')
print(f"{target_name}: {result.get('save_path', '')}")
else:
print(f"{target_name}: 失败 - {result.get('error', '未知错误')}")
except Exception as e:
print(f"使用增强版散点图时出错: {e}")
print("回退到基础版散点图")
use_enhanced = False
# 基础版散点图
if not use_enhanced or not scatter_paths:
print("使用基础版散点图")
for target_folder in models_path.iterdir():
if not target_folder.is_dir():
continue
target_name = target_folder.name
print(f"\n处理目标参数: {target_name}")
try:
inferencer = WaterQualityInference(str(target_folder))
eval_result = inferencer.evaluate_with_split(
data_csv_path=training_csv_path,
split_method="spxy",
test_size=test_size,
random_state=random_state,
metric=metric
)
predictions = eval_result.get('predictions', {})
if predictions:
y_train_true = predictions.get('y_train_true')
y_train_pred = predictions.get('y_train_pred')
y_test_true = predictions.get('y_test_true')
y_test_pred = predictions.get('y_test_pred')
metrics = eval_result.get('test_metrics', {})
if y_train_true is not None and y_test_true is not None:
y_all_true = np.concatenate([y_train_true, y_test_true])
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
train_indices = np.arange(len(y_train_true))
test_indices = np.arange(len(y_train_true), len(y_all_true))
scatter_path = visualizer.plot_scatter_true_vs_pred(
y_true=y_all_true,
y_pred=y_all_pred,
target_name=target_name,
train_indices=train_indices,
test_indices=test_indices,
metrics={
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
'test_r2': metrics.get('r2', 0),
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
'test_rmse': metrics.get('rmse', 0)
}
)
scatter_paths[target_name] = scatter_path
except Exception as e:
print(f"处理目标参数 {target_name} 时出错: {e}")
continue
print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表")
return scatter_paths

View File

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 光谱曲线图生成
"""
import pandas as pd
from pathlib import Path
from typing import Optional, Dict, List, Union
from src.postprocessing.visualization_reports import WaterQualityVisualization
def generate_spectrum_comparison_plots(
csv_path: str,
parameter_columns: Optional[List[str]] = None,
wavelength_start_column: Union[str, int] = "UTM_Y",
output_dir: Optional[str] = None
) -> Dict[str, str]:
"""
生成光谱曲线对比图(不同参数值的光谱曲线对比)
Args:
csv_path: 包含光谱和参数值的CSV文件路径
parameter_columns: 参数列名列表如果为None自动检测
wavelength_start_column: 波长开始列名或索引
output_dir: 输出目录None则使用默认
Returns:
光谱曲线图文件路径字典(键为参数名)
"""
print("\n" + "="*80)
print("生成光谱曲线对比图")
print("="*80)
if csv_path is None:
raise ValueError("请提供 csv_path")
# 确定输出目录
if output_dir is None:
csv_dir = Path(csv_path).parent
output_dir = str(csv_dir / "visualization" / "spectrum_plots")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 实例化可视化器
visualizer = WaterQualityVisualization(output_dir)
# 读取数据以检测参数列
df = pd.read_csv(csv_path)
if parameter_columns is None:
if isinstance(wavelength_start_column, str):
try:
wavelength_start_idx = df.columns.get_loc(wavelength_start_column)
except:
wavelength_start_idx = 13
else:
wavelength_start_idx = wavelength_start_column
parameter_columns = list(df.columns[:wavelength_start_idx])
if len(parameter_columns) > 2:
parameter_columns = parameter_columns[2:]
spectrum_paths = {}
for param_col in parameter_columns:
if param_col not in df.columns:
continue
print(f"\n处理参数: {param_col}")
try:
spectrum_path = visualizer.plot_spectrum_by_parameter(
csv_path=csv_path,
parameter_column=param_col,
wavelength_start_column=wavelength_start_column,
n_groups=5
)
spectrum_paths[param_col] = spectrum_path
except Exception as e:
print(f"处理参数 {param_col} 时出错: {e}")
continue
print(f"\n光谱曲线图生成完成,共生成 {len(spectrum_paths)} 个图表")
return spectrum_paths

View File

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
"""
可视化模块 - 统计图表生成
"""
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional, Dict, List
from src.postprocessing.visualization_reports import WaterQualityVisualization
def generate_statistical_charts(
csv_path: str,
parameter_columns: Optional[List[str]] = None,
output_dir: Optional[str] = None
) -> Dict[str, str]:
"""
生成统计图表(箱线图、直方图、相关性热力图)
Args:
csv_path: CSV文件路径
parameter_columns: 参数列名列表如果为None自动检测
output_dir: 输出目录None则使用默认
Returns:
统计图表文件路径字典
"""
print("\n" + "="*80)
print("生成统计图表")
print("="*80)
if csv_path is None:
raise ValueError("请提供 csv_path")
# 确定输出目录
if output_dir is None:
csv_dir = Path(csv_path).parent
output_dir = str(csv_dir / "visualization" / "statistical_charts")
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 实例化可视化器
visualizer = WaterQualityVisualization(output_dir)
# 读取数据以检测参数列
df = pd.read_csv(csv_path)
if parameter_columns is None:
parameter_columns = list(df.columns[2:])
parameter_columns = [col for col in parameter_columns
if df[col].dtype in [np.float64, np.int64]]
chart_paths = visualizer.plot_statistical_charts(
csv_path=csv_path,
parameter_columns=parameter_columns
)
print(f"\n统计图表生成完成")
return chart_paths

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
# src.gui.components package

View File

@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
自定义组件 - 文件选择控件等公共组件
"""
import os
from PyQt5.QtWidgets import (
QWidget, QHBoxLayout, QLabel, QLineEdit, QPushButton, QFileDialog,
)
from PyQt5.QtCore import Qt
class DirSelectWidget(QWidget):
"""目录选择组件"""
def __init__(self, label_text, parent=None):
"""
初始化目录选择组件
Args:
label_text: 标签文本
parent: 父控件
"""
super().__init__(parent)
self.init_ui(label_text)
def init_ui(self, label_text):
layout = QHBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
self.label = QLabel(label_text)
self.label.setMinimumWidth(120)
self.line_edit = QLineEdit()
self.line_edit.setPlaceholderText("请选择目录...")
self.browse_btn = QPushButton("浏览...")
self.browse_btn.setMaximumWidth(80)
self.browse_btn.clicked.connect(self.browse_dir)
layout.addWidget(self.label)
layout.addWidget(self.line_edit, 1)
layout.addWidget(self.browse_btn)
self.setLayout(layout)
def browse_dir(self):
"""浏览目录 - 智能记忆上次选择位置"""
current_text = self.line_edit.text().strip()
initial_dir = ""
# 最高优先级:输入框已有路径存在
if current_text:
if os.path.isdir(current_text):
initial_dir = current_text
else:
dir_path = os.path.dirname(current_text)
if dir_path and os.path.exists(dir_path):
initial_dir = dir_path
# 调用目录选择对话框
dir_path = QFileDialog.getExistingDirectory(
self, "选择目录", initial_dir
)
if dir_path:
self.line_edit.setText(dir_path)
def get_path(self):
"""获取路径"""
return self.line_edit.text()
def set_path(self, path):
"""设置路径"""
self.line_edit.setText(str(path))
class FileSelectWidget(QWidget):
"""文件选择组件"""
def __init__(self, label_text, file_filter="All Files (*.*)", mode="open", parent=None):
"""
初始化文件选择组件
Args:
label_text: 标签文本
file_filter: 文件过滤器
mode: 选择模式 - "open"(打开文件) 或 "save"(保存文件)
parent: 父控件
"""
super().__init__(parent)
self.file_filter = file_filter
self.mode = mode # "open" 或 "save"
self.init_ui(label_text)
def init_ui(self, label_text):
layout = QHBoxLayout()
layout.setContentsMargins(0, 0, 0, 0)
self.label = QLabel(label_text)
self.label.setMinimumWidth(120)
self.line_edit = QLineEdit()
placeholder = "请选择保存路径..." if self.mode == "save" else "请选择文件..."
self.line_edit.setPlaceholderText(placeholder)
self.browse_btn = QPushButton("浏览...")
self.browse_btn.setMaximumWidth(80)
self.browse_btn.clicked.connect(self.browse_file)
layout.addWidget(self.label)
layout.addWidget(self.line_edit, 1)
layout.addWidget(self.browse_btn)
self.setLayout(layout)
def browse_file(self):
"""浏览文件 - 智能记忆上次选择位置"""
current_text = self.line_edit.text().strip()
initial_dir = ""
# 最高优先级:输入框已有路径存在
if current_text:
if os.path.isdir(current_text):
initial_dir = current_text
else:
dir_path = os.path.dirname(current_text)
if dir_path and os.path.exists(dir_path):
initial_dir = dir_path
if self.mode == "save":
file_path, _ = QFileDialog.getSaveFileName(
self, "保存文件", initial_dir, self.file_filter
)
else:
file_path, _ = QFileDialog.getOpenFileName(
self, "选择文件", initial_dir, self.file_filter
)
if file_path:
self.line_edit.setText(file_path)
def get_path(self):
"""获取路径"""
return self.line_edit.text()
def set_path(self, path):
"""设置路径"""
self.line_edit.setText(str(path))

1
src/gui/core/__init__.py Normal file
View File

@ -0,0 +1 @@
# src.gui.core

View File

@ -0,0 +1,332 @@
# -*- coding: utf-8 -*-
"""
后台线程模块Pipeline 执行线程与诊断逻辑。
"""
import traceback
from PyQt5.QtCore import QThread, pyqtSignal
# =============================================================================
# 依赖诊断
# =============================================================================
def check_pipeline_dependencies():
"""检查pipeline模块的依赖项"""
missing_deps = []
dep_errors = {}
required_packages = [
'numpy', 'pandas', 'scipy', 'matplotlib', 'sklearn',
'joblib', 'PIL', 'cv2', 'rasterio', 'geopandas'
]
for package in required_packages:
try:
if package == 'PIL':
import PIL
elif package == 'cv2':
import cv2
else:
__import__(package)
except Exception as e:
missing_deps.append(package)
dep_errors[package] = repr(e)
return missing_deps, dep_errors
def diagnose_pipeline_import_error():
"""诊断pipeline导入错误"""
import sys
import os
error_info = []
is_frozen = getattr(sys, "frozen", False) or bool(getattr(sys, "_MEIPASS", None))
if is_frozen:
error_info.append(
"[INFO] PyInstaller 环境Pipeline 从程序内置包加载,跳过对仓库路径 src/core/*.py 的磁盘检查"
)
else:
pipeline_file = os.path.normpath(
os.path.join(os.path.dirname(__file__), "..", "..", "core", "water_quality_inversion_pipeline_GUI.py")
)
if not os.path.exists(pipeline_file):
error_info.append(f"[ERROR] Pipeline文件不存在: {pipeline_file}")
error_info.append(
" 解决方案: 请确保项目结构完整,检查 src/core/ 下是否有 water_quality_inversion_pipeline_GUI.py"
)
else:
error_info.append(f"[OK] Pipeline文件存在: {pipeline_file}")
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
error_info.append(f"[INFO] 已添加路径到sys.path: {current_dir}")
missing_deps, dep_errors = check_pipeline_dependencies()
if missing_deps:
error_info.append(f"[ERROR] 缺少必需的依赖包: {', '.join(missing_deps)}")
for pkg in missing_deps:
if pkg in dep_errors:
error_info.append(f" - {pkg} 导入失败原因: {dep_errors[pkg]}")
error_info.append(" 解决方案: 请运行以下命令安装依赖:")
error_info.append(" pip install -r requirements.txt")
error_info.append(" 或使用conda:")
error_info.append(" conda install numpy pandas scipy matplotlib scikit-learn joblib pillow opencv-python rasterio geopandas")
else:
error_info.append("[OK] 主要依赖包均已安装")
try:
from osgeo import gdal # noqa: F401
error_info.append("[OK] GDAL (osgeo) 可用")
except ImportError:
try:
from osgeo import gdal # noqa: F401
error_info.append("[OK] GDAL 可用")
except ImportError:
error_info.append("[WARNING] GDAL/osgeo 不可用,将影响栅格与地理数据处理")
error_info.append(" 开发环境: conda install gdal")
error_info.append(" 打包环境: 请在构建所用 Conda 环境中打包,并确保 spec 已收集 Library/bin 中依赖 DLL")
try:
import unittest
error_info.append("[OK] unittest模块可用")
except ImportError:
error_info.append("[WARNING] unittest模块不可用这可能是PyInstaller打包环境导致的")
error_info.append(" 这不会影响主要功能,但可能影响某些测试相关特性")
return error_info
# =============================================================================
# Pipeline 可用性标志(模块级状态)
# =============================================================================
PIPELINE_AVAILABLE = False
PIPELINE_ERROR_INFO = []
try:
error_info = diagnose_pipeline_import_error()
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
PIPELINE_AVAILABLE = True
print("[OK] 成功导入pipeline模块")
PIPELINE_ERROR_INFO = error_info
except ImportError as e:
PIPELINE_AVAILABLE = False
error_info = diagnose_pipeline_import_error()
print("="*60)
print("[ERROR] PIPELINE导入失败 - 详细诊断信息:")
print("="*60)
for info in error_info:
print(info)
print("-"*60)
print(f"原始ImportError: {str(e)}")
print("-"*60)
if "unittest" in str(e):
print("[INFO] unittest模块缺失 - 这通常在PyInstaller打包环境中发生")
print("解决方案:")
print(" 1. 这不会影响主要功能,程序仍可正常运行")
print(" 2. 如果需要修复,可以在.spec文件中添加unittest模块:")
print(" a = Analysis(..., hiddenimports=['unittest', 'unittest.mock'])")
print(" 3. 或在PyInstaller命令中添加: --hidden-import unittest")
elif "water_quality_inversion_pipeline_GUI" in str(e):
print("[INFO] 可能的解决方案:")
print(" 1. 检查src/core/water_quality_inversion_pipeline_GUI.py文件是否存在")
print(" 2. 确保Python路径设置正确")
print(" 3. 尝试重新安装依赖: pip install -r requirements.txt")
print(" 4. 检查Python版本是否兼容推荐Python 3.8-3.11")
import traceback
print("\n完整错误追踪:")
traceback.print_exc()
print("="*60)
PIPELINE_ERROR_INFO = error_info
except Exception as e:
PIPELINE_AVAILABLE = False
error_info = diagnose_pipeline_import_error()
print("="*60)
print("[ERROR] PIPELINE导入失败 - 其他错误:")
print("="*60)
for info in error_info:
print(info)
print("-"*60)
print(f"原始错误: {str(e)}")
print("-"*60)
print("[INFO] 可能的解决方案:")
print(" 1. 检查Python环境和依赖包版本")
print(" 2. 尝试重新安装所有依赖")
print(" 3. 检查是否有语法错误或其他模块导入问题")
import traceback
print("\n完整错误追踪:")
traceback.print_exc()
print("="*60)
PIPELINE_ERROR_INFO = error_info
# =============================================================================
# WorkerThread
# =============================================================================
class WorkerThread(QThread):
"""后台工作线程,用于执行耗时任务(在工作线程内创建 Pipeline避免阻塞 UI"""
progress_update = pyqtSignal(int, str) # 进度更新信号 (percentage, message)
log_message = pyqtSignal(str, str) # 日志消息信号 (message, level: 'info'/'warning'/'error')
step_completed = pyqtSignal(str, bool, str) # 步骤完成信号 (step_name, success, message)
finished = pyqtSignal(bool, str) # 完成信号 (success, message)
def __init__(self, work_dir: str, config, mode='full', step_name=None):
super().__init__()
self.work_dir = str(work_dir)
self.config = config
self.mode = mode # 'full' 或 'single_step'
self.step_name = step_name # 单步执行时的步骤名称
self.pipeline = None
self.is_running = True
self.current_step = None
self.step_count = 0
self.total_steps = 9
def pipeline_callback(self, step_name, status, message=""):
"""Pipeline回调函数用于接收步骤状态"""
if status == "start":
self.log_message.emit(f"[START] 开始执行: {step_name}", "info")
progress = int((self.step_count / self.total_steps) * 100)
self.progress_update.emit(progress, f"正在执行: {step_name}")
elif status == "completed":
self.step_count += 1
self.log_message.emit(f"[DONE] 完成: {step_name} {message}", "info")
self.step_completed.emit(step_name, True, message)
progress = int((self.step_count / self.total_steps) * 100)
self.progress_update.emit(progress, f"已完成: {step_name}")
elif status == "skipped":
self.step_count += 1
self.log_message.emit(f"[SKIP] 跳过: {step_name} {message}", "warning")
self.step_completed.emit(step_name, True, f"跳过: {message}")
progress = int((self.step_count / self.total_steps) * 100)
self.progress_update.emit(progress, f"已跳过: {step_name}")
elif status == "error":
self.log_message.emit(f"[ERROR] 错误: {step_name} - {message}", "error")
self.step_completed.emit(step_name, False, message)
elif status == "info":
self.log_message.emit(f" {message}", "info")
elif status == "warning":
self.log_message.emit(f" [WARNING] {message}", "warning")
def run(self):
"""运行 pipeline子线程内切换 Matplotlib 为 Agg避免 Qt5Agg 在后台线程绘图导致界面卡死。"""
import os
# GDAL 环境变量保护(放在最前面,防止路径/编码问题)
os.environ['GDAL_FILENAME_IS_UTF8'] = 'YES'
os.environ['SHAPE_ENCODING'] = 'UTF-8'
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
self.pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
if self.mode == 'full':
self.log_message.emit("开始运行完整流程...", "info")
self.step_count = 0
if hasattr(self.pipeline, 'set_callback'):
self.pipeline.set_callback(self.pipeline_callback)
self.pipeline.run_full_pipeline(self.config)
self.progress_update.emit(100, "流程执行完成")
self.finished.emit(True, "完整流程执行成功!")
else:
self.log_message.emit(f"开始独立运行步骤: {self.step_name}", "info")
self.progress_update.emit(0, f"正在执行: {self.step_name}")
if hasattr(self.pipeline, 'set_callback'):
self.pipeline.set_callback(self.pipeline_callback)
self.run_single_step(self.step_name, self.config)
self.progress_update.emit(100, f"步骤 {self.step_name} 执行完成")
self.finished.emit(True, f"步骤 {self.step_name} 独立运行成功!")
except Exception as e:
error_msg = f"执行失败: {str(e)}\n{traceback.format_exc()}"
self.log_message.emit(error_msg, "error")
self.finished.emit(False, error_msg)
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass
def run_single_step(self, step_name, config):
"""运行单个步骤"""
step_method_map = {
'step1': 'step1_generate_water_mask',
'step2': 'step2_find_glint_area',
'step3': 'step3_remove_glint',
'step4': 'step4_process_csv',
'step5': 'step5_extract_training_spectra',
'step5_5': 'step5_5_calculate_water_quality_indices',
'step6': 'step6_train_models',
'step6_5': 'step6_5_non_empirical_modeling',
'step6_75': 'step6_75_custom_regression',
'step7': 'step7_generate_sampling_points',
'step8': 'step8_predict_water_quality',
'step8_5': 'step8_5_predict_with_non_empirical_models',
'step8_75': 'step8_75_predict_with_custom_regression',
'step9': 'step9_generate_distribution_map'
}
if step_name not in step_method_map:
raise ValueError(f"未知的步骤名称: {step_name}")
method_name = step_method_map[step_name]
step_config = dict(config.get(step_name, {}))
step_config['skip_dependency_check'] = True
if step_name == 'step9':
step_config.pop('step9_batch_mode', None)
step_config.pop('prediction_csv_dir', None)
step_config.pop('recursive_csv_scan', None)
if step_name in ['step2', 'step3', 'step4', 'step5', 'step6', 'step7', 'step8', 'step8_5', 'step8_75']:
step_config.pop('output_path', None)
if step_name == 'step8_5' and 'models_dir' in step_config:
step_config['non_empirical_models_dir'] = step_config.pop('models_dir')
method = getattr(self.pipeline, method_name)
result = method(**step_config)
return result
def stop(self):
"""停止执行"""
self.is_running = False
self.terminate()

View File

@ -0,0 +1,46 @@
Formula_Name,Category,Formula,Reference
BGA_Am09KBBI,Phycocyanin (BGA_PC),(w686 - w658) / (w686 + w658),"Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S.; Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery, Optics Express, 2009, 17, 11, 1-13."
BGA_Be162B643sub629,Phycocyanin (BGA_PC),w644 - w629,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
BGA_Be162B700sub601,Phycocyanin (BGA_PC),w700 - w601,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
BGA_Be162BsubPhy,Phycocyanin (BGA_PC),w715 - w615,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540."
BGA_Be16FLHBlueRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w458 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
BGA_Be16FLHGreenRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w558 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
BGA_Be16FLHVioletRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w444 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
BGA_Be16MPI,Phycocyanin (BGA_PC),(w615 - w601) - (w644 - w601),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
BGA_Be16NDPhyI,Phycocyanin (BGA_PC),(w700 - w622) / (w700 + w622),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540."
BGA_Be16NDPhyI644over615,Phycocyanin (BGA_PC),(w644 - w615) / (w644 + w615),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 541."
BGA_Be16NDPhyI644over629,Phycocyanin (BGA_PC),(w644 - w629) / (w644 + w629),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 542."
BGA_Be16Phy2BDA644over629,Phycocyanin (BGA_PC),w644 / w629,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 545."
BGA_Da052BDA,Phycocyanin (BGA_PC),w714 / w672,"Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672."
BGA_Go04MCI,Phycocyanin (BGA_PC),w709 - w681 - (w753 - w681),"Gower, J.F.R.; Brown,L.; Borstad, G.A.; Observation of chlorophyll fluorescence in west coast waters of Canada using the MODIS satellite sensor. Can. J. Remote Sens., 2004, 30 (1), 17闁?5."
BGA_HU103BDA,Phycocyanin (BGA_PC),(((1 / w615) - (1 / w600)) - w725),"Hunter, P.D.; Tyler, A.N.; Willby, N.J.; Gilvear, D.J.; The spatial dynamics of vertical migration by Microcystis aeruginosa in a eutrophic shallow lake: A case study using high spatial resolution time-series airborne remote sensing. Limn. Oceanogr. 2008, 53, 2391-2406"
BGA_Ku15PhyCI,Phycocyanin (BGA_PC),(-1 * (W681 - W665 - (W709 - W665))),"Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-10."
BGA_Ku15SLH,Phycocyanin (BGA_PC),(w715 - w658) + (w715 - w658),"Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-11."
BGA_MI092BDA,Phycocyanin (BGA_PC),w700 / w600,"Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758闁?75."
BGA_MM092BDA,Phycocyanin (BGA_PC),w724 / w600,"Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758闁?76."
BGA_MM12NDCIalt,Phycocyanin (BGA_PC),(w700 - w658) / (w700 + w658),"Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114003"
BGA_MM143BDAopt,Phycocyanin (BGA_PC),((1 / w629) - (1 / w659)) * w724,"Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114004"
BGA_SI052BDA,Phycocyanin (BGA_PC),w709 / w620,"Simis, S. G. H.; Peters, S.W. M.; Gons, H. J.; Remote sensing of the cyanobacteria pigment phycocyanin in turbid inland water. Limn. Oceanogr., 2005, 50, 237闁?45"
BGA_SM122BDA,Phycocyanin (BGA_PC),w709 / w600,"Mishra, S. Remote sensing of cyanobacteria in turbid productive waters, PhD Dissertation. Mississippi State University, USA. 2012."
BGA_SY002BDA,Phycocyanin (BGA_PC),w650 / w625,"Schalles, J.; Yacobi, Y. Remote detection and seasonal patterns of phycocyanin, carotenoid and chlorophyll-a pigments in eutrophic waters. Archiv fur Hydrobiologie, Special Issues Advances in Limnology, 2000, 55,153闁?68"
BGA_Wy08CI,Phycocyanin (BGA_PC),(-1 * (W686 - W672 - (W715 - W672))),"Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672."
Chl_Al10SABI,chlorophyll_a,(w857 - w644) / (w458 + w529),"Alawadi, F. Detection of surface algal blooms using the newly developed algorithm surface algal bloom index (SABI). Proc. SPIE 2010, 7825."
Chl_Am092Bsub,chlorophyll_a,w681 - w665,"Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S. Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery. Opt. Express 2009, 17, 9126闁?144."
Chl_Be16FLHblue,chlorophyll_a,w529 - (w644 + (w458 - w644)),"Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30."
Chl_Be16FLHviolet,chlorophyll_a,w529 - (w644 + (w429 - w644)),"Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30."
Chl_Be16NDTIblue,chlorophyll_a,(w658 - w458) / (w658 + w458),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 543."
Chl_Be16NDTIviolet,chlorophyll_a,(w658 - w444) / (w658 + w444),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 544."
Chl_De933BDA,chlorophyll_a,w600 - w648 - w625,"Dekker, A.; Detection of the optical water quality parameters for eutrophic waters by high resolution remote sensing, Ph.D. thesis, 1993, Free University, Amsterdam."
Chl_Gi033BDA,chlorophyll_a,((1 / w672) - (1 / w715)) * w757,"Gitelson, A.A.; U. Gritz, and M. N. Merzlyak.; Relationships between leaf chlorophyll content and spectral reflectance and algorithms for non-destructive chlorophyll assessment in higher plant leaves. J. Plant Phys. 2003, 160, 271-282."
Chl_Kn07KIVU,chlorophyll_a,(w458 - w644) / w529,"Kneubuhler, M.; Frank T.; Kellenberger, T.W; Pasche N.; Schmid M.; Mapping chlorophyll-a in Lake Kivu with remote sensing methods. 2007, Proceedings of the Envisat Symposium 2007, Montreux, Switzerland 23闁?7 April 2007 (ESA SP-636, July 2007)."
Chl_MM12NDCI,chlorophyll_a,(w715 - w686) / (w715 + w686),"Mishra, S.; and Mishra, D.R. Normalized difference chlorophyll index: A novel model for remote estimation of chlorophyll-a concentration in turbid productive waters, Remote Sens. Environ., 2012, 117, 394-406"
Chl_Zh10FLH,chlorophyll_a,w686 - (w715 + (w672 - w751)),"Zhao, D.Z.; Xing, X.G.; Liu, Y.G.; Yang, J.H.; Wang, L. The relation of chlorophyll-a concentration with the reflectance peak near 700 nm in algae-dominated waters and sensitivity of fluorescence algorithms for detecting algal bloom. Int. J. Remote Sens. 2010, 31, 39-48"
Turb_Be16GreenPlusRedBothOverViolet,Turbidity,(w558 + w658) / w444,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538"
Turb_Be16RedOverViolet,Turbidity,w658 / w444,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539"
Turb_Bow06RedOverGreen,Turbidity,w658 / w558,"Bowers, D. G., and C. E. Binding. 2006. 闁炽儲缈籬e Optical Properties of Mineral Suspended Particles: A Review and Synthesis.闁?Estuarine Coastal and Shelf Science 67 (1闁?): 219闁?30. doi:10.1016/j.ecss.2005.11.010"
Turb_Chip09NIROverGreen,Turbidity,w857 / w558,"Chipman, J. W.; Olmanson, L.G.; Gitelson, A.A.; Remote sensing methods for lake management: A guide for resource managers and decision-makers. 2009."
Turb_Dox02NIRoverRed,Turbidity,w857 / w658,"Doxaran, D., Froidefond, J.-M.; Castaing, P. ; A reflectance band ratio used to estimate suspended matter concentrations in sediment-dominated coastal waters, Remote Sens., 2002, 23, 5079-5085"
Turb_Frohn09GreenPlusRedBothOverBlue,Turbidity,(w558 + w658) / w458,"Frohn, R. C., & Autrey, B. C. (2009). Water quality assessment in the Ohio River using new indices for turbidity and chlorophyll-a with Landsat-7 Imagery. Draft Internal Report, US Environmental Protection Agency."
Turb_Harr92NIR,Turbidity,w857,"Schiebe F.R., Harrington J.A., Ritchie J.C. Remote-Sensing of Suspended Sediments闁炽儲鏁刪e Lake Chicot, Arkansas Project. Int. J. Remote Sens. 1992;13:1487闁?509"
Turb_Lath91RedOverBlue,Turbidity,w658 / w458,"Lathrop, R. G., Jr., T. M. Lillesand, and B. S. Yandell, 1991. Testing the utility of simple multi-date Thematic Mapper calibration algorithms for monitoring turbid inland waters. International Journal of Remote Sensing"
Turb_Moore80Red,Turbidity,w658,"Moore, G.K., Satellite remote sensing of water turbidity, Hydrological Sciences, 1980, 25, 4, 407-422"
1 Formula_Name Category Formula Reference
2 BGA_Am09KBBI Phycocyanin (BGA_PC) (w686 - w658) / (w686 + w658) Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S.; Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery, Optics Express, 2009, 17, 11, 1-13.
3 BGA_Be162B643sub629 Phycocyanin (BGA_PC) w644 - w629 Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538.
4 BGA_Be162B700sub601 Phycocyanin (BGA_PC) w700 - w601 Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539.
5 BGA_Be162BsubPhy Phycocyanin (BGA_PC) w715 - w615 Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540.
6 BGA_Be16FLHBlueRedNIR Phycocyanin (BGA_PC) w658 - (w857 + (w458 - w857)) Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538.
7 BGA_Be16FLHGreenRedNIR Phycocyanin (BGA_PC) w658 - (w857 + (w558 - w857)) Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539.
8 BGA_Be16FLHVioletRedNIR Phycocyanin (BGA_PC) w658 - (w857 + (w444 - w857)) Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538.
9 BGA_Be16MPI Phycocyanin (BGA_PC) (w615 - w601) - (w644 - w601) Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539.
10 BGA_Be16NDPhyI Phycocyanin (BGA_PC) (w700 - w622) / (w700 + w622) Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540.
11 BGA_Be16NDPhyI644over615 Phycocyanin (BGA_PC) (w644 - w615) / (w644 + w615) Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 541.
12 BGA_Be16NDPhyI644over629 Phycocyanin (BGA_PC) (w644 - w629) / (w644 + w629) Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 542.
13 BGA_Be16Phy2BDA644over629 Phycocyanin (BGA_PC) w644 / w629 Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 545.
14 BGA_Da052BDA Phycocyanin (BGA_PC) w714 / w672 Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672.
15 BGA_Go04MCI Phycocyanin (BGA_PC) w709 - w681 - (w753 - w681) Gower, J.F.R.; Brown,L.; Borstad, G.A.; Observation of chlorophyll fluorescence in west coast waters of Canada using the MODIS satellite sensor. Can. J. Remote Sens., 2004, 30 (1), 17闁?5.
16 BGA_HU103BDA Phycocyanin (BGA_PC) (((1 / w615) - (1 / w600)) - w725) Hunter, P.D.; Tyler, A.N.; Willby, N.J.; Gilvear, D.J.; The spatial dynamics of vertical migration by Microcystis aeruginosa in a eutrophic shallow lake: A case study using high spatial resolution time-series airborne remote sensing. Limn. Oceanogr. 2008, 53, 2391-2406
17 BGA_Ku15PhyCI Phycocyanin (BGA_PC) (-1 * (W681 - W665 - (W709 - W665))) Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-10.
18 BGA_Ku15SLH Phycocyanin (BGA_PC) (w715 - w658) + (w715 - w658) Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-11.
19 BGA_MI092BDA Phycocyanin (BGA_PC) w700 / w600 Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758闁?75.
20 BGA_MM092BDA Phycocyanin (BGA_PC) w724 / w600 Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758闁?76.
21 BGA_MM12NDCIalt Phycocyanin (BGA_PC) (w700 - w658) / (w700 + w658) Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114003
22 BGA_MM143BDAopt Phycocyanin (BGA_PC) ((1 / w629) - (1 / w659)) * w724 Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114004
23 BGA_SI052BDA Phycocyanin (BGA_PC) w709 / w620 Simis, S. G. H.; Peters, S.W. M.; Gons, H. J.; Remote sensing of the cyanobacteria pigment phycocyanin in turbid inland water. Limn. Oceanogr., 2005, 50, 237闁?45
24 BGA_SM122BDA Phycocyanin (BGA_PC) w709 / w600 Mishra, S. Remote sensing of cyanobacteria in turbid productive waters, PhD Dissertation. Mississippi State University, USA. 2012.
25 BGA_SY002BDA Phycocyanin (BGA_PC) w650 / w625 Schalles, J.; Yacobi, Y. Remote detection and seasonal patterns of phycocyanin, carotenoid and chlorophyll-a pigments in eutrophic waters. Archiv fur Hydrobiologie, Special Issues Advances in Limnology, 2000, 55,153闁?68
26 BGA_Wy08CI Phycocyanin (BGA_PC) (-1 * (W686 - W672 - (W715 - W672))) Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672.
27 Chl_Al10SABI chlorophyll_a (w857 - w644) / (w458 + w529) Alawadi, F. Detection of surface algal blooms using the newly developed algorithm surface algal bloom index (SABI). Proc. SPIE 2010, 7825.
28 Chl_Am092Bsub chlorophyll_a w681 - w665 Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S. Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery. Opt. Express 2009, 17, 9126闁?144.
29 Chl_Be16FLHblue chlorophyll_a w529 - (w644 + (w458 - w644)) Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30.
30 Chl_Be16FLHviolet chlorophyll_a w529 - (w644 + (w429 - w644)) Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30.
31 Chl_Be16NDTIblue chlorophyll_a (w658 - w458) / (w658 + w458) Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 543.
32 Chl_Be16NDTIviolet chlorophyll_a (w658 - w444) / (w658 + w444) Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 544.
33 Chl_De933BDA chlorophyll_a w600 - w648 - w625 Dekker, A.; Detection of the optical water quality parameters for eutrophic waters by high resolution remote sensing, Ph.D. thesis, 1993, Free University, Amsterdam.
34 Chl_Gi033BDA chlorophyll_a ((1 / w672) - (1 / w715)) * w757 Gitelson, A.A.; U. Gritz, and M. N. Merzlyak.; Relationships between leaf chlorophyll content and spectral reflectance and algorithms for non-destructive chlorophyll assessment in higher plant leaves. J. Plant Phys. 2003, 160, 271-282.
35 Chl_Kn07KIVU chlorophyll_a (w458 - w644) / w529 Kneubuhler, M.; Frank T.; Kellenberger, T.W; Pasche N.; Schmid M.; Mapping chlorophyll-a in Lake Kivu with remote sensing methods. 2007, Proceedings of the Envisat Symposium 2007, Montreux, Switzerland 23闁?7 April 2007 (ESA SP-636, July 2007).
36 Chl_MM12NDCI chlorophyll_a (w715 - w686) / (w715 + w686) Mishra, S.; and Mishra, D.R. Normalized difference chlorophyll index: A novel model for remote estimation of chlorophyll-a concentration in turbid productive waters, Remote Sens. Environ., 2012, 117, 394-406
37 Chl_Zh10FLH chlorophyll_a w686 - (w715 + (w672 - w751)) Zhao, D.Z.; Xing, X.G.; Liu, Y.G.; Yang, J.H.; Wang, L. The relation of chlorophyll-a concentration with the reflectance peak near 700 nm in algae-dominated waters and sensitivity of fluorescence algorithms for detecting algal bloom. Int. J. Remote Sens. 2010, 31, 39-48
38 Turb_Be16GreenPlusRedBothOverViolet Turbidity (w558 + w658) / w444 Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538
39 Turb_Be16RedOverViolet Turbidity w658 / w444 Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539
40 Turb_Bow06RedOverGreen Turbidity w658 / w558 Bowers, D. G., and C. E. Binding. 2006. 闁炽儲缈籬e Optical Properties of Mineral Suspended Particles: A Review and Synthesis.闁?Estuarine Coastal and Shelf Science 67 (1闁?): 219闁?30. doi:10.1016/j.ecss.2005.11.010
41 Turb_Chip09NIROverGreen Turbidity w857 / w558 Chipman, J. W.; Olmanson, L.G.; Gitelson, A.A.; Remote sensing methods for lake management: A guide for resource managers and decision-makers. 2009.
42 Turb_Dox02NIRoverRed Turbidity w857 / w658 Doxaran, D., Froidefond, J.-M.; Castaing, P. ; A reflectance band ratio used to estimate suspended matter concentrations in sediment-dominated coastal waters, Remote Sens., 2002, 23, 5079-5085
43 Turb_Frohn09GreenPlusRedBothOverBlue Turbidity (w558 + w658) / w458 Frohn, R. C., & Autrey, B. C. (2009). Water quality assessment in the Ohio River using new indices for turbidity and chlorophyll-a with Landsat-7 Imagery. Draft Internal Report, US Environmental Protection Agency.
44 Turb_Harr92NIR Turbidity w857 Schiebe F.R., Harrington J.A., Ritchie J.C. Remote-Sensing of Suspended Sediments闁炽儲鏁刪e Lake Chicot, Arkansas Project. Int. J. Remote Sens. 1992;13:1487闁?509
45 Turb_Lath91RedOverBlue Turbidity w658 / w458 Lathrop, R. G., Jr., T. M. Lillesand, and B. S. Yandell, 1991. Testing the utility of simple multi-date Thematic Mapper calibration algorithms for monitoring turbid inland waters. International Journal of Remote Sensing
46 Turb_Moore80Red Turbidity w658 Moore, G.K., Satellite remote sensing of water turbidity, Hydrological Sciences, 1980, 25, 4, 407-422

View File

@ -0,0 +1,315 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
ReportGenerationPanel - Word 分析报告生成面板
"""
import os
import traceback
from pathlib import Path
from typing import Optional
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout,
QLabel, QCheckBox, QPushButton, QLineEdit, QSpinBox,
QMessageBox, QFileDialog,
)
from src.gui.styles import ModernStylesheet
class ReportGenerateThread(QThread):
"""后台生成 Word 报告(避免阻塞 UI"""
finished_ok = pyqtSignal(str)
failed = pyqtSignal(str)
log_message = pyqtSignal(str, str)
def __init__(self, work_dir: str, output_dir: Optional[str], report_title: str, options: dict):
super().__init__()
self.work_dir = work_dir
self.output_dir = output_dir
self.report_title = report_title
self.options = options
def run(self):
try:
from src.postprocessing.report_word import WaterQualityReportGenerator, ReportGenerationConfig
url = (self.options.get("ollama_url") or "").strip() or None
vision = (self.options.get("ollama_vision_model") or "").strip() or None
text = (self.options.get("ollama_text_model") or "").strip() or None
if self.options.get("text_same_as_vision"):
text = vision
timeout = self.options.get("ollama_timeout_s")
enable_ai = self.options.get("enable_ai_analysis")
ai_cfg = ReportGenerationConfig(
ollama_base_url=url,
ollama_vision_model=vision,
ollama_text_model=text,
ollama_timeout_s=int(timeout) if timeout is not None else None,
enable_ai_analysis=bool(enable_ai),
)
self.log_message.emit(
f"报告生成:工作目录={self.work_dir}AI={'' if enable_ai else ''}"
f"模型URL={url or '(环境变量 OLLAMA_URL)'}",
"info",
)
gen = WaterQualityReportGenerator(
work_dir=self.work_dir,
output_dir=self.output_dir,
ai_config=ai_cfg,
)
out_path = gen.generate_report(
work_dir=self.work_dir,
report_title=self.report_title or "水质参数反演分析报告",
)
self.finished_ok.emit(str(out_path))
except Exception as e:
self.failed.emit(f"{e}\n{traceback.format_exc()}")
class ReportGenerationPanel(QWidget):
"""Word 报告生成工作目录、输出目录、Ollama URL/模型、是否启用 AI 等。"""
def __init__(self, main_window=None, parent=None):
super().__init__(parent)
self.main_window = main_window
self._report_thread = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
layout.setContentsMargins(10, 10, 10, 10)
layout.setSpacing(10)
intro = QLabel(
"根据工作目录下的可视化结果14_visualization 等)生成 Word 分析报告。"
"需已存在可视化图表AI 分析通过 Ollama /api/chat 调用本地或远程服务。"
)
intro.setWordWrap(True)
intro.setStyleSheet(
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
)
layout.addWidget(intro)
path_group = QGroupBox("路径")
path_form = QFormLayout()
wd_row = QHBoxLayout()
self.work_dir_edit = QLineEdit()
self.work_dir_edit.setPlaceholderText("选择流程工作目录(含 14_visualization")
wd_browse = QPushButton("浏览…")
wd_browse.clicked.connect(self.browse_work_dir)
sync_btn = QPushButton("同步主窗口工作目录")
sync_btn.clicked.connect(self.sync_work_dir_from_main)
wd_row.addWidget(self.work_dir_edit, 1)
wd_row.addWidget(wd_browse)
wd_row.addWidget(sync_btn)
path_form.addRow("工作目录:", wd_row)
out_row = QHBoxLayout()
self.output_dir_edit = QLineEdit()
self.output_dir_edit.setPlaceholderText("留空则保存到 工作目录/14_visualization")
out_browse = QPushButton("浏览…")
out_browse.clicked.connect(self.browse_output_dir)
out_row.addWidget(self.output_dir_edit, 1)
out_row.addWidget(out_browse)
path_form.addRow("报告输出目录:", out_row)
self.report_title_edit = QLineEdit()
self.report_title_edit.setText("水质参数反演分析报告")
path_form.addRow("报告标题:", self.report_title_edit)
path_group.setLayout(path_form)
layout.addWidget(path_group)
ai_group = QGroupBox("AI 分析Ollama")
ai_form = QFormLayout()
self.enable_ai_cb = QCheckBox("启用 AI 图表解读与综合总结")
self.enable_ai_cb.setChecked(
os.environ.get("ENABLE_AI_ANALYSIS", "1") not in {"0", "false", "False"}
)
ai_form.addRow(self.enable_ai_cb)
self.ollama_url_edit = QLineEdit()
self.ollama_url_edit.setText(
os.environ.get("OLLAMA_URL", "http://localhost:11434").rstrip("/")
)
ai_form.addRow("服务 URL:", self.ollama_url_edit)
self.vision_model_edit = QLineEdit()
self.vision_model_edit.setText(
os.environ.get("OLLAMA_VISION_MODEL", "qwen3-vl:8b")
)
ai_form.addRow("视觉模型:", self.vision_model_edit)
self.same_text_model_cb = QCheckBox("文本总结与视觉使用同一模型")
self.same_text_model_cb.setChecked(True)
ai_form.addRow(self.same_text_model_cb)
self.text_model_edit = QLineEdit()
self.text_model_edit.setText(
os.environ.get(
"OLLAMA_TEXT_MODEL",
self.vision_model_edit.text() or "qwen3-vl:8b"
)
)
self.text_model_edit.setEnabled(False)
self.same_text_model_cb.toggled.connect(self._on_same_text_toggled)
self.vision_model_edit.textChanged.connect(self._sync_text_model_if_linked)
ai_form.addRow("文本模型:", self.text_model_edit)
self.timeout_spin = QSpinBox()
self.timeout_spin.setRange(30, 3600)
self.timeout_spin.setSingleStep(30)
self.timeout_spin.setValue(int(os.environ.get("OLLAMA_TIMEOUT_S", "120")))
ai_form.addRow("请求超时(秒):", self.timeout_spin)
ai_group.setLayout(ai_form)
layout.addWidget(ai_group)
btn_row = QHBoxLayout()
self.generate_btn = QPushButton("生成 Word 报告")
self.generate_btn.setStyleSheet(
ModernStylesheet.get_button_stylesheet("success")
)
self.generate_btn.clicked.connect(self.on_generate_clicked)
btn_row.addWidget(self.generate_btn)
btn_row.addStretch()
layout.addLayout(btn_row)
layout.addStretch()
self.setLayout(layout)
def _on_same_text_toggled(self, checked: bool):
self.text_model_edit.setEnabled(not checked)
if checked:
self.text_model_edit.setText(self.vision_model_edit.text())
def _sync_text_model_if_linked(self, _t=None):
if self.same_text_model_cb.isChecked():
self.text_model_edit.blockSignals(True)
self.text_model_edit.setText(self.vision_model_edit.text())
self.text_model_edit.blockSignals(False)
def _get_default_work_dir(self):
"""获取 work_dir优先用主窗口缓存的 work_dir"""
if self.main_window and hasattr(self.main_window, 'work_dir') and self.main_window.work_dir:
return str(self.main_window.work_dir)
return ""
def browse_work_dir(self):
default = self._get_default_work_dir()
d = QFileDialog.getExistingDirectory(self, "选择工作目录", default)
if d:
self.work_dir_edit.setText(d)
def browse_output_dir(self):
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "14_visualization")
d = QFileDialog.getExistingDirectory(self, "选择报告输出目录", default)
if d:
self.output_dir_edit.setText(d)
def sync_work_dir_from_main(self):
mw = self.main_window
if mw is not None and getattr(mw, "work_dir", None):
self.work_dir_edit.setText(str(mw.work_dir))
else:
QMessageBox.information(self, "提示", "主窗口尚未设置工作目录。")
def set_work_dir(self, work_dir):
if work_dir:
self.work_dir_edit.setText(str(work_dir))
def get_config(self):
return {
"work_dir": self.work_dir_edit.text().strip() or None,
"output_dir": self.output_dir_edit.text().strip() or None,
"report_title": self.report_title_edit.text().strip() or "水质参数反演分析报告",
"ollama_url": self.ollama_url_edit.text().strip(),
"ollama_vision_model": self.vision_model_edit.text().strip(),
"ollama_text_model": self.text_model_edit.text().strip(),
"text_same_as_vision": self.same_text_model_cb.isChecked(),
"ollama_timeout_s": self.timeout_spin.value(),
"enable_ai_analysis": self.enable_ai_cb.isChecked(),
}
def set_config(self, config):
if not config:
return
if config.get("work_dir"):
self.work_dir_edit.setText(str(config["work_dir"]))
if "output_dir" in config:
self.output_dir_edit.setText(str(config["output_dir"] or ""))
if config.get("report_title"):
self.report_title_edit.setText(str(config["report_title"]))
if config.get("ollama_url"):
self.ollama_url_edit.setText(str(config["ollama_url"]))
if config.get("ollama_vision_model"):
self.vision_model_edit.setText(str(config["ollama_vision_model"]))
if "text_same_as_vision" in config:
self.same_text_model_cb.setChecked(bool(config["text_same_as_vision"]))
if config.get("ollama_text_model"):
self.text_model_edit.setText(str(config["ollama_text_model"]))
if config.get("ollama_timeout_s") is not None:
self.timeout_spin.setValue(int(config["ollama_timeout_s"]))
if "enable_ai_analysis" in config:
self.enable_ai_cb.setChecked(bool(config["enable_ai_analysis"]))
def on_generate_clicked(self):
wd = self.work_dir_edit.text().strip()
if not wd or not os.path.isdir(wd):
QMessageBox.warning(self, "提示", "请选择有效的工作目录。")
return
viz = Path(wd) / "14_visualization"
if not viz.is_dir():
QMessageBox.warning(
self,
"提示",
f"未找到可视化目录:\n{viz}\n请先完成流程或生成可视化。",
)
return
if self._report_thread and self._report_thread.isRunning():
QMessageBox.information(self, "提示", "报告正在生成中,请稍候。")
return
out = self.output_dir_edit.text().strip() or None
title = self.report_title_edit.text().strip() or "水质参数反演分析报告"
opts = {
"ollama_url": self.ollama_url_edit.text().strip(),
"ollama_vision_model": self.vision_model_edit.text().strip(),
"ollama_text_model": self.text_model_edit.text().strip(),
"text_same_as_vision": self.same_text_model_cb.isChecked(),
"ollama_timeout_s": self.timeout_spin.value(),
"enable_ai_analysis": self.enable_ai_cb.isChecked(),
}
self.generate_btn.setEnabled(False)
self._report_thread = ReportGenerateThread(wd, out, title, opts)
self._report_thread.log_message.connect(self._forward_log, Qt.QueuedConnection)
self._report_thread.finished_ok.connect(self._on_report_ok, Qt.QueuedConnection)
self._report_thread.failed.connect(self._on_report_fail, Qt.QueuedConnection)
self._report_thread.finished.connect(
lambda: self.generate_btn.setEnabled(True), Qt.QueuedConnection
)
self._report_thread.start()
self._forward_log("已开始生成 Word 报告…", "info")
def _forward_log(self, msg: str, level: str):
mw = self.main_window
if mw is not None and hasattr(mw, "log_message"):
mw.log_message(msg, level)
else:
print(f"[{level}] {msg}")
def _on_report_ok(self, path: str):
QMessageBox.information(self, "完成", f"报告已生成:\n{path}")
self._forward_log(f"Word 报告已保存: {path}", "info")
def _on_report_fail(self, err: str):
QMessageBox.critical(self, "失败", f"报告生成失败:\n{err[:800]}")
self._forward_log(err, "error")

View File

@ -0,0 +1,282 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step1 面板 - 水域掩膜生成
"""
import os
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QHBoxLayout, QGroupBox, QLabel,
QDoubleSpinBox, QCheckBox, QPushButton, QFormLayout, QRadioButton,
QMessageBox,
)
from PyQt5.QtCore import Qt
# 从公共组件库导入
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step1Panel(QWidget):
"""1. 水域掩膜生成"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 标题
# 掩膜生成方式选择
method_group = QGroupBox("掩膜生成方式")
method_layout = QVBoxLayout()
# 使用现有掩膜文件
self.use_existing_radio = QRadioButton("使用现有掩膜文件")
self.use_existing_radio.setChecked(True)
method_layout.addWidget(self.use_existing_radio)
# 使用NDWI自动生成
self.use_ndwi_radio = QRadioButton("使用NDWI自动生成")
method_layout.addWidget(self.use_ndwi_radio)
# 应用QRadioButton样式实心选中点
radio_style = """
QRadioButton::indicator {
width: 16px;
height: 16px;
border-radius: 8px;
border: 2px solid #999;
}
QRadioButton::indicator:checked {
background-color: #0078D7;
border: 2px solid #0078D7;
}
QRadioButton::indicator:unchecked {
background-color: white;
border: 2px solid #999;
}
QRadioButton::indicator:hover {
border: 2px solid #0078D7;
}
"""
self.use_existing_radio.setStyleSheet(radio_style)
self.use_ndwi_radio.setStyleSheet(radio_style)
method_group.setLayout(method_layout)
layout.addWidget(method_group)
# 掩膜文件选择
self.mask_file = FileSelectWidget(
"掩膜文件:",
"Shapefiles (*.shp);;Raster Files (*.dat *.tif);;All Files (*.*)"
)
layout.addWidget(self.mask_file)
# 影像文件选择用于shp栅格化或NDWI生成
self.img_file = FileSelectWidget(
"参考影像:",
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
)
layout.addWidget(self.img_file)
# NDWI参数设置
self.ndwi_group = QGroupBox("NDWI参数设置")
ndwi_layout = QVBoxLayout()
# NDWI阈值
threshold_layout = QHBoxLayout()
threshold_layout.addWidget(QLabel("NDWI阈值:"))
self.ndwi_threshold = QDoubleSpinBox()
self.ndwi_threshold.setRange(0.0, 1.0)
self.ndwi_threshold.setSingleStep(0.05)
self.ndwi_threshold.setValue(0.4)
self.ndwi_threshold.setDecimals(2)
threshold_layout.addWidget(self.ndwi_threshold)
threshold_layout.addStretch()
ndwi_layout.addLayout(threshold_layout)
self.ndwi_group.setLayout(ndwi_layout)
layout.addWidget(self.ndwi_group)
# 输出文件路径使用save模式
self.output_file = FileSelectWidget(
"输出掩膜:",
"Mask Files (*.dat *.tif);;All Files (*.*)",
mode="save"
)
self.output_file.line_edit.setPlaceholderText("water_mask.dat")
layout.addWidget(self.output_file)
# 提示信息 - 专业的 Info Alert 样式
hint = QLabel("💡 提示: 如果掩膜文件是Shapefile(.shp)需要提供参考影像用于栅格化如果使用NDWI自动生成只需要提供参考影像")
hint.setWordWrap(True) # 允许自动换行
hint.setStyleSheet("""
QLabel {
color: #0055D4;
font-size: 13px;
font-weight: bold;
background-color: #E8F4FF;
border: 2px solid #0055D4;
border-radius: 8px;
padding: 12px 16px;
margin: 8px 0px;
}
""")
layout.addWidget(hint)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
# 连接信号
self.use_existing_radio.toggled.connect(self.update_ui_state)
self.use_ndwi_radio.toggled.connect(self.update_ui_state)
layout.addStretch()
self.setLayout(layout)
# 初始UI状态
self.update_ui_state()
def update_ui_state(self):
"""根据选择的掩膜生成方式更新UI状态使用显示/隐藏控制)"""
use_ndwi = self.use_ndwi_radio.isChecked()
# 动态显示/隐藏组件
if use_ndwi:
# 使用NDWI模式隐藏掩膜文件显示NDWI参数和输出掩膜
self.mask_file.setVisible(False)
self.ndwi_group.setVisible(True)
self.output_file.setVisible(True) # 显示输出掩膜路径
# 当切换到NDWI模式时如果工作目录已设置自动填充输出路径
if hasattr(self, 'work_dir') and self.work_dir:
self._auto_fill_output_path()
else:
# 使用现有掩膜模式显示掩膜文件隐藏NDWI参数和输出掩膜
self.mask_file.setVisible(True)
self.ndwi_group.setVisible(False)
self.output_file.setVisible(False) # 隐藏输出掩膜路径
# 参考影像在两种模式下都显示
self.img_file.setVisible(True)
def update_work_directory(self, work_dir):
"""
保存工作目录引用,用于后续自动填充路径
Args:
work_dir: 工作目录路径
"""
if not work_dir:
return
# 保存工作目录引用
self.work_dir = work_dir
# 如果当前选中的是NDWI模式立即填充输出路径
if self.use_ndwi_radio.isChecked():
self._auto_fill_output_path()
def _auto_fill_output_path(self):
"""
自动填充输出掩膜路径仅在NDWI模式下
确保路径使用正斜杠,避免斜杠混用
"""
if not hasattr(self, 'work_dir') or not self.work_dir:
return
# 生成输出掩膜的完整路径
output_dir = os.path.join(self.work_dir, "1_water_mask")
os.makedirs(output_dir, exist_ok=True) # 确保目录存在
# 统一使用正斜杠,避免 \ 和 / 混用
default_output_path = os.path.join(output_dir, "water_mask_out.dat").replace('\\', '/')
self.output_file.set_path(default_output_path)
def get_config(self):
"""获取配置"""
use_ndwi = self.use_ndwi_radio.isChecked()
config = {
'mask_path': None if use_ndwi else self.mask_file.get_path(),
'use_ndwi': use_ndwi,
'ndwi_threshold': self.ndwi_threshold.value()
}
# 参考影像路径(两种模式都可能需要)
img_path = self.img_file.get_path()
if img_path:
config['img_path'] = img_path
# 输出路径仅在NDWI模式下有效
if use_ndwi:
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
else:
# 使用现有掩膜时不传递output_path避免底层错误尝试保存文件
config['output_path'] = None
return config
def set_config(self, config):
"""设置配置"""
if 'mask_path' in config:
self.mask_file.set_path(config['mask_path'])
if 'img_path' in config:
self.img_file.set_path(config['img_path'])
if 'output_path' in config:
self.output_file.set_path(config['output_path'])
if 'use_ndwi' in config:
if config['use_ndwi']:
self.use_ndwi_radio.setChecked(True)
else:
self.use_existing_radio.setChecked(True)
if 'ndwi_threshold' in config:
self.ndwi_threshold.setValue(config['ndwi_threshold'])
self.update_ui_state()
def run_step(self):
"""独立运行步骤1"""
# 验证输入
if self.use_ndwi_radio.isChecked():
# NDWI模式需要影像文件
img_path = self.img_file.get_path()
if not img_path:
QMessageBox.warning(self, "输入错误", "请选择参考影像文件!")
return
else:
# 现有掩膜模式:需要掩膜文件
mask_path = self.mask_file.get_path()
if not mask_path:
QMessageBox.warning(self, "输入错误", "请选择掩膜文件!")
return
# 如果是shp文件还需要影像文件
if mask_path.lower().endswith('.shp'):
img_path = self.img_file.get_path()
if not img_path:
QMessageBox.warning(self, "输入错误", "当使用shp文件时需要提供参考影像用于栅格化")
return
# 获取父窗口并运行步骤
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
config = {'step1': self.get_config()}
parent.run_single_step("step1", config)

View File

@ -0,0 +1,210 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step2 面板 - 耀斑区域识别
"""
import os
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QDoubleSpinBox, QSpinBox, QComboBox, QCheckBox, QPushButton,
QMessageBox,
)
from PyQt5.QtCore import Qt
# 从公共组件库导入
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step2Panel(QWidget):
"""2. 耀斑区域识别"""
def __init__(self, parent=None):
super().__init__(parent)
self.work_dir = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 标题
# 影像文件
self.img_file = FileSelectWidget(
"影像文件:",
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
)
layout.addWidget(self.img_file)
# 水域掩膜文件(可选,用于独立运行)
self.water_mask_file = FileSelectWidget(
"水域掩膜:",
"Mask Files (*.dat *.tif);;All Files (*.*)"
)
self.water_mask_file.label.setText("水域掩膜:")
layout.addWidget(self.water_mask_file)
# 参数设置
params_group = QGroupBox("检测参数")
params_layout = QFormLayout()
# 耀斑波长
self.glint_wave = QDoubleSpinBox()
self.glint_wave.setRange(300, 1000)
self.glint_wave.setValue(750.0)
self.glint_wave.setSuffix(" nm")
params_layout.addRow("耀斑检测波长:", self.glint_wave)
# 检测方法
self.method = QComboBox()
self.method.addItem("Otsu 阈值法", "otsu")
self.method.addItem("Z-Score 方法", "zscore")
self.method.addItem("百分位数法", "percentile")
self.method.addItem("IQR 四分位距法", "iqr")
self.method.addItem("自适应阈值法", "adaptive")
self.method.addItem("多波段综合法", "multi_band")
params_layout.addRow("检测方法:", self.method)
# 最大连通域面积
self.max_area = QSpinBox()
self.max_area.setRange(0, 100000)
self.max_area.setValue(50)
self.max_area.setSpecialValueText("不过滤")
params_layout.addRow("最大连通域面积:", self.max_area)
# 岸边缓冲区
self.buffer_size = QSpinBox()
self.buffer_size.setRange(0, 200)
self.buffer_size.setValue(10)
self.buffer_size.setSpecialValueText("不设置")
params_layout.addRow("岸边缓冲区大小:", self.buffer_size)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出文件路径
self.output_file = FileSelectWidget(
"输出耀斑掩膜:",
"Mask Files (*.dat *.tif);;All Files (*.*)"
)
self.output_file.line_edit.setPlaceholderText("")
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
layout.addStretch()
self.setLayout(layout)
# 信号连接:影像文件路径变化时动态更新波段范围
def get_config(self):
"""获取配置"""
config = {
'img_path': self.img_file.get_path(),
'glint_wave': self.glint_wave.value(),
'method': self.method.currentData(), # 使用 currentData() 获取英文ID
}
if self.max_area.value() > 0:
config['max_area'] = self.max_area.value()
if self.buffer_size.value() > 0:
config['buffer_size'] = self.buffer_size.value()
# 添加水域掩膜路径(用于独立运行)
water_mask_path = self.water_mask_file.get_path()
if water_mask_path:
config['water_mask_path'] = water_mask_path
# 添加输出路径
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'img_path' in config:
self.img_file.set_path(config['img_path'])
if 'glint_wave' in config:
self.glint_wave.setValue(config['glint_wave'])
if 'method' in config:
idx = self.method.findData(config['method']) # 使用 findData()
if idx >= 0:
self.method.setCurrentIndex(idx)
if 'max_area' in config:
self.max_area.setValue(config['max_area'])
if 'buffer_size' in config:
self.buffer_size.setValue(config['buffer_size'])
if 'water_mask_path' in config:
self.water_mask_file.set_path(config['water_mask_path'])
if 'output_path' in config:
self.output_file.set_path(config['output_path'])
def update_from_config(self, work_dir=None, pipeline=None):
"""
从全局配置/Pipeline 或 Step1Panel 自动填充路径,实现上下游数据流转
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例用于获取步骤1生成的水域掩膜路径
"""
# 保存工作目录引用
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass # 保持现有工作目录
else:
self.work_dir = None
# 1. 尝试从 Pipeline 获取
mask_path = None
if pipeline and hasattr(pipeline, 'water_mask_path') and pipeline.water_mask_path:
mask_path = pipeline.water_mask_path
# 2. 如果 Pipeline 中没有,则尝试直接从 Step1 界面读取(关键修复)
main_window = self.window()
if not mask_path and hasattr(main_window, 'step1_panel'):
if main_window.step1_panel.use_ndwi_radio.isChecked():
# NDWI模式读取输出框的路径
mask_path = main_window.step1_panel.output_file.get_path()
else:
# 导入现有模式,读取输入框的路径
mask_path = main_window.step1_panel.mask_file.get_path()
# 填充获取到的路径
if mask_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(mask_path):
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
self.water_mask_file.set_path(mask_path)
# 3. 自动填充输出路径(基于工作目录)
if self.work_dir:
# 生成输出耀斑掩膜的标准路径workspace/2_glint_mask/glint_mask_out.dat
output_dir = os.path.join(self.work_dir, "2_glint_mask")
os.makedirs(output_dir, exist_ok=True)
default_output_path = os.path.join(output_dir, "glint_mask_out.dat").replace('\\', '/')
self.output_file.set_path(default_output_path)
else:
# 没有工作目录时,清空输出路径
self.output_file.set_path("")
def run_step(self):
"""独立运行步骤2"""
# 验证输入
img_path = self.img_file.get_path()
if not img_path:
QMessageBox.warning(self, "输入错误", "请选择影像文件!")
return
# 获取主窗口并运行步骤
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step2': self.get_config()}
main_window.run_single_step('step2', config)

View File

@ -0,0 +1,451 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step3 面板 - 耀斑去除
"""
import os
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QDoubleSpinBox, QSpinBox, QComboBox, QCheckBox, QPushButton,
QLabel, QLineEdit, QMessageBox,
)
from PyQt5.QtCore import Qt
# 从公共组件库导入
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step3Panel(QWidget):
"""步骤3耀斑去除"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 标题
# 影像文件
self.img_file = FileSelectWidget(
"影像文件:",
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
)
layout.addWidget(self.img_file)
# 水域掩膜/边界完整流程可由步骤1自动生成独立单步运行时须手动指定
self.water_mask_file = FileSelectWidget(
"水域掩膜/边界:",
"Mask/Boundary (*.dat *.tif *.shp);;All Files (*.*)"
)
layout.addWidget(self.water_mask_file)
step3_mask_hint = QLabel(
"提示:独立运行本步骤时必须选择水域掩膜或边界(与影像同区域的 .dat/.tif 掩膜,或 .shp 矢量)。"
)
step3_mask_hint.setWordWrap(True)
step3_mask_hint.setStyleSheet("color: #666; font-size: 10px;")
layout.addWidget(step3_mask_hint)
# 方法选择
method_group = QGroupBox("去耀斑方法")
method_layout = QVBoxLayout()
self.method = QComboBox()
for text, data in [('Goodman方法', 'goodman'), ('Kutser方法', 'kutser'),
('Hedley方法', 'hedley'), ('SUGAR算法', 'sugar')]:
self.method.addItem(text, data)
self.method.currentIndexChanged.connect(self._on_method_changed)
method_layout.addWidget(self.method)
method_group.setLayout(method_layout)
layout.addWidget(method_group)
# Goodman参数组
self.goodman_group = QGroupBox("Goodman方法参数")
goodman_layout = QFormLayout()
self.nir_lower = QSpinBox()
self.nir_lower.setRange(0, 200)
self.nir_lower.setValue(65)
goodman_layout.addRow("NIR下波段索引:", self.nir_lower)
self.nir_upper = QSpinBox()
self.nir_upper.setRange(0, 200)
self.nir_upper.setValue(91)
goodman_layout.addRow("NIR上波段索引:", self.nir_upper)
self.goodman_a = QDoubleSpinBox()
self.goodman_a.setDecimals(6)
self.goodman_a.setRange(0, 1)
self.goodman_a.setValue(0.000019)
goodman_layout.addRow("参数A:", self.goodman_a)
self.goodman_b = QDoubleSpinBox()
self.goodman_b.setDecimals(2)
self.goodman_b.setRange(0, 1)
self.goodman_b.setValue(0.1)
goodman_layout.addRow("参数B:", self.goodman_b)
self.goodman_group.setLayout(goodman_layout)
layout.addWidget(self.goodman_group)
# Kutser参数组
self.kutser_group = QGroupBox("Kutser方法参数")
kutser_layout = QFormLayout()
self.oxy_band = QSpinBox()
self.oxy_band.setRange(0, 200)
self.oxy_band.setValue(38)
kutser_layout.addRow("氧吸收波段索引:", self.oxy_band)
self.lower_oxy = QSpinBox()
self.lower_oxy.setRange(0, 200)
self.lower_oxy.setValue(36)
kutser_layout.addRow("下氧吸收波段索引:", self.lower_oxy)
self.upper_oxy = QSpinBox()
self.upper_oxy.setRange(0, 200)
self.upper_oxy.setValue(49)
kutser_layout.addRow("上氧吸收波段索引:", self.upper_oxy)
self.nir_band = QSpinBox()
self.nir_band.setRange(0, 200)
self.nir_band.setValue(47)
kutser_layout.addRow("NIR波段索引:", self.nir_band)
self.kutser_group.setLayout(kutser_layout)
self.kutser_group.setVisible(False)
layout.addWidget(self.kutser_group)
# Hedley参数组
self.hedley_group = QGroupBox("Hedley方法参数")
hedley_layout = QFormLayout()
self.hedley_nir_band = QSpinBox()
self.hedley_nir_band.setRange(0, 200)
self.hedley_nir_band.setValue(47)
hedley_layout.addRow("NIR波段索引:", self.hedley_nir_band)
self.hedley_group.setLayout(hedley_layout)
self.hedley_group.setVisible(False)
layout.addWidget(self.hedley_group)
# SUGAR参数组
self.sugar_group = QGroupBox("SUGAR方法参数")
sugar_layout = QFormLayout()
self.sugar_iter = QSpinBox()
self.sugar_iter.setRange(1, 20)
self.sugar_iter.setValue(3)
self.sugar_iter.setSpecialValueText("自动")
sugar_layout.addRow("迭代次数:", self.sugar_iter)
self.sugar_sigma = QDoubleSpinBox()
self.sugar_sigma.setDecimals(2)
self.sugar_sigma.setRange(0.1, 10)
self.sugar_sigma.setValue(1.0)
sugar_layout.addRow("LoG平滑σ:", self.sugar_sigma)
self.sugar_estimate_background = QCheckBox()
self.sugar_estimate_background.setChecked(True)
sugar_layout.addRow("估计背景光谱:", self.sugar_estimate_background)
self.sugar_glint_mask_method = QComboBox()
self.sugar_glint_mask_method.addItems(['cdf', 'otsu'])
self.sugar_glint_mask_method.setCurrentText('cdf')
sugar_layout.addRow("耀斑掩膜方法:", self.sugar_glint_mask_method)
self.sugar_termination_thresh = QDoubleSpinBox()
self.sugar_termination_thresh.setDecimals(2)
self.sugar_termination_thresh.setRange(1, 100)
self.sugar_termination_thresh.setValue(20.0)
sugar_layout.addRow("终止阈值:", self.sugar_termination_thresh)
self.sugar_bounds = QLineEdit()
self.sugar_bounds.setText("[(1, 2)]")
sugar_layout.addRow("优化边界:", self.sugar_bounds)
self.sugar_group.setLayout(sugar_layout)
self.sugar_group.setVisible(False)
layout.addWidget(self.sugar_group)
# 插值选项
interp_group = QGroupBox("0值像素插值")
interp_layout = QFormLayout()
self.interpolate_zeros = QCheckBox("启用插值")
interp_layout.addRow("", self.interpolate_zeros)
self.interp_method = QComboBox()
for text, data in [('最近邻插值', 'nearest'), ('双线性插值', 'bilinear'),
('样条插值', 'spline'), ('克里金插值', 'kriging')]:
self.interp_method.addItem(text, data)
self.interp_method.setCurrentIndex(1) # 默认双线性插值
interp_layout.addRow("插值方法:", self.interp_method)
interp_group.setLayout(interp_layout)
layout.addWidget(interp_group)
# # 实测经纬度参考点
# self.ref_csv_file = FileSelectWidget(
# "实测经纬度CSV:",
# "CSV Files (*.csv);;All Files (*.*)"
# )
# self.ref_csv_file.line_edit.setPlaceholderText("可选:包含 Lon/Lat 列的 CSV 文件")
# layout.addWidget(self.ref_csv_file)
# 交互式预览按钮
# self.preview_btn = QPushButton("👁️ 打开交互式影像预览")
# self.preview_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('info'))
# self.preview_btn.clicked.connect(self.open_interactive_viewer)
# layout.addWidget(self.preview_btn)
# 输出文件路径
self.output_file = FileSelectWidget(
"输出影像:",
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
)
self.output_file.line_edit.setPlaceholderText("deglint_image.bsq")
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
layout.addStretch()
self.setLayout(layout)
# 信号连接:影像文件路径变化时动态更新波段范围
self.img_file.line_edit.textChanged.connect(self._update_band_ranges)
def open_interactive_viewer(self):
"""打开交互式影像预览"""
from src.gui.water_quality_gui import InteractiveViewerDialog
img_path = self.img_file.get_path()
if not img_path or not os.path.isfile(img_path):
QMessageBox.warning(self, "警告", "请先选择影像文件!")
return
water_mask = self.water_mask_file.get_path()
dialog = InteractiveViewerDialog(img_path, self)
if water_mask and os.path.isfile(water_mask):
dialog.load_water_mask(water_mask)
dialog.exec_()
def _update_band_ranges(self, file_path):
"""根据选择的影像动态限制波段索引的输入范围"""
from osgeo import gdal
if not file_path or not os.path.isfile(file_path):
return
try:
dataset = gdal.Open(file_path)
if dataset is None:
return
raster_count = dataset.RasterCount
max_band = max(0, raster_count - 1)
self.nir_lower.setMaximum(max_band)
self.nir_upper.setMaximum(max_band)
self.oxy_band.setMaximum(max_band)
self.nir_band.setMaximum(max_band)
self.hedley_nir_band.setMaximum(max_band)
dataset = None
except Exception:
pass
def update_from_config(self, work_dir=None, pipeline=None):
"""
从 Step1Panel 自动填充水域掩膜路径,实现上下游数据流转
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
# 保存工作目录引用
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass # 保持现有工作目录
else:
self.work_dir = None
# 从 Step1 界面读取水域掩膜路径
main_window = self.window()
if hasattr(main_window, 'step1_panel'):
if main_window.step1_panel.use_ndwi_radio.isChecked():
# NDWI模式读取输出框的路径
mask_path = main_window.step1_panel.output_file.get_path()
else:
# 导入现有模式,读取输入框的路径
mask_path = main_window.step1_panel.mask_file.get_path()
if mask_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(mask_path):
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
self.water_mask_file.set_path(mask_path)
# 自动填充输出路径(基于工作目录)
if self.work_dir:
output_dir = os.path.join(self.work_dir, "3_deglint")
os.makedirs(output_dir, exist_ok=True)
default_output_path = os.path.join(output_dir, "deglint_image.bsq").replace('\\', '/')
self.output_file.set_path(default_output_path)
else:
self.output_file.set_path("")
def _on_method_changed(self, index):
"""方法改变时更新参数显示"""
method_id = self.method.currentData()
self.goodman_group.setVisible(method_id == 'goodman')
self.kutser_group.setVisible(method_id == 'kutser')
self.hedley_group.setVisible(method_id == 'hedley')
self.sugar_group.setVisible(method_id == 'sugar')
def get_config(self):
"""获取配置"""
config = {
'img_path': self.img_file.get_path(),
'method': self.method.currentData(), # 使用 currentData() 获取英文ID
'enabled': self.enable_checkbox.isChecked(),
'interpolate_zeros': self.interpolate_zeros.isChecked(),
'interpolation_method': self.interp_method.currentData(), # 使用 currentData()
}
water_mask_path = self.water_mask_file.get_path()
if water_mask_path:
config['water_mask'] = water_mask_path
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
method = self.method.currentData() # 使用 currentData()
if method == 'goodman':
config['nir_lower'] = self.nir_lower.value()
config['nir_upper'] = self.nir_upper.value()
config['goodman_A'] = self.goodman_a.value()
config['goodman_B'] = self.goodman_b.value()
elif method == 'kutser':
config['oxy_band'] = self.oxy_band.value()
config['lower_oxy'] = self.lower_oxy.value()
config['upper_oxy'] = self.upper_oxy.value()
config['nir_band'] = self.nir_band.value()
elif method == 'hedley':
config['hedley_nir_band'] = self.hedley_nir_band.value()
elif method == 'sugar':
config['sugar_iter'] = self.sugar_iter.value() if self.sugar_iter.value() > 0 else None
config['sugar_sigma'] = self.sugar_sigma.value()
config['sugar_estimate_background'] = self.sugar_estimate_background.isChecked()
config['sugar_glint_mask_method'] = self.sugar_glint_mask_method.currentData()
config['sugar_termination_thresh'] = self.sugar_termination_thresh.value()
# 解析bounds字符串
try:
import ast
config['sugar_bounds'] = ast.literal_eval(self.sugar_bounds.text())
except:
config['sugar_bounds'] = [(1, 2)] # 默认值
return config
def set_config(self, config):
"""设置配置"""
if 'img_path' in config:
self.img_file.set_path(config['img_path'])
if 'water_mask' in config:
self.water_mask_file.set_path(config['water_mask'])
if 'output_path' in config:
self.output_file.set_path(config['output_path'])
if 'reference_csv' in config:
self.ref_csv_file.set_path(config['reference_csv'])
if 'method' in config:
idx = self.method.findData(config['method']) # 使用 findData()
if idx >= 0:
self.method.setCurrentIndex(idx)
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
if 'interpolate_zeros' in config:
self.interpolate_zeros.setChecked(config['interpolate_zeros'])
if 'interpolation_method' in config:
idx = self.interp_method.findData(config['interpolation_method']) # 使用 findData()
if idx >= 0:
self.interp_method.setCurrentIndex(idx)
# Goodman参数
if 'nir_lower' in config:
self.nir_lower.setValue(config['nir_lower'])
if 'nir_upper' in config:
self.nir_upper.setValue(config['nir_upper'])
if 'goodman_A' in config:
self.goodman_a.setValue(config['goodman_A'])
if 'goodman_B' in config:
self.goodman_b.setValue(config['goodman_B'])
# Kutser参数
if 'oxy_band' in config:
self.oxy_band.setValue(config['oxy_band'])
if 'lower_oxy' in config:
self.lower_oxy.setValue(config['lower_oxy'])
if 'upper_oxy' in config:
self.upper_oxy.setValue(config['upper_oxy'])
if 'nir_band' in config:
self.nir_band.setValue(config['nir_band'])
# Hedley参数
if 'hedley_nir_band' in config:
self.hedley_nir_band.setValue(config['hedley_nir_band'])
# SUGAR参数
if 'sugar_iter' in config:
self.sugar_iter.setValue(config['sugar_iter'] if config['sugar_iter'] is not None else 0)
if 'sugar_sigma' in config:
self.sugar_sigma.setValue(config['sugar_sigma'])
if 'sugar_estimate_background' in config:
self.sugar_estimate_background.setChecked(config['sugar_estimate_background'])
if 'sugar_glint_mask_method' in config:
idx = self.sugar_glint_mask_method.findData(config['sugar_glint_mask_method']) # 使用 findData()
if idx >= 0:
self.sugar_glint_mask_method.setCurrentIndex(idx)
if 'sugar_termination_thresh' in config:
self.sugar_termination_thresh.setValue(config['sugar_termination_thresh'])
if 'sugar_bounds' in config:
self.sugar_bounds.setText(str(config['sugar_bounds']))
def run_step(self):
"""独立运行步骤3"""
# 验证输入
img_path = self.img_file.get_path()
if not img_path:
QMessageBox.warning(self, "输入错误", "请选择影像文件!")
return
if self.enable_checkbox.isChecked():
water_mask_path = self.water_mask_file.get_path()
if not water_mask_path:
QMessageBox.warning(
self,
"输入错误",
"独立运行耀斑去除时,必须选择水域掩膜或边界文件。\n\n"
"请提供与当前影像空间一致的水域栅格掩膜(.dat/.tif或水域矢量边界.shp\n"
"若刚跑过完整流程可使用步骤1生成的水域掩膜文件。",
)
return
# 获取主窗口并运行步骤
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step3': self.get_config()}
main_window.run_single_step('step3', config)

View File

@ -0,0 +1,185 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step4 面板 - 数据预处理
"""
import os
import pandas as pd
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QHBoxLayout, QLabel,
QSpinBox, QPushButton, QCheckBox, QTableView,
QAbstractItemView, QHeaderView, QMessageBox,
)
from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step4Panel(QWidget):
"""步骤4数据预处理"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 标题
# CSV文件
self.csv_file = FileSelectWidget(
"水质参数文件:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.csv_file)
hint = QLabel("提示: 处理CSV文件筛选剔除异常值")
hint.setStyleSheet("color: #666; font-size: 10px;")
layout.addWidget(hint)
preview_group = QGroupBox("CSV数据预览")
preview_layout = QVBoxLayout()
controls_layout = QHBoxLayout()
controls_layout.addWidget(QLabel("预览行数:"))
self.preview_rows_spin = QSpinBox()
self.preview_rows_spin.setRange(1, 200)
self.preview_rows_spin.setValue(10)
controls_layout.addWidget(self.preview_rows_spin)
self.preview_btn = QPushButton("刷新预览")
self.preview_btn.clicked.connect(self.load_csv_preview)
controls_layout.addWidget(self.preview_btn)
controls_layout.addStretch()
self.preview_table = QTableView()
self.preview_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.preview_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.preview_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.preview_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
self.preview_table.verticalHeader().setVisible(False)
self.preview_table.setMinimumHeight(200)
self.preview_status_label = QLabel("请选择CSV文件并点击刷新预览")
self.preview_status_label.setStyleSheet("color: #666; font-size: 11px;")
preview_layout.addLayout(controls_layout)
preview_layout.addWidget(self.preview_table)
preview_layout.addWidget(self.preview_status_label)
preview_group.setLayout(preview_layout)
layout.addWidget(preview_group)
# 输出文件路径
self.output_file = FileSelectWidget(
"输出处理后CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
self.output_file.line_edit.setPlaceholderText("processed_data.csv")
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
layout.addStretch()
self.setLayout(layout)
self.reset_preview()
def get_config(self):
"""获取配置"""
config = {
'csv_path': self.csv_file.get_path(),
}
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'csv_path' in config:
self.csv_file.set_path(config['csv_path'])
self.load_csv_preview()
if 'output_path' in config:
self.output_file.set_path(config['output_path'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充输出路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
if self.work_dir:
output_dir = os.path.join(self.work_dir, "4_processed_data")
os.makedirs(output_dir, exist_ok=True)
default_output_path = os.path.join(output_dir, "processed_data.csv").replace('\\', '/')
self.output_file.set_path(default_output_path)
else:
self.output_file.set_path("")
def run_step(self):
"""独立运行步骤4"""
# 验证输入
csv_path = self.csv_file.get_path()
if not csv_path:
QMessageBox.warning(self, "输入错误", "请选择水质参数文件!")
return
# 获取主窗口并运行步骤
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step4': self.get_config()}
main_window.run_single_step('step4', config)
def reset_preview(self, message="请选择CSV文件并点击刷新预览"):
"""重置预览表格"""
from src.gui.water_quality_gui import PandasTableModel
empty_model = PandasTableModel(pd.DataFrame())
self.preview_table.setModel(empty_model)
self.preview_status_label.setText(message)
def load_csv_preview(self):
"""加载CSV预览数据"""
from src.gui.water_quality_gui import PandasTableModel
csv_path = self.csv_file.get_path()
if not csv_path:
self.reset_preview("请先选择CSV文件")
return
if not os.path.exists(csv_path):
self.reset_preview("文件不存在,请检查路径")
return
try:
rows_to_preview = max(1, self.preview_rows_spin.value())
# dtype=object 确保所有列以字符串读取,避免空值/混合类型导致 dtype 报错
df = pd.read_csv(csv_path, nrows=rows_to_preview, dtype=object)
# fillna 在 PandasTableModel.__init__ 中已执行,此处再次防御性处理
df = df.fillna('')
if df.empty:
self.reset_preview("CSV文件为空")
return
model = PandasTableModel(df)
self.preview_table.setModel(model)
self.preview_status_label.setText(
f"预览 {len(df)} 行,{len(df.columns)} 列(总行数可能更多)"
)
except Exception as exc:
self.reset_preview(f"加载失败: {exc}")

View File

@ -0,0 +1,408 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step5_5 面板 - 水质指数计算
"""
import os
import sys
from pathlib import Path
from typing import Dict, List, Union
def get_resource_path(relative_path: str) -> str:
"""获取资源的绝对路径,适配 PyInstaller 打包环境。"""
if hasattr(sys, '_MEIPASS'):
return os.path.join(sys._MEIPASS, relative_path)
return os.path.abspath(
os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), relative_path)
)
import pandas as pd
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
QHBoxLayout, QLabel, QLineEdit, QComboBox, QCheckBox,
QPushButton, QMessageBox,
)
from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step5_5Panel(QWidget):
"""步骤5.5:水质指数计算"""
def __init__(self, parent=None):
super().__init__(parent)
self.index_checkboxes: Dict[str, QCheckBox] = {}
self.csv_columns = [] # 存储CSV文件列名
self.init_ui()
def init_ui(self):
main_layout = QVBoxLayout()
# 标题
# 数据文件选择
data_group = QGroupBox("数据文件")
data_layout = QVBoxLayout()
# 训练数据CSV文件选择
self.training_data_widget = FileSelectWidget("训练数据CSV文件:", "CSV Files (*.csv)")
data_layout.addWidget(self.training_data_widget)
# 公式CSV文件选择
self.formula_csv_widget = FileSelectWidget("公式CSV文件:", "CSV Files (*.csv)")
data_layout.addWidget(self.formula_csv_widget)
# 刷新公式按钮
refresh_layout = QHBoxLayout()
self.refresh_button = QPushButton("刷新公式列表")
self.refresh_button.clicked.connect(self.refresh_formulas)
refresh_layout.addWidget(self.refresh_button)
refresh_layout.addStretch()
data_layout.addLayout(refresh_layout)
data_group.setLayout(data_layout)
main_layout.addWidget(data_group)
# 公式选择区域
self.formula_group = QGroupBox("选择要计算的公式")
formula_outer_layout = QVBoxLayout()
# 按钮控制区域
button_layout = QHBoxLayout()
self.select_all_btn = QPushButton("全选")
self.select_all_btn.clicked.connect(self.select_all_formulas)
self.deselect_all_btn = QPushButton("清空")
self.deselect_all_btn.clicked.connect(self.deselect_all_formulas)
button_layout.addWidget(self.select_all_btn)
button_layout.addWidget(self.deselect_all_btn)
button_layout.addStretch()
formula_outer_layout.addLayout(button_layout)
# 公式勾选框网格布局
self.formula_layout = QGridLayout()
formula_outer_layout.addLayout(self.formula_layout)
self.formula_group.setLayout(formula_outer_layout)
main_layout.addWidget(self.formula_group)
# 输出文件设置
output_group = QGroupBox("输出设置")
output_layout = QVBoxLayout()
self.output_file_widget = FileSelectWidget(
"输出文件:", "CSV Files (*.csv)", mode="save"
)
output_layout.addWidget(self.output_file_widget)
output_group.setLayout(output_layout)
main_layout.addWidget(output_group)
# 启用选项
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
main_layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_button = QPushButton("独立运行此步骤")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.clicked.connect(self.run_step)
main_layout.addWidget(self.run_button)
# 公式编辑区域
formula_edit_group = QGroupBox("添加自定义公式")
formula_edit_layout = QFormLayout()
self.formula_name_edit = QLineEdit()
# 公式类别下拉选择框
self.formula_category_combo = QComboBox()
self.formula_category_combo.addItems([
"chlorophyll_a",
"Phycocyanin (BGA_PC)",
"Total Nitrogen (TN)",
"Total Phosphorus (TP)",
"Orthophosphate",
"COD",
"BOD",
"TOC",
"Dissolved Oxygen (DO)",
"E. coli",
"Total Coliforms",
"Turbidity",
"Total Suspended Solids (TSS)",
"Color",
"pH",
"Temperature",
"Conductivity",
"Total Dissolved Solids (TDS)"
])
self.formula_category_combo.setEditable(True) # 允许用户输入自定义类别
self.formula_expression_edit = QLineEdit()
self.formula_reference_edit = QLineEdit()
formula_edit_layout.addRow("公式名称:", self.formula_name_edit)
formula_edit_layout.addRow("公式类别:", self.formula_category_combo)
formula_edit_layout.addRow("公式表达式:", self.formula_expression_edit)
formula_edit_layout.addRow("参考文献:", self.formula_reference_edit)
add_button = QPushButton("添加公式")
add_button.clicked.connect(self.add_custom_formula)
formula_edit_layout.addRow(add_button)
formula_edit_group.setLayout(formula_edit_layout)
main_layout.addWidget(formula_edit_group)
main_layout.addStretch()
self.setLayout(main_layout)
# 自动加载内置公式文件
formula_csv_path = get_resource_path("data/sub/waterindex.csv")
if os.path.isfile(formula_csv_path):
self.formula_csv_widget.set_path(str(formula_csv_path))
self.refresh_formulas()
def refresh_formulas(self):
"""刷新公式列表"""
formula_csv_path = self.formula_csv_widget.get_path()
if not formula_csv_path or not os.path.exists(formula_csv_path):
QMessageBox.warning(self, "警告", "请先选择有效的公式CSV文件")
return
try:
# 清除现有的勾选框
for checkbox in self.index_checkboxes.values():
self.formula_layout.removeWidget(checkbox)
checkbox.deleteLater()
self.index_checkboxes.clear()
# 读取公式CSV文件
df = pd.read_csv(formula_csv_path)
if df.empty or 'Formula_Name' not in df.columns:
QMessageBox.warning(self, "警告", "公式CSV文件格式不正确")
return
# 获取所有公式名称(跳过第一行)
formula_names = df['Formula_Name'].tolist()[1:]
# 创建3列布局的勾选框
row, col = 0, 0
for formula_name in formula_names:
if pd.isna(formula_name) or not formula_name.strip():
continue
checkbox = QCheckBox(formula_name.strip())
checkbox.setChecked(True)
self.index_checkboxes[formula_name.strip()] = checkbox
self.formula_layout.addWidget(checkbox, row, col)
col += 1
if col >= 3: # 每行3列
col = 0
row += 1
except Exception as e:
QMessageBox.critical(self, "错误", f"读取公式文件失败: {str(e)}")
def add_custom_formula(self):
"""添加自定义公式到公式CSV文件"""
formula_csv_path = self.formula_csv_widget.get_path()
if not formula_csv_path:
QMessageBox.warning(self, "警告", "请先选择公式CSV文件")
return
formula_name = self.formula_name_edit.text().strip()
formula_category = self.formula_category_combo.currentText().strip()
formula_expression = self.formula_expression_edit.text().strip()
formula_reference = self.formula_reference_edit.text().strip()
if not all([formula_name, formula_category, formula_expression]):
QMessageBox.warning(self, "警告", "请填写公式名称、类别和表达式")
return
try:
# 读取现有公式文件或创建新文件
if os.path.exists(formula_csv_path):
df = pd.read_csv(formula_csv_path)
else:
df = pd.DataFrame(columns=['Formula_Name', 'Category', 'Formula', 'Reference'])
# 添加新公式
new_row = pd.DataFrame({
'Formula_Name': [formula_name],
'Category': [formula_category],
'Formula': [formula_expression],
'Reference': [formula_reference]
})
df = pd.concat([df, new_row], ignore_index=True)
# 保存文件
df.to_csv(formula_csv_path, index=False, encoding='utf-8')
# 清空输入框
self.formula_name_edit.clear()
self.formula_category_combo.setCurrentIndex(0) # 重置到第一个选项
self.formula_expression_edit.clear()
self.formula_reference_edit.clear()
# 刷新公式列表
self.refresh_formulas()
QMessageBox.information(self, "成功", "公式添加成功")
except Exception as e:
QMessageBox.critical(self, "错误", f"添加公式失败: {str(e)}")
def get_config(self) -> Dict[str, Union[List[str], str, bool]]:
"""获取配置"""
selected = [
name for name, checkbox in self.index_checkboxes.items()
if checkbox.isChecked()
]
output_path = self.output_file_widget.get_path()
return {
'training_spectra_path': self.training_data_widget.get_path() or None,
'formula_csv_file': self.formula_csv_widget.get_path() or None,
'formula_names': selected,
'output_file': output_path or None,
'enabled': self.enable_checkbox.isChecked()
}
def set_config(self, config):
"""设置配置"""
if 'training_spectra_path' in config:
self.training_data_widget.set_path(config['training_spectra_path'])
if 'formula_csv_file' in config:
self.formula_csv_widget.set_path(config['formula_csv_file'])
self.refresh_formulas()
if 'formula_names' in config:
selected_formulas = set(config['formula_names'])
for name, checkbox in self.index_checkboxes.items():
checkbox.setChecked(name in selected_formulas)
if 'output_file' in config and config['output_file']:
self.output_file_widget.set_path(config['output_file'])
elif 'output_filename' in config and config['output_filename']:
self.output_file_widget.set_path(config['output_filename'])
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
# 1. 自动填入训练数据路径(从 Step5 的输出中获取)
# 优先级:直接 widget > pipeline.step_outputs 回退
main_window = self.window()
if hasattr(main_window, 'step5_panel'):
# 优先直接从 Step5 的输出 widget 读取(已运行的最新输出)
step5_output = main_window.step5_panel.output_file.get_path()
if step5_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_output):
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
self.training_data_widget.set_path(step5_output)
else:
# 退而求其次,使用 Step5 的输入 CSV
step5_csv = main_window.step5_panel.csv_file.get_path()
if step5_csv:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_csv):
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
self.training_data_widget.set_path(step5_csv)
# 如果上述都没找到,尝试从 pipeline.step_outputs 回退
if not self.training_data_widget.get_path() and pipeline and hasattr(pipeline, 'step_outputs'):
step5_outputs = getattr(pipeline, 'step_outputs', {}).get('step5', {})
training_path = step5_outputs.get('training_spectra')
if training_path:
self.training_data_widget.set_path(training_path)
# 2. 自动填入输出文件的绝对路径
if self.work_dir:
output_abs = os.path.join(self.work_dir, "6_water_quality_indices",
"training_spectra_indices.csv").replace('\\', '/')
self.output_file_widget.set_path(output_abs)
def is_enabled(self) -> bool:
return self.enable_checkbox.isChecked()
def select_all_formulas(self):
"""全选所有公式"""
for checkbox in self.index_checkboxes.values():
checkbox.setChecked(True)
def deselect_all_formulas(self):
"""清空所有公式"""
for checkbox in self.index_checkboxes.values():
checkbox.setChecked(False)
def run_step(self):
"""独立运行步骤5.5:计算水质指数。
动态根据输入 CSV 文件名生成输出文件名,自动填入 output_file_widget。
例如training_spectra.csv → training_spectra_indices.csv
sampling_spectra.csv → sampling_spectra_indices.csv
"""
# 验证输入
training_csv_path = self.training_data_widget.get_path()
formula_csv_path = self.formula_csv_widget.get_path()
if not training_csv_path:
QMessageBox.warning(self, "输入验证失败", "请选择训练数据CSV文件")
return
if not formula_csv_path:
QMessageBox.warning(self, "输入验证失败", "请选择公式CSV文件")
return
if not os.path.exists(training_csv_path):
QMessageBox.warning(self, "输入验证失败", "训练数据CSV文件不存在")
return
if not os.path.exists(formula_csv_path):
QMessageBox.warning(self, "输入验证失败", "公式CSV文件不存在")
return
# 动态生成输出文件:自动拼接 _indices 后缀
input_name = Path(training_csv_path).stem
dynamic_output = f"{input_name}_indices.csv"
# 合成完整绝对路径(优先使用 work_dir其次从 training_csv_path 推导)
work_dir = getattr(self, 'work_dir', None)
if work_dir:
dynamic_output = os.path.join(
work_dir, "6_water_quality_indices", dynamic_output
).replace('\\', '/')
self.output_file_widget.set_path(dynamic_output)
# 获取配置
config = self.get_config()
# 调用GUI的run_single_step方法
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step5_5', {'step5_5': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")

View File

@ -0,0 +1,239 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step5 面板 - 光谱提取
"""
import os
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QLabel,
QSpinBox, QPushButton, QCheckBox, QMessageBox,
)
from PyQt5.QtGui import QFont
from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step5Panel(QWidget):
"""步骤5光谱提取"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 标题
title = QLabel("步骤5训练样本光谱提取")
title.setFont(QFont("Arial", 12, QFont.Bold))
layout.addWidget(title)
# 去耀斑影像文件(用于独立运行)
self.deglint_img_file = FileSelectWidget(
"去耀斑影像:",
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
)
layout.addWidget(self.deglint_img_file)
# 处理后的CSV文件用于独立运行
self.csv_file = FileSelectWidget(
"处理后CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.csv_file)
# 水体掩膜文件(可选,用于独立运行)
self.water_mask_file = FileSelectWidget(
"水体掩膜:",
"Mask Files (*.dat *.tif);;All Files (*.*)"
)
self.water_mask_file.line_edit.setPlaceholderText("可选,如不选择则自动生成")
layout.addWidget(self.water_mask_file)
self.glint_mask_file = FileSelectWidget(
"耀斑掩膜:",
"Mask Files (*.dat *.tif);;All Files (*.*)"
)
layout.addWidget(self.glint_mask_file)
step5_glint_hint = QLabel(
"提示独立运行本步骤时必须选择耀斑掩膜通常为步骤2输出的 severe_glint_area.dat用于在采样时避开耀斑像元。"
)
step5_glint_hint.setWordWrap(True)
step5_glint_hint.setStyleSheet("color: #666; font-size: 10px;")
layout.addWidget(step5_glint_hint)
# 参数设置
params_group = QGroupBox("提取参数")
params_layout = QFormLayout()
self.radius = QSpinBox()
self.radius.setRange(1, 50)
self.radius.setValue(5)
params_layout.addRow("采样半径(像素):", self.radius)
self.source_epsg = QSpinBox()
self.source_epsg.setRange(1000, 99999)
self.source_epsg.setValue(4326)
params_layout.addRow("源坐标系EPSG:", self.source_epsg)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出文件路径
self.output_file = FileSelectWidget(
"输出训练数据:",
"CSV Files (*.csv);;All Files (*.*)"
)
self.output_file.line_edit.setPlaceholderText("training_spectra.csv")
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
layout.addStretch()
self.setLayout(layout)
# 信号连接:影像文件路径变化时动态更新波段范围
def get_config(self):
"""获取配置"""
config = {
'radius': self.radius.value(),
'source_epsg': self.source_epsg.value(),
}
# 添加独立运行所需的文件路径
deglint_img_path = self.deglint_img_file.get_path()
if deglint_img_path:
config['deglint_img_path'] = deglint_img_path
csv_path = self.csv_file.get_path()
if csv_path:
config['csv_path'] = csv_path
water_mask_path = self.water_mask_file.get_path()
if water_mask_path:
config['boundary_path'] = water_mask_path
glint_mask_path = self.glint_mask_file.get_path()
if glint_mask_path:
config['glint_mask_path'] = glint_mask_path
# 注意step5_extract_training_spectra 不接受 output_path / training_spectra_path
# 参数,输出路径由 pipeline 内部根据 training_spectra_dir 自动生成。
return config
def set_config(self, config):
"""设置配置"""
if 'radius' in config:
self.radius.setValue(config['radius'])
if 'source_epsg' in config:
self.source_epsg.setValue(config['source_epsg'])
if 'deglint_img_path' in config:
self.deglint_img_file.set_path(config['deglint_img_path'])
if 'csv_path' in config:
self.csv_file.set_path(config['csv_path'])
if 'boundary_path' in config:
self.water_mask_file.set_path(config['boundary_path'])
if 'glint_mask_path' in config:
self.glint_mask_file.set_path(config['glint_mask_path'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置/Pipeline 或 Step1Panel 自动填充路径,实现上下游数据流转
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例用于获取步骤1生成的水域掩膜路径
"""
# 保存工作目录引用
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
# 1. 尝试从 Pipeline 获取水体掩膜路径
mask_path = None
if pipeline and hasattr(pipeline, 'water_mask_path') and pipeline.water_mask_path:
mask_path = pipeline.water_mask_path
# 2. 如果 Pipeline 中没有,则尝试直接从 Step1 界面读取
main_window = self.window()
if not mask_path and hasattr(main_window, 'step1_panel'):
if main_window.step1_panel.use_ndwi_radio.isChecked():
mask_path = main_window.step1_panel.output_file.get_path()
else:
mask_path = main_window.step1_panel.mask_file.get_path()
# 若为相对路径,使用 work_dir 合成为绝对路径
if mask_path and not os.path.isabs(mask_path):
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
# 填充水体掩膜路径
if mask_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(mask_path):
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
self.water_mask_file.set_path(mask_path)
# 3. 尝试从 Step2 界面读取耀斑掩膜路径
main_window = self.window()
if hasattr(main_window, 'step2_panel'):
glint_path = main_window.step2_panel.output_file.get_path()
if glint_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(glint_path):
glint_path = os.path.join(self.work_dir or '', glint_path).replace('\\', '/')
self.glint_mask_file.set_path(glint_path)
# 4. 自动填充输出路径(基于工作目录)
if self.work_dir:
output_dir = os.path.join(self.work_dir, "5_training_spectra")
os.makedirs(output_dir, exist_ok=True)
default_output_path = os.path.join(output_dir, "training_spectra.csv").replace('\\', '/')
self.output_file.set_path(default_output_path)
else:
self.output_file.set_path("")
# 5. 尝试从 Step4 界面读取已处理的水质参数 CSV 路径,自动填入本面板
main_window = self.window()
if main_window and hasattr(main_window, 'step4_panel'):
step4_output_path = main_window.step4_panel.output_file.get_path()
if step4_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step4_output_path):
step4_output_path = os.path.join(self.work_dir or '', step4_output_path).replace('\\', '/')
existing_csv = self.csv_file.get_path()
if not existing_csv or not existing_csv.strip():
self.csv_file.set_path(step4_output_path)
def run_step(self):
"""独立运行步骤5"""
# 验证输入
deglint_img_path = self.deglint_img_file.get_path()
csv_path = self.csv_file.get_path()
if not deglint_img_path:
QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
return
if not csv_path:
QMessageBox.warning(self, "输入错误", "请选择处理后的CSV文件")
return
if not self.glint_mask_file.get_path():
QMessageBox.warning(
self,
"输入错误",
"独立运行光谱特征提取时,必须选择耀斑掩膜文件。\n\n"
"请提供与去耀斑影像对应的耀斑二值掩膜一般为步骤2输出的 severe_glint_area.dat",
)
return
# 获取主窗口并运行步骤
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step5': self.get_config()}
main_window.run_single_step('step5', config)

View File

@ -0,0 +1,307 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step6_5 面板 - 非经验统计回归建模
"""
import os
from pathlib import Path
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
QHBoxLayout, QLabel, QCheckBox, QSpinBox, QPushButton,
QFileDialog, QMessageBox,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step6_5Panel(QWidget):
"""步骤6.5:非经验统计回归建模"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 标题
# 训练数据文件(用于独立运行)
self.training_csv_file = FileSelectWidget(
"训练数据CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.training_csv_file)
# 参数设置
params_group = QGroupBox("模型参数")
params_layout = QFormLayout()
# 预处理方法
self.preproc_checkboxes = {}
preproc_group = QGroupBox("预处理方法 (可多选)")
preproc_layout = QVBoxLayout()
preproc_grid = QGridLayout()
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
for i, method in enumerate(preproc_methods):
checkbox = QCheckBox(method)
checkbox.setChecked(True)
self.preproc_checkboxes[method] = checkbox
preproc_grid.addWidget(checkbox, i // 4, i % 4)
button_layout = QHBoxLayout()
select_all_btn = QPushButton("全选")
deselect_all_btn = QPushButton("全不选")
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
button_layout.addWidget(select_all_btn)
button_layout.addWidget(deselect_all_btn)
button_layout.addStretch()
preproc_layout.addLayout(preproc_grid)
preproc_layout.addLayout(button_layout)
preproc_group.setLayout(preproc_layout)
params_layout.addRow(preproc_group)
# 算法选择(可多选)
self.algorithm_inputs = {}
algorithms_widget = QWidget()
algorithms_layout = QVBoxLayout()
algorithms_layout.setContentsMargins(0, 0, 0, 0)
algorithms_layout.setSpacing(4)
algorithm_list = ['chl_a', 'nh3', 'mno4', 'tn', 'tp', 'tss']
for algorithm in algorithm_list:
row_widget = QWidget()
row_layout = QHBoxLayout()
row_layout.setContentsMargins(0, 0, 0, 0)
checkbox = QCheckBox(algorithm)
checkbox.setChecked(True)
spinbox = QSpinBox()
spinbox.setRange(0, 500)
spinbox.setValue(0)
spinbox.setMaximumWidth(90)
row_layout.addWidget(checkbox)
row_layout.addWidget(QLabel("对应值列索引:"))
row_layout.addWidget(spinbox)
row_layout.addStretch()
row_widget.setLayout(row_layout)
algorithms_layout.addWidget(row_widget)
self.algorithm_inputs[algorithm] = (checkbox, spinbox)
algorithms_widget.setLayout(algorithms_layout)
params_layout.addRow("非经验算法选择:", algorithms_widget)
# 光谱起始列
self.spectral_start_col = QSpinBox()
self.spectral_start_col.setRange(0, 100)
self.spectral_start_col.setValue(1)
params_layout.addRow("光谱起始列索引:", self.spectral_start_col)
# 窗口大小 (变量名已修正,避免覆盖 QWidget.window)
self.window_size_spinbox = QSpinBox()
self.window_size_spinbox.setRange(1, 20)
self.window_size_spinbox.setValue(5)
params_layout.addRow("窗口大小:", self.window_size_spinbox)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出文件路径
self.output_dir = FileSelectWidget(
"输出模型目录:",
"Directories;;All Files (*.*)"
)
self.output_dir.line_edit.setPlaceholderText("8_Regression_Modeling")
self.output_dir.browse_btn.clicked.disconnect()
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
layout.addWidget(self.output_dir)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_button = QPushButton("独立运行此步骤")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.clicked.connect(self.run_step)
layout.addWidget(self.run_button)
layout.addStretch()
self.setLayout(layout)
def get_config(self):
"""获取配置"""
selected_algorithms = [
name for name, (checkbox, _) in self.algorithm_inputs.items()
if checkbox.isChecked()
]
if not selected_algorithms:
selected_algorithms = list(self.algorithm_inputs.keys())
value_cols = {
name: spinbox.value()
for name, (_, spinbox) in self.algorithm_inputs.items()
if name in selected_algorithms
}
preprocessing_methods = [
method for method, checkbox in self.preproc_checkboxes.items()
if checkbox.isChecked()
] or ['None']
config = {
'preprocessing_methods': preprocessing_methods,
'algorithms': selected_algorithms,
'value_cols': value_cols,
'spectral_start_col': self.spectral_start_col.value(),
'window': self.window_size_spinbox.value(),
'enabled': self.enable_checkbox.isChecked()
}
output_dir = self.output_dir.get_path()
if not output_dir:
main_window = self.parent().window()
if hasattr(main_window, 'work_dir') and main_window.work_dir:
output_dir = str(Path(main_window.work_dir) / "8_Regression_Modeling")
else:
output_dir = str(Path.cwd() / "8_Regression_Modeling")
config['output_dir'] = output_dir
training_csv_path = self.training_csv_file.get_path()
if training_csv_path:
config['csv_path'] = training_csv_path
return config
def set_config(self, config):
"""设置配置"""
if 'preprocessing_methods' in config:
methods = config['preprocessing_methods']
for method, checkbox in self.preproc_checkboxes.items():
checkbox.setChecked(method in methods)
if 'algorithms' in config:
algorithm_values = config['algorithms']
for algorithm, (checkbox, spinbox) in self.algorithm_inputs.items():
checkbox.setChecked(algorithm in algorithm_values)
if 'value_cols' in config:
value_cols = config['value_cols']
if isinstance(value_cols, dict):
for algorithm, (_, spinbox) in self.algorithm_inputs.items():
if algorithm in value_cols:
spinbox.setValue(value_cols[algorithm])
else:
for _, spinbox in self.algorithm_inputs.values():
spinbox.setValue(value_cols)
if 'spectral_start_col' in config:
self.spectral_start_col.setValue(config['spectral_start_col'])
if 'window' in config:
self.window_size_spinbox.setValue(config['window'])
if 'output_dir' in config:
self.output_dir.set_path(config['output_dir'])
if 'csv_path' in config:
self.training_csv_file.set_path(config['csv_path'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
# 借用父组件的 window() 方法,安全绕过当前类的命名冲突
parent_widget = self.parentWidget()
main_window = parent_widget.window() if parent_widget else None
if main_window and hasattr(main_window, 'step5_panel'):
step5_widget = getattr(main_window.step5_panel, 'output_file', None)
step5_output_path = ""
if hasattr(step5_widget, 'get_path'):
step5_output_path = step5_widget.get_path() or ""
elif hasattr(step5_widget, 'text'):
step5_output_path = step5_widget.text() or ""
if step5_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_output_path):
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
existing = self.training_csv_file.get_path()
if not existing or not existing.strip():
self.training_csv_file.set_path(step5_output_path)
# 2. 自动填充输出目录8_Regression_Modeling
if self.work_dir:
output_dir = os.path.join(self.work_dir, "8_Regression_Modeling")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_dir.get_path()
if not existing_out or not existing_out.strip():
self.output_dir.set_path(output_dir)
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
# 借用父组件的 window() 方法,安全绕过当前类的命名冲突
parent_widget = self.parentWidget()
mw = parent_widget.window() if parent_widget else None
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_output_dir(self):
"""浏览输出目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "8_Regression_Modeling")
dir_path = QFileDialog.getExistingDirectory(self, "选择输出模型目录", default)
if dir_path:
self.output_dir.set_path(dir_path)
def run_step(self):
"""独立运行步骤6.5"""
training_csv_path = self.training_csv_file.get_path()
if not training_csv_path:
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件")
return
if not os.path.exists(training_csv_path):
QMessageBox.warning(self, "输入错误", "训练数据CSV文件不存在")
return
config = self.get_config()
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step6_5', {'step6_5': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
def _toggle_checkboxes(self, checkboxes_dict, checked):
"""统一设置预处理checkbox状态"""
for checkbox in checkboxes_dict.values():
checkbox.setChecked(checked)

View File

@ -0,0 +1,374 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step6_75 面板 - 自定义回归分析
"""
import os
from typing import Dict
import pandas as pd
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton,
QScrollArea, QMessageBox,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step6_75Panel(QWidget):
"""步骤6.75:自定义回归分析"""
def __init__(self, parent=None):
super().__init__(parent)
self.x_column_checkboxes: Dict[str, QCheckBox] = {}
self.y_column_checkboxes: Dict[str, QCheckBox] = {}
self.method_checkboxes: Dict[str, QCheckBox] = {}
self.csv_columns = []
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法")
hint.setStyleSheet("color: #666; font-size: 11px;")
layout.addWidget(hint)
# CSV文件选择
csv_group = QGroupBox("数据文件")
csv_layout = QVBoxLayout()
self.csv_file = FileSelectWidget(
"输入CSV文件:",
"CSV Files (*.csv);;All Files (*.*)"
)
self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed)
csv_layout.addWidget(self.csv_file)
self.refresh_btn = QPushButton("刷新列信息")
self.refresh_btn.clicked.connect(self.refresh_csv_columns)
csv_layout.addWidget(self.refresh_btn)
csv_group.setLayout(csv_layout)
layout.addWidget(csv_group)
# 自变量选择
x_group = QGroupBox("自变量列选择 (可多选)")
x_layout = QVBoxLayout()
x_scroll = QScrollArea()
x_scroll.setWidgetResizable(True)
x_scroll.setMinimumHeight(250)
x_scroll.setMaximumHeight(350)
x_widget = QWidget()
self.x_columns_layout = QGridLayout()
x_widget.setLayout(self.x_columns_layout)
x_scroll.setWidget(x_widget)
x_layout.addWidget(x_scroll)
x_btn_layout = QHBoxLayout()
self.x_select_all = QPushButton("全选")
self.x_deselect_all = QPushButton("全不选")
self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True))
self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False))
x_btn_layout.addWidget(self.x_select_all)
x_btn_layout.addWidget(self.x_deselect_all)
x_btn_layout.addStretch()
x_layout.addLayout(x_btn_layout)
x_group.setLayout(x_layout)
layout.addWidget(x_group)
# 因变量选择
y_group = QGroupBox("因变量列选择 (可多选)")
y_layout = QVBoxLayout()
y_scroll = QScrollArea()
y_scroll.setWidgetResizable(True)
y_scroll.setMinimumHeight(200)
y_scroll.setMaximumHeight(300)
y_widget = QWidget()
self.y_columns_layout = QGridLayout()
y_widget.setLayout(self.y_columns_layout)
y_scroll.setWidget(y_widget)
y_layout.addWidget(y_scroll)
y_btn_layout = QHBoxLayout()
self.y_select_all = QPushButton("全选")
self.y_deselect_all = QPushButton("全不选")
self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True))
self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False))
y_btn_layout.addWidget(self.y_select_all)
y_btn_layout.addWidget(self.y_deselect_all)
y_btn_layout.addStretch()
y_layout.addLayout(y_btn_layout)
y_group.setLayout(y_layout)
layout.addWidget(y_group)
# 回归方法选择
method_group = QGroupBox("回归方法选择 (可多选)")
method_layout = QVBoxLayout()
method_grid = QGridLayout()
regression_methods = [
'linear', 'exponential', 'power', 'logarithmic',
'polynomial', 'hyperbolic', 'sigmoidal'
]
for i, method in enumerate(regression_methods):
checkbox = QCheckBox(method)
if method in ['linear', 'exponential', 'power', 'logarithmic']:
checkbox.setChecked(True)
self.method_checkboxes[method] = checkbox
method_grid.addWidget(checkbox, i // 3, i % 3)
method_layout.addLayout(method_grid)
method_btn_layout = QHBoxLayout()
self.method_select_all = QPushButton("全选")
self.method_deselect_all = QPushButton("全不选")
self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True))
self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False))
method_btn_layout.addWidget(self.method_select_all)
method_btn_layout.addWidget(self.method_deselect_all)
method_btn_layout.addStretch()
method_layout.addLayout(method_btn_layout)
method_group.setLayout(method_layout)
layout.addWidget(method_group)
# 输出目录
output_group = QGroupBox("输出设置")
output_layout = QFormLayout()
self.output_dir = QLineEdit()
self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充
output_layout.addRow("输出目录名:", self.output_dir)
output_group.setLayout(output_layout)
layout.addWidget(output_group)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_button = QPushButton("独立运行此步骤")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.clicked.connect(self.run_step)
layout.addWidget(self.run_button)
layout.addStretch()
self.setLayout(layout)
def toggle_checkboxes(self, checkboxes_dict, checked):
"""统一设置checkbox状态"""
for checkbox in checkboxes_dict.values():
checkbox.setChecked(checked)
def on_csv_file_changed(self):
"""CSV文件改变时自动刷新列信息"""
self.refresh_csv_columns()
def refresh_csv_columns(self):
"""刷新CSV文件的列信息"""
csv_path = self.csv_file.get_path()
if not csv_path or not os.path.exists(csv_path):
self.csv_columns = []
self.update_column_widgets()
return
try:
df = pd.read_csv(csv_path, nrows=0)
self.csv_columns = list(df.columns)
self.update_column_widgets()
except Exception as e:
self.csv_columns = []
self.update_column_widgets()
print(f"读取CSV列信息失败: {e}")
def update_column_widgets(self):
"""更新列选择组件"""
for checkbox in self.x_column_checkboxes.values():
checkbox.setParent(None)
self.x_column_checkboxes.clear()
for checkbox in self.y_column_checkboxes.values():
checkbox.setParent(None)
self.y_column_checkboxes.clear()
if not self.csv_columns:
return
for i, col in enumerate(self.csv_columns):
checkbox = QCheckBox(col)
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
checkbox.setChecked(True)
self.x_column_checkboxes[col] = checkbox
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
for i, col in enumerate(self.csv_columns):
checkbox = QCheckBox(col)
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
checkbox.setChecked(True)
self.y_column_checkboxes[col] = checkbox
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
self.x_columns_layout.update()
self.y_columns_layout.update()
def get_config(self):
selected_x_columns = [
col for col, checkbox in self.x_column_checkboxes.items()
if checkbox.isChecked()
]
selected_y_columns = [
col for col, checkbox in self.y_column_checkboxes.items()
if checkbox.isChecked()
]
selected_methods = [
method for method, checkbox in self.method_checkboxes.items()
if checkbox.isChecked()
]
if not selected_methods:
selected_methods = 'all'
return {
'csv_path': self.csv_file.get_path() or None,
'x_columns': selected_x_columns,
'y_columns': selected_y_columns,
'methods': selected_methods,
'output_dir': self.output_dir.text().strip() or None,
'enabled': self.enable_checkbox.isChecked()
}
def set_config(self, config):
if 'csv_path' in config:
self.csv_file.set_path(config['csv_path'])
self.refresh_csv_columns()
if 'x_columns' in config:
selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set()
for col, checkbox in self.x_column_checkboxes.items():
checkbox.setChecked(col in selected_x)
if 'y_columns' in config:
selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set()
for col, checkbox in self.y_column_checkboxes.items():
checkbox.setChecked(col in selected_y)
if 'methods' in config:
methods = config['methods']
if isinstance(methods, list):
selected_methods = set(methods)
elif methods == 'all':
selected_methods = set(self.method_checkboxes.keys())
else:
selected_methods = set()
for method, checkbox in self.method_checkboxes.items():
checkbox.setChecked(method in selected_methods)
if 'output_dir' in config:
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
# 1. 尝试从 Step5 界面读取训练光谱 CSV 路径
main_window = self.window()
if main_window and hasattr(main_window, 'step5_panel'):
step5_widget = getattr(main_window.step5_panel, 'output_file', None)
step5_output_path = ""
if hasattr(step5_widget, 'get_path'):
step5_output_path = step5_widget.get_path() or ""
elif hasattr(step5_widget, 'text'):
step5_output_path = step5_widget.text() or ""
if step5_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_output_path):
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
existing = self.csv_file.get_path()
if not existing or not existing.strip():
self.csv_file.set_path(step5_output_path)
# 2. 自动填充输出目录9_Custom_Regression_Modeling
if self.work_dir:
output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_dir.text().strip()
if not existing_out:
self.output_dir.setText(output_dir)
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def run_step(self):
"""独立运行步骤6.75"""
csv_path = self.csv_file.get_path()
if not csv_path:
QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件")
return
if not os.path.exists(csv_path):
QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在")
return
selected_x_columns = [
col for col, checkbox in self.x_column_checkboxes.items()
if checkbox.isChecked()
]
if not selected_x_columns:
QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列")
return
selected_y_columns = [
col for col, checkbox in self.y_column_checkboxes.items()
if checkbox.isChecked()
]
if not selected_y_columns:
QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列")
return
selected_methods = [
method for method, checkbox in self.method_checkboxes.items()
if checkbox.isChecked()
]
if not selected_methods:
QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法")
return
config = self.get_config()
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step6_75', {'step6_75': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")

View File

@ -0,0 +1,415 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step6 面板 - 机器学习建模
"""
import os
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
QPushButton, QFileDialog, QMessageBox,
)
from PyQt5.QtCore import Qt
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
# ============================================================
# 中文映射表(内部键名 -> 显示文本)
# ============================================================
# 预处理方法:内部键 -> 显示文本
PREPROC_CHINESE = {
'None': '无 (None)',
'MMS': '最小-最大归一化 (MMS)',
'SS': '标度化 (SS)',
'SNV': '标准正态变换 (SNV)',
'MA': '移动平均 (MA)',
'SG': 'Savitzky-Golay (SG)',
'MSC': '多元散射校正 (MSC)',
'D1': '一阶导数 (D1)',
'D2': '二阶导数 (D2)',
'DT': '去趋势 (DT)',
'CT': '中心化 (CT)',
}
# 模型类型:内部键 -> 显示文本
MODEL_CHINESE = {
# 线性模型
'LinearRegression': '多元线性回归 (MLR)',
'Ridge': '岭回归 (Ridge)',
'Lasso': '套索回归 (Lasso)',
'ElasticNet': '弹性网络 (ElasticNet)',
'PLS': '偏最小二乘 (PLSR)',
# 树模型
'DecisionTree': '决策树 (CART)',
'RF': '随机森林 (RF)',
'ExtraTrees': '极端随机树 (ET)',
'XGBoost': '极值梯度提升 (XGBoost)',
'LightGBM': '轻量梯度提升 (LightGBM)',
'CatBoost': '类别梯度提升 (CatBoost)',
# 集成学习
'GradientBoosting': '梯度提升树 (GBDT)',
'AdaBoost': '自适应提升 (AdaBoost)',
# 其他模型
'SVR': '支持向量回归 (SVR)',
'KNN': 'K近邻回归 (KNN)',
'MLP': '多层感知机 (BP神经网络)',
}
# 数据划分方法:内部键 -> 显示文本
SPLIT_CHINESE = {
'spxy': 'SPXY 算法 (考量X-Y空间)',
'ks': 'KS 算法 (考量X空间)',
'random': '随机划分 (Random)',
}
class Step6Panel(QWidget):
"""步骤6机器学习建模"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 标题
# 训练数据文件(用于独立运行)
self.training_csv_file = FileSelectWidget(
"训练数据:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.training_csv_file)
# 机器学习模型页面
self.ml_page = QWidget()
self.create_ml_page()
layout.addWidget(self.ml_page)
# 输出文件路径
self.output_path = FileSelectWidget(
"输出文件:",
"CSV Files (*.csv);;All Files (*.*)",
mode="save"
)
self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
self.output_path.browse_btn.clicked.disconnect()
self.output_path.browse_btn.clicked.connect(self.browse_output_path)
layout.addWidget(self.output_path)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(False)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
layout.addStretch()
self.setLayout(layout)
def create_ml_page(self):
"""创建机器学习模型页面"""
layout = QVBoxLayout()
# 参数设置
params_group = QGroupBox("训练参数")
params_layout = QFormLayout()
self.feature_start = QLineEdit()
self.feature_start.setText("374.285004")
params_layout.addRow("特征起始列:", self.feature_start)
self.cv_folds = QSpinBox()
self.cv_folds.setRange(2, 10)
self.cv_folds.setValue(3)
params_layout.addRow("交叉验证折数:", self.cv_folds)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 预处理方法 - 多选
preproc_group = QGroupBox("预处理方法 (可多选)")
preproc_layout = QVBoxLayout()
preproc_grid = QGridLayout()
self.preproc_checkboxes = {}
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
for i, method in enumerate(preproc_methods):
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
checkbox.setChecked(False)
self.preproc_checkboxes[method] = checkbox
preproc_grid.addWidget(checkbox, i // 4, i % 4)
button_layout = QHBoxLayout()
select_all_btn = QPushButton("全选")
deselect_all_btn = QPushButton("全不选")
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
button_layout.addWidget(select_all_btn)
button_layout.addWidget(deselect_all_btn)
button_layout.addStretch()
preproc_layout.addLayout(preproc_grid)
preproc_layout.addLayout(button_layout)
preproc_group.setLayout(preproc_layout)
layout.addWidget(preproc_group)
# 模型选择 - 多选
model_group = QGroupBox("模型类型 (可多选)")
model_layout = QVBoxLayout()
model_grid = QGridLayout()
self.model_checkboxes = {}
model_groups = [
("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
("【集成学习】", ['GradientBoosting', 'AdaBoost']),
("【其他模型】", ['SVR', 'KNN', 'MLP'])
]
row = 0
for group_name, models in model_groups:
group_label = QLabel(f"<b>{group_name}</b>")
group_label.setStyleSheet(
f"background-color: {ModernStylesheet.COLORS['hover']}; "
f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
f"border-radius: 3px;"
)
model_grid.addWidget(group_label, row, 0, 1, 4)
row += 1
for i, model in enumerate(models):
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
checkbox.setChecked(False)
self.model_checkboxes[model] = checkbox
model_grid.addWidget(checkbox, row, i % 4)
if (i + 1) % 4 == 0:
row += 1
row += 1
model_button_layout = QHBoxLayout()
model_select_all = QPushButton("全选")
model_deselect_all = QPushButton("全不选")
model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
model_button_layout.addWidget(model_select_all)
model_button_layout.addWidget(model_deselect_all)
model_button_layout.addStretch()
model_layout.addLayout(model_grid)
model_layout.addLayout(model_button_layout)
model_group.setLayout(model_layout)
layout.addWidget(model_group)
# 数据划分方法 - 多选
split_group = QGroupBox("数据划分方法 (可多选)")
split_layout = QVBoxLayout()
split_grid = QGridLayout()
self.split_checkboxes = {}
split_methods = ['spxy', 'ks', 'random']
for i, method in enumerate(split_methods):
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
checkbox.setChecked(False)
self.split_checkboxes[method] = checkbox
split_grid.addWidget(checkbox, 0, i)
split_button_layout = QHBoxLayout()
split_select_all = QPushButton("全选")
split_deselect_all = QPushButton("全不选")
split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
split_button_layout.addWidget(split_select_all)
split_button_layout.addWidget(split_deselect_all)
split_button_layout.addStretch()
split_layout.addLayout(split_grid)
split_layout.addLayout(split_button_layout)
split_group.setLayout(split_layout)
layout.addWidget(split_group)
self.ml_page.setLayout(layout)
def _toggle_checkboxes(self, checkboxes_dict, checked):
"""统一设置checkbox状态"""
for checkbox in checkboxes_dict.values():
checkbox.setChecked(checked)
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_output_path(self):
"""浏览输出文件路径(保存对话框)"""
current = self.output_path.get_path().strip()
if current:
initial_dir = os.path.dirname(current)
initial_file = os.path.basename(current)
else:
initial_dir = ""
initial_file = ""
if not initial_dir or not os.path.isdir(initial_dir):
# 默认定位到 indices 目录
work_dir = self._get_default_work_dir()
initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
if initial_dir and not os.path.isdir(initial_dir):
os.makedirs(initial_dir, exist_ok=True)
file_path, _ = QFileDialog.getSaveFileName(
self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
"CSV Files (*.csv);;All Files (*.*)"
)
if file_path:
self.output_path.set_path(file_path)
def get_config(self):
"""获取配置"""
preprocessing_methods = [
method for method, checkbox in self.preproc_checkboxes.items()
if checkbox.isChecked()
]
model_names = [
model for model, checkbox in self.model_checkboxes.items()
if checkbox.isChecked()
]
split_methods = [
method for method, checkbox in self.split_checkboxes.items()
if checkbox.isChecked()
]
config = {
'feature_start_column': self.feature_start.text(),
'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
'model_names': model_names if model_names else ['SVR'],
'split_methods': split_methods if split_methods else ['random'],
'cv_folds': self.cv_folds.value()
}
training_csv_path = self.training_csv_file.get_path()
if training_csv_path:
config['training_csv_path'] = training_csv_path
output_path = self.output_path.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'feature_start_column' in config:
self.feature_start.setText(str(config['feature_start_column']))
if 'cv_folds' in config:
self.cv_folds.setValue(config['cv_folds'])
if 'preprocessing_methods' in config:
methods = config['preprocessing_methods']
for method, checkbox in self.preproc_checkboxes.items():
checkbox.setChecked(method in methods)
if 'model_names' in config:
models = config['model_names']
for model, checkbox in self.model_checkboxes.items():
checkbox.setChecked(model in models)
if 'split_methods' in config:
methods = config['split_methods']
for method, checkbox in self.split_checkboxes.items():
checkbox.setChecked(method in methods)
if 'training_csv_path' in config:
self.training_csv_file.set_path(config['training_csv_path'])
if 'output_path' in config:
self.output_path.set_path(config['output_path'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充训练数据和输出路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
# 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
main_window = self.window()
if hasattr(main_window, 'step5_panel'):
# 优先直接从 Step5 的输出 widget 读取
step5_output = main_window.step5_panel.output_file.get_path()
if step5_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_output):
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
self.training_csv_file.set_path(step5_output)
elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
# 回退:从 Step5 的 config 字典中查找可能的键名
step5_cfg = main_window.step5_panel.get_config()
step5_csv = (
step5_cfg.get('training_spectra_path')
or step5_cfg.get('output_file')
or step5_cfg.get('csv_path')
or step5_cfg.get('output_csv')
)
if step5_csv:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step5_csv):
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
self.training_csv_file.set_path(step5_csv)
# 2. 自动填充输出文件路径(基于工作目录和输入文件名)
# 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
# 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
if self.work_dir:
indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
os.makedirs(indices_dir, exist_ok=True)
training_csv = self.training_csv_file.get_path()
if training_csv:
basename = os.path.splitext(os.path.basename(training_csv))[0]
output_file = f"{basename}_indices.csv"
else:
output_file = "water_quality_indices.csv"
output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
self.output_path.set_path(output_path)
else:
self.output_path.set_path("")
def run_step(self):
"""独立运行步骤6"""
training_csv_path = self.training_csv_file.get_path()
if not training_csv_path:
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件")
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step6': self.get_config()}
main_window.run_single_step('step6', config)
def get_training_params(self):
"""获取模型训练参数"""
return {
'pipeline_type': 'machine_learning',
'feature_start': float(self.feature_start.text()),
'cv_folds': self.cv_folds.value(),
'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
}

View File

@ -0,0 +1,208 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step7 面板 - 采样点生成
"""
import os
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QPushButton, QCheckBox, QSpinBox, QMessageBox,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step7Panel(QWidget):
"""步骤7采样点生成"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 去耀斑影像文件(用于独立运行)
self.deglint_img_file = FileSelectWidget(
"去耀斑影像:",
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
)
layout.addWidget(self.deglint_img_file)
# 水域掩膜文件(可选,用于独立运行)
self.water_mask_file = FileSelectWidget(
"水域掩膜:",
"Mask Files (*.dat *.tif);;All Files (*.*)"
)
self.water_mask_file.label.setText("水域掩膜:")
layout.addWidget(self.water_mask_file)
# 参数设置
params_group = QGroupBox("采样参数")
params_layout = QFormLayout()
self.interval = QSpinBox()
self.interval.setRange(10, 500)
self.interval.setValue(50)
params_layout.addRow("采样点间隔(像素):", self.interval)
self.sample_radius = QSpinBox()
self.sample_radius.setRange(1, 50)
self.sample_radius.setValue(5)
params_layout.addRow("采样半径(像素):", self.sample_radius)
self.chunk_size = QSpinBox()
self.chunk_size.setRange(100, 10000)
self.chunk_size.setValue(1000)
params_layout.addRow("处理块大小:", self.chunk_size)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出文件路径
self.output_file = FileSelectWidget(
"输出采样点:",
"CSV Files (*.csv);;All Files (*.*)"
)
self.output_file.line_edit.setPlaceholderText("sampling_points.csv")
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
layout.addStretch()
self.setLayout(layout)
def get_config(self):
"""获取配置"""
config = {
'interval': self.interval.value(),
'sample_radius': self.sample_radius.value(),
'chunk_size': self.chunk_size.value(),
}
deglint_img_path = self.deglint_img_file.get_path()
if deglint_img_path:
config['deglint_img_path'] = deglint_img_path
water_mask_path = self.water_mask_file.get_path()
if water_mask_path:
config['water_mask_path'] = water_mask_path
# 注意step7_generate_sampling_points 不接受 output_path 参数,输出路径由 pipeline 内部自动生成
return config
def set_config(self, config):
"""设置配置"""
if 'interval' in config:
self.interval.setValue(config['interval'])
if 'sample_radius' in config:
self.sample_radius.setValue(config['sample_radius'])
if 'chunk_size' in config:
self.chunk_size.setValue(config['chunk_size'])
if 'deglint_img_path' in config:
self.deglint_img_file.set_path(config['deglint_img_path'])
if 'water_mask_path' in config:
self.water_mask_file.set_path(config['water_mask_path'])
if 'glint_mask_path' in config:
self.glint_mask_file.set_path(config['glint_mask_path'])
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充去耀斑影像和掩膜路径
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径)
"""
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
# 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径)
deglint_path = None
if pipeline and hasattr(pipeline, 'step_outputs'):
step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {})
deglint_path = (
step3_outputs.get('deglint_image')
or step3_outputs.get('output_path')
or step3_outputs.get('output_file')
or step3_outputs.get('deglint_img_path')
)
# 回退:从 step3 面板 widget 直接读取(可能是相对路径)
if not deglint_path and hasattr(main_window, 'step3_panel'):
deglint_path = main_window.step3_panel.output_file.get_path()
if deglint_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(deglint_path):
deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/')
self.deglint_img_file.set_path(deglint_path)
# 2. 填充水域掩膜路径优先级pipeline.step_outputs > step1_panel > 1_water_mask > input-test
water_mask_path = None
if pipeline and hasattr(pipeline, 'step_outputs'):
step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {})
water_mask_path = (
step1_outputs.get('water_mask')
or step1_outputs.get('output_path')
or step1_outputs.get('output_file')
)
# 回退:从 step1 面板 widget 直接读取
if not water_mask_path and hasattr(main_window, 'step1_panel'):
water_mask_path = main_window.step1_panel.output_file.get_path()
# 备选:扫描 1_water_mask 目录下的 .dat 文件
if not water_mask_path and self.work_dir:
mask_dir = os.path.join(self.work_dir, "1_water_mask")
if os.path.isdir(mask_dir):
dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')]
if dat_files:
water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/')
# 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat
if not water_mask_path and self.work_dir:
input_test_dir = os.path.join(self.work_dir, "input-test")
if os.path.isdir(input_test_dir):
dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')]
# 优先匹配 water_mask_from_shp.dat
for f in dat_files:
if 'water_mask_from_shp' in f.lower():
water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/')
break
# 否则取第一个 .dat 文件
if not water_mask_path and dat_files:
water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/')
if water_mask_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(water_mask_path):
water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/')
self.water_mask_file.set_path(water_mask_path)
# 3. 自动填充输出路径(绝对路径)
if self.work_dir:
output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
self.output_file.set_path(output_path.replace('\\', '/'))
def run_step(self):
"""独立运行步骤7"""
deglint_img_path = self.deglint_img_file.get_path()
if not deglint_img_path:
QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step7': self.get_config()}
main_window.run_single_step('step7', config)

View File

@ -0,0 +1,226 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step8_5 面板 - 非经验模型预测
"""
import os
from pathlib import Path
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
QFileDialog,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step8_5Panel(QWidget):
"""步骤8.5:非经验模型预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 采样光谱CSV文件选择
self.sampling_csv_file = FileSelectWidget(
"采样光谱CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.sampling_csv_file)
# 模型目录选择
self.models_dir_file = FileSelectWidget(
"模型目录:",
"Directories;;All Files (*.*)"
)
self.models_dir_file.label.setText("模型目录:")
self.models_dir_file.browse_btn.clicked.disconnect()
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
layout.addWidget(self.models_dir_file)
# 参数设置
params_group = QGroupBox("预测参数")
params_layout = QFormLayout()
self.metric = QComboBox()
self.metric.addItems(['Average Accuracy(%)', 'Min Accuracy(%)', 'Max Accuracy(%)'])
params_layout.addRow("模型选择指标:", self.metric)
self.prediction_column = QLineEdit()
self.prediction_column.setText("prediction")
params_layout.addRow("预测列名:", self.prediction_column)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出路径
self.output_file = FileSelectWidget(
"输出文件夹:",
"Directories;;All Files (*.*)"
)
self.output_file.label.setText("输出文件夹:")
self.output_file.browse_btn.clicked.disconnect()
self.output_file.browse_btn.clicked.connect(self.browse_output_dir)
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_button = QPushButton("独立运行此步骤")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.clicked.connect(self.run_step)
layout.addWidget(self.run_button)
layout.addStretch()
self.setLayout(layout)
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充采样光谱和回归模型目录
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
if main_window and hasattr(main_window, 'step7_panel'):
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
step7_output_path = ""
if hasattr(step7_widget, 'get_path'):
step7_output_path = step7_widget.get_path() or ""
elif hasattr(step7_widget, 'text'):
step7_output_path = step7_widget.text() or ""
if step7_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step7_output_path):
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
existing = self.sampling_csv_file.get_path()
if not existing or not existing.strip():
self.sampling_csv_file.set_path(step7_output_path)
# 2. 尝试从 Step6.5 界面读取回归模型目录
if main_window and hasattr(main_window, 'step6_5_panel'):
step6_5_widget = getattr(main_window.step6_5_panel, 'output_dir', None)
step6_5_models_dir = ""
if hasattr(step6_5_widget, 'get_path'):
step6_5_models_dir = step6_5_widget.get_path() or ""
elif hasattr(step6_5_widget, 'text'):
step6_5_models_dir = step6_5_widget.text() or ""
if step6_5_models_dir:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step6_5_models_dir):
step6_5_models_dir = os.path.join(self.work_dir or '', step6_5_models_dir).replace('\\', '/')
existing_models = self.models_dir_file.get_path()
if not existing_models or not existing_models.strip():
self.models_dir_file.set_path(step6_5_models_dir)
# 3. 自动填充输出路径(非经验模型预测目录)
if self.work_dir:
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Non_Empirical_Prediction")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_file.get_path()
if not existing_out or not existing_out.strip():
self.output_file.set_path(output_dir)
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_models_dir(self):
"""浏览模型目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "8_Regression_Modeling")
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
if dir_path:
self.models_dir_file.set_path(dir_path)
def browse_output_dir(self):
"""浏览输出目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "11_12_13_predictions/Non_Empirical_Prediction")
dir_path = QFileDialog.getExistingDirectory(self, "选择输出文件夹", default)
if dir_path:
self.output_file.set_path(dir_path)
def get_config(self):
"""获取配置"""
config = {
'metric': self.metric.currentText(),
'prediction_column': self.prediction_column.text(),
'enabled': self.enable_checkbox.isChecked()
}
sampling_csv_path = self.sampling_csv_file.get_path()
if sampling_csv_path:
config['sampling_csv_path'] = sampling_csv_path
models_dir = self.models_dir_file.get_path()
if models_dir:
config['models_dir'] = models_dir
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'metric' in config:
idx = self.metric.findText(config['metric'])
if idx >= 0:
self.metric.setCurrentIndex(idx)
if 'prediction_column' in config:
self.prediction_column.setText(config['prediction_column'])
if 'sampling_csv_path' in config:
self.sampling_csv_file.set_path(config['sampling_csv_path'])
if 'models_dir' in config:
self.models_dir_file.set_path(config['models_dir'])
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
def run_step(self):
"""独立运行步骤8.5"""
sampling_csv_path = self.sampling_csv_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
return
config = self.get_config()
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step8_5', {'step8_5': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")

View File

@ -0,0 +1,230 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step8_75 面板 - 自定义回归预测
"""
import os
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox,
QPushButton, QCheckBox, QMessageBox, QFileDialog,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step8_75Panel(QWidget):
"""步骤8.75:自定义回归预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 采样光谱CSV文件选择
self.sampling_csv_file = FileSelectWidget(
"采样光谱CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.sampling_csv_file)
# 自定义回归模型目录选择9_Custom_Regression_Modeling
self.regression_models_dir = FileSelectWidget(
"回归模型目录:",
"Directories;;All Files (*.*)"
)
self.regression_models_dir.label.setText("回归模型目录:")
self.regression_models_dir.browse_btn.clicked.disconnect()
self.regression_models_dir.browse_btn.clicked.connect(self.browse_regression_models_dir)
self.regression_models_dir.set_path("") # 路径由 update_from_config 根据 work_dir 自动填充
layout.addWidget(self.regression_models_dir)
# 公式CSV文件选择用于查找index_formula
self.formula_csv_file = FileSelectWidget(
"公式CSV文件:",
"CSV Files (*.csv);;All Files (*.*)"
)
self.formula_csv_file.label.setText("公式CSV文件:")
layout.addWidget(self.formula_csv_file)
# 输出目录选择
self.output_dir_widget = FileSelectWidget(
"输出目录:",
"Directories;;All Files (*.*)"
)
self.output_dir_widget.label.setText("输出目录:")
self.output_dir_widget.browse_btn.clicked.disconnect()
self.output_dir_widget.browse_btn.clicked.connect(self.browse_output_dir)
self.output_dir_widget.line_edit.setPlaceholderText("留空使用默认prediction目录")
layout.addWidget(self.output_dir_widget)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_button = QPushButton("独立运行此步骤")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.clicked.connect(self.run_step)
layout.addWidget(self.run_button)
layout.addStretch()
self.setLayout(layout)
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充采样光谱和自定义回归模型目录
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
if main_window and hasattr(main_window, 'step7_panel'):
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
step7_output_path = ""
if hasattr(step7_widget, 'get_path'):
step7_output_path = step7_widget.get_path() or ""
elif hasattr(step7_widget, 'text'):
step7_output_path = step7_widget.text() or ""
if step7_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step7_output_path):
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
existing = self.sampling_csv_file.get_path()
if not existing or not existing.strip():
self.sampling_csv_file.set_path(step7_output_path)
# 2. 尝试从 Step6.75 界面读取自定义回归模型目录
if main_window and hasattr(main_window, 'step6_75_panel'):
step6_75_widget = getattr(main_window.step6_75_panel, 'output_dir', None)
step6_75_models_dir = ""
if hasattr(step6_75_widget, 'get_path'):
step6_75_models_dir = step6_75_widget.get_path() or ""
elif hasattr(step6_75_widget, 'text'):
step6_75_models_dir = step6_75_widget.text() or ""
step6_75_models_dir = step6_75_models_dir.strip()
if step6_75_models_dir:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step6_75_models_dir):
step6_75_models_dir = os.path.join(self.work_dir or '', step6_75_models_dir).replace('\\', '/')
existing_models = self.regression_models_dir.get_path()
if not existing_models or not existing_models.strip():
self.regression_models_dir.set_path(step6_75_models_dir)
# 3. 自动填充回归模型目录(如果 step6_75 未提供)
if self.work_dir:
models_dir = self.regression_models_dir.get_path().strip()
if not models_dir:
default_models_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling").replace('\\', '/')
self.regression_models_dir.set_path(default_models_dir)
# 4. 自动填充输出目录(自定义回归预测目录)
if self.work_dir:
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Custom_Regression_Prediction")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_dir_widget.get_path()
if not existing_out or not existing_out.strip():
self.output_dir_widget.set_path(output_dir)
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_regression_models_dir(self):
"""浏览回归模型目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "9_Custom_Regression_Modeling")
dir_path = QFileDialog.getExistingDirectory(self, "选择回归模型目录", default)
if dir_path:
self.regression_models_dir.set_path(dir_path)
def browse_output_dir(self):
"""浏览输出目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "11_12_13_predictions/Custom_Regression_Prediction")
dir_path = QFileDialog.getExistingDirectory(self, "选择输出目录", default)
if dir_path:
self.output_dir_widget.set_path(dir_path)
def get_config(self):
"""获取配置"""
config = {
'enabled': self.enable_checkbox.isChecked()
}
sampling_csv_path = self.sampling_csv_file.get_path()
if sampling_csv_path:
config['sampling_csv_path'] = sampling_csv_path
regression_models_dir = self.regression_models_dir.get_path()
if regression_models_dir:
config['custom_regression_dir'] = regression_models_dir
formula_csv_path = self.formula_csv_file.get_path()
if formula_csv_path:
config['formula_csv_path'] = formula_csv_path
output_dir = self.output_dir_widget.get_path()
if output_dir:
config['output_dir'] = output_dir
return config
def set_config(self, config):
"""设置配置"""
if 'sampling_csv_path' in config:
self.sampling_csv_file.set_path(config['sampling_csv_path'])
if 'custom_regression_dir' in config:
self.regression_models_dir.set_path(config['custom_regression_dir'])
if 'formula_csv_path' in config:
self.formula_csv_file.set_path(config['formula_csv_path'])
if 'output_dir' in config:
self.output_dir_widget.set_path(config['output_dir'])
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
def run_step(self):
"""独立运行步骤8.75"""
sampling_csv_path = self.sampling_csv_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
return
regression_models_dir = self.regression_models_dir.get_path()
if not regression_models_dir:
QMessageBox.warning(self, "输入错误", "请选择回归模型目录!")
return
config = self.get_config()
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if parent and hasattr(parent, 'run_single_step'):
parent.run_single_step('step8_75', {'step8_75': config})
else:
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")

View File

@ -0,0 +1,211 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step8 面板 - 机器学习预测
"""
import os
from pathlib import Path
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
QFileDialog,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
class Step8Panel(QWidget):
"""步骤8机器学习预测"""
def __init__(self, parent=None):
super().__init__(parent)
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 采样光谱CSV文件用于独立运行
self.sampling_csv_file = FileSelectWidget(
"采样光谱CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.sampling_csv_file)
# 模型目录(用于独立运行)
self.models_dir_file = FileSelectWidget(
"模型目录:",
"Directories;;All Files (*.*)"
)
self.models_dir_file.label.setText("模型目录:")
self.models_dir_file.browse_btn.clicked.disconnect()
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
layout.addWidget(self.models_dir_file)
# 参数设置
params_group = QGroupBox("预测参数")
params_layout = QFormLayout()
self.metric = QComboBox()
self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
params_layout.addRow("模型选择指标:", self.metric)
self.prediction_column = QLineEdit()
self.prediction_column.setText("prediction")
params_layout.addRow("预测列名:", self.prediction_column)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出路径
self.output_file = FileSelectWidget(
"输出路径:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.output_file)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_btn = QPushButton("独立运行此步骤")
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_btn.clicked.connect(self.run_step)
layout.addWidget(self.run_btn)
layout.addStretch()
self.setLayout(layout)
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充采样光谱和模型目录
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
if main_window and hasattr(main_window, 'step7_panel'):
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
step7_output_path = ""
if hasattr(step7_widget, 'get_path'):
step7_output_path = step7_widget.get_path() or ""
elif hasattr(step7_widget, 'text'):
step7_output_path = step7_widget.text() or ""
if step7_output_path:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step7_output_path):
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
existing = self.sampling_csv_file.get_path()
if not existing or not existing.strip():
self.sampling_csv_file.set_path(step7_output_path)
# 2. 尝试从 Step6 界面读取监督模型目录
if main_window and hasattr(main_window, 'step6_panel'):
step6_widget = getattr(main_window.step6_panel, 'output_dir', None)
step6_models_dir = ""
if hasattr(step6_widget, 'get_path'):
step6_models_dir = step6_widget.get_path() or ""
elif hasattr(step6_widget, 'text'):
step6_models_dir = step6_widget.text() or ""
if step6_models_dir:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step6_models_dir):
step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/')
existing_models = self.models_dir_file.get_path()
if not existing_models or not existing_models.strip():
self.models_dir_file.set_path(step6_models_dir)
# 3. 自动填充输出路径(机器学习预测目录)
if self.work_dir:
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_file.get_path()
if not existing_out or not existing_out.strip():
self.output_file.set_path(output_dir)
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_models_dir(self):
"""浏览模型目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "7_Supervised_Model_Training")
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
if dir_path:
self.models_dir_file.set_path(dir_path)
def get_config(self):
"""获取配置"""
config = {
'metric': self.metric.currentText(),
'prediction_column': self.prediction_column.text(),
}
sampling_csv_path = self.sampling_csv_file.get_path()
if sampling_csv_path:
config['sampling_csv_path'] = sampling_csv_path
models_dir = self.models_dir_file.get_path()
if models_dir:
config['models_dir'] = models_dir
output_path = self.output_file.get_path()
if output_path:
config['output_path'] = output_path
return config
def set_config(self, config):
"""设置配置"""
if 'metric' in config:
idx = self.metric.findText(config['metric'])
if idx >= 0:
self.metric.setCurrentIndex(idx)
if 'prediction_column' in config:
self.prediction_column.setText(config['prediction_column'])
if 'sampling_csv_path' in config:
self.sampling_csv_file.set_path(config['sampling_csv_path'])
if 'models_dir' in config:
self.models_dir_file.set_path(config['models_dir'])
if 'output_path' in config:
self.output_file.set_path(config['output_path'])
def run_step(self):
"""独立运行步骤8"""
sampling_csv_path = self.sampling_csv_file.get_path()
models_dir = self.models_dir_file.get_path()
if not sampling_csv_path:
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
return
if not models_dir:
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
return
main_window = self.window()
if hasattr(main_window, 'run_single_step'):
config = {'step8': self.get_config()}
main_window.run_single_step('step8', config)

View File

@ -0,0 +1,533 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Step9 面板 - 分布图生成
"""
import os
import traceback
from pathlib import Path
from typing import List, Optional
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtWidgets import (
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
QRadioButton, QButtonGroup, QMessageBox, QFileDialog,
)
from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
PIPELINE_AVAILABLE = True
except ImportError:
PIPELINE_AVAILABLE = False
class Step9BatchThread(QThread):
"""专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
finished_ok = pyqtSignal(int)
failed = pyqtSignal(str)
log_message = pyqtSignal(str, str)
def __init__(self, work_dir: str, csv_paths: List[str], step9_kwargs: dict, output_dir_optional: Optional[str]):
super().__init__()
self.work_dir = work_dir
self.csv_paths = csv_paths
self.step9_kwargs = step9_kwargs
self.output_dir_optional = (output_dir_optional or "").strip() or None
def run(self):
mpl_prev = None
try:
import matplotlib
mpl_prev = matplotlib.get_backend()
except Exception:
pass
try:
import matplotlib.pyplot as plt
plt.switch_backend("Agg")
except Exception:
mpl_prev = None
try:
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
n = len(self.csv_paths)
for i, csv_p in enumerate(self.csv_paths):
self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
kw = {**self.step9_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
if self.output_dir_optional:
stem = Path(csv_p).stem
kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
else:
kw["output_image_path"] = None
pipeline.step9_generate_distribution_map(**kw)
self.finished_ok.emit(n)
except Exception as e:
self.failed.emit(f"{e}\n{traceback.format_exc()}")
finally:
if mpl_prev:
try:
import matplotlib.pyplot as plt
plt.switch_backend(mpl_prev)
except Exception:
pass
class Step9Panel(QWidget):
"""步骤9分布图生成"""
def __init__(self, parent=None):
super().__init__(parent)
self._batch_thread = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
hint = QLabel(
"独立运行:可选「单个 CSV」或「文件夹批量」扫描目录下所有 .csv"
"完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。"
)
hint.setWordWrap(True)
hint.setStyleSheet(
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
)
layout.addWidget(hint)
mode_row = QHBoxLayout()
self.mode_single_rb = QRadioButton("单个 CSV 文件")
self.mode_folder_rb = QRadioButton("文件夹批量")
self._mode_group = QButtonGroup(self)
self._mode_group.addButton(self.mode_single_rb, 0)
self._mode_group.addButton(self.mode_folder_rb, 1)
mode_row.addWidget(self.mode_single_rb)
mode_row.addWidget(self.mode_folder_rb)
mode_row.addStretch()
layout.addLayout(mode_row)
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
radio_style = """
QRadioButton {
font-size: 14px;
spacing: 8px;
color: #333333;
}
QRadioButton::indicator {
width: 16px;
height: 16px;
border: 2px solid #999999;
border-radius: 3px;
background-color: white;
}
QRadioButton::indicator:checked {
border: 2px solid #0078d4;
background-color: #0078d4;
image: none;
}
QRadioButton::indicator:hover {
border: 2px solid #005a9e;
}
"""
self.mode_single_rb.setStyleSheet(radio_style)
self.mode_folder_rb.setStyleSheet(radio_style)
self.prediction_csv_file = FileSelectWidget(
"预测结果CSV:",
"CSV Files (*.csv);;All Files (*.*)"
)
layout.addWidget(self.prediction_csv_file)
folder_row = QHBoxLayout()
self.prediction_csv_dir_label = QLabel("预测CSV目录:")
self.prediction_csv_dir_label.setMinimumWidth(120)
self.prediction_csv_dir_edit = QLineEdit()
self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
pred_dir_btn = QPushButton("浏览…")
pred_dir_btn.setMaximumWidth(80)
pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
folder_row.addWidget(self.prediction_csv_dir_label)
folder_row.addWidget(self.prediction_csv_dir_edit, 1)
folder_row.addWidget(pred_dir_btn)
self._folder_row_widget = QWidget()
self._folder_row_widget.setLayout(folder_row)
layout.addWidget(self._folder_row_widget)
self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv")
layout.addWidget(self.recursive_csv_cb)
self.boundary_file = FileSelectWidget(
"边界文件:",
"Shapefiles (*.shp);;All Files (*.*)"
)
layout.addWidget(self.boundary_file)
# 参数设置
params_group = QGroupBox("生成参数")
params_layout = QFormLayout()
self.resolution = QDoubleSpinBox()
self.resolution.setRange(1, 1000)
self.resolution.setValue(30)
params_layout.addRow("分辨率(米):", self.resolution)
self.input_crs = QLineEdit()
self.input_crs.setText("EPSG:32651")
params_layout.addRow("输入坐标系:", self.input_crs)
self.output_crs = QLineEdit()
self.output_crs.setText("EPSG:4326")
params_layout.addRow("输出坐标系:", self.output_crs)
self.show_points = QCheckBox("显示采样点")
params_layout.addRow("", self.show_points)
self.use_diffusion = QCheckBox("启用距离扩散")
self.use_diffusion.setChecked(True)
params_layout.addRow("", self.use_diffusion)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
# 输出目录
self.output_dir = FileSelectWidget(
"输出分布图目录:",
"Directories;;All Files (*.*)"
)
self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
self.output_dir.browse_btn.clicked.disconnect()
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
layout.addWidget(self.output_dir)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
self.run_button = QPushButton("独立运行此步骤")
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
self.run_button.clicked.connect(self.run_step)
layout.addWidget(self.run_button)
layout.addStretch()
self.setLayout(layout)
# 信号绑定与初始状态
self.mode_single_rb.toggled.connect(self._toggle_input_mode)
self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
self._toggle_input_mode() # 根据默认值设置初始显示状态
def _toggle_input_mode(self):
"""槽函数:根据单选框状态动态显示/隐藏对应的输入组件。"""
folder_mode = self.mode_folder_rb.isChecked()
# 单个 CSV 模式:显示单文件选择,隐藏文件夹选择
self.prediction_csv_file.setVisible(not folder_mode)
# 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择
self._folder_row_widget.setVisible(folder_mode)
self.recursive_csv_cb.setVisible(folder_mode)
def _get_default_work_dir(self):
"""获取 work_dir优先用 panel 自身缓存的,否则尝试从主窗口取"""
if hasattr(self, 'work_dir') and self.work_dir:
return str(self.work_dir)
mw = self.window()
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
return str(mw.work_dir)
return ""
def browse_prediction_csv_dir(self):
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "11_12_13_predictions")
d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
if d:
self.prediction_csv_dir_edit.setText(d)
def _collect_csv_paths_from_folder(self) -> List[str]:
folder = (self.prediction_csv_dir_edit.text() or "").strip()
if not folder or not os.path.isdir(folder):
return []
root = Path(folder)
if self.recursive_csv_cb.isChecked():
files = sorted(root.rglob("*.csv"))
else:
files = sorted(root.glob("*.csv"))
return [str(p) for p in files if p.is_file()]
def _step9_base_pipeline_kwargs(self) -> dict:
return {
'boundary_shp_path': self.boundary_file.get_path(),
'resolution': self.resolution.value(),
'input_crs': self.input_crs.text(),
'output_crs': self.output_crs.text(),
'show_sample_points': self.show_points.isChecked(),
'use_distance_diffusion': self.use_diffusion.isChecked(),
}
def get_config(self):
pred_csv = (self.prediction_csv_file.get_path() or "").strip()
folder_mode = self.mode_folder_rb.isChecked()
pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
config = {
'step9_batch_mode': 'folder' if folder_mode else 'single',
'prediction_csv_dir': pred_dir if pred_dir else None,
'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
'boundary_shp_path': self.boundary_file.get_path(),
'resolution': self.resolution.value(),
'input_crs': self.input_crs.text(),
'output_crs': self.output_crs.text(),
'show_sample_points': self.show_points.isChecked(),
'use_distance_diffusion': self.use_diffusion.isChecked(),
}
out_dir = (self.output_dir.get_path() or "").strip()
if not folder_mode and pred_csv and out_dir:
stem = Path(pred_csv).stem
config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
else:
config['output_image_path'] = None
return config
def set_config(self, config):
mode = config.get('step9_batch_mode', 'single')
if mode == 'folder':
self.mode_folder_rb.setChecked(True)
else:
self.mode_single_rb.setChecked(True)
if config.get('prediction_csv_dir'):
self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
if 'recursive_csv_scan' in config:
self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
if 'prediction_csv_path' in config and config['prediction_csv_path']:
self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
if 'boundary_shp_path' in config:
self.boundary_file.set_path(config['boundary_shp_path'])
if 'resolution' in config:
self.resolution.setValue(config['resolution'])
if 'input_crs' in config:
self.input_crs.setText(config['input_crs'])
if 'output_crs' in config:
self.output_crs.setText(config['output_crs'])
if 'show_sample_points' in config:
self.show_points.setChecked(config['show_sample_points'])
if 'use_distance_diffusion' in config:
self.use_diffusion.setChecked(config['use_distance_diffusion'])
if 'output_dir' in config and config['output_dir']:
self.output_dir.set_path(str(config['output_dir']))
elif config.get('output_image_path'):
p = Path(str(config['output_image_path']))
if p.parent and str(p.parent) != '.':
self.output_dir.set_path(str(p.parent))
def update_from_config(self, work_dir=None, pipeline=None):
"""从全局配置自动填充预测结果目录
优先使用 Step8机器学习预测的输出目录作为待预测 CSV 目录;
其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
Args:
work_dir: 工作目录路径
pipeline: Pipeline 实例(未使用,保留接口兼容性)
"""
try:
import traceback
if work_dir:
self.work_dir = work_dir
elif hasattr(self, 'work_dir') and self.work_dir:
pass
else:
self.work_dir = None
main_window = self.window()
if not main_window:
return
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
pred_dir = None
if hasattr(main_window, 'step8_panel'):
step8_widget = getattr(main_window.step8_panel, 'output_file', None)
step8_output = ""
if hasattr(step8_widget, 'get_path'):
step8_output = step8_widget.get_path() or ""
elif hasattr(step8_widget, 'text'):
step8_output = step8_widget.text() or ""
if step8_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_output):
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
# 提取父目录后追加 Machine_Learning_Prediction最底层真实子目录
base_pred_dir = str(Path(step8_output).parent)
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
# 2. 备选:从 Step8.5 界面读取非经验预测输出目录
if not pred_dir and hasattr(main_window, 'step8_5_panel'):
step8_5_widget = getattr(main_window.step8_5_panel, 'output_file', None)
step8_5_output = ""
if hasattr(step8_5_widget, 'get_path'):
step8_5_output = step8_5_widget.get_path() or ""
elif hasattr(step8_5_widget, 'text'):
step8_5_output = step8_5_widget.text() or ""
if step8_5_output:
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_5_output):
step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
pred_dir = str(Path(step8_5_output).parent)
# 3. 备选:从 Step8.75 界面读取自定义回归预测输出目录
if not pred_dir and hasattr(main_window, 'step8_75_panel'):
step8_75_widget = getattr(main_window.step8_75_panel, 'output_dir_widget', None)
step8_75_output = ""
if hasattr(step8_75_widget, 'get_path'):
step8_75_output = step8_75_widget.get_path() or ""
elif hasattr(step8_75_widget, 'text'):
step8_75_output = step8_75_widget.text() or ""
if step8_75_output:
pred_dir = step8_75_output
# 自动填入"预测CSV目录"(文件夹批量模式)
if pred_dir:
existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
if not existing_dir:
self.prediction_csv_dir_edit.setText(pred_dir)
# 切换到文件夹批量模式
self.mode_folder_rb.setChecked(True)
# 4. 自动填充输出目录14_visualization
if self.work_dir:
output_dir = os.path.join(self.work_dir, "14_visualization")
os.makedirs(output_dir, exist_ok=True)
existing_out = self.output_dir.get_path()
if not existing_out or not existing_out.strip():
self.output_dir.set_path(output_dir)
# 5. 自动探测原始矢量边界文件(.shp作为专题图底图
# 优先回溯 input-test/roi.shpgeopandas.read_file 仅支持矢量格式
if self.work_dir:
possible_shp = None
candidates = [
Path(self.work_dir).parent / "input-test" / "roi.shp",
Path(self.work_dir) / "roi.shp",
Path(self.work_dir).parent / "roi.shp",
]
for candidate in candidates:
if candidate.exists() and candidate.suffix.lower() == ".shp":
possible_shp = candidate
break
existing_boundary = (self.boundary_file.get_path() or "").strip()
if not existing_boundary and possible_shp:
self.boundary_file.set_path(str(possible_shp))
elif not existing_boundary:
# 未找到 .shp 时清空并提示用户手动选择矢量文件
self.boundary_file.set_path("")
print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。")
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")
traceback.print_exc()
def browse_output_dir(self):
"""浏览输出目录"""
default = self._get_default_work_dir()
if default:
default = os.path.join(default, "14_visualization")
dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
if dir_path:
self.output_dir.set_path(dir_path)
def run_step(self):
"""独立运行步骤9"""
if self._batch_thread and self._batch_thread.isRunning():
QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
return
boundary_shp_path = self.boundary_file.get_path()
if not boundary_shp_path:
QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
return
if not os.path.exists(boundary_shp_path):
QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
return
parent = self.parent()
while parent and not hasattr(parent, 'run_single_step'):
parent = parent.parent()
if not parent or not hasattr(parent, 'run_single_step'):
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
return
if self.mode_folder_rb.isChecked():
csv_list = self._collect_csv_paths_from_folder()
if not csv_list:
QMessageBox.warning(
self,
"输入验证失败",
"所选文件夹中未找到 .csv 文件,或目录无效。\n"
"可勾选「包含子文件夹」以递归扫描。",
)
return
if not PIPELINE_AVAILABLE:
QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
return
work_dir = getattr(parent, "work_dir", None) or "./work_dir"
work_dir = str(work_dir)
base_kw = self._step9_base_pipeline_kwargs()
out_dir_opt = (self.output_dir.get_path() or "").strip() or None
self.run_button.setEnabled(False)
self._batch_thread = Step9BatchThread(work_dir, csv_list, base_kw, out_dir_opt)
main_win = parent
def _batch_log(msg, lvl):
if hasattr(main_win, "log_message"):
main_win.log_message(msg, lvl)
self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
self._batch_thread.finished_ok.connect(self._on_step9_batch_ok, Qt.QueuedConnection)
self._batch_thread.failed.connect(self._on_step9_batch_fail, Qt.QueuedConnection)
self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection)
self._batch_thread.start()
if hasattr(parent, "log_message"):
parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV工作目录 {work_dir}", "info")
return
prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
if not prediction_csv_path:
QMessageBox.warning(
self,
"输入验证失败",
"请选择「预测结果 CSV」文件或切换到「文件夹批量」。",
)
return
if not os.path.isfile(prediction_csv_path):
QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
return
config = self.get_config()
parent.run_single_step('step9', {'step9': config})
def _on_step9_batch_ok(self, n: int):
QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
parent = self.parent()
while parent and not hasattr(parent, "log_message"):
parent = parent.parent()
if parent and hasattr(parent, "log_message"):
parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
def _on_step9_batch_fail(self, err: str):
QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
parent = self.parent()
while parent and not hasattr(parent, "log_message"):
parent = parent.parent()
if parent and hasattr(parent, "log_message"):
parent.log_message(err, "error")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -6,6 +6,7 @@
"""
import os
import sys
import base64
import json
from dataclasses import dataclass
@ -19,6 +20,15 @@ from docx.shared import Inches, Pt, Cm
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.enum.section import WD_SECTION
from docx.oxml.ns import qn
def get_resource_path(relative_path: str) -> str:
"""获取资源的绝对路径,适配 PyInstaller 打包环境。"""
if hasattr(sys, '_MEIPASS'):
return os.path.join(sys._MEIPASS, relative_path)
return os.path.abspath(
os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), relative_path)
)
from docx.oxml import OxmlElement
from docx.shared import RGBColor
import pandas as pd
@ -848,8 +858,8 @@ class WaterQualityReportGenerator:
section.different_first_page_header_footer = True
# 1. 左上角图片(增大) - 使用相对路径
cover_top_img_path = Path(__file__).parent.parent.parent / "data" / "icons" / "word" / "lica.png"
if cover_top_img_path.exists():
cover_top_img_path = get_resource_path("data/icons/word/lica.png")
if os.path.isfile(cover_top_img_path):
try:
p = doc.add_paragraph()
p.alignment = WD_ALIGN_PARAGRAPH.LEFT
@ -897,8 +907,8 @@ class WaterQualityReportGenerator:
# 4. 底部图片(增大) - 使用相对路径
cover_bottom_img_path = Path(__file__).parent.parent.parent / "data" / "icons" / "word" / "fenmian.png"
if cover_bottom_img_path.exists():
cover_bottom_img_path = get_resource_path("data/icons/word/fenmian.png")
if os.path.isfile(cover_bottom_img_path):
try:
p = doc.add_paragraph()
p.alignment = WD_ALIGN_PARAGRAPH.CENTER
@ -960,8 +970,8 @@ class WaterQualityReportGenerator:
# 第一张图片 - 使用相对路径
img1_path = Path(__file__).parent.parent.parent / "data" / "icons" / "word" / "屏幕截图 2026-03-31 144131.png"
if img1_path.exists():
img1_path = get_resource_path("data/icons/word/屏幕截图 2026-03-31 144131.png")
if os.path.isfile(img1_path):
p = doc.add_paragraph()
p.alignment = WD_ALIGN_PARAGRAPH.CENTER
p.add_run().add_picture(str(img1_path), width=Inches(6.0))
@ -999,8 +1009,8 @@ class WaterQualityReportGenerator:
self._style_heading(h, level=1)
# 插入图片 - 使用相对路径
processing_img_path = Path(__file__).parent.parent.parent / "data" / "icons" / "word" / "liucheng.png"
if processing_img_path.exists():
processing_img_path = get_resource_path("data/icons/word/liucheng.png")
if os.path.isfile(processing_img_path):
p = doc.add_paragraph()
p.alignment = WD_ALIGN_PARAGRAPH.CENTER
p.add_run().add_picture(str(processing_img_path), width=Inches(6.5))
@ -1356,8 +1366,8 @@ class WaterQualityReportGenerator:
header_para = header.add_paragraph()
# 1. 最左侧图片 - 使用相对路径
header_img_path = Path(__file__).parent.parent.parent / "data" / "icons" / "word" / "lica.png"
if header_img_path.exists():
header_img_path = get_resource_path("data/icons/word/lica.png")
if os.path.isfile(header_img_path):
try:
run_img = header_para.add_run()
run_img.add_picture(str(header_img_path), width=Inches(1.6))

View File

@ -1003,67 +1003,84 @@ class ReportGenerator:
Returns:
保存的文件路径
"""
from modeling_batch import WaterQualityModelingBatch
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
import joblib
modeler = WaterQualityModelingBatch(models_dir)
# 需要先加载训练结果
# 这里假设results已经存储在modeler中或者需要从保存的文件中读取
# 由于modeling_batch.py的结构我们需要另一种方式来获取所有结果
# 尝试遍历模型目录,查找所有保存的结果
models_path = Path(models_dir)
all_results = []
# 遍历所有目标参数文件夹
for target_folder in models_path.iterdir():
if not target_folder.is_dir():
continue
target_name = target_folder.name
# 查找所有模型文件
for model_file in target_folder.rglob("*.pkl"):
# 从文件名提取信息(假设格式为:{preprocess}_{model}_{split}.pkl
model_info = {
'target': target_name,
'model_file': str(model_file),
'preprocess': 'Unknown',
'model': 'Unknown',
'split_method': 'Unknown'
# 递归扫描 *.joblib 和 *.pkl兼容 artifacts_dir/target_name/ 的所有子目录层级
model_files = list(models_path.rglob("*.joblib")) + list(models_path.rglob("*.pkl"))
for model_file in model_files:
# 目标参数取直系父目录名(符合 artifacts_dir/target_name/ 结构)
target_name = model_file.parent.name
stem = model_file.stem
# 文件名格式:{safe_target}_{preprocess}_{model_name}.joblib
# 使用 split('_', 2) 最多切 3 段,目标 1 段、预处理 1 段、模型 1 段
parts = stem.split('_', 2)
preprocess = parts[1] if len(parts) > 1 else 'Unknown'
model_name_str = parts[2] if len(parts) > 2 else stem
# 尝试从 joblib/pkl 读取元数据,提取性能指标
metrics = {}
try:
data = joblib.load(model_file)
metadata = data.get('metadata', {})
metrics = {
'train_r2': metadata.get('train_r2', 'N/A'),
'test_r2': metadata.get('test_r2', 'N/A'),
'test_rmse': metadata.get('test_rmse', 'N/A'),
'train_rmse': metadata.get('train_rmse', 'N/A'),
'train_mae': metadata.get('train_mae', 'N/A'),
'test_mae': metadata.get('test_mae', 'N/A'),
'cv_mean': metadata.get('cv_mean', 'N/A'),
}
# 尝试从文件名解析
parts = model_file.stem.split('_')
if len(parts) >= 3:
model_info['preprocess'] = parts[0]
model_info['model'] = parts[1]
model_info['split_method'] = parts[2]
all_results.append(model_info)
# 如果有训练结果数据,使用实际指标
# 否则创建一个基本的摘要
except Exception:
pass # 加载失败时 metrics 保持为空字典,摘要中该列为 N/A
all_results.append({
'target': target_name,
'model_file': str(model_file),
'preprocess': preprocess,
'model': model_name_str,
**metrics,
})
summary_data = []
for result in all_results:
summary_data.append({
'目标参数': result['target'],
'预处理方法': result['preprocess'],
'模型名称': result['model'],
'划分方法': result['split_method'],
'模型文件': result['model_file']
'模型文件': result['model_file'],
'训练集R²': result.get('train_r2', 'N/A'),
'测试集R²': result.get('test_r2', 'N/A'),
'测试集RMSE': result.get('test_rmse', 'N/A'),
'训练集RMSE': result.get('train_rmse', 'N/A'),
'训练集MAE': result.get('train_mae', 'N/A'),
'测试集MAE': result.get('test_mae', 'N/A'),
'CV均值': result.get('cv_mean', 'N/A'),
})
if not summary_data:
print("警告:未找到模型文件,生成空摘要")
summary_data = [{
'目标参数': 'No Data',
'预处理方法': 'N/A',
'模型名称': 'N/A',
'划分方法': 'N/A',
'模型文件': 'N/A'
'模型文件': 'N/A',
'训练集R²': 'N/A',
'测试集R²': 'N/A',
'测试集RMSE': 'N/A',
'训练集RMSE': 'N/A',
'训练集MAE': 'N/A',
'测试集MAE': 'N/A',
'CV均值': 'N/A',
}]
df_summary = pd.DataFrame(summary_data)
if output_path is None:

View File

@ -96,8 +96,14 @@ class BandMathCalculator:
print(f"计算表达式: {calc_expression}")
# 安全地计算表达式
result = eval(calc_expression)
# 【新增安全防护】引入 numpy 命名空间,让 eval 引擎安全识别 nan 与 inf
import numpy as np
try:
# 即使 calc_expression 含有纯字符 nan也能被 np.nan 安全接管
result = eval(calc_expression, {"__builtins__": None}, {"nan": np.nan, "inf": np.inf, "np": np})
except Exception as e:
print(f"⚠️ 警告:公式计算异常 ({e}),该点赋值为 nan")
result = np.nan
# 返回结果
if var_name:

View File

@ -14,17 +14,67 @@ def rasterize_envi_xml(shp_filepath):
@timeit
def rasterize_shp(shp_filepath, raster_fn_out, img_path, NoData_value=None):
# ---------- 防御性处理:路径标准化 ----------
shp_filepath = os.path.abspath(shp_filepath).replace('\\', '/')
print(f"[DEBUG rasterize_shp] 标准化后的 SHP 路径: {shp_filepath}")
# 检查伴随文件完整性
shp_base = os.path.splitext(shp_filepath)[0]
for ext in ['.dbf', '.shx', '.prj']:
companion = shp_base + ext
if os.path.exists(companion):
print(f"[DEBUG rasterize_shp] 伴随文件存在: {companion}")
else:
print(f"[WARNING rasterize_shp] 伴随文件缺失: {companion}")
# 确保 GDAL/OGR 驱动已注册
gdal.AllRegister()
ogr.RegisterAll()
# 检查 ESRI Shapefile 驱动
driver = ogr.GetDriverByName("ESRI Shapefile")
if driver is None:
raise RuntimeError(
"系统中未找到 ESRI Shapefile 驱动!请检查 GDAL 是否正确安装及是否包含 Shapefile 支持。"
)
print(f"[DEBUG rasterize_shp] ESRI Shapefile 驱动: {driver.GetName()}")
# 打开参考影像获取尺寸信息
dataset = gdal.Open(img_path)
if dataset is None:
raise ValueError(f"无法打开参考影像文件: {img_path}")
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
geotransform = dataset.GetGeoTransform()
imgdata_in = dataset.GetRasterBand(1).ReadAsArray()
del dataset
# Open the data source and read in the extent
# ---------- 打开 SHP 文件(双重尝试获取详细错误) ----------
source_ds = gdal.OpenEx(shp_filepath, gdal.OF_VECTOR)
if source_ds is None:
raise ValueError(f"无法打开shapefile: {shp_filepath}")
# gdal.OpenEx 失败,尝试 ogr.Open 获取更详细的错误信息
try:
ogr_ds = ogr.Open(shp_filepath)
except Exception as ogr_err:
raise RuntimeError(
f"GDAL/OGR 无法打开 SHP 文件(详细原因):\n"
f" ogr.Open 抛出异常: {str(ogr_err)}\n"
f" 文件路径: {shp_filepath}\n"
f"常见原因:\n"
f" 1. 路径包含中文/空格/特殊字符(建议复制到纯英文路径下重试)\n"
f" 2. .dbf 或 .shx 伴随文件缺失或损坏\n"
f" 3. GDAL 未注册 ESRI Shapefile 驱动\n"
f" 4. 文件被其他程序锁定"
)
if ogr_ds is None:
raise RuntimeError(
f"ogr.Open 和 gdal.OpenEx 均返回 None无法打开 SHP 文件。\n"
f"文件路径: {shp_filepath}\n"
f"请检查:\n"
f" 1. 所有伴随文件(.dbf/.shx/.prg是否齐全\n"
f" 2. 文件是否被其他程序占用\n"
f" 3. 路径中是否存在不支持的字符"
)
# 检查图层数量,如果有多层,指定使用第一层
layer_count = source_ds.GetLayerCount()

View File

@ -1,10 +1,21 @@
from src.utils.util import *
import math
# -*- coding: utf-8 -*-
"""
采样点生成模块 - 提供分块采样和光谱数据提取功能
"""
import os
import math
# GDAL 环境变量保护(放在最前面,防止路径/编码问题)
os.environ['GDAL_FILENAME_IS_UTF8'] = 'YES'
os.environ['SHAPE_ENCODING'] = 'UTF-8'
import numpy as np
from osgeo import gdal, ogr
import spectral
from scipy import ndimage
from src.utils.util import write_bands
try:
from skimage import morphology
from skimage.morphology import skeletonize, medial_axis
@ -87,6 +98,12 @@ def get_spectral_sampling_points_chunked(bil_file, water_mask_shp, severe_glint=
ogr.UseExceptions()
try:
# ---------- 路径归一化 + 存在性检查 ----------
bil_file = os.path.abspath(bil_file).replace('\\', '/')
print(f"[路径检查] 去耀斑影像: {bil_file}")
if not os.path.exists(bil_file):
raise FileNotFoundError(f"【后端错误】无法在磁盘上找到指定的去耀斑影像: {bil_file}")
# 打开bil文件
dataset_bil = gdal.Open(bil_file)
if dataset_bil is None: