当前位置: 首页 > news >正文

循环神经网络中的梯度消失或梯度爆炸问题产生原因分析(二)

上一篇中讨论了一般性的原则,这里我们具体讨论通过时间反向传播(backpropagation through time,BPTT)的细节。我们将展示目标函数对于所有模型参数的梯度计算方法。

出于简单的目的,我们以一个没有偏置参数的循环神经网络为例说明,其在隐藏层中的激活函数使用恒等函数(\phi \left ( x \right )=x)。

对于时间步t,单个样本的输入及其标签分别为\mathbf{x}_{t}\in \mathbb{R}^{d}y_{t}。计算隐状态\mathbf{h}_{t}\in \mathbb{R}^{h}和输出\mathbf{o}_{t}\in \mathbb{R}^{q}的公式为

\mathbf{h}_{t}=\mathbf{W}_{hx}\mathbf{x}_{t}+\mathbf{W}_{hh}\textbf{h}_{t-1}

\mathbf{o}_{t}=\mathbf{W}_{qh}\mathbf{h}_{t}

其中,权重参数为\mathbf{W}_{hx}\in \mathbb{R}^{h\times d}\mathbf{W}_{hh}\in \mathbb{R}^{h\times h}\mathbf{W}_{qh}\in \mathbb{R}^{q\times h}

目标函数为:

L=\frac{1}{T}\sum_{t=1}^{T}l\left ( y_{t} ,\mathbf{o}_{t}\right )

通常,训练这个模型需要对这些参数分别进行梯度计算:\partial L/\partial \textbf{W}_{hx}\partial L/\partial \textbf{W}_{hh}\partial L/\partial \textbf{W}_{qh}

\frac{\partial L}{\partial \textbf{o}_{t}}=\frac{\partial l\left ( \textbf{o}_{t},y_{t} \right )}{T\cdot \partial o_{t}}\in \mathbb{R}^{q}

\frac{\partial L}{\partial \mathbf{W}_{qh}}=\sum_{t=1}^{T}\frac{\partial L}{\partial \textbf{o}_{t}}\textbf{h}_{t}^{\top }

\frac{\partial L}{\partial \mathbf{W}_{hx}}=\sum_{t=1}^{T}\frac{\partial L}{\partial \textbf{h}_{t}}\textbf{x}_{t}^{\top }

\frac{\partial L}{\partial \mathbf{W}_{hh}}=\sum_{t=1}^{T}\frac{\partial L}{\partial \textbf{h}_{t}}\textbf{h}_{t-1}^{\top }

其中:\frac{\partial L}{\partial \mathbf{h}_{t}}=\sum_{i=t}^{T}\left (\textbf{W} _{hh}^{\top } \right )^{T-i}\textbf{W}_{qh}^{\top }\frac{\partial L}{\partial \textbf{o}_{T+t-i}}

\frac{\partial L}{\partial \mathbf{h}_{t}}中可以看到,这个简单的线性例子已经展现出长序列模型的一些关键问题:

它陷入到了\textbf{W} _{hh}^{\top }的潜在的非常大的指数幂。在这个指数幂中,小于1的特征值将会消失(出现梯度消失),大于1的特征值将会发散(出现梯度爆炸)。

http://www.lryc.cn/news/266676.html

相关文章:

  • JWT signature does not match locally computed signature
  • vitepress项目使用github的action自动部署到github-pages中,理论上可以通用所有
  • Python爬虫---解析---JSONPath
  • 路由器介绍和命令操作
  • Hadoop——分布式计算
  • LaTeX引用参考文献 | Texstudio引用参考文献
  • 如何在Go中使用模板
  • 云原生之深入解析基于FunctionGraph在Serverless领域的FinOps的探索和实践
  • 电子电器架构(E/E)演化 —— 主流主机厂域集中架构概述
  • Python常用的几个函数
  • 【Linux系统基础】(2)在Linux上部署MySQL、RabbitMQ、ElasticSearch等各类软件
  • HarmonyOS4.0系统性深入开发01应用模型的构成要素
  • 线下终端门店调研包含哪些内容
  • 倾斜摄影三维模型数据在行业应用分析
  • Apache Flink 进阶教程(七):网络流控及反压剖析
  • k8s学习 — (DevOps实践)第十三章 DevOps 环境搭建
  • Java_Stream流
  • delphi中,tstringlist使用方法示例
  • 【飞凌 OK113i-C 全志T113-i开发板】视频编解码测试
  • 全部没有问题 (一.5)
  • C++归并排序详解以及代码实现
  • springboot整合JPA 多表关联 :一对多 多对多
  • Python 数据分析 Matplotlib篇 plt.rcParams 字典(第5讲)
  • DeamonSet详解
  • TwIST算法MALTLAB主程序详解
  • Flutter 三: Dart
  • redis基本用法学习(C#调用FreeRedis操作redis)
  • Postman接口测试(超详细整理)
  • 【深入解析spring cloud gateway】12 gateway参数调优与分析
  • Java继承,父类没有无参构造方法时,子类必须要显式调用父类的构造方法