官方解释:
torch.gather(input, dim, index, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
Parameters:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
Example:
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
所以,该怎么用呢?
拿维度1举列子:
如果我们要取input指定行的值,则要求index[]的每一行的值是一样的,比如
index = [[2,2,2,2,2],[4,4,4,4,4]]
out[i][j][k] = input[i][index[i][j][k]][k]
这样的话,output[i][j][k] = input[i][index[i][j][k]][k] k从0到5的话,则index[i][j][k]是始终不变的,
所以output[i][j][k] = input[i][2][k],也就是把原来输入的第j行改成了第2行的值,但要注意,这个取值的方法是要求index的一行的值是一样的。