码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • Lstm多变量时间序列预测框架|pytorch


    这是目前我看国内总结时序预测对小白很友好的博客教程,先推荐一下

    PyTorch搭建CNN实现时间序列预测(风速预测)_Cyril_KI的博客-CSDN博客_cnn回归预测pytorch 

    代码:单步预测

    ## 如果在初始化LSTM时令batch_first=True,那么input和output的shape将由:

    ## input(seq_len, batch_size, input_size)

    ## output(seq_len, batch_size, num_directions * hidden_size)

    ## 变为

    ## input(batch_size, seq_len, input_size)

    ## output(batch_size, seq_len, num_directions * hidden_size)

    ## self.num_directions = 1 # 单向LSTM 2为双向LSTM

    1. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    2. class LSTM(nn.Module):
    3. def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
    4. super().__init__()
    5. self.input_size = input_size #channel
    6. self.hidden_size = hidden_size #输出维度 也就是输出通道
    7. self.num_layers = num_layers
    8. self.output_size = output_size #输出个数
    9. self.num_directions = 1 # 单向LSTM
    10. self.batch_size = batch_size
    11. self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
    12. self.linear = nn.Linear(self.hidden_size, self.output_size)
    13. # self.linear = nn.Linear()
    14. def forward(self, input_seq):
    15. batch_size, seq_len = input_seq.shape[0], input_seq.shape[1]
    16. h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
    17. c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
    18. # output(batch_size, seq_len, num_directions * hidden_size)
    19. output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
    20. pred = self.linear(output) # (5, 30, 1)
    21. pred = pred[:, -1, :] # (5, 1)
    22. return pred
    23. lstm = LSTM(input_size=7,hidden_size=64, output_size=1,batch_size=5,num_layers=5).to(device)
    24. tensor = torch.rand(5, 30, 7).to(device)
    25. result = lstm(tensor)
    26. print("result.shape:",result.shape)

    参考资料

    深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)_Cyril_KI的博客-CSDN博客_lstm输入

    //写的非常得好 

    多步预测的讨论

    我的意见如下:首先,要理清楚两个概念:一是多变量一般是指你在预测时考虑了主变量以外的其他变量,而不是说你要预测多个变量,预测多个变量当然是可以的,但效果特别差,这个我之前还和清华一个搞新能源预测的PhD讨论过,所以我们一般是多变量预测单变量,如果要预测多变量那就训练多个LSTM;二是步长的问题,我这里预测了4步,所以是多步。你所纠结的无非是预测的是下一时刻,然而我却直接经过转换变成了预测四个时刻,但这种写法其实是合理的。多步长预测一般有以下几种方法:第一种就是我这种,直接取最后一步,然后接一个MLP来转换成多步,这种优点是简单,可以直接输入多个预测值,这种相当于是把最后的时间序列预测变成了一个纯粹的神经网络非线性预测,也就是这种最后的预测结果要看你MLP的性能了,所以这种我们一般会多加几个线性层;第二种是滚动预测,也就是预测单步然后把预测值加入继续滚动预测多次,这种可能会把误差传递,效果也一般;第三种是多个单步预测,也就是你要预测接下来n步,那么就训练n个模型分别来预测每一步,这种比较耗时。所以用哪种看你自己,你想要理解我这种想法,你可以就把最后这个转换看成是一个简单的MLP转换,仅此而已。

    在这里插入图片描述

     

     

  • 相关阅读:
    在低容错业务场景下落地微服务的实践经验
    笔者近期感想
    java如何创建一个只读集合呢?
    【C++】运算符重载 ⑥ ( 一元运算符重载 | 后置运算符重载 | 前置运算符重载 与 后置运算符重载 的区别 | 后置运算符重载添加 int 占位参数 )
    国产大模型参加高考,同写2024年高考作文,及格分(通义千问、Kimi、智谱清言、Gemini Advanced、Claude-3-Sonnet、GPT-4o)
    【专栏】基础篇04| Redis 该怎么保证数据不丢失(上)
    锂热电池检测设备 你一定没见过这种检测方式!
    如何设置和解除PDF文件保护?
    浏览器页面被禁用 F12(dev tools)
    JAVA设计模式-责任链模式
  • 原文地址:https://blog.csdn.net/weixin_43332715/article/details/127741022
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号