当前位置: 首页 > news >正文

PyTorch 2.0 中设置默认使用 GPU 的方法

PyTorch 2.0 中设置默认使用 GPU 的方法

在 PyTorch 2.0 中,默认情况下仍然是使用 CPU 进行计算,除非明确指定使用 GPU。torch.set_default_device 是 PyTorch 2.0 引入的新功能,用于设置默认设备,使得所有后续张量和模块在没有明确指定设备的情况下,会被创建在这个默认设备上。这在代码中提供了一种更简洁的方式来指定设备,而无需在每次创建张量或模型时手动指定。

  1. 检查 PyTorch 版本
    确保使用的是 PyTorch 2.0 或更高版本:

    import torch
    print(torch.__version__)  # 必须是 2.0 或更高版本
    
  2. 检查 CUDA 是否可用
    在设置 GPU 为默认设备之前,确认 CUDA 可用性:

    print(torch.cuda.is_available())  # True 表示可用
    
  3. 设置默认设备为 GPU
    使用 torch.set_default_device 将默认设备设置为 GPU:

    import torch# 确保 CUDA 可用
    if torch.cuda.is_available():# 设置默认设备为 GPUtorch.set_default_device('cuda')print("默认设备已设置为 GPU")
    else:print("CUDA 不可用,无法设置 GPU 为默认设备")
    
  4. 验证默认设备设置
    创建一个张量,验证其是否在 GPU 上:

    x = torch.tensor([1.0, 2.0, 3.0])
    print(x.device)  # 输出:cuda:0
    
  5. 模型自动加载到 GPU
    如果设置了默认设备,模型的参数和新建的张量会自动加载到 GPU:

    class MyModel(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(10, 1)def forward(self, x):return self.linear(x)model = MyModel()
    print(next(model.parameters()).device)  # 输出:cuda:0
    
全局设置代码示例

以下代码展示如何在脚本中全局设置默认设备为 GPU:

import torch# 检查并设置默认设备
if torch.cuda.is_available():torch.set_default_device('cuda')print("默认设备已设置为 GPU")
else:raise RuntimeError("CUDA 不可用,请检查环境配置")# 示例:自动使用 GPU 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
print(f"x device: {x.device}")  # 输出:cuda:0# 示例:自动将模型参数放到 GPU
model = torch.nn.Linear(5, 2)
print(f"Model parameters device: {next(model.parameters()).device}")  # 输出:cuda:0
注意事项
  1. 与设备显式管理的代码兼容性
    如果代码中显式指定了设备(如 tensor.to(device)),torch.set_default_device 不会影响这些张量。建议在全局设置后,尽量减少显式设备管理操作。

  2. 多 GPU 环境
    如果有多个 GPU,可以指定具体设备,比如 'cuda:1'。示例:

    torch.set_default_device('cuda:1')  # 使用第二块 GPU
    
  3. 性能调优
    默认将所有操作转移到 GPU 可能并不适合所有场景,尤其是小规模任务时,GPU 的初始化开销可能超过性能提升。根据需求灵活调整设备。

http://www.lryc.cn/news/506545.html

相关文章:

  • 如何在 Ubuntu 22.04 服务器上安装 Jenkins
  • 【一篇搞定配置】如何在Ubuntu上配置单机/伪分布式Hadoop
  • 利用Map集合设计程序,存储城市和对应等级相关信息
  • 【自动驾驶】单目摄像头实现自动驾驶3D目标检测
  • 21 go语言(golang) - gin框架安装及使用(二)
  • Intel(R) Iris(R) Xe Graphics安装Anaconda、Pytorch(CPU版本)
  • 【Unity3D】实现可视化链式结构数据(节点数据)
  • Three.js推荐-可以和Three.js结合的动画库
  • 增强现实(AR)和虚拟现实(VR)的应用
  • 告别机器人味:如何让ChatGPT写出有灵魂的内容
  • 【Threejs】从零开始(六)--GUI调试开发3D效果
  • Cocos Creator 试玩广告开发
  • 快速解决oracle 11g中exp无法导出空表的问题
  • selenium 报错 invalid argument: invalid locator
  • Flink2.0未来趋势中需要注意的一些问题
  • 机械鹦鹉与真正的智能:大语言模型推理能力的迷思
  • 本地电脑使用命令行上传文件至远程服务器
  • 【系统】Windows11更新解决办法,一键暂停
  • 34. Three.js案例-创建球体与模糊阴影
  • Qt同步读取串口
  • 如何用上AI视频工具Sora,基于ChatGPT升级Plus使用指南
  • 对象的状态变化处理与工厂模式实现
  • 关于IP代理API,我应该了解哪些功能特性?以及如何安全有效地使用它来隐藏我的网络位置?
  • 在Linux上将 `.sh` 脚本、`.jar` 包或其他脚本文件添加到开机自启动
  • [Maven]构建项目与高级特性
  • 【系统架构设计师】真题论文: 论数据分片技术及其应用(包括解题思路和素材)
  • 【bWAPP】XSS跨站脚本攻击实战
  • Ubuntu 22.04,Rime / luna_pinyin.schema 输入法:外挂词库,自定义词库 (****) OK
  • 多协议视频监控汇聚/视频安防系统Liveweb搭建智慧园区视频管理平台
  • 如何高效获取Twitter数据:Apify平台上的推特数据采集解决方案