• 插帧中grid_sample函数详解


    从之前VSR到后来做MEMC,基本都要用到该函数,但是VSR后期后很多工作很多抛弃了warp操作,因此没有深入研究。但是MEMC是必须用的,否则就要用超级大的网络直接端到端的生成。认准原创https://blog.csdn.net/longshaonihaoa/article/details/125964061

    MEMC系列文章:
    运动估计运动补偿(Motion estimation and motion compensation,MEMC)入门总结
    深度学习MEMC插帧论文列表paper list
    光流估计中cost volume详解
    插帧中grid_sample函数详解

    1、grid_sample基本功能讲解

    官方讲解
    https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

    函数原型

    torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)
    
    • 1

    参数选择
    函数有两个输入项,三个可选参数项。
    input:输入,原始图像。维度[B,3,H,W]
    grid:映射表。维度[B,H,W,2],值归一化为[-1, 1]
    mode: 插值模式,可选双线性‘bilinear’,最近邻‘nearest’。
    padding_mode: 补边模式,可选反射‘reflection’,边缘‘border’,零‘zero’。
    align_corners: 对齐模式,是否选择对齐。

    函数功能:
    首先我们区分一下坐标的区别。比如一张图片,坐标是指某个位置,如(2,3)就是指定图像的第2行第3列那个位置。是说这个位置上的像素值。
    对应到grid上,他每个坐标处会有两个值,对应的是映射后的坐标。所以grid的最后一维是2,分别对应X,Y。这里XY的值归一化到了[-1,1],在应用是需注意,在函数内部实现中会映射到原始尺寸。下面例子中为了形象讲grid时用非归一化的值。(为啥要归一化,开始我觉得蛮多此一举,最近我看图形学也有类似的归一化,应该有一样的原理?)当对输入图像进行处理时,比如需要处理(2,3)这个坐标。那就查grid中坐标为(2,3)的值,假设为(3,3),那就把原图中(2,3)这个坐标上的值 赋给 输出(3,3)这个坐标。

    参数介绍:
    padding_mode:当grid的值超出了宽高界限,该怎么选择值。
    reflection: 用关于边界的对称点的值,直到坐标落在界内。
    border:用边界的值代替
    zeros:用0代替。

    align_corner: 双线性插值的固有参数,是否对其。
    这两个参数在下文代码中会更详细介绍。

    2、ATen代码实现

    完整的代码可参考官方实现

    基本逻辑如下:

    # 逐像素循环处理
    for (const auto h : c10::irange(out_H)) {
        for (const auto w : c10::irange(out_W)) {
        	...
        	// 对坐标进行处理,接下来会讲这个函数
        	scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners);
            scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners);
            if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
                // 双线性插值操作
                ... 
                }
            else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
                // 最近邻插值操作
                int64_t ix_nearest = static_cast(std::nearbyint(ix));
                int64_t iy_nearest = static_cast(std::nearbyint(iy));
                ...
                }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    其实代码中对函数理解最重要的就是grid_sampler_compute_source_index 函数,其代码可见官方地址

    从以下可以看出它调用了两个函数,一个是unnormalize,一个是计算坐标。

    scalar_t grid_sampler_compute_source_index(...) {
      coord = grid_sampler_unnormalize(coord, size, align_corners);
      coord = compute_coordinates(coord, size, padding_mode, align_corners);
      return coord;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5

    unnormalize 实现如下。根据align_corner的设置得到不同运算。当align_corner为True时,原来的[-1,1]映射为[0, size - 1]。False则将[-1, 1] to [-0.5, size - 0.5]。具体代码如下

    scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
      if (align_corners) {
        // unnormalize coord from [-1, 1] to [0, size - 1]
        return ((coord + 1.f) / 2) * (size - 1);
      } else {
        // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
        return ((coord + 1.f) * size - 1) / 2;
      }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    注意align_corner并非只有此处使用。
    计算坐标主要是说对padding_mode的处理。主要可以看以下这部分代码:

    scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
      ...
      scalar_t min = static_cast(twice_low) / 2;
      scalar_t span = static_cast(twice_high - twice_low) / 2;
      in = ::fabs(in - min);
      scalar_t extra = ::fmod(in, span);
      int flips = static_cast(::floor(in / span));
      if (flips % 2 == 0) {    // return略有修改,因为我觉得这样更清楚
        return min + extra;
      } else {
        return min + (span - extra);
      }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    3、CUDA实现

    cuda官方实现的核函数在这里
    感觉cuda比上面写的更清楚,区别在于没有循环。因为cuda核函数是对某一个位置进行操作的。

    4、注意点

    grid给定的事归一化的坐标值,而非偏移量。区别在于,坐标值直接通过unnormalize得到目标坐标。而偏移量需要加上当前坐标才能的到目标坐标。

  • 相关阅读:
    MCDF--lab03
    AI批量写文章伪原创:基于ChatGPT长文本模型,实现批量改写文章、批量回答问题(长期更新)
    信号处理之巴特沃斯滤波器的理解----2022/11/30
    如何在图片上添加水印?快把这些方法收好
    高德德图进去不显示地图或者刷新页面地图丢失解决方法
    Pytorch实现的LSTM、RNN模型结构
    gin 模版
    多线程bind二次封装
    C //例6.4 将一个二维数组行和列的元素互换,存到另一个二维数组中。
    WIN2012远程桌面授权过期
  • 原文地址:https://blog.csdn.net/longshaonihaoa/article/details/125964061