• pytorch -- 构建自己的Dateset,DataLoader如何使用


    目录

    1 torch.utils.data.Dataset类

    2 构建Dataset子类

    3 Dataloader类

    4 Dataset与Dataloader结合使用


             运行模型,使用他人构建的模型,主要是对自身数据dataset类的构造;

            最主要的是定义好数据集的特征X,和类别y;

    1 torch.utils.data.Dataset类

    首先先看Dataset的源码:

            torch.utils.data.Dataset类是pytorch中用来表示数据集的抽象类(只能被继承,不能被实例化,相当于从一堆类中抽取出的内容,包含数据属性和函数属性,只有抽象方法,只能被继承,且子类必须实现抽象方法);

            使用Dataset类来创建数据集;

    1. class Dataset(object):
    2. """An abstract class representing a Dataset.
    3. All other datasets should subclass it. All subclasses should override
    4. ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    5. supporting integer indexing in range from 0 to len(self) exclusive.
    6. """
    7. def __getitem__(self, index):
    8. raise NotImplementedError
    9. def __len__(self):
    10. raise NotImplementedError
    11. def __add__(self, other):
    12. return ConcatDataset([self, other])

    上述函数__getitem__(), __len__()是子类必须要继承的;

            __len__(): 使用该函数返回数据集的大小;

            __getitem__():通常其接收一个index, 用于查找数据和标签,这个index是指一个list的index,list中的每个元素包含数据和标签,其只有在用到的时候,才将数据读入;

            index的取值范围是根据__len__()的返回值确定的;

    2 构建Dataset子类

    1. class MyDataSet(Dataset):
    2. def __init__(self):
    3. # 将所需要的数据属性写在这个函数中
    4. self.data = ...
    5. self.label = ...
    6. def __getitem__(self, index):
    7. return self.data[index], self.label[index]
    8. def __len__(self):
    9. return len(self.data)

             值得注意的地方:

                    一般label值是Long整数类型的,所以标签的tensor,可以使用torch.LongTensor(数据)来转化成Long整数的形式;

                    使用pytorch的GPU训练的话,一般先判断cuda是否可用:

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

                    然后把数据和标签都使用.to()放到GPU显存上进行加速:

    1. for i,(data,label) in enumerate(mydataloader):
    2. data = data.to(device)
    3. label = label.to(device)
    4. print(data,label)

    3 Dataloader类

            这部分内容参考自:聊聊Pytorch中的dataloader - 知乎

            参数:

    1. torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, \
    2. batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, \
    3. drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

            dataset:定义的dataset类返回的结果;

            batchsize:每个bacth要加载的样本数,默认为1;

            shuffle:在每个epoch中对整个数据集data进行shuffle重排,默认为False;

            sampler:定义从数据集中加载数据所采用的策略,如果指定的话,shuffle必须为False;

            batch_sample类似,表示一次返回一个batch的index;

            num_workers:表示开启多少个线程数去加载你的数据,默认为0,代表只使用主进程;

            collate_fn:表示合并样本列表以形成小批量的Tensor对象;

            pin_memory:表示要将load进来的数据是否要拷贝到pin_memory区中,其表示生成的Tensor数据是属于内存中的锁页内存区,这样将Tensor数据转义到GPU中速度就会快一些,默认为False;(pin_memory,通常情况下,数据在内存中要么以锁页的方式存在,要么保存在虚拟内存(磁盘)中,设置为True后,数据直接保存在锁页内存中,后续直接传入cuda;否则需要先从虚拟内存中传入锁页内存中,再传入cuda,这样就比较耗时了,但是对于内存的大小要求比较高)

            drop_last:当你的整个数据长度不能够整除你的batchsize,选择是否要丢弃最后一个不完整的batch,默认为False;

    对num_workers,sample和collate_fn分别进行说明:

            1 设置num_workers:

            pytorch中dataloader一次性创建num_workers个子线程,然后用batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM,dataloader就可以直接从RAM中找本轮迭代要用的batch;

            如果num_worker设置得大,优点:是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮...迭代时已经加载好了;

            缺点:是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是进行CPU复制);

            如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤,只有当你需要的时候再加载相应的batch,当然速度就更慢;

            num_workers经验设置值是自己电脑/服务器的CPU核心数,如果CPU很强、RAM也很充足,就可以设置得更大些,对于单机来说,单跑一个任务的话,直接设置为CPU的核心数最好;

            (这里标注服务器查看CPU信息的bash指令: 参考:查看服务器cpu核数信息_beetle_lzk的博客-CSDN博客_查看服务器cpu核数

                    一:查看cpu信息

                            cat /proc/cpuinfo | grep name | cut -f2 -d: | uniq -c

                    二:查看物理cpu个数,也就是实物cpu的个数

                            cat /proc/cpuinfo| grep "physical id"| sort| uniq| wc -l

                    三:查看每个cpu的core,也就是常说的核心数

                            cat /proc/cpuinfo| grep "cpu cores"| uniq

                    四:查看服务器总的核心数,也就是逻辑cpu个数   ==  (物理cpu个数 * 每个cpu的核心数 * 超线程数)  

                            cat /proc/cpuinfo| grep "processor"| wc -l

            2 定义sample:

                    PyTorch中提供的这个sampler模块,用来对数据进行采样;

                    默认采用SequentialSampler,它会按顺序一个一个进行采样,常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据;

                    这里使用另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样;

    1. from torch.utils.data.sampler import WeightedRandomSampler
    2. ## 如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍
    3. weights = [2 if label == 1 else 1 for data, label in dataset]
    4. sampler = WeightedRandomSampler(weights,num_samples=10, replacement=True)
    5. dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)

                    replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据;

            3 定义collate_fn:

                    使用dataloader时加入collate_fn参数,即可合并样本列表以形成小批量的Tensor对象;

                    如果你的标签不止一个的话,还可以支持自定义,在下面方法中再额外添加对应的label即可:

    1. def detection_collate(batch):
    2. # 自定义整理fn ,用于处理具有不同数量的关联对象注释(边界框)的批次图像
    3. """Custom collate fn for dealing with batches of images that have a different
    4. number of associated object annotations (bounding boxes).
    5. Arguments:
    6. batch: (tuple) A tuple of tensor images and lists of annotations
    7. batch: (tuple)张量图像和注释列表的元组
    8. Return:
    9. A tuple containing:
    10. 1) (tensor) batch of images stacked on their 0 dim
    11. 2) (list of tensors) annotations for a given image are stacked on
    12. 0 dim
    13. """
    14. targets = []
    15. imgs = []
    16. for sample in batch:
    17. imgs.append(sample[0])
    18. targets.append(torch.FloatTensor(sample[1]))
    19. return torch.stack(imgs, 0), targets

            这是Dataloader应该这么写:

    1. data_loader = torch.utils.data.DataLoader(dataset, args.batch_size,
    2. num_workers=args.num_workers, sampler=sampler, shuffle=False,
    3. collate_fn=detection_collate, pin_memory=True, drop_last=True)

    4 Dataset与Dataloader结合使用

            实例化MyDataSet类:

    dataset = MyDataSet()

            MyDataset这个类中的__getitem__的返回值,应该是某一个样本的数据和标签;

            在模型训练时,一般是将多个数据组成batch,这便使用到Dataloader迭代器进行组合;

    mydataloader = DataLoader(dataset = dataset, batch_size = 16,shuffle = True)

            之后训练的时候,使用for循环来遍历mydataloader:  

    1. for i,(data,label) in enumerate(mydataloader):
    2. print(data,label)

            通过上述两个类,可以迅速做出batch数据,修改batch_size, 和乱序使用都很方便;

           

  • 相关阅读:
    华为网络设备高频命令
    密码技术---密钥和SSL/TLS
    如何用记事本制作一个简陋的小网页(3)——注册信息表
    打包时,模块的大小写很重要
    [附源码]计算机毕业设计springboot家庭整理服务管理系统
    Unity的碰撞检测(总结篇)
    如何购买阿里云香港服务器?又有什么什么好处呢?
    【Spring-2】refresh方法中的invokeBeanFactoryPostProcessors方法
    添加Java服务
    SpringMvc+Spring+MyBatis+Maven+Ajax+Json注解开发 利用Maven的依赖导入不使用架包模式 (实操十)
  • 原文地址:https://blog.csdn.net/qq_40671063/article/details/126847290