Initial commit of WQ_GUI

This commit is contained in:
2026-04-08 15:25:08 +08:00
commit 91e36407ae
302 changed files with 40872 additions and 0 deletions

1
src/utils/__init__.py Normal file
View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

226
src/utils/band_math.py Normal file
View File

@ -0,0 +1,226 @@
import pandas as pd
import numpy as np
import re
class BandMathCalculator:
def __init__(self, csv_file):
"""
初始化计算器
csv_file: 包含光谱反射率的CSV文件路径
"""
self.df = pd.read_csv(csv_file)
self.wavelengths = self._extract_wavelengths()
def _extract_wavelengths(self):
"""从列名中提取波长信息"""
wavelengths = []
for col in self.df.columns:
# 尝试从列名中提取数字(波长)
numbers = re.findall(r'\d+\.?\d*', str(col))
if numbers:
wavelengths.append(float(numbers[0]))
else:
wavelengths.append(None)
return wavelengths
def _find_closest_wavelength(self, target_wavelength):
"""找到最接近目标波长的列索引"""
valid_indices = [i for i, wl in enumerate(self.wavelengths) if wl is not None]
if not valid_indices:
raise ValueError("未找到有效的波长列")
# 计算与目标波长的差值
differences = [abs(self.wavelengths[i] - target_wavelength) for i in valid_indices]
min_diff_index = np.argmin(differences)
closest_index = valid_indices[min_diff_index]
closest_wavelength = self.wavelengths[closest_index]
print(
f"目标波长 {target_wavelength}nm -> 最接近波长 {closest_wavelength}nm (列: {self.df.columns[closest_index]})")
return closest_index
def _parse_expression(self, expression):
"""解析表达式,提取所有波段变量 - 支持大小写"""
# 匹配 w或W后面跟着数字的格式的变量
pattern = r'[wW](\d+\.?\d*)'
matches = re.findall(pattern, expression)
return matches # 返回字符串列表,如 ['686', '672', '715', '672']
def _create_substitution_dict(self, variables, row_index=0):
"""创建变量替换字典 - 支持大小写"""
substitution_dict = {}
for var in variables:
wavelength = float(var) # 将字符串转换为浮点数
col_index = self._find_closest_wavelength(wavelength)
value = self.df.iloc[row_index, col_index]
# 同时添加小写和大写版本的变量
substitution_dict[f'w{var}'] = value
substitution_dict[f'W{var}'] = value
return substitution_dict
def calculate(self, expression, row_index=0):
"""
计算自定义波段表达式
参数:
expression: 波段计算表达式,如 'chl=w560/w760'
row_index: 要计算的数据行索引默认为第0行
返回:
计算结果
"""
try:
# 提取表达式中的计算部分
if '=' in expression:
# 如果包含赋值,只取等号右边的计算部分
calc_part = expression.split('=')[1].strip()
var_name = expression.split('=')[0].strip()
else:
calc_part = expression.strip()
var_name = None
# 解析变量
variables = self._parse_expression(calc_part)
print(f"解析到的波长变量: {variables}")
# 创建替换字典
sub_dict = self._create_substitution_dict(variables, row_index)
print(f"变量值: {sub_dict}")
# 替换表达式中的变量 - 使用安全的字符串替换
calc_expression = calc_part
for var_pattern, value in sub_dict.items():
# 确保替换完整的变量名,避免部分匹配
calc_expression = re.sub(r'\b' + re.escape(var_pattern) + r'\b', f"({value})", calc_expression)
print(f"计算表达式: {calc_expression}")
# 安全地计算表达式
result = eval(calc_expression)
# 返回结果
if var_name:
return {var_name: result}
else:
return result
except Exception as e:
print(f"计算错误: {e}")
import traceback
traceback.print_exc()
return None
def calculate_all_rows(self, expression):
"""为所有行计算表达式"""
results = []
for i in range(len(self.df)):
print(f"\n--- 计算第 {i} 行 ---")
result = self.calculate(expression, i)
if result is not None:
if isinstance(result, dict):
results.append(list(result.values())[0])
else:
results.append(result)
else:
# 如果计算失败添加NaN值以保持结果数量一致
results.append(np.nan)
print(f"{i} 行计算失败使用NaN填充")
return results
def process_formulas_from_csv(self, formula_csv_file, formula_names=None, output_file=None):
"""
从公式CSV文件中批量计算并添加到数据文件中
参数:
formula_csv_file: 公式CSV文件路径第一列为公式名称第三列为具体公式
formula_names: 要计算的公式名称列表如果为None则计算所有公式
output_file: 输出文件路径如果为None则自动生成
返回:
包含计算结果的新DataFrame
"""
# 读取公式CSV文件
try:
formulas_df = pd.read_csv(formula_csv_file)
print(f"读取到 {len(formulas_df)} 个公式")
# 检查CSV格式假设第一列为公式名称第三列为具体公式
if len(formulas_df.columns) < 3:
raise ValueError("公式CSV文件需要至少3列")
formula_name_col = formulas_df.columns[0] # 第一列:公式名称
formula_expr_col = formulas_df.columns[2] # 第三列:具体公式
# 创建结果DataFrame的副本
result_df = self.df.copy()
# 如果指定了公式名称,则只计算这些公式
if formula_names is not None:
if isinstance(formula_names, str):
formula_names = [formula_names] # 转换为列表
# 筛选出指定的公式
selected_formulas = formulas_df[formulas_df[formula_name_col].isin(formula_names)]
print(f"找到 {len(selected_formulas)} 个指定的公式")
if len(selected_formulas) == 0:
print(f"警告: 未找到指定的公式: {formula_names}")
return result_df
formulas_to_process = selected_formulas
else:
# 计算所有公式
formulas_to_process = formulas_df
# 为每个公式计算所有行
for _, row in formulas_to_process.iterrows():
formula_name = row[formula_name_col]
formula_expr = row[formula_expr_col]
if pd.isna(formula_name) or pd.isna(formula_expr):
print(f"跳过空公式: {row}")
continue
print(f"\n计算公式: {formula_name} = {formula_expr}")
# 计算所有行的结果
results = self.calculate_all_rows(formula_expr)
# 将结果添加到DataFrame
result_df[formula_name] = results
print(f"公式 '{formula_name}' 计算完成,添加到数据中")
# 保存结果
if output_file is None:
# 自动生成输出文件名
import os
base_name = os.path.splitext(os.path.basename(formula_csv_file))[0]
output_file = f"band_math_results_{base_name}.csv"
result_df.to_csv(output_file, index=False)
print(f"结果已保存到: {output_file}")
return result_df
except Exception as e:
print(f"处理公式CSV文件时出错: {e}")
import traceback
traceback.print_exc()
return None
# 更新使用示例
if __name__ == "__main__":
# 创建计算器实例
calculator = BandMathCalculator(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\training_spectra.csv")
# 示例1: 计算所有公式
# result_df = calculator.process_formulas_from_csv(r"E:\code\WQ\封装\sub\水质参数.csv", "enhanced_data.csv")
# 示例2: 计算指定公式
result_df = calculator.process_formulas_from_csv(
r"E:\code\WQ\封装\sub\水质参数.csv",
formula_names=["BGA_Am09KBBI", "BGA_Be162B643sub629"],
output_file=r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\enhanced_data.csv"
)

View File

@ -0,0 +1,172 @@
from src.utils.util import *
from osgeo import gdal, ogr
import argparse
gdal.UseExceptions()
ogr.UseExceptions()
def xml2shp():
pass
def rasterize_envi_xml(shp_filepath):
pass
@timeit
def rasterize_shp(shp_filepath, raster_fn_out, img_path, NoData_value=None):
dataset = gdal.Open(img_path)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
geotransform = dataset.GetGeoTransform()
imgdata_in = dataset.GetRasterBand(1).ReadAsArray()
del dataset
# Open the data source and read in the extent
source_ds = gdal.OpenEx(shp_filepath, gdal.OF_VECTOR)
if source_ds is None:
raise ValueError(f"无法打开shapefile: {shp_filepath}")
# 检查图层数量,如果有多层,指定使用第一层
layer_count = source_ds.GetLayerCount()
layer_name = None
if layer_count > 1:
print(f"警告: shapefile包含{layer_count}个图层,将使用第一个图层进行栅格化")
# 获取第一个图层
layer = source_ds.GetLayer(0)
layer_name = layer.GetName()
# about 25 metres(ish) use 0.001 if you want roughly 100m
pixel_size_x = abs(geotransform[1]) # 像素宽度X方向
pixel_size_y = abs(geotransform[5]) # 像素高度Y方向通常是负值需要取绝对值
raster_fn_out_tmp = append2filename(raster_fn_out, "_tmp_delete")
# 构建栅格化参数
rasterize_kwargs = {
'format': 'envi',
'outputType': gdal.GDT_Byte,
'noData': NoData_value,
'initValues': NoData_value,
'xRes': pixel_size_x,
'yRes': pixel_size_y,
'allTouched': True,
'burnValues': 1
}
# 如果有多层,指定使用第一层
if layer_name is not None:
rasterize_kwargs['layers'] = [layer_name]
# 执行栅格化
gdal.Rasterize(raster_fn_out_tmp, source_ds, **rasterize_kwargs)
dataset_tmp = gdal.Open(raster_fn_out_tmp)
geotransform_tmp = dataset_tmp.GetGeoTransform()
inv_geotransform_tmp = gdal.InvGeoTransform(geotransform_tmp)
data_tmp = dataset_tmp.GetRasterBand(1).ReadAsArray()
del dataset_tmp
# 创建和输入影像相同行列号、相同分辨率的水域掩膜,方便后续使用
water_mask = np.zeros((im_height, im_width))
for row in range(im_height):
for column in range(im_width):
coor = gdal.ApplyGeoTransform(geotransform, column, row)
coor_pixel = gdal.ApplyGeoTransform(inv_geotransform_tmp, coor[0], coor[1])
coor_pixel = [int(num) for num in coor_pixel]
if coor_pixel[0] < 0 or coor_pixel[0] >= data_tmp.shape[1]:
continue
if coor_pixel[1] < 0 or coor_pixel[1] >= data_tmp.shape[0]:
continue
if imgdata_in[row, column] == 0: # 当shp区域比影像区域大时略过
continue
water_mask[row, column] = data_tmp[coor_pixel[1], coor_pixel[0]]
write_bands(img_path, raster_fn_out, water_mask)
os.remove(raster_fn_out_tmp)
def calculate_NDWI(green_bandnumber, nir_bandnumber, filename):
dataset = gdal.Open(filename) # 打开文件
num_bands = dataset.RasterCount # 栅格矩阵的波段数
im_geotrans = dataset.GetGeoTransform() # 仿射矩阵
im_proj = dataset.GetProjection() # 地图投影信息
tmp = dataset.GetRasterBand(green_bandnumber + 1) # 波段计数从1开始
band_green = tmp.ReadAsArray().astype(np.int16)
tmp = dataset.GetRasterBand(nir_bandnumber + 1) # 波段计数从1开始
band_nir = tmp.ReadAsArray().astype(np.int16)
ndwi = (band_green - band_nir) / (band_green + band_nir)
del dataset
return ndwi
def extract_water(ndwi, threshold=0.3, data_ignore_value=0):
water_region = np.where(ndwi > threshold, 1, data_ignore_value)
return water_region
def ndwi(file_path, ndwi_threshold=0.4, output_path=None, data_ignore_value=0):
if output_path is None:
output_path = append2filename(file_path, "_waterarea")
dataset_in = gdal.Open(file_path)
im_width_in = dataset_in.RasterXSize # 栅格矩阵的列数
im_height_in = dataset_in.RasterYSize # 栅格矩阵的行数
num_bands_in = dataset_in.RasterCount # 栅格矩阵的波段数
geotrans_in = dataset_in.GetGeoTransform() # 仿射矩阵
proj_in = dataset_in.GetProjection() # 地图投影信息
del dataset_in
green_wave = 552.19
nir_wave = 809.2890
green_band_number = find_band_number(green_wave, file_path)
nir_band_number = find_band_number(nir_wave, file_path)
ndwi = calculate_NDWI(green_band_number, nir_band_number, file_path)
water_binary = extract_water(ndwi, threshold=ndwi_threshold) # 0.4
write_bands(file_path, output_path, water_binary)
return output_path
def main():
parser = argparse.ArgumentParser(description="此程序用于提取水域区域,输出的水域栅格和输入的影像具有相同的行列数。")
# parser.add_argument("--global_arg", type=str, help="A global argument for all modes", required=True)
# 创建子命令解析器
subparsers = parser.add_subparsers(dest="algorithm", required=True, help="Choose a mode")
rasterize_shp_ = subparsers.add_parser("rasterize_shp", help="Mode 1 description")
rasterize_shp_.add_argument('-i1', '--img_path', type=str, required=True, help='输入影像文件的路径')
rasterize_shp_.add_argument('-i2', '--shp_path', type=str, required=True, help='输入shp文件的路径')
rasterize_shp_.add_argument('-o', '--water_mask_outpath', required=True, type=str, help='输出水体掩膜文件的路径')
rasterize_shp_.set_defaults(func=rasterize_shp)
ndwi_ = subparsers.add_parser("ndwi", help="Mode 2 description")
ndwi_.add_argument('-i1', '--img_path', type=str, required=True, help='输入影像文件的路径')
ndwi_.add_argument('-i2', '--ndwi_threshold', type=float, required=True, help='输入ndwi水体阈值大于此值的为水域')
ndwi_.add_argument('-o', '--water_mask_outpath', required=True, type=str, help='输出水体掩膜文件的路径')
ndwi_.set_defaults(func=ndwi)
# 解析参数
args = parser.parse_args()
if args.algorithm == "rasterize_shp":
args.func(args.shp_path, args.water_mask_outpath, args.img_path)
elif args.algorithm == "ndwi":
args.func(args.img_path, args.ndwi_threshold, args.water_mask_outpath)
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
main()

View File

@ -0,0 +1,765 @@
from src.utils.util import *
from osgeo import gdal, ogr
import argparse
import cv2
def percentile_stretch(img, data_water_mask, lower_percentile=2, upper_percentile=98, output_range=(0, 255)):
"""
使用百分位数裁剪进行归一化,适用于低反射率数据
通过排除极值,更好地利用数据的动态范围
Args:
img: 输入图像数组反射率值通常在0-1之间
data_water_mask: 水域掩膜
lower_percentile: 下百分位数用于裁剪最小值默认2
upper_percentile: 上百分位数用于裁剪最大值默认98
output_range: 输出范围,默认(0, 255)
Returns:
归一化后的图像数组(整数类型)
"""
# 只在水域掩膜区域计算百分位数
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于百分位数计算,使用原始值")
return img.astype(np.int32)
# 计算百分位数
p_lower = np.percentile(valid_pixels, lower_percentile)
p_upper = np.percentile(valid_pixels, upper_percentile)
# 如果上下界相同,使用最大值作为上界
if p_lower >= p_upper:
p_lower = np.percentile(valid_pixels, 1)
p_upper = np.percentile(valid_pixels, 99)
if p_lower >= p_upper:
p_upper = valid_pixels.max()
p_lower = valid_pixels.min()
print(f"百分位数拉伸: {lower_percentile}%={p_lower:.6f}, {upper_percentile}%={p_upper:.6f}, "
f"数据范围=[{img.min():.6f}, {img.max():.6f}]")
# 裁剪到百分位数范围
img_clipped = np.clip(img, p_lower, p_upper)
# 线性拉伸到输出范围
if p_upper > p_lower:
img_stretched = (img_clipped - p_lower) / (p_upper - p_lower) * (output_range[1] - output_range[0]) + output_range[0]
else:
img_stretched = np.full_like(img, output_range[0], dtype=np.float32)
return img_stretched.astype(np.int32)
@timeit
def otsu(img, max_value, data_water_mask, ignore_value=0, foreground=1, background=0):
height = img.shape[0]
width = img.shape[1]
hist = np.zeros([max_value], np.float32)
# 计算直方图
invalid_counter = 0
for i in range(height):
for j in range(width):
if img[i, j] == ignore_value or img[i, j] < 0 or data_water_mask[i, j] == 0:
invalid_counter = invalid_counter + 1
continue
hist[img[i, j]] += 1
hist /= (height * width - invalid_counter)
threshold = 0
deltaMax = 0
# 遍历像素值,计算最大类间方差
for i in range(max_value):
wA = 0
wB = 0
uAtmp = 0
uBtmp = 0
uA = 0
uB = 0
u = 0
for j in range(max_value):
if j <= i:
wA += hist[j]
uAtmp += j * hist[j]
else:
wB += hist[j]
uBtmp += j * hist[j]
if wA == 0:
wA = 1e-10
if wB == 0:
wB = 1e-10
uA = uAtmp / wA
uB = uBtmp / wB
u = uAtmp + uBtmp
# 计算类间方差
deltaTmp = wA * ((uA - u)**2) + wB * ((uB - u)**2)
# 找出最大类间方差以及阈值
if deltaTmp > deltaMax:
deltaMax = deltaTmp
threshold = i
# 二值化
det_img = img.copy()
det_img[img > threshold] = foreground
det_img[img <= threshold] = background
det_img[np.where(data_water_mask == 0)] = background
return det_img
@timeit
def zscore_threshold(img, data_water_mask, z_threshold=2.5, foreground=1, background=0):
"""
基于Z-score标准化分数的耀斑检测方法
使用统计方法识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
z_threshold: Z-score阈值默认2.5即超过均值2.5个标准差)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
# 只在水域掩膜区域计算统计量,排除无效值
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于统计计算")
return np.zeros_like(img, dtype=np.int32)
mean_val = np.mean(valid_pixels)
std_val = np.std(valid_pixels)
if std_val == 0:
print("警告: 标准差为0无法使用Z-score方法")
return np.zeros_like(img, dtype=np.int32)
# 计算Z-score对无效值进行保护
z_scores = np.zeros_like(img, dtype=np.float32)
valid_mask = (data_water_mask > 0) & np.isfinite(img)
z_scores[valid_mask] = (img[valid_mask] - mean_val) / std_val
# 二值化
det_img = np.zeros_like(img, dtype=np.int32)
det_img[z_scores > z_threshold] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"Z-score方法: 均值={mean_val:.2f}, 标准差={std_val:.2f}, 阈值={mean_val + z_threshold * std_val:.2f}")
return det_img
@timeit
def percentile_threshold(img, data_water_mask, percentile=95, foreground=1, background=0):
"""
基于百分位数的耀斑检测方法
使用百分位数作为阈值,对异常值更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
percentile: 百分位数阈值默认95即超过95%的像素值)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
# 只在水域掩膜区域计算百分位数,排除无效值
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于统计计算")
return np.zeros_like(img, dtype=np.int32)
threshold = np.percentile(valid_pixels, percentile)
# 二值化
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > threshold] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"百分位数方法: {percentile}%分位数为 {threshold:.2f}")
return det_img
@timeit
def multi_band_glint_detection(dataset, img_path, water_mask, glint_waves, weights=None, method='zscore',
z_threshold=2.5, percentile=95, foreground=1, background=0):
"""
多波段融合的耀斑检测方法
结合多个波段的耀斑特征,提高检测的稳健性
Args:
dataset: GDAL数据集
img_path: 影像文件路径(用于获取波长信息)
water_mask: 水域掩膜数组
glint_waves: 用于检测的波长列表,如[750, 800, 850]
weights: 各波段的权重如果为None则使用等权重
method: 使用的检测方法 ('zscore', 'percentile', 'otsu')
z_threshold: Z-score阈值当method='zscore'时使用)
percentile: 百分位数阈值当method='percentile'时使用)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
num_bands = dataset.RasterCount
if weights is None:
weights = [1.0 / len(glint_waves)] * len(glint_waves)
if len(weights) != len(glint_waves):
raise ValueError("权重数量必须与波长数量相同")
# 读取多个波段并加权融合使用float32保持精度
fused_band = None
for i, wave in enumerate(glint_waves):
band_num = find_band_number(wave, img_path)
if band_num >= num_bands:
print(f"警告: 波段号 {band_num} 超出范围,跳过波长 {wave}")
continue
tmp = dataset.GetRasterBand(band_num + 1).ReadAsArray().astype(np.float32)
if fused_band is None:
fused_band = (tmp * weights[i]).astype(np.float32)
else:
fused_band = (fused_band + tmp * weights[i]).astype(np.float32)
if fused_band is None:
raise ValueError("没有有效的波段可以融合")
# 根据方法选择是否需要归一化
# 对于统计方法zscore, percentile直接使用原始反射率值
# 对于Otsu方法需要归一化到整数范围
if method == 'otsu':
# Otsu方法需要整数范围使用百分位数拉伸
fused_band_stretch = percentile_stretch(fused_band, water_mask,
lower_percentile=2, upper_percentile=98)
return otsu(fused_band_stretch, fused_band_stretch.max() + 1, water_mask,
foreground=foreground, background=background)
elif method == 'zscore':
# Z-score方法直接使用原始反射率值
return zscore_threshold(fused_band, water_mask, z_threshold, foreground, background)
elif method == 'percentile':
# 百分位数方法直接使用原始反射率值
return percentile_threshold(fused_band, water_mask, percentile, foreground, background)
else:
raise ValueError(f"不支持的方法: {method}")
@timeit
def adaptive_threshold(img, data_water_mask, window_size=15, percentile=90, foreground=1, background=0):
"""
自适应阈值方法
基于局部统计特性进行阈值分割,对光照变化更稳健
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
window_size: 局部窗口大小(奇数)
percentile: 局部百分位数阈值
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
height, width = img.shape
# 确保窗口大小为奇数
if window_size % 2 == 0:
window_size += 1
half_window = window_size // 2
# 创建输出图像
det_img = np.zeros_like(img, dtype=np.int32)
# 对每个像素计算局部阈值
for i in range(half_window, height - half_window):
for j in range(half_window, width - half_window):
# 只在水域掩膜内处理
if data_water_mask[i, j] == 0:
continue
# 提取局部窗口
local_window = img[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
local_mask = data_water_mask[i - half_window:i + half_window + 1,
j - half_window:j + half_window + 1]
# 只考虑有效像素
valid_pixels = local_window[local_mask > 0]
if len(valid_pixels) > 0:
local_threshold = np.percentile(valid_pixels, percentile)
if img[i, j] > local_threshold:
det_img[i, j] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"自适应阈值方法: 窗口大小={window_size}, 局部百分位数={percentile}%")
return det_img
@timeit
def iqr_outlier_detection(img, data_water_mask, iqr_multiplier=1.5, foreground=1, background=0):
"""
基于IQR四分位距的异常值检测方法
使用四分位距识别异常高亮的像素,对数据分布不敏感
Args:
img: 输入图像数组
data_water_mask: 水域掩膜
iqr_multiplier: IQR倍数默认1.5(标准异常值检测)
foreground: 前景值
background: 背景值
Returns:
二值化检测结果
"""
# 只在水域掩膜区域计算统计量,排除无效值
valid_pixels = img[(data_water_mask > 0) & (img > 0) & np.isfinite(img)]
if len(valid_pixels) == 0:
print("警告: 没有有效像素用于统计计算")
return np.zeros_like(img, dtype=np.int32)
q1 = np.percentile(valid_pixels, 25)
q3 = np.percentile(valid_pixels, 75)
iqr = q3 - q1
# 上界 = Q3 + 1.5 * IQR
upper_bound = q3 + iqr_multiplier * iqr
# 二值化
det_img = np.zeros_like(img, dtype=np.int32)
det_img[img > upper_bound] = foreground
det_img[np.where(data_water_mask == 0)] = background
print(f"IQR方法: Q1={q1:.2f}, Q3={q3:.2f}, IQR={iqr:.2f}, 上界={upper_bound:.2f}")
return det_img
@timeit
def create_shoreline_buffer(water_mask, buffer_size=5, foreground=1, background=0):
"""
创建岸边缓冲区掩膜(向内缓冲)
用于去除岸边附近的错误耀斑检测区域
方法:对水域掩膜进行腐蚀,然后用原始水域减去腐蚀后的水域,得到水域边缘向内缓冲的区域
Args:
water_mask: 水域掩膜数组(水域=1非水域=0
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
岸边缓冲区掩膜(缓冲区区域=1其他=0
"""
if buffer_size <= 0:
print("缓冲区大小为0或负数不创建岸边缓冲区")
return np.zeros_like(water_mask, dtype=np.int32)
# 将水域掩膜转换为二值图像
water_binary = (water_mask > 0).astype(np.int32)
# 创建结构元素(方形结构元素)
# 结构元素大小由buffer_size决定确保是奇数
structure_size = buffer_size * 2 + 1
structure = np.ones((structure_size, structure_size), dtype=np.int32)
# 对水域进行腐蚀,得到缩小后的水域
# 使用OpenCV替代scipy.ndimage.binary_erosion
eroded_water = cv2.erode(water_binary.astype(np.uint8), structure.astype(np.uint8)).astype(np.int32)
# 岸边缓冲区 = 原始水域 - 腐蚀后的水域
# 这给出了水域边缘向内buffer_size像素宽的缓冲区区域
buffer_mask = (water_binary - eroded_water).astype(np.int32)
buffer_pixels = np.sum(buffer_mask > 0)
print(f"岸边缓冲区: 创建了 {buffer_size} 像素宽的内向缓冲区,共 {buffer_pixels} 个像素")
return buffer_mask
@timeit
def remove_shoreline_buffer(glint_mask, water_mask, buffer_size=5, foreground=1, background=0):
"""
从耀斑掩膜中去除岸边缓冲区内的区域
Args:
glint_mask: 耀斑掩膜数组
water_mask: 水域掩膜数组
buffer_size: 缓冲区大小像素数默认5像素
foreground: 前景值
background: 背景值
Returns:
去除岸边缓冲区后的耀斑掩膜
"""
if buffer_size <= 0:
print("缓冲区大小为0不进行岸边缓冲区去除")
return glint_mask
# 创建岸边缓冲区掩膜
buffer_mask = create_shoreline_buffer(water_mask, buffer_size, foreground, background)
# 从耀斑掩膜中去除缓冲区内的区域
cleaned_glint_mask = glint_mask.copy()
cleaned_glint_mask[buffer_mask > 0] = background
removed_pixels = np.sum((glint_mask > 0) & (buffer_mask > 0))
remaining_pixels = np.sum(cleaned_glint_mask > 0)
if removed_pixels > 0:
print(f"岸边缓冲区去除: 从耀斑掩膜中移除了 {removed_pixels} 个岸边向内缓冲区域的像素,"
f"剩余 {remaining_pixels} 个像素")
else:
print(f"岸边缓冲区去除: 缓冲区区域没有耀斑掩膜,无需移除")
return cleaned_glint_mask
@timeit
def filter_large_components(binary_img, max_area=None, foreground=1, background=0):
"""
过滤掉面积超过阈值的连通域
用于去除大面积区域(如岸边、浅水、水华等),保留小面积的耀斑区域
Args:
binary_img: 二值化图像
max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被去除
如果为None则不进行过滤
foreground: 前景值
background: 背景值
Returns:
过滤后的二值化图像
"""
if max_area is None or max_area <= 0:
return binary_img
# 连通域标记
# 使用OpenCV替代scipy.ndimage.label
binary_for_label = (binary_img == foreground).astype(np.uint8)
num_features, labeled_array, stats, centroids = cv2.connectedComponentsWithStats(binary_for_label, connectivity=8)
if num_features == 0:
print("没有检测到连通域")
return binary_img
# 使用OpenCV返回的stats信息直接获取连通域面积
# stats[:, cv2.CC_STAT_AREA] 包含每个连通域的面积(包括背景)
# 跳过索引0背景的面积从索引1开始获取连通域面积
component_sizes = stats[1:, cv2.CC_STAT_AREA]
# 找出需要保留的连通域(面积 <= max_area
keep_labels = np.where(component_sizes <= max_area)[0] + 1 # +1 因为标签从1开始
# 使用布尔索引一次性过滤(高效方法)
# 创建一个mask标记所有需要保留的连通域
keep_mask = np.isin(labeled_array, keep_labels)
# 创建输出图像
filtered_img = np.zeros_like(binary_img, dtype=binary_img.dtype)
filtered_img[keep_mask] = foreground
# 统计信息
removed_count = num_features - len(keep_labels)
kept_count = len(keep_labels)
total_removed_pixels = np.sum(component_sizes[component_sizes > max_area])
if removed_count > 0:
print(f"连通域面积过滤: 移除了 {removed_count} 个大面积连通域(面积 > {max_area} 像素),"
f"共移除 {total_removed_pixels} 个像素;保留了 {kept_count} 个小面积连通域")
else:
print(f"连通域面积过滤: 所有 {kept_count} 个连通域面积均小于阈值 {max_area},全部保留")
return filtered_img
def find_overexposure_area(img_path, threhold=4095):
# 第一步通过某个像素的光谱找到信号最强的波段
# 根据上步所得的波段号检测过曝区域
pass
def create_water_mask_from_shp(shp_file, reference_raster):
"""
从shp文件创建水体掩膜栅格数组内存中不保存到磁盘
参数:
shp_file: str - shp文件路径
reference_raster: str - 参考栅格文件路径(用于获取空间范围和分辨率)
返回:
numpy.ndarray - 水体掩膜数组
"""
try:
# 打开参考栅格获取空间信息
ref_dataset = gdal.Open(reference_raster)
if ref_dataset is None:
raise ValueError(f"无法打开参考栅格文件: {reference_raster}")
geotransform = ref_dataset.GetGeoTransform()
projection = ref_dataset.GetProjection()
width = ref_dataset.RasterXSize
height = ref_dataset.RasterYSize
# 创建内存中的栅格数据集
mem_driver = gdal.GetDriverByName('MEM')
mask_dataset = mem_driver.Create('', width, height, 1, gdal.GDT_Byte)
mask_dataset.SetGeoTransform(geotransform)
mask_dataset.SetProjection(projection)
# 初始化为0
mask_band = mask_dataset.GetRasterBand(1)
mask_band.Fill(0)
# 打开shp文件
shp_dataset = ogr.Open(shp_file)
if shp_dataset is None:
raise ValueError(f"无法打开shp文件: {shp_file}")
layer = shp_dataset.GetLayer()
# 栅格化shp文件
gdal.RasterizeLayer(mask_dataset, [1], layer, burn_values=[1])
# 读取栅格化结果
water_mask = mask_band.ReadAsArray()
# 清理
ref_dataset = None
mask_dataset = None
shp_dataset = None
return water_mask
except Exception as e:
print(f"创建水体掩膜时发生错误: {str(e)}")
raise
@timeit
def find_severe_glint_area(img_path, water_mask, glint_wave=750, output_path=None,
method='otsu', multi_band_waves=None, **kwargs):
"""
找到严重耀斑区域的主函数
注意对于低反射率数据如水面反射率约0.02),本函数采用了改进的归一化策略:
- 统计方法zscore, percentile, iqr直接使用原始反射率值无需归一化
- Otsu和adaptive方法使用百分位数裁剪拉伸2%-98%分位数),避免极值影响
Args:
img_path: 输入影像路径
water_mask: 水域掩膜路径(支持栅格文件如.dat/.tif或SHP文件如.shp如果为None或空字符串则使用全图进行检测
glint_wave: 用于检测的波长(单个波段方法使用)
output_path: 输出路径
method: 检测方法,可选:
- 'otsu': Otsu阈值分割默认使用百分位数拉伸
- 'zscore': Z-score统计方法直接使用原始反射率
- 'percentile': 百分位数阈值方法(直接使用原始反射率)
- 'iqr': IQR异常值检测直接使用原始反射率
- 'adaptive': 自适应阈值方法(使用百分位数拉伸)
- 'multi_band': 多波段融合方法
multi_band_waves: 多波段方法的波长列表,如[750, 800, 850]
**kwargs: 其他方法特定参数
- z_threshold: Z-score阈值默认2.5
- percentile: 百分位数默认95
- iqr_multiplier: IQR倍数默认1.5
- window_size: 自适应阈值窗口大小默认15
- weights: 多波段方法的权重列表
- sub_method: 多波段方法的子方法('otsu', 'zscore', 'percentile'
- max_area: 最大连通域面积阈值(像素数),超过此面积的连通域将被过滤掉
用于去除岸边、浅水、水华等大面积区域默认None表示不过滤
- buffer_size: 岸边缓冲区大小(像素数),用于去除岸边附近的错误耀斑掩膜
默认None表示不进行岸边缓冲区去除设置为正整数时启用
Returns:
输出文件路径
"""
if output_path is None:
output_path = append2filename(img_path, "_severe_glint_area")
dataset = gdal.Open(img_path)
num_bands = dataset.RasterCount
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
# 读取水域掩膜如果water_mask为None或空字符串则创建全图掩膜
if water_mask is None or water_mask == "":
print("注意: water_mask为空使用全图进行检测")
data_water_mask = np.ones((im_height, im_width), dtype=np.int32)
else:
# 检查是否为SHP文件
water_mask_lower = water_mask.lower()
if water_mask_lower.endswith('.shp'):
# 直接使用SHP文件在内存中栅格化
print(f"检测到SHP文件正在从 {water_mask} 创建水体掩膜...")
data_water_mask = create_water_mask_from_shp(water_mask, img_path)
else:
# 使用栅格文件
dataset_water_mask = gdal.Open(water_mask)
if dataset_water_mask is None:
raise ValueError(f"无法打开水域掩膜文件: {water_mask}")
data_water_mask = dataset_water_mask.GetRasterBand(1).ReadAsArray()
del dataset_water_mask
print(f"使用检测方法: {method}")
# 根据方法选择检测算法
if method == 'multi_band':
if multi_band_waves is None:
# 默认使用几个常见NIR波段
multi_band_waves = [glint_wave, glint_wave + 50, glint_wave + 100]
print(f"多波段方法: 使用默认波长 {multi_band_waves}")
else:
print(f"多波段方法: 使用波长 {multi_band_waves}")
sub_method = kwargs.get('sub_method', 'zscore')
weights = kwargs.get('weights', None)
z_threshold = kwargs.get('z_threshold', 2.5)
percentile = kwargs.get('percentile', 95)
flare_binary = multi_band_glint_detection(
dataset, img_path, data_water_mask, multi_band_waves, weights,
method=sub_method, z_threshold=z_threshold, percentile=percentile
)
else:
# 单波段方法
glint_band_number = find_band_number(glint_wave, img_path)
tmp = dataset.GetRasterBand(glint_band_number + 1)
band_flare = tmp.ReadAsArray().astype(np.float32)
# 根据方法选择是否需要归一化
# 对于统计方法zscore, percentile, iqr直接使用原始反射率值
# 对于Otsu和adaptive方法需要归一化到整数范围
if method == 'otsu':
# Otsu方法需要整数范围使用百分位数拉伸
band_flare_stretch = percentile_stretch(band_flare, data_water_mask,
lower_percentile=2, upper_percentile=98)
flare_binary = otsu(band_flare_stretch, band_flare_stretch.max() + 1, data_water_mask)
elif method == 'zscore':
# Z-score方法直接使用原始反射率值
z_threshold = kwargs.get('z_threshold', 2.5)
flare_binary = zscore_threshold(band_flare, data_water_mask, z_threshold)
elif method == 'percentile':
# 百分位数方法直接使用原始反射率值
percentile = kwargs.get('percentile', 95)
flare_binary = percentile_threshold(band_flare, data_water_mask, percentile)
elif method == 'iqr':
# IQR方法直接使用原始反射率值
iqr_multiplier = kwargs.get('iqr_multiplier', 1.5)
flare_binary = iqr_outlier_detection(band_flare, data_water_mask, iqr_multiplier)
elif method == 'adaptive':
# 自适应阈值方法需要归一化
band_flare_stretch = percentile_stretch(band_flare, data_water_mask,
lower_percentile=2, upper_percentile=98)
window_size = kwargs.get('window_size', 15)
percentile = kwargs.get('percentile', 90)
flare_binary = adaptive_threshold(band_flare_stretch, data_water_mask, window_size, percentile)
else:
raise ValueError(f"不支持的方法: {method}。可选方法: otsu, zscore, percentile, iqr, adaptive, multi_band")
# 过滤掉面积超过阈值的连通域(用于去除岸边、浅水、水华等大面积区域)
max_area = kwargs.get('max_area', None)
if max_area is not None and max_area > 0:
print(f"应用连通域面积过滤,最大面积阈值: {max_area} 像素")
flare_binary = filter_large_components(flare_binary, max_area=max_area)
# 去除岸边缓冲区内的耀斑掩膜(用于去除岸边的错误检测)
buffer_size = kwargs.get('buffer_size', None)
if buffer_size is not None and buffer_size > 0:
print(f"应用岸边缓冲区去除,缓冲区大小: {buffer_size} 像素")
flare_binary = remove_shoreline_buffer(flare_binary, data_water_mask, buffer_size=buffer_size)
write_bands(img_path, output_path, flare_binary)
del dataset
return output_path
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
img_path = r"D:\PycharmProjects\0water_rlx\test_data\ref_mosaic_1m_bsq"
parser = argparse.ArgumentParser(
description="此程序通过多种算法分割图像,提取耀斑最严重的区域。"
"支持的算法: otsu, zscore, percentile, iqr, adaptive, multi_band"
)
parser.add_argument('-i1', '--input', type=str, required=True, help='输入影像文件的路径')
parser.add_argument('-i2', '--input_water_mask', type=str, required=True, help='输入水域掩膜文件的路径')
parser.add_argument('-gw', '--glint_wave', type=float, default=750.0,
help='用于提取耀斑严重区域的波段波长(单波段方法使用)')
parser.add_argument('-m', '--method', type=str, default='otsu',
choices=['otsu', 'zscore', 'percentile', 'iqr', 'adaptive', 'multi_band'],
help='检测方法: otsu(默认), zscore, percentile, iqr, adaptive, multi_band')
parser.add_argument('-o', '--output', type=str, help='输出文件的路径')
# 方法特定参数
parser.add_argument('-zt', '--z_threshold', type=float, default=2.5,
help='Z-score方法的阈值默认2.5')
parser.add_argument('-p', '--percentile', type=float, default=95.0,
help='百分位数阈值默认95')
parser.add_argument('-iqr', '--iqr_multiplier', type=float, default=1.5,
help='IQR方法的倍数默认1.5')
parser.add_argument('-ws', '--window_size', type=int, default=15,
help='自适应阈值方法的窗口大小默认15')
parser.add_argument('-mbw', '--multi_band_waves', type=str, default=None,
help='多波段方法的波长列表,用逗号分隔,如: 750,800,850')
parser.add_argument('-sm', '--sub_method', type=str, default='zscore',
choices=['otsu', 'zscore', 'percentile'],
help='多波段方法的子方法默认zscore')
parser.add_argument('-ma', '--max_area', type=int, default=None,
help='最大连通域面积阈值(像素数),超过此面积的连通域将被过滤掉,'
'用于去除岸边、浅水、水华等大面积区域默认None表示不过滤')
parser.add_argument('-bs', '--buffer_size', type=int, default=None,
help='岸边缓冲区大小(像素数),用于去除岸边附近的错误耀斑掩膜'
'默认None表示不进行岸边缓冲区去除设置为正整数时启用')
parser.add_argument('-v', '--verbose', action='store_true', help='启用详细模式')
args = parser.parse_args()
# 解析多波段波长列表
multi_band_waves = None
if args.multi_band_waves:
multi_band_waves = [float(x.strip()) for x in args.multi_band_waves.split(',')]
# 构建kwargs
kwargs = {
'z_threshold': args.z_threshold,
'percentile': args.percentile,
'iqr_multiplier': args.iqr_multiplier,
'window_size': args.window_size,
'sub_method': args.sub_method,
'max_area': args.max_area,
'buffer_size': args.buffer_size
}
find_severe_glint_area(
args.input, args.input_water_mask, args.glint_wave, args.output,
method=args.method, multi_band_waves=multi_band_waves, **kwargs
)

458
src/utils/kriging.py Normal file
View File

@ -0,0 +1,458 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
克里金插值模块
提供基于PyKrige的普通克里金插值功能用于将离散的水质参数预测点
插值为连续的栅格图像。
主要功能:
1. 普通克里金插值
2. 多种变差函数模型支持
3. 自动参数优化
4. 栅格输出功能
"""
import numpy as np
from osgeo import gdal
import time
import os
import glob
from pathlib import Path
from typing import Optional, Tuple, Union, List
import warnings
warnings.filterwarnings('ignore')
# 导入util模块的timeit装饰器
try:
from src.utils.util import timeit
except ImportError:
# 如果导入失败定义一个简单的timeit装饰器
def timeit(f):
def wrapper(*args, **kwargs):
start = time.time()
ret = f(*args, **kwargs)
print(f"{f.__name__} run time: {round(time.time() - start, 2)} s.")
return ret
return wrapper
class KrigingInterpolator:
"""克里金插值器类"""
def __init__(self, variogram_models: Optional[List[str]] = None):
"""
初始化克里金插值器
Args:
variogram_models: 变差函数模型列表,默认为['spherical', 'exponential', 'gaussian', 'linear']
"""
self.variogram_models = variogram_models or ['spherical', 'exponential', 'gaussian', 'linear']
self.last_used_model = None
def validate_input_data(self, x: np.ndarray, y: np.ndarray, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, bool]:
"""
验证和预处理输入数据
Args:
x: X坐标数组
y: Y坐标数组
z: 观测值数组
Returns:
处理后的x, y, z数组和是否有效的标志
"""
# 确保输入为numpy数组
x = np.asarray(x)
y = np.asarray(y)
z = np.asarray(z)
# 检查数组长度一致性
if not (len(x) == len(y) == len(z)):
raise ValueError(f"输入数组长度不一致: x={len(x)}, y={len(y)}, z={len(z)}")
# 移除NaN值
mask = ~(np.isnan(x) | np.isnan(y) | np.isnan(z))
x = x[mask]
y = y[mask]
z = z[mask]
# 检查数据点数量
if len(x) < 3:
print(f"警告:有效数据点不足({len(x)}至少需要3个点进行Kriging插值")
return x, y, z, False
# 检查是否所有点重合
if np.all(x == x[0]) and np.all(y == y[0]):
print(f"警告:所有数据点位置相同({x[0]}, {y[0]}),无法进行空间插值")
return x, y, z, False
# 检查z值是否有变化
if np.all(z == z[0]):
print(f"警告:所有观测值相同({z[0]}),插值结果将为常数")
return x, y, z, True
def create_interpolation_grid(self, x: np.ndarray, y: np.ndarray, spatial_resolution: float) -> Tuple[np.ndarray, np.ndarray, int, int]:
"""
创建插值网格
Args:
x: X坐标数组
y: Y坐标数组
spatial_resolution: 空间分辨率
Returns:
网格x, 网格y, x方向步数, y方向步数
"""
# 计算空间范围,添加小的缓冲区
x_min, x_max = x.min(), x.max()
y_min, y_max = y.min(), y.max()
# 添加缓冲区以确保所有点都在网格内
buffer = spatial_resolution * 0.5
x_min -= buffer
x_max += buffer
y_min -= buffer
y_max += buffer
# 计算网格步数
step_x = int(np.ceil((x_max - x_min) / spatial_resolution)) + 1
step_y = int(np.ceil((y_max - y_min) / spatial_resolution)) + 1
# 限制网格大小以避免内存问题
max_grid_size = 10000
if step_x > max_grid_size or step_y > max_grid_size:
print(f"警告:网格尺寸过大 ({step_x}x{step_y}),将调整空间分辨率")
# 重新计算合适的分辨率
new_resolution_x = (x_max - x_min) / max_grid_size
new_resolution_y = (y_max - y_min) / max_grid_size
spatial_resolution = max(new_resolution_x, new_resolution_y, spatial_resolution)
step_x = int(np.ceil((x_max - x_min) / spatial_resolution)) + 1
step_y = int(np.ceil((y_max - y_min) / spatial_resolution)) + 1
print(f"调整后的空间分辨率: {spatial_resolution:.2f}, 网格尺寸: {step_x}x{step_y}")
# 创建网格
grid_x = np.linspace(x_min, x_max, step_x)
grid_y = np.linspace(y_min, y_max, step_y)
return grid_x, grid_y, step_x, step_y
@timeit
def interpolate(self, x: np.ndarray, y: np.ndarray, z: np.ndarray,
spatial_resolution: float = 1.0,
output_path: Optional[str] = None,
proj: Optional[str] = None) -> Optional[np.ndarray]:
"""
执行克里金插值
Args:
x: X坐标数组
y: Y坐标数组
z: 观测值数组
spatial_resolution: 空间分辨率
output_path: 输出文件路径
proj: 投影信息
Returns:
插值结果数组失败时返回None
"""
try:
from pykrige.ok import OrdinaryKriging
except ImportError:
print("错误未安装pykrige库请运行 'pip install pykrige'")
return None
# 验证输入数据
x, y, z, is_valid = self.validate_input_data(x, y, z)
if not is_valid:
return None
print(f"开始克里金插值,数据点数: {len(x)}")
# 创建插值网格
grid_x, grid_y, step_x, step_y = self.create_interpolation_grid(x, y, spatial_resolution)
print(f"插值网格尺寸: {step_x} x {step_y}")
print(f"空间范围: X=[{grid_x[0]:.2f}, {grid_x[-1]:.2f}], Y=[{grid_y[0]:.2f}, {grid_y[-1]:.2f}]")
# 尝试不同的变差函数模型
z_interpolated = None
successful_model = None
for model in self.variogram_models:
try:
print(f"尝试使用 {model} 变差函数模型...")
# 动态设置参数
nlags = min(20, max(6, len(x) // 3))
n_closest_points = min(12, max(4, len(x) // 2))
OK = OrdinaryKriging(
x, y, z,
variogram_model=model,
verbose=False,
enable_plotting=False,
coordinates_type="euclidean",
nlags=nlags
)
start_time = time.perf_counter()
z_interpolated, ss = OK.execute(
"grid", grid_x, grid_y,
backend="loop",
n_closest_points=n_closest_points
)
end_time = time.perf_counter()
successful_model = model
self.last_used_model = model
print(f"使用 {model} 模型插值成功,耗时: {end_time - start_time:.2f}")
break
except Exception as e:
print(f"模型 {model} 失败: {str(e)}")
continue
if z_interpolated is None:
print("错误:所有变差函数模型均失败,无法完成插值")
return None
# 检查插值结果
if np.all(np.isnan(z_interpolated)):
print("警告插值结果全为NaN值")
return None
nan_count = np.sum(np.isnan(z_interpolated))
total_count = z_interpolated.size
nan_percentage = (nan_count / total_count) * 100
print(f"插值完成,使用模型: {successful_model}")
print(f"结果统计: 总像元数={total_count}, NaN像元数={nan_count} ({nan_percentage:.1f}%)")
print(f"数值范围: [{np.nanmin(z_interpolated):.3f}, {np.nanmax(z_interpolated):.3f}]")
# 保存结果
if output_path and proj:
success = self.save_raster(z_interpolated, grid_x, grid_y, spatial_resolution, proj, output_path)
if success:
print(f"结果已保存至: {output_path}")
else:
print(f"保存失败: {output_path}")
return z_interpolated
def save_raster(self, data: np.ndarray, grid_x: np.ndarray, grid_y: np.ndarray,
spatial_resolution: float, proj: str, output_path: str) -> bool:
"""
保存插值结果为栅格文件
Args:
data: 插值结果数组
grid_x: X方向网格
grid_y: Y方向网格
spatial_resolution: 空间分辨率
proj: 投影信息
output_path: 输出路径
Returns:
是否保存成功
"""
try:
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 创建GDAL数据集
driver = gdal.GetDriverByName("GTiff")
step_x, step_y = data.shape[1], data.shape[0]
dataset = driver.Create(
output_path, step_x, step_y, 1, gdal.GDT_Float64,
options=["COMPRESS=LZW", "TILED=YES"]
)
if dataset is None:
print(f"错误:无法创建输出文件 {output_path}")
return False
# 设置地理变换参数
x_min, y_max = grid_x[0], grid_y[-1]
geotransform = (x_min, spatial_resolution, 0, y_max, 0, -spatial_resolution)
dataset.SetGeoTransform(geotransform)
# 设置投影
dataset.SetProjection(proj)
# 写入数据
band = dataset.GetRasterBand(1)
band.WriteArray(data)
band.SetNoDataValue(np.nan)
# 计算统计信息
band.ComputeStatistics(0)
# 清理资源
band.FlushCache()
dataset.FlushCache()
del dataset
return True
except Exception as e:
print(f"保存栅格文件时出错: {str(e)}")
return False
# 保持向后兼容性的函数接口
@timeit
def interpolate_kriging_pykrige(x, y, z, proj, spatial_resolution, output_path=None):
"""
执行克里金插值(向后兼容接口)
Args:
x: X坐标数组
y: Y坐标数组
z: 观测值数组
proj: 投影信息
spatial_resolution: 空间分辨率
output_path: 输出路径
Returns:
插值结果数组
"""
interpolator = KrigingInterpolator()
return interpolator.interpolate(x, y, z, spatial_resolution, output_path, proj)
def batch_kriging_interpolation(input_folder: str, ref_img_path: str,
output_folder: str, spatial_resolution: float = 1.0,
file_pattern: str = "*.csv") -> None:
"""
批量克里金插值处理
Args:
input_folder: 输入CSV文件夹路径
ref_img_path: 参考影像路径(用于获取投影信息)
output_folder: 输出文件夹路径
spatial_resolution: 空间分辨率
file_pattern: 文件匹配模式
"""
# 验证输入路径
if not os.path.exists(input_folder):
raise FileNotFoundError(f"输入文件夹不存在: {input_folder}")
if not os.path.exists(ref_img_path):
raise FileNotFoundError(f"参考影像不存在: {ref_img_path}")
# 确保输出文件夹存在
os.makedirs(output_folder, exist_ok=True)
# 获取参考影像的投影信息
try:
dataset = gdal.Open(ref_img_path)
if dataset is None:
raise ValueError(f"无法打开参考影像: {ref_img_path}")
im_proj = dataset.GetProjection()
del dataset
except Exception as e:
print(f"获取投影信息失败: {str(e)}")
return
# 查找CSV文件
csv_files = glob.glob(os.path.join(input_folder, file_pattern))
if not csv_files:
print(f"{input_folder} 中未找到匹配 {file_pattern} 的文件")
return
print(f"找到 {len(csv_files)} 个CSV文件待处理")
# 创建插值器
interpolator = KrigingInterpolator()
successful_count = 0
failed_count = 0
for i, csv_path in enumerate(csv_files, 1):
filename = os.path.basename(csv_path)
print(f"\n[{i}/{len(csv_files)}] 处理文件: {filename}")
try:
# 读取CSV文件
# 支持多种分隔符
try:
pos_content = np.loadtxt(csv_path, delimiter='\t')
except ValueError:
try:
pos_content = np.loadtxt(csv_path, delimiter=',')
except ValueError:
pos_content = np.loadtxt(csv_path, delimiter=';')
if pos_content.shape[1] < 3:
print(f"跳过文件 {filename}列数不足需要至少3列x, y, z")
failed_count += 1
continue
# 数据统计
total_points = len(pos_content)
nan_points = np.sum(np.isnan(pos_content[:, 2]))
print(f"数据点统计: 总计{total_points}个, NaN值{nan_points}")
# 构建输出路径
base_name = os.path.splitext(filename)[0]
output_filename = f"{base_name}_kriging.tif"
output_path = os.path.join(output_folder, output_filename)
# 执行插值
result = interpolator.interpolate(
pos_content[:, 0], # x
pos_content[:, 1], # y
pos_content[:, 2], # z
spatial_resolution,
output_path,
im_proj
)
if result is not None:
print(f"✓ 处理成功: {output_filename}")
successful_count += 1
else:
print(f"✗ 处理失败: {filename}")
failed_count += 1
except Exception as e:
print(f"✗ 处理文件 {filename} 时出错: {str(e)}")
failed_count += 1
# 输出总结
print(f"\n{'='*60}")
print(f"批量处理完成")
print(f"成功: {successful_count} 个文件")
print(f"失败: {failed_count} 个文件")
print(f"输出目录: {output_folder}")
if __name__ == '__main__':
# 示例用法
print("克里金插值模块示例")
# 配置参数(根据实际情况修改)
input_folder = r"data/processed/predictions" # CSV文件夹路径
ref_img_path = r"data/raw/reference_image.tif" # 参考影像路径
output_folder = r"data/processed/kriging_results" # 输出文件夹路径
spatial_resolution = 1.0 # 空间分辨率(米)
try:
# 执行批量插值
batch_kriging_interpolation(
input_folder=input_folder,
ref_img_path=ref_img_path,
output_folder=output_folder,
spatial_resolution=spatial_resolution
)
except Exception as e:
print(f"批量处理失败: {str(e)}")
print("\n请检查以下事项:")
print("1. 输入文件夹和参考影像路径是否正确")
print("2. CSV文件格式是否正确至少包含x, y, z三列")
print("3. 是否安装了必要的依赖库pykrige, gdal等")

417
src/utils/lapulasi_otsu.py Normal file
View File

@ -0,0 +1,417 @@
import numpy as np
import cv2
from osgeo import gdal
from collections import Counter
from typing import Optional, Union, Tuple
gdal.UseExceptions()
def laplacian_filter(image):
"""
拉普拉斯算子纹理提取
使用二阶偏微分对图像进行卷积,获取纹理图像
公式: η²g(m,n) = g(m,n+1) + g(m,n-1) + g(m+1,n) + g(m-1,n) - 4g(m,n)
参数:
image: 输入图像全波段叠加后的灰度图像F
返回:
L: 拉普拉斯滤波后的纹理图像
"""
# 拉普拉斯算子核4邻域中心5个像元
kernel = np.array([[0, 1, 0],
[1, -4, 1],
[0, 1, 0]], dtype=np.float32)
filtered_image = cv2.filter2D(image, -1, kernel)
return filtered_image
def apply_threshold(image, threshold):
"""
二值化处理(分割耀光区域)
公式:
- 对于L: W = {1, L>0; 0, L≤0}
- 对于F: S = {1, F>N; 0, F≤N}
参数:
image: 输入图像
threshold: 阈值
返回:
binary_image: 二值化图像0或1
"""
_, binary_image = cv2.threshold(image, threshold, 1, cv2.THRESH_BINARY)
return binary_image
def morphological_dilation(image, iterations=1):
"""
形态学处理(膨胀操作)
对耀光纹理区域进行扩展,使得检测到的耀光区域更加连贯
参数:
image: 输入二值图像
iterations: 膨胀迭代次数
返回:
dilated_image: 膨胀后的图像
"""
kernel = np.ones((3, 3), np.uint8)
dilated_image = cv2.dilate(image, kernel, iterations=iterations)
return dilated_image
def calculate_area_difference(W, S):
"""
计算面积差值
公式: rq = area(W) - area(S)
其中 area(W) 表示拉普拉斯纹理提取区域的面积(用于定位耀光位置)
area(S) 表示通过阈值分割全波段叠加图像得到的耀光区域面积
参数:
W: 纹理提取区域(二值图像)
S: 阈值分割区域(二值图像)
返回:
rq: 面积差值
area_W: 纹理提取区域面积
area_S: 阈值分割区域面积
"""
area_W = np.sum(W)
area_S = np.sum(S)
rq = area_W - area_S
return rq, area_W, area_S
def multi_band_weighted_sum(image_bands, water_mask):
"""
全波段叠加:将所有波段的遥感反射率加权叠加形成一个灰度图像
公式: F = Σ(i=1 to n) R(λi) × G
其中 G 为水体二值化图像F 为灰度图像
参数:
image_bands: 多波段图像数组,形状为 (rows, cols, bands)
water_mask: 水体二值化图像 G形状为 (rows, cols)值为0或1
返回:
F: 全波段加权叠加后的灰度图像
"""
# 确保water_mask是二值化的0或1
if water_mask.dtype != np.float32 and water_mask.dtype != np.float64:
water_mask = water_mask.astype(np.float32)
# 对每个波段进行加权叠加F = Σ R(λi) × G
F = np.zeros((image_bands.shape[0], image_bands.shape[1]), dtype=np.float32)
for band_idx in range(image_bands.shape[2]):
F += image_bands[:, :, band_idx] * water_mask
return F
def find_optimal_threshold(F, u=0.1, q=50, r=20, max_iterations=None):
"""
通过迭代找到最佳阈值
算法原理:
1. 使用拉普拉斯算子提取耀光纹理信息(用于定位耀光的位置)
2. 通过全波段叠加图像F进行阈值分割提取耀光面积
3. 当纹理提取区域的面积与通过阈值分割得到的耀光区域面积差最小时,确定最佳的阈值
步骤:
1. 对全波段叠加图像F进行拉普拉斯纹理提取得到L
2. 对L进行二值化W = {1, L>0; 0, L≤0}(用于定位耀光位置)
3. 对W和S进行形态学膨胀r次
4. 设定阈值N的初始值为F的最小值表示为Nf
5. 对初始值叠加数值u更新阈值迭代q次
6. 每次迭代记录W与S的面积差值到r行、q列的数组R
7. 寻找R中每列数组中最小值所在行数形成新的数组aind
8. 统计aind的众数得到M即为最佳的叠加次数
9. 最佳阈值Nf = min(F) + u·M
参数:
F: 全波段加权叠加后的灰度图像
u: 阈值更新步长默认0.1(论文参数)
q: 迭代次数默认50论文参数
r: 形态学膨胀次数默认20论文参数
max_iterations: 最大迭代次数如果指定则使用此值替代q
返回:
optimal_threshold: 最佳阈值Nf
optimal_S: 最佳阈值对应的耀光区域S通过阈值分割F得到
optimal_W: 最佳阈值对应的纹理区域W通过拉普拉斯提取得到
M: 最佳迭代次数索引
area_differences: 面积差值矩阵R (r行, q列)
thresholds: 每次迭代的阈值列表
"""
# 如果指定了max_iterations使用它替代q
if max_iterations is not None:
q = max_iterations
# 初始化阈值为影像非0最小值
F_nonzero = F[F > 0]
if len(F_nonzero) > 0:
min_value = np.min(F_nonzero)
else:
# 如果所有值都为0使用一个很小的正数作为最小值
min_value = np.finfo(np.float32).eps
print("警告: F中所有值都为0使用极小值作为最小值")
# 步骤1: 对全波段叠加图像F进行拉普拉斯纹理提取用于定位耀光位置
print("进行拉普拉斯纹理提取...")
L = laplacian_filter(F)
# 步骤2: 对L进行二值化得到纹理区域W
# W = {1, L>0; 0, L≤0}
W = apply_threshold(L, 0.0) # 阈值为0即L>0为1L≤0为0
# 步骤3: 对W进行形态学膨胀r次
print(f"对纹理区域W进行形态学膨胀{r}次...")
W_dilated = morphological_dilation(W, iterations=r)
# 存储每次迭代的面积差值r行q列
# 注意论文中r是膨胀次数但这里R矩阵的r行应该对应不同的膨胀次数
# 根据论文描述应该是迭代q次每次记录面积差值
# 但论文提到"叠加q次"和"一共迭代r次"这里理解为迭代q次每次对W和S都膨胀r次
area_differences = [] # 存储每次迭代的面积差值
thresholds = []
W_masks = []
S_masks = []
print(f"开始迭代计算最佳阈值(迭代{q}步长u={u}...")
# 迭代更新阈值迭代q次
for i in range(q):
# 当前阈值Nf = min(F) + u·(i+1)
current_threshold = min_value + u * (i + 1)
thresholds.append(current_threshold)
# 步骤4: 对全波段叠加图像F进行阈值分割得到耀光区域S用于提取耀光面积
# S = {1, F>N; 0, F≤N}
S = apply_threshold(F, current_threshold)
# 步骤5: 对S进行形态学膨胀r次
S_dilated = morphological_dilation(S, iterations=r)
# 步骤6: 计算面积差值:纹理提取区域面积 vs 阈值分割区域面积
rq, area_W, area_S = calculate_area_difference(W_dilated, S_dilated)
area_differences.append(rq)
W_masks.append(W_dilated.copy())
S_masks.append(S_dilated.copy())
if (i + 1) % 10 == 0:
print(f" 迭代 {i+1}/{q}: 阈值={current_threshold:.4f}, 面积差值={rq:.2f}")
# 步骤7: 寻找R中每列数组中最小值所在行数
# 注意:论文中提到"r行、q列的数组R"但根据算法描述应该是q次迭代
# 这里理解为将area_differences重新组织成矩阵形式如果需要
# 但根据论文描述,应该是直接找到最小面积差值对应的迭代次数
area_differences_array = np.array(area_differences)
# 步骤8: 找到最小面积差值所在的行数(索引)
# argmin_q(R(r,q)):找到最小面积差值所在的行数
min_indices = np.where(area_differences_array == np.min(area_differences_array))[0]
# 步骤9: 通过众数统计找到最频繁出现的行数确定最佳的叠加次数M
# Mode(aind):通过众数统计找到最频繁出现的行数
if len(min_indices) > 0:
# 如果最小值出现多次,使用众数统计
counter = Counter(min_indices)
most_common = counter.most_common(1)[0]
M = most_common[0] # 最佳迭代次数索引从0开始
else:
M = 0
# 步骤10: 计算最终阈值
# Nf = min(F) + u·M
# 注意M是索引从0开始所以实际迭代次数是M+1
optimal_threshold = min_value + u * (M + 1)
# 获取最佳阈值对应的掩膜
optimal_W = W_masks[M]
optimal_S = S_masks[M]
print(f"最佳迭代次数索引: M={M} (第{M+1}次迭代)")
print(f"最佳阈值: Nf={optimal_threshold:.4f}")
print(f"最小面积差值: {np.min(area_differences_array):.2f}")
return optimal_threshold, optimal_S, optimal_W, M, area_differences, thresholds
def generate_glint_mask(bsq_file, water_mask=None, u=0.1, q=50, r=20, max_iterations=None, output_file=None):
"""
生成耀光掩膜
算法流程:
1. 全波段叠加F = Σ(i=1 to n) R(λi) × GG为水体二值化图像
2. 拉普拉斯算子提取纹理信息(用于定位耀光位置)
3. 通过阈值分割全波段叠加图像提取耀光面积
4. 当纹理提取区域面积与阈值分割区域面积差最小时,确定最佳阈值
参数:
bsq_file: 输入的BSQ文件路径
water_mask: 水体二值化图像G可以是
- None: 自动生成基于所有像素即全为1的掩膜
- numpy数组: 直接使用数组作为掩膜,形状为 (rows, cols)值为0或1
- 文件路径: 栅格文件路径(.tif/.dat),将自动读取
u: 阈值更新步长默认0.1(论文参数)
q: 迭代次数默认50论文参数
r: 形态学膨胀次数默认20论文参数
max_iterations: 最大迭代次数如果指定则使用此值替代q
output_file: 输出文件路径如果为None则自动生成
返回:
tuple: (耀光掩膜文件路径, 纹理提取图像文件路径)
- 耀光掩膜文件路径: 通过阈值分割全波段叠加图像得到的最终掩膜S掩膜
- 纹理提取图像文件路径: 拉普拉斯纹理提取后的二值化图像W掩膜
"""
# 读取BSQ文件
bsq_dataset = gdal.Open(bsq_file)
if bsq_dataset is None:
raise ValueError(f"无法打开文件: {bsq_file}")
# 获取影像数据
bands = bsq_dataset.RasterCount
rows = bsq_dataset.RasterYSize
cols = bsq_dataset.RasterXSize
print(f"影像尺寸: {rows} x {cols}, 波段数: {bands}")
# 读取所有波段
print("正在读取所有波段数据...")
image_bands = np.zeros((rows, cols, bands), dtype=np.float32)
for band in range(bands):
image_bands[:, :, band] = bsq_dataset.GetRasterBand(band + 1).ReadAsArray().astype(np.float32)
if (band + 1) % 20 == 0:
print(f" 已读取 {band+1}/{bands} 个波段")
# 处理水体掩膜G
if water_mask is None:
# 如果没有提供水体掩膜使用全图所有像素为1
print("未提供水体掩膜,使用全图进行处理")
G = np.ones((rows, cols), dtype=np.float32)
elif isinstance(water_mask, np.ndarray):
# 如果直接提供了numpy数组
if water_mask.shape != (rows, cols):
raise ValueError(f"水体掩膜尺寸 {water_mask.shape} 与影像尺寸 {(rows, cols)} 不匹配")
G = water_mask.astype(np.float32)
# 确保是二值化的0或1
G = np.where(G > 0, 1.0, 0.0)
elif isinstance(water_mask, str):
# 如果是文件路径,读取文件
print(f"从文件读取水体掩膜: {water_mask}")
water_dataset = gdal.Open(water_mask)
if water_dataset is None:
raise ValueError(f"无法打开水体掩膜文件: {water_mask}")
if water_dataset.RasterXSize != cols or water_dataset.RasterYSize != rows:
raise ValueError(f"水体掩膜尺寸与影像尺寸不匹配")
G = water_dataset.GetRasterBand(1).ReadAsArray().astype(np.float32)
water_dataset = None
# 确保是二值化的0或1
G = np.where(G > 0, 1.0, 0.0)
else:
raise ValueError(f"不支持的水体掩膜类型: {type(water_mask)}")
print(f"水体掩膜统计: 水体像素数={np.sum(G)}, 总像素数={rows*cols}, 水体比例={np.sum(G)/(rows*cols)*100:.2f}%")
# 步骤1: 全波段叠加 F = Σ(i=1 to n) R(λi) × G
print("开始全波段叠加...")
F = multi_band_weighted_sum(image_bands, G)
print(f"全波段叠加完成F值范围: [{np.min(F):.4f}, {np.max(F):.4f}]")
print("开始计算最佳阈值...")
# 找到最佳阈值
optimal_threshold, glint_mask, texture_mask, optimal_iteration, area_diffs, thresholds = find_optimal_threshold(
F, u=u, q=q, r=r, max_iterations=max_iterations
)
print(f"\n=== 最佳阈值计算结果 ===")
print(f"最佳阈值: {optimal_threshold:.4f}")
print(f"最佳迭代次数: {optimal_iteration + 1}")
print(f"最小面积差值: {np.min(area_diffs):.4f}")
print(f"纹理提取区域面积: {np.sum(texture_mask)}")
print(f"阈值分割区域面积: {np.sum(glint_mask)}")
# 设置输出文件路径
if output_file is None:
output_file = 'glint_mask.tif'
# 生成纹理提取图像输出路径
texture_output_file = output_file.replace('.tif', '_texture.tif')
if texture_output_file == output_file: # 如果没有.tif扩展名
texture_output_file = output_file + '_texture.tif'
# 保存纹理提取图像W掩膜拉普拉斯提取的纹理区域
print(f"\n保存输出文件...")
driver = gdal.GetDriverByName('GTiff')
texture_dataset = driver.Create(texture_output_file, cols, rows, 1, gdal.GDT_Byte)
texture_dataset.SetGeoTransform(bsq_dataset.GetGeoTransform())
texture_dataset.SetProjection(bsq_dataset.GetProjection())
texture_mask_uint8 = (texture_mask * 255).astype(np.uint8)
texture_dataset.GetRasterBand(1).WriteArray(texture_mask_uint8)
texture_dataset = None
print(f"纹理提取图像已保存至: {texture_output_file}")
# 保存耀光掩膜S掩膜通过阈值分割全波段叠加图像得到
out_dataset = driver.Create(output_file, cols, rows, 1, gdal.GDT_Byte)
# 设置地理变换和投影信息
out_dataset.SetGeoTransform(bsq_dataset.GetGeoTransform())
out_dataset.SetProjection(bsq_dataset.GetProjection())
# 写入掩膜数据转换为0-255范围
glint_mask_uint8 = (glint_mask * 255).astype(np.uint8)
out_dataset.GetRasterBand(1).WriteArray(glint_mask_uint8)
# 关闭数据集
out_dataset = None
bsq_dataset = None
print(f"耀光掩膜已保存至: {output_file}")
return output_file, texture_output_file
# 使用示例
if __name__ == "__main__":
bsq_file = r"D:\BaiduNetdiskDownload\yaobao\test_glint.bsq" # 输入的BSQ文件
output_file = r'D:\BaiduNetdiskDownload\yaobao\glint\lapulas_otsu_glint_mask.tif'
# water_mask_file = r'path/to/water_mask.tif' # 可选:水体掩膜文件路径
# 示例1: 使用论文默认参数q=50, r=20, u=0.1
mask_file, texture_file = generate_glint_mask(
bsq_file,
water_mask=None,
u=0.1,
q=50,
r=20,
output_file=output_file
)
print(f"\n处理完成,耀光掩膜保存在: {mask_file}")
print(f"纹理提取图像保存在: {texture_file}")
# 示例2: 使用水体掩膜文件
# mask_file, texture_file = generate_glint_mask(
# bsq_file,
# water_mask=water_mask_file,
# u=0.1,
# q=50,
# r=20,
# output_file=output_file
# )
# 示例3: 使用numpy数组作为水体掩膜
# import numpy as np
# water_mask_array = np.ones((rows, cols), dtype=np.float32) # 示例全为1的掩膜
# mask_file, texture_file = generate_glint_mask(
# bsq_file,
# water_mask=water_mask_array,
# u=0.1,
# q=50,
# r=20,
# output_file=output_file
# )

1061
src/utils/sampling.py Normal file

File diff suppressed because it is too large Load Diff

15
src/utils/type_define.py Normal file
View File

@ -0,0 +1,15 @@
from enum import Enum
class CoorType(Enum):
latlong = 0 # 经纬度坐标
utm = 1 # UTM坐标
depend_on_image = 2 # 依赖影像坐标
pixel = 3 # 像素坐标
class ImgType(Enum):
ref = 0 # 反射率图像
content = 1 # 含量图像
class PointPosStrategy(Enum):
nearest_single = 0 # 最近单点
four_quadrant = 1 # 四象限

174
src/utils/util.py Normal file
View File

@ -0,0 +1,174 @@
import os, spectral
import time
import numpy as np
from osgeo import gdal
from enum import Enum, unique
import math
import json
class CoorType(Enum):
depend_on_image = 0 # 影像是啥类型坐标就是啥坐标
pixel = 1
class Timer: # Context Manager
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
print(exc_type, exc_value, traceback)
print(f"Run Time: {time.time() - self.start}")
def timeit(f): # decorator
def wraper(*args, **kwargs):
start = time.time()
ret = f(*args, **kwargs)
print(f.__name__ + " run time: " + str(round(time.time() - start, 2)) + " s.")
return ret
return wraper
def get_hdr_file_path(file_path):
return os.path.splitext(file_path)[0] + ".hdr"
def find_band_number(wav1, imgpath):
in_hdr_dict = spectral.envi.read_envi_header(get_hdr_file_path(imgpath))
wavelengths = np.array(in_hdr_dict['wavelength']).astype('float64')
differences = np.abs(wavelengths - wav1)
min_position = np.argmin(differences)
return int(min_position)
@timeit
def average_bands(start_wave, end_wave, imgpath):
start_bandnumber = find_band_number(start_wave, imgpath)
end_bandnumber = find_band_number(end_wave, imgpath)
dataset = gdal.Open(imgpath)
averaged_band = 1
for i in range(start_bandnumber, end_bandnumber + 1):
if i == start_bandnumber:
averaged_band = dataset.GetRasterBand(i + 1).ReadAsArray()
else:
tmp = dataset.GetRasterBand(i + 1).ReadAsArray()
averaged_band = (averaged_band + tmp) / 2
del dataset
return averaged_band
def exclude_by_mask(band, water_mask_path, ignore_value=0):
dataset = gdal.Open(water_mask_path)
data_tmp = dataset.GetRasterBand(1).ReadAsArray()
del dataset
band[np.where(data_tmp == ignore_value)] = 0
return band
@timeit
def average_bands_in_mask(start_wave, end_wave, imgpath, water_mask_path):
tmp = average_bands(start_wave, end_wave, imgpath)
tmp = exclude_by_mask(tmp, water_mask_path)
# raster_fn_out_tmp = append2filename(imgpath, "glint_delete")
# write_bands(imgpath, raster_fn_out_tmp, tmp)
return tmp
def get_average_value(dataset, x, y, band_number, window):
spectral_tmp = dataset.ReadAsArray(x, y, 1, 1)
average_value = spectral_tmp[band_number - window:band_number + window, :, :].mean()
return average_value
def get_valid_extent(dataset, data_ignore_value=0):
pass
def write_bands(imgpath_in, imgpath_out, *args):
# 将输入的波段(可变)写入文件
dataset = gdal.Open(imgpath_in)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
num_bands = dataset.RasterCount
geotransform = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
format = "ENVI"
driver = gdal.GetDriverByName(format)
dst_ds = driver.Create(imgpath_out, im_width, im_height, len(args), gdal.GDT_Float32,
options=["INTERLEAVE=BSQ"])
dst_ds.SetGeoTransform(geotransform)
dst_ds.SetProjection(im_proj)
for i in range(len(args)):
dst_ds.GetRasterBand(i + 1).WriteArray(args[i])
del dataset, dst_ds
def append2filename(file_path, txt2add):
imgfile_out_tmp = os.path.splitext(file_path)
new_file_path = imgfile_out_tmp[0] + "_" + txt2add + imgfile_out_tmp[1]
return new_file_path
def write_fields_to_hdrfile(source_hdr_file, dest_hdr_file):
source_fields = spectral.envi.read_envi_header(source_hdr_file)
dest_fields = spectral.envi.read_envi_header(dest_hdr_file)
with open(dest_hdr_file, "a", encoding='utf-8') as f:
for key in source_fields.keys():
if key in dest_fields or key == "description":
continue
if key == "data ignore value" or key == "wavelength" or key == "wavelength units":
if type(source_fields[key]) == list:
f.write(key + " = {" + ", ".join(source_fields[key]) + "}\n")
else:
f.write(key + " = " + source_fields[key] + "\n")
def getnearest(m, invalid_value=0):
layer_number = math.floor(m.shape[0] / 2)
center = layer_number
for i in range(layer_number + 1):
orig = (center - i, center - i)
tmp = m[center - i:center + i + 1, center - i:center + i + 1]
valid_indices = np.where((tmp != invalid_value) & np.isfinite(tmp))
if valid_indices[0].shape[0] != 0:
return int(valid_indices[0][0] + orig[0]), int(valid_indices[1][0] + orig[0]) # (y ,x)
return None, None
def load_numpy_dict_from_json(filename):
with open(filename, 'r') as f:
np_dict = json.load(f)
# 将字典中的列表转换回 NumPy 数组
model_type = np_dict['model_type']
model_info = np.array(np_dict['model_info'])
precision = np.array(np_dict['accuracy'])
return model_type, model_info, precision

862
src/utils/water_index.py Normal file
View File

@ -0,0 +1,862 @@
import pandas as pd
import numpy as np
import re
from typing import Dict, List, Optional, Union
class WaterQualityIndexCalculator:
"""
水质光谱指数计算器
为每个算法创建单独的函数,自动查找最接近的波长列
"""
def __init__(self):
self.references = {}
def find_closest_wavelength(self, df_columns: List[str], target_wl: int) -> str:
"""
在数据框列名中查找最接近目标波长的列
Args:
df_columns: 数据框的所有列名
target_wl: 目标波长
Returns:
最接近的列名
"""
# 提取所有数值型波长
numeric_wavelengths = []
for col in df_columns:
try:
# 从列名中提取数字
numbers = re.findall(r'\d+', col)
if numbers:
wavelength = int(numbers[0])
numeric_wavelengths.append((col, wavelength))
except:
continue
if not numeric_wavelengths:
raise ValueError(f"无法从列名中提取波长信息: {df_columns}")
# 找到最接近的波长
closest_col, closest_wl = min(numeric_wavelengths,
key=lambda x: abs(x[1] - target_wl))
print(f"为波长 {target_wl}nm 找到最接近的列: {closest_col} ({closest_wl}nm)")
return closest_col
# =========================================================================
# 叶绿素算法
# =========================================================================
def chl_Al10SABI(self, df: pd.DataFrame) -> pd.Series:
"""
Surface Algal Bloom Index (SABI) for chlorophyll detection
参考文献: Alawadi, F. Detection of surface algal blooms using the newly
developed algorithm surface algal bloom index (SABI). Proc. SPIE 2010, 7825.
"""
w857 = df[self.find_closest_wavelength(df.columns, 857)]
w644 = df[self.find_closest_wavelength(df.columns, 644)]
w458 = df[self.find_closest_wavelength(df.columns, 458)]
w529 = df[self.find_closest_wavelength(df.columns, 529)]
result = (w857 - w644) / (w458 + w529)
return result
def chl_Am092Bsub(self, df: pd.DataFrame) -> pd.Series:
"""
Baseline subtraction algorithm for chlorophyll
参考文献: Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S.
Novel optical techniques for detecting and classifying toxic dinoflagellate
Karenia brevis blooms using satellite imagery. Opt. Express 2009, 17, 91269144.
"""
w681 = df[self.find_closest_wavelength(df.columns, 681)]
w665 = df[self.find_closest_wavelength(df.columns, 665)]
result = w681 - w665
return result
def chl_Be16FLHblue(self, df: pd.DataFrame) -> pd.Series:
"""
Fluorescence Line Height algorithm with blue baseline for chlorophyll
参考文献: Beck, R.A. and 22 others; Comparison of satellite reflectance
algorithms for estimating chlorophyll-a in a temperate reservoir using
coincident hyperspectral aircraft imagery and dense coincident surface
observations, Remote Sens. Environ., 2016, 178, 15-30.
"""
w529 = df[self.find_closest_wavelength(df.columns, 529)]
w644 = df[self.find_closest_wavelength(df.columns, 644)]
w458 = df[self.find_closest_wavelength(df.columns, 458)]
result = w529 - (w644 + (w458 - w644))
return result
def chl_Be16FLHviolet(self, df: pd.DataFrame) -> pd.Series:
"""
Fluorescence Line Height algorithm with violet baseline for chlorophyll
参考文献: Beck, R.A. and 22 others; Comparison of satellite reflectance
algorithms for estimating chlorophyll-a in a temperate reservoir using
coincident hyperspectral aircraft imagery and dense coincident surface
observations, Remote Sens. Environ., 2016, 178, 15-30.
"""
w529 = df[self.find_closest_wavelength(df.columns, 529)]
w644 = df[self.find_closest_wavelength(df.columns, 644)]
w429 = df[self.find_closest_wavelength(df.columns, 429)]
result = w529 - (w644 + (w429 - w644))
return result
def chl_Be16NDTIblue(self, df: pd.DataFrame) -> pd.Series:
"""
Normalized Difference Turbidity Index with blue band for chlorophyll
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 543.
"""
w658 = df[self.find_closest_wavelength(df.columns, 658)]
w458 = df[self.find_closest_wavelength(df.columns, 458)]
result = (w658 - w458) / (w658 + w458)
return result
def chl_Be16NDTIviolet(self, df: pd.DataFrame) -> pd.Series:
"""
Normalized Difference Turbidity Index with violet band for chlorophyll
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 544.
"""
w658 = df[self.find_closest_wavelength(df.columns, 658)]
w444 = df[self.find_closest_wavelength(df.columns, 444)]
result = (w658 - w444) / (w658 + w444)
return result
def chl_De933BDA(self, df: pd.DataFrame) -> pd.Series:
"""
Band difference algorithm for chlorophyll
参考文献: Dekker, A.; Detection of the optical water quality parameters for
eutrophic waters by high resolution remote sensing, Ph.D. thesis, 1993,
Free University, Amsterdam.
"""
w600 = df[self.find_closest_wavelength(df.columns, 600)]
w648 = df[self.find_closest_wavelength(df.columns, 648)]
w625 = df[self.find_closest_wavelength(df.columns, 625)]
result = w600 - w648 - w625
return result
def chl_Gi033BDA(self, df: pd.DataFrame) -> pd.Series:
"""
Gitelson algorithm for chlorophyll estimation
参考文献: Gitelson, A.A.; U. Gritz, and M. N. Merzlyak.; Relationships between
leaf chlorophyll content and spectral reflectance and algorithms for
non-destructive chlorophyll assessment in higher plant leaves. J. Plant Phys. 2003, 160, 271-282.
"""
w672 = df[self.find_closest_wavelength(df.columns, 672)]
w715 = df[self.find_closest_wavelength(df.columns, 715)]
w757 = df[self.find_closest_wavelength(df.columns, 757)]
result = ((1 / w672) - (1 / w715)) * w757
return result
def chl_Kn07KIVU(self, df: pd.DataFrame) -> pd.Series:
"""
Kneubuhler algorithm for chlorophyll in Lake Kivu
参考文献: Kneubuhler, M.; Frank T.; Kellenberger, T.W; Pasche N.; Schmid M.;
Mapping chlorophyll-a in Lake Kivu with remote sensing methods. 2007,
Proceedings of the Envisat Symposium 2007, Montreux, Switzerland 2327 April 2007 (ESA SP-636, July 2007).
"""
w458 = df[self.find_closest_wavelength(df.columns, 458)]
w644 = df[self.find_closest_wavelength(df.columns, 644)]
w529 = df[self.find_closest_wavelength(df.columns, 529)]
result = (w458 - w644) / w529
return result
def chl_MM12NDCI(self, df: pd.DataFrame) -> pd.Series:
"""
Normalized Difference Chlorophyll Index
参考文献: Mishra, S.; and Mishra, D.R. Normalized difference chlorophyll
index: A novel model for remote estimation of chlorophyll-a concentration
in turbid productive waters, Remote Sens. Environ., 2012, 117, 394-406
"""
w715 = df[self.find_closest_wavelength(df.columns, 715)]
w686 = df[self.find_closest_wavelength(df.columns, 686)]
result = (w715 - w686) / (w715 + w686)
return result
def chl_Zh10FLH(self, df: pd.DataFrame) -> pd.Series:
"""
Zhao fluorescence line height algorithm for chlorophyll
参考文献: Zhao, D.Z.; Xing, X.G.; Liu, Y.G.; Yang, J.H.; Wang, L. The relation of
chlorophyll-a concentration with the reflectance peak near 700 nm in
algae-dominated waters and sensitivity of fluorescence algorithms for
detecting algal bloom. Int. J. Remote Sens. 2010, 31, 39-48
"""
w686 = df[self.find_closest_wavelength(df.columns, 686)]
w715 = df[self.find_closest_wavelength(df.columns, 715)]
w672 = df[self.find_closest_wavelength(df.columns, 672)]
w751 = df[self.find_closest_wavelength(df.columns, 751)]
result = w686 - (w715 + (w672 - w751))
return result
# =========================================================================
# 蓝藻/藻蓝蛋白算法 (BGA/PC)
# =========================================================================
def BGA_Am09KBBI(self, df: pd.DataFrame) -> pd.Series:
"""
Karenia Brevis Bloom Index for cyanobacteria/phycocyanin
参考文献: Amin, R.; Zhou, J.; Gilerson, A.; Gross, B.; Moshary, F.; Ahmed, S.;
Novel optical techniques for detecting and classifying toxic dinoflagellate
Karenia brevis blooms using satellite imagery, Optics Express, 2009, 17, 11, 1-13.
"""
w686 = df[self.find_closest_wavelength(df.columns, 686)]
w658 = df[self.find_closest_wavelength(df.columns, 658)]
result = (w686 - w658) / (w686 + w658)
return result
def BGA_Be162B643sub629(self, df: pd.DataFrame) -> pd.Series:
"""
Band subtraction algorithm for phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538.
"""
w644 = df[self.find_closest_wavelength(df.columns, 644)]
w629 = df[self.find_closest_wavelength(df.columns, 629)]
result = w644 - w629
return result
def BGA_Be162B700sub601(self, df: pd.DataFrame) -> pd.Series:
"""
Band subtraction algorithm for phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539.
"""
w700 = df[self.find_closest_wavelength(df.columns, 700)]
w601 = df[self.find_closest_wavelength(df.columns, 601)]
result = w700 - w601
return result
def BGA_Be162BsubPhy(self, df: pd.DataFrame) -> pd.Series:
"""
Band subtraction algorithm for phytoplankton/phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540.
"""
w715 = df[self.find_closest_wavelength(df.columns, 715)]
w615 = df[self.find_closest_wavelength(df.columns, 615)]
result = w715 - w615
return result
def BGA_Be16FLHBlueRedNIR(self, df: pd.DataFrame) -> pd.Series:
"""
Fluorescence Line Height with Blue-Red-NIR baseline for phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538.
"""
w658 = df[self.find_closest_wavelength(df.columns, 658)]
w857 = df[self.find_closest_wavelength(df.columns, 857)]
w458 = df[self.find_closest_wavelength(df.columns, 458)]
result = w658 - (w857 + (w458 - w857))
return result
def BGA_Be16FLHGreenRedNIR(self, df: pd.DataFrame) -> pd.Series:
"""
Fluorescence Line Height with Green-Red-NIR baseline for phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539.
"""
w658 = df[self.find_closest_wavelength(df.columns, 658)]
w857 = df[self.find_closest_wavelength(df.columns, 857)]
w558 = df[self.find_closest_wavelength(df.columns, 558)]
result = w658 - (w857 + (w558 - w857))
return result
def BGA_Be16FLHVioletRedNIR(self, df: pd.DataFrame) -> pd.Series:
"""
Fluorescence Line Height with Violet-Red-NIR baseline for phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538.
"""
w658 = df[self.find_closest_wavelength(df.columns, 658)]
w857 = df[self.find_closest_wavelength(df.columns, 857)]
w444 = df[self.find_closest_wavelength(df.columns, 444)]
result = w658 - (w857 + (w444 - w857))
return result
def BGA_Be16MPI(self, df: pd.DataFrame) -> pd.Series:
"""
Maximum Peak Index for phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539.
"""
w615 = df[self.find_closest_wavelength(df.columns, 615)]
w601 = df[self.find_closest_wavelength(df.columns, 601)]
w644 = df[self.find_closest_wavelength(df.columns, 644)]
result = (w615 - w601) - (w644 - w601)
return result
def BGA_Be16NDPhyI(self, df: pd.DataFrame) -> pd.Series:
"""
Normalized Difference Phytoplankton Index for phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 540.
"""
w700 = df[self.find_closest_wavelength(df.columns, 700)]
w622 = df[self.find_closest_wavelength(df.columns, 622)]
result = (w700 - w622) / (w700 + w622)
return result
def BGA_Be16NDPhyI644over615(self, df: pd.DataFrame) -> pd.Series:
"""
Normalized Difference Phytoplankton Index (644/615) for phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 541.
"""
w644 = df[self.find_closest_wavelength(df.columns, 644)]
w615 = df[self.find_closest_wavelength(df.columns, 615)]
result = (w644 - w615) / (w644 + w615)
return result
def BGA_Be16NDPhyI644over629(self, df: pd.DataFrame) -> pd.Series:
"""
Normalized Difference Phytoplankton Index (644/629) for phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 542.
"""
w644 = df[self.find_closest_wavelength(df.columns, 644)]
w629 = df[self.find_closest_wavelength(df.columns, 629)]
result = (w644 - w629) / (w644 + w629)
return result
def BGA_Be16Phy2BDA644over629(self, df: pd.DataFrame) -> pd.Series:
"""
Band ratio algorithm (644/629) for phycocyanin
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 545.
"""
w644 = df[self.find_closest_wavelength(df.columns, 644)]
w629 = df[self.find_closest_wavelength(df.columns, 629)]
result = w644 / w629
return result
def BGA_Da052BDA(self, df: pd.DataFrame) -> pd.Series:
"""
Band ratio algorithm (714/672) for phycocyanin
参考文献: Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A.,
Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial
blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672.
"""
w714 = df[self.find_closest_wavelength(df.columns, 714)]
w672 = df[self.find_closest_wavelength(df.columns, 672)]
result = w714 / w672
return result
def BGA_Go04MCI(self, df: pd.DataFrame) -> pd.Series:
"""
Maximum Chlorophyll Index for phycocyanin
参考文献: Gower, J.F.R.; Brown,L.; Borstad, G.A.; Observation of chlorophyll
fluorescence in west coast waters of Canada using the MODIS satellite sensor.
Can. J. Remote Sens., 2004, 30 (1), 1725.
"""
w709 = df[self.find_closest_wavelength(df.columns, 709)]
w681 = df[self.find_closest_wavelength(df.columns, 681)]
w753 = df[self.find_closest_wavelength(df.columns, 753)]
result = w709 - w681 - (w753 - w681)
return result
def BGA_HU103BDA(self, df: pd.DataFrame) -> pd.Series:
"""
Hunter algorithm for phycocyanin
参考文献: Hunter, P.D.; Tyler, A.N.; Willby, N.J.; Gilvear, D.J.; The spatial
dynamics of vertical migration by Microcystis aeruginosa in a eutrophic
shallow lake: A case study using high spatial resolution time-series
airborne remote sensing. Limn. Oceanogr. 2008, 53, 2391-2406
"""
w615 = df[self.find_closest_wavelength(df.columns, 615)]
w600 = df[self.find_closest_wavelength(df.columns, 600)]
w725 = df[self.find_closest_wavelength(df.columns, 725)]
result = (((1 / w615) - (1 / w600)) - w725)
return result
def BGA_Ku15PhyCI(self, df: pd.DataFrame) -> pd.Series:
"""
Kudela Phytoplankton Community Index for phycocyanin
参考文献: Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K.,
Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial
blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-10.
"""
w681 = df[self.find_closest_wavelength(df.columns, 681)]
w665 = df[self.find_closest_wavelength(df.columns, 665)]
w709 = df[self.find_closest_wavelength(df.columns, 709)]
result = -1 * (w681 - w665 - (w709 - w665))
return result
def BGA_Ku15SLH(self, df: pd.DataFrame) -> pd.Series:
"""
Kudela Scattering Line Height for phycocyanin
参考文献: Kudela, R.M., Palacios, S.L., Austerberry, D.C., Accorsi, E.K.,
Guild, L.S.; Application of hyperspectral remote sensing to cyanobacterial
blooms in inland waters, Torres-Perez, J., 2015, Remote Sens. Environ., 2015, 167, 1-11.
"""
w715 = df[self.find_closest_wavelength(df.columns, 715)]
w658 = df[self.find_closest_wavelength(df.columns, 658)]
result = (w715 - w658) + (w715 - w658)
return result
def BGA_MI092BDA(self, df: pd.DataFrame) -> pd.Series:
"""
Mishra band ratio algorithm (700/600) for phycocyanin
参考文献: Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for
predicting PC concentrations in cyanobacteria: A proximal hyperspectral
remote sensing approach. Remote Sens., 2009, 1, 758775.
"""
w700 = df[self.find_closest_wavelength(df.columns, 700)]
w600 = df[self.find_closest_wavelength(df.columns, 600)]
result = w700 / w600
return result
def BGA_MM092BDA(self, df: pd.DataFrame) -> pd.Series:
"""
Mishra band ratio algorithm (724/600) for phycocyanin
参考文献: Mishra, S.; Mishra, D.R.; Schluchter, W. M., A novel algorithm for
predicting PC concentrations in cyanobacteria: A proximal hyperspectral
remote sensing approach. Remote Sens., 2009, 1, 758776.
"""
w724 = df[self.find_closest_wavelength(df.columns, 724)]
w600 = df[self.find_closest_wavelength(df.columns, 600)]
result = w724 / w600
return result
def BGA_MM12NDCIalt(self, df: pd.DataFrame) -> pd.Series:
"""
Alternative Normalized Difference Chlorophyll Index for phycocyanin
参考文献: Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to
quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett.,
2014, 9 (11), DOI:10.1088/1748-9326/9/11/114003
"""
w700 = df[self.find_closest_wavelength(df.columns, 700)]
w658 = df[self.find_closest_wavelength(df.columns, 658)]
result = (w700 - w658) / (w700 + w658)
return result
def BGA_MM143BDAopt(self, df: pd.DataFrame) -> pd.Series:
"""
Optimized band algorithm for phycocyanin
参考文献: Mishra, S.; Mishra, D.R.; A novel remote sensing algorithm to
quantify phycocyanin in cyanobacterial algal blooms, Env. Res. Lett.,
2014, 9 (11), DOI:10.1088/1748-9326/9/11/114004
"""
w629 = df[self.find_closest_wavelength(df.columns, 629)]
w659 = df[self.find_closest_wavelength(df.columns, 659)]
w724 = df[self.find_closest_wavelength(df.columns, 724)]
result = ((1 / w629) - (1 / w659)) * w724
return result
def BGA_SI052BDA(self, df: pd.DataFrame) -> pd.Series:
"""
Simis band ratio algorithm (709/620) for phycocyanin
参考文献: Simis, S. G. H.; Peters, S.W. M.; Gons, H. J.; Remote sensing of
the cyanobacteria pigment phycocyanin in turbid inland water. Limn. Oceanogr., 2005, 50, 237245
"""
w709 = df[self.find_closest_wavelength(df.columns, 709)]
w620 = df[self.find_closest_wavelength(df.columns, 620)]
result = w709 / w620
return result
def BGA_SM122BDA(self, df: pd.DataFrame) -> pd.Series:
"""
Mishra band ratio algorithm (709/600) for phycocyanin
参考文献: Mishra, S. Remote sensing of cyanobacteria in turbid productive
waters, PhD Dissertation. Mississippi State University, USA. 2012.
"""
w709 = df[self.find_closest_wavelength(df.columns, 709)]
w600 = df[self.find_closest_wavelength(df.columns, 600)]
result = w709 / w600
return result
def BGA_SY002BDA(self, df: pd.DataFrame) -> pd.Series:
"""
Schalles-Yacobi band ratio algorithm (650/625) for phycocyanin
参考文献: Schalles, J.; Yacobi, Y. Remote detection and seasonal patterns of
phycocyanin, carotenoid and chlorophyll-a pigments in eutrophic waters.
Archiv fur Hydrobiologie, Special Issues Advances in Limnology, 2000, 55,153168
"""
w650 = df[self.find_closest_wavelength(df.columns, 650)]
w625 = df[self.find_closest_wavelength(df.columns, 625)]
result = w650 / w625
return result
def BGA_Wy08CI(self, df: pd.DataFrame) -> pd.Series:
"""
Wynne Cyanobacteria Index for phycocyanin
参考文献: Wynne, T. T., Stumpf, R. P., Tomlinson, M. C., Warner, R. A.,
Tester, P. A., Dyble, J.; Relating spectral shape to cyanobacterial
blooms in the Laurentian Great Lakes. Int. J. Remote Sens., 2008, 29, 3665-3672.
"""
w686 = df[self.find_closest_wavelength(df.columns, 686)]
w672 = df[self.find_closest_wavelength(df.columns, 672)]
w715 = df[self.find_closest_wavelength(df.columns, 715)]
result = -1 * (w686 - w672 - (w715 - w672))
return result
# =========================================================================
# 浊度算法
# =========================================================================
def Turb_Be16GreenPlusRedBothOverViolet(self, df: pd.DataFrame) -> pd.Series:
"""
Turbidity algorithm: (Green + Red) / Violet
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 538
"""
w558 = df[self.find_closest_wavelength(df.columns, 558)]
w658 = df[self.find_closest_wavelength(df.columns, 658)]
w444 = df[self.find_closest_wavelength(df.columns, 444)]
result = (w558 + w658) / w444
return result
def Turb_Be16RedOverViolet(self, df: pd.DataFrame) -> pd.Series:
"""
Turbidity algorithm: Red / Violet
参考文献: Beck, R.; Xu, M.; Zhan, S.; Liu, H.; Johansen, R.A.; Tong, S.;
Yang, B.; Shu, S.; Wu, Q.; Wang, S.; Berling, K.; Murray, A.; Emery, E.;
Reif, M.; Harwood, J.; Young, J.; Martin, M.; Stillings, G.; Stumpf, R.;
Su, H.; Ye, Z.; Huang, Y. Comparison of Satellite Reflectance Algorithms
for Estimating Phycocyanin Values and Cyanobacterial Total Biovolume in
a Temperate Reservoir Using Coincident Hyperspectral Aircraft Imagery
and Dense Coincident Surface Observations. Remote Sens. 2017, 9, 539
"""
w658 = df[self.find_closest_wavelength(df.columns, 658)]
w444 = df[self.find_closest_wavelength(df.columns, 444)]
result = w658 / w444
return result
def Turb_Bow06RedOverGreen(self, df: pd.DataFrame) -> pd.Series:
"""
Turbidity algorithm: Red / Green
参考文献: Bowers, D. G., and C. E. Binding. 2006. "The Optical Properties of
Mineral Suspended Particles: A Review and Synthesis." Estuarine Coastal
and Shelf Science 67 (12): 219230. doi:10.1016/j.ecss.2005.11.010
"""
w658 = df[self.find_closest_wavelength(df.columns, 658)]
w558 = df[self.find_closest_wavelength(df.columns, 558)]
result = w658 / w558
return result
def Turb_Chip09NIROverGreen(self, df: pd.DataFrame) -> pd.Series:
"""
Turbidity algorithm: NIR / Green
参考文献: Chipman, J. W.; Olmanson, L.G.; Gitelson, A.A.; Remote sensing
methods for lake management: A guide for resource managers and decision-makers. 2009.
"""
w857 = df[self.find_closest_wavelength(df.columns, 857)]
w558 = df[self.find_closest_wavelength(df.columns, 558)]
result = w857 / w558
return result
def Turb_Dox02NIRoverRed(self, df: pd.DataFrame) -> pd.Series:
"""
Turbidity algorithm: NIR / Red
参考文献: Doxaran, D., Froidefond, J.-M.; Castaing, P. ; A reflectance band
ratio used to estimate suspended matter concentrations in sediment-dominated
coastal waters, Remote Sens., 2002, 23, 5079-5085
"""
w857 = df[self.find_closest_wavelength(df.columns, 857)]
w658 = df[self.find_closest_wavelength(df.columns, 658)]
result = w857 / w658
return result
def Turb_Frohn09GreenPlusRedBothOverBlue(self, df: pd.DataFrame) -> pd.Series:
"""
Turbidity algorithm: (Green + Red) / Blue
参考文献: Frohn, R. C., & Autrey, B. C. (2009). Water quality assessment in
the Ohio River using new indices for turbidity and chlorophyll-a with
Landsat-7 Imagery. Draft Internal Report, US Environmental Protection Agency.
"""
w558 = df[self.find_closest_wavelength(df.columns, 558)]
w658 = df[self.find_closest_wavelength(df.columns, 658)]
w458 = df[self.find_closest_wavelength(df.columns, 458)]
result = (w558 + w658) / w458
return result
def Turb_Harr92NIR(self, df: pd.DataFrame) -> pd.Series:
"""
Turbidity algorithm: NIR reflectance
参考文献: Schiebe F.R., Harrington J.A., Ritchie J.C. Remote-Sensing of
Suspended Sediments—the Lake Chicot, Arkansas Project. Int. J. Remote Sens. 1992;13:14871509
"""
w857 = df[self.find_closest_wavelength(df.columns, 857)]
result = w857
return result
def Turb_Lath91RedOverBlue(self, df: pd.DataFrame) -> pd.Series:
"""
Turbidity algorithm: Red / Blue
参考文献: Lathrop, R. G., Jr., T. M. Lillesand, and B. S. Yandell, 1991.
Testing the utility of simple multi-date Thematic Mapper calibration
algorithms for monitoring turbid inland waters. International Journal of Remote Sensing
"""
w658 = df[self.find_closest_wavelength(df.columns, 658)]
w458 = df[self.find_closest_wavelength(df.columns, 458)]
result = w658 / w458
return result
def Turb_Moore80Red(self, df: pd.DataFrame) -> pd.Series:
"""
Turbidity algorithm: Red reflectance
参考文献: Moore, G.K., Satellite remote sensing of water turbidity,
Hydrological Sciences, 1980, 25, 4, 407-422
"""
w658 = df[self.find_closest_wavelength(df.columns, 658)]
result = w658
return result
def calculate_all_indices(
self,
input_file: str,
output_file: str = None,
selected_indices: Optional[List[str]] = None
) -> pd.DataFrame:
"""
计算所有水质指数
Args:
input_file: 输入CSV文件路径
output_file: 输出CSV文件路径可选
selected_indices: 可选的算法列表,仅计算指定的指数
Returns:
包含计算结果的数据框
"""
# 读取数据
df = pd.read_csv(input_file)
print(f"成功读取数据,形状: {df.shape}")
print(f"数据列: {list(df.columns)}")
results = df.copy()
# 获取所有算法方法
algorithm_methods = [
method for method in dir(self)
if not method.startswith('_') and method not in ['find_closest_wavelength', 'calculate_all_indices']
]
print(f"\n找到 {len(algorithm_methods)} 个算法")
if selected_indices is not None:
filtered = []
missing = []
for name in selected_indices:
if name in algorithm_methods:
filtered.append(name)
else:
missing.append(name)
if missing:
print(f"警告: 以下算法未找到,将被忽略: {', '.join(missing)}")
algorithm_methods = filtered
if not algorithm_methods:
raise ValueError("未找到可用算法,请检查 selected_indices 参数")
# 按算法类型分类计算
algorithm_categories = {
'chlorophyll': [],
'BGA/PC': [],
'turbidity': []
}
for method_name in algorithm_methods:
if method_name.startswith('Turb'):
algorithm_categories['turbidity'].append(method_name)
elif any(keyword in method_name for keyword in ['BDA', 'FLH', 'ND', 'sub', 'CI', 'SLH', 'MPI', 'Phy']):
if method_name not in algorithm_categories['turbidity']:
algorithm_categories['BGA/PC'].append(method_name)
else:
algorithm_categories['chlorophyll'].append(method_name)
# 计算每个类别的算法
for category, algorithms in algorithm_categories.items():
print(f"\n=== 计算 {category} 相关指数 ({len(algorithms)}个算法) ===")
for algo_name in algorithms:
try:
print(f"计算: {algo_name}")
method = getattr(self, algo_name)
results[algo_name] = method(df)
print(f"✓ 成功计算 {algo_name}")
except Exception as e:
print(f"✗ 计算 {algo_name} 时出错: {str(e)}")
results[algo_name] = np.nan
# 保存结果
if output_file:
results.to_csv(output_file, index=False)
print(f"\n结果已保存到: {output_file}")
return results
def main():
"""主函数"""
calculator = WaterQualityIndexCalculator()
print("=" * 80)
print("水质光谱指数计算器")
print("=" * 80)
# 计算指数
input_file = r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\training_spectra.csv" # 修改为您的输入文件路径
output_file = r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\water_quality_results.csv"
try:
# 设置为 None 时默认计算所有已实现的算法;也可以设置为算法名称列表,例如 ['Al10SABI', 'TurbBe16RedOverViolet']
selected_algorithms = None
results = calculator.calculate_all_indices(input_file, output_file, selected_algorithms)
# 显示结果统计
print("\n" + "=" * 80)
print("计算结果统计:")
print("=" * 80)
# 只显示计算出的指数列的统计信息
original_columns = pd.read_csv(input_file).columns
calculated_columns = [col for col in results.columns if col not in original_columns]
if calculated_columns:
stats = results[calculated_columns].describe()
print(stats)
# 按类别显示统计
categories = {
'叶绿素算法': [col for col in calculated_columns if not col.startswith('Turb') and not any(x in col for x in ['BDA', 'FLH', 'ND', 'sub', 'CI', 'SLH', 'MPI', 'Phy']) or col in ['Al10SABI', 'Am092Bsub', 'Be16FLHblue', 'Be16FLHviolet', 'Be16NDTIblue', 'Be16NDTIviolet', 'De933BDA', 'Gi033BDA', 'Kn07KIVU', 'MM12NDCI', 'Zh10FLH']],
'蓝藻/藻蓝蛋白算法': [col for col in calculated_columns if col not in ['Al10SABI', 'Am092Bsub', 'Be16FLHblue', 'Be16FLHviolet', 'Be16NDTIblue', 'Be16NDTIviolet', 'De933BDA', 'Gi033BDA', 'Kn07KIVU', 'MM12NDCI', 'Zh10FLH'] and not col.startswith('Turb')],
'浊度算法': [col for col in calculated_columns if col.startswith('Turb')]
}
for category, algo_list in categories.items():
if algo_list:
print(f"\n{category}统计:")
print(results[algo_list].describe())
else:
print("没有成功计算任何指数")
except FileNotFoundError:
print(f"\n错误: 找不到输入文件 {input_file}")
print("请确保文件存在,或修改 input_file 变量为正确的文件路径")
except Exception as e:
print(f"\n计算过程中发生错误: {str(e)}")
if __name__ == "__main__":
main()