1 概念速览
术语 | 定义 | 关键点 |
---|
编码器 (Encoder) | 将输入序列 $\mathbf x={x_1\dots x_n}$ 压缩为 上下文表示 $\mathbf c$(向量或张量) | 提炼关键信息,支持变长输入 |
解码器 (Decoder) | 在 $\mathbf c$ 的条件下自回归地产生输出序列 $\mathbf y={y_1\dots y_m}$ | 生成、翻译、预测等;可视为条件语言模型 |
训练目标 | 最大化对数似然 $\log p_\theta(\mathbf y\mid\mathbf x)$ | 典型损失:交叉熵 |
为什么有效 | 原生支持「变长输入 → 变长输出」,并能通过注意力显式对齐 | 机器翻译、摘要等 Seq2Seq 任务的基础 |
2 网络形态与“同构”迷思
Encoder | Decoder | 场景示例 | 备注 |
---|
RNN/LSTM/GRU | RNN/LSTM/GRU | 早期 NMT、时间序列预测 | 纵向依赖强,训练难度大 |
卷积 CNN | 反卷积或 CNN | U-Net 图像分割 | 本地感受野,建模全局需扩张卷积 |
Transformer | Transformer | 主流文本/多模态生成 | 并行化、长依赖;显存吃紧 |
Hybrid | Hybrid | 长序列、流式语音 ASR | 编码器和解码器可异构 |
结论: 编码器与解码器完全可以使用不同类型网络。
- RNN → Transformer:先压缩时序,再高效全局注意力解码
- CNN → CTC 解码器:流式语音,低延迟
- ViT → 文本 Transformer:图像字幕(BLIP-2)
3 注意力 & 上下文
- 固定上下文向量瓶颈
早期 Seq2Seq 仅传递单向量 $\mathbf c$,长句信息易丢失。 - Bahdanau / Luong 注意力
解码时对编码隐藏态打分,动态读取相关信息。 - Transformer
编码器和解码器均以多头自注意力为核心,完全抛弃循环结构。 - 跨模态注意力
视觉 token ↔ 字幕 token,或语音特征 ↔ 文本 token。
4 典型任务与落地框架
任务 | 输入 ➜ 输出 | 主流开源模型 / 库 |
---|
机器翻译 | 句子 ➜ 句子 | Transformer、mBART、MarianMT |
文本摘要 | 长文 ➜ 简短摘要 | BART、Pegasus、T5 |
对话生成 | 历史对话 ➜ 回复 | DialoGPT、LLaMA-Chat |
语音识别 | 声谱图 ➜ 文本 | Whisper、RNN-T |
图像字幕 | 图像特征 ➜ 文字 | BLIP-2、PaLI |
时间序列预测 | 历史序列 ➜ 未来序列 | Informer、Seq2Seq RNN |
代码补全 | 代码上下文 ➜ 续写 | CodeT5、StarCoder |
5 极简 PyTorch 模板
import torch, torch.nn as nn
from random import randomclass Seq2Seq(nn.Module):def __init__(self, encoder, decoder, sos_id, eos_id, max_len=128):super().__init__()self.encoder, self.decoder = encoder, decoderself.sos, self.eos, self.max_len = sos_id, eos_id, max_lendef forward(self, src_ids, tgt_ids=None, teacher_forcing=0.5):memory = self.encoder(src_ids)B = src_ids.size(0)ys = torch.full((B, 1), self.sos, dtype=torch.long, device=src_ids.device)outputs = []for t in range(self.max_len):logits = self.decoder(ys, memory) next_token = logits[:, -1].argmax(-1, keepdim=True)outputs.append(next_token)if tgt_ids is not None and random() < teacher_forcing:ys = torch.cat([ys, tgt_ids[:, t:t+1]], dim=1)else:ys = torch.cat([ys, next_token], dim=1)if (next_token == self.eos).all():breakreturn torch.cat(outputs, dim=1)
用法示例
enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=512, nhead=8), num_layers=6)
dec = nn.TransformerDecoder(nn.TransformerDecoderLayer(d_model=512, nhead=8), num_layers=6)
model = Seq2Seq(enc, dec, sos_id=0, eos_id=2)
6 进阶主题
方向 | 思路 | 代表工作 |
---|
长上下文 | 稀疏/线性注意力(Performer, Longformer) | LongT5, Flash-Attention |
检索增强 (RAG) | 外部向量数据库返回候选段落,拼接进解码器输入 | RETRO, Atlas, LlamaIndex |
多模态对齐 | 视觉/音频编码器 + 文本解码器;对比学习统一 token 空间 | BLIP-2, Gemini, GPT-4o |
效率优化 | 混合精度、蒸馏、小模型教师、KV 缓存、模型并行 | DeepSpeed ZeRO-3, Flash-Decoding |
7 小结与实践建议
- 架构是方法论:编码器负责理解,解码器负责表达,二者可自由组合。
- 先跑通,再混搭:先用官方 Transformer 教程跑 NMT baseline,再尝试 LSTM-Enc + Transformer-Dec 等混搭,体会差异。
- 关注长上下文与检索增强:RAG 正成为工业搜索-生成系统的主流范式。
- 做项目,反推理论:挑一项真实业务(如 PDF 摘要、邮 件分类),落地一条端到端流水线,遇到痛点再查论文,理解会更深。