• TensorFlow入门(十四、数据读取机制(1))


    TensorFlow的数据读取方式

    TensorFlow的数据读取方式共有三种,分别是:

            ①预加载数据(Preloaded data)

                    预加载数据的方式,其实就是静态图(Graph)的模式。即将数据直接内嵌到Graph中,再把Graph传入Session中运行。

    示例代码如下:

    1. import tensorflow.compat.v1 as tf
    2. tf.disable_v2_behavior()
    3. a = tf.constant([[5,2]])
    4. b = tf.constant([[1],[3]])
    5. c = tf.matmul(a,b)
    6. init = tf.global_variables_initializer()
    7. with tf.Session() as sess:
    8. sess.run(init)
    9. print(sess.run(c))

            ②先产生数据,再喂数据(Feeding)

                    先生产数据,通过feed_dict喂数据(Feeding)的方式。

    示例代码如下:

    1. import tensorflow.compat.v1 as tf
    2. tf.disable_v2_behavior()
    3. a = tf.placeholder(tf.int16)
    4. b = tf.placeholder(tf.int16)
    5. c = tf.add(a,b)
    6. a1 = 6
    7. b1 = 8
    8. init = tf.global_variables_initializer()
    9. with tf.Session() as sess:
    10. sess.run(init)
    11. print(sess.run(c,feed_dict = {a:a1,b:b1}))

            ③直接从文件中读取(Reading from file)

                    前两种方法虽然方便,但无法满足大型数据集训练时对速度要高效、内存消耗要小的要求。因此,TensorFlow提供了第三种方式,即在静态图(Graph)中定义好文件读取的方法,TensorFlow自动从文件(也就是文本或图片)中读取数据,然后解码成可用的样本集。

                    从文件读取数据的流程主要分为四个步骤:

                            ①创建文件,准备数据

                            ②创建文件名队列,将已准备的文件,按照随机顺序放入队列

                            ③创建Reader,读取文件

                            ④将读取的内容解码后输出

    示例代码如下:

            生成文件

    1. import csv
    2. file_name = "file.csv"
    3. with open(file_name,"w",newline = "") as csvfile:
    4. writer = csv.writer(csvfile, dialect = "excel")
    5. with open("data1.txt","r") as file_txt:
    6. for line in file_txt:
    7. line_datas = str(line).strip("\n").split(",")
    8. print(line_datas)
    9. writer.writerow(line_datas)

    data1.txt的存放位置如下图,代码执行后会生成file.csv文件

            读取文件

    1. import tensorflow.compat.v1 as tf
    2. tf.disable_v2_behavior()
    3. #要保存后csv格式的文件名
    4. file_name_string = "file.csv"
    5. filename_queue = tf.train.string_input_producer([file_name_string])
    6. #定义reader,每次一行
    7. reader = tf.TextLineReader()
    8. key,value = reader.read(filename_queue)
    9. #定义decoder
    10. var1,var2 = tf.decode_csv(value,record_defaults = [[1.0],[1.0]])
    11. #运行图
    12. init = tf.global_variables_initializer()
    13. with tf.Session() as sess:
    14. sess.run(init)
    15. sess.run(tf.local_variables_initializer())
    16. #创建一个协调器,管理线程
    17. coord = tf.train.Coordinator()
    18. #启动QueueRunner,此时文件名队列已经进队
    19. threads = tf.train.start_queue_runners(coord = coord)
    20. for row in enumerate(open(file_name_string,"r")):
    21. e_val,l_val = sess.run([var1,var2])
    22. print(e_val,l_val)
    23. coord.request_stop()
    24. coord.join(threads)

    读取图片

    示例代码如下:

    1. import tensorflow.compat.v1 as tf
    2. tf.disable_v2_behavior()
    3. import os
    4. import matplotlib.pyplot as plt
    5. file_name = os.listdir("./image")
    6. file_list = [os.path.join("./image",file) for file in file_name]
    7. #创建输入队列,默认顺序打乱
    8. filename_queue = tf.train.string_input_producer(file_list,shuffle = True,num_epochs = 2)
    9. key,image = tf.WholeFileReader().read(filename_queue)
    10. #解码成tf中图像格式
    11. image = tf.image.decode_jpeg(image)
    12. with tf.Session() as sess:
    13. sess.run(tf.local_variables_initializer())
    14. #创建一个协调器,管理线程
    15. coord = tf.train.Coordinator()
    16. threads = tf.train.start_queue_runners(coord = coord)
    17. for _ in file_list:
    18. #执行
    19. img = image.eval()
    20. plt.figure(1)
    21. plt.imshow(img)
    22. plt.show()
    23. coord.request_stop()
    24. coord.join(threads)

  • 相关阅读:
    P3051 [USACO12MAR]Haybale Restacking G
    python-(6-2)爬虫---小试牛刀,获得网页页面内容
    【EI检索】2022年电子、通信与控制工程国际会议SECCE 2022
    Git的标签管理
    51单片机项目(13)——基于51单片机的智能台灯protues仿真
    Cluster API 检索从未如此简单
    玩转Mysql系列 - 第24篇:如何正确的使用索引?
    【蓝桥杯选拔赛真题33】python回文数升级 青少年组蓝桥杯python 选拔赛STEMA比赛真题解析
    [VIM] MiniBufExplorer插件
    洛谷P4316 绿豆蛙的归宿
  • 原文地址:https://blog.csdn.net/Victor_Li_/article/details/133698859