• PyTorch使用神经网络进行手写数字识别实战(附源码,包括损失图像和准确率图像)


    全部源码请点赞关注收藏后评论区留言即可~~~

    下面使用torchvision.datasets.MNIST构建手写数字数据集。

    1:数据预处理

    PyTorch提供了torchvision.transforms用于处理数据及数据增强,它可以将数据从[0,255]映射到[0,1]

    2:读取训练数据

    准备好处理数据的流程后,就可以读取用于训练的数据了,torch.util.data.DataLoader提供了迭代数据,随机抽取数据,批量化数据等等功能 读取效果如下

    预处理过后的数据如下

     

    3:构建神经网络模型 

     下面构建用于识别手写数字的神经网络模型

    1. class MLP(nn.Module):
    2. def __init__(self):
    3. super(MLP,self).__init__()
    4. self.inputlayer=nn.Sequential(nn.Linear(28*28,256),nn.ReLU(),nn.Dropout(0.2))
    5. self.hiddenlayer=nn.Sequential(nn.Linear(256,256),nn.ReLU(),nn.Dropout(0.2))
    6. self.outputlayer=nn.Sequential(nn.Linear(256,10))
    7. def forward(self,x):
    8. x=x.view(x.size(0),-1)
    9. x=self.inputlayer(x)
    10. x=self.hiddenlayer(x)
    11. x=self.outputlayer(x)
    12. return x

    可以直接通过打印nn.Module的对象看到其网络结构

    4:模型评估 

     在准备好数据和模型后,就可以训练模型了,下面分别定义了数据处理和加载流程,模型,优化器,损失函数以及用准确率评估模型能力。

    得到的结果如下

    训练一次 可以看出比较混乱 没有说明规律可言 

    训练五次的损失函数如下 可见随着训练次数的增加是逐渐收敛的,规律也非常明显

     

     

     准确率图像如下

    最后 部分源码如下

    1. import torch
    2. import torchvision
    3. import torch.nn as nn
    4. from torch import optim
    5. from tqdm import tqdm
    6. import torch.utils.data.dataset
    7. mnist=torchvision.datasets.MNIST(root='~',train=True,download=True)
    8. for i,j in enumerate(np.random.randint(0,len(mnist),(10,))):
    9. data,label=mnist[j]
    10. plt.subplot(2,5,i+1)
    11. plt.show()
    12. trans=transforms.Compose(
    13. [
    14. transforms.ToTensor(),
    15. transforms.Normalize((0.1307,),(0.3081,))
    16. ]
    17. )
    18. normalized=trans(mnist[0][0])
    19. from torchvision import transforms
    20. mnist=torchvision.datasets.MNIST(root='~',train=True,download=True,transform=trans)
    21. def imshow(img):
    22. img=img*0.3081+0.1307
    23. npimg=img.numpy()
    24. plt.imshow(np.transpose(npimg,(1,2,0)))
    25. dataloader=DataLoader(mnist,batch_size=4,shuffle=True,num_workers=0)
    26. images,labels=next(iter(dataloader))
    27. imshow(torchvision.utils.make_grid(images))
    28. class MLP(nn.Module):
    29. def __init__(self):
    30. super(MLP,self).__init__()
    31. self.inputlayer=nn.Sequential(nn.Linear(28*28,256),nn.ReLU(),nn.Dropout(0.2))
    32. self.hiddenlayer=nn.Sequential(nn.Linear(256,256),nn.ReLU(),nn.Dropout(0.2))
    33. self.outputlayer=nn.Sequential(nn.Linear(256,10))
    34. def forward(self,x):
    35. x=x.view(x.size(0),-1)
    36. x=self.inputlayer(x)
    37. x=self.hiddenlayer(x)
    38. x=self.outputlayer(x)
    39. return x
    40. print(MLP())
    41. trans=transforms.Compose(
    42. [
    43. transforms.ToTensor(),
    44. transforms.Normalize((0.1307,),(0.3081,))
    45. ]
    46. )
    47. al=torchvision.datasets.MNIST(root='~',train=False,download=True,transform=trans)
    48. trainloader=DataLoader(mnist_train,batch_size=16,shuffle=True,num_workers=0)
    49. valloader=DataLoader(mnist_val,batch_size=16,shuffle=True,num_workers=0)
    50. #模型
    51. model=MLP()
    52. #优化器
    53. optimizer=oD(model.parameters(),lr=0.01,momentum=0.9)
    54. #损失函数
    55. celoss=nn.ssEntropyLoss()
    56. best_acc=0
    57. #计算准确率
    58. def accuracy(pred,target):
    59. pred_label=torch.amax(pred,1)
    60. correct=sum(pred_label==target).to(torch.float)
    61. return correct,len(pred)
    62. acc={'train':[],"val}
    63. loss_all={'train':[],"val":[]}
    64. for epoch in tqdm(range(5)):
    65. model.eval()
    66. numer_val,denumer_val,loss_tr=0.,0.,0.
    67. with torch.no_grad():
    68. for data,target in valloader:
    69. output=model(data)
    70. loss=celoss(output,target)
    71. loss_tr+=loss.data
    72. num,denum=accuracy(output,target)
    73. numer_val+=num
    74. denumer_val+=denum
    75. #设置为训练模式
    76. model.train()
    77. numer_tr,denumer_tr,loss_val=0.,0.,0.
    78. for data,target in trainloader:
    79. optizer.zero_grad()
    80. output=model(data)
    81. loss=celoss(output,target)
    82. loss_val+=loss.data
    83. loss.backward()
    84. optimer.step()
    85. num,denum=accuracy(output,target)
    86. numer_tr+=num
    87. denumer_tr+=denum
    88. loss_all['train'].append(loss_tr/len(trainloader))
    89. loss_all['val'].aend(lss_val/len(valloader))
    90. acc['train'].pend(numer_tr/denumer_tr)
    91. acc['val'].append(numer_val/denumer_val)
    92. """
    93. plt.plot(loss_all['train'])
    94. plt.plot(loss_all['val'])
    95. """
    96. plt.plot(acc['train'])
    97. plt.plot(acc['val'])
    98. plt.show()

     创作不易 觉得有帮助请点赞关注收藏~~~

  • 相关阅读:
    webpack基础版及其常用插件分享超详细~~
    upload-labs关卡11(双写后缀名绕过)通关思路
    Linux | 关于入门Linux你有必要了解的指令
    gdb连接qemu调试stm32程序
    win11安装mysql-8.0.28
    排查K8S的WSS内存一致升高
    python如何实现数据可视化,如何用python做可视化
    Java序列化与反序列化
    Windows系统Maven下载安装
    Spring boot 自定义 Starter 及 自动配置
  • 原文地址:https://blog.csdn.net/jiebaoshayebuhui/article/details/127783771