• 【手搓模型】亲手实现 Vision Transformer


    🚩前言

    • 🐳博客主页:😚睡晚不猿序程😚
    • ⌚首发时间:2023.3.17,首发于博客园
    • ⏰最近更新时间:2023.3.17
    • 🙆本文由 睡晚不猿序程 原创
    • 🤡作者是蒻蒟本蒟,如果文章里有任何错误或者表述不清,请 tt 我,万分感谢!orz

    相关文章目录 :无


    目录

    1. 内容简介

    最近在准备使用 Transformer 系列作为 backbone 完成自己的任务,感觉自己打代码的次数也比较少,正好直接用别人写的代码进行训练的同时,自己看着 ViT 的论文以及别人实现的代码自己实现一下 ViT

    感觉 ViT 相对来说实现还是比较简单的,也算是对自己代码能力的一次练习吧,好的,我们接下来开始手撕 ViT


    2. Vision Transformer 总览

    我这里默认大家都理解了 Transformer 的构造了!如果有需要我可以再发一下 Transformer 相关的内容

    ViT 的总体架构和 Transformer 一致,因为它的目标就是希望保证 Transformer 的总体架构不变,并将其应用到 CV 任务中,它可以分为以下几个部分:

    1. 预处理

      包括以下几个步骤:

      1. 划分 patch
      2. 线性嵌入
      3. 添加 CLS Token
      4. 添加位置编码
    2. 使用 Transformer Block 进行处理

    3. MLP 分类头基于 CLS Token 进行分类

    上面讲述的是大框架,接下来我们深入 ViT 的Transformer Block 去看一下和原本的 Transformer 有什么区别

    Transformer Block

    和 Transformer 基本一致,但是使用的是 Pre-Norm,也就是先进行 LayerNorm 然后再做自注意力/MLP,而 Transformer 选择的是 Pose-Norm,也就是先做自注意力/MLP 然后再做 LayerNorm

    Pre-Norm 和 Pose-Norm 各有优劣:

    • Pre-Norm 可以不使用 warmup,训练更简单
    • Pose-Norm 必须使用 warmup 以及其他技术,训练较难,但是完成预训练后泛化能力更好

    ViT 选择了 Pre-Norm,所以训练更为简单

    3. 手撕 Transformer

    接下来我们一部分一部分的来构建 ViT,由一个个组件最后拼合成 ViT

    3.1 预处理部分

    这一部分我们将会构建:

    1. 划分 patch
    2. 线性嵌入
    3. 插入 CLS Token
    4. 嵌入位置编码信息

    我们先把整个部分的代码放在这里,之后我们再详细讲解

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from einops import rearrange, repeat
    from einops.layers.torch import Rearrange
    
    
    class pre_proces(nn.Module):
        def __init__(self, image_size, patch_size, patch_dim, dim):
            super().__init__()
            self.patch_size = patch_size
            self.dim = dim
            self.patch_num = (image_size//patch_size)**2
            self.linear_embedding = nn.Linear(patch_dim, dim)
            self.position_embedding = nn.Parameter(torch.randn(1, self.patch_num+1, self.dim))  # 使用广播
            self.CLS_token = nn.Parameter(torch.randn(1, 1, self.dim))  # 别忘了维度要和 (B,L,C) 对齐
    
        def forward(self, x):
            x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)  # (B,L,C)
            x = self.linear_embedding(x)
            b, l, c = x.shape   # 获取 token 的形状 (B,L,c)
            CLS_token = repeat(self.CLS_token, '1 1 d -> b 1 d', b=b)  # 位置编码复制 B 份
            x = torch.concat((CLS_token, x), dim=1)
            x = x+self.position_embedding
            return x
    

    可以先大概浏览一下,也不是很难看懂啦!

    3.1.1 patch 划分

    x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)  # (B,L,C)
    

    我们直接使用 einops 库中的 rearrange 函数来划分 patch,我们输入的 x 的数组表示为 (B,C,H,W),我们要把它划分成 (B,L,C),其中 L=WWp×HHp" role="presentation">L=WWp×HHp,也就是 patch 的个数,最后 C=Wp×Hp×channels" role="presentation">C=Wp×Hp×channels

    这个函数就把原先的 (B,C,H,W) 表示方式拆开了,很轻易的就能够做到我们想要的 patch 划分,注意 h 和 p1 和 p2 的顺序不能乱

    3.1.2 线性嵌入

    首先我们要先定义一个全连接层

    self.linear_embedding = nn.Linear(patch_dim, dim)
    

    使用这个函数将 patch 映射到 Transformer 处理的维度

    x = self.linear_embedding(x)
    

    接着使用这个函数来执行线性嵌入,将其映射到维度 dim

    3.1.3 插入 CLS Token

    CLS Token 是最后分类头处理的依据,这个思想好像是来源于 BERT,可以看作是一种 池化 方式,CLS Token 在 Transformer 中会和其他元素进行交互,最后的输出时可以认为它拥有了所有 patch 信息,如果不使用 CLS Token 也可以选择平均池化等方式来进行分类

    首先我们要定义 CLS Token,他是一个可学习的向量,所以需要注册为 nn.Parameter ,其维度和 Transformer 处理维度一致,以便于后面进行级联

    self.CLS_token = nn.Parameter(torch.randn(1, 1, self.dim))  # 别忘了维度要和 (B,L,C) 对齐
    

    我们得到了一个大小为 (1,1,dim) 的向量,但是我们的输入的是一个 batch,所以我们要对他进行复制,我们可以使用 einops 库中的 repeat 函数来进行复制,然后再进行级联

    CLS_token = repeat(self.CLS_token, '1 1 d -> b 1 d', b=b)  # 位置编码复制 B 份
    x = torch.concat((CLS_token, x), dim=1)
    

    其中 b 是 batch 大小

    可以发现 einops 库可以很方便的进行矩阵的重排

    3.1.4 嵌入位置信息

    ViT 使用可学习的位置编码,而 Transformer 使用的是 sin/cos 函数进行编码,使用可学习位置编码显然更为方便

    self.position_embedding = nn.Parameter(torch.randn(1, self.patch_num+1, self.dim))  # 使用广播
    

    可学习的参数一定要注册为 nn.Parameter

    向量的个数为 patch 的个数+1,因为因为在头部还加上了一个 CLS Token 呢,最后使用加法进行位置嵌入

    x = x+self.position_embedding
    

    好了每个模块都讲解完成,我们将他拼合

    class pre_proces(nn.Module):
        def __init__(self, image_size, patch_size, patch_dim, dim):
            super().__init__()
            self.patch_size = patch_size	# patch 的大小
            self.dim = dim	# Transformer 使用的维度,Transformer 的特性是输入输出大小不变
            self.patch_num = (image_size//patch_size)**2	# patch 的个数
            self.linear_embedding = nn.Linear(patch_dim, dim)	# 线性嵌入层
            self.position_embedding = nn.Parameter(torch.randn(1, self.patch_num+1, self.dim))  # 使用广播
            self.CLS_token = nn.Parameter(torch.randn(1, 1, self.dim))  # 别忘了维度要和 (B,L,C) 对齐
    
        def forward(self, x):
            x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)  # (B,L,C)
            x = self.linear_embedding(x)	# 线性嵌入
            b, l, c = x.shape   # 获取 token 的形状 (B,L,c)
            CLS_token = repeat(self.CLS_token, '1 1 d -> b 1 d', b=b)  # 位置编码复制 B 份
            x = torch.concat((CLS_token, x), dim=1)	# 级联 CLS Token
            x = x+self.position_embedding	# 位置嵌入
            return x
    

    3.2 Transformer

    这一部分将会是我们的重点,建议大家手推一下自注意力计算,不然可能会有点难理解

    3.2.1 多头自注意力

    首先来回忆一下自注意力公式:

    Output=softmax(QKTDk)V" role="presentation">Output=softmax(QKTDk)V

    输入通过 Wq,Wk,Wv" role="presentation">Wq,Wk,Wv 映射为 QKV,然后经过上述计算得到输出,多头注意力就是使用多个映射权重进行映射,然后最后拼接成为一个大的矩阵,再使用一个映射矩阵映射为输出函数

    还是一样,我们先把整个代码放上来,我们接着在逐行讲解

    class Multihead_self_attention(nn.Module):
        def __init__(self, heads, head_dim, dim):
            super().__init__()
            self.head_dim = head_dim    # 每一个注意力头的维度
            self.heads = heads  # 注意力头个数
            self.inner_dim = self.heads*self.head_dim  # 多头自注意力最后的输出维度
            self.scale = self.head_dim**-0.5   # 正则化系数
            self.to_qkv = nn.Linear(dim, self.inner_dim*3)  # 生成 qkv,每一个矩阵的维度和由自注意力头的维度以及头的个数决定
            self.to_output = nn.Linear(self.inner_dim, dim)
            self.norm = nn.LayerNorm(dim)
            self.softmax = nn.Softmax(dim=-1)
    
        def forward(self, x):
            x = self.norm(x)    # PreNorm
            qkv = self.to_qkv(x).chunk(3, dim=-1)  # 划分 QKV,返回一个列表,其中就包含了 QKV
            Q, K, V = map(lambda t: rearrange(t, 'b l (h dim) -> b h l dim', dim=self.head_dim), qkv)
            K_T = K.transpose(-1, -2)
            att_score = Q@K_T*self.scale
            att = self.softmax(att_score)
            out = att@V   # (B,H,L,dim)
            out = rearrange(out, 'b h l dim -> b l (h dim)')  # 拼接
            output = self.to_output(out)
            return output
    

    我们先用图来表示一下多头自注意力,也就是用多个不同的权重来映射,然后再计算自注意力,这样就得到了多组的输出,最后再进行拼接,使用一个大的矩阵来把多头自注意力输出映射回输入大小

    我们如何构造这多个权重矩阵来进行矩阵运算更快呢?答案是——写成一个线性映射,然后再通过矩阵重排来得到多组 QKV,然后计算自注意力,我们来看图:

    首先输入是一个 (N,dim) 的张量,我们可以把多头的映射横着排列变成一个大矩阵,这样使用一次矩阵运算就可以得到多个输出

    我这里假设了四个头,并且每一个头的维度是 2

    经过映射,我们得到了一个 (N,heads×head_dim)" role="presentation">(N,heads×head_dim) 大小的张量,这时候我们对其重新排列,形成 (heads,N,head_dim)" role="presentation">(heads,N,head_dim) 大小的张量,这样就把每一个头给分离出来了

    接着就是做自注意力,我们现在的张量当作 Q,K 就需要进行转置,其张量大小是 (heads,head_dim,N)" role="presentation">(heads,head_dim,N) ,二者进行相乘,得到的输出为 (heads,N,N)" role="presentation">(heads,N,N),这就是我们的注意力得分,经过 softmax 就可以和 V 相乘了

    这里省略了 softmax,重点看矩阵的维度变化

    计算自注意力输出,就是和 V 相乘,V 的张量大小为 (heads,N,head_dim)" role="presentation">(heads,N,head_dim) ,最后得到输出大小为 (heads,N,head_dim)" role="presentation">(heads,N,head_dim)

    我们把上一步的张量 (heads,N,head_dim)" role="presentation">(heads,N,head_dim) 重排为 (N,heads×head_dim)" role="presentation">(N,heads×head_dim),然后使用一个大小为 (heads×_dim,dim)" role="presentation">(heads×_dim,dim) 的矩阵映射回和输入相同的大小,这样多头自注意力就计算完成了

    大家可以像我一样把过程给写出来,可以清晰非常多,接下来我们再看一下代码实现:

    首先定义我们需要的映射矩阵以及 softmax 函数以及 layernorm 函数

    self.head_dim = head_dim    # 每一个注意力头的维度
    self.heads = heads  # 注意力头个数
    self.inner_dim = self.heads*self.head_dim  # 多头自注意力输出级联后的输出维度
    self.scale = self.head_dim**-0.5   # 正则化系数
    self.to_qkv = nn.Linear(dim, self.inner_dim*3)  # 生成 qkv,每一个矩阵的维度由自注意力头的维度以及头的个数决定
    self.to_output = nn.Linear(self.inner_dim, dim)	# 输出映射矩阵
    self.norm = nn.LayerNorm(dim)	# layerNorm
    self.softmax = nn.Softmax(dim=-1)	# softmax
    

    有了这些,我们可以开始 MHSA 的计算

        def forward(self, x):
            x = self.norm(x)    # PreNorm
            qkv = self.to_qkv(x).chunk(3, dim=-1)  # 按照最后一个维度均分为三分,也就是划分 QKV,返回一个列表,其中就包含了 QKV
            Q, K, V = map(lambda t: rearrange(t, 'b l (h dim) -> b h l dim', dim=self.head_dim), qkv)	# 对 QKV 的多头映射进行拆分,得到(B,head,L,head_dim)
            K_T = K.transpose(-1, -2)	# K 进行转置,用于计算自注意力
            att_score = Q@K_T*self.scale	# 计算自注意力得分
            att = self.softmax(att_score)	# softmax
            out = att@V   # (B,H,L,dim); 自注意力输出
            out = rearrange(out, 'b h l dim -> b l (h dim)')  # 拼接
            output = self.to_output(out)	#输出映射
            return output
    

    上面的部分进行组合

    class Multihead_self_attention(nn.Module):
        def __init__(self, heads, head_dim, dim):
            super().__init__()
            self.head_dim = head_dim    # 每一个注意力头的维度
            self.heads = heads  # 注意力头个数
            self.inner_dim = self.heads*self.head_dim  # 多头自注意力最后的输出维度
            self.scale = self.head_dim**-0.5   # 正则化系数
            self.to_qkv = nn.Linear(dim, self.inner_dim*3)  # 生成 qkv,每一个矩阵的维度和由自注意力头的维度以及头的个数决定
            self.to_output = nn.Linear(self.inner_dim, dim)
            self.norm = nn.LayerNorm(dim)
            self.softmax = nn.Softmax(dim=-1)
    
        def forward(self, x):
            x = self.norm(x)    # PreNorm
            qkv = self.to_qkv(x).chunk(3, dim=-1)  # 划分 QKV,返回一个列表,其中就包含了 QKV
            Q, K, V = map(lambda t: rearrange(t, 'b l (h dim) -> b h l dim', dim=self.head_dim), qkv)
            K_T = K.transpose(-1, -2)
            att_score = Q@K_T*self.scale
            att = self.softmax(att_score)
            out = att@V   # (B,H,L,dim)
            out = rearrange(out, 'b h l dim -> b l (h dim)')  # 拼接
            output = self.to_output(out)
            return output
    
    

    3.2.2 FeedForward

    构建后面的 FeedForward 模块,这个模块就是一个 MLP,中间夹着非线性激活,所以我们直接看代码吧

    class FeedForward(nn.Module):
        def __init__(self, dim, mlp_dim):
            super().__init__()
            self.fc1 = nn.Linear(dim, mlp_dim)
            self.fc2 = nn.Linear(mlp_dim, dim)
            self.norm = nn.LayerNorm(dim)
    
        def forward(self, x):
            x = self.norm(x)
            x = F.gelu(self.fc1(x))
            x = self.fc2(x)
            return x
    

    3.2.3 Transformer Block

    有了 MHSA 以及 FeedForward,我们可以来构建 Transformer Block,这是 Transformer 的基本单元,只需要把我们构建的模块进行组装,然后添加残差连接即可,不会很难

    class Transformer_block(nn.Module):
        def __init__(self, dim, heads, head_dim, mlp_dim):
            super().__init__()
            self.MHA = Multihead_self_attention(heads=heads, head_dim=head_dim, dim=dim)
            self.FeedForward = FeedForward(dim=dim, mlp_dim=mlp_dim)
    
        def forward(self, x):
            x = self.MHA(x)+x
            x = self.FeedForward(x)+x
            return x
    

    添加了一个参数 depth ,用来定义 Transformer 的层数

    ViT

    祝贺大家,走到最后一步啦!我们把上面的东西组装起来,构建 ViT 吧

    class ViT(nn.Module):
        def __init__(self, image_size, channels, patch_size, dim, heads, head_dim, mlp_dim, depth, num_class):
            super().__init__()
            self.to_patch_embedding = pre_proces(image_size=image_size, patch_size=patch_size, patch_dim=channels*patch_size**2, dim=dim)
            self.transformer = Transformer(dim=dim, heads=heads, head_dim=head_dim, mlp_dim=mlp_dim, depth=depth)
            self.MLP_head = nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, num_class)
            )
            self.softmax = nn.Softmax(dim=-1)
    
        def forward(self, x):
            token = self.to_patch_embedding(x)
            output = self.transformer(token)
            CLS_token = output[:, 0, :]	# 提取出 CLS Token
            out = self.softmax(self.MLP_head(CLS_token))
            return out
    

    总结

    这里我们手动实现了 ViT 的构建,不知道大家有没有对 Transformer 的架构有更深入的理解呢?我也是动手实现了才理解其各种细节,刚开始觉得自己不可能实现,但是最后还是成功的,感觉好开心:D

    参考

    [1] lucidrains/vit-pytorch

    [2] 全网最强ViT (Vision Transformer)原理及代码解析

    [3] Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020.


    __EOF__

  • 本文作者: 睡晚不猿序程
  • 本文链接: https://www.cnblogs.com/whp135/p/17228561.html
  • 关于博主: 评论和私信会在第一时间回复。或者直接私信我。
  • 版权声明: 本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
  • 声援博主: 如果您觉得文章对您有帮助,可以点击文章右下角推荐一下。
  • 相关阅读:
    LC516. 最长回文子序列
    Java设计模式之桥接模式
    C语言中文网 - Shell脚本 - 6
    一、openCV+TensorFlow环境搭建
    02、MongoDB -- MongoDB 的安全配置(创建用户、设置用户权限、启动安全控制、操作数据库命令演示、mongodb 的帮助系统介绍)
    Guava-EventBus 源码解析
    深入解析:AntSK 0.1.7版本的技术革新与多模型管理策略
    数据库云管平台 zCloud v3.5发布,智能化和国产数据库支持能力再增强
    2022java面试总结,1000道(集合+JVM+并发编程+Spring+Mybatis)的Java高频面试题
    【蓝桥杯选拔赛真题30】python开关灯 青少年组蓝桥杯python 选拔赛STEMA比赛真题解析
  • 原文地址:https://www.cnblogs.com/whp135/p/17228561.html