Files
HSI/Dimensionality_Reduction_method/dimensionality_reduction.py

1203 lines
51 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import pandas as pd
import spectral
from sklearn.decomposition import PCA, FastICA, FactorAnalysis
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.manifold import MDS, Isomap, TSNE, LocallyLinearEmbedding
from sklearn.preprocessing import StandardScaler
# import h5py
import os
from scipy.io import savemat
import warnings
from typing import Optional, Dict, Any, Tuple, List
from dataclasses import dataclass, field
warnings.filterwarnings('ignore')
# 可视化相关导入
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm
import matplotlib.colors as mcolors
@dataclass
class DimensionalityReductionConfig:
"""高光谱数据降维配置类"""
# 输入文件配置
input_path: Optional[str] = None
label_col: Optional[str] = None
spectral_start: Optional[str] = None
spectral_end: Optional[str] = None
csv_has_header: bool = True
# 输出配置
output_path: Optional[str] = None
output_dir: Optional[str] = None
# 降维方法配置
method: str = 'PCA'
n_components: int = 3
method_params: Dict[str, Any] = field(default_factory=dict)
# 批量处理配置
batch_methods: List[str] = field(default_factory=lambda: ['PCA', 'ICA', 'FA'])
batch_components: List[int] = field(default_factory=lambda: [3, 3, 3])
# 数据预处理配置
use_standardization: bool = True
# 可视化参数
generate_plots: bool = True # 是否生成散点分布图
color_by_label: bool = True # 是否按标签着色散点
save_plots: bool = True # 是否保存图表文件
def __post_init__(self):
"""参数校验和默认值设置"""
# 校验必需的文件路径
if not self.input_path:
raise ValueError("必须指定输入文件路径(input_path)")
# 校验文件存在性
if not os.path.exists(self.input_path):
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
# 校验输出路径(批量处理时必需)
# 单方法处理时输出路径是可选的
# 统一方法名为小写
self.method = self.method.lower()
self.batch_methods = [m.lower() for m in self.batch_methods]
# 校验降维方法
supported_methods = self._get_supported_methods()
if self.method not in supported_methods:
raise ValueError(f"不支持的降维方法: {self.method}。支持的方法: {list(supported_methods.keys())}")
# 校验批量处理参数
if len(self.batch_methods) != len(self.batch_components):
raise ValueError("batch_methods和batch_components长度必须相等")
for method in self.batch_methods:
if method not in supported_methods:
raise ValueError(f"批量处理中包含不支持的方法: {method}")
# 设置方法默认参数
self._set_method_default_params()
def _get_supported_methods(self) -> Dict[str, str]:
"""获取支持的降维方法列表"""
methods = {
'pca': 'Principal Component Analysis (PCA)',
'ica': 'Independent Component Analysis (ICA)',
'fa': 'Factor Analysis (FA)',
'lda': 'Linear Discriminant Analysis (LDA)',
'mds': 'Multidimensional Scaling (MDS)',
'isomap': 'Isometric Mapping (Isomap)',
'lle': 'Locally Linear Embedding (LLE)',
't-sne': 't-Distributed Stochastic Neighbor Embedding (t-SNE)'
}
return methods
def _set_method_default_params(self):
"""根据方法类型设置默认参数"""
if self.method.lower() == 'pca':
self.method_params.setdefault('whiten', False)
self.method_params.setdefault('svd_solver', 'auto')
elif self.method.lower() == 'ica':
self.method_params.setdefault('algorithm', 'parallel')
self.method_params.setdefault('whiten', 'unit-variance')
self.method_params.setdefault('fun', 'logcosh')
elif self.method.lower() == 'fa':
self.method_params.setdefault('rotation', None)
self.method_params.setdefault('svd_method', 'randomized')
elif self.method.lower() == 'lda':
self.method_params.setdefault('solver', 'svd')
self.method_params.setdefault('shrinkage', None)
elif self.method.lower() == 'mds':
self.method_params.setdefault('metric', True)
self.method_params.setdefault('dissimilarity', 'euclidean')
self.method_params.setdefault('random_state', 42)
elif self.method.lower() == 'isomap':
self.method_params.setdefault('n_neighbors', 5)
self.method_params.setdefault('eigen_solver', 'auto')
elif self.method.lower() == 'lle':
self.method_params.setdefault('n_neighbors', 5)
self.method_params.setdefault('reg', 0.001)
self.method_params.setdefault('eigen_solver', 'auto')
elif self.method.lower() == 't-sne':
self.method_params.setdefault('perplexity', 30.0)
self.method_params.setdefault('early_exaggeration', 12.0)
self.method_params.setdefault('learning_rate', 200.0)
self.method_params.setdefault('random_state', 42)
class HyperspectralDimReduction:
"""高光谱数据降维处理系统"""
def __init__(self, config: DimensionalityReductionConfig):
self.config = config
# 初始化数据存储
self.data = None
self.labels = None
self.header = None
self.original_shape = None
self.data_type = None # 'image' or 'csv'
# 初始化处理器
self.scaler = StandardScaler() if config.use_standardization else None
# 保存模型引用(用于获取贡献百分比)
self._last_model = None
self._explained_variance_ratio = None
# 保存可视化数据
self._last_reduced_data = None
self._last_method = None
def load_data(self) -> None:
"""
加载数据文件(使用配置参数)
"""
print(f"正在加载数据文件: {self.config.input_path}")
file_ext = os.path.splitext(self.config.input_path)[1].lower()
if file_ext in ['.dat', '.img', '.hdr']:
self._load_image_data(self.config.input_path)
self.data_type = 'image'
elif file_ext == '.csv':
self._load_csv_data(self.config.input_path, self.config.label_col,
self.config.spectral_start, self.config.spectral_end,
self.config.csv_has_header)
self.data_type = 'csv'
else:
raise ValueError(f"不支持的文件格式: {file_ext}")
print("数据加载完成")
def _load_image_data(self, file_path):
"""加载高光谱图像数据"""
base_path = os.path.splitext(file_path)[0]
hdr_file = base_path + '.hdr'
# 读取ENVI格式图像
img = spectral.open_image(hdr_file)
self.data = img.load()
self.original_shape = self.data.shape
# 读取头文件信息
self.header = spectral.envi.read_envi_header(hdr_file)
# 转换为二维数组(像素×波段)
self.data = self.data.reshape(-1, self.data.shape[2])
def _load_csv_data(self, file_path, label_col, spectral_start,
spectral_end, has_header):
"""加载CSV数据"""
if has_header:
df = pd.read_csv(file_path)
else:
df = pd.read_csv(file_path, header=None)
# 解析光谱列范围
def find_column_by_wavelength(df, wavelength_str):
"""根据波长值找到最接近的列"""
try:
target_wavelength = float(wavelength_str)
# 尝试从列名中提取波长信息
wavelength_cols = []
for col in df.columns:
if isinstance(col, str):
# 尝试提取列名中的数值
try:
# 处理类似 "wavelength_400.5" 的列名
if 'wavelength' in col.lower():
wl = float(col.split('_')[-1])
wavelength_cols.append((col, wl))
# 处理直接是数值的列名
elif col.replace('.', '').isdigit():
wl = float(col)
wavelength_cols.append((col, wl))
except ValueError:
continue
if wavelength_cols:
# 找到最接近的波长列
closest_col = min(wavelength_cols, key=lambda x: abs(x[1] - target_wavelength))[0]
return closest_col
else:
raise ValueError(f"无法在列名中找到波长信息: {list(df.columns)}")
except ValueError:
# 如果不是数值,直接作为列名使用
return wavelength_str
if spectral_start is not None:
if isinstance(spectral_start, str):
try:
# 尝试转换为数值,如果成功则查找波长列
float(spectral_start)
spectral_start_col = find_column_by_wavelength(df, spectral_start)
except ValueError:
# 如果不是数值,直接作为列名使用
spectral_start_col = spectral_start
else:
spectral_start_col = spectral_start
if spectral_end is not None:
if isinstance(spectral_end, str):
try:
float(spectral_end)
spectral_end_col = find_column_by_wavelength(df, spectral_end)
except ValueError:
spectral_end_col = spectral_end
else:
spectral_end_col = spectral_end
# 使用列名范围选择
spectral_cols = df.loc[:, spectral_start_col:spectral_end_col].columns
spectral_data = df.loc[:, spectral_start_col:spectral_end_col].values
else:
# 只有起始列,从起始列到最后
start_idx = df.columns.get_loc(spectral_start_col)
spectral_data = df.iloc[:, start_idx:].values
spectral_cols = df.iloc[:, start_idx:].columns
else:
# 如果没有指定起始列,使用除标签列外的所有数值列
if label_col is not None:
if isinstance(label_col, str) and label_col in df.columns:
spectral_data = df.drop(columns=[label_col]).select_dtypes(include=[np.number]).values
spectral_cols = df.drop(columns=[label_col]).select_dtypes(include=[np.number]).columns
elif isinstance(label_col, int):
spectral_data = df.drop(columns=df.columns[label_col]).select_dtypes(include=[np.number]).values
spectral_cols = df.drop(columns=df.columns[label_col]).select_dtypes(include=[np.number]).columns
else:
spectral_data = df.select_dtypes(include=[np.number]).values
spectral_cols = df.select_dtypes(include=[np.number]).columns
else:
spectral_data = df.select_dtypes(include=[np.number]).values
spectral_cols = df.select_dtypes(include=[np.number]).columns
# 提取标签
if label_col is not None:
if isinstance(label_col, str):
self.labels = df[label_col].values
else:
self.labels = df.iloc[:, label_col].values
# 从数据中移除标签列
if label_col in df.columns:
df = df.drop(columns=[label_col])
elif isinstance(label_col, int):
df = df.drop(columns=df.columns[label_col])
self.data = spectral_data
self.original_shape = self.data.shape
# 创建模拟头文件
self.header = {
'bands': spectral_data.shape[1],
'lines': 1,
'samples': spectral_data.shape[0],
'interleave': 'bsq',
'data type': 4, # float32
'byte order': 0,
'wavelength': [f'Band_{i}' for i in range(spectral_data.shape[1])]
}
def apply_dim_reduction(self) -> Tuple[np.ndarray, List[str]]:
"""
应用降维方法(使用配置参数)
Returns:
Tuple[np.ndarray, List[str]]: (降维后的数据, 波段名称列表)
数据类型适用性说明:
高光谱图像数据 (data_type='image'):
- 推荐: PCA, ICA, FA (适合高维数据,计算效率高)
- 可用但需谨慎: LDA (需要标签), MDS
- 不推荐: Isomap, LLE, t-SNE (计算复杂度高,内存占用大)
CSV光谱数据 (data_type='csv'):
- 全部方法都可用,但根据数据量选择:
小数据集(<1000样本): 全部方法可用
大数据集(>10000样本): 避免t-SNE, MDS, Isomap, LLE
方法限制:
- LDA: 需要标签信息(self.labels不为None)
- t-SNE/MDS/Isomap/LLE: 对大数据计算复杂度高(O(n²)或更高)
"""
print(f"正在应用降维方法: {self.config.method} (维度: {self.config.n_components})")
# 检查方法适用性
self._check_method_compatibility(self.config.method)
# 数据大小检查
n_samples, n_features = self.data.shape
self._check_data_size_compatibility(self.config.method, n_samples, n_features)
# 数据标准化
if self.config.use_standardization and self.scaler is not None:
data_scaled = self.scaler.fit_transform(self.data)
print("已应用数据标准化")
else:
data_scaled = self.data
# 合并配置参数和额外参数
method_params = {**self.config.method_params}
# 应用不同的降维方法
if self.config.method.upper() == 'PCA':
reducer = PCA(n_components=self.config.n_components, **method_params)
reduced_data = reducer.fit_transform(data_scaled)
band_names = [f'PCA_{i+1}' for i in range(self.config.n_components)]
# 保存模型引用和贡献百分比
self._last_model = reducer
self._explained_variance_ratio = reducer.explained_variance_ratio_
elif self.config.method.upper() == 'LDA':
if self.labels is None:
raise ValueError("LDA需要标签信息请确保加载数据时指定了标签列")
reducer = LinearDiscriminantAnalysis(n_components=min(self.config.n_components,
len(np.unique(self.labels))-1),
**method_params)
reduced_data = reducer.fit_transform(data_scaled, self.labels)
band_names = [f'LDA_{i+1}' for i in range(reduced_data.shape[1])]
elif self.config.method.upper() == 'ICA':
reducer = FastICA(n_components=self.config.n_components, **method_params)
reduced_data = reducer.fit_transform(data_scaled)
band_names = [f'ICA_{i+1}' for i in range(self.config.n_components)]
elif self.config.method.upper() == 'FA':
reducer = FactorAnalysis(n_components=self.config.n_components, **method_params)
reduced_data = reducer.fit_transform(data_scaled)
band_names = [f'FA_{i+1}' for i in range(self.config.n_components)]
elif self.config.method.upper() == 'MDS':
reducer = MDS(n_components=self.config.n_components,**method_params)
reduced_data = reducer.fit_transform(data_scaled)
band_names = [f'MDS_{i+1}' for i in range(self.config.n_components)]
elif self.config.method.upper() == 'ISOMAP':
reducer = Isomap(n_components=self.config.n_components, **method_params)
reduced_data = reducer.fit_transform(data_scaled)
band_names = [f'Isomap_{i+1}' for i in range(self.config.n_components)]
elif self.config.method.upper() == 'LLE':
reducer = LocallyLinearEmbedding(n_components=self.config.n_components, **method_params)
reduced_data = reducer.fit_transform(data_scaled)
band_names = [f'LLE_{i+1}' for i in range(self.config.n_components)]
elif self.config.method.upper() == 'T-SNE':
reducer = TSNE(n_components=self.config.n_components, **method_params)
reduced_data = reducer.fit_transform(data_scaled)
band_names = [f'tSNE_{i+1}' for i in range(self.config.n_components)]
else:
raise ValueError(f"不支持的降维方法: {self.config.method}")
# 保存结果以供可视化使用
self._last_reduced_data = reduced_data
self._last_method = self.config.method
return reduced_data, band_names
def generate_visualization(self, reduced_data=None, method=None, output_path=None):
"""
独立的可视化方法调用
参数:
reduced_data: 降维后的数据如果为None则使用上次处理的结果
method: 方法名称如果为None则使用配置中的方法
output_path: 输出路径如果为None则使用默认路径
"""
if reduced_data is None and hasattr(self, '_last_reduced_data'):
reduced_data = self._last_reduced_data
if method is None:
method = self.config.method
# 保存降维结果以供可视化使用
self._last_reduced_data = reduced_data
self._last_method = method
if output_path is None:
output_path = f'visualization_{method.lower()}.csv'
# 生成可视化图表
if self.config.generate_plots and self.data_type == 'csv':
self._generate_visualization_plots(reduced_data, method, output_path)
def _check_method_compatibility(self, method):
"""
检查降维方法与数据类型的兼容性
高光谱图像数据适用性:
- PCA, ICA, FA: 高度推荐,计算效率高,适合高维数据
- LDA: 需要标签,适用于有监督降维
- MDS: 可用但计算量较大
- Isomap, LLE, t-SNE: 不推荐,计算复杂度高,内存占用大
CSV数据适用性:
- 所有方法都可用,但大数据集应避免复杂度高的方法
"""
method = method.upper()
# LDA需要标签检查
if method == 'LDA' and self.labels is None:
raise ValueError("LDA方法需要标签信息请确保数据包含标签列")
# 高光谱数据的限制
if self.data_type == 'image':
high_complexity_methods = ['ISOMAP', 'LLE', 'T-SNE']
if method in high_complexity_methods:
import warnings
warnings.warn(
f"警告: {method}方法对高光谱图像数据计算复杂度高,"
"可能导致内存不足或计算时间过长。建议使用PCA、ICA或FA方法。",
UserWarning
)
# CSV数据的小提示
elif self.data_type == 'csv':
if method in ['T-SNE', 'MDS', 'ISOMAP', 'LLE']:
import warnings
warnings.warn(
f"提示: {method}方法计算复杂度较高,"
"如果数据集较大建议使用PCA、ICA或FA方法。",
UserWarning
)
def _check_data_size_compatibility(self, method, n_samples, n_features):
"""
根据数据大小给出方法适用性建议
参数:
method: 降维方法
n_samples: 样本数量
n_features: 特征数量
"""
method = method.upper()
high_complexity_methods = ['T-SNE', 'MDS', 'ISOMAP', 'LLE']
# 对于大数据集的警告
if method in high_complexity_methods:
if n_samples > 5000:
import warnings
warnings.warn(
f"警告: 数据集有{n_samples}个样本,{method}方法的计算复杂度为O(n²)"
"可能需要很长时间。建议使用PCA、ICA或FA方法。",
UserWarning
)
elif n_samples > 10000:
raise ValueError(
f"数据集过大({n_samples}样本){method}方法不适合。"
"请使用PCA、ICA或FA方法。"
)
# 对于高维数据的建议
if n_features > 1000 and method not in ['PCA', 'ICA', 'FA']:
import warnings
warnings.warn(
f"提示: 数据维度很高({n_features})建议优先考虑PCA、ICA或FA方法。",
UserWarning
)
def get_recommended_methods(self, data_type=None, n_samples=None, n_features=None,
has_labels=False, max_components=None):
"""
根据数据特征推荐合适的降维方法
参数:
data_type: 数据类型 ('image''csv')
n_samples: 样本数量
n_features: 特征数量
has_labels: 是否有标签
max_components: 最大降维维度
返回:
dict: 包含推荐方法和理由的字典
"""
if data_type is None:
data_type = self.data_type
if n_samples is None and self.data:
n_samples = self.data.shape[0]
if n_features is None and self.data:
n_features = self.data.shape[1]
if not has_labels and self.labels is not None:
has_labels = True
recommendations = {
'highly_recommended': [],
'recommended': [],
'use_with_caution': [],
'not_recommended': []
}
# 高光谱图像数据推荐
if data_type == 'image':
recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA']
if has_labels:
recommendations['recommended'].append('LDA')
else:
recommendations['recommended'].append('MDS')
recommendations['use_with_caution'] = ['Isomap', 'LLE']
recommendations['not_recommended'] = ['t-SNE']
# CSV数据推荐
elif data_type == 'csv':
if n_samples and n_samples < 1000: # 小数据集
recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA']
recommendations['recommended'] = ['MDS', 'Isomap', 'LLE', 't-SNE']
if has_labels:
recommendations['recommended'].append('LDA')
elif n_samples and n_samples < 5000: # 中等数据集
recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA']
recommendations['recommended'] = ['MDS']
if has_labels:
recommendations['recommended'].append('LDA')
recommendations['use_with_caution'] = ['Isomap', 'LLE', 't-SNE']
else: # 大数据集
recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA']
if has_labels:
recommendations['recommended'].append('LDA')
recommendations['not_recommended'] = ['MDS', 'Isomap', 'LLE', 't-SNE']
# 高维数据额外建议
if n_features and n_features > 500:
# 对于高维数据PCA通常是最佳选择
if 'PCA' not in recommendations['highly_recommended']:
recommendations['highly_recommended'].insert(0, 'PCA')
return recommendations
@staticmethod
def print_method_info():
"""
打印所有可用降维方法的详细信息
"""
method_info = {
'PCA': {
'name': '主成分分析 (Principal Component Analysis)',
'description': '通过正交变换将数据投影到新的坐标系,保留最大方差的方向',
'advantages': ['计算效率高', '适合高维数据', '无参数', '保持全局结构'],
'limitations': ['线性方法', '对异常值敏感'],
'best_for': ['高光谱数据预处理', '噪声去除', '数据可视化'],
'data_types': ['image', 'csv'],
'complexity': ''
},
'ICA': {
'name': '独立成分分析 (Independent Component Analysis)',
'description': '假设数据是多个独立源的混合,通过统计独立性分离信号',
'advantages': ['能发现隐藏因素', '适合信号分离', '非线性'],
'limitations': ['计算复杂度较高', '结果可能不稳定'],
'best_for': ['信号分离', '去除混合噪声', '发现潜在成分'],
'data_types': ['image', 'csv'],
'complexity': ''
},
'FA': {
'name': '因子分析 (Factor Analysis)',
'description': '假设观测变量由潜在因子和误差组成,建立因子模型',
'advantages': ['解释性强', '适合社会科学数据', '处理测量误差'],
'limitations': ['假设较多', '对数据分布敏感'],
'best_for': ['探索潜在结构', '变量分组', '降维解释'],
'data_types': ['image', 'csv'],
'complexity': ''
},
'LDA': {
'name': '线性判别分析 (Linear Discriminant Analysis)',
'description': '通过最大化类间距离和最小化类内距离进行降维',
'advantages': ['监督学习', '保持类别可分性', '计算简单'],
'limitations': ['需要标签', '线性假设', '对异常值敏感'],
'best_for': ['分类任务', '有标签数据', '保持类别结构'],
'data_types': ['csv'], # 主要用于CSV因为图像通常没有像素级标签
'complexity': ''
},
'MDS': {
'name': '多维尺度分析 (Multidimensional Scaling)',
'description': '保持样本间距离关系在低维空间中的映射',
'advantages': ['保持距离关系', '无分布假设', '直观'],
'limitations': ['计算复杂度O(n²)', '对大数据不适用'],
'best_for': ['保持相对距离', '数据可视化', '相似性分析'],
'data_types': ['csv'],
'complexity': ''
},
'Isomap': {
'name': '等度量映射 (Isometric Mapping)',
'description': '保持测地线距离,在流形上计算最短路径',
'advantages': ['处理非线性流形', '保持局部和全局结构'],
'limitations': ['对噪声敏感', '计算复杂度高', '参数选择重要'],
'best_for': ['非线性流形数据', '保持几何结构'],
'data_types': ['csv'],
'complexity': '很高'
},
'LLE': {
'name': '局部线性嵌入 (Locally Linear Embedding)',
'description': '保持局部邻域的线性关系进行降维',
'advantages': ['处理非线性数据', '保持局部结构', '无优化参数'],
'limitations': ['对大数据不适用', '对噪声敏感', '邻域参数敏感'],
'best_for': ['非线性降维', '保持局部邻域关系'],
'data_types': ['csv'],
'complexity': '很高'
},
't-SNE': {
'name': 't-分布随机邻域嵌入 (t-SNE)',
'description': '将高维数据转换为低维空间,保持局部相似性',
'advantages': ['优秀的可视化效果', '保持局部结构', '非线性'],
'limitations': ['计算复杂度极高', '主要用于可视化', '结果不稳定'],
'best_for': ['数据可视化', '聚类分析', '探索性分析'],
'data_types': ['csv'],
'complexity': '极高'
}
}
print("\n=== 降维方法详细说明 ===\n")
for method, info in method_info.items():
print(f"{method} - {info['name']}")
print(f" 描述: {info['description']}")
print(f" 优势: {', '.join(info['advantages'])}")
print(f" 局限性: {', '.join(info['limitations'])}")
print(f" 适用场景: {', '.join(info['best_for'])}")
print(f" 数据类型: {', '.join(info['data_types'])}")
print(f" 计算复杂度: {info['complexity']}")
print()
def save_results(self, reduced_data, band_names, output_path, method):
"""
保存降维结果
参数:
reduced_data: 降维后的数据
band_names: 波段名称列表
output_path: 输出路径
method: 使用的降维方法
"""
output_base = os.path.splitext(output_path)[0]
if self.data_type == 'image':
# 恢复图像形状
if self.original_shape:
img_shape = (self.original_shape[0], self.original_shape[1],
reduced_data.shape[1])
reduced_img = reduced_data.reshape(img_shape)
else:
reduced_img = reduced_data
# 保存.dat文件
self._save_dat_file(reduced_img, output_path)
# 保存.hdr头文件
hdr_path = output_base + '.hdr'
self._save_hdr_file(hdr_path, reduced_img.shape, band_names, method)
print(f"图像已保存: {output_path}")
print(f"头文件已保存: {hdr_path}")
else: # CSV格式
# 只保存为CSV格式不保存DAT和HDR文件
df_output = pd.DataFrame(reduced_data, columns=band_names)
# 如果有标签,将标签列放在第一列
if self.labels is not None:
df_output.insert(0, 'Label', self.labels)
# 如果output_path已经有.csv扩展名直接使用否则添加扩展名
if output_path.lower().endswith('.csv'):
csv_path = output_path
else:
csv_path = output_base + '.csv'
df_output.to_csv(csv_path, index=False)
print(f"CSV文件已保存: {csv_path}")
print("注意CSV输入文件只保存CSV格式结果不生成DAT和HDR文件")
# 生成可视化图表仅对CSV数据
if self.config.generate_plots:
self._generate_visualization_plots(reduced_data, method, csv_path)
def _save_dat_file(self, data, file_path):
"""保存.dat文件二进制格式"""
with open(file_path, 'wb') as f:
data.astype(np.float32).tofile(f)
def _save_hdr_file(self, hdr_path, data_shape, band_names, method, is_csv=False):
"""保存ENVI头文件"""
if is_csv:
lines, samples = 1, data_shape[0]
else:
lines, samples = data_shape[0], data_shape[1]
bands = data_shape[2] if len(data_shape) == 3 else data_shape[1]
# 从原始头文件中保留关键信息
original_info = {}
if self.header and self.data_type == 'image':
# 保留原始的传感器和采集信息
preserve_keys = [
'sensor type', 'wavelength units', 'sample binning',
'spectral binning', 'line binning', 'shutter', 'gain',
'framerate', 'imager serial number', 'rotation', 'label',
'map info', 'coordinate system string', 'classes'
]
for key in preserve_keys:
if key in self.header:
original_info[key] = self.header[key]
# 处理波长信息 - 如果是降维结果,使用新的波段名称
if 'wavelength' in self.header:
# 对于降维结果,不使用原始波长,而是使用新的波段名称
original_info['band names'] = band_names
else:
original_info['band names'] = band_names
header_content = f"""ENVI
description = {{{method} Result [{pd.Timestamp.now().strftime('%a %b %d %H:%M:%S %Y')}]}}
samples = {samples}
lines = {lines}
bands = {bands}
header offset = 0
file type = ENVI Standard
data type = 4
interleave = bip
byte order = 0
"""
# 添加保留的原始信息
for key, value in original_info.items():
if key == 'band names':
header_content += f"band names = {{{', '.join(value)}}}\n"
elif key == 'wavelength':
# 格式化波长信息
wavelength_str = ', '.join([f'{w:.6f}' for w in value]) if isinstance(value, (list, np.ndarray)) else str(value)
header_content += f"original wavelength = {{{wavelength_str}}}\n"
elif key == 'rotation':
# 特殊处理rotation格式
if isinstance(value, list) and len(value) == 4:
rotation_str = f"[({value[0][0]}, {value[0][1]}), ({value[1][0]}, {value[1][1]}), ({value[2][0]}, {value[2][1]}), ({value[3][0]}, {value[3][1]})]"
header_content += f"rotation = {rotation_str}\n"
else:
header_content += f"rotation = {value}\n"
elif isinstance(value, str) and '\n' in value:
# 多行字符串
header_content += f"{key} = {{\n{value}\n}}\n"
else:
header_content += f"{key} = {value}\n"
# 如果没有原始信息,添加默认值
if not original_info:
header_content += f"""wavelength units = Unknown
band names = {{{', '.join(band_names)}}}
classes = 0
map info = {{Geographic Lat/Lon, 1.0000, 1.0000, 0.0, 0.0, 0.0, 0.0}}
"""
# 添加处理历史
if self.header and 'history' in self.header:
history_info = f"{self.header['history']} -> {method}<SpecGroup>[<SpecInt label:'Components' value:{bands}>]"
else:
history_info = f"{method}<SpecGroup>[<SpecInt label:'Components' value:{bands}>]"
header_content += f"history = {history_info}\n"
with open(hdr_path, 'w', encoding='utf-8') as f:
f.write(header_content)
def _generate_visualization_plots(self, reduced_data, method, csv_path):
"""生成降维结果的可视化散点图"""
try:
# 获取贡献百分比(仅对支持的方法)
explained_variance_ratio = self._get_explained_variance_ratio(method)
# 生成2D散点图前两个特征
if reduced_data.shape[1] >= 2:
self._plot_2d_scatter(reduced_data, method, explained_variance_ratio, csv_path)
# 生成3D散点图前三个特征
if reduced_data.shape[1] >= 3:
self._plot_3d_scatter(reduced_data, method, explained_variance_ratio, csv_path)
except Exception as e:
print(f"生成可视化图表时出错: {e}")
def _get_explained_variance_ratio(self, method):
"""获取降维方法的贡献百分比"""
method_lower = method.lower()
# 对于PCA可以获取解释方差比
if method_lower == 'pca' and self._explained_variance_ratio is not None:
return self._explained_variance_ratio
# 对于其他方法,目前不支持贡献百分比
return None
def _plot_2d_scatter(self, data, method, explained_variance_ratio, csv_path):
"""生成2D散点图"""
plt.figure(figsize=(10, 8))
if self.config.color_by_label and self.labels is not None:
# 按标签着色
unique_labels = np.unique(self.labels)
colors = cm.rainbow(np.linspace(0, 1, len(unique_labels)))
for i, label in enumerate(unique_labels):
mask = self.labels == label
plt.scatter(data[mask, 0], data[mask, 1],
c=[colors[i]], label=f'Class {label}',
alpha=0.7, s=50, edgecolors='black', linewidth=0.5)
plt.legend()
else:
# 统一颜色
plt.scatter(data[:, 0], data[:, 1], alpha=0.7, s=50,
c='blue', edgecolors='black', linewidth=0.5)
# 设置坐标轴标签
xlabel = 'PC1'
ylabel = 'PC2'
if explained_variance_ratio is not None and len(explained_variance_ratio) >= 2:
xlabel = f'PC1 ({explained_variance_ratio[0]:.1%})'
ylabel = f'PC2 ({explained_variance_ratio[1]:.1%})'
plt.xlabel(xlabel, fontsize=12, fontweight='bold')
plt.ylabel(ylabel, fontsize=12, fontweight='bold')
plt.title(f'{method.upper()} - 2D Scatter Plot', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
# 保存图表
if self.config.save_plots:
plot_path = os.path.splitext(csv_path)[0] + '_2d_scatter.png'
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"2D散点图已保存: {plot_path}")
plt.close()
def _plot_3d_scatter(self, data, method, explained_variance_ratio, csv_path):
"""生成3D散点图"""
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
if self.config.color_by_label and self.labels is not None:
# 按标签着色
unique_labels = np.unique(self.labels)
colors = cm.rainbow(np.linspace(0, 1, len(unique_labels)))
for i, label in enumerate(unique_labels):
mask = self.labels == label
ax.scatter(data[mask, 0], data[mask, 1], data[mask, 2],
c=[colors[i]], label=f'Class {label}',
alpha=0.7, s=50, edgecolors='black', linewidth=0.5)
ax.legend()
else:
# 统一颜色
ax.scatter(data[:, 0], data[:, 1], data[:, 2],
alpha=0.7, s=50, c='blue', edgecolors='black', linewidth=0.5)
# 设置坐标轴标签
xlabel = 'PC1'
ylabel = 'PC2'
zlabel = 'PC3'
if explained_variance_ratio is not None and len(explained_variance_ratio) >= 3:
xlabel = f'PC1 ({explained_variance_ratio[0]:.1%})'
ylabel = f'PC2 ({explained_variance_ratio[1]:.1%})'
zlabel = f'PC3 ({explained_variance_ratio[2]:.1%})'
ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
ax.set_zlabel(zlabel, fontsize=12, fontweight='bold')
ax.set_title(f'{method.upper()} - 3D Scatter Plot', fontsize=14, fontweight='bold')
# 保存图表
if self.config.save_plots:
plot_path = os.path.splitext(csv_path)[0] + '_3d_scatter.png'
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"3D散点图已保存: {plot_path}")
plt.close()
def batch_process(self, methods, n_components_list, output_dir):
"""
批量处理多个降维方法
参数:
methods: 降维方法列表
n_components_list: 各方法的维度列表
output_dir: 输出目录
"""
os.makedirs(output_dir, exist_ok=True)
# 显示数据信息和方法推荐
if self.data is not None:
n_samples, n_features = self.data.shape
has_labels = self.labels is not None
print(f"\n=== 数据信息 ===")
print(f"数据类型: {self.data_type}")
print(f"样本数量: {n_samples}")
print(f"特征数量: {n_features}")
print(f"是否有标签: {'' if has_labels else ''}")
# 获取推荐方法
recommendations = self.get_recommended_methods(
self.data_type, n_samples, n_features, has_labels
)
print(f"\n=== 方法推荐 ===")
print(f"高度推荐: {', '.join(recommendations['highly_recommended'])}")
if recommendations['recommended']:
print(f"推荐: {', '.join(recommendations['recommended'])}")
if recommendations['use_with_caution']:
print(f"谨慎使用: {', '.join(recommendations['use_with_caution'])}")
if recommendations['not_recommended']:
print(f"不推荐: {', '.join(recommendations['not_recommended'])}")
# 检查用户选择的方法是否在推荐列表中
for method in methods:
method_lower = method.lower()
method_upper = method_lower.upper()
if method_upper in recommendations['not_recommended']:
print(f"⚠️ 警告: {method_lower}方法不适合当前数据类型,建议使用推荐方法")
elif method_upper in recommendations['use_with_caution']:
print(f"⚠️ 注意: {method_lower}方法对大数据计算复杂度较高,请确保内存充足")
results = {}
for method, n_comp in zip(methods, n_components_list):
# 统一方法名为小写
method_lower = method.lower()
print(f"\n正在处理: {method_lower} (维度: {n_comp})")
try:
# 临时修改配置以适应当前方法
original_method = self.config.method
original_n_components = self.config.n_components
original_method_params = self.config.method_params.copy()
self.config.method = method_lower
self.config.n_components = n_comp
# 清空方法参数,重新设置当前方法的默认参数
self.config.method_params = {}
self.config._set_method_default_params()
reduced_data, band_names = self.apply_dim_reduction()
# 恢复原始配置
self.config.method = original_method
self.config.n_components = original_n_components
self.config.method_params = original_method_params
# 保存结果
if self.data_type == 'image':
output_path = os.path.join(output_dir, f'reduced_{method_lower}.dat')
else: # CSV格式
output_path = os.path.join(output_dir, f'reduced_{method_lower}.csv')
self.save_results(reduced_data, band_names, output_path, method_lower)
# 如果启用可视化且是CSV数据在批量处理时也生成可视化
if self.config.generate_plots and self.data_type == 'csv':
try:
self._generate_visualization_plots(reduced_data, method_lower, output_path)
except Exception as vis_error:
print(f"生成 {method_lower} 可视化时出错: {vis_error}")
results[method_lower] = {
'data': reduced_data,
'bands': band_names,
'n_components': n_comp
}
except Exception as e:
print(f"处理 {method_lower} 时出错: {str(e)}")
return results
def main():
"""主函数:命令行接口"""
import argparse
parser = argparse.ArgumentParser(
description='高光谱数据降维处理工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
1. 单方法降维 (PCA, 3维):
python dimensionality_reduction.py input.csv -l Label -s 391.3225 -e 1003.175 -m pca -n 3 -o result
2. 批量降维处理多种方法:
python dimensionality_reduction.py input.csv -l Label -s 391.3225 -e 1003.175 -M pca ica fa -N 3 3 3 -d results
3. 处理ENVI图像数据:
python dimensionality_reduction.py image.hdr -m pca -n 5 -o reduced_image
4. 生成可视化图表:
python dimensionality_reduction.py input.csv -l Label -s 391.3225 -e 1003.175 -m pca -n 3 -o result --generate-plots --color-by-label
支持的降维方法:
pca: Principal Component Analysis (PCA)
ica: Independent Component Analysis (ICA)
fa: Factor Analysis (FA)
lda: Linear Discriminant Analysis (LDA)
mds: Multidimensional Scaling (MDS)
isomap: Isometric Mapping (Isomap)
lle: Locally Linear Embedding (LLE)
t-sne: t-Distributed Stochastic Neighbor Embedding (t-SNE)
"""
)
parser.add_argument('input_path', help='输入文件路径 (.hdr高光谱图像 或 .csv文件)')
# CSV文件参数
parser.add_argument('-l', '--label-col', help='CSV文件的标签列名')
parser.add_argument('-s', '--spectral-start', help='CSV文件的谱段起始列名')
parser.add_argument('-e', '--spectral-end', help='CSV文件的谱段结束列名')
# 单方法参数
parser.add_argument('-m', '--method', default='pca',
choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'],
help='降维方法 (默认: pca)')
parser.add_argument('-n', '--n-components', type=int, default=3,
help='降维后的维度数 (默认: 3)')
parser.add_argument('-o', '--output', help='输出文件路径或前缀')
# 批量处理参数
parser.add_argument('-M', '--batch-methods', nargs='+',
choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'],
help='批量处理的降维方法列表')
parser.add_argument('-N', '--batch-components', nargs='+', type=int,
help='对应批量方法的维度数列表')
parser.add_argument('-d', '--output-dir', help='批量处理输出目录')
# 可视化参数
parser.add_argument('--generate-plots', action='store_true',
help='生成可视化图表')
parser.add_argument('--color-by-label', action='store_true',
help='按标签着色散点图')
parser.add_argument('--save-plots', action='store_true',
help='保存图表文件')
# 其他选项
parser.add_argument('--no-standardization', action='store_true',
help='不进行数据标准化')
parser.add_argument('--show-methods', action='store_true',
help='显示所有可用方法的详细信息')
args = parser.parse_args()
# 显示方法信息
if args.show_methods:
HyperspectralDimReduction.print_method_info()
return 0
try:
print("=" * 60)
print("高光谱数据降维处理工具")
print("=" * 60)
print(f"输入文件: {args.input_path}")
# 检查是单方法还是批量处理
is_batch = args.batch_methods is not None
if is_batch:
print(f"批量处理方法: {', '.join(args.batch_methods)}")
print(f"对应维度: {args.batch_components}")
print(f"输出目录: {args.output_dir}")
else:
print(f"降维方法: {args.method}")
print(f"目标维度: {args.n_components}")
print(f"输出路径: {args.output}")
if args.label_col:
print(f"标签列: {args.label_col}")
if args.spectral_start:
print(f"谱段范围: {args.spectral_start} - {args.spectral_end}")
print()
# 创建配置
config = DimensionalityReductionConfig(
input_path=args.input_path,
label_col=args.label_col,
spectral_start=args.spectral_start,
spectral_end=args.spectral_end,
output_path=args.output,
output_dir=args.output_dir,
use_standardization=not args.no_standardization,
generate_plots=args.generate_plots,
color_by_label=args.color_by_label,
save_plots=args.save_plots
)
# 创建处理器
processor = HyperspectralDimReduction(config)
# 加载数据
print("正在加载数据...")
processor.load_data()
if is_batch:
# 批量处理
if not args.batch_components:
args.batch_components = [3] * len(args.batch_methods)
elif len(args.batch_components) != len(args.batch_methods):
raise ValueError(f"批量方法数量 ({len(args.batch_methods)}) 与维度数量 ({len(args.batch_components)}) 不匹配")
print("\n开始批量降维处理...")
results = processor.batch_process(args.batch_methods, args.batch_components, args.output_dir or 'reduced_data_batch')
# 打印结果汇总
print(f"\n=== 批量处理完成 ===")
print(f"共处理了 {len(results)} 种降维方法:")
for method_name, result in results.items():
print(f"- {method_name}: {result['n_components']}")
# 保存汇总信息
import pandas as pd
summary_data = []
for method_name, result in results.items():
summary_data.append({
'method': method_name,
'n_components': result['n_components'],
'data_shape': str(result['data'].shape),
'output_file': f'reduced_{method_name}'
})
summary_df = pd.DataFrame(summary_data)
summary_path = os.path.join(args.output_dir or 'reduced_data_batch', 'batch_processing_summary.csv')
summary_df.to_csv(summary_path, index=False)
print(f"批量处理汇总已保存到: {summary_path}")
else:
# 单方法处理
print("\n开始降维处理...")
config.method = args.method
config.n_components = args.n_components
config.output_path = args.output
reduced_data, band_names = processor.apply_dim_reduction()
print("✓ 降维处理完成")
print(f"原始数据维度: {processor.data.shape if processor.data is not None else 'Unknown'}")
print(f"降维后数据维度: {reduced_data.shape}")
# 保存结果
if args.output:
processor.save_results(reduced_data, band_names, args.output, args.method)
print("\n" + "=" * 60)
print("处理完成!")
print("=" * 60)
except Exception as e:
print(f"✗ 处理失败: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())