Files
micro_plastic/classification_model/WaveSelect/WaveSelcet.py
2026-02-25 09:42:51 +08:00

94 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
from classification_model.WaveSelect.Lar import Lar
from classification_model.WaveSelect.Spa import SPA
from classification_model.WaveSelect.Spa_acc import SPA_acc
from classification_model.WaveSelect.Uve import UVE
from classification_model.WaveSelect.Cars import CARS_Cloud
from classification_model.WaveSelect.Pca import Pca
from classification_model.WaveSelect.GA import GA
from classification_model.WaveSelect.ReliefF import ReliefF
from sklearn.model_selection import train_test_split
# from WaveSelect.MRMR import MRMRFeatureSelection
import os
import matplotlib.pyplot as plt
def SpctrumFeatureSelcet(method, X, y, name='', result_dir='', column_names=None):
"""
:param method: 波长筛选/降维的方法包括Cars, Lars, Uve, Spa, Pca。
:param X: 光谱数据,可以是 pandas DataFrame 或 numpy array (n_samples, n_features)。
:param y: 光谱数据对应的标签 (n_samples,)。
:param name: 结果图像的文件名。
:param result_dir: 保存结果的文件夹路径。
:param column_names: 如果 X 是 numpy array需要提供列名列表。
:return:
- X_Feature: 选择/降维后的数据 (n_samples, n_features)。
- y: 对应的标签。
- selected_columns: 选择的特征列名或索引。
"""
global X_Feature
# 判断输入数据类型并转换为 DataFrame如有必要
if isinstance(X, np.ndarray):
if column_names is None:
column_names = [f"{i}" for i in range(X.shape[1])] # 默认列名
X_df = pd.DataFrame(X, columns=column_names)
else:
X_df = X
# 根据所选方法执行特征选择
if method == "None":
X_Feature = X_df
selected_columns = X_df.columns
elif method == "Cars":
save_path = os.path.join(result_dir, f"{name}_cars.png")
# 调用 CARS_Cloud 并获取结果
Featuresecletidx = CARS_Cloud(X_df.values, y, N=50, f=20, cv=10, save_fig=True,save_path=save_path)
Featuresecletidx = Featuresecletidx.astype(int)
X_Feature = X_df.iloc[:, Featuresecletidx]
selected_columns = Featuresecletidx
elif method == "Lars":
Featuresecletidx = Lar(X_df.values, y)
X_Feature = X_df.iloc[:, Featuresecletidx]
selected_columns = X_df.columns[Featuresecletidx]
elif method == "Uve":
uve = UVE(X_df.values, y, 20)
uve.calcCriteria()
uve.evalCriteria(cv=5)
Featuresecletidx = uve.cutFeature() # 返回所选特征的索引
X_Feature = X_df.iloc[:, Featuresecletidx]
selected_columns = X_df.columns[Featuresecletidx]
elif method == "Spa":
save_path = os.path.join(result_dir, f"{name}_spa.png")
Xcal, Xval, ycal, yval = train_test_split(X_df, y, test_size=0.3)
Featuresecletidx, var_sel_phase2 = SPA().spa(
Xcal, ycal, m_min=2, m_max=50, Xval=Xval, yval=yval, autoscaling=1,save_path=save_path)
X_Feature = X_df.iloc[:, Featuresecletidx]
selected_columns = X_df.columns[Featuresecletidx]
elif method == "Spa_acc":
save_path = os.path.join(result_dir, f"{name}_spa_acc.png")
Xcal, Xval, ycal, yval = train_test_split(X_df, y, test_size=0.3)
Featuresecletidx, var_sel_phase2 = SPA_acc().spa(
Xcal, ycal, m_min=2, m_max=50, Xval=Xval, yval=yval, autoscaling=1,save_path=save_path)
X_Feature = X_df.iloc[:, Featuresecletidx]
selected_columns = X_df.columns[Featuresecletidx]
elif method == "GA":
Featuresecletidx = GA(X_df.values, y, 10)
X_Feature = X_df.iloc[:, Featuresecletidx]
selected_columns = X_df.columns[Featuresecletidx]
elif method == "Pca":
X_Feature = Pca(X_df.values)
selected_columns = [f"PC{i+1}" for i in range(X_Feature.shape[1])]
elif method == "ReliefF":
relieff = ReliefF(n_neighbors=20, n_features_to_keep=20)
Featuresecletidx = relieff.fit(X_df.values, y)
X_Feature = X_df.iloc[:, Featuresecletidx]
selected_columns = X_df.columns[Featuresecletidx]
else:
print("没有这个波长筛选方法!")
return None, None
return X_Feature, y, selected_columns # 返回所选特征数据、标签和列名