index_select
是 PyTorch 中的一个非常有用的函数,允许从给定的维度中选择指定索引的张量值
torch.index_select(input, dim, index, out=None) -> Tensor
input | 从中选择数据的源张量 |
dim | 从中选择数据的维度 |
index | 一个 1D 张量,包含你想要从 此张量应该是 |
out | 一个可选的参数,用于指定输出张量。 如果没有提供,将创建一个新的张量。 |
- import torch
- import numpy as np
-
- x = torch.tensor(np.arange(16).reshape(4,4))
- index=torch.LongTensor([1,3])
- x
- '''
- tensor([[ 0, 1, 2, 3],
- [ 4, 5, 6, 7],
- [ 8, 9, 10, 11],
- [12, 13, 14, 15]], dtype=torch.int32)
- '''
-
- torch.index_select(x,dim=0,index=index)
- '''
- tensor([[ 4, 5, 6, 7],
- [12, 13, 14, 15]], dtype=torch.int32)
- '''
-
- torch.index_select(x,dim=1,index=index)
- '''
- tensor([[ 1, 3],
- [ 5, 7],
- [ 9, 11],
- [13, 15]], dtype=torch.int32)
- '''
- import torch
- import numpy as np
-
- x = torch.tensor(np.arange(16).reshape(4,4),dtype=torch.float32, requires_grad=True)
- index=torch.LongTensor([1,3])
- x
- '''
- tensor([[ 0., 1., 2., 3.],
- [ 4., 5., 6., 7.],
- [ 8., 9., 10., 11.],
- [12., 13., 14., 15.]], requires_grad=True)
- '''
-
- torch.index_select(x,dim=0,index=index)
- '''
- tensor([[ 4., 5., 6., 7.],
- [12., 13., 14., 15.]], grad_fn=
) - '''
-
- torch.index_select(x,dim=1,index=index)
- '''
- tensor([[ 1., 3.],
- [ 5., 7.],
- [ 9., 11.],
- [13., 15.]], grad_fn=
) - '''