目录
4.2 tensor.unsqueeze指定dim插入新维度
Pytorch张量维度变化是在构建模型过程中常用且重要的操作,本文从实际应用触发,详细介绍常用的维度变化方法,这些方法包含view、reshap、squeeze、unsqueeze、transpose等。
Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。
view函数的操作对象是Tensor类型,返回的对象类型也为Tensor类型
def view(self, *size: _int) -> Tensor: ...
更便于理解的表示形式:
view(参数a,参数b,…),其中,总的参数个数表示将张量重构后的维度。
通过手工指定,将一个一维tensor变换为3*8维的tensor
- import torch
-
- a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
- 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
-
- a2 = a1.view(3, 8)
- print(a1)
- print(a2)
- print(a1.shape)
- print(a2.shape)
运行程序显示如下:
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
- 19, 20, 21, 22, 23, 24])
- tensor([[ 1, 2, 3, 4, 5, 6, 7, 8],
- [ 9, 10, 11, 12, 13, 14, 15, 16],
- [17, 18, 19, 20, 21, 22, 23, 24]])
- torch.Size([24])
- torch.Size([3, 8])
如果某个参数为-1,则表示该维度取决于其它维度,由Pytorch自己补充
- import torch
-
- a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
- 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
-
- a4 = a3.view(4, -1)
- a5 = a3.view(2, 3, -1)
- a6 = a3.view(-1, 3, 2)
-
- print(a3)
- print(a4)
- print(a5)
- print(a6)
- print(a3.shape)
- print(a4.shape)
- print(a5.shape)
- print(a6.shape)
运行程序显示如下:
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
- 19, 20, 21, 22, 23, 24])
- tensor([[ 1, 2, 3, 4, 5, 6],
- [ 7, 8, 9, 10, 11, 12],
- [13, 14, 15, 16, 17, 18],
- [19, 20, 21, 22, 23, 24]])
- tensor([[[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]],
- [[13, 14, 15, 16],
- [17, 18, 19, 20],
- [21, 22, 23, 24]]])
- tensor([[[ 1, 2],
- [ 3, 4],
- [ 5, 6]],
- [[ 7, 8],
- [ 9, 10],
- [11, 12]],
- [[13, 14],
- [15, 16],
- [17, 18]],
- [[19, 20],
- [21, 22],
- [23, 24]]])
- torch.Size([24])
- torch.Size([4, 6])
- torch.Size([2, 3, 4])
- torch.Size([4, 3, 2])
- import torch
-
- a7 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
- a8 = a6.view(-1)
- print(a7)
- print(a8)
- print(a7.shape)
- print(a8.shape)
运行程序显示如下:
- tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
- 19, 20, 21, 22, 23, 24])
- torch.Size([2, 12])
- torch.Size([24])
返回与 input张量数据大小一样、给定 shape的张量。如果可能,返回的是input 张量的视图,否则返回的是其拷贝。
torch.reshape(input, shape) → [Tensor]
也可以直接在Tensor上使用reshape,形式如下:
tensor.reshape(shape) → [Tensor]
- import torch
-
- a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a2 = torch.reshape(a1, (3, 4))
- print(a1.shape)
- print(a1)
- print(a2.shape)
- print(a2)
运行程序显示如下:
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([3, 4])
- tensor([[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]])
- import torch
-
- a3 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a4 = torch.reshape(a1, (-1, 6))
- print(a3.shape)
- print(a3)
- print(a4.shape)
- print(a4)
运行程序显示如下:
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([2, 6])
- tensor([[ 1, 2, 3, 4, 5, 6],
- [ 7, 8, 9, 10, 11, 12]])
- import torch
-
- a5 = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
- a6 = torch.reshape(a1, (-1,))
- print(a5.shape)
- print(a5)
- print(a6.shape)
- print(a6)
运行程序显示如下:
- torch.Size([2, 6])
- tensor([[ 1, 2, 3, 4, 5, 6],
- [ 7, 8, 9, 10, 11, 12]])
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- improt torch
-
- a7 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a8 = a7.reshape(6, 2)
- a9 = a7.reshape(-1, 3)
- a10 = a9.reshape(-1)
- print(a7.shape)
- print(a7)
- print(a8.shape)
- print(a8)
- print(a9.shape)
- print(a9)
- print(a10.shape)
- print(a10)
运行结果显示如下:
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([6, 2])
- tensor([[ 1, 2],
- [ 3, 4],
- [ 5, 6],
- [ 7, 8],
- [ 9, 10],
- [11, 12]])
- torch.Size([4, 3])
- tensor([[ 1, 2, 3],
- [ 4, 5, 6],
- [ 7, 8, 9],
- [10, 11, 12]])
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
将input张量中所有维度数据为1的维度给移除掉。指定了dim,如果dim对应维度的值不为1 ,则保持不变,为1则移除该维度。
torch.squeeze(input, dim=None) → [Tensor]
也可以在tensor上直接使用squeeze,形式如下:
tensor.squeeze(dim=None) → [Tensor]
- import torch
-
- a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a2 = a1.reshape(3, 1, 4)
- a3 = torch.squeeze(a2)
-
- print(a1.shape)
- print(a1)
- print(a2.shape)
- print(a2)
- print(a3.shape)
- print(a3)
运行结果显示如下:(a2的第二个维度被移除)
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([3, 1, 4])
- tensor([[[ 1, 2, 3, 4]],
-
- [[ 5, 6, 7, 8]],
-
- [[ 9, 10, 11, 12]]])
- torch.Size([3, 4])
- tensor([[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]])
- import torch
-
- a4 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a5 = a1.reshape(3, 1, 4)
- a6 = torch.squeeze(a5, 0)
- a7 = torch.squeeze(a5, 1)
-
- print(a4.shape)
- print(a4)
- print(a5.shape)
- print(a5)
- print(a6.shape)
- print(a6)
- print(a7.shape)
- print(a7)
运行结果显示如下:(a5的第一个维度不为1,所以保持不变;a5的第二个维度为1,所以被移除)
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([3, 1, 4])
- tensor([[[ 1, 2, 3, 4]],
-
- [[ 5, 6, 7, 8]],
-
- [[ 9, 10, 11, 12]]])
- torch.Size([3, 1, 4])
- tensor([[[ 1, 2, 3, 4]],
-
- [[ 5, 6, 7, 8]],
-
- [[ 9, 10, 11, 12]]])
- torch.Size([3, 4])
- tensor([[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]])
- import torch
-
- a8 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a9 = a8.reshape(3, 1, 4)
- a10 = a9.squeeze()
- a11 = a9.squeeze(0)
- a12 = a9.squeeze(1)
-
- print(a8.shape)
- print(a8)
- print(a9.shape)
- print(a9)
- print(a10.shape)
- print(a10)
- print(a11.shape)
- print(a11)
- print(a12.shape)
- print(a12)
运行结果显示如下:
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([3, 1, 4])
- tensor([[[ 1, 2, 3, 4]],
-
- [[ 5, 6, 7, 8]],
-
- [[ 9, 10, 11, 12]]])
- torch.Size([3, 4])
- tensor([[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]])
- torch.Size([3, 1, 4])
- tensor([[[ 1, 2, 3, 4]],
-
- [[ 5, 6, 7, 8]],
-
- [[ 9, 10, 11, 12]]])
- torch.Size([3, 4])
- tensor([[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]])
在给定的 dim 维度位置插入一个新的维度,维度数值为 1,dim 的范围在 [-dim()-1, dim()+1),包首不包尾
torch.unsqueeze(input, dim) → [Tensor]
也可以在tensor上直接使用unsqueeze,形式如下:
torch.unsqueeze(dim) → [Tensor]
- import torch
-
- a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a2 = a1.reshape(3, 4)
- a3 = torch.unsqueeze(a2, 0)
- a4 = torch.unsqueeze(a2, 2)
-
- print(a1.shape)
- print(a1)
- print(a2.shape)
- print(a2)
- print(a3.shape)
- print(a3)
- print(a4.shape)
- print(a4)
运行结果显示如下:
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([3, 4])
- tensor([[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]])
- torch.Size([1, 3, 4])
- tensor([[[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]]])
- torch.Size([3, 4, 1])
- tensor([[[ 1],
- [ 2],
- [ 3],
- [ 4]],
-
- [[ 5],
- [ 6],
- [ 7],
- [ 8]],
-
- [[ 9],
- [10],
- [11],
- [12]]])
- import torch
-
- a5 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a6 = a5.reshape(3, 4)
- a7 = a6.unsqueeze(0)
- a8 = a6.unsqueeze(1)
-
- print(a5.shape)
- print(a5)
- print(a6.shape)
- print(a6)
- print(a7.shape)
- print(a7)
- print(a8.shape)
- print(a8)
运行结果显示如下:
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([3, 4])
- tensor([[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]])
- torch.Size([1, 3, 4])
- tensor([[[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]]])
- torch.Size([3, 1, 4])
- tensor([[[ 1, 2, 3, 4]],
-
- [[ 5, 6, 7, 8]],
-
- [[ 9, 10, 11, 12]]])
返回 input 张量的转置,dim0与dim1交换位置
torch.transpose(input, dim0, dim1) → [Tensor]
也可以在tensor上直接使用unsqueeze,形式如下:
tensor.transpose(dim0, dim1) → [Tensor]
参数:
- import torch
-
- a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a2 = a1.reshape(4, 3, 1)
- a3 = torch.transpose(a2, 0, 1)
- a4 = torch.transpose(a2, 1, 2)
-
- print(a1.shape)
- print(a1)
- print(a2.shape)
- print(a2)
- print(a3.shape)
- print(a3)
- print(a4.shape)
- print(a4)
运行结果显示如下:
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([4, 3, 1])
- tensor([[[ 1],
- [ 2],
- [ 3]],
- [[ 4],
- [ 5],
- [ 6]],
- [[ 7],
- [ 8],
- [ 9]],
- [[10],
- [11],
- [12]]])
- torch.Size([3, 4, 1])
- tensor([[[ 1],
- [ 4],
- [ 7],
- [10]],
- [[ 2],
- [ 5],
- [ 8],
- [11]],
- [[ 3],
- [ 6],
- [ 9],
- [12]]])
- torch.Size([4, 1, 3])
- tensor([[[ 1, 2, 3]],
- [[ 4, 5, 6]],
- [[ 7, 8, 9]],
- [[10, 11, 12]]])
- import torch
-
- a5 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a6 = a1.reshape(4, 3, 1)
- a7 = a6.transpose(0, 1)
- a8 = a6.transpose(1, 2)
-
- print(a5.shape)
- print(a5)
- print(a6.shape)
- print(a6)
- print(a7.shape)
- print(a7)
- print(a8.shape)
- print(a8)
运行结果显示如下:
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([4, 3, 1])
- tensor([[[ 1],
- [ 2],
- [ 3]],
- [[ 4],
- [ 5],
- [ 6]],
- [[ 7],
- [ 8],
- [ 9]],
- [[10],
- [11],
- [12]]])
- torch.Size([3, 4, 1])
- tensor([[[ 1],
- [ 4],
- [ 7],
- [10]],
- [[ 2],
- [ 5],
- [ 8],
- [11]],
- [[ 3],
- [ 6],
- [ 9],
- [12]]])
- torch.Size([4, 1, 3])
- tensor([[[ 1, 2, 3]],
- [[ 4, 5, 6]],
- [[ 7, 8, 9]],
- [[10, 11, 12]]])
返回张量的新视图,其某个维度 size 扩展到更大的 size,如果当前维度 size 为 -1 ,表示当前维度 size 保持不变。
Tensor也可以扩展到更多的维度,新的会追加在最前面。对于新维度,大小不能设置为 -1;
扩展张量不会分配新内存,而只会在现有张量上创建一个新视图。任何大小为1的维度都可以扩展为任意值,而无需分配新内存。
Tensor.expand( *sizes) → [Tensor]
参数:
- sizes (torch.Size or [int] – 指定维度复制的次数
- import torch
-
- a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a2 = a1.reshape(3, 1, 4, 1)
-
- # 维度为 1 的 size 可以扩展成什么任意的 size
- a3 = a2.expand(3, 5, 4, 2)
-
- # -1 表示对应的维度size不变,但如果第一个维度3扩展成6则会报错,维度不为1不能扩展
- a4 = a2.expand(-1, 5, -1, -1)
-
- # 可以扩展新的维度,但只会放到最前面,不能放到后面(会报错)且不能设置为-1
- a5 = a2.expand(2, -1, 5, -1, -1)
-
- print(a1.shape)
- print(a1)
- print(a2.shape)
- print(a2)
- print(a3.shape)
- print(a3)
- print(a4.shape)
- print(a4)
- print(a5.shape)
- print(a5)
运行结果显示如下 :(维度不为1则不能扩展)
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([3, 1, 4, 1])
- tensor([[[[ 1],
- [ 2],
- [ 3],
- [ 4]]],
- [[[ 5],
- [ 6],
- [ 7],
- [ 8]]],
- [[[ 9],
- [10],
- [11],
- [12]]]])
- torch.Size([3, 5, 4, 2])
- tensor([[[[ 1, 1],
- [ 2, 2],
- [ 3, 3],
- [ 4, 4]],
- [[ 1, 1],
- [ 2, 2],
- [ 3, 3],
- [ 4, 4]],
- [[ 1, 1],
- [ 2, 2],
- [ 3, 3],
- [ 4, 4]],
- [[ 1, 1],
- [ 2, 2],
- [ 3, 3],
- [ 4, 4]],
- [[ 1, 1],
- [ 2, 2],
- [ 3, 3],
- [ 4, 4]]],
- [[[ 5, 5],
- [ 6, 6],
- [ 7, 7],
- [ 8, 8]],
- [[ 5, 5],
- [ 6, 6],
- [ 7, 7],
- [ 8, 8]],
- [[ 5, 5],
- [ 6, 6],
- [ 7, 7],
- [ 8, 8]],
- [[ 5, 5],
- [ 6, 6],
- [ 7, 7],
- [ 8, 8]],
- [[ 5, 5],
- [ 6, 6],
- [ 7, 7],
- [ 8, 8]]],
- [[[ 9, 9],
- [10, 10],
- [11, 11],
- [12, 12]],
- [[ 9, 9],
- [10, 10],
- [11, 11],
- [12, 12]],
- [[ 9, 9],
- [10, 10],
- [11, 11],
- [12, 12]],
- [[ 9, 9],
- [10, 10],
- [11, 11],
- [12, 12]],
- [[ 9, 9],
- [10, 10],
- [11, 11],
- [12, 12]]]])
- torch.Size([3, 5, 4, 1])
- tensor([[[[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]]],
- [[[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]]],
- [[[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]]]])
- torch.Size([2, 3, 5, 4, 1])
- tensor([[[[[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]]],
- [[[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]]],
- [[[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]]]],
- [[[[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]],
- [[ 1],
- [ 2],
- [ 3],
- [ 4]]],
- [[[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]],
- [[ 5],
- [ 6],
- [ 7],
- [ 8]]],
- [[[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]],
- [[ 9],
- [10],
- [11],
- [12]]]]])
根据指定维度复制张量,与 expand 不同的是,该方法会拷贝原张量的数据
Tensor.repeat( *sizes) → [Tensor]
参数:
- sizes (torch.Size or [int] – 指定维度复制的次数
- import torch
-
- a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- print(a1.storage().data_ptr())
- a2 = a1.reshape(3, 1, 4)
- print(a2.storage().data_ptr())
- a3 = a2.expand(3, 3, -1)
-
- # expand 操作后,张量的内存地址没变
- print(a3.storage().data_ptr())
-
- a4 = a2.repeat(2, 4, 1)
-
- # repeat 操作后,张量的内存地址会改变
- print(a4.storage().data_ptr())
-
- print(a1.shape)
- print(a1)
- print(a2.shape)
- print(a2)
- print(a3.shape)
- print(a3)
- print(a4.shape)
运行结果显示如下:
- 1974461518528
- 1974461518528
- 1974461518528
- 1974462302208
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([3, 1, 4])
- tensor([[[ 1, 2, 3, 4]],
- [[ 5, 6, 7, 8]],
- [[ 9, 10, 11, 12]]])
- torch.Size([3, 3, 4])
- tensor([[[ 1, 2, 3, 4],
- [ 1, 2, 3, 4],
- [ 1, 2, 3, 4]],
- [[ 5, 6, 7, 8],
- [ 5, 6, 7, 8],
- [ 5, 6, 7, 8]],
- [[ 9, 10, 11, 12],
- [ 9, 10, 11, 12],
- [ 9, 10, 11, 12]]])
- torch.Size([6, 4, 4])
- tensor([[[ 1, 2, 3, 4],
- [ 1, 2, 3, 4],
- [ 1, 2, 3, 4],
- [ 1, 2, 3, 4]],
- [[ 5, 6, 7, 8],
- [ 5, 6, 7, 8],
- [ 5, 6, 7, 8],
- [ 5, 6, 7, 8]],
- [[ 9, 10, 11, 12],
- [ 9, 10, 11, 12],
- [ 9, 10, 11, 12],
- [ 9, 10, 11, 12]],
- [[ 1, 2, 3, 4],
- [ 1, 2, 3, 4],
- [ 1, 2, 3, 4],
- [ 1, 2, 3, 4]],
- [[ 5, 6, 7, 8],
- [ 5, 6, 7, 8],
- [ 5, 6, 7, 8],
- [ 5, 6, 7, 8]],
- [[ 9, 10, 11, 12],
- [ 9, 10, 11, 12],
- [ 9, 10, 11, 12],
- [ 9, 10, 11, 12]]])
返回重新排列的张量
torch.permute(input, dims) → [Tensor]
也可以在tensor上直接使用permute,形式如下:
tensor.permute(dims) → [Tensor]
参数:
- input ([Tensor] 要重新排列的张量
- dims (tuple of python:int) 需要重排的维度索引数组
- import torch
-
- a1 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- a2 = a1.reshape(3, 1, 4)
- a3 = torch.permute(a2, (2, 0, 1))
- a4 = torch.permute(a2, (1, 0, 2))
- a5 = a2.permute(1, 2, 0)
-
- print(a1.shape)
- print(a1)
- print(a2.shape)
- print(a2)
- print(a3.shape)
- print(a3)
- print(a4.shape)
- print(a4)
- print(a5.shape)
- print(a5)
运行结果显示如下:
- torch.Size([12])
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
- torch.Size([3, 1, 4])
- tensor([[[ 1, 2, 3, 4]],
- [[ 5, 6, 7, 8]],
- [[ 9, 10, 11, 12]]])
- torch.Size([4, 3, 1])
- tensor([[[ 1],
- [ 5],
- [ 9]],
- [[ 2],
- [ 6],
- [10]],
- [[ 3],
- [ 7],
- [11]],
- [[ 4],
- [ 8],
- [12]]])
- torch.Size([1, 3, 4])
- tensor([[[ 1, 2, 3, 4],
- [ 5, 6, 7, 8],
- [ 9, 10, 11, 12]]])
- torch.Size([1, 4, 3])
- tensor([[[ 1, 5, 9],
- [ 2, 6, 10],
- [ 3, 7, 11],
- [ 4, 8, 12]]])