• CUDA高性能计算经典问题②:前缀和


    75ee6ddf978a022e33c626cf1ec80ff9.png

    撰文 | Will Zhang

    上一篇《CUDA高性能计算经典问题①:归约》中,我们讨论了CUDA中如何实现高效Reduction,这次来讨论下一个经典问题Prefix Sum(前缀和),也被称为Scan/Prefix Scan等。Scan 是诸如排序等重要问题的子问题,所以基本是进阶必学问题之一。

    1

    问题定义

    首先我们不严谨地定义这个问题,输入一个数组input[n],计算新数组output[n], 使得对于任意元素output[i]都满足:

    output[i] = input[0] + input[1] + ... input[i]

    一个示例如下:

    50d284776a9451575b56f49b0ecc12ad.png

    如果在CPU上我们可以简单地如下实现:

    1. void PrefixSum(const int32_t* input, size_t n, int32_t* output) {
    2. int32_t sum = 0;
    3. for (size_t i = 0; i < n; ++i) {
    4. sum += input[i];
    5. output[i] = sum;
    6. }
    7. }

    问题来了,如何并行?而且是几千个线程和谐地并行?这个问题里还有个明显的依赖,每个元素的计算都依赖之前的值。所以第一次看到这个问题的同学可能会觉得,这怎么可能并行?

    而更进一步地,如何用CUDA并行,Warp级别怎么并行,Shared Memory能装下数据的情况怎么并行,Shared Memory装不下的情况如何并行等等。

    2

    ScanThenFan

    首先我们假设所有数据都可以存储到Global Memory中,因为更多的数据,核心逻辑也是类似的。

    我们介绍的第一个方法称为ScanThenFan,也很符合直觉,如下:

    • 将存储在Global Memory中的数据分为多个Parts,每个Part由一个Thread Block单独做内部的Scan,并将该Part的内部Sum存储到Global Memory中的PartSum数组中

    • 对这个PartSum数组做Scan,我们使用BaseSum标识这个Scan后的数组

    • 每个Part的每个元素都加上对应的BaseSum

    如下图

    8625861fea99ec206a128c93262d727e.png

    3

    Baseline

    我们先不关注Block内如何Scan,在Block内先使用简单的单个线程处理,得到如下代码:

    1. __global__ void ScanAndWritePartSumKernel(const int32_t* input, int32_t* part, int32_t* output, size_t n,
    2. size_t part_num) {
    3. for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    4. // this part process input[part_begin:part_end]
    5. // store sum to part[part_i], output[part_begin:part_end]
    6. size_t part_begin = part_i * blockDim.x;
    7. size_t part_end = min((part_i + 1) * blockDim.x, n);
    8. if (threadIdx.x == 0) { // naive implemention
    9. int32_t acc = 0;
    10. for (size_t i = part_begin; i < part_end; ++i) {
    11. acc += input[i];
    12. output[i] = acc;
    13. }
    14. part[part_i] = acc;
    15. }
    16. }
    17. }
    18. __global__ void ScanPartSumKernel(int32_t* part, size_t part_num) {
    19. int32_t acc = 0;
    20. for (size_t i = 0; i < part_num; ++i) {
    21. acc += part[i];
    22. part[i] = acc;
    23. }
    24. }
    25. __global__ void AddBaseSumKernel(int32_t* part, int32_t* output, size_t n,
    26. size_t part_num) {
    27. for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    28. if (part_i == 0) {
    29. continue;
    30. }
    31. int32_t index = part_i * blockDim.x + threadIdx.x;
    32. if (index < n) {
    33. output[index] += part[part_i - 1];
    34. }
    35. }
    36. }
    37. // for i in range(n):
    38. // output[i] = input[0] + input[1] + ... + input[i]
    39. void ScanThenFan(const int32_t* input, int32_t* buffer, int32_t* output,
    40. size_t n) {
    41. size_t part_size = 1024; // tuned
    42. size_t part_num = (n + part_size - 1) / part_size;
    43. size_t block_num = std::min<size_t>(part_num, 128);
    44. // use buffer[0:part_num] to save the metric of part
    45. int32_t* part = buffer;
    46. // after following step, part[i] = part_sum[i]
    47. ScanAndWritePartSumKernel<<<block_num, part_size>>>(input, part, output, n,
    48. part_num);
    49. // after following step, part[i] = part_sum[0] + part_sum[1] + ... part_sum[i]
    50. ScanPartSumKernel<<<1, 1>>>(part, part_num);
    51. // make final result
    52. AddBaseSumKernel<<<block_num, part_size>>>(part, output, n, part_num);
    53. }

    现在的代码里很多朴素实现,但我们先完成一个大框架,得到此时的耗时72390us作为一个Baseline。

    4

    Shared Memory

    接着,我们看ScanAndWritePartSumKernel函数,我们先做个简单的优化,将单个Part的数据先Load到Shared Memory中再做同样的简单逻辑,如下

    1. __device__ void ScanBlock(int32_t* shm) {
    2. if (threadIdx.x == 0) { // naive implemention
    3. int32_t acc = 0;
    4. for (size_t i = 0; i < blockDim.x; ++i) {
    5. acc += shm[i];
    6. shm[i] = acc;
    7. }
    8. }
    9. __syncthreads();
    10. }
    11. __global__ void ScanAndWritePartSumKernel(const int32_t* input, int32_t* part,
    12. int32_t* output, size_t n,
    13. size_t part_num) {
    14. extern __shared__ int32_t shm[];
    15. for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    16. // store this part input to shm
    17. size_t index = part_i * blockDim.x + threadIdx.x;
    18. shm[threadIdx.x] = index < n ? input[index] : 0;
    19. __syncthreads();
    20. // scan on shared memory
    21. ScanBlock(shm);
    22. __syncthreads();
    23. // write result
    24. if (index < n) {
    25. output[index] = shm[threadIdx.x];
    26. }
    27. if (threadIdx.x == blockDim.x - 1) {
    28. part[part_i] = shm[threadIdx.x];
    29. }
    30. }
    31. }

    这个简单的优化把时间从72390us降低到了33726us,这源于批量的从Global Memory的读取。

    5

    ScanBlock

    接下来我们正经地优化Block内的Scan,对于Block内部的Scan,我们可以用类似的思路拆解为

    • 按照Warp组织,每个Warp内部先做Scan,将每个Warp的和存储到Shared Memory中,称为WarpSum

    • 启动一个单独的Warp对WarpSum进行Scan

    • 每个Warp将最终结果加上上一个Warp对应的WarpSum

    代码如下

    1. __device__ void ScanWarp(int32_t* shm_data, int32_t lane) {
    2. if (lane == 0) { // naive implemention
    3. int32_t acc = 0;
    4. for (int32_t i = 0; i < 32; ++i) {
    5. acc += shm_data[i];
    6. shm_data[i] = acc;
    7. }
    8. }
    9. }
    10. __device__ void ScanBlock(int32_t* shm_data) {
    11. int32_t warp_id = threadIdx.x >> 5;
    12. int32_t lane = threadIdx.x & 31; // 31 = 00011111
    13. __shared__ int32_t warp_sum[32]; // blockDim.x / WarpSize = 32
    14. // scan each warp
    15. ScanWarp(shm_data, lane);
    16. __syncthreads();
    17. // write sum of each warp to warp_sum
    18. if (lane == 31) {
    19. warp_sum[warp_id] = *shm_data;
    20. }
    21. __syncthreads();
    22. // use a single warp to scan warp_sum
    23. if (warp_id == 0) {
    24. ScanWarp(warp_sum + lane, lane);
    25. }
    26. __syncthreads();
    27. // add base
    28. if (warp_id > 0) {
    29. *shm_data += warp_sum[warp_id - 1];
    30. }
    31. __syncthreads();
    32. }

    这一步从33726us降低到了9948us。

    6

    ScanWarp

    接着我们优化ScanWarp。为了方便解释算法,我们假设对16个数做Scan,算法如下图:

    688d383b51ccb1d0ef1085b7f1296ed7.png

    横向的16个点代表16个数,时间轴从上往下,每个入度为2的节点会做加法,并将结果广播到其输出节点,对于32个数的代码如下:

    1. __device__ void ScanWarp(int32_t* shm_data) {
    2. int32_t lane = threadIdx.x & 31;
    3. volatile int32_t* vshm_data = shm_data;
    4. if (lane >= 1) {
    5. vshm_data[0] += vshm_data[-1];
    6. }
    7. __syncwarp();
    8. if (lane >= 2) {
    9. vshm_data[0] += vshm_data[-2];
    10. }
    11. __syncwarp();
    12. if (lane >= 4) {
    13. vshm_data[0] += vshm_data[-4];
    14. }
    15. __syncwarp();
    16. if (lane >= 8) {
    17. vshm_data[0] += vshm_data[-8];
    18. }
    19. __syncwarp();
    20. if (lane >= 16) {
    21. vshm_data[0] += vshm_data[-16];
    22. }
    23. __syncwarp();
    24. }

    这个算法下,每一步都没有bank conflict,耗时也从9948us降低到了7595us。

    7

    ZeroPadding

    接下来我们想更进一步消除ScanWarp中的if,也就是不对lane做判断,warp中所有线程都执行同样的操作,这就意味着之前不符合条件的线程会访问越界,为此我们需要做padding让其不越界。

    为了实现padding,回看ScanBlock函数,其定义的warp_sum并非为kernel launch时指定的。为了更改方便,我们将其更改为kernel launch时指定,如下

    1. __device__ void ScanBlock(int32_t* shm_data) {
    2. int32_t warp_id = threadIdx.x >> 5;
    3. int32_t lane = threadIdx.x & 31; // 31 = 00011111
    4. extern __shared__ int32_t warp_sum[]; // warp_sum[32]
    5. // scan each warp
    6. ScanWarp(shm_data);
    7. __syncthreads();
    8. // write sum of each warp to warp_sum
    9. if (lane == 31) {
    10. warp_sum[warp_id] = *shm_data;
    11. }
    12. __syncthreads();
    13. // use a single warp to scan warp_sum
    14. if (warp_id == 0) {
    15. ScanWarp(warp_sum + lane);
    16. }
    17. __syncthreads();
    18. // add base
    19. if (warp_id > 0) {
    20. *shm_data += warp_sum[warp_id - 1];
    21. }
    22. __syncthreads();
    23. }
    24. __global__ void ScanAndWritePartSumKernel(const int32_t* input, int32_t* part,
    25. int32_t* output, size_t n,
    26. size_t part_num) {
    27. // the first 32 is used to save warp sum
    28. extern __shared__ int32_t shm[];
    29. for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    30. // store this part input to shm
    31. size_t index = part_i * blockDim.x + threadIdx.x;
    32. shm[32 + threadIdx.x] = index < n ? input[index] : 0;
    33. __syncthreads();
    34. // scan on shared memory
    35. ScanBlock(shm + 32 + threadIdx.x);
    36. __syncthreads();
    37. // write result
    38. if (index < n) {
    39. output[index] = shm[32 + threadIdx.x];
    40. }
    41. if (threadIdx.x == blockDim.x - 1) {
    42. part[part_i] = shm[32 + threadIdx.x];
    43. }
    44. }
    45. }
    46. __global__ void ScanPartSumKernel(int32_t* part, size_t part_num) {
    47. int32_t acc = 0;
    48. for (size_t i = 0; i < part_num; ++i) {
    49. acc += part[i];
    50. part[i] = acc;
    51. }
    52. }
    53. __global__ void AddBaseSumKernel(int32_t* part, int32_t* output, size_t n,
    54. size_t part_num) {
    55. for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    56. if (part_i == 0) {
    57. continue;
    58. }
    59. int32_t index = part_i * blockDim.x + threadIdx.x;
    60. if (index < n) {
    61. output[index] += part[part_i - 1];
    62. }
    63. }
    64. }
    65. // for i in range(n):
    66. // output[i] = input[0] + input[1] + ... + input[i]
    67. void ScanThenFan(const int32_t* input, int32_t* buffer, int32_t* output,
    68. size_t n) {
    69. size_t part_size = 1024; // tuned
    70. size_t part_num = (n + part_size - 1) / part_size;
    71. size_t block_num = std::min<size_t>(part_num, 128);
    72. // use buffer[0:part_num] to save the metric of part
    73. int32_t* part = buffer;
    74. // after following step, part[i] = part_sum[i]
    75. size_t shm_size = (32 + part_size) * sizeof(int32_t);
    76. ScanAndWritePartSumKernel<<<block_num, part_size, shm_size>>>(
    77. input, part, output, n, part_num);
    78. // after following step, part[i] = part_sum[0] + part_sum[1] + ... part_sum[i]
    79. ScanPartSumKernel<<<1, 1>>>(part, part_num);
    80. // make final result
    81. AddBaseSumKernel<<<block_num, part_size>>>(part, output, n, part_num);
    82. }

    注意在ScanAndWritePartSumKernel的Launch时,我们重新计算了shared memory的大小,接下来为了做padding,我们要继续修改其shared memory的大小,由于每个warp需要一个16大小的padding才能避免ScanWarp的线程不越界,所以我们更改ScanThenFan为:

    1. // for i in range(n):
    2. // output[i] = input[0] + input[1] + ... + input[i]
    3. void ScanThenFan(const int32_t* input, int32_t* buffer, int32_t* output,
    4. size_t n) {
    5. size_t part_size = 1024; // tuned
    6. size_t part_num = (n + part_size - 1) / part_size;
    7. size_t block_num = std::min<size_t>(part_num, 128);
    8. // use buffer[0:part_num] to save the metric of part
    9. int32_t* part = buffer;
    10. // after following step, part[i] = part_sum[i]
    11. size_t warp_num = part_size / 32;
    12. size_t shm_size = (16 + 32 + warp_num * (16 + 32)) * sizeof(int32_t);
    13. ScanAndWritePartSumKernel<<<block_num, part_size, shm_size>>>(
    14. input, part, output, n, part_num);
    15. // after following step, part[i] = part_sum[0] + part_sum[1] + ... part_sum[i]
    16. ScanPartSumKernel<<<1, 1>>>(part, part_num);
    17. // make final result
    18. AddBaseSumKernel<<<block_num, part_size>>>(part, output, n, part_num);
    19. }

    注意shm_size的计算,我们为warp_sum也提供了16个数的zero padding,对应的Kernel改写如下:

    1. __device__ void ScanWarp(int32_t* shm_data) {
    2. volatile int32_t* vshm_data = shm_data;
    3. vshm_data[0] += vshm_data[-1];
    4. vshm_data[0] += vshm_data[-2];
    5. vshm_data[0] += vshm_data[-4];
    6. vshm_data[0] += vshm_data[-8];
    7. vshm_data[0] += vshm_data[-16];
    8. }
    9. __device__ void ScanBlock(int32_t* shm_data) {
    10. int32_t warp_id = threadIdx.x >> 5;
    11. int32_t lane = threadIdx.x & 31;
    12. extern __shared__ int32_t warp_sum[]; // 16 zero padding
    13. // scan each warp
    14. ScanWarp(shm_data);
    15. __syncthreads();
    16. // write sum of each warp to warp_sum
    17. if (lane == 31) {
    18. warp_sum[16 + warp_id] = *shm_data;
    19. }
    20. __syncthreads();
    21. // use a single warp to scan warp_sum
    22. if (warp_id == 0) {
    23. ScanWarp(warp_sum + 16 + lane);
    24. }
    25. __syncthreads();
    26. // add base
    27. if (warp_id > 0) {
    28. *shm_data += warp_sum[16 + warp_id - 1];
    29. }
    30. __syncthreads();
    31. }
    32. __global__ void ScanAndWritePartSumKernel(const int32_t* input, int32_t* part,
    33. int32_t* output, size_t n,
    34. size_t part_num) {
    35. // the first 16 + 32 is used to save warp sum
    36. extern __shared__ int32_t shm[];
    37. int32_t warp_id = threadIdx.x >> 5;
    38. int32_t lane = threadIdx.x & 31;
    39. // initialize the zero padding
    40. if (threadIdx.x < 16) {
    41. shm[threadIdx.x] = 0;
    42. }
    43. if (lane < 16) {
    44. shm[(16 + 32) + warp_id * (16 + 32) + lane] = 0;
    45. }
    46. __syncthreads();
    47. // process each part
    48. for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    49. // store this part input to shm
    50. size_t index = part_i * blockDim.x + threadIdx.x;
    51. int32_t* myshm = shm + (16 + 32) + warp_id * (16 + 32) + 16 + lane;
    52. *myshm = index < n ? input[index] : 0;
    53. __syncthreads();
    54. // scan on shared memory
    55. ScanBlock(myshm);
    56. __syncthreads();
    57. // write result
    58. if (index < n) {
    59. output[index] = *myshm;
    60. }
    61. if (threadIdx.x == blockDim.x - 1) {
    62. part[part_i] = *myshm;
    63. }
    64. }
    65. }

    改动比较多,主要是对相关index的计算,经过这一步优化,时间从7595us降低到了7516us,看似不大,主要是被瓶颈掩盖了。对于ScanWarp还可以用WarpShuffle来优化,为了体现其效果,我们放在后面再说,先优化当前瓶颈。 

    8

    Recursion

    当前的一个瓶颈在于,之前为了简化,对于PartSum的Scan,是由一个线程去做的,这块可以递归地做,如下:

    1. // for i in range(n):
    2. // output[i] = input[0] + input[1] + ... + input[i]
    3. void ScanThenFan(const int32_t* input, int32_t* buffer, int32_t* output,
    4. size_t n) {
    5. size_t part_size = 1024; // tuned
    6. size_t part_num = (n + part_size - 1) / part_size;
    7. size_t block_num = std::min<size_t>(part_num, 128);
    8. // use buffer[0:part_num] to save the metric of part
    9. int32_t* part = buffer;
    10. // after following step, part[i] = part_sum[i]
    11. size_t warp_num = part_size / 32;
    12. size_t shm_size = (16 + 32 + warp_num * (16 + 32)) * sizeof(int32_t);
    13. ScanAndWritePartSumKernel<<<block_num, part_size, shm_size>>>(
    14. input, part, output, n, part_num);
    15. if (part_num >= 2) {
    16. // after following step
    17. // part[i] = part_sum[0] + part_sum[1] + ... + part_sum[i]
    18. ScanThenFan(part, buffer + part_num, part, part_num);
    19. // make final result
    20. AddBaseSumKernel<<<block_num, part_size>>>(part, output, n, part_num);
    21. }
    22. }

    移除了之前的简单操作后,耗时从7516us下降到了3972us。

    9

    WarpShuffle

    接下来我们使用WarpShuffle来实现WarpScan,如下:

    1. __device__ int32_t ScanWarp(int32_t val) {
    2. int32_t lane = threadIdx.x & 31;
    3. int32_t tmp = __shfl_up_sync(0xffffffff, val, 1);
    4. if (lane >= 1) {
    5. val += tmp;
    6. }
    7. tmp = __shfl_up_sync(0xffffffff, val, 2);
    8. if (lane >= 2) {
    9. val += tmp;
    10. }
    11. tmp = __shfl_up_sync(0xffffffff, val, 4);
    12. if (lane >= 4) {
    13. val += tmp;
    14. }
    15. tmp = __shfl_up_sync(0xffffffff, val, 8);
    16. if (lane >= 8) {
    17. val += tmp;
    18. }
    19. tmp = __shfl_up_sync(0xffffffff, val, 16);
    20. if (lane >= 16) {
    21. val += tmp;
    22. }
    23. return val;
    24. }

    时间从3972us降低到了3747us。

    10

    PTX

    我们可以进一步地使用cuobjdump查看其编译出的PTX代码,我添加了点注释,如下:

    1. // 声明寄存器
    2. .reg .pred %p<11>;
    3. .reg .b32 %r<39>;
    4. // 读取参数到r35寄存器
    5. ld.param.u32 %r35, [_Z8ScanWarpi_param_0];
    6. // 读取threadIdx.x到r18寄存器
    7. mov.u32 %r18, %tid.x;
    8. // r1寄存器存储 lane = threadIdx.x & 31
    9. and.b32 %r1, %r18, 31;
    10. // r19寄存器存储0
    11. mov.u32 %r19, 0;
    12. // r20寄存器存储1
    13. mov.u32 %r20, 1;
    14. // r21寄存器存储-1
    15. mov.u32 %r21, -1;
    16. // r2|p1 = __shfl_up_sync(val, delta=1, 0, membermask=-1)
    17. // 如果src lane在范围内,存储结果到r2中,并设置p1为True, 否则设置p1为False
    18. // r2对应于我们代码中的tmp
    19. shfl.sync.up.b32 %r2|%p1, %r35, %r20, %r19, %r21;
    20. // p6 = (lane == 0)
    21. setp.eq.s32 %p6, %r1, 0;
    22. // 如果p6为真,则跳转到BB0_2
    23. @%p6 bra BB0_2;
    24. // val += tmp
    25. add.s32 %r35, %r2, %r35;
    26. // 偏移2
    27. BB0_2:
    28. mov.u32 %r23, 2;
    29. shfl.sync.up.b32 %r5|%p2, %r35, %r23, %r19, %r21;
    30. setp.lt.u32 %p7, %r1, 2;
    31. @%p7 bra BB0_4;
    32. add.s32 %r35, %r5, %r35;
    33. ...

    可以看到,我们可以直接使用__shfl_up_sync生成的p寄存器来做条件加法,从而避免生成的条件跳转指令,代码如下:

    1. __device__ __forceinline__ int32_t ScanWarp(int32_t val) {
    2. int32_t result;
    3. asm("{"
    4. ".reg .s32 r<5>;"
    5. ".reg .pred p<5>;"
    6. "shfl.sync.up.b32 r0|p0, %1, 1, 0, -1;"
    7. "@p0 add.s32 r0, r0, %1;"
    8. "shfl.sync.up.b32 r1|p1, r0, 2, 0, -1;"
    9. "@p1 add.s32 r1, r1, r0;"
    10. "shfl.sync.up.b32 r2|p2, r1, 4, 0, -1;"
    11. "@p2 add.s32 r2, r2, r1;"
    12. "shfl.sync.up.b32 r3|p3, r2, 8, 0, -1;"
    13. "@p3 add.s32 r3, r3, r2;"
    14. "shfl.sync.up.b32 r4|p4, r3, 16, 0, -1;"
    15. "@p4 add.s32 r4, r4, r3;"
    16. "mov.s32 %0, r4;"
    17. "}"
    18. : "=r"(result)
    19. : "r"(val));
    20. return result;
    21. }

    此外移除依赖的大量shared memory,如下:

    1. __device__ __forceinline__ int32_t ScanBlock(int32_t val) {
    2. int32_t warp_id = threadIdx.x >> 5;
    3. int32_t lane = threadIdx.x & 31;
    4. extern __shared__ int32_t warp_sum[];
    5. // scan each warp
    6. val = ScanWarp(val);
    7. __syncthreads();
    8. // write sum of each warp to warp_sum
    9. if (lane == 31) {
    10. warp_sum[warp_id] = val;
    11. }
    12. __syncthreads();
    13. // use a single warp to scan warp_sum
    14. if (warp_id == 0) {
    15. warp_sum[lane] = ScanWarp(warp_sum[lane]);
    16. }
    17. __syncthreads();
    18. // add base
    19. if (warp_id > 0) {
    20. val += warp_sum[warp_id - 1];
    21. }
    22. __syncthreads();
    23. return val;
    24. }
    25. __global__ void ScanAndWritePartSumKernel(const int32_t* input, int32_t* part,
    26. int32_t* output, size_t n,
    27. size_t part_num) {
    28. for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    29. size_t index = part_i * blockDim.x + threadIdx.x;
    30. int32_t val = index < n ? input[index] : 0;
    31. val = ScanBlock(val);
    32. __syncthreads();
    33. if (index < n) {
    34. output[index] = val;
    35. }
    36. if (threadIdx.x == blockDim.x - 1) {
    37. part[part_i] = val;
    38. }
    39. }
    40. }
    41. __global__ void AddBaseSumKernel(int32_t* part, int32_t* output, size_t n,
    42. size_t part_num) {
    43. for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    44. if (part_i == 0) {
    45. continue;
    46. }
    47. int32_t index = part_i * blockDim.x + threadIdx.x;
    48. if (index < n) {
    49. output[index] += part[part_i - 1];
    50. }
    51. }
    52. }
    53. // for i in range(n):
    54. // output[i] = input[0] + input[1] + ... + input[i]
    55. void ScanThenFan(const int32_t* input, int32_t* buffer, int32_t* output,
    56. size_t n) {
    57. size_t part_size = 1024; // tuned
    58. size_t part_num = (n + part_size - 1) / part_size;
    59. size_t block_num = std::min<size_t>(part_num, 128);
    60. // use buffer[0:part_num] to save the metric of part
    61. int32_t* part = buffer;
    62. // after following step, part[i] = part_sum[i]
    63. size_t shm_size = 32 * sizeof(int32_t);
    64. ScanAndWritePartSumKernel<<<block_num, part_size, shm_size>>>(
    65. input, part, output, n, part_num);
    66. if (part_num >= 2) {
    67. // after following step
    68. // part[i] = part_sum[0] + part_sum[1] + ... + part_sum[i]
    69. ScanThenFan(part, buffer + part_num, part, part_num);
    70. // make final result
    71. AddBaseSumKernel<<<block_num, part_size>>>(part, output, n, part_num);
    72. }
    73. }

    此时耗时下降到了3442us。

    11

    ReduceThenScan

    不同于ScanThenFan,其在第一遍每个Part内部做Scan。在这一节中我们将在第一遍只算和,而在最后一步做Scan,代码如下:

    1. __global__ void ReducePartSumKernel(const int32_t* input, int32_t* part_sum,
    2. int32_t* output, size_t n,
    3. size_t part_num) {
    4. using BlockReduce = cub::BlockReduce<int32_t, 1024>;
    5. __shared__ typename BlockReduce::TempStorage temp_storage;
    6. for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    7. size_t index = part_i * blockDim.x + threadIdx.x;
    8. int32_t val = index < n ? input[index] : 0;
    9. int32_t sum = BlockReduce(temp_storage).Sum(val);
    10. if (threadIdx.x == 0) {
    11. part_sum[part_i] = sum;
    12. }
    13. __syncthreads();
    14. }
    15. }
    16. __global__ void ScanWithBaseSum(const int32_t* input, int32_t* part_sum,
    17. int32_t* output, size_t n, size_t part_num) {
    18. for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    19. size_t index = part_i * blockDim.x + threadIdx.x;
    20. int32_t val = index < n ? input[index] : 0;
    21. val = ScanBlock(val);
    22. __syncthreads();
    23. if (part_i >= 1) {
    24. val += part_sum[part_i - 1];
    25. }
    26. if (index < n) {
    27. output[index] = val;
    28. }
    29. }
    30. }
    31. void ReduceThenScan(const int32_t* input, int32_t* buffer, int32_t* output,
    32. size_t n) {
    33. size_t part_size = 1024; // tuned
    34. size_t part_num = (n + part_size - 1) / part_size;
    35. size_t block_num = std::min<size_t>(part_num, 128);
    36. int32_t* part_sum = buffer; // use buffer[0:part_num]
    37. if (part_num >= 2) {
    38. ReducePartSumKernel<<<block_num, part_size>>>(input, part_sum, output, n,
    39. part_num);
    40. ReduceThenScan(part_sum, buffer + part_num, part_sum, part_num);
    41. }
    42. ScanWithBaseSum<<<block_num, part_size, 32 * sizeof(int32_t)>>>(
    43. input, part_sum, output, n, part_num);
    44. }

    为了简化,我们在代码中使用cub的BlockReduce,这个版本的耗时为3503us, 略有上升。

    之前的算法都存在递归,现在我们想办法消除递归,延续ReduceThenScan的想法,只需要我们把Part切得更大一些,比如让Part数和Block数相等,就可以避免递归,代码如下:

    1. __global__ void ReducePartSumKernelSinglePass(const int32_t* input,
    2. int32_t* g_part_sum, size_t n,
    3. size_t part_size) {
    4. // this block process input[part_begin:part_end]
    5. size_t part_begin = blockIdx.x * part_size;
    6. size_t part_end = min((blockIdx.x + 1) * part_size, n);
    7. // part_sum
    8. int32_t part_sum = 0;
    9. for (size_t i = part_begin + threadIdx.x; i < part_end; i += blockDim.x) {
    10. part_sum += input[i];
    11. }
    12. using BlockReduce = cub::BlockReduce<int32_t, 1024>;
    13. __shared__ typename BlockReduce::TempStorage temp_storage;
    14. part_sum = BlockReduce(temp_storage).Sum(part_sum);
    15. __syncthreads();
    16. if (threadIdx.x == 0) {
    17. g_part_sum[blockIdx.x] = part_sum;
    18. }
    19. }
    20. __global__ void ScanWithBaseSumSinglePass(const int32_t* input,
    21. int32_t* g_base_sum, int32_t* output,
    22. size_t n, size_t part_size,
    23. bool debug) {
    24. // base sum
    25. __shared__ int32_t base_sum;
    26. if (threadIdx.x == 0) {
    27. if (blockIdx.x == 0) {
    28. base_sum = 0;
    29. } else {
    30. base_sum = g_base_sum[blockIdx.x - 1];
    31. }
    32. }
    33. __syncthreads();
    34. // this block process input[part_begin:part_end]
    35. size_t part_begin = blockIdx.x * part_size;
    36. size_t part_end = (blockIdx.x + 1) * part_size;
    37. for (size_t i = part_begin + threadIdx.x; i < part_end; i += blockDim.x) {
    38. int32_t val = i < n ? input[i] : 0;
    39. val = ScanBlock(val);
    40. if (i < n) {
    41. output[i] = val + base_sum;
    42. }
    43. __syncthreads();
    44. if (threadIdx.x == blockDim.x - 1) {
    45. base_sum += val;
    46. }
    47. __syncthreads();
    48. }
    49. }
    50. void ReduceThenScanTwoPass(const int32_t* input, int32_t* part_sum,
    51. int32_t* output, size_t n) {
    52. size_t part_num = 1024;
    53. size_t part_size = (n + part_num - 1) / part_num;
    54. ReducePartSumKernelSinglePass<<<part_num, 1024>>>(input, part_sum, n,
    55. part_size);
    56. ScanWithBaseSumSinglePass<<<1, 1024, 32 * sizeof(int32_t)>>>(
    57. part_sum, nullptr, part_sum, part_num, part_num, true);
    58. ScanWithBaseSumSinglePass<<<part_num, 1024, 32 * sizeof(int32_t)>>>(
    59. input, part_sum, output, n, part_size, false);
    60. }

    耗时下降至2467us。

    12

    结语

    即使做了很多优化,对比CUB的时间1444us,仍然有较大优化空间。不过本人一向秉承“打不过就加入”的原则,而且CUB也是开源的,后面有时间再深入CUB代码写一篇代码解读。

    Reference

    https//:www.amazon.com/CUDA-Handbook-Comprehensive-Guide-Programming/dp/0321809467

    (原文链接:https://zhuanlan.zhihu.com/p/423992093)

    其他人都在看

    欢迎下载体验OneFlow新一代开源深度学习框架:https://github.com/Oneflow-Inc/oneflow/icon-default.png?t=LBL2https://github.com/Oneflow-Inc/oneflow/

    443a4b80f36243bdcf3d3ef01c720658.png

  • 相关阅读:
    基于深度学习网络的疲劳驾驶检测算法matlab仿真
    使用oracle虚拟机添加新硬盘
    二十二、W5100S/W5500+RP2040树莓派Pico<SMTP发送邮件>
    Cross product
    sqlite数据库
    Cholesterol-PEG-FITC,Fluorescein-PEG-CLS,胆固醇-聚乙二醇-荧光素水溶性
    1:开启慢查询日志 与 找到慢SQL
    nc65单据穿透
    Rust 从入门到精通03-helloword
    【餐厅点餐平台|四】UI设计+效果展示
  • 原文地址:https://blog.csdn.net/OneFlow_Official/article/details/122356421