torch.stack() 和 torch.cat() 都可以按照指定的维度进行拼接,但是两者也有区别,torch.satck() 是增加新的维度进行堆叠,即其维度拼接后会增加一个维度;而torch.cat() 是在原维度上进行堆叠,即其维度拼接后的维度个数和原来一致。具体说明如下:
torch.stack(input,dim)
input: 待拼接的张量序列组(list or tuple),拼接的tensor的维度必须要相等,即tensor1.shape = tensor2.shape
dim: 在哪个新增的维度上进行拼接,不能超过拼接后的张量数据的维度大小,默认为 0
复制代码
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
er-hljs
import torch
x1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
x2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.stack((x1,x2),dim=0).shape)
print(torch.stack((x1,x2),dim=1).shape)
print(torch.stack((x1,x2),dim=2).shape)
print(torch.stack((x1,x2),dim=0))
print(torch.stack((x1,x2),dim=1))
print(torch.stack((x1,x2),dim=2))
>> torch.Size([2, 3, 3]) # 2 表示是有两个tensor的拼接,且在第一个维度的位置拼接
>> torch.Size([3, 2, 3])
>> torch.Size([3, 3, 2])
>> tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
>> tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]])
>> tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]])
折叠
torch.cat(input, dim)
input: 待拼接的张量序列组(list or tuple),拼接的tensor的维度必须要相等,即tensor1.shape = tensor2.shape
dim: 在哪个已存在的维度上进行拼接,不能超过拼接后的张量数据的维度大小(即原来的维度大小),默认为 0
复制代码
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
er-hljs
import torch
x1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
x2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.cat((x1,x2),dim=0).shape)
print(torch.cat((x1,x2),dim=1).shape)
print(torch.cat((x1,x2),dim=0))
print(torch.cat((x1,x2),dim=1))
>> torch.Size([6, 3])
>> torch.Size([3, 6])
>> tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
>> tensor([[ 1, 2, 3, 10, 20, 30],
[ 4, 5, 6, 40, 50, 60],
[ 7, 8, 9, 70, 80, 90]])
折叠