• 一文轻松掌握深度学习框架中的einsum


    60f2071ef8dbc616ffd6859b4bd04270.png

    导语:本文主要介绍了如何理解 PyTorch 中的爱因斯坦求和 (einsum) ,并结合实际例子讲解和 PyTorch C++实现代码解读,希望读者看完本文后掌握 einsum 的基本用法。

    撰文|梁德澎

    原文首发于公众号GiantpandaCV

     

    1

    爱因斯坦求和约定

    爱因斯坦求和约定(einsum)提供了一套既简洁又优雅的规则,可实现包括但不限于:向量内积,向量外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练运用 einsum 可以很方便地实现复杂的张量操作,而且不容易出错。

    三条基本规则

    首先看下 einsum 实现矩阵乘法的例子:

    1. a = torch.rand(2,3)
    2. b = torch.rand(3,4)
    3. c = torch.einsum("ik,kj->ij", [a, b])
    4. # 等价操作 torch.mm(a, b)

    其中需要重点关注的是 einsum 的第一个参数 "ik,kj->ij",该字符串(下文以 equation 表示)表示了输入和输出张量的维度。equation 中的箭头左边表示输入张量,以逗号分割每个输入张量,箭头右边则表示输出张量。表示维度的字符只能是26个英文字母 'a' - 'z'。

    而 einsum 的第二个参数表示实际的输入张量列表,其数量要与 equation 中的输入数量对应。同时对应每个张量的 子 equation 的字符个数要与张量的真实维度对应,比如 "ik,kj->ij" 表示输入和输出张量都是两维的。

    equation 中的字符也可以理解为索引,就是输出张量的某个位置的值,是怎么从输入张量中得到的,比如上面矩阵乘法的输出 c 的某个点 c[i, j] 的值是通过 a[i, k] 和 b[k, j] 沿着 k 这个维度做内积得到的。

    接着介绍两个基本概念,自由索引(Free indices)和求和索引(Summation indices):

    • 自由索引,出现在箭头右边的索引,比如上面的例子就是 i 和 j;

    • 求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,比如上面的例子就是 k。

    接着是介绍三条基本规则:

    • 规则一:equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;

    • 规则二:只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;

    • 规则三:equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。

    特殊规则

    特殊规则有两条:

    • equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;

    • equation 中支持 "..." 省略号,用于表示用户并不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写:

    1. a = torch.randn(2,3,5,7,9)
    2. # i = 7, j = 9
    3. b = torch.einsum('...ij->...ji', [a])

    2

    实际例子解读

    接下来将展示13个具体的例子,在这些例子中会将 PyTorch einsum 与对应的 PyTorch 张量接口和 Python 简单的循环展开实现做对比,希望读者看完这些例子之后能轻松掌握 einsum 的基本用法。

    实验代码github链接:

    https://github.com/Ldpe2G/CodingForFun/tree/master/einsum_ex

    1.提取矩阵对角线元素

    1. import torch
    2. import numpy as np
    3. a = torch.arange(9).reshape(3, 3)
    4. # i = 3
    5. torch_ein_out = torch.einsum('ii->i', [a]).numpy()
    6. torch_org_out = torch.diagonal(a, 0).numpy()
    7. np_a = a.numpy()
    8. # 循环展开实现
    9. np_out = np.empty((3,), dtype=np.int32)
    10. # 自由索引外循环
    11. for i in range(0, 3):
    12. # 求和索引内循环
    13. # 这个例子并没有求和索引,
    14. # 所以相当于是1
    15. sum_result = 0
    16. for inner in range(0, 1):
    17. sum_result += np_a[i, i]
    18. np_out[i] = sum_result
    19. print("input:\n", np_a)
    20. print("torch ein out: \n", torch_ein_out)
    21. print("torch org out: \n", torch_org_out)
    22. print("numpy out: \n", np_out)
    23. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
    24. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
    25. # 终端打印结果
    26. # input:
    27. # [[0 1 2]
    28. # [3 4 5]
    29. # [6 7 8]]
    30. # torch ein out:
    31. # [0 4 8]
    32. # torch org out:
    33. # [0 4 8]
    34. # numpy out:
    35. # [0 4 8]
    36. # is np_out == torch_ein_out ? True
    37. # is torch_org_out == torch_ein_out ? True

    2. 矩阵转置

    1. import torch
    2. import numpy as np
    3. a = torch.arange(6).reshape(2, 3)
    4. # i = 2, j = 3
    5. torch_ein_out = torch.einsum('ij->ji', [a]).numpy()
    6. torch_org_out = torch.transpose(a, 0, 1).numpy()
    7. np_a = a.numpy()
    8. # 循环展开实现
    9. np_out = np.empty((3, 2), dtype=np.int32)
    10. # 自由索引外循环
    11. for j in range(0, 3):
    12. for i in range(0, 2):
    13. # 求和索引内循环
    14. # 这个例子并没有求和索引
    15. # 所以相当于是1
    16. sum_result = 0
    17. for inner in range(0, 1):
    18. sum_result += np_a[i, j]
    19. np_out[j, i] = sum_result
    20. print("input:\n", np_a)
    21. print("torch ein out: \n", torch_ein_out)
    22. print("torch org out: \n", torch_org_out)
    23. print("numpy out: \n", np_out)
    24. print("is np_out == torch_org_out ?", np.allclose(torch_ein_out, np_out))
    25. print("is torch_ein_out == torch_org_out ?", np.allclose(torch_ein_out, torch_org_out))
    26. # 终端打印结果
    27. # input:
    28. # [[0 1 2]
    29. # [3 4 5]]
    30. # torch ein out:
    31. # [[0 3]
    32. # [1 4]
    33. # [2 5]]
    34. # torch org out:
    35. # [[0 3]
    36. # [1 4]
    37. # [2 5]]
    38. # numpy out:
    39. # [[0 3]
    40. # [1 4]
    41. # [2 5]]
    42. # is np_out == torch_org_out ? True
    43. # is torch_ein_out == torch_org_out ? True

    3. permute 高维张量转置

    1. import torch
    2. import numpy as np
    3. a = torch.randn(2,3,5,7,9)
    4. # i = 7, j = 9
    5. torch_ein_out = torch.einsum('...ij->...ji', [a]).numpy()
    6. torch_org_out = a.permute(0, 1, 2, 4, 3).numpy()
    7. np_a = a.numpy()
    8. # 循环展开实现
    9. np_out = np.empty((2,3,5,9,7), dtype=np.float32)
    10. # 自由索引外循环
    11. for j in range(0, 9):
    12. for i in range(0, 7):
    13. # 求和索引内循环
    14. # 这个例子没有求和索引
    15. sum_result = 0
    16. for inner in range(0, 1):
    17. sum_result += np_a[..., i, j]
    18. np_out[..., j, i] = sum_result
    19. print("is np_out == torch_org_out ?", np.allclose(torch_ein_out, np_out))
    20. print("is torch_ein_out == torch_org_out ?", np.allclose(torch_ein_out, torch_org_out))
    21. # 终端打印结果
    22. # is np_out == torch_org_out ? True
    23. # is torch_ein_out == torch_org_out ? True

    4. reduce sum

    1. import torch
    2. import numpy as np
    3. a = torch.arange(6).reshape(2, 3)
    4. # i = 2, j = 3
    5. torch_ein_out = torch.einsum('ij->', [a]).numpy()
    6. torch_org_out = torch.sum(a).numpy()
    7. np_a = a.numpy()
    8. # 循环展开实现
    9. np_out = np.empty((1, ), dtype=np.int32)
    10. # 自由索引外循环
    11. # 这个例子中没有自由索引
    12. # 相当于所有维度都加一起
    13. for o in range(0 ,1):
    14. # 求和索引内循环
    15. # 这个例子中,i 和 j
    16. # 都是求和索引
    17. sum_result = 0
    18. for i in range(0, 2):
    19. for j in range(0, 3):
    20. sum_result += np_a[i, j]
    21. np_out[o] = sum_result
    22. print("input:\n", np_a)
    23. print("torch ein out: \n", torch_ein_out)
    24. print("torch org out: \n", torch_org_out)
    25. print("numpy out: \n", np_out)
    26. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
    27. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
    28. # 终端打印结果
    29. # input:
    30. # [[0 1 2]
    31. # [3 4 5]]
    32. # torch ein out:
    33. # 15
    34. # torch org out:
    35. # 15
    36. # numpy out:
    37. # [15]
    38. # is np_out == torch_ein_out ? True
    39. # is torch_org_out == torch_ein_out ? True

    5.矩阵按列求和

    1. import torch
    2. import numpy as np
    3. a = torch.arange(6).reshape(2, 3)
    4. # i = 2, j = 3
    5. torch_ein_out = torch.einsum('ij->j', [a]).numpy()
    6. torch_org_out = torch.sum(a, dim=0).numpy()
    7. np_a = a.numpy()
    8. # 循环展开实现
    9. np_out = np.empty((3, ), dtype=np.int32)
    10. # 自由索引外循环
    11. # 这个例子中是 j
    12. for j in range(0, 3):
    13. # 求和索引内循环
    14. # 这个例子中是 i
    15. sum_result = 0
    16. for i in range(0, 2):
    17. sum_result += np_a[i, j]
    18. np_out[j] = sum_result
    19. print("input:\n", np_a)
    20. print("torch ein out: \n", torch_ein_out)
    21. print("torch org out: \n", torch_org_out)
    22. print("numpy out: \n", np_out)
    23. print("is np_out == torch_ein_out ?", np.allclose(torch_org_out, np_out))
    24. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
    25. # 终端打印输出
    26. # input:
    27. # [[0 1 2]
    28. # [3 4 5]]
    29. # torch ein out:
    30. # [3 5 7]
    31. # torch org out:
    32. # [3 5 7]
    33. # numpy out:
    34. # [3 5 7]
    35. # is np_out == torch_ein_out ? True
    36. # is torch_org_out == torch_ein_out ? True

    6. 矩阵向量乘法

    1. import torch
    2. import numpy as np
    3. a = torch.arange(6).reshape(2, 3)
    4. b = torch.arange(3)
    5. # i = 2, k = 3
    6. torch_ein_out = torch.einsum('ik,k->i', [a, b]).numpy()
    7. # 等价形式,可以省略箭头和输出
    8. torch_ein_out2 = torch.einsum('ik,k', [a, b]).numpy()
    9. torch_org_out = torch.mv(a, b).numpy()
    10. np_a = a.numpy()
    11. np_b = b.numpy()
    12. # 循环展开实现
    13. np_out = np.empty((2, ), dtype=np.int32)
    14. # 自由索引外循环
    15. # 这个例子是 i
    16. for i in range(0, 2):
    17. # 求和索引内循环
    18. # 这个例子中是 k
    19. sum_result = 0
    20. for k in range(0, 3):
    21. sum_result += np_a[i, k] * np_b[k]
    22. np_out[i] = sum_result
    23. print("matrix a:\n", np_a)
    24. print("vector b:\n", np_b)
    25. print("torch ein out: \n", torch_ein_out)
    26. print("torch ein out2: \n", torch_ein_out2)
    27. print("torch org out: \n", torch_org_out)
    28. print("numpy out: \n", np_out)
    29. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
    30. print("is torch_ein_out2 == torch_ein_out ?", np.allclose(torch_ein_out2, torch_ein_out))
    31. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
    32. # 终端打印输出
    33. # matrix a:
    34. # [[0 1 2]
    35. # [3 4 5]]
    36. # vector b:
    37. # [0 1 2]
    38. # torch ein out:
    39. # [ 5 14]
    40. # torch ein out2:
    41. # [ 5 14]
    42. # torch org out:
    43. # [ 5 14]
    44. # numpy out:
    45. # [ 5 14]
    46. # is np_out == torch_ein_out ? True
    47. # is torch_ein_out2 == torch_ein_out ? True
    48. # is torch_org_out == torch_ein_out ? True

    7. 矩阵乘法

    1. import torch
    2. import numpy as np
    3. a = torch.arange(6).reshape(2, 3)
    4. b = torch.arange(15).reshape(3, 5)
    5. # i = 2, k = 3, j = 5
    6. torch_ein_out = torch.einsum('ik,kj->ij', [a, b]).numpy()
    7. # 等价形式,可以省略箭头和输出
    8. torch_ein_out2 = torch.einsum('ik,kj', [a, b]).numpy()
    9. torch_org_out = torch.mm(a, b).numpy()
    10. np_a = a.numpy()
    11. np_b = b.numpy()
    12. # 循环展开实现
    13. np_out = np.empty((2, 5), dtype=np.int32)
    14. # 自由索引外循环
    15. # 这个例子是 i 和 j
    16. for i in range(0, 2):
    17. for j in range(0, 5):
    18. # 求和索引内循环
    19. # 这个例子是 k
    20. sum_result = 0
    21. for k in range(0, 3):
    22. sum_result += np_a[i, k] * np_b[k, j]
    23. np_out[i, j] = sum_result
    24. print("matrix a:\n", np_a)
    25. print("matrix b:\n", np_b)
    26. print("torch ein out: \n", torch_ein_out)
    27. print("torch ein out2: \n", torch_ein_out2)
    28. print("torch org out: \n", torch_org_out)
    29. print("numpy out: \n", np_out)
    30. print("is numpy == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
    31. print("is torch_ein_out2 == torch_ein_out ?", np.allclose(torch_ein_out2, torch_ein_out))
    32. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
    33. # 终端打印输出
    34. # matrix a:
    35. # [[0 1 2]
    36. # [3 4 5]]
    37. # matrix b:
    38. # [[ 0 1 2 3 4]
    39. # [ 5 6 7 8 9]
    40. # [10 11 12 13 14]]
    41. # torch ein out:
    42. # [[ 25 28 31 34 37]
    43. # [ 70 82 94 106 118]]
    44. # torch ein out2:
    45. # [[ 25 28 31 34 37]
    46. # [ 70 82 94 106 118]]
    47. # torch org out:
    48. # [[ 25 28 31 34 37]
    49. # [ 70 82 94 106 118]]
    50. # numpy out:
    51. # [[ 25 28 31 34 37]
    52. # [ 70 82 94 106 118]]
    53. # is numpy == torch_ein_out ? True
    54. # is torch_ein_out2 == torch_ein_out ? True
    55. # is torch_org_out == torch_ein_out ? True

    8. 向量内积

    1. import torch
    2. import numpy as np
    3. a = torch.arange(3)
    4. b = torch.arange(3, 6) # [3, 4, 5]
    5. # i = 3
    6. torch_ein_out = torch.einsum('i,i->', [a, b]).numpy()
    7. # 等价形式,可以省略箭头和输出
    8. torch_ein_out2 = torch.einsum('i,i', [a, b]).numpy()
    9. torch_org_out = torch.dot(a, b).numpy()
    10. np_a = a.numpy()
    11. np_b = b.numpy()
    12. # 循环展开实现
    13. np_out = np.empty((1, ), dtype=np.int32)
    14. # 自由索引外循环
    15. # 这个例子没有自由索引
    16. for o in range(0, 1):
    17. # 求和索引内循环
    18. # 这个例子是 i
    19. sum_result = 0
    20. for i in range(0, 3):
    21. sum_result += np_a[i] * np_b[i]
    22. np_out[o] = sum_result
    23. print("vector a:\n", np_a)
    24. print("vector b:\n", np_b)
    25. print("torch ein out: \n", torch_ein_out)
    26. print("torch ein out2: \n", torch_ein_out2)
    27. print("torch org out: \n", torch_org_out)
    28. print("numpy out: \n", np_out)
    29. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
    30. print("is torch_ein_out2 == torch_ein_out ?", np.allclose(torch_ein_out2, torch_ein_out))
    31. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
    32. # 终端打印输出
    33. # vector a:
    34. # [0 1 2]
    35. # vector b:
    36. # [3 4 5]
    37. # torch ein out:
    38. # 14
    39. # torch ein out2:
    40. # 14
    41. # torch org out:
    42. # 14
    43. # numpy out:
    44. # [14]
    45. # is np_out == torch_ein_out ? True
    46. # is torch_ein_out2 == torch_ein_out ? True
    47. # is torch_org_out == torch_ein_out ? True

    9. 矩阵元素对应相乘并求reduce sum

    1. import torch
    2. import numpy as np
    3. a = torch.arange(6).reshape(2, 3)
    4. b = torch.arange(6,12).reshape(2, 3)
    5. # i = 2, j = 3
    6. torch_ein_out = torch.einsum('ij,ij->', [a, b]).numpy()
    7. # 等价形式,可以省略箭头和输出
    8. torch_ein_out2 = torch.einsum('ij,ij', [a, b]).numpy()
    9. torch_org_out = (a * b).sum().numpy()
    10. np_a = a.numpy()
    11. np_b = b.numpy()
    12. # 循环展开实现
    13. np_out = np.empty((1, ), dtype=np.int32)
    14. # 自由索引外循环
    15. # 这个例子没有自由索引
    16. for o in range(0, 1):
    17. # 求和索引内循环
    18. # 这个例子是 i 和 j
    19. sum_result = 0
    20. for i in range(0, 2):
    21. for j in range(0, 3):
    22. sum_result += np_a[i,j] * np_b[i,j]
    23. np_out[o] = sum_result
    24. print("matrix a:\n", np_a)
    25. print("matrix b:\n", np_b)
    26. print("torch ein out: \n", torch_ein_out)
    27. print("torch ein out2: \n", torch_ein_out2)
    28. print("torch org out: \n", torch_org_out)
    29. print("numpy out: \n", np_out)
    30. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
    31. print("is torch_ein_out2 == torch_ein_out ?", np.allclose(torch_ein_out2, torch_ein_out))
    32. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
    33. # 终端打印输出
    34. # matrix a:
    35. # [[0 1 2]
    36. # [3 4 5]]
    37. # matrix b:
    38. # [[ 6 7 8]
    39. # [ 9 10 11]]
    40. # torch ein out:
    41. # 145
    42. # torch ein out2:
    43. # 145
    44. # torch org out:
    45. # 145
    46. # numpy out:
    47. # [145]
    48. # is np_out == torch_ein_out ? True
    49. # is torch_ein_out2 == torch_ein_out ? True
    50. # is torch_org_out == torch_ein_out ? True

    10. 向量外积

    1. import torch
    2. import numpy as np
    3. a = torch.arange(3)
    4. b = torch.arange(3,7) # [3, 4, 5, 6]
    5. # i = 3, j = 4
    6. torch_ein_out = torch.einsum('i,j->ij', [a, b]).numpy()
    7. # 等价形式,可以省略箭头和输出
    8. torch_ein_out2 = torch.einsum('i,j', [a, b]).numpy()
    9. torch_org_out = torch.outer(a, b).numpy()
    10. np_a = a.numpy()
    11. np_b = b.numpy()
    12. # 循环展开实现
    13. np_out = np.empty((3, 4), dtype=np.int32)
    14. # 自由索引外循环
    15. # 这个例子是 i 和 j
    16. for i in range(0, 3):
    17. for j in range(0, 4):
    18. # 求和索引内循环
    19. # 这个例子没有求和索引
    20. sum_result = 0
    21. for inner in range(0, 1):
    22. sum_result += np_a[i] * np_b[j]
    23. np_out[i, j] = sum_result
    24. print("vector a:\n", np_a)
    25. print("vector b:\n", np_b)
    26. print("torch ein out: \n", torch_ein_out)
    27. print("torch ein out2: \n", torch_ein_out2)
    28. print("torch org out: \n", torch_org_out)
    29. print("numpy out: \n", np_out)
    30. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
    31. print("is torch_ein_out2 == torch_ein_out ?", np.allclose(torch_ein_out2, torch_ein_out))
    32. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_org_out, torch_ein_out))
    33. # 终端打印输出
    34. # vector a:
    35. # [0 1 2]
    36. # vector b:
    37. # [3 4 5 6]
    38. # torch ein out:
    39. # [[ 0 0 0 0]
    40. # [ 3 4 5 6]
    41. # [ 6 8 10 12]]
    42. # torch ein out2:
    43. # [[ 0 0 0 0]
    44. # [ 3 4 5 6]
    45. # [ 6 8 10 12]]
    46. # torch org out:
    47. # [[ 0 0 0 0]
    48. # [ 3 4 5 6]
    49. # [ 6 8 10 12]]
    50. # numpy out:
    51. # [[ 0 0 0 0]
    52. # [ 3 4 5 6]
    53. # [ 6 8 10 12]]
    54. # is np_out == torch_ein_out ? True
    55. # is torch_ein_out2 == torch_ein_out ? True
    56. # is torch_org_out == torch_ein_out ? True

    11. batch 矩阵乘法

    1. import torch
    2. import numpy as np
    3. a = torch.randn(2,3,5)
    4. b = torch.randn(2,5,4)
    5. # i = 2, j = 3, k = 5, l = 4
    6. torch_ein_out = torch.einsum('ijk,ikl->ijl', [a, b]).numpy()
    7. torch_org_out = torch.bmm(a, b).numpy()
    8. np_a = a.numpy()
    9. np_b = b.numpy()
    10. # 循环展开实现
    11. np_out = np.empty((2, 3, 4), dtype=np.float32)
    12. # 自由索引外循环
    13. # 这个例子是 i,j和l
    14. for i in range(0, 2):
    15. for j in range(0, 3):
    16. for l in range(0, 4):
    17. # 求和索引内循环
    18. # 这个例子是 k
    19. sum_result = 0
    20. for k in range(0, 5):
    21. sum_result += np_a[i, j, k] * np_b[i, k, l]
    22. np_out[i, j, l] = sum_result
    23. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
    24. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
    25. # 终端打印输出
    26. # is np_out == torch_ein_out ? True
    27. # is torch_org_out == torch_ein_out ? True

    12. 张量收缩(tensor contraction)

    1. import torch
    2. import numpy as np
    3. a = torch.randn(2,3,5,7)
    4. b = torch.randn(11,13,3,17,5)
    5. # p = 2, q = 3, r = 5, s = 7
    6. # t = 11, u = 13, v = 17, r = 5
    7. torch_ein_out = torch.einsum('pqrs,tuqvr->pstuv', [a, b]).numpy()
    8. torch_org_out = torch.tensordot(a, b, dims=([1, 2], [2, 4])).numpy()
    9. np_a = a.numpy()
    10. np_b = b.numpy()
    11. # 循环展开实现
    12. np_out = np.empty((2, 7, 11, 13, 17), dtype=np.float32)
    13. # 自由索引外循环
    14. # 这里就是 p,s,t,u和v
    15. for p in range(0, 2):
    16. for s in range(0, 7):
    17. for t in range(0, 11):
    18. for u in range(0, 13):
    19. for v in range(0, 17):
    20. # 求和索引内循环
    21. # 这里是 q和r
    22. sum_result = 0
    23. for q in range(0, 3):
    24. for r in range(0, 5):
    25. sum_result += np_a[p, q, r, s] * np_b[t, u, q, v, r]
    26. np_out[p, s, t, u, v] = sum_result
    27. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out, atol=1e-6))
    28. print("is torch_ein_out == torch_org_out ?", np.allclose(torch_ein_out, torch_org_out, atol=1e-6))
    29. # 终端打印输出
    30. # is np_out == torch_ein_out ? True
    31. # is torch_ein_out == torch_org_out ? True

    13. 二次变换(bilinear transformation)

    1. import torch
    2. import numpy as np
    3. a = torch.randn(2,3)
    4. b = torch.randn(5,3,7)
    5. c = torch.randn(2,7)
    6. # i = 2, k = 3, j = 5, l = 7
    7. torch_ein_out = torch.einsum('ik,jkl,il->ij', [a, b, c]).numpy()
    8. m = torch.nn.Bilinear(3, 7, 5, bias=False)
    9. m.weight.data = b
    10. torch_org_out = m(a, c).detach().numpy()
    11. np_a = a.numpy()
    12. np_b = b.numpy()
    13. np_c = c.numpy()
    14. # 循环展开实现
    15. np_out = np.empty((2, 5), dtype=np.float32)
    16. # 自由索引外循环
    17. # 这里是 i 和 j
    18. for i in range(0, 2):
    19. for j in range(0, 5):
    20. # 求和索引内循环
    21. # 这里是 k 和 l
    22. sum_result = 0
    23. for k in range(0, 3):
    24. for l in range(0, 7):
    25. sum_result += np_a[i, k] * np_b[j, k, l] * np_c[i, l]
    26. np_out[i, j] = sum_result
    27. # print("matrix a:\n", np_a)
    28. # print("matrix b:\n", np_b)
    29. print("torch ein out: \n", torch_ein_out)
    30. print("torch org out: \n", torch_org_out)
    31. print("numpy out: \n", np_out)
    32. print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
    33. print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
    34. # 终端打印输出
    35. # torch ein out:
    36. # [[-2.9185116 0.17024004 -0.43915534 1.5860008 10.016678 ]
    37. # [-0.48688257 -3.5114982 -0.7543343 -0.46790922 1.4816089 ]]
    38. # torch org out:
    39. # [[-2.9185116 0.17024004 -0.43915534 1.5860008 10.016678 ]
    40. # [-0.48688257 -3.5114982 -0.7543343 -0.46790922 1.4816089 ]]
    41. # numpy out:
    42. # [[-2.9185114 0.17023998 -0.4391551 1.5860008 10.016678 ]
    43. # [-0.4868826 -3.5114982 -0.7543342 -0.4679092 1.4816089 ]]
    44. # is np_out == torch_ein_out ? True
    45. # is torch_org_out == torch_ein_out ? True

    从上面的13个例子可以看出,只要确定了自由索引和求和索引,einsum 的输出计算都可以用一套比较通用的多层循来实现,外层的循环对应自由索引,内层循环对应求和索引。

    3

    PyTorch einsum 实现简要解读

    C++ 代码解读

    Github 代码链接: 

    https://github.com/pytorch/pytorch/blob/53596cdb7359116e8c8ae18ffef06f2677ad1296/aten/src/ATen/native/Linear.cpp#L148

    我只读懂了大概的实现思路,然后按照我自己的理解添加了注释(仅供参考):

    1. // 为了方便理解,我简化了大部分代码,
    2. // 并把对于 "..." 省略号的处理去掉了
    3. /**
    4. * 代码实现主要分为3大步:
    5. * 1. 解析 equation,分别得到输入和输出对应的字符串
    6. * 2. 补全输出和输入张量的维度,通过 permute 操作对齐输入和输出的维度
    7. * 3. 将维度对齐之后的输入张量相乘,然后根据求和索引累加
    8. */
    9. Tensor einsum(std::string equation, TensorList operands) {
    10. // ......
    11. // 把 equation 按照箭头分割
    12. // 得到箭头左边输入的部分
    13. const auto arrow_pos = equation.find("->");
    14. const auto lhs = equation.substr(0, arrow_pos);
    15. // 获取输入操作数个数
    16. const auto num_ops = operands.size();
    17. // 下面循环主要作用是解析 equation 左边输入部分,
    18. // 按 ',' 号分割得到每个输入张量对应的字符串,
    19. // 并把并把每个 char 字符转成 int, 范围 [0, 25]
    20. // 新建 vector 保存每个输入张量对应的字符数组
    21. std::vector<std::vector<int>> op_labels(num_ops);
    22. std::size_t curr_op = 0;
    23. for (auto i = decltype(lhs.length()){0}; i < lhs.length(); ++i) {
    24. switch (lhs[i]) {
    25. // ......
    26. case ',':
    27. // 遇到逗号,接下来解析下一个输入张量的字符串
    28. ++curr_op;
    29. // ......
    30. break;
    31. default:
    32. // ......
    33. // 把 char 字符转成 int
    34. op_labels[curr_op].push_back(lhs[i] - 'a');
    35. }
    36. }
    37. // TOTAL_LABELS = 26
    38. constexpr int TOTAL_LABELS = 'z' - 'a' + 1;
    39. std::vector<int> label_count(TOTAL_LABELS, 0);
    40. // 遍历所有输入操作数
    41. // 统计 equation 中 'a' - 'z' 每个字符的出现次数
    42. for(const auto i : c10::irange(num_ops)) {
    43. const auto labels = op_labels[i];
    44. for (const auto& label : labels) {
    45. // ......
    46. ++label_count[label];
    47. }
    48. // ......
    49. }
    50. // 创建一个 vector 用于保存 equation
    51. // 箭头右边输出的字符到索引的映射
    52. std::vector<int64_t> label_perm_index(TOTAL_LABELS, -1);
    53. int64_t perm_index = 0;
    54. // ......
    55. // 接下来解析输出字符串
    56. if (arrow_pos == std::string::npos) {
    57. // 处理用户省略了箭头的情况,
    58. // ......
    59. } else {
    60. // 一般情况
    61. // 得到箭头右边的输出
    62. const auto rhs = equation.substr(arrow_pos + 2);
    63. // 遍历输出字符串并解析
    64. for (auto i = decltype(rhs.length()){0}; i < rhs.length(); ++i) {
    65. switch (rhs[i]) {
    66. // ......
    67. default:
    68. // ......
    69. const auto label = rhs[i] - 'a';
    70. // ......
    71. // 建立字符到索引的映射,perm_index从0开始
    72. label_perm_index[label] = perm_index++;
    73. }
    74. }
    75. }
    76. // 保存原始的输出维度大小
    77. const int64_t out_size = perm_index;
    78. // 对齐输出张量的维度,使得对齐之后的维度等于
    79. // 自由索引加上求和索引的个数
    80. // 对输出补全省略掉的求和索引
    81. // 也就是在输入等式中出现,但是没有在输出等式中出现的字符
    82. for (const auto label : c10::irange(TOTAL_LABELS)) {
    83. if (label_count[label] > 0 && label_perm_index[label] == -1) {
    84. label_perm_index[label] = perm_index++;
    85. }
    86. }
    87. // 对所有输入张量,同样补齐维度至与输出维度大小相同
    88. // 最后对输入做 permute 操作,使得输入张量的每一维
    89. // 与输出张量的每一维能对上
    90. std::vector<Tensor> permuted_operands;
    91. for (const auto i: c10::irange(num_ops)) {
    92. // 保存输入张量最终做 permute 时候的维度映射
    93. std::vector<int64_t> perm_shape(perm_index, -1);
    94. Tensor operand = operands[i];
    95. // 取输入张量对应的 equation
    96. const auto labels = op_labels[i];
    97. std::size_t j = 0;
    98. for (const auto& label : labels) {
    99. // ......
    100. // 建立当前遍历到的输入张量字符到
    101. // 输出张量的字符到的映射
    102. // label: 当前遍历到的字符
    103. // label_perm_index: 保存了输出字符对应的索引
    104. // 所以 perm_shape 就是建立了输入张量的每一维度
    105. // 与输出张量维度的对应关系
    106. perm_shape[label_perm_index[label]] = j++;
    107. }
    108. // 如果输入张量的维度小于补全后的输出
    109. // 那么 perm_shape 中一定存在值为 -1 的元素
    110. // 那么相当于需要扩充输入张量的维度
    111. // 扩充的维度添加在张量的尾部
    112. for (int64_t& index : perm_shape) {
    113. if (index == -1) {
    114. // 在张量尾部插入维度1
    115. operand = operand.unsqueeze(-1);
    116. // 修改了perm_shape中的index,
    117. // 因为是引用取值
    118. index = j++;
    119. }
    120. }
    121. // 把输入张量的维度按照输出张量的维度重排,采用 permute 操作
    122. permuted_operands.push_back(operand.permute(perm_shape));
    123. }
    124. // ......
    125. Tensor result = permuted_operands[0];
    126. // .....
    127. // 计算最终结果
    128. for (const auto i: c10::irange(1, num_ops)) {
    129. Tensor operand = permuted_operands[i];
    130. // 新建 vector 用于保存求和索引
    131. std::vector<int64_t> sum_dims;
    132. // ......
    133. // 详细的代码可以阅读 PyTorch 源码
    134. // 这里我还没有完全理解 sumproduct_pair 的实现,
    135. // 里面用的是 permute + bmm,
    136. // 不过我觉得可以简单理解为
    137. // 将张量做广播乘法,再根据求和索引做累加
    138. result = sumproduct_pair(result, operand, sum_dims, false);
    139. }
    140. return result;
    141. }

    图解实现

    下面还是用矩阵乘法来说明C++的实现思路,下图展示的是矩阵乘法的通用实现:

    363dba31e33c37d8554ed65b9d6f3b14.png

    接下来展示C++的实现思路:

    94493a50cdcae76c3f6994ac9928f7c0.png

    4

    总结

    通过上面的实际例子和代码解读,可以看到 einsum 非常灵活,可以方便地实现各种常用的张量操作。希望读者通过这篇文章也可以轻松掌握 einsum 的基本用法。文中对于 PyTorch C++实现代码的解析是基于作者自己的理解,如果觉得有误或者不理解的地方欢迎讨论。

    参考资料

    1.https://www.youtube.com/watch?v=pkVwUVEHmfI&ab_channel=AladdinPersson

    2.https://rockt.github.io/2018/04/30/einsum

    3.https://ajcr.net/Basic-guide-to-einsum/

    4.https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/

    其他人都在看

    欢迎下载体验OneFlow新一代开源深度学习框架:GitHub - Oneflow-Inc/oneflow: OneFlow is a performance-centered and open-source deep learning framework.icon-default.png?t=M1L8https://github.com/Oneflow-Inc/oneflow/

  • 相关阅读:
    C++之函数模板、类模板、模板的特化
    winlicense官方版是一款功能专业强大的编程软件
    ubuntu18.04与windows文件互传
    PC_输入输出系统/设备_I/O系统(IO接口)基础
    浅谈 DLL 导出函数中的转发器函数
    漏刻有时数据可视化Echarts组件开发(26):全国地图三级热力图下钻和对接api自动调用数据开发实录
    编译 gtsam
    idea本地运行正常,打包部署以后openFeign调用失败,返回为null,以及报错406
    package.json 依赖版本中的符号含义
    【VSCode】文件模板创建及使用.md
  • 原文地址:https://blog.csdn.net/OneFlow_Official/article/details/123124430