• 深度学习手写字符识别:推理过程


    说明

    本篇博客主要是跟着B站中国计量大学杨老师的视频实战深度学习手写字符识别
    第一个深度学习实例手写字符识别

    深度学习环境配置

    可以参考下篇博客,网上也有很多教程,很容易搭建好深度学习的环境。
    Windows11搭建GPU版本PyTorch环境详细过程

    数据集

    手写字符识别用到的数据集是MNIST数据集(Mixed National Institute of Standards and Technology database);MNIST是一个用来训练各种图像处理系统二进制图像数据集,广泛应用到机器学习中的训练和测试。
    作为一个入门级的计算机视觉数据集,发布20多年来,它已经被无数机器学习入门者应用无数遍,是最受欢迎的深度学习数据集之一。

    序号说明
    发布方National Institute of Standards and Technology(美国国家标准技术研究所,简称NIST)
    发布时间1998
    背景该数据集的论文想要证明在模式识别问题上,基于CNN的方法可以取代之前的基于手工特征的方法,所以作者创建了一个手写数字的数据集,以手写数字识别作为例子证明CNN在模式识别问题上的优越性。
    简介MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的。MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。

    手写字符识别模型训练

    可以参考下篇博客:
    深度学习手写字符识别:训练模型

    手写字符识别推理过程

    1. 选用训练好的模型output/params_yl.pth
      在这里插入图片描述

    2. Pycharm运行AI_course/classify_pytorch/test_mnist.py文件,输入的手写字符图片里的数字是“4”。
      在这里插入图片描述

    3. 推理源码如下:

    import torch
    import cv2
    from torch.autograd import Variable
    from torchvision import transforms
    from models.cnn import Net
    from toonnx import to_onnx
    
    use_cuda = False
    model = Net(10)
    # 注意:此处应把pth文件改为你训练出来的params_x.pth,x为epoch编号,
    # 一般来讲,编号越大,且训练集(train)和验证集(val)上准确率差别越小的(避免过拟合),效果越好。
    model.load_state_dict(torch.load('output/params_yl.pth'))
    # model = torch.load('output/model.pth')
    model.eval()
    if use_cuda and torch.cuda.is_available():
        model.cuda()
    
    #to_onnx(model, 3, 28, 28, 'output/params.onnx')
    
    img = cv2.imread('4_00440.jpg')
    img = cv2.resize(img, (28, 28))
    img_tensor = transforms.ToTensor()(img)
    img_tensor = img_tensor.unsqueeze(0)
    if use_cuda and torch.cuda.is_available():
        prediction = model(Variable(img_tensor.cuda()))
    else:
        prediction = model(Variable(img_tensor))
    pred = torch.max(prediction, 1)[1]
    print(prediction)
    print(pred)
    cv2.imshow("image", img)
    cv2.waitKey(0)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    1. 运行结果:打印其张量,可以看到用训练模型output/params_yl.pth的推理后结果,输入一张手下字4,最终推理结果是4;打印出0-9数字的概率,可以看到“4”的概率最高。
      在这里插入图片描述
    2. 验证推理有效性:为了验证其推理的真实性,重新手写一个手写字符。注意,得和训练集里的字符一样,黑底白字形式。
    • 手写“0”,识别出来的是“0”
      在这里插入图片描述
      在这里插入图片描述
    • 手写“3”,识别出来的是“3”
      在这里插入图片描述
      在这里插入图片描述
    • 手写“5”,识别出来的是“7”,可以看到识别错了。
      在这里插入图片描述
      在这里插入图片描述
    1. 验证推理结果,额外手写了3个字符,未使用测试集里的手写字符验证,对了2个,错了1个;识别率有待提高,可能需要更多次的epoch。

    后续

    • 下一篇章跟着视频进行手写字符识别的代码解析。
  • 相关阅读:
    LeetCode每日一题——795. 区间子数组个数
    想做WMS仓库管理系统,找了好久才找到云表
    Elasticsearch:使用最新的 Python client 8.0 来创建索引并搜索
    java中static
    【Qt之QMap】介绍及示例
    二维数组的动态创建和释放
    SpringCloud学习笔记(四)
    驱动程序开发:Linux内核自带LED使能
    Java实验案例(一)
    定时执行专家V6.1版发布,附更新日志及新功能截图
  • 原文地址:https://blog.csdn.net/yanceyxin/article/details/136286679