当前位置: 首页 > news >正文

从RL的专业角度解惑 instruct GPT的目标函数

作为早期chatGPT背后的核心技术,instruct GPT一直被业界奉为里程碑式的著作。但是这篇论文关于RL的部分确写的非常模糊,几乎一笔带过。当我们去仔细审查它的目标函数的时候,心中不免有诸多困惑。特别是作者提到用PPO来做强化学习,但是那个目标函数却怎么看都和经典的PPO目标函数不大一样。网上关于这一点的解释资料也甚少,而且不免有理解错误的。所以,鉴于GPT技术在今天是如此的重要,我觉得有必要去把里面的一些误解澄清。这样,后人也可以更加透彻的理解这里面的核心思想,以及这篇文章所用的PPO和原始版本PPO之间的关联。

首先,我们来看原论文的目标函数(省略了pretrain约束的版本):

J(\theta)=E_{(x,y)\sim D_{\pi_\phi}}[r(x,y)-\beta log(\frac{\pi_\phi(y|x)}{\pi_{SFT}(y|x)})]

如果没有后面的惩罚项,这就是一个经典的策略梯度优化对象,我们可以直接把梯度算出来:

J(\phi)=E_{(x,y)\sim D_{\pi_\phi}}[r(x,y)]\approx E_{x\sim D_{\pi_\phi},y\sim \pi_\phi(\cdot|x)}[r(x,y)]=E_{x\sim D_{\pi_\phi}}[\sum_y\pi_\phi(y|x)r(x,y)]

\nabla_\phi J(\phi)=E_{x\sim D_{\pi_\phi}}[\sum_y\nabla_\phi\pi_\phi(y|x)r(x,y)]=E_{x\sim D_{\pi_\phi},y\sim \pi_\phi(\cdot|x)}[\nabla_\phi log \pi_\phi(y|x) r(x,y)]

接下来,经典的做法就是用采样来估计这个梯度,然后做梯度下降,用REINFORCE就可以优化。

但是REINFORCE和PPO最大的差异,在于对新老策略之间距离的约束,也就是KL项。这个项在某种意义上其实是改变了策略空间的度规,从而更自然的反应两个策略(概率分布)之间的真实距离(也就是自然梯度),而原始的REINFORCE之所以效果不好,是因为它默认选择用欧式度规,而这对描述概率分布之间的差异来说并不合适。

那么instruct GPT第一个令人困惑的问题来了,他的KL惩罚项在哪里?大多数人都是直觉上认为这个log(\frac{\pi_\phi(y|x)}{\pi_{SFT}(y|x)})就是KL项,但是这不够严谨,尽管KL的定义和两个分布的比值取对数有关。如果我们严格的把KL的定义写出来,它有如下形式:

KL[\pi_\phi(\cdot |x),\pi_{SFT}(\cdot |x)]=\sum_y\pi_\phi(y|x)log(\frac{\pi_\phi(y|x)}{\pi_{SFT}(y|x)})=E_{y\sim \pi_\phi(\cdot |x)}[log(\frac{\pi_\phi(y|x)}{\pi_{SFT}(y|x)})]

看到这里我们就发现了第一个端倪,这里其实是有一个近似的,而这个近似只有在抽样足够多的时候才成立:

E_{x\sim D_\phi,y\sim \pi_\phi(\cdot |x)}[log(\frac{\pi_\phi(y|x)}{\pi_{SFT}(y|x)})]\approx E_{(x,y)\sim D_\phi}[log(\frac{\pi_\phi(y|x)}{\pi_{SFT}(y|x)})]

所以这个KL项其实是被吸收到期望内部去了,而吸收的前提就是上面提到的这个近似。我们把这个KL项单独提出来,就得到了PPO的目标函数形式(注意,这里是KL形式,而非CLIP形式):

J(\phi)=E_{(x,y)\sim D_{\pi_\phi}}[r(x,y)]-\beta E_{(x,y)\sim D_{\pi_\phi}} [log(\frac{\pi_\phi(y|x)}{\pi_{SFT}(y|x)})]\approx E_{(x,y)\sim D_{\pi_\phi}}[r(x,y)]-\beta E_{x\sim D_{\pi_\phi}}[KL[\pi_\phi(\cdot|x), \pi_{SFT}(\cdot|x)]]

所以网络上所谓的“把KL惩罚直接加到reward上”的说法其实是不准确的,虽然在当前这个目标函数下这二者是等价的,但是一旦我们用类似于PPO中importance sampling的方法来处理这个目标函数,很多地方就说不通了。但是,当我们把它还原成这个标准形式后,我们就发现importance sampling其实不会作用在KL项上。

理解了上面说的,就会立马注意到另外一个令人困惑的地方:如果我们把\pi_{SFT}看作是PPO中的\pi_{old}, 那么这个KL惩罚项其实是和PPO中的KL惩罚项相反的

KL[\pi_\phi(\cdot|x), \pi_{old}(\cdot|x)]\neq KL[\pi_{old}(\cdot|x), \pi_{\phi}(\cdot|x)]

尽管这样并不会影响PPO算法的正确性,因为我们知道

KL[\pi_{old}(\cdot|x), \pi_{\phi}(\cdot|x)]<\delta \Rightarrow KL[\pi_\phi(\cdot|x), \pi_{old}(\cdot|x)]<\frac{\delta}{min_y\pi_{old}(y|x)}

尽管这两个KL都是衡量新策略​相对于旧策略的偏离程度,但是我们依然想搞清楚这二者之间的差异究竟是什么,我们又该在什么时候选择什么样的KL项呢?为了理解清楚这个问题,我们首先来需要注意到当新旧策略在单个数据点上出项差异的时候其实有两种情况:(\pi_{old}​:high,\pi_{\phi}​:low) 和 (\pi_{old}​:low,\pi_{\phi}​:high). 而这正好就对应了这两种KL惩罚项的作用对象。因为KL散度不具备对易性,所以一种KL只会对应的去作用于一种情况,而非二者兼备!

简单的说,当旧策略认为某个动作的概率高而新策略认为该动作的概率低时,KL[\pi_{old}(\cdot|x), \pi_{\phi}(\cdot|x)]会对此进行惩罚,但是KL[\pi_\phi(\cdot|x), \pi_{old}(\cdot|x)]却对此视而不见;同样的,当新策略认为某个动作的概率高而旧策略认为该动作的概率低时,KL[\pi_\phi(\cdot|x), \pi_{old}(\cdot|x)]会进行惩罚, 但KL[\pi_{old}(\cdot|x), \pi_{\phi}(\cdot|x)]会对此视而不见。

理解了这一点,我们就明白了KL[\pi_\phi(\cdot|x), \pi_{old}(\cdot|x)]其实比KL[\pi_{old}(\cdot|x), \pi_{\phi}(\cdot|x)]要更加严格且保守的,因为KL[\pi_\phi(\cdot|x), \pi_{old}(\cdot|x)]主要惩罚新策略增加旧策略低概率动作的概率,从而确保新策略保守更新,保持旧策略的高质量特性。相对的,KL[\pi_{old}(\cdot|x), \pi_{\phi}(\cdot|x)]主要惩罚新策略降低旧策略高概率动作的概率,但对新策略增加旧策略低概率动作的概率限制较少。所以说,PPO中的KL[\pi_{old}(\cdot|x), \pi_{\phi}(\cdot|x)]其实是更加鼓励新策略的exploration的,而instruct GPT中的KL[\pi_\phi(\cdot|x), \pi_{old}(\cdot|x)]则更侧重于保留经过监督微调策略的高质量特性,并不鼓励新策略过多的exploration和创新

http://www.lryc.cn/news/398654.html

相关文章:

  • location匹配的优先级和重定向
  • 观察矩阵(View Matrix)、投影矩阵(Projection Matrix)、视口矩阵(Window Matrix)及VPM矩阵及它们之间的关系
  • 谷粒商城学习笔记-19-快速开发-逆向生成所有微服务基本CRUD代码
  • 时序预测 | Matlab实现TCN-Transformer的时间序列预测
  • 没想到MySQL 9.0这么拉胯
  • 开源 Wiki 系统 InfoSphere 2024.01.1 发布
  • 1.Introduction to Spring Web MVC framework
  • Onnx 1-深度学习-概述1
  • 网络基础——udp协议
  • 分布式锁理解
  • Android Gradle 开发与应用 (十): Gradle 脚本最佳实践
  • c#获取本机的MAC地址(附源码)
  • sqlmap使用之-post注入、head注入(ua、cookie、referer)
  • XSS: 原理 反射型实例[入门]
  • Idea新增Module报错:sdk ‘1.8‘ type ‘JavaSDK‘ is not registered in ProjectJdkTable
  • 基于RHCE基础搭建简单服务
  • 威纶通触摸屏软件离线仿真时出现报错8000端口占用或服务器断线
  • CAS详解
  • 【笔记】虚拟机中的主从数据库连接实体数据库成功后的从数据库不同步问题解决方法2
  • 【每日一练】python类和对象现实举例详细讲解
  • 【学习css1】flex布局-页面footer部分保持在网页底部
  • Java中创建线程的几种方式
  • [A-04] ARMv8/ARMv9-Cache的相关策略
  • 【笔试常见编程题06】最近公共祖先、求最大连续bit数、二进制插入、查找组成一个偶数最接近的两个素数
  • 【工具分享】Gophish——网络钓鱼框架
  • “职业三大底层逻辑“是啥呢?
  • 飞睿智能无线高速uwb安全数据传输模块,低功耗、抗干扰超宽带uwb芯片传输速度技术新突破
  • 手把手教你从微信中取出聊天表情图片,以动态表情保存为gif为例
  • 【深度学习】图形模型基础(5):线性回归模型第三部分:线性回归模型拟合
  • 【Git 入门】初始化配置与新建仓库