Files
BRDF/ocbrdf/aviris.py
2026-04-10 16:46:45 +08:00

978 lines
36 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gc
import os
import numpy as np
import sys
import xarray as xr
import rasterio
from rasterio.windows import Window
import spectral
from brdf_model_M02 import M02
from brdf_model_M02SeaDAS import M02SeaDAS
from brdf_model_L11 import L11
from brdf_model_O25 import O25
from brdf_utils import ADF_OCP, squeeze_trivial_dims
# 尝试导入 tqdm 进度条库,如果不可用则设为 None
try:
from tqdm import tqdm
except ImportError:
tqdm = None
UNC_LUT_CACHE = {}
"""
主 BRDF 校正模块
输入为 xarray 数据集
所需的光谱维度为 "bands",其他维度自由
输入数据集中必填的字段:
Rw: 方向性海洋反射率
sza: 太阳天顶角
vza: 观测天顶角
raa: 相对方位角当太阳和观测在同一侧时raa=0
输入数据集中可选的字段:
Rw_unc: Rw 的不确定度(若缺失,设为零)
输出数据集中的字段:
nrrs: 完全归一化的遥感反射率
rho_ex_w: nrrs * PI
omega_b: bb/(a+bb)
eta_b: bbw/bb
C_brdf: BRDF 校正因子
brdf_unc: C_brdf 的不确定度
nrrs_unc: nrrs 的不确定度
使用和适配 brdf_hypercp 模块需保留的信息:
Brdf_hypercp 是 EUMETSAT 研究项目的一部分
"BRDF correction of S3 OLCI water reflectance products"S3 OLCI 水体反射率产品的 BRDF 校正),
合同号RB_EUM-CO-21-4600002626-JIG。
研究团队成员Davide D'Alimonte (davide.dalimonte@aequora.org)
Tamito Kajiyama (tamito.kajiyama@aequora.org)
Jaime Pitarch (jaime.pitarchportero@artov.ismar.cnr.it)
Vittorio Brando (vittorio.brando@cnr.it)
Marco Talone (talone@icm.csic.es) 以及
Constant Mazeran (constant.mazeran@solvo.fr)。
BRDF 查找表中的相对方位角遵循 OLCI 约定。详见 https://www.eumetsat.int/media/50720图 6。
"""
# ---------------------------------------------------------------------------
# File I/O helpers
# ---------------------------------------------------------------------------
def _resolve_envi_header(filepath):
"""Return the ENVI header path accepted by spectral."""
if filepath.lower().endswith(".hdr"):
return filepath
hdr_path = filepath + ".hdr"
if os.path.exists(hdr_path):
return hdr_path
base, _ext = os.path.splitext(filepath)
hdr_path = base + ".hdr"
if os.path.exists(hdr_path):
return hdr_path
return filepath
def read_bsq(filepath, wavelengths=None, dtype=np.float32):
"""Read an ENVI BSQ hyperspectral image.
Parameters
----------
filepath : str
Path to the ENVI header (.hdr) or image file.
wavelengths : array-like, optional
Band centre wavelengths in nm. When *None*, wavelengths are read
from the ENVI header if available.
Returns
-------
data : ndarray, shape (bands, rows, cols), float64
wavelengths : ndarray, shape (bands,)
img : spectral image object (metadata / transform access)
"""
img = spectral.open_image(_resolve_envi_header(filepath))
data = np.asarray(img.open_memmap(interleave="source"), dtype=dtype)
# data = np.transpose(data, (2, 0, 1)) # -> (bands, rows, cols)
if wavelengths is None:
if hasattr(img, 'bands') and img.bands.centers is not None:
wavelengths = np.array(img.bands.centers, dtype=np.float32)
else:
n_bands = data.shape[0]
wavelengths = np.arange(1, n_bands + 1, dtype=np.float32)
return data, np.asarray(wavelengths, dtype=dtype), img
def relative_azimuth_olci(saa_deg, vaa_deg):
"""Relative azimuth (RAA / delta_phi) in [0, 180] deg, OLCI convention.
Same convention as the module docstring: sun and sensor on the same side -> RAA = 0.
"""
saa = np.asarray(saa_deg, dtype=np.float64)
vaa = np.asarray(vaa_deg, dtype=np.float64)
d = np.abs(saa - vaa)
out = np.minimum(d, 360.0 - d)
return out.astype(np.float32)
def read_bip_angles(filepath, dtype=np.float32):
"""Read an ENVI BIP geometry file (10 bands) and build BRDF angles.
Band layout (1-indexed), last dimension = band (BIP):
1 path length
2 to-sensor azimuth (VAA)
3 to-sensor zenith (VZA)
4 to-sun azimuth (SAA)
5 to-sun zenith (SZA)
6 phase angle
7 slope
8 aspect
9 cosine i
10 UTC time
RAA is not stored; it is computed from SAA and VAA via :func:`relative_azimuth_olci`.
Parameters
----------
filepath : str
Path to the ENVI BIP file (header or image file).
Returns
-------
dict with keys 'sza', 'saa', 'vza', 'vaa', 'raa'.
Each value is a 2-D ndarray of shape (rows, cols).
"""
img = spectral.open_image(_resolve_envi_header(filepath))
data = np.asarray(img.open_memmap(interleave="source"), dtype=dtype)
n_bands = data.shape[2]
if n_bands != 10:
raise ValueError(
f"Geometry file must have exactly 10 bands (got {n_bands}): "
"path length, VAA, VZA, SAA, SZA, phase, slope, aspect, cos i, UTC."
)
vaa = data[:, :, 1].copy()
vza = data[:, :, 2].copy()
saa = data[:, :, 3].copy()
sza = data[:, :, 4].copy()
raa = relative_azimuth_olci(saa, vaa)
return {
'sza': sza,
'saa': saa,
'vza': vza,
'vaa': vaa,
'raa': raa,
}
def read_water_mask(filepath):
"""Read a single-band GeoTIFF water mask.
Parameters
----------
filepath : str
Path to the GeoTIFF file.
Returns
-------
mask : ndarray, shape (rows, cols), int8
Pixels with value 1 indicate water; all other values are non-water.
"""
with rasterio.open(filepath) as src:
mask = src.read(1).astype(np.int8)
return mask
def _block_windows(src, block_size=512):
"""Yield raster windows for block-wise processing."""
rows, cols = src.height, src.width
for row in range(0, rows, block_size):
row_end = min(row + block_size, rows)
for col in range(0, cols, block_size):
col_end = min(col + block_size, cols)
window = Window(col, row, col_end - col, row_end - row)
yield window
def _resolve_output_envi_path(output_file):
"""Return the ENVI image path used for rasterio writes."""
if output_file.lower().endswith(".img"):
return output_file
if output_file.lower().endswith(".hdr"):
return output_file[:-4] + ".img"
return output_file + ".img"
def _read_wavelengths_from_header(hyperspectral_file):
"""Read wavelength metadata from ENVI header if available."""
img = spectral.open_image(_resolve_envi_header(hyperspectral_file))
if hasattr(img, 'bands') and img.bands.centers is not None:
wavelengths = img.bands.centers
else:
wavelengths = img.metadata.get('wavelength')
if wavelengths is None:
return None
return np.asarray(wavelengths, dtype=np.float32)
def save_brdf_result(
ds_out,
output_file,
source_file=None,
output_var="Rw_brdf",
output_format="ENVI",
dtype=np.float32,
):
"""Save the BRDF result to ENVI or NetCDF."""
output_format = output_format.upper()
if output_var not in ds_out:
raise KeyError(f"Output variable '{output_var}' not found in result dataset.")
if output_format == "NETCDF":
if not output_file.lower().endswith(".nc"):
output_file = output_file + ".nc"
ds_out.to_netcdf(output_file)
return output_file
if output_format != "ENVI":
raise ValueError("output_format must be 'ENVI' or 'NETCDF'.")
cube = np.asarray(ds_out[output_var].values, dtype=dtype)
cube = np.transpose(cube, (1, 2, 0))
if output_file.lower().endswith(".hdr"):
hdr_path = output_file
else:
hdr_path = output_file + ".hdr"
metadata = {
"interleave": "bsq",
"description": f"BRDF corrected result: {output_var}",
"bands": cube.shape[2],
"lines": cube.shape[0],
"samples": cube.shape[1],
}
if "bands" in ds_out.coords:
metadata["wavelength"] = [str(v) for v in ds_out.coords["bands"].values.tolist()]
if source_file is not None:
metadata["source_file"] = source_file
spectral.envi.save_image(
hdr_path,
cube,
dtype=dtype,
interleave="bsq",
metadata=metadata,
force=True,
)
return hdr_path
def _get_output_filename(base_output_file: str, var_name: str) -> str:
"""Generate output filename with variable suffix.
Examples:
input.hdr + Rw_brdf -> input_Rw_brdf.hdr
input.hdr + brdf_unc -> input_brdf_unc.hdr
"""
base, ext = os.path.splitext(base_output_file)
if ext.lower() in ('.hdr', '.img'):
base = base # keep base without extension
suffix = f"_{var_name}" if var_name.lower() != "rw_brdf" else ""
return f"{base}{suffix}{ext}"
def save_multiple_brdf_results(
ds_out,
output_file,
source_file=None,
output_vars=None,
output_format="ENVI",
dtype=np.float32,
):
"""Save multiple output variables to separate files."""
if output_vars is None or len(output_vars) == 0:
output_vars = ["Rw_brdf"]
saved_files = []
for var in output_vars:
out_file = _get_output_filename(output_file, var)
try:
saved_path = save_brdf_result(
ds_out=ds_out,
output_file=out_file,
source_file=source_file,
output_var=var,
output_format=output_format,
dtype=dtype,
)
saved_files.append((var, saved_path))
print(f" Saved {var} to: {saved_path}")
except KeyError as e:
print(f" Warning: {e}. Skipping {var}.")
return saved_files
def run_brdf_correction_block(
hyperspectral_file,
angle_file,
output_file,
mask_file=None,
wavelengths=None,
brdf_model='L11',
output_vars=None, # 支持列表
block_size=512,
compute_uncertainty=False,
progress_bar=True,
):
"""
Run BRDF correction block-wise and write multiple variables to separate ENVI files.
``angle_file`` must be a 10-band raster (same grid as hyperspectral); RAA is derived
from bands 2 (VAA) and 4 (SAA); see :func:`read_bip_angles` for band order.
If ``mask_file`` is None, every pixel is corrected (full scene); otherwise only pixels
with mask value 1 are corrected (other pixels pass through unchanged for Rw_brdf).
"""
if output_vars is None:
output_vars = ["Rw_brdf"]
if isinstance(output_vars, str):
output_vars = [output_vars]
# 1. Open all inputs
src_rw = rasterio.open(hyperspectral_file)
n_bands = src_rw.count
n_rows = src_rw.height
n_cols = src_rw.width
src_ang = rasterio.open(angle_file)
if src_ang.height != n_rows or src_ang.width != n_cols:
raise ValueError("Angle file dimensions do not match hyperspectral file.")
if src_ang.count != 10:
raise ValueError(
f"Geometry file must have 10 bands (got {src_ang.count}). "
"Expected: path length, VAA, VZA, SAA, SZA, phase, slope, aspect, cos i, UTC."
)
use_mask = mask_file is not None
src_mask = rasterio.open(mask_file) if use_mask else None
if use_mask and (src_mask.height != n_rows or src_mask.width != n_cols):
raise ValueError("Mask file dimensions do not match hyperspectral file.")
# 2. Wavelengths
if wavelengths is None:
wavelengths = _read_wavelengths_from_header(hyperspectral_file)
if wavelengths is None:
wavelengths = np.arange(1, n_bands + 1, dtype=np.float32)
bands_xr = xr.DataArray(wavelengths, dims=('bands',), coords={'bands': wavelengths}, name='bands')
# 3. Prepare multiple output files (one per variable)
out_meta = src_rw.meta.copy()
out_meta.update({
'driver': 'ENVI',
'dtype': 'float32',
'count': n_bands,
'interleave': 'bsq',
})
# Open one output file for each variable
output_dsts = {}
output_paths = {}
for var_name in output_vars:
var_file = _get_output_filename(output_file, var_name)
out_path = _resolve_output_envi_path(var_file)
output_dsts[var_name] = rasterio.open(out_path, 'w', **out_meta)
output_paths[var_name] = out_path
print(f" Writing {var_name} to {out_path}")
# 4. Band indices in geometry file (1-based); RAA computed from SAA and VAA
band_vaa = 2
band_vza = 3
band_saa = 4
band_sza = 5
# 5. Map output_var to ds_corr field (support uncertainty variables)
var_map = {
'Rw_brdf': 'rho_ex_w',
'rho_ex_w': 'rho_ex_w',
'nrrs': 'nrrs',
'C_brdf': 'C_brdf',
'brdf_unc': 'brdf_unc',
'nrrs_unc': 'nrrs_unc',
}
# Validate all requested output variables
for v in output_vars:
if v not in var_map and v not in ['Rw_brdf']:
raise ValueError(
f"Unknown output_var '{v}'. Allowed: {list(var_map.keys())}"
)
# 6. Progress bar setup
iterator = _block_windows(src_rw, block_size=block_size)
if progress_bar and tqdm is not None:
total_blocks = ((n_rows + block_size - 1) // block_size) * ((n_cols + block_size - 1) // block_size)
iterator = tqdm(iterator, total=total_blocks, desc="Processing blocks", unit="block")
elif progress_bar:
print("Warning: tqdm not installed, progress bar disabled.")
# Statistics for performance monitoring
blocks_processed = 0
blocks_skipped = 0
total_corrected_pixels = 0
# 7. Block-wise processing with early filtering
for window in iterator:
# Step 1: Read mask FIRST to enable early exit for pure land blocks (mask mode only)
if use_mask:
mask_block = src_mask.read(1, window=window).astype(np.int8)
water_mask = (mask_block == 1)
else:
mask_block = None
water_mask = np.ones((window.height, window.width), dtype=bool)
n_water = np.count_nonzero(water_mask)
if n_water == 0:
# Fast path: pure land/no-water block - minimal processing (only when mask is used)
blocks_skipped += 1
# Still need to read hyperspectral for Rw_brdf output, but skip angles
rw_block = src_rw.read(window=window).astype(np.float32)
output_blocks = {}
for var_name in output_vars:
if var_name == 'Rw_brdf':
output_blocks[var_name] = rw_block.copy()
else:
output_blocks[var_name] = np.full_like(rw_block, np.nan, dtype=np.float32)
for var_name, block_data in output_blocks.items():
output_dsts[var_name].write(block_data, window=window)
del rw_block, water_mask
gc.collect()
continue
# Step 2: Has water pixels - full processing
blocks_processed += 1
total_corrected_pixels += n_water
# Read remaining data only for blocks that need processing
rw_block = src_rw.read(window=window).astype(np.float32)
sza_block = src_ang.read(band_sza, window=window).astype(np.float32)
vza_block = src_ang.read(band_vza, window=window).astype(np.float32)
saa_block = src_ang.read(band_saa, window=window).astype(np.float32)
vaa_block = src_ang.read(band_vaa, window=window).astype(np.float32)
raa_block = relative_azimuth_olci(saa_block, vaa_block)
water_idx = np.where(water_mask)
# Prepare output blocks for all variables
output_blocks = {}
for var_name in output_vars:
if var_name == 'Rw_brdf':
output_blocks[var_name] = rw_block.copy()
else:
output_blocks[var_name] = np.full_like(rw_block, np.nan, dtype=np.float32)
# Process water pixels
rw_water = rw_block[:, water_idx[0], water_idx[1]].T # (n_water, bands)
sza_water = sza_block[water_idx]
vza_water = vza_block[water_idx]
raa_water = raa_block[water_idx]
pixel_ids = np.arange(n_water)
ds_block = xr.Dataset(
{
'Rw': xr.DataArray(rw_water, dims=('n', 'bands'), coords={'n': pixel_ids, 'bands': bands_xr}),
'sza': xr.DataArray(sza_water, dims='n', coords={'n': pixel_ids}),
'vza': xr.DataArray(vza_water, dims='n', coords={'n': pixel_ids}),
'raa': xr.DataArray(raa_water, dims='n', coords={'n': pixel_ids}),
}
)
ds_corr = brdf_prototype(
ds_block,
brdf_model=brdf_model,
compute_uncertainty=compute_uncertainty,
)
# Write each requested variable to the water pixels
for var_name in output_vars:
field_name = var_map.get(var_name, 'rho_ex_w')
if var_name == 'Rw_brdf':
field_name = 'rho_ex_w'
if field_name in ds_corr:
values = ds_corr[field_name].values.T
output_blocks[var_name][:, water_idx[0], water_idx[1]] = values
# Write all variables for this block
for var_name, block_data in output_blocks.items():
output_dsts[var_name].write(block_data, window=window)
del (rw_block, sza_block, vza_block, saa_block, vaa_block, raa_block,
output_blocks, water_mask, water_idx, rw_water, sza_water, vza_water,
raa_water, pixel_ids, ds_block, ds_corr)
if mask_block is not None:
del mask_block
gc.collect()
# Close all files
src_rw.close()
src_ang.close()
if src_mask is not None:
src_mask.close()
for dst in output_dsts.values():
dst.close()
# Print performance statistics
total_blocks = blocks_processed + blocks_skipped
if total_blocks > 0:
skip_ratio = blocks_skipped / total_blocks * 100
print(f"Block processing completed:")
print(f" - Total blocks: {total_blocks}")
print(f" - Processed blocks (corrected): {blocks_processed}")
print(f" - Skipped blocks (no mask pixels): {blocks_skipped} ({skip_ratio:.1f}%)")
print(f" - Total pixels BRDF-corrected: {total_corrected_pixels:,}")
return list(output_paths.values())
# ---------------------------------------------------------------------------
# Top-level pipeline
# ---------------------------------------------------------------------------
def brdf_uncertainty(ds, adf=None):
''' Compute uncertainty of BRDF factor and propagate to nrrs '''
# Read LUT
if adf is None:
adf = ADF_OCP
unc_lut_path = adf % 'UNC'
if unc_lut_path not in UNC_LUT_CACHE:
UNC_LUT_CACHE[unc_lut_path] = xr.open_dataset(unc_lut_path, engine='netcdf4')
LUT = UNC_LUT_CACHE[unc_lut_path]
# Interpolate relative uncertainty
unc = LUT['unc'].interp(lambda_unc=ds.bands, theta_s_unc=ds.sza, theta_v_unc=ds.vza,
delta_phi_unc=ds.raa)
# Compute absolute uncertainty of factor
ds['brdf_unc'] = unc * ds['C_brdf']
# Flag BRDF_unc where NaN and set to 0
ds['brdf_unc_fail'] = np.isnan(ds['brdf_unc'])
ds['brdf_unc'] = xr.where(ds['brdf_unc_fail'], 0, ds['brdf_unc'])
# Propagate to nrrs
Rwex_unc2 = ds['brdf_unc'] * ds['brdf_unc'] * ds['Rw'] * ds['Rw']
if 'Rw_unc' in ds:
Rwex_unc2 += ds['C_brdf'] * ds['C_brdf'] * ds['Rw_unc'] * ds['Rw_unc']
ds['nrrs_unc'] = np.sqrt(Rwex_unc2)/np.pi
return ds
def brdf_prototype(ds, adf=None, brdf_model='L11', compute_uncertainty=True):
# 测试 BRDF 模型在 GUI 中不支持:强制覆盖
# brdf_model = 'M02SeaDAS'
# 压缩单一维度(例如,投放次数、提取次数等)以避免插值问题
ds, squeezedDims = squeeze_trivial_dims(ds)
# Initialise model
if brdf_model == 'M02':
BRDF_model = M02(bands=ds.bands, aot=ds.aot, wind=ds.wind, adf=None) # Don't use brdf_py.ADF context
elif brdf_model == 'M02SeaDAS':
BRDF_model = M02SeaDAS(bands=ds.bands, adf=None) # Don't use brdf_py.ADF context
elif brdf_model == 'L11':
BRDF_model = L11(bands=ds.bands, adf=None) # Don't use brdf_py.ADF context
elif brdf_model == 'O25':
BRDF_model = O25(bands=ds.bands, adf=None) # Don't use brdf_py.ADF context
else:
print("BRDF model %s not supported" % brdf_model)
sys.exit(1)
# Init pixel
BRDF_model.init_pixels(ds['sza'], ds['vza'], ds['raa'])
# Compute IOP and normalize by iterating
ds['nrrs'] = ds['Rw'] / np.pi
ds['convergeFlag'] = (0 * ds['sza']).astype(bool)
ds['C_brdf'] = 0 * ds['nrrs'] + 1
for iter_brdf in range(int(BRDF_model.niter)):
# M02: Initialise chl_iter
if brdf_model in ['M02', 'M02SeaDAS'] and (iter_brdf == 0):
chl_iter = {}
ds['log10_chl'] = 0 * ds['sza'] + float(np.log10(BRDF_model.OC4MEchl0))
chl_iter[-1] = 0 * ds['sza'] + float(BRDF_model.OC4MEchl0)
ds = BRDF_model.backward(ds, iter_brdf)
# M02: Check convergence (dummy for M02SeaDAS for the moment... epsilon set to 0)
if brdf_model in ['M02', 'M02SeaDAS']:
chl_iter[iter_brdf] = 10 ** ds['log10_chl']
# Check if convergence is reached |chl_old-chl_new| < epsilon * chl_new
ds['convergeFlag'] = (ds['convergeFlag']) | (
(np.abs(chl_iter[iter_brdf - 1] - chl_iter[iter_brdf]) < float(BRDF_model.OC4MEepsilon) * chl_iter[
iter_brdf]))
# Apply forward model in both geometries
forward_mod = BRDF_model.forward(ds)
forward_mod0 = BRDF_model.forward(ds, normalized=True)
ratio = forward_mod0 / forward_mod
# Drop remnant coordinates to avoid ambiguities in the update of the BRDF factor.
for coord in ratio.coords:
if coord not in ds['C_brdf'].coords:
ratio = ratio.drop(coord)
# Normalize reflectance
ds['C_brdf'] = xr.where(ds['convergeFlag'], ds['C_brdf'], ratio)
ds['nrrs'] = ds['Rw'] / np.pi * ds['C_brdf']
# Flag BRDF where NaN and set to 1 (no correction applied).
ds['C_brdf_fail'] = np.isnan(ds['C_brdf'])
ds['C_brdf'] = xr.where(ds['C_brdf_fail'], 1, ds['C_brdf'])
ds['nrrs'] = xr.where(ds['C_brdf_fail'], ds['Rw'] / np.pi, ds['nrrs'])
# If QAA_fail is raised, raise C_brdf_fail (but still apply C_brdf).
if 'QAA_fail' in ds:
ds['C_brdf_fail'] = (ds['C_brdf_fail']) | (ds['QAA_fail'])
# Compute uncertainty only if requested
if compute_uncertainty:
ds = brdf_uncertainty(ds)
# Compute flag
ds['flags_level2'] = ds['Rw'] * 0 # TODO
# Convert to reflectance unit
ds['rho_ex_w'] = ds['nrrs'] * np.pi
# Expand squeezed trivial dimensions
for dim,d0 in squeezedDims.items():
ds = ds.expand_dims(dim,axis=d0)
return ds
def run_brdf_correction(
hyperspectral_file,
angle_file,
mask_file=None,
wavelengths=None,
brdf_model='L11',
chunk_size=4096,
output_var=None, # if None, keep full output (original behaviour)
compute_uncertainty=True, # compute uncertainty only if requested
progress_bar=True, # 是否显示进度条
):
"""BRDF correction pipeline for water pixels in a hyperspectral scene.
``angle_file`` must be a 10-band ENVI geometry cube (see :func:`read_bip_angles`);
relative azimuth is computed from to-sun and to-sensor azimuths.
If ``mask_file`` is None, a mask of all ones is used (full-scene correction).
"""
# ------------------------------------------------------------------
# 1. Read all inputs
# ------------------------------------------------------------------
rw_data, wvl, _img_meta = read_bsq(hyperspectral_file, wavelengths, dtype=np.float32)
angles = read_bip_angles(angle_file, dtype=np.float32)
n_bands, n_rows, n_cols = rw_data.shape
if mask_file is None:
water_mask = np.ones((n_rows, n_cols), dtype=np.int8)
else:
water_mask = read_water_mask(mask_file)
# Sanity-check spatial dimensions
if angles['sza'].shape != (n_rows, n_cols):
raise ValueError(
f"Angle file spatial shape {angles['sza'].shape} does not match "
f"hyperspectral image shape ({n_rows}, {n_cols})."
)
if water_mask.shape != (n_rows, n_cols):
raise ValueError(
f"Mask shape {water_mask.shape} does not match "
f"hyperspectral image shape ({n_rows}, {n_cols})."
)
# ------------------------------------------------------------------
# 2. Identify water pixels
# ------------------------------------------------------------------
water_rows, water_cols = np.where(water_mask == 1)
n_water = water_rows.size
if n_water == 0:
raise RuntimeError("No water pixels found in the mask (no pixels == 1).")
# ------------------------------------------------------------------
# 3. Determine which output variable(s) we need to keep
# ------------------------------------------------------------------
# If output_var is None, we keep everything (original behaviour)
keep_all = (output_var is None)
# Map output_var to the field in ds_corr
var_map = {
'Rw_brdf': 'rho_ex_w',
'rho_ex_w': 'rho_ex_w',
'nrrs': 'nrrs',
'C_brdf': 'C_brdf'
}
if not keep_all:
if output_var not in var_map:
raise ValueError(f"Unknown output_var '{output_var}'. Allowed: {list(var_map.keys())}")
field_name = var_map[output_var]
else:
field_name = None
# ------------------------------------------------------------------
# 4. Allocate output arrays (only those needed)
# ------------------------------------------------------------------
shape3 = (n_bands, n_rows, n_cols)
shape2 = (n_rows, n_cols)
if keep_all:
# Original behaviour: allocate all arrays
nrrs_full = np.full(shape3, np.nan, dtype=np.float32)
rho_ex_w_full = np.full(shape3, np.nan, dtype=np.float32)
C_brdf_full = np.full(shape3, np.nan, dtype=np.float32)
brdf_unc_full = np.full(shape3, np.nan, dtype=np.float32)
nrrs_unc_full = np.full(shape3, np.nan, dtype=np.float32)
C_brdf_fail_full = np.zeros(shape2, dtype=bool)
rw_brdf_full = rw_data.copy()
else:
# Only allocate the requested output array
output_full = np.full(shape3, np.nan, dtype=np.float32)
# We may also need C_brdf_fail if requested? Not for output_var, but could be used internally? Skip.
# For simplicity, we do not allocate extra arrays.
# ------------------------------------------------------------------
# 56. Chunked BRDF correction
# ------------------------------------------------------------------
n_chunks = int(np.ceil(n_water / chunk_size))
# 初始化进度条
iterator = range(n_chunks)
if progress_bar and tqdm is not None:
iterator = tqdm(iterator, desc="Processing chunks", unit="chunk")
elif progress_bar:
print("Warning: tqdm not installed, progress bar disabled.")
for chunk_idx in iterator:
lo = chunk_idx * chunk_size
hi = min(lo + chunk_size, n_water)
rows_c = water_rows[lo:hi]
cols_c = water_cols[lo:hi]
# Extract batch vectors
rw_c = rw_data[:, rows_c, cols_c].T # (chunk, bands)
sza_c = angles['sza'][rows_c, cols_c] # (chunk,)
vza_c = angles['vza'][rows_c, cols_c]
raa_c = angles['raa'][rows_c, cols_c]
pixel_idx_c = np.arange(hi - lo)
ds_chunk = xr.Dataset(
{
'Rw': xr.DataArray(rw_c, dims=('n', 'bands'),
coords={'n': pixel_idx_c, 'bands': wvl}),
'sza': xr.DataArray(sza_c, dims='n',
coords={'n': pixel_idx_c}),
'vza': xr.DataArray(vza_c, dims='n',
coords={'n': pixel_idx_c}),
'raa': xr.DataArray(raa_c, dims='n',
coords={'n': pixel_idx_c}),
}
)
# Run BRDF model with uncertainty control
ds_corr = brdf_prototype(ds_chunk, brdf_model=brdf_model, compute_uncertainty=compute_uncertainty)
if keep_all:
# Write all results back
nrrs_full[:, rows_c, cols_c] = ds_corr['nrrs'].values.T
rho_ex_w_full[:, rows_c, cols_c] = ds_corr['rho_ex_w'].values.T
C_brdf_full[:, rows_c, cols_c] = ds_corr['C_brdf'].values.T
if compute_uncertainty:
brdf_unc_full[:, rows_c, cols_c] = ds_corr['brdf_unc'].values.T
nrrs_unc_full[:, rows_c, cols_c] = ds_corr['nrrs_unc'].values.T
C_brdf_fail_full[rows_c, cols_c] = ds_corr['C_brdf_fail'].values
rw_brdf_full[:, rows_c, cols_c] = ds_corr['rho_ex_w'].values.T
else:
# Only write the requested output field
values = ds_corr[field_name].values.T
output_full[:, rows_c, cols_c] = values
# Release batch objects so GC can reclaim memory
del rw_c, sza_c, vza_c, raa_c, pixel_idx_c, ds_chunk, ds_corr
gc.collect()
# ------------------------------------------------------------------
# 7. Assemble full-scene output Dataset
# ------------------------------------------------------------------
full_coords = {'bands': wvl, 'y': np.arange(n_rows), 'x': np.arange(n_cols)}
dims3 = ('bands', 'y', 'x')
dims2 = ('y', 'x')
if keep_all:
# Original behaviour: return all variables
ds_out = xr.Dataset(
{
'Rw': xr.DataArray(rw_data, dims=dims3, coords=full_coords),
'Rw_brdf': xr.DataArray(rw_brdf_full, dims=dims3, coords=full_coords),
'nrrs': xr.DataArray(nrrs_full, dims=dims3, coords=full_coords),
'rho_ex_w': xr.DataArray(rho_ex_w_full, dims=dims3, coords=full_coords),
'C_brdf': xr.DataArray(C_brdf_full, dims=dims3, coords=full_coords),
'C_brdf_fail': xr.DataArray(C_brdf_fail_full, dims=dims2),
'water_mask': xr.DataArray(water_mask, dims=dims2),
'sza': xr.DataArray(angles['sza'], dims=dims2),
'saa': xr.DataArray(angles['saa'], dims=dims2),
'vza': xr.DataArray(angles['vza'], dims=dims2),
'vaa': xr.DataArray(angles['vaa'], dims=dims2),
'raa': xr.DataArray(angles['raa'], dims=dims2),
}
)
if compute_uncertainty:
ds_out['brdf_unc'] = xr.DataArray(brdf_unc_full, dims=dims3, coords=full_coords)
ds_out['nrrs_unc'] = xr.DataArray(nrrs_unc_full, dims=dims3, coords=full_coords)
else:
# Return only the requested output variable + geometry + mask
ds_out = xr.Dataset(
{
output_var: xr.DataArray(output_full, dims=dims3, coords=full_coords),
'Rw': xr.DataArray(rw_data, dims=dims3, coords=full_coords),
'water_mask': xr.DataArray(water_mask, dims=dims2),
'sza': xr.DataArray(angles['sza'], dims=dims2),
'saa': xr.DataArray(angles['saa'], dims=dims2),
'vza': xr.DataArray(angles['vza'], dims=dims2),
'vaa': xr.DataArray(angles['vaa'], dims=dims2),
'raa': xr.DataArray(angles['raa'], dims=dims2),
}
)
# If uncertainties were computed, we could add them optionally, but user didn't request them.
# For memory efficiency, we skip them here.
return ds_out
def process_brdf_files(
hyperspectral_file,
angle_file,
output_file,
mask_file=None,
wavelengths=None,
brdf_model="L11",
output_vars=None,
output_format="ENVI",
chunk_size=4096,
block_size=512,
compute_uncertainty=False,
progress_bar=True,
):
"""Run BRDF correction from input files and save multiple variables to separate files.
The geometry file must contain 10 bands as documented for :func:`read_bip_angles`.
If ``mask_file`` is None, the full scene is corrected (same as :func:`run_brdf_correction_block`).
"""
if output_format.upper() != "ENVI":
raise NotImplementedError("Only ENVI output format is supported in block-wise version.")
if output_vars is None:
output_vars = ["Rw_brdf"]
if isinstance(output_vars, str):
output_vars = [output_vars]
print(f"Computing BRDF correction with variables: {output_vars}")
print(f"Output base name: {output_file}")
saved_paths = run_brdf_correction_block(
hyperspectral_file=hyperspectral_file,
angle_file=angle_file,
mask_file=mask_file,
output_file=output_file,
wavelengths=wavelengths,
brdf_model=brdf_model,
output_vars=output_vars,
block_size=block_size,
compute_uncertainty=compute_uncertainty,
progress_bar=progress_bar,
)
return None, saved_paths
def main():
import argparse
parser = argparse.ArgumentParser(description="Water-region BRDF correction for hyperspectral data.")
parser.add_argument("hyperspectral_file", help="Input ENVI BSQ hyperspectral file or .hdr path.")
parser.add_argument(
"angle_file",
help="10-band ENVI geometry file or .hdr: path length, VAA, VZA, SAA, SZA, "
"phase, slope, aspect, cos i, UTC (RAA computed from VAA and SAA).",
)
parser.add_argument("output_file", help="Output path prefix. ENVI writes .hdr/.img, NetCDF writes .nc.")
parser.add_argument(
"--mask",
dest="mask_file",
default=None,
metavar="PATH",
help="Optional water mask GeoTIFF; pixels == 1 are corrected. "
"If omitted, BRDF correction is applied to the full image.",
)
parser.add_argument("--brdf-model", default="L11", choices=["L11", "M02", "M02SeaDAS", "O25"])
parser.add_argument("--output-var", nargs='+', default=["Rw_brdf"],
help="Output variables to save (one or more). "
"Supported: Rw_brdf, rho_ex_w, nrrs, C_brdf, brdf_unc, nrrs_unc. "
"Example: --output-var Rw_brdf nrrs brdf_unc")
parser.add_argument("--output-format", default="ENVI", choices=["ENVI"])
parser.add_argument("--chunk-size", type=int, default=4096,
help="Number of water pixels per processing batch (default: 4096).")
parser.add_argument("--block-size", type=int, default=512,
help="Spatial block size for block-wise processing (default: 512).")
parser.add_argument("--compute-uncertainty", action="store_true",
help="Compute BRDF uncertainty (may increase runtime and output size).")
parser.add_argument("--no-progress-bar", action="store_true",
help="Disable progress bar (useful for non-interactive environments).")
args = parser.parse_args()
progress_bar = not args.no_progress_bar
_ds_out, saved_paths = process_brdf_files(
hyperspectral_file=args.hyperspectral_file,
angle_file=args.angle_file,
mask_file=args.mask_file,
output_file=args.output_file,
brdf_model=args.brdf_model,
output_vars=args.output_var,
output_format=args.output_format,
chunk_size=args.chunk_size,
block_size=args.block_size,
compute_uncertainty=args.compute_uncertainty,
progress_bar=progress_bar,
)
if isinstance(saved_paths, list):
print(f"BRDF correction finished. Saved {len(saved_paths)} files:")
for path in saved_paths:
print(f" - {path}")
else:
print(f"BRDF correction finished. Saved to: {saved_paths}")
if __name__ == "__main__":
main()