当前位置: 首页 > news >正文

Pytorch FSDP权重分片保存与合并

注:本文章方法只适用Pytorch FSDP1的模型,且切分策略为FULL_STATE_DICT场景。

在使用FSDP训练模型时,为了节省显存通常会把模型权重也进行切分,在保存权重时为了加速保存通常每个进程各自保存自己持有的部分权重,避免先汇聚到主进程再保存浪费大量时间的问题。保存成分片权重后,如果需要推理则还需要将分片权重进行合并。下面提供了保存分片权重以及将分片权重合并的代码示例,代码主要参考accelerate官方源码。

import osimport torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
import torch.distributed.checkpoint.format_utils as dist_cp_format_utilsdef save_fsdp_model(model: FSDP, fsdp_ckpt_path: str):# refer accelerate/utils/fsdp_utils.py:save_fsdp_modelwith FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):os.makedirs(fsdp_ckpt_path, exist_ok=True)state_dict = {"model": model.state_dict()}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(fsdp_ckpt_path),planner=DefaultSavePlanner(),)def merge_fsdp_weights(fsdp_ckpt_path: str, save_path: str):# refer accelerate/utils/fsdp_utils.py:merge_fsdp_weightsstate_dict = {}dist_cp_format_utils._load_state_dict(state_dict,storage_reader=dist_cp.FileSystemReader(fsdp_ckpt_path),planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),no_dist=True,)# To handle if state is a dict like {model: {...}}if len(state_dict.keys()) == 1:state_dict = state_dict[list(state_dict)[0]]torch.save(state_dict, save_path)
http://www.lryc.cn/news/619268.html

相关文章:

  • 【C语言强化训练16天】--从基础到进阶的蜕变之旅:Day3
  • 【Qt开发】常用控件(三) -> geometry
  • 疏老师-python训练营-Day44预训练模型
  • php7 太空船运算符
  • Linux 软件编程:文件IO、目录IO、时间函数
  • 适配安卓15(对应的sdk是35)
  • RxJava 在 Android 中的深入解析:使用、原理与最佳实践
  • 大牌点餐接口api对接全流程
  • 《吃透 C++ 类和对象(中):构造函数与析构函数的核心逻辑》
  • Ubuntu22.04轻松安装Qt与OpenCV库
  • 药房智能盘库系统的Python编程分析与实现—基于计算机视觉与时间序列预测的智能库存管理方案
  • 基于大数据spark的医用消耗选品采集数据可视化分析系统【Hadoop、spark、python】
  • 分段锁和限流的间接实现
  • 通信中间件 Fast DDS(一) :编译、安装和测试
  • 机器学习—— TF-IDF文本特征提取评估权重 + Jieba 库进行分词(以《红楼梦》为例)
  • CMake进阶: 使用FetchContent方法基于gTest的C++单元测试
  • LINUX812 shell脚本:if else,for 判断素数,创建用户
  • 【GESP】C++一级知识点之【集成开发环境】
  • TF-IDF:信息检索与文本挖掘的统计权重基石
  • [SC]如何使用sc_semaphore实现对共享资源的访问控制
  • 初识神经网络04——构建神经网络2
  • 【从零开始java学习|第四篇】IntelliJ IDEA 入门指南
  • Redis序列化配置类
  • uni-app实战教程 从0到1开发 画图软件 (学会画图)
  • 基于STC8单片机的RTC时钟实现:从原理到实践
  • 聚合搜索中的设计模式
  • 数据结构:中缀到后缀的转换(Infix to Postfix Conversion)
  • 开发避坑指南(23):Tomcat高版本URL特殊字符限制问题解决方案(RFC 7230 RFC 3986)
  • 一键设置 NTP 时区的脚本(亲测,适用于部署 K8S 的前置环境)
  • 数据结构:图