• Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13)


            本笔记记录CNN做CIFAR100数据集的训练相关内容,代码中使用了类似VGG13的网络结构,做了两个Sequetial(CNN和全连接层),没有用Flatten层而是用reshape操作做CNN和全连接层的中转操作。由于网络层次较深,参数量相比之前的网络多了不少,因此只做了10次epoch(RTX4090),没有继续跑了,最终准确率大概在33.8%左右。

    1. import os
    2. import time
    3. import tensorflow as tf
    4. from tensorflow import keras
    5. from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input
    6. os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
    7. #tf.random.set_seed(12345)
    8. tf.__version__
    9. #如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
    10. # https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
    11. #下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
    12. # 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
    13. (x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
    14. print("Train data shape:", x_train.shape)
    15. print("Train label shape:", y_train.shape)
    16. print("Test data shape:", x_test.shape)
    17. print("Test label shape:", y_test.shape)
    18. def preprocess(x, y):
    19. x = tf.cast(x, dtype=tf.float32) / 255.
    20. y = tf.cast(y, dtype=tf.int32)
    21. return x,y
    22. y_train = tf.squeeze(y_train, axis=1)
    23. y_test = tf.squeeze(y_test, axis=1)
    24. batch_size = 128
    25. train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    26. train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)
    27. test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    28. test_db = test_db.map(preprocess).batch(batch_size)
    29. sample = next(iter(train_db))
    30. print("Train data sample:", sample[0].shape, sample[1].shape,
    31. tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))
    32. #创建CNN网络,总共4unit,每个unit主要是两个卷积层和Max Pooling池化层
    33. cnn_layers = [
    34. #unit 1
    35. layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),
    36. layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),
    37. #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    38. layers.MaxPool2D(pool_size=[2,2], strides=2),
    39. #unit 2
    40. layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),
    41. layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),
    42. #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    43. layers.MaxPool2D(pool_size=[2,2], strides=2),
    44. #unit 3
    45. layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),
    46. layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),
    47. #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    48. layers.MaxPool2D(pool_size=[2,2], strides=2),
    49. #unit 4
    50. layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    51. layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    52. #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    53. layers.MaxPool2D(pool_size=[2,2], strides=2),
    54. #unit 5
    55. layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    56. layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    57. #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    58. layers.MaxPool2D(pool_size=[2,2], strides=2),
    59. ]
    60. def main():
    61. #[b, 32, 32, 3] => [b, 1, 1, 512]
    62. cnn_net = Sequential(cnn_layers)
    63. cnn_net.build(input_shape=[None, 32, 32, 3])
    64. #测试一下卷积层的输出
    65. #x = tf.random.normal([4, 32, 32, 3])
    66. #out = cnn_net(x)
    67. #print(out.shape)
    68. #创建全连接层, 输出为100分类
    69. fc_net = Sequential([
    70. layers.Dense(256, activation='relu'),
    71. layers.Dense(128, activation='relu'),
    72. layers.Dense(100, activation=None),
    73. ])
    74. fc_net.build(input_shape=[None, 512])
    75. #设置优化器
    76. optimizer = optimizers.Adam(learning_rate=1e-4)
    77. #记录cnn层和全连接层所有可训练参数, 实现的效果类似list拼接,比如
    78. # [1, 2] + [3, 4] => [1, 2, 3, 4]
    79. variables = cnn_net.trainable_variables + fc_net.trainable_variables
    80. #进行训练
    81. num_epoches = 10
    82. for epoch in range(num_epoches):
    83. for step, (x,y) in enumerate(train_db):
    84. with tf.GradientTape() as tape:
    85. #[b, 32, 32, 3] => [b, 1, 1, 512]
    86. out = cnn_net(x)
    87. #flatten打平 => [b, 512]
    88. out = tf.reshape(out, [-1, 512])
    89. #使用全连接层做100分类logits输出
    90. #[b, 512] => [b, 100]
    91. logits = fc_net(out)
    92. #标签做one_hot encoding
    93. y_onehot = tf.one_hot(y, depth=100)
    94. #计算损失
    95. loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
    96. loss = tf.reduce_mean(loss)
    97. #计算梯度
    98. grads = tape.gradient(loss, variables)
    99. #更新参数
    100. optimizer.apply_gradients(zip(grads, variables))
    101. if (step % 100 == 0):
    102. print("Epoch[", epoch + 1, "/", num_epoches, "]: step-", step, " loss:", float(loss))
    103. #进行验证
    104. total_samples = 0
    105. total_correct = 0
    106. for x,y in test_db:
    107. out = cnn_net(x)
    108. out = tf.reshape(out, [-1, 512])
    109. logits = fc_net(out)
    110. prob = tf.nn.softmax(logits, axis=1)
    111. pred = tf.argmax(prob, axis=1)
    112. pred = tf.cast(pred, dtype=tf.int32)
    113. correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
    114. correct = tf.reduce_sum(correct)
    115. total_samples += x.shape[0]
    116. total_correct += int(correct)
    117. #统计准确率
    118. acc = total_correct / total_samples
    119. print("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)
    120. if __name__ == '__main__':
    121. main()

    运行结果:

  • 相关阅读:
    Privacy-preserving record linkage on large real world datasets论文总结
    读论文-NeRF学习笔记
    HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE
    js中 slice 用法用法全解析
    《Principles of Model Checking》Chapter 5 Linear Temporal Logic
    绘图系统三:支持散点图、极坐标和子图绘制
    浏览器中修改视频播放速度
    算法练习-第二天(合并两个排序的链表)
    ssm基于微信小程序的新生自助报到系统+ssm+uinapp+Mysql+计算机毕业设计
    Kafka3.0.0版本——消费者(消费者组案例)
  • 原文地址:https://blog.csdn.net/vivo01/article/details/137977454