• pytorch中gather函数的理解


    pytorch函数gather理解

    torch.gather(input, dim, index, out=None) → Tensor 
    
    • 1

    Parameters:

    • input (Tensor) – 源张量
    • dim (int) – 索引的轴
    • index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
    • out (Tensor, optional) – 目标张量

    公式含义

    这个函数的意义就是可以重新排列特定维度的信息。对一个三维张量,从公式来看,输出是下面这种,就是在特定维度上,用索引index下标代替所在位置的值。

    out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
    out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
    out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
    
    • 1
    • 2
    • 3

    直观理解

    原始tensor ,名称为a

    a = torch.randint(0, 30, (2, 3, 5))
    
    • 1

    以下以 CxHxW的维度讲述,其中C=2,H=3, W=5,
    在这里插入图片描述

    index = torch.LongTensor([[[0,1,2,0,2],
                              [0,0,0,0,0],
                              [1,1,1,1,1]],
                            [[1,2,2,2,2],
                             [0,0,0,0,0],
                             [2,2,2,2,2]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    指定dim = 1,也就是在第二个维度上H重新排列,

    b = torch.gather(a, 1,index)
    
    • 1

    此时,第一个维度C是不会改变的,还是存在两个通道C,分别是a[0]和a[1],
    针对a[0]或者a[1] , 在高度维度H上,分别是3行,a[0][0:2] a[1][0:2]。即

    a[0].shape == [3,5]
    
    • 1

    因此,如果选择dim=1,则index 张量里面的数必须在0-2之间,不然会越界,
    下一步就是选取数字了。
    针对每一个通道C,输出张量b,只需要按照index重新排列矩阵即可
    例如在第b[0,1,2]的位置,则选择a[0][index[0,1,2]][2]的值进行代替即可。

    同理在其他维度也是一样。

    注意点

    需要注意的是索引矩阵不能越界,例如针对上述a[2,3,5],
    如果指定dim=0,则index里面的数不能超过1,指定dim=1,则index不能超过2,指定dim=3,则index不能超过4

    本文参考https://www.jianshu.com/p/5d1f8cd5fe31

  • 相关阅读:
    【React】第六部分 生命周期
    基于STM32_DS18B20单总线传感器驱动
    爬虫之BeautifulSoup4
    MyBatis核心对象简介说明
    SqlPlus访问oracle
    CF1286E-Fedya the Potter Strikes Back【KMP,RMQ】
    HTML5基础入门
    gitlab 设置 分支只读
    探索Java面向对象编程的奇妙世界(四)
    基于JAVA医院管理系统计算机毕业设计源码+系统+数据库+lw文档+部署
  • 原文地址:https://blog.csdn.net/weixin_43707042/article/details/134539786