• 03- LSTM 的从零开始实现


    一 长短期记忆算法 简介

    LSTM(Long Short-Term Memory)算法是一种常见的循环神经网络RNN)算法,用于处理序列数据,并且在处理时间序列数据时效果非常好。

    LSTM算法的主要思路是在RNN的基础上增加一个记忆单元,该记忆单元可以帮助网络记忆过去的状态,并在需要时更新它们。同时,LSTM还通过三个门控制信息的流动:遗忘门、输入门和输出门。这些门允许网络选择何时忘记旧的状态,何时接受新的输入,并输出当前状态。

    二 算法实现

    2.1 导包

    1. # 从零实现
    2. import torch
    3. from torch import nn
    4. import dltools

    2.2 导入训练数据

    1. batch_size, num_steps = 32, 35
    2. train_iter, vocab = dltools.load_data_time_machine(batch_size, num_steps=num_steps)

    2.3 初始化模型参数

    1. # 初始化模型参数
    2. def get_lstm_params(vocab_size, num_hiddens, device):
    3. num_inputs = num_outputs = vocab_size
    4. def normal(shape):
    5. return torch.randn(size=shape, device=device) * 0.01
    6. def three():
    7. return (normal((num_inputs, num_hiddens)),
    8. normal((num_hiddens, num_hiddens)),
    9. torch.zeros(num_hiddens, device=device))
    10. W_xi, W_hi, b_i = three() # 输入参数
    11. W_xf, W_hf, b_f = three() # 遗忘门参数
    12. W_xo, W_ho, b_o = three() # 输出门参数
    13. W_xc, W_hc, b_c = three() # 候选记忆元参数
    14. # 输出层
    15. W_hq = normal((num_hiddens, num_outputs))
    16. b_q = torch.zeros(num_outputs, device=device)
    17. # 附加梯度
    18. params = [W_xi, W_hi, b_i,
    19. W_xf, W_hf, b_f,
    20. W_xo, W_ho, b_o,
    21. W_xc, W_hc, b_c, W_hq, b_q]
    22. for param in params:
    23. param.requires_grad_(True)
    24. return params

    初始化LSTM模型的参数。首先,根据词汇表的大小和隐藏单元的数量,计算出输入和输出的维度。接着,定义了一个normal函数,用于生成服从标准正态分布的随机数,并将随机数乘以0.01,以控制参数的初始值范围。

    接下来,定义了一个 内部函数three,用于生成三个参数。每个参数都是一个元组,包含了输入和隐藏单元之间的权重矩阵和偏置向量。这里使用了初始化方式normal来生成权重矩阵,偏置向量初始化为全零

    然后,调用内部函数three分别 初始化输入、遗忘门、输出门和候选记忆元的参数

    接着,使用normal函数初始化了输出层的参数,包括隐藏单元和输出的权重矩阵和偏置向量。

    最后,将所有的参数放入一个列表中,并设置requires_grad为True,表示需要计算参数的梯度。

    返回参数列表。

    2.4 初始化隐藏状态

    1. # 初始化隐藏状态和记忆元
    2. def init_lstm_state(batch_size, num_hiddens, device):
    3. return (torch.zeros((batch_size, num_hiddens), device=device),
    4. torch.zeros((batch_size, num_hiddens), device=device))

    初始化长短期记忆网络(LSTM)的隐藏状态和记忆元。输入参数包括批量大小(batch_size)、隐藏单元数量(num_hiddens)和计算设备(device)。

    函数返回一个元组,其中包含两个张量,分别代表隐藏状态和记忆元。这些张量的大小为(batch_size, num_hiddens),并且初始化为全零。

    2.5 定义 LSTM 结构

    1. # 定义 LSTM 主体结构
    2. def lstm(inputs, state, params):
    3. [W_xi, W_hi, b_i,
    4. W_xf, W_hf, b_f,
    5. W_xo, W_ho, b_o,
    6. W_xc, W_hc, b_c, W_hq, b_q] = params
    7. (H, C) = state
    8. outputs = []
    9. # 准备开始进行前向传播
    10. for X in inputs:
    11. I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
    12. F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
    13. O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
    14. C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
    15. C = F * C + I * C_tilda
    16. H = O * torch.tanh(C)
    17. Y = (H @ W_hq) + b_q
    18. outputs.append(Y)
    19. return torch.cat(outputs, dim=0), (H, C)

    三 模型训练

    3.1 测试代码

    1. # 测试代码
    2. X = torch.arange(10).reshape((2, 5))
    3. num_hiddens = 512
    4. net = dltools.RNNModelScratch(len(vocab), num_hiddens, dltools.try_gpu(), get_lstm_params, init_lstm_state, lstm)
    5. state = net.begin_state(X.shape[0], dltools.try_gpu())
    6. Y, new_state = net(X.to(dltools.try_gpu()), state)

    3.2 在线训练

    1. # 训练和预测
    2. vocab_size, num_hiddens, device = len(vocab), 256, dltools.try_gpu()
    3. num_epochs, lr = 500, 1
    4. model = dltools.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params, init_lstm_state, lstm)
    5. dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

    四 pytorch 实现

    1. num_inputs = vocab_size
    2. lstm_layer = nn.LSTM(num_inputs, num_hiddens)
    3. model = dltools.RNNModel(lstm_layer, len(vocab))
    4. model = model.to(device)
    5. dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

  • 相关阅读:
    面试宝典-【redis】
    实训十八:RIP2邻居认证
    C++ 学习 ::【基础篇:07】:C++ C11 标准中 关键字 auto 的基本介绍与使用
    M1 芯片 MacBook 结合 MAMP 集成环境配置 PHP 环境变量
    MySQL 表的约束
    JAVA构造方法(与类名相同的方法),类方法与实例方法
    数据结构之Trie树
    Tomcat架构设计&源码剖析
    期货自动止损止盈 易盛极星
    php-java-net-python-报修修改计算机毕业设计程序
  • 原文地址:https://blog.csdn.net/March_A/article/details/132841428