• c++笔记--基于C++实现tensor合并


    1--问题描述

            给定两个 NCHW 维度的 Blob,在 H 维度上进行拼接
            (N1, C1, H1, W1), (N2, C2, H2, W2) → (N3, C3, H1 + H2, W3)

    2--实例代码

    1. #include
    2. #include
    3. #include
    4. #include
    5. #include
    6. // c++实现tensor合并,在H维度
    7. struct Tensor{
    8. Tensor(int n, int c, int h, int w){
    9. this->N = n;
    10. this->C = c;
    11. this->H = h;
    12. this->W = w;
    13. this->num = n*c*h*w;
    14. this->data = new float[this->num];
    15. // 随机初始化数据
    16. for(int i = 0; i < this->num; i++){
    17. data[i] = rand() % 10;
    18. }
    19. }
    20. float *data = nullptr;
    21. int N = 0;
    22. int C = 0;
    23. int H = 0;
    24. int W = 0;
    25. int num = 0;
    26. // 返回索引值为(ni, ci, hi, wi)的元素引用
    27. float &at(int ni, int ci, int hi, int wi) const {
    28. // idx = wi + hi*w + ci*h*w + ni*c*h*w // 类似于CUDA线程索引的计算
    29. return data[W * (H * (C * ni + ci) + hi) + wi]; // 计算对应地址的位置
    30. }
    31. };
    32. // 给定两个NCHW维度的Blob,在H维度上进行拼接
    33. void Concat1(const Tensor *a, const Tensor *b, Tensor *c){
    34. for(int ni = 0; ni < a->N; ++ni){
    35. for(int ci = 0; ci < a->C; ++ci){
    36. for(int hi = 0; hi < a->H; ++hi){
    37. for(int wi = 0; wi < a->W; ++wi){
    38. c->at(ni, ci, hi, wi) = a->at(ni, ci, hi, wi);
    39. }
    40. }
    41. for(int hi = 0; hi < b->H; ++hi){
    42. for(int wi = 0; wi < b->W; ++wi){
    43. c->at(ni, ci, a->H+hi, wi) = b->at(ni, ci, hi, wi);
    44. }
    45. }
    46. }
    47. }
    48. }
    49. void Concat2(const Tensor *a, const Tensor *b, Tensor *c){
    50. for(int ni = 0; ni < a->N; ++ni){
    51. for(int ci = 0; ci < a->C; ++ci){
    52. for(int hi = 0; hi < a->H; ++hi){
    53. int offseta = (hi * a->W + ci * a->H * a->W + ni * a->C * a->H * a->W)*sizeof(float);
    54. int offsetc = (hi * c->W + ci * c->H * c->W + ni * c->C * c->H * c->W)*sizeof(float);
    55. memcpy(c->data+offsetc, a->data+offseta, a->W*sizeof(float));
    56. }
    57. for(int hi = 0; hi < b->H; ++hi){
    58. int offsetc = ((hi + a->H) * c->W + ci * c->H * c->W + ni * c->C * c->H * c->W)*sizeof(float);
    59. int offsetb = (hi * b->W + ci*b->H*b->W + ni * b->C * b->H * b->W)*sizeof(float);
    60. memcpy(c->data+offsetc, c->data+offsetb, b->W*sizeof(float));
    61. }
    62. }
    63. }
    64. }
    65. void Concat3(const Tensor *a, const Tensor *b, Tensor *c){
    66. for(int ni = 0; ni < a->N; ++ni){
    67. for(int ci = 0; ci < a->C; ++ci){
    68. int offseta = (ci * a->H * a->W + ni * a->C * a->H * a->W)*sizeof(float);
    69. int offsetc1 = (ci * c->H * c->W + ni * c->C * c->H * c->W)*sizeof(float);
    70. memcpy(c->data+offsetc1, a->data+offseta, a->W*a->H*sizeof(float));
    71. int offsetc2 = offsetc1 + a->W*a->H*sizeof(float);
    72. int offsetb = (ci * b->H * b->W + ni * b->C * b->H * b->W)*sizeof(float);
    73. memcpy(c->data+offsetc2, c->data+offsetb, b->W*b->H*sizeof(float));
    74. }
    75. }
    76. }
    77. int main(int argc, char argv[]){
    78. srand(time(nullptr));
    79. int N1 = 1, C1 = 1, H1 = 2, W1 = 2;
    80. int N2 = 1, C2 = 1, H2 = 2, W2 = 2;
    81. Tensor *a = new Tensor(N1, C1, H1, W1);
    82. Tensor *b = new Tensor(N2, C2, H2, W2);
    83. Tensor *c = new Tensor(N1, C1, H1+H2, W1);
    84. // Tensor a(N1, C1, H1, W1);
    85. // Tensor b(N2, C2, H2, W2);
    86. // Tensor c(N1, C1, H1+H2, W1);
    87. Concat1(a, b, c);
    88. for(int n = 0; n < N1; n++){
    89. for(int channel = 0; channel < C1; channel++){
    90. for(int h = 0; h < H1+H2; h++){
    91. for(int w = 0; w < W1; w++){
    92. std::cout << c->at(n, channel, h, w) << " ";
    93. }
    94. std::cout << std::endl;
    95. }
    96. }
    97. }
    98. std::cout << "-------------" << std::endl;
    99. Concat2(a, b, c);
    100. for(int n = 0; n < N1; n++){
    101. for(int channel = 0; channel < C1; channel++){
    102. for(int h = 0; h < H1+H2; h++){
    103. for(int w = 0; w < W1; w++){
    104. std::cout << c->at(n, channel, h, w) << " ";
    105. }
    106. std::cout << std::endl;
    107. }
    108. }
    109. }
    110. std::cout << "-------------" << std::endl;
    111. Concat3(a, b, c);
    112. for(int n = 0; n < N1; n++){
    113. for(int channel = 0; channel < C1; channel++){
    114. for(int h = 0; h < H1+H2; h++){
    115. for(int w = 0; w < W1; w++){
    116. std::cout << c->at(n, channel, h, w) << " ";
    117. }
    118. std::cout << std::endl;
    119. }
    120. }
    121. }
    122. return 0;
    123. }

  • 相关阅读:
    springboot + activiti实现activiti微服务化
    SpringMVC controller方法获取请求数据与前端传参类型匹配
    mybatisPlus
    长短时记忆网络(Long Short Term Memory,LSTM)详解
    Springboot+vue的疫情管理系统(有报告)。Javaee项目,springboot vue前后端分离项目。
    测试一下 Baichuan2-7B-Chat 的性能
    Linux内存管理知识总结(一)
    【调试】ftrace(一)基本使用方法
    win11如何双屏幕(1台主机2块显示器)
    MYSQL4:慢查询的优化方法和锁机制
  • 原文地址:https://blog.csdn.net/weixin_43863869/article/details/133897402