• 关于yolov8-class Pose(Detect)


    下面看一下代码:

    class Pose(Detect):
       """YOLOv8 Pose head for keypoints models."""
    
        def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
            """Initialize YOLO network with default parameters and Convolutional Layers."""
            super().__init__(nc, ch)
            self.kpt_shape = kpt_shape  # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
            self.nk = kpt_shape[0] * kpt_shape[1]  # number of keypoints total
            self.detect = Detect.forward
    
            c4 = max(ch[0] // 4, self.nk)
            self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
    
        def forward(self, x):
            """Perform forward pass through YOLO model and return predictions."""
            bs = x[0].shape[0]  # batch size
            kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1)  # (bs, 17*3, h*w)
            x = self.detect(self, x)
            if self.training:
                return x, kpt
            pred_kpt = self.kpts_decode(bs, kpt)
            return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
    
        def kpts_decode(self, bs, kpts):
            """Decodes keypoints."""
            ndim = self.kpt_shape[1]
            print(self.anchors)
            if self.export:  # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
                y = kpts.view(bs, *self.kpt_shape, -1)
                a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
                if ndim == 3:
                    a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
                return a.view(bs, self.nk, -1)
            else:
                y = kpts.clone()
                if ndim == 3:
                    y[:, 2::3] = y[:, 2::3].sigmoid()  # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
                y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
                y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
                return y
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    对于:

    y[:, 2::3] = y[:, 2::3].sigmoid()
    
    • 1

    (1)执行的是对关键点解码过程中的一个特定步骤,特别针对预测关键点的第三个维度(如果存在)进行 sigmoid 激活函数的操作。这样做的目的是将预测值转换成概率值,通常用于处理置信度或其他需要被限制在 0 和 1 之间的值。

    (2)y[:, 2::3]: 这是 Python 切片操作的一个例子。这里 y 是一个多维数组(在这种情况下很可能是一个二维数组,代表批量的关键点预测),: 表示选择所有行,2::3 的意思是从索引 2 开始,每隔 3 个元素选取一个。具体到这里,假设每个关键点有三个值(例如,x 位置、y 位置和一个置信度或者其他某种指标),这个操作就是在选择关键点的第三个值。
    (3).sigmoid(): 这是对选中的元素应用 Sigmoid 函数。Sigmoid 函数是一种常见的激活函数,在深度学习中广泛使用,特别是在分类问题上。它能将任意值映射到 0 和 1 之间,适合用作将输出解释为概率。在这个上下文中,使用 Sigmoid 可能是为了将第三个值(例如置信度)转换为概率表示。

    对于:

    y[:, :, :2] * 2.0 + (self.anchors - 0.5) * self.strides
    
    • 1

    (1)self.anchors 表示锚点的中心位置。对于锚点,通常 (x, y) 坐标会被归一化在 [0, 1] 范围,表示这些中心点相对于原始图像大小的位置。
    (2)从 self.anchors 中减去 0.5 是为了把锚点中心转换到一个以网络预测位置为中心的相对坐标系统,然后放缩。
    self.strides 对应于从网络输入到特征图尺度的缩小比例。因此,(self.anchors - 0.5) * self.strides 会根据步长将锚点中心转换到特征图的尺度。
    (3)y[:, :, :2] 是网络对于每个锚点位置的偏移预测,通过乘以 2.0 将这个预测偏移放缩至预期的大小范围(因为网络输出通常是限制在 [0, 1] 之间的),使得这个偏移能够表示出更远的距离。

  • 相关阅读:
    搭建vue3项目并git管理
    网络原理(网络协议初识)
    ESP8266 使用 DRV8833驱动板驱动N20电机
    迪米特法则~
    nginx 基本使用、借助 nginx 和 mkcert 实现本地 https://localhost 测试。
    【医学分割】u2net
    kube-apiserver准入
    C语言数据结构-----单链表(无头单向不循环)
    [机缘参悟-76]:沟通技巧-职场中常见不合适语言的案例分析(尽量避免使用反问式语言)
    易基因:NAR:RCMS编辑系统在特定细胞RNA位点的靶向m5C甲基化和去甲基化研究|项目文章
  • 原文地址:https://blog.csdn.net/weixin_43269994/article/details/138164055