如图所示,使用linspace生成如下(3 × 4)的矩阵,共12个数
a = torch.linspace(1, 12, steps=12).view(3, 4)
print(a)
运行结果:
a:
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])
说明:dim = 0,表示行;dim=1,表示列
以下代码的效果是:获取维度为0的第0行和第2行元素,运行结果如下图所示:
b = torch.index_select(a, dim=0, index=torch.tensor([0, 2]))
print(b)
b:
tensor([[ 1., 2., 3., 4.],
[ 9., 10., 11., 12.]])
相同运行结果的不同实现方式
print(a.index_select(0, torch.tensor([0, 2])))
获取a中dim=1的第一列和第三列
c = torch.index_select(a, dim=1, index = torch.tensor([1, 3]))
print(c)
运行结果:
tensor
([[ 2., 4.],
[ 6., 8.],
[10., 12.]])