Files
2026-02-25 09:42:51 +08:00

191 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
import torch.utils.data as data
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MyDataset(data.Dataset):
def __init__(self, specs, labels):
self.specs = specs
self.labels = labels
def __getitem__(self, index):
spec, target = self.specs[index], self.labels[index]
return spec, target
def __len__(self):
return len(self.specs)
class AutoEncoder(nn.Module):
def __init__(self, inputDim, hiddenDim):
super().__init__()
self.inputDim = inputDim
self.hiddenDim = hiddenDim
self.encoder = nn.Linear(inputDim, hiddenDim, bias=True)
self.decoder = nn.Linear(hiddenDim, inputDim, bias=True)
self.act = F.relu
def forward(self, x, rep=False):
hidden = self.encoder(x)
hidden = self.act(hidden)
if rep:
return hidden
else:
out = self.decoder(hidden)
return out
class SAE(nn.Module):
def __init__(self, encoderList, output_dim):
super().__init__()
self.encoderList = encoderList
self.en1 = encoderList[0]
self.en2 = encoderList[1]
self.fc = nn.Linear(128, output_dim, bias=True) # 分类层输出维度为 num_classes
def forward(self, x):
out = x
out = self.en1(out, rep=True)
out = self.en2(out, rep=True)
out = self.fc(out)
return out
class SAE_net(object):
def __init__(self, AE_epoch=200, SAE_epoch=200,
input_dim=404, hidden1_dim=512,
hidden2_dim=128, output_dim=4, # 默认4类可在调用时传入 num_classes
batch_size=128):
self.AE_epoch = AE_epoch
self.SAE_epoch = SAE_epoch
self.input_dim = input_dim
self.hidden1_dim = hidden1_dim
self.hidden2_dim = hidden2_dim
self.output_dim = output_dim
self.batch_size = batch_size
self.train_loader = None
encoder1 = AutoEncoder(self.input_dim, self.hidden1_dim)
encoder2 = AutoEncoder(self.hidden1_dim, self.hidden2_dim)
self.encoder_list = [encoder1, encoder2]
def trainAE(self, x_train, y_train, encoderList, trainLayer, batchSize, epoch, useCuda=False):
if useCuda:
for encoder in encoderList:
encoder.to(device)
optimizer = optim.Adam(encoderList[trainLayer].parameters())
criterion = nn.MSELoss()
data_train = MyDataset(x_train, y_train)
self.train_loader = torch.utils.data.DataLoader(data_train, batch_size=batchSize, shuffle=True)
for _ in range(epoch):
for batch_idx, (x, target) in enumerate(self.train_loader):
optimizer.zero_grad()
if useCuda:
x, target = x.to(device), target.to(device)
x = Variable(x).type(torch.FloatTensor)
x = x.view(x.size(0), -1)
out = x
if trainLayer != 0:
for i in range(trainLayer):
out = encoderList[i](out, rep=True)
pred = encoderList[trainLayer](out, rep=False).cpu()
loss = criterion(pred, out)
loss.backward()
optimizer.step()
def trainClassifier(self, model, epoch, useCuda=False):
if useCuda:
model = model.to(device)
for param in model.parameters():
param.requires_grad = True
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
for _ in range(epoch):
for batch_idx, (x, target) in enumerate(self.train_loader):
optimizer.zero_grad()
if useCuda:
x, target = x.to(device), target.to(device)
x = Variable(x).type(torch.FloatTensor)
x = x.view(-1, self.input_dim)
out = model(x)
loss = criterion(out, target)
loss.backward()
optimizer.step()
self.model = model
def fit(self, x_train=None, y_train=None, X_test=None, y_test=None):
x_train = x_train[:, np.newaxis, :]
x_train = torch.from_numpy(x_train).float()
for i in range(2):
self.trainAE(x_train=x_train, y_train=y_train,
encoderList=self.encoder_list, trainLayer=i,
batchSize=self.batch_size, epoch=self.AE_epoch)
model = SAE(encoderList=self.encoder_list, output_dim=self.output_dim)
# 训练分类器并获取训练集的评估指标
train_accuracy, train_precision, train_recall, train_f1, train_cm = self.trainClassifier(model=model, epoch=self.SAE_epoch, X_train=x_train, y_train=y_train)
# 计算测试集的评估指标
test_accuracy, test_precision, test_recall, test_f1, test_cm = self.evaluate(model, X_test, y_test)
# 返回训练集和测试集的评估结果
train_metrics = {
"accuracy": train_accuracy,
"precision": train_precision,
"recall": train_recall,
"f1_score": train_f1,
"confusion_matrix": train_cm
}
test_metrics = {
"accuracy": test_accuracy,
"precision": test_precision,
"recall": test_recall,
"f1_score": test_f1,
"confusion_matrix": test_cm
}
return train_metrics, test_metrics
def evaluate(self, model, X_test, y_test):
X_test = torch.from_numpy(X_test).float()
X_test = X_test[:, np.newaxis, :]
X_test = Variable(X_test).view(-1, self.input_dim)
out = model(X_test)
_, y_pred = torch.max(out, 1)
# 计算准确率、精确率、召回率、F1分数和混淆矩阵
accuracy = accuracy_score(y_test, y_pred.numpy())
precision = precision_score(y_test, y_pred.numpy(), average='weighted')
recall = recall_score(y_test, y_pred.numpy(), average='weighted')
f1 = f1_score(y_test, y_pred.numpy(), average='weighted')
cm = confusion_matrix(y_test, y_pred.numpy())
return accuracy, precision, recall, f1, cm
def SAE(X_train, y_train, X_test, y_test, num_classes=4):
clf = SAE_net(output_dim=num_classes)
train_metrics, test_metrics = clf.fit(X_train, y_train, X_test, y_test)
return train_metrics, test_metrics