• 详解torch.nn.functional.grid_sample函数(通俗易懂):可实现对特征图的水平/垂直翻转


    一、函数介绍

    Pytorch中grid_sample函数的接口声明如下,具体网址可以点这里

    torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zeros’, align_corners=None)

    为了简单起见,以下讨论都是基于如下参数进行实验及讲解的:

    torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘border’, align_corners=True)

    给定 维度为(N,C,Hin,Win) 的input,维度为(N,Hout,Wout,2) 的grid,则该函数output的维度为(N,C,Hout,Wout)

    • 其实,input(N,C,Hin,Win)可以理解为一批特征图。其中,N可以理解为批大小(batch size),C可以理解为特征图的通道数,Hin可以理解为特征图的高,Win可以理解为特征图的宽

    • gird(N,Hout,Wout,2)的作用在于提供一批用于在输入特征图上进行元素采样的位置坐标grid的元素值通常在[-1,1]之间(-1,-1) 表示取输入特征图左上角的元素,(1,1) 表示取输入特征图右下角的元素。

    • output(N,C,Hout,Wout)表示函数输出的一批特征图,其批大小依然为N,特征图的通道数依然为C,但特征图的高已经变成了Hout,宽变成了Wout,并且输出特征图中的元素值是从根据grid所提供的位置坐标在输入特征图中采样得到的

    因此,一般来说,我们 首先需要根据Hin和Win的大小,对输入特征图元素坐标位置进行规范化

    假设我们此时有一个1 × 2 × 5 × 9的特征图,即N=1,C=2,Hin=5,Win=9。如下:

    在这里插入图片描述

    那么对输入特征图根据其高(Hin=5)和宽(Win=9)进行元素坐标位置规范化如下:

    在这里插入图片描述

    如果我们想实现对 输入特征图 input(维度大小为1 × 2 × 5 × 9)的 水平翻转 ,则 grid (维度大小应该为1 × 5 × 9 × 2)应该设定为对上述 坐标位置规范化结果的水平翻转 形式,如下:【注意,这里的两个2表示的含义完全不同,input中的2表示的是通道数为2,而grid中的2表示的是坐标,众所周知,二维坐标是2个数。这里之所以举例比较巧合,就是想通过这里的解释,让大家深刻理解上述参数的具体含义。】

    在这里插入图片描述

    二、代码验证(输入特征图和输出特征图大小相同)

    根据代码运行结果可知,当grid设定为对输入特征图元素坐标位置规范化结果的水平翻转形式时,也就实现了对输入特征图的水平翻转

    import torch
    import torch.nn.functional as F
    
    input_data = torch.tensor([[[[1,2,3,4,5,-4,-3,-2,-1],
                            [-1,-2,-3,-4,-5,4,3,2,1],
                            [1,3,5,7,9,11,13,15,17],
                            [0,2,4,6,8,10,12,14,16],
                            [3,6,9,12,15,16,19,21,24]],
    
                           [[9,8,7,6,5,4,3,2,1],
                           [1,2,3,4,5,6,7,8,9],
                           [-9,-8,-7,-6,-5,-4,-3,-2,-1],
                           [-1,-2,-3,-4,-5,-6,-7,-8,-9],
                           [0,2,4,6,8,1,3,5,7]]]]).float()
    print(input_data.shape) # torch.Size([1, 2, 5, 9])
    
    grid = torch.tensor([[[[1,-1],
                          [0.75,-1],
                          [0.5,-1],
                          [0.25,-1],
                          [0,-1],
                          [-0.25,-1],
                          [-0.5,-1],
                          [-0.75,-1],
                          [-1,-1]],
    
                          [[1,-0.5],
                          [0.75,-0.5],
                          [0.5,-0.5],
                          [0.25,-0.5],
                          [0,-0.5],
                          [0.25,-0.5],
                          [0.5,-0.5],
                          [0.75,-0.5],
                          [1,-0.5]],
    
                          [[1,0],
                          [0.75,0],
                          [0.5,0],
                          [0.25,0],
                          [0,0],
                          [0.25,0],
                          [0.5,0],
                          [0.75,0],
                          [1,0]],
    
                          [[1,0.5],
                          [0.75,0.5],
                          [0.5,0.5],
                          [0.25,0.5],
                          [0,0.5],
                          [0.25,0.5],
                          [0.5,0.5],
                          [0.75,0.5],
                          [1,0.5]],
    
                          [[1,1],
                          [0.75,1],
                          [0.5,1],
                          [0.25,1],
                          [0,1],
                          [0.25,1],
                          [0.5,1],
                          [0.75,1],
                          [1,1]]]])
    print(grid.shape) # torch.Size([1, 5, 9, 2])
    
    output = F.grid_sample(input_data, grid, mode='bilinear', padding_mode='border', align_corners=True)
    print(output)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69

    在这里插入图片描述

    在上述例子中,批大小为1。如过批大小为N,则grid应该为N个特征图提供相应的N种采样方式,比如对某些特征图进行水平翻转,对某些特征图进行上下翻转…,当然也可以对N个特征图提供N种相同的采样方式。注意:虽然可以对N个特征图提供N种相同的采样方式,但是对于每个特征图中的所有通道,采样方式都是一致的

    另外,在本例中,输入特征图和输出特征图大小相同。如果我们想输出和输入特征图不同大小的特征图,也是可以的,只需要对grid进行改变即可,参见第三部分。

    三、代码验证(输入特征图和输出特征图大小不同)

    假定,输入特征图与上述保持一致,即N=1,C=2,Hin=5,Win=9,如下:

    在这里插入图片描述

    然而我们 只想采样黄色区域的元素,则相应地, grid应该只选择对输入特征图元素坐标位置规范化结果的对应坐标位置,如下:

    在这里插入图片描述
    代码及结果如下:

    import torch
    import torch.nn.functional as F
    
    input_data = torch.tensor([[[[1,2,3,4,5,-4,-3,-2,-1],
                            [-1,-2,-3,-4,-5,4,3,2,1],
                            [1,3,5,7,9,11,13,15,17],
                            [0,2,4,6,8,10,12,14,16],
                            [3,6,9,12,15,16,19,21,24]],
    
                           [[9,8,7,6,5,4,3,2,1],
                           [1,2,3,4,5,6,7,8,9],
                           [-9,-8,-7,-6,-5,-4,-3,-2,-1],
                           [-1,-2,-3,-4,-5,-6,-7,-8,-9],
                           [0,2,4,6,8,1,3,5,7]]]]).float()
    print(input_data.shape) # torch.Size([1, 2, 5, 9])
    
    grid = torch.tensor([[[[-0.75,-1],
                           [-0.25,-1],
                           [0.25,-1],
                           [0.75,-1]],
    
                          [[-0.75,0],
                           [-0.25,0],
                           [0.25,0],
                           [0.75,0]],
    
                           [[-0.75,1],
                           [-0.25,1],
                           [0.25,1],
                           [0.75,1]]]])
    print(grid.shape) # torch.Size([1, 3, 4, 2])
    
    output = F.grid_sample(input_data, grid, mode='bilinear', padding_mode='border', align_corners=True)
    print(output)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    在这里插入图片描述

    四、自动对输入特征图中的元素坐标位置进行规范化操作

    看到这里,相信大家已经基本知道 torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zeros’, align_corners=None)函数是做什么的了。

    总之一句话,该函数可以根据grid中的坐标顺序对input进行重新采样,从而生成新的ouput

    在上述分析中,我们对输入特征图元素坐标位置的规范化是手动计算的,那么能不能让程序自动对输入特征图中的元素坐标位置进行规范化操作呢?

    当然是可以的,具体分析可以参考如下代码:首先,定义一个函数 generate_flip_grid(w, h)自动对输入特征图中的元素坐标位置进行规范化操作,然后,对规范化后的结果进行相关变化,比如水平翻转或上下翻转,即可实现对输入特征图的水平翻转或上下翻转。在下例中,我们对给定的一批输入特征图均执行了水平翻转的操作。

    # 参考链接:
    # https://cloud.tencent.com/developer/article/1781060
    # https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html
    
    
    import torch
    import torch.nn.functional as F
    
    
    # w和h分别为torch.nn.functional.grid_sample函数中input参数的宽和高
    def generate_flip_grid(w, h):
        x_ = torch.arange(w).unsqueeze(0).expand(h, -1) # torch.Size([h, w])
        # expand(*size)函数可以实现对张量中单维度上数据的复制操作。
        # 其中,*size分别指定了每个维度上复制的倍数。
        # 对于不需要(或非单维度)进行复制的维度,对应位置上可以写上原始维度的大小或者直接写-1。
    
        # 单维度怎么理解呢?
        # 将张量中大小为1的维度称为单维度。例如,shape为[2,3]的张量就没有单维度,
        # shape为[1,3]的张量,其第0个维度上的大小为1,因此第0个维度为张量的单维度。
    
        # 例如,torch.arange(7)结果的shape为[7],没有单维度,因此需要先通过unsqueeze()进行维度增加,
        # 参数为0表示在第0个维度进行维度增加操作,即在张量最外层加一个中括号变成第一维。
    
        y_ = torch.arange(h).unsqueeze(1).expand(-1, w) # torch.Size([h, w])
        grid = torch.stack([x_, y_], dim=0).float() # torch.Size([2, h, w])
        # 将x_和y_沿维度0进行堆叠
    
        grid = grid.unsqueeze(0) # torch.Size([1,2, h, w])
        grid[:, 0, :, :] = 2 * grid[:, 0, :, :] / (w - 1) - 1 # 相当于对x轴坐标进行规范化操作 torch.Size([1, 2, h, w])
        grid[:, 1, :, :] = 2 * grid[:, 1, :, :] / (h - 1) - 1 # 相当于对y轴坐标进行规范化操作 torch.Size([1, 2, h, w])
        grid = grid.permute(0,2,3,1) # 交换维度 转换为 torch.nn.functional.grid_sample函数中grid规定的形式[1,h,w,2]
    
        return grid # torch.Size([1,h,w,2])
    
    
    
    
    # w和h分别为torch.nn.functional.grid_sample函数中input参数的宽和高
    w = 9
    h = 5
    N = 2 # 这里的N相当于batch size
    
    grid = generate_flip_grid(w,h) # 获取输入特征图中元素位置的规范化结果
    
    grid = grid.expand(N, -1, -1, -1).clone() # torch.Size([N, h, w, 2])
    # expand()函数并不会重新分配内存,返回的结果仅仅是原始张量上的一个视图,无法对原始张量进行修改。
    # 因此,如果expand之后直接在下面对grid张量进行元素改变,就会发生错误。
    # clone()函数为复制函数, 可以返回一个完全相同的张量,与原张量不共享内存,从而可以实现下面对张量的修改。
    
    grid[:, :, :, 0] = -grid[:, :, :, 0] # 对x轴坐标取反,相当于实现了水平/左右翻转
    # grid[:, :, :, 1] = -grid[:, :, :, 1] # 对y轴坐标取反,相当于实现了上下翻转
    
    input = torch.tensor([[[[1,2,3,4,5,-4,-3,-2,-1],
                            [-1,-2,-3,-4,-5,4,3,2,1],
                            [1,3,5,7,9,11,13,15,17],
                            [0,2,4,6,8,10,12,14,16],
                            [3,6,9,12,15,16,19,21,24]],
    
                           [[9,8,7,6,5,4,3,2,1],
                            [1,2,3,4,5,6,7,8,9],
                            [-9,-8,-7,-6,-5,-4,-3,-2,-1],
                            [-1,-2,-3,-4,-5,-6,-7,-8,-9],
                            [0,2,4,6,8,1,3,5,7]]],
    
                           [[[9,8,7,6,5,4,3,2,1],
                             [1,2,3,4,5,6,7,8,9],
                             [-9,-8,-7,-6,-5,-4,-3,-2,-1],
                             [-1,-2,-3,-4,-5,-6,-7,-8,-9],
                             [0,2,4,6,8,1,3,5,7]],
    
                            [[1,2,3,4,5,-4,-3,-2,-1],
                             [-1,-2,-3,-4,-5,4,3,2,1],
                             [1,3,5,7,9,11,13,15,17],
                             [0,2,4,6,8,10,12,14,16],
                             [3,6,9,12,15,16,19,21,24]]]]).float()
    # print(input.shape) # torch.Size([2, 2, 5, 9])
    
    output = F.grid_sample(input, grid, mode='bilinear', padding_mode='border', align_corners=True)
    print(output)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79

    在这里插入图片描述

    五、关于参数

    torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zeros’, align_corners=None)

    通过上面的介绍,相信大家对input和grid以及函数的输出都已经了解得差不多了。

    这里,主要说一下其它的三个参数。

    • padding_mode表示当grid中的坐标位置超出边界时像素值的填充方式,如果为zeros,则表示一旦grid坐标超出边界,则用0去填充输出特征图的相应位置元素,如果为border,则表示利用输入特征图对应的边缘元素去填充输出特征图的相应位置元素。想了解更多选择,可以去官网进一步了解。笔者目前只研究了zeros和border两种情况。

    • mode表示插值方式,对于四维数据的话,大家一般选择bilinear即可。想了解更多选择,可以去官网进一步了解。这里说明一下什么时候会用到插值,如果grid中的某个坐标直接对应于输入特征图元素位置的规范化结果中的某个坐标,则直接把对应的值取过来就行。但如果grid中的某个坐标不能直接对应于输入特征图元素位置的规范化结果中的所有坐标,则需要根据不同的插值方式(比如bilinear)在输入特征图中进行插值。

    • 至于align_corners这个参数,一般和插值方式mode搭配使用,表示在插值时像素的对齐方式,有两种选择,分别是True和False。如果把一个像素点看做一个正方形的话,True表示角像素点位于对应正方形的中心。False表示角像素点位于对应正方形的角点坐标。笔者目前只研究明白了align_corner=True的含义。

    六、参考文献

    总之,这个函数,就是根据grid所提供的坐标在input中进行重新采样,然后生成新的output。用这个函数来实现特征图的水平翻转或垂直翻转,特别容易理解,因为直接把input的元素坐标都水平翻转或垂直翻转一下就行了。这种情况下,grid中的坐标不会超出界限,因此就不用考虑padding_mode参数。另外,这种情况下,grid中的坐标在input中均能找到映射,因此也不用考虑详细的插值情况,只需要注意将align_corners设为True,因为我们grid边缘的点的位置坐标在相应的轴上都是等距的,与align_corners为True一致(align_corners为False则不等距)。

    这篇博客被我断断续续写了三四天,如果大家觉得有所帮助的话,麻烦点个赞鼓励一下吧😭。大家有任何问题,欢迎评论区留言,我看到都会尽量回复的~

  • 相关阅读:
    华为---- RIP路由协议基本配置
    【软件】Ubuntu16.04安装repo全纪录,构建自己的repo仓库,最详细的步骤大全,以及踩坑大全
    vue项目中常用解决跨域的方法
    二维随机向量的数学期望E与协方差σ
    计算机视觉代码学习
    最新WooCommerce教程指南-如何搭建B2C外贸独立站
    RabbitMQ消息队列快速入门
    ISS点云内部形状特征描述子
    前端如何将自定义组件注册到全局
    【2023】Git版本控制-远程仓库详解
  • 原文地址:https://blog.csdn.net/qq_40968179/article/details/128093033