• torch.utils.data.DataLoader


    1. #设置数据增强方法
    2. transform = transforms.Compose(
    3. [transforms.ToTensor(),
    4. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    5. #加载数据集的数据,返回所有样本的img和label
    6. trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
    7. download=True, transform=transform)
    8. #对数据进行batch采样
    9. trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
    10. shuffle=True, num_workers=2)

    1.加载数据集的是数据,返回所有样本的img和label

    通过数据加载类完成这一操作

    数据加载类包括三个函数:__init__()、__getitem__()、__len()__()

    (1)__init__()

    1. __init__(
    2. self,
    3. root: str,
    4. train: bool = True,
    5. transform: Optional[Callable] = None,
    6. target_transform: Optional[Callable] = None,
    7. download: bool = False,
    8. )

    返回所有样本的img和label

    (2)__getitem__()

    这个函数在进行epoch训练时才会运行,根据给出的index确定样本,并进行数据增强操作。

    返回数据增强后的样本。

    1. def __getitem__(self, index: int) -> Tuple[Any, Any]:
    2. img, target = self.data[index], self.targets[index]
    3. img = Image.fromarray(img)
    4. if self.transform is not None:
    5. img = self.transform(img)
    6. if self.target_transform is not None:
    7. target = self.target_transform(target)
    8. return img, target

    (3)__len()__()

    返回数据的数量

    1. def __len__(self) -> int:
    2. return len(self.data)

    2.确定训练时的数据加载方式

    torch.utils.data.DataLoader,结合了数据集和取样器,并且可以提供多个线程处理数据集。用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。

    参数:

    dataset:包含所有数据的数据集

    batch_size :每一小组所包含数据的数量

    Shuffle : 是否打乱数据位置,当为Ture时打乱数据,全部抛出数据后再次dataloader时重新打乱。

    sampler : 自定义从数据集中采样的策略,如果制定了采样策略,shuffle则必须为False.

    Batch_sampler:和sampler一样,但是每次返回一组的索引,和batch_size, shuffle, sampler, drop_last 互斥。

    num_workers : 使用线程的数量,当为0时数据直接加载到主程序,默认为0。

    collate_fn:不太了解

    pin_memory:s 是一个布尔类型,为T时将会把数据在返回前先复制到CUDA的固定内存中

    drop_last:布尔类型,为T时将会把最后不足batch_size的数据丢掉,为F将会把剩余的数据作为最后一小组。

    timeout:默认为0。当为正数的时候,这个数值为时间上限,每次取一个batch超过这个值的时候会报错。此参数必须为正数。

    worker_init_fn:和进程有关系,暂时用不到

    torch.utils.data.DataLoader中有采样器、迭代器、__len__()。

  • 相关阅读:
    数据库系统原理与应用教程(069)—— MySQL 练习题:操作题 95-100(十三):分组查询与聚合函数的使用
    学习Autodock分子对接
    常见的4种Bug 出现原因和解决方案
    Apache Doris 基础 -- 部分数据类型及操作
    AHR亚马逊账户健康评级多久更新,如何查看解决
    从零实现Web框架Geo教程-错误恢复-07
    eclipse / sts 设置类注释模板
    大数据开发(Hadoop面试真题-卷九)
    Android 10.0 禁用adb remount功能的实现
    Jenkins中Node节点与构建任务
  • 原文地址:https://blog.csdn.net/baidu_38262850/article/details/126203420