Files
BRDF/Flexbrdf/scripts/flex_main.py
2026-04-10 16:46:45 +08:00

256 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import warnings
import sys
import time
import ray
import numpy as np
import hytools as ht
from hytools.io.envi import *
from hytools.topo import calc_topo_coeffs
from hytools.brdf import calc_brdf_coeffs
from hytools.glint import set_glint_parameters
from hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
# ---------- 全局函数:供 Ray 远程调用 ----------
def export_coeffs(hy_obj, export_dict):
"""Export correction coefficients to file."""
for correction in hy_obj.corrections:
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s.json" % (correction, export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
elif correction == 'glint':
continue
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict, outfile)
def apply_corrections(hy_obj, config_dict):
"""Apply correction to image and export to file."""
header_dict = hy_obj.get_header()
header_dict['data ignore value'] = hy_obj.no_data
header_dict['data type'] = 12
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s" % config_dict['export']["suffix"]
# Export all wavelengths
if len(config_dict['export']['subset_waves']) == 0:
if config_dict["resample"]:
hy_obj.resampler = config_dict['resampler']
waves = hy_obj.resampler['out_waves']
else:
waves = hy_obj.wavelengths
header_dict['bands'] = len(waves)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name, header_dict)
iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections,
resample=config_dict['resample'])
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line, iterator.current_line)
writer.close()
# Export subset of wavelengths
else:
waves = config_dict['export']['subset_waves']
bands = [hy_obj.wave_to_band(x) for x in waves]
waves = [round(hy_obj.wavelengths[x], 2) for x in bands]
header_dict['bands'] = len(bands)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name, header_dict)
for b, band_num in enumerate(bands):
band = hy_obj.get_band(band_num, corrections=hy_obj.corrections)
writer.write_band(band, b)
writer.close()
# Export masks
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
masks = []
mask_names = []
for correction in config_dict["corrections"]:
for mask_type in config_dict[correction]['apply_mask']:
mask_names.append(correction + '_' + mask_type[0])
masks.append(mask_create(hy_obj, [mask_type]))
header_dict['data type'] = 1
header_dict['bands'] = len(masks)
header_dict['band names'] = mask_names
header_dict['samples'] = hy_obj.columns
header_dict['lines'] = hy_obj.lines
header_dict['wavelength'] = []
header_dict['fwhm'] = []
header_dict['wavelength units'] = ''
header_dict['data ignore value'] = 255
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s_mask" % config_dict['export']["suffix"]
writer = WriteENVI(output_name, header_dict)
for band_num, mask in enumerate(masks):
mask = mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band(mask, band_num)
del masks
# ---------- 主类 ----------
class HyToolsCorrector:
"""
高光谱图像校正处理器
用法:
corrector = HyToolsCorrector()
corrector.run("path/to/config.json")
# 或者单次调用
HyToolsCorrector.process_file("path/to/config.json")
"""
def __init__(self):
"""初始化处理器(不自动启动 Ray"""
self.config_dict = None
self.ray_initialized = False
def load_config(self, config_file: str) -> dict:
"""加载 JSON 配置文件"""
with open(config_file, 'r') as f:
self.config_dict = json.load(f)
return self.config_dict
def _init_ray(self):
"""初始化 Ray如果尚未初始化"""
if not ray.is_initialized():
num_cpus = self.config_dict.get('num_cpus', None)
print(f"Initializing Ray with {num_cpus} CPUs.")
ray.init(num_cpus=num_cpus)
self.ray_initialized = True
else:
print("Ray already initialized, reusing existing instance.")
self.ray_initialized = True # 假设外部已初始化
def _shutdown_ray(self):
"""关闭 Ray仅在由本类启动时关闭"""
if self.ray_initialized and ray.is_initialized():
ray.shutdown()
self.ray_initialized = False
def run(self, config_file: str = None):
"""
执行完整的校正流程
:param config_file: JSON 配置文件路径,若为 None 则需先调用 load_config()
"""
if config_file is not None:
self.load_config(config_file)
if self.config_dict is None:
raise ValueError("配置文件未加载,请先调用 load_config() 或传入文件路径。")
print("Starting image correction process...")
total_start = time.time()
images = self.config_dict["input_files"]
# 初始化 Ray
self._init_ray()
# 创建远程 actors
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for _ in images]
# 加载图像
print(f"Loading {len(images)} image(s)...")
load_start = time.time()
if self.config_dict['file_type'] == 'envi':
anc_files = self.config_dict.get("anc_files", {})
_ = ray.get([a.read_file.remote(img, self.config_dict['file_type'],
anc_files.get(img, None))
for a, img in zip(actors, images)])
elif self.config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(img, self.config_dict['file_type'])
for a, img in zip(actors, images)])
else:
raise ValueError(f"不支持的文件类型: {self.config_dict['file_type']}")
_ = ray.get([a.create_bad_bands.remote(self.config_dict['bad_bands']) for a in actors])
load_time = time.time() - load_start
print(f"加载完成,耗时 {load_time:.2f}")
# 应用校正
if self.config_dict.get("corrections"):
corr_list = self.config_dict["corrections"]
print(f"Applying {len(corr_list)} correction(s): {', '.join(corr_list)}")
correction_start = time.time()
for correction in corr_list:
if correction == 'topo':
calc_topo_coeffs(actors, self.config_dict['topo'])
elif correction == 'brdf':
calc_brdf_coeffs(actors, self.config_dict)
elif correction == 'glint':
set_glint_parameters(actors, self.config_dict)
correction_time = time.time() - correction_start
print(f"校正完成,耗时 {correction_time:.2f}")
# 导出系数
export_cfg = self.config_dict.get('export', {})
if export_cfg.get('coeffs') and self.config_dict.get("corrections"):
print("Exporting correction coefficients...")
coeff_start = time.time()
_ = ray.get([a.do.remote(export_coeffs, export_cfg) for a in actors])
print(f"系数导出完成,耗时 {time.time() - coeff_start:.2f}")
# 导出图像
if export_cfg.get('image'):
print("Exporting corrected images...")
export_start = time.time()
_ = ray.get([a.do.remote(apply_corrections, self.config_dict) for a in actors])
print(f"图像导出完成,耗时 {time.time() - export_start:.2f}")
total_time = time.time() - total_start
print(f"全部任务完成,总耗时 {total_time:.2f}")
# 关闭 Ray可选若希望复用则不关闭
self._shutdown_ray()
@staticmethod
def process_file(config_file: str):
"""静态方法:单次处理一个配置文件(自动管理 Ray 生命周期)"""
corrector = HyToolsCorrector()
corrector.run(config_file)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._shutdown_ray()
# ---------- 向后兼容的 main 入口 ----------
def main():
"""保持与原始脚本的兼容性"""
if len(sys.argv) > 1:
config_path = sys.argv[1]
else:
config_path = r"E:\code\hytools-master\hytools-master\examples\configs\aviris.json"
HyToolsCorrector.process_file(config_path)
if __name__ == "__main__":
main()