在 PyTorch 中,“高级索引” 或 “花式索引” 允许用户使用整数数组或布尔数组进行索引,从而获取张量中的指定元素。
整数数组索引:使用一个整数数组作为索引,可以获取张量中指定位置的元素。
布尔数组索引:使用一个布尔数组作为索引,可以根据布尔数组的 True/False 值获取张量中对应位置的元素。
import torch
# 创建一个张量
tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
# 使用整数数组索引获取指定位置的元素
indices = torch.tensor([0, 2]) # 指定要获取的行索引
result = tensor[indices] # 获取指定行的元素
print(result)
输出:
tensor([[1, 2],
[5, 6]])
import torch
# 创建一个张量
tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
# 创建一个布尔数组,用于选择元素
mask = torch.tensor([True, False, True]) # 选择第1和第3行的元素
result = tensor[mask] # 根据布尔数组选择元素
print(result)
输出:
tensor([[1, 2],
[5, 6]])
这些是使用 PyTorch 中的高级索引进行元素选择的基本规则和示例。
import torch
# 创建一个示例张量
tensor = torch.tensor([
[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]],
[[13, 14, 15], [16, 17, 18]]
])
# 创建两个索引数组,用于选择元素
indices1 = torch.tensor([[0, 1], [2, 0]]) # 选择第0个维度(行)的索引
indices2 = torch.tensor([[1, 2], [0, 2]]) # 选择第1个维度(列)的索引
# 使用多维数据索引选择元素
result = tensor[indices1, indices2] # 根据索引数组选择元素
print(result)
输出:
tensor([[[ 4, 5, 6],
[ 7, 8, 9]],
[[13, 14, 15],
[ 4, 5, 6]]])
lq = 4
lk = 5
k = torch.rand((3,lq,lk))
index1 = torch.randint(lq,(1,2*lk))
index2 = torch.randint(lk,(lq,2*lk))
print(index1.shape,index2.shape)
print(k.shape,k[:, index1, index2].shape)
print('index1:\n',index1)
print('index2:\n',index2)
print('k:\n',k)
print('index k:\n',k[:,index1,index2])
其中,index1和index2可以互相broadcasting,并且index1和index2broadcasting的shape决定了索引结果的shape。
输出为:
torch.Size([1, 10]) torch.Size([4, 10])
torch.Size([3, 4, 5]) torch.Size([3, 4, 10])
index1:
tensor([[0, 1, 2, 3, 3, 0, 3, 0, 2, 3]])
index2:
tensor([[1, 1, 0, 0, 0, 3, 4, 1, 4, 4],
[3, 2, 2, 2, 1, 3, 4, 1, 4, 1],
[1, 1, 3, 0, 2, 2, 0, 0, 1, 4],
[4, 3, 3, 0, 4, 0, 3, 3, 3, 2]])
k:
tensor([[[0.8977, 0.2755, 0.6238, 0.4197, 0.9837],
[0.1231, 0.4493, 0.2263, 0.9272, 0.6347],
[0.3343, 0.3117, 0.6854, 0.2295, 0.6499],
[0.8584, 0.3650, 0.2476, 0.6275, 0.4702]],
[[0.8483, 0.6208, 0.6188, 0.4867, 0.1121],
[0.7733, 0.3900, 0.5515, 0.8151, 0.0637],
[0.3329, 0.7633, 0.1499, 0.2026, 0.0895],
[0.0793, 0.1707, 0.5915, 0.1170, 0.2679]],
[[0.4250, 0.5561, 0.2284, 0.8940, 0.1764],
[0.8897, 0.2199, 0.1317, 0.6584, 0.7289],
[0.3934, 0.3325, 0.7833, 0.7059, 0.7230],
[0.4195, 0.0095, 0.9322, 0.5098, 0.5191]]])
index k:
tensor([[[0.2755, 0.4493, 0.3343, 0.8584, 0.8584, 0.4197, 0.4702, 0.2755,
0.6499, 0.4702],
[0.4197, 0.2263, 0.6854, 0.2476, 0.3650, 0.4197, 0.4702, 0.2755,
0.6499, 0.3650],
[0.2755, 0.4493, 0.2295, 0.8584, 0.2476, 0.6238, 0.8584, 0.8977,
0.3117, 0.4702],
[0.9837, 0.9272, 0.2295, 0.8584, 0.4702, 0.8977, 0.6275, 0.4197,
0.2295, 0.2476]],
[[0.6208, 0.3900, 0.3329, 0.0793, 0.0793, 0.4867, 0.2679, 0.6208,
0.0895, 0.2679],
[0.4867, 0.5515, 0.1499, 0.5915, 0.1707, 0.4867, 0.2679, 0.6208,
0.0895, 0.1707],
[0.6208, 0.3900, 0.2026, 0.0793, 0.5915, 0.6188, 0.0793, 0.8483,
0.7633, 0.2679],
[0.1121, 0.8151, 0.2026, 0.0793, 0.2679, 0.8483, 0.1170, 0.4867,
0.2026, 0.5915]],
[[0.5561, 0.2199, 0.3934, 0.4195, 0.4195, 0.8940, 0.5191, 0.5561,
0.7230, 0.5191],
[0.8940, 0.1317, 0.7833, 0.9322, 0.0095, 0.8940, 0.5191, 0.5561,
0.7230, 0.0095],
[0.5561, 0.2199, 0.7059, 0.4195, 0.9322, 0.2284, 0.4195, 0.4250,
0.3325, 0.5191],
[0.1764, 0.6584, 0.7059, 0.4195, 0.5191, 0.4250, 0.5098, 0.8940,
0.7059, 0.9322]]])
再比如这个例子:
Q = torch.rand((B,H,L_Q,D))
M_top = torch.randint(L_Q,(B,H,sample_Q))
Q_reduce = Q[torch.arange(B)[:, None, None],
torch.arange(H)[None, :, None],
M_top, :] # (B,H,sample_Q)
Q_reduce = Q[torch.arrange(B).unsqueeze(-1).unsqueeze(-1),torch.arrange(H).unsqueeze(0).unsqueeze(-1),M_top,:] # (B,H,sample_Q)
两种Q_reduce的写法是一致的。