feat(step8): 新增水色指数反演模块 waterindex_inversion + CSV 公式驱动架构

This commit is contained in:
DXC
2026-06-10 17:13:25 +08:00
parent cfe4c50c31
commit 320f2f18f2
5 changed files with 1218 additions and 0 deletions

View File

@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""
水色指数反演模块(包入口)
从 waterindex.csv 读取公式,对去耀斑 BSQ 高光谱影像进行全图矩阵运算,
输出带完整坐标信息的 GeoTIFF。
公式格式waterindex.csv
- 波长占位符w{nm},如 w686, w708, w665
- 支持混合大小写w686 / W665 均可
- 示例NDCI = (w708 - w665) / (w708 + w665)
输出:
- GeoTIFF (Float32)LZW 压缩,带 Tile
- 完整克隆原始 BSQ 的 GeoTransform / Projection / NoData
- Step 14 可直接用 rasterio 读取数组和空间范围
"""
# 重新导出 WaterIndexProcessor向后兼容所有已有 import
from src.core.algorithms.waterindex_inversion import WaterIndexProcessor
__all__ = ['WaterIndexProcessor']

View File

@ -0,0 +1,646 @@
# -*- coding: utf-8 -*-
"""
水色指数反演模块
直接读取去耀斑高光谱 BSQ 影像,应用 waterindex.csv 中的公式,
输出各水质参数指数的 GeoTIFF 栅格图像。
公式格式waterindex.csv
- 波长占位符w{nm},如 w686, w708, w665
- 支持混合大小写w686 / W665 均可
- 示例NDCI = (w708 - w665) / (w708 + w665)
BGA_Am09KBBI = (w686 - w658) / (w686 + w658)
输出:
- GeoTIFF (Float32)LZW 压缩,带 Tile
- 完整克隆原始 BSQ 的 GeoTransform / Projection / NoData
- Step 14 可直接用 rasterio 读取进行克里金插值
"""
from __future__ import annotations
import csv
import os
import re
import sys
import time
import traceback
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
from osgeo import gdal, osr
# GDAL 驱动注册
gdal.UseExceptions()
# ------------------------------------------------------------------
# 公共工具
# ------------------------------------------------------------------
def _get_resource_path(relative_path: str) -> str:
"""获取 waterindex.csv 等资源的绝对路径,兼容 PyInstaller 打包。"""
if hasattr(sys, '_MEIPASS'):
base = sys._MEIPASS
else:
base = os.path.abspath(
os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', '..', '..')
)
return os.path.join(base, relative_path)
# ------------------------------------------------------------------
# WaterIndexProcessor
# ------------------------------------------------------------------
class WaterIndexProcessor:
"""
水色指数处理器
读取 waterindex.csv 中的公式,应用于 BSQ 高光谱影像,
输出带完整坐标信息的 GeoTIFF 指数图。
核心能力:
- 公式解析w{nm} 占位符 → 实际波段 2D numpy 数组
- 矩阵运算:全影像批量计算,无需逐点循环
- 地理信息保持:克隆原始 BSQ 的 GeoTransform / Projection
- NoData 处理:运算中产生的 NaN/Inf 统一标记为 -9999
"""
# 内置安全命名空间(公式 eval 白名单)
_SAFE_NS: Dict[str, Any] = {
'np': np,
'nan': np.nan,
'inf': np.inf,
'pi': np.pi,
'e': np.e,
}
def __init__(self, waterindex_csv_path: Optional[str] = None):
"""
Parameters
----------
waterindex_csv_path : str, optional
waterindex.csv 路径。
若为 None尝试从默认位置加载
1. src/gui/model/waterindex.csv开发环境
2. _MEIPASS/src/gui/model/waterindex.csv打包环境
"""
self.csv_path: Optional[str] = None
self.formulas: List[Dict[str, Any]] = []
if waterindex_csv_path:
self.csv_path = waterindex_csv_path
else:
candidates = [
os.path.join(os.path.dirname(__file__), '..', '..', 'gui', 'model', 'waterindex.csv'),
os.path.join(os.path.dirname(__file__), '..', '..', '..', 'gui', 'model', 'waterindex.csv'),
]
for p in candidates:
if os.path.isfile(p):
self.csv_path = p
break
if self.csv_path:
self._parse_csv()
else:
self.formulas = []
# ------------------------------------------------------------------
# 公式加载
# ------------------------------------------------------------------
def _parse_csv(self) -> None:
"""解析 waterindex.csv加载所有公式。"""
if not os.path.isfile(self.csv_path):
raise FileNotFoundError(f"公式配置文件不存在: {self.csv_path}")
# ★★★ 防止多次调用时公式翻倍叠加 ★★★
self.formulas.clear()
with open(self.csv_path, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
self.formulas.append(dict(row))
print(f"[WaterIndexProcessor] 加载 {len(self.formulas)} 条公式 ← {self.csv_path}")
def reload(self, waterindex_csv_path: str) -> None:
"""重新加载公式配置文件。"""
self.csv_path = waterindex_csv_path
self._parse_csv()
# ------------------------------------------------------------------
# 公式查询
# ------------------------------------------------------------------
def list_formulas(self) -> List[Dict[str, Any]]:
"""返回所有公式的列表。"""
return list(self.formulas)
def list_formula_names(self) -> List[str]:
"""返回所有公式名称列表。"""
return [f.get('Formula_Name', '') for f in self.formulas]
def get_formula(self, name: str) -> Optional[Dict[str, Any]]:
"""按名称查找单个公式。"""
for f in self.formulas:
if f.get('Formula_Name', '').strip() == name.strip():
return f
return None
def list_categories(self) -> List[str]:
"""返回所有公式类别(去重排序)。"""
cats = set()
for f in self.formulas:
c = f.get('Category', '').strip()
if c:
cats.add(c)
return sorted(cats)
def get_formulas_by_category(self, category: str) -> List[Dict[str, Any]]:
"""按类别筛选公式。"""
return [f for f in self.formulas
if f.get('Category', '').strip().lower() == category.strip().lower()]
# ------------------------------------------------------------------
# 影像元数据
# ------------------------------------------------------------------
def get_image_metadata(self, bsq_path: str, hdr_path: Optional[str] = None) -> Dict[str, Any]:
"""获取影像元数据GDAL + ENVI HDR 双重保障)。
Parameters
----------
bsq_path : str
BSQ 影像路径
hdr_path : str, optional
ENVI HDR 路径None → 自动构造)
Returns
-------
dict
含 keys: width, height, bands, wavelengths, wavelength_range,
geotransform, projection, driver
"""
meta: Dict[str, Any] = {}
# 1. GDAL 优先(获取空间信息)
try:
ds = gdal.Open(bsq_path, gdal.GA_ReadOnly)
if ds is not None:
meta['width'] = ds.RasterXSize
meta['height'] = ds.RasterYSize
meta['bands'] = ds.RasterCount
meta['driver'] = ds.GetDriver().ShortName
gt = ds.GetGeoTransform()
proj = ds.GetProjection()
if gt and gt != (0, 1, 0, 0, 0, 1):
meta['geotransform'] = gt
if proj:
meta['projection'] = proj
ds = None
except Exception:
pass
# 2. HDR 补充波长信息
if hdr_path is None:
hdr_path = os.path.splitext(bsq_path)[0] + '.hdr'
if not os.path.isfile(hdr_path):
hdr_path_alt = os.path.splitext(bsq_path)[0] + '.HDR'
if os.path.isfile(hdr_path_alt):
hdr_path = hdr_path_alt
if os.path.isfile(hdr_path):
wl = self._parse_wavelengths_from_hdr(hdr_path)
if wl:
meta['wavelengths'] = wl
if len(wl) >= 2:
meta['wavelength_range'] = f"{wl[0]:.1f}{wl[-1]:.1f} nm ({len(wl)} 波段)"
elif meta.get('bands', 0) > 0:
meta['wavelength_range'] = f"{meta['bands']} 波段(波长信息缺失)"
return meta
@staticmethod
def _parse_wavelengths_from_hdr(hdr_path: str) -> Optional[List[float]]:
"""从 ENVI .hdr 文件中解析波长列表。"""
try:
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# 格式1wavelength = { 400, 401, ... }
m = re.search(r'wavelength\s*=\s*\{([^}]+)\}', content, re.DOTALL)
if m:
vals = [float(v) for v in re.findall(r'[\d.]+', m.group(1)) if v.strip()]
if vals:
return vals
# 格式2逐行罗列
wavelengths: List[float] = []
in_wl = False
for line in content.split('\n'):
line = line.strip()
if line.startswith('wavelength'):
in_wl = True
continue
if in_wl:
if line.startswith('{'):
continue
try:
wavelengths.append(float(line))
except ValueError:
if '}' in line:
in_wl = False
return wavelengths if wavelengths else None
except Exception:
return None
# ------------------------------------------------------------------
# 公式解析w{nm} 占位符 → 实际波段数据
# ------------------------------------------------------------------
def _find_nearest_band_index(self, target_wv: float,
wavelengths: List[float]) -> int:
"""找到最接近目标波长的 GDAL 波段索引1-based"""
if not wavelengths:
raise ValueError("波长列表为空,无法匹配波段")
nearest = min(range(len(wavelengths)),
key=lambda i: abs(wavelengths[i] - target_wv))
return nearest + 1 # GDAL 波段从 1 开始
def _parse_formula_wavelengths(self, formula: str) -> List[int]:
"""从公式字符串中提取所有波长值去重int"""
raw = re.findall(r'[wW](\d+)', formula)
seen = set()
result: List[int] = []
for r in raw:
v = int(r)
if v not in seen:
seen.add(v)
result.append(v)
return result
def _eval_formula_fast(self, formula: str,
band_data: Dict[int, np.ndarray]) -> Optional[np.ndarray]:
"""快速公式求值(预处理后直接 eval
band_data: {波长int: 2D 数组}
formula 示例: "(w708 - w665) / (w708 + w665)"
"""
# 预处理w708 → _B708避免与 Python 关键字冲突)
processed = re.sub(r'[wW](\d+)', r'_B\1', formula)
# 构建局部变量表_B708 = band_data[708]
local_vars = {f"_B{wv}": arr for wv, arr in band_data.items()}
local_vars.update(self._SAFE_NS)
try:
result = eval(processed, {"__builtins__": {}}, local_vars)
return result
except Exception as e:
print(f" ⚠ 公式求值失败 [{formula}]: {e}")
return None
# ------------------------------------------------------------------
# 单波段读取(带 NoData 处理)
# ------------------------------------------------------------------
@staticmethod
def _read_band_as_float(bsq_path: str, band_idx: int) -> np.ndarray:
"""读取 BSQ 指定波段1-based返回 float64NaN 替换 NoData。"""
ds = gdal.Open(bsq_path, gdal.GA_ReadOnly)
if ds is None:
raise RuntimeError(f"无法用 GDAL 打开影像: {bsq_path}")
band = ds.GetRasterBand(band_idx)
arr = band.ReadAsArray()
nodata = band.GetNoDataValue()
ds = None
arr = arr.astype(np.float64)
if nodata is not None:
arr = np.where(arr == nodata, np.nan, arr)
return arr
# ------------------------------------------------------------------
# 核心处理:逐公式矩阵运算 + GeoTIFF 输出
# ------------------------------------------------------------------
def process_bsq(
self,
bsq_path: str,
hdr_path: Optional[str] = None,
output_dir: Optional[str] = None,
formula_names: Optional[List[str]] = None,
water_mask: Optional[np.ndarray] = None,
nodata_value: float = -9999.0,
progress_callback: Optional[Callable[[str, float], None]] = None,
) -> Dict[str, str]:
"""逐公式处理 BSQ 影像,输出 GeoTIFF。
Parameters
----------
bsq_path : str
去耀斑 BSQ 影像路径
hdr_path : str, optional
ENVI HDR 文件路径None → 自动构造)
output_dir : str, optional
输出目录None → 与 bsq_path 同目录下的 8_WaterIndex_Images/
formula_names : list, optional
要处理的公式名列表None → 处理全部)
water_mask : np.ndarray, optional
水域掩膜数组(与 BSQ 同形状),掩膜值为 0 表示陆地,
将被强制赋值为 nodata_value
nodata_value : float
NoData 标记值
progress_callback : callable, optional
回调 (msg: str, pct: float)
Returns
-------
dict
{公式名: 输出 GeoTIFF 路径}
"""
# ── 自动构造 HDR 路径 ────────────────────────────────────────────
if hdr_path is None:
hdr_path = os.path.splitext(bsq_path)[0] + '.hdr'
if not os.path.isfile(hdr_path):
hdr_path_alt = os.path.splitext(bsq_path)[0] + '.HDR'
if os.path.isfile(hdr_path_alt):
hdr_path = hdr_path_alt
# ── 自动构造输出目录 ────────────────────────────────────────────
if output_dir is None:
output_dir = os.path.join(os.path.dirname(bsq_path), '8_WaterIndex_Images')
os.makedirs(output_dir, exist_ok=True)
def progress(msg: str, pct: float):
if progress_callback:
progress_callback(msg, pct)
# ── 获取影像元数据 ───────────────────────────────────────────────
progress("正在打开影像并读取元数据…", 2)
meta = self.get_image_metadata(bsq_path, hdr_path)
width = meta.get('width', 0)
height = meta.get('height', 0)
n_bands = meta.get('bands', 0)
wavelengths = meta.get('wavelengths', [])
geotransform = meta.get('geotransform')
projection = meta.get('projection')
if n_bands == 0 or width == 0 or height == 0:
raise ValueError(f"影像元数据无效,无法处理: {bsq_path}")
if not wavelengths:
raise ValueError(f"无法从 {hdr_path} 读取波长信息,公式无法解析")
progress(
f"影像: {width}×{height}像素, {n_bands}波段, "
f"波长 {wavelengths[0]:.1f}{wavelengths[-1]:.1f}nm",
5
)
# ── 过滤要处理的公式 ──────────────────────────────────────────────
if formula_names:
formulas_to_run = [
f for f in self.formulas
if f.get('Formula_Name', '').strip() in formula_names
]
else:
formulas_to_run = list(self.formulas)
results: Dict[str, str] = {}
total = len(formulas_to_run)
# ── 逐公式处理 ───────────────────────────────────────────────────
for i, formula_row in enumerate(formulas_to_run):
fname = formula_row.get('Formula_Name', '').strip()
fstr = formula_row.get('Formula', '').strip()
category = formula_row.get('Category', '').strip()
ftype = formula_row.get('Formula_Type', '').strip()
if not fname or not fstr:
continue
progress(
f"[{i + 1}/{total}] {fname} ({category})",
5 + 90 * i / total
)
try:
# 1) 提取公式所需的波长列表
required_wvs = self._parse_formula_wavelengths(fstr)
# 2) 按需读取波段数据(相同波长只读一次)
band_data: Dict[int, np.ndarray] = {}
for wv in required_wvs:
if wv not in band_data:
band_idx = self._find_nearest_band_index(wv, wavelengths)
if not (0 < band_idx <= n_bands):
print(f" ⚠ 公式 '{fname}' 引用波段 {band_idx},超出范围 ({n_bands}),跳过")
raise ValueError(f"波段 {band_idx} 超出影像范围")
band_data[wv] = self._read_band_as_float(bsq_path, band_idx)
# 3) 矩阵运算
index_arr = self._eval_formula_fast(fstr, band_data)
if index_arr is None:
print(f" ⚠ 公式 '{fname}' 计算失败,跳过")
continue
# 4) NoData 处理NaN / Inf → nodata_value
index_arr = np.where(np.isfinite(index_arr), index_arr, nodata_value)
# 4b) 水域掩膜拦截陆地像素mask==0强制赋 NoData
if water_mask is not None:
land_pixels = (water_mask == 0)
land_count = int(land_pixels.sum())
if land_count > 0:
index_arr = np.where(land_pixels, nodata_value, index_arr)
print(f" 🗺 掩膜处理:陆地像素 {land_count:,} 个已设为 NoData")
# 5) 输出 GeoTIFF
safe_fname = re.sub(r'[^\w\u4e00-\u9fff-]', '_', fname)
out_tif = os.path.join(output_dir, f"{safe_fname}.tif")
self._write_geotiff(
out_path=out_tif,
data=index_arr,
reference_bsq=bsq_path,
nodata_value=nodata_value,
description=f"{fname}|{category}|{ftype}|{fstr}",
)
results[fname] = out_tif
valid = index_arr[index_arr != nodata_value]
mean_val = float(np.mean(valid)) if valid.size else np.nan
print(f"{fname}{out_tif} (mean={mean_val:.4f})")
except ValueError as ve:
print(f" ⏭ 跳过 '{fname}': {ve}")
continue
except Exception as e:
print(f" ❌ 公式 '{fname}' 失败: {e}\n{traceback.format_exc()}")
continue
progress(f"完成!共输出 {len(results)} / {total} 个指数图", 100)
return results
def _write_geotiff(
self,
out_path: str,
data: np.ndarray,
reference_bsq: str,
nodata_value: float = -9999.0,
description: str = "",
) -> None:
"""将数组写入 GeoTIFF克隆原始 BSQ 的地理信息。
Parameters
----------
out_path : str
输出 GeoTIFF 路径
data : np.ndarray
2D 数据数组height, width
reference_bsq : str
参考 BSQ 影像路径(用于克隆 GeoTransform / Projection
nodata_value : float
NoData 标记值
description : str
GDAL 数据集描述
"""
height, width = data.shape
driver = gdal.GetDriverByName('GTiff')
if driver is None:
raise RuntimeError("GDAL GTiff 驱动不可用")
out_ds = driver.Create(
out_path,
width, height,
1,
gdal.GDT_Float32,
options=['COMPRESS=LZW', 'TILED=YES', 'BIGTIFF=IF_SAFER'],
)
if out_ds is None:
raise RuntimeError(f"无法创建 GeoTIFF: {out_path}")
# 写入数据
out_band = out_ds.GetRasterBand(1)
out_band.SetNoDataValue(nodata_value)
out_band.WriteArray(data)
out_band.FlushCache()
# 写入描述
if description:
out_band.SetDescription(description)
# ★★★ 克隆原始 BSQ 的 GeoTransform 和 Projection ★★★
ref_ds = gdal.Open(reference_bsq, gdal.GA_ReadOnly)
if ref_ds is not None:
gt = ref_ds.GetGeoTransform()
proj = ref_ds.GetProjection()
if gt and gt != (0, 1, 0, 0, 0, 1):
out_ds.SetGeoTransform(gt)
if proj:
out_ds.SetProjection(proj)
ref_ds = None
out_ds = None
# ------------------------------------------------------------------
# Pipeline 入口(供 PipelineRunner 调用)
# ------------------------------------------------------------------
def run_inversion(
self,
deglint_img_path: str,
work_dir: str,
formula_csv_path: Optional[str] = None,
selected_formulas: Optional[List[str]] = None,
water_mask_path: Optional[str] = None,
nodata_value: float = -9999.0,
callback: Optional[Callable] = None,
**kwargs,
) -> Dict[str, str]:
"""Pipeline 入口方法。
Parameters
----------
deglint_img_path : str
去耀斑影像 BSQ 路径
work_dir : str
工作目录
formula_csv_path : str, optional
waterindex.csv 路径None → 使用初始化时的路径)
selected_formulas : list, optional
要处理的公式列表
water_mask_path : str, optional
水域掩膜路径(如 1_water_mask/water_mask.dat
掩膜中为 0 的像素视为陆地区域,其指数值将被强制设为 NoData。
nodata_value : float
NoData 标记值,默认 -9999.0
callback : callable, optional
进度回调
Returns
-------
dict
{公式名: 输出 GeoTIFF 路径}
"""
# 重新加载公式(如指定了新路径)
if formula_csv_path:
self.reload(formula_csv_path)
elif not self.formulas:
raise RuntimeError("WaterIndexProcessor 未加载公式,请指定 formula_csv_path")
def notify(msg: str, pct: float):
if callback:
callback(msg, pct)
notify("开始水色指数反演", 0)
bsq_path = deglint_img_path
hdr_path = os.path.splitext(bsq_path)[0] + '.hdr'
if not os.path.isfile(hdr_path):
hdr_path_alt = os.path.splitext(bsq_path)[0] + '.HDR'
if os.path.isfile(hdr_path_alt):
hdr_path = hdr_path_alt
output_dir = os.path.join(work_dir, "8_WaterIndex_Images")
# ── 加载水域掩膜(可选)───────────────────────────────────────
water_mask: Optional[np.ndarray] = None
if water_mask_path:
if os.path.isfile(water_mask_path):
try:
import rasterio
with rasterio.open(water_mask_path) as msrc:
water_mask = msrc.read(1)
print(f"[run_inversion] 水域掩膜已加载: {water_mask_path}"
f"形状={water_mask.shape}"
f"陆地区域(0)={int((water_mask == 0).sum())}"
f"水区域(>0)={int((water_mask > 0).sum())}")
except Exception as mask_err:
print(f"[run_inversion] ⚠ 掩膜加载失败,跳过掩膜处理: {mask_err}")
water_mask = None
else:
print(f"[run_inversion] ⚠ 水域掩膜文件不存在: {water_mask_path},跳过掩膜处理")
notify("水色指数处理中…", 20)
results = self.process_bsq(
bsq_path=bsq_path,
hdr_path=hdr_path,
output_dir=output_dir,
formula_names=selected_formulas,
water_mask=water_mask,
nodata_value=nodata_value,
progress_callback=lambda m, p: notify(m, 20 + 70 * p / 100),
)
notify("水色指数反演完成", 100)
return results