Files
BRDF/Flexbrdf/scripts/image_correct2.py
2026-04-10 16:46:45 +08:00

443 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import warnings
import sys
import time
import argparse
import ray
import numpy as np
import Flexbrdf.hytools as ht
from Flexbrdf.hytools.io.envi import *
from Flexbrdf.hytools.topo import calc_topo_coeffs
from Flexbrdf.hytools.brdf import calc_brdf_coeffs
from Flexbrdf.hytools.glint import set_glint_parameters
from Flexbrdf.hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
# Default configuration values
DEFAULT_CONFIG = {
'file_type': 'envi',
'num_cpus': 16,
'bad_bands': [],
'corrections': [],
'resample': False,
'resampler': {'type': 'cubic', 'out_waves': []},
'export': {
'coeffs':True,
'image': True,
'masks': False,
'subset_waves': [],
'output_dir': './output/',
'suffix': 'BRDF'
},
'brdf': {
'type': 'flex',
'grouped': True,
'geometric': 'li_dense_r',
'volume': 'ross_thick',
'b/r': 2.5,
'h/b': 2.0,
'sample_perc': 0.1,
'interp_kind': 'linear',
'calc_mask': [
['water', {'band_1': 850, 'band_2': 660,"threshold": 290}],
['kernel_finite', {}],
['ancillary', {'name': 'sensor_zn', 'min': 0.03490658503988659, 'max': 'inf'}]
],
'apply_mask': [
['water', {'band_1': 850, 'band_2': 660,"threshold": 290}]
],
'bin_type': 'dynamic',
'num_bins': 18,
'ndvi_bin_min': 0.05,
'ndvi_bin_max': 1.0,
'ndvi_perc_min': 10,
'ndvi_perc_max': 95,
'solar_zn_type': 'scene'
},
# 'topo': {
# 'type': 'scs+c',
# 'calc_mask': [
# ['ndi', {'band_1': 850, 'band_2': 660, 'min': 0.05, 'max': 1.0}],
# ['kernel_finite', {}],
# ['ancillary', {'name': 'sensor_zn', 'min': 0.03490658503988659, 'max': 'inf'}]
# ],
# 'apply_mask': [
# ['ndi', {'band_1': 850, 'band_2': 660, 'min': 0.05, 'max': 1.0}]
# ],
# 'sample_perc': 0.1,
# 'subgrouped': False,
# 'subgroup': {}
# },
# 'glint': {
# 'type': 'hochberg',
# 'correction_band': 560,
# 'deep_water_sample': {},
# 'calc_mask': [
# ['ndi', {'band_1': 550, 'band_2': 2150, 'min': -1, 'max': 0}],
# ['kernel_finite', {}]
# ],
# 'apply_mask': [
# ['ndi', {'band_1': 550, 'band_2': 2150, 'min': -1, 'max': 0}]
# ]
# }
}
def build_config_from_args(args):
"""从命令行参数构建配置字典"""
# 自动发现输入文件
input_files = []
if os.path.isdir(args.input_dir):
# 支持的文件扩展名
extensions = ['.tif', '.tiff', '.envi', '.img', '.dat']
for file in os.listdir(args.input_dir):
if any(file.lower().endswith(ext) for ext in extensions):
input_files.append(os.path.join(args.input_dir, file))
else:
input_files = [args.input_dir] if os.path.isfile(args.input_dir) else []
if not input_files:
raise ValueError(f"No input files found in {args.input_dir}")
# 自动生成anc_files
anc_files = {}
for input_file in input_files:
base_name = os.path.splitext(os.path.basename(input_file))[0]
# 根据文件名模式生成对应的angles文件路径
# 例如: 2025_9_2_3_53_45_202592_35252_0_rad_geo_corrected_reflectance.dat
# 对应的angles文件: 2025_9_2_3_53_45_202592_35252_0_rad_rgbxyz_geo_registered_angles.bip
parts = base_name.split('_')
if len(parts) >= 9: # 确保有足够的parts
anc_base = f"{parts[0]}_{parts[1]}_{parts[2]}_{parts[3]}_{parts[4]}_{parts[5]}_{parts[6]}_{parts[7]}_{parts[8]}_rad_rgbxyz_geo_angles_registered"
# 确保路径分隔符统一使用反斜杠Windows
anc_dir_clean = args.anc_dir.replace('/', '\\')
anc_path = os.path.join(anc_dir_clean, anc_base + ".bip")
# 确保最终路径格式正确
anc_path = os.path.normpath(anc_path)
if os.path.exists(anc_path):
# 生成完整的anc_files字典结构
anc_files[input_file] = {
"path_length": [anc_path, 0],
"sensor_az": [anc_path, 9],
"sensor_zn": [anc_path, 8],
"solar_az": [anc_path, 7],
"solar_zn": [anc_path, 6],
"phase": [anc_path, 0],
"slope": [anc_path, 0],
"aspect": [anc_path, 0],
"cosine_i": [anc_path, 0],
"utc_time": [anc_path, 0]
}
else:
# 如果找不到对应的angles文件使用默认的obs文件
obs_path = os.path.join(args.anc_dir, base_name + "_obs")
if os.path.exists(obs_path):
anc_files[input_file] = obs_path
# 构建配置字典
config_dict = DEFAULT_CONFIG.copy()
config_dict.update({
'input_files': input_files,
'anc_files': anc_files,
'file_type': 'envi', # 假设都是ENVI格式
'num_cpus': args.num_cpus,
'bad_bands': args.bad_bands if args.bad_bands else [],
'corrections': args.corrections if args.corrections else [],
'export': {
'coeffs': args.export_coeffs,
'image': args.export_image,
'masks': args.export_masks,
'subset_waves': args.subset_waves if args.subset_waves else [],
'output_dir': args.output_dir,
'suffix': args.suffix
}
})
# 根据命令行参数更新BRDF配置
if 'brdf' in args.corrections:
brdf_config = config_dict['brdf'].copy()
if hasattr(args, 'brdf_type') and args.brdf_type:
brdf_config['type'] = args.brdf_type
if hasattr(args, 'grouped') and args.grouped is not None:
brdf_config['grouped'] = args.grouped
if hasattr(args, 'geometric') and args.geometric:
brdf_config['geometric'] = args.geometric
if hasattr(args, 'volume') and args.volume:
brdf_config['volume'] = args.volume
if hasattr(args, 'num_bins') and args.num_bins:
brdf_config['num_bins'] = args.num_bins
config_dict['brdf'] = brdf_config
# 根据命令行参数更新TOPO配置
if 'topo' in args.corrections:
topo_config = config_dict['topo'].copy()
if hasattr(args, 'topo_type') and args.topo_type:
topo_config['type'] = args.topo_type
config_dict['topo'] = topo_config
# 根据命令行参数更新Glint配置
if 'glint' in args.corrections:
glint_config = config_dict['glint'].copy()
if hasattr(args, 'glint_type') and args.glint_type:
glint_config['type'] = args.glint_type
config_dict['glint'] = glint_config
return config_dict
def main():
parser = argparse.ArgumentParser(description="High-resolution image correction with automatic configuration")
# 必需参数
parser.add_argument('input_dir', help='Input directory containing image files or single image file path')
parser.add_argument('anc_dir', help='Ancillary directory containing angle/obs files')
parser.add_argument('--output-dir', required=True, help='Output directory for corrected images')
# 可选参数 - 基本设置
parser.add_argument('--num-cpus', type=int, default=16, help='Number of CPUs to use (default: 4)')
parser.add_argument('--bad-bands', nargs='*', type=int, default=[], help='Bad band indices to exclude')
# 校正类型
parser.add_argument('--corrections', nargs='*', choices=['topo', 'brdf', 'glint'],
default=['brdf'], help='Correction types to apply (default: brdf)')
# BRDF参数
parser.add_argument('--brdf-type', choices=['universal', 'flex'], default='flex',
help='BRDF correction type (default: flex)')
parser.add_argument('--grouped', action='store_true', default=True,
help='Group images for BRDF correction (default: True)')
parser.add_argument('--no-grouped', action='store_false', dest='grouped',
help='Do not group images for BRDF correction')
parser.add_argument('--geometric', default='li_dense_r',
choices=['li_sparse', 'li_dense', 'li_dense_r', 'roujean'],
help='Geometric kernel type (default: li_dense_r)')
parser.add_argument('--volume', default='ross_thick',
choices=['ross_thin', 'ross_thick', 'hotspot', 'roujean'],
help='Volume kernel type (default: ross_thick)')
parser.add_argument('--num-bins', type=int, default=18,
help='Number of NDVI bins for FlexBRDF (default: 18)')
# TOPO参数
parser.add_argument('--topo-type', default='scs+c',
choices=['scs', 'scs+c', 'c', 'cosine', 'mod_minneart'],
help='TOPO correction type (default: scs+c)')
# Glint参数
parser.add_argument('--glint-type', default='hochberg',
choices=['hochberg', 'gao', 'hedley'],
help='Glint correction type (default: hochberg)')
# 输出选项
parser.add_argument('--export-coeffs', action='store_true', default=True,
help='Export correction coefficients')
parser.add_argument('--export-image', action='store_true', default=True,
help='Export corrected images (default: True)')
parser.add_argument('--export-masks', action='store_true', default=True,
help='Export correction masks (default: True)')
parser.add_argument('--subset-waves', nargs='*', type=float, default=[],
help='Subset of wavelengths to export (empty for all)')
parser.add_argument('--suffix', default='corrected',
help='Suffix for output files (default: corrected)')
parser.add_argument('--output-config', type=str, default=None,
help='Output current configuration to JSON file')
# 向后兼容如果只有一个参数且是JSON文件则使用传统模式
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
config_file = sys.argv[1]
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
else:
args = parser.parse_args()
config_dict = build_config_from_args(args)
# 如果指定了输出配置选项则输出配置到JSON文件
if hasattr(args, 'output_config') and args.output_config:
print(f"Outputting configuration to {args.output_config}...")
# 重新排序配置字典以匹配期望的格式
ordered_config = {
"bad_bands": config_dict.get("bad_bands", []),
"file_type": config_dict.get("file_type", "envi"),
"input_files": config_dict.get("input_files", []),
"anc_files": config_dict.get("anc_files", {}),
"num_cpus": config_dict.get("num_cpus", 4),
"export": config_dict.get("export", {}),
"corrections": config_dict.get("corrections", []),
"brdf": config_dict.get("brdf", {}),
"topo": config_dict.get("topo", {}),
"glint": config_dict.get("glint", {}),
"resample": config_dict.get("resample", False),
"resampler": config_dict.get("resampler", {"type": "cubic", "out_waves": []})
}
with open(args.output_config, 'w', encoding='utf-8') as f:
json.dump(ordered_config, f, indent=2, ensure_ascii=False)
print("Configuration output complete.")
return # 如果只是输出配置,则退出程序
print("Starting image correction process...")
total_start = time.time()
images = config_dict["input_files"]
if ray.is_initialized():
ray.shutdown()
print("Using %s CPUs." % config_dict['num_cpus'])
ray.init(num_cpus = config_dict['num_cpus'])
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for image in images]
print(f"Loading {len(images)} image(s)...")
load_start = time.time()
if config_dict['file_type'] == 'envi':
anc_files = config_dict["anc_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_files[image]) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
_ = ray.get([a.create_bad_bands.remote(config_dict['bad_bands']) for a in actors])
load_time = time.time() - load_start
print(".2f")
# Apply corrections
if config_dict["corrections"]:
print(f"Applying {len(config_dict['corrections'])} correction(s): {', '.join(config_dict['corrections'])}")
correction_start = time.time()
for correction in config_dict["corrections"]:
if correction =='topo':
calc_topo_coeffs(actors,config_dict['topo'])
elif correction == 'brdf':
calc_brdf_coeffs(actors,config_dict)
elif correction == 'glint':
set_glint_parameters(actors,config_dict)
correction_time = time.time() - correction_start
print(".2f")
if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0:
print("Exporting correction coefficients...")
coeff_start = time.time()
_ = ray.get([a.do.remote(export_coeffs,config_dict['export']) for a in actors])
coeff_time = time.time() - coeff_start
print(".2f")
if config_dict['export']['image']:
print("Exporting corrected images...")
export_start = time.time()
_ = ray.get([a.do.remote(apply_corrections,config_dict) for a in actors])
export_time = time.time() - export_start
print(".2f")
total_time = time.time() - total_start
print(".2f")
ray.shutdown()
def export_coeffs(hy_obj,export_dict):
'''Export correction coefficients to file.
'''
for correction in hy_obj.corrections:
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s.json" % (correction,export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
elif correction == 'glint':
continue
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict,outfile)
def apply_corrections(hy_obj,config_dict):
'''Apply correction to image and export
to file.
'''
header_dict = hy_obj.get_header()
header_dict['data ignore value'] = hy_obj.no_data
header_dict['data type'] = 12
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s" % config_dict['export']["suffix"]
#Export all wavelengths
if len(config_dict['export']['subset_waves']) == 0:
if config_dict["resample"] == True:
hy_obj.resampler = config_dict['resampler']
waves= hy_obj.resampler['out_waves']
else:
waves = hy_obj.wavelengths
header_dict['bands'] = len(waves)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name,header_dict)
iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections,
resample=config_dict['resample'])
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line,iterator.current_line)
writer.close()
#Export subset of wavelengths
else:
waves = config_dict['export']['subset_waves']
bands = [hy_obj.wave_to_band(x) for x in waves]
waves = [round(hy_obj.wavelengths[x],2) for x in bands]
header_dict['bands'] = len(bands)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name,header_dict)
for b,band_num in enumerate(bands):
band = hy_obj.get_band(band_num,
corrections=hy_obj.corrections)
writer.write_band(band, b)
writer.close()
#Export masks
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
masks = []
mask_names = []
for correction in config_dict["corrections"]:
for mask_type in config_dict[correction]['apply_mask']:
mask_names.append(correction + '_' + mask_type[0])
masks.append(mask_create(hy_obj, [mask_type]))
header_dict['data type'] = 1
header_dict['bands'] = len(masks)
header_dict['band names'] = mask_names
header_dict['samples'] = hy_obj.columns
header_dict['lines'] = hy_obj.lines
header_dict['wavelength'] = []
header_dict['fwhm'] = []
header_dict['wavelength units'] = ''
header_dict['data ignore value'] = 255
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s_mask" % config_dict['export']["suffix"]
writer = WriteENVI(output_name,header_dict)
for band_num,mask in enumerate(masks):
mask = mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band(mask,band_num)
del masks
if __name__== "__main__":
main()