• Tensorflow图像识别 Tensorflow手写体识别(二)


    资源介绍

    我们从

    MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

    这条链接(MNIST官网)中下载好数据集,如下:

            下载下来以后整理成包含四个压缩包的文件MNIST_data(不要解压),然后把数据集直接拷贝到我们的代码目录下面,执行一个复制,粘贴到当前目录下面

    这是本次项目中要用到的所有文件,您可以下载我的链接:

    https://download.csdn.net/download/llf000000/86922664

    数据集

    数据集介绍

            之前我以为这60000个样本都是png图片,但是如果真的下载下来的话会很占内存。

            MNIST数据库由NIST的特殊数据库3和特殊数据库1构成,其中包含手写数字的二进制图像。
    NIST最初指定SD-3为训练集,SD-1为测试集。然而,SD-3比SD-1更干净、更容易识别。究其原因,可以发现SD-3是在人口普查局员工中收集的,而SD-1是在高中生中收集的。从学习实验中得出合理的结论要求实验结果不依赖于训练集的选择和样本集的测试。因此,有必要通过混合NIST的数据集来建立一个新的数据库。

            MNIST训练集由SD-3的30000个模式和SD-1的30000个模式组成。我们的测试集由来自SD-3的5000个模式和SD-1的5000个模式组成。60000个模式训练集包含了大约250个作家的例子。我们确保训练集和测试集的编写器集是不相交的。

            SD-1包含了58527位数字图像,由500位不同的作家编写。与SD-3不同,SD-3中来自每个writer的数据块按顺序出现,SD-1中的数据被置乱。SD-1的Writer标识是可用的,我们使用这些信息来解读Writer。然后我们把SD-1分成两部分:第一批250名作家写的字符进入了我们的新训练集。剩下的250个作家被放在我们的测试集中。因此,我们有两组,每一组有近30000个例子。(SD1加起来总数是6000)

    数据集格式

            文件中的所有整数都以多数非英特尔处理器使用的MSB first(高端)格式存储。英特尔处理器和其他低端计算机的用户必须翻转标头的字节。

    有4个文件:
    train-images-idx3-ubyte:训练集图像
    train-labels-idx1-ubyte:训练集标签
    t10k-images-idx3-ubyte:测试集图像
    t10k-labels-idx1-ubyte:测试集标签

    (1)训练集包含60000个示例。
    (2)测试集包含10000个示例。测试集的前5000个示例取自原始的NIST训练集。最后5000个是从最初的NIST测试集中提取的。前5000个比后5000个更干净、更简单。

    训练集是有60000个用例的,也就是说这个文件里面包含了60000个标签内容,每一个标签的值为0到9之间的一个数;

    参考链接:

    MNIST数据集的图片读取显示,并保存图片(python代码)_weixin_43094275的博客-CSDN博客_mnist图片显示

    代码查看手写体识别案例

            由于数据集每个图片直接下载下来不现实,,而这些数据集已经被神秘力量整理成可被代码识别的压缩包,通过查阅资料得知我们可以通过编写代码可视化这四个数据集。

    1. #!/usr/bin/env python
    2. # -*- coding:utf-8 -*-
    3. import tensorflow as tf
    4. from tensorflow.examples.tutorials.mnist import input_data
    5. import matplotlib.pyplot as plt
    6. # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签
    7. mnist = input_data.read_data_sets('../MNIST_data/', one_hot=True)
    8. # load data
    9. train_X = mnist.train.images
    10. train_Y = mnist.train.labels
    11. print(train_X.shape, train_Y.shape) # 输出训练集样本和标签的大小
    12. # 查看数据,例如训练集中第一个样本的内容和标签
    13. print(train_X[0]) # 是一个包含784个元素且值在[0,1]之间的向量
    14. print(train_Y[0])
    15. # 可视化样本,下面是输出了训练集中前4个样本
    16. fig, ax = plt.subplots(nrows=2, ncols=2, sharex='all', sharey='all')
    17. ax = ax.flatten()
    18. for i in range(4):
    19. img = train_X[i].reshape(28, 28)
    20. # ax[i].imshow(img,cmap='Greys')
    21. ax[i].imshow(img)
    22. ax[0].set_xticks([])
    23. ax[0].set_yticks([])
    24. plt.tight_layout()
    25. plt.show()

     

    手写体识别案例

    1. # 03_mnist.py
    2. # 手写体识别案例
    3. # 模型:全连接模型
    4. import tensorflow as tf
    5. from tensorflow.examples.tutorials.mnist import input_data
    6. import pylab # 用于显示图片
    7. # 定义样本读取对象
    8. # 这个就是定义一个专门用于mnist数据集的读取的对象
    9. mnist = input_data.read_data_sets("MNIST_data/", # 数据集所在目录
    10. one_hot=True) # 标签是否采用独热编码
    11. # 定义占位符,用于表示图像数据、标签
    12. # 因为这些数据都要从样本中读进来,穿进来,所以我们要定义一个占位符
    13. x = tf.placeholder(tf.float32, [None, 784]) # 图像数据,N行784列
    14. y = tf.placeholder(tf.float32, [None, 10]) # 标签(图像真实类别), N行784列
    15. # 定义权重、偏置
    16. w = tf.Variable(tf.random_normal([784, 10])) # 权重,784行10列
    17. b = tf.Variable(tf.zeros([10])) # 偏置, 10个偏置 十路输出所以又10个偏置
    18. # 构建模型,计算预测结果
    19. pred_y = tf.nn.softmax(tf.matmul(x, w) + b)
    20. '''
    21. 把x和w相乘,n行784列的矩阵×784行10列的矩阵,产生一个n行十列的输出,输出的10个值加上偏置,
    22. 这就是神经网络的计算公式,然后把它交给softmax函数进行挤压,转换成0到1的相对概率,
    23. 这个就作为我们的预测值'''
    24. # 损失函数
    25. cross_entropy = -tf.reduce_sum(y * tf.log(pred_y), reduction_indices=1)
    26. '''
    27. 真实的值×预测的值求对数,然后求和,reduce_sum在指定的维度上求和
    28. '''
    29. cost = tf.reduce_mean(cross_entropy) # 求均值
    30. # 梯度下降优化器
    31. optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
    32. '''
    33. 学习率为0.01,然后调用这个对象的minimize方法,把损失函数的值优化到最小
    34. ,优化的目标函数就是cost
    35. '''
    36. batch_size = 100 # 批次大小
    37. saver = tf.train.Saver() # saver
    38. '''
    39. 用于模型的保存和加载
    40. '''
    41. model_path = "model/mnist/mnist_model.ckpt" # 模型路径
    42. '''
    43. checkpoint
    44. '''
    45. with tf.Session() as sess:
    46. sess.run(tf.global_variables_initializer()) # 初始化
    47. # 开始训练
    48. for epoch in range(10):
    49. # 计算总批次
    50. total_batch = int(mnist.train.num_examples / batch_size)
    51. avg_cost = 0.0
    52. for i in range(total_batch):
    53. # 从训练集读取一个批次的样本
    54. batch_xs, batch_ys = mnist.train.next_batch(batch_size)
    55. '''
    56. xs是图像数据,ys是标签
    57. '''
    58. params = {x: batch_xs, y: batch_ys} # 参数字典
    59. o, c = sess.run([optimizer, cost], # 执行的op
    60. feed_dict=params) # 喂入参数
    61. '''
    62. 第一个op 执行optimizer执行梯度下降
    63. 第二个op 执行cost取得损失函数的值
    64. '''
    65. avg_cost += (c / total_batch) # 计算平均损失值
    66. print("epoch:%d, cost=%.9f" % (epoch + 1, avg_cost))
    67. print("训练结束.")
    68. # 模型评估
    69. # 比较预测结果和真实结果,返回布尔类型的数组
    70. correct_pred = tf.equal(tf.argmax(pred_y, 1), # 求预测结果中最大值的索引
    71. tf.argmax(y, 1)) # 求真实结果中最大的索引
    72. # 将布尔类型数组转换为浮点数,并计算准确率
    73. # 因为计算均值、准确率公式相同,所以调用计算均值的函数计算准确率
    74. accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # cast将correct_pred转换成浮点类型
    75. print("accuracy:", accuracy.eval({x: mnist.test.images, # 测试集下的图像数据
    76. y: mnist.test.labels})) # 测试集下图像的真实类别
    77. '''
    78. eval等价于放到session里面去run()
    79. '''
    80. # 保存模型
    81. save_path = saver.save(sess, model_path)
    82. print("模型已保存:", save_path)
    83. # 从测试集中随机读取2张图像,执行预测
    84. with tf.Session() as sess:
    85. sess.run(tf.global_variables_initializer())
    86. saver.restore(sess, model_path) # 加载模型
    87. # 从测试集中读取样本
    88. batch_xs, batch_ys = mnist.test.next_batch(2)
    89. output = tf.argmax(pred_y, 1) # 直接取出预测结果中的最大值
    90. output_val, predv = sess.run([output, pred_y], # 执行的op
    91. feed_dict={x: batch_xs}) # 预测,所以不需要传入标签
    92. print("预测最终结果:\n", output_val, "\n")
    93. print("真实结果:\n", batch_ys, "\n")
    94. print("预测概率:\n", predv, "\n")
    95. # 显示图片
    96. im = batch_xs[0] # 第一个测试样本
    97. im = im.reshape(-1, 28) # 28列,-1表示经过计算的值(是多少就是多少),行数根据图形的大小来算,算出来有多少行就有多少行
    98. pylab.imshow(im) # 显示图像
    99. pylab.show()
    100. im = batch_xs[1] # 第二个测试样本
    101. im = im.reshape(-1, 28) # 28列,-1表示经过计算的值
    102. pylab.imshow(im)
    103. pylab.show()

     

  • 相关阅读:
    DHorse操作手册
    关于Web应用和容器的指纹收集以及自动化软件的制作
    Python arcpy创建栅格、批量拼接栅格
    Differential (mathematics)
    java基础巩固-宇宙第一AiYWM:为了维持生计,做项目经验之~SSM项目错误集锦Part5(页面好卡呀、反应好慢呀)~整起
    java类的学习
    react 中组件的传参 怎么设置为可选的,比如加上?
    Milvus Cloud——Agent 框架工作方式
    Windows版 PostgreSQL 利用 pg_upgrade 进行大版升级操作
    Vue3 Vite3 状态管理 pinia 基本使用、持久化、在路由守卫中的使用
  • 原文地址:https://blog.csdn.net/llf000000/article/details/127725876