• 循环神经网络(RNN)之门控循环单元(GRU)


            在实现门控循环单元的循环神经网络之前,可以先熟悉论文:Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling
    另外一篇关于RNN编码-解码的论文,大家有兴趣也可以看下,其中一位是来自“深度学习三巨头”之一的约书亚·本吉奥:
    Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation
            我们在上一篇文章:RNN模型参数与变量的依赖以及时间反向传播梯度的推导,熟悉了RNN的一些相关知识,我们发现当时间步较大的时候容易出现梯度爆炸,这个我们用了裁剪梯度来应对,但对于时间步较小时,无法应对梯度出现的衰减问题。通常因为这样的原因,就较难捕捉到时间序列中时间步距离较大的依赖关系。

            于是就提出了带门控的循环单元的RNN,通过学习门来控制信息的流动。在一个门控循环单元中,包含着重置门和更新门,两者有什么特点和优势,如下:

    1、重置门(reset gate):当重置门中元素值接近1,就保留上一时间步的隐藏状态,如果元素值接近0时,将忽略上一步的隐藏状态,仅使用当前输入进行复位,有效地丢弃与预测无关的历史信息,从而允许更紧凑的表示。有助于捕捉时间序列里短期的依赖关系。
    2、更新门(update gate):控制从上一步的隐藏状态有多少信息传递到当前的隐藏状态。有助于捕捉时间序列里长期的依赖关系。
    除了上述两个控制门之外,还有一个候选隐藏状态(candidate hidden state),这个设计主要是应对梯度衰减了,因为它可以保存较早时刻的隐藏状态一直通过时间保存并传递到当前的时间步来。
    门控循环单元最终的计算输出是来自上一步的隐藏状态与更新门按元素乘法,再跟当前时间的候选隐藏状态做组合(相加)。

    对于这样一个门控循环单元,在流程中如何进行计算的,我个人还是比较喜欢使用画图来直观表现,如下:

     代码即解释,我们来实现它:

    1. import d2lzh as d2l
    2. from mxnet import nd
    3. from mxnet.gluon import rnn
    4. (corpus_indices,char_to_idx,idx_to_char,vocab_size)=d2l.load_data_jay_lyrics()
    5. num_inputs,num_hiddens,num_outputs=vocab_size,256,vocab_size
    6. ctx=d2l.try_gpu()
    7. #ctx=None
    8. def get_params():
    9. def _one(shape):
    10. return nd.random.normal(scale=0.01,shape=shape,ctx=ctx)
    11. def _three():
    12. return (_one((num_inputs,num_hiddens)),_one((num_hiddens,num_hiddens)),nd.zeros(num_hiddens,ctx=ctx))
    13. W_xz,W_hz,b_z=_three()#更新门参数
    14. W_xr,W_hr,b_r=_three()#重置门参数
    15. W_xh,W_hh,b_h=_three()#候选隐藏状态参数
    16. #输出层参数
    17. W_hq=_one((num_hiddens,num_outputs))
    18. b_q=nd.zeros(num_outputs,ctx=ctx)
    19. #附上梯度
    20. params=[W_xz,W_hz,b_z,W_xr,W_hr,b_r,W_xh,W_hh,b_h,W_hq,b_q]
    21. for param in params:
    22. param.attach_grad()
    23. return params
    24. #定义模型
    25. #隐藏状态初始化函数
    26. def init_gru_state(batch_size,num_hiddens,ctx):
    27. return (nd.zeros(shape=(batch_size,num_hiddens),ctx=ctx),)
    28. def gru(inputs,state,params):
    29. W_xz,W_hz,b_z,W_xr,W_hr,b_r,W_xh,W_hh,b_h,W_hq,b_q=params
    30. H,=state
    31. outputs=[]
    32. for X in inputs:
    33. Z=nd.sigmoid(nd.dot(X,W_xz)+nd.dot(H,W_hz)+b_z)
    34. R=nd.sigmoid(nd.dot(X,W_xr)+nd.dot(H,W_hr)+b_r)
    35. H_tilda=nd.tanh(nd.dot(X,W_xh)+nd.dot(R*H,W_hh)+b_h)
    36. H=Z*H+(1-Z)*H_tilda
    37. Y=nd.dot(H,W_hq)+b_q
    38. outputs.append(Y)
    39. return outputs,(H,)
    40. #训练模型(相邻采样)
    41. num_epochs,num_steps,batch_size,lr,clipping_theta=200,35,32,1e2,1e-2
    42. pred_period,pred_len,prefixes=40,50,['分开','不分开']
    43. #d2l.train_and_predict_rnn(gru,get_params,init_gru_state,num_hiddens,vocab_size,ctx,corpus_indices,idx_to_char,char_to_idx,False,num_epochs,num_steps,lr,clipping_theta,batch_size,pred_period,pred_len,prefixes)
    44. #简洁实现
    45. gru_layer=rnn.GRU(num_hiddens)
    46. model=d2l.RNNModel(gru_layer,vocab_size)
    47. d2l.train_and_predict_rnn_gluon(model,num_hiddens,vocab_size,ctx,corpus_indices,idx_to_char,char_to_idx,num_epochs,num_steps,lr,clipping_theta,batch_size,pred_period,pred_len,prefixes)

    epoch 40, perplexity 155.968500, time 0.12 sec
     - 分开 我不的让我 我不的让我 我想你的让我 我想你的让我 我想你的让我 我想你的让我 我想你的让我 我想
     - 不分开 我不的让我 我不的让我 我想你的让我 我想你的让我 我想你的让我 我想你的让我 我想你的让我 我想
    epoch 80, perplexity 34.249728, time 0.12 sec
     - 分开 我想要你的微笑 一定在美不人 你的让我有多多 爱你在我不多 我爱你的爱笑 让我想这样 我不要再想你
     - 不分开 我想要你的微笑 一定在美不人 你的让我有多多 爱你在我不多 我爱你的爱笑 让我想这样 我不要再想你
    epoch 120, perplexity 5.254933, time 0.12 sec
     - 分开 我想带这样打 但知后觉 你已经离不舍 不知不觉 我跟了这节奏 后知后觉 我该好好生活 我该好好生活
     - 不分开 我已经这样奏 后知后觉 我该好好生活 我该好好生活 不知不觉 我跟了这节奏 后知后觉 我该好好生活
    epoch 160, perplexity 1.513565, time 0.11 sec
     - 分开 我想轻这里 我妈著好恼 我后 这样 我不要再想要你 不知不觉 你已经离开我 不知不觉 我跟了这节奏
     - 不分开 我已天这样奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生活 不知不觉 你已经离开我
    epoch 200, perplexity 1.072794, time 0.12 sec
     - 分开 让弄堂的太快否听的见 它一定实现它一定实现 载著你 彷彿载著阳光 不管到哪里都是晴天 蝴蝶自在飞
     - 不分开 我已 这样的玩奏就像龙卷风 离不开暴风圈来不及逃 我不能再想 我不能再想 我不 我不 我不能 爱情

  • 相关阅读:
    Echarts可视化项目,Echarts-社区热销排行top效果,销售统计,销售统计sales 线形图,
    内置数据库H2和内置Redis(测试结果来啦)
    开发者分享 | Ascend C算子开发及单算子调用
    4.圆角边框
    Windows查看端口、结束任务进程(Nginx)
    Windows MongoDB详细安装与配置
    Cholesterol-PEG-Acid CLS-PEG-COOH 胆固醇-聚乙二醇-羧基修饰肽类化合物
    android 通过adb shell 命令获取最大打开文件数和当前打开文件数及串口使用率
    linux用vim编写1到100的求和
    测试十大法则
  • 原文地址:https://blog.csdn.net/weixin_41896770/article/details/126481231