Compare commits
117 Commits
82af2d75d3
...
Mega-1.1
| Author | SHA1 | Date | |
|---|---|---|---|
| b2435d66c3 | |||
| e5bb9c5cd9 | |||
| f93dbeb848 | |||
| f61a3dfb1d | |||
| d6c003a211 | |||
| 3ee4e90b31 | |||
| 3f217e95b0 | |||
| 2261b4b30e | |||
| 2d45610aa6 | |||
| f1cc339d4a | |||
| f6455b71ba | |||
| 39e8c29913 | |||
| 19c86e6e44 | |||
| bb5c2a50f8 | |||
| a58744cfbb | |||
| 1949711cda | |||
| 191a4b681d | |||
| 91881d564a | |||
| c2740c2bde | |||
| b3a6855881 | |||
| 6a962f5e8f | |||
| 9cb3c8ed0d | |||
| 48668c9e74 | |||
| 6fc0394fe2 | |||
| f8d5ea2eb8 | |||
| ef3de632d3 | |||
| 3d4462f4e9 | |||
| 84f0f6058f | |||
| 61bd8582e5 | |||
| bd4263d2ca | |||
| afe9eaff2c | |||
| e993a184bd | |||
| 2a89fdc62c | |||
| e62f53bf77 | |||
| 1e0e7d1973 | |||
| 15547bddfb | |||
| 027981e9a6 | |||
| 5084f7d049 | |||
| 0238aa66ab | |||
| 03c788a16c | |||
| d41262aa18 | |||
| 0a0ede2e02 | |||
| 60a9d7d922 | |||
| 82e0b92af6 | |||
| a9e77d2ad0 | |||
| f73a7d8999 | |||
| be47b70594 | |||
| 4c9ca2aa03 | |||
| 89bdcbc27a | |||
| 04669bdee8 | |||
| e59703f163 | |||
| 3584c07b67 | |||
| 1ad4c54b80 | |||
| 5d75d3371b | |||
| d3262ae80d | |||
| 7c7a31ce00 | |||
| 604886abb3 | |||
| 3c4d4081a4 | |||
| 184f5fe9f4 | |||
| aa539db9bd | |||
| 016c895803 | |||
| 16fc92648b | |||
| 0493ba7916 | |||
| 2671c0837a | |||
| 320f2f18f2 | |||
| cfe4c50c31 | |||
| 7571762e63 | |||
| 04a321d225 | |||
| fa9c940074 | |||
| c3cc2ef77e | |||
| 4ca90b0e79 | |||
| 6d49e80c7e | |||
| 9ebe4fe4d3 | |||
| 41c6a64628 | |||
| 2872788cc3 | |||
| 90ba5a5fe2 | |||
| c9b9eded84 | |||
| 47cbb4a013 | |||
| 593719e7d0 | |||
| bf2496badc | |||
| 28394f2eda | |||
| aefc9d5aac | |||
| 624a5bdcd4 | |||
| 371e7a2745 | |||
| d22414bf7d | |||
| e57fdb4f75 | |||
| d5dd2ba1da | |||
| 1cbd38a8e0 | |||
| e3debbcb15 | |||
| 2b76d7908f | |||
| 4efe5b871e | |||
| 2139715829 | |||
| 64aa5b8f40 | |||
| 343e316799 | |||
| 517bb28611 | |||
| 60a2a15188 | |||
| 170d347e21 | |||
| bf4237b160 | |||
| cf387c40ab | |||
| 94ed2f1f8d | |||
| 2c52ca19c5 | |||
| 2a4a7ec7be | |||
| 5a55be286f | |||
| 9ba39a7bff | |||
| d15a7a1e2b | |||
| 6d4d802ffe | |||
| abac272b31 | |||
| 95d30d8d81 | |||
| 375fea77b9 | |||
| 8c7c995985 | |||
| f96c55f361 | |||
| 14278739bf | |||
| d0eb458392 | |||
| 605ec86108 | |||
| dcbcc043e4 | |||
| b2b90050dc | |||
| 9d39e61161 |
50
.gitignore
vendored
@ -155,3 +155,53 @@ tmp/
|
||||
*.bak
|
||||
*.backup
|
||||
*~
|
||||
|
||||
# ============================================================
|
||||
# 不应进入版本控制的文件类型
|
||||
# ============================================================
|
||||
|
||||
# Qwen Code 用户配置(个人环境,每次 clone 都不同)
|
||||
.qwen/settings.json
|
||||
.qwen/settings.json.orig
|
||||
|
||||
# Qwen Code 自动生成的 skill 文件(每次会话重新生成)
|
||||
.qwen/skills/
|
||||
|
||||
# GUI 运行时生成的文件
|
||||
src/gui/scaler_params.pkl
|
||||
src/gui/crash_dump.txt
|
||||
|
||||
# 临时/调试脚本(根目录)
|
||||
降采样光谱.py
|
||||
1.py
|
||||
tset.py
|
||||
|
||||
# 报告与文档(本地工作产物)
|
||||
封装问题分析报告.md
|
||||
软件说明.md
|
||||
软件说明2.md
|
||||
|
||||
# 数据子目录中非 .gitkeep 的生成文件
|
||||
data/sub/waterindex*.csv
|
||||
data/sub/waterindex*.xlsx
|
||||
data/sub/png/watermask.png
|
||||
|
||||
# 图标文件(仅需保留 vector/svg,删除像素图标压缩包副本)
|
||||
data/icons-1/
|
||||
data/icons/
|
||||
|
||||
# 旧版脚手架(遗留实验代码)
|
||||
new/
|
||||
|
||||
# 精确放行 src/new/(端到端模块化新架构)
|
||||
!/src/new/
|
||||
!/src/new/**
|
||||
!/src/new/core/**
|
||||
!/src/new/services/**
|
||||
!/src/new/views/**
|
||||
|
||||
# 前端脚手架(未集成的独立 Vue 项目)
|
||||
frontend/
|
||||
|
||||
# 面板备份目录(运行中自动生成)
|
||||
_archive_panels_backup_/
|
||||
|
||||
4
1.py
@ -1,4 +0,0 @@
|
||||
|
||||
|
||||
new_wavelengths = [np.mean(wavelengths[i:i+3]) for i in range(0, len(wavelengths), 3)]
|
||||
print(new_wavelengths)
|
||||
83
README_new_arch.md
Normal file
@ -0,0 +1,83 @@
|
||||
# 端到端模块化新架构(src/new/)
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
src/new/
|
||||
├── __init__.py
|
||||
├── core/
|
||||
│ ├── __init__.py
|
||||
│ └── base_view.py # 基础通讯接口(继承 QWidget + dispatch_execute)
|
||||
├── services/ # 独立后端大脑
|
||||
│ ├── __init__.py
|
||||
│ ├── step1_service.py # Step 1 真实服务(execute_step1)
|
||||
│ └── placeholder_service.py # step2-step13 占位服务
|
||||
├── views/ # 独立前端皮囊
|
||||
│ ├── __init__.py
|
||||
│ ├── step1_view.py # Step 1 真实视图(继承 BaseView)
|
||||
│ └── placeholder_view.py # step2-step13 占位视图
|
||||
└── main_view.py # 路由与调度壳(QMainWindow + QThread)
|
||||
```
|
||||
|
||||
## 端到端调用链
|
||||
|
||||
```
|
||||
Step1View._on_run_clicked (绿色按钮)
|
||||
│ self.dispatch_execute("step1", self.get_config())
|
||||
▼
|
||||
BaseView.dispatch_execute (沿父链上溯)
|
||||
│ ancestor.run_single_step(step_id, config)
|
||||
▼
|
||||
MainView.run_single_step (查 ROUTES 表 → 注入 work_dir)
|
||||
│ TaskWorker(service_func, config).start()
|
||||
▼
|
||||
services.step1_service.execute_step1(config)
|
||||
│ 调 WaterMaskStep.run(...) → 包装成 dict 返回
|
||||
▼
|
||||
MainView._on_step_done (按 status 写日志)
|
||||
```
|
||||
|
||||
## 运行验证
|
||||
|
||||
### 1. 三层冒烟(推荐先跑)
|
||||
|
||||
```cmd
|
||||
cd D:\111\office\ZHLduijie\1.WQ\WQ_GUI
|
||||
python _smoke_new_arch.py
|
||||
```
|
||||
|
||||
预期输出 `汇总:54/54 通过`。
|
||||
|
||||
### 2. 启动路由主窗口
|
||||
|
||||
```cmd
|
||||
cd D:\111\office\ZHLduijie\1.WQ\WQ_GUI
|
||||
python -m src.new.main_view
|
||||
```
|
||||
|
||||
或:
|
||||
|
||||
```cmd
|
||||
python src\new\main_view.py
|
||||
```
|
||||
|
||||
启动后:
|
||||
|
||||
* 左侧 `QListWidget` 显示 13 个 step(step1 真实,其余占位)
|
||||
* 点击 `执行 Step 1: 水域掩膜` → 绿色按钮 → `dispatch_execute`
|
||||
* 底部 `QTextEdit` 实时打印 `[Router]` / `[Service]` 日志
|
||||
|
||||
## 关键设计原则
|
||||
|
||||
1. **view 零业务**:`src/new/views/*.py` 绝不 import 任何 `src/core/`、`src/services/`
|
||||
2. **service 零 PyQt**:`src/new/services/*.py` 不 import 任何 PyQt、不读写全局
|
||||
3. **唯一跨界通道**:`BaseView.dispatch_execute` 把 (step_id, config) 推给主窗口
|
||||
4. **后台执行不阻塞 UI**:`TaskWorker(QThread)` 子线程跑 service
|
||||
5. **错误兜底**:service 任何异常都被 TaskWorker 捕获并转成 `{status: "error", ...}`
|
||||
|
||||
## 当前状态
|
||||
|
||||
| step | view | service | 状态 |
|
||||
|--------|---------------------|------------------------|---------------------|
|
||||
| step1 | `Step1View` 真实 | `execute_step1` 真实 | ✅ 已迁移 |
|
||||
| step2-13 | `PlaceholderView` | `execute_placeholder` | 🚧 占位待迁移 |
|
||||
4
_check_qaa.py
Normal file
@ -0,0 +1,4 @@
|
||||
import sys
|
||||
sys.path.insert(0, r'D:\111\office\ZHLduijie\1.WQ\WQ_GUI')
|
||||
from src.core.algorithms.qaa import QAABaselineSolver
|
||||
print("QAABaselineSolver imported OK")
|
||||
0
_run_gen_csv.py
Normal file
6
check_lines.py
Normal file
@ -0,0 +1,6 @@
|
||||
import sys
|
||||
with open(r'D:\111\office\ZHLduijie\1.WQ\WQ_GUI\src\gui\water_quality_gui.py', 'rb') as f:
|
||||
content = f.read()
|
||||
lines = content.split(b'\r\n')
|
||||
for i, line in enumerate(lines[2918:2955], start=2919):
|
||||
sys.stdout.buffer.write(f'{i}: {repr(line[:120])}'.encode('utf-8') + b'\n')
|
||||
BIN
data/icons/1.png
|
Before Width: | Height: | Size: 67 KiB |
|
Before Width: | Height: | Size: 1.4 MiB |
|
Before Width: | Height: | Size: 3.4 MiB |
BIN
data/icons/2.png
|
Before Width: | Height: | Size: 70 KiB |
BIN
data/icons/3.png
|
Before Width: | Height: | Size: 55 KiB |
BIN
data/icons/4.png
|
Before Width: | Height: | Size: 51 KiB |
BIN
data/icons/5.png
|
Before Width: | Height: | Size: 46 KiB |
BIN
data/icons/6.png
|
Before Width: | Height: | Size: 51 KiB |
BIN
data/icons/7.png
|
Before Width: | Height: | Size: 78 KiB |
BIN
data/icons/8.png
|
Before Width: | Height: | Size: 59 KiB |
BIN
data/icons/9.png
|
Before Width: | Height: | Size: 76 KiB |
|
Before Width: | Height: | Size: 950 KiB |
|
Before Width: | Height: | Size: 30 KiB |
|
Before Width: | Height: | Size: 3.0 MiB |
|
Before Width: | Height: | Size: 1.6 MiB |
|
Before Width: | Height: | Size: 2.2 MiB |
|
Before Width: | Height: | Size: 5.3 MiB |
|
Before Width: | Height: | Size: 6.2 MiB |
|
Before Width: | Height: | Size: 978 KiB |
|
Before Width: | Height: | Size: 1.9 MiB |
|
Before Width: | Height: | Size: 300 KiB |
|
Before Width: | Height: | Size: 3.1 MiB |
|
Before Width: | Height: | Size: 16 MiB |
|
Before Width: | Height: | Size: 6.4 MiB |
|
Before Width: | Height: | Size: 250 KiB |
|
Before Width: | Height: | Size: 2.9 MiB |
|
Before Width: | Height: | Size: 3.1 MiB |
|
Before Width: | Height: | Size: 884 KiB |
|
Before Width: | Height: | Size: 18 KiB |
@ -1,46 +0,0 @@
|
||||
Formula_Name,Category,Formula,Reference
|
||||
BGA_Am09KBBI,Phycocyanin (BGA_PC),(w686 - w658) / (w686 + w658),"Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S.; Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery, Optics Express, 2009, 17, 11, 1-13."
|
||||
BGA_Be162B643sub629,Phycocyanin (BGA_PC),w644 - w629,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be162B700sub601,Phycocyanin (BGA_PC),w700 - w601,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be162BsubPhy,Phycocyanin (BGA_PC),w715 - w615,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540."
|
||||
BGA_Be16FLHBlueRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w458 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be16FLHGreenRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w558 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be16FLHVioletRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w444 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be16MPI,Phycocyanin (BGA_PC),(w615 - w601) - (w644 - w601),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be16NDPhyI,Phycocyanin (BGA_PC),(w700 - w622) / (w700 + w622),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540."
|
||||
BGA_Be16NDPhyI644over615,Phycocyanin (BGA_PC),(w644 - w615) / (w644 + w615),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 541."
|
||||
BGA_Be16NDPhyI644over629,Phycocyanin (BGA_PC),(w644 - w629) / (w644 + w629),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 542."
|
||||
BGA_Be16Phy2BDA644over629,Phycocyanin (BGA_PC),w644 / w629,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 545."
|
||||
BGA_Da052BDA,Phycocyanin (BGA_PC),w714 / w672,"Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672."
|
||||
BGA_Go04MCI,Phycocyanin (BGA_PC),w709 - w681 - (w753 - w681),"Gower, J.F.R.; Brown,L.; Borstad, G.A.; Observation of chlorophyll fluorescence in west coast waters of Canada using the MODIS satellite sensor. Can. J. Remote Sens., 2004, 30 (1), 17<31><37>?5."
|
||||
BGA_HU103BDA,Phycocyanin (BGA_PC),(((1 / w615) - (1 / w600)) - w725),"Hunter, P.D.; Tyler, A.N.; Willby, N.J.; Gilvear, D.J.; The spatial dynamics of vertical migration by Microcystis aeruginosa in a eutrophic shallow lake: A case study using high spatial resolution time-series airborne remote sensing. Limn. Oceanogr. 2008, 53, 2391-2406"
|
||||
BGA_Ku15PhyCI,Phycocyanin (BGA_PC),(-1 * (W681 - W665 - (W709 - W665))),"Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-10."
|
||||
BGA_Ku15SLH,Phycocyanin (BGA_PC),(w715 - w658) + (w715 - w658),"Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-11."
|
||||
BGA_MI092BDA,Phycocyanin (BGA_PC),w700 / w600,"Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758<35><38>?75."
|
||||
BGA_MM092BDA,Phycocyanin (BGA_PC),w724 / w600,"Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758<35><38>?76."
|
||||
BGA_MM12NDCIalt,Phycocyanin (BGA_PC),(w700 - w658) / (w700 + w658),"Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114003"
|
||||
BGA_MM143BDAopt,Phycocyanin (BGA_PC),((1 / w629) - (1 / w659)) * w724,"Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114004"
|
||||
BGA_SI052BDA,Phycocyanin (BGA_PC),w709 / w620,"Simis, S. G. H.; Peters, S.W. M.; Gons, H. J.; Remote sensing of the cyanobacteria pigment phycocyanin in turbid inland water. Limn. Oceanogr., 2005, 50, 237<33><37>?45"
|
||||
BGA_SM122BDA,Phycocyanin (BGA_PC),w709 / w600,"Mishra, S. Remote sensing of cyanobacteria in turbid productive waters, PhD Dissertation. Mississippi State University, USA. 2012."
|
||||
BGA_SY002BDA,Phycocyanin (BGA_PC),w650 / w625,"Schalles, J.; Yacobi, Y. Remote detection and seasonal patterns of phycocyanin, carotenoid and chlorophyll-a pigments in eutrophic waters. Archiv fur Hydrobiologie, Special Issues Advances in Limnology, 2000, 55,153<35><33>?68"
|
||||
BGA_Wy08CI,Phycocyanin (BGA_PC),(-1 * (W686 - W672 - (W715 - W672))),"Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672."
|
||||
Chl_Al10SABI,chlorophyll_a,(w857 - w644) / (w458 + w529),"Alawadi, F. Detection of surface algal blooms using the newly developed algorithm surface algal bloom index (SABI). Proc. SPIE 2010, 7825."
|
||||
Chl_Am092Bsub,chlorophyll_a,w681 - w665,"Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S. Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery. Opt. Express 2009, 17, 9126<32><36>?144."
|
||||
Chl_Be16FLHblue,chlorophyll_a,w529 - (w644 + (w458 - w644)),"Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30."
|
||||
Chl_Be16FLHviolet,chlorophyll_a,w529 - (w644 + (w429 - w644)),"Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30."
|
||||
Chl_Be16NDTIblue,chlorophyll_a,(w658 - w458) / (w658 + w458),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 543."
|
||||
Chl_Be16NDTIviolet,chlorophyll_a,(w658 - w444) / (w658 + w444),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 544."
|
||||
Chl_De933BDA,chlorophyll_a,w600 - w648 - w625,"Dekker, A.; Detection of the optical water quality parameters for eutrophic waters by high resolution remote sensing, Ph.D. thesis, 1993, Free University, Amsterdam."
|
||||
Chl_Gi033BDA,chlorophyll_a,((1 / w672) - (1 / w715)) * w757,"Gitelson, A.A.; U. Gritz, and M. N. Merzlyak.; Relationships between leaf chlorophyll content and spectral reflectance and algorithms for non-destructive chlorophyll assessment in higher plant leaves. J. Plant Phys. 2003, 160, 271-282."
|
||||
Chl_Kn07KIVU,chlorophyll_a,(w458 - w644) / w529,"Kneubuhler, M.; Frank T.; Kellenberger, T.W; Pasche N.; Schmid M.; Mapping chlorophyll-a in Lake Kivu with remote sensing methods. 2007, Proceedings of the Envisat Symposium 2007, Montreux, Switzerland 23<32><33>?7 April 2007 (ESA SP-636, July 2007)."
|
||||
Chl_MM12NDCI,chlorophyll_a,(w715 - w686) / (w715 + w686),"Mishra, S.; and Mishra, D.R. Normalized difference chlorophyll index: A novel model for remote estimation of chlorophyll-a concentration in turbid productive waters, Remote Sens. Environ., 2012, 117, 394-406"
|
||||
Chl_Zh10FLH,chlorophyll_a,w686 - (w715 + (w672 - w751)),"Zhao, D.Z.; Xing, X.G.; Liu, Y.G.; Yang, J.H.; Wang, L. The relation of chlorophyll-a concentration with the reflectance peak near 700 nm in algae-dominated waters and sensitivity of fluorescence algorithms for detecting algal bloom. Int. J. Remote Sens. 2010, 31, 39-48"
|
||||
Turb_Be16GreenPlusRedBothOverViolet,Turbidity,(w558 + w658) / w444,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538"
|
||||
Turb_Be16RedOverViolet,Turbidity,w658 / w444,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539"
|
||||
Turb_Bow06RedOverGreen,Turbidity,w658 / w558,"Bowers, D. G., and C. E. Binding. 2006. 闁炽儲缈籬e Optical Properties of Mineral Suspended Particles: A Review and Synthesis.<2E><>?Estuarine Coastal and Shelf Science 67 (1<><31>?): 219<31><39>?30. doi:10.1016/j.ecss.2005.11.010"
|
||||
Turb_Chip09NIROverGreen,Turbidity,w857 / w558,"Chipman, J. W.; Olmanson, L.G.; Gitelson, A.A.; Remote sensing methods for lake management: A guide for resource managers and decision-makers. 2009."
|
||||
Turb_Dox02NIRoverRed,Turbidity,w857 / w658,"Doxaran, D., Froidefond, J.-M.; Castaing, P. ; A reflectance band ratio used to estimate suspended matter concentrations in sediment-dominated coastal waters, Remote Sens., 2002, 23, 5079-5085"
|
||||
Turb_Frohn09GreenPlusRedBothOverBlue,Turbidity,(w558 + w658) / w458,"Frohn, R. C., & Autrey, B. C. (2009). Water quality assessment in the Ohio River using new indices for turbidity and chlorophyll-a with Landsat-7 Imagery. Draft Internal Report, US Environmental Protection Agency."
|
||||
Turb_Harr92NIR,Turbidity,w857,"Schiebe F.R., Harrington J.A., Ritchie J.C. Remote-Sensing of Suspended Sediments闁炽儲鏁刪e Lake Chicot, Arkansas Project. Int. J. Remote Sens. 1992;13:1487<38><37>?509"
|
||||
Turb_Lath91RedOverBlue,Turbidity,w658 / w458,"Lathrop, R. G., Jr., T. M. Lillesand, and B. S. Yandell, 1991. Testing the utility of simple multi-date Thematic Mapper calibration algorithms for monitoring turbid inland waters. International Journal of Remote Sensing"
|
||||
Turb_Moore80Red,Turbidity,w658,"Moore, G.K., Satellite remote sensing of water turbidity, Hydrological Sciences, 1980, 25, 4, 407-422"
|
||||
|
@ -1,46 +0,0 @@
|
||||
Formula_Name,Category,Formula,Reference
|
||||
BGA_Am09KBBI,Phycocyanin (BGA_PC),(w686 - w658) / (w686 + w658),"Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S.; Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery, Optics Express, 2009, 17, 11, 1-13."
|
||||
BGA_Be162B643sub629,Phycocyanin (BGA_PC),w644 - w629,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be162B700sub601,Phycocyanin (BGA_PC),w700 - w601,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be162BsubPhy,Phycocyanin (BGA_PC),w715 - w615,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540."
|
||||
BGA_Be16FLHBlueRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w458 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be16FLHGreenRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w558 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be16FLHVioletRedNIR,Phycocyanin (BGA_PC),w658 - (w857 + (w444 - w857)),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538."
|
||||
BGA_Be16MPI,Phycocyanin (BGA_PC),(w615 - w601) - (w644 - w601),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539."
|
||||
BGA_Be16NDPhyI,Phycocyanin (BGA_PC),(w700 - w622) / (w700 + w622),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540."
|
||||
BGA_Be16NDPhyI644over615,Phycocyanin (BGA_PC),(w644 - w615) / (w644 + w615),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 541."
|
||||
BGA_Be16NDPhyI644over629,Phycocyanin (BGA_PC),(w644 - w629) / (w644 + w629),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 542."
|
||||
BGA_Be16Phy2BDA644over629,Phycocyanin (BGA_PC),w644 / w629,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 545."
|
||||
BGA_Da052BDA,Phycocyanin (BGA_PC),w714 / w672,"Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672."
|
||||
BGA_Go04MCI,Phycocyanin (BGA_PC),w709 - w681 - (w753 - w681),"Gower, J.F.R.; Brown,L.; Borstad, G.A.; Observation of chlorophyll fluorescence in west coast waters of Canada using the MODIS satellite sensor. Can. J. Remote Sens., 2004, 30 (1), 17<31><37>?5."
|
||||
BGA_HU103BDA,Phycocyanin (BGA_PC),(((1 / w615) - (1 / w600)) - w725),"Hunter, P.D.; Tyler, A.N.; Willby, N.J.; Gilvear, D.J.; The spatial dynamics of vertical migration by Microcystis aeruginosa in a eutrophic shallow lake: A case study using high spatial resolution time-series airborne remote sensing. Limn. Oceanogr. 2008, 53, 2391-2406"
|
||||
BGA_Ku15PhyCI,Phycocyanin (BGA_PC),-1 * (W681 - W665 - (W709 - W665)),"Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-10."
|
||||
BGA_Ku15SLH,Phycocyanin (BGA_PC),(w715 - w658) + (w715 - w658),"Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K., Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-11."
|
||||
BGA_MI092BDA,Phycocyanin (BGA_PC),w700 / w600,"Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758<35><38>?75."
|
||||
BGA_MM092BDA,Phycocyanin (BGA_PC),w724 / w600,"Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for predicting PC concentrations in cyanobacteria: A proximal hyperspectral remote sensing approach. Remote Sens., 2009, 1, 758<35><38>?76."
|
||||
BGA_MM12NDCIalt,Phycocyanin (BGA_PC),(w700 - w658) / (w700 + w658),"Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114003"
|
||||
BGA_MM143BDAopt,Phycocyanin (BGA_PC),((1 / w629) - (1 / w659)) * w724,"Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett., 2014, 9 (11), DOI:10.1088/1748-9326/9/11/114004"
|
||||
BGA_SI052BDA,Phycocyanin (BGA_PC),w709 / w620,"Simis, S. G. H.; Peters, S.W. M.; Gons, H. J.; Remote sensing of the cyanobacteria pigment phycocyanin in turbid inland water. Limn. Oceanogr., 2005, 50, 237<33><37>?45"
|
||||
BGA_SM122BDA,Phycocyanin (BGA_PC),w709 / w600,"Mishra, S. Remote sensing of cyanobacteria in turbid productive waters, PhD Dissertation. Mississippi State University, USA. 2012."
|
||||
BGA_SY002BDA,Phycocyanin (BGA_PC),w650 / w625,"Schalles, J.; Yacobi, Y. Remote detection and seasonal patterns of phycocyanin, carotenoid and chlorophyll-a pigments in eutrophic waters. Archiv fur Hydrobiologie, Special Issues Advances in Limnology, 2000, 55,153<35><33>?68"
|
||||
BGA_Wy08CI,Phycocyanin (BGA_PC),-1 * (W686 - W672 - (W715 - W672)),"Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A., Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672."
|
||||
Chl_Al10SABI,chlorophyll_a,(w857 - w644) / (w458 + w529),"Alawadi, F. Detection of surface algal blooms using the newly developed algorithm surface algal bloom index (SABI). Proc. SPIE 2010, 7825."
|
||||
Chl_Am092Bsub,chlorophyll_a,w681 - w665,"Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S. Novel optical techniques for detecting and classifying toxic dinoflagellate Karenia brevis blooms using satellite imagery. Opt. Express 2009, 17, 9126<32><36>?144."
|
||||
Chl_Be16FLHblue,chlorophyll_a,w529 - (w644 + (w458 - w644)),"Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30."
|
||||
Chl_Be16FLHviolet,chlorophyll_a,w529 - (w644 + (w429 - w644)),"Beck, R.A. and 22 others; Comparison of satellite reflectance algorithms for estimating chlorophyll-a in a temperate reservoir using coincident hyperspectral aircraft imagery and dense coincident surface observations, Remote Sens. Environ., 2016, 178, 15-30."
|
||||
Chl_Be16NDTIblue,chlorophyll_a,(w658 - w458) / (w658 + w458),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 543."
|
||||
Chl_Be16NDTIviolet,chlorophyll_a,(w658 - w444) / (w658 + w444),"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 544."
|
||||
Chl_De933BDA,chlorophyll_a,w600 - w648 - w625,"Dekker, A.; Detection of the optical water quality parameters for eutrophic waters by high resolution remote sensing, Ph.D. thesis, 1993, Free University, Amsterdam."
|
||||
Chl_Gi033BDA,chlorophyll_a,((1 / w672) - (1 / w715)) * w757,"Gitelson, A.A.; U. Gritz, and M. N. Merzlyak.; Relationships between leaf chlorophyll content and spectral reflectance and algorithms for non-destructive chlorophyll assessment in higher plant leaves. J. Plant Phys. 2003, 160, 271-282."
|
||||
Chl_Kn07KIVU,chlorophyll_a,(w458 - w644) / w529,"Kneubuhler, M.; Frank T.; Kellenberger, T.W; Pasche N.; Schmid M.; Mapping chlorophyll-a in Lake Kivu with remote sensing methods. 2007, Proceedings of the Envisat Symposium 2007, Montreux, Switzerland 23<32><33>?7 April 2007 (ESA SP-636, July 2007)."
|
||||
Chl_MM12NDCI,chlorophyll_a,(w715 - w686) / (w715 + w686),"Mishra, S.; and Mishra, D.R. Normalized difference chlorophyll index: A novel model for remote estimation of chlorophyll-a concentration in turbid productive waters, Remote Sens. Environ., 2012, 117, 394-406"
|
||||
Chl_Zh10FLH,chlorophyll_a,w686 - (w715 + (w672 - w751)),"Zhao, D.Z.; Xing, X.G.; Liu, Y.G.; Yang, J.H.; Wang, L. The relation of chlorophyll-a concentration with the reflectance peak near 700 nm in algae-dominated waters and sensitivity of fluorescence algorithms for detecting algal bloom. Int. J. Remote Sens. 2010, 31, 39-48"
|
||||
Turb_Be16GreenPlusRedBothOverViolet,Turbidity,(w558 + w658) / w444,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538"
|
||||
Turb_Be16RedOverViolet,Turbidity,w658 / w444,"Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.; Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.; Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.; Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539"
|
||||
Turb_Bow06RedOverGreen,Turbidity,w658 / w558,"Bowers, D. G., and C. E. Binding. 2006. 鈥淭he Optical Properties of Mineral Suspended Particles: A Review and Synthesis.<2E><>?Estuarine Coastal and Shelf Science 67 (1<><31>?): 219<31><39>?30. doi:10.1016/j.ecss.2005.11.010"
|
||||
Turb_Chip09NIROverGreen,Turbidity,w857 / w558,"Chipman, J. W.; Olmanson, L.G.; Gitelson, A.A.; Remote sensing methods for lake management: A guide for resource managers and decision-makers. 2009."
|
||||
Turb_Dox02NIRoverRed,Turbidity,w857 / w658,"Doxaran, D., Froidefond, J.-M.; Castaing, P. ; A reflectance band ratio used to estimate suspended matter concentrations in sediment-dominated coastal waters, Remote Sens., 2002, 23, 5079-5085"
|
||||
Turb_Frohn09GreenPlusRedBothOverBlue,Turbidity,(w558 + w658) / w458,"Frohn, R. C., & Autrey, B. C. (2009). Water quality assessment in the Ohio River using new indices for turbidity and chlorophyll-a with Landsat-7 Imagery. Draft Internal Report, US Environmental Protection Agency."
|
||||
Turb_Harr92NIR,Turbidity,w857,"Schiebe F.R., Harrington J.A., Ritchie J.C. Remote-Sensing of Suspended Sediments鈥攖he Lake Chicot, Arkansas Project. Int. J. Remote Sens. 1992;13:1487<38><37>?509"
|
||||
Turb_Lath91RedOverBlue,Turbidity,w658 / w458,"Lathrop, R. G., Jr., T. M. Lillesand, and B. S. Yandell, 1991. Testing the utility of simple multi-date Thematic Mapper calibration algorithms for monitoring turbid inland waters. International Journal of Remote Sensing"
|
||||
Turb_Moore80Red,Turbidity,w658,"Moore, G.K., Satellite remote sensing of water turbidity, Hydrological Sciences, 1980, 25, 4, 407-422"
|
||||
|
85
data/格式转化.py
Normal file
@ -0,0 +1,85 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def batch_convert_to_ico(source_dirs, output_dir, target_size=(256, 256)):
|
||||
"""
|
||||
批量将指定目录下的图像文件转换为 ICO 格式。
|
||||
|
||||
:param source_dirs: 包含源文件夹路径的列表
|
||||
:param output_dir: 转换后 ICO 文件的保存目录
|
||||
:param target_size: 输出 ICO 的尺寸,默认 256x256
|
||||
"""
|
||||
# 支持的常见输入图像后缀
|
||||
supported_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.webp', '.tiff'}
|
||||
|
||||
# 确保输出目录存在,若无则自动创建
|
||||
out_path = Path(output_dir)
|
||||
out_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_converted = 0
|
||||
total_failed = 0
|
||||
|
||||
print("=" * 50)
|
||||
print(f"🚀 开始批量转换 ICO 图标...")
|
||||
print(f"📁 目标输出目录: {out_path}")
|
||||
print("=" * 50)
|
||||
|
||||
# 遍历所有传入的源目录
|
||||
for folder in source_dirs:
|
||||
folder_path = Path(folder)
|
||||
|
||||
if not folder_path.exists():
|
||||
print(f"⚠️ 警告: 源目录不存在,已跳过 -> {folder_path}")
|
||||
continue
|
||||
|
||||
print(f"\n📂 正在扫描目录: {folder_path}")
|
||||
|
||||
# 遍历目录下的所有文件
|
||||
for file_path in folder_path.iterdir():
|
||||
# 仅处理普通文件且后缀在支持列表内(忽略大小写)
|
||||
if file_path.is_file() and file_path.suffix.lower() in supported_extensions:
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
# 处理透明通道问题:
|
||||
# 如果图片支持透明通道 (RGBA/P/LA),转为 RGBA 确保透明背景不丢失
|
||||
# 如果是普通 RGB (如 JPG),转为 RGB
|
||||
if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info):
|
||||
img_clean = img.convert('RGBA')
|
||||
else:
|
||||
img_clean = img.convert('RGB')
|
||||
|
||||
# 构造输出文件名 (原文件名.ico)
|
||||
new_filename = f"{file_path.stem}.ico"
|
||||
save_path = out_path / new_filename
|
||||
|
||||
# 如果目标文件夹中已存在同名文件,为了防止覆盖,可以在文件名后加个标识
|
||||
# 但通常图标库同名直接覆盖较符合需求,这里默认直接保存
|
||||
img_clean.save(save_path, format="ICO", sizes=[target_size])
|
||||
|
||||
print(f" ✅ 成功: {file_path.name} -> {new_filename}")
|
||||
total_converted += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 失败: 无法转换 {file_path.name},错误信息: {e}")
|
||||
total_failed += 1
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("🎉 转换任务结束!")
|
||||
print(f"统计: 成功转换 {total_converted} 个文件,失败 {total_failed} 个。")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 1. 定义你要读取的两个源文件夹路径列表
|
||||
SOURCES = [
|
||||
r"D:\111\office\ZHLduijie\1.WQ\WQ_GUI\data\icons",
|
||||
r"D:\111\office\ZHLduijie\1.WQ\WQ_GUI\data\icons\word"
|
||||
]
|
||||
|
||||
# 2. 定义统一输出的目标文件夹路径
|
||||
OUTPUT = r"D:\111\office\ZHLduijie\1.WQ\WQ_GUI\data\icons-1"
|
||||
|
||||
# 执行转换
|
||||
batch_convert_to_ico(SOURCES, OUTPUT)
|
||||
350
docs/SMOKE_TEST_ROUTE_B_MVP.md
Normal file
@ -0,0 +1,350 @@
|
||||
# Smoke Test — 路线 B MVP(PipelineContext + AutoML + 软取消 + GUI 缝合)
|
||||
|
||||
> 适用范围:路线 B 重构 4 部分(pipeline 包 / AutoML 训练器 / WorkerThread 软取消 / GUI 一键全自动)落盘后的端到端点火试飞清单。
|
||||
> 目标:**用最小数据集(1 个 BSQ + 1 个 CSV)在 10–20 分钟内验证全链路打通**。
|
||||
|
||||
---
|
||||
|
||||
## 0. 前置准备(5 分钟)
|
||||
|
||||
### 0.1 装 Optuna
|
||||
|
||||
`environment.yml` 当前**未列** optuna(属于本次重构新增依赖)。若不装,Step 6 会自动降级到老 GridSearchCV(仍能跑通,但会触发 fallback 日志)。
|
||||
|
||||
```bash
|
||||
call venv\Scripts\activate.bat
|
||||
pip install "optuna>=3.6,<4.0"
|
||||
```
|
||||
|
||||
写入 `environment.yml` 的 patch(提交时改):
|
||||
|
||||
```yaml
|
||||
# 路线 B AutoML 防爆引擎(可选;未装时 Step 6 走老 GridSearchCV 降级路径)
|
||||
- optuna>=3.6
|
||||
```
|
||||
|
||||
### 0.2 准备最小数据集
|
||||
|
||||
```text
|
||||
work_dir_smoke/
|
||||
├── raw/
|
||||
│ ├── sample.b # 假彩色 BSQ(任意小分辨率都行,建议 50×50×6 波段)
|
||||
│ ├── sample_mask.tif # (可选)水域掩膜;不提供则 Step 1 自动生成 NDWI
|
||||
│ └── sample.csv # 含 3–6 个水质参数目标列(Chl-a / TSS / SD / TN / TP / COD…)+ 6 列波段反射率
|
||||
└── (其他文件由流程自动生成)
|
||||
```
|
||||
|
||||
**CSV 模板示例**(`feature_start_column` 默认为第一列;目标列必须**在特征列之前**):
|
||||
|
||||
```csv
|
||||
Chl-a,TSS,SD,B1,B2,B3,B4,B5,B6
|
||||
12.3,15.1,0.8,0.045,0.052,0.038,0.061,0.072,0.085
|
||||
11.8,14.2,0.9,0.044,0.051,0.037,0.060,0.071,0.084
|
||||
... (≥ 200 行;AutoML 智能子采样 N>5000 时才生效)
|
||||
```
|
||||
|
||||
### 0.3 启动 venv
|
||||
|
||||
```bash
|
||||
cd /d "D:\111\office\ZHLduijie\1.WQ\WQ_GUI"
|
||||
call venv\Scripts\activate.bat
|
||||
set PYTHONPATH=src;%PYTHONPATH%
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1. CLI 烟雾(最快路径,3 分钟)— **A 级:必跑**
|
||||
|
||||
跳过 GUI,直接验证 `automl_trainer.py` 自身可独立运行 + Optuna 子采样 + 降级路径:
|
||||
|
||||
```bash
|
||||
python -m src.core.prediction.automl_trainer ^
|
||||
--csv work_dir_smoke/raw/sample.csv ^
|
||||
--feature-start 6 ^
|
||||
--n-trials 5 ^
|
||||
--timeout 60.0 ^
|
||||
--out work_dir_smoke/7_Supervised_Model_Training_AutoML
|
||||
```
|
||||
|
||||
**通过标准**:
|
||||
|
||||
- [ ] 进程退出码 0
|
||||
- [ ] 控制台打印 `AutoML: 目标列 X 共尝试 N 个 trial,最佳 CV R²=…`
|
||||
- [ ] `<out>/<preprocess>/<target>_<preprocess>_<model>_AUTOML.joblib` 存在
|
||||
- [ ] `<out>/automl_summary.json` 存在且 `success=true`
|
||||
|
||||
**若 Optuna 未装**,期待看到:
|
||||
|
||||
```
|
||||
[AutoML] optuna 未安装,全目标列回退老 GridSearchCV
|
||||
```
|
||||
|
||||
产物文件名带 `_AUTOML` 后缀的逻辑此时**不会触发**(fallback 走老路径),属正常。
|
||||
|
||||
---
|
||||
|
||||
## 2. GUI 端到端 9 步(核心场景,10–20 分钟)— **S 级:必跑**
|
||||
|
||||
### 2.1 启动 GUI
|
||||
|
||||
```bash
|
||||
call venv\Scripts\activate.bat
|
||||
set PYTHONPATH=src;%PYTHONPATH%
|
||||
python -m src.gui.water_quality_gui
|
||||
```
|
||||
|
||||
### 2.2 UI 配置
|
||||
|
||||
| 步骤 | 操作 | 期望 |
|
||||
| ----- | -------------------------------------------------------------------- | ------------------------------------------------------------------------------------ |
|
||||
| 1/9 | 点"选择工作目录" → 选 `work_dir_smoke/` | 左侧步骤列表高亮,UI 不报错 |
|
||||
| 2/9 | 在 Step 1 面板选 `sample.b`;**掩膜留空**(验证 NDWI 自动生成路径) | 掩膜文本框保持空白 |
|
||||
| 3/9 | 在 Step 4 面板选 `sample.csv` | CSV 路径显示正确 |
|
||||
| 4/9 | **关键**:其他步骤(2/3/5/5.5/6/7/8/9)保持默认,不改任何参数 | AutoML 默认开启(use_automl=True) |
|
||||
| 5/9 | 点 **▶ 运行完整流程**(不要用老 `run_full_pipeline` 槽) | 弹出**二次确认窗**,文案显示:<br>• 掩膜:`未指定(将自动生成 NDWI 水域掩膜)`<br>• 去耀斑:开启<br>• AutoML:开启(Optuna 子采样寻优) |
|
||||
| 6/9 | 点"是(Y)" | "运行"按钮变灰,"停止"按钮亮起;进度条归零 |
|
||||
|
||||
### 2.3 观察日志(重点 4 大检查点)
|
||||
|
||||
#### ✅ 检查点 1:ctx 路径传递
|
||||
|
||||
启动后**第一秒**应看到类似:
|
||||
|
||||
```
|
||||
[Runner] ctx 已构造:14 路径字段,4 目录字段
|
||||
[Runner] 步骤 1/14:step1_generate_water_mask(requires=['raw_img_path', 'water_mask_path'])
|
||||
[Runner] 步骤 2/14:step2_find_glint_area(requires=['raw_img_path', 'water_mask_path', 'output_dir'])
|
||||
...
|
||||
[Runner] ctx 路径校准:water_mask_path = ...\work_dir_smoke\2_Glint_Area_Mask\glint_mask.tif
|
||||
```
|
||||
|
||||
→ **若没有 `[Runner]` 日志**,说明 v1 旧路径被走到了,**`inspect.signature` duck-type 没探测到 v2**,回去检查 `worker_thread.py:run()`。
|
||||
|
||||
#### ✅ 检查点 2:Step 1 NDWI 自动生成
|
||||
|
||||
```
|
||||
[Step1] 未指定 mask_path,自动基于 NDWI 生成水域掩膜
|
||||
[Step1] NDWI 阈值=0.4,写入 1_Water_Mask/water_mask.tif
|
||||
```
|
||||
|
||||
→ 验证 `<work_dir>/1_Water_Mask/water_mask.tif` 文件存在且非空。
|
||||
|
||||
#### ✅ 检查点 3:AutoML 启用
|
||||
|
||||
```
|
||||
[Step6] AutoML 启用 Optuna 子采样寻优(timeout=300s, n_trials=20, max_samples=5000)
|
||||
[Step6] 目标列 'Chl-a' 共 3 个候选模型,最佳 R²=0.812(model=RandomForest)
|
||||
[Step6] 目标列 'TSS' 共 3 个候选模型,最佳 R²=0.745(model=XGBoost)
|
||||
[Step6] 训练完成,产物写入 7_Supervised_Model_Training_AutoML/
|
||||
[Step6] automl_summary.json 写入完成
|
||||
```
|
||||
|
||||
→ 验证产物:
|
||||
- [ ] `7_Supervised_Model_Training_AutoML/<preprocess>/<target>_<preprocess>_<model>_AUTOML.joblib` ≥ 1 个
|
||||
- [ ] `7_Supervised_Model_Training_AutoML/automl_summary.json` 含 `automl: true` 字段
|
||||
- [ ] 老目录 `7_Supervised_Model_Training/` **不应该被创建**(AutoML 路径独立)
|
||||
|
||||
#### ✅ 检查点 4:AutoML 降级(仅未装 Optuna 时)
|
||||
|
||||
```
|
||||
[AutoML] optuna 未安装,全目标列回退老 GridSearchCV
|
||||
[Step6] 降级路径:调用 WaterQualityModelingBatch.train_models_batch(132 组 GridSearchCV)
|
||||
```
|
||||
|
||||
→ 跑通即可(仍能产生模型文件),但**降级**属于非优选路径。
|
||||
|
||||
### 2.4 9 步全程观察清单
|
||||
|
||||
| 步 | 期望产物(路径相对 `work_dir`) | 期望耗时(50×50 测试数据) |
|
||||
| ---- | -------------------------------------------------------------- | -------------------------- |
|
||||
| 1 | `1_Water_Mask/water_mask.tif` | < 5 s |
|
||||
| 2 | `2_Glint_Area_Mask/glint_mask.tif` | < 5 s |
|
||||
| 3 | `3_Remove_Glint_Image/deglint_image.tif` | < 5 s |
|
||||
| 4 | `4_Process_CSV/processed_data.csv` | < 2 s |
|
||||
| 5 | `5_Training_Sample/training_spectra.csv` | < 5 s |
|
||||
| 5.5 | `5_5_Calculate_Indices/indices.csv`(如启用) | < 2 s |
|
||||
| **6**| `7_Supervised_Model_Training_AutoML/`(**新路径!**) | **< 5 min(Optuna 5 trial)** |
|
||||
| 6.5 | `6_5_Non_Empirical_Modeling/`(如启用) | 1–2 min |
|
||||
| 6.75 | `6_75_Custom_Regression/`(如启用) | 1–2 min |
|
||||
| 7 | `7_Sampling_Points/sampling_points.csv` | < 3 s |
|
||||
| 8 | `8_Prediction/predicted_values.csv` | < 5 s |
|
||||
| 8.5 | `8_5_Prediction_Non_Empirical/predicted.csv`(如启用) | < 5 s |
|
||||
| 8.75 | `8_75_Prediction_Custom/predicted.csv`(如启用) | < 5 s |
|
||||
| 9 | `9_Kriging_Distribution_Map/distribution_map.tif` | 5–30 s(纯 Python 慢) |
|
||||
|
||||
### 2.5 流程结束
|
||||
|
||||
- [ ] 进度条到 100%
|
||||
- [ ] "运行"按钮恢复可点
|
||||
- [ ] "停止"按钮变灰
|
||||
- [ ] 日志末行出现 `=== 流程执行完成 ===` 或 `=== 流程被取消 ===`(取决于是否点过停止)
|
||||
- [ ] 控制台 `on_pipeline_finished` 触发:UI 状态被统一恢复
|
||||
|
||||
---
|
||||
|
||||
## 3. 软取消测试(3 分钟)— **A 级:必跑**
|
||||
|
||||
验证 `threading.Event` 软取消链路(不再用 `terminate()`)。
|
||||
|
||||
### 3.1 启动完整流程
|
||||
|
||||
如 2.2 启动流程。
|
||||
|
||||
### 3.2 中途点"停止"
|
||||
|
||||
**时机**:在 Step 6 AutoML 跑 trials 的中途(看到 `[Step6] 目标列 'Chl-a' 共 N 个候选模型` 之后任意时刻)点"停止"。
|
||||
|
||||
**期望看到**:
|
||||
|
||||
```
|
||||
[STOP] 用户请求软取消
|
||||
[Step6] 检测到 cancel_event,本 trial 完成后退出
|
||||
[Step6] AutoML 在 trial #X 中止,已完成 5/20 trial
|
||||
[Runner] 软取消:跳过剩余 8 个 step
|
||||
=== 流程被取消 ===
|
||||
```
|
||||
|
||||
UI 状态:
|
||||
|
||||
- [ ] "运行"按钮重新亮起
|
||||
- [ ] "停止"按钮变灰
|
||||
- [ ] 进度条保留在中断时的百分比(**不**归零)
|
||||
- [ ] `on_pipeline_finished` 触发(用 `success=False, cancelled=True` 区分)
|
||||
- [ ] **Python 进程不退出**(GUI 仍可继续点"运行"开新流程)
|
||||
|
||||
**反例(不应该发生)**:
|
||||
|
||||
- ❌ `QThread: Destroyed while thread is still running` 警告
|
||||
- ❌ Python 解释器直接崩溃
|
||||
- ❌ UI 永远卡死(`run_all_btn` 一直是灰的)
|
||||
|
||||
### 3.3 旧 `stop()` 路径回归
|
||||
|
||||
为防老代码忘了改,临时把 `water_quality_gui.py:stop_pipeline` 改回 `self.worker.stop()`,跑一次完整流程,看是否出现:
|
||||
|
||||
```
|
||||
[DEPRECATED] WorkerThread.stop() 已弃用,请改用 soft_stop()。
|
||||
```
|
||||
|
||||
**这是预期行为**(弃用方法保留但打 warning),流程仍能完成即视为通过。
|
||||
|
||||
---
|
||||
|
||||
## 4. 失败 / 降级场景(5 分钟)— **B 级:选跑**
|
||||
|
||||
### 4.1 未填掩膜 + NDWI 阈值设极端值
|
||||
|
||||
把 NDWI 阈值设到 `0.9`(几乎无水域),Step 1 应给出 warning 但不崩:
|
||||
|
||||
```
|
||||
[Step1] NDWI 阈值=0.9,水域覆盖率 < 1%,请检查影像
|
||||
```
|
||||
|
||||
### 4.2 CSV 完全无目标列
|
||||
|
||||
准备一个**没有目标列的 CSV**(全特征列),点运行:
|
||||
|
||||
```
|
||||
[AutoML] 训练 CSV 不存在或无目标列:未识别出目标列
|
||||
[Step6] AutoML 全部失败,所有目标列返回 success=False
|
||||
```
|
||||
|
||||
→ UI 不会崩,会在 `automl_summary.json` 写 `error: "未识别出目标列"`。
|
||||
|
||||
### 4.3 Step 1 路径不存在
|
||||
|
||||
Step 1 选了一个**不存在的 .bsq 文件**:
|
||||
|
||||
```
|
||||
[Runner] step1_generate_water_mask 异常:FileNotFoundError
|
||||
[STOP] 流程中止在 step 1
|
||||
```
|
||||
|
||||
→ UI 弹错误窗 + 把左侧步骤列表 `setCurrentRow(0)` 自动定位到 Step 1(`_focus_step` 起效)。
|
||||
|
||||
### 4.4 Optuna 版本冲突
|
||||
|
||||
装一个 `optuna==2.10`(API 大改),跑 GUI:
|
||||
|
||||
```
|
||||
[AutoML] optuna API 不兼容(>=3.6 要求):<error>
|
||||
[AutoML] 全目标列回退老 GridSearchCV
|
||||
```
|
||||
|
||||
→ 降级路径生效即视为通过。
|
||||
|
||||
---
|
||||
|
||||
## 5. 验证矩阵 Checklist
|
||||
|
||||
复制以下到 PR 描述 / 验收单:
|
||||
|
||||
```markdown
|
||||
## 路线 B MVP 验证矩阵
|
||||
|
||||
### 代码落盘
|
||||
- [ ] src/core/pipeline/__init__.py(17 行,4 export)
|
||||
- [ ] src/core/pipeline/context.py(PipelineContext dataclass)
|
||||
- [ ] src/core/pipeline/runner.py(StepSpec + PIPELINE_STEPS + PipelineRunner)
|
||||
- [ ] src/core/prediction/__init__.py(追加 train_with_automl export)
|
||||
- [ ] src/core/prediction/automl_trainer.py(AutoMLResult + train_with_automl + CLI)
|
||||
- [ ] src/core/steps/modeling_step.py(use_automl 分支 + _train_models_automl)
|
||||
- [ ] src/core/water_quality_inversion_pipeline_GUI.py(run_full_pipeline_v2 + LEGACY_ATTR_MAP + _sync_legacy_attrs_from_context)
|
||||
- [ ] src/gui/core/worker_thread.py(cancel_event + soft_stop + run() duck-type)
|
||||
- [ ] src/gui/water_quality_gui.py(on_run_all_clicked + _collect_minimal_config + 按钮重连)
|
||||
|
||||
### CLI 自测
|
||||
- [ ] A.1 `python -m src.core.prediction.automl_trainer --csv ...` 退出码 0
|
||||
- [ ] A.2 产物 .joblib 含 `_AUTOML` 后缀
|
||||
- [ ] A.3 automl_summary.json 含 success=true
|
||||
|
||||
### GUI 端到端
|
||||
- [ ] B.1 启动无 ImportError
|
||||
- [ ] B.2 二次确认窗文案含 mask 提示 + AutoML 状态
|
||||
- [ ] B.3 日志含 [Runner] 前缀(v2 路径生效)
|
||||
- [ ] B.4 Step 1 NDWI 自动生成路径生效
|
||||
- [ ] B.5 9 步产物路径全部存在
|
||||
- [ ] B.6 流程结束后 UI 状态恢复(运行按钮亮、停止按钮灰)
|
||||
|
||||
### 软取消
|
||||
- [ ] C.1 流程中途点停止,cancel_event 触发
|
||||
- [ ] C.2 流程被取消而非崩溃
|
||||
- [ ] C.3 UI 状态由 on_pipeline_finished 统一恢复
|
||||
- [ ] C.4 旧 stop() 调用打 [DEPRECATED] warning
|
||||
|
||||
### 降级
|
||||
- [ ] D.1 Optuna 未装 → 全目标列回退老 GridSearchCV
|
||||
- [ ] D.2 无目标列 CSV → 写 error 到 summary,不崩 UI
|
||||
- [ ] D.3 不存在文件 → _focus_step 定位到对应 step
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 已知未做(不在本次范围)
|
||||
|
||||
- [ ] Kriging 多进程并行(当前 backend="loop" 纯 Python)
|
||||
- [ ] Step 5 radius==0 内存优化(整波段读入)
|
||||
- [ ] 进度条 sub-step 粒度(当前只到 step 级)
|
||||
- [ ] Step 8 全图预测(当前只对采样点预测)
|
||||
- [ ] 全项目搜替换老 `self.worker.stop()` 调用(仅本会话改了 `water_quality_gui.py` 的 stop_pipeline)
|
||||
- [ ] `requirements.txt` 同步 Optuna(仅 `environment.yml` 写)
|
||||
- [ ] 单元测试套件(`tests/` 目录为空;建议用 pytest 覆盖 train_with_automl / PipelineRunner)
|
||||
|
||||
---
|
||||
|
||||
## 7. 出问题找哪里
|
||||
|
||||
| 现象 | 看哪里 |
|
||||
| --------------------------------------------- | ------------------------------------------------------- |
|
||||
| `[Runner]` 日志没出来 | `worker_thread.py:run()` 的 `inspect.signature` 探测 |
|
||||
| `[AutoML]` 完全没打 | `modeling_step.py:170` 的 `if use_automl` 是否进了 |
|
||||
| AutoML 报 `optuna API 不兼容` | `automl_trainer.py:236` 的 `try import` 块 |
|
||||
| 软取消无反应 | `worker_thread.py:run()` 末尾的 `cancel_event.is_set()` |
|
||||
| 二次确认窗没出来 | `water_quality_gui.py:on_run_all_clicked` line ~2848 |
|
||||
| 9 步产物路径错位 | `pipeline/runner.py:PIPELINE_STEPS` 的 `output` 字段 |
|
||||
| 老 v1 路径被走到 | `_sync_legacy_attrs_from_context` 没调,或 v2 异常 |
|
||||
|
||||
---
|
||||
|
||||
> **作者注**:本清单对应**路线 B 一键全自动重构 4 部分全部落盘**的验收场景,编号与 todo 8 同步。
|
||||
> 跑通 §1 + §2 + §3 三段即视为 MVP 验收通过;§4 用于鲁棒性抽查。
|
||||
8
license.lic
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"product": "WaterQualityInversion",
|
||||
"machine_code": "76E4992A5CF08BA570D6150908E04755",
|
||||
"generated_at": "2026-05-28 14:21:35",
|
||||
"expiry": "2099-12-31",
|
||||
"signature": "DC9AB900D7033A281E54F41F3F76D026FFA75D635484D40C7F6FC1F6023E02AB"
|
||||
}
|
||||
6
run_smoke.bat
Normal file
@ -0,0 +1,6 @@
|
||||
@echo off
|
||||
cd /d "D:\111\office\ZHLduijie\1.WQ\WQ_GUI"
|
||||
call venv\Scripts\activate.bat
|
||||
set PYTHONPATH=new\app\api;%PYTHONPATH%
|
||||
python -c "import _smoke_test_train; _smoke_test_train.test_load_train_df(); _smoke_test_train.test_get_model_pipeline_all_types(); _smoke_test_train.test_run_train_sync_linearregression_fast(); _smoke_test_train.test_run_train_sync_bad_csv(); _smoke_test_train.test_run_train_sync_bad_target(); print('OK')" > %TEMP%\smoke_log.txt 2>&1
|
||||
type %TEMP%\smoke_log.txt
|
||||
@ -5,11 +5,12 @@ import sys
|
||||
def _safe_add(path: str) -> None:
|
||||
if not path or not os.path.isdir(path):
|
||||
return
|
||||
try:
|
||||
if hasattr(os, "add_dll_directory"):
|
||||
if hasattr(os, "add_dll_directory"):
|
||||
try:
|
||||
os.add_dll_directory(path)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
os.environ["PATH"] = path + os.pathsep + os.environ.get("PATH", "")
|
||||
except Exception:
|
||||
@ -21,5 +22,4 @@ base = getattr(sys, "_MEIPASS", None)
|
||||
if base:
|
||||
_safe_add(base)
|
||||
_safe_add(os.path.join(base, "lib-dynload"))
|
||||
_safe_add(os.path.join(base, "DLLs"))
|
||||
|
||||
_safe_add(os.path.join(base, "DLLs"))
|
||||
4
src/auth/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
授权认证模块
|
||||
"""
|
||||
223
src/auth/keygen_gui.py
Normal file
@ -0,0 +1,223 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Mega Water - 离线授权发卡器 (开发者专用)
|
||||
生成绑定特定机器码的 .lic 授权文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 确保 src.auth 在 path 中
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
_project_root = os.path.abspath(os.path.join(_current_dir, "..", ".."))
|
||||
if _project_root not in sys.path:
|
||||
sys.path.insert(0, _project_root)
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QWidget, QVBoxLayout, QHBoxLayout, QLabel, QLineEdit,
|
||||
QPushButton, QFileDialog, QMessageBox, QApplication, QDateEdit, QCheckBox
|
||||
)
|
||||
from PyQt5.QtCore import Qt, QDate
|
||||
|
||||
from src.auth.license_manager import generate_license
|
||||
|
||||
# 永久授权的标识日期
|
||||
PERMANENT_EXPIRY = "2099-12-31"
|
||||
|
||||
|
||||
class LicenseKeygenWindow(QWidget):
|
||||
"""授权发卡器主窗口"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("Mega Water - 离线授权发卡器 (开发者专用)")
|
||||
self.setMinimumSize(640, 360)
|
||||
self.move(400, 280)
|
||||
|
||||
self._default_save_path = os.path.join(_project_root, "license.lic")
|
||||
self._setup_ui()
|
||||
|
||||
def _setup_ui(self):
|
||||
# ── 全局字体:无衬线,清晰 ──
|
||||
font_family = "Microsoft YaHei" if sys.platform == "win32" else "Segoe UI"
|
||||
self.setStyleSheet(f"""
|
||||
* {{
|
||||
font-family: {font_family}, 'Segoe UI', sans-serif;
|
||||
font-size: 11pt;
|
||||
}}
|
||||
QLabel#titleLabel {{
|
||||
font-size: 16pt;
|
||||
font-weight: bold;
|
||||
color: #2c3e50;
|
||||
}}
|
||||
QLabel#tipLabel {{
|
||||
font-size: 10pt;
|
||||
color: #95a5a6;
|
||||
}}
|
||||
""")
|
||||
|
||||
main_layout = QVBoxLayout()
|
||||
main_layout.setContentsMargins(45, 40, 45, 40)
|
||||
main_layout.setSpacing(18)
|
||||
|
||||
# ── 标题 ──
|
||||
title_label = QLabel("离线授权发卡器 (开发者专用)")
|
||||
title_label.setObjectName("titleLabel")
|
||||
title_label.setAlignment(Qt.AlignCenter)
|
||||
main_layout.addWidget(title_label)
|
||||
|
||||
# ── 机器码输入行 ──
|
||||
mc_layout = QHBoxLayout()
|
||||
mc_layout.setSpacing(12)
|
||||
mc_label = QLabel("机器码:")
|
||||
mc_label.setFixedWidth(90)
|
||||
self.mc_input = QLineEdit()
|
||||
self.mc_input.setPlaceholderText("粘贴用户发来的 32 位机器码")
|
||||
self.mc_input.setMinimumHeight(36)
|
||||
self.mc_input.setMinimumWidth(400)
|
||||
mc_layout.addWidget(mc_label, 0)
|
||||
mc_layout.addWidget(self.mc_input, 1)
|
||||
main_layout.addLayout(mc_layout)
|
||||
|
||||
# ── 到期时间选择行 ──
|
||||
exp_layout = QHBoxLayout()
|
||||
exp_layout.setSpacing(14)
|
||||
exp_label = QLabel("到期时间:")
|
||||
exp_label.setFixedWidth(90)
|
||||
self.exp_edit = QDateEdit()
|
||||
self.exp_edit.setCalendarPopup(True)
|
||||
self.exp_edit.setMinimumHeight(36)
|
||||
self.exp_edit.setMinimumWidth(160)
|
||||
self.exp_edit.setDate(QDate.currentDate().addYears(1))
|
||||
|
||||
self.perm_check = QCheckBox("永久授权 (不限时)")
|
||||
self.perm_check.setMinimumHeight(36)
|
||||
self.perm_check.stateChanged.connect(self._on_perm_changed)
|
||||
|
||||
exp_layout.addWidget(exp_label, 0)
|
||||
exp_layout.addWidget(self.exp_edit, 0)
|
||||
exp_layout.addWidget(self.perm_check, 0)
|
||||
exp_layout.addStretch(1)
|
||||
main_layout.addLayout(exp_layout)
|
||||
|
||||
# ── 保存路径行 ──
|
||||
path_layout = QHBoxLayout()
|
||||
path_layout.setSpacing(12)
|
||||
path_label = QLabel("保存路径:")
|
||||
path_label.setFixedWidth(90)
|
||||
self.path_input = QLineEdit()
|
||||
self.path_input.setReadOnly(True)
|
||||
self.path_input.setMinimumHeight(36)
|
||||
self.browse_btn = QPushButton("浏览...")
|
||||
self.browse_btn.setMinimumHeight(36)
|
||||
self.browse_btn.setFixedWidth(80)
|
||||
self.browse_btn.clicked.connect(self._on_browse)
|
||||
path_layout.addWidget(path_label, 0)
|
||||
path_layout.addWidget(self.path_input, 1)
|
||||
path_layout.addWidget(self.browse_btn, 0)
|
||||
main_layout.addLayout(path_layout)
|
||||
|
||||
# ── 弹性空间 ──
|
||||
main_layout.addSpacing(10)
|
||||
|
||||
# ── 生成按钮 ──
|
||||
self.gen_btn = QPushButton("生成授权文件 (.lic)")
|
||||
self.gen_btn.setMinimumHeight(48)
|
||||
self.gen_btn.setStyleSheet("""
|
||||
QPushButton {
|
||||
background-color: #27ae60;
|
||||
color: white;
|
||||
font-size: 13pt;
|
||||
font-weight: bold;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
}
|
||||
QPushButton:hover {
|
||||
background-color: #2ecc71;
|
||||
}
|
||||
QPushButton:pressed {
|
||||
background-color: #1e8449;
|
||||
}
|
||||
""")
|
||||
self.gen_btn.clicked.connect(self._on_generate)
|
||||
main_layout.addWidget(self.gen_btn)
|
||||
|
||||
# ── 底部提示 ──
|
||||
tip_label = QLabel("生成后请将 license.lic 文件发给用户,放置到软件安装目录下即可。")
|
||||
tip_label.setObjectName("tipLabel")
|
||||
tip_label.setAlignment(Qt.AlignCenter)
|
||||
main_layout.addWidget(tip_label)
|
||||
|
||||
self.setLayout(main_layout)
|
||||
|
||||
def _on_perm_changed(self, state):
|
||||
"""永久授权复选框状态变化时,联动日期选择器"""
|
||||
if state == Qt.Checked:
|
||||
self.exp_edit.setEnabled(False)
|
||||
else:
|
||||
self.exp_edit.setEnabled(True)
|
||||
|
||||
def _on_browse(self):
|
||||
"""打开文件对话框选择保存路径"""
|
||||
path, _ = QFileDialog.getSaveFileName(
|
||||
self,
|
||||
"选择授权文件保存位置",
|
||||
self._default_save_path,
|
||||
"授权文件 (*.lic)"
|
||||
)
|
||||
if path:
|
||||
if not path.lower().endswith(".lic"):
|
||||
path += ".lic"
|
||||
self.path_input.setText(path)
|
||||
|
||||
def _on_generate(self):
|
||||
"""点击生成按钮,调用授权管理器"""
|
||||
machine_code = self.mc_input.text().strip()
|
||||
if not machine_code:
|
||||
QMessageBox.warning(self, "输入错误", "请输入机器码")
|
||||
return
|
||||
|
||||
output_path = self.path_input.text().strip()
|
||||
if not output_path:
|
||||
QMessageBox.warning(self, "输入错误", "请设置保存路径")
|
||||
return
|
||||
|
||||
# 根据是否勾选永久授权决定日期
|
||||
if self.perm_check.isChecked():
|
||||
expiry_date = PERMANENT_EXPIRY
|
||||
else:
|
||||
expiry_date = self.exp_edit.date().toString("yyyy-MM-dd")
|
||||
|
||||
ok, msg = generate_license(
|
||||
machine_code=machine_code,
|
||||
output_path=output_path,
|
||||
expiry_date=expiry_date
|
||||
)
|
||||
|
||||
if ok:
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"生成成功",
|
||||
f"✅ 授权文件已成功生成!\n\n保存路径:\n{output_path}\n\n请将此文件发给用户即可。",
|
||||
QMessageBox.Ok
|
||||
)
|
||||
else:
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"生成失败",
|
||||
f"❌ {msg}",
|
||||
QMessageBox.Ok
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# ── 高 DPI 自适应(必须放在 QApplication 实例化之前)──
|
||||
from PyQt5.QtCore import Qt
|
||||
QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True)
|
||||
QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True)
|
||||
|
||||
app = QApplication(sys.argv)
|
||||
app.setApplicationName("LicenseKeygen")
|
||||
window = LicenseKeygenWindow()
|
||||
window.show()
|
||||
sys.exit(app.exec_())
|
||||
254
src/auth/license_dialog.py
Normal file
@ -0,0 +1,254 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LicenseDialog - PyQt5 授权拦截弹窗
|
||||
当授权验证失败时弹出,提示用户导入授权文件。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from PyQt5.QtWidgets import (
|
||||
QDialog, QVBoxLayout, QHBoxLayout, QLabel, QPushButton,
|
||||
QTextEdit, QFileDialog, QMessageBox, QApplication
|
||||
)
|
||||
from PyQt5.QtCore import Qt, QTimer
|
||||
from PyQt5.QtGui import QFont, QIcon, QGuiApplication
|
||||
|
||||
# 导入授权管理器
|
||||
from src.auth.license_manager import get_machine_code, verify_license, get_license_path
|
||||
|
||||
|
||||
class LicenseDialog(QDialog):
|
||||
"""
|
||||
授权验证弹窗
|
||||
- 显示本机机器码(只读文本框)
|
||||
- 提供"一键复制"功能
|
||||
- 提供"导入授权文件"按钮
|
||||
- 导入成功后提示重启
|
||||
"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.setWindowTitle("授权验证")
|
||||
self.setWindowFlags(
|
||||
Qt.Dialog |
|
||||
Qt.WindowTitleHint |
|
||||
Qt.WindowCloseButtonHint
|
||||
)
|
||||
self.setModal(True)
|
||||
self.setMinimumWidth(540)
|
||||
|
||||
self._init_ui()
|
||||
self._load_machine_code()
|
||||
|
||||
# 窗口居中
|
||||
QTimer.singleShot(0, self._center_on_screen)
|
||||
|
||||
def _center_on_screen(self):
|
||||
"""将窗口居中到屏幕"""
|
||||
screen = QGuiApplication.primaryScreen()
|
||||
if screen:
|
||||
geo = screen.geometry()
|
||||
self.move(
|
||||
(geo.width() - self.width()) // 2,
|
||||
(geo.height() - self.height()) // 2
|
||||
)
|
||||
|
||||
def _init_ui(self):
|
||||
main_layout = QVBoxLayout(self)
|
||||
main_layout.setContentsMargins(30, 30, 30, 20)
|
||||
main_layout.setSpacing(16)
|
||||
|
||||
# ── 标题区 ──
|
||||
title_font = QFont("Microsoft YaHei", 14, QFont.Bold)
|
||||
title_label = QLabel("本软件需要授权方可运行")
|
||||
title_label.setFont(title_font)
|
||||
title_label.setAlignment(Qt.AlignCenter)
|
||||
title_label.setStyleSheet("color: #2c3e50;")
|
||||
main_layout.addWidget(title_label)
|
||||
|
||||
# ── 机器码标签 ──
|
||||
code_label = QLabel("本机机器码(用于申请授权):")
|
||||
code_label.setStyleSheet("font-weight: bold; color: #34495e;")
|
||||
main_layout.addWidget(code_label)
|
||||
|
||||
# ── 机器码文本框 + 复制按钮 ──
|
||||
code_layout = QHBoxLayout()
|
||||
code_layout.setSpacing(8)
|
||||
|
||||
self.code_edit = QTextEdit()
|
||||
self.code_edit.setReadOnly(True)
|
||||
self.code_edit.setMaximumHeight(72)
|
||||
self.code_edit.setFont(QFont("Consolas", 13))
|
||||
self.code_edit.setStyleSheet(
|
||||
"QTextEdit {"
|
||||
" background-color: #ecf0f1;"
|
||||
" border: 1px solid #bdc3c7;"
|
||||
" border-radius: 4px;"
|
||||
" padding: 8px;"
|
||||
" color: #2c3e50;"
|
||||
"}"
|
||||
)
|
||||
code_layout.addWidget(self.code_edit, 1)
|
||||
|
||||
copy_btn = QPushButton("复制")
|
||||
copy_btn.setFixedWidth(72)
|
||||
copy_btn.setCursor(Qt.PointingHandCursor)
|
||||
copy_btn.setStyleSheet(
|
||||
"QPushButton {"
|
||||
" background-color: #3498db;"
|
||||
" color: white;"
|
||||
" border: none;"
|
||||
" border-radius: 4px;"
|
||||
" padding: 8px 4px;"
|
||||
" font-weight: bold;"
|
||||
"}"
|
||||
"QPushButton:hover { background-color: #2980b9; }"
|
||||
"QPushButton:pressed { background-color: #21618c; }"
|
||||
)
|
||||
copy_btn.clicked.connect(self._copy_code)
|
||||
code_layout.addWidget(copy_btn)
|
||||
|
||||
main_layout.addLayout(code_layout)
|
||||
|
||||
# ── 导入授权文件按钮 ──
|
||||
import_btn = QPushButton("导入授权文件 (.lic)")
|
||||
import_btn.setCursor(Qt.PointingHandCursor)
|
||||
import_btn.setStyleSheet(
|
||||
"QPushButton {"
|
||||
" background-color: #27ae60;"
|
||||
" color: white;"
|
||||
" border: none;"
|
||||
" border-radius: 6px;"
|
||||
" padding: 12px;"
|
||||
" font-size: 14px;"
|
||||
" font-weight: bold;"
|
||||
"}"
|
||||
"QPushButton:hover { background-color: #229954; }"
|
||||
"QPushButton:pressed { background-color: #1e8449; }"
|
||||
)
|
||||
import_btn.clicked.connect(self._import_license)
|
||||
main_layout.addWidget(import_btn)
|
||||
|
||||
# ── 提示文字 ──
|
||||
tip_label = QLabel(
|
||||
"导入后软件将自动重启生效。"
|
||||
)
|
||||
tip_label.setAlignment(Qt.AlignCenter)
|
||||
tip_label.setStyleSheet("color: #95a5a6; font-size: 11px;")
|
||||
main_layout.addWidget(tip_label)
|
||||
|
||||
# ── 授权公司联系信息 ──
|
||||
company_label = QLabel(
|
||||
"请联系授权公司:北京理加联合科技有限公司 或者 北京依锐思遥感技术有限公司"
|
||||
)
|
||||
company_label.setAlignment(Qt.AlignCenter)
|
||||
company_label.setStyleSheet("color: #e74c3c; font-size: 12px; font-weight: bold;")
|
||||
main_layout.addWidget(company_label)
|
||||
|
||||
# ── 取消按钮(退出程序)──
|
||||
cancel_btn = QPushButton("退出")
|
||||
cancel_btn.setCursor(Qt.PointingHandCursor)
|
||||
cancel_btn.setStyleSheet(
|
||||
"QPushButton {"
|
||||
" background-color: #95a5a6;"
|
||||
" color: white;"
|
||||
" border: none;"
|
||||
" border-radius: 4px;"
|
||||
" padding: 8px 20px;"
|
||||
"}"
|
||||
"QPushButton:hover { background-color: #7f8c8d; }"
|
||||
)
|
||||
cancel_btn.clicked.connect(self._quit_app)
|
||||
main_layout.addWidget(cancel_btn, 0, Qt.AlignRight)
|
||||
|
||||
main_layout.addStretch()
|
||||
|
||||
def _load_machine_code(self):
|
||||
"""读取并显示本机机器码"""
|
||||
try:
|
||||
code = get_machine_code(32)
|
||||
self.code_edit.setPlainText(code)
|
||||
except Exception as e:
|
||||
self.code_edit.setPlainText(f"读取失败: {e}")
|
||||
|
||||
def _copy_code(self):
|
||||
"""复制机器码到剪贴板"""
|
||||
clipboard = QApplication.clipboard()
|
||||
clipboard.setText(self.code_edit.toPlainText().strip())
|
||||
|
||||
# 显示反馈
|
||||
QMessageBox.information(self, "已复制", "机器码已复制到剪贴板。")
|
||||
|
||||
def _import_license(self):
|
||||
"""打开文件选择对话框,导入 .lic 授权文件"""
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self,
|
||||
"选择授权文件",
|
||||
"",
|
||||
"授权文件 (*.lic);;所有文件 (*.*)"
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
return
|
||||
|
||||
# 验证授权文件
|
||||
ok, msg = verify_license(file_path)
|
||||
if not ok:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"授权文件无效",
|
||||
f"验证失败: {msg}\n\n请确认选择了正确的授权文件。"
|
||||
)
|
||||
return
|
||||
|
||||
# 复制授权文件到标准路径
|
||||
dest_path = get_license_path()
|
||||
try:
|
||||
dest_dir = os.path.dirname(dest_path)
|
||||
if dest_dir:
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
shutil.copy2(file_path, dest_path)
|
||||
except OSError as e:
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"保存失败",
|
||||
f"无法保存授权文件: {e}"
|
||||
)
|
||||
return
|
||||
|
||||
# 成功提示,重启程序
|
||||
reply = QMessageBox.information(
|
||||
self,
|
||||
"导入成功",
|
||||
"授权文件已成功导入。\n\n软件将自动重启以应用授权。"
|
||||
if False else # 占位,维持下面的逻辑
|
||||
"授权文件已成功导入。\n软件将自动重启以应用授权。",
|
||||
QMessageBox.Ok
|
||||
)
|
||||
|
||||
self.accept()
|
||||
self._restart_app()
|
||||
|
||||
def _quit_app(self):
|
||||
"""退出程序"""
|
||||
self.reject()
|
||||
sys.exit(0)
|
||||
|
||||
def _restart_app(self):
|
||||
"""重启程序"""
|
||||
self.close()
|
||||
QApplication.quit()
|
||||
|
||||
# 延迟重启(确保 QApplication 完全退出)
|
||||
import subprocess
|
||||
import sys as _sys
|
||||
executable = _sys.executable
|
||||
if getattr(_sys, 'frozen', False):
|
||||
# PyInstaller 打包环境下
|
||||
subprocess.Popen([executable] + _sys.argv[1:])
|
||||
else:
|
||||
# 开发环境
|
||||
subprocess.Popen([executable, __file__])
|
||||
328
src/auth/license_manager.py
Normal file
@ -0,0 +1,328 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
License Manager - 离线授权管理模块
|
||||
使用 HMAC-SHA256 + 盐值签名防止篡改
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import hmac
|
||||
import hashlib
|
||||
import base64
|
||||
import uuid
|
||||
import hashlib as _hashlib
|
||||
import subprocess
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
# ============================================================
|
||||
# 第一部分:硬件指纹提取(内嵌 get_machine_code)
|
||||
# ============================================================
|
||||
|
||||
def get_cpu_id() -> Optional[str]:
|
||||
"""读取 CPU 序列号(Processor ID)"""
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
result = subprocess.run(
|
||||
["wmic", "cpu", "get", "ProcessorId"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW,
|
||||
stdin=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
cpu_id = result.stdout.strip().split("\n")[-1].strip()
|
||||
if cpu_id:
|
||||
return cpu_id
|
||||
else:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
for line in f:
|
||||
if "Serial" in line or "processor" in line:
|
||||
cpu_id = line.split(":")[-1].strip()
|
||||
if cpu_id:
|
||||
return cpu_id
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_motherboard_uuid() -> Optional[str]:
|
||||
"""读取主板 UUID(BaseBoard Serial Number)"""
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
result = subprocess.run(
|
||||
["wmic", "baseboard", "get", "SerialNumber"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
creationflags=subprocess.CREATE_NO_WINDOW,
|
||||
stdin=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
board_uuid = result.stdout.strip().split("\n")[-1].strip()
|
||||
board_uuid = re.sub(r'[^a-zA-Z0-9\-]', '', board_uuid)
|
||||
if board_uuid and board_uuid not in ("To be filled", "None"):
|
||||
return board_uuid
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["cat", "/sys/class/dmi/id/product_uuid"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
stdin=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def get_machine_code(code_length: int = 32) -> str:
|
||||
"""
|
||||
生成唯一的机器码(硬件指纹)
|
||||
参数:
|
||||
code_length: 机器码长度,支持 16/24/32/48/64 位,默认 32 位
|
||||
返回:
|
||||
全大写字母+数字的机器码字符串
|
||||
"""
|
||||
cpu_id = get_cpu_id() or ""
|
||||
board_uuid = get_motherboard_uuid() or ""
|
||||
raw_hardware = f"{cpu_id}-{board_uuid}"
|
||||
|
||||
if not raw_hardware.strip("-") or len(raw_hardware) < 8:
|
||||
try:
|
||||
machine_name = uuid.gethostname() or ""
|
||||
mac = ':'.join(re.findall('..', '%012x' % uuid.getnode()))
|
||||
raw_hardware = f"{machine_name}-{mac}"
|
||||
except Exception:
|
||||
raw_hardware = str(uuid.getnode())
|
||||
|
||||
raw_hardware = re.sub(r'[^a-zA-Z0-9]', '', raw_hardware)
|
||||
hash_hex = hashlib.sha256(raw_hardware.encode('utf-8')).hexdigest().upper()
|
||||
hash_hex = hash_hex.replace('O', 'X').replace('L', 'Y').replace('I', 'Z')
|
||||
|
||||
valid_lengths = [16, 24, 32, 48, 64]
|
||||
if code_length not in valid_lengths:
|
||||
code_length = 32
|
||||
|
||||
return hash_hex[:code_length]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 第二部分:授权文件格式与签名机制
|
||||
# ============================================================
|
||||
|
||||
# 开发者密钥(硬编码在软件中,用于验证授权文件)
|
||||
# 注意:实际部署时建议对密钥进行简单混淆或从外部文件加载
|
||||
DEVELOPER_SECRET = b"WaterQuality_v1_2025_SecretKey"
|
||||
LICENSE_VERSION = "1.0"
|
||||
|
||||
|
||||
def _compute_signature(payload_json: str) -> str:
|
||||
"""
|
||||
计算 HMAC-SHA256 签名
|
||||
payload_json: JSON 序列化后的字符串(不含 signature 字段)
|
||||
"""
|
||||
sig = hmac.new(
|
||||
DEVELOPER_SECRET,
|
||||
payload_json.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).hexdigest().upper()
|
||||
return sig
|
||||
|
||||
|
||||
def _clean_hash(s: str) -> str:
|
||||
"""清洗哈希字符串,避免混淆字符"""
|
||||
return s.replace('O', 'X').replace('L', 'Y').replace('I', 'Z')
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 第三部分:核心 API
|
||||
# ============================================================
|
||||
|
||||
def get_license_path() -> str:
|
||||
"""获取授权文件的标准存放路径(程序根目录)"""
|
||||
if getattr(sys, 'frozen', False):
|
||||
base_dir = os.path.dirname(sys.executable)
|
||||
else:
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
return os.path.join(base_dir, "license.lic")
|
||||
|
||||
|
||||
def verify_license(license_path: Optional[str] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
校验授权文件是否匹配本机硬件指纹。
|
||||
|
||||
参数:
|
||||
license_path: 授权文件路径,默认使用标准路径
|
||||
|
||||
返回:
|
||||
(is_valid, message)
|
||||
- is_valid=True 表示授权有效
|
||||
- is_valid=False 表示授权无效,message 为具体原因
|
||||
"""
|
||||
if license_path is None:
|
||||
license_path = get_license_path()
|
||||
|
||||
# Step 1: 文件是否存在
|
||||
if not os.path.isfile(license_path):
|
||||
return False, "授权文件不存在"
|
||||
|
||||
# Step 2: 读取并解析 JSON
|
||||
try:
|
||||
with open(license_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
lic_data = json.loads(content)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
return False, f"授权文件格式错误: {e}"
|
||||
|
||||
# Step 3: 校验版本号
|
||||
version = lic_data.get("version", "")
|
||||
if version != LICENSE_VERSION:
|
||||
return False, f"授权文件版本不匹配 (期望 {LICENSE_VERSION})"
|
||||
|
||||
# Step 4: 校验过期时间
|
||||
expiry_str = lic_data.get("expiry", "")
|
||||
if expiry_str:
|
||||
try:
|
||||
expiry_dt = datetime.strptime(expiry_str, "%Y-%m-%d")
|
||||
if datetime.now() > expiry_dt:
|
||||
return False, "授权已过期"
|
||||
except ValueError:
|
||||
return False, "授权文件日期格式错误"
|
||||
|
||||
# Step 5: 提取 payload(不含 signature)
|
||||
payload_for_verify = {k: v for k, v in lic_data.items() if k != "signature"}
|
||||
payload_json = json.dumps(payload_for_verify, sort_keys=True, ensure_ascii=False)
|
||||
|
||||
# Step 6: 校验签名完整性(防篡改)
|
||||
expected_sig = _compute_signature(payload_json)
|
||||
stored_sig = lic_data.get("signature", "").upper()
|
||||
if not hmac.compare_digest(expected_sig, stored_sig):
|
||||
return False, "授权文件签名校验失败(可能被篡改)"
|
||||
|
||||
# Step 7: 校验机器码绑定
|
||||
bound_machine = lic_data.get("machine_code", "")
|
||||
current_machine = get_machine_code(32)
|
||||
if not hmac.compare_digest(bound_machine, current_machine):
|
||||
return False, "机器码不匹配(授权文件与本机不兼容)"
|
||||
|
||||
return True, "授权验证通过"
|
||||
|
||||
|
||||
def generate_license(
|
||||
machine_code: str,
|
||||
output_path: str,
|
||||
expiry_date: Optional[str] = None,
|
||||
product_name: str = "WaterQualityInversion",
|
||||
max_uses: Optional[int] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
为指定机器码生成合法的授权文件(供开发者使用)。
|
||||
|
||||
参数:
|
||||
machine_code: 目标机器的机器码(32位)
|
||||
output_path: 授权文件输出路径(含文件名,如 "D:/license.lic")
|
||||
expiry_date: 有效期截止日期,格式 "YYYY-MM-DD",默认永久
|
||||
product_name: 产品名称
|
||||
max_uses: 最大使用次数(可选,默认不限制)
|
||||
|
||||
返回:
|
||||
(success, message)
|
||||
"""
|
||||
if len(machine_code) not in (16, 24, 32, 48, 64):
|
||||
return False, f"机器码长度无效(期望 16/24/32/48/64,实际 {len(machine_code)})"
|
||||
|
||||
# 构建 payload
|
||||
payload = {
|
||||
"version": LICENSE_VERSION,
|
||||
"product": product_name,
|
||||
"machine_code": machine_code.upper(),
|
||||
"generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"expiry": expiry_date or "",
|
||||
}
|
||||
if max_uses is not None:
|
||||
payload["max_uses"] = max_uses
|
||||
|
||||
# 计算签名
|
||||
payload_json = json.dumps(payload, sort_keys=True, ensure_ascii=False)
|
||||
signature = _compute_signature(payload_json)
|
||||
|
||||
# 完整授权文件内容
|
||||
lic_content = json.dumps({**payload, "signature": signature}, indent=2, ensure_ascii=False)
|
||||
|
||||
# 写入文件
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(lic_content)
|
||||
return True, f"授权文件已生成: {output_path}"
|
||||
except OSError as e:
|
||||
return False, f"写入授权文件失败: {e}"
|
||||
|
||||
|
||||
def get_machine_info() -> dict:
|
||||
"""获取完整机器信息(调试用)"""
|
||||
return {
|
||||
"cpu_id": get_cpu_id(),
|
||||
"motherboard_uuid": get_motherboard_uuid(),
|
||||
"machine_code_16": get_machine_code(16),
|
||||
"machine_code_32": get_machine_code(32),
|
||||
"license_path": get_license_path(),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 第四部分:便捷入口(支持直接运行)
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="WaterQuality 授权管理工具")
|
||||
subparsers = parser.add_subparsers(dest="command", help="子命令")
|
||||
|
||||
# 子命令:verify
|
||||
p_verify = subparsers.add_parser("verify", help="验证本机授权")
|
||||
p_verify.add_argument("-f", "--file", default=None, help="授权文件路径")
|
||||
|
||||
# 子命令:gen / generate
|
||||
p_gen = subparsers.add_parser("generate", help="为指定机器码生成授权文件")
|
||||
p_gen.add_argument("-m", "--machine", required=True, help="目标机器的机器码")
|
||||
p_gen.add_argument("-o", "--output", required=True, help="输出文件路径")
|
||||
p_gen.add_argument("-e", "--expiry", default=None, help="有效期截止日期 YYYY-MM-DD")
|
||||
p_gen.add_argument("-n", "--name", default="WaterQualityInversion", help="产品名称")
|
||||
|
||||
# 子命令:info
|
||||
subparsers.add_parser("info", help="显示本机机器信息")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "verify":
|
||||
ok, msg = verify_license(args.file)
|
||||
print(f"[{'OK' if ok else 'FAIL'}] {msg}")
|
||||
|
||||
elif args.command == "generate":
|
||||
ok, msg = generate_license(args.machine, args.output, args.expiry, args.name)
|
||||
print(f"[{'OK' if ok else 'FAIL'}] {msg}")
|
||||
|
||||
elif args.command == "info":
|
||||
info = get_machine_info()
|
||||
print("=" * 50)
|
||||
print("硬件指纹信息")
|
||||
print("=" * 50)
|
||||
for k, v in info.items():
|
||||
print(f" {k}: {v or '(读取失败)'}")
|
||||
print("=" * 50)
|
||||
# 同时演示验证
|
||||
ok, msg = verify_license()
|
||||
print(f"\n授权验证: [{'OK' if ok else 'FAIL'}] {msg}")
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
54
src/core/algorithms/__init__.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""
|
||||
算法层模块
|
||||
包含插值算法和耀斑检测算法等核心数学计算
|
||||
"""
|
||||
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
|
||||
from src.core.algorithms.glint_detection.detectors import (
|
||||
otsu_threshold,
|
||||
zscore_threshold,
|
||||
percentile_threshold,
|
||||
iqr_outlier_detection,
|
||||
adaptive_threshold,
|
||||
multi_band_glint_detection,
|
||||
percentile_stretch,
|
||||
filter_large_components,
|
||||
create_shoreline_buffer,
|
||||
remove_shoreline_buffer,
|
||||
calculate_glint_mask,
|
||||
)
|
||||
from src.core.algorithms.qaa.qaas_baseline import QAABaselineSolver
|
||||
from src.core.algorithms.concentration_inversion import (
|
||||
ChlorophyllInversion,
|
||||
CDOMInversion,
|
||||
TurbidityInversion,
|
||||
TotalNitrogenInversion,
|
||||
TotalPhosphorusInversion,
|
||||
ConcentrationPipeline,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 插值
|
||||
'interpolate_pixels',
|
||||
'interpolate_zero_pixels_batch',
|
||||
# 耀斑检测
|
||||
'otsu_threshold',
|
||||
'zscore_threshold',
|
||||
'percentile_threshold',
|
||||
'iqr_outlier_detection',
|
||||
'adaptive_threshold',
|
||||
'multi_band_glint_detection',
|
||||
'percentile_stretch',
|
||||
'filter_large_components',
|
||||
'create_shoreline_buffer',
|
||||
'remove_shoreline_buffer',
|
||||
'calculate_glint_mask',
|
||||
# QAA
|
||||
'QAABaselineSolver',
|
||||
# 浓度反演
|
||||
'ChlorophyllInversion',
|
||||
'CDOMInversion',
|
||||
'TurbidityInversion',
|
||||
'TotalNitrogenInversion',
|
||||
'TotalPhosphorusInversion',
|
||||
'ConcentrationPipeline',
|
||||
]
|
||||
670
src/core/algorithms/concentration_inversion.py
Normal file
@ -0,0 +1,670 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
水质浓度反演模块
|
||||
|
||||
基于 QAA Step 8 输出的光谱吸收/散射系数 (a_lambda, bb_lambda),
|
||||
通过生物光学模型反演水质参数浓度。
|
||||
|
||||
主要反演目标:
|
||||
- 叶绿素 A (Chl-a):675nm 吸收峰法
|
||||
- 浊度 (Turbidity):后向散射系数法
|
||||
- CDOM 吸收系数 a_dg(440):指数衰减法
|
||||
- 总氮 (TN) / 总磷 (TP):光学代理回归框架
|
||||
|
||||
参考:
|
||||
- Lee, Z.P. et al. (2002/2010/2014) QAA 系列
|
||||
- Bricaud, A. et al. (1998) Limnol. Oceanogr. — 叶绿素比吸收系数
|
||||
- Carder, K.L. et al. (1999) Marine Technology Society — CDOM 指数衰减
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公共系数表(来自 Bricaud et al. 1998 等文献,内陆水体典型值)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 叶绿素比吸收系数 a*_ph(675) 单位:m²/mg
|
||||
# 随叶绿素浓度范围变化,Bricaud 经验值
|
||||
CHLA_SPECIFIC_ABSORPTION: Dict[str, float] = {
|
||||
"low": 0.055, # 寡营养水体,Chla < 5 mg/m³
|
||||
"medium": 0.040, # 中营养,Chla 5-30 mg/m³
|
||||
"high": 0.028, # 富营养,Chla 30-100 mg/m³
|
||||
"bloom": 0.020, # 藻华,Chla > 100 mg/m³
|
||||
}
|
||||
|
||||
# CDOM 指数衰减斜率 S(单位:nm⁻¹),内陆水体典型范围 0.010-0.025
|
||||
CDOM_S_LOOKUP: Dict[str, float] = {
|
||||
"low_turbidity": 0.010, # 清澈寡营养
|
||||
"medium_turbidity": 0.015, # 中等浊度
|
||||
"high_turbidity": 0.020, # 高浊度富营养
|
||||
"bloom": 0.025, # 藻华主导
|
||||
}
|
||||
|
||||
# 纯水吸收系数表(400-800nm,Babin et al. 2003 简化值,单位:m⁻¹)
|
||||
PURE_WATER_A: Dict[int, float] = {
|
||||
400: 0.0064, 410: 0.0066, 420: 0.0068, 430: 0.0072,
|
||||
440: 0.0080, 450: 0.0092, 460: 0.0105, 470: 0.0120,
|
||||
480: 0.0135, 490: 0.0155, 500: 0.0175, 510: 0.0200,
|
||||
520: 0.0230, 530: 0.0270, 540: 0.0315, 550: 0.0370,
|
||||
560: 0.0435, 570: 0.0510, 580: 0.0600, 590: 0.0710,
|
||||
600: 0.0830, 610: 0.0960, 620: 0.1110, 630: 0.1280,
|
||||
640: 0.1470, 650: 0.1680, 660: 0.1920, 670: 0.2180,
|
||||
675: 0.2450, 680: 0.2750, 690: 0.3100, 700: 0.3500,
|
||||
710: 0.3950, 720: 0.4450, 730: 0.5000, 740: 0.5600,
|
||||
750: 0.6250, 760: 0.6950, 770: 0.7700, 780: 0.8500,
|
||||
790: 0.9300, 800: 1.0100,
|
||||
}
|
||||
|
||||
|
||||
def _interp_pure_water_a(wavelength: float) -> float:
|
||||
"""线性插值获取纯水吸收系数"""
|
||||
wl_int = {k for k in PURE_WATER_A if k <= int(wavelength)}
|
||||
if not wl_int:
|
||||
return PURE_WATER_A[min(PURE_WATER_A.keys())]
|
||||
k_low = max(wl_int)
|
||||
k_high = min({k for k in PURE_WATER_A if k >= int(wavelength)} or {k_low})
|
||||
if k_low == k_high:
|
||||
return float(PURE_WATER_A[k_low])
|
||||
w = (wavelength - k_low) / (k_high - k_low)
|
||||
return float(PURE_WATER_A[k_low]) * (1 - w) + float(PURE_WATER_A[k_high]) * w
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 叶绿素反演器
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class ChlorophyllInversion:
|
||||
"""
|
||||
基于 675nm 吸收峰法的叶绿素 A 浓度反演。
|
||||
|
||||
原理:
|
||||
总吸收 a(675) = a_w(675) + a_ph(675) + a_dg(675)
|
||||
其中 a_ph(675) 是叶绿素特征吸收峰,
|
||||
a_dg(675) ≈ a_dg(440) * exp(-S * (675-440))
|
||||
|
||||
步骤:
|
||||
1. 从 a(λ) 减去纯水吸收 a_w(λ)
|
||||
2. 用线性基线法估算 a_dg(675):baseline(675) = mean[a(665), a(685)]
|
||||
3. a_ph(675) = a(675) - a_w(675) - baseline(675)
|
||||
4. Chla = a_ph(675) / a*_ph(675)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
specific_absorption : float, optional
|
||||
叶绿素比吸收系数 a*_ph(675),单位 m²/mg。
|
||||
若为 None,使用浓度自适应估算逻辑。
|
||||
lake_case : str, optional
|
||||
水体类型标识,用于自动选择比吸收系数,
|
||||
支持 "oligotrophic_clear" / "medium" / "bloom_dominant" / "turbid_mixed"。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specific_absorption: Optional[float] = None,
|
||||
lake_case: Optional[str] = None
|
||||
):
|
||||
self.specific_absorption = specific_absorption
|
||||
self.lake_case = lake_case or "medium"
|
||||
|
||||
def run_inversion(
|
||||
self,
|
||||
wavelengths: np.ndarray,
|
||||
a_lambda: np.ndarray,
|
||||
bb_lambda: Optional[np.ndarray] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
执行叶绿素 A 反演。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelengths : np.ndarray
|
||||
波长数组(nm),形状 (n_bands,)。
|
||||
a_lambda : np.ndarray
|
||||
总吸收系数 a(λ),形状 (n_bands,)。
|
||||
bb_lambda : np.ndarray, optional
|
||||
后向散射系数(暂未使用,保留扩展接口)。
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
包含键:
|
||||
- chla_mg_m3 : 叶绿素 A 浓度(mg/m³)
|
||||
- a_ph_675 : 675nm 处叶绿素吸收(m⁻¹)
|
||||
- baseline_675 : 675nm 处 CDOM+NAP 基线(m⁻¹)
|
||||
- a_w_675 : 纯水吸收(m⁻¹)
|
||||
"""
|
||||
wavelengths = np.asarray(wavelengths, dtype=np.float64)
|
||||
a_lambda = np.asarray(a_lambda, dtype=np.float64)
|
||||
|
||||
aw_675 = _interp_pure_water_a(675.0)
|
||||
|
||||
wl_arr = wavelengths
|
||||
a_arr = a_lambda
|
||||
|
||||
a_665 = float(np.interp(665, wl_arr, a_arr, left=np.nan, right=np.nan))
|
||||
a_675 = float(np.interp(675, wl_arr, a_arr, left=np.nan, right=np.nan))
|
||||
a_685 = float(np.interp(685, wl_arr, a_arr, left=np.nan, right=np.nan))
|
||||
|
||||
if not np.isfinite(a_665) or not np.isfinite(a_675) or not np.isfinite(a_685):
|
||||
return {
|
||||
"chla_mg_m3": np.nan,
|
||||
"a_ph_675": np.nan,
|
||||
"baseline_675": np.nan,
|
||||
"a_w_675": aw_675,
|
||||
"warning": "675nm 波段缺失,无法进行叶绿素反演",
|
||||
}
|
||||
|
||||
baseline_675 = (a_665 + a_685) / 2.0
|
||||
a_ph_675 = max(a_675 - aw_675 - baseline_675, 0.0)
|
||||
|
||||
if self.specific_absorption is not None:
|
||||
a_star = self.specific_absorption
|
||||
else:
|
||||
a_star = self._adaptive_specific_absorption(a_ph_675)
|
||||
|
||||
if a_star <= 0:
|
||||
return {
|
||||
"chla_mg_m3": np.nan,
|
||||
"a_ph_675": a_ph_675,
|
||||
"baseline_675": baseline_675,
|
||||
"a_w_675": aw_675,
|
||||
"warning": "比吸收系数为非正值",
|
||||
}
|
||||
|
||||
chla = a_ph_675 / a_star
|
||||
return {
|
||||
"chla_mg_m3": chla,
|
||||
"a_ph_675": a_ph_675,
|
||||
"baseline_675": baseline_675,
|
||||
"a_w_675": aw_675,
|
||||
}
|
||||
|
||||
def _adaptive_specific_absorption(self, a_ph_675: float) -> float:
|
||||
"""根据 a_ph(675) 量级自适应选择比吸收系数"""
|
||||
if a_ph_675 < 0.05:
|
||||
return CHLA_SPECIFIC_ABSORPTION["low"]
|
||||
elif a_ph_675 < 0.2:
|
||||
return CHLA_SPECIFIC_ABSORPTION["medium"]
|
||||
elif a_ph_675 < 0.5:
|
||||
return CHLA_SPECIFIC_ABSORPTION["high"]
|
||||
else:
|
||||
return CHLA_SPECIFIC_ABSORPTION["bloom"]
|
||||
|
||||
def invert_to_csv(
|
||||
self,
|
||||
input_csv: str,
|
||||
output_csv: str,
|
||||
sample_id_col: str = "sample_id"
|
||||
) -> str:
|
||||
"""
|
||||
从 a_lambda_results.csv 批量反演叶绿素并保存结果。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_csv : str
|
||||
Step 8 输出的 a_lambda_results.csv 路径。
|
||||
output_csv : str
|
||||
保存路径。
|
||||
sample_id_col : str
|
||||
样本 ID 列名。
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
输出文件路径。
|
||||
"""
|
||||
df = pd.read_csv(input_csv, encoding="utf-8-sig")
|
||||
df = df.sort_values([sample_id_col, "Wavelength"])
|
||||
|
||||
results = []
|
||||
for sid, group in df.groupby(sample_id_col, sort=False):
|
||||
wl = group["Wavelength"].values.astype(np.float64)
|
||||
a = group["a_lambda"].values.astype(np.float64)
|
||||
res = self.run_inversion(wl, a)
|
||||
res[sample_id_col] = sid
|
||||
results.append(res)
|
||||
|
||||
out_df = pd.DataFrame(results)
|
||||
cols = [sample_id_col, "chla_mg_m3", "a_ph_675", "baseline_675", "a_w_675"]
|
||||
cols = [c for c in cols if c in out_df.columns]
|
||||
out_df = out_df[cols]
|
||||
os.makedirs(os.path.dirname(output_csv) or ".", exist_ok=True)
|
||||
out_df.to_csv(output_csv, index=False, float_format="%.6f")
|
||||
return output_csv
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# CDOM 反演器
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class CDOMInversion:
|
||||
"""
|
||||
基于指数衰减模型的 CDOM 吸收系数反演。
|
||||
|
||||
原理:
|
||||
a_dg(λ) = a_dg(λ₀) * exp(-S * (λ - λ₀))
|
||||
|
||||
取 λ₀ = 440nm(蓝光峰),S 由水体类型决定,
|
||||
通过 a(550) ≈ a_w(550) + a_dg(550) 反推 a_dg(440)。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
S : float, optional
|
||||
CDOM 指数衰减斜率(nm⁻¹)。若为 None,根据 lake_case 自动选择。
|
||||
reference_wavelength : int
|
||||
参考波长,默认 440nm。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
S: Optional[float] = None,
|
||||
reference_wavelength: int = 440
|
||||
):
|
||||
self.S = S
|
||||
self.ref_wl = reference_wavelength
|
||||
|
||||
def run_inversion(
|
||||
self,
|
||||
wavelengths: np.ndarray,
|
||||
a_lambda: np.ndarray
|
||||
) -> Dict:
|
||||
"""
|
||||
执行 CDOM 反演。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelengths : np.ndarray
|
||||
波长数组。
|
||||
a_lambda : np.ndarray
|
||||
总吸收系数 a(λ)。
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
包含键:
|
||||
- a_dg_440 : 440nm 处 CDOM 吸收(m⁻¹)
|
||||
- S : 使用的衰减斜率
|
||||
"""
|
||||
wavelengths = np.asarray(wavelengths, dtype=np.float64)
|
||||
a_lambda = np.asarray(a_lambda, dtype=np.float64)
|
||||
|
||||
if self.S is None:
|
||||
S = CDOM_S_LOOKUP["medium_turbidity"]
|
||||
else:
|
||||
S = self.S
|
||||
|
||||
a_440 = float(np.interp(440, wavelengths, a_lambda, left=np.nan, right=np.nan))
|
||||
a_550 = float(np.interp(550, wavelengths, a_lambda, left=np.nan, right=np.nan))
|
||||
aw_440 = _interp_pure_water_a(440.0)
|
||||
aw_550 = _interp_pure_water_a(550.0)
|
||||
|
||||
a_dg_550 = max(a_550 - aw_550, 0.0)
|
||||
delta_wl = 550 - self.ref_wl
|
||||
a_dg_440 = a_dg_550 * np.exp(S * delta_wl)
|
||||
|
||||
return {
|
||||
"a_dg_440": a_dg_440,
|
||||
"a_dg_550": a_dg_550,
|
||||
"S": S,
|
||||
}
|
||||
|
||||
def invert_to_csv(
|
||||
self,
|
||||
input_csv: str,
|
||||
output_csv: str,
|
||||
sample_id_col: str = "sample_id"
|
||||
) -> str:
|
||||
"""从 a_lambda_results.csv 批量反演 CDOM 并保存结果。"""
|
||||
df = pd.read_csv(input_csv, encoding="utf-8-sig")
|
||||
df = df.sort_values([sample_id_col, "Wavelength"])
|
||||
|
||||
results = []
|
||||
for sid, group in df.groupby(sample_id_col, sort=False):
|
||||
wl = group["Wavelength"].values.astype(np.float64)
|
||||
a = group["a_lambda"].values.astype(np.float64)
|
||||
res = self.run_inversion(wl, a)
|
||||
res[sample_id_col] = sid
|
||||
results.append(res)
|
||||
|
||||
out_df = pd.DataFrame(results)
|
||||
cols = [sample_id_col, "a_dg_440", "a_dg_550", "S"]
|
||||
cols = [c for c in cols if c in out_df.columns]
|
||||
out_df = out_df[cols]
|
||||
os.makedirs(os.path.dirname(output_csv) or ".", exist_ok=True)
|
||||
out_df.to_csv(output_csv, index=False, float_format="%.6f")
|
||||
return output_csv
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 浊度反演器
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TurbidityInversion:
|
||||
"""
|
||||
基于后向散射系数的光学浊度反演。
|
||||
|
||||
原理(简化模型):
|
||||
Turbidity (NTU) ≈ k * b_b(550)
|
||||
|
||||
其中 b_b(550) 是 550nm 处的后向散射系数,
|
||||
k 为经验系数(内陆水体典型值 1.0-3.0)。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
k : float
|
||||
经验系数。默认值 2.0。
|
||||
reference_wavelength : int
|
||||
参考波段,默认 550nm。
|
||||
"""
|
||||
|
||||
def __init__(self, k: float = 2.0, reference_wavelength: int = 550):
|
||||
self.k = k
|
||||
self.ref_wl = reference_wavelength
|
||||
|
||||
def run_inversion(
|
||||
self,
|
||||
wavelengths: np.ndarray,
|
||||
bb_lambda: np.ndarray
|
||||
) -> Dict:
|
||||
"""
|
||||
执行浊度反演。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelengths : np.ndarray
|
||||
波长数组。
|
||||
bb_lambda : np.ndarray
|
||||
后向散射系数 b_b(λ)。
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
包含键:
|
||||
- turbidity_ntu : 浊度(NTU)
|
||||
- bb_ref : 参考波段处的 b_b 值
|
||||
"""
|
||||
wavelengths = np.asarray(wavelengths, dtype=np.float64)
|
||||
bb_lambda = np.asarray(bb_lambda, dtype=np.float64)
|
||||
|
||||
bb_ref = float(np.interp(
|
||||
self.ref_wl, wavelengths, bb_lambda, left=np.nan, right=np.nan
|
||||
))
|
||||
turbidity = self.k * bb_ref
|
||||
|
||||
return {
|
||||
"turbidity_ntu": turbidity,
|
||||
"bb_ref": bb_ref,
|
||||
}
|
||||
|
||||
def invert_to_csv(
|
||||
self,
|
||||
input_csv: str,
|
||||
output_csv: str,
|
||||
sample_id_col: str = "sample_id"
|
||||
) -> str:
|
||||
"""从 a_lambda_results.csv 批量反演浊度并保存结果。"""
|
||||
df = pd.read_csv(input_csv, encoding="utf-8-sig")
|
||||
if "bb_lambda" not in df.columns:
|
||||
raise ValueError("输入 CSV 中缺少 bb_lambda 列")
|
||||
df = df.sort_values([sample_id_col, "Wavelength"])
|
||||
|
||||
results = []
|
||||
for sid, group in df.groupby(sample_id_col, sort=False):
|
||||
wl = group["Wavelength"].values.astype(np.float64)
|
||||
bb = group["bb_lambda"].values.astype(np.float64)
|
||||
res = self.run_inversion(wl, bb)
|
||||
res[sample_id_col] = sid
|
||||
results.append(res)
|
||||
|
||||
out_df = pd.DataFrame(results)
|
||||
cols = [sample_id_col, "turbidity_ntu", "bb_ref"]
|
||||
cols = [c for c in cols if c in out_df.columns]
|
||||
out_df = out_df[cols]
|
||||
os.makedirs(os.path.dirname(output_csv) or ".", exist_ok=True)
|
||||
out_df.to_csv(output_csv, index=False, float_format="%.6f")
|
||||
return output_csv
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 总氮 / 总磷反演器(光学代理回归框架)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TotalNitrogenInversion:
|
||||
"""
|
||||
总氮 (TN) 光学代理回归模型。
|
||||
|
||||
框架说明:
|
||||
TN 与 Chla 之间通常存在正相关(R² ≈ 0.5-0.7),
|
||||
本类提供回归框架,实际系数需由实测数据标定。
|
||||
|
||||
公式(线性代理):
|
||||
TN (mg/L) = α * Chla + β * Turbidity + γ
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alpha : float
|
||||
Chla 系数。默认 0.05。
|
||||
beta : float
|
||||
浊度系数。默认 0.10。
|
||||
gamma : float
|
||||
截距。默认 0.20。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float = 0.05,
|
||||
beta: float = 0.10,
|
||||
gamma: float = 0.20
|
||||
):
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
|
||||
def run_inversion(
|
||||
self,
|
||||
chla_mg_m3: float,
|
||||
turbidity_ntu: float
|
||||
) -> Dict:
|
||||
"""执行总氮反演(光学代理法)。"""
|
||||
tn = self.alpha * chla_mg_m3 + self.beta * turbidity_ntu + self.gamma
|
||||
return {"tn_mg_L": tn}
|
||||
|
||||
def calibrate(
|
||||
self,
|
||||
samples: List[Dict]
|
||||
) -> None:
|
||||
"""
|
||||
用实测样本标定回归系数。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
samples : list[dict]
|
||||
样本列表,每项包含 'chla', 'turbidity', 'tn' 键。
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
X = np.array([[s["chla"], s["turbidity"]] for s in samples])
|
||||
y = np.array([s["tn"] for s in samples])
|
||||
coeffs, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
|
||||
self.alpha, self.beta = coeffs
|
||||
self.gamma = float(np.mean(y - self.alpha * X[:, 0] - self.beta * X[:, 1]))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"标定失败: {e}")
|
||||
|
||||
|
||||
class TotalPhosphorusInversion:
|
||||
"""
|
||||
总磷 (TP) 光学代理回归模型。
|
||||
|
||||
框架说明:
|
||||
TP 与 Chla / 浊度均相关(湖泊富营养化阶段尤为明显),
|
||||
提供双变量线性回归框架,实际系数需由实测数据标定。
|
||||
|
||||
公式(线性代理):
|
||||
TP (mg/L) = α * Chla + β * Turbidity + γ
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alpha : float
|
||||
Chla 系数。默认 0.002。
|
||||
beta : float
|
||||
浊度系数。默认 0.005。
|
||||
gamma : float
|
||||
截距。默认 0.010。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float = 0.002,
|
||||
beta: float = 0.005,
|
||||
gamma: float = 0.010
|
||||
):
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
|
||||
def run_inversion(
|
||||
self,
|
||||
chla_mg_m3: float,
|
||||
turbidity_ntu: float
|
||||
) -> Dict:
|
||||
"""执行总磷反演(光学代理法)。"""
|
||||
tp = self.alpha * chla_mg_m3 + self.beta * turbidity_ntu + self.gamma
|
||||
return {"tp_mg_L": tp}
|
||||
|
||||
def calibrate(
|
||||
self,
|
||||
samples: List[Dict]
|
||||
) -> None:
|
||||
"""用实测样本标定回归系数。"""
|
||||
try:
|
||||
import numpy as np
|
||||
X = np.array([[s["chla"], s["turbidity"]] for s in samples])
|
||||
y = np.array([s["tp"] for s in samples])
|
||||
coeffs, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
|
||||
self.alpha, self.beta = coeffs
|
||||
self.gamma = float(np.mean(y - self.alpha * X[:, 0] - self.beta * X[:, 1]))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"标定失败: {e}")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 一站式浓度反演流水线
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class ConcentrationPipeline:
|
||||
"""
|
||||
整合 Chlorophyll / CDOM / Turbidity / TN / TP 反演的一站式流水线。
|
||||
|
||||
接收 Step 8 输出的 a_lambda_results.csv,
|
||||
输出 final_concentrations.csv(含所有水质参数浓度列)。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lake_case : str, optional
|
||||
水体类型,用于 Chla 比吸收系数自适应选择。
|
||||
S_cdom : float, optional
|
||||
CDOM 衰减斜率(若为 None,自动选择)。
|
||||
k_turbidity : float
|
||||
浊度经验系数。
|
||||
tn_params : dict, optional
|
||||
总氮反演初始参数。
|
||||
tp_params : dict, optional
|
||||
总磷反演初始参数。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lake_case: str = "medium",
|
||||
S_cdom: Optional[float] = None,
|
||||
k_turbidity: float = 2.0,
|
||||
tn_params: Optional[Dict] = None,
|
||||
tp_params: Optional[Dict] = None,
|
||||
):
|
||||
self.lake_case = lake_case
|
||||
self.chla_inv = ChlorophyllInversion(lake_case=lake_case)
|
||||
self.cdom_inv = CDOMInversion(S=S_cdom)
|
||||
self.turb_inv = TurbidityInversion(k=k_turbidity)
|
||||
self.tn_inv = TotalNitrogenInversion(**(tn_params or {}))
|
||||
self.tp_inv = TotalPhosphorusInversion(**(tp_params or {}))
|
||||
|
||||
def run_pipeline(
|
||||
self,
|
||||
input_csv: str,
|
||||
output_csv: str,
|
||||
sample_id_col: str = "sample_id"
|
||||
) -> str:
|
||||
"""
|
||||
执行完整浓度反演流水线。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_csv : str
|
||||
Step 8 输出的 a_lambda_results.csv 路径。
|
||||
output_csv : str
|
||||
输出 final_concentrations.csv 路径。
|
||||
sample_id_col : str
|
||||
样本 ID 列名。
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
输出文件路径。
|
||||
"""
|
||||
df = pd.read_csv(input_csv, encoding="utf-8-sig")
|
||||
if "bb_lambda" not in df.columns:
|
||||
df["bb_lambda"] = np.nan
|
||||
|
||||
# ── 保留原始坐标列:按 sample_id 取第一条记录的非光谱列 ───────────
|
||||
wl_col = "Wavelength"
|
||||
coord_meta_cols = [c for c in df.columns if c not in (sample_id_col, wl_col, "a_lambda", "bb_lambda")]
|
||||
coord_df = df.groupby(sample_id_col, sort=False)[coord_meta_cols].first().reset_index()
|
||||
|
||||
df = df.sort_values([sample_id_col, "Wavelength"])
|
||||
|
||||
results = []
|
||||
for sid, group in df.groupby(sample_id_col, sort=False):
|
||||
wl = group["Wavelength"].values.astype(np.float64)
|
||||
a = group["a_lambda"].values.astype(np.float64)
|
||||
bb = group["bb_lambda"].values.astype(np.float64) \
|
||||
if "bb_lambda" in group.columns and group["bb_lambda"].notna().any() \
|
||||
else None
|
||||
|
||||
chla_res = self.chla_inv.run_inversion(wl, a)
|
||||
cdom_res = self.cdom_inv.run_inversion(wl, a)
|
||||
if bb is not None and np.any(np.isfinite(bb)):
|
||||
turb_res = self.turb_inv.run_inversion(wl, bb)
|
||||
else:
|
||||
turb_res = {"turbidity_ntu": np.nan, "bb_ref": np.nan}
|
||||
|
||||
chla_val = chla_res.get("chla_mg_m3", np.nan)
|
||||
turb_val = turb_res.get("turbidity_ntu", np.nan)
|
||||
|
||||
tn_res = self.tn_inv.run_inversion(chla_val, turb_val)
|
||||
tp_res = self.tp_inv.run_inversion(chla_val, turb_val)
|
||||
|
||||
row = {
|
||||
sample_id_col: sid,
|
||||
"Chla_mg_m3": chla_val,
|
||||
"a_ph_675_m1": chla_res.get("a_ph_675", np.nan),
|
||||
"CDOM_a_dg_440_m1": cdom_res.get("a_dg_440", np.nan),
|
||||
"Turbidity_NTU": turb_val,
|
||||
"TN_mg_L": tn_res.get("tn_mg_L", np.nan),
|
||||
"TP_mg_L": tp_res.get("tp_mg_L", np.nan),
|
||||
}
|
||||
results.append(row)
|
||||
|
||||
out_df = pd.DataFrame(results)
|
||||
# ── 将原始坐标列按 sample_id 合并到浓度结果左侧 ───────────────────
|
||||
if not coord_df.empty and sample_id_col in coord_df.columns:
|
||||
out_df = coord_df.merge(out_df, on=sample_id_col, how="left")
|
||||
os.makedirs(os.path.dirname(output_csv) or ".", exist_ok=True)
|
||||
out_df.to_csv(output_csv, index=False, float_format="%.6f")
|
||||
return output_csv
|
||||
31
src/core/algorithms/glint_detection/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
耀斑检测算法模块
|
||||
包含各种耀斑检测的核心数学计算函数
|
||||
"""
|
||||
from src.core.algorithms.glint_detection.detectors import (
|
||||
otsu_threshold,
|
||||
zscore_threshold,
|
||||
percentile_threshold,
|
||||
iqr_outlier_detection,
|
||||
adaptive_threshold,
|
||||
multi_band_glint_detection,
|
||||
percentile_stretch,
|
||||
filter_large_components,
|
||||
create_shoreline_buffer,
|
||||
remove_shoreline_buffer,
|
||||
calculate_glint_mask,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'otsu_threshold',
|
||||
'zscore_threshold',
|
||||
'percentile_threshold',
|
||||
'iqr_outlier_detection',
|
||||
'adaptive_threshold',
|
||||
'multi_band_glint_detection',
|
||||
'percentile_stretch',
|
||||
'filter_large_components',
|
||||
'create_shoreline_buffer',
|
||||
'remove_shoreline_buffer',
|
||||
'calculate_glint_mask',
|
||||
]
|
||||
595
src/core/algorithms/glint_detection/detectors.py
Normal file
@ -0,0 +1,595 @@
|
||||
"""
|
||||
耀斑检测算法模块
|
||||
|
||||
包含各种耀斑检测的核心数学计算函数,纯数学逻辑,不涉及文件I/O。
|
||||
支持的方法:otsu, zscore, percentile, iqr, adaptive, multi_band
|
||||
|
||||
本模块是从 src/utils/find_severe_glint_area.py 抽取出来的核心算法部分。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, List, Tuple
|
||||
from functools import wraps
|
||||
|
||||
try:
|
||||
import cv2
|
||||
CV2_AVAILABLE = True
|
||||
except ImportError:
|
||||
CV2_AVAILABLE = False
|
||||
|
||||
|
||||
def timeit(func):
|
||||
"""装饰器:测量函数执行时间"""
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
import time
|
||||
start = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
print(f"[{func.__name__}] 耗时: {end - start:.2f}s")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 百分位数拉伸
|
||||
# =============================================================================
|
||||
|
||||
def percentile_stretch(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
lower_percentile: float = 2,
|
||||
upper_percentile: float = 98,
|
||||
output_range: Tuple[int, int] = (0, 255)
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用百分位数裁剪进行归一化,适用于低反射率数据
|
||||
通过排除极值,更好地利用数据的动态范围
|
||||
|
||||
Args:
|
||||
img: 输入图像数组(反射率值,通常在0-1之间)
|
||||
data_water_mask: 水域掩膜
|
||||
lower_percentile: 下百分位数,用于裁剪最小值(默认2)
|
||||
upper_percentile: 上百分位数,用于裁剪最大值(默认98)
|
||||
output_range: 输出范围,默认(0, 255)
|
||||
|
||||
Returns:
|
||||
归一化后的图像数组(整数类型)
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return img.astype(np.int32)
|
||||
|
||||
p_lower = np.percentile(valid_pixels, lower_percentile)
|
||||
p_upper = np.percentile(valid_pixels, upper_percentile)
|
||||
|
||||
if p_lower >= p_upper:
|
||||
p_lower = np.percentile(valid_pixels, 1)
|
||||
p_upper = np.percentile(valid_pixels, 99)
|
||||
if p_lower >= p_upper:
|
||||
p_upper = valid_pixels.max()
|
||||
p_lower = valid_pixels.min()
|
||||
|
||||
img_clipped = np.clip(img, p_lower, p_upper)
|
||||
|
||||
if p_upper > p_lower:
|
||||
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (
|
||||
output_range[1] - output_range[0]
|
||||
) + output_range[0]
|
||||
else:
|
||||
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
|
||||
|
||||
return img_stretched.astype(np.int32)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Otsu阈值分割
|
||||
# =============================================================================
|
||||
|
||||
def otsu_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
ignore_value: int = 0,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于Otsu方法的自动阈值分割
|
||||
通过最大化类间方差找到最佳分割阈值
|
||||
|
||||
Args:
|
||||
img: 输入图像数组(整数值)
|
||||
data_water_mask: 水域掩膜
|
||||
ignore_value: 忽略的值(默认为0)
|
||||
foreground: 耀斑区域值(默认1)
|
||||
background: 背景值(默认0)
|
||||
|
||||
Returns:
|
||||
二值化检测结果数组
|
||||
"""
|
||||
height, width = img.shape
|
||||
|
||||
max_value = int(np.max(img[img > ignore_value])) + 1
|
||||
if max_value < 2:
|
||||
max_value = 256
|
||||
|
||||
hist = np.zeros([max_value], np.float32)
|
||||
|
||||
invalid_counter = 0
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
|
||||
invalid_counter += 1
|
||||
continue
|
||||
hist[img[i, j]] += 1
|
||||
|
||||
total_valid = height * width - invalid_counter
|
||||
if total_valid <= 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
hist /= total_valid
|
||||
|
||||
threshold = 0
|
||||
deltaMax = 0
|
||||
|
||||
for i in range(max_value):
|
||||
wA = sum(hist[:i + 1])
|
||||
wB = sum(hist[i + 1:])
|
||||
if wA == 0:
|
||||
wA = 1e-10
|
||||
if wB == 0:
|
||||
wB = 1e-10
|
||||
|
||||
uAtmp = sum(j * hist[j] for j in range(i + 1))
|
||||
uBtmp = sum(j * hist[j] for j in range(i + 1, max_value))
|
||||
uA = uAtmp / wA
|
||||
uB = uBtmp / wB
|
||||
u = uAtmp + uBtmp
|
||||
|
||||
deltaTmp = wA * ((uA - u) ** 2) + wB * ((uB - u) ** 2)
|
||||
if deltaTmp > deltaMax:
|
||||
deltaMax = deltaTmp
|
||||
threshold = i
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > threshold] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Z-score阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def zscore_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
z_threshold: float = 2.5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于Z-score(标准化分数)的耀斑检测方法
|
||||
使用统计方法识别异常高亮的像素,对数据分布不敏感
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
z_threshold: Z-score阈值,默认2.5(即超过均值2.5个标准差)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
mean_val = np.mean(valid_pixels)
|
||||
std_val = np.std(valid_pixels)
|
||||
|
||||
if std_val == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
z_scores = np.zeros_like(img, dtype=np.float32)
|
||||
valid_mask = (data_water_mask > 0) & np.isfinite(img)
|
||||
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[z_scores > z_threshold] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 百分位数阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def percentile_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
percentile: float = 95,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于百分位数的耀斑检测方法
|
||||
使用百分位数作为阈值,对异常值更稳健
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
percentile: 百分位数阈值,默认95(即超过95%的像素值)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
threshold_val = np.percentile(valid_pixels, percentile)
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > threshold_val] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# IQR异常值检测
|
||||
# =============================================================================
|
||||
|
||||
def iqr_outlier_detection(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
iqr_multiplier: float = 1.5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
基于IQR(四分位距)的异常值检测方法
|
||||
使用四分位距识别异常高亮的像素,对数据分布不敏感
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
iqr_multiplier: IQR倍数,默认1.5(标准异常值检测)
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
|
||||
|
||||
if len(valid_pixels) == 0:
|
||||
return np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
q1 = np.percentile(valid_pixels, 25)
|
||||
q3 = np.percentile(valid_pixels, 75)
|
||||
iqr = q3 - q1
|
||||
|
||||
upper_bound = q3 + iqr_multiplier * iqr
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
det_img[img > upper_bound] = foreground
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 自适应阈值检测
|
||||
# =============================================================================
|
||||
|
||||
def adaptive_threshold(
|
||||
img: np.ndarray,
|
||||
data_water_mask: np.ndarray,
|
||||
window_size: int = 15,
|
||||
percentile: float = 90,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
自适应阈值方法
|
||||
基于局部统计特性进行阈值分割,对光照变化更稳健
|
||||
|
||||
Args:
|
||||
img: 输入图像数组
|
||||
data_water_mask: 水域掩膜
|
||||
window_size: 局部窗口大小(奇数)
|
||||
percentile: 局部百分位数阈值
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
height, width = img.shape
|
||||
|
||||
if window_size % 2 == 0:
|
||||
window_size += 1
|
||||
|
||||
half_window = window_size // 2
|
||||
|
||||
det_img = np.zeros_like(img, dtype=np.int32)
|
||||
|
||||
for i in range(half_window, height - half_window):
|
||||
for j in range(half_window, width - half_window):
|
||||
if data_water_mask[i, j] == 0:
|
||||
continue
|
||||
|
||||
local_window = img[i - half_window:i + half_window + 1,
|
||||
j - half_window:j + half_window + 1]
|
||||
local_mask = data_water_mask[i - half_window:i + half_window + 1,
|
||||
j - half_window:j + half_window + 1]
|
||||
|
||||
valid_pixels = local_window[local_mask > 0]
|
||||
|
||||
if len(valid_pixels) > 0:
|
||||
local_th = np.percentile(valid_pixels, percentile)
|
||||
if img[i, j] > local_th:
|
||||
det_img[i, j] = foreground
|
||||
|
||||
det_img[data_water_mask == 0] = background
|
||||
|
||||
return det_img
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 多波段融合耀斑检测
|
||||
# =============================================================================
|
||||
|
||||
def multi_band_glint_detection(
|
||||
nir_band: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
glint_waves: List[float],
|
||||
weights: Optional[List[float]] = None,
|
||||
method: str = 'zscore',
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95,
|
||||
sub_band_arrays: Optional[List[np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
多波段融合的耀斑检测方法
|
||||
结合多个波段的耀斑特征,提高检测的稳健性
|
||||
|
||||
Args:
|
||||
nir_band: 近红外波段数组(主波段,用于兼容性)
|
||||
water_mask: 水域掩膜数组
|
||||
glint_waves: 用于检测的波长列表,如[750, 800, 850]
|
||||
weights: 各波段的权重,如果为None则使用等权重
|
||||
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
|
||||
z_threshold: Z-score阈值(当method='zscore'时使用)
|
||||
percentile: 百分位数阈值(当method='percentile'时使用)
|
||||
sub_band_arrays: 子波段数组列表(如果提供,与 glint_waves 一一对应)
|
||||
|
||||
Returns:
|
||||
二值化检测结果
|
||||
"""
|
||||
if weights is None:
|
||||
weights = [1.0 / len(glint_waves)] * len(glint_waves)
|
||||
|
||||
if len(weights) != len(glint_waves):
|
||||
raise ValueError("权重数量必须与波长数量相同")
|
||||
|
||||
fused_band = None
|
||||
|
||||
if sub_band_arrays is not None and len(sub_band_arrays) == len(glint_waves):
|
||||
for i, band_array in enumerate(sub_band_arrays):
|
||||
if fused_band is None:
|
||||
fused_band = (band_array * weights[i]).astype(np.float32)
|
||||
else:
|
||||
fused_band = (fused_band + band_array * weights[i]).astype(np.float32)
|
||||
else:
|
||||
fused_band = nir_band.astype(np.float32)
|
||||
|
||||
if method == 'otsu':
|
||||
stretched = percentile_stretch(fused_band, water_mask, 2, 98)
|
||||
return otsu_threshold(stretched, water_mask)
|
||||
elif method == 'zscore':
|
||||
return zscore_threshold(fused_band, water_mask, z_threshold)
|
||||
elif method == 'percentile':
|
||||
return percentile_threshold(fused_band, water_mask, percentile)
|
||||
else:
|
||||
raise ValueError(f"不支持的方法: {method}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 连通域过滤
|
||||
# =============================================================================
|
||||
|
||||
def filter_large_components(
|
||||
binary_img: np.ndarray,
|
||||
max_area: Optional[int] = None,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
过滤掉面积超过阈值的连通域
|
||||
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
|
||||
|
||||
Args:
|
||||
binary_img: 二值化图像
|
||||
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
过滤后的二值化图像
|
||||
"""
|
||||
if max_area is None or max_area <= 0:
|
||||
return binary_img
|
||||
|
||||
if CV2_AVAILABLE:
|
||||
binary_for_label = (binary_img == foreground).astype(np.uint8)
|
||||
num_features, labeled_array, stats, _ = cv2.connectedComponentsWithStats(
|
||||
binary_for_label, connectivity=8
|
||||
)
|
||||
|
||||
if num_features == 0:
|
||||
return binary_img
|
||||
|
||||
component_sizes = stats[1:, cv2.CC_STAT_AREA]
|
||||
keep_labels = np.where(component_sizes <= max_area)[0] + 1
|
||||
|
||||
keep_mask = np.isin(labeled_array, keep_labels)
|
||||
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||||
filtered[keep_mask] = foreground
|
||||
|
||||
return filtered
|
||||
else:
|
||||
from scipy import ndimage
|
||||
labeled_array, num_features = ndimage.label(
|
||||
(binary_img == foreground).astype(np.int32)
|
||||
)
|
||||
|
||||
if num_features == 0:
|
||||
return binary_img
|
||||
|
||||
component_sizes = ndimage.sum(
|
||||
(labeled_array == i).astype(np.int32),
|
||||
labeled_array,
|
||||
range(1, num_features + 1)
|
||||
)
|
||||
|
||||
keep_mask = np.isin(labeled_array, [i + 1 for i, s in enumerate(component_sizes) if s <= max_area])
|
||||
filtered = np.zeros_like(binary_img, dtype=binary_img.dtype)
|
||||
filtered[keep_mask] = foreground
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 岸边缓冲区处理
|
||||
# =============================================================================
|
||||
|
||||
def create_shoreline_buffer(
|
||||
water_mask: np.ndarray,
|
||||
buffer_size: int = 5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
创建岸边缓冲区掩膜(向内缓冲)
|
||||
用于去除岸边附近的错误耀斑检测区域
|
||||
|
||||
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
|
||||
|
||||
Args:
|
||||
water_mask: 水域掩膜数组(水域=1,非水域=0)
|
||||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
岸边缓冲区掩膜(缓冲区区域=1,其他=0)
|
||||
"""
|
||||
if buffer_size <= 0:
|
||||
return np.zeros_like(water_mask, dtype=np.int32)
|
||||
|
||||
water_binary = (water_mask > 0).astype(np.uint8)
|
||||
structure_size = buffer_size * 2 + 1
|
||||
structure = np.ones((structure_size, structure_size), dtype=np.uint8)
|
||||
|
||||
if CV2_AVAILABLE:
|
||||
eroded_water = cv2.erode(water_binary, structure).astype(np.int32)
|
||||
else:
|
||||
from scipy import ndimage
|
||||
eroded_water = ndimage.binary_erosion(water_binary, structure).astype(np.int32)
|
||||
|
||||
buffer_mask = (water_binary - eroded_water).astype(np.int32)
|
||||
|
||||
return buffer_mask
|
||||
|
||||
|
||||
def remove_shoreline_buffer(
|
||||
glint_mask: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
buffer_size: int = 5,
|
||||
foreground: int = 1,
|
||||
background: int = 0
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
从耀斑掩膜中去除岸边缓冲区内的区域
|
||||
|
||||
Args:
|
||||
glint_mask: 耀斑掩膜数组
|
||||
water_mask: 水域掩膜数组
|
||||
buffer_size: 缓冲区大小(像素数),默认5像素
|
||||
foreground: 前景值
|
||||
background: 背景值
|
||||
|
||||
Returns:
|
||||
去除岸边缓冲区后的耀斑掩膜
|
||||
"""
|
||||
if buffer_size <= 0:
|
||||
return glint_mask
|
||||
|
||||
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
|
||||
|
||||
cleaned = glint_mask.copy()
|
||||
cleaned[buffer_mask > 0] = background
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 高级组合函数
|
||||
# =============================================================================
|
||||
|
||||
def calculate_glint_mask(
|
||||
nir_band: np.ndarray,
|
||||
water_mask: np.ndarray,
|
||||
method: str = 'otsu',
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95,
|
||||
iqr_multiplier: float = 1.5,
|
||||
window_size: int = 15,
|
||||
apply_percentile_stretch: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
计算耀斑掩膜的统一入口函数
|
||||
|
||||
Args:
|
||||
nir_band: 近红外波段数组
|
||||
water_mask: 水域掩膜
|
||||
method: 检测方法 ('otsu', 'zscore', 'percentile', 'iqr', 'adaptive')
|
||||
z_threshold: Z-score阈值
|
||||
percentile: 百分位数阈值
|
||||
iqr_multiplier: IQR倍数
|
||||
window_size: 自适应阈值窗口大小
|
||||
apply_percentile_stretch: 是否对otsu和adaptive方法应用百分位数拉伸
|
||||
|
||||
Returns:
|
||||
二值化耀斑掩膜
|
||||
"""
|
||||
if method == 'otsu':
|
||||
if apply_percentile_stretch:
|
||||
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
|
||||
return otsu_threshold(stretched, water_mask)
|
||||
else:
|
||||
return otsu_threshold(nir_band.astype(np.int32), water_mask)
|
||||
elif method == 'zscore':
|
||||
return zscore_threshold(nir_band, water_mask, z_threshold)
|
||||
elif method == 'percentile':
|
||||
return percentile_threshold(nir_band, water_mask, percentile)
|
||||
elif method == 'iqr':
|
||||
return iqr_outlier_detection(nir_band, water_mask, iqr_multiplier)
|
||||
elif method == 'adaptive':
|
||||
if apply_percentile_stretch:
|
||||
stretched = percentile_stretch(nir_band, water_mask, 2, 98)
|
||||
return adaptive_threshold(stretched, water_mask, window_size, percentile)
|
||||
else:
|
||||
return adaptive_threshold(nir_band.astype(np.int32), water_mask, window_size, percentile)
|
||||
else:
|
||||
raise ValueError(f"不支持的方法: {method}")
|
||||
7
src/core/algorithms/interpolation/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""
|
||||
插值算法模块
|
||||
包含0值像素插值的核心数学逻辑
|
||||
"""
|
||||
from src.core.algorithms.interpolation.interpolator import interpolate_pixels, interpolate_zero_pixels_batch
|
||||
|
||||
__all__ = ['interpolate_pixels', 'interpolate_zero_pixels_batch']
|
||||
602
src/core/algorithms/interpolation/interpolator.py
Normal file
@ -0,0 +1,602 @@
|
||||
"""
|
||||
像素插值算法模块
|
||||
|
||||
提供对影像中所有波段都为0的像素点进行插值的核心数学逻辑。
|
||||
支持多种插值方法:nearest, bilinear, spline (RBF), kriging。
|
||||
|
||||
本模块使用多进程并行分块 IO 加速(Plan A):
|
||||
- ProcessPoolExecutor 为每个 worker 进程打开一次源影像(initializer 阶段),
|
||||
避免每块重复 gdal.Open 带来的开销(Windows 上 ~50ms/次)
|
||||
- 主进程统一负责输出文件的写入,避免多进程写锁竞争
|
||||
- 分块大小(block_size)默认 1024,内存充足可调至 2048 / 4096
|
||||
|
||||
注意:
|
||||
- GDAL Dataset / Rasterio Dataset 对象不能跨进程传递(picking 不支持),
|
||||
所以 worker 必须在 init 阶段自己独立打开源文件
|
||||
- 每个 worker 强制设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 多线程
|
||||
造成的 CPU 过订阅
|
||||
- 关闭多进程:传 ``use_multiprocessing=False`` 或 ``n_workers=1``
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import griddata, RBFInterpolator
|
||||
from scipy.spatial import cKDTree
|
||||
SCIPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
SCIPY_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
_worker_dataset: Optional["gdal.Dataset"] = None
|
||||
|
||||
|
||||
def interpolate_pixels(
|
||||
image_stack: np.ndarray,
|
||||
zero_coords: np.ndarray,
|
||||
valid_coords: np.ndarray,
|
||||
valid_values: np.ndarray,
|
||||
interpolation_method: str = 'nearest',
|
||||
water_mask: Optional[np.ndarray] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
对指定坐标的像素进行插值(核心数学函数,不涉及文件I/O)
|
||||
|
||||
Args:
|
||||
image_stack: 影像数据堆叠,形状为 (height, width, n_bands) 的 float32 数组
|
||||
zero_coords: 需要插值的像素坐标,形状为 (n_zero, 2),每行是 [x, y]
|
||||
valid_coords: 有效像素坐标,形状为 (n_valid, 2)
|
||||
valid_values: 有效像素对应的值,形状为 (n_valid,) 或 (n_valid, n_bands)
|
||||
interpolation_method: 插值方法,可选 'nearest', 'bilinear', 'spline', 'kriging'
|
||||
water_mask: 可选的水域掩膜数组
|
||||
|
||||
Returns:
|
||||
插值后的影像副本,形状与 image_stack 相同
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
|
||||
height, width, n_bands = image_stack.shape
|
||||
result = image_stack.copy()
|
||||
|
||||
raw_method = str(interpolation_method).lower()
|
||||
if 'nearest' in raw_method or '邻近' in raw_method or '最邻近' in raw_method:
|
||||
method = 'nearest'
|
||||
elif 'bilinear' in raw_method or '线性' in raw_method or '双线性' in raw_method:
|
||||
method = 'bilinear'
|
||||
elif 'spline' in raw_method or '样条' in raw_method or 'rbf' in raw_method:
|
||||
method = 'spline'
|
||||
elif 'kriging' in raw_method or '克里金' in raw_method:
|
||||
method = 'kriging'
|
||||
else:
|
||||
method = 'nearest'
|
||||
|
||||
if len(valid_values) == 0:
|
||||
return result
|
||||
|
||||
is_multiband = len(valid_values.shape) > 1 and valid_values.shape[1] > 1
|
||||
|
||||
if is_multiband:
|
||||
for band_idx in range(n_bands):
|
||||
band_valid_values = valid_values[:, band_idx]
|
||||
interpolated_values = _interpolate_single_band(
|
||||
zero_coords, valid_coords, band_valid_values, method
|
||||
)
|
||||
y_coords = zero_coords[:, 1].astype(int)
|
||||
x_coords = zero_coords[:, 0].astype(int)
|
||||
result[y_coords, x_coords, band_idx] = interpolated_values
|
||||
else:
|
||||
interpolated_values = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values, method
|
||||
)
|
||||
y_coords = zero_coords[:, 1].astype(int)
|
||||
x_coords = zero_coords[:, 0].astype(int)
|
||||
result[y_coords, x_coords] = interpolated_values
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _interpolate_single_band(
|
||||
zero_coords: np.ndarray,
|
||||
valid_coords: np.ndarray,
|
||||
valid_values: np.ndarray,
|
||||
method: str
|
||||
) -> np.ndarray:
|
||||
"""对单个波段执行插值计算"""
|
||||
if method == 'nearest':
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords)
|
||||
return valid_values[indices]
|
||||
|
||||
elif method == 'bilinear':
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
elif method == 'spline':
|
||||
try:
|
||||
max_points = 10000
|
||||
if len(valid_values) > max_points:
|
||||
indices = np.random.choice(len(valid_values), max_points, replace=False)
|
||||
sample_coords = valid_coords[indices]
|
||||
sample_values = valid_values[indices]
|
||||
else:
|
||||
sample_coords = valid_coords
|
||||
sample_values = valid_values
|
||||
rbf = RBFInterpolator(sample_coords, sample_values, kernel='thin_plate_spline')
|
||||
interpolated = rbf(zero_coords)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
except Exception:
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
elif method == 'kriging':
|
||||
try:
|
||||
from src.utils.kriging import KrigingInterpolator
|
||||
interpolator = KrigingInterpolator()
|
||||
max_points = 5000
|
||||
if len(valid_values) > max_points:
|
||||
indices = np.random.choice(len(valid_values), max_points, replace=False)
|
||||
sample_coords = valid_coords[indices]
|
||||
sample_values = valid_values[indices]
|
||||
else:
|
||||
sample_coords = valid_coords
|
||||
sample_values = valid_values
|
||||
interpolated = griddata(
|
||||
sample_coords, sample_values, zero_coords,
|
||||
method='cubic', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
except Exception:
|
||||
interpolated = griddata(
|
||||
valid_coords, valid_values, zero_coords,
|
||||
method='linear', fill_value=0.0
|
||||
)
|
||||
nan_mask = np.isnan(interpolated)
|
||||
if np.any(nan_mask):
|
||||
tree = cKDTree(valid_coords)
|
||||
_, indices = tree.query(zero_coords[nan_mask])
|
||||
interpolated[nan_mask] = valid_values[indices]
|
||||
return interpolated
|
||||
|
||||
return np.zeros(len(zero_coords))
|
||||
|
||||
|
||||
def _normalize_interpolation_method(method: str) -> str:
|
||||
"""将中文/英文混用的插值方法名归一化为内部标准名
|
||||
|
||||
支持: 'nearest'/'邻近'/'最邻近','bilinear'/'线性'/'双线性',
|
||||
'spline'/'样条'/'rbf','kriging'/'克里金'。
|
||||
"""
|
||||
raw = str(method).lower()
|
||||
if 'nearest' in raw or '邻近' in raw or '最邻近' in raw:
|
||||
return 'nearest'
|
||||
if 'bilinear' in raw or '线性' in raw or '双线性' in raw:
|
||||
return 'bilinear'
|
||||
if 'spline' in raw or '样条' in raw or 'rbf' in raw:
|
||||
return 'spline'
|
||||
if 'kriging' in raw or '克里金' in raw:
|
||||
return 'kriging'
|
||||
return 'nearest'
|
||||
|
||||
|
||||
def _read_water_mask_to_array(
|
||||
water_mask: Optional[Union[str, np.ndarray]],
|
||||
expected_height: int,
|
||||
expected_width: int,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""读取水域掩膜为 numpy 数组(单波段,bool/int 均可)
|
||||
|
||||
None 或空字符串直接返回 None。形状不匹配时给出告警但不抛错,
|
||||
让调用方按"无掩膜"路径继续。
|
||||
"""
|
||||
if water_mask is None:
|
||||
return None
|
||||
if isinstance(water_mask, str):
|
||||
if not water_mask.strip():
|
||||
return None
|
||||
mask_ds = gdal.Open(water_mask, gdal.GA_ReadOnly)
|
||||
if mask_ds is None:
|
||||
print(f" [warn] 无法打开水域掩膜 {water_mask},按无掩膜处理")
|
||||
return None
|
||||
try:
|
||||
mask_array = mask_ds.GetRasterBand(1).ReadAsArray()
|
||||
finally:
|
||||
mask_ds = None
|
||||
elif isinstance(water_mask, np.ndarray):
|
||||
mask_array = water_mask
|
||||
else:
|
||||
return None
|
||||
|
||||
if mask_array.shape != (expected_height, expected_width):
|
||||
print(
|
||||
f" [warn] 水域掩膜形状 {mask_array.shape} 与影像 "
|
||||
f"({expected_height}, {expected_width}) 不匹配,按无掩膜处理"
|
||||
)
|
||||
return None
|
||||
return mask_array
|
||||
|
||||
|
||||
def _init_worker(img_path: str) -> None:
|
||||
"""ProcessPoolExecutor initializer: 每个 worker 进程只调用一次
|
||||
|
||||
在 worker 进程启动时打开源影像 dataset 并缓存在模块全局变量
|
||||
``_worker_dataset`` 中。后续所有块处理直接复用这个 dataset,
|
||||
避免每块重复 ``gdal.Open``(Windows 上约 50ms/次,100 块即 5s)。
|
||||
|
||||
同时设置 ``GDAL_NUM_THREADS=1``,避免 8 worker × GDAL 默认多线程
|
||||
造成的 CPU 过订阅。
|
||||
"""
|
||||
global _worker_dataset
|
||||
gdal.SetConfigOption('GDAL_NUM_THREADS', '1')
|
||||
if hasattr(gdal, 'UseExceptions'):
|
||||
gdal.UseExceptions()
|
||||
_worker_dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if _worker_dataset is None:
|
||||
raise RuntimeError(f"Worker failed to open source image: {img_path}")
|
||||
|
||||
|
||||
def _interpolate_block_worker(task: tuple) -> tuple:
|
||||
"""ProcessPoolExecutor worker: 处理单个块并返回结果
|
||||
|
||||
该函数必须保持模块级(可被 pickle),不持有任何外部状态——
|
||||
源 dataset 通过 ``_worker_dataset`` 模块全局变量获取。
|
||||
|
||||
Returns:
|
||||
``(x0, y0, inner_bands, zero_count, error_msg)`` 元组:
|
||||
- x0, y0: 块在影像中的写入起点
|
||||
- inner_bands: ``List[np.ndarray]``,每个元素是 (inner_h, inner_w)
|
||||
float32 数组(每个波段一个),或失败时为 None
|
||||
- zero_count: 该扩展块中识别到的零像素数(含 halo 范围)
|
||||
- error_msg: None 表示成功,str 表示错误信息
|
||||
"""
|
||||
(
|
||||
x0, y0, ey0, ex0, ey1, ex1,
|
||||
row_offset, col_offset, inner_h, inner_w,
|
||||
mask_segment_ext, method,
|
||||
) = task
|
||||
if _worker_dataset is None:
|
||||
return (x0, y0, None, 0, "Worker dataset not initialized")
|
||||
try:
|
||||
inner_bands, zero_count = _process_one_block(
|
||||
_worker_dataset, x0, y0, ey0, ex0, ey1, ex1,
|
||||
row_offset, col_offset, inner_h, inner_w,
|
||||
mask_segment_ext, method,
|
||||
)
|
||||
return (x0, y0, inner_bands, zero_count, None)
|
||||
except Exception as e:
|
||||
return (x0, y0, None, 0, str(e))
|
||||
|
||||
|
||||
def _process_one_block(
|
||||
dataset: "gdal.Dataset",
|
||||
x0: int, y0: int,
|
||||
ey0: int, ex0: int, ey1: int, ex1: int,
|
||||
row_offset: int, col_offset: int,
|
||||
inner_h: int, inner_w: int,
|
||||
mask_segment_ext: Optional[np.ndarray],
|
||||
method: str,
|
||||
) -> Tuple[List[np.ndarray], int]:
|
||||
"""处理单个扩展块(纯计算核心,dataset 显式传入)
|
||||
|
||||
串行模式和并行模式共用此函数。并行模式下 dataset 来自 worker 的
|
||||
缓存(``_worker_dataset``),串行模式下 dataset 由主函数传入。
|
||||
|
||||
Args:
|
||||
dataset: 已打开的源影像 dataset
|
||||
x0, y0: 内部块左上角(写入位置)
|
||||
ey0, ex0, ey1, ex1: 扩展块(含 halo)坐标
|
||||
row_offset, col_offset: 内部块在扩展块中的偏移
|
||||
inner_h, inner_w: 内部块尺寸
|
||||
mask_segment_ext: 扩展块对应的水域掩膜(None 表示不应用)
|
||||
method: 插值方法(已归一化)
|
||||
|
||||
Returns:
|
||||
``(inner_bands, zero_count)`` 元组:
|
||||
- inner_bands: ``List[np.ndarray]``,长度 = n_bands,每个元素形状为
|
||||
``(inner_h, inner_w)`` 的 float32 数组
|
||||
- zero_count: 扩展块中识别到的零像素数
|
||||
"""
|
||||
n_bands = dataset.RasterCount
|
||||
ext_bands: List[np.ndarray] = []
|
||||
for b in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(b)
|
||||
ext_bands.append(
|
||||
band.ReadAsArray(ex0, ey0, ex1 - ex0, ey1 - ey0).astype(np.float32)
|
||||
)
|
||||
band = None
|
||||
|
||||
try:
|
||||
ext_h, ext_w = ey1 - ey0, ex1 - ex0
|
||||
|
||||
all_zero_ext = np.ones((ext_h, ext_w), dtype=bool)
|
||||
for b_data in ext_bands:
|
||||
all_zero_ext &= (b_data == 0)
|
||||
|
||||
if mask_segment_ext is not None:
|
||||
all_zero_ext &= (mask_segment_ext > 0)
|
||||
|
||||
zero_count = int(np.sum(all_zero_ext))
|
||||
|
||||
if zero_count == 0:
|
||||
inner_bands = [
|
||||
ext_bands[b][
|
||||
row_offset:row_offset + inner_h,
|
||||
col_offset:col_offset + inner_w,
|
||||
]
|
||||
for b in range(n_bands)
|
||||
]
|
||||
return inner_bands, 0
|
||||
|
||||
zero_y, zero_x = np.where(all_zero_ext)
|
||||
zero_coords = np.column_stack([zero_x, zero_y])
|
||||
|
||||
valid_mask = ~all_zero_ext
|
||||
valid_y, valid_x = np.where(valid_mask)
|
||||
valid_coords = np.column_stack([valid_x, valid_y])
|
||||
|
||||
if len(valid_coords) == 0:
|
||||
print(
|
||||
f" [warn] 块 (y={y0}-{y0 + inner_h}, x={x0}-{x0 + inner_w}) "
|
||||
f"无有效像素可作插值上下文,已跳过"
|
||||
)
|
||||
inner_bands = [
|
||||
ext_bands[b][
|
||||
row_offset:row_offset + inner_h,
|
||||
col_offset:col_offset + inner_w,
|
||||
]
|
||||
for b in range(n_bands)
|
||||
]
|
||||
return inner_bands, zero_count
|
||||
|
||||
for b in range(n_bands):
|
||||
ext_band = ext_bands[b]
|
||||
valid_values_band = ext_band[valid_mask]
|
||||
if len(valid_values_band) == 0:
|
||||
continue
|
||||
band_result = _interpolate_single_band(
|
||||
zero_coords, valid_coords, valid_values_band, method
|
||||
)
|
||||
ext_band[zero_y, zero_x] = band_result
|
||||
|
||||
inner_bands = [
|
||||
ext_bands[b][
|
||||
row_offset:row_offset + inner_h,
|
||||
col_offset:col_offset + inner_w,
|
||||
]
|
||||
for b in range(n_bands)
|
||||
]
|
||||
return inner_bands, zero_count
|
||||
finally:
|
||||
del ext_bands
|
||||
|
||||
|
||||
def interpolate_zero_pixels_batch(
|
||||
img_path: str,
|
||||
interpolation_method: str = 'nearest',
|
||||
output_path: Optional[str] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
deglint_dir: Optional[str] = None,
|
||||
callback_progress: Optional[callable] = None,
|
||||
block_size: int = 1024,
|
||||
halo_size: int = 64,
|
||||
n_workers: Optional[int] = None,
|
||||
use_multiprocessing: bool = True,
|
||||
) -> Tuple[str, Optional[np.ndarray]]:
|
||||
"""
|
||||
对影像中所有波段都为0的像素点进行插值(完整流程,含文件I/O)。
|
||||
|
||||
采用 **分块 IO + 多进程并行** 策略:
|
||||
1. 影像按 ``block_size`` × ``block_size`` 分块,每块边界外扩展
|
||||
``halo_size`` 像素作为插值上下文,避免块边缘插值退化
|
||||
2. 多进程并行(默认 ``ProcessPoolExecutor``,worker 数 = CPU 核心数)
|
||||
并发处理所有块;GDAL Dataset 不能跨进程传递,所以每个 worker
|
||||
在 ``initializer`` 阶段独立打开源文件一次并缓存
|
||||
3. 主进程按块序接收处理结果并统一写入输出文件,避免写锁竞争
|
||||
4. 该方案可彻底避免一次性读取 50 波段整景影像时的 OOM 隐患
|
||||
(50 波段 × 4000×4000 × float32 ≈ 3GB 的 np.dstack)
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
interpolation_method: 插值方法,支持 'nearest', 'bilinear', 'spline',
|
||||
'kriging' 及其中文别名('邻近'/'最邻近'/'线性'/'双线性'/'样条'/'克里金')
|
||||
output_path: 输出文件路径(如果为 None 且 deglint_dir 提供,自动生成)
|
||||
water_mask: 水域掩膜(文件路径或数组),形状须与影像高宽一致
|
||||
deglint_dir: 去耀斑目录(用于生成默认输出路径)
|
||||
callback_progress: 进度回调函数,签名 ``callback(msg: str)``
|
||||
block_size: 分块大小(像素),默认 1024;内存充足可调 2048/4096
|
||||
halo_size: 上下文 halo 宽度(像素),默认 64
|
||||
n_workers: 并行 worker 进程数;None = ``multiprocessing.cpu_count()``;
|
||||
传 1 等价于串行模式
|
||||
use_multiprocessing: 是否启用多进程;False 时强制串行
|
||||
|
||||
Returns:
|
||||
``(output_path, None)`` 元组。第二个值固定为 ``None``(与原版语义保留
|
||||
兼容;返回完整内存堆叠会重新引入 OOM 风险,故不再提供)。
|
||||
"""
|
||||
if not SCIPY_AVAILABLE:
|
||||
raise ImportError("scipy未安装,无法进行0值像素插值")
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
method = _normalize_interpolation_method(interpolation_method)
|
||||
|
||||
if output_path is None and deglint_dir is not None:
|
||||
output_path = str(Path(deglint_dir) / f"interpolated_{method}.bsq")
|
||||
if output_path is None:
|
||||
raise ValueError("output_path 和 deglint_dir 至少需要指定一个")
|
||||
|
||||
if Path(output_path).exists():
|
||||
return output_path, None
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
if width <= 0 or height <= 0 or n_bands <= 0:
|
||||
raise ValueError(
|
||||
f"影像尺寸异常: width={width}, height={height}, n_bands={n_bands}"
|
||||
)
|
||||
|
||||
mask_array = _read_water_mask_to_array(water_mask, height, width)
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
if driver is None:
|
||||
raise RuntimeError("未找到可用的栅格驱动(ENVI / GTiff 都不存在)")
|
||||
|
||||
out_dataset = driver.Create(
|
||||
output_path, width, height, n_bands, gdal.GDT_Float32
|
||||
)
|
||||
if out_dataset is None:
|
||||
raise RuntimeError(f"无法创建输出文件: {output_path}")
|
||||
out_dataset.SetGeoTransform(geotransform)
|
||||
out_dataset.SetProjection(projection)
|
||||
|
||||
try:
|
||||
if not use_multiprocessing:
|
||||
effective_workers = 1
|
||||
elif n_workers is not None and n_workers >= 1:
|
||||
effective_workers = int(n_workers)
|
||||
else:
|
||||
try:
|
||||
cpu_count = multiprocessing.cpu_count() or 1
|
||||
except (NotImplementedError, OSError):
|
||||
cpu_count = 1
|
||||
# 为了内存安全,强制将物理进程数限制在最高 6 个
|
||||
effective_workers = min(6, max(1, cpu_count))
|
||||
|
||||
n_blocks_y = (height + block_size - 1) // block_size
|
||||
n_blocks_x = (width + block_size - 1) // block_size
|
||||
total_blocks = n_blocks_y * n_blocks_x
|
||||
|
||||
tasks = []
|
||||
for by in range(n_blocks_y):
|
||||
y0 = by * block_size
|
||||
y1 = min(y0 + block_size, height)
|
||||
inner_h = y1 - y0
|
||||
ey0 = max(0, y0 - halo_size)
|
||||
ey1 = min(height, y1 + halo_size)
|
||||
for bx in range(n_blocks_x):
|
||||
x0 = bx * block_size
|
||||
x1 = min(x0 + block_size, width)
|
||||
inner_w = x1 - x0
|
||||
ex0 = max(0, x0 - halo_size)
|
||||
ex1 = min(width, x1 + halo_size)
|
||||
row_offset = y0 - ey0
|
||||
col_offset = x0 - ex0
|
||||
mask_segment_ext = None
|
||||
if mask_array is not None:
|
||||
mask_segment_ext = mask_array[ey0:ey1, ex0:ex1].copy()
|
||||
tasks.append((
|
||||
x0, y0, ey0, ex0, ey1, ex1,
|
||||
row_offset, col_offset, inner_h, inner_w,
|
||||
mask_segment_ext, method,
|
||||
))
|
||||
|
||||
if callback_progress:
|
||||
callback_progress(
|
||||
f"分块插值开始: 共 {total_blocks} 块 "
|
||||
f"(block_size={block_size}, halo={halo_size}, method={method}, "
|
||||
f"workers={effective_workers})"
|
||||
)
|
||||
|
||||
total_zero_pixels = 0
|
||||
|
||||
if effective_workers <= 1:
|
||||
for block_idx, task in enumerate(tasks, 1):
|
||||
x0_t, y0_t = task[0], task[1]
|
||||
if callback_progress:
|
||||
callback_progress(
|
||||
f"块 {block_idx}/{total_blocks} "
|
||||
f"y=[{y0_t},{y0_t + task[8]}) x=[{x0_t},{x0_t + task[9]})"
|
||||
)
|
||||
inner_bands, zero_count = _process_one_block(
|
||||
dataset, *task
|
||||
)
|
||||
for b_idx, band_data in enumerate(inner_bands):
|
||||
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
|
||||
band_data, xoff=x0_t, yoff=y0_t
|
||||
)
|
||||
total_zero_pixels += zero_count
|
||||
else:
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=effective_workers,
|
||||
initializer=_init_worker,
|
||||
initargs=(img_path,),
|
||||
) as executor:
|
||||
futures = [
|
||||
executor.submit(_interpolate_block_worker, task)
|
||||
for task in tasks
|
||||
]
|
||||
for block_idx, future in enumerate(futures, 1):
|
||||
x0_t, y0_t, inner_bands, zero_count, error = future.result()
|
||||
if error is not None:
|
||||
raise RuntimeError(
|
||||
f"块 (y={y0_t}, x={x0_t}) 处理失败: {error}"
|
||||
)
|
||||
if inner_bands is not None:
|
||||
for b_idx, band_data in enumerate(inner_bands):
|
||||
out_dataset.GetRasterBand(b_idx + 1).WriteArray(
|
||||
band_data, xoff=x0_t, yoff=y0_t
|
||||
)
|
||||
total_zero_pixels += zero_count
|
||||
if callback_progress:
|
||||
callback_progress(f"已写入块 {block_idx}/{total_blocks}")
|
||||
|
||||
if callback_progress:
|
||||
callback_progress(
|
||||
f"分块插值完成: 共处理 {total_zero_pixels} 个零像素 "
|
||||
f"({total_blocks} 块,方法 {method},workers={effective_workers})"
|
||||
)
|
||||
|
||||
return output_path, None
|
||||
finally:
|
||||
out_dataset = None
|
||||
finally:
|
||||
dataset = None
|
||||
7
src/core/algorithms/qaa/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
QAA 准解析反演算法模块
|
||||
"""
|
||||
from src.core.algorithms.qaa.qaas_baseline import QAABaselineSolver
|
||||
|
||||
__all__ = ['QAABaselineSolver']
|
||||
345
src/core/algorithms/qaa/qaas_baseline.py
Normal file
@ -0,0 +1,345 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
QAA 准解析算法基线求解器 (QAABaselineSolver)
|
||||
|
||||
实现 QAA-v5 / QAA-v6 核心步骤:
|
||||
1. Rrs(λ) → r_rs(λ)(水面以下遥感反射率转换)
|
||||
2. 计算中间变量 u(λ)(固有光学性质比值)
|
||||
3. λ₀ 锚点查表获取纯水吸收 aw(λ₀) 和后向散射 bbw(λ₀)
|
||||
4. 估算全波段 b_b(λ)(后向散射系数)
|
||||
5. 推导全波段 a(λ)(总吸收系数)
|
||||
|
||||
参考:
|
||||
- Lee, Z.P. et al. (2002) JGR-Oceans, 107(C4), 9-1~9-18 (QAA-v4)
|
||||
- Lee, Z.P. et al. (2010) Applied Optics, 49(4), 617-623 (QAA-v5)
|
||||
- Lee, Z.P. et al. (2014) Applied Optics, 53(4), 598-611 (QAA-v6)
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class QAABaselineSolver:
|
||||
"""
|
||||
QAA 准解析算法基线求解器。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pure_water_csv : str, optional
|
||||
纯水 IOPs 表路径,默认使用 src/utils/pure_water_iops.csv。
|
||||
qaa_version : str, default "QAA-v6"
|
||||
算法版本,支持 "QAA-v5" 或 "QAA-v6"。
|
||||
|
||||
Attributes
|
||||
----------
|
||||
iops_df : pd.DataFrame
|
||||
纯水 IOPs 表,含 Wavelength / aw / bbw 三列。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pure_water_csv: Optional[str] = None,
|
||||
qaa_version: str = "QAA-v6"
|
||||
):
|
||||
if pure_water_csv is None:
|
||||
project_root = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), '..', '..', '..', 'utils')
|
||||
)
|
||||
pure_water_csv = os.path.join(project_root, 'pure_water_iops.csv')
|
||||
|
||||
if not os.path.exists(pure_water_csv):
|
||||
raise FileNotFoundError(f"纯水 IOPs 表不存在: {pure_water_csv}")
|
||||
|
||||
self.iops_df = pd.read_csv(pure_water_csv)
|
||||
self.qaa_version = qaa_version
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 核心 QAA 步骤
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _rrs_to_rrs_subsurface(rrs: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
将水面遥感反射率 Rrs 转换为水面以下遥感反射率 r_rs。
|
||||
|
||||
转换公式(Lee et al. 1999):
|
||||
r_rs = Rrs / (0.52 + 1.7 * Rrs)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rrs : np.ndarray
|
||||
水面遥感反射率 Rrs,形状 (N,) 或 (N, n_bands)。
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
水面以下遥感反射率 r_rs。
|
||||
"""
|
||||
rrs = np.asarray(rrs, dtype=np.float64)
|
||||
denom = 0.52 + 1.7 * rrs
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
result = rrs / denom
|
||||
result[~np.isfinite(result)] = np.nan
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _compute_u(rrs_subsurface: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
计算中间变量 u = b_b / (a + b_b)。
|
||||
|
||||
QAA-v5/v6 经验关系(Lee et al. 2002):
|
||||
u = r_rs / (0.5 * r_rs + sqrt(0.25 * r_rs^2 + 0.1 * r_rs))
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rrs_subsurface : np.ndarray
|
||||
水面以下遥感反射率 r_rs。
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
u 值,范围 [0, 1)。
|
||||
"""
|
||||
rs = np.asarray(rrs_subsurface, dtype=np.float64)
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
result = rs / (0.5 * rs + np.sqrt(0.25 * rs ** 2 + 0.1 * rs))
|
||||
result[~np.isfinite(result)] = np.nan
|
||||
return result
|
||||
|
||||
def _get_pure_water_iops(self, wavelength: Union[int, float]) -> Tuple[float, float]:
|
||||
"""
|
||||
根据波长从纯水 IOPs 表中插值获取 aw 和 bbw。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelength : float
|
||||
波长(nm),范围应在 400-800nm 内。
|
||||
|
||||
Returns
|
||||
-------
|
||||
(aw, bbw) : tuple
|
||||
纯水吸收系数 (m^-1) 和后向散射系数 (m^-1)。
|
||||
"""
|
||||
df = self.iops_df
|
||||
wl_arr = df['Wavelength'].values
|
||||
aw_arr = df['aw'].values
|
||||
bbw_arr = df['bbw'].values
|
||||
|
||||
aw = float(np.interp(wavelength, wl_arr, aw_arr))
|
||||
bbw = float(np.interp(wavelength, wl_arr, bbw_arr))
|
||||
return aw, bbw
|
||||
|
||||
@staticmethod
|
||||
def _compute_bb(
|
||||
u: np.ndarray,
|
||||
bbw_0: float,
|
||||
wavelength: np.ndarray,
|
||||
lambda_0: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
估算全波段后向散射系数 b_b(λ)。
|
||||
|
||||
经验光谱形状(Lee et al. 2002, QAA-v4):
|
||||
b_b(λ) = b_bw(λ₀) * (λ₀ / λ)^S
|
||||
|
||||
其中 S 为经验光谱斜率参数(QAA-v5 中默认 0.5,
|
||||
QAA-v6 中随 λ₀ 自适应调整)。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
u : np.ndarray
|
||||
中间变量 u。
|
||||
bbw_0 : float
|
||||
λ₀ 处的纯水后向散射系数。
|
||||
wavelength : np.ndarray
|
||||
全波段波长数组。
|
||||
lambda_0 : int
|
||||
参考波长(锚点)。
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
全波段后向散射系数 b_b。
|
||||
"""
|
||||
S = 0.5 if lambda_0 < 600 else 0.0
|
||||
wavelength = np.asarray(wavelength, dtype=np.float64)
|
||||
ratio = (float(lambda_0) / wavelength) ** S
|
||||
bb = u * bbw_0 / (1.0 - u) * ratio
|
||||
bb = np.maximum(bb, 0.0)
|
||||
return bb
|
||||
|
||||
@staticmethod
|
||||
def _compute_a(
|
||||
u: np.ndarray,
|
||||
aw_0: float,
|
||||
bbw_0: float,
|
||||
wavelength: np.ndarray,
|
||||
lambda_0: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
推导全波段总吸收系数 a(λ)。
|
||||
|
||||
由 u = b_b / (a + b_b) 推导:
|
||||
a = b_b * (1 - u) / u
|
||||
|
||||
Parameters
|
||||
----------
|
||||
u : np.ndarray
|
||||
中间变量 u。
|
||||
aw_0 : float
|
||||
λ₀ 处的纯水吸收系数。
|
||||
bbw_0 : float
|
||||
λ₀ 处的纯水后向散射系数。
|
||||
wavelength : np.ndarray
|
||||
全波段波长数组。
|
||||
lambda_0 : int
|
||||
参考波长(锚点)。
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
全波段总吸收系数 a。
|
||||
"""
|
||||
S = 0.5 if lambda_0 < 600 else 0.0
|
||||
wavelength = np.asarray(wavelength, dtype=np.float64)
|
||||
ratio = (float(lambda_0) / wavelength) ** S
|
||||
bbw = bbw_0 * ratio
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
a = bbw * (1.0 - u) / u + aw_0
|
||||
a[~np.isfinite(a)] = np.nan
|
||||
return a
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 主入口
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run_inversion(
|
||||
self,
|
||||
wavelengths: np.ndarray,
|
||||
Rrs_spectrum: np.ndarray,
|
||||
lambda_0: int
|
||||
) -> dict:
|
||||
"""
|
||||
执行 QAA 核心反演。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelengths : np.ndarray
|
||||
光谱波长数组(nm),形状 (n_bands,) 或 (n_samples, n_bands)。
|
||||
Rrs_spectrum : np.ndarray
|
||||
水面遥感反射率光谱数据,形状 (n_bands,) 或 (n_samples, n_bands)。
|
||||
若为 2D,每行为一个样本的光谱。
|
||||
lambda_0 : int
|
||||
参考波长(锚点),用于查表获取纯水 IOPs。
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
包含以下键的字典:
|
||||
- wavelengths : 波长数组
|
||||
- Rrs : 输入 Rrs
|
||||
- r_rs_subsurface : 水下遥感反射率
|
||||
- u : 中间变量
|
||||
- a_lambda : 总吸收系数 a(λ)
|
||||
- bb_lambda : 后向散射系数 b_b(λ)
|
||||
- aw : λ₀ 处纯水吸收
|
||||
- bbw : λ₀ 处纯水后向散射
|
||||
"""
|
||||
wavelengths = np.asarray(wavelengths, dtype=np.float64)
|
||||
Rrs_spectrum = np.asarray(Rrs_spectrum, dtype=np.float64)
|
||||
|
||||
if Rrs_spectrum.ndim == 1:
|
||||
Rrs_spectrum = Rrs_spectrum[np.newaxis, :]
|
||||
|
||||
aw_0, bbw_0 = self._get_pure_water_iops(lambda_0)
|
||||
|
||||
results = []
|
||||
for row in Rrs_spectrum:
|
||||
rrs_sub = self._rrs_to_rrs_subsurface(row)
|
||||
u = self._compute_u(rrs_sub)
|
||||
bb = self._compute_bb(u, bbw_0, wavelengths, lambda_0)
|
||||
a = self._compute_a(u, aw_0, bbw_0, wavelengths, lambda_0)
|
||||
results.append({
|
||||
'wavelengths': wavelengths,
|
||||
'Rrs': row,
|
||||
'r_rs_subsurface': rrs_sub,
|
||||
'u': u,
|
||||
'a_lambda': a,
|
||||
'bb_lambda': bb,
|
||||
'aw_0': aw_0,
|
||||
'bbw_0': bbw_0,
|
||||
})
|
||||
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
return results
|
||||
|
||||
def invert_to_csv(
|
||||
self,
|
||||
wavelengths: np.ndarray,
|
||||
Rrs_spectrum: np.ndarray,
|
||||
lambda_0: int,
|
||||
output_csv: str,
|
||||
wavelength_col: str = "Wavelength",
|
||||
sample_ids: Optional[list] = None
|
||||
) -> str:
|
||||
"""
|
||||
执行反演并将结果保存为 CSV 文件。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wavelengths : np.ndarray
|
||||
波长数组(n_bands,)。
|
||||
Rrs_spectrum : np.ndarray
|
||||
光谱数据,形状 (n_bands,) 或 (n_samples, n_bands)。
|
||||
lambda_0 : int
|
||||
参考波长。
|
||||
output_csv : str
|
||||
输出 CSV 文件路径。
|
||||
wavelength_col : str
|
||||
输出 CSV 中波长列的列名前缀。
|
||||
sample_ids : list, optional
|
||||
样本 ID 列表(若为 None,使用 row_0, row_1, ...)。
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
输出文件路径。
|
||||
"""
|
||||
wavelengths = np.asarray(wavelengths, dtype=np.float64)
|
||||
Rrs_spectrum = np.asarray(Rrs_spectrum, dtype=np.float64)
|
||||
|
||||
if Rrs_spectrum.ndim == 1:
|
||||
Rrs_spectrum = Rrs_spectrum[np.newaxis, :]
|
||||
|
||||
n_samples = Rrs_spectrum.shape[0]
|
||||
if sample_ids is None:
|
||||
sample_ids = [f"sample_{i}" for i in range(n_samples)]
|
||||
|
||||
aw_0, bbw_0 = self._get_pure_water_iops(lambda_0)
|
||||
|
||||
rows_out = []
|
||||
for i, row in enumerate(Rrs_spectrum):
|
||||
rrs_sub = self._rrs_to_rrs_subsurface(row)
|
||||
u = self._compute_u(rrs_sub)
|
||||
bb = self._compute_bb(u, bbw_0, wavelengths, lambda_0)
|
||||
a = self._compute_a(u, aw_0, bbw_0, wavelengths, lambda_0)
|
||||
for j, wl in enumerate(wavelengths):
|
||||
rows_out.append({
|
||||
'sample_id': sample_ids[i],
|
||||
'Wavelength': wl,
|
||||
'Rrs': row[j],
|
||||
'r_rs': rrs_sub[j],
|
||||
'u': u[j],
|
||||
'a_lambda': a[j],
|
||||
'bb_lambda': bb[j],
|
||||
})
|
||||
|
||||
df = pd.DataFrame(rows_out)
|
||||
os.makedirs(os.path.dirname(output_csv) or '.', exist_ok=True)
|
||||
df.to_csv(output_csv, index=False, float_format='%.8f')
|
||||
return output_csv
|
||||
22
src/core/algorithms/waterindex_inversion.py
Normal file
@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
水色指数反演模块(包入口)
|
||||
|
||||
从 waterindex.csv 读取公式,对去耀斑 BSQ 高光谱影像进行全图矩阵运算,
|
||||
输出带完整坐标信息的 GeoTIFF。
|
||||
|
||||
公式格式(waterindex.csv):
|
||||
- 波长占位符:w{nm},如 w686, w708, w665
|
||||
- 支持混合大小写:w686 / W665 均可
|
||||
- 示例:NDCI = (w708 - w665) / (w708 + w665)
|
||||
|
||||
输出:
|
||||
- GeoTIFF (Float32),LZW 压缩,带 Tile
|
||||
- 完整克隆原始 BSQ 的 GeoTransform / Projection / NoData
|
||||
- Step 14 可直接用 rasterio 读取数组和空间范围
|
||||
"""
|
||||
|
||||
# 重新导出 WaterIndexProcessor(向后兼容所有已有 import)
|
||||
from src.core.algorithms.waterindex_inversion import WaterIndexProcessor
|
||||
|
||||
__all__ = ['WaterIndexProcessor']
|
||||
646
src/core/algorithms/waterindex_inversion/__init__.py
Normal file
@ -0,0 +1,646 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
水色指数反演模块
|
||||
|
||||
直接读取去耀斑高光谱 BSQ 影像,应用 waterindex.csv 中的公式,
|
||||
输出各水质参数指数的 GeoTIFF 栅格图像。
|
||||
|
||||
公式格式(waterindex.csv):
|
||||
- 波长占位符:w{nm},如 w686, w708, w665
|
||||
- 支持混合大小写:w686 / W665 均可
|
||||
- 示例:NDCI = (w708 - w665) / (w708 + w665)
|
||||
BGA_Am09KBBI = (w686 - w658) / (w686 + w658)
|
||||
|
||||
输出:
|
||||
- GeoTIFF (Float32),LZW 压缩,带 Tile
|
||||
- 完整克隆原始 BSQ 的 GeoTransform / Projection / NoData
|
||||
- Step 14 可直接用 rasterio 读取进行克里金插值
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from osgeo import gdal, osr
|
||||
|
||||
# GDAL 驱动注册
|
||||
gdal.UseExceptions()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公共工具
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_resource_path(relative_path: str) -> str:
|
||||
"""获取 waterindex.csv 等资源的绝对路径,兼容 PyInstaller 打包。"""
|
||||
if hasattr(sys, '_MEIPASS'):
|
||||
base = sys._MEIPASS
|
||||
else:
|
||||
base = os.path.abspath(
|
||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', '..', '..')
|
||||
)
|
||||
return os.path.join(base, relative_path)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WaterIndexProcessor
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class WaterIndexProcessor:
|
||||
"""
|
||||
水色指数处理器
|
||||
|
||||
读取 waterindex.csv 中的公式,应用于 BSQ 高光谱影像,
|
||||
输出带完整坐标信息的 GeoTIFF 指数图。
|
||||
|
||||
核心能力:
|
||||
- 公式解析:w{nm} 占位符 → 实际波段 2D numpy 数组
|
||||
- 矩阵运算:全影像批量计算,无需逐点循环
|
||||
- 地理信息保持:克隆原始 BSQ 的 GeoTransform / Projection
|
||||
- NoData 处理:运算中产生的 NaN/Inf 统一标记为 -9999
|
||||
"""
|
||||
|
||||
# 内置安全命名空间(公式 eval 白名单)
|
||||
_SAFE_NS: Dict[str, Any] = {
|
||||
'np': np,
|
||||
'nan': np.nan,
|
||||
'inf': np.inf,
|
||||
'pi': np.pi,
|
||||
'e': np.e,
|
||||
}
|
||||
|
||||
def __init__(self, waterindex_csv_path: Optional[str] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
waterindex_csv_path : str, optional
|
||||
waterindex.csv 路径。
|
||||
若为 None,尝试从默认位置加载:
|
||||
1. src/gui/model/waterindex.csv(开发环境)
|
||||
2. _MEIPASS/src/gui/model/waterindex.csv(打包环境)
|
||||
"""
|
||||
self.csv_path: Optional[str] = None
|
||||
self.formulas: List[Dict[str, Any]] = []
|
||||
|
||||
if waterindex_csv_path:
|
||||
self.csv_path = waterindex_csv_path
|
||||
else:
|
||||
candidates = [
|
||||
os.path.join(os.path.dirname(__file__), '..', '..', 'gui', 'model', 'waterindex.csv'),
|
||||
os.path.join(os.path.dirname(__file__), '..', '..', '..', 'gui', 'model', 'waterindex.csv'),
|
||||
]
|
||||
for p in candidates:
|
||||
if os.path.isfile(p):
|
||||
self.csv_path = p
|
||||
break
|
||||
|
||||
if self.csv_path:
|
||||
self._parse_csv()
|
||||
else:
|
||||
self.formulas = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公式加载
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _parse_csv(self) -> None:
|
||||
"""解析 waterindex.csv,加载所有公式。"""
|
||||
if not os.path.isfile(self.csv_path):
|
||||
raise FileNotFoundError(f"公式配置文件不存在: {self.csv_path}")
|
||||
|
||||
# ★★★ 防止多次调用时公式翻倍叠加 ★★★
|
||||
self.formulas.clear()
|
||||
|
||||
with open(self.csv_path, 'r', encoding='utf-8-sig') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
self.formulas.append(dict(row))
|
||||
|
||||
print(f"[WaterIndexProcessor] 加载 {len(self.formulas)} 条公式 ← {self.csv_path}")
|
||||
|
||||
def reload(self, waterindex_csv_path: str) -> None:
|
||||
"""重新加载公式配置文件。"""
|
||||
self.csv_path = waterindex_csv_path
|
||||
self._parse_csv()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公式查询
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_formulas(self) -> List[Dict[str, Any]]:
|
||||
"""返回所有公式的列表。"""
|
||||
return list(self.formulas)
|
||||
|
||||
def list_formula_names(self) -> List[str]:
|
||||
"""返回所有公式名称列表。"""
|
||||
return [f.get('Formula_Name', '') for f in self.formulas]
|
||||
|
||||
def get_formula(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""按名称查找单个公式。"""
|
||||
for f in self.formulas:
|
||||
if f.get('Formula_Name', '').strip() == name.strip():
|
||||
return f
|
||||
return None
|
||||
|
||||
def list_categories(self) -> List[str]:
|
||||
"""返回所有公式类别(去重排序)。"""
|
||||
cats = set()
|
||||
for f in self.formulas:
|
||||
c = f.get('Category', '').strip()
|
||||
if c:
|
||||
cats.add(c)
|
||||
return sorted(cats)
|
||||
|
||||
def get_formulas_by_category(self, category: str) -> List[Dict[str, Any]]:
|
||||
"""按类别筛选公式。"""
|
||||
return [f for f in self.formulas
|
||||
if f.get('Category', '').strip().lower() == category.strip().lower()]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 影像元数据
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_image_metadata(self, bsq_path: str, hdr_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""获取影像元数据(GDAL + ENVI HDR 双重保障)。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bsq_path : str
|
||||
BSQ 影像路径
|
||||
hdr_path : str, optional
|
||||
ENVI HDR 路径(None → 自动构造)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
含 keys: width, height, bands, wavelengths, wavelength_range,
|
||||
geotransform, projection, driver
|
||||
"""
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
# 1. GDAL 优先(获取空间信息)
|
||||
try:
|
||||
ds = gdal.Open(bsq_path, gdal.GA_ReadOnly)
|
||||
if ds is not None:
|
||||
meta['width'] = ds.RasterXSize
|
||||
meta['height'] = ds.RasterYSize
|
||||
meta['bands'] = ds.RasterCount
|
||||
meta['driver'] = ds.GetDriver().ShortName
|
||||
gt = ds.GetGeoTransform()
|
||||
proj = ds.GetProjection()
|
||||
if gt and gt != (0, 1, 0, 0, 0, 1):
|
||||
meta['geotransform'] = gt
|
||||
if proj:
|
||||
meta['projection'] = proj
|
||||
ds = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. HDR 补充波长信息
|
||||
if hdr_path is None:
|
||||
hdr_path = os.path.splitext(bsq_path)[0] + '.hdr'
|
||||
if not os.path.isfile(hdr_path):
|
||||
hdr_path_alt = os.path.splitext(bsq_path)[0] + '.HDR'
|
||||
if os.path.isfile(hdr_path_alt):
|
||||
hdr_path = hdr_path_alt
|
||||
|
||||
if os.path.isfile(hdr_path):
|
||||
wl = self._parse_wavelengths_from_hdr(hdr_path)
|
||||
if wl:
|
||||
meta['wavelengths'] = wl
|
||||
if len(wl) >= 2:
|
||||
meta['wavelength_range'] = f"{wl[0]:.1f}–{wl[-1]:.1f} nm ({len(wl)} 波段)"
|
||||
elif meta.get('bands', 0) > 0:
|
||||
meta['wavelength_range'] = f"{meta['bands']} 波段(波长信息缺失)"
|
||||
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def _parse_wavelengths_from_hdr(hdr_path: str) -> Optional[List[float]]:
|
||||
"""从 ENVI .hdr 文件中解析波长列表。"""
|
||||
try:
|
||||
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# 格式1:wavelength = { 400, 401, ... }
|
||||
m = re.search(r'wavelength\s*=\s*\{([^}]+)\}', content, re.DOTALL)
|
||||
if m:
|
||||
vals = [float(v) for v in re.findall(r'[\d.]+', m.group(1)) if v.strip()]
|
||||
if vals:
|
||||
return vals
|
||||
|
||||
# 格式2:逐行罗列
|
||||
wavelengths: List[float] = []
|
||||
in_wl = False
|
||||
for line in content.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith('wavelength'):
|
||||
in_wl = True
|
||||
continue
|
||||
if in_wl:
|
||||
if line.startswith('{'):
|
||||
continue
|
||||
try:
|
||||
wavelengths.append(float(line))
|
||||
except ValueError:
|
||||
if '}' in line:
|
||||
in_wl = False
|
||||
return wavelengths if wavelengths else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公式解析:w{nm} 占位符 → 实际波段数据
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _find_nearest_band_index(self, target_wv: float,
|
||||
wavelengths: List[float]) -> int:
|
||||
"""找到最接近目标波长的 GDAL 波段索引(1-based)。"""
|
||||
if not wavelengths:
|
||||
raise ValueError("波长列表为空,无法匹配波段")
|
||||
nearest = min(range(len(wavelengths)),
|
||||
key=lambda i: abs(wavelengths[i] - target_wv))
|
||||
return nearest + 1 # GDAL 波段从 1 开始
|
||||
|
||||
def _parse_formula_wavelengths(self, formula: str) -> List[int]:
|
||||
"""从公式字符串中提取所有波长值(去重,int)。"""
|
||||
raw = re.findall(r'[wW](\d+)', formula)
|
||||
seen = set()
|
||||
result: List[int] = []
|
||||
for r in raw:
|
||||
v = int(r)
|
||||
if v not in seen:
|
||||
seen.add(v)
|
||||
result.append(v)
|
||||
return result
|
||||
|
||||
def _eval_formula_fast(self, formula: str,
|
||||
band_data: Dict[int, np.ndarray]) -> Optional[np.ndarray]:
|
||||
"""快速公式求值(预处理后直接 eval)。
|
||||
|
||||
band_data: {波长int: 2D 数组}
|
||||
formula 示例: "(w708 - w665) / (w708 + w665)"
|
||||
"""
|
||||
# 预处理:w708 → _B708(避免与 Python 关键字冲突)
|
||||
processed = re.sub(r'[wW](\d+)', r'_B\1', formula)
|
||||
|
||||
# 构建局部变量表:_B708 = band_data[708]
|
||||
local_vars = {f"_B{wv}": arr for wv, arr in band_data.items()}
|
||||
local_vars.update(self._SAFE_NS)
|
||||
|
||||
try:
|
||||
result = eval(processed, {"__builtins__": {}}, local_vars)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" ⚠ 公式求值失败 [{formula}]: {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 单波段读取(带 NoData 处理)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _read_band_as_float(bsq_path: str, band_idx: int) -> np.ndarray:
|
||||
"""读取 BSQ 指定波段(1-based),返回 float64,NaN 替换 NoData。"""
|
||||
ds = gdal.Open(bsq_path, gdal.GA_ReadOnly)
|
||||
if ds is None:
|
||||
raise RuntimeError(f"无法用 GDAL 打开影像: {bsq_path}")
|
||||
|
||||
band = ds.GetRasterBand(band_idx)
|
||||
arr = band.ReadAsArray()
|
||||
nodata = band.GetNoDataValue()
|
||||
ds = None
|
||||
|
||||
arr = arr.astype(np.float64)
|
||||
if nodata is not None:
|
||||
arr = np.where(arr == nodata, np.nan, arr)
|
||||
|
||||
return arr
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 核心处理:逐公式矩阵运算 + GeoTIFF 输出
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def process_bsq(
|
||||
self,
|
||||
bsq_path: str,
|
||||
hdr_path: Optional[str] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
formula_names: Optional[List[str]] = None,
|
||||
water_mask: Optional[np.ndarray] = None,
|
||||
nodata_value: float = -9999.0,
|
||||
progress_callback: Optional[Callable[[str, float], None]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""逐公式处理 BSQ 影像,输出 GeoTIFF。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bsq_path : str
|
||||
去耀斑 BSQ 影像路径
|
||||
hdr_path : str, optional
|
||||
ENVI HDR 文件路径(None → 自动构造)
|
||||
output_dir : str, optional
|
||||
输出目录(None → 与 bsq_path 同目录下的 10_WaterIndex_Images/)
|
||||
formula_names : list, optional
|
||||
要处理的公式名列表(None → 处理全部)
|
||||
water_mask : np.ndarray, optional
|
||||
水域掩膜数组(与 BSQ 同形状),掩膜值为 0 表示陆地,
|
||||
将被强制赋值为 nodata_value
|
||||
nodata_value : float
|
||||
NoData 标记值
|
||||
progress_callback : callable, optional
|
||||
回调 (msg: str, pct: float)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
{公式名: 输出 GeoTIFF 路径}
|
||||
"""
|
||||
# ── 自动构造 HDR 路径 ────────────────────────────────────────────
|
||||
if hdr_path is None:
|
||||
hdr_path = os.path.splitext(bsq_path)[0] + '.hdr'
|
||||
if not os.path.isfile(hdr_path):
|
||||
hdr_path_alt = os.path.splitext(bsq_path)[0] + '.HDR'
|
||||
if os.path.isfile(hdr_path_alt):
|
||||
hdr_path = hdr_path_alt
|
||||
|
||||
# ── 自动构造输出目录 ────────────────────────────────────────────
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join(os.path.dirname(bsq_path), '10_WaterIndex_Images')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def progress(msg: str, pct: float):
|
||||
if progress_callback:
|
||||
progress_callback(msg, pct)
|
||||
|
||||
# ── 获取影像元数据 ───────────────────────────────────────────────
|
||||
progress("正在打开影像并读取元数据…", 2)
|
||||
meta = self.get_image_metadata(bsq_path, hdr_path)
|
||||
|
||||
width = meta.get('width', 0)
|
||||
height = meta.get('height', 0)
|
||||
n_bands = meta.get('bands', 0)
|
||||
wavelengths = meta.get('wavelengths', [])
|
||||
geotransform = meta.get('geotransform')
|
||||
projection = meta.get('projection')
|
||||
|
||||
if n_bands == 0 or width == 0 or height == 0:
|
||||
raise ValueError(f"影像元数据无效,无法处理: {bsq_path}")
|
||||
|
||||
if not wavelengths:
|
||||
raise ValueError(f"无法从 {hdr_path} 读取波长信息,公式无法解析")
|
||||
|
||||
progress(
|
||||
f"影像: {width}×{height}像素, {n_bands}波段, "
|
||||
f"波长 {wavelengths[0]:.1f}–{wavelengths[-1]:.1f}nm",
|
||||
5
|
||||
)
|
||||
|
||||
# ── 过滤要处理的公式 ──────────────────────────────────────────────
|
||||
if formula_names:
|
||||
formulas_to_run = [
|
||||
f for f in self.formulas
|
||||
if f.get('Formula_Name', '').strip() in formula_names
|
||||
]
|
||||
else:
|
||||
formulas_to_run = list(self.formulas)
|
||||
|
||||
results: Dict[str, str] = {}
|
||||
total = len(formulas_to_run)
|
||||
|
||||
# ── 逐公式处理 ───────────────────────────────────────────────────
|
||||
for i, formula_row in enumerate(formulas_to_run):
|
||||
fname = formula_row.get('Formula_Name', '').strip()
|
||||
fstr = formula_row.get('Formula', '').strip()
|
||||
category = formula_row.get('Category', '').strip()
|
||||
ftype = formula_row.get('Formula_Type', '').strip()
|
||||
|
||||
if not fname or not fstr:
|
||||
continue
|
||||
|
||||
progress(
|
||||
f"[{i + 1}/{total}] {fname} ({category})",
|
||||
5 + 90 * i / total
|
||||
)
|
||||
|
||||
try:
|
||||
# 1) 提取公式所需的波长列表
|
||||
required_wvs = self._parse_formula_wavelengths(fstr)
|
||||
|
||||
# 2) 按需读取波段数据(相同波长只读一次)
|
||||
band_data: Dict[int, np.ndarray] = {}
|
||||
for wv in required_wvs:
|
||||
if wv not in band_data:
|
||||
band_idx = self._find_nearest_band_index(wv, wavelengths)
|
||||
if not (0 < band_idx <= n_bands):
|
||||
print(f" ⚠ 公式 '{fname}' 引用波段 {band_idx},超出范围 ({n_bands}),跳过")
|
||||
raise ValueError(f"波段 {band_idx} 超出影像范围")
|
||||
band_data[wv] = self._read_band_as_float(bsq_path, band_idx)
|
||||
|
||||
# 3) 矩阵运算
|
||||
index_arr = self._eval_formula_fast(fstr, band_data)
|
||||
if index_arr is None:
|
||||
print(f" ⚠ 公式 '{fname}' 计算失败,跳过")
|
||||
continue
|
||||
|
||||
# 4) NoData 处理:NaN / Inf → nodata_value
|
||||
index_arr = np.where(np.isfinite(index_arr), index_arr, nodata_value)
|
||||
|
||||
# 4b) 水域掩膜拦截:陆地像素(mask==0)强制赋 NoData
|
||||
if water_mask is not None:
|
||||
land_pixels = (water_mask == 0)
|
||||
land_count = int(land_pixels.sum())
|
||||
if land_count > 0:
|
||||
index_arr = np.where(land_pixels, nodata_value, index_arr)
|
||||
print(f" 🗺 掩膜处理:陆地像素 {land_count:,} 个已设为 NoData")
|
||||
|
||||
# 5) 输出 GeoTIFF
|
||||
safe_fname = re.sub(r'[^\w\u4e00-\u9fff-]', '_', fname)
|
||||
out_tif = os.path.join(output_dir, f"{safe_fname}.tif")
|
||||
|
||||
self._write_geotiff(
|
||||
out_path=out_tif,
|
||||
data=index_arr,
|
||||
reference_bsq=bsq_path,
|
||||
nodata_value=nodata_value,
|
||||
description=f"{fname}|{category}|{ftype}|{fstr}",
|
||||
)
|
||||
|
||||
results[fname] = out_tif
|
||||
valid = index_arr[index_arr != nodata_value]
|
||||
mean_val = float(np.mean(valid)) if valid.size else np.nan
|
||||
print(f" ✅ {fname} → {out_tif} (mean={mean_val:.4f})")
|
||||
|
||||
except ValueError as ve:
|
||||
print(f" ⏭ 跳过 '{fname}': {ve}")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f" ❌ 公式 '{fname}' 失败: {e}\n{traceback.format_exc()}")
|
||||
continue
|
||||
|
||||
progress(f"完成!共输出 {len(results)} / {total} 个指数图", 100)
|
||||
return results
|
||||
|
||||
def _write_geotiff(
|
||||
self,
|
||||
out_path: str,
|
||||
data: np.ndarray,
|
||||
reference_bsq: str,
|
||||
nodata_value: float = -9999.0,
|
||||
description: str = "",
|
||||
) -> None:
|
||||
"""将数组写入 GeoTIFF,克隆原始 BSQ 的地理信息。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
out_path : str
|
||||
输出 GeoTIFF 路径
|
||||
data : np.ndarray
|
||||
2D 数据数组(height, width)
|
||||
reference_bsq : str
|
||||
参考 BSQ 影像路径(用于克隆 GeoTransform / Projection)
|
||||
nodata_value : float
|
||||
NoData 标记值
|
||||
description : str
|
||||
GDAL 数据集描述
|
||||
"""
|
||||
height, width = data.shape
|
||||
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
if driver is None:
|
||||
raise RuntimeError("GDAL GTiff 驱动不可用")
|
||||
|
||||
out_ds = driver.Create(
|
||||
out_path,
|
||||
width, height,
|
||||
1,
|
||||
gdal.GDT_Float32,
|
||||
options=['COMPRESS=LZW', 'TILED=YES', 'BIGTIFF=IF_SAFER'],
|
||||
)
|
||||
if out_ds is None:
|
||||
raise RuntimeError(f"无法创建 GeoTIFF: {out_path}")
|
||||
|
||||
# 写入数据
|
||||
out_band = out_ds.GetRasterBand(1)
|
||||
out_band.SetNoDataValue(nodata_value)
|
||||
out_band.WriteArray(data)
|
||||
out_band.FlushCache()
|
||||
|
||||
# 写入描述
|
||||
if description:
|
||||
out_band.SetDescription(description)
|
||||
|
||||
# ★★★ 克隆原始 BSQ 的 GeoTransform 和 Projection ★★★
|
||||
ref_ds = gdal.Open(reference_bsq, gdal.GA_ReadOnly)
|
||||
if ref_ds is not None:
|
||||
gt = ref_ds.GetGeoTransform()
|
||||
proj = ref_ds.GetProjection()
|
||||
if gt and gt != (0, 1, 0, 0, 0, 1):
|
||||
out_ds.SetGeoTransform(gt)
|
||||
if proj:
|
||||
out_ds.SetProjection(proj)
|
||||
ref_ds = None
|
||||
|
||||
out_ds = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pipeline 入口(供 PipelineRunner 调用)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run_inversion(
|
||||
self,
|
||||
deglint_img_path: str,
|
||||
work_dir: str,
|
||||
formula_csv_path: Optional[str] = None,
|
||||
selected_formulas: Optional[List[str]] = None,
|
||||
water_mask_path: Optional[str] = None,
|
||||
nodata_value: float = -9999.0,
|
||||
callback: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, str]:
|
||||
"""Pipeline 入口方法。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
deglint_img_path : str
|
||||
去耀斑影像 BSQ 路径
|
||||
work_dir : str
|
||||
工作目录
|
||||
formula_csv_path : str, optional
|
||||
waterindex.csv 路径(None → 使用初始化时的路径)
|
||||
selected_formulas : list, optional
|
||||
要处理的公式列表
|
||||
water_mask_path : str, optional
|
||||
水域掩膜路径(如 1_water_mask/water_mask.dat),
|
||||
掩膜中为 0 的像素视为陆地区域,其指数值将被强制设为 NoData。
|
||||
nodata_value : float
|
||||
NoData 标记值,默认 -9999.0
|
||||
callback : callable, optional
|
||||
进度回调
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
{公式名: 输出 GeoTIFF 路径}
|
||||
"""
|
||||
# 重新加载公式(如指定了新路径)
|
||||
if formula_csv_path:
|
||||
self.reload(formula_csv_path)
|
||||
elif not self.formulas:
|
||||
raise RuntimeError("WaterIndexProcessor 未加载公式,请指定 formula_csv_path")
|
||||
|
||||
def notify(msg: str, pct: float):
|
||||
if callback:
|
||||
callback(msg, pct)
|
||||
|
||||
notify("开始水色指数反演", 0)
|
||||
|
||||
bsq_path = deglint_img_path
|
||||
hdr_path = os.path.splitext(bsq_path)[0] + '.hdr'
|
||||
if not os.path.isfile(hdr_path):
|
||||
hdr_path_alt = os.path.splitext(bsq_path)[0] + '.HDR'
|
||||
if os.path.isfile(hdr_path_alt):
|
||||
hdr_path = hdr_path_alt
|
||||
|
||||
output_dir = os.path.join(work_dir, "10_WaterIndex_Images")
|
||||
|
||||
# ── 加载水域掩膜(可选)───────────────────────────────────────
|
||||
water_mask: Optional[np.ndarray] = None
|
||||
if water_mask_path:
|
||||
if os.path.isfile(water_mask_path):
|
||||
try:
|
||||
import rasterio
|
||||
with rasterio.open(water_mask_path) as msrc:
|
||||
water_mask = msrc.read(1)
|
||||
print(f"[run_inversion] 水域掩膜已加载: {water_mask_path},"
|
||||
f"形状={water_mask.shape},"
|
||||
f"陆地区域(0)={int((water_mask == 0).sum())},"
|
||||
f"水区域(>0)={int((water_mask > 0).sum())}")
|
||||
except Exception as mask_err:
|
||||
print(f"[run_inversion] ⚠ 掩膜加载失败,跳过掩膜处理: {mask_err}")
|
||||
water_mask = None
|
||||
else:
|
||||
print(f"[run_inversion] ⚠ 水域掩膜文件不存在: {water_mask_path},跳过掩膜处理")
|
||||
|
||||
notify("水色指数处理中…", 20)
|
||||
|
||||
results = self.process_bsq(
|
||||
bsq_path=bsq_path,
|
||||
hdr_path=hdr_path,
|
||||
output_dir=output_dir,
|
||||
formula_names=selected_formulas,
|
||||
water_mask=water_mask,
|
||||
nodata_value=nodata_value,
|
||||
progress_callback=lambda m, p: notify(m, 20 + 70 * p / 100),
|
||||
)
|
||||
|
||||
notify("水色指数反演完成", 100)
|
||||
return results
|
||||
@ -899,11 +899,11 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
|
||||
if __name__ == '__main__':
|
||||
# 在这里直接设置参数
|
||||
imgpath = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"# BIL格式影像文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\test/yangdian_output.csv" # CSV格式输出文件路径
|
||||
|
||||
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
||||
flare_path = r"E:\code\WQ\封装\work_dir\2_glint\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用)
|
||||
flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用)
|
||||
boundary_path ="D:\BaiduNetdiskDownload\yaobao\water_mask.dat" # 边界掩膜文件路径(可选,None表示不使用)
|
||||
source_epsg = 4326 # 源坐标系EPSG代码,默认为4326 (WGS84地理坐标系)
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from osgeo import gdal, osr
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import re
|
||||
import spectral
|
||||
from math import sin, cos, tan, sqrt, radians
|
||||
|
||||
@ -212,16 +213,83 @@ def load_mask_file(mask_path):
|
||||
def get_hdr_file_path(file_path):
|
||||
"""
|
||||
获取HDR文件路径
|
||||
|
||||
|
||||
Args:
|
||||
file_path: 影像文件路径
|
||||
|
||||
|
||||
Returns:
|
||||
HDR文件路径
|
||||
"""
|
||||
return os.path.splitext(file_path)[0] + ".hdr"
|
||||
|
||||
|
||||
def load_wavelength_columns(imgpath, num_bands):
|
||||
"""
|
||||
加载 wavelength 列名(鲁棒版:三级回退)
|
||||
|
||||
优先级:
|
||||
1) spectral.envi.read_envi_header(标准库解析,依赖 ENVI 头完整性)
|
||||
2) 纯文本暴力解析 .hdr(兜底,绕过 spectral 对 band names / 波段数一致性的校验)
|
||||
—— 解决 .hdr 中 band names 数量与 bands 不符导致的标准库解析失败问题
|
||||
3) 最后回退:band_1, band_2, ..., band_N
|
||||
|
||||
Args:
|
||||
imgpath: 影像文件路径(.bsq / .bil / .bip 等)
|
||||
num_bands: 影像实际波段数(用于回退列名长度 & 不一致警告)
|
||||
|
||||
Returns:
|
||||
spectral_columns: 长度为 num_bands 的字符串列表(与原代码列名格式一致:纯数字字符串)
|
||||
"""
|
||||
hdr_path = get_hdr_file_path(imgpath)
|
||||
|
||||
# 1) 标准库解析
|
||||
try:
|
||||
in_hdr_dict = spectral.envi.read_envi_header(hdr_path)
|
||||
wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64')
|
||||
spectral_columns = [str(wl) for wl in wavelengths]
|
||||
print(f"[wavelength] 标准库解析成功,从 {hdr_path} 提取 {len(spectral_columns)} 个波长")
|
||||
if len(spectral_columns) != num_bands:
|
||||
print(f"[wavelength] 警告: 解析波长数 ({len(spectral_columns)}) 与影像波段数 ({num_bands}) 不一致,将以 num_bands 为准截断/补齐")
|
||||
if len(spectral_columns) > num_bands:
|
||||
spectral_columns = spectral_columns[:num_bands]
|
||||
elif len(spectral_columns) < num_bands:
|
||||
spectral_columns = spectral_columns + [f"band_{j+1}" for j in range(len(spectral_columns), num_bands)]
|
||||
return spectral_columns
|
||||
except Exception as e_std:
|
||||
print(f"[wavelength] 标准库解析失败: {str(e_std)},将尝试文本兜底解析")
|
||||
|
||||
# 2) 兜底:纯文本暴力解析
|
||||
try:
|
||||
if not os.path.isfile(hdr_path):
|
||||
print(f"[wavelength] 文本兜底失败: {hdr_path} 不存在")
|
||||
else:
|
||||
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
hdr_text = f.read()
|
||||
pattern = r'wavelength\s*=\s*\{([^}]+)\}'
|
||||
m = re.search(pattern, hdr_text, flags=re.IGNORECASE | re.DOTALL)
|
||||
if m:
|
||||
inner = m.group(1)
|
||||
tokens = [t.strip() for t in inner.split(',') if t.strip()]
|
||||
if tokens:
|
||||
if len(tokens) != num_bands:
|
||||
print(f"[wavelength] 文本解析波长数 ({len(tokens)}) 与影像波段数 ({num_bands}) 不一致,将以 num_bands 为准截断/补齐")
|
||||
if len(tokens) > num_bands:
|
||||
tokens = tokens[:num_bands]
|
||||
elif len(tokens) < num_bands:
|
||||
tokens = tokens + [f"band_{j+1}" for j in range(len(tokens), num_bands)]
|
||||
print(f"[wavelength] 文本暴力解析成功,从 {hdr_path} 提取 {len(tokens)} 个真实波长")
|
||||
return tokens
|
||||
print(f"[wavelength] 文本兜底: 已匹配到 wavelength = {{ ... }},但内部为空")
|
||||
else:
|
||||
print(f"[wavelength] 文本兜底: 未在 {hdr_path} 中匹配到 wavelength = {{ ... }} 字段")
|
||||
except Exception as e_txt:
|
||||
print(f"[wavelength] 文本兜底解析异常: {str(e_txt)}")
|
||||
|
||||
# 3) 全部失败,最后回退
|
||||
print(f"[wavelength] 所有解析路径均失败,回退到 band_1..band_{num_bands}")
|
||||
return ["band_" + str(j + 1) for j in range(num_bands)]
|
||||
|
||||
|
||||
def calculate_utm_zone(longitude):
|
||||
"""
|
||||
根据经度计算UTM分区号
|
||||
@ -473,9 +541,56 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
|
||||
for i in range(min(3, coor_data.shape[0])):
|
||||
print(f" 行{i + 1}: {coor_data[i, :min(5, coor_data.shape[1])]}") # 只显示前5列
|
||||
|
||||
# 提取原始坐标
|
||||
lat_array = coor_data[:, 0] # 第1列是纬度
|
||||
lon_array = coor_data[:, 1] # 第2列是经度
|
||||
# 提取原始坐标(使用智能坐标列检测)
|
||||
lon_patterns = [
|
||||
r'^lon', r'^lng', r'^longitude', r'经度', r'^x$', r'^utm_x$', r'^pixel_x$'
|
||||
]
|
||||
lat_patterns = [
|
||||
r'^lat', r'^latitude', r'纬度', r'^y$', r'^utm_y$', r'^pixel_y$'
|
||||
]
|
||||
|
||||
x_col_name, y_col_name = None, None
|
||||
|
||||
if coor_df is not None and hasattr(coor_df, 'columns'):
|
||||
for col in coor_df.columns:
|
||||
col_str = str(col).lower().strip()
|
||||
if x_col_name is None and any(re.search(p, col_str) for p in lon_patterns):
|
||||
x_col_name = col
|
||||
if y_col_name is None and any(re.search(p, col_str) for p in lat_patterns):
|
||||
y_col_name = col
|
||||
|
||||
if x_col_name and y_col_name and x_col_name in coor_df.columns and y_col_name in coor_df.columns:
|
||||
lon_array = coor_df[x_col_name].values
|
||||
lat_array = coor_df[y_col_name].values
|
||||
print(f"💡 坐标列名检测: X/经度=[{x_col_name}], Y/纬度=[{y_col_name}]")
|
||||
else:
|
||||
numeric_cols = coor_df.select_dtypes(include=[np.number]).columns.tolist() if coor_df is not None else []
|
||||
if len(numeric_cols) >= 2:
|
||||
col1, col2 = numeric_cols[0], numeric_cols[1]
|
||||
mean1 = coor_df[col1].head(10).mean()
|
||||
mean2 = coor_df[col2].head(10).mean()
|
||||
if abs(mean1) <= 90 and abs(mean2) > 90:
|
||||
y_col_name, x_col_name = col1, col2
|
||||
lon_array = coor_df[x_col_name].values
|
||||
lat_array = coor_df[y_col_name].values
|
||||
elif abs(mean2) <= 90 and abs(mean1) > 90:
|
||||
x_col_name, y_col_name = col1, col2
|
||||
lon_array = coor_df[x_col_name].values
|
||||
lat_array = coor_df[y_col_name].values
|
||||
else:
|
||||
if mean1 > mean2:
|
||||
x_col_name, y_col_name = col1, col2
|
||||
else:
|
||||
x_col_name, y_col_name = col2, col1
|
||||
lon_array = coor_df[x_col_name].values
|
||||
lat_array = coor_df[y_col_name].values
|
||||
print(f"💡 触发智能数值推断坐标列: X/经度=[{x_col_name}], Y/纬度=[{y_col_name}]")
|
||||
else:
|
||||
if coor_data is not None and coor_data.shape[1] >= 3:
|
||||
lat_array = coor_data[:, 1]
|
||||
lon_array = coor_data[:, 2]
|
||||
else:
|
||||
raise Exception("坐标文件格式错误:需要至少2列数据,且最好包含坐标列名(如lon/lat/经度/纬度)")
|
||||
|
||||
print(f"\n=== 原始坐标信息 ===")
|
||||
print(f"原始坐标范围: 经度 {np.min(lon_array):.6f} ~ {np.max(lon_array):.6f}, 纬度 {np.min(lat_array):.6f} ~ {np.max(lat_array):.6f}")
|
||||
@ -711,17 +826,8 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
|
||||
else:
|
||||
original_columns = []
|
||||
|
||||
# 读取波长信息,用作光谱列名
|
||||
wavelengths = None
|
||||
try:
|
||||
in_hdr_dict = spectral.envi.read_envi_header(get_hdr_file_path(imgpath))
|
||||
wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64')
|
||||
# 将波长值转换为字符串作为列名
|
||||
spectral_columns = [str(wl) for wl in wavelengths]
|
||||
print(f"成功读取波长信息,共 {len(spectral_columns)} 个波段")
|
||||
except Exception as e:
|
||||
print(f"警告: 无法读取波长信息 ({str(e)}),使用默认列名 band_1, band_2, ...")
|
||||
spectral_columns = ["band_" + str(j + 1) for j in range(num_bands)]
|
||||
# 读取波长信息,用作光谱列名(三级回退:spectral 解析 → 文本暴力解析 → band_N 兜底)
|
||||
spectral_columns = load_wavelength_columns(imgpath, num_bands)
|
||||
|
||||
# 构建输出列名(不包含前两列坐标列和UTM列)
|
||||
all_columns = original_columns + spectral_columns
|
||||
@ -758,11 +864,11 @@ def get_spectral_in_coor(imgpath, coorpath, outpath, radius=0, flare_path=None,
|
||||
if __name__ == '__main__':
|
||||
# 在这里直接设置参数
|
||||
imgpath = r"E:\code\WQ\封装\work_dir\3_deglint\deglint_goodman.bsq" # BIL格式影像文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\4_processed_data\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\work_dir\5_training_spectra/yangdian_output.csv" # CSV格式输出文件路径
|
||||
coorpath = r"E:\code\WQ\封装\work_dir\5_Data_Cleaning\processed_data.csv"# CSV格式坐标文件路径(第1、2列为纬度和经度)
|
||||
output_path = r"E:\code\WQ\封装\work_dir\6_Spectral_Feature_Extraction/yangdian_output.csv" # CSV格式输出文件路径
|
||||
|
||||
radius = 5 # 采样半径(像素),0表示单点采样,>0表示半径内平均
|
||||
flare_path = r"E:\code\WQ\封装\work_dir\2_glint\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用)
|
||||
flare_path = r"E:\code\WQ\封装\work_dir\2_Glint_Detection\severe_glint_area.dat" # 耀斑掩膜文件路径(可选,None表示不使用)
|
||||
boundary_path = r"D:\BaiduNetdiskDownload\yaobao\water_mask.dat" # 边界掩膜文件路径(可选,None表示不使用)
|
||||
source_epsg = 4326 # 源坐标系EPSG代码,默认为4326 (WGS84地理坐标系)
|
||||
|
||||
|
||||
46
src/core/handlers/__init__.py
Normal file
@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
步骤处理器包
|
||||
|
||||
将 WaterQualityInversionPipeline 的 14 个巨型 step* 方法
|
||||
拆分为独立的 Handler 类,每个 Handler 实现 BaseStepHandler 接口。
|
||||
|
||||
调度器(PipelineScheduler)仅维护执行上下文并根据 step_key
|
||||
从注册表查找对应 Handler 执行,自身不再包含任何算法逻辑。
|
||||
"""
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.handlers.step1_water_mask import Step1WaterMaskHandler
|
||||
from src.core.handlers.step2_glint_detection import Step2GlintDetectionHandler
|
||||
from src.core.handlers.step3_glint_removal import Step3GlintRemovalHandler
|
||||
from src.core.handlers.step4_sampling import Step4SamplingHandler
|
||||
from src.core.handlers.step5_process_csv import Step5ProcessCsvHandler
|
||||
from src.core.handlers.step6_extract_spectra import Step6ExtractSpectraHandler
|
||||
from src.core.handlers.step7_calc_indices import Step7CalcIndicesHandler
|
||||
from src.core.handlers.step8_ml_train import Step8MlTrainHandler
|
||||
from src.core.handlers.step9_ml_predict import Step9MlPredictHandler
|
||||
from src.core.handlers.step10_qaa_inversion import Step10QaaInversionHandler
|
||||
from src.core.handlers.step11_concentration import Step11ConcentrationHandler
|
||||
from src.core.handlers.step12_kriging import Step12KrigingHandler
|
||||
from src.core.handlers.step13_visualization import Step13VisualizationHandler
|
||||
from src.core.handlers.step14_report import Step14ReportHandler
|
||||
|
||||
__all__ = [
|
||||
'BaseStepHandler',
|
||||
'PipelineContext',
|
||||
'Step1WaterMaskHandler',
|
||||
'Step2GlintDetectionHandler',
|
||||
'Step3GlintRemovalHandler',
|
||||
'Step4SamplingHandler',
|
||||
'Step5ProcessCsvHandler',
|
||||
'Step6ExtractSpectraHandler',
|
||||
'Step7CalcIndicesHandler',
|
||||
'Step8MlTrainHandler',
|
||||
'Step9MlPredictHandler',
|
||||
'Step10QaaInversionHandler',
|
||||
'Step11ConcentrationHandler',
|
||||
'Step12KrigingHandler',
|
||||
'Step13VisualizationHandler',
|
||||
'Step14ReportHandler',
|
||||
]
|
||||
282
src/core/handlers/base.py
Normal file
@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Handler 基类与 Pipeline 执行上下文
|
||||
|
||||
BaseStepHandler —— 所有步骤 Handler 的抽象基类,定义统一的 execute 接口。
|
||||
PipelineContext —— 在 Handler 之间传递的共享状态容器(路径、计时、回调等)。
|
||||
|
||||
设计原则:
|
||||
- Handler 只负责"执行一个步骤的算法逻辑",不管理调度/依赖/跳过。
|
||||
- Context 是 Handler 之间唯一的共享状态通道。
|
||||
- 调度器(PipelineScheduler)负责遍历 config、查找 Handler、调用 execute。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
class PipelineContext:
|
||||
"""管道执行上下文 —— Handler 之间共享状态的唯一载体。
|
||||
|
||||
包含:
|
||||
- 工作目录及子目录
|
||||
- 中间结果路径(water_mask_path, glint_mask_path, ...)
|
||||
- 步骤计时记录
|
||||
- 回调函数(用于 GUI 进度通知)
|
||||
- 可视化/报告生成器实例
|
||||
"""
|
||||
|
||||
def __init__(self, work_dir: str = "./work_dir"):
|
||||
self.work_dir = Path(work_dir)
|
||||
self.work_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── 子目录 ──
|
||||
self.water_mask_dir = self.work_dir / "1_water_mask"
|
||||
self.glint_dir = self.work_dir / "2_Glint_Detection"
|
||||
self.deglint_dir = self.work_dir / "3_deglint"
|
||||
self.processed_data_dir = self.work_dir / "5_Data_Cleaning"
|
||||
self.training_spectra_dir = self.work_dir / "6_Spectral_Feature_Extraction"
|
||||
self.indices_dir = self.work_dir / "7_Water_Quality_Indices"
|
||||
self.models_dir = self.work_dir / "8_Supervised_Model_Training"
|
||||
self.non_empirical_models_dir = self.work_dir / "8_Non_Empirical_Regression"
|
||||
self.custom_regression_dir = self.work_dir / "13_Custom_Regression"
|
||||
self.sampling_dir = self.work_dir / "4_sampling"
|
||||
self.prediction_dir = self.work_dir / "11_12_13_predictions"
|
||||
self.visualization_dir = self.work_dir / "14_visualization"
|
||||
self.reports_dir = self.work_dir / "reports"
|
||||
|
||||
for d in [self.water_mask_dir, self.glint_dir, self.deglint_dir,
|
||||
self.processed_data_dir, self.training_spectra_dir,
|
||||
self.indices_dir, self.models_dir, self.non_empirical_models_dir,
|
||||
self.custom_regression_dir, self.sampling_dir, self.prediction_dir,
|
||||
self.visualization_dir, self.reports_dir]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── 中间结果路径 ──
|
||||
self.water_mask_path: Optional[str] = None
|
||||
self.glint_mask_path: Optional[str] = None
|
||||
self.interpolated_img_path: Optional[str] = None
|
||||
self.deglint_img_path: Optional[str] = None
|
||||
self.processed_csv_path: Optional[str] = None
|
||||
self.training_csv_path: Optional[str] = None
|
||||
self.indices_path: Optional[str] = None
|
||||
self.custom_regression_path: Optional[str] = None
|
||||
self.sampling_csv_path: Optional[str] = None
|
||||
self.prediction_files: Dict[str, str] = {}
|
||||
self.distribution_map_path: Optional[str] = None
|
||||
self.qaa_output_path: Optional[str] = None
|
||||
self.concentration_output_path: Optional[str] = None
|
||||
|
||||
# ── 计时 ──
|
||||
self.step_timings: Dict[str, dict] = {}
|
||||
self.pipeline_start_time: Optional[float] = None
|
||||
self.pipeline_end_time: Optional[float] = None
|
||||
|
||||
# ── 回调 ──
|
||||
self._callback: Optional[Callable] = None
|
||||
|
||||
# ── 可视化组件(延迟导入避免循环依赖)──
|
||||
self._visualizer = None
|
||||
self._report_generator = None
|
||||
self._scatter_batch = None
|
||||
|
||||
# ── matplotlib 中文字体 ──
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei',
|
||||
'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# 回调
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
def set_callback(self, callback: Callable):
|
||||
"""设置回调函数,用于向 GUI 报告进度。
|
||||
|
||||
Args:
|
||||
callback: 签名为 callback(step_name, status, message="")
|
||||
status: 'start' | 'completed' | 'skipped' | 'error' | 'info' | 'warning'
|
||||
"""
|
||||
self._callback = callback
|
||||
|
||||
def notify(self, step_name: str, status: str, message: str = ""):
|
||||
"""通知回调函数。"""
|
||||
if self._callback:
|
||||
try:
|
||||
self._callback(step_name, status, message)
|
||||
except Exception as e:
|
||||
print(f"回调函数执行失败: {e}")
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# 计时
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
def record_step_time(self, step_name: str, start_time: float, end_time: float,
|
||||
status: str = "completed", error: Optional[str] = None):
|
||||
elapsed = end_time - start_time
|
||||
self.step_timings[step_name] = {
|
||||
'start_time': datetime.fromtimestamp(start_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'end_time': datetime.fromtimestamp(end_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'elapsed_seconds': elapsed,
|
||||
'elapsed_formatted': self._format_time(elapsed),
|
||||
'status': status,
|
||||
'error': error,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_time(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.2f}秒"
|
||||
elif seconds < 3600:
|
||||
minutes = int(seconds // 60)
|
||||
secs = seconds % 60
|
||||
return f"{minutes}分{secs:.2f}秒"
|
||||
else:
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = seconds % 60
|
||||
return f"{hours}小时{minutes}分{secs:.2f}秒"
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# 可视化组件(延迟导入)
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
@property
|
||||
def visualizer(self):
|
||||
if self._visualizer is None:
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
self._visualizer = WaterQualityVisualization(str(self.visualization_dir))
|
||||
return self._visualizer
|
||||
|
||||
@property
|
||||
def report_generator(self):
|
||||
if self._report_generator is None:
|
||||
from src.postprocessing.visualization_reports import ReportGenerator
|
||||
self._report_generator = ReportGenerator(str(self.reports_dir))
|
||||
return self._report_generator
|
||||
|
||||
@property
|
||||
def scatter_batch(self):
|
||||
if self._scatter_batch is None:
|
||||
from src.core.prediction.sctter_batch import WaterQualityScatterBatch
|
||||
self._scatter_batch = WaterQualityScatterBatch()
|
||||
return self._scatter_batch
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# 步骤输出目录查找(兼容旧接口)
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
_STEP_OUTPUT_DIR_MAP: Optional[Dict[str, Path]] = None
|
||||
|
||||
def _ensure_step_dir_map(self) -> Dict[str, Path]:
|
||||
if PipelineContext._STEP_OUTPUT_DIR_MAP is not None:
|
||||
return PipelineContext._STEP_OUTPUT_DIR_MAP
|
||||
wp = self.work_dir
|
||||
m = {
|
||||
'step1': wp / '1_water_mask',
|
||||
'step2': wp / '2_Glint_Detection',
|
||||
'step3': wp / '3_deglint',
|
||||
'step4_sampling': wp / '4_sampling',
|
||||
'step5_clean': wp / '5_Data_Cleaning',
|
||||
'step6_feature': wp / '6_Spectral_Feature_Extraction',
|
||||
'step7_index': wp / '7_Water_Quality_Indices',
|
||||
'step8_ml_train': wp / '8_Supervised_Model_Training',
|
||||
'step9_ml_predict': wp / '8_Non_Empirical_Regression',
|
||||
'step10_watercolor': wp / '10_WaterIndex_Images',
|
||||
'step11_map': wp / '14_visualization',
|
||||
'step12_viz': wp / '14_visualization',
|
||||
'step13_report': wp / '14_visualization',
|
||||
'step11_predictions': wp / '11_12_13_predictions',
|
||||
'step12_predictions': wp / '11_12_13_predictions',
|
||||
'step13_predictions': wp / '11_12_13_predictions',
|
||||
'custom_regression': wp / '13_Custom_Regression',
|
||||
'prediction_dir': wp / '11_12_13_predictions',
|
||||
'visualization': wp / '14_visualization',
|
||||
'reports': wp / 'reports',
|
||||
'step8': wp / '8_Supervised_Model_Training',
|
||||
'step9': wp / '8_Non_Empirical_Regression',
|
||||
'step10': wp / '10_WaterIndex_Images',
|
||||
'step11': wp / '11_12_13_predictions',
|
||||
'step12': wp / '13_Custom_Regression',
|
||||
'step13': wp / 'reports',
|
||||
'step14': wp / '14_visualization',
|
||||
}
|
||||
PipelineContext._STEP_OUTPUT_DIR_MAP = m
|
||||
return m
|
||||
|
||||
def get_step_output_dir(self, step_name: str) -> Path:
|
||||
mapping = self._ensure_step_dir_map()
|
||||
key = (step_name or '').strip()
|
||||
if key in mapping:
|
||||
return mapping[key]
|
||||
print(f"[PipelineContext.get_step_output_dir] 未知 step_name={key!r},回退到 work_dir")
|
||||
return self.work_dir
|
||||
|
||||
|
||||
class BaseStepHandler(ABC):
|
||||
"""步骤处理器抽象基类。
|
||||
|
||||
所有步骤 Handler 必须实现:
|
||||
- step_key: 类属性,对应 config 中的 key(如 'step1', 'step2', ...)
|
||||
- execute(context, config): 执行步骤逻辑,返回结果字典
|
||||
|
||||
用法示例::
|
||||
|
||||
class Step1WaterMaskHandler(BaseStepHandler):
|
||||
step_key = 'step1'
|
||||
|
||||
def execute(self, ctx, config):
|
||||
result = WaterMaskStep.run(...)
|
||||
ctx.water_mask_path = result
|
||||
return {'water_mask_path': result}
|
||||
"""
|
||||
|
||||
# 子类必须定义:对应 config 字典中的 key
|
||||
step_key: str = None
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, context: PipelineContext, config: dict) -> dict:
|
||||
"""执行步骤逻辑。
|
||||
|
||||
Args:
|
||||
context: 管道执行上下文(共享状态)
|
||||
config: 该步骤的配置字典(即 config[self.step_key])
|
||||
|
||||
Returns:
|
||||
结果字典,包含该步骤产生的输出路径等信息。
|
||||
调度器会将返回值合并到全局结果中。
|
||||
|
||||
Raises:
|
||||
Exception: 任何异常都会由调度器捕获并记录。
|
||||
"""
|
||||
...
|
||||
|
||||
def _resolve_path(self, explicit: Optional[str], fallback: Optional[str],
|
||||
label: str = "path") -> Optional[str]:
|
||||
"""解析路径:优先使用显式传入值,否则回退到上下文中的缓存值。
|
||||
|
||||
Args:
|
||||
explicit: 调用方显式传入的路径
|
||||
fallback: 上下文中的缓存路径
|
||||
label: 用于日志的标签
|
||||
|
||||
Returns:
|
||||
解析后的路径,若两者均为 None 则返回 None
|
||||
"""
|
||||
if explicit is not None:
|
||||
return explicit
|
||||
if fallback is not None:
|
||||
return fallback
|
||||
return None
|
||||
199
src/core/handlers/pipeline_scheduler.py
Normal file
@ -0,0 +1,199 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
极简管道调度器
|
||||
|
||||
替代原 WaterQualityInversionPipeline(2598 行上帝类)的调度核心。
|
||||
调度器自身不包含任何算法逻辑,仅负责:
|
||||
1. 维护 PipelineContext(共享状态)
|
||||
2. 根据 config key 从 Handler 注册表查找对应处理器
|
||||
3. 按序调用 handler.execute(ctx, config),收集结果
|
||||
4. 异常时记录错误并继续(或中止,取决于配置)
|
||||
|
||||
Handler 注册表是 step_key → BaseStepHandler 的映射。
|
||||
新增步骤只需:写一个 Handler 类 + 在注册表中加一行。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
|
||||
|
||||
class PipelineScheduler:
|
||||
"""极简管道调度器。
|
||||
|
||||
用法::
|
||||
|
||||
scheduler = PipelineScheduler(work_dir="./work_dir")
|
||||
scheduler.register_handler(Step1WaterMaskHandler())
|
||||
scheduler.register_handler(Step2GlintDetectionHandler())
|
||||
# ... 注册所有步骤 ...
|
||||
|
||||
scheduler.set_callback(my_callback) # 可选:GUI 进度回调
|
||||
|
||||
result = scheduler.run_full_pipeline(config)
|
||||
# result['step1'] → {'water_mask_path': ...}
|
||||
# result['step2'] → {'glint_mask_path': ...}
|
||||
# ...
|
||||
"""
|
||||
|
||||
def __init__(self, work_dir: str = "./work_dir"):
|
||||
self.ctx = PipelineContext(work_dir)
|
||||
self._handlers: Dict[str, BaseStepHandler] = {}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# Handler 注册
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
def register_handler(self, handler: BaseStepHandler):
|
||||
"""注册一个步骤处理器。
|
||||
|
||||
Args:
|
||||
handler: BaseStepHandler 实例(其 step_key 类属性决定 config 中的 key)
|
||||
"""
|
||||
if handler.step_key is None:
|
||||
raise ValueError(
|
||||
f"Handler {type(handler).__name__} 未定义 step_key 类属性"
|
||||
)
|
||||
self._handlers[handler.step_key] = handler
|
||||
|
||||
def register_handlers(self, handlers: List[BaseStepHandler]):
|
||||
"""批量注册步骤处理器。"""
|
||||
for h in handlers:
|
||||
self.register_handler(h)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# 回调
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
def set_callback(self, callback: Callable):
|
||||
"""设置 GUI 进度回调,代理到 PipelineContext。"""
|
||||
self.ctx.set_callback(callback)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# 单步执行
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
def run_step(self, step_key: str, config: dict) -> Dict[str, Any]:
|
||||
"""执行单个步骤。
|
||||
|
||||
Args:
|
||||
step_key: 步骤 key(如 'step1', 'step2', ...)
|
||||
config: 该步骤的配置字典
|
||||
|
||||
Returns:
|
||||
步骤执行结果字典
|
||||
|
||||
Raises:
|
||||
KeyError: 如果 step_key 未注册 Handler
|
||||
Exception: 步骤执行中的任何异常
|
||||
"""
|
||||
handler = self._handlers.get(step_key)
|
||||
if handler is None:
|
||||
raise KeyError(
|
||||
f"未注册的步骤: {step_key!r}。"
|
||||
f"已注册: {list(self._handlers.keys())}"
|
||||
)
|
||||
|
||||
self.ctx.notify(handler.step_key, 'start')
|
||||
result = handler.execute(self.ctx, config)
|
||||
self.ctx.notify(handler.step_key, 'completed')
|
||||
return result
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# 全流程执行
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
def run_full_pipeline(self, config: Dict[str, dict]) -> Dict[str, Any]:
|
||||
"""按 config 中的 key 顺序执行全流程。
|
||||
|
||||
遍历 config 的顶层 key,对每个 key:
|
||||
- 如果已注册 Handler → 执行并收集结果
|
||||
- 如果未注册 → 跳过并通知
|
||||
- 如果执行失败 → 记录错误,继续执行后续步骤(不中止)
|
||||
|
||||
Args:
|
||||
config: 全流程配置字典,格式为 {step_key: step_config, ...}
|
||||
例如: {'step1': {...}, 'step2': {...}, ...}
|
||||
|
||||
Returns:
|
||||
{
|
||||
'step_results': {step_key: result_dict, ...},
|
||||
'step_timings': {...},
|
||||
'total_elapsed': float,
|
||||
'errors': {step_key: error_message, ...},
|
||||
}
|
||||
"""
|
||||
self.ctx.pipeline_start_time = time.time()
|
||||
|
||||
step_results: Dict[str, Any] = {}
|
||||
errors: Dict[str, str] = {}
|
||||
|
||||
# 按 config 中的顺序遍历(Python 3.7+ dict 保序)
|
||||
for step_key, step_config in config.items():
|
||||
handler = self._handlers.get(step_key)
|
||||
|
||||
if handler is None:
|
||||
self.ctx.notify(step_key, 'skipped', '未注册 Handler')
|
||||
continue
|
||||
|
||||
try:
|
||||
result = handler.execute(self.ctx, step_config)
|
||||
step_results[step_key] = result
|
||||
self.ctx.notify(step_key, 'completed', str(result))
|
||||
except Exception as e:
|
||||
error_msg = f"{type(e).__name__}: {e}"
|
||||
errors[step_key] = error_msg
|
||||
step_results[step_key] = {'error': error_msg}
|
||||
self.ctx.notify(step_key, 'error', error_msg)
|
||||
# 不中止,继续执行后续步骤
|
||||
|
||||
self.ctx.pipeline_end_time = time.time()
|
||||
total_elapsed = self.ctx.pipeline_end_time - self.ctx.pipeline_start_time
|
||||
|
||||
return {
|
||||
'step_results': step_results,
|
||||
'step_timings': self.ctx.step_timings,
|
||||
'total_elapsed': total_elapsed,
|
||||
'total_elapsed_formatted': self.ctx._format_time(total_elapsed),
|
||||
'errors': errors,
|
||||
}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# 便捷属性(代理到 PipelineContext)
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
@property
|
||||
def work_dir(self) -> Path:
|
||||
return self.ctx.work_dir
|
||||
|
||||
@property
|
||||
def water_mask_path(self) -> Optional[str]:
|
||||
return self.ctx.water_mask_path
|
||||
|
||||
@property
|
||||
def glint_mask_path(self) -> Optional[str]:
|
||||
return self.ctx.glint_mask_path
|
||||
|
||||
@property
|
||||
def deglint_img_path(self) -> Optional[str]:
|
||||
return self.ctx.deglint_img_path
|
||||
|
||||
@property
|
||||
def processed_csv_path(self) -> Optional[str]:
|
||||
return self.ctx.processed_csv_path
|
||||
|
||||
@property
|
||||
def training_csv_path(self) -> Optional[str]:
|
||||
return self.ctx.training_csv_path
|
||||
|
||||
@property
|
||||
def indices_path(self) -> Optional[str]:
|
||||
return self.ctx.indices_path
|
||||
|
||||
def get_step_output_dir(self, step_name: str) -> Path:
|
||||
return self.ctx.get_step_output_dir(step_name)
|
||||
57
src/core/handlers/register_handlers.py
Normal file
@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Handler 注册辅助函数
|
||||
|
||||
将所有步骤 Handler 一次性注册到 PipelineScheduler。
|
||||
新增步骤只需在此函数中加一行 register_handler() 调用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.core.handlers.step1_water_mask import Step1WaterMaskHandler
|
||||
from src.core.handlers.step2_glint_detection import Step2GlintDetectionHandler
|
||||
from src.core.handlers.step3_glint_removal import Step3GlintRemovalHandler
|
||||
from src.core.handlers.step4_sampling import Step4SamplingHandler
|
||||
from src.core.handlers.step5_process_csv import Step5ProcessCsvHandler
|
||||
from src.core.handlers.step6_extract_spectra import Step6ExtractSpectraHandler
|
||||
from src.core.handlers.step7_calc_indices import Step7CalcIndicesHandler
|
||||
from src.core.handlers.step8_ml_train import Step8MlTrainHandler
|
||||
from src.core.handlers.step9_ml_predict import Step9MlPredictHandler
|
||||
from src.core.handlers.step10_qaa_inversion import Step10QaaInversionHandler
|
||||
from src.core.handlers.step11_concentration import Step11ConcentrationHandler
|
||||
from src.core.handlers.step12_kriging import Step12KrigingHandler
|
||||
from src.core.handlers.step13_visualization import Step13VisualizationHandler
|
||||
from src.core.handlers.step14_report import Step14ReportHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.core.handlers.pipeline_scheduler import PipelineScheduler
|
||||
|
||||
|
||||
def register_all_handlers(scheduler: PipelineScheduler):
|
||||
"""将所有已实现的步骤 Handler 注册到调度器。
|
||||
|
||||
用法::
|
||||
|
||||
scheduler = PipelineScheduler(work_dir="./work_dir")
|
||||
register_all_handlers(scheduler)
|
||||
result = scheduler.run_full_pipeline(config)
|
||||
|
||||
新增步骤时,在此函数中追加一行 register_handler() 即可。
|
||||
"""
|
||||
scheduler.register_handler(Step1WaterMaskHandler())
|
||||
scheduler.register_handler(Step2GlintDetectionHandler())
|
||||
scheduler.register_handler(Step3GlintRemovalHandler())
|
||||
scheduler.register_handler(Step4SamplingHandler())
|
||||
scheduler.register_handler(Step5ProcessCsvHandler())
|
||||
scheduler.register_handler(Step6ExtractSpectraHandler())
|
||||
scheduler.register_handler(Step7CalcIndicesHandler())
|
||||
scheduler.register_handler(Step8MlTrainHandler())
|
||||
scheduler.register_handler(Step9MlPredictHandler())
|
||||
scheduler.register_handler(Step10QaaInversionHandler())
|
||||
scheduler.register_handler(Step11ConcentrationHandler())
|
||||
scheduler.register_handler(Step12KrigingHandler())
|
||||
scheduler.register_handler(Step13VisualizationHandler())
|
||||
scheduler.register_handler(Step14ReportHandler())
|
||||
137
src/core/handlers/step10_qaa_inversion.py
Normal file
@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step10 处理器:QAA 准解析算法反演
|
||||
|
||||
将原 WaterQualityInversionPipeline.step8_qaa_inversion() 方法
|
||||
剥离为独立的 Step10QaaInversionHandler。
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
|
||||
|
||||
class Step10QaaInversionHandler(BaseStepHandler):
|
||||
"""步骤10:QAA 准解析算法反演(非经验模型)。
|
||||
|
||||
对应 config key: 'step10_qaa'
|
||||
直接使用 QAABaselineSolver 进行物理推导。
|
||||
"""
|
||||
|
||||
step_key = 'step10_qaa'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
from src.core.algorithms.qaa.qaas_baseline import QAABaselineSolver
|
||||
from src.utils.water_owt_config import get_lambda_0
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
lake_name = config.get('lake_name', 'Unknown')
|
||||
lambda_0 = config.get('lambda_0', get_lambda_0(lake_name))
|
||||
output_dir = os.path.join(context.work_dir, "10_QAA_Inversion")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = config.get('output_path') or os.path.join(output_dir, "a_lambda_results.csv")
|
||||
|
||||
spectrum_csv = config.get('spectrum_csv_path')
|
||||
if not spectrum_csv:
|
||||
spectrum_csv = context.training_csv_path
|
||||
if not spectrum_csv or not os.path.exists(spectrum_csv):
|
||||
fallback_candidates = []
|
||||
step6_dir = os.path.join(context.work_dir, "6_Spectral_Feature_Extraction")
|
||||
if os.path.isdir(step6_dir):
|
||||
for f in sorted(os.listdir(step6_dir)):
|
||||
if f.lower().endswith('.csv'):
|
||||
fallback_candidates.append(os.path.join(step6_dir, f))
|
||||
if fallback_candidates:
|
||||
spectrum_csv = fallback_candidates[0]
|
||||
context.notify('step10_qaa', 'info',
|
||||
f'spectrum_csv_path 为空,已自动回退到 step6 产物: {spectrum_csv}')
|
||||
else:
|
||||
msg = f'训练光谱 CSV 不存在或路径为空: {spectrum_csv}'
|
||||
context.notify('step10_qaa', 'error', msg)
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤10: QAA 反演", step_start_time, step_end_time,
|
||||
status="failed", error=msg
|
||||
)
|
||||
return {'error': msg}
|
||||
|
||||
try:
|
||||
df = pd.read_csv(spectrum_csv, encoding="utf-8-sig")
|
||||
col_names = df.columns.tolist()
|
||||
|
||||
wavelength_col_idx = None
|
||||
for i, col in enumerate(col_names):
|
||||
try:
|
||||
float(col)
|
||||
wavelength_col_idx = i
|
||||
break
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if wavelength_col_idx is None:
|
||||
msg = "无法从 CSV 列名中识别波长信息"
|
||||
context.notify('step10_qaa', 'error', msg)
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤10: QAA 反演", step_start_time, step_end_time,
|
||||
status="failed", error=msg
|
||||
)
|
||||
return {'error': msg}
|
||||
|
||||
meta_df = df.iloc[:, :wavelength_col_idx].copy()
|
||||
wavelengths = np.array([float(c) for c in col_names[wavelength_col_idx:]], dtype=np.float64)
|
||||
data_matrix = df.iloc[:, wavelength_col_idx:].values.astype(np.float64)
|
||||
if data_matrix.ndim == 1:
|
||||
data_matrix = data_matrix[np.newaxis, :]
|
||||
|
||||
solver = QAABaselineSolver()
|
||||
raw_result = solver.run_inversion(wavelengths, data_matrix, lambda_0)
|
||||
|
||||
if isinstance(raw_result, list):
|
||||
sample_results = raw_result
|
||||
else:
|
||||
sample_results = [raw_result]
|
||||
|
||||
rows_out = []
|
||||
for i, sample_result in enumerate(sample_results):
|
||||
wl_arr = wavelengths
|
||||
a_arr = sample_result['a_lambda']
|
||||
bb_arr = sample_result['bb_lambda']
|
||||
meta_row = meta_df.iloc[i].to_dict() if i < len(meta_df) else {}
|
||||
for j, wl in enumerate(wl_arr):
|
||||
rows_out.append({
|
||||
'sample_id': f"sample_{i}",
|
||||
'Wavelength': wl,
|
||||
'a_lambda': a_arr[j],
|
||||
'bb_lambda': bb_arr[j],
|
||||
**meta_row,
|
||||
})
|
||||
|
||||
result_df = pd.DataFrame(rows_out)
|
||||
result_df.to_csv(output_path, index=False, float_format='%.8f')
|
||||
|
||||
context.qaa_output_path = output_path
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤10: QAA 反演", step_start_time, step_end_time
|
||||
)
|
||||
context.notify('step10_qaa', 'completed',
|
||||
f"QAA 反演完毕,水域={lake_name},λ₀={lambda_0}nm")
|
||||
|
||||
return {'qaa_output_path': output_path}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤10: QAA 反演", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
71
src/core/handlers/step11_concentration.py
Normal file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step11 处理器:浓度反演
|
||||
|
||||
将原 WaterQualityInversionPipeline.step9_concentration_inversion() 方法
|
||||
剥离为独立的 Step11ConcentrationHandler。
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
|
||||
|
||||
class Step11ConcentrationHandler(BaseStepHandler):
|
||||
"""步骤11:浓度反演(基于 QAA Step10 输出的 a_lambda/bb_lambda)。
|
||||
|
||||
对应 config key: 'step11_concentration'
|
||||
直接使用 ConcentrationPipeline 进行浓度反演。
|
||||
"""
|
||||
|
||||
step_key = 'step11_concentration'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
from src.core.algorithms.concentration_inversion import ConcentrationPipeline
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
input_csv = config.get('input_csv') or context.qaa_output_path
|
||||
output_csv = config.get('output_csv')
|
||||
lake_case = config.get('lake_case', 'medium')
|
||||
|
||||
if not input_csv or not os.path.exists(input_csv):
|
||||
msg = f"QAA 结果文件不存在或路径为空: {input_csv}"
|
||||
context.notify('step11_concentration', 'error', msg)
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤11: 浓度反演", step_start_time, step_end_time,
|
||||
status="failed", error=msg
|
||||
)
|
||||
return {'error': msg}
|
||||
|
||||
if not output_csv:
|
||||
output_dir = os.path.join(context.work_dir, "11_Concentration")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_csv = os.path.join(output_dir, "final_concentrations.csv")
|
||||
|
||||
try:
|
||||
pipeline = ConcentrationPipeline(lake_case=lake_case)
|
||||
result_csv = pipeline.run_pipeline(input_csv, output_csv)
|
||||
|
||||
context.concentration_output_path = result_csv
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤11: 浓度反演", step_start_time, step_end_time
|
||||
)
|
||||
context.notify('step11_concentration', 'completed',
|
||||
f"浓度反演完毕,结果保存于: {result_csv}")
|
||||
|
||||
return {'concentration_output_path': result_csv}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤11: 浓度反演", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
81
src/core/handlers/step12_kriging.py
Normal file
@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step12 处理器:克里金空间插值与分布图生成
|
||||
|
||||
将原 WaterQualityInversionPipeline.step10_map() 方法
|
||||
剥离为独立的 Step12KrigingHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.mapping_step import MappingStep
|
||||
|
||||
|
||||
class Step12KrigingHandler(BaseStepHandler):
|
||||
"""步骤12:克里金空间插值与分布图生成。
|
||||
|
||||
对应 config key: 'step12_kriging'
|
||||
委托类: MappingStep.generate_distribution_map()
|
||||
"""
|
||||
|
||||
step_key = 'step12_kriging'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
prediction_csv_path = config.get('prediction_csv_path')
|
||||
boundary_shp_path = config.get('boundary_shp_path')
|
||||
|
||||
# 强制输出到 visualization_dir
|
||||
csv_name = Path(prediction_csv_path).stem if prediction_csv_path else "distribution"
|
||||
forced_image_path = str(context.visualization_dir / f"{csv_name}_distribution.png")
|
||||
viz_dir_resolved = str(context.visualization_dir)
|
||||
|
||||
output_image_path = config.get('output_image_path')
|
||||
if output_image_path and output_image_path != forced_image_path:
|
||||
norm_user = output_image_path.replace('\\', '/').rstrip('/')
|
||||
norm_viz = viz_dir_resolved.replace('\\', '/').rstrip('/')
|
||||
if not norm_user.startswith(norm_viz + '/') and norm_user != norm_viz:
|
||||
output_image_path = forced_image_path
|
||||
else:
|
||||
output_image_path = forced_image_path
|
||||
|
||||
try:
|
||||
result = MappingStep.generate_distribution_map(
|
||||
prediction_csv_path=prediction_csv_path,
|
||||
boundary_shp_path=boundary_shp_path,
|
||||
output_image_path=output_image_path,
|
||||
resolution=config.get('resolution', 30),
|
||||
input_crs=config.get('input_crs', 'EPSG:32651'),
|
||||
output_crs=config.get('output_crs', 'EPSG:4326'),
|
||||
show_sample_points=config.get('show_sample_points', False),
|
||||
base_map_tif=config.get('base_map_tif'),
|
||||
use_distance_diffusion=config.get('use_distance_diffusion', True),
|
||||
max_diffusion_distance=config.get('max_diffusion_distance'),
|
||||
diffusion_power=config.get('diffusion_power', 2),
|
||||
diffusion_n_neighbors=config.get('diffusion_n_neighbors', 15),
|
||||
cmap=config.get('cmap'),
|
||||
expand_ratio=config.get('expand_ratio', 0.05),
|
||||
output_dir=str(context.visualization_dir),
|
||||
)
|
||||
|
||||
context.distribution_map_path = result
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤12: 克里金插值与分布图", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'distribution_map_path': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤12: 克里金插值与分布图", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
349
src/core/handlers/step13_visualization.py
Normal file
@ -0,0 +1,349 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step13 处理器:可视化成图
|
||||
|
||||
将原 WaterQualityInversionPipeline 中的可视化方法
|
||||
(散点图、箱型图、光谱曲线、统计图表、耀斑预览)
|
||||
剥离为独立的 Step13VisualizationHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
|
||||
|
||||
class Step13VisualizationHandler(BaseStepHandler):
|
||||
"""步骤13:可视化成图。
|
||||
|
||||
对应 config key: 'step13_visualization'
|
||||
包含:散点图、箱型图、光谱曲线、统计图表、耀斑预览。
|
||||
"""
|
||||
|
||||
step_key = 'step13_visualization'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
output_files: Dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
# ── 散点图 ──
|
||||
if config.get('generate_scatter', True):
|
||||
if context.training_csv_path and context.models_dir.exists():
|
||||
try:
|
||||
scatter_config = config.get('scatter_config', {})
|
||||
scatter_paths = self._generate_scatter_plots(context, scatter_config)
|
||||
output_files['scatter_plots'] = scatter_paths
|
||||
except Exception as e:
|
||||
context.notify('step13_visualization', 'warning',
|
||||
f"生成散点图时出错: {e}")
|
||||
|
||||
# ── 箱型图 ──
|
||||
if config.get('generate_boxplots', True):
|
||||
if context.processed_csv_path:
|
||||
try:
|
||||
boxplot_config = config.get('boxplot_config', {})
|
||||
boxplot_paths = self._generate_boxplots(context, boxplot_config)
|
||||
output_files['boxplots'] = boxplot_paths
|
||||
except Exception as e:
|
||||
context.notify('step13_visualization', 'warning',
|
||||
f"生成箱型图时出错: {e}")
|
||||
|
||||
# ── 光谱曲线 ──
|
||||
if config.get('generate_spectrum', True):
|
||||
if context.training_csv_path:
|
||||
try:
|
||||
spectrum_paths = self._generate_spectrum_plots(context, config)
|
||||
output_files['spectrum_plots'] = spectrum_paths
|
||||
except Exception as e:
|
||||
context.notify('step13_visualization', 'warning',
|
||||
f"生成光谱曲线图时出错: {e}")
|
||||
|
||||
# ── 统计图表 ──
|
||||
if config.get('generate_statistics', True):
|
||||
if context.processed_csv_path:
|
||||
try:
|
||||
stat_charts = self._generate_statistics(context)
|
||||
output_files['statistical_charts'] = stat_charts
|
||||
except Exception as e:
|
||||
context.notify('step13_visualization', 'warning',
|
||||
f"生成统计图表时出错: {e}")
|
||||
|
||||
# ── 耀斑预览 ──
|
||||
if config.get('generate_glint_previews', True):
|
||||
try:
|
||||
glint_config = config.get('glint_preview_config', {})
|
||||
preview_paths = context.visualizer.generate_glint_deglint_previews(
|
||||
work_dir=glint_config.get('work_dir') or str(context.work_dir),
|
||||
output_subdir=glint_config.get('output_subdir', 'glint_deglint_previews'),
|
||||
generate_glint=glint_config.get('generate_glint', True),
|
||||
generate_deglint=glint_config.get('generate_deglint', True),
|
||||
)
|
||||
output_files['glint_deglint_previews'] = preview_paths
|
||||
except Exception as e:
|
||||
context.notify('step13_visualization', 'warning',
|
||||
f"生成耀斑预览图时出错: {e}")
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤13: 可视化成图", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'visualization_outputs': output_files}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤13: 可视化成图", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
# ── 散点图 ──
|
||||
|
||||
def _generate_scatter_plots(self, context: PipelineContext,
|
||||
scatter_config: dict) -> Dict[str, str]:
|
||||
training_csv_path = context.training_csv_path
|
||||
models_dir = str(context.models_dir)
|
||||
metric = scatter_config.get('metric', 'test_r2')
|
||||
use_enhanced = scatter_config.get('use_enhanced', True)
|
||||
feature_start_column = scatter_config.get('feature_start_column', 13)
|
||||
test_size = scatter_config.get('test_size', 0.2)
|
||||
random_state = scatter_config.get('random_state', 42)
|
||||
|
||||
scatter_paths = {}
|
||||
|
||||
if use_enhanced:
|
||||
try:
|
||||
results = context.scatter_batch.batch_plot_scatter(
|
||||
models_root_dir=models_dir,
|
||||
csv_path=training_csv_path,
|
||||
output_dir=str(context.visualization_dir / "scatter_plots"),
|
||||
metric=metric,
|
||||
target_column=None,
|
||||
feature_start_column=feature_start_column,
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
)
|
||||
for target_name, result in results.items():
|
||||
if result.get('status') == 'success':
|
||||
scatter_paths[target_name] = result.get('save_path', '')
|
||||
except Exception:
|
||||
use_enhanced = False
|
||||
|
||||
if not use_enhanced or not scatter_paths:
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
models_path = Path(models_dir)
|
||||
for target_folder in models_path.iterdir():
|
||||
if not target_folder.is_dir():
|
||||
continue
|
||||
target_name = target_folder.name
|
||||
try:
|
||||
inferencer = WaterQualityInference(str(target_folder))
|
||||
eval_result = inferencer.evaluate_with_split(
|
||||
data_csv_path=training_csv_path,
|
||||
split_method="spxy",
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
metric=metric,
|
||||
)
|
||||
predictions = eval_result.get('predictions', {})
|
||||
if predictions:
|
||||
y_train_true = predictions.get('y_train_true')
|
||||
y_train_pred = predictions.get('y_train_pred')
|
||||
y_test_true = predictions.get('y_test_true')
|
||||
y_test_pred = predictions.get('y_test_pred')
|
||||
metrics = eval_result.get('test_metrics', {})
|
||||
if y_train_true is not None and y_test_true is not None:
|
||||
y_all_true = np.concatenate([y_train_true, y_test_true])
|
||||
y_all_pred = np.concatenate([y_train_pred, y_test_pred])
|
||||
train_indices = np.arange(len(y_train_true))
|
||||
test_indices = np.arange(len(y_train_true), len(y_all_true))
|
||||
scatter_path = context.visualizer.plot_scatter_true_vs_pred(
|
||||
y_true=y_all_true,
|
||||
y_pred=y_all_pred,
|
||||
target_name=target_name,
|
||||
train_indices=train_indices,
|
||||
test_indices=test_indices,
|
||||
metrics={
|
||||
'train_r2': eval_result.get('train_metrics', {}).get('r2', 0),
|
||||
'test_r2': metrics.get('r2', 0),
|
||||
'train_rmse': eval_result.get('train_metrics', {}).get('rmse', 0),
|
||||
'test_rmse': metrics.get('rmse', 0),
|
||||
}
|
||||
)
|
||||
scatter_paths[target_name] = scatter_path
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return scatter_paths
|
||||
|
||||
# ── 箱型图 ──
|
||||
|
||||
def _generate_boxplots(self, context: PipelineContext,
|
||||
boxplot_config: dict) -> Dict[str, str]:
|
||||
csv_path = context.processed_csv_path
|
||||
parameter_columns = boxplot_config.get('parameter_columns')
|
||||
data_start_column = boxplot_config.get('data_start_column', 4)
|
||||
save_individual = boxplot_config.get('save_individual', True)
|
||||
use_seaborn = boxplot_config.get('use_seaborn', True)
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
if parameter_columns is None:
|
||||
data_columns = df.iloc[:, data_start_column:]
|
||||
parameter_columns = list(data_columns.columns)
|
||||
else:
|
||||
parameter_columns = [col for col in parameter_columns if col in df.columns]
|
||||
|
||||
if not parameter_columns:
|
||||
return {}
|
||||
|
||||
boxplot_dir = context.visualization_dir / "boxplots"
|
||||
boxplot_dir.mkdir(parents=True, exist_ok=True)
|
||||
boxplot_paths = {}
|
||||
|
||||
if save_individual:
|
||||
for column in parameter_columns:
|
||||
if column not in df.columns:
|
||||
continue
|
||||
clean_data = df[column].dropna()
|
||||
if len(clean_data) == 0:
|
||||
continue
|
||||
try:
|
||||
plt.figure(figsize=(8, 6))
|
||||
if use_seaborn:
|
||||
plot_data = pd.DataFrame({'参数': [column] * len(clean_data), '数值': clean_data})
|
||||
sns.boxplot(data=plot_data, x='参数', y='数值', palette='Set2')
|
||||
sns.stripplot(data=plot_data, x='参数', y='数值',
|
||||
color='red', alpha=0.6, size=5, jitter=True)
|
||||
else:
|
||||
box_plot = plt.boxplot([clean_data], labels=[column],
|
||||
patch_artist=True, showfliers=False)
|
||||
box_plot['boxes'][0].set_facecolor('lightblue')
|
||||
box_plot['boxes'][0].set_alpha(0.7)
|
||||
x_pos = np.random.normal(1, 0.04, size=len(clean_data))
|
||||
plt.scatter(x_pos, clean_data, alpha=0.6, s=30, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
plt.title(f'{column} - 箱型图', fontsize=14, fontweight='bold')
|
||||
plt.xlabel('参数', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
stats_text = (f'数据点数: {len(clean_data)}\n'
|
||||
f'均值: {clean_data.mean():.2f}\n'
|
||||
f'中位数: {clean_data.median():.2f}\n'
|
||||
f'标准差: {clean_data.std():.2f}')
|
||||
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round',
|
||||
facecolor='wheat' if not use_seaborn else 'lightgreen',
|
||||
alpha=0.8))
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
plt.tight_layout()
|
||||
safe_name = column.replace('/', '_').replace('\\', '_').replace(':', '_')
|
||||
save_path = boxplot_dir / f'{safe_name}_boxplot.png'
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
boxplot_paths[column] = str(save_path)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 综合箱型图
|
||||
try:
|
||||
plt.figure(figsize=(max(12, len(parameter_columns) * 0.8), 8))
|
||||
box_data = []
|
||||
labels = []
|
||||
for column in parameter_columns:
|
||||
if column in df.columns:
|
||||
clean_data = df[column].dropna()
|
||||
if len(clean_data) > 0:
|
||||
box_data.append(clean_data)
|
||||
labels.append(column)
|
||||
if box_data:
|
||||
if use_seaborn:
|
||||
melted_data = pd.melt(df[labels], var_name='参数', value_name='数值')
|
||||
melted_data = melted_data.dropna()
|
||||
sns.boxplot(data=melted_data, x='参数', y='数值', palette='Set3')
|
||||
sns.stripplot(data=melted_data, x='参数', y='数值',
|
||||
color='red', alpha=0.6, size=4, jitter=True)
|
||||
else:
|
||||
box_plot = plt.boxplot(box_data, labels=labels, patch_artist=True, showfliers=False)
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(box_data)))
|
||||
for patch, color in zip(box_plot['boxes'], colors):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.7)
|
||||
for i, data in enumerate(box_data):
|
||||
x_pos = np.random.normal(i + 1, 0.04, size=len(data))
|
||||
plt.scatter(x_pos, data, alpha=0.6, s=20, color='red',
|
||||
edgecolors='black', linewidth=0.5, zorder=3)
|
||||
plt.title('水质参数箱型图(综合)', fontsize=16, fontweight='bold')
|
||||
plt.xlabel('参数', fontsize=12)
|
||||
plt.ylabel('数值', fontsize=12)
|
||||
plt.xticks(rotation=45, ha='right')
|
||||
plt.grid(True, alpha=0.3, linestyle='--')
|
||||
plt.tight_layout()
|
||||
combined_path = boxplot_dir / 'all_parameters_boxplot.png'
|
||||
plt.savefig(combined_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
boxplot_paths['all_parameters'] = str(combined_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return boxplot_paths
|
||||
|
||||
# ── 光谱曲线 ──
|
||||
|
||||
def _generate_spectrum_plots(self, context: PipelineContext,
|
||||
config: dict) -> Dict[str, str]:
|
||||
csv_path = context.training_csv_path
|
||||
wavelength_start_column = config.get('feature_start_column', 'UTM_Y')
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
if isinstance(wavelength_start_column, str):
|
||||
try:
|
||||
wavelength_start_idx = df.columns.get_loc(wavelength_start_column)
|
||||
except KeyError:
|
||||
wavelength_start_idx = 13
|
||||
else:
|
||||
wavelength_start_idx = wavelength_start_column
|
||||
|
||||
parameter_columns = list(df.columns[:wavelength_start_idx])
|
||||
if len(parameter_columns) > 2:
|
||||
parameter_columns = parameter_columns[2:]
|
||||
|
||||
spectrum_paths = {}
|
||||
for param_col in parameter_columns:
|
||||
if param_col not in df.columns:
|
||||
continue
|
||||
try:
|
||||
spectrum_path = context.visualizer.plot_spectrum_by_parameter(
|
||||
csv_path=csv_path,
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wavelength_start_column,
|
||||
n_groups=5,
|
||||
)
|
||||
spectrum_paths[param_col] = spectrum_path
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return spectrum_paths
|
||||
|
||||
# ── 统计图表 ──
|
||||
|
||||
def _generate_statistics(self, context: PipelineContext) -> Dict[str, str]:
|
||||
csv_path = context.processed_csv_path
|
||||
df = pd.read_csv(csv_path)
|
||||
parameter_columns = list(df.columns[2:])
|
||||
parameter_columns = [col for col in parameter_columns
|
||||
if df[col].dtype in [np.float64, np.int64]]
|
||||
|
||||
return context.visualizer.plot_statistical_charts(
|
||||
csv_path=csv_path,
|
||||
parameter_columns=parameter_columns,
|
||||
)
|
||||
142
src/core/handlers/step14_report.py
Normal file
@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step14 处理器:报告生成
|
||||
|
||||
将原 WaterQualityInversionPipeline.generate_pipeline_report() 方法
|
||||
剥离为独立的 Step14ReportHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
|
||||
|
||||
class Step14ReportHandler(BaseStepHandler):
|
||||
"""步骤14:流程执行报告生成。
|
||||
|
||||
对应 config key: 'step14_report'
|
||||
生成 CSV 和 TXT 格式的流程执行报告。
|
||||
"""
|
||||
|
||||
step_key = 'step14_report'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
try:
|
||||
output_path = config.get('output_path')
|
||||
if output_path is None:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
output_path = str(context.reports_dir / f"pipeline_report_{timestamp}.csv")
|
||||
|
||||
report_data = []
|
||||
total_time = 0.0
|
||||
|
||||
step_order = [
|
||||
"步骤1: 水域掩膜生成",
|
||||
"步骤2: 耀斑区域检测",
|
||||
"步骤3: 耀斑去除",
|
||||
"步骤4: 数据预处理",
|
||||
"步骤5: 光谱提取",
|
||||
"步骤6: 水质光谱指数计算",
|
||||
"步骤7: 机器学习建模与训练",
|
||||
"步骤8: 非经验模型训练",
|
||||
"步骤9: 自定义回归",
|
||||
"步骤10: 采样点生成",
|
||||
"步骤11: 参数预测",
|
||||
"步骤12: 分布图生成",
|
||||
]
|
||||
|
||||
for step_name in step_order:
|
||||
if step_name in context.step_timings:
|
||||
timing_info = context.step_timings[step_name]
|
||||
report_data.append({
|
||||
'步骤': step_name,
|
||||
'开始时间': timing_info['start_time'],
|
||||
'结束时间': timing_info['end_time'],
|
||||
'耗时(秒)': f"{timing_info['elapsed_seconds']:.2f}",
|
||||
'耗时(格式化)': timing_info['elapsed_formatted'],
|
||||
'状态': timing_info['status'],
|
||||
'错误信息': timing_info.get('error', '')
|
||||
})
|
||||
if timing_info['status'] == 'completed':
|
||||
total_time += timing_info['elapsed_seconds']
|
||||
|
||||
if context.pipeline_start_time and context.pipeline_end_time:
|
||||
pipeline_total = context.pipeline_end_time - context.pipeline_start_time
|
||||
report_data.append({
|
||||
'步骤': '总计',
|
||||
'开始时间': datetime.fromtimestamp(context.pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'结束时间': datetime.fromtimestamp(context.pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'耗时(秒)': f"{pipeline_total:.2f}",
|
||||
'耗时(格式化)': context._format_time(pipeline_total),
|
||||
'状态': 'completed',
|
||||
'错误信息': ''
|
||||
})
|
||||
|
||||
df_report = pd.DataFrame(report_data)
|
||||
df_report.to_csv(output_path, index=False, encoding='utf-8-sig')
|
||||
|
||||
txt_output_path = str(Path(output_path).with_suffix('.txt'))
|
||||
with open(txt_output_path, 'w', encoding='utf-8') as f:
|
||||
f.write("=" * 80 + "\n")
|
||||
f.write("水质参数反演流程执行报告\n")
|
||||
f.write("=" * 80 + "\n\n")
|
||||
|
||||
if context.pipeline_start_time and context.pipeline_end_time:
|
||||
f.write(f"流程开始时间: {datetime.fromtimestamp(context.pipeline_start_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"流程结束时间: {datetime.fromtimestamp(context.pipeline_end_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"总耗时: {context._format_time(context.pipeline_end_time - context.pipeline_start_time)}\n\n")
|
||||
|
||||
f.write("-" * 80 + "\n")
|
||||
f.write("各步骤执行详情:\n")
|
||||
f.write("-" * 80 + "\n\n")
|
||||
|
||||
for step_name in step_order:
|
||||
if step_name in context.step_timings:
|
||||
timing_info = context.step_timings[step_name]
|
||||
f.write(f"{step_name}\n")
|
||||
f.write(f" 开始时间: {timing_info['start_time']}\n")
|
||||
f.write(f" 结束时间: {timing_info['end_time']}\n")
|
||||
f.write(f" 耗时: {timing_info['elapsed_formatted']} ({timing_info['elapsed_seconds']:.2f}秒)\n")
|
||||
f.write(f" 状态: {timing_info['status']}\n")
|
||||
if timing_info.get('error'):
|
||||
f.write(f" 错误: {timing_info['error']}\n")
|
||||
f.write("\n")
|
||||
|
||||
f.write("-" * 80 + "\n")
|
||||
f.write("统计摘要:\n")
|
||||
f.write("-" * 80 + "\n")
|
||||
completed_steps = [s for s in context.step_timings.values() if s['status'] == 'completed']
|
||||
failed_steps = [s for s in context.step_timings.values() if s['status'] == 'failed']
|
||||
skipped_steps = [s for s in context.step_timings.values() if s['status'] == 'skipped']
|
||||
f.write(f"成功完成的步骤: {len(completed_steps)}\n")
|
||||
f.write(f"失败的步骤: {len(failed_steps)}\n")
|
||||
f.write(f"跳过的步骤: {len(skipped_steps)}\n")
|
||||
if completed_steps:
|
||||
completed_times = [s['elapsed_seconds'] for s in completed_steps]
|
||||
f.write(f"平均耗时: {context._format_time(np.mean(completed_times))}\n")
|
||||
f.write(f"最长耗时: {context._format_time(np.max(completed_times))}\n")
|
||||
f.write(f"最短耗时: {context._format_time(np.min(completed_times))}\n")
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤14: 报告生成", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'report_csv': output_path, 'report_txt': txt_output_path}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤14: 报告生成", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
83
src/core/handlers/step1_water_mask.py
Normal file
@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step1 处理器:水域掩膜生成
|
||||
|
||||
将原 WaterQualityInversionPipeline.step1_generate_water_mask() 方法
|
||||
剥离为独立的 Step1WaterMaskHandler。
|
||||
|
||||
这是 14 个步骤 Handler 的**打样模板**,其余步骤照此模式拆分:
|
||||
1. 继承 BaseStepHandler,设置 step_key 类属性
|
||||
2. 实现 execute(ctx, config) → 调用对应 Step 类的静态方法
|
||||
3. 将输出路径写入 ctx(上下文共享)
|
||||
4. 记录步骤耗时
|
||||
5. 返回结果字典
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.water_mask_step import WaterMaskStep
|
||||
|
||||
|
||||
class Step1WaterMaskHandler(BaseStepHandler):
|
||||
"""步骤1:水域掩膜生成。
|
||||
|
||||
对应 config key: 'step1'
|
||||
委托类: WaterMaskStep.run()
|
||||
|
||||
用法::
|
||||
|
||||
handler = Step1WaterMaskHandler()
|
||||
result = handler.execute(ctx, config['step1'])
|
||||
# ctx.water_mask_path 已被更新
|
||||
"""
|
||||
|
||||
step_key = 'step1'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
"""执行水域掩膜生成。
|
||||
|
||||
config 可包含的键(全部透传给 WaterMaskStep.run()):
|
||||
- mask_path: 水体掩膜文件路径(.shp / .dat / .tif)
|
||||
- img_path: 输入影像路径(shp 栅格化或 NDWI 时需要)
|
||||
- ndwi_threshold: NDWI 阈值(默认 0.4)
|
||||
- use_ndwi: 是否使用 NDWI 方法(默认 False)
|
||||
- generate_png: 是否生成 PNG 预览(默认 True)
|
||||
- output_path: 指定输出路径(可选)
|
||||
|
||||
Returns:
|
||||
{'water_mask_path': str}
|
||||
"""
|
||||
step_start_time = time.time()
|
||||
|
||||
try:
|
||||
result = WaterMaskStep.run(
|
||||
mask_path=config.get('mask_path'),
|
||||
img_path=config.get('img_path'),
|
||||
ndwi_threshold=config.get('ndwi_threshold', 0.4),
|
||||
use_ndwi=config.get('use_ndwi', False),
|
||||
generate_png=config.get('generate_png', True),
|
||||
output_path=config.get('output_path'),
|
||||
water_mask_dir=str(context.water_mask_dir),
|
||||
callback=context.notify,
|
||||
)
|
||||
|
||||
# 将输出路径写入上下文,供后续步骤使用
|
||||
context.water_mask_path = result
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤1: 水域掩膜生成", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'water_mask_path': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤1: 水域掩膜生成", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
67
src/core/handlers/step2_glint_detection.py
Normal file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step2 处理器:耀斑区域检测
|
||||
|
||||
将原 WaterQualityInversionPipeline.step2_find_glint_area() 方法
|
||||
剥离为独立的 Step2GlintDetectionHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.glint_detection_step import GlintDetectionStep
|
||||
|
||||
|
||||
class Step2GlintDetectionHandler(BaseStepHandler):
|
||||
"""步骤2:耀斑区域检测。
|
||||
|
||||
对应 config key: 'step2'
|
||||
委托类: GlintDetectionStep.run()
|
||||
"""
|
||||
|
||||
step_key = 'step2'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
water_mask_path = self._resolve_path(
|
||||
config.get('water_mask_path'), context.water_mask_path, 'water_mask'
|
||||
)
|
||||
|
||||
try:
|
||||
result = GlintDetectionStep.run(
|
||||
img_path=config.get('img_path'),
|
||||
glint_wave=config.get('glint_wave', 750.0),
|
||||
method=config.get('method', 'otsu'),
|
||||
z_threshold=config.get('z_threshold', 2.5),
|
||||
percentile=config.get('percentile', 95.0),
|
||||
iqr_multiplier=config.get('iqr_multiplier', 1.5),
|
||||
window_size=config.get('window_size', 15),
|
||||
multi_band_waves=config.get('multi_band_waves'),
|
||||
sub_method=config.get('sub_method', 'zscore'),
|
||||
weights=config.get('weights'),
|
||||
max_area=config.get('max_area'),
|
||||
buffer_size=config.get('buffer_size'),
|
||||
water_mask_path=water_mask_path,
|
||||
glint_dir=str(context.glint_dir),
|
||||
callback=context.notify,
|
||||
)
|
||||
|
||||
context.glint_mask_path = result
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤2: 耀斑区域检测", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'glint_mask_path': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤2: 耀斑区域检测", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
85
src/core/handlers/step3_glint_removal.py
Normal file
@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step3 处理器:耀斑去除
|
||||
|
||||
将原 WaterQualityInversionPipeline.step3_remove_glint() 方法
|
||||
剥离为独立的 Step3GlintRemovalHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.glint_removal_step import GlintRemovalStep
|
||||
|
||||
|
||||
class Step3GlintRemovalHandler(BaseStepHandler):
|
||||
"""步骤3:耀斑去除。
|
||||
|
||||
对应 config key: 'step3'
|
||||
委托类: GlintRemovalStep.run()
|
||||
"""
|
||||
|
||||
step_key = 'step3'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
water_mask_path = self._resolve_path(
|
||||
config.get('water_mask_path'), context.water_mask_path, 'water_mask'
|
||||
)
|
||||
|
||||
try:
|
||||
result = GlintRemovalStep.run(
|
||||
img_path=config.get('img_path'),
|
||||
method=config.get('method', 'subtract_nir'),
|
||||
start_wave=config.get('start_wave'),
|
||||
end_wave=config.get('end_wave'),
|
||||
json_path=config.get('json_path'),
|
||||
left_shoulder_wave=config.get('left_shoulder_wave'),
|
||||
valley_wave=config.get('valley_wave'),
|
||||
right_shoulder_wave=config.get('right_shoulder_wave'),
|
||||
water_mask=water_mask_path,
|
||||
interpolate_zeros=config.get('interpolate_zeros', False),
|
||||
interpolation_method=config.get('interpolation_method', 'nearest'),
|
||||
enabled=config.get('enabled', True),
|
||||
kutser_shp_path=config.get('kutser_shp_path'),
|
||||
oxy_band=config.get('oxy_band', 38),
|
||||
lower_oxy=config.get('lower_oxy', 36),
|
||||
upper_oxy=config.get('upper_oxy', 49),
|
||||
nir_band=config.get('nir_band', 47),
|
||||
nir_lower=config.get('nir_lower', 25),
|
||||
nir_upper=config.get('nir_upper', 37),
|
||||
goodman_A=config.get('goodman_A', 0.000019),
|
||||
goodman_B=config.get('goodman_B', 0.1),
|
||||
hedley_shp_path=config.get('hedley_shp_path'),
|
||||
hedley_nir_band=config.get('hedley_nir_band', 47),
|
||||
sugar_bounds=config.get('sugar_bounds'),
|
||||
sugar_sigma=config.get('sugar_sigma', 1.0),
|
||||
sugar_estimate_background=config.get('sugar_estimate_background', True),
|
||||
sugar_glint_mask_method=config.get('sugar_glint_mask_method', 'cdf'),
|
||||
sugar_iter=config.get('sugar_iter', 3),
|
||||
sugar_termination_thresh=config.get('sugar_termination_thresh', 20.0),
|
||||
deglint_dir=str(context.deglint_dir),
|
||||
water_mask_dir=str(context.water_mask_dir),
|
||||
callback=context.notify,
|
||||
output_path=config.get('output_path'),
|
||||
)
|
||||
|
||||
context.deglint_img_path = result
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤3: 耀斑去除", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'deglint_img_path': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤3: 耀斑去除", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
64
src/core/handlers/step4_sampling.py
Normal file
@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step4 处理器:预测采样点生成
|
||||
|
||||
将原 WaterQualityInversionPipeline.step4_sampling() 方法
|
||||
剥离为独立的 Step4SamplingHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.prediction_step import PredictionStep
|
||||
|
||||
|
||||
class Step4SamplingHandler(BaseStepHandler):
|
||||
"""步骤4:生成预测采样点并提取光谱。
|
||||
|
||||
对应 config key: 'step4_sampling'
|
||||
委托类: PredictionStep.generate_sampling_points()
|
||||
"""
|
||||
|
||||
step_key = 'step4_sampling'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
deglint_img_path = self._resolve_path(
|
||||
config.get('deglint_img_path'), context.deglint_img_path, 'deglint_img'
|
||||
)
|
||||
water_mask_path = self._resolve_path(
|
||||
config.get('water_mask_path'), context.water_mask_path, 'water_mask'
|
||||
)
|
||||
glint_mask_path = self._resolve_path(
|
||||
config.get('glint_mask_path'), context.glint_mask_path, 'glint_mask'
|
||||
)
|
||||
|
||||
try:
|
||||
result = PredictionStep.generate_sampling_points(
|
||||
deglint_img_path=deglint_img_path,
|
||||
interval=config.get('interval', 50),
|
||||
sample_radius=config.get('sample_radius', 5),
|
||||
chunk_size=config.get('chunk_size', 1000),
|
||||
water_mask_path=water_mask_path,
|
||||
glint_mask_path=glint_mask_path,
|
||||
output_dir=str(context.sampling_dir),
|
||||
use_adaptive_sampling=config.get('use_adaptive_sampling', True),
|
||||
)
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤4: 生成预测采样点", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'sampling_csv_path': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤4: 生成预测采样点", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
50
src/core/handlers/step5_process_csv.py
Normal file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step5 处理器:CSV 数据处理
|
||||
|
||||
将原 WaterQualityInversionPipeline.step5_process_csv() 方法
|
||||
剥离为独立的 Step5ProcessCsvHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.data_preparation_step import DataPreparationStep
|
||||
|
||||
|
||||
class Step5ProcessCsvHandler(BaseStepHandler):
|
||||
"""步骤5:处理 CSV 文件,筛选剔除异常值。
|
||||
|
||||
对应 config key: 'step5_clean'
|
||||
委托类: DataPreparationStep.process_csv()
|
||||
"""
|
||||
|
||||
step_key = 'step5_clean'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
try:
|
||||
result = DataPreparationStep.process_csv(
|
||||
csv_path=config.get('csv_path'),
|
||||
output_dir=str(context.processed_data_dir),
|
||||
)
|
||||
|
||||
context.processed_csv_path = result
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤5: 处理CSV文件", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'processed_csv_path': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤5: 处理CSV文件", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
66
src/core/handlers/step6_extract_spectra.py
Normal file
@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step6 处理器:训练样本点光谱提取
|
||||
|
||||
将原 WaterQualityInversionPipeline.step6_extract_spectra() 方法
|
||||
剥离为独立的 Step6ExtractSpectraHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.data_preparation_step import DataPreparationStep
|
||||
|
||||
|
||||
class Step6ExtractSpectraHandler(BaseStepHandler):
|
||||
"""步骤6:根据采样点坐标在去耀斑影像中提取平均光谱。
|
||||
|
||||
对应 config key: 'step6_feature'
|
||||
委托类: DataPreparationStep.extract_training_spectra()
|
||||
"""
|
||||
|
||||
step_key = 'step6_feature'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
deglint_img_path = self._resolve_path(
|
||||
config.get('deglint_img_path'), context.deglint_img_path, 'deglint_img'
|
||||
)
|
||||
csv_path = self._resolve_path(
|
||||
config.get('csv_path'), context.processed_csv_path, 'csv'
|
||||
)
|
||||
glint_mask_path = self._resolve_path(
|
||||
config.get('glint_mask_path'), context.glint_mask_path, 'glint_mask'
|
||||
)
|
||||
|
||||
try:
|
||||
result = DataPreparationStep.extract_training_spectra(
|
||||
deglint_img_path=deglint_img_path,
|
||||
radius=config.get('radius', 5),
|
||||
source_epsg=config.get('source_epsg', 4326),
|
||||
csv_path=csv_path,
|
||||
boundary_path=config.get('boundary_path'),
|
||||
glint_mask_path=glint_mask_path,
|
||||
water_mask_path=context.water_mask_path,
|
||||
output_dir=str(context.training_spectra_dir),
|
||||
)
|
||||
|
||||
context.training_csv_path = result
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤6: 提取训练样本点光谱", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'training_csv_path': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤6: 提取训练样本点光谱", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
58
src/core/handlers/step7_calc_indices.py
Normal file
@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step7 处理器:水质光谱指数计算
|
||||
|
||||
将原 WaterQualityInversionPipeline.step7_calc_indices() 方法
|
||||
剥离为独立的 Step7CalcIndicesHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.data_preparation_step import DataPreparationStep
|
||||
|
||||
|
||||
class Step7CalcIndicesHandler(BaseStepHandler):
|
||||
"""步骤7:根据训练光谱计算水质光谱指数。
|
||||
|
||||
对应 config key: 'step7_index'
|
||||
委托类: DataPreparationStep.calculate_water_quality_indices()
|
||||
"""
|
||||
|
||||
step_key = 'step7_index'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
training_csv_path = self._resolve_path(
|
||||
config.get('training_csv_path'), context.training_csv_path, 'training_csv'
|
||||
)
|
||||
|
||||
try:
|
||||
result = DataPreparationStep.calculate_water_quality_indices(
|
||||
training_csv_path=training_csv_path,
|
||||
formula_csv_file=config.get('formula_csv_file'),
|
||||
formula_names=config.get('formula_names'),
|
||||
output_file=config.get('output_file'),
|
||||
enabled=config.get('enabled', True),
|
||||
output_dir=str(context.indices_dir),
|
||||
)
|
||||
|
||||
context.indices_path = result
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤7: 计算水质光谱指数", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'indices_path': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤7: 计算水质光谱指数", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
58
src/core/handlers/step8_ml_train.py
Normal file
@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step8 处理器:机器学习建模与训练
|
||||
|
||||
将原 WaterQualityInversionPipeline.step8_train_ml() 方法
|
||||
剥离为独立的 Step8MlTrainHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.modeling_step import ModelingStep
|
||||
|
||||
|
||||
class Step8MlTrainHandler(BaseStepHandler):
|
||||
"""步骤8:机器学习建模与训练。
|
||||
|
||||
对应 config key: 'step8_ml_train'
|
||||
委托类: ModelingStep.train_models()
|
||||
"""
|
||||
|
||||
step_key = 'step8_ml_train'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
training_csv_path = self._resolve_path(
|
||||
config.get('training_csv_path'), context.training_csv_path, 'training_csv'
|
||||
)
|
||||
|
||||
try:
|
||||
result = ModelingStep.train_models(
|
||||
feature_start_column=config.get('feature_start_column', '374.285004'),
|
||||
preprocessing_methods=config.get('preprocessing_methods'),
|
||||
model_names=config.get('model_names'),
|
||||
split_methods=config.get('split_methods'),
|
||||
cv_folds=config.get('cv_folds', 5),
|
||||
training_csv_path=training_csv_path,
|
||||
output_dir=str(context.models_dir),
|
||||
_report_generator=context.report_generator,
|
||||
)
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤8: 机器学习建模与训练", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'models_dir': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤8: 机器学习建模与训练", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
64
src/core/handlers/step9_ml_predict.py
Normal file
@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Step9 处理器:机器学习推理预测
|
||||
|
||||
将原 WaterQualityInversionPipeline.step9_predict_ml() 方法
|
||||
剥离为独立的 Step9MlPredictHandler。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.core.handlers.base import BaseStepHandler, PipelineContext
|
||||
from src.core.steps.prediction_step import PredictionStep
|
||||
|
||||
|
||||
class Step9MlPredictHandler(BaseStepHandler):
|
||||
"""步骤9:机器学习推理预测。
|
||||
|
||||
对应 config key: 'step9_ml_predict'
|
||||
委托类: PredictionStep.predict_water_quality()
|
||||
"""
|
||||
|
||||
step_key = 'step9_ml_predict'
|
||||
|
||||
def execute(self, context: PipelineContext, config: dict) -> Dict[str, Any]:
|
||||
step_start_time = time.time()
|
||||
|
||||
sampling_csv_path = self._resolve_path(
|
||||
config.get('sampling_csv_path'), context.sampling_csv_path, 'sampling_csv'
|
||||
)
|
||||
|
||||
models_dir = config.get('models_dir') or str(context.models_dir)
|
||||
|
||||
try:
|
||||
result = PredictionStep.predict_water_quality(
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
models_dir=models_dir,
|
||||
metric=config.get('metric', 'test_r2'),
|
||||
prediction_column=config.get('prediction_column', 'prediction'),
|
||||
output_dir=str(context.prediction_dir / "9_ML_Prediction"),
|
||||
_report_generator=context.report_generator,
|
||||
_external_model=config.get('_external_model'),
|
||||
_external_model_path=config.get('_external_model_path'),
|
||||
_external_models_dict=config.get('_external_models_dict'),
|
||||
_external_model_dir=config.get('_external_model_dir'),
|
||||
)
|
||||
|
||||
context.prediction_files.update(result)
|
||||
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤9: 机器学习推理预测", step_start_time, step_end_time
|
||||
)
|
||||
|
||||
return {'prediction_files': result}
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
context.record_step_time(
|
||||
"步骤9: 机器学习推理预测", step_start_time, step_end_time,
|
||||
status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
@ -13,7 +13,7 @@ from sklearn.svm import SVR
|
||||
from sklearn.ensemble import RandomForestRegressor
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
|
||||
from sklearn.model_selection import GridSearchCV, cross_val_score, KFold, train_test_split
|
||||
from sklearn.model_selection import RandomizedSearchCV, cross_val_score, KFold, train_test_split
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from sklearn.ensemble import GradientBoostingRegressor, AdaBoostRegressor, ExtraTreesRegressor
|
||||
@ -40,7 +40,12 @@ CB_AVAILABLE = False # 注释掉catboost
|
||||
import sys
|
||||
import os
|
||||
|
||||
# PyInstaller 打包环境感知:EXE 模式下强制单核,防止 Windows 派生无限重启
|
||||
is_frozen_env = getattr(sys, 'frozen', False)
|
||||
safe_n_jobs = 1 if is_frozen_env else -1
|
||||
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
from src.core.utils.split_methods import spxy, ks
|
||||
|
||||
|
||||
class WaterQualityModelingBatch:
|
||||
@ -284,11 +289,24 @@ class WaterQualityModelingBatch:
|
||||
# 提取所有目标列(从0列到feature_start_index-1列)
|
||||
y_dict = {}
|
||||
target_columns = data.columns[:feature_start_index]
|
||||
|
||||
print(f"检测到的目标列: {list(target_columns)}")
|
||||
|
||||
print(f"检测到的潜在目标列: {list(target_columns)}")
|
||||
|
||||
# 新增:跳过非预测目标的系统保留列
|
||||
ignore_cols = {'ID', 'id', 'Id', 'Longitude', 'Latitude', 'Lon', 'Lat', 'longitude', 'latitude', 'lon', 'lat', 'Station', 'station'}
|
||||
|
||||
for col_name in target_columns:
|
||||
# 过滤黑名单列
|
||||
if col_name in ignore_cols:
|
||||
print(f" 跳过目标列 '{col_name}': 属于系统保留列或空间坐标")
|
||||
continue
|
||||
|
||||
y_series = data[col_name]
|
||||
|
||||
# 过滤非数值类型列 (避免将纯文本备注等拿去回归)
|
||||
if not pd.api.types.is_numeric_dtype(y_series):
|
||||
print(f" 跳过目标列 '{col_name}': 非数值类型")
|
||||
continue
|
||||
|
||||
# 检查是否有非空值
|
||||
if not y_series.isna().all():
|
||||
y_dict[col_name] = y_series
|
||||
@ -403,159 +421,12 @@ class WaterQualityModelingBatch:
|
||||
return X_train, X_test, y_train, y_test
|
||||
|
||||
def spxy(self, data, label, test_size=0.2):
|
||||
"""
|
||||
SPXY算法划分数据集(考虑X和Y空间的距离)
|
||||
|
||||
Args:
|
||||
data: shape (n_samples, n_features)
|
||||
label: shape (n_samples, )
|
||||
test_size: 测试集比例,默认: 0.2
|
||||
|
||||
Returns:
|
||||
X_train: (n_samples, n_features)
|
||||
X_test: (n_samples, n_features)
|
||||
y_train: (n_samples, )
|
||||
y_test: (n_samples, )
|
||||
"""
|
||||
# 确保 data 和 label 是 NumPy 数组
|
||||
data = data.to_numpy() if isinstance(data, pd.DataFrame) else data
|
||||
label = label.to_numpy() if isinstance(label, pd.Series) else label
|
||||
|
||||
# 备份原始数据和标签
|
||||
x_backup = data
|
||||
y_backup = label
|
||||
|
||||
M = data.shape[0]
|
||||
N = round((1 - test_size) * M)
|
||||
samples = np.arange(M)
|
||||
|
||||
# 归一化标签数据
|
||||
label = (label - np.mean(label)) / np.std(label)
|
||||
D = np.zeros((M, M))
|
||||
Dy = np.zeros((M, M))
|
||||
|
||||
# 计算样本之间的距离
|
||||
for i in range(M - 1):
|
||||
xa = data[i, :]
|
||||
ya = label[i]
|
||||
for j in range((i + 1), M):
|
||||
xb = data[j, :]
|
||||
yb = label[j]
|
||||
D[i, j] = np.linalg.norm(xa - xb)
|
||||
Dy[i, j] = np.linalg.norm(ya - yb)
|
||||
|
||||
# 距离归一化
|
||||
Dmax = np.max(D)
|
||||
Dymax = np.max(Dy)
|
||||
D = D / Dmax + Dy / Dymax
|
||||
|
||||
# 找到最远的两个点
|
||||
maxD = D.max(axis=0)
|
||||
index_row = D.argmax(axis=0)
|
||||
index_column = maxD.argmax()
|
||||
|
||||
m = np.zeros(N, dtype=int)
|
||||
m[0] = index_row[index_column]
|
||||
m[1] = index_column
|
||||
|
||||
dminmax = np.zeros(N)
|
||||
dminmax[1] = D[m[0], m[1]]
|
||||
|
||||
# 根据距离选择训练集
|
||||
for i in range(2, N):
|
||||
pool = np.delete(samples, m[:i])
|
||||
dmin = np.zeros(M - i)
|
||||
for j in range(M - i):
|
||||
indexa = pool[j]
|
||||
d = np.zeros(i)
|
||||
for k in range(i):
|
||||
indexb = m[k]
|
||||
if indexa < indexb:
|
||||
d[k] = D[indexa, indexb]
|
||||
else:
|
||||
d[k] = D[indexb, indexa]
|
||||
dmin[j] = np.min(d)
|
||||
dminmax[i] = np.max(dmin)
|
||||
index = np.argmax(dmin)
|
||||
m[i] = pool[index]
|
||||
|
||||
m_complement = np.delete(samples, m)
|
||||
|
||||
# 划分训练集和测试集
|
||||
X_train = data[m, :]
|
||||
y_train = y_backup[m]
|
||||
X_test = data[m_complement, :]
|
||||
y_test = y_backup[m_complement]
|
||||
|
||||
return X_train, X_test, y_train, y_test
|
||||
"""SPXY算法划分数据集(委托至 src.core.utils.split_methods.spxy)"""
|
||||
return spxy(data, label, test_size=test_size)
|
||||
|
||||
def ks(self, data, label, test_size=0.2):
|
||||
"""
|
||||
Kennard-Stone算法划分数据集
|
||||
|
||||
Args:
|
||||
data: shape (n_samples, n_features)
|
||||
label: shape (n_sample, )
|
||||
test_size: 测试集比例,默认: 0.2
|
||||
|
||||
Returns:
|
||||
X_train: (n_samples, n_features)
|
||||
X_test: (n_samples, n_features)
|
||||
y_train: (n_samples, )
|
||||
y_test: (n_samples, )
|
||||
"""
|
||||
# 确保 data 和 label 是 NumPy 数组
|
||||
data = data.to_numpy() if isinstance(data, pd.DataFrame) else data
|
||||
label = label.to_numpy() if isinstance(label, pd.Series) else label
|
||||
|
||||
M = data.shape[0]
|
||||
N = round((1 - test_size) * M)
|
||||
samples = np.arange(M)
|
||||
|
||||
D = np.zeros((M, M))
|
||||
|
||||
for i in range((M - 1)):
|
||||
xa = data[i, :]
|
||||
for j in range((i + 1), M):
|
||||
xb = data[j, :]
|
||||
D[i, j] = np.linalg.norm(xa - xb)
|
||||
|
||||
maxD = np.max(D, axis=0)
|
||||
index_row = np.argmax(D, axis=0)
|
||||
index_column = np.argmax(maxD)
|
||||
|
||||
m = np.zeros(N)
|
||||
m[0] = np.array(index_row[index_column])
|
||||
m[1] = np.array(index_column)
|
||||
m = m.astype(int)
|
||||
dminmax = np.zeros(N)
|
||||
dminmax[1] = D[m[0], m[1]]
|
||||
|
||||
for i in range(2, N):
|
||||
pool = np.delete(samples, m[:i])
|
||||
dmin = np.zeros((M - i))
|
||||
for j in range((M - i)):
|
||||
indexa = pool[j]
|
||||
d = np.zeros(i)
|
||||
for k in range(i):
|
||||
indexb = m[k]
|
||||
if indexa < indexb:
|
||||
d[k] = D[indexa, indexb]
|
||||
else:
|
||||
d[k] = D[indexb, indexa]
|
||||
dmin[j] = np.min(d)
|
||||
dminmax[i] = np.max(dmin)
|
||||
index = np.argmax(dmin)
|
||||
m[i] = pool[index]
|
||||
|
||||
m_complement = np.delete(np.arange(data.shape[0]), m)
|
||||
|
||||
X_train = data[m, :]
|
||||
y_train = label[m]
|
||||
X_test = data[m_complement, :]
|
||||
y_test = label[m_complement]
|
||||
|
||||
return X_train, X_test, y_train, y_test
|
||||
"""Kennard-Stone算法划分数据集(委托至 src.core.utils.split_methods.ks)"""
|
||||
return ks(data, label, test_size=test_size)
|
||||
|
||||
def split_data(self, X: np.ndarray, y: pd.Series, method: str = "random",
|
||||
test_size: float = 0.2, random_state: int = 42) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
@ -635,23 +506,22 @@ class WaterQualityModelingBatch:
|
||||
elif model_name == 'LightGBM':
|
||||
base_model.set_params(verbose=-1)
|
||||
|
||||
# 网格搜索 - 使用KFold代替StratifiedKFold
|
||||
# 随机搜索 —— 替代穷举式 GridSearchCV,大幅降低寻优时间
|
||||
cv_strategy = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
||||
|
||||
grid_search = GridSearchCV(
|
||||
grid_search = RandomizedSearchCV(
|
||||
base_model,
|
||||
config['params'],
|
||||
n_iter=10,
|
||||
cv=cv_strategy,
|
||||
scoring=scoring,
|
||||
n_jobs=-1,
|
||||
verbose=1
|
||||
n_jobs=safe_n_jobs,
|
||||
random_state=random_state,
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
# 在训练集上训练模型
|
||||
# with parallel_backend("threading", n_jobs=-1):
|
||||
# grid_search.fit(X_train, y_train)
|
||||
grid_search.fit(X_train, y_train)
|
||||
|
||||
|
||||
# 获取最佳模型
|
||||
best_model = grid_search.best_estimator_
|
||||
|
||||
|
||||
@ -315,7 +315,7 @@ def main():
|
||||
|
||||
# 示例1: 使用所有回归方法分析光谱指数
|
||||
print("\n1. 光谱指数与叶绿素a的回归分析:")
|
||||
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\water_quality_results.csv")
|
||||
sample_data = pd.read_csv(r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\water_quality_results.csv")
|
||||
spectral_indices = ['Al10SABI','Am092Bsub']
|
||||
|
||||
results1 = analyzer.batch_single_variable_regression(
|
||||
@ -323,7 +323,7 @@ def main():
|
||||
x_columns=spectral_indices,
|
||||
y_column='Chlorophyll',
|
||||
methods='all',
|
||||
output_file=r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\spectral_indices_regression.csv'
|
||||
output_file=r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\spectral_indices_regression.csv'
|
||||
)
|
||||
|
||||
# # 示例2: 使用特定方法分析反射率波段
|
||||
@ -343,7 +343,7 @@ def main():
|
||||
best_models = analyzer.get_best_models_summary()
|
||||
if not best_models.empty:
|
||||
print(best_models[['x_variable', 'regression_method', 'r_squared', 'equation']].to_string(index=False))
|
||||
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\5_training_spectra\best_models_summary.csv', index=False)
|
||||
best_models.to_csv(r'E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\best_models_summary.csv', index=False)
|
||||
print("\n最佳模型汇总已保存到 'best_models_summary.csv'")
|
||||
#
|
||||
# def advanced_usage_example():
|
||||
|
||||
@ -246,8 +246,8 @@ def non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, outp
|
||||
|
||||
if __name__ == "__main__":
|
||||
algorithm= "chl_a"
|
||||
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\8_non_empirical_models\SS\SS_chl_a.json"
|
||||
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\10_sampling\sampling_spectra.csv"
|
||||
model_info_path= r"E:\code\WQ\pipeline_result\work_dir\6_Spectral_Feature_Extraction\8_non_empirical_models\SS\SS_chl_a.json"
|
||||
coor_spectral_path= r"E:\code\WQ\pipeline_result\work_dir\4_sampling\sampling_spectra.csv"
|
||||
output_path= r"E:\code\WQ\pipeline_result\work_dir\11_12_13_predictions\SS_chl_a.csv"
|
||||
wave_radius=5.0
|
||||
non_empirical_retrieval(algorithm, model_info_path, coor_spectral_path, output_path, wave_radius)
|
||||
24
src/core/pipeline/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Pipeline 调度核心:基于 Context 的内存级依赖注入。
|
||||
|
||||
设计目标:
|
||||
- 用 PipelineContext 替代 dict 散落传参(9 步主路径 + 14 个 step 共享同一份 ctx)
|
||||
- 14 个 step 声明式描述(StepSpec),便于 Web / 异步 / 单元测试复用
|
||||
- 不绑定具体 Pipeline 实现(duck-typed),WorkerThread / Web API / 单测可共用
|
||||
"""
|
||||
|
||||
from .context import (
|
||||
PipelineContext,
|
||||
STEP_MAP_OLD_TO_NEW, STEP_MAP_NEW_TO_OLD,
|
||||
resolve_step_id, ALL_STEP_IDS,
|
||||
)
|
||||
from .runner import (
|
||||
StepSpec, PIPELINE_STEPS, PipelineRunner, PipelineHalt,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PipelineContext", "StepSpec", "PIPELINE_STEPS", "PipelineRunner", "PipelineHalt",
|
||||
"STEP_MAP_OLD_TO_NEW", "STEP_MAP_NEW_TO_OLD",
|
||||
"resolve_step_id", "ALL_STEP_IDS",
|
||||
]
|
||||
148
src/core/pipeline/context.py
Normal file
@ -0,0 +1,148 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
PipelineContext:内存级数据载体,跨 14 个 step 传递路径与元信息。
|
||||
|
||||
设计原则:
|
||||
- 所有路径字段以 `_path` 为后缀(与 step 方法形参命名约定一致)
|
||||
- 字段值可缺省(None),由 StepSpec.requires 在调度时注入
|
||||
- dataclass + field(default_factory=dict) 支持原地增删
|
||||
- 不放 GUI 状态(避免循环依赖)
|
||||
- 不绑具体 step 方法(duck-typed cancellation / log append)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 步骤命名映射(定义在叶子节点,打破循环依赖)
|
||||
# ============================================================
|
||||
|
||||
STEP_MAP_OLD_TO_NEW: Dict[str, str] = {
|
||||
"step5_5": "step7",
|
||||
"step6_5": "step8_non_empirical_modeling",
|
||||
"step6_75": "step9",
|
||||
"step8_5": "step11",
|
||||
"step7": "step8",
|
||||
"step8": "step7",
|
||||
"step9": "step14",
|
||||
"step10": "step4",
|
||||
"step11_ml": "step10",
|
||||
"step11": "step11",
|
||||
}
|
||||
|
||||
STEP_MAP_NEW_TO_OLD: Dict[str, str] = {v: k for k, v in STEP_MAP_OLD_TO_NEW.items()}
|
||||
|
||||
ALL_STEP_IDS: Set[str] = set(STEP_MAP_OLD_TO_NEW.keys()) | set(STEP_MAP_OLD_TO_NEW.values())
|
||||
|
||||
|
||||
def resolve_step_id(step_id: str) -> str:
|
||||
"""将任意 step_id 转换为标准新格式。"""
|
||||
if step_id in STEP_MAP_OLD_TO_NEW:
|
||||
return STEP_MAP_OLD_TO_NEW[step_id]
|
||||
return step_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineContext:
|
||||
"""流水线运行上下文(在 14 个 step 之间传递的内存字典)
|
||||
|
||||
字段命名约定:
|
||||
- 路径类字段名 = panel key 名 = step 形参名(全链路无翻译)
|
||||
- 训练/产物 CSV 用 `_path` 后缀(如 training_csv_path / water_mask_path)
|
||||
- 入参影像/CSV 沿用 panel 原名(img_path / csv_path),无 `_path` 后缀
|
||||
- 目录类字段无 `_path` 后缀(如 models_dir / prediction_dir)
|
||||
- 元信息字段无后缀(如 user_config / status / log)
|
||||
"""
|
||||
|
||||
# ── 11 个 step 的入参/产物(按 step 顺序排列;字段名 = panel key = step 形参) ──
|
||||
img_path: Optional[str] = None # Step 1/2/3 入参:原始影像
|
||||
water_mask_path: Optional[str] = None # Step 1 出 → Step 2/3/7 入
|
||||
glint_mask_path: Optional[str] = None # Step 2 出 → Step 3/7 入
|
||||
deglint_img_path: Optional[str] = None # Step 3 出 → Step 5/7 入
|
||||
csv_path: Optional[str] = None # Step 4/5/6_5/6_75 入参:原始/训练 CSV
|
||||
processed_csv_path: Optional[str] = None # Step 4 出 → Step 5 入
|
||||
training_csv_path: Optional[str] = None # Step 5 出 → Step 5_5/6/6_5/6_75 入
|
||||
boundary_path: Optional[str] = None # Step 5 入参:边界 SHP(panel step5 名)
|
||||
indices_path: Optional[str] = None # Step 5.5 出
|
||||
sampling_csv_path: Optional[str] = None # Step 7 出 → Step 8/8_5/8_75/9 入
|
||||
prediction_csv_path: Optional[str] = None # Step 8 出 → Step 9 入
|
||||
distribution_map_path: Optional[str] = None # Step 9 出
|
||||
boundary_shp_path: Optional[str] = None # Step 9 入参:边界 SHP(panel step9 名)
|
||||
formula_csv_path: Optional[str] = None # Step 8_75 入参:公式 CSV
|
||||
|
||||
# ── 目录类(命名不带 _path 以示区别) ──
|
||||
models_dir: Optional[str] = None
|
||||
prediction_dir: Optional[str] = None
|
||||
work_dir: Optional[str] = None
|
||||
|
||||
# ── Step 6 训练产物(AutoML 模式有,常规模式为空) ──
|
||||
model_files: List[str] = field(default_factory=list)
|
||||
|
||||
# ── 元信息(三件套:用户传的配置 / 取消事件 / 状态) ──
|
||||
user_config: Dict[str, Any] = field(default_factory=dict)
|
||||
cancel_event: Optional[Any] = None # duck-typed threading.Event / asyncio.Event
|
||||
status: Dict[str, str] = field(default_factory=dict) # {step_id: 'start'/'completed'/'skipped'/'error'}
|
||||
log: List[str] = field(default_factory=list)
|
||||
|
||||
# ── 诊断 ──
|
||||
step_timings: Dict[str, float] = field(default_factory=dict)
|
||||
pipeline_start_time: Optional[float] = None
|
||||
pipeline_end_time: Optional[float] = None
|
||||
last_error: Optional[str] = None
|
||||
|
||||
# ── 错误汇总(全流程结束后可用) ──
|
||||
error_summary: List[tuple[str, str]] = field(default_factory=list)
|
||||
# ── 出错时立即停止全流程(默认 False:继续后续步骤) ──
|
||||
breakpoint_on_error: bool = False
|
||||
# ── ★ 智能补全锁定步骤列表(由 _auto_fill_missing_steps 自动开启的步骤) ──
|
||||
# GUI 层读取此字段,在运行期间禁用对应面板的启用复选框
|
||||
locked_steps: List[str] = field(default_factory=list)
|
||||
|
||||
# ============================================================
|
||||
# 读写辅助
|
||||
# ============================================================
|
||||
|
||||
def step_id(self, step_id: str) -> str:
|
||||
"""将任意 step_id(可能是旧名)转换为标准新格式。
|
||||
|
||||
用法示例:
|
||||
ctx.status[ctx.step_id('step6_5')] # 'step8_non_empirical_modeling'
|
||||
ctx.user_config[ctx.step_id('step8_5')] # 'step11'
|
||||
"""
|
||||
if step_id in STEP_MAP_OLD_TO_NEW:
|
||||
return STEP_MAP_OLD_TO_NEW[step_id]
|
||||
return step_id
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""原地写入任意属性。
|
||||
|
||||
允许动态字段(如 'report_path')直接挂在 __dict__ 上,
|
||||
避免因静态字段缺失而抛 AttributeError。
|
||||
"""
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""原地读出,缺 key 不抛错。"""
|
||||
return getattr(self, key, default)
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
"""统一软取消检查入口(duck-typed)。
|
||||
|
||||
支持:
|
||||
- threading.Event(.is_set())
|
||||
- asyncio.Event(loop-bound,is_set 同步接口存在)
|
||||
- 自定义 .is_set() / .cancelled 属性
|
||||
"""
|
||||
ev = self.cancel_event
|
||||
if ev is None:
|
||||
return False
|
||||
is_set = getattr(ev, "is_set", None)
|
||||
if callable(is_set):
|
||||
return bool(is_set())
|
||||
return bool(getattr(ev, "cancelled", False))
|
||||
|
||||
def append_log(self, msg: str) -> None:
|
||||
"""写入日志列表(也用于主进程 stdout 调试)。"""
|
||||
self.log.append(msg)
|
||||
650
src/core/pipeline/runner.py
Normal file
@ -0,0 +1,650 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
PipelineRunner:基于 StepSpec 声明式调度 14 个 step。
|
||||
|
||||
设计要点:
|
||||
- StepSpec 声明 requires(ctx 字段名列表)+ produces(ctx 字段名列表)
|
||||
- 命名约定:ctx 字段名 == panel key 名 == step 形参名(全链路无翻译)
|
||||
- 步骤命名:step_id 格式为 stepN 或 stepN_suffix(无小数位),method_name 与 step_id 对齐
|
||||
- 调度顺序:按 PIPELINE_STEPS 列表顺序,requires 缺则 skip
|
||||
- 软取消:在每个 step 前检查 ctx.is_cancelled()
|
||||
- 断点续跑:spec.output_file 已落盘则跳过执行
|
||||
- 错误汇总:全流程结束后 error_summary 记录所有 step 的异常
|
||||
- 预检:run() 入口硬校验 step1 img_path;其余依赖通过智能补全 + 软警告处理
|
||||
- PipelineHalt:外层 run() 不 catch,触发循环 break,实现硬终止
|
||||
- STEP_MAP:旧 step_id → 新 step_id 双向映射,供 GUI 配置兼容使用
|
||||
- duck-typed pipeline:runner 只调 getattr(pipeline, method_name),不强依赖类层级
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from .context import PipelineContext, STEP_MAP_OLD_TO_NEW, STEP_MAP_NEW_TO_OLD, resolve_step_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 终止异常(外层 run() 不 catch,触发循环 break)
|
||||
# ============================================================
|
||||
|
||||
class PipelineHalt(Exception):
|
||||
"""不可恢复的错误,在 run() 循环中抛出后直接 break,不走 Exception 处理分支。
|
||||
|
||||
适用场景:
|
||||
- GUI 层通过 _notify 弹窗拦截后主动抛出的硬终止信号
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================
|
||||
# StepSpec 声明式描述
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class StepSpec:
|
||||
"""单个 step 的元信息(声明式,避免硬编码)"""
|
||||
step_id: str
|
||||
method_name: str
|
||||
requires: List[str] # PipelineContext 字段名列表
|
||||
produces: List[str] = field(default_factory=list) # 写入 ctx 的字段名列表
|
||||
enabled: bool = True
|
||||
parameter_map: Dict[str, str] = field(default_factory=dict)
|
||||
# 当 requires 中任一字段为 None 时是否跳过;默认 True(缺输入就 skip)
|
||||
skip_when_missing: bool = True
|
||||
# 备注(仅用于文档生成 / 调试输出)
|
||||
description: str = ""
|
||||
# ★ 断点续跑:产物文件路径,支持 {work_dir} 占位符(运行时解析)
|
||||
output_file: Optional[str] = None
|
||||
# ★ 预检用:需要验证磁盘文件实际存在的 ctx key 列表
|
||||
required_input_files: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 14 个 step 的声明表(顺序即调度顺序)
|
||||
# step_id / method_name 均不含小数位,与前端显示对齐
|
||||
# output_file / required_input_files 使用 {work_dir} 占位符,由 _resolve_path 展开
|
||||
# ============================================================
|
||||
|
||||
PIPELINE_STEPS: List[StepSpec] = [
|
||||
StepSpec(
|
||||
step_id="step1", method_name="step1_generate_water_mask",
|
||||
requires=["img_path"], produces=["water_mask_path"],
|
||||
required_input_files=["img_path"],
|
||||
output_file="{work_dir}/1_water_mask/water_mask.dat",
|
||||
description="水域掩膜生成(NDWI 或 SHP)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step2", method_name="step2_find_glint_area",
|
||||
requires=["img_path", "water_mask_path"], produces=["glint_mask_path"],
|
||||
required_input_files=["img_path", "water_mask_path"],
|
||||
output_file="{work_dir}/2_Glint_Detection/severe_glint_area.dat",
|
||||
description="耀斑区域检测",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step3", method_name="step3_remove_glint",
|
||||
requires=["img_path", "water_mask_path", "glint_mask_path"],
|
||||
produces=["deglint_img_path"],
|
||||
required_input_files=["img_path", "water_mask_path", "glint_mask_path"],
|
||||
output_file="{work_dir}/3_deglint/deglint.bsq",
|
||||
description="耀斑去除",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step4", method_name="step5_process_csv",
|
||||
requires=["csv_path"], produces=["processed_csv_path"],
|
||||
required_input_files=["csv_path"],
|
||||
output_file="{work_dir}/5_Data_Cleaning/processed_data.csv",
|
||||
description="CSV 异常值清洗",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step5", method_name="step6_extract_spectra",
|
||||
requires=["deglint_img_path", "processed_csv_path", "csv_path", "boundary_path", "glint_mask_path"],
|
||||
produces=["training_csv_path"],
|
||||
parameter_map={
|
||||
"processed_csv_path": "csv_path",
|
||||
"csv_path": "_raw_csv_ignored",
|
||||
},
|
||||
skip_when_missing=False,
|
||||
required_input_files=["deglint_img_path", "processed_csv_path", "boundary_path", "glint_mask_path"],
|
||||
output_file="{work_dir}/6_Spectral_Feature_Extraction/training_spectra.csv",
|
||||
description="实测样本点光谱提取",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step7", method_name="step7_calc_indices",
|
||||
requires=["training_csv_path"], produces=["indices_path", "trad_indices_dir"],
|
||||
required_input_files=["training_csv_path"],
|
||||
output_file="{work_dir}/7_Water_Quality_Indices/training_spectra_indices.csv",
|
||||
description="水质参数指数计算(双轨输出:A轨宽表 + B轨单文件)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step8", method_name="step8_train_ml",
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
required_input_files=["training_csv_path"],
|
||||
output_file="{work_dir}/8_Supervised_Model_Training/best_models.pkl",
|
||||
description="ML 建模(GridSearchCV / AutoML)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step8_non_empirical_modeling",
|
||||
method_name="step8_non_empirical_modeling",
|
||||
requires=["training_csv_path"], produces=["models_dir"],
|
||||
parameter_map={"training_csv_path": "csv_path"},
|
||||
required_input_files=["training_csv_path"],
|
||||
output_file="{work_dir}/8_Non_Empirical_Regression/non_empirical_models.pkl",
|
||||
description="非经验统计回归",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step9", method_name="step9_watercolor_inversion",
|
||||
requires=["deglint_img_path", "water_mask_path"], produces=["watercolor_index_dir"],
|
||||
required_input_files=["deglint_img_path"],
|
||||
output_file="{work_dir}/9_WaterColor_Index_Images",
|
||||
description="水色指数反演(BSQ 影像直接处理)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step10", method_name="step4_sampling",
|
||||
requires=["deglint_img_path", "water_mask_path"], produces=["sampling_csv_path"],
|
||||
required_input_files=["deglint_img_path", "water_mask_path"],
|
||||
output_file="{work_dir}/4_sampling/sampling_spectra.csv",
|
||||
description="整景密集采样点生成 + 光谱提取",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step11_ml", method_name="step9_predict_ml",
|
||||
requires=["sampling_csv_path", "models_dir"], produces=["prediction_csv_path"],
|
||||
required_input_files=["sampling_csv_path", "models_dir"],
|
||||
output_file="{work_dir}/11_12_13_predictions/prediction_results.csv",
|
||||
description="ML 模型预测(采样点)",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step11", method_name="step11_non_empirical_prediction",
|
||||
requires=["sampling_csv_path", "models_dir"], produces=["prediction_dir"],
|
||||
parameter_map={"models_dir": "non_empirical_models_dir"},
|
||||
required_input_files=["sampling_csv_path", "models_dir"],
|
||||
output_file="{work_dir}/11_12_13_predictions/non_empirical_predictions",
|
||||
description="非经验模型预测",
|
||||
),
|
||||
StepSpec(
|
||||
step_id="step14", method_name="step10_map",
|
||||
requires=["prediction_csv_path", "boundary_shp_path"],
|
||||
produces=["distribution_map_path"],
|
||||
required_input_files=["prediction_csv_path", "boundary_shp_path"],
|
||||
output_file="{work_dir}/distribution_map.png",
|
||||
description="克里金插值成图",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# PipelineRunner:执行者
|
||||
# ============================================================
|
||||
|
||||
class PipelineRunner:
|
||||
"""按 StepSpec 调度 14 个 step 方法,支持软取消 + 断点续跑 + 错误汇总。
|
||||
|
||||
用法:
|
||||
ctx = PipelineContext(img_path=..., work_dir=..., user_config=config)
|
||||
runner = PipelineRunner(pipeline_instance)
|
||||
result_ctx = runner.run(ctx, config=config) # 预检通过后开始执行
|
||||
print(result_ctx.error_summary) # [(step_id, error_msg), ...]
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline, steps: Optional[Sequence[StepSpec]] = None):
|
||||
self.pipeline = pipeline
|
||||
self.steps: List[StepSpec] = list(steps) if steps else list(PIPELINE_STEPS)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 主入口
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run(self, ctx: PipelineContext, config=None, skip_list: Optional[List[str]] = None) -> PipelineContext:
|
||||
self.config = config or {}
|
||||
skip_list = skip_list or []
|
||||
logger.info("开始运行完整流程 (Runner 调度模式)...")
|
||||
|
||||
ctx.pipeline_start_time = time.time()
|
||||
error_summary: List[tuple[str, str]] = []
|
||||
skip_set = set(skip_list) if skip_list else set()
|
||||
|
||||
# ── ★ Step1 img_path 硬校验(缺失则立即终止整个流程) ──
|
||||
if not ctx.get("img_path"):
|
||||
msg = "【全流程预检失败】缺少参考影像路径 (img_path),流程无法启动。"
|
||||
ctx.append_log(f"[RUNNER] {msg}")
|
||||
self._notify_step("全流程", "error", msg)
|
||||
ctx.last_error = msg
|
||||
ctx.pipeline_end_time = time.time()
|
||||
return ctx
|
||||
|
||||
# ── ★ 智能补全:扫描 work_dir 默认产物路径,回填 ctx ──
|
||||
self._scan_workdir_outputs(ctx)
|
||||
|
||||
# ── ★ 自动补全缺失步骤:work_dir 有产物则强制开启 + 回填路径 ──
|
||||
self._auto_fill_missing_steps(ctx)
|
||||
|
||||
# ── 软预检警告(不再阻断,仅记录日志)──
|
||||
self._preflight_warnings(ctx)
|
||||
|
||||
# 断点续跑预扫描:ctx 已有产物则记录诊断日志
|
||||
self._restore_outputs_from_ctx(ctx)
|
||||
|
||||
# 1. 暴力上下文注入:将 GUI config 中的所有参数强行塞入 ctx,防丢失
|
||||
for step_id, cfg in self.config.items():
|
||||
if isinstance(cfg, dict):
|
||||
for k, v in cfg.items():
|
||||
if k != 'enabled' and v:
|
||||
setattr(ctx, k, v)
|
||||
|
||||
# 2. 构建依赖提供者映射 (Provider Map)
|
||||
provider_map = {}
|
||||
for step in self.steps:
|
||||
for prod in step.produces:
|
||||
provider_map[prod] = step
|
||||
|
||||
# 3. 强力依赖级联唤醒 (Auto-Wakeup Engine)
|
||||
changed = True
|
||||
woke_up_steps = []
|
||||
while changed:
|
||||
changed = False
|
||||
for step in self.steps:
|
||||
if step.step_id in skip_set:
|
||||
continue # 用户强踢的,绝不唤醒
|
||||
|
||||
step_cfg = self.config.setdefault(step.step_id, {})
|
||||
if not step_cfg.get('enabled', True):
|
||||
continue
|
||||
|
||||
for req in step.requires:
|
||||
# 如果上下文缺这个参数
|
||||
if not (hasattr(ctx, req) and getattr(ctx, req)):
|
||||
provider = provider_map.get(req)
|
||||
if provider and provider.step_id not in skip_set:
|
||||
prov_cfg = self.config.setdefault(provider.step_id, {})
|
||||
if not prov_cfg.get('enabled', True):
|
||||
prov_cfg['enabled'] = True
|
||||
changed = True
|
||||
woke_up_steps.append(provider.step_id)
|
||||
logger.info(f"[*] 自动唤醒: {provider.step_id} (为下游提供 {req})")
|
||||
|
||||
if woke_up_steps:
|
||||
logger.info(f"★ 依赖唤醒完成,共唤醒 {len(woke_up_steps)} 个次/步骤")
|
||||
|
||||
# 4. 正式执行流水线
|
||||
for step in self.steps:
|
||||
# ── 软取消 ──
|
||||
if ctx.is_cancelled():
|
||||
ctx.append_log(f"[RUNNER] 收到取消信号,提前终止 @ {step.step_id}")
|
||||
break
|
||||
|
||||
if step.step_id in skip_set:
|
||||
ctx.status[step.step_id] = "user_skipped"
|
||||
ctx.append_log(
|
||||
f"\n{'='*60}\n"
|
||||
f" ⚠ 用户强制跳过: {step.step_id}({step.description})\n"
|
||||
f" 原因:用户在预检弹窗中勾选「忽略」,已确认跳过\n"
|
||||
f"{'='*60}\n"
|
||||
)
|
||||
self._notify_step(step.step_id, "skipped", "用户强制跳过(预检弹窗)")
|
||||
continue
|
||||
|
||||
step_cfg = self.config.get(step.step_id, {})
|
||||
if not step_cfg.get('enabled', True):
|
||||
continue
|
||||
|
||||
# 4.1 检查磁盘产物:如果已落盘,恢复上下文并跳过(拒绝静默跳过,必须打日志)
|
||||
if step.output_file and os.path.exists(step.output_file):
|
||||
for prod in step.produces:
|
||||
if not (hasattr(ctx, prod) and getattr(ctx, prod)):
|
||||
setattr(ctx, prod, step.output_file)
|
||||
ctx.status[step.step_id] = "skipped"
|
||||
ctx.append_log(f"[CACHE] 产物已存在,跳过运行并恢复上下文: {step.step_id}")
|
||||
self._notify_step(step.step_id, "skipped", "产物已存在(断点续跑)")
|
||||
continue
|
||||
|
||||
# 4.2 依赖死线检查
|
||||
missing = [req for req in step.requires if not (hasattr(ctx, req) and getattr(ctx, req))]
|
||||
if missing:
|
||||
ctx.status[step.step_id] = "skipped"
|
||||
reason = f"缺少必要的上下文参数,自动跳过: {missing}"
|
||||
ctx.append_log(f"[RUNNER] 跳过 {step.step_id},仍缺少必要参数: {missing}")
|
||||
self._notify_step(step.step_id, "skipped", reason)
|
||||
continue
|
||||
|
||||
# 4.3 真正执行
|
||||
ctx.append_log(f"[START] 正在执行步骤: {step.step_id}")
|
||||
self._notify_step(step.step_id, "running", f"正在执行: {step.description}")
|
||||
try:
|
||||
method = getattr(self.pipeline, step.method_name)
|
||||
|
||||
sig = inspect.signature(method)
|
||||
kwargs = {}
|
||||
current_step_cfg = self.config.get(step.step_id, {})
|
||||
|
||||
for param_name in sig.parameters:
|
||||
# 优先级 1:直接使用当前步骤专属配置中的值
|
||||
if param_name in current_step_cfg:
|
||||
kwargs[param_name] = current_step_cfg[param_name]
|
||||
continue
|
||||
|
||||
# 优先级 1.5:【核心修复】硬隔离 output_file,防止被其他步骤的同名变量污染
|
||||
if param_name == 'output_file' and hasattr(step, 'output_file') and step.output_file:
|
||||
work_dir = getattr(ctx, 'work_dir', '')
|
||||
kwargs[param_name] = step.output_file.format(work_dir=work_dir)
|
||||
continue
|
||||
|
||||
# 优先级 2:处理跨步骤的映射逻辑
|
||||
ctx_key = param_name
|
||||
if hasattr(step, 'parameter_map') and step.parameter_map:
|
||||
for k, v in step.parameter_map.items():
|
||||
if v == param_name:
|
||||
ctx_key = k
|
||||
break
|
||||
# 优先级 3:从全局大背包 ctx 中取(排在最后)
|
||||
if hasattr(ctx, ctx_key):
|
||||
kwargs[param_name] = getattr(ctx, ctx_key)
|
||||
|
||||
# 使用解包后的关键字参数调用底层函数
|
||||
result = method(**kwargs)
|
||||
|
||||
# 【产物接力 1】:如果底层函数返回了字典,直接合并到上下文
|
||||
if isinstance(result, dict):
|
||||
for k, v in result.items():
|
||||
setattr(ctx, k, v)
|
||||
|
||||
# 【产物接力 2】:强制通过 StepSpec 的 output_file 模板注入
|
||||
if hasattr(step, 'output_file') and step.output_file:
|
||||
work_dir = getattr(ctx, 'work_dir', '')
|
||||
actual_out_path = step.output_file.format(work_dir=work_dir)
|
||||
for prod in step.produces:
|
||||
if not hasattr(ctx, prod) or not getattr(ctx, prod):
|
||||
setattr(ctx, prod, actual_out_path)
|
||||
logger.info(f"[产物接力] 登记 {prod} = {actual_out_path}")
|
||||
except PipelineHalt:
|
||||
ctx.status[step.step_id] = "error"
|
||||
ctx.append_log(f"[RUNNER] PipelineHalt 硬终止 @ {step.step_id}")
|
||||
self._notify_step(step.step_id, "error", "预检失败,硬终止")
|
||||
break
|
||||
except Exception as e:
|
||||
ctx.status[step.step_id] = "error"
|
||||
error_summary.append((step.step_id, str(e)))
|
||||
ctx.last_error = f"{step.step_id}: {e!r}"
|
||||
ctx.append_log(f"[ERROR] 步骤 {step.step_id} 执行崩溃: {str(e)}")
|
||||
self._notify_step(step.step_id, "error", str(e))
|
||||
break
|
||||
|
||||
ctx.pipeline_end_time = time.time()
|
||||
ctx.error_summary = error_summary
|
||||
return ctx
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# ★ 智能补全:工作目录产物扫描
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _scan_workdir_outputs(self, ctx: PipelineContext) -> None:
|
||||
"""扫描 work_dir 下所有步骤的默认产物路径,若存在则回填 ctx。
|
||||
|
||||
利用 spec.output_file 的 {work_dir} 占位符,展开为实际绝对路径。
|
||||
存在则写入对应的 ctx 字段(produces),供后续步骤直接使用。
|
||||
已在 ctx 中有值的字段不会被覆盖。
|
||||
"""
|
||||
work_dir = ctx.get("work_dir") or ""
|
||||
if not work_dir:
|
||||
return
|
||||
|
||||
for spec in self.steps:
|
||||
if not spec.produces:
|
||||
continue
|
||||
for produce_key in spec.produces:
|
||||
if ctx.get(produce_key):
|
||||
continue # 已有人工填写的值,不覆盖
|
||||
resolved = self._resolve_path(spec.output_file, ctx)
|
||||
if resolved and os.path.exists(resolved):
|
||||
ctx.set(produce_key, resolved)
|
||||
ctx.append_log(
|
||||
f"[AUTO_FILL] 检测到已有产物,回填 {produce_key} = {resolved}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# ★ 智能补全:强制开启被静默跳过的步骤
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _auto_fill_missing_steps(self, ctx: PipelineContext) -> None:
|
||||
"""检查所有 disabled 步骤。
|
||||
|
||||
若某步骤的 output_file 已在 work_dir 落盘(断点续跑),
|
||||
说明该步骤之前已完成但被用户在 GUI 中禁用了。
|
||||
此时系统自动重开启该步骤(forced=True),并将其加入 locked_steps。
|
||||
|
||||
同时,将已落盘的产物路径回填到对应的 ctx 字段,
|
||||
确保下游步骤能正常拿到输入。
|
||||
|
||||
阻断性缺失(step1 img_path)已在 run() 入口硬校验,此处不处理。
|
||||
"""
|
||||
newly_locked: List[str] = []
|
||||
|
||||
for spec in self.steps:
|
||||
if spec.enabled:
|
||||
continue # 用户主动开启的步骤不受影响
|
||||
skip_set = getattr(ctx, '_skip_set', set())
|
||||
if spec.step_id in skip_set:
|
||||
continue # 用户在 PreflightDialog 中手动忽略的步骤不自动补全
|
||||
|
||||
resolved = self._resolve_path(spec.output_file, ctx)
|
||||
if resolved and os.path.exists(resolved):
|
||||
# ── 该步骤已有产物但被禁用 → 自动开启 ──
|
||||
spec.enabled = True
|
||||
ctx.locked_steps.append(spec.step_id)
|
||||
newly_locked.append(spec.step_id)
|
||||
|
||||
# 回填所有产物字段到 ctx
|
||||
for produce_key in spec.produces:
|
||||
if not ctx.get(produce_key):
|
||||
ctx.set(produce_key, resolved)
|
||||
ctx.append_log(
|
||||
f"[AUTO_FILL] 强制开启并回填 {spec.step_id} 产物 {produce_key} = {resolved}"
|
||||
)
|
||||
|
||||
ctx.append_log(
|
||||
f"\n{'='*60}\n"
|
||||
f" ⚡ 智能补全:步骤 {spec.step_id}({spec.description})\n"
|
||||
f" 原因:该步骤在 work_dir 中已有产物但被您在 GUI 中禁用了。\n"
|
||||
f" 操作:系统已自动开启该步骤,产物路径已回填。\n"
|
||||
f" 注意:运行期间该步骤已被锁定,您无法临时关闭。\n"
|
||||
f"{'='*60}\n"
|
||||
)
|
||||
|
||||
if newly_locked:
|
||||
self._notify_step(
|
||||
"全流程",
|
||||
"info",
|
||||
f"智能补全已自动开启 {len(newly_locked)} 个步骤:{newly_locked}"
|
||||
)
|
||||
|
||||
def _resolve_output_for_key(
|
||||
self, produce_key: str, ctx: PipelineContext
|
||||
) -> Optional[str]:
|
||||
"""根据 produces key 查找对应步骤的 output_file 并展开路径。"""
|
||||
for spec in self.steps:
|
||||
if produce_key in spec.produces:
|
||||
return self._resolve_path(spec.output_file, ctx)
|
||||
return None
|
||||
|
||||
def _scan_single_step_outputs(
|
||||
self, spec: StepSpec, ctx: PipelineContext
|
||||
) -> None:
|
||||
"""扫描单个步骤的 work_dir 产物,回填 ctx(不覆盖已有值)。"""
|
||||
if not spec.produces:
|
||||
return
|
||||
for produce_key in spec.produces:
|
||||
if ctx.get(produce_key):
|
||||
continue
|
||||
resolved = self._resolve_path(spec.output_file, ctx)
|
||||
if resolved and os.path.exists(resolved):
|
||||
ctx.set(produce_key, resolved)
|
||||
ctx.append_log(
|
||||
f"[AUTO_FILL] 依赖唤醒后检测到产物,回填 {produce_key} = {resolved}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 软预检警告(不再阻断)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _preflight_warnings(self, ctx: PipelineContext) -> None:
|
||||
"""软预检警告:遍历所有步骤,检测可预见的运行时跳过。
|
||||
|
||||
所有缺失均以 warning 记录日志,不抛异常,不阻止执行。
|
||||
GUI 层可通过回调函数 _notify_step 向用户展示警告列表。
|
||||
"""
|
||||
warnings: List[str] = []
|
||||
|
||||
for spec in self.steps:
|
||||
if not spec.enabled:
|
||||
continue
|
||||
|
||||
# ── Step4 csv_path 缺失警告 ──
|
||||
if spec.step_id == "step4":
|
||||
if not ctx.get("csv_path"):
|
||||
warnings.append(
|
||||
f"[{spec.step_id}] 缺少实测水质数据 (csv_path),"
|
||||
"步骤 5-9 将被自动跳过"
|
||||
)
|
||||
|
||||
# ── 磁盘文件缺失警告(已填充 ctx 但文件实际不存在)──
|
||||
for ctx_key in spec.required_input_files:
|
||||
value = ctx.get(ctx_key)
|
||||
if not value:
|
||||
continue
|
||||
if not os.path.exists(value):
|
||||
warnings.append(
|
||||
f"[{spec.step_id}] 磁盘文件缺失(但 ctx 已回填): {ctx_key} = {value}"
|
||||
)
|
||||
|
||||
if warnings:
|
||||
detail = "\n".join(f" - {w}" for w in warnings)
|
||||
ctx.append_log(
|
||||
f"[RUNNER] 【软预检警告】(流程将继续执行,缺失项将被自动跳过)\n{detail}"
|
||||
)
|
||||
self._notify_step("全流程", "warning", f"预检警告:{len(warnings)} 项\n{detail}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 单步调用
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _invoke(self, spec: StepSpec, ctx: PipelineContext) -> None:
|
||||
"""调一个 step 方法:ctx 路径 → 形参;产出 → ctx 字段。"""
|
||||
ctx.append_log(
|
||||
f"[DEBUG] Step {spec.step_id} requires: {spec.requires}, "
|
||||
f"actual ctx data: {[ctx.get(k) for k in spec.requires]}"
|
||||
)
|
||||
method = getattr(self.pipeline, spec.method_name, None)
|
||||
if method is None:
|
||||
ctx.append_log(f"[RUNNER] 步骤方法缺失: {spec.method_name}(跳过)")
|
||||
ctx.status[spec.step_id] = "skipped"
|
||||
return
|
||||
|
||||
# 1) 把 ctx 路径作为形参注入
|
||||
kwargs: Dict[str, Any] = {}
|
||||
for ctx_key in spec.requires:
|
||||
param_name = spec.parameter_map.get(ctx_key, self._default_param_name(ctx_key))
|
||||
kwargs[param_name] = ctx.get(ctx_key)
|
||||
|
||||
# 2) 允许用户在 ctx.user_config[step_id] 覆盖/补充(非空值才覆盖)
|
||||
user_overrides = ctx.user_config.get(spec.step_id) or {}
|
||||
if isinstance(user_overrides, dict):
|
||||
for k, v in user_overrides.items():
|
||||
if v is not None and v != "":
|
||||
kwargs[k] = v
|
||||
|
||||
# 3) 状态置 start
|
||||
ctx.append_log(
|
||||
f"[RUNNER] -> {spec.method_name}({list(kwargs.keys())})"
|
||||
)
|
||||
ctx.status[spec.step_id] = "start"
|
||||
self._notify_step(spec.step_id, "start", spec.method_name)
|
||||
|
||||
# 4) 执行(外层 run() 统一捕获异常)
|
||||
t0 = time.time()
|
||||
result = method(**kwargs)
|
||||
ctx.status[spec.step_id] = "completed"
|
||||
ctx.step_timings[spec.step_id] = time.time() - t0
|
||||
|
||||
# 5) 产出收割
|
||||
self._harvest(spec, result, ctx)
|
||||
self._notify_step(
|
||||
spec.step_id, "completed",
|
||||
str(result)[:200] if result is not None else "",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 产出收割
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _harvest(self, spec: StepSpec, result: Any, ctx: PipelineContext) -> None:
|
||||
"""把 step 方法返回值灌入 ctx 的 produces 字段。"""
|
||||
if not spec.produces:
|
||||
return
|
||||
if isinstance(result, dict):
|
||||
for produce_key in spec.produces:
|
||||
if produce_key in result:
|
||||
ctx.set(produce_key, result[produce_key])
|
||||
elif result is not None:
|
||||
ctx.set(spec.produces[0], result)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 断点续跑辅助
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_path(
|
||||
self, template: Optional[str], ctx: PipelineContext
|
||||
) -> Optional[str]:
|
||||
"""解析模板中的 {work_dir} 占位符,返回展开后的绝对路径或 None。"""
|
||||
if not template:
|
||||
return None
|
||||
work_dir = ctx.get("work_dir") or ""
|
||||
try:
|
||||
return template.format(work_dir=work_dir)
|
||||
except (KeyError, ValueError):
|
||||
return template
|
||||
|
||||
def _restore_outputs_from_ctx(self, ctx: PipelineContext) -> None:
|
||||
"""诊断日志:记录 ctx 中已有的非 None 产物。"""
|
||||
for spec in self.steps:
|
||||
if not (spec.enabled and spec.produces):
|
||||
continue
|
||||
for key in spec.produces:
|
||||
val = ctx.get(key)
|
||||
if val:
|
||||
ctx.append_log(
|
||||
f"[RUNNER] 断点续跑检测: {spec.step_id} 已有 {key} = {val}"
|
||||
)
|
||||
|
||||
def _restore_ctx_from_output(
|
||||
self, spec: StepSpec, resolved_path: str, ctx: PipelineContext
|
||||
) -> None:
|
||||
"""断点跳过时:将已存在的 output_file 写回 ctx 所有 produces 字段,供下游使用。
|
||||
|
||||
接力棒断链修复:遍历 spec.produces 逐一注册,不遗漏任何下游可能依赖的 key。
|
||||
"""
|
||||
if not spec.produces:
|
||||
return
|
||||
for produce_key in spec.produces:
|
||||
ctx.set(produce_key, resolved_path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 工具
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _default_param_name(ctx_key: str) -> str:
|
||||
"""默认原样返回 ctx 键名作为形参名。特殊缩写由 parameter_map 显式处理。"""
|
||||
return ctx_key
|
||||
|
||||
def _notify_step(self, step_id: str, status: str, message: str) -> None:
|
||||
"""通过 pipeline.callback 通知 GUI 当前步骤状态。"""
|
||||
notify = getattr(self.pipeline, "_notify", None)
|
||||
if callable(notify):
|
||||
try:
|
||||
notify(step_id, status, message)
|
||||
except Exception:
|
||||
pass
|
||||
544
src/core/prediction/automl_trainer.py
Normal file
@ -0,0 +1,544 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Optuna + 智能子采样 AutoML 训练器(路线 B 防爆引擎)。
|
||||
|
||||
为什么需要这个:
|
||||
- 老路径:11 预处理 × 4 模型 × 3 划分 = 132 组 GridSearchCV
|
||||
对中小数据集 10 分钟+,对大数据集 5w+ 行 直接 OOM
|
||||
- AutoML 路径:1 预处理 × N 模型(Optuna 调超参),用智能子采样避开 OOM
|
||||
再用最优超参在**全量数据**上 refit,最终保存单一模型
|
||||
|
||||
设计要点:
|
||||
- 入口 train_with_automl(csv, feature_start_column, model_names, ...)
|
||||
- AutoMLResult dataclass 返回(每个目标列一份)
|
||||
- smart_subsample:N > max_samples 时随机下采样
|
||||
- 失败兜底:optuna 未装 / 全 trial 失败 → fallback 到 WaterQualityModelingBatch
|
||||
- 文件命名规范:{target}_{preprocess}_{model}_AUTOML.joblib
|
||||
- save_data["metadata"]["automl"] = True 标记
|
||||
|
||||
调用:
|
||||
from src.core.prediction.automl_trainer import train_with_automl
|
||||
results = train_with_automl(
|
||||
training_csv_path=".../training_spectra.csv",
|
||||
feature_start_column="374.285004",
|
||||
model_names=["RF", "SVR", "Ridge"],
|
||||
n_trials=20,
|
||||
timeout_sec=300,
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 常量
|
||||
# ============================================================
|
||||
|
||||
# AutoML 寻优阶段允许的最大样本数(避免 OOM)
|
||||
# 5000 样本对 RF/SVR/Ridge 的 Optuna 寻优足够给出稳定 CV
|
||||
DEFAULT_MAX_SAMPLES = 5000
|
||||
|
||||
# 单次 Optuna trial 的默认超时(秒)
|
||||
DEFAULT_TIMEOUT = 300.0
|
||||
|
||||
# 默认 trial 数
|
||||
DEFAULT_N_TRIALS = 20
|
||||
|
||||
# AutoML 输出目录名后缀
|
||||
AUTOML_DIR_SUFFIX = "_AutoML"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 数据类
|
||||
# ============================================================
|
||||
|
||||
@dataclass
|
||||
class AutoMLResult:
|
||||
"""单个目标列的 AutoML 训练结果"""
|
||||
success: bool = False
|
||||
model_path: Optional[str] = None
|
||||
cv_score: float = -float("inf")
|
||||
best_params: Optional[Dict[str, Any]] = None
|
||||
target_column: str = ""
|
||||
preprocessing: str = ""
|
||||
model_name: str = ""
|
||||
n_trials_done: int = 0
|
||||
n_samples_used: int = 0
|
||||
fallback_used: bool = False
|
||||
elapsed_sec: float = 0.0
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 智能子采样
|
||||
# ============================================================
|
||||
|
||||
def smart_subsample(
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
max_samples: int = DEFAULT_MAX_SAMPLES,
|
||||
random_state: int = 42,
|
||||
) -> Tuple[np.ndarray, np.ndarray, bool]:
|
||||
"""当 N > max_samples 时随机下采样;否则原样返回。
|
||||
|
||||
Returns:
|
||||
(X_sub, y_sub, was_subsampled)
|
||||
"""
|
||||
n = X.shape[0]
|
||||
if n <= max_samples:
|
||||
return X, y, False
|
||||
rng = np.random.default_rng(random_state)
|
||||
idx = rng.choice(n, size=max_samples, replace=False)
|
||||
return X[idx], y[idx], True
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 模型工厂
|
||||
# ============================================================
|
||||
|
||||
def _build_model(model_name: str, random_state: int = 42):
|
||||
"""根据英文模型键名构造 sklearn-compatible 模型实例(factory)。"""
|
||||
from sklearn.ensemble import (
|
||||
AdaBoostRegressor, ExtraTreesRegressor, GradientBoostingRegressor,
|
||||
RandomForestRegressor,
|
||||
)
|
||||
from sklearn.linear_model import (
|
||||
ElasticNet, Lasso, LinearRegression, Ridge,
|
||||
)
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.neural_network import MLPRegressor
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.tree import DecisionTreeRegressor
|
||||
|
||||
factory = {
|
||||
"RF": lambda **kw: RandomForestRegressor(random_state=random_state, n_jobs=1, **kw),
|
||||
"ET": lambda **kw: ExtraTreesRegressor(random_state=random_state, n_jobs=1, **kw),
|
||||
"GradientBoosting": lambda **kw: GradientBoostingRegressor(random_state=random_state, **kw),
|
||||
"AdaBoost": lambda **kw: AdaBoostRegressor(random_state=random_state, **kw),
|
||||
"Ridge": lambda **kw: Ridge(**kw),
|
||||
"Lasso": lambda **kw: Lasso(max_iter=5000, **kw),
|
||||
"ElasticNet": lambda **kw: ElasticNet(max_iter=5000, **kw),
|
||||
"LinearRegression": lambda **kw: LinearRegression(**kw),
|
||||
"SVR": lambda **kw: SVR(**kw),
|
||||
"KNN": lambda **kw: KNeighborsRegressor(n_jobs=1, **kw),
|
||||
"MLP": lambda **kw: MLPRegressor(max_iter=500, random_state=random_state, **kw),
|
||||
"DecisionTree": lambda **kw: DecisionTreeRegressor(random_state=random_state, **kw),
|
||||
"PLS": None, # sklearn.cross_decomposition.PLSRegression 暂未集成
|
||||
}
|
||||
builder = factory.get(model_name)
|
||||
if builder is None:
|
||||
return None
|
||||
return builder
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Optuna 超参 search space
|
||||
# ============================================================
|
||||
|
||||
def _get_search_space(model_name: str, trial) -> Dict[str, Any]:
|
||||
"""按模型名返回 Optuna 超参 search space。"""
|
||||
sp: Dict[str, Any] = {}
|
||||
if model_name == "RF":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
|
||||
sp["min_samples_split"] = trial.suggest_int("min_samples_split", 2, 10)
|
||||
sp["min_samples_leaf"] = trial.suggest_int("min_samples_leaf", 1, 5)
|
||||
elif model_name == "ET":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
|
||||
elif model_name == "GradientBoosting":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 300, step=50)
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 8)
|
||||
sp["learning_rate"] = trial.suggest_float("learning_rate", 0.01, 0.3, log=True)
|
||||
elif model_name == "SVR":
|
||||
sp["C"] = trial.suggest_float("C", 0.1, 100.0, log=True)
|
||||
sp["epsilon"] = trial.suggest_float("epsilon", 0.001, 1.0, log=True)
|
||||
sp["kernel"] = trial.suggest_categorical("kernel", ["rbf", "linear"])
|
||||
elif model_name == "KNN":
|
||||
sp["n_neighbors"] = trial.suggest_int("n_neighbors", 3, 20)
|
||||
sp["weights"] = trial.suggest_categorical("weights", ["uniform", "distance"])
|
||||
elif model_name in ("Ridge", "Lasso", "ElasticNet"):
|
||||
sp["alpha"] = trial.suggest_float("alpha", 0.01, 100.0, log=True)
|
||||
if model_name == "ElasticNet":
|
||||
sp["l1_ratio"] = trial.suggest_float("l1_ratio", 0.0, 1.0)
|
||||
elif model_name == "MLP":
|
||||
sp["hidden_layer_sizes"] = trial.suggest_categorical(
|
||||
"hidden_layer_sizes", [(50,), (100,), (50, 50), (100, 50)]
|
||||
)
|
||||
sp["alpha"] = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
|
||||
sp["learning_rate_init"] = trial.suggest_float("learning_rate_init", 1e-4, 1e-2, log=True)
|
||||
elif model_name == "DecisionTree":
|
||||
sp["max_depth"] = trial.suggest_int("max_depth", 3, 20)
|
||||
sp["min_samples_split"] = trial.suggest_int("min_samples_split", 2, 10)
|
||||
elif model_name == "AdaBoost":
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 30, 200, step=30)
|
||||
sp["learning_rate"] = trial.suggest_float("learning_rate", 0.01, 1.0, log=True)
|
||||
else:
|
||||
sp["n_estimators"] = trial.suggest_int("n_estimators", 50, 200, step=50)
|
||||
return sp
|
||||
|
||||
|
||||
def _make_objective(model_name: str, X: np.ndarray, y: np.ndarray,
|
||||
cv_folds: int, random_state: int):
|
||||
"""构造 Optuna objective(5 折 CV R²)。"""
|
||||
from sklearn.model_selection import KFold, cross_val_score
|
||||
|
||||
def objective(trial):
|
||||
params = _get_search_space(model_name, trial)
|
||||
try:
|
||||
builder = _build_model(model_name, random_state=random_state)
|
||||
if builder is None:
|
||||
return -1.0
|
||||
model = builder(**params)
|
||||
kf = KFold(n_splits=cv_folds, shuffle=True, random_state=random_state)
|
||||
scores = cross_val_score(model, X, y, cv=kf, scoring="r2", n_jobs=1)
|
||||
return float(np.mean(scores))
|
||||
except Exception:
|
||||
return -1.0
|
||||
|
||||
return objective
|
||||
|
||||
|
||||
def _refit_full(model_name: str, best_params: Dict[str, Any],
|
||||
X: np.ndarray, y: np.ndarray, random_state: int):
|
||||
"""用 best params 在**全量数据**上 refit。"""
|
||||
builder = _build_model(model_name, random_state=random_state)
|
||||
if builder is None:
|
||||
return None
|
||||
model = builder(**best_params)
|
||||
model.fit(X, y)
|
||||
return model
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 失败兜底(回退到老 GridSearchCV 路径)
|
||||
# ============================================================
|
||||
|
||||
def _fallback_train(
|
||||
training_csv_path: str,
|
||||
feature_start_column,
|
||||
preprocessing: str,
|
||||
model_name: str,
|
||||
split_method: str,
|
||||
cv_folds: int,
|
||||
output_dir: Path,
|
||||
target_column: str,
|
||||
) -> AutoMLResult:
|
||||
"""AutoML 失败时调老 WaterQualityModelingBatch。
|
||||
|
||||
返回的 AutoMLResult.fallback_used=True。
|
||||
"""
|
||||
try:
|
||||
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||
except ImportError as e:
|
||||
return AutoMLResult(
|
||||
success=False, error=f"fallback 导入失败: {e!r}", fallback_used=True,
|
||||
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
|
||||
)
|
||||
|
||||
try:
|
||||
out_dir = output_dir / preprocessing
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
modeler = WaterQualityModelingBatch(str(out_dir))
|
||||
modeler.train_models_batch(
|
||||
csv_path=training_csv_path,
|
||||
feature_start_column=feature_start_column,
|
||||
preprocessing_methods=[preprocessing],
|
||||
model_names=[model_name],
|
||||
split_methods=[split_method],
|
||||
cv_folds=cv_folds,
|
||||
)
|
||||
# 找产出
|
||||
candidates = list(out_dir.rglob(f"{target_column}_{preprocessing}_{model_name}.joblib"))
|
||||
model_path = str(candidates[0]) if candidates else None
|
||||
return AutoMLResult(
|
||||
success=model_path is not None,
|
||||
model_path=model_path,
|
||||
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
|
||||
fallback_used=True,
|
||||
metadata={"source": "WaterQualityModelingBatch"},
|
||||
)
|
||||
except Exception as e:
|
||||
return AutoMLResult(
|
||||
success=False, error=f"fallback 失败: {e!r}", fallback_used=True,
|
||||
target_column=target_column, preprocessing=preprocessing, model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 主入口
|
||||
# ============================================================
|
||||
|
||||
def train_with_automl(
|
||||
training_csv_path: str,
|
||||
feature_start_column,
|
||||
preprocessing_methods: Optional[List[str]] = None,
|
||||
model_names: Optional[List[str]] = None,
|
||||
split_methods: Optional[List[str]] = None,
|
||||
cv_folds: int = 5,
|
||||
output_dir: Optional[str] = None,
|
||||
n_trials: int = DEFAULT_N_TRIALS,
|
||||
timeout_sec: float = DEFAULT_TIMEOUT,
|
||||
max_samples: int = DEFAULT_MAX_SAMPLES,
|
||||
random_state: int = 42,
|
||||
callback: Optional[Callable[[str, str, str], None]] = None,
|
||||
) -> List[AutoMLResult]:
|
||||
"""用 Optuna + 子采样跑 AutoML。失败时自动回退到 GridSearchCV。
|
||||
|
||||
Args:
|
||||
training_csv_path: 训练用 CSV(Step 5 产物 training_spectra.csv)
|
||||
feature_start_column: 特征起始列名或索引(之前所有列视为目标 y)
|
||||
preprocessing_methods: 候选预处理列表(**仅用第 1 个**,避免笛卡尔爆炸)
|
||||
model_names: 候选模型列表(每个都会跑一遍 Optuna)
|
||||
split_methods: 候选数据划分列表(AutoML 仅用第 1 个)
|
||||
cv_folds: 交叉验证折数
|
||||
output_dir: 输出目录(默认 <models_dir>_AutoML)
|
||||
n_trials: 单模型 Optuna trial 数
|
||||
timeout_sec: 单模型超时(秒),到时强制停止
|
||||
max_samples: 寻优阶段允许的最大样本数
|
||||
callback: 状态回调 callback(step_name, status, message)
|
||||
|
||||
Returns:
|
||||
List[AutoMLResult],每个目标列一份结果
|
||||
"""
|
||||
def notify(status: str, msg: str = "") -> None:
|
||||
if callback:
|
||||
callback("步骤6_AutoML", status, msg)
|
||||
|
||||
# ---- 1) 参数默认值 ----
|
||||
if preprocessing_methods is None:
|
||||
preprocessing_methods = ["MMS"]
|
||||
if model_names is None:
|
||||
model_names = ["RF", "SVR", "Ridge"]
|
||||
if split_methods is None:
|
||||
split_methods = ["spxy"]
|
||||
|
||||
# 决策:仅用第一个预处理 + 第一个划分,避免笛卡尔爆炸
|
||||
preproc = preprocessing_methods[0]
|
||||
split_method = split_methods[0]
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = "./8_Supervised_Model_Training_AutoML"
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
preproc_dir = out_dir / preproc
|
||||
preproc_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---- 2) 加载数据 ----
|
||||
notify("start", f"AutoML 训练开始 (n_trials={n_trials}, timeout={timeout_sec}s, max_samples={max_samples})")
|
||||
if not Path(training_csv_path).exists():
|
||||
return [AutoMLResult(success=False, error=f"训练 CSV 不存在: {training_csv_path}")]
|
||||
|
||||
df = pd.read_csv(training_csv_path)
|
||||
|
||||
# 提取目标列(feature_start_column 之前所有数值列)
|
||||
if isinstance(feature_start_column, int):
|
||||
y_cols = [c for c in df.columns[:feature_start_column]
|
||||
if pd.api.types.is_numeric_dtype(df[c])]
|
||||
else:
|
||||
try:
|
||||
idx = list(df.columns).index(feature_start_column)
|
||||
y_cols = [c for c in df.columns[:idx]
|
||||
if pd.api.types.is_numeric_dtype(df[c])]
|
||||
except ValueError:
|
||||
y_cols = []
|
||||
|
||||
if not y_cols:
|
||||
notify("error", "AutoML: 未识别出目标列(feature_start_column 之前的所有数值列)")
|
||||
return [AutoMLResult(success=False, error="未识别出目标列")]
|
||||
|
||||
feat_cols = [c for c in df.columns if c not in y_cols]
|
||||
X_all = df[feat_cols].values.astype(np.float64)
|
||||
|
||||
# ---- 3) 预处理(仅第一项) ----
|
||||
if preproc != "None":
|
||||
try:
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
processed = Preprocessing(preproc, df[feat_cols])
|
||||
if isinstance(processed, pd.DataFrame):
|
||||
X_all = processed.values.astype(np.float64)
|
||||
else:
|
||||
X_all = np.asarray(processed, dtype=np.float64)
|
||||
except Exception as e:
|
||||
notify("warning", f"预处理 {preproc} 失败: {e!r},改用 None")
|
||||
preproc = "None"
|
||||
|
||||
# ---- 4) 检查 Optuna 是否可用 ----
|
||||
try:
|
||||
import optuna
|
||||
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
||||
optuna_available = True
|
||||
except ImportError:
|
||||
optuna_available = False
|
||||
notify("warning", "optuna 未安装,全目标列回退到 GridSearchCV(pip install \"optuna>=3.6\")")
|
||||
|
||||
# ---- 5) 逐 target 跑 ----
|
||||
results: List[AutoMLResult] = []
|
||||
total = len(y_cols)
|
||||
per_model_timeout = max(10.0, timeout_sec / max(1, len(model_names)))
|
||||
|
||||
for ti, tgt in enumerate(y_cols, 1):
|
||||
t0 = time.time()
|
||||
yv = df[tgt].values.astype(np.float64)
|
||||
mask = ~np.isnan(yv)
|
||||
X_t = X_all[mask]
|
||||
y_t = yv[mask]
|
||||
|
||||
if X_t.shape[0] < cv_folds * 2:
|
||||
notify("warning", f"目标 {tgt}: 有效样本 {X_t.shape[0]} 不足,跳过")
|
||||
results.append(AutoMLResult(
|
||||
success=False, target_column=tgt, error=f"样本不足({X_t.shape[0]})",
|
||||
preprocessing=preproc,
|
||||
))
|
||||
continue
|
||||
|
||||
X_sub, y_sub, was_sub = smart_subsample(X_t, y_t, max_samples=max_samples, random_state=random_state)
|
||||
if was_sub:
|
||||
notify("info", f"目标 {tgt}: {X_t.shape[0]} 样本 → 子采样 {X_sub.shape[0]}(寻优用)")
|
||||
|
||||
best_overall = AutoMLResult(success=False, target_column=tgt, preprocessing=preproc)
|
||||
|
||||
if not optuna_available:
|
||||
# 全目标列一次性 fallback
|
||||
best_overall = _fallback_train(
|
||||
training_csv_path, feature_start_column, preproc, model_names[0], split_method,
|
||||
cv_folds, out_dir, tgt,
|
||||
)
|
||||
else:
|
||||
for model_name in model_names:
|
||||
try:
|
||||
builder = _build_model(model_name, random_state=random_state)
|
||||
if builder is None:
|
||||
notify("warning", f"模型 {model_name} 暂不支持 AutoML 寻优")
|
||||
continue
|
||||
|
||||
study = optuna.create_study(
|
||||
direction="maximize",
|
||||
sampler=optuna.samplers.TPESampler(seed=random_state),
|
||||
)
|
||||
study.optimize(
|
||||
_make_objective(model_name, X_sub, y_sub, cv_folds, random_state),
|
||||
n_trials=n_trials,
|
||||
timeout=per_model_timeout,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
if study.best_value is None or study.best_value <= -1.0:
|
||||
notify("warning", f"{tgt}/{model_name}: 全部 trial 失败(CV 全部 <= -1)")
|
||||
continue
|
||||
|
||||
# refit on FULL
|
||||
final_model = _refit_full(model_name, study.best_params, X_t, y_t, random_state)
|
||||
if final_model is None:
|
||||
continue
|
||||
|
||||
# 保存
|
||||
import joblib
|
||||
fname = f"{tgt}_{preproc}_{model_name}_AUTOML.joblib"
|
||||
fpath = preproc_dir / fname
|
||||
joblib.dump({
|
||||
"model": final_model,
|
||||
"target_column_name": tgt,
|
||||
"preprocess_method": preproc,
|
||||
"model_name": model_name,
|
||||
"metadata": {
|
||||
"automl": True,
|
||||
"best_params": study.best_params,
|
||||
"cv_score": float(study.best_value),
|
||||
"n_trials_done": len(study.trials),
|
||||
"n_samples_used_full": int(X_t.shape[0]),
|
||||
"n_samples_used_for_search": int(X_sub.shape[0]),
|
||||
"was_subsampled": was_sub,
|
||||
"split_method": split_method,
|
||||
},
|
||||
}, fpath)
|
||||
|
||||
cand = AutoMLResult(
|
||||
success=True,
|
||||
model_path=str(fpath),
|
||||
cv_score=float(study.best_value),
|
||||
best_params=study.best_params,
|
||||
target_column=tgt,
|
||||
preprocessing=preproc,
|
||||
model_name=model_name,
|
||||
n_trials_done=len(study.trials),
|
||||
n_samples_used=int(X_sub.shape[0]),
|
||||
metadata={"refit_on_full": True, "n_samples_full": int(X_t.shape[0])},
|
||||
)
|
||||
if cand.cv_score > best_overall.cv_score:
|
||||
best_overall = cand
|
||||
except Exception as e:
|
||||
notify("warning", f"目标 {tgt} / 模型 {model_name} 失败: {e!r}")
|
||||
continue
|
||||
|
||||
if not best_overall.success:
|
||||
notify("warning", f"目标 {tgt} 全部 Optuna trial 失败,回退 GridSearchCV")
|
||||
best_overall = _fallback_train(
|
||||
training_csv_path, feature_start_column, preproc, model_names[0], split_method,
|
||||
cv_folds, out_dir, tgt,
|
||||
)
|
||||
|
||||
best_overall.elapsed_sec = time.time() - t0
|
||||
results.append(best_overall)
|
||||
notify("info", f"AutoML 目标 {tgt} 完成 ({ti}/{total}) cv={best_overall.cv_score:.4f}")
|
||||
|
||||
# ---- 6) 汇总 json ----
|
||||
summary_path = out_dir / "automl_summary.json"
|
||||
try:
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump([asdict(r) for r in results], f, ensure_ascii=False, indent=2, default=str)
|
||||
except Exception as e:
|
||||
notify("warning", f"写 automl_summary.json 失败: {e!r}")
|
||||
|
||||
success_n = sum(1 for r in results if r.success)
|
||||
fallback_n = sum(1 for r in results if r.fallback_used)
|
||||
notify("completed", f"AutoML 训练完成 {success_n}/{len(results)} 成功({fallback_n} 走 fallback),汇总 {summary_path}")
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI 自测
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(description="AutoML 训练器 CLI 自测")
|
||||
p.add_argument("--csv", required=True, help="训练用 CSV(feature_start_column 之前的列为目标 y)")
|
||||
p.add_argument("--feature-start", default="0", help="特征起始列名或索引(默认 0)")
|
||||
p.add_argument("--n-trials", type=int, default=DEFAULT_N_TRIALS)
|
||||
p.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT)
|
||||
p.add_argument("--max-samples", type=int, default=DEFAULT_MAX_SAMPLES)
|
||||
p.add_argument("--out", default="./8_Supervised_Model_Training_AutoML")
|
||||
args = p.parse_args()
|
||||
|
||||
# 智能推断 feature_start_column 类型
|
||||
fsc: Any = args.feature_start
|
||||
try:
|
||||
fsc = int(fsc)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
res = train_with_automl(
|
||||
training_csv_path=args.csv,
|
||||
feature_start_column=fsc,
|
||||
n_trials=args.n_trials,
|
||||
timeout_sec=args.timeout,
|
||||
max_samples=args.max_samples,
|
||||
output_dir=args.out,
|
||||
)
|
||||
print(f"\n训练完成 {len(res)} 个目标")
|
||||
for r in res:
|
||||
marker = "✓" if r.success else "✗"
|
||||
fb = " [fallback]" if r.fallback_used else ""
|
||||
print(f" {marker} {r.target_column}: cv={r.cv_score:.4f} path={r.model_path}{fb}")
|
||||
@ -3,9 +3,9 @@
|
||||
"""
|
||||
自定义回归预测模块
|
||||
|
||||
该模块根据9_Custom_Regression_Modeling文件夹中的CSV信息,批量预测水质指数。
|
||||
该模块根据13_Custom_Regression文件夹中的CSV信息,批量预测水质指数。
|
||||
处理流程:
|
||||
1. 读取9_Custom_Regression_Modeling文件夹中的CSV文件
|
||||
1. 读取13_Custom_Regression文件夹中的CSV文件
|
||||
2. 根据r_squared选择最佳模型(指数公式+反演公式)
|
||||
3. 使用指数公式计算光谱指数值
|
||||
4. 使用反演公式计算水质参数值
|
||||
@ -38,12 +38,12 @@ class CustomRegressionPredictor:
|
||||
"""
|
||||
自定义回归预测器
|
||||
|
||||
基于9_Custom_Regression_Modeling文件夹中的回归模型CSV文件,
|
||||
基于13_Custom_Regression文件夹中的回归模型CSV文件,
|
||||
进行水质参数的批量预测。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
regression_models_dir: str = "9_Custom_Regression_Modeling",
|
||||
regression_models_dir: str = "13_Custom_Regression",
|
||||
formula_csv_path: Optional[str] = None,
|
||||
output_dir: str = "prediction_results",
|
||||
log_level: int = logging.INFO):
|
||||
@ -102,7 +102,7 @@ class CustomRegressionPredictor:
|
||||
|
||||
def load_regression_models(self) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
加载9_Custom_Regression_Modeling文件夹中的所有CSV文件
|
||||
加载13_Custom_Regression文件夹中的所有CSV文件
|
||||
|
||||
支持的CSV格式:
|
||||
- 回归结果CSV包含列:y_variable, x_variable, equation, r_squared
|
||||
@ -621,7 +621,7 @@ def main():
|
||||
|
||||
parser = argparse.ArgumentParser(description='自定义回归预测模块')
|
||||
parser.add_argument('--input_csv', required=True, help='输入的光谱采样CSV文件路径')
|
||||
parser.add_argument('--models_dir', default='9_Custom_Regression_Modeling',
|
||||
parser.add_argument('--models_dir', default='13_Custom_Regression',
|
||||
help='回归模型CSV文件目录')
|
||||
parser.add_argument('--output_dir', default='prediction_results',
|
||||
help='预测结果输出目录')
|
||||
|
||||
@ -13,6 +13,7 @@ import sys
|
||||
import os
|
||||
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
from src.core.utils.split_methods import spxy, ks
|
||||
|
||||
# try:
|
||||
# from modeling import WaterQualityModeling
|
||||
@ -26,30 +27,45 @@ from sklearn.model_selection import train_test_split
|
||||
class WaterQualityInference:
|
||||
"""水质参数反演推理类"""
|
||||
|
||||
def __init__(self, artifacts_dir: str = "models/artifacts"):
|
||||
def __init__(self, artifacts_dir: str = "models/artifacts",
|
||||
external_model=None, external_model_path=None):
|
||||
"""
|
||||
初始化推理类
|
||||
|
||||
Args:
|
||||
artifacts_dir: 模型保存目录
|
||||
external_model: 外部预训练模型对象(来自 GUI 导入,跳过磁盘加载)
|
||||
external_model_path: 外部模型文件路径(仅用于日志)
|
||||
"""
|
||||
self.artifacts_dir = Path(artifacts_dir)
|
||||
if not self.artifacts_dir.exists():
|
||||
print(f"警告: 模型目录不存在: {artifacts_dir},将在需要时创建")
|
||||
|
||||
self.best_model_info = None
|
||||
self.loaded_model_data = None
|
||||
|
||||
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
self.best_model_info = None
|
||||
self.external_model = external_model
|
||||
self.external_model_path = external_model_path
|
||||
|
||||
# 规范化 loaded_model_data:始终为 dict,确保 ['model'] 访问不崩溃
|
||||
if external_model is not None:
|
||||
# 外部传入的是裸模型对象 → 包装为 dict,统一后续 .get('model') 访问
|
||||
self.loaded_model_data = {'model': external_model, 'preprocess_method': 'None'}
|
||||
print(f" 外部模型已规范化: type={type(external_model).__name__}")
|
||||
else:
|
||||
self.loaded_model_data = None
|
||||
|
||||
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
加载sampling生成的CSV数据
|
||||
加载sampling生成的CSV数据(兼容 WQI 增强版 CSV)
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径,前两列为经纬度,其余列为光谱数据
|
||||
csv_path: CSV文件路径
|
||||
旧版:x_coord,y_coord,pixel_x,pixel_y,波长...
|
||||
新版:x_coord,y_coord,WQI_...,波长...
|
||||
|
||||
Returns:
|
||||
coords: 经纬度数据 (DataFrame)
|
||||
spectra: 光谱数据 (DataFrame)
|
||||
coords: 经纬度数据 (DataFrame, 2列)
|
||||
spectra: 纯光谱数据 (DataFrame, 跳过 WQI 列)
|
||||
wqi_df: WQI 指数列 (DataFrame, 0或45列)
|
||||
"""
|
||||
print(f"正在加载采样数据: {csv_path}")
|
||||
|
||||
@ -71,15 +87,35 @@ class WaterQualityInference:
|
||||
coords = data.iloc[:, :2].copy()
|
||||
coords.columns = ['longitude', 'latitude']
|
||||
|
||||
# 从第5列开始为光谱数据(跳过第2、3、4列的其他信息)
|
||||
spectra = data.iloc[:, 4:].copy()
|
||||
# 动态识别光谱列(兼容 sampling_spectra.csv 列顺序变更)
|
||||
# 列名约定:波长为纯数字字符串如 "374.285004";WQI 为 "WQI_xxx" 前缀
|
||||
# 旧版 CSV(无WQI):x_coord,y_coord,pixel_x,pixel_y,波长... → 取 [4:]
|
||||
# 新版 CSV(有WQI):x_coord,y_coord,WQI_...,波长... → 过滤 WQI 列后取光谱
|
||||
all_cols = list(data.columns)
|
||||
spectral_col_indices = []
|
||||
wqi_col_indices = []
|
||||
for i, col in enumerate(all_cols):
|
||||
col_str = str(col)
|
||||
if col_str.startswith('WQI_'):
|
||||
wqi_col_indices.append(i)
|
||||
elif col_str.replace('.', '').lstrip('-').isdigit():
|
||||
# 波长列:纯数字字符串
|
||||
spectral_col_indices.append(i)
|
||||
else:
|
||||
# 其他元数据列(x_coord/y_coord/pixel_x/pixel_y),由 coords 接收
|
||||
pass
|
||||
|
||||
# 光谱列 = 纯数字列(WQI 已被排除)
|
||||
spectra = data.iloc[:, spectral_col_indices].copy() if spectral_col_indices else data.iloc[:, 4:].copy()
|
||||
# WQI 列(用于追加到预测结果输出)
|
||||
wqi_df = data.iloc[:, wqi_col_indices].copy() if wqi_col_indices else pd.DataFrame()
|
||||
|
||||
print(f" 经纬度数据形状: {coords.shape}")
|
||||
print(f" 光谱数据形状: {spectra.shape}")
|
||||
print(f" 光谱数据形状: {spectra.shape} (自动识别波长列,排除 {len(wqi_col_indices)} 个WQI列)")
|
||||
print(f" 经纬度范围: 经度[{coords['longitude'].min():.6f}, {coords['longitude'].max():.6f}], "
|
||||
f"纬度[{coords['latitude'].min():.6f}, {coords['latitude'].max():.6f}]")
|
||||
|
||||
return coords, spectra
|
||||
return coords, spectra, wqi_df
|
||||
|
||||
def random(self, data, label, test_ratio=0.2, random_state=123):
|
||||
"""
|
||||
@ -103,159 +139,12 @@ class WaterQualityInference:
|
||||
return X_train, X_test, y_train, y_test
|
||||
|
||||
def spxy(self, data, label, test_size=0.2):
|
||||
"""
|
||||
SPXY算法划分数据集(考虑X和Y空间的距离)
|
||||
|
||||
Args:
|
||||
data: shape (n_samples, n_features)
|
||||
label: shape (n_samples, )
|
||||
test_size: 测试集比例,默认: 0.2
|
||||
|
||||
Returns:
|
||||
X_train: (n_samples, n_features)
|
||||
X_test: (n_samples, n_features)
|
||||
y_train: (n_samples, )
|
||||
y_test: (n_samples, )
|
||||
"""
|
||||
# 确保 data 和 label 是 NumPy 数组
|
||||
data = data.to_numpy() if isinstance(data, pd.DataFrame) else data
|
||||
label = label.to_numpy() if isinstance(label, pd.Series) else label
|
||||
|
||||
# 备份原始数据和标签
|
||||
x_backup = data
|
||||
y_backup = label
|
||||
|
||||
M = data.shape[0]
|
||||
N = round((1 - test_size) * M)
|
||||
samples = np.arange(M)
|
||||
|
||||
# 归一化标签数据
|
||||
label = (label - np.mean(label)) / np.std(label)
|
||||
D = np.zeros((M, M))
|
||||
Dy = np.zeros((M, M))
|
||||
|
||||
# 计算样本之间的距离
|
||||
for i in range(M - 1):
|
||||
xa = data[i, :]
|
||||
ya = label[i]
|
||||
for j in range((i + 1), M):
|
||||
xb = data[j, :]
|
||||
yb = label[j]
|
||||
D[i, j] = np.linalg.norm(xa - xb)
|
||||
Dy[i, j] = np.linalg.norm(ya - yb)
|
||||
|
||||
# 距离归一化
|
||||
Dmax = np.max(D)
|
||||
Dymax = np.max(Dy)
|
||||
D = D / Dmax + Dy / Dymax
|
||||
|
||||
# 找到最远的两个点
|
||||
maxD = D.max(axis=0)
|
||||
index_row = D.argmax(axis=0)
|
||||
index_column = maxD.argmax()
|
||||
|
||||
m = np.zeros(N, dtype=int)
|
||||
m[0] = index_row[index_column]
|
||||
m[1] = index_column
|
||||
|
||||
dminmax = np.zeros(N)
|
||||
dminmax[1] = D[m[0], m[1]]
|
||||
|
||||
# 根据距离选择训练集
|
||||
for i in range(2, N):
|
||||
pool = np.delete(samples, m[:i])
|
||||
dmin = np.zeros(M - i)
|
||||
for j in range(M - i):
|
||||
indexa = pool[j]
|
||||
d = np.zeros(i)
|
||||
for k in range(i):
|
||||
indexb = m[k]
|
||||
if indexa < indexb:
|
||||
d[k] = D[indexa, indexb]
|
||||
else:
|
||||
d[k] = D[indexb, indexa]
|
||||
dmin[j] = np.min(d)
|
||||
dminmax[i] = np.max(dmin)
|
||||
index = np.argmax(dmin)
|
||||
m[i] = pool[index]
|
||||
|
||||
m_complement = np.delete(samples, m)
|
||||
|
||||
# 划分训练集和测试集
|
||||
X_train = data[m, :]
|
||||
y_train = y_backup[m]
|
||||
X_test = data[m_complement, :]
|
||||
y_test = y_backup[m_complement]
|
||||
|
||||
return X_train, X_test, y_train, y_test
|
||||
"""SPXY算法划分数据集(委托至 src.core.utils.split_methods.spxy)"""
|
||||
return spxy(data, label, test_size=test_size)
|
||||
|
||||
def ks(self, data, label, test_size=0.2):
|
||||
"""
|
||||
Kennard-Stone算法划分数据集
|
||||
|
||||
Args:
|
||||
data: shape (n_samples, n_features)
|
||||
label: shape (n_sample, )
|
||||
test_size: 测试集比例,默认: 0.2
|
||||
|
||||
Returns:
|
||||
X_train: (n_samples, n_features)
|
||||
X_test: (n_samples, n_features)
|
||||
y_train: (n_samples, )
|
||||
y_test: (n_samples, )
|
||||
"""
|
||||
# 确保 data 和 label 是 NumPy 数组
|
||||
data = data.to_numpy() if isinstance(data, pd.DataFrame) else data
|
||||
label = label.to_numpy() if isinstance(label, pd.Series) else label
|
||||
|
||||
M = data.shape[0]
|
||||
N = round((1 - test_size) * M)
|
||||
samples = np.arange(M)
|
||||
|
||||
D = np.zeros((M, M))
|
||||
|
||||
for i in range((M - 1)):
|
||||
xa = data[i, :]
|
||||
for j in range((i + 1), M):
|
||||
xb = data[j, :]
|
||||
D[i, j] = np.linalg.norm(xa - xb)
|
||||
|
||||
maxD = np.max(D, axis=0)
|
||||
index_row = np.argmax(D, axis=0)
|
||||
index_column = np.argmax(maxD)
|
||||
|
||||
m = np.zeros(N)
|
||||
m[0] = np.array(index_row[index_column])
|
||||
m[1] = np.array(index_column)
|
||||
m = m.astype(int)
|
||||
dminmax = np.zeros(N)
|
||||
dminmax[1] = D[m[0], m[1]]
|
||||
|
||||
for i in range(2, N):
|
||||
pool = np.delete(samples, m[:i])
|
||||
dmin = np.zeros((M - i))
|
||||
for j in range((M - i)):
|
||||
indexa = pool[j]
|
||||
d = np.zeros(i)
|
||||
for k in range(i):
|
||||
indexb = m[k]
|
||||
if indexa < indexb:
|
||||
d[k] = D[indexa, indexb]
|
||||
else:
|
||||
d[k] = D[indexb, indexa]
|
||||
dmin[j] = np.min(d)
|
||||
dminmax[i] = np.max(dmin)
|
||||
index = np.argmax(dmin)
|
||||
m[i] = pool[index]
|
||||
|
||||
m_complement = np.delete(np.arange(data.shape[0]), m)
|
||||
|
||||
X_train = data[m, :]
|
||||
y_train = label[m]
|
||||
X_test = data[m_complement, :]
|
||||
y_test = label[m_complement]
|
||||
|
||||
return X_train, X_test, y_train, y_test
|
||||
"""Kennard-Stone算法划分数据集(委托至 src.core.utils.split_methods.ks)"""
|
||||
return ks(data, label, test_size=test_size)
|
||||
|
||||
def split_data(self, X: np.ndarray, y: pd.Series, method: str = "random",
|
||||
test_size: float = 0.2, random_state: int = 42) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
@ -519,6 +408,69 @@ class WaterQualityInference:
|
||||
print(f"正在应用预处理方法: {actual_preprocess_method}")
|
||||
print(f"原始光谱数据形状: {spectra.shape}")
|
||||
|
||||
# ---- 自动特征补全:50 光谱 → 补全至模型训练时的 95 维(WQI 指数) ----
|
||||
# 触发条件:模型期望 n_features_in_ 个特征,但当前 spectra 列数不足
|
||||
# 原因:training_spectra.csv 含 50 光谱 + 45 WQI;sampling_spectra.csv 只有 50 光谱
|
||||
# 做法:与训练端(calculate_all_indices)完全一致的算法列表,实时补全缺失的 45 个 WQI 列
|
||||
model = self.loaded_model_data['model']
|
||||
expected_features = getattr(model, 'n_features_in_', None)
|
||||
|
||||
# ---- 自动特征补全:50 光谱 → 补全至模型训练时的 n_features_in_ 维(WQI 指数) ----
|
||||
if expected_features is not None and spectra.shape[1] < expected_features:
|
||||
print(f"[特征补全] 检测到特征缺口:当前 {spectra.shape[1]} 列 < 模型期望 {expected_features} 列,"
|
||||
f"正在从光谱数据实时计算 WQI 指数...")
|
||||
try:
|
||||
from src.utils.water_index import WaterQualityIndexCalculator
|
||||
calc = WaterQualityIndexCalculator()
|
||||
|
||||
# 提取纯计算方法(排除 find_closest_wavelength 和 calculate_all_indices,
|
||||
# 以及不返回 Series 的辅助方法)
|
||||
algorithm_methods = []
|
||||
for m in dir(calc):
|
||||
if m.startswith('_'):
|
||||
continue
|
||||
if m in ['find_closest_wavelength', 'calculate_all_indices']:
|
||||
continue
|
||||
attr = getattr(calc, m)
|
||||
if callable(attr):
|
||||
algorithm_methods.append(m)
|
||||
|
||||
original_col_count = spectra.shape[1]
|
||||
for algo_name in algorithm_methods:
|
||||
try:
|
||||
algo_func = getattr(calc, algo_name)
|
||||
result = algo_func(spectra)
|
||||
# 只追加返回 Series 且长度为样本数的合法结果
|
||||
if isinstance(result, pd.Series) and len(result) == len(spectra):
|
||||
spectra[algo_name] = result.values
|
||||
else:
|
||||
spectra[algo_name] = np.nan
|
||||
except Exception:
|
||||
spectra[algo_name] = np.nan
|
||||
|
||||
print(f"[特征补全] 完成!光谱列已扩充至 {spectra.shape[1]} 列"
|
||||
f"(追加了 {spectra.shape[1] - original_col_count} 个 WQI 指数)")
|
||||
except Exception as e:
|
||||
print(f"[特征补全] 失败,将使用原始光谱特征: {e}")
|
||||
|
||||
# ---- 防线 1:强制维度对齐(物理截断)----
|
||||
if expected_features is not None and spectra.shape[1] > expected_features:
|
||||
print(f"[精准对齐] 正在将 {spectra.shape[1]} 维特征截断为模型要求的 {expected_features} 维")
|
||||
spectra = spectra.iloc[:, :expected_features]
|
||||
elif expected_features is not None and spectra.shape[1] < expected_features:
|
||||
# 维度不足时填充 0
|
||||
padding_cols = expected_features - spectra.shape[1]
|
||||
for i in range(padding_cols):
|
||||
spectra[f'_padding_{i}'] = 0.0
|
||||
print(f"[精准对齐] 特征不足,填充 {padding_cols} 列 0")
|
||||
|
||||
# ---- 防线 2:彻底清洗无穷大数值----
|
||||
# 防止 WQI 计算中除零/溢出产生 np.inf / -np.inf 导致预处理崩溃
|
||||
spectra = spectra.replace([np.inf, -np.inf], np.nan)
|
||||
spectra = spectra.fillna(0)
|
||||
|
||||
print(f"[特征对齐] 最终输入维度: {spectra.shape}")
|
||||
|
||||
try:
|
||||
# 应用预处理
|
||||
spectra_processed = Preprocessing(actual_preprocess_method, spectra)
|
||||
@ -573,7 +525,8 @@ class WaterQualityInference:
|
||||
raise
|
||||
|
||||
def save_predictions(self, coords: pd.DataFrame, predictions: np.ndarray,
|
||||
output_path: str, prediction_column: str = 'prediction'):
|
||||
output_path: str, prediction_column: str = 'prediction',
|
||||
wqi_columns: Optional[pd.DataFrame] = None):
|
||||
"""
|
||||
保存预测结果
|
||||
|
||||
@ -582,11 +535,15 @@ class WaterQualityInference:
|
||||
predictions: 预测结果
|
||||
output_path: 输出文件路径
|
||||
prediction_column: 预测列名称
|
||||
wqi_columns: Optional[pd.DataFrame] = None
|
||||
"""
|
||||
print(f"正在保存预测结果到: {output_path}")
|
||||
|
||||
# 创建结果DataFrame
|
||||
result_df = coords.copy()
|
||||
# 追加 WQI 水质指数列(如 sampling_spectra.csv 注入了 45 列指数)
|
||||
if wqi_columns is not None and not wqi_columns.empty:
|
||||
result_df = pd.concat([result_df, wqi_columns.reset_index(drop=True)], axis=1)
|
||||
result_df[prediction_column] = predictions
|
||||
|
||||
# 确保输出目录存在
|
||||
@ -654,15 +611,18 @@ class WaterQualityInference:
|
||||
# 1. 加载模型
|
||||
print("\n步骤1: 加载模型")
|
||||
print("-" * 40)
|
||||
if model_file_path:
|
||||
if self.external_model is not None:
|
||||
# 已在 __init__ 中规范化,无需重复赋值
|
||||
print(f" 使用外部预训练模型: type={type(self.external_model).__name__}")
|
||||
elif model_file_path:
|
||||
self.load_specific_model(model_file_path)
|
||||
else:
|
||||
self.load_best_model(metric=metric)
|
||||
|
||||
# 2. 加载采样数据
|
||||
# 2. 加载采样数据(coords=坐标, spectra=纯光谱, wqi_df=45个WQI指数列)
|
||||
print("\n步骤2: 加载采样数据")
|
||||
print("-" * 40)
|
||||
coords, spectra = self.load_sampling_data(sampling_csv_path)
|
||||
coords, spectra, wqi_df = self.load_sampling_data(sampling_csv_path)
|
||||
|
||||
# 3. 数据预处理
|
||||
print("\n步骤3: 数据预处理")
|
||||
@ -674,10 +634,11 @@ class WaterQualityInference:
|
||||
print("-" * 40)
|
||||
predictions = self.predict(spectra_processed)
|
||||
|
||||
# 5. 保存预测结果
|
||||
# 5. 保存预测结果(透传 WQI 列至最终输出文件)
|
||||
print("\n步骤5: 保存预测结果")
|
||||
print("-" * 40)
|
||||
result_df = self.save_predictions(coords, predictions, output_csv_path, prediction_column)
|
||||
result_df = self.save_predictions(coords, predictions, output_csv_path,
|
||||
prediction_column, wqi_df)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("推理流程完成!")
|
||||
@ -701,8 +662,8 @@ class WaterQualityInference:
|
||||
|
||||
info = {
|
||||
"status": "model_loaded",
|
||||
"preprocess_method": self.loaded_model_data['preprocess_method'],
|
||||
"model_name": self.loaded_model_data['model_name'],
|
||||
"preprocess_method": self.loaded_model_data.get('preprocess_method', 'Unknown'),
|
||||
"model_name": self.loaded_model_data.get('model_name', type(self.external_model).__name__ if self.external_model else 'Unknown'),
|
||||
"model_type": str(type(self.loaded_model_data['model'])),
|
||||
"metadata": self.loaded_model_data.get('metadata', {})
|
||||
}
|
||||
@ -747,10 +708,11 @@ class WaterQualityInference:
|
||||
output_file = output_path / f"prediction_{csv_file.name}"
|
||||
|
||||
# 执行推理
|
||||
coords, spectra = self.load_sampling_data(str(csv_file))
|
||||
coords, spectra, wqi_df = self.load_sampling_data(str(csv_file))
|
||||
spectra_processed = self.preprocess_spectra(spectra)
|
||||
predictions = self.predict(spectra_processed)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file), prediction_column)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file),
|
||||
prediction_column, wqi_df)
|
||||
|
||||
results[csv_file.name] = {
|
||||
'output_file': str(output_file),
|
||||
@ -770,10 +732,13 @@ class WaterQualityInference:
|
||||
print(f"\n批量推理完成,共处理 {len(csv_files)} 个文件")
|
||||
return results
|
||||
|
||||
def batch_inference_multi_models(self, models_root_dir: str, sampling_csv_path: str,
|
||||
output_dir: str, metric: str = 'test_r2',
|
||||
def batch_inference_multi_models(self, models_root_dir: str, sampling_csv_path: str,
|
||||
output_dir: str, metric: str = 'test_r2',
|
||||
prediction_column: str = 'prediction',
|
||||
output_format: str = 'csv'):
|
||||
output_format: str = 'csv',
|
||||
external_model=None,
|
||||
external_model_path=None,
|
||||
external_models_dict=None):
|
||||
"""
|
||||
使用多个子文件夹中的模型进行批量推理
|
||||
|
||||
@ -788,28 +753,62 @@ class WaterQualityInference:
|
||||
models_root = Path(models_root_dir)
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 查找所有子文件夹
|
||||
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
|
||||
|
||||
if not subdirs:
|
||||
print(f"在目录 {models_root_dir} 中未找到子文件夹")
|
||||
return
|
||||
|
||||
print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理")
|
||||
print(f"输出格式: {output_format.upper()}")
|
||||
|
||||
|
||||
all_results = {}
|
||||
|
||||
for subdir in subdirs:
|
||||
|
||||
# 优先级 1:_external_models_dict 非空 → 直接用字典的 keys 作为 targets,不扫描磁盘
|
||||
print(f"[BatchInference] 终于收到字典啦!包含模型: {list(external_models_dict.keys()) if external_models_dict else 'None'}")
|
||||
if external_models_dict is not None and len(external_models_dict) > 0:
|
||||
targets = list(external_models_dict.keys())
|
||||
print(f"\n使用外部导入模型字典({len(targets)} 个模型)")
|
||||
print(f"检测到外部导入模型,将预测以下参数: {targets}")
|
||||
elif external_model is not None:
|
||||
print(f"\n使用外部预训练模型: {external_model_path or 'unknown'}")
|
||||
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
|
||||
if not subdirs:
|
||||
print(f"在目录 {models_root_dir} 中未找到子文件夹")
|
||||
return {}
|
||||
print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理")
|
||||
targets = [d.name for d in subdirs]
|
||||
else:
|
||||
subdirs = [d for d in models_root.iterdir() if d.is_dir()]
|
||||
if not subdirs:
|
||||
print(f"在目录 {models_root_dir} 中未找到子文件夹")
|
||||
return {}
|
||||
print(f"找到 {len(subdirs)} 个模型子文件夹进行批量推理")
|
||||
targets = [d.name for d in subdirs]
|
||||
|
||||
print(f"输出格式: {output_format.upper()}")
|
||||
|
||||
for subdir_name in targets:
|
||||
try:
|
||||
subdir_name = subdir.name
|
||||
print(f"\n{'='*60}")
|
||||
print(f"处理模型文件夹: {subdir_name}")
|
||||
print(f"处理模型: {subdir_name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# 创建新的推理实例,使用当前子文件夹作为artifacts_dir
|
||||
model_inferencer = WaterQualityInference(str(subdir))
|
||||
|
||||
# 优先级:字典中该 target 的模型 > 共享单模型 > 磁盘加载
|
||||
effective_model = None
|
||||
if external_models_dict and subdir_name in external_models_dict:
|
||||
effective_model = external_models_dict[subdir_name]
|
||||
print(f" → 使用字典中模型: {type(effective_model).__name__}")
|
||||
elif external_model is not None:
|
||||
effective_model = external_model
|
||||
print(f" → 使用共享外部模型: {type(effective_model).__name__}")
|
||||
|
||||
# artifacts_dir:字典模式优先用 placeholder "./",否则用真实子目录
|
||||
artifacts_dir = (
|
||||
str(models_root / subdir_name)
|
||||
if (models_root / subdir_name).is_dir()
|
||||
else str(models_root)
|
||||
)
|
||||
if effective_model is not None:
|
||||
model_inferencer = WaterQualityInference(
|
||||
artifacts_dir,
|
||||
external_model=effective_model,
|
||||
external_model_path=external_model_path or "",
|
||||
)
|
||||
else:
|
||||
model_inferencer = WaterQualityInference(artifacts_dir)
|
||||
|
||||
# 根据输出格式设置文件扩展名
|
||||
file_ext = f".{output_format}"
|
||||
@ -838,10 +837,10 @@ class WaterQualityInference:
|
||||
}
|
||||
}
|
||||
|
||||
print(f"子文件夹 {subdir_name} 处理完成")
|
||||
print(f"模型 {subdir_name} 处理完成")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理子文件夹 {subdir_name} 失败: {e}")
|
||||
print(f"处理模型 {subdir_name} 失败: {e}")
|
||||
all_results[subdir_name] = {
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
@ -908,10 +907,11 @@ class WaterQualityInference:
|
||||
output_file = output_path / f"{file_stem}{file_ext}"
|
||||
|
||||
# 执行推理
|
||||
coords, spectra = self.load_sampling_data(str(csv_file))
|
||||
coords, spectra, wqi_df = self.load_sampling_data(str(csv_file))
|
||||
spectra_processed = self.preprocess_spectra(spectra)
|
||||
predictions = self.predict(spectra_processed)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file), prediction_column)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file),
|
||||
prediction_column, wqi_df)
|
||||
|
||||
results[file_stem] = {
|
||||
'input_file': str(csv_file),
|
||||
|
||||
@ -24,6 +24,7 @@ from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
|
||||
from sklearn.model_selection import GridSearchCV, cross_val_score, KFold, train_test_split
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||
from sklearn.cross_decomposition import PLSRegression
|
||||
from src.core.utils.split_methods import spxy, ks
|
||||
|
||||
# 第三方模型导入
|
||||
# try:
|
||||
@ -256,133 +257,12 @@ class WaterQualityScatterBatch:
|
||||
return X_train, X_test, y_train, y_test
|
||||
|
||||
def spxy(self, data, label, test_size=0.2):
|
||||
"""SPXY算法划分数据集"""
|
||||
# 确保 data 和 label 是 NumPy 数组
|
||||
data = data.to_numpy() if isinstance(data, pd.DataFrame) else data
|
||||
label = label.to_numpy() if isinstance(label, pd.Series) else label
|
||||
|
||||
# 备份原始数据和标签
|
||||
x_backup = data
|
||||
y_backup = label
|
||||
|
||||
M = data.shape[0]
|
||||
N = round((1 - test_size) * M)
|
||||
samples = np.arange(M)
|
||||
|
||||
# 归一化标签数据
|
||||
label = (label - np.mean(label)) / np.std(label)
|
||||
D = np.zeros((M, M))
|
||||
Dy = np.zeros((M, M))
|
||||
|
||||
# 计算样本之间的距离
|
||||
for i in range(M - 1):
|
||||
xa = data[i, :]
|
||||
ya = label[i]
|
||||
for j in range((i + 1), M):
|
||||
xb = data[j, :]
|
||||
yb = label[j]
|
||||
D[i, j] = np.linalg.norm(xa - xb)
|
||||
Dy[i, j] = np.linalg.norm(ya - yb)
|
||||
|
||||
# 距离归一化
|
||||
Dmax = np.max(D)
|
||||
Dymax = np.max(Dy)
|
||||
D = D / Dmax + Dy / Dymax
|
||||
|
||||
# 找到最远的两个点
|
||||
maxD = D.max(axis=0)
|
||||
index_row = D.argmax(axis=0)
|
||||
index_column = maxD.argmax()
|
||||
|
||||
m = np.zeros(N, dtype=int)
|
||||
m[0] = index_row[index_column]
|
||||
m[1] = index_column
|
||||
|
||||
dminmax = np.zeros(N)
|
||||
dminmax[1] = D[m[0], m[1]]
|
||||
|
||||
# 根据距离选择训练集
|
||||
for i in range(2, N):
|
||||
pool = np.delete(samples, m[:i])
|
||||
dmin = np.zeros(M - i)
|
||||
for j in range(M - i):
|
||||
indexa = pool[j]
|
||||
d = np.zeros(i)
|
||||
for k in range(i):
|
||||
indexb = m[k]
|
||||
if indexa < indexb:
|
||||
d[k] = D[indexa, indexb]
|
||||
else:
|
||||
d[k] = D[indexb, indexa]
|
||||
dmin[j] = np.min(d)
|
||||
dminmax[i] = np.max(dmin)
|
||||
index = np.argmax(dmin)
|
||||
m[i] = pool[index]
|
||||
|
||||
m_complement = np.delete(samples, m)
|
||||
|
||||
# 划分训练集和测试集
|
||||
X_train = data[m, :]
|
||||
y_train = y_backup[m]
|
||||
X_test = data[m_complement, :]
|
||||
y_test = y_backup[m_complement]
|
||||
|
||||
return X_train, X_test, y_train, y_test
|
||||
"""SPXY算法划分数据集(委托至 src.core.utils.split_methods.spxy)"""
|
||||
return spxy(data, label, test_size=test_size)
|
||||
|
||||
def ks(self, data, label, test_size=0.2):
|
||||
"""Kennard-Stone算法划分数据集"""
|
||||
# 确保 data 和 label 是 NumPy 数组
|
||||
data = data.to_numpy() if isinstance(data, pd.DataFrame) else data
|
||||
label = label.to_numpy() if isinstance(label, pd.Series) else label
|
||||
|
||||
M = data.shape[0]
|
||||
N = round((1 - test_size) * M)
|
||||
samples = np.arange(M)
|
||||
|
||||
D = np.zeros((M, M))
|
||||
|
||||
for i in range((M - 1)):
|
||||
xa = data[i, :]
|
||||
for j in range((i + 1), M):
|
||||
xb = data[j, :]
|
||||
D[i, j] = np.linalg.norm(xa - xb)
|
||||
|
||||
maxD = np.max(D, axis=0)
|
||||
index_row = np.argmax(D, axis=0)
|
||||
index_column = np.argmax(maxD)
|
||||
|
||||
m = np.zeros(N)
|
||||
m[0] = np.array(index_row[index_column])
|
||||
m[1] = np.array(index_column)
|
||||
m = m.astype(int)
|
||||
dminmax = np.zeros(N)
|
||||
dminmax[1] = D[m[0], m[1]]
|
||||
|
||||
for i in range(2, N):
|
||||
pool = np.delete(samples, m[:i])
|
||||
dmin = np.zeros((M - i))
|
||||
for j in range((M - i)):
|
||||
indexa = pool[j]
|
||||
d = np.zeros(i)
|
||||
for k in range(i):
|
||||
indexb = m[k]
|
||||
if indexa < indexb:
|
||||
d[k] = D[indexa, indexb]
|
||||
else:
|
||||
d[k] = D[indexb, indexa]
|
||||
dmin[j] = np.min(d)
|
||||
dminmax[i] = np.max(dmin)
|
||||
index = np.argmax(dmin)
|
||||
m[i] = pool[index]
|
||||
|
||||
m_complement = np.delete(np.arange(data.shape[0]), m)
|
||||
|
||||
X_train = data[m, :]
|
||||
y_train = label[m]
|
||||
X_test = data[m_complement, :]
|
||||
y_test = label[m_complement]
|
||||
|
||||
return X_train, X_test, y_train, y_test
|
||||
"""Kennard-Stone算法划分数据集(委托至 src.core.utils.split_methods.ks)"""
|
||||
return ks(data, label, test_size=test_size)
|
||||
|
||||
def split_data(self, X: np.ndarray, y: pd.Series, method: str = "random",
|
||||
test_size: float = 0.2, random_state: int = 42) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
|
||||
20
src/core/steps/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""业务步骤层模块"""
|
||||
|
||||
from src.core.steps.water_mask_step import WaterMaskStep
|
||||
from src.core.steps.glint_detection_step import GlintDetectionStep
|
||||
from src.core.steps.glint_removal_step import GlintRemovalStep
|
||||
from src.core.steps.data_preparation_step import DataPreparationStep
|
||||
from src.core.steps.modeling_step import ModelingStep
|
||||
from src.core.steps.prediction_step import PredictionStep
|
||||
from src.core.steps.mapping_step import MappingStep
|
||||
|
||||
__all__ = [
|
||||
"WaterMaskStep",
|
||||
"GlintDetectionStep",
|
||||
"GlintRemovalStep",
|
||||
"DataPreparationStep",
|
||||
"ModelingStep",
|
||||
"PredictionStep",
|
||||
"MappingStep",
|
||||
]
|
||||
184
src/core/steps/data_preparation_step.py
Normal file
@ -0,0 +1,184 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据准备步骤
|
||||
|
||||
包含 step5_process_csv, step6_extract_spectra, step5_5_calculate_water_quality_indices
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable, Dict
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataPreparationStep:
|
||||
"""数据准备步骤"""
|
||||
|
||||
# ---- Step 4: 处理CSV文件 ----
|
||||
|
||||
@staticmethod
|
||||
def process_csv(
|
||||
csv_path: str,
|
||||
output_dir: Union[str, Path] = "./5_Data_Cleaning",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""处理CSV文件(筛选剔除异常值)"""
|
||||
from src.preprocessing.process_water_quality_data import process_water_quality_data
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "processed_data.csv")
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤4", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤4: 处理CSV文件,筛选剔除异常值")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的处理后CSV文件,直接使用: {output_path}")
|
||||
notify("skipped", f"处理后的CSV文件已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
process_water_quality_data(csv_path, output_path)
|
||||
notify("completed", f"处理后的CSV文件已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
# ---- Step 5: 提取训练样本点光谱 ----
|
||||
|
||||
@staticmethod
|
||||
def extract_training_spectra(
|
||||
deglint_img_path: Optional[str] = None,
|
||||
radius: int = 5,
|
||||
source_epsg: int = 4326,
|
||||
csv_path: Optional[str] = None,
|
||||
boundary_path: Optional[str] = None,
|
||||
glint_mask_path: Optional[str] = None,
|
||||
water_mask_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./6_Spectral_Feature_Extraction",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
|
||||
from src.core.glint_removal.get_spectral import get_spectral_in_coor
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "training_spectra.csv")
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤5: 提取训练样本点的平均光谱")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if deglint_img_path is None:
|
||||
raise ValueError("必须提供 deglint_img_path 参数")
|
||||
if csv_path is None:
|
||||
raise ValueError("必须提供 csv_path 参数")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的训练光谱数据文件,直接使用: {output_path}")
|
||||
notify("skipped", f"训练光谱数据已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
# 确保水体掩膜存在
|
||||
final_boundary_path = boundary_path
|
||||
if final_boundary_path is None and water_mask_path is not None:
|
||||
final_boundary_path = water_mask_path
|
||||
|
||||
# 【新增安全防护】智能拦截矢量 .shp,强制替换为步骤 1 生成的栅格 .dat
|
||||
if final_boundary_path and str(final_boundary_path).lower().endswith('.shp'):
|
||||
# 向上追溯查找 1_water_mask 目录下的 dat 替身
|
||||
possible_dat = Path(deglint_img_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
|
||||
if not possible_dat.exists() and output_path:
|
||||
possible_dat = Path(output_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
|
||||
|
||||
if possible_dat.exists():
|
||||
print(f"💡 智能拦截:检测到输入掩膜为矢量 .shp,自动切换为已生成的栅格掩膜: {possible_dat}")
|
||||
final_boundary_path = str(possible_dat)
|
||||
else:
|
||||
print(f"⚠️ 警告:检测到输入掩膜为 .shp 且未找到对应 .dat 替身,可能导致底层读取失败。")
|
||||
|
||||
flare_path = glint_mask_path
|
||||
if flare_path:
|
||||
print(f"光谱提取使用耀斑掩膜: {flare_path}")
|
||||
|
||||
get_spectral_in_coor(
|
||||
deglint_img_path, csv_path, output_path,
|
||||
radius=radius, flare_path=flare_path,
|
||||
boundary_path=final_boundary_path, source_epsg=source_epsg
|
||||
)
|
||||
|
||||
notify("completed", f"训练光谱数据已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
# ---- Step 5.5: 计算水质光谱指数 ----
|
||||
|
||||
@staticmethod
|
||||
def calculate_water_quality_indices(
|
||||
training_csv_path: Optional[str] = None,
|
||||
formula_csv_file: Optional[str] = None,
|
||||
formula_names: Optional[List[str]] = None,
|
||||
output_file: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
output_dir: Union[str, Path] = "./7_Water_Quality_Indices",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> Optional[str]:
|
||||
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤5.5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤5.5: 计算水质光谱指数(使用band_math方法)")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过水质指数计算(enabled=False)。")
|
||||
notify("skipped", "跳过水质指数计算")
|
||||
return None
|
||||
|
||||
if training_csv_path is None:
|
||||
raise ValueError("必须提供 training_csv_path 参数")
|
||||
if formula_csv_file is None:
|
||||
raise ValueError("必须提供 formula_csv_file 参数")
|
||||
|
||||
if output_file:
|
||||
output_path = str(Path(output_file))
|
||||
else:
|
||||
output_path = str(output_dir / "training_spectra_indices.csv")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的水质指数文件,直接使用: {output_path}")
|
||||
notify("skipped", f"水质指数数据已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
from src.utils.band_math import BandMathCalculator
|
||||
|
||||
calculator = BandMathCalculator(training_csv_path)
|
||||
result_df = calculator.process_formulas_from_csv(
|
||||
formula_csv_file=formula_csv_file,
|
||||
formula_names=formula_names,
|
||||
output_file=output_path
|
||||
)
|
||||
|
||||
if result_df is None:
|
||||
raise ValueError("计算水质指数失败,请检查公式CSV文件格式")
|
||||
|
||||
notify("completed", f"水质指数已保存: {output_path}")
|
||||
return output_path
|
||||
113
src/core/steps/glint_detection_step.py
Normal file
@ -0,0 +1,113 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
步骤2: 耀斑区域检测
|
||||
|
||||
支持多种检测方法: otsu, zscore, percentile, iqr, adaptive, multi_band
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union
|
||||
|
||||
|
||||
class GlintDetectionStep:
|
||||
"""耀斑区域检测步骤"""
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
img_path: str,
|
||||
glint_wave: float = 750.0,
|
||||
method: str = "otsu",
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95.0,
|
||||
iqr_multiplier: float = 1.5,
|
||||
window_size: int = 15,
|
||||
multi_band_waves: Optional[List[float]] = None,
|
||||
sub_method: str = "zscore",
|
||||
weights: Optional[List[float]] = None,
|
||||
max_area: Optional[int] = None,
|
||||
buffer_size: Optional[int] = None,
|
||||
water_mask_path: Optional[str] = None,
|
||||
glint_dir: Union[str, Path] = "./2_Glint_Detection",
|
||||
callback: Optional[callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
执行耀斑区域检测
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
glint_wave: 用于耀斑检测的波段波长(nm)
|
||||
method: 检测方法 ('otsu' | 'zscore' | 'percentile' | 'iqr' | 'adaptive' | 'multi_band')
|
||||
z_threshold: Z-score 方法阈值(默认 2.5)
|
||||
percentile: 百分位数阈值(默认 95.0)
|
||||
iqr_multiplier: IQR 倍数(默认 1.5)
|
||||
window_size: 自适应阈值窗口大小(默认 15)
|
||||
multi_band_waves: 多波段方法的波长列表,如 [750, 800, 850]
|
||||
sub_method: 多波段方法的子方法(默认 'zscore')
|
||||
weights: 多波段方法的权重列表(None 表示等权重)
|
||||
max_area: 最大连通域面积阈值(像素),超过则过滤
|
||||
buffer_size: 岸边缓冲区大小(像素),用于去除岸边附近错误掩膜
|
||||
water_mask_path: 水域掩膜文件路径(dat 格式优先)
|
||||
glint_dir: 工作目录
|
||||
callback: 回调函数
|
||||
|
||||
Returns:
|
||||
耀斑掩膜文件路径 (.dat)
|
||||
"""
|
||||
from src.utils.find_severe_glint_area import find_severe_glint_area
|
||||
|
||||
glint_dir = Path(glint_dir)
|
||||
glint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤2", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤2: 找到耀斑区域")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
# 确定水体掩膜路径
|
||||
if water_mask_path is not None and Path(water_mask_path).exists():
|
||||
final_water_mask_path = water_mask_path
|
||||
else:
|
||||
final_water_mask_path = None
|
||||
|
||||
output_path = str(glint_dir / "severe_glint_area.dat")
|
||||
|
||||
# 跳过已存在的文件
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的耀斑掩膜文件,直接使用: {output_path}")
|
||||
notify("skipped", f"耀斑掩膜已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
# 构建检测参数字典
|
||||
kwargs = {
|
||||
"method": method,
|
||||
"z_threshold": z_threshold,
|
||||
"percentile": percentile,
|
||||
"iqr_multiplier": iqr_multiplier,
|
||||
"window_size": window_size,
|
||||
}
|
||||
if method == "multi_band":
|
||||
if multi_band_waves is not None:
|
||||
kwargs["multi_band_waves"] = multi_band_waves
|
||||
if sub_method is not None:
|
||||
kwargs["sub_method"] = sub_method
|
||||
if weights is not None:
|
||||
kwargs["weights"] = weights
|
||||
if max_area is not None:
|
||||
kwargs["max_area"] = max_area
|
||||
if buffer_size is not None:
|
||||
kwargs["buffer_size"] = buffer_size
|
||||
|
||||
glint_mask_path = find_severe_glint_area(
|
||||
img_path, final_water_mask_path, glint_wave, output_path, **kwargs
|
||||
)
|
||||
|
||||
print(f"耀斑掩膜已生成: {glint_mask_path}")
|
||||
print(f"使用检测方法: {method}")
|
||||
notify("completed", f"耀斑掩膜已生成: {glint_mask_path}")
|
||||
return glint_mask_path
|
||||
375
src/core/steps/glint_removal_step.py
Normal file
@ -0,0 +1,375 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
步骤3: 去除耀斑
|
||||
|
||||
支持多种方法: subtract_nir, regression_slope, oxygen_absorption, kutser, goodman, hedley, sugar
|
||||
|
||||
每种方法都会:
|
||||
1. 准备水域掩膜(支持 shp 自动转 dat)
|
||||
2. 调用对应的算法类执行处理
|
||||
3. 复制 hdr 文件到输出影像
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _safe_rename(src_bsq: str, src_hdr: str, dest_bsq: str, dest_hdr: str) -> str:
|
||||
"""将底层硬编码生成的 .bsq + .hdr 文件对重命名到用户指定的 output_path
|
||||
|
||||
使用 os.remove + os.rename 确保原子覆盖(不等 os.replace 的跨设备行为),
|
||||
resolve() 断路防止同路径 self-rename 报错。
|
||||
|
||||
Returns:
|
||||
dest_bsq 路径
|
||||
"""
|
||||
src_bsq_p = Path(src_bsq)
|
||||
src_hdr_p = Path(src_hdr)
|
||||
dest_bsq_p = Path(dest_bsq)
|
||||
dest_hdr_p = Path(dest_hdr)
|
||||
|
||||
if str(src_bsq_p.resolve()) == str(dest_bsq_p.resolve()):
|
||||
return dest_bsq
|
||||
|
||||
if dest_bsq_p.exists():
|
||||
os.remove(dest_bsq_p)
|
||||
if dest_hdr_p.exists():
|
||||
os.remove(dest_hdr_p)
|
||||
|
||||
if src_bsq_p.exists():
|
||||
os.rename(src_bsq_p, dest_bsq_p)
|
||||
if src_hdr_p.exists():
|
||||
os.rename(src_hdr_p, dest_hdr_p)
|
||||
|
||||
return dest_bsq
|
||||
|
||||
|
||||
class GlintRemovalStep:
|
||||
"""去除耀斑步骤"""
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
img_path: str,
|
||||
method: str = "subtract_nir",
|
||||
start_wave: Optional[float] = None,
|
||||
end_wave: Optional[float] = None,
|
||||
json_path: Optional[str] = None,
|
||||
left_shoulder_wave: Optional[float] = None,
|
||||
valley_wave: Optional[float] = None,
|
||||
right_shoulder_wave: Optional[float] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
interpolated_img_path: Optional[str] = None,
|
||||
interpolate_zeros: bool = False,
|
||||
interpolation_method: str = "nearest",
|
||||
enabled: bool = True,
|
||||
# Kutser 参数
|
||||
kutser_shp_path: Optional[str] = None,
|
||||
oxy_band: int = 38,
|
||||
lower_oxy: int = 36,
|
||||
upper_oxy: int = 49,
|
||||
nir_band: int = 47,
|
||||
# Goodman 参数
|
||||
nir_lower: int = 25,
|
||||
nir_upper: int = 37,
|
||||
goodman_A: float = 0.000019,
|
||||
goodman_B: float = 0.1,
|
||||
# Hedley 参数
|
||||
hedley_shp_path: Optional[str] = None,
|
||||
hedley_nir_band: int = 47,
|
||||
# SUGAR 参数
|
||||
sugar_bounds: Optional[List[tuple]] = None,
|
||||
sugar_sigma: float = 1.0,
|
||||
sugar_estimate_background: bool = True,
|
||||
sugar_glint_mask_method: str = "cdf",
|
||||
sugar_iter: Optional[int] = 3,
|
||||
sugar_termination_thresh: float = 20.0,
|
||||
# 内部工具函数
|
||||
_get_image_geo_info=None,
|
||||
_load_image_as_array=None,
|
||||
_save_bands_as_image=None,
|
||||
_copy_hdr_info=None,
|
||||
_prepare_water_mask_for_algorithm=None,
|
||||
_interpolate_zero_pixels_batch=None,
|
||||
deglint_dir: Union[str, Path] = "./3_deglint",
|
||||
water_mask_dir: Union[str, Path] = "./1_water_mask",
|
||||
callback: Optional[Callable] = None,
|
||||
output_path: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
执行去除耀斑处理
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
method: 去耀斑方法
|
||||
...(其余参数同主类 step3_remove_glint)
|
||||
|
||||
Returns:
|
||||
去除耀斑后的影像文件路径
|
||||
"""
|
||||
from src.core.glint_removal.Kutser import Kutser
|
||||
from src.core.glint_removal.Goodman import Goodman
|
||||
from src.core.glint_removal.Hedley import Hedley
|
||||
from src.core.glint_removal.SUGAR import SUGAR, correction_iterative
|
||||
from src.core.utils.gdal_helper import (
|
||||
get_image_geo_info as _default_get_geo,
|
||||
load_image_as_array as _default_load,
|
||||
save_bands_as_image as _default_save_bands,
|
||||
copy_hdr_info as _default_copy_hdr,
|
||||
)
|
||||
from src.core.utils.mask_converter import (
|
||||
prepare_water_mask_for_algorithm as _default_prepare,
|
||||
)
|
||||
|
||||
# 使用提供的函数或默认函数
|
||||
if _get_image_geo_info is None:
|
||||
_get_image_geo_info = _default_get_geo
|
||||
if _load_image_as_array is None:
|
||||
_load_image_as_array = _default_load
|
||||
if _save_bands_as_image is None:
|
||||
_save_bands_as_image = _default_save_bands
|
||||
if _copy_hdr_info is None:
|
||||
_copy_hdr_info = _default_copy_hdr
|
||||
if _prepare_water_mask_for_algorithm is None:
|
||||
_prepare_water_mask_for_algorithm = _default_prepare
|
||||
|
||||
deglint_dir = Path(deglint_dir)
|
||||
deglint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤3", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤3: 去除耀斑")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
# 方法名标准化
|
||||
raw_method = str(method).lower()
|
||||
if "kutser" in raw_method:
|
||||
method = "kutser"
|
||||
elif "goodman" in raw_method:
|
||||
method = "goodman"
|
||||
elif "hedley" in raw_method:
|
||||
method = "hedley"
|
||||
elif "sugar" in raw_method:
|
||||
method = "sugar"
|
||||
|
||||
# 如果未启用,直接返回原始影像
|
||||
if not enabled:
|
||||
print("已设置跳过去除耀斑(enabled=False),将直接使用原始影像。")
|
||||
notify("skipped", "跳过去耀斑,使用原始影像")
|
||||
return img_path
|
||||
|
||||
# ---- 确定水域掩膜 ----
|
||||
final_water_mask = water_mask
|
||||
if final_water_mask is not None and str(final_water_mask).lower().endswith(".shp"):
|
||||
# shp 自动替换为 dat
|
||||
dat_mask = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
|
||||
if Path(dat_mask).exists():
|
||||
print(f"检测到输入掩膜为 .shp,自动替换为栅格掩膜: {dat_mask}")
|
||||
final_water_mask = dat_mask
|
||||
|
||||
if final_water_mask is None:
|
||||
dat_mask_default = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
|
||||
if Path(dat_mask_default).exists():
|
||||
final_water_mask = dat_mask_default
|
||||
print(f"使用步骤1生成的水域掩膜: {final_water_mask}")
|
||||
|
||||
# ---- 步骤3.1: 0值像素插值 ----
|
||||
if interpolate_zeros:
|
||||
print("\n" + "-" * 80)
|
||||
print("步骤3.1: 对0值像素进行插值")
|
||||
print("-" * 80)
|
||||
interp_start_time = time.time()
|
||||
|
||||
if _interpolate_zero_pixels_batch is None:
|
||||
from src.core.algorithms.interpolation.interpolator import (
|
||||
interpolate_zero_pixels_batch as _interp_batch,
|
||||
)
|
||||
_interpolate_zero_pixels_batch = _interp_batch
|
||||
|
||||
interp_result, _ = _interpolate_zero_pixels_batch(
|
||||
img_path=img_path,
|
||||
interpolation_method=interpolation_method,
|
||||
output_path=None,
|
||||
water_mask=final_water_mask,
|
||||
deglint_dir=str(deglint_dir),
|
||||
callback_progress=lambda msg: print(f" {msg}"),
|
||||
)
|
||||
img_path = interp_result
|
||||
interp_end_time = time.time()
|
||||
print(f"插值完成,使用插值后的影像: {img_path}")
|
||||
|
||||
# ---- 获取影像信息 ----
|
||||
geotransform, projection, width, height, n_bands = _get_image_geo_info(img_path)
|
||||
print(f"影像尺寸: {width} x {height} x {n_bands}")
|
||||
|
||||
mask_for_algorithm = _prepare_water_mask_for_algorithm(
|
||||
final_water_mask, (height, width), geotransform, projection, img_path
|
||||
)
|
||||
|
||||
# ==================== Kutser ====================
|
||||
if method == "kutser":
|
||||
print(f"使用方法: Kutser (氧吸收波段={oxy_band}, NIR波段={nir_band})")
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_kutser.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
# 将用户指定的 output_path 标准化为 .bsq 路径
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
kutser = Kutser(
|
||||
img_path,
|
||||
shp_path=None,
|
||||
oxy_band=oxy_band,
|
||||
lower_oxy=lower_oxy,
|
||||
upper_oxy=upper_oxy,
|
||||
NIR_band=nir_band,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
kutser.get_corrected_bands()
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
raise RuntimeError(f"Kutser算法未生成输出文件: {hardcoded_bsq}")
|
||||
|
||||
# ==================== Goodman ====================
|
||||
elif method == "goodman":
|
||||
print(f"使用方法: Goodman (NIR波段范围: {nir_lower}-{nir_upper})")
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_goodman.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
goodman = Goodman(
|
||||
img_path,
|
||||
NIR_lower=nir_lower,
|
||||
NIR_upper=nir_upper,
|
||||
A=goodman_A,
|
||||
B=goodman_B,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
corrected_bands = goodman.get_corrected_bands()
|
||||
|
||||
if not Path(hardcoded_bsq).exists():
|
||||
_save_bands_as_image(corrected_bands, hardcoded_bsq, geotransform, projection)
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
del corrected_bands
|
||||
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
|
||||
# ==================== Hedley ====================
|
||||
elif method == "hedley":
|
||||
print(f"使用方法: Hedley (NIR波段={hedley_nir_band})")
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_hedley.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
hedley = Hedley(
|
||||
img_path,
|
||||
shp_path=None,
|
||||
NIR_band=hedley_nir_band,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
hedley.get_corrected_bands()
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
raise RuntimeError(f"Hedley算法未生成输出文件: {hardcoded_bsq}")
|
||||
|
||||
# ==================== SUGAR ====================
|
||||
elif method == "sugar":
|
||||
glint_method_raw = str(sugar_glint_mask_method).lower()
|
||||
if "cdf" in glint_method_raw or "累积" in glint_method_raw:
|
||||
sugar_glint_mask_method_fixed = "cdf"
|
||||
elif "otsu" in glint_method_raw or "大津" in glint_method_raw:
|
||||
sugar_glint_mask_method_fixed = "otsu"
|
||||
else:
|
||||
sugar_glint_mask_method_fixed = "cdf"
|
||||
|
||||
print(
|
||||
f"使用方法: SUGAR (迭代次数={sugar_iter}, 掩膜方法={sugar_glint_mask_method_fixed})"
|
||||
)
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_sugar.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
if sugar_bounds is None:
|
||||
sugar_bounds = [(1, 2)]
|
||||
|
||||
correction_iterative(
|
||||
img_path,
|
||||
iter=sugar_iter,
|
||||
bounds=sugar_bounds,
|
||||
estimate_background=sugar_estimate_background,
|
||||
glint_mask_method=sugar_glint_mask_method_fixed,
|
||||
termination_thresh=sugar_termination_thresh,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
raise RuntimeError(f"SUGAR算法未生成输出文件: {hardcoded_bsq}")
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"不支持的方法: {method}。支持的方法: kutser, goodman, hedley, sugar"
|
||||
)
|
||||
109
src/core/steps/mapping_step.py
Normal file
@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
成图步骤
|
||||
|
||||
包含 step9_generate_distribution_map
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Callable
|
||||
|
||||
|
||||
class MappingStep:
|
||||
"""成图步骤"""
|
||||
|
||||
@staticmethod
|
||||
def generate_distribution_map(
|
||||
prediction_csv_path: str,
|
||||
boundary_shp_path: str,
|
||||
output_image_path: Optional[str] = None,
|
||||
resolution: float = 30,
|
||||
input_crs: str = "EPSG:32651",
|
||||
output_crs: str = "EPSG:4326",
|
||||
show_sample_points: bool = False,
|
||||
base_map_tif: Optional[str] = None,
|
||||
use_distance_diffusion: bool = True,
|
||||
max_diffusion_distance: Optional[float] = None,
|
||||
diffusion_power: float = 2,
|
||||
diffusion_n_neighbors: int = 15,
|
||||
cmap: Optional[str] = None,
|
||||
expand_ratio: float = 0.05,
|
||||
output_dir: Union[str, Path] = "./14_visualization",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
根据采样点的坐标和反演的实测参数,通过插值方法得到水质参数可视化分布图
|
||||
|
||||
Args:
|
||||
prediction_csv_path: 预测结果CSV文件路径(前两列为经纬度,第三列为预测值)
|
||||
boundary_shp_path: 边界shapefile文件路径
|
||||
output_image_path: 输出图片路径(如果为None,自动生成)
|
||||
resolution: 插值网格分辨率(米)
|
||||
input_crs: 输入坐标系
|
||||
output_crs: 输出坐标系
|
||||
show_sample_points: 是否在图上显示采样点
|
||||
base_map_tif: 底图TIF路径
|
||||
use_distance_diffusion: 是否启用距离扩散补全边界
|
||||
max_diffusion_distance: 距离扩散最大距离(米)
|
||||
diffusion_power: 距离扩散幂参数
|
||||
diffusion_n_neighbors: 距离扩散最近邻数量
|
||||
cmap: 颜色映射名称(None表示自动识别)
|
||||
expand_ratio: 边界外扩比例(0-1之间)
|
||||
output_dir: 输出目录
|
||||
callback: 回调函数
|
||||
|
||||
Returns:
|
||||
可视化分布图文件路径
|
||||
"""
|
||||
from src.postprocessing.map import ContentMapper
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤9", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤9: 生成水质参数可视化分布图")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if output_image_path is None:
|
||||
csv_name = Path(prediction_csv_path).stem
|
||||
output_image_path = str(output_dir / f"{csv_name}_distribution.png")
|
||||
|
||||
if Path(output_image_path).exists():
|
||||
print(f"检测到已存在的分布图文件,直接使用: {output_image_path}")
|
||||
notify("skipped", f"可视化分布图已设置: {output_image_path}")
|
||||
return output_image_path
|
||||
|
||||
mapper = ContentMapper(input_crs=input_crs, output_crs=output_crs)
|
||||
|
||||
mapper_kwargs = {
|
||||
"resolution": resolution,
|
||||
"show_sample_points": show_sample_points,
|
||||
"use_distance_diffusion": use_distance_diffusion,
|
||||
"diffusion_power": diffusion_power,
|
||||
"diffusion_n_neighbors": diffusion_n_neighbors,
|
||||
"expand_ratio": expand_ratio,
|
||||
}
|
||||
|
||||
optional_kwargs = {
|
||||
"base_map_tif": base_map_tif,
|
||||
"max_diffusion_distance": max_diffusion_distance,
|
||||
"cmap": cmap,
|
||||
}
|
||||
mapper_kwargs.update({k: v for k, v in optional_kwargs.items() if v is not None})
|
||||
|
||||
mapper.process_data(
|
||||
csv_file=prediction_csv_path,
|
||||
shp_file=boundary_shp_path,
|
||||
output_file=output_image_path,
|
||||
**mapper_kwargs,
|
||||
)
|
||||
|
||||
notify("completed", f"可视化分布图已保存: {output_image_path}")
|
||||
return output_image_path
|
||||
497
src/core/steps/modeling_step.py
Normal file
@ -0,0 +1,497 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
建模步骤
|
||||
|
||||
包含 step6_train_models, step6_5_non_empirical_modeling, step6_75_custom_regression
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable, Dict
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 汉化 -> 英文 反向映射字典(UI 复选框显示文本 -> 底层算法键名)
|
||||
# ============================================================
|
||||
|
||||
# 模型名称:中文 (缩写) -> 英文键名
|
||||
MODEL_NAME_MAP = {
|
||||
"多元线性回归 (MLR)": "LinearRegression",
|
||||
"岭回归 (Ridge)": "Ridge",
|
||||
"套索回归 (Lasso)": "Lasso",
|
||||
"弹性网络 (ElasticNet)": "ElasticNet",
|
||||
"偏最小二乘 (PLSR)": "PLS",
|
||||
"决策树 (CART)": "DecisionTree",
|
||||
"随机森林 (RF)": "RF",
|
||||
"极端随机树 (ET)": "ExtraTrees",
|
||||
"极值梯度提升 (XGBoost)": "XGBoost",
|
||||
"轻量梯度提升 (LightGBM)": "LightGBM",
|
||||
"类别梯度提升 (CatBoost)": "CatBoost",
|
||||
"梯度提升树 (GBDT)": "GradientBoosting",
|
||||
"自适应提升 (AdaBoost)": "AdaBoost",
|
||||
"支持向量回归 (SVR)": "SVR",
|
||||
"K近邻回归 (KNN)": "KNN",
|
||||
"多层感知机 (BP神经网络)": "MLP",
|
||||
}
|
||||
|
||||
# 预处理方法:各种可能的中文变体 -> 标准键名
|
||||
PREPROC_NAME_MAP = {
|
||||
# 无处理
|
||||
"无 (None)": "None",
|
||||
"None": "None",
|
||||
# MMS
|
||||
"最小-最大归一化 (MMS)": "MMS",
|
||||
"MMS": "MMS",
|
||||
# SS
|
||||
"标度化 (SS)": "SS",
|
||||
"SS": "SS",
|
||||
# SNV
|
||||
"标准正态变换 (SNV)": "SNV",
|
||||
"SNV": "SNV",
|
||||
# MA
|
||||
"移动平均 (MA)": "MA",
|
||||
"MA": "MA",
|
||||
# SG
|
||||
"Savitzky-Golay (SG)": "SG",
|
||||
"SG": "SG",
|
||||
# MSC
|
||||
"多元散射校正 (MSC)": "MSC",
|
||||
"MSC": "MSC",
|
||||
# D1
|
||||
"一阶导数 (D1)": "D1",
|
||||
"D1": "D1",
|
||||
# D2
|
||||
"二阶导数 (D2)": "D2",
|
||||
"D2": "D2",
|
||||
# DT
|
||||
"去趋势 (DT)": "DT",
|
||||
"DT": "DT",
|
||||
# CT
|
||||
"中心化 (CT)": "CT",
|
||||
"CT": "CT",
|
||||
}
|
||||
|
||||
# 数据划分方法:各种可能的中文变体 -> 标准键名
|
||||
SPLIT_NAME_MAP = {
|
||||
"SPXY 算法 (考量X-Y空间)": "spxy",
|
||||
"spxy": "spxy",
|
||||
"KS 算法 (考量X空间)": "ks",
|
||||
"ks": "ks",
|
||||
"随机划分 (Random)": "random",
|
||||
"random": "random",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_model_names(model_names: List[str]) -> List[str]:
|
||||
"""清洗模型名称列表:将汉化显示文本还原为英文键名"""
|
||||
result = []
|
||||
for name in model_names:
|
||||
if name in MODEL_NAME_MAP:
|
||||
result.append(MODEL_NAME_MAP[name])
|
||||
else:
|
||||
# 已经是英文键名,直接保留
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_preprocessing_methods(methods: List[str]) -> List[str]:
|
||||
"""清洗预处理方法列表:将汉化显示文本还原为标准键名"""
|
||||
result = []
|
||||
for method in methods:
|
||||
if method in PREPROC_NAME_MAP:
|
||||
result.append(PREPROC_NAME_MAP[method])
|
||||
else:
|
||||
# 已经是标准键名,直接保留
|
||||
result.append(method)
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_split_methods(methods: List[str]) -> List[str]:
|
||||
"""清洗数据划分方法列表:将汉化显示文本还原为标准键名"""
|
||||
result = []
|
||||
for method in methods:
|
||||
if method in SPLIT_NAME_MAP:
|
||||
result.append(SPLIT_NAME_MAP[method])
|
||||
else:
|
||||
# 已经是标准键名,直接保留
|
||||
result.append(method)
|
||||
return result
|
||||
|
||||
|
||||
class ModelingStep:
|
||||
"""建模步骤"""
|
||||
|
||||
# ---- Step 6: 训练机器学习模型 ----
|
||||
|
||||
@staticmethod
|
||||
def train_models(
|
||||
feature_start_column: str = "374.285004",
|
||||
preprocessing_methods: Optional[List[str]] = None,
|
||||
model_names: Optional[List[str]] = None,
|
||||
split_methods: Optional[List[str]] = None,
|
||||
cv_folds: int = 5,
|
||||
training_csv_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./8_Supervised_Model_Training",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
) -> str:
|
||||
"""使用采样点光谱和实测值建立机器学习模型"""
|
||||
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤6", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤6: 训练机器学习模型")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if training_csv_path is None:
|
||||
raise ValueError("必须提供 training_csv_path 参数")
|
||||
|
||||
# 检查模型目录是否已有模型
|
||||
if output_dir.exists() and any(output_dir.iterdir()):
|
||||
has_models = False
|
||||
for item in output_dir.iterdir():
|
||||
if item.is_dir():
|
||||
model_files = (
|
||||
list(item.glob("*.pkl"))
|
||||
+ list(item.glob("*.joblib"))
|
||||
+ list(item.glob("*.h5"))
|
||||
)
|
||||
if model_files:
|
||||
has_models = True
|
||||
break
|
||||
if has_models:
|
||||
print(f"检测到已存在的模型文件,直接使用: {output_dir}")
|
||||
notify("skipped", f"模型目录已设置: {output_dir}")
|
||||
return str(output_dir)
|
||||
|
||||
if preprocessing_methods is None:
|
||||
preprocessing_methods = ["None", "MMS", "SS", "SNV", "MA", "SG", "MSC", "D1", "D2", "DT", "CT"]
|
||||
if model_names is None:
|
||||
model_names = ["SVR", "RF", "Ridge", "Lasso"]
|
||||
if split_methods is None:
|
||||
split_methods = ["spxy", "ks", "random"]
|
||||
|
||||
# ---- 汉化清洗:将 UI 传来的中文/混合名称转换为底层英文键名 ----
|
||||
preprocessing_methods = _normalize_preprocessing_methods(preprocessing_methods)
|
||||
model_names = _normalize_model_names(model_names)
|
||||
split_methods = _normalize_split_methods(split_methods)
|
||||
|
||||
print(f"[参数清洗] 预处理方法: {preprocessing_methods}")
|
||||
print(f"[参数清洗] 模型名称: {model_names}")
|
||||
print(f"[参数清洗] 划分方法: {split_methods}")
|
||||
|
||||
modeler = WaterQualityModelingBatch(str(output_dir))
|
||||
modeler.train_models_batch(
|
||||
csv_path=training_csv_path,
|
||||
feature_start_column=feature_start_column,
|
||||
preprocessing_methods=preprocessing_methods,
|
||||
model_names=model_names,
|
||||
split_methods=split_methods,
|
||||
cv_folds=cv_folds,
|
||||
)
|
||||
|
||||
print(f"模型训练完成,结果保存在: {output_dir}")
|
||||
|
||||
if _report_generator is not None:
|
||||
try:
|
||||
summary_path = _report_generator.generate_training_summary(str(output_dir))
|
||||
print(f"训练摘要报告已生成: {summary_path}")
|
||||
except Exception as e:
|
||||
print(f"生成训练摘要报告时出错: {e}")
|
||||
|
||||
notify("completed", f"模型训练完成: {output_dir}")
|
||||
return str(output_dir)
|
||||
|
||||
# ---- Step 6.5: 非经验统计回归模型训练 ----
|
||||
|
||||
@staticmethod
|
||||
def train_non_empirical_models(
|
||||
csv_path: Optional[str] = None,
|
||||
preprocessing_methods: Optional[List[str]] = None,
|
||||
algorithms: Optional[List[str]] = None,
|
||||
value_cols: Union[int, Dict[str, int]] = 0,
|
||||
spectral_start_col: int = 1,
|
||||
spectral_end_col: Optional[int] = None,
|
||||
window: int = 5,
|
||||
output_dir: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""非经验统计回归模型训练"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤6.5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤6.5: 非经验统计回归模型训练")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过非经验模型训练(enabled=False)。")
|
||||
notify("skipped", "跳过的经验模型训练")
|
||||
return {}
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("必须提供 csv_path 参数")
|
||||
|
||||
if output_dir is not None:
|
||||
non_empirical_dir = Path(output_dir)
|
||||
else:
|
||||
non_empirical_dir = Path.cwd() / "8_Non_Empirical_Regression"
|
||||
non_empirical_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if preprocessing_methods is None:
|
||||
preprocessing_methods = ["None"]
|
||||
if algorithms is None:
|
||||
algorithms = ["chl_a", "nh3", "mno4", "tn", "tp", "tss"]
|
||||
|
||||
if isinstance(value_cols, int):
|
||||
value_cols_dict = {algorithm: value_cols for algorithm in algorithms}
|
||||
elif isinstance(value_cols, dict):
|
||||
value_cols_dict = value_cols
|
||||
else:
|
||||
raise ValueError("value_cols 参数必须是整数或字典")
|
||||
|
||||
if spectral_end_col is None:
|
||||
df = pd.read_csv(csv_path)
|
||||
spectral_end_col = len(df.columns) - 1
|
||||
|
||||
all_model_results = {}
|
||||
|
||||
for preprocess in preprocessing_methods:
|
||||
preprocess_dir = non_empirical_dir / preprocess
|
||||
preprocess_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
processed_csv_path = _apply_preprocessing_internal(
|
||||
csv_path, preprocess, preprocess_dir, spectral_start_col
|
||||
)
|
||||
|
||||
for algorithm in algorithms:
|
||||
algorithm_value_col = value_cols_dict[algorithm]
|
||||
print(f"\n训练 {preprocess} + {algorithm} 模型 (实测值列: {algorithm_value_col})...")
|
||||
|
||||
model_outpath = str(preprocess_dir / f"{preprocess}_{algorithm}.json")
|
||||
|
||||
if Path(model_outpath).exists():
|
||||
print(f"检测到已存在的模型文件,直接使用: {model_outpath}")
|
||||
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
|
||||
continue
|
||||
|
||||
try:
|
||||
from src.core.non_empirical_model_correction import run_model_correction
|
||||
run_model_correction(
|
||||
algorithm=algorithm,
|
||||
csv_file=processed_csv_path if Path(processed_csv_path).exists() else csv_path,
|
||||
value_col=algorithm_value_col,
|
||||
spectral_start=spectral_start_col,
|
||||
spectral_end=spectral_end_col,
|
||||
model_info_outpath=model_outpath,
|
||||
window=window,
|
||||
)
|
||||
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
|
||||
print(f"模型训练完成: {model_outpath}")
|
||||
except Exception as e:
|
||||
print(f"训练 {preprocess}_{algorithm} 模型时出错: {e}")
|
||||
continue
|
||||
|
||||
summary_path = _generate_non_empirical_summary(all_model_results, non_empirical_dir)
|
||||
notify("completed", f"非经验模型训练完成: {non_empirical_dir}")
|
||||
return all_model_results
|
||||
|
||||
# ---- Step 6.75: 自定义回归分析 ----
|
||||
|
||||
@staticmethod
|
||||
def custom_regression(
|
||||
csv_path: Optional[str] = None,
|
||||
x_columns: Optional[Union[str, List[str]]] = None,
|
||||
y_columns: Optional[Union[str, List[str]]] = None,
|
||||
methods: Union[str, List[str]] = "all",
|
||||
output_dir: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Optional[str]:
|
||||
"""使用自定义回归方法分析指标与目标参数之间的关系"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤6.75", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤6.75: 自定义回归分析")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过自定义回归分析(enabled=False)。")
|
||||
notify("skipped", "跳过自定义回归分析")
|
||||
return None
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("必须提供 csv_path 参数")
|
||||
if y_columns is None:
|
||||
raise ValueError("必须指定 y_columns")
|
||||
if x_columns is None:
|
||||
raise ValueError("必须指定 x_columns")
|
||||
|
||||
if isinstance(x_columns, str):
|
||||
x_columns = [x_columns]
|
||||
if isinstance(y_columns, str):
|
||||
y_columns = [y_columns]
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
missing_x = [col for col in x_columns if col not in df.columns]
|
||||
missing_y = [col for col in y_columns if col not in df.columns]
|
||||
if missing_x:
|
||||
raise ValueError(f"自变量列不存在: {missing_x}")
|
||||
if missing_y:
|
||||
raise ValueError(f"因变量列不存在: {missing_y}")
|
||||
|
||||
if output_dir is None:
|
||||
custom_regression_dir = Path(work_dir) / "13_Custom_Regression"
|
||||
else:
|
||||
custom_regression_dir = Path(work_dir) / output_dir
|
||||
custom_regression_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from src.core.modeling.regression import SingleVariableRegressionAnalysis
|
||||
analyzer = SingleVariableRegressionAnalysis()
|
||||
analyzer.batch_single_variable_regression(
|
||||
data=df,
|
||||
x_columns=x_columns,
|
||||
y_columns=y_columns,
|
||||
methods=methods,
|
||||
output_dir=str(custom_regression_dir),
|
||||
)
|
||||
|
||||
notify("completed", f"自定义回归结果已保存到目录: {custom_regression_dir}")
|
||||
return str(custom_regression_dir)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 内部辅助函数(供 ModelingStep 内部使用)
|
||||
# ============================================================
|
||||
|
||||
def _apply_preprocessing_internal(
|
||||
csv_path: str,
|
||||
preprocess_method: str,
|
||||
output_dir: Path,
|
||||
spectral_start_col: int = 4,
|
||||
) -> str:
|
||||
"""应用预处理到CSV数据(内部函数)"""
|
||||
raw_p = str(preprocess_method).lower()
|
||||
if raw_p == "none" or "无" in raw_p or "跳过" in raw_p:
|
||||
preprocess_method = "None"
|
||||
elif raw_p == "mms" or "minmax" in raw_p or "最大最小" in raw_p:
|
||||
preprocess_method = "MMS"
|
||||
elif raw_p == "ss" or "标准" in raw_p or "标准化" in raw_p:
|
||||
preprocess_method = "SS"
|
||||
elif raw_p == "snv" or "标准正态" in raw_p:
|
||||
preprocess_method = "SNV"
|
||||
elif raw_p == "ma" or "移动" in raw_p:
|
||||
preprocess_method = "MA"
|
||||
elif raw_p == "sg" or "savitzky" in raw_p or "平滑" in raw_p:
|
||||
preprocess_method = "SG"
|
||||
elif raw_p == "msc" or "多元散射" in raw_p:
|
||||
preprocess_method = "MSC"
|
||||
elif raw_p in ("d1", "d2", "dt"):
|
||||
preprocess_method = {"d1": "D1", "d2": "D2", "dt": "DT"}.get(raw_p, raw_p.upper())
|
||||
elif raw_p == "ct" or "去趋势" in raw_p:
|
||||
preprocess_method = "CT"
|
||||
|
||||
if preprocess_method == "None":
|
||||
return csv_path
|
||||
|
||||
output_filename = f"preprocessed_{preprocess_method}.csv"
|
||||
output_path = str(output_dir / output_filename)
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预处理文件,直接使用: {output_path}")
|
||||
return output_path
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
non_spectral_cols = df.iloc[:, :spectral_start_col]
|
||||
spectral_data = df.iloc[:, spectral_start_col:]
|
||||
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
|
||||
save_path = None
|
||||
if preprocess_method == "SS":
|
||||
models_dir = output_dir.parent.parent / "8_Supervised_Model_Training"
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_path = str(models_dir / "scaler_params.pkl")
|
||||
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||
|
||||
processed_spectral = Preprocessing(preprocess_method, spectral_data, save_path=save_path)
|
||||
|
||||
if isinstance(processed_spectral, pd.DataFrame):
|
||||
processed_df = pd.concat([non_spectral_cols, processed_spectral], axis=1)
|
||||
else:
|
||||
processed_spectral_df = pd.DataFrame(
|
||||
processed_spectral, columns=spectral_data.columns, index=spectral_data.index
|
||||
)
|
||||
processed_df = pd.concat([non_spectral_cols, processed_spectral_df], axis=1)
|
||||
|
||||
processed_df.to_csv(output_path, index=False)
|
||||
print(f"预处理完成: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def _generate_non_empirical_summary(model_results: Dict[str, str], output_dir: Path) -> str:
|
||||
"""生成非经验模型训练结果汇总CSV"""
|
||||
summary_path = str(output_dir / "non_empirical_models_summary.csv")
|
||||
summary_data = []
|
||||
|
||||
for model_key, model_path in model_results.items():
|
||||
try:
|
||||
parts = model_key.split("_")
|
||||
preprocess_method = parts[0]
|
||||
algorithm_name = "_".join(parts[1:]) if len(parts) > 2 else parts[1]
|
||||
|
||||
with open(model_path, "r", encoding="utf-8") as f:
|
||||
model_info = json.load(f)
|
||||
|
||||
accuracy_list = model_info.get("accuracy", [])
|
||||
summary_row = {
|
||||
"Preprocessing Method": preprocess_method,
|
||||
"Algorithm Name": algorithm_name,
|
||||
"Model Type": model_info.get("model_type", ""),
|
||||
"Coefficient Count": len(model_info.get("model_info", [])),
|
||||
"Average Accuracy(%)": np.mean(accuracy_list) if accuracy_list else 0,
|
||||
"Min Accuracy(%)": np.min(accuracy_list) if accuracy_list else 0,
|
||||
"Max Accuracy(%)": np.max(accuracy_list) if accuracy_list else 0,
|
||||
"Sample Count": len(model_info.get("long", [])),
|
||||
"Model File": model_path,
|
||||
}
|
||||
|
||||
coefficients = model_info.get("model_info", [])
|
||||
for i, coeff in enumerate(coefficients[:5]):
|
||||
summary_row[f"系数_{i+1}"] = coeff
|
||||
|
||||
summary_data.append(summary_row)
|
||||
except Exception as e:
|
||||
print(f"读取模型文件 {model_path} 时出错: {e}")
|
||||
continue
|
||||
|
||||
if summary_data:
|
||||
df_summary = pd.DataFrame(summary_data)
|
||||
df_summary.to_csv(summary_path, index=False, encoding="utf-8-sig")
|
||||
print(f"汇总文件已生成: {summary_path}")
|
||||
else:
|
||||
print("警告: 没有有效的模型数据可汇总")
|
||||
summary_path = ""
|
||||
|
||||
return summary_path
|
||||
402
src/core/steps/prediction_step.py
Normal file
@ -0,0 +1,402 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
预测步骤
|
||||
|
||||
包含 step7_generate_sampling_points, step8_predict_water_quality,
|
||||
step8_5_predict_with_non_empirical_models, step8_75_predict_with_custom_regression
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable, Dict
|
||||
|
||||
|
||||
class PredictionStep:
|
||||
"""预测步骤"""
|
||||
|
||||
# ---- Step 7: 生成采样点并提取光谱 ----
|
||||
|
||||
@staticmethod
|
||||
def generate_sampling_points(
|
||||
deglint_img_path: Optional[str] = None,
|
||||
interval: int = 50,
|
||||
sample_radius: int = 5,
|
||||
chunk_size: int = 1000,
|
||||
water_mask_path: Optional[str] = None,
|
||||
glint_mask_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./4_sampling",
|
||||
callback: Optional[Callable] = None,
|
||||
use_adaptive_sampling: bool = True,
|
||||
) -> str:
|
||||
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
|
||||
from pathlib import Path
|
||||
from src.utils.sampling import get_spectral_sampling_points_chunked
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "sampling_spectra.csv")
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤7", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤7: 生成预测采样点并提取光谱")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if deglint_img_path is None:
|
||||
raise ValueError("必须提供 deglint_img_path 参数")
|
||||
|
||||
# 1. 初始归一化与安全转换
|
||||
original_path = Path(deglint_img_path)
|
||||
final_deglint_path = original_path
|
||||
|
||||
# 2. 智能回溯探测:如果当前路径不存在,或者后缀是前端死板的 .dat
|
||||
if not final_deglint_path.exists() or final_deglint_path.suffix.lower() == '.dat':
|
||||
print(f"🔍 智能探测:输入去耀斑路径不存在或为 .dat 占位符 ({final_deglint_path}),正在向上搜索真实产物...")
|
||||
|
||||
# 定位到预期的 3_deglint 根目录
|
||||
possible_dir = original_path.parent
|
||||
if possible_dir.name != '3_deglint' and Path(output_path).parent.parent.exists():
|
||||
possible_dir = Path(output_path).parent.parent / "3_deglint"
|
||||
|
||||
if possible_dir.exists():
|
||||
# 搜寻该目录下所有真实存在的 .bsq 文件(接管 goodman/sugar/kutser/hedley 的硬编码产物)
|
||||
existing_bsqs = list(possible_dir.glob("*.bsq"))
|
||||
if existing_bsqs:
|
||||
final_deglint_path = existing_bsqs[0]
|
||||
print(f"💡 智能拦截成功:自动寻回底层真实去耀斑影像: {final_deglint_path}")
|
||||
else:
|
||||
final_deglint_path = original_path.with_suffix('.bsq')
|
||||
else:
|
||||
final_deglint_path = original_path.with_suffix('.bsq')
|
||||
|
||||
deglint_img_str = str(final_deglint_path)
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的采样点光谱数据文件,直接使用: {output_path}")
|
||||
notify("skipped", f"采样点光谱数据已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
glint_mask_to_use = glint_mask_path
|
||||
if glint_mask_to_use is None:
|
||||
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
|
||||
|
||||
# 传递极度安全的 deglint_img_str 进底层(关键字传参,避免 positional 参数顺序陷阱)
|
||||
get_spectral_sampling_points_chunked(
|
||||
deglint_img_str, water_mask_path, glint_mask_to_use,
|
||||
output_path,
|
||||
interval=interval,
|
||||
sample_radius=sample_radius,
|
||||
chunk_size=chunk_size,
|
||||
use_adaptive_sampling=use_adaptive_sampling,
|
||||
)
|
||||
|
||||
notify("completed", f"采样点光谱数据已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
# ---- Step 8: 机器学习模型预测水质参数 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_water_quality(
|
||||
sampling_csv_path: str,
|
||||
models_dir: Optional[str] = None,
|
||||
metric: str = "test_r2",
|
||||
prediction_column: str = "prediction",
|
||||
output_dir: Union[str, Path] = "./9_ML_Prediction",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
_external_model=None,
|
||||
_external_model_path=None,
|
||||
_external_models_dict=None,
|
||||
_external_model_dir=None,
|
||||
) -> Dict[str, str]:
|
||||
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8: 预测水质参数")
|
||||
print("=" * 80)
|
||||
print(f"[PredictionStep] 准备执行预测,字典状态: {'Yes' if _external_models_dict else 'No'}"
|
||||
f", 单模型状态: {'Yes' if _external_model else 'No'}")
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if models_dir is None:
|
||||
raise ValueError("必须提供 models_dir 参数")
|
||||
|
||||
ml_prediction_dir = Path(output_dir)
|
||||
ml_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prediction_files = {}
|
||||
if ml_prediction_dir.exists():
|
||||
csv_files = list(ml_prediction_dir.glob("*.csv"))
|
||||
for csv_file in csv_files:
|
||||
file_stem = csv_file.stem
|
||||
if "_prediction" in file_stem:
|
||||
target_name = file_stem.replace("_prediction", "")
|
||||
elif "_pred" in file_stem:
|
||||
target_name = file_stem.replace("_pred", "")
|
||||
else:
|
||||
target_name = file_stem
|
||||
prediction_files[target_name] = str(csv_file)
|
||||
|
||||
# 检查是否所有目标参数都有预测文件
|
||||
if prediction_files:
|
||||
models_path_obj = Path(models_dir)
|
||||
if models_path_obj.exists():
|
||||
target_folders = [d.name for d in models_path_obj.iterdir() if d.is_dir()]
|
||||
missing_targets = [t for t in target_folders if t not in prediction_files]
|
||||
if not missing_targets:
|
||||
print(f"检测到已存在的预测结果文件,直接使用: {ml_prediction_dir}")
|
||||
notify("skipped", f"预测结果已设置: {ml_prediction_dir}")
|
||||
return prediction_files
|
||||
else:
|
||||
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
|
||||
|
||||
all_results = {}
|
||||
|
||||
if _external_models_dict:
|
||||
# 外部模型字典优先:直接用字典的 keys 作为 targets 列表,
|
||||
# 手动为每个模型创建 inference 实例并调用 inference_pipeline。
|
||||
print(f"\n使用外部导入模型字典({len(_external_models_dict)} 个模型)...")
|
||||
for target_name, model_obj in _external_models_dict.items():
|
||||
try:
|
||||
output_file = ml_prediction_dir / f"{target_name}.csv"
|
||||
model_inferencer = WaterQualityInference(
|
||||
models_dir or "./",
|
||||
external_model=model_obj,
|
||||
external_model_path=_external_model_dir or "",
|
||||
)
|
||||
predictions, result_df = model_inferencer.inference_pipeline(
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_csv_path=str(output_file),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
)
|
||||
prediction_files[target_name] = str(output_file)
|
||||
all_results[target_name] = {
|
||||
"status": "success",
|
||||
"output_file": str(output_file),
|
||||
"sample_count": len(predictions),
|
||||
}
|
||||
print(f" ✓ {target_name}: {len(predictions)} 个预测值")
|
||||
except Exception as e:
|
||||
print(f" ✗ {target_name}: 失败 — {type(e).__name__}: {e}")
|
||||
prediction_files[target_name] = None
|
||||
all_results[target_name] = {"status": "error", "error": str(e)}
|
||||
else:
|
||||
# 字典为空或不存在:回退到扫描 models_dir 子目录的传统逻辑
|
||||
inferencer = WaterQualityInference(
|
||||
models_dir,
|
||||
external_model=_external_model,
|
||||
external_model_path=_external_model_path,
|
||||
)
|
||||
all_results = inferencer.batch_inference_multi_models(
|
||||
models_root_dir=models_dir,
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_dir=str(ml_prediction_dir),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
output_format="csv",
|
||||
external_model=_external_model,
|
||||
external_model_path=_external_model_path,
|
||||
external_models_dict=_external_models_dict,
|
||||
)
|
||||
# batch_inference_multi_models 已确保返回字典,永不返回 None
|
||||
if all_results:
|
||||
for target_name, result in all_results.items():
|
||||
if result.get("status") == "success":
|
||||
prediction_files[target_name] = result["output_file"]
|
||||
|
||||
print(f"预测完成,结果保存在: {ml_prediction_dir}")
|
||||
|
||||
if _report_generator is not None:
|
||||
try:
|
||||
report_path = _report_generator.generate_prediction_report(prediction_files)
|
||||
print(f"预测结果报告已生成: {report_path}")
|
||||
except Exception as e:
|
||||
print(f"生成预测结果报告时出错: {e}")
|
||||
|
||||
notify("completed", f"预测完成: {ml_prediction_dir}")
|
||||
return prediction_files
|
||||
|
||||
# ---- Step 8.5: 非经验模型预测 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_with_non_empirical_models(
|
||||
sampling_csv_path: str,
|
||||
non_empirical_models_dir: Optional[str] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
metric: str = "Average Accuracy(%)",
|
||||
prediction_column: str = "prediction",
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Dict[str, str]:
|
||||
"""使用非经验统计回归模型进行参数预测"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8.5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8.5: 使用非经验模型进行参数预测")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过非经验模型预测(enabled=False)。")
|
||||
notify("skipped", "跳过非经验模型预测")
|
||||
return {}
|
||||
|
||||
if non_empirical_models_dir is not None:
|
||||
final_models_dir = non_empirical_models_dir
|
||||
else:
|
||||
default_models_dir = str(Path(work_dir) / "8_Non_Empirical_Regression")
|
||||
if Path(default_models_dir).exists():
|
||||
final_models_dir = default_models_dir
|
||||
else:
|
||||
raise ValueError("请先执行步骤6.5: 非经验模型训练,或提供 non_empirical_models_dir 参数")
|
||||
|
||||
if output_dir is not None:
|
||||
non_empirical_prediction_dir = Path(output_dir)
|
||||
else:
|
||||
non_empirical_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Non_Empirical_Prediction"
|
||||
non_empirical_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prediction_files = {}
|
||||
summary_path = Path(final_models_dir) / "non_empirical_models_summary.csv"
|
||||
if not summary_path.exists():
|
||||
raise ValueError(f"未找到非经验模型汇总文件: {summary_path}")
|
||||
|
||||
import pandas as pd
|
||||
df_summary = pd.read_csv(summary_path)
|
||||
|
||||
best_models = {}
|
||||
for algorithm in df_summary["Algorithm Name"].unique():
|
||||
algorithm_df = df_summary[df_summary["Algorithm Name"] == algorithm]
|
||||
if metric in algorithm_df.columns:
|
||||
best_model_row = algorithm_df.nlargest(1, metric)
|
||||
else:
|
||||
best_model_row = algorithm_df.iloc[[0]]
|
||||
|
||||
best_model_path = best_model_row["Model File"].values[0]
|
||||
best_preprocess = best_model_row["Preprocessing Method"].values[0]
|
||||
best_accuracy = best_model_row[metric].values[0] if metric in best_model_row.columns else "N/A"
|
||||
|
||||
best_models[algorithm] = {
|
||||
"model_path": best_model_path,
|
||||
"preprocess_method": best_preprocess,
|
||||
"accuracy": best_accuracy,
|
||||
}
|
||||
print(f"算法 {algorithm}: 选择 {best_preprocess} (准确率: {best_accuracy})")
|
||||
|
||||
pd.read_csv(sampling_csv_path) # just to validate
|
||||
|
||||
for algorithm, model_info in best_models.items():
|
||||
print(f"\n使用 {algorithm} 算法进行预测...")
|
||||
output_path = str(non_empirical_prediction_dir / f"non_empirical_{algorithm}_{prediction_column}.csv")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预测结果文件,直接使用: {output_path}")
|
||||
prediction_files[algorithm] = output_path
|
||||
continue
|
||||
|
||||
try:
|
||||
from src.core.non_empirical_retrieval import non_empirical_retrieval
|
||||
non_empirical_retrieval(
|
||||
algorithm=algorithm,
|
||||
model_info_path=model_info["model_path"],
|
||||
coor_spectral_path=sampling_csv_path,
|
||||
output_path=output_path,
|
||||
wave_radius=5,
|
||||
)
|
||||
prediction_files[algorithm] = output_path
|
||||
print(f"预测完成: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"使用 {algorithm} 算法预测时出错: {e}")
|
||||
continue
|
||||
|
||||
notify("completed", f"非经验模型预测完成: {non_empirical_prediction_dir}")
|
||||
return prediction_files
|
||||
|
||||
# ---- Step 8.75: 自定义回归模型预测 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_with_custom_regression(
|
||||
sampling_csv_path: str,
|
||||
custom_regression_dir: Optional[str] = None,
|
||||
formula_csv_path: Optional[str] = None,
|
||||
coordinate_columns: Optional[List[str]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
filename_prefix: str = "custom_regression_prediction",
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Dict[str, str]:
|
||||
"""使用自定义回归模型进行参数预测"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8.75", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8.75: 使用自定义回归模型进行参数预测")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过自定义回归模型预测(enabled=False)。")
|
||||
notify("skipped", "跳过自定义回归预测")
|
||||
return {}
|
||||
|
||||
if not Path(sampling_csv_path).exists():
|
||||
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
|
||||
|
||||
if custom_regression_dir is not None:
|
||||
final_regression_dir = custom_regression_dir
|
||||
else:
|
||||
final_regression_dir = str(Path(work_dir) / "13_Custom_Regression")
|
||||
if not Path(final_regression_dir).exists():
|
||||
raise ValueError(
|
||||
"请先执行步骤6.75: 自定义回归分析,或提供 custom_regression_dir 参数"
|
||||
)
|
||||
|
||||
if output_dir is None:
|
||||
custom_regression_prediction_dir = Path(work_dir) / "13_Custom_Regression" / "Custom_Regression_Prediction"
|
||||
custom_regression_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
prediction_output_dir = str(custom_regression_prediction_dir)
|
||||
else:
|
||||
prediction_output_dir = output_dir
|
||||
|
||||
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
|
||||
|
||||
predictor = CustomRegressionPredictor(
|
||||
regression_csv_dir=final_regression_dir,
|
||||
formula_csv_path=formula_csv_path,
|
||||
)
|
||||
|
||||
print(f"开始使用自定义回归模块进行批量预测...")
|
||||
print(f" 采样点数据: {sampling_csv_path}")
|
||||
print(f" 回归模型目录: {final_regression_dir}")
|
||||
print(f" 输出目录: {prediction_output_dir}")
|
||||
|
||||
saved_files = predictor.run_batch_prediction(
|
||||
input_csv_path=sampling_csv_path,
|
||||
coordinate_columns=coordinate_columns,
|
||||
filename_prefix=filename_prefix,
|
||||
)
|
||||
|
||||
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
|
||||
for param_name, filepath in saved_files.items():
|
||||
print(f" {param_name}: {filepath}")
|
||||
|
||||
notify("completed", f"自定义回归预测完成: {len(saved_files)} 个文件")
|
||||
return saved_files
|
||||
148
src/core/steps/water_mask_step.py
Normal file
@ -0,0 +1,148 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
步骤1: 水域掩膜生成
|
||||
|
||||
支持三种方式:
|
||||
1. 基于 shp 文件栅格化
|
||||
2. 使用现有栅格格式掩膜文件 (.dat/.tif)
|
||||
3. 基于 NDWI 从影像自动生成水体掩膜
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Callable, Union
|
||||
import numpy as np
|
||||
|
||||
|
||||
class WaterMaskStep:
|
||||
"""水域掩膜生成步骤"""
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
mask_path: Optional[str] = None,
|
||||
img_path: Optional[str] = None,
|
||||
ndwi_threshold: float = 0.4,
|
||||
use_ndwi: bool = False,
|
||||
generate_png: bool = True,
|
||||
output_path: Optional[str] = None,
|
||||
water_mask_dir: Union[str, Path] = "./1_water_mask",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
执行水域掩膜生成
|
||||
|
||||
Args:
|
||||
mask_path: 水体掩膜文件路径,支持 .shp(需 img_path)或 .dat/.tif(直接使用)
|
||||
img_path: 输入影像文件路径(当 mask_path 为 shp 或 use_ndwi=True 时必须提供)
|
||||
ndwi_threshold: NDWI 阈值(use_ndwi=True 时使用)
|
||||
use_ndwi: 是否使用 NDWI 方法从影像生成水体掩膜
|
||||
generate_png: 是否生成 PNG 预览图(默认 True)
|
||||
output_path: 指定输出掩膜文件的保存路径(可选)
|
||||
water_mask_dir: 工作目录
|
||||
callback: 回调函数,签名为 callback(step, status, message)
|
||||
|
||||
Returns:
|
||||
dat 格式的水域掩膜文件路径
|
||||
"""
|
||||
from src.utils.extract_water_area import rasterize_shp, ndwi
|
||||
from src.core.utils.preview_generator import (
|
||||
generate_image_preview,
|
||||
generate_water_mask_overlay,
|
||||
)
|
||||
|
||||
water_mask_dir = Path(water_mask_dir)
|
||||
water_mask_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤1", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤1: 生成或设置水域mask")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
# 生成影像预览图
|
||||
if generate_png and img_path is not None and Path(img_path).exists():
|
||||
preview_path = str(water_mask_dir / "hsi_preview.png")
|
||||
generate_image_preview(
|
||||
img_path=img_path,
|
||||
output_path=preview_path,
|
||||
title="影像预览: RGB波段(基于波长)"
|
||||
)
|
||||
|
||||
# ---- NDWI 方法 ----
|
||||
if use_ndwi:
|
||||
if img_path is None:
|
||||
raise ValueError("当 use_ndwi=True 时,必须提供 img_path 参数")
|
||||
if not Path(img_path).exists():
|
||||
raise ValueError(f"影像文件不存在: {img_path}")
|
||||
|
||||
print(f"使用NDWI方法从影像生成水体掩膜,阈值={ndwi_threshold}...")
|
||||
|
||||
ndwi_output_path = output_path or str(water_mask_dir / "water_mask_from_ndwi.dat")
|
||||
os.makedirs(Path(ndwi_output_path).parent, exist_ok=True)
|
||||
|
||||
if Path(ndwi_output_path).exists():
|
||||
print(f"检测到已存在的NDWI掩膜文件,直接使用: {ndwi_output_path}")
|
||||
notify("skipped", f"水域掩膜已设置: {ndwi_output_path}")
|
||||
return ndwi_output_path
|
||||
|
||||
ndwi(img_path, ndwi_threshold, ndwi_output_path)
|
||||
|
||||
if generate_png:
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
generate_water_mask_overlay(
|
||||
img_path=img_path, mask_path=ndwi_output_path, output_path=str(overlay_path)
|
||||
)
|
||||
|
||||
notify("completed", f"NDWI水体掩膜已生成: {ndwi_output_path}")
|
||||
return ndwi_output_path
|
||||
|
||||
# ---- 必须提供 mask_path ----
|
||||
if mask_path is None:
|
||||
raise ValueError("必须提供 mask_path 参数或设置 use_ndwi=True")
|
||||
if not Path(mask_path).exists():
|
||||
raise ValueError(f"文件不存在: {mask_path}")
|
||||
|
||||
file_ext = Path(mask_path).suffix.lower()
|
||||
|
||||
# ---- SHP 栅格化 ----
|
||||
if file_ext == ".shp":
|
||||
if img_path is None:
|
||||
raise ValueError("当 mask_path 为 shp 格式时,必须提供 img_path 参数")
|
||||
|
||||
print(f"检测到shp格式的水体掩膜,正在转换为dat格式...")
|
||||
|
||||
shp_output_path = output_path or str(water_mask_dir / "water_mask_from_shp.dat")
|
||||
os.makedirs(Path(shp_output_path).parent, exist_ok=True)
|
||||
|
||||
if Path(shp_output_path).exists():
|
||||
print(f"检测到已存在的栅格化掩膜文件,直接使用: {shp_output_path}")
|
||||
notify("skipped", f"水域掩膜已设置: {shp_output_path}")
|
||||
if generate_png:
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
if not overlay_path.exists():
|
||||
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
|
||||
return shp_output_path
|
||||
|
||||
safe_mask_path = os.path.abspath(mask_path).replace("\\", "/")
|
||||
rasterize_shp(safe_mask_path, shp_output_path, img_path)
|
||||
|
||||
if generate_png:
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
|
||||
|
||||
notify("completed", f"dat格式水域掩膜已生成: {shp_output_path}")
|
||||
return shp_output_path
|
||||
|
||||
# ---- 栅格格式直接使用 ----
|
||||
print(f"检测到栅格格式的水体掩膜,直接使用: {mask_path}")
|
||||
if generate_png and img_path is not None and Path(img_path).exists():
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
generate_water_mask_overlay(img_path, mask_path, str(overlay_path))
|
||||
|
||||
notify("completed", f"水域掩膜已设置: {mask_path}")
|
||||
return mask_path
|
||||
42
src/core/utils/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
工具模块 - 统一导出接口
|
||||
"""
|
||||
from src.core.utils.gdal_helper import (
|
||||
get_image_geo_info,
|
||||
load_image_as_array,
|
||||
save_array_as_image,
|
||||
save_bands_as_image,
|
||||
copy_hdr_info,
|
||||
read_band_as_array,
|
||||
read_multiple_bands,
|
||||
)
|
||||
from src.core.utils.mask_converter import (
|
||||
prepare_water_mask_for_algorithm,
|
||||
ensure_water_mask_dat,
|
||||
)
|
||||
from src.core.utils.preview_generator import (
|
||||
generate_image_preview,
|
||||
generate_water_mask_overlay,
|
||||
select_rgb_bands_by_wavelength,
|
||||
get_wavelength_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# GDAL IO
|
||||
'get_image_geo_info',
|
||||
'load_image_as_array',
|
||||
'save_array_as_image',
|
||||
'save_bands_as_image',
|
||||
'copy_hdr_info',
|
||||
'read_band_as_array',
|
||||
'read_multiple_bands',
|
||||
# 掩膜转换
|
||||
'prepare_water_mask_for_algorithm',
|
||||
'ensure_water_mask_dat',
|
||||
# 预览图生成
|
||||
'generate_image_preview',
|
||||
'generate_water_mask_overlay',
|
||||
'select_rgb_bands_by_wavelength',
|
||||
'get_wavelength_info',
|
||||
]
|
||||
309
src/core/utils/gdal_helper.py
Normal file
@ -0,0 +1,309 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
GDAL 底层 IO 工具模块
|
||||
|
||||
提供遥感影像读写、格式转换等底层 GDAL 操作功能。
|
||||
这些函数不依赖任何业务逻辑,可在其他项目中独立复用。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
# GDAL 导入(可选)
|
||||
try:
|
||||
from osgeo import gdal, ogr, gdal_array
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# hdr 文件工具
|
||||
try:
|
||||
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path
|
||||
UTIL_AVAILABLE = True
|
||||
except ImportError:
|
||||
UTIL_AVAILABLE = False
|
||||
write_fields_to_hdrfile = None
|
||||
get_hdr_file_path = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 影像信息读取
|
||||
# ============================================================
|
||||
|
||||
def get_image_geo_info(img_path: str) -> Tuple[tuple, str, int, int, int]:
|
||||
"""
|
||||
获取影像的地理信息(不加载图像数据,节省内存)
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (geotransform, projection, width, height, n_bands)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
return geotransform, projection, width, height, n_bands
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def load_image_as_array(img_path: str) -> Tuple[np.ndarray, tuple, str]:
|
||||
"""
|
||||
加载影像文件为numpy数组
|
||||
|
||||
注意:此方法会将所有波段加载到内存,对于大图像会消耗大量内存。
|
||||
建议直接传递文件路径给算法类,让算法类使用GDAL逐波段处理。
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
|
||||
Returns:
|
||||
tuple: (image_array, geotransform, projection)
|
||||
image_array: numpy数组,形状为(height, width, bands)
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
n_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
|
||||
image_bands = []
|
||||
for i in range(1, n_bands + 1):
|
||||
band = dataset.GetRasterBand(i)
|
||||
band_data = band.ReadAsArray()
|
||||
image_bands.append(band_data)
|
||||
|
||||
image_array = np.dstack(image_bands)
|
||||
return image_array, geotransform, projection
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def read_band_as_array(img_path: str, band_index: int) -> np.ndarray:
|
||||
"""
|
||||
读取单个波段为 numpy 数组
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
band_index: 波段索引(从 0 开始)
|
||||
|
||||
Returns:
|
||||
numpy 数组,形状为 (height, width)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
band = dataset.GetRasterBand(band_index + 1)
|
||||
return band.ReadAsArray()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def read_multiple_bands(img_path: str, band_indices: list) -> Tuple[list, tuple, str]:
|
||||
"""
|
||||
读取多个指定波段为列表
|
||||
|
||||
Args:
|
||||
img_path: 影像文件路径
|
||||
band_indices: 波段索引列表
|
||||
|
||||
Returns:
|
||||
tuple: (band_list, geotransform, projection)
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取影像文件")
|
||||
|
||||
dataset = gdal.Open(img_path, gdal.GA_ReadOnly)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
projection = dataset.GetProjection()
|
||||
bands = []
|
||||
for idx in band_indices:
|
||||
band = dataset.GetRasterBand(idx + 1)
|
||||
bands.append(band.ReadAsArray())
|
||||
return bands, geotransform, projection
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 影像写入
|
||||
# ============================================================
|
||||
|
||||
def save_array_as_image(image_array: np.ndarray, output_path: str,
|
||||
geotransform: tuple, projection: str,
|
||||
dtype=None) -> str:
|
||||
"""
|
||||
将numpy数组保存为影像文件
|
||||
|
||||
Args:
|
||||
image_array: numpy数组,形状为(height, width, bands) 或 (height, width)
|
||||
output_path: 输出文件路径
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
dtype: GDAL数据类型(默认 gdal.GDT_Float32)
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if dtype is None:
|
||||
dtype = gdal.GDT_Float32
|
||||
|
||||
if image_array.ndim == 2:
|
||||
height, width = image_array.shape
|
||||
n_bands = 1
|
||||
else:
|
||||
height, width, n_bands = image_array.shape
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
if driver is None:
|
||||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||||
|
||||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||||
|
||||
try:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
if n_bands == 1:
|
||||
band = dataset.GetRasterBand(1)
|
||||
band.WriteArray(image_array)
|
||||
band.FlushCache()
|
||||
else:
|
||||
for i in range(n_bands):
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(image_array[:, :, i])
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def save_bands_as_image(corrected_bands: list, output_path: str,
|
||||
geotransform: tuple, projection: str,
|
||||
dtype=None) -> str:
|
||||
"""
|
||||
直接从波段列表保存影像文件(避免堆叠,节省内存)
|
||||
|
||||
Args:
|
||||
corrected_bands: 校正后的波段列表,每个元素是一个(height, width)的numpy数组
|
||||
output_path: 输出文件路径
|
||||
geotransform: 地理变换参数
|
||||
projection: 投影信息
|
||||
dtype: GDAL数据类型
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法保存影像文件")
|
||||
|
||||
if not corrected_bands:
|
||||
raise ValueError("波段列表为空")
|
||||
|
||||
if dtype is None:
|
||||
dtype = gdal.GDT_Float32
|
||||
|
||||
n_bands = len(corrected_bands)
|
||||
height, width = corrected_bands[0].shape
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
if driver is None:
|
||||
raise ValueError("无法创建影像文件,没有可用的驱动")
|
||||
|
||||
dataset = driver.Create(output_path, width, height, n_bands, dtype)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||||
|
||||
try:
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
dataset.SetProjection(projection)
|
||||
|
||||
for i, band_array in enumerate(corrected_bands):
|
||||
if band_array.shape != (height, width):
|
||||
raise ValueError(f"波段 {i} 的尺寸 {band_array.shape} 与预期 {(height, width)} 不匹配")
|
||||
band = dataset.GetRasterBand(i + 1)
|
||||
band.WriteArray(band_array)
|
||||
band.FlushCache()
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def copy_hdr_info(source_img_path: str, dest_img_path: str) -> bool:
|
||||
"""
|
||||
复制原始影像的hdr文件信息(如波长等)到目标影像的hdr文件
|
||||
|
||||
Args:
|
||||
source_img_path: 源影像文件路径(原始bsq文件)
|
||||
dest_img_path: 目标影像文件路径(去耀斑后的bsq文件)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not UTIL_AVAILABLE:
|
||||
print("警告: util模块未导入,无法复制hdr文件信息")
|
||||
return False
|
||||
|
||||
try:
|
||||
source_hdr_path = get_hdr_file_path(source_img_path)
|
||||
dest_hdr_path = get_hdr_file_path(dest_img_path)
|
||||
|
||||
if not Path(source_hdr_path).exists():
|
||||
print(f"警告: 源hdr文件不存在: {source_hdr_path}")
|
||||
return False
|
||||
|
||||
if not Path(dest_hdr_path).exists():
|
||||
print(f"警告: 目标hdr文件不存在: {dest_hdr_path}")
|
||||
return False
|
||||
|
||||
write_fields_to_hdrfile(source_hdr_path, dest_hdr_path)
|
||||
print(f"已复制原始hdr文件信息到: {dest_hdr_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"警告: 复制hdr文件信息时出错: {e}")
|
||||
return False
|
||||
210
src/core/utils/mask_converter.py
Normal file
@ -0,0 +1,210 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
掩膜转换工具模块
|
||||
|
||||
提供 shapefile / ndarray / dat / tif 等多种格式掩膜之间的相互转换,
|
||||
以及水体掩膜的预处理逻辑。
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from osgeo import gdal, ogr
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
|
||||
def prepare_water_mask_for_algorithm(
|
||||
water_mask: Optional[Union[str, np.ndarray]],
|
||||
image_shape: Union[tuple, np.ndarray],
|
||||
geotransform: tuple,
|
||||
projection: str,
|
||||
img_path: str,
|
||||
water_mask_dir: Optional[str] = None,
|
||||
callback=None
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
准备水域掩膜供算法使用
|
||||
|
||||
支持格式:
|
||||
- None:自动使用预先生成的 dat 格式掩膜
|
||||
- numpy.ndarray:直接返回(确保是 0/1 格式)
|
||||
- .dat / .tif 等栅格文件:读取并返回
|
||||
- .shp 文件:先栅格化,再读取返回
|
||||
|
||||
Args:
|
||||
water_mask: 掩膜来源
|
||||
image_shape: 影像形状 (height, width) 或 (height, width, channels)
|
||||
geotransform: GDAL 地理变换参数
|
||||
projection: 投影信息
|
||||
img_path: 影像路径(用于 shp 栅格化)
|
||||
water_mask_dir: 水体掩膜目录(用于缓存栅格化的 shp 结果)
|
||||
callback: 进度回调函数(可选)
|
||||
|
||||
Returns:
|
||||
numpy数组(dtype=uint8,0=非水域,1=水域)或 None
|
||||
"""
|
||||
img_height, img_width = image_shape[0], image_shape[1]
|
||||
|
||||
if water_mask is None:
|
||||
return None
|
||||
|
||||
# numpy 数组直接返回
|
||||
if isinstance(water_mask, np.ndarray):
|
||||
if water_mask.shape[:2] != (img_height, img_width):
|
||||
raise ValueError(f"掩膜尺寸 {water_mask.shape[:2]} 与图像尺寸 {(img_height, img_width)} 不匹配")
|
||||
return (water_mask > 0).astype(np.uint8)
|
||||
|
||||
# 字符串路径
|
||||
if isinstance(water_mask, str):
|
||||
ext = Path(water_mask).suffix.lower()
|
||||
|
||||
# shapefile 格式
|
||||
if ext == '.shp':
|
||||
return _convert_shp_to_mask(
|
||||
shp_path=water_mask,
|
||||
img_path=img_path,
|
||||
image_shape=image_shape,
|
||||
geotransform=geotransform,
|
||||
projection=projection,
|
||||
water_mask_dir=water_mask_dir,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
# 栅格文件格式
|
||||
return _load_raster_mask(water_mask, img_height, img_width)
|
||||
|
||||
raise ValueError(f"不支持的掩膜类型: {type(water_mask)}")
|
||||
|
||||
|
||||
def _convert_shp_to_mask(shp_path: str, img_path: str,
|
||||
image_shape: tuple,
|
||||
geotransform: tuple,
|
||||
projection: str,
|
||||
water_mask_dir: Optional[str] = None,
|
||||
callback=None) -> np.ndarray:
|
||||
"""将 shapefile 栅格化为掩膜数组"""
|
||||
from src.utils.extract_water_area import rasterize_shp
|
||||
|
||||
safe_shp_path = os.path.abspath(shp_path).replace('\\', '/')
|
||||
shp_name = Path(safe_shp_path).stem
|
||||
|
||||
if water_mask_dir:
|
||||
temp_mask_path = str(Path(water_mask_dir) / f"water_mask_{shp_name}.dat")
|
||||
else:
|
||||
temp_mask_path = f"/tmp/water_mask_{shp_name}.dat"
|
||||
|
||||
# 缓存:已栅格化则直接读取
|
||||
if Path(temp_mask_path).exists():
|
||||
print(f"使用已存在的栅格化掩膜: {temp_mask_path}")
|
||||
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
|
||||
|
||||
# 需要栅格化
|
||||
if img_path is None:
|
||||
raise ValueError("当 water_mask 为 shp 格式时,需要提供 img_path 参数用于栅格化")
|
||||
|
||||
print(f"正在将 SHP 栅格化: {safe_shp_path}")
|
||||
rasterize_shp(safe_shp_path, temp_mask_path, img_path)
|
||||
|
||||
return _load_raster_mask(temp_mask_path, image_shape[0], image_shape[1])
|
||||
|
||||
|
||||
def _load_raster_mask(mask_path: str, img_height: int, img_width: int) -> np.ndarray:
|
||||
"""从栅格文件加载掩膜"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法读取掩膜文件")
|
||||
|
||||
mask_dataset = gdal.Open(mask_path, gdal.GA_ReadOnly)
|
||||
if mask_dataset is None:
|
||||
raise ValueError(f"无法打开掩膜文件: {mask_path}")
|
||||
|
||||
try:
|
||||
mask_array = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
finally:
|
||||
mask_dataset = None
|
||||
|
||||
if mask_array.shape != (img_height, img_width):
|
||||
raise ValueError(f"掩膜尺寸 {mask_array.shape} 与图像尺寸 {(img_height, img_width)} 不匹配")
|
||||
|
||||
return (mask_array > 0).astype(np.uint8)
|
||||
|
||||
|
||||
def ensure_water_mask_dat(img_path: str,
|
||||
existing_dat_path: Optional[str] = None,
|
||||
output_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
确保存在 dat 格式的水体掩膜文件(用于步骤3/4中的算法)
|
||||
|
||||
如果 existing_dat_path 存在且是 .dat 文件,直接返回。
|
||||
如果存在同名 .dat 文件,直接返回。
|
||||
否则从 img_path 生成并保存到 output_dir。
|
||||
|
||||
Args:
|
||||
img_path: 用于生成掩膜的遥感影像路径
|
||||
existing_dat_path: 已有的 dat 格式掩膜路径(可选)
|
||||
output_dir: 输出目录(可选)
|
||||
|
||||
Returns:
|
||||
dat 格式掩膜文件路径
|
||||
"""
|
||||
if existing_dat_path and Path(existing_dat_path).suffix.lower() == '.dat':
|
||||
if Path(existing_dat_path).exists():
|
||||
return existing_dat_path
|
||||
|
||||
img_name = Path(img_path).stem
|
||||
if output_dir is None:
|
||||
output_dir = str(Path(img_path).parent)
|
||||
|
||||
dat_path = str(Path(output_dir) / f"{img_name}_water_mask.dat")
|
||||
|
||||
if Path(dat_path).exists():
|
||||
return dat_path
|
||||
|
||||
# 如果已有其他格式的掩膜,转换为 dat
|
||||
for ext in ['.tif', '.img', '.tiff']:
|
||||
alt_path = str(Path(output_dir) / f"{img_name}_water_mask{ext}")
|
||||
if Path(alt_path).exists():
|
||||
return _convert_to_dat(alt_path, dat_path)
|
||||
|
||||
return dat_path # 返回目标路径,让调用方决定是否需要生成
|
||||
|
||||
|
||||
def _convert_to_dat(src_path: str, dest_path: str) -> str:
|
||||
"""将其他栅格格式转换为 ENVI dat 格式"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法转换格式")
|
||||
|
||||
src_ds = gdal.Open(src_path, gdal.GA_ReadOnly)
|
||||
if src_ds is None:
|
||||
raise ValueError(f"无法打开源掩膜文件: {src_path}")
|
||||
|
||||
try:
|
||||
geotransform = src_ds.GetGeoTransform()
|
||||
projection = src_ds.GetProjection()
|
||||
band = src_ds.GetRasterBand(1)
|
||||
array = band.ReadAsArray()
|
||||
|
||||
driver = gdal.GetDriverByName('ENVI')
|
||||
if driver is None:
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
|
||||
dest_ds = driver.Create(dest_path, src_ds.RasterXSize, src_ds.RasterYSize, 1, gdal.GDT_Byte)
|
||||
if dest_ds is None:
|
||||
raise ValueError(f"无法创建输出文件: {dest_path}")
|
||||
|
||||
try:
|
||||
dest_ds.SetGeoTransform(geotransform)
|
||||
dest_ds.SetProjection(projection)
|
||||
dest_band = dest_ds.GetRasterBand(1)
|
||||
dest_band.WriteArray((array > 0).astype(np.uint8))
|
||||
dest_band.FlushCache()
|
||||
finally:
|
||||
dest_ds = None
|
||||
|
||||
return dest_path
|
||||
finally:
|
||||
src_ds = None
|
||||
339
src/core/utils/preview_generator.py
Normal file
@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
遥感影像预览图生成工具模块
|
||||
|
||||
提供高光谱影像的 RGB 预览图、水域掩膜叠加图等可视化功能。
|
||||
"""
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
GDAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
GDAL_AVAILABLE = False
|
||||
|
||||
# matplotlib 仅在实际使用时导入(preview_generator 是可视化工具)
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Arial Unicode MS']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 辅助函数:波段选择
|
||||
# ============================================================
|
||||
|
||||
def select_rgb_bands_by_wavelength(band_count: int,
|
||||
wavelength_info: Optional[List[float]] = None,
|
||||
fallback_bands: Optional[List[int]] = None) -> List[int]:
|
||||
"""
|
||||
根据波长自动选择 RGB 波段
|
||||
|
||||
Args:
|
||||
band_count: 总波段数
|
||||
wavelength_info: 各波段波长列表(nm),长度为 band_count
|
||||
fallback_bands: 当无法通过波长选择时的回退波段索引 [R, G, B]
|
||||
|
||||
Returns:
|
||||
波段索引列表 [R_index, G_index, B_index](0-based)
|
||||
"""
|
||||
if fallback_bands is None:
|
||||
fallback_bands = [band_count - 3, band_count - 2, band_count - 1]
|
||||
|
||||
if wavelength_info is None:
|
||||
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
|
||||
|
||||
# 目标波长(nm)
|
||||
TARGET_R = 650
|
||||
TARGET_G = 550
|
||||
TARGET_B = 460
|
||||
|
||||
def find_closest(target: float) -> int:
|
||||
min_dist = float('inf')
|
||||
best_idx = 0
|
||||
for i, wl in enumerate(wavelength_info):
|
||||
dist = abs(wl - target)
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
best_idx = i
|
||||
return best_idx
|
||||
|
||||
try:
|
||||
r_idx = find_closest(TARGET_R)
|
||||
g_idx = find_closest(TARGET_G)
|
||||
b_idx = find_closest(TARGET_B)
|
||||
return [r_idx, g_idx, b_idx]
|
||||
except Exception:
|
||||
return [max(0, min(i, band_count - 1)) for i in fallback_bands]
|
||||
|
||||
|
||||
def get_wavelength_info(img_path: str) -> Optional[List[float]]:
|
||||
"""从 hdr 文件读取波长信息"""
|
||||
try:
|
||||
hdr_path = Path(img_path).with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
return None
|
||||
|
||||
wavelengths = []
|
||||
in_wl = False
|
||||
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith('wavelength ='):
|
||||
in_wl = True
|
||||
line = line.split('=', 1)[1].strip()
|
||||
elif in_wl:
|
||||
if line.startswith('{'):
|
||||
line = line[1:]
|
||||
if line.endswith('}'):
|
||||
line = line[:-1]
|
||||
in_wl = False
|
||||
# 解析逗号分隔的数值
|
||||
for token in line.replace(',', ' ').split():
|
||||
try:
|
||||
wavelengths.append(float(token))
|
||||
except ValueError:
|
||||
pass
|
||||
return wavelengths if wavelengths else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 核心预览图生成函数
|
||||
# ============================================================
|
||||
|
||||
def generate_image_preview(img_path: str,
|
||||
output_path: str,
|
||||
bands: Optional[List[int]] = None,
|
||||
title: str = "影像预览") -> str:
|
||||
"""
|
||||
生成高光谱影像的 PNG 预览图
|
||||
|
||||
Args:
|
||||
img_path: 输入影像路径
|
||||
output_path: 输出 PNG 文件路径
|
||||
bands: RGB 波段索引 [R, G, B],None 则自动选择
|
||||
title: 图片标题
|
||||
|
||||
Returns:
|
||||
生成的 PNG 文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法生成影像预览图")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预览图,跳过生成: {output_path}")
|
||||
return output_path
|
||||
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
|
||||
# 自动选择波段
|
||||
if bands is None:
|
||||
if band_count >= 3:
|
||||
wl_info = get_wavelength_info(img_path)
|
||||
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
|
||||
else:
|
||||
bands = [0, 0, 0]
|
||||
|
||||
# 读取波段
|
||||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||||
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
|
||||
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
|
||||
|
||||
r_data[r_data <= 0] = np.nan
|
||||
if band_count > 1:
|
||||
g_data[g_data <= 0] = np.nan
|
||||
if band_count > 2:
|
||||
b_data[b_data <= 0] = np.nan
|
||||
|
||||
# 线性拉伸
|
||||
def linear_stretch(data, low=2, high=98):
|
||||
valid = data[~np.isnan(data)]
|
||||
if len(valid) == 0:
|
||||
return np.zeros_like(data)
|
||||
lo = np.percentile(valid, low)
|
||||
hi = np.percentile(valid, high)
|
||||
if hi - lo < 1e-10:
|
||||
return np.zeros_like(data)
|
||||
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
|
||||
return np.nan_to_num(stretched, nan=0.0)
|
||||
|
||||
r_s = linear_stretch(r_data)
|
||||
g_s = linear_stretch(g_data) if band_count > 1 else r_s
|
||||
b_s = linear_stretch(b_data) if band_count > 2 else r_s
|
||||
|
||||
rgb_image = np.stack([r_s, g_s, b_s], axis=2)
|
||||
|
||||
# 绘图
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
ax.imshow(rgb_image)
|
||||
ax.set_title(title, fontsize=12, fontweight='bold')
|
||||
ax.axis('off')
|
||||
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
if geotransform and geotransform[1] != 0:
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
scale_text = f"分辨率: {pixel_size_x:.2f} m/px | 尺寸: {width} x {height} px"
|
||||
fig.text(0.5, 0.02, scale_text, ha='center', fontsize=9,
|
||||
color='white',
|
||||
bbox=dict(facecolor='black', alpha=0.6,
|
||||
boxstyle='round,pad=0.3'))
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.close(fig)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
|
||||
|
||||
def generate_water_mask_overlay(img_path: str,
|
||||
mask_path: str,
|
||||
output_path: str,
|
||||
bands: Optional[List[int]] = None,
|
||||
mask_color: tuple = (0, 100, 255),
|
||||
mask_alpha: float = 0.5) -> str:
|
||||
"""
|
||||
生成水域掩膜叠加到原图的 PNG 图像
|
||||
|
||||
Args:
|
||||
img_path: 输入影像路径
|
||||
mask_path: 水域掩膜文件路径
|
||||
output_path: 输出 PNG 路径
|
||||
bands: RGB 波段索引,None 则自动选择
|
||||
mask_color: 掩膜叠加颜色 (R, G, B)
|
||||
mask_alpha: 掩膜透明度(0=完全透明,1=完全不透明)
|
||||
|
||||
Returns:
|
||||
生成的 PNG 文件路径
|
||||
"""
|
||||
if not GDAL_AVAILABLE:
|
||||
raise ImportError("GDAL未安装,无法生成叠加图")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的叠加图,跳过生成: {output_path}")
|
||||
return output_path
|
||||
|
||||
dataset = gdal.Open(img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开影像文件: {img_path}")
|
||||
|
||||
try:
|
||||
width = dataset.RasterXSize
|
||||
height = dataset.RasterYSize
|
||||
band_count = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
|
||||
# 自动选择波段
|
||||
if bands is None:
|
||||
if band_count >= 3:
|
||||
wl_info = get_wavelength_info(img_path)
|
||||
bands = select_rgb_bands_by_wavelength(band_count, wl_info)
|
||||
else:
|
||||
bands = [0, 0, 0]
|
||||
|
||||
r_data = dataset.GetRasterBand(bands[0] + 1).ReadAsArray().astype(np.float32)
|
||||
g_data = r_data if band_count == 1 else dataset.GetRasterBand(bands[1] + 1).ReadAsArray().astype(np.float32)
|
||||
b_data = r_data if band_count <= 2 else dataset.GetRasterBand(bands[2] + 1).ReadAsArray().astype(np.float32)
|
||||
|
||||
r_data[r_data <= 0] = np.nan
|
||||
if band_count > 1:
|
||||
g_data[g_data <= 0] = np.nan
|
||||
if band_count > 2:
|
||||
b_data[b_data <= 0] = np.nan
|
||||
|
||||
def linear_stretch(data, low=2, high=98):
|
||||
valid = data[~np.isnan(data)]
|
||||
if len(valid) == 0:
|
||||
return np.zeros_like(data)
|
||||
lo = np.percentile(valid, low)
|
||||
hi = np.percentile(valid, high)
|
||||
if hi - lo < 1e-10:
|
||||
return np.zeros_like(data)
|
||||
stretched = np.clip((data - lo) / (hi - lo), 0, 1)
|
||||
return np.nan_to_num(stretched, nan=0.0)
|
||||
|
||||
r_s = linear_stretch(r_data)
|
||||
g_s = linear_stretch(g_data) if band_count > 1 else r_s
|
||||
b_s = linear_stretch(b_data) if band_count > 2 else r_s
|
||||
|
||||
rgb_image = np.nan_to_num(np.stack([r_s, g_s, b_s], axis=2)) * 255
|
||||
rgb_image = rgb_image.astype(np.uint8)
|
||||
|
||||
# 读取掩膜
|
||||
mask_dataset = gdal.Open(mask_path)
|
||||
if mask_dataset is not None:
|
||||
mask_data = mask_dataset.GetRasterBand(1).ReadAsArray()
|
||||
mask_dataset = None
|
||||
else:
|
||||
print(f"警告: 无法打开掩膜文件: {mask_path}")
|
||||
mask_data = None
|
||||
|
||||
# Alpha 混合
|
||||
overlay = np.zeros((height, width, 4), dtype=np.uint8)
|
||||
overlay[:, :, 0:3] = mask_color
|
||||
overlay[:, :, 3] = 255 # 全不透明
|
||||
|
||||
blended = rgb_image.astype(np.float32)
|
||||
if mask_data is not None:
|
||||
alpha = mask_data.astype(np.float32) / 255.0 * mask_alpha
|
||||
for c in range(3):
|
||||
blended[:, :, c] = rgb_image[:, :, c].astype(np.float32) * (1 - alpha) + mask_color[c] * alpha
|
||||
blended = blended.astype(np.uint8)
|
||||
|
||||
# 绘图
|
||||
fig, ax = plt.subplots(figsize=(14, 10))
|
||||
ax.imshow(blended)
|
||||
ax.axis('off')
|
||||
|
||||
legend_elements = [
|
||||
Patch(facecolor=f'#{mask_color[0]:02x}{mask_color[1]:02x}{mask_color[2]:02x}',
|
||||
edgecolor='black', alpha=mask_alpha, label='水域范围')
|
||||
]
|
||||
ax.legend(handles=legend_elements, loc='upper right', framealpha=0.9)
|
||||
|
||||
# 面积计算
|
||||
if geotransform and geotransform[1] != 0:
|
||||
pixel_size_x = abs(geotransform[1])
|
||||
pixel_size_y = abs(geotransform[5])
|
||||
pixel_area = pixel_size_x * pixel_size_y
|
||||
|
||||
if mask_data is not None:
|
||||
water_pixels = np.sum(mask_data > 0)
|
||||
valid_pixels = np.sum(mask_data >= 0)
|
||||
water_km2 = water_pixels * pixel_area / 1_000_000
|
||||
valid_km2 = valid_pixels * pixel_area / 1_000_000
|
||||
pct = (water_pixels / valid_pixels * 100) if valid_pixels > 0 else 0
|
||||
|
||||
area_text = (f'水域面积: {water_km2:.2f} km² | '
|
||||
f'影像总面积: {valid_km2:.2f} km² | '
|
||||
f'占比: {pct:.1f}%')
|
||||
ax.text(0.02, 0.98, area_text,
|
||||
transform=ax.transAxes, fontsize=11,
|
||||
color='white', fontweight='bold',
|
||||
bbox=dict(facecolor='#0064FF', alpha=0.8,
|
||||
edgecolor='black', boxstyle='round,pad=0.5'),
|
||||
verticalalignment='top')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight', pad_inches=0.1)
|
||||
plt.close(fig)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
dataset = None
|
||||
158
src/core/utils/split_methods.py
Normal file
@ -0,0 +1,158 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据集划分算法 —— SPXY / Kennard-Stone
|
||||
|
||||
从 modeling_batch.py / inference_batch.py / sctter_batch.py 中抽离,
|
||||
消除三处完全相同的重复实现。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def spxy(data, label, test_size=0.2):
|
||||
"""
|
||||
SPXY算法划分数据集(考虑X和Y空间的距离)
|
||||
|
||||
Args:
|
||||
data: shape (n_samples, n_features) —— np.ndarray 或 pd.DataFrame
|
||||
label: shape (n_samples, ) —— np.ndarray 或 pd.Series
|
||||
test_size: 测试集比例,默认: 0.2
|
||||
|
||||
Returns:
|
||||
X_train: (n_samples, n_features)
|
||||
X_test: (n_samples, n_features)
|
||||
y_train: (n_samples, )
|
||||
y_test: (n_samples, )
|
||||
"""
|
||||
data = data.to_numpy() if isinstance(data, pd.DataFrame) else data
|
||||
label = label.to_numpy() if isinstance(label, pd.Series) else label
|
||||
|
||||
x_backup = data
|
||||
y_backup = label
|
||||
|
||||
M = data.shape[0]
|
||||
N = round((1 - test_size) * M)
|
||||
samples = np.arange(M)
|
||||
|
||||
label = (label - np.mean(label)) / np.std(label)
|
||||
D = np.zeros((M, M))
|
||||
Dy = np.zeros((M, M))
|
||||
|
||||
for i in range(M - 1):
|
||||
xa = data[i, :]
|
||||
ya = label[i]
|
||||
for j in range((i + 1), M):
|
||||
xb = data[j, :]
|
||||
yb = label[j]
|
||||
D[i, j] = np.linalg.norm(xa - xb)
|
||||
Dy[i, j] = np.linalg.norm(ya - yb)
|
||||
|
||||
Dmax = np.max(D)
|
||||
Dymax = np.max(Dy)
|
||||
D = D / Dmax + Dy / Dymax
|
||||
|
||||
maxD = D.max(axis=0)
|
||||
index_row = D.argmax(axis=0)
|
||||
index_column = maxD.argmax()
|
||||
|
||||
m = np.zeros(N, dtype=int)
|
||||
m[0] = index_row[index_column]
|
||||
m[1] = index_column
|
||||
|
||||
dminmax = np.zeros(N)
|
||||
dminmax[1] = D[m[0], m[1]]
|
||||
|
||||
for i in range(2, N):
|
||||
pool = np.delete(samples, m[:i])
|
||||
dmin = np.zeros(M - i)
|
||||
for j in range(M - i):
|
||||
indexa = pool[j]
|
||||
d = np.zeros(i)
|
||||
for k in range(i):
|
||||
indexb = m[k]
|
||||
if indexa < indexb:
|
||||
d[k] = D[indexa, indexb]
|
||||
else:
|
||||
d[k] = D[indexb, indexa]
|
||||
dmin[j] = np.min(d)
|
||||
dminmax[i] = np.max(dmin)
|
||||
index = np.argmax(dmin)
|
||||
m[i] = pool[index]
|
||||
|
||||
m_complement = np.delete(samples, m)
|
||||
|
||||
X_train = data[m, :]
|
||||
y_train = y_backup[m]
|
||||
X_test = data[m_complement, :]
|
||||
y_test = y_backup[m_complement]
|
||||
|
||||
return X_train, X_test, y_train, y_test
|
||||
|
||||
|
||||
def ks(data, label, test_size=0.2):
|
||||
"""
|
||||
Kennard-Stone算法划分数据集
|
||||
|
||||
Args:
|
||||
data: shape (n_samples, n_features) —— np.ndarray 或 pd.DataFrame
|
||||
label: shape (n_samples, ) —— np.ndarray 或 pd.Series
|
||||
test_size: 测试集比例,默认: 0.2
|
||||
|
||||
Returns:
|
||||
X_train: (n_samples, n_features)
|
||||
X_test: (n_samples, n_features)
|
||||
y_train: (n_samples, )
|
||||
y_test: (n_samples, )
|
||||
"""
|
||||
data = data.to_numpy() if isinstance(data, pd.DataFrame) else data
|
||||
label = label.to_numpy() if isinstance(label, pd.Series) else label
|
||||
|
||||
M = data.shape[0]
|
||||
N = round((1 - test_size) * M)
|
||||
samples = np.arange(M)
|
||||
|
||||
D = np.zeros((M, M))
|
||||
|
||||
for i in range((M - 1)):
|
||||
xa = data[i, :]
|
||||
for j in range((i + 1), M):
|
||||
xb = data[j, :]
|
||||
D[i, j] = np.linalg.norm(xa - xb)
|
||||
|
||||
maxD = np.max(D, axis=0)
|
||||
index_row = np.argmax(D, axis=0)
|
||||
index_column = np.argmax(maxD)
|
||||
|
||||
m = np.zeros(N)
|
||||
m[0] = np.array(index_row[index_column])
|
||||
m[1] = np.array(index_column)
|
||||
m = m.astype(int)
|
||||
dminmax = np.zeros(N)
|
||||
dminmax[1] = D[m[0], m[1]]
|
||||
|
||||
for i in range(2, N):
|
||||
pool = np.delete(samples, m[:i])
|
||||
dmin = np.zeros((M - i))
|
||||
for j in range((M - i)):
|
||||
indexa = pool[j]
|
||||
d = np.zeros(i)
|
||||
for k in range(i):
|
||||
indexb = m[k]
|
||||
if indexa < indexb:
|
||||
d[k] = D[indexa, indexb]
|
||||
else:
|
||||
d[k] = D[indexb, indexa]
|
||||
dmin[j] = np.min(d)
|
||||
dminmax[i] = np.max(dmin)
|
||||
index = np.argmax(dmin)
|
||||
m[i] = pool[index]
|
||||
|
||||
m_complement = np.delete(np.arange(data.shape[0]), m)
|
||||
|
||||
X_train = data[m, :]
|
||||
y_train = label[m]
|
||||
X_test = data[m_complement, :]
|
||||
y_test = label[m_complement]
|
||||
|
||||
return X_train, X_test, y_train, y_test
|
||||