当前位置: 首页 > news >正文

神经网络 torch.nn---nn.RNN()

torch.nn - PyTorch中文文档 (pytorch-cn.readthedocs.io)

RNN — PyTorch 2.3 documentation

torch.nn---nn.RNN()

nn.RNN(input_size=input_x,hidden_size=hidden_num,num_layers=1,nonlinearity='tanh', #默认'tanh'bias=True,  #默认是Truebatch_first=False,dropout=0,bidirectional=False  #默认为False)

参数说明:

  • input_size – 输入x的特征数量。

  • hidden_size – 隐层的特征数量。

  • num_layers – RNN的层数。

  • nonlinearity – 激活函数。指定非线性函数使用tanh还是relu。默认是tanh

  • bias – 是否使用偏置。

  • batch_first – 如果True的话,那么输入Tensor的shape应该是[batch_size, time_step, feature],输出也是这样。默认是 False,就是这样形式,(seq(num_step), batch, input_dim),也就是将序列长度放在第一位,batch 放在第二位

  • dropout – 默认不使用,如若使用将其设置成一个0-1的数字即可。如果值非零,那么除了最后一层外,其它层的输出都会套上一个dropout层。

  • 是否使用双向的 rnn,默认是 False

输入输出shape

  1. RNN的输入:input_shape = [时间步数, 批量大小, 特征维度] = [num_steps(seq_length), batch_size, input_dim]=input (seq_len, batch, input_size)保存输入序列特征的tensor。
  2. RNN的隐藏层:h_0 (num_layers * num_directions, batch, hidden_size): 保存着初始隐状态的tensor。
  3. RNN的输出: (output, h_n)。在前向计算后会分别返回输出和隐藏状态h,其中输出指的是隐藏层在各个时间步上计算并输出的隐藏状态,它们通常作为后续输出层的输⼊。
  4. output (seq_len, batch, hidden_size * num_directions)形状为(时间步数, 批量大小, 隐藏单元个数): 保存着RNN最后一层的输出特征。如果输入是被填充过的序列,那么输出也是被填充的序列。
  5. 隐藏状态h的形状为(层数, 批量大小,隐藏单元个数)=h_n (num_layers * num_directions, batch, hidden_size): 保存着最后一个时刻隐状态。隐藏状态指的是隐藏层在最后时间步的隐藏状态:当隐藏层有多层时,每⼀层的隐藏状态都会记录在该变量中。

RNN模型参数:

  • weight_ih_l[k] – 第k层的 input-hidden 权重, 可学习,形状是(input_size x hidden_size)

  • weight_hh_l[k] – 第k层的 hidden-hidden 权重, 可学习,形状是(hidden_size x hidden_size)

  • bias_ih_l[k] – 第k层的 input-hidden 偏置, 可学习,形状是(hidden_size)

  • bias_hh_l[k] – 第k层的 hidden-hidden 偏置, 可学习,形状是(hidden_size)

计算过程

​h_t是时刻t的隐状态。
x_t是上一层时刻t的隐状态,或者是第一层在时刻t的输入。
如果nonlinearity='relu',那么将使用relu代替tanh作为激活函数。

http://www.lryc.cn/news/378436.html

相关文章:

  • RocketMQ-记一次生产者发送消息存在超时异常
  • ls命令的参数选项
  • 网络安全:Web 安全 面试题.(文件上传漏洞)
  • 智源联合多所高校推出首个多任务长视频评测基准MLVU
  • Linux系统:线程概念 线程控制
  • LearnOpenGL - Android OpenGL ES 3.0 绘制纹理
  • 山东济南最出名的起名大师颜廷利:二十一世纪哲学的领航者
  • Nginx 负载均衡实现上游服务健康检查
  • 小程序使用接口wx.getLocation配置
  • Protobuf安装配置--附带每一步截图
  • 力扣1019.链表中的下一个更大节点
  • 查询mysql库表的几个语句
  • 【CT】LeetCode手撕—103. 二叉树的锯齿形层序遍历
  • 1958springboot VUE宿舍管理系统开发mysql数据库web结构java编程计算机网页源码maven项目
  • LVS DR模式
  • myslql事务示例
  • 解决Flutter应用程序的兼容性问题
  • 整合微信支付一篇就够了
  • 视创云展为企业虚拟展厅搭建,提供哪些功能?
  • c++ 常用的锁及用法介绍和示例
  • PostgreSQL源码分析——口令认证
  • Stability-AI(图片生成视频)
  • Linux机器通过Docker-Compose安装Jenkins发送Allure报告
  • 基于Gunicorn+Flask+Docker模型高并发部署
  • java:类型变量(TypeVariable)解析--基于TypeResolver实现将类型变量替换为实际类型
  • ru俄罗斯域名如何申请SSL证书?
  • python实现购物车的功能
  • 日元预计明年开始上涨
  • 8、PHP 实现二进制中1的个数、数值的整数次方
  • linux git凭证管理