• pytorch实现数据集读取/下载



    #pic_center =400x
    系列文章:



    数据集读取本地

    Dataset基类介绍

    在 torch.utils.data.Dataset 提供了数据集的基类,我们只需要继承这个基类重写里面的方法即可完成数据集的加载与读取。
    重写两个方法:

    1. __len__方法,能够实现通过全局的len()方法获取其中的元素个数
    2. __getitem__方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据

    数据加载案例

    根据自己数据集的情况修改一下三个方法

    1. __init__方法可以用来设置读取数据集等初始化数据集的基本操作
    2. __getitem__方法通常用来根据索引来返回一条对应的数据内容
    3. __len__方法通常用来返回数据总数
    from torch.utils.data import DataLoader, Dataset
    class MyDataset(Dataset):
        def __init__(self):
            self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
    
        def __len__(self):
            return  len(self.lines)
    
        def __getitem__(self, index):
            # strip取消换行
            # cur_line = self.lines[index].strip()
            # label = cur_line[:4].strip()
            # content = cur_line[4:].strip()
            cur_line = self.lines[index]
            label = cur_line[:4]
            content = cur_line[4:]
            return label, content
    
    my_dataset = MyDataset()
    print(my_dataset[0])
    print(len(my_dataset))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    在这里插入图片描述
    在这里插入图片描述
    本文使用的数据集为开源的本文分类数据集SMS Spam Collection Data Set,下载地址为https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
    数据集是从 Grumbletext 网站手动提取了 425 条 SMS 垃圾邮件的集合,由一个文本文件构成,其中每一行都是有一个类别和后面的原始消息构成。

    DataLoader类的使用详解

    DataLoader的主要作用是将Dataset处理后的数据集进行加载整合成batch用于后续训练,使用方法如下

    from torch.utils.data import DataLoader, Dataset
    class MyDataset(Dataset):
        def __init__(self):
            self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
    
        def __len__(self):
            return  len(self.lines)
    
        def __getitem__(self, index):
            # strip取消换行
            # cur_line = self.lines[index].strip()
            # label = cur_line[:4].strip()
            # content = cur_line[4:].strip()
            cur_line = self.lines[index]
            label = cur_line[:4]
            content = cur_line[4:]
            return label, content
    
    my_dataset = MyDataset()
    
    if __name__ == '__main__':
    
        data_loader = DataLoader(dataset=my_dataset, batch_size=5, shuffle=True, num_workers=1)
        for i in  data_loader:
            print(i)
            break
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    在这里插入图片描述
    实际项目中通常使用enumerate方法 读取每一个batch内容的同时也返回 batch的索引

    from torch.utils.data import DataLoader, Dataset
    class MyDataset(Dataset):
        def __init__(self):
            self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
    
        def __len__(self):
            return  len(self.lines)
    
        def __getitem__(self, index):
            # strip取消换行
            # cur_line = self.lines[index].strip()
            # label = cur_line[:4].strip()
            # content = cur_line[4:].strip()
            cur_line = self.lines[index]
            label = cur_line[:4]
            content = cur_line[4:]
            return label, content
    
    my_dataset = MyDataset()
    
    if __name__ == '__main__':
    
        data_loader = DataLoader(dataset=my_dataset, batch_size=3, shuffle=True, num_workers=1)
        for index, (label, content) in enumerate(data_loader):
            print(index)
            print(label)
            print(content)
            break
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    在这里插入图片描述
    PyTorch数据加载方法

    数据集在线下载

    一篇学会Pytorch数据集下载

  • 相关阅读:
    Qt QMetaObject::invokeMethod
    主成分分析用于ERP研究的实用教程-机遇和挑战(附代码)
    优化SOCKS5的方法
    Golang go-redis cluster模式下不断创建新连接,效率下降问题解决
    mac苹果电脑使用耳机听不到声音
    Linux常见的指令合集
    三、appender分析
    SpringBoot+MybatisPlus Restful示例
    Java---SSM---SpringMVC(2)
    Python基础篇(十一)-- 模块和包
  • 原文地址:https://blog.csdn.net/weixin_42382758/article/details/126112072