• Pytorch,矩阵求和维度变化解析


    二维可以想象成一张纸,
    三维可以想象成多张纸叠在一块
    四维可以想成多沓纸
    求和时,如果没设定keepdim=True,则会消去相加的那一维度,否则则将维度变为1

    A = torch.arange(20).reshape(5, 4)
    A,A.shape, A.sum()
    
    • 1
    • 2
    (tensor([[ 0,  1,  2,  3],
             [ 4,  5,  6,  7],
             [ 8,  9, 10, 11],
             [12, 13, 14, 15],
             [16, 17, 18, 19]]),
     torch.Size([5, 4]),
     tensor(190))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    指定求和汇总张量的轴

    A_sum_axis0 = A.sum(axis=0)
    A_sum_axis0, A_sum_axis0.shape
    
    • 1
    • 2
    (tensor([40, 45, 50, 55]), torch.Size([4]))
    
    • 1
    A_sum_axis1 = A.sum(axis=1)
    A_sum_axis1, A_sum_axis1.shape
    
    • 1
    • 2
    (tensor([ 6, 22, 38, 54, 70]), torch.Size([5]))
    
    • 1
    # 等价于A.SUM()
    A.sum(axis=[0, 1]), A.sum(axis=[0, 1]).shape
    
    • 1
    • 2
    (tensor(190), torch.Size([]))
    
    • 1
    # 三维 测试
    
    SA = torch.arange(20 * 2).reshape(2, 5, 4)
    SA, SA.shape
    
    • 1
    • 2
    • 3
    • 4
    (tensor([[[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11],
              [12, 13, 14, 15],
              [16, 17, 18, 19]],
     
             [[20, 21, 22, 23],
              [24, 25, 26, 27],
              [28, 29, 30, 31],
              [32, 33, 34, 35],
              [36, 37, 38, 39]]]),
     torch.Size([2, 5, 4]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    SA_sum_axis0 = SA.sum(axis=0)
    SA_sum_axis0, SA_sum_axis0.shape
    
    • 1
    • 2
    (tensor([[20, 22, 24, 26],
             [28, 30, 32, 34],
             [36, 38, 40, 42],
             [44, 46, 48, 50],
             [52, 54, 56, 58]]),
     torch.Size([5, 4]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    SA_sum_axis1 = SA.sum(axis=1)
    SA_sum_axis1, SA_sum_axis1.shape
    
    • 1
    • 2
    (tensor([[ 40,  45,  50,  55],
             [140, 145, 150, 155]]),
     torch.Size([2, 4]))
    
    • 1
    • 2
    • 3
    SA_sum_axis2 = SA.sum(axis=2)
    SA_sum_axis2, SA_sum_axis2.shape
    
    • 1
    • 2
    (tensor([[  6,  22,  38,  54,  70],
             [ 86, 102, 118, 134, 150]]),
     torch.Size([2, 5]))
    
    • 1
    • 2
    • 3

    四维矩阵求和

    A = torch.arange(20*2*2).reshape((2,2,5,4))
    A, A.shape
    
    • 1
    • 2
    (tensor([[[[ 0,  1,  2,  3],
               [ 4,  5,  6,  7],
               [ 8,  9, 10, 11],
               [12, 13, 14, 15],
               [16, 17, 18, 19]],
     
              [[20, 21, 22, 23],
               [24, 25, 26, 27],
               [28, 29, 30, 31],
               [32, 33, 34, 35],
               [36, 37, 38, 39]]],
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11


    [[[40, 41, 42, 43],
    [44, 45, 46, 47],
    [48, 49, 50, 51],
    [52, 53, 54, 55],
    [56, 57, 58, 59]],

              [[60, 61, 62, 63],
               [64, 65, 66, 67],
               [68, 69, 70, 71],
               [72, 73, 74, 75],
               [76, 77, 78, 79]]]]),
     torch.Size([2, 2, 5, 4]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    A_sum_axis0 = A.sum(axis=0)
    A_sum_axis0, A_sum_axis0.shape
    
    • 1
    • 2
    (tensor([[[ 40,  42,  44,  46],
              [ 48,  50,  52,  54],
              [ 56,  58,  60,  62],
              [ 64,  66,  68,  70],
              [ 72,  74,  76,  78]],
     
             [[ 80,  82,  84,  86],
              [ 88,  90,  92,  94],
              [ 96,  98, 100, 102],
              [104, 106, 108, 110],
              [112, 114, 116, 118]]]),
     torch.Size([2, 5, 4]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    A_sum_axis1 = A.sum(axis=1)
    A_sum_axis1, A_sum_axis1.shape
    
    • 1
    • 2
    (tensor([[[ 20,  22,  24,  26],
              [ 28,  30,  32,  34],
              [ 36,  38,  40,  42],
              [ 44,  46,  48,  50],
              [ 52,  54,  56,  58]],
     
             [[100, 102, 104, 106],
              [108, 110, 112, 114],
              [116, 118, 120, 122],
              [124, 126, 128, 130],
              [132, 134, 136, 138]]]),
     torch.Size([2, 5, 4]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
  • 相关阅读:
    web网页设计期末课程大作业:家乡旅游主题网站设计——河北8页HTML+CSS+JavaScript
    chrome浏览器关闭百度热搜——AdBlock插件
    vue2 elementui 封装一个动态表单复杂组件
    大模型的交互能力
    C++中的string类
    初识JVM(简单易懂),解开JVM神秘的面纱
    Unity实现简单的对象池
    好用移动APP自动化测试框架哪里找?收藏这份清单就好了!
    uni-app 瀑布流布局的实现
    创建-查看-使用-数据库
  • 原文地址:https://blog.csdn.net/BruceBorgia/article/details/133901167