Initial commit of WQ_GUI
This commit is contained in:
458
src/utils/kriging.py
Normal file
458
src/utils/kriging.py
Normal file
@ -0,0 +1,458 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
克里金插值模块
|
||||
|
||||
提供基于PyKrige的普通克里金插值功能,用于将离散的水质参数预测点
|
||||
插值为连续的栅格图像。
|
||||
|
||||
主要功能:
|
||||
1. 普通克里金插值
|
||||
2. 多种变差函数模型支持
|
||||
3. 自动参数优化
|
||||
4. 栅格输出功能
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from osgeo import gdal
|
||||
import time
|
||||
import os
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union, List
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# 导入util模块的timeit装饰器
|
||||
try:
|
||||
from src.utils.util import timeit
|
||||
except ImportError:
|
||||
# 如果导入失败,定义一个简单的timeit装饰器
|
||||
def timeit(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
start = time.time()
|
||||
ret = f(*args, **kwargs)
|
||||
print(f"{f.__name__} run time: {round(time.time() - start, 2)} s.")
|
||||
return ret
|
||||
return wrapper
|
||||
|
||||
|
||||
class KrigingInterpolator:
|
||||
"""克里金插值器类"""
|
||||
|
||||
def __init__(self, variogram_models: Optional[List[str]] = None):
|
||||
"""
|
||||
初始化克里金插值器
|
||||
|
||||
Args:
|
||||
variogram_models: 变差函数模型列表,默认为['spherical', 'exponential', 'gaussian', 'linear']
|
||||
"""
|
||||
self.variogram_models = variogram_models or ['spherical', 'exponential', 'gaussian', 'linear']
|
||||
self.last_used_model = None
|
||||
|
||||
def validate_input_data(self, x: np.ndarray, y: np.ndarray, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, bool]:
|
||||
"""
|
||||
验证和预处理输入数据
|
||||
|
||||
Args:
|
||||
x: X坐标数组
|
||||
y: Y坐标数组
|
||||
z: 观测值数组
|
||||
|
||||
Returns:
|
||||
处理后的x, y, z数组和是否有效的标志
|
||||
"""
|
||||
# 确保输入为numpy数组
|
||||
x = np.asarray(x)
|
||||
y = np.asarray(y)
|
||||
z = np.asarray(z)
|
||||
|
||||
# 检查数组长度一致性
|
||||
if not (len(x) == len(y) == len(z)):
|
||||
raise ValueError(f"输入数组长度不一致: x={len(x)}, y={len(y)}, z={len(z)}")
|
||||
|
||||
# 移除NaN值
|
||||
mask = ~(np.isnan(x) | np.isnan(y) | np.isnan(z))
|
||||
x = x[mask]
|
||||
y = y[mask]
|
||||
z = z[mask]
|
||||
|
||||
# 检查数据点数量
|
||||
if len(x) < 3:
|
||||
print(f"警告:有效数据点不足({len(x)}个),至少需要3个点进行Kriging插值")
|
||||
return x, y, z, False
|
||||
|
||||
# 检查是否所有点重合
|
||||
if np.all(x == x[0]) and np.all(y == y[0]):
|
||||
print(f"警告:所有数据点位置相同({x[0]}, {y[0]}),无法进行空间插值")
|
||||
return x, y, z, False
|
||||
|
||||
# 检查z值是否有变化
|
||||
if np.all(z == z[0]):
|
||||
print(f"警告:所有观测值相同({z[0]}),插值结果将为常数")
|
||||
|
||||
return x, y, z, True
|
||||
|
||||
def create_interpolation_grid(self, x: np.ndarray, y: np.ndarray, spatial_resolution: float) -> Tuple[np.ndarray, np.ndarray, int, int]:
|
||||
"""
|
||||
创建插值网格
|
||||
|
||||
Args:
|
||||
x: X坐标数组
|
||||
y: Y坐标数组
|
||||
spatial_resolution: 空间分辨率
|
||||
|
||||
Returns:
|
||||
网格x, 网格y, x方向步数, y方向步数
|
||||
"""
|
||||
# 计算空间范围,添加小的缓冲区
|
||||
x_min, x_max = x.min(), x.max()
|
||||
y_min, y_max = y.min(), y.max()
|
||||
|
||||
# 添加缓冲区以确保所有点都在网格内
|
||||
buffer = spatial_resolution * 0.5
|
||||
x_min -= buffer
|
||||
x_max += buffer
|
||||
y_min -= buffer
|
||||
y_max += buffer
|
||||
|
||||
# 计算网格步数
|
||||
step_x = int(np.ceil((x_max - x_min) / spatial_resolution)) + 1
|
||||
step_y = int(np.ceil((y_max - y_min) / spatial_resolution)) + 1
|
||||
|
||||
# 限制网格大小以避免内存问题
|
||||
max_grid_size = 10000
|
||||
if step_x > max_grid_size or step_y > max_grid_size:
|
||||
print(f"警告:网格尺寸过大 ({step_x}x{step_y}),将调整空间分辨率")
|
||||
# 重新计算合适的分辨率
|
||||
new_resolution_x = (x_max - x_min) / max_grid_size
|
||||
new_resolution_y = (y_max - y_min) / max_grid_size
|
||||
spatial_resolution = max(new_resolution_x, new_resolution_y, spatial_resolution)
|
||||
|
||||
step_x = int(np.ceil((x_max - x_min) / spatial_resolution)) + 1
|
||||
step_y = int(np.ceil((y_max - y_min) / spatial_resolution)) + 1
|
||||
print(f"调整后的空间分辨率: {spatial_resolution:.2f}, 网格尺寸: {step_x}x{step_y}")
|
||||
|
||||
# 创建网格
|
||||
grid_x = np.linspace(x_min, x_max, step_x)
|
||||
grid_y = np.linspace(y_min, y_max, step_y)
|
||||
|
||||
return grid_x, grid_y, step_x, step_y
|
||||
|
||||
@timeit
|
||||
def interpolate(self, x: np.ndarray, y: np.ndarray, z: np.ndarray,
|
||||
spatial_resolution: float = 1.0,
|
||||
output_path: Optional[str] = None,
|
||||
proj: Optional[str] = None) -> Optional[np.ndarray]:
|
||||
"""
|
||||
执行克里金插值
|
||||
|
||||
Args:
|
||||
x: X坐标数组
|
||||
y: Y坐标数组
|
||||
z: 观测值数组
|
||||
spatial_resolution: 空间分辨率
|
||||
output_path: 输出文件路径
|
||||
proj: 投影信息
|
||||
|
||||
Returns:
|
||||
插值结果数组,失败时返回None
|
||||
"""
|
||||
try:
|
||||
from pykrige.ok import OrdinaryKriging
|
||||
except ImportError:
|
||||
print("错误:未安装pykrige库,请运行 'pip install pykrige'")
|
||||
return None
|
||||
|
||||
# 验证输入数据
|
||||
x, y, z, is_valid = self.validate_input_data(x, y, z)
|
||||
if not is_valid:
|
||||
return None
|
||||
|
||||
print(f"开始克里金插值,数据点数: {len(x)}")
|
||||
|
||||
# 创建插值网格
|
||||
grid_x, grid_y, step_x, step_y = self.create_interpolation_grid(x, y, spatial_resolution)
|
||||
|
||||
print(f"插值网格尺寸: {step_x} x {step_y}")
|
||||
print(f"空间范围: X=[{grid_x[0]:.2f}, {grid_x[-1]:.2f}], Y=[{grid_y[0]:.2f}, {grid_y[-1]:.2f}]")
|
||||
|
||||
# 尝试不同的变差函数模型
|
||||
z_interpolated = None
|
||||
successful_model = None
|
||||
|
||||
for model in self.variogram_models:
|
||||
try:
|
||||
print(f"尝试使用 {model} 变差函数模型...")
|
||||
|
||||
# 动态设置参数
|
||||
nlags = min(20, max(6, len(x) // 3))
|
||||
n_closest_points = min(12, max(4, len(x) // 2))
|
||||
|
||||
OK = OrdinaryKriging(
|
||||
x, y, z,
|
||||
variogram_model=model,
|
||||
verbose=False,
|
||||
enable_plotting=False,
|
||||
coordinates_type="euclidean",
|
||||
nlags=nlags
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
z_interpolated, ss = OK.execute(
|
||||
"grid", grid_x, grid_y,
|
||||
backend="loop",
|
||||
n_closest_points=n_closest_points
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
successful_model = model
|
||||
self.last_used_model = model
|
||||
print(f"使用 {model} 模型插值成功,耗时: {end_time - start_time:.2f}秒")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"模型 {model} 失败: {str(e)}")
|
||||
continue
|
||||
|
||||
if z_interpolated is None:
|
||||
print("错误:所有变差函数模型均失败,无法完成插值")
|
||||
return None
|
||||
|
||||
# 检查插值结果
|
||||
if np.all(np.isnan(z_interpolated)):
|
||||
print("警告:插值结果全为NaN值")
|
||||
return None
|
||||
|
||||
nan_count = np.sum(np.isnan(z_interpolated))
|
||||
total_count = z_interpolated.size
|
||||
nan_percentage = (nan_count / total_count) * 100
|
||||
|
||||
print(f"插值完成,使用模型: {successful_model}")
|
||||
print(f"结果统计: 总像元数={total_count}, NaN像元数={nan_count} ({nan_percentage:.1f}%)")
|
||||
print(f"数值范围: [{np.nanmin(z_interpolated):.3f}, {np.nanmax(z_interpolated):.3f}]")
|
||||
|
||||
# 保存结果
|
||||
if output_path and proj:
|
||||
success = self.save_raster(z_interpolated, grid_x, grid_y, spatial_resolution, proj, output_path)
|
||||
if success:
|
||||
print(f"结果已保存至: {output_path}")
|
||||
else:
|
||||
print(f"保存失败: {output_path}")
|
||||
|
||||
return z_interpolated
|
||||
|
||||
def save_raster(self, data: np.ndarray, grid_x: np.ndarray, grid_y: np.ndarray,
|
||||
spatial_resolution: float, proj: str, output_path: str) -> bool:
|
||||
"""
|
||||
保存插值结果为栅格文件
|
||||
|
||||
Args:
|
||||
data: 插值结果数组
|
||||
grid_x: X方向网格
|
||||
grid_y: Y方向网格
|
||||
spatial_resolution: 空间分辨率
|
||||
proj: 投影信息
|
||||
output_path: 输出路径
|
||||
|
||||
Returns:
|
||||
是否保存成功
|
||||
"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
# 创建GDAL数据集
|
||||
driver = gdal.GetDriverByName("GTiff")
|
||||
step_x, step_y = data.shape[1], data.shape[0]
|
||||
|
||||
dataset = driver.Create(
|
||||
output_path, step_x, step_y, 1, gdal.GDT_Float64,
|
||||
options=["COMPRESS=LZW", "TILED=YES"]
|
||||
)
|
||||
|
||||
if dataset is None:
|
||||
print(f"错误:无法创建输出文件 {output_path}")
|
||||
return False
|
||||
|
||||
# 设置地理变换参数
|
||||
x_min, y_max = grid_x[0], grid_y[-1]
|
||||
geotransform = (x_min, spatial_resolution, 0, y_max, 0, -spatial_resolution)
|
||||
dataset.SetGeoTransform(geotransform)
|
||||
|
||||
# 设置投影
|
||||
dataset.SetProjection(proj)
|
||||
|
||||
# 写入数据
|
||||
band = dataset.GetRasterBand(1)
|
||||
band.WriteArray(data)
|
||||
band.SetNoDataValue(np.nan)
|
||||
|
||||
# 计算统计信息
|
||||
band.ComputeStatistics(0)
|
||||
|
||||
# 清理资源
|
||||
band.FlushCache()
|
||||
dataset.FlushCache()
|
||||
del dataset
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存栅格文件时出错: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# 保持向后兼容性的函数接口
|
||||
@timeit
|
||||
def interpolate_kriging_pykrige(x, y, z, proj, spatial_resolution, output_path=None):
|
||||
"""
|
||||
执行克里金插值(向后兼容接口)
|
||||
|
||||
Args:
|
||||
x: X坐标数组
|
||||
y: Y坐标数组
|
||||
z: 观测值数组
|
||||
proj: 投影信息
|
||||
spatial_resolution: 空间分辨率
|
||||
output_path: 输出路径
|
||||
|
||||
Returns:
|
||||
插值结果数组
|
||||
"""
|
||||
interpolator = KrigingInterpolator()
|
||||
return interpolator.interpolate(x, y, z, spatial_resolution, output_path, proj)
|
||||
|
||||
|
||||
def batch_kriging_interpolation(input_folder: str, ref_img_path: str,
|
||||
output_folder: str, spatial_resolution: float = 1.0,
|
||||
file_pattern: str = "*.csv") -> None:
|
||||
"""
|
||||
批量克里金插值处理
|
||||
|
||||
Args:
|
||||
input_folder: 输入CSV文件夹路径
|
||||
ref_img_path: 参考影像路径(用于获取投影信息)
|
||||
output_folder: 输出文件夹路径
|
||||
spatial_resolution: 空间分辨率
|
||||
file_pattern: 文件匹配模式
|
||||
"""
|
||||
# 验证输入路径
|
||||
if not os.path.exists(input_folder):
|
||||
raise FileNotFoundError(f"输入文件夹不存在: {input_folder}")
|
||||
|
||||
if not os.path.exists(ref_img_path):
|
||||
raise FileNotFoundError(f"参考影像不存在: {ref_img_path}")
|
||||
|
||||
# 确保输出文件夹存在
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# 获取参考影像的投影信息
|
||||
try:
|
||||
dataset = gdal.Open(ref_img_path)
|
||||
if dataset is None:
|
||||
raise ValueError(f"无法打开参考影像: {ref_img_path}")
|
||||
im_proj = dataset.GetProjection()
|
||||
del dataset
|
||||
except Exception as e:
|
||||
print(f"获取投影信息失败: {str(e)}")
|
||||
return
|
||||
|
||||
# 查找CSV文件
|
||||
csv_files = glob.glob(os.path.join(input_folder, file_pattern))
|
||||
if not csv_files:
|
||||
print(f"在 {input_folder} 中未找到匹配 {file_pattern} 的文件")
|
||||
return
|
||||
|
||||
print(f"找到 {len(csv_files)} 个CSV文件待处理")
|
||||
|
||||
# 创建插值器
|
||||
interpolator = KrigingInterpolator()
|
||||
|
||||
successful_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for i, csv_path in enumerate(csv_files, 1):
|
||||
filename = os.path.basename(csv_path)
|
||||
print(f"\n[{i}/{len(csv_files)}] 处理文件: {filename}")
|
||||
|
||||
try:
|
||||
# 读取CSV文件
|
||||
# 支持多种分隔符
|
||||
try:
|
||||
pos_content = np.loadtxt(csv_path, delimiter='\t')
|
||||
except ValueError:
|
||||
try:
|
||||
pos_content = np.loadtxt(csv_path, delimiter=',')
|
||||
except ValueError:
|
||||
pos_content = np.loadtxt(csv_path, delimiter=';')
|
||||
|
||||
if pos_content.shape[1] < 3:
|
||||
print(f"跳过文件 {filename}:列数不足(需要至少3列:x, y, z)")
|
||||
failed_count += 1
|
||||
continue
|
||||
|
||||
# 数据统计
|
||||
total_points = len(pos_content)
|
||||
nan_points = np.sum(np.isnan(pos_content[:, 2]))
|
||||
print(f"数据点统计: 总计{total_points}个, NaN值{nan_points}个")
|
||||
|
||||
# 构建输出路径
|
||||
base_name = os.path.splitext(filename)[0]
|
||||
output_filename = f"{base_name}_kriging.tif"
|
||||
output_path = os.path.join(output_folder, output_filename)
|
||||
|
||||
# 执行插值
|
||||
result = interpolator.interpolate(
|
||||
pos_content[:, 0], # x
|
||||
pos_content[:, 1], # y
|
||||
pos_content[:, 2], # z
|
||||
spatial_resolution,
|
||||
output_path,
|
||||
im_proj
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
print(f"✓ 处理成功: {output_filename}")
|
||||
successful_count += 1
|
||||
else:
|
||||
print(f"✗ 处理失败: {filename}")
|
||||
failed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 处理文件 {filename} 时出错: {str(e)}")
|
||||
failed_count += 1
|
||||
|
||||
# 输出总结
|
||||
print(f"\n{'='*60}")
|
||||
print(f"批量处理完成")
|
||||
print(f"成功: {successful_count} 个文件")
|
||||
print(f"失败: {failed_count} 个文件")
|
||||
print(f"输出目录: {output_folder}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例用法
|
||||
print("克里金插值模块示例")
|
||||
|
||||
# 配置参数(根据实际情况修改)
|
||||
input_folder = r"data/processed/predictions" # CSV文件夹路径
|
||||
ref_img_path = r"data/raw/reference_image.tif" # 参考影像路径
|
||||
output_folder = r"data/processed/kriging_results" # 输出文件夹路径
|
||||
spatial_resolution = 1.0 # 空间分辨率(米)
|
||||
|
||||
try:
|
||||
# 执行批量插值
|
||||
batch_kriging_interpolation(
|
||||
input_folder=input_folder,
|
||||
ref_img_path=ref_img_path,
|
||||
output_folder=output_folder,
|
||||
spatial_resolution=spatial_resolution
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"批量处理失败: {str(e)}")
|
||||
print("\n请检查以下事项:")
|
||||
print("1. 输入文件夹和参考影像路径是否正确")
|
||||
print("2. CSV文件格式是否正确(至少包含x, y, z三列)")
|
||||
print("3. 是否安装了必要的依赖库(pykrige, gdal等)")
|
||||
Reference in New Issue
Block a user