269 lines
11 KiB
Python
269 lines
11 KiB
Python
import json
|
|
import os
|
|
import warnings
|
|
import sys
|
|
import ray
|
|
import numpy as np
|
|
|
|
import hytools as ht
|
|
from hytools.io.envi import *
|
|
from hytools.io.netcdf import *
|
|
from hytools.masks import mask_dict
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
def main():
|
|
|
|
config_file = sys.argv[1]
|
|
|
|
with open(config_file, 'r') as outfile:
|
|
config_dict = json.load(outfile)
|
|
|
|
if len(sys.argv)>2:
|
|
meta_file = sys.argv[2]
|
|
with open(meta_file, 'r') as outfile:
|
|
meta_dict = json.load(outfile)
|
|
config_dict["outside_metadata"] = meta_dict
|
|
else:
|
|
if "outside_metadata" in config_dict:
|
|
if not isinstance(config_dict["outside_metadata"],dict):
|
|
with open(config_dict["outside_metadata"], 'r') as outfile:
|
|
# load json and replace it by a dict
|
|
meta_dict = json.load(outfile)
|
|
config_dict["outside_metadata"] = meta_dict
|
|
else:
|
|
config_dict["outside_metadata"] = None
|
|
|
|
images= config_dict["input_files"]
|
|
|
|
if ray.is_initialized():
|
|
ray.shutdown()
|
|
print("Using %s CPUs." % config_dict['num_cpus'])
|
|
ray.init(num_cpus = config_dict['num_cpus'])
|
|
|
|
HyTools = ray.remote(ht.HyTools)
|
|
actors = [HyTools.remote() for image in images]
|
|
|
|
# Load data
|
|
if config_dict['file_type'] == 'envi':
|
|
anc_files = config_dict["anc_files"]
|
|
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
|
|
anc_files[image]) for a,image in zip(actors,images)])
|
|
elif config_dict['file_type'] == 'neon':
|
|
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
|
|
elif config_dict['file_type'] == 'emit' or config_dict['file_type'] == 'ncav':
|
|
anc_files = config_dict["anc_files"]
|
|
if bool(config_dict["glt_files"]):
|
|
glt_files = config_dict["glt_files"]
|
|
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
|
|
anc_path=anc_files[image],glt_path=glt_files[image]) for a,image in zip(actors,images)])
|
|
else:
|
|
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
|
|
anc_path=anc_files[image]) for a,image in zip(actors,images)])
|
|
else:
|
|
print("Image file type is not recognized.")
|
|
return
|
|
|
|
|
|
default_export_type = "envi"
|
|
if "export_type" in config_dict:
|
|
if not config_dict["export_type"] in ["envi","netcdf"]:
|
|
print("Image export file type is not recognized.")
|
|
return
|
|
else:
|
|
config_dict["export_type"]=default_export_type
|
|
|
|
if not "use_glt" in config_dict:
|
|
config_dict["use_glt"]=False
|
|
|
|
print("Estimating %s traits:" % len( config_dict['trait_models']))
|
|
for trait in config_dict['trait_models']:
|
|
with open(trait, 'r') as json_file:
|
|
trait_model = json.load(json_file)
|
|
print("\t %s" % trait_model["name"])
|
|
|
|
_ = ray.get([a.do.remote(apply_trait_models,config_dict) for a in actors])
|
|
ray.shutdown()
|
|
|
|
def apply_trait_models(hy_obj,config_dict):
|
|
'''Apply trait model(s) to image and export to file.
|
|
'''
|
|
|
|
hy_obj.create_bad_bands(config_dict['bad_bands'])
|
|
hy_obj.corrections = config_dict['corrections']
|
|
|
|
# Load correction coefficients
|
|
if 'topo' in hy_obj.corrections:
|
|
hy_obj.load_coeffs(config_dict['topo'][hy_obj.file_name],'topo')
|
|
|
|
if 'brdf' in hy_obj.corrections:
|
|
hy_obj.load_coeffs(config_dict['brdf'][hy_obj.file_name],'brdf')
|
|
|
|
hy_obj.resampler['type'] = config_dict["resampling"]['type']
|
|
|
|
for trait in config_dict['trait_models']:
|
|
with open(trait, 'r') as json_file:
|
|
trait_model = json.load(json_file)
|
|
coeffs = np.array(trait_model['model']['coefficients'])
|
|
intercept = np.array(trait_model['model']['intercepts'])
|
|
model_waves = np.array(trait_model['wavelengths'])
|
|
|
|
#Check if wavelengths match
|
|
resample = not all(x in hy_obj.wavelengths for x in model_waves)
|
|
|
|
if resample:
|
|
hy_obj.resampler['out_waves'] = model_waves
|
|
hy_obj.resampler['out_fwhm'] = trait_model['fwhm']
|
|
else:
|
|
wave_mask = [np.argwhere(x==hy_obj.wavelengths)[0][0] for x in model_waves]
|
|
|
|
|
|
use_glt_output_bool=False
|
|
if 'use_glt' in config_dict:
|
|
use_glt_output_bool = config_dict['use_glt']
|
|
if use_glt_output_bool==True:
|
|
header_dict = hy_obj.get_header(warp_glt=True)
|
|
else:
|
|
header_dict = hy_obj.get_header()
|
|
else:
|
|
header_dict = hy_obj.get_header()
|
|
|
|
# Build trait image file
|
|
header_dict['wavelength'] = []
|
|
header_dict['data ignore value'] = -9999
|
|
header_dict['data type'] = 4
|
|
header_dict['trait unit'] = trait_model['units']
|
|
header_dict['band names'] = ["%s_mean" % trait_model["name"],
|
|
"%s_std" % trait_model["name"],
|
|
'range_mask'] + [mask[0] for mask in config_dict['masks']]
|
|
header_dict['bands'] = len(header_dict['band names'] )
|
|
|
|
header_dict['file_type'] = config_dict['file_type']
|
|
header_dict['transform'] = hy_obj.transform
|
|
header_dict['projection'] = hy_obj.projection
|
|
|
|
#Generate masks
|
|
for mask,args in config_dict['masks']:
|
|
mask_function = mask_dict[mask]
|
|
hy_obj.gen_mask(mask_function,mask,args)
|
|
|
|
output_name = config_dict['output_dir']
|
|
|
|
if config_dict["export_type"]=="envi":
|
|
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] + "_%s" % trait_model["name"]
|
|
writer = WriteENVI(output_name,header_dict)
|
|
else:
|
|
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] + "_%s.nc" % trait_model["name"]
|
|
header_dict['lines_glt'] = hy_obj.lines_glt
|
|
header_dict['samples_glt'] = hy_obj.columns_glt
|
|
writer = WriteNetCDF(output_name,header_dict,
|
|
attr_dict=None,
|
|
glt_bool=use_glt_output_bool,
|
|
type_tag="trait",
|
|
band_name=trait_model["name"])
|
|
|
|
if (not use_glt_output_bool) and config_dict['file_type'] == 'emit':
|
|
writer.write_glt_dataset(hy_obj.glt_x,hy_obj.glt_y,dim_x_name="ortho_x",dim_y_name="ortho_y")
|
|
|
|
if config_dict['file_type'] == 'envi' or config_dict['file_type'] == 'emit':
|
|
iterator = hy_obj.iterate(by = 'chunk',
|
|
chunk_size = (2,hy_obj.columns),
|
|
corrections = hy_obj.corrections,
|
|
resample=resample)
|
|
elif config_dict['file_type'] == 'neon':
|
|
iterator = hy_obj.iterate(by = 'chunk',
|
|
chunk_size = (int(np.ceil(hy_obj.lines/32)),int(np.ceil(hy_obj.columns/32))),
|
|
corrections = hy_obj.corrections,
|
|
resample=resample)
|
|
|
|
elif config_dict['file_type'] == 'ncav':
|
|
|
|
iterator = hy_obj.iterate(by = 'chunk',
|
|
chunk_size = (256,hy_obj.columns),
|
|
corrections = hy_obj.corrections,
|
|
resample=resample)
|
|
|
|
out_stack = np.zeros((header_dict['bands'],header_dict['lines'],header_dict['samples'])).astype(np.float32)
|
|
|
|
while not iterator.complete:
|
|
chunk = iterator.read_next()
|
|
if not resample:
|
|
chunk = chunk[:,:,wave_mask]
|
|
|
|
trait_est = np.zeros((chunk.shape[0],
|
|
chunk.shape[1],
|
|
header_dict['bands']))
|
|
|
|
# Apply spectrum transforms
|
|
for transform in trait_model['model']["transform"]:
|
|
if transform== "vector":
|
|
norm = np.linalg.norm(chunk,axis=2)
|
|
chunk = chunk/norm[:,:,np.newaxis]
|
|
if transform == "absorb":
|
|
chunk = np.log(1/chunk)
|
|
if transform == "mean":
|
|
mean = chunk.mean(axis=2)
|
|
chunk = chunk/mean[:,:,np.newaxis]
|
|
|
|
trait_pred = np.einsum('jkl,ml->jkm',chunk,coeffs, optimize='optimal')
|
|
trait_pred = trait_pred + intercept
|
|
trait_est[:,:,0] = trait_pred.mean(axis=2)
|
|
trait_est[:,:,1] = trait_pred.std(ddof=1,axis=2)
|
|
|
|
range_mask = (trait_est[:,:,0] > trait_model["model_diagnostics"]['min']) & \
|
|
(trait_est[:,:,0] < trait_model["model_diagnostics"]['max'])
|
|
trait_est[:,:,2] = range_mask.astype(int)
|
|
|
|
|
|
# Subset and assign custom masks
|
|
for i,(mask,args) in enumerate(config_dict['masks']):
|
|
mask = hy_obj.mask[mask][iterator.current_line:iterator.current_line+chunk.shape[0],
|
|
iterator.current_column:iterator.current_column+chunk.shape[1]]
|
|
|
|
trait_est[:,:,3+i] = mask.astype(int)
|
|
|
|
|
|
nd_mask = hy_obj.mask['no_data'][iterator.current_line:iterator.current_line+chunk.shape[0],
|
|
iterator.current_column:iterator.current_column+chunk.shape[1]]
|
|
|
|
trait_est[~nd_mask,:2] = -9999
|
|
trait_est[~nd_mask,2:] = 255
|
|
|
|
x_start = iterator.current_column
|
|
x_end = iterator.current_column + trait_est.shape[1]
|
|
y_start = iterator.current_line
|
|
y_end = iterator.current_line + trait_est.shape[0]
|
|
out_stack[:,y_start:y_end,x_start:x_end] = np.moveaxis(trait_est,-1,0)
|
|
|
|
if use_glt_output_bool:
|
|
for iband in range(2):
|
|
writer.write_netcdf_band_glt(out_stack[iband,:,:],iband, (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
|
|
writer.close()
|
|
|
|
|
|
for iband in range(len(header_dict['band names'][2:])):
|
|
writer = WriteNetCDF(output_name,header_dict,
|
|
attr_dict=config_dict["outside_metadata"],
|
|
glt_bool=use_glt_output_bool,
|
|
type_tag="mask",
|
|
band_name=header_dict['band names'][2:][iband])
|
|
writer.write_mask_band_glt(out_stack[2+iband,:,:], (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
|
|
writer.close()
|
|
else:
|
|
for iband in range(2):
|
|
writer.write_band(out_stack[iband,:,:],iband)
|
|
writer.close()
|
|
|
|
for iband in range(len(header_dict['band names'][2:])):
|
|
writer = WriteNetCDF(output_name,header_dict,
|
|
attr_dict=config_dict["outside_metadata"],
|
|
glt_bool=use_glt_output_bool,
|
|
type_tag="mask",
|
|
band_name=header_dict['band names'][2:][iband])
|
|
writer.write_mask_band(out_stack[2+iband,:,:])
|
|
writer.close()
|
|
|
|
|
|
if __name__== "__main__":
|
|
main()
|