Transformer的编码器与解码器模块深度解析及python实现完整案例
文章目录
- 一、Transformer 整体架构概览
- 二、编码器 深度解析
- 2.1 输入嵌入与位置编码
- 2.2 多头注意力机制
- 2.3 前馈神经网络
- 2.4 残差连接与层归一化
- 2.5 编码器层堆叠
- 三、解码器 深度解析
- 3.1 输入嵌入与位置编码
- 3.2 掩码多头自注意力
- 3.3 编码器-解码器注意力
- 3.4 解码器层堆叠
- 四、完整 Python 实现
- 4.1 环境准备
- 4.2 完整代码
- 4.3 执行结果
一、Transformer 整体架构概览
Transformer 模型由两个核心部分组成:编码器 和 解码器。
- 编码器:负责理解输入的源语言句子。它将输入序列转换为一组包含丰富上下文信息的向量表示(称为“思想向量”)。
- 解码器:负责生成目标语言句子。它接收编码器的输出,并一个词一个词地生成翻译结果。
它们都由多个相同的层堆叠而成(论文中为6层)。
二、编码器 深度解析
编码器由 N
个相同的编码器层堆叠而成(N=6)。每个编码器层包含两个子层:
- 多头自注意力机制
- 前馈神经网络
这两个子层都使用了 残差连接 和 层归一化。
2.1 输入嵌入与位置编码
原始的输入是词的 ID 序列。模型无法直接理解离散的 ID,所以第一步是 词嵌入,将每个 ID 映射到一个高维稠密向量。
然而,RNN/LSTM 的一个固有特性是它们天生包含位置信息(按顺序处理)。Transformer 没有这个特性,因此必须显式地向模型注入位置信息。这就是 位置编码 的作用。
- 实现:位置编码是一个与词嵌入维度相同的矩阵。对于位置
pos
和维度i
,其值PE(pos, i)
的计算公式如下:- PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos, 2i)} = \sin(\frac{pos}{10000^{2i/d_{model}}})PE(pos,2i)=sin(100002i/dmodelpos)
- PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos(\frac{pos}{10000^{2i/d_{model}}})PE(pos,2i+1)=cos(100002i/dmodelpos)
- 效果:这种设计使得模型能够轻松学习到位置信息,且不同长度的序列可以共享相同的位置编码模式。
最终输入:Input Embeddings + Positional Encodings
2.2 多头注意力机制
这是 Transformer 的核心。它允许模型同时关注序列中不同位置的信息。
- 核心思想:与其用一个单一的注意力函数,不如将
Q
(Query),K
(Key),V
(Value) 投影到h
个不同的子空间中,并行执行h
次注意力计算,然后将结果拼接起来。 - 步骤:
- 线性投影:对于每个头,输入向量
X
通过不同的权重矩阵 WiQ,WiK,WiVW_i^Q, W_i^K, W_i^VWiQ,WiK,WiV 投影,得到Q_i, K_i, V_i
。 - 缩放点积注意力:对每个头计算注意力分数。
$ \text{Attention}(Q_i, K_i, V_i) = \text{softmax}(\frac{Q_i K_i^T}{\sqrt{d_k}}) V_i $- 缩放:除以 dk\sqrt{d_k}dk(
d_k
是 K 的维度)是为了防止点积过大导致 softmax 梯度消失。 - 自注意力:在这里,
Q, K, V
都来自同一个输入(编码器的输出)。
- 缩放:除以 dk\sqrt{d_k}dk(
- 拼接与投影:将
h
个头的输出向量拼接起来,再通过一个最终的线性层 WOW^OWO 进行投影,得到最终的输出。
- 线性投影:对于每个头,输入向量
2.3 前馈神经网络
FFN 对每个位置的表示独立地进行相同的非线性变换。它由两个全连接层组成,中间有一个 ReLU 激活函数。
- 公式:FFN(x)=max(0,xW1+b1)W2+b2\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2FFN(x)=max(0,xW1+b1)W2+b2
- 作用:为模型增加非线性变换能力,增强其表达能力。它对每个位置的向量进行操作,不涉及跨位置的信息交互。
2.4 残差连接与层归一化
这是训练深层网络的关键技巧。
- 残差连接:将子层的输入直接加到其输出上。
Output=Sublayer(Input)+Input\text{Output} = \text{Sublayer}(\text{Input}) + \text{Input}Output=Sublayer(Input)+Input - 作用:解决了深度网络中的梯度消失/爆炸问题,使得训练更深的网络成为可能。
- 层归一化:在残差连接之后,对每个样本的所有特征进行归一化,稳定训练过程。
LayerNorm(x)=γx−μσ2+ϵ+β\text{LayerNorm}(x) = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \betaLayerNorm(x)=γσ2+ϵx−μ+β
(其中 γ\gammaγ 和 β\betaβ 是可学习的参数)
2.5 编码器层堆叠
将上述的 多头注意力 -> 残差连接+层归一化 -> FFN -> 残差连接+层归一化 结构作为一个编码器层。然后,将 N
个这样的层堆叠起来,每一层的输出都是下一层的输入。
三、解码器 深度解析
解码器同样由 N
个相同的层堆叠而成。每个解码器层包含三个子层:
- 掩码多头自注意力
- 编码器-解码器多头注意力
- 前馈神经网络
同样,所有子层都使用残差连接和层归一化。
3.1 输入嵌入与位置编码
与编码器类似,解码器的输入(目标语言句子)也需要经过 词嵌入 和 位置编码。
3.2 掩码多头自注意力
这个子层与编码器的自注意力非常相似,但有一个关键区别:掩码。
- 目的:在训练时,解码器需要预测下一个词。例如,在预测 “world” 时,它只能看到 “I” 和 “hello”,而不能看到 “world” 及其之后的词。这被称为 “教师强制” 机制。
- 实现:通过在 softmax 计算之前,将未来位置对应的注意力分数设置为一个非常小的值(如
-1e9
),使得 softmax 后这些位置的权重接近于 0。
3.3 编码器-解码器注意力
这是连接编码器和解码器的桥梁。
- 区别:这里的
Q
来自解码器前一层的输出,而K
和V
来自编码器的最终输出。 - 作用:允许解码器在生成每个词时,有选择地关注输入源语言句子中的不同部分。例如,翻译 “cat” 时,解码器会高度关注输入中的 “猫”。
3.4 解码器层堆叠
将 掩码自注意力 -> 编码器-解码器注意力 -> FFN 结构作为一个解码器层,并堆叠 N
个。
四、完整 Python 实现
我们将使用 PyTorch 来构建 Transformer 模型。
4.1 环境准备
确保你已经安装了 PyTorch。
pip install torch
4.2 完整代码
我们将按照解析的顺序,一步步构建模型组件。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np# --- 1. 位置编码 ---
class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)position = torch.arange(max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))# 使用 pe.register_buffer 使得 pe 成为模型的一部分,但不会被视为参数pe = torch.zeros(max_len, 1, d_model)pe[:, 0, 0::2] = torch.sin(position * div_term)pe[:, 0, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe)def forward(self, x):"""x: Tensor, shape [seq_len, batch_size, d_model]"""# x 的形状是 [seq_len, batch_size, d_model]# self.pe 的形状是 [max_len, 1, d_model]# 广播机制将 self.pe 的第1维扩展为 batch_sizex = x + self.pe[:x.size(0)]return self.dropout(x)# --- 2. 完整的 Transformer 模型 ---
class Transformer(nn.Module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6,num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):super(Transformer, self).__init__()self.d_model = d_model# 嵌入层self.src_embedding = nn.Embedding(src_vocab_size, d_model)self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)# 位置编码self.positional_encoding = PositionalEncoding(d_model, dropout)# PyTorch 提供的官方 Transformer 模型# transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers,# dim_feedforward, dropout, batch_first=False) # 默认 batch_first=Falseself.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers,dim_feedforward, dropout, batch_first=True) # 使用 batch_first=True 更直观# 输出层self.fc_out = nn.Linear(d_model, tgt_vocab_size)# 初始化参数self._init_weights()def _init_weights(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None,src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):# src: [batch_size, src_len]# tgt: [batch_size, tgt_len]# 1. 嵌入和位置编码# 将词索引转换为词向量,并乘以 d_model 的平方根进行缩放src_emb = self.src_embedding(src) * math.sqrt(self.d_model)tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)# 2. 添加位置编码# nn.Transformer with batch_first=True expects [batch, seq_len, features]src_emb = self.positional_encoding(src_emb) # [batch, src_len, d_model]tgt_emb = self.positional_encoding(tgt_emb) # [batch, tgt_len, d_model]# 3. 调用官方 Transformer 模块# 输出形状: [batch_size, tgt_len, d_model]output = self.transformer(src_emb, tgt_emb,src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_mask,src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask,memory_key_padding_mask=memory_key_padding_mask)# 4. 通过线性层输出# 输出形状: [batch_size, tgt_len, tgt_vocab_size]return self.fc_out(output)# --- 3. 辅助函数 ---def generate_square_subsequent_mask(sz):"""生成一个用于解码器的因果掩码(上三角矩阵)。确保在位置 i,只能看到位置小于 i 的 token。"""# 返回一个形状为 (sz, sz) 的张量# torch.triu 返回一个上三角矩阵,对角线及以上为1,其余为0# 然后用 0 填充上三角,用 -inf 填充下三角(不包括对角线)return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)# --- 4. 模型训练与推理示例 ---# 假设的超参数
SRC_VOCAB_SIZE = 5000
TGT_VOCAB_SIZE = 5000
D_MODEL = 512
NHEAD = 8
NUM_ENCODER_LAYERS = 3 # 减少层数以加快演示速度
NUM_DECODER_LAYERS = 3
DIM_FEEDFORWARD = 512
DROPOUT = 0.1
BATCH_SIZE = 64
MAX_SEQ_LEN = 20# 1. 初始化模型
transformer_model = Transformer(SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, D_MODEL, NHEAD,NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,DIM_FEEDFORWARD, DROPOUT)# 2. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=0) # 假设 0 是 padding token 的索引
optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)# 3. 模拟数据
# 创建一些随机数据来模拟输入
src_data = torch.randint(1, SRC_VOCAB_SIZE, (BATCH_SIZE, MAX_SEQ_LEN))
tgt_data = torch.randint(1, TGT_VOCAB_SIZE, (BATCH_SIZE, MAX_SEQ_LEN))# 创建一个填充掩码,假设索引 0 是填充 token
# src_key_padding_mask: [batch_size, src_len]
src_key_padding_mask = (src_data == 0)
tgt_key_padding_mask = (tgt_data == 0)# --- 训练过程 ---
transformer_model.train()
optimizer.zero_grad()# 准备解码器的输入和输出
# 解码器的输入是目标序列向左平移一位
tgt_input = tgt_data[:, :-1]
tgt_output = tgt_data[:, 1:]# 创建解码器掩码 (因果掩码)
# tgt_mask: [tgt_len, tgt_len]
tgt_mask = generate_square_subsequent_mask(tgt_input.size(1))# 前向传播
logits = transformer_model(src_data, tgt_input,src_key_padding_mask=src_key_padding_mask,tgt_key_padding_mask=(tgt_input == 0), # 对输入进行掩码tgt_mask=tgt_mask)# 计算损失
# logits: [batch_size, tgt_len, tgt_vocab_size]
# tgt_output: [batch_size, tgt_len]
# 需要将 logits 的维度从 [B, L, V] 变为 [B*L, V],将 tgt_output 展平为 [B*L]
loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))# 反向传播和优化
loss.backward()
optimizer.step()print(f"Training Loss: {loss.item()}")# --- 推理过程 (Greedy Decoding) ---
transformer_model.eval()
src_sample = torch.randint(1, SRC_VOCAB_SIZE, (1, MAX_SEQ_LEN)) # [1, src_len]
sos_token = torch.tensor([[1]]) # Start of sequence token
eos_token = torch.tensor([[2]]) # End of sequence token# 创建源序列的填充掩码
src_key_padding_mask_infer = (src_sample == 0)# 生成目标序列
for i in range(MAX_SEQ_LEN):# 在每次迭代中,tgt_input 的长度都会增加# tgt_input: [1, current_len]tgt_input = sos_token if i == 0 else torch.cat([sos_token, generated_tokens], dim=1)# 在每次迭代中都重新创建因果掩码,因为序列长度在变tgt_mask_infer = generate_square_subsequent_mask(tgt_input.size(1))# 获取模型输出# output: [1, current_len, tgt_vocab_size]output = transformer_model(src_sample, tgt_input,src_key_padding_mask=src_key_padding_mask_infer,tgt_key_padding_mask=(tgt_input == 0),tgt_mask=tgt_mask_infer)# 取最后一个时间步的输出last_word_logits = output[:, -1, :] # [1, tgt_vocab_size]# 获取概率最高的词的索引next_token = torch.argmax(last_word_logits, dim=-1).unsqueeze(0) # [1, 1]# 将新词添加到已生成的序列中generated_tokens = next_token if i == 0 else torch.cat([generated_tokens, next_token], dim=1)# 如果生成了 EOS token,则停止生成if next_token.item() == eos_token.item():breakprint("\n--- Inference ---")
print(f"Source: {src_sample}")
print(f"Generated Target: {generated_tokens}")
4.3 执行结果
Training Loss: 8.610507011413574--- Inference ---
Source: tensor([[4958, 4725, 800, 2363, 1056, 4051, 2208, 1702, 3965, 1113, 2956, 98,3260, 2197, 250, 1491, 1247, 682, 2149, 1405]])
Generated Target: tensor([[540, 540, 540, 540, 540, 540, 540, 540, 540, 540, 540, 540, 540, 540,540, 540, 540, 540, 540, 540]])
总结:
- 核心思想:Transformer 通过 自注意力机制 实现了序列的并行处理和长距离依赖的捕捉,彻底改变了 NLP 领域。
- 编码器:负责理解输入。它通过多层 自注意力 和 FFN 来提取和整合输入序列的上下文信息。
- 解码器:负责生成输出。它通过 掩码自注意力 来保证预测的“因果性”,并通过 编码器-解码器注意力 来获取源语言的信息。
- 关键组件:
- 位置编码:为模型注入序列顺序信息。
- 多头注意力:让模型从不同的“子空间”关注信息。
- 残差连接与层归一化:确保深层网络可以稳定高效地训练。
- Python 实现:通过 PyTorch,我们可以清晰地构建出 Transformer 的每一个模块,并最终组合成一个完整的模型。这个实现是理解和使用现代大型语言模型(如 GPT, BERT)的基石。