• pytorch view()和reshape() 详解


    前言

    如果没有时间看下去,这里直接告诉你结论:

    • 两者都是用来重塑tensor的shape的。

    • view只适合对满足连续性条件(contiguous)的tensor进行操作,并且该操作不会开辟新的内存空间,只是产生了对原存储空间的一个新别称和引用,返回值是视图。

    • reshape对适合对满足连续性条件(contiguous)的tensor进行操作返回值是视图,否则返回副本(此时等价于先调用contiguous()方法在使用view())

    • 考虑内存的开销而且要确保重塑后的tensor与之前的tensor共享存储空间,那就使用view

    • view能干的reshape都能干 如果只是重塑一个tensor的shape 那就无脑选择reshape

    pytorch Tensor 介绍

    想要深入理解view与reshape的区别,首先要理解一些有关PyTorch张量存储的底层原理,比如tensor的头信息区(Tensor)和存储区 (Storage)以及tensor的步长Stride

    Tensor 文档链接

    Tensor 存储结构介绍

    tensor数据采用头信息区(Tensor)和存储区 (Storage)分开存储的形式,如图1所示。变量名以及其存储的数据是分为两个区域分别存储的。比如,我们定义并初始化一个tensor,tensor名为A,A的形状size、步长stride、数据的索引等信息都存储在头信息区,而A所存储的真实数据则存储在存储区。另外,如果我们对A进行截取、转置或修改等操作后赋值给B,则B的数据共享A的存储区,存储区的数据数量没变,变化的只是B的头信息区对数据的索引方式。如果听说过浅拷贝和深拷贝的话,很容易明白这种方式其实就是浅拷贝。

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-N2OUa60I-1661605448596)(https://note.youdao.com/yws/res/4/WEBRESOURCEc24383dc6161107a6217f220c1813a44)]

    代码示例如下:

    import torch
    a = torch.arange(5)  # 初始化张量 a 为 [0, 1, 2, 3, 4]
    b = a[2:]            # 截取张量a的部分值并赋值给b,b其实只是改变了a对数据的索引方式
    print('a:', a)
    print('b:', b)
    print('ptr of storage of a:', a.storage().data_ptr())  # 打印a的存储区地址
    print('ptr of storage of b:', b.storage().data_ptr())  # 打印b的存储区地址,可以发现两者是共用存储区
     
    print('==================================================================')
     
    b[1] = 0    # 修改b中索引为1,即a中索引为3的数据为0
    print('a:', a)
    print('b:', b)
    print('ptr of storage of a:', a.storage().data_ptr())  # 打印a的存储区地址,可以发现a的相应位置的值也跟着改变,说明两者是共用存储区
    print('ptr of storage of b:', b.storage().data_ptr())  # 打印b的存储区地址
     
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-d5robjjk-1661605448598)(https://note.youdao.com/yws/res/c/WEBRESOURCEd715646616ddbd9ab2e857e3934bea1c)]

    Tensor的步长(stride)属性

    torch的tensor也是有步长属性的,说起stride属性是不是很耳熟?是的,卷积神经网络中卷积核对特征图的卷积操作也是有stride属性的,但这两个stride可完全不是一个意思哦。tensor的步长可以理解为从索引中的一个维度跨到下一个维度中间的跨度

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2RVZKAQT-1661605448599)(https://note.youdao.com/yws/res/7/WEBRESOURCE9995364b2e5d8cf3a377a8952e9b68b7)]

    我们看下如下例子:

    
    import torch
    a = torch.arange(6).reshape(2, 3)  # 初始化张量 a
    b = torch.arange(6).view(3, 2)     # 初始化张量 b
    print('a:', a)
    print('stride of a:', a.stride())  # 打印a的stride
    print('b:', b)
    print('stride of b:', b.stride())  # 打印b的stride
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Z9CgkX3S-1661605448600)(https://note.youdao.com/yws/res/8/WEBRESOURCEa4349796e7e709732c62fb42c6635888)]

    Tensor View 理解

    参考链接

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yJJRFyYh-1661605448601)(https://note.youdao.com/yws/res/d/WEBRESOURCE71d76e815ae8727f47ab028bfa1263ad)]

    大致意思是:

    返回的张量共享相同的数据,并且必须具有相同数量的元素,但可能具有不同的大小。对于要查看的张量,新视图大小必须与其原始大小和步幅兼容,即每个新视图维度必须是原始维度的子空间,或者满足以下连续条件:
    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-g3br8pUJ-1661605448602)(https://note.youdao.com/yws/res/8/WEBRESOURCEece0d25023a8856d1343a7c255a9c8c8)]

    否则需要先使用contiguous()方法将原始tensor转换为满足连续条件的tensor,然后就可以使用view方法进行shape变换了。或者直接使用reshape方法进行维度变换,但这种方法变换后的tensor就不是与原始tensor共享内存了,而是被重新开辟了一个空间。

    如何理解tensor是否满足连续条件呐?下面通过一系列例子来慢慢理解下

    查看tensor的stride、size属性
    如下例子:
    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4ylVLZbx-1661605448603)(https://note.youdao.com/yws/res/4/WEBRESOURCEf56c32ec5f8974b6a2656f7a506cb534)]

    我们可以看到结果是满足连续性的 stride[0] = 3 = 1X3

    下面我们看看不满足连续性的例子:

    import torch
    a = torch.arange(9).reshape(3, 3)     # 初始化张量a
    b = a.permute(1, 0)  # 对a进行转置
    print('struct of b:\n', b)
    print('size   of b:', b.size())    # 查看b的shape
    print('stride of b:', b.stride())  # 查看b的stride
     
    '''   运行结果   '''
    struct of b:
    tensor([[0, 3, 6],
            [1, 4, 7],
            [2, 5, 8]])
    size   of b: torch.Size([3, 3])
    stride of b: (1, 3)   # 注:此时不满足连续性条件
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    输出a和b的存储区来看一下有没有什么不同:

    import torch
    a = torch.arange(9).reshape(3, 3)             # 初始化张量a
    print('ptr of storage of a: ', a.storage().data_ptr())  # 查看a的storage区的地址
    print('storage of a: \n', a.storage())        # 查看a的storage区的数据存放形式
    b = a.permute(1, 0)                           # 转置
    print('ptr of storage of b: ', b.storage().data_ptr())  # 查看b的storage区的地址
    print('storage of b: \n', b.storage())        # 查看b的storage区的数据存放形式
     
    '''   运行结果   '''
    ptr of storage of a:  1899603060672
    storage of a: 
      0
     1
     2
     3
     4
     5
     6
     7
     8
    [torch.LongStorage of size 9]
    ptr of storage of b:  1899603060672
    storage of b: 
      0
     1
     2
     3
     4
     5
     6
     7
     8
    [torch.LongStorage of size 9]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    由结果可以看出,张量a、b仍然共用存储区,并且存储区数据存放的顺序没有变化,这也充分说明了b与a共用存储区,b只是改变了数据的索引方式。那么为什么b就不符合连续性条件了呐(T-T)?其实原因很简单,我们结合图3来解释下:
    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6SrlS5US-1661605448603)(https://note.youdao.com/yws/res/2/WEBRESOURCE1ae714a0b792817b02987eb595d77e02)]

    Torch.reshape

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SG9BbXZu-1661605448604)(https://note.youdao.com/yws/res/c/WEBRESOURCE3162ae9c77fe9339b2b6cd326a21e5cc)]

    作用:与view方法类似,将输入tensor转换为新的shape格式。
    但是reshape方法更强大,可以认为a.reshape = a.view() + a.contiguous().view()。
    即:在满足tensor连续性条件时,a.reshape返回的结果与a.view()相同,否则返回的结果与a.contiguous().view()相同

    参考文章和链接

    https://blog.csdn.net/Flag_ing/article/details/109129752
    https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
    https://stackoverflow.com/questions/49643225/whats-the-difference-between-reshape-and-view-in-pytorch

  • 相关阅读:
    CUDA编程入门系列(二) GPU硬件架构综述
    软体机器人与拓扑优化
    Kotlin编程实战——开始(03)
    【Linux入门指北】Linux磁盘扩容
    第15届蓝桥杯题解
    【C/C++】C语言太细了
    element中的el-input只能输入数字,输入其他内容清空
    PHP Hyperf框架 RPC调用内存泄露
    欢乐钓鱼大师一键钓鱼,解放双手!
    C代码实现循环队列
  • 原文地址:https://blog.csdn.net/BXD1314/article/details/126562501