• 如何使用Tensorflow的VGG16预置模型


    参考列表

    1. 如何用VGG16训练MNIST数据集?
    2. Tensorflow官方文档

    1 机器和环境

    使用的是矩池云+Pycharm进行测试,镜像内容为:

    • Ubuntu18.04
    • Python 3.9
    • CUDA 11.2
    • cuDNN 8
    • NVCC
    • Tensorflow 2.8.0
    • VNC

    预置模型参数解释

    在官方给定的文档中,模型的预置参数如下

    tf.keras.applications.vgg16.VGG16(
        include_top=True,
        weights='imagenet',
        input_tensor=None,
        input_shape=None,
        pooling=None,
        classes=1000,
        classifier_activation='softmax'
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    1. weights = ‘imagenet’:这表示模型使用了ImageNet中的预训练好的权值,直接拿来用可以大幅度降低训练成本,即“迁移学习”,否则把这个变量置为None,一切都从头开始训练。(笔者以MNIST手写数字识别为例,使用ImageNet预训练好的权值是第一个epoch即达到了87%,而不适用的话就只有20%);
    2. include_top :是否需要它的全连接网络(头部网络),由于Image Net训练的是千分类的任务(classes=1000),如果需要使用自己的数据,则需要把这个变量置为False,并且把classes这个变量缺省;
    3. input_shape = None:默认的值(224,244,3),也需要根据自己的数据集来更改;
    4. pooling:用于特征提取的可选池化模式,None意味着模型的输出将是最后一个卷积块的 4D 张量输出,'avg’意味着全局平均池化将应用于最后一个卷积块的输出,因此模型的输出将是一个 2D 张量,'max’表示将应用全局最大池化。
    5. input_tensor:可选的参数,即tensor作为模型的图像输入,保持默认即可;
    6. classifier_activation:输入字符串或者回调,激活函数用作最后的全连接层,当include_top 为False的时候可以忽略该选项,当加载预训练模型的weights 时,即weights不为None时,该选项只能为None或者’Softmax’。

    实例——应用MNIST数据集训练

    (1) 导入包

    import tensorflow as tf
    
    • 1

    (2) 加载数据集

    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    
    • 1

    (3) 处理数据集的形状,参考如何用VGG16训练MNIST数据集?
    这里测试了一下,不可以忽略该步骤,因为网络要求最小的size是32 × 32 × 3

    现在,我们需要28×28 填充为32×32(当然也可以填充为48×48,它可视为超参数,为了提升性能,你可以任意”折腾“,经过测试,32×32性能较好)。之所以填充,是因为VGG16有5层池化(汇聚)层,如果不加以扩充,可能多层池化后图片变成了1*1大小。

        x_train = tf.pad(x_train, [[0, 0], [2, 2], [2, 2]]) / 255  # 填充,除以255表示归一化
        x_test = tf.pad(x_test, [[0, 0], [2, 2], [2, 2]]) / 255
    
    • 1
    • 2

    (4) 升级数据维度,参考如何用VGG16训练MNIST数据集?

    由于VGG16是适配3通道图片的,而MNIST是单通道的图片,所以需要将其”升级“为3单通道。最简单的升级就是对单通道的图片连续复制3次。

        x_train = tf.stack((x_train, x_train, x_train), axis=-1)
        x_test = tf.stack((x_test, x_test, x_test), axis=-1)
    
    • 1
    • 2

    (5) 模型实例化

       model = tf.keras.applications.VGG16(
            include_top=False,
            weights= 'imagenet',
            input_tensor=None,
            input_shape=(32, 32, 3),
            pooling=None,
            classes=1000,
            classifier_activation="softmax",
        )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    (6)冻结其他层
    由于使用了预训练模型,VGG16的基础部分(卷积、池化等)的权值都不用训练。所以逐层设定layer.trainable = False将其冻结。

     for layer in model.layers:
            layer.trainable = False
        )
    
    • 1
    • 2
    • 3

    (7)由于没有使用模型自带的全连接层,因此我们需要自己重写全连接层

        x = tf.keras.layers.Flatten()(model.output)  # 展平
        x = tf.keras.layers.Dense(4096, activation='relu')(x)  # 定义全连接
        x = tf.keras.layers.Dropout(0.5)(x) # Dropout
        x = tf.keras.layers.Dense(4096, activation='relu')(x)
        x = tf.keras.layers.Dropout(0.5)(x)
        predictions = tf.keras.layers.Dense(10, activation='softmax')(x)  # softmax回归,10分类
        head_model = tf.keras.Model(inputs=model.input, outputs=predictions)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    (8)定义优化器,损失函数,训练的batch,epoch,并将训练集和测试集喂给网络

     head_model.compile(optimizer='adam',
                           loss=tf.keras.losses.sparse_categorical_crossentropy,
                           metrics=['accuracy'])
        history = head_model.fit(x_train, y_train,
                                 batch_size=64,
                                 epochs=10,
                                 validation_data=(x_test, y_test))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    (9) 全部代码如下

    import tensorflow as tf
    
    
    if __name__ == '__main__':
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    
        x_train = tf.pad(x_train, [[0, 0], [2, 2], [2, 2]]) / 255  # 填充,除以255表示归一化
        x_test = tf.pad(x_test, [[0, 0], [2, 2], [2, 2]]) / 255
    
        x_train = tf.stack((x_train, x_train, x_train), axis=-1)
        x_test = tf.stack((x_test, x_test, x_test), axis=-1)
    
        model = tf.keras.applications.VGG16(
            include_top=False,
            weights= 'imagenet',
            input_tensor=None,
            input_shape=(32, 32, 3),
            pooling=None,
            classes=1000,
            classifier_activation="softmax",
        )
        for layer in model.layers:
            layer.trainable = False
    
        model.summary()
    
        x = tf.keras.layers.Flatten()(model.output)  # 展平
        x = tf.keras.layers.Dense(4096, activation='relu')(x)  # 定义全连接
        x = tf.keras.layers.Dropout(0.5)(x)
        x = tf.keras.layers.Dense(4096, activation='relu')(x)
        x = tf.keras.layers.Dropout(0.5)(x)
        predictions = tf.keras.layers.Dense(10, activation='softmax')(x)  # softmax回归,10分类
        head_model = tf.keras.Model(inputs=model.input, outputs=predictions)
        # 搭配训练参数
        head_model.compile(optimizer='adam',
                           loss=tf.keras.losses.sparse_categorical_crossentropy,
                           metrics=['accuracy'])
        history = head_model.fit(x_train, y_train,
                                 batch_size=64,
                                 epochs=10,
                                 validation_data=(x_test, y_test))
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
  • 相关阅读:
    模拟算法及其优化
    带你掌握如何使用CANN 算子ST测试工具msopst
    pycharm 里面安装 codeium 插件的时候,不能够弹出登录界面
    XSS-labs靶场实战(一)——第1-3关
    软考高级系统架构设计师系列论文之:论软件系统架构风格
    Python之文件处理-JSON文件
    PostgreSQL:查询元数据(表 、字段)信息、库表导入导出命令
    【从零开始的大数据学习】Flink官方教程学习笔记(一)
    代码随想录学习记录——栈与队列篇
    快速部署 微软开源的 Garnet 键值数据库
  • 原文地址:https://blog.csdn.net/qq_42635142/article/details/125523230