• Pytorch 中 tensor的维度拼接


    torch.stack() 和 torch.cat() 都可以按照指定的维度进行拼接,但是两者也有区别,torch.satck() 是增加新的维度进行堆叠,即其维度拼接后会增加一个维度;而torch.cat() 是在原维度上进行堆叠,即其维度拼接后的维度个数和原来一致。具体说明如下:

    torch.stack(input,dim)

    input: 待拼接的张量序列组(list or tuple),拼接的tensor的维度必须要相等,即tensor1.shape = tensor2.shape

    dim: 在哪个新增的维度上进行拼接,不能超过拼接后的张量数据的维度大小,默认为 0

    复制代码
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    er-hljs
    import torch x1 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) x2 = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) print(torch.stack((x1,x2),dim=0).shape) print(torch.stack((x1,x2),dim=1).shape) print(torch.stack((x1,x2),dim=2).shape) print(torch.stack((x1,x2),dim=0)) print(torch.stack((x1,x2),dim=1)) print(torch.stack((x1,x2),dim=2)) >> torch.Size([2, 3, 3]) # 2 表示是有两个tensor的拼接,且在第一个维度的位置拼接 >> torch.Size([3, 2, 3]) >> torch.Size([3, 3, 2]) >> tensor([[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) >> tensor([[[ 1, 2, 3], [10, 20, 30]], [[ 4, 5, 6], [40, 50, 60]], [[ 7, 8, 9], [70, 80, 90]]]) >> tensor([[[ 1, 10], [ 2, 20], [ 3, 30]], [[ 4, 40], [ 5, 50], [ 6, 60]], [[ 7, 70], [ 8, 80], [ 9, 90]]])
    折叠

    torch.cat(input, dim)

    input: 待拼接的张量序列组(list or tuple),拼接的tensor的维度必须要相等,即tensor1.shape = tensor2.shape

    dim: 在哪个已存在的维度上进行拼接,不能超过拼接后的张量数据的维度大小(即原来的维度大小),默认为 0

    复制代码
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    er-hljs
    import torch x1 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) x2 = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]]) print(torch.cat((x1,x2),dim=0).shape) print(torch.cat((x1,x2),dim=1).shape) print(torch.cat((x1,x2),dim=0)) print(torch.cat((x1,x2),dim=1)) >> torch.Size([6, 3]) >> torch.Size([3, 6]) >> tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 20, 30], [40, 50, 60], [70, 80, 90]]) >> tensor([[ 1, 2, 3, 10, 20, 30], [ 4, 5, 6, 40, 50, 60], [ 7, 8, 9, 70, 80, 90]])
    折叠
  • 相关阅读:
    云计算平台建设总体技术方案参考
    测试开发春招
    前后端分离项目,vue+uni-app+php+mysql订座预约小程序系统设计与实现
    数据结构:平衡二叉树
    一张逻辑图讲清楚OS在做什么:浅谈OS
    第三方库并不是必须的
    数据结构与算法3-数组
    面试官:今天要不来聊聊Redis基础吧?
    搭建极简GB28181 网守和网关服务器,建立AI推理和3d服务场景,然后开源代码(一)
    【C语言】程序环境深度剖析
  • 原文地址:https://www.cnblogs.com/jack-nie-23/p/16479560.html