当前位置: 首页 > news >正文

Transformer详解(3)-多头自注意力机制

attention

在这里插入图片描述

在这里插入图片描述

multi-head attention

在这里插入图片描述
在这里插入图片描述

pytorch代码实现

import math
import torch
from torch import nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, heads=8, d_model=128, droput=0.1):super().__init__()self.d_model = d_model  # 128self.d_k = d_model // heads  # 128//8=16self.h = heads  # 8self.q_linear = nn.Linear(d_model, d_model)  # (50,128)*(128,128)=(50,128),其中(128*128)属于权重,在网络训练中学习。self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(droput)self.out = nn.Linear(d_model, d_model)def attention(self, q, k, v, d_k, mask=None, dropout=None):scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # 矩阵乘法 (32,8,50,16)*(32,8,16,50)->(32,8,50,50)if mask is not None:mask = mask.unsqueeze(1)scores = scores.masked_fill(mask == 0, -1e9)scores = F.softmax(scores, dim=-1)if dropout is not None:scores = dropout(scores)output = torch.matmul(scores, v)  # (32,8,50,50)*(32,8,50,16)->(32,8,50,16)return outputdef forward(self, q, k, v, mask=None):bs = q.size(0)  # batch_size 大小  这里的例子是32k = self.k_linear(k).view(bs, -1, self.h, self.d_k)q = self.k_linear(q).view(bs, -1, self.h, self.d_k)v = self.k_linear(v).view(bs, -1, self.h, self.d_k)# (32,50,128)->(32,50,128)->(32,50,8,16)  8*16=128 每个embedding拆成的8份,也就是8个头k = k.transpose(1, 2)  # (32,50,8,16)->(32,8,50,16)q = q.transpose(1, 2)v = v.transpose(1, 2)scores = self.attention(q, k, v, self.d_k, mask, self.dropout)  # (32,8,50,16)concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)  # (32,50,128)output = self.out(concat)  # (32,50,128)return outputif __name__ == '__main__':multi_head_attention = MultiHeadAttention(8, 128)normal_tensor = torch.randn(32, 50, 128)  # 随机生成均值为0,方差为1的正态分布。batch_size=32,序列长度=50,embedding维度=128。x = torch.sigmoid(normal_tensor)  # 把每个数缩放到(0,1)output = multi_head_attention(x, x, x)print('done')
http://www.lryc.cn/news/352614.html

相关文章:

  • 运用HTML、CSS设计Web网页——“西式甜品网”图例及代码
  • 大语言模型是通用人工智能的实现路径吗?【文末有福利】
  • c语言——宏offsetof
  • C#串口通信-串口相关参数介绍
  • 节省时间与精力:用BAT文件和任务计划器自动执行重复任务
  • 一年前的Java作业,模拟游戏玩家战斗
  • C++ 学习 关于引用
  • BERT ner 微调参数的选择
  • 【MySQL精通之路】系统变量-持久化系统变量
  • fdk-aac将aac格式转为pcm数据
  • 【C语言深度解剖】(15):动态内存管理和柔性数组
  • 力扣每日一题 5/25
  • (1)无线电失控保护(一)
  • 基于51单片机的多功能万年历温度计—可显示农历
  • 【软件设计师】下午题总结-数据流图、数据库、统一建模语言
  • CSDN 自动评论互动脚本
  • Tomcat端口配置
  • SpringBoot中使用AOP实现日志记录功能
  • kubernetes(k8s) v1.30.1 helm 集群安装 Dashboard v7.4.0 可视化管理工具 图形化管理工具
  • CS144(所有lab解析)
  • LeetCode 热题 100 介绍
  • Flutter 中的 AnimatedPhysicalModel 小部件:全面指南
  • 第二十届文博会沙井艺立方分会场启幕!大咖齐打卡!
  • 【Vue】computed 和 methods 的区别
  • HarmonyOS 鸿蒙应用开发 - 创建自定义组件
  • 【Vue3】封装axios请求(cli和vite)
  • Java8 Optional常用方法使用场景
  • isscc2024 short course4 In-memory Computing Architectures
  • ubuntu 安装 kvm 启动虚拟机
  • [OpenGL] opengl切线空间