Pytorch官网上对DataLoader方法进行了详细的介绍,数据加载器。结合数据集和采样器,并提供给定数据集的可迭代对象。DataLoader
支持具有单进程或多进程加载、自定义加载顺序和可选的自动批处理(整理)和内存固定的地图样式和可迭代样式数据集。
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device='')
参数
1
)。True
在每个 epoch 重新洗牌数据(默认值:False
)。0
表示数据将在主进程中加载。(默认:0
)True
如果数据集大小不能被批次大小整除,则设置为丢弃最后一个不完整的批次。如果False
数据集的大小不能被批大小整除,那么最后一批将更小。(默认:False
)这里使用CIFAR10数据集,通过DataLoader
方法将数据集以64一组打包,在windows系统中num_workers=0
,最后在tensorboard中将打包好的图像展示。
注意,对于打包的图片展示,使用的方法是add_images()
方法,单张图片展示使用add_image()
方法
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
train_set = torchvision.datasets.CIFAR10(root='C:\\Users\\hp\\PycharmProjects\\pythonProject\\Pytorch_Learning\\p11-dataset_transform\\dataset',
train=True, transform=torchvision.transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
writer = SummaryWriter("logs")
step = 0
for data in train_loader:
img, target = data
writer.add_images("test_data", img, step)
step += 1
writer.close()