import numpy as np import os from osgeo import gdal from util import * from type_define import * import math from pyproj import CRS from pyproj import Transformer import argparse import json def get_coor_base_lang_lat(img_path, water_mask, severe_glint, csv_pos, coor_type, radius, point_pos_strategy=PointPosStrategy.nearest_single, img_type=ImgType.ref): """ 此函数用于获取影像(img_path)在经纬度坐标(csv_pos)处的值,并输出到文件(out_txt); :param img_path: data ignore value应该为0;hdr不应该带bil等后缀 :param img_type: 枚举ImgType :param csv_pos: :param coor_type: 枚举CoorType :param radius: 单位为米 :param point_pos_strategy: 枚举PointPosStrategy :param severe_glint: 单波段图像,1代表耀斑严重的区域 :return: """ dataset = gdal.Open(img_path) im_width = dataset.RasterXSize # 栅格矩阵的列数 im_height = dataset.RasterYSize # 栅格矩阵的行数 num_bands = dataset.RasterCount # 栅格矩阵的波段数 geotransform = dataset.GetGeoTransform() # 仿射矩阵 im_proj = dataset.GetProjection() # 地图投影信息 inv_geotransform_input = gdal.InvGeoTransform(geotransform) band_number_used = 1 # 找绿波段干啥? if img_type == ImgType.ref: green_wave = 560 # 绿光信号最强 band_number_used = find_band_number(green_wave, img_path) elif img_type == ImgType.content: band_number_used = 0 position_content = np.loadtxt(csv_pos, delimiter="\t") new_columns_num = 4 # + num_bands new_columns = np.zeros((position_content.shape[0], new_columns_num)) position_content = np.hstack((position_content, new_columns)) dataset_water_mask = gdal.Open(water_mask) # geotransform_water_mask = dataset_water_mask.GetGeoTransform() # 仿射矩阵 # data_water_mask = dataset_water_mask.GetRasterBand(1).ReadAsArray() dataset_severe_glint = gdal.Open(severe_glint) geotransform_severe_glint = dataset_severe_glint.GetGeoTransform() # 仿射矩阵 data_severe_glint = dataset_severe_glint.GetRasterBand(1).ReadAsArray() inv_geotransform_severe_glint = gdal.InvGeoTransform(geotransform_severe_glint) if coor_type == CoorType.latlong: # gps epsg_code = 32600 + math.floor(position_content[0, 0] / 6) + 31 crs = CRS.from_epsg(epsg_code) projTransformer = Transformer.from_crs(crs.geodetic_crs, crs) for i in range(position_content.shape[0]): if coor_type == CoorType.latlong: # gps utm_cor = projTransformer.transform(position_content[i, 1], position_content[i, 0]) elif coor_type == CoorType.utm: # utm utm_cor = (position_content[i, 0], position_content[i, 1]) img_coor = gdal.ApplyGeoTransform(inv_geotransform_input, utm_cor[0], utm_cor[1]) # (x,y) img_coor = [int(i) for i in img_coor] radius_pixel_data = math.ceil(radius / geotransform[1] / 2) data_local = dataset.GetRasterBand(band_number_used + 1).ReadAsArray(img_coor[0] - radius_pixel_data, img_coor[1] - radius_pixel_data, 2 * radius_pixel_data + 1, 2 * radius_pixel_data + 1) water_mask_local = dataset_water_mask.GetRasterBand(1).ReadAsArray(img_coor[0] - radius_pixel_data, img_coor[1] - radius_pixel_data, 2 * radius_pixel_data + 1, 2 * radius_pixel_data + 1) # if img_type == ImgType.content: # data_local = dataset.ReadAsArray(img_coor[0] - radius_pixel_data, img_coor[1] - radius_pixel_data, # 2 * radius_pixel_data + 1, 2 * radius_pixel_data + 1) # elif img_type == ImgType.ref: # data_local = dataset.GetRasterBand(band_number_used + 1).ReadAsArray(img_coor[0] - radius_pixel_data, # img_coor[1] - radius_pixel_data, # 2 * radius_pixel_data + 1, # 2 * radius_pixel_data + 1) img_coor_severe_glint = gdal.ApplyGeoTransform(inv_geotransform_severe_glint, utm_cor[0], utm_cor[1]) img_coor_severe_glint = [int(i) for i in img_coor_severe_glint] if geotransform[1] == geotransform_severe_glint[1]: radius_severe_glint = math.ceil(radius / geotransform_severe_glint[1] / 2) # severe_glint_local = data_severe_glint[img_coor_severe_glint[1] - radius_severe_glint:img_coor_severe_glint[1] + radius_severe_glint + 1, img_coor_severe_glint[0] - radius_severe_glint:img_coor_severe_glint[0] + radius_severe_glint + 1] else: # 当分辨率不一致时,还没完成???????????????????????????????????????????????????????? severe_glint_local = np.zeros(data_local.shape) for x in range(2 * radius_pixel_data + 1): for y in range(2 * radius_pixel_data + 1): img_coor_tmp = gdal.ApplyGeoTransform(geotransform, utm_cor[0], utm_cor[1]) img_coor_tmp2 = gdal.ApplyGeoTransform(inv_geotransform_severe_glint, img_coor_tmp[0], img_coor_tmp[1]) severe_glint_local[y, x] = data_severe_glint[img_coor_tmp[1], img_coor_tmp[0]] invalid_value = 0 data_local[severe_glint_local == 1] = invalid_value data_local[water_mask_local == 0] = invalid_value # 根据PointPosStrategy确定使用哪些位置的数据 if point_pos_strategy == PointPosStrategy.nearest_single: coor_y_local, coor_x_local = getnearest(data_local, invalid_value) if coor_x_local is None: continue coor_y = coor_y_local + img_coor[1] - radius_pixel_data coor_x = coor_x_local + img_coor[0] - radius_pixel_data utm_x, utm_y = gdal.ApplyGeoTransform(geotransform, coor_x, coor_y) position_content[i, position_content.shape[1] - new_columns_num] = utm_x position_content[i, position_content.shape[1] - new_columns_num + 1] = utm_y position_content[i, position_content.shape[1] - new_columns_num + 2] = int(coor_x) position_content[i, position_content.shape[1] - new_columns_num + 3] = int(coor_y) # position_content[i, position_content.shape[1] - new_columns_num + 4:position_content.shape[1]] = \ # dataset.ReadAsArray(int(coor_x), int(coor_y), 1, 1).reshape(num_bands) a = 1 elif point_pos_strategy == PointPosStrategy.four_quadrant: # 还没完成???????????????????????????????????????????????????????? a = 1 a = 1 rows_to_keep = np.where(position_content[:, position_content.shape[1] - new_columns_num + 3] != 0) position_content2 = position_content[rows_to_keep] position_content3 = position_content2[~np.isnan(position_content2).any(axis=1)] # np.savetxt(out_txt, position_content3) del dataset, dataset_severe_glint, dataset_water_mask return position_content3 def fit(x1, x2, y): A = np.column_stack((x1, x2, np.ones((x2.shape[0], 1)))) coefficients, _, _, _ = np.linalg.lstsq(A, y, rcond=None) return coefficients def accuracy_evaluation(x1, x2, y_real, coefficients): A = np.column_stack((x1, x2, np.ones((x2.shape[0], 1)))) y_pred = A.dot(coefficients) accuracy = np.absolute((y_real - y_pred) / y_real * 100) return accuracy def accuracy_evaluation_tss(x1, x2, y_real, coefficients): A = np.column_stack((x1, x2, np.ones((x2.shape[0], 1)))) y = A.dot(coefficients) y_pred = np.exp(y) accuracy = np.absolute((y_real - y_pred) / y_real * 100) return accuracy def get_x_in_coor(coor, *args): new_columns_counter = len(args) new_columns = np.zeros((coor.shape[0], new_columns_counter)) coor_extend = np.hstack((coor, new_columns)) for i in range(coor.shape[0]): for j in range(new_columns_counter): coor_extend[i, coor_extend.shape[1] - (new_columns_counter - j)] = args[j][ int(coor_extend[i, coor_extend.shape[1] - new_columns_counter - 1]), int(coor_extend[i, coor_extend.shape[1] - new_columns_counter - 2])] return coor_extend def write_model_info(model_type, coefficients, accuracy, long, lat, outpath): # 将 NumPy 数组转换为列表 np_dict = { 'model_type': model_type, 'model_info': coefficients.tolist(), 'accuracy': accuracy.tolist(), 'long': long.tolist(), 'lat': lat.tolist() } # 将字典写入 JSON 文件,使用 indent 参数进行格式化(每一级缩进4个空格) with open(outpath, 'w') as f: json.dump(np_dict, f, indent=4) def chl_a(img_path, coor, outpath_coeff, window=5): # 叶绿素 wave1 = 651 band_651 = average_bands(wave1 - window, wave1 + window, img_path) wave2 = 707 band_707 = average_bands(wave2 - window, wave2 + window, img_path) wave3 = 670 band_670 = average_bands(wave3 - window, wave3 + window, img_path) x = (band_651 - band_707) / (band_707 - band_670) coor_extend = get_x_in_coor(coor, x) # 修正模型参数并输出 x = coor_extend[:, -1] y_real = coor_extend[:, 2] coefficients = np.polyfit(list(x), list(y_real), 1) y_pred = np.polyval(coefficients, list(x)) accuracy = np.absolute((y_real - y_pred) / y_real * 100) write_model_info("chl-a", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff) model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff) return coefficients def nh3(img_path, coor, outpath_coeff, window=5): # 氨氮 wave1 = 600 band_600 = average_bands(wave1 - window, wave1 + window, img_path) wave2 = 500 band_500 = average_bands(wave2 - window, wave2 + window, img_path) wave3 = 850 band_850 = average_bands(wave3 - window, wave3 + window, img_path) x13 = np.log(band_500 / band_850) x23 = np.exp(band_600 / band_500) coor_extend = get_x_in_coor(coor, x13, x23) # 修正模型参数并输出 x1 = coor_extend[:, -2] x2 = coor_extend[:, -1] y_real = coor_extend[:, 2] coefficients = fit(x1, x2, y_real) accuracy = accuracy_evaluation(x1, x2, y_real, coefficients) write_model_info("nh3", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff) model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff) return coefficients def mno4(img_path, coor, outpath_coeff, window=5): # 高猛酸盐 wave1 = 500 band_500 = average_bands(wave1 - window, wave1 + window, img_path) wave2 = 440 band_440 = average_bands(wave2 - window, wave2 + window, img_path) wave3 = 610 band_610 = average_bands(wave3 - window, wave3 + window, img_path) wave4 = 800 band_800 = average_bands(wave4 - window, wave4 + window, img_path) x3 = band_500 / band_440 x6 = band_610 / band_800 coor_extend = get_x_in_coor(coor, x3, x6) # 修正模型参数并输出 x1 = coor_extend[:, -2] x2 = coor_extend[:, -1] y_real = coor_extend[:, 2] coefficients = fit(x1, x2, y_real) accuracy = accuracy_evaluation(x1, x2, y_real, coefficients) write_model_info("mno4", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff) model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff) return coefficients def tn(img_path, coor, outpath_coeff, window=5): # 总氮 wave1 = 600 band_600 = average_bands(wave1 - window, wave1 + window, img_path) wave2 = 500 band_500 = average_bands(wave2 - window, wave2 + window, img_path) wave3 = 850 band_850 = average_bands(wave3 - window, wave3 + window, img_path) x13 = np.log(band_500 / band_850) x23 = np.exp(band_600 / band_500) coor_extend = get_x_in_coor(coor, x13, x23) # 修正模型参数并输出 x1 = coor_extend[:, -2] x2 = coor_extend[:, -1] y_real = coor_extend[:, 2] coefficients = fit(x1, x2, y_real) accuracy = accuracy_evaluation(x1, x2, y_real, coefficients) write_model_info("tn", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff) model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff) return coefficients def tp(img_path, coor, outpath_coeff, window=5): # 总磷 wave1 = 600 band_600 = average_bands(wave1 - window, wave1 + window, img_path) wave2 = 500 band_500 = average_bands(wave2 - window, wave2 + window, img_path) wave3 = 850 band_850 = average_bands(wave3 - window, wave3 + window, img_path) x13 = np.log(band_500 / band_850) x23 = np.exp(band_600 / band_500) coor_extend = get_x_in_coor(coor, x13, x23) # 修正模型参数并输出 x1 = coor_extend[:, -2] x2 = coor_extend[:, -1] y_real = coor_extend[:, 2] coefficients = fit(x1, x2, y_real) accuracy = accuracy_evaluation(x1, x2, y_real, coefficients) write_model_info("tp", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff) model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff) return coefficients def tss(img_path, coor, outpath_coeff, window=5): # 总悬浮物?????????????????? wave1 = 555 band_555 = average_bands(wave1 - window, wave1 + window, img_path) wave2 = 670 band_670 = average_bands(wave2 - window, wave2 + window, img_path) wave3 = 490 band_490 = average_bands(wave3 - window, wave3 + window, img_path) x1 = band_555 + band_670 x2 = band_490 / band_555 coor_extend = get_x_in_coor(coor, x1, x2) # 修正模型参数并输出 x1 = coor_extend[:, -2] x2 = coor_extend[:, -1] y_real = coor_extend[:, 2] y = np.log(y_real) coefficients = fit(x1, x2, y) accuracy = accuracy_evaluation_tss(x1, x2, y_real, coefficients) write_model_info("tss", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff) model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff) return coefficients def main(): parser = argparse.ArgumentParser(description="此程序用于通过实测数据修正模型参数。") # parser.add_argument("--global_arg", type=str, help="A global argument for all modes", required=True) # 创建子命令解析器 subparsers = parser.add_subparsers(dest="algorithm", required=True, help="Choose a mode") chl_a_ = subparsers.add_parser("chl_a", help="叶绿素") chl_a_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径') chl_a_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径') chl_a_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径') chl_a_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径') chl_a_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径') chl_a_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径') chl_a_.set_defaults(func=chl_a) nh3_ = subparsers.add_parser("nh3", help="氨氮") nh3_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径') nh3_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径') nh3_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径') nh3_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径') nh3_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径') nh3_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径') nh3_.set_defaults(func=nh3) mno4_ = subparsers.add_parser("mno4", help="高猛酸盐") mno4_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径') mno4_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径') mno4_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径') mno4_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径') mno4_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径') mno4_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径') mno4_.set_defaults(func=mno4) tn_ = subparsers.add_parser("tn", help="总氮") tn_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径') tn_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径') tn_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径') tn_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径') tn_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径') tn_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径') tn_.set_defaults(func=tn) tp_ = subparsers.add_parser("tp", help="总磷") tp_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径') tp_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径') tp_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径') tp_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径') tp_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径') tp_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径') tp_.set_defaults(func=tp) tss_ = subparsers.add_parser("tss", help="总悬浮物") tss_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径') tss_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径') tss_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径') tss_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径') tss_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径') tss_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径') tss_.set_defaults(func=tss) # 解析参数 args = parser.parse_args() tmp = get_coor_base_lang_lat(args.img, args.water_mask, args.severe_glint, args.measured_data, CoorType.latlong, args.radius) if args.algorithm == "chl_a": args.func(args.img, tmp, args.model_info_outpath) elif args.algorithm == "nh3": args.func(args.img, tmp, args.model_info_outpath) elif args.algorithm == "mno4": args.func(args.img, tmp, args.model_info_outpath) elif args.algorithm == "tn": args.func(args.img, tmp, args.model_info_outpath) elif args.algorithm == "tp": args.func(args.img, tmp, args.model_info_outpath) elif args.algorithm == "tss": args.func(args.img, tmp, args.model_info_outpath) # Press the green button in the gutter to run the script. if __name__ == '__main__': main()