作者单位:京东探索研究院
论文链接:https://arxiv.org/abs/2108.11048
代码链接:https://github.com/jiy173/MANA
笔者言: 如何对齐是VSR中具有挑战的任务,光流方法和可变性卷积在中等运动的视频中具有显著优势,然而在处理大运动视频时会失效。本文通过非局部注意方式跳过对齐来融合相邻帧,在大运动视频上实现了SOTA。
本文提出了一种内存增强非本地注意网络(MANA)。以前的方法主要利用相邻帧来辅助当前帧的超分。这些方法在空间帧对齐方面存在挑战,并且缺乏来自相邻帧的有用信息。相比之下,本文设计了一种跨帧非局部注意机制,允许视频在没有帧对齐的情况下实现超分,从而对视频中的大运动更加健壮。此外,为了获取相邻帧以外的一般先验信息,并补偿大运动造成的信息损失,本文设计了一种新的记忆增强注意模块,在训练中记忆一般视频细节。本文收集了Parkour数据集以验证MANA方法在大运动视频中的优越性。
VSR主要有两个挑战:
为了解决上述挑战,本文提出了一种内存增强非本地注意网络(MANA),它包含两个主要的模块。
下图展示了MANA的网络结构:
网络的第一阶段对每个输入帧应用相同的编码网络将所有视频帧嵌入到相同的特征空间中,第二阶段包括跨帧非局部注意和记忆增强注意两部分。跨帧非局部注意旨在从相邻帧特征中挖掘有用信息
X
t
X_t
Xt,记忆增强注意利用当前帧特性直接查询存储库,输出为
Y
t
Y_t
Yt。将
X
t
X_t
Xt和
Y
t
Y_t
Yt通过两个不同的卷积层进行卷积,作为残差加到输入帧特征
F
t
F_t
Ft中。解码器解码注意模块的输出,上采样模块对像素进行洗牌以生成高分辨率残差。残差为双线性上采样的LR帧增加了细节,从而得到清晰的高分辨率帧。
跨帧非局部注意的结构如下图所示:
首先使用GN对输入特征进行归一化得到k和v
F
‾
t
−
τ
,
.
.
.
,
F
‾
t
+
τ
\overline F_{t-τ},...,\overline F_{t+τ}
Ft−τ,...,Ft+τ,中心特征
F
t
F_t
Ft作为q。在传统的非局部注意设置中,相关矩阵
Γ
=
Q
T
K
Γ = Q^TK
Γ=QTK,
Γ
Γ
Γ的大小为HW × HWT,该矩阵对GPU内存的负担较大。为了提高网络的内存效率,对每个邻居帧分别进行非局部注意,即Γ的大小为HW × HW。为了减轻错误匹配像素的影响,在相关矩阵
Γ
Γ
Γ的第二个维度的每个切片上乘以一个高斯权重映射
G
∈
R
H
W
G\in \mathbb R^{HW}
G∈RHW,高斯映射的中心位于q的位置,输出为:
X
t
=
(
G
⨂
Γ
)
V
X_t = (G⨂Γ)V
Xt=(G⨂Γ)V其中⨂为Hadamard积。
记忆增强注意力结构如下图:
该模块维护一个全局存储库 M ∈ R C ′ × N M\in \mathbb R^{C '×N} M∈RC′×N。使用常规的非局部注意在全局存储库 M M M中查询当前帧特征 Q T Q^T QT,即相关矩阵为 Γ M = Q T M Γ_M = Q^TM ΓM=QTM。最终得到输出: Y ^ t = s o f t m a x ( Γ M ) M T \hat Y_t =softmax(Γ_M)M^T Y^t=softmax(ΓM)MT Y t Y_t Yt由 Y ^ t \hat Y_t Y^treshape而来。Nonlocal_attention的代码如下:
class nonlocal_attention(nn.Module):
def __init__(self, config, is_training=True):
super(nonlocal_attention, self).__init__()
self.in_channels = config['in_channels']
self.inter_channels = self.in_channels // 2
self.is_training=is_training
width=config['width']
height=config['height']
self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
self.W_z = nn.Conv2d(in_channels=self.inter_channels*7, out_channels=self.in_channels, kernel_size=1)
nn.init.constant_(self.W_z.weight, 0)
nn.init.constant_(self.W_z.bias, 0)
self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True)
x1=np.linspace(0,width-1,width)
y1=np.linspace(0,height-1,height)
x2=np.linspace(0,width-1,width)
y2=np.linspace(0,height-1,height)
X1,Y1,Y2,X2=np.meshgrid(x1,y1,y2,x2)
D=(X1-X2)**2+(Y1-Y2)**2
D=torch.from_numpy(D)
D=rearrange(D, 'a b c d -> (a b) (c d)')
if self.is_training:
D=D.float()
else:
D=D.half()
self.D=torch.nn.Parameter(D,requires_grad=False)
self.std=torch.nn.Parameter(4*torch.ones(1).float())
if self.is_training==False:
self.std=self.std.half()
self.W_z1 = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1)
nn.init.constant_(self.W_z1.weight, 0)
nn.init.constant_(self.W_z1.bias, 0)
self.mb = torch.nn.Parameter(torch.randn(self.inter_channels, 256))
def forward(self, x):
b, t, c, h, w = x.size()
q = x[:, 3, :, :, :]
weight=torch.exp(-0.5*(self.D/(self.std*self.std)))
weight=weight.unsqueeze(0).repeat(b,1,1)
reshaped_x = x.view(b*t , c, h, w).contiguous()
h_ = self.norm(reshaped_x)
q_=self.norm(q)
g_x = self.g(h_).view(b, t, self.inter_channels, h,w).contiguous()
theta_x = self.theta(h_).view(b, t, self.inter_channels, h,w).contiguous()
phi_x = self.phi(q_).view(b,self.inter_channels, -1)
phi_x_for_quant=phi_x.permute(0,2,1)
phi_x= phi_x.permute(0,2,1).contiguous()
corr_l = []
for i in range(t):
theta = theta_x[:, i, :, :, :]
g = g_x[:, i, :, :, :]
g = g.view(b, self.inter_channels, -1).permute(0,2,1).contiguous()
theta = theta.view(b, self.inter_channels, -1).contiguous()
if self.is_training:
f = torch.matmul(phi_x, theta)
else:
f = torch.matmul(phi_x.half(), theta.half())
f_div_C = F.softmax(f, dim=-1)*weight
if self.is_training:
y = torch.matmul(f_div_C, g).float()
else:
y = torch.matmul(f_div_C, g.half()).float()
y=y.permute(0,2,1).view(b,self.inter_channels,h,w)
corr_l.append(y)
corr_prob = torch.cat(corr_l, dim=1).view(b, -1, h, w)
W_y = self.W_z(corr_prob)
mbg = self.mb.unsqueeze(0).repeat(b, 1, 1)
f1 = torch.matmul(phi_x_for_quant, mbg)
f_div_C1 = F.softmax(f1 * (int(self.inter_channels) ** (-0.5)), dim=-1)
y1 = torch.matmul(f_div_C1, mbg.permute(0, 2, 1))
qloss=torch.mean(torch.abs(phi_x_for_quant-y1))
y1 = y1.permute(0, 2, 1).view(b, self.inter_channels, h, w).contiguous()
W_y1 = self.W_z1(y1)
z = W_y + q+W_y1
return z, qloss
No Mem为只有跨帧非局部注意模块的网络进行了实验。在这些配置中,N = 512的结果最好。使用较小的内存(N = 128和N = 256)会导致性能略有下降。当使用更大的内存(N = 1024)时,优势饱和,这意味着低分辨率帧的局部细节可以在低维空间中很好地表示。
上表显示了MANA在PSNR、SSIM和LPIPS评分方面的定量比较(Related Work中有BasicVSRQuantitative Comparisons却没有。。。)。Parkour数据集中的视频有非常大的运动,使帧的精确对齐困难。MANA方法不需要帧对齐,在很大程度上超过了所有VSR方法。这一观察结果表明,MANA能够处理视频中的大动作。有趣的是,帧对齐VSR方法的性能甚至不如SISR方法CSNLN。这是因为融合不对齐的帧经常会在结果中造成伪影。Vimeo90K-Motion由运动相对较大的常规视频组成。本文计算了Vimeo90K测试集中视频的光流量,并根据平均流量大小进行排序。选择前6%的视频组成Vimeo90K-Motion。结果进一步证实了我们的MANA在一些动作的视频中效果更好。