• tensorflow之TFRcords文件读取


    TFRcords文件读取与储存

    (1)TFRecords分析、存取:

    TFRecords是Tensorflow设计的一种内置文件格式,是一种二进制文件,
    它能更好的利用内存,更方便复制和移动

    为了将二进制数据和标签(训练的类别标签)数据存储在同一个文件中

    文件格式:*.tfrecords

    写入文件内容:Example协议块

    (2)TFRecords存储API

    1、建立TFRecord存储器
    tf.python_io.TFRecordWriter(path)
    写入tfrecords文件
    path: TFRecords文件的路径
    return:写文件
    
    method
    write(record):向文件中写入一个字符串记录
    
    close():关闭文件写入器
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    :字符串为一个序列化的Example,Example.SerializeToString()

    (3)TFRecords读取方法API:

    同文件阅读器流程,中间需要解析过程
    
    解析TFRecords的example协议内存块
    tf.parse_single_example(serialized,features=None,name=None)
    解析一个单一的Example原型
    serialized:标量字符串Tensor,一个序列化的Example
    features:dict字典数据,键为读取的名字,值为FixedLenFeature
    return:一个键值对组成的字典,键为读取的名字
    
    
    tf.FixedLenFeature(shape,dtype)
    shape:输入数据的形状,一般不指定,为空列表
    dtype:输入数据类型,与存储进文件的类型要一致
    类型只能是float32,int64,string
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    (4)CIFAR-10批处理结果存入tfrecords流程:

    1、构造存储器

    2、构造每一个样本的Example

    3、写入序列化的Example

    (5)读取tfrecords流程:

    1、构造TFRecords阅读器

    2、解析Example

    3、转换格式,bytes解码

    (6)代码实现:

       def write_ro_tfrecords(self, image_batch, label_batch):
            """
            将图片的特征值和目标值存进tfrecords
            :param image_batch: 10张图片的特征值
            :param label_batch: 10张图片的目标值
            :return: None
            """
            # 1、建立TFRecord存储器
            writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)
    
            # 2、循环将所有样本写入文件,每张图片样本都要构造example协议
            for i in range(10):
                # 取出第i个图片数据的特征值和目标值
                image = image_batch[i].eval().tostring()
    
                label = int(label_batch[i].eval()[0])
    
                # 构造一个样本的example
                example =  tf.train.Example(features=tf.train.Features(feature={
                    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                }))
    
                # 写入单独的样本
                writer.write(example.SerializeToString())
    
            # 关闭
            writer.close()
            return None
    
        def read_from_tfrecords(self):
    
            # 1、构造文件队列
            file_queue = tf.train.string_input_producer([FLAGS.cifar_tfrecords])
    
            # 2、构造文件阅读器,读取内容example,value=一个样本的序列化example
            reader = tf.TFRecordReader()
    
            key, value = reader.read(file_queue)
    
            # 3、解析example
            features = tf.parse_single_example(value, features={
                "image": tf.FixedLenFeature([], tf.string),
                "label": tf.FixedLenFeature([], tf.int64),
            })
    
            # 4、解码内容, 如果读取的内容格式是string需要解码, 如果是int64,float32不需要解码
            image = tf.decode_raw(features["image"], tf.uint8)
    
            # 固定图片的形状,方便与批处理
            image_reshape = tf.reshape(image, [self.height, self.width, self.channel])
    
            label = tf.cast(features["label"], tf.int32)
    
            print(image_reshape, label)
    
            # 进行批处理
            image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
    
            return image_batch, label_batch
    
    
    if __name__ == "__main__":
        # 1、找到文件,放入列表   路径+名字  ->列表当中
        file_name = os.listdir(FLAGS.cifar_dir)
    
        filelist = [os.path.join(FLAGS.cifar_dir, file) for file in file_name if file[-3:] == "bin"]
    
        # print(file_name)
        cf = CifarRead(filelist)
    
        # image_batch, label_batch = cf.read_and_decode()
    
        image_batch, label_batch = cf.read_from_tfrecords()
    
        # 开启会话运行结果
        with tf.Session() as sess:
            # 定义一个线程协调器
            coord = tf.train.Coordinator()
    
            # 开启读文件的线程
            threads = tf.train.start_queue_runners(sess, coord=coord)
    
            # 存进tfrecords文件
            # print("开始存储")
            #
            # cf.write_ro_tfrecords(image_batch, label_batch)
            #
            # print("结束存储")
    
            # 打印读取的内容
            print(sess.run([image_batch, label_batch]))
    
            # 回收子线程
            coord.request_stop()
    
            coord.join(threads)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
  • 相关阅读:
    面经-框架-事务失效的几种场景
    机器学习-sklearn-高斯混合模型-学习笔记
    大数据技术之Hadoop:HDFS存储原理篇(五)
    【Taro3踩坑日记】不存在全局配置文件:C:\Users\TYW\.taro-global-config\index.json
    TCP 加速小记
    动态设置原生swiper的配置项
    《HelloGitHub》第 77 期
    EL表达式内置对象initParam
    知识产权与标准化
    2022-09-16 第二小组 张明旭 Java学习记录
  • 原文地址:https://blog.csdn.net/XST1520203418/article/details/121598588