增加模块;增加主调用命令
This commit is contained in:
807
Anomaly_method/One_Class_SVM.py
Normal file
807
Anomaly_method/One_Class_SVM.py
Normal file
@ -0,0 +1,807 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import spectral
|
||||
from sklearn.svm import OneClassSVM
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.model_selection import GridSearchCV, cross_val_score
|
||||
from sklearn.metrics import make_scorer, precision_score, recall_score, f1_score
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Dict, Any, Tuple, List, Union
|
||||
from dataclasses import dataclass, field
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from scipy import stats
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
@dataclass
|
||||
class OneClassSVMConfig:
|
||||
"""One-Class SVM异常检测配置类"""
|
||||
|
||||
# 输入文件配置
|
||||
input_path: Optional[str] = None
|
||||
label_col: Optional[str] = None
|
||||
spectral_start: Optional[str] = None
|
||||
spectral_end: Optional[str] = None
|
||||
csv_has_header: bool = True
|
||||
|
||||
# 输出配置
|
||||
output_path: Optional[str] = None
|
||||
output_dir: Optional[str] = None
|
||||
|
||||
# SVM参数
|
||||
kernel: str = 'rbf' # 核函数: 'linear', 'rbf', 'poly'
|
||||
nu: float = 0.1 # 训练误差比例 (0.01-0.5)
|
||||
gamma: Optional[Union[str, float]] = 'auto' # 核函数参数: 'auto', 'scale', 或数值
|
||||
degree: int = 3 # 多项式核的阶数
|
||||
coef0: float = 0.0 # 核函数常数项
|
||||
tol: float = 1e-3 # 停止准则公差
|
||||
shrinking: bool = True # 是否使用shrinking heuristic
|
||||
cache_size: float = 200 # 核缓存大小(MB)
|
||||
max_iter: int = -1 # 最大迭代次数
|
||||
|
||||
# 数据预处理配置
|
||||
use_scaling: bool = True # 是否使用标准化
|
||||
use_pca: bool = True # 是否使用PCA降维
|
||||
n_components: Optional[int] = None # PCA降维维度,None表示自动选择
|
||||
|
||||
# 参数调优配置
|
||||
use_grid_search: bool = False # 是否使用网格搜索调优参数
|
||||
cv_folds: int = 3 # 交叉验证折数
|
||||
|
||||
def __post_init__(self):
|
||||
"""参数校验和默认值设置"""
|
||||
# 校验必需的文件路径
|
||||
if not self.input_path:
|
||||
raise ValueError("必须指定输入文件路径(input_path)")
|
||||
|
||||
# 校验文件存在性
|
||||
if not os.path.exists(self.input_path):
|
||||
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
|
||||
|
||||
# 校验输出路径
|
||||
if not self.output_path and not self.output_dir:
|
||||
raise ValueError("必须指定输出路径(output_path)或输出目录(output_dir)")
|
||||
|
||||
# 校验参数范围
|
||||
if not 0.01 <= self.nu <= 0.5:
|
||||
raise ValueError("nu必须在[0.01, 0.5]范围内")
|
||||
|
||||
# 统一核函数名为小写
|
||||
self.kernel = self.kernel.lower()
|
||||
|
||||
# 校验核函数类型
|
||||
supported_kernels = ['linear', 'rbf', 'poly']
|
||||
if self.kernel not in supported_kernels:
|
||||
raise ValueError(f"不支持的核函数: {self.kernel}。支持的核函数: {supported_kernels}")
|
||||
|
||||
# 校验gamma参数
|
||||
if isinstance(self.gamma, str):
|
||||
if self.gamma not in ['auto', 'scale']:
|
||||
raise ValueError("gamma字符串值必须是'auto'或'scale'")
|
||||
elif self.gamma is not None and self.gamma <= 0:
|
||||
raise ValueError("gamma数值必须大于0")
|
||||
|
||||
# 校验多项式阶数
|
||||
if self.degree < 1:
|
||||
raise ValueError("degree必须大于等于1")
|
||||
|
||||
|
||||
class OneClassSVMAnomalyDetector:
|
||||
"""
|
||||
One-Class SVM异常检测器
|
||||
|
||||
使用One-Class SVM算法进行无监督异常检测,只需要正常数据进行训练。
|
||||
支持多种核函数,能够学习复杂的数据分布边界。
|
||||
|
||||
算法原理:
|
||||
- 使用正常数据训练SVM,学习正常样本的决策边界
|
||||
- 测试样本落在边界内为正常(-1),落在边界外为异常(+1)
|
||||
"""
|
||||
|
||||
def __init__(self, config: OneClassSVMConfig):
|
||||
"""
|
||||
初始化异常检测器
|
||||
|
||||
Args:
|
||||
config: 配置对象
|
||||
"""
|
||||
self.config = config
|
||||
self.data = None
|
||||
self.labels = None
|
||||
self.wavelengths = None
|
||||
self.data_shape = None
|
||||
self.input_format = None
|
||||
|
||||
# 预处理组件
|
||||
self.scaler = None
|
||||
self.pca = None
|
||||
|
||||
# One-Class SVM模型
|
||||
self.svm_model = None
|
||||
|
||||
# 结果
|
||||
self.decision_scores = None # 决策函数值
|
||||
self.predictions = None # 预测结果 (+1正常, -1异常)
|
||||
|
||||
# 最佳参数(网格搜索后)
|
||||
self.best_params = None
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""
|
||||
加载高光谱数据文件
|
||||
|
||||
支持CSV和ENVI格式文件
|
||||
"""
|
||||
file_path = Path(self.config.input_path)
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
print(f"正在读取文件: {file_path}")
|
||||
|
||||
if suffix == '.csv':
|
||||
self._load_csv_data(file_path)
|
||||
elif suffix in ['.hdr']:
|
||||
self._load_envi_data(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}。支持的格式: .hdr")
|
||||
|
||||
def _load_csv_data(self, file_path: Path) -> None:
|
||||
"""加载CSV格式数据"""
|
||||
if self.config.spectral_start is None:
|
||||
raise ValueError("对于CSV文件,必须指定spectral_start参数")
|
||||
|
||||
self.input_format = 'csv'
|
||||
|
||||
# 读取数据
|
||||
df = pd.read_csv(file_path, header=0 if self.config.csv_has_header else None)
|
||||
|
||||
# 提取标签列
|
||||
if self.config.label_col and self.config.label_col in df.columns:
|
||||
self.labels = df[self.config.label_col].values
|
||||
spectral_cols = [col for col in df.columns if col != self.config.label_col]
|
||||
else:
|
||||
self.labels = None
|
||||
spectral_cols = df.columns.tolist()
|
||||
|
||||
# 提取光谱数据
|
||||
if self.config.spectral_end:
|
||||
end_idx = spectral_cols.index(self.config.spectral_end) + 1
|
||||
else:
|
||||
end_idx = len(spectral_cols)
|
||||
|
||||
start_idx = spectral_cols.index(self.config.spectral_start)
|
||||
spectral_data = df[spectral_cols[start_idx:end_idx]].values
|
||||
|
||||
# 生成波长信息
|
||||
self.wavelengths = np.arange(len(spectral_cols[start_idx:end_idx]))
|
||||
|
||||
self.data = spectral_data.astype(np.float32)
|
||||
self.data_shape = self.data.shape
|
||||
|
||||
print(f"CSV数据加载完成: {self.data.shape} 样本 x {self.data.shape[1]} 波段")
|
||||
|
||||
def _load_envi_data(self, file_path: Path) -> None:
|
||||
"""加载ENVI格式数据"""
|
||||
self.input_format = 'envi'
|
||||
|
||||
try:
|
||||
# 确保文件路径存在且可访问
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"ENVI文件不存在: {file_path}")
|
||||
|
||||
# 检查文件是否可读
|
||||
if not os.access(file_path, os.R_OK):
|
||||
raise PermissionError(f"没有读取权限: {file_path}")
|
||||
|
||||
# 使用规范化的路径字符串
|
||||
file_path_str = str(file_path)
|
||||
print(f"尝试使用spectral库读取ENVI文件: {file_path_str}")
|
||||
|
||||
# 检查对应的头文件
|
||||
hdr_path = file_path.with_suffix(file_path.suffix + '.hdr')
|
||||
if not hdr_path.exists():
|
||||
# 尝试 .hdr 格式
|
||||
hdr_path = file_path.with_suffix('.hdr')
|
||||
if not hdr_path.exists():
|
||||
print(f"警告: 未找到头文件 {file_path.with_suffix(file_path.suffix + '.hdr')} 或 {file_path.with_suffix('.hdr')}")
|
||||
|
||||
# 读取ENVI文件 - 使用规范化的路径
|
||||
envi_data = spectral.open_image(file_path_str)
|
||||
|
||||
# 获取数据数组
|
||||
print("正在加载数据到内存...")
|
||||
data_array = envi_data.load()
|
||||
|
||||
# 获取波长信息
|
||||
if hasattr(envi_data, 'bands') and hasattr(envi_data.bands, 'centers'):
|
||||
self.wavelengths = np.array(envi_data.bands.centers)
|
||||
print(f"读取到波长信息: {len(self.wavelengths)} 个波段")
|
||||
else:
|
||||
self.wavelengths = np.arange(data_array.shape[2])
|
||||
print(f"未找到波长信息,使用默认波长: 0-{len(self.wavelengths)-1}")
|
||||
|
||||
# 重塑数据为 (samples, bands)
|
||||
rows, cols, bands = data_array.shape
|
||||
self.data = data_array.reshape(-1, bands).astype(np.float32)
|
||||
self.data_shape = (rows, cols, bands)
|
||||
|
||||
print(f"ENVI数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取ENVI文件失败: {str(e)}。请确保spectral库可用且文件格式正确。")
|
||||
|
||||
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
Args:
|
||||
data: 输入数据数组
|
||||
|
||||
Returns:
|
||||
处理后的数据数组
|
||||
"""
|
||||
print("正在预处理数据...")
|
||||
|
||||
# 检查数据有效性
|
||||
if data is None:
|
||||
raise ValueError("请先调用load_data()加载数据")
|
||||
|
||||
# 数据清理:移除NaN和无穷大值
|
||||
data_clean = self._clean_data(data)
|
||||
|
||||
# 标准化
|
||||
if self.config.use_scaling:
|
||||
data_scaled = self._scale_data(data_clean)
|
||||
else:
|
||||
data_scaled = data_clean
|
||||
|
||||
# PCA降维
|
||||
if self.config.use_pca:
|
||||
data_processed = self._apply_pca(data_scaled)
|
||||
else:
|
||||
data_processed = data_scaled
|
||||
|
||||
print(f"数据预处理完成: {data_processed.shape}")
|
||||
return data_processed
|
||||
|
||||
def _clean_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""清理数据:移除无效值"""
|
||||
# 移除包含NaN或无穷大值的行
|
||||
valid_mask = np.all(np.isfinite(data), axis=1)
|
||||
data_clean = data[valid_mask]
|
||||
|
||||
if np.sum(~valid_mask) > 0:
|
||||
print(f"移除了 {np.sum(~valid_mask)} 个包含无效值的样本")
|
||||
|
||||
# 移除全零行
|
||||
non_zero_mask = np.any(data_clean != 0, axis=1)
|
||||
data_clean = data_clean[non_zero_mask]
|
||||
|
||||
if np.sum(~non_zero_mask) > 0:
|
||||
print(f"移除了 {np.sum(~non_zero_mask)} 个全零样本")
|
||||
|
||||
return data_clean
|
||||
|
||||
def _scale_data(self, data: np.ndarray) -> np.ndarray:
|
||||
"""标准化数据"""
|
||||
if self.scaler is None:
|
||||
self.scaler = StandardScaler()
|
||||
|
||||
data_scaled = self.scaler.fit_transform(data)
|
||||
return data_scaled
|
||||
|
||||
def _apply_pca(self, data: np.ndarray) -> np.ndarray:
|
||||
"""应用PCA降维"""
|
||||
n_samples, n_features = data.shape
|
||||
|
||||
# 确定PCA维度
|
||||
if self.config.n_components is None:
|
||||
# 自动选择维度:保留95%的方差或最小维度
|
||||
n_components = min(n_features, max(2, int(n_features * 0.95)))
|
||||
else:
|
||||
n_components = min(self.config.n_components, n_features)
|
||||
|
||||
if n_features <= n_components:
|
||||
print(f"特征数({n_features})小于等于目标维度({n_components}),跳过PCA降维")
|
||||
return data
|
||||
|
||||
if self.pca is None:
|
||||
self.pca = PCA(n_components=n_components, random_state=42)
|
||||
|
||||
data_pca = self.pca.fit_transform(data)
|
||||
|
||||
explained_var = np.sum(self.pca.explained_variance_ratio_)
|
||||
print(f"PCA降维完成: {n_features} -> {n_components} 维度,解释方差: {explained_var:.3f}")
|
||||
|
||||
return data_pca
|
||||
|
||||
def tune_parameters(self, data: np.ndarray) -> Dict[str, Any]:
|
||||
"""
|
||||
使用网格搜索调优参数
|
||||
|
||||
Args:
|
||||
data: 训练数据
|
||||
|
||||
Returns:
|
||||
最佳参数字典
|
||||
"""
|
||||
print("正在进行参数调优...")
|
||||
|
||||
# 定义参数网格
|
||||
param_grid = {
|
||||
'nu': [0.01, 0.05, 0.1, 0.15, 0.2]
|
||||
}
|
||||
|
||||
if self.config.kernel == 'rbf':
|
||||
param_grid['gamma'] = ['auto', 'scale', 0.001, 0.01, 0.1, 1.0]
|
||||
elif self.config.kernel == 'poly':
|
||||
param_grid['gamma'] = ['auto', 'scale', 0.001, 0.01, 0.1, 1.0]
|
||||
param_grid['degree'] = [2, 3, 4]
|
||||
param_grid['coef0'] = [0.0, 0.1, 1.0]
|
||||
|
||||
# 创建基础模型
|
||||
base_svm = OneClassSVM(
|
||||
kernel=self.config.kernel,
|
||||
degree=self.config.degree,
|
||||
coef0=self.config.coef0,
|
||||
tol=self.config.tol,
|
||||
shrinking=self.config.shrinking,
|
||||
cache_size=self.config.cache_size,
|
||||
max_iter=self.config.max_iter
|
||||
)
|
||||
|
||||
# 对于One-Class SVM,我们使用决策函数的负值作为评分
|
||||
# (决策函数值越负表示越可能是异常)
|
||||
def anomaly_score(estimator, X):
|
||||
scores = estimator.decision_function(X)
|
||||
return -np.mean(scores) # 负的平均决策函数值
|
||||
|
||||
scorer = make_scorer(anomaly_score)
|
||||
|
||||
# 网格搜索
|
||||
grid_search = GridSearchCV(
|
||||
base_svm,
|
||||
param_grid,
|
||||
cv=self.config.cv_folds,
|
||||
scoring=scorer,
|
||||
n_jobs=-1,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
grid_search.fit(data)
|
||||
|
||||
self.best_params = grid_search.best_params_
|
||||
print(f"最佳参数: {self.best_params}")
|
||||
print(f"最佳评分: {grid_search.best_score_:.4f}")
|
||||
|
||||
return self.best_params
|
||||
|
||||
def train_model(self, data: np.ndarray) -> None:
|
||||
"""
|
||||
训练One-Class SVM模型
|
||||
|
||||
Args:
|
||||
data: 训练数据(正常样本)
|
||||
"""
|
||||
print("正在训练One-Class SVM模型...")
|
||||
|
||||
# 参数调优
|
||||
if self.config.use_grid_search:
|
||||
best_params = self.tune_parameters(data)
|
||||
# 更新配置参数
|
||||
for param, value in best_params.items():
|
||||
setattr(self.config, param, value)
|
||||
|
||||
# 设置gamma参数
|
||||
gamma_value = self.config.gamma
|
||||
if self.config.gamma == 'auto':
|
||||
gamma_value = 1.0 / data.shape[1] # 1/特征数
|
||||
elif self.config.gamma == 'scale':
|
||||
gamma_value = 1.0 / (data.shape[1] * data.var())
|
||||
elif self.config.gamma is None:
|
||||
gamma_value = 'scale' # 默认使用'scale'
|
||||
|
||||
# 创建One-Class SVM模型
|
||||
self.svm_model = OneClassSVM(
|
||||
kernel=self.config.kernel,
|
||||
nu=self.config.nu,
|
||||
gamma=gamma_value,
|
||||
degree=self.config.degree,
|
||||
coef0=self.config.coef0,
|
||||
tol=self.config.tol,
|
||||
shrinking=self.config.shrinking,
|
||||
cache_size=self.config.cache_size,
|
||||
max_iter=self.config.max_iter
|
||||
)
|
||||
|
||||
# 训练模型
|
||||
self.svm_model.fit(data)
|
||||
|
||||
# 计算训练集上的决策函数值
|
||||
train_scores = self.svm_model.decision_function(data)
|
||||
train_predictions = self.svm_model.predict(data)
|
||||
|
||||
n_outliers_train = np.sum(train_predictions == -1)
|
||||
outlier_ratio_train = n_outliers_train / len(data)
|
||||
|
||||
print("模型训练完成:")
|
||||
print(f" - 核函数: {self.config.kernel}")
|
||||
print(f" - nu: {self.config.nu}")
|
||||
print(f" - gamma: {gamma_value}")
|
||||
if self.config.kernel == 'poly':
|
||||
print(f" - degree: {self.config.degree}")
|
||||
print(f" - 训练集异常比例: {outlier_ratio_train:.3f}")
|
||||
print(".4f")
|
||||
|
||||
def detect_anomalies(self, data: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
检测异常
|
||||
|
||||
Args:
|
||||
data: 测试数据,如果为None则使用训练数据
|
||||
|
||||
Returns:
|
||||
predictions: 预测结果 (+1为正常,-1为异常)
|
||||
scores: 决策函数值
|
||||
"""
|
||||
if self.svm_model is None:
|
||||
raise ValueError("请先调用train_model()训练模型")
|
||||
|
||||
if data is None:
|
||||
data = self.data
|
||||
|
||||
print("正在检测异常...")
|
||||
|
||||
# 预处理数据(使用训练时的预处理器)
|
||||
data_processed = self.preprocess_data(data)
|
||||
|
||||
# 预测异常
|
||||
predictions = self.svm_model.predict(data_processed)
|
||||
scores = self.svm_model.decision_function(data_processed)
|
||||
|
||||
self.predictions = predictions
|
||||
self.decision_scores = scores
|
||||
|
||||
# 统计信息
|
||||
n_anomalies = np.sum(predictions == -1)
|
||||
anomaly_ratio = n_anomalies / len(predictions)
|
||||
|
||||
print("异常检测完成:")
|
||||
print(f" - 测试样本数: {len(predictions)}")
|
||||
print(f" - 异常样本数: {n_anomalies}")
|
||||
print(f" - 异常比例: {anomaly_ratio:.3f}")
|
||||
print(".4f")
|
||||
print(".4f")
|
||||
|
||||
return predictions, scores
|
||||
|
||||
def run_analysis_from_config(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
从配置运行完整的异常检测分析
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: (predictions, scores)
|
||||
"""
|
||||
# 预处理数据
|
||||
processed_data = self.preprocess_data(self.data)
|
||||
|
||||
# 训练模型
|
||||
self.train_model(processed_data)
|
||||
|
||||
# 执行异常检测
|
||||
predictions, scores = self.detect_anomalies()
|
||||
|
||||
return predictions, scores
|
||||
|
||||
def save_results(self, output_path: Optional[str] = None) -> None:
|
||||
"""
|
||||
保存异常检测结果为ENVI格式文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径,如果为None则使用配置中的路径
|
||||
"""
|
||||
if output_path is None:
|
||||
if self.config.output_path:
|
||||
output_path = self.config.output_path
|
||||
elif self.config.output_dir:
|
||||
base_name = Path(self.config.input_path).stem
|
||||
output_path = os.path.join(self.config.output_dir, f"{base_name}_anomaly_ocsvm.bip")
|
||||
else:
|
||||
raise ValueError("必须指定输出路径")
|
||||
|
||||
print(f"正在保存结果到: {output_path}")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = os.path.dirname(output_path)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 根据输入格式决定输出格式
|
||||
if self.input_format == 'envi':
|
||||
self._save_envi_results(output_path)
|
||||
else:
|
||||
# 对于CSV输入,创建虚拟的2D图像
|
||||
self._save_csv_results_as_envi(output_path)
|
||||
|
||||
def _save_envi_results(self, output_path: str) -> None:
|
||||
"""保存ENVI格式结果"""
|
||||
# 重塑预测结果为原始图像尺寸
|
||||
rows, cols, bands = self.data_shape
|
||||
|
||||
# 创建异常检测结果图像
|
||||
# 结果包括:预测结果、决策函数值
|
||||
anomaly_image = np.zeros((rows, cols, 2), dtype=np.float32)
|
||||
|
||||
# 需要将1D结果映射回2D图像
|
||||
predictions_2d = self.predictions.reshape(rows, cols)
|
||||
scores_2d = self.decision_scores.reshape(rows, cols)
|
||||
|
||||
anomaly_image[:, :, 0] = predictions_2d # 预测结果 (-1, 1)
|
||||
anomaly_image[:, :, 1] = scores_2d # 决策函数值
|
||||
|
||||
# 波段名称
|
||||
band_names = ["Anomaly_Prediction", "Decision_Score"]
|
||||
|
||||
# 保存为ENVI格式,参考Covariance.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_csv_results_as_envi(self, output_path: str) -> None:
|
||||
"""将CSV结果保存为ENVI格式"""
|
||||
# 对于CSV数据,创建1xN的图像
|
||||
n_samples = len(self.predictions)
|
||||
anomaly_image = np.zeros((1, n_samples, 2), dtype=np.float32)
|
||||
|
||||
anomaly_image[0, :, 0] = self.predictions # 预测结果
|
||||
anomaly_image[0, :, 1] = self.decision_scores # 决策函数值
|
||||
|
||||
# 波段名称
|
||||
band_names = ["Anomaly_Prediction", "Decision_Score"]
|
||||
|
||||
# 保存为ENVI格式,参考Covariance.py的方式
|
||||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||||
|
||||
def _save_envi_data(self, data: np.ndarray, output_path: Union[str, Path],
|
||||
band_names: List[str], interleave: str = 'bip') -> None:
|
||||
"""
|
||||
保存数据为ENVI格式,参考Covariance.py的实现
|
||||
|
||||
参数:
|
||||
data: 要保存的数据数组 (height, width, channels)
|
||||
output_path: 输出文件路径
|
||||
band_names: 波段名称列表
|
||||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
height, width, channels = data.shape
|
||||
|
||||
# 确保目录存在
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"保存ENVI文件 - 形状: {data.shape}, 交织方式: {interleave}")
|
||||
|
||||
# 保存.dat文件(二进制格式)
|
||||
self._save_dat_file(data, str(output_path))
|
||||
|
||||
# 保存.hdr头文件
|
||||
hdr_path = str(output_path).replace('.dat', '.hdr')
|
||||
self._save_hdr_file(hdr_path, data.shape, band_names, interleave)
|
||||
|
||||
print(f"One-Class SVM异常检测结果已保存: {output_path}")
|
||||
print(f"头文件已保存: {hdr_path}")
|
||||
print(f"使用的交织格式: {interleave.upper()}")
|
||||
|
||||
def _save_dat_file(self, data: np.ndarray, file_path: str) -> None:
|
||||
"""保存.dat文件(二进制格式),参考Covariance.py"""
|
||||
# 根据数据类型保存
|
||||
if data.dtype == np.float32:
|
||||
# 对于浮点数据,直接保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
else:
|
||||
# 对于其他数据类型,转换为float32保存
|
||||
with open(file_path, 'wb') as f:
|
||||
data.astype(np.float32).tofile(f)
|
||||
|
||||
def _save_hdr_file(self, hdr_path: str, data_shape: Tuple[int, ...],
|
||||
band_names: List[str], interleave: str) -> None:
|
||||
"""保存ENVI头文件,参考Covariance.py并适配One-Class SVM数据"""
|
||||
height, width, channels = data_shape
|
||||
|
||||
# 确定数据类型编码
|
||||
# ENVI数据类型: 4=float32, 5=float64, 1=uint8, 2=int16, 3=int32, 12=uint16
|
||||
data_type_code = 4 # 默认float32
|
||||
|
||||
header_content = f"""ENVI
|
||||
description = {{One-Class SVM Anomaly Detection Results - Generated by OneClassSVMAnomalyDetector}}
|
||||
samples = {width}
|
||||
lines = {height}
|
||||
bands = {channels}
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = {data_type_code}
|
||||
interleave = {interleave}
|
||||
byte order = 0
|
||||
"""
|
||||
|
||||
# 添加波段名称
|
||||
if band_names:
|
||||
header_content += "band names = {\n"
|
||||
for i, name in enumerate(band_names):
|
||||
header_content += f' "{name}"'
|
||||
if i < len(band_names) - 1:
|
||||
header_content += ","
|
||||
header_content += "\n"
|
||||
header_content += "}\n"
|
||||
|
||||
# 添加One-Class SVM相关元数据
|
||||
header_content += f"anomaly_detection_method = one_class_svm\n"
|
||||
header_content += f"kernel = {self.config.kernel}\n"
|
||||
header_content += f"nu = {self.config.nu}\n"
|
||||
if self.config.gamma is not None:
|
||||
header_content += f"gamma = {self.config.gamma}\n"
|
||||
if self.config.kernel == 'poly':
|
||||
header_content += f"degree = {self.config.degree}\n"
|
||||
header_content += f"coef0 = {self.config.coef0}\n"
|
||||
header_content += f"use_pca = {self.config.use_pca}\n"
|
||||
if self.config.use_pca and self.config.n_components is not None:
|
||||
header_content += f"n_components = {self.config.n_components}\n"
|
||||
if self.best_params:
|
||||
header_content += f"best_params = {self.best_params}\n"
|
||||
|
||||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||||
f.write(header_content)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取异常检测统计信息
|
||||
|
||||
Returns:
|
||||
包含各种统计信息的字典
|
||||
"""
|
||||
if self.predictions is None:
|
||||
raise ValueError("请先运行异常检测")
|
||||
|
||||
stats = {
|
||||
'total_samples': len(self.predictions),
|
||||
'n_anomalies': int(np.sum(self.predictions == -1)),
|
||||
'anomaly_ratio': float(np.sum(self.predictions == -1) / len(self.predictions)),
|
||||
'kernel': self.config.kernel,
|
||||
'nu': self.config.nu,
|
||||
'gamma': self.config.gamma,
|
||||
'mean_decision_score': float(np.mean(self.decision_scores)),
|
||||
'std_decision_score': float(np.std(self.decision_scores)),
|
||||
'min_decision_score': float(np.min(self.decision_scores)),
|
||||
'max_decision_score': float(np.max(self.decision_scores))
|
||||
}
|
||||
|
||||
if self.config.kernel == 'poly':
|
||||
stats['degree'] = self.config.degree
|
||||
|
||||
if self.best_params:
|
||||
stats['best_params'] = self.best_params
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数:命令行接口"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='One-Class SVM高光谱异常检测',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
使用示例:
|
||||
python One_Class_SVM.py --input data.csv --spectral-start 400 --kernel rbf --nu 0.1 --output results.bip
|
||||
python One_Class_SVM.py --input image.bip --kernel poly --degree 3 --grid-search --output results.bip
|
||||
python One_Class_SVM.py --input data.csv --spectral-start 400 --kernel linear --no-pca --output results.bip
|
||||
"""
|
||||
)
|
||||
|
||||
# 输入文件参数
|
||||
parser.add_argument('--input', required=True, help='输入文件路径 (CSV或ENVI格式)')
|
||||
parser.add_argument('--label-col', help='标签列名 (CSV文件)')
|
||||
parser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
|
||||
parser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
|
||||
|
||||
# 输出参数
|
||||
parser.add_argument('--output', help='输出文件路径')
|
||||
parser.add_argument('--output-dir', help='输出目录 (将自动生成文件名)')
|
||||
|
||||
# SVM参数
|
||||
parser.add_argument('--kernel', choices=['linear', 'rbf', 'poly'], default='rbf',
|
||||
help='核函数类型,默认rbf')
|
||||
parser.add_argument('--nu', type=float, default=0.1,
|
||||
help='训练误差比例 (0.01-0.5),默认0.1')
|
||||
parser.add_argument('--gamma', type=float,
|
||||
help='核函数gamma参数,默认auto (1/特征数)')
|
||||
parser.add_argument('--degree', type=int, default=3,
|
||||
help='多项式核的阶数,默认3')
|
||||
parser.add_argument('--coef0', type=float, default=0.0,
|
||||
help='核函数常数项,默认0.0')
|
||||
|
||||
# 预处理参数
|
||||
parser.add_argument('--no-scaling', action='store_true',
|
||||
help='禁用数据标准化')
|
||||
parser.add_argument('--no-pca', action='store_true',
|
||||
help='禁用PCA降维')
|
||||
parser.add_argument('--n-components', type=int,
|
||||
help='PCA降维维度,默认自动选择')
|
||||
|
||||
# 参数调优
|
||||
parser.add_argument('--grid-search', action='store_true',
|
||||
help='使用网格搜索调优参数')
|
||||
parser.add_argument('--cv-folds', type=int, default=3,
|
||||
help='交叉验证折数,默认3')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# 创建配置
|
||||
config = OneClassSVMConfig(
|
||||
input_path=args.input,
|
||||
label_col=args.label_col,
|
||||
spectral_start=args.spectral_start,
|
||||
spectral_end=args.spectral_end,
|
||||
output_path=args.output,
|
||||
output_dir=args.output_dir,
|
||||
kernel=args.kernel,
|
||||
nu=args.nu,
|
||||
gamma=args.gamma,
|
||||
degree=args.degree,
|
||||
coef0=args.coef0,
|
||||
use_scaling=not args.no_scaling,
|
||||
use_pca=not args.no_pca,
|
||||
n_components=args.n_components,
|
||||
use_grid_search=args.grid_search,
|
||||
cv_folds=args.cv_folds
|
||||
)
|
||||
|
||||
# 创建检测器
|
||||
detector = OneClassSVMAnomalyDetector(config)
|
||||
|
||||
# 执行异常检测流程
|
||||
print("=== One-Class SVM异常检测 ===")
|
||||
print(f"输入文件: {config.input_path}")
|
||||
print(f"核函数: {config.kernel}")
|
||||
print(f"nu: {config.nu}")
|
||||
print(f"gamma: {'auto' if config.gamma is None else config.gamma}")
|
||||
if config.kernel == 'poly':
|
||||
print(f"degree: {config.degree}")
|
||||
print(f"参数调优: {'启用' if config.use_grid_search else '禁用'}")
|
||||
print("-" * 50)
|
||||
|
||||
# 1. 加载数据
|
||||
detector.load_data()
|
||||
|
||||
# 2. 预处理数据
|
||||
processed_data = detector.preprocess_data(detector.data)
|
||||
|
||||
# 3. 训练模型
|
||||
detector.train_model(processed_data)
|
||||
|
||||
# 4. 执行异常检测
|
||||
predictions, scores = detector.detect_anomalies(processed_data)
|
||||
|
||||
# 5. 保存结果
|
||||
detector.save_results()
|
||||
|
||||
# 6. 显示统计信息
|
||||
stats = detector.get_statistics()
|
||||
print("\n=== 检测结果统计 ===")
|
||||
for key, value in stats.items():
|
||||
print(f"{key}: {value}")
|
||||
|
||||
print("\n处理完成!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user