当前位置: 首页 > news >正文

使用 GaLore 预训练LLaMA-7B

项目代码:

https://github.com/jiaweizzhao/galoreicon-default.png?t=O83Ahttps://github.com/jiaweizzhao/galore

参考博客:

https://zhuanlan.zhihu.com/p/686686751

创建环境

基础环境配置如下:

  • 操作系统: CentOS 7
  • CPUs: 单个节点具有 1TB 内存的 Intel CPU,物理CPU个数为64,每颗CPU核数为16
  • GPUs: 8 卡 A800 80GB GPUs
  • Python: 3.10 (需要先升级OpenSSL到1.1.1t版本(点击下载OpenSSL),然后再编译安装Python),点击下载Python
  • NVIDIA驱动程序
  • 版本: 515.125.06,根据不同型号选择不同的驱动程序,点击下载。
  • CUDA工具包: 11.8,点击下载

conda create -n GaLore python=3.10

安装依赖包

pip install -r requirements.txt

其中,requirements.txt 文件为:

torch==2.1.0
transformers==4.31.0
tokenizers
datasets==2.14.6
peft
wandb
loguru
nvitop
lion-pytorch
matplotlib
bitsandbytes
scipy
scikit-learn
evaluate

pip install tensorly

注意:Pytorch 需确保2.1.0以上,不然会报错。

数据集准备

本文使用 C4 数据集进行训练,C4 数据集是由 Google 提供的一个大型预训练数据集,用于训练语言模型。C4 数据集包含了数十亿个网页的内容,经过处理和清洗后形成了一个适合用于训练大型语言模型的数据集。这个数据集可以用于训练各种自然语言处理任务,如文本生成、文本分类

、语言建模等。语言建模

  • 下载地址:https://huggingface.co/datasets/allenai/c4/tree/main/en

由于数据集太大,这里只下载了一个文件大约356317条数据。

wandb 启用离线模式

启用离线模式后,wandb 将不会上传数据,但仍然会记录实验过程中的数据和结果。

wandb  offline
# W&B offline. Running your script from this directory will only write metadata locally. Use wandb disabled to completely turn off W&B.

单张 4090 消费级显卡预训练 LLaMA-7B

接下来,使用单个 GPU(例如:NVIDIA RTX 4090)训练 7B 模型,您所需要做的就是指定 --optimizer=galore_adamw8bit_per_layer ,这会启用 GaLoreAdamW8bit 并进行每层权重更新。通过激活(梯度)检查点(activation checkpointing),您可以将在 NVIDIA RTX 4090 上测试的批量大小保持为 16。

执行命令:

CUDA_VISIBLE_DEVICES=3 torchrun --standalone --nproc_per_node 1 torchrun_main.py \--model_config configs/llama_7b.json \--lr 0.005 \--galore_scale 0.25 \--rank 1024 \--update_proj_gap 500 \--batch_size 16 \--total_batch_size 512 \--activation_checkpointing \--num_training_steps 150000 \--warmup_steps 15000 \--weight_decay 0 \--grad_clipping 1.0 \--dtype bfloat16 \--eval_every 1000 \--single_gpu \--optimizer galore_adamw8bit_per_layer

CUDA_VISIBLE_DEVICES=3 torchrun --standalone --nproc_per_node 1 torchrun_main.py --model_config configs/llama_7b.json --lr 0.005 --galore_scale 0.25 --rank 1024 --update_proj_gap 500 --batch_size 16 --total_batch_size 512 --activation_checkpointing --num_training_steps 150000 --warmup_steps 15000 --weight_decay 0 --grad_clipping 1.0 --dtype bfloat16 --eval_every 1000 --single_gpu --optimizer galore_adamw8bit_per_layer

好像是因为连不了外网所以没找到数据集:

解决方法,手动下载数据集,上传到服务器:

 下载地址:https://huggingface.co/datasets/allenai/c4/tree/main/en

同样,模型也要提前下好,放在指定位置:

http://www.lryc.cn/news/439353.html

相关文章:

  • gitlab无法push(pre-receive hook declined)
  • 物品识别——基于python语言
  • 【PostgreSQL】安装及使用(Navicat/Arcgis),连接(C#)
  • 第L6周:机器学习-随机森林(RF)
  • 【电路笔记】-差分运算放大器
  • git 命令---想要更改远程仓库
  • LeetCode:2848. 与车的相交点 一次遍历,时间复杂度O(n)
  • Xcode 16 RC (16A242) 发布下载,正式版下周公布
  • git 更换远程地址的方法
  • 9. 什么是 Beam Search?深入理解模型生成策略
  • Spring自定义注解
  • 微信小程序:wx.login或调用uni.login时报错the code is a mock one
  • URL的执行流程
  • 双指针算法专题(2)
  • 加密与安全_优雅存储用户密码的最佳实践
  • 【多线程】深入剖析线程池的应用
  • 『功能项目』切换职业面板【48】
  • 【EasyExcel】@ColumnWidth(value = 20) EasyExcel设置列宽不生效
  • CPU 和 GPU:为什么GPU更适合深度学习?
  • 【机器学习】:解锁数据背后的智慧宝藏——深度探索与未来展望
  • 【Kubernetes】常见面试题汇总(十八)
  • 无限边界:现代整合安全如何保护云
  • HTML贪吃蛇游戏
  • HTML 揭秘:HTML 编码快速入门
  • Ubuntu22.04系统安装opencv步骤简述及问题解决方法
  • 移情别恋c++ ദ്ദി˶ー̀֊ー́ ) ——13.mapset
  • 【webpack4系列】webpack基础用法(二)
  • Python Pyvis库创建交互式网络图 高级功能详解
  • Linux服务器上安装git lfs命令
  • S100A9:鸡支原体感染中的免疫调控“双面间谍”【AbMole】