• ECCV2022 论文 Contrastive Deep Supervision


    论文链接https://arxiv.org/pdf/2207.05306.pdf

    代码链接GitHub - ArchipLab-LinfengZhang/contrastive-deep-supervision: Codes for ECCV2022 paper - contrastive deep supervision

    动机

    近年来,由于大量数据的出现以及计算机算力的提升,深度学习统治了计算机视觉领域。然而,随着神经网络深度增加的同时,也带来了一些挑战。传统的有监督方法仅对模型的最后一层进行监督,然后再将误差反向传播到中间层。由于反向传播过程中可能会出现梯度消失、爆炸及弥漫等问题,怎么优化好模型中间层的参数成为了一个难点。

    近期,深度监督被用于解决上述问题,它的做法是在中间层中添加辅助的分类器。在训练期间,辅助分类器与最终的分类器一同优化。大量实验证明,深度监督加速了模型的收敛。然而,通常来说,不同深度的特征学到的信息不同,底层特征往往含有丰富的纹理及颜色等信息,而深层特征往往含有丰富的语义信息,简单地将辅助分类器应用到中间层特征显然存在问题,因为底层特征没有丰富的语义信息,不适合进行分类 (底层特征往往用于目标定位,因为它含有较多的空间位置信息)。基于这些理论,就有了这篇文章 《Contrastive Deep Supervision》,以下简称 CDS。

    创新点

    这篇文章的作者认为:相比于有监督的任务损失,对比学习能给中间层的特征提供更好的监督。对比学习通常在同一张图片中使用两种不同的数据增强 (增强方法可以相同,但其中的参数不同),随后将增强后的两张图片视为正样本对,与其余图片构成负样本对。作者提出的方法如下图中的 (d) 所示,几个投影头会附件在中间层的后面,用于将特征映射到嵌入空间,以便进行对比学习,这些投影头在推理期间会被 kill 掉,这样就避免了额外的计算及额外的存储空间。与训练中间层特征去学习特定任务知识的深度监督不同,CDS 学习的是图片中的本质信息,这些信息不受数据增强的影响,这也使神经网络能更好地泛化。此外,由于对比学习可以在未标记的数据上进行,CDS 也可应用到半监督任务中。这篇文章的主要创新点如下:

    (1) 提出了 CDS,这是一种神经网络训练方法,其中中间层直接通过对比学习进行优化。它使神经网络能够学习更好的视觉表示,且无需在推理过程中增加额外的开销

    (2) 从深度监督的角度来看,作者第一个表明除了有监督任务损失之外,中间层还可以通过其他方式进行训练

    (3) 从表示学习的角度来看,作者首个表明对比学习和监督学习可以以一阶段的深度监督的方式联合训练模型,而不是两阶段的 “pretrain-finetune” 方案 (先预训练,后微调)

    方法论

    CDS

    假定一个 minibatch 有 N 张图片,对每张图片都进行两次随机的数据增强,增强后就有 2N 张图片。为了方便,作者把 x_{i} 和 x_{N+i} 作为来自同一图像的两个增强表示,这两张图片也被视为一个正样本对。z=c(x) 为经过投影层并标准化后的输出,对比学习的公式如下:

    L_{Contra} 鼓励编码器网络从同一图像中学习不同增强的相似表示,同时增加来自不同图像的增强表示之间的差异。

    CDS 与深度监督之间的主要区别在于深度监督通过交叉熵损失来训练辅助分类器,而 CDS 则通过对比学习来训练。CDS 整体损失函数公式如下:

    这个公式表示有 K-1 个中间层使用了对比学习来训练,最后一层使用交叉熵损失来训练。

    CDS 还可以推广到半监督学习和知识蒸馏中:

    在半监督学习中,作者假设有 X_{1} 个带标签的图片,对应的标签为 Y_{1},无标签的数据为 X_{2}。在有标签数据中,可以直接使用 CDS。在无标签数据中,只能进行对比学习。整体的损失公式如下:

    在知识蒸馏中,作者进一步提出通过将教师模型学到的图像在数据增强中的不变性传递给学生模型,来改进具有 CDS 的知识蒸馏。f^{S} 和 f^{T} 分别表示知识蒸馏中的学生模型和教师模型,原始的知识蒸馏直接最小化了学生和教师模型的骨干特征之间的距离,可以表示为:

    与原始知识蒸馏不同,带有 CDS 的知识蒸馏最小化的是两个模型的嵌入向量 (经投影层得到) 之间的距离,公式如下:

    知识蒸馏中的整体损失函数公式如下:

    一些细节和 tricks

    投影层的设计

    在 CDS 的训练期间,将几个投影头添加到神经网络的中间层。这些投影头将骨干特征映射到归一化的嵌入空间,其中应用了对比学习损失。通常,投影头是由两个全连接层和一个 ReLU 函数堆叠而成的非线性投影。然而,在 CDS 中,输入特征来自中间层而不是最终层,因此需要修改投影层的设计。作者通过在非线性投影之前添加卷积层来增加这些投影头的复杂性。

    对比学习

    CDS 是一个通用的训练框架,不依赖于特定的对比学习方法。在这篇文章中,作者在大多数实验中采用 SimCLR 和 SupCon 作为对比学习的方法。如果使用更好的对比学习算法,模型最终的性能也会进一步提升。

    负样本

    以前的研究表明,负样本的数量对对比学习的表现有着重要的影响,因此在对比学习中通常使用大的 batch size。但在 CDS 中,作者认为诸如交叉熵之类的损失已经足以防止对比学习收敛到崩溃的解决方案。

    实验结果

    在 CIFAR100 和 CIFAR10 上的分类结果如下:

    ImageNet 上的分类结果如下:

    在目标检测数据集 COCO2017 上的结果如下:

    在细粒度数据集上的结果如下:

    代码

    代码也比较简单,拿 resnet18 来举例:

    1. import torch.nn as nn
    2. import torch.utils.model_zoo as model_zoo
    3. import torch.nn.functional as F
    4. import torch
    5. __all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]
    6. # model_urls = {
    7. # "resnet18": "./pretrain/resnet18-5c106cde.pth",
    8. # "resnet34": "./pretrain/resnet34-333f7ec4.pth",
    9. # "resnet50": "./pretrain/resnet50-19c8e357.pth",
    10. # "resnet101": "./pretrain/resnet101-5d3b4d8f.pth",
    11. # "resnet152": "./pretrain/resnet152-b121ed2d.pth",
    12. # }
    13. model_urls = {
    14. "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
    15. "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
    16. "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
    17. "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
    18. "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
    19. }
    20. def conv3x3(in_planes, out_planes, stride=1):
    21. """3x3 convolution with padding"""
    22. return nn.Conv2d(
    23. in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
    24. )
    25. def conv1x1(in_planes, out_planes, stride=1):
    26. """1x1 convolution"""
    27. return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
    28. class SepConv(nn.Module):
    29. def __init__(
    30. self, channel_in, channel_out, kernel_size=3, stride=2, padding=1, affine=True
    31. ):
    32. # depthwise and pointwise convolution, downsample by 2
    33. super(SepConv, self).__init__()
    34. self.op = nn.Sequential(
    35. nn.Conv2d(
    36. channel_in,
    37. channel_in,
    38. kernel_size=kernel_size,
    39. stride=stride,
    40. padding=padding,
    41. groups=channel_in,
    42. bias=False,
    43. ),
    44. nn.Conv2d(channel_in, channel_in, kernel_size=1, padding=0, bias=False),
    45. nn.BatchNorm2d(channel_in, affine=affine),
    46. nn.ReLU(inplace=False),
    47. nn.Conv2d(
    48. channel_in,
    49. channel_in,
    50. kernel_size=kernel_size,
    51. stride=1,
    52. padding=padding,
    53. groups=channel_in,
    54. bias=False,
    55. ),
    56. nn.Conv2d(channel_in, channel_out, kernel_size=1, padding=0, bias=False),
    57. nn.BatchNorm2d(channel_out, affine=affine),
    58. nn.ReLU(inplace=False),
    59. )
    60. def forward(self, x):
    61. return self.op(x)
    62. class BasicBlock(nn.Module):
    63. expansion = 1
    64. def __init__(self, inplanes, planes, stride=1, downsample=None):
    65. super(BasicBlock, self).__init__()
    66. self.conv1 = conv3x3(inplanes, planes, stride)
    67. self.bn1 = nn.BatchNorm2d(planes)
    68. self.relu = nn.ReLU(inplace=True)
    69. self.conv2 = conv3x3(planes, planes)
    70. self.bn2 = nn.BatchNorm2d(planes)
    71. self.downsample = downsample
    72. self.stride = stride
    73. def forward(self, x):
    74. identity = x
    75. out = self.conv1(x)
    76. out = self.bn1(out)
    77. out = self.relu(out)
    78. out = self.conv2(out)
    79. out = self.bn2(out)
    80. if self.downsample is not None:
    81. identity = self.downsample(x)
    82. out += identity
    83. out = self.relu(out)
    84. return out
    85. class ResNet(nn.Module):
    86. def __init__(
    87. self, block, layers, num_classes=100, zero_init_residual=False, align="CONV"
    88. ):
    89. super(ResNet, self).__init__()
    90. self.inplanes = 64
    91. self.align = align
    92. # reduce the kernel-size and stride of ResNet on cifar datasets.
    93. self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    94. self.bn1 = nn.BatchNorm2d(64)
    95. self.relu = nn.ReLU(inplace=True)
    96. # remove maxpooling layer for ResNet on cifar datasets.
    97. # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    98. self.layer1 = self._make_layer(block, 64, layers[0])
    99. self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
    100. self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
    101. self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
    102. self.auxiliary1 = nn.Sequential(
    103. SepConv(channel_in=64 * block.expansion, channel_out=128 * block.expansion),
    104. SepConv(
    105. channel_in=128 * block.expansion, channel_out=256 * block.expansion
    106. ),
    107. SepConv(
    108. channel_in=256 * block.expansion, channel_out=512 * block.expansion
    109. ),
    110. nn.AvgPool2d(4, 4),
    111. )
    112. self.auxiliary2 = nn.Sequential(
    113. SepConv(
    114. channel_in=128 * block.expansion,
    115. channel_out=256 * block.expansion,
    116. ),
    117. SepConv(
    118. channel_in=256 * block.expansion,
    119. channel_out=512 * block.expansion,
    120. ),
    121. nn.AvgPool2d(4, 4),
    122. )
    123. self.auxiliary3 = nn.Sequential(
    124. SepConv(
    125. channel_in=256 * block.expansion,
    126. channel_out=512 * block.expansion,
    127. ),
    128. nn.AvgPool2d(4, 4),
    129. )
    130. self.auxiliary4 = nn.AvgPool2d(4, 4)
    131. self.fc = nn.Linear(512 * block.expansion, num_classes)
    132. for m in self.modules():
    133. if isinstance(m, nn.Conv2d):
    134. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
    135. elif isinstance(m, nn.BatchNorm2d):
    136. nn.init.constant_(m.weight, 1)
    137. nn.init.constant_(m.bias, 0)
    138. # Zero-initialize the last BN in each residual branch,
    139. # so that the residual branch starts with zeros, and each residual block behaves like an identity.
    140. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
    141. if zero_init_residual:
    142. for m in self.modules():
    143. if isinstance(m, Bottleneck):
    144. nn.init.constant_(m.bn3.weight, 0)
    145. elif isinstance(m, BasicBlock):
    146. nn.init.constant_(m.bn2.weight, 0)
    147. def _make_layer(self, block, planes, blocks, stride=1):
    148. downsample = None
    149. if stride != 1 or self.inplanes != planes * block.expansion:
    150. downsample = nn.Sequential(
    151. conv1x1(self.inplanes, planes * block.expansion, stride),
    152. nn.BatchNorm2d(planes * block.expansion),
    153. )
    154. layers = []
    155. layers.append(block(self.inplanes, planes, stride, downsample))
    156. self.inplanes = planes * block.expansion
    157. for _ in range(1, blocks):
    158. layers.append(block(self.inplanes, planes))
    159. return nn.Sequential(*layers)
    160. def forward(self, x):
    161. feature_list = []
    162. x = self.conv1(x)
    163. x = self.bn1(x)
    164. x = self.relu(x)
    165. x = self.layer1(x)
    166. feature_list.append(x)
    167. x = self.layer2(x)
    168. feature_list.append(x)
    169. x = self.layer3(x)
    170. feature_list.append(x)
    171. x = self.layer4(x)
    172. feature_list.append(x)
    173. out1_feature = self.auxiliary1(feature_list[0]).view(x.size(0), -1)
    174. out2_feature = self.auxiliary2(feature_list[1]).view(x.size(0), -1)
    175. out3_feature = self.auxiliary3(feature_list[2]).view(x.size(0), -1)
    176. out4_feature = self.auxiliary4(feature_list[3]).view(x.size(0), -1)
    177. out = self.fc(out4_feature)
    178. feat_list = [out4_feature, out3_feature, out2_feature, out1_feature]
    179. for index in range(len(feat_list)):
    180. feat_list[index] = F.normalize(feat_list[index], dim=1)
    181. if self.training:
    182. return out, feat_list
    183. else:
    184. return out
    185. def resnet18(pretrained=False, **kwargs):
    186. """Constructs a ResNet-18 model.
    187. Args:
    188. pretrained (bool): If True, returns a model pre-trained on ImageNet
    189. """
    190. model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    191. if pretrained:
    192. model.load_state_dict(
    193. model_zoo.load_url(model_urls["resnet50"])
    194. )
    195. return model

    就是在 resnet 的4个 layer 后添加了 auxiliary head,而 auxiliary head 又由深度可分离卷积与平均池化层构成,用于进一步提取特征 (因为作者认为 resnet 提取的特征的表达能力还不够强,需要进一步提取)

    对比学习的损失函数代码如下:

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. class SupConLoss(nn.Module):
    5. """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    6. It also supports the unsupervised contrastive loss in SimCLR"""
    7. def __init__(self, temperature=0.07, contrast_mode='all',
    8. base_temperature=0.07):
    9. super(SupConLoss, self).__init__()
    10. self.temperature = temperature
    11. self.contrast_mode = contrast_mode
    12. self.base_temperature = base_temperature
    13. def forward(self, features, labels=None, mask=None):
    14. """Compute loss for model. If both `labels` and `mask` are None,
    15. it degenerates to SimCLR unsupervised loss:
    16. https://arxiv.org/pdf/2002.05709.pdf
    17. Args:
    18. features: hidden vector of shape [bsz, n_views, ...].
    19. labels: ground truth of shape [bsz].
    20. mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
    21. has the same class as sample i. Can be asymmetric.
    22. Returns:
    23. A loss scalar.
    24. """
    25. device = (torch.device('cuda')
    26. if features.is_cuda
    27. else torch.device('cpu'))
    28. if len(features.shape) < 3:
    29. raise ValueError('`features` needs to be [bsz, n_views, ...],'
    30. 'at least 3 dimensions are required')
    31. if len(features.shape) > 3:
    32. features = features.view(features.shape[0], features.shape[1], -1)
    33. batch_size = features.shape[0]
    34. if labels is not None and mask is not None:
    35. raise ValueError('Cannot define both `labels` and `mask`')
    36. elif labels is None and mask is None:
    37. mask = torch.eye(batch_size, dtype=torch.float32).to(device)
    38. elif labels is not None:
    39. labels = labels.contiguous().view(-1, 1)
    40. if labels.shape[0] != batch_size:
    41. raise ValueError('Num of labels does not match num of features')
    42. mask = torch.eq(labels, labels.T).float().to(device)
    43. else:
    44. mask = mask.float().to(device)
    45. contrast_count = features.shape[1] # 2
    46. contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
    47. # 256 x 512
    48. if self.contrast_mode == 'one':
    49. anchor_feature = features[:, 0]
    50. anchor_count = 1
    51. elif self.contrast_mode == 'all':
    52. anchor_feature = contrast_feature # 256 x 512
    53. anchor_count = contrast_count # 2
    54. else:
    55. raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
    56. # compute logits
    57. anchor_dot_contrast = torch.div(
    58. torch.matmul(anchor_feature, contrast_feature.T),
    59. self.temperature)
    60. # for numerical stability
    61. # print (anchor_dot_contrast.size()) 256 x 256
    62. logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
    63. logits = anchor_dot_contrast - logits_max.detach()
    64. # tile mask
    65. mask = mask.repeat(anchor_count, contrast_count)
    66. # mask-out self-contrast cases
    67. logits_mask = torch.scatter(
    68. torch.ones_like(mask),
    69. 1,
    70. torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
    71. 0
    72. )
    73. mask = mask * logits_mask
    74. # compute log_prob
    75. exp_logits = torch.exp(logits) * logits_mask
    76. log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
    77. # compute mean of log-likelihood over positive
    78. mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
    79. loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
    80. loss = loss.view(anchor_count, batch_size).mean()
    81. return loss

    在 CIFAR100 上,我使用 resnet18 复现的结果为 80.54%,与论文中的 80.84% 差别不大

  • 相关阅读:
    QT控件无法获取焦点问题
    Stable diffusion的一些参数意义及常规设置
    TMS320F28335使用多个串口时,SCIRXST Register出现错误
    做推特群发群推“三不要怕”
    《架构风清扬-Java面试系列第25讲》聊聊ArrayBlockingQueue的特点及使用场景
    序列化和反序列化指令在PLC通信上的应用
    《Ai企业知识库》-模型实践-rasa开源学习框架-搭建简易机器人-环境准备(针对windows)-02
    数据结构试题 20-21
    计算机网络1
    被315点名的流氓下载器,又回来了…
  • 原文地址:https://blog.csdn.net/qq_38964360/article/details/127077976