上篇文章提到了RNN循环神经网络在处理文本数据时具有优秀的表现,然后RNN面对长距离时存在梯度爆炸、梯度消失的问题。梯度爆炸通常我们可以采用梯度截断的方法来缓解,然而梯度消失问题就需要我们接下来说的RNN变体——LSTM来解决了。
LSTM( Long Short Term Memory)长短期记忆网络,随着句子长度的不断增长,LSTM发挥了比普通RNN更优良的性能。下面将介绍LSTM的原理及paddle框架实现的代码。
这是LSTM的内部原理图,初学者看上去可能会觉得很抽象,我们把它具体分为几个模块逐一解决。
其计算公式为:
将上一个LSTM单元的输出和当前输入乘以权重Wf加上偏置bf后经过一个sigmoid激活函数。这里可以理解为一个橡皮擦,把不需要记忆的内容过滤掉。遗忘门用于和中间状态C相乘,过滤掉不重要的内容。
其计算公式为:
将上一个LSTM单元的输出和当前输入乘以权重Wf加上偏置bf后经过一个sigmoid激活函数与上一个LSTM单元的输出和当前输入乘以权重Wf加上偏置bf后经过一个tanh激活函数相乘。这里可以理解为一个铅笔,在细胞状态Ct上写需要记忆的内容。tanh获得了暂时的细胞状态C,然后使用输入门控制保留信息的多少。
其计算公式为:
将遗忘门输出内容乘以上一层细胞状态加上输入门的内容。Ct也可以理解成一个日记本,记录时刻的信息。
其计算公式为:
这一层LSTM输出为上一个LSTM单元的输出和当前输入乘以权重Wf加上偏置bf后经过一个sigmoid激活函数与当前细胞状态Ct经过tanh激活函数后相乘。控制输出保留中间状态C的多少。
sigmoid激活函数用于门控机制上决定了是否通过,相当于0表示忘记,1表示记忆。
tanh激活函数用在了状态和输出上,是对数据的处理。
因为每个LSTM都有门控机制,中间记忆细胞通过输入门和遗忘门之后进行累加,而不像RNN那样的累乘,从而缓解了RNN的梯度消失或梯度爆炸。