• pytorch张量维度变换详解:view、squeeze、transpose


    目录

    1 view函数

    1.1 指定变换后的维度

    1.2 自动推理变换后的维度

    1.3 将tensor展平成一维

    2 reshape函数

    2.1 指定变换后的维度

    2.2 自动推理转换后的维度

    2.3 将tensor展平成一维

    2.4 使用tensor.reshape变换

    3 squeeze函数

    3.1 torch.squeeze去除所有为1的维度

    3.2 torch.squeeze指定dim去除

    3.3 tensor.squeeze去除为1的维度

    4 unsqueeze函数

    4.1 torch.unsqueeze指定dim插入新维度

    4.2 tensor.unsqueeze指定dim插入新维度

    5 transpose函数

    5.1 torch.transpose转置指定维度

    5.2 tensor.transpose转置指定维度

    6 expand函数

    7 repeat函数

    8 permute函数


     

    Pytorch张量维度变化是在构建模型过程中常用且重要的操作,本文从实际应用触发,详细介绍常用的维度变化方法,这些方法包含view、reshap、squeeze、unsqueeze、transpose等。

    1 view函数

    Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。

    view函数的操作对象是Tensor类型,返回的对象类型也为Tensor类型

        def view(self, *size: _int) -> Tensor: ...

    更便于理解的表示形式:

    view(参数a,参数b,…),其中,总的参数个数表示将张量重构后的维度。

    1.1 指定变换后的维度

    通过手工指定,将一个一维tensor变换为3*8维的tensor

    1. import torch
    2. a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
    3. 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
    4. a2 = a1.view(3, 8)
    5. print(a1)
    6. print(a2)
    7. print(a1.shape)
    8. print(a2.shape)

    运行程序显示如下:

    1. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
    2. 19, 20, 21, 22, 23, 24])
    3. tensor([[ 1, 2, 3, 4, 5, 6, 7, 8],
    4. [ 9, 10, 11, 12, 13, 14, 15, 16],
    5. [17, 18, 19, 20, 21, 22, 23, 24]])
    6. torch.Size([24])
    7. torch.Size([3, 8])

    1.2 自动推理变换后的维度

    如果某个参数为-1,则表示该维度取决于其它维度,由Pytorch自己补充

    1. import torch
    2. a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
    3. 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
    4. a4 = a3.view(4, -1)
    5. a5 = a3.view(2, 3, -1)
    6. a6 = a3.view(-1, 3, 2)
    7. print(a3)
    8. print(a4)
    9. print(a5)
    10. print(a6)
    11. print(a3.shape)
    12. print(a4.shape)
    13. print(a5.shape)
    14. print(a6.shape)

     运行程序显示如下:

    1. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
    2. 19, 20, 21, 22, 23, 24])
    3. tensor([[ 1, 2, 3, 4, 5, 6],
    4. [ 7, 8, 9, 10, 11, 12],
    5. [13, 14, 15, 16, 17, 18],
    6. [19, 20, 21, 22, 23, 24]])
    7. tensor([[[ 1, 2, 3, 4],
    8. [ 5, 6, 7, 8],
    9. [ 9, 10, 11, 12]],
    10. [[13, 14, 15, 16],
    11. [17, 18, 19, 20],
    12. [21, 22, 23, 24]]])
    13. tensor([[[ 1, 2],
    14. [ 3, 4],
    15. [ 5, 6]],
    16. [[ 7, 8],
    17. [ 9, 10],
    18. [11, 12]],
    19. [[13, 14],
    20. [15, 16],
    21. [17, 18]],
    22. [[19, 20],
    23. [21, 22],
    24. [23, 24]]])
    25. torch.Size([24])
    26. torch.Size([4, 6])
    27. torch.Size([2, 3, 4])
    28. torch.Size([4, 3, 2])

    1.3 将tensor展平成一维

    1. import torch
    2. a7 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    3. [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
    4. a8 = a6.view(-1)
    5. print(a7)
    6. print(a8)
    7. print(a7.shape)
    8. print(a8.shape)

     运行程序显示如下:

    1. tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    2. [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
    3. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
    4. 19, 20, 21, 22, 23, 24])
    5. torch.Size([2, 12])
    6. torch.Size([24])

    reshape函数

    返回与 input张量数据大小一样、给定 shape的张量。如果可能,返回的是input 张量的视图,否则返回的是其拷贝。

    torch.reshape(input, shape) → [Tensor]

    也可以直接在Tensor上使用reshape,形式如下:

    tensor.reshape(shape) → [Tensor]

    2.1 指定变换后的维度

    1. import torch
    2. a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a2 = torch.reshape(a1, (3, 4))
    4. print(a1.shape)
    5. print(a1)
    6. print(a2.shape)
    7. print(a2)

    运行程序显示如下:

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([3, 4])
    4. tensor([[ 1, 2, 3, 4],
    5. [ 5, 6, 7, 8],
    6. [ 9, 10, 11, 12]])

    2.2 自动推理转换后的维度

    1. import torch
    2. a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a4 = torch.reshape(a1, (-1, 6))
    4. print(a3.shape)
    5. print(a3)
    6. print(a4.shape)
    7. print(a4)

    运行程序显示如下:

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([2, 6])
    4. tensor([[ 1, 2, 3, 4, 5, 6],
    5. [ 7, 8, 9, 10, 11, 12]])

    2.3 将tensor展平成一维

    1. import torch
    2. a5 = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
    3. a6 = torch.reshape(a1, (-1,))
    4. print(a5.shape)
    5. print(a5)
    6. print(a6.shape)
    7. print(a6)

    运行程序显示如下:

    1. torch.Size([2, 6])
    2. tensor([[ 1, 2, 3, 4, 5, 6],
    3. [ 7, 8, 9, 10, 11, 12]])
    4. torch.Size([12])
    5. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

    2.4 使用tensor.reshape变换

    1. improt torch
    2. a7 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a8 = a7.reshape(6, 2)
    4. a9 = a7.reshape(-1, 3)
    5. a10 = a9.reshape(-1)
    6. print(a7.shape)
    7. print(a7)
    8. print(a8.shape)
    9. print(a8)
    10. print(a9.shape)
    11. print(a9)
    12. print(a10.shape)
    13. print(a10)

    运行结果显示如下:

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([6, 2])
    4. tensor([[ 1, 2],
    5. [ 3, 4],
    6. [ 5, 6],
    7. [ 7, 8],
    8. [ 9, 10],
    9. [11, 12]])
    10. torch.Size([4, 3])
    11. tensor([[ 1, 2, 3],
    12. [ 4, 5, 6],
    13. [ 7, 8, 9],
    14. [10, 11, 12]])
    15. torch.Size([12])
    16. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

    3 squeeze函数

    将input张量中所有维度数据为1的维度给移除掉。指定了dim,如果dim对应维度的值不为1 ,则保持不变,为1则移除该维度。

     torch.squeeze(input, dim=None) → [Tensor]

     也可以在tensor上直接使用squeeze,形式如下:

     tensor.squeeze(dim=None) → [Tensor]

    3.1 torch.squeeze去除所有为1的维度

    1. import torch
    2. a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a2 = a1.reshape(3, 1, 4)
    4. a3 = torch.squeeze(a2)
    5. print(a1.shape)
    6. print(a1)
    7. print(a2.shape)
    8. print(a2)
    9. print(a3.shape)
    10. print(a3)

     运行结果显示如下:(a2的第二个维度被移除)

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([3, 1, 4])
    4. tensor([[[ 1, 2, 3, 4]],
    5. [[ 5, 6, 7, 8]],
    6. [[ 9, 10, 11, 12]]])
    7. torch.Size([3, 4])
    8. tensor([[ 1, 2, 3, 4],
    9. [ 5, 6, 7, 8],
    10. [ 9, 10, 11, 12]])

    3.2 torch.squeeze指定dim去除为1的维度

    1. import torch
    2. a4 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a5 = a1.reshape(3, 1, 4)
    4. a6 = torch.squeeze(a5, 0)
    5. a7 = torch.squeeze(a5, 1)
    6. print(a4.shape)
    7. print(a4)
    8. print(a5.shape)
    9. print(a5)
    10. print(a6.shape)
    11. print(a6)
    12. print(a7.shape)
    13. print(a7)

    运行结果显示如下:(a5的第一个维度不为1,所以保持不变;a5的第二个维度为1,所以被移除)

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([3, 1, 4])
    4. tensor([[[ 1, 2, 3, 4]],
    5. [[ 5, 6, 7, 8]],
    6. [[ 9, 10, 11, 12]]])
    7. torch.Size([3, 1, 4])
    8. tensor([[[ 1, 2, 3, 4]],
    9. [[ 5, 6, 7, 8]],
    10. [[ 9, 10, 11, 12]]])
    11. torch.Size([3, 4])
    12. tensor([[ 1, 2, 3, 4],
    13. [ 5, 6, 7, 8],
    14. [ 9, 10, 11, 12]])

    3.3 tensor.squeeze指定dim去除为1的维度

    1. import torch
    2. a8 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a9 = a8.reshape(3, 1, 4)
    4. a10 = a9.squeeze()
    5. a11 = a9.squeeze(0)
    6. a12 = a9.squeeze(1)
    7. print(a8.shape)
    8. print(a8)
    9. print(a9.shape)
    10. print(a9)
    11. print(a10.shape)
    12. print(a10)
    13. print(a11.shape)
    14. print(a11)
    15. print(a12.shape)
    16. print(a12)

    运行结果显示如下:

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([3, 1, 4])
    4. tensor([[[ 1, 2, 3, 4]],
    5. [[ 5, 6, 7, 8]],
    6. [[ 9, 10, 11, 12]]])
    7. torch.Size([3, 4])
    8. tensor([[ 1, 2, 3, 4],
    9. [ 5, 6, 7, 8],
    10. [ 9, 10, 11, 12]])
    11. torch.Size([3, 1, 4])
    12. tensor([[[ 1, 2, 3, 4]],
    13. [[ 5, 6, 7, 8]],
    14. [[ 9, 10, 11, 12]]])
    15. torch.Size([3, 4])
    16. tensor([[ 1, 2, 3, 4],
    17. [ 5, 6, 7, 8],
    18. [ 9, 10, 11, 12]])

    4 unsqueeze函数

    在给定的 dim 维度位置插入一个新的维度,维度数值为 1,dim 的范围在 [-dim()-1, dim()+1),包首不包尾

    torch.unsqueeze(input, dim) → [Tensor]

     也可以在tensor上直接使用unsqueeze,形式如下:

    torch.unsqueeze(dim) → [Tensor]

    4.1 torch.unsqueeze指定dim插入新维度

    1. import torch
    2. a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a2 = a1.reshape(3, 4)
    4. a3 = torch.unsqueeze(a2, 0)
    5. a4 = torch.unsqueeze(a2, 2)
    6. print(a1.shape)
    7. print(a1)
    8. print(a2.shape)
    9. print(a2)
    10. print(a3.shape)
    11. print(a3)
    12. print(a4.shape)
    13. print(a4)

    运行结果显示如下:

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([3, 4])
    4. tensor([[ 1, 2, 3, 4],
    5. [ 5, 6, 7, 8],
    6. [ 9, 10, 11, 12]])
    7. torch.Size([1, 3, 4])
    8. tensor([[[ 1, 2, 3, 4],
    9. [ 5, 6, 7, 8],
    10. [ 9, 10, 11, 12]]])
    11. torch.Size([3, 4, 1])
    12. tensor([[[ 1],
    13. [ 2],
    14. [ 3],
    15. [ 4]],
    16. [[ 5],
    17. [ 6],
    18. [ 7],
    19. [ 8]],
    20. [[ 9],
    21. [10],
    22. [11],
    23. [12]]])

    4.2 tensor.unsqueeze指定dim插入新维度

    1. import torch
    2. a5 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a6 = a5.reshape(3, 4)
    4. a7 = a6.unsqueeze(0)
    5. a8 = a6.unsqueeze(1)
    6. print(a5.shape)
    7. print(a5)
    8. print(a6.shape)
    9. print(a6)
    10. print(a7.shape)
    11. print(a7)
    12. print(a8.shape)
    13. print(a8)

    运行结果显示如下:

    1. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    2. torch.Size([3, 4])
    3. tensor([[ 1, 2, 3, 4],
    4. [ 5, 6, 7, 8],
    5. [ 9, 10, 11, 12]])
    6. torch.Size([1, 3, 4])
    7. tensor([[[ 1, 2, 3, 4],
    8. [ 5, 6, 7, 8],
    9. [ 9, 10, 11, 12]]])
    10. torch.Size([3, 1, 4])
    11. tensor([[[ 1, 2, 3, 4]],
    12. [[ 5, 6, 7, 8]],
    13. [[ 9, 10, 11, 12]]])

    5 transpose函数

    返回 input 张量的转置,dim0与dim1交换位置

    torch.transpose(input, dim0, dim1) → [Tensor]

      也可以在tensor上直接使用unsqueeze,形式如下:

    tensor.transpose(dim0, dim1) → [Tensor]

    参数:

    • input ([Tensor] 输入的张量
    • dim0 ([int] 第一个要转置的维度
    • dim1 ([int] 第二个要转置的维度

    5.1 torch.transpose转置指定维度

    1. import torch
    2. a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a2 = a1.reshape(4, 3, 1)
    4. a3 = torch.transpose(a2, 0, 1)
    5. a4 = torch.transpose(a2, 1, 2)
    6. print(a1.shape)
    7. print(a1)
    8. print(a2.shape)
    9. print(a2)
    10. print(a3.shape)
    11. print(a3)
    12. print(a4.shape)
    13. print(a4)

    运行结果显示如下:

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([4, 3, 1])
    4. tensor([[[ 1],
    5. [ 2],
    6. [ 3]],
    7. [[ 4],
    8. [ 5],
    9. [ 6]],
    10. [[ 7],
    11. [ 8],
    12. [ 9]],
    13. [[10],
    14. [11],
    15. [12]]])
    16. torch.Size([3, 4, 1])
    17. tensor([[[ 1],
    18. [ 4],
    19. [ 7],
    20. [10]],
    21. [[ 2],
    22. [ 5],
    23. [ 8],
    24. [11]],
    25. [[ 3],
    26. [ 6],
    27. [ 9],
    28. [12]]])
    29. torch.Size([4, 1, 3])
    30. tensor([[[ 1, 2, 3]],
    31. [[ 4, 5, 6]],
    32. [[ 7, 8, 9]],
    33. [[10, 11, 12]]])

    5.2 tensor.transpose转置指定维度

    1. import torch
    2. a5 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a6 = a1.reshape(4, 3, 1)
    4. a7 = a6.transpose(0, 1)
    5. a8 = a6.transpose(1, 2)
    6. print(a5.shape)
    7. print(a5)
    8. print(a6.shape)
    9. print(a6)
    10. print(a7.shape)
    11. print(a7)
    12. print(a8.shape)
    13. print(a8)

    运行结果显示如下:

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([4, 3, 1])
    4. tensor([[[ 1],
    5. [ 2],
    6. [ 3]],
    7. [[ 4],
    8. [ 5],
    9. [ 6]],
    10. [[ 7],
    11. [ 8],
    12. [ 9]],
    13. [[10],
    14. [11],
    15. [12]]])
    16. torch.Size([3, 4, 1])
    17. tensor([[[ 1],
    18. [ 4],
    19. [ 7],
    20. [10]],
    21. [[ 2],
    22. [ 5],
    23. [ 8],
    24. [11]],
    25. [[ 3],
    26. [ 6],
    27. [ 9],
    28. [12]]])
    29. torch.Size([4, 1, 3])
    30. tensor([[[ 1, 2, 3]],
    31. [[ 4, 5, 6]],
    32. [[ 7, 8, 9]],
    33. [[10, 11, 12]]])

    6 expand函数

    返回张量的新视图,其某个维度 size 扩展到更大的 size,如果当前维度 size 为 -1 ,表示当前维度 size 保持不变。

    Tensor也可以扩展到更多的维度,新的会追加在最前面。对于新维度,大小不能设置为 -1;

    扩展张量不会分配新内存,而只会在现有张量上创建一个新视图。任何大小为1的维度都可以扩展为任意值,而无需分配新内存。

    Tensor.expand( *sizes) → [Tensor]

    参数:

    • sizes (torch.Size or [int] – 指定维度复制的次数

     

    1. import torch
    2. a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a2 = a1.reshape(3, 1, 4, 1)
    4. # 维度为 1 的 size 可以扩展成什么任意的 size
    5. a3 = a2.expand(3, 5, 4, 2)
    6. # -1 表示对应的维度size不变,但如果第一个维度3扩展成6则会报错,维度不为1不能扩展
    7. a4 = a2.expand(-1, 5, -1, -1)
    8. # 可以扩展新的维度,但只会放到最前面,不能放到后面(会报错)且不能设置为-1
    9. a5 = a2.expand(2, -1, 5, -1, -1)
    10. print(a1.shape)
    11. print(a1)
    12. print(a2.shape)
    13. print(a2)
    14. print(a3.shape)
    15. print(a3)
    16. print(a4.shape)
    17. print(a4)
    18. print(a5.shape)
    19. print(a5)

    运行结果显示如下 :(维度不为1则不能扩展)

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([3, 1, 4, 1])
    4. tensor([[[[ 1],
    5. [ 2],
    6. [ 3],
    7. [ 4]]],
    8. [[[ 5],
    9. [ 6],
    10. [ 7],
    11. [ 8]]],
    12. [[[ 9],
    13. [10],
    14. [11],
    15. [12]]]])
    16. torch.Size([3, 5, 4, 2])
    17. tensor([[[[ 1, 1],
    18. [ 2, 2],
    19. [ 3, 3],
    20. [ 4, 4]],
    21. [[ 1, 1],
    22. [ 2, 2],
    23. [ 3, 3],
    24. [ 4, 4]],
    25. [[ 1, 1],
    26. [ 2, 2],
    27. [ 3, 3],
    28. [ 4, 4]],
    29. [[ 1, 1],
    30. [ 2, 2],
    31. [ 3, 3],
    32. [ 4, 4]],
    33. [[ 1, 1],
    34. [ 2, 2],
    35. [ 3, 3],
    36. [ 4, 4]]],
    37. [[[ 5, 5],
    38. [ 6, 6],
    39. [ 7, 7],
    40. [ 8, 8]],
    41. [[ 5, 5],
    42. [ 6, 6],
    43. [ 7, 7],
    44. [ 8, 8]],
    45. [[ 5, 5],
    46. [ 6, 6],
    47. [ 7, 7],
    48. [ 8, 8]],
    49. [[ 5, 5],
    50. [ 6, 6],
    51. [ 7, 7],
    52. [ 8, 8]],
    53. [[ 5, 5],
    54. [ 6, 6],
    55. [ 7, 7],
    56. [ 8, 8]]],
    57. [[[ 9, 9],
    58. [10, 10],
    59. [11, 11],
    60. [12, 12]],
    61. [[ 9, 9],
    62. [10, 10],
    63. [11, 11],
    64. [12, 12]],
    65. [[ 9, 9],
    66. [10, 10],
    67. [11, 11],
    68. [12, 12]],
    69. [[ 9, 9],
    70. [10, 10],
    71. [11, 11],
    72. [12, 12]],
    73. [[ 9, 9],
    74. [10, 10],
    75. [11, 11],
    76. [12, 12]]]])
    77. torch.Size([3, 5, 4, 1])
    78. tensor([[[[ 1],
    79. [ 2],
    80. [ 3],
    81. [ 4]],
    82. [[ 1],
    83. [ 2],
    84. [ 3],
    85. [ 4]],
    86. [[ 1],
    87. [ 2],
    88. [ 3],
    89. [ 4]],
    90. [[ 1],
    91. [ 2],
    92. [ 3],
    93. [ 4]],
    94. [[ 1],
    95. [ 2],
    96. [ 3],
    97. [ 4]]],
    98. [[[ 5],
    99. [ 6],
    100. [ 7],
    101. [ 8]],
    102. [[ 5],
    103. [ 6],
    104. [ 7],
    105. [ 8]],
    106. [[ 5],
    107. [ 6],
    108. [ 7],
    109. [ 8]],
    110. [[ 5],
    111. [ 6],
    112. [ 7],
    113. [ 8]],
    114. [[ 5],
    115. [ 6],
    116. [ 7],
    117. [ 8]]],
    118. [[[ 9],
    119. [10],
    120. [11],
    121. [12]],
    122. [[ 9],
    123. [10],
    124. [11],
    125. [12]],
    126. [[ 9],
    127. [10],
    128. [11],
    129. [12]],
    130. [[ 9],
    131. [10],
    132. [11],
    133. [12]],
    134. [[ 9],
    135. [10],
    136. [11],
    137. [12]]]])
    138. torch.Size([2, 3, 5, 4, 1])
    139. tensor([[[[[ 1],
    140. [ 2],
    141. [ 3],
    142. [ 4]],
    143. [[ 1],
    144. [ 2],
    145. [ 3],
    146. [ 4]],
    147. [[ 1],
    148. [ 2],
    149. [ 3],
    150. [ 4]],
    151. [[ 1],
    152. [ 2],
    153. [ 3],
    154. [ 4]],
    155. [[ 1],
    156. [ 2],
    157. [ 3],
    158. [ 4]]],
    159. [[[ 5],
    160. [ 6],
    161. [ 7],
    162. [ 8]],
    163. [[ 5],
    164. [ 6],
    165. [ 7],
    166. [ 8]],
    167. [[ 5],
    168. [ 6],
    169. [ 7],
    170. [ 8]],
    171. [[ 5],
    172. [ 6],
    173. [ 7],
    174. [ 8]],
    175. [[ 5],
    176. [ 6],
    177. [ 7],
    178. [ 8]]],
    179. [[[ 9],
    180. [10],
    181. [11],
    182. [12]],
    183. [[ 9],
    184. [10],
    185. [11],
    186. [12]],
    187. [[ 9],
    188. [10],
    189. [11],
    190. [12]],
    191. [[ 9],
    192. [10],
    193. [11],
    194. [12]],
    195. [[ 9],
    196. [10],
    197. [11],
    198. [12]]]],
    199. [[[[ 1],
    200. [ 2],
    201. [ 3],
    202. [ 4]],
    203. [[ 1],
    204. [ 2],
    205. [ 3],
    206. [ 4]],
    207. [[ 1],
    208. [ 2],
    209. [ 3],
    210. [ 4]],
    211. [[ 1],
    212. [ 2],
    213. [ 3],
    214. [ 4]],
    215. [[ 1],
    216. [ 2],
    217. [ 3],
    218. [ 4]]],
    219. [[[ 5],
    220. [ 6],
    221. [ 7],
    222. [ 8]],
    223. [[ 5],
    224. [ 6],
    225. [ 7],
    226. [ 8]],
    227. [[ 5],
    228. [ 6],
    229. [ 7],
    230. [ 8]],
    231. [[ 5],
    232. [ 6],
    233. [ 7],
    234. [ 8]],
    235. [[ 5],
    236. [ 6],
    237. [ 7],
    238. [ 8]]],
    239. [[[ 9],
    240. [10],
    241. [11],
    242. [12]],
    243. [[ 9],
    244. [10],
    245. [11],
    246. [12]],
    247. [[ 9],
    248. [10],
    249. [11],
    250. [12]],
    251. [[ 9],
    252. [10],
    253. [11],
    254. [12]],
    255. [[ 9],
    256. [10],
    257. [11],
    258. [12]]]]])

    7 repeat函数

    根据指定维度复制张量,与 expand 不同的是,该方法会拷贝原张量的数据

    Tensor.repeat( *sizes) → [Tensor]

    参数:

    • sizes (torch.Size or [int] – 指定维度复制的次数 
    1. import torch
    2. a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. print(a1.storage().data_ptr())
    4. a2 = a1.reshape(3, 1, 4)
    5. print(a2.storage().data_ptr())
    6. a3 = a2.expand(3, 3, -1)
    7. # expand 操作后,张量的内存地址没变
    8. print(a3.storage().data_ptr())
    9. a4 = a2.repeat(2, 4, 1)
    10. # repeat 操作后,张量的内存地址会改变
    11. print(a4.storage().data_ptr())
    12. print(a1.shape)
    13. print(a1)
    14. print(a2.shape)
    15. print(a2)
    16. print(a3.shape)
    17. print(a3)
    18. print(a4.shape)

     运行结果显示如下:

    1. 1974461518528
    2. 1974461518528
    3. 1974461518528
    4. 1974462302208
    5. torch.Size([12])
    6. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    7. torch.Size([3, 1, 4])
    8. tensor([[[ 1, 2, 3, 4]],
    9. [[ 5, 6, 7, 8]],
    10. [[ 9, 10, 11, 12]]])
    11. torch.Size([3, 3, 4])
    12. tensor([[[ 1, 2, 3, 4],
    13. [ 1, 2, 3, 4],
    14. [ 1, 2, 3, 4]],
    15. [[ 5, 6, 7, 8],
    16. [ 5, 6, 7, 8],
    17. [ 5, 6, 7, 8]],
    18. [[ 9, 10, 11, 12],
    19. [ 9, 10, 11, 12],
    20. [ 9, 10, 11, 12]]])
    21. torch.Size([6, 4, 4])
    22. tensor([[[ 1, 2, 3, 4],
    23. [ 1, 2, 3, 4],
    24. [ 1, 2, 3, 4],
    25. [ 1, 2, 3, 4]],
    26. [[ 5, 6, 7, 8],
    27. [ 5, 6, 7, 8],
    28. [ 5, 6, 7, 8],
    29. [ 5, 6, 7, 8]],
    30. [[ 9, 10, 11, 12],
    31. [ 9, 10, 11, 12],
    32. [ 9, 10, 11, 12],
    33. [ 9, 10, 11, 12]],
    34. [[ 1, 2, 3, 4],
    35. [ 1, 2, 3, 4],
    36. [ 1, 2, 3, 4],
    37. [ 1, 2, 3, 4]],
    38. [[ 5, 6, 7, 8],
    39. [ 5, 6, 7, 8],
    40. [ 5, 6, 7, 8],
    41. [ 5, 6, 7, 8]],
    42. [[ 9, 10, 11, 12],
    43. [ 9, 10, 11, 12],
    44. [ 9, 10, 11, 12],
    45. [ 9, 10, 11, 12]]])

    8 permute函数

    返回重新排列的张量

    torch.permute(input, dims) → [Tensor]

     也可以在tensor上直接使用permute,形式如下: 

    tensor.permute(dims) → [Tensor]

    参数:

    • input ([Tensor] 要重新排列的张量
    • dims (tuple of python:int) 需要重排的维度索引数组

     

    1. import torch
    2. a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. a2 = a1.reshape(3, 1, 4)
    4. a3 = torch.permute(a2, (2, 0, 1))
    5. a4 = torch.permute(a2, (1, 0, 2))
    6. a5 = a2.permute(1, 2, 0)
    7. print(a1.shape)
    8. print(a1)
    9. print(a2.shape)
    10. print(a2)
    11. print(a3.shape)
    12. print(a3)
    13. print(a4.shape)
    14. print(a4)
    15. print(a5.shape)
    16. print(a5)

    运行结果显示如下:

    1. torch.Size([12])
    2. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
    3. torch.Size([3, 1, 4])
    4. tensor([[[ 1, 2, 3, 4]],
    5. [[ 5, 6, 7, 8]],
    6. [[ 9, 10, 11, 12]]])
    7. torch.Size([4, 3, 1])
    8. tensor([[[ 1],
    9. [ 5],
    10. [ 9]],
    11. [[ 2],
    12. [ 6],
    13. [10]],
    14. [[ 3],
    15. [ 7],
    16. [11]],
    17. [[ 4],
    18. [ 8],
    19. [12]]])
    20. torch.Size([1, 3, 4])
    21. tensor([[[ 1, 2, 3, 4],
    22. [ 5, 6, 7, 8],
    23. [ 9, 10, 11, 12]]])
    24. torch.Size([1, 4, 3])
    25. tensor([[[ 1, 5, 9],
    26. [ 2, 6, 10],
    27. [ 3, 7, 11],
    28. [ 4, 8, 12]]])

     

     

  • 相关阅读:
    发布订阅模式
    教你1分钟搞定2小时字幕
    解决:使用WileyNJDv5_Template模板时,无法生成pdf文件。
    UOS服务器操作系统搭建离线yum仓库
    迈向数字化发展新阶段,某商业银行数据存储创新方案及实践经验
    Java随笔-反射
    Selenium 模拟浏览器操作案例
    多线程的创建及状态描述
    pytorch的searchsorted解释
    【一行记录】达梦timestamp转yyyy-mm-dd
  • 原文地址:https://blog.csdn.net/lsb2002/article/details/132905346