Yamini Bansal, Preetum Nakkiran, and Boaz Barak. Revisiting model stitching to compare neural representations. Advances in Neural Information Processing Systems, 35, 2021.
此外还有CCA CKA SVCCA PWCCA等工作,衡量相似性
正文
如下图所示,其实就是做了这样一件事:把一个网络的某一层特征,接一个stitching layer(文章定义为1x1的卷积),送到另一个网络中去。固定两个网络的参数,只训练stitching layer,看看接起来的网络的准确率如何。而对stitching layer的训练文章介绍了两种方式:task loss matching 和 direct matching
其实就是直接用 model 2 的特征图
B
∈
R
n
×
p
B\in R^{n\times p}
B∈Rn×p 作为target,用 model 1 的特征图
A
∈
R
n
×
p
A \in R^{n\times p}
A∈Rn×p 作为input,用最小二乘法找到最优的
M
0
∈
C
M_0 \in C
M0∈C,使得
∥
A
M
0
−
B
∥
F
=
m
i
n
M
∈
C
∥
A
M
−
B
∥
F
\|AM_0-B\|_F = min_{M\in C}\|AM-B\|_F
∥AM0−B∥F=minM∈C∥AM−B∥F 其中 p 为特征图的维度
p
=
c
p=c
p=c,所以这个其实就是个1x1的卷积,对所有所有图片所有位置的像素都是相同处理,n 为样本数,也可以理解为图像数量乘以hw。
此外其实还有一种理论是用
∥
A
M
0
−
B
∥
F
=
m
i
n
M
∈
C
{
∥
A
M
−
B
∥
F
+
α
⋅
∥
M
∥
1
}
\|AM_0-B\|_F = min_{M\in C}\{\|AM-B\|_F+\alpha \cdot \|M\|_1 \}
∥AM0−B∥F=minM∈C{∥AM−B∥F+α⋅∥M∥1}作为目标函数,目的是求得一个稀疏的M矩阵,这里的stitching layer的L1范数可以对 direct matching 也可以对 task loss 两种求最优化的方式去添加
前面说了,stitching layer可以是不满秩的,不满秩的网络可以通过 bottleneck 结构(即通道数先减小再增加,其实就是把
p
×
p
p\times p
p×p 的矩阵分解为
p
×
k
p \times k
p×k 的矩阵和
k
×
p
k \times p
k×p 的矩阵的乘积,从而使得乘积的秩最多为k,可以通过SVD分解来实现)的模块来实现。