• PyTorch应用实战一:实现卷积操作


    实验环境

    python3.6 + pytorch1.8.0

    import torch
    print(torch.__version__)
    
    • 1
    • 2
    1.8.0
    
    • 1

    0.卷积定义

    卷积操作是指两个函数f和g之间的一种数学运算,它在信号处理、图像处理、机器学习等领域中广泛应用。在离散情况下,卷积操作可以表示为:

    ( f ∗ g ) [ n ] = ∑ m = − ∞ ∞ f [ m ] g [ n − m ] (f * g)[n] = \sum_{m=-\infty}^{\infty}f[m]g[n-m] (fg)[n]=m=f[m]g[nm]

    其中, f f f g g g是离散函数, ∗ * 表示卷积操作, n n n是离散的变量。卷积操作可以看作是将函数 g g g沿着 n n n轴翻转,然后平移,每次和函数 f f f相乘并求和,最后得到一个新的函数。这种操作可以实现信号的滤波、特征提取等功能,是数字信号处理中非常重要的基础操作。

    1.利用张量操作实现卷积

    1.1 unfold函数

    PyTorchunfold函数用于对张量进行展开操作。torch.unfold()可以理解为将一个高维的张量展开成一个二维矩阵的操作。即将原来的张量沿着指定的维度展开成一个二维矩阵,其中第一维对应原来张量的维度,第二维对应展开的位置。

    函数原型如下:

    torch.unfold(input, dimension, size, step)
    
    • 1

    参数说明:

    • input (Tensor) – 要展开的张量
    • dimension (int) – 沿着哪个维度展开
    • size (int) – 展开窗口的大小
    • step (int) – 两个相邻窗口之间的步长

    1.2 张量分片

    import torch
    
    • 1
    a = torch.arange(16).view(4, 4)
    a
    
    • 1
    • 2
    tensor([[ 0,  1,  2,  3],
            [ 4,  5,  6,  7],
            [ 8,  9, 10, 11],
            [12, 13, 14, 15]])
    
    • 1
    • 2
    • 3
    • 4
    b = a.unfold(0, 3, 1)
    b
    
    • 1
    • 2
    tensor([[[ 0,  4,  8],
             [ 1,  5,  9],
             [ 2,  6, 10],
             [ 3,  7, 11]],
    
            [[ 4,  8, 12],
             [ 5,  9, 13],
             [ 6, 10, 14],
             [ 7, 11, 15]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    b.shape
    
    • 1
    torch.Size([2, 4, 3])
    
    • 1
    c = b.unfold(1, 3, 1)
    c
    
    • 1
    • 2
    tensor([[[[ 0,  1,  2],
              [ 4,  5,  6],
              [ 8,  9, 10]],
    
             [[ 1,  2,  3],
              [ 5,  6,  7],
              [ 9, 10, 11]]],
    
    
            [[[ 4,  5,  6],
              [ 8,  9, 10],
              [12, 13, 14]],
    
             [[ 5,  6,  7],
              [ 9, 10, 11],
              [13, 14, 15]]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    c.shape
    
    • 1
    torch.Size([2, 2, 3, 3])
    
    • 1

    完整程序

    import torch
    a = torch.arange(16).view(4, 4)
    b = a.unfold(0, 3, 1)
    c = b.unfold(1, 3, 1)
    c.shape
    
    • 1
    • 2
    • 3
    • 4
    • 5
    torch.Size([2, 2, 3, 3])
    
    • 1

    这段代码定义了三个变量。假设我们将其分别命名为abc,则:

    • 变量a是一个4x4的张量,其中包含了0到15的整数值,它通过torch.arange(16).view(4, 4)两个函数调用来实现。
    • 变量b是通过对变量a进行折叠操作得到的一个张量,具体来说,它是将变量a沿着第0维(即行)展开,并取窗口大小为3,步长为1的子张量所得到的结果。因此,如果我们将张量b打印出来,会得到:
    tensor([[[ 0,  1,  2],
             [ 4,  5,  6],
             [ 8,  9, 10],
             [12, 13, 14]],
    
            [[ 1,  2,  3],
             [ 5,  6,  7],
             [ 9, 10, 11],
             [13, 14, 15]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    其中,第一个子张量的值为[[0, 1, 2], [4, 5, 6], [8, 9, 10], [12, 13, 14]],第二个子张量的值为[[1, 2, 3], [5, 6, 7], [9, 10, 11], [13, 14, 15]]。注意,这个张量的形状为(2, 4, 3),即它包含2个子张量,每个子张量的形状为(4, 3)

    • 变量c是对变量b进行类似的操作得到的,但是它是在第1维(即列)上展开并取子张量。具体来说,它是将变量b沿着第1维(即列)展开,并取窗口大小为3,步长为1的子张量所得到的结果。因此,如果我们将张量c打印出来,会得到:
    tensor([[[[ 0,  1,  2],
              [ 4,  5,  6],
              [ 8,  9, 10]],
    
             [[ 1,  2,  3],
              [ 5,  6,  7],
              [ 9, 10, 11]]],
    
    
            [[[ 4,  5,  6],
              [ 8,  9, 10],
              [12, 13, 14]],
    
             [[ 5,  6,  7],
              [ 9, 10, 11],
              [13, 14, 15]]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    其中,第一个子张量的值为[[[0, 1, 2], [4, 5, 6], [8, 9, 10]], [[1, 2, 3], [5, 6, 7], [9, 10, 11]]],第二个子张量的值为[[[4, 5, 6], [8, 9, 10], [12, 13, 14]], [[5, 6, 7], [9, 10, 11], [13, 14, 15]]]。注意,这个张量的形状为(2, 2, 3, 3),即它包含2个子张量,每个子张量的形状为(2, 3, 3)

    2.实现卷积操作

    2.1 编写卷积函数

    完整程序

    import torch
    def conv2d(x, weight, bias, stride, pad):
        n, c, h, w = x.shape
        d, c, k, j = weight.shape
        
        x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
        x_pad[:, :, pad:-pad, pad:-pad] = x
        
        x_pad = x_pad.unfold(2, k, stride)
        x_pad = x_pad.unfold(3, j, stride)
        
        out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
        out = out + bias.view(1, -1, 1, 1)
        return out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    该函数实现了二维卷积操作。下面对函数进行详细分析:

    1. 输入参数:
    • x: 输入张量,维度为(batch_size,in_channels,input_height,input_width)。
    • weight: 卷积核张量,维度为(out_channels,in_channels,kernel_height,kernel_width)。
    • bias: 偏置项张量,维度为(out_channels,)。
    • stride: 卷积核移动的步长,可以是一个数或是一个长度为 2 的元组,分别表示水平方向和竖直方向的步长。
    • pad: 输入张量周围要填充的零的数量。
    1. 局部填充:
    • 在进行卷积操作之前,需要在输入张量的周围按照给定的 pad 进行填充,以避免卷积核在张量边缘处超出范围的情况发生。
    • 在函数中使用 x_pad 表示经过填充后的输入张量。
    • 具体实现:将输入张量 x 在第 2 和第 3 个维度(height 和 width 维度)上分别拆分成若干个形状为(kernel_height,kernel_width)的张量,每个张量之间的跳跃长度由 stride 决定,然后在第 2 和第 3 个维度上分别进行展开。这样每个展开后的张量就可以看作一个二维卷积核作用在 x 上的局部卷积结果,这些局部结果被按照第 2 和第 3 个维度重新拼接起来,得到新的张量 x_pad。
    1. 卷积操作:
    • 在新的张量 x_pad 上使用 einsum 函数对卷积核进行卷积操作。

    • einsum 的第一个参数表示操作的规则,其中 ndhw 表示最终输出的张量的维度为(batch_size,out_channels,output_height,output_width),nchw 和 dckj 表示两个输入张量 x_pad 和 weight 的维度,其中 c k j 分别表示 input_channels、kernel_height 和 kernel_width。

    • 最终得到的输出张量形状为(batch_size,out_channels,output_height,output_width),并在每个位置上加上偏置项 bias。

    • torch: PyTorch库

    • einsum: Einstein summation notation,爱因斯坦求和约定,一种张量求和的简便表示法。

    • 'nchwkj,dckj->ndhw': 爱因斯坦求和符号,左侧的张量为x_pad,右侧的张量为weight。在左侧张量中,n, c, h, w分别表示batch size、通道数、高度、宽度。在右侧张量中,d, c, k, j分别表示输出通道数、输入通道数、卷积核高度、卷积核宽度。这个式子的意义是将x_padweight执行卷积操作,并输出结果张量,其形状为(batch_size, output_channels, height, width)

    • x_pad: 输入的张量,形状为 (batch_size, input_channels, input_height, input_width)

    • weight: 卷积核张量,形状为 (output_channels, input_channels, kernel_height, kernel_width)

    1. 返回结果:
    • 返回卷积操作后得到的输出张量。

    2.2 对编写的卷积函数举例分析

    # 设置测试数据
    x = torch.randn(2, 3, 5, 5, requires_grad=True)
    weight = torch.randn(4, 3, 3, 3, requires_grad=True)
    bias = torch.randn(4, requires_grad = True)
    stride = 2
    pad = 2
    x, weight, bias
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    (tensor([[[[-0.4888,  1.0257,  0.0312, -0.9026, -0.9060],
               [ 0.2071, -0.4962, -0.1658,  1.0919,  0.3785],
               [-0.4654,  1.5442,  0.6005,  0.3594, -2.6207],
               [ 0.5830,  0.0533,  0.5719,  1.5413,  0.5949],
               [-0.9152, -0.2114, -0.4888, -0.0065, -0.9767]],
     
              [[ 0.4706, -0.1108, -0.1563, -1.7946, -0.8533],
               [-0.2119,  0.3165, -2.2668, -0.8956,  1.0617],
               [-0.7809, -0.2120, -0.8592, -0.5057,  0.7954],
               [-2.8820, -0.6888,  0.4450, -0.3586, -0.9477],
               [ 0.6244,  0.4303,  1.4739,  0.2740,  1.6605]],
     
              [[-0.1501,  0.6234, -1.6086,  0.1693,  0.4932],
               [ 1.0611, -1.0938,  0.1695,  1.0193,  0.4263],
               [ 1.4681, -0.1552, -0.0667, -0.7293,  1.0816],
               [ 0.8972,  1.1683, -1.4757,  0.4421, -0.0355],
               [-2.1331,  1.4847,  0.1378, -1.6907, -0.1350]]],
     
     
             [[[-1.3853,  1.6396,  0.3436,  0.3841,  0.2355],
               [-0.2206, -0.5087, -1.6956,  1.3205,  0.7058],
               [ 0.0993,  0.3533, -0.2086,  0.2969,  0.2627],
               [ 0.3752,  0.0304,  1.2487,  1.3963, -0.0063],
               [-1.3758,  0.5088, -1.3849,  1.3050,  0.4150]],
     
              [[ 0.2824, -2.8634, -0.1016, -0.1627,  1.7081],
               [ 0.1406,  0.2220, -0.6005,  0.2997, -0.1846],
               [ 1.6700,  0.5787,  0.6561, -0.0236,  1.7743],
               [ 2.1429, -0.2838, -0.0527,  0.3504, -0.3444],
               [-0.9409, -0.4734, -0.4060, -0.5088, -1.8518]],
     
              [[-2.2152,  0.2104, -0.3302,  0.2036, -0.9443],
               [-0.6576, -0.4455,  0.5117, -2.0058, -1.3985],
               [-0.5688,  1.2338, -0.1832,  0.1760,  0.4506],
               [-0.6563,  0.4021, -1.6210,  0.5582, -0.9238],
               [-1.0506, -0.9638,  0.7453, -0.3535, -0.3536]]]], requires_grad=True),
     tensor([[[[ 0.3069,  0.2079, -0.2952],
               [ 1.7681,  1.1056, -1.0555],
               [ 1.5845,  0.8294,  0.6588]],
     
              [[ 0.2574,  0.5007,  0.2912],
               [-0.0210,  0.6593, -0.9691],
               [-0.2918,  0.5695, -1.1242]],
     
              [[ 0.7327, -0.3453,  0.7041],
               [-0.2236, -1.7762,  0.0190],
               [-1.0927, -2.9369,  0.1768]]],
     
     
             [[[-2.3830, -1.4807,  1.8573],
               [ 1.0097, -0.9640,  1.0361],
               [-0.5222, -1.0386, -0.4016]],
     
              [[ 0.5071,  1.1433, -0.1194],
               [-0.0133, -0.3878, -0.1853],
               [ 0.3456, -0.6502,  0.2221]],
     
              [[-1.7672, -0.0469, -0.5996],
               [-0.2080, -1.6209,  0.4120],
               [ 0.8404, -1.6748, -0.7170]]],
     
     
             [[[ 0.2850,  0.1691, -0.9228],
               [ 0.7234,  0.5582, -0.4327],
               [ 0.6563,  0.2941,  1.5549]],
     
              [[ 0.2642, -1.9061,  1.6212],
               [-0.5276, -0.5608,  0.3824],
               [ 0.4452, -2.5152,  0.4490]],
     
              [[-0.1276,  0.7784,  0.7998],
               [-0.3030, -0.9776,  0.9681],
               [ 1.0225,  0.8946, -0.8084]]],
     
     
             [[[-0.5087, -0.8345, -1.4763],
               [-0.4938,  1.1979, -0.1335],
               [ 0.5010,  0.2865,  0.0728]],
     
              [[-0.3177, -0.6937, -1.0327],
               [ 0.8147, -1.7101, -1.8257],
               [-0.1593, -1.3855, -0.0885]],
     
              [[-0.4687, -1.6307,  1.5791],
               [-1.3030,  0.2004, -0.7055],
               [ 0.0674, -0.8772,  0.1586]]]], requires_grad=True),
     tensor([ 1.5349, -0.5608,  0.5182,  0.3328], requires_grad=True))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    n, c, h, w = x.shape
    d, c, k, j = weight.shape
    
    • 1
    • 2
    n, c, h, w
    
    • 1
    (2, 3, 5, 5)
    
    • 1
    d, c, k, j
    
    • 1
    (4, 3, 3, 3)
    
    • 1
    # 补零
    x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
    x_pad
    
    • 1
    • 2
    • 3
    tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
    
             [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
    
             [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
    
    
            [[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
    
             [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.]],
    
             [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    x_pad.shape
    
    • 1
    torch.Size([2, 3, 9, 9])
    
    • 1
    x_pad[:, :, pad:-pad, pad:-pad] = x
    x_pad
    
    • 1
    • 2
    tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.4888,  1.0257,  0.0312, -0.9026, -0.9060,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.2071, -0.4962, -0.1658,  1.0919,  0.3785,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.4654,  1.5442,  0.6005,  0.3594, -2.6207,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.5830,  0.0533,  0.5719,  1.5413,  0.5949,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.9152, -0.2114, -0.4888, -0.0065, -0.9767,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000]],
    
             [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.4706, -0.1108, -0.1563, -1.7946, -0.8533,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.2119,  0.3165, -2.2668, -0.8956,  1.0617,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.7809, -0.2120, -0.8592, -0.5057,  0.7954,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -2.8820, -0.6888,  0.4450, -0.3586, -0.9477,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.6244,  0.4303,  1.4739,  0.2740,  1.6605,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000]],
    
             [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.1501,  0.6234, -1.6086,  0.1693,  0.4932,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  1.0611, -1.0938,  0.1695,  1.0193,  0.4263,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  1.4681, -0.1552, -0.0667, -0.7293,  1.0816,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.8972,  1.1683, -1.4757,  0.4421, -0.0355,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -2.1331,  1.4847,  0.1378, -1.6907, -0.1350,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000]]],
    
    
            [[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -1.3853,  1.6396,  0.3436,  0.3841,  0.2355,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.2206, -0.5087, -1.6956,  1.3205,  0.7058,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0993,  0.3533, -0.2086,  0.2969,  0.2627,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.3752,  0.0304,  1.2487,  1.3963, -0.0063,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -1.3758,  0.5088, -1.3849,  1.3050,  0.4150,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000]],
    
             [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.2824, -2.8634, -0.1016, -0.1627,  1.7081,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.1406,  0.2220, -0.6005,  0.2997, -0.1846,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  1.6700,  0.5787,  0.6561, -0.0236,  1.7743,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  2.1429, -0.2838, -0.0527,  0.3504, -0.3444,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.9409, -0.4734, -0.4060, -0.5088, -1.8518,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000]],
    
             [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -2.2152,  0.2104, -0.3302,  0.2036, -0.9443,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.6576, -0.4455,  0.5117, -2.0058, -1.3985,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.5688,  1.2338, -0.1832,  0.1760,  0.4506,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -0.6563,  0.4021, -1.6210,  0.5582, -0.9238,
                0.0000,  0.0000],
              [ 0.0000,  0.0000, -1.0506, -0.9638,  0.7453, -0.3535, -0.3536,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000],
              [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                0.0000,  0.0000]]]], grad_fn=)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    # 卷积
    x_pad = x_pad.unfold(2, k, stride)
    x_pad.shape
    
    • 1
    • 2
    • 3
    torch.Size([2, 3, 4, 9, 3])
    
    • 1
    x_pad = x_pad.unfold(3, j, stride)
    x_pad.shape
    
    • 1
    • 2
    torch.Size([2, 3, 4, 4, 3, 3])
    
    • 1
    out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
    out.shape
    
    • 1
    • 2
    torch.Size([2, 4, 4, 4])
    
    • 1
    bias.view(1, -1, 1, 1).shape
    
    • 1
    torch.Size([1, 4, 1, 1])
    
    • 1
    # 偏置
    out = out + bias.view(1, -1, 1, 1)
    out
    
    • 1
    • 2
    • 3
    tensor([[[[ 0.6573, -0.3444,  1.5693, -0.1906],
              [ 2.5483,  5.1142, -2.3528, -3.6162],
              [ 2.9913, -6.1289,  6.8200,  0.9229],
              [ 0.4849,  0.1813,  3.2616,  1.5637]],
    
             [[-0.1524, -1.2003, -0.3415,  0.0318],
              [-1.7830,  2.5286, -1.6660,  3.1253],
              [ 1.3314, -8.2623, -5.0055,  5.7671],
              [-1.0563,  5.2751, -0.4214,  2.8473]],
    
             [[ 0.0908,  2.6704,  1.0336,  0.0481],
              [ 0.2077,  2.0459,  1.8095, -0.7039],
              [ 0.9519, -4.5551,  3.7108,  0.7446],
              [ 0.6689,  3.9448,  2.3968,  0.6958]],
    
             [[ 0.2318, -0.3356,  2.4320,  0.0480],
              [ 0.2101, -1.7177,  6.3956, -0.4108],
              [ 8.2352, -5.8456, 12.9459, -0.8763],
              [-2.3292, -1.5263,  2.1349,  0.3653]]],
    
    
            [[[-0.0869,  1.0713,  0.1655,  2.4414],
              [-1.3623, -3.0759,  0.2430,  2.3259],
              [-0.9281,  2.1402,  7.1618,  4.1895],
              [ 0.9273,  1.1176,  0.7792,  0.9265]],
    
             [[ 1.6465, -1.7187, -0.7251, -0.8871],
              [-1.6260, -0.8628, -1.0122,  3.2737],
              [ 0.5831,  2.1665, -0.5353, -2.0468],
              [-2.3738, -0.1232, -0.0771, -1.8642]],
    
             [[ 0.2819,  6.0978,  2.9618,  0.4676],
              [ 1.3592,  6.7231,  3.8100,  3.6118],
              [ 0.9885, -5.7760,  5.4375,  0.5480],
              [-0.5778,  1.4657, -2.8315,  0.1923]],
    
             [[-0.1444,  3.6788,  0.3721,  0.1150],
              [-1.4057,  0.1613, -2.5436,  1.3156],
              [-6.1195,  1.8325,  3.1565,  0.8296],
              [ 1.6766,  6.9403,  1.3986,  0.8758]]]], grad_fn=)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    2.3 验证编写卷积函数的正确性

    import torch.nn.functional as F
    x = torch.randn(2, 3, 5, 5, requires_grad=True)
    w = torch.randn(4, 3, 3, 3, requires_grad=True)
    b = torch.randn(4, requires_grad = True)
    stride = 2
    pad = 2
    torch_out = F.conv2d(x, w, b, stride, pad)
    my_out = conv2d(x, w, b, stride, pad)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    torch_out == my_out
    
    • 1
    tensor([[[[ True,  True,  True,  True],
              [ True, False, False,  True],
              [ True,  True, False,  True],
              [ True,  True, False,  True]],
    
             [[ True,  True,  True,  True],
              [ True, False,  True,  True],
              [False,  True,  True,  True],
              [ True,  True,  True,  True]],
    
             [[ True, False, False,  True],
              [False, False,  True,  True],
              [ True, False, False,  True],
              [ True,  True, False,  True]],
    
             [[ True,  True, False,  True],
              [False, False, False,  True],
              [ True, False, False, False],
              [ True,  True,  True,  True]]],
    
    
            [[[ True,  True,  True,  True],
              [False,  True, False,  True],
              [ True,  True, False,  True],
              [ True, False,  True,  True]],
    
             [[ True,  True, False,  True],
              [ True, False, False, False],
              [ True, False, False, False],
              [ True, False,  True,  True]],
    
             [[ True,  True, False,  True],
              [False, False, False,  True],
              [ True, False, False, False],
              [ True, False,  True,  True]],
    
             [[ True, False,  True,  True],
              [ True, False, False,  True],
              [False, False, False, False],
              [ True,  True, False,  True]]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    torch.allclose(torch_out, my_out, atol=1e-5)
    
    • 1
    True
    
    • 1
    • torch.allclose是用于检查两个张量之间的数值是否相等的函数。

    • 在使用时,需要将第一个张量作为第一个参数传入(即torch_out),将第二个张量作为第二个参数传入(即my_out),并将允许的绝对误差(atol)作为第三个参数传入(默认值为1e-8)。

    • 函数将返回一个布尔值,表示两个张量是否具有相近的数值。如果返回True,则表示两个张量具有相近的数值,否则表示它们之间存在数值差异。

    grad_out = torch.randn(*torch_out.shape)
    grad_x = torch.autograd.grad(torch_out, x, grad_out, retain_graph=True)
    my_grad_x = torch.autograd.grad(my_out, x, grad_out, retain_graph=True)
    
    • 1
    • 2
    • 3
    torch.allclose(grad_x[0], my_grad_x[0], atol=1e-5)
    
    • 1
    True
    
    • 1
    grad_w = torch.autograd.grad(torch_out, w, grad_out, retain_graph=True)
    my_grad_w = torch.autograd.grad(my_out, w, grad_out, retain_graph=True)
    
    • 1
    • 2
    torch.allclose(grad_w[0], my_grad_w[0], atol=1e-5)
    
    • 1
    True
    
    • 1
    grad_b = torch.autograd.grad(torch_out, b, grad_out, retain_graph=True)
    my_grad_b = torch.autograd.grad(my_out, b, grad_out, retain_graph=True)
    
    • 1
    • 2
    torch.allclose(grad_b[0], my_grad_b[0], atol=1e-5)
    
    • 1
    True
    
    • 1

    全是True,表明编写的卷积函数在一定范围内与PyTorch内置的Conv2d函数结果相近,说明了实现的正确性

    附:系列文章

    序号文章目录直达链接
    1PyTorch应用实战一:实现卷积操作https://want595.blog.csdn.net/article/details/132575530
    2PyTorch应用实战二:实现卷积神经网络进行图像分类https://want595.blog.csdn.net/article/details/132575702
    3PyTorch应用实战三:构建神经网络https://want595.blog.csdn.net/article/details/132575758
    4PyTorch应用实战四:基于PyTorch构建复杂应用https://want595.blog.csdn.net/article/details/132625270
    5PyTorch应用实战五:实现二值化神经网络https://want595.blog.csdn.net/article/details/132625348
    6PyTorch应用实战六:利用LSTM实现文本情感分类https://want595.blog.csdn.net/article/details/132625382
  • 相关阅读:
    华为云ROMA Connect亮相Gartner®全球应用创新及商业解决方案峰会,助力企业应用集成和数字化转型
    iOS UWB——NI框架部分类
    Comparable接口与Comparator接口
    设计模式之迭代器模式
    Qt5开发从入门到精通——第二篇(控件篇)
    细胞膜杂化脂质体载紫杉醇/红细胞膜包被雷公藤甲素-红素仿生共载脂质体的研究制备
    Java 中那些绕不开的内置接口 -- Comparable 和 Comparator
    日期时间格式化 @JsonFormat与@DateTimeFormat
    【无标题】力扣报错:member access within null pointer of type ‘struct ListNode‘
    【操作教程】TSINGSEE青犀视频平台如何将旧数据库导入到新数据库?
  • 原文地址:https://blog.csdn.net/m0_68111267/article/details/132575530