当前位置: 首页 > news >正文

零基础-动手学深度学习-9.1. 门控循环单元(GRU)及代码实现

  前一章中我们介绍了循环神经网络的基础知识, 这种网络可以更好地处理序列数据。 我们在文本数据上实现了基于循环神经网络的语言模型, 但是对于当今各种各样的序列学习问题,这些技术可能并不够用。例如,循环神经网络在实践中一个常见问题是数值不稳定性。 尽管我们已经应用了梯度裁剪等技巧来缓解这个问题, 但是仍需要通过设计更复杂的序列模型来进一步处理它。 具体来说,我们将引入两个广泛使用的网络, 即门控循环单元(gated recurrent units,GRU)和 长短期记忆网络(long short-term memory,LSTM)。 然后,我们将基于一个单向隐藏层来扩展循环神经网络架构。 我们将描述具有多个隐藏层的深层架构, 并讨论基于前向和后向循环计算的双向设计。 现代循环网络经常采用这种扩展。 在解释这些循环神经网络的变体时, 我们将继续考虑 8节中的语言建模问题。事实上,语言建模只揭示了序列学习能力的冰山一角。 在各种序列学习问题中,如自动语音识别、文本到语音转换和机器翻译, 输入和输出都是任意长度的序列。 为了阐述如何拟合这种类型的数据, 我们将以机器翻译为例介绍基于循环神经网络的 “编码器-解码器”架构和束搜索,并用它们来生成序列

9.1. 门控循环单元(GRU)

在 8.7节中, 我们讨论了如何在循环神经网络中计算梯度, 以及矩阵连续乘积可以导致梯度消失或梯度爆炸的问题。 下面我们简单思考一下这种梯度异常在实践中的意义:

  • 我们可能会遇到这样的情况:早期观测值对预测所有未来观测值具有非常重要的意义。 考虑一个极端情况,其中第一个观测值包含一个校验和, 目标是在序列的末尾辨别校验和是否正确。 在这种情况下,第一个词元的影响至关重要。 我们希望有某些机制能够在一个记忆元里存储重要的早期信息。 如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度, 因为它会影响所有后续的观测值。

  • 我们可能会遇到这样的情况:一些词元没有相关的观测值。 例如,在对网页内容进行情感分析时, 可能有一些辅助HTML代码与网页传达的情绪无关。 我们希望有一些机制来跳过隐状态表示中的此类词元。

  • 我们可能会遇到这样的情况:序列的各个部分之间存在逻辑中断。 例如,书的章节之间可能会有过渡存在, 或者证券的熊市和牛市之间可能会有过渡存在。 在这种情况下,最好有一种方法来重置我们的内部状态表示

在学术界已经提出了许多方法来解决这类问题。 其中最早的方法是“长短期记忆”(long-short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997), 我们将在 9.2节中讨论。 门控循环单元(gated recurrent unit,GRU) (Cho et al., 2014) 是一个稍微简化的变体,通常能够提供同等的效果, 并且计算 (Chung et al., 2014)的速度明显更快。 由于门控循环单元更简单,我们从它开始解读。

9.1.1. 门控隐状态

对于重置门和更新门的理解为各司其职。重置门单方面控制自某个节点开始,之前的记忆(隐状态)不在乎了,直接清空影响,同时也需要更新门帮助它实现记忆的更新。更新们更多是用于处理梯度消失问题,可以选择一定程度地保留记忆,防止梯度消失。

9.1.2. 从零开始实现

为了更好地理解门控循环单元模型,我们从零开始实现它。 首先,我们读取 8.5节中使用的时间机器数据集:

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

9.1.2.1. 初始化模型参数

下一步是初始化模型参数。 我们从标准差为的高斯分布中提取权重, 并将偏置项设为,超参数num_hiddens定义隐藏单元的数量, 实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01
#前面两个def和rnn是一样的def three():#这个函数用来辅助定义两个w和一个breturn (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))
#是的就是方便在这个下面进行一个参数定义W_xz, W_hz, b_z = three()  # 更新门参数W_xr, W_hr, b_r = three()  # 重置门参数W_xh, W_hh, b_h = three()  # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 对所有parmes附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

9.1.2.2. 定义模型

现在我们将定义隐状态的初始化函数init_gru_state。 与 8.5节中定义的init_rnn_state函数一样, 此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值全部为零。

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )#这里还是返回一个tuple 忘了为啥来着 回去看来看因为lstm需要另一个参数嘻嘻

现在我们准备定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

@是按矩阵乘法,*是按元素乘法的意思

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:#对每个序列的时间长度Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

9.1.2.3. 训练与预测

训练和预测的工作方式与 8.5节完全相同。 训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)#gru放进去
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)输出:
perplexity 1.1, 19911.5 tokens/sec on cuda:0
time traveller firenis i heidfile sook at i jomer and sugard are
travelleryou can show black is white by argument said filby

9.1.3. 简洁实现

高级API包含了前文介绍的所有配置细节, 所以我们可以直接实例化门控循环单元模型。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)#矩阵乘法做到一起所以速度快
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)输出:
perplexity 1.0, 109423.8 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
traveller with a slight accession ofcheerfulness really thi

9.1.4. 小结

  • 门控循环神经网络可以更好地捕获时间步距离很长的序列上的依赖关系。

  • 重置门有助于捕获序列中的短期依赖关系。

  • 更新门有助于捕获序列中的长期依赖关系。

  • 重置门打开时,门控循环单元包含基本循环神经网络;更新门打开时,门控循环单元可以跳过子序列。

http://www.lryc.cn/news/613078.html

相关文章:

  • Docker国内可用镜像列表(长期免费)
  • 接入小甲鱼数字人API教程【详解】
  • Next.js 样式:CSS 模块、Sass 等
  • ENSP 中静态路由负载分担
  • vue3接收SSE流数据进行实时渲染日志
  • RabbitMQ-日常运维命令
  • CS231n2017 Assignment3 RNN、LSTM部分
  • 3深度学习Pytorch-神经网络--全连接神经网络、数据准备(构建数据类Dataset、TensorDataset 和数据加载器DataLoader)
  • PID基础知识
  • 关于其他副脑类 GPTs 市场现状及研究报告
  • mysql全屏终端全量、部分备份、恢复脚本
  • Python面试题及详细答案150道(16-30) -- 数据结构篇
  • 分布式微服务--GateWay(过滤器及使用Gateway注意点)
  • 告别YAML,在SpringBoot中用数据库配置替代配置文件
  • word生成问题总结
  • 【遥感图像入门】近三年遥感图像建筑物细粒度分类技术一览
  • Day116 若依融合mqtt
  • 界面组件DevExpress WPF中文教程:网格视图数据布局 - 紧凑模式
  • 音视频时间戳获取与同步原理详解
  • 【Docker】RustDesk远程控制-私有化部署开源版本
  • 生成式AI的“幽灵漏洞”:法律如何为技术的阴影划界
  • PCIe Base Specification解析(八)
  • 从配置到远程访问:如何用群晖NAS FTP+ Cpolar搭建稳定文件传输通道
  • 深入解析Three.js中的BufferAttribute:源码与实现机制
  • Linux下动态库链接的详细过程
  • C++位图(Bitmap)与布隆过滤器(Bloom Filter)详解及海量数据处理应用
  • vue3父组件把一个对象整体传入子组件,还是把一个对象的多个属性分成多个参数传入
  • C#中统计某个字符出现次数的最简单方法
  • Git `cherry-pick` 工具汇总
  • Numpy科学计算与数据分析:Numpy线性代数基础与实践