最近写代码又遇见了这个问题,又忘记了,于是写一篇博客记录一下。
一般我们使用pytorch获取CIFAR10数据集,一般这样写:
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
最后出来的结果都是小数和xxx数。
如果使用了ToTensoer
,那么会将原始数据都归一化到0~1的范围内,数据都将除以255。
归一化之后,就是标准化,我们使用Normalize并传入mean和std,公式是:
o
u
t
p
u
t
=
i
n
p
u
t
−
m
e
a
n
s
t
d
output = \frac{input -mean}{std}
output=stdinput−mean
注意!input已经被除255了。
这样就得到了最后的结果。
其实数据一直都没有被修改,当你使用
dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
得到一个训练集的时候,原始数据并没有被transform,数据其实一直保存在dst_train.data里
在迭代或者通过下标获取数据时,才会使用transform来修改数据。
这个类维持一个data原始数据,因此有时候如果要修改数据,其实没必要去修改标准化后的数据,直接修改.data即可。
如果有人做的是后门攻击,可以尝试一下重写CIFAR10数据集的类,重写__getitem__ 即可。