Files
water_content_retrieval/extract_water_area.py
2025-01-06 10:18:08 +08:00

143 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from util import *
from osgeo import gdal
import argparse
def xml2shp():
pass
def rasterize_envi_xml(shp_filepath):
pass
@timeit
def rasterize_shp(shp_filepath, raster_fn_out, img_path, NoData_value=0):
dataset = gdal.Open(img_path)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
geotransform = dataset.GetGeoTransform()
imgdata_in = dataset.GetRasterBand(1).ReadAsArray()
del dataset
# Open the data source and read in the extent
source_ds = gdal.OpenEx(shp_filepath)
# about 25 metres(ish) use 0.001 if you want roughly 100m
pixel_size = geotransform[1]
raster_fn_out_tmp = append2filename(raster_fn_out, "_tmp_delete")
gdal.Rasterize(raster_fn_out_tmp, source_ds, format='envi', outputType=gdal.GDT_Byte,
noData=NoData_value, initValues=NoData_value, xRes=pixel_size, yRes=-pixel_size, allTouched=True,
burnValues=1)
dataset_tmp = gdal.Open(raster_fn_out_tmp)
geotransform_tmp = dataset_tmp.GetGeoTransform()
inv_geotransform_tmp = gdal.InvGeoTransform(geotransform_tmp)
data_tmp = dataset_tmp.GetRasterBand(1).ReadAsArray()
del dataset_tmp
# 创建和输入影像相同行列号、相同分辨率的水域掩膜,方便后续使用
water_mask = np.zeros((im_height, im_width))
for row in range(im_height):
for column in range(im_width):
coor = gdal.ApplyGeoTransform(geotransform, column, row)
coor_pixel = gdal.ApplyGeoTransform(inv_geotransform_tmp, coor[0], coor[1])
coor_pixel = [int(num) for num in coor_pixel]
if coor_pixel[0] < 0 or coor_pixel[0] >= data_tmp.shape[1]:
continue
if coor_pixel[1] < 0 or coor_pixel[1] >= data_tmp.shape[0]:
continue
if imgdata_in[row, column] == 0: # 当shp区域比影像区域大时略过
continue
water_mask[row, column] = data_tmp[coor_pixel[1], coor_pixel[0]]
write_bands(img_path, raster_fn_out, water_mask)
os.remove(raster_fn_out_tmp)
def calculate_NDWI(green_bandnumber, nir_bandnumber, filename):
dataset = gdal.Open(filename) # 打开文件
num_bands = dataset.RasterCount # 栅格矩阵的波段数
im_geotrans = dataset.GetGeoTransform() # 仿射矩阵
im_proj = dataset.GetProjection() # 地图投影信息
tmp = dataset.GetRasterBand(green_bandnumber + 1) # 波段计数从1开始
band_green = tmp.ReadAsArray().astype(np.int16)
tmp = dataset.GetRasterBand(nir_bandnumber + 1) # 波段计数从1开始
band_nir = tmp.ReadAsArray().astype(np.int16)
ndwi = (band_green - band_nir) / (band_green + band_nir)
del dataset
return ndwi
def extract_water(ndwi, threshold=0.3, data_ignore_value=0):
water_region = np.where(ndwi > threshold, 1, data_ignore_value)
return water_region
def ndwi(file_path, ndwi_threshold=0.4, output_path=None, data_ignore_value=0):
if output_path is None:
output_path = append2filename(file_path, "_waterarea")
dataset_in = gdal.Open(file_path)
im_width_in = dataset_in.RasterXSize # 栅格矩阵的列数
im_height_in = dataset_in.RasterYSize # 栅格矩阵的行数
num_bands_in = dataset_in.RasterCount # 栅格矩阵的波段数
geotrans_in = dataset_in.GetGeoTransform() # 仿射矩阵
proj_in = dataset_in.GetProjection() # 地图投影信息
del dataset_in
green_wave = 552.19
nir_wave = 809.2890
green_band_number = find_band_number(green_wave, file_path)
nir_band_number = find_band_number(nir_wave, file_path)
ndwi = calculate_NDWI(green_band_number, nir_band_number, file_path)
water_binary = extract_water(ndwi, threshold=ndwi_threshold) # 0.4
write_bands(file_path, output_path, water_binary)
return output_path
def main():
parser = argparse.ArgumentParser(description="此程序用于提取水域区域,输出的水域栅格和输入的影像具有相同的行列数。")
# parser.add_argument("--global_arg", type=str, help="A global argument for all modes", required=True)
# 创建子命令解析器
subparsers = parser.add_subparsers(dest="algorithm", required=True, help="Choose a mode")
rasterize_shp_ = subparsers.add_parser("rasterize_shp", help="Mode 1 description")
rasterize_shp_.add_argument('-i1', '--img_path', type=str, required=True, help='输入影像文件的路径')
rasterize_shp_.add_argument('-i2', '--shp_path', type=str, required=True, help='输入shp文件的路径')
rasterize_shp_.add_argument('-o', '--water_mask_outpath', required=True, type=str, help='输出水体掩膜文件的路径')
rasterize_shp_.set_defaults(func=rasterize_shp)
ndwi_ = subparsers.add_parser("ndwi", help="Mode 2 description")
ndwi_.add_argument('-i1', '--img_path', type=str, required=True, help='输入影像文件的路径')
ndwi_.add_argument('-i2', '--ndwi_threshold', type=float, required=True, help='输入ndwi水体阈值大于此值的为水域')
ndwi_.add_argument('-o', '--water_mask_outpath', required=True, type=str, help='输出水体掩膜文件的路径')
ndwi_.set_defaults(func=ndwi)
# 解析参数
args = parser.parse_args()
if args.algorithm == "rasterize_shp":
args.func(args.shp_path, args.water_mask_outpath, args.img_path)
elif args.algorithm == "ndwi":
args.func(args.img_path, args.ndwi_threshold, args.water_mask_outpath)
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
main()