import numpy as np import os from osgeo import gdal from util import * from type_define import * import math from pyproj import CRS from pyproj import Transformer import argparse def find_index(wavelength, array): differences = np.abs(array - wavelength) min_position = np.argmin(differences) return min_position def get_mean_value(index, array, window): window = int(window) result = array[1:, index - window:index + window + 1].mean(axis=1) return result def calculate(x1, x2, coefficients): A = np.column_stack((x1, x2, np.ones((x2.shape[0], 1)))) y_pred = A.dot(coefficients) return y_pred def retrieval_chl_a(model_info_path, coor_spectral_path, output_path, window=5): model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path) coor_spectral = np.loadtxt(coor_spectral_path) wave1 = 651 index_wave1 = find_index(wave1, coor_spectral[0, :]) band_651 = get_mean_value(index_wave1, coor_spectral, window) wave2 = 707 index_wave2 = find_index(wave2, coor_spectral[0, :]) band_707 = get_mean_value(index_wave2, coor_spectral, window) wave3 = 670 index_wave3 = find_index(wave3, coor_spectral[0, :]) band_670 = get_mean_value(index_wave3, coor_spectral, window) x = (band_651 - band_707) / (band_707 - band_670) retrieval_result = np.polyval(model_info, list(x)) position_content = np.hstack((coor_spectral[1:, 0:2], retrieval_result.reshape((retrieval_result.shape[0], 1)))) np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t") return position_content def retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=5): model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path) coor_spectral = np.loadtxt(coor_spectral_path) wave1 = 600 index_wave1 = find_index(wave1, coor_spectral[0, :]) band_600 = get_mean_value(index_wave1, coor_spectral, window) wave2 = 500 index_wave2 = find_index(wave2, coor_spectral[0, :]) band_500 = get_mean_value(index_wave2, coor_spectral, window) wave3 = 850 index_wave3 = find_index(wave3, coor_spectral[0, :]) band_850 = get_mean_value(index_wave3, coor_spectral, window) x13 = np.log(band_500 / band_850) x23 = np.exp(band_600 / band_500) retrieval_result = calculate(x13, x23, model_info) position_content = np.hstack((coor_spectral[1:, 0:2], retrieval_result.reshape((retrieval_result.shape[0], 1)))) if output_path is not None: np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t") return position_content def retrieval_tss(model_info_path, coor_spectral_path, output_path, window=5): position_content = retrieval_nh3(model_info_path, coor_spectral_path, window=window) tmp = np.exp(position_content[:, -1]) position_content[:, -1] = tmp np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t") return position_content def main(): parser = argparse.ArgumentParser(description="此程序使用修正后的模型进行反演。") # parser.add_argument("--global_arg", type=str, help="A global argument for all modes", required=True) # 创建子命令解析器 subparsers = parser.add_subparsers(dest="algorithm", required=True, help="Choose a mode") chl_a_ = subparsers.add_parser("chl_a", help="叶绿素") chl_a_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径') chl_a_.add_argument('-i2', '--coor_spectral_path', type=str, required=True, help='输入坐标-光谱文件的路径') chl_a_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径') chl_a_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径') chl_a_.set_defaults(func=retrieval_chl_a) nh3_ = subparsers.add_parser("nh3", help="氨氮") nh3_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径') nh3_.add_argument('-i2', '--coor_spectral_path', type=str, required=True, help='输入坐标-光谱文件的路径') nh3_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径') nh3_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径') nh3_.set_defaults(func=retrieval_nh3) mno4_ = subparsers.add_parser("mno4", help="高猛酸盐") mno4_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径') mno4_.add_argument('-i2', '--coor_spectral_path', type=str, required=True, help='输入坐标-光谱文件的路径') mno4_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径') mno4_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径') mno4_.set_defaults(func=retrieval_nh3) tn_ = subparsers.add_parser("tn", help="总氮") tn_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径') tn_.add_argument('-i2', '--coor_spectral_path', type=str, required=True, help='输入坐标-光谱文件的路径') tn_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径') tn_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径') tn_.set_defaults(func=retrieval_nh3) tp_ = subparsers.add_parser("tp", help="总磷") tp_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径') tp_.add_argument('-i2', '--coor_spectral_path', type=str, required=True, help='输入坐标-光谱文件的路径') tp_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径') tp_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径') tp_.set_defaults(func=retrieval_nh3) tss_ = subparsers.add_parser("tss", help="总悬浮物") tss_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径') tss_.add_argument('-i2', '--coor_spectral_path', type=str, required=True, help='输入坐标-光谱文件的路径') tss_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径') tss_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径') tss_.set_defaults(func=retrieval_tss) # 解析参数 args = parser.parse_args() if args.algorithm == "chl_a": args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius) elif args.algorithm == "nh3": args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius) elif args.algorithm == "mno4": args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius) elif args.algorithm == "tn": args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius) elif args.algorithm == "tp": args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius) elif args.algorithm == "tss": args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius) # Press the green button in the gutter to run the script. if __name__ == '__main__': main()