• 深度学习 Day 20——优化器对比实验


    深度学习 Day 20——优化器对比实验

    一、前言

    在上一期数据增强实验中,我们将TensorFlow版本升级到了2.4.0,可能有些库会出现不兼容异常,大家需要版本对应一下。

    本期博客,我们将着眼于深度学习中的各种优化器对比进行学习。

    二、我的环境

    • 电脑系统:Windows 11
    • 语言环境:Python 3.8.5
    • 编译器:DataSpell 2022.2
    • 深度学习环境:TensorFlow 2.4.0
    • 显卡及显存:RTX 3070 8G

    三、前期工作

    1、设置GPU

    import tensorflow as tf
    gpus = tf.config.list_physical_devices("GPU")
    
    if gpus:
        gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
        tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
        tf.config.set_visible_devices([gpu0],"GPU")
    
    from tensorflow          import keras
    import matplotlib.pyplot as plt
    import pandas            as pd
    import numpy             as np
    import warnings,os,PIL,pathlib
    
    warnings.filterwarnings("ignore")             #忽略警告信息
    plt.rcParams['font.sans-serif']    = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    2、导入数据

    本期使用的数据集跟之前的好莱坞明星识别使用的数据集一样。

    data_dir    = "/content/gdrive/MyDrive/data"
    data_dir    = pathlib.Path(data_dir)
    image_count = len(list(data_dir.glob('*/*')))
    print("图片总数为:",image_count)
    
    • 1
    • 2
    • 3
    • 4
    图片总数为: 1800
    
    • 1
    batch_size = 16
    img_height = 336
    img_width  = 336
    
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=12,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=12,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    Found 1800 files belonging to 17 classes.
    Using 1440 files for training.
    Found 1800 files belonging to 17 classes.
    Using 360 files for validation.
    
    • 1
    • 2
    • 3
    • 4

    查看一下数据文件标签:

    class_names = train_ds.class_names
    print(class_names)
    
    • 1
    • 2
    ['Angelina Jolie', 'Brad Pitt', 'Denzel Washington', 'Hugh Jackman', 'Jennifer Lawrence', 'Johnny Depp', 'Kate Winslet', 'Leonardo DiCaprio', 'Megan Fox', 'Natalie Portman', 'Nicole Kidman', 'Robert Downey Jr', 'Sandra Bullock', 'Scarlett Johansson', 'Tom Cruise', 'Tom Hanks', 'Will Smith']
    
    • 1

    3、配置数据集

    AUTOTUNE = tf.data.AUTOTUNE
    
    def train_preprocessing(image,label):
        return (image/255.0,label)
    
    train_ds = (
        train_ds.cache()
        .shuffle(1000)
        .map(train_preprocessing)    # 这里可以设置预处理函数
    #     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
        .prefetch(buffer_size=AUTOTUNE)
    )
    
    val_ds = (
        val_ds.cache()
        .shuffle(1000)
        .map(train_preprocessing)    # 这里可以设置预处理函数
    #     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
        .prefetch(buffer_size=AUTOTUNE)
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    4、数据可视化

    plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
    plt.suptitle("数据展示")
    
    for images, labels in train_ds.take(1):
        for i in range(15):
            plt.subplot(4, 5, i + 1)
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)
    
            # 显示图片
            plt.imshow(images[i])
            # 显示标签
            plt.xlabel(class_names[labels[i]-1])
    
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    在这里插入图片描述

    三、构建模型

    from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
    from tensorflow.keras.models import Model
    
    def create_model(optimizer='adam'):
        # 加载预训练模型
        vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
                                                                    include_top=False,
                                                                    input_shape=(img_width, img_height, 3),
                                                                    pooling='avg')
        for layer in vgg16_base_model.layers:
            layer.trainable = False
    
        X = vgg16_base_model.output
        
        X = Dense(170, activation='relu')(X)
        X = BatchNormalization()(X)
        X = Dropout(0.5)(X)
    
        output = Dense(len(class_names), activation='softmax')(X)
        vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)
    
        vgg16_model.compile(optimizer=optimizer,
                            loss='sparse_categorical_crossentropy',
                            metrics=['accuracy'])
        return vgg16_model
    
    model1 = create_model(optimizer=tf.keras.optimizers.Adam())
    model2 = create_model(optimizer=tf.keras.optimizers.SGD())
    model2.summary()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    打印的网络结构:

    Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
    58892288/58889256 [==============================] - 8s 0us/step
    Model: "model_1"
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    input_2 (InputLayer)         [(None, 336, 336, 3)]     0         
    _________________________________________________________________
    block1_conv1 (Conv2D)        (None, 336, 336, 64)      1792      
    _________________________________________________________________
    block1_conv2 (Conv2D)        (None, 336, 336, 64)      36928     
    _________________________________________________________________
    block1_pool (MaxPooling2D)   (None, 168, 168, 64)      0         
    _________________________________________________________________
    block2_conv1 (Conv2D)        (None, 168, 168, 128)     73856     
    _________________________________________________________________
    block2_conv2 (Conv2D)        (None, 168, 168, 128)     147584    
    _________________________________________________________________
    block2_pool (MaxPooling2D)   (None, 84, 84, 128)       0         
    _________________________________________________________________
    block3_conv1 (Conv2D)        (None, 84, 84, 256)       295168    
    _________________________________________________________________
    block3_conv2 (Conv2D)        (None, 84, 84, 256)       590080    
    _________________________________________________________________
    block3_conv3 (Conv2D)        (None, 84, 84, 256)       590080    
    _________________________________________________________________
    block3_pool (MaxPooling2D)   (None, 42, 42, 256)       0         
    _________________________________________________________________
    block4_conv1 (Conv2D)        (None, 42, 42, 512)       1180160   
    _________________________________________________________________
    block4_conv2 (Conv2D)        (None, 42, 42, 512)       2359808   
    _________________________________________________________________
    block4_conv3 (Conv2D)        (None, 42, 42, 512)       2359808   
    _________________________________________________________________
    block4_pool (MaxPooling2D)   (None, 21, 21, 512)       0         
    _________________________________________________________________
    block5_conv1 (Conv2D)        (None, 21, 21, 512)       2359808   
    _________________________________________________________________
    block5_conv2 (Conv2D)        (None, 21, 21, 512)       2359808   
    _________________________________________________________________
    block5_conv3 (Conv2D)        (None, 21, 21, 512)       2359808   
    _________________________________________________________________
    block5_pool (MaxPooling2D)   (None, 10, 10, 512)       0         
    _________________________________________________________________
    global_average_pooling2d_1 ( (None, 512)               0         
    _________________________________________________________________
    dense_2 (Dense)              (None, 170)               87210     
    _________________________________________________________________
    batch_normalization_1 (Batch (None, 170)               680       
    _________________________________________________________________
    dropout_1 (Dropout)          (None, 170)               0         
    _________________________________________________________________
    dense_3 (Dense)              (None, 17)                2907      
    =================================================================
    Total params: 14,805,485
    Trainable params: 90,457
    Non-trainable params: 14,715,028
    _________________________________________________________________
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58

    在这里我们是直接从网上下载VGG16模型并使用,可能会出现下载失败的情况,例如:

    Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5: None -- [WinError 10054] 远程主机强迫关闭了一个现有的连接。
    
    • 1

    这种情况就是网络问题,导致无法下载,可以多试几次看看,如果一直都无法下载的话,可以直接进入上面错误中的网址:https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5它就会自动下载该模型,我们将其保存在项目文件夹中,然后我们在上面代码调用VGG16模型的时候里面的weights参数的值改成下载的VGG模型对应的地址即可。

    在这里我们使用了两种优化器进行对比,Adam和SGD并对其两者进行简单的介绍:

    • Adam

      keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
      
      • 1

      它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。Adam的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。堆内存的需求比较小,也适用于大数据集和更高维空间的模型。

    • SGD

      keras.optimizers.SGD(lr=0.01, momentum=0.0, decay=0.0, nesterov=False)
      
      • 1

      它是一种随机梯度下降优化器,SGD就是每一次迭代计算mini-batch的梯度,然后对参数进行更新,是最常见的优化方法了。

    四、训练模型

    NO_EPOCHS = 50
    
    history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
    history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
    
    • 1
    • 2
    • 3
    • 4
    Epoch 1/50
    90/90 [==============================] - 113s 1s/step - loss: 2.8072 - accuracy: 0.1535 - val_loss: 2.7235 - val_accuracy: 0.0556
    Epoch 2/50
    90/90 [==============================] - 20s 221ms/step - loss: 2.0860 - accuracy: 0.3243 - val_loss: 2.4607 - val_accuracy: 0.2833
    Epoch 3/50
    90/90 [==============================] - 21s 238ms/step - loss: 1.8125 - accuracy: 0.4132 - val_loss: 2.2316 - val_accuracy: 0.2972
    Epoch 4/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.5680 - accuracy: 0.5146 - val_loss: 1.9419 - val_accuracy: 0.4361
    Epoch 5/50
    90/90 [==============================] - 20s 225ms/step - loss: 1.4038 - accuracy: 0.5681 - val_loss: 1.6831 - val_accuracy: 0.4833
    Epoch 6/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.2327 - accuracy: 0.6153 - val_loss: 1.6376 - val_accuracy: 0.4944
    Epoch 7/50
    90/90 [==============================] - 20s 223ms/step - loss: 1.1563 - accuracy: 0.6486 - val_loss: 1.6727 - val_accuracy: 0.4417
    Epoch 8/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.0707 - accuracy: 0.6694 - val_loss: 1.4806 - val_accuracy: 0.5250
    Epoch 9/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.9549 - accuracy: 0.7125 - val_loss: 1.6010 - val_accuracy: 0.4889
    Epoch 10/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.8829 - accuracy: 0.7347 - val_loss: 1.7179 - val_accuracy: 0.4611
    Epoch 11/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.8417 - accuracy: 0.7389 - val_loss: 1.7174 - val_accuracy: 0.4833
    Epoch 12/50
    90/90 [==============================] - 20s 225ms/step - loss: 0.7601 - accuracy: 0.7708 - val_loss: 1.5996 - val_accuracy: 0.4833
    Epoch 13/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.7254 - accuracy: 0.7757 - val_loss: 1.6183 - val_accuracy: 0.5278
    Epoch 14/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.6863 - accuracy: 0.8014 - val_loss: 1.7551 - val_accuracy: 0.4722
    Epoch 15/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.6336 - accuracy: 0.8069 - val_loss: 1.8830 - val_accuracy: 0.4639
    Epoch 16/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.5819 - accuracy: 0.8319 - val_loss: 1.4917 - val_accuracy: 0.5389
    Epoch 17/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.5748 - accuracy: 0.8340 - val_loss: 1.8751 - val_accuracy: 0.4694
    Epoch 18/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.5219 - accuracy: 0.8396 - val_loss: 2.0875 - val_accuracy: 0.4861
    Epoch 19/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.4934 - accuracy: 0.8556 - val_loss: 1.9038 - val_accuracy: 0.5028
    Epoch 20/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.4942 - accuracy: 0.8514 - val_loss: 1.6452 - val_accuracy: 0.5444
    Epoch 21/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.4933 - accuracy: 0.8431 - val_loss: 2.1585 - val_accuracy: 0.4472
    Epoch 22/50
    90/90 [==============================] - 20s 225ms/step - loss: 0.4514 - accuracy: 0.8701 - val_loss: 2.0218 - val_accuracy: 0.4972
    Epoch 23/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.4458 - accuracy: 0.8694 - val_loss: 1.6499 - val_accuracy: 0.5417
    Epoch 24/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.3927 - accuracy: 0.8917 - val_loss: 2.3310 - val_accuracy: 0.4222
    Epoch 25/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.3870 - accuracy: 0.8854 - val_loss: 1.6200 - val_accuracy: 0.5583
    Epoch 26/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.3800 - accuracy: 0.8861 - val_loss: 1.9285 - val_accuracy: 0.5361
    Epoch 27/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.3792 - accuracy: 0.8771 - val_loss: 2.3675 - val_accuracy: 0.4806
    Epoch 28/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.3321 - accuracy: 0.8986 - val_loss: 1.7445 - val_accuracy: 0.5500
    Epoch 29/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.3185 - accuracy: 0.9076 - val_loss: 1.7202 - val_accuracy: 0.5639
    Epoch 30/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.3436 - accuracy: 0.8958 - val_loss: 1.6614 - val_accuracy: 0.5667
    Epoch 31/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2917 - accuracy: 0.9118 - val_loss: 2.0079 - val_accuracy: 0.5500
    Epoch 32/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.3325 - accuracy: 0.8868 - val_loss: 2.0677 - val_accuracy: 0.5028
    Epoch 33/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2879 - accuracy: 0.9146 - val_loss: 1.6412 - val_accuracy: 0.6028
    Epoch 34/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2856 - accuracy: 0.9111 - val_loss: 2.1213 - val_accuracy: 0.5222
    Epoch 35/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2645 - accuracy: 0.9153 - val_loss: 2.0940 - val_accuracy: 0.5222
    Epoch 36/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2528 - accuracy: 0.9160 - val_loss: 1.8489 - val_accuracy: 0.5389
    Epoch 37/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2553 - accuracy: 0.9208 - val_loss: 1.8388 - val_accuracy: 0.5583
    Epoch 38/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2362 - accuracy: 0.9285 - val_loss: 1.8624 - val_accuracy: 0.5667
    Epoch 39/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.2245 - accuracy: 0.9229 - val_loss: 1.9156 - val_accuracy: 0.5639
    Epoch 40/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2198 - accuracy: 0.9333 - val_loss: 2.2192 - val_accuracy: 0.5556
    Epoch 41/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2144 - accuracy: 0.9278 - val_loss: 1.8951 - val_accuracy: 0.5833
    Epoch 42/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2074 - accuracy: 0.9389 - val_loss: 2.0159 - val_accuracy: 0.5500
    Epoch 43/50
    90/90 [==============================] - 20s 225ms/step - loss: 0.2166 - accuracy: 0.9257 - val_loss: 2.2641 - val_accuracy: 0.5111
    Epoch 44/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2312 - accuracy: 0.9264 - val_loss: 2.0438 - val_accuracy: 0.5750
    Epoch 45/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.2248 - accuracy: 0.9257 - val_loss: 2.2686 - val_accuracy: 0.5472
    Epoch 46/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2102 - accuracy: 0.9375 - val_loss: 2.2441 - val_accuracy: 0.5583
    Epoch 47/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.2120 - accuracy: 0.9340 - val_loss: 2.3860 - val_accuracy: 0.5361
    Epoch 48/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.1959 - accuracy: 0.9354 - val_loss: 2.4052 - val_accuracy: 0.5167
    Epoch 49/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.1699 - accuracy: 0.9521 - val_loss: 2.5167 - val_accuracy: 0.5250
    Epoch 50/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.1645 - accuracy: 0.9528 - val_loss: 2.1405 - val_accuracy: 0.5722
    Epoch 1/50
    90/90 [==============================] - 21s 226ms/step - loss: 3.0785 - accuracy: 0.0986 - val_loss: 2.7949 - val_accuracy: 0.1000
    Epoch 2/50
    90/90 [==============================] - 20s 223ms/step - loss: 2.5472 - accuracy: 0.1924 - val_loss: 2.6379 - val_accuracy: 0.1583
    Epoch 3/50
    90/90 [==============================] - 20s 225ms/step - loss: 2.2651 - accuracy: 0.2694 - val_loss: 2.4596 - val_accuracy: 0.2528
    Epoch 4/50
    90/90 [==============================] - 20s 224ms/step - loss: 2.0612 - accuracy: 0.3389 - val_loss: 2.2347 - val_accuracy: 0.3389
    Epoch 5/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.9508 - accuracy: 0.3653 - val_loss: 2.0695 - val_accuracy: 0.3972
    Epoch 6/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.8406 - accuracy: 0.4021 - val_loss: 1.9282 - val_accuracy: 0.3917
    Epoch 7/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.7565 - accuracy: 0.4451 - val_loss: 1.8469 - val_accuracy: 0.4111
    Epoch 8/50
    90/90 [==============================] - 20s 223ms/step - loss: 1.6587 - accuracy: 0.4667 - val_loss: 1.7935 - val_accuracy: 0.4306
    Epoch 9/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.5934 - accuracy: 0.4889 - val_loss: 1.6561 - val_accuracy: 0.4528
    Epoch 10/50
    90/90 [==============================] - 20s 223ms/step - loss: 1.5516 - accuracy: 0.4854 - val_loss: 1.7235 - val_accuracy: 0.3944
    Epoch 11/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.4753 - accuracy: 0.5403 - val_loss: 1.6903 - val_accuracy: 0.4333
    Epoch 12/50
    90/90 [==============================] - 20s 223ms/step - loss: 1.4309 - accuracy: 0.5389 - val_loss: 1.6633 - val_accuracy: 0.4556
    Epoch 13/50
    90/90 [==============================] - 20s 225ms/step - loss: 1.4168 - accuracy: 0.5437 - val_loss: 1.6759 - val_accuracy: 0.4667
    Epoch 14/50
    90/90 [==============================] - 20s 223ms/step - loss: 1.3726 - accuracy: 0.5701 - val_loss: 1.7004 - val_accuracy: 0.4667
    Epoch 15/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.2890 - accuracy: 0.5924 - val_loss: 1.6371 - val_accuracy: 0.4639
    Epoch 16/50
    90/90 [==============================] - 20s 223ms/step - loss: 1.2669 - accuracy: 0.6139 - val_loss: 1.5207 - val_accuracy: 0.4806
    Epoch 17/50
    90/90 [==============================] - 20s 223ms/step - loss: 1.2238 - accuracy: 0.6097 - val_loss: 1.5294 - val_accuracy: 0.4972
    Epoch 18/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.1582 - accuracy: 0.6375 - val_loss: 1.4838 - val_accuracy: 0.5111
    Epoch 19/50
    90/90 [==============================] - 20s 223ms/step - loss: 1.1518 - accuracy: 0.6271 - val_loss: 1.5244 - val_accuracy: 0.5111
    Epoch 20/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.1324 - accuracy: 0.6438 - val_loss: 1.5217 - val_accuracy: 0.4917
    Epoch 21/50
    90/90 [==============================] - 21s 237ms/step - loss: 1.0931 - accuracy: 0.6590 - val_loss: 1.4744 - val_accuracy: 0.5056
    Epoch 22/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.0524 - accuracy: 0.6667 - val_loss: 1.4386 - val_accuracy: 0.5167
    Epoch 23/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.0196 - accuracy: 0.6729 - val_loss: 1.4282 - val_accuracy: 0.5278
    Epoch 24/50
    90/90 [==============================] - 20s 224ms/step - loss: 1.0143 - accuracy: 0.6924 - val_loss: 1.5158 - val_accuracy: 0.5361
    Epoch 25/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.9708 - accuracy: 0.6875 - val_loss: 1.5623 - val_accuracy: 0.4806
    Epoch 26/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.9651 - accuracy: 0.6875 - val_loss: 1.3693 - val_accuracy: 0.5611
    Epoch 27/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.9384 - accuracy: 0.7076 - val_loss: 1.4377 - val_accuracy: 0.5556
    Epoch 28/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.8951 - accuracy: 0.7285 - val_loss: 1.4171 - val_accuracy: 0.5222
    Epoch 29/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.8706 - accuracy: 0.7340 - val_loss: 1.6458 - val_accuracy: 0.5167
    Epoch 30/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.8520 - accuracy: 0.7375 - val_loss: 1.4419 - val_accuracy: 0.5139
    Epoch 31/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.8547 - accuracy: 0.7188 - val_loss: 1.2940 - val_accuracy: 0.5889
    Epoch 32/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.8222 - accuracy: 0.7424 - val_loss: 1.4509 - val_accuracy: 0.5528
    Epoch 33/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.8406 - accuracy: 0.7299 - val_loss: 1.4598 - val_accuracy: 0.5306
    Epoch 34/50
    90/90 [==============================] - 20s 225ms/step - loss: 0.7983 - accuracy: 0.7528 - val_loss: 1.5114 - val_accuracy: 0.5472
    Epoch 35/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.7992 - accuracy: 0.7403 - val_loss: 1.4475 - val_accuracy: 0.5750
    Epoch 36/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.7557 - accuracy: 0.7569 - val_loss: 1.5024 - val_accuracy: 0.5389
    Epoch 37/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.7298 - accuracy: 0.7681 - val_loss: 1.4272 - val_accuracy: 0.5389
    Epoch 38/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.7378 - accuracy: 0.7632 - val_loss: 1.3973 - val_accuracy: 0.5778
    Epoch 39/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.7025 - accuracy: 0.7875 - val_loss: 1.3738 - val_accuracy: 0.5500
    Epoch 40/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.6812 - accuracy: 0.7958 - val_loss: 1.5651 - val_accuracy: 0.5361
    Epoch 41/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.6646 - accuracy: 0.7854 - val_loss: 1.4765 - val_accuracy: 0.5667
    Epoch 42/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.6477 - accuracy: 0.8021 - val_loss: 1.5985 - val_accuracy: 0.5361
    Epoch 43/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.6508 - accuracy: 0.8042 - val_loss: 1.3467 - val_accuracy: 0.5667
    Epoch 44/50
    90/90 [==============================] - 20s 225ms/step - loss: 0.6539 - accuracy: 0.7889 - val_loss: 1.3919 - val_accuracy: 0.5778
    Epoch 45/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.6402 - accuracy: 0.8104 - val_loss: 1.3426 - val_accuracy: 0.5917
    Epoch 46/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.6178 - accuracy: 0.8076 - val_loss: 1.4094 - val_accuracy: 0.5833
    Epoch 47/50
    90/90 [==============================] - 20s 223ms/step - loss: 0.6083 - accuracy: 0.8000 - val_loss: 1.3747 - val_accuracy: 0.5750
    Epoch 48/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.6079 - accuracy: 0.8028 - val_loss: 1.5148 - val_accuracy: 0.5583
    Epoch 49/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.6115 - accuracy: 0.8000 - val_loss: 1.9661 - val_accuracy: 0.4556
    Epoch 50/50
    90/90 [==============================] - 20s 224ms/step - loss: 0.5785 - accuracy: 0.8146 - val_loss: 1.4971 - val_accuracy: 0.5500
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200

    五、模型评估

    1、Accuracy与Loss图

    from matplotlib.ticker import MultipleLocator
    plt.rcParams['savefig.dpi'] = 300 #图片像素
    plt.rcParams['figure.dpi']  = 300 #分辨率
    
    acc1     = history_model1.history['accuracy']
    acc2     = history_model2.history['accuracy']
    val_acc1 = history_model1.history['val_accuracy']
    val_acc2 = history_model2.history['val_accuracy']
    
    loss1     = history_model1.history['loss']
    loss2     = history_model2.history['loss']
    val_loss1 = history_model1.history['val_loss']
    val_loss2 = history_model2.history['val_loss']
    
    epochs_range = range(len(acc1))
    
    plt.figure(figsize=(16, 4))
    plt.subplot(1, 2, 1)
    
    plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
    plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
    plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
    plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')
    # 设置刻度间隔,x轴每1一个刻度
    ax = plt.gca()
    ax.xaxis.set_major_locator(MultipleLocator(1))
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss1, label='Training Loss-Adam')
    plt.plot(epochs_range, loss2, label='Training Loss-SGD')
    plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
    plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
       
    # 设置刻度间隔,x轴每1一个刻度
    ax = plt.gca()
    ax.xaxis.set_major_locator(MultipleLocator(1))
    
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    在这里插入图片描述

    2、评估模型

    def test_accuracy_report(model):
        score = model.evaluate(val_ds, verbose=0)
        print('Loss function: %s, accuracy:' % score[0], score[1])
        
    test_accuracy_report(model2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    Loss function: 1.49705171585083, accuracy: 0.550000011920929
    
    • 1

    六、最后我想说

    本期的博客到这里就结束了,最近我的电脑出现了一些问题导致无法对复杂的模型进行训练,最近考虑准备重装一下系统并清理一下电脑了,本次实验我是在Google Colaboratory上run的,这个平台目前我使用下来感觉不还错,提供免费的算力对我来说够了,大家如果自己的电脑配置不够的话也可以去试试。

  • 相关阅读:
    浅谈一下:Java当中的构造方法
    学习-官方文档编辑方法
    AffineTransformations仿射变化
    【Verilog基础】10.偶分频和奇分频
    第 45 届国际大学生程序设计竞赛(ICPC)亚洲区域赛(昆明),签到题4题
    PyCharm及Python3.10.5安装配置教程
    【附源码】计算机毕业设计JAVA学习自律养成小程序前台.mp4
    【MySQL】索引和事物
    二阶系统时域响应
    如何退出或卸载奇安信天擎软件
  • 原文地址:https://blog.csdn.net/qq_52417436/article/details/128042415