Files
micro_plastic/classification_model/Classification/CNN_SAE.py
2026-02-25 09:42:51 +08:00

331 lines
12 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.model_selection import train_test_split
import numpy as np
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
# 自定义数据集
class MyDataset(Dataset):
def __init__(self, specs, labels, augment=False):
self.specs = specs
self.labels = labels
self.augment = augment
def __getitem__(self, index):
spec, target = self.specs[index], self.labels[index]
if self.augment:
noise = 0.01 * torch.randn_like(spec)
spec = spec + noise
return spec, target
def __len__(self):
return len(self.specs)
# 数据标准化
def ZspProcess(X_train, X_test, y_train, y_test, need=True):
if need:
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_train = torch.tensor(X_train[:, np.newaxis, :], dtype=torch.float32)
X_test = torch.tensor(X_test[:, np.newaxis, :], dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)
data_train = MyDataset(X_train, y_train, augment=True)
data_test = MyDataset(X_test, y_test, augment=False)
return data_train, data_test
# Focal Loss
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
probs = torch.softmax(inputs, dim=1)
target_probs = probs[range(len(targets)), targets]
focal_weight = self.alpha * (1 - target_probs) ** self.gamma
log_prob = -torch.log(target_probs)
loss = focal_weight * log_prob
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
# 位置编码模块
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, embed_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / embed_dim))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
pe = pe.unsqueeze(0).transpose(0, 1) # (max_len, 1, embed_dim)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
# Transformer模块
class TransformerBlockWithSAE(nn.Module):
def __init__(self, embed_dim, ff_dim, dropout=0.1, max_len=5000):
super(TransformerBlockWithSAE, self).__init__()
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.scale = embed_dim ** 0.5
self.positional_encoding = PositionalEncoding(embed_dim, max_len)
self.feed_forward = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, embed_dim)
)
self.layernorm1 = nn.LayerNorm(embed_dim)
self.layernorm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.positional_encoding(x)
q = self.query(x)
k = self.key(x)
v = self.value(x)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v)
x = self.layernorm1(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.layernorm2(x + self.dropout(ff_output))
return x
# 修改后的 CNN+Transformer 模型
class CNNWithSAE(nn.Module):
def __init__(self, nls, embed_dim=96, ff_dim=192, dropout=0.1, max_len=5000):
super(CNNWithSAE, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(1, 64, kernel_size=5, stride=2, padding=2),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.2),
nn.MaxPool1d(2, 2)
)
self.conv2 = nn.Sequential(
nn.Conv1d(64, embed_dim, kernel_size=5, stride=2, padding=2),
nn.BatchNorm1d(embed_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.MaxPool1d(2, 2)
)
self.transformer = TransformerBlockWithSAE(embed_dim, ff_dim, dropout, max_len)
self.fc = nn.Sequential(
nn.Linear(embed_dim, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, nls)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.permute(2, 0, 1)
x = self.transformer(x)
x = x.mean(dim=0)
x = self.fc(x)
return x
# 修改后的 CNN+Transformer 模型
class CNNWithSAE(nn.Module):
def __init__(self, nls, embed_dim=96, ff_dim=192, dropout=0.1):
super(CNNWithSAE, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(1, 64, kernel_size=5, stride=2, padding=2),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.2),
nn.MaxPool1d(2, 2)
)
self.conv2 = nn.Sequential(
nn.Conv1d(64, embed_dim, kernel_size=5, stride=2, padding=2),
nn.BatchNorm1d(embed_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.MaxPool1d(2, 2)
)
self.transformer = TransformerBlockWithSAE(embed_dim, ff_dim, dropout)
self.fc = nn.Sequential(
nn.Linear(embed_dim, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, nls)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.permute(2, 0, 1) # 调整为 Transformer 输入格式 (seq_len, batch, embed_dim)
x = self.transformer(x)
x = x.mean(dim=0) # 平均池化
x = self.fc(x)
return x
# 训练函数(包含早停机制)
def TransformerTrain(X_train, X_val, y_train, y_val, BATCH_SIZE, n_epochs, nls, model_path, patience=10):
data_train, data_val = ZspProcess(X_train, X_val, y_train, y_val, need=True)
train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(data_val, batch_size=BATCH_SIZE, shuffle=False)
model = CNNWithSAE(nls=nls).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
criterion = FocalLoss(alpha=1, gamma=2).to(device)
scaler = GradScaler()
best_val_loss = float('inf')
early_stop_counter = 0
y_true_train, y_pred_train = [], []
for epoch in range(n_epochs):
model.train()
train_loss, train_acc = [], []
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
with autocast():
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
_, preds = torch.max(outputs, 1)
y_true_train.extend(labels.cpu().numpy())
y_pred_train.extend(preds.cpu().numpy())
acc = accuracy_score(labels.cpu(), preds.cpu())
train_loss.append(loss.item())
train_acc.append(acc)
# 验证集评估
model.eval()
val_loss = []
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss.append(loss.item())
avg_val_loss = np.mean(val_loss)
avg_train_loss = np.mean(train_loss)
avg_train_acc = np.mean(train_acc)
print(f"Epoch [{epoch+1}/{n_epochs}] - Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}, Val Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
early_stop_counter = 0
torch.save(model.state_dict(), model_path)
print("Model improved and saved.")
else:
early_stop_counter += 1
print(f"No improvement. Early stop counter: {early_stop_counter}/{patience}")
if early_stop_counter >= patience:
print("Early stopping triggered.")
break
# 训练集指标
train_accuracy = accuracy_score(y_true_train, y_pred_train)
train_precision = precision_score(y_true_train, y_pred_train, average='weighted')
train_recall = recall_score(y_true_train, y_pred_train, average='weighted')
train_f1 = f1_score(y_true_train, y_pred_train, average='weighted')
train_cm = confusion_matrix(y_true_train, y_pred_train)
train_metrics = {
"accuracy": train_accuracy,
"precision": train_precision,
"recall": train_recall,
"f1_score": train_f1,
"confusion_matrix": train_cm
}
return model, train_metrics
# 测试函数
def TransformerTest(X_test, y_test, BATCH_SIZE, nls, model_path):
data_test = ZspProcess(X_test, X_test, y_test, y_test, need=True)[1]
test_loader = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=False)
model = CNNWithSAE(nls=nls).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
y_true, y_pred = [], []
test_loss = []
criterion = FocalLoss(alpha=1, gamma=2).to(device) # 使用 FocalLoss
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(preds.cpu().numpy())
test_loss.append(loss.item())
# 测试集指标
test_accuracy = accuracy_score(y_true, y_pred)
test_precision = precision_score(y_true, y_pred, average='weighted')
test_recall = recall_score(y_true, y_pred, average='weighted')
test_f1 = f1_score(y_true, y_pred, average='weighted')
test_cm = confusion_matrix(y_true, y_pred)
test_metrics = {
"accuracy": test_accuracy,
"precision": test_precision,
"recall": test_recall,
"f1_score": test_f1,
"confusion_matrix": test_cm
}
print(f"Accuracy: {test_accuracy:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1 Score: {test_f1:.4f}")
print(f"Confusion Matrix:\n{test_cm}")
return test_metrics
def SAETrainAndTest(X,X_test, y, y_test, BATCH_SIZE, n_epochs, nls, model_path, val_split=0.2, patience=10):
# 从训练集中划分验证集
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=val_split, random_state=42)
# 训练模型并获取训练指标
model, train_metrics = TransformerTrain(X_train, X_val, y_train, y_val, BATCH_SIZE, n_epochs, nls, model_path, patience)
# 测试模型并获取测试指标
test_metrics = TransformerTest(X_test, y_test, BATCH_SIZE, nls, model_path)
return train_metrics, test_metrics