当前位置: 首页 > article >正文

如何使用DeepSpeed来训练大模型

🔥 DeepSpeed是什么?

  • DeepSpeed 是微软开源的一个 分布式训练加速库

  • 能帮助我们:

    • 高效训练大模型(百亿、千亿参数规模)

    • 节省显存、加速训练

    • 支持 ZeRO 优化器、Offload、混合精度(FP16/BF16)、梯度累积

    • 快速启动多机多卡训练

总结一句话:

DeepSpeed = 大模型训练神器,尤其适合 SFT、预训练、微调阶段。


🌈 DeepSpeed 安装

1️⃣ 安装基础依赖

通常只需要:

pip install deepspeed

对于更大规模训练,可以加上:

pip install deepspeed[all]

确保安装了 PyTorch >= 1.12。


🚀 DeepSpeed 快速上手(训练脚本改造)

2️⃣ 修改训练脚本(以 PyTorch / Hugging Face 为例)

🧩 (1)DeepSpeed CLI 启动

假设你已经有一个 train.py(PyTorch训练脚本):

deepspeed train.py --deepspeed ds_config.json
  • ds_config.json:DeepSpeed配置文件(稍后详细讲)。

🧩 (2)代码适配(只需两步!)

✅ a. 导入 deepspeed

import deepspeed

✅ b. 替换优化器 & 模型初始化:

model_engine, optimizer, _, _ = deepspeed.initialize(args=your_args,model=model,optimizer=optimizer,model_parameters=model.parameters(),config="ds_config.json"
)

✅ c. 训练 loop 改为:

for batch in dataloader:outputs = model_engine(batch)loss = outputs.lossmodel_engine.backward(loss)model_engine.step()

🎯 小结:只需 initializemodel_engine 替换,几行代码搞定!


🔍 DeepSpeed配置文件(ds_config.json)详解

这是 DeepSpeed 的核心,控制训练的优化策略。常见配置如下:

{"train_batch_size": 32,"train_micro_batch_size_per_gpu": 4,"gradient_accumulation_steps": 8,"zero_optimization": {"stage": 2,"offload_optimizer": {"device": "cpu"},"offload_param": {"device": "cpu"}},"fp16": {"enabled": true},"gradient_clipping": 1.0,"steps_per_print": 100,"wall_clock_breakdown": false
}

⚙️ 常见配置解释:

参数含义推荐值 / 建议
train_batch_size全局 batch size必须设置
train_micro_batch_size_per_gpu每个GPU的 batch size看显存而定
gradient_accumulation_steps梯度累积步数train_batch_size / (num_gpus * micro_batch_size)
zero_optimizationZeRO 优化器stage 1/2/3
offload_optimizer优化器 offload省显存,慢一点
offload_param参数 offloadstage 3 时常用
fp16 / bf16混合精度true
gradient_clipping梯度裁剪1.0


📦 Hugging Face 🤗 集成 DeepSpeed

Hugging Face Transformers 已原生支持 DeepSpeed!
只需在 trainer 里加上 --deepspeed 参数即可!

✅ 步骤:
1️⃣ 准备 ds_config.json
2️⃣ 命令行运行:

accelerate config  # 配置训练
accelerate launch --multi_gpu --deepspeed ds_config.json train.py

✅ 代码示例:

from transformers import Trainer, TrainingArgumentstraining_args = TrainingArguments(output_dir="./results",per_device_train_batch_size=2,per_device_eval_batch_size=2,gradient_accumulation_steps=8,fp16=True,deepspeed="ds_config.json",  # 只需加这一行!
)trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,
)trainer.train()


💡 高级技巧

ZeRO-3 + Offload
最大化节省显存(即使只有 24GB 显卡也能训练 65B 模型!)
Activation Checkpointing
减少显存占用,开启方式:

"activation_checkpointing": {"partition_activations": true,"contiguous_memory_optimization": true
}

梯度累积
模拟大 batch size,显存不够时的必杀技。
DeepSpeed Inference Engine
支持推理加速,适合部署阶段。


🌳 项目结构

sft_project/
├── data/
│   ├── train.jsonl
│   └── val.jsonl
├── model/
│   └── (预训练模型文件夹,如LLaMA、Baichuan)
├── deepspeed_config/
│   └── ds_config.json
├── train.py
├── requirements.txt
└── README.md

🎓 总结

你想做什么?如何用DeepSpeed?
训练大模型deepspeed 启动,写好 ds_config.json
不想改代码Hugging Face Trainer + --deepspeed 参数
显存不够开启 ZeRO-3 + Offload + FP16/BF16
多机多卡训练deepspeed --num_gpus=8accelerate launch
部署DeepSpeed Inference 加速推理

http://www.lryc.cn/news/2393716.html

相关文章:

  • 道可云人工智能每日资讯|《北京市人工智能赋能新型工业化行动方案(2025年)》发布
  • Unity 中实现首尾无限循环的 ListView
  • mongodb集群之副本集
  • 基于微服务架构的社交学习平台WEB系统的设计与实现
  • window10下docker方式安装dify步骤
  • Spark SQL进阶:解锁大数据处理的新姿势
  • 放假带出门的充电宝买哪种好用耐用?倍思超能充35W了解一下!
  • 云原生DMZ架构实战:基于AWS CloudFormation的安全隔离区设计
  • 小工具合集
  • AI智能体策略FunctionCalling和ReAct有什么区别?
  • 改进自己的图片 app
  • docker不用dockerfile
  • Uniapp+UView+Uni-star打包小程序极简方案
  • 深度学习篇---Pytorch框架下OC-SORT实现
  • STM32 HAL库SPI读写W25Q128(软件模拟+硬件spi)
  • 算法题(159):快速幂
  • 【新品发布】嵌入式人工智能实验箱EDU-AIoT ELF 2正式发布
  • 基于javaweb的SpringBoot体检管理系统设计与实现(源码+文档+部署讲解)
  • Mac Python 安装依赖出错 error: externally-managed-environment
  • Docker Desktop for Windows 系统设置说明文档
  • C++高级编程深度指南:内存管理、安全函数、递归、错误处理、命令行参数解析、可变参数应用与未定义行为规避
  • 【下拉选项数据管理优化实践:从硬编码到高扩展性架构】
  • IPD的基础理论与框架——(四)矩阵型组织:打破部门壁垒,构建高效协同的底层
  • 深度学习篇---OC-SORT实际应用效果
  • 讲述我的plc自学之路 第十一章
  • OpenLayers 图形绘制
  • 小程序为什么要安装SSL安全证书
  • python打卡训练营打卡记录day40
  • 互联网大厂Java求职面试:Spring Boot 3.2+自动配置原理、AOT编译及原生镜像
  • 小型图书管理系统案例(用于spring mvc 实践)