import numpy as np import pandas as pd import spectral from sklearn.decomposition import PCA, FastICA, FactorAnalysis from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.manifold import MDS, Isomap, TSNE, LocallyLinearEmbedding from sklearn.preprocessing import StandardScaler # import h5py import os from scipy.io import savemat import warnings from typing import Optional, Dict, Any, Tuple, List from dataclasses import dataclass, field warnings.filterwarnings('ignore') # 可视化相关导入 import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import matplotlib.cm as cm import matplotlib.colors as mcolors @dataclass class DimensionalityReductionConfig: """高光谱数据降维配置类""" # 输入文件配置 input_path: Optional[str] = None label_col: Optional[str] = None spectral_start: Optional[str] = None spectral_end: Optional[str] = None csv_has_header: bool = True # 输出配置 output_path: Optional[str] = None output_dir: Optional[str] = None # 降维方法配置 method: str = 'PCA' n_components: int = 3 method_params: Dict[str, Any] = field(default_factory=dict) # 批量处理配置 batch_methods: List[str] = field(default_factory=lambda: ['PCA', 'ICA', 'FA']) batch_components: List[int] = field(default_factory=lambda: [3, 3, 3]) # 数据预处理配置 use_standardization: bool = True # 可视化参数 generate_plots: bool = True # 是否生成散点分布图 color_by_label: bool = True # 是否按标签着色散点 save_plots: bool = True # 是否保存图表文件 def __post_init__(self): """参数校验和默认值设置""" # 校验必需的文件路径 if not self.input_path: raise ValueError("必须指定输入文件路径(input_path)") # 校验文件存在性 if not os.path.exists(self.input_path): raise FileNotFoundError(f"输入文件不存在: {self.input_path}") # 校验输出路径(批量处理时必需) # 单方法处理时输出路径是可选的 # 统一方法名为小写 self.method = self.method.lower() self.batch_methods = [m.lower() for m in self.batch_methods] # 校验降维方法 supported_methods = self._get_supported_methods() if self.method not in supported_methods: raise ValueError(f"不支持的降维方法: {self.method}。支持的方法: {list(supported_methods.keys())}") # 校验批量处理参数 if len(self.batch_methods) != len(self.batch_components): raise ValueError("batch_methods和batch_components长度必须相等") for method in self.batch_methods: if method not in supported_methods: raise ValueError(f"批量处理中包含不支持的方法: {method}") # 设置方法默认参数 self._set_method_default_params() def _get_supported_methods(self) -> Dict[str, str]: """获取支持的降维方法列表""" methods = { 'pca': 'Principal Component Analysis (PCA)', 'ica': 'Independent Component Analysis (ICA)', 'fa': 'Factor Analysis (FA)', 'lda': 'Linear Discriminant Analysis (LDA)', 'mds': 'Multidimensional Scaling (MDS)', 'isomap': 'Isometric Mapping (Isomap)', 'lle': 'Locally Linear Embedding (LLE)', 't-sne': 't-Distributed Stochastic Neighbor Embedding (t-SNE)' } return methods def _set_method_default_params(self): """根据方法类型设置默认参数""" if self.method.lower() == 'pca': self.method_params.setdefault('whiten', False) self.method_params.setdefault('svd_solver', 'auto') elif self.method.lower() == 'ica': self.method_params.setdefault('algorithm', 'parallel') self.method_params.setdefault('whiten', 'unit-variance') self.method_params.setdefault('fun', 'logcosh') elif self.method.lower() == 'fa': self.method_params.setdefault('rotation', None) self.method_params.setdefault('svd_method', 'randomized') elif self.method.lower() == 'lda': self.method_params.setdefault('solver', 'svd') self.method_params.setdefault('shrinkage', None) elif self.method.lower() == 'mds': self.method_params.setdefault('metric', True) self.method_params.setdefault('dissimilarity', 'euclidean') self.method_params.setdefault('random_state', 42) elif self.method.lower() == 'isomap': self.method_params.setdefault('n_neighbors', 5) self.method_params.setdefault('eigen_solver', 'auto') elif self.method.lower() == 'lle': self.method_params.setdefault('n_neighbors', 5) self.method_params.setdefault('reg', 0.001) self.method_params.setdefault('eigen_solver', 'auto') elif self.method.lower() == 't-sne': self.method_params.setdefault('perplexity', 30.0) self.method_params.setdefault('early_exaggeration', 12.0) self.method_params.setdefault('learning_rate', 200.0) self.method_params.setdefault('random_state', 42) class HyperspectralDimReduction: """高光谱数据降维处理系统""" def __init__(self, config: DimensionalityReductionConfig): self.config = config # 初始化数据存储 self.data = None self.labels = None self.header = None self.original_shape = None self.data_type = None # 'image' or 'csv' # 初始化处理器 self.scaler = StandardScaler() if config.use_standardization else None # 保存模型引用(用于获取贡献百分比) self._last_model = None self._explained_variance_ratio = None # 保存可视化数据 self._last_reduced_data = None self._last_method = None def load_data(self) -> None: """ 加载数据文件(使用配置参数) """ print(f"正在加载数据文件: {self.config.input_path}") file_ext = os.path.splitext(self.config.input_path)[1].lower() if file_ext in ['.dat', '.img', '.hdr']: self._load_image_data(self.config.input_path) self.data_type = 'image' elif file_ext == '.csv': self._load_csv_data(self.config.input_path, self.config.label_col, self.config.spectral_start, self.config.spectral_end, self.config.csv_has_header) self.data_type = 'csv' else: raise ValueError(f"不支持的文件格式: {file_ext}") print("数据加载完成") def _load_image_data(self, file_path): """加载高光谱图像数据""" base_path = os.path.splitext(file_path)[0] hdr_file = base_path + '.hdr' # 读取ENVI格式图像 img = spectral.open_image(hdr_file) self.data = img.load() self.original_shape = self.data.shape # 读取头文件信息 self.header = spectral.envi.read_envi_header(hdr_file) # 转换为二维数组(像素×波段) self.data = self.data.reshape(-1, self.data.shape[2]) def _load_csv_data(self, file_path, label_col, spectral_start, spectral_end, has_header): """加载CSV数据""" if has_header: df = pd.read_csv(file_path) else: df = pd.read_csv(file_path, header=None) # 解析光谱列范围 def find_column_by_wavelength(df, wavelength_str): """根据波长值找到最接近的列""" try: target_wavelength = float(wavelength_str) # 尝试从列名中提取波长信息 wavelength_cols = [] for col in df.columns: if isinstance(col, str): # 尝试提取列名中的数值 try: # 处理类似 "wavelength_400.5" 的列名 if 'wavelength' in col.lower(): wl = float(col.split('_')[-1]) wavelength_cols.append((col, wl)) # 处理直接是数值的列名 elif col.replace('.', '').isdigit(): wl = float(col) wavelength_cols.append((col, wl)) except ValueError: continue if wavelength_cols: # 找到最接近的波长列 closest_col = min(wavelength_cols, key=lambda x: abs(x[1] - target_wavelength))[0] return closest_col else: raise ValueError(f"无法在列名中找到波长信息: {list(df.columns)}") except ValueError: # 如果不是数值,直接作为列名使用 return wavelength_str if spectral_start is not None: if isinstance(spectral_start, str): try: # 尝试转换为数值,如果成功则查找波长列 float(spectral_start) spectral_start_col = find_column_by_wavelength(df, spectral_start) except ValueError: # 如果不是数值,直接作为列名使用 spectral_start_col = spectral_start else: spectral_start_col = spectral_start if spectral_end is not None: if isinstance(spectral_end, str): try: float(spectral_end) spectral_end_col = find_column_by_wavelength(df, spectral_end) except ValueError: spectral_end_col = spectral_end else: spectral_end_col = spectral_end # 使用列名范围选择 spectral_cols = df.loc[:, spectral_start_col:spectral_end_col].columns spectral_data = df.loc[:, spectral_start_col:spectral_end_col].values else: # 只有起始列,从起始列到最后 start_idx = df.columns.get_loc(spectral_start_col) spectral_data = df.iloc[:, start_idx:].values spectral_cols = df.iloc[:, start_idx:].columns else: # 如果没有指定起始列,使用除标签列外的所有数值列 if label_col is not None: if isinstance(label_col, str) and label_col in df.columns: spectral_data = df.drop(columns=[label_col]).select_dtypes(include=[np.number]).values spectral_cols = df.drop(columns=[label_col]).select_dtypes(include=[np.number]).columns elif isinstance(label_col, int): spectral_data = df.drop(columns=df.columns[label_col]).select_dtypes(include=[np.number]).values spectral_cols = df.drop(columns=df.columns[label_col]).select_dtypes(include=[np.number]).columns else: spectral_data = df.select_dtypes(include=[np.number]).values spectral_cols = df.select_dtypes(include=[np.number]).columns else: spectral_data = df.select_dtypes(include=[np.number]).values spectral_cols = df.select_dtypes(include=[np.number]).columns # 提取标签 if label_col is not None: if isinstance(label_col, str): self.labels = df[label_col].values else: self.labels = df.iloc[:, label_col].values # 从数据中移除标签列 if label_col in df.columns: df = df.drop(columns=[label_col]) elif isinstance(label_col, int): df = df.drop(columns=df.columns[label_col]) self.data = spectral_data self.original_shape = self.data.shape # 创建模拟头文件 self.header = { 'bands': spectral_data.shape[1], 'lines': 1, 'samples': spectral_data.shape[0], 'interleave': 'bsq', 'data type': 4, # float32 'byte order': 0, 'wavelength': [f'Band_{i}' for i in range(spectral_data.shape[1])] } def apply_dim_reduction(self) -> Tuple[np.ndarray, List[str]]: """ 应用降维方法(使用配置参数) Returns: Tuple[np.ndarray, List[str]]: (降维后的数据, 波段名称列表) 数据类型适用性说明: 高光谱图像数据 (data_type='image'): - 推荐: PCA, ICA, FA (适合高维数据,计算效率高) - 可用但需谨慎: LDA (需要标签), MDS - 不推荐: Isomap, LLE, t-SNE (计算复杂度高,内存占用大) CSV光谱数据 (data_type='csv'): - 全部方法都可用,但根据数据量选择: 小数据集(<1000样本): 全部方法可用 大数据集(>10000样本): 避免t-SNE, MDS, Isomap, LLE 方法限制: - LDA: 需要标签信息(self.labels不为None) - t-SNE/MDS/Isomap/LLE: 对大数据计算复杂度高(O(n²)或更高) """ print(f"正在应用降维方法: {self.config.method} (维度: {self.config.n_components})") # 检查方法适用性 self._check_method_compatibility(self.config.method) # 数据大小检查 n_samples, n_features = self.data.shape self._check_data_size_compatibility(self.config.method, n_samples, n_features) # 数据标准化 if self.config.use_standardization and self.scaler is not None: data_scaled = self.scaler.fit_transform(self.data) print("已应用数据标准化") else: data_scaled = self.data # 合并配置参数和额外参数 method_params = {**self.config.method_params} # 应用不同的降维方法 if self.config.method.upper() == 'PCA': reducer = PCA(n_components=self.config.n_components, **method_params) reduced_data = reducer.fit_transform(data_scaled) band_names = [f'PCA_{i+1}' for i in range(self.config.n_components)] # 保存模型引用和贡献百分比 self._last_model = reducer self._explained_variance_ratio = reducer.explained_variance_ratio_ elif self.config.method.upper() == 'LDA': if self.labels is None: raise ValueError("LDA需要标签信息,请确保加载数据时指定了标签列") reducer = LinearDiscriminantAnalysis(n_components=min(self.config.n_components, len(np.unique(self.labels))-1), **method_params) reduced_data = reducer.fit_transform(data_scaled, self.labels) band_names = [f'LDA_{i+1}' for i in range(reduced_data.shape[1])] elif self.config.method.upper() == 'ICA': reducer = FastICA(n_components=self.config.n_components, **method_params) reduced_data = reducer.fit_transform(data_scaled) band_names = [f'ICA_{i+1}' for i in range(self.config.n_components)] elif self.config.method.upper() == 'FA': reducer = FactorAnalysis(n_components=self.config.n_components, **method_params) reduced_data = reducer.fit_transform(data_scaled) band_names = [f'FA_{i+1}' for i in range(self.config.n_components)] elif self.config.method.upper() == 'MDS': reducer = MDS(n_components=self.config.n_components,**method_params) reduced_data = reducer.fit_transform(data_scaled) band_names = [f'MDS_{i+1}' for i in range(self.config.n_components)] elif self.config.method.upper() == 'ISOMAP': reducer = Isomap(n_components=self.config.n_components, **method_params) reduced_data = reducer.fit_transform(data_scaled) band_names = [f'Isomap_{i+1}' for i in range(self.config.n_components)] elif self.config.method.upper() == 'LLE': reducer = LocallyLinearEmbedding(n_components=self.config.n_components, **method_params) reduced_data = reducer.fit_transform(data_scaled) band_names = [f'LLE_{i+1}' for i in range(self.config.n_components)] elif self.config.method.upper() == 'T-SNE': reducer = TSNE(n_components=self.config.n_components, **method_params) reduced_data = reducer.fit_transform(data_scaled) band_names = [f'tSNE_{i+1}' for i in range(self.config.n_components)] else: raise ValueError(f"不支持的降维方法: {self.config.method}") # 保存结果以供可视化使用 self._last_reduced_data = reduced_data self._last_method = self.config.method return reduced_data, band_names def generate_visualization(self, reduced_data=None, method=None, output_path=None): """ 独立的可视化方法调用 参数: reduced_data: 降维后的数据,如果为None则使用上次处理的结果 method: 方法名称,如果为None则使用配置中的方法 output_path: 输出路径,如果为None则使用默认路径 """ if reduced_data is None and hasattr(self, '_last_reduced_data'): reduced_data = self._last_reduced_data if method is None: method = self.config.method # 保存降维结果以供可视化使用 self._last_reduced_data = reduced_data self._last_method = method if output_path is None: output_path = f'visualization_{method.lower()}.csv' # 生成可视化图表 if self.config.generate_plots and self.data_type == 'csv': self._generate_visualization_plots(reduced_data, method, output_path) def _check_method_compatibility(self, method): """ 检查降维方法与数据类型的兼容性 高光谱图像数据适用性: - PCA, ICA, FA: 高度推荐,计算效率高,适合高维数据 - LDA: 需要标签,适用于有监督降维 - MDS: 可用但计算量较大 - Isomap, LLE, t-SNE: 不推荐,计算复杂度高,内存占用大 CSV数据适用性: - 所有方法都可用,但大数据集应避免复杂度高的方法 """ method = method.upper() # LDA需要标签检查 if method == 'LDA' and self.labels is None: raise ValueError("LDA方法需要标签信息,请确保数据包含标签列") # 高光谱数据的限制 if self.data_type == 'image': high_complexity_methods = ['ISOMAP', 'LLE', 'T-SNE'] if method in high_complexity_methods: import warnings warnings.warn( f"警告: {method}方法对高光谱图像数据计算复杂度高," "可能导致内存不足或计算时间过长。建议使用PCA、ICA或FA方法。", UserWarning ) # CSV数据的小提示 elif self.data_type == 'csv': if method in ['T-SNE', 'MDS', 'ISOMAP', 'LLE']: import warnings warnings.warn( f"提示: {method}方法计算复杂度较高," "如果数据集较大,建议使用PCA、ICA或FA方法。", UserWarning ) def _check_data_size_compatibility(self, method, n_samples, n_features): """ 根据数据大小给出方法适用性建议 参数: method: 降维方法 n_samples: 样本数量 n_features: 特征数量 """ method = method.upper() high_complexity_methods = ['T-SNE', 'MDS', 'ISOMAP', 'LLE'] # 对于大数据集的警告 if method in high_complexity_methods: if n_samples > 5000: import warnings warnings.warn( f"警告: 数据集有{n_samples}个样本,{method}方法的计算复杂度为O(n²)," "可能需要很长时间。建议使用PCA、ICA或FA方法。", UserWarning ) elif n_samples > 10000: raise ValueError( f"数据集过大({n_samples}样本),{method}方法不适合。" "请使用PCA、ICA或FA方法。" ) # 对于高维数据的建议 if n_features > 1000 and method not in ['PCA', 'ICA', 'FA']: import warnings warnings.warn( f"提示: 数据维度很高({n_features}),建议优先考虑PCA、ICA或FA方法。", UserWarning ) def get_recommended_methods(self, data_type=None, n_samples=None, n_features=None, has_labels=False, max_components=None): """ 根据数据特征推荐合适的降维方法 参数: data_type: 数据类型 ('image' 或 'csv') n_samples: 样本数量 n_features: 特征数量 has_labels: 是否有标签 max_components: 最大降维维度 返回: dict: 包含推荐方法和理由的字典 """ if data_type is None: data_type = self.data_type if n_samples is None and self.data: n_samples = self.data.shape[0] if n_features is None and self.data: n_features = self.data.shape[1] if not has_labels and self.labels is not None: has_labels = True recommendations = { 'highly_recommended': [], 'recommended': [], 'use_with_caution': [], 'not_recommended': [] } # 高光谱图像数据推荐 if data_type == 'image': recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA'] if has_labels: recommendations['recommended'].append('LDA') else: recommendations['recommended'].append('MDS') recommendations['use_with_caution'] = ['Isomap', 'LLE'] recommendations['not_recommended'] = ['t-SNE'] # CSV数据推荐 elif data_type == 'csv': if n_samples and n_samples < 1000: # 小数据集 recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA'] recommendations['recommended'] = ['MDS', 'Isomap', 'LLE', 't-SNE'] if has_labels: recommendations['recommended'].append('LDA') elif n_samples and n_samples < 5000: # 中等数据集 recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA'] recommendations['recommended'] = ['MDS'] if has_labels: recommendations['recommended'].append('LDA') recommendations['use_with_caution'] = ['Isomap', 'LLE', 't-SNE'] else: # 大数据集 recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA'] if has_labels: recommendations['recommended'].append('LDA') recommendations['not_recommended'] = ['MDS', 'Isomap', 'LLE', 't-SNE'] # 高维数据额外建议 if n_features and n_features > 500: # 对于高维数据,PCA通常是最佳选择 if 'PCA' not in recommendations['highly_recommended']: recommendations['highly_recommended'].insert(0, 'PCA') return recommendations @staticmethod def print_method_info(): """ 打印所有可用降维方法的详细信息 """ method_info = { 'PCA': { 'name': '主成分分析 (Principal Component Analysis)', 'description': '通过正交变换将数据投影到新的坐标系,保留最大方差的方向', 'advantages': ['计算效率高', '适合高维数据', '无参数', '保持全局结构'], 'limitations': ['线性方法', '对异常值敏感'], 'best_for': ['高光谱数据预处理', '噪声去除', '数据可视化'], 'data_types': ['image', 'csv'], 'complexity': '低' }, 'ICA': { 'name': '独立成分分析 (Independent Component Analysis)', 'description': '假设数据是多个独立源的混合,通过统计独立性分离信号', 'advantages': ['能发现隐藏因素', '适合信号分离', '非线性'], 'limitations': ['计算复杂度较高', '结果可能不稳定'], 'best_for': ['信号分离', '去除混合噪声', '发现潜在成分'], 'data_types': ['image', 'csv'], 'complexity': '中' }, 'FA': { 'name': '因子分析 (Factor Analysis)', 'description': '假设观测变量由潜在因子和误差组成,建立因子模型', 'advantages': ['解释性强', '适合社会科学数据', '处理测量误差'], 'limitations': ['假设较多', '对数据分布敏感'], 'best_for': ['探索潜在结构', '变量分组', '降维解释'], 'data_types': ['image', 'csv'], 'complexity': '中' }, 'LDA': { 'name': '线性判别分析 (Linear Discriminant Analysis)', 'description': '通过最大化类间距离和最小化类内距离进行降维', 'advantages': ['监督学习', '保持类别可分性', '计算简单'], 'limitations': ['需要标签', '线性假设', '对异常值敏感'], 'best_for': ['分类任务', '有标签数据', '保持类别结构'], 'data_types': ['csv'], # 主要用于CSV,因为图像通常没有像素级标签 'complexity': '低' }, 'MDS': { 'name': '多维尺度分析 (Multidimensional Scaling)', 'description': '保持样本间距离关系在低维空间中的映射', 'advantages': ['保持距离关系', '无分布假设', '直观'], 'limitations': ['计算复杂度O(n²)', '对大数据不适用'], 'best_for': ['保持相对距离', '数据可视化', '相似性分析'], 'data_types': ['csv'], 'complexity': '高' }, 'Isomap': { 'name': '等度量映射 (Isometric Mapping)', 'description': '保持测地线距离,在流形上计算最短路径', 'advantages': ['处理非线性流形', '保持局部和全局结构'], 'limitations': ['对噪声敏感', '计算复杂度高', '参数选择重要'], 'best_for': ['非线性流形数据', '保持几何结构'], 'data_types': ['csv'], 'complexity': '很高' }, 'LLE': { 'name': '局部线性嵌入 (Locally Linear Embedding)', 'description': '保持局部邻域的线性关系进行降维', 'advantages': ['处理非线性数据', '保持局部结构', '无优化参数'], 'limitations': ['对大数据不适用', '对噪声敏感', '邻域参数敏感'], 'best_for': ['非线性降维', '保持局部邻域关系'], 'data_types': ['csv'], 'complexity': '很高' }, 't-SNE': { 'name': 't-分布随机邻域嵌入 (t-SNE)', 'description': '将高维数据转换为低维空间,保持局部相似性', 'advantages': ['优秀的可视化效果', '保持局部结构', '非线性'], 'limitations': ['计算复杂度极高', '主要用于可视化', '结果不稳定'], 'best_for': ['数据可视化', '聚类分析', '探索性分析'], 'data_types': ['csv'], 'complexity': '极高' } } print("\n=== 降维方法详细说明 ===\n") for method, info in method_info.items(): print(f"{method} - {info['name']}") print(f" 描述: {info['description']}") print(f" 优势: {', '.join(info['advantages'])}") print(f" 局限性: {', '.join(info['limitations'])}") print(f" 适用场景: {', '.join(info['best_for'])}") print(f" 数据类型: {', '.join(info['data_types'])}") print(f" 计算复杂度: {info['complexity']}") print() def save_results(self, reduced_data, band_names, output_path, method): """ 保存降维结果 参数: reduced_data: 降维后的数据 band_names: 波段名称列表 output_path: 输出路径 method: 使用的降维方法 """ output_base = os.path.splitext(output_path)[0] if self.data_type == 'image': # 恢复图像形状 if self.original_shape: img_shape = (self.original_shape[0], self.original_shape[1], reduced_data.shape[1]) reduced_img = reduced_data.reshape(img_shape) else: reduced_img = reduced_data # 保存.dat文件 self._save_dat_file(reduced_img, output_path) # 保存.hdr头文件 hdr_path = output_base + '.hdr' self._save_hdr_file(hdr_path, reduced_img.shape, band_names, method) print(f"图像已保存: {output_path}") print(f"头文件已保存: {hdr_path}") else: # CSV格式 # 只保存为CSV格式,不保存DAT和HDR文件 df_output = pd.DataFrame(reduced_data, columns=band_names) # 如果有标签,将标签列放在第一列 if self.labels is not None: df_output.insert(0, 'Label', self.labels) # 如果output_path已经有.csv扩展名,直接使用;否则添加扩展名 if output_path.lower().endswith('.csv'): csv_path = output_path else: csv_path = output_base + '.csv' df_output.to_csv(csv_path, index=False) print(f"CSV文件已保存: {csv_path}") print("注意:CSV输入文件只保存CSV格式结果,不生成DAT和HDR文件") # 生成可视化图表(仅对CSV数据) if self.config.generate_plots: self._generate_visualization_plots(reduced_data, method, csv_path) def _save_dat_file(self, data, file_path): """保存.dat文件(二进制格式)""" with open(file_path, 'wb') as f: data.astype(np.float32).tofile(f) def _save_hdr_file(self, hdr_path, data_shape, band_names, method, is_csv=False): """保存ENVI头文件""" if is_csv: lines, samples = 1, data_shape[0] else: lines, samples = data_shape[0], data_shape[1] bands = data_shape[2] if len(data_shape) == 3 else data_shape[1] # 从原始头文件中保留关键信息 original_info = {} if self.header and self.data_type == 'image': # 保留原始的传感器和采集信息 preserve_keys = [ 'sensor type', 'wavelength units', 'sample binning', 'spectral binning', 'line binning', 'shutter', 'gain', 'framerate', 'imager serial number', 'rotation', 'label', 'map info', 'coordinate system string', 'classes' ] for key in preserve_keys: if key in self.header: original_info[key] = self.header[key] # 处理波长信息 - 如果是降维结果,使用新的波段名称 if 'wavelength' in self.header: # 对于降维结果,不使用原始波长,而是使用新的波段名称 original_info['band names'] = band_names else: original_info['band names'] = band_names header_content = f"""ENVI description = {{{method} Result [{pd.Timestamp.now().strftime('%a %b %d %H:%M:%S %Y')}]}} samples = {samples} lines = {lines} bands = {bands} header offset = 0 file type = ENVI Standard data type = 4 interleave = bip byte order = 0 """ # 添加保留的原始信息 for key, value in original_info.items(): if key == 'band names': header_content += f"band names = {{{', '.join(value)}}}\n" elif key == 'wavelength': # 格式化波长信息 wavelength_str = ', '.join([f'{w:.6f}' for w in value]) if isinstance(value, (list, np.ndarray)) else str(value) header_content += f"original wavelength = {{{wavelength_str}}}\n" elif key == 'rotation': # 特殊处理rotation格式 if isinstance(value, list) and len(value) == 4: rotation_str = f"[({value[0][0]}, {value[0][1]}), ({value[1][0]}, {value[1][1]}), ({value[2][0]}, {value[2][1]}), ({value[3][0]}, {value[3][1]})]" header_content += f"rotation = {rotation_str}\n" else: header_content += f"rotation = {value}\n" elif isinstance(value, str) and '\n' in value: # 多行字符串 header_content += f"{key} = {{\n{value}\n}}\n" else: header_content += f"{key} = {value}\n" # 如果没有原始信息,添加默认值 if not original_info: header_content += f"""wavelength units = Unknown band names = {{{', '.join(band_names)}}} classes = 0 map info = {{Geographic Lat/Lon, 1.0000, 1.0000, 0.0, 0.0, 0.0, 0.0}} """ # 添加处理历史 if self.header and 'history' in self.header: history_info = f"{self.header['history']} -> {method}[]" else: history_info = f"{method}[]" header_content += f"history = {history_info}\n" with open(hdr_path, 'w', encoding='utf-8') as f: f.write(header_content) def _generate_visualization_plots(self, reduced_data, method, csv_path): """生成降维结果的可视化散点图""" try: # 获取贡献百分比(仅对支持的方法) explained_variance_ratio = self._get_explained_variance_ratio(method) # 生成2D散点图(前两个特征) if reduced_data.shape[1] >= 2: self._plot_2d_scatter(reduced_data, method, explained_variance_ratio, csv_path) # 生成3D散点图(前三个特征) if reduced_data.shape[1] >= 3: self._plot_3d_scatter(reduced_data, method, explained_variance_ratio, csv_path) except Exception as e: print(f"生成可视化图表时出错: {e}") def _get_explained_variance_ratio(self, method): """获取降维方法的贡献百分比""" method_lower = method.lower() # 对于PCA,可以获取解释方差比 if method_lower == 'pca' and self._explained_variance_ratio is not None: return self._explained_variance_ratio # 对于其他方法,目前不支持贡献百分比 return None def _plot_2d_scatter(self, data, method, explained_variance_ratio, csv_path): """生成2D散点图""" plt.figure(figsize=(10, 8)) if self.config.color_by_label and self.labels is not None: # 按标签着色 unique_labels = np.unique(self.labels) colors = cm.rainbow(np.linspace(0, 1, len(unique_labels))) for i, label in enumerate(unique_labels): mask = self.labels == label plt.scatter(data[mask, 0], data[mask, 1], c=[colors[i]], label=f'Class {label}', alpha=0.7, s=50, edgecolors='black', linewidth=0.5) plt.legend() else: # 统一颜色 plt.scatter(data[:, 0], data[:, 1], alpha=0.7, s=50, c='blue', edgecolors='black', linewidth=0.5) # 设置坐标轴标签 xlabel = 'PC1' ylabel = 'PC2' if explained_variance_ratio is not None and len(explained_variance_ratio) >= 2: xlabel = f'PC1 ({explained_variance_ratio[0]:.1%})' ylabel = f'PC2 ({explained_variance_ratio[1]:.1%})' plt.xlabel(xlabel, fontsize=12, fontweight='bold') plt.ylabel(ylabel, fontsize=12, fontweight='bold') plt.title(f'{method.upper()} - 2D Scatter Plot', fontsize=14, fontweight='bold') plt.grid(True, alpha=0.3) # 保存图表 if self.config.save_plots: plot_path = os.path.splitext(csv_path)[0] + '_2d_scatter.png' plt.savefig(plot_path, dpi=300, bbox_inches='tight') print(f"2D散点图已保存: {plot_path}") plt.close() def _plot_3d_scatter(self, data, method, explained_variance_ratio, csv_path): """生成3D散点图""" fig = plt.figure(figsize=(12, 10)) ax = fig.add_subplot(111, projection='3d') if self.config.color_by_label and self.labels is not None: # 按标签着色 unique_labels = np.unique(self.labels) colors = cm.rainbow(np.linspace(0, 1, len(unique_labels))) for i, label in enumerate(unique_labels): mask = self.labels == label ax.scatter(data[mask, 0], data[mask, 1], data[mask, 2], c=[colors[i]], label=f'Class {label}', alpha=0.7, s=50, edgecolors='black', linewidth=0.5) ax.legend() else: # 统一颜色 ax.scatter(data[:, 0], data[:, 1], data[:, 2], alpha=0.7, s=50, c='blue', edgecolors='black', linewidth=0.5) # 设置坐标轴标签 xlabel = 'PC1' ylabel = 'PC2' zlabel = 'PC3' if explained_variance_ratio is not None and len(explained_variance_ratio) >= 3: xlabel = f'PC1 ({explained_variance_ratio[0]:.1%})' ylabel = f'PC2 ({explained_variance_ratio[1]:.1%})' zlabel = f'PC3 ({explained_variance_ratio[2]:.1%})' ax.set_xlabel(xlabel, fontsize=12, fontweight='bold') ax.set_ylabel(ylabel, fontsize=12, fontweight='bold') ax.set_zlabel(zlabel, fontsize=12, fontweight='bold') ax.set_title(f'{method.upper()} - 3D Scatter Plot', fontsize=14, fontweight='bold') # 保存图表 if self.config.save_plots: plot_path = os.path.splitext(csv_path)[0] + '_3d_scatter.png' plt.savefig(plot_path, dpi=300, bbox_inches='tight') print(f"3D散点图已保存: {plot_path}") plt.close() def batch_process(self, methods, n_components_list, output_dir): """ 批量处理多个降维方法 参数: methods: 降维方法列表 n_components_list: 各方法的维度列表 output_dir: 输出目录 """ os.makedirs(output_dir, exist_ok=True) # 显示数据信息和方法推荐 if self.data is not None: n_samples, n_features = self.data.shape has_labels = self.labels is not None print(f"\n=== 数据信息 ===") print(f"数据类型: {self.data_type}") print(f"样本数量: {n_samples}") print(f"特征数量: {n_features}") print(f"是否有标签: {'是' if has_labels else '否'}") # 获取推荐方法 recommendations = self.get_recommended_methods( self.data_type, n_samples, n_features, has_labels ) print(f"\n=== 方法推荐 ===") print(f"高度推荐: {', '.join(recommendations['highly_recommended'])}") if recommendations['recommended']: print(f"推荐: {', '.join(recommendations['recommended'])}") if recommendations['use_with_caution']: print(f"谨慎使用: {', '.join(recommendations['use_with_caution'])}") if recommendations['not_recommended']: print(f"不推荐: {', '.join(recommendations['not_recommended'])}") # 检查用户选择的方法是否在推荐列表中 for method in methods: method_lower = method.lower() method_upper = method_lower.upper() if method_upper in recommendations['not_recommended']: print(f"⚠️ 警告: {method_lower}方法不适合当前数据类型,建议使用推荐方法") elif method_upper in recommendations['use_with_caution']: print(f"⚠️ 注意: {method_lower}方法对大数据计算复杂度较高,请确保内存充足") results = {} for method, n_comp in zip(methods, n_components_list): # 统一方法名为小写 method_lower = method.lower() print(f"\n正在处理: {method_lower} (维度: {n_comp})") try: # 临时修改配置以适应当前方法 original_method = self.config.method original_n_components = self.config.n_components original_method_params = self.config.method_params.copy() self.config.method = method_lower self.config.n_components = n_comp # 清空方法参数,重新设置当前方法的默认参数 self.config.method_params = {} self.config._set_method_default_params() reduced_data, band_names = self.apply_dim_reduction() # 恢复原始配置 self.config.method = original_method self.config.n_components = original_n_components self.config.method_params = original_method_params # 保存结果 if self.data_type == 'image': output_path = os.path.join(output_dir, f'reduced_{method_lower}.dat') else: # CSV格式 output_path = os.path.join(output_dir, f'reduced_{method_lower}.csv') self.save_results(reduced_data, band_names, output_path, method_lower) # 如果启用可视化且是CSV数据,在批量处理时也生成可视化 if self.config.generate_plots and self.data_type == 'csv': try: self._generate_visualization_plots(reduced_data, method_lower, output_path) except Exception as vis_error: print(f"生成 {method_lower} 可视化时出错: {vis_error}") results[method_lower] = { 'data': reduced_data, 'bands': band_names, 'n_components': n_comp } except Exception as e: print(f"处理 {method_lower} 时出错: {str(e)}") return results def main(): """主函数:命令行接口""" import argparse parser = argparse.ArgumentParser( description='高光谱数据降维处理工具', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 使用示例: 1. 单方法降维 (PCA, 3维): python dimensionality_reduction.py input.csv -l Label -s 391.3225 -e 1003.175 -m pca -n 3 -o result 2. 批量降维处理多种方法: python dimensionality_reduction.py input.csv -l Label -s 391.3225 -e 1003.175 -M pca ica fa -N 3 3 3 -d results 3. 处理ENVI图像数据: python dimensionality_reduction.py image.hdr -m pca -n 5 -o reduced_image 4. 生成可视化图表: python dimensionality_reduction.py input.csv -l Label -s 391.3225 -e 1003.175 -m pca -n 3 -o result --generate-plots --color-by-label 支持的降维方法: pca: Principal Component Analysis (PCA) ica: Independent Component Analysis (ICA) fa: Factor Analysis (FA) lda: Linear Discriminant Analysis (LDA) mds: Multidimensional Scaling (MDS) isomap: Isometric Mapping (Isomap) lle: Locally Linear Embedding (LLE) t-sne: t-Distributed Stochastic Neighbor Embedding (t-SNE) """ ) parser.add_argument('input_path', help='输入文件路径 (.hdr高光谱图像 或 .csv文件)') # CSV文件参数 parser.add_argument('-l', '--label-col', help='CSV文件的标签列名') parser.add_argument('-s', '--spectral-start', help='CSV文件的谱段起始列名') parser.add_argument('-e', '--spectral-end', help='CSV文件的谱段结束列名') # 单方法参数 parser.add_argument('-m', '--method', default='pca', choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'], help='降维方法 (默认: pca)') parser.add_argument('-n', '--n-components', type=int, default=3, help='降维后的维度数 (默认: 3)') parser.add_argument('-o', '--output', help='输出文件路径或前缀') # 批量处理参数 parser.add_argument('-M', '--batch-methods', nargs='+', choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'], help='批量处理的降维方法列表') parser.add_argument('-N', '--batch-components', nargs='+', type=int, help='对应批量方法的维度数列表') parser.add_argument('-d', '--output-dir', help='批量处理输出目录') # 可视化参数 parser.add_argument('--generate-plots', action='store_true', help='生成可视化图表') parser.add_argument('--color-by-label', action='store_true', help='按标签着色散点图') parser.add_argument('--save-plots', action='store_true', help='保存图表文件') # 其他选项 parser.add_argument('--no-standardization', action='store_true', help='不进行数据标准化') parser.add_argument('--show-methods', action='store_true', help='显示所有可用方法的详细信息') args = parser.parse_args() # 显示方法信息 if args.show_methods: HyperspectralDimReduction.print_method_info() return 0 try: print("=" * 60) print("高光谱数据降维处理工具") print("=" * 60) print(f"输入文件: {args.input_path}") # 检查是单方法还是批量处理 is_batch = args.batch_methods is not None if is_batch: print(f"批量处理方法: {', '.join(args.batch_methods)}") print(f"对应维度: {args.batch_components}") print(f"输出目录: {args.output_dir}") else: print(f"降维方法: {args.method}") print(f"目标维度: {args.n_components}") print(f"输出路径: {args.output}") if args.label_col: print(f"标签列: {args.label_col}") if args.spectral_start: print(f"谱段范围: {args.spectral_start} - {args.spectral_end}") print() # 创建配置 config = DimensionalityReductionConfig( input_path=args.input_path, label_col=args.label_col, spectral_start=args.spectral_start, spectral_end=args.spectral_end, output_path=args.output, output_dir=args.output_dir, use_standardization=not args.no_standardization, generate_plots=args.generate_plots, color_by_label=args.color_by_label, save_plots=args.save_plots ) # 创建处理器 processor = HyperspectralDimReduction(config) # 加载数据 print("正在加载数据...") processor.load_data() if is_batch: # 批量处理 if not args.batch_components: args.batch_components = [3] * len(args.batch_methods) elif len(args.batch_components) != len(args.batch_methods): raise ValueError(f"批量方法数量 ({len(args.batch_methods)}) 与维度数量 ({len(args.batch_components)}) 不匹配") print("\n开始批量降维处理...") results = processor.batch_process(args.batch_methods, args.batch_components, args.output_dir or 'reduced_data_batch') # 打印结果汇总 print(f"\n=== 批量处理完成 ===") print(f"共处理了 {len(results)} 种降维方法:") for method_name, result in results.items(): print(f"- {method_name}: {result['n_components']} 维") # 保存汇总信息 import pandas as pd summary_data = [] for method_name, result in results.items(): summary_data.append({ 'method': method_name, 'n_components': result['n_components'], 'data_shape': str(result['data'].shape), 'output_file': f'reduced_{method_name}' }) summary_df = pd.DataFrame(summary_data) summary_path = os.path.join(args.output_dir or 'reduced_data_batch', 'batch_processing_summary.csv') summary_df.to_csv(summary_path, index=False) print(f"批量处理汇总已保存到: {summary_path}") else: # 单方法处理 print("\n开始降维处理...") config.method = args.method config.n_components = args.n_components config.output_path = args.output reduced_data, band_names = processor.apply_dim_reduction() print("✓ 降维处理完成") print(f"原始数据维度: {processor.data.shape if processor.data is not None else 'Unknown'}") print(f"降维后数据维度: {reduced_data.shape}") # 保存结果 if args.output: processor.save_results(reduced_data, band_names, args.output, args.method) print("\n" + "=" * 60) print("处理完成!") print("=" * 60) except Exception as e: print(f"✗ 处理失败: {e}") import traceback traceback.print_exc() return 1 return 0 if __name__ == "__main__": exit(main())