• im2col代码解析


    前言

    数字图像处理专栏的很多博客里,当遇到sobel算子、均值滤波等算子时,我们使用的是传统的卷积方法(直接卷积),也就是将卷积核在输入图像上不断进行移动产生输出。直接计算时,由于输入图像矩阵存放在内存中地址有重叠且不连续的空间上,在计算时有可能需要多次访问内存。多次访问内存直接增加了数据传输时间,从而进一步影响了卷积计算速度。
    同样地在深度学习中,卷积层也需要对输入特征图进行卷积,如果还是使用直接卷积的方式,势必会影响模型训练以及推理的速度。因此,人们采用一些策略来加速卷积运算。

    im2col

    im2col算法的原理这里不再阐述,网上有很有优质的博客,大家自行阅读即可。这里我们主要关注im2col的具体C语言实现。
    总的来说,im2col需要把图像张量data_im转换为一个列表示矩阵data_col,期间不做任数值上的运算。如下图所示,左边是输入特征图data_im(三通道),右边为im2col后的矩阵data_col(单通道),这里为了简单,我们假设padding=0,stride=2,卷积核大小ksize=3。经过卷积后的输出特征图data_output(单通道)大小为2*2。
    在这里插入图片描述

    代码

    float im2col_get_pixel(float *im, int height, int width, int channels,
        int row, int col, int channel, int pad)
    {
        row -= pad;
        col -= pad;
    
        if (row < 0 || col < 0 ||
            row >= height || col >= width) return 0;  // 超过范围的直接返回pading的0
        return im[col + width * (row + height * channel)];  // 在内存中索引
    }
    
    //From Berkeley Vision's Caffe!
    //https://github.com/BVLC/caffe/blob/master/LICENSE
    void im2col_cpu(float* data_im,
        int channels, int height, int width,
        int ksize, int stride, int pad, float* data_col)
    {
        int c, h, w;
        // height_col和width_col本质为输出特征图的高和宽,它两相乘就为data_col的列数
        int height_col = (height + 2 * pad - ksize) / stride + 1;  // 2
        int width_col = (width + 2 * pad - ksize) / stride + 1;  // 2
    
        int channels_col = channels * ksize * ksize;  // 27
        for (c = 0; c < channels_col; ++c) {
            int w_offset = c % ksize;  // 计算data_col中每一行的初始坐标
            int h_offset = (c / ksize) % ksize;
            int c_im = c / ksize / ksize;  // 计算是data_im的第几个通道
            printf("(h_offset,w_offset)=(%d,%d)\n", h_offset, w_offset);
            for (h = 0; h < height_col; ++h) {
                for (w = 0; w < width_col; ++w) {
                    int im_row = h_offset + h * stride;  // 每一行的初始坐标加上步长
                    int im_col = w_offset + w * stride;
                    printf("(im_row,im_col)=(%d,%d)\n", im_row, im_col);
                    int col_index = (c * height_col + h) * width_col + w; // 计算在data_col中的索引
                    printf("col_index:%d\n", col_index);
                    data_col[col_index] = im2col_get_pixel(data_im, height, width, channels,
                        im_row, im_col, c_im, pad);  // 根据坐标和当前的通道取值
                }
            }
        }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41

    思路

    首先每一个通道im2col后的行数为卷积核的大小,这里为3*3=9;每一个通道im2col后的列数为data_output的元素个数,这里也就是4。然后我们根据对应关系给data_col填充数值。
    按行给data_col填充数值,可以从图中看到data_col左上角的元素5对应着data_im中的(0,0)点,对于data_col中的每一行,我们首先都要计算出对应data_im中的哪个坐标,也就是代码中如下这两行

    int w_offset = c % ksize;
    int h_offset = (c / ksize) % ksize;
    
    • 1
    • 2

    对于data_col中的第一行中剩下的三个元素(5,4,1)。我们要根据stride来计算它在data_im中对应的坐标,按行进行计算,如下

    int im_row = h_offset + h * stride;
    int im_col = w_offset + w * stride;
    
    // 假如求元素1在data_im中的坐标
    // im_row = 0 + 1 * 2 = 2
    // im_col = 0 + 1 * 2 = 2 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    以此类推,便可以计算出data_col中所有元素在data_im中的坐标,然后就能取出对应的值。
    因为在C语言中数组都是行主序存储的,所以我们还要计算data_col中的元素在内存中的索引index,代码如下

    int col_index = (c * height_col + h) * width_col + w;
    /*
    如果这里看不明白,可以将width_col带进去,就变为
    col_index = c * (height_col * width_col) + h * width_col + w
    其中height * width_col为每一行的元素个数(为固定值),c为第几行,h为这一行的第几段,width_col为一段有几个,w为这一段的第几个
    */
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    在这里插入图片描述
    如上图所示,一行有2段,每段有2个。对于元素1来说,h=1,w=1,算出col_index = 3,刚好对应。

    最后将计算出的im_row, im_col, c_im 传入im2col_get_pixel函数索引像素值。

    结论

    总体来说,im2col的原理比较简单,但是在代码实现上还是有一定难度,最难的地方就在于如何将data_col中的坐标映射回data_im中的坐标,然后同时又因为是C语言编码,还要注意如何访问内存的问题。

    参考链接

    https://zhuanlan.zhihu.com/p/386052987
    https://blog.csdn.net/caicaiatnbu/article/details/100515321

  • 相关阅读:
    熊市里的大机构压力倍增,灰度、Tether、微策略等巨鲸会不会成为"巨雷"?
    Feign的超时时间如何设置,我研究了4种情况
    JS中计算时数据有误差解决方案
    BetaFlight飞控AOCODAF435V2MPU6500固件编译
    基于PHP+MySQL仓库管理系统的设计与实现
    【量化交易】 量化因子 风险类因子
    面试官:transient关键字修饰的变量当真不可序列化?我:烦请先生教我!
    使用 Linux 15 年后,我重新回到 Windows:感觉非常糟糕
    xftp打开时提示需要更新或使用新版本
    3.Python_创建型模式_抽象工厂模式
  • 原文地址:https://blog.csdn.net/qq_41596730/article/details/133014350