当前位置: 首页 > news >正文

讲透 RNN 到 Transformer !!!

哈喽,我是我不是小upper~​

在深度学习领域,序列数据处理一直是重要的研究方向,从语音识别到自然语言处理,从时间序列分析到机器翻译,数据都以序列的形式存在。今天咱们就来聊聊从 RNN 到 Transformer 的演变历程,一起探究为什么 “Attention Is All You Need”?​

为什么一开始我们用 RNN?​

在深度学习发展早期,面对序列数据,传统神经网络难以捕捉数据间的时序依赖关系,而 RNN(循环神经网络)的出现解决了这一难题。想象一下,当咱们在看一部小说的时候,每一页的内容都和前面发生的事情有关,理解当前情节需要结合之前的故事脉络。RNN 就像一个人在读小说,每读一句话就记住重点,然后带着记忆去读下一句。它通过隐藏层的循环连接,按顺序处理信息,一步一步传递上下文信息。​

具体来说,RNN 在每个时间步都会接收当前输入和上一个时间步的隐藏状态,通过特定的计算方式更新隐藏状态,将之前积累的 “记忆” 与新信息融合。这种机制使得 RNN 能够处理变长序列数据,在语音识别、文本生成等任务中崭露头角,比如在语言模型中,RNN 可以根据前文预测下一个可能出现的单词,在机器翻译里,能按顺序将源语言逐词转换为目标语言 。

RNN 的困境与局限​

虽然 RNN 开创了序列数据处理的新方向,但随着研究的深入和应用场景复杂度的提升,它的问题逐渐暴露出来。​

首先,RNN 处理数据的速度太慢。由于它必须按顺序,一个时间步一个时间步地处理,就像你得一句一句等着看小说,不能跳着看,无法并行计算。在处理长文本、长语音序列等大规模数据时,这种顺序处理方式会导致计算效率极低,训练时间大幅增加,难以满足实时性要求高的应用场景。​

其次,RNN 存在严重的 “记忆力不好” 问题,也就是梯度消失问题。在反向传播过程中,梯度需要从序列的最后一个时间步反向传递到第一个时间步来更新网络参数。随着序列长度增加,梯度在传递过程中会不断衰减,就像声音在长长的隧道里传播,越传越弱,导致网络难以学习到长距离的依赖关系,前面发生的事容易被 “忘记” 。这使得 RNN 在处理长序列时,对早期信息的利用能力很差,比如在生成一篇长文章时,可能会出现前后逻辑不一致、忘记前文设定等问题。​

此外,RNN 结构的局限性还体现在难以处理复杂的语义关系,它对信息的整合方式较为单一,无法有效捕捉序列中不同位置信息之间的复杂关联。​

Transformer 出现前的 “革命”:Attention 机制​

在 Transformer 横空出世之前,Attention(注意力机制)的出现已然是深度学习领域的一场 “革命”。它的灵感其实源于人类的认知习惯,就好比咱们复习的时候,会把小说的精彩段落划重点。当我们阅读文本时,大脑并不会机械地逐字处理,而是会根据当前内容,灵活地回顾其他关键句子,并决定 “我现在需要关注谁”。​

从技术角度来讲,Attention 是一种 “加权平均” 思想。在传统的 RNN 中,每一步的输出依赖上一步的隐藏状态,信息传递具有很强的顺序性;而 Attention 机制打破了这种局限,它不再只关注上一个时间步,而是能够 “纵观全局”,对所有的输入进行考量,并根据相关性赋予不同的权重。​

以自然语言处理中的指代消解为例,当句子是 “张三去了公园。后来他……”,当前输入是 “他”,想要知道 “他” 指的是谁,Attention 机制就会去 “翻阅” 前文,通过计算语义的相关性,发现 “张三” 和 “他” 最相关,从而赋予 “张三” 更高的权重,准确理解 “他” 的指代对象 。​

用数学公式来描述 Attention 的基本原理,假设我们有一个 “查询” 向量 ​Q,以及多个 “键 - 值对”,其中键组成矩阵 ​K,值组成矩阵 ​V。Attention 的计算过程就是要从这些 “键 - 值对” 中找出与 “查询” 最相关的内容,并进行加权求和。具体公式如下:

其中,​d_k​ 是键 ​K 的维度,除以 \sqrt{d_k} 这一步骤被称为缩放(Scaling),目的是避免在计算 ​QK^T 时,由于维度过高导致数值过大,进而引发梯度爆炸的问题 。通过这个公式,我们可以根据 “查询” 和 “键” 之间的相似度,计算出每个 “值” 对应的权重,再用这些权重对 “值” 进行加权求和,得到最终的输出。

Transformer 的诞生与革新​

为了解决 RNN 的诸多问题,Transformer 应运而生。它直接放弃了 RNN 顺序处理的模式,将 Attention 机制发挥到极致,其核心思想正如论文标题所言 ——“Attention Is All You Need”,即仅依靠注意力机制就能高效完成序列数据处理任务 。​

Transformer 引入了 Self-Attention(自注意力)机制,这种机制可以理解为:在处理一个序列中的每个词时,该词都能 “看到” 序列中的其他所有词(包括自己),并通过计算彼此之间的注意力权重,来确定在生成当前词的表示时,其他词的重要程度 。同时,Transformer 丢掉了循环结构,采用位置编码(Positional Encoding)来补充顺序信息,弥补了没有循环结构可能丢失的位置信息。​

Transformer 的核心构件​

        1. Self-Attention(自注意力)

假设输入是由词向量组成的矩阵 ​X,Transformer 首先通过三个不同的线性变换,将 ​X 分别映射为三个新的矩阵:查询矩阵 ​Q、键矩阵 ​K和值矩阵 ​V,即:​

​其中,​W^Q、​W^K、​W^V是可学习的权重矩阵 。​

得到 ​Q、​K、​V后,就可以按照 Attention 的基本公式计算自注意力:​

​经过这一系列计算,最终输出依然是与输入矩阵 ​X 维度相同的矩阵,只不过这个矩阵中的每个元素,都融合了序列中各个位置的信息,且根据相关性进行了加权 。

        2. 多头注意力(Multi-Head Attention)​

为了让模型能够从不同角度捕捉信息,Transformer 并没有只用一次注意力计算,而是采用了多头注意力机制。它并行地进行多次(通常是 ​h 次)自注意力计算,也就是多个 “头”(head) 。每个 “头” 都有自己独立的参数,对输入进行不同的变换和计算。​

对于每个 “头” ​i,其计算过程如下:​

head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

​其中,W_i^QW_i^KW_i^V是第 ​i个 “头” 对应的权重矩阵 。​完成所有 “头” 的计算后,将各个 “头” 的输出结果拼接起来,再通过一个线性变换得到最终的多头注意力输出:​

MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O

​其中,​W^O 是用于线性变换的权重矩阵 。多头注意力机制使得模型能够捕捉到更丰富的语义信息和不同层面的依赖关系 。

        3. Transformer Block 的结构​

Transformer 的每一层由两部分组成:​

(1) 首先是多头自注意力层,在计算多头自注意力后,采用残差连接(Residual Connection)将输入直接加到多头自注意力的输出上,这样可以有效缓解梯度消失问题,帮助网络更好地训练。残差连接后再进行 LayerNorm(层归一化)操作,LayerNorm 的作用是对每个样本的特征进行归一化,使网络训练更加稳定 。​

(2) 接着是前馈神经网络(Feed Forward Network, FFN),它由两个线性变换和一个非线性激活函数组成,对多头自注意力的输出进一步处理 。同样,在前馈神经网络输出后,也会进行残差连接和 LayerNorm 。

        4. 位置编码(Positional Encoding)​

由于 Transformer 没有像 RNN 那样的循环结构来体现数据的顺序,所以引入了位置编码来表示序列中元素的顺序信息。位置编码是一个与输入词向量维度相同的向量,通过将其加到输入的词向量上,就可以为模型提供位置信息 。常用的位置编码公式如下:

PE(pos, 2i) = sin(\frac{pos}{10000^{\frac{2i}{d_{model}}}})

PE(pos, 2i + 1) = cos(\frac{pos}{10000^{\frac{2i}{d_{model}}}})

其中,​pos 表示位置,i 表示维度索引,​d_{model} 是模型的维度 。通过这种方式生成的位置编码,能够让模型区分不同位置的元素,并且不同位置的编码之间具有一定的数学关系,便于模型学习 。

为什么 “Attention Is All You Need”?​

Transformer 之所以能够喊出 “Attention Is All You Need” 这句口号,是因为它凭借强大的注意力机制,实现了对传统序列处理模型的超越:​

  • 结构简化:丢掉了 RNN 复杂的循环结构,不再受限于顺序处理,极大地简化了模型架构,降低了模型训练和优化的难度 。​
  • 高效的信息交互:用注意力机制完成所有信息交互,每个元素都能直接关注到序列中的其他元素,能够更精准地捕捉长距离依赖关系和复杂的语义关联 。​
  • 并行计算:支持并行计算,不再像 RNN 那样一个时间步一个时间步地处理数据,能够一次性处理整个序列,大幅提升了训练和推理速度,尤其在处理大规模数据时优势明显 。​
  • 长文本处理能力:在处理长文本时,不会出现像 RNN 那样的梯度消失和长距离依赖难以捕捉的问题,对长文本的处理表现更加稳定和出色 。​

从 RNN 到 Attention 机制,再到 Transformer,深度学习领域在序列数据处理方向的每一次突破,都让我们离人工智能的 “智能” 本质更近一步。Transformer 凭借其创新性和高效性,开启了人工智能技术发展的新篇章,也为后续更多的研究和应用奠定了坚实的基础 。

完整案例

我们来构造一个简单的字符级序列任务:输入是表示两位数字加法的字符序列(如 "13+35"),输出是其结果(如 "48")。该任务可以便于模拟 Seq2Seq 场景,并能直观对比模型表现。

# comments: 数据集生成脚本
import randomdef generate_example():a = random.randint(0, 99)b = random.randint(0, 99)x = f"{a:02d}+{b:02d}"  # 输入例如 "06+25"y = str(a + b)         # 输出 "31"return x, y# 生成数据集
dataset = [generate_example() for _ in range(10000)]
train_set = dataset[:8000]
valid_set = dataset[8000:9000]
test_set = dataset[9000:]

代码实现 Seq2Seq RNN 和 Transformer 模型~

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader# comments: 定义字符映射
CHARS = ['0','1','2','3','4','5','6','7','8','9','+']
PAD_IDX = 0
CHAR2IDX = {c:i+1 for i,c in enumerate(CHARS)}
IDX2CHAR = {i:c for c,i in CHAR2IDX.items()}
VOCAB_SIZE = len(CHAR2IDX) + 1  # 包含 PADclass SeqAdditionDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):x, y = self.data[idx]x_idx = [CHAR2IDX[c] for c in x]y_idx = [CHAR2IDX[c] for c in y]return torch.tensor(x_idx), torch.tensor(y_idx)# comments: collate_fn for padding
def collate_fn(batch):xs, ys = zip(*batch)xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=PAD_IDX)ys = nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=PAD_IDX)return xs, ys# RNN Encoder
class RNNEncoder(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, n_layers):super().__init__()self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=PAD_IDX)self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, batch_first=True)def forward(self, src):embedded = self.embedding(src)outputs, hidden = self.rnn(embedded)return outputs, hidden# RNN Decoder
class RNNDecoder(nn.Module):def __init__(self, output_dim, emb_dim, hid_dim, n_layers):super().__init__()self.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=PAD_IDX)self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, batch_first=True)self.fc_out = nn.Linear(hid_dim, output_dim)def forward(self, input, hidden):# input: [batch]input = input.unsqueeze(1)embedded = self.embedding(input)output, hidden = self.rnn(embedded, hidden)prediction = self.fc_out(output.squeeze(1))return prediction, hidden# comments: Seq2Seq 包装
class Seq2SeqRNN(nn.Module):def __init__(self, encoder, decoder, device):super().__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, src, trg, teacher_forcing_ratio=0.5):batch_size = src.size(0)trg_len = trg.size(1)trg_vocab_size = self.decoder.fc_out.out_featuresoutputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)enc_outputs, hidden = self.encoder(src)input = trg[:,0]for t in range(1, trg_len):output, hidden = self.decoder(input, hidden)outputs[:,t] = outputteacher_force = random.random() < teacher_forcing_ratiotop1 = output.argmax(1)input = trg[:,t] if teacher_force else top1return outputs

到这里,大家可以完整体验从 RNN 到 Transformer 的设计理念,并通过可视化体验 Attention 带来的优势。

大家在真实项目中,可以进一步结合 BERT、GPT 等预训练模型,实现更强大的性能提升噢。

http://www.lryc.cn/news/571236.html

相关文章:

  • k8s 收集event事件至Loki
  • Kafka 简介(附电子教程资料)
  • 云计算-Raft算法报告-raft与paxos对比
  • 【MySQL基础】表的功能实现:增删查改详细讲解
  • 第十七届山东省职业院校技能大赛中职组网络建设与运维赛项
  • php在线生成pdf选民证系统支持中文(小工具)
  • 【前端基础】摩天之建的艺术:html(下)
  • 数据库的查询
  • 游戏技能编辑器开发完全指南系统架构设计之技能编辑器整体架构
  • RISC-V向量扩展与GPU协处理:开源加速器设计新范式——对比NVDLA与香山架构的指令集融合方案
  • 【开源工具】Windows屏幕控制大师:息屏+亮度调节+快捷键一体化解决方案
  • 数字化零售如何全面优化顾客体验
  • 【SpringBoot】Spring Boot实现SSE实时推送实战
  • TDMQ CKafka 版事务:分布式环境下的消息一致性保障
  • 工业视觉应用开发教程(一)
  • KingbaseES在线体验平台:开启国产数据库学习新征程
  • Mybatis(XML映射文件、动态SQL)
  • 有趣的git
  • 机器学习项目微服务离线移植
  • 洪水风险图制作全流程:HEC-RAS 与 ArcGIS 的耦合应用
  • Rocky Linux 9 系统初始化与安全加固脚本
  • MySQL的Sql优化经验总结
  • 大模型知识库RAG框架,比如LangChain、ChatChat、FastGPT等等,哪个效果比较好
  • 执行 PGPT_PROFILES=ollama make run下面报错,
  • HTML知识全解析:从入门到精通的前端指南(上)
  • (OSGB转3DTiles强大工具)ModelSer--强大的实景三维数据分布式管理平台
  • 内测分发平台应用的异地容灾和负载均衡处理和实现思路?
  • 【前端基础】摩天之建的艺术:html(上)
  • 点云提取车道线 识别车道线
  • Rust 学习笔记:关于 OOP 和 trait 对象的练习题