Compare commits
48 Commits
master
...
170d347e21
| Author | SHA1 | Date | |
|---|---|---|---|
| 170d347e21 | |||
| bf4237b160 | |||
| cf387c40ab | |||
| 94ed2f1f8d | |||
| 2c52ca19c5 | |||
| 2a4a7ec7be | |||
| 5a55be286f | |||
| 9ba39a7bff | |||
| d15a7a1e2b | |||
| 6d4d802ffe | |||
| abac272b31 | |||
| 95d30d8d81 | |||
| 375fea77b9 | |||
| 8c7c995985 | |||
| f96c55f361 | |||
| 14278739bf | |||
| d0eb458392 | |||
| 605ec86108 | |||
| dcbcc043e4 | |||
| b2b90050dc | |||
| 9d39e61161 | |||
| 82af2d75d3 | |||
| 820986d975 | |||
| a14d40f28d | |||
| 56de4b6fc4 | |||
| 4d23a65a21 | |||
| 27d6db3141 | |||
| 6d6bb6e402 | |||
| d7b5c45dd4 | |||
| 3c0bd29275 | |||
| ca12517d41 | |||
| 33b6a918aa | |||
| 8c7458bbe4 | |||
| 9b9365d823 | |||
| 7cadd7e437 | |||
| f24aa4f555 | |||
| 5af466b2d3 | |||
| a4e6747b54 | |||
| 0f36da742f | |||
| 742bc392a5 | |||
| a645c64987 | |||
| c12b9d8d8a | |||
| dc33ee260d | |||
| 6e51d1482c | |||
| 9cc89bcd69 | |||
| 15cc14b8e1 | |||
| 8d36c23524 | |||
| 71e3aaa8cd |
336
Step1Panel_UI联动优化说明.md
Normal file
336
Step1Panel_UI联动优化说明.md
Normal file
@ -0,0 +1,336 @@
|
||||
# Step1Panel UI 联动逻辑优化说明
|
||||
|
||||
## 📋 修改概览
|
||||
|
||||
本次优化针对 Step1Panel(水域掩膜生成步骤)的 UI 联动逻辑进行了深度重构,主要解决了:
|
||||
1. ✅ 输出掩膜组件与单选按钮的深度绑定
|
||||
2. ✅ 路径显示的斜杠混用问题
|
||||
3. ✅ 底层运行逻辑的兼容性
|
||||
|
||||
---
|
||||
|
||||
## 🎯 核心改进
|
||||
|
||||
### 1. 输出掩膜的显示/隐藏与单选按钮深度绑定
|
||||
|
||||
#### 修改位置:`Step1Panel.update_ui_state()`
|
||||
|
||||
**原逻辑问题**:
|
||||
- 输出掩膜在两种模式下都显示,不符合业务逻辑
|
||||
- "使用现有掩膜文件"时不需要指定输出路径
|
||||
|
||||
**新逻辑**:
|
||||
```python
|
||||
def update_ui_state(self):
|
||||
"""根据选择的掩膜生成方式更新UI状态(使用显示/隐藏控制)"""
|
||||
use_ndwi = self.use_ndwi_radio.isChecked()
|
||||
|
||||
# 动态显示/隐藏组件
|
||||
if use_ndwi:
|
||||
# 使用NDWI模式:隐藏掩膜文件,显示NDWI参数和输出掩膜
|
||||
self.mask_file.setVisible(False)
|
||||
self.ndwi_group.setVisible(True)
|
||||
self.output_file.setVisible(True) # 显示输出掩膜路径
|
||||
|
||||
# 当切换到NDWI模式时,如果工作目录已设置,自动填充输出路径
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
self._auto_fill_output_path()
|
||||
else:
|
||||
# 使用现有掩膜模式:显示掩膜文件,隐藏NDWI参数和输出掩膜
|
||||
self.mask_file.setVisible(True)
|
||||
self.ndwi_group.setVisible(False)
|
||||
self.output_file.setVisible(False) # 隐藏输出掩膜路径
|
||||
|
||||
# 参考影像在两种模式下都显示
|
||||
self.img_file.setVisible(True)
|
||||
```
|
||||
|
||||
**行为说明**:
|
||||
| 模式 | 掩膜文件 | NDWI参数组 | 输出掩膜 | 参考影像 |
|
||||
|------|---------|-----------|---------|---------|
|
||||
| 使用现有掩膜文件 | ✅ 显示 | ❌ 隐藏 | ❌ 隐藏 | ✅ 显示 |
|
||||
| 使用NDWI自动生成 | ❌ 隐藏 | ✅ 显示 | ✅ 显示 | ✅ 显示 |
|
||||
|
||||
---
|
||||
|
||||
### 2. 修复路径斜杠混用问题
|
||||
|
||||
#### 修改位置 1:`Step1Panel._auto_fill_output_path()` (新增方法)
|
||||
|
||||
**核心改进**:统一使用正斜杠 `/`,避免 Windows 下的 `\` 和 `/` 混用
|
||||
|
||||
```python
|
||||
def _auto_fill_output_path(self):
|
||||
"""
|
||||
自动填充输出掩膜路径(仅在NDWI模式下)
|
||||
确保路径使用正斜杠,避免斜杠混用
|
||||
"""
|
||||
if not hasattr(self, 'work_dir') or not self.work_dir:
|
||||
return
|
||||
|
||||
# 生成输出掩膜的完整路径
|
||||
output_dir = os.path.join(self.work_dir, "1_water_mask")
|
||||
os.makedirs(output_dir, exist_ok=True) # 确保目录存在
|
||||
|
||||
# 统一使用正斜杠,避免 \ 和 / 混用
|
||||
default_output_path = os.path.join(output_dir, "water_mask_out.dat").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
```
|
||||
|
||||
**关键技术点**:
|
||||
- 使用 `os.path.join()` 构建路径(适配不同操作系统)
|
||||
- 最终通过 `.replace('\\', '/')` 统一转换为正斜杠
|
||||
- 在界面显示前完成转换,确保用户看到的路径一致
|
||||
|
||||
#### 修改位置 2:`Step1Panel.update_work_directory()`
|
||||
|
||||
**原逻辑问题**:
|
||||
- 开机时直接填充路径,不考虑当前选择的模式
|
||||
- 没有斜杠格式化
|
||||
|
||||
**新逻辑**:
|
||||
```python
|
||||
def update_work_directory(self, work_dir):
|
||||
"""
|
||||
保存工作目录引用,用于后续自动填充路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
"""
|
||||
if not work_dir:
|
||||
return
|
||||
|
||||
# 保存工作目录引用
|
||||
self.work_dir = work_dir
|
||||
|
||||
# 如果当前选中的是NDWI模式,立即填充输出路径
|
||||
if self.use_ndwi_radio.isChecked():
|
||||
self._auto_fill_output_path()
|
||||
```
|
||||
|
||||
**改进说明**:
|
||||
- 只保存工作目录引用,不立即填充
|
||||
- 仅在 NDWI 模式下才调用 `_auto_fill_output_path()`
|
||||
- 配合 `update_ui_state()` 中的切换触发逻辑
|
||||
|
||||
---
|
||||
|
||||
### 3. 底层运行逻辑的兼容性保障
|
||||
|
||||
#### 修改位置:`Step1Panel.get_config()`
|
||||
|
||||
**原逻辑问题**:
|
||||
- 无论哪种模式,都传递 `output_path` 给底层 Pipeline
|
||||
- "使用现有掩膜"模式下传递空路径可能导致底层错误
|
||||
|
||||
**新逻辑**:
|
||||
```python
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
use_ndwi = self.use_ndwi_radio.isChecked()
|
||||
|
||||
config = {
|
||||
'mask_path': None if use_ndwi else self.mask_file.get_path(),
|
||||
'use_ndwi': use_ndwi,
|
||||
'ndwi_threshold': self.ndwi_threshold.value()
|
||||
}
|
||||
|
||||
# 参考影像路径(两种模式都可能需要)
|
||||
img_path = self.img_file.get_path()
|
||||
if img_path:
|
||||
config['img_path'] = img_path
|
||||
|
||||
# 输出路径:仅在NDWI模式下有效
|
||||
if use_ndwi:
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
else:
|
||||
# 使用现有掩膜时,不传递output_path,避免底层错误尝试保存文件
|
||||
config['output_path'] = None
|
||||
|
||||
return config
|
||||
```
|
||||
|
||||
**关键改进**:
|
||||
- 根据 `use_ndwi` 模式动态决定是否传递 `output_path`
|
||||
- "使用现有掩膜"模式:强制 `output_path = None`
|
||||
- "NDWI自动生成"模式:传递用户选择的路径
|
||||
|
||||
**底层兼容性**:
|
||||
```python
|
||||
# Pipeline 中的处理逻辑(已在之前的提交中实现)
|
||||
def step1_generate_water_mask(..., output_path: Optional[str] = None):
|
||||
if use_ndwi:
|
||||
if output_path:
|
||||
ndwi_output_path = output_path # 使用用户指定路径
|
||||
else:
|
||||
ndwi_output_path = str(self.water_mask_dir / "water_mask_from_ndwi.dat")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. 主窗口初始化逻辑优化
|
||||
|
||||
#### 修改位置:`WaterQualityGUI._auto_fill_output_paths()`
|
||||
|
||||
**原逻辑问题**:
|
||||
- 开机时直接调用 `set_path()` 填充输出掩膜路径
|
||||
- 不考虑当前的单选按钮状态
|
||||
|
||||
**新逻辑**:
|
||||
```python
|
||||
def _auto_fill_output_paths(self):
|
||||
"""
|
||||
根据工作目录自动填充各步骤的输出路径
|
||||
注意:Step1 的输出路径由 update_work_directory() 根据模式自动控制
|
||||
"""
|
||||
if not self.work_dir:
|
||||
return
|
||||
|
||||
# Step1: 只传递工作目录引用,不直接填充路径
|
||||
# 路径填充由 Step1Panel 根据单选按钮状态自动控制
|
||||
if hasattr(self, 'step1_panel'):
|
||||
self.step1_panel.update_work_directory(self.work_dir)
|
||||
```
|
||||
|
||||
**改进说明**:
|
||||
- 主窗口只传递工作目录引用
|
||||
- Step1Panel 内部根据模式自主决定是否填充路径
|
||||
- 解耦主窗口和子面板的逻辑依赖
|
||||
|
||||
---
|
||||
|
||||
## 🔄 完整的交互流程
|
||||
|
||||
### 场景 1:开机启动(默认选中"使用现有掩膜文件")
|
||||
|
||||
```
|
||||
1. 主窗口启动 → QTimer.singleShot(100) 延迟弹窗
|
||||
2. 用户选择工作目录 D:\work
|
||||
3. _auto_fill_output_paths() 调用 step1_panel.update_work_directory(work_dir)
|
||||
4. Step1Panel.update_work_directory() 保存 self.work_dir = "D:\work"
|
||||
5. 检查当前模式:use_existing_radio.isChecked() = True
|
||||
6. 不调用 _auto_fill_output_path(),输出掩膜保持隐藏
|
||||
7. 用户看到的界面:
|
||||
✅ 掩膜文件输入框(显示)
|
||||
✅ 参考影像输入框(显示)
|
||||
❌ NDWI参数组(隐藏)
|
||||
❌ 输出掩膜输入框(隐藏)
|
||||
```
|
||||
|
||||
### 场景 2:用户切换到"使用NDWI自动生成"
|
||||
|
||||
```
|
||||
1. 用户点击"使用NDWI自动生成"单选按钮
|
||||
2. 触发 toggled 信号 → update_ui_state()
|
||||
3. use_ndwi = True
|
||||
4. 执行显示/隐藏逻辑:
|
||||
- self.mask_file.setVisible(False) # 隐藏掩膜文件
|
||||
- self.ndwi_group.setVisible(True) # 显示NDWI参数组
|
||||
- self.output_file.setVisible(True) # 显示输出掩膜
|
||||
5. 检查工作目录:hasattr(self, 'work_dir') = True
|
||||
6. 调用 self._auto_fill_output_path()
|
||||
7. 生成路径:
|
||||
output_dir = os.path.join("D:\work", "1_water_mask")
|
||||
path = os.path.join(output_dir, "water_mask_out.dat")
|
||||
formatted_path = path.replace('\\', '/')
|
||||
# 结果:D:/work/1_water_mask/water_mask_out.dat
|
||||
8. 自动填充到输出掩膜输入框
|
||||
9. 用户看到的界面:
|
||||
❌ 掩膜文件输入框(隐藏)
|
||||
✅ 参考影像输入框(显示)
|
||||
✅ NDWI参数组(显示)
|
||||
✅ 输出掩膜输入框(显示,已填充:D:/work/1_water_mask/water_mask_out.dat)
|
||||
```
|
||||
|
||||
### 场景 3:用户点击"独立运行此步骤"
|
||||
|
||||
#### 当前选择:"使用现有掩膜文件"
|
||||
```python
|
||||
config = {
|
||||
'mask_path': "D:/data/existing_mask.dat", # 用户选择的现有掩膜
|
||||
'use_ndwi': False,
|
||||
'ndwi_threshold': 0.4,
|
||||
'img_path': "D:/data/image.dat",
|
||||
'output_path': None # ✅ 强制为 None,不尝试保存
|
||||
}
|
||||
```
|
||||
|
||||
#### 当前选择:"使用NDWI自动生成"
|
||||
```python
|
||||
config = {
|
||||
'mask_path': None, # NDWI模式不需要现有掩膜
|
||||
'use_ndwi': True,
|
||||
'ndwi_threshold': 0.4,
|
||||
'img_path': "D:/data/image.dat",
|
||||
'output_path': "D:/work/1_water_mask/water_mask_out.dat" # ✅ 传递输出路径
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ✅ 测试检查点
|
||||
|
||||
### UI 显示测试
|
||||
- [ ] 开机启动后,默认选中"使用现有掩膜文件",输出掩膜输入框应隐藏
|
||||
- [ ] 切换到"使用NDWI自动生成",输出掩膜输入框应显示,并自动填充路径
|
||||
- [ ] 切换回"使用现有掩膜文件",输出掩膜输入框应再次隐藏
|
||||
- [ ] 所有自动填充的路径应使用正斜杠 `/`,无 `\` 混用
|
||||
|
||||
### 路径格式测试
|
||||
- [ ] 工作目录:`D:\work` → 输出路径应显示为:`D:/work/1_water_mask/water_mask_out.dat`
|
||||
- [ ] 工作目录:`C:\Users\Test\Documents` → 输出路径应显示为:`C:/Users/Test/Documents/1_water_mask/water_mask_out.dat`
|
||||
|
||||
### 运行逻辑测试
|
||||
- [ ] "使用现有掩膜"模式运行:验证 `config['output_path'] == None`
|
||||
- [ ] "NDWI自动生成"模式运行:验证 `config['output_path']` 为有效路径字符串
|
||||
- [ ] 底层 Pipeline 接收 `output_path=None` 时不报错
|
||||
|
||||
---
|
||||
|
||||
## 📝 代码修改总结
|
||||
|
||||
| 文件 | 修改内容 | 行数变化 |
|
||||
|------|---------|---------|
|
||||
| `src/gui/water_quality_gui.py` | Step1Panel.update_ui_state() | +6 / -3 |
|
||||
| `src/gui/water_quality_gui.py` | Step1Panel.update_work_directory() | +10 / -8 |
|
||||
| `src/gui/water_quality_gui.py` | Step1Panel._auto_fill_output_path() (新增) | +15 / 0 |
|
||||
| `src/gui/water_quality_gui.py` | Step1Panel.get_config() | +12 / -6 |
|
||||
| `src/gui/water_quality_gui.py` | WaterQualityGUI._auto_fill_output_paths() | +3 / -4 |
|
||||
|
||||
**总计**:约 **+46 / -21** 行
|
||||
|
||||
---
|
||||
|
||||
## 🎯 优化效果
|
||||
|
||||
### 用户体验提升
|
||||
1. ✅ UI 更简洁:不需要的组件自动隐藏
|
||||
2. ✅ 路径一致性:所有路径显示统一使用正斜杠
|
||||
3. ✅ 自动化程度提高:切换模式时自动填充/清空路径
|
||||
|
||||
### 代码质量提升
|
||||
1. ✅ 职责分离:主窗口不直接操作子面板的路径填充
|
||||
2. ✅ 逻辑内聚:Step1Panel 内部自主管理显示和路径
|
||||
3. ✅ 兼容性保障:底层 Pipeline 不会收到无效的 output_path
|
||||
|
||||
### 维护性提升
|
||||
1. ✅ 新增 `_auto_fill_output_path()` 方法,单一职责
|
||||
2. ✅ 路径格式化逻辑集中在一处,便于修改
|
||||
3. ✅ 注释清晰,说明了各模式下的行为
|
||||
|
||||
---
|
||||
|
||||
## 🔧 后续可能的扩展
|
||||
|
||||
如果其他步骤也有类似的路径填充需求,可以考虑:
|
||||
1. 提取公共方法 `format_path_separator(path)` 到工具类
|
||||
2. 在 `FileSelectWidget` 类中增加路径格式化的内置支持
|
||||
3. 为所有路径输入框添加统一的验证和格式化逻辑
|
||||
|
||||
---
|
||||
|
||||
**文档生成时间**: 2026-05-06
|
||||
**修改人员**: DXC
|
||||
**关联提交**: (待提交)
|
||||
BIN
data/icons/Mega Water 1.0.jpg
Normal file
BIN
data/icons/Mega Water 1.0.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 30 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 3.0 MiB |
@ -1,58 +1,26 @@
|
||||
# 水质参数反演分析系统 - Python 依赖
|
||||
# 安装: pip install -r requirements.txt
|
||||
#
|
||||
# 说明:
|
||||
# - Windows 下 GDAL 若 pip 安装失败,建议用 conda-forge: conda install -c conda-forge gdal
|
||||
# 或使用已编译的 GDAL wheel / OSGeo4W,并保证与 rasterio 版本匹配。
|
||||
# - Word 报告(report_word)与 GUI「报告生成」页依赖 python-docx;AI 解读走 Ollama HTTP API,
|
||||
# 无需额外 pip 包(本地或远程部署 Ollama 即可)。
|
||||
|
||||
# ---------- GUI ----------
|
||||
PyQt5>=5.15.0
|
||||
|
||||
# ---------- 科学计算 ----------
|
||||
# 注:当前工程打包/运行日志显示使用 Python 3.12,因此下限按 Py3.12 兼容版本设置
|
||||
numpy>=1.26.0
|
||||
scipy>=1.11.0
|
||||
pandas>=2.0.0
|
||||
|
||||
# ---------- 机器学习 ----------
|
||||
scikit-learn>=1.4.0
|
||||
# xgboost>=2.0.0 # 可选;仅在环境已安装时 spec 会自动打入
|
||||
# lightgbm>=4.0.0 # 可选;当前流水线默认未启用
|
||||
|
||||
# ---------- 地理空间 ----------
|
||||
rasterio>=1.3.9
|
||||
fiona>=1.9.5
|
||||
shapely>=2.0.0
|
||||
geopandas>=0.14.0
|
||||
pyproj>=3.6.0
|
||||
spectral>=0.22.0
|
||||
|
||||
# ---------- 图像 ----------
|
||||
opencv-python>=4.5.0
|
||||
Pillow>=8.0.0
|
||||
scikit-image>=0.22.0
|
||||
|
||||
# ---------- 可视化 ----------
|
||||
matplotlib>=3.8.0
|
||||
seaborn>=0.11.0
|
||||
matplotlib-scalebar>=0.8.0
|
||||
|
||||
# ---------- 信号处理 ----------
|
||||
PyWavelets>=1.1.0
|
||||
|
||||
# ---------- 通用工具 ----------
|
||||
joblib>=1.1.0
|
||||
tqdm>=4.62.0
|
||||
PyYAML>=6.0
|
||||
|
||||
# ---------- 表格导出(.xlsx)----------
|
||||
openpyxl>=3.0.0
|
||||
|
||||
# ---------- Word 报告生成 ----------
|
||||
python-docx>=1.1.0
|
||||
lxml>=4.9.0
|
||||
|
||||
# ---------- 打包(可选,仅构建 exe 时需要)----------
|
||||
pyinstaller>=6.0.0
|
||||
pykrige>=1.7.3
|
||||
@ -5,11 +5,12 @@ import sys
|
||||
def _safe_add(path: str) -> None:
|
||||
if not path or not os.path.isdir(path):
|
||||
return
|
||||
try:
|
||||
if hasattr(os, "add_dll_directory"):
|
||||
if hasattr(os, "add_dll_directory"):
|
||||
try:
|
||||
os.add_dll_directory(path)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
os.environ["PATH"] = path + os.pathsep + os.environ.get("PATH", "")
|
||||
except Exception:
|
||||
@ -21,5 +22,4 @@ base = getattr(sys, "_MEIPASS", None)
|
||||
if base:
|
||||
_safe_add(base)
|
||||
_safe_add(os.path.join(base, "lib-dynload"))
|
||||
_safe_add(os.path.join(base, "DLLs"))
|
||||
|
||||
_safe_add(os.path.join(base, "DLLs"))
|
||||
4
src/auth/__init__.py
Normal file
4
src/auth/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
授权认证模块
|
||||
"""
|
||||
223
src/auth/keygen_gui.py
Normal file
223
src/auth/keygen_gui.py
Normal file
@ -0,0 +1,223 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Mega Water - 离线授权发卡器 (开发者专用)
|
||||
生成绑定特定机器码的 .lic 授权文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 确保 src.auth 在 path 中
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
_project_root = os.path.abspath(os.path.join(_current_dir, "..", ".."))
|
||||
if _project_root not in sys.path:
|
||||
sys.path.insert(0, _project_root)
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QHBoxLayout, QLabel, QLineEdit,
|
||||
QPushButton, QFileDialog, QMessageBox, QApplication, QDateEdit, QCheckBox
|
||||
)
|
||||
from PyQt5.QtCore import Qt, QDate
|
||||
|
||||
from src.auth.license_manager import generate_license
|
||||
|
||||
# 永久授权的标识日期
|
||||
PERMANENT_EXPIRY = "2099-12-31"
|
||||
|
||||
|
||||
class LicenseKeygenWindow(QWidget):
|
||||
"""授权发卡器主窗口"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("Mega Water - 离线授权发卡器 (开发者专用)")
|
||||
self.setMinimumSize(640, 360)
|
||||
self.move(400, 280)
|
||||
|
||||
self._default_save_path = os.path.join(_project_root, "license.lic")
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
# ── 全局字体:无衬线,清晰 ──
|
||||
font_family = "Microsoft YaHei" if sys.platform == "win32" else "Segoe UI"
|
||||
self.setStyleSheet(f"""
|
||||
* {{
|
||||
font-family: {font_family}, 'Segoe UI', sans-serif;
|
||||
font-size: 11pt;
|
||||
}}
|
||||
QLabel#titleLabel {{
|
||||
font-size: 16pt;
|
||||
font-weight: bold;
|
||||
color: #2c3e50;
|
||||
}}
|
||||
QLabel#tipLabel {{
|
||||
font-size: 10pt;
|
||||
color: #95a5a6;
|
||||
}}
|
||||
""")
|
||||
|
||||
main_layout = QVBoxLayout()
|
||||
main_layout.setContentsMargins(45, 40, 45, 40)
|
||||
main_layout.setSpacing(18)
|
||||
|
||||
# ── 标题 ──
|
||||
title_label = QLabel("离线授权发卡器 (开发者专用)")
|
||||
title_label.setObjectName("titleLabel")
|
||||
title_label.setAlignment(Qt.AlignCenter)
|
||||
main_layout.addWidget(title_label)
|
||||
|
||||
# ── 机器码输入行 ──
|
||||
mc_layout = QHBoxLayout()
|
||||
mc_layout.setSpacing(12)
|
||||
mc_label = QLabel("机器码:")
|
||||
mc_label.setFixedWidth(90)
|
||||
self.mc_input = QLineEdit()
|
||||
self.mc_input.setPlaceholderText("粘贴用户发来的 32 位机器码")
|
||||
self.mc_input.setMinimumHeight(36)
|
||||
self.mc_input.setMinimumWidth(400)
|
||||
mc_layout.addWidget(mc_label, 0)
|
||||
mc_layout.addWidget(self.mc_input, 1)
|
||||
main_layout.addLayout(mc_layout)
|
||||
|
||||
# ── 到期时间选择行 ──
|
||||
exp_layout = QHBoxLayout()
|
||||
exp_layout.setSpacing(14)
|
||||
exp_label = QLabel("到期时间:")
|
||||
exp_label.setFixedWidth(90)
|
||||
self.exp_edit = QDateEdit()
|
||||
self.exp_edit.setCalendarPopup(True)
|
||||
self.exp_edit.setMinimumHeight(36)
|
||||
self.exp_edit.setMinimumWidth(160)
|
||||
self.exp_edit.setDate(QDate.currentDate().addYears(1))
|
||||
|
||||
self.perm_check = QCheckBox("永久授权 (不限时)")
|
||||
self.perm_check.setMinimumHeight(36)
|
||||
self.perm_check.stateChanged.connect(self._on_perm_changed)
|
||||
|
||||
exp_layout.addWidget(exp_label, 0)
|
||||
exp_layout.addWidget(self.exp_edit, 0)
|
||||
exp_layout.addWidget(self.perm_check, 0)
|
||||
exp_layout.addStretch(1)
|
||||
main_layout.addLayout(exp_layout)
|
||||
|
||||
# ── 保存路径行 ──
|
||||
path_layout = QHBoxLayout()
|
||||
path_layout.setSpacing(12)
|
||||
path_label = QLabel("保存路径:")
|
||||
path_label.setFixedWidth(90)
|
||||
self.path_input = QLineEdit()
|
||||
self.path_input.setReadOnly(True)
|
||||
self.path_input.setMinimumHeight(36)
|
||||
self.browse_btn = QPushButton("浏览...")
|
||||
self.browse_btn.setMinimumHeight(36)
|
||||
self.browse_btn.setFixedWidth(80)
|
||||
self.browse_btn.clicked.connect(self._on_browse)
|
||||
path_layout.addWidget(path_label, 0)
|
||||
path_layout.addWidget(self.path_input, 1)
|
||||
path_layout.addWidget(self.browse_btn, 0)
|
||||
main_layout.addLayout(path_layout)
|
||||
|
||||
# ── 弹性空间 ──
|
||||
main_layout.addSpacing(10)
|
||||
|
||||
# ── 生成按钮 ──
|
||||
self.gen_btn = QPushButton("生成授权文件 (.lic)")
|
||||
self.gen_btn.setMinimumHeight(48)
|
||||
self.gen_btn.setStyleSheet("""
|
||||
QPushButton {
|
||||
background-color: #27ae60;
|
||||
color: white;
|
||||
font-size: 13pt;
|
||||
font-weight: bold;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
}
|
||||
QPushButton:hover {
|
||||
background-color: #2ecc71;
|
||||
}
|
||||
QPushButton:pressed {
|
||||
background-color: #1e8449;
|
||||
}
|
||||
""")
|
||||
self.gen_btn.clicked.connect(self._on_generate)
|
||||
main_layout.addWidget(self.gen_btn)
|
||||
|
||||
# ── 底部提示 ──
|
||||
tip_label = QLabel("生成后请将 license.lic 文件发给用户,放置到软件安装目录下即可。")
|
||||
tip_label.setObjectName("tipLabel")
|
||||
tip_label.setAlignment(Qt.AlignCenter)
|
||||
main_layout.addWidget(tip_label)
|
||||
|
||||
self.setLayout(main_layout)
|
||||
|
||||
def _on_perm_changed(self, state):
|
||||
"""永久授权复选框状态变化时,联动日期选择器"""
|
||||
if state == Qt.Checked:
|
||||
self.exp_edit.setEnabled(False)
|
||||
else:
|
||||
self.exp_edit.setEnabled(True)
|
||||
|
||||
def _on_browse(self):
|
||||
"""打开文件对话框选择保存路径"""
|
||||
path, _ = QFileDialog.getSaveFileName(
|
||||
self,
|
||||
"选择授权文件保存位置",
|
||||
self._default_save_path,
|
||||
"授权文件 (*.lic)"
|
||||
)
|
||||
if path:
|
||||
if not path.lower().endswith(".lic"):
|
||||
path += ".lic"
|
||||
self.path_input.setText(path)
|
||||
|
||||
def _on_generate(self):
|
||||
"""点击生成按钮,调用授权管理器"""
|
||||
machine_code = self.mc_input.text().strip()
|
||||
if not machine_code:
|
||||
QMessageBox.warning(self, "输入错误", "请输入机器码")
|
||||
return
|
||||
|
||||
output_path = self.path_input.text().strip()
|
||||
if not output_path:
|
||||
QMessageBox.warning(self, "输入错误", "请设置保存路径")
|
||||
return
|
||||
|
||||
# 根据是否勾选永久授权决定日期
|
||||
if self.perm_check.isChecked():
|
||||
expiry_date = PERMANENT_EXPIRY
|
||||
else:
|
||||
expiry_date = self.exp_edit.date().toString("yyyy-MM-dd")
|
||||
|
||||
ok, msg = generate_license(
|
||||
machine_code=machine_code,
|
||||
output_path=output_path,
|
||||
expiry_date=expiry_date
|
||||
)
|
||||
|
||||
if ok:
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"生成成功",
|
||||
f"✅ 授权文件已成功生成!\n\n保存路径:\n{output_path}\n\n请将此文件发给用户即可。",
|
||||
QMessageBox.Ok
|
||||
)
|
||||
else:
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"生成失败",
|
||||
f"❌ {msg}",
|
||||
QMessageBox.Ok
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ── 高 DPI 自适应(必须放在 QApplication 实例化之前)──
|
||||
from PyQt5.QtCore import Qt
|
||||
QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True)
|
||||
QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True)
|
||||
|
||||
app = QApplication(sys.argv)
|
||||
app.setApplicationName("LicenseKeygen")
|
||||
window = LicenseKeygenWindow()
|
||||
window.show()
|
||||
sys.exit(app.exec_())
|
||||
255
src/auth/license_dialog.py
Normal file
255
src/auth/license_dialog.py
Normal file
@ -0,0 +1,255 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LicenseDialog - PyQt5 授权拦截弹窗
|
||||
当授权验证失败时弹出,提示用户导入授权文件。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
|
||||
QTextEdit, QFileDialog, QMessageBox, QApplication
|
||||
)
|
||||
from PyQt5.QtCore import Qt, QTimer
|
||||
from PyQt5.QtGui import QFont, QIcon, QGuiApplication
|
||||
|
||||
# 导入授权管理器
|
||||
from src.auth.license_manager import get_machine_code, verify_license, get_license_path
|
||||
|
||||
|
||||
class LicenseDialog(QDialog):
|
||||
"""
|
||||
授权验证弹窗
|
||||
- 显示本机机器码(只读文本框)
|
||||
- 提供"一键复制"功能
|
||||
- 提供"导入授权文件"按钮
|
||||
- 导入成功后提示重启
|
||||
"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.setWindowTitle("授权验证")
|
||||
self.setWindowFlags(
|
||||
Qt.Dialog |
|
||||
Qt.WindowTitleHint |
|
||||
Qt.WindowCloseButtonHint
|
||||
)
|
||||
self.setModal(True)
|
||||
self.setMinimumWidth(540)
|
||||
|
||||
self._init_ui()
|
||||
self._load_machine_code()
|
||||
|
||||
# 窗口居中
|
||||
QTimer.singleShot(0, self._center_on_screen)
|
||||
|
||||
def _center_on_screen(self):
|
||||
"""将窗口居中到屏幕"""
|
||||
screen = QGuiApplication.primaryScreen()
|
||||
if screen:
|
||||
geo = screen.geometry()
|
||||
self.move(
|
||||
(geo.width() - self.width()) // 2,
|
||||
(geo.height() - self.height()) // 2
|
||||
)
|
||||
|
||||
def _init_ui(self):
|
||||
main_layout = QVBoxLayout(self)
|
||||
main_layout.setContentsMargins(30, 30, 30, 20)
|
||||
main_layout.setSpacing(16)
|
||||
|
||||
# ── 标题区 ──
|
||||
title_font = QFont("Microsoft YaHei", 14, QFont.Bold)
|
||||
title_label = QLabel("本软件需要授权方可运行")
|
||||
title_label.setFont(title_font)
|
||||
title_label.setAlignment(Qt.AlignCenter)
|
||||
title_label.setStyleSheet("color: #2c3e50;")
|
||||
main_layout.addWidget(title_label)
|
||||
|
||||
# ── 说明文字 ──
|
||||
hint_label = QLabel(
|
||||
"请获取授权文件(license.lic)后导入,"
|
||||
"或联系技术支持获取授权。"
|
||||
)
|
||||
hint_label.setAlignment(Qt.AlignCenter)
|
||||
hint_label.setStyleSheet("color: #7f8c8d; font-size: 12px;")
|
||||
main_layout.addWidget(hint_label)
|
||||
|
||||
# ── 机器码标签 ──
|
||||
code_label = QLabel("本机机器码(用于申请授权):")
|
||||
code_label.setStyleSheet("font-weight: bold; color: #34495e;")
|
||||
main_layout.addWidget(code_label)
|
||||
|
||||
# ── 机器码文本框 + 复制按钮 ──
|
||||
code_layout = QHBoxLayout()
|
||||
code_layout.setSpacing(8)
|
||||
|
||||
self.code_edit = QTextEdit()
|
||||
self.code_edit.setReadOnly(True)
|
||||
self.code_edit.setMaximumHeight(72)
|
||||
self.code_edit.setFont(QFont("Consolas", 13))
|
||||
self.code_edit.setStyleSheet(
|
||||
"QTextEdit {"
|
||||
" background-color: #ecf0f1;"
|
||||
" border: 1px solid #bdc3c7;"
|
||||
" border-radius: 4px;"
|
||||
" padding: 8px;"
|
||||
" color: #2c3e50;"
|
||||
"}"
|
||||
)
|
||||
code_layout.addWidget(self.code_edit, 1)
|
||||
|
||||
copy_btn = QPushButton("复制")
|
||||
copy_btn.setFixedWidth(72)
|
||||
copy_btn.setCursor(Qt.PointingHandCursor)
|
||||
copy_btn.setStyleSheet(
|
||||
"QPushButton {"
|
||||
" background-color: #3498db;"
|
||||
" color: white;"
|
||||
" border: none;"
|
||||
" border-radius: 4px;"
|
||||
" padding: 8px 4px;"
|
||||
" font-weight: bold;"
|
||||
"}"
|
||||
"QPushButton:hover { background-color: #2980b9; }"
|
||||
"QPushButton:pressed { background-color: #21618c; }"
|
||||
)
|
||||
copy_btn.clicked.connect(self._copy_code)
|
||||
code_layout.addWidget(copy_btn)
|
||||
|
||||
main_layout.addLayout(code_layout)
|
||||
|
||||
# ── 导入授权文件按钮 ──
|
||||
import_btn = QPushButton("导入授权文件 (.lic)")
|
||||
import_btn.setCursor(Qt.PointingHandCursor)
|
||||
import_btn.setStyleSheet(
|
||||
"QPushButton {"
|
||||
" background-color: #27ae60;"
|
||||
" color: white;"
|
||||
" border: none;"
|
||||
" border-radius: 6px;"
|
||||
" padding: 12px;"
|
||||
" font-size: 14px;"
|
||||
" font-weight: bold;"
|
||||
"}"
|
||||
"QPushButton:hover { background-color: #229954; }"
|
||||
"QPushButton:pressed { background-color: #1e8449; }"
|
||||
)
|
||||
import_btn.clicked.connect(self._import_license)
|
||||
main_layout.addWidget(import_btn)
|
||||
|
||||
# ── 提示文字 ──
|
||||
tip_label = QLabel(
|
||||
"导入后软件将自动重启生效。"
|
||||
)
|
||||
tip_label.setAlignment(Qt.AlignCenter)
|
||||
tip_label.setStyleSheet("color: #95a5a6; font-size: 11px;")
|
||||
main_layout.addWidget(tip_label)
|
||||
|
||||
# ── 取消按钮(退出程序)──
|
||||
cancel_btn = QPushButton("退出")
|
||||
cancel_btn.setCursor(Qt.PointingHandCursor)
|
||||
cancel_btn.setStyleSheet(
|
||||
"QPushButton {"
|
||||
" background-color: #95a5a6;"
|
||||
" color: white;"
|
||||
" border: none;"
|
||||
" border-radius: 4px;"
|
||||
" padding: 8px 20px;"
|
||||
"}"
|
||||
"QPushButton:hover { background-color: #7f8c8d; }"
|
||||
)
|
||||
cancel_btn.clicked.connect(self._quit_app)
|
||||
main_layout.addWidget(cancel_btn, 0, Qt.AlignRight)
|
||||
|
||||
main_layout.addStretch()
|
||||
|
||||
def _load_machine_code(self):
|
||||
"""读取并显示本机机器码"""
|
||||
try:
|
||||
code = get_machine_code(32)
|
||||
self.code_edit.setPlainText(code)
|
||||
except Exception as e:
|
||||
self.code_edit.setPlainText(f"读取失败: {e}")
|
||||
|
||||
def _copy_code(self):
|
||||
"""复制机器码到剪贴板"""
|
||||
clipboard = QApplication.clipboard()
|
||||
clipboard.setText(self.code_edit.toPlainText().strip())
|
||||
|
||||
# 显示反馈
|
||||
QMessageBox.information(self, "已复制", "机器码已复制到剪贴板。")
|
||||
|
||||
def _import_license(self):
|
||||
"""打开文件选择对话框,导入 .lic 授权文件"""
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self,
|
||||
"选择授权文件",
|
||||
"",
|
||||
"授权文件 (*.lic);;所有文件 (*.*)"
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
return
|
||||
|
||||
# 验证授权文件
|
||||
ok, msg = verify_license(file_path)
|
||||
if not ok:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"授权文件无效",
|
||||
f"验证失败: {msg}\n\n请确认选择了正确的授权文件。"
|
||||
)
|
||||
return
|
||||
|
||||
# 复制授权文件到标准路径
|
||||
dest_path = get_license_path()
|
||||
try:
|
||||
dest_dir = os.path.dirname(dest_path)
|
||||
if dest_dir:
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
shutil.copy2(file_path, dest_path)
|
||||
except OSError as e:
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"保存失败",
|
||||
f"无法保存授权文件: {e}"
|
||||
)
|
||||
return
|
||||
|
||||
# 成功提示,重启程序
|
||||
reply = QMessageBox.information(
|
||||
self,
|
||||
"导入成功",
|
||||
"授权文件已成功导入。\n\n软件将自动重启以应用授权。"
|
||||
if False else # 占位,维持下面的逻辑
|
||||
"授权文件已成功导入。\n软件将自动重启以应用授权。",
|
||||
QMessageBox.Ok
|
||||
)
|
||||
|
||||
self.accept()
|
||||
self._restart_app()
|
||||
|
||||
def _quit_app(self):
|
||||
"""退出程序"""
|
||||
self.reject()
|
||||
sys.exit(0)
|
||||
|
||||
def _restart_app(self):
|
||||
"""重启程序"""
|
||||
self.close()
|
||||
QApplication.quit()
|
||||
|
||||
# 延迟重启(确保 QApplication 完全退出)
|
||||
import subprocess
|
||||
import sys as _sys
|
||||
executable = _sys.executable
|
||||
if getattr(_sys, 'frozen', False):
|
||||
# PyInstaller 打包环境下
|
||||
subprocess.Popen([executable] + _sys.argv[1:])
|
||||
else:
|
||||
# 开发环境
|
||||
subprocess.Popen([executable, __file__])
|
||||
328
src/auth/license_manager.py
Normal file
328
src/auth/license_manager.py
Normal file
@ -0,0 +1,328 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
License Manager - 离线授权管理模块
|
||||
使用 HMAC-SHA256 + 盐值签名防止篡改
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import hmac
|
||||
import hashlib
|
||||
import base64
|
||||
import uuid
|
||||
import hashlib as _hashlib
|
||||
import subprocess
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
# ============================================================
|
||||
# 第一部分:硬件指纹提取(内嵌 get_machine_code)
|
||||
# ============================================================
|
||||
|
||||
def get_cpu_id() -> Optional[str]:
|
||||
"""读取 CPU 序列号(Processor ID)"""
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
result = subprocess.run(
|
||||
["wmic", "cpu", "get", "ProcessorId"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW,
|
||||
stdin=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
cpu_id = result.stdout.strip().split("\n")[-1].strip()
|
||||
if cpu_id:
|
||||
return cpu_id
|
||||
else:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
for line in f:
|
||||
if "Serial" in line or "processor" in line:
|
||||
cpu_id = line.split(":")[-1].strip()
|
||||
if cpu_id:
|
||||
return cpu_id
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_motherboard_uuid() -> Optional[str]:
|
||||
"""读取主板 UUID(BaseBoard Serial Number)"""
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
result = subprocess.run(
|
||||
["wmic", "baseboard", "get", "SerialNumber"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW,
|
||||
stdin=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
board_uuid = result.stdout.strip().split("\n")[-1].strip()
|
||||
board_uuid = re.sub(r'[^a-zA-Z0-9\-]', '', board_uuid)
|
||||
if board_uuid and board_uuid not in ("To be filled", "None"):
|
||||
return board_uuid
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["cat", "/sys/class/dmi/id/product_uuid"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
stdin=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_machine_code(code_length: int = 32) -> str:
|
||||
"""
|
||||
生成唯一的机器码(硬件指纹)
|
||||
参数:
|
||||
code_length: 机器码长度,支持 16/24/32/48/64 位,默认 32 位
|
||||
返回:
|
||||
全大写字母+数字的机器码字符串
|
||||
"""
|
||||
cpu_id = get_cpu_id() or ""
|
||||
board_uuid = get_motherboard_uuid() or ""
|
||||
raw_hardware = f"{cpu_id}-{board_uuid}"
|
||||
|
||||
if not raw_hardware.strip("-") or len(raw_hardware) < 8:
|
||||
try:
|
||||
machine_name = uuid.gethostname() or ""
|
||||
mac = ':'.join(re.findall('..', '%012x' % uuid.getnode()))
|
||||
raw_hardware = f"{machine_name}-{mac}"
|
||||
except Exception:
|
||||
raw_hardware = str(uuid.getnode())
|
||||
|
||||
raw_hardware = re.sub(r'[^a-zA-Z0-9]', '', raw_hardware)
|
||||
hash_hex = hashlib.sha256(raw_hardware.encode('utf-8')).hexdigest().upper()
|
||||
hash_hex = hash_hex.replace('O', 'X').replace('L', 'Y').replace('I', 'Z')
|
||||
|
||||
valid_lengths = [16, 24, 32, 48, 64]
|
||||
if code_length not in valid_lengths:
|
||||
code_length = 32
|
||||
|
||||
return hash_hex[:code_length]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 第二部分:授权文件格式与签名机制
|
||||
# ============================================================
|
||||
|
||||
# 开发者密钥(硬编码在软件中,用于验证授权文件)
|
||||
# 注意:实际部署时建议对密钥进行简单混淆或从外部文件加载
|
||||
DEVELOPER_SECRET = b"WaterQuality_v1_2025_SecretKey"
|
||||
LICENSE_VERSION = "1.0"
|
||||
|
||||
|
||||
def _compute_signature(payload_json: str) -> str:
|
||||
"""
|
||||
计算 HMAC-SHA256 签名
|
||||
payload_json: JSON 序列化后的字符串(不含 signature 字段)
|
||||
"""
|
||||
sig = hmac.new(
|
||||
DEVELOPER_SECRET,
|
||||
payload_json.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest().upper()
|
||||
return sig
|
||||
|
||||
|
||||
def _clean_hash(s: str) -> str:
|
||||
"""清洗哈希字符串,避免混淆字符"""
|
||||
return s.replace('O', 'X').replace('L', 'Y').replace('I', 'Z')
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 第三部分:核心 API
|
||||
# ============================================================
|
||||
|
||||
def get_license_path() -> str:
|
||||
"""获取授权文件的标准存放路径(程序根目录)"""
|
||||
if getattr(sys, 'frozen', False):
|
||||
base_dir = os.path.dirname(sys.executable)
|
||||
else:
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
return os.path.join(base_dir, "license.lic")
|
||||
|
||||
|
||||
def verify_license(license_path: Optional[str] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
校验授权文件是否匹配本机硬件指纹。
|
||||
|
||||
参数:
|
||||
license_path: 授权文件路径,默认使用标准路径
|
||||
|
||||
返回:
|
||||
(is_valid, message)
|
||||
- is_valid=True 表示授权有效
|
||||
- is_valid=False 表示授权无效,message 为具体原因
|
||||
"""
|
||||
if license_path is None:
|
||||
license_path = get_license_path()
|
||||
|
||||
# Step 1: 文件是否存在
|
||||
if not os.path.isfile(license_path):
|
||||
return False, "授权文件不存在"
|
||||
|
||||
# Step 2: 读取并解析 JSON
|
||||
try:
|
||||
with open(license_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
lic_data = json.loads(content)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
return False, f"授权文件格式错误: {e}"
|
||||
|
||||
# Step 3: 校验版本号
|
||||
version = lic_data.get("version", "")
|
||||
if version != LICENSE_VERSION:
|
||||
return False, f"授权文件版本不匹配 (期望 {LICENSE_VERSION})"
|
||||
|
||||
# Step 4: 校验过期时间
|
||||
expiry_str = lic_data.get("expiry", "")
|
||||
if expiry_str:
|
||||
try:
|
||||
expiry_dt = datetime.strptime(expiry_str, "%Y-%m-%d")
|
||||
if datetime.now() > expiry_dt:
|
||||
return False, "授权已过期"
|
||||
except ValueError:
|
||||
return False, "授权文件日期格式错误"
|
||||
|
||||
# Step 5: 提取 payload(不含 signature)
|
||||
payload_for_verify = {k: v for k, v in lic_data.items() if k != "signature"}
|
||||
payload_json = json.dumps(payload_for_verify, sort_keys=True, ensure_ascii=False)
|
||||
|
||||
# Step 6: 校验签名完整性(防篡改)
|
||||
expected_sig = _compute_signature(payload_json)
|
||||
stored_sig = lic_data.get("signature", "").upper()
|
||||
if not hmac.compare_digest(expected_sig, stored_sig):
|
||||
return False, "授权文件签名校验失败(可能被篡改)"
|
||||
|
||||
# Step 7: 校验机器码绑定
|
||||
bound_machine = lic_data.get("machine_code", "")
|
||||
current_machine = get_machine_code(32)
|
||||
if not hmac.compare_digest(bound_machine, current_machine):
|
||||
return False, "机器码不匹配(授权文件与本机不兼容)"
|
||||
|
||||
return True, "授权验证通过"
|
||||
|
||||
|
||||
def generate_license(
|
||||
machine_code: str,
|
||||
output_path: str,
|
||||
expiry_date: Optional[str] = None,
|
||||
product_name: str = "WaterQualityInversion",
|
||||
max_uses: Optional[int] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
为指定机器码生成合法的授权文件(供开发者使用)。
|
||||
|
||||
参数:
|
||||
machine_code: 目标机器的机器码(32位)
|
||||
output_path: 授权文件输出路径(含文件名,如 "D:/license.lic")
|
||||
expiry_date: 有效期截止日期,格式 "YYYY-MM-DD",默认永久
|
||||
product_name: 产品名称
|
||||
max_uses: 最大使用次数(可选,默认不限制)
|
||||
|
||||
返回:
|
||||
(success, message)
|
||||
"""
|
||||
if len(machine_code) not in (16, 24, 32, 48, 64):
|
||||
return False, f"机器码长度无效(期望 16/24/32/48/64,实际 {len(machine_code)})"
|
||||
|
||||
# 构建 payload
|
||||
payload = {
|
||||
"version": LICENSE_VERSION,
|
||||
"product": product_name,
|
||||
"machine_code": machine_code.upper(),
|
||||
"generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"expiry": expiry_date or "",
|
||||
}
|
||||
if max_uses is not None:
|
||||
payload["max_uses"] = max_uses
|
||||
|
||||
# 计算签名
|
||||
payload_json = json.dumps(payload, sort_keys=True, ensure_ascii=False)
|
||||
signature = _compute_signature(payload_json)
|
||||
|
||||
# 完整授权文件内容
|
||||
lic_content = json.dumps({**payload, "signature": signature}, indent=2, ensure_ascii=False)
|
||||
|
||||
# 写入文件
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(lic_content)
|
||||
return True, f"授权文件已生成: {output_path}"
|
||||
except OSError as e:
|
||||
return False, f"写入授权文件失败: {e}"
|
||||
|
||||
|
||||
def get_machine_info() -> dict:
|
||||
"""获取完整机器信息(调试用)"""
|
||||
return {
|
||||
"cpu_id": get_cpu_id(),
|
||||
"motherboard_uuid": get_motherboard_uuid(),
|
||||
"machine_code_16": get_machine_code(16),
|
||||
"machine_code_32": get_machine_code(32),
|
||||
"license_path": get_license_path(),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 第四部分:便捷入口(支持直接运行)
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="WaterQuality 授权管理工具")
|
||||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||||
|
||||
# 子命令:verify
|
||||
p_verify = subparsers.add_parser("verify", help="验证本机授权")
|
||||
p_verify.add_argument("-f", "--file", default=None, help="授权文件路径")
|
||||
|
||||
# 子命令:gen / generate
|
||||
p_gen = subparsers.add_parser("generate", help="为指定机器码生成授权文件")
|
||||
p_gen.add_argument("-m", "--machine", required=True, help="目标机器的机器码")
|
||||
p_gen.add_argument("-o", "--output", required=True, help="输出文件路径")
|
||||
p_gen.add_argument("-e", "--expiry", default=None, help="有效期截止日期 YYYY-MM-DD")
|
||||
p_gen.add_argument("-n", "--name", default="WaterQualityInversion", help="产品名称")
|
||||
|
||||
# 子命令:info
|
||||
subparsers.add_parser("info", help="显示本机机器信息")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "verify":
|
||||
ok, msg = verify_license(args.file)
|
||||
print(f"[{'OK' if ok else 'FAIL'}] {msg}")
|
||||
|
||||
elif args.command == "generate":
|
||||
ok, msg = generate_license(args.machine, args.output, args.expiry, args.name)
|
||||
print(f"[{'OK' if ok else 'FAIL'}] {msg}")
|
||||
|
||||
elif args.command == "info":
|
||||
info = get_machine_info()
|
||||
print("=" * 50)
|
||||
print("硬件指纹信息")
|
||||
print("=" * 50)
|
||||
for k, v in info.items():
|
||||
print(f" {k}: {v or '(读取失败)'}")
|
||||
print("=" * 50)
|
||||
# 同时演示验证
|
||||
ok, msg = verify_license()
|
||||
print(f"\n授权验证: [{'OK' if ok else 'FAIL'}] {msg}")
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
36
src/core/algorithms/__init__.py
Normal file
36
src/core/algorithms/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""
|
||||
算法层模块
|
||||
包含插值算法和耀斑检测算法等核心数学计算
|
||||
"""
|
||||
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
|
||||
from src.core.algorithms.glint_detection.detectors import (
|
||||
otsu_threshold,
|
||||
zscore_threshold,
|
||||
percentile_threshold,
|
||||
iqr_outlier_detection,
|
||||
adaptive_threshold,
|
||||
multi_band_glint_detection,
|
||||
percentile_stretch,
|
||||
filter_large_components,
|
||||
create_shoreline_buffer,
|
||||
remove_shoreline_buffer,
|
||||
calculate_glint_mask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 插值
|
||||
'interpolate_pixels',
|
||||
'interpolate_zero_pixels_batch',
|
||||
# 耀斑检测
|
||||
'otsu_threshold',
|
||||
'zscore_threshold',
|
||||
'percentile_threshold',
|
||||
'iqr_outlier_detection',
|
||||
'adaptive_threshold',
|
||||
'multi_band_glint_detection',
|
||||
'percentile_stretch',
|
||||
'filter_large_components',
|
||||
'create_shoreline_buffer',
|
||||
'remove_shoreline_buffer',
|
||||
'calculate_glint_mask',
|
||||
]
|
||||
31
src/core/algorithms/glint_detection/__init__.py
Normal file
31
src/core/algorithms/glint_detection/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
耀斑检测算法模块
|
||||
包含各种耀斑检测的核心数学计算函数
|
||||
"""
|
||||
from src.core.algorithms.glint_detection.detectors import (
|
||||
otsu_threshold,
|
||||
zscore_threshold,
|
||||
percentile_threshold,
|
||||
iqr_outlier_detection,
|
||||
adaptive_threshold,
|
||||
multi_band_glint_detection,
|
||||
percentile_stretch,
|
||||
filter_large_components,
|
||||
create_shoreline_buffer,
|
||||
remove_shoreline_buffer,
|
||||
calculate_glint_mask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'otsu_threshold',
|
||||
'zscore_threshold',
|
||||
'percentile_threshold',
|
||||
'iqr_outlier_detection',
|
||||
'adaptive_threshold',
|
||||
'multi_band_glint_detection',
|
||||
'percentile_stretch',
|
||||
'filter_large_components',
|
||||
'create_shoreline_buffer',
|
||||
'remove_shoreline_buffer',
|
||||
'calculate_glint_mask',
|
||||
]
|
||||
595
src/core/algorithms/glint_detection/detectors.py
Normal file
595
src/core/algorithms/glint_detection/detectors.py
Normal file
@ -0,0 +1,595 @@
|
||||
"""
|
||||
耀斑检测算法模块
|
||||
|
||||
包含各种耀斑检测的核心数学计算函数,纯数学逻辑,不涉及文件I/O。
|
||||
支持的方法:otsu, zscore, percentile, iqr, adaptive, multi_band
|
||||
|
||||
本模块是从 src/utils/find_severe_glint_area.py 抽取出来的核心算法部分。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, List, Tuple
|
||||
from functools import wraps
|
||||
|
||||
try:
|
||||
import cv2
|
||||
CV2_AVAILABLE = True
|
||||
except ImportError:
|
||||
CV2_AVAILABLE = False
|
||||
|
||||
|
||||
def timeit(func):
|
||||
"""装饰器:测量函数执行时间"""
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
import time
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
print(f"[{func.__name__}] 耗时: {end - start:.2f}s")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 百分位数拉伸
|
||||
# =============================================================================
|
||||
|
||||
def percentile_stretch(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
lower_percentile: float = 2,
|
||||
upper_percentile: float = 98,
|
||||
output_range: Tuple[int, int] = (0, 255)
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用百分位数裁剪进行归一化,适用于低反射率数据
|
||||
通过排除极值,更好地利用数据的动态范围
|
||||
|
||||
Args:
|
||||
img: 输入图像数组(反射率值,通常在0-1之间)
|
||||
data_water_mask: 水域掩膜
|
||||
lower_percentile: 下百分位数,用于裁剪最小值(默认2)
|
||||
upper_percentile: 上百分位数,用于裁剪最大值(默认98)
|
||||
output_range: 输出范围,默认(0, 255)
|
||||
|
||||
Returns:
|
||||
归一化后的图像数组(整数类型)
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return img.astype(np.int32)
|
||||
|
||||
p_lower = np.percentile(valid_pixels, lower_percentile)
|
||||
p_upper = np.percentile(valid_pixels, upper_percentile)
|
||||
|
||||
if p_lower >= p_upper:
|
||||
p_lower = np.percentile(valid_pixels, 1)
|
||||
p_upper = np.percentile(valid_pixels, 99)
|
||||
if p_lower >= p_upper:
|
||||
p_upper = valid_pixels.max()
|
||||
p_lower = valid_pixels.min()
|
||||
|
||||
img_clipped = np.clip(img, p_lower, p_upper)
|
||||
|
||||
if p_upper > p_lower:
|
||||
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (
|
||||
output_range[1] - output_range[0]
|
||||
) + output_range[0]
|
||||
else:
|
||||
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
|
||||
|
||||
return img_stretched.astype(np.int32)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Otsu阈值分割
|
||||
# =============================================================================
|
||||
|
||||
def otsu_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
ignore_value: int = 0,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于Otsu方法的自动阈值分割
|
||||
通过最大化类间方差找到最佳分割阈值
|
||||
|
||||
Args:
|
||||
img: 输入图像数组(整数值)
|
||||
data_water_mask: 水域掩膜
|
||||
ignore_value: 忽略的值(默认为0)
|
||||
foreground: 耀斑区域值(默认1)
|
||||
background: 背景值(默认0)
|
||||
|
||||
Returns:
|
||||
二值化检测结果数组
|
||||
"""
|
||||
height, width = img.shape
|
||||
|
||||
max_value = int(np.max(img[img > ignore_value])) + 1
|
||||
if max_value < 2:
|
||||
max_value = 256
|
||||
|
||||
hist = np.zeros([max_value], np.float32)
|
||||
|
||||
invalid_counter = 0
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
|
||||
invalid_counter += 1
|
||||
continue
|
||||
hist[img[i, j]] += 1
|
||||
|
||||
total_valid = height * width - invalid_counter
|
||||
if total_valid <= 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
hist /= total_valid
|
||||
|
||||
threshold = 0
|
||||
deltaMax = 0
|
||||
|
||||
for i in range(max_value):
|
||||
wA = sum(hist[:i + 1])
|
||||
wB = sum(hist[i + 1:])
|
||||
if wA == 0:
|
||||
wA = 1e-10
|
||||
if wB == 0:
|
||||
wB = 1e-10
|
||||
|
||||
uAtmp = sum(j * hist[j] for j in range(i + 1))
|
||||
uBtmp = sum(j * hist[j] for j in range(i + 1, max_value))
|
||||
uA = uAtmp / wA
|
||||
uB = uBtmp / wB
|
||||
u = uAtmp + uBtmp
|
||||
|
||||
deltaTmp = wA * ((uA - u) ** 2) + wB * ((uB - u) ** 2)
|
||||
if deltaTmp > deltaMax:
|
||||
deltaMax = deltaTmp
|
||||
threshold = i
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > threshold] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Z-score阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def zscore_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
z_threshold: float = 2.5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于Z-score(标准化分数)的耀斑检测方法
|
||||
使用统计方法识别异常高亮的像素,对数据分布不敏感
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
z_threshold: Z-score阈值,默认2.5(即超过均值2.5个标准差)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
mean_val = np.mean(valid_pixels)
|
||||
std_val = np.std(valid_pixels)
|
||||
|
||||
if std_val == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
z_scores = np.zeros_like(img, dtype=np.float32)
|
||||
valid_mask = (data_water_mask > 0) & np.isfinite(img)
|
||||
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[z_scores > z_threshold] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 百分位数阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def percentile_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
percentile: float = 95,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于百分位数的耀斑检测方法
|
||||
使用百分位数作为阈值,对异常值更稳健
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
percentile: 百分位数阈值,默认95(即超过95%的像素值)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
threshold_val = np.percentile(valid_pixels, percentile)
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > threshold_val] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# IQR异常值检测
|
||||
# =============================================================================
|
||||
|
||||
def iqr_outlier_detection(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
iqr_multiplier: float = 1.5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于IQR(四分位距)的异常值检测方法
|
||||
使用四分位距识别异常高亮的像素,对数据分布不敏感
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
iqr_multiplier: IQR倍数,默认1.5(标准异常值检测)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
q1 = np.percentile(valid_pixels, 25)
|
||||
q3 = np.percentile(valid_pixels, 75)
|
||||
iqr = q3 - q1
|
||||
|
||||
upper_bound = q3 + iqr_multiplier * iqr
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > upper_bound] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 自适应阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def adaptive_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
window_size: int = 15,
|
||||
percentile: float = 90,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
自适应阈值方法
|
||||
基于局部统计特性进行阈值分割,对光照变化更稳健
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
window_size: 局部窗口大小(奇数)
|
||||
percentile: 局部百分位数阈值
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
height, width = img.shape
|
||||
|
||||
if window_size % 2 == 0:
|
||||
window_size += 1
|
||||
|
||||
half_window = window_size // 2
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
for i in range(half_window, height - half_window):
|
||||
for j in range(half_window, width - half_window):
|
||||
if data_water_mask[i, j] == 0:
|
||||
continue
|
||||
|
||||
local_window = img[i - half_window:i + half_window + 1,
|
||||
j - half_window:j + half_window + 1]
|
||||
local_mask = data_water_mask[i - half_window:i + half_window + 1,
|
||||
j - half_window:j + half_window + 1]
|
||||
|
||||
valid_pixels = local_window[local_mask > 0]
|
||||
|
||||
if len(valid_pixels) > 0:
|
||||
local_th = np.percentile(valid_pixels, percentile)
|
||||
if img[i, j] > local_th:
|
||||
det_img[i, j] = foreground
|
||||
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 多波段融合耀斑检测
|
||||
# =============================================================================
|
||||
|
||||
def multi_band_glint_detection(
|
||||
nir_band: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
glint_waves: List[float],
|
||||
weights: Optional[List[float]] = None,
|
||||
method: str = 'zscore',
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95,
|
||||
sub_band_arrays: Optional[List[np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
多波段融合的耀斑检测方法
|
||||
结合多个波段的耀斑特征,提高检测的稳健性
|
||||
|
||||
Args:
|
||||
nir_band: 近红外波段数组(主波段,用于兼容性)
|
||||
water_mask: 水域掩膜数组
|
||||
glint_waves: 用于检测的波长列表,如[750, 800, 850]
|
||||
weights: 各波段的权重,如果为None则使用等权重
|
||||
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
|
||||
z_threshold: Z-score阈值(当method='zscore'时使用)
|
||||
percentile: 百分位数阈值(当method='percentile'时使用)
|
||||
sub_band_arrays: 子波段数组列表(如果提供,与 glint_waves 一一对应)
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
if weights is None:
|
||||
weights = [1.0 / len(glint_waves)] * len(glint_waves)
|
||||
|
||||
if len(weights) != len(glint_waves):
|
||||
raise ValueError("权重数量必须与波长数量相同")
|
||||
|
||||
fused_band = None
|
||||
|
||||
if sub_band_arrays is not None and len(sub_band_arrays) == len(glint_waves):
|
||||
for i, band_array in enumerate(sub_band_arrays):
|
||||
if fused_band is None:
|
||||
fused_band = (band_array * weights[i]).astype(np.float32)
|
||||
else:
|
||||
fused_band = (fused_band + band_array * weights[i]).astype(np.float32)
|
||||
else:
|
||||
fused_band = nir_band.astype(np.float32)
|
||||
|
||||
if method == 'otsu':
|
||||
stretched = percentile_stretch(fused_band, water_mask, 2, 98)
|
||||
return otsu_threshold(stretched, water_mask)
|
||||
elif method == 'zscore':
|
||||
return zscore_threshold(fused_band, water_mask, z_threshold)
|
||||
elif method == 'percentile':
|
||||
return percentile_threshold(fused_band, water_mask, percentile)
|
||||
else:
|
||||
raise ValueError(f"不支持的方法: {method}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 连通域过滤
|
||||
# =============================================================================
|
||||
|
||||
def filter_large_components(
|
||||
binary_img: np.ndarray,
|
||||
max_area: Optional[int] = None,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
过滤掉面积超过阈值的连通域
|
||||
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
|
||||
|
||||
Args:
|
||||
binary_img: 二值化图像
|
||||
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
过滤后的二值化图像
|
||||
"""
|
||||
if max_area is None or max_area <= 0:
|
||||
return binary_img
|
||||
|
||||
if CV2_AVAILABLE:
|
||||
binary_for_label = (binary_img == foreground).astype(np.uint8)
|
||||
num_features, labeled_array, stats, _ = cv2.connectedComponentsWithStats(
|
||||
binary_for_label, connectivity=8
|
||||
)
|
||||
|
||||
if num_features == 0:
|
||||
return binary_img
|
||||
|
||||
component_sizes = stats[1:, cv2.CC_STAT_AREA]
|
||||
keep_labels = np.where(component_sizes <= max_area)[0] + 1
|
||||
|
||||
keep_mask = np.isin(labeled_array, keep_labels)
|
||||
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||||
filtered[keep_mask] = foreground
|
||||
|
||||
return filtered
|
||||
else:
|
||||
from scipy import ndimage
|
||||
labeled_array, num_features = ndimage.label(
|
||||
(binary_img == foreground).astype(np.int32)
|
||||
)
|
||||
|
||||
if num_features == 0:
|
||||
return binary_img
|
||||
|
||||
component_sizes = ndimage.sum(
|
||||
(labeled_array == i).astype(np.int32),
|
||||
labeled_array,
|
||||
range(1, num_features + 1)
|
||||
)
|
||||
|
||||
keep_mask = np.isin(labeled_array, [i + 1 for i, s in enumerate(component_sizes) if s <= max_area])
|
||||
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||||
filtered[keep_mask] = foreground
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 岸边缓冲区处理
|
||||
# =============================================================================
|
||||
|
||||
def create_shoreline_buffer(
|
||||
water_mask: np.ndarray,
|
||||
buffer_size: int = 5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
创建岸边缓冲区掩膜(向内缓冲)
|
||||
用于去除岸边附近的错误耀斑检测区域
|
||||
|
||||
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
|
||||
|
||||
Args:
|
||||
water_mask: 水域掩膜数组(水域=1,非水域=0)
|
||||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
岸边缓冲区掩膜(缓冲区区域=1,其他=0)
|
||||
"""
|
||||
if buffer_size <= 0:
|
||||
return np.zeros_like(water_mask, dtype=np.int32)
|
||||
|
||||
water_binary = (water_mask > 0).astype(np.uint8)
|
||||
structure_size = buffer_size * 2 + 1
|
||||
structure = np.ones((structure_size, structure_size), dtype=np.uint8)
|
||||
|
||||
if CV2_AVAILABLE:
|
||||
eroded_water = cv2.erode(water_binary, structure).astype(np.int32)
|
||||
else:
|
||||
from scipy import ndimage
|
||||
eroded_water = ndimage.binary_erosion(water_binary, structure).astype(np.int32)
|
||||
|
||||
buffer_mask = (water_binary - eroded_water).astype(np.int32)
|
||||
|
||||
return buffer_mask
|
||||
|
||||
|
||||
def remove_shoreline_buffer(
|
||||
glint_mask: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
buffer_size: int = 5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
从耀斑掩膜中去除岸边缓冲区内的区域
|
||||
|
||||
Args:
|
||||
glint_mask: 耀斑掩膜数组
|
||||
water_mask: 水域掩膜数组
|
||||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
去除岸边缓冲区后的耀斑掩膜
|
||||
"""
|
||||
if buffer_size <= 0:
|
||||
return glint_mask
|
||||
|
||||
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
|
||||
|
||||
cleaned = glint_mask.copy()
|
||||
cleaned[buffer_mask > 0] = background
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 高级组合函数
|
||||
# =============================================================================
|
||||
|
||||
def calculate_glint_mask(
|
||||
nir_band: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
method: str = 'otsu',
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95,
|
||||
iqr_multiplier: float = 1.5,
|
||||
window_size: int = 15,
|
||||
apply_percentile_stretch: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
计算耀斑掩膜的统一入口函数
|
||||
|
||||
Args:
|
||||
nir_band: 近红外波段数组
|
||||
water_mask: 水域掩膜
|
||||
method: 检测方法 ('otsu', 'zscore', 'percentile', 'iqr', 'adaptive')
|
||||
z_threshold: Z-score阈值
|
||||
percentile: 百分位数阈值
|
||||
iqr_multiplier: IQR倍数
|
||||
window_size: 自适应阈值窗口大小
|
||||
apply_percentile_stretch: 是否对otsu和adaptive方法应用百分位数拉伸
|
||||
|
||||
Returns:
|
||||
二值化耀斑掩膜
|
||||
"""
|
||||
if method == 'otsu':
|
||||
if apply_percentile_stretch:
|
||||
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
|
||||
return otsu_threshold(stretched, water_mask)
|
||||
else:
|
||||
return otsu_threshold(nir_band.astype(np.int32), water_mask)
|
||||
elif method == 'zscore':
|
||||
return zscore_threshold(nir_band, water_mask, z_threshold)
|
||||
elif method == 'percentile':
|
||||
return percentile_threshold(nir_band, water_mask, percentile)
|
||||
elif method == 'iqr':
|
||||
return iqr_outlier_detection(nir_band, water_mask, iqr_multiplier)
|
||||
elif method == 'adaptive':
|
||||
if apply_percentile_stretch:
|
||||
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
|
||||
return adaptive_threshold(stretched, water_mask, window_size, percentile)
|
||||
else:
|
||||
return adaptive_threshold(nir_band.astype(np.int32), water_mask, window_size, percentile)
|
||||
else:
|
||||
raise ValueError(f"不支持的方法: {method}")
|
||||
7
src/core/algorithms/interpolation/__init__.py
Normal file
7
src/core/algorithms/interpolation/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""
|
||||
插值算法模块
|
||||
包含0值像素插值的核心数学逻辑
|
||||
"""
|
||||
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
|
||||
|
||||
__all__ = ['interpolate_pixels', 'interpolate_zero_pixels_batch']
|
||||
320
src/core/algorithms/interpolation/interpolator.py
Normal file
320
src/core/algorithms/interpolation/interpolator.py
Normal file
@ -0,0 +1,320 @@
|
||||
"""
|
||||
像素插值算法模块
|
||||
|
||||
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
|
||||
支持多种插值方法:nearest, bilinear, spline (RBF), kriging。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import griddata, RBFInterpolator
|
||||
from scipy.spatial import cKDTree
|
||||
SCIPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
SCIPY_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
def interpolate_pixels(
|
||||
image_stack: np.ndarray,
|
||||
zero_coords: np.ndarray,
|
||||
valid_coords: np.ndarray,
|
||||
valid_values: np.ndarray,
|
||||
interpolation_method: str = 'nearest',
|
||||
water_mask: Optional[np.ndarray] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对指定坐标的像素进行插值(核心数学函数,不涉及文件I/O)
|
||||
|
||||
Args:
|
||||
image_stack: 影像数据堆叠,形状为 (height, width, n_bands) 的 float32 数组
|
||||
zero_coords: 需要插值的像素坐标,形状为 (n_zero, 2),每行是 [x, y]
|
||||
valid_coords: 有效像素坐标,形状为 (n_valid, 2)
|
||||
valid_values: 有效像素对应的值,形状为 (n_valid,) 或 (n_valid, n_bands)
|
||||
interpolation_method: 插值方法,可选 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
water_mask: 可选的水域掩膜数组
|
||||
|
||||
Returns:
|
||||
插值后的影像副本,形状与 image_stack 相同
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
|
||||
height, width, n_bands = image_stack.shape
|
||||
result = image_stack.copy()
|
||||
|
||||
# 兼容中文和各种格式的method参数
|
||||
raw_method = str(interpolation_method).lower()
|
||||
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
|
||||
method = 'nearest'
|
||||
elif 'bilinear' in raw_method or '线性' in raw_method or '双线性' in raw_method:
|
||||
method = 'bilinear'
|
||||
elif 'spline' in raw_method or '样条' in raw_method or 'rbf' in raw_method:
|
||||
method = 'spline'
|
||||
elif 'kriging' in raw_method or '克里金' in raw_method:
|
||||
method = 'kriging'
|
||||
else:
|
||||
method = 'nearest'
|
||||
|
||||
if len(valid_values) == 0:
|
||||
return result
|
||||
|
||||
is_multiband = len(valid_values.shape) > 1 and valid_values.shape[1] > 1
|
||||
|
||||
if is_multiband:
|
||||
for band_idx in range(n_bands):
|
||||
band_valid_values = valid_values[:, band_idx]
|
||||
interpolated_values = _interpolate_single_band(
|
||||
zero_coords, valid_coords, band_valid_values, method
|
||||
)
|
||||
y_coords = zero_coords[:, 1].astype(int)
|
||||
x_coords = zero_coords[:, 0].astype(int)
|
||||
result[y_coords, x_coords, band_idx] = interpolated_values
|
||||
else:
|
||||
interpolated_values = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values, method
|
||||
)
|
||||
y_coords = zero_coords[:, 1].astype(int)
|
||||
x_coords = zero_coords[:, 0].astype(int)
|
||||
result[y_coords, x_coords] = interpolated_values
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _interpolate_single_band(
|
||||
zero_coords: np.ndarray,
|
||||
valid_coords: np.ndarray,
|
||||
valid_values: np.ndarray,
|
||||
method: str
|
||||
) -> np.ndarray:
|
||||
"""对单个波段执行插值计算"""
|
||||
if method == 'nearest':
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords)
|
||||
return valid_values[indices]
|
||||
|
||||
elif method == 'bilinear':
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
elif method == 'spline':
|
||||
try:
|
||||
max_points = 10000
|
||||
if len(valid_values) > max_points:
|
||||
indices = np.random.choice(len(valid_values), max_points, replace=False)
|
||||
sample_coords = valid_coords[indices]
|
||||
sample_values = valid_values[indices]
|
||||
else:
|
||||
sample_coords = valid_coords
|
||||
sample_values = valid_values
|
||||
rbf = RBFInterpolator(sample_coords, sample_values, kernel='thin_plate_spline')
|
||||
interpolated = rbf(zero_coords)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
except Exception:
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
elif method == 'kriging':
|
||||
try:
|
||||
from src.utils.kriging import KrigingInterpolator
|
||||
interpolator = KrigingInterpolator()
|
||||
max_points = 5000
|
||||
if len(valid_values) > max_points:
|
||||
indices = np.random.choice(len(valid_values), max_points, replace=False)
|
||||
sample_coords = valid_coords[indices]
|
||||
sample_values = valid_values[indices]
|
||||
else:
|
||||
sample_coords = valid_coords
|
||||
sample_values = valid_values
|
||||
interpolated = griddata(
|
||||
sample_coords, sample_values, zero_coords,
|
||||
method='cubic', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
except Exception:
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
return np.zeros(len(zero_coords))
|
||||
|
||||
|
||||
def interpolate_zero_pixels_batch(
|
||||
img_path: str,
|
||||
interpolation_method: str = 'nearest',
|
||||
output_path: Optional[str] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
deglint_dir: Optional[str] = None,
|
||||
callback_progress: Optional[callable] = None
|
||||
) -> Tuple[str, Optional[np.ndarray]]:
|
||||
"""
|
||||
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
output_path: 输出文件路径(如果为None,自动生成)
|
||||
water_mask: 水域掩膜(文件路径或数组)
|
||||
deglint_dir: 去耀斑目录(用于生成默认输出路径)
|
||||
callback_progress: 进度回调函数
|
||||
|
||||
Returns:
|
||||
(output_path, interpolated_image_stack) 元组
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
# 确定输出路径
|
||||
if output_path is None and deglint_dir is not None:
|
||||
output_path = str(Path(deglint_dir) / f"interpolated_{interpolation_method}.bsq")
|
||||
|
||||
# 检查文件是否已存在
|
||||
if output_path and Path(output_path).exists():
|
||||
return output_path, None
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
# 读取所有波段数据
|
||||
all_bands = []
|
||||
for band_idx in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(band_idx)
|
||||
band_data = band.ReadAsArray().astype(np.float32)
|
||||
all_bands.append(band_data)
|
||||
|
||||
image_stack = np.dstack(all_bands)
|
||||
|
||||
# 读取水域掩膜
|
||||
mask_array = None
|
||||
if water_mask is not None:
|
||||
if isinstance(water_mask, str):
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
elif isinstance(water_mask, np.ndarray):
|
||||
mask_array = water_mask
|
||||
|
||||
# 找出所有波段都为0的像素点
|
||||
all_bands_zero = np.all(image_stack == 0, axis=2)
|
||||
|
||||
if mask_array is not None:
|
||||
all_bands_zero = all_bands_zero & (mask_array > 0)
|
||||
|
||||
zero_pixel_count = np.sum(all_bands_zero)
|
||||
if zero_pixel_count == 0:
|
||||
# 无需插值,直接保存
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(all_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
out_dataset = None
|
||||
return output_path, image_stack
|
||||
|
||||
# 获取坐标
|
||||
zero_y, zero_x = np.where(all_bands_zero)
|
||||
zero_coords = np.column_stack([zero_x, zero_y])
|
||||
|
||||
valid_mask = ~all_bands_zero
|
||||
valid_y, valid_x = np.where(valid_mask)
|
||||
valid_coords = np.column_stack([valid_x, valid_y])
|
||||
|
||||
if len(valid_coords) == 0:
|
||||
raise ValueError("没有有效像素可用于插值")
|
||||
|
||||
# 逐波段插值
|
||||
interpolated_bands = []
|
||||
for band_idx in range(n_bands):
|
||||
if callback_progress:
|
||||
callback_progress(f"处理波段 {band_idx + 1}/{n_bands}...")
|
||||
band_data = all_bands[band_idx].copy()
|
||||
valid_values_band = band_data[valid_mask]
|
||||
|
||||
if len(valid_values_band) == 0:
|
||||
interpolated_bands.append(band_data)
|
||||
continue
|
||||
|
||||
band_result = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values_band, interpolation_method
|
||||
)
|
||||
band_data[all_bands_zero] = band_result
|
||||
interpolated_bands.append(band_data)
|
||||
|
||||
# 保存结果
|
||||
if output_path:
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
out_dataset = driver.Create(output_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
for i, band_data in enumerate(interpolated_bands):
|
||||
out_band = out_dataset.GetRasterBand(i + 1)
|
||||
out_band.WriteArray(band_data)
|
||||
out_band.FlushCache()
|
||||
out_dataset = None
|
||||
|
||||
result_stack = np.dstack(interpolated_bands)
|
||||
return output_path, result_stack
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
@ -1,5 +1,4 @@
|
||||
import numpy as np
|
||||
# import preprocessing
|
||||
import os
|
||||
|
||||
try:
|
||||
@ -8,283 +7,301 @@ try:
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
class Hedley:
|
||||
def __init__(self, im_aligned, shp_path=None, NIR_band = 47, water_mask=None, output_path=None):
|
||||
def __init__(self, img_path, shp_path=None, NIR_band=47, water_mask=None,
|
||||
output_path=None, block_size=1000):
|
||||
"""
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
|
||||
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
|
||||
:param NIR_band (int): band index for NIR band which corresponds to 842.36nm, which corresponds closely to the NIR band in Micasense
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||||
如果为None,则不保存
|
||||
"""
|
||||
self.im_aligned = im_aligned
|
||||
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
|
||||
self.NIR_band = NIR_band
|
||||
self.n_bands = im_aligned.shape[-1]
|
||||
self.height = im_aligned.shape[0]
|
||||
self.width = im_aligned.shape[1]
|
||||
self.output_path = output_path
|
||||
|
||||
# 加载水域掩膜
|
||||
self.water_mask = self._load_water_mask(water_mask)
|
||||
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 如果存在水域掩膜,只在掩膜内计算R_min
|
||||
if self.water_mask is not None:
|
||||
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
|
||||
self.R_min = np.percentile(nir_band_masked, 5, interpolation='nearest') if nir_band_masked.size > 0 else 0
|
||||
else:
|
||||
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, interpolation='nearest')
|
||||
|
||||
def _read_shp_to_bbox(self, shp_path):
|
||||
"""
|
||||
读取shapefile并提取边界框
|
||||
|
||||
:param shp_path (str): shapefile文件路径
|
||||
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
|
||||
"""
|
||||
if not os.path.exists(shp_path):
|
||||
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
|
||||
|
||||
try:
|
||||
try:
|
||||
import geopandas as gpd
|
||||
gdf = gpd.read_file(shp_path)
|
||||
# 获取所有几何体的总边界框
|
||||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||||
min_x, min_y, max_x, max_y = bounds
|
||||
except ImportError:
|
||||
# 如果geopandas不可用,尝试使用fiona
|
||||
import fiona
|
||||
from shapely.geometry import shape
|
||||
|
||||
min_x = float('inf')
|
||||
min_y = float('inf')
|
||||
max_x = float('-inf')
|
||||
max_y = float('-inf')
|
||||
|
||||
with fiona.open(shp_path) as shp:
|
||||
for feature in shp:
|
||||
geom = shape(feature['geometry'])
|
||||
if geom:
|
||||
bounds = geom.bounds
|
||||
min_x = min(min_x, bounds[0])
|
||||
min_y = min(min_y, bounds[1])
|
||||
max_x = max(max_x, bounds[2])
|
||||
max_y = max(max_y, bounds[3])
|
||||
|
||||
# 转换为整数像素坐标
|
||||
x1 = max(0, int(min_x))
|
||||
y1 = max(0, int(min_y))
|
||||
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
|
||||
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
|
||||
|
||||
return ((x1, y1), (x2, y2))
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
|
||||
|
||||
def _load_water_mask(self, water_mask):
|
||||
"""
|
||||
加载水域掩膜
|
||||
|
||||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
:return: numpy数组或None,1表示水域,0表示非水域
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# 如果已经是numpy数组
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||||
|
||||
# 如果是文件路径
|
||||
if isinstance(water_mask, str):
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
except ImportError:
|
||||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||||
|
||||
# 检查是否为shapefile
|
||||
if water_mask.lower().endswith('.shp'):
|
||||
# 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸)
|
||||
# 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考
|
||||
raise ValueError("Hedley类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
|
||||
else:
|
||||
# 栅格文件
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||||
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
def covariance_NIR(self,NIR,b):
|
||||
"""
|
||||
NIR & b are vectors
|
||||
reflectance for band i
|
||||
"""
|
||||
n = len(NIR)
|
||||
# 优化:减少重复计算,使用更高效的numpy操作
|
||||
nir_mean = np.mean(NIR)
|
||||
b_mean = np.mean(b)
|
||||
# 使用更高效的协方差计算
|
||||
pij = np.mean((NIR - nir_mean) * (b - b_mean))
|
||||
pjj = np.mean((NIR - nir_mean) ** 2)
|
||||
# 避免除零错误
|
||||
return pij / pjj if pjj != 0 else 0.0
|
||||
|
||||
def correlation_bands_reflectance(self):
|
||||
"""
|
||||
calculate correlation between NIR and other bands for reflectance
|
||||
NIR_band is 750 nm
|
||||
"""
|
||||
# If bbox is None, use the entire image
|
||||
if self.bbox is None:
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 直接使用视图,只在需要时创建扁平数组
|
||||
im_region = self.im_aligned
|
||||
mask_region = self.water_mask
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
im_region = self.im_aligned[y1:y2,x1:x2,:]
|
||||
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内计算相关性
|
||||
if mask_region is not None:
|
||||
mask_bool = mask_region.astype(bool)
|
||||
if mask_bool.any():
|
||||
# 只在掩膜内提取数据
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band][mask_bool]
|
||||
else:
|
||||
# 如果掩膜内没有有效像素,使用全区域
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
|
||||
mask_bool = None
|
||||
else:
|
||||
NIR_reflectance = im_region[:,:,self.NIR_band].ravel()
|
||||
mask_bool = None
|
||||
|
||||
# 优化:一次性计算所有波段的相关性,减少循环开销
|
||||
corr_list = []
|
||||
for v in range(self.n_bands):
|
||||
if mask_bool is not None and mask_bool.any():
|
||||
band_reflectance = im_region[:,:,v][mask_bool]
|
||||
else:
|
||||
band_reflectance = im_region[:,:,v].ravel()
|
||||
corr = self.covariance_NIR(NIR_reflectance, band_reflectance)
|
||||
corr_list.append(corr)
|
||||
|
||||
return corr_list
|
||||
|
||||
def _save_corrected_bands(self, corrected_bands):
|
||||
"""
|
||||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_bands: 校正后的波段列表
|
||||
Hedley 耀斑去除算法 - 分块逐波段处理版本
|
||||
|
||||
:param img_path (str): 输入影像文件路径(GDAL可读取的格式)
|
||||
:param shp_path (str, optional): 深水区域shapefile,已废弃,请使用water_mask
|
||||
:param NIR_band (int): NIR波段索引(默认47,对应842.36nm)
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜
|
||||
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
|
||||
:param block_size (int): 分块大小(默认1000)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if self.output_path is None:
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 将波段列表转换为数组
|
||||
corrected_array = np.stack(corrected_bands, axis=2)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
self.img_path = img_path
|
||||
self.NIR_band = int(float(NIR_band))
|
||||
self.water_mask = None
|
||||
self.water_mask_path = water_mask
|
||||
self.output_path = output_path
|
||||
self.block_size = block_size
|
||||
self.R_min = None
|
||||
self.corr_list = None # 全局协方差系数列表
|
||||
|
||||
# 打开影像
|
||||
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if self.dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.n_bands = self.dataset.RasterCount
|
||||
|
||||
def _load_water_mask(self):
|
||||
"""延迟加载水域掩膜"""
|
||||
if self.water_mask_path is None:
|
||||
return None
|
||||
|
||||
if isinstance(self.water_mask_path, np.ndarray):
|
||||
if self.water_mask_path.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(
|
||||
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||||
)
|
||||
return (self.water_mask_path > 0).astype(np.uint8)
|
||||
|
||||
if isinstance(self.water_mask_path, str):
|
||||
if self.water_mask_path.lower().endswith('.shp'):
|
||||
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
|
||||
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(
|
||||
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||||
)
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
return None
|
||||
|
||||
def covariance_NIR(self, NIR, b):
|
||||
"""计算 NIR 与波段 b 之间的协方差系数 b_i = Cov(NIR,b) / Var(NIR)"""
|
||||
n = len(NIR)
|
||||
nir_mean = np.mean(NIR)
|
||||
b_mean = np.mean(b)
|
||||
pij = np.mean((NIR - nir_mean) * (b - b_mean))
|
||||
pjj = np.mean((NIR - nir_mean) ** 2)
|
||||
return pij / pjj if pjj != 0 else 0.0
|
||||
|
||||
def _scan_global_stats(self, sample_step=20):
|
||||
"""
|
||||
扫描全图获取全局 R_min
|
||||
|
||||
使用重采样方式扫描,大幅降低内存占用。
|
||||
"""
|
||||
print(f"[Hedley] 扫描全局统计量(采样步长={sample_step})...")
|
||||
water_mask = self._load_water_mask()
|
||||
|
||||
nir_samples = []
|
||||
sample_count = 0
|
||||
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
block_height = y_end - y_off
|
||||
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
|
||||
nir_band = None
|
||||
|
||||
if water_mask is not None:
|
||||
mask_block = water_mask[y_off:y_end, :]
|
||||
mask_bool = mask_block.astype(bool)
|
||||
else:
|
||||
mask_bool = np.ones((block_height, self.width), dtype=bool)
|
||||
|
||||
if mask_bool.any():
|
||||
nir_sampled = nir_block[mask_bool][::sample_step]
|
||||
nir_samples.append(nir_sampled)
|
||||
sample_count += nir_sampled.size
|
||||
|
||||
del nir_block, mask_block
|
||||
|
||||
if sample_count == 0:
|
||||
self.R_min = 0.0
|
||||
else:
|
||||
bsq_path = self.output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_array.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
all_nir = np.concatenate(nir_samples)
|
||||
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
|
||||
del all_nir
|
||||
|
||||
print(f"[Hedley] 全局 R_min={self.R_min:.4f}")
|
||||
|
||||
def _compute_corr_list(self, sample_step=5):
|
||||
"""
|
||||
计算每个波段与NIR的协方差系数 corr_list[b] = Cov(NIR, band_b) / Var(NIR)
|
||||
|
||||
全分辨率扫描,逐波段读取,每波段内存 ≈ block_size²
|
||||
由于需要相关性计算,需要足够多的样本,取sample_step=5
|
||||
"""
|
||||
print(f"[Hedley] 计算全局协方差系数列表(采样步长={sample_step})...")
|
||||
water_mask = self._load_water_mask()
|
||||
|
||||
# 预收集NIR和每个波段的样本数据
|
||||
nir_samples = []
|
||||
band_samples = [[] for _ in range(self.n_bands)]
|
||||
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
block_height = y_end - y_off
|
||||
|
||||
# 读取NIR波段(每块只读一次)
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
|
||||
nir_band = None
|
||||
|
||||
# 取 NIR 样本(每块只取一次,放在波段循环外)
|
||||
if water_mask is not None:
|
||||
mask_block = water_mask[y_off:y_end, :]
|
||||
mask_bool = mask_block.astype(bool)
|
||||
else:
|
||||
mask_bool = np.ones((block_height, self.width), dtype=bool)
|
||||
|
||||
if mask_bool.any():
|
||||
nir_sampled = nir_block[mask_bool][::sample_step]
|
||||
nir_samples.append(nir_sampled)
|
||||
|
||||
# 逐波段读取并采样(all_band 严格使用单波段切片)
|
||||
for b in range(self.n_bands):
|
||||
band = self.dataset.GetRasterBand(b + 1)
|
||||
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
|
||||
band = None
|
||||
|
||||
if mask_bool.any():
|
||||
band_sampled = block[mask_bool][::sample_step]
|
||||
band_samples[b].append(band_sampled)
|
||||
|
||||
del block
|
||||
|
||||
del nir_block
|
||||
|
||||
# 汇总并计算相关系数
|
||||
if len(nir_samples) == 0 or sum(len(s) for s in nir_samples) == 0:
|
||||
self.corr_list = [0.0] * self.n_bands
|
||||
else:
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
all_nir = np.concatenate(nir_samples)
|
||||
self.corr_list = []
|
||||
for b in range(self.n_bands):
|
||||
all_band = np.concatenate(band_samples[b])
|
||||
corr = self.covariance_NIR(all_nir, all_band)
|
||||
self.corr_list.append(float(corr))
|
||||
|
||||
del all_nir
|
||||
for b in range(self.n_bands):
|
||||
band_samples[b] = None
|
||||
|
||||
print(f"[Hedley] 协方差系数: min={min(self.corr_list):.4f}, max={max(self.corr_list):.4f}")
|
||||
|
||||
def _process_block(self, x_off, y_off, x_size, y_size):
|
||||
"""
|
||||
处理单个分块
|
||||
|
||||
Returns:
|
||||
list of np.ndarray: 校正后的波段列表
|
||||
"""
|
||||
# 读取NIR波段
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
NIR = nir_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
nir_band = None
|
||||
|
||||
# 预计算 NIR - R_min
|
||||
NIR_diff = NIR - self.R_min
|
||||
|
||||
# 获取掩膜
|
||||
water_mask = self._load_water_mask()
|
||||
if water_mask is not None:
|
||||
y_end = y_off + y_size
|
||||
x_end = x_off + x_size
|
||||
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
|
||||
else:
|
||||
mask_block = None
|
||||
|
||||
# 逐波段处理
|
||||
corrected_bands = []
|
||||
for b in range(self.n_bands):
|
||||
band = self.dataset.GetRasterBand(b + 1)
|
||||
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
band = None
|
||||
|
||||
corr = self.corr_list[b]
|
||||
# Hedley 校正公式:R_corrected = R - corr * (NIR - R_min)
|
||||
corrected = R - corr * NIR_diff
|
||||
|
||||
if mask_block is not None:
|
||||
corrected = np.where(mask_block, corrected, R)
|
||||
|
||||
corrected_bands.append(corrected)
|
||||
del R
|
||||
|
||||
del NIR, NIR_diff
|
||||
|
||||
return corrected_bands
|
||||
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
correction is done in reflectance
|
||||
|
||||
:return: 校正后的波段列表
|
||||
执行分块处理,返回校正后的波段列表
|
||||
"""
|
||||
corr = self.correlation_bands_reflectance()
|
||||
NIR_reflectance = self.im_aligned[:,:,self.NIR_band]
|
||||
# 预计算NIR-R_min,避免在循环中重复计算
|
||||
NIR_diff = NIR_reflectance - self.R_min
|
||||
|
||||
# 获取水域掩膜(如果存在)
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
if self.output_path is None:
|
||||
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
|
||||
|
||||
corrected_bands = []
|
||||
for band_number in range(self.n_bands): #iterate across bands
|
||||
b = corr[band_number]
|
||||
R = self.im_aligned[:,:,band_number]
|
||||
# 优化:减少中间数组创建
|
||||
corrected_band = R - b * NIR_diff
|
||||
|
||||
# 如果存在水域掩膜,只对水域区域应用校正
|
||||
if water_mask_bool is not None:
|
||||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||||
|
||||
corrected_bands.append(corrected_band)
|
||||
|
||||
# 如果提供了输出路径,保存结果
|
||||
if self.output_path is not None:
|
||||
self._save_corrected_bands(corrected_bands)
|
||||
|
||||
return corrected_bands
|
||||
# Step 1: 扫描全局 R_min
|
||||
self._scan_global_stats(sample_step=20)
|
||||
|
||||
# Step 2: 计算协方差系数列表
|
||||
self._compute_corr_list(sample_step=5)
|
||||
|
||||
# Step 3: 创建输出文件
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
|
||||
|
||||
geotransform = self.dataset.GetGeoTransform()
|
||||
projection = self.dataset.GetProjection()
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
out_dataset = driver.Create(bsq_path, self.width, self.height,
|
||||
self.n_bands, gdal.GDT_Float32)
|
||||
if out_dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
|
||||
# Step 4: 分块处理
|
||||
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
|
||||
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
|
||||
total_blocks = n_blocks_x * n_blocks_y
|
||||
|
||||
print(f"[Hedley] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
|
||||
|
||||
block_idx = 0
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
y_size = y_end - y_off
|
||||
|
||||
for x_off in range(0, self.width, self.block_size):
|
||||
x_end = min(x_off + self.block_size, self.width)
|
||||
x_size = x_end - x_off
|
||||
block_idx += 1
|
||||
|
||||
print(f"[Hedley] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
|
||||
|
||||
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
|
||||
|
||||
for b in range(self.n_bands):
|
||||
out_band = out_dataset.GetRasterBand(b + 1)
|
||||
out_band.WriteArray(corrected_bands[b], x_off, y_off)
|
||||
out_band.FlushCache()
|
||||
|
||||
del corrected_bands
|
||||
|
||||
out_dataset = None
|
||||
self.dataset = None
|
||||
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"[Hedley] 校正完成,已保存至: {bsq_path}")
|
||||
else:
|
||||
print(f"[Hedley] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)")
|
||||
|
||||
return []
|
||||
|
||||
def __del__(self):
|
||||
if self.dataset is not None:
|
||||
self.dataset = None
|
||||
@ -1,5 +1,4 @@
|
||||
import numpy as np
|
||||
# import preprocessing
|
||||
import os
|
||||
|
||||
try:
|
||||
@ -8,306 +7,333 @@ try:
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
class Kutser:
|
||||
def __init__(self, im_aligned, shp_path=None, oxy_band = 38,lower_oxy = 36, upper_oxy = 49, NIR_band = 47, water_mask=None, output_path=None):
|
||||
"""
|
||||
:param im_aligned (np.ndarray): band aligned and calibrated & corrected reflectance image
|
||||
:param shp_path (str, optional): path to shapefile (.shp) defining the region containing the glint region in deep water.
|
||||
If None, uses the entire image. The shapefile can use pixel coordinates or geographic coordinates.
|
||||
:param oxy_band (int): band index for oxygen absorption band, which corresponds to 760.6nm
|
||||
:param lower_oxy (int): band index for outside oxygen absorption band, which corresponds to 742.39nm
|
||||
:param upper_oxy (int): band index for outside oxygen absorption band, which corresponds to 860.48nm
|
||||
see Kutser, Vahtmäe and Praks
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜,1表示水域,0表示非水域
|
||||
可以是numpy数组、栅格文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
如果为None,则处理全图
|
||||
:param output_path (str or None): 输出文件路径,如果提供则保存校正后的图像
|
||||
如果为None,则不保存
|
||||
"""
|
||||
self.im_aligned = im_aligned
|
||||
self.bbox = self._read_shp_to_bbox(shp_path) if shp_path else None
|
||||
self.oxy_band = oxy_band
|
||||
self.lower_oxy = lower_oxy
|
||||
self.upper_oxy = upper_oxy
|
||||
self.NIR_band = NIR_band
|
||||
self.n_bands = im_aligned.shape[-1]
|
||||
self.height = im_aligned.shape[0]
|
||||
self.width = im_aligned.shape[1]
|
||||
self.output_path = output_path
|
||||
|
||||
# 加载水域掩膜
|
||||
self.water_mask = self._load_water_mask(water_mask)
|
||||
|
||||
# 使用ravel()而不是flatten(),避免不必要的复制
|
||||
# 如果存在水域掩膜,只在掩膜内计算R_min
|
||||
if self.water_mask is not None:
|
||||
nir_band_masked = self.im_aligned[:,:,self.NIR_band][self.water_mask.astype(bool)]
|
||||
self.R_min = np.percentile(nir_band_masked, 5, interpolation='nearest') if nir_band_masked.size > 0 else 0
|
||||
else:
|
||||
self.R_min = np.percentile(self.im_aligned[:,:,self.NIR_band].ravel(), 5, interpolation='nearest')
|
||||
|
||||
def _read_shp_to_bbox(self, shp_path):
|
||||
"""
|
||||
读取shapefile并提取边界框
|
||||
|
||||
:param shp_path (str): shapefile文件路径
|
||||
:return: tuple: ((x1,y1),(x2,y2)), where x1,y1 is the upper left corner, x2,y2 is the lower right corner
|
||||
"""
|
||||
if not os.path.exists(shp_path):
|
||||
raise FileNotFoundError(f"Shapefile not found: {shp_path}")
|
||||
|
||||
try:
|
||||
try:
|
||||
import geopandas as gpd
|
||||
gdf = gpd.read_file(shp_path)
|
||||
# 获取所有几何体的总边界框
|
||||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||||
min_x, min_y, max_x, max_y = bounds
|
||||
except ImportError:
|
||||
# 如果geopandas不可用,尝试使用fiona
|
||||
import fiona
|
||||
from shapely.geometry import shape
|
||||
|
||||
min_x = float('inf')
|
||||
min_y = float('inf')
|
||||
max_x = float('-inf')
|
||||
max_y = float('-inf')
|
||||
|
||||
with fiona.open(shp_path) as shp:
|
||||
for feature in shp:
|
||||
geom = shape(feature['geometry'])
|
||||
if geom:
|
||||
bounds = geom.bounds
|
||||
min_x = min(min_x, bounds[0])
|
||||
min_y = min(min_y, bounds[1])
|
||||
max_x = max(max_x, bounds[2])
|
||||
max_y = max(max_y, bounds[3])
|
||||
|
||||
# 转换为整数像素坐标
|
||||
x1 = max(0, int(min_x))
|
||||
y1 = max(0, int(min_y))
|
||||
x2 = min(self.im_aligned.shape[1], int(max_x) + 1)
|
||||
y2 = min(self.im_aligned.shape[0], int(max_y) + 1)
|
||||
|
||||
return ((x1, y1), (x2, y2))
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading shapefile {shp_path}: {e}")
|
||||
|
||||
def _load_water_mask(self, water_mask):
|
||||
"""
|
||||
加载水域掩膜
|
||||
|
||||
:param water_mask: 可以是None、numpy数组、文件路径(.dat/.tif)或shapefile路径(.shp)
|
||||
:return: numpy数组或None,1表示水域,0表示非水域
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# 如果已经是numpy数组
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8) # 确保是0/1掩膜
|
||||
|
||||
# 如果是文件路径
|
||||
if isinstance(water_mask, str):
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
except ImportError:
|
||||
raise ValueError("使用文件路径作为掩膜时,必须安装GDAL")
|
||||
|
||||
# 检查是否为shapefile
|
||||
if water_mask.lower().endswith('.shp'):
|
||||
# 从shp文件创建掩膜(需要参考图像,这里假设使用im_aligned的尺寸)
|
||||
# 注意:如果输入是numpy数组,无法从shp创建掩膜,需要提供栅格参考
|
||||
raise ValueError("Kutser类输入为numpy数组时,无法从shp文件创建掩膜。请先栅格化shp文件或提供numpy数组掩膜")
|
||||
else:
|
||||
# 栅格文件
|
||||
mask_dataset = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {water_mask}")
|
||||
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
def get_depth_D(self):
|
||||
"""
|
||||
Assume the amount of glint is proportional to the depth of the oxygen absorption feature, D
|
||||
returns the normalised D by dividing it by the maximum D found in a deep water region
|
||||
"""
|
||||
# 优化:减少中间数组创建,使用更高效的计算
|
||||
lower_oxy_band = self.im_aligned[:,:,self.lower_oxy]
|
||||
upper_oxy_band = self.im_aligned[:,:,self.upper_oxy]
|
||||
oxy_band = self.im_aligned[:,:,self.oxy_band]
|
||||
D = (lower_oxy_band + upper_oxy_band) * 0.5 - oxy_band
|
||||
|
||||
# 确定用于计算D_max的区域
|
||||
if self.bbox is None:
|
||||
search_region = D
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
search_region = D[y1:y2,x1:x2]
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内搜索最大值
|
||||
if self.water_mask is not None:
|
||||
if self.bbox is None:
|
||||
mask_region = self.water_mask.astype(bool)
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
mask_region = self.water_mask[y1:y2,x1:x2].astype(bool)
|
||||
|
||||
if mask_region.any():
|
||||
D_max = search_region[mask_region].max()
|
||||
else:
|
||||
D_max = search_region.max()
|
||||
else:
|
||||
D_max = search_region.max() # assumed to be the maximum glint value
|
||||
|
||||
# 避免除零错误
|
||||
if D_max == 0:
|
||||
return np.zeros_like(D)
|
||||
return D / D_max
|
||||
|
||||
def get_glint_G(self):
|
||||
"""
|
||||
The spectral variation of glint G is found by subtracting the spectrum at the darkest (ie. lowest D) NIR deep-water pixel from the brightest
|
||||
returns G as a function of wavelength
|
||||
"""
|
||||
# If bbox is None, use the entire image
|
||||
if self.bbox is None:
|
||||
im_region = self.im_aligned
|
||||
mask_region = self.water_mask
|
||||
else:
|
||||
((x1,y1),(x2,y2)) = self.bbox
|
||||
im_region = self.im_aligned[y1:y2,x1:x2,:]
|
||||
mask_region = self.water_mask[y1:y2,x1:x2] if self.water_mask is not None else None
|
||||
|
||||
# 如果存在水域掩膜,只在掩膜内计算最大最小值
|
||||
if mask_region is not None:
|
||||
mask_bool = mask_region.astype(bool)
|
||||
if mask_bool.any():
|
||||
# 对每个波段,只在掩膜内计算最大最小值
|
||||
G_list = []
|
||||
for i in range(self.n_bands):
|
||||
band_data = im_region[:,:,i]
|
||||
G_max = band_data[mask_bool].max()
|
||||
G_min = band_data[mask_bool].min()
|
||||
G_list.append(G_max - G_min)
|
||||
else:
|
||||
# 如果掩膜内没有有效像素,使用全区域
|
||||
G_max = np.amax(im_region, axis=(0, 1))
|
||||
G_min = np.amin(im_region, axis=(0, 1))
|
||||
G_list = (G_max - G_min).tolist()
|
||||
else:
|
||||
# 优化:一次性计算所有波段的最大最小值,减少循环开销
|
||||
# 使用numpy的amax和amin沿最后一个轴计算
|
||||
G_max = np.amax(im_region, axis=(0, 1)) # 沿空间维度计算最大值
|
||||
G_min = np.amin(im_region, axis=(0, 1)) # 沿空间维度计算最小值
|
||||
G_list = (G_max - G_min).tolist()
|
||||
return G_list
|
||||
|
||||
def _save_corrected_bands(self, corrected_bands):
|
||||
class Kutser:
|
||||
def __init__(self, img_path, shp_path=None, oxy_band=38, lower_oxy=36,
|
||||
upper_oxy=49, NIR_band=47, water_mask=None, output_path=None,
|
||||
block_size=1000):
|
||||
"""
|
||||
保存校正后的波段到文件(BSQ格式,ENVI格式)
|
||||
|
||||
:param corrected_bands: 校正后的波段列表
|
||||
Kutser 耀斑去除算法 - 分块逐波段处理版本
|
||||
|
||||
:param img_path (str): 输入影像文件路径(GDAL可读取的格式)
|
||||
:param shp_path (str, optional): 深水区域shapefile,已废弃,请使用water_mask
|
||||
:param oxy_band (int): 氧吸收波段索引(默认38,对应760.6nm)
|
||||
:param lower_oxy (int): 氧吸收下方波段索引(默认36,对应742.39nm)
|
||||
:param upper_oxy (int): 氧吸收上方波段索引(默认49,对应860.48nm)
|
||||
:param NIR_band (int): NIR波段索引(默认47,对应842.36nm)
|
||||
:param water_mask (np.ndarray or str or None): 水域掩膜
|
||||
:param output_path (str): 输出文件路径(必须提供,用于分块写入)
|
||||
:param block_size (int): 分块大小(默认1000)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if self.output_path is None:
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 将波段列表转换为数组
|
||||
corrected_array = np.stack(corrected_bands, axis=2)
|
||||
|
||||
# 如果没有地理信息,使用默认值
|
||||
geotransform = (0, 1, 0, 0, 0, -1)
|
||||
projection = ""
|
||||
|
||||
# 强制使用ENVI格式(BSQ格式),确保文件扩展名为.bsq
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
# 如果扩展名不是.bsq,使用基础路径添加.bsq
|
||||
if ext.lower() != '.bsq':
|
||||
bsq_path = base_path + '.bsq'
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
self.img_path = img_path
|
||||
self.oxy_band = int(float(oxy_band))
|
||||
self.lower_oxy = int(float(lower_oxy))
|
||||
self.upper_oxy = int(float(upper_oxy))
|
||||
self.NIR_band = int(float(NIR_band))
|
||||
self.water_mask = None # 延迟加载,在处理前初始化
|
||||
self.water_mask_path = water_mask
|
||||
self.output_path = output_path
|
||||
self.block_size = block_size
|
||||
self.R_min = None # 全局R_min(来自重采样扫描)
|
||||
self.D_max = None # 全局D_max(来自重采样扫描)
|
||||
self.G_list = None # 全局G值列表
|
||||
|
||||
# 打开影像获取基本信息
|
||||
self.dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if self.dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
self.width = self.dataset.RasterXSize
|
||||
self.height = self.dataset.RasterYSize
|
||||
self.n_bands = self.dataset.RasterCount
|
||||
|
||||
def _load_water_mask(self):
|
||||
"""延迟加载水域掩膜"""
|
||||
if self.water_mask_path is None:
|
||||
return None
|
||||
|
||||
if isinstance(self.water_mask_path, np.ndarray):
|
||||
if self.water_mask_path.shape[:2] != (self.height, self.width):
|
||||
raise ValueError(
|
||||
f"掩膜尺寸 {self.water_mask_path.shape[:2]} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||||
)
|
||||
return (self.water_mask_path > 0).astype(np.uint8)
|
||||
|
||||
if isinstance(self.water_mask_path, str):
|
||||
if self.water_mask_path.lower().endswith('.shp'):
|
||||
raise ValueError("请先栅格化shapefile为栅格掩膜文件")
|
||||
mask_dataset = gdal.Open(self.water_mask_path, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {self.water_mask_path}")
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
if mask_array.shape != (self.height, self.width):
|
||||
raise ValueError(
|
||||
f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(self.height, self.width)} 不匹配"
|
||||
)
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
return None
|
||||
|
||||
def _scan_global_stats(self, sample_step=20):
|
||||
"""
|
||||
通过重采样扫描获取全图全局统计量:R_min 和 D_max
|
||||
|
||||
分块读取,按sample_step跳行/跳列采样,大幅降低内存占用。
|
||||
内存峰值 ≈ 单波段块大小 + 几个掩膜数组 ≈ block_size² × 4~8MB
|
||||
"""
|
||||
print(f"[Kutser] 扫描全局统计量(采样步长={sample_step})...")
|
||||
water_mask = self._load_water_mask()
|
||||
|
||||
# 预分配采样数组(NIR波段和D值)
|
||||
nir_samples = []
|
||||
d_samples = []
|
||||
sample_count = 0
|
||||
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
block_height = y_end - y_off
|
||||
|
||||
# 读取NIR波段(用于R_min)
|
||||
nir_band = self.dataset.GetRasterBand(self.NIR_band + 1)
|
||||
nir_block = nir_band.ReadAsArray(0, y_off, self.width, block_height)
|
||||
nir_band = None
|
||||
|
||||
# 读取氧吸收相关波段(用于D_max)
|
||||
lower_band = self.dataset.GetRasterBand(self.lower_oxy + 1)
|
||||
lower_block = lower_band.ReadAsArray(0, y_off, self.width, block_height)
|
||||
lower_band = None
|
||||
|
||||
upper_band = self.dataset.GetRasterBand(self.upper_oxy + 1)
|
||||
upper_block = upper_band.ReadAsArray(0, y_off, self.width, block_height)
|
||||
upper_band = None
|
||||
|
||||
oxy_band_obj = self.dataset.GetRasterBand(self.oxy_band + 1)
|
||||
oxy_block = oxy_band_obj.ReadAsArray(0, y_off, self.width, block_height)
|
||||
oxy_band_obj = None
|
||||
|
||||
# 计算D = (lower + upper) * 0.5 - oxy
|
||||
d_block = (lower_block.astype(np.float32) + upper_block.astype(np.float32)) * 0.5 - oxy_block.astype(np.float32)
|
||||
|
||||
# 获取掩膜(整块)
|
||||
if water_mask is not None:
|
||||
mask_block = water_mask[y_off:y_end, :]
|
||||
else:
|
||||
mask_block = np.ones((block_height, self.width), dtype=np.uint8)
|
||||
|
||||
# 对掩膜区域进行采样
|
||||
mask_bool = mask_block.astype(bool)
|
||||
|
||||
if mask_bool.any():
|
||||
# 按步长采样
|
||||
nir_sampled = nir_block[mask_bool][::sample_step]
|
||||
d_sampled = d_block[mask_bool][::sample_step]
|
||||
nir_samples.append(nir_sampled)
|
||||
d_samples.append(d_sampled)
|
||||
sample_count += nir_sampled.size
|
||||
|
||||
# 显式释放块内存
|
||||
del nir_block, lower_block, upper_block, oxy_block, d_block, mask_block
|
||||
|
||||
# 汇总
|
||||
if sample_count == 0:
|
||||
self.R_min = 0.0
|
||||
self.D_max = 1.0
|
||||
else:
|
||||
bsq_path = self.output_path
|
||||
|
||||
# 使用ENVI驱动(默认就是BSQ格式)
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
raise ValueError("无法创建ENVI格式文件,ENVI驱动不可用")
|
||||
|
||||
height, width, n_bands = corrected_array.shape
|
||||
# 创建ENVI格式数据集(会自动生成.hdr文件)
|
||||
dataset = driver.Create(bsq_path, width, height, n_bands, gdal.GDT_Float32)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
try:
|
||||
# 设置地理变换和投影
|
||||
if geotransform:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
if projection:
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
# 写入每个波段(BSQ格式:按波段顺序存储)
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(corrected_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
# 检查.hdr文件是否已创建
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"头文件已保存至: {hdr_path}")
|
||||
all_nir = np.concatenate(nir_samples)
|
||||
all_d = np.concatenate(d_samples)
|
||||
self.R_min = float(np.percentile(all_nir, 5, method='nearest'))
|
||||
self.D_max = float(all_d.max())
|
||||
del all_nir, all_d
|
||||
|
||||
print(f"[Kutser] 全局 R_min={self.R_min:.4f}, D_max={self.D_max:.4f}")
|
||||
|
||||
def _compute_G_list(self):
|
||||
"""
|
||||
计算全局G值列表(每个波段)
|
||||
|
||||
G = G_max - G_min(所有水域像素的极值差异)
|
||||
使用全分辨率扫描,但逐波段读取,每波段内存 ≈ block_size²
|
||||
"""
|
||||
print(f"[Kutser] 计算全局G值列表(n_bands={self.n_bands})...")
|
||||
water_mask = self._load_water_mask()
|
||||
|
||||
# 初始化G_max和G_min为极值
|
||||
g_max = np.full(self.n_bands, -np.inf, dtype=np.float32)
|
||||
g_min = np.full(self.n_bands, np.inf, dtype=np.float32)
|
||||
|
||||
# 逐块扫描
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
block_height = y_end - y_off
|
||||
|
||||
# 读取所有波段的当前块
|
||||
for b in range(self.n_bands):
|
||||
band = self.dataset.GetRasterBand(b + 1)
|
||||
block = band.ReadAsArray(0, y_off, self.width, block_height).astype(np.float32)
|
||||
band = None
|
||||
|
||||
if water_mask is not None:
|
||||
mask_block = water_mask[y_off:y_end, :]
|
||||
mask_bool = mask_block.astype(bool)
|
||||
if mask_bool.any():
|
||||
band_masked = block[mask_bool]
|
||||
g_max[b] = max(g_max[b], band_masked.max())
|
||||
g_min[b] = min(g_min[b], band_masked.min())
|
||||
else:
|
||||
g_max[b] = max(g_max[b], block.max())
|
||||
g_min[b] = min(g_min[b], block.min())
|
||||
|
||||
del block
|
||||
|
||||
self.G_list = (g_max - g_min).tolist()
|
||||
print(f"[Kutser] G值范围: min={min(self.G_list):.4f}, max={max(self.G_list):.4f}")
|
||||
|
||||
def _process_block(self, x_off, y_off, x_size, y_size):
|
||||
"""
|
||||
处理单个分块:读取数据 -> 计算D -> 逐波段校正 -> 返回块结果
|
||||
|
||||
Returns:
|
||||
list of np.ndarray: 校正后的波段列表(每波段形状为 y_size x x_size)
|
||||
"""
|
||||
# 读取氧吸收相关波段
|
||||
lower_band = self.dataset.GetRasterBand(self.lower_oxy + 1)
|
||||
lower_block = lower_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
lower_band = None
|
||||
|
||||
upper_band = self.dataset.GetRasterBand(self.upper_oxy + 1)
|
||||
upper_block = upper_band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
upper_band = None
|
||||
|
||||
oxy_band_obj = self.dataset.GetRasterBand(self.oxy_band + 1)
|
||||
oxy_block = oxy_band_obj.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
oxy_band_obj = None
|
||||
|
||||
# 计算D
|
||||
D = (lower_block + upper_block) * 0.5 - oxy_block
|
||||
|
||||
# 避免除零
|
||||
if self.D_max == 0:
|
||||
D_normalized = np.zeros_like(D)
|
||||
else:
|
||||
print(f"校正后的图像已保存至: {bsq_path} (BSQ格式)")
|
||||
print(f"警告: 未检测到.hdr文件,但GDAL应该已自动创建")
|
||||
D_normalized = D / self.D_max
|
||||
|
||||
# 释放临时块
|
||||
del lower_block, upper_block, oxy_block, D
|
||||
|
||||
# 获取当前块的水域掩膜
|
||||
water_mask = self._load_water_mask()
|
||||
if water_mask is not None:
|
||||
y_end = y_off + y_size
|
||||
x_end = x_off + x_size
|
||||
mask_block = water_mask[y_off:y_end, x_off:x_end].astype(bool)
|
||||
else:
|
||||
mask_block = None
|
||||
|
||||
# 逐波段处理
|
||||
corrected_bands = []
|
||||
for b in range(self.n_bands):
|
||||
band = self.dataset.GetRasterBand(b + 1)
|
||||
R = band.ReadAsArray(x_off, y_off, x_size, y_size).astype(np.float32)
|
||||
band = None
|
||||
|
||||
G = self.G_list[b]
|
||||
# 校正公式:R_corrected = R - G * D_normalized
|
||||
corrected = R - G * D_normalized
|
||||
|
||||
# 只对水域区域应用校正
|
||||
if mask_block is not None:
|
||||
corrected = np.where(mask_block, corrected, R)
|
||||
|
||||
corrected_bands.append(corrected)
|
||||
del R
|
||||
|
||||
# 释放D块
|
||||
del D_normalized
|
||||
|
||||
return corrected_bands
|
||||
|
||||
def get_corrected_bands(self):
|
||||
"""
|
||||
correction is done in reflectance
|
||||
|
||||
:return: 校正后的波段列表
|
||||
"""
|
||||
g_list = self.get_glint_G()
|
||||
D = self.get_depth_D()
|
||||
|
||||
# 获取水域掩膜(如果存在)
|
||||
water_mask_bool = self.water_mask.astype(bool) if self.water_mask is not None else None
|
||||
执行分块处理,返回校正后的波段列表
|
||||
|
||||
corrected_bands = []
|
||||
for band_number in range(self.n_bands): #iterate across bands
|
||||
G = g_list[band_number]
|
||||
R = self.im_aligned[:,:,band_number]
|
||||
# 优化:减少中间数组创建,直接计算
|
||||
corrected_band = R - G * D
|
||||
|
||||
# 如果存在水域掩膜,只对水域区域应用校正
|
||||
if water_mask_bool is not None:
|
||||
corrected_band = np.where(water_mask_bool, corrected_band, R)
|
||||
|
||||
corrected_bands.append(corrected_band)
|
||||
|
||||
# 如果提供了输出路径,保存结果
|
||||
if self.output_path is not None:
|
||||
self._save_corrected_bands(corrected_bands)
|
||||
|
||||
return corrected_bands
|
||||
内存峰值 ≈ 单波段块大小 + 几个辅助数组 ≈ 1000×1000×4B × 3 ≈ 12MB
|
||||
"""
|
||||
if self.output_path is None:
|
||||
raise ValueError("output_path 必须提供,分块处理需要直接写入文件")
|
||||
|
||||
# Step 1: 扫描全局统计量(R_min, D_max)
|
||||
self._scan_global_stats(sample_step=20)
|
||||
|
||||
# Step 2: 计算全局G列表
|
||||
self._compute_G_list()
|
||||
|
||||
# Step 3: 创建输出文件
|
||||
output_dir = os.path.dirname(self.output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
base_path, ext = os.path.splitext(self.output_path)
|
||||
bsq_path = base_path + '.bsq' if ext.lower() != '.bsq' else self.output_path
|
||||
|
||||
# 获取地理信息
|
||||
geotransform = self.dataset.GetGeoTransform()
|
||||
projection = self.dataset.GetProjection()
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
out_dataset = driver.Create(bsq_path, self.width, self.height,
|
||||
self.n_bands, gdal.GDT_Float32)
|
||||
if out_dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {bsq_path}")
|
||||
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
|
||||
# Step 4: 分块处理
|
||||
n_blocks_x = (self.width + self.block_size - 1) // self.block_size
|
||||
n_blocks_y = (self.height + self.block_size - 1) // self.block_size
|
||||
total_blocks = n_blocks_x * n_blocks_y
|
||||
|
||||
print(f"[Kutser] 开始分块处理,共 {total_blocks} 块 ({n_blocks_x}×{n_blocks_y}),块大小={self.block_size}")
|
||||
|
||||
block_idx = 0
|
||||
for y_off in range(0, self.height, self.block_size):
|
||||
y_end = min(y_off + self.block_size, self.height)
|
||||
y_size = y_end - y_off
|
||||
|
||||
for x_off in range(0, self.width, self.block_size):
|
||||
x_end = min(x_off + self.block_size, self.width)
|
||||
x_size = x_end - x_off
|
||||
block_idx += 1
|
||||
|
||||
print(f"[Kutser] 处理块 {block_idx}/{total_blocks} (y={y_off}, x={x_off})")
|
||||
|
||||
# 处理当前块
|
||||
corrected_bands = self._process_block(x_off, y_off, x_size, y_size)
|
||||
|
||||
# 写入输出文件
|
||||
for b in range(self.n_bands):
|
||||
out_band = out_dataset.GetRasterBand(b + 1)
|
||||
out_band.WriteArray(corrected_bands[b], x_off, y_off)
|
||||
out_band.FlushCache()
|
||||
|
||||
del corrected_bands
|
||||
|
||||
out_dataset = None
|
||||
self.dataset = None
|
||||
|
||||
# 检查.hdr文件
|
||||
hdr_path = bsq_path + '.hdr'
|
||||
if os.path.exists(hdr_path):
|
||||
print(f"[Kutser] 校正完成,已保存至: {bsq_path}")
|
||||
else:
|
||||
print(f"[Kutser] 校正完成,已保存至: {bsq_path}(警告: 未检测到.hdr文件)")
|
||||
|
||||
# 返回空列表(结果已直接写入文件)
|
||||
return []
|
||||
|
||||
def __del__(self):
|
||||
if self.dataset is not None:
|
||||
self.dataset = None
|
||||
File diff suppressed because it is too large
Load Diff
@ -19,6 +19,7 @@ from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
|
||||
from sklearn.tree import DecisionTreeRegressor
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
from joblib import parallel_backend
|
||||
# 第三方模型导入
|
||||
# try:
|
||||
# import lightgbm as lgb
|
||||
@ -39,14 +40,13 @@ CB_AVAILABLE = False # 注释掉catboost
|
||||
import sys
|
||||
import os
|
||||
|
||||
# PyInstaller 打包环境感知:EXE 模式下强制单核,防止 Windows 派生无限重启
|
||||
is_frozen_env = getattr(sys, 'frozen', False)
|
||||
safe_n_jobs = 1 if is_frozen_env else -1
|
||||
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
|
||||
|
||||
def _sklearn_parallel_n_jobs() -> int:
|
||||
"""PyInstaller 等打包环境下,joblib loky 会再次启动当前 exe,出现多个同名进程。"""
|
||||
return 1 if getattr(sys, "frozen", False) else -1
|
||||
|
||||
|
||||
class WaterQualityModelingBatch:
|
||||
"""水质参数反演批量建模类"""
|
||||
|
||||
@ -642,26 +642,25 @@ class WaterQualityModelingBatch:
|
||||
# 网格搜索 - 使用KFold代替StratifiedKFold
|
||||
cv_strategy = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
||||
|
||||
_n_jobs = _sklearn_parallel_n_jobs()
|
||||
grid_search = GridSearchCV(
|
||||
base_model,
|
||||
config['params'],
|
||||
cv=cv_strategy,
|
||||
scoring=scoring,
|
||||
n_jobs=_n_jobs,
|
||||
n_jobs=safe_n_jobs,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
# 在训练集上训练模型
|
||||
# with parallel_backend("threading", n_jobs=-1):
|
||||
# grid_search.fit(X_train, y_train)
|
||||
grid_search.fit(X_train, y_train)
|
||||
|
||||
# 获取最佳模型
|
||||
best_model = grid_search.best_estimator_
|
||||
|
||||
# 交叉验证评估(在训练集上)
|
||||
cv_scores = cross_val_score(
|
||||
best_model, X_train, y_train, cv=cv_strategy, scoring=scoring, n_jobs=_n_jobs
|
||||
)
|
||||
cv_scores = cross_val_score(best_model, X_train, y_train, cv=cv_strategy, scoring=scoring)
|
||||
|
||||
# 计算训练集上的回归指标
|
||||
y_train_pred = best_model.predict(X_train)
|
||||
|
||||
@ -40,16 +40,19 @@ class WaterQualityInference:
|
||||
self.best_model_info = None
|
||||
self.loaded_model_data = None
|
||||
|
||||
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
加载sampling生成的CSV数据
|
||||
加载sampling生成的CSV数据(兼容 WQI 增强版 CSV)
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径,前两列为经纬度,其余列为光谱数据
|
||||
csv_path: CSV文件路径
|
||||
旧版:x_coord,y_coord,pixel_x,pixel_y,波长...
|
||||
新版:x_coord,y_coord,WQI_...,波长...
|
||||
|
||||
Returns:
|
||||
coords: 经纬度数据 (DataFrame)
|
||||
spectra: 光谱数据 (DataFrame)
|
||||
coords: 经纬度数据 (DataFrame, 2列)
|
||||
spectra: 纯光谱数据 (DataFrame, 跳过 WQI 列)
|
||||
wqi_df: WQI 指数列 (DataFrame, 0或45列)
|
||||
"""
|
||||
print(f"正在加载采样数据: {csv_path}")
|
||||
|
||||
@ -71,15 +74,35 @@ class WaterQualityInference:
|
||||
coords = data.iloc[:, :2].copy()
|
||||
coords.columns = ['longitude', 'latitude']
|
||||
|
||||
# 从第5列开始为光谱数据(跳过第2、3、4列的其他信息)
|
||||
spectra = data.iloc[:, 4:].copy()
|
||||
# 动态识别光谱列(兼容 sampling_spectra.csv 列顺序变更)
|
||||
# 列名约定:波长为纯数字字符串如 "374.285004";WQI 为 "WQI_xxx" 前缀
|
||||
# 旧版 CSV(无WQI):x_coord,y_coord,pixel_x,pixel_y,波长... → 取 [4:]
|
||||
# 新版 CSV(有WQI):x_coord,y_coord,WQI_...,波长... → 过滤 WQI 列后取光谱
|
||||
all_cols = list(data.columns)
|
||||
spectral_col_indices = []
|
||||
wqi_col_indices = []
|
||||
for i, col in enumerate(all_cols):
|
||||
col_str = str(col)
|
||||
if col_str.startswith('WQI_'):
|
||||
wqi_col_indices.append(i)
|
||||
elif col_str.replace('.', '').lstrip('-').isdigit():
|
||||
# 波长列:纯数字字符串
|
||||
spectral_col_indices.append(i)
|
||||
else:
|
||||
# 其他元数据列(x_coord/y_coord/pixel_x/pixel_y),由 coords 接收
|
||||
pass
|
||||
|
||||
# 光谱列 = 纯数字列(WQI 已被排除)
|
||||
spectra = data.iloc[:, spectral_col_indices].copy() if spectral_col_indices else data.iloc[:, 4:].copy()
|
||||
# WQI 列(用于追加到预测结果输出)
|
||||
wqi_df = data.iloc[:, wqi_col_indices].copy() if wqi_col_indices else pd.DataFrame()
|
||||
|
||||
print(f" 经纬度数据形状: {coords.shape}")
|
||||
print(f" 光谱数据形状: {spectra.shape}")
|
||||
print(f" 光谱数据形状: {spectra.shape} (自动识别波长列,排除 {len(wqi_col_indices)} 个WQI列)")
|
||||
print(f" 经纬度范围: 经度[{coords['longitude'].min():.6f}, {coords['longitude'].max():.6f}], "
|
||||
f"纬度[{coords['latitude'].min():.6f}, {coords['latitude'].max():.6f}]")
|
||||
|
||||
return coords, spectra
|
||||
return coords, spectra, wqi_df
|
||||
|
||||
def random(self, data, label, test_ratio=0.2, random_state=123):
|
||||
"""
|
||||
@ -519,6 +542,69 @@ class WaterQualityInference:
|
||||
print(f"正在应用预处理方法: {actual_preprocess_method}")
|
||||
print(f"原始光谱数据形状: {spectra.shape}")
|
||||
|
||||
# ---- 自动特征补全:50 光谱 → 补全至模型训练时的 95 维(WQI 指数) ----
|
||||
# 触发条件:模型期望 n_features_in_ 个特征,但当前 spectra 列数不足
|
||||
# 原因:training_spectra.csv 含 50 光谱 + 45 WQI;sampling_spectra.csv 只有 50 光谱
|
||||
# 做法:与训练端(calculate_all_indices)完全一致的算法列表,实时补全缺失的 45 个 WQI 列
|
||||
model = self.loaded_model_data['model']
|
||||
expected_features = getattr(model, 'n_features_in_', None)
|
||||
|
||||
# ---- 自动特征补全:50 光谱 → 补全至模型训练时的 n_features_in_ 维(WQI 指数) ----
|
||||
if expected_features is not None and spectra.shape[1] < expected_features:
|
||||
print(f"[特征补全] 检测到特征缺口:当前 {spectra.shape[1]} 列 < 模型期望 {expected_features} 列,"
|
||||
f"正在从光谱数据实时计算 WQI 指数...")
|
||||
try:
|
||||
from src.utils.water_index import WaterQualityIndexCalculator
|
||||
calc = WaterQualityIndexCalculator()
|
||||
|
||||
# 提取纯计算方法(排除 find_closest_wavelength 和 calculate_all_indices,
|
||||
# 以及不返回 Series 的辅助方法)
|
||||
algorithm_methods = []
|
||||
for m in dir(calc):
|
||||
if m.startswith('_'):
|
||||
continue
|
||||
if m in ['find_closest_wavelength', 'calculate_all_indices']:
|
||||
continue
|
||||
attr = getattr(calc, m)
|
||||
if callable(attr):
|
||||
algorithm_methods.append(m)
|
||||
|
||||
original_col_count = spectra.shape[1]
|
||||
for algo_name in algorithm_methods:
|
||||
try:
|
||||
algo_func = getattr(calc, algo_name)
|
||||
result = algo_func(spectra)
|
||||
# 只追加返回 Series 且长度为样本数的合法结果
|
||||
if isinstance(result, pd.Series) and len(result) == len(spectra):
|
||||
spectra[algo_name] = result.values
|
||||
else:
|
||||
spectra[algo_name] = np.nan
|
||||
except Exception:
|
||||
spectra[algo_name] = np.nan
|
||||
|
||||
print(f"[特征补全] 完成!光谱列已扩充至 {spectra.shape[1]} 列"
|
||||
f"(追加了 {spectra.shape[1] - original_col_count} 个 WQI 指数)")
|
||||
except Exception as e:
|
||||
print(f"[特征补全] 失败,将使用原始光谱特征: {e}")
|
||||
|
||||
# ---- 防线 1:强制维度对齐(物理截断)----
|
||||
if expected_features is not None and spectra.shape[1] > expected_features:
|
||||
print(f"[精准对齐] 正在将 {spectra.shape[1]} 维特征截断为模型要求的 {expected_features} 维")
|
||||
spectra = spectra.iloc[:, :expected_features]
|
||||
elif expected_features is not None and spectra.shape[1] < expected_features:
|
||||
# 维度不足时填充 0
|
||||
padding_cols = expected_features - spectra.shape[1]
|
||||
for i in range(padding_cols):
|
||||
spectra[f'_padding_{i}'] = 0.0
|
||||
print(f"[精准对齐] 特征不足,填充 {padding_cols} 列 0")
|
||||
|
||||
# ---- 防线 2:彻底清洗无穷大数值----
|
||||
# 防止 WQI 计算中除零/溢出产生 np.inf / -np.inf 导致预处理崩溃
|
||||
spectra = spectra.replace([np.inf, -np.inf], np.nan)
|
||||
spectra = spectra.fillna(0)
|
||||
|
||||
print(f"[特征对齐] 最终输入维度: {spectra.shape}")
|
||||
|
||||
try:
|
||||
# 应用预处理
|
||||
spectra_processed = Preprocessing(actual_preprocess_method, spectra)
|
||||
@ -555,7 +641,13 @@ class WaterQualityInference:
|
||||
print(f"输入数据形状: {spectra_processed.shape}")
|
||||
|
||||
try:
|
||||
predictions = model.predict(spectra_processed)
|
||||
# 清洗 NaN / Inf,防止 SVR 等模型报错
|
||||
spectra_clean = np.nan_to_num(spectra_processed, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
if np.any(np.isnan(spectra_clean)) or np.any(np.isinf(spectra_clean)):
|
||||
print("警告: 清洗后数据中仍存在 NaN/Inf,已重置为 0")
|
||||
spectra_clean = np.nan_to_num(spectra_clean, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
predictions = model.predict(spectra_clean)
|
||||
print(f"预测完成,结果形状: {predictions.shape}")
|
||||
print(f"预测值范围: [{np.min(predictions):.4f}, {np.max(predictions):.4f}]")
|
||||
print(f"预测值统计: 均值={np.mean(predictions):.4f}, 标准差={np.std(predictions):.4f}")
|
||||
@ -567,7 +659,8 @@ class WaterQualityInference:
|
||||
raise
|
||||
|
||||
def save_predictions(self, coords: pd.DataFrame, predictions: np.ndarray,
|
||||
output_path: str, prediction_column: str = 'prediction'):
|
||||
output_path: str, prediction_column: str = 'prediction',
|
||||
wqi_columns: Optional[pd.DataFrame] = None):
|
||||
"""
|
||||
保存预测结果
|
||||
|
||||
@ -576,11 +669,15 @@ class WaterQualityInference:
|
||||
predictions: 预测结果
|
||||
output_path: 输出文件路径
|
||||
prediction_column: 预测列名称
|
||||
wqi_columns: Optional[pd.DataFrame] = None
|
||||
"""
|
||||
print(f"正在保存预测结果到: {output_path}")
|
||||
|
||||
# 创建结果DataFrame
|
||||
result_df = coords.copy()
|
||||
# 追加 WQI 水质指数列(如 sampling_spectra.csv 注入了 45 列指数)
|
||||
if wqi_columns is not None and not wqi_columns.empty:
|
||||
result_df = pd.concat([result_df, wqi_columns.reset_index(drop=True)], axis=1)
|
||||
result_df[prediction_column] = predictions
|
||||
|
||||
# 确保输出目录存在
|
||||
@ -653,10 +750,10 @@ class WaterQualityInference:
|
||||
else:
|
||||
self.load_best_model(metric=metric)
|
||||
|
||||
# 2. 加载采样数据
|
||||
# 2. 加载采样数据(coords=坐标, spectra=纯光谱, wqi_df=45个WQI指数列)
|
||||
print("\n步骤2: 加载采样数据")
|
||||
print("-" * 40)
|
||||
coords, spectra = self.load_sampling_data(sampling_csv_path)
|
||||
coords, spectra, wqi_df = self.load_sampling_data(sampling_csv_path)
|
||||
|
||||
# 3. 数据预处理
|
||||
print("\n步骤3: 数据预处理")
|
||||
@ -668,10 +765,11 @@ class WaterQualityInference:
|
||||
print("-" * 40)
|
||||
predictions = self.predict(spectra_processed)
|
||||
|
||||
# 5. 保存预测结果
|
||||
# 5. 保存预测结果(透传 WQI 列至最终输出文件)
|
||||
print("\n步骤5: 保存预测结果")
|
||||
print("-" * 40)
|
||||
result_df = self.save_predictions(coords, predictions, output_csv_path, prediction_column)
|
||||
result_df = self.save_predictions(coords, predictions, output_csv_path,
|
||||
prediction_column, wqi_df)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("推理流程完成!")
|
||||
@ -741,10 +839,11 @@ class WaterQualityInference:
|
||||
output_file = output_path / f"prediction_{csv_file.name}"
|
||||
|
||||
# 执行推理
|
||||
coords, spectra = self.load_sampling_data(str(csv_file))
|
||||
coords, spectra, wqi_df = self.load_sampling_data(str(csv_file))
|
||||
spectra_processed = self.preprocess_spectra(spectra)
|
||||
predictions = self.predict(spectra_processed)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file), prediction_column)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file),
|
||||
prediction_column, wqi_df)
|
||||
|
||||
results[csv_file.name] = {
|
||||
'output_file': str(output_file),
|
||||
@ -902,10 +1001,11 @@ class WaterQualityInference:
|
||||
output_file = output_path / f"{file_stem}{file_ext}"
|
||||
|
||||
# 执行推理
|
||||
coords, spectra = self.load_sampling_data(str(csv_file))
|
||||
coords, spectra, wqi_df = self.load_sampling_data(str(csv_file))
|
||||
spectra_processed = self.preprocess_spectra(spectra)
|
||||
predictions = self.predict(spectra_processed)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file), prediction_column)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file),
|
||||
prediction_column, wqi_df)
|
||||
|
||||
results[file_stem] = {
|
||||
'input_file': str(csv_file),
|
||||
|
||||
20
src/core/steps/__init__.py
Normal file
20
src/core/steps/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""业务步骤层模块"""
|
||||
|
||||
from src.core.steps.water_mask_step import WaterMaskStep
|
||||
from src.core.steps.glint_detection_step import GlintDetectionStep
|
||||
from src.core.steps.glint_removal_step import GlintRemovalStep
|
||||
from src.core.steps.data_preparation_step import DataPreparationStep
|
||||
from src.core.steps.modeling_step import ModelingStep
|
||||
from src.core.steps.prediction_step import PredictionStep
|
||||
from src.core.steps.mapping_step import MappingStep
|
||||
|
||||
__all__ = [
|
||||
"WaterMaskStep",
|
||||
"GlintDetectionStep",
|
||||
"GlintRemovalStep",
|
||||
"DataPreparationStep",
|
||||
"ModelingStep",
|
||||
"PredictionStep",
|
||||
"MappingStep",
|
||||
]
|
||||
184
src/core/steps/data_preparation_step.py
Normal file
184
src/core/steps/data_preparation_step.py
Normal file
@ -0,0 +1,184 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据准备步骤
|
||||
|
||||
包含 step4_process_csv, step5_extract_training_spectra, step5_5_calculate_water_quality_indices
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable, Dict
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataPreparationStep:
|
||||
"""数据准备步骤"""
|
||||
|
||||
# ---- Step 4: 处理CSV文件 ----
|
||||
|
||||
@staticmethod
|
||||
def process_csv(
|
||||
csv_path: str,
|
||||
output_dir: Union[str, Path] = "./4_processed_data",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""处理CSV文件(筛选剔除异常值)"""
|
||||
from src.preprocessing.process_water_quality_data import process_water_quality_data
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "processed_data.csv")
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤4", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤4: 处理CSV文件,筛选剔除异常值")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的处理后CSV文件,直接使用: {output_path}")
|
||||
notify("skipped", f"处理后的CSV文件已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
process_water_quality_data(csv_path, output_path)
|
||||
notify("completed", f"处理后的CSV文件已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
# ---- Step 5: 提取训练样本点光谱 ----
|
||||
|
||||
@staticmethod
|
||||
def extract_training_spectra(
|
||||
deglint_img_path: Optional[str] = None,
|
||||
radius: int = 5,
|
||||
source_epsg: int = 4326,
|
||||
csv_path: Optional[str] = None,
|
||||
boundary_path: Optional[str] = None,
|
||||
glint_mask_path: Optional[str] = None,
|
||||
water_mask_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./5_training_spectra",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
|
||||
from src.core.glint_removal.get_spectral import get_spectral_in_coor
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "training_spectra.csv")
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤5: 提取训练样本点的平均光谱")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if deglint_img_path is None:
|
||||
raise ValueError("必须提供 deglint_img_path 参数")
|
||||
if csv_path is None:
|
||||
raise ValueError("必须提供 csv_path 参数")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的训练光谱数据文件,直接使用: {output_path}")
|
||||
notify("skipped", f"训练光谱数据已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
# 确保水体掩膜存在
|
||||
final_boundary_path = boundary_path
|
||||
if final_boundary_path is None and water_mask_path is not None:
|
||||
final_boundary_path = water_mask_path
|
||||
|
||||
# 【新增安全防护】智能拦截矢量 .shp,强制替换为步骤 1 生成的栅格 .dat
|
||||
if final_boundary_path and str(final_boundary_path).lower().endswith('.shp'):
|
||||
# 向上追溯查找 1_water_mask 目录下的 dat 替身
|
||||
possible_dat = Path(deglint_img_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
|
||||
if not possible_dat.exists() and output_path:
|
||||
possible_dat = Path(output_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
|
||||
|
||||
if possible_dat.exists():
|
||||
print(f"💡 智能拦截:检测到输入掩膜为矢量 .shp,自动切换为已生成的栅格掩膜: {possible_dat}")
|
||||
final_boundary_path = str(possible_dat)
|
||||
else:
|
||||
print(f"⚠️ 警告:检测到输入掩膜为 .shp 且未找到对应 .dat 替身,可能导致底层读取失败。")
|
||||
|
||||
flare_path = glint_mask_path
|
||||
if flare_path:
|
||||
print(f"光谱提取使用耀斑掩膜: {flare_path}")
|
||||
|
||||
get_spectral_in_coor(
|
||||
deglint_img_path, csv_path, output_path,
|
||||
radius=radius, flare_path=flare_path,
|
||||
boundary_path=final_boundary_path, source_epsg=source_epsg
|
||||
)
|
||||
|
||||
notify("completed", f"训练光谱数据已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
# ---- Step 5.5: 计算水质光谱指数 ----
|
||||
|
||||
@staticmethod
|
||||
def calculate_water_quality_indices(
|
||||
training_spectra_path: Optional[str] = None,
|
||||
formula_csv_file: Optional[str] = None,
|
||||
formula_names: Optional[List[str]] = None,
|
||||
output_file: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
output_dir: Union[str, Path] = "./6_water_quality_indices",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> Optional[str]:
|
||||
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤5.5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤5.5: 计算水质光谱指数(使用band_math方法)")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过水质指数计算(enabled=False)。")
|
||||
notify("skipped", "跳过水质指数计算")
|
||||
return None
|
||||
|
||||
if training_spectra_path is None:
|
||||
raise ValueError("必须提供 training_spectra_path 参数")
|
||||
if formula_csv_file is None:
|
||||
raise ValueError("必须提供 formula_csv_file 参数")
|
||||
|
||||
if output_file:
|
||||
output_path = str(Path(output_file))
|
||||
else:
|
||||
output_path = str(output_dir / "water_quality_indices.csv")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的水质指数文件,直接使用: {output_path}")
|
||||
notify("skipped", f"水质指数数据已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
from src.utils.band_math import BandMathCalculator
|
||||
|
||||
calculator = BandMathCalculator(training_spectra_path)
|
||||
result_df = calculator.process_formulas_from_csv(
|
||||
formula_csv_file=formula_csv_file,
|
||||
formula_names=formula_names,
|
||||
output_file=output_path
|
||||
)
|
||||
|
||||
if result_df is None:
|
||||
raise ValueError("计算水质指数失败,请检查公式CSV文件格式")
|
||||
|
||||
notify("completed", f"水质指数已保存: {output_path}")
|
||||
return output_path
|
||||
113
src/core/steps/glint_detection_step.py
Normal file
113
src/core/steps/glint_detection_step.py
Normal file
@ -0,0 +1,113 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
步骤2: 耀斑区域检测
|
||||
|
||||
支持多种检测方法: otsu, zscore, percentile, iqr, adaptive, multi_band
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union
|
||||
|
||||
|
||||
class GlintDetectionStep:
|
||||
"""耀斑区域检测步骤"""
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
img_path: str,
|
||||
glint_wave: float = 750.0,
|
||||
method: str = "otsu",
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95.0,
|
||||
iqr_multiplier: float = 1.5,
|
||||
window_size: int = 15,
|
||||
multi_band_waves: Optional[List[float]] = None,
|
||||
sub_method: str = "zscore",
|
||||
weights: Optional[List[float]] = None,
|
||||
max_area: Optional[int] = None,
|
||||
buffer_size: Optional[int] = None,
|
||||
water_mask_path: Optional[str] = None,
|
||||
glint_dir: Union[str, Path] = "./2_glint",
|
||||
callback: Optional[callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
执行耀斑区域检测
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
glint_wave: 用于耀斑检测的波段波长(nm)
|
||||
method: 检测方法 ('otsu' | 'zscore' | 'percentile' | 'iqr' | 'adaptive' | 'multi_band')
|
||||
z_threshold: Z-score 方法阈值(默认 2.5)
|
||||
percentile: 百分位数阈值(默认 95.0)
|
||||
iqr_multiplier: IQR 倍数(默认 1.5)
|
||||
window_size: 自适应阈值窗口大小(默认 15)
|
||||
multi_band_waves: 多波段方法的波长列表,如 [750, 800, 850]
|
||||
sub_method: 多波段方法的子方法(默认 'zscore')
|
||||
weights: 多波段方法的权重列表(None 表示等权重)
|
||||
max_area: 最大连通域面积阈值(像素),超过则过滤
|
||||
buffer_size: 岸边缓冲区大小(像素),用于去除岸边附近错误掩膜
|
||||
water_mask_path: 水域掩膜文件路径(dat 格式优先)
|
||||
glint_dir: 工作目录
|
||||
callback: 回调函数
|
||||
|
||||
Returns:
|
||||
耀斑掩膜文件路径 (.dat)
|
||||
"""
|
||||
from src.utils.find_severe_glint_area import find_severe_glint_area
|
||||
|
||||
glint_dir = Path(glint_dir)
|
||||
glint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤2", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤2: 找到耀斑区域")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
# 确定水体掩膜路径
|
||||
if water_mask_path is not None and Path(water_mask_path).exists():
|
||||
final_water_mask_path = water_mask_path
|
||||
else:
|
||||
final_water_mask_path = None
|
||||
|
||||
output_path = str(glint_dir / "severe_glint_area.dat")
|
||||
|
||||
# 跳过已存在的文件
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的耀斑掩膜文件,直接使用: {output_path}")
|
||||
notify("skipped", f"耀斑掩膜已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
# 构建检测参数字典
|
||||
kwargs = {
|
||||
"method": method,
|
||||
"z_threshold": z_threshold,
|
||||
"percentile": percentile,
|
||||
"iqr_multiplier": iqr_multiplier,
|
||||
"window_size": window_size,
|
||||
}
|
||||
if method == "multi_band":
|
||||
if multi_band_waves is not None:
|
||||
kwargs["multi_band_waves"] = multi_band_waves
|
||||
if sub_method is not None:
|
||||
kwargs["sub_method"] = sub_method
|
||||
if weights is not None:
|
||||
kwargs["weights"] = weights
|
||||
if max_area is not None:
|
||||
kwargs["max_area"] = max_area
|
||||
if buffer_size is not None:
|
||||
kwargs["buffer_size"] = buffer_size
|
||||
|
||||
glint_mask_path = find_severe_glint_area(
|
||||
img_path, final_water_mask_path, glint_wave, output_path, **kwargs
|
||||
)
|
||||
|
||||
print(f"耀斑掩膜已生成: {glint_mask_path}")
|
||||
print(f"使用检测方法: {method}")
|
||||
notify("completed", f"耀斑掩膜已生成: {glint_mask_path}")
|
||||
return glint_mask_path
|
||||
375
src/core/steps/glint_removal_step.py
Normal file
375
src/core/steps/glint_removal_step.py
Normal file
@ -0,0 +1,375 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
步骤3: 去除耀斑
|
||||
|
||||
支持多种方法: subtract_nir, regression_slope, oxygen_absorption, kutser, goodman, hedley, sugar
|
||||
|
||||
每种方法都会:
|
||||
1. 准备水域掩膜(支持 shp 自动转 dat)
|
||||
2. 调用对应的算法类执行处理
|
||||
3. 复制 hdr 文件到输出影像
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _safe_rename(src_bsq: str, src_hdr: str, dest_bsq: str, dest_hdr: str) -> str:
|
||||
"""将底层硬编码生成的 .bsq + .hdr 文件对重命名到用户指定的 output_path
|
||||
|
||||
使用 os.remove + os.rename 确保原子覆盖(不等 os.replace 的跨设备行为),
|
||||
resolve() 断路防止同路径 self-rename 报错。
|
||||
|
||||
Returns:
|
||||
dest_bsq 路径
|
||||
"""
|
||||
src_bsq_p = Path(src_bsq)
|
||||
src_hdr_p = Path(src_hdr)
|
||||
dest_bsq_p = Path(dest_bsq)
|
||||
dest_hdr_p = Path(dest_hdr)
|
||||
|
||||
if str(src_bsq_p.resolve()) == str(dest_bsq_p.resolve()):
|
||||
return dest_bsq
|
||||
|
||||
if dest_bsq_p.exists():
|
||||
os.remove(dest_bsq_p)
|
||||
if dest_hdr_p.exists():
|
||||
os.remove(dest_hdr_p)
|
||||
|
||||
if src_bsq_p.exists():
|
||||
os.rename(src_bsq_p, dest_bsq_p)
|
||||
if src_hdr_p.exists():
|
||||
os.rename(src_hdr_p, dest_hdr_p)
|
||||
|
||||
return dest_bsq
|
||||
|
||||
|
||||
class GlintRemovalStep:
|
||||
"""去除耀斑步骤"""
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
img_path: str,
|
||||
method: str = "subtract_nir",
|
||||
start_wave: Optional[float] = None,
|
||||
end_wave: Optional[float] = None,
|
||||
json_path: Optional[str] = None,
|
||||
left_shoulder_wave: Optional[float] = None,
|
||||
valley_wave: Optional[float] = None,
|
||||
right_shoulder_wave: Optional[float] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
interpolated_img_path: Optional[str] = None,
|
||||
interpolate_zeros: bool = False,
|
||||
interpolation_method: str = "nearest",
|
||||
enabled: bool = True,
|
||||
# Kutser 参数
|
||||
kutser_shp_path: Optional[str] = None,
|
||||
oxy_band: int = 38,
|
||||
lower_oxy: int = 36,
|
||||
upper_oxy: int = 49,
|
||||
nir_band: int = 47,
|
||||
# Goodman 参数
|
||||
nir_lower: int = 25,
|
||||
nir_upper: int = 37,
|
||||
goodman_A: float = 0.000019,
|
||||
goodman_B: float = 0.1,
|
||||
# Hedley 参数
|
||||
hedley_shp_path: Optional[str] = None,
|
||||
hedley_nir_band: int = 47,
|
||||
# SUGAR 参数
|
||||
sugar_bounds: Optional[List[tuple]] = None,
|
||||
sugar_sigma: float = 1.0,
|
||||
sugar_estimate_background: bool = True,
|
||||
sugar_glint_mask_method: str = "cdf",
|
||||
sugar_iter: Optional[int] = 3,
|
||||
sugar_termination_thresh: float = 20.0,
|
||||
# 内部工具函数
|
||||
_get_image_geo_info=None,
|
||||
_load_image_as_array=None,
|
||||
_save_bands_as_image=None,
|
||||
_copy_hdr_info=None,
|
||||
_prepare_water_mask_for_algorithm=None,
|
||||
_interpolate_zero_pixels_batch=None,
|
||||
deglint_dir: Union[str, Path] = "./3_deglint",
|
||||
water_mask_dir: Union[str, Path] = "./1_water_mask",
|
||||
callback: Optional[Callable] = None,
|
||||
output_path: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
执行去除耀斑处理
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
method: 去耀斑方法
|
||||
...(其余参数同主类 step3_remove_glint)
|
||||
|
||||
Returns:
|
||||
去除耀斑后的影像文件路径
|
||||
"""
|
||||
from src.core.glint_removal.Kutser import Kutser
|
||||
from src.core.glint_removal.Goodman import Goodman
|
||||
from src.core.glint_removal.Hedley import Hedley
|
||||
from src.core.glint_removal.SUGAR import SUGAR, correction_iterative
|
||||
from src.core.utils.gdal_helper import (
|
||||
get_image_geo_info as _default_get_geo,
|
||||
load_image_as_array as _default_load,
|
||||
save_bands_as_image as _default_save_bands,
|
||||
copy_hdr_info as _default_copy_hdr,
|
||||
)
|
||||
from src.core.utils.mask_converter import (
|
||||
prepare_water_mask_for_algorithm as _default_prepare,
|
||||
)
|
||||
|
||||
# 使用提供的函数或默认函数
|
||||
if _get_image_geo_info is None:
|
||||
_get_image_geo_info = _default_get_geo
|
||||
if _load_image_as_array is None:
|
||||
_load_image_as_array = _default_load
|
||||
if _save_bands_as_image is None:
|
||||
_save_bands_as_image = _default_save_bands
|
||||
if _copy_hdr_info is None:
|
||||
_copy_hdr_info = _default_copy_hdr
|
||||
if _prepare_water_mask_for_algorithm is None:
|
||||
_prepare_water_mask_for_algorithm = _default_prepare
|
||||
|
||||
deglint_dir = Path(deglint_dir)
|
||||
deglint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤3", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤3: 去除耀斑")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
# 方法名标准化
|
||||
raw_method = str(method).lower()
|
||||
if "kutser" in raw_method:
|
||||
method = "kutser"
|
||||
elif "goodman" in raw_method:
|
||||
method = "goodman"
|
||||
elif "hedley" in raw_method:
|
||||
method = "hedley"
|
||||
elif "sugar" in raw_method:
|
||||
method = "sugar"
|
||||
|
||||
# 如果未启用,直接返回原始影像
|
||||
if not enabled:
|
||||
print("已设置跳过去除耀斑(enabled=False),将直接使用原始影像。")
|
||||
notify("skipped", "跳过去耀斑,使用原始影像")
|
||||
return img_path
|
||||
|
||||
# ---- 确定水域掩膜 ----
|
||||
final_water_mask = water_mask
|
||||
if final_water_mask is not None and str(final_water_mask).lower().endswith(".shp"):
|
||||
# shp 自动替换为 dat
|
||||
dat_mask = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
|
||||
if Path(dat_mask).exists():
|
||||
print(f"检测到输入掩膜为 .shp,自动替换为栅格掩膜: {dat_mask}")
|
||||
final_water_mask = dat_mask
|
||||
|
||||
if final_water_mask is None:
|
||||
dat_mask_default = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
|
||||
if Path(dat_mask_default).exists():
|
||||
final_water_mask = dat_mask_default
|
||||
print(f"使用步骤1生成的水域掩膜: {final_water_mask}")
|
||||
|
||||
# ---- 步骤3.1: 0值像素插值 ----
|
||||
if interpolate_zeros:
|
||||
print("\n" + "-" * 80)
|
||||
print("步骤3.1: 对0值像素进行插值")
|
||||
print("-" * 80)
|
||||
interp_start_time = time.time()
|
||||
|
||||
if _interpolate_zero_pixels_batch is None:
|
||||
from src.core.algorithms.interpolation.interpolator import (
|
||||
interpolate_zero_pixels_batch as _interp_batch,
|
||||
)
|
||||
_interpolate_zero_pixels_batch = _interp_batch
|
||||
|
||||
interp_result, _ = _interpolate_zero_pixels_batch(
|
||||
img_path=img_path,
|
||||
interpolation_method=interpolation_method,
|
||||
output_path=None,
|
||||
water_mask=final_water_mask,
|
||||
deglint_dir=str(deglint_dir),
|
||||
callback_progress=lambda msg: print(f" {msg}"),
|
||||
)
|
||||
img_path = interp_result
|
||||
interp_end_time = time.time()
|
||||
print(f"插值完成,使用插值后的影像: {img_path}")
|
||||
|
||||
# ---- 获取影像信息 ----
|
||||
geotransform, projection, width, height, n_bands = _get_image_geo_info(img_path)
|
||||
print(f"影像尺寸: {width} x {height} x {n_bands}")
|
||||
|
||||
mask_for_algorithm = _prepare_water_mask_for_algorithm(
|
||||
final_water_mask, (height, width), geotransform, projection, img_path
|
||||
)
|
||||
|
||||
# ==================== Kutser ====================
|
||||
if method == "kutser":
|
||||
print(f"使用方法: Kutser (氧吸收波段={oxy_band}, NIR波段={nir_band})")
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_kutser.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
# 将用户指定的 output_path 标准化为 .bsq 路径
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
kutser = Kutser(
|
||||
img_path,
|
||||
shp_path=None,
|
||||
oxy_band=oxy_band,
|
||||
lower_oxy=lower_oxy,
|
||||
upper_oxy=upper_oxy,
|
||||
NIR_band=nir_band,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
kutser.get_corrected_bands()
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
raise RuntimeError(f"Kutser算法未生成输出文件: {hardcoded_bsq}")
|
||||
|
||||
# ==================== Goodman ====================
|
||||
elif method == "goodman":
|
||||
print(f"使用方法: Goodman (NIR波段范围: {nir_lower}-{nir_upper})")
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_goodman.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
goodman = Goodman(
|
||||
img_path,
|
||||
NIR_lower=nir_lower,
|
||||
NIR_upper=nir_upper,
|
||||
A=goodman_A,
|
||||
B=goodman_B,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
corrected_bands = goodman.get_corrected_bands()
|
||||
|
||||
if not Path(hardcoded_bsq).exists():
|
||||
_save_bands_as_image(corrected_bands, hardcoded_bsq, geotransform, projection)
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
del corrected_bands
|
||||
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
|
||||
# ==================== Hedley ====================
|
||||
elif method == "hedley":
|
||||
print(f"使用方法: Hedley (NIR波段={hedley_nir_band})")
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_hedley.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
hedley = Hedley(
|
||||
img_path,
|
||||
shp_path=None,
|
||||
NIR_band=hedley_nir_band,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
hedley.get_corrected_bands()
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
raise RuntimeError(f"Hedley算法未生成输出文件: {hardcoded_bsq}")
|
||||
|
||||
# ==================== SUGAR ====================
|
||||
elif method == "sugar":
|
||||
glint_method_raw = str(sugar_glint_mask_method).lower()
|
||||
if "cdf" in glint_method_raw or "累积" in glint_method_raw:
|
||||
sugar_glint_mask_method_fixed = "cdf"
|
||||
elif "otsu" in glint_method_raw or "大津" in glint_method_raw:
|
||||
sugar_glint_mask_method_fixed = "otsu"
|
||||
else:
|
||||
sugar_glint_mask_method_fixed = "cdf"
|
||||
|
||||
print(
|
||||
f"使用方法: SUGAR (迭代次数={sugar_iter}, 掩膜方法={sugar_glint_mask_method_fixed})"
|
||||
)
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_sugar.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
if sugar_bounds is None:
|
||||
sugar_bounds = [(1, 2)]
|
||||
|
||||
correction_iterative(
|
||||
img_path,
|
||||
iter=sugar_iter,
|
||||
bounds=sugar_bounds,
|
||||
estimate_background=sugar_estimate_background,
|
||||
glint_mask_method=sugar_glint_mask_method_fixed,
|
||||
termination_thresh=sugar_termination_thresh,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
raise RuntimeError(f"SUGAR算法未生成输出文件: {hardcoded_bsq}")
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"不支持的方法: {method}。支持的方法: kutser, goodman, hedley, sugar"
|
||||
)
|
||||
109
src/core/steps/mapping_step.py
Normal file
109
src/core/steps/mapping_step.py
Normal file
@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
成图步骤
|
||||
|
||||
包含 step9_generate_distribution_map
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Callable
|
||||
|
||||
|
||||
class MappingStep:
|
||||
"""成图步骤"""
|
||||
|
||||
@staticmethod
|
||||
def generate_distribution_map(
|
||||
prediction_csv_path: str,
|
||||
boundary_shp_path: str,
|
||||
output_image_path: Optional[str] = None,
|
||||
resolution: float = 30,
|
||||
input_crs: str = "EPSG:32651",
|
||||
output_crs: str = "EPSG:4326",
|
||||
show_sample_points: bool = False,
|
||||
base_map_tif: Optional[str] = None,
|
||||
use_distance_diffusion: bool = True,
|
||||
max_diffusion_distance: Optional[float] = None,
|
||||
diffusion_power: float = 2,
|
||||
diffusion_n_neighbors: int = 15,
|
||||
cmap: Optional[str] = None,
|
||||
expand_ratio: float = 0.05,
|
||||
output_dir: Union[str, Path] = "./14_visualization",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
根据采样点的坐标和反演的实测参数,通过插值方法得到水质参数可视化分布图
|
||||
|
||||
Args:
|
||||
prediction_csv_path: 预测结果CSV文件路径(前两列为经纬度,第三列为预测值)
|
||||
boundary_shp_path: 边界shapefile文件路径
|
||||
output_image_path: 输出图片路径(如果为None,自动生成)
|
||||
resolution: 插值网格分辨率(米)
|
||||
input_crs: 输入坐标系
|
||||
output_crs: 输出坐标系
|
||||
show_sample_points: 是否在图上显示采样点
|
||||
base_map_tif: 底图TIF路径
|
||||
use_distance_diffusion: 是否启用距离扩散补全边界
|
||||
max_diffusion_distance: 距离扩散最大距离(米)
|
||||
diffusion_power: 距离扩散幂参数
|
||||
diffusion_n_neighbors: 距离扩散最近邻数量
|
||||
cmap: 颜色映射名称(None表示自动识别)
|
||||
expand_ratio: 边界外扩比例(0-1之间)
|
||||
output_dir: 输出目录
|
||||
callback: 回调函数
|
||||
|
||||
Returns:
|
||||
可视化分布图文件路径
|
||||
"""
|
||||
from src.postprocessing.map import ContentMapper
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤9", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤9: 生成水质参数可视化分布图")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if output_image_path is None:
|
||||
csv_name = Path(prediction_csv_path).stem
|
||||
output_image_path = str(output_dir / f"{csv_name}_distribution.png")
|
||||
|
||||
if Path(output_image_path).exists():
|
||||
print(f"检测到已存在的分布图文件,直接使用: {output_image_path}")
|
||||
notify("skipped", f"可视化分布图已设置: {output_image_path}")
|
||||
return output_image_path
|
||||
|
||||
mapper = ContentMapper(input_crs=input_crs, output_crs=output_crs)
|
||||
|
||||
mapper_kwargs = {
|
||||
"resolution": resolution,
|
||||
"show_sample_points": show_sample_points,
|
||||
"use_distance_diffusion": use_distance_diffusion,
|
||||
"diffusion_power": diffusion_power,
|
||||
"diffusion_n_neighbors": diffusion_n_neighbors,
|
||||
"expand_ratio": expand_ratio,
|
||||
}
|
||||
|
||||
optional_kwargs = {
|
||||
"base_map_tif": base_map_tif,
|
||||
"max_diffusion_distance": max_diffusion_distance,
|
||||
"cmap": cmap,
|
||||
}
|
||||
mapper_kwargs.update({k: v for k, v in optional_kwargs.items() if v is not None})
|
||||
|
||||
mapper.process_data(
|
||||
csv_file=prediction_csv_path,
|
||||
shp_file=boundary_shp_path,
|
||||
output_file=output_image_path,
|
||||
**mapper_kwargs,
|
||||
)
|
||||
|
||||
notify("completed", f"可视化分布图已保存: {output_image_path}")
|
||||
return output_image_path
|
||||
497
src/core/steps/modeling_step.py
Normal file
497
src/core/steps/modeling_step.py
Normal file
@ -0,0 +1,497 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
建模步骤
|
||||
|
||||
包含 step6_train_models, step6_5_non_empirical_modeling, step6_75_custom_regression
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable, Dict
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 汉化 -> 英文 反向映射字典(UI 复选框显示文本 -> 底层算法键名)
|
||||
# ============================================================
|
||||
|
||||
# 模型名称:中文 (缩写) -> 英文键名
|
||||
MODEL_NAME_MAP = {
|
||||
"多元线性回归 (MLR)": "LinearRegression",
|
||||
"岭回归 (Ridge)": "Ridge",
|
||||
"套索回归 (Lasso)": "Lasso",
|
||||
"弹性网络 (ElasticNet)": "ElasticNet",
|
||||
"偏最小二乘 (PLSR)": "PLS",
|
||||
"决策树 (CART)": "DecisionTree",
|
||||
"随机森林 (RF)": "RF",
|
||||
"极端随机树 (ET)": "ExtraTrees",
|
||||
"极值梯度提升 (XGBoost)": "XGBoost",
|
||||
"轻量梯度提升 (LightGBM)": "LightGBM",
|
||||
"类别梯度提升 (CatBoost)": "CatBoost",
|
||||
"梯度提升树 (GBDT)": "GradientBoosting",
|
||||
"自适应提升 (AdaBoost)": "AdaBoost",
|
||||
"支持向量回归 (SVR)": "SVR",
|
||||
"K近邻回归 (KNN)": "KNN",
|
||||
"多层感知机 (BP神经网络)": "MLP",
|
||||
}
|
||||
|
||||
# 预处理方法:各种可能的中文变体 -> 标准键名
|
||||
PREPROC_NAME_MAP = {
|
||||
# 无处理
|
||||
"无 (None)": "None",
|
||||
"None": "None",
|
||||
# MMS
|
||||
"最小-最大归一化 (MMS)": "MMS",
|
||||
"MMS": "MMS",
|
||||
# SS
|
||||
"标度化 (SS)": "SS",
|
||||
"SS": "SS",
|
||||
# SNV
|
||||
"标准正态变换 (SNV)": "SNV",
|
||||
"SNV": "SNV",
|
||||
# MA
|
||||
"移动平均 (MA)": "MA",
|
||||
"MA": "MA",
|
||||
# SG
|
||||
"Savitzky-Golay (SG)": "SG",
|
||||
"SG": "SG",
|
||||
# MSC
|
||||
"多元散射校正 (MSC)": "MSC",
|
||||
"MSC": "MSC",
|
||||
# D1
|
||||
"一阶导数 (D1)": "D1",
|
||||
"D1": "D1",
|
||||
# D2
|
||||
"二阶导数 (D2)": "D2",
|
||||
"D2": "D2",
|
||||
# DT
|
||||
"去趋势 (DT)": "DT",
|
||||
"DT": "DT",
|
||||
# CT
|
||||
"中心化 (CT)": "CT",
|
||||
"CT": "CT",
|
||||
}
|
||||
|
||||
# 数据划分方法:各种可能的中文变体 -> 标准键名
|
||||
SPLIT_NAME_MAP = {
|
||||
"SPXY 算法 (考量X-Y空间)": "spxy",
|
||||
"spxy": "spxy",
|
||||
"KS 算法 (考量X空间)": "ks",
|
||||
"ks": "ks",
|
||||
"随机划分 (Random)": "random",
|
||||
"random": "random",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_model_names(model_names: List[str]) -> List[str]:
|
||||
"""清洗模型名称列表:将汉化显示文本还原为英文键名"""
|
||||
result = []
|
||||
for name in model_names:
|
||||
if name in MODEL_NAME_MAP:
|
||||
result.append(MODEL_NAME_MAP[name])
|
||||
else:
|
||||
# 已经是英文键名,直接保留
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_preprocessing_methods(methods: List[str]) -> List[str]:
|
||||
"""清洗预处理方法列表:将汉化显示文本还原为标准键名"""
|
||||
result = []
|
||||
for method in methods:
|
||||
if method in PREPROC_NAME_MAP:
|
||||
result.append(PREPROC_NAME_MAP[method])
|
||||
else:
|
||||
# 已经是标准键名,直接保留
|
||||
result.append(method)
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_split_methods(methods: List[str]) -> List[str]:
|
||||
"""清洗数据划分方法列表:将汉化显示文本还原为标准键名"""
|
||||
result = []
|
||||
for method in methods:
|
||||
if method in SPLIT_NAME_MAP:
|
||||
result.append(SPLIT_NAME_MAP[method])
|
||||
else:
|
||||
# 已经是标准键名,直接保留
|
||||
result.append(method)
|
||||
return result
|
||||
|
||||
|
||||
class ModelingStep:
|
||||
"""建模步骤"""
|
||||
|
||||
# ---- Step 6: 训练机器学习模型 ----
|
||||
|
||||
@staticmethod
|
||||
def train_models(
|
||||
feature_start_column: str = "374.285004",
|
||||
preprocessing_methods: Optional[List[str]] = None,
|
||||
model_names: Optional[List[str]] = None,
|
||||
split_methods: Optional[List[str]] = None,
|
||||
cv_folds: int = 5,
|
||||
training_csv_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
) -> str:
|
||||
"""使用采样点光谱和实测值建立机器学习模型"""
|
||||
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤6", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤6: 训练机器学习模型")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if training_csv_path is None:
|
||||
raise ValueError("必须提供 training_csv_path 参数")
|
||||
|
||||
# 检查模型目录是否已有模型
|
||||
if output_dir.exists() and any(output_dir.iterdir()):
|
||||
has_models = False
|
||||
for item in output_dir.iterdir():
|
||||
if item.is_dir():
|
||||
model_files = (
|
||||
list(item.glob("*.pkl"))
|
||||
+ list(item.glob("*.joblib"))
|
||||
+ list(item.glob("*.h5"))
|
||||
)
|
||||
if model_files:
|
||||
has_models = True
|
||||
break
|
||||
if has_models:
|
||||
print(f"检测到已存在的模型文件,直接使用: {output_dir}")
|
||||
notify("skipped", f"模型目录已设置: {output_dir}")
|
||||
return str(output_dir)
|
||||
|
||||
if preprocessing_methods is None:
|
||||
preprocessing_methods = ["None", "MMS", "SS", "SNV", "MA", "SG", "MSC", "D1", "D2", "DT", "CT"]
|
||||
if model_names is None:
|
||||
model_names = ["SVR", "RF", "Ridge", "Lasso"]
|
||||
if split_methods is None:
|
||||
split_methods = ["spxy", "ks", "random"]
|
||||
|
||||
# ---- 汉化清洗:将 UI 传来的中文/混合名称转换为底层英文键名 ----
|
||||
preprocessing_methods = _normalize_preprocessing_methods(preprocessing_methods)
|
||||
model_names = _normalize_model_names(model_names)
|
||||
split_methods = _normalize_split_methods(split_methods)
|
||||
|
||||
print(f"[参数清洗] 预处理方法: {preprocessing_methods}")
|
||||
print(f"[参数清洗] 模型名称: {model_names}")
|
||||
print(f"[参数清洗] 划分方法: {split_methods}")
|
||||
|
||||
modeler = WaterQualityModelingBatch(str(output_dir))
|
||||
modeler.train_models_batch(
|
||||
csv_path=training_csv_path,
|
||||
feature_start_column=feature_start_column,
|
||||
preprocessing_methods=preprocessing_methods,
|
||||
model_names=model_names,
|
||||
split_methods=split_methods,
|
||||
cv_folds=cv_folds,
|
||||
)
|
||||
|
||||
print(f"模型训练完成,结果保存在: {output_dir}")
|
||||
|
||||
if _report_generator is not None:
|
||||
try:
|
||||
summary_path = _report_generator.generate_training_summary(str(output_dir))
|
||||
print(f"训练摘要报告已生成: {summary_path}")
|
||||
except Exception as e:
|
||||
print(f"生成训练摘要报告时出错: {e}")
|
||||
|
||||
notify("completed", f"模型训练完成: {output_dir}")
|
||||
return str(output_dir)
|
||||
|
||||
# ---- Step 6.5: 非经验统计回归模型训练 ----
|
||||
|
||||
@staticmethod
|
||||
def train_non_empirical_models(
|
||||
csv_path: Optional[str] = None,
|
||||
preprocessing_methods: Optional[List[str]] = None,
|
||||
algorithms: Optional[List[str]] = None,
|
||||
value_cols: Union[int, Dict[str, int]] = 0,
|
||||
spectral_start_col: int = 1,
|
||||
spectral_end_col: Optional[int] = None,
|
||||
window: int = 5,
|
||||
output_dir: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""非经验统计回归模型训练"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤6.5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤6.5: 非经验统计回归模型训练")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过非经验模型训练(enabled=False)。")
|
||||
notify("skipped", "跳过的经验模型训练")
|
||||
return {}
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("必须提供 csv_path 参数")
|
||||
|
||||
if output_dir is not None:
|
||||
non_empirical_dir = Path(output_dir)
|
||||
else:
|
||||
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
|
||||
non_empirical_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if preprocessing_methods is None:
|
||||
preprocessing_methods = ["None"]
|
||||
if algorithms is None:
|
||||
algorithms = ["chl_a", "nh3", "mno4", "tn", "tp", "tss"]
|
||||
|
||||
if isinstance(value_cols, int):
|
||||
value_cols_dict = {algorithm: value_cols for algorithm in algorithms}
|
||||
elif isinstance(value_cols, dict):
|
||||
value_cols_dict = value_cols
|
||||
else:
|
||||
raise ValueError("value_cols 参数必须是整数或字典")
|
||||
|
||||
if spectral_end_col is None:
|
||||
df = pd.read_csv(csv_path)
|
||||
spectral_end_col = len(df.columns) - 1
|
||||
|
||||
all_model_results = {}
|
||||
|
||||
for preprocess in preprocessing_methods:
|
||||
preprocess_dir = non_empirical_dir / preprocess
|
||||
preprocess_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
processed_csv_path = _apply_preprocessing_internal(
|
||||
csv_path, preprocess, preprocess_dir, spectral_start_col
|
||||
)
|
||||
|
||||
for algorithm in algorithms:
|
||||
algorithm_value_col = value_cols_dict[algorithm]
|
||||
print(f"\n训练 {preprocess} + {algorithm} 模型 (实测值列: {algorithm_value_col})...")
|
||||
|
||||
model_outpath = str(preprocess_dir / f"{preprocess}_{algorithm}.json")
|
||||
|
||||
if Path(model_outpath).exists():
|
||||
print(f"检测到已存在的模型文件,直接使用: {model_outpath}")
|
||||
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
|
||||
continue
|
||||
|
||||
try:
|
||||
from src.core.non_empirical_model_correction import run_model_correction
|
||||
run_model_correction(
|
||||
algorithm=algorithm,
|
||||
csv_file=processed_csv_path if Path(processed_csv_path).exists() else csv_path,
|
||||
value_col=algorithm_value_col,
|
||||
spectral_start=spectral_start_col,
|
||||
spectral_end=spectral_end_col,
|
||||
model_info_outpath=model_outpath,
|
||||
window=window,
|
||||
)
|
||||
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
|
||||
print(f"模型训练完成: {model_outpath}")
|
||||
except Exception as e:
|
||||
print(f"训练 {preprocess}_{algorithm} 模型时出错: {e}")
|
||||
continue
|
||||
|
||||
summary_path = _generate_non_empirical_summary(all_model_results, non_empirical_dir)
|
||||
notify("completed", f"非经验模型训练完成: {non_empirical_dir}")
|
||||
return all_model_results
|
||||
|
||||
# ---- Step 6.75: 自定义回归分析 ----
|
||||
|
||||
@staticmethod
|
||||
def custom_regression(
|
||||
csv_path: Optional[str] = None,
|
||||
x_columns: Optional[Union[str, List[str]]] = None,
|
||||
y_columns: Optional[Union[str, List[str]]] = None,
|
||||
methods: Union[str, List[str]] = "all",
|
||||
output_dir: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Optional[str]:
|
||||
"""使用自定义回归方法分析指标与目标参数之间的关系"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤6.75", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤6.75: 自定义回归分析")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过自定义回归分析(enabled=False)。")
|
||||
notify("skipped", "跳过自定义回归分析")
|
||||
return None
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("必须提供 csv_path 参数")
|
||||
if y_columns is None:
|
||||
raise ValueError("必须指定 y_columns")
|
||||
if x_columns is None:
|
||||
raise ValueError("必须指定 x_columns")
|
||||
|
||||
if isinstance(x_columns, str):
|
||||
x_columns = [x_columns]
|
||||
if isinstance(y_columns, str):
|
||||
y_columns = [y_columns]
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
missing_x = [col for col in x_columns if col not in df.columns]
|
||||
missing_y = [col for col in y_columns if col not in df.columns]
|
||||
if missing_x:
|
||||
raise ValueError(f"自变量列不存在: {missing_x}")
|
||||
if missing_y:
|
||||
raise ValueError(f"因变量列不存在: {missing_y}")
|
||||
|
||||
if output_dir is None:
|
||||
custom_regression_dir = Path(work_dir) / "9_Custom_Regression_Modeling"
|
||||
else:
|
||||
custom_regression_dir = Path(work_dir) / output_dir
|
||||
custom_regression_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from src.core.modeling.regression import SingleVariableRegressionAnalysis
|
||||
analyzer = SingleVariableRegressionAnalysis()
|
||||
analyzer.batch_single_variable_regression(
|
||||
data=df,
|
||||
x_columns=x_columns,
|
||||
y_columns=y_columns,
|
||||
methods=methods,
|
||||
output_dir=str(custom_regression_dir),
|
||||
)
|
||||
|
||||
notify("completed", f"自定义回归结果已保存到目录: {custom_regression_dir}")
|
||||
return str(custom_regression_dir)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 内部辅助函数(供 ModelingStep 内部使用)
|
||||
# ============================================================
|
||||
|
||||
def _apply_preprocessing_internal(
|
||||
csv_path: str,
|
||||
preprocess_method: str,
|
||||
output_dir: Path,
|
||||
spectral_start_col: int = 4,
|
||||
) -> str:
|
||||
"""应用预处理到CSV数据(内部函数)"""
|
||||
raw_p = str(preprocess_method).lower()
|
||||
if raw_p == "none" or "无" in raw_p or "跳过" in raw_p:
|
||||
preprocess_method = "None"
|
||||
elif raw_p == "mms" or "minmax" in raw_p or "最大最小" in raw_p:
|
||||
preprocess_method = "MMS"
|
||||
elif raw_p == "ss" or "标准" in raw_p or "标准化" in raw_p:
|
||||
preprocess_method = "SS"
|
||||
elif raw_p == "snv" or "标准正态" in raw_p:
|
||||
preprocess_method = "SNV"
|
||||
elif raw_p == "ma" or "移动" in raw_p:
|
||||
preprocess_method = "MA"
|
||||
elif raw_p == "sg" or "savitzky" in raw_p or "平滑" in raw_p:
|
||||
preprocess_method = "SG"
|
||||
elif raw_p == "msc" or "多元散射" in raw_p:
|
||||
preprocess_method = "MSC"
|
||||
elif raw_p in ("d1", "d2", "dt"):
|
||||
preprocess_method = {"d1": "D1", "d2": "D2", "dt": "DT"}.get(raw_p, raw_p.upper())
|
||||
elif raw_p == "ct" or "去趋势" in raw_p:
|
||||
preprocess_method = "CT"
|
||||
|
||||
if preprocess_method == "None":
|
||||
return csv_path
|
||||
|
||||
output_filename = f"preprocessed_{preprocess_method}.csv"
|
||||
output_path = str(output_dir / output_filename)
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预处理文件,直接使用: {output_path}")
|
||||
return output_path
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
non_spectral_cols = df.iloc[:, :spectral_start_col]
|
||||
spectral_data = df.iloc[:, spectral_start_col:]
|
||||
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
|
||||
save_path = None
|
||||
if preprocess_method == "SS":
|
||||
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_path = str(models_dir / "scaler_params.pkl")
|
||||
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||
|
||||
processed_spectral = Preprocessing(preprocess_method, spectral_data, save_path=save_path)
|
||||
|
||||
if isinstance(processed_spectral, pd.DataFrame):
|
||||
processed_df = pd.concat([non_spectral_cols, processed_spectral], axis=1)
|
||||
else:
|
||||
processed_spectral_df = pd.DataFrame(
|
||||
processed_spectral, columns=spectral_data.columns, index=spectral_data.index
|
||||
)
|
||||
processed_df = pd.concat([non_spectral_cols, processed_spectral_df], axis=1)
|
||||
|
||||
processed_df.to_csv(output_path, index=False)
|
||||
print(f"预处理完成: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def _generate_non_empirical_summary(model_results: Dict[str, str], output_dir: Path) -> str:
|
||||
"""生成非经验模型训练结果汇总CSV"""
|
||||
summary_path = str(output_dir / "non_empirical_models_summary.csv")
|
||||
summary_data = []
|
||||
|
||||
for model_key, model_path in model_results.items():
|
||||
try:
|
||||
parts = model_key.split("_")
|
||||
preprocess_method = parts[0]
|
||||
algorithm_name = "_".join(parts[1:]) if len(parts) > 2 else parts[1]
|
||||
|
||||
with open(model_path, "r", encoding="utf-8") as f:
|
||||
model_info = json.load(f)
|
||||
|
||||
accuracy_list = model_info.get("accuracy", [])
|
||||
summary_row = {
|
||||
"Preprocessing Method": preprocess_method,
|
||||
"Algorithm Name": algorithm_name,
|
||||
"Model Type": model_info.get("model_type", ""),
|
||||
"Coefficient Count": len(model_info.get("model_info", [])),
|
||||
"Average Accuracy(%)": np.mean(accuracy_list) if accuracy_list else 0,
|
||||
"Min Accuracy(%)": np.min(accuracy_list) if accuracy_list else 0,
|
||||
"Max Accuracy(%)": np.max(accuracy_list) if accuracy_list else 0,
|
||||
"Sample Count": len(model_info.get("long", [])),
|
||||
"Model File": model_path,
|
||||
}
|
||||
|
||||
coefficients = model_info.get("model_info", [])
|
||||
for i, coeff in enumerate(coefficients[:5]):
|
||||
summary_row[f"系数_{i+1}"] = coeff
|
||||
|
||||
summary_data.append(summary_row)
|
||||
except Exception as e:
|
||||
print(f"读取模型文件 {model_path} 时出错: {e}")
|
||||
continue
|
||||
|
||||
if summary_data:
|
||||
df_summary = pd.DataFrame(summary_data)
|
||||
df_summary.to_csv(summary_path, index=False, encoding="utf-8-sig")
|
||||
print(f"汇总文件已生成: {summary_path}")
|
||||
else:
|
||||
print("警告: 没有有效的模型数据可汇总")
|
||||
summary_path = ""
|
||||
|
||||
return summary_path
|
||||
350
src/core/steps/prediction_step.py
Normal file
350
src/core/steps/prediction_step.py
Normal file
@ -0,0 +1,350 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
预测步骤
|
||||
|
||||
包含 step7_generate_sampling_points, step8_predict_water_quality,
|
||||
step8_5_predict_with_non_empirical_models, step8_75_predict_with_custom_regression
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable, Dict
|
||||
|
||||
|
||||
class PredictionStep:
|
||||
"""预测步骤"""
|
||||
|
||||
# ---- Step 7: 生成采样点并提取光谱 ----
|
||||
|
||||
@staticmethod
|
||||
def generate_sampling_points(
|
||||
deglint_img_path: Optional[str] = None,
|
||||
interval: int = 50,
|
||||
sample_radius: int = 5,
|
||||
chunk_size: int = 1000,
|
||||
water_mask_path: Optional[str] = None,
|
||||
glint_mask_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./10_sampling",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
|
||||
from pathlib import Path
|
||||
from src.utils.sampling import get_spectral_sampling_points_chunked
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "sampling_spectra.csv")
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤7", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤7: 生成预测采样点并提取光谱")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if deglint_img_path is None:
|
||||
raise ValueError("必须提供 deglint_img_path 参数")
|
||||
|
||||
# 1. 初始归一化与安全转换
|
||||
original_path = Path(deglint_img_path)
|
||||
final_deglint_path = original_path
|
||||
|
||||
# 2. 智能回溯探测:如果当前路径不存在,或者后缀是前端死板的 .dat
|
||||
if not final_deglint_path.exists() or final_deglint_path.suffix.lower() == '.dat':
|
||||
print(f"🔍 智能探测:输入去耀斑路径不存在或为 .dat 占位符 ({final_deglint_path}),正在向上搜索真实产物...")
|
||||
|
||||
# 定位到预期的 3_deglint 根目录
|
||||
possible_dir = original_path.parent
|
||||
if possible_dir.name != '3_deglint' and Path(output_path).parent.parent.exists():
|
||||
possible_dir = Path(output_path).parent.parent / "3_deglint"
|
||||
|
||||
if possible_dir.exists():
|
||||
# 搜寻该目录下所有真实存在的 .bsq 文件(接管 goodman/sugar/kutser/hedley 的硬编码产物)
|
||||
existing_bsqs = list(possible_dir.glob("*.bsq"))
|
||||
if existing_bsqs:
|
||||
final_deglint_path = existing_bsqs[0]
|
||||
print(f"💡 智能拦截成功:自动寻回底层真实去耀斑影像: {final_deglint_path}")
|
||||
else:
|
||||
final_deglint_path = original_path.with_suffix('.bsq')
|
||||
else:
|
||||
final_deglint_path = original_path.with_suffix('.bsq')
|
||||
|
||||
deglint_img_str = str(final_deglint_path)
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的采样点光谱数据文件,直接使用: {output_path}")
|
||||
notify("skipped", f"采样点光谱数据已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
glint_mask_to_use = glint_mask_path
|
||||
if glint_mask_to_use is None:
|
||||
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
|
||||
|
||||
# 传递极度安全的 deglint_img_str 进底层
|
||||
get_spectral_sampling_points_chunked(
|
||||
deglint_img_str, water_mask_path, glint_mask_to_use,
|
||||
output_path, interval, sample_radius, chunk_size
|
||||
)
|
||||
|
||||
notify("completed", f"采样点光谱数据已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
# ---- Step 8: 机器学习模型预测水质参数 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_water_quality(
|
||||
sampling_csv_path: str,
|
||||
models_dir: Optional[str] = None,
|
||||
metric: str = "test_r2",
|
||||
prediction_column: str = "prediction",
|
||||
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
) -> Dict[str, str]:
|
||||
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8: 预测水质参数")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if models_dir is None:
|
||||
raise ValueError("必须提供 models_dir 参数")
|
||||
|
||||
ml_prediction_dir = Path(output_dir)
|
||||
ml_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prediction_files = {}
|
||||
if ml_prediction_dir.exists():
|
||||
csv_files = list(ml_prediction_dir.glob("*.csv"))
|
||||
for csv_file in csv_files:
|
||||
file_stem = csv_file.stem
|
||||
if "_prediction" in file_stem:
|
||||
target_name = file_stem.replace("_prediction", "")
|
||||
elif "_pred" in file_stem:
|
||||
target_name = file_stem.replace("_pred", "")
|
||||
else:
|
||||
target_name = file_stem
|
||||
prediction_files[target_name] = str(csv_file)
|
||||
|
||||
# 检查是否所有目标参数都有预测文件
|
||||
if prediction_files:
|
||||
models_path_obj = Path(models_dir)
|
||||
if models_path_obj.exists():
|
||||
target_folders = [d.name for d in models_path_obj.iterdir() if d.is_dir()]
|
||||
missing_targets = [t for t in target_folders if t not in prediction_files]
|
||||
if not missing_targets:
|
||||
print(f"检测到已存在的预测结果文件,直接使用: {ml_prediction_dir}")
|
||||
notify("skipped", f"预测结果已设置: {ml_prediction_dir}")
|
||||
return prediction_files
|
||||
else:
|
||||
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
|
||||
|
||||
inferencer = WaterQualityInference(models_dir)
|
||||
all_results = inferencer.batch_inference_multi_models(
|
||||
models_root_dir=models_dir,
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_dir=str(ml_prediction_dir),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
output_format="csv",
|
||||
)
|
||||
|
||||
for target_name, result in all_results.items():
|
||||
if result.get("status") == "success":
|
||||
prediction_files[target_name] = result["output_file"]
|
||||
|
||||
print(f"预测完成,结果保存在: {ml_prediction_dir}")
|
||||
|
||||
if _report_generator is not None:
|
||||
try:
|
||||
report_path = _report_generator.generate_prediction_report(prediction_files)
|
||||
print(f"预测结果报告已生成: {report_path}")
|
||||
except Exception as e:
|
||||
print(f"生成预测结果报告时出错: {e}")
|
||||
|
||||
notify("completed", f"预测完成: {ml_prediction_dir}")
|
||||
return prediction_files
|
||||
|
||||
# ---- Step 8.5: 非经验模型预测 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_with_non_empirical_models(
|
||||
sampling_csv_path: str,
|
||||
non_empirical_models_dir: Optional[str] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
metric: str = "Average Accuracy(%)",
|
||||
prediction_column: str = "prediction",
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Dict[str, str]:
|
||||
"""使用非经验统计回归模型进行参数预测"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8.5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8.5: 使用非经验模型进行参数预测")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过非经验模型预测(enabled=False)。")
|
||||
notify("skipped", "跳过非经验模型预测")
|
||||
return {}
|
||||
|
||||
if non_empirical_models_dir is not None:
|
||||
final_models_dir = non_empirical_models_dir
|
||||
else:
|
||||
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
|
||||
if Path(default_models_dir).exists():
|
||||
final_models_dir = default_models_dir
|
||||
else:
|
||||
raise ValueError("请先执行步骤6.5: 非经验模型训练,或提供 non_empirical_models_dir 参数")
|
||||
|
||||
if output_dir is not None:
|
||||
non_empirical_prediction_dir = Path(output_dir)
|
||||
else:
|
||||
non_empirical_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Non_Empirical_Prediction"
|
||||
non_empirical_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prediction_files = {}
|
||||
summary_path = Path(final_models_dir) / "non_empirical_models_summary.csv"
|
||||
if not summary_path.exists():
|
||||
raise ValueError(f"未找到非经验模型汇总文件: {summary_path}")
|
||||
|
||||
import pandas as pd
|
||||
df_summary = pd.read_csv(summary_path)
|
||||
|
||||
best_models = {}
|
||||
for algorithm in df_summary["Algorithm Name"].unique():
|
||||
algorithm_df = df_summary[df_summary["Algorithm Name"] == algorithm]
|
||||
if metric in algorithm_df.columns:
|
||||
best_model_row = algorithm_df.nlargest(1, metric)
|
||||
else:
|
||||
best_model_row = algorithm_df.iloc[[0]]
|
||||
|
||||
best_model_path = best_model_row["Model File"].values[0]
|
||||
best_preprocess = best_model_row["Preprocessing Method"].values[0]
|
||||
best_accuracy = best_model_row[metric].values[0] if metric in best_model_row.columns else "N/A"
|
||||
|
||||
best_models[algorithm] = {
|
||||
"model_path": best_model_path,
|
||||
"preprocess_method": best_preprocess,
|
||||
"accuracy": best_accuracy,
|
||||
}
|
||||
print(f"算法 {algorithm}: 选择 {best_preprocess} (准确率: {best_accuracy})")
|
||||
|
||||
pd.read_csv(sampling_csv_path) # just to validate
|
||||
|
||||
for algorithm, model_info in best_models.items():
|
||||
print(f"\n使用 {algorithm} 算法进行预测...")
|
||||
output_path = str(non_empirical_prediction_dir / f"non_empirical_{algorithm}_{prediction_column}.csv")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预测结果文件,直接使用: {output_path}")
|
||||
prediction_files[algorithm] = output_path
|
||||
continue
|
||||
|
||||
try:
|
||||
from src.core.non_empirical_retrieval import non_empirical_retrieval
|
||||
non_empirical_retrieval(
|
||||
algorithm=algorithm,
|
||||
model_info_path=model_info["model_path"],
|
||||
coor_spectral_path=sampling_csv_path,
|
||||
output_path=output_path,
|
||||
wave_radius=5,
|
||||
)
|
||||
prediction_files[algorithm] = output_path
|
||||
print(f"预测完成: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"使用 {algorithm} 算法预测时出错: {e}")
|
||||
continue
|
||||
|
||||
notify("completed", f"非经验模型预测完成: {non_empirical_prediction_dir}")
|
||||
return prediction_files
|
||||
|
||||
# ---- Step 8.75: 自定义回归模型预测 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_with_custom_regression(
|
||||
sampling_csv_path: str,
|
||||
custom_regression_dir: Optional[str] = None,
|
||||
formula_csv_path: Optional[str] = None,
|
||||
coordinate_columns: Optional[List[str]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
filename_prefix: str = "custom_regression_prediction",
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Dict[str, str]:
|
||||
"""使用自定义回归模型进行参数预测"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8.75", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8.75: 使用自定义回归模型进行参数预测")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过自定义回归模型预测(enabled=False)。")
|
||||
notify("skipped", "跳过自定义回归预测")
|
||||
return {}
|
||||
|
||||
if not Path(sampling_csv_path).exists():
|
||||
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
|
||||
|
||||
if custom_regression_dir is not None:
|
||||
final_regression_dir = custom_regression_dir
|
||||
else:
|
||||
final_regression_dir = str(Path(work_dir) / "9_Custom_Regression_Modeling")
|
||||
if not Path(final_regression_dir).exists():
|
||||
raise ValueError(
|
||||
"请先执行步骤6.75: 自定义回归分析,或提供 custom_regression_dir 参数"
|
||||
)
|
||||
|
||||
if output_dir is None:
|
||||
custom_regression_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Custom_Regression_Prediction"
|
||||
custom_regression_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
prediction_output_dir = str(custom_regression_prediction_dir)
|
||||
else:
|
||||
prediction_output_dir = output_dir
|
||||
|
||||
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
|
||||
|
||||
predictor = CustomRegressionPredictor(
|
||||
regression_csv_dir=final_regression_dir,
|
||||
formula_csv_path=formula_csv_path,
|
||||
)
|
||||
|
||||
print(f"开始使用自定义回归模块进行批量预测...")
|
||||
print(f" 采样点数据: {sampling_csv_path}")
|
||||
print(f" 回归模型目录: {final_regression_dir}")
|
||||
print(f" 输出目录: {prediction_output_dir}")
|
||||
|
||||
saved_files = predictor.run_batch_prediction(
|
||||
input_csv_path=sampling_csv_path,
|
||||
coordinate_columns=coordinate_columns,
|
||||
filename_prefix=filename_prefix,
|
||||
)
|
||||
|
||||
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
|
||||
for param_name, filepath in saved_files.items():
|
||||
print(f" {param_name}: {filepath}")
|
||||
|
||||
notify("completed", f"自定义回归预测完成: {len(saved_files)} 个文件")
|
||||
return saved_files
|
||||
148
src/core/steps/water_mask_step.py
Normal file
148
src/core/steps/water_mask_step.py
Normal file
@ -0,0 +1,148 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
步骤1: 水域掩膜生成
|
||||
|
||||
支持三种方式:
|
||||
1. 基于 shp 文件栅格化
|
||||
2. 使用现有栅格格式掩膜文件 (.dat/.tif)
|
||||
3. 基于 NDWI 从影像自动生成水体掩膜
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Callable, Union
|
||||
import numpy as np
|
||||
|
||||
|
||||
class WaterMaskStep:
|
||||
"""水域掩膜生成步骤"""
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
mask_path: Optional[str] = None,
|
||||
img_path: Optional[str] = None,
|
||||
ndwi_threshold: float = 0.4,
|
||||
use_ndwi: bool = False,
|
||||
generate_png: bool = True,
|
||||
output_path: Optional[str] = None,
|
||||
water_mask_dir: Union[str, Path] = "./1_water_mask",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
执行水域掩膜生成
|
||||
|
||||
Args:
|
||||
mask_path: 水体掩膜文件路径,支持 .shp(需 img_path)或 .dat/.tif(直接使用)
|
||||
img_path: 输入影像文件路径(当 mask_path 为 shp 或 use_ndwi=True 时必须提供)
|
||||
ndwi_threshold: NDWI 阈值(use_ndwi=True 时使用)
|
||||
use_ndwi: 是否使用 NDWI 方法从影像生成水体掩膜
|
||||
generate_png: 是否生成 PNG 预览图(默认 True)
|
||||
output_path: 指定输出掩膜文件的保存路径(可选)
|
||||
water_mask_dir: 工作目录
|
||||
callback: 回调函数,签名为 callback(step, status, message)
|
||||
|
||||
Returns:
|
||||
dat 格式的水域掩膜文件路径
|
||||
"""
|
||||
from src.utils.extract_water_area import rasterize_shp, ndwi
|
||||
from src.core.utils.preview_generator import (
|
||||
generate_image_preview,
|
||||
generate_water_mask_overlay,
|
||||
)
|
||||
|
||||
water_mask_dir = Path(water_mask_dir)
|
||||
water_mask_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤1", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤1: 生成或设置水域mask")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
# 生成影像预览图
|
||||
if generate_png and img_path is not None and Path(img_path).exists():
|
||||
preview_path = str(water_mask_dir / "hsi_preview.png")
|
||||
generate_image_preview(
|
||||
img_path=img_path,
|
||||
output_path=preview_path,
|
||||
title="影像预览: RGB波段(基于波长)"
|
||||
)
|
||||
|
||||
# ---- NDWI 方法 ----
|
||||
if use_ndwi:
|
||||
if img_path is None:
|
||||
raise ValueError("当 use_ndwi=True 时,必须提供 img_path 参数")
|
||||
if not Path(img_path).exists():
|
||||
raise ValueError(f"影像文件不存在: {img_path}")
|
||||
|
||||
print(f"使用NDWI方法从影像生成水体掩膜,阈值={ndwi_threshold}...")
|
||||
|
||||
ndwi_output_path = output_path or str(water_mask_dir / "water_mask_from_ndwi.dat")
|
||||
os.makedirs(Path(ndwi_output_path).parent, exist_ok=True)
|
||||
|
||||
if Path(ndwi_output_path).exists():
|
||||
print(f"检测到已存在的NDWI掩膜文件,直接使用: {ndwi_output_path}")
|
||||
notify("skipped", f"水域掩膜已设置: {ndwi_output_path}")
|
||||
return ndwi_output_path
|
||||
|
||||
ndwi(img_path, ndwi_threshold, ndwi_output_path)
|
||||
|
||||
if generate_png:
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
generate_water_mask_overlay(
|
||||
img_path=img_path, mask_path=ndwi_output_path, output_path=str(overlay_path)
|
||||
)
|
||||
|
||||
notify("completed", f"NDWI水体掩膜已生成: {ndwi_output_path}")
|
||||
return ndwi_output_path
|
||||
|
||||
# ---- 必须提供 mask_path ----
|
||||
if mask_path is None:
|
||||
raise ValueError("必须提供 mask_path 参数或设置 use_ndwi=True")
|
||||
if not Path(mask_path).exists():
|
||||
raise ValueError(f"文件不存在: {mask_path}")
|
||||
|
||||
file_ext = Path(mask_path).suffix.lower()
|
||||
|
||||
# ---- SHP 栅格化 ----
|
||||
if file_ext == ".shp":
|
||||
if img_path is None:
|
||||
raise ValueError("当 mask_path 为 shp 格式时,必须提供 img_path 参数")
|
||||
|
||||
print(f"检测到shp格式的水体掩膜,正在转换为dat格式...")
|
||||
|
||||
shp_output_path = output_path or str(water_mask_dir / "water_mask_from_shp.dat")
|
||||
os.makedirs(Path(shp_output_path).parent, exist_ok=True)
|
||||
|
||||
if Path(shp_output_path).exists():
|
||||
print(f"检测到已存在的栅格化掩膜文件,直接使用: {shp_output_path}")
|
||||
notify("skipped", f"水域掩膜已设置: {shp_output_path}")
|
||||
if generate_png:
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
if not overlay_path.exists():
|
||||
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
|
||||
return shp_output_path
|
||||
|
||||
safe_mask_path = os.path.abspath(mask_path).replace("\\", "/")
|
||||
rasterize_shp(safe_mask_path, shp_output_path, img_path)
|
||||
|
||||
if generate_png:
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
|
||||
|
||||
notify("completed", f"dat格式水域掩膜已生成: {shp_output_path}")
|
||||
return shp_output_path
|
||||
|
||||
# ---- 栅格格式直接使用 ----
|
||||
print(f"检测到栅格格式的水体掩膜,直接使用: {mask_path}")
|
||||
if generate_png and img_path is not None and Path(img_path).exists():
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
generate_water_mask_overlay(img_path, mask_path, str(overlay_path))
|
||||
|
||||
notify("completed", f"水域掩膜已设置: {mask_path}")
|
||||
return mask_path
|
||||
42
src/core/utils/__init__.py
Normal file
42
src/core/utils/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
工具模块 - 统一导出接口
|
||||
"""
|
||||
from src.core.utils.gdal_helper import (
|
||||
get_image_geo_info,
|
||||
load_image_as_array,
|
||||
save_array_as_image,
|
||||
save_bands_as_image,
|
||||
copy_hdr_info,
|
||||
read_band_as_array,
|
||||
read_multiple_bands,
|
||||
)
|
||||
from src.core.utils.mask_converter import (
|
||||
prepare_water_mask_for_algorithm,
|
||||
ensure_water_mask_dat,
|
||||
)
|
||||
from src.core.utils.preview_generator import (
|
||||
generate_image_preview,
|
||||
generate_water_mask_overlay,
|
||||
select_rgb_bands_by_wavelength,
|
||||
get_wavelength_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# GDAL IO
|
||||
'get_image_geo_info',
|
||||
'load_image_as_array',
|
||||
'save_array_as_image',
|
||||
'save_bands_as_image',
|
||||
'copy_hdr_info',
|
||||
'read_band_as_array',
|
||||
'read_multiple_bands',
|
||||
# 掩膜转换
|
||||
'prepare_water_mask_for_algorithm',
|
||||
'ensure_water_mask_dat',
|
||||
# 预览图生成
|
||||
'generate_image_preview',
|
||||
'generate_water_mask_overlay',
|
||||
'select_rgb_bands_by_wavelength',
|
||||
'get_wavelength_info',
|
||||
]
|
||||
309
src/core/utils/gdal_helper.py
Normal file
309
src/core/utils/gdal_helper.py
Normal file
@ -0,0 +1,309 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
GDAL 底层 IO 工具模块
|
||||
|
||||
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
|
||||
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
# GDAL 导入(可选)
|
||||
try:
|
||||
from osgeo import gdal, ogr, gdal_array
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# hdr 文件工具
|
||||
try:
|
||||
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
|
||||
UTIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
UTIL_AVAILABLE = False
|
||||
write_fields_to_hdrfile = None
|
||||
get_hdr_file_path = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 影像信息读取
|
||||
# ============================================================
|
||||
|
||||
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
|
||||
"""
|
||||
获取影像的地理信息(不加载图像数据,节省内存)
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (geotransform, projection, width, height, n_bands)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
return geotransform, projection, width, height, n_bands
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
|
||||
"""
|
||||
加载影像文件为numpy数组
|
||||
|
||||
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
|
||||
建议直接传递文件路径给算法类,让算法类使用GDAL逐波段处理。
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (image_array, geotransform, projection)
|
||||
image_array: numpy数组,形状为(height, width, bands)
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
image_bands = []
|
||||
for i in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(i)
|
||||
band_data = band.ReadAsArray()
|
||||
image_bands.append(band_data)
|
||||
|
||||
image_array = np.dstack(image_bands)
|
||||
return image_array, geotransform, projection
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
|
||||
"""
|
||||
读取单个波段为 numpy 数组
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
band_index: 波段索引(从 0 开始)
|
||||
|
||||
Returns:
|
||||
numpy 数组,形状为 (height, width)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
band = dataset.GetRasterBand(band_index + 1)
|
||||
return band.ReadAsArray()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
|
||||
"""
|
||||
读取多个指定波段为列表
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
band_indices: 波段索引列表
|
||||
|
||||
Returns:
|
||||
tuple: (band_list, geotransform, projection)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
bands = []
|
||||
for idx in band_indices:
|
||||
band = dataset.GetRasterBand(idx + 1)
|
||||
bands.append(band.ReadAsArray())
|
||||
return bands, geotransform, projection
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 影像写入
|
||||
# ============================================================
|
||||
|
||||
def save_array_as_image(image_array: np.ndarray, output_path: str,
|
||||
geotransform: tuple, projection: str,
|
||||
dtype=None) -> str:
|
||||
"""
|
||||
将numpy数组保存为影像文件
|
||||
|
||||
Args:
|
||||
image_array: numpy数组,形状为(height, width, bands) 或 (height, width)
|
||||
output_path: 输出文件路径
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
dtype: GDAL数据类型(默认 gdal.GDT_Float32)
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if dtype is None:
|
||||
dtype = gdal.GDT_Float32
|
||||
|
||||
if image_array.ndim == 2:
|
||||
height, width = image_array.shape
|
||||
n_bands = 1
|
||||
else:
|
||||
height, width, n_bands = image_array.shape
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
if driver is None:
|
||||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||||
|
||||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||||
|
||||
try:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
if n_bands == 1:
|
||||
band = dataset.GetRasterBand(1)
|
||||
band.WriteArray(image_array)
|
||||
band.FlushCache()
|
||||
else:
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(image_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def save_bands_as_image(corrected_bands: list, output_path: str,
|
||||
geotransform: tuple, projection: str,
|
||||
dtype=None) -> str:
|
||||
"""
|
||||
直接从波段列表保存影像文件(避免堆叠,节省内存)
|
||||
|
||||
Args:
|
||||
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
|
||||
output_path: 输出文件路径
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
dtype: GDAL数据类型
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if not corrected_bands:
|
||||
raise ValueError("波段列表为空")
|
||||
|
||||
if dtype is None:
|
||||
dtype = gdal.GDT_Float32
|
||||
|
||||
n_bands = len(corrected_bands)
|
||||
height, width = corrected_bands[0].shape
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
if driver is None:
|
||||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||||
|
||||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||||
|
||||
try:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
for i, band_array in enumerate(corrected_bands):
|
||||
if band_array.shape != (height, width):
|
||||
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(band_array)
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
|
||||
"""
|
||||
复制原始影像的hdr文件信息(如波长等)到目标影像的hdr文件
|
||||
|
||||
Args:
|
||||
source_img_path: 源影像文件路径(原始bsq文件)
|
||||
dest_img_path: 目标影像文件路径(去耀斑后的bsq文件)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not UTIL_AVAILABLE:
|
||||
print("警告: util模块未导入,无法复制hdr文件信息")
|
||||
return False
|
||||
|
||||
try:
|
||||
source_hdr_path = get_hdr_file_path(source_img_path)
|
||||
dest_hdr_path = get_hdr_file_path(dest_img_path)
|
||||
|
||||
if not Path(source_hdr_path).exists():
|
||||
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
|
||||
return False
|
||||
|
||||
if not Path(dest_hdr_path).exists():
|
||||
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
|
||||
return False
|
||||
|
||||
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
|
||||
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"警告: 复制hdr文件信息时出错: {e}")
|
||||
return False
|
||||
210
src/core/utils/mask_converter.py
Normal file
210
src/core/utils/mask_converter.py
Normal file
@ -0,0 +1,210 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
掩膜转换工具模块
|
||||
|
||||
提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换,
|
||||
以及水体掩膜的预处理逻辑。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
def prepare_water_mask_for_algorithm(
|
||||
water_mask: Optional[Union[str, np.ndarray]],
|
||||
image_shape: Union[tuple, np.ndarray],
|
||||
geotransform: tuple,
|
||||
projection: str,
|
||||
img_path: str,
|
||||
water_mask_dir: Optional[str] = None,
|
||||
callback=None
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
准备水域掩膜供算法使用
|
||||
|
||||
支持格式:
|
||||
- None:自动使用预先生成的 dat 格式掩膜
|
||||
- numpy.ndarray:直接返回(确保是 0/1 格式)
|
||||
- .dat / .tif 等栅格文件:读取并返回
|
||||
- .shp 文件:先栅格化,再读取返回
|
||||
|
||||
Args:
|
||||
water_mask: 掩膜来源
|
||||
image_shape: 影像形状 (height, width) 或 (height, width, channels)
|
||||
geotransform: GDAL 地理变换参数
|
||||
projection: 投影信息
|
||||
img_path: 影像路径(用于 shp 栅格化)
|
||||
water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果)
|
||||
callback: 进度回调函数(可选)
|
||||
|
||||
Returns:
|
||||
numpy数组(dtype=uint8,0=非水域,1=水域)或 None
|
||||
"""
|
||||
img_height, img_width = image_shape[0], image_shape[1]
|
||||
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# numpy 数组直接返回
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (img_height, img_width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8)
|
||||
|
||||
# 字符串路径
|
||||
if isinstance(water_mask, str):
|
||||
ext = Path(water_mask).suffix.lower()
|
||||
|
||||
# shapefile 格式
|
||||
if ext == '.shp':
|
||||
return _convert_shp_to_mask(
|
||||
shp_path=water_mask,
|
||||
img_path=img_path,
|
||||
image_shape=image_shape,
|
||||
geotransform=geotransform,
|
||||
projection=projection,
|
||||
water_mask_dir=water_mask_dir,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
# 栅格文件格式
|
||||
return _load_raster_mask(water_mask, img_height, img_width)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
|
||||
def _convert_shp_to_mask(shp_path: str, img_path: str,
|
||||
image_shape: tuple,
|
||||
geotransform: tuple,
|
||||
projection: str,
|
||||
water_mask_dir: Optional[str] = None,
|
||||
callback=None) -> np.ndarray:
|
||||
"""将 shapefile 栅格化为掩膜数组"""
|
||||
from src.utils.extract_water_area import rasterize_shp
|
||||
|
||||
safe_shp_path = os.path.abspath(shp_path).replace('\\', '/')
|
||||
shp_name = Path(safe_shp_path).stem
|
||||
|
||||
if water_mask_dir:
|
||||
temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat")
|
||||
else:
|
||||
temp_mask_path = f"/tmp/water_mask_{shp_name}.dat"
|
||||
|
||||
# 缓存:已栅格化则直接读取
|
||||
if Path(temp_mask_path).exists():
|
||||
print(f"使用已存在的栅格化掩膜: {temp_mask_path}")
|
||||
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
|
||||
|
||||
# 需要栅格化
|
||||
if img_path is None:
|
||||
raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化")
|
||||
|
||||
print(f"正在将 SHP 栅格化: {safe_shp_path}")
|
||||
rasterize_shp(safe_shp_path, temp_mask_path, img_path)
|
||||
|
||||
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
|
||||
|
||||
|
||||
def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray:
|
||||
"""从栅格文件加载掩膜"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取掩膜文件")
|
||||
|
||||
mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {mask_path}")
|
||||
|
||||
try:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
finally:
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (img_height, img_width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
|
||||
def ensure_water_mask_dat(img_path: str,
|
||||
existing_dat_path: Optional[str] = None,
|
||||
output_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
确保存在 dat 格式的水体掩膜文件(用于步骤3/4中的算法)
|
||||
|
||||
如果 existing_dat_path 存在且是 .dat 文件,直接返回。
|
||||
如果存在同名 .dat 文件,直接返回。
|
||||
否则从 img_path 生成并保存到 output_dir。
|
||||
|
||||
Args:
|
||||
img_path: 用于生成掩膜的遥感影像路径
|
||||
existing_dat_path: 已有的 dat 格式掩膜路径(可选)
|
||||
output_dir: 输出目录(可选)
|
||||
|
||||
Returns:
|
||||
dat 格式掩膜文件路径
|
||||
"""
|
||||
if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat':
|
||||
if Path(existing_dat_path).exists():
|
||||
return existing_dat_path
|
||||
|
||||
img_name = Path(img_path).stem
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(img_path).parent)
|
||||
|
||||
dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat")
|
||||
|
||||
if Path(dat_path).exists():
|
||||
return dat_path
|
||||
|
||||
# 如果已有其他格式的掩膜,转换为 dat
|
||||
for ext in ['.tif', '.img', '.tiff']:
|
||||
alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}")
|
||||
if Path(alt_path).exists():
|
||||
return _convert_to_dat(alt_path, dat_path)
|
||||
|
||||
return dat_path # 返回目标路径,让调用方决定是否需要生成
|
||||
|
||||
|
||||
def _convert_to_dat(src_path: str, dest_path: str) -> str:
|
||||
"""将其他栅格格式转换为 ENVI dat 格式"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法转换格式")
|
||||
|
||||
src_ds = gdal.Open(src_path, gdal.GA_ReadOnly)
|
||||
if src_ds is None:
|
||||
raise ValueError(f"无法打开源掩膜文件: {src_path}")
|
||||
|
||||
try:
|
||||
geotransform = src_ds.GetGeoTransform()
|
||||
projection = src_ds.GetProjection()
|
||||
band = src_ds.GetRasterBand(1)
|
||||
array = band.ReadAsArray()
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte)
|
||||
if dest_ds is None:
|
||||
raise ValueError(f"无法创建输出文件: {dest_path}")
|
||||
|
||||
try:
|
||||
dest_ds.SetGeoTransform(geotransform)
|
||||
dest_ds.SetProjection(projection)
|
||||
dest_band = dest_ds.GetRasterBand(1)
|
||||
dest_band.WriteArray((array > 0).astype(np.uint8))
|
||||
dest_band.FlushCache()
|
||||
finally:
|
||||
dest_ds = None
|
||||
|
||||
return dest_path
|
||||
finally:
|
||||
src_ds = None
|
||||
339
src/core/utils/preview_generator.py
Normal file
339
src/core/utils/preview_generator.py
Normal file
@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
遥感影像预览图生成工具模块
|
||||
|
||||
提供高光谱影像的 RGB 预览图、水域掩膜叠加图等可视化功能。
|
||||
"""
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# matplotlib 仅在实际使用时导入(preview_generator 是可视化工具)
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 辅助函数:波段选择
|
||||
# ============================================================
|
||||
|
||||
def select_rgb_bands_by_wavelength(band_count: int,
|
||||
wavelength_info: Optional[List[float]] = None,
|
||||
fallback_bands: Optional[List[int]] = None) -> List[int]:
|
||||
"""
|
||||
根据波长自动选择 RGB 波段
|
||||
|
||||
Args:
|
||||
band_count: 总波段数
|
||||
wavelength_info: 各波段波长列表(nm),长度为 band_count
|
||||
fallback_bands: 当无法通过波长选择时的回退波段索引 [R, G, B]
|
||||
|
||||
Returns:
|
||||
波段索引列表 [R_index, G_index, B_index](0-based)
|
||||
"""
|
||||
if fallback_bands is None:
|
||||
fallback_bands = [band_count - 3, band_count - 2, band_count - 1]
|
||||
|
||||
if wavelength_info is None:
|
||||
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
|
||||
|
||||
# 目标波长(nm)
|
||||
TARGET_R = 650
|
||||
TARGET_G = 550
|
||||
TARGET_B = 460
|
||||
|
||||
def find_closest(target: float) -> int:
|
||||
min_dist = float('inf')
|
||||
best_idx = 0
|
||||
for i, wl in enumerate(wavelength_info):
|
||||
dist = abs(wl - target)
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
best_idx = i
|
||||
return best_idx
|
||||
|
||||
try:
|
||||
r_idx = find_closest(TARGET_R)
|
||||
g_idx = find_closest(TARGET_G)
|
||||
b_idx = find_closest(TARGET_B)
|
||||
return [r_idx, g_idx, b_idx]
|
||||
except Exception:
|
||||
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
|
||||
|
||||
|
||||
def get_wavelength_info(img_path: str) -> Optional[List[float]]:
|
||||
"""从 hdr 文件读取波长信息"""
|
||||
try:
|
||||
hdr_path = Path(img_path).with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
return None
|
||||
|
||||
wavelengths = []
|
||||
in_wl = False
|
||||
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith('wavelength ='):
|
||||
in_wl = True
|
||||
line = line.split('=', 1)[1].strip()
|
||||
elif in_wl:
|
||||
if line.startswith('{'):
|
||||
line = line[1:]
|
||||
if line.endswith('}'):
|
||||
line = line[:-1]
|
||||
in_wl = False
|
||||
# 解析逗号分隔的数值
|
||||
for token in line.replace(',', ' ').split():
|
||||
try:
|
||||
wavelengths.append(float(token))
|
||||
except ValueError:
|
||||
pass
|
||||
return wavelengths if wavelengths else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 核心预览图生成函数
|
||||
# ============================================================
|
||||
|
||||
def generate_image_preview(img_path: str,
|
||||
output_path: str,
|
||||
bands: Optional[List[int]] = None,
|
||||
title: str = "影像预览") -> str:
|
||||
"""
|
||||
生成高光谱影像的 PNG 预览图
|
||||
|
||||
Args:
|
||||
img_path: 输入影像路径
|
||||
output_path: 输出 PNG 文件路径
|
||||
bands: RGB 波段索引 [R, G, B],None 则自动选择
|
||||
title: 图片标题
|
||||
|
||||
Returns:
|
||||
生成的 PNG 文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法生成影像预览图")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预览图,跳过生成: {output_path}")
|
||||
return output_path
|
||||
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
|
||||
# 自动选择波段
|
||||
if bands is None:
|
||||
if band_count >= 3:
|
||||
wl_info = get_wavelength_info(img_path)
|
||||
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
|
||||
else:
|
||||
bands = [0, 0, 0]
|
||||
|
||||
# 读取波段
|
||||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||||
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
|
||||
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
|
||||
|
||||
r_data[r_data <= 0] = np.nan
|
||||
if band_count > 1:
|
||||
g_data[g_data <= 0] = np.nan
|
||||
if band_count > 2:
|
||||
b_data[b_data <= 0] = np.nan
|
||||
|
||||
# 线性拉伸
|
||||
def linear_stretch(data, low=2, high=98):
|
||||
valid = data[~np.isnan(data)]
|
||||
if len(valid) == 0:
|
||||
return np.zeros_like(data)
|
||||
lo = np.percentile(valid, low)
|
||||
hi = np.percentile(valid, high)
|
||||
if hi - lo < 1e-10:
|
||||
return np.zeros_like(data)
|
||||
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
|
||||
return np.nan_to_num(stretched, nan=0.0)
|
||||
|
||||
r_s = linear_stretch(r_data)
|
||||
g_s = linear_stretch(g_data) if band_count > 1 else r_s
|
||||
b_s = linear_stretch(b_data) if band_count > 2 else r_s
|
||||
|
||||
rgb_image = np.stack([r_s, g_s, b_s], axis=2)
|
||||
|
||||
# 绘图
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
ax.imshow(rgb_image)
|
||||
ax.set_title(title, fontsize=12, fontweight='bold')
|
||||
ax.axis('off')
|
||||
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
if geotransform and geotransform[1] != 0:
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
scale_text = f"分辨率: {pixel_size_x:.2f} m/px | 尺寸: {width} x {height} px"
|
||||
fig.text(0.5, 0.02, scale_text, ha='center', fontsize=9,
|
||||
color='white',
|
||||
bbox=dict(facecolor='black', alpha=0.6,
|
||||
boxstyle='round,pad=0.3'))
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.close(fig)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def generate_water_mask_overlay(img_path: str,
|
||||
mask_path: str,
|
||||
output_path: str,
|
||||
bands: Optional[List[int]] = None,
|
||||
mask_color: tuple = (0, 100, 255),
|
||||
mask_alpha: float = 0.5) -> str:
|
||||
"""
|
||||
生成水域掩膜叠加到原图的 PNG 图像
|
||||
|
||||
Args:
|
||||
img_path: 输入影像路径
|
||||
mask_path: 水域掩膜文件路径
|
||||
output_path: 输出 PNG 路径
|
||||
bands: RGB 波段索引,None 则自动选择
|
||||
mask_color: 掩膜叠加颜色 (R, G, B)
|
||||
mask_alpha: 掩膜透明度(0=完全透明,1=完全不透明)
|
||||
|
||||
Returns:
|
||||
生成的 PNG 文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法生成叠加图")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的叠加图,跳过生成: {output_path}")
|
||||
return output_path
|
||||
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
|
||||
# 自动选择波段
|
||||
if bands is None:
|
||||
if band_count >= 3:
|
||||
wl_info = get_wavelength_info(img_path)
|
||||
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
|
||||
else:
|
||||
bands = [0, 0, 0]
|
||||
|
||||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||||
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
|
||||
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
|
||||
|
||||
r_data[r_data <= 0] = np.nan
|
||||
if band_count > 1:
|
||||
g_data[g_data <= 0] = np.nan
|
||||
if band_count > 2:
|
||||
b_data[b_data <= 0] = np.nan
|
||||
|
||||
def linear_stretch(data, low=2, high=98):
|
||||
valid = data[~np.isnan(data)]
|
||||
if len(valid) == 0:
|
||||
return np.zeros_like(data)
|
||||
lo = np.percentile(valid, low)
|
||||
hi = np.percentile(valid, high)
|
||||
if hi - lo < 1e-10:
|
||||
return np.zeros_like(data)
|
||||
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
|
||||
return np.nan_to_num(stretched, nan=0.0)
|
||||
|
||||
r_s = linear_stretch(r_data)
|
||||
g_s = linear_stretch(g_data) if band_count > 1 else r_s
|
||||
b_s = linear_stretch(b_data) if band_count > 2 else r_s
|
||||
|
||||
rgb_image = np.nan_to_num(np.stack([r_s, g_s, b_s], axis=2)) * 255
|
||||
rgb_image = rgb_image.astype(np.uint8)
|
||||
|
||||
# 读取掩膜
|
||||
mask_dataset = gdal.Open(mask_path)
|
||||
if mask_dataset is not None:
|
||||
mask_data = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
else:
|
||||
print(f"警告: 无法打开掩膜文件: {mask_path}")
|
||||
mask_data = None
|
||||
|
||||
# Alpha 混合
|
||||
overlay = np.zeros((height, width, 4), dtype=np.uint8)
|
||||
overlay[:, :, 0:3] = mask_color
|
||||
overlay[:, :, 3] = 255 # 全不透明
|
||||
|
||||
blended = rgb_image.astype(np.float32)
|
||||
if mask_data is not None:
|
||||
alpha = mask_data.astype(np.float32) / 255.0 * mask_alpha
|
||||
for c in range(3):
|
||||
blended[:, :, c] = rgb_image[:, :, c].astype(np.float32) * (1 - alpha) + mask_color[c] * alpha
|
||||
blended = blended.astype(np.uint8)
|
||||
|
||||
# 绘图
|
||||
fig, ax = plt.subplots(figsize=(14, 10))
|
||||
ax.imshow(blended)
|
||||
ax.axis('off')
|
||||
|
||||
legend_elements = [
|
||||
Patch(facecolor=f'#{mask_color[0]:02x}{mask_color[1]:02x}{mask_color[2]:02x}',
|
||||
edgecolor='black', alpha=mask_alpha, label='水域范围')
|
||||
]
|
||||
ax.legend(handles=legend_elements, loc='upper right', framealpha=0.9)
|
||||
|
||||
# 面积计算
|
||||
if geotransform and geotransform[1] != 0:
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
pixel_size_y = abs(geotransform[5])
|
||||
pixel_area = pixel_size_x * pixel_size_y
|
||||
|
||||
if mask_data is not None:
|
||||
water_pixels = np.sum(mask_data > 0)
|
||||
valid_pixels = np.sum(mask_data >= 0)
|
||||
water_km2 = water_pixels * pixel_area / 1_000_000
|
||||
valid_km2 = valid_pixels * pixel_area / 1_000_000
|
||||
pct = (water_pixels / valid_pixels * 100) if valid_pixels > 0 else 0
|
||||
|
||||
area_text = (f'水域面积: {water_km2:.2f} km² | '
|
||||
f'影像总面积: {valid_km2:.2f} km² | '
|
||||
f'占比: {pct:.1f}%')
|
||||
ax.text(0.02, 0.98, area_text,
|
||||
transform=ax.transAxes, fontsize=11,
|
||||
color='white', fontweight='bold',
|
||||
bbox=dict(facecolor='#0064FF', alpha=0.8,
|
||||
edgecolor='black', boxstyle='round,pad=0.5'),
|
||||
verticalalignment='top')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.close(fig)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
21
src/core/visualization/__init__.py
Normal file
21
src/core/visualization/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 统一导出接口
|
||||
|
||||
本模块从各子模块导入可视化函数,提供统一的导出接口。
|
||||
"""
|
||||
from src.core.visualization.scatter_plot import generate_model_scatter_plots
|
||||
from src.core.visualization.spectrum_plot import generate_spectrum_comparison_plots
|
||||
from src.core.visualization.boxplot import generate_boxplots
|
||||
from src.core.visualization.statistics import generate_statistical_charts
|
||||
from src.core.visualization.preview import generate_glint_deglint_previews
|
||||
from src.core.visualization.report import generate_pipeline_report
|
||||
|
||||
__all__ = [
|
||||
'generate_model_scatter_plots',
|
||||
'generate_spectrum_comparison_plots',
|
||||
'generate_boxplots',
|
||||
'generate_statistical_charts',
|
||||
'generate_glint_deglint_previews',
|
||||
'generate_pipeline_report',
|
||||
]
|
||||
183
src/core/visualization/boxplot.py
Normal file
183
src/core/visualization/boxplot.py
Normal file
@ -0,0 +1,183 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 箱型图生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
sns.set_style("whitegrid")
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
def generate_boxplots(
|
||||
csv_path: str,
|
||||
parameter_columns: Optional[List[str]] = None,
|
||||
data_start_column: int = 4,
|
||||
save_individual: bool = True,
|
||||
use_seaborn: bool = True,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成水质参数的箱型图
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径
|
||||
parameter_columns: 参数列名列表(如果为None,自动检测)
|
||||
data_start_column: 数据开始列索引(从第几列开始,默认第5列,索引为4)
|
||||
save_individual: 是否为每个参数单独保存箱型图
|
||||
use_seaborn: 是否使用seaborn绘制(更美观)
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
箱型图文件路径字典
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成水质参数箱型图")
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供 csv_path")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
csv_dir = Path(csv_path).parent
|
||||
output_dir = str(csv_dir / "visualization" / "boxplots")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
# 确定参数列
|
||||
if parameter_columns is None:
|
||||
data_columns = df.iloc[:, data_start_column:]
|
||||
parameter_columns = list(data_columns.columns)
|
||||
else:
|
||||
parameter_columns = [col for col in parameter_columns if col in df.columns]
|
||||
|
||||
if not parameter_columns:
|
||||
print("警告: 未找到有效的参数列")
|
||||
return {}
|
||||
|
||||
boxplot_dir = Path(output_dir)
|
||||
boxplot_paths = {}
|
||||
|
||||
if save_individual:
|
||||
print(f"为每个参数单独绘制箱型图(共 {len(parameter_columns)} 个参数)")
|
||||
|
||||
for column in parameter_columns:
|
||||
if column not in df.columns:
|
||||
continue
|
||||
|
||||
clean_data = df[column].dropna()
|
||||
|
||||
if len(clean_data) == 0:
|
||||
print(f"跳过列 '{column}': 没有有效数据")
|
||||
continue
|
||||
|
||||
try:
|
||||
plt.figure(figsize=(8, 6))
|
||||
|
||||
if use_seaborn:
|
||||
plot_data = pd.DataFrame({
|
||||
'参数': [column] * len(clean_data),
|
||||
'数值': clean_data
|
||||
})
|
||||
sns.boxplot(data=plot_data, x='参数', y='数值', palette='Set2')
|
||||
sns.stripplot(data=plot_data, x='参数', y='数值',
|
||||
color='red', alpha=0.6, size=5, jitter=True)
|
||||
else:
|
||||
box_plot = plt.boxplot([clean_data], labels=[column],
|
||||
patch_artist=True, showfliers=False)
|
||||
box_plot['boxes'][0].set_facecolor('lightblue')
|
||||
box_plot['boxes'][0].set_alpha(0.7)
|
||||
|
||||
x_pos = np.random.normal(1, 0.04, size=len(clean_data))
|
||||
plt.scatter(x_pos, clean_data, alpha=0.6, s=30, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
|
||||
plt.title(f'{column} - 箱型图', fontsize=14, fontweight='bold')
|
||||
plt.xlabel('参数', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
|
||||
stats_text = (f'数据点数: {len(clean_data)}\n'
|
||||
f'均值: {clean_data.mean():.2f}\n'
|
||||
f'中位数: {clean_data.median():.2f}\n'
|
||||
f'标准差: {clean_data.std():.2f}')
|
||||
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round',
|
||||
facecolor='wheat' if not use_seaborn else 'lightgreen',
|
||||
alpha=0.8))
|
||||
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
plt.tight_layout()
|
||||
|
||||
safe_column_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
|
||||
save_path = boxplot_dir / f'{safe_column_name}_boxplot.png'
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
boxplot_paths[column] = str(save_path)
|
||||
print(f" 已保存: {save_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" 处理参数 {column} 时出错: {e}")
|
||||
continue
|
||||
|
||||
# 综合箱型图
|
||||
try:
|
||||
print("\n生成综合箱型图(所有参数在一张图上)")
|
||||
plt.figure(figsize=(max(12, len(parameter_columns) * 0.8), 8))
|
||||
|
||||
box_data = []
|
||||
labels = []
|
||||
for column in parameter_columns:
|
||||
if column in df.columns:
|
||||
clean_data = df[column].dropna()
|
||||
if len(clean_data) > 0:
|
||||
box_data.append(clean_data)
|
||||
labels.append(column)
|
||||
|
||||
if box_data:
|
||||
if use_seaborn:
|
||||
melted_data = pd.melt(df[labels], var_name='参数', value_name='数值')
|
||||
melted_data = melted_data.dropna()
|
||||
sns.boxplot(data=melted_data, x='参数', y='数值', palette='Set3')
|
||||
sns.stripplot(data=melted_data, x='参数', y='数值',
|
||||
color='red', alpha=0.6, size=4, jitter=True)
|
||||
else:
|
||||
box_plot = plt.boxplot(box_data, labels=labels, patch_artist=True, showfliers=False)
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(box_data)))
|
||||
for patch, color in zip(box_plot['boxes'], colors):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.7)
|
||||
|
||||
for i, data in enumerate(box_data):
|
||||
x_pos = np.random.normal(i + 1, 0.04, size=len(data))
|
||||
plt.scatter(x_pos, data, alpha=0.6, s=20, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
|
||||
plt.title('水质参数箱型图(综合)', fontsize=16, fontweight='bold')
|
||||
plt.xlabel('参数', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
plt.tight_layout()
|
||||
|
||||
combined_path = boxplot_dir / 'all_parameters_boxplot.png'
|
||||
plt.savefig(combined_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
boxplot_paths['all_parameters'] = str(combined_path)
|
||||
print(f" 已保存综合箱型图: {combined_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"生成综合箱型图时出错: {e}")
|
||||
|
||||
print(f"\n箱型图生成完成,共生成 {len(boxplot_paths)} 个图表")
|
||||
return boxplot_paths
|
||||
59
src/core/visualization/preview.py
Normal file
59
src/core/visualization/preview.py
Normal file
@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 耀斑影像预览图生成
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_glint_deglint_previews(
|
||||
work_dir: str,
|
||||
output_subdir: str = "glint_deglint_previews",
|
||||
generate_glint: bool = True,
|
||||
generate_deglint: bool = True,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成2_glint和3_deglint文件夹中影像文件的PNG预览图
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录
|
||||
output_subdir: 输出子目录名称
|
||||
generate_glint: 是否处理2_glint文件夹
|
||||
generate_deglint: 是否处理3_deglint文件夹
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
生成的预览图路径字典
|
||||
"""
|
||||
print(f"\n{'='*70}")
|
||||
print("步骤: 生成耀斑分析影像预览图")
|
||||
print(f"{'='*70}")
|
||||
|
||||
if work_dir is None:
|
||||
raise ValueError("请提供 work_dir")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(work_dir) / "visualization" / output_subdir)
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
try:
|
||||
preview_paths = visualizer.generate_glint_deglint_previews(
|
||||
work_dir=work_dir,
|
||||
output_subdir=output_subdir,
|
||||
generate_glint=generate_glint,
|
||||
generate_deglint=generate_deglint
|
||||
)
|
||||
|
||||
print(f"耀斑分析影像预览图生成完成,共生成 {len(preview_paths)} 个预览图")
|
||||
return preview_paths
|
||||
|
||||
except Exception as e:
|
||||
print(f"生成耀斑分析影像预览图时出错: {e}")
|
||||
return {}
|
||||
147
src/core/visualization/report.py
Normal file
147
src/core/visualization/report.py
Normal file
@ -0,0 +1,147 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 流程执行报告生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def generate_pipeline_report(
|
||||
step_timings: Dict,
|
||||
pipeline_start_time: Optional[float] = None,
|
||||
pipeline_end_time: Optional[float] = None,
|
||||
output_path: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
生成流程执行报告,包含每步的耗时统计
|
||||
|
||||
Args:
|
||||
step_timings: 步骤耗时字典(格式:{step_name: {start_time, end_time, elapsed_seconds, elapsed_formatted, status, error}})
|
||||
pipeline_start_time: 流程开始时间戳
|
||||
pipeline_end_time: 流程结束时间戳
|
||||
output_path: 输出文件路径(如果为None,自动生成)
|
||||
|
||||
Returns:
|
||||
报告文件路径
|
||||
"""
|
||||
if output_path is None:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
output_path = str(Path.cwd() / "reports" / f"pipeline_report_{timestamp}.csv")
|
||||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _format_time(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.2f}秒"
|
||||
elif seconds < 3600:
|
||||
minutes = int(seconds // 60)
|
||||
secs = seconds % 60
|
||||
return f"{minutes}分{secs:.2f}秒"
|
||||
else:
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = seconds % 60
|
||||
return f"{hours}小时{minutes}分{secs:.2f}秒"
|
||||
|
||||
# 准备报告数据
|
||||
report_data = []
|
||||
total_time = 0.0
|
||||
|
||||
step_order = [
|
||||
"步骤1: 生成水域mask",
|
||||
"步骤2: 找到耀斑区域",
|
||||
"步骤3: 去除耀斑",
|
||||
"步骤4: 处理CSV文件",
|
||||
"步骤5: 提取训练样本点光谱",
|
||||
"步骤5.5: 计算水质光谱指数",
|
||||
"步骤6: 训练机器学习模型",
|
||||
"步骤6.5: 非经验模型训练",
|
||||
"步骤6.75: 自定义回归",
|
||||
"步骤7: 生成预测采样点",
|
||||
"步骤8: 预测水质参数",
|
||||
"步骤9: 生成分布图"
|
||||
]
|
||||
|
||||
for step_name in step_order:
|
||||
if step_name in step_timings:
|
||||
timing_info = step_timings[step_name]
|
||||
report_data.append({
|
||||
'步骤': step_name,
|
||||
'开始时间': timing_info['start_time'],
|
||||
'结束时间': timing_info['end_time'],
|
||||
'耗时(秒)': f"{timing_info['elapsed_seconds']:.2f}",
|
||||
'耗时(格式化)': timing_info['elapsed_formatted'],
|
||||
'状态': timing_info['status'],
|
||||
'错误信息': timing_info.get('error', '')
|
||||
})
|
||||
if timing_info['status'] == 'completed':
|
||||
total_time += timing_info['elapsed_seconds']
|
||||
|
||||
if pipeline_start_time and pipeline_end_time:
|
||||
pipeline_total = pipeline_end_time - pipeline_start_time
|
||||
report_data.append({
|
||||
'步骤': '总计',
|
||||
'开始时间': datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'结束时间': datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'耗时(秒)': f"{pipeline_total:.2f}",
|
||||
'耗时(格式化)': _format_time(pipeline_total),
|
||||
'状态': 'completed',
|
||||
'错误信息': ''
|
||||
})
|
||||
|
||||
df_report = pd.DataFrame(report_data)
|
||||
df_report.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
# 文本格式报告
|
||||
txt_output_path = str(Path(output_path).with_suffix('.txt'))
|
||||
with open(txt_output_path, 'w', encoding='utf-8') as f:
|
||||
f.write("="*80 + "\n")
|
||||
f.write("水质参数反演流程执行报告\n")
|
||||
f.write("="*80 + "\n\n")
|
||||
|
||||
if pipeline_start_time and pipeline_end_time:
|
||||
f.write(f"流程开始时间: {datetime.fromtimestamp(pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"流程结束时间: {datetime.fromtimestamp(pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"总耗时: {_format_time(pipeline_end_time - pipeline_start_time)}\n\n")
|
||||
|
||||
f.write("-"*80 + "\n")
|
||||
f.write("各步骤执行详情:\n")
|
||||
f.write("-"*80 + "\n\n")
|
||||
|
||||
for step_name in step_order:
|
||||
if step_name in step_timings:
|
||||
timing_info = step_timings[step_name]
|
||||
f.write(f"{step_name}\n")
|
||||
f.write(f" 开始时间: {timing_info['start_time']}\n")
|
||||
f.write(f" 结束时间: {timing_info['end_time']}\n")
|
||||
f.write(f" 耗时: {timing_info['elapsed_formatted']} ({timing_info['elapsed_seconds']:.2f}秒)\n")
|
||||
f.write(f" 状态: {timing_info['status']}\n")
|
||||
if timing_info.get('error'):
|
||||
f.write(f" 错误: {timing_info['error']}\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write("-"*80 + "\n")
|
||||
f.write("统计摘要:\n")
|
||||
f.write("-"*80 + "\n")
|
||||
completed_steps = [s for s in step_timings.values() if s['status'] == 'completed']
|
||||
failed_steps = [s for s in step_timings.values() if s['status'] == 'failed']
|
||||
skipped_steps = [s for s in step_timings.values() if s['status'] == 'skipped']
|
||||
|
||||
f.write(f"成功完成的步骤: {len(completed_steps)}\n")
|
||||
f.write(f"失败的步骤: {len(failed_steps)}\n")
|
||||
f.write(f"跳过的步骤: {len(skipped_steps)}\n")
|
||||
|
||||
if completed_steps:
|
||||
completed_times = [s['elapsed_seconds'] for s in completed_steps]
|
||||
f.write(f"平均耗时: {_format_time(np.mean(completed_times))}\n")
|
||||
f.write(f"最长耗时: {_format_time(np.max(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.max(completed_times)][0]})\n")
|
||||
f.write(f"最短耗时: {_format_time(np.min(completed_times))} ({[s['elapsed_formatted'] for s in completed_steps if s['elapsed_seconds'] == np.min(completed_times)][0]})\n")
|
||||
|
||||
print(f"\n流程报告已生成:")
|
||||
print(f" CSV格式: {output_path}")
|
||||
print(f" 文本格式: {txt_output_path}")
|
||||
|
||||
return output_path
|
||||
147
src/core/visualization/scatter_plot.py
Normal file
147
src/core/visualization/scatter_plot.py
Normal file
@ -0,0 +1,147 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 散点图生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Union
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_model_scatter_plots(
|
||||
models_dir: str,
|
||||
training_csv_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
metric: str = 'test_r2',
|
||||
use_enhanced: bool = True,
|
||||
feature_start_column: Union[str, int] = 13,
|
||||
test_size: float = 0.2,
|
||||
random_state: int = 42,
|
||||
scatter_batch=None # 可选:传入已实例化的 scatter_batch 对象
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成模型评估散点图(真实值vs预测值)
|
||||
|
||||
Args:
|
||||
models_dir: 模型保存目录
|
||||
training_csv_path: 训练数据CSV路径
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
metric: 选择最佳模型的指标
|
||||
use_enhanced: 是否使用增强版散点图(带置信区间,使用sctter_batch)
|
||||
feature_start_column: 特征开始列名或索引
|
||||
test_size: 测试集比例
|
||||
random_state: 随机种子
|
||||
scatter_batch: 可选,已实例化的 WaterQualityScatterBatch 对象
|
||||
|
||||
Returns:
|
||||
散点图文件路径字典(键为目标参数名)
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成模型评估散点图")
|
||||
print("="*80)
|
||||
|
||||
if training_csv_path is None:
|
||||
raise ValueError("请提供 training_csv_path")
|
||||
|
||||
models_path = Path(models_dir)
|
||||
if not models_path.exists():
|
||||
raise ValueError(f"模型目录不存在: {models_dir}")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(models_dir).parent / "14_visualization" / "scatter_plots")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
scatter_paths = {}
|
||||
|
||||
# 增强版散点图
|
||||
if use_enhanced:
|
||||
print("使用增强版散点图(带置信区间)")
|
||||
try:
|
||||
from src.core.prediction.sctter_batch import WaterQualityScatterBatch
|
||||
if scatter_batch is None:
|
||||
scatter_batch = WaterQualityScatterBatch()
|
||||
|
||||
results = scatter_batch.batch_plot_scatter(
|
||||
models_root_dir=models_dir,
|
||||
csv_path=training_csv_path,
|
||||
output_dir=output_dir,
|
||||
metric=metric,
|
||||
target_column=None,
|
||||
feature_start_column=feature_start_column,
|
||||
test_size=test_size,
|
||||
random_state=random_state
|
||||
)
|
||||
|
||||
for target_name, result in results.items():
|
||||
if result.get('status') == 'success':
|
||||
scatter_paths[target_name] = result.get('save_path', '')
|
||||
print(f" ✓ {target_name}: {result.get('save_path', '')}")
|
||||
else:
|
||||
print(f" ✗ {target_name}: 失败 - {result.get('error', '未知错误')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用增强版散点图时出错: {e}")
|
||||
print("回退到基础版散点图")
|
||||
use_enhanced = False
|
||||
|
||||
# 基础版散点图
|
||||
if not use_enhanced or not scatter_paths:
|
||||
print("使用基础版散点图")
|
||||
for target_folder in models_path.iterdir():
|
||||
if not target_folder.is_dir():
|
||||
continue
|
||||
|
||||
target_name = target_folder.name
|
||||
print(f"\n处理目标参数: {target_name}")
|
||||
|
||||
try:
|
||||
inferencer = WaterQualityInference(str(target_folder))
|
||||
eval_result = inferencer.evaluate_with_split(
|
||||
data_csv_path=training_csv_path,
|
||||
split_method="spxy",
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
metric=metric
|
||||
)
|
||||
|
||||
predictions = eval_result.get('predictions', {})
|
||||
if predictions:
|
||||
y_train_true = predictions.get('y_train_true')
|
||||
y_train_pred = predictions.get('y_train_pred')
|
||||
y_test_true = predictions.get('y_test_true')
|
||||
y_test_pred = predictions.get('y_test_pred')
|
||||
metrics = eval_result.get('test_metrics', {})
|
||||
|
||||
if y_train_true is not None and y_test_true is not None:
|
||||
y_all_true = np.concatenate([y_train_true, y_test_true])
|
||||
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
|
||||
|
||||
train_indices = np.arange(len(y_train_true))
|
||||
test_indices = np.arange(len(y_train_true), len(y_all_true))
|
||||
|
||||
scatter_path = visualizer.plot_scatter_true_vs_pred(
|
||||
y_true=y_all_true,
|
||||
y_pred=y_all_pred,
|
||||
target_name=target_name,
|
||||
train_indices=train_indices,
|
||||
test_indices=test_indices,
|
||||
metrics={
|
||||
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
|
||||
'test_r2': metrics.get('r2', 0),
|
||||
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
|
||||
'test_rmse': metrics.get('rmse', 0)
|
||||
}
|
||||
)
|
||||
scatter_paths[target_name] = scatter_path
|
||||
except Exception as e:
|
||||
print(f"处理目标参数 {target_name} 时出错: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n散点图生成完成,共生成 {len(scatter_paths)} 个图表")
|
||||
return scatter_paths
|
||||
80
src/core/visualization/spectrum_plot.py
Normal file
80
src/core/visualization/spectrum_plot.py
Normal file
@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 光谱曲线图生成
|
||||
"""
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Union
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_spectrum_comparison_plots(
|
||||
csv_path: str,
|
||||
parameter_columns: Optional[List[str]] = None,
|
||||
wavelength_start_column: Union[str, int] = "UTM_Y",
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成光谱曲线对比图(不同参数值的光谱曲线对比)
|
||||
|
||||
Args:
|
||||
csv_path: 包含光谱和参数值的CSV文件路径
|
||||
parameter_columns: 参数列名列表(如果为None,自动检测)
|
||||
wavelength_start_column: 波长开始列名或索引
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
光谱曲线图文件路径字典(键为参数名)
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成光谱曲线对比图")
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供 csv_path")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
csv_dir = Path(csv_path).parent
|
||||
output_dir = str(csv_dir / "visualization" / "spectrum_plots")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
# 读取数据以检测参数列
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
if parameter_columns is None:
|
||||
if isinstance(wavelength_start_column, str):
|
||||
try:
|
||||
wavelength_start_idx = df.columns.get_loc(wavelength_start_column)
|
||||
except:
|
||||
wavelength_start_idx = 13
|
||||
else:
|
||||
wavelength_start_idx = wavelength_start_column
|
||||
|
||||
parameter_columns = list(df.columns[:wavelength_start_idx])
|
||||
if len(parameter_columns) > 2:
|
||||
parameter_columns = parameter_columns[2:]
|
||||
|
||||
spectrum_paths = {}
|
||||
for param_col in parameter_columns:
|
||||
if param_col not in df.columns:
|
||||
continue
|
||||
|
||||
print(f"\n处理参数: {param_col}")
|
||||
try:
|
||||
spectrum_path = visualizer.plot_spectrum_by_parameter(
|
||||
csv_path=csv_path,
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wavelength_start_column,
|
||||
n_groups=5
|
||||
)
|
||||
spectrum_paths[param_col] = spectrum_path
|
||||
except Exception as e:
|
||||
print(f"处理参数 {param_col} 时出错: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n光谱曲线图生成完成,共生成 {len(spectrum_paths)} 个图表")
|
||||
return spectrum_paths
|
||||
59
src/core/visualization/statistics.py
Normal file
59
src/core/visualization/statistics.py
Normal file
@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化模块 - 统计图表生成
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
|
||||
|
||||
def generate_statistical_charts(
|
||||
csv_path: str,
|
||||
parameter_columns: Optional[List[str]] = None,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
生成统计图表(箱线图、直方图、相关性热力图)
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径
|
||||
parameter_columns: 参数列名列表(如果为None,自动检测)
|
||||
output_dir: 输出目录(None则使用默认)
|
||||
|
||||
Returns:
|
||||
统计图表文件路径字典
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("生成统计图表")
|
||||
print("="*80)
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("请提供 csv_path")
|
||||
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
csv_dir = Path(csv_path).parent
|
||||
output_dir = str(csv_dir / "visualization" / "statistical_charts")
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 实例化可视化器
|
||||
visualizer = WaterQualityVisualization(output_dir)
|
||||
|
||||
# 读取数据以检测参数列
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
if parameter_columns is None:
|
||||
parameter_columns = list(df.columns[2:])
|
||||
parameter_columns = [col for col in parameter_columns
|
||||
if df[col].dtype in [np.float64, np.int64]]
|
||||
|
||||
chart_paths = visualizer.plot_statistical_charts(
|
||||
csv_path=csv_path,
|
||||
parameter_columns=parameter_columns
|
||||
)
|
||||
|
||||
print(f"\n统计图表生成完成")
|
||||
return chart_paths
|
||||
File diff suppressed because it is too large
Load Diff
1
src/gui/components/__init__.py
Normal file
1
src/gui/components/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# src.gui.components package
|
||||
148
src/gui/components/custom_widgets.py
Normal file
148
src/gui/components/custom_widgets.py
Normal file
@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自定义组件 - 文件选择控件等公共组件
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QHBoxLayout, QLabel, QLineEdit, QPushButton, QFileDialog,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
|
||||
class DirSelectWidget(QWidget):
|
||||
"""目录选择组件"""
|
||||
def __init__(self, label_text, parent=None):
|
||||
"""
|
||||
初始化目录选择组件
|
||||
|
||||
Args:
|
||||
label_text: 标签文本
|
||||
parent: 父控件
|
||||
"""
|
||||
super().__init__(parent)
|
||||
self.init_ui(label_text)
|
||||
|
||||
def init_ui(self, label_text):
|
||||
layout = QHBoxLayout()
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
self.label = QLabel(label_text)
|
||||
self.label.setMinimumWidth(120)
|
||||
self.line_edit = QLineEdit()
|
||||
self.line_edit.setPlaceholderText("请选择目录...")
|
||||
self.browse_btn = QPushButton("浏览...")
|
||||
self.browse_btn.setMaximumWidth(80)
|
||||
self.browse_btn.clicked.connect(self.browse_dir)
|
||||
|
||||
layout.addWidget(self.label)
|
||||
layout.addWidget(self.line_edit, 1)
|
||||
layout.addWidget(self.browse_btn)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def browse_dir(self):
|
||||
"""浏览目录 - 智能记忆上次选择位置"""
|
||||
current_text = self.line_edit.text().strip()
|
||||
initial_dir = ""
|
||||
|
||||
# 最高优先级:输入框已有路径存在
|
||||
if current_text:
|
||||
if os.path.isdir(current_text):
|
||||
initial_dir = current_text
|
||||
else:
|
||||
dir_path = os.path.dirname(current_text)
|
||||
if dir_path and os.path.exists(dir_path):
|
||||
initial_dir = dir_path
|
||||
|
||||
# 调用目录选择对话框
|
||||
dir_path = QFileDialog.getExistingDirectory(
|
||||
self, "选择目录", initial_dir
|
||||
)
|
||||
if dir_path:
|
||||
self.line_edit.setText(dir_path)
|
||||
|
||||
def get_path(self):
|
||||
"""获取路径"""
|
||||
return self.line_edit.text()
|
||||
|
||||
def set_path(self, path):
|
||||
"""设置路径"""
|
||||
self.line_edit.setText(str(path))
|
||||
|
||||
|
||||
class FileSelectWidget(QWidget):
|
||||
"""文件选择组件"""
|
||||
def __init__(self, label_text, file_filter="All Files (*.*)", mode="open", parent=None):
|
||||
"""
|
||||
初始化文件选择组件
|
||||
|
||||
Args:
|
||||
label_text: 标签文本
|
||||
file_filter: 文件过滤器
|
||||
mode: 选择模式 - "open"(打开文件) 或 "save"(保存文件)
|
||||
parent: 父控件
|
||||
"""
|
||||
super().__init__(parent)
|
||||
self.file_filter = file_filter
|
||||
self.mode = mode # "open" 或 "save"
|
||||
self.init_ui(label_text)
|
||||
|
||||
def init_ui(self, label_text):
|
||||
layout = QHBoxLayout()
|
||||
layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
self.label = QLabel(label_text)
|
||||
self.label.setMinimumWidth(120)
|
||||
self.line_edit = QLineEdit()
|
||||
placeholder = "请选择保存路径..." if self.mode == "save" else "请选择文件..."
|
||||
self.line_edit.setPlaceholderText(placeholder)
|
||||
self.browse_btn = QPushButton("浏览...")
|
||||
self.browse_btn.setMaximumWidth(80)
|
||||
self.browse_btn.clicked.connect(self.browse_file)
|
||||
|
||||
layout.addWidget(self.label)
|
||||
layout.addWidget(self.line_edit, 1)
|
||||
layout.addWidget(self.browse_btn)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
def browse_file(self):
|
||||
"""浏览文件 - 智能记忆上次选择位置"""
|
||||
current_text = self.line_edit.text().strip()
|
||||
initial_dir = ""
|
||||
|
||||
# 最高优先级:输入框已有路径存在
|
||||
if current_text:
|
||||
if os.path.isdir(current_text):
|
||||
initial_dir = current_text
|
||||
else:
|
||||
dir_path = os.path.dirname(current_text)
|
||||
if dir_path and os.path.exists(dir_path):
|
||||
initial_dir = dir_path
|
||||
|
||||
if self.mode == "save":
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存文件", initial_dir, self.file_filter
|
||||
)
|
||||
else:
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self, "选择文件", initial_dir, self.file_filter
|
||||
)
|
||||
if file_path:
|
||||
self.line_edit.setText(file_path)
|
||||
|
||||
def get_path(self):
|
||||
"""获取路径"""
|
||||
return self.line_edit.text()
|
||||
|
||||
def set_path(self, path):
|
||||
"""设置路径"""
|
||||
self.line_edit.setText(str(path))
|
||||
|
||||
def set_read_only(self, read_only=True):
|
||||
"""设置文件选择框为只读,并禁用浏览按钮。"""
|
||||
self.line_edit.setReadOnly(read_only)
|
||||
self.browse_btn.setEnabled(not read_only)
|
||||
1
src/gui/core/__init__.py
Normal file
1
src/gui/core/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# src.gui.core
|
||||
332
src/gui/core/worker_thread.py
Normal file
332
src/gui/core/worker_thread.py
Normal file
@ -0,0 +1,332 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
后台线程模块:Pipeline 执行线程与诊断逻辑。
|
||||
"""
|
||||
import traceback
|
||||
from PyQt5.QtCore import QThread, pyqtSignal
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 依赖诊断
|
||||
# =============================================================================
|
||||
|
||||
def check_pipeline_dependencies():
|
||||
"""检查pipeline模块的依赖项"""
|
||||
missing_deps = []
|
||||
dep_errors = {}
|
||||
|
||||
required_packages = [
|
||||
'numpy', 'pandas', 'scipy', 'matplotlib', 'sklearn',
|
||||
'joblib', 'PIL', 'cv2', 'rasterio', 'geopandas'
|
||||
]
|
||||
|
||||
for package in required_packages:
|
||||
try:
|
||||
if package == 'PIL':
|
||||
import PIL
|
||||
elif package == 'cv2':
|
||||
import cv2
|
||||
else:
|
||||
__import__(package)
|
||||
except Exception as e:
|
||||
missing_deps.append(package)
|
||||
dep_errors[package] = repr(e)
|
||||
|
||||
return missing_deps, dep_errors
|
||||
|
||||
|
||||
def diagnose_pipeline_import_error():
|
||||
"""诊断pipeline导入错误"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
error_info = []
|
||||
|
||||
is_frozen = getattr(sys, "frozen", False) or bool(getattr(sys, "_MEIPASS", None))
|
||||
|
||||
if is_frozen:
|
||||
error_info.append(
|
||||
"[INFO] PyInstaller 环境:Pipeline 从程序内置包加载,跳过对仓库路径 src/core/*.py 的磁盘检查"
|
||||
)
|
||||
else:
|
||||
pipeline_file = os.path.normpath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "core", "water_quality_inversion_pipeline_GUI.py")
|
||||
)
|
||||
if not os.path.exists(pipeline_file):
|
||||
error_info.append(f"[ERROR] Pipeline文件不存在: {pipeline_file}")
|
||||
error_info.append(
|
||||
" 解决方案: 请确保项目结构完整,检查 src/core/ 下是否有 water_quality_inversion_pipeline_GUI.py"
|
||||
)
|
||||
else:
|
||||
error_info.append(f"[OK] Pipeline文件存在: {pipeline_file}")
|
||||
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
if current_dir not in sys.path:
|
||||
sys.path.insert(0, current_dir)
|
||||
error_info.append(f"[INFO] 已添加路径到sys.path: {current_dir}")
|
||||
|
||||
missing_deps, dep_errors = check_pipeline_dependencies()
|
||||
if missing_deps:
|
||||
error_info.append(f"[ERROR] 缺少必需的依赖包: {', '.join(missing_deps)}")
|
||||
for pkg in missing_deps:
|
||||
if pkg in dep_errors:
|
||||
error_info.append(f" - {pkg} 导入失败原因: {dep_errors[pkg]}")
|
||||
error_info.append(" 解决方案: 请运行以下命令安装依赖:")
|
||||
error_info.append(" pip install -r requirements.txt")
|
||||
error_info.append(" 或使用conda:")
|
||||
error_info.append(" conda install numpy pandas scipy matplotlib scikit-learn joblib pillow opencv-python rasterio geopandas")
|
||||
else:
|
||||
error_info.append("[OK] 主要依赖包均已安装")
|
||||
|
||||
try:
|
||||
from osgeo import gdal # noqa: F401
|
||||
error_info.append("[OK] GDAL (osgeo) 可用")
|
||||
except ImportError:
|
||||
try:
|
||||
from osgeo import gdal # noqa: F401
|
||||
error_info.append("[OK] GDAL 可用")
|
||||
except ImportError:
|
||||
error_info.append("[WARNING] GDAL/osgeo 不可用,将影响栅格与地理数据处理")
|
||||
error_info.append(" 开发环境: conda install gdal")
|
||||
error_info.append(" 打包环境: 请在构建所用 Conda 环境中打包,并确保 spec 已收集 Library/bin 中依赖 DLL")
|
||||
|
||||
try:
|
||||
import unittest
|
||||
error_info.append("[OK] unittest模块可用")
|
||||
except ImportError:
|
||||
error_info.append("[WARNING] unittest模块不可用,这可能是PyInstaller打包环境导致的")
|
||||
error_info.append(" 这不会影响主要功能,但可能影响某些测试相关特性")
|
||||
|
||||
return error_info
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pipeline 可用性标志(模块级状态)
|
||||
# =============================================================================
|
||||
|
||||
PIPELINE_AVAILABLE = False
|
||||
PIPELINE_ERROR_INFO = []
|
||||
|
||||
try:
|
||||
error_info = diagnose_pipeline_import_error()
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
PIPELINE_AVAILABLE = True
|
||||
print("[OK] 成功导入pipeline模块")
|
||||
PIPELINE_ERROR_INFO = error_info
|
||||
|
||||
except ImportError as e:
|
||||
PIPELINE_AVAILABLE = False
|
||||
error_info = diagnose_pipeline_import_error()
|
||||
|
||||
print("="*60)
|
||||
print("[ERROR] PIPELINE导入失败 - 详细诊断信息:")
|
||||
print("="*60)
|
||||
|
||||
for info in error_info:
|
||||
print(info)
|
||||
|
||||
print("-"*60)
|
||||
print(f"原始ImportError: {str(e)}")
|
||||
print("-"*60)
|
||||
|
||||
if "unittest" in str(e):
|
||||
print("[INFO] unittest模块缺失 - 这通常在PyInstaller打包环境中发生")
|
||||
print("解决方案:")
|
||||
print(" 1. 这不会影响主要功能,程序仍可正常运行")
|
||||
print(" 2. 如果需要修复,可以在.spec文件中添加unittest模块:")
|
||||
print(" a = Analysis(..., hiddenimports=['unittest', 'unittest.mock'])")
|
||||
print(" 3. 或在PyInstaller命令中添加: --hidden-import unittest")
|
||||
elif "water_quality_inversion_pipeline_GUI" in str(e):
|
||||
print("[INFO] 可能的解决方案:")
|
||||
print(" 1. 检查src/core/water_quality_inversion_pipeline_GUI.py文件是否存在")
|
||||
print(" 2. 确保Python路径设置正确")
|
||||
print(" 3. 尝试重新安装依赖: pip install -r requirements.txt")
|
||||
print(" 4. 检查Python版本是否兼容(推荐Python 3.8-3.11)")
|
||||
|
||||
import traceback
|
||||
print("\n完整错误追踪:")
|
||||
traceback.print_exc()
|
||||
print("="*60)
|
||||
|
||||
PIPELINE_ERROR_INFO = error_info
|
||||
|
||||
except Exception as e:
|
||||
PIPELINE_AVAILABLE = False
|
||||
error_info = diagnose_pipeline_import_error()
|
||||
|
||||
print("="*60)
|
||||
print("[ERROR] PIPELINE导入失败 - 其他错误:")
|
||||
print("="*60)
|
||||
|
||||
for info in error_info:
|
||||
print(info)
|
||||
|
||||
print("-"*60)
|
||||
print(f"原始错误: {str(e)}")
|
||||
print("-"*60)
|
||||
|
||||
print("[INFO] 可能的解决方案:")
|
||||
print(" 1. 检查Python环境和依赖包版本")
|
||||
print(" 2. 尝试重新安装所有依赖")
|
||||
print(" 3. 检查是否有语法错误或其他模块导入问题")
|
||||
|
||||
import traceback
|
||||
print("\n完整错误追踪:")
|
||||
traceback.print_exc()
|
||||
print("="*60)
|
||||
|
||||
PIPELINE_ERROR_INFO = error_info
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WorkerThread
|
||||
# =============================================================================
|
||||
|
||||
class WorkerThread(QThread):
|
||||
"""后台工作线程,用于执行耗时任务(在工作线程内创建 Pipeline,避免阻塞 UI)。"""
|
||||
progress_update = pyqtSignal(int, str) # 进度更新信号 (percentage, message)
|
||||
log_message = pyqtSignal(str, str) # 日志消息信号 (message, level: 'info'/'warning'/'error')
|
||||
step_completed = pyqtSignal(str, bool, str) # 步骤完成信号 (step_name, success, message)
|
||||
finished = pyqtSignal(bool, str) # 完成信号 (success, message)
|
||||
|
||||
def __init__(self, work_dir: str, config, mode='full', step_name=None):
|
||||
super().__init__()
|
||||
self.work_dir = str(work_dir)
|
||||
self.config = config
|
||||
self.mode = mode # 'full' 或 'single_step'
|
||||
self.step_name = step_name # 单步执行时的步骤名称
|
||||
self.pipeline = None
|
||||
self.is_running = True
|
||||
self.current_step = None
|
||||
self.step_count = 0
|
||||
self.total_steps = 9
|
||||
|
||||
def pipeline_callback(self, step_name, status, message=""):
|
||||
"""Pipeline回调函数,用于接收步骤状态"""
|
||||
if status == "start":
|
||||
self.log_message.emit(f"[START] 开始执行: {step_name}", "info")
|
||||
progress = int((self.step_count / self.total_steps) * 100)
|
||||
self.progress_update.emit(progress, f"正在执行: {step_name}")
|
||||
elif status == "completed":
|
||||
self.step_count += 1
|
||||
self.log_message.emit(f"[DONE] 完成: {step_name} {message}", "info")
|
||||
self.step_completed.emit(step_name, True, message)
|
||||
progress = int((self.step_count / self.total_steps) * 100)
|
||||
self.progress_update.emit(progress, f"已完成: {step_name}")
|
||||
elif status == "skipped":
|
||||
self.step_count += 1
|
||||
self.log_message.emit(f"[SKIP] 跳过: {step_name} {message}", "warning")
|
||||
self.step_completed.emit(step_name, True, f"跳过: {message}")
|
||||
progress = int((self.step_count / self.total_steps) * 100)
|
||||
self.progress_update.emit(progress, f"已跳过: {step_name}")
|
||||
elif status == "error":
|
||||
self.log_message.emit(f"[ERROR] 错误: {step_name} - {message}", "error")
|
||||
self.step_completed.emit(step_name, False, message)
|
||||
elif status == "info":
|
||||
self.log_message.emit(f" {message}", "info")
|
||||
elif status == "warning":
|
||||
self.log_message.emit(f" [WARNING] {message}", "warning")
|
||||
|
||||
def run(self):
|
||||
"""运行 pipeline:子线程内切换 Matplotlib 为 Agg,避免 Qt5Agg 在后台线程绘图导致界面卡死。"""
|
||||
import os
|
||||
# GDAL 环境变量保护(放在最前面,防止路径/编码问题)
|
||||
os.environ['GDAL_FILENAME_IS_UTF8'] = 'YES'
|
||||
os.environ['SHAPE_ENCODING'] = 'UTF-8'
|
||||
|
||||
mpl_prev = None
|
||||
try:
|
||||
import matplotlib
|
||||
mpl_prev = matplotlib.get_backend()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("Agg")
|
||||
except Exception:
|
||||
mpl_prev = None
|
||||
try:
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
self.pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
|
||||
|
||||
if self.mode == 'full':
|
||||
self.log_message.emit("开始运行完整流程...", "info")
|
||||
self.step_count = 0
|
||||
|
||||
if hasattr(self.pipeline, 'set_callback'):
|
||||
self.pipeline.set_callback(self.pipeline_callback)
|
||||
|
||||
self.pipeline.run_full_pipeline(self.config)
|
||||
|
||||
self.progress_update.emit(100, "流程执行完成")
|
||||
self.finished.emit(True, "完整流程执行成功!")
|
||||
else:
|
||||
self.log_message.emit(f"开始独立运行步骤: {self.step_name}", "info")
|
||||
self.progress_update.emit(0, f"正在执行: {self.step_name}")
|
||||
|
||||
if hasattr(self.pipeline, 'set_callback'):
|
||||
self.pipeline.set_callback(self.pipeline_callback)
|
||||
|
||||
self.run_single_step(self.step_name, self.config)
|
||||
|
||||
self.progress_update.emit(100, f"步骤 {self.step_name} 执行完成")
|
||||
self.finished.emit(True, f"步骤 {self.step_name} 独立运行成功!")
|
||||
except Exception as e:
|
||||
error_msg = f"执行失败: {str(e)}\n{traceback.format_exc()}"
|
||||
self.log_message.emit(error_msg, "error")
|
||||
self.finished.emit(False, error_msg)
|
||||
finally:
|
||||
if mpl_prev:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend(mpl_prev)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def run_single_step(self, step_name, config):
|
||||
"""运行单个步骤"""
|
||||
step_method_map = {
|
||||
'step1': 'step1_generate_water_mask',
|
||||
'step2': 'step2_find_glint_area',
|
||||
'step3': 'step3_remove_glint',
|
||||
'step4': 'step4_process_csv',
|
||||
'step5': 'step5_extract_training_spectra',
|
||||
'step5_5': 'step5_5_calculate_water_quality_indices',
|
||||
'step6': 'step6_train_models',
|
||||
'step6_5': 'step6_5_non_empirical_modeling',
|
||||
'step6_75': 'step6_75_custom_regression',
|
||||
'step7': 'step7_generate_sampling_points',
|
||||
'step8': 'step8_predict_water_quality',
|
||||
'step8_5': 'step8_5_predict_with_non_empirical_models',
|
||||
'step8_75': 'step8_75_predict_with_custom_regression',
|
||||
'step9': 'step9_generate_distribution_map'
|
||||
}
|
||||
|
||||
if step_name not in step_method_map:
|
||||
raise ValueError(f"未知的步骤名称: {step_name}")
|
||||
|
||||
method_name = step_method_map[step_name]
|
||||
step_config = dict(config.get(step_name, {}))
|
||||
|
||||
step_config['skip_dependency_check'] = True
|
||||
|
||||
if step_name == 'step9':
|
||||
step_config.pop('step9_batch_mode', None)
|
||||
step_config.pop('prediction_csv_dir', None)
|
||||
step_config.pop('recursive_csv_scan', None)
|
||||
|
||||
if step_name in ['step2', 'step3', 'step4', 'step5', 'step6', 'step7', 'step8', 'step8_5', 'step8_75']:
|
||||
step_config.pop('output_path', None)
|
||||
|
||||
if step_name == 'step8_5' and 'models_dir' in step_config:
|
||||
step_config['non_empirical_models_dir'] = step_config.pop('models_dir')
|
||||
|
||||
method = getattr(self.pipeline, method_name)
|
||||
result = method(**step_config)
|
||||
|
||||
return result
|
||||
|
||||
def stop(self):
|
||||
"""停止执行"""
|
||||
self.is_running = False
|
||||
self.terminate()
|
||||
46
src/gui/model/waterindex.csv
Normal file
46
src/gui/model/waterindex.csv
Normal file
@ -0,0 +1,46 @@
|
||||
Formula_Name,Category,Formula,Reference
|
||||
BGA_Am09KBBI,Phycocyanin (BGA_PC),(w686 - w658) / (w686 + w658),"Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S.; Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery, Optics Express, 2009, 17, 11, 1-13."
|
||||
BGA_Be162B643sub629,Phycocyanin (BGA_PC),w644 - w629,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be162B700sub601,Phycocyanin (BGA_PC),w700 - w601,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be162BsubPhy,Phycocyanin (BGA_PC),w715 - w615,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540."
|
||||
BGA_Be16FLHBlueRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w458 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be16FLHGreenRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w558 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be16FLHVioletRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w444 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be16MPI,Phycocyanin (BGA_PC),(w615 - w601) - (w644 - w601),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be16NDPhyI,Phycocyanin (BGA_PC),(w700 - w622) / (w700 + w622),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540."
|
||||
BGA_Be16NDPhyI644over615,Phycocyanin (BGA_PC),(w644 - w615) / (w644 + w615),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 541."
|
||||
BGA_Be16NDPhyI644over629,Phycocyanin (BGA_PC),(w644 - w629) / (w644 + w629),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 542."
|
||||
BGA_Be16Phy2BDA644over629,Phycocyanin (BGA_PC),w644 / w629,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 545."
|
||||
BGA_Da052BDA,Phycocyanin (BGA_PC),w714 / w672,"Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672."
|
||||
BGA_Go04MCI,Phycocyanin (BGA_PC),w709 - w681 - (w753 - w681),"Gower, J.F.R.; Brown,L.; Borstad, G.A.; Observation of chlorophyll fluorescence in west coast waters of Canada using the MODIS satellite sensor. Can. J. Remote Sens., 2004, 30 (1), 17闁?5."
|
||||
BGA_HU103BDA,Phycocyanin (BGA_PC),(((1 / w615) - (1 / w600)) - w725),"Hunter, P.D.; Tyler, A.N.; Willby, N.J.; Gilvear, D.J.; The spatial dynamics of vertical migration by Microcystis aeruginosa in a eutrophic shallow lake: A case study using high spatial resolution time-series airborne remote sensing. Limn. Oceanogr. 2008, 53, 2391-2406"
|
||||
BGA_Ku15PhyCI,Phycocyanin (BGA_PC),(-1 * (W681 - W665 - (W709 - W665))),"Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-10."
|
||||
BGA_Ku15SLH,Phycocyanin (BGA_PC),(w715 - w658) + (w715 - w658),"Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-11."
|
||||
BGA_MI092BDA,Phycocyanin (BGA_PC),w700 / w600,"Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758闁?75."
|
||||
BGA_MM092BDA,Phycocyanin (BGA_PC),w724 / w600,"Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758闁?76."
|
||||
BGA_MM12NDCIalt,Phycocyanin (BGA_PC),(w700 - w658) / (w700 + w658),"Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114003"
|
||||
BGA_MM143BDAopt,Phycocyanin (BGA_PC),((1 / w629) - (1 / w659)) * w724,"Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114004"
|
||||
BGA_SI052BDA,Phycocyanin (BGA_PC),w709 / w620,"Simis, S. G. H.; Peters, S.W. M.; Gons, H. J.; Remote sensing of the cyanobacteria pigment phycocyanin in turbid inland water. Limn. Oceanogr., 2005, 50, 237闁?45"
|
||||
BGA_SM122BDA,Phycocyanin (BGA_PC),w709 / w600,"Mishra, S. Remote sensing of cyanobacteria in turbid productive waters, PhD Dissertation. Mississippi State University, USA. 2012."
|
||||
BGA_SY002BDA,Phycocyanin (BGA_PC),w650 / w625,"Schalles, J.; Yacobi, Y. Remote detection and seasonal patterns of phycocyanin, carotenoid and chlorophyll-a pigments in eutrophic waters. Archiv fur Hydrobiologie, Special Issues Advances in Limnology, 2000, 55,153闁?68"
|
||||
BGA_Wy08CI,Phycocyanin (BGA_PC),(-1 * (W686 - W672 - (W715 - W672))),"Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672."
|
||||
Chl_Al10SABI,chlorophyll_a,(w857 - w644) / (w458 + w529),"Alawadi, F. Detection of surface algal blooms using the newly developed algorithm surface algal bloom index (SABI). Proc. SPIE 2010, 7825."
|
||||
Chl_Am092Bsub,chlorophyll_a,w681 - w665,"Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S. Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery. Opt. Express 2009, 17, 9126闁?144."
|
||||
Chl_Be16FLHblue,chlorophyll_a,w529 - (w644 + (w458 - w644)),"Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30."
|
||||
Chl_Be16FLHviolet,chlorophyll_a,w529 - (w644 + (w429 - w644)),"Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30."
|
||||
Chl_Be16NDTIblue,chlorophyll_a,(w658 - w458) / (w658 + w458),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 543."
|
||||
Chl_Be16NDTIviolet,chlorophyll_a,(w658 - w444) / (w658 + w444),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 544."
|
||||
Chl_De933BDA,chlorophyll_a,w600 - w648 - w625,"Dekker, A.; Detection of the optical water quality parameters for eutrophic waters by high resolution remote sensing, Ph.D. thesis, 1993, Free University, Amsterdam."
|
||||
Chl_Gi033BDA,chlorophyll_a,((1 / w672) - (1 / w715)) * w757,"Gitelson, A.A.; U. Gritz, and M. N. Merzlyak.; Relationships between leaf chlorophyll content and spectral reflectance and algorithms for non-destructive chlorophyll assessment in higher plant leaves. J. Plant Phys. 2003, 160, 271-282."
|
||||
Chl_Kn07KIVU,chlorophyll_a,(w458 - w644) / w529,"Kneubuhler, M.; Frank T.; Kellenberger, T.W; Pasche N.; Schmid M.; Mapping chlorophyll-a in Lake Kivu with remote sensing methods. 2007, Proceedings of the Envisat Symposium 2007, Montreux, Switzerland 23闁?7 April 2007 (ESA SP-636, July 2007)."
|
||||
Chl_MM12NDCI,chlorophyll_a,(w715 - w686) / (w715 + w686),"Mishra, S.; and Mishra, D.R. Normalized difference chlorophyll index: A novel model for remote estimation of chlorophyll-a concentration in turbid productive waters, Remote Sens. Environ., 2012, 117, 394-406"
|
||||
Chl_Zh10FLH,chlorophyll_a,w686 - (w715 + (w672 - w751)),"Zhao, D.Z.; Xing, X.G.; Liu, Y.G.; Yang, J.H.; Wang, L. The relation of chlorophyll-a concentration with the reflectance peak near 700 nm in algae-dominated waters and sensitivity of fluorescence algorithms for detecting algal bloom. Int. J. Remote Sens. 2010, 31, 39-48"
|
||||
Turb_Be16GreenPlusRedBothOverViolet,Turbidity,(w558 + w658) / w444,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538"
|
||||
Turb_Be16RedOverViolet,Turbidity,w658 / w444,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539"
|
||||
Turb_Bow06RedOverGreen,Turbidity,w658 / w558,"Bowers, D. G., and C. E. Binding. 2006. 闁炽儲缈籬e Optical Properties of Mineral Suspended Particles: A Review and Synthesis.闁?Estuarine Coastal and Shelf Science 67 (1闁?): 219闁?30. doi:10.1016/j.ecss.2005.11.010"
|
||||
Turb_Chip09NIROverGreen,Turbidity,w857 / w558,"Chipman, J. W.; Olmanson, L.G.; Gitelson, A.A.; Remote sensing methods for lake management: A guide for resource managers and decision-makers. 2009."
|
||||
Turb_Dox02NIRoverRed,Turbidity,w857 / w658,"Doxaran, D., Froidefond, J.-M.; Castaing, P. ; A reflectance band ratio used to estimate suspended matter concentrations in sediment-dominated coastal waters, Remote Sens., 2002, 23, 5079-5085"
|
||||
Turb_Frohn09GreenPlusRedBothOverBlue,Turbidity,(w558 + w658) / w458,"Frohn, R. C., & Autrey, B. C. (2009). Water quality assessment in the Ohio River using new indices for turbidity and chlorophyll-a with Landsat-7 Imagery. Draft Internal Report, US Environmental Protection Agency."
|
||||
Turb_Harr92NIR,Turbidity,w857,"Schiebe F.R., Harrington J.A., Ritchie J.C. Remote-Sensing of Suspended Sediments闁炽儲鏁刪e Lake Chicot, Arkansas Project. Int. J. Remote Sens. 1992;13:1487闁?509"
|
||||
Turb_Lath91RedOverBlue,Turbidity,w658 / w458,"Lathrop, R. G., Jr., T. M. Lillesand, and B. S. Yandell, 1991. Testing the utility of simple multi-date Thematic Mapper calibration algorithms for monitoring turbid inland waters. International Journal of Remote Sensing"
|
||||
Turb_Moore80Red,Turbidity,w658,"Moore, G.K., Satellite remote sensing of water turbidity, Hydrological Sciences, 1980, 25, 4, 407-422"
|
||||
|
315
src/gui/panels/report_generation_panel.py
Normal file
315
src/gui/panels/report_generation_panel.py
Normal file
@ -0,0 +1,315 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
ReportGenerationPanel - Word 分析报告生成面板
|
||||
"""
|
||||
|
||||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QHBoxLayout, QGroupBox, QFormLayout,
|
||||
QLabel, QCheckBox, QPushButton, QLineEdit, QSpinBox,
|
||||
QMessageBox, QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class ReportGenerateThread(QThread):
|
||||
"""后台生成 Word 报告(避免阻塞 UI)。"""
|
||||
finished_ok = pyqtSignal(str)
|
||||
failed = pyqtSignal(str)
|
||||
log_message = pyqtSignal(str, str)
|
||||
|
||||
def __init__(self, work_dir: str, output_dir: Optional[str], report_title: str, options: dict):
|
||||
super().__init__()
|
||||
self.work_dir = work_dir
|
||||
self.output_dir = output_dir
|
||||
self.report_title = report_title
|
||||
self.options = options
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
from src.postprocessing.report_word import WaterQualityReportGenerator, ReportGenerationConfig
|
||||
|
||||
url = (self.options.get("ollama_url") or "").strip() or None
|
||||
vision = (self.options.get("ollama_vision_model") or "").strip() or None
|
||||
text = (self.options.get("ollama_text_model") or "").strip() or None
|
||||
if self.options.get("text_same_as_vision"):
|
||||
text = vision
|
||||
timeout = self.options.get("ollama_timeout_s")
|
||||
enable_ai = self.options.get("enable_ai_analysis")
|
||||
|
||||
ai_cfg = ReportGenerationConfig(
|
||||
ollama_base_url=url,
|
||||
ollama_vision_model=vision,
|
||||
ollama_text_model=text,
|
||||
ollama_timeout_s=int(timeout) if timeout is not None else None,
|
||||
enable_ai_analysis=bool(enable_ai),
|
||||
)
|
||||
self.log_message.emit(
|
||||
f"报告生成:工作目录={self.work_dir},AI={'开' if enable_ai else '关'},"
|
||||
f"模型URL={url or '(环境变量 OLLAMA_URL)'}",
|
||||
"info",
|
||||
)
|
||||
gen = WaterQualityReportGenerator(
|
||||
work_dir=self.work_dir,
|
||||
output_dir=self.output_dir,
|
||||
ai_config=ai_cfg,
|
||||
)
|
||||
out_path = gen.generate_report(
|
||||
work_dir=self.work_dir,
|
||||
report_title=self.report_title or "水质参数反演分析报告",
|
||||
)
|
||||
self.finished_ok.emit(str(out_path))
|
||||
except Exception as e:
|
||||
self.failed.emit(f"{e}\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
class ReportGenerationPanel(QWidget):
|
||||
"""Word 报告生成:工作目录、输出目录、Ollama URL/模型、是否启用 AI 等。"""
|
||||
|
||||
def __init__(self, main_window=None, parent=None):
|
||||
super().__init__(parent)
|
||||
self.main_window = main_window
|
||||
self._report_thread = None
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
layout.setContentsMargins(10, 10, 10, 10)
|
||||
layout.setSpacing(10)
|
||||
|
||||
intro = QLabel(
|
||||
"根据工作目录下的可视化结果(14_visualization 等)生成 Word 分析报告。"
|
||||
"需已存在可视化图表;AI 分析通过 Ollama /api/chat 调用本地或远程服务。"
|
||||
)
|
||||
intro.setWordWrap(True)
|
||||
intro.setStyleSheet(
|
||||
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
|
||||
)
|
||||
layout.addWidget(intro)
|
||||
|
||||
path_group = QGroupBox("路径")
|
||||
path_form = QFormLayout()
|
||||
|
||||
wd_row = QHBoxLayout()
|
||||
self.work_dir_edit = QLineEdit()
|
||||
self.work_dir_edit.setPlaceholderText("选择流程工作目录(含 14_visualization)…")
|
||||
wd_browse = QPushButton("浏览…")
|
||||
wd_browse.clicked.connect(self.browse_work_dir)
|
||||
sync_btn = QPushButton("同步主窗口工作目录")
|
||||
sync_btn.clicked.connect(self.sync_work_dir_from_main)
|
||||
wd_row.addWidget(self.work_dir_edit, 1)
|
||||
wd_row.addWidget(wd_browse)
|
||||
wd_row.addWidget(sync_btn)
|
||||
path_form.addRow("工作目录:", wd_row)
|
||||
|
||||
out_row = QHBoxLayout()
|
||||
self.output_dir_edit = QLineEdit()
|
||||
self.output_dir_edit.setPlaceholderText("留空则保存到 工作目录/14_visualization")
|
||||
out_browse = QPushButton("浏览…")
|
||||
out_browse.clicked.connect(self.browse_output_dir)
|
||||
out_row.addWidget(self.output_dir_edit, 1)
|
||||
out_row.addWidget(out_browse)
|
||||
path_form.addRow("报告输出目录:", out_row)
|
||||
|
||||
self.report_title_edit = QLineEdit()
|
||||
self.report_title_edit.setText("水质参数反演分析报告")
|
||||
path_form.addRow("报告标题:", self.report_title_edit)
|
||||
|
||||
path_group.setLayout(path_form)
|
||||
layout.addWidget(path_group)
|
||||
|
||||
ai_group = QGroupBox("AI 分析(Ollama)")
|
||||
ai_form = QFormLayout()
|
||||
|
||||
self.enable_ai_cb = QCheckBox("启用 AI 图表解读与综合总结")
|
||||
self.enable_ai_cb.setChecked(
|
||||
os.environ.get("ENABLE_AI_ANALYSIS", "1") not in {"0", "false", "False"}
|
||||
)
|
||||
ai_form.addRow(self.enable_ai_cb)
|
||||
|
||||
self.ollama_url_edit = QLineEdit()
|
||||
self.ollama_url_edit.setText(
|
||||
os.environ.get("OLLAMA_URL", "http://localhost:11434").rstrip("/")
|
||||
)
|
||||
ai_form.addRow("服务 URL:", self.ollama_url_edit)
|
||||
|
||||
self.vision_model_edit = QLineEdit()
|
||||
self.vision_model_edit.setText(
|
||||
os.environ.get("OLLAMA_VISION_MODEL", "qwen3-vl:8b")
|
||||
)
|
||||
ai_form.addRow("视觉模型:", self.vision_model_edit)
|
||||
|
||||
self.same_text_model_cb = QCheckBox("文本总结与视觉使用同一模型")
|
||||
self.same_text_model_cb.setChecked(True)
|
||||
ai_form.addRow(self.same_text_model_cb)
|
||||
|
||||
self.text_model_edit = QLineEdit()
|
||||
self.text_model_edit.setText(
|
||||
os.environ.get(
|
||||
"OLLAMA_TEXT_MODEL",
|
||||
self.vision_model_edit.text() or "qwen3-vl:8b"
|
||||
)
|
||||
)
|
||||
self.text_model_edit.setEnabled(False)
|
||||
self.same_text_model_cb.toggled.connect(self._on_same_text_toggled)
|
||||
self.vision_model_edit.textChanged.connect(self._sync_text_model_if_linked)
|
||||
ai_form.addRow("文本模型:", self.text_model_edit)
|
||||
|
||||
self.timeout_spin = QSpinBox()
|
||||
self.timeout_spin.setRange(30, 3600)
|
||||
self.timeout_spin.setSingleStep(30)
|
||||
self.timeout_spin.setValue(int(os.environ.get("OLLAMA_TIMEOUT_S", "120")))
|
||||
ai_form.addRow("请求超时(秒):", self.timeout_spin)
|
||||
|
||||
ai_group.setLayout(ai_form)
|
||||
layout.addWidget(ai_group)
|
||||
|
||||
btn_row = QHBoxLayout()
|
||||
self.generate_btn = QPushButton("生成 Word 报告")
|
||||
self.generate_btn.setStyleSheet(
|
||||
ModernStylesheet.get_button_stylesheet("success")
|
||||
)
|
||||
self.generate_btn.clicked.connect(self.on_generate_clicked)
|
||||
btn_row.addWidget(self.generate_btn)
|
||||
btn_row.addStretch()
|
||||
layout.addLayout(btn_row)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def _on_same_text_toggled(self, checked: bool):
|
||||
self.text_model_edit.setEnabled(not checked)
|
||||
if checked:
|
||||
self.text_model_edit.setText(self.vision_model_edit.text())
|
||||
|
||||
def _sync_text_model_if_linked(self, _t=None):
|
||||
if self.same_text_model_cb.isChecked():
|
||||
self.text_model_edit.blockSignals(True)
|
||||
self.text_model_edit.setText(self.vision_model_edit.text())
|
||||
self.text_model_edit.blockSignals(False)
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用主窗口缓存的 work_dir"""
|
||||
if self.main_window and hasattr(self.main_window, 'work_dir') and self.main_window.work_dir:
|
||||
return str(self.main_window.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_work_dir(self):
|
||||
default = self._get_default_work_dir()
|
||||
d = QFileDialog.getExistingDirectory(self, "选择工作目录", default)
|
||||
if d:
|
||||
self.work_dir_edit.setText(d)
|
||||
|
||||
def browse_output_dir(self):
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "14_visualization")
|
||||
d = QFileDialog.getExistingDirectory(self, "选择报告输出目录", default)
|
||||
if d:
|
||||
self.output_dir_edit.setText(d)
|
||||
|
||||
def sync_work_dir_from_main(self):
|
||||
mw = self.main_window
|
||||
if mw is not None and getattr(mw, "work_dir", None):
|
||||
self.work_dir_edit.setText(str(mw.work_dir))
|
||||
else:
|
||||
QMessageBox.information(self, "提示", "主窗口尚未设置工作目录。")
|
||||
|
||||
def set_work_dir(self, work_dir):
|
||||
if work_dir:
|
||||
self.work_dir_edit.setText(str(work_dir))
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"work_dir": self.work_dir_edit.text().strip() or None,
|
||||
"output_dir": self.output_dir_edit.text().strip() or None,
|
||||
"report_title": self.report_title_edit.text().strip() or "水质参数反演分析报告",
|
||||
"ollama_url": self.ollama_url_edit.text().strip(),
|
||||
"ollama_vision_model": self.vision_model_edit.text().strip(),
|
||||
"ollama_text_model": self.text_model_edit.text().strip(),
|
||||
"text_same_as_vision": self.same_text_model_cb.isChecked(),
|
||||
"ollama_timeout_s": self.timeout_spin.value(),
|
||||
"enable_ai_analysis": self.enable_ai_cb.isChecked(),
|
||||
}
|
||||
|
||||
def set_config(self, config):
|
||||
if not config:
|
||||
return
|
||||
if config.get("work_dir"):
|
||||
self.work_dir_edit.setText(str(config["work_dir"]))
|
||||
if "output_dir" in config:
|
||||
self.output_dir_edit.setText(str(config["output_dir"] or ""))
|
||||
if config.get("report_title"):
|
||||
self.report_title_edit.setText(str(config["report_title"]))
|
||||
if config.get("ollama_url"):
|
||||
self.ollama_url_edit.setText(str(config["ollama_url"]))
|
||||
if config.get("ollama_vision_model"):
|
||||
self.vision_model_edit.setText(str(config["ollama_vision_model"]))
|
||||
if "text_same_as_vision" in config:
|
||||
self.same_text_model_cb.setChecked(bool(config["text_same_as_vision"]))
|
||||
if config.get("ollama_text_model"):
|
||||
self.text_model_edit.setText(str(config["ollama_text_model"]))
|
||||
if config.get("ollama_timeout_s") is not None:
|
||||
self.timeout_spin.setValue(int(config["ollama_timeout_s"]))
|
||||
if "enable_ai_analysis" in config:
|
||||
self.enable_ai_cb.setChecked(bool(config["enable_ai_analysis"]))
|
||||
|
||||
def on_generate_clicked(self):
|
||||
wd = self.work_dir_edit.text().strip()
|
||||
if not wd or not os.path.isdir(wd):
|
||||
QMessageBox.warning(self, "提示", "请选择有效的工作目录。")
|
||||
return
|
||||
viz = Path(wd) / "14_visualization"
|
||||
if not viz.is_dir():
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"提示",
|
||||
f"未找到可视化目录:\n{viz}\n请先完成流程或生成可视化。",
|
||||
)
|
||||
return
|
||||
if self._report_thread and self._report_thread.isRunning():
|
||||
QMessageBox.information(self, "提示", "报告正在生成中,请稍候。")
|
||||
return
|
||||
|
||||
out = self.output_dir_edit.text().strip() or None
|
||||
title = self.report_title_edit.text().strip() or "水质参数反演分析报告"
|
||||
opts = {
|
||||
"ollama_url": self.ollama_url_edit.text().strip(),
|
||||
"ollama_vision_model": self.vision_model_edit.text().strip(),
|
||||
"ollama_text_model": self.text_model_edit.text().strip(),
|
||||
"text_same_as_vision": self.same_text_model_cb.isChecked(),
|
||||
"ollama_timeout_s": self.timeout_spin.value(),
|
||||
"enable_ai_analysis": self.enable_ai_cb.isChecked(),
|
||||
}
|
||||
self.generate_btn.setEnabled(False)
|
||||
self._report_thread = ReportGenerateThread(wd, out, title, opts)
|
||||
self._report_thread.log_message.connect(self._forward_log, Qt.QueuedConnection)
|
||||
self._report_thread.finished_ok.connect(self._on_report_ok, Qt.QueuedConnection)
|
||||
self._report_thread.failed.connect(self._on_report_fail, Qt.QueuedConnection)
|
||||
self._report_thread.finished.connect(
|
||||
lambda: self.generate_btn.setEnabled(True), Qt.QueuedConnection
|
||||
)
|
||||
self._report_thread.start()
|
||||
self._forward_log("已开始生成 Word 报告…", "info")
|
||||
|
||||
def _forward_log(self, msg: str, level: str):
|
||||
mw = self.main_window
|
||||
if mw is not None and hasattr(mw, "log_message"):
|
||||
mw.log_message(msg, level)
|
||||
else:
|
||||
print(f"[{level}] {msg}")
|
||||
|
||||
def _on_report_ok(self, path: str):
|
||||
QMessageBox.information(self, "完成", f"报告已生成:\n{path}")
|
||||
self._forward_log(f"Word 报告已保存: {path}", "info")
|
||||
|
||||
def _on_report_fail(self, err: str):
|
||||
QMessageBox.critical(self, "失败", f"报告生成失败:\n{err[:800]}")
|
||||
self._forward_log(err, "error")
|
||||
282
src/gui/panels/step1_panel.py
Normal file
282
src/gui/panels/step1_panel.py
Normal file
@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step1 面板 - 水域掩膜生成
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QHBoxLayout, QGroupBox, QLabel,
|
||||
QDoubleSpinBox, QCheckBox, QPushButton, QFormLayout, QRadioButton,
|
||||
QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
# 从公共组件库导入
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step1Panel(QWidget):
|
||||
"""1. 水域掩膜生成"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 掩膜生成方式选择
|
||||
method_group = QGroupBox("掩膜生成方式")
|
||||
method_layout = QVBoxLayout()
|
||||
|
||||
# 使用现有掩膜文件
|
||||
self.use_existing_radio = QRadioButton("使用现有掩膜文件")
|
||||
self.use_existing_radio.setChecked(True)
|
||||
method_layout.addWidget(self.use_existing_radio)
|
||||
|
||||
# 使用NDWI自动生成
|
||||
self.use_ndwi_radio = QRadioButton("使用NDWI自动生成")
|
||||
method_layout.addWidget(self.use_ndwi_radio)
|
||||
|
||||
# 应用QRadioButton样式(实心选中点)
|
||||
radio_style = """
|
||||
QRadioButton::indicator {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border-radius: 8px;
|
||||
border: 2px solid #999;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
background-color: #0078D7;
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
QRadioButton::indicator:unchecked {
|
||||
background-color: white;
|
||||
border: 2px solid #999;
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #0078D7;
|
||||
}
|
||||
"""
|
||||
self.use_existing_radio.setStyleSheet(radio_style)
|
||||
self.use_ndwi_radio.setStyleSheet(radio_style)
|
||||
|
||||
method_group.setLayout(method_layout)
|
||||
layout.addWidget(method_group)
|
||||
|
||||
# 掩膜文件选择
|
||||
self.mask_file = FileSelectWidget(
|
||||
"掩膜文件:",
|
||||
"Shapefiles (*.shp);;Raster Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.mask_file)
|
||||
|
||||
# 影像文件选择(用于shp栅格化或NDWI生成)
|
||||
self.img_file = FileSelectWidget(
|
||||
"参考影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.img_file)
|
||||
|
||||
# NDWI参数设置
|
||||
self.ndwi_group = QGroupBox("NDWI参数设置")
|
||||
ndwi_layout = QVBoxLayout()
|
||||
|
||||
# NDWI阈值
|
||||
threshold_layout = QHBoxLayout()
|
||||
threshold_layout.addWidget(QLabel("NDWI阈值:"))
|
||||
self.ndwi_threshold = QDoubleSpinBox()
|
||||
self.ndwi_threshold.setRange(0.0, 1.0)
|
||||
self.ndwi_threshold.setSingleStep(0.05)
|
||||
self.ndwi_threshold.setValue(0.4)
|
||||
self.ndwi_threshold.setDecimals(2)
|
||||
threshold_layout.addWidget(self.ndwi_threshold)
|
||||
threshold_layout.addStretch()
|
||||
ndwi_layout.addLayout(threshold_layout)
|
||||
|
||||
self.ndwi_group.setLayout(ndwi_layout)
|
||||
layout.addWidget(self.ndwi_group)
|
||||
|
||||
# 输出文件路径(使用save模式)
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)",
|
||||
mode="save"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("water_mask.dat")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 提示信息 - 专业的 Info Alert 样式
|
||||
hint = QLabel("💡 提示: 如果掩膜文件是Shapefile(.shp),需要提供参考影像用于栅格化;如果使用NDWI自动生成,只需要提供参考影像")
|
||||
hint.setWordWrap(True) # 允许自动换行
|
||||
hint.setStyleSheet("""
|
||||
QLabel {
|
||||
color: #0055D4;
|
||||
font-size: 13px;
|
||||
font-weight: bold;
|
||||
background-color: #E8F4FF;
|
||||
border: 2px solid #0055D4;
|
||||
border-radius: 8px;
|
||||
padding: 12px 16px;
|
||||
margin: 8px 0px;
|
||||
}
|
||||
""")
|
||||
layout.addWidget(hint)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
# 连接信号
|
||||
self.use_existing_radio.toggled.connect(self.update_ui_state)
|
||||
self.use_ndwi_radio.toggled.connect(self.update_ui_state)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
# 初始UI状态
|
||||
self.update_ui_state()
|
||||
|
||||
def update_ui_state(self):
|
||||
"""根据选择的掩膜生成方式更新UI状态(使用显示/隐藏控制)"""
|
||||
use_ndwi = self.use_ndwi_radio.isChecked()
|
||||
|
||||
# 动态显示/隐藏组件
|
||||
if use_ndwi:
|
||||
# 使用NDWI模式:隐藏掩膜文件,显示NDWI参数和输出掩膜
|
||||
self.mask_file.setVisible(False)
|
||||
self.ndwi_group.setVisible(True)
|
||||
self.output_file.setVisible(True) # 显示输出掩膜路径
|
||||
|
||||
# 当切换到NDWI模式时,如果工作目录已设置,自动填充输出路径
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
self._auto_fill_output_path()
|
||||
else:
|
||||
# 使用现有掩膜模式:显示掩膜文件,隐藏NDWI参数和输出掩膜
|
||||
self.mask_file.setVisible(True)
|
||||
self.ndwi_group.setVisible(False)
|
||||
self.output_file.setVisible(False) # 隐藏输出掩膜路径
|
||||
|
||||
# 参考影像在两种模式下都显示
|
||||
self.img_file.setVisible(True)
|
||||
|
||||
def update_work_directory(self, work_dir):
|
||||
"""
|
||||
保存工作目录引用,用于后续自动填充路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
"""
|
||||
if not work_dir:
|
||||
return
|
||||
|
||||
# 保存工作目录引用
|
||||
self.work_dir = work_dir
|
||||
|
||||
# 如果当前选中的是NDWI模式,立即填充输出路径
|
||||
if self.use_ndwi_radio.isChecked():
|
||||
self._auto_fill_output_path()
|
||||
|
||||
def _auto_fill_output_path(self):
|
||||
"""
|
||||
自动填充输出掩膜路径(仅在NDWI模式下)
|
||||
确保路径使用正斜杠,避免斜杠混用
|
||||
"""
|
||||
if not hasattr(self, 'work_dir') or not self.work_dir:
|
||||
return
|
||||
|
||||
# 生成输出掩膜的完整路径
|
||||
output_dir = os.path.join(self.work_dir, "1_water_mask")
|
||||
os.makedirs(output_dir, exist_ok=True) # 确保目录存在
|
||||
|
||||
# 统一使用正斜杠,避免 \ 和 / 混用
|
||||
default_output_path = os.path.join(output_dir, "water_mask_out.dat").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
use_ndwi = self.use_ndwi_radio.isChecked()
|
||||
|
||||
config = {
|
||||
'mask_path': None if use_ndwi else self.mask_file.get_path(),
|
||||
'use_ndwi': use_ndwi,
|
||||
'ndwi_threshold': self.ndwi_threshold.value()
|
||||
}
|
||||
|
||||
# 参考影像路径(两种模式都可能需要)
|
||||
img_path = self.img_file.get_path()
|
||||
if img_path:
|
||||
config['img_path'] = img_path
|
||||
|
||||
# 输出路径:仅在NDWI模式下有效
|
||||
if use_ndwi:
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
else:
|
||||
# 使用现有掩膜时,不传递output_path,避免底层错误尝试保存文件
|
||||
config['output_path'] = None
|
||||
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'mask_path' in config:
|
||||
self.mask_file.set_path(config['mask_path'])
|
||||
if 'img_path' in config:
|
||||
self.img_file.set_path(config['img_path'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
if 'use_ndwi' in config:
|
||||
if config['use_ndwi']:
|
||||
self.use_ndwi_radio.setChecked(True)
|
||||
else:
|
||||
self.use_existing_radio.setChecked(True)
|
||||
if 'ndwi_threshold' in config:
|
||||
self.ndwi_threshold.setValue(config['ndwi_threshold'])
|
||||
|
||||
self.update_ui_state()
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤1"""
|
||||
# 验证输入
|
||||
if self.use_ndwi_radio.isChecked():
|
||||
# NDWI模式:需要影像文件
|
||||
img_path = self.img_file.get_path()
|
||||
if not img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择参考影像文件!")
|
||||
return
|
||||
else:
|
||||
# 现有掩膜模式:需要掩膜文件
|
||||
mask_path = self.mask_file.get_path()
|
||||
if not mask_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择掩膜文件!")
|
||||
return
|
||||
|
||||
# 如果是shp文件,还需要影像文件
|
||||
if mask_path.lower().endswith('.shp'):
|
||||
img_path = self.img_file.get_path()
|
||||
if not img_path:
|
||||
QMessageBox.warning(self, "输入错误", "当使用shp文件时,需要提供参考影像用于栅格化!")
|
||||
return
|
||||
|
||||
# 获取父窗口并运行步骤
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
config = {'step1': self.get_config()}
|
||||
parent.run_single_step("step1", config)
|
||||
210
src/gui/panels/step2_panel.py
Normal file
210
src/gui/panels/step2_panel.py
Normal file
@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step2 面板 - 耀斑区域识别
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QDoubleSpinBox, QSpinBox, QComboBox, QCheckBox, QPushButton,
|
||||
QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
# 从公共组件库导入
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step2Panel(QWidget):
|
||||
"""2. 耀斑区域识别"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.work_dir = None
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 影像文件
|
||||
self.img_file = FileSelectWidget(
|
||||
"影像文件:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.img_file)
|
||||
|
||||
# 水域掩膜文件(可选,用于独立运行)
|
||||
self.water_mask_file = FileSelectWidget(
|
||||
"水域掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.water_mask_file.label.setText("水域掩膜:")
|
||||
layout.addWidget(self.water_mask_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("检测参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
# 耀斑波长
|
||||
self.glint_wave = QDoubleSpinBox()
|
||||
self.glint_wave.setRange(300, 1000)
|
||||
self.glint_wave.setValue(750.0)
|
||||
self.glint_wave.setSuffix(" nm")
|
||||
params_layout.addRow("耀斑检测波长:", self.glint_wave)
|
||||
|
||||
# 检测方法
|
||||
self.method = QComboBox()
|
||||
self.method.addItem("Otsu 阈值法", "otsu")
|
||||
self.method.addItem("Z-Score 方法", "zscore")
|
||||
self.method.addItem("百分位数法", "percentile")
|
||||
self.method.addItem("IQR 四分位距法", "iqr")
|
||||
self.method.addItem("自适应阈值法", "adaptive")
|
||||
self.method.addItem("多波段综合法", "multi_band")
|
||||
params_layout.addRow("检测方法:", self.method)
|
||||
|
||||
# 最大连通域面积
|
||||
self.max_area = QSpinBox()
|
||||
self.max_area.setRange(0, 100000)
|
||||
self.max_area.setValue(50)
|
||||
self.max_area.setSpecialValueText("不过滤")
|
||||
params_layout.addRow("最大连通域面积:", self.max_area)
|
||||
|
||||
# 岸边缓冲区
|
||||
self.buffer_size = QSpinBox()
|
||||
self.buffer_size.setRange(0, 200)
|
||||
self.buffer_size.setValue(10)
|
||||
self.buffer_size.setSpecialValueText("不设置")
|
||||
params_layout.addRow("岸边缓冲区大小:", self.buffer_size)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出耀斑掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
# 信号连接:影像文件路径变化时动态更新波段范围
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'img_path': self.img_file.get_path(),
|
||||
'glint_wave': self.glint_wave.value(),
|
||||
'method': self.method.currentData(), # 使用 currentData() 获取英文ID
|
||||
}
|
||||
if self.max_area.value() > 0:
|
||||
config['max_area'] = self.max_area.value()
|
||||
if self.buffer_size.value() > 0:
|
||||
config['buffer_size'] = self.buffer_size.value()
|
||||
# 添加水域掩膜路径(用于独立运行)
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['water_mask_path'] = water_mask_path
|
||||
# 添加输出路径
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'img_path' in config:
|
||||
self.img_file.set_path(config['img_path'])
|
||||
if 'glint_wave' in config:
|
||||
self.glint_wave.setValue(config['glint_wave'])
|
||||
if 'method' in config:
|
||||
idx = self.method.findData(config['method']) # 使用 findData()
|
||||
if idx >= 0:
|
||||
self.method.setCurrentIndex(idx)
|
||||
if 'max_area' in config:
|
||||
self.max_area.setValue(config['max_area'])
|
||||
if 'buffer_size' in config:
|
||||
self.buffer_size.setValue(config['buffer_size'])
|
||||
if 'water_mask_path' in config:
|
||||
self.water_mask_file.set_path(config['water_mask_path'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""
|
||||
从全局配置/Pipeline 或 Step1Panel 自动填充路径,实现上下游数据流转
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例,用于获取步骤1生成的水域掩膜路径
|
||||
"""
|
||||
# 保存工作目录引用
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass # 保持现有工作目录
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Pipeline 获取
|
||||
mask_path = None
|
||||
if pipeline and hasattr(pipeline, 'water_mask_path') and pipeline.water_mask_path:
|
||||
mask_path = pipeline.water_mask_path
|
||||
|
||||
# 2. 如果 Pipeline 中没有,则尝试直接从 Step1 界面读取(关键修复)
|
||||
main_window = self.window()
|
||||
if not mask_path and hasattr(main_window, 'step1_panel'):
|
||||
if main_window.step1_panel.use_ndwi_radio.isChecked():
|
||||
# NDWI模式,读取输出框的路径
|
||||
mask_path = main_window.step1_panel.output_file.get_path()
|
||||
else:
|
||||
# 导入现有模式,读取输入框的路径
|
||||
mask_path = main_window.step1_panel.mask_file.get_path()
|
||||
|
||||
# 填充获取到的路径
|
||||
if mask_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(mask_path):
|
||||
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
|
||||
self.water_mask_file.set_path(mask_path)
|
||||
|
||||
# 3. 自动填充输出路径(基于工作目录)
|
||||
if self.work_dir:
|
||||
# 生成输出耀斑掩膜的标准路径:workspace/2_glint_mask/glint_mask_out.dat
|
||||
output_dir = os.path.join(self.work_dir, "2_glint_mask")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
default_output_path = os.path.join(output_dir, "glint_mask_out.dat").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
else:
|
||||
# 没有工作目录时,清空输出路径
|
||||
self.output_file.set_path("")
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤2"""
|
||||
# 验证输入
|
||||
img_path = self.img_file.get_path()
|
||||
if not img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择影像文件!")
|
||||
return
|
||||
|
||||
# 获取主窗口并运行步骤
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step2': self.get_config()}
|
||||
main_window.run_single_step('step2', config)
|
||||
451
src/gui/panels/step3_panel.py
Normal file
451
src/gui/panels/step3_panel.py
Normal file
@ -0,0 +1,451 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step3 面板 - 耀斑去除
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QDoubleSpinBox, QSpinBox, QComboBox, QCheckBox, QPushButton,
|
||||
QLabel, QLineEdit, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
# 从公共组件库导入
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step3Panel(QWidget):
|
||||
"""步骤3:耀斑去除"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 影像文件
|
||||
self.img_file = FileSelectWidget(
|
||||
"影像文件:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.img_file)
|
||||
|
||||
# 水域掩膜/边界:完整流程可由步骤1自动生成;独立单步运行时须手动指定
|
||||
self.water_mask_file = FileSelectWidget(
|
||||
"水域掩膜/边界:",
|
||||
"Mask/Boundary (*.dat *.tif *.shp);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.water_mask_file)
|
||||
step3_mask_hint = QLabel(
|
||||
"提示:独立运行本步骤时必须选择水域掩膜或边界(与影像同区域的 .dat/.tif 掩膜,或 .shp 矢量)。"
|
||||
)
|
||||
step3_mask_hint.setWordWrap(True)
|
||||
step3_mask_hint.setStyleSheet("color: #666; font-size: 10px;")
|
||||
layout.addWidget(step3_mask_hint)
|
||||
|
||||
# 方法选择
|
||||
method_group = QGroupBox("去耀斑方法")
|
||||
method_layout = QVBoxLayout()
|
||||
|
||||
self.method = QComboBox()
|
||||
for text, data in [('Goodman方法', 'goodman'), ('Kutser方法', 'kutser'),
|
||||
('Hedley方法', 'hedley'), ('SUGAR算法', 'sugar')]:
|
||||
self.method.addItem(text, data)
|
||||
self.method.currentIndexChanged.connect(self._on_method_changed)
|
||||
method_layout.addWidget(self.method)
|
||||
|
||||
method_group.setLayout(method_layout)
|
||||
layout.addWidget(method_group)
|
||||
|
||||
# Goodman参数组
|
||||
self.goodman_group = QGroupBox("Goodman方法参数")
|
||||
goodman_layout = QFormLayout()
|
||||
|
||||
self.nir_lower = QSpinBox()
|
||||
self.nir_lower.setRange(0, 200)
|
||||
self.nir_lower.setValue(65)
|
||||
goodman_layout.addRow("NIR下波段索引:", self.nir_lower)
|
||||
|
||||
self.nir_upper = QSpinBox()
|
||||
self.nir_upper.setRange(0, 200)
|
||||
self.nir_upper.setValue(91)
|
||||
goodman_layout.addRow("NIR上波段索引:", self.nir_upper)
|
||||
|
||||
self.goodman_a = QDoubleSpinBox()
|
||||
self.goodman_a.setDecimals(6)
|
||||
self.goodman_a.setRange(0, 1)
|
||||
self.goodman_a.setValue(0.000019)
|
||||
goodman_layout.addRow("参数A:", self.goodman_a)
|
||||
|
||||
self.goodman_b = QDoubleSpinBox()
|
||||
self.goodman_b.setDecimals(2)
|
||||
self.goodman_b.setRange(0, 1)
|
||||
self.goodman_b.setValue(0.1)
|
||||
goodman_layout.addRow("参数B:", self.goodman_b)
|
||||
|
||||
self.goodman_group.setLayout(goodman_layout)
|
||||
layout.addWidget(self.goodman_group)
|
||||
|
||||
# Kutser参数组
|
||||
self.kutser_group = QGroupBox("Kutser方法参数")
|
||||
kutser_layout = QFormLayout()
|
||||
|
||||
self.oxy_band = QSpinBox()
|
||||
self.oxy_band.setRange(0, 200)
|
||||
self.oxy_band.setValue(38)
|
||||
kutser_layout.addRow("氧吸收波段索引:", self.oxy_band)
|
||||
|
||||
self.lower_oxy = QSpinBox()
|
||||
self.lower_oxy.setRange(0, 200)
|
||||
self.lower_oxy.setValue(36)
|
||||
kutser_layout.addRow("下氧吸收波段索引:", self.lower_oxy)
|
||||
|
||||
self.upper_oxy = QSpinBox()
|
||||
self.upper_oxy.setRange(0, 200)
|
||||
self.upper_oxy.setValue(49)
|
||||
kutser_layout.addRow("上氧吸收波段索引:", self.upper_oxy)
|
||||
|
||||
self.nir_band = QSpinBox()
|
||||
self.nir_band.setRange(0, 200)
|
||||
self.nir_band.setValue(47)
|
||||
kutser_layout.addRow("NIR波段索引:", self.nir_band)
|
||||
|
||||
self.kutser_group.setLayout(kutser_layout)
|
||||
self.kutser_group.setVisible(False)
|
||||
layout.addWidget(self.kutser_group)
|
||||
|
||||
# Hedley参数组
|
||||
self.hedley_group = QGroupBox("Hedley方法参数")
|
||||
hedley_layout = QFormLayout()
|
||||
|
||||
self.hedley_nir_band = QSpinBox()
|
||||
self.hedley_nir_band.setRange(0, 200)
|
||||
self.hedley_nir_band.setValue(47)
|
||||
hedley_layout.addRow("NIR波段索引:", self.hedley_nir_band)
|
||||
|
||||
self.hedley_group.setLayout(hedley_layout)
|
||||
self.hedley_group.setVisible(False)
|
||||
layout.addWidget(self.hedley_group)
|
||||
|
||||
# SUGAR参数组
|
||||
self.sugar_group = QGroupBox("SUGAR方法参数")
|
||||
sugar_layout = QFormLayout()
|
||||
|
||||
self.sugar_iter = QSpinBox()
|
||||
self.sugar_iter.setRange(1, 20)
|
||||
self.sugar_iter.setValue(3)
|
||||
self.sugar_iter.setSpecialValueText("自动")
|
||||
sugar_layout.addRow("迭代次数:", self.sugar_iter)
|
||||
|
||||
self.sugar_sigma = QDoubleSpinBox()
|
||||
self.sugar_sigma.setDecimals(2)
|
||||
self.sugar_sigma.setRange(0.1, 10)
|
||||
self.sugar_sigma.setValue(1.0)
|
||||
sugar_layout.addRow("LoG平滑σ:", self.sugar_sigma)
|
||||
|
||||
self.sugar_estimate_background = QCheckBox()
|
||||
self.sugar_estimate_background.setChecked(True)
|
||||
sugar_layout.addRow("估计背景光谱:", self.sugar_estimate_background)
|
||||
|
||||
self.sugar_glint_mask_method = QComboBox()
|
||||
self.sugar_glint_mask_method.addItems(['cdf', 'otsu'])
|
||||
self.sugar_glint_mask_method.setCurrentText('cdf')
|
||||
sugar_layout.addRow("耀斑掩膜方法:", self.sugar_glint_mask_method)
|
||||
|
||||
self.sugar_termination_thresh = QDoubleSpinBox()
|
||||
self.sugar_termination_thresh.setDecimals(2)
|
||||
self.sugar_termination_thresh.setRange(1, 100)
|
||||
self.sugar_termination_thresh.setValue(20.0)
|
||||
sugar_layout.addRow("终止阈值:", self.sugar_termination_thresh)
|
||||
|
||||
self.sugar_bounds = QLineEdit()
|
||||
self.sugar_bounds.setText("[(1, 2)]")
|
||||
sugar_layout.addRow("优化边界:", self.sugar_bounds)
|
||||
|
||||
self.sugar_group.setLayout(sugar_layout)
|
||||
self.sugar_group.setVisible(False)
|
||||
layout.addWidget(self.sugar_group)
|
||||
|
||||
# 插值选项
|
||||
interp_group = QGroupBox("0值像素插值")
|
||||
interp_layout = QFormLayout()
|
||||
|
||||
self.interpolate_zeros = QCheckBox("启用插值")
|
||||
interp_layout.addRow("", self.interpolate_zeros)
|
||||
|
||||
self.interp_method = QComboBox()
|
||||
for text, data in [('最近邻插值', 'nearest'), ('双线性插值', 'bilinear'),
|
||||
('样条插值', 'spline'), ('克里金插值', 'kriging')]:
|
||||
self.interp_method.addItem(text, data)
|
||||
self.interp_method.setCurrentIndex(1) # 默认双线性插值
|
||||
interp_layout.addRow("插值方法:", self.interp_method)
|
||||
|
||||
interp_group.setLayout(interp_layout)
|
||||
layout.addWidget(interp_group)
|
||||
|
||||
# # 实测经纬度参考点
|
||||
# self.ref_csv_file = FileSelectWidget(
|
||||
# "实测经纬度CSV:",
|
||||
# "CSV Files (*.csv);;All Files (*.*)"
|
||||
# )
|
||||
# self.ref_csv_file.line_edit.setPlaceholderText("可选:包含 Lon/Lat 列的 CSV 文件")
|
||||
# layout.addWidget(self.ref_csv_file)
|
||||
|
||||
# 交互式预览按钮
|
||||
# self.preview_btn = QPushButton("👁️ 打开交互式影像预览")
|
||||
# self.preview_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('info'))
|
||||
# self.preview_btn.clicked.connect(self.open_interactive_viewer)
|
||||
# layout.addWidget(self.preview_btn)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("deglint_image.bsq")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
# 信号连接:影像文件路径变化时动态更新波段范围
|
||||
self.img_file.line_edit.textChanged.connect(self._update_band_ranges)
|
||||
|
||||
def open_interactive_viewer(self):
|
||||
"""打开交互式影像预览"""
|
||||
from src.gui.water_quality_gui import InteractiveViewerDialog
|
||||
img_path = self.img_file.get_path()
|
||||
if not img_path or not os.path.isfile(img_path):
|
||||
QMessageBox.warning(self, "警告", "请先选择影像文件!")
|
||||
return
|
||||
|
||||
water_mask = self.water_mask_file.get_path()
|
||||
|
||||
dialog = InteractiveViewerDialog(img_path, self)
|
||||
if water_mask and os.path.isfile(water_mask):
|
||||
dialog.load_water_mask(water_mask)
|
||||
dialog.exec_()
|
||||
|
||||
def _update_band_ranges(self, file_path):
|
||||
"""根据选择的影像动态限制波段索引的输入范围"""
|
||||
from osgeo import gdal
|
||||
|
||||
if not file_path or not os.path.isfile(file_path):
|
||||
return
|
||||
|
||||
try:
|
||||
dataset = gdal.Open(file_path)
|
||||
if dataset is None:
|
||||
return
|
||||
raster_count = dataset.RasterCount
|
||||
max_band = max(0, raster_count - 1)
|
||||
self.nir_lower.setMaximum(max_band)
|
||||
self.nir_upper.setMaximum(max_band)
|
||||
self.oxy_band.setMaximum(max_band)
|
||||
self.nir_band.setMaximum(max_band)
|
||||
self.hedley_nir_band.setMaximum(max_band)
|
||||
dataset = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""
|
||||
从 Step1Panel 自动填充水域掩膜路径,实现上下游数据流转
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
# 保存工作目录引用
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass # 保持现有工作目录
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 从 Step1 界面读取水域掩膜路径
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'step1_panel'):
|
||||
if main_window.step1_panel.use_ndwi_radio.isChecked():
|
||||
# NDWI模式,读取输出框的路径
|
||||
mask_path = main_window.step1_panel.output_file.get_path()
|
||||
else:
|
||||
# 导入现有模式,读取输入框的路径
|
||||
mask_path = main_window.step1_panel.mask_file.get_path()
|
||||
|
||||
if mask_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(mask_path):
|
||||
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
|
||||
self.water_mask_file.set_path(mask_path)
|
||||
|
||||
# 自动填充输出路径(基于工作目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "3_deglint")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
default_output_path = os.path.join(output_dir, "deglint_image.bsq").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
else:
|
||||
self.output_file.set_path("")
|
||||
|
||||
def _on_method_changed(self, index):
|
||||
"""方法改变时更新参数显示"""
|
||||
method_id = self.method.currentData()
|
||||
self.goodman_group.setVisible(method_id == 'goodman')
|
||||
self.kutser_group.setVisible(method_id == 'kutser')
|
||||
self.hedley_group.setVisible(method_id == 'hedley')
|
||||
self.sugar_group.setVisible(method_id == 'sugar')
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'img_path': self.img_file.get_path(),
|
||||
'method': self.method.currentData(), # 使用 currentData() 获取英文ID
|
||||
'enabled': self.enable_checkbox.isChecked(),
|
||||
'interpolate_zeros': self.interpolate_zeros.isChecked(),
|
||||
'interpolation_method': self.interp_method.currentData(), # 使用 currentData()
|
||||
}
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['water_mask'] = water_mask_path
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
|
||||
method = self.method.currentData() # 使用 currentData()
|
||||
|
||||
if method == 'goodman':
|
||||
config['nir_lower'] = self.nir_lower.value()
|
||||
config['nir_upper'] = self.nir_upper.value()
|
||||
config['goodman_A'] = self.goodman_a.value()
|
||||
config['goodman_B'] = self.goodman_b.value()
|
||||
|
||||
elif method == 'kutser':
|
||||
config['oxy_band'] = self.oxy_band.value()
|
||||
config['lower_oxy'] = self.lower_oxy.value()
|
||||
config['upper_oxy'] = self.upper_oxy.value()
|
||||
config['nir_band'] = self.nir_band.value()
|
||||
|
||||
elif method == 'hedley':
|
||||
config['hedley_nir_band'] = self.hedley_nir_band.value()
|
||||
|
||||
elif method == 'sugar':
|
||||
config['sugar_iter'] = self.sugar_iter.value() if self.sugar_iter.value() > 0 else None
|
||||
config['sugar_sigma'] = self.sugar_sigma.value()
|
||||
config['sugar_estimate_background'] = self.sugar_estimate_background.isChecked()
|
||||
config['sugar_glint_mask_method'] = self.sugar_glint_mask_method.currentData()
|
||||
config['sugar_termination_thresh'] = self.sugar_termination_thresh.value()
|
||||
# 解析bounds字符串
|
||||
try:
|
||||
import ast
|
||||
config['sugar_bounds'] = ast.literal_eval(self.sugar_bounds.text())
|
||||
except:
|
||||
config['sugar_bounds'] = [(1, 2)] # 默认值
|
||||
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'img_path' in config:
|
||||
self.img_file.set_path(config['img_path'])
|
||||
if 'water_mask' in config:
|
||||
self.water_mask_file.set_path(config['water_mask'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
if 'reference_csv' in config:
|
||||
self.ref_csv_file.set_path(config['reference_csv'])
|
||||
if 'method' in config:
|
||||
idx = self.method.findData(config['method']) # 使用 findData()
|
||||
if idx >= 0:
|
||||
self.method.setCurrentIndex(idx)
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
if 'interpolate_zeros' in config:
|
||||
self.interpolate_zeros.setChecked(config['interpolate_zeros'])
|
||||
if 'interpolation_method' in config:
|
||||
idx = self.interp_method.findData(config['interpolation_method']) # 使用 findData()
|
||||
if idx >= 0:
|
||||
self.interp_method.setCurrentIndex(idx)
|
||||
|
||||
# Goodman参数
|
||||
if 'nir_lower' in config:
|
||||
self.nir_lower.setValue(config['nir_lower'])
|
||||
if 'nir_upper' in config:
|
||||
self.nir_upper.setValue(config['nir_upper'])
|
||||
if 'goodman_A' in config:
|
||||
self.goodman_a.setValue(config['goodman_A'])
|
||||
if 'goodman_B' in config:
|
||||
self.goodman_b.setValue(config['goodman_B'])
|
||||
|
||||
# Kutser参数
|
||||
if 'oxy_band' in config:
|
||||
self.oxy_band.setValue(config['oxy_band'])
|
||||
if 'lower_oxy' in config:
|
||||
self.lower_oxy.setValue(config['lower_oxy'])
|
||||
if 'upper_oxy' in config:
|
||||
self.upper_oxy.setValue(config['upper_oxy'])
|
||||
if 'nir_band' in config:
|
||||
self.nir_band.setValue(config['nir_band'])
|
||||
|
||||
# Hedley参数
|
||||
if 'hedley_nir_band' in config:
|
||||
self.hedley_nir_band.setValue(config['hedley_nir_band'])
|
||||
|
||||
# SUGAR参数
|
||||
if 'sugar_iter' in config:
|
||||
self.sugar_iter.setValue(config['sugar_iter'] if config['sugar_iter'] is not None else 0)
|
||||
if 'sugar_sigma' in config:
|
||||
self.sugar_sigma.setValue(config['sugar_sigma'])
|
||||
if 'sugar_estimate_background' in config:
|
||||
self.sugar_estimate_background.setChecked(config['sugar_estimate_background'])
|
||||
if 'sugar_glint_mask_method' in config:
|
||||
idx = self.sugar_glint_mask_method.findData(config['sugar_glint_mask_method']) # 使用 findData()
|
||||
if idx >= 0:
|
||||
self.sugar_glint_mask_method.setCurrentIndex(idx)
|
||||
if 'sugar_termination_thresh' in config:
|
||||
self.sugar_termination_thresh.setValue(config['sugar_termination_thresh'])
|
||||
if 'sugar_bounds' in config:
|
||||
self.sugar_bounds.setText(str(config['sugar_bounds']))
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤3"""
|
||||
# 验证输入
|
||||
img_path = self.img_file.get_path()
|
||||
if not img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择影像文件!")
|
||||
return
|
||||
if self.enable_checkbox.isChecked():
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if not water_mask_path:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入错误",
|
||||
"独立运行耀斑去除时,必须选择水域掩膜或边界文件。\n\n"
|
||||
"请提供与当前影像空间一致的水域栅格掩膜(.dat/.tif),或水域矢量边界(.shp)。\n"
|
||||
"若刚跑过完整流程,可使用步骤1生成的水域掩膜文件。",
|
||||
)
|
||||
return
|
||||
|
||||
# 获取主窗口并运行步骤
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step3': self.get_config()}
|
||||
main_window.run_single_step('step3', config)
|
||||
185
src/gui/panels/step4_panel.py
Normal file
185
src/gui/panels/step4_panel.py
Normal file
@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step4 面板 - 数据预处理
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QHBoxLayout, QLabel,
|
||||
QSpinBox, QPushButton, QCheckBox, QTableView,
|
||||
QAbstractItemView, QHeaderView, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step4Panel(QWidget):
|
||||
"""步骤4:数据预处理"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
# CSV文件
|
||||
self.csv_file = FileSelectWidget(
|
||||
"水质参数文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.csv_file)
|
||||
|
||||
hint = QLabel("提示: 处理CSV文件,筛选剔除异常值")
|
||||
hint.setStyleSheet("color: #666; font-size: 10px;")
|
||||
layout.addWidget(hint)
|
||||
|
||||
preview_group = QGroupBox("CSV数据预览")
|
||||
preview_layout = QVBoxLayout()
|
||||
|
||||
controls_layout = QHBoxLayout()
|
||||
controls_layout.addWidget(QLabel("预览行数:"))
|
||||
self.preview_rows_spin = QSpinBox()
|
||||
self.preview_rows_spin.setRange(1, 200)
|
||||
self.preview_rows_spin.setValue(10)
|
||||
controls_layout.addWidget(self.preview_rows_spin)
|
||||
self.preview_btn = QPushButton("刷新预览")
|
||||
self.preview_btn.clicked.connect(self.load_csv_preview)
|
||||
controls_layout.addWidget(self.preview_btn)
|
||||
controls_layout.addStretch()
|
||||
|
||||
self.preview_table = QTableView()
|
||||
self.preview_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||
self.preview_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||
self.preview_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||
self.preview_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
|
||||
self.preview_table.verticalHeader().setVisible(False)
|
||||
self.preview_table.setMinimumHeight(200)
|
||||
|
||||
self.preview_status_label = QLabel("请选择CSV文件并点击刷新预览")
|
||||
self.preview_status_label.setStyleSheet("color: #666; font-size: 11px;")
|
||||
|
||||
preview_layout.addLayout(controls_layout)
|
||||
preview_layout.addWidget(self.preview_table)
|
||||
preview_layout.addWidget(self.preview_status_label)
|
||||
preview_group.setLayout(preview_layout)
|
||||
layout.addWidget(preview_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出处理后CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("processed_data.csv")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
self.reset_preview()
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'csv_path': self.csv_file.get_path(),
|
||||
}
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'csv_path' in config:
|
||||
self.csv_file.set_path(config['csv_path'])
|
||||
self.load_csv_preview()
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "4_processed_data")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
default_output_path = os.path.join(output_dir, "processed_data.csv").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
else:
|
||||
self.output_file.set_path("")
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤4"""
|
||||
# 验证输入
|
||||
csv_path = self.csv_file.get_path()
|
||||
if not csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择水质参数文件!")
|
||||
return
|
||||
|
||||
# 获取主窗口并运行步骤
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step4': self.get_config()}
|
||||
main_window.run_single_step('step4', config)
|
||||
|
||||
def reset_preview(self, message="请选择CSV文件并点击刷新预览"):
|
||||
"""重置预览表格"""
|
||||
from src.gui.water_quality_gui import PandasTableModel
|
||||
empty_model = PandasTableModel(pd.DataFrame())
|
||||
self.preview_table.setModel(empty_model)
|
||||
self.preview_status_label.setText(message)
|
||||
|
||||
def load_csv_preview(self):
|
||||
"""加载CSV预览数据"""
|
||||
from src.gui.water_quality_gui import PandasTableModel
|
||||
csv_path = self.csv_file.get_path()
|
||||
if not csv_path:
|
||||
self.reset_preview("请先选择CSV文件")
|
||||
return
|
||||
if not os.path.exists(csv_path):
|
||||
self.reset_preview("文件不存在,请检查路径")
|
||||
return
|
||||
|
||||
try:
|
||||
rows_to_preview = max(1, self.preview_rows_spin.value())
|
||||
# dtype=object 确保所有列以字符串读取,避免空值/混合类型导致 dtype 报错
|
||||
df = pd.read_csv(csv_path, nrows=rows_to_preview, dtype=object)
|
||||
# fillna 在 PandasTableModel.__init__ 中已执行,此处再次防御性处理
|
||||
df = df.fillna('')
|
||||
if df.empty:
|
||||
self.reset_preview("CSV文件为空")
|
||||
return
|
||||
|
||||
model = PandasTableModel(df)
|
||||
self.preview_table.setModel(model)
|
||||
self.preview_status_label.setText(
|
||||
f"预览 {len(df)} 行,{len(df.columns)} 列(总行数可能更多)"
|
||||
)
|
||||
except Exception as exc:
|
||||
self.reset_preview(f"加载失败: {exc}")
|
||||
225
src/gui/panels/step5_5_panel.py
Normal file
225
src/gui/panels/step5_5_panel.py
Normal file
@ -0,0 +1,225 @@
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QGridLayout,
|
||||
QHBoxLayout, QLabel, QCheckBox, QPushButton, QMessageBox, QScrollArea
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
def get_resource_path(relative_path: str) -> str:
|
||||
"""适配开发与 PyInstaller 环境的路径获取逻辑。
|
||||
支持两种打包模式:
|
||||
1. --onedir 模式:文件在 exe_root/_internal/ 下 → 检查 _internal 目录
|
||||
2. --onefile 模式:文件在 sys._MEIPASS 平铺目录
|
||||
"""
|
||||
# 优先检查 PyInstaller onefile 模式(文件平铺在 _MEIPASS 下)
|
||||
if hasattr(sys, '_MEIPASS'):
|
||||
internal_path = os.path.join(sys._MEIPASS, '_internal', relative_path)
|
||||
if os.path.exists(internal_path):
|
||||
return internal_path
|
||||
return os.path.join(sys._MEIPASS, relative_path)
|
||||
|
||||
# 兼容 PyInstaller onedir 模式的 _internal 目录(exe 同级目录下)
|
||||
exe_dir = os.path.dirname(sys.executable)
|
||||
internal_path = os.path.join(exe_dir, '_internal', relative_path)
|
||||
if os.path.exists(internal_path):
|
||||
return internal_path
|
||||
|
||||
# 开发环境下:基于当前文件 (step5_5_panel.py) 的绝对路径进行回溯
|
||||
# 当前在 src/gui/panels/,目标在 src/gui/model/
|
||||
base_dir = Path(__file__).resolve().parent.parent / "model"
|
||||
target_path = base_dir / os.path.basename(relative_path)
|
||||
return str(target_path)
|
||||
|
||||
class Step5_5Panel(QWidget):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.index_checkboxes: Dict[str, QCheckBox] = {}
|
||||
# 标识为 waterindex.csv,目录跳转逻辑在 get_resource_path 中
|
||||
self.builtin_formula_path = get_resource_path("waterindex.csv")
|
||||
|
||||
self.init_ui()
|
||||
# 延迟一小会儿加载,确保UI框架已就绪
|
||||
self._auto_load_formulas()
|
||||
|
||||
def init_ui(self):
|
||||
main_layout = QVBoxLayout()
|
||||
main_layout.setContentsMargins(20, 20, 20, 20)
|
||||
main_layout.setSpacing(10)
|
||||
|
||||
# 1. 路径展示区 (半透明只读)
|
||||
path_group = QGroupBox("公式配置源 (内置)")
|
||||
path_layout = QVBoxLayout()
|
||||
self.formula_csv_widget = FileSelectWidget("内置CSV路径:", "CSV Files (*.csv)")
|
||||
self.formula_csv_widget.set_path(self.builtin_formula_path)
|
||||
self.formula_csv_widget.set_read_only(True)
|
||||
# 视觉微调:提示用户这是内置的
|
||||
self.formula_csv_widget.line_edit.setStyleSheet("background-color: #f0f0f0; color: #666;")
|
||||
path_layout.addWidget(self.formula_csv_widget)
|
||||
path_group.setLayout(path_layout)
|
||||
main_layout.addWidget(path_group)
|
||||
|
||||
# 2. 训练数据输入
|
||||
input_group = QGroupBox("输入样本数据")
|
||||
input_layout = QVBoxLayout()
|
||||
self.training_data_widget = FileSelectWidget("特征提取CSV:", "CSV Files (*.csv)")
|
||||
input_layout.addWidget(self.training_data_widget)
|
||||
input_group.setLayout(input_layout)
|
||||
main_layout.addWidget(input_group)
|
||||
|
||||
# 3. 公式选择区
|
||||
self.formula_group = QGroupBox("待计算水质指数勾选")
|
||||
formula_outer_layout = QVBoxLayout()
|
||||
|
||||
btn_layout = QHBoxLayout()
|
||||
self.select_all_btn = QPushButton("全选")
|
||||
self.deselect_all_btn = QPushButton("清空")
|
||||
self.select_all_btn.clicked.connect(self.select_all_formulas)
|
||||
self.deselect_all_btn.clicked.connect(self.deselect_all_formulas)
|
||||
btn_layout.addWidget(self.select_all_btn)
|
||||
btn_layout.addWidget(self.deselect_all_btn)
|
||||
btn_layout.addStretch()
|
||||
|
||||
self.refresh_button = QPushButton("手动重新加载公式")
|
||||
self.refresh_button.clicked.connect(lambda: self.refresh_formulas(silent=False))
|
||||
btn_layout.addWidget(self.refresh_button)
|
||||
|
||||
formula_outer_layout.addLayout(btn_layout)
|
||||
|
||||
# 核心滚动区
|
||||
scroll = QScrollArea()
|
||||
scroll.setWidgetResizable(True)
|
||||
scroll.setMinimumHeight(300) # 强制最小高度,防止塌陷
|
||||
self.scroll_content = QWidget()
|
||||
self.formula_layout = QGridLayout(self.scroll_content)
|
||||
self.formula_layout.setAlignment(Qt.AlignTop) # 靠顶对齐
|
||||
scroll.setWidget(self.scroll_content)
|
||||
formula_outer_layout.addWidget(scroll)
|
||||
|
||||
self.formula_group.setLayout(formula_outer_layout)
|
||||
main_layout.addWidget(self.formula_group)
|
||||
|
||||
# 4. 输出与运行
|
||||
output_group = QGroupBox("结果输出")
|
||||
output_layout = QVBoxLayout()
|
||||
self.output_file_widget = FileSelectWidget("保存路径:", "CSV Files (*.csv)", mode="save")
|
||||
output_layout.addWidget(self.output_file_widget)
|
||||
output_group.setLayout(output_layout)
|
||||
main_layout.addWidget(output_group)
|
||||
|
||||
self.enable_checkbox = QCheckBox("启用计算流程")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
main_layout.addWidget(self.enable_checkbox)
|
||||
|
||||
self.run_button = QPushButton("立即执行计算")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.setMinimumHeight(40)
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
main_layout.addWidget(self.run_button)
|
||||
|
||||
self.setLayout(main_layout)
|
||||
|
||||
def _auto_load_formulas(self):
|
||||
"""启动时自动加载逻辑"""
|
||||
if os.path.exists(self.builtin_formula_path):
|
||||
self.refresh_formulas(silent=True)
|
||||
else:
|
||||
print(f"DEBUG: 自动加载失败,路径不存在: {self.builtin_formula_path}")
|
||||
|
||||
def refresh_formulas(self, silent=False):
|
||||
path = self.builtin_formula_path
|
||||
if not os.path.exists(path):
|
||||
if not silent: QMessageBox.warning(self, "错误", f"找不到内置公式文件:\n{path}")
|
||||
return
|
||||
|
||||
try:
|
||||
# 清理旧列表
|
||||
for i in reversed(range(self.formula_layout.count())):
|
||||
widget = self.formula_layout.itemAt(i).widget()
|
||||
if widget: widget.deleteLater()
|
||||
self.index_checkboxes.clear()
|
||||
|
||||
# 鲁棒性读取:尝试不同编码
|
||||
for encoding in ['utf-8', 'gbk', 'utf-8-sig']:
|
||||
try:
|
||||
df = pd.read_csv(path, encoding=encoding)
|
||||
if 'Formula_Name' in df.columns: break
|
||||
except: continue
|
||||
|
||||
if 'Formula_Name' not in df.columns:
|
||||
if not silent: QMessageBox.critical(self, "错误", "CSV文件缺少 'Formula_Name' 列")
|
||||
return
|
||||
|
||||
names = df['Formula_Name'].dropna().unique().tolist()
|
||||
|
||||
row, col = 0, 0
|
||||
for name in names:
|
||||
name = str(name).strip()
|
||||
if not name: continue
|
||||
cb = QCheckBox(name)
|
||||
cb.setChecked(True)
|
||||
self.index_checkboxes[name] = cb
|
||||
self.formula_layout.addWidget(cb, row, col)
|
||||
col += 1
|
||||
if col >= 3:
|
||||
col = 0
|
||||
row += 1
|
||||
|
||||
# 强制UI更新
|
||||
self.scroll_content.adjustSize()
|
||||
print(f"✅ 成功加载 {len(self.index_checkboxes)} 个公式")
|
||||
|
||||
except Exception as e:
|
||||
if not silent: QMessageBox.critical(self, "加载失败", f"原因: {str(e)}")
|
||||
|
||||
def select_all_formulas(self):
|
||||
for cb in self.index_checkboxes.values(): cb.setChecked(True)
|
||||
|
||||
def deselect_all_formulas(self):
|
||||
for cb in self.index_checkboxes.values(): cb.setChecked(False)
|
||||
|
||||
def get_config(self):
|
||||
selected = [n for n, cb in self.index_checkboxes.items() if cb.isChecked()]
|
||||
return {
|
||||
'training_spectra_path': self.training_data_widget.get_path(),
|
||||
'formula_csv_file': self.builtin_formula_path,
|
||||
'formula_names': selected,
|
||||
'output_file': self.output_file_widget.get_path(),
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
|
||||
def set_config(self, config):
|
||||
if 'training_spectra_path' in config: self.training_data_widget.set_path(config['training_spectra_path'])
|
||||
if 'formula_names' in config:
|
||||
sel = set(config['formula_names'])
|
||||
for n, cb in self.index_checkboxes.items(): cb.setChecked(n in sel)
|
||||
if 'output_file' in config: self.output_file_widget.set_path(config['output_file'])
|
||||
self.enable_checkbox.setChecked(config.get('enabled', True))
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
if work_dir: self.work_dir = work_dir
|
||||
main = self.window()
|
||||
if hasattr(main, 'step5_panel'):
|
||||
p5 = main.step5_panel.output_file.get_path() # 修正:变量名对齐
|
||||
if p5:
|
||||
if not os.path.isabs(p5): p5 = os.path.join(self.work_dir or '', p5).replace('\\', '/')
|
||||
self.training_data_widget.set_path(p5)
|
||||
|
||||
if self.work_dir:
|
||||
out = os.path.join(self.work_dir, "6_water_quality_indices", "training_spectra_indices.csv").replace('\\', '/')
|
||||
self.output_file_widget.set_path(out)
|
||||
|
||||
def run_step(self):
|
||||
config = self.get_config()
|
||||
if not config['training_spectra_path']:
|
||||
QMessageBox.warning(self, "提示", "请先选择输入数据")
|
||||
return
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'): parent = parent.parent()
|
||||
if parent: parent.run_single_step('step5_5', {'step5_5': config})
|
||||
239
src/gui/panels/step5_panel.py
Normal file
239
src/gui/panels/step5_panel.py
Normal file
@ -0,0 +1,239 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step5 面板 - 光谱提取
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QLabel,
|
||||
QSpinBox, QPushButton, QCheckBox, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtGui import QFont
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step5Panel(QWidget):
|
||||
"""步骤5:光谱提取"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
title = QLabel("步骤5:训练样本光谱提取")
|
||||
title.setFont(QFont("Arial", 12, QFont.Bold))
|
||||
layout.addWidget(title)
|
||||
|
||||
# 去耀斑影像文件(用于独立运行)
|
||||
self.deglint_img_file = FileSelectWidget(
|
||||
"去耀斑影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.deglint_img_file)
|
||||
|
||||
# 处理后的CSV文件(用于独立运行)
|
||||
self.csv_file = FileSelectWidget(
|
||||
"处理后CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.csv_file)
|
||||
|
||||
# 水体掩膜文件(可选,用于独立运行)
|
||||
self.water_mask_file = FileSelectWidget(
|
||||
"水体掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.water_mask_file.line_edit.setPlaceholderText("可选,如不选择则自动生成")
|
||||
layout.addWidget(self.water_mask_file)
|
||||
|
||||
self.glint_mask_file = FileSelectWidget(
|
||||
"耀斑掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.glint_mask_file)
|
||||
step5_glint_hint = QLabel(
|
||||
"提示:独立运行本步骤时必须选择耀斑掩膜(通常为步骤2输出的 severe_glint_area.dat),用于在采样时避开耀斑像元。"
|
||||
)
|
||||
step5_glint_hint.setWordWrap(True)
|
||||
step5_glint_hint.setStyleSheet("color: #666; font-size: 10px;")
|
||||
layout.addWidget(step5_glint_hint)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("提取参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.radius = QSpinBox()
|
||||
self.radius.setRange(1, 50)
|
||||
self.radius.setValue(5)
|
||||
params_layout.addRow("采样半径(像素):", self.radius)
|
||||
|
||||
self.source_epsg = QSpinBox()
|
||||
self.source_epsg.setRange(1000, 99999)
|
||||
self.source_epsg.setValue(4326)
|
||||
params_layout.addRow("源坐标系EPSG:", self.source_epsg)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出训练数据:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("training_spectra.csv")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
# 信号连接:影像文件路径变化时动态更新波段范围
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'radius': self.radius.value(),
|
||||
'source_epsg': self.source_epsg.value(),
|
||||
}
|
||||
# 添加独立运行所需的文件路径
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
if deglint_img_path:
|
||||
config['deglint_img_path'] = deglint_img_path
|
||||
csv_path = self.csv_file.get_path()
|
||||
if csv_path:
|
||||
config['csv_path'] = csv_path
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['boundary_path'] = water_mask_path
|
||||
glint_mask_path = self.glint_mask_file.get_path()
|
||||
if glint_mask_path:
|
||||
config['glint_mask_path'] = glint_mask_path
|
||||
# 注意:step5_extract_training_spectra 不接受 output_path / training_spectra_path
|
||||
# 参数,输出路径由 pipeline 内部根据 training_spectra_dir 自动生成。
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'radius' in config:
|
||||
self.radius.setValue(config['radius'])
|
||||
if 'source_epsg' in config:
|
||||
self.source_epsg.setValue(config['source_epsg'])
|
||||
if 'deglint_img_path' in config:
|
||||
self.deglint_img_file.set_path(config['deglint_img_path'])
|
||||
if 'csv_path' in config:
|
||||
self.csv_file.set_path(config['csv_path'])
|
||||
if 'boundary_path' in config:
|
||||
self.water_mask_file.set_path(config['boundary_path'])
|
||||
if 'glint_mask_path' in config:
|
||||
self.glint_mask_file.set_path(config['glint_mask_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置/Pipeline 或 Step1Panel 自动填充路径,实现上下游数据流转
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例,用于获取步骤1生成的水域掩膜路径
|
||||
"""
|
||||
# 保存工作目录引用
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Pipeline 获取水体掩膜路径
|
||||
mask_path = None
|
||||
if pipeline and hasattr(pipeline, 'water_mask_path') and pipeline.water_mask_path:
|
||||
mask_path = pipeline.water_mask_path
|
||||
|
||||
# 2. 如果 Pipeline 中没有,则尝试直接从 Step1 界面读取
|
||||
main_window = self.window()
|
||||
if not mask_path and hasattr(main_window, 'step1_panel'):
|
||||
if main_window.step1_panel.use_ndwi_radio.isChecked():
|
||||
mask_path = main_window.step1_panel.output_file.get_path()
|
||||
else:
|
||||
mask_path = main_window.step1_panel.mask_file.get_path()
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if mask_path and not os.path.isabs(mask_path):
|
||||
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
|
||||
|
||||
# 填充水体掩膜路径
|
||||
if mask_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(mask_path):
|
||||
mask_path = os.path.join(self.work_dir or '', mask_path).replace('\\', '/')
|
||||
self.water_mask_file.set_path(mask_path)
|
||||
|
||||
# 3. 尝试从 Step2 界面读取耀斑掩膜路径
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'step2_panel'):
|
||||
glint_path = main_window.step2_panel.output_file.get_path()
|
||||
if glint_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(glint_path):
|
||||
glint_path = os.path.join(self.work_dir or '', glint_path).replace('\\', '/')
|
||||
self.glint_mask_file.set_path(glint_path)
|
||||
|
||||
# 4. 自动填充输出路径(基于工作目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "5_training_spectra")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
default_output_path = os.path.join(output_dir, "training_spectra.csv").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
else:
|
||||
self.output_file.set_path("")
|
||||
|
||||
# 5. 尝试从 Step4 界面读取已处理的水质参数 CSV 路径,自动填入本面板
|
||||
main_window = self.window()
|
||||
if main_window and hasattr(main_window, 'step4_panel'):
|
||||
step4_output_path = main_window.step4_panel.output_file.get_path()
|
||||
if step4_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step4_output_path):
|
||||
step4_output_path = os.path.join(self.work_dir or '', step4_output_path).replace('\\', '/')
|
||||
existing_csv = self.csv_file.get_path()
|
||||
if not existing_csv or not existing_csv.strip():
|
||||
self.csv_file.set_path(step4_output_path)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤5"""
|
||||
# 验证输入
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
csv_path = self.csv_file.get_path()
|
||||
if not deglint_img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
|
||||
return
|
||||
if not csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择处理后的CSV文件!")
|
||||
return
|
||||
if not self.glint_mask_file.get_path():
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入错误",
|
||||
"独立运行光谱特征提取时,必须选择耀斑掩膜文件。\n\n"
|
||||
"请提供与去耀斑影像对应的耀斑二值掩膜(一般为步骤2输出的 severe_glint_area.dat)。",
|
||||
)
|
||||
return
|
||||
|
||||
# 获取主窗口并运行步骤
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step5': self.get_config()}
|
||||
main_window.run_single_step('step5', config)
|
||||
307
src/gui/panels/step6_5_panel.py
Normal file
307
src/gui/panels/step6_5_panel.py
Normal file
@ -0,0 +1,307 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step6_5 面板 - 非经验统计回归建模
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QCheckBox, QSpinBox, QPushButton,
|
||||
QFileDialog, QMessageBox,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step6_5Panel(QWidget):
|
||||
"""步骤6.5:非经验统计回归建模"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 训练数据文件(用于独立运行)
|
||||
self.training_csv_file = FileSelectWidget(
|
||||
"训练数据CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.training_csv_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("模型参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
# 预处理方法
|
||||
self.preproc_checkboxes = {}
|
||||
preproc_group = QGroupBox("预处理方法 (可多选)")
|
||||
preproc_layout = QVBoxLayout()
|
||||
preproc_grid = QGridLayout()
|
||||
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
|
||||
|
||||
for i, method in enumerate(preproc_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
checkbox.setChecked(True)
|
||||
self.preproc_checkboxes[method] = checkbox
|
||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||
|
||||
button_layout = QHBoxLayout()
|
||||
select_all_btn = QPushButton("全选")
|
||||
deselect_all_btn = QPushButton("全不选")
|
||||
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
|
||||
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
|
||||
button_layout.addWidget(select_all_btn)
|
||||
button_layout.addWidget(deselect_all_btn)
|
||||
button_layout.addStretch()
|
||||
|
||||
preproc_layout.addLayout(preproc_grid)
|
||||
preproc_layout.addLayout(button_layout)
|
||||
preproc_group.setLayout(preproc_layout)
|
||||
params_layout.addRow(preproc_group)
|
||||
|
||||
# 算法选择(可多选)
|
||||
self.algorithm_inputs = {}
|
||||
algorithms_widget = QWidget()
|
||||
algorithms_layout = QVBoxLayout()
|
||||
algorithms_layout.setContentsMargins(0, 0, 0, 0)
|
||||
algorithms_layout.setSpacing(4)
|
||||
|
||||
algorithm_list = ['chl_a', 'nh3', 'mno4', 'tn', 'tp', 'tss']
|
||||
for algorithm in algorithm_list:
|
||||
row_widget = QWidget()
|
||||
row_layout = QHBoxLayout()
|
||||
row_layout.setContentsMargins(0, 0, 0, 0)
|
||||
checkbox = QCheckBox(algorithm)
|
||||
checkbox.setChecked(True)
|
||||
spinbox = QSpinBox()
|
||||
spinbox.setRange(0, 500)
|
||||
spinbox.setValue(0)
|
||||
spinbox.setMaximumWidth(90)
|
||||
row_layout.addWidget(checkbox)
|
||||
row_layout.addWidget(QLabel("对应值列索引:"))
|
||||
row_layout.addWidget(spinbox)
|
||||
row_layout.addStretch()
|
||||
row_widget.setLayout(row_layout)
|
||||
algorithms_layout.addWidget(row_widget)
|
||||
self.algorithm_inputs[algorithm] = (checkbox, spinbox)
|
||||
|
||||
algorithms_widget.setLayout(algorithms_layout)
|
||||
params_layout.addRow("非经验算法选择:", algorithms_widget)
|
||||
|
||||
# 光谱起始列
|
||||
self.spectral_start_col = QSpinBox()
|
||||
self.spectral_start_col.setRange(0, 100)
|
||||
self.spectral_start_col.setValue(1)
|
||||
params_layout.addRow("光谱起始列索引:", self.spectral_start_col)
|
||||
|
||||
# 窗口大小 (变量名已修正,避免覆盖 QWidget.window)
|
||||
self.window_size_spinbox = QSpinBox()
|
||||
self.window_size_spinbox.setRange(1, 20)
|
||||
self.window_size_spinbox.setValue(5)
|
||||
params_layout.addRow("窗口大小:", self.window_size_spinbox)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_dir = FileSelectWidget(
|
||||
"输出模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_dir.line_edit.setPlaceholderText("8_Regression_Modeling")
|
||||
self.output_dir.browse_btn.clicked.disconnect()
|
||||
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
layout.addWidget(self.output_dir)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
selected_algorithms = [
|
||||
name for name, (checkbox, _) in self.algorithm_inputs.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_algorithms:
|
||||
selected_algorithms = list(self.algorithm_inputs.keys())
|
||||
|
||||
value_cols = {
|
||||
name: spinbox.value()
|
||||
for name, (_, spinbox) in self.algorithm_inputs.items()
|
||||
if name in selected_algorithms
|
||||
}
|
||||
|
||||
preprocessing_methods = [
|
||||
method for method, checkbox in self.preproc_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
] or ['None']
|
||||
|
||||
config = {
|
||||
'preprocessing_methods': preprocessing_methods,
|
||||
'algorithms': selected_algorithms,
|
||||
'value_cols': value_cols,
|
||||
'spectral_start_col': self.spectral_start_col.value(),
|
||||
'window': self.window_size_spinbox.value(),
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
|
||||
output_dir = self.output_dir.get_path()
|
||||
if not output_dir:
|
||||
main_window = self.parent().window()
|
||||
if hasattr(main_window, 'work_dir') and main_window.work_dir:
|
||||
output_dir = str(Path(main_window.work_dir) / "8_Regression_Modeling")
|
||||
else:
|
||||
output_dir = str(Path.cwd() / "8_Regression_Modeling")
|
||||
config['output_dir'] = output_dir
|
||||
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if training_csv_path:
|
||||
config['csv_path'] = training_csv_path
|
||||
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'preprocessing_methods' in config:
|
||||
methods = config['preprocessing_methods']
|
||||
for method, checkbox in self.preproc_checkboxes.items():
|
||||
checkbox.setChecked(method in methods)
|
||||
|
||||
if 'algorithms' in config:
|
||||
algorithm_values = config['algorithms']
|
||||
for algorithm, (checkbox, spinbox) in self.algorithm_inputs.items():
|
||||
checkbox.setChecked(algorithm in algorithm_values)
|
||||
|
||||
if 'value_cols' in config:
|
||||
value_cols = config['value_cols']
|
||||
if isinstance(value_cols, dict):
|
||||
for algorithm, (_, spinbox) in self.algorithm_inputs.items():
|
||||
if algorithm in value_cols:
|
||||
spinbox.setValue(value_cols[algorithm])
|
||||
else:
|
||||
for _, spinbox in self.algorithm_inputs.values():
|
||||
spinbox.setValue(value_cols)
|
||||
|
||||
if 'spectral_start_col' in config:
|
||||
self.spectral_start_col.setValue(config['spectral_start_col'])
|
||||
|
||||
if 'window' in config:
|
||||
self.window_size_spinbox.setValue(config['window'])
|
||||
if 'output_dir' in config:
|
||||
self.output_dir.set_path(config['output_dir'])
|
||||
if 'csv_path' in config:
|
||||
self.training_csv_file.set_path(config['csv_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 借用父组件的 window() 方法,安全绕过当前类的命名冲突
|
||||
parent_widget = self.parentWidget()
|
||||
main_window = parent_widget.window() if parent_widget else None
|
||||
if main_window and hasattr(main_window, 'step5_panel'):
|
||||
step5_widget = getattr(main_window.step5_panel, 'output_file', None)
|
||||
step5_output_path = ""
|
||||
if hasattr(step5_widget, 'get_path'):
|
||||
step5_output_path = step5_widget.get_path() or ""
|
||||
elif hasattr(step5_widget, 'text'):
|
||||
step5_output_path = step5_widget.text() or ""
|
||||
|
||||
if step5_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_output_path):
|
||||
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
|
||||
existing = self.training_csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.training_csv_file.set_path(step5_output_path)
|
||||
|
||||
# 2. 自动填充输出目录(8_Regression_Modeling)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "8_Regression_Modeling")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_dir.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
# 借用父组件的 window() 方法,安全绕过当前类的命名冲突
|
||||
parent_widget = self.parentWidget()
|
||||
mw = parent_widget.window() if parent_widget else None
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "8_Regression_Modeling")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出模型目录", default)
|
||||
if dir_path:
|
||||
self.output_dir.set_path(dir_path)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤6.5"""
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if not training_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
|
||||
return
|
||||
|
||||
if not os.path.exists(training_csv_path):
|
||||
QMessageBox.warning(self, "输入错误", "训练数据CSV文件不存在!")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step6_5', {'step6_5': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
|
||||
def _toggle_checkboxes(self, checkboxes_dict, checked):
|
||||
"""统一设置预处理checkbox状态"""
|
||||
for checkbox in checkboxes_dict.values():
|
||||
checkbox.setChecked(checked)
|
||||
374
src/gui/panels/step6_75_panel.py
Normal file
374
src/gui/panels/step6_75_panel.py
Normal file
@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step6_75 面板 - 自定义回归分析
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pandas as pd
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QLineEdit, QCheckBox, QPushButton,
|
||||
QScrollArea, QMessageBox,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step6_75Panel(QWidget):
|
||||
"""步骤6.75:自定义回归分析"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.x_column_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.y_column_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.method_checkboxes: Dict[str, QCheckBox] = {}
|
||||
self.csv_columns = []
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
hint = QLabel("指定自变量与因变量列,批量尝试不同回归方法")
|
||||
hint.setStyleSheet("color: #666; font-size: 11px;")
|
||||
layout.addWidget(hint)
|
||||
|
||||
# CSV文件选择
|
||||
csv_group = QGroupBox("数据文件")
|
||||
csv_layout = QVBoxLayout()
|
||||
|
||||
self.csv_file = FileSelectWidget(
|
||||
"输入CSV文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.csv_file.line_edit.textChanged.connect(self.on_csv_file_changed)
|
||||
csv_layout.addWidget(self.csv_file)
|
||||
|
||||
self.refresh_btn = QPushButton("刷新列信息")
|
||||
self.refresh_btn.clicked.connect(self.refresh_csv_columns)
|
||||
csv_layout.addWidget(self.refresh_btn)
|
||||
|
||||
csv_group.setLayout(csv_layout)
|
||||
layout.addWidget(csv_group)
|
||||
|
||||
# 自变量选择
|
||||
x_group = QGroupBox("自变量列选择 (可多选)")
|
||||
x_layout = QVBoxLayout()
|
||||
|
||||
x_scroll = QScrollArea()
|
||||
x_scroll.setWidgetResizable(True)
|
||||
x_scroll.setMinimumHeight(250)
|
||||
x_scroll.setMaximumHeight(350)
|
||||
|
||||
x_widget = QWidget()
|
||||
self.x_columns_layout = QGridLayout()
|
||||
x_widget.setLayout(self.x_columns_layout)
|
||||
|
||||
x_scroll.setWidget(x_widget)
|
||||
x_layout.addWidget(x_scroll)
|
||||
|
||||
x_btn_layout = QHBoxLayout()
|
||||
self.x_select_all = QPushButton("全选")
|
||||
self.x_deselect_all = QPushButton("全不选")
|
||||
self.x_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, True))
|
||||
self.x_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.x_column_checkboxes, False))
|
||||
x_btn_layout.addWidget(self.x_select_all)
|
||||
x_btn_layout.addWidget(self.x_deselect_all)
|
||||
x_btn_layout.addStretch()
|
||||
x_layout.addLayout(x_btn_layout)
|
||||
|
||||
x_group.setLayout(x_layout)
|
||||
layout.addWidget(x_group)
|
||||
|
||||
# 因变量选择
|
||||
y_group = QGroupBox("因变量列选择 (可多选)")
|
||||
y_layout = QVBoxLayout()
|
||||
|
||||
y_scroll = QScrollArea()
|
||||
y_scroll.setWidgetResizable(True)
|
||||
y_scroll.setMinimumHeight(200)
|
||||
y_scroll.setMaximumHeight(300)
|
||||
|
||||
y_widget = QWidget()
|
||||
self.y_columns_layout = QGridLayout()
|
||||
y_widget.setLayout(self.y_columns_layout)
|
||||
|
||||
y_scroll.setWidget(y_widget)
|
||||
y_layout.addWidget(y_scroll)
|
||||
|
||||
y_btn_layout = QHBoxLayout()
|
||||
self.y_select_all = QPushButton("全选")
|
||||
self.y_deselect_all = QPushButton("全不选")
|
||||
self.y_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, True))
|
||||
self.y_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.y_column_checkboxes, False))
|
||||
y_btn_layout.addWidget(self.y_select_all)
|
||||
y_btn_layout.addWidget(self.y_deselect_all)
|
||||
y_btn_layout.addStretch()
|
||||
y_layout.addLayout(y_btn_layout)
|
||||
|
||||
y_group.setLayout(y_layout)
|
||||
layout.addWidget(y_group)
|
||||
|
||||
# 回归方法选择
|
||||
method_group = QGroupBox("回归方法选择 (可多选)")
|
||||
method_layout = QVBoxLayout()
|
||||
|
||||
method_grid = QGridLayout()
|
||||
regression_methods = [
|
||||
'linear', 'exponential', 'power', 'logarithmic',
|
||||
'polynomial', 'hyperbolic', 'sigmoidal'
|
||||
]
|
||||
|
||||
for i, method in enumerate(regression_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
if method in ['linear', 'exponential', 'power', 'logarithmic']:
|
||||
checkbox.setChecked(True)
|
||||
self.method_checkboxes[method] = checkbox
|
||||
method_grid.addWidget(checkbox, i // 3, i % 3)
|
||||
|
||||
method_layout.addLayout(method_grid)
|
||||
|
||||
method_btn_layout = QHBoxLayout()
|
||||
self.method_select_all = QPushButton("全选")
|
||||
self.method_deselect_all = QPushButton("全不选")
|
||||
self.method_select_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, True))
|
||||
self.method_deselect_all.clicked.connect(lambda: self.toggle_checkboxes(self.method_checkboxes, False))
|
||||
method_btn_layout.addWidget(self.method_select_all)
|
||||
method_btn_layout.addWidget(self.method_deselect_all)
|
||||
method_btn_layout.addStretch()
|
||||
method_layout.addLayout(method_btn_layout)
|
||||
|
||||
method_group.setLayout(method_layout)
|
||||
layout.addWidget(method_group)
|
||||
|
||||
# 输出目录
|
||||
output_group = QGroupBox("输出设置")
|
||||
output_layout = QFormLayout()
|
||||
|
||||
self.output_dir = QLineEdit()
|
||||
self.output_dir.setText("") # 路径由 update_from_config 根据 work_dir 自动填充
|
||||
output_layout.addRow("输出目录名:", self.output_dir)
|
||||
|
||||
output_group.setLayout(output_layout)
|
||||
layout.addWidget(output_group)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def toggle_checkboxes(self, checkboxes_dict, checked):
|
||||
"""统一设置checkbox状态"""
|
||||
for checkbox in checkboxes_dict.values():
|
||||
checkbox.setChecked(checked)
|
||||
|
||||
def on_csv_file_changed(self):
|
||||
"""CSV文件改变时自动刷新列信息"""
|
||||
self.refresh_csv_columns()
|
||||
|
||||
def refresh_csv_columns(self):
|
||||
"""刷新CSV文件的列信息"""
|
||||
csv_path = self.csv_file.get_path()
|
||||
if not csv_path or not os.path.exists(csv_path):
|
||||
self.csv_columns = []
|
||||
self.update_column_widgets()
|
||||
return
|
||||
|
||||
try:
|
||||
df = pd.read_csv(csv_path, nrows=0)
|
||||
self.csv_columns = list(df.columns)
|
||||
self.update_column_widgets()
|
||||
except Exception as e:
|
||||
self.csv_columns = []
|
||||
self.update_column_widgets()
|
||||
print(f"读取CSV列信息失败: {e}")
|
||||
|
||||
def update_column_widgets(self):
|
||||
"""更新列选择组件"""
|
||||
for checkbox in self.x_column_checkboxes.values():
|
||||
checkbox.setParent(None)
|
||||
self.x_column_checkboxes.clear()
|
||||
|
||||
for checkbox in self.y_column_checkboxes.values():
|
||||
checkbox.setParent(None)
|
||||
self.y_column_checkboxes.clear()
|
||||
|
||||
if not self.csv_columns:
|
||||
return
|
||||
|
||||
for i, col in enumerate(self.csv_columns):
|
||||
checkbox = QCheckBox(col)
|
||||
if any(keyword in col.lower() for keyword in ['index', 'ratio', 'normalized', 'nd', 'b']):
|
||||
checkbox.setChecked(True)
|
||||
self.x_column_checkboxes[col] = checkbox
|
||||
self.x_columns_layout.addWidget(checkbox, i // 3, i % 3)
|
||||
|
||||
for i, col in enumerate(self.csv_columns):
|
||||
checkbox = QCheckBox(col)
|
||||
if any(keyword in col.lower() for keyword in ['chl', 'tn', 'tp', 'turbidity', 'do', 'ph', 'conductivity']):
|
||||
checkbox.setChecked(True)
|
||||
self.y_column_checkboxes[col] = checkbox
|
||||
self.y_columns_layout.addWidget(checkbox, i // 2, i % 2)
|
||||
|
||||
self.x_columns_layout.update()
|
||||
self.y_columns_layout.update()
|
||||
|
||||
def get_config(self):
|
||||
selected_x_columns = [
|
||||
col for col, checkbox in self.x_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
selected_y_columns = [
|
||||
col for col, checkbox in self.y_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
selected_methods = [
|
||||
method for method, checkbox in self.method_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_methods:
|
||||
selected_methods = 'all'
|
||||
|
||||
return {
|
||||
'csv_path': self.csv_file.get_path() or None,
|
||||
'x_columns': selected_x_columns,
|
||||
'y_columns': selected_y_columns,
|
||||
'methods': selected_methods,
|
||||
'output_dir': self.output_dir.text().strip() or None,
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
|
||||
def set_config(self, config):
|
||||
if 'csv_path' in config:
|
||||
self.csv_file.set_path(config['csv_path'])
|
||||
self.refresh_csv_columns()
|
||||
|
||||
if 'x_columns' in config:
|
||||
selected_x = set(config['x_columns']) if isinstance(config['x_columns'], list) else set()
|
||||
for col, checkbox in self.x_column_checkboxes.items():
|
||||
checkbox.setChecked(col in selected_x)
|
||||
|
||||
if 'y_columns' in config:
|
||||
selected_y = set(config['y_columns']) if isinstance(config['y_columns'], list) else set()
|
||||
for col, checkbox in self.y_column_checkboxes.items():
|
||||
checkbox.setChecked(col in selected_y)
|
||||
|
||||
if 'methods' in config:
|
||||
methods = config['methods']
|
||||
if isinstance(methods, list):
|
||||
selected_methods = set(methods)
|
||||
elif methods == 'all':
|
||||
selected_methods = set(self.method_checkboxes.keys())
|
||||
else:
|
||||
selected_methods = set()
|
||||
for method, checkbox in self.method_checkboxes.items():
|
||||
checkbox.setChecked(method in selected_methods)
|
||||
|
||||
if 'output_dir' in config:
|
||||
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Step5 界面读取训练光谱 CSV 路径
|
||||
main_window = self.window()
|
||||
if main_window and hasattr(main_window, 'step5_panel'):
|
||||
step5_widget = getattr(main_window.step5_panel, 'output_file', None)
|
||||
step5_output_path = ""
|
||||
if hasattr(step5_widget, 'get_path'):
|
||||
step5_output_path = step5_widget.get_path() or ""
|
||||
elif hasattr(step5_widget, 'text'):
|
||||
step5_output_path = step5_widget.text() or ""
|
||||
|
||||
if step5_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_output_path):
|
||||
step5_output_path = os.path.join(self.work_dir or '', step5_output_path).replace('\\', '/')
|
||||
existing = self.csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.csv_file.set_path(step5_output_path)
|
||||
|
||||
# 2. 自动填充输出目录(9_Custom_Regression_Modeling)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir.text().strip()
|
||||
if not existing_out:
|
||||
self.output_dir.setText(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤6.75"""
|
||||
csv_path = self.csv_file.get_path()
|
||||
|
||||
if not csv_path:
|
||||
QMessageBox.warning(self, "输入验证失败", "请选择输入CSV文件")
|
||||
return
|
||||
if not os.path.exists(csv_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "输入CSV文件不存在")
|
||||
return
|
||||
|
||||
selected_x_columns = [
|
||||
col for col, checkbox in self.x_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_x_columns:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一个自变量列")
|
||||
return
|
||||
|
||||
selected_y_columns = [
|
||||
col for col, checkbox in self.y_column_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_y_columns:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一个因变量列")
|
||||
return
|
||||
|
||||
selected_methods = [
|
||||
method for method, checkbox in self.method_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
if not selected_methods:
|
||||
QMessageBox.warning(self, "输入验证失败", "请至少选择一种回归方法")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step6_75', {'step6_75': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
415
src/gui/panels/step6_panel.py
Normal file
415
src/gui/panels/step6_panel.py
Normal file
@ -0,0 +1,415 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step6 面板 - 机器学习建模
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QGridLayout,
|
||||
QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
|
||||
QPushButton, QFileDialog, QMessageBox,
|
||||
)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 中文映射表(内部键名 -> 显示文本)
|
||||
# ============================================================
|
||||
|
||||
# 预处理方法:内部键 -> 显示文本
|
||||
PREPROC_CHINESE = {
|
||||
'None': '无 (None)',
|
||||
'MMS': '最小-最大归一化 (MMS)',
|
||||
'SS': '标度化 (SS)',
|
||||
'SNV': '标准正态变换 (SNV)',
|
||||
'MA': '移动平均 (MA)',
|
||||
'SG': 'Savitzky-Golay (SG)',
|
||||
'MSC': '多元散射校正 (MSC)',
|
||||
'D1': '一阶导数 (D1)',
|
||||
'D2': '二阶导数 (D2)',
|
||||
'DT': '去趋势 (DT)',
|
||||
'CT': '中心化 (CT)',
|
||||
}
|
||||
|
||||
# 模型类型:内部键 -> 显示文本
|
||||
MODEL_CHINESE = {
|
||||
# 线性模型
|
||||
'LinearRegression': '多元线性回归 (MLR)',
|
||||
'Ridge': '岭回归 (Ridge)',
|
||||
'Lasso': '套索回归 (Lasso)',
|
||||
'ElasticNet': '弹性网络 (ElasticNet)',
|
||||
'PLS': '偏最小二乘 (PLSR)',
|
||||
# 树模型
|
||||
'DecisionTree': '决策树 (CART)',
|
||||
'RF': '随机森林 (RF)',
|
||||
'ExtraTrees': '极端随机树 (ET)',
|
||||
'XGBoost': '极值梯度提升 (XGBoost)',
|
||||
'LightGBM': '轻量梯度提升 (LightGBM)',
|
||||
'CatBoost': '类别梯度提升 (CatBoost)',
|
||||
# 集成学习
|
||||
'GradientBoosting': '梯度提升树 (GBDT)',
|
||||
'AdaBoost': '自适应提升 (AdaBoost)',
|
||||
# 其他模型
|
||||
'SVR': '支持向量回归 (SVR)',
|
||||
'KNN': 'K近邻回归 (KNN)',
|
||||
'MLP': '多层感知机 (BP神经网络)',
|
||||
}
|
||||
|
||||
# 数据划分方法:内部键 -> 显示文本
|
||||
SPLIT_CHINESE = {
|
||||
'spxy': 'SPXY 算法 (考量X-Y空间)',
|
||||
'ks': 'KS 算法 (考量X空间)',
|
||||
'random': '随机划分 (Random)',
|
||||
}
|
||||
|
||||
|
||||
class Step6Panel(QWidget):
|
||||
"""步骤6:机器学习建模"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 标题
|
||||
|
||||
|
||||
# 训练数据文件(用于独立运行)
|
||||
self.training_csv_file = FileSelectWidget(
|
||||
"训练数据:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.training_csv_file)
|
||||
|
||||
# 机器学习模型页面
|
||||
self.ml_page = QWidget()
|
||||
self.create_ml_page()
|
||||
layout.addWidget(self.ml_page)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_path = FileSelectWidget(
|
||||
"输出文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)",
|
||||
mode="save"
|
||||
)
|
||||
self.output_path.line_edit.setPlaceholderText("自动生成,或手动指定输出文件路径...")
|
||||
self.output_path.browse_btn.clicked.disconnect()
|
||||
self.output_path.browse_btn.clicked.connect(self.browse_output_path)
|
||||
layout.addWidget(self.output_path)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(False)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def create_ml_page(self):
|
||||
"""创建机器学习模型页面"""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("训练参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.feature_start = QLineEdit()
|
||||
self.feature_start.setText("374.285004")
|
||||
params_layout.addRow("特征起始列:", self.feature_start)
|
||||
|
||||
self.cv_folds = QSpinBox()
|
||||
self.cv_folds.setRange(2, 10)
|
||||
self.cv_folds.setValue(3)
|
||||
params_layout.addRow("交叉验证折数:", self.cv_folds)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 预处理方法 - 多选
|
||||
preproc_group = QGroupBox("预处理方法 (可多选)")
|
||||
preproc_layout = QVBoxLayout()
|
||||
|
||||
preproc_grid = QGridLayout()
|
||||
self.preproc_checkboxes = {}
|
||||
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
|
||||
|
||||
for i, method in enumerate(preproc_methods):
|
||||
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
|
||||
checkbox.setChecked(False)
|
||||
self.preproc_checkboxes[method] = checkbox
|
||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||
|
||||
button_layout = QHBoxLayout()
|
||||
select_all_btn = QPushButton("全选")
|
||||
deselect_all_btn = QPushButton("全不选")
|
||||
select_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, True))
|
||||
deselect_all_btn.clicked.connect(lambda: self._toggle_checkboxes(self.preproc_checkboxes, False))
|
||||
button_layout.addWidget(select_all_btn)
|
||||
button_layout.addWidget(deselect_all_btn)
|
||||
button_layout.addStretch()
|
||||
|
||||
preproc_layout.addLayout(preproc_grid)
|
||||
preproc_layout.addLayout(button_layout)
|
||||
preproc_group.setLayout(preproc_layout)
|
||||
layout.addWidget(preproc_group)
|
||||
|
||||
# 模型选择 - 多选
|
||||
model_group = QGroupBox("模型类型 (可多选)")
|
||||
model_layout = QVBoxLayout()
|
||||
|
||||
model_grid = QGridLayout()
|
||||
self.model_checkboxes = {}
|
||||
|
||||
model_groups = [
|
||||
("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
|
||||
("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
|
||||
("【集成学习】", ['GradientBoosting', 'AdaBoost']),
|
||||
("【其他模型】", ['SVR', 'KNN', 'MLP'])
|
||||
]
|
||||
|
||||
row = 0
|
||||
for group_name, models in model_groups:
|
||||
group_label = QLabel(f"<b>{group_name}</b>")
|
||||
group_label.setStyleSheet(
|
||||
f"background-color: {ModernStylesheet.COLORS['hover']}; "
|
||||
f"padding: 5px; border: 1px solid {ModernStylesheet.COLORS['border_light']}; "
|
||||
f"border-radius: 3px;"
|
||||
)
|
||||
model_grid.addWidget(group_label, row, 0, 1, 4)
|
||||
row += 1
|
||||
|
||||
for i, model in enumerate(models):
|
||||
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
|
||||
checkbox.setChecked(False)
|
||||
self.model_checkboxes[model] = checkbox
|
||||
model_grid.addWidget(checkbox, row, i % 4)
|
||||
if (i + 1) % 4 == 0:
|
||||
row += 1
|
||||
|
||||
row += 1
|
||||
|
||||
model_button_layout = QHBoxLayout()
|
||||
model_select_all = QPushButton("全选")
|
||||
model_deselect_all = QPushButton("全不选")
|
||||
model_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, True))
|
||||
model_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.model_checkboxes, False))
|
||||
model_button_layout.addWidget(model_select_all)
|
||||
model_button_layout.addWidget(model_deselect_all)
|
||||
model_button_layout.addStretch()
|
||||
|
||||
model_layout.addLayout(model_grid)
|
||||
model_layout.addLayout(model_button_layout)
|
||||
model_group.setLayout(model_layout)
|
||||
layout.addWidget(model_group)
|
||||
|
||||
# 数据划分方法 - 多选
|
||||
split_group = QGroupBox("数据划分方法 (可多选)")
|
||||
split_layout = QVBoxLayout()
|
||||
|
||||
split_grid = QGridLayout()
|
||||
self.split_checkboxes = {}
|
||||
split_methods = ['spxy', 'ks', 'random']
|
||||
|
||||
for i, method in enumerate(split_methods):
|
||||
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
|
||||
checkbox.setChecked(False)
|
||||
self.split_checkboxes[method] = checkbox
|
||||
split_grid.addWidget(checkbox, 0, i)
|
||||
|
||||
split_button_layout = QHBoxLayout()
|
||||
split_select_all = QPushButton("全选")
|
||||
split_deselect_all = QPushButton("全不选")
|
||||
split_select_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, True))
|
||||
split_deselect_all.clicked.connect(lambda: self._toggle_checkboxes(self.split_checkboxes, False))
|
||||
split_button_layout.addWidget(split_select_all)
|
||||
split_button_layout.addWidget(split_deselect_all)
|
||||
split_button_layout.addStretch()
|
||||
|
||||
split_layout.addLayout(split_grid)
|
||||
split_layout.addLayout(split_button_layout)
|
||||
split_group.setLayout(split_layout)
|
||||
layout.addWidget(split_group)
|
||||
|
||||
self.ml_page.setLayout(layout)
|
||||
|
||||
def _toggle_checkboxes(self, checkboxes_dict, checked):
|
||||
"""统一设置checkbox状态"""
|
||||
for checkbox in checkboxes_dict.values():
|
||||
checkbox.setChecked(checked)
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_output_path(self):
|
||||
"""浏览输出文件路径(保存对话框)"""
|
||||
current = self.output_path.get_path().strip()
|
||||
if current:
|
||||
initial_dir = os.path.dirname(current)
|
||||
initial_file = os.path.basename(current)
|
||||
else:
|
||||
initial_dir = ""
|
||||
initial_file = ""
|
||||
|
||||
if not initial_dir or not os.path.isdir(initial_dir):
|
||||
# 默认定位到 indices 目录
|
||||
work_dir = self._get_default_work_dir()
|
||||
initial_dir = os.path.join(work_dir, "6_water_quality_indices") if work_dir else ""
|
||||
if initial_dir and not os.path.isdir(initial_dir):
|
||||
os.makedirs(initial_dir, exist_ok=True)
|
||||
|
||||
file_path, _ = QFileDialog.getSaveFileName(
|
||||
self, "保存输出文件", os.path.join(initial_dir, initial_file) if initial_file else initial_dir,
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
if file_path:
|
||||
self.output_path.set_path(file_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
preprocessing_methods = [
|
||||
method for method, checkbox in self.preproc_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
model_names = [
|
||||
model for model, checkbox in self.model_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
split_methods = [
|
||||
method for method, checkbox in self.split_checkboxes.items()
|
||||
if checkbox.isChecked()
|
||||
]
|
||||
|
||||
config = {
|
||||
'feature_start_column': self.feature_start.text(),
|
||||
'preprocessing_methods': preprocessing_methods if preprocessing_methods else ['None'],
|
||||
'model_names': model_names if model_names else ['SVR'],
|
||||
'split_methods': split_methods if split_methods else ['random'],
|
||||
'cv_folds': self.cv_folds.value()
|
||||
}
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if training_csv_path:
|
||||
config['training_csv_path'] = training_csv_path
|
||||
output_path = self.output_path.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'feature_start_column' in config:
|
||||
self.feature_start.setText(str(config['feature_start_column']))
|
||||
if 'cv_folds' in config:
|
||||
self.cv_folds.setValue(config['cv_folds'])
|
||||
if 'preprocessing_methods' in config:
|
||||
methods = config['preprocessing_methods']
|
||||
for method, checkbox in self.preproc_checkboxes.items():
|
||||
checkbox.setChecked(method in methods)
|
||||
if 'model_names' in config:
|
||||
models = config['model_names']
|
||||
for model, checkbox in self.model_checkboxes.items():
|
||||
checkbox.setChecked(model in models)
|
||||
if 'split_methods' in config:
|
||||
methods = config['split_methods']
|
||||
for method, checkbox in self.split_checkboxes.items():
|
||||
checkbox.setChecked(method in methods)
|
||||
if 'training_csv_path' in config:
|
||||
self.training_csv_file.set_path(config['training_csv_path'])
|
||||
if 'output_path' in config:
|
||||
self.output_path.set_path(config['output_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充训练数据和输出路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
# 1. 尝试从 Step5 界面读取训练数据路径,并确保为绝对路径
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'step5_panel'):
|
||||
# 优先直接从 Step5 的输出 widget 读取
|
||||
step5_output = main_window.step5_panel.output_file.get_path()
|
||||
if step5_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_output):
|
||||
step5_output = os.path.join(self.work_dir or '', step5_output).replace('\\', '/')
|
||||
self.training_csv_file.set_path(step5_output)
|
||||
elif hasattr(main_window, 'step5_panel') and hasattr(main_window.step5_panel, 'get_config'):
|
||||
# 回退:从 Step5 的 config 字典中查找可能的键名
|
||||
step5_cfg = main_window.step5_panel.get_config()
|
||||
step5_csv = (
|
||||
step5_cfg.get('training_spectra_path')
|
||||
or step5_cfg.get('output_file')
|
||||
or step5_cfg.get('csv_path')
|
||||
or step5_cfg.get('output_csv')
|
||||
)
|
||||
if step5_csv:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step5_csv):
|
||||
step5_csv = os.path.join(self.work_dir or '', step5_csv).replace('\\', '/')
|
||||
self.training_csv_file.set_path(step5_csv)
|
||||
|
||||
# 2. 自动填充输出文件路径(基于工作目录和输入文件名)
|
||||
# 输入是 training_spectra.csv → 输出 {work_dir}/6_water_quality_indices/training_spectra_indices.csv
|
||||
# 输入是 sampling_spectra.csv → 输出 {work_dir}/6_water_quality_indices/sampling_spectra_indices.csv
|
||||
if self.work_dir:
|
||||
indices_dir = os.path.join(self.work_dir, "6_water_quality_indices")
|
||||
os.makedirs(indices_dir, exist_ok=True)
|
||||
training_csv = self.training_csv_file.get_path()
|
||||
if training_csv:
|
||||
basename = os.path.splitext(os.path.basename(training_csv))[0]
|
||||
output_file = f"{basename}_indices.csv"
|
||||
else:
|
||||
output_file = "water_quality_indices.csv"
|
||||
output_path = os.path.join(indices_dir, output_file).replace('\\', '/')
|
||||
self.output_path.set_path(output_path)
|
||||
else:
|
||||
self.output_path.set_path("")
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤6"""
|
||||
training_csv_path = self.training_csv_file.get_path()
|
||||
if not training_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择训练数据CSV文件!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step6': self.get_config()}
|
||||
main_window.run_single_step('step6', config)
|
||||
|
||||
def get_training_params(self):
|
||||
"""获取模型训练参数"""
|
||||
return {
|
||||
'pipeline_type': 'machine_learning',
|
||||
'feature_start': float(self.feature_start.text()),
|
||||
'cv_folds': self.cv_folds.value(),
|
||||
'preprocess_methods': [method for method, cb in self.preproc_checkboxes.items() if cb.isChecked()],
|
||||
'model_types': [model for model, cb in self.model_checkboxes.items() if cb.isChecked()],
|
||||
'split_methods': [method for method, cb in self.split_checkboxes.items() if cb.isChecked()]
|
||||
}
|
||||
208
src/gui/panels/step7_panel.py
Normal file
208
src/gui/panels/step7_panel.py
Normal file
@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step7 面板 - 采样点生成
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QSpinBox, QMessageBox,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step7Panel(QWidget):
|
||||
"""步骤7:采样点生成"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 去耀斑影像文件(用于独立运行)
|
||||
self.deglint_img_file = FileSelectWidget(
|
||||
"去耀斑影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.deglint_img_file)
|
||||
|
||||
# 水域掩膜文件(可选,用于独立运行)
|
||||
self.water_mask_file = FileSelectWidget(
|
||||
"水域掩膜:",
|
||||
"Mask Files (*.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.water_mask_file.label.setText("水域掩膜:")
|
||||
layout.addWidget(self.water_mask_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("采样参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.interval = QSpinBox()
|
||||
self.interval.setRange(10, 500)
|
||||
self.interval.setValue(50)
|
||||
params_layout.addRow("采样点间隔(像素):", self.interval)
|
||||
|
||||
self.sample_radius = QSpinBox()
|
||||
self.sample_radius.setRange(1, 50)
|
||||
self.sample_radius.setValue(5)
|
||||
params_layout.addRow("采样半径(像素):", self.sample_radius)
|
||||
|
||||
self.chunk_size = QSpinBox()
|
||||
self.chunk_size.setRange(100, 10000)
|
||||
self.chunk_size.setValue(1000)
|
||||
params_layout.addRow("处理块大小:", self.chunk_size)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出文件路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出采样点:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("sampling_points.csv")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'interval': self.interval.value(),
|
||||
'sample_radius': self.sample_radius.value(),
|
||||
'chunk_size': self.chunk_size.value(),
|
||||
}
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
if deglint_img_path:
|
||||
config['deglint_img_path'] = deglint_img_path
|
||||
water_mask_path = self.water_mask_file.get_path()
|
||||
if water_mask_path:
|
||||
config['water_mask_path'] = water_mask_path
|
||||
# 注意:step7_generate_sampling_points 不接受 output_path 参数,输出路径由 pipeline 内部自动生成
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'interval' in config:
|
||||
self.interval.setValue(config['interval'])
|
||||
if 'sample_radius' in config:
|
||||
self.sample_radius.setValue(config['sample_radius'])
|
||||
if 'chunk_size' in config:
|
||||
self.chunk_size.setValue(config['chunk_size'])
|
||||
if 'deglint_img_path' in config:
|
||||
self.deglint_img_file.set_path(config['deglint_img_path'])
|
||||
if 'water_mask_path' in config:
|
||||
self.water_mask_file.set_path(config['water_mask_path'])
|
||||
if 'glint_mask_path' in config:
|
||||
self.glint_mask_file.set_path(config['glint_mask_path'])
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充去耀斑影像和掩膜路径
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(用于从 step_outputs 获取绝对路径)
|
||||
"""
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 填充去耀斑影像路径(优先从 pipeline.step_outputs 获取绝对路径)
|
||||
deglint_path = None
|
||||
if pipeline and hasattr(pipeline, 'step_outputs'):
|
||||
step3_outputs = getattr(pipeline, 'step_outputs', {}).get('step3', {})
|
||||
deglint_path = (
|
||||
step3_outputs.get('deglint_image')
|
||||
or step3_outputs.get('output_path')
|
||||
or step3_outputs.get('output_file')
|
||||
or step3_outputs.get('deglint_img_path')
|
||||
)
|
||||
# 回退:从 step3 面板 widget 直接读取(可能是相对路径)
|
||||
if not deglint_path and hasattr(main_window, 'step3_panel'):
|
||||
deglint_path = main_window.step3_panel.output_file.get_path()
|
||||
|
||||
if deglint_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(deglint_path):
|
||||
deglint_path = os.path.join(self.work_dir or '', deglint_path).replace('\\', '/')
|
||||
self.deglint_img_file.set_path(deglint_path)
|
||||
|
||||
# 2. 填充水域掩膜路径(优先级:pipeline.step_outputs > step1_panel > 1_water_mask > input-test)
|
||||
water_mask_path = None
|
||||
if pipeline and hasattr(pipeline, 'step_outputs'):
|
||||
step1_outputs = getattr(pipeline, 'step_outputs', {}).get('step1', {})
|
||||
water_mask_path = (
|
||||
step1_outputs.get('water_mask')
|
||||
or step1_outputs.get('output_path')
|
||||
or step1_outputs.get('output_file')
|
||||
)
|
||||
# 回退:从 step1 面板 widget 直接读取
|
||||
if not water_mask_path and hasattr(main_window, 'step1_panel'):
|
||||
water_mask_path = main_window.step1_panel.output_file.get_path()
|
||||
# 备选:扫描 1_water_mask 目录下的 .dat 文件
|
||||
if not water_mask_path and self.work_dir:
|
||||
mask_dir = os.path.join(self.work_dir, "1_water_mask")
|
||||
if os.path.isdir(mask_dir):
|
||||
dat_files = [f for f in os.listdir(mask_dir) if f.lower().endswith('.dat')]
|
||||
if dat_files:
|
||||
water_mask_path = os.path.join(mask_dir, dat_files[0]).replace('\\', '/')
|
||||
# 备选:扫描 input-test 目录(优先匹配 water_mask_from_shp.dat)
|
||||
if not water_mask_path and self.work_dir:
|
||||
input_test_dir = os.path.join(self.work_dir, "input-test")
|
||||
if os.path.isdir(input_test_dir):
|
||||
dat_files = [f for f in os.listdir(input_test_dir) if f.lower().endswith('.dat')]
|
||||
# 优先匹配 water_mask_from_shp.dat
|
||||
for f in dat_files:
|
||||
if 'water_mask_from_shp' in f.lower():
|
||||
water_mask_path = os.path.join(input_test_dir, f).replace('\\', '/')
|
||||
break
|
||||
# 否则取第一个 .dat 文件
|
||||
if not water_mask_path and dat_files:
|
||||
water_mask_path = os.path.join(input_test_dir, dat_files[0]).replace('\\', '/')
|
||||
|
||||
if water_mask_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(water_mask_path):
|
||||
water_mask_path = os.path.join(self.work_dir or '', water_mask_path).replace('\\', '/')
|
||||
self.water_mask_file.set_path(water_mask_path)
|
||||
|
||||
# 3. 自动填充输出路径(绝对路径)
|
||||
if self.work_dir:
|
||||
output_path = os.path.join(self.work_dir, "10_sampling", "sampling_spectra.csv")
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
self.output_file.set_path(output_path.replace('\\', '/'))
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤7"""
|
||||
deglint_img_path = self.deglint_img_file.get_path()
|
||||
if not deglint_img_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择去耀斑影像文件!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step7': self.get_config()}
|
||||
main_window.run_single_step('step7', config)
|
||||
226
src/gui/panels/step8_5_panel.py
Normal file
226
src/gui/panels/step8_5_panel.py
Normal file
@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8_5 面板 - 非经验模型预测
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
|
||||
QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step8_5Panel(QWidget):
|
||||
"""步骤8.5:非经验模型预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 采样光谱CSV文件选择
|
||||
self.sampling_csv_file = FileSelectWidget(
|
||||
"采样光谱CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.sampling_csv_file)
|
||||
|
||||
# 模型目录选择
|
||||
self.models_dir_file = FileSelectWidget(
|
||||
"模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.models_dir_file.label.setText("模型目录:")
|
||||
self.models_dir_file.browse_btn.clicked.disconnect()
|
||||
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
|
||||
layout.addWidget(self.models_dir_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("预测参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.metric = QComboBox()
|
||||
self.metric.addItems(['Average Accuracy(%)', 'Min Accuracy(%)', 'Max Accuracy(%)'])
|
||||
params_layout.addRow("模型选择指标:", self.metric)
|
||||
|
||||
self.prediction_column = QLineEdit()
|
||||
self.prediction_column.setText("prediction")
|
||||
params_layout.addRow("预测列名:", self.prediction_column)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出文件夹:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_file.label.setText("输出文件夹:")
|
||||
self.output_file.browse_btn.clicked.disconnect()
|
||||
self.output_file.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充采样光谱和回归模型目录
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
|
||||
if main_window and hasattr(main_window, 'step7_panel'):
|
||||
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
|
||||
step7_output_path = ""
|
||||
if hasattr(step7_widget, 'get_path'):
|
||||
step7_output_path = step7_widget.get_path() or ""
|
||||
elif hasattr(step7_widget, 'text'):
|
||||
step7_output_path = step7_widget.text() or ""
|
||||
|
||||
if step7_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step7_output_path):
|
||||
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
|
||||
existing = self.sampling_csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.sampling_csv_file.set_path(step7_output_path)
|
||||
|
||||
# 2. 尝试从 Step6.5 界面读取回归模型目录
|
||||
if main_window and hasattr(main_window, 'step6_5_panel'):
|
||||
step6_5_widget = getattr(main_window.step6_5_panel, 'output_dir', None)
|
||||
step6_5_models_dir = ""
|
||||
if hasattr(step6_5_widget, 'get_path'):
|
||||
step6_5_models_dir = step6_5_widget.get_path() or ""
|
||||
elif hasattr(step6_5_widget, 'text'):
|
||||
step6_5_models_dir = step6_5_widget.text() or ""
|
||||
|
||||
if step6_5_models_dir:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step6_5_models_dir):
|
||||
step6_5_models_dir = os.path.join(self.work_dir or '', step6_5_models_dir).replace('\\', '/')
|
||||
existing_models = self.models_dir_file.get_path()
|
||||
if not existing_models or not existing_models.strip():
|
||||
self.models_dir_file.set_path(step6_5_models_dir)
|
||||
|
||||
# 3. 自动填充输出路径(非经验模型预测目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Non_Empirical_Prediction")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_file.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_file.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_models_dir(self):
|
||||
"""浏览模型目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "8_Regression_Modeling")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
|
||||
if dir_path:
|
||||
self.models_dir_file.set_path(dir_path)
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "11_12_13_predictions/Non_Empirical_Prediction")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出文件夹", default)
|
||||
if dir_path:
|
||||
self.output_file.set_path(dir_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'metric': self.metric.currentText(),
|
||||
'prediction_column': self.prediction_column.text(),
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if sampling_csv_path:
|
||||
config['sampling_csv_path'] = sampling_csv_path
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if models_dir:
|
||||
config['models_dir'] = models_dir
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'metric' in config:
|
||||
idx = self.metric.findText(config['metric'])
|
||||
if idx >= 0:
|
||||
self.metric.setCurrentIndex(idx)
|
||||
if 'prediction_column' in config:
|
||||
self.prediction_column.setText(config['prediction_column'])
|
||||
if 'sampling_csv_path' in config:
|
||||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||||
if 'models_dir' in config:
|
||||
self.models_dir_file.set_path(config['models_dir'])
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤8.5"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step8_5', {'step8_5': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
230
src/gui/panels/step8_75_panel.py
Normal file
230
src/gui/panels/step8_75_panel.py
Normal file
@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8_75 面板 - 自定义回归预测
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox,
|
||||
QPushButton, QCheckBox, QMessageBox, QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step8_75Panel(QWidget):
|
||||
"""步骤8.75:自定义回归预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 采样光谱CSV文件选择
|
||||
self.sampling_csv_file = FileSelectWidget(
|
||||
"采样光谱CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.sampling_csv_file)
|
||||
|
||||
# 自定义回归模型目录选择(9_Custom_Regression_Modeling)
|
||||
self.regression_models_dir = FileSelectWidget(
|
||||
"回归模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.regression_models_dir.label.setText("回归模型目录:")
|
||||
self.regression_models_dir.browse_btn.clicked.disconnect()
|
||||
self.regression_models_dir.browse_btn.clicked.connect(self.browse_regression_models_dir)
|
||||
self.regression_models_dir.set_path("") # 路径由 update_from_config 根据 work_dir 自动填充
|
||||
layout.addWidget(self.regression_models_dir)
|
||||
|
||||
# 公式CSV文件选择(用于查找index_formula)
|
||||
self.formula_csv_file = FileSelectWidget(
|
||||
"公式CSV文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.formula_csv_file.label.setText("公式CSV文件:")
|
||||
layout.addWidget(self.formula_csv_file)
|
||||
|
||||
# 输出目录选择
|
||||
self.output_dir_widget = FileSelectWidget(
|
||||
"输出目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_dir_widget.label.setText("输出目录:")
|
||||
self.output_dir_widget.browse_btn.clicked.disconnect()
|
||||
self.output_dir_widget.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
self.output_dir_widget.line_edit.setPlaceholderText("留空使用默认prediction目录")
|
||||
layout.addWidget(self.output_dir_widget)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充采样光谱和自定义回归模型目录
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
|
||||
if main_window and hasattr(main_window, 'step7_panel'):
|
||||
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
|
||||
step7_output_path = ""
|
||||
if hasattr(step7_widget, 'get_path'):
|
||||
step7_output_path = step7_widget.get_path() or ""
|
||||
elif hasattr(step7_widget, 'text'):
|
||||
step7_output_path = step7_widget.text() or ""
|
||||
|
||||
if step7_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step7_output_path):
|
||||
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
|
||||
existing = self.sampling_csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.sampling_csv_file.set_path(step7_output_path)
|
||||
|
||||
# 2. 尝试从 Step6.75 界面读取自定义回归模型目录
|
||||
if main_window and hasattr(main_window, 'step6_75_panel'):
|
||||
step6_75_widget = getattr(main_window.step6_75_panel, 'output_dir', None)
|
||||
step6_75_models_dir = ""
|
||||
if hasattr(step6_75_widget, 'get_path'):
|
||||
step6_75_models_dir = step6_75_widget.get_path() or ""
|
||||
elif hasattr(step6_75_widget, 'text'):
|
||||
step6_75_models_dir = step6_75_widget.text() or ""
|
||||
step6_75_models_dir = step6_75_models_dir.strip()
|
||||
|
||||
if step6_75_models_dir:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step6_75_models_dir):
|
||||
step6_75_models_dir = os.path.join(self.work_dir or '', step6_75_models_dir).replace('\\', '/')
|
||||
existing_models = self.regression_models_dir.get_path()
|
||||
if not existing_models or not existing_models.strip():
|
||||
self.regression_models_dir.set_path(step6_75_models_dir)
|
||||
|
||||
# 3. 自动填充回归模型目录(如果 step6_75 未提供)
|
||||
if self.work_dir:
|
||||
models_dir = self.regression_models_dir.get_path().strip()
|
||||
if not models_dir:
|
||||
default_models_dir = os.path.join(self.work_dir, "9_Custom_Regression_Modeling").replace('\\', '/')
|
||||
self.regression_models_dir.set_path(default_models_dir)
|
||||
|
||||
# 4. 自动填充输出目录(自定义回归预测目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Custom_Regression_Prediction")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir_widget.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_dir_widget.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_regression_models_dir(self):
|
||||
"""浏览回归模型目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "9_Custom_Regression_Modeling")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择回归模型目录", default)
|
||||
if dir_path:
|
||||
self.regression_models_dir.set_path(dir_path)
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "11_12_13_predictions/Custom_Regression_Prediction")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出目录", default)
|
||||
if dir_path:
|
||||
self.output_dir_widget.set_path(dir_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if sampling_csv_path:
|
||||
config['sampling_csv_path'] = sampling_csv_path
|
||||
regression_models_dir = self.regression_models_dir.get_path()
|
||||
if regression_models_dir:
|
||||
config['custom_regression_dir'] = regression_models_dir
|
||||
formula_csv_path = self.formula_csv_file.get_path()
|
||||
if formula_csv_path:
|
||||
config['formula_csv_path'] = formula_csv_path
|
||||
output_dir = self.output_dir_widget.get_path()
|
||||
if output_dir:
|
||||
config['output_dir'] = output_dir
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'sampling_csv_path' in config:
|
||||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||||
if 'custom_regression_dir' in config:
|
||||
self.regression_models_dir.set_path(config['custom_regression_dir'])
|
||||
if 'formula_csv_path' in config:
|
||||
self.formula_csv_file.set_path(config['formula_csv_path'])
|
||||
if 'output_dir' in config:
|
||||
self.output_dir_widget.set_path(config['output_dir'])
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤8.75"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
regression_models_dir = self.regression_models_dir.get_path()
|
||||
if not regression_models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择回归模型目录!")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if parent and hasattr(parent, 'run_single_step'):
|
||||
parent.run_single_step('step8_75', {'step8_75': config})
|
||||
else:
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
211
src/gui/panels/step8_panel.py
Normal file
211
src/gui/panels/step8_panel.py
Normal file
@ -0,0 +1,211 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8 面板 - 机器学习预测
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout,
|
||||
QPushButton, QCheckBox, QComboBox, QLineEdit, QMessageBox,
|
||||
QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
class Step8Panel(QWidget):
|
||||
"""步骤8:机器学习预测"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
# 采样光谱CSV文件(用于独立运行)
|
||||
self.sampling_csv_file = FileSelectWidget(
|
||||
"采样光谱CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.sampling_csv_file)
|
||||
|
||||
# 模型目录(用于独立运行)
|
||||
self.models_dir_file = FileSelectWidget(
|
||||
"模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.models_dir_file.label.setText("模型目录:")
|
||||
self.models_dir_file.browse_btn.clicked.disconnect()
|
||||
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
|
||||
layout.addWidget(self.models_dir_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("预测参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.metric = QComboBox()
|
||||
self.metric.addItems(['test_r2', 'test_rmse', 'test_mae'])
|
||||
params_layout.addRow("模型选择指标:", self.metric)
|
||||
|
||||
self.prediction_column = QLineEdit()
|
||||
self.prediction_column.setText("prediction")
|
||||
params_layout.addRow("预测列名:", self.prediction_column)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出路径
|
||||
self.output_file = FileSelectWidget(
|
||||
"输出路径:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_btn = QPushButton("独立运行此步骤")
|
||||
self.run_btn.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_btn.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_btn)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充采样光谱和模型目录
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
|
||||
# 1. 尝试从 Step7 界面读取全湖采样点 CSV 路径
|
||||
if main_window and hasattr(main_window, 'step7_panel'):
|
||||
step7_widget = getattr(main_window.step7_panel, 'output_file', None)
|
||||
step7_output_path = ""
|
||||
if hasattr(step7_widget, 'get_path'):
|
||||
step7_output_path = step7_widget.get_path() or ""
|
||||
elif hasattr(step7_widget, 'text'):
|
||||
step7_output_path = step7_widget.text() or ""
|
||||
|
||||
if step7_output_path:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step7_output_path):
|
||||
step7_output_path = os.path.join(self.work_dir or '', step7_output_path).replace('\\', '/')
|
||||
existing = self.sampling_csv_file.get_path()
|
||||
if not existing or not existing.strip():
|
||||
self.sampling_csv_file.set_path(step7_output_path)
|
||||
|
||||
# 2. 尝试从 Step6 界面读取监督模型目录
|
||||
if main_window and hasattr(main_window, 'step6_panel'):
|
||||
step6_widget = getattr(main_window.step6_panel, 'output_dir', None)
|
||||
step6_models_dir = ""
|
||||
if hasattr(step6_widget, 'get_path'):
|
||||
step6_models_dir = step6_widget.get_path() or ""
|
||||
elif hasattr(step6_widget, 'text'):
|
||||
step6_models_dir = step6_widget.text() or ""
|
||||
|
||||
if step6_models_dir:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step6_models_dir):
|
||||
step6_models_dir = os.path.join(self.work_dir or '', step6_models_dir).replace('\\', '/')
|
||||
existing_models = self.models_dir_file.get_path()
|
||||
if not existing_models or not existing_models.strip():
|
||||
self.models_dir_file.set_path(step6_models_dir)
|
||||
|
||||
# 3. 自动填充输出路径(机器学习预测目录)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "11_12_13_predictions/Machine_Learning_Prediction")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_file.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_file.set_path(output_dir)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_models_dir(self):
|
||||
"""浏览模型目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "7_Supervised_Model_Training")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", default)
|
||||
if dir_path:
|
||||
self.models_dir_file.set_path(dir_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'metric': self.metric.currentText(),
|
||||
'prediction_column': self.prediction_column.text(),
|
||||
}
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if sampling_csv_path:
|
||||
config['sampling_csv_path'] = sampling_csv_path
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if models_dir:
|
||||
config['models_dir'] = models_dir
|
||||
output_path = self.output_file.get_path()
|
||||
if output_path:
|
||||
config['output_path'] = output_path
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'metric' in config:
|
||||
idx = self.metric.findText(config['metric'])
|
||||
if idx >= 0:
|
||||
self.metric.setCurrentIndex(idx)
|
||||
if 'prediction_column' in config:
|
||||
self.prediction_column.setText(config['prediction_column'])
|
||||
if 'sampling_csv_path' in config:
|
||||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||||
if 'models_dir' in config:
|
||||
self.models_dir_file.set_path(config['models_dir'])
|
||||
if 'output_path' in config:
|
||||
self.output_file.set_path(config['output_path'])
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤8"""
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if not sampling_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
if not models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择模型目录!")
|
||||
return
|
||||
|
||||
main_window = self.window()
|
||||
if hasattr(main_window, 'run_single_step'):
|
||||
config = {'step8': self.get_config()}
|
||||
main_window.run_single_step('step8', config)
|
||||
533
src/gui/panels/step9_panel.py
Normal file
533
src/gui/panels/step9_panel.py
Normal file
@ -0,0 +1,533 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step9 面板 - 分布图生成
|
||||
"""
|
||||
|
||||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from PyQt5.QtCore import Qt, QThread, pyqtSignal
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QGroupBox, QFormLayout, QHBoxLayout,
|
||||
QLabel, QCheckBox, QPushButton, QLineEdit, QDoubleSpinBox,
|
||||
QRadioButton, QButtonGroup, QMessageBox, QFileDialog,
|
||||
)
|
||||
|
||||
from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
# Pipeline 可用性(与 core/worker_thread.py 保持一致)
|
||||
try:
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
PIPELINE_AVAILABLE = True
|
||||
except ImportError:
|
||||
PIPELINE_AVAILABLE = False
|
||||
|
||||
|
||||
class Step9BatchThread(QThread):
|
||||
"""专题图:按文件夹内多个预测 CSV 批量生成分布图。"""
|
||||
|
||||
finished_ok = pyqtSignal(int)
|
||||
failed = pyqtSignal(str)
|
||||
log_message = pyqtSignal(str, str)
|
||||
|
||||
def __init__(self, work_dir: str, csv_paths: List[str], step9_kwargs: dict, output_dir_optional: Optional[str]):
|
||||
super().__init__()
|
||||
self.work_dir = work_dir
|
||||
self.csv_paths = csv_paths
|
||||
self.step9_kwargs = step9_kwargs
|
||||
self.output_dir_optional = (output_dir_optional or "").strip() or None
|
||||
|
||||
def run(self):
|
||||
mpl_prev = None
|
||||
try:
|
||||
import matplotlib
|
||||
mpl_prev = matplotlib.get_backend()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend("Agg")
|
||||
except Exception:
|
||||
mpl_prev = None
|
||||
try:
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=self.work_dir)
|
||||
n = len(self.csv_paths)
|
||||
for i, csv_p in enumerate(self.csv_paths):
|
||||
self.log_message.emit(f"专题图 [{i + 1}/{n}] {csv_p}", "info")
|
||||
kw = {**self.step9_kwargs, "prediction_csv_path": csv_p, "skip_dependency_check": True}
|
||||
if self.output_dir_optional:
|
||||
stem = Path(csv_p).stem
|
||||
kw["output_image_path"] = str(Path(self.output_dir_optional) / f"{stem}_distribution.png")
|
||||
else:
|
||||
kw["output_image_path"] = None
|
||||
pipeline.step9_generate_distribution_map(**kw)
|
||||
self.finished_ok.emit(n)
|
||||
except Exception as e:
|
||||
self.failed.emit(f"{e}\n{traceback.format_exc()}")
|
||||
finally:
|
||||
if mpl_prev:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend(mpl_prev)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class Step9Panel(QWidget):
|
||||
"""步骤9:分布图生成"""
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self._batch_thread = None
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QVBoxLayout()
|
||||
|
||||
hint = QLabel(
|
||||
"独立运行:可选「单个 CSV」或「文件夹批量」(扫描目录下所有 .csv)。"
|
||||
"完整流程中预测 CSV 由步骤11、12、13 自动传入,无需在此选择。"
|
||||
)
|
||||
hint.setWordWrap(True)
|
||||
hint.setStyleSheet(
|
||||
f"color: {ModernStylesheet.COLORS.get('text_secondary', '#666')};"
|
||||
)
|
||||
layout.addWidget(hint)
|
||||
|
||||
mode_row = QHBoxLayout()
|
||||
self.mode_single_rb = QRadioButton("单个 CSV 文件")
|
||||
self.mode_folder_rb = QRadioButton("文件夹批量")
|
||||
self._mode_group = QButtonGroup(self)
|
||||
self._mode_group.addButton(self.mode_single_rb, 0)
|
||||
self._mode_group.addButton(self.mode_folder_rb, 1)
|
||||
mode_row.addWidget(self.mode_single_rb)
|
||||
mode_row.addWidget(self.mode_folder_rb)
|
||||
mode_row.addStretch()
|
||||
layout.addLayout(mode_row)
|
||||
|
||||
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
|
||||
radio_style = """
|
||||
QRadioButton {
|
||||
font-size: 14px;
|
||||
spacing: 8px;
|
||||
color: #333333;
|
||||
}
|
||||
QRadioButton::indicator {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border: 2px solid #999999;
|
||||
border-radius: 3px;
|
||||
background-color: white;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
border: 2px solid #0078d4;
|
||||
background-color: #0078d4;
|
||||
image: none;
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #005a9e;
|
||||
}
|
||||
"""
|
||||
self.mode_single_rb.setStyleSheet(radio_style)
|
||||
self.mode_folder_rb.setStyleSheet(radio_style)
|
||||
|
||||
self.prediction_csv_file = FileSelectWidget(
|
||||
"预测结果CSV:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.prediction_csv_file)
|
||||
|
||||
folder_row = QHBoxLayout()
|
||||
self.prediction_csv_dir_label = QLabel("预测CSV目录:")
|
||||
self.prediction_csv_dir_label.setMinimumWidth(120)
|
||||
self.prediction_csv_dir_edit = QLineEdit()
|
||||
self.prediction_csv_dir_edit.setPlaceholderText("选择含多个预测结果 CSV 的文件夹…")
|
||||
pred_dir_btn = QPushButton("浏览…")
|
||||
pred_dir_btn.setMaximumWidth(80)
|
||||
pred_dir_btn.clicked.connect(self.browse_prediction_csv_dir)
|
||||
folder_row.addWidget(self.prediction_csv_dir_label)
|
||||
folder_row.addWidget(self.prediction_csv_dir_edit, 1)
|
||||
folder_row.addWidget(pred_dir_btn)
|
||||
self._folder_row_widget = QWidget()
|
||||
self._folder_row_widget.setLayout(folder_row)
|
||||
layout.addWidget(self._folder_row_widget)
|
||||
|
||||
self.recursive_csv_cb = QCheckBox("包含子文件夹(递归扫描 *.csv)")
|
||||
layout.addWidget(self.recursive_csv_cb)
|
||||
|
||||
self.boundary_file = FileSelectWidget(
|
||||
"边界文件:",
|
||||
"Shapefiles (*.shp);;All Files (*.*)"
|
||||
)
|
||||
layout.addWidget(self.boundary_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("生成参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
self.resolution = QDoubleSpinBox()
|
||||
self.resolution.setRange(1, 1000)
|
||||
self.resolution.setValue(30)
|
||||
params_layout.addRow("分辨率(米):", self.resolution)
|
||||
|
||||
self.input_crs = QLineEdit()
|
||||
self.input_crs.setText("EPSG:32651")
|
||||
params_layout.addRow("输入坐标系:", self.input_crs)
|
||||
|
||||
self.output_crs = QLineEdit()
|
||||
self.output_crs.setText("EPSG:4326")
|
||||
params_layout.addRow("输出坐标系:", self.output_crs)
|
||||
|
||||
self.show_points = QCheckBox("显示采样点")
|
||||
params_layout.addRow("", self.show_points)
|
||||
|
||||
self.use_diffusion = QCheckBox("启用距离扩散")
|
||||
self.use_diffusion.setChecked(True)
|
||||
params_layout.addRow("", self.use_diffusion)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
|
||||
# 输出目录
|
||||
self.output_dir = FileSelectWidget(
|
||||
"输出分布图目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_dir.line_edit.setPlaceholderText("留空→工作目录/14_visualization")
|
||||
self.output_dir.browse_btn.clicked.disconnect()
|
||||
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
layout.addWidget(self.output_dir)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
self.run_button = QPushButton("独立运行此步骤")
|
||||
self.run_button.setStyleSheet(ModernStylesheet.get_button_stylesheet('success'))
|
||||
self.run_button.clicked.connect(self.run_step)
|
||||
layout.addWidget(self.run_button)
|
||||
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
# 信号绑定与初始状态
|
||||
self.mode_single_rb.toggled.connect(self._toggle_input_mode)
|
||||
self.mode_folder_rb.toggled.connect(self._toggle_input_mode)
|
||||
self.mode_single_rb.setChecked(True) # 默认选中"单个 CSV"
|
||||
self._toggle_input_mode() # 根据默认值设置初始显示状态
|
||||
|
||||
def _toggle_input_mode(self):
|
||||
"""槽函数:根据单选框状态动态显示/隐藏对应的输入组件。"""
|
||||
folder_mode = self.mode_folder_rb.isChecked()
|
||||
# 单个 CSV 模式:显示单文件选择,隐藏文件夹选择
|
||||
self.prediction_csv_file.setVisible(not folder_mode)
|
||||
# 文件夹批量模式:显示文件夹选择 + 递归选项,隐藏单文件选择
|
||||
self._folder_row_widget.setVisible(folder_mode)
|
||||
self.recursive_csv_cb.setVisible(folder_mode)
|
||||
|
||||
def _get_default_work_dir(self):
|
||||
"""获取 work_dir,优先用 panel 自身缓存的,否则尝试从主窗口取"""
|
||||
if hasattr(self, 'work_dir') and self.work_dir:
|
||||
return str(self.work_dir)
|
||||
mw = self.window()
|
||||
if mw and hasattr(mw, 'work_dir') and mw.work_dir:
|
||||
return str(mw.work_dir)
|
||||
return ""
|
||||
|
||||
def browse_prediction_csv_dir(self):
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "11_12_13_predictions")
|
||||
d = QFileDialog.getExistingDirectory(self, "选择预测结果 CSV 所在文件夹", default)
|
||||
if d:
|
||||
self.prediction_csv_dir_edit.setText(d)
|
||||
|
||||
def _collect_csv_paths_from_folder(self) -> List[str]:
|
||||
folder = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
if not folder or not os.path.isdir(folder):
|
||||
return []
|
||||
root = Path(folder)
|
||||
if self.recursive_csv_cb.isChecked():
|
||||
files = sorted(root.rglob("*.csv"))
|
||||
else:
|
||||
files = sorted(root.glob("*.csv"))
|
||||
return [str(p) for p in files if p.is_file()]
|
||||
|
||||
def _step9_base_pipeline_kwargs(self) -> dict:
|
||||
return {
|
||||
'boundary_shp_path': self.boundary_file.get_path(),
|
||||
'resolution': self.resolution.value(),
|
||||
'input_crs': self.input_crs.text(),
|
||||
'output_crs': self.output_crs.text(),
|
||||
'show_sample_points': self.show_points.isChecked(),
|
||||
'use_distance_diffusion': self.use_diffusion.isChecked(),
|
||||
}
|
||||
|
||||
def get_config(self):
|
||||
pred_csv = (self.prediction_csv_file.get_path() or "").strip()
|
||||
folder_mode = self.mode_folder_rb.isChecked()
|
||||
pred_dir = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
config = {
|
||||
'step9_batch_mode': 'folder' if folder_mode else 'single',
|
||||
'prediction_csv_dir': pred_dir if pred_dir else None,
|
||||
'recursive_csv_scan': self.recursive_csv_cb.isChecked(),
|
||||
'prediction_csv_path': None if folder_mode else (pred_csv if pred_csv else None),
|
||||
'boundary_shp_path': self.boundary_file.get_path(),
|
||||
'resolution': self.resolution.value(),
|
||||
'input_crs': self.input_crs.text(),
|
||||
'output_crs': self.output_crs.text(),
|
||||
'show_sample_points': self.show_points.isChecked(),
|
||||
'use_distance_diffusion': self.use_diffusion.isChecked(),
|
||||
}
|
||||
out_dir = (self.output_dir.get_path() or "").strip()
|
||||
if not folder_mode and pred_csv and out_dir:
|
||||
stem = Path(pred_csv).stem
|
||||
config['output_image_path'] = str(Path(out_dir) / f"{stem}_distribution.png")
|
||||
else:
|
||||
config['output_image_path'] = None
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
mode = config.get('step9_batch_mode', 'single')
|
||||
if mode == 'folder':
|
||||
self.mode_folder_rb.setChecked(True)
|
||||
else:
|
||||
self.mode_single_rb.setChecked(True)
|
||||
if config.get('prediction_csv_dir'):
|
||||
self.prediction_csv_dir_edit.setText(str(config['prediction_csv_dir']))
|
||||
if 'recursive_csv_scan' in config:
|
||||
self.recursive_csv_cb.setChecked(bool(config['recursive_csv_scan']))
|
||||
if 'prediction_csv_path' in config and config['prediction_csv_path']:
|
||||
self.prediction_csv_file.set_path(str(config['prediction_csv_path']))
|
||||
if 'boundary_shp_path' in config:
|
||||
self.boundary_file.set_path(config['boundary_shp_path'])
|
||||
if 'resolution' in config:
|
||||
self.resolution.setValue(config['resolution'])
|
||||
if 'input_crs' in config:
|
||||
self.input_crs.setText(config['input_crs'])
|
||||
if 'output_crs' in config:
|
||||
self.output_crs.setText(config['output_crs'])
|
||||
if 'show_sample_points' in config:
|
||||
self.show_points.setChecked(config['show_sample_points'])
|
||||
if 'use_distance_diffusion' in config:
|
||||
self.use_diffusion.setChecked(config['use_distance_diffusion'])
|
||||
if 'output_dir' in config and config['output_dir']:
|
||||
self.output_dir.set_path(str(config['output_dir']))
|
||||
elif config.get('output_image_path'):
|
||||
p = Path(str(config['output_image_path']))
|
||||
if p.parent and str(p.parent) != '.':
|
||||
self.output_dir.set_path(str(p.parent))
|
||||
|
||||
def update_from_config(self, work_dir=None, pipeline=None):
|
||||
"""从全局配置自动填充预测结果目录
|
||||
|
||||
优先使用 Step8(机器学习预测)的输出目录作为待预测 CSV 目录;
|
||||
其次回退到 Step8.5(回归预测)或 Step8.75(自定义回归预测)的输出目录。
|
||||
|
||||
Args:
|
||||
work_dir: 工作目录路径
|
||||
pipeline: Pipeline 实例(未使用,保留接口兼容性)
|
||||
"""
|
||||
try:
|
||||
import traceback
|
||||
|
||||
if work_dir:
|
||||
self.work_dir = work_dir
|
||||
elif hasattr(self, 'work_dir') and self.work_dir:
|
||||
pass
|
||||
else:
|
||||
self.work_dir = None
|
||||
|
||||
main_window = self.window()
|
||||
if not main_window:
|
||||
return
|
||||
|
||||
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
|
||||
pred_dir = None
|
||||
if hasattr(main_window, 'step8_panel'):
|
||||
step8_widget = getattr(main_window.step8_panel, 'output_file', None)
|
||||
step8_output = ""
|
||||
if hasattr(step8_widget, 'get_path'):
|
||||
step8_output = step8_widget.get_path() or ""
|
||||
elif hasattr(step8_widget, 'text'):
|
||||
step8_output = step8_widget.text() or ""
|
||||
|
||||
if step8_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step8_output):
|
||||
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
|
||||
# 提取父目录后追加 Machine_Learning_Prediction(最底层真实子目录)
|
||||
base_pred_dir = str(Path(step8_output).parent)
|
||||
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
|
||||
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
|
||||
|
||||
# 2. 备选:从 Step8.5 界面读取非经验预测输出目录
|
||||
if not pred_dir and hasattr(main_window, 'step8_5_panel'):
|
||||
step8_5_widget = getattr(main_window.step8_5_panel, 'output_file', None)
|
||||
step8_5_output = ""
|
||||
if hasattr(step8_5_widget, 'get_path'):
|
||||
step8_5_output = step8_5_widget.get_path() or ""
|
||||
elif hasattr(step8_5_widget, 'text'):
|
||||
step8_5_output = step8_5_widget.text() or ""
|
||||
|
||||
if step8_5_output:
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step8_5_output):
|
||||
step8_5_output = os.path.join(self.work_dir or '', step8_5_output).replace('\\', '/')
|
||||
pred_dir = str(Path(step8_5_output).parent)
|
||||
|
||||
# 3. 备选:从 Step8.75 界面读取自定义回归预测输出目录
|
||||
if not pred_dir and hasattr(main_window, 'step8_75_panel'):
|
||||
step8_75_widget = getattr(main_window.step8_75_panel, 'output_dir_widget', None)
|
||||
step8_75_output = ""
|
||||
if hasattr(step8_75_widget, 'get_path'):
|
||||
step8_75_output = step8_75_widget.get_path() or ""
|
||||
elif hasattr(step8_75_widget, 'text'):
|
||||
step8_75_output = step8_75_widget.text() or ""
|
||||
|
||||
if step8_75_output:
|
||||
pred_dir = step8_75_output
|
||||
|
||||
# 自动填入"预测CSV目录"(文件夹批量模式)
|
||||
if pred_dir:
|
||||
existing_dir = (self.prediction_csv_dir_edit.text() or "").strip()
|
||||
if not existing_dir:
|
||||
self.prediction_csv_dir_edit.setText(pred_dir)
|
||||
# 切换到文件夹批量模式
|
||||
self.mode_folder_rb.setChecked(True)
|
||||
|
||||
# 4. 自动填充输出目录(14_visualization)
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "14_visualization")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
existing_out = self.output_dir.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_dir.set_path(output_dir)
|
||||
|
||||
# 5. 自动探测原始矢量边界文件(.shp)作为专题图底图
|
||||
# 优先回溯 input-test/roi.shp,geopandas.read_file 仅支持矢量格式
|
||||
if self.work_dir:
|
||||
possible_shp = None
|
||||
candidates = [
|
||||
Path(self.work_dir).parent / "input-test" / "roi.shp",
|
||||
Path(self.work_dir) / "roi.shp",
|
||||
Path(self.work_dir).parent / "roi.shp",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.exists() and candidate.suffix.lower() == ".shp":
|
||||
possible_shp = candidate
|
||||
break
|
||||
|
||||
existing_boundary = (self.boundary_file.get_path() or "").strip()
|
||||
if not existing_boundary and possible_shp:
|
||||
self.boundary_file.set_path(str(possible_shp))
|
||||
elif not existing_boundary:
|
||||
# 未找到 .shp 时清空并提示用户手动选择矢量文件
|
||||
self.boundary_file.set_path("")
|
||||
print("⚠️ 提示:专题图生成模块需传入标准矢量边界文件 (.shp),请手动选择。")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
default = self._get_default_work_dir()
|
||||
if default:
|
||||
default = os.path.join(default, "14_visualization")
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出分布图目录", default)
|
||||
if dir_path:
|
||||
self.output_dir.set_path(dir_path)
|
||||
|
||||
def run_step(self):
|
||||
"""独立运行步骤9"""
|
||||
if self._batch_thread and self._batch_thread.isRunning():
|
||||
QMessageBox.information(self, "提示", "批量任务正在运行,请稍候。")
|
||||
return
|
||||
|
||||
boundary_shp_path = self.boundary_file.get_path()
|
||||
if not boundary_shp_path:
|
||||
QMessageBox.warning(self, "输入验证失败", "请选择边界文件")
|
||||
return
|
||||
if not os.path.exists(boundary_shp_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "边界文件不存在")
|
||||
return
|
||||
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, 'run_single_step'):
|
||||
parent = parent.parent()
|
||||
|
||||
if not parent or not hasattr(parent, 'run_single_step'):
|
||||
QMessageBox.critical(self, "错误", "无法找到父级GUI对象")
|
||||
return
|
||||
|
||||
if self.mode_folder_rb.isChecked():
|
||||
csv_list = self._collect_csv_paths_from_folder()
|
||||
if not csv_list:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入验证失败",
|
||||
"所选文件夹中未找到 .csv 文件,或目录无效。\n"
|
||||
"可勾选「包含子文件夹」以递归扫描。",
|
||||
)
|
||||
return
|
||||
if not PIPELINE_AVAILABLE:
|
||||
QMessageBox.critical(self, "错误", "Pipeline 模块不可用,无法批量生成专题图。")
|
||||
return
|
||||
work_dir = getattr(parent, "work_dir", None) or "./work_dir"
|
||||
work_dir = str(work_dir)
|
||||
base_kw = self._step9_base_pipeline_kwargs()
|
||||
out_dir_opt = (self.output_dir.get_path() or "").strip() or None
|
||||
self.run_button.setEnabled(False)
|
||||
self._batch_thread = Step9BatchThread(work_dir, csv_list, base_kw, out_dir_opt)
|
||||
main_win = parent
|
||||
|
||||
def _batch_log(msg, lvl):
|
||||
if hasattr(main_win, "log_message"):
|
||||
main_win.log_message(msg, lvl)
|
||||
|
||||
self._batch_thread.log_message.connect(_batch_log, Qt.QueuedConnection)
|
||||
self._batch_thread.finished_ok.connect(self._on_step9_batch_ok, Qt.QueuedConnection)
|
||||
self._batch_thread.failed.connect(self._on_step9_batch_fail, Qt.QueuedConnection)
|
||||
self._batch_thread.finished.connect(lambda: self.run_button.setEnabled(True), Qt.QueuedConnection)
|
||||
self._batch_thread.start()
|
||||
if hasattr(parent, "log_message"):
|
||||
parent.log_message(f"专题图批量:共 {len(csv_list)} 个 CSV,工作目录 {work_dir}", "info")
|
||||
return
|
||||
|
||||
prediction_csv_path = (self.prediction_csv_file.get_path() or "").strip()
|
||||
if not prediction_csv_path:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"输入验证失败",
|
||||
"请选择「预测结果 CSV」文件,或切换到「文件夹批量」。",
|
||||
)
|
||||
return
|
||||
if not os.path.isfile(prediction_csv_path):
|
||||
QMessageBox.warning(self, "输入验证失败", "预测结果 CSV 不存在或不是文件")
|
||||
return
|
||||
|
||||
config = self.get_config()
|
||||
parent.run_single_step('step9', {'step9': config})
|
||||
|
||||
def _on_step9_batch_ok(self, n: int):
|
||||
QMessageBox.information(self, "完成", f"已批量生成 {n} 个分布图。")
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, "log_message"):
|
||||
parent = parent.parent()
|
||||
if parent and hasattr(parent, "log_message"):
|
||||
parent.log_message(f"专题图批量完成,共 {n} 个文件。", "info")
|
||||
|
||||
def _on_step9_batch_fail(self, err: str):
|
||||
QMessageBox.critical(self, "失败", f"批量生成中断:\n{err[:900]}")
|
||||
parent = self.parent()
|
||||
while parent and not hasattr(parent, "log_message"):
|
||||
parent = parent.parent()
|
||||
if parent and hasattr(parent, "log_message"):
|
||||
parent.log_message(err, "error")
|
||||
1817
src/gui/panels/visualization_panel.py
Normal file
1817
src/gui/panels/visualization_panel.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -6,6 +6,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
@ -19,6 +20,15 @@ from docx.shared import Inches, Pt, Cm
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
from docx.enum.section import WD_SECTION
|
||||
from docx.oxml.ns import qn
|
||||
|
||||
|
||||
def get_resource_path(relative_path: str) -> str:
|
||||
"""获取资源的绝对路径,适配 PyInstaller 打包环境。"""
|
||||
if hasattr(sys, '_MEIPASS'):
|
||||
return os.path.join(sys._MEIPASS, relative_path)
|
||||
return os.path.abspath(
|
||||
os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), relative_path)
|
||||
)
|
||||
from docx.oxml import OxmlElement
|
||||
from docx.shared import RGBColor
|
||||
import pandas as pd
|
||||
@ -848,8 +858,8 @@ class WaterQualityReportGenerator:
|
||||
section.different_first_page_header_footer = True
|
||||
|
||||
# 1. 左上角图片(增大) - 使用相对路径
|
||||
cover_top_img_path = Path(__file__).parent.parent.parent / "data" / "icons" / "word" / "lica.png"
|
||||
if cover_top_img_path.exists():
|
||||
cover_top_img_path = get_resource_path("data/icons/word/lica.png")
|
||||
if os.path.isfile(cover_top_img_path):
|
||||
try:
|
||||
p = doc.add_paragraph()
|
||||
p.alignment = WD_ALIGN_PARAGRAPH.LEFT
|
||||
@ -897,8 +907,8 @@ class WaterQualityReportGenerator:
|
||||
|
||||
|
||||
# 4. 底部图片(增大) - 使用相对路径
|
||||
cover_bottom_img_path = Path(__file__).parent.parent.parent / "data" / "icons" / "word" / "fenmian.png"
|
||||
if cover_bottom_img_path.exists():
|
||||
cover_bottom_img_path = get_resource_path("data/icons/word/fenmian.png")
|
||||
if os.path.isfile(cover_bottom_img_path):
|
||||
try:
|
||||
p = doc.add_paragraph()
|
||||
p.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
||||
@ -960,8 +970,8 @@ class WaterQualityReportGenerator:
|
||||
|
||||
|
||||
# 第一张图片 - 使用相对路径
|
||||
img1_path = Path(__file__).parent.parent.parent / "data" / "icons" / "word" / "屏幕截图 2026-03-31 144131.png"
|
||||
if img1_path.exists():
|
||||
img1_path = get_resource_path("data/icons/word/屏幕截图 2026-03-31 144131.png")
|
||||
if os.path.isfile(img1_path):
|
||||
p = doc.add_paragraph()
|
||||
p.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
||||
p.add_run().add_picture(str(img1_path), width=Inches(6.0))
|
||||
@ -999,8 +1009,8 @@ class WaterQualityReportGenerator:
|
||||
self._style_heading(h, level=1)
|
||||
|
||||
# 插入图片 - 使用相对路径
|
||||
processing_img_path = Path(__file__).parent.parent.parent / "data" / "icons" / "word" / "liucheng.png"
|
||||
if processing_img_path.exists():
|
||||
processing_img_path = get_resource_path("data/icons/word/liucheng.png")
|
||||
if os.path.isfile(processing_img_path):
|
||||
p = doc.add_paragraph()
|
||||
p.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
||||
p.add_run().add_picture(str(processing_img_path), width=Inches(6.5))
|
||||
@ -1356,8 +1366,8 @@ class WaterQualityReportGenerator:
|
||||
header_para = header.add_paragraph()
|
||||
|
||||
# 1. 最左侧图片 - 使用相对路径
|
||||
header_img_path = Path(__file__).parent.parent.parent / "data" / "icons" / "word" / "lica.png"
|
||||
if header_img_path.exists():
|
||||
header_img_path = get_resource_path("data/icons/word/lica.png")
|
||||
if os.path.isfile(header_img_path):
|
||||
try:
|
||||
run_img = header_para.add_run()
|
||||
run_img.add_picture(str(header_img_path), width=Inches(1.6))
|
||||
|
||||
@ -1003,67 +1003,84 @@ class ReportGenerator:
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
from modeling_batch import WaterQualityModelingBatch
|
||||
|
||||
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||
import joblib
|
||||
|
||||
modeler = WaterQualityModelingBatch(models_dir)
|
||||
|
||||
# 需要先加载训练结果
|
||||
# 这里假设results已经存储在modeler中,或者需要从保存的文件中读取
|
||||
# 由于modeling_batch.py的结构,我们需要另一种方式来获取所有结果
|
||||
|
||||
# 尝试遍历模型目录,查找所有保存的结果
|
||||
models_path = Path(models_dir)
|
||||
all_results = []
|
||||
|
||||
# 遍历所有目标参数文件夹
|
||||
for target_folder in models_path.iterdir():
|
||||
if not target_folder.is_dir():
|
||||
continue
|
||||
|
||||
target_name = target_folder.name
|
||||
|
||||
# 查找所有模型文件
|
||||
for model_file in target_folder.rglob("*.pkl"):
|
||||
# 从文件名提取信息(假设格式为:{preprocess}_{model}_{split}.pkl)
|
||||
model_info = {
|
||||
'target': target_name,
|
||||
'model_file': str(model_file),
|
||||
'preprocess': 'Unknown',
|
||||
'model': 'Unknown',
|
||||
'split_method': 'Unknown'
|
||||
|
||||
# 递归扫描 *.joblib 和 *.pkl,兼容 artifacts_dir/target_name/ 的所有子目录层级
|
||||
model_files = list(models_path.rglob("*.joblib")) + list(models_path.rglob("*.pkl"))
|
||||
|
||||
for model_file in model_files:
|
||||
# 目标参数取直系父目录名(符合 artifacts_dir/target_name/ 结构)
|
||||
target_name = model_file.parent.name
|
||||
stem = model_file.stem
|
||||
|
||||
# 文件名格式:{safe_target}_{preprocess}_{model_name}.joblib
|
||||
# 使用 split('_', 2) 最多切 3 段,目标 1 段、预处理 1 段、模型 1 段
|
||||
parts = stem.split('_', 2)
|
||||
preprocess = parts[1] if len(parts) > 1 else 'Unknown'
|
||||
model_name_str = parts[2] if len(parts) > 2 else stem
|
||||
|
||||
# 尝试从 joblib/pkl 读取元数据,提取性能指标
|
||||
metrics = {}
|
||||
try:
|
||||
data = joblib.load(model_file)
|
||||
metadata = data.get('metadata', {})
|
||||
metrics = {
|
||||
'train_r2': metadata.get('train_r2', 'N/A'),
|
||||
'test_r2': metadata.get('test_r2', 'N/A'),
|
||||
'test_rmse': metadata.get('test_rmse', 'N/A'),
|
||||
'train_rmse': metadata.get('train_rmse', 'N/A'),
|
||||
'train_mae': metadata.get('train_mae', 'N/A'),
|
||||
'test_mae': metadata.get('test_mae', 'N/A'),
|
||||
'cv_mean': metadata.get('cv_mean', 'N/A'),
|
||||
}
|
||||
|
||||
# 尝试从文件名解析
|
||||
parts = model_file.stem.split('_')
|
||||
if len(parts) >= 3:
|
||||
model_info['preprocess'] = parts[0]
|
||||
model_info['model'] = parts[1]
|
||||
model_info['split_method'] = parts[2]
|
||||
|
||||
all_results.append(model_info)
|
||||
|
||||
# 如果有训练结果数据,使用实际指标
|
||||
# 否则创建一个基本的摘要
|
||||
except Exception:
|
||||
pass # 加载失败时 metrics 保持为空字典,摘要中该列为 N/A
|
||||
|
||||
all_results.append({
|
||||
'target': target_name,
|
||||
'model_file': str(model_file),
|
||||
'preprocess': preprocess,
|
||||
'model': model_name_str,
|
||||
**metrics,
|
||||
})
|
||||
|
||||
summary_data = []
|
||||
for result in all_results:
|
||||
summary_data.append({
|
||||
'目标参数': result['target'],
|
||||
'预处理方法': result['preprocess'],
|
||||
'模型名称': result['model'],
|
||||
'划分方法': result['split_method'],
|
||||
'模型文件': result['model_file']
|
||||
'模型文件': result['model_file'],
|
||||
'训练集R²': result.get('train_r2', 'N/A'),
|
||||
'测试集R²': result.get('test_r2', 'N/A'),
|
||||
'测试集RMSE': result.get('test_rmse', 'N/A'),
|
||||
'训练集RMSE': result.get('train_rmse', 'N/A'),
|
||||
'训练集MAE': result.get('train_mae', 'N/A'),
|
||||
'测试集MAE': result.get('test_mae', 'N/A'),
|
||||
'CV均值': result.get('cv_mean', 'N/A'),
|
||||
})
|
||||
|
||||
|
||||
if not summary_data:
|
||||
print("警告:未找到模型文件,生成空摘要")
|
||||
summary_data = [{
|
||||
'目标参数': 'No Data',
|
||||
'预处理方法': 'N/A',
|
||||
'模型名称': 'N/A',
|
||||
'划分方法': 'N/A',
|
||||
'模型文件': 'N/A'
|
||||
'模型文件': 'N/A',
|
||||
'训练集R²': 'N/A',
|
||||
'测试集R²': 'N/A',
|
||||
'测试集RMSE': 'N/A',
|
||||
'训练集RMSE': 'N/A',
|
||||
'训练集MAE': 'N/A',
|
||||
'测试集MAE': 'N/A',
|
||||
'CV均值': 'N/A',
|
||||
}]
|
||||
|
||||
|
||||
df_summary = pd.DataFrame(summary_data)
|
||||
|
||||
if output_path is None:
|
||||
|
||||
@ -96,8 +96,14 @@ class BandMathCalculator:
|
||||
|
||||
print(f"计算表达式: {calc_expression}")
|
||||
|
||||
# 安全地计算表达式
|
||||
result = eval(calc_expression)
|
||||
# 【新增安全防护】引入 numpy 命名空间,让 eval 引擎安全识别 nan 与 inf
|
||||
import numpy as np
|
||||
try:
|
||||
# 即使 calc_expression 含有纯字符 nan,也能被 np.nan 安全接管
|
||||
result = eval(calc_expression, {"__builtins__": None}, {"nan": np.nan, "inf": np.inf, "np": np})
|
||||
except Exception as e:
|
||||
print(f"⚠️ 警告:公式计算异常 ({e}),该点赋值为 nan")
|
||||
result = np.nan
|
||||
|
||||
# 返回结果
|
||||
if var_name:
|
||||
|
||||
@ -14,17 +14,67 @@ def rasterize_envi_xml(shp_filepath):
|
||||
|
||||
@timeit
|
||||
def rasterize_shp(shp_filepath, raster_fn_out, img_path, NoData_value=None):
|
||||
# ---------- 防御性处理:路径标准化 ----------
|
||||
shp_filepath = os.path.abspath(shp_filepath).replace('\\', '/')
|
||||
print(f"[DEBUG rasterize_shp] 标准化后的 SHP 路径: {shp_filepath}")
|
||||
|
||||
# 检查伴随文件完整性
|
||||
shp_base = os.path.splitext(shp_filepath)[0]
|
||||
for ext in ['.dbf', '.shx', '.prj']:
|
||||
companion = shp_base + ext
|
||||
if os.path.exists(companion):
|
||||
print(f"[DEBUG rasterize_shp] 伴随文件存在: {companion}")
|
||||
else:
|
||||
print(f"[WARNING rasterize_shp] 伴随文件缺失: {companion}")
|
||||
|
||||
# 确保 GDAL/OGR 驱动已注册
|
||||
gdal.AllRegister()
|
||||
ogr.RegisterAll()
|
||||
|
||||
# 检查 ESRI Shapefile 驱动
|
||||
driver = ogr.GetDriverByName("ESRI Shapefile")
|
||||
if driver is None:
|
||||
raise RuntimeError(
|
||||
"系统中未找到 ESRI Shapefile 驱动!请检查 GDAL 是否正确安装及是否包含 Shapefile 支持。"
|
||||
)
|
||||
print(f"[DEBUG rasterize_shp] ESRI Shapefile 驱动: {driver.GetName()}")
|
||||
|
||||
# 打开参考影像获取尺寸信息
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开参考影像文件: {img_path}")
|
||||
im_width = dataset.RasterXSize
|
||||
im_height = dataset.RasterYSize
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
imgdata_in = dataset.GetRasterBand(1).ReadAsArray()
|
||||
del dataset
|
||||
|
||||
# Open the data source and read in the extent
|
||||
# ---------- 打开 SHP 文件(双重尝试获取详细错误) ----------
|
||||
source_ds = gdal.OpenEx(shp_filepath, gdal.OF_VECTOR)
|
||||
if source_ds is None:
|
||||
raise ValueError(f"无法打开shapefile: {shp_filepath}")
|
||||
# gdal.OpenEx 失败,尝试 ogr.Open 获取更详细的错误信息
|
||||
try:
|
||||
ogr_ds = ogr.Open(shp_filepath)
|
||||
except Exception as ogr_err:
|
||||
raise RuntimeError(
|
||||
f"GDAL/OGR 无法打开 SHP 文件(详细原因):\n"
|
||||
f" ogr.Open 抛出异常: {str(ogr_err)}\n"
|
||||
f" 文件路径: {shp_filepath}\n"
|
||||
f"常见原因:\n"
|
||||
f" 1. 路径包含中文/空格/特殊字符(建议复制到纯英文路径下重试)\n"
|
||||
f" 2. .dbf 或 .shx 伴随文件缺失或损坏\n"
|
||||
f" 3. GDAL 未注册 ESRI Shapefile 驱动\n"
|
||||
f" 4. 文件被其他程序锁定"
|
||||
)
|
||||
if ogr_ds is None:
|
||||
raise RuntimeError(
|
||||
f"ogr.Open 和 gdal.OpenEx 均返回 None,无法打开 SHP 文件。\n"
|
||||
f"文件路径: {shp_filepath}\n"
|
||||
f"请检查:\n"
|
||||
f" 1. 所有伴随文件(.dbf/.shx/.prg)是否齐全\n"
|
||||
f" 2. 文件是否被其他程序占用\n"
|
||||
f" 3. 路径中是否存在不支持的字符"
|
||||
)
|
||||
|
||||
# 检查图层数量,如果有多层,指定使用第一层
|
||||
layer_count = source_ds.GetLayerCount()
|
||||
|
||||
@ -1,10 +1,22 @@
|
||||
from src.utils.util import *
|
||||
import math
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
采样点生成模块 - 提供分块采样和光谱数据提取功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
|
||||
# GDAL 环境变量保护(放在最前面,防止路径/编码问题)
|
||||
os.environ['GDAL_FILENAME_IS_UTF8'] = 'YES'
|
||||
os.environ['SHAPE_ENCODING'] = 'UTF-8'
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from osgeo import gdal, ogr
|
||||
import spectral
|
||||
from scipy import ndimage
|
||||
from src.utils.util import write_bands
|
||||
|
||||
try:
|
||||
from skimage import morphology
|
||||
from skimage.morphology import skeletonize, medial_axis
|
||||
@ -87,6 +99,12 @@ def get_spectral_sampling_points_chunked(bil_file, water_mask_shp, severe_glint=
|
||||
ogr.UseExceptions()
|
||||
|
||||
try:
|
||||
# ---------- 路径归一化 + 存在性检查 ----------
|
||||
bil_file = os.path.abspath(bil_file).replace('\\', '/')
|
||||
print(f"[路径检查] 去耀斑影像: {bil_file}")
|
||||
if not os.path.exists(bil_file):
|
||||
raise FileNotFoundError(f"【后端错误】无法在磁盘上找到指定的去耀斑影像: {bil_file}")
|
||||
|
||||
# 打开bil文件
|
||||
dataset_bil = gdal.Open(bil_file)
|
||||
if dataset_bil is None:
|
||||
@ -337,6 +355,45 @@ def get_spectral_sampling_points_chunked(bil_file, water_mask_shp, severe_glint=
|
||||
if f:
|
||||
f.close()
|
||||
|
||||
# ==============================================================================
|
||||
# 🚀 终极手术植入点:带强行环境净化的特征引擎挂载
|
||||
# ==============================================================================
|
||||
|
||||
# 2. 安全校验路径落盘状态
|
||||
if output_csvpath and os.path.exists(str(output_csvpath)):
|
||||
try:
|
||||
from src.utils.water_index import WaterQualityIndexCalculator
|
||||
|
||||
print("\n[特征引擎挂载] 正在为采样点自动追加 45 个水质指数衍生特征...")
|
||||
|
||||
# 读取基础底座(50列光谱)
|
||||
base_df = pd.read_csv(output_csvpath)
|
||||
|
||||
# 实例化计算器
|
||||
calc = WaterQualityIndexCalculator()
|
||||
|
||||
# 提取有效算法
|
||||
algorithm_methods = [
|
||||
m for m in dir(calc)
|
||||
if not m.startswith('_') and m not in ['find_closest_wavelength', 'calculate_all_indices']
|
||||
]
|
||||
|
||||
# 就地追加 45 列衍生指数
|
||||
for algo_name in algorithm_methods:
|
||||
try:
|
||||
algo_func = getattr(calc, algo_name)
|
||||
base_df[algo_name] = algo_func(base_df)
|
||||
except Exception:
|
||||
base_df[algo_name] = np.nan
|
||||
|
||||
# 覆盖重写最终结果!
|
||||
base_df.to_csv(output_csvpath, index=False, encoding='utf-8-sig')
|
||||
print(f"✓ 特征扩充大功告成!当前文件总维度完美适配模型: {base_df.shape}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠ 警告:追加特征失败,保留原基础光谱。死因: {e}")
|
||||
# ==============================================================================
|
||||
|
||||
return x_out, y_out, np.array(spectral_out)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user