当前位置: 首页 > news >正文

transfomer中attention为什么要除以根号d_k

简介

得到矩阵 Q, K, V之后就可以计算出 Self-Attention 的输出了,计算的公式如下:
A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d k ) V Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=Softmax(dk QKT)V

好处

除以维度的开方,可以将数据向0方向集中,使得经过softmax后的梯度更大.
从数学上分析,可以使得QK的分布和Q/K保持一致,

推导

对于两个独立的正态分布而言,两者的加法的期望和方差就是两个独立分布的期望和方差。
qk_T的计算过程为[len_q,dim][dim,len_k]=[len_q,len_k],qk的元素等于dim个乘积的和。对于0-1分布表乘积不会影响期望和方差,但是求和操作会使得方差乘以dim,因此对qk元素除以sqrt(dim)把标准差压回1.

这里展示一个不严谨的采样可视化过程
假设在query在(0,1)分布,key在(0,1)分布,随机采样lengthdim个点,然后统计querykey_T的散点的分布

import math
import numpy as np
import matplotlib.pyplot as pltdef plot_curve(mu=0, sigma =1):import numpy as npimport matplotlib.pyplot as pltfrom scipy.stats import norm# 设置正态分布的参数# mu, sigma = 0, 1  # 均值和标准差# 创建一个x值的范围,覆盖正态分布的整个区间x = np.linspace(mu - 4 * sigma, mu + 4 * sigma, 1000)# 计算对应的正态分布的概率密度值y = norm.pdf(x, mu, sigma)# 我们可以选择y值较高的点来绘制散点图,以模拟概率密度的分布# 这里我们可以设置一个阈值,只绘制y值大于某个值的点threshold = 0.01  # 可以根据需要调整这个阈值selected_points = y > thresholdplt.plot(x, y, 'r-', lw=2, label='Normal dist. (mu={}, sigma={})'.format(mu, sigma))plt.title('Normal Distribution Scatter Approximation')plt.xlabel('Value')plt.ylabel('Probability Density')plt.legend()plt.grid(True)plt.show()def plot_poins(x):# 因为这是一个一维的正态分布,我们通常只绘制x轴上的点# 但为了模拟二维散点图,我们可以简单地将y轴设置为与x轴相同或固定值(例如0)y = np.zeros_like(x)# 绘制散点图plt.figure(figsize=(8, 6))plt.scatter(x, y, alpha=0.5)  # alpha控制点的透明度plt.title('Normal (0, 1) Distribution Scatter Plot')plt.xlabel('Value')plt.ylabel('Value (or Frequency if binned)')plt.grid(True)plt.show()if __name__ == '__main__':# 设置随机种子以便结果可复现np.random.seed(0)len = 10000dim = 100query = np.random.normal(0, 1, len*dim).reshape(len,dim)key = np.random.normal(0, 1, len*dim).reshape(dim,len)qk = np.matmul(query,key) / math.sqrt(dim)mean_query = query.mean()std_query = np.std(query,ddof=1)mean_key = key.mean()std_key = np.std(key,ddof=1)mean_qk = qk.mean()std_qk = np.std(qk,ddof=1)plot_poins(query)plot_curve(mean_query,std_query)

在这里插入图片描述

http://www.lryc.cn/news/360494.html

相关文章:

  • iperf3带宽压测工具使用
  • [数据集][目标检测]焊接处缺陷检测数据集VOC+YOLO格式3400张8类别
  • 2024华为OD机试真题-剩余银饰的重量-C++(C卷D卷)
  • 糖果促销【百度之星】/思维
  • 【python学习】安装Anaconda后,如何进行环境管理(命令行操作及图形化操作Anaconda Navigator)及包管理
  • HTML大雪纷飞
  • 问界新M7 Ultra仅售28.98万元起,上市即交付
  • 【Java数据结构】详解LinkedList与链表(四)
  • ssm汉服文化平台网站
  • 如何让 LightRoom 每次导入照片后不自动弹出 SD 卡 LR
  • elasticdump和ESM
  • Java扩展机制:SPI与Spring.factories详解
  • iPhone 语言编程:深入探索与无限可能
  • css动态导航栏鼠标悬停特效
  • Vue中使用axios先获取头像上传参数然后上传图片到服务器-demo
  • Win11环境下Android Studio中Flutter开发环境构建(逐步解决)
  • Thread Servlet思考
  • 电源滤波器怎么选用
  • 终于更新了!时隔一年niushop多商户b2b2c的新补丁v5.0.2终于发布了,一起看看有啥新变化
  • google的chromedriver最新版下载地址
  • Gitee的原理及应用详解(四)
  • IP 协议的相关特性
  • C++11 在 Windows 环境下的多线程编程指南
  • [数据集][目标检测]旋风检测数据集VOC+YOLO格式157张1类别
  • 智慧商砼搅拌车安监运营管理的创新实践
  • 渗透测试框架提权
  • tcp链接中的三次挥手是什么原因
  • 运维相关知识
  • 网络安全基础技术扫盲篇名词解释之“证书“
  • [数据集][目标检测]老鼠检测数据集VOC+YOLO格式4107张1类别