824 lines
30 KiB
Python
824 lines
30 KiB
Python
import numpy as np
|
||
import pandas as pd
|
||
import spectral
|
||
from sklearn.linear_model import LinearRegression
|
||
from sklearn.preprocessing import StandardScaler, RobustScaler
|
||
from sklearn.metrics import mean_squared_error
|
||
import os
|
||
import warnings
|
||
from typing import Optional, Dict, Any, Tuple, List, Union
|
||
from dataclasses import dataclass, field
|
||
import argparse
|
||
from pathlib import Path
|
||
from scipy import stats
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
@dataclass
|
||
class SquaredLossAnomalyConfig:
|
||
"""基于最小二乘的自重构异常检测配置类"""
|
||
|
||
# 输入文件配置
|
||
input_path: Optional[str] = None # 输入数据路径
|
||
label_col: Optional[str] = None
|
||
spectral_start: Optional[str] = None
|
||
spectral_end: Optional[str] = None
|
||
csv_has_header: bool = True
|
||
|
||
# 输出配置
|
||
output_path: Optional[str] = None
|
||
output_dir: Optional[str] = None
|
||
|
||
# 自重构参数
|
||
reconstruction_method: str = 'linear' # 重构方法: 'linear', 'pca'
|
||
n_components: Optional[int] = None # PCA降维维度(当reconstruction_method='pca'时使用)
|
||
target_band: Optional[int] = None # 目标波段索引(用于linear方法,None表示使用最后一个波段)
|
||
|
||
# 检测参数
|
||
probability_flag: bool = True # True: 概率模式, False: 分类模式
|
||
threshold: float = 0.8 # 分类阈值(当probability_flag=False时使用)
|
||
use_scaling: bool = True # 是否使用标准化
|
||
scaler_type: str = 'robust' # 标准化类型: 'standard' 或 'robust'
|
||
|
||
# 模型参数
|
||
fit_intercept: bool = True # 是否拟合截距
|
||
normalize: bool = False # 弃用参数,保持兼容性
|
||
|
||
def __post_init__(self):
|
||
"""参数校验和默认值设置"""
|
||
# 校验必需的文件路径
|
||
if not self.input_path:
|
||
raise ValueError("必须指定输入数据路径(input_path)")
|
||
|
||
# 校验文件存在性
|
||
if not os.path.exists(self.input_path):
|
||
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
|
||
|
||
# 校验重构方法
|
||
supported_methods = ['linear', 'pca']
|
||
if self.reconstruction_method not in supported_methods:
|
||
raise ValueError(f"不支持的重构方法: {self.reconstruction_method}。支持的方法: {supported_methods}")
|
||
|
||
# 校验阈值范围
|
||
if not 0 <= self.threshold <= 1:
|
||
raise ValueError("threshold必须在[0, 1]范围内")
|
||
|
||
# 统一标准化类型为小写
|
||
self.scaler_type = self.scaler_type.lower()
|
||
|
||
# 校验标准化类型
|
||
if self.scaler_type not in ['standard', 'robust']:
|
||
raise ValueError(f"不支持的标准化类型: {self.scaler_type}。支持的类型: ['standard', 'robust']")
|
||
|
||
# 统一重构方法为小写
|
||
self.reconstruction_method = self.reconstruction_method.lower()
|
||
|
||
|
||
class SquaredLossAnomalyDetector:
|
||
"""
|
||
基于最小二乘的自重构异常检测器
|
||
|
||
算法原理:
|
||
1. 使用自重构方法学习数据的正常模式
|
||
2. 计算每个样本的重构误差作为异常分数
|
||
3. 将误差归一化为0-1的异常概率
|
||
4. 根据阈值进行二元分类
|
||
|
||
支持的重构方法:
|
||
- linear: 使用线性回归进行特征重构
|
||
- pca: 使用PCA进行降维重构
|
||
|
||
参考:基于重构误差的无监督异常检测方法
|
||
"""
|
||
|
||
def __init__(self, config: SquaredLossAnomalyConfig):
|
||
"""
|
||
初始化异常检测器
|
||
|
||
Args:
|
||
config: 配置对象
|
||
"""
|
||
self.config = config
|
||
|
||
# 输入数据
|
||
self.data = None
|
||
self.labels = None
|
||
self.wavelengths = None
|
||
self.data_shape = None
|
||
self.input_format = None
|
||
|
||
# 预处理组件
|
||
self.scaler = None
|
||
|
||
# 重构模型
|
||
self.model = None
|
||
self.pca_model = None
|
||
|
||
# 结果
|
||
self.reconstruction_errors = None
|
||
self.anomaly_probabilities = None
|
||
self.predictions = None
|
||
|
||
def load_data(self) -> None:
|
||
"""
|
||
加载高光谱数据文件
|
||
|
||
支持CSV和ENVI格式文件
|
||
"""
|
||
file_path = Path(self.config.input_path)
|
||
suffix = file_path.suffix.lower()
|
||
|
||
print(f"正在读取数据: {file_path}")
|
||
|
||
if suffix == '.csv':
|
||
self._load_csv_data(file_path)
|
||
elif suffix in ['.hdr']:
|
||
self._load_envi_data(file_path)
|
||
else:
|
||
raise ValueError(f"不支持的文件格式: {suffix}。支持的格式: .hdr")
|
||
|
||
def _load_csv_data(self, file_path: Path) -> None:
|
||
"""加载CSV格式数据"""
|
||
if self.config.spectral_start is None:
|
||
raise ValueError("对于CSV文件,必须指定spectral_start参数")
|
||
|
||
self.input_format = 'csv'
|
||
|
||
# 读取数据
|
||
df = pd.read_csv(file_path, header=0 if self.config.csv_has_header else None)
|
||
|
||
# 提取标签列
|
||
if self.config.label_col and self.config.label_col in df.columns:
|
||
self.labels = df[self.config.label_col].values
|
||
spectral_cols = [col for col in df.columns if col != self.config.label_col]
|
||
else:
|
||
self.labels = None
|
||
spectral_cols = df.columns.tolist()
|
||
|
||
# 提取光谱数据
|
||
if self.config.spectral_end:
|
||
end_idx = spectral_cols.index(self.config.spectral_end) + 1
|
||
else:
|
||
end_idx = len(spectral_cols)
|
||
|
||
start_idx = spectral_cols.index(self.config.spectral_start)
|
||
spectral_data = df[spectral_cols[start_idx:end_idx]].values
|
||
|
||
# 生成波长信息
|
||
self.wavelengths = np.arange(len(spectral_cols[start_idx:end_idx]))
|
||
|
||
self.data = spectral_data.astype(np.float32)
|
||
self.data_shape = self.data.shape
|
||
|
||
print(f"数据加载完成: {self.data.shape} 样本 x {self.data.shape[1]} 波段")
|
||
|
||
def _load_envi_data(self, file_path: Path) -> None:
|
||
"""加载ENVI格式数据"""
|
||
self.input_format = 'envi'
|
||
|
||
try:
|
||
# 确保文件路径存在且可访问
|
||
file_path = Path(file_path)
|
||
if not file_path.exists():
|
||
raise FileNotFoundError(f"ENVI文件不存在: {file_path}")
|
||
|
||
# 检查文件是否可读
|
||
if not os.access(file_path, os.R_OK):
|
||
raise PermissionError(f"没有读取权限: {file_path}")
|
||
|
||
# 使用规范化的路径字符串
|
||
file_path_str = str(file_path)
|
||
print(f"尝试使用spectral库读取ENVI文件: {file_path_str}")
|
||
|
||
# 检查对应的头文件
|
||
hdr_path = file_path.with_suffix(file_path.suffix + '.hdr')
|
||
if not hdr_path.exists():
|
||
# 尝试 .hdr 格式
|
||
hdr_path = file_path.with_suffix('.hdr')
|
||
if not hdr_path.exists():
|
||
print(f"警告: 未找到头文件 {file_path.with_suffix(file_path.suffix + '.hdr')} 或 {file_path.with_suffix('.hdr')}")
|
||
|
||
# 读取ENVI文件 - 使用规范化的路径
|
||
envi_data = spectral.open_image(file_path_str)
|
||
|
||
# 获取数据数组
|
||
print("正在加载数据到内存...")
|
||
data_array = envi_data.load()
|
||
|
||
# 获取波长信息
|
||
if hasattr(envi_data, 'bands') and hasattr(envi_data.bands, 'centers'):
|
||
self.wavelengths = np.array(envi_data.bands.centers)
|
||
print(f"读取到波长信息: {len(self.wavelengths)} 个波段")
|
||
else:
|
||
self.wavelengths = np.arange(data_array.shape[2])
|
||
print(f"未找到波长信息,使用默认波长: 0-{len(self.wavelengths)-1}")
|
||
|
||
# 重塑数据为 (samples, bands)
|
||
rows, cols, bands = data_array.shape
|
||
self.data = data_array.reshape(-1, bands).astype(np.float32)
|
||
self.data_shape = (rows, cols, bands)
|
||
|
||
print(f"数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||
|
||
except Exception as e:
|
||
raise ValueError(f"读取ENVI文件失败: {str(e)}。请确保spectral库可用且文件格式正确。")
|
||
|
||
def preprocess_data(self, data: np.ndarray) -> np.ndarray:
|
||
"""
|
||
数据预处理
|
||
|
||
Args:
|
||
data: 输入数据
|
||
|
||
Returns:
|
||
处理后的数据
|
||
"""
|
||
# 数据清理
|
||
data_clean = self._clean_data(data)
|
||
|
||
# 标准化
|
||
if self.config.use_scaling:
|
||
data_processed = self._scale_data(data_clean)
|
||
else:
|
||
data_processed = data_clean
|
||
|
||
return data_processed
|
||
|
||
def _clean_data(self, data: np.ndarray) -> np.ndarray:
|
||
"""清理数据:移除无效值"""
|
||
# 移除包含NaN或无穷大值的行
|
||
valid_mask = np.all(np.isfinite(data), axis=1)
|
||
data_clean = data[valid_mask]
|
||
|
||
if np.sum(~valid_mask) > 0:
|
||
print(f"移除了 {np.sum(~valid_mask)} 个包含无效值的样本")
|
||
|
||
# 移除全零行
|
||
non_zero_mask = np.any(data_clean != 0, axis=1)
|
||
data_clean = data_clean[non_zero_mask]
|
||
|
||
if np.sum(~non_zero_mask) > 0:
|
||
print(f"移除了 {np.sum(~non_zero_mask)} 个全零样本")
|
||
|
||
return data_clean
|
||
|
||
def _scale_data(self, data: np.ndarray) -> np.ndarray:
|
||
"""标准化数据"""
|
||
if self.scaler is None:
|
||
if self.config.scaler_type == 'robust':
|
||
self.scaler = RobustScaler()
|
||
else:
|
||
self.scaler = StandardScaler()
|
||
|
||
# 使用输入数据拟合标准化器
|
||
self.scaler.fit(data)
|
||
|
||
data_scaled = self.scaler.transform(data)
|
||
return data_scaled
|
||
|
||
def build_reconstruction_model(self) -> None:
|
||
"""
|
||
构建自重构模型
|
||
|
||
根据配置的重构方法构建相应的重构模型
|
||
"""
|
||
if self.data is None:
|
||
raise ValueError("请先调用load_data()加载数据")
|
||
|
||
print(f"正在构建{self.config.reconstruction_method}自重构模型...")
|
||
|
||
# 预处理数据
|
||
data_processed = self.preprocess_data(self.data)
|
||
|
||
if self.config.reconstruction_method == 'linear':
|
||
self._build_linear_reconstruction_model(data_processed)
|
||
elif self.config.reconstruction_method == 'pca':
|
||
self._build_pca_reconstruction_model(data_processed)
|
||
|
||
def _build_linear_reconstruction_model(self, data: np.ndarray) -> None:
|
||
"""
|
||
构建线性重构模型
|
||
|
||
使用前n-1个波段预测第n个波段
|
||
"""
|
||
n_samples, n_features = data.shape
|
||
|
||
if n_features < 2:
|
||
raise ValueError("数据至少需要2个波段进行线性重构")
|
||
|
||
# 确定目标波段
|
||
if self.config.target_band is None:
|
||
target_band = -1 # 默认使用最后一个波段
|
||
else:
|
||
target_band = self.config.target_band
|
||
if target_band >= n_features or target_band < 0:
|
||
raise ValueError(f"目标波段索引 {target_band} 超出范围 [0, {n_features-1}]")
|
||
|
||
# 准备训练数据
|
||
X_train = np.delete(data, target_band, axis=1) # 移除目标波段作为特征
|
||
y_train = data[:, target_band] # 目标波段
|
||
|
||
# 训练线性回归模型
|
||
self.model = LinearRegression(
|
||
fit_intercept=self.config.fit_intercept
|
||
)
|
||
|
||
self.model.fit(X_train, y_train)
|
||
|
||
# 计算训练重构误差
|
||
y_pred = self.model.predict(X_train)
|
||
train_mse = mean_squared_error(y_train, y_pred)
|
||
train_rmse = np.sqrt(train_mse)
|
||
|
||
print(f"线性重构模型构建完成:")
|
||
print(f" - 目标波段: {target_band}")
|
||
print(f" - 特征波段数: {X_train.shape[1]}")
|
||
print(f" - 训练MSE: {train_mse:.4f}")
|
||
print(f" - 训练RMSE: {train_rmse:.4f}")
|
||
|
||
def _build_pca_reconstruction_model(self, data: np.ndarray) -> None:
|
||
"""
|
||
构建PCA重构模型
|
||
|
||
使用PCA进行降维重构
|
||
"""
|
||
from sklearn.decomposition import PCA
|
||
|
||
n_samples, n_features = data.shape
|
||
|
||
# 确定PCA维度
|
||
if self.config.n_components is None:
|
||
n_components = min(n_features, max(2, n_features // 2))
|
||
else:
|
||
n_components = min(self.config.n_components, n_features)
|
||
|
||
# 训练PCA模型
|
||
self.pca_model = PCA(n_components=n_components, random_state=42)
|
||
data_pca = self.pca_model.fit_transform(data)
|
||
|
||
# 重构数据
|
||
data_reconstructed = self.pca_model.inverse_transform(data_pca)
|
||
|
||
# 计算重构误差
|
||
reconstruction_errors = np.mean((data - data_reconstructed) ** 2, axis=1)
|
||
mean_error = np.mean(reconstruction_errors)
|
||
std_error = np.std(reconstruction_errors)
|
||
|
||
explained_var = np.sum(self.pca_model.explained_variance_ratio_)
|
||
|
||
print(f"PCA重构模型构建完成:")
|
||
print(f" - 原始维度: {n_features}")
|
||
print(f" - PCA维度: {n_components}")
|
||
print(f" - 解释方差比例: {explained_var:.3f}")
|
||
print(f" - 平均重构误差: {mean_error:.4f}")
|
||
print(f" - 重构误差标准差: {std_error:.4f}")
|
||
|
||
def detect_anomalies(self) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
使用自重构模型检测异常
|
||
|
||
Returns:
|
||
predictions: 预测结果
|
||
probabilities: 异常概率
|
||
"""
|
||
if self.model is None and self.pca_model is None:
|
||
raise ValueError("请先调用build_reconstruction_model()构建重构模型")
|
||
|
||
if self.data is None:
|
||
raise ValueError("请先调用load_data()加载数据")
|
||
|
||
print("正在进行自重构异常检测...")
|
||
|
||
# 预处理数据
|
||
data_processed = self.preprocess_data(self.data)
|
||
|
||
# 计算重构误差
|
||
if self.config.reconstruction_method == 'linear':
|
||
reconstruction_errors = self._compute_linear_reconstruction_errors(data_processed)
|
||
elif self.config.reconstruction_method == 'pca':
|
||
reconstruction_errors = self._compute_pca_reconstruction_errors(data_processed)
|
||
|
||
# 将误差归一化为概率
|
||
anomaly_probabilities = self._errors_to_probabilities(reconstruction_errors)
|
||
|
||
# 根据输出模式生成预测结果
|
||
if self.config.probability_flag:
|
||
# 概率模式:输出连续概率值
|
||
predictions = anomaly_probabilities
|
||
else:
|
||
# 分类模式:基于阈值的二元分类
|
||
predictions = (anomaly_probabilities > self.config.threshold).astype(int)
|
||
|
||
self.reconstruction_errors = reconstruction_errors
|
||
self.anomaly_probabilities = anomaly_probabilities
|
||
self.predictions = predictions
|
||
|
||
# 统计信息
|
||
if self.config.probability_flag:
|
||
n_anomalies = np.sum(predictions > self.config.threshold)
|
||
anomaly_ratio = n_anomalies / len(predictions) if len(predictions) > 0 else 0
|
||
else:
|
||
n_anomalies = np.sum(predictions == 1)
|
||
anomaly_ratio = n_anomalies / len(predictions) if len(predictions) > 0 else 0
|
||
|
||
print(f"异常检测完成:")
|
||
print(f" - 样本数: {len(predictions)}")
|
||
print(f" - 异常样本数: {n_anomalies}")
|
||
print(f" - 异常比例: {anomaly_ratio:.3f}")
|
||
|
||
return predictions, anomaly_probabilities
|
||
|
||
def run_analysis_from_config(self) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
从配置运行完整的异常检测分析
|
||
|
||
Returns:
|
||
Tuple[np.ndarray, np.ndarray]: (predictions, scores)
|
||
"""
|
||
# 构建重构模型
|
||
self.build_reconstruction_model()
|
||
|
||
# 执行异常检测
|
||
predictions, scores = self.detect_anomalies()
|
||
|
||
return predictions, scores
|
||
|
||
def _compute_linear_reconstruction_errors(self, data: np.ndarray) -> np.ndarray:
|
||
"""
|
||
计算线性重构误差
|
||
|
||
Args:
|
||
data: 预处理后的数据
|
||
|
||
Returns:
|
||
重构误差数组
|
||
"""
|
||
# 确定目标波段
|
||
if self.config.target_band is None:
|
||
target_band = -1
|
||
else:
|
||
target_band = self.config.target_band
|
||
|
||
# 准备特征数据
|
||
X_data = np.delete(data, target_band, axis=1) # 移除目标波段作为特征
|
||
y_true = data[:, target_band] # 真实目标值
|
||
|
||
# 预测
|
||
y_pred = self.model.predict(X_data)
|
||
|
||
# 计算重构误差(使用绝对误差)
|
||
reconstruction_errors = np.abs(y_true - y_pred)
|
||
|
||
return reconstruction_errors
|
||
|
||
def _compute_pca_reconstruction_errors(self, data: np.ndarray) -> np.ndarray:
|
||
"""
|
||
计算PCA重构误差
|
||
|
||
Args:
|
||
data: 预处理后的数据
|
||
|
||
Returns:
|
||
重构误差数组
|
||
"""
|
||
# PCA降维
|
||
data_pca = self.pca_model.transform(data)
|
||
|
||
# 重构
|
||
data_reconstructed = self.pca_model.inverse_transform(data_pca)
|
||
|
||
# 计算重构误差(使用均方误差)
|
||
reconstruction_errors = np.mean((data - data_reconstructed) ** 2, axis=1)
|
||
|
||
return reconstruction_errors
|
||
|
||
def _errors_to_probabilities(self, errors: np.ndarray) -> np.ndarray:
|
||
"""
|
||
将重构误差转换为异常概率
|
||
|
||
使用经验累积分布函数(ECDF)将误差映射到[0,1]概率区间
|
||
|
||
Args:
|
||
errors: 重构误差数组
|
||
|
||
Returns:
|
||
异常概率数组
|
||
"""
|
||
# 计算经验累积分布
|
||
sorted_errors = np.sort(errors)
|
||
ecdf = np.arange(1, len(sorted_errors) + 1) / len(sorted_errors)
|
||
|
||
# 使用线性插值将误差映射到概率
|
||
probabilities = np.interp(errors, sorted_errors, ecdf)
|
||
|
||
return probabilities
|
||
|
||
def save_results(self, output_path: Optional[str] = None) -> None:
|
||
"""
|
||
保存异常检测结果为单波段ENVI格式文件
|
||
|
||
Args:
|
||
output_path: 输出文件路径,如果为None则使用配置中的路径
|
||
"""
|
||
if output_path is None:
|
||
if self.config.output_path:
|
||
output_path = self.config.output_path
|
||
elif self.config.output_dir:
|
||
base_name = Path(self.config.test_path or self.config.train_path).stem
|
||
output_path = os.path.join(self.config.output_dir, f"{base_name}_anomaly_squared_loss.bip")
|
||
else:
|
||
raise ValueError("必须指定输出路径")
|
||
|
||
print(f"正在保存结果到: {output_path}")
|
||
|
||
# 创建输出目录
|
||
output_dir = os.path.dirname(output_path)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 根据输入格式决定输出格式
|
||
if self.input_format == 'envi':
|
||
self._save_envi_results(output_path)
|
||
else:
|
||
# 对于CSV输入,创建虚拟的2D图像
|
||
self._save_csv_results_as_envi(output_path)
|
||
|
||
def _save_envi_results(self, output_path: str) -> None:
|
||
"""保存ENVI格式结果为单波段图像"""
|
||
# 重塑预测结果为原始图像尺寸
|
||
rows, cols, bands = self.data_shape
|
||
|
||
# 创建单波段异常检测结果图像
|
||
anomaly_image = np.zeros((rows, cols, 1), dtype=np.float32)
|
||
|
||
# 将1D结果映射回2D图像
|
||
result_2d = self.predictions.reshape(rows, cols)
|
||
anomaly_image[:, :, 0] = result_2d
|
||
|
||
# 波段名称
|
||
band_names = ["Anomaly_Result"]
|
||
|
||
# 保存为ENVI格式,参考Covariance.py的方式
|
||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||
|
||
def _save_csv_results_as_envi(self, output_path: str) -> None:
|
||
"""将CSV结果保存为单波段ENVI格式"""
|
||
# 对于CSV数据,创建1xN的单波段图像
|
||
n_samples = len(self.predictions)
|
||
anomaly_image = np.zeros((1, n_samples, 1), dtype=np.float32)
|
||
|
||
anomaly_image[0, :, 0] = self.predictions
|
||
|
||
# 波段名称
|
||
band_names = ["Anomaly_Result"]
|
||
|
||
# 保存为ENVI格式,参考Covariance.py的方式
|
||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||
|
||
def _save_envi_data(self, data: np.ndarray, output_path: Union[str, Path],
|
||
band_names: List[str], interleave: str = 'bip') -> None:
|
||
"""
|
||
保存数据为ENVI格式,参考Covariance.py的实现
|
||
|
||
参数:
|
||
data: 要保存的数据数组 (height, width, channels)
|
||
output_path: 输出文件路径
|
||
band_names: 波段名称列表
|
||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||
"""
|
||
output_path = Path(output_path)
|
||
height, width, channels = data.shape
|
||
|
||
# 确保目录存在
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"保存ENVI文件 - 形状: {data.shape}, 交织方式: {interleave}")
|
||
|
||
# 保存.dat文件(二进制格式)
|
||
self._save_dat_file(data, str(output_path))
|
||
|
||
# 保存.hdr头文件
|
||
hdr_path = str(output_path).replace('.dat', '.hdr')
|
||
self._save_hdr_file(hdr_path, data.shape, band_names, interleave)
|
||
|
||
print(f"最小二乘异常检测结果已保存: {output_path}")
|
||
print(f"头文件已保存: {hdr_path}")
|
||
print(f"使用的交织格式: {interleave.upper()}")
|
||
|
||
def _save_dat_file(self, data: np.ndarray, file_path: str) -> None:
|
||
"""保存.dat文件(二进制格式),参考Covariance.py"""
|
||
# 根据数据类型保存
|
||
if data.dtype == np.float32:
|
||
# 对于浮点数据,直接保存
|
||
with open(file_path, 'wb') as f:
|
||
data.astype(np.float32).tofile(f)
|
||
else:
|
||
# 对于其他数据类型,转换为float32保存
|
||
with open(file_path, 'wb') as f:
|
||
data.astype(np.float32).tofile(f)
|
||
|
||
def _save_hdr_file(self, hdr_path: str, data_shape: Tuple[int, ...],
|
||
band_names: List[str], interleave: str) -> None:
|
||
"""保存ENVI头文件,参考Covariance.py并适配最小二乘数据"""
|
||
height, width, channels = data_shape
|
||
|
||
# 确定数据类型编码
|
||
# ENVI数据类型: 4=float32, 5=float64, 1=uint8, 2=int16, 3=int32, 12=uint16
|
||
data_type_code = 4 # 默认float32
|
||
|
||
header_content = f"""ENVI
|
||
description = {{Squared Loss Anomaly Detection Results - Generated by SquaredLossAnomalyDetector}}
|
||
samples = {width}
|
||
lines = {height}
|
||
bands = {channels}
|
||
header offset = 0
|
||
file type = ENVI Standard
|
||
data type = {data_type_code}
|
||
interleave = {interleave}
|
||
byte order = 0
|
||
"""
|
||
|
||
# 添加波段名称
|
||
if band_names:
|
||
header_content += "band names = {\n"
|
||
for i, name in enumerate(band_names):
|
||
header_content += f' "{name}"'
|
||
if i < len(band_names) - 1:
|
||
header_content += ","
|
||
header_content += "\n"
|
||
header_content += "}\n"
|
||
|
||
# 添加自重构异常检测相关元数据
|
||
header_content += f"anomaly_detection_method = self_reconstruction\n"
|
||
header_content += f"reconstruction_method = {self.config.reconstruction_method}\n"
|
||
if self.config.reconstruction_method == 'pca' and self.config.n_components is not None:
|
||
header_content += f"n_components = {self.config.n_components}\n"
|
||
if self.config.target_band is not None:
|
||
header_content += f"target_band = {self.config.target_band}\n"
|
||
header_content += f"probability_flag = {self.config.probability_flag}\n"
|
||
if not self.config.probability_flag:
|
||
header_content += f"threshold = {self.config.threshold}\n"
|
||
header_content += f"use_scaling = {self.config.use_scaling}\n"
|
||
header_content += f"scaler_type = {self.config.scaler_type}\n"
|
||
header_content += f"fit_intercept = {self.config.fit_intercept}\n"
|
||
|
||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||
f.write(header_content)
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""
|
||
获取异常检测统计信息
|
||
|
||
Returns:
|
||
包含各种统计信息的字典
|
||
"""
|
||
if self.predictions is None:
|
||
raise ValueError("请先运行异常检测")
|
||
|
||
stats = {
|
||
'total_samples': len(self.predictions),
|
||
'output_mode': 'probability' if self.config.probability_flag else 'classification',
|
||
'threshold': self.config.threshold if not self.config.probability_flag else None,
|
||
'mean_reconstruction_error': float(np.mean(self.reconstruction_errors)),
|
||
'std_reconstruction_error': float(np.std(self.reconstruction_errors)),
|
||
'min_reconstruction_error': float(np.min(self.reconstruction_errors)),
|
||
'max_reconstruction_error': float(np.max(self.reconstruction_errors)),
|
||
'mean_anomaly_probability': float(np.mean(self.anomaly_probabilities)),
|
||
'std_anomaly_probability': float(np.std(self.anomaly_probabilities)),
|
||
}
|
||
|
||
if self.config.probability_flag:
|
||
stats['n_anomalies'] = int(np.sum(self.predictions > self.config.threshold))
|
||
stats['anomaly_ratio'] = float(np.sum(self.predictions > self.config.threshold) / len(self.predictions))
|
||
else:
|
||
stats['n_anomalies'] = int(np.sum(self.predictions == 1))
|
||
stats['anomaly_ratio'] = float(np.sum(self.predictions == 1) / len(self.predictions))
|
||
|
||
return stats
|
||
|
||
|
||
def main():
|
||
"""主函数:命令行接口"""
|
||
parser = argparse.ArgumentParser(
|
||
description='基于最小二乘的高光谱异常检测',
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
使用示例:
|
||
# 线性重构 + 概率模式
|
||
python squared_loss_probability.py --input data.csv --spectral-start 400 --reconstruction-method linear --probability --output results.bip
|
||
|
||
# PCA重构 + 分类模式
|
||
python squared_loss_probability.py --input image.bip --reconstruction-method pca --n-components 10 --threshold 0.8 --output results.bip
|
||
|
||
# 指定目标波段的线性重构
|
||
python squared_loss_probability.py --input data.csv --spectral-start 400 --target-band 50 --output results.bip
|
||
"""
|
||
)
|
||
|
||
# 输入文件参数
|
||
parser.add_argument('--input', required=True, help='输入数据路径 (CSV或ENVI格式)')
|
||
parser.add_argument('--label-col', help='标签列名 (CSV文件)')
|
||
parser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
|
||
parser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
|
||
|
||
# 输出参数
|
||
parser.add_argument('--output', help='输出文件路径')
|
||
parser.add_argument('--output-dir', help='输出目录 (将自动生成文件名)')
|
||
|
||
# 自重构参数
|
||
parser.add_argument('--reconstruction-method', choices=['linear', 'pca'], default='linear',
|
||
help='重构方法 (linear: 线性回归, pca: PCA降维,默认linear)')
|
||
parser.add_argument('--n-components', type=int,
|
||
help='PCA降维维度 (仅在reconstruction-method=pca时使用)')
|
||
parser.add_argument('--target-band', type=int,
|
||
help='目标波段索引 (仅在reconstruction-method=linear时使用,默认使用最后一个波段)')
|
||
|
||
# 检测参数
|
||
parser.add_argument('--probability', action='store_true', default=True,
|
||
help='概率模式输出 (0-1连续值,默认)')
|
||
parser.add_argument('--classification', action='store_true',
|
||
help='分类模式输出 (0/1二元分类)')
|
||
parser.add_argument('--threshold', type=float, default=0.8,
|
||
help='分类阈值 (分类模式时使用,默认0.8)')
|
||
|
||
# 预处理参数
|
||
parser.add_argument('--no-scaling', action='store_true',
|
||
help='禁用数据标准化')
|
||
parser.add_argument('--scaler-type', choices=['standard', 'robust'], default='robust',
|
||
help='标准化类型,默认robust')
|
||
|
||
# 模型参数
|
||
parser.add_argument('--no-intercept', action='store_true',
|
||
help='不拟合截距项')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 处理输出模式
|
||
if args.classification:
|
||
probability_flag = False
|
||
else:
|
||
probability_flag = args.probability
|
||
|
||
try:
|
||
# 创建配置
|
||
config = SquaredLossAnomalyConfig(
|
||
input_path=args.input,
|
||
label_col=args.label_col,
|
||
spectral_start=args.spectral_start,
|
||
spectral_end=args.spectral_end,
|
||
output_path=args.output,
|
||
output_dir=args.output_dir,
|
||
reconstruction_method=args.reconstruction_method,
|
||
n_components=args.n_components,
|
||
target_band=args.target_band,
|
||
probability_flag=probability_flag,
|
||
threshold=args.threshold,
|
||
use_scaling=not args.no_scaling,
|
||
scaler_type=args.scaler_type,
|
||
fit_intercept=not args.no_intercept
|
||
)
|
||
|
||
# 创建检测器
|
||
detector = SquaredLossAnomalyDetector(config)
|
||
|
||
# 执行异常检测流程
|
||
print("=== 基于自重构的异常检测 ===")
|
||
print(f"输入数据: {config.input_path}")
|
||
print(f"重构方法: {config.reconstruction_method}")
|
||
print(f"输出模式: {'概率' if config.probability_flag else '分类'}")
|
||
if not config.probability_flag:
|
||
print(f"分类阈值: {config.threshold}")
|
||
print("-" * 50)
|
||
|
||
# 1. 加载数据
|
||
detector.load_data()
|
||
|
||
# 2. 构建自重构模型
|
||
detector.build_reconstruction_model()
|
||
|
||
# 3. 执行异常检测
|
||
predictions, probabilities = detector.detect_anomalies()
|
||
|
||
# 5. 保存结果
|
||
detector.save_results()
|
||
|
||
# 6. 显示统计信息
|
||
stats = detector.get_statistics()
|
||
print("\n=== 检测结果统计 ===")
|
||
for key, value in stats.items():
|
||
if value is not None:
|
||
print(f"{key}: {value}")
|
||
|
||
print("\n处理完成!")
|
||
|
||
except Exception as e:
|
||
print(f"错误: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit(main())
|