当前位置: 首页 > news >正文

【PyTorch单点知识】PyTorch中的自动混合精度(AMP)模块详解

文章目录

      • 0. 前言
      • 1. 什么是自动混合精度?
      • 2. PyTorch AMP 模块
      • 3. 如何使用 PyTorch AMP
        • 3.1 环境准备
        • 3.2 代码实例
        • 3.3 代码解析
      • 4. 结论

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在深度学习领域,训练大型神经网络往往需要大量的计算资源。为了提高训练效率和减少内存占用,研究人员和工程师们不断探索新的技术手段。其中,自动混合精度(Automatic Mixed Precision, AMP)是一种非常有效的技术,它能够在保证模型准确性的同时显著提高训练速度和降低内存使用。

PyTorch 1.6 版本引入了对自动混合精度的支持,通过 torch.cuda.amp 模块来实现。本文将详细介绍 PyTorch 中的 AMP 模块,并提供一个示例来演示如何使用它。

1. 什么是自动混合精度?

自动混合精度是一种训练技巧,它允许在训练过程中使用低于32位浮点的数值格式(如16位浮点数),从而节省内存并加速训练过程。PyTorch 的 AMP 模块能够自动识别哪些操作可以安全地使用16位精度,而哪些操作需要保持32位精度以保证数值稳定性和准确性。这种方法的主要好处包括:

  1. 加速训练:在现代GPU上,对于16位浮点数的算术运算比32位浮点数更快。因此,使用混合精度训练可以显著提高训练速度;
  2. 减少内存使用:16位浮点数占用的空间是32位浮点数的一半,这意味着模型可以在有限的GPU内存中处理更大的批次大小,或者可以将更多的数据缓存到内存中,从而进一步加速训练。
  3. 提高计算效率:通过减少数据类型转换的需求,可以减少计算开销。在某些情况下,使用16位浮点数的运算可以利用特定硬件(如NVIDIA Tensor Cores)的优势,这些硬件专门为低精度运算进行了优化。
  4. 数值稳定性:虽然16位浮点数的动态范围较小,但通过适当的缩放策略(例如使用GradScaler)可以维持数值稳定性,从而避免梯度消失或爆炸的问题。
  5. 易于集成:PyTorch等框架提供的自动混合精度(Automatic Mixed Precision, AMP)工具使得混合精度训练变得非常简单,通常只需要添加几行代码即可实现。

2. PyTorch AMP 模块

PyTorch 的 AMP 模块主要包含两个核心组件:autocastGradScaler

  • autocast:这是一个上下文管理器,它会自动将张量转换为合适的精度。当张量被传递给运算符时,它们会被转换为16位浮点数(如果支持的话),这有助于提高计算速度并减少内存使用。

  • GradScaler:这是一个用于放大梯度的类,因为在混合精度训练中,梯度可能会非常小,以至于导致数值稳定性问题。GradScaler 可以帮助解决这个问题,它在反向传播之前放大损失,然后在更新权重之后还原梯度的尺度。

3. 如何使用 PyTorch AMP

接下来,将通过一个简单的示例来演示如何使用 PyTorch 的 AMP 模块来训练一个神经网络。

3.1 环境准备

确保安装了 PyTorch 1.6 或更高版本。可以使用以下命令安装:

pip install torch==1.10.0+cu111 torchvision==0.11.1+cu111 torchaudio===0.10.0 -f https://download.pytorch.org/whl/cu111/torch_stable.html
3.2 代码实例

下面的示例代码演示了如何使用 PyTorch 的 AMP 模块来训练一个简单的多层感知器(MLP)。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast# 设置随机种子以保证结果的一致性
torch.manual_seed(0)# 创建一个简单的多层感知器模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.linear1 = nn.Linear(10, 100)self.linear2 = nn.Linear(100, 10)def forward(self, x):x = torch.relu(self.linear1(x))x = self.linear2(x)return x# 初始化模型、损失函数和优化器
model = MLP().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 创建 GradScaler
scaler = GradScaler()# 生成一些随机数据
inputs = torch.randn(100, 10).cuda()
targets = torch.randint(0, 10, (100,)).cuda()# 训练循环
for epoch in range(1):print(f"inputs dtype:{inputs.dtype}")# 使用 autocast 上下文管理器with autocast():  #尝试去掉这行再看下# 前向传播outputs = model(inputs)print(f"outputs dtype:{outputs.dtype}")loss = criterion(outputs, targets)print(f"loss dtype:{loss.dtype}")# 清除梯度optimizer.zero_grad(set_to_none=True)# 使用 GradScaler 缩放损失scaler.scale(loss).backward()# 更新权重scaler.step(optimizer)# 更新 GradScalerscaler.update()print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")
3.3 代码解析

上面实例输出为:

inputs dtype:torch.float32
outputs dtype:torch.float16   
loss dtype:torch.float32
Epoch 1, Loss: 2.2972

这里可以注意到outputs的类型自动变成了float16

  1. 模型定义:我们定义了一个简单的多层感知器模型,包含两个线性层。
  2. 初始化:初始化模型、损失函数和优化器,并创建 GradScaler 对象。
  3. 数据准备:生成一些随机输入数据和目标标签。
  4. 训练循环
    • 使用 with autocast() 上下文管理器自动转换张量精度。
    • 前向传播计算输出和损失。
    • 使用 scaler.scale(loss) 放大损失以确保数值稳定性。
    • 反向传播和梯度更新。
    • 更新 GradScaler 状态。

4. 结论

通过使用 PyTorch 的自动混合精度模块,我们可以显著提高模型的训练速度并减少内存使用,尤其是在 GPU 上训练大型神经网络时。上述示例展示了如何轻松地将 AMP 集成到现有训练流程中,只需几行代码即可启用这一功能。

http://www.lryc.cn/news/438343.html

相关文章:

  • 数据结构 --- 哈希表
  • Linux相关:在阿里云下载centos系统镜像
  • 24. 线模型对象
  • EasyExcel 快速入门
  • Sparse4D v1
  • 速盾:你知道高防 IP 和高防 CDN 的区别吗?
  • HTML和CSS网页制作成品
  • Ai+若依(集成easyexcel实现excel表格增强)
  • 钻机、塔吊等大型工程设备,如何远程维护、实时采集运行数据?
  • 【AutoX.js】选择器 UiSelector - 查找包名
  • ERP进销存多仓库管理系统源码 带完整的安装代码包以及搭建部署教程
  • 数据清洗-缺失值填充-对XGBoost参数优化填充
  • Qt_按钮类控件
  • union 的定义和基本结构以及用途
  • 混合整数规划及其MATLAB实现
  • 【数据结构】6——图1,概念
  • 技术周总结 09.09~09.15周日(C# WinForm WPF)
  • 4K投影仪选购全攻略:全玻璃镜头的当贝F6,画面细节纤毫毕现
  • 除了字符串前导的*号之外,将串中其它*号全部删除
  • SpringBoot开发——使用@Slf4j注解实现日志输出
  • VSCode拉取远程项目
  • 【已解决】SpringBoot3项目整合Druid依赖:Druid监控页面404报错
  • 【算法】滑动窗口—找所有字母异位词
  • Vue安装及环境配置【图解版】
  • 绕过CDN查找真实IP方法
  • Qt与MQTT交互通信
  • dd 命令:复制和转换文件
  • 文件系统(磁盘 磁盘文件 inode)
  • ThreeJs创建圆环
  • React实现类似Vue的路由监听Hook