• torch.gather() 用法解读


    torch.gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor

    沿dim" role="presentation" style="position: relative;">dim指定的轴和index" role="presentation" style="position: relative;">index指定的索引从input" role="presentation" style="position: relative;">input中提取对应的值。

    对于一个三维张量

    output[i][j][k]=input[index[i][j][k]][j][k]#ifdim==0" role="presentation" style="position: relative;">output[i][j][k]=input[index[i][j][k]][j][k]#ifdim==0 

    output[i][j][k]=input[i][index[i][j][k]][k]#ifdim==1" role="presentation" style="position: relative;">output[i][j][k]=input[i][index[i][j][k]][k]#ifdim==1 

    output[i][j][k]=input[i][j][index[i][j][k]]#ifdim==2" role="presentation" style="position: relative;">output[i][j][k]=input[i][j][index[i][j][k]]#ifdim==2 

    input" role="presentation" style="position: relative;">inputindex" role="presentation" style="position: relative;">indexdimensions" role="presentation" style="position: relative;">dimensions数目必须相同。 out" role="presentation" style="position: relative;">outindex" role="presentation" style="position: relative;">indexshape" role="presentation" style="position: relative;">shape是相同的。(注意dimensions" role="presentation" style="position: relative;">dimensionsshape" role="presentation" style="position: relative;">shape的区别)

    示例

    下面用两个例子来解释一下具体的用法

    例1

    1. import torch
    2. dim = 0
    3. _input = torch.tensor([[10, 11, 12],
    4. [13, 14, 15],
    5. [16, 17, 18]])
    6. index = torch.tensor([[0, 1, 2],
    7. [1, 2, 0]])
    8. output = torch.gather(_input, dim, index)
    9. print(output)
    10. # tensor([[10, 14, 18],
    11. # [13, 17, 12]])

    该例中 _input.shape=(3, 3),dimensions=2,其中_input和index的dimensions相同都为2,output和index的shape相同都为(2, 3)。

    因为dim=0,index中的每个数其值代表dim=0即"行"这个维度的索引,而每个数本身所在位置的索引指定了其它维度的索引。比如index中第0行的[0, 1, 2]分别表示第0、1、2行,而这三个数本身在dim=1维度的索引为0、1、2即第0、1、2列。因此第一个数0定位到_input中的第0行,而0本身在index中的第0列,因此又定位到_input的第0列,这样就找到了10这个数,同理找到14和18。

    index中的第1行[1, 2, 0]分别表示_input中的第1、2、0行和第0、1、2列,因此找到_input中对应的数[13, 17, 12]。

    例2

    1. import torch
    2. dim = 1
    3. _input = torch.tensor([[10, 11, 12],
    4. [13, 14, 15],
    5. [16, 17, 18]])
    6. index = torch.tensor([[0, 1],
    7. [1, 2],
    8. [2, 0]])
    9. output = torch.gather(_input, dim, index)
    10. print(output)
    11. # tensor([[10, 11],
    12. # [14, 15],
    13. # [18, 16]])

    该例中 _input.shape=(3, 3),dimensions=2,其中_input和index的dimensions相同都为2,output和index的shape相同都为(3, 2)。

    因为dim=1,index中的每个数其值代表dim=1即"列"这个维度的索引,而每个数本身所在位置的索引指定了其它维度的索引。比如index中第0行的[0, 1]分别表示第0、1列,而这三个数本身在dim=0维度的索引为0即第0行。因此第一个数0定位到_input中的第0列,而0本身在index中的第0行,因此又定位到_input的第0行,这样就找到了10这个数,同理找到11。

    index中的第1行[1, 2]分别表示_input中的第1、2列和第1行,因此找到_input中对应的数[14, 15]。

    index中的第2行[2, 0]分别表示_input中的第2、0列和第2行,因此找到_input中对应的数[18, 16]。

    总结

    上面的示例是二维的情况,同理也可以推广到三维甚至更多维。总结来说,index中每个数其本身的值表示参数dim指定维度的索引,而其它的每个维度都由每个数在index中的对应维度的索引指定。

    参考

    torch.gather — PyTorch 1.12 documentation

    python - What does the gather function do in pytorch in layman terms? - Stack Overflow

  • 相关阅读:
    【三、centOS安装后的基本配置】
    车联网解决方案-最新全套文件
    C# Thread.Sleep(0)有什么用?
    Java代码审计——文件操作漏洞
    Kafka为什么性能这么快?4大核心原因详解
    linux开机自动启动java的jar包项目及开机自动启动Nacos的配置
    软件工程第六周之服务层与API调用
    【Nacos】源码之服务端服务注册
    devpi
    postman接口测试工具发起webservice请求
  • 原文地址:https://blog.csdn.net/ooooocj/article/details/126046919