Files
micro_plastic/classification_model/Classification/CNN_deepseek.py
2026-02-25 09:42:51 +08:00

191 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import pandas as pd
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 动态数据增强数据集
class SpectralDataset(Dataset):
def __init__(self, X, y, augment=False, input_length=462):
# 如果 X 是 DataFrame则转换为 numpy
if isinstance(X, pd.DataFrame):
X = X.values # 转换为 numpy 数组
if isinstance(y, pd.Series) or isinstance(y, pd.DataFrame):
y = y.values # 确保 y 也是 numpy 数组
# 确保 X 形状为 (N, L),然后扩展维度到 (N, 1, L)
assert len(X.shape) == 2, f"Expected X to be 2D, got {X.shape}"
self.X = torch.tensor(X[:, np.newaxis, :], dtype=torch.float32) # (N, 1, L)
self.y = torch.tensor(y, dtype=torch.long) # y 应该是一维的
self.augment = augment
self.input_length = input_length
def __getitem__(self, index):
x = self.X[index] # Shape: (1, L)
y = self.y[index]
if self.augment:
# 添加噪声
if torch.rand(1) < 0.7:
noise_level = torch.rand(1) * 0.05
x += noise_level * torch.randn_like(x)
# 光谱平移
if torch.rand(1) < 0.5:
shift = torch.randint(-5, 5, (1,)).item()
x = torch.roll(x, shifts=shift, dims=-1)
# 局部遮挡
if torch.rand(1) < 0.3:
start = torch.randint(0, self.input_length - 10, (1,)).item()
x[0, start:start + 10] = 0.0
return x, y
def __len__(self):
return len(self.X)
# 光谱注意力模块
class SpectralAttention(nn.Module):
def __init__(self, channel, reduction=8):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.GELU(),
nn.Linear(channel // reduction, channel),
nn.Sigmoid()
)
def forward(self, x):
b, c, l = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1)
return x * y.expand_as(x)
# CNN 模型
class AgroSpecCNN(nn.Module):
def __init__(self, input_length=462, num_classes=21):
super().__init__()
self.input_length = input_length
self.features = nn.Sequential(
nn.Conv1d(1, 64, 5, padding=2), # 使用更大的 kernel
nn.BatchNorm1d(64),
nn.GELU(),
SpectralAttention(64),
nn.MaxPool1d(2), # 池化层
nn.Conv1d(64, 128, 5, padding=2),
nn.BatchNorm1d(128),
nn.GELU(),
SpectralAttention(128),
nn.AdaptiveAvgPool1d(self.input_length // 2), # 自适应池化根据输入大小调整
nn.Conv1d(128, 256, 5, padding=2),
nn.BatchNorm1d(256),
nn.GELU(),
nn.AdaptiveAvgPool1d(1) # 最终池化为 1 维
)
self.classifier = nn.Sequential(
nn.Linear(256, 128),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1) # 扁平化处理
return self.classifier(x)
# 训练过程
def CNNTrain(X_train, y_train, BATCH_SIZE, n_epochs, input_length, num_classes, model_path):
train_set = SpectralDataset(X_train, y_train, augment=True, input_length=input_length)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
model = AgroSpecCNN(input_length, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
for epoch in range(n_epochs):
model.train()
total_loss, correct, total = 0, 0, 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
outputs = model(x)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = outputs.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
print(f"Epoch {epoch+1}/{n_epochs} - Loss: {total_loss / len(train_loader):.4f}, Accuracy: {correct / total:.4f}")
torch.save(model.state_dict(), model_path)
return {"train_loss": total_loss / len(train_loader), "train_accuracy": correct / total}
# 测试过程
def CNNTest(X_test, y_test, BATCH_SIZE, input_length, num_classes, model_path):
test_set = SpectralDataset(X_test, y_test, augment=False, input_length=input_length)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)
model = AgroSpecCNN(input_length, num_classes).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
total_loss, correct, total = 0, 0, 0
all_preds, all_targets = [], []
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
outputs = model(x)
loss = criterion(outputs, y)
total_loss += loss.item()
_, predicted = outputs.max(1)
total += y.size(0)
correct += predicted.eq(y).sum().item()
all_preds.extend(predicted.cpu().numpy())
all_targets.extend(y.cpu().numpy())
metrics = {
"test_loss": total_loss / len(test_loader),
"test_accuracy": correct / total,
"precision": precision_score(all_targets, all_preds, average='weighted'),
"recall": recall_score(all_targets, all_preds, average='weighted'),
"f1": f1_score(all_targets, all_preds, average='weighted'),
"confusion_matrix": confusion_matrix(all_targets, all_preds)
}
return metrics
# 统一的 CNN 训练与测试调用
def CNN_deepseek(X_train, X_test, y_train, y_test, BATCH_SIZE, n_epochs, input_length, num_classes, model_path):
train_metrics = CNNTrain(X_train, y_train, BATCH_SIZE, n_epochs, input_length, num_classes, model_path)
test_metrics = CNNTest(X_test, y_test, BATCH_SIZE, input_length, num_classes, model_path)
return train_metrics, test_metrics