• 【Python深度学习】Python全栈体系(三十二)


    深度学习

    第十一章 Tensorflow 数据读取

    一、模型保存与加载

    1. 什么是模型保存与加载?
    • 模型加载可能是一个很长的过程,如果每次执行预测之前都重新训练,会非常耗时,所以几乎所有人工智能框架都提供了模型保存与加载功能,使得模型训练完成后,可以保存到文件中,供其它程序使用或继续训练。
    2. 模型保存与加载 API
    • 模型保存与加载通过 tf.train.Saver 对象完成,实例化对象:
      • saver = tf.train.Saver(var_list=None, max_to_keep=5)
        • var_list:要保存和还原的变量,可以是一个dict或一个列表
        • max_to_keep:要保留的最近检查点文件的最大数量。创建新文件时,会删除较旧的文件(如max_to_keep=5表示保留5个检查点文件)
    • 保存:saver.save(sess, ‘/tmp/ckpt/model’)
    • 加载:saver.restore(sess, ‘/tmp/ckpt/model’)
    # 线性回归示例
    import tensorflow as tf
    import os
    
    tf.compat.v1.disable_eager_execution()
    # 第一步:创建样本数据
    # 100行1列
    x = tf.compat.v1.random_normal([100, 1], mean=1.75, stddev=0.5, name="x_data")
    y_true = tf.matmul(x, [[2.0]]) + 5.0  # 计算 y = 2x + 5
    
    # 第二步:建立线性模型
    # 初始化权重(随机数)和偏置(固定设置为0),计算wx+b得到预测值
    weight = tf.Variable(tf.compat.v1.random_normal([1, 1], name="w"),
                         trainable=True)  # 训练过程中值是否允许变化
    bias = tf.Variable(0.0, name="b", trainable=True)  # 偏置
    y_predict = tf.matmul(x, weight) + bias  # 计算预测值
    # 第三步:创建损失函数
    loss = tf.reduce_mean(tf.square(y_true - y_predict))  # 均方差损失函数
    # 第四步:使用梯度下降进行训练
    # 0.1 学习率
    train_op = tf.compat.v1.train.GradientDescentOptimizer(0.1).minimize(loss)
    # 收集损失函数的值
    tf.compat.v1.summary.scalar("losses", loss)
    merged = tf.compat.v1.summary.merge_all()  # 合并摘要操作
    
    init_op = tf.compat.v1.global_variables_initializer()
    
    saver = tf.compat.v1.train.Saver()  # 实例化一个saver
    with tf.compat.v1.Session() as sess:
        sess.run(init_op)  # 执行初始化op
    
        # 打印初始权重和偏置
        print("weight: ", weight.eval(), "bias: ", bias.eval())
    
        # 指定事件文件并记录图的信息
        fw = tf.compat.v1.summary.FileWriter("../summary/", graph=sess.graph)
    
        # 训练之前,检查是否已经有模型保存,如果有,则加载
        if os.path.exists("../model/linear_model/checkpoint"):
            saver.restore(sess, "../model/linear_model/")
        # 循环训练
        for i in range(200):
            sess.run(train_op)  # 执行训练
            summary = sess.run(merged)  # 执行摘要合并操作
            fw.add_summary(summary, i)  # 写入事件文件
            print(i, ":", " weight: ", weight.eval(), "bias: ", bias.eval())
    
        # 训练完成,保存模型
        saver.save(sess, "../model/linear_model/")
    
    # 第一次运行
    """
    weight:  [[0.35549]] bias:  0.0
    0 :  weight:  [[3.1358984]] bias:  1.5653534
    1 :  weight:  [[3.5814297]] bias:  1.8450298
    ...
    197 :  weight:  [[2.168407]] bias:  4.6981807
    198 :  weight:  [[2.1649537]] bias:  4.7005367
    199 :  weight:  [[2.1565037]] bias:  4.699788
    """
    # 第二次运行
    """
    weight:  [[0.05191222]] bias:  0.0
    0 :  weight:  [[2.1568353]] bias:  4.704247
    1 :  weight:  [[2.1516428]] bias:  4.7066517
    ...
    197 :  weight:  [[2.0153403]] bias:  4.9711003
    198 :  weight:  [[2.0157084]] bias:  4.971714
    199 :  weight:  [[2.015551]] bias:  4.9720135
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70

    二、数据读取

    1. 文件读取机制
    • TensorFlow 文件读取分为三个步骤:
      • 第一步:将要读取的文件放入文件名队列
      • 第二步:读取文件内容,并实行解码
      • 第三步:批处理,按照指定笔数构建成一个批次取出
        在这里插入图片描述
    2. 文件读取 API
    • 文件队列构造:生成一个先入先出的队列,文件阅读器会需要它来读取数据
      • tf.train.string_input_producer(string_tensor, shuffle=True)
        • string_tensor: 含有文件名的一阶张量
        • shuffle: 是否打乱文件顺序
      • 返回:文件队列
    • 文件读取:
      • 文本文件读取:tf.TextLineReader
        • 读取CSV文件,默认按行读取
      • 二进制文件读取:tfFixedLengthRecordReader(record_bytes)
        • 读取每个记录是固定字节的二进制文件
        • record_bytes:每次读取的字节数
      • 通用读取方法:read(file_queue)
        • 从队列中读取指定数量(行,字节)的内容
        • 返回值:一个tensor元组(文件名,value)
    • 文件内容解码:
      • 解码文本文件:tf.decode_csv(records, record_defaults)
        • 将CSV文件内容转换为张量,与tf.TextLineReader搭配使用
        • 参数:
          • records:字符串,对应文件中的一行
          • record_defaults:类型
        • 返回:tensor对象列表
      • 解码二进制文件:tf.decode_raw(input_bytes, out_type)
        • 将字节转换为由数字表示的张量,与tf.FixedLengthRecordReader搭配使用
        • 参数:
          • input_bytes:待转换字节
          • out_type:输出类型
        • 返回:转换结果
    # CSV文件读取示例
    import tensorflow as tf
    import os
    
    tf.compat.v1.disable_eager_execution()
    
    
    def csv_read(filelist):  # 从csv样本文件中读取数据
        # 构建文件队列
        file_queue = tf.compat.v1.train.string_input_producer(filelist)
        # 定义reader
        reader = tf.compat.v1.TextLineReader()
        k, v = reader.read(file_queue)  # 读取,返回文件名称和数据
        # 解码
        records = [["None"], ["None"]]
        example, label = tf.compat.v1.decode_csv(v, record_defaults=records)
        # 批处理
        example_bat, label_bat = tf.compat.v1.train.batch([example, label],  # 参与批处理的数据
                                                          batch_size=9,  # 批次大小
                                                          num_threads=1)  # 线程数量
        return example_bat, label_bat
    
    
    if __name__ == "__main__":
        # 构建文件列表
        dir_name = "../test_data/"
        file_names = os.listdir(dir_name)  # 列出目录下所有的文件
        file_list = []
        for f in file_names:
            # 将目录名称、文件名称拼接成完整路径,并添加到文件列表
            file_list.append(os.path.join(dir_name, f))
    
        example, label = csv_read(file_list)  # 调用自定义函数,读取指定文件列表中的数据
    
        # 开启Session,执行
        with tf.compat.v1.Session() as sess:
            coord = tf.train.Coordinator()  # 定义线程协调器
            threads = tf.compat.v1.train.start_queue_runners(sess, coord=coord)
            print(sess.run([example, label]))  # 执行操作
            # 等待线程停止,并回收资源
            coord.request_stop()
            coord.join(threads)
    
    """
    [array([b'CCC1', b'CCC2', b'CCC3', b'AAA1', b'AAA2', b'AAA3', b'BBB1',
           b'BBB2', b'BBB3'], dtype=object), array([b'C1', b'C2', b'C3', b'A1', b'A2', b'A3', b'B1', b'B2', b'B3'],
          dtype=object)]
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    3. 图片文件读取 API
    • 图像读取器:tf.WholeFileReader
      • 功能:将文件的全部内容作为值输出的reader
      • read方法:读取文件内容,返回文件名和文件内容
    • 图像解码器:
      • tf.image.decode_jpeg(constants):解码jpeg格式
      • tf.image.decode_png(constants):解码png格式
      • 返回值:3-D张量,[height, width, channels]
    • 修改图像大小:tf.image.resize_images(images, size)
      • images:图片数据,3-D或4-D张量
        • 3-D:[长,宽,通道]
        • 4-D:[数量,长,宽,通道]
      • size:1-D int32张量,[长,宽](不需要传通道数)
    # 图像样本读取示例
    import tensorflow as tf
    import os
    
    
    # 图像样本读取函数
    def img_read(filelist):
        # 构建文件队列
        file_queue = tf.train.string_input_producer(filelist)
        # 定义reader
        reader = tf.WholeFileReader()
        k, v = reader.read(file_queue)  # 读取整个文件内容
        # 解码
        img = tf.image.decode_jpeg(v)
        # 批处理
        img_resized = tf.image.resize(img, [200, 200])  # 将图像设置成200*200大小
        img_resized.set_shape([200, 200, 3])  # 固定样本形状,批处理时对数据形状有要求
        img_bat = tf.train.batch([img_resized],
                                 batch_size=10,
                                 num_threads=1)
        return img_bat
    
    
    if __name__ == "__main__":
        # 构建文件列表
        dir_name = "../test_img/"
        file_names = os.listdir(dir_name)
        file_list = []
        for f in file_names:
            # 将目录名、文件名拼接成完整路径放入文件列表中
            file_list.append(os.path.join(dir_name, f))
        imgs = img_read(file_list)
    
        with tf.Session() as sess:
            coord = tf.train.Coordinator()  # 线程协调器
            threads = tf.train.start_queue_runners(sess, coord=coord)
            result = imgs.eval()  # 调用函数,分批次读取样本
            # 等待线程结束,并回收资源
            coord.request_stop()
            coord.join(threads)
    
    # 显示图片
    import matplotlib.pyplot as plt
    
    plt.figure("Img Show", facecolor="lightgray")
    
    for i in range(10):  # 循环显示读取到的样本(批次读取,所以有多个样本)
        plt.subplot(2, 5, i + 1)  # 显示子图,2行5列的第i+1个子图
        plt.xticks([])
        plt.yticks([])
        plt.imshow(result[i].astype("int32"))
    
    plt.tight_layout()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54

    第十二章 Tensorflow 手写体识别

    一、MINIST 数据集

    • 手写数字的数据集,来自美国国家标准与技术研究所(National Institute of Standards and Technology, NIST),发布于1998年
    • 样本来自250个不同人的手写数字,50%高中学生,50%是人口普查局的工作人员
    • 数字从0~9,图片大小是28x28像素,训练数据集包含60000个样本,测试数据集包含10000个样本。数据集的标签是长度为10的一维数组,数组中每个元素索引号表示对应数字出现的概率。
    • 下载地址:http://yann.lecun.com/exdb/mnist/

    二、任务目标

    • 根据训练集样本进行模型训练
    • 保存模型
    • 加载模型,用于新的手写体数字识别

    三、网络结构

    在这里插入图片描述

    四、相关 API

    • tf.matmul():执行矩阵乘法计算
    • tf.nn.softmax():softmax激活函数
    • tf.reduce_sum():指定维度上求张量和
    • tf.train.GradientDescentOptimizer():优化器,执行梯度下降
    • tf.argmax():返回张量最大元素的索引值

    五、代码

    # 手写体识别案例
    # 模型:全连接模型
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import pylab
    
    # 定义样本读取对象
    mnist = input_data.read_data_sets("MNIST_data/",  # 数据集所在的目录
                                      one_hot=True)  # 标签是否采用独热编码
    
    # 定义占位符,用于表图像数据、标签
    x = tf.placeholder(tf.float32, [None, 784])  # 图像数据,N行784列
    y = tf.placeholder(tf.float32, [None, 10])  # 标签(图像真实类别),N行784列
    
    # 定义权重、偏置
    w = tf.Variable(tf.random_normal([784, 10]))  # 权重,784行10列
    b = tf.Variable(tf.zeros([10]))  # 偏置,10个偏置
    
    # 构建模型,计算预测结果
    pred_y = tf.nn.softmax(tf.matmul(x, w) + b)
    # 损失函数
    cross_entropy = -tf.reduce_sum(y * tf.log(pred_y), reduction_indices=1)
    cost = tf.reduce_mean(cross_entropy)  # 求均值
    # 梯度下降优化器
    optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
    
    batch_size = 100  # 批次大小
    saver = tf.train.Saver()  # saver
    model_path = "../model/mnist/mnist_model.ckpt"  # 模型路径
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())  # 初始化
        # 开始训练
        for epoch in range(10):
            # 计算总批次
            total_batch = int(mnist.train.num_examples / batch_size)
            avg_cost = 0.0
    
            for i in range(total_batch):
                # 从训练集读取一个批次的样本
                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
                params = {x: batch_xs, y: batch_ys}  # 参数字典
                o, c = sess.run([optimizer, cost],  # 执行的op
                                feed_dict=params)  # 喂入参数
                avg_cost += (c / total_batch)  # 计算平均损失值
    
            print("epoch:%d, cost=%.9f" % (epoch + 1, avg_cost))
        print("训练结束。")
    
        # 模型评估
        # 比较预测结果和真实结果,返回布尔类型的数组
        correct_pred = tf.equal(tf.argmax(pred_y, 1),  # 求预测结果中最大值的索引
                                tf.argmax(y, 1))  # 求真实结果中最大值的索引
        # 将布尔类型数组转换为浮点数,并计算准确率
        # 因为计算均值、准确率公式相同,所以调用计算均值的函数计算准确率
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
        print("accuracy:", accuracy.eval({x: mnist.test.images,  # 测试集下的图像数据
                                          y: mnist.test.labels}))  # 测试集下的图像的真实类别
        # 保存模型
        save_path = saver.save(sess, model_path)
        print("模型已保存:", save_path)
    
    # 从测试集中随机读取2张图像,执行预测
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, model_path)  # 加载模型
        # 从测试集中读取样本
        batch_xs, batch_ys = mnist.test.next_batch(2)
        output = tf.argmax(pred_y, 1)  # 直接取出预测结果中的最大值
        output_val, predv = sess.run([output, pred_y],  # 执行的OP
                                     feed_dict={x: batch_xs})  # 预测,所以不需要传入标签
        print("预测最终结果:\n", output_val, "\n")
        print("真实结果:\n", batch_ys, "\n")
        print("预测概率:\n", predv, "\n")
    
        # 显示图片
        im = batch_xs[0]  # 第一个测试样本
        im = im.reshape(-1, 28)  # 28列,-1表示经过计算的值
        pylab.imshow(im)
        pylab.show()
    
        # 显示图片
        im = batch_xs[1]  # 第一个测试样本
        im = im.reshape(-1, 28)  # 28列,-1表示经过计算的值
        pylab.imshow(im)
        pylab.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86

    第十三章 Tensorflow 服饰识别

    一、数据集介绍

    • 是来自 Zalando 文章的数据集,是时尚版的 MNIST。包括 60000 个训练集数据,10000 个测试集数据,每个数据为 28 x 28 灰度图像,一共有 10 类:
      在这里插入图片描述

    二、任务目标

    • 搭建模型
    • 根据训练集样本进行模型训练
    • 使用模型对新的服饰图片识别

    三、网络结构

    在这里插入图片描述

    四、代码

    # 使用卷积神经网络实现服饰识别
    import tensorflow as tf
    from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
    
    
    # 定义FashionMnist类
    class FashionMnist():
        out_features1 = 12  # 第一组卷积层输出通道数量(即第一个卷积层卷积核数)
        out_features2 = 24  # 第二组卷积层输出通道数量(即第二个卷积层卷积核数)
        con_neurons = 512  # 全连接层神经元数量
    
        def __init__(self, path):
            """
            构造方法
            :param path: 指定数据集目录
            """
            self.sess = tf.Session()
            self.data = read_data_sets(path, one_hot=True)
    
        def init_weight_variable(self, shape):
            """
            根据指定形状初始化权重
            :param shape: 指定要初始化的变量的形状
            :return: 返回经过初始化的变量
            """
            initial = tf.truncated_normal(shape, stddev=0.1)  # 截尾正态分布
            return tf.Variable(initial)
    
        def init_bias_variable(self, shape):
            """
            初始化偏置
            :param shape: 指定要初始化的变量的形状
            :return: 返回经过初始化的变量
            """
            initial = tf.constant(1.0, shape=shape)
            return tf.Variable(initial)
    
        def conv2d(self, x, w):
            """
            二维卷积方法
            :param x: 原始数据
            :param w: 卷积核
            :return: 返回卷积运算的结果
            """
            # 卷积核:[高度,宽度,输入通道数,输出通道数]
            return tf.nn.conv2d(x,  # 原始数据
                                w,  # 卷积核
                                strides=[1, 1, 1, 1],  # 各维度上的步长值
                                padding="SAME")  # 输入矩阵和输出矩阵大小一样
    
        def max_pool_2x2(self, x):
            """
            定义池化方法
            :param x: 原始数据
            :return: 池化计算结果
            """
            return tf.nn.max_pool(x,
                                  ksize=[1, 2, 2, 1],  # 池化区域大小
                                  strides=[1, 2, 2, 1],  # 各个维度上的步长值
                                  padding="SAME")
    
        def create_conv_pool_layer(self, input, input_features, output_features):
            """
            定义卷积、激活、池化层
            :param input: 原始数据
            :param input_features: 输入特征数量
            :param output_features: 输出特征数量
            :return: 卷积、激活、池化层运算结果
            """
            filter = self.init_weight_variable([5, 5, input_features, output_features])
            b_conv = self.init_bias_variable([output_features])  # 偏置,卷积有多少输出就有多少个偏置
            h_conv = tf.nn.relu(self.conv2d(input, filter) + b_conv)  # 卷积,激活运算
            h_pool = self.max_pool_2x2(h_conv)  # 对卷积激活运算的结果做池化
            return h_pool
    
        def create_fc_layer(self, h_pool_flat, input_features, con_neurons):
            """
            创建全连接层
            :param h_pool_flat: 输入数据,经过拉伸后的一维张量
            :param input_features: 输入特征数量
            :param con_neurons: 神经元数量(输出特征数量)
            :return: 经过全连接计算后的结果
            """
            w_fc = self.init_weight_variable([input_features, con_neurons])  # 权重
            b_fc = self.init_bias_variable([con_neurons])  # 偏置
            h_fc1 = tf.nn.relu(tf.matmul(h_pool_flat, w_fc) + b_fc)  # 计算wx+b并做激活
            return h_fc1
    
        def build(self):
            """
            组建CNN
            :return:
            """
            # 定义输入数据、标签数据的占位符
            self.x = tf.placeholder(tf.float32, shape=[None, 784])
            x_image = tf.reshape(self.x, [-1, 28, 28, 1])  # 变维成28*28单通道图像数据
            self.y = tf.placeholder(tf.float32, shape=[None, 10])  # 标签,N个样本,每个样本10个类别对应的概率
    
            # 第一组卷积池化
            h_pool1 = self.create_conv_pool_layer(x_image, 1, self.out_features1)
    
            # 第二组卷积池化
            h_pool2 = self.create_conv_pool_layer(h_pool1,  # 以上一个卷积池化层的输出作为输入
                                                  self.out_features1,  # 输入特征数量,为上一层输出特征数量
                                                  self.out_features2)  # 输出特征数量
    
            # 全连接
            h_pool2_flat_features = 7 * 7 * self.out_features2  # 计算特征点的数量
            h_pool2_flat = tf.reshape(h_pool2, [-1, h_pool2_flat_features])  # 拉伸成一维
            h_fc = self.create_fc_layer(h_pool2_flat,  # 输入
                                        h_pool2_flat_features,  # 输入特征数量
                                        self.out_features1)  # 输出特征数量
    
            # dropout(通过随机丢弃一定比例神经元参数更新,防止过拟合)
            self.keep_prob = tf.placeholder("float")  # 保存率
            h_fcl_drop = tf.nn.dropout(h_fc, self.keep_prob)
    
            # 输出层
            w_fc = self.init_weight_variable([self.con_neurons, 10])  # 512行10列
            b_fc = self.init_bias_variable([10])  # 10个偏置
            y_conv = tf.matmul(h_fcl_drop, w_fc) + b_fc  # 计算wx+b
    
            # 计算准确率
            correct_prediction = tf.equal(tf.argmax(y_conv, 1),
                                          tf.argmax(self.y, 1))
            self.accuracy = tf.reduce.mean(tf.cast(correct_prediction, tf.float32))
            # 损失函数
            loss_func = tf.nn.softmax_cross_entropy_with_logits(labels=self.y,  # 真实值
                                                                logits=y_conv)  # 预测值
    
            cross_entropy = tf.reduce_mean(loss_func)
    
            # 优化器
            optimizer = tf.train.AdamOptimizer(0.001)
            self.train_step = optimizer.minimize(cross_entropy)
    
        def train(self):
            self.sess.run(tf.global_variables_initializer())
            batch_size = 100  # 批次大小
            print("begin training...")
    
            for i in range(10):
                total_batch = int(self.data.train.num_examples / batch_size)  # 计算批次数量
    
                for j in range(total_batch):
                    batch = self.data.train.next_batch(batch_size)  # 获取一个批次样本
                    params = {self.x: batch[0],  # 图像
                              self.y: batch[1],  # 标签
                              self.keep_prob: 0.5}  # 计算丢弃率
                    t, acc = self.sess.run([self.train_step, self.accuracy],  # 执行的op
                                           params)
                    if j % 100 == 0:
                        print("i: %d, j: %d  acc: %f" % (i, j, acc))
    
        def eval(self, x, y, keep_prob):
            params = {self.x: x, self.y: y, self.keep_prob: 0.5}
            test_acc = self.sess.run(self.accuracy, params)  # 计算准确率
            print("Test Accuracy: %f" % test_acc)
            return test_acc
    
        def close(self):
            self.sess.close()
    
    
    if __name__ == "__main__":
        mnist = FashionMnist("fashion_mnist/")
        mnist.build()  # 组件网络
        mnist.train()  # 训练
    
        # 评估
        xs, ys = mnist.data.test.next_batch(100)
        mnist.eval(xs, ys, 0.5)
        mnist.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
  • 相关阅读:
    【王道】计算机网络网络层(三)
    Prometheus详解(六)——Prometheus使用Exporter监控Redis
    Hexo+Github 快速搭建个人博客
    推荐算法效果不佳时的检查清单
    C# Onnx GFPGAN GPEN-BFR 人像修复
    Cobalt Strike(十三)内网隧道通信
    C++——智能指针
    【软考 系统架构设计师】系统可靠性分析与设计① 系统可靠性分析
    性能测试 —— Jmeter 常用三种定时器
    uniapp微信小程序解决上方刘海屏遮挡
  • 原文地址:https://blog.csdn.net/sgsgkxkx/article/details/126411420