191 lines
6.6 KiB
Python
191 lines
6.6 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import numpy as np
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
||
from sklearn.preprocessing import StandardScaler
|
||
from sklearn.model_selection import train_test_split
|
||
import matplotlib.pyplot as plt
|
||
from sklearn.manifold import TSNE
|
||
import pandas as pd
|
||
|
||
# 设备配置
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# 动态数据增强数据集
|
||
class SpectralDataset(Dataset):
|
||
def __init__(self, X, y, augment=False, input_length=462):
|
||
# 如果 X 是 DataFrame,则转换为 numpy
|
||
if isinstance(X, pd.DataFrame):
|
||
X = X.values # 转换为 numpy 数组
|
||
|
||
if isinstance(y, pd.Series) or isinstance(y, pd.DataFrame):
|
||
y = y.values # 确保 y 也是 numpy 数组
|
||
|
||
# 确保 X 形状为 (N, L),然后扩展维度到 (N, 1, L)
|
||
assert len(X.shape) == 2, f"Expected X to be 2D, got {X.shape}"
|
||
self.X = torch.tensor(X[:, np.newaxis, :], dtype=torch.float32) # (N, 1, L)
|
||
self.y = torch.tensor(y, dtype=torch.long) # y 应该是一维的
|
||
self.augment = augment
|
||
self.input_length = input_length
|
||
|
||
def __getitem__(self, index):
|
||
x = self.X[index] # Shape: (1, L)
|
||
y = self.y[index]
|
||
|
||
if self.augment:
|
||
# 添加噪声
|
||
if torch.rand(1) < 0.7:
|
||
noise_level = torch.rand(1) * 0.05
|
||
x += noise_level * torch.randn_like(x)
|
||
|
||
# 光谱平移
|
||
if torch.rand(1) < 0.5:
|
||
shift = torch.randint(-5, 5, (1,)).item()
|
||
x = torch.roll(x, shifts=shift, dims=-1)
|
||
|
||
# 局部遮挡
|
||
if torch.rand(1) < 0.3:
|
||
start = torch.randint(0, self.input_length - 10, (1,)).item()
|
||
x[0, start:start + 10] = 0.0
|
||
|
||
return x, y
|
||
|
||
def __len__(self):
|
||
return len(self.X)
|
||
|
||
# 光谱注意力模块
|
||
class SpectralAttention(nn.Module):
|
||
def __init__(self, channel, reduction=8):
|
||
super().__init__()
|
||
self.avg_pool = nn.AdaptiveAvgPool1d(1)
|
||
self.fc = nn.Sequential(
|
||
nn.Linear(channel, channel // reduction),
|
||
nn.GELU(),
|
||
nn.Linear(channel // reduction, channel),
|
||
nn.Sigmoid()
|
||
)
|
||
|
||
def forward(self, x):
|
||
b, c, l = x.size()
|
||
y = self.avg_pool(x).view(b, c)
|
||
y = self.fc(y).view(b, c, 1)
|
||
return x * y.expand_as(x)
|
||
|
||
|
||
# CNN 模型
|
||
class AgroSpecCNN(nn.Module):
|
||
def __init__(self, input_length=462, num_classes=21):
|
||
super().__init__()
|
||
self.input_length = input_length
|
||
self.features = nn.Sequential(
|
||
nn.Conv1d(1, 64, 5, padding=2), # 使用更大的 kernel
|
||
nn.BatchNorm1d(64),
|
||
nn.GELU(),
|
||
SpectralAttention(64),
|
||
nn.MaxPool1d(2), # 池化层
|
||
|
||
nn.Conv1d(64, 128, 5, padding=2),
|
||
nn.BatchNorm1d(128),
|
||
nn.GELU(),
|
||
SpectralAttention(128),
|
||
nn.AdaptiveAvgPool1d(self.input_length // 2), # 自适应池化根据输入大小调整
|
||
|
||
nn.Conv1d(128, 256, 5, padding=2),
|
||
nn.BatchNorm1d(256),
|
||
nn.GELU(),
|
||
nn.AdaptiveAvgPool1d(1) # 最终池化为 1 维
|
||
)
|
||
|
||
self.classifier = nn.Sequential(
|
||
nn.Linear(256, 128),
|
||
nn.GELU(),
|
||
nn.Dropout(0.3),
|
||
nn.Linear(128, num_classes)
|
||
)
|
||
|
||
def forward(self, x):
|
||
x = self.features(x)
|
||
x = x.view(x.size(0), -1) # 扁平化处理
|
||
return self.classifier(x)
|
||
|
||
|
||
# 训练过程
|
||
def CNNTrain(X_train, y_train, BATCH_SIZE, n_epochs, input_length, num_classes, model_path):
|
||
train_set = SpectralDataset(X_train, y_train, augment=True, input_length=input_length)
|
||
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
|
||
|
||
model = AgroSpecCNN(input_length, num_classes).to(device)
|
||
criterion = nn.CrossEntropyLoss()
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
||
|
||
for epoch in range(n_epochs):
|
||
model.train()
|
||
total_loss, correct, total = 0, 0, 0
|
||
|
||
for x, y in train_loader:
|
||
x, y = x.to(device), y.to(device)
|
||
|
||
optimizer.zero_grad()
|
||
outputs = model(x)
|
||
loss = criterion(outputs, y)
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
total_loss += loss.item()
|
||
_, predicted = outputs.max(1)
|
||
total += y.size(0)
|
||
correct += predicted.eq(y).sum().item()
|
||
|
||
print(f"Epoch {epoch+1}/{n_epochs} - Loss: {total_loss / len(train_loader):.4f}, Accuracy: {correct / total:.4f}")
|
||
|
||
torch.save(model.state_dict(), model_path)
|
||
return {"train_loss": total_loss / len(train_loader), "train_accuracy": correct / total}
|
||
|
||
|
||
# 测试过程
|
||
def CNNTest(X_test, y_test, BATCH_SIZE, input_length, num_classes, model_path):
|
||
test_set = SpectralDataset(X_test, y_test, augment=False, input_length=input_length)
|
||
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)
|
||
|
||
model = AgroSpecCNN(input_length, num_classes).to(device)
|
||
model.load_state_dict(torch.load(model_path))
|
||
model.eval()
|
||
|
||
total_loss, correct, total = 0, 0, 0
|
||
all_preds, all_targets = [], []
|
||
criterion = nn.CrossEntropyLoss()
|
||
|
||
with torch.no_grad():
|
||
for x, y in test_loader:
|
||
x, y = x.to(device), y.to(device)
|
||
|
||
outputs = model(x)
|
||
loss = criterion(outputs, y)
|
||
|
||
total_loss += loss.item()
|
||
_, predicted = outputs.max(1)
|
||
total += y.size(0)
|
||
correct += predicted.eq(y).sum().item()
|
||
|
||
all_preds.extend(predicted.cpu().numpy())
|
||
all_targets.extend(y.cpu().numpy())
|
||
|
||
metrics = {
|
||
"test_loss": total_loss / len(test_loader),
|
||
"test_accuracy": correct / total,
|
||
"precision": precision_score(all_targets, all_preds, average='weighted'),
|
||
"recall": recall_score(all_targets, all_preds, average='weighted'),
|
||
"f1": f1_score(all_targets, all_preds, average='weighted'),
|
||
"confusion_matrix": confusion_matrix(all_targets, all_preds)
|
||
}
|
||
return metrics
|
||
|
||
|
||
# 统一的 CNN 训练与测试调用
|
||
def CNN_deepseek(X_train, X_test, y_train, y_test, BATCH_SIZE, n_epochs, input_length, num_classes, model_path):
|
||
train_metrics = CNNTrain(X_train, y_train, BATCH_SIZE, n_epochs, input_length, num_classes, model_path)
|
||
test_metrics = CNNTest(X_test, y_test, BATCH_SIZE, input_length, num_classes, model_path)
|
||
return train_metrics, test_metrics
|