当前位置: 首页 > news >正文

4.5.门控循环单元GRU

门控循环单元GRU

​ 对于一个序列,不是每个观察值都是同等重要的,可能会遇到一下几种情况:

  1. 早期观测值对预测所有未来观测值都具有非常重要的意义。

    考虑极端情况,第一个观测值包含一个校验和,目的是在序列的末尾辨别校验和事否正确,我们希望有某些机制在一个记忆元里存储重要的早期信息。如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度。

  2. 一些词元没有相关的观测值

    在对网页内容进行情感分析时,可能一些辅助的HTML代码与网页传达的情绪无关,我们希望有一些机制来跳过隐状态中的此类词元

  3. 序列的各个部分存在逻辑中断

    书的章节之间可能也会有过渡,证券的熊市,牛市之间可能会有过渡。这种情况下, 最好有一种方法来重置我们的内部状态表示

​ 有很多方法来解决这类问题,最早的方法是"长短期记忆"(long-short-term memory,LSTM)。门控循环单元(gated recurrent unit,GRU)是一个稍微简化的变体,通常能提供同等的效果,并且计算速度更快。

1.门控隐状态

​ 门控循环单元与普通的循环神经网络之间的关键区别在于: 前者支持隐状态的门控。 这意味着模型有专门的机制来确定应该何时更新隐状态, 以及应该何时重置隐状态。这些机制是可学习的。

1.1 重置门和更新门

在这里插入图片描述

​ 重置门和更新门的输入如图所示。重置门允许我们控制”可能还想记住“的过去状态的数量;更新门将允许我们控制新状态中有多少个是旧状态的副本。

​ 其中输入是由当前时间步的输入和前一时间步的隐状态给出,两个门的输出由使用sigmoid激活函数的两个全连接层给出。

​ 假设输入是一个小批量 X t ∈ R n × d X_t\in \R^{n\times d} XtRn×d(样本数量 n n n,输入个数 d d d),上一个时间步的隐状态是 H t − 1 ∈ R n × h H_{t-1}\in \R^{n\times h} Ht1Rn×h(隐藏单元个数 h h h)。那么重置门 R t R_t Rt和更新门 Z t Z_t Zt(均为 R n × h \R^{n\times h} Rn×h)的计算如下所示:
R t = σ ( X t W x r + H t − 1 W h r + b r ) Z t = σ ( X t W x z + H t − 1 W h z + b z ) R_t = \sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\\ Z_t = \sigma(X_t W_{xz}+H_{t-1}W_{hz}+b_z) Rt=σ(XtWxr+Ht1Whr+br)Zt=σ(XtWxz+Ht1Whz+bz)
​ 其中 W x r , W x z ∈ R d × h W_{xr},W_{xz}\in \R^{d\times h} Wxr,WxzRd×h W h r , W h z ∈ R h × h W_{hr},W_{hz}\in \R^{h\times h} Whr,WhzRh×h是权重参数, b r , b z ∈ R 1 × h b_r,b_z\in \R^{1\times h} br,bzR1×h是偏置参数。求和过程中会触发广播机制。 我们使用sigmoid函数将输入值转换到区间¥(0,1)$。

1.2 候选隐状态

在这里插入图片描述

​ 将重置门 R t R_t Rt与常规隐状态更新机制集成,得到在时间步 t t t的候选隐状态 H ^ t ∈ R n × h \hat{H}_t\in\R ^{n\times h} H^tRn×h
H ^ t = t a n h ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) \hat{H}_t = tanh(X_tW_{xh}+(R_t\odot H_{t-1})W_{hh}+b_h) H^t=tanh(XtWxh+(RtHt1)Whh+bh)
​ 其中 W x h ∈ R d × h W_{xh}\in\R^{d\times h} WxhRd×h W h h ∈ R h × h W_{hh}\in \R ^{h\times h} WhhRh×h是权重参数, b h ∈ R 1 × h b_h\in \R^{1\times h} bhR1×h是偏置项,符号 ⊙ \odot 是Hadamard积(按元素乘积)运算符,此处使用tanh非线性激活函数确保候选隐状态中的值保持在区间 ( − 1 , 1 ) (-1,1) (1,1)中。。

R t ⊙ H t − 1 R_t\odot H_{t-1} RtHt1的元素相乘可以减少以往状态的影响,每当重置门 R t R_t Rt中的项接近1时,我们恢复一个普通的循环神经网络,如果 R t R_t Rt全为0,则之前的信息全部遗忘。重置门是可以学习的,通过学习,可以根据目前的输入决定哪些东西需要遗忘。

1.3 隐状态

在这里插入图片描述

​ 1.2中得出的是候选隐状态,真正的隐状态需要结合更新门的效果。这一步确定新的隐状态 H t ∈ R n × h H_t\in \R^{n\times h} HtRn×h在多大程度上来自旧的状态 H t − 1 H_{t-1} Ht1和新的候选状态 H t ^ \hat{H_t} Ht^。更新门 Z t Z_t Zt仅需要在 H t − 1 H_{t-1} Ht1 H ^ t \hat{H}_t H^t之间进行按元素的凸组合就可以实现,于是得出了最终的更新公式:
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ^ t H_t =Z_t \odot H_{t-1}+(1-Z_t)\odot \hat{H}_t Ht=ZtHt1+(1Zt)H^t
​ 容易看出,更新门 Z t Z_t Zt越趋近1,模型就倾向只保留旧状态,此时来自输入 X t X_t Xt的信息基本上被忽略,从而有效地跳过了依赖链条中的时间步 t t t。相反,当 Z t Z_t Zt接近0时,新的隐状态 H t H_t Ht就会接近候选隐状态 H t ^ \hat {H_t} Ht^

2.代码实现

2.1 从零开始

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return paramsdef init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

2.2 训练与预测

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

2.3 简洁实现

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
http://www.lryc.cn/news/416181.html

相关文章:

  • 10种 Python数据结构,从入门到精通
  • 【AI】人工智能时代,程序员如何保持核心竞争力?
  • WPF学习(3)- WrapPanel控件(瀑布流布局)+DockPanel控件(停靠布局)
  • 【python】Python中实现定时任务常见的几种方式原理分析与应用实战
  • 老公请喝茶,2024年老婆必送老公的养生茶,暖暖的很贴心
  • 3d打印相关资料
  • MySQL1 DDL语言
  • el-tree懒加载状态下实现搜索筛选(纯前端)
  • NLP——Transfromer 架构详解
  • 大模型算法面试题(二十)
  • 2024最新最全面的Selenium 3.0 + Python自动化测试框架
  • 海运中的甩柜是怎么回事❓怎么才能避免❓
  • Win11+docker+gpu+vscode+pytorch配置anomalib(2)
  • AI在招聘市场趋势分析中的应用
  • AMEYA360:太阳诱电应对 165℃的叠层金属类功率电感器实现商品化!
  • Nginx进阶-常见配置(三)
  • 开源协作式书签管理器推荐
  • 【线性代数】【二】2.2极大线性无关组与向量空间的基
  • STM32常见的下载方式有三种
  • RK3568-npu模型转换推理
  • 《C语言程序设计 第4版》笔记和代码 第十二章 数据体和数据结构基础
  • 学习记录——day26 进程间的通信 无名管道 无名管道 信号通信 特殊的信号处理
  • WHAT - xmlhttprequest vs fetch vs wretch
  • 吴恩达老师机器学习作业-ex7(聚类)
  • lombok 驼峰命名缺陷,导致后台获取参数为null的解决办法
  • 【dockerpython】亲测有效!适合新手!docker创建conda镜像+容器使用(挂载、端口映射、gpu使用)+云镜像仓库教程
  • 矩阵,求矩阵秩、逆矩阵
  • 指针和const
  • 基于C#调用文心一言大模型制作桌面软件(可改装接口)
  • VScode插件安装