Torchvision 官方文档 Torchvision 中的 torchvision.datasets
就是 Torchvision 提供的标准数据集,其中有以下内容:
我们以 CIFAR 为例,该数据集包括了60000张32*32像素的图像,总共有10个类别,每个类别有6000张图像,其中有50000张图像为训练图像,10000张为测试图像。其使用说明如下图所示:
root
:数据集存放的路径。train
:如果为 True,创建的数据集就为训练集,否则创建的数据集就为测试集。transform
:使用 transforms
中的变换操作对数据集进行变换。target_transform
:对 target 进行 transform。download
:如果为 True,就会自动从网上下载这个数据集,否则就不会下载。例如:
import torchvision
train_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=False, download=True)
print(train_set[0]) # (, 6)
刚开始运行时可以看到正在从网上下载数据集,如果下载速度非常慢可以复制链接去迅雷之类的地方下载,下载好后自己创建设定的路径,将数据集放过来即可:
然后设置断点,用 Debug 模式运行一下代码,我们来查看一下数据集的内容:
可以看到 classes
表示图像的种类,classes_to_idx
表示将种类映射为整数,targets
表示每张图像对应的种类编号,试着输出一下第一张图的信息:
img, target = train_set[0]
print(img) #
print(target) # 6
print(train_set.classes[target]) # frog
img.show() # 图像显示为青蛙
现在展示如何使用 transform
参数,假设我们需要将数据集的图像都转换成 tensor 类型:
trans_dataset = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=True, transform=trans_dataset, download=True)
test_set = torchvision.datasets.CIFAR10(root='dataset/CIFAR10', train=False, transform=trans_dataset, download=True)
img, target = train_set[0]
print(type(img)) #