从零用 NumPy 实现单层 Transformer 解码器(Decoder-Only)
200行代码理解LLM Attention+自解码推理
- 从零用 NumPy 实现单层 Transformer 解码器(Decoder-Only)
-
- 一、整体流程
- 二、核心组件实现
-
- 1. 正弦位置编码
- 2. Layer Normalization
- 3. 数值稳定的 Softmax
- 4. 多头自注意力(含因果掩码)
- 5. 前馈网络(MLP + GELU)
- 6. 单层 Transformer 的前向函数
- 三、运行与验证
- 四、总结
从零用 NumPy 实现单层 Transformer 解码器(Decoder-Only)
近年来,大语言模型(LLM)大多基于 Transformer Decoder 架构,例如 GPT、LLaMA 等。在这篇文章中,我们将用 纯 NumPy 实现一个单层的 Decoder-Only Transformer,并支持因果掩码、多头注意力、GELU 激活等核心特性。
一、整体流程
单层 Pre-LN Decoder 的标准计算步骤为:
- 添加位置编码(Positional Encoding)
- LayerNorm(预归一化)
- 多头自注意力(Multi-Head Attention)
- 输出投影 W O W_O WO
- 残差连接①
- LayerNorm(第二次)
- 前馈网络(MLP + GELU)
- 残差连接②
- 词表投影预测下一个 Token
流程示意图如下(省略 Batch):
X_in → +PE → LN → MHA(+mask) → W_O → +残差①→ LN → MLP(GELU) → +残差② → W_vocab^T → softmax
二、核心组件实现
1. 正弦位置编码
我们使用原论文《Attention Is All You Need》中的正弦位置编码,为每个位置生成固定的向量并加到嵌入上:
def add_cos_embedding(X_l):seq_len, embed_dim = X_l.shapefor i in range(seq_len):for j in range(embed_dim):if j % 2 == 0:X_l[i, j] += np.sin(float(i) / (10000 ** (j / embed_dim)))else:X_l[i, j] += np.cos(float(i-1) / (10000 ** ((j-1) / embed_dim)))return X_l
这里用到了不同频率的正弦/余弦,偶数维用 sin
,奇数维用 cos
,从而编码位置信息。
2. Layer Normalization
LayerNorm 是 Transformer 的标配归一化方法,这里我们用每个 token 自身的均值和方差来归一化:
def layer_norm(X_l, beta, gamma, eps=1e-6):mean = np.mean(X_l, axis=-1, keepdims=True)std = np.std(X_l, axis=-1, keepdims=True)X_l = (X_l - mean) / (std + eps)return X_l * gamma + beta
注意,这里 gamma
和 beta
是可学习参数,对所有 token 共享。
3. 数值稳定的 Softmax
为了避免指数溢出,我们在计算 exp
前先减去行最大值:
def softmax(X_l, eps=1e-6):X_l = X_l - np.max(X_l, axis=-1, keepdims=True)exp_x = np.exp(X_l)return exp_x / (np.sum(exp_x, axis=-1, keepdims=True) + eps)
4. 多头自注意力(含因果掩码)
在 Decoder 中,我们必须确保当前 token 不能看到未来的信息,因此需要因果掩码(Causal Mask):
def compute_attention(X_q, X_k, X_v, heads, attn_dim, mask):head_outputs = []for i in range(heads):scores = np.dot(X_q[i], X_k[i].T) / np.sqrt(attn_dim)scores = scores + mask # 将不可见位置加上 -1e9attn = softmax(scores)out = np.dot(attn, X_v[i])head_outputs.append(out)return np.concatenate(head_outputs, axis=-1) # 拼接所有头
其中 mask
是一个 [seq_len, seq_len]
的矩阵,上三角(未来位置)为 -1e9
,下三角为 0。
5. 前馈网络(MLP + GELU)
MLP 部分使用了常见的 GELU 激活函数:
def gelu(x):return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))def mlp