整体框架如图1,encoder包含3个stage,输入图像分辨率假设为(H,W,3),3个阶段输出空间分辨率分别为
(
H
/
4
,
W
/
4
)
,
(
H
/
8.
W
/
8
)
,
(
H
/
16
,
W
/
16
)
(H/4,W/4),(H/8.W/8),(H/16,W/16)
(H/4,W/4),(H/8.W/8),(H/16,W/16),前两个stage使用卷积层将输入转换为token Embeedings,分别为
E
1
∈
R
H
4
∗
W
4
∗
C
1
E_1 \in\mathbb{R}^{\frac{H}{4}*\frac{W}{4}*C_1}
E1∈R4H∗4W∗C1,
E
2
∈
R
H
8
∗
W
8
∗
C
2
E_2 \in\mathbb{R}^{\frac{H}{8}*\frac{W}{8}*C_2}
E2∈R8H∗8W∗C2.卷积block的设计遵循transformer block的原则,仅仅将self-attention操作用
5
×
5
5\times5
5×5的卷积替代。第3个stage使用通用的self-attention blocks获取token Embeedings
E
3
∈
R
H
16
×
W
16
×
C
3
E_3\in \mathbb{R}^{\frac{H}{16}\times\frac{W}{16}\times C_3}
E3∈R16H×16W×C3。在每个stage之间,stride为2的卷积被用来下采样tokens到之前分辨率的一半。
给定输入图像
I
∈
R
3
×
H
×
W
I \in \mathbb{R}^{3×H×W}
I∈R3×H×W, ConvMAE编码器的第1阶段首先使用非重叠4 × 4卷积生成一个高分辨率的token embeeding,
E
1
∈
R
C
1
×
H
4
×
W
4
E_1∈R^{C_1 × \frac{H}{4} × \frac{W}{4}}
E1∈RC1×4H×4W。然后将
E
1
E_1
E1送入堆叠的卷积块中,重复
L
1
L_1
L1次,
L
1
L_1
L1表示第1阶段的层数。
与stage 1相似,stage 2使用非重叠2×2卷积进一步下采样特征映射到token embeedings :
E
2
∈
R
C
2
×
H
8
×
W
8
E_2 \in \mathbb{R}^{C_2× \frac{H}{8} × \frac{W}{8}}
E2∈RC2×8H×8W。
E
2
E_2
E2被
L
2
L_2
L2层卷积块再次处理。在第1阶段和第2阶段进行局部信息融合后,第3阶段利用transformer block进行全局特征融合。利用非重叠2 × 2卷积将
E
2
E_2
E2投影到token embeedings
E
3
∈
R
(
H
16
×
W
16
)
×
C
3
E_3\in \mathbb{R}^{(\frac{H}{16} × \frac{W}{16})×C_3}
E3∈R(16H×16W)×C3。将E3与中间位置嵌入(IPL)混合送入具有
L
3
L_3
L3层的纯transformer block。我们将阶段3中注意头的数量表示为
H
a
H_a
Ha。不同时期FFN的mlp-比值分别记为
P
1
、
P
2
、
P
3
P_1、P_2、P_3
P1、P2、P3。
MAE decoder的输入部分是来源于Encoder的的可视化token
E
d
E_d
Ed以及mask tokens,在transformer blocks中进行组合用于图像重建任务。
ConvMAE encoder获取多尺度特征
E
1
,
E
2
,
E
3
E_1,E_2,E_3
E1,E2,E3,捕获了细粒度和粗粒度图像信息。为了更好的监督这种多粒度表示的预训练,通过stride-4 和stride-2 卷积将
E
1
,
E
2
E_1,E_2
E1,E2下采样到与
E
3
E_3
E3相同的尺寸大小,通过一个linear 层融合多粒度tokens以获取用于输入到decoder的可视化token。
E
d
=
L
i
n
e
a
r
(
S
t
r
i
d
e
C
o
n
v
(
E
1
,
4
)
+
S
t
r
i
d
e
C
o
n
v
(
E
2
,
2
)
+
E
3
)
E_d=Linear(StrideConv(E_1,4)+StrideConv(E_2,2)+E_3)
Ed=Linear(StrideConv(E1,4)+StrideConv(E2,2)+E3) StrideConv(.,K)表示stride-K的卷积。
损失函数与MAE一样考虑重建任务损失,目标函数只计算masked patches。
T
M
T_M
TM是一系列masked token 而且t是token索引,重建目标
I
I
I为输入图像的归一化像素值,
I
^
\hat{I}
I^是重建图像。