• 第十三天Tensorflow数据读取和复杂模型的构建


    Tensorflow数据读取

    Tensorflow的数据读取大概分成一下几步:

    1. 根据文件列表将文件打乱成文件队列
    2. 依据文件不同生成不同的文件阅读器
    3. 将文件阅读器读到的东西进行解码
    4. 批处理读出文件

    读Csv文件

    import tensorflow as tf
    import os
    
    def read_csv(filelist):
        #构建文件队列
        file_queue = tf.train.string_input_producer(string_tensor=filelist,#文件名列表
                                                    shuffle=True)#打乱文件顺序
    
        #构建文件读取器
        reader = tf.TextLineReader()
    
        #读取数据
        name, val = reader.read(file_queue=file_queue) #name是文件名字,val是读到的值
    
        #解码
        records = [['None'], ['None']]
        example, label = tf.decode_csv(records=val, 
                                       record_defaults=records)#说明数据拆分格式,一般是一个列表包括两个列表对象,内部第一个列表表示X的数据格式,第二个列表表示是Y的数据格式
    
        #批处理
        exam_bat, lab_bat = tf.train.batch([example, label], #返回批处理数据
                                           batch_size=5, #批大小
                                           num_threads=1) #线程数量
    
        return exam_bat, lab_bat
    
    #构建文件列表
    dir_name = './test_data/'
    file_names = os.listdir(dir_name)
    file_list = []
    for f in file_names:
        file_list.append(os.path.join(dir_name, f))
    
    example, label = read_csv(file_list) #这里需要注意,example, label是函数返回的操作对象,所以需要交给会话
    
    with tf.Session() as sess:
        #线程协调器
        coord = tf.train.Coordinator()
        #开启读取文件线程
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)  #线程协调器
    
        print(sess.run([example, label])) #多个操作的运行用列表传入
    
        #等待线程停止,回收资源
        coord.request_stop()
        coord.join(threads)  #回收
    
    [array([b'AAAAAAAAAA1', b'AAAAAAAAAA2', b'AAAAAAAAAA3', b'AAAAAAAAAA4',
           b'AAAAAAAAAA5'], dtype=object), array([b'A1', b'A2', b'A3', b'A4', b'A5'], dtype=object)]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49

    在这里插入图片描述

    读图片文件(Jpeg为例子)

    import tensorflow as tf
    import os
    import matplotlib.pyplot as plt
    
    def read_img(filelist):
        #构建文件队列
        file_queue = tf.train.string_input_producer(filelist)
    
        #构建文件读取器
        reader = tf.WholeFileReader() #图像读取器
    
        #读取数据
        name, val = reader.read(file_queue) #name是文件名字,val是读到的值
    
        #解码
        img = tf.image.decode_jpeg(val)
        #还有png解码器
        #tf.image.decode_png(val)
    
        #批处理
        img_resiezed = tf.image.resize(img, [200, 200])  #修改图像大小,颜色通道大小默认
        img_resiezed.set_shape([200, 200, 3])  #修改张量大小
        img_bat = tf.train.batch([img_resiezed], batch_size=10, num_threads=1)
    
        return img_bat
    
    #构建文件列表
    dir_name = './test_img/'
    file_names = os.listdir(dir_name)
    file_list = []
    for f in file_names:
        file_list.append(os.path.join(dir_name, f))
    
    imgs = read_img(file_list)
    
    with tf.Session() as sess:
        #线程协调器
        coord = tf.train.Coordinator()
        #开启读取文件线程
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)  #线程协调器
    
        print(sess.run(imgs))
        res=sess.run(imgs)
    
        imgs = imgs.eval()
    
        #等待线程停止,回收资源
        coord.request_stop()
        coord.join(threads)  #回收
    
    fig = plt.figure('Imshow', facecolor='lightgray')
    for i in range(10):
        plt.subplot(2, 5, i + 1)
        plt.imshow(res[i].astype("int32")) #没有imgs = imgs.eval(),imgs是tensor类型,就会报错
        plt.xticks([])
        plt.yticks([])
    plt.tight_layout()  #紧凑式布局
    plt.show()
    
    type(res)
    #输出
    numpy.ndarray
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62

    读到的图片实际上一个个三纬列表,10个图片套在一个大列表里面,就成了四维列表。
    在这里插入图片描述
    在这里插入图片描述

    Tensorflow实现手写字体识别

    将要构建一个单层10个神经元的简单网络:
    在这里插入图片描述

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import os
    import pylab
    import numpy as np
    import ssl  
    
    
    #下载数据忽略证书
    ssl._create_default_https_context = ssl._create_unverified_context 
    
    #数据准备
    mnist = input_data.read_data_sets(train_dir='./MNIST_data', one_hot=True)
    
    #整理数据
    x = tf.placeholder(tf.float32, [None, 784])  # 占位符,输入,样本数不确定所以用None
    y = tf.placeholder(tf.float32, [None, 10])  # 占位符,输出,样本数不确定所以用None
    
    # 构建模型
    W = tf.Variable(tf.random_normal([784, 10]))  # 权重
    b = tf.Variable(tf.zeros([10]))  # 偏置值
    
    pred_y = tf.nn.softmax(tf.matmul(x, W) + b)  # softmax分类
    print("pred_y.shape:", pred_y.shape)
    
    # 损失函数
    cross_entropy = -tf.reduce_sum(y * tf.log(pred_y),
                                   reduction_indices=1)  # 求交叉熵,相当于axis=1
    
    cost = tf.reduce_mean(cross_entropy)  # 求损失函数平均值
    
    #梯度下降
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(cost)
    
    #模型保存对象
    saver = tf.train.Saver()
    model_path = "/Users/呆呆网友/Documents/Artificial-Intelligence-Note/8_12/model/"
    
    #超参数
    batch_size = 100
    
    #定义变量初始化操作
    #这里的样本初始化直接在对话中完成
    
    #执行
    with tf.Session() as sess:
        #初始化变量
        sess.run(tf.global_variables_initializer())
    
        #在训练前检查是否有模型保存
        if os.path.exists('/Users/呆呆网友/Documents/Artificial-Intelligence-Note/8_12/model/checkpoint'):
            saver.restore(sess, '/Users/呆呆网友/Documents/Artificial-Intelligence-Note/8_12/model/')
    
        for epoch in range(10):
            total_batch = int(mnist.train.num_examples / batch_size)
            total_cost = 0 #求平均损失值需要的变量
            for i in range(total_batch):
                #拿到一个批次的样本数据
                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
                params = {x: batch_xs, y: batch_ys}
                o, cost_res = sess.run([optimizer, cost], #进行梯度下降,计算损失
                                       feed_dict=params)
                #求平均损失值
                total_cost += cost_res
    
            #平均损失值
            avg_cost = total_cost / total_batch
            print('epoch:{},cost:{}'.format(epoch + 1, avg_cost))
    
        print('训练完成')
    
        #模型评估
        corr_pred = tf.equal(tf.argmax(y, 1),  #真实类别
                             tf.argmax(pred_y, 1))  #判断水平方向索引是否相等, 会得到一个bool数组
        #精度
        acc = tf.reduce_mean(tf.cast(corr_pred, 'float32'))
        print('精度:', acc.eval({x: mnist.test.images, y: mnist.test.labels}))  #eval()传feed_dict方式
    
        #保存模型
        save_path = saver.save(sess, '/Users/limuyuan/Documents/Artificial-Intelligence-Note/8_12/model/')
        print('保存成功:', save_path)
    
    #从测试集中抽取两张图像进行预测
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())  #初始化
        saver.restore(sess, '/Users/limuyuan/Documents/Artificial-Intelligence-Note/8_12/model/')  #加载模型
    
        #从测试集读两张图像
        test_images, test_labels = mnist.test.next_batch(batch_size=2)
        out_put = tf.arg_max(pred_y, 1)  #axis=1
        out_val, predv = sess.run([out_put, pred_y], feed_dict={x: test_images})  #feed_dict为pred_y传参数
        print('预测类别为:', out_val)
        print('真实类别:', np.argmax(test_labels, axis=1))  #这个返回最大值的索引
        print("预测概率为:", np.max(predv, axis=1))  #这个返回最大的值
    
        #显示图像
        img = test_images[0]
        img = img.reshape(-1, 28)
        pylab.imshow(img)
        pylab.show()
    
        #显示图像
        img = test_images[1]
        img = img.reshape(-1, 28)
        pylab.imshow(img)
        pylab.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106

    在这里插入图片描述

  • 相关阅读:
    7.java三大特征之一:多态
    CVPR 2023 | UniMatch: 重新审视半监督语义分割中的强弱一致性
    MyBatis笔记03------XXXMapper.xml文件解析
    iptables知识手册
    『 CSS实战』CSS3 实现一些好玩的效果(2)
    Leetcode刷题详解——在排序数组中查找元素的第一个和最后一个位置
    【基础】Java面试题
    Java基础之变量
    Keras深度学习实战(25)——使用skip-gram和CBOW模型构建单词向量
    利用CloudCompare进行点云过滤去噪(统计滤波)
  • 原文地址:https://blog.csdn.net/weixin_45256637/article/details/126309395