• c++笔记--基于C++实现tensor合并


    1--问题描述

            给定两个 NCHW 维度的 Blob,在 H 维度上进行拼接
            (N1, C1, H1, W1), (N2, C2, H2, W2) → (N3, C3, H1 + H2, W3)

    2--实例代码

    1. #include
    2. #include
    3. #include
    4. #include
    5. #include
    6. // c++实现tensor合并,在H维度
    7. struct Tensor{
    8. Tensor(int n, int c, int h, int w){
    9. this->N = n;
    10. this->C = c;
    11. this->H = h;
    12. this->W = w;
    13. this->num = n*c*h*w;
    14. this->data = new float[this->num];
    15. // 随机初始化数据
    16. for(int i = 0; i < this->num; i++){
    17. data[i] = rand() % 10;
    18. }
    19. }
    20. float *data = nullptr;
    21. int N = 0;
    22. int C = 0;
    23. int H = 0;
    24. int W = 0;
    25. int num = 0;
    26. // 返回索引值为(ni, ci, hi, wi)的元素引用
    27. float &at(int ni, int ci, int hi, int wi) const {
    28. // idx = wi + hi*w + ci*h*w + ni*c*h*w // 类似于CUDA线程索引的计算
    29. return data[W * (H * (C * ni + ci) + hi) + wi]; // 计算对应地址的位置
    30. }
    31. };
    32. // 给定两个NCHW维度的Blob,在H维度上进行拼接
    33. void Concat1(const Tensor *a, const Tensor *b, Tensor *c){
    34. for(int ni = 0; ni < a->N; ++ni){
    35. for(int ci = 0; ci < a->C; ++ci){
    36. for(int hi = 0; hi < a->H; ++hi){
    37. for(int wi = 0; wi < a->W; ++wi){
    38. c->at(ni, ci, hi, wi) = a->at(ni, ci, hi, wi);
    39. }
    40. }
    41. for(int hi = 0; hi < b->H; ++hi){
    42. for(int wi = 0; wi < b->W; ++wi){
    43. c->at(ni, ci, a->H+hi, wi) = b->at(ni, ci, hi, wi);
    44. }
    45. }
    46. }
    47. }
    48. }
    49. void Concat2(const Tensor *a, const Tensor *b, Tensor *c){
    50. for(int ni = 0; ni < a->N; ++ni){
    51. for(int ci = 0; ci < a->C; ++ci){
    52. for(int hi = 0; hi < a->H; ++hi){
    53. int offseta = (hi * a->W + ci * a->H * a->W + ni * a->C * a->H * a->W)*sizeof(float);
    54. int offsetc = (hi * c->W + ci * c->H * c->W + ni * c->C * c->H * c->W)*sizeof(float);
    55. memcpy(c->data+offsetc, a->data+offseta, a->W*sizeof(float));
    56. }
    57. for(int hi = 0; hi < b->H; ++hi){
    58. int offsetc = ((hi + a->H) * c->W + ci * c->H * c->W + ni * c->C * c->H * c->W)*sizeof(float);
    59. int offsetb = (hi * b->W + ci*b->H*b->W + ni * b->C * b->H * b->W)*sizeof(float);
    60. memcpy(c->data+offsetc, c->data+offsetb, b->W*sizeof(float));
    61. }
    62. }
    63. }
    64. }
    65. void Concat3(const Tensor *a, const Tensor *b, Tensor *c){
    66. for(int ni = 0; ni < a->N; ++ni){
    67. for(int ci = 0; ci < a->C; ++ci){
    68. int offseta = (ci * a->H * a->W + ni * a->C * a->H * a->W)*sizeof(float);
    69. int offsetc1 = (ci * c->H * c->W + ni * c->C * c->H * c->W)*sizeof(float);
    70. memcpy(c->data+offsetc1, a->data+offseta, a->W*a->H*sizeof(float));
    71. int offsetc2 = offsetc1 + a->W*a->H*sizeof(float);
    72. int offsetb = (ci * b->H * b->W + ni * b->C * b->H * b->W)*sizeof(float);
    73. memcpy(c->data+offsetc2, c->data+offsetb, b->W*b->H*sizeof(float));
    74. }
    75. }
    76. }
    77. int main(int argc, char argv[]){
    78. srand(time(nullptr));
    79. int N1 = 1, C1 = 1, H1 = 2, W1 = 2;
    80. int N2 = 1, C2 = 1, H2 = 2, W2 = 2;
    81. Tensor *a = new Tensor(N1, C1, H1, W1);
    82. Tensor *b = new Tensor(N2, C2, H2, W2);
    83. Tensor *c = new Tensor(N1, C1, H1+H2, W1);
    84. // Tensor a(N1, C1, H1, W1);
    85. // Tensor b(N2, C2, H2, W2);
    86. // Tensor c(N1, C1, H1+H2, W1);
    87. Concat1(a, b, c);
    88. for(int n = 0; n < N1; n++){
    89. for(int channel = 0; channel < C1; channel++){
    90. for(int h = 0; h < H1+H2; h++){
    91. for(int w = 0; w < W1; w++){
    92. std::cout << c->at(n, channel, h, w) << " ";
    93. }
    94. std::cout << std::endl;
    95. }
    96. }
    97. }
    98. std::cout << "-------------" << std::endl;
    99. Concat2(a, b, c);
    100. for(int n = 0; n < N1; n++){
    101. for(int channel = 0; channel < C1; channel++){
    102. for(int h = 0; h < H1+H2; h++){
    103. for(int w = 0; w < W1; w++){
    104. std::cout << c->at(n, channel, h, w) << " ";
    105. }
    106. std::cout << std::endl;
    107. }
    108. }
    109. }
    110. std::cout << "-------------" << std::endl;
    111. Concat3(a, b, c);
    112. for(int n = 0; n < N1; n++){
    113. for(int channel = 0; channel < C1; channel++){
    114. for(int h = 0; h < H1+H2; h++){
    115. for(int w = 0; w < W1; w++){
    116. std::cout << c->at(n, channel, h, w) << " ";
    117. }
    118. std::cout << std::endl;
    119. }
    120. }
    121. }
    122. return 0;
    123. }

  • 相关阅读:
    【Spring入门学习】
    DDoS渗透与攻防实战 (一) : 初识DDoS
    《凤凰架构》-全局事务章节读书笔记
    批量获取CSDN文章对文章质量分进行检测,有助于优化文章质量
    校园二手交易小程序,微信小程序二手交易系统毕设作品
    前端规范——前端代码提交篇(2)
    day1:Node.js 简介
    python转yuyv422到jpg
    【Java】JDK动态代理实现原理
    数据结构之队的实现
  • 原文地址:https://blog.csdn.net/weixin_43863869/article/details/133897402