Initial commit

This commit is contained in:
2026-04-10 16:46:45 +08:00
commit 4fd1b0a203
165 changed files with 25698 additions and 0 deletions

908
ocbrdf/ocbrdf_main3.py Normal file
View File

@ -0,0 +1,908 @@
import gc
import os
import numpy as np
import sys
import xarray as xr
import rasterio
from rasterio.windows import Window
import spectral
from brdf_model_M02 import M02
from brdf_model_M02SeaDAS import M02SeaDAS
from brdf_model_L11 import L11
from brdf_model_O25 import O25
from brdf_utils import ADF_OCP, squeeze_trivial_dims
# 尝试导入 tqdm 进度条库,如果不可用则设为 None
try:
from tqdm import tqdm
except ImportError:
tqdm = None
UNC_LUT_CACHE = {}
"""
主 BRDF 校正模块
输入为 xarray 数据集
所需的光谱维度为 "bands",其他维度自由
输入数据集中必填的字段:
Rw: 方向性海洋反射率
sza: 太阳天顶角
vza: 观测天顶角
raa: 相对方位角当太阳和观测在同一侧时raa=0
输入数据集中可选的字段:
Rw_unc: Rw 的不确定度(若缺失,设为零)
输出数据集中的字段:
nrrs: 完全归一化的遥感反射率
rho_ex_w: nrrs * PI
omega_b: bb/(a+bb)
eta_b: bbw/bb
C_brdf: BRDF 校正因子
brdf_unc: C_brdf 的不确定度
nrrs_unc: nrrs 的不确定度
使用和适配 brdf_hypercp 模块需保留的信息:
Brdf_hypercp 是 EUMETSAT 研究项目的一部分
"BRDF correction of S3 OLCI water reflectance products"S3 OLCI 水体反射率产品的 BRDF 校正),
合同号RB_EUM-CO-21-4600002626-JIG。
研究团队成员Davide D'Alimonte (davide.dalimonte@aequora.org)
Tamito Kajiyama (tamito.kajiyama@aequora.org)
Jaime Pitarch (jaime.pitarchportero@artov.ismar.cnr.it)
Vittorio Brando (vittorio.brando@cnr.it)
Marco Talone (talone@icm.csic.es) 以及
Constant Mazeran (constant.mazeran@solvo.fr)。
BRDF 查找表中的相对方位角遵循 OLCI 约定。详见 https://www.eumetsat.int/media/50720图 6。
"""
# ---------------------------------------------------------------------------
# File I/O helpers
# ---------------------------------------------------------------------------
def _resolve_envi_header(filepath):
"""Return the ENVI header path accepted by spectral."""
if filepath.lower().endswith(".hdr"):
return filepath
hdr_path = filepath + ".hdr"
if os.path.exists(hdr_path):
return hdr_path
base, _ext = os.path.splitext(filepath)
hdr_path = base + ".hdr"
if os.path.exists(hdr_path):
return hdr_path
return filepath
def read_bsq(filepath, wavelengths=None, dtype=np.float32):
"""Read an ENVI BSQ hyperspectral image.
Parameters
----------
filepath : str
Path to the ENVI header (.hdr) or image file.
wavelengths : array-like, optional
Band centre wavelengths in nm. When *None*, wavelengths are read
from the ENVI header if available.
Returns
-------
data : ndarray, shape (bands, rows, cols), float64
wavelengths : ndarray, shape (bands,)
img : spectral image object (metadata / transform access)
"""
img = spectral.open_image(_resolve_envi_header(filepath))
data = np.asarray(img.open_memmap(interleave="source"), dtype=dtype)
# data = np.transpose(data, (2, 0, 1)) # -> (bands, rows, cols)
if wavelengths is None:
if hasattr(img, 'bands') and img.bands.centers is not None:
wavelengths = np.array(img.bands.centers, dtype=np.float32)
else:
n_bands = data.shape[0]
wavelengths = np.arange(1, n_bands + 1, dtype=np.float32)
return data, np.asarray(wavelengths, dtype=dtype), img
def read_bip_angles(filepath, dtype=np.float32):
"""Read an ENVI BIP angle file and extract the five geometry bands.
Band layout (1-indexed):
band 7 -> SZA Solar Zenith Angle
band 8 -> SAA Solar Azimuth Angle
band 9 -> VZA Sensor Zenith Angle
band 10 -> VAA Sensor Azimuth Angle
band 11 -> RAA Relative Azimuth Angle
Parameters
----------
filepath : str
Path to the ENVI BIP file (header or image file).
Returns
-------
dict with keys 'sza', 'saa', 'vza', 'vaa', 'raa'.
Each value is a 2-D ndarray of shape (rows, cols), float64.
"""
img = spectral.open_image(_resolve_envi_header(filepath))
data = np.asarray(img.open_memmap(interleave="source"), dtype=dtype)
n_bands = data.shape[2]
if n_bands < 11:
raise ValueError(
f"Angle file has only {n_bands} bands; at least 11 are required "
"(bands 711 = SZA, SAA, VZA, VAA, RAA)."
)
return {
'sza': data[:, :, 6].copy(),
'saa': data[:, :, 7].copy(),
'vza': data[:, :, 8].copy(),
'vaa': data[:, :, 9].copy(),
'raa': data[:, :, 10].copy(),
}
def read_water_mask(filepath):
"""Read a single-band GeoTIFF water mask.
Parameters
----------
filepath : str
Path to the GeoTIFF file.
Returns
-------
mask : ndarray, shape (rows, cols), int8
Pixels with value 1 indicate water; all other values are non-water.
"""
with rasterio.open(filepath) as src:
mask = src.read(1).astype(np.int8)
return mask
def _block_windows(src, block_size=512):
"""Yield raster windows for block-wise processing."""
rows, cols = src.height, src.width
for row in range(0, rows, block_size):
row_end = min(row + block_size, rows)
for col in range(0, cols, block_size):
col_end = min(col + block_size, cols)
window = Window(col, row, col_end - col, row_end - row)
yield window
def _resolve_output_envi_path(output_file):
"""Return the ENVI image path used for rasterio writes."""
if output_file.lower().endswith(".img"):
return output_file
if output_file.lower().endswith(".hdr"):
return output_file[:-4] + ".img"
return output_file + ".img"
def _read_wavelengths_from_header(hyperspectral_file):
"""Read wavelength metadata from ENVI header if available."""
img = spectral.open_image(_resolve_envi_header(hyperspectral_file))
if hasattr(img, 'bands') and img.bands.centers is not None:
wavelengths = img.bands.centers
else:
wavelengths = img.metadata.get('wavelength')
if wavelengths is None:
return None
return np.asarray(wavelengths, dtype=np.float32)
def save_brdf_result(
ds_out,
output_file,
source_file=None,
output_var="Rw_brdf",
output_format="ENVI",
dtype=np.float32,
):
"""Save the BRDF result to ENVI or NetCDF."""
output_format = output_format.upper()
if output_var not in ds_out:
raise KeyError(f"Output variable '{output_var}' not found in result dataset.")
if output_format == "NETCDF":
if not output_file.lower().endswith(".nc"):
output_file = output_file + ".nc"
ds_out.to_netcdf(output_file)
return output_file
if output_format != "ENVI":
raise ValueError("output_format must be 'ENVI' or 'NETCDF'.")
cube = np.asarray(ds_out[output_var].values, dtype=dtype)
cube = np.transpose(cube, (1, 2, 0))
if output_file.lower().endswith(".hdr"):
hdr_path = output_file
else:
hdr_path = output_file + ".hdr"
metadata = {
"interleave": "bsq",
"description": f"BRDF corrected result: {output_var}",
"bands": cube.shape[2],
"lines": cube.shape[0],
"samples": cube.shape[1],
}
if "bands" in ds_out.coords:
metadata["wavelength"] = [str(v) for v in ds_out.coords["bands"].values.tolist()]
if source_file is not None:
metadata["source_file"] = source_file
spectral.envi.save_image(
hdr_path,
cube,
dtype=dtype,
interleave="bsq",
metadata=metadata,
force=True,
)
return hdr_path
def _get_output_filename(base_output_file: str, var_name: str) -> str:
"""Generate output filename with variable suffix.
Examples:
input.hdr + Rw_brdf -> input_Rw_brdf.hdr
input.hdr + brdf_unc -> input_brdf_unc.hdr
"""
base, ext = os.path.splitext(base_output_file)
if ext.lower() in ('.hdr', '.img'):
base = base # keep base without extension
suffix = f"_{var_name}" if var_name.lower() != "rw_brdf" else ""
return f"{base}{suffix}{ext}"
def save_multiple_brdf_results(
ds_out,
output_file,
source_file=None,
output_vars=None,
output_format="ENVI",
dtype=np.float32,
):
"""Save multiple output variables to separate files."""
if output_vars is None or len(output_vars) == 0:
output_vars = ["Rw_brdf"]
saved_files = []
for var in output_vars:
out_file = _get_output_filename(output_file, var)
try:
saved_path = save_brdf_result(
ds_out=ds_out,
output_file=out_file,
source_file=source_file,
output_var=var,
output_format=output_format,
dtype=dtype,
)
saved_files.append((var, saved_path))
print(f" Saved {var} to: {saved_path}")
except KeyError as e:
print(f" Warning: {e}. Skipping {var}.")
return saved_files
def run_brdf_correction_block(
hyperspectral_file,
angle_file,
mask_file,
output_file,
wavelengths=None,
brdf_model='L11',
output_vars=None, # 支持列表
block_size=512,
compute_uncertainty=False,
progress_bar=True,
):
"""
Run BRDF correction block-wise and write multiple variables to separate ENVI files.
"""
if output_vars is None:
output_vars = ["Rw_brdf"]
if isinstance(output_vars, str):
output_vars = [output_vars]
# 1. Open all inputs
src_rw = rasterio.open(hyperspectral_file)
n_bands = src_rw.count
n_rows = src_rw.height
n_cols = src_rw.width
src_ang = rasterio.open(angle_file)
if src_ang.height != n_rows or src_ang.width != n_cols:
raise ValueError("Angle file dimensions do not match hyperspectral file.")
src_mask = rasterio.open(mask_file)
if src_mask.height != n_rows or src_mask.width != n_cols:
raise ValueError("Mask file dimensions do not match hyperspectral file.")
# 2. Wavelengths
if wavelengths is None:
wavelengths = _read_wavelengths_from_header(hyperspectral_file)
if wavelengths is None:
wavelengths = np.arange(1, n_bands + 1, dtype=np.float32)
bands_xr = xr.DataArray(wavelengths, dims=('bands',), coords={'bands': wavelengths}, name='bands')
# 3. Prepare multiple output files (one per variable)
out_meta = src_rw.meta.copy()
out_meta.update({
'driver': 'ENVI',
'dtype': 'float32',
'count': n_bands,
'interleave': 'bsq',
})
# Open one output file for each variable
output_dsts = {}
output_paths = {}
for var_name in output_vars:
var_file = _get_output_filename(output_file, var_name)
out_path = _resolve_output_envi_path(var_file)
output_dsts[var_name] = rasterio.open(out_path, 'w', **out_meta)
output_paths[var_name] = out_path
print(f" Writing {var_name} to {out_path}")
# 4. Band indices in angle file (1-based)
band_sza = 7
band_vza = 9
band_raa = 11
# 5. Map output_var to ds_corr field (support uncertainty variables)
var_map = {
'Rw_brdf': 'rho_ex_w',
'rho_ex_w': 'rho_ex_w',
'nrrs': 'nrrs',
'C_brdf': 'C_brdf',
'brdf_unc': 'brdf_unc',
'nrrs_unc': 'nrrs_unc',
}
# Validate all requested output variables
for v in output_vars:
if v not in var_map and v not in ['Rw_brdf']:
raise ValueError(
f"Unknown output_var '{v}'. Allowed: {list(var_map.keys())}"
)
# 6. Progress bar setup
iterator = _block_windows(src_rw, block_size=block_size)
if progress_bar and tqdm is not None:
total_blocks = ((n_rows + block_size - 1) // block_size) * ((n_cols + block_size - 1) // block_size)
iterator = tqdm(iterator, total=total_blocks, desc="Processing blocks", unit="block")
elif progress_bar:
print("Warning: tqdm not installed, progress bar disabled.")
# Statistics for performance monitoring
blocks_processed = 0
blocks_skipped = 0
total_water_pixels = 0
# 7. Block-wise processing with early filtering
for window in iterator:
# Step 1: Read mask FIRST to enable early exit for pure land blocks
mask_block = src_mask.read(1, window=window).astype(np.int8)
water_mask = (mask_block == 1)
n_water = np.count_nonzero(water_mask)
if n_water == 0:
# Fast path: pure land/no-water block - minimal processing
blocks_skipped += 1
# Still need to read hyperspectral for Rw_brdf output, but skip angles
rw_block = src_rw.read(window=window).astype(np.float32)
output_blocks = {}
for var_name in output_vars:
if var_name == 'Rw_brdf':
output_blocks[var_name] = rw_block.copy()
else:
output_blocks[var_name] = np.full_like(rw_block, np.nan, dtype=np.float32)
for var_name, block_data in output_blocks.items():
output_dsts[var_name].write(block_data, window=window)
del rw_block, mask_block, water_mask
gc.collect()
continue
# Step 2: Has water pixels - full processing
blocks_processed += 1
total_water_pixels += n_water
# Read remaining data only for blocks that need processing
rw_block = src_rw.read(window=window).astype(np.float32)
sza_block = src_ang.read(band_sza, window=window).astype(np.float32)
vza_block = src_ang.read(band_vza, window=window).astype(np.float32)
raa_block = src_ang.read(band_raa, window=window).astype(np.float32)
water_idx = np.where(water_mask)
# Prepare output blocks for all variables
output_blocks = {}
for var_name in output_vars:
if var_name == 'Rw_brdf':
output_blocks[var_name] = rw_block.copy()
else:
output_blocks[var_name] = np.full_like(rw_block, np.nan, dtype=np.float32)
# Process water pixels
rw_water = rw_block[:, water_idx[0], water_idx[1]].T # (n_water, bands)
sza_water = sza_block[water_idx]
vza_water = vza_block[water_idx]
raa_water = raa_block[water_idx]
pixel_ids = np.arange(n_water)
ds_block = xr.Dataset(
{
'Rw': xr.DataArray(rw_water, dims=('n', 'bands'), coords={'n': pixel_ids, 'bands': bands_xr}),
'sza': xr.DataArray(sza_water, dims='n', coords={'n': pixel_ids}),
'vza': xr.DataArray(vza_water, dims='n', coords={'n': pixel_ids}),
'raa': xr.DataArray(raa_water, dims='n', coords={'n': pixel_ids}),
}
)
ds_corr = brdf_prototype(
ds_block,
brdf_model=brdf_model,
compute_uncertainty=compute_uncertainty,
)
# Write each requested variable to the water pixels
for var_name in output_vars:
field_name = var_map.get(var_name, 'rho_ex_w')
if var_name == 'Rw_brdf':
field_name = 'rho_ex_w'
if field_name in ds_corr:
values = ds_corr[field_name].values.T
output_blocks[var_name][:, water_idx[0], water_idx[1]] = values
# Write all variables for this block
for var_name, block_data in output_blocks.items():
output_dsts[var_name].write(block_data, window=window)
del (rw_block, sza_block, vza_block, raa_block, mask_block, output_blocks,
water_mask, water_idx, rw_water, sza_water, vza_water, raa_water,
pixel_ids, ds_block, ds_corr)
gc.collect()
# Close all files
src_rw.close()
src_ang.close()
src_mask.close()
for dst in output_dsts.values():
dst.close()
# Print performance statistics
total_blocks = blocks_processed + blocks_skipped
if total_blocks > 0:
skip_ratio = blocks_skipped / total_blocks * 100
print(f"Block processing completed:")
print(f" - Total blocks: {total_blocks}")
print(f" - Processed blocks (with water): {blocks_processed}")
print(f" - Skipped blocks (no water): {blocks_skipped} ({skip_ratio:.1f}%)")
print(f" - Total water pixels processed: {total_water_pixels:,}")
return list(output_paths.values())
# ---------------------------------------------------------------------------
# Top-level pipeline
# ---------------------------------------------------------------------------
def brdf_uncertainty(ds, adf=None):
''' Compute uncertainty of BRDF factor and propagate to nrrs '''
# Read LUT
if adf is None:
adf = ADF_OCP
unc_lut_path = adf % 'UNC'
if unc_lut_path not in UNC_LUT_CACHE:
UNC_LUT_CACHE[unc_lut_path] = xr.open_dataset(unc_lut_path, engine='netcdf4')
LUT = UNC_LUT_CACHE[unc_lut_path]
# Interpolate relative uncertainty
unc = LUT['unc'].interp(lambda_unc=ds.bands, theta_s_unc=ds.sza, theta_v_unc=ds.vza,
delta_phi_unc=ds.raa)
# Compute absolute uncertainty of factor
ds['brdf_unc'] = unc * ds['C_brdf']
# Flag BRDF_unc where NaN and set to 0
ds['brdf_unc_fail'] = np.isnan(ds['brdf_unc'])
ds['brdf_unc'] = xr.where(ds['brdf_unc_fail'], 0, ds['brdf_unc'])
# Propagate to nrrs
Rwex_unc2 = ds['brdf_unc'] * ds['brdf_unc'] * ds['Rw'] * ds['Rw']
if 'Rw_unc' in ds:
Rwex_unc2 += ds['C_brdf'] * ds['C_brdf'] * ds['Rw_unc'] * ds['Rw_unc']
ds['nrrs_unc'] = np.sqrt(Rwex_unc2)/np.pi
return ds
def brdf_prototype(ds, adf=None, brdf_model='L11', compute_uncertainty=True):
# 测试 BRDF 模型在 GUI 中不支持:强制覆盖
# brdf_model = 'M02SeaDAS'
# 压缩单一维度(例如,投放次数、提取次数等)以避免插值问题
ds, squeezedDims = squeeze_trivial_dims(ds)
# Initialise model
if brdf_model == 'M02':
BRDF_model = M02(bands=ds.bands, aot=ds.aot, wind=ds.wind, adf=None) # Don't use brdf_py.ADF context
elif brdf_model == 'M02SeaDAS':
BRDF_model = M02SeaDAS(bands=ds.bands, adf=None) # Don't use brdf_py.ADF context
elif brdf_model == 'L11':
BRDF_model = L11(bands=ds.bands, adf=None) # Don't use brdf_py.ADF context
elif brdf_model == 'O25':
BRDF_model = O25(bands=ds.bands, adf=None) # Don't use brdf_py.ADF context
else:
print("BRDF model %s not supported" % brdf_model)
sys.exit(1)
# Init pixel
BRDF_model.init_pixels(ds['sza'], ds['vza'], ds['raa'])
# Compute IOP and normalize by iterating
ds['nrrs'] = ds['Rw'] / np.pi
ds['convergeFlag'] = (0 * ds['sza']).astype(bool)
ds['C_brdf'] = 0 * ds['nrrs'] + 1
for iter_brdf in range(int(BRDF_model.niter)):
# M02: Initialise chl_iter
if brdf_model in ['M02', 'M02SeaDAS'] and (iter_brdf == 0):
chl_iter = {}
ds['log10_chl'] = 0 * ds['sza'] + float(np.log10(BRDF_model.OC4MEchl0))
chl_iter[-1] = 0 * ds['sza'] + float(BRDF_model.OC4MEchl0)
ds = BRDF_model.backward(ds, iter_brdf)
# M02: Check convergence (dummy for M02SeaDAS for the moment... epsilon set to 0)
if brdf_model in ['M02', 'M02SeaDAS']:
chl_iter[iter_brdf] = 10 ** ds['log10_chl']
# Check if convergence is reached |chl_old-chl_new| < epsilon * chl_new
ds['convergeFlag'] = (ds['convergeFlag']) | (
(np.abs(chl_iter[iter_brdf - 1] - chl_iter[iter_brdf]) < float(BRDF_model.OC4MEepsilon) * chl_iter[
iter_brdf]))
# Apply forward model in both geometries
forward_mod = BRDF_model.forward(ds)
forward_mod0 = BRDF_model.forward(ds, normalized=True)
ratio = forward_mod0 / forward_mod
# Drop remnant coordinates to avoid ambiguities in the update of the BRDF factor.
for coord in ratio.coords:
if coord not in ds['C_brdf'].coords:
ratio = ratio.drop(coord)
# Normalize reflectance
ds['C_brdf'] = xr.where(ds['convergeFlag'], ds['C_brdf'], ratio)
ds['nrrs'] = ds['Rw'] / np.pi * ds['C_brdf']
# Flag BRDF where NaN and set to 1 (no correction applied).
ds['C_brdf_fail'] = np.isnan(ds['C_brdf'])
ds['C_brdf'] = xr.where(ds['C_brdf_fail'], 1, ds['C_brdf'])
ds['nrrs'] = xr.where(ds['C_brdf_fail'], ds['Rw'] / np.pi, ds['nrrs'])
# If QAA_fail is raised, raise C_brdf_fail (but still apply C_brdf).
if 'QAA_fail' in ds:
ds['C_brdf_fail'] = (ds['C_brdf_fail']) | (ds['QAA_fail'])
# Compute uncertainty only if requested
if compute_uncertainty:
ds = brdf_uncertainty(ds)
# Compute flag
ds['flags_level2'] = ds['Rw'] * 0 # TODO
# Convert to reflectance unit
ds['rho_ex_w'] = ds['nrrs'] * np.pi
# Expand squeezed trivial dimensions
for dim,d0 in squeezedDims.items():
ds = ds.expand_dims(dim,axis=d0)
return ds
def run_brdf_correction(
hyperspectral_file,
angle_file,
mask_file,
wavelengths=None,
brdf_model='L11',
chunk_size=4096,
output_var=None, # if None, keep full output (original behaviour)
compute_uncertainty=True, # compute uncertainty only if requested
progress_bar=True, # 是否显示进度条
):
"""BRDF correction pipeline for water pixels in a hyperspectral scene.
... (docstring unchanged, update description) ...
"""
# ------------------------------------------------------------------
# 1. Read all inputs
# ------------------------------------------------------------------
rw_data, wvl, _img_meta = read_bsq(hyperspectral_file, wavelengths, dtype=np.float32)
angles = read_bip_angles(angle_file, dtype=np.float32)
water_mask = read_water_mask(mask_file)
n_bands, n_rows, n_cols = rw_data.shape
# Sanity-check spatial dimensions
if angles['sza'].shape != (n_rows, n_cols):
raise ValueError(
f"Angle file spatial shape {angles['sza'].shape} does not match "
f"hyperspectral image shape ({n_rows}, {n_cols})."
)
if water_mask.shape != (n_rows, n_cols):
raise ValueError(
f"Mask shape {water_mask.shape} does not match "
f"hyperspectral image shape ({n_rows}, {n_cols})."
)
# ------------------------------------------------------------------
# 2. Identify water pixels
# ------------------------------------------------------------------
water_rows, water_cols = np.where(water_mask == 1)
n_water = water_rows.size
if n_water == 0:
raise RuntimeError("No water pixels found in the mask (no pixels == 1).")
# ------------------------------------------------------------------
# 3. Determine which output variable(s) we need to keep
# ------------------------------------------------------------------
# If output_var is None, we keep everything (original behaviour)
keep_all = (output_var is None)
# Map output_var to the field in ds_corr
var_map = {
'Rw_brdf': 'rho_ex_w',
'rho_ex_w': 'rho_ex_w',
'nrrs': 'nrrs',
'C_brdf': 'C_brdf'
}
if not keep_all:
if output_var not in var_map:
raise ValueError(f"Unknown output_var '{output_var}'. Allowed: {list(var_map.keys())}")
field_name = var_map[output_var]
else:
field_name = None
# ------------------------------------------------------------------
# 4. Allocate output arrays (only those needed)
# ------------------------------------------------------------------
shape3 = (n_bands, n_rows, n_cols)
shape2 = (n_rows, n_cols)
if keep_all:
# Original behaviour: allocate all arrays
nrrs_full = np.full(shape3, np.nan, dtype=np.float32)
rho_ex_w_full = np.full(shape3, np.nan, dtype=np.float32)
C_brdf_full = np.full(shape3, np.nan, dtype=np.float32)
brdf_unc_full = np.full(shape3, np.nan, dtype=np.float32)
nrrs_unc_full = np.full(shape3, np.nan, dtype=np.float32)
C_brdf_fail_full = np.zeros(shape2, dtype=bool)
rw_brdf_full = rw_data.copy()
else:
# Only allocate the requested output array
output_full = np.full(shape3, np.nan, dtype=np.float32)
# We may also need C_brdf_fail if requested? Not for output_var, but could be used internally? Skip.
# For simplicity, we do not allocate extra arrays.
# ------------------------------------------------------------------
# 56. Chunked BRDF correction
# ------------------------------------------------------------------
n_chunks = int(np.ceil(n_water / chunk_size))
# 初始化进度条
iterator = range(n_chunks)
if progress_bar and tqdm is not None:
iterator = tqdm(iterator, desc="Processing chunks", unit="chunk")
elif progress_bar:
print("Warning: tqdm not installed, progress bar disabled.")
for chunk_idx in iterator:
lo = chunk_idx * chunk_size
hi = min(lo + chunk_size, n_water)
rows_c = water_rows[lo:hi]
cols_c = water_cols[lo:hi]
# Extract batch vectors
rw_c = rw_data[:, rows_c, cols_c].T # (chunk, bands)
sza_c = angles['sza'][rows_c, cols_c] # (chunk,)
vza_c = angles['vza'][rows_c, cols_c]
raa_c = angles['raa'][rows_c, cols_c]
pixel_idx_c = np.arange(hi - lo)
ds_chunk = xr.Dataset(
{
'Rw': xr.DataArray(rw_c, dims=('n', 'bands'),
coords={'n': pixel_idx_c, 'bands': wvl}),
'sza': xr.DataArray(sza_c, dims='n',
coords={'n': pixel_idx_c}),
'vza': xr.DataArray(vza_c, dims='n',
coords={'n': pixel_idx_c}),
'raa': xr.DataArray(raa_c, dims='n',
coords={'n': pixel_idx_c}),
}
)
# Run BRDF model with uncertainty control
ds_corr = brdf_prototype(ds_chunk, brdf_model=brdf_model, compute_uncertainty=compute_uncertainty)
if keep_all:
# Write all results back
nrrs_full[:, rows_c, cols_c] = ds_corr['nrrs'].values.T
rho_ex_w_full[:, rows_c, cols_c] = ds_corr['rho_ex_w'].values.T
C_brdf_full[:, rows_c, cols_c] = ds_corr['C_brdf'].values.T
if compute_uncertainty:
brdf_unc_full[:, rows_c, cols_c] = ds_corr['brdf_unc'].values.T
nrrs_unc_full[:, rows_c, cols_c] = ds_corr['nrrs_unc'].values.T
C_brdf_fail_full[rows_c, cols_c] = ds_corr['C_brdf_fail'].values
rw_brdf_full[:, rows_c, cols_c] = ds_corr['rho_ex_w'].values.T
else:
# Only write the requested output field
values = ds_corr[field_name].values.T
output_full[:, rows_c, cols_c] = values
# Release batch objects so GC can reclaim memory
del rw_c, sza_c, vza_c, raa_c, pixel_idx_c, ds_chunk, ds_corr
gc.collect()
# ------------------------------------------------------------------
# 7. Assemble full-scene output Dataset
# ------------------------------------------------------------------
full_coords = {'bands': wvl, 'y': np.arange(n_rows), 'x': np.arange(n_cols)}
dims3 = ('bands', 'y', 'x')
dims2 = ('y', 'x')
if keep_all:
# Original behaviour: return all variables
ds_out = xr.Dataset(
{
'Rw': xr.DataArray(rw_data, dims=dims3, coords=full_coords),
'Rw_brdf': xr.DataArray(rw_brdf_full, dims=dims3, coords=full_coords),
'nrrs': xr.DataArray(nrrs_full, dims=dims3, coords=full_coords),
'rho_ex_w': xr.DataArray(rho_ex_w_full, dims=dims3, coords=full_coords),
'C_brdf': xr.DataArray(C_brdf_full, dims=dims3, coords=full_coords),
'C_brdf_fail': xr.DataArray(C_brdf_fail_full, dims=dims2),
'water_mask': xr.DataArray(water_mask, dims=dims2),
'sza': xr.DataArray(angles['sza'], dims=dims2),
'saa': xr.DataArray(angles['saa'], dims=dims2),
'vza': xr.DataArray(angles['vza'], dims=dims2),
'vaa': xr.DataArray(angles['vaa'], dims=dims2),
'raa': xr.DataArray(angles['raa'], dims=dims2),
}
)
if compute_uncertainty:
ds_out['brdf_unc'] = xr.DataArray(brdf_unc_full, dims=dims3, coords=full_coords)
ds_out['nrrs_unc'] = xr.DataArray(nrrs_unc_full, dims=dims3, coords=full_coords)
else:
# Return only the requested output variable + geometry + mask
ds_out = xr.Dataset(
{
output_var: xr.DataArray(output_full, dims=dims3, coords=full_coords),
'Rw': xr.DataArray(rw_data, dims=dims3, coords=full_coords),
'water_mask': xr.DataArray(water_mask, dims=dims2),
'sza': xr.DataArray(angles['sza'], dims=dims2),
'saa': xr.DataArray(angles['saa'], dims=dims2),
'vza': xr.DataArray(angles['vza'], dims=dims2),
'vaa': xr.DataArray(angles['vaa'], dims=dims2),
'raa': xr.DataArray(angles['raa'], dims=dims2),
}
)
# If uncertainties were computed, we could add them optionally, but user didn't request them.
# For memory efficiency, we skip them here.
return ds_out
def process_brdf_files(
hyperspectral_file,
angle_file,
mask_file,
output_file,
wavelengths=None,
brdf_model="L11",
output_vars=None,
output_format="ENVI",
chunk_size=4096,
block_size=512,
compute_uncertainty=False,
progress_bar=True,
):
"""Run BRDF correction from three input files and save multiple variables to separate files."""
if output_format.upper() != "ENVI":
raise NotImplementedError("Only ENVI output format is supported in block-wise version.")
if output_vars is None:
output_vars = ["Rw_brdf"]
if isinstance(output_vars, str):
output_vars = [output_vars]
print(f"Computing BRDF correction with variables: {output_vars}")
print(f"Output base name: {output_file}")
saved_paths = run_brdf_correction_block(
hyperspectral_file=hyperspectral_file,
angle_file=angle_file,
mask_file=mask_file,
output_file=output_file,
wavelengths=wavelengths,
brdf_model=brdf_model,
output_vars=output_vars,
block_size=block_size,
compute_uncertainty=compute_uncertainty,
progress_bar=progress_bar,
)
return None, saved_paths
def main():
import argparse
parser = argparse.ArgumentParser(description="Water-region BRDF correction for hyperspectral data.")
parser.add_argument("hyperspectral_file", help="Input ENVI BSQ hyperspectral file or .hdr path.")
parser.add_argument("angle_file", help="Input ENVI BIP angle file or .hdr path.")
parser.add_argument("mask_file", help="Input water mask GeoTIFF; pixels == 1 are corrected.")
parser.add_argument("output_file", help="Output path prefix. ENVI writes .hdr/.img, NetCDF writes .nc.")
parser.add_argument("--brdf-model", default="L11", choices=["L11", "M02", "M02SeaDAS", "O25"])
parser.add_argument("--output-var", nargs='+', default=["Rw_brdf"],
help="Output variables to save (one or more). "
"Supported: Rw_brdf, rho_ex_w, nrrs, C_brdf, brdf_unc, nrrs_unc. "
"Example: --output-var Rw_brdf nrrs brdf_unc")
parser.add_argument("--output-format", default="ENVI", choices=["ENVI"])
parser.add_argument("--chunk-size", type=int, default=4096,
help="Number of water pixels per processing batch (default: 4096).")
parser.add_argument("--block-size", type=int, default=512,
help="Spatial block size for block-wise processing (default: 512).")
parser.add_argument("--compute-uncertainty", action="store_true",
help="Compute BRDF uncertainty (may increase runtime and output size).")
parser.add_argument("--no-progress-bar", action="store_true",
help="Disable progress bar (useful for non-interactive environments).")
args = parser.parse_args()
progress_bar = not args.no_progress_bar
_ds_out, saved_paths = process_brdf_files(
hyperspectral_file=args.hyperspectral_file,
angle_file=args.angle_file,
mask_file=args.mask_file,
output_file=args.output_file,
brdf_model=args.brdf_model,
output_vars=args.output_var,
output_format=args.output_format,
chunk_size=args.chunk_size,
block_size=args.block_size,
compute_uncertainty=args.compute_uncertainty,
progress_bar=progress_bar,
)
if isinstance(saved_paths, list):
print(f"BRDF correction finished. Saved {len(saved_paths)} files:")
for path in saved_paths:
print(f" - {path}")
else:
print(f"BRDF correction finished. Saved to: {saved_paths}")
if __name__ == "__main__":
main()