import json import os import warnings import sys import time import ray import numpy as np import hytools as ht from hytools.io.envi import * from hytools.topo import calc_topo_coeffs from hytools.brdf import calc_brdf_coeffs from hytools.glint import set_glint_parameters from hytools.masks import mask_create warnings.filterwarnings("ignore") np.seterr(divide='ignore', invalid='ignore') # ---------- 全局函数:供 Ray 远程调用 ---------- def export_coeffs(hy_obj, export_dict): """Export correction coefficients to file.""" for correction in hy_obj.corrections: coeff_file = export_dict['output_dir'] coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0] coeff_file += "_%s_coeffs_%s.json" % (correction, export_dict["suffix"]) with open(coeff_file, 'w') as outfile: if correction == 'topo': corr_dict = hy_obj.topo elif correction == 'glint': continue else: corr_dict = hy_obj.brdf json.dump(corr_dict, outfile) def apply_corrections(hy_obj, config_dict): """Apply correction to image and export to file.""" header_dict = hy_obj.get_header() header_dict['data ignore value'] = hy_obj.no_data header_dict['data type'] = 12 output_name = config_dict['export']['output_dir'] output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] output_name += "_%s" % config_dict['export']["suffix"] # Export all wavelengths if len(config_dict['export']['subset_waves']) == 0: if config_dict["resample"]: hy_obj.resampler = config_dict['resampler'] waves = hy_obj.resampler['out_waves'] else: waves = hy_obj.wavelengths header_dict['bands'] = len(waves) header_dict['wavelength'] = waves writer = WriteENVI(output_name, header_dict) iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections, resample=config_dict['resample']) while not iterator.complete: line = iterator.read_next() writer.write_line(line, iterator.current_line) writer.close() # Export subset of wavelengths else: waves = config_dict['export']['subset_waves'] bands = [hy_obj.wave_to_band(x) for x in waves] waves = [round(hy_obj.wavelengths[x], 2) for x in bands] header_dict['bands'] = len(bands) header_dict['wavelength'] = waves writer = WriteENVI(output_name, header_dict) for b, band_num in enumerate(bands): band = hy_obj.get_band(band_num, corrections=hy_obj.corrections) writer.write_band(band, b) writer.close() # Export masks if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0): masks = [] mask_names = [] for correction in config_dict["corrections"]: for mask_type in config_dict[correction]['apply_mask']: mask_names.append(correction + '_' + mask_type[0]) masks.append(mask_create(hy_obj, [mask_type])) header_dict['data type'] = 1 header_dict['bands'] = len(masks) header_dict['band names'] = mask_names header_dict['samples'] = hy_obj.columns header_dict['lines'] = hy_obj.lines header_dict['wavelength'] = [] header_dict['fwhm'] = [] header_dict['wavelength units'] = '' header_dict['data ignore value'] = 255 output_name = config_dict['export']['output_dir'] output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] output_name += "_%s_mask" % config_dict['export']["suffix"] writer = WriteENVI(output_name, header_dict) for band_num, mask in enumerate(masks): mask = mask.astype(int) mask[~hy_obj.mask['no_data']] = 255 writer.write_band(mask, band_num) del masks # ---------- 主类 ---------- class HyToolsCorrector: """ 高光谱图像校正处理器 用法: corrector = HyToolsCorrector() corrector.run("path/to/config.json") # 或者单次调用 HyToolsCorrector.process_file("path/to/config.json") """ def __init__(self): """初始化处理器(不自动启动 Ray)""" self.config_dict = None self.ray_initialized = False def load_config(self, config_file: str) -> dict: """加载 JSON 配置文件""" with open(config_file, 'r') as f: self.config_dict = json.load(f) return self.config_dict def _init_ray(self): """初始化 Ray(如果尚未初始化)""" if not ray.is_initialized(): num_cpus = self.config_dict.get('num_cpus', None) print(f"Initializing Ray with {num_cpus} CPUs.") ray.init(num_cpus=num_cpus) self.ray_initialized = True else: print("Ray already initialized, reusing existing instance.") self.ray_initialized = True # 假设外部已初始化 def _shutdown_ray(self): """关闭 Ray(仅在由本类启动时关闭)""" if self.ray_initialized and ray.is_initialized(): ray.shutdown() self.ray_initialized = False def run(self, config_file: str = None): """ 执行完整的校正流程 :param config_file: JSON 配置文件路径,若为 None 则需先调用 load_config() """ if config_file is not None: self.load_config(config_file) if self.config_dict is None: raise ValueError("配置文件未加载,请先调用 load_config() 或传入文件路径。") print("Starting image correction process...") total_start = time.time() images = self.config_dict["input_files"] # 初始化 Ray self._init_ray() # 创建远程 actors HyTools = ray.remote(ht.HyTools) actors = [HyTools.remote() for _ in images] # 加载图像 print(f"Loading {len(images)} image(s)...") load_start = time.time() if self.config_dict['file_type'] == 'envi': anc_files = self.config_dict.get("anc_files", {}) _ = ray.get([a.read_file.remote(img, self.config_dict['file_type'], anc_files.get(img, None)) for a, img in zip(actors, images)]) elif self.config_dict['file_type'] == 'neon': _ = ray.get([a.read_file.remote(img, self.config_dict['file_type']) for a, img in zip(actors, images)]) else: raise ValueError(f"不支持的文件类型: {self.config_dict['file_type']}") _ = ray.get([a.create_bad_bands.remote(self.config_dict['bad_bands']) for a in actors]) load_time = time.time() - load_start print(f"加载完成,耗时 {load_time:.2f} 秒") # 应用校正 if self.config_dict.get("corrections"): corr_list = self.config_dict["corrections"] print(f"Applying {len(corr_list)} correction(s): {', '.join(corr_list)}") correction_start = time.time() for correction in corr_list: if correction == 'topo': calc_topo_coeffs(actors, self.config_dict['topo']) elif correction == 'brdf': calc_brdf_coeffs(actors, self.config_dict) elif correction == 'glint': set_glint_parameters(actors, self.config_dict) correction_time = time.time() - correction_start print(f"校正完成,耗时 {correction_time:.2f} 秒") # 导出系数 export_cfg = self.config_dict.get('export', {}) if export_cfg.get('coeffs') and self.config_dict.get("corrections"): print("Exporting correction coefficients...") coeff_start = time.time() _ = ray.get([a.do.remote(export_coeffs, export_cfg) for a in actors]) print(f"系数导出完成,耗时 {time.time() - coeff_start:.2f} 秒") # 导出图像 if export_cfg.get('image'): print("Exporting corrected images...") export_start = time.time() _ = ray.get([a.do.remote(apply_corrections, self.config_dict) for a in actors]) print(f"图像导出完成,耗时 {time.time() - export_start:.2f} 秒") total_time = time.time() - total_start print(f"全部任务完成,总耗时 {total_time:.2f} 秒") # 关闭 Ray(可选,若希望复用则不关闭) self._shutdown_ray() @staticmethod def process_file(config_file: str): """静态方法:单次处理一个配置文件(自动管理 Ray 生命周期)""" corrector = HyToolsCorrector() corrector.run(config_file) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._shutdown_ray() # ---------- 向后兼容的 main 入口 ---------- def main(): """保持与原始脚本的兼容性""" if len(sys.argv) > 1: config_path = sys.argv[1] else: config_path = r"E:\code\hytools-master\hytools-master\examples\configs\aviris.json" HyToolsCorrector.process_file(config_path) if __name__ == "__main__": main()