• 端侧GPU基于opencl实现reduce算子


    采用了一个work group的所有线程来计算最内部维度的reduce计算,work group通常采用固定 64个线程。

    计算一个warp/sub group内部的累加采用了sub_group_reduce_add,类似于CUDA的warp shuffle,理论上有助于提升性能。

    opencl kernel和测试代码如下,ARM mali GPU性能显著优于tflite GPU delegate实现。没有直接采用atomic对不同warp之间的结果进行累加,因为opencl的atomic不支持float,而采用了另一种方式。

    1. #include
    2. #include
    3. #include
    4. #include
    5. #include
    6. #include "mem_helper.h"
    7. #define CL_HPP_TARGET_OPENCL_VERSION 300
    8. #include
    9. using TEST_DTYPE = float;
    10. using namespace std;
    11. std::string kernel_source{R"(
    12. // we use all threads in a block to calculate the reduce mean/sum of the last tensor axis
    13. kernel void reduce_kernel(__global const float* d_in, __global float* d_out, int channel_size, int block_size) {
    14. int gid = get_global_id(0);
    15. int lid = get_local_id(0);
    16. int batch = get_group_id(0);
    17. int warp_id = get_sub_group_id();
    18. int lane_id = get_sub_group_local_id(); // id of threads in a warp
    19. int sg_num = get_num_sub_groups();
    20. // for warp_num = 8, block_size should < 8 * 8 = 64
    21. // for warp_num = 16, block_size should < 16 * 16 = 256
    22. int addr_offset = batch * channel_size;
    23. float mean = 0.0f;
    24. for (int i = lid; i < channel_size; i += block_size) {
    25. float x0 = d_in[addr_offset + i];
    26. mean += x0;
    27. }
    28. // all values in a warp are reduced together
    29. mean = sub_group_reduce_add(mean);
    30. // we use local mem to exchange the data between warps
    31. // we write first value of each warp into local mem, and then read by threads in one warp for reduce
    32. // should not > threads in a warp (typically 8 or 16),
    33. // and should not < work_group size / threads in a warp
    34. #define MAX_WARP_NUM 16
    35. local float shared_sums[MAX_WARP_NUM];
    36. if (lane_id == 0) {
    37. // atomic_store(&shared_sums[warp_id], mean);
    38. shared_sums[warp_id] = mean;
    39. }
    40. barrier(CLK_LOCAL_MEM_FENCE);
    41. // mean = (lid < sg_num) ? shared_sums[lid] : 0; // only first warp get mean
    42. mean = (lane_id < sg_num) ? shared_sums[lane_id] : 0; // each warp get mean
    43. mean = sub_group_reduce_add(mean);
    44. mean /= channel_size;
    45. if (lid == 0) {
    46. d_out[batch] = mean;
    47. }
    48. }
    49. )"};
    50. int main() {
    51. std::vector platforms;
    52. cl::Platform::get(&platforms);
    53. std::cout << "get platform num:" << platforms.size() << std::endl;
    54. cl::Platform plat;
    55. for (auto& p : platforms) {
    56. std::string platver = p.getInfo();
    57. if (platver.find("OpenCL 2.") != std::string::npos || platver.find("OpenCL 3.") != std::string::npos) {
    58. // Note: an OpenCL 3.x platform may not support all required features!
    59. plat = p;
    60. }
    61. }
    62. if (plat() == 0) {
    63. std::cout << "No OpenCL 2.0 or newer platform found.\n";
    64. return -1;
    65. }
    66. std::cout << "platform name:" << plat.getInfo() << std::endl;
    67. cl::Platform newP = cl::Platform::setDefault(plat);
    68. if (newP != plat) {
    69. std::cout << "Error setting default platform.\n";
    70. return -1;
    71. }
    72. // get default device (CPUs, GPUs) of the default platform
    73. std::vector all_devices;
    74. newP.getDevices(CL_DEVICE_TYPE_GPU, &all_devices); // CL_DEVICE_TYPE_ALL
    75. std::cout << "get all_devices num:" << all_devices.size() << std::endl;
    76. if (all_devices.size() == 0) {
    77. std::cout << " No devices found. Check OpenCL installation!\n";
    78. exit(1);
    79. }
    80. // cl::Device default_device = cl::Device::getDefault();
    81. cl::Device default_device = all_devices[0];
    82. std::cout << "device name: " << default_device.getInfo() << std::endl;
    83. // a context is like a "runtime link" to the device and platform;
    84. // i.e. communication is possible
    85. cl::Context context({default_device});
    86. cl::CommandQueue queue(context, default_device);
    87. int batch = 512;
    88. int channel_size = 768;
    89. vector<int> shape1 = {batch, channel_size};
    90. vector<int> shape2 = {batch,};
    91. MemoryHelper mem_in(shape1);
    92. MemoryHelper mem_out(shape2);
    93. mem_in.StepInit();
    94. // CL_MEM_WRITE_ONLY CL_MEM_READ_ONLY CL_MEM_READ_WRITE
    95. cl::Buffer d_in = cl::Buffer(context, CL_MEM_READ_WRITE, mem_in.bytes);
    96. cl::Buffer d_out = cl::Buffer(context, CL_MEM_READ_WRITE, mem_out.bytes);
    97. memset(mem_out.Mem(), 0, mem_out.bytes);
    98. // push write commands to queue
    99. queue.enqueueWriteBuffer(d_in, CL_TRUE, 0, mem_in.bytes, mem_in.Mem());
    100. std::vector programStrings;
    101. programStrings.push_back(kernel_source);
    102. cl::Program program(context, programStrings);
    103. if (program.build({default_device}, "-cl-std=CL3.0") != CL_SUCCESS) {
    104. std::cout << "Error building: " << program.getBuildInfo(default_device) << std::endl;
    105. exit(1);
    106. }
    107. auto cl_kernel = cl::KernelFunctorint, int>(program, "reduce_kernel");
    108. int block_size = 64;
    109. if(block_size > channel_size){
    110. block_size = channel_size;
    111. block_size = (block_size / 8) * 8;
    112. }
    113. int local_thread_num = block_size;
    114. int total_thread_num = batch * local_thread_num;
    115. // global, or global, local, or offset, global, local
    116. cl::EnqueueArgs kernel_args(queue, cl::NDRange(total_thread_num), cl::NDRange(local_thread_num));
    117. cl_kernel(kernel_args, d_in, d_out, channel_size, block_size);
    118. queue.enqueueReadBuffer(d_out, CL_TRUE, 0, mem_out.bytes, mem_out.Mem());
    119. std::cout << "results:" << std::endl;
    120. TEST_DTYPE* h_c = mem_out.Mem();
    121. for (int i = 0; i < mem_out.elem_num; i++) {
    122. std::cout << float(h_c[i]) << " ";
    123. }
    124. std::cout << std::endl;
    125. return 0;
    126. }

    mem_helper:

    1. #include
    2. #include
    3. #include
    4. #include
    5. using namespace std;
    6. template <class T>
    7. class MemoryHelper {
    8. public:
    9. const vector<int> shape;
    10. const size_t elem_num = 0;
    11. const string name;
    12. const size_t bytes = 0;
    13. std::unique_ptr h_mem = nullptr;
    14. public:
    15. MemoryHelper(const vector<int>& shape, const string& name = ""): shape(shape),
    16. name(name),
    17. elem_num(GetElemNum(shape)),
    18. bytes(elem_num * sizeof(T)) {
    19. h_mem = std::make_unique(elem_num);
    20. }
    21. void RandInit(int seed=0){
    22. srand(seed);
    23. for (size_t i = 0; i < elem_num; i++) {
    24. h_mem[i] = T(rand() % 100);
    25. }
    26. }
    27. void StepInit(float ratio=0.01f, float bias=0.0f){
    28. for(size_t i=0;i
    29. h_mem[i] = i*ratio+bias;
    30. }
    31. }
    32. T* Mem() {
    33. return h_mem.get();
    34. }
    35. size_t GetBytes() {
    36. return bytes;
    37. }
    38. public:
    39. static int GetElemNum(const vector<int>& shape) {
    40. size_t elem_num = 1;
    41. for (auto elem : shape) {
    42. elem_num *= elem;
    43. }
    44. return elem_num;
    45. }
    46. };

    附加信息

    mali gpu sub_group大小,对应于CUDA warp的大小,也可以通过相关函数获取:

    G.2.2 OpenCL 2.1 built-in functions

    Several new built-in functions are added in OpenCL 2.1.
    The new functions are:
    • get_enqueued_num_sub_groups
    • get_kernel_max_sub_group_size_for_ndrange
    • get_kernel_sub_group_count_for_ndrange
    • get_max_sub_group_size
    • get_num_sub_groups
    • get_sub_group_size
    • get_sub_group_local_id
    • get_sub_group_id
    • sub_group_all
    • sub_group_any
    • sub_group_barrier
    • sub_group_broadcast
    • sub_group_commit_read_pipe
    • sub_group_commit_write_pipe
    • sub_group_reduce_
    • sub_group_reserve_read_pipe
    • sub_group_reserve_write_pipe
    • sub_group_scan_exclusive_
    • sub_group_scan_inclusive_

    The in sub_group_reduce_, sub_group_scan_inclusive_ and sub_group_scan_exclusive_ defines the operator and can be add, min or max.

    For the sub_group_reduce, sub_group_scan_exclusive, and sub_group_scan_inclusive functions, gentype is int, uint, long, ulong, or float.

    If cl_khr_fp16 is supported, gentype also includes half.

    If cl_khr_fp64 or doubles are supported, gentype also includes double.

    warp shuffle其他相关方法

    // These functions are available to devices supporting cl_khr_subgroup_extended_types:
    // Note: Existing functions supporting additional data types.

    gentype sub_group_broadcast( gentype value, uint index )
    gentype sub_group_reduce_add( gentype value )
    gentype sub_group_reduce_min( gentype value )
    gentype sub_group_reduce_max( gentype value )

    gentype sub_group_scan_inclusive_add( gentype value )
    gentype sub_group_scan_inclusive_min( gentype value )
    gentype sub_group_scan_inclusive_max( gentype value )

    gentype sub_group_scan_exclusive_add( gentype value )
    gentype sub_group_scan_exclusive_min( gentype value )
    gentype sub_group_scan_exclusive_max( gentype value )

    // These functions are available to devices supporting cl_khr_subgroup_shuffle:
    gentype sub_group_shuffle( gentype value, uint index )
    gentype sub_group_shuffle_xor( gentype value, uint mask )
    // These functions are available to devices supporting cl_khr_subgroup_shuffle_relative:
    gentype sub_group_shuffle_up( gentype value, uint delta )
    gentype sub_group_shuffle_down( gentype value, uint delta )

    ref 

    The OpenCL™ Extension Specification

    Arm® Mali™ Bifrost and Valhall OpenCL Developer Guide

    https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Ext.html#_extended_subgroup_functions

    atomic操作可以参考《OpenCl异构并行计算  原理 机制与优化实践》 

  • 相关阅读:
    【MySQL】MySQL的存储过程(1)
    栈浅谈(上)
    崩坏:星穹铁道私人服务器搭建教程
    ansible免密登陆脚本
    mysql数据库无法连接:解决MySQL数据库无法连接的方法
    Ant Design Vue - 去掉 <a-tabs> 标签页组件底部细条灰色线(清掉选项卡组件整体底部灰色黑色细线)
    【微服务架构】链路追踪Skywalking入门与实践
    【scikit-learn基础】--『预处理』之 离散化
    【定向征文活动】2023年深圳1024开发者城市聚会活动参会感想征文
    微信小程序image组件的@load事件在图片没有加载完导致序列帧动画闪烁,如何解决?
  • 原文地址:https://blog.csdn.net/u013701860/article/details/126141236