当前位置: 首页 > news >正文

什么?穷哥们没钱RLHF?跟我一起DPO吧,丐版一样用

本次DPO训练采用TRL的方式来进行训练

Huggingface TRL是一个基于peft的库,它可以让RL步骤变得更灵活、简单,你可以使用这个算法finetune一个模型去生成积极的评论、减少毒性等等。

本次进行DPO的模型是一个500M的GPT-2,目的是训练快,少占资源,快速看到结果。

下载Tokenizer:

from transformers import AutoTokenizer

AutoTokenizer.from_pretrained('gpt2').save_pretrained('tokenizer/gpt2')

  下载Datasets:

from datasets import load_dataset

load_dataset('b-mc2/sql-create-context').save_to_disk(

'dataset/b-mc2/sql-create-context')

下载Model:

from transformers import AutoModelForCausalLM

AutoModelForCausalLM.from_pretrained('gpt2').save_pretrained('model/gpt2')

图片

图 下载Tokenizer,model,数据

首先我们看一下原始数据集,原始数据集的构成分为3部分,一个是question,代表想提出的问题,一个是answer代表回答,第三部分是context代表参考的表结构。

图片

图 原始数据集

图片

图 数据集样例

实际数据样例,我们进一步规范了三种数据类型:

·第一个prompt,包含了context表结构和问题。

·第二个chose,表示希望训练之后的模型按着什么范式来回答问题。

·第三个reject,表示不希望用什么方式来回答,这里就留空了,代表隐式确认,如果有条件也可以整理不喜欢的回答范式。

这个训练的目的就是不管回答什么问题,都要用SQL语句的形式来回答,强调一种受欢迎回答的范式,这也是RLHF/DPO训练的主要目的。

下面开始训练部分,首先load tokenizer。

图片

图8-9 load tokenizer

按照需求来整理数据格式。

图片

图 整理数据格式

读取模型。

from transformers import AutoTokenizer

import random

import torch

tokenizer = AutoTokenizer.from_pretrained('/data2/DPO/tokenizer/gpt2')

tokenizer.pad_token_id = 0

tokenizer

from transformers import AutoModelForCausalLM

model_dpo = AutoModelForCausalLM.from_pretrained('/data2/DPO/model/gpt2').to('cuda')

model_dpo_ref = AutoModelForCausalLM.from_pretrained('/data2/DPO/model/gpt2').to('cuda')

先做个测试看看模型目前是怎么回答的。

图片

图 训练前的回答方式

如上图所示,很显然这个回答方式不是我们要求的方式,我们需要它把问题都按着SQL语句来进行回答。

最后一步就是正式训练了。

图片

图片

图片

如上图所示,随着训练的开展,模型回复对话的方式,基本就越来越向着正规SQL的方向演进。

这就是DPO训练所达成的目的。

图片

也没有多废资源,我是点auto-map技能点了,正常也就一张A100够了。

http://www.lryc.cn/news/331666.html

相关文章:

  • 【Leetcode笔记】102.二叉树的层序遍历
  • 进程的状态
  • spring-boot集成websocket
  • 【Python】【Flask】提交表单后报500错误
  • Golang vs Java
  • HomePlug AV
  • 【面试八股总结】超文本传输协议HTTP(二)
  • SQL Server中视图使用子查询的性能影响与优化方案
  • Adaboost集成学习 | Matlab实现基于SVM-Adaboost支持向量机结合Adaboost集成学习时间序列预测(股票价格预测)
  • Apache DolphinScheduler 【安装部署】
  • 【随笔】Git -- 高级命令(上篇)(六)
  • java中Date类,SimpleDateFormat类和Calendar类
  • 施耐德 PLC 控制系统 产品 + 软件总体介绍 2020
  • UniApp 应用发布到苹果商店指南
  • KamaCoder 46. 携带研究材料(第六期模拟笔试)
  • MySQL的基本操作(超详细)
  • 自动驾驶之心规划控制笔记
  • Linux中部署Java jar 包 shell 脚本
  • auto.js v1.4.4 实现自动打卡
  • 【Linux实验室】NFS、DHCP的搭建
  • Samba 总是需要输入网络凭证
  • 图像处理_积分图
  • B/S架构SaaS模式 医院云HIS系统源码,自主研发,支持电子病历4级
  • (C)1005 继续(3n+1)猜想
  • 编译好的C++应用程序拷贝到其它电脑,提示dll未找到依赖项的解决方法。
  • wps 开发插件
  • C语言----数据在内存中的存储
  • 【Linux学习】Linux 的虚拟化和容器化技术
  • Delphi 是一种内存安全的语言吗?
  • golang语言系列:Scrum、Kanban等敏捷管理策略