初始提交
This commit is contained in:
49
classification_model/Classification/Cls.py
Normal file
49
classification_model/Classification/Cls.py
Normal file
@ -0,0 +1,49 @@
|
||||
from classification_model.Classification.ClassicCls import SVM, PLS_DA, RF, XGBoost, LightGBM, CatBoost,LogisticRegressionModel,AdaBoost,KNN
|
||||
# from Classification.CNN import CNN
|
||||
# from Classification.CNN_Transfomer import TransformerTrainAndTest
|
||||
# from Classification.CNN_SAE import SAETrainAndTest
|
||||
# from Classification.SAE import SAE
|
||||
# from Classification.CNN_deepseek import CNN_deepseek
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
|
||||
def QualitativeAnalysis(model, X_train, X_test, y_train, y_test, n_jobs=-1):
|
||||
"""
|
||||
根据模型名称调用不同的分类模型,并返回训练集和测试集的评估指标。
|
||||
|
||||
参数:
|
||||
- model: 要使用的分类模型名称
|
||||
- X_train, X_test: 训练集和测试集的特征数据
|
||||
- y_train, y_test: 训练集和测试集的标签数据
|
||||
- n_jobs: 使用的核心数量,适用于支持多线程的模型
|
||||
|
||||
返回:
|
||||
- train_metrics: 包含训练集 accuracy, precision, recall, f1_score 的字典
|
||||
- test_metrics: 包含测试集 accuracy, precision, recall, f1_score 的字典
|
||||
"""
|
||||
|
||||
if model == "PLS_DA":
|
||||
train_metrics, test_metrics = PLS_DA(X_train, X_test, y_train, y_test)
|
||||
elif model == "ANN":
|
||||
train_metrics, test_metrics = ANN(X_train, X_test, y_train, y_test)
|
||||
elif model == "SVM":
|
||||
train_metrics, test_metrics = SVM(X_train, X_test, y_train, y_test)
|
||||
elif model == "RF":
|
||||
train_metrics, test_metrics = RF(X_train, X_test, y_train, y_test, n_jobs=n_jobs)
|
||||
elif model == "LogisticRegression":
|
||||
train_metrics, test_metrics = LogisticRegressionModel(X_train, X_test, y_train, y_test, penalty='l2', C=1.0, solver='lbfgs')
|
||||
elif model == "XGBoost":
|
||||
train_metrics, test_metrics = XGBoost(X_train, X_test, y_train, y_test, n_estimators=100, learning_rate=0.1, max_depth=3)
|
||||
elif model == "LightGBM":
|
||||
train_metrics, test_metrics = LightGBM(X_train, X_test, y_train, y_test, n_estimators=100, learning_rate=0.1, max_depth=-1, num_leaves=31)
|
||||
elif model == "CatBoost":
|
||||
train_metrics, test_metrics = CatBoost(X_train, X_test, y_train, y_test, iterations=500, learning_rate=0.1, depth=6)
|
||||
elif model == "AdaBoost":
|
||||
train_metrics, test_metrics = AdaBoost(X_train, X_test, y_train, y_test, n_estimators=50, learning_rate=1.0)
|
||||
elif model == 'KNN':
|
||||
train_metrics, test_metrics = KNN(X_train, X_test, y_train, y_test, n_neighbors=5)
|
||||
else:
|
||||
print("No such model for Qualitative Analysis")
|
||||
return None, None
|
||||
|
||||
return train_metrics, test_metrics
|
||||
Reference in New Issue
Block a user