当前位置: 首页 > news >正文

人工智能基础部分13-LSTM网络:预测上证指数走势

大家好,我是微学AI,今天给大家介绍一下LSTM网络,主要运用于解决序列问题。

一、LSTM网络简单介绍

LSTM又称为:长短期记忆网络,它是一种特殊的 RNN。LSTM网络主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。对于相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

引入LSTM网络的原因:由于 RNN 网络主要问题是长期依赖,即隐藏状态在时间上传递过程中可能会丢失之前的信息。为了解决这个问题,引入了长短时记忆网络 (LSTM) 和门控循环单元 (GRU)。这两种网络结构在隐藏层中增加了门控机制,能够更好地控制信息的传递。

 其中符号及表示意思如下:

 LSTM中有三个门:
(1)遗忘门f:决定上一个时刻的记忆单元状态需要遗忘多少信息,保留多少信息到当前记忆单元状态。
(2)输入门i:控制当前时刻输入信息候选状态有多少信息需要保存到当前记忆单元状态。
(3)输出门o:控制当前时刻的记忆单元状态有多少信息需要输出给外部状态。

形象的例子让我们更好的理解LSTM的原理:

假设你是一个梦想远大的学生,你想通过学习一门课程获得更多的知识。在学习过程中,LSTM模型帮助你,它就像是一个老师,它的遗忘门就像是老师的提醒,它让你挑出不用的知识,以保持你对重要知识的清晰记忆。它的输入门就像是老师的指导,它会重新审视你学习过的知识,按照自己的逻辑把知识结合起来,进化出更多有用的知识。最后,它的输出门就像老师的监督,它会确保你学习到了有用的知识,不要浪费时间去学习无用的知识。

二、LSTM网络运用-预测上证指数走势

# 使用LSTM预测沪市指数
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout
from pandas import DataFrame
from pandas import concat
from itertools import chain
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']# 转化为可以用于监督学习的数据
def get_train_set(data_set, timesteps_in, timesteps_out=1):train_data_set = np.array(data_set)reframed_train_data_set = np.array(series_to_supervised(train_data_set, timesteps_in, timesteps_out).values)train_x, train_y = reframed_train_data_set[:, :-timesteps_out], reframed_train_data_set[:, -timesteps_out:]# 将数据集重构为符合LSTM要求的数据格式,即 [样本数,时间步,特征]train_x = train_x.reshape((train_x.shape[0], timesteps_in, 1))return train_x, train_y"""
将时间序列数据转换为适用于监督学习的数据
给定输入、输出序列的长度
data: 观察序列
n_in: 观测数据input(X)的步长,范围[1, len(data)], 默认为1
n_out: 观测数据output(y)的步长, 范围为[0, len(data)-1], 默认为1
dropnan: 是否删除NaN行
返回值:适用于监督学习的 DataFrame
"""
def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):print(data.shape)n_vars = 1 if type(data) is list else data.shape[1]df = DataFrame(data)cols, names = list(), list()# input sequence (t-n, ... t-1)for i in range(n_in, 0, -1):cols.append(df.shift(i))names += [('var%d(t-%d)' % (j + 1, i)) for j in range(n_vars)]# 预测序列 (t, t+1, ... t+n)for i in range(0, n_out):cols.append(df.shift(-i))if i == 0:names += [('var%d(t)' % (j + 1)) for j in range(n_vars)]else:names += [('var%d(t+%d)' % (j + 1, i)) for j in range(n_vars)]# 拼接到一起agg = concat(cols, axis=1)agg.columns = names# 去掉NaN行if dropnan:agg.dropna(inplace=True)return agg# 使用LSTM进行预测
def lstm_model(source_data_set, train_x, label_y, input_epochs, input_batch_size, timesteps_out):model = Sequential()# 第一层, 隐藏层神经元节点个数为128, 返回整个序列model.add(LSTM(128, return_sequences=True, activation='tanh', input_shape=(train_x.shape[1], train_x.shape[2])))# 第二层,隐藏层神经元节点个数为128, 只返回序列最后一个输出model.add(LSTM(128, return_sequences=False))model.add(Dropout(0.5))# 第三层 因为是回归问题所以使用linearmodel.add(Dense(timesteps_out, activation='linear'))model.compile(loss='mean_squared_error', optimizer='adam')# LSTM训练 input_epochs次数res = model.fit(train_x, label_y, epochs=input_epochs, batch_size=input_batch_size, verbose=2, shuffle=False)# 模型预测train_predict = model.predict(train_x)#test_data_list = list(chain(*test_data))train_predict_list = list(chain(*train_predict))plt.plot(res.history['loss'], label='train')plt.show()#print(model.summary())plot_img(source_data_set, train_predict)# 呈现原始数据,训练结果,验证结果,预测结果
def plot_img(source_data_set, train_predict):plt.figure(figsize=(24, 8))# 原始数据蓝色plt.plot(source_data_set[:, -1], c='b',label = '标签')# 训练数据绿色plt.plot([x for x in train_predict], c='g')plt.legend()plt.show()# 设置观测数据input(X)的步长(时间步),epochs,batch_size
timesteps_in = 3
timesteps_out = 3
epochs = 1000
batch_size = 100
data = pd.read_csv('./shanghai_index_1990_12_19_to_2019_12_11.csv')
data_set = data[['Price']].values.astype('float64')
# 转化为可以用于监督学习的数据
train_x, label_y = get_train_set(data_set, timesteps_in=timesteps_in, timesteps_out=timesteps_out)print(train_x, label_y )
print(train_x.shape)
print(train_x.shape[1], train_x.shape[2])# 使用LSTM进行训练、预测
lstm_model(data_set, train_x, label_y, epochs, batch_size, timesteps_out=timesteps_out)

运行结果:

http://www.lryc.cn/news/15297.html

相关文章:

  • 内网穿透/组网/设备上云平台EasyNTS上云网关的安装操作指南
  • 易点天下基于 StarRocks 全面构建实时离线一体的湖仓方案
  • Tomcat的类加载机制
  • 【shell 编程大全】数组,逻辑判断以及循环
  • Android13 Bluetooth更新
  • 手工测试混了5年,年底接到了被裁员的消息....
  • Umi框架
  • 教你学git
  • 【工作笔记】syslog,kern.log大量写入invalid cookie错误信息问题
  • 【C++】多线程
  • 0202插入删除-算法第四版红黑树-红黑树-数据结构和算法(Java)
  • vue 生成二维码插件 vue-qr使用方法
  • 网络工程课(二)
  • Pytorch并行计算(三): 梯度累加
  • 蓝桥杯入门即劝退(十八)最小覆盖子串(滑动窗口解法)
  • Android一~
  • 一月券商金工精选
  • UML中常见的9种图
  • 使用SpringBoot实现无限级评论回复功能
  • Kafka 介绍和使用
  • [学习笔记]Rocket.Chat业务数据备份
  • 【ZOJ 1090】The Circumference of the Circle 题解(海伦公式+正弦定理推论)
  • 【go】slice原理
  • 【数据库】MySQL概念知识语法-基础篇(DQL),真的很详细,一篇文章你就会了
  • 博客界的至高神:属于自己的WordPress网站,你值得拥有!
  • 操作系统(day13)-- 虚拟内存;页面分配策略
  • SQL零基础入门学习(四)
  • 19岁就患老年痴呆!这些前兆别忽视!
  • 【C++】thread|mutex|atomic|condition_variable
  • 学成在线项目笔记