170 lines
7.5 KiB
Python
170 lines
7.5 KiB
Python
import numpy as np
|
|
import os
|
|
from osgeo import gdal
|
|
from util import *
|
|
from type_define import *
|
|
import math
|
|
from pyproj import CRS
|
|
from pyproj import Transformer
|
|
import argparse
|
|
|
|
|
|
def find_index(wavelength, array):
|
|
differences = np.abs(array - wavelength)
|
|
min_position = np.argmin(differences)
|
|
return min_position
|
|
|
|
|
|
def get_mean_value(index, array, window):
|
|
window = int(window)
|
|
result = array[1:, index - window:index + window + 1].mean(axis=1)
|
|
|
|
return result
|
|
|
|
|
|
def calculate(x1, x2, coefficients):
|
|
A = np.column_stack((x1, x2, np.ones((x2.shape[0], 1))))
|
|
y_pred = A.dot(coefficients)
|
|
|
|
return y_pred
|
|
|
|
|
|
def retrieval_chl_a(model_info_path, coor_spectral_path, output_path, window=5):
|
|
model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path)
|
|
coor_spectral = np.loadtxt(coor_spectral_path)
|
|
|
|
wave1 = 651
|
|
index_wave1 = find_index(wave1, coor_spectral[0, :])
|
|
band_651 = get_mean_value(index_wave1, coor_spectral, window)
|
|
|
|
wave2 = 707
|
|
index_wave2 = find_index(wave2, coor_spectral[0, :])
|
|
band_707 = get_mean_value(index_wave2, coor_spectral, window)
|
|
|
|
wave3 = 670
|
|
index_wave3 = find_index(wave3, coor_spectral[0, :])
|
|
band_670 = get_mean_value(index_wave3, coor_spectral, window)
|
|
|
|
x = (band_651 - band_707) / (band_707 - band_670)
|
|
retrieval_result = np.polyval(model_info, list(x))
|
|
|
|
position_content = np.hstack((coor_spectral[1:, 0:2], retrieval_result.reshape((retrieval_result.shape[0], 1))))
|
|
np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t")
|
|
|
|
return position_content
|
|
|
|
|
|
def retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=5):
|
|
model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path)
|
|
coor_spectral = np.loadtxt(coor_spectral_path)
|
|
|
|
wave1 = 600
|
|
index_wave1 = find_index(wave1, coor_spectral[0, :])
|
|
band_600 = get_mean_value(index_wave1, coor_spectral, window)
|
|
|
|
wave2 = 500
|
|
index_wave2 = find_index(wave2, coor_spectral[0, :])
|
|
band_500 = get_mean_value(index_wave2, coor_spectral, window)
|
|
|
|
wave3 = 850
|
|
index_wave3 = find_index(wave3, coor_spectral[0, :])
|
|
band_850 = get_mean_value(index_wave3, coor_spectral, window)
|
|
|
|
x13 = np.log(band_500 / band_850)
|
|
x23 = np.exp(band_600 / band_500)
|
|
retrieval_result = calculate(x13, x23, model_info)
|
|
|
|
position_content = np.hstack((coor_spectral[1:, 0:2], retrieval_result.reshape((retrieval_result.shape[0], 1))))
|
|
if output_path is not None:
|
|
np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t")
|
|
|
|
return position_content
|
|
|
|
|
|
def retrieval_tss(model_info_path, coor_spectral_path, output_path, window=5):
|
|
position_content = retrieval_nh3(model_info_path, coor_spectral_path, window=window)
|
|
|
|
tmp = np.exp(position_content[:, -1])
|
|
position_content[:, -1] = tmp
|
|
|
|
np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t")
|
|
|
|
return position_content
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="此程序使用修正后的模型进行反演。")
|
|
|
|
# parser.add_argument("--global_arg", type=str, help="A global argument for all modes", required=True)
|
|
|
|
# 创建子命令解析器
|
|
subparsers = parser.add_subparsers(dest="algorithm", required=True, help="Choose a mode")
|
|
|
|
chl_a_ = subparsers.add_parser("chl_a", help="叶绿素")
|
|
chl_a_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
|
chl_a_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
|
help='输入坐标-光谱文件的路径')
|
|
chl_a_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
|
chl_a_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
|
chl_a_.set_defaults(func=retrieval_chl_a)
|
|
|
|
nh3_ = subparsers.add_parser("nh3", help="氨氮")
|
|
nh3_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
|
nh3_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
|
help='输入坐标-光谱文件的路径')
|
|
nh3_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
|
nh3_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
|
nh3_.set_defaults(func=retrieval_nh3)
|
|
|
|
mno4_ = subparsers.add_parser("mno4", help="高猛酸盐")
|
|
mno4_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
|
mno4_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
|
help='输入坐标-光谱文件的路径')
|
|
mno4_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
|
mno4_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
|
mno4_.set_defaults(func=retrieval_nh3)
|
|
|
|
tn_ = subparsers.add_parser("tn", help="总氮")
|
|
tn_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
|
tn_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
|
help='输入坐标-光谱文件的路径')
|
|
tn_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
|
tn_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
|
tn_.set_defaults(func=retrieval_nh3)
|
|
|
|
tp_ = subparsers.add_parser("tp", help="总磷")
|
|
tp_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
|
tp_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
|
help='输入坐标-光谱文件的路径')
|
|
tp_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
|
tp_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
|
tp_.set_defaults(func=retrieval_nh3)
|
|
|
|
tss_ = subparsers.add_parser("tss", help="总悬浮物")
|
|
tss_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
|
|
tss_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
|
|
help='输入坐标-光谱文件的路径')
|
|
tss_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
|
|
tss_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
|
|
tss_.set_defaults(func=retrieval_tss)
|
|
|
|
# 解析参数
|
|
args = parser.parse_args()
|
|
if args.algorithm == "chl_a":
|
|
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
|
elif args.algorithm == "nh3":
|
|
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
|
elif args.algorithm == "mno4":
|
|
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
|
elif args.algorithm == "tn":
|
|
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
|
elif args.algorithm == "tp":
|
|
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
|
elif args.algorithm == "tss":
|
|
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
|
|
|
|
|
|
# Press the green button in the gutter to run the script.
|
|
if __name__ == '__main__':
|
|
main()
|