在图像分类实验中,经常能看到对数据集进行数据增强操作,其中包括transforms.Normalize(),这个函数的定义如下:
torchvision.transforms.Normalize(mean, std, inplace=False)
功能:针对RGB3个 channel 分别对图像进行标准化
output = ( input - mean ) / std
通常ImageNet有自己的标准化参数,是通过抽样统计图像的均值方差得到的,那么针对本地特定数据集,如何获取到适合的参数呢?我参考了PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差_紫芝的博客-CSDN博客_pytorch 数据归一化
原文代码有一处错误,需要先把transform设置为transforms.ToTensor(),而不是None,否则会运行错误。以下是改正后的代码:
- def getStat(train_data):
- '''
- Compute mean and variance for training data
- :param train_data: 自定义类Dataset(或ImageFolder即可)
- :return: (mean, std)
- '''
- print('Compute mean and variance for training data.')
- print(len(train_data))
- train_loader = torch.utils.data.DataLoader(
- train_data, batch_size=1, shuffle=False, num_workers=0,
- pin_memory=True)
- mean = torch.zeros(3)
- std = torch.zeros(3)
- for X, _ in train_loader:
- for d in range(3):
- mean[d] += X[:, d, :, :].mean()
- std[d] += X[:, d, :, :].std()
- mean.div_(len(train_data))
- std.div_(len(train_data))
- return list(mean.numpy()), list(std.numpy())
-
-
- if __name__ == '__main__':
- train_dataset = ImageFolder(root=r'/data1/sharedata/leafseg/', transform=transforms.ToTensor())
- print(getStat(train_dataset))
Compute mean and variance for training data.
3257
([0.059938803, 0.08676067, 0.041085023], [0.10522498, 0.1488454, 0.07508467])
将结果写入transform列表中即可。
data_transforms = {
'train': transforms.Compose([
transforms.Resize(640),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.0599, 0.0868, 0.0411], [0.1052, 0.1488, 0.0751])
]),
'val': transforms.Compose([
transforms.Resize(640),
transforms.ToTensor(),
transforms.Normalize([0.0599, 0.0868, 0.0411], [0.1052, 0.1488, 0.0751])
]),
}