48 lines
2.3 KiB
Python
48 lines
2.3 KiB
Python
|
|
# from Classification.CNN_HYper import
|
|
from classification_model.Classification.CNN_Transfomer import TransformerTrainAndTest
|
|
from classification_model.Classification.CNN_SAE import SAETrainAndTest
|
|
from classification_model.Classification.SAE import SAE
|
|
from classification_model.Classification.CNN_deepseek import CNN_deepseek
|
|
from multiprocessing import Pool, cpu_count
|
|
|
|
# 贝叶斯优化模型调用
|
|
from classification_model.Classification.ClassicCls_网格搜索 import optimize_SVM, optimize_KNN, optimize_XGBoost, optimize_RF, optimize_CatBoost, optimize_LogisticRegression
|
|
|
|
def QualitativeAnalysis(model, X_train, X_test, y_train, y_test, n_jobs=-1):
|
|
"""
|
|
根据模型名称调用不同的分类模型,并返回训练集和测试集的评估指标。
|
|
|
|
参数:
|
|
- model: 要使用的分类模型名称
|
|
- X_train, X_test: 训练集和测试集的特征数据
|
|
- y_train, y_test: 训练集和测试集的标签数据
|
|
- n_jobs: 使用的核心数量,适用于支持多线程的模型
|
|
|
|
返回:
|
|
- train_metrics: 包含训练集 accuracy, precision, recall, f1_score 的字典
|
|
- test_metrics: 包含测试集 accuracy, precision, recall, f1_score 的字典
|
|
"""
|
|
|
|
|
|
|
|
if model == "SVM":
|
|
best_params, train_metrics, test_metrics = optimize_SVM(X_train, y_train, X_test, y_test)
|
|
elif model == "RF":
|
|
best_params, train_metrics, test_metrics = optimize_RF(X_train, y_train, X_test, y_test)
|
|
# elif model == "optimize_CNN":
|
|
# best_params, train_metrics, test_metrics = optimize_hyperparameters(X_train, X_test, y_train, y_test, nls=10, n_iter=10)
|
|
elif model == "LogisticRegression":
|
|
best_params, train_metrics, test_metrics = optimize_LogisticRegression(X_train, y_train, X_test, y_test)
|
|
elif model == "XGBoost":
|
|
best_params, train_metrics, test_metrics = optimize_XGBoost(X_train, y_train, X_test, y_test)
|
|
elif model == "CatBoost":
|
|
best_params, train_metrics, test_metrics = optimize_CatBoost(X_train, y_train, X_test, y_test)
|
|
elif model == 'KNN':
|
|
best_params, train_metrics, test_metrics = optimize_KNN(X_train, y_train, X_test, y_test)
|
|
else:
|
|
print("No such model for Qualitative Analysis")
|
|
return None, None
|
|
|
|
return best_params,train_metrics, test_metrics
|