• PyTorch 迭代器读取数据


    from torch.utils.data import Dataset
    
    class MetaDataset(Dataset):
      def __init__(self, n_episode, value):
        self.value = value
        self.n_episode = n_episode
      
      def set_iter(self):
        self.iterator = self._iter()
    
      def _iter(self):
        i = 0
        while True:
          yield self.value + i
          i += 1
    
      def __getitem__(self, i):
        return next(self.iterator)
    
      def __len__(self):
        return self.n_episode
    
    if __name__ == "__main__":
      dataset = MetaDataset(5, 1)
      print("dataset length:", len(dataset))
      dataset.set_iter()
      for i in range(len(dataset)):
        print(dataset[i])
    

    输出:

    dataset length: 5
    1
    2
    3
    4
    5
    

    Version 1:想着在__getitem__里面调用迭代器_iter(),每次getitem就取一次值,结果发现拿到的是一个function,没法用。

    from torch.utils.data import Dataset
    
    class MetaDataset(Dataset):
      def __init__(self, n_episode, value):
        self.value = value
        self.n_episode = n_episode
      
      def _iter(self):
        i = 0
        while True:
          yield self.value + i
          i += 1
    
      def __getitem__(self, i):
        v = self._iter()
        return v
    
      def __len__(self):
        return self.n_episode
    
    if __name__ == "__main__":
      dataset = MetaDataset(5, 1)
      print("dataset length:", len(dataset))
      for i in range(len(dataset)):
        print(dataset[i])
    

    输出:

    dataset length: 5
    <generator object MetaDataset._iter at 0x7effac314d60>
    <generator object MetaDataset._iter at 0x7effac314d60>
    <generator object MetaDataset._iter at 0x7effac314d60>
    <generator object MetaDataset._iter at 0x7effac314d60>
    <generator object MetaDataset._iter at 0x7effac314d60>
    

    在这里插入图片描述
    Version 2:查了一下之后,发现应该用next才能取到迭代器的值,于是加了next,现在能拿到值了,但是只能取第一个值。为啥?因为把迭代器的初始化操作v=self.iter()放在了getitem里面,那么每次getitem实际上都会重新初始化迭代器。

    from torch.utils.data import Dataset
    
    class MetaDataset(Dataset):
      def __init__(self, n_episode, value):
        self.value = value
        self.n_episode = n_episode
      
      def _iter(self):
        i = 0
        while True:
          yield self.value + i
          i += 1
    
      def __getitem__(self, i):
        v = self._iter()
        return next(v)    # 使用next()
    
      def __len__(self):
        return self.n_episode
    
    if __name__ == "__main__":
      dataset = MetaDataset(5, 1)
      print("dataset length:", len(dataset))
      for i in range(len(dataset)):
        print(dataset[i])
    

    输出:

    dataset length: 5
    1
    1
    1
    1
    1
    

    在这里插入图片描述
    Version 3:既然这样,那就把迭代器的初始化放到__init__的时候去做,然后发现果然work

    from torch.utils.data import Dataset
    
    class MetaDataset(Dataset):
      def __init__(self, n_episode, value):
        self.value = value
        self.n_episode = n_episode
      	
        def _iter():
            i = 0
            while True:
                yield self.value + i
                i += 1
                
        self.iterator = _iter()
    
      def __getitem__(self, i):
        return next(self.iterator)
    
      def __len__(self):
        return self.n_episode
    
    if __name__ == "__main__":
      dataset = MetaDataset(5, 1)
      print("dataset length:", len(dataset))
      for i in range(len(dataset)):
        print(dataset[i])
    

    输出:

    dataset length: 5
    1
    2
    3
    4
    5
    

    在这里插入图片描述
    Version 4:我想着实际代码里面肯定不能这么写,因为初始化的时候很多函数都在里面,所以就加了一个init,专门用来初始化迭代器。

    from torch.utils.data import Dataset
    
    class MetaDataset(Dataset):
      def __init__(self, n_episode, value):
        self.value = value
        self.n_episode = n_episode
      
      # 初始化迭代器
      def set_iter(self):
        self.iterator = self._iter()
    
      def _iter(self):
        i = 0
        while True:
          yield self.value + i
          i += 1
    
      def __getitem__(self, i):
        return next(self.iterator)
    
      def __len__(self):
        return self.n_episode
    
    if __name__ == "__main__":
      dataset = MetaDataset(5, 1)
      print("dataset length:", len(dataset))
      dataset.set_iter()
      for i in range(len(dataset)):
        print(dataset[i])
    

    输出:

    dataset length: 5
    1
    2
    3
    4
    5
    

    在这里插入图片描述
    Version 5:好像直接在__init__里面初始化iterator也可以……??事实证明没毛病。

    from torch.utils.data import Dataset
    
    class MetaDataset(Dataset):
      def __init__(self, n_episode, value):
        self.value = value
        self.n_episode = n_episode
        self.iterator = self._iter()   # 直接在init里初始化迭代器
      
      def _iter(self):
        i = 0
        while True:
          yield self.value + i
          i += 1
    
      def __getitem__(self, i):
        return next(self.iterator)
    
      def __len__(self):
        return self.n_episode
    
    if __name__ == "__main__":
      dataset = MetaDataset(5, 1)
      print("dataset length:", len(dataset))
      for i in range(len(dataset)):
        print(dataset[i])
    

    输出:

    dataset length: 5
    1
    2
    3
    4
    5
    

    在这里插入图片描述

    Version 6:

    这种方式不管是enumerate还是data_loader[i]都可以拿到那个iterate_dataset()生成的元素!
    在这里插入图片描述

    本质上就是dataset类可以视为list,既能enumerate访问也能dataset[index]访问

    dataset类只需要改写__getitem__和__len__方法,TensorDataset是这两个方法
    IterableDataset是__iter__和__getitem__

    但是pytorch的dataloader类只是一个迭代器,他只能enumerate访问,不能dataloader[index]访问(我才发现这一点)。dataloader只重写了__iter__,没有重写其他的

    在这里插入图片描述
    在这里插入图片描述

  • 相关阅读:
    前端和后端是Web开发选哪个好?
    # 磁盘引导方式相关知识之BIOS、msdos、MBR、UEFI、gpt、esp、csm
    c语言-自研定时器计划任务语法
    MES管理系统的生产模块与ERP有何差异
    常用Windows快捷键大全
    【排序算法】选择排序
    PGL图学习之图游走类metapath2vec模型[系列五]
    Python 中的邻接矩阵
    TSUMU58CDT9-1显示器芯片方案
    如何使用SMS向客户传递服务信息?指南在这里!
  • 原文地址:https://blog.csdn.net/qq_31347869/article/details/125777859