当前位置: 首页 > news >正文

【学习笔记】2.1注意力机制

参考资料:https://github.com/datawhalechina/happy-llm 

2.1.1 什么是注意力机制

注意力机制最初源于计算机视觉领域,其核心思想是通过集中关注重点部分来高效处理信息。在自然语言处理中,注意力机制通过聚焦于关键的 token(如单词或短语),可以实现更高效和高质量的计算。其三个核心变量为:Query(查询值)、Key(键值)和 Value(真值)。例如,在查找新闻报道中的时间时,Query 可以是“时间”或“日期”等向量,Key 和 Value 是整个文本。通过计算 Query 和 Key 的相关性得到权重,再将权重与 Value 结合,最终得到对文本的注意力加权结果。注意力机制通过这种方式拟合序列中每个词与其他词的相关关系。

2.1.2 深入理解注意力机制 

注意力机制的核心变量是 Query(查询值)Key(键值)Value(真值)。通过类比字典查询的过程,可以理解注意力机制的计算逻辑:

  1. 字典查询类比

    • 字典的 键(Key)值(Value) 对应于注意力机制中的 Key 和 Value。

    • 查询(Query)通过与 Key 的匹配来获取对应的 Value。

    • 当 Query 匹配多个 Key 时,可以通过为每个 Key 分配权重(注意力分数)来组合多个 Value。

  2. 注意力分数的计算

    • 使用 点积 计算 Query 和 Key 的相似度: x = qK^T

    • 通过 softmax 函数 将点积结果归一化为权重:\text{softmax}(x)i = \frac{e^{xi}}{\sum{j}e^{x_j}}​​。

    • 权重反映了 Query 和每个 Key 的相似程度,且权重之和为 1。

  3. 注意力机制的公式

    • 基本公式:attention(Q,K,V) = softmax(QK^T)V

    • 为了处理高维数据并保持梯度稳定,引入放缩因子:attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V

最终,注意力机制通过计算 Query 和 Key 的相似度,为每个 Key 分配权重,并结合 Value 得到加权结果。

2.1.3 注意力机制的实现

'''注意力计算函数'''
def attention(query, key, value, dropout=None):'''args:query: 查询值矩阵key: 键值矩阵value: 真值矩阵'''# 获取键向量的维度,键向量的维度和值向量的维度相同d_k = query.size(-1) # 计算Q与K的内积并除以根号dk# transpose——相当于转置scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)# Softmaxp_attn = scores.softmax(dim=-1)if dropout is not None:p_attn = dropout(p_attn)# 采样# 根据计算结果对value进行加权求和return torch.matmul(p_attn, value), p_attn

2.1.4 自注意力

  • 定义

    • 自注意力是注意力机制的变种,用于计算同一序列中每个元素对其他所有元素的注意力分布。

  • 计算过程

    • Q、K、V 都由同一个输入通过不同的参数矩阵 Wq​、Wk​、Wv​ 计算得到。

    • 通过自注意力机制,可以建模文本中每个 token 与其他所有 token 的依赖关系。

  • 应用场景

    • 在 Transformer 的 Encoder 中,输入通过参数矩阵 Wq​、Wk​、Wv​ 分别得到 Q、K、V,从而拟合输入语句中每个 token 对其他所有 token 的关系。

  • 代码实现

    • 在代码中,自注意力机制通过将 Q、K、V 的输入设置为同一个参数来实现。

# attention 为上文定义的注意力计算函数
attention(x, x, x)

2.1.5 掩码自注意力

掩码自注意力(Mask Self-Attention) 是一种在自注意力机制中引入掩码的技术,用于遮蔽特定位置的 token,使模型在学习过程中只能使用历史信息进行预测,而不能看到未来信息。这种方法的核心动机是实现并行计算,提高 Transformer 模型的效率。

  1. 生成掩码矩阵

    • 使用上三角矩阵作为掩码,其中上三角部分的值为 −∞,其余部分为 0。

    • 掩码矩阵的维度通常为 (1, \text{seq_len}, \text{seq_len}),通过广播机制应用于整个输入序列。

  2. 掩码的应用

    • 在计算注意力分数时,将掩码矩阵与注意力分数相加。

    • 通过 Softmax 操作,将上三角部分的 −∞ 转换为 0,从而忽略这些位置的注意力分数。

示例:

假设待学习的文本序列为 【BOS】I like you【EOS】,掩码自注意力的输入如下:

<BOS> 【MASK】【MASK】【MASK】【MASK】
<BOS>    I   【MASK】 【MASK】【MASK】
<BOS>    I     like  【MASK】【MASK】
<BOS>    I     like    you  【MASK】
<BOS>    I     like    you   </EOS>

  • 每个输入样本只看到前面的 token,预测下一个 token。

  • 通过并行处理,模型可以同时处理整个序列,而不是逐个步骤串行处理。

代码实现:
# 创建一个上三角矩阵,用于遮蔽未来信息
mask = torch.full((1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)# 在注意力计算时,将掩码与注意力分数相加
scores = scores + mask[:, :seqlen, :seqlen]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
  • 掩码矩阵:上三角部分为 −∞,其余部分为 0。

  • Softmax 操作:将 −∞ 转换为 0,忽略上三角区域的注意力分数。

2.1.6 多头注意力

多头注意力机制(Multi-Head Attention) 是 Transformer 模型的核心组件,用于更全面地拟合语句序列中的相关关系。它通过同时进行多次注意力计算,每次拟合不同的关系,然后将结果拼接并线性变换,从而更深入地建模语言信息。

核心动机:
  1. 单一注意力的局限性:一次注意力计算只能拟合一种相关关系,难以全面捕捉语句中的复杂依赖。

  2. 多头注意力的优势:通过多个注意力头同时计算,每个头可以捕捉不同的信息,从而更全面地拟合语句关系。

多头注意力机制的工作原理:
  1. 公式表示

    \mathrm{MultiHead}(Q, K, V) = \mathrm{Concat}(\mathrm{head_1}, ..., \mathrm{head_h})W^O \\

    其中:

    \mathrm{head_i} = \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)
    • Q,K,V:输入的查询、键和值矩阵。

    • W^Q_i, W^K_i, W^V_i​:每个头的参数矩阵。

    • W^O:输出权重矩阵,用于将拼接后的结果投影回原始维度。

  2. 多头注意力的实现

    • 将输入序列通过不同的参数矩阵 W^Q, W^K, W^V分别计算得到 Q,K,V。

    • 将 Q,K,V 分成多个头。

    • 对每个头分别进行注意力计算,然后将结果拼接。

    • 最后通过一个线性层 WO 将拼接后的结果投影回原始维度。

代码实现:

 

import torch.nn as nn
import torch'''多头自注意力计算模块'''
class MultiHeadAttention(nn.Module):def __init__(self, args: ModelArgs, is_causal=False):# 构造函数# args: 配置对象super().__init__()# 隐藏层维度必须是头数的整数倍,因为后面我们会将输入拆成头数个矩阵assert args.n_embd % args.n_heads == 0# 模型并行处理大小,默认为1。model_parallel_size = 1# 本地计算头数,等于总头数除以模型并行处理大小。self.n_local_heads = args.n_heads // model_parallel_size# 每个头的维度,等于模型维度除以头的总数。self.head_dim = args.dim // args.n_heads# Wq, Wk, Wv 参数矩阵,每个参数矩阵为 n_embd x n_embd# 这里通过三个组合矩阵来代替了n个参数矩阵的组合,其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积,# 不理解的读者可以自行模拟一下,每一个线性层其实相当于n个参数矩阵的拼接self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)self.wk = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)self.wv = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)# 输出权重矩阵,维度为 n_embd x n_embd(head_dim = n_embeds / n_heads)self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)# 注意力的 dropoutself.attn_dropout = nn.Dropout(args.dropout)# 残差连接的 dropoutself.resid_dropout = nn.Dropout(args.dropout)# 创建一个上三角矩阵,用于遮蔽未来信息# 注意,因为是多头注意力,Mask 矩阵比之前我们定义的多一个维度if is_causal:mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))mask = torch.triu(mask, diagonal=1)# 注册为模型的缓冲区self.register_buffer("mask", mask)def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):# 获取批次大小和序列长度,[batch_size, seq_len, dim]bsz, seqlen, _ = q.shape# 计算查询(Q)、键(K)、值(V),输入通过参数矩阵层,维度为 (B, T, n_embed) x (n_embed, n_embed) -> (B, T, n_embed)xq, xk, xv = self.wq(q), self.wk(k), self.wv(v)# 将 Q、K、V 拆分成多头,维度为 (B, T, n_head, C // n_head),然后交换维度,变成 (B, n_head, T, C // n_head)# 因为在注意力计算中我们是取了后两个维度参与计算# 为什么要先按B*T*n_head*C//n_head展开再互换1、2维度而不是直接按注意力输入展开,是因为view的展开方式是直接把输入全部排开,# 然后按要求构造,可以发现只有上述操作能够实现我们将每个头对应部分取出来的目标xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)xq = xq.transpose(1, 2)xk = xk.transpose(1, 2)xv = xv.transpose(1, 2)# 注意力计算# 计算 QK^T / sqrt(d_k),维度为 (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)# 掩码自注意力必须有注意力掩码if self.is_causal:assert hasattr(self, 'mask')# 这里截取到序列长度,因为有些序列可能比 max_seq_len 短scores = scores + self.mask[:, :, :seqlen, :seqlen]# 计算 softmax,维度为 (B, nh, T, T)scores = F.softmax(scores.float(), dim=-1).type_as(xq)# 做 Dropoutscores = self.attn_dropout(scores)# V * Score,维度为(B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)output = torch.matmul(scores, xv)# 恢复时间维度并合并头。# 将多头的结果拼接起来, 先交换维度为 (B, T, n_head, C // n_head),再拼接成 (B, T, n_head * C // n_head)# contiguous 函数用于重新开辟一块新内存存储,因为Pytorch设置先transpose再view会报错,# 因为view直接基于底层存储得到,然而transpose并不会改变底层存储,因此需要额外存储output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)# 最终投影回残差流。output = self.wo(output)output = self.resid_dropout(output)return output

http://www.lryc.cn/news/571664.html

相关文章:

  • C#开发MES管理系统源码工业生产线数据采集WPF上位机产线执行系统源码
  • crackme010
  • 01初始uni-app+tabBar+首页
  • 关于球面投影SphericalProjector的介绍以及代码开发
  • 分治算法之归并排序
  • webpack+vite前端构建工具 - 3webpack处理js
  • 深入ZGC并发处理的原理
  • 固态硬盘的加装和初始化
  • 电路图识图基础知识-摇臂钻床识图(三十一)
  • 27.自连接
  • 你的下一把量化“瑞士军刀”?KHQuant适用场景全解析【AI量化第32篇】
  • 数据集笔记:宣城轨迹
  • 权重遍历及Delong‘s test | 已完成单调性检验?
  • 键盘 AK35I Pro V2 分析
  • ABP vNext + Azure Application Insights:APM 监控与性能诊断最佳实践
  • 零基础设计模式——总结与进阶 - 1. 设计模式的综合应用
  • 利用cpolar实现Talebook数字图书馆的实时访问
  • ZYNQ学习记录FPGA(五)高频信号中的亚稳态问题
  • VMware vSphere Foundation 9.0 技术手册 —— Ⅰ 安装 ESXi 9.0 (虚拟机)
  • 数据库char字段做trim之后查询很慢的解决方式
  • 需要做一款小程序,用来发券,后端如何进行设计能够保证足够安全?
  • 微信原生小程序转uniapp过程及错误总结
  • 环卫车辆定位与监管:安心联车辆监控管理平台--科技赋能城市环境卫生管理
  • 【力扣 中等 C】2. 两数相加
  • chili3d笔记18 出三视图调整
  • 数据结构——选择题—查漏补缺
  • Could not locate zlibwapi.dll. Please make sure it is in your library path!
  • 功耗高?加密弱?爱普特APT32F1023H8S6单片机 2μA待机+AES硬件加密破局
  • Vue3 + TypeScript 本地存储 localStorage 的用法
  • 【时时三省】(C语言基础)内部函数和外部函数