175 lines
4.9 KiB
Python
175 lines
4.9 KiB
Python
import os, spectral
|
|
import time
|
|
import numpy as np
|
|
from osgeo import gdal
|
|
from enum import Enum, unique
|
|
import math
|
|
import json
|
|
|
|
|
|
class CoorType(Enum):
|
|
depend_on_image = 0 # 影像是啥类型坐标就是啥坐标
|
|
pixel = 1
|
|
|
|
|
|
class Timer: # Context Manager
|
|
def __enter__(self):
|
|
self.start = time.time()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
print(exc_type, exc_value, traceback)
|
|
print(f"Run Time: {time.time() - self.start}")
|
|
|
|
|
|
def timeit(f): # decorator
|
|
def wraper(*args, **kwargs):
|
|
start = time.time()
|
|
ret = f(*args, **kwargs)
|
|
print(f.__name__ + " run time: " + str(round(time.time() - start, 2)) + " s.")
|
|
return ret
|
|
|
|
return wraper
|
|
|
|
|
|
def get_hdr_file_path(file_path):
|
|
return os.path.splitext(file_path)[0] + ".hdr"
|
|
|
|
|
|
def find_band_number(wav1, imgpath):
|
|
in_hdr_dict = spectral.envi.read_envi_header(get_hdr_file_path(imgpath))
|
|
|
|
wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64')
|
|
|
|
differences = np.abs(wavelengths - wav1)
|
|
min_position = np.argmin(differences)
|
|
|
|
return int(min_position)
|
|
|
|
|
|
@timeit
|
|
def average_bands(start_wave, end_wave, imgpath):
|
|
start_bandnumber = find_band_number(start_wave, imgpath)
|
|
end_bandnumber = find_band_number(end_wave, imgpath)
|
|
|
|
dataset = gdal.Open(imgpath)
|
|
|
|
averaged_band = 1
|
|
for i in range(start_bandnumber, end_bandnumber + 1):
|
|
if i == start_bandnumber:
|
|
averaged_band = dataset.GetRasterBand(i + 1).ReadAsArray()
|
|
else:
|
|
tmp = dataset.GetRasterBand(i + 1).ReadAsArray()
|
|
averaged_band = (averaged_band + tmp) / 2
|
|
|
|
del dataset
|
|
|
|
return averaged_band
|
|
|
|
|
|
def exclude_by_mask(band, water_mask_path, ignore_value=0):
|
|
dataset = gdal.Open(water_mask_path)
|
|
data_tmp = dataset.GetRasterBand(1).ReadAsArray()
|
|
|
|
del dataset
|
|
|
|
band[np.where(data_tmp == ignore_value)] = 0
|
|
|
|
return band
|
|
|
|
|
|
@timeit
|
|
def average_bands_in_mask(start_wave, end_wave, imgpath, water_mask_path):
|
|
tmp = average_bands(start_wave, end_wave, imgpath)
|
|
|
|
tmp = exclude_by_mask(tmp, water_mask_path)
|
|
|
|
# raster_fn_out_tmp = append2filename(imgpath, "glint_delete")
|
|
# write_bands(imgpath, raster_fn_out_tmp, tmp)
|
|
|
|
return tmp
|
|
|
|
|
|
def get_average_value(dataset, x, y, band_number, window):
|
|
spectral_tmp = dataset.ReadAsArray(x, y, 1, 1)
|
|
average_value = spectral_tmp[band_number - window:band_number + window, :, :].mean()
|
|
|
|
return average_value
|
|
|
|
|
|
def get_valid_extent(dataset, data_ignore_value=0):
|
|
pass
|
|
|
|
|
|
def write_bands(imgpath_in, imgpath_out, *args):
|
|
# 将输入的波段(可变)写入文件
|
|
dataset = gdal.Open(imgpath_in)
|
|
im_width = dataset.RasterXSize
|
|
im_height = dataset.RasterYSize
|
|
num_bands = dataset.RasterCount
|
|
geotransform = dataset.GetGeoTransform()
|
|
im_proj = dataset.GetProjection()
|
|
|
|
format = "ENVI"
|
|
driver = gdal.GetDriverByName(format)
|
|
dst_ds = driver.Create(imgpath_out, im_width, im_height, len(args), gdal.GDT_Float32,
|
|
options=["INTERLEAVE=BSQ"])
|
|
dst_ds.SetGeoTransform(geotransform)
|
|
dst_ds.SetProjection(im_proj)
|
|
|
|
for i in range(len(args)):
|
|
dst_ds.GetRasterBand(i + 1).WriteArray(args[i])
|
|
|
|
del dataset, dst_ds
|
|
|
|
|
|
def append2filename(file_path, txt2add):
|
|
imgfile_out_tmp = os.path.splitext(file_path)
|
|
new_file_path = imgfile_out_tmp[0] + "_" + txt2add + imgfile_out_tmp[1]
|
|
|
|
return new_file_path
|
|
|
|
|
|
def write_fields_to_hdrfile(source_hdr_file, dest_hdr_file):
|
|
source_fields = spectral.envi.read_envi_header(source_hdr_file)
|
|
dest_fields = spectral.envi.read_envi_header(dest_hdr_file)
|
|
|
|
with open(dest_hdr_file, "a", encoding='utf-8') as f:
|
|
for key in source_fields.keys():
|
|
if key in dest_fields or key == "description":
|
|
continue
|
|
|
|
if key == "data ignore value" or key == "wavelength" or key == "wavelength units":
|
|
if type(source_fields[key]) == list:
|
|
f.write(key + " = {" + ", ".join(source_fields[key]) + "}\n")
|
|
else:
|
|
f.write(key + " = " + source_fields[key] + "\n")
|
|
|
|
|
|
def getnearest(m, invalid_value=0):
|
|
layer_number = math.floor(m.shape[0] / 2)
|
|
center = layer_number
|
|
|
|
for i in range(layer_number + 1):
|
|
orig = (center - i, center - i)
|
|
|
|
tmp = m[center - i:center + i + 1, center - i:center + i + 1]
|
|
valid_indices = np.where((tmp != invalid_value) & np.isfinite(tmp))
|
|
|
|
if valid_indices[0].shape[0] != 0:
|
|
return int(valid_indices[0][0] + orig[0]), int(valid_indices[1][0] + orig[0]) # (y ,x)
|
|
|
|
return None, None
|
|
|
|
|
|
def load_numpy_dict_from_json(filename):
|
|
with open(filename, 'r') as f:
|
|
np_dict = json.load(f)
|
|
|
|
# 将字典中的列表转换回 NumPy 数组
|
|
model_type = np_dict['model_type']
|
|
model_info = np.array(np_dict['model_info'])
|
|
precision = np.array(np_dict['accuracy'])
|
|
|
|
return model_type, model_info, precision
|