当前位置: 首页 > news >正文

多头自注意力机制的代码实现

文章目录

  • 1、自注意力机制
  • 2、多头注意力机制

  • transformer的整体结构:
    在这里插入图片描述

1、自注意力机制

  • 自注意力机制如下:
    在这里插入图片描述
  • 计算过程:
    在这里插入图片描述
  • 代码如下:
class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_size):super().__init__()self.W_q = nn.Linear(embed_dim, key_size, bias=False)self.W_k = nn.Linear(embed_dim, key_size, bias=False)self.W_v = nn.Linear(embed_dim, value_size, bias=False)def forward(self, x, attn_mask=None):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量attn_mask: (N, L, L),用于对注意力矩阵(L, L)进行mask输出:shape:(N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(query.size(2))if attn_mask is not None:scores = scores.masked_fill(attn_mask, 0)attn_weights = F.softmax(scores, dim=-1)	# dim为-1表示,对每个嵌入向量与其他所有向量的注意力权重,进行softmax,以使每一行的和为1return torch.matmul(attn_weights, value)

2、多头注意力机制

  • 结构如下:
    在这里插入图片描述
  • 计算过程如下:
class MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads, key_size, value_size, bias=False):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.q_head_dim = key_size // num_headsself.k_head_dim = key_size // num_headsself.v_head_dim = value_size // num_headsself.W_q = nn.Linear(embed_dim, key_size, bias=bias)self.W_k = nn.Linear(embed_dim, key_size, bias=bias)self.W_v = nn.Linear(embed_dim, value_size, bias=bias)        self.q_proj = nn.Linear(key_size, key_size, bias=bias)self.k_proj = nn.Linear(key_size, key_size, bias=bias)self.v_proj = nn.Linear(value_size, value_size, bias=bias)self.out_proj = nn.Linear(value_size, embed_dim, bias=bias)def forward(self, x):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量Returns:output: (N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)N, L, value_size = v.size()q = q.reshape(N, L, self.num_heads, self.q_head_dim).transpose(1, 2)k = k.reshape(N, L, self.num_heads, self.k_head_dim).transpose(1, 2)v = v.reshape(N, L, self.num_heads, self.v_head_dim).transpose(1, 2)att = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(k.size(-1))att = F.softmax(att, dim=-1)output = torch.matmul(att, v)output = output.transpose(1, 2).reshape(N, L, value_size)output = self.out_proj(output)return output
http://www.lryc.cn/news/140973.html

相关文章:

  • 抽象工厂模式
  • 登录校验-Filter-详解
  • 堆栈方法区笔记记录
  • 新版微信小程序获取用户手机号
  • CSS实践 —— 悬浮盒子阴影加上移效果
  • 安全测试基础知识
  • 列表首屏毫秒级加载与自动滚动定位方案
  • 小区物业业主管理信息系统设计的设计与实现(论文+源码)_kaic
  • Fortran 微分方程求解 --ODEPACK
  • 8路光栅尺磁栅尺编码器或16路高速DI脉冲信号转Modbus TCP网络模块 YL99-RJ45
  • 【Python】函数
  • centos安装MySQL 解压版完整教程(按步骤傻瓜式安装
  • 【后端速成 Vue】第一个 Vue 程序
  • Macbook pro M1 安装Ubuntu教程
  • 前端console.log打印内容与后端请求返回数据不一致
  • SQL入门:多表查询
  • 【C++】进一步认识模板
  • Mysql Oracle 区别
  • 华为OD-第K长的连续字母字符串长度
  • 【编程题】有效三角形的个数
  • 【mysql是怎样运行的】-EXPLAIN详解
  • 数据结构例题代码及其讲解-链表
  • [Open-source tool] 可搭配PHP和SQL的表單開源工具_Form tools(1):簡介和建置
  • 移动数据业务价值链的整合
  • 合并两个链表
  • 测试框架pytest教程(9)跳过测试skip和xfail
  • HTML <textarea> 标签
  • 探索图结构:从基础到算法应用
  • Redis之GEO类型解读
  • uniapp 微信小程序 路由跳转