当前位置: 首页 > news >正文

RNN介绍及Pytorch源码解析

介绍一下RNN模型的结构以及源码,用作自己复习的材料。 

RNN模型所对应的源码在:\PyTorch\Lib\site-packages\torch\nn\modules\RNN.py文件中。

RNN的模型图如下:

源码注释中写道,RNN的数学公式:

h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})

h_t表示在t时刻的隐藏状态,

x_t表示在t时刻的输入,

h_{t-1}表示前一层在时间t-1的隐藏状态,或者是在时间“0”的初始隐藏状态。

接下来我们看一下源码中RNN类的初始化(只介绍几个重要的参数):

torch.nn.RNN(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0.0, bidirectional=False, device=None, dtype=None)
  • input_size:输入数据中的特征数(可以理解为嵌入维度 embedding_dim)。
  • hidden_size:处于隐藏状态 h 的特征数(可以理解为输出的特征维度)。
  • num_layers:代表着RNN的层数,默认是1(层),当该参数大于零时,又称为多层RNN。
  • bidirectional:即是否启用双向RNN,默认关闭。

下面是输入部分:

这是Pytorch官方文档中给出的解释:

输入分为input和h_0,当没有提供h_0的时候,h_0默认为0。

当batch_size == Ture时,输入的维度一般为(batch_size * seq_len * emb_dim)。

下面是输出部分:

其中output的维度为(batch_size * seq_len * hidden_size * bidirectional

其中bidirectional表示RNN是双向还是单向的,单向为1,双向为2。

下面使用代码举例:

import torch
import torch.nn as nn
rnn1 = nn.RNN(input_size=20,hidden_size=40,num_layers=4,bidirectional=True)
rnn2 = nn.RNN(input_size=20,hidden_size=40,num_layers=4,bidirectional=False)
tensor1 = torch.randn(5,10,20)
tensor2 = torch.randn(5,10,20)
out1,h_n = rnn1(tensor1)
out2,h_n = rnn2(tensor2)
print(out1.shape) # torch.Size([5, 10, 80])
print(out2.shape) # torch.Size([5, 10, 40])

可以看到当bidirectional=True时,输出的特征维度是hidden_size * 2;

可以看到当bidirectional=False时,输出的特征维度是hidden_size * 1;

http://www.lryc.cn/news/260065.html

相关文章:

  • Qt 文字描边(基础篇)
  • .360勒索病毒解密方法|勒索病毒解决|勒索病毒恢复|数据库修复
  • Nginx(四层+七层代理)+Tomcat实现负载均衡、动静分离
  • 【前端】vscode 相关插件
  • 【MySQL】MySQL库的增删查改
  • 基于基于深度学习的表情识别人脸打分系统
  • Linux|操作系统|Error: Could not create the Java Virtual Machine 报错的解决思路
  • K8S学习指南-minikube的安装
  • 恒创科技:有哪些免费的CDN加速服务
  • Kibana搜索数据利器:KQL与Lucene
  • float32、int8、uint8、int32、uint32之间的区别
  • 百度搜索展现服务重构:进步与优化
  • icmp协议、ip数据包 基础
  • es6从url中获取想要的参数
  • 【elementui笔记:el-table表格的输入校验】
  • 每天五分钟计算机视觉:GoogLeNet的核心模型结构——Inception
  • 卡片C语言(2021年蓝桥杯B)
  • 数据库动态视图和存储过程报表数据管理功能设计
  • css+js 选项卡动画效果
  • [C错题本]转义字符/指针与首元素/运算
  • Layui继续学习
  • react+datav+echarts实现可视化数据大屏
  • CSS新手入门笔记整理:CSS浮动布局
  • 微服务组件Sentinel的学习(1)
  • 小程序 -网络请求post/get
  • Elasticsearch 8.10之前同义词最佳实践
  • 芯知识 | 什么是OTP语音芯片?唯创知音WTN6xxx系列:低成本智能语音解决方案
  • Linux内核密钥环
  • web前端之正弦波浪动功能、repeat、calc
  • 使用工具 NVM来管理不同版本的 Node.js启动vue项目