• TensorFlow识别4种天气状态(CNN,模型ACC:93.78%)


    活动地址:CSDN21天学习挑战赛

    1.数据处理

    数据链接:传送门(提取码:hqij )

    设置GPU环境

    import tensorflow as tf
    
    gpus = tf.config.list_physical_devices("GPU")
    
    if gpus:
        gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
        tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
        tf.config.set_visible_devices([gpu0],"GPU")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    导入数据

    import matplotlib.pyplot as plt
    import os,PIL
    
    # 设置随机种子尽可能使结果可以重现
    import numpy as np
    np.random.seed(1)
    
    # 设置随机种子尽可能使结果可以重现
    import tensorflow as tf
    tf.random.set_seed(1)
    
    from tensorflow import keras
    from tensorflow.keras import layers,models
    
    import pathlib
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    设置数据地址

    data_dir = "E:\demo_study\jupyter\Jupyter_notebook\Weather-recognition-based-on-CNN\weather_photos"
    data_dir = pathlib.Path(data_dir)
    
    • 1
    • 2

    查看数据

    数据集一共分为cloudyrainshinesunrise四类,分别存放于weather_photos文件夹中以各自名字命名的子文件夹中

    查看图片总数:

    image_count = len(list(data_dir.glob('*/*.jpg')))
    print("图片总数为:",image_count)
    
    • 1
    • 2

    输出:

    图片总数为: 1125
    
    • 1

    查看第一张图片:

    roses = list(data_dir.glob('sunrise/*.jpg'))
    PIL.Image.open(str(roses[0]))
    
    • 1
    • 2

    输出:

    请添加图片描述

    1.1.数据预处理

    加载数据:

    使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

    设置数据参数:

    batch_size = 64
    img_height = 180
    img_width = 180
    
    • 1
    • 2
    • 3

    加载训练集数据:

    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        seed=123,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    Found 1125 files belonging to 4 classes.
    Using 900 files for training.
    
    • 1
    • 2

    加载验证集数据:

    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        seed=123,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    Found 1125 files belonging to 4 classes.
    Using 225 files for validation.
    
    • 1
    • 2

    通过class_names输出数据集的标签,标签按字母顺序对应于目录名称:

    class_names = train_ds.class_names
    print(class_names)
    
    • 1
    • 2

    输出:

    ['cloudy', 'rain', 'shine', 'sunrise']
    
    • 1

    查看train_ds的数据类型:

    train_ds
    
    • 1

    输出:

    <PrefetchDataset element_spec=(TensorSpec(shape=(None, 180, 180, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int32, name=None))>
    
    • 1

    1.2.数据可视化

    # 每次画的图不一样
    
    plt.figure(figsize=(12, 10))
    
    for images, labels in train_ds.take(1):
        for i in range(30):
            ax = plt.subplot(5, 6, i + 1)
    
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(class_names[labels[i]])
            plt.savefig('pic1.jpg', dpi=600) #指定分辨率保存
            plt.axis("off")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    输出:
    请添加图片描述

    查看图片形状:

    for image_batch, labels_batch in train_ds:
        print(image_batch.shape)
        print(labels_batch.shape)
        break
    
    • 1
    • 2
    • 3
    • 4
    • Image_batch是形状的张量(32,180,180,3)。这是一批形状180x180x3的32张图片(最后一维指的是彩色通道RGB)。
    • Label_batch是形状(32,)的张量,这些标签对应32张图片

    输出:

    (64, 180, 180, 3)
    (64,)
    
    • 1
    • 2

    1.3.配置数据集

    shuffle() : 打乱数据,详细可参考:数据集shuffle方法中buffer_size的理解

    prefetch() :预取数据,加速运行,详细可参考:Better performance with the tf.data API

    cache() :将数据集缓存到内存当中,加速运行

    推荐一篇博客:【学习笔记】使用tf.data对预处理过程优化

    prefetch()功能详细介绍:它使得训练步骤的预处理和模型执行部分重叠起来,原来是:

    在这里插入图片描述

    prefetch()之后是:

    在这里插入图片描述
    当然,不这么处理也可以的

    AUTOTUNE = tf.data.AUTOTUNE
    
    train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
    val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
    
    • 1
    • 2
    • 3
    • 4

    2.构建CNN网络

    卷积神经网络(CNN)的输入是张量 (Tensor) 形式的: (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息(color_channels(R,G,B),分别对应 RGB 的三个颜色通道)

    首先通过layers.experimental.preprocessing.Rescaling对图像进行处理(官方解释:rescales and offsets the values of a batch of image (e.g. go from inputs in the [0, 255] range to inputs in the [0, 1] range)

    函数详细介绍可看:tf.keras.layers.Rescaling

    设置scale=1./255,可将[0, 255]范围内的输入重新缩放为[0, 1]范围内
    参数input_shape指明图像的大小格式

    2.1.池化

    池化层介绍:

    在CNN中通常会在相邻的卷积层之间加入一个池化层,池化层可以有效的缩小参数矩阵的尺寸,从而减少最后连接层的中的参数数量

    常见的2类池化层可看下图(借鉴自知乎:传送门):
    在这里插入图片描述

    池化层的作用:

    • 缩小特征图一个显而易见的好处是减少参数量,降维、去除冗余信息、对特征进行压缩、简化网络复杂度、减小计算量、减小内存消耗
    • 池化最重要的一个作用是扩大神经元,也就是特征图中一个像素的感受野大小。 在浅层卷积中,特征图还很大,一个像素能接收到的实际图像面积很小。虽然逐层的卷积一定程度上能够沟通相邻的神经元,但作用有限。而池化,通过简单粗暴地合并相邻的若干个神经元,使缩小后的特征图上的神经元能够获得来自原图像更大范围的信息,从而提取出更高阶的特征
    • 平均池化和最大池化分别代表了合并相邻神经元信息的两种策略。平均池化能够更好地保留所有神经元的信息,能够体现出对应区域的平均响应情况;而最大池化则能够保留细节,保留对应区域的最大响应,防止强烈的响应被周围的神经元弱化。

    这里用的是平均池化层(官方可看:tf.keras.layers.AveragePooling2D),其模板为:

    tf.keras.layers.AveragePooling2D(
        pool_size=(2, 2),
        strides=None,
        padding='valid',
        data_format=None,
        **kwargs
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这里我们默认填充形状为valid(有效),此时不添加填充,结果输出的形状是: ⌊ i n p u t _ s h a p e − p o o l _ s i z e s t r i d e s ⌋ + 1 ( i n p u t _ s h a p e ≥ p o o l _ s i z e ) \lfloor \dfrac{input\_shape-pool\_size}{strides}\rfloor +1( input\_shape\geq pool\_size) stridesinput_shapepool_size+1(input_shapepool_size)

    如果填充形状为same(有效),此时将添加填充,结果输出的形状是:(特别地,如果跨度为1,则输出形状与输入形状相同) ⌊ i n p u t _ s h a p e − 1 s t r i d e s ⌋ + 1 \lfloor \dfrac{input\_shape-1}{strides}\rfloor +1 stridesinput_shape1+1

    2.2.卷积层

    输入图片矩阵 I I I大小: w × w w\times w w×w
    卷积核 K K K大小: k × k k\times k k×k
    步长 S S S大小: s s s
    填充 P P P大小: p p p

    卷积输出大小的计算公式为: o = ( w − k + 2 p ) s + 1 o=\dfrac{\left( w-k+2p\right) }{s}+1 o=s(wk+2p)+1

    可以用tf.keras.layers.Conv2D()(官方可看:tf.keras.layers.Conv2D)构造卷积核,代码结构为:

    tf.keras.layers.Conv2D(
        filters,
        kernel_size,
        strides=(1, 1),
        padding='valid',
        data_format=None,
        dilation_rate=(1, 1),
        groups=1,
        activation=None,
        use_bias=True,
        kernel_initializer='glorot_uniform',
        bias_initializer='zeros',
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        **kwargs
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    这里网络构造了3个3x3的卷积核,使用relu激活函数

    num_classes = 4
    
    model = models.Sequential([
        layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
        
        layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  
        layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样
        layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
        layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样
        layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
        layers.Dropout(0.3),                    # 防止过拟合,提高模型泛化能力
        
        layers.Flatten(),                       # Flatten层,连接卷积层与全连接层
        layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取
        layers.Dense(num_classes)               # 输出层,输出预期结果
    ])
    
    model.summary()  # 打印网络结构
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    输出如下:

    Model: "sequential"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     rescaling (Rescaling)       (None, 180, 180, 3)       0         
                                                                     
     conv2d (Conv2D)             (None, 178, 178, 16)      448       
                                                                     
     average_pooling2d (AverageP  (None, 89, 89, 16)       0         
     ooling2D)                                                       
                                                                     
     conv2d_1 (Conv2D)           (None, 87, 87, 32)        4640      
                                                                     
     average_pooling2d_1 (Averag  (None, 43, 43, 32)       0         
     ePooling2D)                                                     
                                                                     
     conv2d_2 (Conv2D)           (None, 41, 41, 64)        18496     
                                                                     
     dropout (Dropout)           (None, 41, 41, 64)        0         
                                                                     
     flatten (Flatten)           (None, 107584)            0         
                                                                     
     dense (Dense)               (None, 128)               13770880  
                                                                     
     dense_1 (Dense)             (None, 4)                 516       
                                                                     
    =================================================================
    Total params: 13,794,980
    Trainable params: 13,794,980
    Non-trainable params: 0
    _________________________________________________________________
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    2.3.编译设置

    • 损失函数(loss):用于衡量模型在训练期间的准确率,这里用sparse_categorical_crossentropy,原理与categorical_crossentropy(多类交叉熵损失 )一样,不过真实值采用的整数编码(例如第0个类用数字0表示,第3个类用数字3表示,官方可看:tf.keras.losses.SparseCategoricalCrossentropy
    • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新,这里是Adam(官方可看:tf.keras.optimizers.Adam
    • 评价函数(metrics):用于监控训练和测试步骤,本次使用accuracy,即被正确分类的图像的比率(官方可看:tf.keras.metrics.Accuracy
    # 设置优化器
    opt = tf.keras.optimizers.Adam(learning_rate=0.001)
    
    model.compile(optimizer=opt,
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    2.4.模型训练

    epochs = 30
    
    history = model.fit(
      train_ds,
      validation_data=val_ds,
      epochs=epochs
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    这里 s t e p step step i t e r a t i o n iteration iteration)的个数为: s t e p = ⌈ e x a m p l e N u m s ∗ e p o c h ​ b a t c h s i z e ⌉ = ⌈ ( 1125 − 225 ) ∗ 1 64 ⌉ = ⌈ 14.0625 ⌉ = 15 step=\lceil \dfrac{exampleNums∗epoch ​ }{batch size}\rceil=\lceil \dfrac{(1125-225)∗1}{64}\rceil=\lceil 14.0625\rceil=15 step=batchsizeexampleNumsepoch=64(1125225)1=14.0625=15

    输出:

    Epoch 1/30
    15/15 [==============================] - 12s 293ms/step - loss: 1.3749 - accuracy: 0.5144 - val_loss: 0.6937 - val_accuracy: 0.6533
    Epoch 2/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.6076 - accuracy: 0.7800 - val_loss: 0.4826 - val_accuracy: 0.7956
    Epoch 3/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.3909 - accuracy: 0.8589 - val_loss: 0.4557 - val_accuracy: 0.7911
    Epoch 4/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.2943 - accuracy: 0.8878 - val_loss: 0.4268 - val_accuracy: 0.8356
    Epoch 5/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.2307 - accuracy: 0.9056 - val_loss: 0.4260 - val_accuracy: 0.8400
    Epoch 6/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.2000 - accuracy: 0.9267 - val_loss: 0.3143 - val_accuracy: 0.8711
    Epoch 7/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.1496 - accuracy: 0.9367 - val_loss: 0.3277 - val_accuracy: 0.8844
    Epoch 8/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.0894 - accuracy: 0.9678 - val_loss: 0.2851 - val_accuracy: 0.9200
    Epoch 9/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.0638 - accuracy: 0.9800 - val_loss: 0.4995 - val_accuracy: 0.8578
    Epoch 10/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.1132 - accuracy: 0.9622 - val_loss: 0.4961 - val_accuracy: 0.8356
    Epoch 11/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.0576 - accuracy: 0.9800 - val_loss: 0.3318 - val_accuracy: 0.8844
    Epoch 12/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0417 - accuracy: 0.9878 - val_loss: 0.5433 - val_accuracy: 0.8756
    Epoch 13/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0301 - accuracy: 0.9944 - val_loss: 0.3797 - val_accuracy: 0.8978
    Epoch 14/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.0401 - accuracy: 0.9833 - val_loss: 0.3982 - val_accuracy: 0.8489
    Epoch 15/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.0260 - accuracy: 0.9922 - val_loss: 0.4777 - val_accuracy: 0.8844
    Epoch 16/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.0139 - accuracy: 0.9978 - val_loss: 0.3858 - val_accuracy: 0.8978
    Epoch 17/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.0067 - accuracy: 0.9989 - val_loss: 0.3942 - val_accuracy: 0.9156
    Epoch 18/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0059 - accuracy: 0.9989 - val_loss: 0.4101 - val_accuracy: 0.8844
    Epoch 19/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0039 - accuracy: 1.0000 - val_loss: 0.5176 - val_accuracy: 0.8889
    Epoch 20/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0068 - accuracy: 0.9989 - val_loss: 0.3836 - val_accuracy: 0.9156
    Epoch 21/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0116 - accuracy: 0.9989 - val_loss: 0.4635 - val_accuracy: 0.8800
    Epoch 22/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0062 - accuracy: 0.9989 - val_loss: 0.4315 - val_accuracy: 0.9022
    Epoch 23/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0047 - accuracy: 1.0000 - val_loss: 0.5728 - val_accuracy: 0.8933
    Epoch 24/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0090 - accuracy: 0.9989 - val_loss: 0.5049 - val_accuracy: 0.8889
    Epoch 25/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0210 - accuracy: 0.9911 - val_loss: 0.5712 - val_accuracy: 0.8756
    Epoch 26/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0354 - accuracy: 0.9878 - val_loss: 0.6332 - val_accuracy: 0.8889
    Epoch 27/30
    15/15 [==============================] - 2s 146ms/step - loss: 0.0781 - accuracy: 0.9789 - val_loss: 0.5726 - val_accuracy: 0.8578
    Epoch 28/30
    15/15 [==============================] - 2s 144ms/step - loss: 0.0713 - accuracy: 0.9767 - val_loss: 0.5084 - val_accuracy: 0.8889
    Epoch 29/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0439 - accuracy: 0.9822 - val_loss: 0.5302 - val_accuracy: 0.8711
    Epoch 30/30
    15/15 [==============================] - 2s 145ms/step - loss: 0.0343 - accuracy: 0.9878 - val_loss: 0.4488 - val_accuracy: 0.8933
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60

    保存模型:

    model.save('model')
    
    • 1

    输出日志:

    WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 3 of 3). These functions will not be directly callable after loading.
    INFO:tensorflow:Assets written to: model\assets
    INFO:tensorflow:Assets written to: model\assets
    
    • 1
    • 2
    • 3

    3.模型评估(acc:92.00%)

    model = tf.keras.models.load_model('model')
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    epochs_range = range(epochs)
    
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend()
    plt.title('Training and Validation Accuracy')
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.savefig('pic2.jpg', dpi=600) #指定分辨率保存
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    输出:
    请添加图片描述

    4.模型优化(acc:93.78%)

    优化思路:

    Training loss 不断下降,Validation loss趋于不变,说明网络过拟合,我们尝试将网络结构设置的更复杂一些,学习率调小一些

    4.1.优化网络

    num_classes = 4
    
    
    model = models.Sequential([
        layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
        
        layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  
        layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样
        layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
        layers.AveragePooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
        layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样
        layers.Conv2D(128, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
        layers.Dropout(0.3),  
        
        layers.Flatten(),                       # Flatten层,连接卷积层与全连接层
        layers.Dense(256, activation='relu'),   # 全连接层,特征进一步提取
        layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取
        layers.Dense(num_classes)               # 输出层,输出预期结果
    ])
    
    model.summary()  # 打印网络结构
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    输出:

    Model: "sequential_7"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     rescaling_7 (Rescaling)     (None, 180, 180, 3)       0         
                                                                     
     conv2d_25 (Conv2D)          (None, 178, 178, 16)      448       
                                                                     
     average_pooling2d_18 (Avera  (None, 89, 89, 16)       0         
     gePooling2D)                                                    
                                                                     
     conv2d_26 (Conv2D)          (None, 87, 87, 32)        4640      
                                                                     
     average_pooling2d_19 (Avera  (None, 43, 43, 32)       0         
     gePooling2D)                                                    
                                                                     
     conv2d_27 (Conv2D)          (None, 41, 41, 64)        18496     
                                                                     
     average_pooling2d_20 (Avera  (None, 20, 20, 64)       0         
     gePooling2D)                                                    
                                                                     
     conv2d_28 (Conv2D)          (None, 18, 18, 128)       73856     
                                                                     
     dropout_8 (Dropout)         (None, 18, 18, 128)       0         
                                                                     
     flatten_7 (Flatten)         (None, 41472)             0         
                                                                     
     dense_16 (Dense)            (None, 256)               10617088  
                                                                     
     dense_17 (Dense)            (None, 128)               32896     
                                                                     
     dense_18 (Dense)            (None, 4)                 516       
                                                                     
    =================================================================
    Total params: 10,747,940
    Trainable params: 10,747,940
    Non-trainable params: 0
    _________________________________________________________________
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    4.2.优化学习率

    # 设置优化器
    opt = tf.keras.optimizers.Adam(learning_rate=0.0008)
    
    model.compile(optimizer=opt,
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    训练模型:

    epochs = 50
    
    history = model.fit(
      train_ds,
      validation_data=val_ds,
      epochs=epochs
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出:

    Epoch 1/50
    15/15 [==============================] - 2s 131ms/step - loss: 1.0564 - accuracy: 0.4989 - val_loss: 0.7628 - val_accuracy: 0.5689
    Epoch 2/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.7389 - accuracy: 0.6589 - val_loss: 0.8257 - val_accuracy: 0.6711
    Epoch 3/50
    15/15 [==============================] - 2s 126ms/step - loss: 0.5849 - accuracy: 0.7667 - val_loss: 0.5710 - val_accuracy: 0.7422
    Epoch 4/50
    15/15 [==============================] - 2s 125ms/step - loss: 0.4292 - accuracy: 0.8389 - val_loss: 0.6052 - val_accuracy: 0.7689
    Epoch 5/50
    15/15 [==============================] - 2s 140ms/step - loss: 0.3802 - accuracy: 0.8511 - val_loss: 0.7504 - val_accuracy: 0.7556
    Epoch 6/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.3367 - accuracy: 0.8667 - val_loss: 0.4836 - val_accuracy: 0.8311
    Epoch 7/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.2773 - accuracy: 0.8889 - val_loss: 0.3823 - val_accuracy: 0.8622
    Epoch 8/50
    15/15 [==============================] - 2s 125ms/step - loss: 0.2457 - accuracy: 0.9067 - val_loss: 0.3668 - val_accuracy: 0.8622
    Epoch 9/50
    15/15 [==============================] - 2s 126ms/step - loss: 0.2333 - accuracy: 0.9144 - val_loss: 0.4030 - val_accuracy: 0.8489
    Epoch 10/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.2526 - accuracy: 0.9089 - val_loss: 0.6440 - val_accuracy: 0.8044
    Epoch 11/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.2331 - accuracy: 0.9122 - val_loss: 0.4930 - val_accuracy: 0.8444
    Epoch 12/50
    15/15 [==============================] - 2s 126ms/step - loss: 0.1934 - accuracy: 0.9311 - val_loss: 0.3481 - val_accuracy: 0.8844
    Epoch 13/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.1471 - accuracy: 0.9367 - val_loss: 0.3174 - val_accuracy: 0.9022
    Epoch 14/50
    15/15 [==============================] - 2s 122ms/step - loss: 0.1141 - accuracy: 0.9656 - val_loss: 0.4393 - val_accuracy: 0.8578
    Epoch 15/50
    15/15 [==============================] - 2s 125ms/step - loss: 0.0878 - accuracy: 0.9700 - val_loss: 0.4360 - val_accuracy: 0.8978
    Epoch 16/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.0718 - accuracy: 0.9744 - val_loss: 0.3478 - val_accuracy: 0.8978
    Epoch 17/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.0561 - accuracy: 0.9844 - val_loss: 0.3561 - val_accuracy: 0.9200
    Epoch 18/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.1152 - accuracy: 0.9544 - val_loss: 0.3755 - val_accuracy: 0.9111
    Epoch 19/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.0642 - accuracy: 0.9778 - val_loss: 0.3634 - val_accuracy: 0.8978
    Epoch 20/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.0347 - accuracy: 0.9922 - val_loss: 0.3544 - val_accuracy: 0.8978
    Epoch 21/50
    15/15 [==============================] - 2s 132ms/step - loss: 0.0432 - accuracy: 0.9844 - val_loss: 0.7549 - val_accuracy: 0.8222
    Epoch 22/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.0641 - accuracy: 0.9811 - val_loss: 0.4202 - val_accuracy: 0.8933
    Epoch 23/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.0295 - accuracy: 0.9900 - val_loss: 0.4618 - val_accuracy: 0.9200
    Epoch 24/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.0131 - accuracy: 0.9978 - val_loss: 0.4210 - val_accuracy: 0.9067
    Epoch 25/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.0172 - accuracy: 0.9944 - val_loss: 0.4878 - val_accuracy: 0.8978
    Epoch 26/50
    15/15 [==============================] - 2s 125ms/step - loss: 0.0086 - accuracy: 0.9978 - val_loss: 0.4908 - val_accuracy: 0.9111
    Epoch 27/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.0096 - accuracy: 0.9967 - val_loss: 0.5744 - val_accuracy: 0.8978
    Epoch 28/50
    15/15 [==============================] - 2s 122ms/step - loss: 0.0058 - accuracy: 0.9989 - val_loss: 0.5868 - val_accuracy: 0.8889
    Epoch 29/50
    15/15 [==============================] - 2s 126ms/step - loss: 0.0196 - accuracy: 0.9911 - val_loss: 0.5722 - val_accuracy: 0.8578
    Epoch 30/50
    15/15 [==============================] - 2s 126ms/step - loss: 0.0108 - accuracy: 0.9967 - val_loss: 0.5498 - val_accuracy: 0.9022
    Epoch 31/50
    15/15 [==============================] - 2s 125ms/step - loss: 0.0041 - accuracy: 1.0000 - val_loss: 0.5006 - val_accuracy: 0.9111
    Epoch 32/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.0020 - accuracy: 1.0000 - val_loss: 0.5010 - val_accuracy: 0.9067
    Epoch 33/50
    15/15 [==============================] - 2s 130ms/step - loss: 0.0067 - accuracy: 0.9967 - val_loss: 0.5418 - val_accuracy: 0.9156
    Epoch 34/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.0092 - accuracy: 0.9967 - val_loss: 0.6289 - val_accuracy: 0.9067
    Epoch 35/50
    15/15 [==============================] - 2s 122ms/step - loss: 0.0152 - accuracy: 0.9933 - val_loss: 0.5908 - val_accuracy: 0.8978
    Epoch 36/50
    15/15 [==============================] - 2s 123ms/step - loss: 0.0359 - accuracy: 0.9911 - val_loss: 0.5404 - val_accuracy: 0.8756
    Epoch 37/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.0492 - accuracy: 0.9811 - val_loss: 0.5009 - val_accuracy: 0.9022
    Epoch 38/50
    15/15 [==============================] - 2s 125ms/step - loss: 0.0395 - accuracy: 0.9844 - val_loss: 0.4718 - val_accuracy: 0.9244
    Epoch 39/50
    15/15 [==============================] - 2s 125ms/step - loss: 0.0357 - accuracy: 0.9956 - val_loss: 0.6038 - val_accuracy: 0.8844
    Epoch 40/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.0309 - accuracy: 0.9889 - val_loss: 0.4189 - val_accuracy: 0.9200
    Epoch 41/50
    15/15 [==============================] - 2s 126ms/step - loss: 0.0093 - accuracy: 0.9989 - val_loss: 0.5180 - val_accuracy: 0.9067
    Epoch 42/50
    15/15 [==============================] - 2s 125ms/step - loss: 0.0033 - accuracy: 1.0000 - val_loss: 0.4415 - val_accuracy: 0.9244
    Epoch 43/50
    15/15 [==============================] - 2s 124ms/step - loss: 0.0016 - accuracy: 1.0000 - val_loss: 0.4622 - val_accuracy: 0.9378
    Epoch 44/50
    15/15 [==============================] - 2s 125ms/step - loss: 5.7711e-04 - accuracy: 1.0000 - val_loss: 0.4805 - val_accuracy: 0.9333
    Epoch 45/50
    15/15 [==============================] - 2s 125ms/step - loss: 4.1283e-04 - accuracy: 1.0000 - val_loss: 0.4820 - val_accuracy: 0.9244
    Epoch 46/50
    15/15 [==============================] - 2s 126ms/step - loss: 3.2792e-04 - accuracy: 1.0000 - val_loss: 0.4859 - val_accuracy: 0.9333
    Epoch 47/50
    15/15 [==============================] - 2s 127ms/step - loss: 2.7573e-04 - accuracy: 1.0000 - val_loss: 0.4932 - val_accuracy: 0.9289
    Epoch 48/50
    15/15 [==============================] - 2s 123ms/step - loss: 2.7769e-04 - accuracy: 1.0000 - val_loss: 0.4877 - val_accuracy: 0.9333
    Epoch 49/50
    15/15 [==============================] - 2s 124ms/step - loss: 2.6387e-04 - accuracy: 1.0000 - val_loss: 0.5107 - val_accuracy: 0.9289
    Epoch 50/50
    15/15 [==============================] - 2s 122ms/step - loss: 2.1140e-04 - accuracy: 1.0000 - val_loss: 0.4979 - val_accuracy: 0.9378
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100

    可以发现val_accuracy的最大值在训练的第50次,为93.78%

    训练过程可视化图:
    请添加图片描述

  • 相关阅读:
    【OceanBase系列】—— OceanBase V4.x 中的常用备份恢复 SQL 总结
    fsbrain的学习笔记
    基于springboot的服装批发市场商家相册系统
    计算机网络——网络层数据交换方式、IP数据报、IPv4地址、重要协议、IPv6
    人工神经网络反向传播,神经网络的前向传播
    Postman 批量测试接口详细教程
    Redis——Lettuce连接redis集群
    简述linux系统中软件包管理系统
    ES6.--Promise、任务队列和事件循环
    第六章 图论 16 AcWing 1558. 加油站
  • 原文地址:https://blog.csdn.net/qq_45550375/article/details/126325300