• 张量的连续性、contiguous函数


            在pytorch中,tensor的实际数据以一维数组(storage)的形式存储于某个连续的内存中,“行优先”进行存储。

    tensor的连续性

             tensor连续(contiguous)是指tensor的storage元素排列顺序与其按行优先时的元素排列顺序相同。如下图所示:

            上图中,tensor b是tensor a经过转置而来的,即使用了 tensor.t() 方法。

            出现不连续现象,本质上是由于pytorch中不同tensor可能共用同一个storage导致的。
        pytorch的很多操作都会导致tensor不连续,如tensor.transpose()(tensor.t())、tensor.narrow()、tensor.expand()。
            以转置为例,因为转置操作前后共用同一个storage,但显然转置后的tensor按照行优先排列成1维后与原storage不同了,因此转置后结果属于不连续(见下例)。

    2. tensor.is_contiguous()

       is_contiguous直观的解释是Tensor底层一维数组元素的存储顺序与Tensor按行优先一维展开的元素顺序是否一致

            如果想要变得连续使用contiguous方法,如果Tensor不是连续的,则会重新开辟一块内存空间保证数据是在内存中是连续的;如果Tensor是连续的,则contiguous无操作。

            tensor.is_contiguous()用于判断tensor是否连续,以转置为例说明:

    1. import torch
    2. a = torch.tensor([[1, 2, 3],
    3. [4, 5, 6],
    4. [7, 8, 9]])
    5. print(a)
    6. print(a.storage())
    7. print(a.is_contiguous()) # a是连续的
    8. """
    9. tensor([[1, 2, 3],
    10. [4, 5, 6],
    11. [7, 8, 9]])
    12. 1
    13. 2
    14. 3
    15. 4
    16. 5
    17. 6
    18. 7
    19. 8
    20. 9
    21. [torch.LongStorage of size 9]
    22. True
    23. """
    24. b = a.t() # b是a的转置
    25. print(b)
    26. print(b.storage())
    27. print(b.is_contiguous()) # b是不连续的
    28. """
    29. tensor([[1, 4, 7],
    30. [2, 5, 8],
    31. [3, 6, 9]])
    32. 1
    33. 2
    34. 3
    35. 4
    36. 5
    37. 6
    38. 7
    39. 8
    40. 9
    41. [torch.LongStorage of size 9]
    42. False
    43. """

    3. tensor不连续的后果

            tensor不连续会导致某些操作无法进行,比如view()就无法进行。在上面的例子中:由于 b 是不连续的,所以对其进行view()操作会报错;b.view(3,3)没报错,因为b本身的shape就是(3,3)。

    1. print(b.view(3, 3))
    2. """
    3. tensor([[1, 4, 7],
    4. [2, 5, 8],
    5. [3, 6, 9]])
    6. """
    7. print(b.view(1, 9))# 报错
    8. print(b.view(-1))# 报错

     4. tensor.contiguous()


            tensor.contiguous()返回一个与原始tensor有相同元素的 “连续”tensor,如果原始tensor本身就是连续的,则返回原始tensor。
            注意:tensor.contiguous()函数不会对原始数据做任何修改,他不仅返回一个新tensor,还为这个新tensor创建了一个新的storage,在这个storage上,该新的tensor是连续的。
    继续使用上面的例子:

    1. c = b.contiguous()
    2. print(b)
    3. print(c)
    4. print(b.storage())
    5. print(c.storage())

     输出结果:

    1. # b
    2. tensor([[1, 4, 7],
    3. [2, 5, 8],
    4. [3, 6, 9]])
    5. # c
    6. tensor([[1, 4, 7],
    7. [2, 5, 8],
    8. [3, 6, 9]])
    9. # b.storage
    10. 1
    11. 2
    12. 3
    13. 4
    14. 5
    15. 6
    16. 7
    17. 8
    18. 9
    19. [torch.LongStorage of size 9]
    20. #c.storage
    21. 1
    22. 4
    23. 7
    24. 2
    25. 5
    26. 8
    27. 3
    28. 6
    29. 9
    30. [torch.LongStorage of size 9]

    接着运行如下代码: 

    1. print(b.is_contiguous()) # False
    2. print(c.is_contiguous()) # True
    3. print(c.view(1, 9)) # tensor([[1, 4, 7, 2, 5, 8, 3, 6, 9]])


    参考自:https://blog.csdn.net/baidu_41774120/article/details/128666944

  • 相关阅读:
    Android平台GB28181设备接入模块之SmartGBD
    【JavaSE】基础笔记 - 类和对象(下)
    Shell 实现文件基本操作(切割、排序、去重)
    Visual Studio 2022 版本 17.5 预览版 正式上线,有你期待的功能吗?
    大数据必学Java基础(七十七):线程的生命周期和常见方法
    Mybatis的collection三层嵌套查询(验证通过)
    【ViT(Vision Transformer)】(一) 中英双语
    数据库索引详解
    离线量化(后量化)算法研究-----脉络梳理
    怎么统计 20 亿用户的登录状态 | bitmap
  • 原文地址:https://blog.csdn.net/m0_48241022/article/details/132804698