• learn掩码张量


    目录

    1、什么是掩码张量

    2、掩码张量的作用

    3、代码演示

    (1)、定义一个上三角矩阵,k=0或者 k默认为 0

    (2)、k=1

    (3)、k=-1

    4、掩码张量代码实现

    (1)、输出效果

    (2)、输出效果分析


    1、什么是掩码张量

    • 掩就是代表遮掩,码就是张量中的数值,它的尺寸不定,里面只有 1 和 0 的元素,代表的位置被遮掩或者不被遮掩,至于是 0 位置被遮掩还是 1 位置被遮掩可以自己定义,因此它的作用就是让另外一个张量中的数值被遮掩,也可以说成是被替换,它的表现形式是一个张量

    2、掩码张量的作用

    • 在transformers中,掩码张量的主要作用应用在 attention时,有一些生成的attention张量中的值计算有可能已知了未来信息而得到的,未来信息被看到是因为训练时会把整个输出结果都一次性进行 Embedding,但是理论上解码器的输出却不是一次就能产生最终结果的,而是一次次通过上次结果综合得出的。因此,未来的信息可能被提前利用,所以,我们会进行遮掩

    3、代码演示

    (1)、定义一个上三角矩阵,k=0或者 k默认为 0

    1. attn_shape = (1,3,3) # 定义掩码张量的形状
    2. sub_mask = np.triu(np.ones(attn_shape), k = 0).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
    3. print(sub_mask)

    [[[1 1 1]
      [0 1 1]
      [0 0 1]]]

    (2)、k=1

    1. attn_shape = (1,3,3) # 定义掩码张量的形状
    2. sub_mask = np.triu(np.ones(attn_shape), k = 1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
    3. print(sub_mask)

    [[[0 1 1]
      [0 0 1]
      [0 0 0]]]

    (3)、k=-1

    1. attn_shape = (1,3,3) # 定义掩码张量的形状
    2. sub_mask = np.triu(np.ones(attn_shape), k = -1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
    3. print(sub_mask)

    [[[1 1 1]
      [1 1 1]
      [0 1 1]]]

    4、掩码张量代码实现

    1. import numpy as np
    2. import torch
    3. def subsequent_mask(size):
    4. """
    5. :param size: 生成向后遮掩的掩码张量,参数 size 是掩码张量的最后两个维度大小,它的最后两个维度形成一个方阵
    6. :return:
    7. """
    8. attn_shape = (1,size,size) # 定义掩码张量的形状
    9. subsequent_mask = np.triu(np.ones(attn_shape),k = 1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形
    10. return torch.from_numpy(1 - subsequent_mask) # 先将numpy 类型转化为 tensor,再做三角的翻转,将位置为 0 的地方变为 1,将位置为 1 的方变为 0
    1. size = 5
    2. sm = subsequent_mask(size)
    3. print("sm :",sm)
    4. # 掩码张量的可视化
    5. import matplotlib.pyplot as plt
    6. plt.figure(figsize=(5,5))
    7. plt.imshow(subsequent_mask(20)[0])

    (1)、输出效果

    (2)、输出效果分析

    • 通过观察可视化方阵,黄色是 1 的部分,这里代表被遮掩,紫色代表没有被遮掩的信息,横坐标代表目标词汇的位置,纵坐标代表可查看的位置
    • 我们看到,在 0 的位置我们以看望过去都是黄色的,都被遮掩了,1的位置一眼望过去还是黄色,说明第一次词还没有产生,从第二个位置看过去,就能看到位置 1 的词,其他位置看不到,以此类推

  • 相关阅读:
    java springboot在测试类中构建虚拟MVC环境并发送请求
    MMKV(1)
    【EI会议征稿】第三届应用力学与先进材料国际学术会议(ICAMAM 2024)
    Vuejs设计与实现 —— 渲染器核心 Diff 算法
    加载报错 Unsupported version (not an attribute), or file does not exist
    Tomcat服务(部署、虚拟主机配置、优化)
    Hadoop生态之Hive(二)
    flutter系列之:flutter架构什么的,看完这篇文章就全懂了
    VXLAN间通信
    C++类型转换运算符的重载,自增自减运算符的重载
  • 原文地址:https://blog.csdn.net/qq_51691366/article/details/133524954