目录
- attn_shape = (1,3,3) # 定义掩码张量的形状
- sub_mask = np.triu(np.ones(attn_shape), k = 0).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
- print(sub_mask)
[[[1 1 1]
[0 1 1]
[0 0 1]]]
- attn_shape = (1,3,3) # 定义掩码张量的形状
- sub_mask = np.triu(np.ones(attn_shape), k = 1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
- print(sub_mask)
[[[0 1 1]
[0 0 1]
[0 0 0]]]
- attn_shape = (1,3,3) # 定义掩码张量的形状
- sub_mask = np.triu(np.ones(attn_shape), k = -1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
- print(sub_mask)
[[[1 1 1]
[1 1 1]
[0 1 1]]]
- import numpy as np
- import torch
- def subsequent_mask(size):
- """
- :param size: 生成向后遮掩的掩码张量,参数 size 是掩码张量的最后两个维度大小,它的最后两个维度形成一个方阵
- :return:
- """
- attn_shape = (1,size,size) # 定义掩码张量的形状
- subsequent_mask = np.triu(np.ones(attn_shape),k = 1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形
- return torch.from_numpy(1 - subsequent_mask) # 先将numpy 类型转化为 tensor,再做三角的翻转,将位置为 0 的地方变为 1,将位置为 1 的方变为 0
- size = 5
- sm = subsequent_mask(size)
- print("sm :",sm)
- # 掩码张量的可视化
- import matplotlib.pyplot as plt
- plt.figure(figsize=(5,5))
- plt.imshow(subsequent_mask(20)[0])
