443 lines
17 KiB
Python
443 lines
17 KiB
Python
import json
|
||
import os
|
||
import warnings
|
||
import sys
|
||
import time
|
||
import argparse
|
||
import ray
|
||
import numpy as np
|
||
import Flexbrdf.hytools as ht
|
||
from Flexbrdf.hytools.io.envi import *
|
||
from Flexbrdf.hytools.topo import calc_topo_coeffs
|
||
from Flexbrdf.hytools.brdf import calc_brdf_coeffs
|
||
from Flexbrdf.hytools.glint import set_glint_parameters
|
||
from Flexbrdf.hytools.masks import mask_create
|
||
|
||
warnings.filterwarnings("ignore")
|
||
np.seterr(divide='ignore', invalid='ignore')
|
||
|
||
# Default configuration values
|
||
DEFAULT_CONFIG = {
|
||
'file_type': 'envi',
|
||
'num_cpus': 16,
|
||
'bad_bands': [],
|
||
'corrections': [],
|
||
'resample': False,
|
||
'resampler': {'type': 'cubic', 'out_waves': []},
|
||
'export': {
|
||
'coeffs':True,
|
||
'image': True,
|
||
'masks': False,
|
||
'subset_waves': [],
|
||
'output_dir': './output/',
|
||
'suffix': 'BRDF'
|
||
},
|
||
'brdf': {
|
||
'type': 'flex',
|
||
'grouped': True,
|
||
'geometric': 'li_dense_r',
|
||
'volume': 'ross_thick',
|
||
'b/r': 2.5,
|
||
'h/b': 2.0,
|
||
'sample_perc': 0.1,
|
||
'interp_kind': 'linear',
|
||
'calc_mask': [
|
||
['water', {'band_1': 850, 'band_2': 660,"threshold": 290}],
|
||
['kernel_finite', {}],
|
||
['ancillary', {'name': 'sensor_zn', 'min': 0.03490658503988659, 'max': 'inf'}]
|
||
],
|
||
'apply_mask': [
|
||
['water', {'band_1': 850, 'band_2': 660,"threshold": 290}]
|
||
],
|
||
'bin_type': 'dynamic',
|
||
'num_bins': 18,
|
||
'ndvi_bin_min': 0.05,
|
||
'ndvi_bin_max': 1.0,
|
||
'ndvi_perc_min': 10,
|
||
'ndvi_perc_max': 95,
|
||
'solar_zn_type': 'scene'
|
||
},
|
||
# 'topo': {
|
||
# 'type': 'scs+c',
|
||
# 'calc_mask': [
|
||
# ['ndi', {'band_1': 850, 'band_2': 660, 'min': 0.05, 'max': 1.0}],
|
||
# ['kernel_finite', {}],
|
||
# ['ancillary', {'name': 'sensor_zn', 'min': 0.03490658503988659, 'max': 'inf'}]
|
||
# ],
|
||
# 'apply_mask': [
|
||
# ['ndi', {'band_1': 850, 'band_2': 660, 'min': 0.05, 'max': 1.0}]
|
||
# ],
|
||
# 'sample_perc': 0.1,
|
||
# 'subgrouped': False,
|
||
# 'subgroup': {}
|
||
# },
|
||
# 'glint': {
|
||
# 'type': 'hochberg',
|
||
# 'correction_band': 560,
|
||
# 'deep_water_sample': {},
|
||
# 'calc_mask': [
|
||
# ['ndi', {'band_1': 550, 'band_2': 2150, 'min': -1, 'max': 0}],
|
||
# ['kernel_finite', {}]
|
||
# ],
|
||
# 'apply_mask': [
|
||
# ['ndi', {'band_1': 550, 'band_2': 2150, 'min': -1, 'max': 0}]
|
||
# ]
|
||
# }
|
||
}
|
||
|
||
def build_config_from_args(args):
|
||
"""从命令行参数构建配置字典"""
|
||
# 自动发现输入文件
|
||
input_files = []
|
||
if os.path.isdir(args.input_dir):
|
||
# 支持的文件扩展名
|
||
extensions = ['.tif', '.tiff', '.envi', '.img', '.dat']
|
||
for file in os.listdir(args.input_dir):
|
||
if any(file.lower().endswith(ext) for ext in extensions):
|
||
input_files.append(os.path.join(args.input_dir, file))
|
||
else:
|
||
input_files = [args.input_dir] if os.path.isfile(args.input_dir) else []
|
||
|
||
if not input_files:
|
||
raise ValueError(f"No input files found in {args.input_dir}")
|
||
|
||
# 自动生成anc_files
|
||
anc_files = {}
|
||
for input_file in input_files:
|
||
base_name = os.path.splitext(os.path.basename(input_file))[0]
|
||
# 根据文件名模式生成对应的angles文件路径
|
||
# 例如: 2025_9_2_3_53_45_202592_35252_0_rad_geo_corrected_reflectance.dat
|
||
# 对应的angles文件: 2025_9_2_3_53_45_202592_35252_0_rad_rgbxyz_geo_registered_angles.bip
|
||
parts = base_name.split('_')
|
||
if len(parts) >= 9: # 确保有足够的parts
|
||
anc_base = f"{parts[0]}_{parts[1]}_{parts[2]}_{parts[3]}_{parts[4]}_{parts[5]}_{parts[6]}_{parts[7]}_{parts[8]}_rad_rgbxyz_geo_angles_registered"
|
||
# 确保路径分隔符统一使用反斜杠(Windows)
|
||
anc_dir_clean = args.anc_dir.replace('/', '\\')
|
||
anc_path = os.path.join(anc_dir_clean, anc_base + ".bip")
|
||
# 确保最终路径格式正确
|
||
anc_path = os.path.normpath(anc_path)
|
||
|
||
if os.path.exists(anc_path):
|
||
# 生成完整的anc_files字典结构
|
||
anc_files[input_file] = {
|
||
"path_length": [anc_path, 0],
|
||
"sensor_az": [anc_path, 9],
|
||
"sensor_zn": [anc_path, 8],
|
||
"solar_az": [anc_path, 7],
|
||
"solar_zn": [anc_path, 6],
|
||
"phase": [anc_path, 0],
|
||
"slope": [anc_path, 0],
|
||
"aspect": [anc_path, 0],
|
||
"cosine_i": [anc_path, 0],
|
||
"utc_time": [anc_path, 0]
|
||
}
|
||
else:
|
||
# 如果找不到对应的angles文件,使用默认的obs文件
|
||
obs_path = os.path.join(args.anc_dir, base_name + "_obs")
|
||
if os.path.exists(obs_path):
|
||
anc_files[input_file] = obs_path
|
||
|
||
# 构建配置字典
|
||
config_dict = DEFAULT_CONFIG.copy()
|
||
config_dict.update({
|
||
'input_files': input_files,
|
||
'anc_files': anc_files,
|
||
'file_type': 'envi', # 假设都是ENVI格式
|
||
'num_cpus': args.num_cpus,
|
||
'bad_bands': args.bad_bands if args.bad_bands else [],
|
||
'corrections': args.corrections if args.corrections else [],
|
||
'export': {
|
||
'coeffs': args.export_coeffs,
|
||
'image': args.export_image,
|
||
'masks': args.export_masks,
|
||
'subset_waves': args.subset_waves if args.subset_waves else [],
|
||
'output_dir': args.output_dir,
|
||
'suffix': args.suffix
|
||
}
|
||
})
|
||
|
||
# 根据命令行参数更新BRDF配置
|
||
if 'brdf' in args.corrections:
|
||
brdf_config = config_dict['brdf'].copy()
|
||
if hasattr(args, 'brdf_type') and args.brdf_type:
|
||
brdf_config['type'] = args.brdf_type
|
||
if hasattr(args, 'grouped') and args.grouped is not None:
|
||
brdf_config['grouped'] = args.grouped
|
||
if hasattr(args, 'geometric') and args.geometric:
|
||
brdf_config['geometric'] = args.geometric
|
||
if hasattr(args, 'volume') and args.volume:
|
||
brdf_config['volume'] = args.volume
|
||
if hasattr(args, 'num_bins') and args.num_bins:
|
||
brdf_config['num_bins'] = args.num_bins
|
||
config_dict['brdf'] = brdf_config
|
||
|
||
# 根据命令行参数更新TOPO配置
|
||
if 'topo' in args.corrections:
|
||
topo_config = config_dict['topo'].copy()
|
||
if hasattr(args, 'topo_type') and args.topo_type:
|
||
topo_config['type'] = args.topo_type
|
||
config_dict['topo'] = topo_config
|
||
|
||
# 根据命令行参数更新Glint配置
|
||
if 'glint' in args.corrections:
|
||
glint_config = config_dict['glint'].copy()
|
||
if hasattr(args, 'glint_type') and args.glint_type:
|
||
glint_config['type'] = args.glint_type
|
||
config_dict['glint'] = glint_config
|
||
|
||
return config_dict
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="High-resolution image correction with automatic configuration")
|
||
|
||
# 必需参数
|
||
parser.add_argument('input_dir', help='Input directory containing image files or single image file path')
|
||
parser.add_argument('anc_dir', help='Ancillary directory containing angle/obs files')
|
||
parser.add_argument('--output-dir', required=True, help='Output directory for corrected images')
|
||
|
||
# 可选参数 - 基本设置
|
||
parser.add_argument('--num-cpus', type=int, default=16, help='Number of CPUs to use (default: 4)')
|
||
parser.add_argument('--bad-bands', nargs='*', type=int, default=[], help='Bad band indices to exclude')
|
||
|
||
# 校正类型
|
||
parser.add_argument('--corrections', nargs='*', choices=['topo', 'brdf', 'glint'],
|
||
default=['brdf'], help='Correction types to apply (default: brdf)')
|
||
|
||
# BRDF参数
|
||
parser.add_argument('--brdf-type', choices=['universal', 'flex'], default='flex',
|
||
help='BRDF correction type (default: flex)')
|
||
parser.add_argument('--grouped', action='store_true', default=True,
|
||
help='Group images for BRDF correction (default: True)')
|
||
parser.add_argument('--no-grouped', action='store_false', dest='grouped',
|
||
help='Do not group images for BRDF correction')
|
||
parser.add_argument('--geometric', default='li_dense_r',
|
||
choices=['li_sparse', 'li_dense', 'li_dense_r', 'roujean'],
|
||
help='Geometric kernel type (default: li_dense_r)')
|
||
parser.add_argument('--volume', default='ross_thick',
|
||
choices=['ross_thin', 'ross_thick', 'hotspot', 'roujean'],
|
||
help='Volume kernel type (default: ross_thick)')
|
||
parser.add_argument('--num-bins', type=int, default=18,
|
||
help='Number of NDVI bins for FlexBRDF (default: 18)')
|
||
|
||
# TOPO参数
|
||
parser.add_argument('--topo-type', default='scs+c',
|
||
choices=['scs', 'scs+c', 'c', 'cosine', 'mod_minneart'],
|
||
help='TOPO correction type (default: scs+c)')
|
||
|
||
# Glint参数
|
||
parser.add_argument('--glint-type', default='hochberg',
|
||
choices=['hochberg', 'gao', 'hedley'],
|
||
help='Glint correction type (default: hochberg)')
|
||
|
||
# 输出选项
|
||
parser.add_argument('--export-coeffs', action='store_true', default=True,
|
||
help='Export correction coefficients')
|
||
parser.add_argument('--export-image', action='store_true', default=True,
|
||
help='Export corrected images (default: True)')
|
||
parser.add_argument('--export-masks', action='store_true', default=True,
|
||
help='Export correction masks (default: True)')
|
||
parser.add_argument('--subset-waves', nargs='*', type=float, default=[],
|
||
help='Subset of wavelengths to export (empty for all)')
|
||
parser.add_argument('--suffix', default='corrected',
|
||
help='Suffix for output files (default: corrected)')
|
||
parser.add_argument('--output-config', type=str, default=None,
|
||
help='Output current configuration to JSON file')
|
||
|
||
# 向后兼容:如果只有一个参数且是JSON文件,则使用传统模式
|
||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||
config_file = sys.argv[1]
|
||
with open(config_file, 'r') as outfile:
|
||
config_dict = json.load(outfile)
|
||
else:
|
||
args = parser.parse_args()
|
||
config_dict = build_config_from_args(args)
|
||
|
||
# 如果指定了输出配置选项,则输出配置到JSON文件
|
||
if hasattr(args, 'output_config') and args.output_config:
|
||
print(f"Outputting configuration to {args.output_config}...")
|
||
# 重新排序配置字典以匹配期望的格式
|
||
ordered_config = {
|
||
"bad_bands": config_dict.get("bad_bands", []),
|
||
"file_type": config_dict.get("file_type", "envi"),
|
||
"input_files": config_dict.get("input_files", []),
|
||
"anc_files": config_dict.get("anc_files", {}),
|
||
"num_cpus": config_dict.get("num_cpus", 4),
|
||
"export": config_dict.get("export", {}),
|
||
"corrections": config_dict.get("corrections", []),
|
||
"brdf": config_dict.get("brdf", {}),
|
||
"topo": config_dict.get("topo", {}),
|
||
"glint": config_dict.get("glint", {}),
|
||
"resample": config_dict.get("resample", False),
|
||
"resampler": config_dict.get("resampler", {"type": "cubic", "out_waves": []})
|
||
}
|
||
with open(args.output_config, 'w', encoding='utf-8') as f:
|
||
json.dump(ordered_config, f, indent=2, ensure_ascii=False)
|
||
print("Configuration output complete.")
|
||
return # 如果只是输出配置,则退出程序
|
||
|
||
print("Starting image correction process...")
|
||
total_start = time.time()
|
||
|
||
images = config_dict["input_files"]
|
||
|
||
if ray.is_initialized():
|
||
ray.shutdown()
|
||
print("Using %s CPUs." % config_dict['num_cpus'])
|
||
ray.init(num_cpus = config_dict['num_cpus'])
|
||
|
||
HyTools = ray.remote(ht.HyTools)
|
||
actors = [HyTools.remote() for image in images]
|
||
|
||
print(f"Loading {len(images)} image(s)...")
|
||
load_start = time.time()
|
||
|
||
if config_dict['file_type'] == 'envi':
|
||
anc_files = config_dict["anc_files"]
|
||
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
|
||
anc_files[image]) for a,image in zip(actors,images)])
|
||
|
||
elif config_dict['file_type'] == 'neon':
|
||
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
|
||
|
||
_ = ray.get([a.create_bad_bands.remote(config_dict['bad_bands']) for a in actors])
|
||
|
||
load_time = time.time() - load_start
|
||
print(".2f")
|
||
|
||
# Apply corrections
|
||
if config_dict["corrections"]:
|
||
print(f"Applying {len(config_dict['corrections'])} correction(s): {', '.join(config_dict['corrections'])}")
|
||
correction_start = time.time()
|
||
|
||
for correction in config_dict["corrections"]:
|
||
if correction =='topo':
|
||
calc_topo_coeffs(actors,config_dict['topo'])
|
||
elif correction == 'brdf':
|
||
calc_brdf_coeffs(actors,config_dict)
|
||
elif correction == 'glint':
|
||
set_glint_parameters(actors,config_dict)
|
||
|
||
correction_time = time.time() - correction_start
|
||
print(".2f")
|
||
|
||
if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0:
|
||
print("Exporting correction coefficients...")
|
||
coeff_start = time.time()
|
||
_ = ray.get([a.do.remote(export_coeffs,config_dict['export']) for a in actors])
|
||
coeff_time = time.time() - coeff_start
|
||
print(".2f")
|
||
|
||
if config_dict['export']['image']:
|
||
print("Exporting corrected images...")
|
||
export_start = time.time()
|
||
_ = ray.get([a.do.remote(apply_corrections,config_dict) for a in actors])
|
||
export_time = time.time() - export_start
|
||
print(".2f")
|
||
|
||
total_time = time.time() - total_start
|
||
print(".2f")
|
||
|
||
ray.shutdown()
|
||
|
||
def export_coeffs(hy_obj,export_dict):
|
||
'''Export correction coefficients to file.
|
||
'''
|
||
for correction in hy_obj.corrections:
|
||
coeff_file = export_dict['output_dir']
|
||
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
|
||
coeff_file += "_%s_coeffs_%s.json" % (correction,export_dict["suffix"])
|
||
|
||
with open(coeff_file, 'w') as outfile:
|
||
if correction == 'topo':
|
||
corr_dict = hy_obj.topo
|
||
elif correction == 'glint':
|
||
continue
|
||
else:
|
||
corr_dict = hy_obj.brdf
|
||
json.dump(corr_dict,outfile)
|
||
|
||
def apply_corrections(hy_obj,config_dict):
|
||
'''Apply correction to image and export
|
||
to file.
|
||
'''
|
||
|
||
header_dict = hy_obj.get_header()
|
||
header_dict['data ignore value'] = hy_obj.no_data
|
||
header_dict['data type'] = 12
|
||
|
||
output_name = config_dict['export']['output_dir']
|
||
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
|
||
output_name += "_%s" % config_dict['export']["suffix"]
|
||
|
||
#Export all wavelengths
|
||
if len(config_dict['export']['subset_waves']) == 0:
|
||
|
||
if config_dict["resample"] == True:
|
||
hy_obj.resampler = config_dict['resampler']
|
||
waves= hy_obj.resampler['out_waves']
|
||
else:
|
||
waves = hy_obj.wavelengths
|
||
|
||
header_dict['bands'] = len(waves)
|
||
header_dict['wavelength'] = waves
|
||
|
||
writer = WriteENVI(output_name,header_dict)
|
||
iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections,
|
||
resample=config_dict['resample'])
|
||
while not iterator.complete:
|
||
line = iterator.read_next()
|
||
writer.write_line(line,iterator.current_line)
|
||
writer.close()
|
||
|
||
#Export subset of wavelengths
|
||
else:
|
||
waves = config_dict['export']['subset_waves']
|
||
bands = [hy_obj.wave_to_band(x) for x in waves]
|
||
waves = [round(hy_obj.wavelengths[x],2) for x in bands]
|
||
header_dict['bands'] = len(bands)
|
||
header_dict['wavelength'] = waves
|
||
|
||
writer = WriteENVI(output_name,header_dict)
|
||
for b,band_num in enumerate(bands):
|
||
band = hy_obj.get_band(band_num,
|
||
corrections=hy_obj.corrections)
|
||
writer.write_band(band, b)
|
||
writer.close()
|
||
|
||
#Export masks
|
||
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
|
||
masks = []
|
||
mask_names = []
|
||
|
||
for correction in config_dict["corrections"]:
|
||
for mask_type in config_dict[correction]['apply_mask']:
|
||
mask_names.append(correction + '_' + mask_type[0])
|
||
masks.append(mask_create(hy_obj, [mask_type]))
|
||
|
||
header_dict['data type'] = 1
|
||
header_dict['bands'] = len(masks)
|
||
header_dict['band names'] = mask_names
|
||
header_dict['samples'] = hy_obj.columns
|
||
header_dict['lines'] = hy_obj.lines
|
||
header_dict['wavelength'] = []
|
||
header_dict['fwhm'] = []
|
||
header_dict['wavelength units'] = ''
|
||
header_dict['data ignore value'] = 255
|
||
|
||
|
||
output_name = config_dict['export']['output_dir']
|
||
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
|
||
output_name += "_%s_mask" % config_dict['export']["suffix"]
|
||
|
||
writer = WriteENVI(output_name,header_dict)
|
||
|
||
for band_num,mask in enumerate(masks):
|
||
mask = mask.astype(int)
|
||
mask[~hy_obj.mask['no_data']] = 255
|
||
writer.write_band(mask,band_num)
|
||
|
||
del masks
|
||
|
||
if __name__== "__main__":
|
||
main()
|