当前位置: 首页 > news >正文

【LLM】扩散模型与自回归模型:文本生成的未来对决

  🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎

📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃

🎁欢迎各位→点赞👍 + 收藏⭐️ + 留言📝​

📣系列专栏 - 机器学习【ML】 自然语言处理【NLP】  深度学习【DL】

 🖍foreword

✔说明⇢本人讲解主要包括Python、机器学习(ML)、深度学习(DL)、自然语言处理(NLP)等内容。

如果你对这个系列感兴趣的话,可以关注订阅哟👋

目录

简介

理解自回归语言模型

数学公式

模型架构

AR 语言模型中的训练和推理

探索基于扩散的语言模型

数学公式

模型架构

训练和推理

正面交锋:比较分析

混合模型

未来趋势与结论


简介

现代语言模型 (LM) 一直以自回归 (AR)方法为主,该方法以从左到右的方式逐个生成文本标记。在这种范式中,模型学习根据前一个单词最大化下一个单词的似然值。这种方法已在 GPT-3、LLaMA 3 等模型中得到应用,并已被证明非常成功。然而,顺序生成意味着错误会传播,并且并行性受到限制。如果 AR 模型早期出现错误,则会影响所有后续标记,并且它无法轻松修改之前的标记。

相比之下,基于扩散的语言模型 (DLM)是一种借鉴于图像生成领域的较新理念。在扩散过程中,整个序列会通过多步噪声损坏,然后训练一个模型来逆转这种损坏。在视觉领域,扩散模型(如 DDPM)会向图像添加高斯噪声,然后学习去噪网络。对于文本,我们对此理念进行了调整,使文本逐渐加噪(例如,替换标记或扰动嵌入),然后恢复文本。简而言之,扩散模型通过逐步对损坏的句子进行去噪来同时生成整个文本,从而允许在序列中的任何位置进行并行更新和纠错。

理解自回归语言模型

自回归模型是现代自然语言处理(NLP)的基石,尤其是在文本生成领域。它们的设计本质上是顺序的,就像我们通常构建句子和叙述的方式一样。

自回归模型(仅解码器)的实际应用

自回归(AR)语言模型通过将标记序列分解为条件概率的乘积来学习该标记序列的概率:

下一个 token 的概率表达式

训练通过教师强制训练数据中的先前标记来最大化数据可能性(或等效地最小化交叉熵)。这种从左到右的因式分解是大型LM训练中的常见方法。AR模型通常使用RNN/LSTM或(现在更常见的是)Transformer解码器架构来实现这些条件概率。

数学公式

在 AR 模型中,序列的联合概率被逐步分解。对于 token x_1,…,x_T

下一个 token 的概率表达式

等效地,在训练中,最小化负对数似然(NLL)或交叉熵损失:

交叉熵损失表达式

其中p(θ)由模型的softmax输出给出。该公式假设每个令牌仅依赖于之前的令牌(因果掩码)。通过最大化对数似然,AR模型有效地学习预测下一个单词。这种方法允许并行训练,但在推理时强制顺序生成。

模型架构

现代 AR 语言模型几乎总是使用Transformer 解码器。典型的 GPT 风格模型的工作原理如下:每个输入 token 首先被映射到一个连续的嵌入向量,并添加一个固定的位置编码。这些嵌入经过 N 个堆叠的Transformer 块(每个块都包含带掩码的自注意力机制和前馈层),生成语境化的表示。最后,线性层和 softmax 为下一个 token 生成词汇的概率分布。

仅解码器模型

 成分

  • 嵌入和位置编码:输入标记(作为整数)经过嵌入查找;我们添加正弦或学习的位置向量,以便模型知道标记顺序。
  • 解码器层:每层由因果自注意力组成(标记t仅关注位置≤t
  • 输出投影:对最终位置的顶层输出进行投影和softmax处理,以预测下一个标记概率。

该架构自然地强化了AR 属性(通过因果注意机制),并且可以扩展到数十亿个参数。它还允许使用固定上下文窗口(例如2048 个标记),如果超出,则必须通过截断或缓存来处理。在生成过程中,模型在每个步骤中都会采样(或贪婪地选择)下一个标记并将其附加到上下文中。该过程迭代重复,一次构建一个标记,直到满足预定义的停止条件,例如生成序列结束 (EOS)标记或达到最大长度。

AR 语言模型中的训练和推理

 训练

训练使用教师强制算法:在每一步t,模型都会看到真实的先前标记x_<t并预测x_t。损失是交叉熵的总和。像Adam这样的优化器会调整模型权重,以尽量减少训练语料库的损失。这种最大似然过程简单明了,适用于大规模数据/计算,这解释了为什么AR LM能够带来强大的上下文学习和下游性能。

推理

在推理时,令牌是按顺序生成的。常见的策略包括贪婪解码(每次选择最高概率的令牌)或采样方法(例如,top-k/核采样)来引入多样性。波束搜索也用于一些AR设置(特别是翻译等任务),以探索多个高概率序列。
请注意,所有这些策略每次前向传递都会生成一个令牌,这使得长序列的生成相对较慢。自回归模型不能修改已经生成的令牌;每个选择都是最终的,并作为下一步的输入反馈。错误倾向于向前传播,严格的从左到右的性质意味着生成不能跨位置并行化。

探索基于扩散的语言模型

扩散模型的灵感源自非平衡热力学,这类似于观察油漆在水中逐渐扩散,直至均匀地着色的过程。在生成式人工智能的背景下,这个过程被逆转,从而产生数据。

不同类型的生成模型概述

扩散模型通过逐渐添加噪声来破坏数据,然后学习一个反向的去噪过程。在文本领域,基于扩散的语言模型将这一思想应用于标记序列。其核心概念是定义一个正向(加噪)过程q,逐步将数据转化为噪声,以及一个学习到的反向过程p_θ,用于恢复原始数据。

该过程包括两个主要阶段:

  • 前向(加噪)过程 ( q ):这是一个扩散阶段,其中噪声在一系列时间步长内逐渐添加到原始数据和干净数据中。对于文本,这意味着逐渐将干净的文本嵌入转换为更简单的分布,通常是高斯噪声。此过程系统地降低数据质量,直到其与随机噪声无法区分。
  • 逆向(去噪)过程 ( p_θ ):这是生成阶段,也是扩散模型创建内容的核心。模型学习逆向正向过程,逐步从噪声输入中去除噪声。通过迭代去噪,模型可以重建原始数据,或生成与学习到的数据分布一致的新的高质量样本。

与自回归模型不同,扩散 LM 可以在生成过程中联合更新所有标记。

数学公式

通过缓慢添加(去除)噪声来生成样本的正向(逆向)扩散过程的马尔可夫链。

在扩散语言模型中,每个文本序列(长度为L)首先被嵌入到一个连续空间中(例如,通过词嵌入)。前向马尔可夫链会逐渐破坏这个嵌入序列。例如,我们可以像在连续扩散模型中那样定义高斯步骤:

连续扩散中的高斯阶跃

因此,经过T步之后,数据几乎就成了纯噪声。我们可以将上述等式重写如下:

在哪里:

这里,z_0是干净的 嵌入序列,z_T是近随机高斯 噪声。在离散/标记扩散变体中,前向过程可能会通过随机替换或在词汇单纯形上使用可学习的转换来破坏标记,但其思想类似于通过许多小步骤添加噪声。

逆向(生成)模型p_θ 试图反转正向过程。通常,对于每个时间步t,神经网络(主要基于 Transformer)会获取噪声状态z_t ,并预测p_θ(z_{t-1} | z_t) 的参数(例如平均噪声)训练的目标是最小化p_θ与真实逆向条件q(z_{t-1} | z_t)之间的差异在实践中,这可以通过对目标函数进行去噪(例如,预测添加的高斯噪声)或最大化似然函数的变分界限来实现。

模型架构

扩散语言模型 (LM) 的架构多种多样,但它们都包含一些共同的元素:token 嵌入、噪声调度和去噪网络。典型的流程如下:嵌入 token,添加时间步长信息,然后运行一系列Transformer或U-net 层来预测去噪后的输出。

一些关键点:

  • 连续 vs. 离散:一些模型(例如 Plaid LM)将 token 嵌入到连续空间并使用高斯噪声。其他模型则直接对 token 概率的离散单纯形(独热向量)进行操作。例如,SSD-LM在自然词汇空间(单纯形)上进行扩散,而不是在潜在空间中进行扩散。
  • 时间嵌入:模型通常会对当前噪声时间戳t进行编码(通常通过正弦嵌入或学习嵌入),以便知道添加了多少噪声。这类似于Transformer 中的位置嵌入,但表示的是噪声级别。
  • 网络主干:许多扩散语言模型 (LM) 都是用于降噪器的Transformer 模块(类似于 AR 模型)。一些研究(在视觉领域)采用了U-Net 架构,但对于文本而言,主干通常是一个双向甚至是自回归的 Transformer,它将整个噪声序列作为输入。

总体而言,扩散语言模型 (LM) 的架构类似于一个编码器网络,它处理噪声序列并输出去噪过程的残差或下一步预测。该网络经过训练,以便在经过T 个去噪步骤(通常为50-200 步)后,恢复的序列能够与原始文本匹配。与增强现实 (AR) 解码器不同,扩散网络在训练和采样期间会同时查看(并更新)所有标记。

训练和推理

LLaDA 概念概述。(a)预训练。LLaDA 在文本上进行训练,所有标记均以相同的比例 t ∼ U[0, 1] 独立应用随机掩码。(b)SFT。只有响应标记可能被掩码。采样。LLaDA 模拟从 t = 1(完全掩码)到 t = 0(未掩码)的扩散过程,并在每一步同时预测所有掩码,并采用灵活的重新掩码策略。

训练

扩散语言模型的训练目标通常是最大似然或去噪分数匹配的形式。人们可以推导出数据似然的变分下限,但常见的简化方法是直接最小化去噪误差。本质上,在每个步骤t中,都会为模型提供示例x的噪声版本z_t,并学习预测原始x或添加的噪声。这可以实现为预测噪声的均方误差损失或离散级别的交叉熵。因此,该模型会学习在每一步中逆转前向损坏。由于前向过程是固定且已知的,因此梯度可以在所有时间步长中流经去噪网络,从而实现并行硬件上的端到端训练。

推理

要生成文本,我们需要从随机噪声向量z_T开始,并应用学习到的逆过程T步。具体来说,在步骤t处,我们有一个带噪声的嵌入z_t,模型预测z_{T-1}的参数(均值或方差) 。我们对其进行采样或取均值得到z_{T-1},并持续到z_0 。然后将结果z_0解码(通过最近标记或 softmax 方法)为离散文本。与 AR 模型不同,所有标记在每个去噪步骤中都会并行更新。

值得注意的是,扩散推理允许前瞻:模型可以在每一步修改序列的任何部分。理论上,这为规划或全局一致性提供了更大的灵活性。然而,每个序列的神经评估次数为T次,而增强现实 (AR)仅执行与标记数量相同的步骤。实际上,除非大幅减少步数,否则对于长文本,简单的扩散采样可能比增强现实(AR)更慢。

正面交锋:比较分析

自回归模型和扩散模型在文本生成方法上的根本区别导致了不同的性能特征和权衡。

A. 生成过程的根本差异

  • 自回归 (AR)[顺序和逐个标记]:AR 模型通常被描述为顺序讲故事的人。它们严格地从左到右构建文本,每次一个标记,每个新标记都根据之前出现的整个标记序列进行预测。
  • 扩散[并行和迭代细化]:扩散模型以迭代细化器的方式运行。它们不是按顺序生成,而是从整个序列的噪声表示开始,并通过多个步骤逐步进行去噪。

B. 速度和效率

  • 固定长度输出 vs 任意长度输出:对于非常短的响应,AR 模型速度更快,因为它们仅按顺序生成必要数量的标记。扩散模型虽然需要多个去噪步骤,但由于其并行特性,可以同时处理整个输出块,从而更快地生成固定长度的输出(尤其是较长的输出)。
  • 长上下文:AR 模型在处理长输入上下文时效率更高,因为它们可以利用键值缓存技术。相比之下,扩散模型在处理长上下文时会遇到困难,因为它们对整个标记块的迭代细化需要反复计算针对完整上下文的注意力,因为块的内容在每次去噪过程中都可能发生变化。

C. 质量和多样性

  • 连贯性和流畅性:自回归模型由于其严格的顺序条件,通常在局部流畅性和语法正确性方面表现出色。扩散模型通过迭代改进,可以实现卓越的全局连贯性和语境感知,因为它们可以完善整个输出。
  • 多样性:扩散模型理论上提供了更大的多样性和更好的模型覆盖率,因为它们的生成过程允许它们更广泛地探索数据分布,从而避免更多的崩溃。

D. 差异总结

DLM 与 AR-LM

混合模型

自回归模型和扩散模型各自固有的优缺点自然而然地推动了混合架构的发展。自回归模型在长上下文中表现出色,流畅性和效率更高,而扩散模型则具有强大的全局一致性、多样性潜力和细粒度控制能力。混合方法的目标是结合这些互补的优势,从而减轻各自的 局限性。

最近的几个模型介于这两个范式之间:

  • AR-Diffusion:由微软研究人员提出的 AR-Diffusion 算法,它为右侧的 token 分配了更多的去噪步骤,而为左侧的 token 分配了更少的去噪步骤。这使得较早出现的 token 先出现,然后再对较晚出现的 token 进行条件化,从而有效地将自回归依赖性重新引入到扩散过程中。
  • LongTextAR:该模型解决了一个特定的挑战,即在图像中渲染连贯的长文本。纯扩散模型由于上下文窗口的限制,往往难以完成这项任务。LongTextAR 利用自回归模型的能力,处理更长的文本序列,并将其与视觉生成功能相结合。

未来趋势与结论

展望未来,我们预计增强现实 (AR)与扩散技术之间将出现更多交叉融合。未来可能出现的趋势包括:

  • 减少步骤扩散:知识提炼或一致性模型等方法可以将扩散减少到几十步甚至单步(就像在图像扩散中所做的那样)
  • 文本的潜在扩散:类似于图像中的潜在扩散,在扩散之前将文本映射到紧凑的潜在空间可以提高效率。
  • 统一架构:我们可能会看到在顺序和并行生成模式之间流畅切换的架构,或者仅对深度模型的某些层使用扩散

总体而言,虽然自回归变换语言模型 (LM) 目前仍占主导地位,但基于扩散的方法正在迅速发展。它们在并行、全局、相干生成和模块化控制方面的潜力使其成为一个令人兴奋的前沿领域。融合了自回归和扩散优势的混合模型(例如 AR-Diffusion)已经预示着两者兼具的优势。未来的语言模型很可能不会仅仅局限于其中之一,而是会将扩散原理(例如基于噪声的扰动、分数匹配目标)融入到以自回归为主的框架中。这可能会催生出具有更快采样速度、更佳多样性和全新控制能力的模型。

http://www.lryc.cn/news/612813.html

相关文章:

  • 分布式事务与分布式锁
  • “物联网+职业本科”:VR虚拟仿真实训室的发展前景
  • USB枚举介绍 以及linux USBFFS应用demo
  • 抖音、快手、视频号等多平台视频解析下载 + 磁力嗅探下载、视频加工(提取音频 / 压缩等)
  • Go语言Ebiten坦克大战
  • JVM类加载
  • Redis中间件(三):Redis存储原理与数据模型
  • Spring MVC拦截器与过滤器的区别详解
  • Ubuntu24.04的“errors from xkbcomp are not fatal to the X server”终极修复方案
  • Ethereum:如何优雅部署 NPM 包中的第三方智能合约?
  • SpringBoot学习日记 Day5:解锁企业级开发核心技能
  • 90-基于Flask的中国博物馆数据可视化分析系统
  • 8- 知识图谱 — 应用案例怎么 “落地” 才有效?构建流程与行业实践全解析
  • LoRaWAN的网络拓扑
  • Kong vs. NGINX:从反向代理到云原生网关的全景对比
  • PCL提取平面上的圆形凸台特征
  • 阿里系bx_et加密分析
  • 构造函数:C++对象初始化的核心机制
  • 天猫商品评论API技术指南
  • uni-app X能成为下一个Flutter吗?
  • Flutter报错...Unsupported class file major version 65
  • C# 异步编程(async_await特性的结构)
  • PyTorch 核心三件套:Tensor、Module、Autograd
  • `/dev/vdb` 是一个新挂载的 4TB 硬盘,但目前尚未对其进行分区和格式化。
  • vscode 打开设置
  • Flutter 三棵树
  • 【物联网】基于树莓派的物联网开发【25】——树莓派安装Grafana与Influxdb无缝集成
  • CentOS 7 下通过 Anaconda3 运行llm大模型、deepseek大模型的完整指南
  • 人工智能的20大应用
  • 从Centos 9 Stream 版本切换到 Rocky Linux 9