实现水体含量反演程序。
This commit is contained in:
174
util.py
Normal file
174
util.py
Normal file
@ -0,0 +1,174 @@
|
||||
import os, spectral
|
||||
import time
|
||||
import numpy as np
|
||||
from osgeo import gdal
|
||||
from enum import Enum, unique
|
||||
import math
|
||||
import json
|
||||
|
||||
|
||||
class CoorType(Enum):
|
||||
depend_on_image = 0 # 影像是啥类型坐标就是啥坐标
|
||||
pixel = 1
|
||||
|
||||
|
||||
class Timer: # Context Manager
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
print(exc_type, exc_value, traceback)
|
||||
print(f"Run Time: {time.time() - self.start}")
|
||||
|
||||
|
||||
def timeit(f): # decorator
|
||||
def wraper(*args, **kwargs):
|
||||
start = time.time()
|
||||
ret = f(*args, **kwargs)
|
||||
print(f.__name__ + " run time: " + str(round(time.time() - start, 2)) + " s.")
|
||||
return ret
|
||||
|
||||
return wraper
|
||||
|
||||
|
||||
def get_hdr_file_path(file_path):
|
||||
return os.path.splitext(file_path)[0] + ".hdr"
|
||||
|
||||
|
||||
def find_band_number(wav1, imgpath):
|
||||
in_hdr_dict = spectral.envi.read_envi_header(get_hdr_file_path(imgpath))
|
||||
|
||||
wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64')
|
||||
|
||||
differences = np.abs(wavelengths - wav1)
|
||||
min_position = np.argmin(differences)
|
||||
|
||||
return int(min_position)
|
||||
|
||||
|
||||
@timeit
|
||||
def average_bands(start_wave, end_wave, imgpath):
|
||||
start_bandnumber = find_band_number(start_wave, imgpath)
|
||||
end_bandnumber = find_band_number(end_wave, imgpath)
|
||||
|
||||
dataset = gdal.Open(imgpath)
|
||||
|
||||
averaged_band = 1
|
||||
for i in range(start_bandnumber, end_bandnumber + 1):
|
||||
if i == start_bandnumber:
|
||||
averaged_band = dataset.GetRasterBand(i + 1).ReadAsArray()
|
||||
else:
|
||||
tmp = dataset.GetRasterBand(i + 1).ReadAsArray()
|
||||
averaged_band = (averaged_band + tmp) / 2
|
||||
|
||||
del dataset
|
||||
|
||||
return averaged_band
|
||||
|
||||
|
||||
def exclude_by_mask(band, water_mask_path, ignore_value=0):
|
||||
dataset = gdal.Open(water_mask_path)
|
||||
data_tmp = dataset.GetRasterBand(1).ReadAsArray()
|
||||
|
||||
del dataset
|
||||
|
||||
band[np.where(data_tmp == ignore_value)] = 0
|
||||
|
||||
return band
|
||||
|
||||
|
||||
@timeit
|
||||
def average_bands_in_mask(start_wave, end_wave, imgpath, water_mask_path):
|
||||
tmp = average_bands(start_wave, end_wave, imgpath)
|
||||
|
||||
tmp = exclude_by_mask(tmp, water_mask_path)
|
||||
|
||||
# raster_fn_out_tmp = append2filename(imgpath, "glint_delete")
|
||||
# write_bands(imgpath, raster_fn_out_tmp, tmp)
|
||||
|
||||
return tmp
|
||||
|
||||
|
||||
def get_average_value(dataset, x, y, band_number, window):
|
||||
spectral_tmp = dataset.ReadAsArray(x, y, 1, 1)
|
||||
average_value = spectral_tmp[band_number - window:band_number + window, :, :].mean()
|
||||
|
||||
return average_value
|
||||
|
||||
|
||||
def get_valid_extent(dataset, data_ignore_value=0):
|
||||
pass
|
||||
|
||||
|
||||
def write_bands(imgpath_in, imgpath_out, *args):
|
||||
# 将输入的波段(可变)写入文件
|
||||
dataset = gdal.Open(imgpath_in)
|
||||
im_width = dataset.RasterXSize
|
||||
im_height = dataset.RasterYSize
|
||||
num_bands = dataset.RasterCount
|
||||
geotransform = dataset.GetGeoTransform()
|
||||
im_proj = dataset.GetProjection()
|
||||
|
||||
format = "ENVI"
|
||||
driver = gdal.GetDriverByName(format)
|
||||
dst_ds = driver.Create(imgpath_out, im_width, im_height, len(args), gdal.GDT_Float32,
|
||||
options=["INTERLEAVE=BSQ"])
|
||||
dst_ds.SetGeoTransform(geotransform)
|
||||
dst_ds.SetProjection(im_proj)
|
||||
|
||||
for i in range(len(args)):
|
||||
dst_ds.GetRasterBand(i + 1).WriteArray(args[i])
|
||||
|
||||
del dataset, dst_ds
|
||||
|
||||
|
||||
def append2filename(file_path, txt2add):
|
||||
imgfile_out_tmp = os.path.splitext(file_path)
|
||||
new_file_path = imgfile_out_tmp[0] + "_" + txt2add + imgfile_out_tmp[1]
|
||||
|
||||
return new_file_path
|
||||
|
||||
|
||||
def write_fields_to_hdrfile(source_hdr_file, dest_hdr_file):
|
||||
source_fields = spectral.envi.read_envi_header(source_hdr_file)
|
||||
dest_fields = spectral.envi.read_envi_header(dest_hdr_file)
|
||||
|
||||
with open(dest_hdr_file, "a", encoding='utf-8') as f:
|
||||
for key in source_fields.keys():
|
||||
if key in dest_fields or key == "description":
|
||||
continue
|
||||
|
||||
if key == "data ignore value" or key == "wavelength" or key == "wavelength units":
|
||||
if type(source_fields[key]) == list:
|
||||
f.write(key + " = {" + ", ".join(source_fields[key]) + "}\n")
|
||||
else:
|
||||
f.write(key + " = " + source_fields[key] + "\n")
|
||||
|
||||
|
||||
def getnearest(m, invalid_value=0):
|
||||
layer_number = math.floor(m.shape[0] / 2)
|
||||
center = layer_number
|
||||
|
||||
for i in range(layer_number + 1):
|
||||
orig = (center - i, center - i)
|
||||
|
||||
tmp = m[center - i:center + i + 1, center - i:center + i + 1]
|
||||
valid_indices = np.where((tmp != invalid_value) & np.isfinite(tmp))
|
||||
|
||||
if valid_indices[0].shape[0] != 0:
|
||||
return int(valid_indices[0][0] + orig[0]), int(valid_indices[1][0] + orig[0]) # (y ,x)
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def load_numpy_dict_from_json(filename):
|
||||
with open(filename, 'r') as f:
|
||||
np_dict = json.load(f)
|
||||
|
||||
# 将字典中的列表转换回 NumPy 数组
|
||||
model_type = np_dict['model_type']
|
||||
model_info = np.array(np_dict['model_info'])
|
||||
precision = np.array(np_dict['accuracy'])
|
||||
|
||||
return model_type, model_info, precision
|
Reference in New Issue
Block a user