• PyTorch 中的【高级索引】 或 【花式索引】


    PyTorch 中,“高级索引” 或 “花式索引” 允许用户使用整数数组或布尔数组进行索引,从而获取张量中的指定元素。

    具体规则:

    1. 整数数组索引:使用一个整数数组作为索引,可以获取张量中指定位置的元素。

    2. 布尔数组索引:使用一个布尔数组作为索引,可以根据布尔数组的 True/False 值获取张量中对应位置的元素。

    举例说明:

    整数数组索引:
    import torch
    
    # 创建一个张量
    tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
    
    # 使用整数数组索引获取指定位置的元素
    indices = torch.tensor([0, 2])  # 指定要获取的行索引
    result = tensor[indices]  # 获取指定行的元素
    print(result)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    输出:

    tensor([[1, 2],
            [5, 6]])
    
    • 1
    • 2
    布尔数组索引:
    import torch
    
    # 创建一个张量
    tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
    
    # 创建一个布尔数组,用于选择元素
    mask = torch.tensor([True, False, True])  # 选择第1和第3行的元素
    result = tensor[mask]  # 根据布尔数组选择元素
    print(result)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    输出:

    tensor([[1, 2],
            [5, 6]])
    
    • 1
    • 2

    这些是使用 PyTorch 中的高级索引进行元素选择的基本规则和示例。

    多维数组索引1
    import torch
    
    # 创建一个示例张量
    tensor = torch.tensor([
        [[1, 2, 3], [4, 5, 6]],
        [[7, 8, 9], [10, 11, 12]],
        [[13, 14, 15], [16, 17, 18]]
    ])
    
    # 创建两个索引数组,用于选择元素
    indices1 = torch.tensor([[0, 1], [2, 0]])  # 选择第0个维度(行)的索引
    indices2 = torch.tensor([[1, 2], [0, 2]])  # 选择第1个维度(列)的索引
    
    # 使用多维数据索引选择元素
    result = tensor[indices1, indices2]  # 根据索引数组选择元素
    print(result)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    输出:

    tensor([[[ 4,  5,  6],
             [ 7,  8,  9]],
    
            [[13, 14, 15],
             [ 4,  5,  6]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    多维数组索引2
    lq = 4
    lk = 5
    k = torch.rand((3,lq,lk))
    index1 = torch.randint(lq,(1,2*lk))
    index2 = torch.randint(lk,(lq,2*lk))
    print(index1.shape,index2.shape)
    print(k.shape,k[:, index1, index2].shape)
    print('index1:\n',index1)
    print('index2:\n',index2)
    print('k:\n',k)
    print('index k:\n',k[:,index1,index2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    其中,index1和index2可以互相broadcasting,并且index1和index2broadcasting的shape决定了索引结果的shape。

    输出为:

    torch.Size([1, 10]) torch.Size([4, 10])
    torch.Size([3, 4, 5]) torch.Size([3, 4, 10])
    index1:
     tensor([[0, 1, 2, 3, 3, 0, 3, 0, 2, 3]])
    index2:
     tensor([[1, 1, 0, 0, 0, 3, 4, 1, 4, 4],
            [3, 2, 2, 2, 1, 3, 4, 1, 4, 1],
            [1, 1, 3, 0, 2, 2, 0, 0, 1, 4],
            [4, 3, 3, 0, 4, 0, 3, 3, 3, 2]])
    k:
     tensor([[[0.8977, 0.2755, 0.6238, 0.4197, 0.9837],
             [0.1231, 0.4493, 0.2263, 0.9272, 0.6347],
             [0.3343, 0.3117, 0.6854, 0.2295, 0.6499],
             [0.8584, 0.3650, 0.2476, 0.6275, 0.4702]],
    
            [[0.8483, 0.6208, 0.6188, 0.4867, 0.1121],
             [0.7733, 0.3900, 0.5515, 0.8151, 0.0637],
             [0.3329, 0.7633, 0.1499, 0.2026, 0.0895],
             [0.0793, 0.1707, 0.5915, 0.1170, 0.2679]],
    
            [[0.4250, 0.5561, 0.2284, 0.8940, 0.1764],
             [0.8897, 0.2199, 0.1317, 0.6584, 0.7289],
             [0.3934, 0.3325, 0.7833, 0.7059, 0.7230],
             [0.4195, 0.0095, 0.9322, 0.5098, 0.5191]]])
    index k:
     tensor([[[0.2755, 0.4493, 0.3343, 0.8584, 0.8584, 0.4197, 0.4702, 0.2755,
              0.6499, 0.4702],
             [0.4197, 0.2263, 0.6854, 0.2476, 0.3650, 0.4197, 0.4702, 0.2755,
              0.6499, 0.3650],
             [0.2755, 0.4493, 0.2295, 0.8584, 0.2476, 0.6238, 0.8584, 0.8977,
              0.3117, 0.4702],
             [0.9837, 0.9272, 0.2295, 0.8584, 0.4702, 0.8977, 0.6275, 0.4197,
              0.2295, 0.2476]],
    
            [[0.6208, 0.3900, 0.3329, 0.0793, 0.0793, 0.4867, 0.2679, 0.6208,
              0.0895, 0.2679],
             [0.4867, 0.5515, 0.1499, 0.5915, 0.1707, 0.4867, 0.2679, 0.6208,
              0.0895, 0.1707],
             [0.6208, 0.3900, 0.2026, 0.0793, 0.5915, 0.6188, 0.0793, 0.8483,
              0.7633, 0.2679],
             [0.1121, 0.8151, 0.2026, 0.0793, 0.2679, 0.8483, 0.1170, 0.4867,
              0.2026, 0.5915]],
    
            [[0.5561, 0.2199, 0.3934, 0.4195, 0.4195, 0.8940, 0.5191, 0.5561,
              0.7230, 0.5191],
             [0.8940, 0.1317, 0.7833, 0.9322, 0.0095, 0.8940, 0.5191, 0.5561,
              0.7230, 0.0095],
             [0.5561, 0.2199, 0.7059, 0.4195, 0.9322, 0.2284, 0.4195, 0.4250,
              0.3325, 0.5191],
             [0.1764, 0.6584, 0.7059, 0.4195, 0.5191, 0.4250, 0.5098, 0.8940,
              0.7059, 0.9322]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51

    再比如这个例子:

    Q = torch.rand((B,H,L_Q,D))
    M_top = torch.randint(L_Q,(B,H,sample_Q))
    Q_reduce = Q[torch.arange(B)[:, None, None],
                       torch.arange(H)[None, :, None],
                       M_top, :]  # (B,H,sample_Q)
    Q_reduce = Q[torch.arrange(B).unsqueeze(-1).unsqueeze(-1),torch.arrange(H).unsqueeze(0).unsqueeze(-1),M_top,:]  # (B,H,sample_Q)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    两种Q_reduce的写法是一致的。

  • 相关阅读:
    SpringBoot-Rest 风格请求处理
    微服务实战 04 springCloud 集成 Sentinel 实战
    2022“杭电杯”中国大学生算法设计超级联赛(9)
    win10&阿里云实现内网穿透#frp
    IB课程 EE怎么写?
    条款1:视C++为一个语言联邦
    16. 文件上传
    异地远程访问内网BUG管理系统【Cpolar内网穿透】
    逆向分析 工具、加壳、安全防护篇
    jsp公交查询系统Myeclipse开发mysql数据库web结构java编程计算机网页项目
  • 原文地址:https://blog.csdn.net/qq_40940771/article/details/137600485