• Momentum Contrast for Unsupervised Visual Representation Learning 论文学习


    1. 解决了什么问题?

    非监督学习在自然语言处理非常成功,如 GPT 和 BERT。但在计算机视觉任务上,监督预训练方法要领先于非监督的方法。这种差异可能是因为各自的信号空间不同,语言任务有着离散的信号空间(单词、短语等)来构建非监督学习所需的字典。而计算机视觉则很难构建一个字典,因为原始信号位于连续的高维空间,不像单词一样是结构化的。

    最近的非监督表征学习方法使用对比损失取得了不错的效果,它们基本是构建了一个动态字典。从数据中采样,产生字典的 keys/tokens,由编码器网络表征。非监督学习训练编码器来进行字典查询:query 应该与匹配到的 key 距离近,而与其它 keys 距离远。通过最小化对比损失来进行训练。

    本文假设字典的构建应该满足两个条件,一是足够大,二是在训练过程中不断地更新,是连续的。字典足够大,能更好地从连续的高维空间中采样,keys 由近似的编码器表征,这样 key 和 query 的比较才是连续的。而目前的方法只能满足上述两个条件中的一个。

    非监督/自监督学习一般包括两个方面:pretext 任务和损失函数。Pretext 的意思是待解决的任务不是我们真正关心的,我们真正的目的是学习好的数据表征。损失函数可以独立于 pretext 任务来研究,MoCo 聚焦在损失函数上。

    2. 提出了什么方法?

    针对非监督视觉表征学习任务,提出了 Momentum Contrast,构建了一个包含队列和滑动平均编码器的动态字典。MoCo 针对非监督学习,利用对比损失构建了一个足够大且连续的字典,该字典用一个队列维护:当前 mini-batch 的表征加入队列,最早的 mini-batch 表征从队列中剔除。字典的 keys 来自于之前的多个 mini-batches,通过一个基于动量的滑动平均编码器实现该缓慢演进的 key 编码器,保证连续性。

    2.1 Contrastive Learning as Dictionary Look-up

    对比学习就是针对字典查询任务,训练一个编码器。给定一个编码后的 query q q q和字典的一组编码样本 { k 0 , k 1 , k 2 , . . . } \lbrace k_0,k_1,k_2,...\rbrace {k0,k1,k2,...}。假设字典中有一个 q q q匹配到的 key,记做 k + k_+ k+。对比损失中,当 q q q k + k_+ k+相似而与其它 keys 不相似时,损失值就小。相似度用点积表示,是对比损失函数的一种形式,叫做 InfoNCE:

    L q = − log ⁡ exp ⁡ ( q ⋅ k + / τ ) ∑ i = 0 K exp ⁡ ( q ⋅ k i / τ ) \mathcal{L}_q = -\log \frac{\exp\left(q\cdot k_+/\tau\right)}{\sum_{i=0}^K \exp \left(q\cdot k_i /\tau\right)} Lq=logi=0Kexp(qki/τ)exp(qk+/τ)

    τ \tau τ是一个调节超参数。除数是对一个正样本和 K K K个负样本求和。该损失是基于 Softmax 的 ( K + 1 ) − way (K+1)-\text{way} (K+1)way分类器,将 q q q分类为 k + k_+ k+。对比损失函数也可基于其它形式,比如 margin-based 损失和 NCE 损失变体。

    对比损失是训练编码器的非监督目标函数,该编码器表征 query 和 key。通常,query 表征为 q = f q ( x q ) q=f_q(x^q) q=fq(xq) f q f_q fq是编码器网络, x q x^q xq是 query 样本。同样, k = f k ( x k ) k=f_k(x^k) k=fk(xk)。输入 x q x^q xq x k x^k xk可以是图像、图块或图块构成的 context。网络 f q f_q fq f k f_k fk可以是一样的,也可以部分共享的,也可以是完全不同的。

    2.2 Momentum Contrast

    对比学习从高维连续输入(如图像)中构建离散字典。该字典是动态的,keys 通过随机采样得到,在训练过程中 key 编码器不断地更新。本文假设,如果一个字典足够大,涵盖了丰富的负样本,就能用该字典学习好的特征。而且编码器在更新过程中是连续的。

    Dictionary as a queue

    本方法的核心就是,字典用一个样本队列来维护。这样我们就可复用不久前 mini-batches 的 keys。该字典的大小可以远大于 mini-batch 的大小,作为一个超参灵活地设定。

    字典中的样本被逐步替换掉。当前 mini-batch 样本加入到字典中,最早的 mini-batch 则被剔除。字典只代表了数据集的一个子集,维持这个字典的计算量是可以控制的。剔除最早的 mini-batch,它所编码的 keys 过时了,与最新的 mini-batch 连续性最低。

    Momentum update

    队列表示能让字典很大,但无法通过反向传播来更新编码器(梯度应该回传给队列中所有的样本)。简单的办法就是复制 query 编码器 f q f_q fq到 key 编码器 f k f_k fk,不管梯度。但这个办法实验效果不行。作者认为,编码器的迅速变化,降低了 key 表征的连续性。于是提出了动量更新,解决这个问题。

    f k f_k fk的参数记做 θ k \theta_k θk f q f_q fq的参数为 θ q \theta_q θq,更新 θ k \theta_k θk

    θ k ← m θ k + ( 1 − m ) θ q \theta_k \leftarrow m\theta_k + (1-m)\theta_q θkmθk+(1m)θq

    m ∈ [ 0 , 1 ) m\in [0,1) m[0,1)是动量系数。反向传播只用更新 θ q \theta_q θq。动量更新使 θ k \theta_k θk的更新更加平滑。这样,尽管队列中的 keys 是用不同的编码器(不同的 mini-batches)编码的,这些编码器的差异很小。在实验中,大动量系数(比如 m = 0.9 m=0.9 m=0.9)的表现要好于小的系数,表明缓慢更新的 key 编码器是使用队列的关键。

    Relations to previous mechanisms

    MoCo 对于对比损失是通用的。在下图中,MoCo 和现有的两种对比损失机制进行了比较,在字典大小和连续性方面有着不同的特性。它们差异体现在 keys 是如何维护的,以及 key 编码器是如何更新的。

    1. 计算 query 和 key 的编码器通过端到端的反向传播更新,两个编码器可以不一样。它使用当前 mini-batch
      的样本作为字典,keys 是连续编码的,因为编码器参数是一样的。但是字典大小与 mini-batch 大小是耦合的,受 GPU
      显存大小限制。

    2. 从 memory bank 中采样得到 keys 表征。Memory bank 包括了数据集所有样本的表征。每个
      mini-batch 的字典都是从 memory bank 中随机采样得到,无需反向传播,因此字典规模可以很大。但是 memory
      bank 的样本表征是看到了才会更新,因此采样的 keys 是 epoch 中不同步骤的编码器生成的,彼此缺乏连续性。

    3. MoCo 使用动量更新编码器来编码新的 keys,维护了一个 keys 队列。Moco
      并不记录每个样本,因此对内存更加有效,可以在数以亿计的数据上训练。

    在这里插入图片描述

    2.3 Pretext Task

    如果一对 query 和 key 来自于同一图像,则是正样本对,否则为负样本对。使用数据增强得到同一图像的两个随机视角,产生正样本对。用各自的编码器 f q f_q fq f k f_k fk编码得到 query 和 key。算法1 是该 pretext 任务的 MoCo 伪码。对于当前 mini-batch,编码 queries 和相应的 keys,得到正样本对。负样本对则来自于队列。

    在这里插入图片描述

    技术细节

    编码器采用 ResNet,全局平均池化层后的最后一个全连接层有固定维度( 128 128 128维)的输出。用 L 2 − norm L2-\text{norm} L2norm对输出向量归一化。这就是 query 或 key 的表征。 τ \tau τ设为 0.07 0.07 0.07。数据增强方法如下:对随机缩放的图像裁剪一块 224 × 224 224\times 224 224×224大小的区域,然后使用随机色彩变动、随机水平翻转和随机灰度转换。

    Shuffling BN

    f q f_q fq f k f_k fk都在 ResNet 中使用了 BN。BN 会阻碍模型学习高质量表征。模型似乎在作弊,欺骗 pretext 任务,很容易就找到了低损失值的方案。这可能是因为 BN 在 batch 内部交流信息造成了信息泄露。

    于是作者使用了 shuffling BN。使用了多个 GPU 训练,对于每个 GPU 的样本独立完成 BN 操作。对于 key 编码器 f k f_k fk,shuffle 当前 mini-batch 样本的顺序,然后再将其分配到各个 GPU,编码后再 shuffle 回来。Query 编码器 f q f_q fq的 mini-batch 样本顺序不变。这保证了计算 query 和它的正样本 key 所需的 batch 统计信息来自于两个不同的子集。这就有效解决了作弊问题。

    在上图(a)和©中,作者使用了 shuffling BN,(b) 中没有用,因为 memory bank 提供的正样本 keys 来自于之前产生的、不同的 mini-batches。

  • 相关阅读:
    Vue.js 中的异步组件是什么?
    (2.2w字)前端单元测试之Jest详解篇
    element ui富文本编辑器的使用(quill-editor)
    Flutter实战-请求封装(六)之设置抓包Proxy
    dialogx,给大家推荐一个开源安卓弹窗组件。
    通过监控Nginx日志来实时屏蔽高频恶意访问的IP
    LeetCode 212. 单词搜索 II -- 字典树+dfs
    linux-网站服务
    浏览器网页截屏妙用Capture node screenshot
    2023开学礼山东财经大学《乡村振兴战略下传统村落文化旅游设计》许少辉新财经图书馆
  • 原文地址:https://blog.csdn.net/calvinpaean/article/details/133269762