当前位置: 首页 > news >正文

语言模型 RLHF 实践指南(一):策略网络、价值网络与 PPO 损失函数


在使用 Proximal Policy Optimization(PPO)对语言模型进行强化学习微调(如 RLHF)时,大家经常会问:

  • 策略网络的动作概率是怎么来的?
  • 价值网络的得分是如何计算的?
  • 奖励从哪里来?损失函数怎么构建?
  • 微调后的旧轨迹还能用吗?

这篇文章将以语言模型强化学习微调为例,结合实际实现和数学公式,深入解析 PPO 的关键计算流程。


1️⃣ 策略网络:如何计算动作概率?

策略网络 πθ(a∣s)\pi_\theta(a|s)πθ(as) 用于给出状态 sss 下采取动作 aaa 的概率。

对于语言模型(如 GPT)来说:

  • 状态 sss:Prompt(如“请介绍量子计算”)
  • 动作 aaa:生成的回答(如“量子计算是一种…”)

策略网络的输出是 token 级别的 logits,经 softmax 后得到概率:

outputs = model(input_ids)
logits = outputs.logits                         # [batch_size, seq_len, vocab_size]
probs = torch.softmax(logits, dim=-1)           # 得到 token 概率

对于一个完整回答,其概率为:

πθ(a1:T∣s)=∏t=1Tπθ(at∣s,a<t) \pi_\theta(a_{1:T} | s) = \prod_{t=1}^T \pi_\theta(a_t | s, a_{<t}) πθ(a1:Ts)=t=1Tπθ(ats,a<t)

该概率在 PPO 中用于计算策略概率比:

rt=πθ(at∣st)πθold(at∣st) r_t = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt=πθold(atst)πθ(atst)


2️⃣ 价值网络:如何计算状态得分?

价值网络 Vϕ(s)V_\phi(s)Vϕ(s) 预测的是状态 sss 的期望累计奖励,即该 prompt + 回复的“好坏”。

实现方式通常是共享模型底座 + 线性输出层:

hidden_states = outputs.hidden_states         # [batch_size, seq_len, hidden_dim]
value = value_head(hidden_states).squeeze(-1) # 每个 token 对应一个值

通常使用最后一个 token 的 value 作为整段文本的状态值:

Vϕ(s)=value(last_token) V_\phi(s) = \text{value}(\text{last\_token}) Vϕ(s)=value(last_token)
也可以做 mean pooling 等方式。


3️⃣ 奖励函数:怎么定义?

PPO 是一个基于奖励优化的强化学习算法。对于语言模型,一般使用人工偏好、打分器、奖励模型(RM)来计算奖励 RRR

示例方式:

  • 高质量回答奖励高,例如 R=+4R = +4R=+4
  • 差的回答奖励低,例如 R=+1R = +1R=+1
  • 或者使用两个回复的相对排序值差距(ranking loss)

PPO 使用奖励和预测值来计算优势函数(Advantage):

A^t=Rt−Vϕ(st) \hat{A}_t = R_t - V_\phi(s_t) A^t=RtVϕ(st)

也可以用 GAE(广义优势估计)进一步平滑优势值。


4️⃣ PPO 策略损失函数:如何构建?

核心损失函数如下(Clipped Surrogate Objective):

Lpolicy=−Et[min⁡(rtA^t,clip(rt,1−ϵ,1+ϵ)A^t)] L^{\text{policy}} = -\mathbb{E}_t \left[ \min \left( r_t \hat{A}_t, \text{clip}(r_t, 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right] Lpolicy=Et[min(rtA^t,clip(rt,1ϵ,1+ϵ)A^t)]

解释:

  • rtr_trt 是策略概率比
  • A^t\hat{A}_tA^t 是优势函数
  • ϵ\epsilonϵ 是截断系数(常用 0.2)

这个损失保证了策略更新不能偏离旧策略太远,防止训练不稳定。

🔍 第一次微调时,rt=1r_t = 1rt=1

由于初始时当前策略与旧策略相同,有:

rt=πθ(at∣st)πθold(at∣st)=1 r_t = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} = 1 rt=πθold(atst)πθ(atst)=1

所以第一次策略更新实际变成:

Lpolicy=−A^t L^{\text{policy}} = -\hat{A}_t Lpolicy=A^t

相当于标准的策略梯度算法。


5️⃣ PPO 价值损失函数:如何构建?

价值网络使用均方误差损失来拟合奖励:

Lvalue=Et[(Vϕ(st)−Rt)2] L^{\text{value}} = \mathbb{E}_t \left[ \left( V_\phi(s_t) - R_t \right)^2 \right] Lvalue=Et[(Vϕ(st)Rt)2]

也可以加上 value clipping:

Lvalue-clipped=max⁡((Vϕ(st)−Rt)2,(clip(Vϕ(st),Vold−ϵ,Vold+ϵ)−Rt)2) L^{\text{value-clipped}} = \max\left( (V_\phi(s_t) - R_t)^2, (\text{clip}(V_\phi(s_t), V_{\text{old}} - \epsilon, V_{\text{old}} + \epsilon) - R_t)^2 \right) Lvalue-clipped=max((Vϕ(st)Rt)2,(clip(Vϕ(st),Voldϵ,Vold+ϵ)Rt)2)


6️⃣ 总损失函数:包含 entropy 奖励

完整的 PPO 损失函数通常为:

L=Lpolicy+c1⋅Lvalue−c2⋅H(πθ) L = L^{\text{policy}} + c_1 \cdot L^{\text{value}} - c_2 \cdot H(\pi_\theta) L=Lpolicy+c1Lvaluec2H(πθ)

  • H(πθ)H(\pi_\theta)H(πθ):策略的熵,用于鼓励探索(entropy bonus)
  • c1,c2c_1, c_2c1,c2:超参数,通常取 0.5 和 0.01

熵越高表示策略更随机,防止策略过早收敛到确定动作。


7️⃣ 微调后,旧轨迹还能继续用吗?

不能。PPO 是 on-policy 算法。

每轮策略更新后,旧轨迹(state, action, reward, old prob)就过时了,必须重新采样:

  • 旧策略生成的样本反映不了当前策略的行为
  • 若继续使用,会引入策略偏移(policy mismatch)

因此,PPO 的标准训练循环是:

  1. 用当前策略生成轨迹
  2. 固定轨迹,训练 N 个 epoch
  3. 更新策略后丢弃旧轨迹
  4. 重复采样新数据

✅ 总结回顾

项目内容说明
策略概率模型输出 logits → softmax 得到 token 概率
策略损失PPO clipped objective,基于概率比和优势函数
价值得分Value head 输出一个实数,预测状态期望奖励
奖励函数来自人工打分或奖励模型,指导优势函数计算
是否复用轨迹❌ 不能复用旧轨迹,策略更新后必须重新采样

🔚 写在最后

理解 PPO 中策略概率、价值得分、损失函数之间的关系,是成功实现 RLHF、SFT + RL 微调语言模型的基础。

这些原理不只是公式,更影响着你训练是否稳定、样本是否有效、微调是否收敛。


http://www.lryc.cn/news/582961.html

相关文章:

  • MySQL索引面试问题梳理
  • 【Android】组件及布局介绍
  • Flutter基础(前端教程②-卡片列表)
  • 【牛客刷题】小红的v三元组
  • 从单体到微服务:Spring Cloud 开篇与微服务设计
  • 音频主动降噪技术
  • 暑假算法日记第四天
  • Spring AI:检索增强生成(RAG)
  • 工作中的思考
  • Java教程:【程序调试技巧】入门
  • 项目Win系统下可正常获取Header字段,但是到了linux、docker部署后无法获取
  • 数据湖技术之Iceberg-03 Iceberg整合Flink 实时写入与增量读取
  • 【HarmonyOS】鸿蒙端云一体化开发入门详解 (一)
  • 深度剖析 Linux ip neigh:邻居表项的查看与添加实践
  • RabbitMQ第二章(RocketMQ的五大工作模式)
  • 二进制安全-汇编语言-04-第一个程序
  • 为什么elementui的<el-table-column label=“名称“ prop=“name“ label不用写成:label
  • Docker快速部署Hive服务
  • C++ 遍历可变参数的几种方法
  • 零基础|宝塔面板|frp内网穿透|esp32cam远程访问|微信小程序
  • 链表算法之【移除链表元素】
  • 【深度学习新浪潮】什么是上下文长度?
  • C++异步编程入门
  • 猿人学js逆向比赛第一届第十五题
  • Java面试基础:概念
  • 部署并运行Vim/Vmamba在ImageNet上的训练与测试
  • JavaScript之数组方法详解
  • (C++)list列表相关基础用法(C++教程)(STL库基础教程)
  • HTTP/3.x协议详解:基于QUIC的下一代Web传输协议
  • 音频被动降噪技术