当前位置: 首页 > news >正文

LSTM 学习笔记 之pytorch调包每个参数的解释

0、 LSTM 原理

整理优秀的文章
LSTM入门例子:根据前9年的数据预测后3年的客流(PyTorch实现)
[干货]深入浅出LSTM及其Python代码实现
整理视频
李毅宏手撕LSTM
[双语字幕]吴恩达深度学习deeplearning.ai

1 Pytorch 代码

这里直接调用了nn.lstm

 self.lstm = nn.LSTM(input_size, hidden_size, num_layers)  # utilize the LSTM model in torch.nn

下面作为初学者解释一下里面的3个参数
input_size: 这个就是输入的向量的长度or 维度,如一个单词可能占用20个维度。
hidden_size: 这个是隐藏层,其实我感觉有点全连接的意思,这个层的维度影响LSTM 网络输入的维度,换句话说,LSTM接收的数据维度不是输入什么维度就是什么维度,而是经过了隐藏层,做了一个维度的转化。
num_layers: 这里就是说堆叠了几个LSMT 结构。

2 网络定义

class LstmRNN(nn.Module):"""Parameters:- input_size: feature size- hidden_size: number of hidden units- output_size: number of output- num_layers: layers of LSTM to stack"""def __init__(self, input_size, hidden_size=1, output_size=1, num_layers=1):super().__init__()self.lstm = nn.LSTM(input_size, hidden_size, num_layers)  # utilize the LSTM model in torch.nnself.forwardCalculation = nn.Linear(hidden_size, output_size)def forward(self, _x):x, _ = self.lstm(_x)  # _x is input, size (seq_len, batch, input_size)s, b, h = x.shape  # x is output, size (seq_len, batch, hidden_size)x = x.view(s * b, h)x = self.forwardCalculation(x)x = x.view(s, b, -1)return x

3 网络初始化

我们定义一个网络导出onnx ,观察 网络的具体结构

INPUT_FEATURES_NUM = 100
OUTPUT_FEATURES_NUM = 13
lstm_model = LstmRNN(INPUT_FEATURES_NUM, 16, output_size=OUTPUT_FEATURES_NUM, num_layers=2)  # 16 hidden units
print(lstm_model)
save_onnx_path= "weights/lstm_16.onnx"
input_data = torch.randn(1,150,100)input_names = ["images"] + ["called_%d" % i for i in range(2)]
output_names = ["prob"]
torch.onnx.export(lstm_model,input_data,save_onnx_path,verbose=True,input_names=input_names,output_names=output_names,opset_version=12)

在这里插入图片描述
可以看到 LSTM W 是1x64x100;这个序列150没有了 是不是说150序列是一次一次的送的呢,所以在网络中没有体现;16是hidden,LSTM里面的W是64,这里存在一个4倍的关系。
我想这个关系和LSTM的3个门(输入+输出+遗忘+C^)有联系。
在这里插入图片描述
在这里插入图片描述
这里输出我们设置的13,如图 onnx 网络结构可视化显示也是13,至于这个150,或许就是输入有150个词,输出也是150个词吧。

在这里插入图片描述
至于LSTM的层数设置为2,则表示有2个LSTM堆叠。
在这里插入图片描述

4 网络提取

另外提取 网络方便看 每一层的维度,代码如下。

import onnx
from onnx import helper, checker
from onnx import TensorProto
import re
import argparse
model = "./weights/lstm_16.onnx"
output_model_path = "./weights/lstm_16_e.onnx"onnx_model = onnx.load(model)
#Flatten
onnx.utils.extract_model(model, output_model_path, ['images'],['prob'])
http://www.lryc.cn/news/535497.html

相关文章:

  • ASUS/华硕飞行堡垒9 FX506H FX706H 原厂Win10系统 工厂文件 带ASUS Recovery恢复
  • Unity使用iTextSharp导出PDF-04图形
  • JDBC如何连接数据库
  • Unity URP的2D光照简介
  • 【IC】AI处理器核心--第二部分 用于处理 DNN 的硬件设计
  • 从 0 开始本地部署 DeepSeek:详细步骤 + 避坑指南 + 构建可视化(安装在D盘)
  • 如何本地部署DeepSeek集成Word办公软件
  • Centos10 Stream 基础配置
  • 时间序列分析(三)——白噪声检验
  • ThinkPHP8视图赋值与渲染
  • 对贵司需求的PLC触摸的远程调试的解决方案
  • 2.12寒假作业
  • 记使用AScript自动化操作ios苹果手机
  • 【Apache Paimon】-- 16 -- 利用 paimon-flink-action 同步 kafka 数据到 hive paimon 表中
  • 基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试
  • 算法之 数论
  • Java 大视界 -- 人工智能驱动下 Java 大数据的技术革新与应用突破(83)
  • 【04】RUST特性
  • PlantUml常用语法
  • 保存字典类型的文件用什么格式比较好
  • 开源模型应用落地-Qwen1.5-MoE-A2.7B-Chat与vllm实现推理加速的正确姿势(一)
  • 一竞技瓦拉几亚S4预选:YB 2-0击败GG
  • deepseek+kimi一键生成PPT
  • mybatis 是否支持延迟加载?延迟加载的原理是什么?
  • 【Android开发】安卓手机APP拍照并使用机器学习进行OCR文字识别
  • 力扣 15.三数之和
  • 机器学习:二分类和多分类
  • 安科瑞光伏发电防逆流解决方案——守护电网安全,提升能源效率
  • ml5.js框架实现AI图片识别
  • HDFS应用-后端存储cephfs-文件存储和对象存储数据双向迁移