Transformer的并行计算与长序列处理瓶颈
Transformer相比RNN(循环神经网络)的核心优势之一是天然支持并行计算,这源于其自注意力机制和网络结构的设计.并行计算能力和长序列处理瓶颈是其架构特性的两个关键表现:
- 并行计算:指 Transformer 在训练 / 推理时通过矩阵运算并行化、模块独立性实现高效计算的能力;
- 长序列处理瓶颈:指当输入序列长度(n)增加时,自注意力机制的计算 / 内存复杂度呈O(n²)增长,导致效率骤降的问题。
1. 并行计算
1. 自注意力机制的并行性
自注意力的计算公式为:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dkQKT)V
对于序列长度为nnn的输入,自注意力中每个位置的计算不依赖其他位置的中间结果:
- 计算Q、K、VQ、K、VQ、K、V的线性变换时,所有token的qi、ki、viq_i、k_i、v_iqi、ki、vi可同时生成(并行);
- 计算QKTQK^TQKT(n×nn×nn×n的分数矩阵)时,每个元素score(i,j)score(i,j)score(i,j)的计算独立于其他元素(可并行);
- 即使是softmax和加权求和步骤,也可对整个序列的所有位置同时执行(并行)。
而RNN需要按序列顺序计算(hih_ihi依赖hi−1h_{i-1}hi−1),完全串行,无法并行。
2. 网络结构的并行性
- 编码器/解码器层的并行:编码器的每一层(多头注意力+前馈网络)对整个序列的处理是“批量”的,所有token共享层参数,可同时更新;
- 训练时的并行优化:结合数据并行(同一模型在不同样本上并行训练)、模型并行(将网络层拆分到不同设备),可充分利用GPU/TPU的并行计算能力,大幅加速训练。
核心观点:Transformer的并行能力源于模块独立性和矩阵运算的可并行性。
- 底层:矩阵运算天然支持并行(GPU的SIMD架构可并行处理矩阵元素);
- 中层:模块独立(前馈网络对每个位置的计算独立;多头注意力的“头”之间无依赖);
- 顶层:训练时可通过批处理(batch维度)、序列分片进一步提升并行效率。
根本原理:并行能力源于“计算单元的独立性”和“矩阵运算的可拆分性”。
- 前馈网络:对序列中每个位置的计算是独立函数(FFN(x_i) = W2·ReLU(W1·x_i + b1) + b2),无跨位置依赖,可完全并行;
- 多头注意力:每个“头”的计算独立(head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)),头之间可并行;
- 矩阵运算:QKT的每个元素(QKT)[i][j] = Q[i]·K[j],元素间无依赖,可由GPU并行计算。
1. 长序列瓶颈
长序列处理的核心瓶颈
当序列长度nnn增大(如文档级文本、长视频帧、基因组序列,nnn可达10410^4104甚至10510^5105),Transformer的性能会急剧下降,核心瓶颈来自自注意力的O(n2)O(n²)O(n2)复杂度:
1. 计算复杂度瓶颈
自注意力的核心步骤(QKTQK^TQKT矩阵乘法)的计算量为O(n2⋅d)O(n²·d)O(n2⋅d)(ddd为隐藏层维度):
- 当n=1000n=1000n=1000时,计算量约为106⋅d10^6·d106⋅d;
- 当n=10000n=10000n=10000时,计算量增至108⋅d10^8·d108⋅d(是前者的100倍)。
这种平方级增长会导致: - 单次前向/反向传播时间大幅增加(训练/推理变慢);
- 难以利用并行计算优势(过多计算量超出硬件算力上限)。
2. 内存瓶颈
自注意力过程中需要存储多个n×nn×nn×n或n×dn×dn×d的中间张量:
- Q、K、VQ、K、VQ、K、V的形状为(n,d)(n,d)(n,d),总内存为O(3nd)O(3nd)O(3nd);
- QKTQK^TQKT的分数矩阵形状为(n,n)(n,n)(n,n),内存为O(n2)O(n²)O(n2);
- 注意力权重矩阵(softmax结果)同样为(n,n)(n,n)(n,n),内存O(n2)O(n²)O(n2)。
当n=10000n=10000n=10000时,n2=108n²=10^8n2=108,若每个元素为4字节(float32),仅分数矩阵就需要400MB内存,加上其他张量,单头注意力就可能占用数GB内存,远超普通GPU的显存上限(如16GB GPU难以处理n=20000n=20000n=20000的序列)。
3. 优化器的额外负担
训练时,优化器(如Adam)需要存储所有参数的梯度和动量信息,长序列会导致中间变量(如注意力权重的梯度)的内存占用也随n2n²n2增长,进一步加剧内存压力。
三、长序列处理的解决方案
为突破O(n2)O(n²)O(n2)瓶颈,研究者提出了多种优化思路,核心是用“稀疏注意力”或“线性复杂度注意力”替代全局注意力:
- 稀疏注意力(Sparse Attention)
仅计算部分位置的注意力,将复杂度降至O(n⋅w)O(n·w)O(n⋅w)(www为局部窗口大小):
- 滑动窗口注意力(如Longformer):每个位置仅关注左右www个相邻位置(总窗口2w+12w+12w+1),适合时序相关的长序列;
- 固定稀疏模式(如BigBird):每个位置关注“局部窗口+随机采样+全局标记”,兼顾局部相关性和全局信息;
- 轴向注意力(如Axial Transformer):将长序列拆分为多个维度(如文本拆分为“句-词”),在每个维度单独计算注意力,复杂度降至O(n⋅n)O(n·\sqrt{n})O(n⋅n)。
- 线性注意力(Linear Attention)
用“核函数”替换QKTQK^TQKT的矩阵乘法,将复杂度降至O(n⋅d)O(n·d)O(n⋅d):
- 核心思路:将softmax(QKT/d)V\text{softmax}(QK^T/\sqrt{d})Vsoftmax(QKT/d)V改写为KT(softmax(QKT/d)TV)Z\frac{K^T(\text{softmax}(QK^T/\sqrt{d})^T V)}{Z}ZKT(softmax(QKT/d)TV)(ZZZ为归一化项),通过核函数(如exp(q⋅k)\exp(q·k)exp(q⋅k))的性质,将矩阵乘法转化为逐元素操作;
- 代表模型:Performer(用随机特征映射近似核函数)、Linformer(用低秩矩阵近似K、VK、VK、V)。
- 分层/压缩注意力
通过“序列压缩”减少有效长度:
- ** hierarchical Attention**:先对长序列分块,计算块内注意力得到“块表示”,再计算块间注意力(如文档先分句子,再对句子表示计算注意力);
- Downsampling:用池化(如平均池化)或卷积将长序列压缩为短序列(如ViT中的Patch Embedding将图像压缩为n=14×14n=14×14n=14×14的patch序列)。
核心观点:长序列处理瓶颈源于自注意力的全连接关联特性,导致复杂度随长度平方增长。分层展开:
- 底层:自注意力需计算“每个位置与所有位置”的关联(QK^T矩阵为n×n);
- 中层:计算复杂度O(n²d)(d为隐藏维度)、内存占用O(n²)(存储注意力权重);
- 顶层:当n过大(如n>10k),计算耗时、内存溢出,效率骤降。
根本原理:自注意力的“全关联定义”导致复杂度随长度平方增长,是机制固有属性。
自注意力的核心公式为:
Attention(Q,K,V) = softmax((QK^T)/√d_k)·V
其中QK^T是n×n矩阵(n为序列长度),其计算/存储复杂度必然是O(n²);即使优化实现(如稀疏化),也只能降低系数,无法改变O(n²)的本质(因“注意力”定义本身要求衡量位置间的关联)。