当前位置: 首页 > news >正文

【DeepLearning-1】 注意力机制(Attention Mechanism)

1.1注意力机制的基本原理:

  1. 计算注意力权重

    注意力权重是通过计算输入数据中各个部分之间的相关性来得到的。这些权重表示在给定上下文下,数据的某个部分相对于其他部分的重要性。
  2. 加权求和

    使用这些注意力权重对输入数据进行加权求和,以生成一个紧凑的表示,该表示集中了输入数据的关键信息。

1.2数学原理:

假设我们有一个输入序列 X=[x1​,x2​,...,xn​] ,其中 xi​ 是序列中的元素。在自注意力机制中,我们首先将输入转换为查询(Q)、键(K)和值(V):

变体:

  • 多头注意力(Multi-Head Attention)
    • 在 Transformer 模型中,使用了多头注意力机制,它将 Q、K、V 分割为多个“头”,每个头在不同的表示子空间中学习注意力:

1.3代码实现: 

class Attention(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5 #缩放因子,用于调整注意力得分的规模,通常是 dim_head 的平方根的倒数self.attend = nn.Softmax(dim = -1) #Softmax 函数,用于计算注意力权重self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim=-1)q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b p h n d -> b p n (h d)')return self.to_out(out)

forward(self, x):

  • 生成查询(Q)、键(K)和值(V):

    • qkv = self.to_qkv(x).chunk(3, dim=-1): 这行代码使用一个线性变换(self.to_qkv)将输入 x 转换为查询(Q)、键(K)和值(V)这三组向量,然后将其分割成三个部分。
  • 重排为多头格式:

    • q, k, v = map(...): 这里使用 rearrange 函数将 Q、K 和 V 的形状转换为多头格式。原始的扁平形状被重排为一个具有多个头部的形状,以便独立进行自注意力运算。
  • 计算注意力得分:

    • dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale: 这里计算查询(Q)和键(K)之间的点积,以得到注意力得分。得分通过 self.scale(一个基于头维度 dim_head 的缩放因子)进行缩放,以防止梯度消失或爆炸。
  • 应用 Softmax 获取注意力权重:

    • attn = self.attend(dots): 使用 Softmax 函数对注意力得分进行归一化,得到每个键对应的注意力权重。
  • 加权和以得到输出:

    • out = torch.matmul(attn, v): 将注意力权重应用于值(V),得到加权和,这是自注意力的输出。
  • 重排并通过输出层:

    • out = rearrange(out, 'b p h n d -> b p n (h d)'): 将输出重排回原始格式,并通过可能存在的输出线性层和 dropout 层。
http://www.lryc.cn/news/286560.html

相关文章:

  • c++:string相关的oj题(415. 字符串相加、125. 验证回文串、541. 反转字符串 II、557. 反转字符串中的单词 III)
  • HuoCMS|免费开源可商用CMS建站系统HuoCMS 2.0下载(thinkphp内核)
  • VsCode + CMake构建项目 C/C++连接Mysql数据库 | 数据库增删改查C++封装 | 信息管理系统通用代码 ---- 课程笔记
  • HackTheBox - Medium - Linux - Ransom
  • 柠檬微趣面试准备
  • uniapp嵌套webview,无法返回上一级?
  • 【优先级队列 之 堆的实现】
  • Vue中$watch()方法和watch属性的区别
  • openssl3.2 - 官方demo学习 - test - certs - 001 - Primary root: root-cert
  • 小程序商城能不能自己开发?
  • GPTBots:利用FlowBot中的卡片和表单信息,提供丰富的客服体验
  • ERC20 解读
  • C#,入门教程(31)——预处理指令的基础知识与使用方法
  • Java SE:面向对象(下)
  • 搭建开源数据库中间件MyCat2-配置mysql数据库双主双从
  • Oracle 19c rac集群管理 -------- 集群启停操作过程
  • 【Java】HttpServlet类中前后端交互三种方式(query string、form表单、JSON字符串)
  • 【深蓝学院】移动机器人运动规划--第2章 基于搜索的路径规划--笔记
  • 安装向量数据库milvus可视化工具attu
  • STM32标准库开发——串口发送/单字节接收
  • jdk17新特性——文本块(即多行的字符串)增强
  • 阿里云ECS使用docker搭建mysql服务
  • Windows给docker设置阿里源
  • 安裝火狐和穀歌流覽器插件FoxyProxy管理海外動態IP代理
  • C++重新入门-函数重载
  • niushop靶场漏洞查找-文件上传漏洞等(超详细)
  • Bit Extraction and Bootstrapping for BGV/BFV
  • 七八分钟快速用k8s部署springboot前后端分离项目
  • 中移(苏州)软件技术有限公司面试问题与解答(2)—— Linux内核内存初始化的完整流程1
  • 蓝桥杯、编程考级、NOC、全国青少年信息素养大赛—scratch列表考点