当前位置: 首页 > news >正文

LSTM实战之预测股票

📈 用PyTorch搭建LSTM模型,轻松预测股票价格!🚀

Hey小伙伴们,今天给大家带来一个超级实用的项目教程——如何用PyTorch和LSTM模型来预测股票价格!🌟

🔍 项目背景

我们都知道股市是个风云变幻的地方,而预测股价则是很多投资者梦寐以求的能力。今天,我们就来尝试一下用机器学习的方法来预测股价,让数据说话!

📑 准备工作

首先,我们要准备好开发环境,确保安装了以下Python库:

  • numpy: 数组处理
  • pandas: 数据处理
  • matplotlib: 数据可视化
  • scikit-learn: 数据预处理
  • torch: 构建LSTM模型

💻 实战演练

1️⃣ 导入库 & 加载数据

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim# 加载数据
df = pd.read_csv('stock_data.csv')
# 只保留收盘价
data = df.filter(['Close'])
# 将数据转换为numpy数组
dataset = data.values
# 归一化数据
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(dataset)

2️⃣ 创建数据集

# 训练集和测试集划分
training_data_len = int(np.ceil(len(dataset) * .8))# 创建训练数据集
def create_dataset(data, time_step=1):X_train, y_train = [], []for i in range(len(data)-time_step-1):X_train.append(data[i:(i+time_step), 0])y_train.append(data[i + time_step, 0])return np.array(X_train), np.array(y_train)time_step = 60
X_train, y_train = create_dataset(scaled_data[:training_data_len], time_step)
X_test, y_test = create_dataset(scaled_data[training_data_len-time_step:], time_step)# 调整数据形状以适应LSTM
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))# 转换为PyTorch张量
X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).float()
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float()

3️⃣ 构建LSTM模型

class LSTMModel(nn.Module):def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):super(LSTMModel, self).__init__()self.hidden_dim = hidden_dimself.layer_dim = layer_dimself.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))out = self.fc(out[:, -1, :]) return outinput_dim = 1
hidden_dim = 50
layer_dim = 1
output_dim = 1model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

4️⃣ 训练模型

num_epochs = 100for epoch in range(num_epochs):outputs = model(X_train)optimizer.zero_grad()# 获取损失loss = criterion(outputs, y_train)# 反向传播和优化loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

5️⃣ 预测和评估

# 获取模型预测值
train_predictions = model(X_train).detach().numpy()
test_predictions = model(X_test).detach().numpy()# 反归一化预测值
train_predictions = scaler.inverse_transform(train_predictions)
test_predictions = scaler.inverse_transform(test_predictions)# 计算均方根误差(RMSE)
rmse = np.sqrt(np.mean(((test_predictions - y_test.numpy()) ** 2)))
print("Root Mean Squared Error:", rmse)# 可视化结果
train = data[:training_data_len+1]
valid = data[training_data_len+1:]
valid['Predictions'] = test_predictionsplt.figure(figsize=(16,8))
plt.title('Model')
plt.xlabel('Date', fontsize=18)
plt.ylabel('Close Price', fontsize=18)
plt.plot(train['Close'])
plt.plot(valid[['Close', 'Predictions']])
plt.legend(['Train', 'Val', 'Predictions'], loc='upper right')
plt.show()

📊 结果展示

最后,我们来看看预测结果。可以看到,我们的模型虽然不是完美无缺,但在一定程度上还是能够捕捉到股价的变化趋势。这为投资者提供了非常有价值的信息哦!👀
在这里插入图片描述

🏆 结语

今天的分享就到这里啦!希望这篇教程能帮到你,也欢迎小伙伴们在评论区分享你的经验或者遇到的问题,我们一起探讨学习!🌟


如果你在运行过程中遇到任何问题,或者想要了解更多细节,随时可以问我哦!💡
如果你喜欢这篇教程,请给我点个赞哦!💖
也可以收藏,关注我了解更多人工智能知识哦!😉


📌 附录:常见问题解答

  • Q: 如何获取股票数据?

  • A: 你可以从雅虎财经、tushare等数据源获取股票数据。

  • Q: 为什么我的模型预测效果不好?

  • A: 可能是因为数据不足、模型结构不够复杂或者超参数设置不当。尝试增加数据量、调整模型架构或优化超参数。

  • Q: 我可以在哪里找到更多关于LSTM的知识?

  • A: 有很多在线资源和书籍可以学习LSTM,比如官方文档、博客文章和教程视频。

希望这篇文章对你有所帮助!如果有任何疑问,记得留言哦!👋

#PyTorch #LSTM #股票预测 #时间序列分析 #机器学习 #数据科学 #Python编程

http://www.lryc.cn/news/417805.html

相关文章:

  • 30-50K|抖音大模型|社招3轮面经
  • ChatGPT首次被植入人类大脑:帮助残障人士开启对话
  • 数据结构-常见排序的七大排序
  • 程序员学CFA——财务报告与分析(四)
  • 【消息队列】kafka如何保证消息不丢失?
  • 不同随机数生成的含义
  • Jar工具完全指南:从入门到精通
  • 前端使用docx-preview展示docx + 后端doc转docx
  • Vue3 组件通信
  • 如何在Ubuntu 14.04上安装、配置和部署Rocket.Chat
  • ISO 26262中的失效率计算:IEC TR 62380-Section 15-Switches and keyboards
  • Linux安全与高级应用(五)深入探讨Linux Shell脚本应用:从基础到高级
  • Java中等题-解码方法(力扣)
  • 【Git】git 从入门到实战系列(二)—— Git 介绍以及安装方法
  • 【QT 5 QT 6 构建工具qmake-cmake-和-软件编译器MSVCxxxvs MinGWxxx说明】
  • SD卡参数错误:深度解析与数之寻软件恢复实战
  • 深入理解和应用RabbitMQ的Work Queues模型
  • 嵌入式面试八股文(三)·野指针产生原因和解决方法、指针函数和函数指针的区别
  • OpenCV 中 CV_8UC1,CV_32FC3,CV_32S等参数的含义
  • v 3 + vite + ts 自适应布局(postcss-pxtorem)
  • (MTK)java文件添加简单接口并配置相应的SELinux avc 权限笔记2
  • Linux安全与高级应用(六)Linux Shell脚本编程的高级应用:条件测试与if语句的妙用
  • 升级MacOS(Mojave)后使用git问题
  • 基于PFC和ECN搭建无损RoCE网络的工作流程分析
  • 射频功率放大器调测简略
  • Linux使用docker搭建Redis 哨兵模式
  • springboot给类进行赋初值的四种方式
  • Day32 | 1049. 最后一块石头的重量 II 494. 目标和 474.一和零
  • linux 查看一个端口是否被占用
  • 【Git】5. 配置 Git