码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • NNDL 作业8:RNN - 简单循环网络


    简单循环网络 ( Simple Recurrent Network , SRN) 只有一个隐藏层的神经网络 .

    目录

    1. 使用Numpy实现SRN

    2. 在1的基础上,增加激活函数tanh

    3. 分别使用nn.RNNCell、nn.RNN实现SRN

    4. 分析“二进制加法” 源代码(选做)

    5. 实现“Character-Level Language Models”源代码(必做)

    6. 分析“序列到序列”源代码(选做)

    7. “编码器-解码器”的简单实现(必做)


    1. 使用Numpy实现SRN

    1. import numpy as np
    2. inputs = np.array([[1., 1.],
    3. [1., 1.],
    4. [2., 2.]]) # 初始化输入序列
    5. print('inputs is ', inputs)
    6. state_t = np.zeros(2, ) # 初始化存储器
    7. print('state_t is ', state_t)
    8. w1, w2, w3, w4, w5, w6, w7, w8 = 1., 1., 1., 1., 1., 1., 1., 1.
    9. U1, U2, U3, U4 = 1., 1., 1., 1.
    10. print('--------------------------------------')
    11. for input_t in inputs:
    12. print('inputs is ', input_t)
    13. print('state_t is ', state_t)
    14. in_h1 = np.dot([w1, w3], input_t) + np.dot([U2, U4], state_t)
    15. in_h2 = np.dot([w2, w4], input_t) + np.dot([U1, U3], state_t)
    16. state_t = in_h1, in_h2
    17. output_y1 = np.dot([w5, w7], [in_h1, in_h2])
    18. output_y2 = np.dot([w6, w8], [in_h1, in_h2])
    19. print('output_y is ', output_y1, output_y2)
    20. print('---------------')

    2. 在1的基础上,增加激活函数tanh

    1. import numpy as np
    2. inputs = np.array([[1., 1.],
    3. [1., 1.],
    4. [2., 2.]]) # 初始化输入序列
    5. print('inputs is ', inputs)
    6. state_t = np.zeros(2, ) # 初始化存储器
    7. print('state_t is ', state_t)
    8. w1, w2, w3, w4, w5, w6, w7, w8 = 1., 1., 1., 1., 1., 1., 1., 1.
    9. U1, U2, U3, U4 = 1., 1., 1., 1.
    10. print('--------------------------------------')
    11. for input_t in inputs:
    12. print('inputs is ', input_t)
    13. print('state_t is ', state_t)
    14. in_h1 = np.tanh(np.dot([w1, w3], input_t) + np.dot([U2, U4], state_t))
    15. in_h2 = np.tanh(np.dot([w2, w4], input_t) + np.dot([U1, U3], state_t))
    16. state_t = in_h1, in_h2
    17. output_y1 = np.dot([w5, w7], [in_h1, in_h2])
    18. output_y2 = np.dot([w6, w8], [in_h1, in_h2])
    19. print('output_y is ', output_y1, output_y2)
    20. print('---------------')

    3. 分别使用nn.RNNCell、nn.RNN实现SRN

    1. import torch
    2. batch_size = 1
    3. seq_len = 3 # 序列长度
    4. input_size = 2 # 输入序列维度
    5. hidden_size = 2 # 隐藏层维度
    6. output_size = 2 # 输出层维度
    7. # RNNCell
    8. cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size)
    9. # 初始化参数 https://zhuanlan.zhihu.com/p/342012463
    10. for name, param in cell.named_parameters():
    11. if name.startswith("weight"):
    12. torch.nn.init.ones_(param)
    13. else:
    14. torch.nn.init.zeros_(param)
    15. # 线性层
    16. liner = torch.nn.Linear(hidden_size, output_size)
    17. liner.weight.data = torch.Tensor([[1, 1], [1, 1]])
    18. liner.bias.data = torch.Tensor([0.0])
    19. seq = torch.Tensor([[[1, 1]],
    20. [[1, 1]],
    21. [[2, 2]]])
    22. hidden = torch.zeros(batch_size, hidden_size)
    23. output = torch.zeros(batch_size, output_size)
    24. for idx, input in enumerate(seq):
    25. print('=' * 20, idx, '=' * 20)
    26. print('Input :', input)
    27. print('hidden :', hidden)
    28. hidden = cell(input, hidden)
    29. output = liner(hidden)
    30. print('output :', output)

    1. import torch
    2. batch_size = 1
    3. seq_len = 3
    4. input_size = 2
    5. hidden_size = 2
    6. num_layers = 1
    7. output_size = 2
    8. cell = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
    9. for name, param in cell.named_parameters(): # 初始化参数
    10. if name.startswith("weight"):
    11. torch.nn.init.ones_(param)
    12. else:
    13. torch.nn.init.zeros_(param)
    14. # 线性层
    15. liner = torch.nn.Linear(hidden_size, output_size)
    16. liner.weight.data = torch.Tensor([[1, 1], [1, 1]])
    17. liner.bias.data = torch.Tensor([0.0])
    18. inputs = torch.Tensor([[[1, 1]],
    19. [[1, 1]],
    20. [[2, 2]]])
    21. hidden = torch.zeros(num_layers, batch_size, hidden_size)
    22. out, hidden = cell(inputs, hidden)
    23. print('Input :', inputs[0])
    24. print('hidden:', 0, 0)
    25. print('Output:', liner(out[0]))
    26. print('--------------------------------------')
    27. print('Input :', inputs[1])
    28. print('hidden:', out[0])
    29. print('Output:', liner(out[1]))
    30. print('--------------------------------------')
    31. print('Input :', inputs[2])
    32. print('hidden:', out[1])
    33. print('Output:', liner(out[2]))

    4. 分析“二进制加法” 源代码(选做)

    Anyone Can Learn To Code an LSTM-RNN in Python (Part 1: RNN) - i am trask

    5. 实现“Character-Level Language Models”源代码(必做)

    翻译Character-Level Language Models 相关内容

    The Unreasonable Effectiveness of Recurrent Neural Networks

    编码实现该模型 

    6. 分析“序列到序列”源代码(选做)

     

    7. “编码器-解码器”的简单实现(必做)

     

    seq2seq的PyTorch实现_哔哩哔哩_bilibili

    Seq2Seq的PyTorch实现 - mathor

    REF:

    Hung-yi Lee (ntu.edu.tw)

    《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

    完全图解RNN、RNN变体、Seq2Seq、Attention机制 - 知乎 (zhihu.com)
  • 相关阅读:
    面试突击39:synchronized底层是如何实现的?
    语法基础(变量、输入输出、表达式与顺序语句)
    input修改checkbox复选框默认选中样式
    Java面试被问了几个简单的问题,却回答的不是很好
    年前端react面试打怪升级之路
    行业追踪,2023-09-15
    [RoarCTF 2019]Simple Upload
    js颜色调试器
    LeetCode 双周赛 99,纯纯送分场!
    DSPE-PEG-DBCO,DBCO-PEG-DSPE,磷脂-聚乙二醇-二苯并环辛炔科研实验用
  • 原文地址:https://blog.csdn.net/qq_38975453/article/details/127561213
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号