import json import os import warnings import sys import time import argparse import ray import numpy as np import Flexbrdf.hytools as ht from Flexbrdf.hytools.io.envi import * from Flexbrdf.hytools.topo import calc_topo_coeffs from Flexbrdf.hytools.brdf import calc_brdf_coeffs from Flexbrdf.hytools.glint import set_glint_parameters from Flexbrdf.hytools.masks import mask_create warnings.filterwarnings("ignore") np.seterr(divide='ignore', invalid='ignore') # Default configuration values DEFAULT_CONFIG = { 'file_type': 'envi', 'num_cpus': 16, 'bad_bands': [], 'corrections': [], 'resample': False, 'resampler': {'type': 'cubic', 'out_waves': []}, 'export': { 'coeffs':True, 'image': True, 'masks': False, 'subset_waves': [], 'output_dir': './output/', 'suffix': 'BRDF' }, 'brdf': { 'type': 'flex', 'grouped': True, 'geometric': 'li_dense_r', 'volume': 'ross_thick', 'b/r': 2.5, 'h/b': 2.0, 'sample_perc': 0.1, 'interp_kind': 'linear', 'calc_mask': [ ['water', {'band_1': 850, 'band_2': 660,"threshold": 290}], ['kernel_finite', {}], ['ancillary', {'name': 'sensor_zn', 'min': 0.03490658503988659, 'max': 'inf'}] ], 'apply_mask': [ ['water', {'band_1': 850, 'band_2': 660,"threshold": 290}] ], 'bin_type': 'dynamic', 'num_bins': 18, 'ndvi_bin_min': 0.05, 'ndvi_bin_max': 1.0, 'ndvi_perc_min': 10, 'ndvi_perc_max': 95, 'solar_zn_type': 'scene' }, # 'topo': { # 'type': 'scs+c', # 'calc_mask': [ # ['ndi', {'band_1': 850, 'band_2': 660, 'min': 0.05, 'max': 1.0}], # ['kernel_finite', {}], # ['ancillary', {'name': 'sensor_zn', 'min': 0.03490658503988659, 'max': 'inf'}] # ], # 'apply_mask': [ # ['ndi', {'band_1': 850, 'band_2': 660, 'min': 0.05, 'max': 1.0}] # ], # 'sample_perc': 0.1, # 'subgrouped': False, # 'subgroup': {} # }, # 'glint': { # 'type': 'hochberg', # 'correction_band': 560, # 'deep_water_sample': {}, # 'calc_mask': [ # ['ndi', {'band_1': 550, 'band_2': 2150, 'min': -1, 'max': 0}], # ['kernel_finite', {}] # ], # 'apply_mask': [ # ['ndi', {'band_1': 550, 'band_2': 2150, 'min': -1, 'max': 0}] # ] # } } def build_config_from_args(args): """从命令行参数构建配置字典""" # 自动发现输入文件 input_files = [] if os.path.isdir(args.input_dir): # 支持的文件扩展名 extensions = ['.tif', '.tiff', '.envi', '.img', '.dat'] for file in os.listdir(args.input_dir): if any(file.lower().endswith(ext) for ext in extensions): input_files.append(os.path.join(args.input_dir, file)) else: input_files = [args.input_dir] if os.path.isfile(args.input_dir) else [] if not input_files: raise ValueError(f"No input files found in {args.input_dir}") # 自动生成anc_files anc_files = {} for input_file in input_files: base_name = os.path.splitext(os.path.basename(input_file))[0] # 根据文件名模式生成对应的angles文件路径 # 例如: 2025_9_2_3_53_45_202592_35252_0_rad_geo_corrected_reflectance.dat # 对应的angles文件: 2025_9_2_3_53_45_202592_35252_0_rad_rgbxyz_geo_registered_angles.bip parts = base_name.split('_') if len(parts) >= 9: # 确保有足够的parts anc_base = f"{parts[0]}_{parts[1]}_{parts[2]}_{parts[3]}_{parts[4]}_{parts[5]}_{parts[6]}_{parts[7]}_{parts[8]}_rad_rgbxyz_geo_angles_registered" # 确保路径分隔符统一使用反斜杠(Windows) anc_dir_clean = args.anc_dir.replace('/', '\\') anc_path = os.path.join(anc_dir_clean, anc_base + ".bip") # 确保最终路径格式正确 anc_path = os.path.normpath(anc_path) if os.path.exists(anc_path): # 生成完整的anc_files字典结构 anc_files[input_file] = { "path_length": [anc_path, 0], "sensor_az": [anc_path, 9], "sensor_zn": [anc_path, 8], "solar_az": [anc_path, 7], "solar_zn": [anc_path, 6], "phase": [anc_path, 0], "slope": [anc_path, 0], "aspect": [anc_path, 0], "cosine_i": [anc_path, 0], "utc_time": [anc_path, 0] } else: # 如果找不到对应的angles文件,使用默认的obs文件 obs_path = os.path.join(args.anc_dir, base_name + "_obs") if os.path.exists(obs_path): anc_files[input_file] = obs_path # 构建配置字典 config_dict = DEFAULT_CONFIG.copy() config_dict.update({ 'input_files': input_files, 'anc_files': anc_files, 'file_type': 'envi', # 假设都是ENVI格式 'num_cpus': args.num_cpus, 'bad_bands': args.bad_bands if args.bad_bands else [], 'corrections': args.corrections if args.corrections else [], 'export': { 'coeffs': args.export_coeffs, 'image': args.export_image, 'masks': args.export_masks, 'subset_waves': args.subset_waves if args.subset_waves else [], 'output_dir': args.output_dir, 'suffix': args.suffix } }) # 根据命令行参数更新BRDF配置 if 'brdf' in args.corrections: brdf_config = config_dict['brdf'].copy() if hasattr(args, 'brdf_type') and args.brdf_type: brdf_config['type'] = args.brdf_type if hasattr(args, 'grouped') and args.grouped is not None: brdf_config['grouped'] = args.grouped if hasattr(args, 'geometric') and args.geometric: brdf_config['geometric'] = args.geometric if hasattr(args, 'volume') and args.volume: brdf_config['volume'] = args.volume if hasattr(args, 'num_bins') and args.num_bins: brdf_config['num_bins'] = args.num_bins config_dict['brdf'] = brdf_config # 根据命令行参数更新TOPO配置 if 'topo' in args.corrections: topo_config = config_dict['topo'].copy() if hasattr(args, 'topo_type') and args.topo_type: topo_config['type'] = args.topo_type config_dict['topo'] = topo_config # 根据命令行参数更新Glint配置 if 'glint' in args.corrections: glint_config = config_dict['glint'].copy() if hasattr(args, 'glint_type') and args.glint_type: glint_config['type'] = args.glint_type config_dict['glint'] = glint_config return config_dict def main(): parser = argparse.ArgumentParser(description="High-resolution image correction with automatic configuration") # 必需参数 parser.add_argument('input_dir', help='Input directory containing image files or single image file path') parser.add_argument('anc_dir', help='Ancillary directory containing angle/obs files') parser.add_argument('--output-dir', required=True, help='Output directory for corrected images') # 可选参数 - 基本设置 parser.add_argument('--num-cpus', type=int, default=16, help='Number of CPUs to use (default: 4)') parser.add_argument('--bad-bands', nargs='*', type=int, default=[], help='Bad band indices to exclude') # 校正类型 parser.add_argument('--corrections', nargs='*', choices=['topo', 'brdf', 'glint'], default=['brdf'], help='Correction types to apply (default: brdf)') # BRDF参数 parser.add_argument('--brdf-type', choices=['universal', 'flex'], default='flex', help='BRDF correction type (default: flex)') parser.add_argument('--grouped', action='store_true', default=True, help='Group images for BRDF correction (default: True)') parser.add_argument('--no-grouped', action='store_false', dest='grouped', help='Do not group images for BRDF correction') parser.add_argument('--geometric', default='li_dense_r', choices=['li_sparse', 'li_dense', 'li_dense_r', 'roujean'], help='Geometric kernel type (default: li_dense_r)') parser.add_argument('--volume', default='ross_thick', choices=['ross_thin', 'ross_thick', 'hotspot', 'roujean'], help='Volume kernel type (default: ross_thick)') parser.add_argument('--num-bins', type=int, default=18, help='Number of NDVI bins for FlexBRDF (default: 18)') # TOPO参数 parser.add_argument('--topo-type', default='scs+c', choices=['scs', 'scs+c', 'c', 'cosine', 'mod_minneart'], help='TOPO correction type (default: scs+c)') # Glint参数 parser.add_argument('--glint-type', default='hochberg', choices=['hochberg', 'gao', 'hedley'], help='Glint correction type (default: hochberg)') # 输出选项 parser.add_argument('--export-coeffs', action='store_true', default=True, help='Export correction coefficients') parser.add_argument('--export-image', action='store_true', default=True, help='Export corrected images (default: True)') parser.add_argument('--export-masks', action='store_true', default=True, help='Export correction masks (default: True)') parser.add_argument('--subset-waves', nargs='*', type=float, default=[], help='Subset of wavelengths to export (empty for all)') parser.add_argument('--suffix', default='corrected', help='Suffix for output files (default: corrected)') parser.add_argument('--output-config', type=str, default=None, help='Output current configuration to JSON file') # 向后兼容:如果只有一个参数且是JSON文件,则使用传统模式 if len(sys.argv) == 2 and sys.argv[1].endswith('.json'): config_file = sys.argv[1] with open(config_file, 'r') as outfile: config_dict = json.load(outfile) else: args = parser.parse_args() config_dict = build_config_from_args(args) # 如果指定了输出配置选项,则输出配置到JSON文件 if hasattr(args, 'output_config') and args.output_config: print(f"Outputting configuration to {args.output_config}...") # 重新排序配置字典以匹配期望的格式 ordered_config = { "bad_bands": config_dict.get("bad_bands", []), "file_type": config_dict.get("file_type", "envi"), "input_files": config_dict.get("input_files", []), "anc_files": config_dict.get("anc_files", {}), "num_cpus": config_dict.get("num_cpus", 4), "export": config_dict.get("export", {}), "corrections": config_dict.get("corrections", []), "brdf": config_dict.get("brdf", {}), "topo": config_dict.get("topo", {}), "glint": config_dict.get("glint", {}), "resample": config_dict.get("resample", False), "resampler": config_dict.get("resampler", {"type": "cubic", "out_waves": []}) } with open(args.output_config, 'w', encoding='utf-8') as f: json.dump(ordered_config, f, indent=2, ensure_ascii=False) print("Configuration output complete.") return # 如果只是输出配置,则退出程序 print("Starting image correction process...") total_start = time.time() images = config_dict["input_files"] if ray.is_initialized(): ray.shutdown() print("Using %s CPUs." % config_dict['num_cpus']) ray.init(num_cpus = config_dict['num_cpus']) HyTools = ray.remote(ht.HyTools) actors = [HyTools.remote() for image in images] print(f"Loading {len(images)} image(s)...") load_start = time.time() if config_dict['file_type'] == 'envi': anc_files = config_dict["anc_files"] _ = ray.get([a.read_file.remote(image,config_dict['file_type'], anc_files[image]) for a,image in zip(actors,images)]) elif config_dict['file_type'] == 'neon': _ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)]) _ = ray.get([a.create_bad_bands.remote(config_dict['bad_bands']) for a in actors]) load_time = time.time() - load_start print(".2f") # Apply corrections if config_dict["corrections"]: print(f"Applying {len(config_dict['corrections'])} correction(s): {', '.join(config_dict['corrections'])}") correction_start = time.time() for correction in config_dict["corrections"]: if correction =='topo': calc_topo_coeffs(actors,config_dict['topo']) elif correction == 'brdf': calc_brdf_coeffs(actors,config_dict) elif correction == 'glint': set_glint_parameters(actors,config_dict) correction_time = time.time() - correction_start print(".2f") if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0: print("Exporting correction coefficients...") coeff_start = time.time() _ = ray.get([a.do.remote(export_coeffs,config_dict['export']) for a in actors]) coeff_time = time.time() - coeff_start print(".2f") if config_dict['export']['image']: print("Exporting corrected images...") export_start = time.time() _ = ray.get([a.do.remote(apply_corrections,config_dict) for a in actors]) export_time = time.time() - export_start print(".2f") total_time = time.time() - total_start print(".2f") ray.shutdown() def export_coeffs(hy_obj,export_dict): '''Export correction coefficients to file. ''' for correction in hy_obj.corrections: coeff_file = export_dict['output_dir'] coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0] coeff_file += "_%s_coeffs_%s.json" % (correction,export_dict["suffix"]) with open(coeff_file, 'w') as outfile: if correction == 'topo': corr_dict = hy_obj.topo elif correction == 'glint': continue else: corr_dict = hy_obj.brdf json.dump(corr_dict,outfile) def apply_corrections(hy_obj,config_dict): '''Apply correction to image and export to file. ''' header_dict = hy_obj.get_header() header_dict['data ignore value'] = hy_obj.no_data header_dict['data type'] = 12 output_name = config_dict['export']['output_dir'] output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] output_name += "_%s" % config_dict['export']["suffix"] #Export all wavelengths if len(config_dict['export']['subset_waves']) == 0: if config_dict["resample"] == True: hy_obj.resampler = config_dict['resampler'] waves= hy_obj.resampler['out_waves'] else: waves = hy_obj.wavelengths header_dict['bands'] = len(waves) header_dict['wavelength'] = waves writer = WriteENVI(output_name,header_dict) iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections, resample=config_dict['resample']) while not iterator.complete: line = iterator.read_next() writer.write_line(line,iterator.current_line) writer.close() #Export subset of wavelengths else: waves = config_dict['export']['subset_waves'] bands = [hy_obj.wave_to_band(x) for x in waves] waves = [round(hy_obj.wavelengths[x],2) for x in bands] header_dict['bands'] = len(bands) header_dict['wavelength'] = waves writer = WriteENVI(output_name,header_dict) for b,band_num in enumerate(bands): band = hy_obj.get_band(band_num, corrections=hy_obj.corrections) writer.write_band(band, b) writer.close() #Export masks if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0): masks = [] mask_names = [] for correction in config_dict["corrections"]: for mask_type in config_dict[correction]['apply_mask']: mask_names.append(correction + '_' + mask_type[0]) masks.append(mask_create(hy_obj, [mask_type])) header_dict['data type'] = 1 header_dict['bands'] = len(masks) header_dict['band names'] = mask_names header_dict['samples'] = hy_obj.columns header_dict['lines'] = hy_obj.lines header_dict['wavelength'] = [] header_dict['fwhm'] = [] header_dict['wavelength units'] = '' header_dict['data ignore value'] = 255 output_name = config_dict['export']['output_dir'] output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] output_name += "_%s_mask" % config_dict['export']["suffix"] writer = WriteENVI(output_name,header_dict) for band_num,mask in enumerate(masks): mask = mask.astype(int) mask[~hy_obj.mask['no_data']] = 255 writer.write_band(mask,band_num) del masks if __name__== "__main__": main()