feat(step8): 新增水色指数反演模块 waterindex_inversion + CSV 公式驱动架构
This commit is contained in:
22
src/core/algorithms/waterindex_inversion.py
Normal file
22
src/core/algorithms/waterindex_inversion.py
Normal file
@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
水色指数反演模块(包入口)
|
||||
|
||||
从 waterindex.csv 读取公式,对去耀斑 BSQ 高光谱影像进行全图矩阵运算,
|
||||
输出带完整坐标信息的 GeoTIFF。
|
||||
|
||||
公式格式(waterindex.csv):
|
||||
- 波长占位符:w{nm},如 w686, w708, w665
|
||||
- 支持混合大小写:w686 / W665 均可
|
||||
- 示例:NDCI = (w708 - w665) / (w708 + w665)
|
||||
|
||||
输出:
|
||||
- GeoTIFF (Float32),LZW 压缩,带 Tile
|
||||
- 完整克隆原始 BSQ 的 GeoTransform / Projection / NoData
|
||||
- Step 14 可直接用 rasterio 读取数组和空间范围
|
||||
"""
|
||||
|
||||
# 重新导出 WaterIndexProcessor(向后兼容所有已有 import)
|
||||
from src.core.algorithms.waterindex_inversion import WaterIndexProcessor
|
||||
|
||||
__all__ = ['WaterIndexProcessor']
|
||||
646
src/core/algorithms/waterindex_inversion/__init__.py
Normal file
646
src/core/algorithms/waterindex_inversion/__init__.py
Normal file
@ -0,0 +1,646 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
水色指数反演模块
|
||||
|
||||
直接读取去耀斑高光谱 BSQ 影像,应用 waterindex.csv 中的公式,
|
||||
输出各水质参数指数的 GeoTIFF 栅格图像。
|
||||
|
||||
公式格式(waterindex.csv):
|
||||
- 波长占位符:w{nm},如 w686, w708, w665
|
||||
- 支持混合大小写:w686 / W665 均可
|
||||
- 示例:NDCI = (w708 - w665) / (w708 + w665)
|
||||
BGA_Am09KBBI = (w686 - w658) / (w686 + w658)
|
||||
|
||||
输出:
|
||||
- GeoTIFF (Float32),LZW 压缩,带 Tile
|
||||
- 完整克隆原始 BSQ 的 GeoTransform / Projection / NoData
|
||||
- Step 14 可直接用 rasterio 读取进行克里金插值
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from osgeo import gdal, osr
|
||||
|
||||
# GDAL 驱动注册
|
||||
gdal.UseExceptions()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公共工具
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_resource_path(relative_path: str) -> str:
|
||||
"""获取 waterindex.csv 等资源的绝对路径,兼容 PyInstaller 打包。"""
|
||||
if hasattr(sys, '_MEIPASS'):
|
||||
base = sys._MEIPASS
|
||||
else:
|
||||
base = os.path.abspath(
|
||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), '..', '..', '..')
|
||||
)
|
||||
return os.path.join(base, relative_path)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WaterIndexProcessor
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class WaterIndexProcessor:
|
||||
"""
|
||||
水色指数处理器
|
||||
|
||||
读取 waterindex.csv 中的公式,应用于 BSQ 高光谱影像,
|
||||
输出带完整坐标信息的 GeoTIFF 指数图。
|
||||
|
||||
核心能力:
|
||||
- 公式解析:w{nm} 占位符 → 实际波段 2D numpy 数组
|
||||
- 矩阵运算:全影像批量计算,无需逐点循环
|
||||
- 地理信息保持:克隆原始 BSQ 的 GeoTransform / Projection
|
||||
- NoData 处理:运算中产生的 NaN/Inf 统一标记为 -9999
|
||||
"""
|
||||
|
||||
# 内置安全命名空间(公式 eval 白名单)
|
||||
_SAFE_NS: Dict[str, Any] = {
|
||||
'np': np,
|
||||
'nan': np.nan,
|
||||
'inf': np.inf,
|
||||
'pi': np.pi,
|
||||
'e': np.e,
|
||||
}
|
||||
|
||||
def __init__(self, waterindex_csv_path: Optional[str] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
waterindex_csv_path : str, optional
|
||||
waterindex.csv 路径。
|
||||
若为 None,尝试从默认位置加载:
|
||||
1. src/gui/model/waterindex.csv(开发环境)
|
||||
2. _MEIPASS/src/gui/model/waterindex.csv(打包环境)
|
||||
"""
|
||||
self.csv_path: Optional[str] = None
|
||||
self.formulas: List[Dict[str, Any]] = []
|
||||
|
||||
if waterindex_csv_path:
|
||||
self.csv_path = waterindex_csv_path
|
||||
else:
|
||||
candidates = [
|
||||
os.path.join(os.path.dirname(__file__), '..', '..', 'gui', 'model', 'waterindex.csv'),
|
||||
os.path.join(os.path.dirname(__file__), '..', '..', '..', 'gui', 'model', 'waterindex.csv'),
|
||||
]
|
||||
for p in candidates:
|
||||
if os.path.isfile(p):
|
||||
self.csv_path = p
|
||||
break
|
||||
|
||||
if self.csv_path:
|
||||
self._parse_csv()
|
||||
else:
|
||||
self.formulas = []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公式加载
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _parse_csv(self) -> None:
|
||||
"""解析 waterindex.csv,加载所有公式。"""
|
||||
if not os.path.isfile(self.csv_path):
|
||||
raise FileNotFoundError(f"公式配置文件不存在: {self.csv_path}")
|
||||
|
||||
# ★★★ 防止多次调用时公式翻倍叠加 ★★★
|
||||
self.formulas.clear()
|
||||
|
||||
with open(self.csv_path, 'r', encoding='utf-8-sig') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
self.formulas.append(dict(row))
|
||||
|
||||
print(f"[WaterIndexProcessor] 加载 {len(self.formulas)} 条公式 ← {self.csv_path}")
|
||||
|
||||
def reload(self, waterindex_csv_path: str) -> None:
|
||||
"""重新加载公式配置文件。"""
|
||||
self.csv_path = waterindex_csv_path
|
||||
self._parse_csv()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公式查询
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_formulas(self) -> List[Dict[str, Any]]:
|
||||
"""返回所有公式的列表。"""
|
||||
return list(self.formulas)
|
||||
|
||||
def list_formula_names(self) -> List[str]:
|
||||
"""返回所有公式名称列表。"""
|
||||
return [f.get('Formula_Name', '') for f in self.formulas]
|
||||
|
||||
def get_formula(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""按名称查找单个公式。"""
|
||||
for f in self.formulas:
|
||||
if f.get('Formula_Name', '').strip() == name.strip():
|
||||
return f
|
||||
return None
|
||||
|
||||
def list_categories(self) -> List[str]:
|
||||
"""返回所有公式类别(去重排序)。"""
|
||||
cats = set()
|
||||
for f in self.formulas:
|
||||
c = f.get('Category', '').strip()
|
||||
if c:
|
||||
cats.add(c)
|
||||
return sorted(cats)
|
||||
|
||||
def get_formulas_by_category(self, category: str) -> List[Dict[str, Any]]:
|
||||
"""按类别筛选公式。"""
|
||||
return [f for f in self.formulas
|
||||
if f.get('Category', '').strip().lower() == category.strip().lower()]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 影像元数据
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_image_metadata(self, bsq_path: str, hdr_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""获取影像元数据(GDAL + ENVI HDR 双重保障)。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bsq_path : str
|
||||
BSQ 影像路径
|
||||
hdr_path : str, optional
|
||||
ENVI HDR 路径(None → 自动构造)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
含 keys: width, height, bands, wavelengths, wavelength_range,
|
||||
geotransform, projection, driver
|
||||
"""
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
# 1. GDAL 优先(获取空间信息)
|
||||
try:
|
||||
ds = gdal.Open(bsq_path, gdal.GA_ReadOnly)
|
||||
if ds is not None:
|
||||
meta['width'] = ds.RasterXSize
|
||||
meta['height'] = ds.RasterYSize
|
||||
meta['bands'] = ds.RasterCount
|
||||
meta['driver'] = ds.GetDriver().ShortName
|
||||
gt = ds.GetGeoTransform()
|
||||
proj = ds.GetProjection()
|
||||
if gt and gt != (0, 1, 0, 0, 0, 1):
|
||||
meta['geotransform'] = gt
|
||||
if proj:
|
||||
meta['projection'] = proj
|
||||
ds = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. HDR 补充波长信息
|
||||
if hdr_path is None:
|
||||
hdr_path = os.path.splitext(bsq_path)[0] + '.hdr'
|
||||
if not os.path.isfile(hdr_path):
|
||||
hdr_path_alt = os.path.splitext(bsq_path)[0] + '.HDR'
|
||||
if os.path.isfile(hdr_path_alt):
|
||||
hdr_path = hdr_path_alt
|
||||
|
||||
if os.path.isfile(hdr_path):
|
||||
wl = self._parse_wavelengths_from_hdr(hdr_path)
|
||||
if wl:
|
||||
meta['wavelengths'] = wl
|
||||
if len(wl) >= 2:
|
||||
meta['wavelength_range'] = f"{wl[0]:.1f}–{wl[-1]:.1f} nm ({len(wl)} 波段)"
|
||||
elif meta.get('bands', 0) > 0:
|
||||
meta['wavelength_range'] = f"{meta['bands']} 波段(波长信息缺失)"
|
||||
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def _parse_wavelengths_from_hdr(hdr_path: str) -> Optional[List[float]]:
|
||||
"""从 ENVI .hdr 文件中解析波长列表。"""
|
||||
try:
|
||||
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# 格式1:wavelength = { 400, 401, ... }
|
||||
m = re.search(r'wavelength\s*=\s*\{([^}]+)\}', content, re.DOTALL)
|
||||
if m:
|
||||
vals = [float(v) for v in re.findall(r'[\d.]+', m.group(1)) if v.strip()]
|
||||
if vals:
|
||||
return vals
|
||||
|
||||
# 格式2:逐行罗列
|
||||
wavelengths: List[float] = []
|
||||
in_wl = False
|
||||
for line in content.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith('wavelength'):
|
||||
in_wl = True
|
||||
continue
|
||||
if in_wl:
|
||||
if line.startswith('{'):
|
||||
continue
|
||||
try:
|
||||
wavelengths.append(float(line))
|
||||
except ValueError:
|
||||
if '}' in line:
|
||||
in_wl = False
|
||||
return wavelengths if wavelengths else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公式解析:w{nm} 占位符 → 实际波段数据
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _find_nearest_band_index(self, target_wv: float,
|
||||
wavelengths: List[float]) -> int:
|
||||
"""找到最接近目标波长的 GDAL 波段索引(1-based)。"""
|
||||
if not wavelengths:
|
||||
raise ValueError("波长列表为空,无法匹配波段")
|
||||
nearest = min(range(len(wavelengths)),
|
||||
key=lambda i: abs(wavelengths[i] - target_wv))
|
||||
return nearest + 1 # GDAL 波段从 1 开始
|
||||
|
||||
def _parse_formula_wavelengths(self, formula: str) -> List[int]:
|
||||
"""从公式字符串中提取所有波长值(去重,int)。"""
|
||||
raw = re.findall(r'[wW](\d+)', formula)
|
||||
seen = set()
|
||||
result: List[int] = []
|
||||
for r in raw:
|
||||
v = int(r)
|
||||
if v not in seen:
|
||||
seen.add(v)
|
||||
result.append(v)
|
||||
return result
|
||||
|
||||
def _eval_formula_fast(self, formula: str,
|
||||
band_data: Dict[int, np.ndarray]) -> Optional[np.ndarray]:
|
||||
"""快速公式求值(预处理后直接 eval)。
|
||||
|
||||
band_data: {波长int: 2D 数组}
|
||||
formula 示例: "(w708 - w665) / (w708 + w665)"
|
||||
"""
|
||||
# 预处理:w708 → _B708(避免与 Python 关键字冲突)
|
||||
processed = re.sub(r'[wW](\d+)', r'_B\1', formula)
|
||||
|
||||
# 构建局部变量表:_B708 = band_data[708]
|
||||
local_vars = {f"_B{wv}": arr for wv, arr in band_data.items()}
|
||||
local_vars.update(self._SAFE_NS)
|
||||
|
||||
try:
|
||||
result = eval(processed, {"__builtins__": {}}, local_vars)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" ⚠ 公式求值失败 [{formula}]: {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 单波段读取(带 NoData 处理)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _read_band_as_float(bsq_path: str, band_idx: int) -> np.ndarray:
|
||||
"""读取 BSQ 指定波段(1-based),返回 float64,NaN 替换 NoData。"""
|
||||
ds = gdal.Open(bsq_path, gdal.GA_ReadOnly)
|
||||
if ds is None:
|
||||
raise RuntimeError(f"无法用 GDAL 打开影像: {bsq_path}")
|
||||
|
||||
band = ds.GetRasterBand(band_idx)
|
||||
arr = band.ReadAsArray()
|
||||
nodata = band.GetNoDataValue()
|
||||
ds = None
|
||||
|
||||
arr = arr.astype(np.float64)
|
||||
if nodata is not None:
|
||||
arr = np.where(arr == nodata, np.nan, arr)
|
||||
|
||||
return arr
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 核心处理:逐公式矩阵运算 + GeoTIFF 输出
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def process_bsq(
|
||||
self,
|
||||
bsq_path: str,
|
||||
hdr_path: Optional[str] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
formula_names: Optional[List[str]] = None,
|
||||
water_mask: Optional[np.ndarray] = None,
|
||||
nodata_value: float = -9999.0,
|
||||
progress_callback: Optional[Callable[[str, float], None]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""逐公式处理 BSQ 影像,输出 GeoTIFF。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bsq_path : str
|
||||
去耀斑 BSQ 影像路径
|
||||
hdr_path : str, optional
|
||||
ENVI HDR 文件路径(None → 自动构造)
|
||||
output_dir : str, optional
|
||||
输出目录(None → 与 bsq_path 同目录下的 8_WaterIndex_Images/)
|
||||
formula_names : list, optional
|
||||
要处理的公式名列表(None → 处理全部)
|
||||
water_mask : np.ndarray, optional
|
||||
水域掩膜数组(与 BSQ 同形状),掩膜值为 0 表示陆地,
|
||||
将被强制赋值为 nodata_value
|
||||
nodata_value : float
|
||||
NoData 标记值
|
||||
progress_callback : callable, optional
|
||||
回调 (msg: str, pct: float)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
{公式名: 输出 GeoTIFF 路径}
|
||||
"""
|
||||
# ── 自动构造 HDR 路径 ────────────────────────────────────────────
|
||||
if hdr_path is None:
|
||||
hdr_path = os.path.splitext(bsq_path)[0] + '.hdr'
|
||||
if not os.path.isfile(hdr_path):
|
||||
hdr_path_alt = os.path.splitext(bsq_path)[0] + '.HDR'
|
||||
if os.path.isfile(hdr_path_alt):
|
||||
hdr_path = hdr_path_alt
|
||||
|
||||
# ── 自动构造输出目录 ────────────────────────────────────────────
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join(os.path.dirname(bsq_path), '8_WaterIndex_Images')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def progress(msg: str, pct: float):
|
||||
if progress_callback:
|
||||
progress_callback(msg, pct)
|
||||
|
||||
# ── 获取影像元数据 ───────────────────────────────────────────────
|
||||
progress("正在打开影像并读取元数据…", 2)
|
||||
meta = self.get_image_metadata(bsq_path, hdr_path)
|
||||
|
||||
width = meta.get('width', 0)
|
||||
height = meta.get('height', 0)
|
||||
n_bands = meta.get('bands', 0)
|
||||
wavelengths = meta.get('wavelengths', [])
|
||||
geotransform = meta.get('geotransform')
|
||||
projection = meta.get('projection')
|
||||
|
||||
if n_bands == 0 or width == 0 or height == 0:
|
||||
raise ValueError(f"影像元数据无效,无法处理: {bsq_path}")
|
||||
|
||||
if not wavelengths:
|
||||
raise ValueError(f"无法从 {hdr_path} 读取波长信息,公式无法解析")
|
||||
|
||||
progress(
|
||||
f"影像: {width}×{height}像素, {n_bands}波段, "
|
||||
f"波长 {wavelengths[0]:.1f}–{wavelengths[-1]:.1f}nm",
|
||||
5
|
||||
)
|
||||
|
||||
# ── 过滤要处理的公式 ──────────────────────────────────────────────
|
||||
if formula_names:
|
||||
formulas_to_run = [
|
||||
f for f in self.formulas
|
||||
if f.get('Formula_Name', '').strip() in formula_names
|
||||
]
|
||||
else:
|
||||
formulas_to_run = list(self.formulas)
|
||||
|
||||
results: Dict[str, str] = {}
|
||||
total = len(formulas_to_run)
|
||||
|
||||
# ── 逐公式处理 ───────────────────────────────────────────────────
|
||||
for i, formula_row in enumerate(formulas_to_run):
|
||||
fname = formula_row.get('Formula_Name', '').strip()
|
||||
fstr = formula_row.get('Formula', '').strip()
|
||||
category = formula_row.get('Category', '').strip()
|
||||
ftype = formula_row.get('Formula_Type', '').strip()
|
||||
|
||||
if not fname or not fstr:
|
||||
continue
|
||||
|
||||
progress(
|
||||
f"[{i + 1}/{total}] {fname} ({category})",
|
||||
5 + 90 * i / total
|
||||
)
|
||||
|
||||
try:
|
||||
# 1) 提取公式所需的波长列表
|
||||
required_wvs = self._parse_formula_wavelengths(fstr)
|
||||
|
||||
# 2) 按需读取波段数据(相同波长只读一次)
|
||||
band_data: Dict[int, np.ndarray] = {}
|
||||
for wv in required_wvs:
|
||||
if wv not in band_data:
|
||||
band_idx = self._find_nearest_band_index(wv, wavelengths)
|
||||
if not (0 < band_idx <= n_bands):
|
||||
print(f" ⚠ 公式 '{fname}' 引用波段 {band_idx},超出范围 ({n_bands}),跳过")
|
||||
raise ValueError(f"波段 {band_idx} 超出影像范围")
|
||||
band_data[wv] = self._read_band_as_float(bsq_path, band_idx)
|
||||
|
||||
# 3) 矩阵运算
|
||||
index_arr = self._eval_formula_fast(fstr, band_data)
|
||||
if index_arr is None:
|
||||
print(f" ⚠ 公式 '{fname}' 计算失败,跳过")
|
||||
continue
|
||||
|
||||
# 4) NoData 处理:NaN / Inf → nodata_value
|
||||
index_arr = np.where(np.isfinite(index_arr), index_arr, nodata_value)
|
||||
|
||||
# 4b) 水域掩膜拦截:陆地像素(mask==0)强制赋 NoData
|
||||
if water_mask is not None:
|
||||
land_pixels = (water_mask == 0)
|
||||
land_count = int(land_pixels.sum())
|
||||
if land_count > 0:
|
||||
index_arr = np.where(land_pixels, nodata_value, index_arr)
|
||||
print(f" 🗺 掩膜处理:陆地像素 {land_count:,} 个已设为 NoData")
|
||||
|
||||
# 5) 输出 GeoTIFF
|
||||
safe_fname = re.sub(r'[^\w\u4e00-\u9fff-]', '_', fname)
|
||||
out_tif = os.path.join(output_dir, f"{safe_fname}.tif")
|
||||
|
||||
self._write_geotiff(
|
||||
out_path=out_tif,
|
||||
data=index_arr,
|
||||
reference_bsq=bsq_path,
|
||||
nodata_value=nodata_value,
|
||||
description=f"{fname}|{category}|{ftype}|{fstr}",
|
||||
)
|
||||
|
||||
results[fname] = out_tif
|
||||
valid = index_arr[index_arr != nodata_value]
|
||||
mean_val = float(np.mean(valid)) if valid.size else np.nan
|
||||
print(f" ✅ {fname} → {out_tif} (mean={mean_val:.4f})")
|
||||
|
||||
except ValueError as ve:
|
||||
print(f" ⏭ 跳过 '{fname}': {ve}")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f" ❌ 公式 '{fname}' 失败: {e}\n{traceback.format_exc()}")
|
||||
continue
|
||||
|
||||
progress(f"完成!共输出 {len(results)} / {total} 个指数图", 100)
|
||||
return results
|
||||
|
||||
def _write_geotiff(
|
||||
self,
|
||||
out_path: str,
|
||||
data: np.ndarray,
|
||||
reference_bsq: str,
|
||||
nodata_value: float = -9999.0,
|
||||
description: str = "",
|
||||
) -> None:
|
||||
"""将数组写入 GeoTIFF,克隆原始 BSQ 的地理信息。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
out_path : str
|
||||
输出 GeoTIFF 路径
|
||||
data : np.ndarray
|
||||
2D 数据数组(height, width)
|
||||
reference_bsq : str
|
||||
参考 BSQ 影像路径(用于克隆 GeoTransform / Projection)
|
||||
nodata_value : float
|
||||
NoData 标记值
|
||||
description : str
|
||||
GDAL 数据集描述
|
||||
"""
|
||||
height, width = data.shape
|
||||
|
||||
driver = gdal.GetDriverByName('GTiff')
|
||||
if driver is None:
|
||||
raise RuntimeError("GDAL GTiff 驱动不可用")
|
||||
|
||||
out_ds = driver.Create(
|
||||
out_path,
|
||||
width, height,
|
||||
1,
|
||||
gdal.GDT_Float32,
|
||||
options=['COMPRESS=LZW', 'TILED=YES', 'BIGTIFF=IF_SAFER'],
|
||||
)
|
||||
if out_ds is None:
|
||||
raise RuntimeError(f"无法创建 GeoTIFF: {out_path}")
|
||||
|
||||
# 写入数据
|
||||
out_band = out_ds.GetRasterBand(1)
|
||||
out_band.SetNoDataValue(nodata_value)
|
||||
out_band.WriteArray(data)
|
||||
out_band.FlushCache()
|
||||
|
||||
# 写入描述
|
||||
if description:
|
||||
out_band.SetDescription(description)
|
||||
|
||||
# ★★★ 克隆原始 BSQ 的 GeoTransform 和 Projection ★★★
|
||||
ref_ds = gdal.Open(reference_bsq, gdal.GA_ReadOnly)
|
||||
if ref_ds is not None:
|
||||
gt = ref_ds.GetGeoTransform()
|
||||
proj = ref_ds.GetProjection()
|
||||
if gt and gt != (0, 1, 0, 0, 0, 1):
|
||||
out_ds.SetGeoTransform(gt)
|
||||
if proj:
|
||||
out_ds.SetProjection(proj)
|
||||
ref_ds = None
|
||||
|
||||
out_ds = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pipeline 入口(供 PipelineRunner 调用)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run_inversion(
|
||||
self,
|
||||
deglint_img_path: str,
|
||||
work_dir: str,
|
||||
formula_csv_path: Optional[str] = None,
|
||||
selected_formulas: Optional[List[str]] = None,
|
||||
water_mask_path: Optional[str] = None,
|
||||
nodata_value: float = -9999.0,
|
||||
callback: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, str]:
|
||||
"""Pipeline 入口方法。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
deglint_img_path : str
|
||||
去耀斑影像 BSQ 路径
|
||||
work_dir : str
|
||||
工作目录
|
||||
formula_csv_path : str, optional
|
||||
waterindex.csv 路径(None → 使用初始化时的路径)
|
||||
selected_formulas : list, optional
|
||||
要处理的公式列表
|
||||
water_mask_path : str, optional
|
||||
水域掩膜路径(如 1_water_mask/water_mask.dat),
|
||||
掩膜中为 0 的像素视为陆地区域,其指数值将被强制设为 NoData。
|
||||
nodata_value : float
|
||||
NoData 标记值,默认 -9999.0
|
||||
callback : callable, optional
|
||||
进度回调
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
{公式名: 输出 GeoTIFF 路径}
|
||||
"""
|
||||
# 重新加载公式(如指定了新路径)
|
||||
if formula_csv_path:
|
||||
self.reload(formula_csv_path)
|
||||
elif not self.formulas:
|
||||
raise RuntimeError("WaterIndexProcessor 未加载公式,请指定 formula_csv_path")
|
||||
|
||||
def notify(msg: str, pct: float):
|
||||
if callback:
|
||||
callback(msg, pct)
|
||||
|
||||
notify("开始水色指数反演", 0)
|
||||
|
||||
bsq_path = deglint_img_path
|
||||
hdr_path = os.path.splitext(bsq_path)[0] + '.hdr'
|
||||
if not os.path.isfile(hdr_path):
|
||||
hdr_path_alt = os.path.splitext(bsq_path)[0] + '.HDR'
|
||||
if os.path.isfile(hdr_path_alt):
|
||||
hdr_path = hdr_path_alt
|
||||
|
||||
output_dir = os.path.join(work_dir, "8_WaterIndex_Images")
|
||||
|
||||
# ── 加载水域掩膜(可选)───────────────────────────────────────
|
||||
water_mask: Optional[np.ndarray] = None
|
||||
if water_mask_path:
|
||||
if os.path.isfile(water_mask_path):
|
||||
try:
|
||||
import rasterio
|
||||
with rasterio.open(water_mask_path) as msrc:
|
||||
water_mask = msrc.read(1)
|
||||
print(f"[run_inversion] 水域掩膜已加载: {water_mask_path},"
|
||||
f"形状={water_mask.shape},"
|
||||
f"陆地区域(0)={int((water_mask == 0).sum())},"
|
||||
f"水区域(>0)={int((water_mask > 0).sum())}")
|
||||
except Exception as mask_err:
|
||||
print(f"[run_inversion] ⚠ 掩膜加载失败,跳过掩膜处理: {mask_err}")
|
||||
water_mask = None
|
||||
else:
|
||||
print(f"[run_inversion] ⚠ 水域掩膜文件不存在: {water_mask_path},跳过掩膜处理")
|
||||
|
||||
notify("水色指数处理中…", 20)
|
||||
|
||||
results = self.process_bsq(
|
||||
bsq_path=bsq_path,
|
||||
hdr_path=hdr_path,
|
||||
output_dir=output_dir,
|
||||
formula_names=selected_formulas,
|
||||
water_mask=water_mask,
|
||||
nodata_value=nodata_value,
|
||||
progress_callback=lambda m, p: notify(m, 20 + 70 * p / 100),
|
||||
)
|
||||
|
||||
notify("水色指数反演完成", 100)
|
||||
return results
|
||||
Reference in New Issue
Block a user