当前位置: 首页 > news >正文

训练一个中文gpt2模型

前言

  1. 这是我的github上的一个介绍,关于如何训练中文版本的gpt2的。
  2. 链接为: https://github.com/yuanzhoulvpi2017/zero_nlp

介绍

  1. 本文,将介绍如何使用中文语料,训练一个gpt2
  2. 可以使用你自己的数据训练,用来:写新闻、写古诗、写对联等
  3. 我这里也训练了一个中文gpt2模型,使用了612万个样本,每个样本有512个tokens,总共相当于大约31亿个tokens

⚠️安装包

需要准备好环境,也就是安装需要的包

pip install -r requirements.txt

像是pytorch这种基础的包肯定也是要安装的,就不提了。

数据

数据来源

  1. 获得数据:数据链接,关注公众号【统计学人】,然后回复【gpt2】即可获得。
  2. 获得我训练好的模型(使用了15GB的数据(31亿个tokens),在一张3090上,训练了60多小时)

数据格式

  1. 数据其实就是一系列文件夹📁,然后每一个文件夹里面有大量的文件,每一个文件都是.csv格式的文件。其中有一列数据是content
  2. 每一行的content就代表一句话,截图如下
  3. 虽然数据有15GB那么大,但是处理起来一点也不复杂,使用 datasets
    包,可以很轻松的处理大数据,而我只需要传递所有的文件路径即可,这个使用 glob 包就能完成。

代码

⚙️训练代码train_chinese_gpt2.ipynb

⚠️注意

  1. 现在训练一个gpt2代码,其实很简单的。抛开处理数据问题,技术上就三点:tokenizergpt2_modelTrainer
  2. tokenizer使用的是bert-base-chinese
    ,然后再添加一下bos_tokeneos_tokenpad_token
  3. gpt2_model使用的是gpt2,这里的gpt2我是从0开始训练的。而不是使用别人的预训练的gpt2模型。
  4. Trainer训练器使用的就是transformersTrainer模块。(支撑多卡并行,tensorboard等,都写好的,直接调用就行了,非常好用)

📤推理代码infer.ipynb

⚠️注意

这个是chinese-gpt2的推理代码

  1. 将代码中的model_name_or_path = "checkpoint-36000"里面的"checkpoint-36000",修改为模型所在的路径。
  2. 然后运行下面一个代码块,即可输出文本生成结果
  3. 可以参考这个代码,制作一个api,或者打包成一个函数或者类。

🤖交互机器人界面chatbot.py

⚠️注意

  1. 修改代码里面的第4行,这一行值为模型所在的位置,修改为我分享的模型文件路径。
model_name_or_path = "checkpoint-36000"
  1. 运行
python chatbot.py
  1. 点击链接,即可在浏览器中打开机器人对话界面

更多

  1. 这个完整的项目下来,其实我都是全靠huggingface文档、教程度过来的.
  2. 我做的东西,也就是把Tokenizer改成中文的了,然后也整理了数据,别的大部分东西,都不是我做的了.
  3. 原文链接为https://huggingface.co/course/zh-CN/chapter7/6?fw=pt.

一起学习

其实,我更喜欢做应用,但是也要理解相关的背后原理,目前还在研究相关的gpt2原理还有相关的推理细节,这是我整理的链接,希望可以共同进步

  1. https://huggingface.co/blog/how-to-generate
  2. https://huggingface.co/gpt2
  3. https://huggingface.co/gpt2-large
http://www.lryc.cn/news/7322.html

相关文章:

  • python文件头规范和函数注释自动生成(pycharm)
  • Fluent Python 笔记 第 17 章 使用 future 处理并发
  • Android进阶之路 - StringUtils、NumberUtils 场景源码
  • 装备制造业数字化转型CRM系统解决方案(信息图)
  • CGAL 二维剖分
  • node.js+vue婚纱影楼摄影婚庆管理系统vscode项目
  • C语言 指针的新理解
  • 【向每个应用View中增加子控件 Objective-C语言】
  • 【FPGA】Verilog:组合电路设计 | 三输入 | 多数表决器
  • 【安全等保】安全等保二级和三级哪个高?哪个费用更高?
  • C++ STL学习记录(v1)
  • 开发中遇到的问题
  • Javascript笔记
  • Elasticsearch(ES)配置及优化
  • 一文看懂Java语言与Java生态圈
  • GitHub 上有什么嵌入式方面的项目?
  • 【C语言进阶】结构体、位段、枚举和联合
  • markdown和latex常用部分参考@注脚@链接跳转@csdn
  • Java 在二叉树中增加一行
  • kubernetes(k8s) 知识总结(第2期)
  • windows-Mysql的主从数据库同步设置
  • Docker逃逸
  • k8s项目部署
  • Modbus通信协议学习笔记
  • ubuntu重启、关机命令
  • Xshell 7 连接云服务器的步骤和出现的错误
  • Python多进程同步——文件锁
  • 实现 element-plus 表格多选时按 shift 进行连选的功能
  • 华为OD机试真题JAVA实现【考古学家】真题+解题思路+代码(20222023)
  • Spring3之基于Aspect实现AOP