LeetCode - Google 大模型10题 第2天 Position Embedding(位置编码) 3题
欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145454489
在 Transformer 架构中,位置编码(Position Embedding) 是辅助模型理解序列中元素顺序的关键机制。绝对位置编码(Absolute Positional Encoding, Absolute PE) 是最基础的形式,通过为序列中的每个位置分配一个固定的、与位置相关的向量来注入位置信息,这些向量通常是通过正弦和余弦函数生成的,使模型明确区分不同位置的元素。相对位置编码(Relative Positional Encoding, Relative PE) 通过考虑元素之间的相对距离,使得模型在计算注意力时动态地捕捉序列中元素的相对位置关系。旋转位置编码(Rotary Positional Encoding, RoPE) 是相对位置编码的一种改进形式,通过将位置信息嵌入到查询(Query)和键(Key)向量中,以旋转的方式结合位置信息,使得模型在处理长序列时能够更高效地利用位置信息,同时保持计算的简洁性和可扩展性。
原理参考:理解 旋转位置编码(RoPE) 与 绝对相对位置编码 之间的优势
1. 旋转位置编码 RoPE
旋转位置编码(Rotary Position Embedding, RoPE) 公式,在 Llama3 源码中,超参数 θ = 500000 \theta = 500000 θ=500000, p o s pos pos 是序列 s s s 的位置, i i i 是模型维度 d i m dim dim 的位置(或 d i m / 2 dim/2 dim/2),即:
P E ( p o s , i ) = c o s ( p o s 50000 0 i d m ) + i ⋅ s i n ( p o s 50000 0 i d m ) x i = 1 50000 0 i d m P E ( p o s , i ) = c o s ( p o s ⋅ x i ) + i ⋅ s i n ( p o s ⋅ x i ) = e i ⋅ p o s ⋅ x i PE_{(pos,i)} = cos(\frac{pos}{500000^{\frac{i}{d_{m}}}})+i\cdot sin(\frac{pos}{500000^{\frac{i}{d_{m}}}}) \\ x_{i} = \frac{1}{500000^{\frac{i}{d_{m}}}} \\ PE_{(pos,i)} = cos(pos \cdot x_{i})+i\cdot sin(pos \cdot x_{i})=e^{i \cdot pos \cdot x_{i}} PE(pos,i)=cos(500000dmipos)+i⋅sin(500000dmipos)xi=500000dmi1PE(pos,i)=cos(pos⋅xi)+i⋅sin(pos⋅xi)=ei⋅pos⋅xi
import math
import torch
import torch.nn.functional as F
from torch import nn
def precompute_freqs_cis(seq_len, dim, theta=10000.0):"""计算 freqs_cis, 即 频率(frequencies) + cis(cos isin)"""half_dim = dim // 2 # RoPE的维度是极坐标,是dim的1/2freqs = 1.0 / (theta ** (torch.arange(0, half_dim) / half_dim))t = torch.arange(seq_len) # type: ignorefreqs = torch.outer(t, freqs) # type: ignorefreqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64return freqs_cis
def apply_rotary_emb(q, k, freqs_cis):# [2, 8, 10, 64] -> [2, 8, 10, 32] (complex)xq = torch.view_as_complex(q.reshape(*q.shape[:-1], -1, 2)) # 转换成 complex 形式xk = torch.view_as_complex(k.reshape(*k.shape[:-1], -1, 2)) # 转换成 complex 形式# [2, 8, 10, 32, 2] -> [2, 8, 10, 64]xq_out = torch.view_as_real(xq * freqs_cis).flatten(3) # flatten 第3维度xk_out = torch.view_as_real(xk * freqs_cis).flatten(3)return xq_out, xk_out
class MultiHeadAttention(nn.Module):"""多头自注意力机制 MultiHeadAttention"""def __init__(self, heads, d_model, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None):bs = q.size(0)s = q.size(1)# 进行线性操作划分为成 h 个头k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.q_linear(q).view(bs, -1, self.h, self.d_k)v = self.v_linear(v).view(bs, -1, self.h, self.d_k)# 矩阵转置k = k.transpose(1, 2) # [bs,h,s,d] = [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 预计算 RoPE 频率freqs_cis = precompute_freqs_cis(s, self.d_k) # output: [10, 32], i.e. [s,d_k//2]# 应用 RoPE 到 q 和 kq, k = apply_rotary_emb(q, k, freqs_cis)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数bs, s, h, d = 2, 10, 8, 512dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = MultiHeadAttention(h, d, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()
2. 绝对位置编码 Absolute PE
Transformer 的 绝对位置编码(Absolute Positional Encoding) 公式,在 Transformer 源码中,超参数 θ = 10000 \theta = 10000 θ=10000, p o s pos pos 是序列 s s s 的位置, i i i 是模型维度 d i m dim dim 的位置,即:
P E ( p o s , 2 i ) = s i n ( p o s 1000 0 2 i d m ) P E ( p o s , 2 i + 1 ) = c o s ( p o s 1000 0 2 i d m ) A t t e n t i o n ( Q K V ) = S o f t m a x ( ( Q + P E ) ( K + P E ) ⊤ d m ) ( V + P E ) PE_{(pos,2i)}=sin(\frac{pos}{10000^{\frac{2i}{d_{m}}}}) \\ PE_{(pos,2i+1)}=cos(\frac{pos}{10000^{\frac{2i}{d_{m}}}}) \\ Attention(QKV) = Softmax(\frac{(Q+PE)(K+PE)^{\top}}{\sqrt{d_{m}}})(V+PE) PE(pos,2i)=sin(10000dm2ipos)PE(pos,2i+1)=cos(10000dm2ipos)Attention(QKV)=Softmax(dm(Q+PE)(K+PE)⊤)(V+PE)
注意:在多头自注意力机制中,位置编码的维度是 d m d_{m} dm,直接加到输入 x ( q , k , v ) x(q,k,v) x(q,k,v),再进行线性变换(Linear),划分成多个 head 和 d k d_{k} dk 维度。
import math
import torch
import torch.nn.functional as F
from torch import nn
def get_positional_encoding(seq_len, dim, theta=10000.0):"""计算 sin - cos 形式的绝对位置编码"""position = torch.arange(0, seq_len)# 优化写法# div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(theta) / dim))div_term = 1.0 / torch.pow(theta, torch.arange(0, dim, 2) / dim)pe = torch.zeros(seq_len, dim)pe[:, 0::2] = torch.sin(torch.outer(position, div_term))pe[:, 1::2] = torch.cos(torch.outer(position, div_term))return pe
class MultiHeadAttention(nn.Module):"""多头自注意力机制 MultiHeadAttention"""def __init__(self, heads, d_model, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None):bs = q.size(0)s = q.size(1)# 计算 sin - cos 形式的绝对位置编码 [1, 10, 512] 自动广播 [2, 10, 512]pe = get_positional_encoding(s, self.d_model)pe = pe.unsqueeze(0) # 扩展维度以匹配 x 的形状# PyTorch 支持张量的广播,应用绝对位置编码到 q, k, vq += pek += pev += pe# 进行线性操作划分为成 h 个头k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.q_linear(q).view(bs, -1, self.h, self.d_k)v = self.v_linear(v).view(bs, -1, self.h, self.d_k)# 矩阵转置k = k.transpose(1, 2) # [bs,h,s,d] = [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算 attentionattn = self.attention(q, k, v, self.d_k, mask, self.dropout)# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数bs, s, h, d = 2, 10, 8, 512dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = MultiHeadAttention(h, d, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()
3. 相对位置编码 Relative PE
相对位置编码(Relative Positional Encoding,简称 RPE 或 RePE),在 Transformer-XL 与 T5 中使用相对位置编码,具体的实现方式较多,核心是,相对位置编码的索引矩阵 relative_indices
, [ 10 × 10 ] [10 \times 10] [10×10],范围是 [ 0 , 18 ] [0,18] [0,18],一共19个值,即:
[ 9, 8, 7, 6, 5, 4, 3, 2, 1, 0],
[10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
[11, 10, 9, 8, 7, 6, 5, 4, 3, 2],
[12, 11, 10, 9, 8, 7, 6, 5, 4, 3],
[13, 12, 11, 10, 9, 8, 7, 6, 5, 4],
[14, 13, 12, 11, 10, 9, 8, 7, 6, 5],
[15, 14, 13, 12, 11, 10, 9, 8, 7, 6],
[16, 15, 14, 13, 12, 11, 10, 9, 8, 7],
[17, 16, 15, 14, 13, 12, 11, 10, 9, 8],
[18, 17, 16, 15, 14, 13, 12, 11, 10, 9]
索引矩阵的源码:relative_indices = torch.arange(s).unsqueeze(1) - torch.arange(s).unsqueeze(0) + s - 1
参考 Tensor2Tensor 的 common_attention.py 实现方式,注意只是其中一类,即:
A t t e n t i o n ( Q K V ) = S o f t m a x ( Q K ⊤ d m + R e P E i j ) V Attention(QKV) = Softmax(\frac{QK^{\top}}{\sqrt{d_{m}}}+RePE_{ij})V Attention(QKV)=Softmax(dmQK⊤+RePEij)V
即:
import math
import torch
import torch.nn.functional as F
from torch import nn
def relative_positional_encoding(seq_len, dim, theta=10000.0):# 计算位置编码索引,参考 Absolute PE 公式relative_positions = torch.arange(1 - seq_len, seq_len).unsqueeze(1) # [-9,9], 一共19个值# div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(theta) / dim))div_term = 1.0 / torch.pow(theta, torch.arange(0, dim, 2) / dim)pe = torch.zeros(2 * seq_len - 1, dim)pe[:, 0::2] = torch.sin(relative_positions * div_term)pe[:, 1::2] = torch.cos(relative_positions * div_term)return pe
class MultiHeadAttention(nn.Module):"""多头自注意力机制 MultiHeadAttention"""def __init__(self, heads, d_model, dropout=0.1):super().__init__()self.d_model = d_modelself.d_k = d_model // headsself.h = headsself.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)@staticmethoddef attention(q, k, v, d_k, mask=None, dropout=None):bs, h, s, _ = q.shape# 计算查询和键的注意力分数scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)# ---------- 相对位置编码 RePE ---------- #re_pe = relative_positional_encoding(s, d_k) # [19, 64]# relative_indices 输出 0~18 的方阵 [s,s]re_indices = torch.arange(s).unsqueeze(1) - torch.arange(s).unsqueeze(0) + s - 1 # [10, 10]re = re_pe[re_indices] # [10, 10, 64]# 爱因斯坦公式拆解: q 的维度 [2,8,10,64] -> [10,2,8,64] -> [10,16,64]# re 的维度 [10,10,64], 则 qz' = [10,16,10] -> [10,2,8,10] -> [2,8,10,10]re_scores = torch.einsum('bhrd,rld->bhrl', q, re) # [2, 8, 10, 10]scores = scores + re_scores# ---------- 相对位置编码 RePE ---------- ## 掩盖掉那些为了填补长度增加的单元,使其通过 softmax 计算后为 0if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)return outputdef forward(self, q, k, v, mask=None):bs = q.size(0)s = q.size(1)# 进行线性操作划分为成 h 个头k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.q_linear(q).view(bs, -1, self.h, self.d_k)v = self.v_linear(v).view(bs, -1, self.h, self.d_k)# 矩阵转置k = k.transpose(1, 2) # [bs,h,s,d] = [2, 8, 10, 64]q = q.transpose(1, 2)v = v.transpose(1, 2)# 计算注意力attn = self.attention(q, k, v, self.d_k, mask, self.dropout)# 连接多个头并输入到最后的线性层concat = attn.transpose(1, 2).contiguous().view(bs, -1, self.d_model)output = self.out(concat)return output
def main():# 设置超参数bs, s, h, d = 2, 10, 8, 512dropout_rate = 0.1# 创建 MultiHeadAttention 实例attention = MultiHeadAttention(h, d, dropout_rate)# 创建随机输入张量q = torch.randn(bs, s, d)k = torch.randn(bs, s, d)v = torch.randn(bs, s, d)# 可选:创建掩码,因果掩码,上三角矩阵mask = torch.tril(torch.ones(bs, s, s))# 测试无掩码的情况output_no_mask = attention(q, k, v)print("Output shape without mask:", output_no_mask.shape)# 测试有掩码的情况output_with_mask = attention(q, k, v, mask)print("Output shape with mask:", output_with_mask.shape)# 检查输出是否符合预期assert output_no_mask.shape == (bs, s, d), "Output shape is incorrect without mask"assert output_with_mask.shape == (bs, s, d), "Output shape is incorrect with mask"print("Test passed!")
if __name__ == '__main__':main()