当前位置: 首页 > news >正文

门控循环单元GRU

目录

  • 一、GRU提出的背景:
    • 1.RNN存在的问题:
    • 2.GRU的思想:
  • 二、更新门和重置门:
  • 三、GRU网络架构:
    • 1.更新门和重置门如何发挥作用:
      • 1.1候选隐藏状态H~t:
      • 1.2隐藏状态Ht:
    • 2.GRU:
  • 四、底层源码:
  • 五、Pytorch版代码:

一、GRU提出的背景:

1.RNN存在的问题:

循环神经网络讲解文章

由于RNN的隐藏状态ht用于记录之前的所有序列信息,而对于长序列问题来说ht会记录太多序列信息导致序列时序特征区分度很差(最前面的序列特征因为进行了太多轮迭代往往不太好从ht中提取),因此一些比较靠前但很重要的序列特征在ht中可能就不太被重视,而一些比较靠后但没用的序列特征在ht中被过于关注。

2.GRU的思想:

GRU的思想是如何将隐藏状态ht中一些重要的序列信息给予高的关注,而一些不重要的序列信息给予低的关注。

  • 对于需要关注的序列信息,使用更新门来提高关注度
  • 对于需要遗忘的序列信息,使用遗忘门来降低关注度

二、更新门和重置门:

GRU提出更新门和重置门的思想来改变隐藏状态ht中不同序列信息的关注度。
在这里插入图片描述
更新门和重置门可以分别看做一个全连接层的隐藏层,这样的话上图就等价于两个并排的隐藏层,其中:

  • 每个隐藏层都接收之前时间步的隐藏状态Ht-1和当前时间步的输入batch。
  • 更新门和重置门有各自的可学习权重参数和偏置值,公式含义类似传统RNN。
  • Rt 和 Zt 都是根据过去的隐藏状态 Ht-1 和当前输入 Xt 计算得到的 [0,1] 之间的量(激活函数)。

三、GRU网络架构:

1.更新门和重置门如何发挥作用:

重置门对过去t个时间步的序列信息(Ht-1)进行选择,更新门对当前一个时间步的序列信息(Xt)进行选择。具体原理如下:

1.1候选隐藏状态H~t:

候选隐藏状态既保留了之前的隐藏状态Ht-1,又保留了当前一个时间步的序列信息Xt。
在这里插入图片描述
因为Rt是一个[0,1] 之间的量,所以Rt×Ht-1是对之前的隐藏状态Ht-1进行一次选择:Rt 在某个位置的值越趋近于0,则表示Ht-1这个位置的序列信息越倾向于被丢弃,反之保留。

综上,重置门的作用是对过去的序列信息Ht-1进行选择,Ht-1中哪些序列信息当前的输出是有用的,应该被保存下来,而哪些序列信息是不重要的,应该被遗忘。

1.2隐藏状态Ht:

在这里插入图片描述
因为Zt是一个[0,1] 之间的量,如果Zt全为0,则当前隐藏状态Ht为当前候选隐藏状态,该候选隐藏状态不仅保留了之前的序列信息,还保留了当前时间步batch的序列信息;如果Zt全为1,则当前隐藏状态Ht为上一个时间步的隐藏状态。

综上,更新门的作用是决定当前一个时间步的序列信息是否保留,如果Zt全为0,则说明当前时间步batch的序列信息是有用的(候选隐藏状态包含之前的序列信息和当前一个时间步的序列信息),保留下来加入到隐藏状态Ht中;如果Zt全为1,则说明当前时间步batch的序列信息是没有用的,丢弃当前batch的序列信息,直接使用上一个时间步的隐藏状态Ht-1作为当前的隐藏状态Ht。(Ht-1仅包含之前的序列信息,不包含当前一个时间步的序列信息)

2.GRU:

GRU网络架构如下,可以看做是三个隐藏层并排的架构。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

四、底层源码:

代码中num_hiddens表示隐藏层神经元个数,由于重置门、更新门的输出维度相同,所以重置门和更新门两个隐藏层的神经元个数也是一样的=num_hiddens。

import torch
from torch import nn
from d2l import torch as d2l# 数据预处理,获取datalodaer和字典
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)# 初始化可学习参数
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01def three():return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()W_xr, W_hr, b_r = three()W_xh, W_hh, b_h = three()W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params# 初始化隐藏状态
def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),)# 定义门控循环单元模型
def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)# 训练
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

五、Pytorch版代码:

num_inputs = vocab_size
# 调用pytorch构建网络结构
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
http://www.lryc.cn/news/411715.html

相关文章:

  • 程序员修炼之路
  • PHP时间相关函数
  • python进阶——python面向对象
  • 【无标题】vue2鼠标悬停(hover)时切换图片
  • 每天一个数据分析题(四百五十九)- 分析法
  • 英语:十、助动词和情态动词
  • DB2-Db2DefaultValueConverter
  • (自适应手机端)行业协会机构网站模板
  • 视频理解调研笔记 | 2021年前视频动作分类发展脉络
  • 怎么通过 ssh 访问远程设备
  • linux Ubuntu 安装mysql-8.0.39 二进制版本
  • ZooKeeper日志自动清理实用脚本
  • KVM+GFS分布式存储系统构建高可用
  • CIFAR-10 数据集图像分类与可视化
  • 没有了高项!!2024软考下半年软考高级哪个最容易考过?
  • 用户自定义Table API Connector(Sources Sinks)
  • 自闭症儿童能否摘帽?摘帽成功的秘诀揭秘
  • 主题巴巴WordPress主题合辑打包下载+主题巴巴SEO插件
  • git把本地文件上传远程仓库的流程
  • 基于springboot+vue+uniapp的养老院管理系统小程序
  • el-popover实现点击空白区域关闭,弹窗区域不关闭
  • Disjoint Set Union
  • 手写 Hibernate ORM 框架 05-基本效果测试
  • Unity材质球自动遍历所需贴图
  • C++那些事之结构化绑定
  • ECRS工时分析软件:工业工程精益生产的智慧引擎
  • 大语言模型的核心岗位及其要求
  • 【屏驱MCU】RT-Thread 文件系统接口解析
  • 进程管理工具top ps
  • 2年社招冲击字节,一天三面斩获offer