当前位置: 首页 > news >正文

MATLAB实现LSTM时间序列预测

        LSTM模型可以在一定程度上学习和预测非平稳的时间序列,其具有强大的记忆和非线性建模能力,可以捕捉到时间序列中的复杂模式和趋势[4]。在这种情况下,LSTM模型可能会自动学习到时间序列的非平稳性,并在预测中进行适当的调整。其作为循环神经网络(RNN)的特殊形式,继承了循环神经网络的优点。首先,利用记忆机制,可以有效提取时间序列数据的时间依赖性。其次,在模型训练时,学习到的权重参数在时间步骤之间是共享的,故对长时间序列的训练具有一定的可拓展性,而且比起传统的神经网络模型,它所需参数数量较少,降低了模型的复杂度。最后,它也具有LSTM神经网络特有的优势,对训练时权重变化不稳定而产生梯度消失和梯度爆炸问题有着不错的改善效果。LSTM单元的主要结构由3个门控制器和记忆细胞组成。其中,输入门控制特征的流向信息,输出门控制特征的输出信息,遗忘门控制特征的去除与遗忘,记忆细胞负责存储细胞状态信息。通过不同功能门的控制,从而解决RNN存在的长期依赖问题[5]。LSTM单元内的计算过程为:

clc
clear
load('data.mat')
data=RTS'
%% 序列的前485个用于训练,后10个用于验证神经网络,然后往后预测10个数据。
dataTrain = data(1:485);    %定义训练集
dataTest = data(486:495);    %该数据是用来在最后与预测值进行对比的
%% 数据预处理
mu = mean(dataTrain);    %求均值 
sig = std(dataTrain);      %求均差 
dataTrainStandardized = (dataTrain - mu) / sig;    
%% 输入的每个时间步,LSTM网络学习预测下一个时间步,这里交错一个时间步效果最好。
XTrain = dataTrainStandardized(1:end-1);  
YTrain = dataTrainStandardized(2:end);  
%% 一维特征lstm网络训练
numFeatures = 1;   %特征为一维
numResponses = 1;  %输出也是一维
numHiddenUnits = 200;   %创建LSTM回归网络,指定LSTM层的隐含单元个数200。可调layers = [ ...sequenceInputLayer(numFeatures)    %输入层lstmLayer(numHiddenUnits)  % lstm层,如果是构建多层的LSTM模型,可以修改。fullyConnectedLayer(numResponses)    %为全连接层,是输出的维数。regressionLayer];      %其计算回归问题的半均方误差模块 。即说明这不是在进行分类问题。%指定训练选项,求解器设置为adam, 1000轮训练。
%梯度阈值设置为 1。指定初始学习率 0.01,在 125 轮训练后通过乘以因子 0.2 来降低学习率。
options = trainingOptions('adam', ...'MaxEpochs',1000, ...'GradientThreshold',1, ...'InitialLearnRate',0.01, ...      'LearnRateSchedule','piecewise', ...%每当经过一定数量的时期时,学习率就会乘以一个系数。'LearnRateDropPeriod',400, ...      %乘法之间的纪元数由“ LearnRateDropPeriod”控制。可调'LearnRateDropFactor',0.15, ...      %乘法因子由参“ LearnRateDropFactor”控制,可调'Verbose',0,  ...  %如果将其设置为true,则有关训练进度的信息将被打印到命令窗口中。默认值为true。'Plots','training-progress');    %构建曲线图 将'training-progress'替换为none
net = trainNetwork(XTrain,YTrain,layers,options); 
net = predictAndUpdateState(net,XTrain);  %将新的XTrain数据用在网络上进行初始化网络状态
[net,YPred] = predictAndUpdateState(net,YTrain(end));  %用训练的最后一步来进行预测第一个预测值,给定一个初始值。这是用预测值更新网络状态特有的。
%% 进行用于验证神经网络的数据预测(用预测值更新网络状态)
for i = 2:20  %从第二步开始,这里进行20次单步预测(10为用于验证的预测值,10为往后预测的值。一共20个)[net,YPred(:,i)] = predictAndUpdateState(net,YPred(:,i-1),'ExecutionEnvironment','cpu');  %predictAndUpdateState函数是一次预测一个值并更新网络状态
end
%% 验证神经网络
YPred = sig*YPred + mu;      %使用先前计算的参数对预测去标准化。
rmse = sqrt(mean((YPred(1:10)-dataTest).^2)) ;     %计算均方根误差 (RMSE)。
subplot(2,1,1)
plot(dataTrain(1:end))   %先画出前面485个数据,是训练数据。
hold on
idx = 486:(485+20);   %为横坐标
plot(idx,YPred(1:20),'.-')  %显示预测值
hold off
xlabel("Time")
ylabel("Case")
title("Forecast")
legend(["Observed" "Forecast"])
subplot(2,1,2)
plot(data)
xlabel("Time")
ylabel("Case")
title("Dataset")
%% 继续往后预测2023年的数据
figure(2)
idx = 486:(485+20);   %为横坐标
plot(idx,YPred(1:20),'.-')  %显示预测值
hold off
net = resetState(net);

 MATLAB运行结果如下:

 

http://www.lryc.cn/news/297510.html

相关文章:

  • Kubernetes CNI Calico:Route Reflector 模式(RR) calico IPIP切换RR网络模式
  • 探索Gin框架:Golang Gin框架请求参数的获取
  • 极值图论基础
  • word导出链接
  • (delphi11最新学习资料) Object Pascal 学习笔记---第4章第2.5节(重载和模糊调用)
  • ElementUI Data:Table 表格
  • 11.2 OpenGL可编程顶点处理:细分着色器
  • 微软正在偷走你的浏览记录,Edge浏览器偷疯了
  • 什么是数据库软删除,什么场景下要用软删除?(go GORM硬删除)
  • 计算机设计大赛 深度学习+python+opencv实现动物识别 - 图像识别
  • 我主编的电子技术实验手册(02)——仪表与电源
  • C语言----内存函数
  • 【力扣】快乐数,哈希集合 + 快慢指针 + 数学
  • c实现顺序表
  • 微软为新闻编辑行业推出 AI 辅助项目,记者参加免费课程
  • openssl3.2 - exp - buffer to BIO
  • Android 13.0 系统framework修改低电量关机值为3%
  • 【EAI 013】BC-Z: Zero-Shot Task Generalization with Robotic Imitation Learning
  • 一文讲透ast.literal_eval() eval() json.loads()
  • 微软.NET6开发的C#特性——类、结构体和联合体
  • naiveui 上传图片遇到的坑 Upload
  • 安全之护网(HVV)、红蓝对抗
  • Leetcode 213 打家劫舍 II
  • 【C语言】三子棋游戏实现代码
  • docker常用10条容器操作命令
  • 《MySQL 简易速速上手小册》第2章:数据库设计最佳实践(2024 最新版)
  • 利用YOLOv8 pose estimation 进行 人的 头部等马赛克
  • 【Python 千题 —— 基础篇】查找年龄
  • 前后端通讯:前端调用后端接口的五种方式,优劣势和场景
  • Mysql大表添加字段失败解决方案