当前位置: 首页 > news >正文

由浅入深学习大语言模型RLHF(PPO强化学习- v1浅浅的)

        最近,随着DeepSeek的爆火,GRPO也走进了视野中。为了更好的学习GRPO,需要对PPO的强化学习有一个深入的理解,那么写一篇文章加深理解吧。纵观网上的文章,要么说PPO原理,各种复杂的公式看了就晕,要么说各种方式命名的模型,再要么默认你是个NLPer。这导致RLer和NLPer之间学习大语言模型强化学习产生了巨大的gap。于是,我们单纯说说大语言模型里面的PPO吧。

        其实PPO也是在训练模型,和SFT一样,都是为了获得最终一个用于推理部署的模型。SFT训练模型时一般需要一个base【模型】和【损失函数】,这里先这么浅显地说,因为PPO也需要这些,我们通过这些相同的部分来弥补NLP和RL之间的gap吧。

一、模型

SFT(Only one model)

        SFT模型一般是一个已经预训练过的大语言模型(例如GPTs、BERT等),或者是一个未经训练的小模型(LSTM)。

PPO(Four Models)

        PPO训练时总共有四个模型分别是Policy Model(Actor)、Reward Model、Reference Model、Critic Model。这里和SFT模型相似的也就是Policy Model,这个模型也是经过预训练的模型且用于未来实际使用的模型。其他三个模型都是用来辅助Policy Model模型训练的。那么其他三个模型的作用是什么呢?Let's talk step by step.

        Reward Model

        Fine! 你肯定早就听过这个模型。我们说PPO是根据好的和坏的样本对来进行训练的,从来让模型产生输出好样本的偏好,那么如何识别好样本和坏样本呢?是的,依靠Reward Model。让我们看看下边的例子。

s1: 中国的首都是哪里?北京。->过于简洁,但正确,2分

s2: 中国的首都是哪里?中国的首都是北京。->比较中肯,3分

s3: 中国的首都是哪里?中国的首都不是广州和武汉,是北京。->很多废话,0分

s4: 中国的首都是哪里?中国的首都不是广州,是北京。->一点点废话,1分

        上面我按照自己的偏好给每个句子进行了打分。Reward Model在这里的作用就是学习我打分的风格,然后产生一个数值或者概率,这里可以用各种方法,先不说具体咋做,可以线按照你想的方法产生一个分数,然后我们接着往下走。好了,我们现在有一个模型可以产生奖励了,我们可以给Policy Model模型产生的输出打分了,然后对这个分数进行优化,即奖励较大时加大对损失的权重,奖励较大时给损失乘以一个较小的权重。这样,模型就可以达到L1级别(借用自动驾驶等级概念)的偏好学习了。

        看似我们的方法已经可以work了,但仅仅是看似。实际上在模型训练的过程中可能会因为Reward打分不准导致Policy Model训练出现偏差或者Policy Model过于追求奖励大的而出现性能下降的现象。

        所以这么不稳定的训练,需要再加入另外一个模型Reference Model来维护一下训练的稳定性。

        Reference Model

        Reference Model被用于维持训练的稳定性。我们知道PPO被用在大语言模型是为了维持模型回答的风格,这种风格应该是朝着某种方向去的。比如,我们需要训练一个模型,他的风格需要是安全型的,即在特定情况下,他应该对用户的输入做出拒答。但是当训练不稳定时,模型可能对用户所有的输入都做出拒答。显然,这不是我们想要的模型效果。所以,我们需要一个基准模型,这个模型给Policy Model当作参考,告诉他不要在训练的时候偏离基准模型太远,即保留基准模型的一些能力。

        于是Reference Model(参考模型)呼之欲出,那么这个与Reference Model控制距离的方法如何实现呢?我们简单的猜一下,控制距离的方法。我这里给出一个简单的猜测,我们可以将输入同时送入到Reference Model和Reward Model中,然后根据两个句子输出的logits计算距离,如果距离过大时应该被拉近,距离适当时可以保持。当然,作为Reference Model在训练的时候是不需要更新参数的,不然就被一起拉着跑偏了。

        看着我们的模型可以训练起来了,正式进入L2级别。

        Critic Model

        没错,我们的模型其实完全可以训练了,至于为什么要多次一举,加个Critix Model。我也不是特别的理解。那么,让我们问问DeepSeek吧。

 

        DeepSeek告诉我们Critic Model可以降低方差?那么为什么呢?为什么前面的方法会出现高方差,以及Critic Model是如何降低方差的呢?

        我们打个比方,比如我们日常在与人交流的时候,可能一不小心说出了话,让他人不开心,后来通过各种方法找补回来,让别人理解了我们的内心想法。先说错话找补回来让对方理解和直接让对方理解我们真实的想法,这二者最后的结果是一样的(即奖励,Reward Model的打分),但是过程是不一样的。这个Critic Model的作用可以类比为教我们如何正确的表达,而不仅仅是会说的对。

        至于优势估计和价值引导,可以看到优势估计可以衡量特定动作(输出特定token)对于平均情况的优势,价值引导是提供长期回报(即输出某个token的长期回报),这两者也都是面向token级别(Critic Model),而不是句子级别(Reward Model给整个句子打分)的优化。


       至此,我们理解了这四个模型的大致作用,下面我们从具体说说是PPO如何做的。

二、损失函数

        施工中...

http://www.lryc.cn/news/539839.html

相关文章:

  • 网络安全三件套
  • 瑞芯微RV1126部署YOLOv8全流程:环境搭建、pt-onnx-rknn模型转换、C++推理代码、错误解决、优化、交叉编译第三方库
  • 【ISO 14229-1:2023 UDS诊断(会话控制0x10服务)测试用例CAPL代码全解析⑤】
  • python-leetcode 35.二叉树的中序遍历
  • glob 用法技巧
  • CodeGPT 使用教程(适用于 VSCode)
  • 以下是MySQL中常见的增删改查语句
  • Vue3 与 TypeScript 实战:核心细节与最佳实践
  • 23种设计模式 - 解释器模式
  • 常用的 React Hooks 的介绍和示例
  • ChatGLM-6B模型
  • 编译安装php
  • 【JavaEE进阶】Spring MVC(3)
  • 30 款 Windows 和 Mac 下的复制粘贴软件对比
  • 【LLAMA】羊驼从LLAMA1到LLAMA3梳理
  • 【OS安装与使用】part3-ubuntu安装Nvidia显卡驱动+CUDA 12.4
  • 【蓝桥杯集训·每日一题2025】 AcWing 6123. 哞叫时间 python
  • JAVA中常用类型
  • 【办公类-90-02】】20250215大班周计划四类活动的写法(分散运动、户外游戏、个别化综合)(基础列表采用读取WORD表格单元格数据,非采用切片组合)
  • 求矩阵对角线元素的最大值
  • NoSQL之redis数据库
  • 【R语言】非参数检验
  • 【力扣Hot 100】栈
  • HTTP 与 HTTPS:协议详解与对比
  • C++编程语言:抽象机制:模板和层级结构(Bjarne Stroustrup)
  • 建筑兔零基础自学python记录22|实战人脸识别项目——视频人脸识别(下)11
  • 在使用export default 导出时,使用的components属性的作用?
  • 以太网交换基础(涵盖二层转发原理和MAC表的学习)
  • Vue 实现通过URL浏览器本地下载 PDF 和 图片
  • 【2025最新计算机毕业设计】基于SpringBoot+Vue非遗传承与保护研究系统【提供源码+答辩PPT+文档+项目部署】