• pytorch使用cat()和stack()拼接tensors


    有时我们在处理数据时,需要对指定的tensor按照指定维度进行拼接,对于这个需求,pytorch中提供了两个函数供我们使用,一个是torch.cat(),另外一个是torch.stack(),这两者都可以拼接tensor,但是这二者又有一些区别。

    二者相同点就是都可以实现拼接tensor,不同之处就是是否是在新的维度上进行拼接(是否产生新的维度)。

    一、torch.cat()

    该方法可以将任意个tensor按照指定维度进行拼接,需要传入两个参数,一个参数是需要拼接的tensor,需要以列表的形式进行传入,第二个参数就是需要拼接的维度。

    a = torch.randn(3, 4)
    b = torch.randn(3, 4)
    
    c = torch.cat([a, b], dim=0)
    print(c)
    print(c.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    tensor([[ 0.1040, -0.3168, -1.3974, -1.2703],
            [ 0.4375,  1.4254,  0.2875, -0.2420],
            [-0.9663, -1.8022, -1.2352,  0.7283],
            [-0.4226,  0.0375, -0.3861,  1.3939],
            [ 1.6275, -0.1319, -0.7143,  0.3624],
            [ 0.2245, -1.7482, -0.7933, -0.1008]])
    torch.Size([6, 4])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    该例子中我们定义了两个tensor,维度分别都是【3,4】,我们使用cat进行拼接,传入的维度是0,那么我们得到的结果就是会将两个tensor按照第一个维度进行拼接,可以理解为按行堆叠,把每一行想成一个样本,那么我们拼接后就会得到6个样本,维度变成【6,4】。

    二、torch.stack()

    第二种方法就是torch.stack()了,该方法也可以进行拼接,但是与cat有一些不同。

    对于传入的参数列表和torch.cat是一样的,但是stack指定的dim是一个新的维度,最终是在这个新的维度上进行拼接。

    a = torch.randn(3, 4)
    b = torch.randn(3, 4)
    
    c = torch.stack([a, b], dim=0)
    print(c)
    print(c.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    tensor([[[ 0.1040, -0.3168, -1.3974, -1.2703],
             [ 0.4375,  1.4254,  0.2875, -0.2420],
             [-0.9663, -1.8022, -1.2352,  0.7283]],
    
            [[-0.4226,  0.0375, -0.3861,  1.3939],
             [ 1.6275, -0.1319, -0.7143,  0.3624],
             [ 0.2245, -1.7482, -0.7933, -0.1008]]])
    torch.Size([2, 3, 4])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    上面我们指定拼接的dim为0,那么我们会新产生一个维度,得到结果【2,3,4】,原来两个tensor的维度不变,新生成一个维度2,代表拼接后维度。

    c = torch.stack([a, b], dim=1)
    print(c)
    print(c.shape)
    
    • 1
    • 2
    • 3
    tensor([[[ 0.1040, -0.3168, -1.3974, -1.2703],
             [-0.4226,  0.0375, -0.3861,  1.3939]],
    
            [[ 0.4375,  1.4254,  0.2875, -0.2420],
             [ 1.6275, -0.1319, -0.7143,  0.3624]],
    
            [[-0.9663, -1.8022, -1.2352,  0.7283],
             [ 0.2245, -1.7482, -0.7933, -0.1008]]])
    torch.Size([3, 2, 4])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    如果我们设置为1,那么就会新产生1一个维度在第二位,得到结果【3,2,4】。

  • 相关阅读:
    Java爬虫详解
    在线漫画app开发,更好地保证用户的个性化体验
    贪心算法(活动安排问题)
    springSecurity认证功能初体验
    微信小程序开发学习笔记《17》uni-app框架-tabBar
    ros中对move_base的调用
    算法刷题日志——回溯算法
    自己部署 Docker Kong
    【虚拟文件系统】文件系统 API 解读(1)
    英语学习笔记
  • 原文地址:https://blog.csdn.net/m0_47256162/article/details/127822995