torch.
cat
(tensors, dim=0, out=None) → Tensor
Concatenates the given sequence of seq
tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().
torch.cat() can be best understood via examples.
Parameters:
Example:
- >>> x = torch.randn(2, 3)
- >>> x
- tensor([[ 0.6580, -1.0969, -0.4614],
- [-0.1034, -0.5790, 0.1497]])
- >>> torch.cat((x, x, x), 0)
- tensor([[ 0.6580, -1.0969, -0.4614],
- [-0.1034, -0.5790, 0.1497],
- [ 0.6580, -1.0969, -0.4614],
- [-0.1034, -0.5790, 0.1497],
- [ 0.6580, -1.0969, -0.4614],
- [-0.1034, -0.5790, 0.1497]])
- >>> torch.cat((x, x, x), 1)
- tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
- -1.0969, -0.4614],
- [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
- -0.5790, 0.1497]])