• LSTM长短期记忆网络


    LSTM的matlab详解

    LSTM(Long Short-Term Memory) 是一种特殊的循环神经网络,它通过引入门机制(gate)来解决传统循环神经网络长序列训练过程中的梯度消失和爆炸问题,从而能够更好地处理序列数据。

        LSTM在传统的循环神经网络基础上增加了三个门结构:遗忘门、输入门和输出门。这些门结构允许网络选择性地忘记或记住特定事物,并控制信息的流动,从而增强了网络的记忆能力和泛化能力。具体来说,

    • 遗忘门控制是否清除先前保存的记忆状态,避免无关信息对当前任务造成干扰;
    • 输入门控制当前要存储的信息,提高网络对重要信息的敏感度;
    • 输出门控制哪些记忆状态被输出,使网络能够更好地适应具体的任务需求。

        LSTM网络利用这些门结构学习长时间依赖关系,比如在语音识别、自然语言处理和时间序列预测等领域中有着广泛的应用。

    以下是一个用LSTM神经网络进行单维度时间序列的示例:

    单维度的时间序列预测

    clc
    clear
    load oxygen.mat %加载数据(double型,剔除了2023年的的第一个数据,总共为2191个。时序预测没有实际时间,只有事情发生的顺序)
    data=data';  %不转置的话,无法训练lstm网络,显示维度不对。
    %% 序列的前17727个用于训练,后200个用于验证神经网络,然后往后预测200个数据。
    dataTrain = data(1:17727);    %定义训练集
    dataTest = data(17728:end);    %该数据是用来在最后与预测值进行对比的
    
    %% 数据预处理
    mu = mean(dataTrain);    %求均值 
    sig = std(dataTrain);      %求均差 
    dataTrainStandardized = (dataTrain - mu) / sig;    
    
    %% 输入的每个时间步,LSTM网络学习预测下一个时间步,这里交错一个时间步效果最好。
    XTrain = dataTrainStandardized(1:end-1);  
    YTrain = dataTrainStandardized(2:end);  
    
    %% 一维特征lstm网络训练
    numFeatures = 1;   %特征为一维
    numResponses = 1;  %输出也是一维
    numHiddenUnits = 200;   %创建LSTM回归网络,指定LSTM层的隐含单元个数200。可调
     
    layers = [ ...
        sequenceInputLayer(numFeatures)    %输入层
        lstmLayer(numHiddenUnits)  % lstm层,如果是构建多层的LSTM模型,可以修改。
        fullyConnectedLayer(numResponses)    %为全连接层,是输出的维数。
        regressionLayer];      %其计算回归问题的半均方误差模块 。即说明这不是在进行分类问题。
     
    %指定训练选项,求解器设置为adam, 1000轮训练。
    %梯度阈值设置为 1。指定初始学习率 0.01,在 125 轮训练后通过乘以因子 0.2 来降低学习率。
    options = trainingOptions('adam', ...
        'MaxEpochs',1000, ...
        'GradientThreshold',1, ...
        'InitialLearnRate',0.01, ...      
        'LearnRateSchedule','piecewise', ...%每当经过一定数量的时期时,学习率就会乘以一个系数。
        'LearnRateDropPeriod',400, ...      %乘法之间的纪元数由“ LearnRateDropPeriod”控制。可调
        'LearnRateDropFactor',0.15, ...      %乘法因子由参“ LearnRateDropFactor”控制,可调
        'Verbose',0,  ...  %如果将其设置为true,则有关训练进度的信息将被打印到命令窗口中。默认值为true。
        'Plots','training-progress');    %构建曲线图 将'training-progress'替换为none
    net = trainNetwork(XTrain,YTrain,layers,options); 
    
    %% 神经网络初始化
    net = predictAndUpdateState(net,XTrain);  %将新的XTrain数据用在网络上进行初始化网络状态
    [net,YPred] = predictAndUpdateState(net,YTrain(end));  %用训练的最后一步来进行预测第一个预测值,给定一个初始值。这是用预测值更新网络状态特有的。
    
    %% 进行用于验证神经网络的数据预测(用预测值更新网络状态)
    for i = 2:291  %从第二步开始,这里进行191次单步预测(191为用于验证的预测值,100为往后预测的值。一共291个)
        [net,YPred(:,i)] = predictAndUpdateState(net,YPred(:,i-1),'ExecutionEnvironment','cpu');  %predictAndUpdateState函数是一次预测一个值并更新网络状态
    end
    %% 验证神经网络
    YPred = sig*YPred + mu;      %使用先前计算的参数对预测去标准化。
    rmse = sqrt(mean((YPred(1:191)-dataTest).^2)) ;     %计算均方根误差 (RMSE)。
    subplot(2,1,1)
    plot(dataTrain(1:end))   %先画出前面2000个数据,是训练数据。
    hold on
    idx = 2001:(2000+191);   %为横坐标
    plot(idx,YPred(1:191),'.-')  %显示预测值
    hold off
    xlabel("Time")
    ylabel("Case")
    title("Forecast")
    legend(["Observed" "Forecast"])
    subplot(2,1,2)
    plot(data)
    xlabel("Time")
    ylabel("Case")
    title("Dataset")
    %% 继续往后预测2023年的数据
    figure(2)
    idx = 2001:(2000+291);   %为横坐标
    plot(idx,YPred(1:291),'.-')  %显示预测值
    hold off
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
  • 相关阅读:
    PAL8像素格式
    Java 集合 - Set 接口
    软件测试怎么学?App自动化、Web自动化、性能测试怎么学?一文总结
    【数据结构】二叉树
    python 之numpy 之随机生成数
    八大排序(四)--------直接插入排序
    Linux项目实战——五子棋(单机人人对战版)
    C#程序全局异常处理—WPF和Web API两种模式
    ArrayList源码解析
    基于JavaWeb的学生住宿管理系统
  • 原文地址:https://blog.csdn.net/Lov1_BYS/article/details/131923284