Initial commit
This commit is contained in:
256
Flexbrdf/scripts/flex_main.py
Normal file
256
Flexbrdf/scripts/flex_main.py
Normal file
@ -0,0 +1,256 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
import sys
|
||||
import time
|
||||
import ray
|
||||
import numpy as np
|
||||
import hytools as ht
|
||||
from hytools.io.envi import *
|
||||
from hytools.topo import calc_topo_coeffs
|
||||
from hytools.brdf import calc_brdf_coeffs
|
||||
from hytools.glint import set_glint_parameters
|
||||
from hytools.masks import mask_create
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
np.seterr(divide='ignore', invalid='ignore')
|
||||
|
||||
|
||||
# ---------- 全局函数:供 Ray 远程调用 ----------
|
||||
def export_coeffs(hy_obj, export_dict):
|
||||
"""Export correction coefficients to file."""
|
||||
for correction in hy_obj.corrections:
|
||||
coeff_file = export_dict['output_dir']
|
||||
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
|
||||
coeff_file += "_%s_coeffs_%s.json" % (correction, export_dict["suffix"])
|
||||
|
||||
with open(coeff_file, 'w') as outfile:
|
||||
if correction == 'topo':
|
||||
corr_dict = hy_obj.topo
|
||||
elif correction == 'glint':
|
||||
continue
|
||||
else:
|
||||
corr_dict = hy_obj.brdf
|
||||
json.dump(corr_dict, outfile)
|
||||
|
||||
|
||||
def apply_corrections(hy_obj, config_dict):
|
||||
"""Apply correction to image and export to file."""
|
||||
header_dict = hy_obj.get_header()
|
||||
header_dict['data ignore value'] = hy_obj.no_data
|
||||
header_dict['data type'] = 12
|
||||
|
||||
output_name = config_dict['export']['output_dir']
|
||||
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
|
||||
output_name += "_%s" % config_dict['export']["suffix"]
|
||||
|
||||
# Export all wavelengths
|
||||
if len(config_dict['export']['subset_waves']) == 0:
|
||||
if config_dict["resample"]:
|
||||
hy_obj.resampler = config_dict['resampler']
|
||||
waves = hy_obj.resampler['out_waves']
|
||||
else:
|
||||
waves = hy_obj.wavelengths
|
||||
|
||||
header_dict['bands'] = len(waves)
|
||||
header_dict['wavelength'] = waves
|
||||
|
||||
writer = WriteENVI(output_name, header_dict)
|
||||
iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections,
|
||||
resample=config_dict['resample'])
|
||||
while not iterator.complete:
|
||||
line = iterator.read_next()
|
||||
writer.write_line(line, iterator.current_line)
|
||||
writer.close()
|
||||
|
||||
# Export subset of wavelengths
|
||||
else:
|
||||
waves = config_dict['export']['subset_waves']
|
||||
bands = [hy_obj.wave_to_band(x) for x in waves]
|
||||
waves = [round(hy_obj.wavelengths[x], 2) for x in bands]
|
||||
header_dict['bands'] = len(bands)
|
||||
header_dict['wavelength'] = waves
|
||||
|
||||
writer = WriteENVI(output_name, header_dict)
|
||||
for b, band_num in enumerate(bands):
|
||||
band = hy_obj.get_band(band_num, corrections=hy_obj.corrections)
|
||||
writer.write_band(band, b)
|
||||
writer.close()
|
||||
|
||||
# Export masks
|
||||
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
|
||||
masks = []
|
||||
mask_names = []
|
||||
|
||||
for correction in config_dict["corrections"]:
|
||||
for mask_type in config_dict[correction]['apply_mask']:
|
||||
mask_names.append(correction + '_' + mask_type[0])
|
||||
masks.append(mask_create(hy_obj, [mask_type]))
|
||||
|
||||
header_dict['data type'] = 1
|
||||
header_dict['bands'] = len(masks)
|
||||
header_dict['band names'] = mask_names
|
||||
header_dict['samples'] = hy_obj.columns
|
||||
header_dict['lines'] = hy_obj.lines
|
||||
header_dict['wavelength'] = []
|
||||
header_dict['fwhm'] = []
|
||||
header_dict['wavelength units'] = ''
|
||||
header_dict['data ignore value'] = 255
|
||||
|
||||
output_name = config_dict['export']['output_dir']
|
||||
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
|
||||
output_name += "_%s_mask" % config_dict['export']["suffix"]
|
||||
|
||||
writer = WriteENVI(output_name, header_dict)
|
||||
for band_num, mask in enumerate(masks):
|
||||
mask = mask.astype(int)
|
||||
mask[~hy_obj.mask['no_data']] = 255
|
||||
writer.write_band(mask, band_num)
|
||||
|
||||
del masks
|
||||
|
||||
|
||||
# ---------- 主类 ----------
|
||||
class HyToolsCorrector:
|
||||
"""
|
||||
高光谱图像校正处理器
|
||||
用法:
|
||||
corrector = HyToolsCorrector()
|
||||
corrector.run("path/to/config.json")
|
||||
# 或者单次调用
|
||||
HyToolsCorrector.process_file("path/to/config.json")
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化处理器(不自动启动 Ray)"""
|
||||
self.config_dict = None
|
||||
self.ray_initialized = False
|
||||
|
||||
def load_config(self, config_file: str) -> dict:
|
||||
"""加载 JSON 配置文件"""
|
||||
with open(config_file, 'r') as f:
|
||||
self.config_dict = json.load(f)
|
||||
return self.config_dict
|
||||
|
||||
def _init_ray(self):
|
||||
"""初始化 Ray(如果尚未初始化)"""
|
||||
if not ray.is_initialized():
|
||||
num_cpus = self.config_dict.get('num_cpus', None)
|
||||
print(f"Initializing Ray with {num_cpus} CPUs.")
|
||||
ray.init(num_cpus=num_cpus)
|
||||
self.ray_initialized = True
|
||||
else:
|
||||
print("Ray already initialized, reusing existing instance.")
|
||||
self.ray_initialized = True # 假设外部已初始化
|
||||
|
||||
def _shutdown_ray(self):
|
||||
"""关闭 Ray(仅在由本类启动时关闭)"""
|
||||
if self.ray_initialized and ray.is_initialized():
|
||||
ray.shutdown()
|
||||
self.ray_initialized = False
|
||||
|
||||
def run(self, config_file: str = None):
|
||||
"""
|
||||
执行完整的校正流程
|
||||
:param config_file: JSON 配置文件路径,若为 None 则需先调用 load_config()
|
||||
"""
|
||||
if config_file is not None:
|
||||
self.load_config(config_file)
|
||||
if self.config_dict is None:
|
||||
raise ValueError("配置文件未加载,请先调用 load_config() 或传入文件路径。")
|
||||
|
||||
print("Starting image correction process...")
|
||||
total_start = time.time()
|
||||
|
||||
images = self.config_dict["input_files"]
|
||||
|
||||
# 初始化 Ray
|
||||
self._init_ray()
|
||||
|
||||
# 创建远程 actors
|
||||
HyTools = ray.remote(ht.HyTools)
|
||||
actors = [HyTools.remote() for _ in images]
|
||||
|
||||
# 加载图像
|
||||
print(f"Loading {len(images)} image(s)...")
|
||||
load_start = time.time()
|
||||
|
||||
if self.config_dict['file_type'] == 'envi':
|
||||
anc_files = self.config_dict.get("anc_files", {})
|
||||
_ = ray.get([a.read_file.remote(img, self.config_dict['file_type'],
|
||||
anc_files.get(img, None))
|
||||
for a, img in zip(actors, images)])
|
||||
elif self.config_dict['file_type'] == 'neon':
|
||||
_ = ray.get([a.read_file.remote(img, self.config_dict['file_type'])
|
||||
for a, img in zip(actors, images)])
|
||||
else:
|
||||
raise ValueError(f"不支持的文件类型: {self.config_dict['file_type']}")
|
||||
|
||||
_ = ray.get([a.create_bad_bands.remote(self.config_dict['bad_bands']) for a in actors])
|
||||
|
||||
load_time = time.time() - load_start
|
||||
print(f"加载完成,耗时 {load_time:.2f} 秒")
|
||||
|
||||
# 应用校正
|
||||
if self.config_dict.get("corrections"):
|
||||
corr_list = self.config_dict["corrections"]
|
||||
print(f"Applying {len(corr_list)} correction(s): {', '.join(corr_list)}")
|
||||
correction_start = time.time()
|
||||
|
||||
for correction in corr_list:
|
||||
if correction == 'topo':
|
||||
calc_topo_coeffs(actors, self.config_dict['topo'])
|
||||
elif correction == 'brdf':
|
||||
calc_brdf_coeffs(actors, self.config_dict)
|
||||
elif correction == 'glint':
|
||||
set_glint_parameters(actors, self.config_dict)
|
||||
|
||||
correction_time = time.time() - correction_start
|
||||
print(f"校正完成,耗时 {correction_time:.2f} 秒")
|
||||
|
||||
# 导出系数
|
||||
export_cfg = self.config_dict.get('export', {})
|
||||
if export_cfg.get('coeffs') and self.config_dict.get("corrections"):
|
||||
print("Exporting correction coefficients...")
|
||||
coeff_start = time.time()
|
||||
_ = ray.get([a.do.remote(export_coeffs, export_cfg) for a in actors])
|
||||
print(f"系数导出完成,耗时 {time.time() - coeff_start:.2f} 秒")
|
||||
|
||||
# 导出图像
|
||||
if export_cfg.get('image'):
|
||||
print("Exporting corrected images...")
|
||||
export_start = time.time()
|
||||
_ = ray.get([a.do.remote(apply_corrections, self.config_dict) for a in actors])
|
||||
print(f"图像导出完成,耗时 {time.time() - export_start:.2f} 秒")
|
||||
|
||||
total_time = time.time() - total_start
|
||||
print(f"全部任务完成,总耗时 {total_time:.2f} 秒")
|
||||
|
||||
# 关闭 Ray(可选,若希望复用则不关闭)
|
||||
self._shutdown_ray()
|
||||
|
||||
@staticmethod
|
||||
def process_file(config_file: str):
|
||||
"""静态方法:单次处理一个配置文件(自动管理 Ray 生命周期)"""
|
||||
corrector = HyToolsCorrector()
|
||||
corrector.run(config_file)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._shutdown_ray()
|
||||
|
||||
|
||||
# ---------- 向后兼容的 main 入口 ----------
|
||||
def main():
|
||||
"""保持与原始脚本的兼容性"""
|
||||
if len(sys.argv) > 1:
|
||||
config_path = sys.argv[1]
|
||||
else:
|
||||
config_path = r"E:\code\hytools-master\hytools-master\examples\configs\aviris.json"
|
||||
HyToolsCorrector.process_file(config_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user