当前位置: 首页 > news >正文

抽丝剥茧,一步步推导“大模型强化学习的策略梯度公式”

大模型LLM强化学习的梯度公式是:
∇θJ(θ)=Eτ∼πθ[∑t=0T−1∇θlog⁡πθ(at∣st)A(st,at)] \nabla_{\theta} J(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}} \left[ \sum_{t=0}^{T-1} \nabla_{\theta} \log \pi_{\theta}(a_t | s_t) A(s_t, a_t) \right] θJ(θ)=Eτπθ[t=0T1θlogπθ(atst)A(st,at)]
第一眼看到让人感到不知所云,其实它源自一个非常简单的公式,如下:
J(θ)=Eτ∼πθR(τ) J(\theta) = \underset{\tau \sim \pi_{\theta}}{\mathbb{E}} R(\tau) J(θ)=τπθER(τ)
τ\tauτ 是大模型 πθ\pi_{\theta}πθ 生成的动作轨迹 (s0,a0,s1,a1,⋯ ,sT−1,aT−1)(s_0, a_0, s_1, a_1, \cdots, s_{T-1}, a_{T-1})(s0,a0,s1,a1,,sT1,aT1)θ\thetaθ 是网络参数,s0s_0s0 就是用户给的 prompt,每个aaa 代表模型输出的 token,它结合前面的 sss 就组成下一个 sss

强化学习的目标是要最大化轨迹分数 R(τ)R(\tau)R(τ) 的期望 ,也就是最大化目标函数 J(θ)J(\theta)J(θ)。下面从 J(θ)J(\theta)J(θ)θ\thetaθ 的梯度开始,一步步推导出最终可以用于模型训练实操的策略梯度公式,每一行的解读见后。

在这里插入图片描述

下面是每行推导的解读:

(1) 这里需要对网络参数 θ\thetaθ 求梯度,但是 θ\thetaθ 藏在概率分布里,不好求。

(2) 按照数学期望的定义展开,这样 θ\thetaθ 就能被看见了。

(3) 求导可以移到内部,因为只有 P(τ,θ)P(\tau,\theta)P(τ,θ) 里有 θ\thetaθ

(4) 上一步的公式不具有实操性,我们不可能遍历所有的轨迹,计算每一条轨迹的生成概率和分数然后求和。蒙特卡洛采样能接受的只有数学期望,我们要把公式重新变回数学期望的形式,所以这里提出来一个P(τ,θ)P(\tau,\theta)P(τ,θ)

(5) 这里出现了 log⁡\loglog,这是一个数学技巧:函数的导数除以自己可以表示成函数对数的导数。无论是SFT的交叉熵损失公式还是这里强化学习的梯度公式,都可以看到这个对数概率梯度。对数概率梯度是神经网络优化里非常常见的一个概念

(6) 这样就可以还原出数学期望的表达形式。现在,只需要大量采样轨迹 τ\tauτ,计算每一条轨迹的概率梯度(怎么计算后面有解释)并乘以轨迹的得分,就可以获得用于更新参数的梯度了。

仔细看,(6)已经能反映强化学习策略梯度的灵魂了,也就是“对数概率梯度×权重”。对数概率梯度代表的是“为了增加轨迹 τ\tauτ 的生成概率,θ\thetaθ 应该朝这个方向调整”,那么对于所有的轨迹,哪些轨迹要加强,哪些轨迹要削弱呢?权重 R(τ)R(\tau)R(τ) 就是评判标准,它对每一条轨迹的对数概率梯度进行加权。

所以说,强化学习其实就是加权的SFT。并且强化学习的权重可正可负,轨迹也是自己生成的,相比SFT引入了更多样的负样本。

考虑一种特殊情况:权重 R(τ)R(\tau)R(τ) 等于常数,也就是任何轨迹分数都一样。就算不推导也应该能猜到,剩下的对数概率梯度的期望 Eτ∼πθ[∇θlog⁡P(τ,θ)]=0\underset{\tau \sim \pi_{\theta}}{\mathbb{E}} \left[\nabla_{\theta}\log{P(\tau,\theta)} \right]=0τπθE[θlogP(τ,θ)]=0 ,因为已经不需要优化了,目标函数 J(θ)J(\theta)J(θ) 的梯度应该是0。

更一般地,如果变量 XXX 与轨迹 τ\tauτ 无关,那么对数概率梯度乘以 XXX 的期望还是0,即 Eτ∼πθ[∇θlog⁡P(τ,θ)⋅X]=0\underset{\tau \sim \pi_{\theta}}{\mathbb{E}} \left[\nabla_{\theta}\log{P(\tau,\theta)}\cdot X \right]=0τπθE[θlogP(τ,θ)X]=0。这是因为 XXX 和对数概率导数是两个独立变量,XXX 可以单独提出来。这是对数概率梯度的一个重要性质,后面会用到。

(7) 这一步具体讨论怎么计算轨迹的概率和分数。

还记得 τ\tauτ 是动作轨迹 (s0,a0,s1,a1,⋯ ,sT−1,aT−1)(s_0, a_0, s_1, a_1, \cdots, s_{T-1}, a_{T-1})(s0,a0,s1,a1,,sT1,aT1) 吗,所以轨迹的概率就是每个动作概率 πθ(at∣st)\pi_{\theta}(a_t|s_t)πθ(atst) 的乘积(其实还有初始概率 P(s0)P(s_0)P(s0) 和环境转移概率 P(st+1∣st,at)P(s_{t+1}|s_t,a_t)P(st+1st,at) ,但它们都和 θ\thetaθ 无关,所以求导后消失了),乘积取对数后变成求和。

至于奖励分数,一般来说只有轨迹执行完后才会生成一个分数,强化学习的重点就是把这个延迟的奖励分数分配到每一个动作上R(st,at)R(s_t,a_t)R(st,at) 表示 sts_tst 状态下做出动作 ata_tat 的奖励分数。

(8) 上一步是两个求和相乘,根据乘法分配律,等价于先用每一个动作的对数概率梯度乘以总分数,再求和。

(9) (10) 对于时间步 ttt 采取的动作 ata_tat,我们可以把总得分拆成“与它无关的前序动作的得分+与它相关的后序动作得分”。为什么要这么拆?还记得前面说的吗,一个独立变量乘以对数概率梯度,它的数学期望是0。这样,通过剥离期望为0的部分,就能在不改变期望值的条件下减小方差。减小方差对于实际训练过程很有用,它能防止梯度爆炸,提升训练稳定性。

(11) 因为前序动作分数和当前动作的独立性,它不会影响当前轨迹的对数概率期望(也就是0),这一项消掉。严谨的证明这里省略掉了。

(12) 一般把 ∑k=tT−1R(sk,ak)\sum_{k=t}^{T-1}R(s_k,a_k)k=tT1R(sk,ak) 叫做动作 ata_tat动作价值函数,代表采取此动作后该轨迹上的未来奖励之和,用 Q(st,at)Q(s_t,a_t)Q(st,at) 表示,这就是强化学习里动作价值函数的来源。

(13) Q(st,at)Q(s_t,a_t)Q(st,at) 它是一个绝对的值,不是相对的值,它的方差可能过大。比如分数是99,99是大是小?分数99代表动作的优秀程度究竟是多大?直接把99作为权重乘到这个动作的对数概率梯度上去,这不太合理。比较合理的方式是找到这个状态 sts_tst 的基准分数,也就是状态价值函数 V(st)V(s_t)V(st) 作为基线,理论上它是这个状态下可能采取的所有动作的最终分数的平均值。Q(st,at)Q(s_t,a_t)Q(st,at) 相对于基线的差值更能反映动作的好坏。

由于 V(st)V(s_t)V(st) 只和状态 sts_tst 有关,和动作 ata_tat 无关,所以它和该动作的对数概率梯度也是独立变量,根据前面的讨论,减去它并不会改变最终期望,反而可以有效减小方差。

(14) Q(st,at)−V(st)Q(s_t,a_t)-V(s_t)Q(st,at)V(st) 一般表示成优势函数 A(st,at)A(s_t,a_t)A(st,at),它反映的是动作的相对好坏。至于怎么从轨迹的最终打分 R(τ)R(\tau)R(τ) 推出每一个动作的优势值,不同策略有不同的方法:

PPO (Proximal Policy Optimization) 会老老实实地用一个单独的模型来预测 V(st)V(s_t)V(st) ,然后用蒙特卡洛采样或一个神经网络来估计 Q(st,at)Q(s_t,a_t)Q(st,at),最后相减得到动作的优势值。

DeepSeek-R1的 GRPO (Group Relative Policy Optimization) 则用了一个简单的方法:生成一批轨迹(也叫一个组),把它们得分的平均值作为这组的基准,每一条轨迹得分相对这个基准的差值就是优势值,一条轨迹内的所有动作都共享这个相同的优势值。这样就不用单独训练一个模型来模拟状态价值函数。

http://www.lryc.cn/news/585251.html

相关文章:

  • manifest.json只有源码视图没其他配置
  • Monorepo 与包管理工具:从幽灵依赖看 npm 与 pnpm 的架构差异
  • php的原生类
  • 云、实时、时序数据库混合应用:医疗数据管理的革新与展望(中)
  • 安全领域的 AI 采用:主要用例和需避免的错误
  • 将Blender、Three.js与Cesium集成构建物联网3D可视化系统
  • Redis数据库基础篇章学习
  • 2025年NSSCTF-青海民族大学 2025 新生赛WP
  • 【Spring Boot】Spring Boot 4.0 的颠覆性AI特性全景解析,结合智能编码实战案例、底层架构革新及Prompt工程手册
  • 《棒球规则介绍》领队和主教练谁说了算·棒球1号位
  • Lookahead:Trie 树(前缀树)
  • 关于List.of()
  • 深度对比扣子(Coze) vs n8n
  • PyTorch笔记5----------Autograd、nn库
  • Android Jetpack Compose 状态管理介绍
  • 流程图设计指南|从零到一优化生产流程(附模板)
  • MySQL的使用
  • 斯坦福 CS336 动手大语言模型 Assignment1 BPE Tokenizer TransformerLM
  • 高速路上的 “阳光哨兵”:分布式光伏监控系统守护能源高效运转
  • 250630课题进展
  • 电力自动化的通信中枢,为何工业交换机越来越重要?
  • C++——构造函数
  • 数据库迁移人大金仓数据库
  • stm32-modbus-rs485程序移植过程
  • 微算法科技基于格密码的量子加密技术,融入LSQb算法的信息隐藏与传输过程中,实现抗量子攻击策略强化
  • 【AI大模型】RAG系统组件:向量数据库(ChromaDB)
  • 新作品:吃啥好呢 - 个性化美食推荐
  • QT跨平台应用程序开发框架(4)—— 常用控件QWidget
  • 【机器学习】保序回归平滑校准算法
  • AI在医疗影像诊断中的应用前景与挑战