891 lines
33 KiB
Python
891 lines
33 KiB
Python
import numpy as np
|
||
import pandas as pd
|
||
# 尝试导入 spectral 库
|
||
try:
|
||
import spectral
|
||
SPECTRAL_AVAILABLE = True
|
||
except ImportError:
|
||
SPECTRAL_AVAILABLE = False
|
||
print("警告: spectral库不可用,将无法处理ENVI格式文件")
|
||
|
||
from sklearn.covariance import EllipticEnvelope
|
||
from sklearn.decomposition import PCA
|
||
from sklearn.manifold import TSNE
|
||
from sklearn.preprocessing import StandardScaler, RobustScaler
|
||
import os
|
||
import warnings
|
||
from typing import Optional, Dict, Any, Tuple, List, Union
|
||
from dataclasses import dataclass, field
|
||
import argparse
|
||
from pathlib import Path
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
@dataclass
|
||
class CovarianceAnomalyConfig:
|
||
"""
|
||
基于协方差估计的异常检测配置类
|
||
|
||
支持的文件格式:
|
||
- CSV文件: 需要指定光谱数据列范围
|
||
- ENVI格式: .bip/.bil/.bsq文件,需要对应的.hdr头文件
|
||
|
||
ENVI文件要求:
|
||
- 数据文件和头文件在同一目录
|
||
- 头文件扩展名为 .hdr
|
||
- 支持的数据类型包括: byte, int16, uint16, float32等
|
||
"""
|
||
|
||
# 输入文件配置
|
||
input_path: Optional[str] = None
|
||
label_col: Optional[str] = None
|
||
spectral_start: Optional[str] = None
|
||
spectral_end: Optional[str] = None
|
||
csv_has_header: bool = True
|
||
|
||
# 输出配置
|
||
output_path: Optional[str] = None
|
||
output_dir: Optional[str] = None
|
||
|
||
# 异常检测参数
|
||
contamination: float = 0.1 # 异常点比例 (0.1 = 10%)
|
||
support_fraction: Optional[float] = None # MCD支持比例,默认为None使用自动选择
|
||
assume_centered: bool = False # 数据是否已中心化
|
||
|
||
# 数据预处理配置
|
||
use_dimensionality_reduction: bool = True # 是否使用降维
|
||
reduction_method: str = 'pca' # 降维方法: 'pca' 或 't-sne'
|
||
n_components: int = 10 # 降维后的维度
|
||
use_scaling: bool = True # 是否使用标准化
|
||
scaler_type: str = 'robust' # 标准化类型: 'standard' 或 'robust'
|
||
|
||
# 随机种子
|
||
random_state: int = 42
|
||
|
||
def __post_init__(self):
|
||
"""参数校验和默认值设置"""
|
||
# 校验必需的文件路径
|
||
if not self.input_path:
|
||
raise ValueError("必须指定输入文件路径(input_path)")
|
||
|
||
# 清理和规范化路径
|
||
self.input_path = self._normalize_path(self.input_path)
|
||
if self.output_path:
|
||
self.output_path = self._normalize_path(self.output_path)
|
||
if self.output_dir:
|
||
self.output_dir = self._normalize_path(self.output_dir)
|
||
|
||
# 校验文件存在性(使用规范化后的路径)
|
||
if not os.path.exists(self.input_path):
|
||
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
|
||
|
||
# 校验输出路径
|
||
if not self.output_path and not self.output_dir:
|
||
raise ValueError("必须指定输出路径(output_path)或输出目录(output_dir)")
|
||
|
||
# 校验参数范围
|
||
if not 0 < self.contamination <= 0.5:
|
||
raise ValueError("contamination必须在(0, 0.5]范围内")
|
||
|
||
def _normalize_path(self, path: str) -> str:
|
||
"""
|
||
规范化路径:移除引号、展开用户目录、转换为绝对路径
|
||
|
||
Args:
|
||
path: 输入路径字符串
|
||
|
||
Returns:
|
||
规范化后的绝对路径
|
||
"""
|
||
if not path:
|
||
return path
|
||
|
||
# 移除可能的引号和其他包装字符
|
||
path = str(path).strip('\'"')
|
||
|
||
# 展开用户目录 (~)
|
||
path = os.path.expanduser(path)
|
||
|
||
# 转换为绝对路径
|
||
path = os.path.abspath(path)
|
||
|
||
# 在 Windows 上,确保使用反斜杠
|
||
if os.name == 'nt':
|
||
path = path.replace('/', '\\')
|
||
|
||
return path
|
||
|
||
if self.support_fraction is not None and not 0.5 <= self.support_fraction <= 1.0:
|
||
raise ValueError("support_fraction必须在[0.5, 1.0]范围内")
|
||
|
||
# 统一方法名为小写
|
||
self.reduction_method = self.reduction_method.lower()
|
||
self.scaler_type = self.scaler_type.lower()
|
||
|
||
# 校验降维方法
|
||
if self.reduction_method not in ['pca', 't-sne']:
|
||
raise ValueError(f"不支持的降维方法: {self.reduction_method}。支持的方法: ['pca', 't-sne']")
|
||
|
||
# 校验标准化类型
|
||
if self.scaler_type not in ['standard', 'robust']:
|
||
raise ValueError(f"不支持的标准化类型: {self.scaler_type}。支持的类型: ['standard', 'robust']")
|
||
|
||
|
||
class CovarianceAnomalyDetector:
|
||
"""
|
||
基于协方差估计的异常检测器
|
||
|
||
使用scikit-learn的EllipticEnvelope类实现,该类基于Minimum Covariance Determinant (MCD)
|
||
提供稳健的协方差估计,能够有效处理包含异常值的数据集。
|
||
"""
|
||
|
||
def __init__(self, config: CovarianceAnomalyConfig):
|
||
"""
|
||
初始化异常检测器
|
||
|
||
Args:
|
||
config: 配置对象
|
||
"""
|
||
self.config = config
|
||
self.data = None
|
||
self.labels = None
|
||
self.wavelengths = None
|
||
self.data_shape = None
|
||
self.input_format = None
|
||
|
||
# 预处理组件
|
||
self.scaler = None
|
||
self.reducer = None
|
||
|
||
# 异常检测模型
|
||
self.detector = None
|
||
|
||
# 结果
|
||
self.anomaly_scores = None
|
||
self.predictions = None
|
||
self.mahalanobis_distances = None
|
||
|
||
def run_analysis_from_config(self) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
从配置运行完整的异常检测分析
|
||
|
||
Returns:
|
||
Tuple[np.ndarray, np.ndarray]: (predictions, scores)
|
||
"""
|
||
# 预处理数据
|
||
processed_data = self.preprocess_data()
|
||
|
||
# 执行异常检测
|
||
predictions, scores = self.fit_predict(processed_data)
|
||
|
||
return predictions, scores
|
||
|
||
def load_data(self) -> None:
|
||
"""
|
||
加载高光谱数据文件
|
||
|
||
支持CSV和ENVI格式文件
|
||
"""
|
||
# 使用规范化后的路径字符串
|
||
file_path_str = self.config.input_path
|
||
file_path = Path(file_path_str)
|
||
suffix = file_path.suffix.lower()
|
||
|
||
print(f"正在读取文件: {file_path_str}")
|
||
|
||
if suffix == '.csv':
|
||
self._load_csv_data(file_path)
|
||
elif suffix in ['.hdr']:
|
||
self._load_envi_data(file_path)
|
||
else:
|
||
raise ValueError(f"不支持的文件格式: {suffix}。支持的格式: .hdr")
|
||
|
||
def _load_csv_data(self, file_path: Path) -> None:
|
||
"""加载CSV格式数据"""
|
||
if self.config.spectral_start is None:
|
||
raise ValueError("对于CSV文件,必须指定spectral_start参数")
|
||
|
||
self.input_format = 'csv'
|
||
|
||
# 读取数据
|
||
df = pd.read_csv(file_path, header=0 if self.config.csv_has_header else None)
|
||
|
||
# 提取标签列
|
||
if self.config.label_col and self.config.label_col in df.columns:
|
||
self.labels = df[self.config.label_col].values
|
||
spectral_cols = [col for col in df.columns if col != self.config.label_col]
|
||
else:
|
||
self.labels = None
|
||
spectral_cols = df.columns.tolist()
|
||
|
||
# 提取光谱数据
|
||
if self.config.spectral_end:
|
||
end_idx = spectral_cols.index(self.config.spectral_end) + 1
|
||
else:
|
||
end_idx = len(spectral_cols)
|
||
|
||
start_idx = spectral_cols.index(self.config.spectral_start)
|
||
spectral_data = df[spectral_cols[start_idx:end_idx]].values
|
||
|
||
# 生成波长信息
|
||
self.wavelengths = np.arange(len(spectral_cols[start_idx:end_idx]))
|
||
|
||
self.data = spectral_data.astype(np.float32)
|
||
self.data_shape = self.data.shape
|
||
|
||
print(f"CSV数据加载完成: {self.data.shape} 样本 x {self.data.shape[1]} 波段")
|
||
|
||
def _load_envi_data(self, file_path: Path) -> None:
|
||
"""加载ENVI格式数据"""
|
||
self.input_format = 'envi'
|
||
|
||
try:
|
||
# 确保文件路径存在且可访问
|
||
file_path = Path(file_path)
|
||
if not file_path.exists():
|
||
raise FileNotFoundError(f"ENVI文件不存在: {file_path}")
|
||
|
||
# 检查文件是否可读
|
||
if not os.access(file_path, os.R_OK):
|
||
raise PermissionError(f"没有读取权限: {file_path}")
|
||
|
||
# 使用规范化的路径字符串
|
||
file_path_str = str(file_path)
|
||
print(f"尝试使用spectral库读取ENVI文件: {file_path_str}")
|
||
|
||
# 检查对应的头文件
|
||
hdr_path = file_path.with_suffix(file_path.suffix + '.hdr')
|
||
if not hdr_path.exists():
|
||
# 尝试 .hdr 格式
|
||
hdr_path = file_path.with_suffix('.hdr')
|
||
if not hdr_path.exists():
|
||
print(f"警告: 未找到头文件 {file_path.with_suffix(file_path.suffix + '.hdr')} 或 {file_path.with_suffix('.hdr')}")
|
||
|
||
# 读取ENVI文件 - 使用规范化的路径
|
||
envi_data = spectral.open_image(file_path_str)
|
||
|
||
# 获取数据数组
|
||
print("正在加载数据到内存...")
|
||
data_array = envi_data.load()
|
||
|
||
# 获取波长信息
|
||
if hasattr(envi_data, 'bands') and hasattr(envi_data.bands, 'centers'):
|
||
self.wavelengths = np.array(envi_data.bands.centers)
|
||
print(f"读取到波长信息: {len(self.wavelengths)} 个波段")
|
||
else:
|
||
self.wavelengths = np.arange(data_array.shape[2])
|
||
print(f"未找到波长信息,使用默认波长: 0-{len(self.wavelengths)-1}")
|
||
|
||
# 重塑数据为 (samples, bands)
|
||
rows, cols, bands = data_array.shape
|
||
self.data = data_array.reshape(-1, bands).astype(np.float32)
|
||
self.data_shape = (rows, cols, bands)
|
||
|
||
print(f"ENVI数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
|
||
|
||
except Exception as e:
|
||
raise ValueError(f"读取ENVI文件失败: {str(e)}。请确保spectral库可用且文件格式正确。")
|
||
|
||
def preprocess_data(self) -> np.ndarray:
|
||
"""
|
||
数据预处理
|
||
|
||
Returns:
|
||
处理后的数据数组
|
||
"""
|
||
print("正在预处理数据...")
|
||
|
||
# 检查数据有效性
|
||
if self.data is None:
|
||
raise ValueError("请先调用load_data()加载数据")
|
||
|
||
# 检查数据维度要求
|
||
n_samples, n_features = self.data.shape
|
||
if n_samples <= n_features ** 2:
|
||
print(f"警告: 样本数({n_samples})小于特征数平方({n_features**2}),建议使用降维")
|
||
|
||
# 数据清理:移除NaN和无穷大值
|
||
data_clean = self._clean_data(self.data)
|
||
|
||
# 标准化
|
||
if self.config.use_scaling:
|
||
data_scaled = self._scale_data(data_clean)
|
||
else:
|
||
data_scaled = data_clean
|
||
|
||
# 降维
|
||
if self.config.use_dimensionality_reduction:
|
||
data_processed = self._reduce_dimensionality(data_scaled)
|
||
else:
|
||
data_processed = data_scaled
|
||
|
||
print(f"数据预处理完成: {data_processed.shape}")
|
||
return data_processed
|
||
|
||
def _clean_data(self, data: np.ndarray) -> np.ndarray:
|
||
"""清理数据:移除无效值"""
|
||
# 移除包含NaN或无穷大值的行
|
||
valid_mask = np.all(np.isfinite(data), axis=1)
|
||
data_clean = data[valid_mask]
|
||
|
||
if np.sum(~valid_mask) > 0:
|
||
print(f"移除了 {np.sum(~valid_mask)} 个包含无效值的样本")
|
||
|
||
# 移除全零行
|
||
non_zero_mask = np.any(data_clean != 0, axis=1)
|
||
data_clean = data_clean[non_zero_mask]
|
||
|
||
if np.sum(~non_zero_mask) > 0:
|
||
print(f"移除了 {np.sum(~non_zero_mask)} 个全零样本")
|
||
|
||
return data_clean
|
||
|
||
def _scale_data(self, data: np.ndarray) -> np.ndarray:
|
||
"""标准化数据"""
|
||
if self.config.scaler_type == 'robust':
|
||
self.scaler = RobustScaler()
|
||
else:
|
||
self.scaler = StandardScaler()
|
||
|
||
data_scaled = self.scaler.fit_transform(data)
|
||
return data_scaled
|
||
|
||
def _reduce_dimensionality(self, data: np.ndarray) -> np.ndarray:
|
||
"""降维处理"""
|
||
n_samples, n_features = data.shape
|
||
|
||
# 检查是否需要降维
|
||
if n_features <= self.config.n_components:
|
||
print(f"特征数({n_features})小于等于目标维度({self.config.n_components}),跳过降维")
|
||
return data
|
||
|
||
if self.config.reduction_method == 'pca':
|
||
self.reducer = PCA(
|
||
n_components=self.config.n_components,
|
||
random_state=self.config.random_state
|
||
)
|
||
elif self.config.reduction_method == 't-sne':
|
||
# t-SNE需要更多的样本
|
||
if n_samples < 4 * self.config.n_components:
|
||
print(f"警告: t-SNE需要至少4倍于目标维度的样本数,使用PCA代替")
|
||
self.reducer = PCA(
|
||
n_components=self.config.n_components,
|
||
random_state=self.config.random_state
|
||
)
|
||
else:
|
||
self.reducer = TSNE(
|
||
n_components=self.config.n_components,
|
||
random_state=self.config.random_state,
|
||
perplexity=min(30, n_samples // 3)
|
||
)
|
||
|
||
data_reduced = self.reducer.fit_transform(data)
|
||
|
||
explained_var = None
|
||
if hasattr(self.reducer, 'explained_variance_ratio_'):
|
||
explained_var = np.sum(self.reducer.explained_variance_ratio_)
|
||
print(f"降维完成,解释方差比例: {explained_var:.3f}")
|
||
|
||
return data_reduced
|
||
|
||
def fit_predict(self, data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
训练异常检测模型并进行预测
|
||
|
||
Args:
|
||
data: 输入数据数组 (n_samples, n_features)
|
||
|
||
Returns:
|
||
predictions: 预测结果 (+1为正常,-1为异常)
|
||
scores: 异常分数 (马氏距离)
|
||
"""
|
||
print("正在训练异常检测模型...")
|
||
|
||
# 初始化异常检测器
|
||
self.detector = EllipticEnvelope(
|
||
contamination=self.config.contamination,
|
||
support_fraction=self.config.support_fraction,
|
||
assume_centered=self.config.assume_centered,
|
||
random_state=self.config.random_state
|
||
)
|
||
|
||
# 训练模型
|
||
self.detector.fit(data)
|
||
|
||
# 预测异常
|
||
predictions = self.detector.predict(data)
|
||
|
||
# 计算马氏距离作为异常分数
|
||
# 使用predict_proba或decision_function获取分数
|
||
try:
|
||
# 尝试获取决策函数值
|
||
scores = self.detector.decision_function(data)
|
||
# 转换为马氏距离的近似值(负值表示异常)
|
||
scores = -scores
|
||
except:
|
||
# 如果decision_function不可用,使用距离计算
|
||
scores = self._compute_mahalanobis_distances(data)
|
||
|
||
self.predictions = predictions
|
||
self.anomaly_scores = scores
|
||
self.mahalanobis_distances = self._compute_mahalanobis_distances(data)
|
||
|
||
# 统计信息
|
||
n_anomalies = np.sum(predictions == -1)
|
||
anomaly_ratio = n_anomalies / len(predictions)
|
||
|
||
print(f"异常检测完成:")
|
||
print(f" - 总样本数: {len(predictions)}")
|
||
print(f" - 异常样本数: {n_anomalies}")
|
||
print(f" - 异常比例: {anomaly_ratio:.3f}")
|
||
|
||
return predictions, scores
|
||
|
||
def _compute_mahalanobis_distances(self, data: np.ndarray) -> np.ndarray:
|
||
"""
|
||
计算马氏距离
|
||
|
||
Args:
|
||
data: 输入数据
|
||
|
||
Returns:
|
||
马氏距离数组
|
||
"""
|
||
if self.detector is None:
|
||
raise ValueError("请先调用fit_predict()训练模型")
|
||
|
||
# 获取协方差矩阵和均值
|
||
cov_matrix = self.detector.covariance_
|
||
center = self.detector.location_
|
||
|
||
# 计算马氏距离
|
||
diff = data - center
|
||
inv_cov = np.linalg.inv(cov_matrix)
|
||
mahalanobis_sq = np.sum(diff * (diff @ inv_cov), axis=1)
|
||
mahalanobis_dist = np.sqrt(mahalanobis_sq)
|
||
|
||
return mahalanobis_dist
|
||
|
||
def save_results(self, output_path: Optional[str] = None) -> None:
|
||
"""
|
||
保存异常检测结果为ENVI格式文件
|
||
|
||
Args:
|
||
output_path: 输出文件路径,如果为None则使用配置中的路径
|
||
"""
|
||
if output_path is None:
|
||
if self.config.output_path:
|
||
output_path = self.config.output_path
|
||
elif self.config.output_dir:
|
||
base_name = Path(self.config.input_path).stem
|
||
output_path = str(Path(self.config.output_dir) / f"{base_name}_anomaly_covariance.bip")
|
||
else:
|
||
raise ValueError("必须指定输出路径")
|
||
|
||
# 清理路径中的可能问题字符
|
||
output_path = str(output_path).strip('\'"')
|
||
|
||
print(f"正在保存结果到: {output_path}")
|
||
|
||
# 创建输出目录
|
||
output_path_obj = Path(output_path)
|
||
output_dir = output_path_obj.parent
|
||
|
||
if output_dir and not output_dir.exists():
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
print(f"创建输出目录: {output_dir}")
|
||
|
||
# 根据输入格式决定输出格式
|
||
if self.input_format == 'envi':
|
||
self._save_envi_results(output_path)
|
||
else:
|
||
# 对于CSV输入,创建虚拟的2D图像
|
||
self._save_csv_results_as_envi(output_path)
|
||
|
||
def _save_envi_results(self, output_path: str) -> None:
|
||
"""保存ENVI格式结果,使用spectral库"""
|
||
# 重塑预测结果为原始图像尺寸
|
||
rows, cols, bands = self.data_shape
|
||
|
||
# 创建异常检测结果图像
|
||
# 结果包括:预测结果、马氏距离、异常分数
|
||
anomaly_image = np.zeros((rows, cols, 3), dtype=np.float32)
|
||
|
||
# 需要将1D结果映射回2D图像
|
||
# 假设数据是按行优先顺序展平的
|
||
predictions_2d = self.predictions.reshape(rows, cols)
|
||
distances_2d = self.mahalanobis_distances.reshape(rows, cols)
|
||
scores_2d = self.anomaly_scores.reshape(rows, cols)
|
||
|
||
anomaly_image[:, :, 0] = predictions_2d # 预测结果 (-1, 1)
|
||
anomaly_image[:, :, 1] = distances_2d # 马氏距离
|
||
anomaly_image[:, :, 2] = scores_2d # 异常分数
|
||
|
||
# 波段名称
|
||
band_names = ["Anomaly_Prediction", "Mahalanobis_Distance", "Anomaly_Score"]
|
||
|
||
# 保存为ENVI格式,参考spectral2cie2.py的方式
|
||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||
|
||
def _save_envi_data(self, data: np.ndarray, output_path: Union[str, Path],
|
||
band_names: List[str], interleave: str = 'bip') -> None:
|
||
"""
|
||
保存数据为ENVI格式,参考edge_detect.py的实现
|
||
|
||
参数:
|
||
data: 要保存的数据数组 (height, width, channels)
|
||
output_path: 输出文件路径
|
||
band_names: 波段名称列表
|
||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||
"""
|
||
output_path = Path(output_path)
|
||
height, width, channels = data.shape
|
||
|
||
# 确保目录存在
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"保存ENVI文件 - 形状: {data.shape}, 交织方式: {interleave}")
|
||
|
||
# 使用edge_detect.py的方式保存
|
||
output_base = str(output_path.with_suffix('')) # 移除扩展名并转换为字符串
|
||
|
||
# 保存.dat文件(二进制格式)
|
||
self._save_dat_file(data, str(output_path))
|
||
|
||
# 保存.hdr头文件
|
||
hdr_path = output_base + '.hdr'
|
||
self._save_hdr_file(hdr_path, data.shape, band_names, interleave)
|
||
|
||
print(f"异常检测结果已保存: {output_path}")
|
||
print(f"头文件已保存: {hdr_path}")
|
||
print(f"使用的交织格式: {interleave.upper()}")
|
||
|
||
def _save_dat_file(self, data: np.ndarray, file_path: str) -> None:
|
||
"""保存.dat文件(二进制格式),参考edge_detect.py"""
|
||
# 根据数据类型保存
|
||
if data.dtype == np.float32:
|
||
# 对于浮点数据,直接保存
|
||
with open(file_path, 'wb') as f:
|
||
data.astype(np.float32).tofile(f)
|
||
else:
|
||
# 对于其他数据类型,转换为float32保存
|
||
with open(file_path, 'wb') as f:
|
||
data.astype(np.float32).tofile(f)
|
||
|
||
def _save_hdr_file(self, hdr_path: str, data_shape: Tuple[int, ...],
|
||
band_names: List[str], interleave: str) -> None:
|
||
"""保存ENVI头文件,参考edge_detect.py并适配多波段数据"""
|
||
height, width, channels = data_shape
|
||
|
||
# 确定数据类型编码
|
||
# ENVI数据类型: 4=float32, 5=float64, 1=uint8, 2=int16, 3=int32, 12=uint16
|
||
data_type_code = 4 # 默认float32
|
||
|
||
header_content = f"""ENVI
|
||
description = {{Covariance Anomaly Detection Results - Generated by CovarianceAnomalyDetector}}
|
||
samples = {width}
|
||
lines = {height}
|
||
bands = {channels}
|
||
header offset = 0
|
||
file type = ENVI Standard
|
||
data type = {data_type_code}
|
||
interleave = {interleave}
|
||
byte order = 0
|
||
"""
|
||
|
||
# 添加波段名称
|
||
if band_names:
|
||
header_content += "band names = {\n"
|
||
for i, name in enumerate(band_names):
|
||
header_content += f' "{name}"'
|
||
if i < len(band_names) - 1:
|
||
header_content += ","
|
||
header_content += "\n"
|
||
header_content += "}\n"
|
||
|
||
# 添加异常检测相关元数据
|
||
header_content += f"anomaly_detection_method = covariance\n"
|
||
header_content += f"contamination = {self.config.contamination}\n"
|
||
if self.config.support_fraction is not None:
|
||
header_content += f"support_fraction = {self.config.support_fraction}\n"
|
||
header_content += f"dimensionality_reduction = {self.config.use_dimensionality_reduction}\n"
|
||
if self.config.use_dimensionality_reduction:
|
||
header_content += f"reduction_method = {self.config.reduction_method}\n"
|
||
header_content += f"n_components = {self.config.n_components}\n"
|
||
|
||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||
f.write(header_content)
|
||
|
||
def _create_envi_hdr_file(self, bil_path: Union[str, Path], height: int, width: int,
|
||
bands: int, data_type: str, interleave: str,
|
||
band_names: List[str], input_hdr_path: Optional[Union[str, Path]] = None) -> None:
|
||
"""
|
||
创建ENVI头文件,参考spectral2cie2.py的实现
|
||
|
||
Args:
|
||
bil_path: BIL文件路径
|
||
height: 图像高度
|
||
width: 图像宽度
|
||
bands: 波段数
|
||
data_type: 数据类型 ('float32', 'uint8', 'int16', 等)
|
||
interleave: 交织方式 ('bip', 'bil', 'bsq')
|
||
band_names: 波段名称列表
|
||
input_hdr_path: 输入ENVI文件的HDR路径,用于复制元数据
|
||
"""
|
||
# 使用 .bip.hdr 格式的头文件
|
||
bil_path_obj = Path(bil_path)
|
||
hdr_path = bil_path_obj.with_suffix(bil_path_obj.suffix + '.hdr')
|
||
|
||
# 数据类型映射 (ENVI格式)
|
||
dtype_map = {
|
||
'uint8': '1',
|
||
'int16': '2',
|
||
'int32': '3',
|
||
'float32': '4',
|
||
'float64': '5',
|
||
'complex64': '6',
|
||
'complex128': '9',
|
||
'uint16': '12',
|
||
'uint32': '13',
|
||
'int64': '14',
|
||
'uint64': '15'
|
||
}
|
||
|
||
envi_dtype = dtype_map.get(data_type, '4') # 默认为float32
|
||
|
||
with open(hdr_path, 'w') as f:
|
||
f.write("ENVI\n")
|
||
f.write("description = {\n")
|
||
f.write(" Covariance Anomaly Detection Results - Generated by CovarianceAnomalyDetector\n")
|
||
f.write(f" Contamination: {self.config.contamination}, Support Fraction: {self.config.support_fraction}\n")
|
||
f.write("}\n")
|
||
f.write(f"samples = {width}\n")
|
||
f.write(f"lines = {height}\n")
|
||
f.write(f"bands = {bands}\n")
|
||
f.write("header offset = 0\n")
|
||
f.write("file type = ENVI Standard\n")
|
||
f.write(f"data type = {envi_dtype}\n")
|
||
f.write(f"interleave = {interleave}\n")
|
||
f.write("sensor type = Unknown\n")
|
||
f.write("byte order = 0\n")
|
||
|
||
# 波段名称
|
||
f.write("band names = {\n")
|
||
for i, name in enumerate(band_names):
|
||
f.write(f' "{name}"')
|
||
if i < len(band_names) - 1:
|
||
f.write(",")
|
||
f.write("\n")
|
||
f.write("}\n")
|
||
|
||
# 如果有输入HDR文件,尝试复制相关的元数据
|
||
if input_hdr_path and Path(input_hdr_path).exists():
|
||
try:
|
||
self._copy_hdr_metadata(input_hdr_path, f)
|
||
except Exception as e:
|
||
print(f"复制HDR元数据失败: {e}")
|
||
|
||
print(f"ENVI头文件创建完成: {hdr_path}")
|
||
|
||
def _copy_hdr_metadata(self, input_hdr_path: Union[str, Path], output_file) -> None:
|
||
"""
|
||
从输入HDR文件复制元数据到输出HDR文件
|
||
|
||
Args:
|
||
input_hdr_path: 输入HDR文件路径
|
||
output_file: 输出文件对象
|
||
"""
|
||
try:
|
||
with open(input_hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||
content = f.read()
|
||
|
||
# 解析输入HDR文件,提取有用的元数据
|
||
lines = content.split('\n')
|
||
metadata_to_copy = []
|
||
|
||
# 需要复制的元数据字段
|
||
fields_to_copy = [
|
||
'wavelength units', 'wavelength', 'fwhm', 'bbl', 'map info',
|
||
'coordinate system string', 'projection info', 'pixel size',
|
||
'acquisition time', 'sensor type', 'radiance scale factor',
|
||
'reflectance scale factor', 'data gain values', 'data offset values'
|
||
]
|
||
|
||
i = 0
|
||
while i < len(lines):
|
||
line = lines[i].strip()
|
||
if '=' in line:
|
||
key = line.split('=')[0].strip().lower()
|
||
if any(field in key for field in fields_to_copy):
|
||
# 复制这个字段及其可能的后续行
|
||
metadata_to_copy.append(lines[i])
|
||
i += 1
|
||
# 如果是多行值,继续读取
|
||
while i < len(lines) and not ('=' in lines[i] and not lines[i].strip().endswith(',')):
|
||
if lines[i].strip():
|
||
metadata_to_copy.append(lines[i])
|
||
i += 1
|
||
if i >= len(lines):
|
||
break
|
||
continue
|
||
i += 1
|
||
|
||
# 将复制的元数据写入输出文件
|
||
if metadata_to_copy:
|
||
output_file.write("\n")
|
||
for line in metadata_to_copy:
|
||
if line.strip():
|
||
output_file.write(line + "\n")
|
||
|
||
print(f"已从输入HDR文件复制 {len(metadata_to_copy)} 行元数据")
|
||
|
||
except Exception as e:
|
||
print(f"读取输入HDR文件失败: {e}")
|
||
|
||
def _save_csv_results_as_envi(self, output_path: str) -> None:
|
||
"""将CSV结果保存为ENVI格式"""
|
||
# 对于CSV数据,创建1xN的图像
|
||
n_samples = len(self.predictions)
|
||
anomaly_image = np.zeros((1, n_samples, 3), dtype=np.float32)
|
||
|
||
anomaly_image[0, :, 0] = self.predictions # 预测结果
|
||
anomaly_image[0, :, 1] = self.mahalanobis_distances # 马氏距离
|
||
anomaly_image[0, :, 2] = self.anomaly_scores # 异常分数
|
||
|
||
# 波段名称
|
||
band_names = ["Anomaly_Prediction", "Mahalanobis_Distance", "Anomaly_Score"]
|
||
|
||
# 保存为ENVI格式,参考spectral2cie2.py的方式
|
||
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""
|
||
获取异常检测统计信息
|
||
|
||
Returns:
|
||
包含各种统计信息的字典
|
||
"""
|
||
if self.predictions is None:
|
||
raise ValueError("请先运行异常检测")
|
||
|
||
stats = {
|
||
'total_samples': len(self.predictions),
|
||
'n_anomalies': int(np.sum(self.predictions == -1)),
|
||
'anomaly_ratio': float(np.sum(self.predictions == -1) / len(self.predictions)),
|
||
'mean_mahalanobis_distance': float(np.mean(self.mahalanobis_distances)),
|
||
'std_mahalanobis_distance': float(np.std(self.mahalanobis_distances)),
|
||
'min_mahalanobis_distance': float(np.min(self.mahalanobis_distances)),
|
||
'max_mahalanobis_distance': float(np.max(self.mahalanobis_distances)),
|
||
'contamination': self.config.contamination,
|
||
'reduction_method': self.config.reduction_method if self.config.use_dimensionality_reduction else None,
|
||
'n_components': self.config.n_components if self.config.use_dimensionality_reduction else None
|
||
}
|
||
|
||
return stats
|
||
|
||
|
||
def main():
|
||
"""主函数:命令行接口"""
|
||
parser = argparse.ArgumentParser(
|
||
description='基于协方差估计的高光谱异常检测',
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
使用示例:
|
||
python Covariance.py --input data.csv --spectral-start 400 --output results.bip
|
||
python Covariance.py --input image.bip --contamination 0.15 --reduction-method t-sne --output results.bip
|
||
"""
|
||
)
|
||
|
||
# 输入文件参数
|
||
parser.add_argument('--input', required=True, help='输入文件路径 (CSV或ENVI格式,如.bip/.bil/.bsq文件)')
|
||
parser.add_argument('--label-col', help='标签列名 (CSV文件)')
|
||
parser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
|
||
parser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
|
||
|
||
# 输出参数
|
||
parser.add_argument('--output', help='输出文件路径')
|
||
parser.add_argument('--output-dir', help='输出目录 (将自动生成文件名)')
|
||
|
||
# 异常检测参数
|
||
parser.add_argument('--contamination', type=float, default=0.1,
|
||
help='异常点比例 (0, 0.5],默认0.1')
|
||
parser.add_argument('--support-fraction', type=float,
|
||
help='MCD支持比例 [0.5, 1.0],默认自动选择')
|
||
|
||
# 预处理参数
|
||
parser.add_argument('--no-reduction', action='store_true',
|
||
help='禁用降维处理')
|
||
parser.add_argument('--reduction-method', choices=['pca', 't-sne'], default='pca',
|
||
help='降维方法,默认pca')
|
||
parser.add_argument('--n-components', type=int, default=10,
|
||
help='降维后的维度,默认10')
|
||
parser.add_argument('--no-scaling', action='store_true',
|
||
help='禁用数据标准化')
|
||
parser.add_argument('--scaler-type', choices=['standard', 'robust'], default='robust',
|
||
help='标准化类型,默认robust')
|
||
|
||
# 其他参数
|
||
parser.add_argument('--random-state', type=int, default=42,
|
||
help='随机种子,默认42')
|
||
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
# 创建配置
|
||
config = CovarianceAnomalyConfig(
|
||
input_path=args.input,
|
||
label_col=args.label_col,
|
||
spectral_start=args.spectral_start,
|
||
spectral_end=args.spectral_end,
|
||
output_path=args.output,
|
||
output_dir=args.output_dir,
|
||
contamination=args.contamination,
|
||
support_fraction=args.support_fraction,
|
||
use_dimensionality_reduction=not args.no_reduction,
|
||
reduction_method=args.reduction_method,
|
||
n_components=args.n_components,
|
||
use_scaling=not args.no_scaling,
|
||
scaler_type=args.scaler_type,
|
||
random_state=args.random_state
|
||
)
|
||
|
||
# 创建检测器
|
||
detector = CovarianceAnomalyDetector(config)
|
||
|
||
# 执行异常检测流程
|
||
print("=== 基于协方差估计的异常检测 ===")
|
||
print(f"输入文件: {config.input_path}")
|
||
print(f"异常比例: {config.contamination}")
|
||
print(f"降维方法: {config.reduction_method if config.use_dimensionality_reduction else '无'}")
|
||
print("-" * 50)
|
||
|
||
# 1. 加载数据
|
||
detector.load_data()
|
||
|
||
# 2. 预处理数据
|
||
processed_data = detector.preprocess_data()
|
||
|
||
# 3. 执行异常检测
|
||
predictions, scores = detector.fit_predict(processed_data)
|
||
|
||
# 4. 保存结果
|
||
detector.save_results()
|
||
|
||
# 5. 显示统计信息
|
||
stats = detector.get_statistics()
|
||
print("\n=== 检测结果统计 ===")
|
||
for key, value in stats.items():
|
||
print(f"{key}: {value}")
|
||
|
||
print("\n处理完成!")
|
||
|
||
except Exception as e:
|
||
print(f"错误: {str(e)}")
|
||
return 1
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit(main())
|