• tensor.nozero(), mask, [mask]


    1. nozero()

    https://pytorch.org/docs/stable/generated/torch.nonzero.html?highlight=nonzero#torch.nonzero

    torch.nonzero(input, *, out=None, as_tuple=False) → LongTensor or tuple of LongTensors
    
    • 1

    torch.nonzero(..., as_tuple=False) (default) returns a 2-D tensor where each row is the index for a nonzero value.
    torch.nonzero(..., as_tuple=False) (默认)返回一个2-D张量,其中每一行是一个非零值的索引。

    tensor.nozero()默认返回一个2维的tensor, 里面是符合条件的索引.

    举个例子:

    import torch
    
    
    x = torch.randint(low=2, high=3, size=[2, 3])
    idx = x.nonzero()
    
    print(f"x:\n{x}\n")
    print(f"idx:\n{idx}\n")
    
    print(f"x.size: {x.size()}")
    print(f"idx.size: {idx.size()}")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    结果:

    x:
    tensor([[2, 2, 2],
            [2, 2, 2]])
    
    idx:
    tensor([[0, 0],
            [0, 1],
            [0, 2],
            [1, 0],
            [1, 1],
            [1, 2]])
    
    x.size: torch.Size([2, 3])
    idx.size: torch.Size([6, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    可以看到, 这里的idx是用2-D tensor来表示x中符合条件元素的索引.

    因为x有两个维度, 所以idx中的size()[1]也是2


    如果输入是多维的tensor, 那么表示也只会用2-D的tensor. 例子如下:

    import torch
    
    
    x = torch.randint(low=2, high=3, size=[2, 3, 4])
    idx = x.nonzero()
    
    print(f"x:\n{x}\n")
    print(f"idx:\n{idx}\n")
    
    print(f"x.size: {x.size()}")
    print(f"idx.size: {idx.size()}")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    结果:

    x:
    tensor([[[2, 2, 2, 2],
             [2, 2, 2, 2],
             [2, 2, 2, 2]],
    
            [[2, 2, 2, 2],
             [2, 2, 2, 2],
             [2, 2, 2, 2]]])
    
    idx:
    tensor([[0, 0, 0],
            [0, 0, 1],
            [0, 0, 2],
            [0, 0, 3],
            [0, 1, 0],
            [0, 1, 1],
            [0, 1, 2],
            [0, 1, 3],
            [0, 2, 0],
            [0, 2, 1],
            [0, 2, 2],
            [0, 2, 3],
            [1, 0, 0],
            [1, 0, 1],
            [1, 0, 2],
            [1, 0, 3],
            [1, 1, 0],
            [1, 1, 1],
            [1, 1, 2],
            [1, 1, 3],
            [1, 2, 0],
            [1, 2, 1],
            [1, 2, 2],
            [1, 2, 3]])
    
    x.size: torch.Size([2, 3, 4])
    idx.size: torch.Size([24, 3])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    因为x有3个维度, 所以idx中的size()[1]也是3

    再举个例子:

    import torch
    
    
    x = torch.randint(low=2, high=3, size=[2, 3, 4, 2])
    idx = x.nonzero()
    
    print(f"x:\n{x}\n")
    print(f"idx:\n{idx}\n")
    
    print(f"x.size: {x.size()}")
    print(f"idx.size: {idx.size()}")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    结果:

    x:
    tensor([[[[2, 2],
              [2, 2],
              [2, 2],
              [2, 2]],
    
             [[2, 2],
              [2, 2],
              [2, 2],
              [2, 2]],
    
             [[2, 2],
              [2, 2],
              [2, 2],
              [2, 2]]],
    
    
            [[[2, 2],
              [2, 2],
              [2, 2],
              [2, 2]],
    
             [[2, 2],
              [2, 2],
              [2, 2],
              [2, 2]],
    
             [[2, 2],
              [2, 2],
              [2, 2],
              [2, 2]]]])
    
    idx:
    tensor([[0, 0, 0, 0],
            [0, 0, 0, 1],
            [0, 0, 1, 0],
            [0, 0, 1, 1],
            [0, 0, 2, 0],
            [0, 0, 2, 1],
            [0, 0, 3, 0],
            [0, 0, 3, 1],
            [0, 1, 0, 0],
            [0, 1, 0, 1],
            [0, 1, 1, 0],
            [0, 1, 1, 1],
            [0, 1, 2, 0],
            [0, 1, 2, 1],
            [0, 1, 3, 0],
            [0, 1, 3, 1],
            [0, 2, 0, 0],
            [0, 2, 0, 1],
            [0, 2, 1, 0],
            [0, 2, 1, 1],
            [0, 2, 2, 0],
            [0, 2, 2, 1],
            [0, 2, 3, 0],
            [0, 2, 3, 1],
            [1, 0, 0, 0],
            [1, 0, 0, 1],
            [1, 0, 1, 0],
            [1, 0, 1, 1],
            [1, 0, 2, 0],
            [1, 0, 2, 1],
            [1, 0, 3, 0],
            [1, 0, 3, 1],
            [1, 1, 0, 0],
            [1, 1, 0, 1],
            [1, 1, 1, 0],
            [1, 1, 1, 1],
            [1, 1, 2, 0],
            [1, 1, 2, 1],
            [1, 1, 3, 0],
            [1, 1, 3, 1],
            [1, 2, 0, 0],
            [1, 2, 0, 1],
            [1, 2, 1, 0],
            [1, 2, 1, 1],
            [1, 2, 2, 0],
            [1, 2, 2, 1],
            [1, 2, 3, 0],
            [1, 2, 3, 1]])
    
    x.size: torch.Size([2, 3, 4, 2])
    idx.size: torch.Size([48, 4])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84

    因为x有4个维度, 所以idx中的size()[1]也是4

    终于, 我们可以发现结论了:

    • tensor.nozero()默认返回的是一个2-Dtensor, 行表示非零元素的索引, 列的大小 = 输入tensor维度的数量

    2. mask

    在手撸YOLOv3时, 需要写mask, 一般是根据置信度是否大于阈值来判断是否为正样本:

    • ≥ thresh: 张样本
    • < thresh: 负样本

    在YOLOv3中, 网络的前向推理结果一般为 [N, H, W, 3, 8], 其中:

    • N: batch size
    • H: height
    • W: weight
    • 3: 3种预测尺度
    • 8: conf + loc + cls = 1 + 4 + 3

    为了能够取到置信度conf, 因此output[..., 0], 由此就可以得到mask:

    output = net(input)  # [N, H, W, 3, 8]
    
    mask = output[..., 0] > thresh  # [N, H, W, 3]
    
    • 1
    • 2
    • 3

    那么问题来了: 为什么得到的mask的size为[N, H, W, 3]呢?

    我们讲一下mask这样的取法究竟是在做什么.


    import torch
    
    
    x = torch.randint(low=2, high=3, size=[2, 3])
    
    mask = x[..., 0] > 1  # 最后一个维度中的第一个元素如果大于1, 则前面的维度返回True, 否则返回False
    
    print(f"x:\n{x}\n")
    print(f"x[..., 0]:\n{x[..., 0]}\n")
    print(f"x[..., 0].size:\n{x[..., 0].size()}\n")
    print(f"mask:\n{mask}\n")
    print(f"mask.size:\n{mask.size()}\n")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    x:
    tensor([[2, 2, 2],
            [2, 2, 2]])
    
    x[..., 0]:
    tensor([2, 2])
    
    x[..., 0].size:
    torch.Size([2])
    
    mask:
    tensor([True, True])
    
    mask.size:
    torch.Size([2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    或许这样看还不够明显, 那么我们将x的值修改一下:

    import torch
    
    
    x = torch.tensor(data=[[0, 2, 2],
                           [0, 2, 2]], dtype=torch.int8)
    
    mask = x[..., 0] > 1  # 最后一个维度中的第一个元素如果大于1, 则前面的维度返回True, 否则返回False
    
    print(f"x:\n{x}\n")
    print(f"x[..., 0]:\n{x[..., 0]}\n")
    print(f"x[..., 0].size:\n{x[..., 0].size()}\n")
    print(f"mask:\n{mask}\n")
    print(f"mask.size:\n{mask.size()}\n")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    结果:

    x:
    tensor([[0, 2, 2],
            [0, 2, 2]], dtype=torch.int8)
    
    x[..., 0]:
    tensor([0, 0], dtype=torch.int8)
    
    x[..., 0].size:
    torch.Size([2])
    
    mask:
    tensor([False, False])
    
    mask.size:
    torch.Size([2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    看起来, 这个函数取的是第一列(这里的描述并不准确). 那么我们在将x的维度扩充一下看看:

    import torch
    
    
    x = torch.randint(low=0, high=3, size=[2, 3, 2])
    
    mask = x[..., 0] > 1  # 最后一个维度中的第一个元素如果大于1, 则前面的维度返回True, 否则返回False
    
    print(f"x:\n{x}\n")
    print(f"x[..., 0]:\n{x[..., 0]}\n")
    print(f"x[..., 0].size:\n{x[..., 0].size()}\n")
    print(f"mask:\n{mask}\n")
    print(f"mask.size:\n{mask.size()}\n")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    结果:

    x:
    tensor([[[2, 1],
             [2, 1],
             [1, 0]],
    
            [[2, 0],
             [1, 2],
             [0, 0]]])
    
    x[..., 0]:
    tensor([[2, 2, 1],
            [2, 1, 0]])
    
    x[..., 0].size:
    torch.Size([2, 3])
    
    mask:
    tensor([[ True,  True, False],
            [ True, False, False]])
    
    mask.size:
    torch.Size([2, 3])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    结论仍然是对的, 我们再将x扩充至 [N, H, W, 3]:

    import torch
    
    
    x = torch.randint(low=0, high=3, size=[2, 3, 2, 1])
    
    mask = x[..., 0] > 1  # 最后一个维度中的第一个元素如果大于1, 则前面的维度返回True, 否则返回False
    
    print(f"x:\n{x}\n")
    print(f"x[..., 0]:\n{x[..., 0]}\n")
    print(f"x[..., 0].size:\n{x[..., 0].size()}\n")
    print(f"mask:\n{mask}\n")
    print(f"mask.size:\n{mask.size()}\n")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    x:
    tensor([[[[0],
              [0]],
    
             [[0],
              [1]],
    
             [[2],
              [2]]],
    
    
            [[[1],
              [2]],
    
             [[2],
              [2]],
    
             [[2],
              [2]]]])
    
    x[..., 0]:
    tensor([[[0, 0],
             [0, 1],
             [2, 2]],
    
            [[1, 2],
             [2, 2],
             [2, 2]]])
    
    x[..., 0].size:
    torch.Size([2, 3, 2])
    
    mask:
    tensor([[[False, False],
             [False, False],
             [ True,  True]],
    
            [[False,  True],
             [ True,  True],
             [ True,  True]]])
    
    mask.size:
    torch.Size([2, 3, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43

    结论是对的.

    3. [mask]

    上面我们分析了mask究竟是在取什么, 那么将mask应用到x中(x[mask])会是怎么样的呢?

    import torch
    
    
    x = torch.randint(low=0, high=3, size=[2, 3])
    
    mask = x[..., 0] > 1  # 最后一个维度中的第一个元素如果大于1, 则前面的维度返回True, 否则返回False
    
    filtered_x = x[mask]
    
    print(f"x:\n{x}\n")
    print(f"x.size:\n{x.size()}\n")
    
    print("-" * 50)
    
    print(f"x[..., 0]:\n{x[..., 0]}\n")
    print(f"x[..., 0].size:\n{x[..., 0].size()}\n")
    
    print("-" * 50)
    
    print(f"mask:\n{mask}\n")
    print(f"mask.size:\n{mask.size()}\n")
    
    print("-" * 50)
    
    print(f"filtered_x:\n{filtered_x}\n")
    print(f"filtered_x.size:\n{filtered_x.size()}\n")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    结果2:

    x:
    tensor([[2, 2, 1],
            [0, 0, 2]])
    
    x.size:
    torch.Size([2, 3])
    
    --------------------------------------------------
    x[..., 0]:
    tensor([2, 0])
    
    x[..., 0].size:
    torch.Size([2])
    
    --------------------------------------------------
    mask:
    tensor([ True, False])
    
    mask.size:
    torch.Size([2])
    
    --------------------------------------------------
    filtered_x:
    tensor([[2, 2, 1]])
    
    filtered_x.size:
    torch.Size([1, 3])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    mask负责筛选"列", x[mask]负责取出符合条件的行, 且这个行是一个2D tensor.

    结果2:

    x:
    tensor([[0, 1, 0],
            [1, 0, 0]])
    
    x.size:
    torch.Size([2, 3])
    
    --------------------------------------------------
    x[..., 0]:
    tensor([0, 1])
    
    x[..., 0].size:
    torch.Size([2])
    
    --------------------------------------------------
    mask:
    tensor([False, False])
    
    mask.size:
    torch.Size([2])
    
    --------------------------------------------------
    filtered_x:
    tensor([], size=(0, 3), dtype=torch.int64)
    
    filtered_x.size:
    torch.Size([0, 3])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    mask负责筛选"列", x[mask]负责取出符合条件的行, 且这个行是一个2D tensor.


    我们将x的维度提升:

    import torch
    
    
    x = torch.randint(low=0, high=3, size=[2, 3, 2])
    
    mask = x[..., 0] > 1  # 最后一个维度中的第一个元素如果大于1, 则前面的维度返回True, 否则返回False
    
    filtered_x = x[mask]
    
    print(f"x:\n{x}\n")
    print(f"x.size:\n{x.size()}\n")
    
    print("-" * 50)
    
    print(f"x[..., 0]:\n{x[..., 0]}\n")
    print(f"x[..., 0].size:\n{x[..., 0].size()}\n")
    
    print("-" * 50)
    
    print(f"mask:\n{mask}\n")
    print(f"mask.size:\n{mask.size()}\n")
    
    print("-" * 50)
    
    print(f"filtered_x:\n{filtered_x}\n")
    print(f"filtered_x.size:\n{filtered_x.size()}\n")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    结果1:

    x:
    tensor([[[0, 2],
             [0, 2],
             [0, 2]],
    
            [[2, 2],
             [1, 0],
             [1, 2]]])
    
    x.size:
    torch.Size([2, 3, 2])
    
    --------------------------------------------------
    x[..., 0]:
    tensor([[0, 0, 0],
            [2, 1, 1]])
    
    x[..., 0].size:
    torch.Size([2, 3])
    
    --------------------------------------------------
    mask:
    tensor([[False, False, False],
            [ True, False, False]])
    
    mask.size:
    torch.Size([2, 3])
    
    --------------------------------------------------
    filtered_x:
    tensor([[2, 2]])
    
    filtered_x.size:
    torch.Size([1, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    结果2:

    x:
    tensor([[[1, 0],
             [2, 0],
             [2, 1]],
    
            [[2, 0],
             [2, 0],
             [2, 1]]])
    
    x.size:
    torch.Size([2, 3, 2])
    
    --------------------------------------------------
    x[..., 0]:
    tensor([[1, 2, 2],
            [2, 2, 2]])
    
    x[..., 0].size:
    torch.Size([2, 3])
    
    --------------------------------------------------
    mask:
    tensor([[False,  True,  True],
            [ True,  True,  True]])
    
    mask.size:
    torch.Size([2, 3])
    
    --------------------------------------------------
    filtered_x:
    tensor([[2, 0],
            [2, 1],
            [2, 0],
            [2, 0],
            [2, 1]])
    
    filtered_x.size:
    torch.Size([5, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    我们可以发现, 结论 mask负责筛选"列", x[mask]负责取出符合条件的行, 且这个行是一个2D tensor 依然是符合的, 我们再扩充一下维度:

    import torch
    
    
    x = torch.randint(low=0, high=3, size=[2, 3, 2, 2])
    
    mask = x[..., 0] > 1  # 最后一个维度中的第一个元素如果大于1, 则前面的维度返回True, 否则返回False
    
    filtered_x = x[mask]
    
    print(f"x:\n{x}\n")
    print(f"x.size:\n{x.size()}\n")
    
    print("-" * 50)
    
    print(f"x[..., 0]:\n{x[..., 0]}\n")
    print(f"x[..., 0].size:\n{x[..., 0].size()}\n")
    
    print("-" * 50)
    
    print(f"mask:\n{mask}\n")
    print(f"mask.size:\n{mask.size()}\n")
    
    print("-" * 50)
    
    print(f"filtered_x:\n{filtered_x}\n")
    print(f"filtered_x.size:\n{filtered_x.size()}\n")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    结果1:

    x:
    tensor([[[[0, 0],
              [0, 1]],
    
             [[1, 1],
              [0, 0]],
    
             [[0, 1],
              [2, 1]]],
    
    
            [[[2, 2],
              [0, 1]],
    
             [[0, 2],
              [0, 2]],
    
             [[0, 0],
              [2, 2]]]])
    
    x.size:
    torch.Size([2, 3, 2, 2])
    
    --------------------------------------------------
    x[..., 0]:
    tensor([[[0, 0],
             [1, 0],
             [0, 2]],
    
            [[2, 0],
             [0, 0],
             [0, 2]]])
    
    x[..., 0].size:
    torch.Size([2, 3, 2])
    
    --------------------------------------------------
    mask:
    tensor([[[False, False],
             [False, False],
             [False,  True]],
    
            [[ True, False],
             [False, False],
             [False,  True]]])
    
    mask.size:
    torch.Size([2, 3, 2])
    
    --------------------------------------------------
    filtered_x:
    tensor([[2, 1],
            [2, 2],
            [2, 2]])
    
    filtered_x.size:
    torch.Size([3, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57

    结果2:

    x:
    tensor([[[[2, 2],
              [2, 2]],
    
             [[1, 0],
              [0, 2]],
    
             [[0, 2],
              [2, 2]]],
    
    
            [[[0, 0],
              [1, 0]],
    
             [[1, 1],
              [0, 1]],
    
             [[2, 1],
              [1, 0]]]])
    
    x.size:
    torch.Size([2, 3, 2, 2])
    
    --------------------------------------------------
    x[..., 0]:
    tensor([[[2, 2],
             [1, 0],
             [0, 2]],
    
            [[0, 1],
             [1, 0],
             [2, 1]]])
    
    x[..., 0].size:
    torch.Size([2, 3, 2])
    
    --------------------------------------------------
    mask:
    tensor([[[ True,  True],
             [False, False],
             [False,  True]],
    
            [[False, False],
             [False, False],
             [ True, False]]])
    
    mask.size:
    torch.Size([2, 3, 2])
    
    --------------------------------------------------
    filtered_x:
    tensor([[2, 2],
            [2, 2],
            [2, 2],
            [2, 1]])
    
    filtered_x.size:
    torch.Size([4, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58

    结论依然存在!


    4. 结论

    到现在我们终于明白了,

    output = net(input)  # [N, H, W, 3, 8]
    
    mask = output[..., 0] > thresh  # [N, H, W, 3]
    
    • 1
    • 2
    • 3
    • mask: 取出的是符合条件的列 (判断用的维度就消失了, 所以[N, H, W, 3, 8] -> [N, H, W, 3])
    • output[mask]: 取出符合条件的行 (是一个2-D tensor)

    举个例子:

    def get_idx_and_info(self, output, thresh):
        output = output.permute(0, 2, 3, 1)  # [N, 24, H, W] -> [N, H, W, 3*8]
        N, H, W, _ = output.size()
        output = output.reshape(N, H, W, 3, -1)  # N, H, W, 24] -> N, H, W, 3, 8]
        
        # 定义掩码(符合条件的行)
        mask = output[..., 0] > thresh  # 置信度 > thresh  # [N, H, W, 3] (这里应该是自动squeeze了)
        
        idx = mask.nonzero()  # size -> [符合条件数的个数, 4]  # 因为mask的维度个数有4个, 所以这里是4, 
        info = output[mask]  # size -> [符合条件数的个数, 8]  # 8 = c + loc + cls
        
        return idx, info
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
  • 相关阅读:
    Spring Cloud--从零开始搭建微服务基础环境【四】
    学习嵌入式可以胜任哪一些行业?
    Maven介绍、优缺点、生命周期、坐标、依赖...
    MATLAB算法实战应用案例精讲-【神经网络】激活函数:PRelu(附Java、C语言、Python和MATLAB代码)
    让Git自动忽略指定文件
    SaaS 营销怎么做?几点思考
    Qt Charts简介
    1.3 do...while实现1+...100 for实现1+...100
    Win11无法删除文件夹怎么办?Win11无法删除文件夹的解决方法
    第一行代码Android 第九章9.4-9.5(解析JSON格式,网络编程最佳实践:发送HTTP请求的代码)
  • 原文地址:https://blog.csdn.net/weixin_44878336/article/details/126117812