- import torchvision
- from torch.utils.data import DataLoader
- from torch.utils.tensorboard import SummaryWriter
-
- # 准备的测试数据集 数据放在了CIFAR10文件夹下
-
- test_data = torchvision.datasets.CIFAR10("./CIFAR10",
- train=False, transform=torchvision.transforms.ToTensor())
- test_loader = DataLoader(dataset=test_data, batch_size=4,
- shuffle=True, num_workers=0, drop_last=False)
-
- # 测试数据集中第一张图片及target
- img, target = test_data[0]
- print(img.shape)
- print(target)
-
- # 在定义test_loader时,设置了batch_size=4,表示一次性从数据集中取出4个数据
- for data in test_loader:
- imgs, targets = data
- print(imgs.shape)
- print(targets)
-
- # 在定义test_loader时,设置了batch_size=4,表示一次性从数据集中取出4个数据
- writer = SummaryWriter("logs")
- for epoch in range(2):
- step = 0
- for data in test_loader:
- imgs, targets = data
- writer.add_images("Epoch: {}".format(epoch), imgs, step)
- step = step + 1
- writer.close()
把CIFAR10做成一个数据集,然后得到迭代器
每个迭代器包括图像和标签
下面是tensorboard的用法

