• [2022-11-28]神经网络与深度学习 hw10 - LSTM和GRU


    hw10 - LSTM 和GRU相关习题

    task 1

    题目内容

    当使用公式 h t = h t − 1 + g ( x t , h t − 1 ; Θ ) h_t=h_{t-1}+g(x_t,h_{t-1};\Theta) ht=ht1+g(xt,ht1;Θ)作为循环神经网络的状态更新公式时,分析其可能存在梯度爆炸的原因并给出解决方法。

    题目分析

    梯度爆炸是深度的神经网络所会发生的普遍问题。当计算图变得极深时,由于变深的结构使模型丧失了学习到先前信息的能力,让优化变得极其困难。深层的计算图存在于之后介绍的循环网络中,因为循环网络要在很长时间序列的各个时刻重复应用相同操作来构建非常深的计算图,并且模型参数共享,这使问题更加凸显。
    我们将问题抽象,这个过程就是在计算图中出现一条反复与权重相乘(假设最简单的情况)的路径。那么在 t t t步之后就变成了权重的 t t t次方次计算。通过高中数学我们可知,
    W t = V e c ⋅ d i a g ( λ ) t ⋅ V e c − 1 W^t=Vec \cdot diag(\lambda)^t \cdot Vec^{-1} Wt=Vecdiag(λ)tVec1
    显然可见,结果关于 λ \lambda λ的值极度敏感,由此反向传播得到的结果也因此敏感,非常容易导致梯度爆炸(梯度消失也如此)。
    此时将公式代入,计算分析即可得到梯度爆炸及其产生原因 。
    解决方法见下面的题目解答。

    题目解答

    由循环神经网络更新公式 h t = h t − 1 + g ( x t , h t − 1 ; Θ ) h_t=h_{t-1}+g(x_t,h_{t-1};\Theta) ht=ht1+g(xt,ht1;Θ)和前面推导 W t W^t Wt次权重相乘可知:
    h t = h t − 1 + g ( x t , W T h t − 1 ; Θ ) t = h t − 1 + g ( x t , ( W t ) T h 0 ; Θ ) t h_t=h_{t-1}+g(x_t,W^Th_{t-1};\Theta)^t=h_{t-1}+g(x_t,(W^t)^Th_0;\Theta)^t ht=ht1+g(xt,WTht1;Θ)t=ht1+g(xt,(Wt)Th0;Θ)t
    W W W符合如下特征值分解:
    W = Q Λ Q T W=Q \Lambda Q^T W=QΛQT
    由此可将循环转化为
    h t = h t − 1 + g ( x t , Q T Λ t Q h 0 , Θ ) t h_t=h_{t-1}+g(x_t,Q^T\Lambda^tQh_0,\Theta)^t ht=ht1+g(xt,QTΛtQh0,Θ)t
    显然可见,当 t → ∞ t→∞ t时,幅值不到1的特征值衰减 → 0 →0 0,幅值大于1的特征值 → ∞ →∞
    由此结合第 t t t时刻权重计算有 ∏ t w t \prod_{t}w^t twt,显然反向求梯度的值也是非常大的。
    对于解决梯度爆炸的方法,其一是简单的使用梯度截断方式,暴力高效;另一个是增加门控装置,避免产生这类现象。

    题目总结

    本题考查的是RNN的梯度计算和公式分析,以及对于梯度爆炸的解决方法。

    task 2

    题目内容

    推导LSTM网络中参数的梯度,并分析其避免梯度消失的效果。

    题目分析

    在这里插入图片描述
    首先还是要清楚一下什么是LSTM:
    在这里插入图片描述
    这题只需要计算一下梯度,并且分析即可。

    题目解答

    f t = σ ( W f ∗ ( h t − 1 , x t ) + b f ) f_t=\sigma(W_f*(h_{t-1},x_t)+b_f) ft=σ(Wf(ht1,xt)+bf)
    i t = σ ( W i ∗ ( h t − 1 , x t ) + b i ) i_t=\sigma(W_i*(h_{t-1},x_t)+b_i) it=σ(Wi(ht1,xt)+bi)
    C t ~ = t a n h ( W c ∗ ( h t − 1 , x t ) + b c ) \tilde{C_t}=tanh(W_c*(h_{t-1},x_t)+b_c) Ct~=tanh(Wc(ht1,xt)+bc)
    C t = f t ∗ C t − 1 + i t ∗ C t ~ C_t=f_t*C_{t-1}+i_t*\tilde{C_t} Ct=ftCt1+itCt~
    O t = σ ( W o ∗ ( h t − 1 , x t ) + b o ) O_t = \sigma(W_o*(h_{t-1},x_t)+b_o) Ot=σ(Wo(ht1,xt)+bo)
    h t = O t ∗ t a n h ( C t ) h_t = O_t * tanh(C_t) ht=Ottanh(Ct)
    ∂ E k ∂ W = ∂ E k ∂ h k ∗ ∂ h k ∂ C k ∗ ⋯ ∗ ∂ C 2 ∂ C 1 ∗ ∂ C 1 ∂ W = ∂ E k ∂ h k ∗ ∂ h k ∂ C k ∗ ( ∏ t = 2 k ∂ C t ∂ C t − 1 ) ∗ ∂ C 1 ∂ W \frac{\partial E_k}{\partial W}=\frac{\partial E_k}{\partial h_k}* \frac{\partial h_k}{\partial C_k} * \cdots * \frac{\partial C_2}{\partial C_1} * \frac{\partial C_1}{\partial W} =\frac{\partial E_k}{\partial h_k}* \frac{\partial h_k}{\partial C_k} *( \prod_{t=2}^{k}\frac{\partial C_t}{\partial C_{t-1}})*\frac{\partial C_1}{\partial W} WEk=hkEkCkhkC1C2WC1=hkEkCkhk(t=2kCt1Ct)WC1
    展开并求导可得:
    ∂ C t ∂ C t − 1 = σ ′ ( W f ∗ ( h t − 1 , x t ) ) W f ∗ O t − 1 ⊗ t a n h ′ ( C t − 1 ) ∗ C t − 1 + f t + σ ′ ( W i ∗ ( h t − 1 , x t ) ) W i ∗ O t − 1 ⊗ t a n h ′ ( C t − 1 ) ∗ C t − 1 + σ ′ ( W c ∗ ( h t − 1 , x t ) ) ∗ O t − 1 ⊗ t a n h ′ ( C t − 1 ) ∗ i t \frac{\partial C_t}{\partial C_{t-1}}=\sigma'(W_f*(h_{t-1},x_t))W_f*O_{t-1} \otimes tanh'(C_{t-1})*C_{t-1}+f_t+ \sigma'(W_i*(h_{t-1},x_t))W_i*O_{t-1}\otimes tanh'(C_{t-1})*C_{t-1}+\sigma'(W_c*(h_{t-1},x_t))*O_{t-1}\otimes tanh'(C_{t-1})*i_t Ct1Ct=σ(Wf(ht1,xt))WfOt1tanh(Ct1)Ct1+ft+σ(Wi(ht1,xt))WiOt1tanh(Ct1)Ct1+σ(Wc(ht1,xt))Ot1tanh(Ct1)it
    由梯度求导公式可知,由于门控的存在且门控所在只有0、1两种值,同时三者为相加关系,梯度不会有过大的起伏,有效避免了梯度消失的问题。

    题目总结

    本题考查的是LSTM的求导方法。要求对于LSTM的结构有所记忆,同时考察了对于门控存在意义的理解。

    task 3

    题目内容

    什么时候应该用GRU? 什么时候用LSTM?

    题目分析

    在这里插入图片描述
    首先我们要知道LSTM和GRU之间的区别在哪里,然后分析其对应特点。
    在这里插入图片描述

    题目解答

    我们了解了 RNN、LSTM 和 GRU 单元之间的基本区别。 从两个层(即 LSTM 和 GRU)的工作来看,GRU 使用更少的训练参数,因此使用更少的内存并且比 LSTM 执行得更快,而 LSTM 在更大的数据集上更准确。 如果你处理大序列并且关注准确性,可以选择 LSTM,当你内存消耗较少并且想要更快的结果时使用 GRU。

    题目总结

    本题考查对于LSTM和GRU的综合掌握,通过分析其原理组成和计算过程,能够很好地确定两种结构的使用场景。

    task 4

    题目内容

    LSTMBP,并使用numpy实现

    题目分析

    对LSTM链式求导即可,写成代码形式。

    题目解答

    代码如下:

    import numpy as np
    import torch
     
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))
     
    class LSTMCell:
        def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh):
            self.weight_ih = weight_ih
            self.weight_hh = weight_hh
            self.bias_ih = bias_ih
            self.bias_hh = bias_hh
     
            self.dc_prev = None
            self.dh_prev = None
     
            self.weight_ih_grad_stack = []
            self.weight_hh_grad_stack = []
            self.bias_ih_grad_stack = []
            self.bias_hh_grad_stack = []
     
            self.x_stack = []
            self.dx_list = []
            self.dh_prev_stack = []
     
            self.h_prev_stack = []
            self.c_prev_stack = []
     
            self.h_next_stack = []
            self.c_next_stack = []
     
            self.input_gate_stack = []
            self.forget_gate_stack = []
            self.output_gate_stack = []
            self.cell_memory_stack = []
        def __call__(self, x, h_prev, c_prev):
            a_vector = np.dot(x, self.weight_ih.T) + np.dot(h_prev, self.weight_hh.T)
            a_vector += self.bias_ih + self.bias_hh
     
            h_size = np.shape(h_prev)[1]
            a_i = a_vector[:, h_size * 0:h_size * 1]
            a_f = a_vector[:, h_size * 1:h_size * 2]
            a_c = a_vector[:, h_size * 2:h_size * 3]
            a_o = a_vector[:, h_size * 3:]
     
            input_gate = sigmoid(a_i)
            forget_gate = sigmoid(a_f)
            cell_memory = np.tanh(a_c)
            output_gate = sigmoid(a_o)
     
            c_next = (forget_gate * c_prev) + (input_gate * cell_memory)
            h_next = output_gate * np.tanh(c_next)
     
            self.x_stack.append(x)
     
            self.h_prev_stack.append(h_prev)
            self.c_prev_stack.append(c_prev)
     
            self.c_next_stack.append(c_next)
            self.h_next_stack.append(h_next)
     
            self.input_gate_stack.append(input_gate)
            self.forget_gate_stack.append(forget_gate)
            self.output_gate_stack.append(output_gate)
            self.cell_memory_stack.append(cell_memory)
     
            self.dc_prev = np.zeros_like(c_next)
            self.dh_prev = np.zeros_like(h_next)
     
            return h_next, c_next
     
        def backward(self, dh_next):
            x_stack = self.x_stack.pop()
     
            h_prev = self.h_prev_stack.pop()
            c_prev = self.c_prev_stack.pop()
     
            c_next = self.c_next_stack.pop()
     
            input_gate = self.input_gate_stack.pop()
            forget_gate = self.forget_gate_stack.pop()
            output_gate = self.output_gate_stack.pop()
            cell_memory = self.cell_memory_stack.pop()
     
            dh = dh_next + self.dh_prev
     
            d_tanh_c = dh * output_gate * (1 - np.square(np.tanh(c_next)))
            dc = d_tanh_c + self.dc_prev
     
            dc_prev = dc * forget_gate
            self.dc_prev = dc_prev
     
            d_input_gate = dc * cell_memory
            d_forget_gate = dc * c_prev
            d_cell_memory = dc * input_gate
     
            d_output_gate = dh * np.tanh(c_next)
     
            d_ai = d_input_gate * input_gate * (1 - input_gate)
            d_af = d_forget_gate * forget_gate * (1 - forget_gate)
            d_ao = d_output_gate * output_gate * (1 - output_gate)
            d_ac = d_cell_memory * (1 - np.square(cell_memory))
     
            da = np.concatenate((d_ai, d_af, d_ac, d_ao), axis=1)
     
            dx = np.dot(da, self.weight_ih)
            dh_prev = np.dot(da, self.weight_hh)
            self.dh_prev = dh_prev
     
            self.dx_list.insert(0, dx)
            self.dh_prev_stack.append(dh_prev)
     
            self.weight_ih_grad_stack.append(np.dot(da.T, x_stack))
            self.weight_hh_grad_stack.append(np.dot(da.T, h_prev))
     
            db = np.sum(da, axis=0)
            self.bias_ih_grad_stack.append(db)
            self.bias_hh_grad_stack.append(db)
     
            return dh_prev
     
     
    np.random.seed(123)
    torch.random.manual_seed(123)
    np.set_printoptions(precision=6, suppress=True)
     
    lstm_torch = torch.nn.LSTMCell(2, 3).double()
    lstm_numpy = LSTMCell(lstm_torch.weight_ih.data.numpy(),
                          lstm_torch.weight_hh.data.numpy(),
                          lstm_torch.bias_ih.data.numpy(),
                          lstm_torch.bias_hh.data.numpy())
     
    x_numpy = np.random.random((4, 2))
    x_torch = torch.tensor(x_numpy, requires_grad=True)
     
    h_numpy = np.random.random((4, 3))
    h_torch = torch.tensor(h_numpy, requires_grad=True)
     
    c_numpy = np.random.random((4, 3))
    c_torch = torch.tensor(c_numpy, requires_grad=True)
     
    dh_numpy = np.random.random((4, 3))
    dh_torch = torch.tensor(dh_numpy, requires_grad=True)
     
    h_numpy, c_numpy = lstm_numpy(x_numpy, h_numpy, c_numpy)
    h_torch, c_torch = lstm_torch(x_torch, (h_torch, c_torch))
    h_torch.backward(dh_torch)
     
    dh_numpy = lstm_numpy.backward(dh_numpy)
     
    print("h_numpy :\n", h_numpy)
    print("h_torch :\n", h_torch.data.numpy())
     
    print("---------------------------------")
    print("c_numpy :\n", c_numpy)
    print("c_torch :\n", c_torch.data.numpy())
     
    print("---------------------------------")
    print("dx_numpy :\n", np.sum(lstm_numpy.dx_list, axis=0))
    print("dx_torch :\n", x_torch.grad.data.numpy())
     
    print("---------------------------------")
    print("w_ih_grad_numpy :\n",
          np.sum(lstm_numpy.weight_ih_grad_stack, axis=0))
    print("w_ih_grad_torch :\n",
          lstm_torch.weight_ih.grad.data.numpy())
     
    print("---------------------------------")
    print("w_hh_grad_numpy :\n",
          np.sum(lstm_numpy.weight_hh_grad_stack, axis=0))
    print("w_hh_grad_torch :\n",
          lstm_torch.weight_hh.grad.data.numpy())
     
    print("---------------------------------")
    print("b_ih_grad_numpy :\n",
          np.sum(lstm_numpy.bias_ih_grad_stack, axis=0))
    print("b_ih_grad_torch :\n",
          lstm_torch.bias_ih.grad.data.numpy())
     
    print("---------------------------------")
    print("b_hh_grad_numpy :\n",
          np.sum(lstm_numpy.bias_hh_grad_stack, axis=0))
    print("b_hh_grad_torch :\n",
          lstm_torch.bias_hh.grad.data.numpy())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184

    输出结果如下:

    h_numpy :
     [[ 0.055856  0.234159  0.138457]
     [ 0.094461  0.245843  0.224411]
     [ 0.020396  0.086745  0.082545]
     [-0.003794  0.040677  0.063094]]
    h_torch :
     [[ 0.055856  0.234159  0.138457]
     [ 0.094461  0.245843  0.224411]
     [ 0.020396  0.086745  0.082545]
     [-0.003794  0.040677  0.063094]]
    ---------------------------------
    c_numpy :
     [[ 0.092093  0.384992  0.213364]
     [ 0.151362  0.424671  0.318313]
     [ 0.033245  0.141979  0.120822]
     [-0.0061    0.062946  0.094999]]
    c_torch :
     [[ 0.092093  0.384992  0.213364]
     [ 0.151362  0.424671  0.318313]
     [ 0.033245  0.141979  0.120822]
     [-0.0061    0.062946  0.094999]]
    ---------------------------------
    dx_numpy :
     [[-0.144016  0.029775]
     [-0.229789  0.140921]
     [-0.246041 -0.009354]
     [-0.088844  0.036652]]
    dx_torch :
     [[-0.144016  0.029775]
     [-0.229789  0.140921]
     [-0.246041 -0.009354]
     [-0.088844  0.036652]]
    ---------------------------------
    w_ih_grad_numpy :
     [[-0.056788 -0.036448]
     [ 0.018742  0.014428]
     [ 0.007827  0.024828]
     [ 0.07856   0.05437 ]
     [ 0.061267  0.045952]
     [ 0.083886  0.0655  ]
     [ 0.229755  0.156008]
     [ 0.345218  0.251984]
     [ 0.430385  0.376664]
     [ 0.014239  0.011767]
     [ 0.054866  0.044531]
     [ 0.04654   0.048565]]
    w_ih_grad_torch :
     [[-0.056788 -0.036448]
     [ 0.018742  0.014428]
     [ 0.007827  0.024828]
     [ 0.07856   0.05437 ]
     [ 0.061267  0.045952]
     [ 0.083886  0.0655  ]
     [ 0.229755  0.156008]
     [ 0.345218  0.251984]
     [ 0.430385  0.376664]
     [ 0.014239  0.011767]
     [ 0.054866  0.044531]
     [ 0.04654   0.048565]]
    ---------------------------------
    w_hh_grad_numpy :
     [[-0.037698 -0.048568 -0.021069]
     [ 0.016749  0.016277  0.007556]
     [ 0.035743  0.02156   0.000111]
     [ 0.060824  0.069505  0.029101]
     [ 0.060402  0.051634  0.025643]
     [ 0.068116  0.06966   0.035544]
     [ 0.168965  0.217076  0.075904]
     [ 0.248277  0.290927  0.138279]
     [ 0.384974  0.401949  0.167006]
     [ 0.015448  0.0139    0.005158]
     [ 0.057147  0.048975  0.022261]
     [ 0.057297  0.048308  0.017745]]
    w_hh_grad_torch :
     [[-0.037698 -0.048568 -0.021069]
     [ 0.016749  0.016277  0.007556]
     [ 0.035743  0.02156   0.000111]
     [ 0.060824  0.069505  0.029101]
     [ 0.060402  0.051634  0.025643]
     [ 0.068116  0.06966   0.035544]
     [ 0.168965  0.217076  0.075904]
     [ 0.248277  0.290927  0.138279]
     [ 0.384974  0.401949  0.167006]
     [ 0.015448  0.0139    0.005158]
     [ 0.057147  0.048975  0.022261]
     [ 0.057297  0.048308  0.017745]]
    ---------------------------------
    b_ih_grad_numpy :
     [-0.084682  0.032588  0.046412  0.126449  0.111421  0.139337  0.361956
      0.539519  0.761838  0.027649  0.103695  0.099405]
    b_ih_grad_torch :
     [-0.084682  0.032588  0.046412  0.126449  0.111421  0.139337  0.361956
      0.539519  0.761838  0.027649  0.103695  0.099405]
    ---------------------------------
    b_hh_grad_numpy :
     [-0.084682  0.032588  0.046412  0.126449  0.111421  0.139337  0.361956
      0.539519  0.761838  0.027649  0.103695  0.099405]
    b_hh_grad_torch :
     [-0.084682  0.032588  0.046412  0.126449  0.111421  0.139337  0.361956
      0.539519  0.761838  0.027649  0.103695  0.099405]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100

    问题总结

    本题考查的是对于LSTM的掌握,通过自己编写代码,我们能够最终对LSTM有一个较为清楚的认识。

  • 相关阅读:
    Operator SDK
    牛客网sql练习
    工作中一定要学会拒绝?
    Kruskal算法
    深入理解rtmp(一)之开发环境搭建
    某网吧网络布线规划设计
    VLAN的TRUNK协议(VTP)
    STM32F1定时器-PWM输出
    SpringSecurity整合JWT
    手把手教你实现buffer(二)——内存管理及移动语义
  • 原文地址:https://blog.csdn.net/LupnisJ/article/details/128100764