60 lines
2.3 KiB
Python
60 lines
2.3 KiB
Python
from deap import base, creator, tools, algorithms
|
|
import numpy as np
|
|
from sklearn.datasets import make_classification
|
|
from sklearn.model_selection import cross_val_score
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
|
|
def GA(X, y, n_generations=20, population_size=50, crossover_prob=0.7, mutation_prob=0.2):
|
|
"""
|
|
使用遗传算法进行特征选择,返回选择的特征索引。
|
|
|
|
参数:
|
|
X (ndarray): 特征矩阵
|
|
y (ndarray): 标签
|
|
n_generations (int): 迭代次数
|
|
population_size (int): 种群大小
|
|
crossover_prob (float): 交叉概率
|
|
mutation_prob (float): 变异概率
|
|
|
|
返回:
|
|
list: 选择的特征索引
|
|
"""
|
|
# 创建适应度和个体
|
|
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
|
|
creator.create("Individual", list, fitness=creator.FitnessMax)
|
|
|
|
toolbox = base.Toolbox()
|
|
toolbox.register("attr_bool", lambda: np.random.randint(0, 2))
|
|
toolbox.register("individual", tools.initRepeat, creator.Individual, toolbox.attr_bool, n=X.shape[1])
|
|
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
|
|
|
|
# 定义适应度函数
|
|
def evaluate(individual):
|
|
selected_features = [index for index, val in enumerate(individual) if val == 1]
|
|
if not selected_features:
|
|
return 0, # 没有特征时适应度为 0
|
|
X_selected = X[:, selected_features]
|
|
clf = RandomForestClassifier(random_state=42)
|
|
score = cross_val_score(clf, X_selected, y, cv=5).mean() # 5 折交叉验证
|
|
return score,
|
|
|
|
toolbox.register("evaluate", evaluate)
|
|
toolbox.register("mate", tools.cxTwoPoint)
|
|
toolbox.register("mutate", tools.mutFlipBit, indpb=0.05)
|
|
toolbox.register("select", tools.selTournament, tournsize=3)
|
|
|
|
# 初始化种群
|
|
population = toolbox.population(n=population_size)
|
|
|
|
# 运行遗传算法
|
|
result_population, _ = algorithms.eaSimple(population, toolbox, cxpb=crossover_prob,
|
|
mutpb=mutation_prob, ngen=n_generations,
|
|
verbose=False)
|
|
|
|
# 获取最优个体
|
|
best_individual = tools.selBest(result_population, k=1)[0]
|
|
selected_features = [index for index, val in enumerate(best_individual) if val == 1]
|
|
|
|
return selected_features
|