• 【Pytorch笔记】5.DataLoader、Dataset、自定义Dataset


    参考
    深度之眼官方账号 - 02-01 Dataloader与Dataset.mp4

    torch.utils.data.DataLoader

    功能:构建可迭代的数据装载器。

    data.DataLoader(dataset,
                    batch_size=1,
                    shuffle=False,
                    sampler=None,
                    batch_sampler=None,
                    num_workers=0,
                    collate_fn=None,
                    pin_memory=False,
                    drop_last=False,
                    timeout=0,
                    worker_init_fn=None,
                    multiprocessing_context=None)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    dataset:Dataset类,决定数据从哪里读取;
    batch_size:批大小
    shuffle:每个epoch是否乱序;
    sampler:定义从数据集中抽取数据的策略,指明sampler时不能指明shuffle。
    batch_sampler:和sampler相似,但是一次返回一个批大小的下标。和batch_sizeshufflesamplerdrop_last互斥;
    num_workers:使用该DataLoader加载数据的子进程的数量。如果该值为0,意味着只有主进程使用该DataLoader加载数据。
    collate_fn:合并样本列表,形成一个小型tensor批次。在使用map-style dataset批量加载时使用。
    pin_memory:True时,DataLoader在返回tensor数据时会先将这些数据复制到device/CUDA的pinned memory中。
    drop_last:当样本数不能被batch_size整除时,是否舍弃最后一批数据。
    timeout:如果是正数,则表示获取批次的超时值。要求设置为非负数。
    worker_init_fn:如果不是None,那么这个函数会调用所有带有worker_id的worker的子进程。

    Epoch:所有训练样本都已输入到模型中,称为一个Epoch。
    Iteration:一批样本输入到模型中,称为一个Iteration。
    BatchSize:批大小,决定一个Epoch有多少个Iteration。

    torch.utils.data.Dataset

    功能:Dataset抽象类,所有自定义的Dataset需要继承他,并且复写__getitem__()

    __getitem__():接收一个索引,返回一个样本。

    torch.utils.data.TensorDataset

    由tensor封装的数据集,传入一些在第1维长度相同的tensor,形成dataset。

    DataLoader如何使用?

    import torch
    from torch.utils import data
    
    
    def data_gen(num_examples):
        X = torch.normal(0, 1, (num_examples, 2))
        y = torch.normal(0, 1, (num_examples, 1))
        return X, y
    
    
    def load_array(data_arrays, batch_size, need_shuffle=True):
        dataset = data.TensorDataset(*data_arrays)
        return data.DataLoader(dataset, batch_size, shuffle=need_shuffle)
    
    
    data_A, data_B = data_gen(30)
    # 定义一个数据迭代器data_iter
    data_iter = load_array((data_A, data_B), 10, False)
    
    # 使用for循环,iter每次会指向下一组数据
    # 每组数据返回我们封装进去的那些tensor,封装几个就返回几个
    for X, y in iter(data_iter):
        print("X\n", X)
        print("y\n", y)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    输出:

    X
     tensor([[-0.2338, -0.4072],
            [ 1.6080, -0.0880],
            [-1.0243, -0.6245],
            [ 0.6012,  1.9620],
            [ 1.1876,  0.9539],
            [-0.5972,  0.9251],
            [-0.4918, -0.1340],
            [ 2.3297,  0.1833],
            [-0.8487,  0.3370],
            [-0.0600,  2.3769]])
    y
     tensor([[ 1.0454],
            [ 0.5992],
            [ 2.0075],
            [ 0.2727],
            [-1.6845],
            [ 0.0845],
            [ 1.0992],
            [ 0.5103],
            [-0.6727],
            [-0.1900]])
    X
     tensor([[-0.1419, -1.5535],
            [-1.6436,  0.8680],
            [-0.4432, -0.7703],
            [ 0.3822,  0.4675],
            [ 0.4000,  1.3471],
            [ 0.9776,  2.0103],
            [ 0.1298,  2.7382],
            [ 0.2664, -0.6223],
            [-1.0774,  0.0734],
            [-0.1904, -1.3299]])
    y
     tensor([[-0.5979],
            [-0.5432],
            [ 0.2951],
            [ 0.2811],
            [-0.5997],
            [ 0.8073],
            [ 1.4356],
            [ 1.1555],
            [-0.3368],
            [-0.0626]])
    X
     tensor([[-1.4326, -0.3407],
            [-1.1878, -1.5619],
            [ 0.3498,  1.5307],
            [-0.8174,  0.6017],
            [ 0.8076,  0.8295],
            [ 2.6239,  1.1669],
            [-1.2598,  1.4309],
            [ 0.3365,  0.1765],
            [-0.4472, -0.6882],
            [ 0.6732, -0.0742]])
    y
     tensor([[ 0.5114],
            [ 1.0669],
            [-1.5565],
            [ 0.4512],
            [ 3.2071],
            [ 0.4752],
            [-1.5981],
            [ 0.0035],
            [-0.2723],
            [-1.3634]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66

    自定义Dataset

    import torch
    from torch.utils import data
    
    # 自定义的Dataset
    # 每一条数据包含X的一行数据(1x2的tensor)和y的一行数据(1x1的tensor)
    class MyTensorDataset(data.Dataset):
        def __init__(self, tensor_list):
            # 初始化超类 (data.Dataset)
            super(MyTensorDataset, self).__init__()
            self.data = tensor_list
    
        
        def __getitem__(self, index):
            # 这个函数必须重写
            # 返回dataset中下标编号为index的数据
            sample = {'X': self.data[0][index], 'y': self.data[1][index]}
            return sample
        
    
        def __len__(self):
            # 返回这个dataset中的数据个数
            return len(self.data[0])
    
    
    # 随机生成数据
    def data_gen(num_examples):
        X = torch.normal(0, 1, (num_examples, 2))
        y = torch.normal(0, 1, (num_examples, 1))
        return X, y
    
    
    # 使用自己的Dataset
    def load_array(data_arrays, batch_size, need_shuffle=True):
        X, y = data_arrays  
        dataset = MyTensorDataset((X, y)) 
        return data.DataLoader(dataset, batch_size, shuffle=need_shuffle)
    
    
    data_A, data_B = data_gen(30)
    data_iter = load_array((data_A, data_B), 10, False)
    
    for sample in iter(data_iter):
        print("sample:\n", sample)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43

    输出:

    sample:
     {'X': tensor([[ 0.2149,  0.6216],
            [ 0.4691,  0.1862],
            [-1.7705, -1.5983],
            [-0.0196, -0.5903],
            [-0.1313,  0.3206],
            [ 0.1898, -0.2575],
            [ 1.5934, -0.4720],
            [-0.8343, -0.2181],
            [ 1.6159, -0.5473],
            [-1.2662, -0.0218]]), 'y': tensor([[ 0.8886],
            [ 0.1653],
            [ 0.8054],
            [-0.0725],
            [-0.4806],
            [-0.4661],
            [-0.4040],
            [ 0.6192],
            [-0.2522],
            [ 1.4091]])}
    sample:
     {'X': tensor([[ 0.6441, -0.5759],
            [-0.7285, -1.0021],
            [ 0.1250, -0.2333],
            [ 0.3196,  0.7762],
            [ 0.1429,  0.4667],
            [ 1.0751, -0.4867],
            [ 0.1664,  0.3489],
            [ 0.1616, -0.1998],
            [ 0.6707, -0.4678],
            [-1.7778, -2.4658]]), 'y': tensor([[-0.2342],
            [ 0.1402],
            [ 0.5768],
            [ 1.7898],
            [-0.6802],
            [-2.3584],
            [ 0.7048],
            [ 0.1848],
            [ 0.1225],
            [ 0.5535]])}
    sample:
     {'X': tensor([[ 0.1694,  0.2863],
            [ 0.3062, -0.7494],
            [ 0.6844,  1.9278],
            [ 0.9141,  0.3842],
            [-1.2314, -1.4933],
            [ 0.1568, -1.4182],
            [ 1.7723, -0.4890],
            [ 0.5734,  1.0614],
            [ 0.9536, -0.2866],
            [ 0.2510,  0.3375]]), 'y': tensor([[ 0.9802],
            [-0.5557],
            [ 0.7763],
            [ 1.1688],
            [-1.0067],
            [-0.4044],
            [-0.2745],
            [-0.3661],
            [-0.6058],
            [ 0.6905]])}
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
  • 相关阅读:
    SpringMVC(二)
    LeetCode每日一题:1620. 网络信号最好的坐标 (中等) 暴力枚举
    PHP下富文本HTML过滤器:HTMLPurifier使用教程
    一言不合就重构
    C/C++ 泛型模板约束
    【Java八股文之进阶篇(五)】多线程编程核心之并发容器
    【Linux】简易ftp文件传输项目(云盘)
    vant_vant引入
    HarmonyOS--状态管理--装饰器
    Fiddle日常运用手册(3)-对移动端产品进行数据接口抓包
  • 原文地址:https://blog.csdn.net/xhyu61/article/details/133553713