• PyTorch使用神经网络进行手写数字识别实战(附源码,包括损失图像和准确率图像)


    全部源码请点赞关注收藏后评论区留言即可~~~

    下面使用torchvision.datasets.MNIST构建手写数字数据集。

    1:数据预处理

    PyTorch提供了torchvision.transforms用于处理数据及数据增强,它可以将数据从[0,255]映射到[0,1]

    2:读取训练数据

    准备好处理数据的流程后,就可以读取用于训练的数据了,torch.util.data.DataLoader提供了迭代数据,随机抽取数据,批量化数据等等功能 读取效果如下

    预处理过后的数据如下

     

    3:构建神经网络模型 

     下面构建用于识别手写数字的神经网络模型

    1. class MLP(nn.Module):
    2. def __init__(self):
    3. super(MLP,self).__init__()
    4. self.inputlayer=nn.Sequential(nn.Linear(28*28,256),nn.ReLU(),nn.Dropout(0.2))
    5. self.hiddenlayer=nn.Sequential(nn.Linear(256,256),nn.ReLU(),nn.Dropout(0.2))
    6. self.outputlayer=nn.Sequential(nn.Linear(256,10))
    7. def forward(self,x):
    8. x=x.view(x.size(0),-1)
    9. x=self.inputlayer(x)
    10. x=self.hiddenlayer(x)
    11. x=self.outputlayer(x)
    12. return x

    可以直接通过打印nn.Module的对象看到其网络结构

    4:模型评估 

     在准备好数据和模型后,就可以训练模型了,下面分别定义了数据处理和加载流程,模型,优化器,损失函数以及用准确率评估模型能力。

    得到的结果如下

    训练一次 可以看出比较混乱 没有说明规律可言 

    训练五次的损失函数如下 可见随着训练次数的增加是逐渐收敛的,规律也非常明显

     

     

     准确率图像如下

    最后 部分源码如下

    1. import torch
    2. import torchvision
    3. import torch.nn as nn
    4. from torch import optim
    5. from tqdm import tqdm
    6. import torch.utils.data.dataset
    7. mnist=torchvision.datasets.MNIST(root='~',train=True,download=True)
    8. for i,j in enumerate(np.random.randint(0,len(mnist),(10,))):
    9. data,label=mnist[j]
    10. plt.subplot(2,5,i+1)
    11. plt.show()
    12. trans=transforms.Compose(
    13. [
    14. transforms.ToTensor(),
    15. transforms.Normalize((0.1307,),(0.3081,))
    16. ]
    17. )
    18. normalized=trans(mnist[0][0])
    19. from torchvision import transforms
    20. mnist=torchvision.datasets.MNIST(root='~',train=True,download=True,transform=trans)
    21. def imshow(img):
    22. img=img*0.3081+0.1307
    23. npimg=img.numpy()
    24. plt.imshow(np.transpose(npimg,(1,2,0)))
    25. dataloader=DataLoader(mnist,batch_size=4,shuffle=True,num_workers=0)
    26. images,labels=next(iter(dataloader))
    27. imshow(torchvision.utils.make_grid(images))
    28. class MLP(nn.Module):
    29. def __init__(self):
    30. super(MLP,self).__init__()
    31. self.inputlayer=nn.Sequential(nn.Linear(28*28,256),nn.ReLU(),nn.Dropout(0.2))
    32. self.hiddenlayer=nn.Sequential(nn.Linear(256,256),nn.ReLU(),nn.Dropout(0.2))
    33. self.outputlayer=nn.Sequential(nn.Linear(256,10))
    34. def forward(self,x):
    35. x=x.view(x.size(0),-1)
    36. x=self.inputlayer(x)
    37. x=self.hiddenlayer(x)
    38. x=self.outputlayer(x)
    39. return x
    40. print(MLP())
    41. trans=transforms.Compose(
    42. [
    43. transforms.ToTensor(),
    44. transforms.Normalize((0.1307,),(0.3081,))
    45. ]
    46. )
    47. al=torchvision.datasets.MNIST(root='~',train=False,download=True,transform=trans)
    48. trainloader=DataLoader(mnist_train,batch_size=16,shuffle=True,num_workers=0)
    49. valloader=DataLoader(mnist_val,batch_size=16,shuffle=True,num_workers=0)
    50. #模型
    51. model=MLP()
    52. #优化器
    53. optimizer=oD(model.parameters(),lr=0.01,momentum=0.9)
    54. #损失函数
    55. celoss=nn.ssEntropyLoss()
    56. best_acc=0
    57. #计算准确率
    58. def accuracy(pred,target):
    59. pred_label=torch.amax(pred,1)
    60. correct=sum(pred_label==target).to(torch.float)
    61. return correct,len(pred)
    62. acc={'train':[],"val}
    63. loss_all={'train':[],"val":[]}
    64. for epoch in tqdm(range(5)):
    65. model.eval()
    66. numer_val,denumer_val,loss_tr=0.,0.,0.
    67. with torch.no_grad():
    68. for data,target in valloader:
    69. output=model(data)
    70. loss=celoss(output,target)
    71. loss_tr+=loss.data
    72. num,denum=accuracy(output,target)
    73. numer_val+=num
    74. denumer_val+=denum
    75. #设置为训练模式
    76. model.train()
    77. numer_tr,denumer_tr,loss_val=0.,0.,0.
    78. for data,target in trainloader:
    79. optizer.zero_grad()
    80. output=model(data)
    81. loss=celoss(output,target)
    82. loss_val+=loss.data
    83. loss.backward()
    84. optimer.step()
    85. num,denum=accuracy(output,target)
    86. numer_tr+=num
    87. denumer_tr+=denum
    88. loss_all['train'].append(loss_tr/len(trainloader))
    89. loss_all['val'].aend(lss_val/len(valloader))
    90. acc['train'].pend(numer_tr/denumer_tr)
    91. acc['val'].append(numer_val/denumer_val)
    92. """
    93. plt.plot(loss_all['train'])
    94. plt.plot(loss_all['val'])
    95. """
    96. plt.plot(acc['train'])
    97. plt.plot(acc['val'])
    98. plt.show()

     创作不易 觉得有帮助请点赞关注收藏~~~

  • 相关阅读:
    奇瑞新能源无界Pro 高科技配置+强续航实力征战纯电市场
    如何通过SK集成chatGPT实现DotNet项目工程化?
    【牛客网-面试必刷TOP101】二分查找题目
    SQL数据库的基本操作流程
    事务隔离级别
    配置与管理Samba服务器实例
    基于 SpringBoot + MyBatis 的在线音乐播放器
    数据库问题记录(粗略版)oracle、mysql等主流数据库通用
    代码随想录算法训练营第二十八天|LeetCode93 复原IP地址、LeetCode78 子集
    2022.5.29-参加工信部蓝桥杯青少组国赛(二等奖)
  • 原文地址:https://blog.csdn.net/jiebaoshayebuhui/article/details/127783771