当前位置: 首页 > news >正文

LLaMA中ROPE位置编码实现源码解析

1、Attention中q,经下式,生成新的q。m为句长length,d为embedding_dim/head
θ i = 1 1000 0 2 i d \theta_i=\frac{1}{10000^\frac{2i}{d}} θi=10000d2i1
在这里插入图片描述

2、LLaMA中RoPE源码

import torchdef precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0):'''计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx:param dim: q,k,v的最后一维,一般为emb_dim/head_num:param end: 句长length:param constant: 这里指10000:return:复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t))'''# freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta# 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)]freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2]# 计算mt = torch.arange(end, device=freqs.device)  # [length]# 计算m*thetafreqs = torch.outer(t, freqs).float()  # [length, d]# freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1# 计算cos(m*theta)+j*sin(m*theta)freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64# freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0),  cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))]# 其中j为虚数单位, m=0,1,...,length-1return freqs_cis # [length, d]def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):ndim = x.ndimassert 0 <= 1 < ndimassert freqs_cis.shape == (x.shape[1], x.shape[-1])shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2)return freqs_cis.view(*shape) # [1, length, 1, d/2]def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor,):# 先将xq维度变为[bs, length, head,  d/2, 2], 利用torch.view_as_complex转变为复数# xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)]xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2]# 同样的,xk_:[k0+j*k1, k2+j*k3, ..., k(d-2)+j*k(d-1)]xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2]# 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下# (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0))# 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0)# 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部# 最后经flatten函数将维度拉平,即[bs, length, head, d]# 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)]xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d]# 即为新生成的qxk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)if __name__=='__main__':# (bs, length, head, d)q = torch.randn((2, 10, 12, 32))  # q=[q0, q1, .., qd-1]k = torch.randn((2, 10, 12, 32))v = torch.randn((2, 10, 12, 32))freqs_cis= precompute_freqs_cis(dim=32, end=10, constant= 10000.0)# print(freqs_cis.detach().numpy())q_new, k_new = apply_rotary_emb(xq=q, xk=k, freqs_cis=freqs_cis)print()

3、表示
看1中的公式表示,q0和q1相互作用,得到新的q0和q1,也即
q 0 n e w = q 0 ∗ c o s ( m θ 0 ) − q 1 ∗ s i n ( m θ 0 ) q^{new}_0=q_0*cos(m\theta_0)-q_1*sin(m\theta_0) q0new=q0cos(mθ0)q1sin(mθ0)
q 1 n e w = q 1 ∗ c o s ( m θ 0 ) + q 0 ∗ s i n ( m θ 0 ) q^{new}_1=q_1*cos(m\theta_0)+q_0*sin(m\theta_0) q1new=q1cos(mθ0)+q0sin(mθ0)
这里将 ( q 0 , q 1 ) (q_0,q_1) (q0,q1) ( q 0 n e w , q 1 n e w ) (q^{new}_0,q^{new}_1) (q0new,q1new)看做向量,很明显上式是向量旋转,旋转角度为逆时针 m θ 0 m\theta_0 mθ0

在这里插入图片描述
可与PalM中ROPE实现方式做对比
PaLM中ROPE位置编码实现源码解析

http://www.lryc.cn/news/142742.html

相关文章:

  • 在c++ 20下使用微软的proxy库替代传统的virtual动态多态
  • Spring MVC:@RequestMapping
  • 【vue3+ts项目】配置eslint校验代码工具,eslint+prettier+stylelint
  • PHP之ZipArchive打包压缩文件
  • 面试之快速学习C++14
  • 【算法专题突破】双指针 - 快乐数(3)
  • 【javaweb】学习日记Day4 - Maven 依赖管理 Web入门
  • C++信息学奥赛1144:单词翻转
  • qt检查文件夹是否有写权限
  • LSF 安装目录,快速参考 LSF 命令、守护程序、配置文件、日志文件和重要集群配置参数
  • 在Mybatis中写动态sql这些标签:if、where、set、trim、foreach、choose的作用是什么,怎么用?
  • 7 Python的模块和包
  • 【JavaWeb 篇】使用Servlet、JdbcTemplate和Durid连接池实现用户登录功能与测试
  • 【Unity3D赛车游戏】【六】如何在Unity中为汽车添加发动机和手动挡变速?
  • 【Go 基础篇】切片:Go语言中的灵活数据结构
  • 龙芯2K1000LA移植交叉编译环境以及QT
  • javaee spring依赖注入之spel方式
  • 【Java集合学习1】ArrayList集合学习及集合概述分析
  • TouchGFX之调试
  • C# winform加载yolov8模型测试(附例程)
  • 浙大陈越何钦铭数据结构07-图6 旅游规划
  • VUE笔记(七)项目登录
  • 大语言模型之六- LLM之企业私有化部署
  • Python3 列表
  • OpenCV基础知识(8)— 图形检测
  • Java虚拟机
  • c++学习 之 函数重载注意事项
  • 2023-08-27 LeetCode每日一题(合并区间)
  • C#,数值计算——调适数值积分法(adaptive quadrature)的计算方法与源程序
  • 微信小程序发布迭代版本后如何提示用户强制更新新版本