• TFRecord的Shuffle、划分和读取


    对数据集的shuffle处理需要设置相应的buffer_size参数,相当于需要将相应数目的样本读入内存,且这部分内存会在训练过程中一直保持占用。完全的shuffle需要将整个数据集读入内存,这在大规模数据集的情况下是不现实的,故需要结合设备内存以及Batch大小将TFRecord文件随机划分为多个子文件,再对数据集做local shuffle(即设置相对较小的buffer_size,不小于单个子文件的样本数)。

    Shuffle和划分

    下文以一个异常检测数据集(正负样本不平衡)为例,在生成第一批TFRecord时,我将正负样本分别写入单独的TFrecord文件以备后续在对正负样本有不同处理策略的情况下无需再解析example_proto。比如在以下代码中,我对正负样本有不同的验证集比例,并将他们写入不同的验证集文件。

    import numpy as np
    import tensorflow as tf
    from tqdm.notebook import tqdm as tqdm
    
    # TFRecord划分
    raw_normal_dataset = tf.data.TFRecordDataset("normal_16_256.tfrecords","GZIP")
    raw_anomaly_dataset = tf.data.TFRecordDataset("anomaly_16_256.tfrecords","GZIP")
    normal_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
    anomaly_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
    train_writer_list = [tf.io.TFRecordWriter(r'ex_1/'+'train_16_256_{}.tfrecords'.format(i),"GZIP") for i in range(SUBFILE_NUM+1)]
    with tqdm(total=LEN_NORMAL_DATASET+LEN_ANOMALY_DATASET) as pbar:
        for example_proto in raw_normal_dataset:
            # 划分训练集和测试集
            if np.random.random() > 0.99: # 正样本测试集的比例
                normal_val_writer.write(example_proto.numpy())
            else:
                train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
            pbar.update(1)
    
        for example_proto in raw_anomaly_dataset:
            # 划分训练集和测试集
            if np.random.random() > 0.7: # 负样本测试集的比例
                anomaly_val_writer.write(example_proto.numpy())
            else:
                train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy())
            pbar.update(1)
    normal_val_writer.close()
    anomaly_val_writer.close()
    for train_writer in train_writer_list:
        train_writer.close()
    
    折叠

    读取

    raw_train_dataset = tf.data.TFRecordDataset([r'ex_1/'+'train_16_256_{}.tfrecords'.format(i) for i in range(SUBFILE_NUM+1)],"GZIP")
    raw_train_dataset = raw_train_dataset.shuffle(buffer_size=100000).batch(BATCH_SIZE)
    parsed_train_dataset = raw_train_dataset.map(map_func=map_func)
    
    raw_normal_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP")
    raw_anomaly_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP")
    parsed_nomarl_val_dataset = raw_normal_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)
    parsed_anomaly_val_dateset = raw_anomaly_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)
    
  • 相关阅读:
    蓝桥杯:时间显示
    脱壳工具:Youpk的使用详解
    css主题切换
    UML基础
    Esko Ukkonen: On-line Construction of Suffix Trees
    Golang | Leetcode Golang题解之第200题岛屿数量
    Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)
    Datax的同步调研
    spring boot入门与理解MVC三层架构
    ZooKeeper的ZAB协议?
  • 原文地址:https://www.cnblogs.com/yc0806/p/16526114.html