Initial commit

This commit is contained in:
2026-04-10 16:46:45 +08:00
commit 4fd1b0a203
165 changed files with 25698 additions and 0 deletions

View File

@ -0,0 +1,137 @@
import argparse
import json
import os
import ray
import hytools as ht
from hytools.io.envi import WriteENVI
from hytools.brdf import calc_brdf_coeffs
def build_anc_mapping(anc_file, anc_names):
return dict(zip(anc_names, [[anc_file, i] for i in range(len(anc_names))]))
def build_brdf_config(args):
mask = [["ndi", {"band_1": args.ndi_band_1, "band_2": args.ndi_band_2, "min": args.ndi_min, "max": args.ndi_max}]]
brdf = {
"type": args.brdf_type,
"grouped": False,
"geometric": args.geometric,
"volume": args.volume,
"b/r": args.b_r,
"h/b": args.h_b,
"sample_perc": args.sample_perc,
"calc_mask": mask,
"apply_mask": mask,
"solar_zn_type": args.solar_zn_type
}
if args.brdf_type == "universal":
brdf["diagnostic_plots"] = False
brdf["diagnostic_waves"] = []
if args.brdf_type == "flex":
brdf["interp_kind"] = "linear"
brdf["bin_type"] = "dynamic"
brdf["num_bins"] = args.num_bins
brdf["ndvi_bin_min"] = args.ndvi_bin_min
brdf["ndvi_bin_max"] = args.ndvi_bin_max
brdf["ndvi_perc_min"] = args.ndvi_perc_min
brdf["ndvi_perc_max"] = args.ndvi_perc_max
return brdf
def export_brdf_corrected(hy_obj, args):
header_dict = hy_obj.get_header()
header_dict["data ignore value"] = hy_obj.no_data
header_dict["data type"] = 4
output_name = os.path.join(
args["output_dir"],
f"{os.path.splitext(os.path.basename(hy_obj.file_name))[0]}_{args['suffix']}"
)
writer = WriteENVI(output_name, header_dict)
iterator = hy_obj.iterate(by="line", corrections=hy_obj.corrections)
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line, iterator.current_line)
writer.close()
def export_brdf_coeffs(hy_obj, args):
output_name = os.path.join(
args["output_dir"],
f"{os.path.splitext(os.path.basename(hy_obj.file_name))[0]}_brdf_coeffs_{args['suffix']}.json"
)
with open(output_name, "w") as outfile:
json.dump(hy_obj.brdf, outfile)
def main():
parser = argparse.ArgumentParser(description="BRDF correction for ENVI BIP images")
parser.add_argument("image", type=str)
parser.add_argument("anc_file", type=str)
parser.add_argument("output_dir", type=str)
parser.add_argument("--anc-map", type=str, default="")
parser.add_argument("--anc-names", type=str, default="path_length,sensor_az,sensor_zn,solar_az,solar_zn,phase,slope,aspect,cosine_i,utc_time")
parser.add_argument("--brdf-type", type=str, default="universal", choices=["universal", "flex"])
parser.add_argument("--suffix", type=str, default="brdf")
parser.add_argument("--num-cpus", type=int, default=1)
parser.add_argument("--bad-bands-json", type=str, default="")
parser.add_argument("--solar-zn-type", type=str, default="scene")
parser.add_argument("--geometric", type=str, default="li_dense_r")
parser.add_argument("--volume", type=str, default="ross_thick")
parser.add_argument("--b-r", type=float, default=2.5)
parser.add_argument("--h-b", type=float, default=2.0)
parser.add_argument("--sample-perc", type=float, default=0.1)
parser.add_argument("--ndi-band-1", type=float, default=850.0)
parser.add_argument("--ndi-band-2", type=float, default=660.0)
parser.add_argument("--ndi-min", type=float, default=0.05)
parser.add_argument("--ndi-max", type=float, default=1.0)
parser.add_argument("--num-bins", type=int, default=18)
parser.add_argument("--ndvi-bin-min", type=float, default=0.05)
parser.add_argument("--ndvi-bin-max", type=float, default=1.0)
parser.add_argument("--ndvi-perc-min", type=float, default=10.0)
parser.add_argument("--ndvi-perc-max", type=float, default=95.0)
parser.add_argument("--export-coeffs", action="store_true")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
if args.anc_map:
with open(args.anc_map, "r") as infile:
anc_map = json.load(infile)
else:
anc_names = [x.strip() for x in args.anc_names.split(",") if x.strip()]
anc_map = build_anc_mapping(args.anc_file, anc_names)
config_dict = {
"file_type": "envi",
"input_files": [args.image],
"anc_files": {args.image: anc_map},
"corrections": ["brdf"],
"brdf": build_brdf_config(args),
"num_cpus": args.num_cpus
}
if args.bad_bands_json:
config_dict["bad_bands"] = json.loads(args.bad_bands_json)
if ray.is_initialized():
ray.shutdown()
ray.init(num_cpus=config_dict["num_cpus"])
HyTools = ray.remote(ht.HyTools)
actor = HyTools.remote()
ray.get(actor.read_file.remote(args.image, "envi", anc_map))
if "bad_bands" in config_dict:
ray.get(actor.create_bad_bands.remote(config_dict["bad_bands"]))
calc_brdf_coeffs([actor], config_dict)
ray.get(actor.do.remote(export_brdf_corrected, {"output_dir": args.output_dir, "suffix": args.suffix}))
if args.export_coeffs:
ray.get(actor.do.remote(export_brdf_coeffs, {"output_dir": args.output_dir, "suffix": args.suffix}))
ray.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,294 @@
'''Template script for generating image_correct configuration JSON files.
These setting are meant only as an example, are not appropriate for
all situations and may need to be adjusted
'''
import json
import glob
import numpy as np
#Output path for configuration file
config_file = "/data1/temp/ht_test/ic_test.json"
config_dict = {}
#Only coefficients for good bands will be calculated
config_dict['bad_bands'] =[[300,400],[1337,1430],[1800,1960],[2450,2600]]
#config_dict['bad_bands'] =[[300,400],[900,2600]] # Subset for testing
# Input data settings for NEON
#################################################################
# config_dict['file_type'] = 'neon'
# images= glob.glob("/data1/temp/ht_test/*.h5")
# images.sort()
# config_dict["input_files"] = images
# Input data settings for ENVI
#################################################################
''' Only difference between ENVI and NEON settings is the specification
of the ancillary datasets (ex. viewing and solar geometry). All hytools
functions assume that the ancillary data and the image date are the same
size, spatially, and are ENVI formatted files.
The ancillary parameter is a dictionary with a key per image. Each value
per image is also a dictionary where the key is the dataset name and the
value is list consisting of the file path and the band number.
'''
config_dict['file_type'] = 'envi'
aviris_anc_names = ['path_length','sensor_az','sensor_zn',
'solar_az', 'solar_zn','phase','slope',
'aspect', 'cosine_i','utc_time']
images= glob.glob("/data1/temp/ht_test/ang20190707t203417_rfl_v2v2_img")
images.sort()
config_dict["input_files"] = images
config_dict["anc_files"] = {}
anc_files = glob.glob("/data1/temp/ht_test/ang20190707t203417_rdn_v2v2_obs_ort_corr")
anc_files.sort()
for i,image in enumerate(images):
config_dict["anc_files"][image] = dict(zip(aviris_anc_names,
[[anc_files[i],a] for a in range(len(aviris_anc_names))]))
# Export settings
#################################################################
''' Options for subset waves:
1. List of subset wavelenths
2. Empty list, this will output all good bands, if resampler is
set it will also resample.
- Currently resampler cannot be used in conjuction with option 1
'''
config_dict['export'] = {}
config_dict['export']['coeffs'] = True
config_dict['export']['image'] = True
config_dict['export']['masks'] = True
config_dict['export']['subset_waves'] = []
config_dict['export']['output_dir'] = "/data1/temp/ht_test/"
config_dict['export']["suffix"] = 'anc_nodata_test'
#Corrections
#################################################################
''' Specify correction(s) to be applied, corrections will be applied
in the order they are specified.
Options include:
['topo']
['brdf']
['glint']
['topo','brdf']
['brdf','topo']
['brdf','topo','glint']
[] <---Export uncorrected images
'''
config_dict["corrections"] = ['brdf']
#Topographic Correction options
#################################################################
'''
Types supported:
- 'cosine'
- 'c'
- 'scs
- 'scs+c'
- 'mod_minneart'
- 'precomputed'
Apply and calc masks are only needed for C and SCS+C corrections. They will
be ignored in all other cases and correction will be applied to all
non no-data pixels.
'c_fit_type' is only applicable for the C or SCS+C correction type. Options
include 'ols' or 'nnls'. Choosing 'nnls' can limit overcorrection.
For precomputed topographic coefficients 'coeff_files' is a
dictionary where each key is the full the image path and value
is the full path to coefficients file, one per image.
'''
config_dict["topo"] = {}
config_dict["topo"]['type'] = 'scs+c'
config_dict["topo"]['calc_mask'] = [["ndi", {'band_1': 850,'band_2': 660,
'min': 0.1,'max': 1.0}],
['ancillary',{'name':'slope',
'min': np.radians(5),'max':'+inf' }],
['ancillary',{'name':'cosine_i',
'min': 0.12,'max':'+inf' }],
['cloud',{'method':'zhai_2018',
'cloud':True,'shadow':True,
'T1': 0.01,'t2': 1/10,'t3': 1/4,
't4': 1/2,'T7': 9,'T8': 9}]]
config_dict["topo"]['apply_mask'] = [["ndi", {'band_1': 850,'band_2': 660,
'min': 0.1,'max': 1.0}],
['ancillary',{'name':'slope',
'min': np.radians(5),'max':'+inf' }],
['ancillary',{'name':'cosine_i',
'min': 0.12,'max':'+inf' }]]
config_dict["topo"]['c_fit_type'] = 'nnls'
# config_dict["topo"]['type'] = 'precomputed'
# config_dict["brdf"]['coeff_files'] = {}
#BRDF Correction options
#################################################################3
'''
Types supported:
- 'universal': Simple kernel multiplicative correction.
- 'local': Correction by class. (Future.....)
- 'flex' : Correction by NDVI class
- 'precomputed' : Use precomputed coefficients
If 'bin_type' == 'user'
'bins' should be a list of lists, each list the NDVI bounds [low,high]
Object shapes ('h/b','b/r') only needed for Li kernels.
For precomputed topographic coefficients 'coeff_files' is a
dictionary where each key is the full the image path and value
is the full path to coefficients file, one per image.
'''
config_dict["brdf"] = {}
# Options are 'line','scene', or a float for a custom solar zn
# Custom solar zenith angle should be in radians
config_dict["brdf"]['solar_zn_type'] ='scene'
# Universal BRDF config
#----------------------
# config_dict["brdf"]['type'] = 'universal'
# config_dict["brdf"]['grouped'] = True
# config_dict["brdf"]['sample_perc'] = 0.1
# config_dict["brdf"]['geometric'] = 'li_dense_r'
# config_dict["brdf"]['volume'] = 'ross_thick'
# config_dict["brdf"]["b/r"] = 2.5
# config_dict["brdf"]["h/b"] = 2
# config_dict["brdf"]['calc_mask'] = [["ndi", {'band_1': 850,'band_2': 660,
# 'min': 0.1,'max': 1.0}]]
# config_dict["brdf"]['apply_mask'] = [["ndi", {'band_1': 850,'band_2': 660,
# 'min': 0.1,'max': 1.0}]]
# config_dict["brdf"]['diagnostic_plots'] = True
# config_dict["brdf"]['diagnostic_waves'] = [440,550,660,850]
#----------------------
# ## Flex BRDF configs
# ##------------------
config_dict["brdf"]['type'] = 'flex'
config_dict["brdf"]['grouped'] = True
config_dict["brdf"]['geometric'] = 'li_dense_r'
config_dict["brdf"]['volume'] = 'ross_thick'
config_dict["brdf"]["b/r"] = 2.5
config_dict["brdf"]["h/b"] = 2
config_dict["brdf"]['sample_perc'] = 0.1
config_dict["brdf"]['interp_kind'] = 'linear'
config_dict["brdf"]['calc_mask'] = [["ndi", {'band_1': 850,'band_2': 660,
'min': 0.1,'max': 1.0}],
['kernel_finite',{}],
['ancillary',{'name':'sensor_zn',
'min':np.radians(2),'max':'inf' }],
['neon_edge',{'radius': 30}],
['cloud',{'method':'zhai_2018',
'cloud':True,'shadow':True,
'T1': 0.01,'t2': 1/10,'t3': 1/4,
't4': 1/2,'T7': 9,'T8': 9}]]
config_dict["brdf"]['apply_mask'] = [["ndi", {'band_1': 850,'band_2': 660,
'min': 0.05,'max': 1.0}]]
# ## Flex dynamic NDVI params
config_dict["brdf"]['bin_type'] = 'dynamic'
config_dict["brdf"]['num_bins'] = 18
config_dict["brdf"]['ndvi_bin_min'] = 0.05
config_dict["brdf"]['ndvi_bin_max'] = 1.0
config_dict["brdf"]['ndvi_perc_min'] = 10
config_dict["brdf"]['ndvi_perc_max'] = 95
# ## Flex fixed bins specified by user
# config_dict["brdf"]['bin_type'] = 'user'
# config_dict["brdf"]['bins'] = [[0.1,.25],[.25,.75],[.75,1]]
# ##-----------------
## Precomputed BRDF coefficients
##------------------------------
# config_dict["brdf"]['type'] = 'precomputed'
# config_dict["brdf"]['coeff_files'] = {}
##------------------------------
#Glint Correction options
#################################################################
'''
Types supported:
- hochberg
- hedley
- gao
Common reference bands include:
- 860nm (NIR)
- 1650nm (SWIR)
- 2190nm (SWIR)
The Hedley-specific config would be in the form of:
[ImagePath]: [y1, y2, x1, x1]
e.g.:
config_dict["glint"]["deep_water_sample"] = {
"/path_to_image1": [
137, 574, 8034, 8470
],
"/path_to_image2": [
48, 393, 5780, 5925
],
}
'''
config_dict["glint"] = {}
config_dict['glint']['type'] = 'hedley'
config_dict['glint']['correction_wave'] = 1650
# External masks for glint correction
mask_files = glob.glob("/data2/prisma/rfl/PRS_20210629153937_20210629153942_0001_modtran/*_cls")
mask_files.sort()
file_dict = dict(zip(images,mask_files))
config_dict['glint']['apply_mask'] = [["external", {'class' : 1,
'files' : file_dict}]]
config_dict["glint"]["deep_water_sample"] = {
images[0]: [
225,250,240,260
]}
#Wavelength resampling options
##############################
'''
Types supported:
- 'gaussian': needs output waves and output FWHM
- 'linear', 'nearest', 'nearest-up',
'zero', 'slinear', 'quadratic','cubic': Piecewise
interpolation using Scipy interp1d
config_dict["resampler"] only needed when resampling == True
'''
config_dict["resample"] = False
# config_dict["resampler"] = {}
# config_dict["resampler"]['type'] = 'cubic'
# config_dict["resampler"]['out_waves'] = []
# config_dict["resampler"]['out_fwhm'] = []
# Remove bad bands from output waves
# for wavelength in range(450,660,100):
# bad=False
# for start,end in config_dict['bad_bands']:
# bad = ((wavelength >= start) & (wavelength <=end)) or bad
# if not bad:
# config_dict["resampler"]['out_waves'].append(wavelength)
config_dict['num_cpus'] = len(images)
with open(config_file, 'w') as outfile:
json.dump(config_dict,outfile,indent=3)

View File

@ -0,0 +1,338 @@
# python c:/mydata/image_correct_json_generate_gui.py
import os
import json
import glob
import numpy as np
import tkinter as tk
from tkinter.filedialog import asksaveasfilename, askdirectory
def fill_config(images, anc_files,out_coef_dir,img_file_type,corr_list, flag_pre_compute=False, topo_coeff = [], brdf_coeff=[]):
config_dict = {}
#Only coefficients for good bands will be calculated
config_dict['bad_bands'] =[[300,400],[1337,1430],[1800,1960],[2450,2600]]
config_dict['file_type'] = img_file_type #'envi'
aviris_anc_names = ['path_length','sensor_az','sensor_zn',
'solar_az', 'solar_zn','phase','slope',
'aspect', 'cosine_i','utc_time']
#aviris_anc_names = ['sensor_az','sensor_zn',
# 'solar_az', 'solar_zn']
images.sort()
config_dict["input_files"] = images
if img_file_type=='envi':
config_dict["anc_files"] = {}
anc_files.sort()
for i,image in enumerate(images):
config_dict["anc_files"][image] = dict(zip(aviris_anc_names,
[[anc_files[i],a] for a in range(len(aviris_anc_names))]))
config_dict['export'] = {}
config_dict["topo"] = {}
config_dict["brdf"] = {}
if flag_pre_compute:
config_dict['export']['coeffs'] = False
config_dict['export']['image'] = True
else:
config_dict['export']['coeffs'] = True
config_dict['export']['image'] = False
config_dict['export']['masks'] = False
config_dict['export']['subset_waves'] = [660,550,440] #[440,550,660,850] #
config_dict['export']['output_dir'] = out_coef_dir
print('_'.join(corr_list))
if len(corr_list)>0:
config_dict['export']["suffix"] = '_'.join(corr_list) # 'brdf'
else:
config_dict['export']["suffix"] = 'raw'
config_dict["corrections"] = corr_list # ['brdf']
if flag_pre_compute and len(topo_coeff)==len(images):
config_dict["topo"]['type'] = 'precomputed'
topo_files = sorted(topo_coeff)
#print(dict(zip(images, topo_files)))
config_dict["topo"]['coeff_files'] = dict(zip(images, topo_files))
else:
config_dict["topo"]['type'] = 'scs+c'
config_dict["topo"]['calc_mask'] = [["ndi", {'band_1': 850,'band_2': 660,
'min': 0.05,'max': 1.0}]
]
config_dict["topo"]['apply_mask'] = [["ndi", {'band_1': 850,'band_2': 660,
'min': 0.05,'max': 1.0}]]
config_dict["topo"]['c_fit_type'] = 'nnls' #'ols' #'nnls' #
if flag_pre_compute and len(brdf_coeff)==len(images):
config_dict["brdf"]['type'] = 'precomputed'
brdf_files = sorted(brdf_coeff)
config_dict["brdf"]['coeff_files'] = dict(zip(images, brdf_files))
else:
# Options are 'line','scene', or a float for a custom solar zn
# Custom solar zenith angle should be in radians
config_dict["brdf"]['solar_zn_type'] ='scene'
#----------------------
# ## Flex BRDF configs
# ##------------------
config_dict["brdf"]['type'] = 'flex'
config_dict["brdf"]['grouped'] = True
config_dict["brdf"]['geometric'] = 'li_sparse_r'
config_dict["brdf"]['volume'] = 'ross_thick'
config_dict["brdf"]["b/r"] = 2.5
config_dict["brdf"]["h/b"] = 2
config_dict["brdf"]['sample_perc'] = 0.1
config_dict["brdf"]['interp_kind'] = 'linear'
config_dict["brdf"]['calc_mask'] = [["ndi", {'band_1': 850,'band_2': 660,
'min': 0.05,'max': 1.0}],
['kernel_finite',{}],
['ancillary',{'name':'sensor_zn',
'min':np.radians(2),'max':'inf' }]
]
config_dict["brdf"]['apply_mask'] = [["ndi", {'band_1': 850,'band_2': 660,
'min': 0.05,'max': 1.0}]]
# ## Flex dynamic NDVI params
config_dict["brdf"]['bin_type'] = 'dynamic'
config_dict["brdf"]['num_bins'] = 18
config_dict["brdf"]['ndvi_bin_min'] = 0.05
config_dict["brdf"]['ndvi_bin_max'] = 1.0
config_dict["brdf"]['ndvi_perc_min'] = 10
config_dict["brdf"]['ndvi_perc_max'] = 95
config_dict["resample"] = False
config_dict['num_cpus'] = len(images)
return config_dict
'''
def update_corr_list(corr_list_):
#print(chk_topo.get(),chk_brdf.get())
corr_list_ = ['topo']*chk_topo.get()+['brdf']*chk_brdf.get()
#print(corr_list)
'''
def gen_config(entry_outdir, entry_outjson,img_list_out, obs_list_out, radio_f_type, corr_list,chk_precompute, topo_list_out,brdf_list_out):
outdir_name = entry_outdir.get()+'/'
out_json = entry_outjson.get()
images = img_list_out['text'].split('\n')
anc_files = obs_list_out['text'].split('\n')
img_file_type = str(radio_f_type.get()).lower()
corr_list_str = ['topo']*(corr_list[0].get()) + ['brdf']*(corr_list[1].get())
flag_pre_compute = bool(chk_precompute.get())
#print(flag_pre_compute,chk_precompute.get())
if flag_pre_compute:
topo_json_list = topo_list_out['text'].split('\n')
brdf_json_list = brdf_list_out['text'].split('\n')
return_json_dict = fill_config(images, anc_files,outdir_name,img_file_type,corr_list_str,flag_pre_compute=flag_pre_compute, topo_coeff = topo_json_list, brdf_coeff=brdf_json_list)
else:
return_json_dict = fill_config(images, anc_files,outdir_name,img_file_type,corr_list_str)
with open(out_json, 'w') as outfile:
json.dump(return_json_dict,outfile,indent=3)
window.title(f"File saved- {out_json}")
def save_file(txt_out_json):
"""Save the current file as a new file."""
filepath = asksaveasfilename(
defaultextension=".json",
filetypes=[("JSON Files", "*.json"), ("All Files", "*.*")],
)
if not filepath:
return
txt_out_json.delete(0,tk.END)
txt_out_json.insert(0,filepath)
window.title(f"Export Configuration JSON - {filepath}")
def open_folder(out_component):
"""Open a file for editing."""
in_img_dir = askdirectory()
if not in_img_dir:
return
#print(in_img_dir)
out_component.delete(0, tk.END)
out_component.insert(0,in_img_dir)
window.title(f"Folder- {in_img_dir}")
def open_file(out_component, list_component, radio_f_type=None, pattern=None):
"""Open a file for editing."""
in_img_dir = askdirectory()
if not in_img_dir:
return
#out_component.delete("1.0", tk.END)
#out_component.insert("1.0",in_img_dir)
out_component["text"] = in_img_dir
if pattern is None:
img_file_type = str(radio_f_type.get()).lower()
file_ext_list = {'envi':'*img','neon':'*.h5'}
pattern = file_ext_list[img_file_type]
#if pattern is None:
# return in_img_dir
file_list = sorted(glob.glob(in_img_dir+'/'+pattern))
if len(file_list)==0:
list_component['text']= 'No files selected'
return
list_component['text']= '\n'.join([ os.path.normpath(x) for x in file_list])
#print(list_component['text'])
window.title(f"Folder- {in_img_dir}")
window = tk.Tk()
window.title("Setup image correction configuration file")
#window.rowconfigure(0, minsize=400, weight=1)
#window.columnconfigure(1, minsize=600, weight=1)
#txt_edit = tk.Text(window)
frm_img_buttons = tk.Frame(window, relief=tk.RAISED, bd=2)
frm_img_buttons.columnconfigure(1, minsize=300, weight=1)
frm_img_buttons.rowconfigure(1, minsize=50, weight=1)
frm_obs_buttons = tk.Frame(window, relief=tk.RAISED, bd=2)
frm_obs_buttons.columnconfigure(1, minsize=300, weight=1)
frm_obs_buttons.rowconfigure(1, minsize=50, weight=1)
frm_out_buttons = tk.Frame(window, relief=tk.RAISED, bd=2)
frm_out_buttons.columnconfigure(1, weight=1) #, minsize=300
frm_out_json_buttons = tk.Frame(window, relief=tk.RAISED, bd=2)
frm_out_json_buttons.columnconfigure(1, weight=1) #minsize=300,
frm_file_type = tk.Frame(window, relief=tk.RAISED, bd=2)
frm_corr_type = tk.Frame(window, relief=tk.RAISED, bd=2)
frm_precomp = tk.Frame(window, relief=tk.RAISED, bd=2,highlightbackground="grey",highlightthickness=5,padx=5, pady=5)
frm_precomp.columnconfigure(1,weight=1)
frm_precomp.rowconfigure(1, minsize=50, weight=1)
frm_pre_topo_buttons = tk.Frame(frm_precomp, relief=tk.RAISED, bd=2)
frm_pre_topo_buttons.columnconfigure(1, minsize=280, weight=1)
#frm_pre_topo_buttons.rowconfigure(1, minsize=50, weight=1)
frm_pre_brdf_buttons = tk.Frame(frm_precomp, relief=tk.RAISED, bd=2)
frm_pre_brdf_buttons.columnconfigure(1, minsize=280, weight=1)
#frm_pre_brdf_buttons.rowconfigure(1, minsize=50, weight=1)
frm_final_gen = tk.Frame(window, relief=tk.RAISED, bd=2,highlightbackground="grey",highlightthickness=5)
frm_final_gen.rowconfigure(0, minsize=50, weight=0)
frm_final_gen.columnconfigure(0, minsize=100, weight=1) #
frm_final_gen.columnconfigure(1, minsize=50, weight=0)
frm_final_gen.columnconfigure(2, minsize=100, weight=1)
label_img_dir = tk.Label(frm_img_buttons,text="image", fg="white", bg="black")
label_obs_dir = tk.Label(frm_obs_buttons,text="image", fg="white", bg="black")
txt_outdir = tk.Entry(frm_out_buttons)
img_list_out = tk.Label(frm_img_buttons,bg="grey",anchor="w")
obs_list_out = tk.Label(frm_obs_buttons,bg="grey",anchor="w")
btn_open1 = tk.Button(frm_img_buttons, text="Image Folder...", command=lambda: open_file(label_img_dir,img_list_out,radio_f_type=f_type) ) #'*img'
btn_open2 = tk.Button(frm_obs_buttons, text="Obs_ort Folder...", command=lambda: open_file(label_obs_dir,obs_list_out,pattern='*obs_ort'))
btn_open3 = tk.Button(frm_out_buttons, text="Output coeff Folder...", command=lambda: open_folder(txt_outdir)) #, command=open_file(txt_edit, None)
#btn_open1 = tk.Button(frm_img_buttons, text="Image Folder...", command=lambda: open_file(label_img_dir,img_list_out,radio_f_type=f_type) ) #'*img'
#btn_open2 = tk.Button(frm_obs_buttons, text="Obs_ort Folder...", command=lambda: open_file(label_obs_dir,obs_list_out,pattern='*obs_ort'))
btn_open_topo = tk.Button(frm_pre_topo_buttons, text="TOPO json Folder...", command=lambda: open_file(label_pre_topo_dir,topo_list_out,pattern='*topo_coeffs*.json'))
btn_open_brdf = tk.Button(frm_pre_brdf_buttons, text="BRDF json Folder...", command=lambda: open_file(label_pre_brdf_dir,brdf_list_out,pattern='*brdf_coeffs*.json'))
txt_out_json = tk.Entry(frm_out_json_buttons)
btn_save = tk.Button(frm_out_json_buttons, text="Save As...", command=lambda: save_file(txt_out_json))
btn_open1.grid(row=0, column=0, sticky="ew", padx=5, pady=5)
label_img_dir.grid(row=0, column=1, sticky="ew", padx=5, pady=5)
img_list_out.grid(row=1, columnspan=2, sticky="nsew", padx=5, pady=5)
btn_open2.grid(row=0, column=0, sticky="ew", padx=5, pady=5)
label_obs_dir.grid(row=0, column=1, sticky="ew", padx=5, pady=5)
obs_list_out.grid(row=1, columnspan=2, sticky="nsew", padx=5, pady=5)
btn_open3.grid(row=0, column=0, sticky="ew", padx=5, pady=5)
txt_outdir.grid(row=0, column=1, sticky="ew", padx=5, pady=5)
btn_save.grid(row=0, column=0, sticky="ew", padx=5)
txt_out_json.grid(row=0, column=1, sticky="ew", padx=5, pady=5)
f_type = tk.StringVar(value="envi")
ev_btn = tk.Radiobutton(frm_file_type, text='ENVI (*img)', variable=f_type, value='envi')
h5_btn = tk.Radiobutton(frm_file_type, text='NEON HDF5 (*.h5)', variable=f_type, value='neon')
ev_btn.grid(row=0, column=0, sticky="ew", padx=5)
h5_btn.grid(row=0, column=1, sticky="ew", padx=5)
chk_topo = tk.IntVar(value=0)
chk_brdf = tk.IntVar(value=1)
chk_topo_btn = tk.Checkbutton(frm_corr_type, text='TOPO', variable=chk_topo, onvalue=1, offvalue=0) #, command=lambda: update_corr_list(corr_list))
chk_brdf_btn = tk.Checkbutton(frm_corr_type, text='BRDF', variable=chk_brdf, onvalue=1, offvalue=0) #, command=lambda: update_corr_list(corr_list))
chk_topo_btn.grid(row=0, column=0, sticky="ew", padx=5)
chk_brdf_btn.grid(row=0, column=1, sticky="ew", padx=5)
corr_list = [chk_topo,chk_brdf]
chk_precompute = tk.IntVar(value=0)
coeff_list= []
label_precomput = tk.Checkbutton(frm_precomp, text='Load Precomputed Coefficients', variable=chk_precompute, onvalue=1, offvalue=0,anchor="center") #tk.Label(frm_precomp,text="Precomputed Coefficients")
label_precomput.grid(row=0, columnspan=2, sticky="ew", padx=2, pady=2)
btn_open_topo.grid(row=0, column=0, sticky="ew", padx=2, pady=2)
btn_open_brdf.grid(row=0, column=0, sticky="ew", padx=2, pady=2)
label_pre_topo_dir = tk.Label(frm_pre_topo_buttons,text="topo json", fg="white", bg="black")
label_pre_topo_dir.grid(row=0, column=1, sticky="ew", padx=2, pady=2)
label_pre_brdf_dir = tk.Label(frm_pre_brdf_buttons,text="brdf json", fg="white", bg="black")
label_pre_brdf_dir.grid(row=0, column=1, sticky="ew", padx=2, pady=2)
topo_list_out = tk.Label(frm_pre_topo_buttons,bg="grey",anchor="w")
topo_list_out.grid(row=1, columnspan=2, sticky="ew", padx=2,pady=2)
brdf_list_out = tk.Label(frm_pre_brdf_buttons,bg="grey",anchor="w")
brdf_list_out.grid(row=1, columnspan=2, sticky="ew", padx=2,pady=2)
frm_pre_topo_buttons.grid(row=1, column=0, sticky="ew")
frm_pre_brdf_buttons.grid(row=1, column=1, sticky="ew")
#img_list_out = tk.Label(frm_img_buttons,bg="grey",anchor="w")
#obs_list_out = tk.Label(frm_obs_buttons,bg="grey",anchor="w")
btn_gen = tk.Button(frm_final_gen, text="Generate", font=("Calibri",12,"bold"), command=lambda: gen_config(txt_outdir,txt_out_json,img_list_out, obs_list_out,f_type,corr_list,chk_precompute,topo_list_out, brdf_list_out))
#btn_gen.place(relx=.5, rely=.5,anchor= 'e')
btn_gen.grid(row=0,column=1,sticky="wens") #anchor='center',
#btn_gen.place(relx=0.5, rely=0.95, anchor=tk.CENTER)
frm_img_buttons.grid(row=0, column=0, sticky="ew")
frm_obs_buttons.grid(row=0, column=1, sticky="ew")
frm_out_buttons.grid(row=1, columnspan=2, sticky="we")
frm_file_type.grid(row=2,column=0, sticky="ew")
frm_corr_type.grid(row=2,column=1, sticky="ew")
frm_precomp.grid(row=3,columnspan=2, sticky="nsew")
frm_out_json_buttons.grid(row=4, columnspan=2, sticky="ew")
frm_final_gen.grid(row=6, columnspan=2, sticky="nsew") #, sticky="nsew"
window.mainloop()

View File

@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
"""
HyTools: Hyperspectral image processing library
Copyright (C) 2021 University of Wisconsin
Authors: Adam Chlus, Zhiwei Ye, Philip Townsend.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
JSON format for PLSR trait models v0.1
This version assumes that all transforms are applied using all model wavelengths.
TODO: Develop standard codes for spectrometer, both airborne/spaceborne and field.
TODO: Allow for more options for specifying wavelength subsets
"""
import json
model_dict = {}
# Metadata
#####################################
'''
trait : Trait name (str)
units : Trait units (str)
description: Model description (str)
wavelength_units: Wavelength units (str)
wavelengths : Model wavelengths (list)
Only wavelengths used in the model should be
included in the list of wavelengths.
fwhm : Model fwhm (list)
type : Model type (str)
'''
model_dict["name"] = ''
model_dict["units"] = ''
model_dict["description"] = ''
model_dict["wavelength_units"] = ''
model_dict["wavelengths"] = []
model_dict["fwhm"] = []
model_dict["spectrometer"] = ''
model_dict["type"] = ''
# Diagnostics
#####################################
'''Currently the only required diagnostics are 'min'
and 'max', these are the min and max values of the
dataset used to build the model and are used to generate
the data range mask, which identifies pixels with predictions
outside of the model dataset range.
'''
model_dict["model_diagnostics"] = {}
model_dict["model_diagnostics"]["rmse"] = 0.0
model_dict["model_diagnostics"]["r_squared"] = 0.0
model_dict["model_diagnostics"]["min"] = 0.0
model_dict["model_diagnostics"]["max"] = 0.0
# Model
#####################################
'''
transform: List of transforms to be applied in order of application.
Options:
- 'vector': vector norm using np.linalg.norm
- 'mean' : Normalize to mean
- 'absorb' : log(1/R)
Examples:
['vector','absorb']
Empty list for no transforms ([])
coefficients: List of lists, sublists are the coefficients for
model iterations.
intercepts : Permuted model intercepts (list)
components : Number of model component (int)
'''
model_dict['model'] = {}
model_dict['model']["components"] = 0
model_dict['model']["transform"] = ['mean']
model_dict['model']["intercepts"] = []
model_dict['model']["coefficients"] =[[],[]]
model_path = '*.json'
with open(model_path, 'w') as outfile:
json.dump(model_dict,outfile)

View File

@ -0,0 +1,99 @@
'''Template script for generating trait_estimate configuration JSON files.
'''
import os
import json
import glob
home = os.path.expanduser("~")
#Output path for configuration file
config_file = "/.json"
config_dict = {}
config_dict['file_type'] = 'envi'
config_dict["output_dir"] = './'
config_dict['bad_bands'] =[[300,400],[1337,1430],[1800,1960],[2450,2600]]
# Input data settings for NEON
#################################################################
# config_dict['file_type'] = 'neon'
# images= glob.glob("*.h5")
# images.sort()
# config_dict["input_files"] = images
# Input data settings for ENVI
#################################################################
''' Only difference between ENVI and NEON settings is the specification
of the ancillary datasets (ex. viewing and solar geometry). All hytools
functions assume that the ancillary data and the image date are the same
size, spatially, and are ENVI formatted files.
The ancillary parameter is a dictionary with a key per image. Each value
per image is also a dictionary where the key is the dataset name and the
value is list consisting of the file path and the band number.
'''
config_dict['file_type'] = 'envi'
aviris_anc_names = ['path_length','sensor_az','sensor_zn',
'solar_az', 'solar_zn','phase','slope',
'aspect', 'cosine_i','utc_time']
images= glob.glob("*img")
images.sort()
config_dict["input_files"] = images
config_dict["anc_files"] = {}
anc_files = glob.glob("*ort")
anc_files.sort()
for i,image in enumerate(images):
config_dict["anc_files"][image] = dict(zip(aviris_anc_names,
[[anc_files[i],a] for a in range(len(aviris_anc_names))]))
config_dict['num_cpus'] = len(images)
# Assign correction coefficients
##########################################################
''' Specify correction(s) to apply and paths to coefficients.
'''
config_dict['corrections'] = ['topo','brdf']
topo_files = glob.glob("*topo.json")
topo_files.sort()
config_dict["topo"] = dict(zip(images,topo_files))
brdf_files = glob.glob("*brdf.json")
brdf_files.sort()
config_dict["brdf"] = dict(zip(images,brdf_files))
# Select wavelength resampling type
##########################################################
'''Wavelength resampler will only be used if image wavelengths
and model wavelengths do not match exactly
See image_correct_json_generate.py for options.
'''
config_dict["resampling"] = {}
config_dict["resampling"]['type'] = 'cubic'
# Masks
##########################################################
'''Specify list of masking layers to be appended to the
trait map. Each will be placed in a separate layer.
For no masks provide an empty list: []
'''
config_dict["masks"] = [["ndi", {'band_1': 850,'band_2': 660,
'min': 0.1,'max': 1.0}],
['neon_edge',{'radius': 30}]]
# Define trait coefficients
##########################################################
models = glob.glob('*.json')
models.sort()
config_dict["trait_models"] = models
with open(config_file, 'w') as outfile:
json.dump(config_dict,outfile)

View File

@ -0,0 +1,256 @@
import json
import os
import warnings
import sys
import time
import ray
import numpy as np
import hytools as ht
from hytools.io.envi import *
from hytools.topo import calc_topo_coeffs
from hytools.brdf import calc_brdf_coeffs
from hytools.glint import set_glint_parameters
from hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
# ---------- 全局函数:供 Ray 远程调用 ----------
def export_coeffs(hy_obj, export_dict):
"""Export correction coefficients to file."""
for correction in hy_obj.corrections:
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s.json" % (correction, export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
elif correction == 'glint':
continue
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict, outfile)
def apply_corrections(hy_obj, config_dict):
"""Apply correction to image and export to file."""
header_dict = hy_obj.get_header()
header_dict['data ignore value'] = hy_obj.no_data
header_dict['data type'] = 12
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s" % config_dict['export']["suffix"]
# Export all wavelengths
if len(config_dict['export']['subset_waves']) == 0:
if config_dict["resample"]:
hy_obj.resampler = config_dict['resampler']
waves = hy_obj.resampler['out_waves']
else:
waves = hy_obj.wavelengths
header_dict['bands'] = len(waves)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name, header_dict)
iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections,
resample=config_dict['resample'])
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line, iterator.current_line)
writer.close()
# Export subset of wavelengths
else:
waves = config_dict['export']['subset_waves']
bands = [hy_obj.wave_to_band(x) for x in waves]
waves = [round(hy_obj.wavelengths[x], 2) for x in bands]
header_dict['bands'] = len(bands)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name, header_dict)
for b, band_num in enumerate(bands):
band = hy_obj.get_band(band_num, corrections=hy_obj.corrections)
writer.write_band(band, b)
writer.close()
# Export masks
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
masks = []
mask_names = []
for correction in config_dict["corrections"]:
for mask_type in config_dict[correction]['apply_mask']:
mask_names.append(correction + '_' + mask_type[0])
masks.append(mask_create(hy_obj, [mask_type]))
header_dict['data type'] = 1
header_dict['bands'] = len(masks)
header_dict['band names'] = mask_names
header_dict['samples'] = hy_obj.columns
header_dict['lines'] = hy_obj.lines
header_dict['wavelength'] = []
header_dict['fwhm'] = []
header_dict['wavelength units'] = ''
header_dict['data ignore value'] = 255
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s_mask" % config_dict['export']["suffix"]
writer = WriteENVI(output_name, header_dict)
for band_num, mask in enumerate(masks):
mask = mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band(mask, band_num)
del masks
# ---------- 主类 ----------
class HyToolsCorrector:
"""
高光谱图像校正处理器
用法:
corrector = HyToolsCorrector()
corrector.run("path/to/config.json")
# 或者单次调用
HyToolsCorrector.process_file("path/to/config.json")
"""
def __init__(self):
"""初始化处理器(不自动启动 Ray"""
self.config_dict = None
self.ray_initialized = False
def load_config(self, config_file: str) -> dict:
"""加载 JSON 配置文件"""
with open(config_file, 'r') as f:
self.config_dict = json.load(f)
return self.config_dict
def _init_ray(self):
"""初始化 Ray如果尚未初始化"""
if not ray.is_initialized():
num_cpus = self.config_dict.get('num_cpus', None)
print(f"Initializing Ray with {num_cpus} CPUs.")
ray.init(num_cpus=num_cpus)
self.ray_initialized = True
else:
print("Ray already initialized, reusing existing instance.")
self.ray_initialized = True # 假设外部已初始化
def _shutdown_ray(self):
"""关闭 Ray仅在由本类启动时关闭"""
if self.ray_initialized and ray.is_initialized():
ray.shutdown()
self.ray_initialized = False
def run(self, config_file: str = None):
"""
执行完整的校正流程
:param config_file: JSON 配置文件路径,若为 None 则需先调用 load_config()
"""
if config_file is not None:
self.load_config(config_file)
if self.config_dict is None:
raise ValueError("配置文件未加载,请先调用 load_config() 或传入文件路径。")
print("Starting image correction process...")
total_start = time.time()
images = self.config_dict["input_files"]
# 初始化 Ray
self._init_ray()
# 创建远程 actors
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for _ in images]
# 加载图像
print(f"Loading {len(images)} image(s)...")
load_start = time.time()
if self.config_dict['file_type'] == 'envi':
anc_files = self.config_dict.get("anc_files", {})
_ = ray.get([a.read_file.remote(img, self.config_dict['file_type'],
anc_files.get(img, None))
for a, img in zip(actors, images)])
elif self.config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(img, self.config_dict['file_type'])
for a, img in zip(actors, images)])
else:
raise ValueError(f"不支持的文件类型: {self.config_dict['file_type']}")
_ = ray.get([a.create_bad_bands.remote(self.config_dict['bad_bands']) for a in actors])
load_time = time.time() - load_start
print(f"加载完成,耗时 {load_time:.2f}")
# 应用校正
if self.config_dict.get("corrections"):
corr_list = self.config_dict["corrections"]
print(f"Applying {len(corr_list)} correction(s): {', '.join(corr_list)}")
correction_start = time.time()
for correction in corr_list:
if correction == 'topo':
calc_topo_coeffs(actors, self.config_dict['topo'])
elif correction == 'brdf':
calc_brdf_coeffs(actors, self.config_dict)
elif correction == 'glint':
set_glint_parameters(actors, self.config_dict)
correction_time = time.time() - correction_start
print(f"校正完成,耗时 {correction_time:.2f}")
# 导出系数
export_cfg = self.config_dict.get('export', {})
if export_cfg.get('coeffs') and self.config_dict.get("corrections"):
print("Exporting correction coefficients...")
coeff_start = time.time()
_ = ray.get([a.do.remote(export_coeffs, export_cfg) for a in actors])
print(f"系数导出完成,耗时 {time.time() - coeff_start:.2f}")
# 导出图像
if export_cfg.get('image'):
print("Exporting corrected images...")
export_start = time.time()
_ = ray.get([a.do.remote(apply_corrections, self.config_dict) for a in actors])
print(f"图像导出完成,耗时 {time.time() - export_start:.2f}")
total_time = time.time() - total_start
print(f"全部任务完成,总耗时 {total_time:.2f}")
# 关闭 Ray可选若希望复用则不关闭
self._shutdown_ray()
@staticmethod
def process_file(config_file: str):
"""静态方法:单次处理一个配置文件(自动管理 Ray 生命周期)"""
corrector = HyToolsCorrector()
corrector.run(config_file)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._shutdown_ray()
# ---------- 向后兼容的 main 入口 ----------
def main():
"""保持与原始脚本的兼容性"""
if len(sys.argv) > 1:
config_path = sys.argv[1]
else:
config_path = r"E:\code\hytools-master\hytools-master\examples\configs\aviris.json"
HyToolsCorrector.process_file(config_path)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,189 @@
import json
import os
import warnings
import sys
import time
import ray
import numpy as np
import hytools as ht
from hytools.io.envi import *
from hytools.topo import calc_topo_coeffs
from hytools.brdf import calc_brdf_coeffs
from hytools.glint import set_glint_parameters
from hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
def main():
config_file = r"E:\code\hytools-master\hytools-master\examples\configs\aviris.json"
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
print("Starting image correction process...")
total_start = time.time()
images = config_dict["input_files"]
if ray.is_initialized():
ray.shutdown()
print("Using %s CPUs." % config_dict['num_cpus'])
ray.init(num_cpus = config_dict['num_cpus'])
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for image in images]
print(f"Loading {len(images)} image(s)...")
load_start = time.time()
if config_dict['file_type'] == 'envi':
anc_files = config_dict["anc_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_files[image]) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
_ = ray.get([a.create_bad_bands.remote(config_dict['bad_bands']) for a in actors])
load_time = time.time() - load_start
print(".2f")
# Apply corrections
if config_dict["corrections"]:
print(f"Applying {len(config_dict['corrections'])} correction(s): {', '.join(config_dict['corrections'])}")
correction_start = time.time()
for correction in config_dict["corrections"]:
if correction =='topo':
calc_topo_coeffs(actors,config_dict['topo'])
elif correction == 'brdf':
calc_brdf_coeffs(actors,config_dict)
elif correction == 'glint':
set_glint_parameters(actors,config_dict)
correction_time = time.time() - correction_start
print(".2f")
if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0:
print("Exporting correction coefficients...")
coeff_start = time.time()
_ = ray.get([a.do.remote(export_coeffs,config_dict['export']) for a in actors])
coeff_time = time.time() - coeff_start
print(".2f")
if config_dict['export']['image']:
print("Exporting corrected images...")
export_start = time.time()
_ = ray.get([a.do.remote(apply_corrections,config_dict) for a in actors])
export_time = time.time() - export_start
print(".2f")
total_time = time.time() - total_start
print(".2f")
ray.shutdown()
def export_coeffs(hy_obj,export_dict):
'''Export correction coefficients to file.
'''
for correction in hy_obj.corrections:
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s.json" % (correction,export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
elif correction == 'glint':
continue
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict,outfile)
def apply_corrections(hy_obj,config_dict):
'''Apply correction to image and export
to file.
'''
header_dict = hy_obj.get_header()
header_dict['data ignore value'] = hy_obj.no_data
header_dict['data type'] = 12
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s" % config_dict['export']["suffix"]
#Export all wavelengths
if len(config_dict['export']['subset_waves']) == 0:
if config_dict["resample"] == True:
hy_obj.resampler = config_dict['resampler']
waves= hy_obj.resampler['out_waves']
else:
waves = hy_obj.wavelengths
header_dict['bands'] = len(waves)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name,header_dict)
iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections,
resample=config_dict['resample'])
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line,iterator.current_line)
writer.close()
#Export subset of wavelengths
else:
waves = config_dict['export']['subset_waves']
bands = [hy_obj.wave_to_band(x) for x in waves]
waves = [round(hy_obj.wavelengths[x],2) for x in bands]
header_dict['bands'] = len(bands)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name,header_dict)
for b,band_num in enumerate(bands):
band = hy_obj.get_band(band_num,
corrections=hy_obj.corrections)
writer.write_band(band, b)
writer.close()
#Export masks
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
masks = []
mask_names = []
for correction in config_dict["corrections"]:
for mask_type in config_dict[correction]['apply_mask']:
mask_names.append(correction + '_' + mask_type[0])
masks.append(mask_create(hy_obj, [mask_type]))
header_dict['data type'] = 1
header_dict['bands'] = len(masks)
header_dict['band names'] = mask_names
header_dict['samples'] = hy_obj.columns
header_dict['lines'] = hy_obj.lines
header_dict['wavelength'] = []
header_dict['fwhm'] = []
header_dict['wavelength units'] = ''
header_dict['data ignore value'] = 255
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s_mask" % config_dict['export']["suffix"]
writer = WriteENVI(output_name,header_dict)
for band_num,mask in enumerate(masks):
mask = mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band(mask,band_num)
del masks
if __name__== "__main__":
main()

View File

@ -0,0 +1,442 @@
import json
import os
import warnings
import sys
import time
import argparse
import ray
import numpy as np
import Flexbrdf.hytools as ht
from Flexbrdf.hytools.io.envi import *
from Flexbrdf.hytools.topo import calc_topo_coeffs
from Flexbrdf.hytools.brdf import calc_brdf_coeffs
from Flexbrdf.hytools.glint import set_glint_parameters
from Flexbrdf.hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
# Default configuration values
DEFAULT_CONFIG = {
'file_type': 'envi',
'num_cpus': 16,
'bad_bands': [],
'corrections': [],
'resample': False,
'resampler': {'type': 'cubic', 'out_waves': []},
'export': {
'coeffs':True,
'image': True,
'masks': False,
'subset_waves': [],
'output_dir': './output/',
'suffix': 'BRDF'
},
'brdf': {
'type': 'flex',
'grouped': True,
'geometric': 'li_dense_r',
'volume': 'ross_thick',
'b/r': 2.5,
'h/b': 2.0,
'sample_perc': 0.1,
'interp_kind': 'linear',
'calc_mask': [
['water', {'band_1': 850, 'band_2': 660,"threshold": 290}],
['kernel_finite', {}],
['ancillary', {'name': 'sensor_zn', 'min': 0.03490658503988659, 'max': 'inf'}]
],
'apply_mask': [
['water', {'band_1': 850, 'band_2': 660,"threshold": 290}]
],
'bin_type': 'dynamic',
'num_bins': 18,
'ndvi_bin_min': 0.05,
'ndvi_bin_max': 1.0,
'ndvi_perc_min': 10,
'ndvi_perc_max': 95,
'solar_zn_type': 'scene'
},
# 'topo': {
# 'type': 'scs+c',
# 'calc_mask': [
# ['ndi', {'band_1': 850, 'band_2': 660, 'min': 0.05, 'max': 1.0}],
# ['kernel_finite', {}],
# ['ancillary', {'name': 'sensor_zn', 'min': 0.03490658503988659, 'max': 'inf'}]
# ],
# 'apply_mask': [
# ['ndi', {'band_1': 850, 'band_2': 660, 'min': 0.05, 'max': 1.0}]
# ],
# 'sample_perc': 0.1,
# 'subgrouped': False,
# 'subgroup': {}
# },
# 'glint': {
# 'type': 'hochberg',
# 'correction_band': 560,
# 'deep_water_sample': {},
# 'calc_mask': [
# ['ndi', {'band_1': 550, 'band_2': 2150, 'min': -1, 'max': 0}],
# ['kernel_finite', {}]
# ],
# 'apply_mask': [
# ['ndi', {'band_1': 550, 'band_2': 2150, 'min': -1, 'max': 0}]
# ]
# }
}
def build_config_from_args(args):
"""从命令行参数构建配置字典"""
# 自动发现输入文件
input_files = []
if os.path.isdir(args.input_dir):
# 支持的文件扩展名
extensions = ['.tif', '.tiff', '.envi', '.img', '.dat']
for file in os.listdir(args.input_dir):
if any(file.lower().endswith(ext) for ext in extensions):
input_files.append(os.path.join(args.input_dir, file))
else:
input_files = [args.input_dir] if os.path.isfile(args.input_dir) else []
if not input_files:
raise ValueError(f"No input files found in {args.input_dir}")
# 自动生成anc_files
anc_files = {}
for input_file in input_files:
base_name = os.path.splitext(os.path.basename(input_file))[0]
# 根据文件名模式生成对应的angles文件路径
# 例如: 2025_9_2_3_53_45_202592_35252_0_rad_geo_corrected_reflectance.dat
# 对应的angles文件: 2025_9_2_3_53_45_202592_35252_0_rad_rgbxyz_geo_registered_angles.bip
parts = base_name.split('_')
if len(parts) >= 9: # 确保有足够的parts
anc_base = f"{parts[0]}_{parts[1]}_{parts[2]}_{parts[3]}_{parts[4]}_{parts[5]}_{parts[6]}_{parts[7]}_{parts[8]}_rad_rgbxyz_geo_angles_registered"
# 确保路径分隔符统一使用反斜杠Windows
anc_dir_clean = args.anc_dir.replace('/', '\\')
anc_path = os.path.join(anc_dir_clean, anc_base + ".bip")
# 确保最终路径格式正确
anc_path = os.path.normpath(anc_path)
if os.path.exists(anc_path):
# 生成完整的anc_files字典结构
anc_files[input_file] = {
"path_length": [anc_path, 0],
"sensor_az": [anc_path, 9],
"sensor_zn": [anc_path, 8],
"solar_az": [anc_path, 7],
"solar_zn": [anc_path, 6],
"phase": [anc_path, 0],
"slope": [anc_path, 0],
"aspect": [anc_path, 0],
"cosine_i": [anc_path, 0],
"utc_time": [anc_path, 0]
}
else:
# 如果找不到对应的angles文件使用默认的obs文件
obs_path = os.path.join(args.anc_dir, base_name + "_obs")
if os.path.exists(obs_path):
anc_files[input_file] = obs_path
# 构建配置字典
config_dict = DEFAULT_CONFIG.copy()
config_dict.update({
'input_files': input_files,
'anc_files': anc_files,
'file_type': 'envi', # 假设都是ENVI格式
'num_cpus': args.num_cpus,
'bad_bands': args.bad_bands if args.bad_bands else [],
'corrections': args.corrections if args.corrections else [],
'export': {
'coeffs': args.export_coeffs,
'image': args.export_image,
'masks': args.export_masks,
'subset_waves': args.subset_waves if args.subset_waves else [],
'output_dir': args.output_dir,
'suffix': args.suffix
}
})
# 根据命令行参数更新BRDF配置
if 'brdf' in args.corrections:
brdf_config = config_dict['brdf'].copy()
if hasattr(args, 'brdf_type') and args.brdf_type:
brdf_config['type'] = args.brdf_type
if hasattr(args, 'grouped') and args.grouped is not None:
brdf_config['grouped'] = args.grouped
if hasattr(args, 'geometric') and args.geometric:
brdf_config['geometric'] = args.geometric
if hasattr(args, 'volume') and args.volume:
brdf_config['volume'] = args.volume
if hasattr(args, 'num_bins') and args.num_bins:
brdf_config['num_bins'] = args.num_bins
config_dict['brdf'] = brdf_config
# 根据命令行参数更新TOPO配置
if 'topo' in args.corrections:
topo_config = config_dict['topo'].copy()
if hasattr(args, 'topo_type') and args.topo_type:
topo_config['type'] = args.topo_type
config_dict['topo'] = topo_config
# 根据命令行参数更新Glint配置
if 'glint' in args.corrections:
glint_config = config_dict['glint'].copy()
if hasattr(args, 'glint_type') and args.glint_type:
glint_config['type'] = args.glint_type
config_dict['glint'] = glint_config
return config_dict
def main():
parser = argparse.ArgumentParser(description="High-resolution image correction with automatic configuration")
# 必需参数
parser.add_argument('input_dir', help='Input directory containing image files or single image file path')
parser.add_argument('anc_dir', help='Ancillary directory containing angle/obs files')
parser.add_argument('--output-dir', required=True, help='Output directory for corrected images')
# 可选参数 - 基本设置
parser.add_argument('--num-cpus', type=int, default=16, help='Number of CPUs to use (default: 4)')
parser.add_argument('--bad-bands', nargs='*', type=int, default=[], help='Bad band indices to exclude')
# 校正类型
parser.add_argument('--corrections', nargs='*', choices=['topo', 'brdf', 'glint'],
default=['brdf'], help='Correction types to apply (default: brdf)')
# BRDF参数
parser.add_argument('--brdf-type', choices=['universal', 'flex'], default='flex',
help='BRDF correction type (default: flex)')
parser.add_argument('--grouped', action='store_true', default=True,
help='Group images for BRDF correction (default: True)')
parser.add_argument('--no-grouped', action='store_false', dest='grouped',
help='Do not group images for BRDF correction')
parser.add_argument('--geometric', default='li_dense_r',
choices=['li_sparse', 'li_dense', 'li_dense_r', 'roujean'],
help='Geometric kernel type (default: li_dense_r)')
parser.add_argument('--volume', default='ross_thick',
choices=['ross_thin', 'ross_thick', 'hotspot', 'roujean'],
help='Volume kernel type (default: ross_thick)')
parser.add_argument('--num-bins', type=int, default=18,
help='Number of NDVI bins for FlexBRDF (default: 18)')
# TOPO参数
parser.add_argument('--topo-type', default='scs+c',
choices=['scs', 'scs+c', 'c', 'cosine', 'mod_minneart'],
help='TOPO correction type (default: scs+c)')
# Glint参数
parser.add_argument('--glint-type', default='hochberg',
choices=['hochberg', 'gao', 'hedley'],
help='Glint correction type (default: hochberg)')
# 输出选项
parser.add_argument('--export-coeffs', action='store_true', default=True,
help='Export correction coefficients')
parser.add_argument('--export-image', action='store_true', default=True,
help='Export corrected images (default: True)')
parser.add_argument('--export-masks', action='store_true', default=True,
help='Export correction masks (default: True)')
parser.add_argument('--subset-waves', nargs='*', type=float, default=[],
help='Subset of wavelengths to export (empty for all)')
parser.add_argument('--suffix', default='corrected',
help='Suffix for output files (default: corrected)')
parser.add_argument('--output-config', type=str, default=None,
help='Output current configuration to JSON file')
# 向后兼容如果只有一个参数且是JSON文件则使用传统模式
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
config_file = sys.argv[1]
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
else:
args = parser.parse_args()
config_dict = build_config_from_args(args)
# 如果指定了输出配置选项则输出配置到JSON文件
if hasattr(args, 'output_config') and args.output_config:
print(f"Outputting configuration to {args.output_config}...")
# 重新排序配置字典以匹配期望的格式
ordered_config = {
"bad_bands": config_dict.get("bad_bands", []),
"file_type": config_dict.get("file_type", "envi"),
"input_files": config_dict.get("input_files", []),
"anc_files": config_dict.get("anc_files", {}),
"num_cpus": config_dict.get("num_cpus", 4),
"export": config_dict.get("export", {}),
"corrections": config_dict.get("corrections", []),
"brdf": config_dict.get("brdf", {}),
"topo": config_dict.get("topo", {}),
"glint": config_dict.get("glint", {}),
"resample": config_dict.get("resample", False),
"resampler": config_dict.get("resampler", {"type": "cubic", "out_waves": []})
}
with open(args.output_config, 'w', encoding='utf-8') as f:
json.dump(ordered_config, f, indent=2, ensure_ascii=False)
print("Configuration output complete.")
return # 如果只是输出配置,则退出程序
print("Starting image correction process...")
total_start = time.time()
images = config_dict["input_files"]
if ray.is_initialized():
ray.shutdown()
print("Using %s CPUs." % config_dict['num_cpus'])
ray.init(num_cpus = config_dict['num_cpus'])
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for image in images]
print(f"Loading {len(images)} image(s)...")
load_start = time.time()
if config_dict['file_type'] == 'envi':
anc_files = config_dict["anc_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_files[image]) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
_ = ray.get([a.create_bad_bands.remote(config_dict['bad_bands']) for a in actors])
load_time = time.time() - load_start
print(".2f")
# Apply corrections
if config_dict["corrections"]:
print(f"Applying {len(config_dict['corrections'])} correction(s): {', '.join(config_dict['corrections'])}")
correction_start = time.time()
for correction in config_dict["corrections"]:
if correction =='topo':
calc_topo_coeffs(actors,config_dict['topo'])
elif correction == 'brdf':
calc_brdf_coeffs(actors,config_dict)
elif correction == 'glint':
set_glint_parameters(actors,config_dict)
correction_time = time.time() - correction_start
print(".2f")
if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0:
print("Exporting correction coefficients...")
coeff_start = time.time()
_ = ray.get([a.do.remote(export_coeffs,config_dict['export']) for a in actors])
coeff_time = time.time() - coeff_start
print(".2f")
if config_dict['export']['image']:
print("Exporting corrected images...")
export_start = time.time()
_ = ray.get([a.do.remote(apply_corrections,config_dict) for a in actors])
export_time = time.time() - export_start
print(".2f")
total_time = time.time() - total_start
print(".2f")
ray.shutdown()
def export_coeffs(hy_obj,export_dict):
'''Export correction coefficients to file.
'''
for correction in hy_obj.corrections:
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s.json" % (correction,export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
elif correction == 'glint':
continue
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict,outfile)
def apply_corrections(hy_obj,config_dict):
'''Apply correction to image and export
to file.
'''
header_dict = hy_obj.get_header()
header_dict['data ignore value'] = hy_obj.no_data
header_dict['data type'] = 12
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s" % config_dict['export']["suffix"]
#Export all wavelengths
if len(config_dict['export']['subset_waves']) == 0:
if config_dict["resample"] == True:
hy_obj.resampler = config_dict['resampler']
waves= hy_obj.resampler['out_waves']
else:
waves = hy_obj.wavelengths
header_dict['bands'] = len(waves)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name,header_dict)
iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections,
resample=config_dict['resample'])
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line,iterator.current_line)
writer.close()
#Export subset of wavelengths
else:
waves = config_dict['export']['subset_waves']
bands = [hy_obj.wave_to_band(x) for x in waves]
waves = [round(hy_obj.wavelengths[x],2) for x in bands]
header_dict['bands'] = len(bands)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name,header_dict)
for b,band_num in enumerate(bands):
band = hy_obj.get_band(band_num,
corrections=hy_obj.corrections)
writer.write_band(band, b)
writer.close()
#Export masks
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
masks = []
mask_names = []
for correction in config_dict["corrections"]:
for mask_type in config_dict[correction]['apply_mask']:
mask_names.append(correction + '_' + mask_type[0])
masks.append(mask_create(hy_obj, [mask_type]))
header_dict['data type'] = 1
header_dict['bands'] = len(masks)
header_dict['band names'] = mask_names
header_dict['samples'] = hy_obj.columns
header_dict['lines'] = hy_obj.lines
header_dict['wavelength'] = []
header_dict['fwhm'] = []
header_dict['wavelength units'] = ''
header_dict['data ignore value'] = 255
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s_mask" % config_dict['export']["suffix"]
writer = WriteENVI(output_name,header_dict)
for band_num,mask in enumerate(masks):
mask = mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band(mask,band_num)
del masks
if __name__== "__main__":
main()

View File

@ -0,0 +1,376 @@
import json
import os
import warnings
import sys
import ray
import numpy as np
import time
import hytools as ht
from hytools.io.envi import *
from hytools.io.netcdf import *
from hytools.topo import calc_topo_coeffs
from hytools.brdf import calc_brdf_coeffs
from hytools.glint import set_glint_parameters
from hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
def main():
time_start = time.perf_counter()
config_file = sys.argv[1]
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
if len(sys.argv)>2:
meta_file = sys.argv[2]
with open(meta_file, 'r') as outfile:
meta_dict = json.load(outfile)
config_dict["outside_metadata"] = meta_dict
else:
if "outside_metadata" in config_dict:
if not isinstance(config_dict["outside_metadata"],dict):
with open(config_dict["outside_metadata"], 'r') as outfile:
# load json and replace it by a dict
meta_dict = json.load(outfile)
config_dict["outside_metadata"] = meta_dict
else:
config_dict["outside_metadata"] = None
images= config_dict["input_files"]
if not config_dict['file_type'].lower() in ['envi','emit','ncav','neon']:
print("Image type is not recognized.")
return
if 'image_format' in config_dict['export']:
if not config_dict['export']['image_format'].lower() in ['netcdf','envi']:
print("Export image type is not recognized.")
return
else:
config_dict['export']['image_format']='envi'
if ray.is_initialized():
ray.shutdown()
print("Using %s CPUs." % config_dict['num_cpus'])
ray.init(num_cpus = config_dict['num_cpus'])
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for image in images]
if config_dict['file_type'] in ['envi','emit','ncav']:
anc_files = config_dict["anc_files"]
if ('topo' in config_dict['corrections']) or ('brdf' in config_dict['corrections']):
if "glt_files" in config_dict:
if bool(config_dict["glt_files"]):
glt_files = config_dict["glt_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_path=anc_files[image],glt_path=glt_files[image]) for a,image in zip(actors,images)])
else:
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_path=anc_files[image]) for a,image in zip(actors,images)])
else:
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_path=anc_files[image]) for a,image in zip(actors,images)])
else:
if "glt_files" in config_dict:
if bool(config_dict["glt_files"]):
glt_files = config_dict["glt_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
glt_path=glt_files[image]) for a,image in zip(actors,images)])
else:
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
else:
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
_ = ray.get([a.create_bad_bands.remote(config_dict['bad_bands']) for a in actors])
for correction in config_dict["corrections"]:
if correction =='topo':
time_topo_start = time.perf_counter()
calc_topo_coeffs(actors,config_dict['topo'])
time_topo_end = time.perf_counter()
print("TOPO Time: {} sec.".format(time_topo_end - time_topo_start))
elif correction == 'brdf':
time_brdf_start = time.perf_counter()
calc_brdf_coeffs(actors,config_dict)
time_brdf_end = time.perf_counter()
print("BRDF Time: {} sec.".format(time_brdf_end - time_brdf_start))
elif correction == 'glint':
time_glint_start = time.perf_counter()
set_glint_parameters(actors,config_dict)
time_glint_end = time.perf_counter()
print("Glint Time: {} sec.".format(time_glint_end - time_glint_start))
if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0:
print("Exporting correction coefficients.")
_ = ray.get([a.do.remote(export_coeffs,config_dict['export']) for a in actors])
time_export_start = time.perf_counter()
if config_dict['export']['image']:
print("Exporting corrected images.")
_ = ray.get([a.do.remote(apply_corrections,config_dict) for a in actors])
time_export_end = time.perf_counter()
print("{} Export Time: {} sec.".format(images[0],time_export_end - time_export_start))
ray.shutdown()
time_end = time.perf_counter()
print("Total Time: {} sec.".format(time_end - time_start))
def export_coeffs(hy_obj,export_dict):
'''Export correction coefficients to file.
'''
for correction in hy_obj.corrections:
if correction=='unsmooth':
continue
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s.json" % (correction,export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict,outfile)
def apply_corrections(hy_obj,config_dict):
'''Apply correction to image and export
to file.
'''
use_glt_output_bool=False
if 'use_glt' in config_dict['export']:
use_glt_output_bool = config_dict['export']['use_glt']
if use_glt_output_bool==True:
header_dict = hy_obj.get_header(warp_glt=True)
else:
header_dict = hy_obj.get_header()
else:
header_dict = hy_obj.get_header()
header_dict['data ignore value'] = hy_obj.no_data
header_dict['data type'] = 4
if 'image_format' in config_dict['export']:
outformat = config_dict['export']['image_format'].lower()
if not outformat in ["envi", "netcdf"]:
print("Output image format is neither 'ENVI' nor 'NetCDF', default output format is set to 'ENVI'.")
outformat="envi"
else:
outformat="envi"
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
if outformat=='envi':
output_name += "_%s" % config_dict['export']["suffix"]
elif outformat=='netcdf':
output_name += "_%s.nc" % config_dict['export']["suffix"]
#Export all wavelengths
if len(config_dict['export']['subset_waves']) == 0:
if config_dict["resample"] == True:
hy_obj.resampler = config_dict['resampler']
waves= hy_obj.resampler['out_waves']
else:
waves = hy_obj.wavelengths
header_dict['bands'] = len(waves)
header_dict['wavelength'] = waves
header_dict['fwhm'] = hy_obj.fwhm
header_dict['file_type'] = config_dict['file_type']
if outformat=='envi':
writer = WriteENVI(output_name,header_dict)
elif outformat=='netcdf':
if hy_obj.file_type=='emit':
print("EMIT Full export is not supported yet.")
return
header_dict['lines_glt'] = hy_obj.lines_glt
header_dict['samples_glt'] = hy_obj.columns_glt
writer = WriteNetCDF(output_name,header_dict,
type_tag="reflectance",
attr_dict=config_dict["outside_metadata"],
glt_bool=use_glt_output_bool)
iterator = hy_obj.iterate(by = 'line', corrections = hy_obj.corrections,
resample = config_dict['resample'])
if outformat=='netcdf' and hy_obj.file_type=='emit':
iterator = hy_obj.iterate(by = 'band', corrections = hy_obj.corrections,
resample = config_dict['resample'])
else:
iterator = hy_obj.iterate(by = 'line', corrections = hy_obj.corrections,
resample = config_dict['resample'])
if use_glt_output_bool==False:
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line,iterator.current_line)
if outformat=='netcdf' and hy_obj.file_type=='emit':
writer.write_glt_dataset(hy_obj.glt_x,hy_obj.glt_y,dim_x_name="ortho_x",dim_y_name="ortho_y")
else:
for b,band_num in enumerate(range(header_dict['bands'])):
if hy_obj.bad_bands[b]==True:
continue
band = hy_obj.get_band(band_num,
corrections = hy_obj.corrections)
if outformat=='envi':
writer.write_band_glt(band,b, (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
elif outformat=='netcdf' and hy_obj.file_type=='emit':
writer.write_netcdf_band_glt(band,b, (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
writer.close()
#Export subset of wavelengths
else:
waves = config_dict['export']['subset_waves']
bands = [hy_obj.wave_to_band(x) for x in waves]
waves = [round(hy_obj.wavelengths[x],2) for x in bands]
header_dict['bands'] = len(bands)
header_dict['wavelength'] = waves
header_dict['fwhm'] = [hy_obj.fwhm[x] for x in bands]
header_dict['file_type'] = config_dict['file_type']
if outformat=='envi':
writer = WriteENVI(output_name,header_dict)
elif outformat=='netcdf':
header_dict['lines_glt'] = hy_obj.lines_glt
header_dict['samples_glt'] = hy_obj.columns_glt
writer = WriteNetCDF(output_name,header_dict,
type_tag="reflectance",
attr_dict=config_dict["outside_metadata"],
glt_bool=use_glt_output_bool)
if use_glt_output_bool==False:
for b,band_num in enumerate(bands):
band = hy_obj.get_band(band_num,
corrections = hy_obj.corrections)
writer.write_band(band, b)
if outformat=='netcdf' and hy_obj.file_type=='emit':
writer.write_glt_dataset(hy_obj.glt_x,hy_obj.glt_y,dim_x_name="ortho_x",dim_y_name="ortho_y")
else:
for b,band_num in enumerate(bands):
band = hy_obj.get_band(band_num,
corrections = hy_obj.corrections)
if outformat=='envi':
writer.write_band_glt(band,b, (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
elif outformat=='netcdf':
writer.write_netcdf_band_glt(band,b, (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
writer.close()
#Export masks
# does not work for precomputed json coeffs
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
masks = []
mask_names = []
for correction in config_dict["corrections"]:
if correction=='unsmooth':
continue
if config_dict[correction]["type"]=="precomputed":
with open(config_dict[correction]['coeff_files'][hy_obj.file_name], 'r') as outfile:
tmp_dict = json.load(outfile)
config_dict[correction]['apply_mask'] = tmp_dict['apply_mask']
for mask_type in config_dict[correction]['apply_mask']:
if mask_type[0]=='ndi':
b1_tag = mask_type[1]["band_1"]
b2_tag = mask_type[1]["band_2"]
mask_extend_name = f"{mask_type[0]}_{b1_tag}_{b2_tag}"
mask_names.append(correction + '_' + mask_extend_name)
elif mask_type[0]=='ancillary':
name_tag = mask_type[1]["name"]
mask_extend_name = f"{mask_type[0]}_{name_tag}"
mask_names.append(correction + '_' + mask_extend_name)
else:
mask_names.append(correction + '_' + mask_type[0])
masks.append(mask_create(hy_obj, [mask_type]))
header_dict['data type'] = 1
header_dict['bands'] = len(masks)
header_dict['band names'] = mask_names
header_dict['wavelength'] = []
header_dict['fwhm'] = []
header_dict['wavelength units'] = ''
header_dict['data ignore value'] = 255
header_dict['file_type'] = config_dict['file_type']
if use_glt_output_bool==False:
header_dict['samples'] = hy_obj.columns
header_dict['lines'] = hy_obj.lines
else:
header_dict['samples'] = hy_obj.columns_glt
header_dict['lines'] = hy_obj.lines_glt
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
if outformat=='envi':
output_name += "_%s_mask" % config_dict['export']["suffix"]
writer = WriteENVI(output_name,header_dict)
if use_glt_output_bool==False:
for band_num,mask in enumerate(masks):
mask =mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band(mask,band_num)
else:
for band_num,mask in enumerate(masks):
mask =mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band_glt(mask,band_num, (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
del masks
elif outformat=='netcdf':
output_name += "_%s.nc" % config_dict['export']["suffix"]
for band_num,mask in enumerate(masks):
mask =mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer = WriteNetCDF(output_name,header_dict,
type_tag="mask",
attr_dict=None,
glt_bool=use_glt_output_bool,
band_name = mask_names[band_num])
if use_glt_output_bool==False:
writer.write_mask_band(mask)
else:
writer.write_mask_band_glt(mask, (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
writer.close()
if __name__== "__main__":
main()

View File

@ -0,0 +1,210 @@
import json
import os
import warnings
import sys
import ray
import numpy as np
import time
import hytools as ht
from hytools.io.envi import *
from hytools.topo import calc_topo_coeffs
from hytools.brdf import calc_brdf_coeffs
from hytools.masks import mask_create
from hytools.misc import update_topo_group
import logging
logging.basicConfig(format='%(asctime)s %(message)s',level=logging.INFO)
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
def main():
time_start = time.perf_counter()
config_file = sys.argv[1]
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
images= config_dict["input_files"]
if ray.is_initialized():
ray.shutdown()
logging.info("Using %s CPUs." % config_dict['num_cpus'])
ray.init(num_cpus = config_dict['num_cpus'])
HyTools = ray.remote(ht.HyTools)
if "subgrouped" in config_dict["topo"]:
if config_dict["topo"]["subgrouped"]:
subgroup_list, group_tag_list = update_topo_group(config_dict["topo"]["subgroup"])
actor_subgroup = []
for file_name_list in subgroup_list:
actor_subgroup+=[[HyTools.remote() for image in file_name_list]]
actors = []
for actor_list in actor_subgroup:
actors+= actor_list
else:
actors = [HyTools.remote() for image in images]
actor_subgroup = None
group_tag_list=None
else:
actors = [HyTools.remote() for image in images]
actor_subgroup = None
group_tag_list=None
if config_dict['file_type'] == 'envi' or config_dict['file_type'] == 'ncav' or config_dict['file_type'] == 'emit':
anc_files = config_dict["anc_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_files[image]) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
#Here is where the outlier detection should probably happen.
_ = ray.get([a.create_bad_bands.remote(config_dict['bad_bands']) for a in actors])
for correction in config_dict["corrections"]:
if correction =='topo':
time_topo_start = time.perf_counter()
calc_topo_coeffs(actors,config_dict['topo'],actor_group_list=actor_subgroup,group_tag_list=group_tag_list)
time_topo_end = time.perf_counter()
logging.info("TOPO Time: {} sec.".format(time_topo_end - time_topo_start))
elif correction == 'brdf':
time_brdf_start = time.perf_counter()
calc_brdf_coeffs(actors,config_dict)
time_brdf_end = time.perf_counter()
logging.info("BRDF Time: {} sec.".format(time_brdf_end - time_brdf_start))
if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0:
logging.info("Exporting correction coefficients.")
_ = ray.get([a.do.remote(export_coeffs,config_dict['export']) for a in actors])
time_export_start = time.perf_counter()
if config_dict['export']['image']:
logging.info("Exporting corrected images.")
_ = ray.get([a.do.remote(apply_corrections,config_dict) for a in actors])
time_export_end = time.perf_counter()
logging.info("Export Time: {} sec.".format(time_export_end - time_export_start))
ray.shutdown()
time_end = time.perf_counter()
logging.info("Total Time: {} sec.".format(time_end - time_start))
def export_coeffs(hy_obj,export_dict):
'''Export correction coefficients to file.
'''
for correction in hy_obj.corrections:
if correction=='unsmooth':
continue
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s.json" % (correction,export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict,outfile)
def apply_corrections(hy_obj,config_dict):
'''Apply correction to image and export
to file.
'''
header_dict = hy_obj.get_header()
header_dict['data ignore value'] = hy_obj.no_data
header_dict['data type'] = 4
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s" % config_dict['export']["suffix"]
#Export all wavelengths
if len(config_dict['export']['subset_waves']) == 0:
if config_dict["resample"] == True:
hy_obj.resampler = config_dict['resampler']
waves= hy_obj.resampler['out_waves']
else:
waves = hy_obj.wavelengths
header_dict['bands'] = len(waves)
header_dict['wavelength'] = waves
header_dict['fwhm'] = hy_obj.fwhm
writer = WriteENVI(output_name,header_dict)
iterator = hy_obj.iterate(by = 'line', corrections = hy_obj.corrections,
resample = config_dict['resample'])
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line,iterator.current_line)
writer.close()
#Export subset of wavelengths
else:
waves = config_dict['export']['subset_waves']
bands = [hy_obj.wave_to_band(x) for x in waves]
waves = [round(hy_obj.wavelengths[x],2) for x in bands]
header_dict['bands'] = len(bands)
header_dict['wavelength'] = waves
header_dict['fwhm'] = [hy_obj.fwhm[x] for x in bands]
writer = WriteENVI(output_name,header_dict)
for b,band_num in enumerate(bands):
#print('image_correct.py Ln 123',hy_obj.corrections)
band = hy_obj.get_band(band_num,
corrections = hy_obj.corrections)
writer.write_band(band, b)
writer.close()
#Export masks
# does not work for precomputed json coeffs
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
masks = []
mask_names = []
for correction in config_dict["corrections"]:
if correction=='unsmooth':
continue
#for mask_type in config_dict[correction]['calc_mask']:
for mask_type in config_dict[correction]['apply_mask']:
mask_names.append(correction + '_' + mask_type[0])
masks.append(mask_create(hy_obj, [mask_type]))
header_dict['data type'] = 1
header_dict['bands'] = len(masks)
header_dict['band names'] = mask_names
header_dict['samples'] = hy_obj.columns
header_dict['lines'] = hy_obj.lines
header_dict['wavelength'] = []
header_dict['fwhm'] = []
header_dict['wavelength units'] = ''
header_dict['data ignore value'] = 255
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s_mask" % config_dict['export']["suffix"]
writer = WriteENVI(output_name,header_dict)
for band_num,mask in enumerate(masks):
mask =mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band(mask,band_num)
del masks
if __name__== "__main__":
main()

View File

@ -0,0 +1,225 @@
import json
import os
import warnings
import sys
import ray
import numpy as np
import time
import hytools as ht
from hytools.io.envi import *
from hytools.topo import calc_topo_coeffs
from hytools.brdf import calc_brdf_coeffs
from hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
def main():
time_start = time.perf_counter()
config_file = sys.argv[1]
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
images= config_dict["input_files"]
if ray.is_initialized():
ray.shutdown()
print("Using %s CPUs." % config_dict['num_cpus'])
ray.init(num_cpus = config_dict['num_cpus'])
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for image in images]
if config_dict['file_type'] == 'envi':
anc_files = config_dict["anc_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_files[image]) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'emit' or config_dict['file_type'] == 'ncav':
anc_files = config_dict["anc_files"]
if bool(config_dict["glt_files"]):
glt_files = config_dict["glt_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_path=anc_files[image],glt_path=glt_files[image]) for a,image in zip(actors,images)])
else:
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_path=anc_files[image]) for a,image in zip(actors,images)])
_ = ray.get([a.create_bad_bands.remote(config_dict['bad_bands']) for a in actors])
for correction in config_dict["corrections"]:
if correction =='topo':
time_topo_start = time.perf_counter()
#set_trace()
calc_topo_coeffs(actors,config_dict['topo'])
time_topo_end = time.perf_counter()
print("TOPO Time: {} sec.".format(time_topo_end - time_topo_start))
elif correction == 'brdf':
time_brdf_start = time.perf_counter()
calc_brdf_coeffs(actors,config_dict)
time_brdf_end = time.perf_counter()
print("BRDF Time: {} sec.".format(time_brdf_end - time_brdf_start))
if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0:
print("Exporting correction coefficients.")
_ = ray.get([a.do.remote(export_coeffs,config_dict['export']) for a in actors])
time_export_start = time.perf_counter()
if config_dict['export']['image']:
print("Exporting corrected images.")
_ = ray.get([a.do.remote(apply_corrections,config_dict) for a in actors])
time_export_end = time.perf_counter()
print("{} Export Time: {} sec.".format(images[0],time_export_end - time_export_start))
ray.shutdown()
time_end = time.perf_counter()
print("Total Time: {} sec.".format(time_end - time_start))
def export_coeffs(hy_obj,export_dict):
'''Export correction coefficients to file.
'''
for correction in hy_obj.corrections:
if correction=='unsmooth':
continue
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s.json" % (correction,export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict,outfile)
def apply_corrections(hy_obj,config_dict):
'''Apply correction to image and export
to file.
'''
use_glt_output_bool=False
if 'use_glt' in config_dict['export']:
use_glt_output_bool = config_dict['export']['use_glt']
if use_glt_output_bool==True:
header_dict = hy_obj.get_header(warp_glt=True)
else:
header_dict = hy_obj.get_header()
else:
header_dict = hy_obj.get_header()
header_dict['data ignore value'] = hy_obj.no_data
header_dict['data type'] = 4
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s" % config_dict['export']["suffix"]
#Export all wavelengths
if len(config_dict['export']['subset_waves']) == 0:
if config_dict["resample"] == True:
hy_obj.resampler = config_dict['resampler']
waves= hy_obj.resampler['out_waves']
else:
waves = hy_obj.wavelengths
header_dict['bands'] = len(waves)
header_dict['wavelength'] = waves
header_dict['fwhm'] = hy_obj.fwhm
writer = WriteENVI(output_name,header_dict)
iterator = hy_obj.iterate(by = 'line', corrections = hy_obj.corrections,
resample = config_dict['resample'])
if use_glt_output_bool==False:
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line,iterator.current_line)
else:
while not iterator.complete:
line = iterator.read_next()
line_ind=np.where(hy_obj.glt_y==iterator.current_line + 1)
redundant_line = line[hy_obj.glt_x[line_ind]-1]
writer.write_line_glt(redundant_line[:,:2],line_ind[0],line_ind[1])
writer.close()
#Export subset of wavelengths
else:
waves = config_dict['export']['subset_waves']
bands = [hy_obj.wave_to_band(x) for x in waves]
waves = [round(hy_obj.wavelengths[x],2) for x in bands]
header_dict['bands'] = len(bands)
header_dict['wavelength'] = waves
header_dict['fwhm'] = [hy_obj.fwhm[x] for x in bands]
writer = WriteENVI(output_name,header_dict)
if use_glt_output_bool==False:
for b,band_num in enumerate(bands):
band = hy_obj.get_band(band_num,
corrections = hy_obj.corrections)
writer.write_band(band, b)
else:
for b,band_num in enumerate(bands):
band = hy_obj.get_band(band_num,
corrections = hy_obj.corrections)
writer.write_band_glt(band,b, (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
writer.close()
#Export masks
# does not work for precomputed json coeffs
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
masks = []
mask_names = []
for correction in config_dict["corrections"]:
if correction=='glint':
continue
for mask_type in config_dict[correction]['apply_mask']:
mask_names.append(correction + '_' + mask_type[0])
masks.append(mask_create(hy_obj, [mask_type]))
header_dict['data type'] = 1
header_dict['bands'] = len(masks)
header_dict['band names'] = mask_names
header_dict['samples'] = hy_obj.columns
header_dict['lines'] = hy_obj.lines
header_dict['wavelength'] = []
header_dict['fwhm'] = []
header_dict['wavelength units'] = ''
header_dict['data ignore value'] = 255
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s_mask" % config_dict['export']["suffix"]
writer = WriteENVI(output_name,header_dict)
if use_glt_output_bool==False:
for band_num,mask in enumerate(masks):
mask =mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band(mask,band_num)
else:
for band_num,mask in enumerate(masks):
mask =mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band_glt(mask,band_num, (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
del masks
if __name__== "__main__":
main()

View File

@ -0,0 +1,413 @@
import os, sys
import numpy as np
import argparse
import pandas as pd
try:
from osgeo import gdal, osr, ogr
has_gdal=True
except ModuleNotFoundError:
print("No gdal installed")
has_gdal=False
import hytools as ht
from hytools.misc.point import local_point2spec, subset_band_list #*
#warnings.filterwarnings("ignore")
#np.seterr(divide='ignore', invalid='ignore')
def obs_point2spec(hyObj, img_row, img_col):
b1=hyObj.get_anc('sensor_az')[ img_row,img_col]
b2=hyObj.get_anc('sensor_zn')[ img_row,img_col]
b3=hyObj.get_anc('solar_az')[ img_row,img_col]
b4=hyObj.get_anc('solar_zn')[ img_row,img_col]
b5=hyObj.get_anc('slope')[ img_row,img_col]
b6=hyObj.get_anc('aspect')[img_row,img_col]
obs_data = np.vstack((img_row,img_col,b1,b2,b3,b4,b5,b6))
obs_df = pd.DataFrame(obs_data.T, columns=['img_row','img_col','sensor_az','sensor_zn','solar_az','solar_zn','slope','azimuth'])
return obs_df
def rasterize_polygon(hyObj,polygon_fn, key_id,use_glt_bool):
''' Rasterize polygon based on image georeference and boundary
'''
source_ds = ogr.Open(polygon_fn)
source_layer = source_ds.GetLayer()
field_list = []
ldefn = source_layer.GetLayerDefn()
for n in range(ldefn.GetFieldCount()):
fdefn = ldefn.GetFieldDefn(n)
field_list.append(fdefn.name)
print(field_list[1])
if not (key_id in field_list):
print('Field "',key_id,'" is not in the shapefile!')
return (None, None)
tmp_mem_driver=ogr.GetDriverByName('MEMORY')
dest = tmp_mem_driver.CreateDataSource('tempData')
mem_lyr = dest.CopyLayer(source_layer,'newlayer',['OVERWRITE=YES'])
FeatureCount= mem_lyr.GetFeatureCount()
# Add a new field
new_field = ogr.FieldDefn('tempFID', ogr.OFTInteger)
mem_lyr.CreateField(new_field)
lookup_dict={}
for i, feature in enumerate(mem_lyr):
feature.SetField('tempFID', i+1) # key step1
lookup_dict[str(i+1)]=feature.GetField(key_id)
mem_lyr.SetFeature(feature) # key step 2
if use_glt_bool:
out_col = hyObj.columns_glt
out_row = hyObj.lines_glt
out_transform = hyObj.glt_transform
out_proj = hyObj.glt_projection
else:
out_col = hyObj.columns
out_row = hyObj.lines
out_transform = hyObj.transform
out_proj = hyObj.projection
if FeatureCount<255:
target_ds = gdal.GetDriverByName('MEM').Create('', out_col, out_row, 1, gdal.GDT_Byte)
nodata_val=255
else:
if (FeatureCount>255 and FeatureCount < 32767):
target_ds = gdal.GetDriverByName('MEM').Create('', out_col, out_row, 1, gdal.GDT_Int16)
nodata_val=-9999
else: # >32767
target_ds = gdal.GetDriverByName('MEM').Create('', out_col, out_row, 1, gdal.GDT_Int32)
nodata_val=-9999
target_ds.SetGeoTransform(out_transform)
target_ds.SetProjection(out_proj)
band = target_ds.GetRasterBand(1)
band.SetNoDataValue(nodata_val)
gdal.RasterizeLayer(target_ds, [1], mem_lyr, options=["ATTRIBUTE=tempFID" ,"ALL_TOUCHED=FALSE"])
return (target_ds,lookup_dict)
def gen_df_polygon(hyObj, target_ds, lookup_dict, imgsrs2latlon, uid, use_glt_bool):
'''Generate a dataframe that stores the location, UID of all the points within the polygons
Parameters
----------
hyObj: HyTools file object
target_ds: GDAL raster dataset
one band raster in which each polygon has unique digital number
lookup_dict: dictionary
a distionary linking polygon DN in raster (target_ds) and the UID in the polygons attribute table
imgsrs2latlon: coordinate transformation object
transform from georeferenced coordinates of the image to LAT LON
uid: str
the user specified unique polygon ID name from the attribute table of the shapefile
Returns
-------
return_df: pandas dataframe
a dataframe that stores the location, UID of all the points within the polygons
'''
poly_raster=target_ds.GetRasterBand(1).ReadAsArray()
data_type=target_ds.GetRasterBand(1).DataType
if data_type == gdal.GDT_Byte:
ind=np.where((poly_raster>0) & (poly_raster<255) )
else:
if data_type == gdal.GDT_Int16:
ind=np.where((poly_raster>0) & (poly_raster<32767) )
else:
ind=np.where(poly_raster>0 )
total_point=len(ind[1])
print(total_point,' points')
if total_point==0:
# polygons are not intersecting the image
print( "No intersection.")
return None
return_df = pd.DataFrame(columns=['new_uid',uid,'img_col_glt','img_row_glt','img_col_raw','img_row_raw','lon','lat'])
return_df = return_df.fillna(0) # with 0s rather than NaNs
if use_glt_bool:
ul_x, new_x_resolution, new_x_rot, ul_y, new_y_rot, new_y_resolution = hyObj.glt_transform
else:
ul_x, new_x_resolution, new_x_rot, ul_y, new_y_rot, new_y_resolution = hyObj.transform
sub_id_dict = {}
for key in lookup_dict.keys():
sub_id_dict[key]=0
# add polygon ID, and point order number within the same polygon
for index in range(total_point):
row=ind[0][index]
col=ind[1][index]
if use_glt_bool:
row_post_glt = hyObj.glt_y[row,col] - 1
col_post_glt = hyObj.glt_x[row,col] - 1 # zero-based
else:
row_post_glt = row
col_post_glt = col
poly_id=poly_raster[row,col]
poly_id_code=lookup_dict[str(poly_id)]
x_coord = ul_x + (col+0.5)*new_x_resolution + (row+0.5)*new_x_rot
y_coord = ul_y + (col+0.5)*new_y_rot + (row+0.5)*new_y_resolution
lat, lon, _ = imgsrs2latlon.TransformPoint(x_coord, y_coord) #lon, lat, _ = imgsrs2latlon.TransformPoint(x_coord, y_coord)
sub_id = sub_id_dict[str(poly_id)]
sub_id_dict[str(poly_id)]+=1
temp_df = pd.DataFrame([['{}_{}'.format(poly_id_code, sub_id),poly_id_code, col,row,col_post_glt,row_post_glt, lon,lat]], columns=['new_uid', uid, 'img_col_glt','img_row_glt','img_col_raw','img_row_raw','lon','lat'])
return_df = pd.concat([return_df,temp_df],ignore_index=True) #return_df.append(temp_df,ignore_index=True)
return return_df
def local_polygon2spec(hyObj, poly_shp, uid, use_band_list=False, band_list=[],use_glt_bool=False):
"""Extract spectra with points within the boundary of polygons from the hyperspectral image
Steps:
1, Rasterize polygon based on image georeference
2, Get locations of the points of interest from the raster
3, Overlapping points and the hyperspectral image, and extract spectra
Parameters
----------
hyObj : HyTools file object
poly_shp: str
full filename of the polygon shapefile
uid: str
the user specified unique polygon ID name from the attribute table of the shapefile
use_band_list: boolean
default True; whether to use a subset of bands
band_list: list or numpy array
default is a blank list
if it is a list, it should be one like [5,6,7,8,9, 12]
if it is a numpy array, it should be the same size as hyObj.bad_bands with only True or False in the array
use_glt_bool: boolean
default False;
Returns
-------
point_df: pandas dataframe
it include all the location and spectra information for all points within the polygons
"""
if use_glt_bool:
img_srs = osr.SpatialReference(wkt=hyObj.glt_projection)
else:
img_srs = osr.SpatialReference(wkt=hyObj.projection)
latlon_wgs84 = osr.SpatialReference()
latlon_wgs84.ImportFromEPSG ( 4326 )
# LAT LON will be the only georeferenced coordinates kept in the result
imgsrs2latlon = osr.CoordinateTransformation (img_srs, latlon_wgs84)
# convert polygon geometry into raster with the same size of the image, and store UID in a lookup dictionary
target_ds, lookup_dict=rasterize_polygon(hyObj,poly_shp,uid,use_glt_bool)
if target_ds is None:
return None
# generate a dataframe that stores the location, UID of all the points within the polygons
point_df = gen_df_polygon(hyObj, target_ds, lookup_dict, imgsrs2latlon, uid, use_glt_bool)
if point_df is None :
return None
# extract full spectra information from image based on points locations
#spec_data = extract_from_point(hyObj, point_df)
spec_data = hyObj.get_pixels(point_df['img_row_raw'].values.astype(np.int16),point_df['img_col_raw'].values.astype(np.int16))
# determine the column names of the spectra dataframe based on wavelengths
if hyObj.wavelength_units.lower()[:4]=='micr':
new_band_name = ['B{:0.3f}'.format(x) for x in hyObj.wavelengths]
elif hyObj.wavelength_units.lower()[:4]=='nano' :
new_band_name = ['B{:04d}'.format(int(x)) for x in hyObj.wavelengths]
else:
new_band_name = ['B{:d}'.format(x+1) for x in range(hyObj.bands)]
if hyObj.file_type in ['ncav']:
spec_df = pd.DataFrame(spec_data, columns=new_band_name) #spec_df = pd.DataFrame(spec_data.T, columns=new_band_name)
else:
spec_df = pd.DataFrame(spec_data, columns=new_band_name)
# perform the subsetting of the columns in the dataframe according to the band_list or hyObj.bad_bands
spec_df = subset_band_list(hyObj,spec_df,use_band_list, band_list)
# merge location information and spectra information
point_df = pd.concat([point_df,spec_df], axis=1, join='inner')
return point_df
def main():
parser = argparse.ArgumentParser(description='Export fractional cover image by EndMember csv')
parser.add_argument('-i', type=str, required=True,help='Input image file name')
parser.add_argument('-pnt', type=str, required=True,help='CSV filename or shapefile')
parser.add_argument('-od', type=str, required=True,help='Output folder')
parser.add_argument('-uid', type=str, required=True,help='Unique ID in the vector file')
parser.add_argument('-epsg', type=str, required=False, help='UTM EPSG code')
parser.add_argument('-anc', type=str, required=False, help='Ancillary file / OBS file')
parser.add_argument('-glt', type=str, default=None, required=False, help='External GLT ENVI file')
parser.add_argument('-dt', type=str, default='envi', required=False, help="Data type of the image (default 'envi') ['envi','emit','ncav']", choices=['envi','emit','ncav'])
parser.add_argument('-nnb', type=int, required=False, default=4,help='How many neighbors in the image should be sampled from the center', choices=[0,4,8])
args = parser.parse_args()
in_image_file = args.i
out_path = args.od
pnt_file = args.pnt
uid = args.uid
epsg_code = args.epsg
n_neighbor_chose = args.nnb
file_format = args.dt
if not args.glt is None:
glt_dict = {
"glt_x": [args.glt,1],
"glt_y": [args.glt,0]
}
else:
glt_dict = {}
if args.anc: #not args.anc is None
anc_dict = {
"path_length": [
args.anc,
0
],
"sensor_az": [
args.anc,
1
],
"sensor_zn": [
args.anc,
2
],
"solar_az": [
args.anc,
3
],
"solar_zn": [
args.anc,
4
],
"phase": [
args.anc,
5
],
"slope": [
args.anc,
6
],
"aspect": [
args.anc,
7
],
"cosine_i": [
args.anc,
8
],
"utc_time": [
args.anc,
9
]
}
else:
anc_dict = None
if pnt_file.endswith('csv'):
if args.epsg is None:
epsg_code = None
else:
epsg_code = args.epsg
hy_obj = ht.HyTools()
hy_obj.read_file(in_image_file,file_format, glt_path=glt_dict, anc_path =anc_dict)
lookup_glt_bool = False
if file_format=='emit':
lookup_glt_bool=True
else:
if glt_dict:
lookup_glt_bool=True
if has_gdal and (pnt_file.endswith('.shp') or pnt_file.endswith('.geojson') or pnt_file.endswith('.json')):
out_df = local_polygon2spec(hy_obj, pnt_file, uid, use_band_list=False, band_list=[],use_glt_bool=lookup_glt_bool)
else:
if not pnt_file.endswith('.csv'):
print("Point location file is not in CSV format")
return
pnt_df = pd.read_csv(pnt_file)
if 'x_coord' in pnt_df.columns and 'y_coord' in pnt_df.columns:
out_df = local_point2spec(hy_obj, pnt_file, uid, 'x_coord', 'y_coord', epsg_code, n_neighbor=n_neighbor_chose, use_band_list=False, band_list=[],use_glt_bool=lookup_glt_bool)
elif 'lat' in pnt_df.columns and 'lon' in pnt_df.columns:
print('latlon inside point file.')
out_df = local_point2spec(hy_obj, pnt_file, uid, 'lon', 'lat', epsg_code, n_neighbor=n_neighbor_chose, use_band_list=False, band_list=[],use_glt_bool=lookup_glt_bool)
else:
print("Unknown coordinates column names")
return
if out_df is None:
return
img_base_name=os.path.basename(in_image_file).split('.')[0]
out_df.insert(loc = 1,
column = 'flightline',
value = [img_base_name.split('_')[0]]*out_df.shape[0])
out_df.to_csv(out_path+img_base_name+"_spec_df_asvc.csv",index=False)
if not args.anc is None:
img_row = out_df['img_row_raw'].values.astype(np.int16)
img_col = out_df['img_col_raw'].values.astype(np.int16)
out_obs_df = obs_point2spec(hy_obj, img_row, img_col)
out_obs_df.insert(loc = 0,
column = 'flightline',
value = [img_base_name.split('_')[0]]*out_df.shape[0])
out_obs_df.insert(loc = 0,
column = 'new_uid',
value = out_df['new_uid'])
out_obs_df.to_csv(out_path+img_base_name+"_obs_df.csv",index=False)
if __name__== "__main__":
main()

View File

@ -0,0 +1,89 @@
'''neon2envi.py
TODO: Add phase and UTC time to ancillary output
TODO: Implement progress bar like from this example:
https://docs.ray.io/en/master/auto_examples/progress_bar.html
'''
import argparse
import os
import ray
import numpy as np
import hytools as ht
from hytools.io.envi import WriteENVI
def main():
'''This command line tool exports NEON AOP HDF imaging spectroscopy data
to an ENVI formatted binary file, with the option of also exporting
ancillary data following formatting used by NASA JPL for AVIRIS
observables. The script utilizes ray to export images in parallel.
'''
parser = argparse.ArgumentParser(description = "Convert NEON AOP H5 to ENVI format")
parser.add_argument('images',help="Input image pathnames", nargs='*')
parser.add_argument('output_dir',help="Output directory", type = str)
parser.add_argument("-anc", help="Output ancillary", required=False, action='store_true')
args = parser.parse_args()
if not args.output_dir.endswith("/"):
args.output_dir+="/"
if ray.is_initialized():
ray.shutdown()
ray.init(num_cpus = len(args.images))
hytool = ray.remote(ht.HyTools)
actors = [hytool.remote() for image in args.images]
_ = ray.get([a.read_file.remote(image,'neon') for a,image in zip(actors,args.images)])
def neon_to_envi(hy_obj):
basemame = os.path.basename(os.path.splitext(hy_obj.file_name)[0])
print("Exporting %s " % basemame)
output_name = args.output_dir+ basemame
writer = WriteENVI(output_name,hy_obj.get_header())
iterator = hy_obj.iterate(by = 'chunk')
pixels_processed = 0
while not iterator.complete:
chunk = iterator.read_next()
pixels_processed += chunk.shape[0]*chunk.shape[1]
writer.write_chunk(chunk,iterator.current_line,iterator.current_column)
if iterator.complete:
writer.close()
def export_anc(hy_obj):
anc_header = hy_obj.get_header()
anc_header['bands'] = 10
anc_header['band_names'] = ['path length', 'to-sensor azimuth',
'to-sensor zenith','to-sun azimuth',
'to-sun zenith','phase', 'slope',
'aspect', 'cosine i','UTC time']
anc_header['wavelength units'] = np.nan
anc_header['wavelength'] = np.nan
anc_header['data type'] = 4
output_name = args.output_dir+ os.path.basename(os.path.splitext(hy_obj.file_name)[0])
writer = WriteENVI(output_name + "_ancillary", anc_header)
writer.write_band(hy_obj.get_anc("path_length"),0)
writer.write_band(hy_obj.get_anc("sensor_az",radians = False),1)
writer.write_band(hy_obj.get_anc("sensor_zn",radians = False),2)
writer.write_band(hy_obj.get_anc("solar_az",radians = False),3)
writer.write_band(hy_obj.get_anc("solar_zn",radians = False),4)
#writer.write_band(hy_obj.get_anc("phase placeholder"),5)
writer.write_band(hy_obj.get_anc("slope",radians = False),6)
writer.write_band(hy_obj.get_anc("aspect",radians = False),7)
writer.write_band(hy_obj.cosine_i(),8)
#writer.write_band('UTC time placeholder',9)
writer.close()
_ = ray.get([a.do.remote(neon_to_envi) for a in actors])
if args.anc:
print("\nExporting ancillary data")
_ = ray.get([a.do.remote(export_anc) for a in actors])
print("Export complete.")
if __name__== "__main__":
main()

View File

@ -0,0 +1,140 @@
import json
import os
import warnings
import sys
import numpy as np
import h5py
import hytools as ht
from hytools.io.envi import *
from hytools.brdf import calc_flex_single_post
from hytools.glint import set_glint_parameters
from hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
def main():
config_file = sys.argv[1]
sample_folder = sys.argv[2]
load_reflectance_mode = int(sys.argv[3])
if not load_reflectance_mode in [1,0]:
print("Please set the mode for loading H5 reflectance (0-Whole;1-By Band)")
return
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
images = config_dict["input_files"]
brdf_dict = config_dict['brdf']
sample_h5_list = []
for image in images:
tmp_file_name = f"{sample_folder}/{os.path.splitext(os.path.basename(image))[0]}_prebrdf_sample.h5"
if os.path.exists(tmp_file_name):
sample_h5_list+=[tmp_file_name]
sample_dict = load_sample_h5(sample_h5_list,load_reflectance_mode)
if isinstance(brdf_dict['solar_zn_type'],str):
if brdf_dict['solar_zn_type'] == 'scene':
brdf_dict["solar_zn_norm_radians"]=float(sample_dict['mean_solar_zn'])
print("Scene average solar zenith angle : %s degrees" % round(np.degrees(brdf_dict["solar_zn_norm_radians"]),3))
elif isinstance(brdf_dict['solar_zn_type'],float):
brdf_dict["solar_zn_norm_radians"]=brdf_dict['solar_zn_type']
else:
print('Unrecognized solar zenith angle normalization')
calc_flex_single_post(sample_dict,brdf_dict,load_reflectance_mode)
if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0:
print("Exporting correction coefficients.")
export_coeffs_brdf(sample_dict,config_dict['export'],images)
def load_sample_h5(h5_file_list,load_reflectance_mode):
'''Load information from H5 files, and return a dictionary with all the info needed.
'''
combine_refl = []
combine_kernel = []
solar_zn_list = []
ndi_list = []
bad_bands=None #get from the 1st image
for i_order, h5name in enumerate(h5_file_list):
h5_obj = h5py.File(h5name, "r")
wavelist = h5_obj["wavelengths"][()]
set_solar_zn = h5_obj["kernels_samples"].attrs['set_solar_zn']
refl_samples = h5_obj["reflectance_samples"][()]
kernel_samples = h5_obj["kernels_samples"][()]
if i_order==0:
bad_bands = h5_obj["bad_bands"][()]
h5_obj.close()
sample_nir=refl_samples[:,get_wave(850,wavelist)]
sample_red=refl_samples[:,get_wave(660,wavelist)]
sample_ndi = (sample_nir-sample_red)/(sample_nir+sample_red)
ndi_list+=[sample_ndi]
if load_reflectance_mode==0:
combine_refl+=[refl_samples]
refl_samples=None
else:
combine_refl+=[[h5name]]
solar_zn_list+=[set_solar_zn]
combine_kernel+=[kernel_samples]
return {
"kernels_samples":np.concatenate(combine_kernel,axis=0),
"reflectance_samples":np.concatenate(combine_refl,axis=0),
"ndi_samples":np.concatenate(ndi_list),
"mean_solar_zn": np.array(solar_zn_list).mean(),
"bad_bands":bad_bands,
}
def export_coeffs_brdf(data_dict,export_dict,images):
'''Export correction coefficients to file.
'''
for image in images:
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(image))[0]
coeff_file += "_%s_coeffs_%s_chtc.json" % ("brdf",export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
corr_dict = data_dict['brdf_dict'] #hy_obj.brdf
json.dump(corr_dict,outfile)
def get_wave(wave,wavelengths):
"""Return the band image corresponding to the input wavelength.
If not an exact match the closest wavelength will be returned.
Args:
wave (float): Wavelength in image units.
wavelengths (list): Wavelength list
Returns:
band index 0-based.
"""
if (wave > wavelengths.max()) | (wave < wavelengths.min()):
print("Input wavelength outside wavelength range!")
band_ind = None
else:
band_ind = np.argmin(np.abs(wavelengths - wave))
return band_ind
if __name__== "__main__":
main()

View File

@ -0,0 +1,238 @@
import json
import os
import warnings
import sys
import numpy as np
from scipy.optimize import nnls
import h5py
import hytools as ht
from hytools.io.envi import *
from hytools.masks import mask_create
from hytools.topo.c import calc_c
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
def main():
config_file = sys.argv[1]
sample_folder = sys.argv[2]
topo_subgroup_id = sys.argv[3]
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
images = []
if config_dict["topo"]["subgrouped"]:
topo_dict = config_dict['topo']
subgroup = topo_dict["subgroup"]
sample_h5_list=[]
for each_img_name in subgroup.keys():
each_h5_name = f"{sample_folder}/{os.path.splitext(os.path.basename(each_img_name))[0]}_pretopo_sample.h5"
if not os.path.exists(each_h5_name):
print(f"File:'{each_h5_name}' is not found, skip this...")
continue
if subgroup[each_img_name]==topo_subgroup_id:
sample_h5_list+=[each_h5_name]
images+=[each_img_name]
if len(sample_h5_list)==0:
print(f"Cannot find subgroup '{topo_subgroup_id}', exit.")
return
sample_dict = load_sample_h5(sample_h5_list)
calc_topo_single_post(sample_dict,topo_dict) #,update topo coeffs
apply_topo_scsc(sample_dict) # update reflectance
if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0:
print("Exporting correction coefficients.")
export_coeffs_topo(sample_dict,config_dict['export'],images,sample_h5_list)
for image_order, imagename in enumerate(images):
export_h5(imagename,config_dict['export'],sample_dict[sample_h5_list[image_order]])
else:
print("No subgroup is defined, exit.")
def load_sample_h5(h5_file_list):
h5_all_dict = {}
bad_bands=None #get from the 1st image
for i_order, h5name in enumerate(h5_file_list):
h5_obj = h5py.File(h5name, "r")
wavelist = h5_obj["wavelengths"][()]
full_image_wavelist = h5_obj["image_wavelengths"][()]
set_solar_zn = h5_obj["kernels_samples"].attrs['set_solar_zn']
refl_samples = h5_obj["reflectance_samples"][()]
kernel_samples = h5_obj["kernels_samples"][()]
slope_samples = h5_obj["slope_samples"][()]
cosine_i_samples = h5_obj["cosine_i_samples"][()]
bad_bands = h5_obj["bad_bands"][()]
h5_obj.close()
sample_nir=refl_samples[:,get_wave(850,wavelist)]
sample_red=refl_samples[:,get_wave(660,wavelist)]
sample_ndi = (sample_nir-sample_red)/(sample_nir+sample_red)
h5_all_dict[h5name] = {
"kernels_samples":kernel_samples,
"reflectance_samples":refl_samples,
"ndi_samples":sample_ndi,
"bad_bands":bad_bands,
"full_image_wavelist":full_image_wavelist,
"set_solar_zn":set_solar_zn,
"wavelist":wavelist,
"slope_samples":slope_samples,
"cosine_i_samples":cosine_i_samples,
"topo_dict":None,
}
return h5_all_dict
def export_coeffs_topo(data_dict,export_dict,images,h5_list):
'''Export correction coefficients to file.
'''
for img_order, image in enumerate(images):
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(image))[0]
coeff_file += "_%s_coeffs_%s_chtc.json" % ("topo",export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
corr_dict = data_dict[h5_list[img_order]]['topo_dict']
json.dump(corr_dict,outfile)
def get_wave(wave,wavelengths):
"""Return the band image corresponding to the input wavelength.
If not an exact match the closest wavelength will be returned.
Args:
wave (float): Wavelength in image units.
wavelengths (list): Wavelength list
Returns:
band index 0-based.
"""
if (wave > wavelengths.max()) | (wave < wavelengths.min()):
print("Input wavelength outside wavelength range!")
band_ind = None
else:
band_ind = np.argmin(np.abs(wavelengths - wave))
return band_ind
def export_h5(imagename,export_dict,obj_dict):
out_filename = f"{export_dict['output_dir']}{os.path.splitext(os.path.basename(imagename))[0]}_prebrdf_sample.h5"
h5_obj = h5py.File(out_filename, "w")
h5_obj.attrs['Image Name']=f"{os.path.splitext(os.path.basename(imagename))[0]}"
dset1 = h5_obj.create_dataset("kernels_samples", data=obj_dict["kernels_samples"])
dset2 = h5_obj.create_dataset("reflectance_samples", data=obj_dict["reflectance_samples"])
dset3 = h5_obj.create_dataset("wavelengths", data=obj_dict["wavelist"])
dset1.attrs['set_solar_zn']=obj_dict['set_solar_zn']
dset1.attrs['kernels_names']='["Volume","Geometry"]'
dset1.attrs['Solar Zenith Unit']="Radians"
dset4 = h5_obj.create_dataset("image_wavelengths", data=obj_dict["full_image_wavelist"])
dset5 = h5_obj.create_dataset("bad_bands", data=obj_dict["bad_bands"])
h5_obj.close()
print(f"{out_filename} saved.")
def calc_topo_single_post(sample_dict,topo_dict):
combine_refl = []
combine_cos_i = []
combine_slope = []
ndi_list = []
for h5_name in sample_dict.keys():
sub_dict = sample_dict[h5_name]
ndi_list+=[sub_dict["ndi_samples"]]
combine_refl+=[sub_dict["reflectance_samples"]]
combine_cos_i+=[sub_dict["cosine_i_samples"]]
combine_slope+=[sub_dict["slope_samples"]]
bad_bands = sub_dict["bad_bands"]
combine_refl=np.concatenate(combine_refl,axis=0)
combine_cos_i=np.concatenate(combine_cos_i,axis=0)
combine_slope=np.concatenate(combine_slope,axis=0)
ndi_list=np.concatenate(ndi_list,axis=0)
mask = np.ones(ndi_list.shape).astype(bool)
mask &= (ndi_list >= float(topo_dict['calc_mask'][0][1]['min'])) & (ndi_list <= float(topo_dict['calc_mask'][0][1]['max']))
mask &= (combine_slope >= float(topo_dict['calc_mask'][1][1]['min'])) & (combine_slope <= float(topo_dict['calc_mask'][1][1]['max']))
mask &= (combine_cos_i >= float(topo_dict['calc_mask'][2][1]['min'])) & (combine_cos_i <= float(topo_dict['calc_mask'][2][1]['max']))
feasible_sample_count=np.count_nonzero(mask)
if feasible_sample_count>10:
used_reflectance_samples = combine_refl[mask==1,:]
used_cos_i = combine_cos_i[mask==1]
topo_dict['coeffs'] = {}
band_cursor=0
for band_num,band in enumerate(bad_bands):
if ~band:
topo_dict['coeffs'][band_num] = calc_c(used_reflectance_samples[:,band_cursor],used_cos_i,
fit_type=topo_dict['c_fit_type'])
band_cursor+=1
else:
topo_dict['coeffs'] = {}
band_cursor=0
for band_num,band in enumerate(bad_bands):
if ~band:
topo_dict['coeffs'][band_num] = 100000.0
band_cursor+=1
for h5_name in sample_dict.keys():
sub_dict = sample_dict[h5_name]
sub_dict["topo_dict"] = topo_dict
def apply_topo_scsc(sample_dict):
for h5_name in sample_dict.keys():
data_dict = sample_dict[h5_name]
topo_dict = data_dict["topo_dict"]
slope_samples = data_dict['slope_samples']
c1 = np.cos(slope_samples) * np.cos(data_dict['set_solar_zn'])
cosine_i = data_dict['cosine_i_samples']
ndi_list = data_dict['ndi_samples']
refl_samples = data_dict['reflectance_samples']
C_arr = np.array(list(topo_dict['coeffs'].values()))
mask = np.ones(ndi_list.shape).astype(bool)
# mask order in the config matters here
mask &= (ndi_list >= float(topo_dict['apply_mask'][0][1]['min'])) & (ndi_list <= float(topo_dict['apply_mask'][0][1]['max']))
mask &= (slope_samples >= float(topo_dict['apply_mask'][1][1]['min'])) & (slope_samples <= float(topo_dict['apply_mask'][1][1]['max']))
mask &= (cosine_i >= float(topo_dict['apply_mask'][2][1]['min'])) & (cosine_i <= float(topo_dict['apply_mask'][2][1]['max']))
for band_order in range(refl_samples.shape[1]):
band = np.copy(refl_samples[:,band_order])
correction_factor = (c1 + C_arr[band_order])/(cosine_i + C_arr[band_order])
band[mask] = band[mask]*correction_factor[mask]
refl_samples[:,band_order]=band
if __name__== "__main__":
main()

View File

@ -0,0 +1,150 @@
import json
import os
import warnings
import sys
import numpy as np
import hytools as ht
from hytools.io.envi import *
from hytools.topo import load_topo_precomputed
from hytools.brdf import load_brdf_precomputed
from hytools.glint import set_glint_parameters_single
from hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
def main():
config_file = sys.argv[1]
image_order = int(sys.argv[2])
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
image = config_dict["input_files"][image_order]
actor = ht.HyTools()
if config_dict['file_type'] == 'envi':
anc_file = config_dict["anc_files"][image]
actor.read_file(image,config_dict['file_type'],anc_file)
elif config_dict['file_type'] == 'neon':
actor.read_file(image,config_dict['file_type'])
actor.create_bad_bands(config_dict['bad_bands'])
for correction in config_dict["corrections"]:
if correction =='topo':
if config_dict['topo']['type'] == 'precomputed':
print("Using precomputed topographic coefficients.")
load_topo_precomputed(actor,config_dict['topo'])
actor.corrections.append('topo')
else:
print('Only precomputed topographic coefficients are accepted. Quit.')
return
elif correction == 'brdf':
if config_dict['brdf']['type'] == 'precomputed':
print("Using precomputed BRDF coefficients.")
load_brdf_precomputed(actor,config_dict['brdf'])
actor.corrections.append('brdf')
else:
print('Only precomputed BRDF coefficients are accepted. Quit.')
return
elif correction == 'glint':
set_glint_parameters_single(actor, config_dict)
if config_dict['export']['image']:
print("Exporting corrected image.")
apply_corrections_single(actor,config_dict)
def apply_corrections_single(hy_obj,config_dict):
'''Apply correction to image and export
to file.
'''
header_dict = hy_obj.get_header()
header_dict['data ignore value'] = hy_obj.no_data
header_dict['data type'] = 4
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s" % config_dict['export']["suffix"]
#Export all wavelengths
if len(config_dict['export']['subset_waves']) == 0:
if config_dict["resample"] == True:
hy_obj.resampler = config_dict['resampler']
waves= hy_obj.resampler['out_waves']
else:
waves = hy_obj.wavelengths
header_dict['bands'] = len(waves)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name,header_dict)
iterator = hy_obj.iterate(by='line', corrections=hy_obj.corrections,
resample=config_dict['resample'])
while not iterator.complete:
line = iterator.read_next()
writer.write_line(line,iterator.current_line)
writer.close()
#Export subset of wavelengths
else:
waves = config_dict['export']['subset_waves']
bands = [hy_obj.wave_to_band(x) for x in waves]
waves = [round(hy_obj.wavelengths[x],2) for x in bands]
header_dict['bands'] = len(bands)
header_dict['wavelength'] = waves
writer = WriteENVI(output_name,header_dict)
for b,band_num in enumerate(bands):
band = hy_obj.get_band(band_num,
corrections=hy_obj.corrections)
writer.write_band(band, b)
writer.close()
#Export masks
if (config_dict['export']['masks']) and (len(config_dict["corrections"]) > 0):
masks = []
mask_names = []
for correction in config_dict["corrections"]:
for mask_type in getattr(hy_obj,correction)['apply_mask']:
mask_names.append(correction + '_' + mask_type[0])
masks.append(mask_create(hy_obj, [mask_type]))
header_dict['data type'] = 1
header_dict['bands'] = len(masks)
header_dict['band names'] = mask_names
header_dict['samples'] = hy_obj.columns
header_dict['lines'] = hy_obj.lines
header_dict['wavelength'] = []
header_dict['fwhm'] = []
header_dict['wavelength units'] = ''
header_dict['data ignore value'] = 255
output_name = config_dict['export']['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
output_name += "_%s_mask" % config_dict['export']["suffix"]
writer = WriteENVI(output_name,header_dict)
for band_num,mask in enumerate(masks):
mask = mask.astype(int)
mask[~hy_obj.mask['no_data']] = 255
writer.write_band(mask,band_num)
del masks
if __name__== "__main__":
main()

View File

@ -0,0 +1,116 @@
import json
import os
import warnings
import sys
import numpy as np
import h5py
import hytools as ht
from hytools.io.envi import *
from hytools.topo import calc_topo_coeffs_single
from hytools.brdf import calc_brdf_coeffs_pre
from hytools.glint import set_glint_parameters
from hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
def main():
config_file = sys.argv[1]
flightline_index = int(sys.argv[2])
group_topo_bool = bool(int(sys.argv[3]))
print("For Group Topo (no correction applied to samples)",group_topo_bool)
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
image = config_dict["input_files"][flightline_index]
Ht_Obj = ht.HyTools()
if config_dict['file_type'] == 'envi':
anc_files = config_dict["anc_files"]
Ht_Obj.read_file(image,config_dict['file_type'],anc_files[image])
elif config_dict['file_type'] == 'neon':
Ht_Obj.read_file(image,config_dict['file_type'])
elif config_dict['file_type'] == 'ncav' or config_dict['file_type'] == 'emit':
anc_files = config_dict["anc_files"]
Ht_Obj.read_file(image,config_dict['file_type'],anc_files[image])
Ht_Obj.create_bad_bands(config_dict['bad_bands'])
non_water_count = np.count_nonzero(Ht_Obj.ndi()[Ht_Obj.mask['no_data']]>0.001)
if non_water_count<50:
print("Not enough ground pixels... exit")
return
for correction in config_dict["corrections"]:
if correction =='topo':
if group_topo_bool is False: # get single line topo coeffs
calc_topo_coeffs_single(Ht_Obj,config_dict['topo'])
elif correction == 'brdf':
data_dict=calc_brdf_coeffs_pre(Ht_Obj,config_dict)
print(f"{data_dict['kernel_samples'].shape[0]} pixels extracted.")
export_h5(Ht_Obj,config_dict['export'],data_dict,group_topo_bool)
if (config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0) and group_topo_bool is False:
print("Exporting correction coefficients.")
export_coeffs_topo(Ht_Obj,config_dict['export'])
def export_h5(hy_obj,export_dict,obj_dict,topo_bool):
if topo_bool==False:
out_filename = f"{export_dict['output_dir']}{os.path.splitext(os.path.basename(hy_obj.file_name))[0]}_prebrdf_sample.h5"
else:
out_filename = f"{export_dict['output_dir']}{os.path.splitext(os.path.basename(hy_obj.file_name))[0]}_pretopo_sample.h5"
h5_obj = h5py.File(out_filename, "w")
dset1 = h5_obj.create_dataset("kernels_samples", data=obj_dict["kernel_samples"]) #chunks=(50, 50), compression="gzip"
dset2 = h5_obj.create_dataset("reflectance_samples", data=obj_dict["reflectance_samples"])
dset3 = h5_obj.create_dataset("wavelengths", data=obj_dict["used_band"])
dset1.attrs['set_solar_zn']=obj_dict['set_solar_zn']
dset1.attrs['kernels_names']='["Volume","Geometry"]'
dset1.attrs['Solar Zenith Unit']="Radians"
dset4 = h5_obj.create_dataset("image_wavelengths", data=np.array(hy_obj.wavelengths))
dset5 = h5_obj.create_dataset("bad_bands", data=np.array(hy_obj.bad_bands))
dset6 = h5_obj.create_dataset("slope_samples", data=obj_dict["slope_samples"])
dset6.attrs['Slope Unit']="Radians"
dset7 = h5_obj.create_dataset("cosine_i_samples", data=obj_dict["cos_i_samples"])
h5_obj.close()
print(f"{out_filename} saved.")
def export_coeffs_topo(hy_obj,export_dict):
'''Export correction coefficients to file.
'''
for correction in hy_obj.corrections:
if not correction == 'topo':
continue
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s_chtc.json" % (correction,export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
elif correction == 'glint':
continue
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict,outfile)
if __name__== "__main__":
main()

View File

@ -0,0 +1,100 @@
import json
import os
import warnings
import sys
import numpy as np
import h5py
import hytools as ht
from hytools.io.envi import *
from hytools.topo import calc_topo_coeffs_single
from hytools.brdf import calc_brdf_coeffs_pre
from hytools.glint import set_glint_parameters
from hytools.masks import mask_create
warnings.filterwarnings("ignore")
np.seterr(divide='ignore', invalid='ignore')
def main():
config_file = sys.argv[1]
flightline_index = int(sys.argv[2])
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
image = config_dict["input_files"][flightline_index]
Ht_Obj = ht.HyTools()
if config_dict['file_type'] == 'envi':
anc_files = config_dict["anc_files"]
Ht_Obj.read_file(image,config_dict['file_type'],anc_files[image])
elif config_dict['file_type'] == 'neon':
Ht_Obj.read_file(image,config_dict['file_type'])
Ht_Obj.create_bad_bands(config_dict['bad_bands'])
for correction in config_dict["corrections"]:
if correction =='topo':
calc_topo_coeffs_single(Ht_Obj,config_dict['topo'])
elif correction == 'brdf':
data_dict=calc_brdf_coeffs_pre(Ht_Obj,config_dict)
print(f"{data_dict['kernel_samples'].shape[0]} pixels extracted.")
export_h5(Ht_Obj,config_dict['export'],data_dict)
if config_dict['export']['coeffs'] and len(config_dict["corrections"]) > 0:
print("Exporting correction coefficients.")
export_coeffs_topo(Ht_Obj,config_dict['export'])
def export_h5(hy_obj,export_dict,obj_dict):
out_filename = f"{export_dict['output_dir']}{os.path.splitext(os.path.basename(hy_obj.file_name))[0]}_prebrdf_sample.h5"
h5_obj = h5py.File(out_filename, "w")
dset1 = h5_obj.create_dataset("kernels_samples", data=obj_dict["kernel_samples"]) #compression="gzip"
dset2 = h5_obj.create_dataset("reflectance_samples", data=obj_dict["reflectance_samples"])
dset3 = h5_obj.create_dataset("wavelengths", data=obj_dict["used_band"])
dset1.attrs['set_solar_zn']=obj_dict['set_solar_zn']
dset1.attrs['kernels_names']='["Volume","Geometry"]'
dset4 = h5_obj.create_dataset("image_wavelengths", data=np.array(hy_obj.wavelengths))
dset5 = h5_obj.create_dataset("bad_bands", data=np.array(hy_obj.bad_bands))
h5_obj.close()
print(f"{out_filename} saved.")
def export_coeffs_topo(hy_obj,export_dict):
'''Export correction coefficients to file.
'''
for correction in hy_obj.corrections:
if not (correction == 'topo'):
continue
coeff_file = export_dict['output_dir']
coeff_file += os.path.splitext(os.path.basename(hy_obj.file_name))[0]
coeff_file += "_%s_coeffs_%s_chtc.json" % (correction,export_dict["suffix"])
with open(coeff_file, 'w') as outfile:
if correction == 'topo':
corr_dict = hy_obj.topo
elif correction == 'glint':
continue
else:
corr_dict = hy_obj.brdf
json.dump(corr_dict,outfile)
if __name__== "__main__":
main()

View File

@ -0,0 +1,35 @@
import sys, os
import multiprocessing
import subprocess, json
exec_str="python ../no_ray/image_correct_export_image.py "
def run_command(command):
print(command)
subprocess.run(command,shell=True)
def main():
config_file = sys.argv[1]
total_count = int(sys.argv[2])
worker_count = min(os.cpu_count()-1,total_count)
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
if total_count > len(config_dict["input_files"]):
print("Out of upper bound")
return
pool = multiprocessing.Pool(processes=worker_count)
commands = [f"{exec_str} {config_file} {order}" for order in range(total_count)]
pool.map(run_command, commands)
pool.close()
pool.join() # Wait for all subprocesses to finish
print('All image export is done.')
if __name__== "__main__":
main()

View File

@ -0,0 +1,42 @@
import sys, os
import multiprocessing
import subprocess, json
exec_str="python ../no_ray/image_correct_get_sample_chtc.py "
merge_str="python ../no_ray/image_correct_combine_sample_chtc.py {} {}"
def run_command(command):
print(command)
subprocess.run(command,shell=True)
def main():
config_file = sys.argv[1]
total_count = int(sys.argv[2])
worker_count = min(os.cpu_count()-1,total_count)
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
h5_folder=config_dict["export"]["output_dir"]
if total_count > len(config_dict["input_files"]):
print("Out of upper bound")
return
pool = multiprocessing.Pool(processes=worker_count)
commands = [f"{exec_str} {config_file} {order}" for order in range(total_count)]
pool.map(run_command, commands)
pool.close()
pool.join() # Wait for all subprocesses to finish
print('All extractions are done.')
subprocess.run(merge_str.format(config_file,h5_folder),shell=True)
if __name__== "__main__":
main()

View File

@ -0,0 +1,86 @@
import sys, os, time
import multiprocessing
import subprocess, json
exec_str="python ../noray/image_correct_get_raw_sample_chtc.py "
topo_group_str = "python ../noray/image_correct_combine_topo_sample_chtc.py {} {} {}"
merge_str="python ../noray/image_correct_combine_sample_chtc.py {} {}"
def parse_group_info(group_dict,full_img_list):
out_group_dict={}
out_group_dict_short={}
for img_name in full_img_list:
subgroup_name = group_dict[img_name]
if subgroup_name in out_group_dict:
out_group_dict[subgroup_name]+=[img_name]
out_group_dict_short[subgroup_name]+=[full_img_list.index(img_name)]
else:
out_group_dict[subgroup_name]=[img_name]
out_group_dict_short[subgroup_name]=[full_img_list.index(img_name)]
return out_group_dict, out_group_dict_short
def run_command(command):
subprocess.run(command,shell=True)
def run_step2(params):
config_file,h5_folder,current_group_name=params
#print(topo_group_str.format(config_file,h5_folder,current_group_name))
subprocess.run(topo_group_str.format(config_file,h5_folder,current_group_name),shell=True)
def main():
config_file = sys.argv[1]
h5_folder = sys.argv[2]
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
image_list = config_dict["input_files"]
total_count = int(len(image_list))
subgroup=config_dict['topo']['subgroup']
group_meta_dict, worker_unfinished=parse_group_info(subgroup,image_list)
level1_worker_count = min(os.cpu_count()-1,total_count)
print(level1_worker_count)
with multiprocessing.Pool(processes=level1_worker_count) as pool:
subgroup_status_dict={}
for sub_group_name in group_meta_dict:
step1_status=[]
for order_in_full_list in worker_unfinished[sub_group_name]:
command = f"{exec_str} {config_file} {order_in_full_list} 1"
step1_status += [pool.apply_async(run_command, args=(command,))]
subgroup_status_dict[sub_group_name] = step1_status
level1_total_unfinished = total_count
while level1_total_unfinished:
# check for finished subgroup every 10 sec. If the raw pixel extraction is done for a subgroup, start step2-group TOPO.
temp_unfinished_count=0
subgroup_list = list(subgroup_status_dict.keys())
for gp_name in subgroup_list:
sub_count=sum([not r.ready() for r in subgroup_status_dict[gp_name]])
if sub_count==0:
pool.apply_async(run_step2, args=((config_file,h5_folder,gp_name),))
del subgroup_status_dict[gp_name]
temp_unfinished_count += sub_count
level1_total_unfinished=temp_unfinished_count
#print(f"=={level1_total_unfinished}/{total_count} tasks Level 1 pretopo")
time.sleep(10.0)
print('All extraction is done.')
subprocess.run(merge_str.format(config_file,h5_folder),shell=True)
if __name__== "__main__":
main()

View File

@ -0,0 +1,45 @@
import sys, os
import multiprocessing
import subprocess, json
exec_str="python ../no_ray/trait_estimate_inde.py "
def run_command(command):
print(command)
subprocess.run(command,shell=True)
def main():
config_file = sys.argv[1]
total_img_count = int(sys.argv[2])
total_trait_count = int(sys.argv[3])
total_count = total_img_count*total_trait_count
worker_count = min(os.cpu_count()-1,total_count)
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
pool = multiprocessing.Pool(processes=worker_count)
if (total_img_count > len(config_dict["input_files"])) or (total_trait_count> len(config_dict["input_files"])):
print("Out of upper bound")
return
param_list = []
for img_i in range(total_img_count):
for trait_j in range(total_trait_count):
param_list+=[[img_i,trait_j]]
commands=[f"{exec_str} {config_file} {param_order[0]} {param_order[1]}" for param_order in param_list]
pool.map(run_command, commands)
pool.close()
pool.join() # Wait for all subprocesses to finish
print('All traits are done.')
if __name__== "__main__":
main()

View File

@ -0,0 +1,165 @@
import json
import os
import warnings
import sys
import numpy as np
import hytools as ht
from hytools.io.envi import *
from hytools.masks import mask_dict
warnings.filterwarnings("ignore")
def main():
config_file = sys.argv[1]
image_order = int(sys.argv[2])
trait_order = int(sys.argv[3])
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
image= config_dict["input_files"][image_order]
actor = ht.HyTools()
# Load data
if config_dict['file_type'] in ('envi','emit','ncav'):
anc_file = config_dict["anc_files"][image]
if "glt_files" in config_dict:
if bool(config_dict["glt_files"]):
actor.read_file(image,config_dict['file_type'],anc_path=anc_file,glt_path=config_dict["glt_files"][image]) # chunk_glt writing is not supported
else:
actor.read_file(image,config_dict['file_type'],anc_path=anc_file)
else:
actor.read_file(image,config_dict['file_type'],anc_path=anc_file)
elif config_dict['file_type'] == 'neon':
actor.read_file(image,config_dict['file_type'])
trait = config_dict['trait_models'][trait_order]
with open(trait, 'r') as json_file:
trait_model = json.load(json_file)
print("\t %s" % trait_model["name"])
apply_single_trait_models(actor,config_dict,trait_order)
def apply_single_trait_models(hy_obj,config_dict,trait_order):
'''Apply trait model(s) to image and export to file.
'''
hy_obj.create_bad_bands(config_dict['bad_bands'])
hy_obj.corrections = config_dict['corrections']
# Load correction coefficients
if 'topo' in hy_obj.corrections:
hy_obj.load_coeffs(config_dict['topo'][hy_obj.file_name],'topo')
if 'brdf' in hy_obj.corrections:
hy_obj.load_coeffs(config_dict['brdf'][hy_obj.file_name],'brdf')
hy_obj.resampler['type'] = config_dict["resampling"]['type']
for trait in [config_dict['trait_models'][trait_order]]:
with open(trait, 'r') as json_file:
trait_model = json.load(json_file)
coeffs = np.array(trait_model['model']['coefficients'])
intercept = np.array(trait_model['model']['intercepts'])
model_waves = np.array(trait_model['wavelengths'])
#Check if wavelengths match
resample = not all(x in hy_obj.wavelengths for x in model_waves)
if resample:
hy_obj.resampler['out_waves'] = model_waves
hy_obj.resampler['out_fwhm'] = trait_model['fwhm']
else:
wave_mask = [np.argwhere(x==hy_obj.wavelengths)[0][0] for x in model_waves]
# Build trait image file
header_dict = hy_obj.get_header()
header_dict['wavelength'] = []
header_dict['data ignore value'] = -9999
header_dict['data type'] = 4
header_dict['trait unit'] = trait_model['units']
header_dict['band names'] = ["%s_mean" % trait_model["name"],
"%s_std" % trait_model["name"],
'range_mask'] + [mask[0] for mask in config_dict['masks']]
header_dict['bands'] = len(header_dict['band names'] )
#Generate masks
for mask,args in config_dict['masks']:
mask_function = mask_dict[mask]
hy_obj.gen_mask(mask_function,mask,args)
output_name = config_dict['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] + "_%s" % trait_model["name"]
writer = WriteENVI(output_name,header_dict)
if config_dict['file_type'] == 'envi' or config_dict['file_type'] == 'emit':
iterator = hy_obj.iterate(by = 'chunk',
chunk_size = (2,hy_obj.columns),
corrections = hy_obj.corrections,
resample=resample)
elif config_dict['file_type'] == 'neon':
iterator = hy_obj.iterate(by = 'chunk',
chunk_size = (int(np.ceil(hy_obj.lines/32)),int(np.ceil(hy_obj.columns/32))),
corrections = hy_obj.corrections,
resample=resample)
elif config_dict['file_type'] == 'ncav':
iterator = hy_obj.iterate(by = 'chunk',
chunk_size = (512,512),
corrections = hy_obj.corrections,
resample=resample)
while not iterator.complete:
chunk = iterator.read_next()
if not resample:
chunk = chunk[:,:,wave_mask]
trait_est = np.zeros((chunk.shape[0],
chunk.shape[1],
header_dict['bands']))
# Apply spectrum transforms
for transform in trait_model['model']["transform"]:
if transform== "vector": #vnorm
norm = np.linalg.norm(chunk,axis=2)
chunk = chunk/norm[:,:,np.newaxis]
if transform == "absorb":
chunk = np.log(1/chunk)
if transform == "mean":
mean = chunk.mean(axis=2)
chunk = chunk/mean[:,:,np.newaxis]
trait_pred = np.einsum('jkl,ml->jkm',chunk,coeffs, optimize='optimal')
trait_pred = trait_pred + intercept
trait_est[:,:,0] = trait_pred.mean(axis=2)
trait_est[:,:,1] = trait_pred.std(ddof=1,axis=2)
range_mask = (trait_est[:,:,0] > trait_model["model_diagnostics"]['min']) & \
(trait_est[:,:,0] < trait_model["model_diagnostics"]['max'])
trait_est[:,:,2] = range_mask.astype(int)
# Subset and assign custom masks
for i,(mask,args) in enumerate(config_dict['masks']):
mask = hy_obj.mask[mask][iterator.current_line:iterator.current_line+chunk.shape[0],
iterator.current_column:iterator.current_column+chunk.shape[1]]
trait_est[:,:,3+i] = mask.astype(int)
nd_mask = hy_obj.mask['no_data'][iterator.current_line:iterator.current_line+chunk.shape[0],
iterator.current_column:iterator.current_column+chunk.shape[1]]
trait_est[~nd_mask] = -9999
writer.write_chunk(trait_est,
iterator.current_line,
iterator.current_column)
writer.close()
if __name__== "__main__":
main()

783
Flexbrdf/scripts/rad2geo.py Normal file
View File

@ -0,0 +1,783 @@
#!/usr/bin/env python3
"""
高光谱图像地理变换工具
功能:
- 读取地理校正后的高光谱航带图像 (dat格式11个波段前3个光谱波段、列号、行号、航带号、后5个光谱波段)
- 读取原始高光谱图像 (bil格式)
- 通过行列号进行地理变换匹配
- 保存变换后的高光谱图像
文件命名格式:
- 地理校正文件2025_9_2_3_53_45_202592_35252_0_rad_rgbxyz_geo1.dat
- 原始文件2025_9_2_3_53_45_202592_35252_0_rad.bil
依赖:
- numpy
- spectral (pip install spectral)
- pathlib
"""
import numpy as np
import os
from pathlib import Path
import re
from typing import Tuple, Optional, List
try:
import spectral
SPECTRAL_AVAILABLE = True
except ImportError:
SPECTRAL_AVAILABLE = False
print("警告: spectral库不可用请安装: pip install spectral")
# 可选GDAL库用于保存ENVI格式文件
try:
from osgeo import gdal
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
print("警告: GDAL不可用将使用numpy保存")
def _format_geo_info_for_envi(geo_value):
"""
直接返回地理信息的原始格式已经是从HDR文件直接复制的
"""
if not geo_value:
return None
return geo_value
def extract_file_key(filename: str) -> str:
"""
从文件名中提取用于匹配的关键字
匹配规则:按"2025_9_2_3_53_45_202592_35252_0"这一部分完全相同进行匹配
即匹配模式为数字_数字_..._数字_rad
示例:
"2025_9_2_3_53_45_202592_35252_2_rad_rgbxyz_geo1.dat" -> "2025_9_2_3_53_45_202592_35252_2"
"2025_9_2_3_53_45_202592_35252_9_rad.bil" -> "2025_9_2_3_53_45_202592_35252_9"
"""
# 提取_rad之前的部分作为匹配键
if '_rad' in filename:
key = filename.split('_rad')[0]
return key
else:
# 如果没有_rad尝试其他方式
return filename
def load_geo_corrected_dat(dat_file: str) -> Tuple[np.ndarray, dict]:
"""
读取地理校正后的dat文件
dat文件包含11个波段的原始数据
- 波段0-2: 前三个光谱波段
- 波段3: 列号
- 波段4: 行号
- 波段5: 航带号
- 波段6-10: 后五个光谱波段
使用spectral库读取dat文件
参数:
-----------
dat_file : str
dat文件路径
返回:
-----------
data : np.ndarray
图像数据 (pixels, 11)每个像素11个波段
metadata : dict
元数据信息
"""
if not SPECTRAL_AVAILABLE:
raise RuntimeError("需要spectral库来读取dat文件请安装: pip install spectral")
try:
# 查找对应的头文件
dat_path = Path(dat_file)
hdr_file = dat_path.with_suffix('.hdr')
if not hdr_file.exists():
# 首先读取数据大小来推断图像尺寸
data_temp = np.fromfile(dat_file, dtype=np.float32)
if data_temp.size % 11 != 0:
raise ValueError(f"数据大小 {data_temp.size} 不能被11整除可能不是正确的dat文件格式")
num_pixels = data_temp.size // 11
# 尝试推断图像尺寸
possible_rows = []
for rows in range(1000, int(np.sqrt(num_pixels)) + 1000, 100):
if num_pixels % rows == 0:
cols = num_pixels // rows
if cols > 0 and cols < 10000:
possible_rows.append((rows, cols))
if not possible_rows:
rows = int(np.sqrt(num_pixels))
cols = (num_pixels + rows - 1) // rows
else:
rows, cols = possible_rows[0]
# 创建临时的头文件
hdr_content = f"""ENVI
description = {{
高光谱地理校正数据 - 11波段格式
前3个光谱波段、列号、行号、航带号、后5个光谱波段}}
samples = {cols}
lines = {rows}
bands = 11
header offset = 0
file type = ENVI Standard
data type = 4
interleave = bip
byte order = 0
band names = {{
波段1, 波段2, 波段3, 列号, 行号, 航带号, 波段7, 波段8, 波段9, 波段10, 波段11}}
"""
with open(hdr_file, 'w', encoding='utf-8') as f:
f.write(hdr_content)
# 使用spectral读取
image_data = spectral.open_image(str(hdr_file))
data = image_data.load()
# 如果数据是3D的需要重新整形为2D (pixels, bands)
if data.ndim == 3:
data_reshaped = data.reshape(-1, data.shape[2])
else:
data_reshaped = data
# 确保是11波段
if data_reshaped.shape[1] != 11:
raise ValueError(f"数据波段数 {data_reshaped.shape[1]} 不等于期望的11波段")
# 检查坐标是否全部为0第4、5波段为列号、行号
all_coords_zero = (data_reshaped[:, 3].max() == 0 and data_reshaped[:, 4].max() == 0)
# 如果坐标都为0显示警告
if all_coords_zero:
print(f"警告: dat文件中的所有行列号都是0可能需要重新解释数据格式")
# 读取HDR文件中的地理信息直接复制原始格式
geo_info = {}
try:
# 直接读取HDR文件内容保留原始格式
with open(str(hdr_file), 'r', encoding='utf-8', errors='ignore') as f:
hdr_content = f.read()
# 提取map info行
import re
map_info_match = re.search(r'^map info\s*=\s*(.+)$', hdr_content, re.MULTILINE | re.IGNORECASE)
if map_info_match:
geo_info['map_info'] = map_info_match.group(1).strip()
# 提取coordinate system string行
coord_sys_match = re.search(r'^coordinate system string\s*=\s*(.+)$', hdr_content, re.MULTILINE | re.IGNORECASE)
if coord_sys_match:
geo_info['coordinate_system_string'] = coord_sys_match.group(1).strip()
# 提取projection info行如果有
proj_info_match = re.search(r'^projection info\s*=\s*(.+)$', hdr_content, re.MULTILINE | re.IGNORECASE)
if proj_info_match:
geo_info['projection_info'] = proj_info_match.group(1).strip()
except Exception as e:
print(f"警告: 无法读取geo文件的HDR地理信息: {e}")
metadata = {
'file_path': dat_file,
'hdr_file': str(hdr_file),
'data_type': str(data_reshaped.dtype),
'bands': 11,
'pixels': len(data_reshaped),
'lines': image_data.nrows,
'samples': image_data.ncols,
'band_names': ['band1', 'band2', 'band3', 'column', 'row', 'strip', 'band7', 'band8', 'band9', 'band10', 'band11'],
'col_idx': 3,
'row_idx': 4,
'strip_idx': 5,
'all_coords_zero': all_coords_zero,
'wavelengths': getattr(image_data.bands, 'centers', None),
'interleave': getattr(image_data, 'interleave', 'unknown'),
'geo_info': geo_info
}
return data_reshaped, metadata
except Exception as e:
print(f"spectral读取失败: {e}")
print("回退到numpy直接读取...")
# 备用方案使用numpy直接读取
try:
data = np.fromfile(dat_file, dtype=np.float32)
if data.size % 11 != 0:
raise ValueError(f"数据大小 {data.size} 不能被11整除可能不是正确的dat文件格式")
num_pixels = data.size // 11
data_reshaped = data.reshape(num_pixels, 11)
print(f"numpy直接读取成功: {dat_file}")
print(f"数据形状: {data_reshaped.shape} (pixels, bands)")
# 分析坐标信息
all_coords_zero = (data_reshaped[:, 3].max() == 0 and data_reshaped[:, 4].max() == 0)
metadata = {
'file_path': dat_file,
'data_type': 'float32',
'bands': 11,
'pixels': num_pixels,
'band_names': ['band1', 'band2', 'band3', 'column', 'row', 'strip', 'band7', 'band8', 'band9', 'band10', 'band11'],
'col_idx': 3,
'row_idx': 4,
'strip_idx': 5,
'all_coords_zero': all_coords_zero
}
return data_reshaped, metadata
except Exception as e2:
raise RuntimeError(f"读取dat文件失败: spectral错误: {e}, numpy错误: {e2}")
def load_original_bil(bil_file: str) -> Tuple[np.ndarray, dict]:
"""
读取原始高光谱bil文件
使用spectral库通过hdr头文件读取ENVI格式的高光谱数据
参数:
-----------
bil_file : str
bil文件路径
返回:
-----------
data : np.ndarray
高光谱数据 (lines, samples, bands)
metadata : dict
元数据信息
"""
if not SPECTRAL_AVAILABLE:
raise RuntimeError("需要spectral库来读取bil文件请安装: pip install spectral")
try:
# 查找对应的头文件
bil_path = Path(bil_file)
hdr_file = bil_path.with_suffix('.hdr')
if not hdr_file.exists():
hdr_file = bil_path.parent / f"{bil_path.name}.hdr"
if not hdr_file.exists():
hdr_file = bil_path.parent / f"{bil_path.stem}.hdr"
if not hdr_file.exists():
possible_hdrs = list(bil_path.parent.glob(f"{bil_path.stem}*.hdr"))
if possible_hdrs:
hdr_file = possible_hdrs[0]
else:
raise FileNotFoundError(f"未找到对应的头文件")
# 使用spectral读取ENVI格式文件
image_data = spectral.open_image(str(hdr_file))
data = image_data.load()
# 获取元数据
metadata = {
'file_path': bil_file,
'hdr_file': str(hdr_file),
'lines': image_data.nrows,
'samples': image_data.ncols,
'bands': image_data.nbands,
'wavelengths': image_data.bands.centers if hasattr(image_data.bands, 'centers') else None,
'data_type': str(data.dtype),
'interleave': getattr(image_data, 'interleave', 'unknown')
}
return data, metadata
except Exception as e:
raise RuntimeError(f"读取高光谱文件失败: {bil_file}, 错误: {e}")
def perform_geometric_transform(original_data: np.ndarray,
geo_data: np.ndarray,
geo_metadata: dict,
output_shape: Optional[Tuple[int, int]] = None,
chunk_size: Optional[int] = None) -> np.ndarray:
"""
向量化版本的地理变换
核心逻辑:
1. geo_data[:, 3]和geo_data[:, 4]是原始bil文件中的列号和行号geo从0开始bil从1开始需减1
2. geo像素中"非地理参考"的波段前3个和后5个不全为0的像素才需要替换输出保持0
3. 输出尺寸为geo图像的尺寸geo_metadata['lines'], geo_metadata['samples']
4. 支持可选分块以降低峰值内存
参数:
-----------
original_data : np.ndarray
原始高光谱数据 (orig_lines, orig_samples, bands)
geo_data : np.ndarray
地理校正数据 (pixels, 11) - geo图像中所有像素的序列
geo_metadata : dict
地理校正数据的元数据
output_shape : tuple, optional
输出图像的形状 (lines, samples)如果不指定则尝试从geo_metadata推断
chunk_size : int, optional
分块处理时的块大小,默认一次性处理所有像素
返回:
-----------
transformed_data : np.ndarray
变换后的高光谱数据尺寸为geo图像的尺寸
"""
try:
# 原始BIL尺寸
orig_lines, orig_samples, bands = original_data.shape
# 输出尺寸使用geo尺寸
if output_shape is None:
if 'lines' in geo_metadata and 'samples' in geo_metadata:
lines_out, samples_out = int(geo_metadata['lines']), int(geo_metadata['samples'])
else:
# 兜底按geo展平像素数推断一个尽量方的尺寸
total = len(geo_data)
lines_out = int(np.sqrt(total))
samples_out = (total + lines_out - 1) // lines_out
else:
lines_out, samples_out = output_shape
# 验证输出图像尺寸是否与geo_data像素数量匹配
expected_pixels = lines_out * samples_out
if expected_pixels != len(geo_data):
print(f"警告: 输出图像尺寸({lines_out}x{samples_out}={expected_pixels})与geo数据像素数({len(geo_data)})不匹配")
print("将调整输出尺寸以匹配geo数据像素数")
lines_out = len(geo_data) // samples_out
if lines_out * samples_out < len(geo_data):
lines_out += 1
# 预分配输出(忽略值=0
out = np.zeros((lines_out, samples_out, bands), dtype=original_data.dtype)
print(f"开始地理变换: 原始bil尺寸 {orig_lines}x{orig_samples}x{bands}")
print(f"输出geo图像尺寸: {lines_out}x{samples_out}x{bands}")
print(f"处理 {len(geo_data)} 个geo像素")
# 获取坐标索引从metadata中读取默认为3、4、5
col_idx = geo_metadata.get('col_idx', 3)
row_idx = geo_metadata.get('row_idx', 4)
# 有效掩码:
# 1) geo像素的"非地理参考"波段前col_idx个和后面的不全为0
# 对于11波段检查波段0-2和6-10是否不全为0
non_geo_bands = np.concatenate([np.arange(col_idx), np.arange(col_idx+3, geo_data.shape[1])])
non_zero_mask = ~(np.all(geo_data[:, non_geo_bands] == 0, axis=1))
# 2) 索引合法geo从0BIL从1因此索引时要减1
src_cols = geo_data[:, col_idx].astype(np.int64) - 1
src_rows = geo_data[:, row_idx].astype(np.int64) - 1
in_bounds = (src_rows >= 0) & (src_rows < orig_lines) & (src_cols >= 0) & (src_cols < orig_samples)
valid = non_zero_mask & in_bounds
if not np.any(valid):
print("地理变换完成: 没有有效像素需要处理")
return out
# 有效条目
valid_idx = np.flatnonzero(valid)
# 计算原始BIL扁平索引并批量取光谱
src_rows_v = src_rows[valid]
src_cols_v = src_cols[valid]
src_lin = src_rows_v * orig_samples + src_cols_v # (N_valid,)
orig_flat = original_data.reshape(-1, bands)
if chunk_size is None or len(valid_idx) <= chunk_size:
# 一次性
spectra = orig_flat[src_lin] # (N_valid, bands)
# 计算输出位置geo是按行优先展平
out_rows = valid_idx // samples_out
out_cols = valid_idx % samples_out
out[out_rows, out_cols, :] = spectra
else:
# 分块,降低峰值内存
for s in range(0, len(valid_idx), chunk_size):
e = min(s + chunk_size, len(valid_idx))
src_lin_chunk = src_lin[s:e]
spectra = orig_flat[src_lin_chunk]
idx_chunk = valid_idx[s:e]
out_rows = idx_chunk // samples_out
out_cols = idx_chunk % samples_out
out[out_rows, out_cols, :] = spectra
valid_count = len(valid_idx)
skipped_count = len(geo_data) - valid_count
print(f"地理变换完成: 成功处理 {valid_count} 个像素,跳过 {skipped_count} 个像素无效或全0")
return out
except Exception as e:
raise RuntimeError(f"地理变换失败: {e}")
def fill_nan_with_nearest(data: np.ndarray) -> np.ndarray:
"""
用最近邻值填充NaN值
参数:
-----------
data : np.ndarray
包含NaN的数组
返回:
-----------
filled_data : np.ndarray
填充后的数组
"""
# 简单实现:使用前向填充
filled_data = data.copy()
# 对于每个波段
for band in range(data.shape[2]):
band_data = data[:, :, band]
# 找到非NaN值的掩码
valid_mask = ~np.isnan(band_data)
if np.any(valid_mask):
# 使用最近邻插值(这里使用简单的行方向填充)
for i in range(band_data.shape[0]):
row_data = band_data[i, :]
valid_indices = np.where(~np.isnan(row_data))[0]
if len(valid_indices) > 0:
# 对每一行,用有效值填充
for j in range(len(row_data)):
if np.isnan(row_data[j]):
# 找到最近的有效值
distances = np.abs(valid_indices - j)
nearest_idx = valid_indices[np.argmin(distances)]
row_data[j] = row_data[nearest_idx]
filled_data[:, :, band] = band_data
return filled_data
def save_transformed_data(data: np.ndarray, output_file: str, wavelengths: Optional[np.ndarray] = None,
geo_info: Optional[dict] = None):
"""
保存变换后的高光谱数据为ENVI BIL格式
参数:
-----------
data : np.ndarray
要保存的数据 (lines, samples, bands)
output_file : str
输出文件路径 (.bil)
wavelengths : np.ndarray, optional
波长信息
"""
lines, samples, bands = data.shape
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 输出文件路径
bil_file = str(output_path.with_suffix('.dat'))
hdr_file = str(output_path.with_suffix('.hdr'))
try:
if GDAL_AVAILABLE:
# 使用GDAL保存ENVI格式文件
save_with_gdal_envi(data, bil_file, wavelengths, geo_info)
else:
# 回退到numpy保存
save_with_numpy_envi(data, bil_file, hdr_file, wavelengths, geo_info)
print(f"✅ 成功保存地理变换结果:")
print(f" 数据文件: {bil_file}")
print(f" 头文件: {hdr_file}")
print(f" 数据尺寸: {lines} x {samples} x {bands}")
print(f" 保存方式: {'GDAL' if GDAL_AVAILABLE else 'NumPy'}")
except Exception as e:
raise RuntimeError(f"保存文件失败: {output_file}, 错误: {e}")
def save_with_gdal_envi(data: np.ndarray, bil_file: str, wavelengths: Optional[np.ndarray] = None,
geo_info: Optional[dict] = None):
"""
使用GDAL保存ENVI BIL格式文件
"""
lines, samples, bands = data.shape
hdr_file = bil_file.replace('.dat', '.hdr')
# 创建GDAL ENVI驱动
driver = gdal.GetDriverByName('ENVI')
# 创建数据集GDAL ENVI默认使用BSQ格式data type = 12 (uint16)
dataset = driver.Create(bil_file, samples, lines, bands, gdal.GDT_UInt16,
options=['INTERLEAVE=BSQ'])
if dataset is None:
raise RuntimeError(f"无法创建ENVI数据集: {bil_file}")
try:
# 设置元数据
metadata = dataset.GetMetadata()
metadata['DESCRIPTION'] = 'Geometrically transformed hyperspectral data'
metadata['SENSOR_TYPE'] = 'Hyperspectral'
metadata['DATA_UNITS'] = 'Reflectance'
metadata['PROCESSING_ALGORITHM'] = 'Geometric Transformation'
metadata['CREATION_DATE'] = str(np.datetime64('now'))
# 添加波长信息到元数据
if wavelengths is not None and len(wavelengths) == bands:
metadata['wavelength_units'] = 'nm'
for i, wl in enumerate(wavelengths):
metadata[f'wavelength_{i+1}'] = str(wl)
dataset.SetMetadata(metadata)
# 写入数据
for band_idx in range(bands):
band = dataset.GetRasterBand(band_idx + 1)
band_data = data[:, :, band_idx].astype(np.float32)
band.WriteArray(band_data)
band.SetNoDataValue(0.0) # 设置NoData值
# 设置波段描述
if wavelengths is not None and band_idx < len(wavelengths):
band.SetDescription(f'{wavelengths[band_idx]:.1f} nm')
else:
band.SetDescription(f'Band {band_idx + 1}')
finally:
# 关闭数据集
dataset = None
# 在数据集完全关闭后覆盖GDAL自动创建的HDR文件
import time
time.sleep(0.1) # 短暂等待确保GDAL完成写入
create_envi_header(hdr_file, lines, samples, bands, wavelengths, geo_info)
def save_with_numpy_envi(data: np.ndarray, bil_file: str, hdr_file: str,
wavelengths: Optional[np.ndarray] = None, geo_info: Optional[dict] = None):
"""
使用numpy保存ENVI BSQ格式文件GDAL不可用时的回退方案
"""
lines, samples, bands = data.shape
# 保存二进制数据 - BSQ格式按波段顺序存储data type = 12 (uint16)
with open(bil_file, 'wb') as f:
# BSQ格式对于每个波段存储所有像素
for band_idx in range(bands):
# 转换为uint16裁剪到有效范围
band_data = np.clip(data[:, :, band_idx], 0, 65535).astype(np.uint16)
band_data.tofile(f)
# 创建HDR头文件
create_envi_header(hdr_file, lines, samples, bands, wavelengths, geo_info)
def create_envi_header(hdr_file: str, lines: int, samples: int, bands: int,
wavelengths: Optional[np.ndarray] = None, geo_info: Optional[dict] = None):
"""
创建ENVI格式的HDR头文件
"""
with open(hdr_file, 'w', encoding='utf-8') as f:
f.write("ENVI\n")
f.write("description = {\n")
f.write(" Geometrically transformed hyperspectral data\n")
f.write(" Processed with Python geometric transformation}\n")
f.write(f"samples = {samples}\n")
f.write(f"lines = {lines}\n")
f.write(f"bands = {bands}\n")
f.write("header offset = 0\n")
f.write("file type = ENVI Standard\n")
f.write("data type = 12\n") # uint16
f.write("interleave = bsq\n")
f.write("sensor type = Hyperspectral\n")
f.write("byte order = 0\n") # little-endian
f.write("data ignore value = 0\n")
# 添加地理参考信息直接使用从原始HDR复制的格式
if geo_info:
if geo_info.get('map_info'):
f.write(f"map info = {geo_info['map_info']}\n")
if geo_info.get('coordinate_system_string'):
f.write(f"coordinate system string = {geo_info['coordinate_system_string']}\n")
if geo_info.get('projection_info'):
f.write(f"projection info = {geo_info['projection_info']}\n")
# 添加波长信息
if wavelengths is not None and len(wavelengths) == bands:
f.write("wavelength units = nm\n")
f.write("wavelength = { ")
for i, wl in enumerate(wavelengths):
f.write(f"{wl}")
if i < len(wavelengths) - 1:
f.write(",")
f.write(" }\n")
def process_file_pair(geo_dat_file: str, bil_file: str, output_dir: str) -> bool:
"""
处理一对匹配的文件
参数:
-----------
geo_dat_file : str
地理校正dat文件路径
bil_file : str
原始bil文件路径
output_dir : str
输出目录
返回:
-----------
success : bool
处理是否成功
"""
try:
# 读取地理校正数据
geo_data, geo_metadata = load_geo_corrected_dat(geo_dat_file)
# 读取原始高光谱数据
original_data, orig_metadata = load_original_bil(bil_file)
# 执行地理变换
transformed_data = perform_geometric_transform(original_data, geo_data, geo_metadata)
# 生成输出文件名
bil_name = Path(bil_file).stem
output_file = Path(output_dir) / f"{bil_name}_geo_corrected.dat"
# 保存结果
wavelengths = orig_metadata.get('wavelengths')
geo_info = geo_metadata.get('geo_info')
save_transformed_data(transformed_data, str(output_file), wavelengths, geo_info)
return True
except Exception as e:
print(f"处理失败: {e}")
return False
def find_matching_files(geo_dir: str, bil_dir: str) -> List[Tuple[str, str]]:
"""
查找匹配的文件对
参数:
-----------
geo_dir : str
地理校正文件目录
bil_dir : str
原始bil文件目录
返回:
-----------
file_pairs : List[Tuple[str, str]]
匹配的文件对列表 [(geo_file, bil_file), ...]
"""
geo_files = {}
bil_files = {}
# 收集地理校正文件
for file_path in Path(geo_dir).glob('*_rad_rgbxyz_geo_angles_registered.bip'):
key = extract_file_key(file_path.name)
geo_files[key] = str(file_path)
# 收集bil文件
for file_path in Path(bil_dir).glob('*_rad.bil'):
key = extract_file_key(file_path.name)
bil_files[key] = str(file_path)
# 找到匹配的对
matching_pairs = []
for key in geo_files.keys():
if key in bil_files:
matching_pairs.append((geo_files[key], bil_files[key]))
return matching_pairs
def batch_process(geo_dir: str, bil_dir: str, output_dir: str) -> dict:
"""
批量处理所有匹配的文件对
参数:
-----------
geo_dir : str
地理校正文件目录
bil_dir : str
原始bil文件目录
output_dir : str
输出目录
返回:
-----------
results : dict
处理结果统计
"""
# 确保输出目录存在
Path(output_dir).mkdir(parents=True, exist_ok=True)
# 查找匹配的文件对
file_pairs = find_matching_files(geo_dir, bil_dir)
if not file_pairs:
return {'total': 0, 'success': 0, 'failed': 0}
# 处理每一对文件
results = {'total': len(file_pairs), 'success': 0, 'failed': 0}
for geo_file, bil_file in file_pairs:
success = process_file_pair(geo_file, bil_file, output_dir)
if success:
results['success'] += 1
else:
results['failed'] += 1
return results
def main():
"""
主函数 - 示例用法
"""
# 示例路径(需要根据实际情况修改)
geo_corrected_dir = r"D:\BaiduNetdiskDownload\20250902\_3_52_52\316\jiaozhen" # 地理校正dat文件目录
original_bil_dir = r"D:\BaiduNetdiskDownload\20250902\_3_52_52\Geoout\Radout" # 原始bil文件目录
output_dir = r"D:\BaiduNetdiskDownload\20250902\_3_52_52\316\cube" # 输出目录
# 检查依赖
if not SPECTRAL_AVAILABLE:
print("错误: 需要安装spectral库")
return
# 批量处理
results = batch_process(geo_corrected_dir, original_bil_dir, output_dir)
if results['success'] > 0:
print(f"处理完成!成功处理了 {results['success']} 对文件")
else:
print("未成功处理任何文件,请检查文件路径和格式")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,814 @@
#!/usr/bin/env python3
"""
高光谱反射率校正工具
功能:
- 读取ENVI ASCII Plot File格式的反射率校正文件
- 读取航带文件夹中的高光谱数据文件 (.bil/.bip/.bsq等ENVI格式)
注意spectral库通过读取对应的.hdr头文件来访问数据
- 根据波长匹配进行反射率校正(航带数据 / 校正值)
- 保存校正后的反射率文件
依赖:
- numpy
- GDAL - 用于读取和保存ENVI格式高光谱数据
- pathlib
使用方法:
python redlence.py <航带文件夹路径> <校正文件路径> [输出文件夹路径]
文件要求:
- 航带文件夹中应包含 .bil/.bip/.bsq 文件及其对应的 .hdr 头文件
- 校正文件为ENVI ASCII Plot File格式包含波长和校正值两列数据
"""
import numpy as np
import os
import sys
from pathlib import Path
from typing import Tuple, List, Dict, Optional
import argparse
# 可选GDAL库用于读取和保存ENVI格式文件
try:
from osgeo import gdal
# GDAL性能优化配置
GDAL_AVAILABLE = True
except ImportError:
GDAL_AVAILABLE = False
print("警告: GDAL不可用请安装GDAL")
def parse_correction_file(correction_file: str) -> Tuple[np.ndarray, np.ndarray]:
"""
解析ENVI ASCII Plot File格式的反射率校正文件优化版本
参数:
-----------
correction_file : str
校正文件路径
返回:
-----------
wavelengths : np.ndarray
波长数组 (nm)
corrections : np.ndarray
校正值数组
"""
try:
# 使用numpy的loadtxt加速读取跳过标题行
# 首先找到数据开始的行号
with open(correction_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
# 找到数据开始的行索引
data_start_idx = 0
for i, line in enumerate(lines):
line = line.strip()
if not line or line.startswith(';'):
continue
if 'ENVI ASCII Plot File' in line or 'Column 1:' in line or 'Column 2:' in line:
continue
# 找到第一个可能是数据行的行
try:
parts = line.split()
if len(parts) >= 2:
float(parts[0]), float(parts[1])
data_start_idx = i
break
except (ValueError, IndexError):
continue
if data_start_idx == 0 and not lines:
raise ValueError("校正文件中未找到有效的数据行")
# 使用numpy loadtxt从数据开始行读取
data = np.loadtxt(correction_file, skiprows=data_start_idx, dtype=float)
if data.ndim == 1:
# 如果只有一行数据
wavelengths = np.array([data[0]])
corrections = np.array([data[1]])
else:
wavelengths = data[:, 0]
corrections = data[:, 1]
print(f"✅ 成功解析校正文件: {correction_file}")
print(f" 数据点数: {len(wavelengths)}")
print(f" 波长范围: {wavelengths.min():.1f} - {wavelengths.max():.1f} nm")
print(f" 校正值范围: {corrections.min():.6f} - {corrections.max():.6f}")
return wavelengths, corrections
except Exception as e:
# 如果numpy loadtxt失败回退到原始方法
print(f"⚠️ numpy加速读取失败回退到逐行读取: {e}")
return parse_correction_file_fallback(correction_file)
def parse_correction_file_fallback(correction_file: str) -> Tuple[np.ndarray, np.ndarray]:
"""
回退方法:使用原始的逐行解析方式
"""
wavelengths = []
corrections = []
try:
with open(correction_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
# 跳过标题行,找到数据开始
for line in lines:
line = line.strip()
# 跳过空行和注释
if not line or line.startswith(';'):
continue
# 跳过标题
if 'ENVI ASCII Plot File' in line or 'Column 1:' in line or 'Column 2:' in line:
continue
# 尝试解析数据行(波长 校正值)
try:
parts = line.split()
if len(parts) >= 2:
wavelength = float(parts[0])
correction = float(parts[1])
wavelengths.append(wavelength)
corrections.append(correction)
except (ValueError, IndexError):
# 如果这一行不是数据,可能是其他格式,跳过
continue
if not wavelengths:
raise ValueError("校正文件中未找到有效的数据行")
wavelengths = np.array(wavelengths, dtype=float)
corrections = np.array(corrections, dtype=float)
return wavelengths, corrections
except Exception as e:
raise RuntimeError(f"解析校正文件失败: {correction_file}, 错误: {e}")
def find_hyperspectral_files(folder_path: str) -> List[str]:
"""
查找文件夹中的高光谱数据文件
支持的格式:.bil, .bip, .bsq等ENVI格式文件
注意GDAL可以直接读取ENVI格式文件会自动查找对应的.hdr头文件
参数:
-----------
folder_path : str
文件夹路径
返回:
-----------
file_list : List[str]
高光谱数据文件路径列表 (.bil/.bip/.bsq文件)
"""
folder = Path(folder_path)
if not folder.exists():
raise FileNotFoundError(f"文件夹不存在: {folder_path}")
# 支持的ENVI格式扩展名
supported_extensions = ['.bil', '.bip', '.bsq','.dat']
hyperspectral_files = []
for ext in supported_extensions:
files = list(folder.glob(f'*{ext}'))
hyperspectral_files.extend([str(f) for f in files])
# 去重(防止同一个文件被多次识别)
hyperspectral_files = list(set(hyperspectral_files))
if not hyperspectral_files:
print(f"警告: 在文件夹 {folder_path} 中未找到高光谱数据文件")
print("支持的格式: .bil, .bip, .bsq")
return sorted(hyperspectral_files)
def load_hyperspectral_data(file_path: str) -> Tuple[np.ndarray, Dict]:
"""
使用GDAL读取高光谱数据文件 - 流式版本
不加载整个立方体,只返回数据集句柄和元数据
参数:
-----------
file_path : str
高光谱数据文件路径 (.bil/.bip/.bsq)
返回:
-----------
dataset : gdal.Dataset
GDAL数据集对象用于流式读取
metadata : dict
元数据信息
"""
if not GDAL_AVAILABLE:
raise RuntimeError("需要GDAL库来读取高光谱文件请安装GDAL")
file_path = Path(file_path)
try:
# 使用GDAL打开ENVI文件
dataset = gdal.Open(str(file_path), gdal.GA_ReadOnly)
if dataset is None:
raise RuntimeError(f"无法打开文件: {file_path}")
# 获取基本信息
lines = dataset.RasterYSize
samples = dataset.RasterXSize
bands = dataset.RasterCount
# 查找对应的HDR文件
hdr_file = None
hdr_candidates = [
file_path.with_suffix('.hdr'), # datafile.hdr
file_path.with_suffix(file_path.suffix + '.hdr'), # datafile.ext.hdr
file_path.parent / f"{file_path.name}.hdr", # datafile.ext.hdr (另一种写法)
]
for candidate in hdr_candidates:
if candidate.exists():
hdr_file = candidate
break
# 尝试从HDR文件中提取波长信息
wavelengths = None
if hdr_file and hdr_file.exists():
try:
# 读取HDR文件内容
with open(str(hdr_file), 'r', encoding='utf-8', errors='ignore') as f:
hdr_content = f.read()
# 提取波长信息
import re
wavelength_match = re.search(r'wavelength\s*=\s*\{([^}]+)\}', hdr_content, re.IGNORECASE)
if wavelength_match:
wavelength_str = wavelength_match.group(1)
# 解析波长值
wavelength_values = []
for val in wavelength_str.split(','):
val = val.strip()
try:
wavelength_values.append(float(val))
except ValueError:
continue
if wavelength_values:
wavelengths = np.array(wavelength_values, dtype=float)
except Exception as e:
print(f"⚠️ 无法从HDR文件读取波长信息: {e}")
# 构建元数据
metadata = {
'file_path': str(file_path), # 原始数据文件路径
'hdr_file': str(hdr_file) if hdr_file else None, # 对应的hdr文件路径
'lines': lines,
'samples': samples,
'bands': bands,
'wavelengths': wavelengths,
'data_type': 'float32', # 流式读取时的数据类型
'interleave': 'unknown' # GDAL不直接提供interleave信息
}
print(f"✅ 成功打开高光谱文件: {Path(file_path).name}")
if hdr_file:
print(f" HDR文件: {Path(hdr_file).name}")
print(f" 数据尺寸: {metadata['lines']} x {metadata['samples']} x {metadata['bands']}")
print(f" 数据类型: {metadata['data_type']} (流式)")
if metadata['wavelengths'] is not None:
print(f" 波长范围: {metadata['wavelengths'][0]:.1f} - {metadata['wavelengths'][-1]:.1f} nm")
return dataset, metadata
except Exception as e:
raise RuntimeError(f"读取高光谱文件失败: {file_path}, 错误: {e}")
def interpolate_corrections(wavelengths_data: np.ndarray, wavelengths_corr: np.ndarray,
corrections: np.ndarray) -> np.ndarray:
"""
将校正值插值到数据波长上
参数:
-----------
wavelengths_data : np.ndarray
数据文件的波长数组
wavelengths_corr : np.ndarray
校正文件的波长数组
corrections : np.ndarray
校正值数组
返回:
-----------
interpolated_corrections : np.ndarray
插值后的校正值数组
"""
if wavelengths_data is None:
raise ValueError("数据文件缺少波长信息,无法进行校正")
try:
# 使用线性插值
from scipy.interpolate import interp1d
interp_func = interp1d(wavelengths_corr, corrections,
kind='linear', bounds_error=False,
fill_value='extrapolate')
interpolated = interp_func(wavelengths_data)
print(f"✅ 成功插值校正值到数据波长")
print(f" 数据波段数: {len(wavelengths_data)}")
print(f" 校正数据点数: {len(wavelengths_corr)}")
print(f" 插值范围: {interpolated.min():.3f} - {interpolated.max():.3f}")
return interpolated
except ImportError:
# 如果没有scipy使用numpy的interp
print("警告: scipy不可用使用numpy进行线性插值")
interpolated = np.interp(wavelengths_data, wavelengths_corr, corrections)
print(f"✅ 使用numpy插值校正值到数据波长")
print(f" 插值范围: {interpolated.min():.3f} - {interpolated.max():.3f}")
return interpolated
def apply_reflectance_correction_streaming(input_dataset, output_dataset, corrections: np.ndarray, block_size: int = 1024):
"""
应用反射率校正 - 流式版本
按块读取→校正→写入,避免加载整个立方体
参数:
-----------
input_dataset : gdal.Dataset
输入数据集
output_dataset : gdal.Dataset
输出数据集
corrections : np.ndarray
校正值数组 (bands,)
block_size : int
块大小(行数)
"""
lines = input_dataset.RasterYSize
samples = input_dataset.RasterXSize
bands = input_dataset.RasterCount
if bands != len(corrections):
raise ValueError(f"数据波段数 ({bands}) 与校正值数量 ({len(corrections)}) 不匹配")
print("🔢 正在应用反射率校正(流式处理,向量化加速)...")
# 用“乘倒数”替代逐元素除法,并在无效校正值(0/NaN/Inf)处直接置零
corrections = np.asarray(corrections, dtype=np.float32)
scale = np.zeros((bands,), dtype=np.float32)
valid = np.isfinite(corrections) & (corrections != 0)
scale[valid] = np.float32(10000.0) / corrections[valid]
scale_3d = scale[:, None, None]
# 提前缓存输出波段对象,避免循环内重复 GetRasterBand
output_bands = [output_dataset.GetRasterBand(i + 1) for i in range(bands)]
total_blocks = (lines + block_size - 1) // block_size
# 按块处理:每个块一次性读全波段并向量化计算
for block_idx, y_start in enumerate(range(0, lines, block_size), start=1):
y_end = min(y_start + block_size, lines)
actual_block_size = y_end - y_start
# 低频打印进度避免大量I/O拖慢处理
if block_idx == 1 or block_idx == total_blocks or block_idx % 10 == 0:
print(f" 处理块 {block_idx}/{total_blocks}: 行 {y_start}-{y_end-1} ({actual_block_size} 行)")
block = input_dataset.ReadAsArray(0, y_start, samples, actual_block_size)
if block is None:
raise RuntimeError(f"读取数据块失败: y_start={y_start}, block_size={actual_block_size}")
# 统一维度为 (bands, block_y, samples)
if block.ndim == 2:
block = block[np.newaxis, :, :]
block_f = block.astype(np.float32, copy=False)
np.multiply(block_f, scale_3d, out=block_f, casting='unsafe')
np.clip(block_f, 0, 65535, out=block_f)
block_u16 = block_f.astype(np.uint16, copy=False)
# GDAL按波段写出
for band_idx in range(bands):
output_bands[band_idx].WriteArray(block_u16[band_idx, :, :], 0, y_start)
print("✅ 成功应用反射率校正(流式处理,向量化加速)")
def save_corrected_data_streaming(input_dataset, corrections: np.ndarray, output_file: str,
wavelengths: Optional[np.ndarray] = None, source_hdr: Optional[str] = None):
"""
保存校正后的反射率数据为ENVI格式 - 流式版本
参数:
-----------
input_dataset : gdal.Dataset
输入数据集
corrections : np.ndarray
校正值数组
output_file : str
输出文件路径(不含扩展名)
wavelengths : np.ndarray, optional
波长信息
source_hdr : str, optional
源HDR文件路径用于复制HDR内容
"""
lines = input_dataset.RasterYSize
samples = input_dataset.RasterXSize
bands = input_dataset.RasterCount
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 输出文件路径
bil_file = str(output_path.with_suffix('.dat'))
hdr_file = str(output_path.with_suffix('.hdr'))
try:
# 创建输出数据集
driver = gdal.GetDriverByName('ENVI')
output_dataset = driver.Create(bil_file, samples, lines, bands, gdal.GDT_UInt16,
options=['INTERLEAVE=BSQ'])
if output_dataset is None:
raise RuntimeError(f"无法创建ENVI数据集: {bil_file}")
# 设置NoData值
for band_idx in range(bands):
output_band = output_dataset.GetRasterBand(band_idx + 1)
output_band.SetNoDataValue(0)
# 应用流式校正和保存
apply_reflectance_correction_streaming(input_dataset, output_dataset, corrections)
# 关闭输出数据集
output_dataset = None
# 处理HDR文件
if source_hdr and Path(source_hdr).exists():
import shutil
shutil.copy2(source_hdr, hdr_file)
print(f"✅ 已复制源HDR文件: {Path(source_hdr).name}")
else:
# 创建HDR头文件
create_envi_header(hdr_file, lines, samples, bands, wavelengths, None)
print(f"✅ 成功保存校正结果:")
print(f" 数据文件: {bil_file}")
print(f" 头文件: {hdr_file}")
print(f" 数据尺寸: {lines} x {samples} x {bands}")
print(f" 数据类型: uint16 (反射率x10000)")
print(f" 处理方式: 流式处理")
except Exception as e:
raise RuntimeError(f"保存文件失败: {output_file}, 错误: {e}")
def save_with_gdal(data: np.ndarray, bil_file: str, wavelengths: Optional[np.ndarray] = None,
source_file: Optional[str] = None, source_hdr: Optional[str] = None):
"""
使用GDAL保存ENVI格式文件
"""
lines, samples, bands = data.shape
hdr_file = bil_file.replace('.dat', '.hdr')
# 创建GDAL驱动
driver = gdal.GetDriverByName('ENVI')
# 创建数据集 - 使用uint16格式优化性能和文件大小
dataset = driver.Create(bil_file, samples, lines, bands, gdal.GDT_UInt16,
options=['INTERLEAVE=BSQ'])
if dataset is None:
raise RuntimeError(f"无法创建ENVI数据集: {bil_file}")
try:
# 设置元数据
metadata = dataset.GetMetadata()
metadata['DESCRIPTION'] = 'Reflectance corrected hyperspectral data using Python'
metadata['SENSOR_TYPE'] = 'Hyperspectral'
metadata['DATA_UNITS'] = 'Reflectance'
metadata['PROCESSING_ALGORITHM'] = 'Reflectance Correction'
metadata['CREATION_DATE'] = str(np.datetime64('now'))
if source_file:
metadata['SOURCE_FILE'] = Path(source_file).name
# 添加波长信息到元数据
if wavelengths is not None and len(wavelengths) == bands:
metadata['wavelength_units'] = 'nm'
for i, wl in enumerate(wavelengths):
metadata[f'wavelength_{i+1}'] = str(wl)
dataset.SetMetadata(metadata)
# 写入数据 - 优化版本转换为uint16格式
print(f"💾 正在写入 {bands} 个波段的数据...")
# 将反射率转换为uint16乘以10000以保留4位小数精度裁剪到有效范围
data_uint16 = np.clip(data * 10000, 0, 65535).astype(np.uint16, copy=False)
for band_idx in range(bands):
band = dataset.GetRasterBand(band_idx + 1)
band_data = data_uint16[:, :, band_idx]
band.WriteArray(band_data)
band.SetNoDataValue(0) # uint16的NoData值为0
# 简化:只在必要时设置波段描述(可选优化:完全移除以提升速度)
# if wavelengths is not None and band_idx < len(wavelengths):
# band.SetDescription(f'{wavelengths[band_idx]:.1f} nm')
# 每处理100个波段显示一次进度减少打印频率
if bands >= 100 and (band_idx + 1) % 100 == 0:
print(f" 已写入 {band_idx + 1}/{bands} 个波段")
print(f"✅ 数据写入完成 ({bands} 个波段)")
# 创建HDR头文件GDAL会自动创建基本的HDR但我们需要添加更多信息
create_envi_header(hdr_file, lines, samples, bands, wavelengths, source_file)
finally:
# 关闭数据集
dataset = None
def save_with_numpy(data: np.ndarray, bil_file: str, hdr_file: str,
wavelengths: Optional[np.ndarray] = None, source_file: Optional[str] = None, source_hdr: Optional[str] = None):
"""
使用numpy保存ENVI格式文件GDAL不可用时的回退方案优化版本
"""
lines, samples, bands = data.shape
print(f"💾 正在保存 {bands} 个波段的数据...")
# 保存二进制数据 - uint16格式预先转换数据类型
with open(bil_file, 'wb') as f:
# 将反射率转换为uint16乘以10000以保留4位小数精度裁剪到有效范围
data_to_save = np.clip(data * 10000, 0, 65535).astype(np.uint16, copy=False)
data_to_save.tofile(f)
print(f"✅ 数据文件写入完成")
# 如果有源HDR文件直接复制
if source_hdr and Path(source_hdr).exists():
import shutil
shutil.copy2(source_hdr, hdr_file)
print(f"✅ 已复制源HDR文件: {Path(source_hdr).name}")
else:
# 创建HDR头文件
create_envi_header(hdr_file, lines, samples, bands, wavelengths, source_file)
def create_envi_header(hdr_file: str, lines: int, samples: int, bands: int,
wavelengths: Optional[np.ndarray] = None, source_file: Optional[str] = None):
"""
创建ENVI格式的HDR头文件
"""
with open(hdr_file, 'w', encoding='utf-8') as f:
f.write("ENVI\n")
f.write("description = {\n")
f.write(" Reflectance corrected hyperspectral data\n")
f.write(" Processed with Python reflectance correction}\n")
f.write(f"samples = {samples}\n")
f.write(f"lines = {lines}\n")
f.write(f"bands = {bands}\n")
f.write("header offset = 0\n")
f.write("file type = ENVI Standard\n")
f.write("data type = 12\n") # uint16
f.write("interleave = bsq\n")
f.write("sensor type = Hyperspectral\n")
f.write("byte order = 0\n") # little-endian
f.write("reflectance scale factor = 10000\n") # 反射率缩放因子
# 添加波长信息
if wavelengths is not None and len(wavelengths) == bands:
f.write("wavelength units = nm\n")
f.write("wavelength = {\n")
for i, wl in enumerate(wavelengths):
f.write(f" {wl}")
if i < len(wavelengths) - 1:
f.write(",")
if (i + 1) % 10 == 0: # 每10个波长换行
f.write("\n")
f.write("}\n")
# 添加波段名称
f.write("band names = {\n")
for i, wl in enumerate(wavelengths):
f.write(f" {wl:.1f} nm")
if i < len(wavelengths) - 1:
f.write(",")
if (i + 1) % 8 == 0: # 每8个波段名称换行
f.write("\n")
f.write("}\n")
def process_single_file(hyp_file: str, wavelengths_corr: np.ndarray, corrections: np.ndarray,
output_dir: str) -> bool:
"""
处理单个高光谱文件
参数:
-----------
hyp_file : str
高光谱文件路径
wavelengths_corr : np.ndarray
校正文件的波长
corrections : np.ndarray
校正值
output_dir : str
输出目录
返回:
-----------
success : bool
处理是否成功
"""
try:
print(f"\n🔄 处理文件: {Path(hyp_file).name}")
# 读取高光谱数据(流式)
input_dataset, metadata = load_hyperspectral_data(hyp_file)
try:
# 获取数据波长
wavelengths_data = metadata.get('wavelengths')
if wavelengths_data is None:
print(f"⚠️ 跳过文件 {Path(hyp_file).name}: 缺少波长信息")
return False
# 判断数据波长和校正波长是否一致
wavelengths_data = np.array(wavelengths_data)
if len(wavelengths_data) == len(wavelengths_corr) and np.allclose(wavelengths_data, wavelengths_corr, rtol=1e-6):
# 波长完全一致,直接使用校正值
print("✅ 数据波长与校正波长完全一致,无需插值")
interpolated_corrections = corrections
else:
# 波长不一致,需要插值
print("🔄 数据波长与校正波长不一致,进行插值")
interpolated_corrections = interpolate_corrections(
wavelengths_data, wavelengths_corr, corrections
)
# 生成输出文件名
input_name = Path(hyp_file).stem
output_file = Path(output_dir) / f"{input_name}_reflectance"
# 流式保存结果(包含校正)
save_corrected_data_streaming(input_dataset, interpolated_corrections, str(output_file),
wavelengths_data, metadata.get('hdr_file'))
finally:
# 确保关闭输入数据集
input_dataset = None
print(f"✅ 成功处理文件: {Path(hyp_file).name}")
return True
except Exception as e:
print(f"❌ 处理文件失败: {Path(hyp_file).name}, 错误: {e}")
return False
def batch_process(hyperspectral_dir: str, correction_file: str, output_dir: str) -> Dict[str, int]:
"""
批量处理文件夹中的所有高光谱文件
参数:
-----------
hyperspectral_dir : str
高光谱文件文件夹
correction_file : str
校正文件路径
output_dir : str
输出目录
返回:
-----------
results : dict
处理结果统计
"""
print("=" * 60)
print("🏁 开始高光谱反射率校正批量处理")
print("=" * 60)
# 检查依赖
if not GDAL_AVAILABLE:
print("❌ 错误: 需要安装GDAL库")
return {'total': 0, 'success': 0, 'failed': 0}
# 确保输出目录存在
Path(output_dir).mkdir(parents=True, exist_ok=True)
try:
# 解析校正文件
print("📖 解析校正文件...")
wavelengths_corr, corrections = parse_correction_file(correction_file)
# 查找高光谱文件
print(f"\n📂 查找高光谱文件...")
hyp_files = find_hyperspectral_files(hyperspectral_dir)
if not hyp_files:
print("❌ 未找到高光谱文件,处理终止")
return {'total': 0, 'success': 0, 'failed': 0}
print(f"找到 {len(hyp_files)} 个高光谱数据文件:")
for f in hyp_files:
print(f" - {Path(f).name} (需要对应的.hdr头文件)")
# 处理每个文件
results = {'total': len(hyp_files), 'success': 0, 'failed': 0}
for hyp_file in hyp_files:
success = process_single_file(hyp_file, wavelengths_corr, corrections, output_dir)
if success:
results['success'] += 1
else:
results['failed'] += 1
# 输出总结
print("\n" + "=" * 60)
print("📊 处理完成总结:")
print(f" 总文件数: {results['total']}")
print(f" 成功处理: {results['success']}")
print(f" 处理失败: {results['failed']}")
print(f" 输出目录: {output_dir}")
print("=" * 60)
return results
except Exception as e:
print(f"❌ 批量处理失败: {e}")
return {'total': 0, 'success': 0, 'failed': 0}
def main():
"""
主函数 - 命令行接口
"""
parser = argparse.ArgumentParser(
description='高光谱反射率校正工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
文件要求:
- 航带文件夹中应包含ENVI格式的高光谱数据文件(.bil/.bip/.bsq)和对应的头文件(.hdr)
- 校正文件应为ENVI ASCII Plot File格式包含波长和校正值两列数据
使用示例:
python redlence.py /path/to/hyperspectral/folder /path/to/correction.txt
python redlence.py /path/to/hyperspectral/folder /path/to/correction.txt /path/to/output/folder
"""
)
parser.add_argument('hyperspectral_dir', help='包含高光谱文件的文件夹路径')
parser.add_argument('correction_file', help='反射率校正文件路径 (ENVI ASCII Plot File格式)')
parser.add_argument('output_dir', nargs='?', default=None,
help='输出目录路径 (可选默认在输入文件夹下创建output文件夹)')
args = parser.parse_args()
# 设置默认输出目录
if args.output_dir is None:
args.output_dir = str(Path(args.hyperspectral_dir) / 'reflectance_output')
# 检查输入文件存在
if not Path(args.correction_file).exists():
print(f"❌ 校正文件不存在: {args.correction_file}")
return 1
if not Path(args.hyperspectral_dir).exists():
print(f"❌ 高光谱文件夹不存在: {args.hyperspectral_dir}")
return 1
# 执行批量处理
results = batch_process(args.hyperspectral_dir, args.correction_file, args.output_dir)
# 返回适当的退出码
if results['success'] > 0:
print("✅ 处理完成!")
return 0
else:
print("❌ 没有成功处理任何文件")
return 1
if __name__ == "__main__":
exit(main())

View File

@ -0,0 +1,163 @@
import json
import os
import warnings
import sys
import ray
import numpy as np
import hytools as ht
from hytools.io.envi import *
from hytools.masks import mask_dict
warnings.filterwarnings("ignore")
def main():
config_file = sys.argv[1]
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
images= config_dict["input_files"]
if ray.is_initialized():
ray.shutdown()
print("Using %s CPUs." % config_dict['num_cpus'])
ray.init(num_cpus = config_dict['num_cpus'])
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for image in images]
# Load data
if config_dict['file_type'] in ('envi','emit','ncav'):
anc_files = config_dict["anc_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_files[image]) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
print("Estimating %s traits:" % len( config_dict['trait_models']))
for trait in config_dict['trait_models']:
with open(trait, 'r') as json_file:
trait_model = json.load(json_file)
print("\t %s" % trait_model["name"])
_ = ray.get([a.do.remote(apply_trait_models,config_dict) for a in actors])
ray.shutdown()
def apply_trait_models(hy_obj,config_dict):
'''Apply trait model(s) to image and export to file.
'''
hy_obj.create_bad_bands(config_dict['bad_bands'])
hy_obj.corrections = config_dict['corrections']
# Load correction coefficients
if 'topo' in hy_obj.corrections:
hy_obj.load_coeffs(config_dict['topo'][hy_obj.file_name],'topo')
if 'brdf' in hy_obj.corrections:
hy_obj.load_coeffs(config_dict['brdf'][hy_obj.file_name],'brdf')
hy_obj.resampler['type'] = config_dict["resampling"]['type']
for trait in config_dict['trait_models']:
with open(trait, 'r') as json_file:
trait_model = json.load(json_file)
coeffs = np.array(trait_model['model']['coefficients'])
intercept = np.array(trait_model['model']['intercepts'])
model_waves = np.array(trait_model['wavelengths'])
#Check if wavelengths match
resample = not all(x in hy_obj.wavelengths for x in model_waves)
if resample:
hy_obj.resampler['out_waves'] = model_waves
hy_obj.resampler['out_fwhm'] = trait_model['fwhm']
else:
wave_mask = [np.argwhere(x==hy_obj.wavelengths)[0][0] for x in model_waves]
# Build trait image file
header_dict = hy_obj.get_header()
header_dict['wavelength'] = []
header_dict['data ignore value'] = -9999
header_dict['data type'] = 4
header_dict['trait unit'] = trait_model['units']
header_dict['band names'] = ["%s_mean" % trait_model["name"],
"%s_std" % trait_model["name"],
'range_mask'] + [mask[0] for mask in config_dict['masks']]
header_dict['bands'] = len(header_dict['band names'] )
#Generate masks
for mask,args in config_dict['masks']:
mask_function = mask_dict[mask]
hy_obj.gen_mask(mask_function,mask,args)
output_name = config_dict['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] + "_%s" % trait_model["name"]
writer = WriteENVI(output_name,header_dict)
if config_dict['file_type'] == 'envi' or config_dict['file_type'] == 'emit':
iterator = hy_obj.iterate(by = 'chunk',
chunk_size = (2,hy_obj.columns),
corrections = hy_obj.corrections,
resample=resample)
elif config_dict['file_type'] == 'neon':
iterator = hy_obj.iterate(by = 'chunk',
chunk_size = (int(np.ceil(hy_obj.lines/32)),int(np.ceil(hy_obj.columns/32))),
corrections = hy_obj.corrections,
resample=resample)
elif config_dict['file_type'] == 'ncav':
iterator = hy_obj.iterate(by = 'chunk',
chunk_size = (256,hy_obj.columns),
corrections = hy_obj.corrections,
resample=resample)
while not iterator.complete:
chunk = iterator.read_next()
if not resample:
chunk = chunk[:,:,wave_mask]
trait_est = np.zeros((chunk.shape[0],
chunk.shape[1],
header_dict['bands']))
# Apply spectrum transforms
for transform in trait_model['model']["transform"]:
if transform== "vector":
norm = np.linalg.norm(chunk,axis=2)
chunk = chunk/norm[:,:,np.newaxis]
if transform == "absorb":
chunk = np.log(1/chunk)
if transform == "mean":
mean = chunk.mean(axis=2)
chunk = chunk/mean[:,:,np.newaxis]
trait_pred = np.einsum('jkl,ml->jkm',chunk,coeffs, optimize='optimal')
trait_pred = trait_pred + intercept
trait_est[:,:,0] = trait_pred.mean(axis=2)
trait_est[:,:,1] = trait_pred.std(ddof=1,axis=2)
range_mask = (trait_est[:,:,0] > trait_model["model_diagnostics"]['min']) & \
(trait_est[:,:,0] < trait_model["model_diagnostics"]['max'])
trait_est[:,:,2] = range_mask.astype(int)
# Subset and assign custom masks
for i,(mask,args) in enumerate(config_dict['masks']):
mask = hy_obj.mask[mask][iterator.current_line:iterator.current_line+chunk.shape[0],
iterator.current_column:iterator.current_column+chunk.shape[1]]
trait_est[:,:,3+i] = mask.astype(int)
nd_mask = hy_obj.mask['no_data'][iterator.current_line:iterator.current_line+chunk.shape[0],
iterator.current_column:iterator.current_column+chunk.shape[1]]
trait_est[~nd_mask] = -9999
writer.write_chunk(trait_est,
iterator.current_line,
iterator.current_column)
writer.close()
if __name__== "__main__":
main()

View File

@ -0,0 +1,268 @@
import json
import os
import warnings
import sys
import ray
import numpy as np
import hytools as ht
from hytools.io.envi import *
from hytools.io.netcdf import *
from hytools.masks import mask_dict
warnings.filterwarnings("ignore")
def main():
config_file = sys.argv[1]
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
if len(sys.argv)>2:
meta_file = sys.argv[2]
with open(meta_file, 'r') as outfile:
meta_dict = json.load(outfile)
config_dict["outside_metadata"] = meta_dict
else:
if "outside_metadata" in config_dict:
if not isinstance(config_dict["outside_metadata"],dict):
with open(config_dict["outside_metadata"], 'r') as outfile:
# load json and replace it by a dict
meta_dict = json.load(outfile)
config_dict["outside_metadata"] = meta_dict
else:
config_dict["outside_metadata"] = None
images= config_dict["input_files"]
if ray.is_initialized():
ray.shutdown()
print("Using %s CPUs." % config_dict['num_cpus'])
ray.init(num_cpus = config_dict['num_cpus'])
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for image in images]
# Load data
if config_dict['file_type'] == 'envi':
anc_files = config_dict["anc_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_files[image]) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'emit' or config_dict['file_type'] == 'ncav':
anc_files = config_dict["anc_files"]
if bool(config_dict["glt_files"]):
glt_files = config_dict["glt_files"]
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_path=anc_files[image],glt_path=glt_files[image]) for a,image in zip(actors,images)])
else:
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_path=anc_files[image]) for a,image in zip(actors,images)])
else:
print("Image file type is not recognized.")
return
default_export_type = "envi"
if "export_type" in config_dict:
if not config_dict["export_type"] in ["envi","netcdf"]:
print("Image export file type is not recognized.")
return
else:
config_dict["export_type"]=default_export_type
if not "use_glt" in config_dict:
config_dict["use_glt"]=False
print("Estimating %s traits:" % len( config_dict['trait_models']))
for trait in config_dict['trait_models']:
with open(trait, 'r') as json_file:
trait_model = json.load(json_file)
print("\t %s" % trait_model["name"])
_ = ray.get([a.do.remote(apply_trait_models,config_dict) for a in actors])
ray.shutdown()
def apply_trait_models(hy_obj,config_dict):
'''Apply trait model(s) to image and export to file.
'''
hy_obj.create_bad_bands(config_dict['bad_bands'])
hy_obj.corrections = config_dict['corrections']
# Load correction coefficients
if 'topo' in hy_obj.corrections:
hy_obj.load_coeffs(config_dict['topo'][hy_obj.file_name],'topo')
if 'brdf' in hy_obj.corrections:
hy_obj.load_coeffs(config_dict['brdf'][hy_obj.file_name],'brdf')
hy_obj.resampler['type'] = config_dict["resampling"]['type']
for trait in config_dict['trait_models']:
with open(trait, 'r') as json_file:
trait_model = json.load(json_file)
coeffs = np.array(trait_model['model']['coefficients'])
intercept = np.array(trait_model['model']['intercepts'])
model_waves = np.array(trait_model['wavelengths'])
#Check if wavelengths match
resample = not all(x in hy_obj.wavelengths for x in model_waves)
if resample:
hy_obj.resampler['out_waves'] = model_waves
hy_obj.resampler['out_fwhm'] = trait_model['fwhm']
else:
wave_mask = [np.argwhere(x==hy_obj.wavelengths)[0][0] for x in model_waves]
use_glt_output_bool=False
if 'use_glt' in config_dict:
use_glt_output_bool = config_dict['use_glt']
if use_glt_output_bool==True:
header_dict = hy_obj.get_header(warp_glt=True)
else:
header_dict = hy_obj.get_header()
else:
header_dict = hy_obj.get_header()
# Build trait image file
header_dict['wavelength'] = []
header_dict['data ignore value'] = -9999
header_dict['data type'] = 4
header_dict['trait unit'] = trait_model['units']
header_dict['band names'] = ["%s_mean" % trait_model["name"],
"%s_std" % trait_model["name"],
'range_mask'] + [mask[0] for mask in config_dict['masks']]
header_dict['bands'] = len(header_dict['band names'] )
header_dict['file_type'] = config_dict['file_type']
header_dict['transform'] = hy_obj.transform
header_dict['projection'] = hy_obj.projection
#Generate masks
for mask,args in config_dict['masks']:
mask_function = mask_dict[mask]
hy_obj.gen_mask(mask_function,mask,args)
output_name = config_dict['output_dir']
if config_dict["export_type"]=="envi":
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] + "_%s" % trait_model["name"]
writer = WriteENVI(output_name,header_dict)
else:
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] + "_%s.nc" % trait_model["name"]
header_dict['lines_glt'] = hy_obj.lines_glt
header_dict['samples_glt'] = hy_obj.columns_glt
writer = WriteNetCDF(output_name,header_dict,
attr_dict=None,
glt_bool=use_glt_output_bool,
type_tag="trait",
band_name=trait_model["name"])
if (not use_glt_output_bool) and config_dict['file_type'] == 'emit':
writer.write_glt_dataset(hy_obj.glt_x,hy_obj.glt_y,dim_x_name="ortho_x",dim_y_name="ortho_y")
if config_dict['file_type'] == 'envi' or config_dict['file_type'] == 'emit':
iterator = hy_obj.iterate(by = 'chunk',
chunk_size = (2,hy_obj.columns),
corrections = hy_obj.corrections,
resample=resample)
elif config_dict['file_type'] == 'neon':
iterator = hy_obj.iterate(by = 'chunk',
chunk_size = (int(np.ceil(hy_obj.lines/32)),int(np.ceil(hy_obj.columns/32))),
corrections = hy_obj.corrections,
resample=resample)
elif config_dict['file_type'] == 'ncav':
iterator = hy_obj.iterate(by = 'chunk',
chunk_size = (256,hy_obj.columns),
corrections = hy_obj.corrections,
resample=resample)
out_stack = np.zeros((header_dict['bands'],header_dict['lines'],header_dict['samples'])).astype(np.float32)
while not iterator.complete:
chunk = iterator.read_next()
if not resample:
chunk = chunk[:,:,wave_mask]
trait_est = np.zeros((chunk.shape[0],
chunk.shape[1],
header_dict['bands']))
# Apply spectrum transforms
for transform in trait_model['model']["transform"]:
if transform== "vector":
norm = np.linalg.norm(chunk,axis=2)
chunk = chunk/norm[:,:,np.newaxis]
if transform == "absorb":
chunk = np.log(1/chunk)
if transform == "mean":
mean = chunk.mean(axis=2)
chunk = chunk/mean[:,:,np.newaxis]
trait_pred = np.einsum('jkl,ml->jkm',chunk,coeffs, optimize='optimal')
trait_pred = trait_pred + intercept
trait_est[:,:,0] = trait_pred.mean(axis=2)
trait_est[:,:,1] = trait_pred.std(ddof=1,axis=2)
range_mask = (trait_est[:,:,0] > trait_model["model_diagnostics"]['min']) & \
(trait_est[:,:,0] < trait_model["model_diagnostics"]['max'])
trait_est[:,:,2] = range_mask.astype(int)
# Subset and assign custom masks
for i,(mask,args) in enumerate(config_dict['masks']):
mask = hy_obj.mask[mask][iterator.current_line:iterator.current_line+chunk.shape[0],
iterator.current_column:iterator.current_column+chunk.shape[1]]
trait_est[:,:,3+i] = mask.astype(int)
nd_mask = hy_obj.mask['no_data'][iterator.current_line:iterator.current_line+chunk.shape[0],
iterator.current_column:iterator.current_column+chunk.shape[1]]
trait_est[~nd_mask,:2] = -9999
trait_est[~nd_mask,2:] = 255
x_start = iterator.current_column
x_end = iterator.current_column + trait_est.shape[1]
y_start = iterator.current_line
y_end = iterator.current_line + trait_est.shape[0]
out_stack[:,y_start:y_end,x_start:x_end] = np.moveaxis(trait_est,-1,0)
if use_glt_output_bool:
for iband in range(2):
writer.write_netcdf_band_glt(out_stack[iband,:,:],iband, (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
writer.close()
for iband in range(len(header_dict['band names'][2:])):
writer = WriteNetCDF(output_name,header_dict,
attr_dict=config_dict["outside_metadata"],
glt_bool=use_glt_output_bool,
type_tag="mask",
band_name=header_dict['band names'][2:][iband])
writer.write_mask_band_glt(out_stack[2+iband,:,:], (hy_obj.glt_y[hy_obj.fill_mask]-1,hy_obj.glt_x[hy_obj.fill_mask]-1),hy_obj.fill_mask)
writer.close()
else:
for iband in range(2):
writer.write_band(out_stack[iband,:,:],iband)
writer.close()
for iband in range(len(header_dict['band names'][2:])):
writer = WriteNetCDF(output_name,header_dict,
attr_dict=config_dict["outside_metadata"],
glt_bool=use_glt_output_bool,
type_tag="mask",
band_name=header_dict['band names'][2:][iband])
writer.write_mask_band(out_stack[2+iband,:,:])
writer.close()
if __name__== "__main__":
main()

View File

@ -0,0 +1,153 @@
import json
import os
import warnings
import sys
import ray
import numpy as np
import hytools as ht
from hytools.io.envi import *
from hytools.masks import mask_dict
warnings.filterwarnings("ignore")
def main():
config_file = sys.argv[1]
with open(config_file, 'r') as outfile:
config_dict = json.load(outfile)
images= config_dict["input_files"]
if ray.is_initialized():
ray.shutdown()
print("Using %s CPUs." % config_dict['num_cpus'])
ray.init(num_cpus = config_dict['num_cpus'])
HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for image in images]
# Load data
if config_dict['file_type'] == 'envi':
anc_files = config_dict["anc_files"]
if ('topo' in config_dict['corrections']) or ('brdf' in config_dict['corrections']):
_ = ray.get([a.read_file.remote(image,config_dict['file_type'],
anc_files[image]) for a,image in zip(actors,images)])
else:
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
elif config_dict['file_type'] == 'neon':
_ = ray.get([a.read_file.remote(image,config_dict['file_type']) for a,image in zip(actors,images)])
print("Estimating %s traits:" % len( config_dict['trait_models']))
for trait in config_dict['trait_models']:
with open(trait, 'r') as json_file:
trait_model = json.load(json_file)
print("\t %s" % trait_model["name"])
_ = ray.get([a.do.remote(apply_trait_models,config_dict) for a in actors])
ray.shutdown()
def apply_trait_models(hy_obj,config_dict):
'''Apply trait model(s) to image and export to file.
'''
hy_obj.create_bad_bands(config_dict['bad_bands'])
hy_obj.corrections = config_dict['corrections']
# Load correction coefficients
if 'topo' in hy_obj.corrections:
hy_obj.load_coeffs(config_dict['topo'][hy_obj.file_name],'topo')
if 'brdf' in hy_obj.corrections:
hy_obj.load_coeffs(config_dict['brdf'][hy_obj.file_name],'brdf')
hy_obj.resampler['type'] = config_dict["resampling"]['type']
for trait in config_dict['trait_models']:
with open(trait, 'r') as json_file:
trait_model = json.load(json_file)
coeffs = np.array(trait_model['model']['coefficients'])
intercept = np.array(trait_model['model']['intercepts'])
model_waves = np.array(trait_model['wavelengths'])
#Check if wavelengths match
resample = not all(x in hy_obj.wavelengths for x in model_waves)
if resample:
hy_obj.resampler['out_waves'] = model_waves
else:
wave_mask = [np.argwhere(x==hy_obj.wavelengths)[0][0] for x in model_waves]
# Build trait image file
header_dict = hy_obj.get_header()
header_dict['wavelength'] = []
header_dict['data ignore value'] = -9999
header_dict['data type'] = 4
header_dict['band names'] = ["%s_mean" % trait_model["name"],
"%s_std" % trait_model["name"],
'range_mask'] + [mask[0] for mask in config_dict['masks']]
header_dict['bands'] = len(header_dict['band names'] )
#Generate masks
for mask,args in config_dict['masks']:
mask_function = mask_dict[mask]
hy_obj.gen_mask(mask_function,mask,args)
output_name = config_dict['output_dir']
output_name += os.path.splitext(os.path.basename(hy_obj.file_name))[0] + "_%s" % trait_model["name"]
writer = WriteENVI(output_name,header_dict)
iterator = hy_obj.iterate(by = 'chunk',
chunk_size = (100,100),
corrections = hy_obj.corrections,
resample=resample)
while not iterator.complete:
chunk = iterator.read_next()
if not resample:
chunk = chunk[:,:,wave_mask]
trait_est = np.zeros((chunk.shape[0],
chunk.shape[1],
header_dict['bands']))
# Apply spectrum transforms
for transform in trait_model['model']["transform"]:
if transform== "vector":
norm = np.linalg.norm(chunk,axis=2)
chunk = chunk/norm[:,:,np.newaxis]
if transform == "absorb":
chunk = np.log(1/chunk)
if transform == "mean":
mean = chunk.mean(axis=2)
chunk = chunk/mean[:,:,np.newaxis]
trait_pred = np.einsum('jkl,ml->jkm',chunk,coeffs, optimize='optimal')
trait_pred = trait_pred + intercept
trait_est[:,:,0] = trait_pred.mean(axis=2)
trait_est[:,:,1] = trait_pred.std(ddof=1,axis=2)
range_mask = (trait_est[:,:,0] > trait_model["model_diagnostics"]['min']) & \
(trait_est[:,:,0] < trait_model["model_diagnostics"]['max'])
trait_est[:,:,2] = range_mask.astype(int)
# Subset and assign custom masks
for i,(mask,args) in enumerate(config_dict['masks']):
mask = hy_obj.mask[mask][iterator.current_line:iterator.current_line+chunk.shape[0],
iterator.current_column:iterator.current_column+chunk.shape[1]]
trait_est[:,:,3+i] = mask.astype(int)
nd_mask = hy_obj.mask['no_data'][iterator.current_line:iterator.current_line+chunk.shape[0],
iterator.current_column:iterator.current_column+chunk.shape[1]]
trait_est[~nd_mask] = -9999
writer.write_chunk(trait_est,
iterator.current_line,
iterator.current_column)
writer.close()
if __name__== "__main__":
main()

View File

@ -0,0 +1,147 @@
'''transform.py
TODO: Add MNF option
parser.add_argument("-t", help="Transform type", type = str)
'''
import argparse
import pickle
import os
from shutil import which
import ray
import numpy as np
from sklearn.decomposition import PCA
import hytools as ht
from hytools.io.envi import WriteENVI
def main():
'''
This script exports PCA transformed images. A single image or a group
of images can be provided as input. In the case of a group of images the PCA decomposition will be performed
using sampled data pooled from all images. All images must be of the same format, either all ENVI or all NEON.
Images can be optionally mosaicked to a GEOTIFF. Mosaicking is done using gdal_merge.py and therefore
requires gdal to be installed. Mosiacking won't work properly on images with a rotation.
'''
parser = argparse.ArgumentParser(description = "Perform a PCA")
parser.add_argument('images',help="Input image pathnames", nargs='*')
parser.add_argument('output_dir',help="Output directory", type = str)
parser.add_argument("-comps", help="Number of components to export", type = int,required=False,default=10)
parser.add_argument("-sample", help="Percent of data to subsample", type = float,required=False,default=0.1)
parser.add_argument("-merge", help="Use gdal_merge.py to mosaic PCA images", required=False, action='store_true')
parser.add_argument("-inv", help="Apply inverse transform", required=False, action='store_true')
args = parser.parse_args()
if not args.output_dir.endswith("/"):
args.output_dir+="/"
if ray.is_initialized():
ray.shutdown()
ray.init(num_cpus = len(args.images))
hytool = ray.remote(ht.HyTools)
actors = [hytool.remote() for image in args.images]
if args.images[0].endswith('.h5'):
file_type = 'neon'
else:
file_type = 'envi'
_ = ray.get([a.read_file.remote(image,file_type) for a,image in zip(actors,args.images)])
# Sample data
samples = ray.get([a.do.remote(subsample,args) for a in actors])
# Center, scale and fit PCA transform
X = np.concatenate(samples).astype(np.float32)
x_mean = X.mean(axis=0)[np.newaxis,:]
X -=x_mean
x_std = X.std(axis=0,ddof=1)[np.newaxis,:]
X /=x_std
X = X[~np.isnan(X.sum(axis=1)) & ~np.isinf(X.sum(axis=1)),:]
print('Performing PCA decomposition')
pca = PCA(n_components=args.comps)
pca.fit(X)
pca_pkl = pickle.dumps(pca)
args.pca_pkl = pca_pkl
args.x_mean = x_mean
args.x_std = x_std
#Apply tranform and export
_ = ray.get([a.do.remote(apply_transform,args) for a in actors])
if args.merge and len(args.images) > 1:
if which('gdal_merge.py') is not None:
print('Mosaicking flightlines')
output_files = ["%s%s_pca" %(args.output_dir,image) for image in \
ray.get([a.do.remote(lambda x : x.base_name) for a in actors])]
string = ['gdal_merge.py','-o', '%stransform_mosaic.tif' % args.output_dir] + output_files
os.system(' '.join(string))
else:
print('gdal_merge.py not found, exiting.')
def subsample(hy_obj,args):
print("Sampling %s" % os.path.basename(hy_obj.file_name))
# Select 'sample_perc' % of pixels for modeling
# This can probably be written more concisely
sub_samples = np.zeros((hy_obj.lines,hy_obj.columns)).astype(bool)
idx = np.array(np.where(hy_obj.mask['no_data'])).T
idxRand= idx[np.random.choice(range(len(idx)),int(len(idx)*args.sample), replace = False)].T
sub_samples[idxRand[0],idxRand[1]] = True
hy_obj.mask['samples'] = sub_samples
X = []
hy_obj.create_bad_bands([[300,400],[1300,1450],[1780,2000],[2450,2600]])
for band_num,band in enumerate(hy_obj.bad_bands):
if ~band:
X.append(hy_obj.get_band(band_num,mask='samples'))
return np.array(X).T
def apply_transform(hy_obj,args):
print("Exporting %s PCA" % hy_obj.base_name)
pca = pickle.loads(args.pca_pkl)
output_name = '%s/%s_pca%03d_inv' % (args.output_dir,hy_obj.base_name,pca.n_components)
header_dict = hy_obj.get_header()
header_dict['bands'] = (~hy_obj.bad_bands).sum()
header_dict['wavelength'] = hy_obj.wavelengths[~hy_obj.bad_bands]
header_dict['fwhm'] = hy_obj.fwhm[~hy_obj.bad_bands]
header_dict['data type'] = 4
header_dict['data ignore value'] = 0
if not args.inv:
header_dict['bands'] = pca.n_components
output_name = '%s/%s_pca%03d' % (args.output_dir,hy_obj.base_name,pca.n_components)
header_dict['wavelength'] = []
header_dict['fwhm'] = []
writer = WriteENVI(output_name,header_dict)
iterator = hy_obj.iterate(by = 'chunk',chunk_size = (500,500))
while not iterator.complete:
chunk = iterator.read_next()
X_chunk = chunk[:,:,~hy_obj.bad_bands].astype(np.float32)
X_chunk = X_chunk.reshape((X_chunk.shape[0]*X_chunk.shape[1],X_chunk.shape[2]))
X_chunk -=args.x_mean
X_chunk /=args.x_std
X_chunk[np.isnan(X_chunk) | np.isinf(X_chunk)] = 0
pca_chunk= pca.transform(X_chunk)
if args.inv:
pca_chunk = pca.inverse_transform(pca_chunk)
pca_chunk *=args.x_std
pca_chunk +=args.x_mean
pca_chunk = pca_chunk.reshape((chunk.shape[0],chunk.shape[1],header_dict['bands']))
pca_chunk[chunk[:,:,0] == hy_obj.no_data] =0
writer.write_chunk(pca_chunk,
iterator.current_line,
iterator.current_column)
if __name__== "__main__":
main()

File diff suppressed because it is too large Load Diff