• 时序预测 | Pytorch实现TCN-Transformer的时间序列预测


    时序预测 | Pytorch实现TCN-Transformer的时间序列预测

    效果一览

    在这里插入图片描述

    基本介绍

    基于TCN-Transformer模型的时间序列预测,可以用于做光伏发电功率预测,风速预测,风力发电功率预测,负荷预测等,python程序
    python代码,pytorch

    在这里插入图片描述

    程序设计

    数据集描述
     通过预览数据,可知此次实验的数据属性为date(日期)open(开盘价)、high(最高价)、low(最低价)、close(收盘价)以及volume(成交量)
     其中,我们要实现股票预测,需要着重对close(收盘价)一列进行探索性分析。
    """
    from torch import nn
    
    # 1.导入库 对数据集进行处理
    import pandas as pd
    import numpy as np
    from sklearn.metrics import mean_absolute_error, mean_squared_error
    from torch.utils.data import DataLoader, Dataset
    import torch
    from model import TCN_transfomer
    from sklearn.preprocessing import MinMaxScaler
    from sklearn.metrics import r2_score
    import matplotlib.pyplot as plt
    plt.rcParams['font.family'] = 'SimHei'#绘图正常显示中文
    plt.rcParams['axes.unicode_minus']=False#用来正常显示负号#有中文出现的情况,
    
    
    from tqdm import tqdm
    
    epoch = 100
    totall_loss = []  # 记录损失值
    batch_size=32
    num_inputs=5
    sequence_length=32
    num_channels=[64,16,4,1]
    kernel_size=3
    dropout=0.3
    nb_unites=sequence_length
    
    # 需要u'内容'
    # 2.定义获取数据函数,数据预处理。去除ID,股票代码,
    # 前一天的收盘价,交易日期等对训练集无用的数据
    def getData(root, sequence_length, batch_size):
        stock_data = pd.read_csv(root)
        print(stock_data.info())
        print(stock_data.head().to_string())
    
        #首先删除一些对预测close无用的信息
        stock_data.drop('id', axis=1, inplace=True)  # 删除date
        stock_data.drop(labels="ts_code", axis=1, inplace=True)
        stock_data.drop(labels="trade_date", axis=1, inplace=True)
        stock_data.drop(labels="pre_close", axis=1, inplace=True)
        stock_data.drop(labels="change", axis=1, inplace=True)
        stock_data.drop(labels="pct_chg", axis=1, inplace=True)
        stock_data.drop(labels="amount", axis=1, inplace=True)
        print("整理后\n", stock_data.head())
    
        #获取收盘价的最大值与最下值
        close_max = stock_data["close"].max()  # 收盘价的最大值
        close_min = stock_data["close"].min()  # s收盘价的最小值
        # 2.1对数据进行标准化min-max
        scaler = MinMaxScaler()
        df = scaler.fit_transform(stock_data)
        print("整理后\n", df)
        # 2.2构造X,Y
        # 根据前n天的数据,预测未来一天的收盘价(close),
        # 例如根据1月1日、1月2日、1月3日、1月4日、1月5日的数据
        # (每一天的数据包含8个特征),预测1月6日的收盘价。
        sequence = sequence_length
        x = []
        y = []
        for i in range(df.shape[0] - sequence):
            x.append(df[i:i + sequence, :])
            y.append(df[i + sequence, 3])
        x = np.array(x, dtype=np.float32)
        y = np.array(y, dtype=np.float32).reshape(-1, 1)
    
        print("x.shape=", x.shape)
        x=np.transpose(x,(0,2,1))
        print("转置后x.shape=", x.shape)
        print("y.shape", y.shape)
        # 2.3构造batch,构造训练集train与测试集test
        total_len = len(y)
        print("total_len=", total_len)
        trainx, trainy = x[:int(0.90 * total_len), ], y[:int(0.90 * total_len), ]
        testx, testy = x[int(0.90 * total_len):, ], y[int(0.90 * total_len):, ]
        train_loader = DataLoader(dataset=Mydataset(trainx, trainy), shuffle=True, batch_size=batch_size)
        test_loader = DataLoader(dataset=Mydataset(testx, testy), shuffle=True, batch_size=batch_size)
        return [close_max, close_min, train_loader, test_loader]
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83

    [1] https://blog.csdn.net/kjm13182345320/article/details/129036772?spm=1001.2014.3001.5502
    [2] https://blog.csdn.net/kjm13182345320/article/details/128690229

  • 相关阅读:
    GORM夜谈
    《设计模式:可复用面向对象软件的基础》——行为模式(笔记)
    rust编程-通用编程概念(chapter 3.1)
    蓝桥杯刷题day13——乘飞机【算法赛】
    一款强大的子域名收集工具(OneForAll)
    React@16.x(28)useMemo
    vue3 + ts 项目实站 【二】 vue-router 安装. (后台管理系统)
    《会计信息系统》课程期末复习题与参考答案
    python学习一(基础语句)
    冒泡排序
  • 原文地址:https://blog.csdn.net/kjm13182345320/article/details/134543904