搭建深度神经网络模型时,难免会遇到 torch.cat()函数,来进行tensor的拼接。
代码案例如下:
import torch
import numpy as np
array1 = np.zeros((4, 1, 28, 28))
array2 = np.zeros((4, 1, 28, 28))
print("array1.shape:", array1.shape) # (4, 1, 28, 28)
tensor1 = torch.tensor(array1)
tensor2 = torch.tensor(array2)
print("tensor1.shape", tensor1.shape) # torch.Size([4, 1, 28, 28])
c = torch.cat(tensor1, tensor2, dim=0) # 报错
这样执行会报错,错误信息如下:
TypeError: cat() received an invalid combination of arguments - got (Tensor, Tensor, dim=int), but expected one of:
* (tuple of Tensors tensors, name dim, Tensor out)
didn't match because some of the keywords were incorrect: dim
* (tuple of Tensors tensors, int dim, Tensor out)
didn't match because some of the keywords were incorrect: dim
torch.cat() 函数进行tensor的拼接,将要拼接的tensor组合成元组,即可解决该报错1。
代码修改如下:
import torch
import numpy as np
array1 = np.zeros((4, 1, 28, 28))
array2 = np.zeros((4, 1, 28, 28))
print("array1.shape:", array1.shape) # (4, 1, 28, 28)
tensor1 = torch.tensor(array1)
tensor2 = torch.tensor(array2)
print("tensor1.shape", tensor1.shape) # torch.Size([4, 1, 28, 28])
# c = torch.cat(tensor1, tensor2, dim=0) # 报错
c = torch.cat((tensor1, tensor2), dim=0)
print("c.shape", c.shape)
c1 = torch.cat((tensor1, tensor2), dim=1)
print("c1.shape", c1.shape)
此时输出为:
array1.shape: (4, 1, 28, 28)
tensor1.shape torch.Size([4, 1, 28, 28])
c.shape torch.Size([8, 1, 28, 28])
c1.shape torch.Size([4, 2, 28, 28])
从该案例中,不仅可以学到dubug信息,还可以详细了解到 torch.cat()函数的具体用法,比如参数dim的含义等等。
如果我的这篇文章帮助到了你,那我也会感到很高兴,一个人能走多远,在于与谁同行。