增加模块;增加主调用命令
This commit is contained in:
594
Feature_Selection_method/feture_select.py
Normal file
594
Feature_Selection_method/feture_select.py
Normal file
@ -0,0 +1,594 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from Feature_Selection_method.Lar import Lar
|
||||
from Feature_Selection_method.Spa import SPA
|
||||
from Feature_Selection_method.Uve import UVE
|
||||
from Feature_Selection_method.Cars import CARS_Cloud
|
||||
from Feature_Selection_method.GA import GA
|
||||
from Feature_Selection_method.ReliefF import ReliefF
|
||||
from Feature_Selection_method.random_fog import shuffled_frog_leaping_selection
|
||||
from Feature_Selection_method.sipls import sipls_feature_selection
|
||||
from sklearn.model_selection import train_test_split
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Optional, Union, List, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
def _get_x_axis_values(feature_names: List[str]) -> Tuple[Optional[np.ndarray], str]:
|
||||
"""
|
||||
从特征名称中提取x轴数值(通常是波长)
|
||||
|
||||
Args:
|
||||
feature_names: 特征名称列表
|
||||
|
||||
Returns:
|
||||
(x_values, x_label): x轴数值数组和标签,如果无法提取则返回(None, "")
|
||||
"""
|
||||
if not feature_names:
|
||||
return None, ""
|
||||
|
||||
# 尝试从列名中提取数值
|
||||
x_values = []
|
||||
for name in feature_names:
|
||||
try:
|
||||
# 尝试将列名转换为浮点数
|
||||
if isinstance(name, (int, float)):
|
||||
x_values.append(float(name))
|
||||
elif isinstance(name, str):
|
||||
# 尝试提取字符串中的数值
|
||||
# 处理类似 "400.5", "Band_400", "Wavelength_400.5nm" 的格式
|
||||
import re
|
||||
# 查找浮点数模式
|
||||
match = re.search(r'(\d+\.?\d*)', str(name))
|
||||
if match:
|
||||
x_values.append(float(match.group(1)))
|
||||
else:
|
||||
# 如果找不到数值,返回None
|
||||
return None, ""
|
||||
else:
|
||||
return None, ""
|
||||
except (ValueError, TypeError):
|
||||
return None, ""
|
||||
|
||||
# 检查是否所有值都是唯一的(避免重复的波长)
|
||||
if len(set(x_values)) != len(x_values):
|
||||
return None, ""
|
||||
|
||||
# 检查波长范围是否合理(假设是nm单位,范围在200-2500nm之间)
|
||||
x_array = np.array(x_values)
|
||||
if np.min(x_array) < 200 or np.max(x_array) > 2500:
|
||||
return None, ""
|
||||
|
||||
# 确定标签
|
||||
x_label = "Wavelength (nm)"
|
||||
|
||||
return x_array, x_label
|
||||
|
||||
|
||||
def plot_feature_selection_results(X: Union[pd.DataFrame, np.ndarray],
|
||||
selected_indices: Union[List[int], np.ndarray],
|
||||
method_name: str,
|
||||
save_path: Optional[str] = None,
|
||||
figsize: Tuple[int, int] = (12, 6)) -> plt.Figure:
|
||||
"""
|
||||
绘制特征选择结果的可视化图
|
||||
|
||||
Args:
|
||||
X: 特征数据矩阵 (n_samples, n_features)
|
||||
selected_indices: 选择的特征索引列表
|
||||
method_name: 特征选择方法名称
|
||||
save_path: 图片保存路径,如果为None则不保存
|
||||
figsize: 图片尺寸
|
||||
|
||||
Returns:
|
||||
matplotlib Figure对象
|
||||
"""
|
||||
# 转换为numpy数组
|
||||
if isinstance(X, pd.DataFrame):
|
||||
X_array = X.values
|
||||
feature_names = X.columns.tolist()
|
||||
else:
|
||||
X_array = X
|
||||
feature_names = [f"Feature_{i}" for i in range(X.shape[1])]
|
||||
|
||||
# 计算平均光谱
|
||||
mean_spectrum = np.mean(X_array, axis=0)
|
||||
n_features = X_array.shape[1]
|
||||
|
||||
# 创建x轴 - 尝试使用波长值而不是索引
|
||||
x_values, x_label = _get_x_axis_values(feature_names)
|
||||
if x_values is None:
|
||||
# 如果无法提取波长值,使用特征索引
|
||||
x_values = np.arange(n_features)
|
||||
x_label = "Feature Index"
|
||||
|
||||
# 创建图形
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# 绘制平均光谱曲线
|
||||
ax.plot(x_values, mean_spectrum, 'b-', linewidth=1.5, alpha=0.8, label='Mean Spectrum')
|
||||
|
||||
# 标注选择的特征点
|
||||
if len(selected_indices) > 0:
|
||||
# 确保selected_indices是有效的numpy数组
|
||||
selected_indices = np.asarray(selected_indices, dtype=int)
|
||||
|
||||
# 检查索引范围
|
||||
valid_indices = selected_indices[(selected_indices >= 0) & (selected_indices < len(x_values))]
|
||||
|
||||
if len(valid_indices) > 0:
|
||||
selected_x = x_values[valid_indices]
|
||||
selected_y = mean_spectrum[valid_indices]
|
||||
|
||||
ax.scatter(selected_x, selected_y, color='red', s=60, alpha=0.9,
|
||||
edgecolors='darkred', linewidth=1.5, label='Selected Features', zorder=5)
|
||||
|
||||
# 添加选择的特征数量信息
|
||||
ax.text(0.02, 0.98, f'Selected: {len(selected_indices)}/{n_features} features',
|
||||
transform=ax.transAxes, fontsize=10, verticalalignment='top',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
||||
|
||||
# 设置标题和标签
|
||||
ax.set_title(f'Feature Selection Results - {method_name}', fontsize=14, fontweight='bold')
|
||||
ax.set_xlabel(x_label, fontsize=12)
|
||||
ax.set_ylabel('Intensity', fontsize=12)
|
||||
|
||||
# 设置网格和图例
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.legend(loc='upper right', fontsize=10)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
print(f"Visualization saved to: {save_path}")
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureSelectionConfig:
|
||||
"""特征选择配置类"""
|
||||
# CSV文件相关配置
|
||||
csv_file_path: Optional[str] = None
|
||||
label_column: Optional[str] = None
|
||||
spectral_columns: Optional[List[str]] = None
|
||||
|
||||
# 特征选择方法配置
|
||||
method: str = "None"
|
||||
method_params: dict = field(default_factory=dict)
|
||||
|
||||
# 输出配置
|
||||
output_csv: bool = False
|
||||
output_dir: str = ""
|
||||
output_filename: str = "selected_features"
|
||||
|
||||
# 可视化配置
|
||||
save_plots: bool = True
|
||||
plot_name_prefix: str = ""
|
||||
plot_dir: Optional[str] = None # 可视化图片保存目录,如果为None则使用output_dir
|
||||
|
||||
def __post_init__(self):
|
||||
"""参数校验和默认值设置"""
|
||||
if self.csv_file_path and not os.path.exists(self.csv_file_path):
|
||||
raise FileNotFoundError(f"CSV文件不存在: {self.csv_file_path}")
|
||||
|
||||
if self.csv_file_path and not self.label_column:
|
||||
raise ValueError("指定CSV文件时必须提供标签列名(label_column)")
|
||||
|
||||
if self.csv_file_path and not self.spectral_columns:
|
||||
raise ValueError("指定CSV文件时必须提供光谱列名列表(spectral_columns)")
|
||||
|
||||
# 设置默认的方法参数
|
||||
self._set_default_method_params()
|
||||
|
||||
def _set_default_method_params(self):
|
||||
"""根据方法设置默认参数"""
|
||||
if self.method == "Cars":
|
||||
self.method_params.setdefault('N', 50)
|
||||
self.method_params.setdefault('f', 20)
|
||||
self.method_params.setdefault('cv', 10)
|
||||
elif self.method == "Uve":
|
||||
self.method_params.setdefault('ncomp', 20)
|
||||
self.method_params.setdefault('cv', 5)
|
||||
elif self.method == "Spa":
|
||||
self.method_params.setdefault('m_min', 2)
|
||||
self.method_params.setdefault('m_max', 50)
|
||||
self.method_params.setdefault('autoscaling', 1)
|
||||
elif self.method == "GA":
|
||||
self.method_params.setdefault('population_size', 10)
|
||||
elif self.method == "ReliefF":
|
||||
self.method_params.setdefault('n_neighbors', 20)
|
||||
self.method_params.setdefault('n_features_to_keep', 20)
|
||||
elif self.method == "RandomFrog":
|
||||
self.method_params.setdefault('n_frogs', 50)
|
||||
self.method_params.setdefault('n_memeplexes', 5)
|
||||
self.method_params.setdefault('n_evolution_steps', 10)
|
||||
self.method_params.setdefault('n_shuffle_iterations', 10)
|
||||
self.method_params.setdefault('cv', 5)
|
||||
elif self.method == "SiPLS":
|
||||
self.method_params.setdefault('n_intervals_list', [10, 15, 20])
|
||||
self.method_params.setdefault('n_combinations_list', [2, 3, 4])
|
||||
self.method_params.setdefault('max_components', 15)
|
||||
self.method_params.setdefault('cv_folds', 5)
|
||||
|
||||
|
||||
class SpectrumFeatureSelector:
|
||||
"""光谱特征选择器"""
|
||||
|
||||
def __init__(self, config: FeatureSelectionConfig):
|
||||
self.config = config
|
||||
|
||||
def load_csv_data(self) -> Tuple[pd.DataFrame, np.ndarray]:
|
||||
"""从CSV文件加载数据"""
|
||||
if not self.config.csv_file_path:
|
||||
raise ValueError("未指定CSV文件路径")
|
||||
|
||||
df = pd.read_csv(self.config.csv_file_path)
|
||||
|
||||
# 验证列是否存在
|
||||
if self.config.label_column not in df.columns:
|
||||
raise ValueError(f"标签列 '{self.config.label_column}' 不存在于CSV文件中")
|
||||
|
||||
missing_cols = [col for col in self.config.spectral_columns if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"以下光谱列不存在于CSV文件中: {missing_cols}")
|
||||
|
||||
# 提取特征和标签
|
||||
X = df[self.config.spectral_columns]
|
||||
y = df[self.config.label_column].values
|
||||
|
||||
return X, y
|
||||
|
||||
def save_selected_features_csv(self, X_selected: pd.DataFrame, y: np.ndarray,
|
||||
selected_columns: Union[List[str], np.ndarray]):
|
||||
"""保存选定的特征到CSV文件"""
|
||||
if not self.config.output_csv:
|
||||
return
|
||||
|
||||
os.makedirs(self.config.output_dir, exist_ok=True)
|
||||
|
||||
# 创建结果DataFrame
|
||||
if isinstance(selected_columns, np.ndarray):
|
||||
selected_col_names = [f"feature_{i}" for i in selected_columns]
|
||||
else:
|
||||
selected_col_names = selected_columns
|
||||
|
||||
result_df = pd.DataFrame(X_selected.values, columns=selected_col_names)
|
||||
result_df[self.config.label_column] = y
|
||||
|
||||
output_path = os.path.join(self.config.output_dir,
|
||||
f"{self.config.output_filename}.csv")
|
||||
result_df.to_csv(output_path, index=False)
|
||||
print(f"Selected features saved to: {output_path}")
|
||||
|
||||
def plot_feature_selection(self, X: pd.DataFrame,
|
||||
selected_indices: Union[List[int], np.ndarray]) -> Optional[plt.Figure]:
|
||||
"""绘制特征选择结果可视化"""
|
||||
if not self.config.save_plots:
|
||||
return None
|
||||
|
||||
# 确定保存目录
|
||||
plot_dir = self.config.plot_dir if self.config.plot_dir else self.config.output_dir
|
||||
if not plot_dir:
|
||||
return None
|
||||
|
||||
os.makedirs(plot_dir, exist_ok=True)
|
||||
|
||||
# 生成文件名
|
||||
filename = f"{self.config.plot_name_prefix}_{self.config.method}_feature_selection.png"
|
||||
save_path = os.path.join(plot_dir, filename)
|
||||
|
||||
# 绘制可视化图
|
||||
fig = plot_feature_selection_results(
|
||||
X=X,
|
||||
selected_indices=selected_indices,
|
||||
method_name=self.config.method,
|
||||
save_path=save_path
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def _convert_to_indices(self, X: pd.DataFrame, selected_columns) -> List[int]:
|
||||
"""
|
||||
将selected_columns转换为原始DataFrame X的索引列表
|
||||
|
||||
Args:
|
||||
X: 原始DataFrame
|
||||
selected_columns: 选择的列,可以是索引数组、列名列表等
|
||||
|
||||
Returns:
|
||||
索引列表
|
||||
"""
|
||||
try:
|
||||
# 处理pandas Index对象
|
||||
if hasattr(selected_columns, 'tolist'): # pandas Index or Series
|
||||
selected_columns = selected_columns.tolist()
|
||||
|
||||
if isinstance(selected_columns, np.ndarray):
|
||||
# 如果是numpy数组,直接作为索引
|
||||
return selected_columns.tolist()
|
||||
elif isinstance(selected_columns, list) and len(selected_columns) > 0:
|
||||
if isinstance(selected_columns[0], str):
|
||||
# 如果是列名列表,转换为索引
|
||||
indices = []
|
||||
for col in selected_columns:
|
||||
try:
|
||||
# 首先尝试精确匹配
|
||||
idx = X.columns.get_loc(col)
|
||||
indices.append(idx)
|
||||
except KeyError:
|
||||
# 如果精确匹配失败,尝试数值近似匹配(处理小数点精度问题)
|
||||
try:
|
||||
target_value = float(col)
|
||||
# 找到最接近的列名
|
||||
best_match = None
|
||||
best_diff = float('inf')
|
||||
best_idx = None
|
||||
|
||||
for i, col_name in enumerate(X.columns):
|
||||
try:
|
||||
col_value = float(col_name)
|
||||
diff = abs(col_value - target_value)
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best_match = col_name
|
||||
best_idx = i
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
if best_match is not None and best_diff < 1.0: # 允许1.0以内的误差
|
||||
print(f"Approximate match: '{col}' -> '{best_match}' (diff: {best_diff:.3f})")
|
||||
indices.append(best_idx)
|
||||
else:
|
||||
print(f"Warning: No suitable match found for column '{col}' in DataFrame columns")
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
print(f"Warning: Cannot parse column name '{col}' as numeric")
|
||||
continue
|
||||
return indices
|
||||
else:
|
||||
# 如果是数字列表,直接作为索引
|
||||
return [int(idx) for idx in selected_columns]
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Error converting selected_columns to indices: {e}")
|
||||
return []
|
||||
|
||||
def select_features(self, X: Optional[pd.DataFrame] = None, y: Optional[np.ndarray] = None,
|
||||
column_names: Optional[List[str]] = None) -> Tuple[pd.DataFrame, np.ndarray, Union[List[str], np.ndarray]]:
|
||||
"""
|
||||
执行特征选择
|
||||
|
||||
Args:
|
||||
X: 特征数据,如果为None则从CSV文件加载
|
||||
y: 标签数据,如果为None则从CSV文件加载
|
||||
column_names: 列名,用于numpy数组输入
|
||||
|
||||
Returns:
|
||||
X_selected: 选定的特征数据
|
||||
y: 标签数据
|
||||
selected_columns: 选定的列名或索引
|
||||
"""
|
||||
# 如果没有提供数据,从CSV加载
|
||||
if X is None or y is None:
|
||||
X, y = self.load_csv_data()
|
||||
|
||||
# 确保X是DataFrame格式
|
||||
if isinstance(X, np.ndarray):
|
||||
if column_names is not None:
|
||||
X = pd.DataFrame(X, columns=column_names)
|
||||
else:
|
||||
X = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(X.shape[1])])
|
||||
|
||||
# 执行特征选择
|
||||
X_selected, y_selected, selected_columns = SpctrumFeatureSelcet(
|
||||
method=self.config.method,
|
||||
X=X,
|
||||
y=y,
|
||||
name=self.config.plot_name_prefix,
|
||||
result_dir=self.config.output_dir if self.config.save_plots else '',
|
||||
column_names=None # 已经转换为DataFrame,不再需要column_names
|
||||
)
|
||||
|
||||
# 保存结果到CSV(如果配置了)
|
||||
self.save_selected_features_csv(X_selected, y_selected, selected_columns)
|
||||
|
||||
# 生成可视化图(如果配置了)
|
||||
if self.config.save_plots:
|
||||
# 转换selected_columns为原始数据集X中的索引列表
|
||||
# selected_columns对应X_selected中的列,我们需要找到它们在原始数据集X中的位置
|
||||
selected_indices = self._convert_to_indices(X, selected_columns)
|
||||
|
||||
if len(selected_indices) > 0:
|
||||
self.plot_feature_selection(X, selected_indices)
|
||||
else:
|
||||
print(f"Warning: No valid indices found for plotting. selected_columns: {selected_columns}")
|
||||
print(f"Available columns in X: {list(X.columns[:5])}...") # 显示前5个列名用于调试
|
||||
|
||||
return X_selected, y_selected, selected_columns
|
||||
|
||||
|
||||
def SpctrumFeatureSelcet(method, X, y, name='', result_dir='', column_names=None, method_params=None):
|
||||
"""
|
||||
核心特征选择函数(保持原有业务逻辑不变)
|
||||
|
||||
:param method: 波长筛选/降维的方法,包括:Cars, Lars, Uve, Spa, GA, ReliefF, RandomFrog, SiPLS。
|
||||
:param X: 光谱数据,可以是 pandas DataFrame 或 numpy array (n_samples, n_features)。
|
||||
:param y: 光谱数据对应的标签 (n_samples,)。
|
||||
:param name: 结果图像的文件名。
|
||||
:param result_dir: 保存结果的文件夹路径。
|
||||
:param column_names: 如果 X 是 numpy array,需要提供列名列表。
|
||||
:param method_params: 方法特定的参数字典。
|
||||
:return:
|
||||
- X_Feature: 选择/降维后的数据 (n_samples, n_features)。
|
||||
- y: 对应的标签。
|
||||
- selected_columns: 选择的特征列名或索引。
|
||||
"""
|
||||
if method_params is None:
|
||||
method_params = {}
|
||||
|
||||
global X_Feature
|
||||
|
||||
# 判断输入数据类型并转换为 DataFrame(如有必要)
|
||||
if isinstance(X, np.ndarray):
|
||||
if column_names is None:
|
||||
column_names = [f"{i}" for i in range(X.shape[1])] # 默认列名
|
||||
X_df = pd.DataFrame(X, columns=column_names)
|
||||
else:
|
||||
X_df = X
|
||||
|
||||
# 根据所选方法执行特征选择
|
||||
if method == "None":
|
||||
X_Feature = X_df
|
||||
selected_columns = X_df.columns
|
||||
elif method == "Cars":
|
||||
save_path = os.path.join(result_dir, f"{name}_cars.png") if result_dir else None
|
||||
# 调用 CARS_Cloud 并获取结果,使用配置的参数
|
||||
N = method_params.get('N', 50)
|
||||
f = method_params.get('f', 20)
|
||||
cv = method_params.get('cv', 10)
|
||||
|
||||
Featuresecletidx = CARS_Cloud(X_df.values, y, N=N, f=f, cv=cv,
|
||||
save_fig=bool(save_path), save_path=save_path)
|
||||
Featuresecletidx = Featuresecletidx.astype(int)
|
||||
X_Feature = X_df.iloc[:, Featuresecletidx]
|
||||
selected_columns = Featuresecletidx
|
||||
|
||||
elif method == "Lars":
|
||||
Featuresecletidx = Lar(X_df.values, y)
|
||||
X_Feature = X_df.iloc[:, Featuresecletidx]
|
||||
selected_columns = X_df.columns[Featuresecletidx]
|
||||
elif method == "Uve":
|
||||
ncomp = method_params.get('ncomp', 20)
|
||||
cv = method_params.get('cv', 5)
|
||||
|
||||
uve = UVE(X_df.values, y, ncomp)
|
||||
uve.calcCriteria()
|
||||
uve.evalCriteria(cv=cv)
|
||||
Featuresecletidx = uve.cutFeature() # 返回所选特征的索引
|
||||
X_Feature = X_df.iloc[:, Featuresecletidx]
|
||||
selected_columns = X_df.columns[Featuresecletidx]
|
||||
elif method == "Spa":
|
||||
save_path = os.path.join(result_dir, f"{name}_spa.png") if result_dir else None
|
||||
|
||||
Xcal, Xval, ycal, yval = train_test_split(X_df, y, test_size=0.3)
|
||||
|
||||
m_min = method_params.get('m_min', 2)
|
||||
m_max = method_params.get('m_max', 50)
|
||||
autoscaling = method_params.get('autoscaling', 1)
|
||||
|
||||
Featuresecletidx, var_sel_phase2 = SPA().spa(
|
||||
Xcal, ycal, m_min=m_min, m_max=m_max, Xval=Xval, yval=yval,
|
||||
autoscaling=autoscaling, save_path=save_path)
|
||||
X_Feature = X_df.iloc[:, Featuresecletidx]
|
||||
selected_columns = X_df.columns[Featuresecletidx]
|
||||
elif method == "GA":
|
||||
population_size = method_params.get('population_size', 10)
|
||||
Featuresecletidx = GA(X_df.values, y, population_size)
|
||||
X_Feature = X_df.iloc[:, Featuresecletidx]
|
||||
selected_columns = X_df.columns[Featuresecletidx]
|
||||
elif method == "ReliefF":
|
||||
n_neighbors = method_params.get('n_neighbors', 20)
|
||||
n_features_to_keep = method_params.get('n_features_to_keep', 20)
|
||||
|
||||
relieff = ReliefF(n_neighbors=n_neighbors, n_features_to_keep=n_features_to_keep)
|
||||
Featuresecletidx = relieff.fit(X_df.values, y)
|
||||
X_Feature = X_df.iloc[:, Featuresecletidx]
|
||||
selected_columns = X_df.columns[Featuresecletidx]
|
||||
elif method == "RandomFrog":
|
||||
n_frogs = method_params.get('n_frogs', 50)
|
||||
n_memeplexes = method_params.get('n_memeplexes', 5)
|
||||
n_evolution_steps = method_params.get('n_evolution_steps', 10)
|
||||
n_shuffle_iterations = method_params.get('n_shuffle_iterations', 10)
|
||||
cv = method_params.get('cv', 5)
|
||||
|
||||
Featuresecletidx = shuffled_frog_leaping_selection(
|
||||
X_df.values, y,
|
||||
n_frogs=n_frogs,
|
||||
n_memeplexes=n_memeplexes,
|
||||
n_evolution_steps=n_evolution_steps,
|
||||
n_shuffle_iterations=n_shuffle_iterations,
|
||||
cv=cv
|
||||
)
|
||||
X_Feature = X_df.iloc[:, Featuresecletidx]
|
||||
selected_columns = X_df.columns[Featuresecletidx]
|
||||
elif method == "SiPLS":
|
||||
n_intervals_list = method_params.get('n_intervals_list', [10, 15, 20])
|
||||
n_combinations_list = method_params.get('n_combinations_list', [2, 3, 4])
|
||||
max_components = method_params.get('max_components', 15)
|
||||
cv_folds = method_params.get('cv_folds', 5)
|
||||
|
||||
result = sipls_feature_selection(
|
||||
X_df.values, y,
|
||||
n_intervals_list=n_intervals_list,
|
||||
n_combinations_list=n_combinations_list,
|
||||
max_components=max_components,
|
||||
cv_folds=cv_folds
|
||||
)
|
||||
|
||||
if result and 'selected_wavelengths' in result:
|
||||
Featuresecletidx = result['selected_wavelengths']
|
||||
X_Feature = X_df.iloc[:, Featuresecletidx]
|
||||
selected_columns = X_df.columns[Featuresecletidx]
|
||||
else:
|
||||
raise ValueError("SiPLS算法未能找到有效的特征选择结果")
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的特征选择方法: {method}。支持的方法包括: None, Cars, Lars, Uve, Spa, GA, ReliefF, RandomFrog, SiPLS")
|
||||
|
||||
return X_Feature, y, selected_columns # 返回所选特征数据、标签和列名
|
||||
|
||||
|
||||
# 便捷函数,用于向后兼容和简化使用
|
||||
def select_features_from_csv(config: FeatureSelectionConfig) -> Tuple[pd.DataFrame, np.ndarray, Union[List[str], np.ndarray]]:
|
||||
"""
|
||||
从CSV文件进行特征选择的主要接口函数
|
||||
|
||||
Args:
|
||||
config: 特征选择配置对象
|
||||
|
||||
Returns:
|
||||
X_selected: 选定的特征数据
|
||||
y: 标签数据
|
||||
selected_columns: 选定的列名或索引
|
||||
"""
|
||||
selector = SpectrumFeatureSelector(config)
|
||||
return selector.select_features()
|
||||
|
||||
|
||||
def select_features_from_data(X: pd.DataFrame, y: np.ndarray, method: str,
|
||||
method_params: Optional[dict] = None,
|
||||
name: str = '', result_dir: str = '',
|
||||
column_names: Optional[List[str]] = None) -> Tuple[pd.DataFrame, np.ndarray, Union[List[str], np.ndarray]]:
|
||||
"""
|
||||
直接从数据进行特征选择的便捷函数
|
||||
|
||||
Args:
|
||||
X: 特征数据
|
||||
y: 标签数据
|
||||
method: 特征选择方法
|
||||
method_params: 方法参数
|
||||
name: 输出文件名前缀
|
||||
result_dir: 输出目录
|
||||
column_names: 列名
|
||||
|
||||
Returns:
|
||||
X_selected: 选定的特征数据
|
||||
y: 标签数据
|
||||
selected_columns: 选定的列名或索引
|
||||
"""
|
||||
config = FeatureSelectionConfig(
|
||||
method=method,
|
||||
method_params=method_params or {},
|
||||
output_csv=False, # 直接数据输入不输出CSV
|
||||
save_plots=bool(result_dir),
|
||||
plot_name_prefix=name
|
||||
)
|
||||
|
||||
selector = SpectrumFeatureSelector(config)
|
||||
return selector.select_features(X=X, y=y, column_names=column_names)
|
||||
Reference in New Issue
Block a user