当前位置: 首页 > news >正文

【pytorch】多GPU同时训练模型

文章目录

  • 1. 基本原理
    • 单机多卡训练教程——DP模式
  • 2. Pytorch进行单机多卡训练步骤
    • 1. 指定GPU
    • 2. 更改模型训练方式
    • 3. 更改权重保存方式


摘要:多GPU同时训练,能够解决单张GPU显存不足问题,同时加快模型训练。

1. 基本原理

单机多卡训练教程——DP模式

(1)将模型复制到各个GPU中,并将一个batch的数据划分成mini_batch(平均分配) 并分发给每个GPU;
注意:这里的batch_size要大于device数。
(2)各个GPU独自完成mini_batch的前向传播,并把获得的output传递给GPU_0(主GPU) ;
(3) GPU_0整合各个GPU传递过来的output,并计算loss。此时GPU_0可以对这些loss进行一些聚合操作;
(4) GPU_0归并loss之后,并进行后向传播以及梯度下降从而完成模型参数的更新(此时只有GPU_0上的模型参数得到了更新),GPU_0将更新好的模型参数又传递给其余GPU;

以上就是DP模式下多卡GPU进行训练的方式。其实可以看到GPU_0不仅承担了前向传播的任务,还承担了收集loss,并进行梯度下降。因此在使用DP模式进行单机多卡GPU训练的时候会有一张卡的显存利用会比其他卡更多,那就是你设置的GPU_0。

2. Pytorch进行单机多卡训练步骤

只需要在你的代码中改三个地方就可实现

1. 指定GPU

在这里插入图片描述
如上所示,在导入各种库下面使用os.environ["CUDA_VISIBLE_DEVICES"]来指定可识别的GPU,该语句在程序开始前使用。
代码如下:

import torch.nn as nn
import os
os.environ["CUDA_VISIBLE_DEVICES"]= 2,3,1'#指定该程序可以识别的物理GPU编号,这里的你主机上的2号GPU就是训练程序中的主GPUO,这里最好—定要自己指定你自己可以用的gpu号。

2. 更改模型训练方式

在这里插入图片描述
平常的模型训练方式只需要model.cuda()语句即可,在单机多卡训练中,只需要在该语句下面添加一行nn.DataParallel语句即可。
代码如下

model.cuda()
model = nn.DataParallel(model,devise =[0,1,2])#在执行该语句之前最好加上model.cuda(),保证你的模型存在GPU上即可

3. 更改权重保存方式

对于数据,我们只需要按照平常的方式使用.cuda()放置在GPU上即可,内部batch的拆分已经被封装在了DataPanallel模块中。要注意的是,由于我们的model被nn.DataPanallel()包裹住了,所以如果想要储存模型的参数,需要使用:model.module.state_dict()的方式才能取出(不能直接是model.state_dict()
代码如下:

'''
使用单机多卡训练的模型权重保存方式
'''
torch.save(model.module.state_dict(),f'best.pth')  

作为参考,将平常的权重保存方式也写上:

'''
平常的权重保存方式
'''
torch.save(model.state_dict(),f'best.pth')  
http://www.lryc.cn/news/187222.html

相关文章:

  • Git 学习笔记 | Git 基本理论
  • 滚动表格封装
  • 【LeetCode高频SQL50题-基础版】打卡第3天:第16~20题
  • 系统压力测试:保障系统性能与稳定的重要措施
  • 常用数据结构和算法
  • C++中使用引用避免内存复制
  • 计算机网络(第8版)-第4章 网络层
  • chromadb 0.4.0 后的改动
  • Windows环境下下载安装Elasticsearch和Kibana
  • 机器学习:随机森林
  • ctfshow-web11(session绕过)
  • 状态模式:对象状态的变化
  • 解耦常用方法
  • 根据二叉树创建字符串--力扣
  • 代码事件派发机制(观察者模式)
  • 微服务技术栈-Nacos配置管理和Feign远程调用
  • 操作系统 OS
  • 基于ffmpeg给视频添加时间字幕
  • 爬虫基础知识点快速入门
  • 解释器模式 行为型模式之五
  • 2023年中国汽车座舱行业发展现状及趋势分析:高级人机交互(HMI)系统将逐步提升[图]
  • 常见的通用型项目管理软件推荐
  • 手机总是提醒系统更新,到底要不要更新呢?
  • 什么是API
  • RedissonClient 分布式锁 处理并发访问共享资源
  • Hadoop-2.5.2平台环境搭建遇到的问题
  • 基于WTMM算法的图像多重分形谱计算matlab仿真
  • VR全景展示带来旅游新体验,助力旅游业发展!
  • Xcode 15 编译出错问题解决
  • 基于指数趋近律的机器人滑模轨迹跟踪控制算法及MATLAB仿真