当前位置: 首页 > news >正文

pytorch使用DataParallel并行化保存和加载模型(单卡、多卡各种情况讲解)

话不多说,直接进入正题。

!!!不过要注意一点,本文保存模型采用的都是只保存模型参数的情况,而不是保存整个模型的情况。一定要看清楚再用啊!

1 单卡训练,单卡加载

#保存模型
torch.save(model.state_dict(),'model.pt')#加载模型
model=MyModel()#MyModel()是你定义的创建模型的函数,就是先初始化得到一个模型实例,之后再将模型参数加载到该实例上
model.load_state_dict(torch.load('model.pt'))

2 单卡训练,多卡加载

保存模型的过程同第一种情况一样,但是要注意,多卡加载模型时, 是先加载模型参数,再对模型做并行化处理。

#保存模型
torch.save(model.state_dict(),'model.pt')#加载模型
model=MyModel()
model.load_state_dict(torch.load('model.pt'))model=nn.DataParallel(model)#将模型进行并行化处理

3 多卡保存,单卡加载

方法一:

考虑到之后可能需要单卡加载你多卡训练的模型,所以建议在保存的时候,要去除模型参数字典里面的module,即使用model.module.state_dict()代替model.state_dict()来进行去除。

因为是单卡加载,所以还是要先加载 模型参数,再对模型做并行化处理。

#保存模型
torch.save(model.module.state_dict(),'modle.pt')#加载模型
model=MyModel()
model.load_state_dict(torch.load('model.pt'))model=nn.DataParallel(model)

方法二:

仍然使用model.state_dict()保存,但是单卡加载的时候,要把模型做并行化(在单卡上并行),加载的时候要注意:由于我们保存到 方式是以多卡方式保存的,所以无论加载之后的模型是 在答案卡上运行还是在多卡上运行,都要先把模型并行化处理,然后再去加载模型。

#保存模型
torch.save(model.state_dict(),'model.pt')#加载模型
model=MyModel()model=nn.DataParallel(model)model.load_state_dict(torch.load('model.pt'))

4 多卡保存,多卡加载

这里保存模型采用”多卡保存,单卡加载“的第二种方法,加载的时候,要先把模型做并行化(在多卡上并行),然后再加载。

#保存模型
torch.save(model.state_dict(),'model.pt')#加载模型
model=MyModel()model=nn.DataParallel(model)model.load_state_dict(torch.load('model.pt'))

希望以上内容能够帮助到你,这里是希望你能越来越好的 小白冲鸭 ~~~

http://www.lryc.cn/news/366336.html

相关文章:

  • PS初级|写在纸上的字怎么抠成透明背景?
  • Docker面试整理-Docker的网络是如何工作的?
  • 获得抖音商品评论 API 返回值
  • Qt | QtBluetooth(蓝牙电脑当服务端+手机当客户端) 配对成功啦
  • 我找到了全网最低价买服务器的 bug !!!
  • 聚类的外部指标(Purity, ARI, NMI, ACC) 和内部指标(NCC,Entropy,Compactness,Silhouette Index)
  • 国标GB/T 28181详解:国标GBT28181-2022的客户端主动发起历史视音频回放流程
  • Vue项目安装axios报错npm error code ERESOLVE npm error ERESOLVE could not resolve解决方法
  • 【Linux】Centos7升级内核的方法:yum更新(ELRepo)
  • 【CSS】object-fit 和 object-position 属性详解
  • 【算法专题--栈】最小栈--高频面试题(图文详解,小白一看就会!!)
  • Vite项目构建chrome extension,实现多入口
  • 【vector模拟实现】附加代码讲解
  • 本地运行ChatTTS
  • 应用解析 | 面向智能网联汽车的产教融合解决方案
  • 华为设备动态路由OSPF(单区域+多区域)实验
  • R语言探索与分析19-CPI的分析和研究
  • 【C++ | 拷贝构造函数】一文了解C++的 拷贝(复制)构造函数
  • 【工具】Vmware17 安装mac(13.6.7)虚拟机
  • mac node版本切换 nvm install nvm ls-remote N/A问题
  • 牛客小白月赛95
  • Python实现调用并执行Linux系统命令
  • 古字画3d立体在线数字展览馆更高效便捷
  • 编写程序,提示用户输入以米/秒(m/s)为单位的速度v和以米/秒的平方(m/s)为单位的加速度 a,然后显示最短跑道长度。
  • k8s 对外发布(ingress)
  • FL Studio21.2.7最新中文破解版免费激活,音乐制作全掌握!
  • 2 - 寻找用户推荐人(高频 SQL 50 题基础版)
  • 高考志愿填报有哪些技巧和方法
  • codereview时通常需要关注哪些
  • DSP28335模块配置模板系列——定时器中断配置模板