• HRNet 源码分析


    论文 Deep High-Resolution Representation Learning for Human Pose Estimation

    也是 一个 top-down得对于人体姿态估计得检测方法。和Stack hourglass netword, CPN, MSPN等都大致一样。需要先学习一个人体检测器,将每个人都检测出来,然后在送进单个人体姿态估计模型。

    从论文名字可以看出,HIgh-Resolution 高分辨率。

    Stack hourglass netword, CPN, MSPN 模型结构都有一定得相似性,类型和Unet结构相似,加上一些残差。都经历一个下采样然后在进行上采样得过程。

    然后HRNet 稍微有点不同,保持相同大小进行特征传递。每经过一个Transition是多出一个下采样得分支

    如图:

     模型得代码如下:

    1. class PoseHighResolutionNet(nn.Module):
    2. def __init__(self, cfg, **kwargs):
    3. pass
    4. def forward(self, x):
    5. # 下面将x缩小了4倍 两个conv得s=2
    6. x = self.conv1(x)
    7. x = self.bn1(x)
    8. x = self.relu(x)
    9. x = self.conv2(x)
    10. x = self.bn2(x)
    11. x = self.relu(x)
    12. x = self.layer1(x)
    13. x_list = []
    14. # 每经过一个transition 产生一个下采样分支
    15. for i in range(self.stage2_cfg['NUM_BRANCHES']):
    16. if self.transition1[i] is not None:
    17. x_list.append(self.transition1[i](x))
    18. else:
    19. x_list.append(x)
    20. y_list = self.stage2(x_list)
    21. x_list = []
    22. # 每经过一个transition 产生一个下采样分支
    23. for i in range(self.stage3_cfg['NUM_BRANCHES']):
    24. if self.transition2[i] is not None:
    25. x_list.append(self.transition2[i](y_list[-1]))
    26. else:
    27. x_list.append(y_list[i])
    28. y_list = self.stage3(x_list)
    29. x_list = []
    30. # 每经过一个transition 产生一个下采样分支
    31. for i in range(self.stage4_cfg['NUM_BRANCHES']):
    32. if self.transition3[i] is not None:
    33. x_list.append(self.transition3[i](y_list[-1]))
    34. else:
    35. x_list.append(y_list[i])
    36. y_list = self.stage4(x_list)
    37. # 最终模型只取 最后一个stage得 第一层得输出来
    38. x = self.final_layer(y_list[0])
    39. return x

     对于heatmaplabel的生成 和其他网络一样采用 2D高斯函数生成

    1. def generate_target(self, joints, joints_vis):
    2. '''
    3. :param joints: [num_joints, 3]
    4. :param joints_vis: [num_joints, 3]
    5. :return: target, target_weight(1: visible, 0: invisible)
    6. '''
    7. target_weight = np.ones((self.num_joints, 1), dtype=np.float32)
    8. target_weight[:, 0] = joints_vis[:, 0]
    9. assert self.target_type == 'gaussian', \
    10. 'Only support gaussian map now!'
    11. if self.target_type == 'gaussian':
    12. target = np.zeros((self.num_joints,
    13. self.heatmap_size[1],
    14. self.heatmap_size[0]),
    15. dtype=np.float32)
    16. tmp_size = self.sigma * 3
    17. for joint_id in range(self.num_joints):
    18. feat_stride = self.image_size / self.heatmap_size
    19. mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
    20. mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
    21. # Check that any part of the gaussian is in-bounds
    22. ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
    23. br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
    24. if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
    25. or br[0] < 0 or br[1] < 0:
    26. # If not, just return the image as is
    27. target_weight[joint_id] = 0
    28. continue
    29. # # Generate gaussian
    30. # 生成高斯函数进行赋值
    31. size = 2 * tmp_size + 1
    32. x = np.arange(0, size, 1, np.float32)
    33. y = x[:, np.newaxis]
    34. x0 = y0 = size // 2
    35. # The gaussian is not normalized, we want the center value to equal 1
    36. g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.sigma ** 2))
    37. # Usable gaussian range
    38. g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
    39. g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]
    40. # Image range
    41. img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
    42. img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])
    43. v = target_weight[joint_id]
    44. if v > 0.5:
    45. target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
    46. g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
    47. if self.use_different_joints_weight:
    48. target_weight = np.multiply(target_weight, self.joints_weight)
    49. return target, target_weight

     对于人体姿态估计网络的损失函数 基本都是一样的 MSE

    1. class JointsMSELoss(nn.Module):
    2. def __init__(self, use_target_weight):
    3. super(JointsMSELoss, self).__init__()
    4. # 采用 MSE
    5. self.criterion = nn.MSELoss(reduction='mean')
    6. self.use_target_weight = use_target_weight
    7. def forward(self, output, target, target_weight):
    8. batch_size = output.size(0)
    9. num_joints = output.size(1)
    10. heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
    11. heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
    12. loss = 0
    13. for idx in range(num_joints):
    14. heatmap_pred = heatmaps_pred[idx].squeeze()
    15. heatmap_gt = heatmaps_gt[idx].squeeze()
    16. if self.use_target_weight:
    17. # 损失计算
    18. loss += 0.5 * self.criterion(
    19. heatmap_pred.mul(target_weight[:, idx]),
    20. heatmap_gt.mul(target_weight[:, idx])
    21. )
    22. else:
    23. loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
    24. return loss / num_joints

    从代码上看 其实和其他的top-down网络基本相似,貌似仅仅是在网络结构上进行了一定调整。

    网络结构采用并行的传递方式。

    基于inference代码 基本上和其他top-down网络代码相似。

    有兴趣可以看看其他几篇top-down网络的源码分析

    Rethinking on Multi-Stage Networks for Human Pose Estimation 源码分析_那时那月那人的博客-CSDN博客Rethinking on Multi-Stage Networks for Human Pose Estimation 源码分析https://blog.csdn.net/xiaoxu1025/article/details/127840623CPN-Cascaded Pyramid Network for Multi-Person Pose Estimation 源码分析_那时那月那人的博客-CSDN博客CPN 源码分析https://blog.csdn.net/xiaoxu1025/article/details/127838074Stacked Hourglass Networks for Human Pose Estimation 源码分析_那时那月那人的博客-CSDN博客Stacked Hourglass Networks for Human Pose Estimation 源码分析从源码分析 Stacked Hourglass Networks 在人体检测方向得具体实现https://blog.csdn.net/xiaoxu1025/article/details/127835690

    到此 基于 top-down 方法的人体姿态检测 网络模型告一段落。

    如果对采用bottom-up方式的HigherHRNet感兴趣可以移步下面链接

    HigherHRNet 源码分析_那时那月那人的博客-CSDN博客

  • 相关阅读:
    目标检测论文解读复现之八:基于YOLOv5s的滑雪人员检测研究
    Linux命令(85)之mkdir
    ros2机器人上位机与下位机连接方式(转载)
    3D激光SLAM:LIO-SAM整体介绍与安装编译
    使用 MoveIt 控制自己的真实机械臂【4】——了解 MoveIt 的轨迹规划实现机制
    基于C语言的使用checksum进行差错检测
    合肥大厂校招
    CF487C Prefix Product Sequence 题解
    数据结构 | (二) List
    【非正式协议 Objective-C语言】
  • 原文地址:https://blog.csdn.net/xiaoxu1025/article/details/127843498