Files
WQ_GUI/src/utils/kriging.py
2026-04-08 15:25:08 +08:00

458 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
克里金插值模块
提供基于PyKrige的普通克里金插值功能用于将离散的水质参数预测点
插值为连续的栅格图像。
主要功能:
1. 普通克里金插值
2. 多种变差函数模型支持
3. 自动参数优化
4. 栅格输出功能
"""
import numpy as np
from osgeo import gdal
import time
import os
import glob
from pathlib import Path
from typing import Optional, Tuple, Union, List
import warnings
warnings.filterwarnings('ignore')
# 导入util模块的timeit装饰器
try:
from src.utils.util import timeit
except ImportError:
# 如果导入失败定义一个简单的timeit装饰器
def timeit(f):
def wrapper(*args, **kwargs):
start = time.time()
ret = f(*args, **kwargs)
print(f"{f.__name__} run time: {round(time.time() - start, 2)} s.")
return ret
return wrapper
class KrigingInterpolator:
"""克里金插值器类"""
def __init__(self, variogram_models: Optional[List[str]] = None):
"""
初始化克里金插值器
Args:
variogram_models: 变差函数模型列表,默认为['spherical', 'exponential', 'gaussian', 'linear']
"""
self.variogram_models = variogram_models or ['spherical', 'exponential', 'gaussian', 'linear']
self.last_used_model = None
def validate_input_data(self, x: np.ndarray, y: np.ndarray, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, bool]:
"""
验证和预处理输入数据
Args:
x: X坐标数组
y: Y坐标数组
z: 观测值数组
Returns:
处理后的x, y, z数组和是否有效的标志
"""
# 确保输入为numpy数组
x = np.asarray(x)
y = np.asarray(y)
z = np.asarray(z)
# 检查数组长度一致性
if not (len(x) == len(y) == len(z)):
raise ValueError(f"输入数组长度不一致: x={len(x)}, y={len(y)}, z={len(z)}")
# 移除NaN值
mask = ~(np.isnan(x) | np.isnan(y) | np.isnan(z))
x = x[mask]
y = y[mask]
z = z[mask]
# 检查数据点数量
if len(x) < 3:
print(f"警告:有效数据点不足({len(x)}至少需要3个点进行Kriging插值")
return x, y, z, False
# 检查是否所有点重合
if np.all(x == x[0]) and np.all(y == y[0]):
print(f"警告:所有数据点位置相同({x[0]}, {y[0]}),无法进行空间插值")
return x, y, z, False
# 检查z值是否有变化
if np.all(z == z[0]):
print(f"警告:所有观测值相同({z[0]}),插值结果将为常数")
return x, y, z, True
def create_interpolation_grid(self, x: np.ndarray, y: np.ndarray, spatial_resolution: float) -> Tuple[np.ndarray, np.ndarray, int, int]:
"""
创建插值网格
Args:
x: X坐标数组
y: Y坐标数组
spatial_resolution: 空间分辨率
Returns:
网格x, 网格y, x方向步数, y方向步数
"""
# 计算空间范围,添加小的缓冲区
x_min, x_max = x.min(), x.max()
y_min, y_max = y.min(), y.max()
# 添加缓冲区以确保所有点都在网格内
buffer = spatial_resolution * 0.5
x_min -= buffer
x_max += buffer
y_min -= buffer
y_max += buffer
# 计算网格步数
step_x = int(np.ceil((x_max - x_min) / spatial_resolution)) + 1
step_y = int(np.ceil((y_max - y_min) / spatial_resolution)) + 1
# 限制网格大小以避免内存问题
max_grid_size = 10000
if step_x > max_grid_size or step_y > max_grid_size:
print(f"警告:网格尺寸过大 ({step_x}x{step_y}),将调整空间分辨率")
# 重新计算合适的分辨率
new_resolution_x = (x_max - x_min) / max_grid_size
new_resolution_y = (y_max - y_min) / max_grid_size
spatial_resolution = max(new_resolution_x, new_resolution_y, spatial_resolution)
step_x = int(np.ceil((x_max - x_min) / spatial_resolution)) + 1
step_y = int(np.ceil((y_max - y_min) / spatial_resolution)) + 1
print(f"调整后的空间分辨率: {spatial_resolution:.2f}, 网格尺寸: {step_x}x{step_y}")
# 创建网格
grid_x = np.linspace(x_min, x_max, step_x)
grid_y = np.linspace(y_min, y_max, step_y)
return grid_x, grid_y, step_x, step_y
@timeit
def interpolate(self, x: np.ndarray, y: np.ndarray, z: np.ndarray,
spatial_resolution: float = 1.0,
output_path: Optional[str] = None,
proj: Optional[str] = None) -> Optional[np.ndarray]:
"""
执行克里金插值
Args:
x: X坐标数组
y: Y坐标数组
z: 观测值数组
spatial_resolution: 空间分辨率
output_path: 输出文件路径
proj: 投影信息
Returns:
插值结果数组失败时返回None
"""
try:
from pykrige.ok import OrdinaryKriging
except ImportError:
print("错误未安装pykrige库请运行 'pip install pykrige'")
return None
# 验证输入数据
x, y, z, is_valid = self.validate_input_data(x, y, z)
if not is_valid:
return None
print(f"开始克里金插值,数据点数: {len(x)}")
# 创建插值网格
grid_x, grid_y, step_x, step_y = self.create_interpolation_grid(x, y, spatial_resolution)
print(f"插值网格尺寸: {step_x} x {step_y}")
print(f"空间范围: X=[{grid_x[0]:.2f}, {grid_x[-1]:.2f}], Y=[{grid_y[0]:.2f}, {grid_y[-1]:.2f}]")
# 尝试不同的变差函数模型
z_interpolated = None
successful_model = None
for model in self.variogram_models:
try:
print(f"尝试使用 {model} 变差函数模型...")
# 动态设置参数
nlags = min(20, max(6, len(x) // 3))
n_closest_points = min(12, max(4, len(x) // 2))
OK = OrdinaryKriging(
x, y, z,
variogram_model=model,
verbose=False,
enable_plotting=False,
coordinates_type="euclidean",
nlags=nlags
)
start_time = time.perf_counter()
z_interpolated, ss = OK.execute(
"grid", grid_x, grid_y,
backend="loop",
n_closest_points=n_closest_points
)
end_time = time.perf_counter()
successful_model = model
self.last_used_model = model
print(f"使用 {model} 模型插值成功,耗时: {end_time - start_time:.2f}")
break
except Exception as e:
print(f"模型 {model} 失败: {str(e)}")
continue
if z_interpolated is None:
print("错误:所有变差函数模型均失败,无法完成插值")
return None
# 检查插值结果
if np.all(np.isnan(z_interpolated)):
print("警告插值结果全为NaN值")
return None
nan_count = np.sum(np.isnan(z_interpolated))
total_count = z_interpolated.size
nan_percentage = (nan_count / total_count) * 100
print(f"插值完成,使用模型: {successful_model}")
print(f"结果统计: 总像元数={total_count}, NaN像元数={nan_count} ({nan_percentage:.1f}%)")
print(f"数值范围: [{np.nanmin(z_interpolated):.3f}, {np.nanmax(z_interpolated):.3f}]")
# 保存结果
if output_path and proj:
success = self.save_raster(z_interpolated, grid_x, grid_y, spatial_resolution, proj, output_path)
if success:
print(f"结果已保存至: {output_path}")
else:
print(f"保存失败: {output_path}")
return z_interpolated
def save_raster(self, data: np.ndarray, grid_x: np.ndarray, grid_y: np.ndarray,
spatial_resolution: float, proj: str, output_path: str) -> bool:
"""
保存插值结果为栅格文件
Args:
data: 插值结果数组
grid_x: X方向网格
grid_y: Y方向网格
spatial_resolution: 空间分辨率
proj: 投影信息
output_path: 输出路径
Returns:
是否保存成功
"""
try:
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 创建GDAL数据集
driver = gdal.GetDriverByName("GTiff")
step_x, step_y = data.shape[1], data.shape[0]
dataset = driver.Create(
output_path, step_x, step_y, 1, gdal.GDT_Float64,
options=["COMPRESS=LZW", "TILED=YES"]
)
if dataset is None:
print(f"错误:无法创建输出文件 {output_path}")
return False
# 设置地理变换参数
x_min, y_max = grid_x[0], grid_y[-1]
geotransform = (x_min, spatial_resolution, 0, y_max, 0, -spatial_resolution)
dataset.SetGeoTransform(geotransform)
# 设置投影
dataset.SetProjection(proj)
# 写入数据
band = dataset.GetRasterBand(1)
band.WriteArray(data)
band.SetNoDataValue(np.nan)
# 计算统计信息
band.ComputeStatistics(0)
# 清理资源
band.FlushCache()
dataset.FlushCache()
del dataset
return True
except Exception as e:
print(f"保存栅格文件时出错: {str(e)}")
return False
# 保持向后兼容性的函数接口
@timeit
def interpolate_kriging_pykrige(x, y, z, proj, spatial_resolution, output_path=None):
"""
执行克里金插值(向后兼容接口)
Args:
x: X坐标数组
y: Y坐标数组
z: 观测值数组
proj: 投影信息
spatial_resolution: 空间分辨率
output_path: 输出路径
Returns:
插值结果数组
"""
interpolator = KrigingInterpolator()
return interpolator.interpolate(x, y, z, spatial_resolution, output_path, proj)
def batch_kriging_interpolation(input_folder: str, ref_img_path: str,
output_folder: str, spatial_resolution: float = 1.0,
file_pattern: str = "*.csv") -> None:
"""
批量克里金插值处理
Args:
input_folder: 输入CSV文件夹路径
ref_img_path: 参考影像路径(用于获取投影信息)
output_folder: 输出文件夹路径
spatial_resolution: 空间分辨率
file_pattern: 文件匹配模式
"""
# 验证输入路径
if not os.path.exists(input_folder):
raise FileNotFoundError(f"输入文件夹不存在: {input_folder}")
if not os.path.exists(ref_img_path):
raise FileNotFoundError(f"参考影像不存在: {ref_img_path}")
# 确保输出文件夹存在
os.makedirs(output_folder, exist_ok=True)
# 获取参考影像的投影信息
try:
dataset = gdal.Open(ref_img_path)
if dataset is None:
raise ValueError(f"无法打开参考影像: {ref_img_path}")
im_proj = dataset.GetProjection()
del dataset
except Exception as e:
print(f"获取投影信息失败: {str(e)}")
return
# 查找CSV文件
csv_files = glob.glob(os.path.join(input_folder, file_pattern))
if not csv_files:
print(f"{input_folder} 中未找到匹配 {file_pattern} 的文件")
return
print(f"找到 {len(csv_files)} 个CSV文件待处理")
# 创建插值器
interpolator = KrigingInterpolator()
successful_count = 0
failed_count = 0
for i, csv_path in enumerate(csv_files, 1):
filename = os.path.basename(csv_path)
print(f"\n[{i}/{len(csv_files)}] 处理文件: {filename}")
try:
# 读取CSV文件
# 支持多种分隔符
try:
pos_content = np.loadtxt(csv_path, delimiter='\t')
except ValueError:
try:
pos_content = np.loadtxt(csv_path, delimiter=',')
except ValueError:
pos_content = np.loadtxt(csv_path, delimiter=';')
if pos_content.shape[1] < 3:
print(f"跳过文件 {filename}列数不足需要至少3列x, y, z")
failed_count += 1
continue
# 数据统计
total_points = len(pos_content)
nan_points = np.sum(np.isnan(pos_content[:, 2]))
print(f"数据点统计: 总计{total_points}个, NaN值{nan_points}")
# 构建输出路径
base_name = os.path.splitext(filename)[0]
output_filename = f"{base_name}_kriging.tif"
output_path = os.path.join(output_folder, output_filename)
# 执行插值
result = interpolator.interpolate(
pos_content[:, 0], # x
pos_content[:, 1], # y
pos_content[:, 2], # z
spatial_resolution,
output_path,
im_proj
)
if result is not None:
print(f"✓ 处理成功: {output_filename}")
successful_count += 1
else:
print(f"✗ 处理失败: {filename}")
failed_count += 1
except Exception as e:
print(f"✗ 处理文件 {filename} 时出错: {str(e)}")
failed_count += 1
# 输出总结
print(f"\n{'='*60}")
print(f"批量处理完成")
print(f"成功: {successful_count} 个文件")
print(f"失败: {failed_count} 个文件")
print(f"输出目录: {output_folder}")
if __name__ == '__main__':
# 示例用法
print("克里金插值模块示例")
# 配置参数(根据实际情况修改)
input_folder = r"data/processed/predictions" # CSV文件夹路径
ref_img_path = r"data/raw/reference_image.tif" # 参考影像路径
output_folder = r"data/processed/kriging_results" # 输出文件夹路径
spatial_resolution = 1.0 # 空间分辨率(米)
try:
# 执行批量插值
batch_kriging_interpolation(
input_folder=input_folder,
ref_img_path=ref_img_path,
output_folder=output_folder,
spatial_resolution=spatial_resolution
)
except Exception as e:
print(f"批量处理失败: {str(e)}")
print("\n请检查以下事项:")
print("1. 输入文件夹和参考影像路径是否正确")
print("2. CSV文件格式是否正确至少包含x, y, z三列")
print("3. 是否安装了必要的依赖库pykrige, gdal等")