• tensorflow2.0 mnist手写数字识别 并验证几张图片以查看效果


    目的:

    tensorflow2.0+mnist已经算是ML界的“HELLO WORLD”了吧。网络上这方面的内容也比较多,可是有很多教程都是讲到训练完成就结束,如果只从数据上看准确率,对我们的主观感受不深。这里讲一下如何拿一张图片验证一下,以便直观感受。

    环境:miniconda下tensorflow2.0

    1、以下为训练:

    1. import tensorflow as tf
    2. import pandas as pd
    3. import matplotlib.pyplot as plt
    4. %matplotlib inline
    5. import image
    6. import numpy as np
    7. data = tf.keras.datasets.mnist
    8. (x_train,y_train),(x_test,y_test) = data.load_data()
    9. x_train,x_test = x_train/255.0,x_test/255.0
    10. #print(x_train.shape)
    11. #plt.imshow(y_train[0])
    12. #print(x_train[0])
    13. #y_test[0]
    14. model = tf.keras.models.Sequential()
    15. model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
    16. model.add(tf.keras.layers.Dense(128, activation="relu"))
    17. model.add(tf.keras.layers.Dropout(0.2))
    18. model.add(tf.keras.layers.Dense(10))
    19. loss_fn = loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    20. model.compile(
    21. optimizer="adam",
    22. loss=loss_fn,
    23. metrics=['accuracy']
    24. )
    25. model.fit(x_train,y_train,epochs=5)

    2、以下为验证:

     1、需要注意的是model.predict()输入参数应该是一个2维1列的矩阵。

    2、这样做的好处是可以一次输入多个待验证数据,一并取得答案。

    3、

    而图片为28*28的矩阵,取其中一个:

                                                    img1_ = x_test[INDEX]

    将矩阵展开成1维数组:

                                                    img1_ = img1_.flatten()

    将1维数组展转换成2维1列的矩阵,-1代表由函数自动确认个数:

                                                    img1  = img1_.reshape(1,-1)

    1. #以下为训练,输入训练集的索引为3
    2. INDEX = 3
    3. img1_ = x_test[INDEX]
    4. img1_ = img1_.flatten()
    5. img1 = img1_.reshape(1,-1)
    6. #print(img1)
    7. ret = model.predict(img1)
    8. #显示预测值,其实对应的是0-9共10个数分别的概率,最大的为预测值
    9. print(ret)
    10. #显示真正值
    11. print(y_test[INDEX])

    3、结果如下:

    ---------------------------------------------------------------分割线---------------------------------------------------------

    那辛苦训练好的模型如何保存下来呢?:

                                                         model.save("./mnist_model/mnist01.h5")   

                                                         #保存在了路径:./mnist_model/mnist01.h5 里。

    保存下来的模型如何加载呢?:

                                 model = tf.keras.models.load_model("./mnist_model/mnist01.h5")

  • 相关阅读:
    驱动开发 linux内核GPIO子系统、及其新版API的概念和使用,linux内核定时器
    ROSIntegration ROSIntegrationVision与虚幻引擎4(Unreal Engine 4)的配置
    重载和重写什么区别?
    Hibernate 一对多关系映射
    卫浴服务信息展示预约小程序的作用如何
    HTTPS报文分析(Wireshark)
    CPU版本的pytorch安装
    Django数据表修改方法
    DarkGate恶意软件通过消息服务传播
    在linux下的vim中使用内联函数时,会有未定义的引用错误解决办法
  • 原文地址:https://blog.csdn.net/c_1969/article/details/127638690