• 卷积神经网络——vgg16网络及其python实现


    1、介绍     

            VGG-16网络包括13个卷积层和3个全连接层,网络结构较LeNet-5等网络变得十分复杂,但同时也有不错的效果。VGG16有强大的拟合能力在当时取得了非常的效果,但同时VGG也有部分不足:
    1、巨大参数量导致训练时间过长,调参难度较大;
    2、模型所需内存容量大,VGG的权值文件很大,用到实际应用会比较困难。

    2、结构原理 

    这是经典的vgg网络,输入图片大小为224*224。

    下面这为官方给出的几种VGG结构图。

     

     现在多用的为D模型。

    简单介绍下过程,输入224*224大小的图片,然后用两次64个3*3的卷积核进行全采集,也就是补零采集,保证特征不丢失,得到64*224*224的特征;池化层得到64*112*112;再利用128个3*3的卷积核进行特征采集两次,得到特征112*112*128;池化得到56*56*128大小特征.........反复这样操作,最后卷积完得到7*7*512的特征,然后利用全连接层进行展开,最后得到1000个特征,随后进行概率分类操作。

    3、python实现

            选用的数据集为fashion数据集,具体请另外了解。数据可直接在库中导入,本文用class网络编写神经网络程序。

    1. class VGG16(Model):
    2. def __init__(self):
    3. super(VGG16, self).__init__()
    4. self.c1 = Conv2D(filters=64, kernel_size=(3, 3), padding='same')
    5. self.b1 = BatchNormalization()
    6. self.a1 = Activation('relu')
    7. self.c2 = Conv2D(filters=64, kernel_size=(3, 3), padding='same', )
    8. self.b2 = BatchNormalization()
    9. self.a2 = Activation('relu')
    10. self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
    11. self.d1 = Dropout(0.2)
    12. self.c3 = Conv2D(filters=128, kernel_size=(3, 3), padding='same')
    13. self.b3 = BatchNormalization()
    14. self.a3 = Activation('relu')
    15. self.c4 = Conv2D(filters=128, kernel_size=(3, 3), padding='same')
    16. self.b4 = BatchNormalization()
    17. self.a4 = Activation('relu')
    18. self.p2 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
    19. self.d2 = Dropout(0.2)
    20. self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
    21. self.b5 = BatchNormalization()
    22. self.a5 = Activation('relu')
    23. self.c6 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
    24. self.b6 = BatchNormalization()
    25. self.a6 = Activation('relu')
    26. self.c7 = Conv2D(filters=256, kernel_size=(3, 3), padding='same')
    27. self.b7 = BatchNormalization()
    28. self.a7 = Activation('relu')
    29. self.p3 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
    30. self.d3 = Dropout(0.2)
    31. self.c8 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    32. self.b8 = BatchNormalization()
    33. self.a8 = Activation('relu')
    34. self.c9 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    35. self.b9 = BatchNormalization()
    36. self.a9 = Activation('relu')
    37. self.c10 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    38. self.b10 = BatchNormalization()
    39. self.a10 = Activation('relu')
    40. self.p4 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
    41. self.d4 = Dropout(0.2)
    42. self.c11 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    43. self.b11 = BatchNormalization()
    44. self.a11 = Activation('relu')
    45. self.c12 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    46. self.b12 = BatchNormalization()
    47. self.a12 = Activation('relu')
    48. self.c13 = Conv2D(filters=512, kernel_size=(3, 3), padding='same')
    49. self.b13 = BatchNormalization()
    50. self.a13 = Activation('relu')
    51. self.p5 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
    52. self.d5 = Dropout(0.2)
    53. self.flatten = Flatten()
    54. self.f1 = Dense(512, activation='relu')
    55. self.d6 = Dropout(0.2)
    56. self.f2 = Dense(512, activation='relu')
    57. self.d7 = Dropout(0.2)
    58. self.f3 = Dense(10, activation='softmax')
    59. def call(self, x):
    60. x = self.c1(x)
    61. x = self.b1(x)
    62. x = self.a1(x)
    63. x = self.c2(x)
    64. x = self.b2(x)
    65. x = self.a2(x)
    66. x = self.p1(x)
    67. x = self.d1(x)
    68. x = self.c3(x)
    69. x = self.b3(x)
    70. x = self.a3(x)
    71. x = self.c4(x)
    72. x = self.b4(x)
    73. x = self.a4(x)
    74. x = self.p2(x)
    75. x = self.d2(x)
    76. x = self.c5(x)
    77. x = self.b5(x)
    78. x = self.a5(x)
    79. x = self.c6(x)
    80. x = self.b6(x)
    81. x = self.a6(x)
    82. x = self.c7(x)
    83. x = self.b7(x)
    84. x = self.a7(x)
    85. x = self.p3(x)
    86. x = self.d3(x)
    87. x = self.c8(x)
    88. x = self.b8(x)
    89. x = self.a8(x)
    90. x = self.c9(x)
    91. x = self.b9(x)
    92. x = self.a9(x)
    93. x = self.c10(x)
    94. x = self.b10(x)
    95. x = self.a10(x)
    96. x = self.p4(x)
    97. x = self.d4(x)
    98. x = self.c11(x)
    99. x = self.b11(x)
    100. x = self.a11(x)
    101. x = self.c12(x)
    102. x = self.b12(x)
    103. x = self.a12(x)
    104. x = self.c13(x)
    105. x = self.b13(x)
    106. x = self.a13(x)
    107. x = self.p5(x)
    108. x = self.d5(x)
    109. x = self.flatten(x)
    110. x = self.f1(x)
    111. x = self.d6(x)
    112. x = self.f2(x)
    113. x = self.d7(x)
    114. y = self.f3(x)
    115. return y
    116. model = VGG16()

    读取数据

    1. import tensorflow as tf
    2. import os
    3. import numpy as np
    4. from matplotlib import pyplot as plt
    5. from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
    6. from tensorflow.keras import Model
    7. np.set_printoptions(threshold=np.inf)
    8. fashion = tf.keras.datasets.fashion_mnist
    9. (x_train, y_train), (x_test, y_test) = fashion.load_data()
    10. x_train, x_test = x_train / 255.0, x_test / 255.0
    11. x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
    12. x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

    迭代训练

    1. model.compile(optimizer='adam',
    2. loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    3. metrics=['sparse_categorical_accuracy'])
    4. cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
    5. save_weights_only=True,
    6. save_best_only=True)
    7. history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_data=(x_test, y_test), validation_freq=1,
    8. callbacks=[cp_callback])

    绘制结果图

    1. acc = history.history['sparse_categorical_accuracy']
    2. val_acc = history.history['val_sparse_categorical_accuracy']
    3. loss = history.history['loss']
    4. val_loss = history.history['val_loss']
    5. plt.subplot(1, 2, 1)
    6. plt.plot(acc, label='Training Accuracy')
    7. plt.plot(val_acc, label='Validation Accuracy')
    8. plt.title('Training and Validation Accuracy')
    9. plt.legend()
    10. plt.subplot(1, 2, 2)
    11. plt.plot(loss, label='Training Loss')
    12. plt.plot(val_loss, label='Validation Loss')
    13. plt.title('Training and Validation Loss')
    14. plt.legend()
    15. plt.show()

     

     虽然不是很稳定,但总的来说准确率还可以。

  • 相关阅读:
    Unity3D XML与Properties配置文件读取详解
    Sprinig Boot优雅实现接口幂等性
    JVM【类加载与GC垃圾回收机制】
    Qt事件的详细介绍和原理
    UML--类图的表示
    新款 锐科达 SV-2402VP SIP广播音频模块 支持RTP流音频广播
    【虹科干货】逻辑数据库可能已经无法满足需求了!
    大白话说Python+Flask入门(六)Flask SQLAlchemy操作mysql数据库
    【STM32单片机】贪吃蛇游戏设计
    剑指 Offer 46. 把数字翻译成字符串(DP)
  • 原文地址:https://blog.csdn.net/abc1234abcdefg/article/details/125495965