• 深度学习实战(1):花的分类任务


    写在前面:

    实验目的:通过建立Alexnet神经网络建立模型并根据训练数据来训练模型 以达到可以将一张花的类别进行分类

    Python版本:Python3

    IDE:VSCode

    系统:MacOS

    数据集以及代码的资源放在文章末尾了 有需要请自取~

    目录

    写在前面:

    前言

    数据集 

    训练模型代码 (附有注释)

    训练集数据量展示

    训练迭代过程展示

    训练结果 Accuracy展示 

    训练结果 Loss展示 

    测试集 

    预测结果代码 

    预测结果展示

    结语 


     

    前言

    本文仅作为学习训练 不涉及任何商业用途 如有错误或不足之处还请指出

    数据集 

    数据集一共有五种花的类别 但本次实验模型仅用了rose和sunflower两种类别进行分类测试

    五种花的类别:

     Rose:

    Sunflower: 

    训练模型代码 (附有注释)

    1. import os , glob
    2. from sklearn.model_selection import train_test_split
    3. import tensorflow as tf
    4. from tensorflow import keras
    5. from tensorflow.keras import layers
    6. import matplotlib.pyplot as plt
    7. # 变量
    8. resize = 224 # 图片尺寸参数
    9. epochs = 8 # 迭代次数
    10. batch_size = 5 # 每次训练多少张
    11. #——————————————————————————————————————————————————————————————————————————————————
    12. # 训练集路径
    13. train_data_path = '/Users/liqun/Desktop/KS/MyPython/DataSet/flowers/Training'
    14. # 玫瑰花文件夹路径
    15. rose_path = os.path.join(train_data_path,'rose')
    16. # 太阳花文件夹路径
    17. sunflower_path = os.path.join(train_data_path,'sunflower')
    18. # 将文件夹内的图片读取出来
    19. fpath_rose = [os.path.abspath(fp) for fp in glob.glob(os.path.join(rose_path,'*.jpg'))]
    20. fpath_sunflower = [os.path.abspath(fp) for fp in glob.glob(os.path.join(sunflower_path,'*.jpg'))]
    21. #文件数量
    22. num_rose = len(fpath_rose)
    23. num_sunflower = len(fpath_sunflower)
    24. # 设置标签
    25. label_rose = [0] * num_rose
    26. label_sunflower = [1] * num_sunflower
    27. # 展示
    28. print('rose: ', num_rose)
    29. print('sunflower: ', num_sunflower)
    30. # 划分为多少验证集
    31. RATIO_TEST = 0.1
    32. num_rose_test = int(num_rose * RATIO_TEST)
    33. num_sunflower_test = int(num_sunflower * RATIO_TEST)
    34. # train
    35. fpath_train = fpath_rose[num_rose_test:] + fpath_sunflower[num_sunflower_test:]
    36. label_train = label_rose[num_rose_test:] + label_sunflower[num_sunflower_test:]
    37. # validation
    38. fpath_vali = fpath_rose[:num_rose_test] + fpath_sunflower[:num_sunflower_test]
    39. label_vali = label_rose[:num_rose_test] + label_sunflower[:num_sunflower_test]
    40. num_train = len(fpath_train)
    41. num_vali = len(fpath_vali)
    42. # 展示
    43. print('num_train: ', num_train)
    44. print('num_label: ', num_vali)
    45. # 预处理函数
    46. def preproc(fpath, label):
    47. image_byte = tf.io.read_file(fpath) # 读取文件
    48. image = tf.io.decode_image(image_byte) # 检测图像是否为BMP,GIF,JPEG或PNG,并执行相应的操作将输入字节string转换为类型uint8的Tensor
    49. image_resize = tf.image.resize_with_pad(image, 224, 224) #缩放到224*224
    50. image_norm = tf.cast(image_resize, tf.float32) / 255. #归一化
    51. label_onehot = tf.one_hot(label, 2)
    52. return image_norm, label_onehot
    53. dataset_train = tf.data.Dataset.from_tensor_slices((fpath_train, label_train)) #将数据进行预处理
    54. dataset_train = dataset_train.shuffle(num_train).repeat() #打乱顺序
    55. dataset_train = dataset_train.map(preproc, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    56. dataset_train = dataset_train.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) #一批次处理多少份
    57. dataset_vali = tf.data.Dataset.from_tensor_slices((fpath_vali, label_vali))
    58. dataset_vali = dataset_vali.shuffle(num_vali).repeat()
    59. dataset_vali = dataset_vali.map(preproc, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    60. dataset_vali = dataset_vali.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    61. #——————————————————————————————————————————————————————————————————————————————————
    62. # 建立模型 卷积神经网络
    63. model = tf.keras.Sequential(name='Alexnet')
    64. # 第一层
    65. model.add(layers.Conv2D(filters=96, kernel_size=(11,11),
    66. strides=(4,4), padding='valid',
    67. input_shape=(resize,resize,3),
    68. activation='relu'))
    69. model.add(layers.BatchNormalization())
    70. # 第一层池化层:最大池化层
    71. model.add(layers.MaxPooling2D(pool_size=(3,3),
    72. strides=(2,2),
    73. padding='valid'))
    74. #第二层
    75. model.add(layers.Conv2D(filters=256, kernel_size=(5,5),
    76. strides=(1,1), padding='same',
    77. activation='relu'))
    78. model.add(layers.BatchNormalization())
    79. #第二层池化层
    80. model.add(layers.MaxPooling2D(pool_size=(3,3),
    81. strides=(2,2),
    82. padding='valid'))
    83. #第三层
    84. model.add(layers.Conv2D(filters=384, kernel_size=(3,3),
    85. strides=(1,1), padding='same',
    86. activation='relu'))
    87. #第四层
    88. model.add(layers.Conv2D(filters=384, kernel_size=(3,3),
    89. strides=(1,1), padding='same',
    90. activation='relu'))
    91. #第五层
    92. model.add(layers.Conv2D(filters=256, kernel_size=(3,3),
    93. strides=(1,1), padding='same',
    94. activation='relu'))
    95. #池化层
    96. model.add(layers.MaxPooling2D(pool_size=(3,3),
    97. strides=(2,2), padding='valid'))
    98. #第6,7,8层
    99. model.add(layers.Flatten())
    100. model.add(layers.Dense(4096, activation='relu'))
    101. model.add(layers.Dropout(0.5))
    102. model.add(layers.Dense(4096, activation='relu'))
    103. model.add(layers.Dropout(0.5))
    104. model.add(layers.Dense(1000, activation='relu'))
    105. model.add(layers.Dropout(0.5))
    106. # Output Layer
    107. model.add(layers.Dense(2, activation='softmax'))
    108. # Training 优化器 随机梯度下降算法
    109. model.compile(loss='categorical_crossentropy',
    110. optimizer='sgd', #梯度下降法
    111. metrics=['accuracy'])
    112. history = model.fit(dataset_train,
    113. steps_per_epoch = num_train//batch_size,
    114. epochs = epochs, #迭代次数
    115. validation_data = dataset_vali,
    116. validation_steps = num_vali//batch_size,
    117. verbose = 1)
    118. # 评分标准
    119. scores_train = model.evaluate(dataset_train, steps=num_train//batch_size, verbose=1)
    120. print(scores_train)
    121. scores_vali = model.evaluate(dataset_vali, steps=num_vali//batch_size, verbose=1)
    122. print(scores_vali)
    123. #保存模型
    124. model.save('/Users/liqun/Desktop/KS/MyPython/project/flowerModel.h5')
    125. '''
    126. history对象的history内容(history.history)是字典类型,
    127. 键的内容受metrics的设置影响,值的长度与epochs值一致。
    128. '''
    129. history_dict = history.history
    130. train_loss = history_dict['loss']
    131. train_accuracy = history_dict['accuracy']
    132. val_loss = history_dict['val_loss']
    133. val_accuracy = history_dict['val_accuracy']
    134. # Draw loss
    135. plt.figure()
    136. plt.plot(range(epochs), train_loss, label='train_loss')
    137. plt.plot(range(epochs), val_loss, label='val_loss')
    138. plt.legend()
    139. plt.xlabel('epochs')
    140. plt.ylabel('loss')
    141. # Draw accuracy
    142. plt.figure()
    143. plt.plot(range(epochs), train_accuracy, label='train_accuracy')
    144. plt.plot(range(epochs), val_accuracy, label='val_accuracy')
    145. plt.legend()
    146. plt.xlabel('epochs')
    147. plt.ylabel('accuracy')
    148. # Display
    149. plt.show()
    150. print('Train has finished')

    训练集数据量展示

    训练迭代过程展示

    训练结果 Accuracy展示 

    训练结果 Loss展示 

    测试集 

    预测结果代码 

    1. import cv2
    2. from tensorflow.keras.models import load_model
    3. resize = 224
    4. label = ('rose', 'sunflower')
    5. image = cv2.resize(cv2.imread('/Users/liqun/Desktop/KS/MyPython/DataSet/flowers/Training/sunflower/23286304156_3635f7de05.jpg'),(resize,resize))
    6. image = image.astype("float") / 255.0 # 归一化
    7. image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))
    8. # 加载模型
    9. model = load_model('/Users/liqun/Desktop/KS/MyPython/project/flowerModel.h5')
    10. predict = model.predict(image)
    11. i = predict.argmax(axis=1)[0]
    12. # 展示结果
    13. print('——————————————————————')
    14. print('Predict result')
    15. print(label[i],':',max(predict[0])*100,'%')

    预测结果展示

    结语 

    模型到这里就训练并检测完毕了 如有需要的小伙伴可以下载下方的数据集测试集及源代码

    链接: https://pan.baidu.com/s/1OJfwcF1PvX9qkZwT7MXd_Q?pwd=i0bt 提取码: i0bt

    如果我的文章对你有帮助 麻烦点个赞再走呀 

  • 相关阅读:
    算法训练 第八周
    3.1_2 覆盖与交换
    【vue】牛客专题训练01
    MySql安全加固:无关或匿名帐号&是否更改root用户&避免空口令用户&是否加密数据库密码
    【前端】CSS
    园子开店记:被智能的淘宝处罚,说是“预防性的违规”
    易基因|RNA m6A甲基化测序(MeRIP-seq)技术介绍
    vscode终端npm install报错
    创建个人github.io主页(基础版)//吐槽:很多国内教程已经失效了
    Eureka: Netflix开源的服务发现框架
  • 原文地址:https://blog.csdn.net/m0_54689021/article/details/126075918