代码:
import torch
from torch.utils import data
class MyDataset(torch.utils.data.Dataset):
def __init__(self):
self.data=torch.arange(0,20)
def __getitem__(self,index):
x=self.data[index]
y=x * 2
return y
def __len__(self):
return len(self.data)
# 定义DataLoader
dataset=MyDataset()
print(len(dataset))
print(dataset[3])
打印结果:
import torch
from torch.utils import data
class MyDateset(torch.utils.data.Dataset):
def __init__(self):
self.data=torch.arange(0,20)
def __getitem__(self,index):
x=self.data[index]
y=x * 2
return y
def __len__(self):
return len(self.data)
# 定义DataLoader
dataset=MyDateset()
print(len(dataset))
print(dataset[3])
dataloader=torch.utils.data.DataLoader(dataset,shuffle=True,batch_size=4)
print(len(dataloader))
for x in dataloader:
print(x)
打印结果如下所示:
1、pin_memory
2、num_workers
3、collate_fn
分类任务
目标检测任务
在进行目标检测任务中,需要重写collate方法,并将方法传入到DataLoader中