• 模型部署——CenterPoint转ONNX(自定义onnx算子)


    CenterPoint基于OpenPcDet导出一个完整的ONNX,并用TensorRT推理,部署几个难点如下:

    1.计算pillar中每个点相对几何中心的偏移,取下标方式进行计算是的整个计算图变得复杂,同时这种赋值方式导致运行在pytorch为浅拷贝,而在一些推理后端上表现为深拷贝

    • 修改代码,使用矩阵切片代替原先的操作,使导出的模型在推理后端上的行为结果和pytorch一致,并简化计算图,同时,计算网格坐标也需要修改,修改代码如下:
              # points_xyz = points[:, [0, 1, 2]].contiguous() 
              points_xyz = points[..., :3] 
              points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0)
      
              points_mean = scatter_mean(points_xyz,unq_inv)
              # # 每个点相对voxel质心的偏移
              f_cluster = points_xyz - points_mean[unq_inv, :] # torch.Size([1067877, 3])
              f_center = torch.zeros_like(points_xyz).to()
              # 每个点相对几何中心的偏移
              # f_center[:, 0] = points_xyz[:, 0] - (points_coords[:, 0].to(points_xyz.dtype) * self.voxel_x + self.x_offset)
              # f_center[:, 1] = points_xyz[:, 1] - (points_coords[:, 1].to(points_xyz.dtype) * self.voxel_y + self.y_offset)
              # f_center[:, 2] = points_xyz[:, 2] - self.z_offset
              device = points_xyz.device
              f_center = points_xyz - (points_coords * torch.tensor([self.voxel_x, self.voxel_y, self.voxel_z]).to(device) + torch.tensor([self.z_offset, self.y_offset, self.x_offset]).to(device))
      
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    2.torch_scatterscatter_meanscatter_max onnx不支持,需人为自定义onnx节点,后续并自定义tensorRTScatterMeanPluginScatterMaxPlugin算子

    自定义onnx ScatterMax 算子如下,这里ScatterMax算子没有具体实现,仅为了增加相应的onnx节点,好导出onnx计算图,方便后续自定义实现TensorRT算子,实际上导出onnx并不能用onnxruntime来推理,这样做好处:我们可以只需要自定义实现TensorRT算子,对onnx增加相应节点就行,而不需要管具体的onnx算子实现。

    class ScatterMax(torch.autograd.Function):
        @staticmethod
        def forward(ctx,src,index):
        	  # 调unique仅为了输出对应的维度信息
            temp = torch.unique(src)
            out = torch.zeros((temp.shape[0],src.shape[1]),dtype=torch.float32,device=src.device)
            return out
        @staticmethod
        def symbolic(g,src,index):
            return g.op("xiaohu::ScatterMaxPlugin",src,index)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    ScatterMeanPluginScatterBevPlugin节点和ScatterMaxPlugin节点定义方式是类似的

    3.torch.stack算子 onnx不支持,导出onnx计算图很乱,将torch.stack和后续PointPillarScatter操作合并,一起定义为ScatterBevPlugin算子,自定义onnx节点和TensorRT算子来实现,ScatterBevPlugin实现功能和以下代码功能一致:

            voxel_coords = torch.stack((unq_coords // self.scale_xy, (unq_coords % self.scale_xy) // self.scale_y, unq_coords % self.scale_y,
                                       torch.zeros(unq_coords.shape[0]).to(unq_coords.device).int()), dim=1)
            # 将voxel_coords
            voxel_coords = voxel_coords[:, [0, 3, 2, 1]] # index,z,y,x
    
            pillars_feature = features.t()  # float32[64,pillar_num]
            spatial_feature = torch.zeros(64, 468 * 468,dtype=features.dtype, device=features.device)
            indices =  voxel_coords[:, 2] * 468 + voxel_coords[:, 3] #468 * y + x
            # indices = indices.type(torch.long)
            # tensors used as indices must be long, byte or bool tensors
    
            indices = indices.long()
            spatial_feature[:, indices] = pillars_feature
            spatial_feature = spatial_feature.view(1,64, 468, 468) # 对应onnx resahap
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    4.由于基于OpenPcDetCenterPoint用了动态体素化,计算体素信息调用torch.unique,而torch.unique算子 TensorRT 不支持,

    torch.unique可以成功导出onnx,点击onnx 的unique节点,可以看出torch.unique输出有4个,而实际只有无重复的网格坐标unq_coords, 原始张量每个元素在处理后无重复数据中的索引unq_inv两个输出在后面用到了
    在这里插入图片描述

    在这里插入图片描述
    onnx不支持torch all函数,实现TensorRT算子的本质用cuda/cpp实现Plugin::enqueue 函数,将下面python对应的一系列小型操作放在预处理实现,用cuda单独实现会更好点

            points_coords = torch.floor((points[:, [0,1,2]] - self.point_cloud_range[[0,1,2]]) / self.voxel_size[[0,1,2]]).int()
            # onnx不支持all
            # 如果张量中的所有元素为True,才返回True
            mask = ((points_coords >= 0) & (points_coords < self.grid_size[[0,1]])).all(dim=1)
    
            mask = torch.rand(150000).bool()
            # 会调用onnx里的GatherND算子
            points = points[mask]
            points_coords = points_coords[mask]
    
            merge_coords = points_coords[:, 0] * self.scale_y + points_coords[:, 1] 
            # sorted:是否返回无重复张量按照数值进行排序,默认是升序排列,sorted并非表示降序
            # return_inverse:是否返回原始张量中每个元素在处理后的无重复张量中对应的索引
            # return_counts:统计原始张量中每个独立元素的个数
            # dim:值沿那个维度进行unique的处理
            unq_coords, unq_inv, _ = torch.unique(merge_coords, return_inverse=True, return_counts=True, dim=0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    修改后,onnx输入有4个:原始点云points,无重复一维体素网格坐标unq_coords,原始张量中每个元素在处理后的无重复张量中对应的索引unq_inv,网格坐标coords

    下面看CenterPoint转换出的onnx计算图:自定义onnx节点有 ScatterMaxPlugin,ScatterMeanPlugin,ScatterBevPlugin,用tensorRT实现就需要自定义ScatterMaxPlugin,ScatterMeanPlugin,ScatterBevPlugin 3个算子,后续会写下tenorRT自定义算子,并用cuda实现CenterPoint预处理和后处理,从而完成整个CenterPoint部署

    onnx太小看不清,自定义onnx节点如下:
    在这里插入图片描述

    完整的onnx如下:

    在这里插入图片描述

  • 相关阅读:
    超越BERT:多语言大模型的最新进展与挑战
    掌握Perl并发:线程与进程编程全攻略
    (部分不懂,笔记整理未完成)【图论】差分约束
    6-Mysql子查询,多表连接(内连接,外连接,交叉连接)
    python常用进制转换
    【机器学习】线性回归算法:原理、公式推导、损失函数、似然函数、梯度下降
    详解strstr函数:查找子字符串函数及其模拟实现
    高等数学(第七版)同济大学 习题9-10 个人解答
    Java基于SSM+JSP的服装定制系统
    Flutter 常见异常分析
  • 原文地址:https://blog.csdn.net/weixin_42905141/article/details/127545123