143 lines
5.5 KiB
Python
143 lines
5.5 KiB
Python
from util import *
|
||
from osgeo import gdal
|
||
import argparse
|
||
|
||
|
||
def xml2shp():
|
||
pass
|
||
|
||
|
||
def rasterize_envi_xml(shp_filepath):
|
||
pass
|
||
|
||
|
||
@timeit
|
||
def rasterize_shp(shp_filepath, raster_fn_out, img_path, NoData_value=0):
|
||
dataset = gdal.Open(img_path)
|
||
im_width = dataset.RasterXSize
|
||
im_height = dataset.RasterYSize
|
||
geotransform = dataset.GetGeoTransform()
|
||
imgdata_in = dataset.GetRasterBand(1).ReadAsArray()
|
||
del dataset
|
||
|
||
# Open the data source and read in the extent
|
||
source_ds = gdal.OpenEx(shp_filepath)
|
||
# about 25 metres(ish) use 0.001 if you want roughly 100m
|
||
pixel_size = geotransform[1]
|
||
raster_fn_out_tmp = append2filename(raster_fn_out, "_tmp_delete")
|
||
gdal.Rasterize(raster_fn_out_tmp, source_ds, format='envi', outputType=gdal.GDT_Byte,
|
||
noData=NoData_value, initValues=NoData_value, xRes=pixel_size, yRes=-pixel_size, allTouched=True,
|
||
burnValues=1)
|
||
|
||
dataset_tmp = gdal.Open(raster_fn_out_tmp)
|
||
geotransform_tmp = dataset_tmp.GetGeoTransform()
|
||
inv_geotransform_tmp = gdal.InvGeoTransform(geotransform_tmp)
|
||
data_tmp = dataset_tmp.GetRasterBand(1).ReadAsArray()
|
||
del dataset_tmp
|
||
|
||
# 创建和输入影像相同行列号、相同分辨率的水域掩膜,方便后续使用
|
||
water_mask = np.zeros((im_height, im_width))
|
||
for row in range(im_height):
|
||
for column in range(im_width):
|
||
coor = gdal.ApplyGeoTransform(geotransform, column, row)
|
||
|
||
coor_pixel = gdal.ApplyGeoTransform(inv_geotransform_tmp, coor[0], coor[1])
|
||
coor_pixel = [int(num) for num in coor_pixel]
|
||
|
||
if coor_pixel[0] < 0 or coor_pixel[0] >= data_tmp.shape[1]:
|
||
continue
|
||
if coor_pixel[1] < 0 or coor_pixel[1] >= data_tmp.shape[0]:
|
||
continue
|
||
|
||
if imgdata_in[row, column] == 0: # 当shp区域比影像区域大时,略过
|
||
continue
|
||
|
||
water_mask[row, column] = data_tmp[coor_pixel[1], coor_pixel[0]]
|
||
|
||
write_bands(img_path, raster_fn_out, water_mask)
|
||
|
||
os.remove(raster_fn_out_tmp)
|
||
|
||
|
||
def calculate_NDWI(green_bandnumber, nir_bandnumber, filename):
|
||
dataset = gdal.Open(filename) # 打开文件
|
||
num_bands = dataset.RasterCount # 栅格矩阵的波段数
|
||
im_geotrans = dataset.GetGeoTransform() # 仿射矩阵
|
||
im_proj = dataset.GetProjection() # 地图投影信息
|
||
|
||
tmp = dataset.GetRasterBand(green_bandnumber + 1) # 波段计数从1开始
|
||
band_green = tmp.ReadAsArray().astype(np.int16)
|
||
tmp = dataset.GetRasterBand(nir_bandnumber + 1) # 波段计数从1开始
|
||
band_nir = tmp.ReadAsArray().astype(np.int16)
|
||
|
||
ndwi = (band_green - band_nir) / (band_green + band_nir)
|
||
|
||
del dataset
|
||
|
||
return ndwi
|
||
|
||
|
||
def extract_water(ndwi, threshold=0.3, data_ignore_value=0):
|
||
water_region = np.where(ndwi > threshold, 1, data_ignore_value)
|
||
|
||
return water_region
|
||
|
||
|
||
def ndwi(file_path, ndwi_threshold=0.4, output_path=None, data_ignore_value=0):
|
||
if output_path is None:
|
||
output_path = append2filename(file_path, "_waterarea")
|
||
|
||
dataset_in = gdal.Open(file_path)
|
||
im_width_in = dataset_in.RasterXSize # 栅格矩阵的列数
|
||
im_height_in = dataset_in.RasterYSize # 栅格矩阵的行数
|
||
num_bands_in = dataset_in.RasterCount # 栅格矩阵的波段数
|
||
geotrans_in = dataset_in.GetGeoTransform() # 仿射矩阵
|
||
proj_in = dataset_in.GetProjection() # 地图投影信息
|
||
del dataset_in
|
||
|
||
green_wave = 552.19
|
||
nir_wave = 809.2890
|
||
green_band_number = find_band_number(green_wave, file_path)
|
||
nir_band_number = find_band_number(nir_wave, file_path)
|
||
|
||
ndwi = calculate_NDWI(green_band_number, nir_band_number, file_path)
|
||
|
||
water_binary = extract_water(ndwi, threshold=ndwi_threshold) # 0.4
|
||
|
||
write_bands(file_path, output_path, water_binary)
|
||
|
||
return output_path
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="此程序用于提取水域区域,输出的水域栅格和输入的影像具有相同的行列数。")
|
||
|
||
# parser.add_argument("--global_arg", type=str, help="A global argument for all modes", required=True)
|
||
|
||
# 创建子命令解析器
|
||
subparsers = parser.add_subparsers(dest="algorithm", required=True, help="Choose a mode")
|
||
|
||
rasterize_shp_ = subparsers.add_parser("rasterize_shp", help="Mode 1 description")
|
||
rasterize_shp_.add_argument('-i1', '--img_path', type=str, required=True, help='输入影像文件的路径')
|
||
rasterize_shp_.add_argument('-i2', '--shp_path', type=str, required=True, help='输入shp文件的路径')
|
||
rasterize_shp_.add_argument('-o', '--water_mask_outpath', required=True, type=str, help='输出水体掩膜文件的路径')
|
||
rasterize_shp_.set_defaults(func=rasterize_shp)
|
||
|
||
ndwi_ = subparsers.add_parser("ndwi", help="Mode 2 description")
|
||
ndwi_.add_argument('-i1', '--img_path', type=str, required=True, help='输入影像文件的路径')
|
||
ndwi_.add_argument('-i2', '--ndwi_threshold', type=float, required=True, help='输入ndwi水体阈值,大于此值的为水域')
|
||
ndwi_.add_argument('-o', '--water_mask_outpath', required=True, type=str, help='输出水体掩膜文件的路径')
|
||
ndwi_.set_defaults(func=ndwi)
|
||
|
||
# 解析参数
|
||
args = parser.parse_args()
|
||
if args.algorithm == "rasterize_shp":
|
||
args.func(args.shp_path, args.water_mask_outpath, args.img_path)
|
||
elif args.algorithm == "ndwi":
|
||
args.func(args.img_path, args.ndwi_threshold, args.water_mask_outpath)
|
||
|
||
|
||
# Press the green button in the gutter to run the script.
|
||
if __name__ == '__main__':
|
||
main()
|