• keras 识别手写数字mnist


    #加载库
    import keras
    from keras import layers
    import matplotlib.pyplot as plt
    import keras.datasets.mnist as mnist
    
    (train_image,train_label),(test_image,test_label)=mnist.load_data()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这里插入图片描述

    #建立模型
    model = keras.Sequential()
    model.add(layers.Flatten()) #展平 (60000,28,28)---》(60000,28*28)
    model.add(layers.Dense(64,activation='relu'))
    model.add(layers.Dense(10,activation='softmax'))
    #编译
    model.compile(optimizer = "adam",
                  loss="sparse_categorical_crossentropy",
                  metrics=['acc']
    )
    #训练模型
    model.fit(train_image,train_label,epochs=50,batch_size=512)
    #评估结果
    model.evaluate(test_image,test_label)
    # [0.34606385231018066, 0.9524000287055969]
    model.evaluate(train_image,train_label)
    #[0.05486735701560974, 0.9823333621025085]
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    模型优化
    增加隐藏层

    model = keras.Sequential()
    model.add(layers.Flatten()) #展平 (60000,28,28)---》(60000,28*28)
    model.add(layers.Dense(64,activation='relu'))# 隐藏层
    model.add(layers.Dense(64,activation='relu'))# 隐藏层
    model.add(layers.Dense(64,activation='relu'))# 隐藏层
    model.add(layers.Dense(10,activation='softmax'))
    
    model.compile(optimizer = "adam",
                  loss="sparse_categorical_crossentropy",
                  metrics=['acc']
    )
    
    history = model.fit(train_image,train_label,epochs=50,batch_size=512,validation_data=(test_image,test_label))
    #画图
    epochs=range(len(history.history['acc']))
    plt.figure()
    plt.plot(epochs,history.history['acc'],'b',label='Training acc')
    plt.plot(epochs,history.history['val_acc'],'r',label='Validation acc')
    plt.title('Traing and Validation accuracy')
    plt.legend()
    
    
    plt.figure()
    plt.plot(epochs,history.history['loss'],'b',label='Training loss')
    plt.plot(epochs,history.history['val_loss'],'r',label='Validation val_loss')
    plt.title('Traing and Validation loss')
    plt.legend()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    在这里插入图片描述

  • 相关阅读:
    Spring IoC 容器生命周期:Ioc容器启停过程发生了什么-13
    PHP 使用 PHPRedis 与 Predis
    【数据结构】线性表与顺序表
    包装类与基本类型的区别
    设计模式之中介者模式
    如何将「知识」体系化管理
    web网页设计期末课程大作业——汉中印象旅游景点介绍网页设计与实现19页面HTML+CSS+JavaScript
    20个Python random模块的代码示例
    SpringBoot自动配置
    计算机网络中的封装和分用,五层协议
  • 原文地址:https://blog.csdn.net/qq_41431778/article/details/125503716