• Pytorch搭建循环神经网络RNN(简单实战)


    Pytorch搭建循环神经网络RNN(简单实战)

    去年写了篇《循环神经网络》,里面主要介绍了循环神经网络的结构与Tensorflow实现。而本篇博客主要介绍基于Pytorch搭建RNN。

    通过Sin预测Cos

    import torch
    import torch.nn as nn
    import numpy as np
    from matplotlib import pyplot as plt
    
    • 1
    • 2
    • 3
    • 4

    首先,我们定义一些超参数

    TIME_STEP = 10  # rnn 时序步长数
    INPUT_SIZE = 1  # rnn 的输入维度
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    H_SIZE = 64  # of rnn 隐藏单元个数
    EPOCHS = 100  # 总共训练次数
    h_state = None  # 隐藏层状态
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    使用Numpy生成Sin和Cos函数

    steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)
    x_np = np.sin(steps)
    y_np = np.cos(steps)
    
    • 1
    • 2
    • 3

    可视化数据

    plt.figure(1)
    plt.suptitle('Sin and Cos', fontsize='18')
    plt.plot(steps, y_np, 'r-', label='target (cos)')
    plt.plot(steps, x_np, 'b-', label='input (sin)')
    plt.legend(loc='best')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    定义网络结构

    class RNN(nn.Module):
        def __init__(self):
            super(RNN, self).__init__()
            self.rnn = nn.RNN(
                input_size=INPUT_SIZE,
                hidden_size=H_SIZE,
                num_layers=1,
                batch_first=True,
            )
            self.out = nn.Linear(H_SIZE, 1)
    
        def forward(self, x, h_state):
            r_out, h_state = self.rnn(x, h_state)
            outs = []  # 保存所有的预测值
            for time_step in range(r_out.size(1)):  # 计算每一步长的预测值
                outs.append(self.out(r_out[:, time_step, :]))
            return torch.stack(outs, dim=1), h_state
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    rnn = RNN().to(DEVICE)
    optimizer = torch.optim.Adam(rnn.parameters())  # Adam优化,几乎不用调参
    criterion = nn.MSELoss()  # 因为最终的结果是一个数值,所以损失函数用均方误差
    
    rnn.train()
    plt.figure(2)
    for step in range(EPOCHS):
        start, end = step * np.pi, (step+1)*np.pi  # 一个时间周期
        steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
        x_np = np.sin(steps)
        y_np = np.cos(steps)
        x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])  # shape (batch, time_step, input_size)
        y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
        x = x.to(DEVICE)
        prediction, h_state = rnn(x, h_state) # rnn output
        # 这一步非常重要
        h_state = h_state.data  # 重置隐藏层的状态, 切断和前一次迭代的链接
        loss = criterion(prediction.cpu(), y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (step+1) % 20 == 0:  # 每训练20个批次可视化一下效果,并打印一下loss
            print("EPOCHS: {},Loss:{:4f}".format(step, loss))
            plt.plot(steps, y_np.flatten(), 'r-')
            plt.plot(steps, prediction.cpu().data.numpy().flatten(), 'b-')
            plt.draw()
            plt.pause(0.01)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    运行结果如下:

    EPOCHS: 19,Loss:0.052745

    EPOCHS: 39,Loss:0.016266

    EPOCHS: 59,Loss:0.005471

    EPOCHS: 79,Loss:0.001329

    EPOCHS: 99,Loss:0.002216

  • 相关阅读:
    论文笔记:Skeleton Key: Image Captioning by Skeleton-Attribute Decomposition
    基于python-CNN深度学习的食物识别-含数据集+pyqt界面
    Linux 怎样通过win 远程桌面连接链接Linux后台服务器的可视化图形界面
    黑马JVM总结(三十二)
    新能源分布式资产上链 数字新云南启航
    多线程事务(仅保证原子性)
    WEB安全基础 - - -Tomcat弱口令漏洞
    Linux:系统安全及应用
    Direct Sparse Mapping reading notes -- Initialization
    Java实现redis缓存效果变量过期
  • 原文地址:https://blog.csdn.net/weixin_53065229/article/details/132984316