Files
water_content_retrieval/util.py
2025-01-06 10:18:08 +08:00

175 lines
4.9 KiB
Python

import os, spectral
import time
import numpy as np
from osgeo import gdal
from enum import Enum, unique
import math
import json
class CoorType(Enum):
depend_on_image = 0 # 影像是啥类型坐标就是啥坐标
pixel = 1
class Timer: # Context Manager
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
print(exc_type, exc_value, traceback)
print(f"Run Time: {time.time() - self.start}")
def timeit(f): # decorator
def wraper(*args, **kwargs):
start = time.time()
ret = f(*args, **kwargs)
print(f.__name__ + " run time: " + str(round(time.time() - start, 2)) + " s.")
return ret
return wraper
def get_hdr_file_path(file_path):
return os.path.splitext(file_path)[0] + ".hdr"
def find_band_number(wav1, imgpath):
in_hdr_dict = spectral.envi.read_envi_header(get_hdr_file_path(imgpath))
wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64')
differences = np.abs(wavelengths - wav1)
min_position = np.argmin(differences)
return int(min_position)
@timeit
def average_bands(start_wave, end_wave, imgpath):
start_bandnumber = find_band_number(start_wave, imgpath)
end_bandnumber = find_band_number(end_wave, imgpath)
dataset = gdal.Open(imgpath)
averaged_band = 1
for i in range(start_bandnumber, end_bandnumber + 1):
if i == start_bandnumber:
averaged_band = dataset.GetRasterBand(i + 1).ReadAsArray()
else:
tmp = dataset.GetRasterBand(i + 1).ReadAsArray()
averaged_band = (averaged_band + tmp) / 2
del dataset
return averaged_band
def exclude_by_mask(band, water_mask_path, ignore_value=0):
dataset = gdal.Open(water_mask_path)
data_tmp = dataset.GetRasterBand(1).ReadAsArray()
del dataset
band[np.where(data_tmp == ignore_value)] = 0
return band
@timeit
def average_bands_in_mask(start_wave, end_wave, imgpath, water_mask_path):
tmp = average_bands(start_wave, end_wave, imgpath)
tmp = exclude_by_mask(tmp, water_mask_path)
# raster_fn_out_tmp = append2filename(imgpath, "glint_delete")
# write_bands(imgpath, raster_fn_out_tmp, tmp)
return tmp
def get_average_value(dataset, x, y, band_number, window):
spectral_tmp = dataset.ReadAsArray(x, y, 1, 1)
average_value = spectral_tmp[band_number - window:band_number + window, :, :].mean()
return average_value
def get_valid_extent(dataset, data_ignore_value=0):
pass
def write_bands(imgpath_in, imgpath_out, *args):
# 将输入的波段(可变)写入文件
dataset = gdal.Open(imgpath_in)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
num_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
format = "ENVI"
driver = gdal.GetDriverByName(format)
dst_ds = driver.Create(imgpath_out, im_width, im_height, len(args), gdal.GDT_Float32,
options=["INTERLEAVE=BSQ"])
dst_ds.SetGeoTransform(geotransform)
dst_ds.SetProjection(im_proj)
for i in range(len(args)):
dst_ds.GetRasterBand(i + 1).WriteArray(args[i])
del dataset, dst_ds
def append2filename(file_path, txt2add):
imgfile_out_tmp = os.path.splitext(file_path)
new_file_path = imgfile_out_tmp[0] + "_" + txt2add + imgfile_out_tmp[1]
return new_file_path
def write_fields_to_hdrfile(source_hdr_file, dest_hdr_file):
source_fields = spectral.envi.read_envi_header(source_hdr_file)
dest_fields = spectral.envi.read_envi_header(dest_hdr_file)
with open(dest_hdr_file, "a", encoding='utf-8') as f:
for key in source_fields.keys():
if key in dest_fields or key == "description":
continue
if key == "data ignore value" or key == "wavelength" or key == "wavelength units":
if type(source_fields[key]) == list:
f.write(key + " = {" + ", ".join(source_fields[key]) + "}\n")
else:
f.write(key + " = " + source_fields[key] + "\n")
def getnearest(m, invalid_value=0):
layer_number = math.floor(m.shape[0] / 2)
center = layer_number
for i in range(layer_number + 1):
orig = (center - i, center - i)
tmp = m[center - i:center + i + 1, center - i:center + i + 1]
valid_indices = np.where((tmp != invalid_value) & np.isfinite(tmp))
if valid_indices[0].shape[0] != 0:
return int(valid_indices[0][0] + orig[0]), int(valid_indices[1][0] + orig[0]) # (y ,x)
return None, None
def load_numpy_dict_from_json(filename):
with open(filename, 'r') as f:
np_dict = json.load(f)
# 将字典中的列表转换回 NumPy 数组
model_type = np_dict['model_type']
model_info = np.array(np_dict['model_info'])
precision = np.array(np_dict['accuracy'])
return model_type, model_info, precision