• keras 识别手写数字mnist


    #加载库
    import keras
    from keras import layers
    import matplotlib.pyplot as plt
    import keras.datasets.mnist as mnist
    
    (train_image,train_label),(test_image,test_label)=mnist.load_data()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这里插入图片描述

    #建立模型
    model = keras.Sequential()
    model.add(layers.Flatten()) #展平 (60000,28,28)---》(60000,28*28)
    model.add(layers.Dense(64,activation='relu'))
    model.add(layers.Dense(10,activation='softmax'))
    #编译
    model.compile(optimizer = "adam",
                  loss="sparse_categorical_crossentropy",
                  metrics=['acc']
    )
    #训练模型
    model.fit(train_image,train_label,epochs=50,batch_size=512)
    #评估结果
    model.evaluate(test_image,test_label)
    # [0.34606385231018066, 0.9524000287055969]
    model.evaluate(train_image,train_label)
    #[0.05486735701560974, 0.9823333621025085]
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    模型优化
    增加隐藏层

    model = keras.Sequential()
    model.add(layers.Flatten()) #展平 (60000,28,28)---》(60000,28*28)
    model.add(layers.Dense(64,activation='relu'))# 隐藏层
    model.add(layers.Dense(64,activation='relu'))# 隐藏层
    model.add(layers.Dense(64,activation='relu'))# 隐藏层
    model.add(layers.Dense(10,activation='softmax'))
    
    model.compile(optimizer = "adam",
                  loss="sparse_categorical_crossentropy",
                  metrics=['acc']
    )
    
    history = model.fit(train_image,train_label,epochs=50,batch_size=512,validation_data=(test_image,test_label))
    #画图
    epochs=range(len(history.history['acc']))
    plt.figure()
    plt.plot(epochs,history.history['acc'],'b',label='Training acc')
    plt.plot(epochs,history.history['val_acc'],'r',label='Validation acc')
    plt.title('Traing and Validation accuracy')
    plt.legend()
    
    
    plt.figure()
    plt.plot(epochs,history.history['loss'],'b',label='Training loss')
    plt.plot(epochs,history.history['val_loss'],'r',label='Validation val_loss')
    plt.title('Traing and Validation loss')
    plt.legend()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    在这里插入图片描述

  • 相关阅读:
    React基础语法
    硬件知识:U盘相关知识介绍,值得收藏
    CPU流水线与指令乱序执行
    惊了!10万字的Spark全文!
    全链路压测的步骤及重要性
    Spring Cloud Alibaba —— nacos配置中心管理数据库、gateway等配置项
    33.【C/C++ char 类型与Ascii大整合,少一个没考虑你打我】
    ZMQ之异步管家模式
    Sonatype Nexus 如何把多仓库合并在一起
    sqlmap获取目标
  • 原文地址:https://blog.csdn.net/qq_41431778/article/details/125503716