【多模态大模型】FlashAttention in NeurIPS 2022
一、引言
论文: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
作者: Stanford University
代码: FlashAttention
特点: 该方法提出将Q、K、V拆分为若干小块,使执行注意力时不需要频繁进行读写操作,而是每个小块只进行一次读写,从而提升注意力的执行速度。
⚠️ 在学习该方法前,建议补充Attention的相关知识。
二、详情
GPU中SRAM和HBM的计算和存储能力如下图:

可见,SRAM计算能力强(17TB/s),HBM的存储容量大(40GB)。因此,GPU的运算通常在SRAM上进行,如果运算结果的内存占用太大,系统会把运算结果先写入HBM,然后从HBM读出来再在SRAM上进行下一步的运算。
于是,我们就得到原始Attention的执行过程:

其中,Q、K、V分别是Query、Key、Value矩阵,S是相似度矩阵,P是权重矩阵,O是输出矩阵。
这里没写除以 d k \sqrt{d_k} dk的操作,不过无伤大雅,因为它对运算的影响并不大。
可见,计算S、P、O时都要进行读取,计算完成后也都要进行写入。然而,运算速度领先于读写速度导致SRAM运算完了要等数据过来才能进行下一步运算,这就拖慢了整体的速度。
2.1 拆分
FlashAttention提出将Q、K、V拆分成若干小块,这样每个小块的S、P矩阵不至于太大到需要写入HBM中,这样就能只在最开始读取Q、K、V、O(之前的运算结果),在SRAM中完成所有运算后,再将新的O写入HBM。
如果没有SoftMax操作,该过程很容易实现,如下图:


分别循环Q和K、V的小块,循环结果求和就是我们所有期望的O。但是,SoftMax阻碍了它的实现,回顾原始SoftMax公式:
s o f t m a x ( s ) j = e s j ∑ k = 1 N e s k softmax(\boldsymbol{s})_j=\frac{e^{s_j}}{\sum_{k=1}^{N}e^{s_k}} softmax(s)j=∑k=1Neskesj
可见,它要把相似度矩阵S的每一行转为一个概率分布。但是分块策略无法一次性获得完整的S中的行,于是FlashAttention在SoftMax中引入了 m ( s ) m(\boldsymbol{s}) m(s),新的SoftMax公式如下:
s o f t m a x ( s ) i = e s i − m ( s ) ∑ j = 1 N e s j − m ( s ) = f i l ( s ) softmax(\boldsymbol{s})_i=\frac{e^{s_i-m(\boldsymbol{s})}}{\sum_{j=1}^{N}e^{s_j-m(\boldsymbol{s})}}=\frac{f_i}{l(\boldsymbol{s})} softmax(s)i=∑j=1Nesj−m(s)esi−m(s)=l(s)fi
其中,最大值 m ( s ) = max i s i m(\boldsymbol{s})=\max_i s_i m(s)=maxisi,指数和 l ( s ) = ∑ i f i l(\boldsymbol{s})=\sum_i f_i l(s)=∑ifi。事实上,该操作不会影响SoftMax的结果,如下:
s o f t m a x ( [ 1 , 2 , 3 , 10 ] ) = [ e 1 e 1 + e 2 + e 3 + e 10 , e 2 e 1 + e 2 + e 3 + e 10 , e 3 e 1 + e 2 + e 3 + e 10 , e 10 e 1 + e 2 + e 3 + e 10 ] = [ e 1 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 , e 2 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 , e 3 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 , e 10 − 10 e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 ] softmax([1,2,3,10])=[\frac{e^{1}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{2}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{3}}{e^{1}+e^{2}+e^{3}+e^{10}},\frac{e^{10}}{e^{1}+e^{2}+e^{3}+e^{10}}]\\=[\frac{e^{1-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{2-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{3-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}},\frac{e^{10-10}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}}] softmax([1,2,3,10])=[e1+e2+e3+e10e1,e1+e2+e3+e10e2,e1+e2+e3+e10e3,e1+e2+e3+e10e10]=[e1−10+e2−10+e3−10+e10−10e1−10,e1−10+e2−10+e3−10+e10−10e2−10,e1−10+e2−10+e3−10+e10−10e3−10,e1−10+e2−10+e3−10+e10−10e10−10]
可见,上下同乘 e 10 e^{10} e10即可还原为原公式。
此时,我们分 T r = 2 T_r=2 Tr=2块分别计算上述SoftMax,有:
s o f t m a x ( [ 1 , 2 ] ) = [ e 1 − m 1 e 1 − m 1 + e 2 − m 1 , e 2 − m 1 e 1 − m 1 + e 2 − m 1 ] = [ f 1 l 1 , f 2 l 1 ] , m 1 = 2 s o f t m a x ( [ 3 , 10 ] ) = [ e 3 − m 2 e 3 − m 2 + e 10 − m 2 , e 10 − m 2 e 3 − m 2 + e 10 − m 2 ] = [ f 3 l 2 , f 4 l 2 ] , m 2 = 10 softmax([1,2])=[\frac{e^{1-m_1}}{e^{1-m_1}+e^{2-m_1}},\frac{e^{2-m_1}}{e^{1-m_1}+e^{2-m_1}}]=[\frac{f_1}{l_1},\frac{f_{2}}{l_1}],m_1=2\\ softmax([3,10])=[\frac{e^{3-m_2}}{e^{3-m_2}+e^{10-m_2}},\frac{e^{10-m_2}}{e^{3-m_2}+e^{10-m_2}}]=[\frac{f_3}{l_2},\frac{f_4}{l_2}],m_2=10 softmax([1,2])=[e1−m1+e2−m1e1−m1,e1−m1+e2−m1e2−m1]=[l1f1,l1f2],m1=2softmax([3,10])=[e3−m2+e10−m2e3−m2,e3−m2+e10−m2e10−m2]=[l2f3,l2f4],m2=10
其中,每个小块里减去的是当前块的最大值,记为 m i m_i mi;当前块的分子,记为 p i \boldsymbol{p}_i pi(是多个 f i f_i fi组成的向量);当前块的分母指数和,记为 l i l_i li。对应地,当前块的输出 p i / l i \boldsymbol{p}_i/l_i pi/li,记为 o \boldsymbol{o} o。
在不同块的遍历计算过程中,我们可以不断更新最大值 m ( s ) m(\boldsymbol{s}) m(s)(初始为负无穷)、指数和 l ( s ) l(\boldsymbol{s}) l(s)(初始为0)。
对于 m ( s ) m(\boldsymbol{s}) m(s),更新公式为 m ( s ) n e w = max ( m ( s ) , m i ) m(\boldsymbol{s})^{new}=\max(m(\boldsymbol{s}),m_i) m(s)new=max(m(s),mi)。
对于 l ( s ) l(\boldsymbol{s}) l(s),更新公式为 l ( s ) n e w = e m ( s ) − m ( s ) n e w × l ( s ) + e m i − m ( s ) n e w × l i l(\boldsymbol{s})^{new}=e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times l(\boldsymbol{s})+e^{m_i-m(\boldsymbol{s})^{new}}\times l_i l(s)new=em(s)−m(s)new×l(s)+emi−m(s)new×li。
在第一块中,
- m ( s ) n e w = max ( − inf , m 1 ) = 2 m(\boldsymbol{s})^{new}=\max(-\inf,m_1)=2 m(s)new=max(−inf,m1)=2
- l ( s ) n e w = e m ( s ) − m ( s ) n e w × l ( s ) + e m 1 − m ( s ) n e w × l 1 = e − inf − 2 × 0 + e 2 − 2 × ( e 1 − 2 + e 2 − 2 ) l(\boldsymbol{s})^{new}=e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times l(\boldsymbol{s})+e^{m_1-m(\boldsymbol{s})^{new}}\times l_1=e^{-\inf-2}\times 0+e^{2-2}\times(e^{1-2}+e^{2-2}) l(s)new=em(s)−m(s)new×l(s)+em1−m(s)new×l1=e−inf−2×0+e2−2×(e1−2+e2−2)
- 令 m ( s ) ← m ( s ) n e w m(\boldsymbol{s})\leftarrow m(\boldsymbol{s})^{new} m(s)←m(s)new, l ( s ) ← l ( s ) n e w l(\boldsymbol{s})\leftarrow l(\boldsymbol{s})^{new} l(s)←l(s)new
在第二块中,
- m ( s ) n e w = max ( 2 , m 2 ) = 10 m(\boldsymbol{s})^{new}=\max(2,m_2)=10 m(s)new=max(2,m2)=10
- l ( s ) n e w = e m ( s ) − m ( s ) n e w × l ( s ) + e m 2 − m ( s ) n e w × l 2 l(\boldsymbol{s})^{new}=e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times l(\boldsymbol{s})+e^{m_2-m(\boldsymbol{s})^{new}}\times l_2 l(s)new=em(s)−m(s)new×l(s)+em2−m(s)new×l2
= e 2 − 10 × ( e 1 − 2 + e 2 − 2 ) + e 10 − 10 × ( e 3 − 10 + e 10 − 10 ) = e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 =e^{2-10}\times(e^{1-2}+e^{2-2})+e^{10-10}\times(e^{3-10}+e^{10-10})=e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10} =e2−10×(e1−2+e2−2)+e10−10×(e3−10+e10−10)=e1−10+e2−10+e3−10+e10−10
可见,最后的输出结果 m ( s ) m(\boldsymbol{s}) m(s)和 l ( s ) l(\boldsymbol{s}) l(s)已经与实际 s o f t m a x ( [ 1 , 2 , 3 , 10 ] ) softmax([1,2,3,10]) softmax([1,2,3,10])中的一致。
m ( s ) m(\boldsymbol{s}) m(s)的更新公式能使 m ( s ) n e w m(\boldsymbol{s})^{new} m(s)new始终为当前行的最大值, l ( s ) l(\boldsymbol{s}) l(s)的更新公式能使 l ( s ) n e w l(\boldsymbol{s})^{new} l(s)new的指数项始终减的是 m ( s ) n e w m(\boldsymbol{s})^{new} m(s)new。
同样地,在遍历过程中,我们也可以根据新的 m ( s ) m(\boldsymbol{s}) m(s)和 l ( s ) l(\boldsymbol{s}) l(s)计算和更新当前的 o \boldsymbol{o} o(初始为0向量)。
对于 o \boldsymbol{o} o,更新公式为
o n e w = l ( s ) × e m ( s ) − m ( s ) n e w × o + e m i − m ( s ) n e w × p i × V i l ( s ) n e w \boldsymbol{o}^{new}=\frac{l(\boldsymbol{s})\times e^{m(\boldsymbol{s})-m(\boldsymbol{s})^{new}}\times \boldsymbol{o}+e^{m_i-m(\boldsymbol{s})^{new}}\times \boldsymbol{p}_i\times\boldsymbol{V}_i}{l(\boldsymbol{s})^{new}} onew=l(s)newl(s)×em(s)−m(s)new×o+emi−m(s)new×pi×Vi
其中, p i = [ f i ∗ B r , ⋯ , f ( i + 1 ) ∗ B r ] \boldsymbol{p}_i=[f_{i*Br},\cdots,f_{(i+1)*B_r}] pi=[fi∗Br,⋯,f(i+1)∗Br], V i \boldsymbol{V}_i Vi为V矩阵的第 i i i块。
我们假设 V = [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] , [ 7 , 8 ] ] \boldsymbol{V}=[[1,2],[3,4],[5,6],[7,8]] V=[[1,2],[3,4],[5,6],[7,8]],则有
在第一块中,
- m ( s ) n e w = 2 m(\boldsymbol{s})^{new}=2 m(s)new=2
- l ( s ) n e w = e − inf − 2 × 0 + e 2 − 2 × ( e 1 − 2 + e 2 − 2 ) = e 1 − 2 + e 2 − 2 l(\boldsymbol{s})^{new}=e^{-\inf-2}\times 0+e^{2-2}\times(e^{1-2}+e^{2-2})=e^{1-2}+e^{2-2} l(s)new=e−inf−2×0+e2−2×(e1−2+e2−2)=e1−2+e2−2
- o n e w = 0 × e − inf − 2 × 0 + e 2 − 2 × [ e 1 − 2 , e 2 − 2 ] × [ 1 2 3 4 ] e − inf − 2 × 0 + e 2 − 2 × ( e 1 − 2 + e 2 − 2 ) = [ e 1 − 2 , e 2 − 2 ] × [ 1 2 3 4 ] ( e 1 − 2 + e 2 − 2 ) \boldsymbol{o}^{new}=\frac{0\times e^{-\inf-2}\times 0+e^{2-2}\times [e^{1-2},e^{2-2}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}}{e^{-\inf-2}\times 0+e^{2-2}\times(e^{1-2}+e^{2-2})}=\frac{[e^{1-2},e^{2-2}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}}{(e^{1-2}+e^{2-2})} onew=e−inf−2×0+e2−2×(e1−2+e2−2)0×e−inf−2×0+e2−2×[e1−2,e2−2]×[1324]=(e1−2+e2−2)[e1−2,e2−2]×[1324]
- 令 m ( s ) ← m ( s ) n e w m(\boldsymbol{s})\leftarrow m(\boldsymbol{s})^{new} m(s)←m(s)new, l ( s ) ← l ( s ) n e w l(\boldsymbol{s})\leftarrow l(\boldsymbol{s})^{new} l(s)←l(s)new, o ← o n e w \boldsymbol{o}\leftarrow \boldsymbol{o}^{new} o←onew
在第二块中,
- m ( s ) n e w = 10 m(\boldsymbol{s})^{new}=10 m(s)new=10
- l ( s ) n e w = e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 l(\boldsymbol{s})^{new}=e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10} l(s)new=e1−10+e2−10+e3−10+e10−10
- o n e w = ( e 1 − 2 + e 2 − 2 ) × e 2 − 10 × [ e 1 − 2 , e 2 − 2 ] × [ 1 2 3 4 ] ( e 1 − 2 + e 2 − 2 ) + e 10 − 10 × [ e 3 − 10 , e 10 − 10 ] × [ 5 6 7 8 ] e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 = [ e 1 − 10 , e 2 − 10 ] × [ 1 2 3 4 ] + [ e 3 − 10 , e 10 − 10 ] × [ 5 6 7 8 ] e 1 − 10 + e 2 − 10 + e 3 − 10 + e 10 − 10 \boldsymbol{o}^{new}=\frac{(e^{1-2}+e^{2-2})\times e^{2-10}\times \frac{[e^{1-2},e^{2-2}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}}{(e^{1-2}+e^{2-2})}+e^{10-10}\times [e^{3-10},e^{10-10}]\times \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}}\\=\frac{[e^{1-10},e^{2-10}]\times\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}+[e^{3-10},e^{10-10}]\times \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix}}{e^{1-10}+e^{2-10}+e^{3-10}+e^{10-10}} onew=e1−10+e2−10+e3−10+e10−10(e1−2+e2−2)×e2−10×(e1−2+e2−2)[e1−2,e2−2]×[1324]+e10−10×[e3−10,e10−10]×[5768]=e1−10+e2−10+e3−10+e10−10[e1−10,e2−10]×[1324]+[e3−10,e10−10]×[5768]
可见,最后的结果已经与实际 s o f t m a x ( [ 1 , 2 , 3 , 10 ] ) × V softmax([1,2,3,10])\times\boldsymbol{V} softmax([1,2,3,10])×V一致。
o \boldsymbol{o} o的更新公式能使各块分子指数项上减去最新的 m ( s ) n e w m(\boldsymbol{s})^{new} m(s)new,并使各块的最新的指数和合并。
致谢:
本博客仅做记录使用,无任何商业用途,参考内容如下:
Flash Attention 为什么那么快?原理讲解
Flash Attention论文解读