Files
HSI/Anomaly_method/RX.py

884 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import pandas as pd
import spectral
from sklearn.preprocessing import StandardScaler
from sklearn.covariance import EmpiricalCovariance, MinCovDet
import os
import warnings
from typing import Optional, Dict, Any, Tuple, List, Union
from dataclasses import dataclass, field
import argparse
from pathlib import Path
from scipy import stats, linalg
import cv2
warnings.filterwarnings('ignore')
@dataclass
class RXConfig:
"""RX异常检测配置类"""
# 输入文件配置
input_path: Optional[str] = None
background_path: Optional[str] = None # 用户指定背景数据
label_col: Optional[str] = None
spectral_start: Optional[str] = None
spectral_end: Optional[str] = None
csv_has_header: bool = True
# 输出配置
output_path: Optional[str] = None
output_dir: Optional[str] = None
# 背景模型配置
background_model: str = 'global' # 'global', 'local_row', 'local_square', 'custom'
window_size: int = 5 # 局部窗口大小
row_window: int = 3 # 行窗口大小(行数)
# 检测参数
threshold: Optional[float] = None # 异常检测阈值None表示不进行阈值分割
contamination: float = 0.1 # 异常点比例,用于自动确定阈值
use_robust_covariance: bool = False # 是否使用稳健协方差估计
# 数据预处理配置
use_scaling: bool = False # 是否使用标准化RX算法通常不使用
def __post_init__(self):
"""参数校验和默认值设置"""
# 校验必需的文件路径
if not self.input_path:
raise ValueError("必须指定输入文件路径(input_path)")
# 校验文件存在性
if not os.path.exists(self.input_path):
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
if self.background_path and not os.path.exists(self.background_path):
raise FileNotFoundError(f"背景文件不存在: {self.background_path}")
# 校验输出路径
if not self.output_path and not self.output_dir:
raise ValueError("必须指定输出路径(output_path)或输出目录(output_dir)")
# 校验背景模型类型
supported_models = ['global', 'local_row', 'local_square', 'custom']
if self.background_model not in supported_models:
raise ValueError(f"不支持的背景模型: {self.background_model}。支持的模型: {supported_models}")
# 校验参数范围
if self.window_size <= 0:
raise ValueError("window_size必须大于0")
if self.row_window <= 0:
raise ValueError("row_window必须大于0")
if self.background_model == 'custom' and not self.background_path:
raise ValueError("使用custom背景模型时必须指定background_path")
class RXAnomalyDetector:
"""
Reed-Xiaoli (RX) 异常检测器
RX算法基于Mahalanobis距离检测高光谱数据中的异常像素。
算法计算每个像素相对于背景分布的距离,距离越大的像素越可能是异常。
数学基础:
D(x) = (x - μ)ᵀ K⁻¹ (x - μ)
其中:
- x: 测试像素的光谱向量
- μ: 背景数据的均值向量
- K: 背景数据的协方差矩阵
"""
def __init__(self, config: RXConfig):
"""
初始化RX异常检测器
Args:
config: 配置对象
"""
self.config = config
self.data = None
self.background_data = None
self.wavelengths = None
self.data_shape = None
self.input_format = None
# 背景统计量
self.background_mean = None
self.background_cov = None
self.background_inv_cov = None
# 结果
self.distance_map = None # Mahalanobis距离图
self.anomaly_map = None # 异常检测结果
# 预处理组件
self.scaler = None
def load_data(self) -> None:
"""
加载高光谱数据文件
支持CSV和ENVI格式文件
"""
file_path = Path(self.config.input_path)
suffix = file_path.suffix.lower()
print(f"正在读取输入数据: {file_path}")
if suffix == '.csv':
self._load_csv_data(file_path, is_background=False)
elif suffix in ['.hdr']:
self._load_envi_data(file_path, is_background=False)
else:
raise ValueError(f"不支持的文件格式: {suffix}。支持的格式: .hdr")
# 加载背景数据
if self.config.background_model == 'custom':
self._load_background_data()
def _load_csv_data(self, file_path: Path, is_background: bool = False) -> None:
"""加载CSV格式数据"""
if self.config.spectral_start is None:
raise ValueError("对于CSV文件必须指定spectral_start参数")
# 读取数据
df = pd.read_csv(file_path, header=0 if self.config.csv_has_header else None)
# 提取标签列
if self.config.label_col and self.config.label_col in df.columns:
labels = df[self.config.label_col].values
spectral_cols = [col for col in df.columns if col != self.config.label_col]
else:
labels = None
spectral_cols = df.columns.tolist()
# 提取光谱数据
if self.config.spectral_end:
end_idx = spectral_cols.index(self.config.spectral_end) + 1
else:
end_idx = len(spectral_cols)
start_idx = spectral_cols.index(self.config.spectral_start)
spectral_data = df[spectral_cols[start_idx:end_idx]].values
# 生成波长信息
wavelengths = np.arange(len(spectral_cols[start_idx:end_idx]))
# 存储数据
if is_background:
self.background_data = spectral_data.astype(np.float32)
print(f"背景数据加载完成: {self.background_data.shape} 样本 x {self.background_data.shape[1]} 波段")
else:
self.data = spectral_data.astype(np.float32)
self.data_shape = self.data.shape
self.input_format = 'csv'
self.wavelengths = wavelengths
print(f"输入数据加载完成: {self.data.shape} 样本 x {self.data.shape[1]} 波段")
def _load_envi_data(self, file_path: Path, is_background: bool = False) -> None:
"""加载ENVI格式数据"""
try:
# 确保文件路径存在且可访问
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"ENVI文件不存在: {file_path}")
# 检查文件是否可读
if not os.access(file_path, os.R_OK):
raise PermissionError(f"没有读取权限: {file_path}")
# 使用规范化的路径字符串
file_path_str = str(file_path)
print(f"尝试使用spectral库读取ENVI文件: {file_path_str}")
# 检查对应的头文件
hdr_path = file_path.with_suffix(file_path.suffix + '.hdr')
if not hdr_path.exists():
# 尝试 .hdr 格式
hdr_path = file_path.with_suffix('.hdr')
if not hdr_path.exists():
print(f"警告: 未找到头文件 {file_path.with_suffix(file_path.suffix + '.hdr')}{file_path.with_suffix('.hdr')}")
# 读取ENVI文件 - 使用规范化的路径
envi_data = spectral.open_image(file_path_str)
# 获取数据数组
print("正在加载数据到内存...")
data_array = envi_data.load()
# 获取波长信息
if hasattr(envi_data, 'bands') and hasattr(envi_data.bands, 'centers'):
wavelengths = np.array(envi_data.bands.centers)
print(f"读取到波长信息: {len(wavelengths)} 个波段")
else:
wavelengths = np.arange(data_array.shape[2])
print(f"未找到波长信息,使用默认波长: 0-{len(wavelengths)-1}")
# 重塑数据为 (samples, bands)
rows, cols, bands = data_array.shape
spectral_data = data_array.reshape(-1, bands).astype(np.float32)
# 存储数据
if is_background:
self.background_data = spectral_data
print(f"背景数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
else:
self.data = spectral_data
self.data_shape = (rows, cols, bands)
self.input_format = 'envi'
self.wavelengths = wavelengths
print(f"输入数据加载完成: {rows}x{cols} 像素 x {bands} 波段")
except Exception as e:
raise ValueError(f"读取ENVI文件失败: {str(e)}。请确保spectral库可用且文件格式正确。")
def _load_background_data(self) -> None:
"""加载用户指定的背景数据"""
if not self.config.background_path:
raise ValueError("background_path未指定")
file_path = Path(self.config.background_path)
suffix = file_path.suffix.lower()
print(f"正在读取背景数据: {file_path}")
if suffix == '.csv':
self._load_csv_data(file_path, is_background=True)
elif suffix in ['.bil', '.bsq', '.bip', '.dat']:
self._load_envi_data(file_path, is_background=True)
else:
raise ValueError(f"不支持的背景文件格式: {suffix}")
def preprocess_data(self) -> None:
"""
数据预处理
RX算法通常不需要标准化除非指定
"""
if self.config.use_scaling:
print("正在标准化数据...")
self.scaler = StandardScaler()
# 对输入数据进行标准化
if self.data is not None:
self.data = self.scaler.fit_transform(self.data)
# 对背景数据进行相同的标准化
if self.background_data is not None:
self.background_data = self.scaler.transform(self.background_data)
def compute_background_statistics(self) -> None:
"""
计算背景统计量(均值和协方差矩阵)
根据不同的背景模型计算相应的统计量
"""
print(f"正在计算背景统计量 (模型: {self.config.background_model})...")
if self.config.background_model == 'global':
self._compute_global_background()
elif self.config.background_model == 'custom':
self._compute_custom_background()
# 局部模型的统计量在检测过程中计算
def _compute_global_background(self) -> None:
"""计算全局背景统计量"""
# 使用整个数据集作为背景
background_data = self.data
if self.config.use_robust_covariance:
# 使用稳健协方差估计
cov_estimator = MinCovDet()
cov_estimator.fit(background_data)
self.background_mean = cov_estimator.location_
self.background_cov = cov_estimator.covariance_
else:
# 使用经验协方差
self.background_mean = np.mean(background_data, axis=0)
self.background_cov = np.cov(background_data.T)
# 计算协方差矩阵的逆
try:
self.background_inv_cov = linalg.inv(self.background_cov)
except np.linalg.LinAlgError:
print("警告: 协方差矩阵奇异,使用伪逆")
self.background_inv_cov = linalg.pinv(self.background_cov)
print(f"全局背景统计量计算完成,协方差矩阵形状: {self.background_cov.shape}")
def _compute_custom_background(self) -> None:
"""计算用户指定背景统计量"""
if self.background_data is None:
raise ValueError("自定义背景数据未加载")
background_data = self.background_data
if self.config.use_robust_covariance:
cov_estimator = MinCovDet()
cov_estimator.fit(background_data)
self.background_mean = cov_estimator.location_
self.background_cov = cov_estimator.covariance_
else:
self.background_mean = np.mean(background_data, axis=0)
self.background_cov = np.cov(background_data.T)
# 计算协方差矩阵的逆
try:
self.background_inv_cov = linalg.inv(self.background_cov)
except np.linalg.LinAlgError:
print("警告: 协方差矩阵奇异,使用伪逆")
self.background_inv_cov = linalg.pinv(self.background_cov)
print(f"自定义背景统计量计算完成,协方差矩阵形状: {self.background_cov.shape}")
def _compute_local_background_stats(self, data_window: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
计算局部背景统计量
Args:
data_window: 局部窗口数据 (n_pixels, n_bands)
Returns:
mean_vector, cov_matrix, inv_cov_matrix
"""
# 移除窗口中的NaN值
valid_mask = ~np.isnan(data_window).any(axis=1)
valid_data = data_window[valid_mask]
if len(valid_data) < 2:
# 数据不足,返回全局统计量
return self.background_mean, self.background_cov, self.background_inv_cov
if self.config.use_robust_covariance and len(valid_data) > self.background_cov.shape[0]:
# 使用稳健协方差估计
cov_estimator = MinCovDet()
cov_estimator.fit(valid_data)
mean_vec = cov_estimator.location_
cov_mat = cov_estimator.covariance_
else:
# 使用经验协方差
mean_vec = np.mean(valid_data, axis=0)
cov_mat = np.cov(valid_data.T)
# 计算协方差矩阵的逆
try:
inv_cov_mat = linalg.inv(cov_mat)
except np.linalg.LinAlgError:
inv_cov_mat = linalg.pinv(cov_mat)
return mean_vec, cov_mat, inv_cov_mat
def compute_mahalanobis_distance(self, x: np.ndarray, mean_vec: np.ndarray,
inv_cov_mat: np.ndarray) -> float:
"""
计算单个像素的Mahalanobis距离
Args:
x: 像素光谱向量
mean_vec: 均值向量
inv_cov_mat: 协方差矩阵的逆
Returns:
Mahalanobis距离
"""
diff = x - mean_vec
distance = np.dot(np.dot(diff, inv_cov_mat), diff.T)
# 确保距离为非负实数
return max(0, distance)
def detect_anomalies_global(self) -> np.ndarray:
"""
使用全局背景模型进行异常检测
Returns:
距离图 (n_samples,)
"""
print("正在进行全局RX异常检测...")
n_samples = self.data.shape[0]
distance_map = np.zeros(n_samples, dtype=np.float32)
# 对每个像素计算Mahalanobis距离
for i in range(n_samples):
distance_map[i] = self.compute_mahalanobis_distance(
self.data[i], self.background_mean, self.background_inv_cov
)
return distance_map
def detect_anomalies_local_row(self) -> np.ndarray:
"""
使用局部行背景模型进行异常检测
Returns:
距离图 (rows, cols)
"""
print("正在进行局部行RX异常检测...")
rows, cols, bands = self.data_shape
distance_map = np.zeros((rows, cols), dtype=np.float32)
# 将数据重塑为2D图像
image_data = self.data.reshape(rows, cols, bands)
half_window = self.config.row_window // 2
for i in range(rows):
# 计算行窗口范围
row_start = max(0, i - half_window)
row_end = min(rows, i + half_window + 1)
# 提取行窗口数据
row_window = image_data[row_start:row_end, :, :].reshape(-1, bands)
# 计算局部背景统计量
local_mean, local_cov, local_inv_cov = self._compute_local_background_stats(row_window)
# 计算当前行的距离
for j in range(cols):
pixel = image_data[i, j, :]
distance_map[i, j] = self.compute_mahalanobis_distance(
pixel, local_mean, local_inv_cov
)
return distance_map.flatten()
def detect_anomalies_local_square(self) -> np.ndarray:
"""
使用局部正方形背景模型进行异常检测
Returns:
距离图 (rows, cols)
"""
print("正在进行局部正方形RX异常检测...")
rows, cols, bands = self.data_shape
distance_map = np.zeros((rows, cols), dtype=np.float32)
# 将数据重塑为2D图像
image_data = self.data.reshape(rows, cols, bands)
half_window = self.config.window_size // 2
# 使用向量化计算优化性能
for i in range(rows):
for j in range(cols):
# 计算窗口范围
row_start = max(0, i - half_window)
row_end = min(rows, i + half_window + 1)
col_start = max(0, j - half_window)
col_end = min(cols, j + half_window + 1)
# 提取窗口数据
window_data = image_data[row_start:row_end, col_start:col_end, :].reshape(-1, bands)
# 计算局部背景统计量
local_mean, local_cov, local_inv_cov = self._compute_local_background_stats(window_data)
# 计算当前像素的距离
pixel = image_data[i, j, :]
distance_map[i, j] = self.compute_mahalanobis_distance(
pixel, local_mean, local_inv_cov
)
return distance_map.flatten()
def detect_anomalies(self) -> np.ndarray:
"""
执行异常检测
Returns:
Mahalanobis距离图
"""
if self.config.background_model == 'global':
self.compute_background_statistics()
distance_map = self.detect_anomalies_global()
elif self.config.background_model == 'local_row':
distance_map = self.detect_anomalies_local_row()
elif self.config.background_model == 'local_square':
distance_map = self.detect_anomalies_local_square()
elif self.config.background_model == 'custom':
self.compute_background_statistics()
distance_map = self.detect_anomalies_global()
self.distance_map = distance_map
# 应用阈值分割(如果指定)
if self.config.threshold is not None:
self.anomaly_map = (distance_map > self.config.threshold).astype(np.int32)
else:
self.anomaly_map = None
# 统计信息
mean_dist = np.mean(distance_map)
std_dist = np.std(distance_map)
max_dist = np.max(distance_map)
print("异常检测完成:")
print(".4f")
print(".4f")
print(".4f")
if self.anomaly_map is not None:
n_anomalies = np.sum(self.anomaly_map)
anomaly_ratio = n_anomalies / len(distance_map)
print(".3f")
return distance_map
def run_analysis_from_config(self) -> Tuple[np.ndarray, np.ndarray]:
"""
从配置运行完整的异常检测分析
Returns:
Tuple[np.ndarray, np.ndarray]: (predictions, scores)
"""
# 执行异常检测,获取距离图作为分数
scores = self.detect_anomalies()
# 生成预测结果(基于阈值或百分位数)
if self.config.threshold is not None:
predictions = (scores > self.config.threshold).astype(np.int32)
else:
# 使用contamination参数确定阈值
# contamination表示异常比例所以使用(1-contamination)*100百分位数
percentile = (1 - self.config.contamination) * 100
threshold = np.percentile(scores, percentile)
predictions = (scores > threshold).astype(np.int32)
return predictions, scores
def save_results(self, output_path: Optional[str] = None) -> None:
"""
保存异常检测结果为ENVI格式文件
Args:
output_path: 输出文件路径如果为None则使用配置中的路径
"""
if output_path is None:
if self.config.output_path:
output_path = self.config.output_path
elif self.config.output_dir:
base_name = Path(self.config.input_path).stem
output_path = os.path.join(self.config.output_dir, f"{base_name}_rx_anomaly.bip")
else:
raise ValueError("必须指定输出路径")
print(f"正在保存结果到: {output_path}")
# 创建输出目录
output_dir = os.path.dirname(output_path)
os.makedirs(output_dir, exist_ok=True)
# 根据输入格式决定输出格式
if self.input_format == 'envi':
self._save_envi_results(output_path)
else:
# 对于CSV输入创建虚拟的2D图像
self._save_csv_results_as_envi(output_path)
def _save_envi_results(self, output_path: str) -> None:
"""保存ENVI格式结果"""
# 重塑距离图为原始图像尺寸
rows, cols, bands = self.data_shape
# 创建单波段异常检测结果图像
if self.anomaly_map is not None:
# 保存异常检测结果 (0/1)
anomaly_image = np.zeros((rows, cols, 1), dtype=np.float32)
anomaly_image[:, :, 0] = self.anomaly_map.reshape(rows, cols)
band_names = ["Anomaly_Map"]
else:
# 保存距离图
anomaly_image = np.zeros((rows, cols, 1), dtype=np.float32)
anomaly_image[:, :, 0] = self.distance_map.reshape(rows, cols)
band_names = ["Mahalanobis_Distance"]
# 保存为ENVI格式参考Covariance.py的方式
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
def _save_csv_results_as_envi(self, output_path: str) -> None:
"""将CSV结果保存为ENVI格式"""
# 对于CSV数据创建1xN的单波段图像
n_samples = len(self.distance_map)
if self.anomaly_map is not None:
anomaly_image = np.zeros((1, n_samples, 1), dtype=np.float32)
anomaly_image[0, :, 0] = self.anomaly_map
band_names = ["Anomaly_Map"]
else:
anomaly_image = np.zeros((1, n_samples, 1), dtype=np.float32)
anomaly_image[0, :, 0] = self.distance_map
band_names = ["Mahalanobis_Distance"]
# 保存为ENVI格式参考Covariance.py的方式
self._save_envi_data(anomaly_image, output_path, band_names, 'bip')
def _save_envi_data(self, data: np.ndarray, output_path: Union[str, Path],
band_names: List[str], interleave: str = 'bip') -> None:
"""
保存数据为ENVI格式参考Covariance.py的实现
参数:
data: 要保存的数据数组 (height, width, channels)
output_path: 输出文件路径
band_names: 波段名称列表
interleave: 交织方式 ('bip', 'bil', 'bsq')
"""
output_path = Path(output_path)
height, width, channels = data.shape
# 确保目录存在
output_path.parent.mkdir(parents=True, exist_ok=True)
print(f"保存ENVI文件 - 形状: {data.shape}, 交织方式: {interleave}")
# 保存.dat文件二进制格式
self._save_dat_file(data, str(output_path))
# 保存.hdr头文件
hdr_path = str(output_path).replace('.dat', '.hdr')
self._save_hdr_file(hdr_path, data.shape, band_names, interleave)
print(f"RX异常检测结果已保存: {output_path}")
print(f"头文件已保存: {hdr_path}")
print(f"使用的交织格式: {interleave.upper()}")
def _save_dat_file(self, data: np.ndarray, file_path: str) -> None:
"""保存.dat文件二进制格式参考Covariance.py"""
# 根据数据类型保存
if data.dtype == np.float32:
# 对于浮点数据,直接保存
with open(file_path, 'wb') as f:
data.astype(np.float32).tofile(f)
else:
# 对于其他数据类型转换为float32保存
with open(file_path, 'wb') as f:
data.astype(np.float32).tofile(f)
def _save_hdr_file(self, hdr_path: str, data_shape: Tuple[int, ...],
band_names: List[str], interleave: str) -> None:
"""保存ENVI头文件参考Covariance.py并适配RX数据"""
height, width, channels = data_shape
# 确定数据类型编码
# ENVI数据类型: 4=float32, 5=float64, 1=uint8, 2=int16, 3=int32, 12=uint16
data_type_code = 4 # 默认float32
header_content = f"""ENVI
description = {{RX Anomaly Detection Results - Generated by RXAnomalyDetector}}
samples = {width}
lines = {height}
bands = {channels}
header offset = 0
file type = ENVI Standard
data type = {data_type_code}
interleave = {interleave}
byte order = 0
"""
# 添加波段名称
if band_names:
header_content += "band names = {\n"
for i, name in enumerate(band_names):
header_content += f' "{name}"'
if i < len(band_names) - 1:
header_content += ","
header_content += "\n"
header_content += "}\n"
# 添加RX相关元数据
header_content += f"anomaly_detection_method = rx\n"
header_content += f"background_model = {self.config.background_model}\n"
if self.config.background_model in ['local_row', 'local_square']:
header_content += f"window_size = {self.config.window_size}\n"
if self.config.background_model == 'local_row':
header_content += f"row_window = {self.config.row_window}\n"
if self.config.threshold is not None:
header_content += f"threshold = {self.config.threshold}\n"
header_content += f"use_robust_covariance = {self.config.use_robust_covariance}\n"
with open(hdr_path, 'w', encoding='utf-8') as f:
f.write(header_content)
def get_statistics(self) -> Dict[str, Any]:
"""
获取异常检测统计信息
Returns:
包含各种统计信息的字典
"""
if self.distance_map is None:
raise ValueError("请先运行异常检测")
stats = {
'total_pixels': len(self.distance_map),
'background_model': self.config.background_model,
'mean_distance': float(np.mean(self.distance_map)),
'std_distance': float(np.std(self.distance_map)),
'min_distance': float(np.min(self.distance_map)),
'max_distance': float(np.max(self.distance_map)),
'median_distance': float(np.median(self.distance_map))
}
# 计算分位数
for percentile in [25, 75, 90, 95, 99]:
stats[f'percentile_{percentile}'] = float(np.percentile(self.distance_map, percentile))
if self.anomaly_map is not None:
stats['threshold'] = self.config.threshold
stats['n_anomalies'] = int(np.sum(self.anomaly_map))
stats['anomaly_ratio'] = float(np.sum(self.anomaly_map) / len(self.distance_map))
return stats
def suggest_threshold(self, method: str = 'percentile') -> float:
"""
建议异常检测阈值
Args:
method: 阈值建议方法 ('percentile', 'mean_std', 'auto')
Returns:
建议的阈值
"""
if self.distance_map is None:
raise ValueError("请先运行异常检测")
if method == 'percentile':
# 使用95%分位数作为阈值
threshold = np.percentile(self.distance_map, 95)
elif method == 'mean_std':
# 使用均值+2倍标准差作为阈值
mean_dist = np.mean(self.distance_map)
std_dist = np.std(self.distance_map)
threshold = mean_dist + 2 * std_dist
elif method == 'auto':
# 使用自动阈值选择(基于卡方分布)
# 对于高维数据,自由度近似为波段数
n_bands = self.data.shape[1]
# 使用卡方分布的95%分位数
threshold = stats.chi2.ppf(0.95, n_bands)
else:
raise ValueError(f"不支持的阈值建议方法: {method}")
return float(threshold)
def main():
"""主函数:命令行接口"""
parser = argparse.ArgumentParser(
description='Reed-Xiaoli (RX) 高光谱异常检测',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
python RX.py --input image.bip --background-model global --output results.bip
python RX.py --input data.csv --spectral-start 400 --background-model local_square --window-size 7 --output results.bip
python RX.py --input image.bip --background-model custom --background background.bip --threshold 10.0 --output results.bip
背景模型说明:
global: 使用整个数据集作为背景
local_row: 使用像素周围N行的窗口作为背景
local_square: 使用固定大小的正方形窗口作为背景
custom: 使用外部提供的背景数据
"""
)
# 输入文件参数
parser.add_argument('--input', required=True, help='输入文件路径 (CSV或ENVI格式)')
parser.add_argument('--background', help='背景数据路径 (custom模式时必需)')
parser.add_argument('--label-col', help='标签列名 (CSV文件)')
parser.add_argument('--spectral-start', help='光谱数据起始列名或波长 (CSV文件)')
parser.add_argument('--spectral-end', help='光谱数据结束列名或波长 (CSV文件)')
# 输出参数
parser.add_argument('--output', help='输出文件路径')
parser.add_argument('--output-dir', help='输出目录 (将自动生成文件名)')
# 背景模型参数
parser.add_argument('--background-model', choices=['global', 'local_row', 'local_square', 'custom'],
default='global', help='背景模型类型默认global')
parser.add_argument('--window-size', type=int, default=5,
help='局部窗口大小默认5')
parser.add_argument('--row-window', type=int, default=3,
help='行窗口大小默认3')
# 检测参数
parser.add_argument('--threshold', type=float,
help='异常检测阈值,不指定则只输出距离图')
parser.add_argument('--robust-covariance', action='store_true',
help='使用稳健协方差估计')
args = parser.parse_args()
try:
# 创建配置
config = RXConfig(
input_path=args.input,
background_path=args.background,
label_col=args.label_col,
spectral_start=args.spectral_start,
spectral_end=args.spectral_end,
output_path=args.output,
output_dir=args.output_dir,
background_model=args.background_model,
window_size=args.window_size,
row_window=args.row_window,
threshold=args.threshold,
use_robust_covariance=args.robust_covariance
)
# 创建检测器
detector = RXAnomalyDetector(config)
# 执行异常检测流程
print("=== Reed-Xiaoli (RX) 异常检测 ===")
print(f"输入文件: {config.input_path}")
print(f"背景模型: {config.background_model}")
if config.background_model in ['local_row', 'local_square']:
print(f"窗口大小: {config.window_size}")
if config.threshold is not None:
print(f"检测阈值: {config.threshold}")
print("-" * 50)
# 1. 加载数据
detector.load_data()
# 2. 预处理数据
detector.preprocess_data()
# 3. 执行异常检测
distance_map = detector.detect_anomalies()
# 4. 阈值建议
if config.threshold is None:
suggested_threshold = detector.suggest_threshold('percentile')
print(f"\n建议阈值 (95百分位数): {suggested_threshold:.4f}")
# 5. 保存结果
detector.save_results()
# 6. 显示统计信息
stats = detector.get_statistics()
print("\n=== 检测结果统计 ===")
for key, value in stats.items():
print(f"{key}: {value}")
print("\n处理完成!")
except Exception as e:
print(f"错误: {str(e)}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())