from classification_model.Classification.ClassicCls import SVM, PLS_DA, RF, XGBoost, LightGBM, CatBoost,LogisticRegressionModel,AdaBoost,KNN # from Classification.CNN import CNN # from Classification.CNN_Transfomer import TransformerTrainAndTest # from Classification.CNN_SAE import SAETrainAndTest # from Classification.SAE import SAE # from Classification.CNN_deepseek import CNN_deepseek from multiprocessing import Pool, cpu_count def QualitativeAnalysis(model, X_train, X_test, y_train, y_test, n_jobs=-1): """ 根据模型名称调用不同的分类模型,并返回训练集和测试集的评估指标。 参数: - model: 要使用的分类模型名称 - X_train, X_test: 训练集和测试集的特征数据 - y_train, y_test: 训练集和测试集的标签数据 - n_jobs: 使用的核心数量,适用于支持多线程的模型 返回: - train_metrics: 包含训练集 accuracy, precision, recall, f1_score 的字典 - test_metrics: 包含测试集 accuracy, precision, recall, f1_score 的字典 """ if model == "PLS_DA": train_metrics, test_metrics = PLS_DA(X_train, X_test, y_train, y_test) elif model == "ANN": train_metrics, test_metrics = ANN(X_train, X_test, y_train, y_test) elif model == "SVM": train_metrics, test_metrics = SVM(X_train, X_test, y_train, y_test) elif model == "RF": train_metrics, test_metrics = RF(X_train, X_test, y_train, y_test, n_jobs=n_jobs) elif model == "LogisticRegression": train_metrics, test_metrics = LogisticRegressionModel(X_train, X_test, y_train, y_test, penalty='l2', C=1.0, solver='lbfgs') elif model == "XGBoost": train_metrics, test_metrics = XGBoost(X_train, X_test, y_train, y_test, n_estimators=100, learning_rate=0.1, max_depth=3) elif model == "LightGBM": train_metrics, test_metrics = LightGBM(X_train, X_test, y_train, y_test, n_estimators=100, learning_rate=0.1, max_depth=-1, num_leaves=31) elif model == "CatBoost": train_metrics, test_metrics = CatBoost(X_train, X_test, y_train, y_test, iterations=500, learning_rate=0.1, depth=6) elif model == "AdaBoost": train_metrics, test_metrics = AdaBoost(X_train, X_test, y_train, y_test, n_estimators=50, learning_rate=1.0) elif model == 'KNN': train_metrics, test_metrics = KNN(X_train, X_test, y_train, y_test, n_neighbors=5) else: print("No such model for Qualitative Analysis") return None, None return train_metrics, test_metrics