Files
water_content_retrieval/model_correction.py
2025-01-06 10:18:08 +08:00

462 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import os
from osgeo import gdal
from util import *
from type_define import *
import math
from pyproj import CRS
from pyproj import Transformer
import argparse
import json
def get_coor_base_lang_lat(img_path, water_mask, severe_glint, csv_pos, coor_type, radius, point_pos_strategy=PointPosStrategy.nearest_single, img_type=ImgType.ref):
"""
此函数用于获取影像img_path在经纬度坐标csv_pos处的值并输出到文件out_txt
:param img_path: data ignore value应该为0hdr不应该带bil等后缀
:param img_type: 枚举ImgType
:param csv_pos:
:param coor_type: 枚举CoorType
:param radius: 单位为米
:param point_pos_strategy: 枚举PointPosStrategy
:param severe_glint: 单波段图像1代表耀斑严重的区域
:return:
"""
dataset = gdal.Open(img_path)
im_width = dataset.RasterXSize # 栅格矩阵的列数
im_height = dataset.RasterYSize # 栅格矩阵的行数
num_bands = dataset.RasterCount # 栅格矩阵的波段数
geotransform = dataset.GetGeoTransform() # 仿射矩阵
im_proj = dataset.GetProjection() # 地图投影信息
inv_geotransform_input = gdal.InvGeoTransform(geotransform)
band_number_used = 1 # 找绿波段干啥?
if img_type == ImgType.ref:
green_wave = 560 # 绿光信号最强
band_number_used = find_band_number(green_wave, img_path)
elif img_type == ImgType.content:
band_number_used = 0
position_content = np.loadtxt(csv_pos, delimiter="\t")
new_columns_num = 4 # + num_bands
new_columns = np.zeros((position_content.shape[0], new_columns_num))
position_content = np.hstack((position_content, new_columns))
dataset_water_mask = gdal.Open(water_mask)
# geotransform_water_mask = dataset_water_mask.GetGeoTransform() # 仿射矩阵
# data_water_mask = dataset_water_mask.GetRasterBand(1).ReadAsArray()
dataset_severe_glint = gdal.Open(severe_glint)
geotransform_severe_glint = dataset_severe_glint.GetGeoTransform() # 仿射矩阵
data_severe_glint = dataset_severe_glint.GetRasterBand(1).ReadAsArray()
inv_geotransform_severe_glint = gdal.InvGeoTransform(geotransform_severe_glint)
if coor_type == CoorType.latlong: # gps
epsg_code = 32600 + math.floor(position_content[0, 0] / 6) + 31
crs = CRS.from_epsg(epsg_code)
projTransformer = Transformer.from_crs(crs.geodetic_crs, crs)
for i in range(position_content.shape[0]):
if coor_type == CoorType.latlong: # gps
utm_cor = projTransformer.transform(position_content[i, 1], position_content[i, 0])
elif coor_type == CoorType.utm: # utm
utm_cor = (position_content[i, 0], position_content[i, 1])
img_coor = gdal.ApplyGeoTransform(inv_geotransform_input, utm_cor[0], utm_cor[1]) # (x,y)
img_coor = [int(i) for i in img_coor]
radius_pixel_data = math.ceil(radius / geotransform[1] / 2)
data_local = dataset.GetRasterBand(band_number_used + 1).ReadAsArray(img_coor[0] - radius_pixel_data,
img_coor[1] - radius_pixel_data,
2 * radius_pixel_data + 1,
2 * radius_pixel_data + 1)
water_mask_local = dataset_water_mask.GetRasterBand(1).ReadAsArray(img_coor[0] - radius_pixel_data,
img_coor[1] - radius_pixel_data,
2 * radius_pixel_data + 1,
2 * radius_pixel_data + 1)
# if img_type == ImgType.content:
# data_local = dataset.ReadAsArray(img_coor[0] - radius_pixel_data, img_coor[1] - radius_pixel_data,
# 2 * radius_pixel_data + 1, 2 * radius_pixel_data + 1)
# elif img_type == ImgType.ref:
# data_local = dataset.GetRasterBand(band_number_used + 1).ReadAsArray(img_coor[0] - radius_pixel_data,
# img_coor[1] - radius_pixel_data,
# 2 * radius_pixel_data + 1,
# 2 * radius_pixel_data + 1)
img_coor_severe_glint = gdal.ApplyGeoTransform(inv_geotransform_severe_glint, utm_cor[0], utm_cor[1])
img_coor_severe_glint = [int(i) for i in img_coor_severe_glint]
if geotransform[1] == geotransform_severe_glint[1]:
radius_severe_glint = math.ceil(radius / geotransform_severe_glint[1] / 2) #
severe_glint_local = data_severe_glint[img_coor_severe_glint[1] - radius_severe_glint:img_coor_severe_glint[1] + radius_severe_glint + 1,
img_coor_severe_glint[0] - radius_severe_glint:img_coor_severe_glint[0] + radius_severe_glint + 1]
else: # 当分辨率不一致时,还没完成????????????????????????????????????????????????????????
severe_glint_local = np.zeros(data_local.shape)
for x in range(2 * radius_pixel_data + 1):
for y in range(2 * radius_pixel_data + 1):
img_coor_tmp = gdal.ApplyGeoTransform(geotransform, utm_cor[0], utm_cor[1])
img_coor_tmp2 = gdal.ApplyGeoTransform(inv_geotransform_severe_glint, img_coor_tmp[0], img_coor_tmp[1])
severe_glint_local[y, x] = data_severe_glint[img_coor_tmp[1], img_coor_tmp[0]]
invalid_value = 0
data_local[severe_glint_local == 1] = invalid_value
data_local[water_mask_local == 0] = invalid_value
# 根据PointPosStrategy确定使用哪些位置的数据
if point_pos_strategy == PointPosStrategy.nearest_single:
coor_y_local, coor_x_local = getnearest(data_local, invalid_value)
if coor_x_local is None:
continue
coor_y = coor_y_local + img_coor[1] - radius_pixel_data
coor_x = coor_x_local + img_coor[0] - radius_pixel_data
utm_x, utm_y = gdal.ApplyGeoTransform(geotransform, coor_x, coor_y)
position_content[i, position_content.shape[1] - new_columns_num] = utm_x
position_content[i, position_content.shape[1] - new_columns_num + 1] = utm_y
position_content[i, position_content.shape[1] - new_columns_num + 2] = int(coor_x)
position_content[i, position_content.shape[1] - new_columns_num + 3] = int(coor_y)
# position_content[i, position_content.shape[1] - new_columns_num + 4:position_content.shape[1]] = \
# dataset.ReadAsArray(int(coor_x), int(coor_y), 1, 1).reshape(num_bands)
a = 1
elif point_pos_strategy == PointPosStrategy.four_quadrant: # 还没完成????????????????????????????????????????????????????????
a = 1
a = 1
rows_to_keep = np.where(position_content[:, position_content.shape[1] - new_columns_num + 3] != 0)
position_content2 = position_content[rows_to_keep]
position_content3 = position_content2[~np.isnan(position_content2).any(axis=1)]
# np.savetxt(out_txt, position_content3)
del dataset, dataset_severe_glint, dataset_water_mask
return position_content3
def fit(x1, x2, y):
A = np.column_stack((x1, x2, np.ones((x2.shape[0], 1))))
coefficients, _, _, _ = np.linalg.lstsq(A, y, rcond=None)
return coefficients
def accuracy_evaluation(x1, x2, y_real, coefficients):
A = np.column_stack((x1, x2, np.ones((x2.shape[0], 1))))
y_pred = A.dot(coefficients)
accuracy = np.absolute((y_real - y_pred) / y_real * 100)
return accuracy
def accuracy_evaluation_tss(x1, x2, y_real, coefficients):
A = np.column_stack((x1, x2, np.ones((x2.shape[0], 1))))
y = A.dot(coefficients)
y_pred = np.exp(y)
accuracy = np.absolute((y_real - y_pred) / y_real * 100)
return accuracy
def get_x_in_coor(coor, *args):
new_columns_counter = len(args)
new_columns = np.zeros((coor.shape[0], new_columns_counter))
coor_extend = np.hstack((coor, new_columns))
for i in range(coor.shape[0]):
for j in range(new_columns_counter):
coor_extend[i, coor_extend.shape[1] - (new_columns_counter - j)] = args[j][
int(coor_extend[i, coor_extend.shape[1] - new_columns_counter - 1]),
int(coor_extend[i, coor_extend.shape[1] - new_columns_counter - 2])]
return coor_extend
def write_model_info(model_type, coefficients, accuracy, long, lat, outpath):
# 将 NumPy 数组转换为列表
np_dict = {
'model_type': model_type,
'model_info': coefficients.tolist(),
'accuracy': accuracy.tolist(),
'long': long.tolist(),
'lat': lat.tolist()
}
# 将字典写入 JSON 文件,使用 indent 参数进行格式化每一级缩进4个空格
with open(outpath, 'w') as f:
json.dump(np_dict, f, indent=4)
def chl_a(img_path, coor, outpath_coeff, window=5): # 叶绿素
wave1 = 651
band_651 = average_bands(wave1 - window, wave1 + window, img_path)
wave2 = 707
band_707 = average_bands(wave2 - window, wave2 + window, img_path)
wave3 = 670
band_670 = average_bands(wave3 - window, wave3 + window, img_path)
x = (band_651 - band_707) / (band_707 - band_670)
coor_extend = get_x_in_coor(coor, x)
# 修正模型参数并输出
x = coor_extend[:, -1]
y_real = coor_extend[:, 2]
coefficients = np.polyfit(list(x), list(y_real), 1)
y_pred = np.polyval(coefficients, list(x))
accuracy = np.absolute((y_real - y_pred) / y_real * 100)
write_model_info("chl-a", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff)
model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff)
return coefficients
def nh3(img_path, coor, outpath_coeff, window=5): # 氨氮
wave1 = 600
band_600 = average_bands(wave1 - window, wave1 + window, img_path)
wave2 = 500
band_500 = average_bands(wave2 - window, wave2 + window, img_path)
wave3 = 850
band_850 = average_bands(wave3 - window, wave3 + window, img_path)
x13 = np.log(band_500 / band_850)
x23 = np.exp(band_600 / band_500)
coor_extend = get_x_in_coor(coor, x13, x23)
# 修正模型参数并输出
x1 = coor_extend[:, -2]
x2 = coor_extend[:, -1]
y_real = coor_extend[:, 2]
coefficients = fit(x1, x2, y_real)
accuracy = accuracy_evaluation(x1, x2, y_real, coefficients)
write_model_info("nh3", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff)
model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff)
return coefficients
def mno4(img_path, coor, outpath_coeff, window=5): # 高猛酸盐
wave1 = 500
band_500 = average_bands(wave1 - window, wave1 + window, img_path)
wave2 = 440
band_440 = average_bands(wave2 - window, wave2 + window, img_path)
wave3 = 610
band_610 = average_bands(wave3 - window, wave3 + window, img_path)
wave4 = 800
band_800 = average_bands(wave4 - window, wave4 + window, img_path)
x3 = band_500 / band_440
x6 = band_610 / band_800
coor_extend = get_x_in_coor(coor, x3, x6)
# 修正模型参数并输出
x1 = coor_extend[:, -2]
x2 = coor_extend[:, -1]
y_real = coor_extend[:, 2]
coefficients = fit(x1, x2, y_real)
accuracy = accuracy_evaluation(x1, x2, y_real, coefficients)
write_model_info("mno4", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff)
model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff)
return coefficients
def tn(img_path, coor, outpath_coeff, window=5): # 总氮
wave1 = 600
band_600 = average_bands(wave1 - window, wave1 + window, img_path)
wave2 = 500
band_500 = average_bands(wave2 - window, wave2 + window, img_path)
wave3 = 850
band_850 = average_bands(wave3 - window, wave3 + window, img_path)
x13 = np.log(band_500 / band_850)
x23 = np.exp(band_600 / band_500)
coor_extend = get_x_in_coor(coor, x13, x23)
# 修正模型参数并输出
x1 = coor_extend[:, -2]
x2 = coor_extend[:, -1]
y_real = coor_extend[:, 2]
coefficients = fit(x1, x2, y_real)
accuracy = accuracy_evaluation(x1, x2, y_real, coefficients)
write_model_info("tn", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff)
model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff)
return coefficients
def tp(img_path, coor, outpath_coeff, window=5): # 总磷
wave1 = 600
band_600 = average_bands(wave1 - window, wave1 + window, img_path)
wave2 = 500
band_500 = average_bands(wave2 - window, wave2 + window, img_path)
wave3 = 850
band_850 = average_bands(wave3 - window, wave3 + window, img_path)
x13 = np.log(band_500 / band_850)
x23 = np.exp(band_600 / band_500)
coor_extend = get_x_in_coor(coor, x13, x23)
# 修正模型参数并输出
x1 = coor_extend[:, -2]
x2 = coor_extend[:, -1]
y_real = coor_extend[:, 2]
coefficients = fit(x1, x2, y_real)
accuracy = accuracy_evaluation(x1, x2, y_real, coefficients)
write_model_info("tp", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff)
model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff)
return coefficients
def tss(img_path, coor, outpath_coeff, window=5): # 总悬浮物??????????????????
wave1 = 555
band_555 = average_bands(wave1 - window, wave1 + window, img_path)
wave2 = 670
band_670 = average_bands(wave2 - window, wave2 + window, img_path)
wave3 = 490
band_490 = average_bands(wave3 - window, wave3 + window, img_path)
x1 = band_555 + band_670
x2 = band_490 / band_555
coor_extend = get_x_in_coor(coor, x1, x2)
# 修正模型参数并输出
x1 = coor_extend[:, -2]
x2 = coor_extend[:, -1]
y_real = coor_extend[:, 2]
y = np.log(y_real)
coefficients = fit(x1, x2, y)
accuracy = accuracy_evaluation_tss(x1, x2, y_real, coefficients)
write_model_info("tss", coefficients, accuracy, coor_extend[:, 0], coor_extend[:, 1], outpath_coeff)
model_type, model_info, accuracy_ = load_numpy_dict_from_json(outpath_coeff)
return coefficients
def main():
parser = argparse.ArgumentParser(description="此程序用于通过实测数据修正模型参数。")
# parser.add_argument("--global_arg", type=str, help="A global argument for all modes", required=True)
# 创建子命令解析器
subparsers = parser.add_subparsers(dest="algorithm", required=True, help="Choose a mode")
chl_a_ = subparsers.add_parser("chl_a", help="叶绿素")
chl_a_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径')
chl_a_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径')
chl_a_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径')
chl_a_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径')
chl_a_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径')
chl_a_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径')
chl_a_.set_defaults(func=chl_a)
nh3_ = subparsers.add_parser("nh3", help="氨氮")
nh3_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径')
nh3_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径')
nh3_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径')
nh3_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径')
nh3_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径')
nh3_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径')
nh3_.set_defaults(func=nh3)
mno4_ = subparsers.add_parser("mno4", help="高猛酸盐")
mno4_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径')
mno4_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径')
mno4_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径')
mno4_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径')
mno4_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径')
mno4_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径')
mno4_.set_defaults(func=mno4)
tn_ = subparsers.add_parser("tn", help="总氮")
tn_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径')
tn_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径')
tn_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径')
tn_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径')
tn_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径')
tn_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径')
tn_.set_defaults(func=tn)
tp_ = subparsers.add_parser("tp", help="总磷")
tp_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径')
tp_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径')
tp_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径')
tp_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径')
tp_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径')
tp_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径')
tp_.set_defaults(func=tp)
tss_ = subparsers.add_parser("tss", help="总悬浮物")
tss_.add_argument('-i1', '--img', type=str, required=True, help='输入影像文件的路径')
tss_.add_argument('-i2', '--water_mask', type=str, required=True, help='输入水域掩膜文件的路径')
tss_.add_argument('-i3', '--severe_glint', type=str, required=True, help='输入耀斑严重文件的路径')
tss_.add_argument('-i4', '--measured_data', type=str, required=True, help='输入实测含量数据的路径')
tss_.add_argument('-i5', '--radius', type=float, default=2.0, help='输入实测坐标半径')
tss_.add_argument('-o', '--model_info_outpath', required=True, type=str, help='输出模型信息文件的路径')
tss_.set_defaults(func=tss)
# 解析参数
args = parser.parse_args()
tmp = get_coor_base_lang_lat(args.img, args.water_mask, args.severe_glint, args.measured_data, CoorType.latlong,
args.radius)
if args.algorithm == "chl_a":
args.func(args.img, tmp, args.model_info_outpath)
elif args.algorithm == "nh3":
args.func(args.img, tmp, args.model_info_outpath)
elif args.algorithm == "mno4":
args.func(args.img, tmp, args.model_info_outpath)
elif args.algorithm == "tn":
args.func(args.img, tmp, args.model_info_outpath)
elif args.algorithm == "tp":
args.func(args.img, tmp, args.model_info_outpath)
elif args.algorithm == "tss":
args.func(args.img, tmp, args.model_info_outpath)
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
main()