import os, spectral import time import numpy as np from osgeo import gdal from enum import Enum, unique import math import json class CoorType(Enum): depend_on_image = 0 # 影像是啥类型坐标就是啥坐标 pixel = 1 class Timer: # Context Manager def __enter__(self): self.start = time.time() return self def __exit__(self, exc_type, exc_value, traceback): print(exc_type, exc_value, traceback) print(f"Run Time: {time.time() - self.start}") def timeit(f): # decorator def wraper(*args, **kwargs): start = time.time() ret = f(*args, **kwargs) print(f.__name__ + " run time: " + str(round(time.time() - start, 2)) + " s.") return ret return wraper def get_hdr_file_path(file_path): return os.path.splitext(file_path)[0] + ".hdr" def find_band_number(wav1, imgpath): in_hdr_dict = spectral.envi.read_envi_header(get_hdr_file_path(imgpath)) wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64') differences = np.abs(wavelengths - wav1) min_position = np.argmin(differences) return int(min_position) @timeit def average_bands(start_wave, end_wave, imgpath): start_bandnumber = find_band_number(start_wave, imgpath) end_bandnumber = find_band_number(end_wave, imgpath) dataset = gdal.Open(imgpath) averaged_band = 1 for i in range(start_bandnumber, end_bandnumber + 1): if i == start_bandnumber: averaged_band = dataset.GetRasterBand(i + 1).ReadAsArray() else: tmp = dataset.GetRasterBand(i + 1).ReadAsArray() averaged_band = (averaged_band + tmp) / 2 del dataset return averaged_band def exclude_by_mask(band, water_mask_path, ignore_value=0): dataset = gdal.Open(water_mask_path) data_tmp = dataset.GetRasterBand(1).ReadAsArray() del dataset band[np.where(data_tmp == ignore_value)] = 0 return band @timeit def average_bands_in_mask(start_wave, end_wave, imgpath, water_mask_path): tmp = average_bands(start_wave, end_wave, imgpath) tmp = exclude_by_mask(tmp, water_mask_path) # raster_fn_out_tmp = append2filename(imgpath, "glint_delete") # write_bands(imgpath, raster_fn_out_tmp, tmp) return tmp def get_average_value(dataset, x, y, band_number, window): spectral_tmp = dataset.ReadAsArray(x, y, 1, 1) average_value = spectral_tmp[band_number - window:band_number + window, :, :].mean() return average_value def get_valid_extent(dataset, data_ignore_value=0): pass def write_bands(imgpath_in, imgpath_out, *args): # 将输入的波段(可变)写入文件 dataset = gdal.Open(imgpath_in) im_width = dataset.RasterXSize im_height = dataset.RasterYSize num_bands = dataset.RasterCount geotransform = dataset.GetGeoTransform() im_proj = dataset.GetProjection() format = "ENVI" driver = gdal.GetDriverByName(format) dst_ds = driver.Create(imgpath_out, im_width, im_height, len(args), gdal.GDT_Float32, options=["INTERLEAVE=BSQ"]) dst_ds.SetGeoTransform(geotransform) dst_ds.SetProjection(im_proj) for i in range(len(args)): dst_ds.GetRasterBand(i + 1).WriteArray(args[i]) del dataset, dst_ds def append2filename(file_path, txt2add): imgfile_out_tmp = os.path.splitext(file_path) new_file_path = imgfile_out_tmp[0] + "_" + txt2add + imgfile_out_tmp[1] return new_file_path def write_fields_to_hdrfile(source_hdr_file, dest_hdr_file): source_fields = spectral.envi.read_envi_header(source_hdr_file) dest_fields = spectral.envi.read_envi_header(dest_hdr_file) with open(dest_hdr_file, "a", encoding='utf-8') as f: for key in source_fields.keys(): if key in dest_fields or key == "description": continue if key == "data ignore value" or key == "wavelength" or key == "wavelength units": if type(source_fields[key]) == list: f.write(key + " = {" + ", ".join(source_fields[key]) + "}\n") else: f.write(key + " = " + source_fields[key] + "\n") def getnearest(m, invalid_value=0): layer_number = math.floor(m.shape[0] / 2) center = layer_number for i in range(layer_number + 1): orig = (center - i, center - i) tmp = m[center - i:center + i + 1, center - i:center + i + 1] valid_indices = np.where((tmp != invalid_value) & np.isfinite(tmp)) if valid_indices[0].shape[0] != 0: return int(valid_indices[0][0] + orig[0]), int(valid_indices[1][0] + orig[0]) # (y ,x) return None, None def load_numpy_dict_from_json(filename): with open(filename, 'r') as f: np_dict = json.load(f) # 将字典中的列表转换回 NumPy 数组 model_type = np_dict['model_type'] model_info = np.array(np_dict['model_info']) precision = np.array(np_dict['accuracy']) return model_type, model_info, precision