• TensorFlow2代码解读(4)


    1. import tensorflow as tf
    2. from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
    3. (x, y), (x_test, y_test) = datasets.fashion_mnist.load_data()
    4. x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
    5. y = tf.convert_to_tensor(y, dtype=tf.int32)
    6. x_test = tf.convert_to_tensor(x_test, dtype=tf.float32) / 255.
    7. y_test = tf.convert_to_tensor(y_test, dtype=tf.int32)
    8. # 将28x28的图像转换成长度为784的向量
    9. x = tf.reshape(x, [-1, 28*28])
    10. x_test = tf.reshape(x_test, [-1, 28*28])
    11. batch_size = 128
    12. db = tf.data.Dataset.from_tensor_slices((x, y))
    13. db = db.shuffle(10000).batch(batch_size)
    14. db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    15. db_test = db_test.batch(batch_size)
    16. model = Sequential([
    17. layers.Dense(256, activation=tf.nn.relu),
    18. layers.Dense(128, activation=tf.nn.relu),
    19. layers.Dense(64, activation=tf.nn.relu),
    20. layers.Dense(32, activation=tf.nn.relu),
    21. layers.Dense(16, activation=tf.nn.relu),
    22. layers.Dense(10)
    23. ])
    24. model.build(input_shape=[None, 28*28])
    25. optimizer = optimizers.Adam(lr=1e-3)
    26. model.compile(optimizer=optimizer,
    27. loss=tf.losses.MSE,
    28. metrics=[metrics.CategoricalAccuracy()])
    29. def main():
    30. model.fit(db, epochs=30, validation_data=db_test)
    31. if __name__ == '__main__':
    32. main()
    1. (x, y), (x_test, y_test) = datasets.fashion_mnist.load_data()
    2. 这行代码从TensorFlow的Fashion MNIST数据集中加载训练集(x, y)和测试集(x_test, y_test)。
    1. x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
    2. y = tf.convert_to_tensor(y, dtype=tf.int32)
    3. x_test = tf.convert_to_tensor(x_test, dtype=tf.float32) / 255.
    4. y_test = tf.convert_to_tensor(y_test, dtype=tf.int32)
    5. 这里将加载的数据转换为张量格式,并对输入图像进行预处理。x和x_test被转换为float32类型,并且通过除以255进行归一化,使像素值在01之间。而y和y_test则被转换为int32类型。
    1. x = tf.reshape(x, [-1, 28*28])
    2. x_test = tf.reshape(x_test, [-1, 28*28])
    3. 这两行代码将图像数据进行reshape操作,将每张图像从28x28的二维数组转换为长度为784的一维向量。-1表示根据原始数据的维度自动计算。
    1. batch_size = 128
    2. 这行代码定义了批量训练的批次大小,即每次送入模型训练的样本数量。
    1. db = tf.data.Dataset.from_tensor_slices((x, y))
    2. db = db.shuffle(10000).batch(batch_size)
    3. 这里使用tf.data.Dataset.from_tensor_slices()将训练数据(x, y)转换为数据集对象db。然后使用.shuffle(10000)对数据集进行混洗操作,打乱样本的顺序。最后使用.batch(batch_size)方法将数据集划分为批次,每个批次的样本数量为batch_size。
    1. db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    2. db_test = db_test.batch(batch_size)
    3. 这里使用tf.data.Dataset.from_tensor_slices()将测试数据(x_test, y_test)转换为数据集对象db_test。同样地,使用.batch(batch_size)方法将数据集划分为批次,每个批次的样本数量为batch_size。
    1. model = Sequential([
    2. layers.Dense(256, activation=tf.nn.relu),
    3. layers.Dense(128, activation=tf.nn.relu),
    4. layers.Dense(64, activation=tf.nn.relu),
    5. layers.Dense(32, activation=tf.nn.relu),
    6. layers.Dense(16, activation=tf.nn.relu),
    7. layers.Dense(10)
    8. ])
    9. model.build(input_shape=[None, 28*28])
    10. optimizer = optimizers.Adam(lr=1e-3)
    11. 这里定义了一个Sequential模型,并通过layers.Dense添加了一系列全连接层。每个全连接层都有不同的输出大小和激活函数。最后一个全连接层没有指定激活函数,因为这是一个多类别分类问题,概率可以通过softmax函数计算。使用model.build()指定了模型的输入形状为(None, 28*28),其中None表示可以接受任意批次大小。
    1. model.compile(optimizer=optimizer,
    2. loss=tf.losses.MSE,
    3. metrics=[metrics.CategoricalAccuracy()])
    4. 这行代码编译模型。使用指定的优化器optimizer、损失函数tf.losses.MSE和评估指标metrics.CategoricalAccuracy()来配置模型。
    1. def main():
    2. model.fit(db, epochs=30, validation_data=db_test)
    3. 这里定义了一个main()函数,并在其中使用.fit()方法对模型进行训练。训练数据集为db,训练30个epoch,并且使用验证数据集db_test进行模型的验证。
    1. if __name__ == '__main__':
    2. main()
    3. 这里使用if __name__ == '__main__':判断当前脚本是否直接运行,如果是,则调用main()函数开始执行代码。

  • 相关阅读:
    抓包工具mitmprox
    You must install .NET Desktop Runtime to run this application
    C# 实现 Linux 视频会议(源码,支持信创环境,银河麒麟,统信UOS)
    竞赛选题 深度学习中文汉字识别
    Android 远程调用服务之 AIDL
    【华为机试真题 JAVA】猴子爬山-100
    源码层面理解 LiveData 各种特性的实现原理
    今天安装mongodb,有许多心得记录一下
    高等数学(第七版)同济大学 习题10-4 (后7题)个人解答
    h2database BTree 设计实现与查询优化思考
  • 原文地址:https://blog.csdn.net/Victor_Li_/article/details/133855792