import torch import torch.nn as nn import torch.optim as optim from torch.cuda.amp import GradScaler, autocast from torch.utils.data import Dataset, DataLoader from sklearn.preprocessing import StandardScaler from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix from sklearn.model_selection import train_test_split import numpy as np import os device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # 自定义数据集 class MyDataset(Dataset): def __init__(self, specs, labels, augment=False): self.specs = specs self.labels = labels self.augment = augment def __getitem__(self, index): spec, target = self.specs[index], self.labels[index] if self.augment: noise = 0.01 * torch.randn_like(spec) spec = spec + noise return spec, target def __len__(self): return len(self.specs) # 数据标准化 def ZspProcess(X_train, X_test, y_train, y_test, need=True): if need: scaler = StandardScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) X_train = torch.tensor(X_train[:, np.newaxis, :], dtype=torch.float32) X_test = torch.tensor(X_test[:, np.newaxis, :], dtype=torch.float32) y_train = torch.tensor(y_train, dtype=torch.long) y_test = torch.tensor(y_test, dtype=torch.long) data_train = MyDataset(X_train, y_train, augment=True) data_test = MyDataset(X_test, y_test, augment=False) return data_train, data_test # Focal Loss class FocalLoss(nn.Module): def __init__(self, alpha=1, gamma=2, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): probs = torch.softmax(inputs, dim=1) target_probs = probs[range(len(targets)), targets] focal_weight = self.alpha * (1 - target_probs) ** self.gamma log_prob = -torch.log(target_probs) loss = focal_weight * log_prob if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() else: return loss # 位置编码模块 class PositionalEncoding(nn.Module): def __init__(self, embed_dim, max_len=5000): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_len, embed_dim) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_dim)) pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度 pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度 pe = pe.unsqueeze(0).transpose(0, 1) # (max_len, 1, embed_dim) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:x.size(0), :] # Transformer模块 class TransformerBlockWithSAE(nn.Module): def __init__(self, embed_dim, ff_dim, dropout=0.1, max_len=5000): super(TransformerBlockWithSAE, self).__init__() self.query = nn.Linear(embed_dim, embed_dim) self.key = nn.Linear(embed_dim, embed_dim) self.value = nn.Linear(embed_dim, embed_dim) self.scale = embed_dim ** 0.5 self.positional_encoding = PositionalEncoding(embed_dim, max_len) self.feed_forward = nn.Sequential( nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim) ) self.layernorm1 = nn.LayerNorm(embed_dim) self.layernorm2 = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.positional_encoding(x) q = self.query(x) k = self.key(x) v = self.value(x) attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale attn_weights = torch.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, v) x = self.layernorm1(x + self.dropout(attn_output)) ff_output = self.feed_forward(x) x = self.layernorm2(x + self.dropout(ff_output)) return x # 修改后的 CNN+Transformer 模型 class CNNWithSAE(nn.Module): def __init__(self, nls, embed_dim=96, ff_dim=192, dropout=0.1, max_len=5000): super(CNNWithSAE, self).__init__() self.conv1 = nn.Sequential( nn.Conv1d(1, 64, kernel_size=5, stride=2, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2), nn.MaxPool1d(2, 2) ) self.conv2 = nn.Sequential( nn.Conv1d(64, embed_dim, kernel_size=5, stride=2, padding=2), nn.BatchNorm1d(embed_dim), nn.ReLU(), nn.Dropout(0.2), nn.MaxPool1d(2, 2) ) self.transformer = TransformerBlockWithSAE(embed_dim, ff_dim, dropout, max_len) self.fc = nn.Sequential( nn.Linear(embed_dim, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, nls) ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.permute(2, 0, 1) x = self.transformer(x) x = x.mean(dim=0) x = self.fc(x) return x # 修改后的 CNN+Transformer 模型 class CNNWithSAE(nn.Module): def __init__(self, nls, embed_dim=96, ff_dim=192, dropout=0.1): super(CNNWithSAE, self).__init__() self.conv1 = nn.Sequential( nn.Conv1d(1, 64, kernel_size=5, stride=2, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2), nn.MaxPool1d(2, 2) ) self.conv2 = nn.Sequential( nn.Conv1d(64, embed_dim, kernel_size=5, stride=2, padding=2), nn.BatchNorm1d(embed_dim), nn.ReLU(), nn.Dropout(0.2), nn.MaxPool1d(2, 2) ) self.transformer = TransformerBlockWithSAE(embed_dim, ff_dim, dropout) self.fc = nn.Sequential( nn.Linear(embed_dim, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, nls) ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.permute(2, 0, 1) # 调整为 Transformer 输入格式 (seq_len, batch, embed_dim) x = self.transformer(x) x = x.mean(dim=0) # 平均池化 x = self.fc(x) return x # 训练函数(包含早停机制) def TransformerTrain(X_train, X_val, y_train, y_val, BATCH_SIZE, n_epochs, nls, model_path, patience=10): data_train, data_val = ZspProcess(X_train, X_val, y_train, y_val, need=True) train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True) val_loader = DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=False) model = CNNWithSAE(nls=nls).to(device) optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5) criterion = FocalLoss(alpha=1, gamma=2).to(device) scaler = GradScaler() best_val_loss = float('inf') early_stop_counter = 0 y_true_train, y_pred_train = [], [] for epoch in range(n_epochs): model.train() train_loss, train_acc = [], [] for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() with autocast(): outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() _, preds = torch.max(outputs, 1) y_true_train.extend(labels.cpu().numpy()) y_pred_train.extend(preds.cpu().numpy()) acc = accuracy_score(labels.cpu(), preds.cpu()) train_loss.append(loss.item()) train_acc.append(acc) # 验证集评估 model.eval() val_loss = [] with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss.append(loss.item()) avg_val_loss = np.mean(val_loss) avg_train_loss = np.mean(train_loss) avg_train_acc = np.mean(train_acc) print(f"Epoch [{epoch+1}/{n_epochs}] - Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}, Val Loss: {avg_val_loss:.4f}") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss early_stop_counter = 0 torch.save(model.state_dict(), model_path) print("Model improved and saved.") else: early_stop_counter += 1 print(f"No improvement. Early stop counter: {early_stop_counter}/{patience}") if early_stop_counter >= patience: print("Early stopping triggered.") break # 训练集指标 train_accuracy = accuracy_score(y_true_train, y_pred_train) train_precision = precision_score(y_true_train, y_pred_train, average='weighted') train_recall = recall_score(y_true_train, y_pred_train, average='weighted') train_f1 = f1_score(y_true_train, y_pred_train, average='weighted') train_cm = confusion_matrix(y_true_train, y_pred_train) train_metrics = { "accuracy": train_accuracy, "precision": train_precision, "recall": train_recall, "f1_score": train_f1, "confusion_matrix": train_cm } return model, train_metrics # 测试函数 def TransformerTest(X_test, y_test, BATCH_SIZE, nls, model_path): data_test = ZspProcess(X_test, X_test, y_test, y_test, need=True)[1] test_loader = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=False) model = CNNWithSAE(nls=nls).to(device) model.load_state_dict(torch.load(model_path)) model.eval() y_true, y_pred = [], [] test_loss = [] criterion = FocalLoss(alpha=1, gamma=2).to(device) # 使用 FocalLoss with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) y_true.extend(labels.cpu().numpy()) y_pred.extend(preds.cpu().numpy()) test_loss.append(loss.item()) # 测试集指标 test_accuracy = accuracy_score(y_true, y_pred) test_precision = precision_score(y_true, y_pred, average='weighted') test_recall = recall_score(y_true, y_pred, average='weighted') test_f1 = f1_score(y_true, y_pred, average='weighted') test_cm = confusion_matrix(y_true, y_pred) test_metrics = { "accuracy": test_accuracy, "precision": test_precision, "recall": test_recall, "f1_score": test_f1, "confusion_matrix": test_cm } print(f"Accuracy: {test_accuracy:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1 Score: {test_f1:.4f}") print(f"Confusion Matrix:\n{test_cm}") return test_metrics def SAETrainAndTest(X,X_test, y, y_test, BATCH_SIZE, n_epochs, nls, model_path, val_split=0.2, patience=10): # 从训练集中划分验证集 X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=val_split, random_state=42) # 训练模型并获取训练指标 model, train_metrics = TransformerTrain(X_train, X_val, y_train, y_val, BATCH_SIZE, n_epochs, nls, model_path, patience) # 测试模型并获取测试指标 test_metrics = TransformerTest(X_test, y_test, BATCH_SIZE, nls, model_path) return train_metrics, test_metrics