Files
water_content_retrieval/retrieval.py
2025-01-06 10:18:08 +08:00

170 lines
7.5 KiB
Python

import numpy as np
import os
from osgeo import gdal
from util import *
from type_define import *
import math
from pyproj import CRS
from pyproj import Transformer
import argparse
def find_index(wavelength, array):
differences = np.abs(array - wavelength)
min_position = np.argmin(differences)
return min_position
def get_mean_value(index, array, window):
window = int(window)
result = array[1:, index - window:index + window + 1].mean(axis=1)
return result
def calculate(x1, x2, coefficients):
A = np.column_stack((x1, x2, np.ones((x2.shape[0], 1))))
y_pred = A.dot(coefficients)
return y_pred
def retrieval_chl_a(model_info_path, coor_spectral_path, output_path, window=5):
model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path)
coor_spectral = np.loadtxt(coor_spectral_path)
wave1 = 651
index_wave1 = find_index(wave1, coor_spectral[0, :])
band_651 = get_mean_value(index_wave1, coor_spectral, window)
wave2 = 707
index_wave2 = find_index(wave2, coor_spectral[0, :])
band_707 = get_mean_value(index_wave2, coor_spectral, window)
wave3 = 670
index_wave3 = find_index(wave3, coor_spectral[0, :])
band_670 = get_mean_value(index_wave3, coor_spectral, window)
x = (band_651 - band_707) / (band_707 - band_670)
retrieval_result = np.polyval(model_info, list(x))
position_content = np.hstack((coor_spectral[1:, 0:2], retrieval_result.reshape((retrieval_result.shape[0], 1))))
np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t")
return position_content
def retrieval_nh3(model_info_path, coor_spectral_path, output_path=None, window=5):
model_type, model_info, accuracy_ = load_numpy_dict_from_json(model_info_path)
coor_spectral = np.loadtxt(coor_spectral_path)
wave1 = 600
index_wave1 = find_index(wave1, coor_spectral[0, :])
band_600 = get_mean_value(index_wave1, coor_spectral, window)
wave2 = 500
index_wave2 = find_index(wave2, coor_spectral[0, :])
band_500 = get_mean_value(index_wave2, coor_spectral, window)
wave3 = 850
index_wave3 = find_index(wave3, coor_spectral[0, :])
band_850 = get_mean_value(index_wave3, coor_spectral, window)
x13 = np.log(band_500 / band_850)
x23 = np.exp(band_600 / band_500)
retrieval_result = calculate(x13, x23, model_info)
position_content = np.hstack((coor_spectral[1:, 0:2], retrieval_result.reshape((retrieval_result.shape[0], 1))))
if output_path is not None:
np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t")
return position_content
def retrieval_tss(model_info_path, coor_spectral_path, output_path, window=5):
position_content = retrieval_nh3(model_info_path, coor_spectral_path, window=window)
tmp = np.exp(position_content[:, -1])
position_content[:, -1] = tmp
np.savetxt(output_path, position_content, fmt='%.4f', delimiter="\t")
return position_content
def main():
parser = argparse.ArgumentParser(description="此程序使用修正后的模型进行反演。")
# parser.add_argument("--global_arg", type=str, help="A global argument for all modes", required=True)
# 创建子命令解析器
subparsers = parser.add_subparsers(dest="algorithm", required=True, help="Choose a mode")
chl_a_ = subparsers.add_parser("chl_a", help="叶绿素")
chl_a_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
chl_a_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
help='输入坐标-光谱文件的路径')
chl_a_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
chl_a_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
chl_a_.set_defaults(func=retrieval_chl_a)
nh3_ = subparsers.add_parser("nh3", help="氨氮")
nh3_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
nh3_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
help='输入坐标-光谱文件的路径')
nh3_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
nh3_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
nh3_.set_defaults(func=retrieval_nh3)
mno4_ = subparsers.add_parser("mno4", help="高猛酸盐")
mno4_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
mno4_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
help='输入坐标-光谱文件的路径')
mno4_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
mno4_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
mno4_.set_defaults(func=retrieval_nh3)
tn_ = subparsers.add_parser("tn", help="总氮")
tn_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
tn_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
help='输入坐标-光谱文件的路径')
tn_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
tn_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
tn_.set_defaults(func=retrieval_nh3)
tp_ = subparsers.add_parser("tp", help="总磷")
tp_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
tp_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
help='输入坐标-光谱文件的路径')
tp_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
tp_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
tp_.set_defaults(func=retrieval_nh3)
tss_ = subparsers.add_parser("tss", help="总悬浮物")
tss_.add_argument('-i1', '--model_info_path', type=str, required=True, help='输入模型信息文件的路径')
tss_.add_argument('-i2', '--coor_spectral_path', type=str, required=True,
help='输入坐标-光谱文件的路径')
tss_.add_argument('-i3', '--wave_radius', type=float, default=5.0, help='输入波长平均半径')
tss_.add_argument('-o', '--outpath', required=True, type=str, help='输出文件的路径')
tss_.set_defaults(func=retrieval_tss)
# 解析参数
args = parser.parse_args()
if args.algorithm == "chl_a":
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
elif args.algorithm == "nh3":
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
elif args.algorithm == "mno4":
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
elif args.algorithm == "tn":
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
elif args.algorithm == "tp":
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
elif args.algorithm == "tss":
args.func(args.model_info_path, args.coor_spectral_path, args.outpath, args.wave_radius)
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
main()