• pytorch的使用:卷积神经网络模块


    1.读取数据

    • 分别构建训练集和测试集(验证集)
    • DataLoader来迭代取数据
    • 使用transforms将数据转换为tensor格式
    1. # 定义超参数
    2. input_size = 28 #图像的总尺寸28*28
    3. num_classes = 10 #标签的种类数
    4. num_epochs = 3 #训练的总循环周期
    5. batch_size = 64 #一个撮(批次)的大小,64张图片
    6. # 训练集
    7. train_dataset = datasets.MNIST(root='./data',
    8. train=True,
    9. transform=transforms.ToTensor(),
    10. download=True)
    11. # 测试集
    12. test_dataset = datasets.MNIST(root='./data',
    13. train=False,
    14. transform=transforms.ToTensor())
    15. # 构建batch数据
    16. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
    17. batch_size=batch_size,
    18. shuffle=True)
    19. test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
    20. batch_size=batch_size,
    21. shuffle=True)

     

    1.卷积神经网络模块

    pytorch与tensorflow 2相比,pytorch更注重过程,pytoch卷积模块需要指定输入通道数和输出通道数,卷积核的参数总数为卷积核K x 卷积核K x 输入通道数 x 输出通道数,卷积模块padding也需要自己计算,如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1,pytoch在计算下一层特征大小时,采用向下取整的原则,另外pytorch特征维度为batch*channels*h*w,channels在第二维度。

    1. class CNN(nn.Module):
    2. def __init__(self):
    3. super(CNN, self).__init__()
    4. self.conv1 = nn.Sequential( # 输入大小 (1, 28, 28)
    5. nn.Conv2d(
    6. in_channels=1, # 灰度图
    7. out_channels=16, # 要得到几多少个特征图
    8. kernel_size=5, # 卷积核大小
    9. stride=1, # 步长
    10. padding=2, # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1
    11. ), # 输出的特征图为 (16, 28, 28)
    12. nn.ReLU(), # relu层
    13. nn.MaxPool2d(kernel_size=2), # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14)
    14. )
    15. self.conv2 = nn.Sequential( # 下一个套餐的输入 (16, 14, 14)
    16. nn.Conv2d(16, 32, 5, 1, 2), # 输出 (32, 14, 14)
    17. nn.ReLU(), # relu层
    18. nn.Conv2d(32, 32, 5, 1, 2),
    19. nn.ReLU(),
    20. nn.MaxPool2d(2), # 输出 (32, 7, 7)
    21. )
    22. self.conv3 = nn.Sequential( # 下一个套餐的输入 (16, 14, 14)
    23. nn.Conv2d(32, 64, 5, 1, 2), # 输出 (32, 14, 14)
    24. nn.ReLU(), # 输出 (32, 7, 7)
    25. )
    26. self.out = nn.Linear(64 * 7 * 7, 10) # 全连接层得到的结果
    27. def forward(self, x):
    28. x = self.conv1(x)
    29. x = self.conv2(x)
    30. x = self.conv3(x)
    31. x = x.view(x.size(0), -1) # flatten操作,结果为:(batch_size, 32 * 7 * 7)
    32. output = self.out(x)
    33. return output

    3.训练网络模型 

    定义准确率作为验证集评估指标 

    1. def accuracy(predictions, labels):
    2. pred = torch.max(predictions.data, 1)[1]
    3. rights = pred.eq(labels.data.view_as(pred)).sum()
    4. return rights, len(labels)

     

    1. # 实例化
    2. net = CNN()
    3. #损失函数
    4. criterion = nn.CrossEntropyLoss()
    5. #优化器
    6. optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法
    7. #开始训练循环
    8. for epoch in range(num_epochs):
    9. #当前epoch的结果保存下来
    10. train_rights = []
    11. for batch_idx, (data, target) in enumerate(train_loader): #针对容器中的每一个批进行循环
    12. net.train()
    13. output = net(data)
    14. loss = criterion(output, target)
    15. optimizer.zero_grad()
    16. loss.backward()
    17. optimizer.step()
    18. right = accuracy(output, target)
    19. train_rights.append(right)
    20. if batch_idx % 100 == 0:
    21. net.eval()
    22. val_rights = []
    23. for (data, target) in test_loader:
    24. output = net(data)
    25. right = accuracy(output, target)
    26. val_rights.append(right)
    27. #准确率计算
    28. train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
    29. val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))
    30. print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
    31. epoch, batch_idx * batch_size, len(train_loader.dataset),
    32. 100. * batch_idx / len(train_loader),
    33. loss.data,
    34. 100. * train_r[0].numpy() / train_r[1],
    35. 100. * val_r[0].numpy() / val_r[1]))
    当前epoch: 0 [0/60000 (0%)]	损失: 2.300918	训练集准确率: 10.94%	测试集正确率: 10.10%
    当前epoch: 0 [6400/60000 (11%)]	损失: 0.204191	训练集准确率: 78.06%	测试集正确率: 93.31%
    当前epoch: 0 [12800/60000 (21%)]	损失: 0.039503	训练集准确率: 86.51%	测试集正确率: 96.69%
    当前epoch: 0 [19200/60000 (32%)]	损失: 0.057866	训练集准确率: 89.93%	测试集正确率: 97.54%
    当前epoch: 0 [25600/60000 (43%)]	损失: 0.069566	训练集准确率: 91.68%	测试集正确率: 97.68%
    当前epoch: 0 [32000/60000 (53%)]	损失: 0.228793	训练集准确率: 92.85%	测试集正确率: 98.18%
    当前epoch: 0 [38400/60000 (64%)]	损失: 0.111003	训练集准确率: 93.72%	测试集正确率: 98.16%
    当前epoch: 0 [44800/60000 (75%)]	损失: 0.110226	训练集准确率: 94.28%	测试集正确率: 98.44%
    当前epoch: 0 [51200/60000 (85%)]	损失: 0.014538	训练集准确率: 94.78%	测试集正确率: 98.60%
    当前epoch: 0 [57600/60000 (96%)]	损失: 0.051019	训练集准确率: 95.14%	测试集正确率: 98.45%
    当前epoch: 1 [0/60000 (0%)]	损失: 0.036383	训练集准确率: 98.44%	测试集正确率: 98.68%
    当前epoch: 1 [6400/60000 (11%)]	损失: 0.088116	训练集准确率: 98.50%	测试集正确率: 98.37%
    当前epoch: 1 [12800/60000 (21%)]	损失: 0.120306	训练集准确率: 98.59%	测试集正确率: 98.97%
    当前epoch: 1 [19200/60000 (32%)]	损失: 0.030676	训练集准确率: 98.63%	测试集正确率: 98.83%
    当前epoch: 1 [25600/60000 (43%)]	损失: 0.068475	训练集准确率: 98.59%	测试集正确率: 98.87%
    当前epoch: 1 [32000/60000 (53%)]	损失: 0.033244	训练集准确率: 98.62%	测试集正确率: 99.03%
    当前epoch: 1 [38400/60000 (64%)]	损失: 0.024162	训练集准确率: 98.67%	测试集正确率: 98.81%
    当前epoch: 1 [44800/60000 (75%)]	损失: 0.006713	训练集准确率: 98.69%	测试集正确率: 98.17%
    当前epoch: 1 [51200/60000 (85%)]	损失: 0.009284	训练集准确率: 98.69%	测试集正确率: 98.97%
    当前epoch: 1 [57600/60000 (96%)]	损失: 0.036536	训练集准确率: 98.68%	测试集正确率: 98.97%
    当前epoch: 2 [0/60000 (0%)]	损失: 0.125235	训练集准确率: 98.44%	测试集正确率: 98.73%
    当前epoch: 2 [6400/60000 (11%)]	损失: 0.028075	训练集准确率: 99.13%	测试集正确率: 99.17%
    当前epoch: 2 [12800/60000 (21%)]	损失: 0.029663	训练集准确率: 99.26%	测试集正确率: 98.39%
    当前epoch: 2 [19200/60000 (32%)]	损失: 0.073855	训练集准确率: 99.20%	测试集正确率: 98.81%
    当前epoch: 2 [25600/60000 (43%)]	损失: 0.018130	训练集准确率: 99.16%	测试集正确率: 99.09%
    当前epoch: 2 [32000/60000 (53%)]	损失: 0.006968	训练集准确率: 99.15%	测试集正确率: 99.11%

     

     

     

  • 相关阅读:
    从CentOS6升级到CentOS7要注意的一些事项
    碰瓷“一带一路”
    android Java工程配置kotlin环境
    Blog搭建:pycharm+虚拟环境+django
    【TES600】青翼科技基于XC7K325T与TMS320C6678的通用信号处理平台
    Zabbix预处理和数据开源节流
    六、c++代码中的安全风险-fopen
    提示工程101|与 AI 交谈的技巧和艺术
    Ubuntu-18.04本地化部署Rustdesk服务器
    01 容器端口映射导致 302 存在问题 以及 nginx 对于 302 的 Location 的重写
  • 原文地址:https://blog.csdn.net/qq_52053775/article/details/126162918