• Transformer 中 Positional Encoding 实现


    参考博文:

    https://www.cnblogs.com/nickchen121/p/16470736.html

    解决问题 

    位置编码的主要目的是确保模型能够理解序列中的元素之间的相对位置和顺序,从而更好地捕捉到语义信息。在Transformer模型中,位置编码通常与词嵌入(word embeddings)相加,以形成模型的输入表示。这有助于模型在处理序列数据时更好地理解元素的位置和顺序,从而提高其性能,特别是在自然语言处理任务中。

    原理

    这里就是拿经典款transformer举例了

    这个i是维度,2i这块是告诉你是sin还是cos的,是0~dimension/2

    详细过程:

    sin(pos+k) = sin(pos)*cos(k)+cos(pos)*sin(k) #sin表示偶数维度

    cos(pos+k) = cos(pos)cos(k) +sin(pos)sin(k) #cos表示奇数维度

    !pos+k可是pos和k的线性组合!

    例如

    pos+K=5, 当我计算第五个单词的位置编码时:

    pos=1, k=4; pos=2, k=3;

    这样就可以得知几个位置之间的相对关系

    代码实现

    Transformer

    一维绝对的位置编码

    1. def create_1d_absolute_sincos_embeddings(n_pos_vec,dim):
    2. assert dim % 2 == 0, "wrong dimension" # dim must be even
    3. # 初始化position embedding
    4. position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float) #numel()返回数组元素个数
    5. # omega是对i进行遍历
    6. omega = torch.arange(dim//2, dtype=torch.float) #//是整除
    7. omega /= dim/2.
    8. omega = 1./(10000**omega)
    9. out = n_pos_vec[:, None]@omega[None, :] # 先把n_pos_vec变成列向量,一个维度加上None相当于扩了一维;接下来是把omega拓成一个行向量, @是矩阵乘法
    10. emb_sin = torch.sin(out)
    11. emb_cos = torch.cos(out)
    12. # 接下来是偶数位用sin赋值,奇数位用cos去赋值
    13. position_embedding[:, 0::2] = emb_sin
    14. position_embedding[:, 1::2] = emb_cos
    15. return position_embedding
    16. if __name__ == '__main__':
    17. n_pos = 4
    18. dim = 4
    19. n_pos_vec = torch.arange(n_pos, dtype=torch.float)
    20. print(n_pos_vec)
    21. pe = create_1d_absolute_sincos_embeddings(n_pos_vec,dim)
    22. print("pe", pe)

    Vision Transformer

    一维的,绝对的,可训练的

    这里用的也是一维的位置编码因为论文里做了实验表明二维的位置编码对模型效果并没有提升

    1. def create_1d_absolute_trainable_embeddings(n_pos_vec,dim):
    2. # 传入索引
    3. # n_pos_vec: torch.aramge(n_pos, dtype=torch.float)
    4. # 因为可学习所以用nn.embedding来实现
    5. position_embedding = nn.Embedding(n_pos_vec.numel(), dim)
    6. # 初始化weight(parameter class)
    7. nn.init.constant_(position_embedding.weight, 0.)
    8. return position_embedding # 一维的,绝对的,可学习的embedding

     Swin Transformer

    二维的,相对的,基于位置偏差的

    相对位置,可学习

    1. def create_2d_relative_bias_trainable_embeddings(n_head,height,width,dim):
    2. # embeddings的行数就是bias的个数,列数就是num_heads
    3. # 横轴取值 width:5[0,1,2,3,4] bias ={-width+1, width-1 }{-4,4} 4-(-4)+1 = 9
    4. # 纵轴取值 height:5[0,1,2,3,4] bias ={-height+1, height-1} 1-(-1)+1 = 3
    5. position_embedding = nn.Embedding((2*width-1)*(2*height-1), n_head)
    6. # 初始化weight(parameter class)
    7. nn.init.constant_(position_embedding.weight, 0.)
    8. # 获取window中二维的,两两之间的位置偏差
    9. # step1:算出横轴和纵轴各自的位置偏差,用网格法把横轴的位置索引和纵轴的位置索引定义出来
    10. def get_2d_relative_position_index(height, width):
    11. m1, m2 = torch.meshgrid(torch.arange(height), torch.arange(width)) # m1行一样,m2列一样
    12. coords = torch.stack([m1, m2]) # 把m1和m2拼接起来,dim=-1表示最后一个维度 #2*height*width
    13. coords_flatten = torch.flatten(coords,1) # 把coords压缩成一维,dim=1表示第一个维度,得到2*【height*width】
    14. ralative_coords_bias = coords_flatten[:, :, :None]- coords_flatten[:, None, :]#得到网格里任意两点横轴纵轴坐标的差值,[2,height*width,height*width]
    15. # 把它们都变成正数
    16. ralative_coords_bias[0, :, :] += height-1 # 横轴坐标的差值,0代表高度维
    17. ralative_coords_bias[1, :, :] += width-1 # 纵轴坐标的差值 1代表宽度维
    18. # 把两个方向上的坐标转化成一个方向上的坐标,类似于把一个2dtensor赋值到1dtensor
    19. # A;2d,B:1d B[i*cols+j] = A[i,j]
    20. ralative_coords_bias[0,:,:] += ralative_coords_bias[1, :, :].max()+1 # 把横轴坐标的差值转化成一维坐标,即i*cols
    21. # 相对位置索引
    22. return ralative_coords_bias.sum(0) # [height*width,height*width] # 两个方向上的坐标相加,得到相对位置索引
    23. relative_position_bias = get_2d_relative_position_index(height, width) # [height*width,height*width]
    24. bias_embedding = position_embedding(torch.flatten(relative_position_bias)).reshape(height*width,height*width,n_head) # [height*width,height*width,n_head]
    25. bias_embedding.permute(2,0,1).unsqueeze(0) # [1, n_head,height*width,height*width]
    26. return bias_embedding # 二维的,相对的,可学习的embedding
  • 相关阅读:
    WPF自定义控件与样式(4)-CheckBox/RadioButton自定义样式
    07_瑞萨GUI(LVGL)移植实战教程之LVGL对接EC11旋转编码器驱动
    【JAVA】Scanner的next()、nextInt()、nextLine()读取机制
    Java:缓存行和伪共享
    神仙打架!腾讯云阿里云谁更棋高一着?
    Lyx使用对中文进行编译
    实例解释遇到前端报错时如何排查问题
    Volatile和CAS
    Spring实例化源码解析之FactoryBean(十一)
    dubbo学习(一)dubbo简介与原理
  • 原文地址:https://blog.csdn.net/Scabbards_/article/details/133808098