Files
water_content_retrieval/deglint_regression_slope.py
2025-01-06 10:18:08 +08:00

213 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from util import *
from osgeo import gdal
import json
from shapely.geometry import Point, Polygon
import xml.etree.ElementTree as ET
def parse_xml(roifilepath_xml):
# 读取 XML 文件
tree = ET.parse(roifilepath_xml) # 替换为您的 XML 文件路径
root = tree.getroot()
# 打印根标签
# print("Root tag:", root.tag)
# 遍历 XML 文件的所有子元素
# for child in root:
# print("Tag:", child.tag, "| Attributes:", child.attrib)
# 查找特定标签
polygons = []
for region in root.findall(".//Region"):
# print("Region name:", region.get("name"))
for polygon in region.findall(".//Polygon"):
coordinates_str = polygon.find(".//Coordinates").text.strip()
coordinates = list(map(float, coordinates_str.split()))
points = [(coordinates[i], coordinates[i + 1]) for i in range(0, len(coordinates), 2)]
polygon = Polygon(points)
polygons.append(polygon)
# print("Polygon coordinates:", coordinates)
return CoorType.depend_on_image, polygons
def create_json(filepath_json):
"""
用户创建sample json文件parse_json函数需要能够解析创建逆时针闭合不规则多边形
:param filepath_json:
:return:
"""
polygons = {
"Available_Coordinates_type": ["depend_on_image", "pixel"],
"Coordinates_type": "pixel",
"polygons": [
{
"id": 1,
"coordinates": [
{"x": 120, "y": 300},
{"x": 250, "y": 220},
{"x": 340, "y": 400},
{"x": 200, "y": 500},
{"x": 120, "y": 300} # 闭合
]
},
{
"id": 2,
"coordinates": [
{"x": 400, "y": 100},
{"x": 550, "y": 180},
{"x": 480, "y": 350},
{"x": 370, "y": 280},
{"x": 400, "y": 100} # 闭合
]
}
]
}
# 保存为 JSON 文件
with open(filepath_json, "w") as json_file:
json.dump(polygons, json_file, indent=4)
def parse_json(roifilepath_json):
with open(roifilepath_json, "r") as json_file:
loaded_data = json.load(json_file)
coor_type_tmp = loaded_data["Coordinates_type"]
if coor_type_tmp == "pixel":
coor_type = CoorType.pixel
elif coor_type_tmp == "depend_on_image":
coor_type = CoorType.depend_on_image
else:
print("Error: unknown coordinates type!!!")
return
# 构建多边形
polygons = []
for poly in loaded_data["polygons"]:
# 提取多边形坐标
coordinates = [(point["x"], point["y"]) for point in poly["coordinates"]]
# 使用 Shapely 构造多边形
polygons.append(Polygon(coordinates))
return coor_type, polygons
def parse_roifile(roifilepath):
extension = os.path.splitext(roifilepath)[-1].removeprefix(".")
if extension == "json":
coor_type, polygons = parse_json(roifilepath)
elif extension == "xml":
coor_type, polygons = parse_xml(roifilepath)
else:
print("Error: unknown roi file type!!!")
return
return coor_type, polygons
def get_coors_in_polygons(polygons):
"""
:param polygons: list类型坐标类型必须为像素坐标
:return:
"""
x = []
y = []
for i in range(len(polygons)):
polygon = polygons[i]
bounds = polygon.bounds
for x_tmp in range(int(bounds[0]), int(bounds[2])):
for y_tmp in range(int(bounds[1]), int(bounds[3])):
if polygon.contains(Point(x_tmp, y_tmp)):
x.append(x_tmp)
y.append(y_tmp)
return x, y
@timeit
def get_pixel_coors_in_polygons(roifilepath, dataset):
coor_type, polygons = parse_roifile(roifilepath)
polygons_pixel_coors = []
# 如果polygons坐标不是像素坐标则转换为像素坐标
if coor_type == CoorType.depend_on_image:
geotransform_input = dataset.GetGeoTransform()
inv_geotransform_input = gdal.InvGeoTransform(geotransform_input)
for polygon in polygons:
shifted_coords = [gdal.ApplyGeoTransform(inv_geotransform_input, x, y) for x, y in polygon.exterior.coords]
shifted_polygon = Polygon(shifted_coords)
polygons_pixel_coors.append(shifted_polygon)
elif coor_type == CoorType.pixel:
polygons_pixel_coors = polygons
return get_coors_in_polygons(polygons_pixel_coors)
def get_valid_pixel_value(band, x, y):
value = band[y, x]
value = np.sort(value)
return value
@timeit
def deglint_regression_slope(imgpath, water_mask, out_imgpath, start_nir_wave, end_nir_wave, filepath_json):
"""
:param imgpath:
:param out_imgpath:
:param start_nir_wave:
:param end_nir_wave:
:param args:
:param kwargs:
:return:
"""
nir = average_bands(start_nir_wave, end_nir_wave, imgpath)
dataset = gdal.Open(imgpath)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
num_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
valid_pixel_coor1_x, valid_pixel_coor1_y = get_pixel_coors_in_polygons(filepath_json, dataset)
x = get_valid_pixel_value(nir, valid_pixel_coor1_x, valid_pixel_coor1_y)
format = "ENVI"
driver = gdal.GetDriverByName(format)
dst_ds = driver.Create(out_imgpath, im_width, im_height, num_bands, gdal.GDT_Int16,
options=["INTERLEAVE=BSQ"])
dst_ds.SetGeoTransform(geotransform)
dst_ds.SetProjection(im_proj)
dataset_water_mask = gdal.Open(water_mask)
data_water_mask = dataset_water_mask.GetRasterBand(1).ReadAsArray()
del dataset_water_mask
for i in range(num_bands):
band = dataset.GetRasterBand(i + 1).ReadAsArray()
y = get_valid_pixel_value(band, valid_pixel_coor1_x, valid_pixel_coor1_y)
coefficients = np.polyfit(x, y, 1)
glint = (nir - min(x)) * coefficients[0]
glint[np.where(glint < 0)] = 0
glint[np.where(data_water_mask == 0)] = 0
band_deglint = band - glint
# 确保值大于等于0
band_deglint[np.where(band_deglint < 0)] = 0
dst_ds.GetRasterBand(i + 1).WriteArray(band_deglint)
del dataset, dst_ds
write_fields_to_hdrfile(get_hdr_file_path(imgpath), get_hdr_file_path(out_imgpath))