• torch_scatter.scatter()的使用方法详解


    1. 参数

    在这里插入图片描述
    具体来讲,scatter函数的作用就是将index中相同索引对应位置的src元素进行某种方式的操作,例如summean等,然后将这些操作结果按照索引顺序进行拼接。下面我用具体的例子来进行讲解。

    2. 示例

    2.1 简单示例

    首先初始化src和index:

    src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
    index = torch.tensor([0, 0, 1], dtype=torch.int64)
    
    • 1
    • 2

    接着使用scatter函数:

    out = scatter(src, index, dim=0, reduce='mean')
    
    • 1

    我们观察index=[0, 0, 1],第0个位置和第1个位置都为0,第2个位置为1。也就是说,我们需要将src中第0个元素和第1个元素求平均变成一个元素,然后第2个元素求mean也就是本身为一个元素。如果index=[1, 0, 0],则意味着我们需要将src中第1个元素和第2个元素求平均变成一个元素,而第0个元素保持不变。

    那么src中第几个元素到底是如何定义的呢?这就需要用到dim参数了。

    dim=0意味着我们需要对src的维度0进行操作:

    tensor([[1., 2., 3.],
            [4., 5., 6.],
            [7., 8., 9.]])
    
    • 1
    • 2
    • 3

    即src中第0个元素为[1, 2, 3],第1个元素为[4, 5, 6],第2个元素为[7, 8, 9]

    而如果dim=1,则第0个元素为[1, 4, 7],第1个元素为[2, 5, 8],第2个元素为[3, 6, 9]

    因此,如果有以下代码:

    src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
    index = torch.tensor([0, 0, 1], dtype=torch.int64)
    out = scatter(src, index, dim=0, reduce='mean')
    
    • 1
    • 2
    • 3

    那么我们就应该将src中的第0个元素为[1, 2, 3]和第1个元素为[4, 5, 6]求平均为[2.5, 3.5, 4.5],然后第2个元素[7, 8, 9]保持不变,即:

    tensor([[2.5000, 3.5000, 4.5000],
            [7.0000, 8.0000, 9.0000]])
    
    • 1
    • 2

    2.2 顺序问题

    上面的例子中index=[0, 0, 1],最后结果是将src中第0个元素和第1个元素求平均放到了位置0,然后src中第2个元素保持不变放到了位置1。

    如果index=[1, 1, 0],结果为:

    tensor([[7.0000, 8.0000, 9.0000],
            [2.5000, 3.5000, 4.5000]])
    
    • 1
    • 2

    可以发现,上述结果是将src中第2个元素[7, 8, 9]保持不变放到了位置0,然后将src中第0个元素[1, 2, 3]和第1个元素[4, 5, 6]求平均保持不变放到了位置1。

    也就是说,无论index怎么变化,都是优先将index中0对应位置的操作结果进行放置。

    2.3 维度问题

    如果src的维度为(4, 3),而我们需要对dim=0操作,也就是一共有四个元素,那么index的长度应该为4,即以下操作是不合法的:

    src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
    index = torch.tensor([1, 1, 0], dtype=torch.int64)
    out = scatter(src, index, dim=0, reduce='mean')
    print(out)
    
    • 1
    • 2
    • 3
    • 4

    报错为:

    RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 0.  Target sizes: [4, 3].  Tensor sizes: [3, 1]
    
    • 1

    正确做法应该是:

    src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
    index = torch.tensor([1, 1, 0, 2], dtype=torch.int64)
    out = scatter(src, index, dim=0, reduce='mean')
    print(out)
    
    • 1
    • 2
    • 3
    • 4

    输出为:

    tensor([[ 7.0000,  8.0000,  9.0000],
            [ 2.5000,  3.5000,  4.5000],
            [10.0000, 11.0000, 12.0000]])
    
    • 1
    • 2
    • 3
  • 相关阅读:
    使用Harbor作为docker镜像仓库之安装运行Harbor
    oceanbase数据库安装和连接实战(阿里云服务器操作)
    docker——入门介绍、组件介绍、安装与启动、镜像相关命令、容器相关命令、应用部署
    英语单词和词组笔记
    【跟学C++】C++链表——List类(Study11)
    java八股文面试[设计模式]——23种设计模式
    【VulnHub靶场】Hackable: III
    焱融全闪系列科普|固态存储核心技术 SSD
    竞赛 机器学习股票大数据量化分析与预测系统 - python 竞赛
    etcd实现大规模服务治理应用实战
  • 原文地址:https://blog.csdn.net/Cyril_KI/article/details/125908710