增加模块;增加主调用命令
This commit is contained in:
59
Feature_Selection_method/GA.py
Normal file
59
Feature_Selection_method/GA.py
Normal file
@ -0,0 +1,59 @@
|
||||
from deap import base, creator, tools, algorithms
|
||||
import numpy as np
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.model_selection import cross_val_score
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
|
||||
def GA(X, y, n_generations=20, population_size=50, crossover_prob=0.7, mutation_prob=0.2):
|
||||
"""
|
||||
使用遗传算法进行特征选择,返回选择的特征索引。
|
||||
|
||||
参数:
|
||||
X (ndarray): 特征矩阵
|
||||
y (ndarray): 标签
|
||||
n_generations (int): 迭代次数
|
||||
population_size (int): 种群大小
|
||||
crossover_prob (float): 交叉概率
|
||||
mutation_prob (float): 变异概率
|
||||
|
||||
返回:
|
||||
list: 选择的特征索引
|
||||
"""
|
||||
# 创建适应度和个体
|
||||
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
|
||||
creator.create("Individual", list, fitness=creator.FitnessMax)
|
||||
|
||||
toolbox = base.Toolbox()
|
||||
toolbox.register("attr_bool", lambda: np.random.randint(0, 2))
|
||||
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_bool, n=X.shape[1])
|
||||
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
|
||||
|
||||
# 定义适应度函数
|
||||
def evaluate(individual):
|
||||
selected_features = [index for index, val in enumerate(individual) if val == 1]
|
||||
if not selected_features:
|
||||
return 0, # 没有特征时适应度为 0
|
||||
X_selected = X[:, selected_features]
|
||||
clf = RandomForestClassifier(random_state=42)
|
||||
score = cross_val_score(clf, X_selected, y, cv=5).mean() # 5 折交叉验证
|
||||
return score,
|
||||
|
||||
toolbox.register("evaluate", evaluate)
|
||||
toolbox.register("mate", tools.cxTwoPoint)
|
||||
toolbox.register("mutate", tools.mutFlipBit, indpb=0.05)
|
||||
toolbox.register("select", tools.selTournament, tournsize=3)
|
||||
|
||||
# 初始化种群
|
||||
population = toolbox.population(n=population_size)
|
||||
|
||||
# 运行遗传算法
|
||||
result_population, _ = algorithms.eaSimple(population, toolbox, cxpb=crossover_prob,
|
||||
mutpb=mutation_prob, ngen=n_generations,
|
||||
verbose=False)
|
||||
|
||||
# 获取最优个体
|
||||
best_individual = tools.selBest(result_population, k=1)[0]
|
||||
selected_features = [index for index, val in enumerate(best_individual) if val == 1]
|
||||
|
||||
return selected_features
|
||||
Reference in New Issue
Block a user