• Pytorch中 permute、transpose 和 view 、resize函数


    1、transpose与permute

    transpose() 和 permute() 都是返回转置后矩阵,在pytorch中转置用的函数就只有这两个 ,这两个函数都是交换维度的操作

    transpose用法:tensor.transpose(dim0, dim1) → Tensor
    只能操作2D矩阵的转置, transpose每次只能交换两个维度, 这是相比于permute的一个不同点,每次输入两个index,实现转置,,参数顺序无所谓。
    permute用法:tensor.permute(dim0, dim1, ...., dimn)
    permute可以进行多维度转置, permute每次可以交换多个维度,且必须传入所有维度数,参数顺序表示交换结果是原值的哪个维。

    permute操作可以有1至多步的Transpose操作实现

    注意:使用transpose或permute之后,若要使用view,必须先contiguous()

    1. # 创造二维数据x,dim=0时候2,dim=1时候3
    2. x = torch.randn(2,3) 'x.shape → [2,3]'
    3. # 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4
    4. y = torch.randn(2,3,4) 'y.shape → [2,3,4]'
    5. """
    6. 操作dim不同:
    7. transpose()只能一次操作两个维度;permute()可以一次操作多维数据,
    8. 且必须传入所有维度数,因为permute()的参数是int*。
    9. """
    10. # 对于transpose
    11. x.transpose(0,1) 'shape→[3,2] '
    12. x.transpose(1,0) 'shape→[3,2] '
    13. y.transpose(0,1) 'shape→[3,2,4]'
    14. y.transpose(0,2,1) 'error,操作不了多维'
    15. # 对于permute()
    16. x.permute(0,1) 'shape→[2,3]'
    17. x.permute(1,0) 'shape→[3,2], 注意返回的shape不同于x.transpose(1,0) '
    18. y.permute(0,1) "error 没有传入所有维度数"
    19. y.permute(1,0,2) 'shape→[3,2,4]'
    20. """
    21. 操作dim不同:
    22. transpose()只能一次操作两个维度, 维度的顺序不影响结果;permute()可以一次操作多维数据,
    23. 且必须传入所有维度数,因为permute()的参数是int*。
    24. """
    25. # 对于transpose, (0,1) 和 (1,0) 都是指变换 维度 0 和 1, 输入顺序不影响
    26. x1 = x.transpose(0,1) 'shape→[3,2] '
    27. x2 = x.transpose(1,0) '也变换了,shape→[3,2] '
    28. # 对于permute(),
    29. x1 = x.permute(0,1) '保持原理tensor不变, 不同transpose,shape→[2,3] '
    30. x2 = x.permute(1,0) 'shape→[3,2] '
    31. y1 = y.permute(0,1,2) '保持不变,shape→[2,3,4] '
    32. y2 = y.permute(1,0,2) 'shape→[3,2,4] '
    33. y3 = y.permute(1,2,0) 'shape→[3,4,2] '

    2、关于连续contiguous()

    用view()函数改变通过转置后的数据结构,导致报错

    RuntimeError: invalid argument 2: view size is not compatible with input tensor's....

    这是因为tensor经过转置后数据的内存地址不连续导致的,也就是tensor . is_contiguous()==False。
    虽然在torch里面,view函数相当于numpy的reshape,但是这时候reshape()可以改变该tensor结构,但是view()不可以

    1. x = torch.rand(3,4)
    2. x = x.transpose(0,1)
    3. print(x.is_contiguous()) # 是否连续
    4. 'False'
    5. # 会发现
    6. x.view(3,4)
    7. '''
    8. RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
    9. 就是不连续导致的
    10. '''
    11. # 但是这样是可以的。
    12. x = x.contiguous()
    13. x.view(3,4)
    14. x = torch.rand(3,4)
    15. x = x.permute(1,0) # 等价x = x.transpose(0,1)
    16. x.reshape(3,4)
    17. '''这就不报错了
    18. 说明x.reshape(3,4) 这个操作
    19. 等于x = x.contiguous().view()
    20. 尽管如此,但是我们还是不推荐使用reshape
    21. 除非为了获取完全不同但是数据相同的克隆体
    22. '''

    调用contiguous()时,会强制拷贝一份tensor,让它的布局从头到尾创建的一毛一样。
    只需要记住了,每次在使用view()之前,该tensor只要使用了transpose()和permute()这两个函数一定要contiguous().

    transpose与permute会实实在在的根据需求(要交换的dim)把相应的Tensor元素的位置进行调整, 而view 会将Tensor所有维度拉平成一维 (即按行,这也是为什么view操作要求Tensor是contiguous的原因),然后再根据传入的的维度(只要保证各维度的乘积=总元素个数即可)信息重构出一个Tensor。

    1. a = torch.Tensor([[[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]],
    2. [[-1,-2,-3,-4,-5], [-6,-7,-8,-9,-10], [-11,-12,-13,-14,-15]]])
    3. >>> a.shape
    4. torch.Size([2, 3, 5])
    5. # 还是上面的Tensor a
    6. >>> print(a.shape)
    7. torch.Size([2, 3, 5])
    8. >>> print(a.view(2,5,3))
    9. tensor([[[ 1., 2., 3.],
    10. [ 4., 5., 6.],
    11. [ 7., 8., 9.],
    12. [ 10., 11., 12.],
    13. [ 13., 14., 15.]],
    14. [[ -1., -2., -3.],
    15. [ -4., -5., -6.],
    16. [ -7., -8., -9.],
    17. [-10., -11., -12.],
    18. [-13., -14., -15.]]])
    19. >>> c = a.transpose(1,2)
    20. >>> print(c, c.shape)
    21. (tensor([[[ 1., 6., 11.],
    22. [ 2., 7., 12.],
    23. [ 3., 8., 13.],
    24. [ 4., 9., 14.],
    25. [ 5., 10., 15.]],
    26. [[ -1., -6., -11.],
    27. [ -2., -7., -12.],
    28. [ -3., -8., -13.],
    29. [ -4., -9., -14.],
    30. [ -5., -10., -15.]]]),
    31. torch.Size([2, 5, 3]))

    即使view()transpose()最终得到的Tensor的shape是一样的,但二者内容并不相同。view函数只是按照给定的(2,5,3)的Tensor维度,将元素按顺序一个个填进去;而transpose函数,则的确是在进行第一个第二维度的转置

    3、view与reshape的区别

    view()具有跟reshape()相同的功能,都能去重塑矩阵的形状

    不同点:

    reshape()方法不受此限制;如果对 tensor 调用过 transpose, permute等操作的话会使该 tensor 在内存中变得不再连续。

    view():

    作用:将tensor转换为指定的shape,原始的data不改变。返回的tensor与原始的tensor共享存储区。view()方法只适用于满足连续性(contiguous)条件的tensor,并且该操作不会开辟新的内存空间,只是产生了对原存储空间的一个新别称和引用,返回值是视图。也就是说view不会改变原来数据的存放方式,并且,也不会产生数据的副本,view返回的是视图。

    如果tensor 不满足连续性条件,需要先调用 contiguous()方法,但这种方法变换后的tensor就不是与原始tensor共享内存了,而是被重新开辟了一个空间。

    view()可以通过在某一维度输入为-1,来动态调整这个矩阵的维度的size, 而 reshape且无动态调整的功能。而且 view()用于pytorch中对张量进行处理,

    view方法可以调整tensor的形状,但必须保证调整前后元素总数一致。view不会修改自身的数据,返回的新tensor与源tensor共享内存,即更改其中一个,另外一个也会跟着改变

    reshape():

    作用:与view方法类似,将输入tensor转换为新的shape格式。

    reshape方法更强大,可以认为a.reshape = a.view() + a.contiguous().view()

    reshape()方法的返回值既可以是视图,也可以是副本。即:在满足tensor连续性条件时,a.reshape返回的结果与a.view()相同,否则返回的结果与a.contiguous().view()相同。

    PyTorch:view() 与 reshape() 区别详解_地球被支点撬走啦的博客-CSDN博客_reshape view

    4、torch.flatten()

    torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

    input (Tensor) – 输入为Tensor 
    start_dim (int) – 展平的开始维度 
    end_dim (int) – 展平的结束维度

    展平一个连续范围的维度,输出类型为Tensor, flatten函数就是对tensor类型进行扁平化处理,也就是在不同维度上进行堆叠操作,a.flatten(m),这个意思是将a这个tensor,从第m维度开始堆叠,一直堆叠到最后一个维度

    1. import torch
    2. # t 是三维张量 torch.Size([3, 2, 2])
    3. t = torch.tensor([[[1, 2],
    4. [3, 4]],
    5. [[5, 6],
    6. [7, 8]],
    7. [[9, 10],
    8. [11, 12]]])
    9. #如果不传入参数,默认开始维度为0,最后维度为-1,展开为一维
    10. result_0 = torch.flatten(t)
    11. print(result_0)
    12. '''
    13. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    14. '''
    15. #当开始维度为1,最后维度为-1,展开为3x4,也就是说第一维度不变,后面的压缩
    16. result_1 = torch.flatten(t, start_dim=1)
    17. print(result_1)
    18. '''
    19. tensor([[ 1, 2, 3, 4],
    20. [ 5, 6, 7, 8],
    21. [ 9, 10, 11, 12]])
    22. '''
    23. torch.flatten(t, start_dim=1).size()
    24. # torch.Size([3, 4])
    25. #下面的和上面进行对比应该就能看出是,当锁定最后的维度的时候
    26. #前面的就会合并
    27. result_3 = torch.flatten(t, start_dim=0, end_dim=1)
    28. print(result_3)
    29. '''
    30. tensor([[ 1, 2],
    31. [ 3, 4],
    32. [ 5, 6],
    33. [ 7, 8],
    34. [ 9, 10],
    35. [11, 12]])
    36. '''
    37. torch.flatten(t, start_dim=0, end_dim=1).size()
    38. # torch.Size([6, 2])

    示例:

    1. import torch
    2. # 随机产生了一个tensor,它的Batchsize是2,C是3,H是2,W是3
    3. a=torch.rand(2,3,2,3)
    4. print(a)
    5. '''
    6. tensor([[[[0.5521, 0.2547, 0.5242],
    7. [0.8248, 0.4500, 0.2413]],
    8. [[0.7759, 0.1261, 0.0090],
    9. [0.0197, 0.6191, 0.0422]],
    10. [[0.0896, 0.1731, 0.5484],
    11. [0.7927, 0.0752, 0.2176]]],
    12. [[[0.0118, 0.3865, 0.9587],
    13. [0.6599, 0.2464, 0.0728]],
    14. [[0.2858, 0.3772, 0.8215],
    15. [0.3267, 0.2859, 0.4329]],
    16. [[0.7329, 0.4436, 0.4246],
    17. [0.4162, 0.8688, 0.5286]]]])
    18. '''
    19. ##########################################################################
    20. result_0 = a.flatten(0)
    21. print(result_0.shape)
    22. print(result_0)
    23. '''
    24. torch.Size([36])
    25. tensor([0.5521, 0.2547, 0.5242, 0.8248, 0.4500, 0.2413, 0.7759, 0.1261, 0.0090,
    26. 0.0197, 0.6191, 0.0422, 0.0896, 0.1731, 0.5484, 0.7927, 0.0752, 0.2176,
    27. 0.0118, 0.3865, 0.9587, 0.6599, 0.2464, 0.0728, 0.2858, 0.3772, 0.8215,
    28. 0.3267, 0.2859, 0.4329, 0.7329, 0.4436, 0.4246, 0.4162, 0.8688, 0.5286])
    29. '''
    30. ##########################################################################
    31. result_1 = a.flatten(1)
    32. print(result_1.shape)
    33. print(result_1)
    34. '''
    35. torch.Size([2, 18])
    36. tensor([[0.5521, 0.2547, 0.5242, 0.8248, 0.4500, 0.2413, 0.7759, 0.1261, 0.0090,
    37. 0.0197, 0.6191, 0.0422, 0.0896, 0.1731, 0.5484, 0.7927, 0.0752, 0.2176],
    38. [0.0118, 0.3865, 0.9587, 0.6599, 0.2464, 0.0728, 0.2858, 0.3772, 0.8215,
    39. 0.3267, 0.2859, 0.4329, 0.7329, 0.4436, 0.4246, 0.4162, 0.8688, 0.5286]])
    40. '''
    41. ##########################################################################
    42. result_2 = a.flatten(2)
    43. print(result_2.shape)
    44. print(result_2)
    45. '''
    46. torch.Size([2, 3, 6])
    47. tensor([[[0.5521, 0.2547, 0.5242, 0.8248, 0.4500, 0.2413],
    48. [0.7759, 0.1261, 0.0090, 0.0197, 0.6191, 0.0422],
    49. [0.0896, 0.1731, 0.5484, 0.7927, 0.0752, 0.2176]],
    50. [[0.0118, 0.3865, 0.9587, 0.6599, 0.2464, 0.0728],
    51. [0.2858, 0.3772, 0.8215, 0.3267, 0.2859, 0.4329],
    52. [0.7329, 0.4436, 0.4246, 0.4162, 0.8688, 0.5286]]])
    53. '''
    54. ##########################################################################
    55. result_3 = a.flatten(3)
    56. print(result_3.shape)
    57. print(result_3)
    58. '''
    59. torch.Size([2, 3, 2, 3])
    60. tensor([[[[0.5521, 0.2547, 0.5242],
    61. [0.8248, 0.4500, 0.2413]],
    62. [[0.7759, 0.1261, 0.0090],
    63. [0.0197, 0.6191, 0.0422]],
    64. [[0.0896, 0.1731, 0.5484],
    65. [0.7927, 0.0752, 0.2176]]],
    66. [[[0.0118, 0.3865, 0.9587],
    67. [0.6599, 0.2464, 0.0728]],
    68. [[0.2858, 0.3772, 0.8215],
    69. [0.3267, 0.2859, 0.4329]],
    70. [[0.7329, 0.4436, 0.4246],
    71. [0.4162, 0.8688, 0.5286]]]])
    72. '''
    73. ##########################################################################
    74. result_4 = a.flatten(0, 1)
    75. print(result_4.shape)
    76. print(result_4)
    77. '''
    78. torch.Size([6, 2, 3])
    79. tensor([[[0.5521, 0.2547, 0.5242],
    80. [0.8248, 0.4500, 0.2413]],
    81. [[0.7759, 0.1261, 0.0090],
    82. [0.0197, 0.6191, 0.0422]],
    83. [[0.0896, 0.1731, 0.5484],
    84. [0.7927, 0.0752, 0.2176]],
    85. [[0.0118, 0.3865, 0.9587],
    86. [0.6599, 0.2464, 0.0728]],
    87. [[0.2858, 0.3772, 0.8215],
    88. [0.3267, 0.2859, 0.4329]],
    89. [[0.7329, 0.4436, 0.4246],
    90. [0.4162, 0.8688, 0.5286]]])
    91. '''

    a.flatten(0)的意思就是从batchsize这个维度开始堆叠,直到W结束,那最后就是成一维的了,也就是只剩W这个维度,那当然就是只有一条这样子

    a.flatten(1)的意思就是从C(channel)这个维度开始堆叠,直到W结束,Batchsize这个维度没有参与运算,因此还是有B这个维度的,这样的话就是相当于将三维的数据堆叠成只有一个维度W的数据,那当然就变成了两条

    a.flatten(2)的意思就是从H(Height)这个维度开始堆叠,直到W结束,B和C这两个维度都没有参与运算,因此将H这个维度堆叠到W上去,就是将原本的平面变成了一个长条

    最后a.flatten(3)的意思就是将H这个维度堆叠到H这个维度上去,自己堆叠自己就是没有堆叠

    a.flatten(0,1), 将B的维度叠加到C的维度上,就是将两个batch叠加合并了

    5、flatten函数的用法及其与reshape函数的区别

    深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别_清泉_流响的博客-CSDN博客_flatten函数

  • 相关阅读:
    Gut(IF=31.793)重磅综述|肠道微生物组如何影响宿主健康
    【考研数学】线性代数第六章 —— 二次型(3,正定矩阵与正定二次型)
    Milvus+Attu
    嵌入式Android系统耳机驱动基本知识
    记录第一次开源流计算框架Flink代码的贡献
    v-model表单数据双向绑定-表单提交示例
    Matlab:合并不同的整数类型
    DataBinding双向绑定简介
    Spring cloud 集成 SkyWalking 实现性能监控、链路追踪、日志收集
    【STM32】RTC(实时时钟)
  • 原文地址:https://blog.csdn.net/ytusdc/article/details/126453243