当前位置: 首页 > news >正文

Linux -- 使用多张gpu卡进行深度学习任务(以tensorflow为例)

在linux系统上进行多gpu卡的深度学习任务

  • 确保已安装最新的 TensorFlow GPU 版本。
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
  • 1、确保你已经正确安装了tensorflow和相关的GPU驱动,这里可以通过在命令行输入nvidia-smi来查看:
    在这里插入图片描述
    如果成功显示了类似上述的GPU信息和驱动版本信息,则说明NVIDIA驱动已经正确安装。

2、导入必要的库,设置可见的gpu设备列表:

import tensorflow as tf
# 设置可见的GPU设备列表(例如,使用GPU 0、1、2和3)
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpu_devices, 'GPU')

在这里插入图片描述

  • 3、创建一个MirroredStrategy对象,该对象将自动复制模型和数据到每个可见的GPU卡上:
strategy = tf.distribute.MirroredStrategy()
  • 4、在strategy范围内创建和训练模型:
with strategy.scope():# 创建和编译模型model = create_model()model.compile(...)# 加载数据train_dataset = load_train_data()test_dataset = load_test_data()# 训练模型model.fit(train_dataset, validation_data=test_dataset, ...)

以上,在MirroredStrategy范围内创建的模型将自动复制并分布到每个可见的GPU卡上,每个卡都将处理一部分数据。

使用多个 GPU 的最佳做法是使用 tf.distribute.Strategy

以下给出一个官网的简单示例:

tf.debugging.set_log_device_placement(True)
gpus = tf.config.list_logical_devices('GPU')
strategy = tf.distribute.MirroredStrategy(gpus)
with strategy.scope():inputs = tf.keras.layers.Input(shape=(1,))predictions = tf.keras.layers.Dense(1)(inputs)model = tf.keras.models.Model(inputs=inputs, outputs=predictions)model.compile(loss='mse',optimizer=tf.keras.optimizers.SGD(learning_rate=0.2))

当然,也有手动的放置方法:

tf.debugging.set_log_device_placement(True)gpus = tf.config.list_logical_devices('GPU')
if gpus:# Replicate your computation on multiple GPUsc = []for gpu in gpus:with tf.device(gpu.name):a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])c.append(tf.matmul(a, b))with tf.device('/CPU:0'):matmul_sum = tf.add_n(c)print(matmul_sum)

在tensorflow上使用gpu:https://www.tensorflow.org/guide/gpu?hl=zh-cn

http://www.lryc.cn/news/178293.html

相关文章:

  • Mendix中的依赖管理:npm和Maven的应用
  • 自定义hooks之useLastState、useSafeState
  • 前端判断: []+[], []+{}, {}+[], {}+{}
  • el-input-number/el-input 实现实时输入数字转换千分位(失焦时展示千分位)
  • 一篇博客学会系列(2)—— C语言中的自定义类型 :结构体、位段、枚举、联合体
  • KongA 任意用户登录漏洞分析
  • 吉力宝:智能科技鞋品牌步力宝引领传统产业创新思维
  • 【IPC 通信】信号处理接口 Signal API(1)
  • 使用GDIView排查GDI对象泄漏导致的程序UI界面绘制异常问题
  • 蓝桥等考Python组别一级001
  • Unity之Hololens2开发 如何接入的MRTK OpenXR Plugin
  • Ubuntu系统Linux内核安装和使用
  • 数学术语之源——群同态的“核(kernel)”
  • defcon-quals 2023 crackme.tscript.dso wp
  • 前端开发 vs. 后端开发:编程之路的选择
  • 算法练习4——删除有序数组中的重复项 II
  • 【C++进阶(六)】STL大法--栈和队列深度剖析优先级队列适配器原理
  • linux opensuse使用mtk烧录工具flashtool
  • Visio如何对文本打下标、上标,以及插入公式编辑器等问题(已解决)
  • 快速将iPhone大量照片快速传输到电脑的办法!
  • TCP/IP协议簇包含的协议
  • 天地图绘制区域图层
  • git权限不够:Ask a project Owner or Maintainer to create a default branch
  • AI在材料科学中的应用
  • VSCode快速设置heder和main函数
  • JimuReport积木报表 v1.6.2 版本正式发布—开源免费的低代码报表
  • sqlsession对象为什么不能被共享?
  • MySQL MMM高可用架构
  • Spring Boot中配置文件介绍及其使用教程
  • Hobby脚本自动化工具