在看centernet中遇到gather函数,看了好几个博客也都没咋看懂,直到看了这个视频链接,在此感谢这位哔站up主。
直接先看代码:
import torch
a = torch.arange(15).view(3, 5)
b = torch.zeros_like(a)
b[1][2] = 1
b[0][0] = 1
print(a)
print(b)
输出:
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
tensor([[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0]])
代码:
c = a.gather(0, b) # dim=0
d = a.gather(1, b) # dim=1
print(c)
print(d)
输出:
tensor([[5, 1, 2, 3, 4],
[0, 1, 7, 3, 4],
[0, 1, 2, 3, 4]])
tensor([[ 1, 0, 0, 0, 0],
[ 5, 5, 6, 5, 5],
[10, 10, 10, 10, 10]])
首先解释一下什么是dim=0和dim=1,其实0和1是索引矩阵b在被索引矩阵a索引的位置,例如dim=0,以3维为例,
output
[
i
]
[
j
]
[
k
]
[i][j][k]
[i][j][k]=input
[
i
n
d
e
x
[
i
]
[
j
]
[
k
]
]
[
j
]
[
k
]
[index[i][j][k]]\ \ [j][k]
[index[i][j][k]] [j][k]
dim=1,
output
[
i
]
[
j
]
[
k
]
[i][j][k]
[i][j][k]=input
[
i
]
[
i
n
d
e
x
[
i
]
[
j
]
[
k
]
]
[
k
]
[i]\ [index[i][j][k]]\ [k]
[i] [index[i][j][k]] [k]
上面的代码中,c、d为output,a为input,b为index,
所以c = a.gather(0, b) # dim=0,
c[0][0]=a[b[0][0]] [0]=a[1][0]=5
c[0][1]=a[b[0][1]] [1]=a[0][1]=1
c[0][2]=a[b[0][2]] [2]=a[0][2]=2
。。。依此类推
c = a.gather(0, b) # dim=1,
c[0][0]=a[0][b[0][0]]=a[0][1]=1
c[0][1]=a[0][b[0][1]]=a[0][0]=0
c[0][2]=a[0][b[0][2]]=a[0][0]=0
c[1][0]=a[1][b[1][0]]=a[1][0]=5
c[1][1]=a[1][b[1][1]]=a[1][0]=5
c[1][2]=a[1][b[1][2]]=a[1][1]=6
。。。依此类推