抽丝剥茧,一步步推导“大模型强化学习的策略梯度公式”
大模型LLM强化学习的梯度公式是:
∇θJ(θ)=Eτ∼πθ[∑t=0T−1∇θlogπθ(at∣st)A(st,at)]
\nabla_{\theta} J(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}} \left[ \sum_{t=0}^{T-1} \nabla_{\theta} \log \pi_{\theta}(a_t | s_t) A(s_t, a_t) \right]
∇θJ(θ)=Eτ∼πθ[t=0∑T−1∇θlogπθ(at∣st)A(st,at)]
第一眼看到让人感到不知所云,其实它源自一个非常简单的公式,如下:
J(θ)=Eτ∼πθR(τ)
J(\theta) = \underset{\tau \sim \pi_{\theta}}{\mathbb{E}} R(\tau)
J(θ)=τ∼πθER(τ)
τ\tauτ 是大模型 πθ\pi_{\theta}πθ 生成的动作轨迹 (s0,a0,s1,a1,⋯ ,sT−1,aT−1)(s_0, a_0, s_1, a_1, \cdots, s_{T-1}, a_{T-1})(s0,a0,s1,a1,⋯,sT−1,aT−1),θ\thetaθ 是网络参数,s0s_0s0 就是用户给的 prompt,每个aaa 代表模型输出的 token,它结合前面的 sss 就组成下一个 sss。
强化学习的目标是要最大化轨迹分数 R(τ)R(\tau)R(τ) 的期望 ,也就是最大化目标函数 J(θ)J(\theta)J(θ)。下面从 J(θ)J(\theta)J(θ) 对 θ\thetaθ 的梯度开始,一步步推导出最终可以用于模型训练实操的策略梯度公式,每一行的解读见后。
下面是每行推导的解读:
(1) 这里需要对网络参数 θ\thetaθ 求梯度,但是 θ\thetaθ 藏在概率分布里,不好求。
(2) 按照数学期望的定义展开,这样 θ\thetaθ 就能被看见了。
(3) 求导可以移到内部,因为只有 P(τ,θ)P(\tau,\theta)P(τ,θ) 里有 θ\thetaθ。
(4) 上一步的公式不具有实操性,我们不可能遍历所有的轨迹,计算每一条轨迹的生成概率和分数然后求和。蒙特卡洛采样能接受的只有数学期望,我们要把公式重新变回数学期望的形式,所以这里提出来一个P(τ,θ)P(\tau,\theta)P(τ,θ) 。
(5) 这里出现了 log\loglog,这是一个数学技巧:函数的导数除以自己可以表示成函数对数的导数。无论是SFT的交叉熵损失公式还是这里强化学习的梯度公式,都可以看到这个对数概率梯度。对数概率梯度是神经网络优化里非常常见的一个概念。
(6) 这样就可以还原出数学期望的表达形式。现在,只需要大量采样轨迹 τ\tauτ,计算每一条轨迹的概率梯度(怎么计算后面有解释)并乘以轨迹的得分,就可以获得用于更新参数的梯度了。
仔细看,(6)已经能反映强化学习策略梯度的灵魂了,也就是“对数概率梯度×权重”。对数概率梯度代表的是“为了增加轨迹 τ\tauτ 的生成概率,θ\thetaθ 应该朝这个方向调整”,那么对于所有的轨迹,哪些轨迹要加强,哪些轨迹要削弱呢?权重 R(τ)R(\tau)R(τ) 就是评判标准,它对每一条轨迹的对数概率梯度进行加权。
所以说,强化学习其实就是加权的SFT。并且强化学习的权重可正可负,轨迹也是自己生成的,相比SFT引入了更多样的负样本。
考虑一种特殊情况:权重 R(τ)R(\tau)R(τ) 等于常数,也就是任何轨迹分数都一样。就算不推导也应该能猜到,剩下的对数概率梯度的期望 Eτ∼πθ[∇θlogP(τ,θ)]=0\underset{\tau \sim \pi_{\theta}}{\mathbb{E}} \left[\nabla_{\theta}\log{P(\tau,\theta)} \right]=0τ∼πθE[∇θlogP(τ,θ)]=0 ,因为已经不需要优化了,目标函数 J(θ)J(\theta)J(θ) 的梯度应该是0。
更一般地,如果变量 XXX 与轨迹 τ\tauτ 无关,那么对数概率梯度乘以 XXX 的期望还是0,即 Eτ∼πθ[∇θlogP(τ,θ)⋅X]=0\underset{\tau \sim \pi_{\theta}}{\mathbb{E}} \left[\nabla_{\theta}\log{P(\tau,\theta)}\cdot X \right]=0τ∼πθE[∇θlogP(τ,θ)⋅X]=0。这是因为 XXX 和对数概率导数是两个独立变量,XXX 可以单独提出来。这是对数概率梯度的一个重要性质,后面会用到。
(7) 这一步具体讨论怎么计算轨迹的概率和分数。
还记得 τ\tauτ 是动作轨迹 (s0,a0,s1,a1,⋯ ,sT−1,aT−1)(s_0, a_0, s_1, a_1, \cdots, s_{T-1}, a_{T-1})(s0,a0,s1,a1,⋯,sT−1,aT−1) 吗,所以轨迹的概率就是每个动作概率 πθ(at∣st)\pi_{\theta}(a_t|s_t)πθ(at∣st) 的乘积(其实还有初始概率 P(s0)P(s_0)P(s0) 和环境转移概率 P(st+1∣st,at)P(s_{t+1}|s_t,a_t)P(st+1∣st,at) ,但它们都和 θ\thetaθ 无关,所以求导后消失了),乘积取对数后变成求和。
至于奖励分数,一般来说只有轨迹执行完后才会生成一个分数,强化学习的重点就是把这个延迟的奖励分数分配到每一个动作上,R(st,at)R(s_t,a_t)R(st,at) 表示 sts_tst 状态下做出动作 ata_tat 的奖励分数。
(8) 上一步是两个求和相乘,根据乘法分配律,等价于先用每一个动作的对数概率梯度乘以总分数,再求和。
(9) (10) 对于时间步 ttt 采取的动作 ata_tat,我们可以把总得分拆成“与它无关的前序动作的得分+与它相关的后序动作得分”。为什么要这么拆?还记得前面说的吗,一个独立变量乘以对数概率梯度,它的数学期望是0。这样,通过剥离期望为0的部分,就能在不改变期望值的条件下减小方差。减小方差对于实际训练过程很有用,它能防止梯度爆炸,提升训练稳定性。
(11) 因为前序动作分数和当前动作的独立性,它不会影响当前轨迹的对数概率期望(也就是0),这一项消掉。严谨的证明这里省略掉了。
(12) 一般把 ∑k=tT−1R(sk,ak)\sum_{k=t}^{T-1}R(s_k,a_k)∑k=tT−1R(sk,ak) 叫做动作 ata_tat 的动作价值函数,代表采取此动作后该轨迹上的未来奖励之和,用 Q(st,at)Q(s_t,a_t)Q(st,at) 表示,这就是强化学习里动作价值函数的来源。
(13) Q(st,at)Q(s_t,a_t)Q(st,at) 它是一个绝对的值,不是相对的值,它的方差可能过大。比如分数是99,99是大是小?分数99代表动作的优秀程度究竟是多大?直接把99作为权重乘到这个动作的对数概率梯度上去,这不太合理。比较合理的方式是找到这个状态 sts_tst 的基准分数,也就是状态价值函数 V(st)V(s_t)V(st) 作为基线,理论上它是这个状态下可能采取的所有动作的最终分数的平均值。Q(st,at)Q(s_t,a_t)Q(st,at) 相对于基线的差值更能反映动作的好坏。
由于 V(st)V(s_t)V(st) 只和状态 sts_tst 有关,和动作 ata_tat 无关,所以它和该动作的对数概率梯度也是独立变量,根据前面的讨论,减去它并不会改变最终期望,反而可以有效减小方差。
(14) Q(st,at)−V(st)Q(s_t,a_t)-V(s_t)Q(st,at)−V(st) 一般表示成优势函数 A(st,at)A(s_t,a_t)A(st,at),它反映的是动作的相对好坏。至于怎么从轨迹的最终打分 R(τ)R(\tau)R(τ) 推出每一个动作的优势值,不同策略有不同的方法:
PPO (Proximal Policy Optimization) 会老老实实地用一个单独的模型来预测 V(st)V(s_t)V(st) ,然后用蒙特卡洛采样或一个神经网络来估计 Q(st,at)Q(s_t,a_t)Q(st,at),最后相减得到动作的优势值。
DeepSeek-R1的 GRPO (Group Relative Policy Optimization) 则用了一个简单的方法:生成一批轨迹(也叫一个组),把它们得分的平均值作为这组的基准,每一条轨迹得分相对这个基准的差值就是优势值,一条轨迹内的所有动作都共享这个相同的优势值。这样就不用单独训练一个模型来模拟状态价值函数。