• 【深度学习理论】(7) 长短时记忆网络 LSTM


    大家好,今天和各位分享一下长短时记忆网络 LSTM 的原理,并使用 Pytorch 从公式上实现 LSTM 层

    上一节介绍了循环神经网络 RNN,感兴趣的可以看一下:https://blog.csdn.net/dgvv4/article/details/125424902

    我的这个专栏中有许多 LSTM 的实战案例,便于大家巩固知识:https://blog.csdn.net/dgvv4/category_11712004.html


    1. 引言

    循环神经网络的记忆功能在处理时间序列问题上存在很大优势,但随着训练的不断进行,RNN 网络一直在不断的扩充记忆,致使 RNN 产生梯度消失以及梯度爆炸

    为了解决RNN难以有效训练的问题,拥有选择记忆功能的 LSTM模型被提出。LSTM 是在 RNN 的基础上进行的改进,其既能学习数据中的长期依赖,又能解决梯度消失。LSTM 包含一个记忆单元和三个门结构,其中门结构分别是输入门、输出门和遗忘门。

    LSTM 的工作过程如下:

    首先由输入数据 X_t 前一时刻隐藏层的输出数据 h_t-1 共同作用于遗忘门,遗忘门对上述信息进行筛选,记忆时间序列中的重要特征信息,丢弃无关紧要的信息;然后将输入数据 x_t 以及前一时刻隐藏层的输出数据 h_t-1 作为输入门的输入信息,进行更新;其次记忆单元通过输入数据 X_t、前一时刻隐藏层的输出数据 h_t-1 以及前一时刻的记忆单元状态 C_t-1 对自身状态进行更新;最后将输入数据 X_t前一时刻隐藏层的输出数据 h_t-1 以及当前时刻的记忆单元状态 C_t 共同作用于输出门,输出当前时刻的隐藏层信息 h_t

    LSTM 的结构图如下:


    2. 原理解析

    2.1 遗忘门

    上一时刻的输出 h_t-1当前时刻的输入 X_t 结合,并通过 Sigmoid 函数计算得到一个阈值为 [0,1] 的张量 f_t,该 f_t 可以看作是对上一时刻的状态 C_t-1 的权重项,f_t 负责控制上一时刻状态需要被遗忘的程度

    计算公式:

    f_t = \sigma (W_f \cdot [h_t-1, x_t] + b_f)

    将公式展开,其中 W_if 是对当前时刻输入的特征提取,W_hf 是对前一时刻状态的特征提取,@ 代表矩阵相乘。

    f_t = \sigma ( W_{if}@X_{t} + b_{if} + W_{hf}@h_{t-1} + b_{hf} )


    2.2 输入门

    输入门是与 tanh 函数配合控制新信息加入的程度。在这个过程中,tanh 函数会给出一个新的候选向量 \tilde{C_t},输入门为 \tilde{C_t} 中的每一项产生一个在 [0,1] 之间的值 i_t,控制新信息被加入的多少。

    计算公式:

    i_t = \sigma (W_i \cdot [h_t-1, x_t] + b_i)

    \tilde{C}_t = tanh (W_c \cdot [h_t-1, x_t] + b_c)

    公式展开,其中 W_i 是对当前时刻输入的特征提取,W_h 是对前一时刻状态的特征提取,@ 代表矩阵相乘。

    i_t = \sigma ( W_{ii}@X_{t} + b_{ii} + W_{hi}@h_{t-1} + b_{hi} )

    g_t = \tilde{C_t} = tanh ( W_{ig}@X_{t} + b_{ig} + W_{hg}@h_{t-1} + b_{hg} )

    至此,模型已经计算了遗忘门的输出 f_t,和输入门的输出 i_t分别用来控制上一时刻的状态需要被遗忘的程度,和新增信息的规模,接下来可以根据这两个输出更新当前时刻的状态 C_t

    计算公式,其中 * 代表张量之间逐元素相乘。

    C_t = f_t * C_{t-1} + i_t * \tilde{C_t}


    2.3 输出门

    输出门用来过滤当前状态的某些信息,将其舍去。输出门的计算过程,输入数据 X_t前一时刻隐藏层的输出数据 h_t-1 经过 sigmoid 函数,把每一项的值压缩到 [0-1] 之间作为过滤信息的权重项。然后与更新后的当前状态 C_t 逐元素相乘,

    计算公式:

    o_t = \sigma (W_o \cdot [h_t-1, x_t] + b_o)

    h_t = o_t * tanh(C_t)

    公式展开:

    o_t = tanh ( W_{io}@X_{t} + b_{io} + W_{ho}@h_{t-1} + b_{ho} )


    3. 代码实现

    3.1 官方 API

    torch.nn.LSTM() 参数如下:

    1. lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False)
    2. '''
    3. input_size: 每个单词使用多少长的向量来表示
    4. hidden_size: 隐含层,经过LSTM层后每个单词用多长的向量来表示
    5. num_layers: LSTM的层数
    6. bias: 是否使用偏置项,默认为True,即 w@x+b
    7. batch_first: 对于输入是否将batch放在axis=0的位置,默认False,即[seq_len, batch, feature_len]
    8. '''

    实例化单层 LSTM,做一次前向传播,查看输出信息

    1. import torch
    2. from torch import nn
    3. # 定义参数
    4. batch = 3 # 现在有3个句子
    5. seq_len = 10 # 每个句子有10个单词
    6. feature_len = 100 # 每个单词用长度为100的向量来表示
    7. hidden_len = 20 # 经过LSTM层后每个单词用长度为20的向量来表示
    8. # 当前时刻的输入 [batch, seq_len, feature_len]
    9. inputs = torch.randn(batch, seq_len, feature_len)
    10. # 上一时刻的状态 [batch, hidden_len]
    11. h0 = torch.randn(batch, hidden_len)
    12. c0 = torch.randn(batch, hidden_len)
    13. # 实例化LSTM层
    14. lstm = nn.LSTM(input_size=feature_len, hidden_size=hidden_len,
    15. num_layers=1, batch_first=True)
    16. # c:最后一个单词更新的状态,[num_layer, batch, hidden_size]
    17. # h:最后一个单词的输出,[num_layer, batch, hidden_size]
    18. # out: 整体输出结果,[batch, seq_len, hidden_size]
    19. out, (h,c) = lstm(inputs)
    20. print('out:', out.shape, # [3, 10, 20]
    21. 'h:', h.shape, # [1, 3, 20]
    22. 'c:', c.shape) # [1, 3, 20]
    23. # 查看权重信息
    24. for k,v in lstm.named_parameters():
    25. print(k, v.shape)
    26. '''
    27. weight_ih_l0 torch.Size([80, 100])
    28. weight_hh_l0 torch.Size([80, 20])
    29. bias_ih_l0 torch.Size([80])
    30. bias_hh_l0 torch.Size([80])
    31. '''

    3.2 自定义函数

    接下来根据第二小节解释过的公式,从原理上实现一个 LSTM 层,主要就是6个公式的计算,还要注意张量的shape变化。

    代码实现如下:

    1. import torch
    2. from torch import nn
    3. '''
    4. inputs: 当前时刻的输入 [batch, seq_len, feature_len]
    5. c0: 上一时刻的状态,[batch, hidden_len]
    6. h0: 上一时刻的输出,[batch, hidden_len]
    7. w_ih, b_ih: 对当前时刻输入的特征矩阵和偏置
    8. w_hh, b_hh: 对上一时刻状态的特征矩阵和偏置
    9. w_ih.shape=[4*hdiien_size, feature_len]
    10. w_hh.shape=[4*hdiien_size, hidden_len]
    11. b.shape=[4*hidden_size]
    12. '''
    13. # ------------------------------------------------------------- #
    14. #(1)自定义LSTM模型
    15. # ------------------------------------------------------------- #
    16. def lstm_forward(inputs, initial_states, w_ih, w_hh, b_ih, b_hh):
    17. h0, c0 = initial_states # 获取初始状态
    18. # batch代表序列个数,seq_len代表每个序列有多少个样本,feature_len代表每个样本有多少个特征
    19. batch, seq_len, feature_len = inputs.shape # 获取输入的形状
    20. # 获取隐含层个数,根据公式由4个W拼接而成
    21. hidden_len = w_ih.shape[0] // 4 # weight_ih_l0 torch.Size([80, 100])
    22. # 初始化输出层 [batch, seq_len, hidden_len]
    23. outputs = torch.zeros(batch, seq_len, hidden_len)
    24. # 在LSTM中不断更新上一时刻的状态
    25. pre_h, pre_c = h0, c0
    26. # 扩充w的维度==>[b, 4*hdiien_size, feature_len]
    27. batch_w_ih = w_ih.unsqueeze(0).tile(batch, 1, 1)
    28. # ==>[b, 4*hdiien_size, hidden_len]
    29. batch_w_hh = w_hh.unsqueeze(0).tile(batch, 1, 1)
    30. # 遍历每个序列中的每个单词
    31. for t in range(seq_len):
    32. # 获取当前时刻的输入张量
    33. x = inputs[:, t, :] # [b, feature_len]
    34. # 三维矩阵相乘 [b, 4*hdiien_size, feature_len] @ [b, feature_len, 1]
    35. w_time_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) # [b, 4*hidden_len, 1]
    36. w_time_x = w_time_x.squeeze(-1) # [b, 4*hidden_len]
    37. # 状态的矩阵相乘 [b, 4*hdiien_size, hidden_len] @ [b, hidden_len, 1]
    38. w_time_h_pre = torch.bmm(batch_w_hh, pre_h.unsqueeze(-1)) # [b, 4*hidden_size, 1]
    39. w_time_h_pre = w_time_h_pre.squeeze(-1) # [b, 4*hidden_size]
    40. # 取前1/4用作输入门(i)
    41. i_t = w_time_x[:, :hidden_len] + b_ih[:hidden_len] + w_time_h_pre[:, :hidden_len] + b_hh[:hidden_len]
    42. i_t = torch.sigmoid(i_t)
    43. # 遗忘门(f)
    44. f_t = w_time_x[:, hidden_len:hidden_len*2] + b_ih[hidden_len:hidden_len*2] + w_time_h_pre[:, hidden_len:hidden_len*2] + b_hh[hidden_len:hidden_len*2]
    45. f_t = torch.sigmoid(f_t)
    46. # 细胞门(g)
    47. g_t = w_time_x[:, hidden_len*2:hidden_len*3] + b_ih[hidden_len*2:hidden_len*3] + w_time_h_pre[:, hidden_len*2:hidden_len*3] + b_hh[hidden_len*2:hidden_len*3]
    48. g_t = torch.tanh(g_t)
    49. # 输出门(o)
    50. o_t = w_time_x[:, hidden_len*3:] + b_ih[hidden_len*3:] + w_time_h_pre[:, hidden_len*3:] + b_hh[hidden_len*3:]
    51. o_t = torch.tanh(o_t)
    52. # 状态(c)
    53. pre_c = f_t * pre_c + i_t * g_t
    54. # 当前时刻lstm的输出(h)
    55. pre_h = o_t * torch.tanh(pre_c)
    56. # 更新输出层
    57. outputs[:, t, :] = pre_h
    58. # 返回输出、最后一个时刻的输出h,状态c
    59. return outputs, (pre_h, pre_c)
    60. # ------------------------------------------------------------- #
    61. #(2)前向传播
    62. # ------------------------------------------------------------- #
    63. batch = 3 # 3个句子
    64. seq_len = 10 # 序列长度,每个句子有10个单词
    65. feature_len = 100 # 特征个数,一个单词用长度为100的向量来表示
    66. hidden_len = 20 # 隐含层,经过LSTM层后用长度为20的向量来表示
    67. # 构造输入层 [batch, seq_len, feature_len]
    68. inputs = torch.randn(batch, seq_len, feature_len)
    69. # 初始状态,不需要训练 [batch, hidden_len]
    70. h0 = torch.randn(batch, hidden_len)
    71. c0 = torch.randn(batch, hidden_len)
    72. # 构造权重
    73. w_ih = torch.randn(hidden_len*4, feature_len) # [80, 100]
    74. w_hh = torch.randn(hidden_len*4, hidden_len) # [80, 100]
    75. # 构造偏执
    76. b_ih = torch.randn(hidden_len*4) # [80]
    77. b_hh = torch.randn(hidden_len*4) # [80]
    78. # lstm层计算结果
    79. outputs, (final_h, final_c) = lstm_forward(inputs, (h0, c0), w_ih, w_hh, b_ih, b_hh)
    80. '''
    81. outputs: 所有句子的输出,[batch,seq_len, hidden_len]
    82. pre_h: 最后一次个单词的输出,[batch, hidden_len]
    83. pre_c: 最后一个单词的状态,[batch, hidden_len]
    84. '''
    85. print('outputs.shape:', outputs.shape, # [3, 10, 20]
    86. 'pre_h.shape:', final_h.shape, # [3, 20]
    87. 'pre_c.shape:', final_c.shape) # [3, 20]
  • 相关阅读:
    Jsp 学习笔记
    【设计模式】Java设计模式 - 状态模式
    数据分析 - matplotlib示例代码
    Kubecost - Kubernetes 开支监控和管理
    一文看懂 ZooKeeper ,面试再也不用背八股
    ffplay使用dxva2实现硬解渲染
    Linux网络环境配置:(内含:随机ip和固定ip设置方式)
    给电瓶车“消消火”——TSINGSEE青犀智能电瓶车棚监控方案
    软考2021高级架构师下午案例分析第4题:关于反规范化设计、数据不一致问题
    力扣题解8/10
  • 原文地址:https://blog.csdn.net/dgvv4/article/details/125456578