• 趋动云模型--猫狗识别


    写在前面:感谢Datawhale各位大大和各位群友的答疑解惑!

    目录

    代码注释

    函数详解


    代码注释

    1. ###导入模块 ###
    2. #argparse是一个Python模块:命令行选项、参数和子命令解析器。
    3. #os模块提供的就是各种 Python 程序与操作系统进行交互的接口
    4. import argparse
    5. import tensorflow as tf #导入tensorflow 重命名为tf
    6. import os
    7. #line now+1 创建解析器,line now+2-line now+7 添加参数,line now+8解析参数。详见argparse模块官方文档
    8. parser = argparse.ArgumentParser(description='Process some integers')
    9. parser.add_argument('--mode', default='train', help='train or test')
    10. parser.add_argument("--num_epochs", default=5, type=int)
    11. parser.add_argument("--batch_size", default=32, type=int)
    12. parser.add_argument("--learning_rate", default=0.001)
    13. parser.add_argument("--data_dir", default="/gemini/data-1")
    14. parser.add_argument("--train_dir", default="/gemini/output")
    15. args = parser.parse_args()
    16. ###定义一个decode_and_resize函数,函数的输入参数为filename和label,即文件名和标签
    17. def _decode_and_resize(filename, label):
    18. image_string = tf.io.read_file(filename)#读取图像
    19. image_decoded = tf.image.decode_jpeg(image_string, channels=3)#对图像进行解码,RGB图像是三通道
    20. image_resized = tf.image.resize(image_decoded, [150, 150]) / 255.0 #resize图像大小,除255猜测是进行了归一化,后续看是否可以论证
    21. return image_resized, label #函数返回值 resize图像和label
    22. if __name__ == "__main__":
    23. ###设置数据集###
    24. train_dir = args.data_dir + "/train" #训练数据集路径
    25. cats = []
    26. dogs = []
    27. for file in os.listdir(train_dir): #添加训练图像路径
    28. if file.startswith("dog"):
    29. dogs.append(train_dir + "/" + file)
    30. else:
    31. cats.append(train_dir + "/" + file)
    32. print("dogSize:%d catSize:%d" % (len(cats), len(dogs))) #分别输出dog和cat两类训练数据大小,这里应该是% (len(dogs), len(cats)),但该数据集dogs数量=cats数量,问题不大。
    33. train_cat_filenames = tf.constant(cats[:10000])
    34. train_dog_filenames = tf.constant(dogs[:10000])
    35. train_filenames = tf.concat([train_cat_filenames, train_dog_filenames], axis=-1)
    36. train_labels = tf.concat([
    37. tf.zeros(train_cat_filenames.shape, dtype=tf.int32),
    38. tf.ones(train_dog_filenames.shape, dtype=tf.int32)
    39. ], axis=-1)
    40. train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))#假设我们现在有两组数据,分别是特征和标签,为了简化说明问题,我们假设每两个特征对应一个标签。之后把特征和标签组合成一个tuple
    41. train_dataset = train_dataset.map(map_func=_decode_and_resize,
    42. num_parallel_calls=tf.data.experimental.AUTOTUNE)
    43. #map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset
    44. train_dataset = train_dataset.shuffle(buffer_size=20000)#shuffle的功能为打乱dataset中的元素
    45. train_dataset = train_dataset.batch(args.batch_size)
    46. train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)#数据预读取,提高IO性能
    47. ###搭建模型###
    48. model = tf.keras.Sequential([
    49. tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(150, 150, 3)),#卷积层
    50. tf.keras.layers.MaxPool2D(), #最大池化
    51. tf.keras.layers.Conv2D(64, 3, activation="relu"),
    52. tf.keras.layers.MaxPool2D(),
    53. tf.keras.layers.Conv2D(128, 3, activation="relu"),
    54. tf.keras.layers.MaxPool2D(),
    55. tf.keras.layers.Conv2D(128, 3, activation="relu"),
    56. tf.keras.layers.MaxPool2D(),
    57. tf.keras.layers.Flatten(), #将输入层的数据压成一维的数据,一般用再卷积层和全连接层之间
    58. tf.keras.layers.Dropout(0.5), #dropout层
    59. tf.keras.layers.Dense(512, activation="relu"),#全连接层
    60. tf.keras.layers.Dense(2, activation="softmax")
    61. ])
    62. model.compile(
    63. optimizer=tf.keras.optimizers.Adam(learning_rate=args.learning_rate),
    64. loss=tf.keras.losses.sparse_categorical_crossentropy,
    65. metrics=[tf.keras.metrics.sparse_categorical_accuracy]
    66. )#在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准
    67. model.fit(train_dataset, epochs=args.num_epochs)#将训练数据在模型中训练一定次数,返回loss和测量指标
    68. model.save(args.train_dir)#保存模型
    69. ### 构建测试数据集
    70. test_cat_filenames = tf.constant(cats[10000:])
    71. test_dog_filenames = tf.constant(dogs[10000:])
    72. test_filenames = tf.concat([test_cat_filenames, test_dog_filenames], axis=-1)
    73. test_labels = tf.concat([
    74. tf.zeros(test_cat_filenames.shape, dtype=tf.int32),
    75. tf.ones(test_dog_filenames.shape, dtype=tf.int32)
    76. ], axis=-1)
    77. test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))
    78. test_dataset = test_dataset.map(_decode_and_resize)
    79. test_dataset = test_dataset.batch(args.batch_size)
    80. sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    81. for images, label in test_dataset:
    82. y_pred = model.predict(images)
    83. sparse_categorical_accuracy.update_state(y_true=label, y_pred=y_pred)
    84. print("test accuracy:%f" % sparse_categorical_accuracy.result())

    函数详解

  • 相关阅读:
    【Java开发】 Springboot集成Mybatis-Flex
    Elasticsearch:与多个 PDF 聊天 | LangChain Python 应用教程(免费 LLMs 和嵌入)
    VXLAN基础
    Jenkins环境配置篇-邮件发送
    解决Could not find artifact *** in alimaven的问题
    四旋翼飞行器基本模型(Matlab&Simulink)
    挑战来了!如何应对大商家订单多小商家没有订单的数据倾斜问题?
    java 将字符串转为Base64格式与将Base64内容解析出来
    pollFirst(),pollLast(),peekFirst(),peekLast()
    这么分页,小心有坑
  • 原文地址:https://blog.csdn.net/LYLYC_3/article/details/133220883