• PyTorch 中张量运算广播


    TLDR

    右对齐,空补一,从左往右依维运算
    [m] + [x, y] = [m +x, m + y]

    正文

    以如下 a b 两个 tensor 计算为例

    a = torch.tensor([
        [1],
        [2],
        [3],
    ])
    b = torch.tensor([
        [
            [1, 2, 3],
        ],
        [
            [4, 5, 6],
        ],
        [
            [7, 8, 9],
        ],
    ])
    # a.shape = (3, 1)
    # b.shape = (3, 1, 3)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    首先将两个 tensor 的 shape 右对齐
    a( , 3, 1)
    b(3, 1, 3)

    判断两个 tensor 是否满足广播规则

    • tensor 至少有一个维度(比如 torch.tensor((0,)) 便不符合本要求)
    • 检查上一步对齐的 tensor shape,要求两个 tensor 对应维度的大小:要么相同;要么其中一个为 1;要么其中一个为空
    • 如果满足上述规则,则继续,否则报错

    将对齐后空缺的维度设置为 1
    a(1, 3, 1)
    b(3, 1, 3)
    其实就是对 a 进行了扩维,此时两个 tensor 为:

    a = torch.tensor([
        [
            [1],
            [2],
            [3],
        ],
    ])
    b = torch.tensor([
        [
            [1, 2, 3],
        ],
        [
            [4, 5, 6],
        ],
        [
            [7, 8, 9],
        ],
    ])
    # a.shape = (1, 3, 1)
    # b.shape = (3, 1, 3)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    从左往右对两个 tensor 的每一个维度进行运算,按照以下规则

    • 如果大小相同,则直接进行运算即可(一一对应)
    • 如果其中一个大小为 1,则使用这个元素与另一个 tensor 当前维度下的每个元素进行运算(本质是一个递归操作)

    例如计算 a + b (这两个 tensor 已经经过上述步骤处理,即维度已经相同)

    # 1. 因为 a.shape[0] == 1,所以将 a[0] 分别与 b[0]、b[1]、b[2] 相加
    [
    	a[0] + b[0],
    	a[0] + b[1],
    	a[0] + b[2],
    ]
    
    # 2. 接下来继续往后计算,以 a[0] + b[0] 为例
    #    因为 a[0].shape[0] = 3, b[0].shape[0] = 1,
    #    所以将 b[0][0] 分别与 a[0][0]、a[0][1]、a[0][2] 相加
    [
    	[	# a[0] + b[0]
    		a[0][0] + b[0][0],
    		a[0][1] + b[0][0],
    		a[0][2] + b[0][0],
    	],
    	[	# a[0] + b[1]
    		a[0][0] + b[1][0],
    		a[0][1] + b[1][0],
    		a[0][2] + b[1][0],
    	],
    	[	# a[0] + b[2]
    		a[0][0] + b[2][0],
    		a[0][1] + b[2][0],
    		a[0][2] + b[2][0],
    	],
    ]
    
    # 3. 继续往后计算,以 a[0][0] + b[0][0] 为例
    #    因为 a[0][0].shape[0] == 1,
    #    所以将 a[0][0][0] 分别与 b[0][0][0]、b[0][0][1]、b[0][0][2] 相加
    [
    	[	# a[0] + b[0]
    		[ 	# a[0][0] + b[0][0]
    			a[0][0][0] + b[0][0][0],
    			a[0][0][0] + b[0][0][1],
    			a[0][0][0] + b[0][0][2],
    		],
    		[ 	# a[0][1] + b[0][0]
    			a[0][1][0] + b[0][0][0],
    			a[0][1][0] + b[0][0][1],
    			a[0][1][0] + b[0][0][2],
    		],
    		[ 	# a[0][2] + b[0][0]
    			a[0][2][0] + b[0][0][0],
    			a[0][2][0] + b[0][0][1],
    			a[0][2][0] + b[0][0][2],
    		],
    	],
    	[	# a[0] + b[1]
    		[ 	# a[0][0] + b[1][0]
    			a[0][0][0] + b[1][0][0],
    			a[0][0][0] + b[1][0][1],
    			a[0][0][0] + b[1][0][2],
    		],
    		[ 	# a[0][1] + b[1][0]
    			a[0][1][0] + b[1][0][0],
    			a[0][1][0] + b[1][0][1],
    			a[0][1][0] + b[1][0][2],
    		],
    		[ 	# a[0][2] + b[1][0]
    			a[0][2][0] + b[1][0][0],
    			a[0][2][0] + b[1][0][1],
    			a[0][2][0] + b[1][0][2],
    		],
    	],
    	[	# a[0] + b[2]
    		[ 	# a[0][0] + b[2][0]
    			a[0][0][0] + b[2][0][0],
    			a[0][0][0] + b[2][0][1],
    			a[0][0][0] + b[2][0][2],
    		],
    		[ 	# a[0][1] + b[2][0]
    			a[0][1][0] + b[2][0][0],
    			a[0][1][0] + b[2][0][1],
    			a[0][1][0] + b[2][0][2],
    		],
    		[ 	# a[0][2] + b[2][0]
    			a[0][2][0] + b[2][0][0],
    			a[0][2][0] + b[2][0][1],
    			a[0][2][0] + b[2][0][2],
    		],
    	],
    ]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84

    总结

    右对齐空补一,从左往右依维递归)运算。
    一个 tensor 的某个维度大小为 1 时的计算规则:[1] + [2, 3, 4] = [1 + 2, 1 + 3, 1 + 4]

    《PyTorch 官方文档:BROADCASTING SEMANTICS》

  • 相关阅读:
    2.06_python+Django+mysql实现pdf转word项目_项目开发-创建模型
    php沿河农产品特卖网站的设计与实现毕业设计源码201521
    某电商网站的数据库设计(3)
    zookeeper、Dubbo
    python+vue+elementui心理健康测试教育系统django339
    python进阶(29)单例模式
    RabbitMQ死信队列、延时队列
    stm32f334定时器配置详细解释
    Vuex源码解析
    如何入驻抖音本地生活服务商,附上便捷流程!
  • 原文地址:https://blog.csdn.net/gd920129/article/details/133821512