当前位置: 首页 > news >正文

attention 注意力机制 学习笔记-GPT2

注意力机制

这可能是比较核心的地方了。

gpt2 是一个decoder-only模型,也就是仅仅使用decoder层而没有encoder层。

decoder层中使用了masked-attention 来进行注意力计算。在看代码之前,先了解attention-forward的相关背景知识。

在普通的self-attention 中,对于一个长为T的句子,对其中第t个单词。需要计算t和句子中所有T个单词的注意力。也就是使用词t的Q向量 q t q_t qt 和 T中的所有单词的key向量 k j , 0 < = j < = T k_j, 0<=j<=T kj,0<=j<=T相乘。得到词t和句子中其他单词的注意力得分。

在这里插入图片描述

于是对于词t和当前句子S, 得到了注意力得分向量,而后对该向量使用softmax. 标准化的同时得到softmax后的注意力得分。

然后使用 每个词对应的值向量与注意力得分相乘之后再求和
( v 1 , v 2 , . . . , v T ) ( s c o r e t 1 s c o r e t 2 . . . s c o r e t T ) = o u t t (v_1, v_2, ..., v_T) \begin{pmatrix}score_{t1}\\score_{t2}\\... \\score_{tT}\end{pmatrix} = out_t (v1,v2,...,vT) scoret1scoret2...scoretT =outt
这里要注意, s o c r e t i socre_{ti} socreti 是一个标量值,但是 v t v_t vt 是 一个向量,长度和词嵌入向量长度相同,相加时,对每个向量位置元素对应相加。

在这里插入图片描述

对于masked-attention呢,实际上就是计算注意力得分时候,对第t个单词,仅仅计算0到t单词的注意力得分,t~T 部分的注意力得分不计算,计算softmaxs时t之后的部分以初值0代替。

在这里插入图片描述

在这里插入图片描述

multi-head attention

前面了解了attention基本知识,就很好理解多头注意力了。多头注意力实际上就是将单个Q,K,V向量,分裂为多个头,然后和self-attention一样流程计算每个头的注意力,最后得到一个输出向量,然后将多个头的输出向量拼接到一起,得到最后的输出结果。

在这里插入图片描述

比如,原本的一个向量长度为 l e n g t h Q = = l e n g t h K = = l e n g t h V = = 168 length_Q == length_K == length_V == 168 lengthQ==lengthK==lengthV==168 分裂为12个注意力头之后,每个注意力头的QKV向量长度为 l e n g t h Q i = = l e n g t h K i = = l e n g t h V i = 64 , i ∈ [ 0 , 12 ] length_{Q_i} == length_{K_i} == length_{V_i} = 64, i \in [0,12] lengthQi==lengthKi==lengthVi=64,i[0,12]

然后和分裂的self-attention一样,对每个词t的第i个头的Q向量 Q t i Q_{t_i} Qti,与其他词的第i个头的K向量 K j i , 0 < = j < = t , i ∈ [ 0 , 12 ] K_{j_i}, 0<=j<=t, i\in[0,12] Kji,0<=j<=t,i[0,12] 内积,得到注意力得分。

而后和self-attention一样的,每一个注意力头的Value向量和该头的注意力得分相乘,得到该注意力头的结果。

对于12个头长度为64的attention,最后得到12个64长的注意力结果

再将其拼接,得到长为768的注意attention forward结果,和单个注意力头但是长为768的attention结果相同。

在这里插入图片描述

http://www.lryc.cn/news/484206.html

相关文章:

  • 什么是HTTP,什么是HTTPS?HTTP和HTTPS都有哪些区别?
  • SkyWalking-安装
  • RabbitMQ运维
  • Go语言并发精髓:深入理解和运用go语句
  • 基于STM32的智能家居系统:MQTT、AT指令、TCP\HTTP、IIC技术
  • 分糖果(相等分配)
  • docker构建jdk11
  • 唐帕科技校园语音报警系统:通过关键词识别,阻止校园霸凌事件
  • 酒店行业数据仓库
  • A029-基于Spring Boot的物流管理系统的设计与实现
  • Python Day5 进阶语法(列表表达式/三元/断言/with-as/异常捕获/字符串方法/lambda函数
  • 一文了解Android的核心系统服务
  • Scala的Array(1)
  • [Linux] Linux信号捕捉
  • Elasticsearch的查询语法——DSL 查询
  • 开发语言中,堆区和栈区的区别
  • 驾校增加无人机培训项目可行性技术分析
  • JavaWeb后端开发知识储备1
  • ISUP协议视频平台EasyCVR视频设备轨迹回放平台智慧农业视频远程监控管理方案
  • 大数据新视界 -- 大数据大厂之 Impala 存储格式转换:从原理到实践,开启大数据性能优化星际之旅(下)(20/30)
  • 百度搜索AI探索版多线程批量生成TXT原创文章软件-可生成3种类型文章
  • ubuntu20.04 解决Pytorch默认安装CPU版本的问题
  • 名词解释-2-形状算数实验、潜在空间、3D生成模型
  • Android 使用python统计getevent按键
  • NVIDIA jetson查看资源占用情况,打印/保存资源使用情况日志
  • ssm102“魅力”繁峙宣传网站的设计与实现+vue(论文+源码)_kaic
  • 逐行加载 HTML 内容并实时显示效果:使用 wxPython 的实现
  • UE4 Cook 从UAT传递参数给UE4Editor
  • 【学习日记】notebook添加JAVA支持
  • 以太坊系地址衍生算法分层确定性生成逻辑