Initial commit

This commit is contained in:
2026-04-10 16:46:45 +08:00
commit 4fd1b0a203
165 changed files with 25698 additions and 0 deletions

View File

@ -0,0 +1,256 @@
import json
import os
import warnings
import sys
import time
import ray
import numpy as np
import hytools as ht
from hytools.io.envi import *
from hytools.topo import calc_topo_coeffs
from hytools.brdf import calc_brdf_coeffs
from hytools.glint import set_glint_parameters
from hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
# ---------- 全局函数:供 Ray 远程调用 ----------
def export_coeffs(hy_obj, export_dict):
"""Export correction coefficients to file."""
for correction in hy_obj.corrections:
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s.json" % (correction, export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
elif correction == 'glint':
continue
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict, outfile)
def apply_corrections(hy_obj, config_dict):
"""Apply correction to image and export to file."""
header_dict = hy_obj.get_header()
header_dict['data ignore value'] = hy_obj.no_data
header_dict['data type'] = 12
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s" % config_dict['export']["suffix"]
# Export all wavelengths
if len(config_dict['export']['subset_waves']) == 0:
if config_dict["resample"]:
hy_obj.resampler = config_dict['resampler']
waves = hy_obj.resampler['out_waves']
else:
waves = hy_obj.wavelengths
header_dict['bands'] = len(waves)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name, header_dict)
iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections,
resample=config_dict['resample'])
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line, iterator.current_line)
writer.close()
# Export subset of wavelengths
else:
waves = config_dict['export']['subset_waves']
bands = [hy_obj.wave_to_band(x) for x in waves]
waves = [round(hy_obj.wavelengths[x], 2) for x in bands]
header_dict['bands'] = len(bands)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name, header_dict)
for b, band_num in enumerate(bands):
band = hy_obj.get_band(band_num, corrections=hy_obj.corrections)
writer.write_band(band, b)
writer.close()
# Export masks
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
masks = []
mask_names = []
for correction in config_dict["corrections"]:
for mask_type in config_dict[correction]['apply_mask']:
mask_names.append(correction + '_' + mask_type[0])
masks.append(mask_create(hy_obj, [mask_type]))
header_dict['data type'] = 1
header_dict['bands'] = len(masks)
header_dict['band names'] = mask_names
header_dict['samples'] = hy_obj.columns
header_dict['lines'] = hy_obj.lines
header_dict['wavelength'] = []
header_dict['fwhm'] = []
header_dict['wavelength units'] = ''
header_dict['data ignore value'] = 255
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s_mask" % config_dict['export']["suffix"]
writer = WriteENVI(output_name, header_dict)
for band_num, mask in enumerate(masks):
mask = mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band(mask, band_num)
del masks
# ---------- 主类 ----------
class HyToolsCorrector:
"""
高光谱图像校正处理器
用法:
corrector = HyToolsCorrector()
corrector.run("path/to/config.json")
# 或者单次调用
HyToolsCorrector.process_file("path/to/config.json")
"""
def __init__(self):
"""初始化处理器(不自动启动 Ray"""
self.config_dict = None
self.ray_initialized = False
def load_config(self, config_file: str) -> dict:
"""加载 JSON 配置文件"""
with open(config_file, 'r') as f:
self.config_dict = json.load(f)
return self.config_dict
def _init_ray(self):
"""初始化 Ray如果尚未初始化"""
if not ray.is_initialized():
num_cpus = self.config_dict.get('num_cpus', None)
print(f"Initializing Ray with {num_cpus} CPUs.")
ray.init(num_cpus=num_cpus)
self.ray_initialized = True
else:
print("Ray already initialized, reusing existing instance.")
self.ray_initialized = True # 假设外部已初始化
def _shutdown_ray(self):
"""关闭 Ray仅在由本类启动时关闭"""
if self.ray_initialized and ray.is_initialized():
ray.shutdown()
self.ray_initialized = False
def run(self, config_file: str = None):
"""
执行完整的校正流程
:param config_file: JSON 配置文件路径,若为 None 则需先调用 load_config()
"""
if config_file is not None:
self.load_config(config_file)
if self.config_dict is None:
raise ValueError("配置文件未加载,请先调用 load_config() 或传入文件路径。")
print("Starting image correction process...")
total_start = time.time()
images = self.config_dict["input_files"]
# 初始化 Ray
self._init_ray()
# 创建远程 actors
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for _ in images]
# 加载图像
print(f"Loading {len(images)} image(s)...")
load_start = time.time()
if self.config_dict['file_type'] == 'envi':
anc_files = self.config_dict.get("anc_files", {})
_ = ray.get([a.read_file.remote(img, self.config_dict['file_type'],
anc_files.get(img, None))
for a, img in zip(actors, images)])
elif self.config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(img, self.config_dict['file_type'])
for a, img in zip(actors, images)])
else:
raise ValueError(f"不支持的文件类型: {self.config_dict['file_type']}")
_ = ray.get([a.create_bad_bands.remote(self.config_dict['bad_bands']) for a in actors])
load_time = time.time() - load_start
print(f"加载完成,耗时 {load_time:.2f}")
# 应用校正
if self.config_dict.get("corrections"):
corr_list = self.config_dict["corrections"]
print(f"Applying {len(corr_list)} correction(s): {', '.join(corr_list)}")
correction_start = time.time()
for correction in corr_list:
if correction == 'topo':
calc_topo_coeffs(actors, self.config_dict['topo'])
elif correction == 'brdf':
calc_brdf_coeffs(actors, self.config_dict)
elif correction == 'glint':
set_glint_parameters(actors, self.config_dict)
correction_time = time.time() - correction_start
print(f"校正完成,耗时 {correction_time:.2f}")
# 导出系数
export_cfg = self.config_dict.get('export', {})
if export_cfg.get('coeffs') and self.config_dict.get("corrections"):
print("Exporting correction coefficients...")
coeff_start = time.time()
_ = ray.get([a.do.remote(export_coeffs, export_cfg) for a in actors])
print(f"系数导出完成,耗时 {time.time() - coeff_start:.2f}")
# 导出图像
if export_cfg.get('image'):
print("Exporting corrected images...")
export_start = time.time()
_ = ray.get([a.do.remote(apply_corrections, self.config_dict) for a in actors])
print(f"图像导出完成,耗时 {time.time() - export_start:.2f}")
total_time = time.time() - total_start
print(f"全部任务完成,总耗时 {total_time:.2f}")
# 关闭 Ray可选若希望复用则不关闭
self._shutdown_ray()
@staticmethod
def process_file(config_file: str):
"""静态方法:单次处理一个配置文件(自动管理 Ray 生命周期)"""
corrector = HyToolsCorrector()
corrector.run(config_file)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._shutdown_ray()
# ---------- 向后兼容的 main 入口 ----------
def main():
"""保持与原始脚本的兼容性"""
if len(sys.argv) > 1:
config_path = sys.argv[1]
else:
config_path = r"E:\code\hytools-master\hytools-master\examples\configs\aviris.json"
HyToolsCorrector.process_file(config_path)
if __name__ == "__main__":
main()