增加模块;增加主调用命令
This commit is contained in:
890
Anomaly_method/Covariance.py
Normal file
890
Anomaly_method/Covariance.py
Normal file
@ -0,0 +1,890 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
# 尝试导入 spectral 库
|
||||
try:
|
||||
import spectral
|
||||
SPECTRAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
SPECTRAL_AVAILABLE = False
|
||||
print("警告: spectral库不可用,将无法处理ENVI格式文件")
|
||||
|
||||
from sklearn.covariance import EllipticEnvelope
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.preprocessing import StandardScaler, RobustScaler
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Dict, Any, Tuple, List, Union
|
||||
from dataclasses import dataclass, field
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
@dataclass
|
||||
class CovarianceAnomalyConfig:
|
||||
"""
|
||||
基于协方差估计的异常检测配置类
|
||||
|
||||
支持的文件格式:
|
||||
- CSV文件: 需要指定光谱数据列范围
|
||||
- ENVI格式: .bip/.bil/.bsq文件,需要对应的.hdr头文件
|
||||
|
||||
ENVI文件要求:
|
||||
- 数据文件和头文件在同一目录
|
||||
- 头文件扩展名为 .hdr
|
||||
- 支持的数据类型包括: byte, int16, uint16, float32等
|
||||
"""
|
||||
|
||||
# 输入文件配置
|
||||
input_path: Optional[str] = None
|
||||
label_col: Optional[str] = None
|
||||
spectral_start: Optional[str] = None
|
||||
spectral_end: Optional[str] = None
|
||||
csv_has_header: bool = True
|
||||
|
||||
# 输出配置
|
||||
output_path: Optional[str] = None
|
||||
output_dir: Optional[str] = None
|
||||
|
||||
# 异常检测参数
|
||||
contamination: float = 0.1 # 异常点比例 (0.1 = 10%)
|
||||
support_fraction: Optional[float] = None # MCD支持比例,默认为None使用自动选择
|
||||
assume_centered: bool = False # 数据是否已中心化
|
||||
|
||||
# 数据预处理配置
|
||||
use_dimensionality_reduction: bool = True # 是否使用降维
|
||||
reduction_method: str = 'pca' # 降维方法: 'pca' 或 't-sne'
|
||||
n_components: int = 10 # 降维后的维度
|
||||
use_scaling: bool = True # 是否使用标准化
|
||||
scaler_type: str = 'robust' # 标准化类型: 'standard' 或 'robust'
|
||||
|
||||
# 随机种子
|
||||
random_state: int = 42
|
||||
|
||||
def __post_init__(self):
|
||||
"""参数校验和默认值设置"""
|
||||
# 校验必需的文件路径
|
||||
if not self.input_path:
|
||||
raise ValueError("必须指定输入文件路径(input_path)")
|
||||
|
||||
# 清理和规范化路径
|
||||
self.input_path = self._normalize_path(self.input_path)
|
||||
if self.output_path:
|
||||
self.output_path = self._normalize_path(self.output_path)
|
||||
if self.output_dir:
|
||||
self.output_dir = self._normalize_path(self.output_dir)
|
||||
|
||||
# 校验文件存在性(使用规范化后的路径)
|
||||
if not os.path.exists(self.input_path):
|
||||
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
|
||||
|
||||
# 校验输出路径
|
||||
if not self.output_path and not self.output_dir:
|
||||
raise ValueError("必须指定输出路径(output_path)或输出目录(output_dir)")
|
||||
|
||||
# 校验参数范围
|
||||
if not 0 < self.contamination <= 0.5:
|
||||
raise ValueError("contamination必须在(0, 0.5]范围内")
|
||||
|
||||
def _normalize_path(self, path: str) -> str:
|
||||
"""
|
||||
规范化路径:移除引号、展开用户目录、转换为绝对路径
|
||||
|
||||
Args:
|
||||
path: 输入路径字符串
|
||||
|
||||
Returns:
|
||||
规范化后的绝对路径
|
||||
"""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
# 移除可能的引号和其他包装字符
|
||||
path = str(path).strip('\'"')
|
||||
|
||||
# 展开用户目录 (~)
|
||||
path = os.path.expanduser(path)
|
||||
|
||||
# 转换为绝对路径
|
||||
path = os.path.abspath(path)
|
||||
|
||||
# 在 Windows 上,确保使用反斜杠
|
||||
if os.name == 'nt':
|
||||
path = path.replace('/', '\\')
|
||||
|
||||
return path
|
||||
|
||||
if self.support_fraction is not None and not 0.5 <= self.support_fraction <= 1.0:
|
||||
raise ValueError("support_fraction必须在[0.5, 1.0]范围内")
|
||||
|
||||
# 统一方法名为小写
|
||||
self.reduction_method = self.reduction_method.lower()
|
||||
self.scaler_type = self.scaler_type.lower()
|
||||
|
||||
# 校验降维方法
|
||||
if self.reduction_method not in ['pca', 't-sne']:
|
||||
raise ValueError(f"不支持的降维方法: {self.reduction_method}。支持的方法: ['pca', 't-sne']")
|
||||
|
||||
# 校验标准化类型
|
||||
if self.scaler_type not in ['standard', 'robust']:
|
||||
raise ValueError(f"不支持的标准化类型: {self.scaler_type}。支持的类型: ['standard', 'robust']")
|
||||
|
||||
|
||||
class CovarianceAnomalyDetector:
|
||||
"""
|
||||
基于协方差估计的异常检测器
|
||||
|
||||
使用scikit-learn的EllipticEnvelope类实现,该类基于Minimum Covariance Determinant (MCD)
|
||||
提供稳健的协方差估计,能够有效处理包含异常值的数据集。
|
||||
"""
|
||||
|
||||
def __init__(self, config: CovarianceAnomalyConfig):
|
||||
"""
|
||||
初始化异常检测器
|
||||
|
||||
Args:
|
||||
config: 配置对象
|
||||
"""
|
||||
self.config = config
|
||||
self.data = None
|
||||
self.labels = None
|
||||
self.wavelengths = None
|
||||
self.data_shape = None
|
||||
self.input_format = None
|
||||
|
||||
# 预处理组件
|
||||
self.scaler = None
|
||||
self.reducer = None
|
||||
|
||||
# 异常检测模型
|
||||
self.detector = None
|
||||
|
||||
# 结果
|
||||
self.anomaly_scores = None
|
||||
self.predictions = None
|
||||
self.mahalanobis_distances = None
|
||||
|
||||
def run_analysis_from_config(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
从配置运行完整的异常检测分析
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: (predictions, scores)
|
||||
"""
|
||||
# 预处理数据
|
||||
processed_data = self.preprocess_data()
|
||||
|
||||
# 执行异常检测
|
||||
predictions, scores = self.fit_predict(processed_data)
|
||||
|
||||
return predictions, scores
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""
|
||||
加载高光谱数据文件
|
||||
|
||||
支持CSV和ENVI格式文件
|
||||
"""
|
||||
# 使用规范化后的路径字符串
|
||||
file_path_str = self.config.input_path
|
||||
file_path = Path(file_path_str)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
print(f"正在读取文件: {file_path_str}")
|
||||
|
||||
if suffix == '.csv':
|
||||
self._load_csv_data(file_path)
|
||||
elif suffix in ['.hdr']:
|
||||
self._load_envi_data(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}。支持的格式: .hdr")
|
||||
|
||||
def _load_csv_data(self, file_path: Path) -> None:
|
||||
"""加载CSV格式数据"""
|
||||
if self.config.spectral_start is None:
|
||||
raise ValueError("对于CSV文件,必须指定spectral_start参数")
|
||||
|
||||
self.input_format = 'csv'
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv(file_path, header=0 if self.config.csv_has_header else None)
|
||||
|
||||
# 提取标签列
|
||||
if self.config.label_col and self.config.label_col in df.columns:
|
||||
self.labels = df[self.config.label_col].values
|
||||
spectral_cols = [col for col in df.columns if col != self.config.label_col]
|
||||
else:
|
||||
self.labels = None
|
||||
spectral_cols = df.columns.tolist()
|
||||
|
||||
# 提取光谱数据
|
||||
if self.config.spectral_end:
|
||||
end_idx = spectral_cols.index(self.config.spectral_end) + 1
|
||||
else:
|
||||
end_idx = len(spectral_cols)
|
||||
|
||||
start_idx = spectral_cols.index(self.config.spectral_start)
|
||||
spectral_data = df[spectral_cols[start_idx:end_idx]].values
|
||||
|
||||
# 生成波长信息
|
||||
self.wavelengths = np.arange(len(spectral_cols[start_idx:end_idx]))
|
||||
|
||||
self.data = spectral_data.astype(np.float32)
|
||||
self.data_shape = self.data.shape
|
||||
|
||||
print(f"CSV数据加载完成: {self.data.shape} 样本 x {self.data.shape[1]} 波段")
|
||||
|
||||
def _load_envi_data(self, file_path: Path) -> None:
|
||||
"""加载ENVI格式数据"""
|
||||
self.input_format = 'envi'
|
||||
|
||||
try:
|
||||
# 确保文件路径存在且可访问
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"ENVI文件不存在: {file_path}")
|
||||
|
||||
# 检查文件是否可读
|
||||
if not os.access(file_path, os.R_OK):
|
||||
raise PermissionError(f"没有读取权限: {file_path}")
|
||||
|
||||
# 使用规范化的路径字符串
|
||||
file_path_str = str(file_path)
|
||||
print(f"尝试使用spectral库读取ENVI文件: {file_path_str}")
|
||||
|
||||
# 检查对应的头文件
|
||||
hdr_path = file_path.with_suffix(file_path.suffix + '.hdr')
|
||||
if not hdr_path.exists():
|
||||
# 尝试 .hdr 格式
|
||||
hdr_path = file_path.with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
print(f"警告: 未找到头文件 {file_path.with_suffix(file_path.suffix + '.hdr')} 或 {file_path.with_suffix('.hdr')}")
|
||||
|
||||
# 读取ENVI文件 - 使用规范化的路径
|
||||
envi_data = spectral.open_image(file_path_str)
|
||||
|
||||
# 获取数据数组
|
||||
print("正在加载数据到内存...")
|
||||
data_array = envi_data.load()
|
||||
|
||||
# 获取波长信息
|
||||
if hasattr(envi_data, 'bands') and hasattr(envi_data.bands, 'centers'):
|
||||
self.wavelengths = np.array(envi_data.bands.centers)
|
||||
print(f"读取到波长信息: {len(self.wavelengths)} 个波段")
|
||||
else:
|
||||
self.wavelengths = np.arange(data_array.shape[2])
|
||||
print(f"未找到波长信息,使用默认波长: 0-{len(self.wavelengths)-1}")
|
||||
|
||||
# 重塑数据为 (samples, bands)
|
||||
rows, cols, bands = data_array.shape
|
||||
self.data = data_array.reshape(-1, bands).astype(np.float32)
|
||||
self.data_shape = (rows, cols, bands)
|
||||
|
||||
print(f"ENVI数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取ENVI文件失败: {str(e)}。请确保spectral库可用且文件格式正确。")
|
||||
|
||||
def preprocess_data(self) -> np.ndarray:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
Returns:
|
||||
处理后的数据数组
|
||||
"""
|
||||
print("正在预处理数据...")
|
||||
|
||||
# 检查数据有效性
|
||||
if self.data is None:
|
||||
raise ValueError("请先调用load_data()加载数据")
|
||||
|
||||
# 检查数据维度要求
|
||||
n_samples, n_features = self.data.shape
|
||||
if n_samples <= n_features ** 2:
|
||||
print(f"警告: 样本数({n_samples})小于特征数平方({n_features**2}),建议使用降维")
|
||||
|
||||
# 数据清理:移除NaN和无穷大值
|
||||
data_clean = self._clean_data(self.data)
|
||||
|
||||
# 标准化
|
||||
if self.config.use_scaling:
|
||||
data_scaled = self._scale_data(data_clean)
|
||||
else:
|
||||
data_scaled = data_clean
|
||||
|
||||
# 降维
|
||||
if self.config.use_dimensionality_reduction:
|
||||
data_processed = self._reduce_dimensionality(data_scaled)
|
||||
else:
|
||||
data_processed = data_scaled
|
||||
|
||||
print(f"数据预处理完成: {data_processed.shape}")
|
||||
return data_processed
|
||||
|
||||
def _clean_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""清理数据:移除无效值"""
|
||||
# 移除包含NaN或无穷大值的行
|
||||
valid_mask = np.all(np.isfinite(data), axis=1)
|
||||
data_clean = data[valid_mask]
|
||||
|
||||
if np.sum(~valid_mask) > 0:
|
||||
print(f"移除了 {np.sum(~valid_mask)} 个包含无效值的样本")
|
||||
|
||||
# 移除全零行
|
||||
non_zero_mask = np.any(data_clean != 0, axis=1)
|
||||
data_clean = data_clean[non_zero_mask]
|
||||
|
||||
if np.sum(~non_zero_mask) > 0:
|
||||
print(f"移除了 {np.sum(~non_zero_mask)} 个全零样本")
|
||||
|
||||
return data_clean
|
||||
|
||||
def _scale_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""标准化数据"""
|
||||
if self.config.scaler_type == 'robust':
|
||||
self.scaler = RobustScaler()
|
||||
else:
|
||||
self.scaler = StandardScaler()
|
||||
|
||||
data_scaled = self.scaler.fit_transform(data)
|
||||
return data_scaled
|
||||
|
||||
def _reduce_dimensionality(self, data: np.ndarray) -> np.ndarray:
|
||||
"""降维处理"""
|
||||
n_samples, n_features = data.shape
|
||||
|
||||
# 检查是否需要降维
|
||||
if n_features <= self.config.n_components:
|
||||
print(f"特征数({n_features})小于等于目标维度({self.config.n_components}),跳过降维")
|
||||
return data
|
||||
|
||||
if self.config.reduction_method == 'pca':
|
||||
self.reducer = PCA(
|
||||
n_components=self.config.n_components,
|
||||
random_state=self.config.random_state
|
||||
)
|
||||
elif self.config.reduction_method == 't-sne':
|
||||
# t-SNE需要更多的样本
|
||||
if n_samples < 4 * self.config.n_components:
|
||||
print(f"警告: t-SNE需要至少4倍于目标维度的样本数,使用PCA代替")
|
||||
self.reducer = PCA(
|
||||
n_components=self.config.n_components,
|
||||
random_state=self.config.random_state
|
||||
)
|
||||
else:
|
||||
self.reducer = TSNE(
|
||||
n_components=self.config.n_components,
|
||||
random_state=self.config.random_state,
|
||||
perplexity=min(30, n_samples // 3)
|
||||
)
|
||||
|
||||
data_reduced = self.reducer.fit_transform(data)
|
||||
|
||||
explained_var = None
|
||||
if hasattr(self.reducer, 'explained_variance_ratio_'):
|
||||
explained_var = np.sum(self.reducer.explained_variance_ratio_)
|
||||
print(f"降维完成,解释方差比例: {explained_var:.3f}")
|
||||
|
||||
return data_reduced
|
||||
|
||||
def fit_predict(self, data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
训练异常检测模型并进行预测
|
||||
|
||||
Args:
|
||||
data: 输入数据数组 (n_samples, n_features)
|
||||
|
||||
Returns:
|
||||
predictions: 预测结果 (+1为正常,-1为异常)
|
||||
scores: 异常分数 (马氏距离)
|
||||
"""
|
||||
print("正在训练异常检测模型...")
|
||||
|
||||
# 初始化异常检测器
|
||||
self.detector = EllipticEnvelope(
|
||||
contamination=self.config.contamination,
|
||||
support_fraction=self.config.support_fraction,
|
||||
assume_centered=self.config.assume_centered,
|
||||
random_state=self.config.random_state
|
||||
)
|
||||
|
||||
# 训练模型
|
||||
self.detector.fit(data)
|
||||
|
||||
# 预测异常
|
||||
predictions = self.detector.predict(data)
|
||||
|
||||
# 计算马氏距离作为异常分数
|
||||
# 使用predict_proba或decision_function获取分数
|
||||
try:
|
||||
# 尝试获取决策函数值
|
||||
scores = self.detector.decision_function(data)
|
||||
# 转换为马氏距离的近似值(负值表示异常)
|
||||
scores = -scores
|
||||
except:
|
||||
# 如果decision_function不可用,使用距离计算
|
||||
scores = self._compute_mahalanobis_distances(data)
|
||||
|
||||
self.predictions = predictions
|
||||
self.anomaly_scores = scores
|
||||
self.mahalanobis_distances = self._compute_mahalanobis_distances(data)
|
||||
|
||||
# 统计信息
|
||||
n_anomalies = np.sum(predictions == -1)
|
||||
anomaly_ratio = n_anomalies / len(predictions)
|
||||
|
||||
print(f"异常检测完成:")
|
||||
print(f" - 总样本数: {len(predictions)}")
|
||||
print(f" - 异常样本数: {n_anomalies}")
|
||||
print(f" - 异常比例: {anomaly_ratio:.3f}")
|
||||
|
||||
return predictions, scores
|
||||
|
||||
def _compute_mahalanobis_distances(self, data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
计算马氏距离
|
||||
|
||||
Args:
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
马氏距离数组
|
||||
"""
|
||||
if self.detector is None:
|
||||
raise ValueError("请先调用fit_predict()训练模型")
|
||||
|
||||
# 获取协方差矩阵和均值
|
||||
cov_matrix = self.detector.covariance_
|
||||
center = self.detector.location_
|
||||
|
||||
# 计算马氏距离
|
||||
diff = data - center
|
||||
inv_cov = np.linalg.inv(cov_matrix)
|
||||
mahalanobis_sq = np.sum(diff * (diff @ inv_cov), axis=1)
|
||||
mahalanobis_dist = np.sqrt(mahalanobis_sq)
|
||||
|
||||
return mahalanobis_dist
|
||||
|
||||
def save_results(self, output_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
保存异常检测结果为ENVI格式文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径,如果为None则使用配置中的路径
|
||||
"""
|
||||
if output_path is None:
|
||||
if self.config.output_path:
|
||||
output_path = self.config.output_path
|
||||
elif self.config.output_dir:
|
||||
base_name = Path(self.config.input_path).stem
|
||||
output_path = str(Path(self.config.output_dir) / f"{base_name}_anomaly_covariance.bip")
|
||||
else:
|
||||
raise ValueError("必须指定输出路径")
|
||||
|
||||
# 清理路径中的可能问题字符
|
||||
output_path = str(output_path).strip('\'"')
|
||||
|
||||
print(f"正在保存结果到: {output_path}")
|
||||
|
||||
# 创建输出目录
|
||||
output_path_obj = Path(output_path)
|
||||
output_dir = output_path_obj.parent
|
||||
|
||||
if output_dir and not output_dir.exists():
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
print(f"创建输出目录: {output_dir}")
|
||||
|
||||
# 根据输入格式决定输出格式
|
||||
if self.input_format == 'envi':
|
||||
self._save_envi_results(output_path)
|
||||
else:
|
||||
# 对于CSV输入,创建虚拟的2D图像
|
||||
self._save_csv_results_as_envi(output_path)
|
||||
|
||||
def _save_envi_results(self, output_path: str) -> None:
|
||||
"""保存ENVI格式结果,使用spectral库"""
|
||||
# 重塑预测结果为原始图像尺寸
|
||||
rows, cols, bands = self.data_shape
|
||||
|
||||
# 创建异常检测结果图像
|
||||
# 结果包括:预测结果、马氏距离、异常分数
|
||||
anomaly_image = np.zeros((rows, cols, 3), dtype=np.float32)
|
||||
|
||||
# 需要将1D结果映射回2D图像
|
||||
# 假设数据是按行优先顺序展平的
|
||||
predictions_2d = self.predictions.reshape(rows, cols)
|
||||
distances_2d = self.mahalanobis_distances.reshape(rows, cols)
|
||||
scores_2d = self.anomaly_scores.reshape(rows, cols)
|
||||
|
||||
anomaly_image[:, :, 0] = predictions_2d # 预测结果 (-1, 1)
|
||||
anomaly_image[:, :, 1] = distances_2d # 马氏距离
|
||||
anomaly_image[:, :, 2] = scores_2d # 异常分数
|
||||
|
||||
# 波段名称
|
||||
band_names = ["Anomaly_Prediction", "Mahalanobis_Distance", "Anomaly_Score"]
|
||||
|
||||
# 保存为ENVI格式,参考spectral2cie2.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_envi_data(self, data: np.ndarray, output_path: Union[str, Path],
|
||||
band_names: List[str], interleave: str = 'bip') -> None:
|
||||
"""
|
||||
保存数据为ENVI格式,参考edge_detect.py的实现
|
||||
|
||||
参数:
|
||||
data: 要保存的数据数组 (height, width, channels)
|
||||
output_path: 输出文件路径
|
||||
band_names: 波段名称列表
|
||||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
height, width, channels = data.shape
|
||||
|
||||
# 确保目录存在
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"保存ENVI文件 - 形状: {data.shape}, 交织方式: {interleave}")
|
||||
|
||||
# 使用edge_detect.py的方式保存
|
||||
output_base = str(output_path.with_suffix('')) # 移除扩展名并转换为字符串
|
||||
|
||||
# 保存.dat文件(二进制格式)
|
||||
self._save_dat_file(data, str(output_path))
|
||||
|
||||
# 保存.hdr头文件
|
||||
hdr_path = output_base + '.hdr'
|
||||
self._save_hdr_file(hdr_path, data.shape, band_names, interleave)
|
||||
|
||||
print(f"异常检测结果已保存: {output_path}")
|
||||
print(f"头文件已保存: {hdr_path}")
|
||||
print(f"使用的交织格式: {interleave.upper()}")
|
||||
|
||||
def _save_dat_file(self, data: np.ndarray, file_path: str) -> None:
|
||||
"""保存.dat文件(二进制格式),参考edge_detect.py"""
|
||||
# 根据数据类型保存
|
||||
if data.dtype == np.float32:
|
||||
# 对于浮点数据,直接保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
else:
|
||||
# 对于其他数据类型,转换为float32保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
|
||||
def _save_hdr_file(self, hdr_path: str, data_shape: Tuple[int, ...],
|
||||
band_names: List[str], interleave: str) -> None:
|
||||
"""保存ENVI头文件,参考edge_detect.py并适配多波段数据"""
|
||||
height, width, channels = data_shape
|
||||
|
||||
# 确定数据类型编码
|
||||
# ENVI数据类型: 4=float32, 5=float64, 1=uint8, 2=int16, 3=int32, 12=uint16
|
||||
data_type_code = 4 # 默认float32
|
||||
|
||||
header_content = f"""ENVI
|
||||
description = {{Covariance Anomaly Detection Results - Generated by CovarianceAnomalyDetector}}
|
||||
samples = {width}
|
||||
lines = {height}
|
||||
bands = {channels}
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = {data_type_code}
|
||||
interleave = {interleave}
|
||||
byte order = 0
|
||||
"""
|
||||
|
||||
# 添加波段名称
|
||||
if band_names:
|
||||
header_content += "band names = {\n"
|
||||
for i, name in enumerate(band_names):
|
||||
header_content += f' "{name}"'
|
||||
if i < len(band_names) - 1:
|
||||
header_content += ","
|
||||
header_content += "\n"
|
||||
header_content += "}\n"
|
||||
|
||||
# 添加异常检测相关元数据
|
||||
header_content += f"anomaly_detection_method = covariance\n"
|
||||
header_content += f"contamination = {self.config.contamination}\n"
|
||||
if self.config.support_fraction is not None:
|
||||
header_content += f"support_fraction = {self.config.support_fraction}\n"
|
||||
header_content += f"dimensionality_reduction = {self.config.use_dimensionality_reduction}\n"
|
||||
if self.config.use_dimensionality_reduction:
|
||||
header_content += f"reduction_method = {self.config.reduction_method}\n"
|
||||
header_content += f"n_components = {self.config.n_components}\n"
|
||||
|
||||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||||
f.write(header_content)
|
||||
|
||||
def _create_envi_hdr_file(self, bil_path: Union[str, Path], height: int, width: int,
|
||||
bands: int, data_type: str, interleave: str,
|
||||
band_names: List[str], input_hdr_path: Optional[Union[str, Path]] = None) -> None:
|
||||
"""
|
||||
创建ENVI头文件,参考spectral2cie2.py的实现
|
||||
|
||||
Args:
|
||||
bil_path: BIL文件路径
|
||||
height: 图像高度
|
||||
width: 图像宽度
|
||||
bands: 波段数
|
||||
data_type: 数据类型 ('float32', 'uint8', 'int16', 等)
|
||||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||||
band_names: 波段名称列表
|
||||
input_hdr_path: 输入ENVI文件的HDR路径,用于复制元数据
|
||||
"""
|
||||
# 使用 .bip.hdr 格式的头文件
|
||||
bil_path_obj = Path(bil_path)
|
||||
hdr_path = bil_path_obj.with_suffix(bil_path_obj.suffix + '.hdr')
|
||||
|
||||
# 数据类型映射 (ENVI格式)
|
||||
dtype_map = {
|
||||
'uint8': '1',
|
||||
'int16': '2',
|
||||
'int32': '3',
|
||||
'float32': '4',
|
||||
'float64': '5',
|
||||
'complex64': '6',
|
||||
'complex128': '9',
|
||||
'uint16': '12',
|
||||
'uint32': '13',
|
||||
'int64': '14',
|
||||
'uint64': '15'
|
||||
}
|
||||
|
||||
envi_dtype = dtype_map.get(data_type, '4') # 默认为float32
|
||||
|
||||
with open(hdr_path, 'w') as f:
|
||||
f.write("ENVI\n")
|
||||
f.write("description = {\n")
|
||||
f.write(" Covariance Anomaly Detection Results - Generated by CovarianceAnomalyDetector\n")
|
||||
f.write(f" Contamination: {self.config.contamination}, Support Fraction: {self.config.support_fraction}\n")
|
||||
f.write("}\n")
|
||||
f.write(f"samples = {width}\n")
|
||||
f.write(f"lines = {height}\n")
|
||||
f.write(f"bands = {bands}\n")
|
||||
f.write("header offset = 0\n")
|
||||
f.write("file type = ENVI Standard\n")
|
||||
f.write(f"data type = {envi_dtype}\n")
|
||||
f.write(f"interleave = {interleave}\n")
|
||||
f.write("sensor type = Unknown\n")
|
||||
f.write("byte order = 0\n")
|
||||
|
||||
# 波段名称
|
||||
f.write("band names = {\n")
|
||||
for i, name in enumerate(band_names):
|
||||
f.write(f' "{name}"')
|
||||
if i < len(band_names) - 1:
|
||||
f.write(",")
|
||||
f.write("\n")
|
||||
f.write("}\n")
|
||||
|
||||
# 如果有输入HDR文件,尝试复制相关的元数据
|
||||
if input_hdr_path and Path(input_hdr_path).exists():
|
||||
try:
|
||||
self._copy_hdr_metadata(input_hdr_path, f)
|
||||
except Exception as e:
|
||||
print(f"复制HDR元数据失败: {e}")
|
||||
|
||||
print(f"ENVI头文件创建完成: {hdr_path}")
|
||||
|
||||
def _copy_hdr_metadata(self, input_hdr_path: Union[str, Path], output_file) -> None:
|
||||
"""
|
||||
从输入HDR文件复制元数据到输出HDR文件
|
||||
|
||||
Args:
|
||||
input_hdr_path: 输入HDR文件路径
|
||||
output_file: 输出文件对象
|
||||
"""
|
||||
try:
|
||||
with open(input_hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
content = f.read()
|
||||
|
||||
# 解析输入HDR文件,提取有用的元数据
|
||||
lines = content.split('\n')
|
||||
metadata_to_copy = []
|
||||
|
||||
# 需要复制的元数据字段
|
||||
fields_to_copy = [
|
||||
'wavelength units', 'wavelength', 'fwhm', 'bbl', 'map info',
|
||||
'coordinate system string', 'projection info', 'pixel size',
|
||||
'acquisition time', 'sensor type', 'radiance scale factor',
|
||||
'reflectance scale factor', 'data gain values', 'data offset values'
|
||||
]
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
if '=' in line:
|
||||
key = line.split('=')[0].strip().lower()
|
||||
if any(field in key for field in fields_to_copy):
|
||||
# 复制这个字段及其可能的后续行
|
||||
metadata_to_copy.append(lines[i])
|
||||
i += 1
|
||||
# 如果是多行值,继续读取
|
||||
while i < len(lines) and not ('=' in lines[i] and not lines[i].strip().endswith(',')):
|
||||
if lines[i].strip():
|
||||
metadata_to_copy.append(lines[i])
|
||||
i += 1
|
||||
if i >= len(lines):
|
||||
break
|
||||
continue
|
||||
i += 1
|
||||
|
||||
# 将复制的元数据写入输出文件
|
||||
if metadata_to_copy:
|
||||
output_file.write("\n")
|
||||
for line in metadata_to_copy:
|
||||
if line.strip():
|
||||
output_file.write(line + "\n")
|
||||
|
||||
print(f"已从输入HDR文件复制 {len(metadata_to_copy)} 行元数据")
|
||||
|
||||
except Exception as e:
|
||||
print(f"读取输入HDR文件失败: {e}")
|
||||
|
||||
def _save_csv_results_as_envi(self, output_path: str) -> None:
|
||||
"""将CSV结果保存为ENVI格式"""
|
||||
# 对于CSV数据,创建1xN的图像
|
||||
n_samples = len(self.predictions)
|
||||
anomaly_image = np.zeros((1, n_samples, 3), dtype=np.float32)
|
||||
|
||||
anomaly_image[0, :, 0] = self.predictions # 预测结果
|
||||
anomaly_image[0, :, 1] = self.mahalanobis_distances # 马氏距离
|
||||
anomaly_image[0, :, 2] = self.anomaly_scores # 异常分数
|
||||
|
||||
# 波段名称
|
||||
band_names = ["Anomaly_Prediction", "Mahalanobis_Distance", "Anomaly_Score"]
|
||||
|
||||
# 保存为ENVI格式,参考spectral2cie2.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取异常检测统计信息
|
||||
|
||||
Returns:
|
||||
包含各种统计信息的字典
|
||||
"""
|
||||
if self.predictions is None:
|
||||
raise ValueError("请先运行异常检测")
|
||||
|
||||
stats = {
|
||||
'total_samples': len(self.predictions),
|
||||
'n_anomalies': int(np.sum(self.predictions == -1)),
|
||||
'anomaly_ratio': float(np.sum(self.predictions == -1) / len(self.predictions)),
|
||||
'mean_mahalanobis_distance': float(np.mean(self.mahalanobis_distances)),
|
||||
'std_mahalanobis_distance': float(np.std(self.mahalanobis_distances)),
|
||||
'min_mahalanobis_distance': float(np.min(self.mahalanobis_distances)),
|
||||
'max_mahalanobis_distance': float(np.max(self.mahalanobis_distances)),
|
||||
'contamination': self.config.contamination,
|
||||
'reduction_method': self.config.reduction_method if self.config.use_dimensionality_reduction else None,
|
||||
'n_components': self.config.n_components if self.config.use_dimensionality_reduction else None
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:命令行接口"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='基于协方差估计的高光谱异常检测',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
python Covariance.py --input data.csv --spectral-start 400 --output results.bip
|
||||
python Covariance.py --input image.bip --contamination 0.15 --reduction-method t-sne --output results.bip
|
||||
"""
|
||||
)
|
||||
|
||||
# 输入文件参数
|
||||
parser.add_argument('--input', required=True, help='输入文件路径 (CSV或ENVI格式,如.bip/.bil/.bsq文件)')
|
||||
parser.add_argument('--label-col', help='标签列名 (CSV文件)')
|
||||
parser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
|
||||
parser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
|
||||
|
||||
# 输出参数
|
||||
parser.add_argument('--output', help='输出文件路径')
|
||||
parser.add_argument('--output-dir', help='输出目录 (将自动生成文件名)')
|
||||
|
||||
# 异常检测参数
|
||||
parser.add_argument('--contamination', type=float, default=0.1,
|
||||
help='异常点比例 (0, 0.5],默认0.1')
|
||||
parser.add_argument('--support-fraction', type=float,
|
||||
help='MCD支持比例 [0.5, 1.0],默认自动选择')
|
||||
|
||||
# 预处理参数
|
||||
parser.add_argument('--no-reduction', action='store_true',
|
||||
help='禁用降维处理')
|
||||
parser.add_argument('--reduction-method', choices=['pca', 't-sne'], default='pca',
|
||||
help='降维方法,默认pca')
|
||||
parser.add_argument('--n-components', type=int, default=10,
|
||||
help='降维后的维度,默认10')
|
||||
parser.add_argument('--no-scaling', action='store_true',
|
||||
help='禁用数据标准化')
|
||||
parser.add_argument('--scaler-type', choices=['standard', 'robust'], default='robust',
|
||||
help='标准化类型,默认robust')
|
||||
|
||||
# 其他参数
|
||||
parser.add_argument('--random-state', type=int, default=42,
|
||||
help='随机种子,默认42')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# 创建配置
|
||||
config = CovarianceAnomalyConfig(
|
||||
input_path=args.input,
|
||||
label_col=args.label_col,
|
||||
spectral_start=args.spectral_start,
|
||||
spectral_end=args.spectral_end,
|
||||
output_path=args.output,
|
||||
output_dir=args.output_dir,
|
||||
contamination=args.contamination,
|
||||
support_fraction=args.support_fraction,
|
||||
use_dimensionality_reduction=not args.no_reduction,
|
||||
reduction_method=args.reduction_method,
|
||||
n_components=args.n_components,
|
||||
use_scaling=not args.no_scaling,
|
||||
scaler_type=args.scaler_type,
|
||||
random_state=args.random_state
|
||||
)
|
||||
|
||||
# 创建检测器
|
||||
detector = CovarianceAnomalyDetector(config)
|
||||
|
||||
# 执行异常检测流程
|
||||
print("=== 基于协方差估计的异常检测 ===")
|
||||
print(f"输入文件: {config.input_path}")
|
||||
print(f"异常比例: {config.contamination}")
|
||||
print(f"降维方法: {config.reduction_method if config.use_dimensionality_reduction else '无'}")
|
||||
print("-" * 50)
|
||||
|
||||
# 1. 加载数据
|
||||
detector.load_data()
|
||||
|
||||
# 2. 预处理数据
|
||||
processed_data = detector.preprocess_data()
|
||||
|
||||
# 3. 执行异常检测
|
||||
predictions, scores = detector.fit_predict(processed_data)
|
||||
|
||||
# 4. 保存结果
|
||||
detector.save_results()
|
||||
|
||||
# 5. 显示统计信息
|
||||
stats = detector.get_statistics()
|
||||
print("\n=== 检测结果统计 ===")
|
||||
for key, value in stats.items():
|
||||
print(f"{key}: {value}")
|
||||
|
||||
print("\n处理完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {str(e)}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
807
Anomaly_method/One_Class_SVM.py
Normal file
807
Anomaly_method/One_Class_SVM.py
Normal file
@ -0,0 +1,807 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import spectral
|
||||
from sklearn.svm import OneClassSVM
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.model_selection import GridSearchCV, cross_val_score
|
||||
from sklearn.metrics import make_scorer, precision_score, recall_score, f1_score
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Dict, Any, Tuple, List, Union
|
||||
from dataclasses import dataclass, field
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from scipy import stats
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
@dataclass
|
||||
class OneClassSVMConfig:
|
||||
"""One-Class SVM异常检测配置类"""
|
||||
|
||||
# 输入文件配置
|
||||
input_path: Optional[str] = None
|
||||
label_col: Optional[str] = None
|
||||
spectral_start: Optional[str] = None
|
||||
spectral_end: Optional[str] = None
|
||||
csv_has_header: bool = True
|
||||
|
||||
# 输出配置
|
||||
output_path: Optional[str] = None
|
||||
output_dir: Optional[str] = None
|
||||
|
||||
# SVM参数
|
||||
kernel: str = 'rbf' # 核函数: 'linear', 'rbf', 'poly'
|
||||
nu: float = 0.1 # 训练误差比例 (0.01-0.5)
|
||||
gamma: Optional[Union[str, float]] = 'auto' # 核函数参数: 'auto', 'scale', 或数值
|
||||
degree: int = 3 # 多项式核的阶数
|
||||
coef0: float = 0.0 # 核函数常数项
|
||||
tol: float = 1e-3 # 停止准则公差
|
||||
shrinking: bool = True # 是否使用shrinking heuristic
|
||||
cache_size: float = 200 # 核缓存大小(MB)
|
||||
max_iter: int = -1 # 最大迭代次数
|
||||
|
||||
# 数据预处理配置
|
||||
use_scaling: bool = True # 是否使用标准化
|
||||
use_pca: bool = True # 是否使用PCA降维
|
||||
n_components: Optional[int] = None # PCA降维维度,None表示自动选择
|
||||
|
||||
# 参数调优配置
|
||||
use_grid_search: bool = False # 是否使用网格搜索调优参数
|
||||
cv_folds: int = 3 # 交叉验证折数
|
||||
|
||||
def __post_init__(self):
|
||||
"""参数校验和默认值设置"""
|
||||
# 校验必需的文件路径
|
||||
if not self.input_path:
|
||||
raise ValueError("必须指定输入文件路径(input_path)")
|
||||
|
||||
# 校验文件存在性
|
||||
if not os.path.exists(self.input_path):
|
||||
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
|
||||
|
||||
# 校验输出路径
|
||||
if not self.output_path and not self.output_dir:
|
||||
raise ValueError("必须指定输出路径(output_path)或输出目录(output_dir)")
|
||||
|
||||
# 校验参数范围
|
||||
if not 0.01 <= self.nu <= 0.5:
|
||||
raise ValueError("nu必须在[0.01, 0.5]范围内")
|
||||
|
||||
# 统一核函数名为小写
|
||||
self.kernel = self.kernel.lower()
|
||||
|
||||
# 校验核函数类型
|
||||
supported_kernels = ['linear', 'rbf', 'poly']
|
||||
if self.kernel not in supported_kernels:
|
||||
raise ValueError(f"不支持的核函数: {self.kernel}。支持的核函数: {supported_kernels}")
|
||||
|
||||
# 校验gamma参数
|
||||
if isinstance(self.gamma, str):
|
||||
if self.gamma not in ['auto', 'scale']:
|
||||
raise ValueError("gamma字符串值必须是'auto'或'scale'")
|
||||
elif self.gamma is not None and self.gamma <= 0:
|
||||
raise ValueError("gamma数值必须大于0")
|
||||
|
||||
# 校验多项式阶数
|
||||
if self.degree < 1:
|
||||
raise ValueError("degree必须大于等于1")
|
||||
|
||||
|
||||
class OneClassSVMAnomalyDetector:
|
||||
"""
|
||||
One-Class SVM异常检测器
|
||||
|
||||
使用One-Class SVM算法进行无监督异常检测,只需要正常数据进行训练。
|
||||
支持多种核函数,能够学习复杂的数据分布边界。
|
||||
|
||||
算法原理:
|
||||
- 使用正常数据训练SVM,学习正常样本的决策边界
|
||||
- 测试样本落在边界内为正常(-1),落在边界外为异常(+1)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OneClassSVMConfig):
|
||||
"""
|
||||
初始化异常检测器
|
||||
|
||||
Args:
|
||||
config: 配置对象
|
||||
"""
|
||||
self.config = config
|
||||
self.data = None
|
||||
self.labels = None
|
||||
self.wavelengths = None
|
||||
self.data_shape = None
|
||||
self.input_format = None
|
||||
|
||||
# 预处理组件
|
||||
self.scaler = None
|
||||
self.pca = None
|
||||
|
||||
# One-Class SVM模型
|
||||
self.svm_model = None
|
||||
|
||||
# 结果
|
||||
self.decision_scores = None # 决策函数值
|
||||
self.predictions = None # 预测结果 (+1正常, -1异常)
|
||||
|
||||
# 最佳参数(网格搜索后)
|
||||
self.best_params = None
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""
|
||||
加载高光谱数据文件
|
||||
|
||||
支持CSV和ENVI格式文件
|
||||
"""
|
||||
file_path = Path(self.config.input_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
print(f"正在读取文件: {file_path}")
|
||||
|
||||
if suffix == '.csv':
|
||||
self._load_csv_data(file_path)
|
||||
elif suffix in ['.hdr']:
|
||||
self._load_envi_data(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}。支持的格式: .hdr")
|
||||
|
||||
def _load_csv_data(self, file_path: Path) -> None:
|
||||
"""加载CSV格式数据"""
|
||||
if self.config.spectral_start is None:
|
||||
raise ValueError("对于CSV文件,必须指定spectral_start参数")
|
||||
|
||||
self.input_format = 'csv'
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv(file_path, header=0 if self.config.csv_has_header else None)
|
||||
|
||||
# 提取标签列
|
||||
if self.config.label_col and self.config.label_col in df.columns:
|
||||
self.labels = df[self.config.label_col].values
|
||||
spectral_cols = [col for col in df.columns if col != self.config.label_col]
|
||||
else:
|
||||
self.labels = None
|
||||
spectral_cols = df.columns.tolist()
|
||||
|
||||
# 提取光谱数据
|
||||
if self.config.spectral_end:
|
||||
end_idx = spectral_cols.index(self.config.spectral_end) + 1
|
||||
else:
|
||||
end_idx = len(spectral_cols)
|
||||
|
||||
start_idx = spectral_cols.index(self.config.spectral_start)
|
||||
spectral_data = df[spectral_cols[start_idx:end_idx]].values
|
||||
|
||||
# 生成波长信息
|
||||
self.wavelengths = np.arange(len(spectral_cols[start_idx:end_idx]))
|
||||
|
||||
self.data = spectral_data.astype(np.float32)
|
||||
self.data_shape = self.data.shape
|
||||
|
||||
print(f"CSV数据加载完成: {self.data.shape} 样本 x {self.data.shape[1]} 波段")
|
||||
|
||||
def _load_envi_data(self, file_path: Path) -> None:
|
||||
"""加载ENVI格式数据"""
|
||||
self.input_format = 'envi'
|
||||
|
||||
try:
|
||||
# 确保文件路径存在且可访问
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"ENVI文件不存在: {file_path}")
|
||||
|
||||
# 检查文件是否可读
|
||||
if not os.access(file_path, os.R_OK):
|
||||
raise PermissionError(f"没有读取权限: {file_path}")
|
||||
|
||||
# 使用规范化的路径字符串
|
||||
file_path_str = str(file_path)
|
||||
print(f"尝试使用spectral库读取ENVI文件: {file_path_str}")
|
||||
|
||||
# 检查对应的头文件
|
||||
hdr_path = file_path.with_suffix(file_path.suffix + '.hdr')
|
||||
if not hdr_path.exists():
|
||||
# 尝试 .hdr 格式
|
||||
hdr_path = file_path.with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
print(f"警告: 未找到头文件 {file_path.with_suffix(file_path.suffix + '.hdr')} 或 {file_path.with_suffix('.hdr')}")
|
||||
|
||||
# 读取ENVI文件 - 使用规范化的路径
|
||||
envi_data = spectral.open_image(file_path_str)
|
||||
|
||||
# 获取数据数组
|
||||
print("正在加载数据到内存...")
|
||||
data_array = envi_data.load()
|
||||
|
||||
# 获取波长信息
|
||||
if hasattr(envi_data, 'bands') and hasattr(envi_data.bands, 'centers'):
|
||||
self.wavelengths = np.array(envi_data.bands.centers)
|
||||
print(f"读取到波长信息: {len(self.wavelengths)} 个波段")
|
||||
else:
|
||||
self.wavelengths = np.arange(data_array.shape[2])
|
||||
print(f"未找到波长信息,使用默认波长: 0-{len(self.wavelengths)-1}")
|
||||
|
||||
# 重塑数据为 (samples, bands)
|
||||
rows, cols, bands = data_array.shape
|
||||
self.data = data_array.reshape(-1, bands).astype(np.float32)
|
||||
self.data_shape = (rows, cols, bands)
|
||||
|
||||
print(f"ENVI数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取ENVI文件失败: {str(e)}。请确保spectral库可用且文件格式正确。")
|
||||
|
||||
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
Args:
|
||||
data: 输入数据数组
|
||||
|
||||
Returns:
|
||||
处理后的数据数组
|
||||
"""
|
||||
print("正在预处理数据...")
|
||||
|
||||
# 检查数据有效性
|
||||
if data is None:
|
||||
raise ValueError("请先调用load_data()加载数据")
|
||||
|
||||
# 数据清理:移除NaN和无穷大值
|
||||
data_clean = self._clean_data(data)
|
||||
|
||||
# 标准化
|
||||
if self.config.use_scaling:
|
||||
data_scaled = self._scale_data(data_clean)
|
||||
else:
|
||||
data_scaled = data_clean
|
||||
|
||||
# PCA降维
|
||||
if self.config.use_pca:
|
||||
data_processed = self._apply_pca(data_scaled)
|
||||
else:
|
||||
data_processed = data_scaled
|
||||
|
||||
print(f"数据预处理完成: {data_processed.shape}")
|
||||
return data_processed
|
||||
|
||||
def _clean_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""清理数据:移除无效值"""
|
||||
# 移除包含NaN或无穷大值的行
|
||||
valid_mask = np.all(np.isfinite(data), axis=1)
|
||||
data_clean = data[valid_mask]
|
||||
|
||||
if np.sum(~valid_mask) > 0:
|
||||
print(f"移除了 {np.sum(~valid_mask)} 个包含无效值的样本")
|
||||
|
||||
# 移除全零行
|
||||
non_zero_mask = np.any(data_clean != 0, axis=1)
|
||||
data_clean = data_clean[non_zero_mask]
|
||||
|
||||
if np.sum(~non_zero_mask) > 0:
|
||||
print(f"移除了 {np.sum(~non_zero_mask)} 个全零样本")
|
||||
|
||||
return data_clean
|
||||
|
||||
def _scale_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""标准化数据"""
|
||||
if self.scaler is None:
|
||||
self.scaler = StandardScaler()
|
||||
|
||||
data_scaled = self.scaler.fit_transform(data)
|
||||
return data_scaled
|
||||
|
||||
def _apply_pca(self, data: np.ndarray) -> np.ndarray:
|
||||
"""应用PCA降维"""
|
||||
n_samples, n_features = data.shape
|
||||
|
||||
# 确定PCA维度
|
||||
if self.config.n_components is None:
|
||||
# 自动选择维度:保留95%的方差或最小维度
|
||||
n_components = min(n_features, max(2, int(n_features * 0.95)))
|
||||
else:
|
||||
n_components = min(self.config.n_components, n_features)
|
||||
|
||||
if n_features <= n_components:
|
||||
print(f"特征数({n_features})小于等于目标维度({n_components}),跳过PCA降维")
|
||||
return data
|
||||
|
||||
if self.pca is None:
|
||||
self.pca = PCA(n_components=n_components, random_state=42)
|
||||
|
||||
data_pca = self.pca.fit_transform(data)
|
||||
|
||||
explained_var = np.sum(self.pca.explained_variance_ratio_)
|
||||
print(f"PCA降维完成: {n_features} -> {n_components} 维度,解释方差: {explained_var:.3f}")
|
||||
|
||||
return data_pca
|
||||
|
||||
def tune_parameters(self, data: np.ndarray) -> Dict[str, Any]:
|
||||
"""
|
||||
使用网格搜索调优参数
|
||||
|
||||
Args:
|
||||
data: 训练数据
|
||||
|
||||
Returns:
|
||||
最佳参数字典
|
||||
"""
|
||||
print("正在进行参数调优...")
|
||||
|
||||
# 定义参数网格
|
||||
param_grid = {
|
||||
'nu': [0.01, 0.05, 0.1, 0.15, 0.2]
|
||||
}
|
||||
|
||||
if self.config.kernel == 'rbf':
|
||||
param_grid['gamma'] = ['auto', 'scale', 0.001, 0.01, 0.1, 1.0]
|
||||
elif self.config.kernel == 'poly':
|
||||
param_grid['gamma'] = ['auto', 'scale', 0.001, 0.01, 0.1, 1.0]
|
||||
param_grid['degree'] = [2, 3, 4]
|
||||
param_grid['coef0'] = [0.0, 0.1, 1.0]
|
||||
|
||||
# 创建基础模型
|
||||
base_svm = OneClassSVM(
|
||||
kernel=self.config.kernel,
|
||||
degree=self.config.degree,
|
||||
coef0=self.config.coef0,
|
||||
tol=self.config.tol,
|
||||
shrinking=self.config.shrinking,
|
||||
cache_size=self.config.cache_size,
|
||||
max_iter=self.config.max_iter
|
||||
)
|
||||
|
||||
# 对于One-Class SVM,我们使用决策函数的负值作为评分
|
||||
# (决策函数值越负表示越可能是异常)
|
||||
def anomaly_score(estimator, X):
|
||||
scores = estimator.decision_function(X)
|
||||
return -np.mean(scores) # 负的平均决策函数值
|
||||
|
||||
scorer = make_scorer(anomaly_score)
|
||||
|
||||
# 网格搜索
|
||||
grid_search = GridSearchCV(
|
||||
base_svm,
|
||||
param_grid,
|
||||
cv=self.config.cv_folds,
|
||||
scoring=scorer,
|
||||
n_jobs=-1,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
grid_search.fit(data)
|
||||
|
||||
self.best_params = grid_search.best_params_
|
||||
print(f"最佳参数: {self.best_params}")
|
||||
print(f"最佳评分: {grid_search.best_score_:.4f}")
|
||||
|
||||
return self.best_params
|
||||
|
||||
def train_model(self, data: np.ndarray) -> None:
|
||||
"""
|
||||
训练One-Class SVM模型
|
||||
|
||||
Args:
|
||||
data: 训练数据(正常样本)
|
||||
"""
|
||||
print("正在训练One-Class SVM模型...")
|
||||
|
||||
# 参数调优
|
||||
if self.config.use_grid_search:
|
||||
best_params = self.tune_parameters(data)
|
||||
# 更新配置参数
|
||||
for param, value in best_params.items():
|
||||
setattr(self.config, param, value)
|
||||
|
||||
# 设置gamma参数
|
||||
gamma_value = self.config.gamma
|
||||
if self.config.gamma == 'auto':
|
||||
gamma_value = 1.0 / data.shape[1] # 1/特征数
|
||||
elif self.config.gamma == 'scale':
|
||||
gamma_value = 1.0 / (data.shape[1] * data.var())
|
||||
elif self.config.gamma is None:
|
||||
gamma_value = 'scale' # 默认使用'scale'
|
||||
|
||||
# 创建One-Class SVM模型
|
||||
self.svm_model = OneClassSVM(
|
||||
kernel=self.config.kernel,
|
||||
nu=self.config.nu,
|
||||
gamma=gamma_value,
|
||||
degree=self.config.degree,
|
||||
coef0=self.config.coef0,
|
||||
tol=self.config.tol,
|
||||
shrinking=self.config.shrinking,
|
||||
cache_size=self.config.cache_size,
|
||||
max_iter=self.config.max_iter
|
||||
)
|
||||
|
||||
# 训练模型
|
||||
self.svm_model.fit(data)
|
||||
|
||||
# 计算训练集上的决策函数值
|
||||
train_scores = self.svm_model.decision_function(data)
|
||||
train_predictions = self.svm_model.predict(data)
|
||||
|
||||
n_outliers_train = np.sum(train_predictions == -1)
|
||||
outlier_ratio_train = n_outliers_train / len(data)
|
||||
|
||||
print("模型训练完成:")
|
||||
print(f" - 核函数: {self.config.kernel}")
|
||||
print(f" - nu: {self.config.nu}")
|
||||
print(f" - gamma: {gamma_value}")
|
||||
if self.config.kernel == 'poly':
|
||||
print(f" - degree: {self.config.degree}")
|
||||
print(f" - 训练集异常比例: {outlier_ratio_train:.3f}")
|
||||
print(".4f")
|
||||
|
||||
def detect_anomalies(self, data: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
检测异常
|
||||
|
||||
Args:
|
||||
data: 测试数据,如果为None则使用训练数据
|
||||
|
||||
Returns:
|
||||
predictions: 预测结果 (+1为正常,-1为异常)
|
||||
scores: 决策函数值
|
||||
"""
|
||||
if self.svm_model is None:
|
||||
raise ValueError("请先调用train_model()训练模型")
|
||||
|
||||
if data is None:
|
||||
data = self.data
|
||||
|
||||
print("正在检测异常...")
|
||||
|
||||
# 预处理数据(使用训练时的预处理器)
|
||||
data_processed = self.preprocess_data(data)
|
||||
|
||||
# 预测异常
|
||||
predictions = self.svm_model.predict(data_processed)
|
||||
scores = self.svm_model.decision_function(data_processed)
|
||||
|
||||
self.predictions = predictions
|
||||
self.decision_scores = scores
|
||||
|
||||
# 统计信息
|
||||
n_anomalies = np.sum(predictions == -1)
|
||||
anomaly_ratio = n_anomalies / len(predictions)
|
||||
|
||||
print("异常检测完成:")
|
||||
print(f" - 测试样本数: {len(predictions)}")
|
||||
print(f" - 异常样本数: {n_anomalies}")
|
||||
print(f" - 异常比例: {anomaly_ratio:.3f}")
|
||||
print(".4f")
|
||||
print(".4f")
|
||||
|
||||
return predictions, scores
|
||||
|
||||
def run_analysis_from_config(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
从配置运行完整的异常检测分析
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: (predictions, scores)
|
||||
"""
|
||||
# 预处理数据
|
||||
processed_data = self.preprocess_data(self.data)
|
||||
|
||||
# 训练模型
|
||||
self.train_model(processed_data)
|
||||
|
||||
# 执行异常检测
|
||||
predictions, scores = self.detect_anomalies()
|
||||
|
||||
return predictions, scores
|
||||
|
||||
def save_results(self, output_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
保存异常检测结果为ENVI格式文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径,如果为None则使用配置中的路径
|
||||
"""
|
||||
if output_path is None:
|
||||
if self.config.output_path:
|
||||
output_path = self.config.output_path
|
||||
elif self.config.output_dir:
|
||||
base_name = Path(self.config.input_path).stem
|
||||
output_path = os.path.join(self.config.output_dir, f"{base_name}_anomaly_ocsvm.bip")
|
||||
else:
|
||||
raise ValueError("必须指定输出路径")
|
||||
|
||||
print(f"正在保存结果到: {output_path}")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.dirname(output_path)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 根据输入格式决定输出格式
|
||||
if self.input_format == 'envi':
|
||||
self._save_envi_results(output_path)
|
||||
else:
|
||||
# 对于CSV输入,创建虚拟的2D图像
|
||||
self._save_csv_results_as_envi(output_path)
|
||||
|
||||
def _save_envi_results(self, output_path: str) -> None:
|
||||
"""保存ENVI格式结果"""
|
||||
# 重塑预测结果为原始图像尺寸
|
||||
rows, cols, bands = self.data_shape
|
||||
|
||||
# 创建异常检测结果图像
|
||||
# 结果包括:预测结果、决策函数值
|
||||
anomaly_image = np.zeros((rows, cols, 2), dtype=np.float32)
|
||||
|
||||
# 需要将1D结果映射回2D图像
|
||||
predictions_2d = self.predictions.reshape(rows, cols)
|
||||
scores_2d = self.decision_scores.reshape(rows, cols)
|
||||
|
||||
anomaly_image[:, :, 0] = predictions_2d # 预测结果 (-1, 1)
|
||||
anomaly_image[:, :, 1] = scores_2d # 决策函数值
|
||||
|
||||
# 波段名称
|
||||
band_names = ["Anomaly_Prediction", "Decision_Score"]
|
||||
|
||||
# 保存为ENVI格式,参考Covariance.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_csv_results_as_envi(self, output_path: str) -> None:
|
||||
"""将CSV结果保存为ENVI格式"""
|
||||
# 对于CSV数据,创建1xN的图像
|
||||
n_samples = len(self.predictions)
|
||||
anomaly_image = np.zeros((1, n_samples, 2), dtype=np.float32)
|
||||
|
||||
anomaly_image[0, :, 0] = self.predictions # 预测结果
|
||||
anomaly_image[0, :, 1] = self.decision_scores # 决策函数值
|
||||
|
||||
# 波段名称
|
||||
band_names = ["Anomaly_Prediction", "Decision_Score"]
|
||||
|
||||
# 保存为ENVI格式,参考Covariance.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_envi_data(self, data: np.ndarray, output_path: Union[str, Path],
|
||||
band_names: List[str], interleave: str = 'bip') -> None:
|
||||
"""
|
||||
保存数据为ENVI格式,参考Covariance.py的实现
|
||||
|
||||
参数:
|
||||
data: 要保存的数据数组 (height, width, channels)
|
||||
output_path: 输出文件路径
|
||||
band_names: 波段名称列表
|
||||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
height, width, channels = data.shape
|
||||
|
||||
# 确保目录存在
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"保存ENVI文件 - 形状: {data.shape}, 交织方式: {interleave}")
|
||||
|
||||
# 保存.dat文件(二进制格式)
|
||||
self._save_dat_file(data, str(output_path))
|
||||
|
||||
# 保存.hdr头文件
|
||||
hdr_path = str(output_path).replace('.dat', '.hdr')
|
||||
self._save_hdr_file(hdr_path, data.shape, band_names, interleave)
|
||||
|
||||
print(f"One-Class SVM异常检测结果已保存: {output_path}")
|
||||
print(f"头文件已保存: {hdr_path}")
|
||||
print(f"使用的交织格式: {interleave.upper()}")
|
||||
|
||||
def _save_dat_file(self, data: np.ndarray, file_path: str) -> None:
|
||||
"""保存.dat文件(二进制格式),参考Covariance.py"""
|
||||
# 根据数据类型保存
|
||||
if data.dtype == np.float32:
|
||||
# 对于浮点数据,直接保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
else:
|
||||
# 对于其他数据类型,转换为float32保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
|
||||
def _save_hdr_file(self, hdr_path: str, data_shape: Tuple[int, ...],
|
||||
band_names: List[str], interleave: str) -> None:
|
||||
"""保存ENVI头文件,参考Covariance.py并适配One-Class SVM数据"""
|
||||
height, width, channels = data_shape
|
||||
|
||||
# 确定数据类型编码
|
||||
# ENVI数据类型: 4=float32, 5=float64, 1=uint8, 2=int16, 3=int32, 12=uint16
|
||||
data_type_code = 4 # 默认float32
|
||||
|
||||
header_content = f"""ENVI
|
||||
description = {{One-Class SVM Anomaly Detection Results - Generated by OneClassSVMAnomalyDetector}}
|
||||
samples = {width}
|
||||
lines = {height}
|
||||
bands = {channels}
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = {data_type_code}
|
||||
interleave = {interleave}
|
||||
byte order = 0
|
||||
"""
|
||||
|
||||
# 添加波段名称
|
||||
if band_names:
|
||||
header_content += "band names = {\n"
|
||||
for i, name in enumerate(band_names):
|
||||
header_content += f' "{name}"'
|
||||
if i < len(band_names) - 1:
|
||||
header_content += ","
|
||||
header_content += "\n"
|
||||
header_content += "}\n"
|
||||
|
||||
# 添加One-Class SVM相关元数据
|
||||
header_content += f"anomaly_detection_method = one_class_svm\n"
|
||||
header_content += f"kernel = {self.config.kernel}\n"
|
||||
header_content += f"nu = {self.config.nu}\n"
|
||||
if self.config.gamma is not None:
|
||||
header_content += f"gamma = {self.config.gamma}\n"
|
||||
if self.config.kernel == 'poly':
|
||||
header_content += f"degree = {self.config.degree}\n"
|
||||
header_content += f"coef0 = {self.config.coef0}\n"
|
||||
header_content += f"use_pca = {self.config.use_pca}\n"
|
||||
if self.config.use_pca and self.config.n_components is not None:
|
||||
header_content += f"n_components = {self.config.n_components}\n"
|
||||
if self.best_params:
|
||||
header_content += f"best_params = {self.best_params}\n"
|
||||
|
||||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||||
f.write(header_content)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取异常检测统计信息
|
||||
|
||||
Returns:
|
||||
包含各种统计信息的字典
|
||||
"""
|
||||
if self.predictions is None:
|
||||
raise ValueError("请先运行异常检测")
|
||||
|
||||
stats = {
|
||||
'total_samples': len(self.predictions),
|
||||
'n_anomalies': int(np.sum(self.predictions == -1)),
|
||||
'anomaly_ratio': float(np.sum(self.predictions == -1) / len(self.predictions)),
|
||||
'kernel': self.config.kernel,
|
||||
'nu': self.config.nu,
|
||||
'gamma': self.config.gamma,
|
||||
'mean_decision_score': float(np.mean(self.decision_scores)),
|
||||
'std_decision_score': float(np.std(self.decision_scores)),
|
||||
'min_decision_score': float(np.min(self.decision_scores)),
|
||||
'max_decision_score': float(np.max(self.decision_scores))
|
||||
}
|
||||
|
||||
if self.config.kernel == 'poly':
|
||||
stats['degree'] = self.config.degree
|
||||
|
||||
if self.best_params:
|
||||
stats['best_params'] = self.best_params
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:命令行接口"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='One-Class SVM高光谱异常检测',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
python One_Class_SVM.py --input data.csv --spectral-start 400 --kernel rbf --nu 0.1 --output results.bip
|
||||
python One_Class_SVM.py --input image.bip --kernel poly --degree 3 --grid-search --output results.bip
|
||||
python One_Class_SVM.py --input data.csv --spectral-start 400 --kernel linear --no-pca --output results.bip
|
||||
"""
|
||||
)
|
||||
|
||||
# 输入文件参数
|
||||
parser.add_argument('--input', required=True, help='输入文件路径 (CSV或ENVI格式)')
|
||||
parser.add_argument('--label-col', help='标签列名 (CSV文件)')
|
||||
parser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
|
||||
parser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
|
||||
|
||||
# 输出参数
|
||||
parser.add_argument('--output', help='输出文件路径')
|
||||
parser.add_argument('--output-dir', help='输出目录 (将自动生成文件名)')
|
||||
|
||||
# SVM参数
|
||||
parser.add_argument('--kernel', choices=['linear', 'rbf', 'poly'], default='rbf',
|
||||
help='核函数类型,默认rbf')
|
||||
parser.add_argument('--nu', type=float, default=0.1,
|
||||
help='训练误差比例 (0.01-0.5),默认0.1')
|
||||
parser.add_argument('--gamma', type=float,
|
||||
help='核函数gamma参数,默认auto (1/特征数)')
|
||||
parser.add_argument('--degree', type=int, default=3,
|
||||
help='多项式核的阶数,默认3')
|
||||
parser.add_argument('--coef0', type=float, default=0.0,
|
||||
help='核函数常数项,默认0.0')
|
||||
|
||||
# 预处理参数
|
||||
parser.add_argument('--no-scaling', action='store_true',
|
||||
help='禁用数据标准化')
|
||||
parser.add_argument('--no-pca', action='store_true',
|
||||
help='禁用PCA降维')
|
||||
parser.add_argument('--n-components', type=int,
|
||||
help='PCA降维维度,默认自动选择')
|
||||
|
||||
# 参数调优
|
||||
parser.add_argument('--grid-search', action='store_true',
|
||||
help='使用网格搜索调优参数')
|
||||
parser.add_argument('--cv-folds', type=int, default=3,
|
||||
help='交叉验证折数,默认3')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# 创建配置
|
||||
config = OneClassSVMConfig(
|
||||
input_path=args.input,
|
||||
label_col=args.label_col,
|
||||
spectral_start=args.spectral_start,
|
||||
spectral_end=args.spectral_end,
|
||||
output_path=args.output,
|
||||
output_dir=args.output_dir,
|
||||
kernel=args.kernel,
|
||||
nu=args.nu,
|
||||
gamma=args.gamma,
|
||||
degree=args.degree,
|
||||
coef0=args.coef0,
|
||||
use_scaling=not args.no_scaling,
|
||||
use_pca=not args.no_pca,
|
||||
n_components=args.n_components,
|
||||
use_grid_search=args.grid_search,
|
||||
cv_folds=args.cv_folds
|
||||
)
|
||||
|
||||
# 创建检测器
|
||||
detector = OneClassSVMAnomalyDetector(config)
|
||||
|
||||
# 执行异常检测流程
|
||||
print("=== One-Class SVM异常检测 ===")
|
||||
print(f"输入文件: {config.input_path}")
|
||||
print(f"核函数: {config.kernel}")
|
||||
print(f"nu: {config.nu}")
|
||||
print(f"gamma: {'auto' if config.gamma is None else config.gamma}")
|
||||
if config.kernel == 'poly':
|
||||
print(f"degree: {config.degree}")
|
||||
print(f"参数调优: {'启用' if config.use_grid_search else '禁用'}")
|
||||
print("-" * 50)
|
||||
|
||||
# 1. 加载数据
|
||||
detector.load_data()
|
||||
|
||||
# 2. 预处理数据
|
||||
processed_data = detector.preprocess_data(detector.data)
|
||||
|
||||
# 3. 训练模型
|
||||
detector.train_model(processed_data)
|
||||
|
||||
# 4. 执行异常检测
|
||||
predictions, scores = detector.detect_anomalies(processed_data)
|
||||
|
||||
# 5. 保存结果
|
||||
detector.save_results()
|
||||
|
||||
# 6. 显示统计信息
|
||||
stats = detector.get_statistics()
|
||||
print("\n=== 检测结果统计 ===")
|
||||
for key, value in stats.items():
|
||||
print(f"{key}: {value}")
|
||||
|
||||
print("\n处理完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
883
Anomaly_method/RX.py
Normal file
883
Anomaly_method/RX.py
Normal file
@ -0,0 +1,883 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import spectral
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.covariance import EmpiricalCovariance, MinCovDet
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Dict, Any, Tuple, List, Union
|
||||
from dataclasses import dataclass, field
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from scipy import stats, linalg
|
||||
import cv2
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
@dataclass
|
||||
class RXConfig:
|
||||
"""RX异常检测配置类"""
|
||||
|
||||
# 输入文件配置
|
||||
input_path: Optional[str] = None
|
||||
background_path: Optional[str] = None # 用户指定背景数据
|
||||
label_col: Optional[str] = None
|
||||
spectral_start: Optional[str] = None
|
||||
spectral_end: Optional[str] = None
|
||||
csv_has_header: bool = True
|
||||
|
||||
# 输出配置
|
||||
output_path: Optional[str] = None
|
||||
output_dir: Optional[str] = None
|
||||
|
||||
# 背景模型配置
|
||||
background_model: str = 'global' # 'global', 'local_row', 'local_square', 'custom'
|
||||
window_size: int = 5 # 局部窗口大小
|
||||
row_window: int = 3 # 行窗口大小(行数)
|
||||
|
||||
# 检测参数
|
||||
threshold: Optional[float] = None # 异常检测阈值,None表示不进行阈值分割
|
||||
contamination: float = 0.1 # 异常点比例,用于自动确定阈值
|
||||
use_robust_covariance: bool = False # 是否使用稳健协方差估计
|
||||
|
||||
# 数据预处理配置
|
||||
use_scaling: bool = False # 是否使用标准化(RX算法通常不使用)
|
||||
|
||||
def __post_init__(self):
|
||||
"""参数校验和默认值设置"""
|
||||
# 校验必需的文件路径
|
||||
if not self.input_path:
|
||||
raise ValueError("必须指定输入文件路径(input_path)")
|
||||
|
||||
# 校验文件存在性
|
||||
if not os.path.exists(self.input_path):
|
||||
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
|
||||
|
||||
if self.background_path and not os.path.exists(self.background_path):
|
||||
raise FileNotFoundError(f"背景文件不存在: {self.background_path}")
|
||||
|
||||
# 校验输出路径
|
||||
if not self.output_path and not self.output_dir:
|
||||
raise ValueError("必须指定输出路径(output_path)或输出目录(output_dir)")
|
||||
|
||||
# 校验背景模型类型
|
||||
supported_models = ['global', 'local_row', 'local_square', 'custom']
|
||||
if self.background_model not in supported_models:
|
||||
raise ValueError(f"不支持的背景模型: {self.background_model}。支持的模型: {supported_models}")
|
||||
|
||||
# 校验参数范围
|
||||
if self.window_size <= 0:
|
||||
raise ValueError("window_size必须大于0")
|
||||
|
||||
if self.row_window <= 0:
|
||||
raise ValueError("row_window必须大于0")
|
||||
|
||||
if self.background_model == 'custom' and not self.background_path:
|
||||
raise ValueError("使用custom背景模型时,必须指定background_path")
|
||||
|
||||
|
||||
class RXAnomalyDetector:
|
||||
"""
|
||||
Reed-Xiaoli (RX) 异常检测器
|
||||
|
||||
RX算法基于Mahalanobis距离检测高光谱数据中的异常像素。
|
||||
算法计算每个像素相对于背景分布的距离,距离越大的像素越可能是异常。
|
||||
|
||||
数学基础:
|
||||
D(x) = (x - μ)ᵀ K⁻¹ (x - μ)
|
||||
|
||||
其中:
|
||||
- x: 测试像素的光谱向量
|
||||
- μ: 背景数据的均值向量
|
||||
- K: 背景数据的协方差矩阵
|
||||
"""
|
||||
|
||||
def __init__(self, config: RXConfig):
|
||||
"""
|
||||
初始化RX异常检测器
|
||||
|
||||
Args:
|
||||
config: 配置对象
|
||||
"""
|
||||
self.config = config
|
||||
self.data = None
|
||||
self.background_data = None
|
||||
self.wavelengths = None
|
||||
self.data_shape = None
|
||||
self.input_format = None
|
||||
|
||||
# 背景统计量
|
||||
self.background_mean = None
|
||||
self.background_cov = None
|
||||
self.background_inv_cov = None
|
||||
|
||||
# 结果
|
||||
self.distance_map = None # Mahalanobis距离图
|
||||
self.anomaly_map = None # 异常检测结果
|
||||
|
||||
# 预处理组件
|
||||
self.scaler = None
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""
|
||||
加载高光谱数据文件
|
||||
|
||||
支持CSV和ENVI格式文件
|
||||
"""
|
||||
file_path = Path(self.config.input_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
print(f"正在读取输入数据: {file_path}")
|
||||
|
||||
if suffix == '.csv':
|
||||
self._load_csv_data(file_path, is_background=False)
|
||||
elif suffix in ['.hdr']:
|
||||
self._load_envi_data(file_path, is_background=False)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}。支持的格式: .hdr")
|
||||
|
||||
# 加载背景数据
|
||||
if self.config.background_model == 'custom':
|
||||
self._load_background_data()
|
||||
|
||||
def _load_csv_data(self, file_path: Path, is_background: bool = False) -> None:
|
||||
"""加载CSV格式数据"""
|
||||
if self.config.spectral_start is None:
|
||||
raise ValueError("对于CSV文件,必须指定spectral_start参数")
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv(file_path, header=0 if self.config.csv_has_header else None)
|
||||
|
||||
# 提取标签列
|
||||
if self.config.label_col and self.config.label_col in df.columns:
|
||||
labels = df[self.config.label_col].values
|
||||
spectral_cols = [col for col in df.columns if col != self.config.label_col]
|
||||
else:
|
||||
labels = None
|
||||
spectral_cols = df.columns.tolist()
|
||||
|
||||
# 提取光谱数据
|
||||
if self.config.spectral_end:
|
||||
end_idx = spectral_cols.index(self.config.spectral_end) + 1
|
||||
else:
|
||||
end_idx = len(spectral_cols)
|
||||
|
||||
start_idx = spectral_cols.index(self.config.spectral_start)
|
||||
spectral_data = df[spectral_cols[start_idx:end_idx]].values
|
||||
|
||||
# 生成波长信息
|
||||
wavelengths = np.arange(len(spectral_cols[start_idx:end_idx]))
|
||||
|
||||
# 存储数据
|
||||
if is_background:
|
||||
self.background_data = spectral_data.astype(np.float32)
|
||||
print(f"背景数据加载完成: {self.background_data.shape} 样本 x {self.background_data.shape[1]} 波段")
|
||||
else:
|
||||
self.data = spectral_data.astype(np.float32)
|
||||
self.data_shape = self.data.shape
|
||||
self.input_format = 'csv'
|
||||
self.wavelengths = wavelengths
|
||||
print(f"输入数据加载完成: {self.data.shape} 样本 x {self.data.shape[1]} 波段")
|
||||
|
||||
def _load_envi_data(self, file_path: Path, is_background: bool = False) -> None:
|
||||
"""加载ENVI格式数据"""
|
||||
try:
|
||||
# 确保文件路径存在且可访问
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"ENVI文件不存在: {file_path}")
|
||||
|
||||
# 检查文件是否可读
|
||||
if not os.access(file_path, os.R_OK):
|
||||
raise PermissionError(f"没有读取权限: {file_path}")
|
||||
|
||||
# 使用规范化的路径字符串
|
||||
file_path_str = str(file_path)
|
||||
print(f"尝试使用spectral库读取ENVI文件: {file_path_str}")
|
||||
|
||||
# 检查对应的头文件
|
||||
hdr_path = file_path.with_suffix(file_path.suffix + '.hdr')
|
||||
if not hdr_path.exists():
|
||||
# 尝试 .hdr 格式
|
||||
hdr_path = file_path.with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
print(f"警告: 未找到头文件 {file_path.with_suffix(file_path.suffix + '.hdr')} 或 {file_path.with_suffix('.hdr')}")
|
||||
|
||||
# 读取ENVI文件 - 使用规范化的路径
|
||||
envi_data = spectral.open_image(file_path_str)
|
||||
|
||||
# 获取数据数组
|
||||
print("正在加载数据到内存...")
|
||||
data_array = envi_data.load()
|
||||
|
||||
# 获取波长信息
|
||||
if hasattr(envi_data, 'bands') and hasattr(envi_data.bands, 'centers'):
|
||||
wavelengths = np.array(envi_data.bands.centers)
|
||||
print(f"读取到波长信息: {len(wavelengths)} 个波段")
|
||||
else:
|
||||
wavelengths = np.arange(data_array.shape[2])
|
||||
print(f"未找到波长信息,使用默认波长: 0-{len(wavelengths)-1}")
|
||||
|
||||
# 重塑数据为 (samples, bands)
|
||||
rows, cols, bands = data_array.shape
|
||||
spectral_data = data_array.reshape(-1, bands).astype(np.float32)
|
||||
|
||||
# 存储数据
|
||||
if is_background:
|
||||
self.background_data = spectral_data
|
||||
print(f"背景数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||||
else:
|
||||
self.data = spectral_data
|
||||
self.data_shape = (rows, cols, bands)
|
||||
self.input_format = 'envi'
|
||||
self.wavelengths = wavelengths
|
||||
print(f"输入数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取ENVI文件失败: {str(e)}。请确保spectral库可用且文件格式正确。")
|
||||
|
||||
def _load_background_data(self) -> None:
|
||||
"""加载用户指定的背景数据"""
|
||||
if not self.config.background_path:
|
||||
raise ValueError("background_path未指定")
|
||||
|
||||
file_path = Path(self.config.background_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
print(f"正在读取背景数据: {file_path}")
|
||||
|
||||
if suffix == '.csv':
|
||||
self._load_csv_data(file_path, is_background=True)
|
||||
elif suffix in ['.bil', '.bsq', '.bip', '.dat']:
|
||||
self._load_envi_data(file_path, is_background=True)
|
||||
else:
|
||||
raise ValueError(f"不支持的背景文件格式: {suffix}")
|
||||
|
||||
def preprocess_data(self) -> None:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
RX算法通常不需要标准化,除非指定
|
||||
"""
|
||||
if self.config.use_scaling:
|
||||
print("正在标准化数据...")
|
||||
self.scaler = StandardScaler()
|
||||
|
||||
# 对输入数据进行标准化
|
||||
if self.data is not None:
|
||||
self.data = self.scaler.fit_transform(self.data)
|
||||
|
||||
# 对背景数据进行相同的标准化
|
||||
if self.background_data is not None:
|
||||
self.background_data = self.scaler.transform(self.background_data)
|
||||
|
||||
def compute_background_statistics(self) -> None:
|
||||
"""
|
||||
计算背景统计量(均值和协方差矩阵)
|
||||
|
||||
根据不同的背景模型计算相应的统计量
|
||||
"""
|
||||
print(f"正在计算背景统计量 (模型: {self.config.background_model})...")
|
||||
|
||||
if self.config.background_model == 'global':
|
||||
self._compute_global_background()
|
||||
elif self.config.background_model == 'custom':
|
||||
self._compute_custom_background()
|
||||
# 局部模型的统计量在检测过程中计算
|
||||
|
||||
def _compute_global_background(self) -> None:
|
||||
"""计算全局背景统计量"""
|
||||
# 使用整个数据集作为背景
|
||||
background_data = self.data
|
||||
|
||||
if self.config.use_robust_covariance:
|
||||
# 使用稳健协方差估计
|
||||
cov_estimator = MinCovDet()
|
||||
cov_estimator.fit(background_data)
|
||||
self.background_mean = cov_estimator.location_
|
||||
self.background_cov = cov_estimator.covariance_
|
||||
else:
|
||||
# 使用经验协方差
|
||||
self.background_mean = np.mean(background_data, axis=0)
|
||||
self.background_cov = np.cov(background_data.T)
|
||||
|
||||
# 计算协方差矩阵的逆
|
||||
try:
|
||||
self.background_inv_cov = linalg.inv(self.background_cov)
|
||||
except np.linalg.LinAlgError:
|
||||
print("警告: 协方差矩阵奇异,使用伪逆")
|
||||
self.background_inv_cov = linalg.pinv(self.background_cov)
|
||||
|
||||
print(f"全局背景统计量计算完成,协方差矩阵形状: {self.background_cov.shape}")
|
||||
|
||||
def _compute_custom_background(self) -> None:
|
||||
"""计算用户指定背景统计量"""
|
||||
if self.background_data is None:
|
||||
raise ValueError("自定义背景数据未加载")
|
||||
|
||||
background_data = self.background_data
|
||||
|
||||
if self.config.use_robust_covariance:
|
||||
cov_estimator = MinCovDet()
|
||||
cov_estimator.fit(background_data)
|
||||
self.background_mean = cov_estimator.location_
|
||||
self.background_cov = cov_estimator.covariance_
|
||||
else:
|
||||
self.background_mean = np.mean(background_data, axis=0)
|
||||
self.background_cov = np.cov(background_data.T)
|
||||
|
||||
# 计算协方差矩阵的逆
|
||||
try:
|
||||
self.background_inv_cov = linalg.inv(self.background_cov)
|
||||
except np.linalg.LinAlgError:
|
||||
print("警告: 协方差矩阵奇异,使用伪逆")
|
||||
self.background_inv_cov = linalg.pinv(self.background_cov)
|
||||
|
||||
print(f"自定义背景统计量计算完成,协方差矩阵形状: {self.background_cov.shape}")
|
||||
|
||||
def _compute_local_background_stats(self, data_window: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
计算局部背景统计量
|
||||
|
||||
Args:
|
||||
data_window: 局部窗口数据 (n_pixels, n_bands)
|
||||
|
||||
Returns:
|
||||
mean_vector, cov_matrix, inv_cov_matrix
|
||||
"""
|
||||
# 移除窗口中的NaN值
|
||||
valid_mask = ~np.isnan(data_window).any(axis=1)
|
||||
valid_data = data_window[valid_mask]
|
||||
|
||||
if len(valid_data) < 2:
|
||||
# 数据不足,返回全局统计量
|
||||
return self.background_mean, self.background_cov, self.background_inv_cov
|
||||
|
||||
if self.config.use_robust_covariance and len(valid_data) > self.background_cov.shape[0]:
|
||||
# 使用稳健协方差估计
|
||||
cov_estimator = MinCovDet()
|
||||
cov_estimator.fit(valid_data)
|
||||
mean_vec = cov_estimator.location_
|
||||
cov_mat = cov_estimator.covariance_
|
||||
else:
|
||||
# 使用经验协方差
|
||||
mean_vec = np.mean(valid_data, axis=0)
|
||||
cov_mat = np.cov(valid_data.T)
|
||||
|
||||
# 计算协方差矩阵的逆
|
||||
try:
|
||||
inv_cov_mat = linalg.inv(cov_mat)
|
||||
except np.linalg.LinAlgError:
|
||||
inv_cov_mat = linalg.pinv(cov_mat)
|
||||
|
||||
return mean_vec, cov_mat, inv_cov_mat
|
||||
|
||||
def compute_mahalanobis_distance(self, x: np.ndarray, mean_vec: np.ndarray,
|
||||
inv_cov_mat: np.ndarray) -> float:
|
||||
"""
|
||||
计算单个像素的Mahalanobis距离
|
||||
|
||||
Args:
|
||||
x: 像素光谱向量
|
||||
mean_vec: 均值向量
|
||||
inv_cov_mat: 协方差矩阵的逆
|
||||
|
||||
Returns:
|
||||
Mahalanobis距离
|
||||
"""
|
||||
diff = x - mean_vec
|
||||
distance = np.dot(np.dot(diff, inv_cov_mat), diff.T)
|
||||
|
||||
# 确保距离为非负实数
|
||||
return max(0, distance)
|
||||
|
||||
def detect_anomalies_global(self) -> np.ndarray:
|
||||
"""
|
||||
使用全局背景模型进行异常检测
|
||||
|
||||
Returns:
|
||||
距离图 (n_samples,)
|
||||
"""
|
||||
print("正在进行全局RX异常检测...")
|
||||
|
||||
n_samples = self.data.shape[0]
|
||||
distance_map = np.zeros(n_samples, dtype=np.float32)
|
||||
|
||||
# 对每个像素计算Mahalanobis距离
|
||||
for i in range(n_samples):
|
||||
distance_map[i] = self.compute_mahalanobis_distance(
|
||||
self.data[i], self.background_mean, self.background_inv_cov
|
||||
)
|
||||
|
||||
return distance_map
|
||||
|
||||
def detect_anomalies_local_row(self) -> np.ndarray:
|
||||
"""
|
||||
使用局部行背景模型进行异常检测
|
||||
|
||||
Returns:
|
||||
距离图 (rows, cols)
|
||||
"""
|
||||
print("正在进行局部行RX异常检测...")
|
||||
|
||||
rows, cols, bands = self.data_shape
|
||||
distance_map = np.zeros((rows, cols), dtype=np.float32)
|
||||
|
||||
# 将数据重塑为2D图像
|
||||
image_data = self.data.reshape(rows, cols, bands)
|
||||
|
||||
half_window = self.config.row_window // 2
|
||||
|
||||
for i in range(rows):
|
||||
# 计算行窗口范围
|
||||
row_start = max(0, i - half_window)
|
||||
row_end = min(rows, i + half_window + 1)
|
||||
|
||||
# 提取行窗口数据
|
||||
row_window = image_data[row_start:row_end, :, :].reshape(-1, bands)
|
||||
|
||||
# 计算局部背景统计量
|
||||
local_mean, local_cov, local_inv_cov = self._compute_local_background_stats(row_window)
|
||||
|
||||
# 计算当前行的距离
|
||||
for j in range(cols):
|
||||
pixel = image_data[i, j, :]
|
||||
distance_map[i, j] = self.compute_mahalanobis_distance(
|
||||
pixel, local_mean, local_inv_cov
|
||||
)
|
||||
|
||||
return distance_map.flatten()
|
||||
|
||||
def detect_anomalies_local_square(self) -> np.ndarray:
|
||||
"""
|
||||
使用局部正方形背景模型进行异常检测
|
||||
|
||||
Returns:
|
||||
距离图 (rows, cols)
|
||||
"""
|
||||
print("正在进行局部正方形RX异常检测...")
|
||||
|
||||
rows, cols, bands = self.data_shape
|
||||
distance_map = np.zeros((rows, cols), dtype=np.float32)
|
||||
|
||||
# 将数据重塑为2D图像
|
||||
image_data = self.data.reshape(rows, cols, bands)
|
||||
|
||||
half_window = self.config.window_size // 2
|
||||
|
||||
# 使用向量化计算优化性能
|
||||
for i in range(rows):
|
||||
for j in range(cols):
|
||||
# 计算窗口范围
|
||||
row_start = max(0, i - half_window)
|
||||
row_end = min(rows, i + half_window + 1)
|
||||
col_start = max(0, j - half_window)
|
||||
col_end = min(cols, j + half_window + 1)
|
||||
|
||||
# 提取窗口数据
|
||||
window_data = image_data[row_start:row_end, col_start:col_end, :].reshape(-1, bands)
|
||||
|
||||
# 计算局部背景统计量
|
||||
local_mean, local_cov, local_inv_cov = self._compute_local_background_stats(window_data)
|
||||
|
||||
# 计算当前像素的距离
|
||||
pixel = image_data[i, j, :]
|
||||
distance_map[i, j] = self.compute_mahalanobis_distance(
|
||||
pixel, local_mean, local_inv_cov
|
||||
)
|
||||
|
||||
return distance_map.flatten()
|
||||
|
||||
def detect_anomalies(self) -> np.ndarray:
|
||||
"""
|
||||
执行异常检测
|
||||
|
||||
Returns:
|
||||
Mahalanobis距离图
|
||||
"""
|
||||
if self.config.background_model == 'global':
|
||||
self.compute_background_statistics()
|
||||
distance_map = self.detect_anomalies_global()
|
||||
elif self.config.background_model == 'local_row':
|
||||
distance_map = self.detect_anomalies_local_row()
|
||||
elif self.config.background_model == 'local_square':
|
||||
distance_map = self.detect_anomalies_local_square()
|
||||
elif self.config.background_model == 'custom':
|
||||
self.compute_background_statistics()
|
||||
distance_map = self.detect_anomalies_global()
|
||||
|
||||
self.distance_map = distance_map
|
||||
|
||||
# 应用阈值分割(如果指定)
|
||||
if self.config.threshold is not None:
|
||||
self.anomaly_map = (distance_map > self.config.threshold).astype(np.int32)
|
||||
else:
|
||||
self.anomaly_map = None
|
||||
|
||||
# 统计信息
|
||||
mean_dist = np.mean(distance_map)
|
||||
std_dist = np.std(distance_map)
|
||||
max_dist = np.max(distance_map)
|
||||
|
||||
print("异常检测完成:")
|
||||
print(".4f")
|
||||
print(".4f")
|
||||
print(".4f")
|
||||
|
||||
if self.anomaly_map is not None:
|
||||
n_anomalies = np.sum(self.anomaly_map)
|
||||
anomaly_ratio = n_anomalies / len(distance_map)
|
||||
print(".3f")
|
||||
|
||||
return distance_map
|
||||
|
||||
def run_analysis_from_config(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
从配置运行完整的异常检测分析
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: (predictions, scores)
|
||||
"""
|
||||
# 执行异常检测,获取距离图作为分数
|
||||
scores = self.detect_anomalies()
|
||||
|
||||
# 生成预测结果(基于阈值或百分位数)
|
||||
if self.config.threshold is not None:
|
||||
predictions = (scores > self.config.threshold).astype(np.int32)
|
||||
else:
|
||||
# 使用contamination参数确定阈值
|
||||
# contamination表示异常比例,所以使用(1-contamination)*100百分位数
|
||||
percentile = (1 - self.config.contamination) * 100
|
||||
threshold = np.percentile(scores, percentile)
|
||||
predictions = (scores > threshold).astype(np.int32)
|
||||
|
||||
return predictions, scores
|
||||
|
||||
def save_results(self, output_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
保存异常检测结果为ENVI格式文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径,如果为None则使用配置中的路径
|
||||
"""
|
||||
if output_path is None:
|
||||
if self.config.output_path:
|
||||
output_path = self.config.output_path
|
||||
elif self.config.output_dir:
|
||||
base_name = Path(self.config.input_path).stem
|
||||
output_path = os.path.join(self.config.output_dir, f"{base_name}_rx_anomaly.bip")
|
||||
else:
|
||||
raise ValueError("必须指定输出路径")
|
||||
|
||||
print(f"正在保存结果到: {output_path}")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.dirname(output_path)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 根据输入格式决定输出格式
|
||||
if self.input_format == 'envi':
|
||||
self._save_envi_results(output_path)
|
||||
else:
|
||||
# 对于CSV输入,创建虚拟的2D图像
|
||||
self._save_csv_results_as_envi(output_path)
|
||||
|
||||
def _save_envi_results(self, output_path: str) -> None:
|
||||
"""保存ENVI格式结果"""
|
||||
# 重塑距离图为原始图像尺寸
|
||||
rows, cols, bands = self.data_shape
|
||||
|
||||
# 创建单波段异常检测结果图像
|
||||
if self.anomaly_map is not None:
|
||||
# 保存异常检测结果 (0/1)
|
||||
anomaly_image = np.zeros((rows, cols, 1), dtype=np.float32)
|
||||
anomaly_image[:, :, 0] = self.anomaly_map.reshape(rows, cols)
|
||||
band_names = ["Anomaly_Map"]
|
||||
else:
|
||||
# 保存距离图
|
||||
anomaly_image = np.zeros((rows, cols, 1), dtype=np.float32)
|
||||
anomaly_image[:, :, 0] = self.distance_map.reshape(rows, cols)
|
||||
band_names = ["Mahalanobis_Distance"]
|
||||
|
||||
# 保存为ENVI格式,参考Covariance.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_csv_results_as_envi(self, output_path: str) -> None:
|
||||
"""将CSV结果保存为ENVI格式"""
|
||||
# 对于CSV数据,创建1xN的单波段图像
|
||||
n_samples = len(self.distance_map)
|
||||
if self.anomaly_map is not None:
|
||||
anomaly_image = np.zeros((1, n_samples, 1), dtype=np.float32)
|
||||
anomaly_image[0, :, 0] = self.anomaly_map
|
||||
band_names = ["Anomaly_Map"]
|
||||
else:
|
||||
anomaly_image = np.zeros((1, n_samples, 1), dtype=np.float32)
|
||||
anomaly_image[0, :, 0] = self.distance_map
|
||||
band_names = ["Mahalanobis_Distance"]
|
||||
|
||||
# 保存为ENVI格式,参考Covariance.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_envi_data(self, data: np.ndarray, output_path: Union[str, Path],
|
||||
band_names: List[str], interleave: str = 'bip') -> None:
|
||||
"""
|
||||
保存数据为ENVI格式,参考Covariance.py的实现
|
||||
|
||||
参数:
|
||||
data: 要保存的数据数组 (height, width, channels)
|
||||
output_path: 输出文件路径
|
||||
band_names: 波段名称列表
|
||||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
height, width, channels = data.shape
|
||||
|
||||
# 确保目录存在
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"保存ENVI文件 - 形状: {data.shape}, 交织方式: {interleave}")
|
||||
|
||||
# 保存.dat文件(二进制格式)
|
||||
self._save_dat_file(data, str(output_path))
|
||||
|
||||
# 保存.hdr头文件
|
||||
hdr_path = str(output_path).replace('.dat', '.hdr')
|
||||
self._save_hdr_file(hdr_path, data.shape, band_names, interleave)
|
||||
|
||||
print(f"RX异常检测结果已保存: {output_path}")
|
||||
print(f"头文件已保存: {hdr_path}")
|
||||
print(f"使用的交织格式: {interleave.upper()}")
|
||||
|
||||
def _save_dat_file(self, data: np.ndarray, file_path: str) -> None:
|
||||
"""保存.dat文件(二进制格式),参考Covariance.py"""
|
||||
# 根据数据类型保存
|
||||
if data.dtype == np.float32:
|
||||
# 对于浮点数据,直接保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
else:
|
||||
# 对于其他数据类型,转换为float32保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
|
||||
def _save_hdr_file(self, hdr_path: str, data_shape: Tuple[int, ...],
|
||||
band_names: List[str], interleave: str) -> None:
|
||||
"""保存ENVI头文件,参考Covariance.py并适配RX数据"""
|
||||
height, width, channels = data_shape
|
||||
|
||||
# 确定数据类型编码
|
||||
# ENVI数据类型: 4=float32, 5=float64, 1=uint8, 2=int16, 3=int32, 12=uint16
|
||||
data_type_code = 4 # 默认float32
|
||||
|
||||
header_content = f"""ENVI
|
||||
description = {{RX Anomaly Detection Results - Generated by RXAnomalyDetector}}
|
||||
samples = {width}
|
||||
lines = {height}
|
||||
bands = {channels}
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = {data_type_code}
|
||||
interleave = {interleave}
|
||||
byte order = 0
|
||||
"""
|
||||
|
||||
# 添加波段名称
|
||||
if band_names:
|
||||
header_content += "band names = {\n"
|
||||
for i, name in enumerate(band_names):
|
||||
header_content += f' "{name}"'
|
||||
if i < len(band_names) - 1:
|
||||
header_content += ","
|
||||
header_content += "\n"
|
||||
header_content += "}\n"
|
||||
|
||||
# 添加RX相关元数据
|
||||
header_content += f"anomaly_detection_method = rx\n"
|
||||
header_content += f"background_model = {self.config.background_model}\n"
|
||||
if self.config.background_model in ['local_row', 'local_square']:
|
||||
header_content += f"window_size = {self.config.window_size}\n"
|
||||
if self.config.background_model == 'local_row':
|
||||
header_content += f"row_window = {self.config.row_window}\n"
|
||||
if self.config.threshold is not None:
|
||||
header_content += f"threshold = {self.config.threshold}\n"
|
||||
header_content += f"use_robust_covariance = {self.config.use_robust_covariance}\n"
|
||||
|
||||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||||
f.write(header_content)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取异常检测统计信息
|
||||
|
||||
Returns:
|
||||
包含各种统计信息的字典
|
||||
"""
|
||||
if self.distance_map is None:
|
||||
raise ValueError("请先运行异常检测")
|
||||
|
||||
stats = {
|
||||
'total_pixels': len(self.distance_map),
|
||||
'background_model': self.config.background_model,
|
||||
'mean_distance': float(np.mean(self.distance_map)),
|
||||
'std_distance': float(np.std(self.distance_map)),
|
||||
'min_distance': float(np.min(self.distance_map)),
|
||||
'max_distance': float(np.max(self.distance_map)),
|
||||
'median_distance': float(np.median(self.distance_map))
|
||||
}
|
||||
|
||||
# 计算分位数
|
||||
for percentile in [25, 75, 90, 95, 99]:
|
||||
stats[f'percentile_{percentile}'] = float(np.percentile(self.distance_map, percentile))
|
||||
|
||||
if self.anomaly_map is not None:
|
||||
stats['threshold'] = self.config.threshold
|
||||
stats['n_anomalies'] = int(np.sum(self.anomaly_map))
|
||||
stats['anomaly_ratio'] = float(np.sum(self.anomaly_map) / len(self.distance_map))
|
||||
|
||||
return stats
|
||||
|
||||
def suggest_threshold(self, method: str = 'percentile') -> float:
|
||||
"""
|
||||
建议异常检测阈值
|
||||
|
||||
Args:
|
||||
method: 阈值建议方法 ('percentile', 'mean_std', 'auto')
|
||||
|
||||
Returns:
|
||||
建议的阈值
|
||||
"""
|
||||
if self.distance_map is None:
|
||||
raise ValueError("请先运行异常检测")
|
||||
|
||||
if method == 'percentile':
|
||||
# 使用95%分位数作为阈值
|
||||
threshold = np.percentile(self.distance_map, 95)
|
||||
elif method == 'mean_std':
|
||||
# 使用均值+2倍标准差作为阈值
|
||||
mean_dist = np.mean(self.distance_map)
|
||||
std_dist = np.std(self.distance_map)
|
||||
threshold = mean_dist + 2 * std_dist
|
||||
elif method == 'auto':
|
||||
# 使用自动阈值选择(基于卡方分布)
|
||||
# 对于高维数据,自由度近似为波段数
|
||||
n_bands = self.data.shape[1]
|
||||
# 使用卡方分布的95%分位数
|
||||
threshold = stats.chi2.ppf(0.95, n_bands)
|
||||
else:
|
||||
raise ValueError(f"不支持的阈值建议方法: {method}")
|
||||
|
||||
return float(threshold)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:命令行接口"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Reed-Xiaoli (RX) 高光谱异常检测',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
python RX.py --input image.bip --background-model global --output results.bip
|
||||
python RX.py --input data.csv --spectral-start 400 --background-model local_square --window-size 7 --output results.bip
|
||||
python RX.py --input image.bip --background-model custom --background background.bip --threshold 10.0 --output results.bip
|
||||
|
||||
背景模型说明:
|
||||
global: 使用整个数据集作为背景
|
||||
local_row: 使用像素周围N行的窗口作为背景
|
||||
local_square: 使用固定大小的正方形窗口作为背景
|
||||
custom: 使用外部提供的背景数据
|
||||
"""
|
||||
)
|
||||
|
||||
# 输入文件参数
|
||||
parser.add_argument('--input', required=True, help='输入文件路径 (CSV或ENVI格式)')
|
||||
parser.add_argument('--background', help='背景数据路径 (custom模式时必需)')
|
||||
parser.add_argument('--label-col', help='标签列名 (CSV文件)')
|
||||
parser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
|
||||
parser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
|
||||
|
||||
# 输出参数
|
||||
parser.add_argument('--output', help='输出文件路径')
|
||||
parser.add_argument('--output-dir', help='输出目录 (将自动生成文件名)')
|
||||
|
||||
# 背景模型参数
|
||||
parser.add_argument('--background-model', choices=['global', 'local_row', 'local_square', 'custom'],
|
||||
default='global', help='背景模型类型,默认global')
|
||||
parser.add_argument('--window-size', type=int, default=5,
|
||||
help='局部窗口大小,默认5')
|
||||
parser.add_argument('--row-window', type=int, default=3,
|
||||
help='行窗口大小,默认3')
|
||||
|
||||
# 检测参数
|
||||
parser.add_argument('--threshold', type=float,
|
||||
help='异常检测阈值,不指定则只输出距离图')
|
||||
parser.add_argument('--robust-covariance', action='store_true',
|
||||
help='使用稳健协方差估计')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# 创建配置
|
||||
config = RXConfig(
|
||||
input_path=args.input,
|
||||
background_path=args.background,
|
||||
label_col=args.label_col,
|
||||
spectral_start=args.spectral_start,
|
||||
spectral_end=args.spectral_end,
|
||||
output_path=args.output,
|
||||
output_dir=args.output_dir,
|
||||
background_model=args.background_model,
|
||||
window_size=args.window_size,
|
||||
row_window=args.row_window,
|
||||
threshold=args.threshold,
|
||||
use_robust_covariance=args.robust_covariance
|
||||
)
|
||||
|
||||
# 创建检测器
|
||||
detector = RXAnomalyDetector(config)
|
||||
|
||||
# 执行异常检测流程
|
||||
print("=== Reed-Xiaoli (RX) 异常检测 ===")
|
||||
print(f"输入文件: {config.input_path}")
|
||||
print(f"背景模型: {config.background_model}")
|
||||
if config.background_model in ['local_row', 'local_square']:
|
||||
print(f"窗口大小: {config.window_size}")
|
||||
if config.threshold is not None:
|
||||
print(f"检测阈值: {config.threshold}")
|
||||
print("-" * 50)
|
||||
|
||||
# 1. 加载数据
|
||||
detector.load_data()
|
||||
|
||||
# 2. 预处理数据
|
||||
detector.preprocess_data()
|
||||
|
||||
# 3. 执行异常检测
|
||||
distance_map = detector.detect_anomalies()
|
||||
|
||||
# 4. 阈值建议
|
||||
if config.threshold is None:
|
||||
suggested_threshold = detector.suggest_threshold('percentile')
|
||||
print(f"\n建议阈值 (95百分位数): {suggested_threshold:.4f}")
|
||||
|
||||
# 5. 保存结果
|
||||
detector.save_results()
|
||||
|
||||
# 6. 显示统计信息
|
||||
stats = detector.get_statistics()
|
||||
print("\n=== 检测结果统计 ===")
|
||||
for key, value in stats.items():
|
||||
print(f"{key}: {value}")
|
||||
|
||||
print("\n处理完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
823
Anomaly_method/squared_loss_probability.py
Normal file
823
Anomaly_method/squared_loss_probability.py
Normal file
@ -0,0 +1,823 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import spectral
|
||||
from sklearn.linear_model import LinearRegression
|
||||
from sklearn.preprocessing import StandardScaler, RobustScaler
|
||||
from sklearn.metrics import mean_squared_error
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Dict, Any, Tuple, List, Union
|
||||
from dataclasses import dataclass, field
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from scipy import stats
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
@dataclass
|
||||
class SquaredLossAnomalyConfig:
|
||||
"""基于最小二乘的自重构异常检测配置类"""
|
||||
|
||||
# 输入文件配置
|
||||
input_path: Optional[str] = None # 输入数据路径
|
||||
label_col: Optional[str] = None
|
||||
spectral_start: Optional[str] = None
|
||||
spectral_end: Optional[str] = None
|
||||
csv_has_header: bool = True
|
||||
|
||||
# 输出配置
|
||||
output_path: Optional[str] = None
|
||||
output_dir: Optional[str] = None
|
||||
|
||||
# 自重构参数
|
||||
reconstruction_method: str = 'linear' # 重构方法: 'linear', 'pca'
|
||||
n_components: Optional[int] = None # PCA降维维度(当reconstruction_method='pca'时使用)
|
||||
target_band: Optional[int] = None # 目标波段索引(用于linear方法,None表示使用最后一个波段)
|
||||
|
||||
# 检测参数
|
||||
probability_flag: bool = True # True: 概率模式, False: 分类模式
|
||||
threshold: float = 0.8 # 分类阈值(当probability_flag=False时使用)
|
||||
use_scaling: bool = True # 是否使用标准化
|
||||
scaler_type: str = 'robust' # 标准化类型: 'standard' 或 'robust'
|
||||
|
||||
# 模型参数
|
||||
fit_intercept: bool = True # 是否拟合截距
|
||||
normalize: bool = False # 弃用参数,保持兼容性
|
||||
|
||||
def __post_init__(self):
|
||||
"""参数校验和默认值设置"""
|
||||
# 校验必需的文件路径
|
||||
if not self.input_path:
|
||||
raise ValueError("必须指定输入数据路径(input_path)")
|
||||
|
||||
# 校验文件存在性
|
||||
if not os.path.exists(self.input_path):
|
||||
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
|
||||
|
||||
# 校验重构方法
|
||||
supported_methods = ['linear', 'pca']
|
||||
if self.reconstruction_method not in supported_methods:
|
||||
raise ValueError(f"不支持的重构方法: {self.reconstruction_method}。支持的方法: {supported_methods}")
|
||||
|
||||
# 校验阈值范围
|
||||
if not 0 <= self.threshold <= 1:
|
||||
raise ValueError("threshold必须在[0, 1]范围内")
|
||||
|
||||
# 统一标准化类型为小写
|
||||
self.scaler_type = self.scaler_type.lower()
|
||||
|
||||
# 校验标准化类型
|
||||
if self.scaler_type not in ['standard', 'robust']:
|
||||
raise ValueError(f"不支持的标准化类型: {self.scaler_type}。支持的类型: ['standard', 'robust']")
|
||||
|
||||
# 统一重构方法为小写
|
||||
self.reconstruction_method = self.reconstruction_method.lower()
|
||||
|
||||
|
||||
class SquaredLossAnomalyDetector:
|
||||
"""
|
||||
基于最小二乘的自重构异常检测器
|
||||
|
||||
算法原理:
|
||||
1. 使用自重构方法学习数据的正常模式
|
||||
2. 计算每个样本的重构误差作为异常分数
|
||||
3. 将误差归一化为0-1的异常概率
|
||||
4. 根据阈值进行二元分类
|
||||
|
||||
支持的重构方法:
|
||||
- linear: 使用线性回归进行特征重构
|
||||
- pca: 使用PCA进行降维重构
|
||||
|
||||
参考:基于重构误差的无监督异常检测方法
|
||||
"""
|
||||
|
||||
def __init__(self, config: SquaredLossAnomalyConfig):
|
||||
"""
|
||||
初始化异常检测器
|
||||
|
||||
Args:
|
||||
config: 配置对象
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
# 输入数据
|
||||
self.data = None
|
||||
self.labels = None
|
||||
self.wavelengths = None
|
||||
self.data_shape = None
|
||||
self.input_format = None
|
||||
|
||||
# 预处理组件
|
||||
self.scaler = None
|
||||
|
||||
# 重构模型
|
||||
self.model = None
|
||||
self.pca_model = None
|
||||
|
||||
# 结果
|
||||
self.reconstruction_errors = None
|
||||
self.anomaly_probabilities = None
|
||||
self.predictions = None
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""
|
||||
加载高光谱数据文件
|
||||
|
||||
支持CSV和ENVI格式文件
|
||||
"""
|
||||
file_path = Path(self.config.input_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
print(f"正在读取数据: {file_path}")
|
||||
|
||||
if suffix == '.csv':
|
||||
self._load_csv_data(file_path)
|
||||
elif suffix in ['.hdr']:
|
||||
self._load_envi_data(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}。支持的格式: .hdr")
|
||||
|
||||
def _load_csv_data(self, file_path: Path) -> None:
|
||||
"""加载CSV格式数据"""
|
||||
if self.config.spectral_start is None:
|
||||
raise ValueError("对于CSV文件,必须指定spectral_start参数")
|
||||
|
||||
self.input_format = 'csv'
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv(file_path, header=0 if self.config.csv_has_header else None)
|
||||
|
||||
# 提取标签列
|
||||
if self.config.label_col and self.config.label_col in df.columns:
|
||||
self.labels = df[self.config.label_col].values
|
||||
spectral_cols = [col for col in df.columns if col != self.config.label_col]
|
||||
else:
|
||||
self.labels = None
|
||||
spectral_cols = df.columns.tolist()
|
||||
|
||||
# 提取光谱数据
|
||||
if self.config.spectral_end:
|
||||
end_idx = spectral_cols.index(self.config.spectral_end) + 1
|
||||
else:
|
||||
end_idx = len(spectral_cols)
|
||||
|
||||
start_idx = spectral_cols.index(self.config.spectral_start)
|
||||
spectral_data = df[spectral_cols[start_idx:end_idx]].values
|
||||
|
||||
# 生成波长信息
|
||||
self.wavelengths = np.arange(len(spectral_cols[start_idx:end_idx]))
|
||||
|
||||
self.data = spectral_data.astype(np.float32)
|
||||
self.data_shape = self.data.shape
|
||||
|
||||
print(f"数据加载完成: {self.data.shape} 样本 x {self.data.shape[1]} 波段")
|
||||
|
||||
def _load_envi_data(self, file_path: Path) -> None:
|
||||
"""加载ENVI格式数据"""
|
||||
self.input_format = 'envi'
|
||||
|
||||
try:
|
||||
# 确保文件路径存在且可访问
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"ENVI文件不存在: {file_path}")
|
||||
|
||||
# 检查文件是否可读
|
||||
if not os.access(file_path, os.R_OK):
|
||||
raise PermissionError(f"没有读取权限: {file_path}")
|
||||
|
||||
# 使用规范化的路径字符串
|
||||
file_path_str = str(file_path)
|
||||
print(f"尝试使用spectral库读取ENVI文件: {file_path_str}")
|
||||
|
||||
# 检查对应的头文件
|
||||
hdr_path = file_path.with_suffix(file_path.suffix + '.hdr')
|
||||
if not hdr_path.exists():
|
||||
# 尝试 .hdr 格式
|
||||
hdr_path = file_path.with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
print(f"警告: 未找到头文件 {file_path.with_suffix(file_path.suffix + '.hdr')} 或 {file_path.with_suffix('.hdr')}")
|
||||
|
||||
# 读取ENVI文件 - 使用规范化的路径
|
||||
envi_data = spectral.open_image(file_path_str)
|
||||
|
||||
# 获取数据数组
|
||||
print("正在加载数据到内存...")
|
||||
data_array = envi_data.load()
|
||||
|
||||
# 获取波长信息
|
||||
if hasattr(envi_data, 'bands') and hasattr(envi_data.bands, 'centers'):
|
||||
self.wavelengths = np.array(envi_data.bands.centers)
|
||||
print(f"读取到波长信息: {len(self.wavelengths)} 个波段")
|
||||
else:
|
||||
self.wavelengths = np.arange(data_array.shape[2])
|
||||
print(f"未找到波长信息,使用默认波长: 0-{len(self.wavelengths)-1}")
|
||||
|
||||
# 重塑数据为 (samples, bands)
|
||||
rows, cols, bands = data_array.shape
|
||||
self.data = data_array.reshape(-1, bands).astype(np.float32)
|
||||
self.data_shape = (rows, cols, bands)
|
||||
|
||||
print(f"数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取ENVI文件失败: {str(e)}。请确保spectral库可用且文件格式正确。")
|
||||
|
||||
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
Args:
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
处理后的数据
|
||||
"""
|
||||
# 数据清理
|
||||
data_clean = self._clean_data(data)
|
||||
|
||||
# 标准化
|
||||
if self.config.use_scaling:
|
||||
data_processed = self._scale_data(data_clean)
|
||||
else:
|
||||
data_processed = data_clean
|
||||
|
||||
return data_processed
|
||||
|
||||
def _clean_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""清理数据:移除无效值"""
|
||||
# 移除包含NaN或无穷大值的行
|
||||
valid_mask = np.all(np.isfinite(data), axis=1)
|
||||
data_clean = data[valid_mask]
|
||||
|
||||
if np.sum(~valid_mask) > 0:
|
||||
print(f"移除了 {np.sum(~valid_mask)} 个包含无效值的样本")
|
||||
|
||||
# 移除全零行
|
||||
non_zero_mask = np.any(data_clean != 0, axis=1)
|
||||
data_clean = data_clean[non_zero_mask]
|
||||
|
||||
if np.sum(~non_zero_mask) > 0:
|
||||
print(f"移除了 {np.sum(~non_zero_mask)} 个全零样本")
|
||||
|
||||
return data_clean
|
||||
|
||||
def _scale_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""标准化数据"""
|
||||
if self.scaler is None:
|
||||
if self.config.scaler_type == 'robust':
|
||||
self.scaler = RobustScaler()
|
||||
else:
|
||||
self.scaler = StandardScaler()
|
||||
|
||||
# 使用输入数据拟合标准化器
|
||||
self.scaler.fit(data)
|
||||
|
||||
data_scaled = self.scaler.transform(data)
|
||||
return data_scaled
|
||||
|
||||
def build_reconstruction_model(self) -> None:
|
||||
"""
|
||||
构建自重构模型
|
||||
|
||||
根据配置的重构方法构建相应的重构模型
|
||||
"""
|
||||
if self.data is None:
|
||||
raise ValueError("请先调用load_data()加载数据")
|
||||
|
||||
print(f"正在构建{self.config.reconstruction_method}自重构模型...")
|
||||
|
||||
# 预处理数据
|
||||
data_processed = self.preprocess_data(self.data)
|
||||
|
||||
if self.config.reconstruction_method == 'linear':
|
||||
self._build_linear_reconstruction_model(data_processed)
|
||||
elif self.config.reconstruction_method == 'pca':
|
||||
self._build_pca_reconstruction_model(data_processed)
|
||||
|
||||
def _build_linear_reconstruction_model(self, data: np.ndarray) -> None:
|
||||
"""
|
||||
构建线性重构模型
|
||||
|
||||
使用前n-1个波段预测第n个波段
|
||||
"""
|
||||
n_samples, n_features = data.shape
|
||||
|
||||
if n_features < 2:
|
||||
raise ValueError("数据至少需要2个波段进行线性重构")
|
||||
|
||||
# 确定目标波段
|
||||
if self.config.target_band is None:
|
||||
target_band = -1 # 默认使用最后一个波段
|
||||
else:
|
||||
target_band = self.config.target_band
|
||||
if target_band >= n_features or target_band < 0:
|
||||
raise ValueError(f"目标波段索引 {target_band} 超出范围 [0, {n_features-1}]")
|
||||
|
||||
# 准备训练数据
|
||||
X_train = np.delete(data, target_band, axis=1) # 移除目标波段作为特征
|
||||
y_train = data[:, target_band] # 目标波段
|
||||
|
||||
# 训练线性回归模型
|
||||
self.model = LinearRegression(
|
||||
fit_intercept=self.config.fit_intercept
|
||||
)
|
||||
|
||||
self.model.fit(X_train, y_train)
|
||||
|
||||
# 计算训练重构误差
|
||||
y_pred = self.model.predict(X_train)
|
||||
train_mse = mean_squared_error(y_train, y_pred)
|
||||
train_rmse = np.sqrt(train_mse)
|
||||
|
||||
print(f"线性重构模型构建完成:")
|
||||
print(f" - 目标波段: {target_band}")
|
||||
print(f" - 特征波段数: {X_train.shape[1]}")
|
||||
print(f" - 训练MSE: {train_mse:.4f}")
|
||||
print(f" - 训练RMSE: {train_rmse:.4f}")
|
||||
|
||||
def _build_pca_reconstruction_model(self, data: np.ndarray) -> None:
|
||||
"""
|
||||
构建PCA重构模型
|
||||
|
||||
使用PCA进行降维重构
|
||||
"""
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
n_samples, n_features = data.shape
|
||||
|
||||
# 确定PCA维度
|
||||
if self.config.n_components is None:
|
||||
n_components = min(n_features, max(2, n_features // 2))
|
||||
else:
|
||||
n_components = min(self.config.n_components, n_features)
|
||||
|
||||
# 训练PCA模型
|
||||
self.pca_model = PCA(n_components=n_components, random_state=42)
|
||||
data_pca = self.pca_model.fit_transform(data)
|
||||
|
||||
# 重构数据
|
||||
data_reconstructed = self.pca_model.inverse_transform(data_pca)
|
||||
|
||||
# 计算重构误差
|
||||
reconstruction_errors = np.mean((data - data_reconstructed) ** 2, axis=1)
|
||||
mean_error = np.mean(reconstruction_errors)
|
||||
std_error = np.std(reconstruction_errors)
|
||||
|
||||
explained_var = np.sum(self.pca_model.explained_variance_ratio_)
|
||||
|
||||
print(f"PCA重构模型构建完成:")
|
||||
print(f" - 原始维度: {n_features}")
|
||||
print(f" - PCA维度: {n_components}")
|
||||
print(f" - 解释方差比例: {explained_var:.3f}")
|
||||
print(f" - 平均重构误差: {mean_error:.4f}")
|
||||
print(f" - 重构误差标准差: {std_error:.4f}")
|
||||
|
||||
def detect_anomalies(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
使用自重构模型检测异常
|
||||
|
||||
Returns:
|
||||
predictions: 预测结果
|
||||
probabilities: 异常概率
|
||||
"""
|
||||
if self.model is None and self.pca_model is None:
|
||||
raise ValueError("请先调用build_reconstruction_model()构建重构模型")
|
||||
|
||||
if self.data is None:
|
||||
raise ValueError("请先调用load_data()加载数据")
|
||||
|
||||
print("正在进行自重构异常检测...")
|
||||
|
||||
# 预处理数据
|
||||
data_processed = self.preprocess_data(self.data)
|
||||
|
||||
# 计算重构误差
|
||||
if self.config.reconstruction_method == 'linear':
|
||||
reconstruction_errors = self._compute_linear_reconstruction_errors(data_processed)
|
||||
elif self.config.reconstruction_method == 'pca':
|
||||
reconstruction_errors = self._compute_pca_reconstruction_errors(data_processed)
|
||||
|
||||
# 将误差归一化为概率
|
||||
anomaly_probabilities = self._errors_to_probabilities(reconstruction_errors)
|
||||
|
||||
# 根据输出模式生成预测结果
|
||||
if self.config.probability_flag:
|
||||
# 概率模式:输出连续概率值
|
||||
predictions = anomaly_probabilities
|
||||
else:
|
||||
# 分类模式:基于阈值的二元分类
|
||||
predictions = (anomaly_probabilities > self.config.threshold).astype(int)
|
||||
|
||||
self.reconstruction_errors = reconstruction_errors
|
||||
self.anomaly_probabilities = anomaly_probabilities
|
||||
self.predictions = predictions
|
||||
|
||||
# 统计信息
|
||||
if self.config.probability_flag:
|
||||
n_anomalies = np.sum(predictions > self.config.threshold)
|
||||
anomaly_ratio = n_anomalies / len(predictions) if len(predictions) > 0 else 0
|
||||
else:
|
||||
n_anomalies = np.sum(predictions == 1)
|
||||
anomaly_ratio = n_anomalies / len(predictions) if len(predictions) > 0 else 0
|
||||
|
||||
print(f"异常检测完成:")
|
||||
print(f" - 样本数: {len(predictions)}")
|
||||
print(f" - 异常样本数: {n_anomalies}")
|
||||
print(f" - 异常比例: {anomaly_ratio:.3f}")
|
||||
|
||||
return predictions, anomaly_probabilities
|
||||
|
||||
def run_analysis_from_config(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
从配置运行完整的异常检测分析
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: (predictions, scores)
|
||||
"""
|
||||
# 构建重构模型
|
||||
self.build_reconstruction_model()
|
||||
|
||||
# 执行异常检测
|
||||
predictions, scores = self.detect_anomalies()
|
||||
|
||||
return predictions, scores
|
||||
|
||||
def _compute_linear_reconstruction_errors(self, data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
计算线性重构误差
|
||||
|
||||
Args:
|
||||
data: 预处理后的数据
|
||||
|
||||
Returns:
|
||||
重构误差数组
|
||||
"""
|
||||
# 确定目标波段
|
||||
if self.config.target_band is None:
|
||||
target_band = -1
|
||||
else:
|
||||
target_band = self.config.target_band
|
||||
|
||||
# 准备特征数据
|
||||
X_data = np.delete(data, target_band, axis=1) # 移除目标波段作为特征
|
||||
y_true = data[:, target_band] # 真实目标值
|
||||
|
||||
# 预测
|
||||
y_pred = self.model.predict(X_data)
|
||||
|
||||
# 计算重构误差(使用绝对误差)
|
||||
reconstruction_errors = np.abs(y_true - y_pred)
|
||||
|
||||
return reconstruction_errors
|
||||
|
||||
def _compute_pca_reconstruction_errors(self, data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
计算PCA重构误差
|
||||
|
||||
Args:
|
||||
data: 预处理后的数据
|
||||
|
||||
Returns:
|
||||
重构误差数组
|
||||
"""
|
||||
# PCA降维
|
||||
data_pca = self.pca_model.transform(data)
|
||||
|
||||
# 重构
|
||||
data_reconstructed = self.pca_model.inverse_transform(data_pca)
|
||||
|
||||
# 计算重构误差(使用均方误差)
|
||||
reconstruction_errors = np.mean((data - data_reconstructed) ** 2, axis=1)
|
||||
|
||||
return reconstruction_errors
|
||||
|
||||
def _errors_to_probabilities(self, errors: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
将重构误差转换为异常概率
|
||||
|
||||
使用经验累积分布函数(ECDF)将误差映射到[0,1]概率区间
|
||||
|
||||
Args:
|
||||
errors: 重构误差数组
|
||||
|
||||
Returns:
|
||||
异常概率数组
|
||||
"""
|
||||
# 计算经验累积分布
|
||||
sorted_errors = np.sort(errors)
|
||||
ecdf = np.arange(1, len(sorted_errors) + 1) / len(sorted_errors)
|
||||
|
||||
# 使用线性插值将误差映射到概率
|
||||
probabilities = np.interp(errors, sorted_errors, ecdf)
|
||||
|
||||
return probabilities
|
||||
|
||||
def save_results(self, output_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
保存异常检测结果为单波段ENVI格式文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径,如果为None则使用配置中的路径
|
||||
"""
|
||||
if output_path is None:
|
||||
if self.config.output_path:
|
||||
output_path = self.config.output_path
|
||||
elif self.config.output_dir:
|
||||
base_name = Path(self.config.test_path or self.config.train_path).stem
|
||||
output_path = os.path.join(self.config.output_dir, f"{base_name}_anomaly_squared_loss.bip")
|
||||
else:
|
||||
raise ValueError("必须指定输出路径")
|
||||
|
||||
print(f"正在保存结果到: {output_path}")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.dirname(output_path)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 根据输入格式决定输出格式
|
||||
if self.input_format == 'envi':
|
||||
self._save_envi_results(output_path)
|
||||
else:
|
||||
# 对于CSV输入,创建虚拟的2D图像
|
||||
self._save_csv_results_as_envi(output_path)
|
||||
|
||||
def _save_envi_results(self, output_path: str) -> None:
|
||||
"""保存ENVI格式结果为单波段图像"""
|
||||
# 重塑预测结果为原始图像尺寸
|
||||
rows, cols, bands = self.data_shape
|
||||
|
||||
# 创建单波段异常检测结果图像
|
||||
anomaly_image = np.zeros((rows, cols, 1), dtype=np.float32)
|
||||
|
||||
# 将1D结果映射回2D图像
|
||||
result_2d = self.predictions.reshape(rows, cols)
|
||||
anomaly_image[:, :, 0] = result_2d
|
||||
|
||||
# 波段名称
|
||||
band_names = ["Anomaly_Result"]
|
||||
|
||||
# 保存为ENVI格式,参考Covariance.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_csv_results_as_envi(self, output_path: str) -> None:
|
||||
"""将CSV结果保存为单波段ENVI格式"""
|
||||
# 对于CSV数据,创建1xN的单波段图像
|
||||
n_samples = len(self.predictions)
|
||||
anomaly_image = np.zeros((1, n_samples, 1), dtype=np.float32)
|
||||
|
||||
anomaly_image[0, :, 0] = self.predictions
|
||||
|
||||
# 波段名称
|
||||
band_names = ["Anomaly_Result"]
|
||||
|
||||
# 保存为ENVI格式,参考Covariance.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_envi_data(self, data: np.ndarray, output_path: Union[str, Path],
|
||||
band_names: List[str], interleave: str = 'bip') -> None:
|
||||
"""
|
||||
保存数据为ENVI格式,参考Covariance.py的实现
|
||||
|
||||
参数:
|
||||
data: 要保存的数据数组 (height, width, channels)
|
||||
output_path: 输出文件路径
|
||||
band_names: 波段名称列表
|
||||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
height, width, channels = data.shape
|
||||
|
||||
# 确保目录存在
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"保存ENVI文件 - 形状: {data.shape}, 交织方式: {interleave}")
|
||||
|
||||
# 保存.dat文件(二进制格式)
|
||||
self._save_dat_file(data, str(output_path))
|
||||
|
||||
# 保存.hdr头文件
|
||||
hdr_path = str(output_path).replace('.dat', '.hdr')
|
||||
self._save_hdr_file(hdr_path, data.shape, band_names, interleave)
|
||||
|
||||
print(f"最小二乘异常检测结果已保存: {output_path}")
|
||||
print(f"头文件已保存: {hdr_path}")
|
||||
print(f"使用的交织格式: {interleave.upper()}")
|
||||
|
||||
def _save_dat_file(self, data: np.ndarray, file_path: str) -> None:
|
||||
"""保存.dat文件(二进制格式),参考Covariance.py"""
|
||||
# 根据数据类型保存
|
||||
if data.dtype == np.float32:
|
||||
# 对于浮点数据,直接保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
else:
|
||||
# 对于其他数据类型,转换为float32保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
|
||||
def _save_hdr_file(self, hdr_path: str, data_shape: Tuple[int, ...],
|
||||
band_names: List[str], interleave: str) -> None:
|
||||
"""保存ENVI头文件,参考Covariance.py并适配最小二乘数据"""
|
||||
height, width, channels = data_shape
|
||||
|
||||
# 确定数据类型编码
|
||||
# ENVI数据类型: 4=float32, 5=float64, 1=uint8, 2=int16, 3=int32, 12=uint16
|
||||
data_type_code = 4 # 默认float32
|
||||
|
||||
header_content = f"""ENVI
|
||||
description = {{Squared Loss Anomaly Detection Results - Generated by SquaredLossAnomalyDetector}}
|
||||
samples = {width}
|
||||
lines = {height}
|
||||
bands = {channels}
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = {data_type_code}
|
||||
interleave = {interleave}
|
||||
byte order = 0
|
||||
"""
|
||||
|
||||
# 添加波段名称
|
||||
if band_names:
|
||||
header_content += "band names = {\n"
|
||||
for i, name in enumerate(band_names):
|
||||
header_content += f' "{name}"'
|
||||
if i < len(band_names) - 1:
|
||||
header_content += ","
|
||||
header_content += "\n"
|
||||
header_content += "}\n"
|
||||
|
||||
# 添加自重构异常检测相关元数据
|
||||
header_content += f"anomaly_detection_method = self_reconstruction\n"
|
||||
header_content += f"reconstruction_method = {self.config.reconstruction_method}\n"
|
||||
if self.config.reconstruction_method == 'pca' and self.config.n_components is not None:
|
||||
header_content += f"n_components = {self.config.n_components}\n"
|
||||
if self.config.target_band is not None:
|
||||
header_content += f"target_band = {self.config.target_band}\n"
|
||||
header_content += f"probability_flag = {self.config.probability_flag}\n"
|
||||
if not self.config.probability_flag:
|
||||
header_content += f"threshold = {self.config.threshold}\n"
|
||||
header_content += f"use_scaling = {self.config.use_scaling}\n"
|
||||
header_content += f"scaler_type = {self.config.scaler_type}\n"
|
||||
header_content += f"fit_intercept = {self.config.fit_intercept}\n"
|
||||
|
||||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||||
f.write(header_content)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取异常检测统计信息
|
||||
|
||||
Returns:
|
||||
包含各种统计信息的字典
|
||||
"""
|
||||
if self.predictions is None:
|
||||
raise ValueError("请先运行异常检测")
|
||||
|
||||
stats = {
|
||||
'total_samples': len(self.predictions),
|
||||
'output_mode': 'probability' if self.config.probability_flag else 'classification',
|
||||
'threshold': self.config.threshold if not self.config.probability_flag else None,
|
||||
'mean_reconstruction_error': float(np.mean(self.reconstruction_errors)),
|
||||
'std_reconstruction_error': float(np.std(self.reconstruction_errors)),
|
||||
'min_reconstruction_error': float(np.min(self.reconstruction_errors)),
|
||||
'max_reconstruction_error': float(np.max(self.reconstruction_errors)),
|
||||
'mean_anomaly_probability': float(np.mean(self.anomaly_probabilities)),
|
||||
'std_anomaly_probability': float(np.std(self.anomaly_probabilities)),
|
||||
}
|
||||
|
||||
if self.config.probability_flag:
|
||||
stats['n_anomalies'] = int(np.sum(self.predictions > self.config.threshold))
|
||||
stats['anomaly_ratio'] = float(np.sum(self.predictions > self.config.threshold) / len(self.predictions))
|
||||
else:
|
||||
stats['n_anomalies'] = int(np.sum(self.predictions == 1))
|
||||
stats['anomaly_ratio'] = float(np.sum(self.predictions == 1) / len(self.predictions))
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:命令行接口"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='基于最小二乘的高光谱异常检测',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
# 线性重构 + 概率模式
|
||||
python squared_loss_probability.py --input data.csv --spectral-start 400 --reconstruction-method linear --probability --output results.bip
|
||||
|
||||
# PCA重构 + 分类模式
|
||||
python squared_loss_probability.py --input image.bip --reconstruction-method pca --n-components 10 --threshold 0.8 --output results.bip
|
||||
|
||||
# 指定目标波段的线性重构
|
||||
python squared_loss_probability.py --input data.csv --spectral-start 400 --target-band 50 --output results.bip
|
||||
"""
|
||||
)
|
||||
|
||||
# 输入文件参数
|
||||
parser.add_argument('--input', required=True, help='输入数据路径 (CSV或ENVI格式)')
|
||||
parser.add_argument('--label-col', help='标签列名 (CSV文件)')
|
||||
parser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
|
||||
parser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
|
||||
|
||||
# 输出参数
|
||||
parser.add_argument('--output', help='输出文件路径')
|
||||
parser.add_argument('--output-dir', help='输出目录 (将自动生成文件名)')
|
||||
|
||||
# 自重构参数
|
||||
parser.add_argument('--reconstruction-method', choices=['linear', 'pca'], default='linear',
|
||||
help='重构方法 (linear: 线性回归, pca: PCA降维,默认linear)')
|
||||
parser.add_argument('--n-components', type=int,
|
||||
help='PCA降维维度 (仅在reconstruction-method=pca时使用)')
|
||||
parser.add_argument('--target-band', type=int,
|
||||
help='目标波段索引 (仅在reconstruction-method=linear时使用,默认使用最后一个波段)')
|
||||
|
||||
# 检测参数
|
||||
parser.add_argument('--probability', action='store_true', default=True,
|
||||
help='概率模式输出 (0-1连续值,默认)')
|
||||
parser.add_argument('--classification', action='store_true',
|
||||
help='分类模式输出 (0/1二元分类)')
|
||||
parser.add_argument('--threshold', type=float, default=0.8,
|
||||
help='分类阈值 (分类模式时使用,默认0.8)')
|
||||
|
||||
# 预处理参数
|
||||
parser.add_argument('--no-scaling', action='store_true',
|
||||
help='禁用数据标准化')
|
||||
parser.add_argument('--scaler-type', choices=['standard', 'robust'], default='robust',
|
||||
help='标准化类型,默认robust')
|
||||
|
||||
# 模型参数
|
||||
parser.add_argument('--no-intercept', action='store_true',
|
||||
help='不拟合截距项')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 处理输出模式
|
||||
if args.classification:
|
||||
probability_flag = False
|
||||
else:
|
||||
probability_flag = args.probability
|
||||
|
||||
try:
|
||||
# 创建配置
|
||||
config = SquaredLossAnomalyConfig(
|
||||
input_path=args.input,
|
||||
label_col=args.label_col,
|
||||
spectral_start=args.spectral_start,
|
||||
spectral_end=args.spectral_end,
|
||||
output_path=args.output,
|
||||
output_dir=args.output_dir,
|
||||
reconstruction_method=args.reconstruction_method,
|
||||
n_components=args.n_components,
|
||||
target_band=args.target_band,
|
||||
probability_flag=probability_flag,
|
||||
threshold=args.threshold,
|
||||
use_scaling=not args.no_scaling,
|
||||
scaler_type=args.scaler_type,
|
||||
fit_intercept=not args.no_intercept
|
||||
)
|
||||
|
||||
# 创建检测器
|
||||
detector = SquaredLossAnomalyDetector(config)
|
||||
|
||||
# 执行异常检测流程
|
||||
print("=== 基于自重构的异常检测 ===")
|
||||
print(f"输入数据: {config.input_path}")
|
||||
print(f"重构方法: {config.reconstruction_method}")
|
||||
print(f"输出模式: {'概率' if config.probability_flag else '分类'}")
|
||||
if not config.probability_flag:
|
||||
print(f"分类阈值: {config.threshold}")
|
||||
print("-" * 50)
|
||||
|
||||
# 1. 加载数据
|
||||
detector.load_data()
|
||||
|
||||
# 2. 构建自重构模型
|
||||
detector.build_reconstruction_model()
|
||||
|
||||
# 3. 执行异常检测
|
||||
predictions, probabilities = detector.detect_anomalies()
|
||||
|
||||
# 5. 保存结果
|
||||
detector.save_results()
|
||||
|
||||
# 6. 显示统计信息
|
||||
stats = detector.get_statistics()
|
||||
print("\n=== 检测结果统计 ===")
|
||||
for key, value in stats.items():
|
||||
if value is not None:
|
||||
print(f"{key}: {value}")
|
||||
|
||||
print("\n处理完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user