884 lines
32 KiB
Python
884 lines
32 KiB
Python
import numpy as np
|
||
import pandas as pd
|
||
import spectral
|
||
from sklearn.preprocessing import StandardScaler
|
||
from sklearn.covariance import EmpiricalCovariance, MinCovDet
|
||
import os
|
||
import warnings
|
||
from typing import Optional, Dict, Any, Tuple, List, Union
|
||
from dataclasses import dataclass, field
|
||
import argparse
|
||
from pathlib import Path
|
||
from scipy import stats, linalg
|
||
import cv2
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
@dataclass
|
||
class RXConfig:
|
||
"""RX异常检测配置类"""
|
||
|
||
# 输入文件配置
|
||
input_path: Optional[str] = None
|
||
background_path: Optional[str] = None # 用户指定背景数据
|
||
label_col: Optional[str] = None
|
||
spectral_start: Optional[str] = None
|
||
spectral_end: Optional[str] = None
|
||
csv_has_header: bool = True
|
||
|
||
# 输出配置
|
||
output_path: Optional[str] = None
|
||
output_dir: Optional[str] = None
|
||
|
||
# 背景模型配置
|
||
background_model: str = 'global' # 'global', 'local_row', 'local_square', 'custom'
|
||
window_size: int = 5 # 局部窗口大小
|
||
row_window: int = 3 # 行窗口大小(行数)
|
||
|
||
# 检测参数
|
||
threshold: Optional[float] = None # 异常检测阈值,None表示不进行阈值分割
|
||
contamination: float = 0.1 # 异常点比例,用于自动确定阈值
|
||
use_robust_covariance: bool = False # 是否使用稳健协方差估计
|
||
|
||
# 数据预处理配置
|
||
use_scaling: bool = False # 是否使用标准化(RX算法通常不使用)
|
||
|
||
def __post_init__(self):
|
||
"""参数校验和默认值设置"""
|
||
# 校验必需的文件路径
|
||
if not self.input_path:
|
||
raise ValueError("必须指定输入文件路径(input_path)")
|
||
|
||
# 校验文件存在性
|
||
if not os.path.exists(self.input_path):
|
||
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
|
||
|
||
if self.background_path and not os.path.exists(self.background_path):
|
||
raise FileNotFoundError(f"背景文件不存在: {self.background_path}")
|
||
|
||
# 校验输出路径
|
||
if not self.output_path and not self.output_dir:
|
||
raise ValueError("必须指定输出路径(output_path)或输出目录(output_dir)")
|
||
|
||
# 校验背景模型类型
|
||
supported_models = ['global', 'local_row', 'local_square', 'custom']
|
||
if self.background_model not in supported_models:
|
||
raise ValueError(f"不支持的背景模型: {self.background_model}。支持的模型: {supported_models}")
|
||
|
||
# 校验参数范围
|
||
if self.window_size <= 0:
|
||
raise ValueError("window_size必须大于0")
|
||
|
||
if self.row_window <= 0:
|
||
raise ValueError("row_window必须大于0")
|
||
|
||
if self.background_model == 'custom' and not self.background_path:
|
||
raise ValueError("使用custom背景模型时,必须指定background_path")
|
||
|
||
|
||
class RXAnomalyDetector:
|
||
"""
|
||
Reed-Xiaoli (RX) 异常检测器
|
||
|
||
RX算法基于Mahalanobis距离检测高光谱数据中的异常像素。
|
||
算法计算每个像素相对于背景分布的距离,距离越大的像素越可能是异常。
|
||
|
||
数学基础:
|
||
D(x) = (x - μ)ᵀ K⁻¹ (x - μ)
|
||
|
||
其中:
|
||
- x: 测试像素的光谱向量
|
||
- μ: 背景数据的均值向量
|
||
- K: 背景数据的协方差矩阵
|
||
"""
|
||
|
||
def __init__(self, config: RXConfig):
|
||
"""
|
||
初始化RX异常检测器
|
||
|
||
Args:
|
||
config: 配置对象
|
||
"""
|
||
self.config = config
|
||
self.data = None
|
||
self.background_data = None
|
||
self.wavelengths = None
|
||
self.data_shape = None
|
||
self.input_format = None
|
||
|
||
# 背景统计量
|
||
self.background_mean = None
|
||
self.background_cov = None
|
||
self.background_inv_cov = None
|
||
|
||
# 结果
|
||
self.distance_map = None # Mahalanobis距离图
|
||
self.anomaly_map = None # 异常检测结果
|
||
|
||
# 预处理组件
|
||
self.scaler = None
|
||
|
||
def load_data(self) -> None:
|
||
"""
|
||
加载高光谱数据文件
|
||
|
||
支持CSV和ENVI格式文件
|
||
"""
|
||
file_path = Path(self.config.input_path)
|
||
suffix = file_path.suffix.lower()
|
||
|
||
print(f"正在读取输入数据: {file_path}")
|
||
|
||
if suffix == '.csv':
|
||
self._load_csv_data(file_path, is_background=False)
|
||
elif suffix in ['.hdr']:
|
||
self._load_envi_data(file_path, is_background=False)
|
||
else:
|
||
raise ValueError(f"不支持的文件格式: {suffix}。支持的格式: .hdr")
|
||
|
||
# 加载背景数据
|
||
if self.config.background_model == 'custom':
|
||
self._load_background_data()
|
||
|
||
def _load_csv_data(self, file_path: Path, is_background: bool = False) -> None:
|
||
"""加载CSV格式数据"""
|
||
if self.config.spectral_start is None:
|
||
raise ValueError("对于CSV文件,必须指定spectral_start参数")
|
||
|
||
# 读取数据
|
||
df = pd.read_csv(file_path, header=0 if self.config.csv_has_header else None)
|
||
|
||
# 提取标签列
|
||
if self.config.label_col and self.config.label_col in df.columns:
|
||
labels = df[self.config.label_col].values
|
||
spectral_cols = [col for col in df.columns if col != self.config.label_col]
|
||
else:
|
||
labels = None
|
||
spectral_cols = df.columns.tolist()
|
||
|
||
# 提取光谱数据
|
||
if self.config.spectral_end:
|
||
end_idx = spectral_cols.index(self.config.spectral_end) + 1
|
||
else:
|
||
end_idx = len(spectral_cols)
|
||
|
||
start_idx = spectral_cols.index(self.config.spectral_start)
|
||
spectral_data = df[spectral_cols[start_idx:end_idx]].values
|
||
|
||
# 生成波长信息
|
||
wavelengths = np.arange(len(spectral_cols[start_idx:end_idx]))
|
||
|
||
# 存储数据
|
||
if is_background:
|
||
self.background_data = spectral_data.astype(np.float32)
|
||
print(f"背景数据加载完成: {self.background_data.shape} 样本 x {self.background_data.shape[1]} 波段")
|
||
else:
|
||
self.data = spectral_data.astype(np.float32)
|
||
self.data_shape = self.data.shape
|
||
self.input_format = 'csv'
|
||
self.wavelengths = wavelengths
|
||
print(f"输入数据加载完成: {self.data.shape} 样本 x {self.data.shape[1]} 波段")
|
||
|
||
def _load_envi_data(self, file_path: Path, is_background: bool = False) -> None:
|
||
"""加载ENVI格式数据"""
|
||
try:
|
||
# 确保文件路径存在且可访问
|
||
file_path = Path(file_path)
|
||
if not file_path.exists():
|
||
raise FileNotFoundError(f"ENVI文件不存在: {file_path}")
|
||
|
||
# 检查文件是否可读
|
||
if not os.access(file_path, os.R_OK):
|
||
raise PermissionError(f"没有读取权限: {file_path}")
|
||
|
||
# 使用规范化的路径字符串
|
||
file_path_str = str(file_path)
|
||
print(f"尝试使用spectral库读取ENVI文件: {file_path_str}")
|
||
|
||
# 检查对应的头文件
|
||
hdr_path = file_path.with_suffix(file_path.suffix + '.hdr')
|
||
if not hdr_path.exists():
|
||
# 尝试 .hdr 格式
|
||
hdr_path = file_path.with_suffix('.hdr')
|
||
if not hdr_path.exists():
|
||
print(f"警告: 未找到头文件 {file_path.with_suffix(file_path.suffix + '.hdr')} 或 {file_path.with_suffix('.hdr')}")
|
||
|
||
# 读取ENVI文件 - 使用规范化的路径
|
||
envi_data = spectral.open_image(file_path_str)
|
||
|
||
# 获取数据数组
|
||
print("正在加载数据到内存...")
|
||
data_array = envi_data.load()
|
||
|
||
# 获取波长信息
|
||
if hasattr(envi_data, 'bands') and hasattr(envi_data.bands, 'centers'):
|
||
wavelengths = np.array(envi_data.bands.centers)
|
||
print(f"读取到波长信息: {len(wavelengths)} 个波段")
|
||
else:
|
||
wavelengths = np.arange(data_array.shape[2])
|
||
print(f"未找到波长信息,使用默认波长: 0-{len(wavelengths)-1}")
|
||
|
||
# 重塑数据为 (samples, bands)
|
||
rows, cols, bands = data_array.shape
|
||
spectral_data = data_array.reshape(-1, bands).astype(np.float32)
|
||
|
||
# 存储数据
|
||
if is_background:
|
||
self.background_data = spectral_data
|
||
print(f"背景数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||
else:
|
||
self.data = spectral_data
|
||
self.data_shape = (rows, cols, bands)
|
||
self.input_format = 'envi'
|
||
self.wavelengths = wavelengths
|
||
print(f"输入数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||
|
||
except Exception as e:
|
||
raise ValueError(f"读取ENVI文件失败: {str(e)}。请确保spectral库可用且文件格式正确。")
|
||
|
||
def _load_background_data(self) -> None:
|
||
"""加载用户指定的背景数据"""
|
||
if not self.config.background_path:
|
||
raise ValueError("background_path未指定")
|
||
|
||
file_path = Path(self.config.background_path)
|
||
suffix = file_path.suffix.lower()
|
||
|
||
print(f"正在读取背景数据: {file_path}")
|
||
|
||
if suffix == '.csv':
|
||
self._load_csv_data(file_path, is_background=True)
|
||
elif suffix in ['.bil', '.bsq', '.bip', '.dat']:
|
||
self._load_envi_data(file_path, is_background=True)
|
||
else:
|
||
raise ValueError(f"不支持的背景文件格式: {suffix}")
|
||
|
||
def preprocess_data(self) -> None:
|
||
"""
|
||
数据预处理
|
||
|
||
RX算法通常不需要标准化,除非指定
|
||
"""
|
||
if self.config.use_scaling:
|
||
print("正在标准化数据...")
|
||
self.scaler = StandardScaler()
|
||
|
||
# 对输入数据进行标准化
|
||
if self.data is not None:
|
||
self.data = self.scaler.fit_transform(self.data)
|
||
|
||
# 对背景数据进行相同的标准化
|
||
if self.background_data is not None:
|
||
self.background_data = self.scaler.transform(self.background_data)
|
||
|
||
def compute_background_statistics(self) -> None:
|
||
"""
|
||
计算背景统计量(均值和协方差矩阵)
|
||
|
||
根据不同的背景模型计算相应的统计量
|
||
"""
|
||
print(f"正在计算背景统计量 (模型: {self.config.background_model})...")
|
||
|
||
if self.config.background_model == 'global':
|
||
self._compute_global_background()
|
||
elif self.config.background_model == 'custom':
|
||
self._compute_custom_background()
|
||
# 局部模型的统计量在检测过程中计算
|
||
|
||
def _compute_global_background(self) -> None:
|
||
"""计算全局背景统计量"""
|
||
# 使用整个数据集作为背景
|
||
background_data = self.data
|
||
|
||
if self.config.use_robust_covariance:
|
||
# 使用稳健协方差估计
|
||
cov_estimator = MinCovDet()
|
||
cov_estimator.fit(background_data)
|
||
self.background_mean = cov_estimator.location_
|
||
self.background_cov = cov_estimator.covariance_
|
||
else:
|
||
# 使用经验协方差
|
||
self.background_mean = np.mean(background_data, axis=0)
|
||
self.background_cov = np.cov(background_data.T)
|
||
|
||
# 计算协方差矩阵的逆
|
||
try:
|
||
self.background_inv_cov = linalg.inv(self.background_cov)
|
||
except np.linalg.LinAlgError:
|
||
print("警告: 协方差矩阵奇异,使用伪逆")
|
||
self.background_inv_cov = linalg.pinv(self.background_cov)
|
||
|
||
print(f"全局背景统计量计算完成,协方差矩阵形状: {self.background_cov.shape}")
|
||
|
||
def _compute_custom_background(self) -> None:
|
||
"""计算用户指定背景统计量"""
|
||
if self.background_data is None:
|
||
raise ValueError("自定义背景数据未加载")
|
||
|
||
background_data = self.background_data
|
||
|
||
if self.config.use_robust_covariance:
|
||
cov_estimator = MinCovDet()
|
||
cov_estimator.fit(background_data)
|
||
self.background_mean = cov_estimator.location_
|
||
self.background_cov = cov_estimator.covariance_
|
||
else:
|
||
self.background_mean = np.mean(background_data, axis=0)
|
||
self.background_cov = np.cov(background_data.T)
|
||
|
||
# 计算协方差矩阵的逆
|
||
try:
|
||
self.background_inv_cov = linalg.inv(self.background_cov)
|
||
except np.linalg.LinAlgError:
|
||
print("警告: 协方差矩阵奇异,使用伪逆")
|
||
self.background_inv_cov = linalg.pinv(self.background_cov)
|
||
|
||
print(f"自定义背景统计量计算完成,协方差矩阵形状: {self.background_cov.shape}")
|
||
|
||
def _compute_local_background_stats(self, data_window: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||
"""
|
||
计算局部背景统计量
|
||
|
||
Args:
|
||
data_window: 局部窗口数据 (n_pixels, n_bands)
|
||
|
||
Returns:
|
||
mean_vector, cov_matrix, inv_cov_matrix
|
||
"""
|
||
# 移除窗口中的NaN值
|
||
valid_mask = ~np.isnan(data_window).any(axis=1)
|
||
valid_data = data_window[valid_mask]
|
||
|
||
if len(valid_data) < 2:
|
||
# 数据不足,返回全局统计量
|
||
return self.background_mean, self.background_cov, self.background_inv_cov
|
||
|
||
if self.config.use_robust_covariance and len(valid_data) > self.background_cov.shape[0]:
|
||
# 使用稳健协方差估计
|
||
cov_estimator = MinCovDet()
|
||
cov_estimator.fit(valid_data)
|
||
mean_vec = cov_estimator.location_
|
||
cov_mat = cov_estimator.covariance_
|
||
else:
|
||
# 使用经验协方差
|
||
mean_vec = np.mean(valid_data, axis=0)
|
||
cov_mat = np.cov(valid_data.T)
|
||
|
||
# 计算协方差矩阵的逆
|
||
try:
|
||
inv_cov_mat = linalg.inv(cov_mat)
|
||
except np.linalg.LinAlgError:
|
||
inv_cov_mat = linalg.pinv(cov_mat)
|
||
|
||
return mean_vec, cov_mat, inv_cov_mat
|
||
|
||
def compute_mahalanobis_distance(self, x: np.ndarray, mean_vec: np.ndarray,
|
||
inv_cov_mat: np.ndarray) -> float:
|
||
"""
|
||
计算单个像素的Mahalanobis距离
|
||
|
||
Args:
|
||
x: 像素光谱向量
|
||
mean_vec: 均值向量
|
||
inv_cov_mat: 协方差矩阵的逆
|
||
|
||
Returns:
|
||
Mahalanobis距离
|
||
"""
|
||
diff = x - mean_vec
|
||
distance = np.dot(np.dot(diff, inv_cov_mat), diff.T)
|
||
|
||
# 确保距离为非负实数
|
||
return max(0, distance)
|
||
|
||
def detect_anomalies_global(self) -> np.ndarray:
|
||
"""
|
||
使用全局背景模型进行异常检测
|
||
|
||
Returns:
|
||
距离图 (n_samples,)
|
||
"""
|
||
print("正在进行全局RX异常检测...")
|
||
|
||
n_samples = self.data.shape[0]
|
||
distance_map = np.zeros(n_samples, dtype=np.float32)
|
||
|
||
# 对每个像素计算Mahalanobis距离
|
||
for i in range(n_samples):
|
||
distance_map[i] = self.compute_mahalanobis_distance(
|
||
self.data[i], self.background_mean, self.background_inv_cov
|
||
)
|
||
|
||
return distance_map
|
||
|
||
def detect_anomalies_local_row(self) -> np.ndarray:
|
||
"""
|
||
使用局部行背景模型进行异常检测
|
||
|
||
Returns:
|
||
距离图 (rows, cols)
|
||
"""
|
||
print("正在进行局部行RX异常检测...")
|
||
|
||
rows, cols, bands = self.data_shape
|
||
distance_map = np.zeros((rows, cols), dtype=np.float32)
|
||
|
||
# 将数据重塑为2D图像
|
||
image_data = self.data.reshape(rows, cols, bands)
|
||
|
||
half_window = self.config.row_window // 2
|
||
|
||
for i in range(rows):
|
||
# 计算行窗口范围
|
||
row_start = max(0, i - half_window)
|
||
row_end = min(rows, i + half_window + 1)
|
||
|
||
# 提取行窗口数据
|
||
row_window = image_data[row_start:row_end, :, :].reshape(-1, bands)
|
||
|
||
# 计算局部背景统计量
|
||
local_mean, local_cov, local_inv_cov = self._compute_local_background_stats(row_window)
|
||
|
||
# 计算当前行的距离
|
||
for j in range(cols):
|
||
pixel = image_data[i, j, :]
|
||
distance_map[i, j] = self.compute_mahalanobis_distance(
|
||
pixel, local_mean, local_inv_cov
|
||
)
|
||
|
||
return distance_map.flatten()
|
||
|
||
def detect_anomalies_local_square(self) -> np.ndarray:
|
||
"""
|
||
使用局部正方形背景模型进行异常检测
|
||
|
||
Returns:
|
||
距离图 (rows, cols)
|
||
"""
|
||
print("正在进行局部正方形RX异常检测...")
|
||
|
||
rows, cols, bands = self.data_shape
|
||
distance_map = np.zeros((rows, cols), dtype=np.float32)
|
||
|
||
# 将数据重塑为2D图像
|
||
image_data = self.data.reshape(rows, cols, bands)
|
||
|
||
half_window = self.config.window_size // 2
|
||
|
||
# 使用向量化计算优化性能
|
||
for i in range(rows):
|
||
for j in range(cols):
|
||
# 计算窗口范围
|
||
row_start = max(0, i - half_window)
|
||
row_end = min(rows, i + half_window + 1)
|
||
col_start = max(0, j - half_window)
|
||
col_end = min(cols, j + half_window + 1)
|
||
|
||
# 提取窗口数据
|
||
window_data = image_data[row_start:row_end, col_start:col_end, :].reshape(-1, bands)
|
||
|
||
# 计算局部背景统计量
|
||
local_mean, local_cov, local_inv_cov = self._compute_local_background_stats(window_data)
|
||
|
||
# 计算当前像素的距离
|
||
pixel = image_data[i, j, :]
|
||
distance_map[i, j] = self.compute_mahalanobis_distance(
|
||
pixel, local_mean, local_inv_cov
|
||
)
|
||
|
||
return distance_map.flatten()
|
||
|
||
def detect_anomalies(self) -> np.ndarray:
|
||
"""
|
||
执行异常检测
|
||
|
||
Returns:
|
||
Mahalanobis距离图
|
||
"""
|
||
if self.config.background_model == 'global':
|
||
self.compute_background_statistics()
|
||
distance_map = self.detect_anomalies_global()
|
||
elif self.config.background_model == 'local_row':
|
||
distance_map = self.detect_anomalies_local_row()
|
||
elif self.config.background_model == 'local_square':
|
||
distance_map = self.detect_anomalies_local_square()
|
||
elif self.config.background_model == 'custom':
|
||
self.compute_background_statistics()
|
||
distance_map = self.detect_anomalies_global()
|
||
|
||
self.distance_map = distance_map
|
||
|
||
# 应用阈值分割(如果指定)
|
||
if self.config.threshold is not None:
|
||
self.anomaly_map = (distance_map > self.config.threshold).astype(np.int32)
|
||
else:
|
||
self.anomaly_map = None
|
||
|
||
# 统计信息
|
||
mean_dist = np.mean(distance_map)
|
||
std_dist = np.std(distance_map)
|
||
max_dist = np.max(distance_map)
|
||
|
||
print("异常检测完成:")
|
||
print(".4f")
|
||
print(".4f")
|
||
print(".4f")
|
||
|
||
if self.anomaly_map is not None:
|
||
n_anomalies = np.sum(self.anomaly_map)
|
||
anomaly_ratio = n_anomalies / len(distance_map)
|
||
print(".3f")
|
||
|
||
return distance_map
|
||
|
||
def run_analysis_from_config(self) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
从配置运行完整的异常检测分析
|
||
|
||
Returns:
|
||
Tuple[np.ndarray, np.ndarray]: (predictions, scores)
|
||
"""
|
||
# 执行异常检测,获取距离图作为分数
|
||
scores = self.detect_anomalies()
|
||
|
||
# 生成预测结果(基于阈值或百分位数)
|
||
if self.config.threshold is not None:
|
||
predictions = (scores > self.config.threshold).astype(np.int32)
|
||
else:
|
||
# 使用contamination参数确定阈值
|
||
# contamination表示异常比例,所以使用(1-contamination)*100百分位数
|
||
percentile = (1 - self.config.contamination) * 100
|
||
threshold = np.percentile(scores, percentile)
|
||
predictions = (scores > threshold).astype(np.int32)
|
||
|
||
return predictions, scores
|
||
|
||
def save_results(self, output_path: Optional[str] = None) -> None:
|
||
"""
|
||
保存异常检测结果为ENVI格式文件
|
||
|
||
Args:
|
||
output_path: 输出文件路径,如果为None则使用配置中的路径
|
||
"""
|
||
if output_path is None:
|
||
if self.config.output_path:
|
||
output_path = self.config.output_path
|
||
elif self.config.output_dir:
|
||
base_name = Path(self.config.input_path).stem
|
||
output_path = os.path.join(self.config.output_dir, f"{base_name}_rx_anomaly.bip")
|
||
else:
|
||
raise ValueError("必须指定输出路径")
|
||
|
||
print(f"正在保存结果到: {output_path}")
|
||
|
||
# 创建输出目录
|
||
output_dir = os.path.dirname(output_path)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 根据输入格式决定输出格式
|
||
if self.input_format == 'envi':
|
||
self._save_envi_results(output_path)
|
||
else:
|
||
# 对于CSV输入,创建虚拟的2D图像
|
||
self._save_csv_results_as_envi(output_path)
|
||
|
||
def _save_envi_results(self, output_path: str) -> None:
|
||
"""保存ENVI格式结果"""
|
||
# 重塑距离图为原始图像尺寸
|
||
rows, cols, bands = self.data_shape
|
||
|
||
# 创建单波段异常检测结果图像
|
||
if self.anomaly_map is not None:
|
||
# 保存异常检测结果 (0/1)
|
||
anomaly_image = np.zeros((rows, cols, 1), dtype=np.float32)
|
||
anomaly_image[:, :, 0] = self.anomaly_map.reshape(rows, cols)
|
||
band_names = ["Anomaly_Map"]
|
||
else:
|
||
# 保存距离图
|
||
anomaly_image = np.zeros((rows, cols, 1), dtype=np.float32)
|
||
anomaly_image[:, :, 0] = self.distance_map.reshape(rows, cols)
|
||
band_names = ["Mahalanobis_Distance"]
|
||
|
||
# 保存为ENVI格式,参考Covariance.py的方式
|
||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||
|
||
def _save_csv_results_as_envi(self, output_path: str) -> None:
|
||
"""将CSV结果保存为ENVI格式"""
|
||
# 对于CSV数据,创建1xN的单波段图像
|
||
n_samples = len(self.distance_map)
|
||
if self.anomaly_map is not None:
|
||
anomaly_image = np.zeros((1, n_samples, 1), dtype=np.float32)
|
||
anomaly_image[0, :, 0] = self.anomaly_map
|
||
band_names = ["Anomaly_Map"]
|
||
else:
|
||
anomaly_image = np.zeros((1, n_samples, 1), dtype=np.float32)
|
||
anomaly_image[0, :, 0] = self.distance_map
|
||
band_names = ["Mahalanobis_Distance"]
|
||
|
||
# 保存为ENVI格式,参考Covariance.py的方式
|
||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||
|
||
def _save_envi_data(self, data: np.ndarray, output_path: Union[str, Path],
|
||
band_names: List[str], interleave: str = 'bip') -> None:
|
||
"""
|
||
保存数据为ENVI格式,参考Covariance.py的实现
|
||
|
||
参数:
|
||
data: 要保存的数据数组 (height, width, channels)
|
||
output_path: 输出文件路径
|
||
band_names: 波段名称列表
|
||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||
"""
|
||
output_path = Path(output_path)
|
||
height, width, channels = data.shape
|
||
|
||
# 确保目录存在
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"保存ENVI文件 - 形状: {data.shape}, 交织方式: {interleave}")
|
||
|
||
# 保存.dat文件(二进制格式)
|
||
self._save_dat_file(data, str(output_path))
|
||
|
||
# 保存.hdr头文件
|
||
hdr_path = str(output_path).replace('.dat', '.hdr')
|
||
self._save_hdr_file(hdr_path, data.shape, band_names, interleave)
|
||
|
||
print(f"RX异常检测结果已保存: {output_path}")
|
||
print(f"头文件已保存: {hdr_path}")
|
||
print(f"使用的交织格式: {interleave.upper()}")
|
||
|
||
def _save_dat_file(self, data: np.ndarray, file_path: str) -> None:
|
||
"""保存.dat文件(二进制格式),参考Covariance.py"""
|
||
# 根据数据类型保存
|
||
if data.dtype == np.float32:
|
||
# 对于浮点数据,直接保存
|
||
with open(file_path, 'wb') as f:
|
||
data.astype(np.float32).tofile(f)
|
||
else:
|
||
# 对于其他数据类型,转换为float32保存
|
||
with open(file_path, 'wb') as f:
|
||
data.astype(np.float32).tofile(f)
|
||
|
||
def _save_hdr_file(self, hdr_path: str, data_shape: Tuple[int, ...],
|
||
band_names: List[str], interleave: str) -> None:
|
||
"""保存ENVI头文件,参考Covariance.py并适配RX数据"""
|
||
height, width, channels = data_shape
|
||
|
||
# 确定数据类型编码
|
||
# ENVI数据类型: 4=float32, 5=float64, 1=uint8, 2=int16, 3=int32, 12=uint16
|
||
data_type_code = 4 # 默认float32
|
||
|
||
header_content = f"""ENVI
|
||
description = {{RX Anomaly Detection Results - Generated by RXAnomalyDetector}}
|
||
samples = {width}
|
||
lines = {height}
|
||
bands = {channels}
|
||
header offset = 0
|
||
file type = ENVI Standard
|
||
data type = {data_type_code}
|
||
interleave = {interleave}
|
||
byte order = 0
|
||
"""
|
||
|
||
# 添加波段名称
|
||
if band_names:
|
||
header_content += "band names = {\n"
|
||
for i, name in enumerate(band_names):
|
||
header_content += f' "{name}"'
|
||
if i < len(band_names) - 1:
|
||
header_content += ","
|
||
header_content += "\n"
|
||
header_content += "}\n"
|
||
|
||
# 添加RX相关元数据
|
||
header_content += f"anomaly_detection_method = rx\n"
|
||
header_content += f"background_model = {self.config.background_model}\n"
|
||
if self.config.background_model in ['local_row', 'local_square']:
|
||
header_content += f"window_size = {self.config.window_size}\n"
|
||
if self.config.background_model == 'local_row':
|
||
header_content += f"row_window = {self.config.row_window}\n"
|
||
if self.config.threshold is not None:
|
||
header_content += f"threshold = {self.config.threshold}\n"
|
||
header_content += f"use_robust_covariance = {self.config.use_robust_covariance}\n"
|
||
|
||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||
f.write(header_content)
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""
|
||
获取异常检测统计信息
|
||
|
||
Returns:
|
||
包含各种统计信息的字典
|
||
"""
|
||
if self.distance_map is None:
|
||
raise ValueError("请先运行异常检测")
|
||
|
||
stats = {
|
||
'total_pixels': len(self.distance_map),
|
||
'background_model': self.config.background_model,
|
||
'mean_distance': float(np.mean(self.distance_map)),
|
||
'std_distance': float(np.std(self.distance_map)),
|
||
'min_distance': float(np.min(self.distance_map)),
|
||
'max_distance': float(np.max(self.distance_map)),
|
||
'median_distance': float(np.median(self.distance_map))
|
||
}
|
||
|
||
# 计算分位数
|
||
for percentile in [25, 75, 90, 95, 99]:
|
||
stats[f'percentile_{percentile}'] = float(np.percentile(self.distance_map, percentile))
|
||
|
||
if self.anomaly_map is not None:
|
||
stats['threshold'] = self.config.threshold
|
||
stats['n_anomalies'] = int(np.sum(self.anomaly_map))
|
||
stats['anomaly_ratio'] = float(np.sum(self.anomaly_map) / len(self.distance_map))
|
||
|
||
return stats
|
||
|
||
def suggest_threshold(self, method: str = 'percentile') -> float:
|
||
"""
|
||
建议异常检测阈值
|
||
|
||
Args:
|
||
method: 阈值建议方法 ('percentile', 'mean_std', 'auto')
|
||
|
||
Returns:
|
||
建议的阈值
|
||
"""
|
||
if self.distance_map is None:
|
||
raise ValueError("请先运行异常检测")
|
||
|
||
if method == 'percentile':
|
||
# 使用95%分位数作为阈值
|
||
threshold = np.percentile(self.distance_map, 95)
|
||
elif method == 'mean_std':
|
||
# 使用均值+2倍标准差作为阈值
|
||
mean_dist = np.mean(self.distance_map)
|
||
std_dist = np.std(self.distance_map)
|
||
threshold = mean_dist + 2 * std_dist
|
||
elif method == 'auto':
|
||
# 使用自动阈值选择(基于卡方分布)
|
||
# 对于高维数据,自由度近似为波段数
|
||
n_bands = self.data.shape[1]
|
||
# 使用卡方分布的95%分位数
|
||
threshold = stats.chi2.ppf(0.95, n_bands)
|
||
else:
|
||
raise ValueError(f"不支持的阈值建议方法: {method}")
|
||
|
||
return float(threshold)
|
||
|
||
|
||
def main():
|
||
"""主函数:命令行接口"""
|
||
parser = argparse.ArgumentParser(
|
||
description='Reed-Xiaoli (RX) 高光谱异常检测',
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
使用示例:
|
||
python RX.py --input image.bip --background-model global --output results.bip
|
||
python RX.py --input data.csv --spectral-start 400 --background-model local_square --window-size 7 --output results.bip
|
||
python RX.py --input image.bip --background-model custom --background background.bip --threshold 10.0 --output results.bip
|
||
|
||
背景模型说明:
|
||
global: 使用整个数据集作为背景
|
||
local_row: 使用像素周围N行的窗口作为背景
|
||
local_square: 使用固定大小的正方形窗口作为背景
|
||
custom: 使用外部提供的背景数据
|
||
"""
|
||
)
|
||
|
||
# 输入文件参数
|
||
parser.add_argument('--input', required=True, help='输入文件路径 (CSV或ENVI格式)')
|
||
parser.add_argument('--background', help='背景数据路径 (custom模式时必需)')
|
||
parser.add_argument('--label-col', help='标签列名 (CSV文件)')
|
||
parser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
|
||
parser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
|
||
|
||
# 输出参数
|
||
parser.add_argument('--output', help='输出文件路径')
|
||
parser.add_argument('--output-dir', help='输出目录 (将自动生成文件名)')
|
||
|
||
# 背景模型参数
|
||
parser.add_argument('--background-model', choices=['global', 'local_row', 'local_square', 'custom'],
|
||
default='global', help='背景模型类型,默认global')
|
||
parser.add_argument('--window-size', type=int, default=5,
|
||
help='局部窗口大小,默认5')
|
||
parser.add_argument('--row-window', type=int, default=3,
|
||
help='行窗口大小,默认3')
|
||
|
||
# 检测参数
|
||
parser.add_argument('--threshold', type=float,
|
||
help='异常检测阈值,不指定则只输出距离图')
|
||
parser.add_argument('--robust-covariance', action='store_true',
|
||
help='使用稳健协方差估计')
|
||
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
# 创建配置
|
||
config = RXConfig(
|
||
input_path=args.input,
|
||
background_path=args.background,
|
||
label_col=args.label_col,
|
||
spectral_start=args.spectral_start,
|
||
spectral_end=args.spectral_end,
|
||
output_path=args.output,
|
||
output_dir=args.output_dir,
|
||
background_model=args.background_model,
|
||
window_size=args.window_size,
|
||
row_window=args.row_window,
|
||
threshold=args.threshold,
|
||
use_robust_covariance=args.robust_covariance
|
||
)
|
||
|
||
# 创建检测器
|
||
detector = RXAnomalyDetector(config)
|
||
|
||
# 执行异常检测流程
|
||
print("=== Reed-Xiaoli (RX) 异常检测 ===")
|
||
print(f"输入文件: {config.input_path}")
|
||
print(f"背景模型: {config.background_model}")
|
||
if config.background_model in ['local_row', 'local_square']:
|
||
print(f"窗口大小: {config.window_size}")
|
||
if config.threshold is not None:
|
||
print(f"检测阈值: {config.threshold}")
|
||
print("-" * 50)
|
||
|
||
# 1. 加载数据
|
||
detector.load_data()
|
||
|
||
# 2. 预处理数据
|
||
detector.preprocess_data()
|
||
|
||
# 3. 执行异常检测
|
||
distance_map = detector.detect_anomalies()
|
||
|
||
# 4. 阈值建议
|
||
if config.threshold is None:
|
||
suggested_threshold = detector.suggest_threshold('percentile')
|
||
print(f"\n建议阈值 (95百分位数): {suggested_threshold:.4f}")
|
||
|
||
# 5. 保存结果
|
||
detector.save_results()
|
||
|
||
# 6. 显示统计信息
|
||
stats = detector.get_statistics()
|
||
print("\n=== 检测结果统计 ===")
|
||
for key, value in stats.items():
|
||
print(f"{key}: {value}")
|
||
|
||
print("\n处理完成!")
|
||
|
||
except Exception as e:
|
||
print(f"错误: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit(main())
|