增加模块;增加主调用命令
This commit is contained in:
883
Anomaly_method/RX.py
Normal file
883
Anomaly_method/RX.py
Normal file
@ -0,0 +1,883 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import spectral
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.covariance import EmpiricalCovariance, MinCovDet
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Dict, Any, Tuple, List, Union
|
||||
from dataclasses import dataclass, field
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from scipy import stats, linalg
|
||||
import cv2
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
@dataclass
|
||||
class RXConfig:
|
||||
"""RX异常检测配置类"""
|
||||
|
||||
# 输入文件配置
|
||||
input_path: Optional[str] = None
|
||||
background_path: Optional[str] = None # 用户指定背景数据
|
||||
label_col: Optional[str] = None
|
||||
spectral_start: Optional[str] = None
|
||||
spectral_end: Optional[str] = None
|
||||
csv_has_header: bool = True
|
||||
|
||||
# 输出配置
|
||||
output_path: Optional[str] = None
|
||||
output_dir: Optional[str] = None
|
||||
|
||||
# 背景模型配置
|
||||
background_model: str = 'global' # 'global', 'local_row', 'local_square', 'custom'
|
||||
window_size: int = 5 # 局部窗口大小
|
||||
row_window: int = 3 # 行窗口大小(行数)
|
||||
|
||||
# 检测参数
|
||||
threshold: Optional[float] = None # 异常检测阈值,None表示不进行阈值分割
|
||||
contamination: float = 0.1 # 异常点比例,用于自动确定阈值
|
||||
use_robust_covariance: bool = False # 是否使用稳健协方差估计
|
||||
|
||||
# 数据预处理配置
|
||||
use_scaling: bool = False # 是否使用标准化(RX算法通常不使用)
|
||||
|
||||
def __post_init__(self):
|
||||
"""参数校验和默认值设置"""
|
||||
# 校验必需的文件路径
|
||||
if not self.input_path:
|
||||
raise ValueError("必须指定输入文件路径(input_path)")
|
||||
|
||||
# 校验文件存在性
|
||||
if not os.path.exists(self.input_path):
|
||||
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
|
||||
|
||||
if self.background_path and not os.path.exists(self.background_path):
|
||||
raise FileNotFoundError(f"背景文件不存在: {self.background_path}")
|
||||
|
||||
# 校验输出路径
|
||||
if not self.output_path and not self.output_dir:
|
||||
raise ValueError("必须指定输出路径(output_path)或输出目录(output_dir)")
|
||||
|
||||
# 校验背景模型类型
|
||||
supported_models = ['global', 'local_row', 'local_square', 'custom']
|
||||
if self.background_model not in supported_models:
|
||||
raise ValueError(f"不支持的背景模型: {self.background_model}。支持的模型: {supported_models}")
|
||||
|
||||
# 校验参数范围
|
||||
if self.window_size <= 0:
|
||||
raise ValueError("window_size必须大于0")
|
||||
|
||||
if self.row_window <= 0:
|
||||
raise ValueError("row_window必须大于0")
|
||||
|
||||
if self.background_model == 'custom' and not self.background_path:
|
||||
raise ValueError("使用custom背景模型时,必须指定background_path")
|
||||
|
||||
|
||||
class RXAnomalyDetector:
|
||||
"""
|
||||
Reed-Xiaoli (RX) 异常检测器
|
||||
|
||||
RX算法基于Mahalanobis距离检测高光谱数据中的异常像素。
|
||||
算法计算每个像素相对于背景分布的距离,距离越大的像素越可能是异常。
|
||||
|
||||
数学基础:
|
||||
D(x) = (x - μ)ᵀ K⁻¹ (x - μ)
|
||||
|
||||
其中:
|
||||
- x: 测试像素的光谱向量
|
||||
- μ: 背景数据的均值向量
|
||||
- K: 背景数据的协方差矩阵
|
||||
"""
|
||||
|
||||
def __init__(self, config: RXConfig):
|
||||
"""
|
||||
初始化RX异常检测器
|
||||
|
||||
Args:
|
||||
config: 配置对象
|
||||
"""
|
||||
self.config = config
|
||||
self.data = None
|
||||
self.background_data = None
|
||||
self.wavelengths = None
|
||||
self.data_shape = None
|
||||
self.input_format = None
|
||||
|
||||
# 背景统计量
|
||||
self.background_mean = None
|
||||
self.background_cov = None
|
||||
self.background_inv_cov = None
|
||||
|
||||
# 结果
|
||||
self.distance_map = None # Mahalanobis距离图
|
||||
self.anomaly_map = None # 异常检测结果
|
||||
|
||||
# 预处理组件
|
||||
self.scaler = None
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""
|
||||
加载高光谱数据文件
|
||||
|
||||
支持CSV和ENVI格式文件
|
||||
"""
|
||||
file_path = Path(self.config.input_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
print(f"正在读取输入数据: {file_path}")
|
||||
|
||||
if suffix == '.csv':
|
||||
self._load_csv_data(file_path, is_background=False)
|
||||
elif suffix in ['.hdr']:
|
||||
self._load_envi_data(file_path, is_background=False)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}。支持的格式: .hdr")
|
||||
|
||||
# 加载背景数据
|
||||
if self.config.background_model == 'custom':
|
||||
self._load_background_data()
|
||||
|
||||
def _load_csv_data(self, file_path: Path, is_background: bool = False) -> None:
|
||||
"""加载CSV格式数据"""
|
||||
if self.config.spectral_start is None:
|
||||
raise ValueError("对于CSV文件,必须指定spectral_start参数")
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv(file_path, header=0 if self.config.csv_has_header else None)
|
||||
|
||||
# 提取标签列
|
||||
if self.config.label_col and self.config.label_col in df.columns:
|
||||
labels = df[self.config.label_col].values
|
||||
spectral_cols = [col for col in df.columns if col != self.config.label_col]
|
||||
else:
|
||||
labels = None
|
||||
spectral_cols = df.columns.tolist()
|
||||
|
||||
# 提取光谱数据
|
||||
if self.config.spectral_end:
|
||||
end_idx = spectral_cols.index(self.config.spectral_end) + 1
|
||||
else:
|
||||
end_idx = len(spectral_cols)
|
||||
|
||||
start_idx = spectral_cols.index(self.config.spectral_start)
|
||||
spectral_data = df[spectral_cols[start_idx:end_idx]].values
|
||||
|
||||
# 生成波长信息
|
||||
wavelengths = np.arange(len(spectral_cols[start_idx:end_idx]))
|
||||
|
||||
# 存储数据
|
||||
if is_background:
|
||||
self.background_data = spectral_data.astype(np.float32)
|
||||
print(f"背景数据加载完成: {self.background_data.shape} 样本 x {self.background_data.shape[1]} 波段")
|
||||
else:
|
||||
self.data = spectral_data.astype(np.float32)
|
||||
self.data_shape = self.data.shape
|
||||
self.input_format = 'csv'
|
||||
self.wavelengths = wavelengths
|
||||
print(f"输入数据加载完成: {self.data.shape} 样本 x {self.data.shape[1]} 波段")
|
||||
|
||||
def _load_envi_data(self, file_path: Path, is_background: bool = False) -> None:
|
||||
"""加载ENVI格式数据"""
|
||||
try:
|
||||
# 确保文件路径存在且可访问
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"ENVI文件不存在: {file_path}")
|
||||
|
||||
# 检查文件是否可读
|
||||
if not os.access(file_path, os.R_OK):
|
||||
raise PermissionError(f"没有读取权限: {file_path}")
|
||||
|
||||
# 使用规范化的路径字符串
|
||||
file_path_str = str(file_path)
|
||||
print(f"尝试使用spectral库读取ENVI文件: {file_path_str}")
|
||||
|
||||
# 检查对应的头文件
|
||||
hdr_path = file_path.with_suffix(file_path.suffix + '.hdr')
|
||||
if not hdr_path.exists():
|
||||
# 尝试 .hdr 格式
|
||||
hdr_path = file_path.with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
print(f"警告: 未找到头文件 {file_path.with_suffix(file_path.suffix + '.hdr')} 或 {file_path.with_suffix('.hdr')}")
|
||||
|
||||
# 读取ENVI文件 - 使用规范化的路径
|
||||
envi_data = spectral.open_image(file_path_str)
|
||||
|
||||
# 获取数据数组
|
||||
print("正在加载数据到内存...")
|
||||
data_array = envi_data.load()
|
||||
|
||||
# 获取波长信息
|
||||
if hasattr(envi_data, 'bands') and hasattr(envi_data.bands, 'centers'):
|
||||
wavelengths = np.array(envi_data.bands.centers)
|
||||
print(f"读取到波长信息: {len(wavelengths)} 个波段")
|
||||
else:
|
||||
wavelengths = np.arange(data_array.shape[2])
|
||||
print(f"未找到波长信息,使用默认波长: 0-{len(wavelengths)-1}")
|
||||
|
||||
# 重塑数据为 (samples, bands)
|
||||
rows, cols, bands = data_array.shape
|
||||
spectral_data = data_array.reshape(-1, bands).astype(np.float32)
|
||||
|
||||
# 存储数据
|
||||
if is_background:
|
||||
self.background_data = spectral_data
|
||||
print(f"背景数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||||
else:
|
||||
self.data = spectral_data
|
||||
self.data_shape = (rows, cols, bands)
|
||||
self.input_format = 'envi'
|
||||
self.wavelengths = wavelengths
|
||||
print(f"输入数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取ENVI文件失败: {str(e)}。请确保spectral库可用且文件格式正确。")
|
||||
|
||||
def _load_background_data(self) -> None:
|
||||
"""加载用户指定的背景数据"""
|
||||
if not self.config.background_path:
|
||||
raise ValueError("background_path未指定")
|
||||
|
||||
file_path = Path(self.config.background_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
print(f"正在读取背景数据: {file_path}")
|
||||
|
||||
if suffix == '.csv':
|
||||
self._load_csv_data(file_path, is_background=True)
|
||||
elif suffix in ['.bil', '.bsq', '.bip', '.dat']:
|
||||
self._load_envi_data(file_path, is_background=True)
|
||||
else:
|
||||
raise ValueError(f"不支持的背景文件格式: {suffix}")
|
||||
|
||||
def preprocess_data(self) -> None:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
RX算法通常不需要标准化,除非指定
|
||||
"""
|
||||
if self.config.use_scaling:
|
||||
print("正在标准化数据...")
|
||||
self.scaler = StandardScaler()
|
||||
|
||||
# 对输入数据进行标准化
|
||||
if self.data is not None:
|
||||
self.data = self.scaler.fit_transform(self.data)
|
||||
|
||||
# 对背景数据进行相同的标准化
|
||||
if self.background_data is not None:
|
||||
self.background_data = self.scaler.transform(self.background_data)
|
||||
|
||||
def compute_background_statistics(self) -> None:
|
||||
"""
|
||||
计算背景统计量(均值和协方差矩阵)
|
||||
|
||||
根据不同的背景模型计算相应的统计量
|
||||
"""
|
||||
print(f"正在计算背景统计量 (模型: {self.config.background_model})...")
|
||||
|
||||
if self.config.background_model == 'global':
|
||||
self._compute_global_background()
|
||||
elif self.config.background_model == 'custom':
|
||||
self._compute_custom_background()
|
||||
# 局部模型的统计量在检测过程中计算
|
||||
|
||||
def _compute_global_background(self) -> None:
|
||||
"""计算全局背景统计量"""
|
||||
# 使用整个数据集作为背景
|
||||
background_data = self.data
|
||||
|
||||
if self.config.use_robust_covariance:
|
||||
# 使用稳健协方差估计
|
||||
cov_estimator = MinCovDet()
|
||||
cov_estimator.fit(background_data)
|
||||
self.background_mean = cov_estimator.location_
|
||||
self.background_cov = cov_estimator.covariance_
|
||||
else:
|
||||
# 使用经验协方差
|
||||
self.background_mean = np.mean(background_data, axis=0)
|
||||
self.background_cov = np.cov(background_data.T)
|
||||
|
||||
# 计算协方差矩阵的逆
|
||||
try:
|
||||
self.background_inv_cov = linalg.inv(self.background_cov)
|
||||
except np.linalg.LinAlgError:
|
||||
print("警告: 协方差矩阵奇异,使用伪逆")
|
||||
self.background_inv_cov = linalg.pinv(self.background_cov)
|
||||
|
||||
print(f"全局背景统计量计算完成,协方差矩阵形状: {self.background_cov.shape}")
|
||||
|
||||
def _compute_custom_background(self) -> None:
|
||||
"""计算用户指定背景统计量"""
|
||||
if self.background_data is None:
|
||||
raise ValueError("自定义背景数据未加载")
|
||||
|
||||
background_data = self.background_data
|
||||
|
||||
if self.config.use_robust_covariance:
|
||||
cov_estimator = MinCovDet()
|
||||
cov_estimator.fit(background_data)
|
||||
self.background_mean = cov_estimator.location_
|
||||
self.background_cov = cov_estimator.covariance_
|
||||
else:
|
||||
self.background_mean = np.mean(background_data, axis=0)
|
||||
self.background_cov = np.cov(background_data.T)
|
||||
|
||||
# 计算协方差矩阵的逆
|
||||
try:
|
||||
self.background_inv_cov = linalg.inv(self.background_cov)
|
||||
except np.linalg.LinAlgError:
|
||||
print("警告: 协方差矩阵奇异,使用伪逆")
|
||||
self.background_inv_cov = linalg.pinv(self.background_cov)
|
||||
|
||||
print(f"自定义背景统计量计算完成,协方差矩阵形状: {self.background_cov.shape}")
|
||||
|
||||
def _compute_local_background_stats(self, data_window: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
计算局部背景统计量
|
||||
|
||||
Args:
|
||||
data_window: 局部窗口数据 (n_pixels, n_bands)
|
||||
|
||||
Returns:
|
||||
mean_vector, cov_matrix, inv_cov_matrix
|
||||
"""
|
||||
# 移除窗口中的NaN值
|
||||
valid_mask = ~np.isnan(data_window).any(axis=1)
|
||||
valid_data = data_window[valid_mask]
|
||||
|
||||
if len(valid_data) < 2:
|
||||
# 数据不足,返回全局统计量
|
||||
return self.background_mean, self.background_cov, self.background_inv_cov
|
||||
|
||||
if self.config.use_robust_covariance and len(valid_data) > self.background_cov.shape[0]:
|
||||
# 使用稳健协方差估计
|
||||
cov_estimator = MinCovDet()
|
||||
cov_estimator.fit(valid_data)
|
||||
mean_vec = cov_estimator.location_
|
||||
cov_mat = cov_estimator.covariance_
|
||||
else:
|
||||
# 使用经验协方差
|
||||
mean_vec = np.mean(valid_data, axis=0)
|
||||
cov_mat = np.cov(valid_data.T)
|
||||
|
||||
# 计算协方差矩阵的逆
|
||||
try:
|
||||
inv_cov_mat = linalg.inv(cov_mat)
|
||||
except np.linalg.LinAlgError:
|
||||
inv_cov_mat = linalg.pinv(cov_mat)
|
||||
|
||||
return mean_vec, cov_mat, inv_cov_mat
|
||||
|
||||
def compute_mahalanobis_distance(self, x: np.ndarray, mean_vec: np.ndarray,
|
||||
inv_cov_mat: np.ndarray) -> float:
|
||||
"""
|
||||
计算单个像素的Mahalanobis距离
|
||||
|
||||
Args:
|
||||
x: 像素光谱向量
|
||||
mean_vec: 均值向量
|
||||
inv_cov_mat: 协方差矩阵的逆
|
||||
|
||||
Returns:
|
||||
Mahalanobis距离
|
||||
"""
|
||||
diff = x - mean_vec
|
||||
distance = np.dot(np.dot(diff, inv_cov_mat), diff.T)
|
||||
|
||||
# 确保距离为非负实数
|
||||
return max(0, distance)
|
||||
|
||||
def detect_anomalies_global(self) -> np.ndarray:
|
||||
"""
|
||||
使用全局背景模型进行异常检测
|
||||
|
||||
Returns:
|
||||
距离图 (n_samples,)
|
||||
"""
|
||||
print("正在进行全局RX异常检测...")
|
||||
|
||||
n_samples = self.data.shape[0]
|
||||
distance_map = np.zeros(n_samples, dtype=np.float32)
|
||||
|
||||
# 对每个像素计算Mahalanobis距离
|
||||
for i in range(n_samples):
|
||||
distance_map[i] = self.compute_mahalanobis_distance(
|
||||
self.data[i], self.background_mean, self.background_inv_cov
|
||||
)
|
||||
|
||||
return distance_map
|
||||
|
||||
def detect_anomalies_local_row(self) -> np.ndarray:
|
||||
"""
|
||||
使用局部行背景模型进行异常检测
|
||||
|
||||
Returns:
|
||||
距离图 (rows, cols)
|
||||
"""
|
||||
print("正在进行局部行RX异常检测...")
|
||||
|
||||
rows, cols, bands = self.data_shape
|
||||
distance_map = np.zeros((rows, cols), dtype=np.float32)
|
||||
|
||||
# 将数据重塑为2D图像
|
||||
image_data = self.data.reshape(rows, cols, bands)
|
||||
|
||||
half_window = self.config.row_window // 2
|
||||
|
||||
for i in range(rows):
|
||||
# 计算行窗口范围
|
||||
row_start = max(0, i - half_window)
|
||||
row_end = min(rows, i + half_window + 1)
|
||||
|
||||
# 提取行窗口数据
|
||||
row_window = image_data[row_start:row_end, :, :].reshape(-1, bands)
|
||||
|
||||
# 计算局部背景统计量
|
||||
local_mean, local_cov, local_inv_cov = self._compute_local_background_stats(row_window)
|
||||
|
||||
# 计算当前行的距离
|
||||
for j in range(cols):
|
||||
pixel = image_data[i, j, :]
|
||||
distance_map[i, j] = self.compute_mahalanobis_distance(
|
||||
pixel, local_mean, local_inv_cov
|
||||
)
|
||||
|
||||
return distance_map.flatten()
|
||||
|
||||
def detect_anomalies_local_square(self) -> np.ndarray:
|
||||
"""
|
||||
使用局部正方形背景模型进行异常检测
|
||||
|
||||
Returns:
|
||||
距离图 (rows, cols)
|
||||
"""
|
||||
print("正在进行局部正方形RX异常检测...")
|
||||
|
||||
rows, cols, bands = self.data_shape
|
||||
distance_map = np.zeros((rows, cols), dtype=np.float32)
|
||||
|
||||
# 将数据重塑为2D图像
|
||||
image_data = self.data.reshape(rows, cols, bands)
|
||||
|
||||
half_window = self.config.window_size // 2
|
||||
|
||||
# 使用向量化计算优化性能
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
# 计算窗口范围
|
||||
row_start = max(0, i - half_window)
|
||||
row_end = min(rows, i + half_window + 1)
|
||||
col_start = max(0, j - half_window)
|
||||
col_end = min(cols, j + half_window + 1)
|
||||
|
||||
# 提取窗口数据
|
||||
window_data = image_data[row_start:row_end, col_start:col_end, :].reshape(-1, bands)
|
||||
|
||||
# 计算局部背景统计量
|
||||
local_mean, local_cov, local_inv_cov = self._compute_local_background_stats(window_data)
|
||||
|
||||
# 计算当前像素的距离
|
||||
pixel = image_data[i, j, :]
|
||||
distance_map[i, j] = self.compute_mahalanobis_distance(
|
||||
pixel, local_mean, local_inv_cov
|
||||
)
|
||||
|
||||
return distance_map.flatten()
|
||||
|
||||
def detect_anomalies(self) -> np.ndarray:
|
||||
"""
|
||||
执行异常检测
|
||||
|
||||
Returns:
|
||||
Mahalanobis距离图
|
||||
"""
|
||||
if self.config.background_model == 'global':
|
||||
self.compute_background_statistics()
|
||||
distance_map = self.detect_anomalies_global()
|
||||
elif self.config.background_model == 'local_row':
|
||||
distance_map = self.detect_anomalies_local_row()
|
||||
elif self.config.background_model == 'local_square':
|
||||
distance_map = self.detect_anomalies_local_square()
|
||||
elif self.config.background_model == 'custom':
|
||||
self.compute_background_statistics()
|
||||
distance_map = self.detect_anomalies_global()
|
||||
|
||||
self.distance_map = distance_map
|
||||
|
||||
# 应用阈值分割(如果指定)
|
||||
if self.config.threshold is not None:
|
||||
self.anomaly_map = (distance_map > self.config.threshold).astype(np.int32)
|
||||
else:
|
||||
self.anomaly_map = None
|
||||
|
||||
# 统计信息
|
||||
mean_dist = np.mean(distance_map)
|
||||
std_dist = np.std(distance_map)
|
||||
max_dist = np.max(distance_map)
|
||||
|
||||
print("异常检测完成:")
|
||||
print(".4f")
|
||||
print(".4f")
|
||||
print(".4f")
|
||||
|
||||
if self.anomaly_map is not None:
|
||||
n_anomalies = np.sum(self.anomaly_map)
|
||||
anomaly_ratio = n_anomalies / len(distance_map)
|
||||
print(".3f")
|
||||
|
||||
return distance_map
|
||||
|
||||
def run_analysis_from_config(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
从配置运行完整的异常检测分析
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: (predictions, scores)
|
||||
"""
|
||||
# 执行异常检测,获取距离图作为分数
|
||||
scores = self.detect_anomalies()
|
||||
|
||||
# 生成预测结果(基于阈值或百分位数)
|
||||
if self.config.threshold is not None:
|
||||
predictions = (scores > self.config.threshold).astype(np.int32)
|
||||
else:
|
||||
# 使用contamination参数确定阈值
|
||||
# contamination表示异常比例,所以使用(1-contamination)*100百分位数
|
||||
percentile = (1 - self.config.contamination) * 100
|
||||
threshold = np.percentile(scores, percentile)
|
||||
predictions = (scores > threshold).astype(np.int32)
|
||||
|
||||
return predictions, scores
|
||||
|
||||
def save_results(self, output_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
保存异常检测结果为ENVI格式文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径,如果为None则使用配置中的路径
|
||||
"""
|
||||
if output_path is None:
|
||||
if self.config.output_path:
|
||||
output_path = self.config.output_path
|
||||
elif self.config.output_dir:
|
||||
base_name = Path(self.config.input_path).stem
|
||||
output_path = os.path.join(self.config.output_dir, f"{base_name}_rx_anomaly.bip")
|
||||
else:
|
||||
raise ValueError("必须指定输出路径")
|
||||
|
||||
print(f"正在保存结果到: {output_path}")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.dirname(output_path)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 根据输入格式决定输出格式
|
||||
if self.input_format == 'envi':
|
||||
self._save_envi_results(output_path)
|
||||
else:
|
||||
# 对于CSV输入,创建虚拟的2D图像
|
||||
self._save_csv_results_as_envi(output_path)
|
||||
|
||||
def _save_envi_results(self, output_path: str) -> None:
|
||||
"""保存ENVI格式结果"""
|
||||
# 重塑距离图为原始图像尺寸
|
||||
rows, cols, bands = self.data_shape
|
||||
|
||||
# 创建单波段异常检测结果图像
|
||||
if self.anomaly_map is not None:
|
||||
# 保存异常检测结果 (0/1)
|
||||
anomaly_image = np.zeros((rows, cols, 1), dtype=np.float32)
|
||||
anomaly_image[:, :, 0] = self.anomaly_map.reshape(rows, cols)
|
||||
band_names = ["Anomaly_Map"]
|
||||
else:
|
||||
# 保存距离图
|
||||
anomaly_image = np.zeros((rows, cols, 1), dtype=np.float32)
|
||||
anomaly_image[:, :, 0] = self.distance_map.reshape(rows, cols)
|
||||
band_names = ["Mahalanobis_Distance"]
|
||||
|
||||
# 保存为ENVI格式,参考Covariance.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_csv_results_as_envi(self, output_path: str) -> None:
|
||||
"""将CSV结果保存为ENVI格式"""
|
||||
# 对于CSV数据,创建1xN的单波段图像
|
||||
n_samples = len(self.distance_map)
|
||||
if self.anomaly_map is not None:
|
||||
anomaly_image = np.zeros((1, n_samples, 1), dtype=np.float32)
|
||||
anomaly_image[0, :, 0] = self.anomaly_map
|
||||
band_names = ["Anomaly_Map"]
|
||||
else:
|
||||
anomaly_image = np.zeros((1, n_samples, 1), dtype=np.float32)
|
||||
anomaly_image[0, :, 0] = self.distance_map
|
||||
band_names = ["Mahalanobis_Distance"]
|
||||
|
||||
# 保存为ENVI格式,参考Covariance.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_envi_data(self, data: np.ndarray, output_path: Union[str, Path],
|
||||
band_names: List[str], interleave: str = 'bip') -> None:
|
||||
"""
|
||||
保存数据为ENVI格式,参考Covariance.py的实现
|
||||
|
||||
参数:
|
||||
data: 要保存的数据数组 (height, width, channels)
|
||||
output_path: 输出文件路径
|
||||
band_names: 波段名称列表
|
||||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
height, width, channels = data.shape
|
||||
|
||||
# 确保目录存在
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"保存ENVI文件 - 形状: {data.shape}, 交织方式: {interleave}")
|
||||
|
||||
# 保存.dat文件(二进制格式)
|
||||
self._save_dat_file(data, str(output_path))
|
||||
|
||||
# 保存.hdr头文件
|
||||
hdr_path = str(output_path).replace('.dat', '.hdr')
|
||||
self._save_hdr_file(hdr_path, data.shape, band_names, interleave)
|
||||
|
||||
print(f"RX异常检测结果已保存: {output_path}")
|
||||
print(f"头文件已保存: {hdr_path}")
|
||||
print(f"使用的交织格式: {interleave.upper()}")
|
||||
|
||||
def _save_dat_file(self, data: np.ndarray, file_path: str) -> None:
|
||||
"""保存.dat文件(二进制格式),参考Covariance.py"""
|
||||
# 根据数据类型保存
|
||||
if data.dtype == np.float32:
|
||||
# 对于浮点数据,直接保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
else:
|
||||
# 对于其他数据类型,转换为float32保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
|
||||
def _save_hdr_file(self, hdr_path: str, data_shape: Tuple[int, ...],
|
||||
band_names: List[str], interleave: str) -> None:
|
||||
"""保存ENVI头文件,参考Covariance.py并适配RX数据"""
|
||||
height, width, channels = data_shape
|
||||
|
||||
# 确定数据类型编码
|
||||
# ENVI数据类型: 4=float32, 5=float64, 1=uint8, 2=int16, 3=int32, 12=uint16
|
||||
data_type_code = 4 # 默认float32
|
||||
|
||||
header_content = f"""ENVI
|
||||
description = {{RX Anomaly Detection Results - Generated by RXAnomalyDetector}}
|
||||
samples = {width}
|
||||
lines = {height}
|
||||
bands = {channels}
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = {data_type_code}
|
||||
interleave = {interleave}
|
||||
byte order = 0
|
||||
"""
|
||||
|
||||
# 添加波段名称
|
||||
if band_names:
|
||||
header_content += "band names = {\n"
|
||||
for i, name in enumerate(band_names):
|
||||
header_content += f' "{name}"'
|
||||
if i < len(band_names) - 1:
|
||||
header_content += ","
|
||||
header_content += "\n"
|
||||
header_content += "}\n"
|
||||
|
||||
# 添加RX相关元数据
|
||||
header_content += f"anomaly_detection_method = rx\n"
|
||||
header_content += f"background_model = {self.config.background_model}\n"
|
||||
if self.config.background_model in ['local_row', 'local_square']:
|
||||
header_content += f"window_size = {self.config.window_size}\n"
|
||||
if self.config.background_model == 'local_row':
|
||||
header_content += f"row_window = {self.config.row_window}\n"
|
||||
if self.config.threshold is not None:
|
||||
header_content += f"threshold = {self.config.threshold}\n"
|
||||
header_content += f"use_robust_covariance = {self.config.use_robust_covariance}\n"
|
||||
|
||||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||||
f.write(header_content)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取异常检测统计信息
|
||||
|
||||
Returns:
|
||||
包含各种统计信息的字典
|
||||
"""
|
||||
if self.distance_map is None:
|
||||
raise ValueError("请先运行异常检测")
|
||||
|
||||
stats = {
|
||||
'total_pixels': len(self.distance_map),
|
||||
'background_model': self.config.background_model,
|
||||
'mean_distance': float(np.mean(self.distance_map)),
|
||||
'std_distance': float(np.std(self.distance_map)),
|
||||
'min_distance': float(np.min(self.distance_map)),
|
||||
'max_distance': float(np.max(self.distance_map)),
|
||||
'median_distance': float(np.median(self.distance_map))
|
||||
}
|
||||
|
||||
# 计算分位数
|
||||
for percentile in [25, 75, 90, 95, 99]:
|
||||
stats[f'percentile_{percentile}'] = float(np.percentile(self.distance_map, percentile))
|
||||
|
||||
if self.anomaly_map is not None:
|
||||
stats['threshold'] = self.config.threshold
|
||||
stats['n_anomalies'] = int(np.sum(self.anomaly_map))
|
||||
stats['anomaly_ratio'] = float(np.sum(self.anomaly_map) / len(self.distance_map))
|
||||
|
||||
return stats
|
||||
|
||||
def suggest_threshold(self, method: str = 'percentile') -> float:
|
||||
"""
|
||||
建议异常检测阈值
|
||||
|
||||
Args:
|
||||
method: 阈值建议方法 ('percentile', 'mean_std', 'auto')
|
||||
|
||||
Returns:
|
||||
建议的阈值
|
||||
"""
|
||||
if self.distance_map is None:
|
||||
raise ValueError("请先运行异常检测")
|
||||
|
||||
if method == 'percentile':
|
||||
# 使用95%分位数作为阈值
|
||||
threshold = np.percentile(self.distance_map, 95)
|
||||
elif method == 'mean_std':
|
||||
# 使用均值+2倍标准差作为阈值
|
||||
mean_dist = np.mean(self.distance_map)
|
||||
std_dist = np.std(self.distance_map)
|
||||
threshold = mean_dist + 2 * std_dist
|
||||
elif method == 'auto':
|
||||
# 使用自动阈值选择(基于卡方分布)
|
||||
# 对于高维数据,自由度近似为波段数
|
||||
n_bands = self.data.shape[1]
|
||||
# 使用卡方分布的95%分位数
|
||||
threshold = stats.chi2.ppf(0.95, n_bands)
|
||||
else:
|
||||
raise ValueError(f"不支持的阈值建议方法: {method}")
|
||||
|
||||
return float(threshold)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:命令行接口"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Reed-Xiaoli (RX) 高光谱异常检测',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
python RX.py --input image.bip --background-model global --output results.bip
|
||||
python RX.py --input data.csv --spectral-start 400 --background-model local_square --window-size 7 --output results.bip
|
||||
python RX.py --input image.bip --background-model custom --background background.bip --threshold 10.0 --output results.bip
|
||||
|
||||
背景模型说明:
|
||||
global: 使用整个数据集作为背景
|
||||
local_row: 使用像素周围N行的窗口作为背景
|
||||
local_square: 使用固定大小的正方形窗口作为背景
|
||||
custom: 使用外部提供的背景数据
|
||||
"""
|
||||
)
|
||||
|
||||
# 输入文件参数
|
||||
parser.add_argument('--input', required=True, help='输入文件路径 (CSV或ENVI格式)')
|
||||
parser.add_argument('--background', help='背景数据路径 (custom模式时必需)')
|
||||
parser.add_argument('--label-col', help='标签列名 (CSV文件)')
|
||||
parser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
|
||||
parser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
|
||||
|
||||
# 输出参数
|
||||
parser.add_argument('--output', help='输出文件路径')
|
||||
parser.add_argument('--output-dir', help='输出目录 (将自动生成文件名)')
|
||||
|
||||
# 背景模型参数
|
||||
parser.add_argument('--background-model', choices=['global', 'local_row', 'local_square', 'custom'],
|
||||
default='global', help='背景模型类型,默认global')
|
||||
parser.add_argument('--window-size', type=int, default=5,
|
||||
help='局部窗口大小,默认5')
|
||||
parser.add_argument('--row-window', type=int, default=3,
|
||||
help='行窗口大小,默认3')
|
||||
|
||||
# 检测参数
|
||||
parser.add_argument('--threshold', type=float,
|
||||
help='异常检测阈值,不指定则只输出距离图')
|
||||
parser.add_argument('--robust-covariance', action='store_true',
|
||||
help='使用稳健协方差估计')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# 创建配置
|
||||
config = RXConfig(
|
||||
input_path=args.input,
|
||||
background_path=args.background,
|
||||
label_col=args.label_col,
|
||||
spectral_start=args.spectral_start,
|
||||
spectral_end=args.spectral_end,
|
||||
output_path=args.output,
|
||||
output_dir=args.output_dir,
|
||||
background_model=args.background_model,
|
||||
window_size=args.window_size,
|
||||
row_window=args.row_window,
|
||||
threshold=args.threshold,
|
||||
use_robust_covariance=args.robust_covariance
|
||||
)
|
||||
|
||||
# 创建检测器
|
||||
detector = RXAnomalyDetector(config)
|
||||
|
||||
# 执行异常检测流程
|
||||
print("=== Reed-Xiaoli (RX) 异常检测 ===")
|
||||
print(f"输入文件: {config.input_path}")
|
||||
print(f"背景模型: {config.background_model}")
|
||||
if config.background_model in ['local_row', 'local_square']:
|
||||
print(f"窗口大小: {config.window_size}")
|
||||
if config.threshold is not None:
|
||||
print(f"检测阈值: {config.threshold}")
|
||||
print("-" * 50)
|
||||
|
||||
# 1. 加载数据
|
||||
detector.load_data()
|
||||
|
||||
# 2. 预处理数据
|
||||
detector.preprocess_data()
|
||||
|
||||
# 3. 执行异常检测
|
||||
distance_map = detector.detect_anomalies()
|
||||
|
||||
# 4. 阈值建议
|
||||
if config.threshold is None:
|
||||
suggested_threshold = detector.suggest_threshold('percentile')
|
||||
print(f"\n建议阈值 (95百分位数): {suggested_threshold:.4f}")
|
||||
|
||||
# 5. 保存结果
|
||||
detector.save_results()
|
||||
|
||||
# 6. 显示统计信息
|
||||
stats = detector.get_statistics()
|
||||
print("\n=== 检测结果统计 ===")
|
||||
for key, value in stats.items():
|
||||
print(f"{key}: {value}")
|
||||
|
||||
print("\n处理完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user