• 深度学习03——手写数字识别实例


    目录

    0. 实验概述 

    1.利用Tensorflow自动加载mnist数据集 

    2. 手写数字识别体验

    2.1 准备网络结构与优化器

     2.2 计算损失函数与输出

     2.3 梯度计算与优化

     2.4 循环

    2.5 完整代码

    补充:os.environ['TF_CPP_MIN_LOG_LEVEL']


    0. 实验概述 

    (以图片中的二分类问题为例)

     

    1.利用Tensorflow自动加载mnist数据集 

     

     代码:

    1. import tensorflow as tf
    2. from tensorflow.keras import datasets, layers, optimizers
    3. (xs,ys),_ = datasets.mnist.load_data() # 自动下载mnist数据集
    4. print('datasets:',xs.shape,ys.shape)
    5. xs = tf.convert_to_tensor(xs,dtype=tf.float32)/255. # 将mnist中的数据转为tensorflow格式
    6. db = tf.data.Dataset.from_tensor_slices((xs,ys)) #将下载的数据存入datasets数据集
    7. for step,(x,y) in enumerate(db): #单个数据输出
    8. print(step,x.shape,y,y.shape)

    代码切割分析:

     

    2. 手写数字识别体验

    2.1 准备网络结构与优化器

     

    利用Sequential模块。 

    1. #准备网络结构与优化器
    2. model = keras.Sequential([
    3. #3层结构
    4. layers.Dense(512, activation='relu'),
    5. layers.Dense(256, activation='relu'),
    6. layers.Dense(10)])
    7. optimizer = optimizers.SGD(learning_rate=0.001)

     2.2 计算损失函数与输出

    1. with tf.GradientTape() as tape:
    2. # [b, 28, 28] => [b, 784]
    3. x = tf.reshape(x, (-1, 28*28))
    4. # Step1. compute output
    5. # [b, 784] => [b, 10]
    6. out = model(x)
    7. # Step2. compute loss
    8. loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]

     2.3 梯度计算与优化

    1. # Step3. optimize and update w1, w2, w3, b1, b2, b3
    2. grads = tape.gradient(loss, model.trainable_variables)
    3. # w' = w - lr * grad
    4. optimizer.apply_gradients(zip(grads, model.trainable_variables))

     

     2.4 循环

     

    2.5 完整代码

    1. import os
    2. import tensorflow as tf
    3. from tensorflow import keras
    4. from tensorflow.keras import layers, optimizers, datasets
    5. os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
    6. #数据集的加载
    7. (x, y), (x_val, y_val) = datasets.mnist.load_data()
    8. x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
    9. y = tf.convert_to_tensor(y, dtype=tf.int32)
    10. y = tf.one_hot(y, depth=10)
    11. print(x.shape, y.shape)
    12. train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
    13. train_dataset = train_dataset.batch(200) #一次加载200张图片
    14. #准备网络结构与优化器
    15. model = keras.Sequential([
    16. #3层结构
    17. layers.Dense(512, activation='relu'),
    18. layers.Dense(256, activation='relu'),
    19. layers.Dense(10)])
    20. optimizer = optimizers.SGD(learning_rate=0.001)
    21. #计算迭代
    22. def train_epoch(epoch):
    23. # Step4.loop
    24. for step, (x, y) in enumerate(train_dataset):
    25. with tf.GradientTape() as tape:
    26. # [b, 28, 28] => [b, 784]
    27. x = tf.reshape(x, (-1, 28*28))
    28. # Step1. compute output
    29. # [b, 784] => [b, 10]
    30. out = model(x)
    31. # Step2. compute loss
    32. loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]
    33. # Step3. optimize and update w1, w2, w3, b1, b2, b3
    34. grads = tape.gradient(loss, model.trainable_variables)
    35. # w' = w - lr * grad
    36. optimizer.apply_gradients(zip(grads, model.trainable_variables))
    37. if step % 100 == 0:
    38. print(epoch, step, 'loss:', loss.numpy())
    39. def train():
    40. #计算迭代30次
    41. for epoch in range(30):
    42. train_epoch(epoch)
    43. if __name__ == '__main__':
    44. train()

    训练结果:

     

    补充:os.environ['TF_CPP_MIN_LOG_LEVEL']

     os.environ["TF_CPP_MIN_LOG_LEVEL"]的取值有四个:0,1,2,3,分别和log的四个等级挂钩:INFO,WARNING,ERROR,FATAL(重要性由左到右递增)

        当os.environ["TF_CPP_MIN_LOG_LEVEL"]=0的时候,输出信息:INFO + WARNING + ERROR + FATAL
        当os.environ["TF_CPP_MIN_LOG_LEVEL"]=1的时候,输出信息:WARNING + ERROR + FATAL
        当os.environ["TF_CPP_MIN_LOG_LEVEL"]=2的时候,输出信息:ERROR + FATAL
        当os.environ["TF_CPP_MIN_LOG_LEVEL"]=3的时候,输出信息:FATAL
     

  • 相关阅读:
    uniapp 阿里云点播 视频播放(加密版)(APP版)记录播放进度
    在hugging face上发布自己的模型 (ubuntu 19.0)
    FlyFish|前端数据可视化开发避坑指南(二)
    四大战略举措,亚马逊云科技加码中国市场背后的逻辑
    Godot游戏引擎有哪些优势
    用pymysql封装项目通用的连接和查询
    掌握Explain分析性能瓶颈、避免索引失效
    Mysql基本知识篇
    DOE认证是什么
    Spring的Environment
  • 原文地址:https://blog.csdn.net/m0_55196097/article/details/126356082