256 lines
9.2 KiB
Python
256 lines
9.2 KiB
Python
import json
|
||
import os
|
||
import warnings
|
||
import sys
|
||
import time
|
||
import ray
|
||
import numpy as np
|
||
import hytools as ht
|
||
from hytools.io.envi import *
|
||
from hytools.topo import calc_topo_coeffs
|
||
from hytools.brdf import calc_brdf_coeffs
|
||
from hytools.glint import set_glint_parameters
|
||
from hytools.masks import mask_create
|
||
|
||
warnings.filterwarnings("ignore")
|
||
np.seterr(divide='ignore', invalid='ignore')
|
||
|
||
|
||
# ---------- 全局函数:供 Ray 远程调用 ----------
|
||
def export_coeffs(hy_obj, export_dict):
|
||
"""Export correction coefficients to file."""
|
||
for correction in hy_obj.corrections:
|
||
coeff_file = export_dict['output_dir']
|
||
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
|
||
coeff_file += "_%s_coeffs_%s.json" % (correction, export_dict["suffix"])
|
||
|
||
with open(coeff_file, 'w') as outfile:
|
||
if correction == 'topo':
|
||
corr_dict = hy_obj.topo
|
||
elif correction == 'glint':
|
||
continue
|
||
else:
|
||
corr_dict = hy_obj.brdf
|
||
json.dump(corr_dict, outfile)
|
||
|
||
|
||
def apply_corrections(hy_obj, config_dict):
|
||
"""Apply correction to image and export to file."""
|
||
header_dict = hy_obj.get_header()
|
||
header_dict['data ignore value'] = hy_obj.no_data
|
||
header_dict['data type'] = 12
|
||
|
||
output_name = config_dict['export']['output_dir']
|
||
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
|
||
output_name += "_%s" % config_dict['export']["suffix"]
|
||
|
||
# Export all wavelengths
|
||
if len(config_dict['export']['subset_waves']) == 0:
|
||
if config_dict["resample"]:
|
||
hy_obj.resampler = config_dict['resampler']
|
||
waves = hy_obj.resampler['out_waves']
|
||
else:
|
||
waves = hy_obj.wavelengths
|
||
|
||
header_dict['bands'] = len(waves)
|
||
header_dict['wavelength'] = waves
|
||
|
||
writer = WriteENVI(output_name, header_dict)
|
||
iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections,
|
||
resample=config_dict['resample'])
|
||
while not iterator.complete:
|
||
line = iterator.read_next()
|
||
writer.write_line(line, iterator.current_line)
|
||
writer.close()
|
||
|
||
# Export subset of wavelengths
|
||
else:
|
||
waves = config_dict['export']['subset_waves']
|
||
bands = [hy_obj.wave_to_band(x) for x in waves]
|
||
waves = [round(hy_obj.wavelengths[x], 2) for x in bands]
|
||
header_dict['bands'] = len(bands)
|
||
header_dict['wavelength'] = waves
|
||
|
||
writer = WriteENVI(output_name, header_dict)
|
||
for b, band_num in enumerate(bands):
|
||
band = hy_obj.get_band(band_num, corrections=hy_obj.corrections)
|
||
writer.write_band(band, b)
|
||
writer.close()
|
||
|
||
# Export masks
|
||
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
|
||
masks = []
|
||
mask_names = []
|
||
|
||
for correction in config_dict["corrections"]:
|
||
for mask_type in config_dict[correction]['apply_mask']:
|
||
mask_names.append(correction + '_' + mask_type[0])
|
||
masks.append(mask_create(hy_obj, [mask_type]))
|
||
|
||
header_dict['data type'] = 1
|
||
header_dict['bands'] = len(masks)
|
||
header_dict['band names'] = mask_names
|
||
header_dict['samples'] = hy_obj.columns
|
||
header_dict['lines'] = hy_obj.lines
|
||
header_dict['wavelength'] = []
|
||
header_dict['fwhm'] = []
|
||
header_dict['wavelength units'] = ''
|
||
header_dict['data ignore value'] = 255
|
||
|
||
output_name = config_dict['export']['output_dir']
|
||
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
|
||
output_name += "_%s_mask" % config_dict['export']["suffix"]
|
||
|
||
writer = WriteENVI(output_name, header_dict)
|
||
for band_num, mask in enumerate(masks):
|
||
mask = mask.astype(int)
|
||
mask[~hy_obj.mask['no_data']] = 255
|
||
writer.write_band(mask, band_num)
|
||
|
||
del masks
|
||
|
||
|
||
# ---------- 主类 ----------
|
||
class HyToolsCorrector:
|
||
"""
|
||
高光谱图像校正处理器
|
||
用法:
|
||
corrector = HyToolsCorrector()
|
||
corrector.run("path/to/config.json")
|
||
# 或者单次调用
|
||
HyToolsCorrector.process_file("path/to/config.json")
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化处理器(不自动启动 Ray)"""
|
||
self.config_dict = None
|
||
self.ray_initialized = False
|
||
|
||
def load_config(self, config_file: str) -> dict:
|
||
"""加载 JSON 配置文件"""
|
||
with open(config_file, 'r') as f:
|
||
self.config_dict = json.load(f)
|
||
return self.config_dict
|
||
|
||
def _init_ray(self):
|
||
"""初始化 Ray(如果尚未初始化)"""
|
||
if not ray.is_initialized():
|
||
num_cpus = self.config_dict.get('num_cpus', None)
|
||
print(f"Initializing Ray with {num_cpus} CPUs.")
|
||
ray.init(num_cpus=num_cpus)
|
||
self.ray_initialized = True
|
||
else:
|
||
print("Ray already initialized, reusing existing instance.")
|
||
self.ray_initialized = True # 假设外部已初始化
|
||
|
||
def _shutdown_ray(self):
|
||
"""关闭 Ray(仅在由本类启动时关闭)"""
|
||
if self.ray_initialized and ray.is_initialized():
|
||
ray.shutdown()
|
||
self.ray_initialized = False
|
||
|
||
def run(self, config_file: str = None):
|
||
"""
|
||
执行完整的校正流程
|
||
:param config_file: JSON 配置文件路径,若为 None 则需先调用 load_config()
|
||
"""
|
||
if config_file is not None:
|
||
self.load_config(config_file)
|
||
if self.config_dict is None:
|
||
raise ValueError("配置文件未加载,请先调用 load_config() 或传入文件路径。")
|
||
|
||
print("Starting image correction process...")
|
||
total_start = time.time()
|
||
|
||
images = self.config_dict["input_files"]
|
||
|
||
# 初始化 Ray
|
||
self._init_ray()
|
||
|
||
# 创建远程 actors
|
||
HyTools = ray.remote(ht.HyTools)
|
||
actors = [HyTools.remote() for _ in images]
|
||
|
||
# 加载图像
|
||
print(f"Loading {len(images)} image(s)...")
|
||
load_start = time.time()
|
||
|
||
if self.config_dict['file_type'] == 'envi':
|
||
anc_files = self.config_dict.get("anc_files", {})
|
||
_ = ray.get([a.read_file.remote(img, self.config_dict['file_type'],
|
||
anc_files.get(img, None))
|
||
for a, img in zip(actors, images)])
|
||
elif self.config_dict['file_type'] == 'neon':
|
||
_ = ray.get([a.read_file.remote(img, self.config_dict['file_type'])
|
||
for a, img in zip(actors, images)])
|
||
else:
|
||
raise ValueError(f"不支持的文件类型: {self.config_dict['file_type']}")
|
||
|
||
_ = ray.get([a.create_bad_bands.remote(self.config_dict['bad_bands']) for a in actors])
|
||
|
||
load_time = time.time() - load_start
|
||
print(f"加载完成,耗时 {load_time:.2f} 秒")
|
||
|
||
# 应用校正
|
||
if self.config_dict.get("corrections"):
|
||
corr_list = self.config_dict["corrections"]
|
||
print(f"Applying {len(corr_list)} correction(s): {', '.join(corr_list)}")
|
||
correction_start = time.time()
|
||
|
||
for correction in corr_list:
|
||
if correction == 'topo':
|
||
calc_topo_coeffs(actors, self.config_dict['topo'])
|
||
elif correction == 'brdf':
|
||
calc_brdf_coeffs(actors, self.config_dict)
|
||
elif correction == 'glint':
|
||
set_glint_parameters(actors, self.config_dict)
|
||
|
||
correction_time = time.time() - correction_start
|
||
print(f"校正完成,耗时 {correction_time:.2f} 秒")
|
||
|
||
# 导出系数
|
||
export_cfg = self.config_dict.get('export', {})
|
||
if export_cfg.get('coeffs') and self.config_dict.get("corrections"):
|
||
print("Exporting correction coefficients...")
|
||
coeff_start = time.time()
|
||
_ = ray.get([a.do.remote(export_coeffs, export_cfg) for a in actors])
|
||
print(f"系数导出完成,耗时 {time.time() - coeff_start:.2f} 秒")
|
||
|
||
# 导出图像
|
||
if export_cfg.get('image'):
|
||
print("Exporting corrected images...")
|
||
export_start = time.time()
|
||
_ = ray.get([a.do.remote(apply_corrections, self.config_dict) for a in actors])
|
||
print(f"图像导出完成,耗时 {time.time() - export_start:.2f} 秒")
|
||
|
||
total_time = time.time() - total_start
|
||
print(f"全部任务完成,总耗时 {total_time:.2f} 秒")
|
||
|
||
# 关闭 Ray(可选,若希望复用则不关闭)
|
||
self._shutdown_ray()
|
||
|
||
@staticmethod
|
||
def process_file(config_file: str):
|
||
"""静态方法:单次处理一个配置文件(自动管理 Ray 生命周期)"""
|
||
corrector = HyToolsCorrector()
|
||
corrector.run(config_file)
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
self._shutdown_ray()
|
||
|
||
|
||
# ---------- 向后兼容的 main 入口 ----------
|
||
def main():
|
||
"""保持与原始脚本的兼容性"""
|
||
if len(sys.argv) > 1:
|
||
config_path = sys.argv[1]
|
||
else:
|
||
config_path = r"E:\code\hytools-master\hytools-master\examples\configs\aviris.json"
|
||
HyToolsCorrector.process_file(config_path)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |