对比损失(Contrastive Loss)中的参数 τ \tau τ是一个神秘的参数,大部分论文都默认采用较小的值来进行自监督对比学习(例如0.05),但是很少有文章详细讲解参数 τ \tau τ的作用,本文将详解对比损失中的超参数 τ \tau τ,并借此分析对比学习的核心机制。首先总结下本文的发现:
首先给出自监督学习广泛使用的对比损失(InfoNCE Loss)的形式:
L ( x i ) = − log [ exp ( s i , i / τ ) ∑ k ≠ i exp ( s i , k / τ ) + exp ( s i , i / τ ) ] (1) \mathcal{L}({x}_i) = -\log \left[\frac{\exp (s_{i,i}/\tau)}{\sum_{k\neq i} \exp(s_{i,k}/\tau) + \exp (s_{i, i}/\tau)}\right]\tag{1} L(xi)=−log[∑k=iexp(si,k/τ)+exp(si,i/τ)exp(si,i/τ)](1)
直观来说,该损失函数要求第 i i i个样本和它另一个扩增的(正)样本之间的相似度 s i , i s_{i,i} si,i之间尽可能大,而与其它实例(负样本)之间的相似度 s i , k s_{i,k} si,k之间尽可能小。然而,其实还有很多损失函数可以满足这个要求,例如下面最简单的形式 L simple \mathcal{L}_{\text{simple}} Lsimple:
L simple ( x i ) = − s i , i + λ ∑ i ≠ j s i , j (2) \mathcal{L}_{\text{simple}}({x}_i) = -s_{i,i} + \lambda \sum_{i\neq j}s_{i,j}\tag{2} Lsimple(xi)=−si,i+λi=j∑si,j(2)
然而实际训练时,采用 L sample \mathcal{L}_{\text{sample}} Lsample作为损失函数效果非常不好,论文给出了使用式(1)和式(2)的性能对比( τ = 0.07 \tau=0.07 τ=0.07):
数据集
Contrastive Loss
Simple Loss
CIFAR-10
79.75
74.83
CIFAR-100
51.82
39.31
ImageNet-100
71.53
48.09
SVHN
92.55
70.83
上面的结果显示,在所有数据集上Contrastive Loss都要远好于Simple Loss。作者通过研究发现,不同于Simple Loss,Contrastive Loss是一个困难样本自发现的损失函数。通过公式(2)可以发现,Simple Loss对所有的负样本给予了相同权重的惩罚 ∂ L simple ∂ s i , k = λ \frac{\partial \mathcal{L}_{\text{simple}}}{\partial s_{i,k}}=\lambda ∂si,k∂Lsimple=λ。而Contrastive Loss则会自动给相似度更高的负样本比较高的惩罚,这一点可以通过对Contrastive Loss中不同负样本的相似度惩罚梯度观察得到:
对正样本的梯度:
∂
L
(
x
i
)
∂
s
i
,
i
=
−
1
τ
∑
k
≠
i
P
i
,
k
对负样本的梯度:
∂
L
(
x
i
)
∂
s
i
,
j
=
1
τ
P
i
,
j
\text{对正样本的梯度:}\frac{\partial \mathcal{L}(x_i)}{\partial s_{i,i}}=-\frac{1}{\tau}\sum_{k\neq i} P_{i,k}\\ \text{对负样本的梯度:}\frac{\partial \mathcal{L}(x_i)}{\partial s_{i,j}}=\frac{1}{\tau}P_{i,j}
对正样本的梯度:∂si,i∂L(xi)=−τ1k=i∑Pi,k对负样本的梯度:∂si,j∂L(xi)=τ1Pi,j
其中
P
i
,
j
=
exp
(
s
i
,
j
/
τ
)
∑
k
≠
i
exp
(
s
i
,
k
/
τ
)
+
exp
(
s
i
,
i
/
τ
)
P_{i,j}=\frac{\exp(s_{i,j/}\tau)}{\sum_{k\neq i} \exp(s_{i,k}/\tau) + \exp({s_{i,i}/\tau})}
Pi,j=∑k=iexp(si,k/τ)+exp(si,i/τ)exp(si,j/τ)。对于所有的负样本来说,
P
i
,
j
P_{i,j}
Pi,j的分母项都是相同的,那么
s
i
,
j
s_{i,j}
si,j越大,则负样本的梯度项
∂
L
(
x
i
)
∂
s
i
,
j
=
1
τ
P
i
,
j
\frac{\partial \mathcal{L}(x_i)}{\partial s_{i,j}}=\frac{1}{\tau}P_{i,j}
∂si,j∂L(xi)=τ1Pi,j也越大。也就是说,Contrastive Loss给予了更相似(困难)负样本更大的远离该样本的梯度。可以把不同的负样本想象成同极电荷在不同距离处的受力情况,距离越近的电荷受到的斥力越大,而越远的电荷受到的斥力越小。Contrastive Loss也是这样,这种性质有利于形成在超球面均匀分布的特性
为了更具体的解释超参数 τ \tau τ的作用,作者计算了两种极端情况,即 τ \tau τ趋近于0和无穷大。当 τ \tau τ趋近于0时:
lim
τ
→
0
+
−
log
[
exp
(
s
i
,
i
/
τ
)
∑
k
≠
i
exp
(
s
i
,
k
/
τ
)
+
exp
(
s
i
,
i
/
τ
)
]
=
lim
τ
→
0
+
log
[
1
+
∑
k
≠
i
exp
(
(
s
i
,
k
−
s
i
,
i
)
/
τ
)
]
(3)
简单点,我们仅考虑那些困难的负样本,即如果存在负样本 x k x_k xk,有 Sim ( x i , x k ) ≥ Sim ( x i , x i + ) \text{Sim}(x_i, x_k)\ge \text{Sim}(x_i,x_i^+) Sim(xi,xk)≥Sim(xi,xi+),则称 x k x_k xk为困难的负样本。此时式(3)可以改写为:
lim τ → 0 + log [ 1 + ∑ s i , k ≥ s i , i k exp ( ( s i , k − s i , i ) / τ ) ] (4) \lim_{\tau \to 0^+}\log \left[1+\sum_{\color{red}{s_{i,k} \ge s_{i,i}}}^k \exp((s_{i,k} - s_{i,i})/\tau)\right]\tag{4} τ→0+limlog⎣⎡1+si,k≥si,i∑kexp((si,k−si,i)/τ)⎦⎤(4)
因为 τ → 0 + \tau \to 0^+ τ→0+,我们直接省略常数1,并且将求和直接改为最大的 s i , k s_{i,k} si,k这一项,记作 s max s_{\text{max}} smax,则式(4)可以改写为:
lim τ → 0 + 1 τ max [ s max − s i , i , 0 ] (5) \lim_{\tau \to 0^+} \frac{1}{\tau} \max[s_{\text{max}} - s_{i,i},0]\tag{5} τ→0+limτ1max[smax−si,i,0](5)
可以看出,此时Contrastive Loss退化为只关注最困难的负样本的损失函数。而当 τ \tau τ趋近于无穷大时:
lim
τ
→
+
∞
−
log
[
exp
(
s
i
,
i
/
τ
)
∑
k
≠
i
exp
(
s
i
,
k
/
τ
)
+
exp
(
s
i
,
i
/
τ
)
]
=
lim
τ
→
+
∞
−
1
τ
s
i
,
i
+
log
∑
k
exp
(
s
i
,
k
/
τ
)
=
lim
τ
→
+
∞
−
1
τ
s
i
,
i
+
log
[
N
(
1
+
(
1
N
∑
k
exp
(
s
i
,
k
/
τ
)
−
1
)
)
]
=
lim
τ
→
+
∞
−
1
τ
s
i
,
i
+
log
[
1
+
(
1
N
∑
k
exp
(
s
i
,
k
/
τ
)
−
1
)
]
+
log
N
=
lim
τ
→
+
∞
−
1
τ
s
i
,
i
+
(
1
N
∑
k
exp
(
s
i
,
k
/
τ
)
−
1
)
+
log
N
=
lim
τ
→
+
∞
−
1
τ
s
i
,
i
+
1
N
τ
∑
k
s
i
,
k
+
log
N
=
lim
τ
→
+
∞
1
−
N
N
τ
s
i
,
i
+
1
N
τ
∑
k
≠
i
s
i
,
k
+
log
N
(6)
上述等式推导利用了 ln ( 1 + x ) \ln(1+x) ln(1+x)和 e x e^x ex的泰勒展开,或者说等价无穷小
此时Contrastive Loss对所有负样本的权重都相同( 1 N τ \frac{1}{N\tau} Nτ1),即Contrastive Loss失去了对困难样本关注的特性。有趣的是,当 τ = N − 1 N \tau = \frac{N-1}{N} τ=NN−1时,对比损失 L ( x i ) \mathcal{L}(x_i) L(xi)与前面提到的 L simple \mathcal{L}_{\text{simple}} Lsimple几乎一样
论文作者通过上面两个极限情况分析了超参数 τ \tau τ的作用:随着 τ \tau τ的增大,Contrastive Loss倾向于“一视同仁”;随着 τ \tau τ的减小,Contrastive Loss变得倾向于关注最困难的样本