当前位置: 首页 > article >正文

pytorch LSTM 结构详解

最近项目用到了LSTM ,但是对LSTM 的输入输出不是很理解,对此,我详细查找了lstm 的资料

import torch.nn as nnclass LSTMModel(nn.Module):def __init__(self, input_size=1, hidden_size=50, num_layers=2):super(LSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, 1)  # 1 表示预测输出变量为1def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)out, _ = self.lstm(x, (h0, c0))out = self.fc(out[:, -1, :])return out # out 形状为(batch_size,1)
  • input_size=1:输入特征的维度,适用于单变量时间序列。

  • hidden_size=50:LSTM 隐藏层的维度,决定了模型的记忆能力。

  • num_layers=2:堆叠的 LSTM 层数,增加层数可以提升模型的表达能力。

  • batch_first=True:指定输入和输出的张量形状为 (batch_size, seq_len, input_size)

  • self.fc:一个全连接层,将 LSTM 的输出映射到最终的预测值。

  • batch_size 表示批次、seq_len 表示窗口大小、input_size 表示输入尺寸,单变量输入为1 ,多变量要基于个数变化

  • 初始化隐藏状态和细胞状态

    • h0c0 分别表示初始的隐藏状态和细胞状态,形状为 (num_layers, batch_size, hidden_size)

    • 在每次前向传播时,初始化为零张量。

  • LSTM 层处理

    • self.lstm(x, (h0, c0)):将输入 x 和初始状态传入 LSTM 层,输出 out 和新的状态。

    • out 的形状为 (batch_size, seq_len, hidden_size),包含了每个时间步的输出。

  • 全连接层映射

    • out[:, -1, :]:提取序列中最后一个时间步的输出。

    • self.fc(...):将提取的输出通过全连接层,得到最终的预测结果。

http://www.lryc.cn/news/2384922.html

相关文章:

  • 流程自动化引擎:重塑企业数字神经回路
  • nginx web服务日志分析
  • VSCode+EIDE通过KeilC51编译,使VSCode+EIDE“支持”C和ASM混编
  • 5.23本日总结
  • 游戏引擎学习第298天:改进排序键 - 第1部分
  • Mysql篇-优化
  • Java 集合框架核心知识点全解析:从入门到高频面试题(含 JDK 源码剖析)
  • 一文详解生成式 AI:李宏毅《生成式 AI 导论》学习笔记
  • 什么是物联网 (IoT):2024 年物联网概述
  • 8级-数组
  • 大模型 Agent 就是文字艺术吗?
  • YOLOv8检测头代码详解(示例展示数据变换过程)
  • JUC并发编程1
  • 消息队列RabbitMQ与AMQP协议详解
  • Day 29 训练
  • STM32开发环境配置——VSCode+PlatformIO + CubeMX + FreeRTOS的集成环境配置
  • Profibus转Profinet网关赋能鼓式硫化机:智能化生产升级的关键突破
  • redis 缓存穿透,缓存雪崩,缓存击穿
  • JAVA8怎么使用9的List.of
  • 告别手动测试:AUTOSAR网络管理自动化测试实战
  • BUCK电路利用状态空间平均法和开关周期平均法推导
  • MongoDB 用户与权限管理完全指南
  • C++滑动门问题(附两种方法)
  • 基于ITcpServer/IHttpServer框架的HTTP服务器
  • 初识main函数
  • FPGA高效验证工具Solidify 8.0:全面重构图形用户界面
  • SIL2/PLd 认证 Inxpect毫米波安全雷达:3D 扫描 + 微小运动检测守护工业安全
  • java中string类型的list集合放到redis的5种数据类型的那种比较合适呢,可以用StringRedisTemplate实现
  • PyQt学习系列09-应用程序打包与部署
  • 实现图片自动压缩算法,canvas压缩图片方法