• Stacked Hourglass Networks for Human Pose Estimation 源码分析


    基于top-down方法的人体姿态估计模型源码解析

    Rethinking on Multi-Stage Networks for Human Pose Estimation 源码分析_那时那月那人的博客-CSDN博客

    HRNet 源码分析_那时那月那人的博客-CSDN博客

    CPN-Cascaded Pyramid Network for Multi-Person Pose Estimation 源码分析_那时那月那人的博客-CSDN博客
    论文地址https://arxiv.org/pdf/1603.06937.pdf


    GitHub - princeton-vl/pytorch_stacked_hourglass: Pytorch implementation of "Stacked Hourglass Networks for Human Pose Estimation"Pytorch implementation of "Stacked Hourglass Networks for Human Pose Estimation" - GitHub - princeton-vl/pytorch_stacked_hourglass: Pytorch implementation of "Stacked Hourglass Networks for Human Pose Estimation"https://github.com/princeton-vl/pytorch_stacked_hourglass

     人体姿态估计 一般分为两个方向: Top-down 和 bottom-up

    top-down 方法依赖于 目标检测 需要先检测出一个个人  然后对单个人进行后续姿态估计

    比如: Stacked Hourglass Networks  、Cascaded Pyramid Network、CPN 、MSPN、HRNet等等。

    bottom-up 和 top-down 相反   先确定人 然后在进行分组  比如  open-pose、HigherHRNet等等。

    首先我们分析下网络结构:

    1. class PoseNet(nn.Module):
    2. def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=0, **kwargs):
    3. super(PoseNet, self).__init__()
    4. self.nstack = nstack
    5. self.pre = nn.Sequential(
    6. Conv(3, 64, 7, 2, bn=True, relu=True),
    7. Residual(64, 128),
    8. Pool(2, 2),
    9. Residual(128, 128),
    10. Residual(128, inp_dim)
    11. )
    12. self.hgs = nn.ModuleList( [
    13. nn.Sequential(
    14. Hourglass(4, inp_dim, bn, increase),
    15. ) for i in range(nstack)] )
    16. self.features = nn.ModuleList( [
    17. nn.Sequential(
    18. Residual(inp_dim, inp_dim),
    19. Conv(inp_dim, inp_dim, 1, bn=True, relu=True)
    20. ) for i in range(nstack)] )
    21. self.outs = nn.ModuleList( [Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] )
    22. self.merge_features = nn.ModuleList( [Merge(inp_dim, inp_dim) for i in range(nstack-1)] )
    23. self.merge_preds = nn.ModuleList( [Merge(oup_dim, inp_dim) for i in range(nstack-1)] )
    24. self.nstack = nstack
    25. self.heatmapLoss = HeatmapLoss()
    26. def forward(self, imgs):
    27. ## our posenet
    28. # shape (B, H, W, C) -> (B, C, H, W)
    29. x = imgs.permute(0, 3, 1, 2) #x of size 1,3,inpdim,inpdim
    30. # 图片缩小四倍 经过一个 k=7, s=2得卷积核缩小2倍 接一个 残差块
    31. # 然后经过一个池化层在缩小2倍 通道输变成 256 然后在接 两个残差块
    32. # shape (B, 256, H // 4, W // 4)
    33. x = self.pre(x)
    34. combined_hm_preds = []
    35. # 堆叠得 stack hourglasses 层数 可以自己设置 这里 是 8 个
    36. for i in range(self.nstack):
    37. # 这里 就是一个 类似 unet结构得残差连接
    38. # 先下采样到 H // (4 * 8) * W // (4 * 8) 然后在上采用到 (H // 4, W // 4)
    39. # shape: (B, H // 4, W // 4, 256)
    40. hg = self.hgs[i](x)
    41. # shape: (B, 256, H // 4, W // 4)
    42. feature = self.features[i](hg)
    43. # shape: (B, num_joints, H // 4, W // 4) 连接数
    44. preds = self.outs[i](feature)
    45. combined_hm_preds.append(preds)
    46. # 对于前面得堆叠让其进行 预测 也就是 论文中所述: 中间监督 让网络越来越好
    47. if i < self.nstack - 1:
    48. # 让 heatmap 和 feature 融合 进入下一个 hourglasses
    49. x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)
    50. # 将所有堆叠得 hourglasse 输出返回
    51. return torch.stack(combined_hm_preds, 1)

    现在网络结构和输出都有了 我们来看下 网络得损失函数 损失函数很简单 就是 用MSE

    1. def calc_loss(self, combined_hm_preds, heatmaps):
    2. combined_loss = []
    3. # 对每个堆叠块 进行损失计算
    4. for i in range(self.nstack):
    5. # 计算每个堆叠块得损失
    6. combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps))
    7. combined_loss = torch.stack(combined_loss, dim=1)
    8. return combined_loss
    9. class HeatmapLoss(torch.nn.Module):
    10. """
    11. loss for detection heatmap
    12. """
    13. def __init__(self):
    14. super(HeatmapLoss, self).__init__()
    15. def forward(self, pred, gt):
    16. """
    17. pred: shape (B, num_joints, H, W)
    18. gt shape (B, num_joints, H, W)
    19. """
    20. # 就是 简单得平方差 计算
    21. l = ((pred - gt)**2)
    22. l = l.mean(dim=3).mean(dim=2).mean(dim=1)
    23. return l ## l of dim bsize

     最后分析下 heatmap 对应得groundtruth怎么生成得。其实很简单就是 用二维高斯函数来计算周围点到 关键点得距离 。当然heatmap如果很大 计算所有点没必要。所有只需要就算距离关键点(x,y)一定范围内得距离即可。为什么不直接 设置关键点为1 其他点都为0 这样导致 负样本过多,正样本只有一个,模型无法学习。

    1. class GenerateHeatmap():
    2. def __init__(self, output_res, num_parts):
    3. self.output_res = output_res
    4. self.num_parts = num_parts
    5. # 计算一个一定范围得二维高斯函数
    6. sigma = self.output_res/64 # 这里 sigma = 1
    7. self.sigma = sigma
    8. size = 6*sigma + 3 # size = 9 一般都是取一奇数 这样有中心点 也就是对应关键点得位置
    9. x = np.arange(0, size, 1, float)
    10. y = x[:, np.newaxis]
    11. x0, y0 = 3*sigma + 1, 3*sigma + 1
    12. # 得到一个 (size, size) 得 二维高斯函数 中心点值为1 其余点按照高斯分布降低
    13. self.g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
    14. def __call__(self, keypoints):
    15. # (num_joints, H, W)
    16. hms = np.zeros(shape = (self.num_parts, self.output_res, self.output_res), dtype = np.float32)
    17. sigma = self.sigma
    18. for p in keypoints:
    19. for idx, pt in enumerate(p):
    20. if pt[0] > 0:
    21. x, y = int(pt[0]), int(pt[1])
    22. if x<0 or y<0 or x>=self.output_res or y>=self.output_res:
    23. continue
    24. # 取一个 范围 来用上面计算得高斯函数来进行覆盖
    25. ul = int(x - 3*sigma - 1), int(y - 3*sigma - 1)
    26. br = int(x + 3*sigma + 2), int(y + 3*sigma + 2)
    27. c,d = max(0, -ul[0]), min(br[0], self.output_res) - ul[0]
    28. a,b = max(0, -ul[1]), min(br[1], self.output_res) - ul[1]
    29. cc,dd = max(0, ul[0]), min(br[0], self.output_res)
    30. aa,bb = max(0, ul[1]), min(br[1], self.output_res)
    31. # 用 self.g 来进行赋值
    32. hms[idx, aa:bb,cc:dd] = np.maximum(hms[idx, aa:bb,cc:dd], self.g[a:b,c:d])
    33. return hms

    这是一个 经典得 top-down 形式 人体姿态估计网络。

  • 相关阅读:
    [计算机入门] Windows附件程序介绍(办公类)
    Linux常见指令(1)
    python基础教程视频学习如何使用Python编程语言
    【论文阅读笔记】NTIRE 2022 Burst Super-Resolution Challenge
    【畅购商城】购物车模块之查看购物车
    Android拖放startDragAndDrop拖拽onDrawShadow动态添加View,Kotlin(3)
    探索安全之道 | 企业漏洞管理:从理念到行动
    Ajax技术【Ajax技术详解、 Ajax 的使用、Ajax请求、 JSON详解、JACKSON 的使用 】(一)-全面详解(学习总结---从入门到深化)
    java(JVM)
    国科云:什么是DHCP?DHCP是怎么工作的?
  • 原文地址:https://blog.csdn.net/xiaoxu1025/article/details/127835690