当前位置: 首页 > news >正文

GRU(门控循环单元)的原理与代码实现

1.GRU的原理

1.1重置门和更新门

1.2候选隐藏状态

 

1.3隐状态 

2. GRU的代码实现

#导包
import torch
from torch import nn
import dltools#加载数据
batch_size, num_steps = 32, 35
train_iter, vocab = dltools.load_data_time_machine(batch_size, num_steps)#封装函数:实现初始化模型参数
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))# 更新门参数W_xz, W_hz, b_z = three()# 重置门W_xr, W_hr, b_r = three()# 候选隐藏状态参数W_xh, W_hh, b_h = three()# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params#定义函数:初始化隐藏状态
def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device))#定义函数:构建GRU网络结构
def gru(inputs, state, params):[W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, )#训练和预测
vocab_size, num_hiddens, device = len(vocab), 256, dltools.try_gpu()
num_epochs, lr = 500, 5
model = dltools.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 

3.pytorch 简洁实现版_GRU调包实现 

num_inputs = vocab_size
#创建网络层
gru_layer = nn.GRU(num_inputs, num_hiddens)
#建模
model = dltools.RNNModel(gru_layer, len(vocab))
#将模型转到device上
model = model.to(device)
#模型训练
dltools.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

 

4.知识点个人理解

 

 

http://www.lryc.cn/news/442730.html

相关文章:

  • 【医疗大数据】医疗保健领域的大数据管理:采用挑战和影响
  • gevent + flask 接口会卡住
  • SpringCloud Alibaba五大组件之——Sentinel
  • brpc之io事件分发器
  • MySQL | 知识 | 从底层看清 InnoDB 数据结构
  • es的封装
  • 写一个自动化记录鼠标/键盘的动作,然后可以重复执行的python程序
  • Spring Boot-热部署问题
  • 深度学习——管理模型的参数
  • 芯片验证板卡设计原理图:372-基于XC7VX690T的万兆光纤、双FMC扩展的综合计算平台 RISCV 芯片验证平台
  • 【软设】 系统开发基础
  • Linux移植之系统烧写
  • 【数据结构与算法】LeetCode:双指针法
  • Istio下载及安装
  • Redis基础数据结构之 Sorted Set 有序集合 源码解读
  • 蓝队技能-应急响应篇Web内存马查杀JVM分析Class提取诊断反编译日志定性
  • 递归快速获取机构树型图
  • [Web安全 网络安全]-XSS跨站脚本攻击
  • 数据库数据恢复—SQL Server附加数据库出现“错误823”怎么恢复数据?
  • Vscode 中新手小白使用 Open With Live Server 的坑
  • 【深度学习 transformer】Transformer与ResNet50在自定义数据集图像分类中的效果比较
  • 【系统架构设计师】专业英语90题(附答案详解)
  • ItemXItemEffect | ItemEffect
  • web 动画库
  • 我的AI工具箱Tauri版-MicrosoftTTS文本转语音
  • 【Webpack--013】SourceMap源码映射设置
  • 创新驱动,技术引领:2025年广州见证汽车电子技术新高度
  • Spring Boot框架在心理教育辅导系统中的应用案例
  • Shiro-550—漏洞分析(CVE-2016-4437)
  • 【例题】lanqiao4425 咖啡馆订单系统