- import numpy as np
- import torch
- from torchvision.datasets import mnist
- import torchvision.transforms as transforms
- from torch.utils.data import DataLoader
- import torch.nn.functional as F
- import torch.optim as optim
- from torch import nn
- from sklearn.metrics import confusion_matrix
- import matplotlib.pyplot as plt
- import seaborn as sns
- import csv
- import pandas as pd
- # 设置超参数
- train_batch_size = 64
- test_batch_size = 64
- learning_rate = 0.001
- num_epochs = 10
-
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5])
- ])
- # 下载和预处理数据集
- train_dataset = mnist.MNIST('data', train=True, transform=transform, download=True)
- test_dataset = mnist.MNIST('data', train=False, transform=transform)
- # 创建数据加载器
- train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
- test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
- # 定义CNN模型
- class CNN(nn.Module):
- def __init__(self):
- super(CNN, self).__init__()
- self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
- self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
- self.fc1 = nn.Linear(1024, 256)
- self.fc2 = nn.Linear(256, 10)
-
- def forward(self, x):
- x = F.relu(F.max_pool2d(self.conv1(x), 2))
- x = F.relu(F.max_pool2d(self.conv2(x), 2))
- x = x.view(x.size(0), -1)
- x = F.relu(self.fc1(x))
- x = self.fc2(x)
- return F.log_softmax(x, dim=1)
- # 初始化模型、优化器和损失函数
- model = CNN()
- optimizer = optim.Adam(model.parameters(), lr=learning_rate)
- criterion = nn.CrossEntropyLoss()
- # 记录训练和测试过程中的损失和准确率
- train_losses = []
- test_losses = []
- train_accuracies = []
- test_accuracies = []
- for epoch in range(num_epochs):
- model.train()
- train_loss = 0.0
- correct = 0
- total = 0
-
- for batch_idx, (data, target) in enumerate(train_loader):
- optimizer.zero_grad()
- output = model(data)
- loss = criterion(output, target)
- loss.backward()
- optimizer.step()
- train_loss += loss.item()
-
- # 计算训练准确率
- _, predicted = output.max(1)
- total += target.size(0)
- correct += predicted.eq(target).sum().item()
-
- # 计算平均训练损失和训练准确率
- train_loss /= len(train_loader)
- train_accuracy = 100. * correct / total
- train_losses.append(train_loss)
- train_accuracies.append(train_accuracy) # 记录训练准确率
-
- # 测试模型
- model.eval()
- test_loss = 0.0
- correct = 0
- all_labels = []
- all_preds = []
-
- with torch.no_grad():
- for data, target in test_loader:
- output = model(data)
- test_loss += criterion(output, target).item()
- pred = output.argmax(dim=1, keepdim=True)
- correct += pred.eq(target.view_as(pred)).sum().item()
- all_labels.extend(target.numpy())
- all_preds.extend(pred.numpy())
- # 计算平均测试损失和测试准确率
- test_loss /= len(test_loader)
- test_accuracy = 100. * correct / len(test_loader.dataset)
- test_losses.append(test_loss)
- test_accuracies.append(test_accuracy)
-
- print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
- # 保存训练结果
- data = np.column_stack((train_losses,test_losses,train_accuracies, test_accuracies))
- np.savetxt("results.txt", data)
- # 绘制Loss曲线图
- plt.figure(figsize=(10, 2))
- plt.plot(train_losses, label='Train Loss', color='blue')
- plt.plot(test_losses, label='Test Loss', color='red')
- plt.xlabel('Epoch')
- plt.ylabel('Loss')
- plt.legend()
- plt.title('Loss Curve')
- plt.grid(True)
- plt.savefig('loss_curve.png')
- plt.show()
-
- # 绘制Accuracy曲线图
- plt.figure(figsize=(10, 2))
- plt.plot(train_accuracies, label='Train Accuracy', color='red') # 绘制训练准确率曲线
- plt.plot(test_accuracies, label='Test Accuracy', color='green')
- plt.xlabel('Epoch')
- plt.ylabel('Accuracy')
- plt.legend()
- plt.title('Accuracy Curve')
- plt.grid(True)
- plt.savefig('accuracy_curve.png')
- plt.show()
- # 计算混淆矩阵
- confusion_mat = confusion_matrix(all_labels, all_preds)
- plt.figure(figsize=(10, 8))
- sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
- plt.xlabel('Predicted Labels')
- plt.ylabel('True Labels')
- plt.title('Confusion Matrix')
- plt.savefig('confusion_matrix.png')
- plt.show()