1203 lines
51 KiB
Python
1203 lines
51 KiB
Python
import numpy as np
|
||
import pandas as pd
|
||
import spectral
|
||
from sklearn.decomposition import PCA, FastICA, FactorAnalysis
|
||
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
|
||
from sklearn.manifold import MDS, Isomap, TSNE, LocallyLinearEmbedding
|
||
from sklearn.preprocessing import StandardScaler
|
||
# import h5py
|
||
import os
|
||
from scipy.io import savemat
|
||
import warnings
|
||
from typing import Optional, Dict, Any, Tuple, List
|
||
from dataclasses import dataclass, field
|
||
warnings.filterwarnings('ignore')
|
||
|
||
# 可视化相关导入
|
||
import matplotlib.pyplot as plt
|
||
from mpl_toolkits.mplot3d import Axes3D
|
||
import matplotlib.cm as cm
|
||
import matplotlib.colors as mcolors
|
||
|
||
|
||
@dataclass
|
||
class DimensionalityReductionConfig:
|
||
"""高光谱数据降维配置类"""
|
||
|
||
# 输入文件配置
|
||
input_path: Optional[str] = None
|
||
label_col: Optional[str] = None
|
||
spectral_start: Optional[str] = None
|
||
spectral_end: Optional[str] = None
|
||
csv_has_header: bool = True
|
||
|
||
# 输出配置
|
||
output_path: Optional[str] = None
|
||
output_dir: Optional[str] = None
|
||
|
||
# 降维方法配置
|
||
method: str = 'PCA'
|
||
n_components: int = 3
|
||
method_params: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
# 批量处理配置
|
||
batch_methods: List[str] = field(default_factory=lambda: ['PCA', 'ICA', 'FA'])
|
||
batch_components: List[int] = field(default_factory=lambda: [3, 3, 3])
|
||
|
||
# 数据预处理配置
|
||
use_standardization: bool = True
|
||
|
||
# 可视化参数
|
||
generate_plots: bool = True # 是否生成散点分布图
|
||
color_by_label: bool = True # 是否按标签着色散点
|
||
save_plots: bool = True # 是否保存图表文件
|
||
|
||
def __post_init__(self):
|
||
"""参数校验和默认值设置"""
|
||
# 校验必需的文件路径
|
||
if not self.input_path:
|
||
raise ValueError("必须指定输入文件路径(input_path)")
|
||
|
||
# 校验文件存在性
|
||
if not os.path.exists(self.input_path):
|
||
raise FileNotFoundError(f"输入文件不存在: {self.input_path}")
|
||
|
||
# 校验输出路径(批量处理时必需)
|
||
# 单方法处理时输出路径是可选的
|
||
|
||
# 统一方法名为小写
|
||
self.method = self.method.lower()
|
||
self.batch_methods = [m.lower() for m in self.batch_methods]
|
||
|
||
# 校验降维方法
|
||
supported_methods = self._get_supported_methods()
|
||
if self.method not in supported_methods:
|
||
raise ValueError(f"不支持的降维方法: {self.method}。支持的方法: {list(supported_methods.keys())}")
|
||
|
||
# 校验批量处理参数
|
||
if len(self.batch_methods) != len(self.batch_components):
|
||
raise ValueError("batch_methods和batch_components长度必须相等")
|
||
|
||
for method in self.batch_methods:
|
||
if method not in supported_methods:
|
||
raise ValueError(f"批量处理中包含不支持的方法: {method}")
|
||
|
||
# 设置方法默认参数
|
||
self._set_method_default_params()
|
||
|
||
def _get_supported_methods(self) -> Dict[str, str]:
|
||
"""获取支持的降维方法列表"""
|
||
methods = {
|
||
'pca': 'Principal Component Analysis (PCA)',
|
||
'ica': 'Independent Component Analysis (ICA)',
|
||
'fa': 'Factor Analysis (FA)',
|
||
'lda': 'Linear Discriminant Analysis (LDA)',
|
||
'mds': 'Multidimensional Scaling (MDS)',
|
||
'isomap': 'Isometric Mapping (Isomap)',
|
||
'lle': 'Locally Linear Embedding (LLE)',
|
||
't-sne': 't-Distributed Stochastic Neighbor Embedding (t-SNE)'
|
||
}
|
||
return methods
|
||
|
||
def _set_method_default_params(self):
|
||
"""根据方法类型设置默认参数"""
|
||
if self.method.lower() == 'pca':
|
||
self.method_params.setdefault('whiten', False)
|
||
self.method_params.setdefault('svd_solver', 'auto')
|
||
elif self.method.lower() == 'ica':
|
||
self.method_params.setdefault('algorithm', 'parallel')
|
||
self.method_params.setdefault('whiten', 'unit-variance')
|
||
self.method_params.setdefault('fun', 'logcosh')
|
||
elif self.method.lower() == 'fa':
|
||
self.method_params.setdefault('rotation', None)
|
||
self.method_params.setdefault('svd_method', 'randomized')
|
||
elif self.method.lower() == 'lda':
|
||
self.method_params.setdefault('solver', 'svd')
|
||
self.method_params.setdefault('shrinkage', None)
|
||
elif self.method.lower() == 'mds':
|
||
self.method_params.setdefault('metric', True)
|
||
self.method_params.setdefault('dissimilarity', 'euclidean')
|
||
self.method_params.setdefault('random_state', 42)
|
||
elif self.method.lower() == 'isomap':
|
||
self.method_params.setdefault('n_neighbors', 5)
|
||
self.method_params.setdefault('eigen_solver', 'auto')
|
||
elif self.method.lower() == 'lle':
|
||
self.method_params.setdefault('n_neighbors', 5)
|
||
self.method_params.setdefault('reg', 0.001)
|
||
self.method_params.setdefault('eigen_solver', 'auto')
|
||
elif self.method.lower() == 't-sne':
|
||
self.method_params.setdefault('perplexity', 30.0)
|
||
self.method_params.setdefault('early_exaggeration', 12.0)
|
||
self.method_params.setdefault('learning_rate', 200.0)
|
||
self.method_params.setdefault('random_state', 42)
|
||
|
||
|
||
class HyperspectralDimReduction:
|
||
"""高光谱数据降维处理系统"""
|
||
|
||
def __init__(self, config: DimensionalityReductionConfig):
|
||
self.config = config
|
||
|
||
# 初始化数据存储
|
||
self.data = None
|
||
self.labels = None
|
||
self.header = None
|
||
self.original_shape = None
|
||
self.data_type = None # 'image' or 'csv'
|
||
|
||
# 初始化处理器
|
||
self.scaler = StandardScaler() if config.use_standardization else None
|
||
|
||
# 保存模型引用(用于获取贡献百分比)
|
||
self._last_model = None
|
||
self._explained_variance_ratio = None
|
||
|
||
# 保存可视化数据
|
||
self._last_reduced_data = None
|
||
self._last_method = None
|
||
|
||
def load_data(self) -> None:
|
||
"""
|
||
加载数据文件(使用配置参数)
|
||
"""
|
||
print(f"正在加载数据文件: {self.config.input_path}")
|
||
|
||
file_ext = os.path.splitext(self.config.input_path)[1].lower()
|
||
|
||
if file_ext in ['.dat', '.img', '.hdr']:
|
||
self._load_image_data(self.config.input_path)
|
||
self.data_type = 'image'
|
||
elif file_ext == '.csv':
|
||
self._load_csv_data(self.config.input_path, self.config.label_col,
|
||
self.config.spectral_start, self.config.spectral_end,
|
||
self.config.csv_has_header)
|
||
self.data_type = 'csv'
|
||
else:
|
||
raise ValueError(f"不支持的文件格式: {file_ext}")
|
||
|
||
print("数据加载完成")
|
||
|
||
def _load_image_data(self, file_path):
|
||
"""加载高光谱图像数据"""
|
||
base_path = os.path.splitext(file_path)[0]
|
||
hdr_file = base_path + '.hdr'
|
||
|
||
# 读取ENVI格式图像
|
||
img = spectral.open_image(hdr_file)
|
||
self.data = img.load()
|
||
self.original_shape = self.data.shape
|
||
|
||
# 读取头文件信息
|
||
self.header = spectral.envi.read_envi_header(hdr_file)
|
||
|
||
# 转换为二维数组(像素×波段)
|
||
self.data = self.data.reshape(-1, self.data.shape[2])
|
||
|
||
def _load_csv_data(self, file_path, label_col, spectral_start,
|
||
spectral_end, has_header):
|
||
"""加载CSV数据"""
|
||
if has_header:
|
||
df = pd.read_csv(file_path)
|
||
else:
|
||
df = pd.read_csv(file_path, header=None)
|
||
|
||
# 解析光谱列范围
|
||
def find_column_by_wavelength(df, wavelength_str):
|
||
"""根据波长值找到最接近的列"""
|
||
try:
|
||
target_wavelength = float(wavelength_str)
|
||
# 尝试从列名中提取波长信息
|
||
wavelength_cols = []
|
||
for col in df.columns:
|
||
if isinstance(col, str):
|
||
# 尝试提取列名中的数值
|
||
try:
|
||
# 处理类似 "wavelength_400.5" 的列名
|
||
if 'wavelength' in col.lower():
|
||
wl = float(col.split('_')[-1])
|
||
wavelength_cols.append((col, wl))
|
||
# 处理直接是数值的列名
|
||
elif col.replace('.', '').isdigit():
|
||
wl = float(col)
|
||
wavelength_cols.append((col, wl))
|
||
except ValueError:
|
||
continue
|
||
|
||
if wavelength_cols:
|
||
# 找到最接近的波长列
|
||
closest_col = min(wavelength_cols, key=lambda x: abs(x[1] - target_wavelength))[0]
|
||
return closest_col
|
||
else:
|
||
raise ValueError(f"无法在列名中找到波长信息: {list(df.columns)}")
|
||
except ValueError:
|
||
# 如果不是数值,直接作为列名使用
|
||
return wavelength_str
|
||
|
||
if spectral_start is not None:
|
||
if isinstance(spectral_start, str):
|
||
try:
|
||
# 尝试转换为数值,如果成功则查找波长列
|
||
float(spectral_start)
|
||
spectral_start_col = find_column_by_wavelength(df, spectral_start)
|
||
except ValueError:
|
||
# 如果不是数值,直接作为列名使用
|
||
spectral_start_col = spectral_start
|
||
else:
|
||
spectral_start_col = spectral_start
|
||
|
||
if spectral_end is not None:
|
||
if isinstance(spectral_end, str):
|
||
try:
|
||
float(spectral_end)
|
||
spectral_end_col = find_column_by_wavelength(df, spectral_end)
|
||
except ValueError:
|
||
spectral_end_col = spectral_end
|
||
else:
|
||
spectral_end_col = spectral_end
|
||
|
||
# 使用列名范围选择
|
||
spectral_cols = df.loc[:, spectral_start_col:spectral_end_col].columns
|
||
spectral_data = df.loc[:, spectral_start_col:spectral_end_col].values
|
||
else:
|
||
# 只有起始列,从起始列到最后
|
||
start_idx = df.columns.get_loc(spectral_start_col)
|
||
spectral_data = df.iloc[:, start_idx:].values
|
||
spectral_cols = df.iloc[:, start_idx:].columns
|
||
else:
|
||
# 如果没有指定起始列,使用除标签列外的所有数值列
|
||
if label_col is not None:
|
||
if isinstance(label_col, str) and label_col in df.columns:
|
||
spectral_data = df.drop(columns=[label_col]).select_dtypes(include=[np.number]).values
|
||
spectral_cols = df.drop(columns=[label_col]).select_dtypes(include=[np.number]).columns
|
||
elif isinstance(label_col, int):
|
||
spectral_data = df.drop(columns=df.columns[label_col]).select_dtypes(include=[np.number]).values
|
||
spectral_cols = df.drop(columns=df.columns[label_col]).select_dtypes(include=[np.number]).columns
|
||
else:
|
||
spectral_data = df.select_dtypes(include=[np.number]).values
|
||
spectral_cols = df.select_dtypes(include=[np.number]).columns
|
||
else:
|
||
spectral_data = df.select_dtypes(include=[np.number]).values
|
||
spectral_cols = df.select_dtypes(include=[np.number]).columns
|
||
|
||
# 提取标签
|
||
if label_col is not None:
|
||
if isinstance(label_col, str):
|
||
self.labels = df[label_col].values
|
||
else:
|
||
self.labels = df.iloc[:, label_col].values
|
||
|
||
# 从数据中移除标签列
|
||
if label_col in df.columns:
|
||
df = df.drop(columns=[label_col])
|
||
elif isinstance(label_col, int):
|
||
df = df.drop(columns=df.columns[label_col])
|
||
|
||
self.data = spectral_data
|
||
self.original_shape = self.data.shape
|
||
|
||
# 创建模拟头文件
|
||
self.header = {
|
||
'bands': spectral_data.shape[1],
|
||
'lines': 1,
|
||
'samples': spectral_data.shape[0],
|
||
'interleave': 'bsq',
|
||
'data type': 4, # float32
|
||
'byte order': 0,
|
||
'wavelength': [f'Band_{i}' for i in range(spectral_data.shape[1])]
|
||
}
|
||
|
||
def apply_dim_reduction(self) -> Tuple[np.ndarray, List[str]]:
|
||
"""
|
||
应用降维方法(使用配置参数)
|
||
|
||
Returns:
|
||
Tuple[np.ndarray, List[str]]: (降维后的数据, 波段名称列表)
|
||
|
||
数据类型适用性说明:
|
||
高光谱图像数据 (data_type='image'):
|
||
- 推荐: PCA, ICA, FA (适合高维数据,计算效率高)
|
||
- 可用但需谨慎: LDA (需要标签), MDS
|
||
- 不推荐: Isomap, LLE, t-SNE (计算复杂度高,内存占用大)
|
||
|
||
CSV光谱数据 (data_type='csv'):
|
||
- 全部方法都可用,但根据数据量选择:
|
||
小数据集(<1000样本): 全部方法可用
|
||
大数据集(>10000样本): 避免t-SNE, MDS, Isomap, LLE
|
||
|
||
方法限制:
|
||
- LDA: 需要标签信息(self.labels不为None)
|
||
- t-SNE/MDS/Isomap/LLE: 对大数据计算复杂度高(O(n²)或更高)
|
||
"""
|
||
print(f"正在应用降维方法: {self.config.method} (维度: {self.config.n_components})")
|
||
|
||
# 检查方法适用性
|
||
self._check_method_compatibility(self.config.method)
|
||
|
||
# 数据大小检查
|
||
n_samples, n_features = self.data.shape
|
||
self._check_data_size_compatibility(self.config.method, n_samples, n_features)
|
||
|
||
# 数据标准化
|
||
if self.config.use_standardization and self.scaler is not None:
|
||
data_scaled = self.scaler.fit_transform(self.data)
|
||
print("已应用数据标准化")
|
||
else:
|
||
data_scaled = self.data
|
||
|
||
# 合并配置参数和额外参数
|
||
method_params = {**self.config.method_params}
|
||
|
||
# 应用不同的降维方法
|
||
if self.config.method.upper() == 'PCA':
|
||
reducer = PCA(n_components=self.config.n_components, **method_params)
|
||
reduced_data = reducer.fit_transform(data_scaled)
|
||
band_names = [f'PCA_{i+1}' for i in range(self.config.n_components)]
|
||
|
||
# 保存模型引用和贡献百分比
|
||
self._last_model = reducer
|
||
self._explained_variance_ratio = reducer.explained_variance_ratio_
|
||
|
||
elif self.config.method.upper() == 'LDA':
|
||
if self.labels is None:
|
||
raise ValueError("LDA需要标签信息,请确保加载数据时指定了标签列")
|
||
reducer = LinearDiscriminantAnalysis(n_components=min(self.config.n_components,
|
||
len(np.unique(self.labels))-1),
|
||
**method_params)
|
||
reduced_data = reducer.fit_transform(data_scaled, self.labels)
|
||
band_names = [f'LDA_{i+1}' for i in range(reduced_data.shape[1])]
|
||
|
||
elif self.config.method.upper() == 'ICA':
|
||
reducer = FastICA(n_components=self.config.n_components, **method_params)
|
||
reduced_data = reducer.fit_transform(data_scaled)
|
||
band_names = [f'ICA_{i+1}' for i in range(self.config.n_components)]
|
||
|
||
elif self.config.method.upper() == 'FA':
|
||
reducer = FactorAnalysis(n_components=self.config.n_components, **method_params)
|
||
reduced_data = reducer.fit_transform(data_scaled)
|
||
band_names = [f'FA_{i+1}' for i in range(self.config.n_components)]
|
||
|
||
elif self.config.method.upper() == 'MDS':
|
||
reducer = MDS(n_components=self.config.n_components,**method_params)
|
||
reduced_data = reducer.fit_transform(data_scaled)
|
||
band_names = [f'MDS_{i+1}' for i in range(self.config.n_components)]
|
||
|
||
elif self.config.method.upper() == 'ISOMAP':
|
||
reducer = Isomap(n_components=self.config.n_components, **method_params)
|
||
reduced_data = reducer.fit_transform(data_scaled)
|
||
band_names = [f'Isomap_{i+1}' for i in range(self.config.n_components)]
|
||
|
||
elif self.config.method.upper() == 'LLE':
|
||
reducer = LocallyLinearEmbedding(n_components=self.config.n_components, **method_params)
|
||
reduced_data = reducer.fit_transform(data_scaled)
|
||
band_names = [f'LLE_{i+1}' for i in range(self.config.n_components)]
|
||
|
||
elif self.config.method.upper() == 'T-SNE':
|
||
reducer = TSNE(n_components=self.config.n_components, **method_params)
|
||
reduced_data = reducer.fit_transform(data_scaled)
|
||
band_names = [f'tSNE_{i+1}' for i in range(self.config.n_components)]
|
||
|
||
else:
|
||
raise ValueError(f"不支持的降维方法: {self.config.method}")
|
||
|
||
# 保存结果以供可视化使用
|
||
self._last_reduced_data = reduced_data
|
||
self._last_method = self.config.method
|
||
|
||
return reduced_data, band_names
|
||
|
||
def generate_visualization(self, reduced_data=None, method=None, output_path=None):
|
||
"""
|
||
独立的可视化方法调用
|
||
|
||
参数:
|
||
reduced_data: 降维后的数据,如果为None则使用上次处理的结果
|
||
method: 方法名称,如果为None则使用配置中的方法
|
||
output_path: 输出路径,如果为None则使用默认路径
|
||
"""
|
||
if reduced_data is None and hasattr(self, '_last_reduced_data'):
|
||
reduced_data = self._last_reduced_data
|
||
if method is None:
|
||
method = self.config.method
|
||
|
||
# 保存降维结果以供可视化使用
|
||
self._last_reduced_data = reduced_data
|
||
self._last_method = method
|
||
|
||
if output_path is None:
|
||
output_path = f'visualization_{method.lower()}.csv'
|
||
|
||
# 生成可视化图表
|
||
if self.config.generate_plots and self.data_type == 'csv':
|
||
self._generate_visualization_plots(reduced_data, method, output_path)
|
||
|
||
def _check_method_compatibility(self, method):
|
||
"""
|
||
检查降维方法与数据类型的兼容性
|
||
|
||
高光谱图像数据适用性:
|
||
- PCA, ICA, FA: 高度推荐,计算效率高,适合高维数据
|
||
- LDA: 需要标签,适用于有监督降维
|
||
- MDS: 可用但计算量较大
|
||
- Isomap, LLE, t-SNE: 不推荐,计算复杂度高,内存占用大
|
||
|
||
CSV数据适用性:
|
||
- 所有方法都可用,但大数据集应避免复杂度高的方法
|
||
"""
|
||
method = method.upper()
|
||
|
||
# LDA需要标签检查
|
||
if method == 'LDA' and self.labels is None:
|
||
raise ValueError("LDA方法需要标签信息,请确保数据包含标签列")
|
||
|
||
# 高光谱数据的限制
|
||
if self.data_type == 'image':
|
||
high_complexity_methods = ['ISOMAP', 'LLE', 'T-SNE']
|
||
if method in high_complexity_methods:
|
||
import warnings
|
||
warnings.warn(
|
||
f"警告: {method}方法对高光谱图像数据计算复杂度高,"
|
||
"可能导致内存不足或计算时间过长。建议使用PCA、ICA或FA方法。",
|
||
UserWarning
|
||
)
|
||
|
||
# CSV数据的小提示
|
||
elif self.data_type == 'csv':
|
||
if method in ['T-SNE', 'MDS', 'ISOMAP', 'LLE']:
|
||
import warnings
|
||
warnings.warn(
|
||
f"提示: {method}方法计算复杂度较高,"
|
||
"如果数据集较大,建议使用PCA、ICA或FA方法。",
|
||
UserWarning
|
||
)
|
||
|
||
def _check_data_size_compatibility(self, method, n_samples, n_features):
|
||
"""
|
||
根据数据大小给出方法适用性建议
|
||
|
||
参数:
|
||
method: 降维方法
|
||
n_samples: 样本数量
|
||
n_features: 特征数量
|
||
"""
|
||
method = method.upper()
|
||
high_complexity_methods = ['T-SNE', 'MDS', 'ISOMAP', 'LLE']
|
||
|
||
# 对于大数据集的警告
|
||
if method in high_complexity_methods:
|
||
if n_samples > 5000:
|
||
import warnings
|
||
warnings.warn(
|
||
f"警告: 数据集有{n_samples}个样本,{method}方法的计算复杂度为O(n²),"
|
||
"可能需要很长时间。建议使用PCA、ICA或FA方法。",
|
||
UserWarning
|
||
)
|
||
elif n_samples > 10000:
|
||
raise ValueError(
|
||
f"数据集过大({n_samples}样本),{method}方法不适合。"
|
||
"请使用PCA、ICA或FA方法。"
|
||
)
|
||
|
||
# 对于高维数据的建议
|
||
if n_features > 1000 and method not in ['PCA', 'ICA', 'FA']:
|
||
import warnings
|
||
warnings.warn(
|
||
f"提示: 数据维度很高({n_features}),建议优先考虑PCA、ICA或FA方法。",
|
||
UserWarning
|
||
)
|
||
|
||
def get_recommended_methods(self, data_type=None, n_samples=None, n_features=None,
|
||
has_labels=False, max_components=None):
|
||
"""
|
||
根据数据特征推荐合适的降维方法
|
||
|
||
参数:
|
||
data_type: 数据类型 ('image' 或 'csv')
|
||
n_samples: 样本数量
|
||
n_features: 特征数量
|
||
has_labels: 是否有标签
|
||
max_components: 最大降维维度
|
||
|
||
返回:
|
||
dict: 包含推荐方法和理由的字典
|
||
"""
|
||
if data_type is None:
|
||
data_type = self.data_type
|
||
if n_samples is None and self.data:
|
||
n_samples = self.data.shape[0]
|
||
if n_features is None and self.data:
|
||
n_features = self.data.shape[1]
|
||
if not has_labels and self.labels is not None:
|
||
has_labels = True
|
||
|
||
recommendations = {
|
||
'highly_recommended': [],
|
||
'recommended': [],
|
||
'use_with_caution': [],
|
||
'not_recommended': []
|
||
}
|
||
|
||
# 高光谱图像数据推荐
|
||
if data_type == 'image':
|
||
recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA']
|
||
|
||
if has_labels:
|
||
recommendations['recommended'].append('LDA')
|
||
else:
|
||
recommendations['recommended'].append('MDS')
|
||
|
||
recommendations['use_with_caution'] = ['Isomap', 'LLE']
|
||
recommendations['not_recommended'] = ['t-SNE']
|
||
|
||
# CSV数据推荐
|
||
elif data_type == 'csv':
|
||
if n_samples and n_samples < 1000: # 小数据集
|
||
recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA']
|
||
recommendations['recommended'] = ['MDS', 'Isomap', 'LLE', 't-SNE']
|
||
if has_labels:
|
||
recommendations['recommended'].append('LDA')
|
||
elif n_samples and n_samples < 5000: # 中等数据集
|
||
recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA']
|
||
recommendations['recommended'] = ['MDS']
|
||
if has_labels:
|
||
recommendations['recommended'].append('LDA')
|
||
recommendations['use_with_caution'] = ['Isomap', 'LLE', 't-SNE']
|
||
else: # 大数据集
|
||
recommendations['highly_recommended'] = ['PCA', 'ICA', 'FA']
|
||
if has_labels:
|
||
recommendations['recommended'].append('LDA')
|
||
recommendations['not_recommended'] = ['MDS', 'Isomap', 'LLE', 't-SNE']
|
||
|
||
# 高维数据额外建议
|
||
if n_features and n_features > 500:
|
||
# 对于高维数据,PCA通常是最佳选择
|
||
if 'PCA' not in recommendations['highly_recommended']:
|
||
recommendations['highly_recommended'].insert(0, 'PCA')
|
||
|
||
return recommendations
|
||
|
||
@staticmethod
|
||
def print_method_info():
|
||
"""
|
||
打印所有可用降维方法的详细信息
|
||
"""
|
||
method_info = {
|
||
'PCA': {
|
||
'name': '主成分分析 (Principal Component Analysis)',
|
||
'description': '通过正交变换将数据投影到新的坐标系,保留最大方差的方向',
|
||
'advantages': ['计算效率高', '适合高维数据', '无参数', '保持全局结构'],
|
||
'limitations': ['线性方法', '对异常值敏感'],
|
||
'best_for': ['高光谱数据预处理', '噪声去除', '数据可视化'],
|
||
'data_types': ['image', 'csv'],
|
||
'complexity': '低'
|
||
},
|
||
'ICA': {
|
||
'name': '独立成分分析 (Independent Component Analysis)',
|
||
'description': '假设数据是多个独立源的混合,通过统计独立性分离信号',
|
||
'advantages': ['能发现隐藏因素', '适合信号分离', '非线性'],
|
||
'limitations': ['计算复杂度较高', '结果可能不稳定'],
|
||
'best_for': ['信号分离', '去除混合噪声', '发现潜在成分'],
|
||
'data_types': ['image', 'csv'],
|
||
'complexity': '中'
|
||
},
|
||
'FA': {
|
||
'name': '因子分析 (Factor Analysis)',
|
||
'description': '假设观测变量由潜在因子和误差组成,建立因子模型',
|
||
'advantages': ['解释性强', '适合社会科学数据', '处理测量误差'],
|
||
'limitations': ['假设较多', '对数据分布敏感'],
|
||
'best_for': ['探索潜在结构', '变量分组', '降维解释'],
|
||
'data_types': ['image', 'csv'],
|
||
'complexity': '中'
|
||
},
|
||
'LDA': {
|
||
'name': '线性判别分析 (Linear Discriminant Analysis)',
|
||
'description': '通过最大化类间距离和最小化类内距离进行降维',
|
||
'advantages': ['监督学习', '保持类别可分性', '计算简单'],
|
||
'limitations': ['需要标签', '线性假设', '对异常值敏感'],
|
||
'best_for': ['分类任务', '有标签数据', '保持类别结构'],
|
||
'data_types': ['csv'], # 主要用于CSV,因为图像通常没有像素级标签
|
||
'complexity': '低'
|
||
},
|
||
'MDS': {
|
||
'name': '多维尺度分析 (Multidimensional Scaling)',
|
||
'description': '保持样本间距离关系在低维空间中的映射',
|
||
'advantages': ['保持距离关系', '无分布假设', '直观'],
|
||
'limitations': ['计算复杂度O(n²)', '对大数据不适用'],
|
||
'best_for': ['保持相对距离', '数据可视化', '相似性分析'],
|
||
'data_types': ['csv'],
|
||
'complexity': '高'
|
||
},
|
||
'Isomap': {
|
||
'name': '等度量映射 (Isometric Mapping)',
|
||
'description': '保持测地线距离,在流形上计算最短路径',
|
||
'advantages': ['处理非线性流形', '保持局部和全局结构'],
|
||
'limitations': ['对噪声敏感', '计算复杂度高', '参数选择重要'],
|
||
'best_for': ['非线性流形数据', '保持几何结构'],
|
||
'data_types': ['csv'],
|
||
'complexity': '很高'
|
||
},
|
||
'LLE': {
|
||
'name': '局部线性嵌入 (Locally Linear Embedding)',
|
||
'description': '保持局部邻域的线性关系进行降维',
|
||
'advantages': ['处理非线性数据', '保持局部结构', '无优化参数'],
|
||
'limitations': ['对大数据不适用', '对噪声敏感', '邻域参数敏感'],
|
||
'best_for': ['非线性降维', '保持局部邻域关系'],
|
||
'data_types': ['csv'],
|
||
'complexity': '很高'
|
||
},
|
||
't-SNE': {
|
||
'name': 't-分布随机邻域嵌入 (t-SNE)',
|
||
'description': '将高维数据转换为低维空间,保持局部相似性',
|
||
'advantages': ['优秀的可视化效果', '保持局部结构', '非线性'],
|
||
'limitations': ['计算复杂度极高', '主要用于可视化', '结果不稳定'],
|
||
'best_for': ['数据可视化', '聚类分析', '探索性分析'],
|
||
'data_types': ['csv'],
|
||
'complexity': '极高'
|
||
}
|
||
}
|
||
|
||
print("\n=== 降维方法详细说明 ===\n")
|
||
|
||
for method, info in method_info.items():
|
||
print(f"{method} - {info['name']}")
|
||
print(f" 描述: {info['description']}")
|
||
print(f" 优势: {', '.join(info['advantages'])}")
|
||
print(f" 局限性: {', '.join(info['limitations'])}")
|
||
print(f" 适用场景: {', '.join(info['best_for'])}")
|
||
print(f" 数据类型: {', '.join(info['data_types'])}")
|
||
print(f" 计算复杂度: {info['complexity']}")
|
||
print()
|
||
|
||
def save_results(self, reduced_data, band_names, output_path, method):
|
||
"""
|
||
保存降维结果
|
||
|
||
参数:
|
||
reduced_data: 降维后的数据
|
||
band_names: 波段名称列表
|
||
output_path: 输出路径
|
||
method: 使用的降维方法
|
||
"""
|
||
output_base = os.path.splitext(output_path)[0]
|
||
|
||
if self.data_type == 'image':
|
||
# 恢复图像形状
|
||
if self.original_shape:
|
||
img_shape = (self.original_shape[0], self.original_shape[1],
|
||
reduced_data.shape[1])
|
||
reduced_img = reduced_data.reshape(img_shape)
|
||
else:
|
||
reduced_img = reduced_data
|
||
|
||
# 保存.dat文件
|
||
self._save_dat_file(reduced_img, output_path)
|
||
|
||
# 保存.hdr头文件
|
||
hdr_path = output_base + '.hdr'
|
||
self._save_hdr_file(hdr_path, reduced_img.shape, band_names, method)
|
||
|
||
print(f"图像已保存: {output_path}")
|
||
print(f"头文件已保存: {hdr_path}")
|
||
|
||
else: # CSV格式
|
||
# 只保存为CSV格式,不保存DAT和HDR文件
|
||
df_output = pd.DataFrame(reduced_data, columns=band_names)
|
||
|
||
# 如果有标签,将标签列放在第一列
|
||
if self.labels is not None:
|
||
df_output.insert(0, 'Label', self.labels)
|
||
|
||
# 如果output_path已经有.csv扩展名,直接使用;否则添加扩展名
|
||
if output_path.lower().endswith('.csv'):
|
||
csv_path = output_path
|
||
else:
|
||
csv_path = output_base + '.csv'
|
||
|
||
df_output.to_csv(csv_path, index=False)
|
||
|
||
print(f"CSV文件已保存: {csv_path}")
|
||
print("注意:CSV输入文件只保存CSV格式结果,不生成DAT和HDR文件")
|
||
|
||
# 生成可视化图表(仅对CSV数据)
|
||
if self.config.generate_plots:
|
||
self._generate_visualization_plots(reduced_data, method, csv_path)
|
||
|
||
def _save_dat_file(self, data, file_path):
|
||
"""保存.dat文件(二进制格式)"""
|
||
with open(file_path, 'wb') as f:
|
||
data.astype(np.float32).tofile(f)
|
||
|
||
def _save_hdr_file(self, hdr_path, data_shape, band_names, method, is_csv=False):
|
||
"""保存ENVI头文件"""
|
||
if is_csv:
|
||
lines, samples = 1, data_shape[0]
|
||
else:
|
||
lines, samples = data_shape[0], data_shape[1]
|
||
|
||
bands = data_shape[2] if len(data_shape) == 3 else data_shape[1]
|
||
|
||
# 从原始头文件中保留关键信息
|
||
original_info = {}
|
||
if self.header and self.data_type == 'image':
|
||
# 保留原始的传感器和采集信息
|
||
preserve_keys = [
|
||
'sensor type', 'wavelength units', 'sample binning',
|
||
'spectral binning', 'line binning', 'shutter', 'gain',
|
||
'framerate', 'imager serial number', 'rotation', 'label',
|
||
'map info', 'coordinate system string', 'classes'
|
||
]
|
||
|
||
for key in preserve_keys:
|
||
if key in self.header:
|
||
original_info[key] = self.header[key]
|
||
|
||
# 处理波长信息 - 如果是降维结果,使用新的波段名称
|
||
if 'wavelength' in self.header:
|
||
# 对于降维结果,不使用原始波长,而是使用新的波段名称
|
||
original_info['band names'] = band_names
|
||
else:
|
||
original_info['band names'] = band_names
|
||
|
||
header_content = f"""ENVI
|
||
description = {{{method} Result [{pd.Timestamp.now().strftime('%a %b %d %H:%M:%S %Y')}]}}
|
||
samples = {samples}
|
||
lines = {lines}
|
||
bands = {bands}
|
||
header offset = 0
|
||
file type = ENVI Standard
|
||
data type = 4
|
||
interleave = bip
|
||
byte order = 0
|
||
"""
|
||
|
||
# 添加保留的原始信息
|
||
for key, value in original_info.items():
|
||
if key == 'band names':
|
||
header_content += f"band names = {{{', '.join(value)}}}\n"
|
||
elif key == 'wavelength':
|
||
# 格式化波长信息
|
||
wavelength_str = ', '.join([f'{w:.6f}' for w in value]) if isinstance(value, (list, np.ndarray)) else str(value)
|
||
header_content += f"original wavelength = {{{wavelength_str}}}\n"
|
||
elif key == 'rotation':
|
||
# 特殊处理rotation格式
|
||
if isinstance(value, list) and len(value) == 4:
|
||
rotation_str = f"[({value[0][0]}, {value[0][1]}), ({value[1][0]}, {value[1][1]}), ({value[2][0]}, {value[2][1]}), ({value[3][0]}, {value[3][1]})]"
|
||
header_content += f"rotation = {rotation_str}\n"
|
||
else:
|
||
header_content += f"rotation = {value}\n"
|
||
elif isinstance(value, str) and '\n' in value:
|
||
# 多行字符串
|
||
header_content += f"{key} = {{\n{value}\n}}\n"
|
||
else:
|
||
header_content += f"{key} = {value}\n"
|
||
|
||
# 如果没有原始信息,添加默认值
|
||
if not original_info:
|
||
header_content += f"""wavelength units = Unknown
|
||
band names = {{{', '.join(band_names)}}}
|
||
classes = 0
|
||
map info = {{Geographic Lat/Lon, 1.0000, 1.0000, 0.0, 0.0, 0.0, 0.0}}
|
||
"""
|
||
|
||
# 添加处理历史
|
||
if self.header and 'history' in self.header:
|
||
history_info = f"{self.header['history']} -> {method}<SpecGroup>[<SpecInt label:'Components' value:{bands}>]"
|
||
else:
|
||
history_info = f"{method}<SpecGroup>[<SpecInt label:'Components' value:{bands}>]"
|
||
|
||
header_content += f"history = {history_info}\n"
|
||
|
||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||
f.write(header_content)
|
||
|
||
def _generate_visualization_plots(self, reduced_data, method, csv_path):
|
||
"""生成降维结果的可视化散点图"""
|
||
try:
|
||
# 获取贡献百分比(仅对支持的方法)
|
||
explained_variance_ratio = self._get_explained_variance_ratio(method)
|
||
|
||
# 生成2D散点图(前两个特征)
|
||
if reduced_data.shape[1] >= 2:
|
||
self._plot_2d_scatter(reduced_data, method, explained_variance_ratio, csv_path)
|
||
|
||
# 生成3D散点图(前三个特征)
|
||
if reduced_data.shape[1] >= 3:
|
||
self._plot_3d_scatter(reduced_data, method, explained_variance_ratio, csv_path)
|
||
|
||
except Exception as e:
|
||
print(f"生成可视化图表时出错: {e}")
|
||
|
||
def _get_explained_variance_ratio(self, method):
|
||
"""获取降维方法的贡献百分比"""
|
||
method_lower = method.lower()
|
||
|
||
# 对于PCA,可以获取解释方差比
|
||
if method_lower == 'pca' and self._explained_variance_ratio is not None:
|
||
return self._explained_variance_ratio
|
||
|
||
# 对于其他方法,目前不支持贡献百分比
|
||
return None
|
||
|
||
def _plot_2d_scatter(self, data, method, explained_variance_ratio, csv_path):
|
||
"""生成2D散点图"""
|
||
plt.figure(figsize=(10, 8))
|
||
|
||
if self.config.color_by_label and self.labels is not None:
|
||
# 按标签着色
|
||
unique_labels = np.unique(self.labels)
|
||
colors = cm.rainbow(np.linspace(0, 1, len(unique_labels)))
|
||
|
||
for i, label in enumerate(unique_labels):
|
||
mask = self.labels == label
|
||
plt.scatter(data[mask, 0], data[mask, 1],
|
||
c=[colors[i]], label=f'Class {label}',
|
||
alpha=0.7, s=50, edgecolors='black', linewidth=0.5)
|
||
plt.legend()
|
||
else:
|
||
# 统一颜色
|
||
plt.scatter(data[:, 0], data[:, 1], alpha=0.7, s=50,
|
||
c='blue', edgecolors='black', linewidth=0.5)
|
||
|
||
# 设置坐标轴标签
|
||
xlabel = 'PC1'
|
||
ylabel = 'PC2'
|
||
if explained_variance_ratio is not None and len(explained_variance_ratio) >= 2:
|
||
xlabel = f'PC1 ({explained_variance_ratio[0]:.1%})'
|
||
ylabel = f'PC2 ({explained_variance_ratio[1]:.1%})'
|
||
|
||
plt.xlabel(xlabel, fontsize=12, fontweight='bold')
|
||
plt.ylabel(ylabel, fontsize=12, fontweight='bold')
|
||
plt.title(f'{method.upper()} - 2D Scatter Plot', fontsize=14, fontweight='bold')
|
||
plt.grid(True, alpha=0.3)
|
||
|
||
# 保存图表
|
||
if self.config.save_plots:
|
||
plot_path = os.path.splitext(csv_path)[0] + '_2d_scatter.png'
|
||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||
print(f"2D散点图已保存: {plot_path}")
|
||
|
||
plt.close()
|
||
|
||
def _plot_3d_scatter(self, data, method, explained_variance_ratio, csv_path):
|
||
"""生成3D散点图"""
|
||
fig = plt.figure(figsize=(12, 10))
|
||
ax = fig.add_subplot(111, projection='3d')
|
||
|
||
if self.config.color_by_label and self.labels is not None:
|
||
# 按标签着色
|
||
unique_labels = np.unique(self.labels)
|
||
colors = cm.rainbow(np.linspace(0, 1, len(unique_labels)))
|
||
|
||
for i, label in enumerate(unique_labels):
|
||
mask = self.labels == label
|
||
ax.scatter(data[mask, 0], data[mask, 1], data[mask, 2],
|
||
c=[colors[i]], label=f'Class {label}',
|
||
alpha=0.7, s=50, edgecolors='black', linewidth=0.5)
|
||
ax.legend()
|
||
else:
|
||
# 统一颜色
|
||
ax.scatter(data[:, 0], data[:, 1], data[:, 2],
|
||
alpha=0.7, s=50, c='blue', edgecolors='black', linewidth=0.5)
|
||
|
||
# 设置坐标轴标签
|
||
xlabel = 'PC1'
|
||
ylabel = 'PC2'
|
||
zlabel = 'PC3'
|
||
if explained_variance_ratio is not None and len(explained_variance_ratio) >= 3:
|
||
xlabel = f'PC1 ({explained_variance_ratio[0]:.1%})'
|
||
ylabel = f'PC2 ({explained_variance_ratio[1]:.1%})'
|
||
zlabel = f'PC3 ({explained_variance_ratio[2]:.1%})'
|
||
|
||
ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
|
||
ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
|
||
ax.set_zlabel(zlabel, fontsize=12, fontweight='bold')
|
||
ax.set_title(f'{method.upper()} - 3D Scatter Plot', fontsize=14, fontweight='bold')
|
||
|
||
# 保存图表
|
||
if self.config.save_plots:
|
||
plot_path = os.path.splitext(csv_path)[0] + '_3d_scatter.png'
|
||
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
|
||
print(f"3D散点图已保存: {plot_path}")
|
||
|
||
plt.close()
|
||
|
||
def batch_process(self, methods, n_components_list, output_dir):
|
||
"""
|
||
批量处理多个降维方法
|
||
|
||
参数:
|
||
methods: 降维方法列表
|
||
n_components_list: 各方法的维度列表
|
||
output_dir: 输出目录
|
||
"""
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 显示数据信息和方法推荐
|
||
if self.data is not None:
|
||
n_samples, n_features = self.data.shape
|
||
has_labels = self.labels is not None
|
||
|
||
print(f"\n=== 数据信息 ===")
|
||
print(f"数据类型: {self.data_type}")
|
||
print(f"样本数量: {n_samples}")
|
||
print(f"特征数量: {n_features}")
|
||
print(f"是否有标签: {'是' if has_labels else '否'}")
|
||
|
||
# 获取推荐方法
|
||
recommendations = self.get_recommended_methods(
|
||
self.data_type, n_samples, n_features, has_labels
|
||
)
|
||
|
||
print(f"\n=== 方法推荐 ===")
|
||
print(f"高度推荐: {', '.join(recommendations['highly_recommended'])}")
|
||
if recommendations['recommended']:
|
||
print(f"推荐: {', '.join(recommendations['recommended'])}")
|
||
if recommendations['use_with_caution']:
|
||
print(f"谨慎使用: {', '.join(recommendations['use_with_caution'])}")
|
||
if recommendations['not_recommended']:
|
||
print(f"不推荐: {', '.join(recommendations['not_recommended'])}")
|
||
|
||
# 检查用户选择的方法是否在推荐列表中
|
||
for method in methods:
|
||
method_lower = method.lower()
|
||
method_upper = method_lower.upper()
|
||
if method_upper in recommendations['not_recommended']:
|
||
print(f"⚠️ 警告: {method_lower}方法不适合当前数据类型,建议使用推荐方法")
|
||
elif method_upper in recommendations['use_with_caution']:
|
||
print(f"⚠️ 注意: {method_lower}方法对大数据计算复杂度较高,请确保内存充足")
|
||
|
||
results = {}
|
||
for method, n_comp in zip(methods, n_components_list):
|
||
# 统一方法名为小写
|
||
method_lower = method.lower()
|
||
print(f"\n正在处理: {method_lower} (维度: {n_comp})")
|
||
|
||
try:
|
||
# 临时修改配置以适应当前方法
|
||
original_method = self.config.method
|
||
original_n_components = self.config.n_components
|
||
original_method_params = self.config.method_params.copy()
|
||
|
||
self.config.method = method_lower
|
||
self.config.n_components = n_comp
|
||
# 清空方法参数,重新设置当前方法的默认参数
|
||
self.config.method_params = {}
|
||
self.config._set_method_default_params()
|
||
|
||
reduced_data, band_names = self.apply_dim_reduction()
|
||
|
||
# 恢复原始配置
|
||
self.config.method = original_method
|
||
self.config.n_components = original_n_components
|
||
self.config.method_params = original_method_params
|
||
|
||
# 保存结果
|
||
if self.data_type == 'image':
|
||
output_path = os.path.join(output_dir, f'reduced_{method_lower}.dat')
|
||
else: # CSV格式
|
||
output_path = os.path.join(output_dir, f'reduced_{method_lower}.csv')
|
||
|
||
self.save_results(reduced_data, band_names, output_path, method_lower)
|
||
|
||
# 如果启用可视化且是CSV数据,在批量处理时也生成可视化
|
||
if self.config.generate_plots and self.data_type == 'csv':
|
||
try:
|
||
self._generate_visualization_plots(reduced_data, method_lower, output_path)
|
||
except Exception as vis_error:
|
||
print(f"生成 {method_lower} 可视化时出错: {vis_error}")
|
||
|
||
results[method_lower] = {
|
||
'data': reduced_data,
|
||
'bands': band_names,
|
||
'n_components': n_comp
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"处理 {method_lower} 时出错: {str(e)}")
|
||
|
||
return results
|
||
|
||
def main():
|
||
"""主函数:命令行接口"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(
|
||
description='高光谱数据降维处理工具',
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
使用示例:
|
||
1. 单方法降维 (PCA, 3维):
|
||
python dimensionality_reduction.py input.csv -l Label -s 391.3225 -e 1003.175 -m pca -n 3 -o result
|
||
|
||
2. 批量降维处理多种方法:
|
||
python dimensionality_reduction.py input.csv -l Label -s 391.3225 -e 1003.175 -M pca ica fa -N 3 3 3 -d results
|
||
|
||
3. 处理ENVI图像数据:
|
||
python dimensionality_reduction.py image.hdr -m pca -n 5 -o reduced_image
|
||
|
||
4. 生成可视化图表:
|
||
python dimensionality_reduction.py input.csv -l Label -s 391.3225 -e 1003.175 -m pca -n 3 -o result --generate-plots --color-by-label
|
||
|
||
支持的降维方法:
|
||
pca: Principal Component Analysis (PCA)
|
||
ica: Independent Component Analysis (ICA)
|
||
fa: Factor Analysis (FA)
|
||
lda: Linear Discriminant Analysis (LDA)
|
||
mds: Multidimensional Scaling (MDS)
|
||
isomap: Isometric Mapping (Isomap)
|
||
lle: Locally Linear Embedding (LLE)
|
||
t-sne: t-Distributed Stochastic Neighbor Embedding (t-SNE)
|
||
"""
|
||
)
|
||
|
||
parser.add_argument('input_path', help='输入文件路径 (.hdr高光谱图像 或 .csv文件)')
|
||
|
||
# CSV文件参数
|
||
parser.add_argument('-l', '--label-col', help='CSV文件的标签列名')
|
||
parser.add_argument('-s', '--spectral-start', help='CSV文件的谱段起始列名')
|
||
parser.add_argument('-e', '--spectral-end', help='CSV文件的谱段结束列名')
|
||
|
||
# 单方法参数
|
||
parser.add_argument('-m', '--method', default='pca',
|
||
choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'],
|
||
help='降维方法 (默认: pca)')
|
||
parser.add_argument('-n', '--n-components', type=int, default=3,
|
||
help='降维后的维度数 (默认: 3)')
|
||
parser.add_argument('-o', '--output', help='输出文件路径或前缀')
|
||
|
||
# 批量处理参数
|
||
parser.add_argument('-M', '--batch-methods', nargs='+',
|
||
choices=['pca', 'ica', 'fa', 'lda', 'mds', 'isomap', 'lle', 't-sne'],
|
||
help='批量处理的降维方法列表')
|
||
parser.add_argument('-N', '--batch-components', nargs='+', type=int,
|
||
help='对应批量方法的维度数列表')
|
||
parser.add_argument('-d', '--output-dir', help='批量处理输出目录')
|
||
|
||
# 可视化参数
|
||
parser.add_argument('--generate-plots', action='store_true',
|
||
help='生成可视化图表')
|
||
parser.add_argument('--color-by-label', action='store_true',
|
||
help='按标签着色散点图')
|
||
parser.add_argument('--save-plots', action='store_true',
|
||
help='保存图表文件')
|
||
|
||
# 其他选项
|
||
parser.add_argument('--no-standardization', action='store_true',
|
||
help='不进行数据标准化')
|
||
parser.add_argument('--show-methods', action='store_true',
|
||
help='显示所有可用方法的详细信息')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 显示方法信息
|
||
if args.show_methods:
|
||
HyperspectralDimReduction.print_method_info()
|
||
return 0
|
||
|
||
try:
|
||
print("=" * 60)
|
||
print("高光谱数据降维处理工具")
|
||
print("=" * 60)
|
||
print(f"输入文件: {args.input_path}")
|
||
|
||
# 检查是单方法还是批量处理
|
||
is_batch = args.batch_methods is not None
|
||
|
||
if is_batch:
|
||
print(f"批量处理方法: {', '.join(args.batch_methods)}")
|
||
print(f"对应维度: {args.batch_components}")
|
||
print(f"输出目录: {args.output_dir}")
|
||
else:
|
||
print(f"降维方法: {args.method}")
|
||
print(f"目标维度: {args.n_components}")
|
||
print(f"输出路径: {args.output}")
|
||
|
||
if args.label_col:
|
||
print(f"标签列: {args.label_col}")
|
||
if args.spectral_start:
|
||
print(f"谱段范围: {args.spectral_start} - {args.spectral_end}")
|
||
print()
|
||
|
||
# 创建配置
|
||
config = DimensionalityReductionConfig(
|
||
input_path=args.input_path,
|
||
label_col=args.label_col,
|
||
spectral_start=args.spectral_start,
|
||
spectral_end=args.spectral_end,
|
||
output_path=args.output,
|
||
output_dir=args.output_dir,
|
||
use_standardization=not args.no_standardization,
|
||
generate_plots=args.generate_plots,
|
||
color_by_label=args.color_by_label,
|
||
save_plots=args.save_plots
|
||
)
|
||
|
||
# 创建处理器
|
||
processor = HyperspectralDimReduction(config)
|
||
|
||
# 加载数据
|
||
print("正在加载数据...")
|
||
processor.load_data()
|
||
|
||
if is_batch:
|
||
# 批量处理
|
||
if not args.batch_components:
|
||
args.batch_components = [3] * len(args.batch_methods)
|
||
elif len(args.batch_components) != len(args.batch_methods):
|
||
raise ValueError(f"批量方法数量 ({len(args.batch_methods)}) 与维度数量 ({len(args.batch_components)}) 不匹配")
|
||
|
||
print("\n开始批量降维处理...")
|
||
results = processor.batch_process(args.batch_methods, args.batch_components, args.output_dir or 'reduced_data_batch')
|
||
|
||
# 打印结果汇总
|
||
print(f"\n=== 批量处理完成 ===")
|
||
print(f"共处理了 {len(results)} 种降维方法:")
|
||
for method_name, result in results.items():
|
||
print(f"- {method_name}: {result['n_components']} 维")
|
||
|
||
# 保存汇总信息
|
||
import pandas as pd
|
||
summary_data = []
|
||
for method_name, result in results.items():
|
||
summary_data.append({
|
||
'method': method_name,
|
||
'n_components': result['n_components'],
|
||
'data_shape': str(result['data'].shape),
|
||
'output_file': f'reduced_{method_name}'
|
||
})
|
||
|
||
summary_df = pd.DataFrame(summary_data)
|
||
summary_path = os.path.join(args.output_dir or 'reduced_data_batch', 'batch_processing_summary.csv')
|
||
summary_df.to_csv(summary_path, index=False)
|
||
print(f"批量处理汇总已保存到: {summary_path}")
|
||
|
||
else:
|
||
# 单方法处理
|
||
print("\n开始降维处理...")
|
||
config.method = args.method
|
||
config.n_components = args.n_components
|
||
config.output_path = args.output
|
||
|
||
reduced_data, band_names = processor.apply_dim_reduction()
|
||
|
||
print("✓ 降维处理完成")
|
||
print(f"原始数据维度: {processor.data.shape if processor.data is not None else 'Unknown'}")
|
||
print(f"降维后数据维度: {reduced_data.shape}")
|
||
|
||
# 保存结果
|
||
if args.output:
|
||
processor.save_results(reduced_data, band_names, args.output, args.method)
|
||
|
||
print("\n" + "=" * 60)
|
||
print("处理完成!")
|
||
print("=" * 60)
|
||
|
||
except Exception as e:
|
||
print(f"✗ 处理失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
return 0
|
||
|
||
if __name__ == "__main__":
|
||
exit(main()) |