• CVPR2022 | 无需对齐就能胜任大运动超分的内存增强非局部注意方法


    在这里插入图片描述
    作者单位:京东探索研究院
    论文链接:https://arxiv.org/abs/2108.11048
    代码链接:https://github.com/jiy173/MANA
    笔者言: 如何对齐是VSR中具有挑战的任务,光流方法和可变性卷积在中等运动的视频中具有显著优势,然而在处理大运动视频时会失效。本文通过非局部注意方式跳过对齐来融合相邻帧,在大运动视频上实现了SOTA。

    看点

    本文提出了一种内存增强非本地注意网络(MANA)。以前的方法主要利用相邻帧来辅助当前帧的超分。这些方法在空间帧对齐方面存在挑战,并且缺乏来自相邻帧的有用信息。相比之下,本文设计了一种跨帧非局部注意机制,允许视频在没有帧对齐的情况下实现超分,从而对视频中的大运动更加健壮。此外,为了获取相邻帧以外的一般先验信息,并补偿大运动造成的信息损失,本文设计了一种新的记忆增强注意模块,在训练中记忆一般视频细节。本文收集了Parkour数据集以验证MANA方法在大运动视频中的优越性。
    在这里插入图片描述

    动机

    VSR主要有两个挑战:

    • 第一个来自于视频的动态特性。由于帧间运动,融合前需要对相邻帧进行对齐,之前方法采用光流进行显式扭曲或可变形卷积学习隐式对齐。然而,它们高度依赖于相邻帧空间对齐的精度,很难胜任大运动视频。
    • 第二个来自于LR中高频细节的不可逆丢失和有用信息的缺乏。大多数VSR方法试图融合来自相邻帧的信息进行重建。然而,获取的信息仍然有限,特别是大运动视频。在这种情况下,由于相邻帧的相似性降低,相关性变小,从而从本质上退化为SISR。

    为了解决上述挑战,本文提出了一种内存增强非本地注意网络(MANA),它包含两个主要的模块。

    • 跨帧非局部注意模块用来解决帧对齐问题。该模块允许在不与当前帧对齐的情况下融合相邻帧。传统的非局部注意计算q和k中每个像素之间的成对相关。然而,对所有空间位置的像素一视同仁是不合适的。由于连续性的性质,查询附近的像素会有更好的对应。为此,本文使用一个以q为中心的可训练高斯映射来对相关性进行加权,高斯加权跨帧非局部注意绕开了帧对齐操作。
    • 为了解决相邻帧信息缺乏的问题,本文寻求融合现有视频之外的有用先验信息。这意味着网络在对训练集中的其他视频进行超分时要记住以前的经验。基于这一原理,本文在网络中引入了记忆增强注意模块。这个模块中拥有一个二维记忆库,它是在训练过程中学习到的。目的是总结整个训练集中具有代表性的局部细节,并将其作为当前超分的外部参考。

    方法

    Overview

    下图展示了MANA的网络结构:
    在这里插入图片描述
    网络的第一阶段对每个输入帧应用相同的编码网络将所有视频帧嵌入到相同的特征空间中,第二阶段包括跨帧非局部注意和记忆增强注意两部分。跨帧非局部注意旨在从相邻帧特征中挖掘有用信息 X t X_t Xt,记忆增强注意利用当前帧特性直接查询存储库,输出为 Y t Y_t Yt。将 X t X_t Xt Y t Y_t Yt通过两个不同的卷积层进行卷积,作为残差加到输入帧特征 F t F_t Ft中。解码器解码注意模块的输出,上采样模块对像素进行洗牌以生成高分辨率残差。残差为双线性上采样的LR帧增加了细节,从而得到清晰的高分辨率帧。

    跨帧非局部注意

    跨帧非局部注意的结构如下图所示:
    在这里插入图片描述
    首先使用GN对输入特征进行归一化得到k和v F ‾ t − τ , . . . , F ‾ t + τ \overline F_{t-τ},...,\overline F_{t+τ} Ftτ,...,Ft+τ,中心特征 F t F_t Ft作为q。在传统的非局部注意设置中,相关矩阵 Γ = Q T K Γ = Q^TK Γ=QTK Γ Γ Γ的大小为HW × HWT,该矩阵对GPU内存的负担较大。为了提高网络的内存效率,对每个邻居帧分别进行非局部注意,即Γ的大小为HW × HW。为了减轻错误匹配像素的影响,在相关矩阵 Γ Γ Γ的第二个维度的每个切片上乘以一个高斯权重映射 G ∈ R H W G\in \mathbb R^{HW} GRHW,高斯映射的中心位于q的位置,输出为: X t = ( G ⨂ Γ ) V X_t = (G⨂Γ)V Xt=(GΓ)V其中⨂为Hadamard积。

    记忆增强注意力

    记忆增强注意力结构如下图:
    在这里插入图片描述

    该模块维护一个全局存储库 M ∈ R C ′ × N M\in \mathbb R^{C '×N} MRC×N。使用常规的非局部注意在全局存储库 M M M中查询当前帧特征 Q T Q^T QT,即相关矩阵为 Γ M = Q T M Γ_M = Q^TM ΓM=QTM。最终得到输出: Y ^ t = s o f t m a x ( Γ M ) M T \hat Y_t =softmax(Γ_M)M^T Y^t=softmax(ΓM)MT Y t Y_t Yt Y ^ t \hat Y_t Y^treshape而来。Nonlocal_attention的代码如下:

    class nonlocal_attention(nn.Module):
        def __init__(self, config, is_training=True):
            super(nonlocal_attention, self).__init__()
            self.in_channels = config['in_channels']
            self.inter_channels = self.in_channels // 2
            self.is_training=is_training
            width=config['width']
            height=config['height']
                    
            self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
            self.W_z = nn.Conv2d(in_channels=self.inter_channels*7, out_channels=self.in_channels, kernel_size=1)
            nn.init.constant_(self.W_z.weight, 0)
            nn.init.constant_(self.W_z.bias, 0)
            self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
            self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
    
            self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True)
            
            x1=np.linspace(0,width-1,width)
            y1=np.linspace(0,height-1,height)
            x2=np.linspace(0,width-1,width)
            y2=np.linspace(0,height-1,height)
            X1,Y1,Y2,X2=np.meshgrid(x1,y1,y2,x2)
            D=(X1-X2)**2+(Y1-Y2)**2
            D=torch.from_numpy(D)
            D=rearrange(D, 'a b c d -> (a b) (c d)')
            if self.is_training:
                D=D.float()
            else:
                D=D.half()
            self.D=torch.nn.Parameter(D,requires_grad=False)
            self.std=torch.nn.Parameter(4*torch.ones(1).float())
            if self.is_training==False:
                self.std=self.std.half()
            self.W_z1 = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1)
            nn.init.constant_(self.W_z1.weight, 0)
            nn.init.constant_(self.W_z1.bias, 0)
            self.mb = torch.nn.Parameter(torch.randn(self.inter_channels, 256))
        
        def forward(self, x):
            b, t, c, h, w = x.size()
            q = x[:, 3, :, :, :]
            
            weight=torch.exp(-0.5*(self.D/(self.std*self.std)))
            weight=weight.unsqueeze(0).repeat(b,1,1)
    
            reshaped_x = x.view(b*t , c, h, w).contiguous()
            h_ = self.norm(reshaped_x)
            q_=self.norm(q)
    
            g_x = self.g(h_).view(b, t, self.inter_channels, h,w).contiguous()
            theta_x = self.theta(h_).view(b, t, self.inter_channels,  h,w).contiguous()
            phi_x = self.phi(q_).view(b,self.inter_channels, -1)
            phi_x_for_quant=phi_x.permute(0,2,1)
            phi_x= phi_x.permute(0,2,1).contiguous()
    
            corr_l = []
            for i in range(t):
                theta = theta_x[:, i, :, :, :]
                g = g_x[:, i, :, :, :]
    
                g = g.view(b, self.inter_channels, -1).permute(0,2,1).contiguous()
                theta = theta.view(b, self.inter_channels, -1).contiguous()
                
                if self.is_training:
                    f = torch.matmul(phi_x, theta)
                else:
                    f = torch.matmul(phi_x.half(), theta.half())
                
                f_div_C = F.softmax(f, dim=-1)*weight
                if self.is_training:
                    y = torch.matmul(f_div_C, g).float()
                else:
                    y = torch.matmul(f_div_C, g.half()).float()
                y=y.permute(0,2,1).view(b,self.inter_channels,h,w)
                corr_l.append(y)
                
    
            corr_prob = torch.cat(corr_l, dim=1).view(b, -1, h, w)
            W_y = self.W_z(corr_prob)
            
            mbg = self.mb.unsqueeze(0).repeat(b, 1, 1)
            f1 = torch.matmul(phi_x_for_quant, mbg)
            f_div_C1 = F.softmax(f1 * (int(self.inter_channels) ** (-0.5)), dim=-1)
            y1 = torch.matmul(f_div_C1, mbg.permute(0, 2, 1))
            qloss=torch.mean(torch.abs(phi_x_for_quant-y1))
            y1 = y1.permute(0, 2, 1).view(b, self.inter_channels, h, w).contiguous()
            W_y1 = self.W_z1(y1)
            
            z = W_y + q+W_y1
    
            return z, qloss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92

    实验

    消融实验

    在这里插入图片描述
    No Mem为只有跨帧非局部注意模块的网络进行了实验。在这些配置中,N = 512的结果最好。使用较小的内存(N = 128和N = 256)会导致性能略有下降。当使用更大的内存(N = 1024)时,优势饱和,这意味着低分辨率帧的局部细节可以在低维空间中很好地表示。

    定量实验

    在这里插入图片描述
    上表显示了MANA在PSNR、SSIM和LPIPS评分方面的定量比较(Related Work中有BasicVSRQuantitative Comparisons却没有。。。)。Parkour数据集中的视频有非常大的运动,使帧的精确对齐困难。MANA方法不需要帧对齐,在很大程度上超过了所有VSR方法。这一观察结果表明,MANA能够处理视频中的大动作。有趣的是,帧对齐VSR方法的性能甚至不如SISR方法CSNLN。这是因为融合不对齐的帧经常会在结果中造成伪影。Vimeo90K-Motion由运动相对较大的常规视频组成。本文计算了Vimeo90K测试集中视频的光流量,并根据平均流量大小进行排序。选择前6%的视频组成Vimeo90K-Motion。结果进一步证实了我们的MANA在一些动作的视频中效果更好。

    定量评估

    在这里插入图片描述

  • 相关阅读:
    通过TPT的FUSION平台实现联合测试
    ROS参数服务器(Param):通信模型、Hello World与拓展
    【Android -- 开源库】fastjson 基本使用
    算法刷题记录 Day52
    flutter报错: library “libflutter.so“ not found
    vue+cesium 点击矢量图层获取geoserver信息
    R 语言入门学习笔记:软件安装踩坑记录——删除所有包以及彻底解决库包被安装到 C 盘用户目录下的问题,以及一些其他需要注意的点
    一文彻底搞懂 JS 闭包
    YaRN: Efficient Context Window Extension of Large Language Models
    大灰狼远控木马分析
  • 原文地址:https://blog.csdn.net/Srhyme/article/details/126458799