当前位置: 首页 > news >正文

解决deepspeed框架的bug:不保存调度器状态,模型训练重启时学习率从头开始

deepspeed存在一个bug,即在训练时不保存调度器状态,因此如果训练中断后再重新开始训练,调度器还是会从头开始而不是接着上一个checkpoint的调度器状态来训练。这个bug在deepspeed的github中也有其他人提出:https://github.com/microsoft/DeepSpeed/issues/3875
因此我们需要写一个保存调度器状态的代码,才可以解决这个问题。
具体方法是加一个callback类,专门负责保存调度器的状态以及在训练重新开始时加载调度器的状态:
先在训练文件中给trainer加一个callback

from smoe.callbacks.save_model import SchedulerStateCallback
trainer.add_callback(SchedulerStateCallback)
class SchedulerStateCallback(TrainerCallback):def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):if os.environ.get("RANK", "0") == "0":#scheduler = kwargs['lr_scheduler']scheduler = kwargs.get("lr_scheduler")if scheduler is None:return scheduler_state = scheduler.state_dict()#save_path = os.path.join(args.output_dir, SCHEDULER_NAME)# 使用 PREFIX_CHECKPOINT_DIR 和 global_step 创建检查点目录名checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"# 完整的检查点目录路径checkpoint_path = os.path.join(args.output_dir, checkpoint_folder)# 如果目录不存在,则创建它if not os.path.exists(checkpoint_path):os.makedirs(checkpoint_path)# 完整的保存路径save_path = os.path.join(checkpoint_path, SCHEDULER_NAME)# 保存scheduler状态torch.save(scheduler_state, save_path)def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):# 如果resume_from_checkpoint设置了有效路径if args.resume_from_checkpoint is not None:load_path = os.path.join(args.resume_from_checkpoint, SCHEDULER_NAME)# 如果该路径下有保存的调度器状态,则加载它if os.path.exists(load_path):#scheduler = kwargs['lr_scheduler']scheduler = kwargs.get("lr_scheduler")if scheduler is None:return scheduler_state = torch.load(load_path)scheduler.load_state_dict(scheduler_state)

解决效果如下,我们可以看到,在chaeckpoint10重新开始训练的时候,学习率是接着之前的学习率开始的(5.5e-7),而不是从头开始(0.5e-7):
在这里插入图片描述在这里插入图片描述

http://www.lryc.cn/news/162388.html

相关文章:

  • Linux ipc通信(消息对列)
  • 【计算机网络】 ARP协议和DNS协议
  • 【逐步剖C++】-第一章-C++类和对象(上)
  • 索尼 toio™ 应用创意开发征文|探索创新的玩乐世界——索尼 toio™
  • 企业架构LNMP学习笔记23
  • 第六章 图 五、图的深度优先遍历(DFS算法)
  • React 中的 useLayoutEffect 钩子函数
  • upload-labs1-21关文件上传通关手册
  • MATLAB遗传算法求解生鲜货损制冷时间窗碳排放多成本车辆路径规划问题
  • 界面控件DevExpress .NET应用安全 Web API v23.1亮点:支持Swagger模式
  • SpringMVC之CRUD------增删改查
  • 微信小程序开发教学系列(4)- 抖音小程序组件开发
  • RabbitMQ反序列化失败:Failed to convert message
  • CTFSHOW 年CTF
  • 肖sir__设计测试用例方法之状态迁移法05_(黑盒测试)
  • 无涯教程-JavaScript - IMPRODUCT函数
  • yapi以及gitlab的容器化部署
  • TCP、UDP 协议的区别,各自的应用场景
  • C高级 DAY3
  • Linux CentOS7命令及命令行
  • 【C++入门到精通】C++入门 ——搜索二叉树(二叉树进阶)
  • 学成在线-网站搭建
  • stm32同芯片但不同flash工程更换Device出现报错
  • Element UI实现每次只弹出一个Message消息提示
  • 「网页开发|前端开发|Vue」04 快速掌握开发网站需要的Vue基础知识
  • 解决Redis分布式锁主从架构锁失效问题的终极方案 含面试题
  • 建站系列(三)--- 网络协议
  • jetson orin nx无显示器启动
  • 【APUE】标准I/O库
  • es6---模块化