torch.gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor
沿
对于一个三维张量
o u t p u t [ i ] [ j ] [ k ] = i n p u t [ i n d e x [ i ] [ j ] [ k ] ] [ j ] [ k ] # i f d i m == 0 " role="presentation" style="position: relative;">
o u t p u t [ i ] [ j ] [ k ] = i n p u t [ i ] [ i n d e x [ i ] [ j ] [ k ] ] [ k ] # i f d i m == 1 " role="presentation" style="position: relative;">
o u t p u t [ i ] [ j ] [ k ] = i n p u t [ i ] [ j ] [ i n d e x [ i ] [ j ] [ k ] ] # i f d i m == 2 " role="presentation" style="position: relative;">
下面用两个例子来解释一下具体的用法
- import torch
-
- dim = 0
- _input = torch.tensor([[10, 11, 12],
- [13, 14, 15],
- [16, 17, 18]])
- index = torch.tensor([[0, 1, 2],
- [1, 2, 0]])
-
- output = torch.gather(_input, dim, index)
-
- print(output)
- # tensor([[10, 14, 18],
- # [13, 17, 12]])

该例中 _input.shape=(3, 3),dimensions=2,其中_input和index的dimensions相同都为2,output和index的shape相同都为(2, 3)。
因为dim=0,index中的每个数其值代表dim=0即"行"这个维度的索引,而每个数本身所在位置的索引指定了其它维度的索引。比如index中第0行的[0, 1, 2]分别表示第0、1、2行,而这三个数本身在dim=1维度的索引为0、1、2即第0、1、2列。因此第一个数0定位到_input中的第0行,而0本身在index中的第0列,因此又定位到_input的第0列,这样就找到了10这个数,同理找到14和18。
index中的第1行[1, 2, 0]分别表示_input中的第1、2、0行和第0、1、2列,因此找到_input中对应的数[13, 17, 12]。
- import torch
-
- dim = 1
- _input = torch.tensor([[10, 11, 12],
- [13, 14, 15],
- [16, 17, 18]])
- index = torch.tensor([[0, 1],
- [1, 2],
- [2, 0]])
-
- output = torch.gather(_input, dim, index)
- print(output)
- # tensor([[10, 11],
- # [14, 15],
- # [18, 16]])

该例中 _input.shape=(3, 3),dimensions=2,其中_input和index的dimensions相同都为2,output和index的shape相同都为(3, 2)。
因为dim=1,index中的每个数其值代表dim=1即"列"这个维度的索引,而每个数本身所在位置的索引指定了其它维度的索引。比如index中第0行的[0, 1]分别表示第0、1列,而这三个数本身在dim=0维度的索引为0即第0行。因此第一个数0定位到_input中的第0列,而0本身在index中的第0行,因此又定位到_input的第0行,这样就找到了10这个数,同理找到11。
index中的第1行[1, 2]分别表示_input中的第1、2列和第1行,因此找到_input中对应的数[14, 15]。
index中的第2行[2, 0]分别表示_input中的第2、0列和第2行,因此找到_input中对应的数[18, 16]。
上面的示例是二维的情况,同理也可以推广到三维甚至更多维。总结来说,index中每个数其本身的值表示参数dim指定维度的索引,而其它的每个维度都由每个数在index中的对应维度的索引指定。
torch.gather — PyTorch 1.12 documentation
python - What does the gather function do in pytorch in layman terms? - Stack Overflow