目录
当一个python类中定义了__getitem__函数,则其实例对象能够通过下标来进行索引数据。
代码:
- import numpy as np
-
- # 创建类
- class Example():
- def __getitem__(self, index):
- data = np.array([[1,2,3], [4,5,6], [7,8,9]])
- return data[index]
-
- # 使用Example类实例对象example1
- example1 = Example()
-
- # 索引访问数据
- print('example1[0][0]:', example1[0][0])
- print('example1[0]:', example1[0])
-
- # 切片访问数据
- print('example1[0:2]:\n', example1[0:2])
输出:
- example1[0][0]: 1
-
- example1[0]: [1 2 3]
-
- example1[0:2]:
- [[1 2 3]
- [4 5 6]]
代码:
- import torch
- import numpy as np
- from torch.utils.data import Dataset
-
- # 创建MyDataset类
- class MyDataset(Dataset):
- def __init__(self, x, y):
- self.data = torch.from_numpy(x).float()
- self.label = torch.LongTensor(y)
-
- def __getitem__(self, idx):
- return self.data[idx], self.label[idx], idx
-
- def __len__(self):
- return len(self.data)
-
- Train_data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
- Train_label = np.array([10, 11, 12, 13])
- TrainDataset = MyDataset(Train_data, Train_label) # 创建实例对象
- print('len:', len(TrainDataset))
-
- # 创建DataLoader
- loader = torch.utils.data.DataLoader(
- dataset=TrainDataset,
- batch_size=2,
- shuffle=False,
- num_workers=0,
- drop_last=False)
-
- # 按batchsize打印数据
- for batch_idx, (data, label, index) in enumerate(loader):
- print('batch_idx:',batch_idx, '\ndata:',data, '\nlabel:',label, '\nindex:',index)
- print('---------')
输出:
- len: 4
-
- batch_idx: 0
- data: tensor([[1., 2., 3.],
- [4., 5., 6.]])
- label: tensor([10, 11])
- index: tensor([0, 1])
- ---------
- batch_idx: 1
- data: tensor([[ 7., 8., 9.],
- [10., 11., 12.]])
- label: tensor([12, 13])
- index: tensor([2, 3])
- ---------