from classification_model.Classification.CNN_HYper import optimize_CNN from classification_model.Classification.CNN_Transfomer import TransformerTrainAndTest from classification_model.Classification.CNN_SAE import SAETrainAndTest from classification_model.Classification.SAE import SAE from classification_model.Classification.CNN_deepseek import CNN_deepseek from multiprocessing import Pool, cpu_count # 贝叶斯优化模型调用 from classification_model.Classification.ClassicClsHY import optimize_SVM, optimize_KNN, optimize_XGBoost, optimize_RF, optimize_CatBoost, optimize_LogisticRegression, optimize_ANN def QualitativeAnalysis(model, X_train, X_test, y_train, y_test, n_jobs=-1): """ 根据模型名称调用不同的分类模型,并返回训练集和测试集的评估指标。 参数: - model: 要使用的分类模型名称 - X_train, X_test: 训练集和测试集的特征数据 - y_train, y_test: 训练集和测试集的标签数据 - n_jobs: 使用的核心数量,适用于支持多线程的模型 返回: - train_metrics: 包含训练集 accuracy, precision, recall, f1_score 的字典 - test_metrics: 包含测试集 accuracy, precision, recall, f1_score 的字典 """ if model == "ANN": best_params, train_metrics, test_metrics = optimize_ANN(X_train, y_train, X_test, y_test) elif model == "SVM": best_params, train_metrics, test_metrics = optimize_SVM(X_train, y_train, X_test, y_test) elif model == "RF": best_params, train_metrics, test_metrics = optimize_RF(X_train, y_train, X_test, y_test) elif model == "optimize_CNN": best_params,train_metrics, test_metrics = optimize_CNN(X_train, X_test, y_train, y_test, model_path=r'H:\arithmetic\python\opensa-main(local)\opensa-main\OpenSA\tensorboard_logs\model_best.pth') elif model == "LogisticRegression": best_params, train_metrics, test_metrics = optimize_LogisticRegression(X_train, y_train, X_test, y_test) elif model == "XGBoost": best_params, train_metrics, test_metrics = optimize_XGBoost(X_train, y_train, X_test, y_test) elif model == "CatBoost": best_params, train_metrics, test_metrics = optimize_CatBoost(X_train, y_train, X_test, y_test) elif model == 'KNN': best_params, train_metrics, test_metrics = optimize_KNN(X_train, y_train, X_test, y_test) else: print("No such model for Qualitative Analysis") return None, None return best_params,train_metrics, test_metrics