当前位置: 首页 > news >正文

【大模型LLM】梯度累积(Gradient Accumulation)原理详解

在这里插入图片描述

梯度累积(Gradient Accumulation)原理详解

梯度累积是一种在深度学习训练中常用的技术,特别适用于显存有限但希望使用较大批量大小(batch size)的情况。通过梯度累积,可以在不增加单个批次大小的情况下模拟较大的批量大小,从而提高模型的稳定性和收敛速度。

基本概念

在标准的随机梯度下降(SGD)及其变体(如Adam、RMSprop等)中,每次更新模型参数时都需要计算整个批次数据的损失函数梯度,并立即用这个梯度来更新模型参数。然而,在处理大规模数据集或使用非常大的模型时,单个批次的数据量可能会超出GPU显存的容量。此时,梯度累积技术就可以发挥作用。

工作原理

梯度累积的核心思想是:将多个小批次(mini-batch)的梯度累加起来,然后一次性执行一次参数更新。具体步骤如下:

  1. 初始化梯度累积器:在每个训练步骤开始时,初始化一个梯度累积器(通常为零)。
  2. 前向传播与梯度计算
    • 对于每一个小批次 i(从 1 到 k),执行前向传播计算损失。
    • 执行反向传播计算该小批次的梯度。
  3. 累积梯度:将当前小批次的梯度累加到梯度累积器中。
  4. 参数更新:当累积了 k 个小批次的梯度后,使用累积的梯度来更新模型参数,并重置梯度累积器。
详细步骤

假设我们希望使用的批量大小是 N,但由于显存限制只能使用较小的批量大小 n(其中 N = k * n),那么我们可以进行 k 次前向和后向传播,每次都计算一个小批次的梯度并将其累加,直到累积了 k 个小批次的梯度之后,再进行一次参数更新。

示例代码

以下是一个简单的PyTorch示例,展示了如何实现梯度累积:

import torch
import torch.nn as nn
import torch.optim as optim# 假设有一个简单的模型
model = nn.Linear(10, 2)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 设置梯度累积步数
accumulation_steps = 4
optimizer.zero_grad()  # 清空梯度for i, (inputs, labels) in enumerate(data_loader):outputs = model(inputs)loss = criterion(outputs, labels)# 将损失除以累积步数,使得总的损失不变loss = loss / accumulation_steps# 反向传播计算梯度loss.backward()if (i + 1) % accumulation_steps == 0:# 累积足够步数后,执行优化步骤optimizer.step()optimizer.zero_grad()  # 清空梯度
关键点解释
  1. 损失缩放:由于我们将一个大批次分成多个小批次,并且每次只计算一个小批次的损失,因此需要将每个小批次的损失除以累积步数 accumulation_steps,以确保总的损失值保持不变。

  2. 梯度累积:每次反向传播后,梯度会被累加而不是立即用于更新参数。只有当累积了足够的步数后,才会使用累积的梯度进行一次参数更新。

  3. 参数更新:在累积了足够的梯度后,调用 optimizer.step() 来更新模型参数,并清空梯度累积器(即调用 optimizer.zero_grad())。

优点
  • 突破显存限制:通过使用较小的批量大小,可以有效地减少每一步所需的显存量,从而允许在有限的硬件资源上训练更大的模型或使用更大的批量大小。
  • 模拟大批次训练效果:梯度累积实际上模拟了使用较大批量大小的效果,有助于提高模型训练的稳定性和收敛速度。
  • 灵活性:可以根据实际硬件条件灵活调整累积步数,适应不同的训练需求。
注意事项
  • 学习率调整:由于梯度累积实际上是将多个小批次的梯度累加起来进行一次更新,因此需要相应地调整学习率。例如,如果原始设置的学习率为 lr,并且使用了 k 步梯度累积,则新的有效学习率应为 lr * k
  • 随机性影响:梯度累积可能会引入一定的随机性,因为不同小批次之间的顺序可能会影响最终的梯度累积结果。不过,在实践中这种影响通常是可以接受的。
总结

梯度累积是一种非常实用的技术,特别是在显存受限但希望利用更大批量大小的情况下。它不仅帮助克服了硬件限制,还能够保持甚至提升模型训练的质量。通过合理配置梯度累积步数和学习率,可以显著改善训练效率和效果。

http://www.lryc.cn/news/602877.html

相关文章:

  • 微服务架构中 gRPC 的应用
  • Rust 最短路径、Tide、Partial、Yew、Leptos、数独实践案例
  • Hugging Face-环境配置
  • 洛谷 P10448 组合型枚举-普及-
  • HTML响应式SEO公司网站源码
  • 归雁思维:解锁自然规律与人类智慧的桥梁
  • 疯狂星期四文案网第22天运营日记
  • CFIHL: 水培生菜的多种叶绿素 a 荧光瞬态图像数据集
  • 递归算法的一些具体应用
  • TDSQL 技术详解
  • go‑cdc‑chunkers:用 CDC 实现智能分块 强力去重
  • Apache Ignite 的 JDBC Client Driver(JDBC 客户端驱动)
  • 利用frp实现内网穿透功能(服务器)Linux、(内网)Windows
  • OpenGL进阶系列22 - OpenGL SuperBible - bumpmapping 例子学习
  • 短剧系统开发上线全流程攻略:从架构设计到性能优化
  • 页面性能优化
  • Go性能优化深度指南:从原理到实战
  • C++-关于协程的一些思考
  • Linux 远程连接与文件传输:从基础到高级配置
  • 多系统集成前端困境:老旧工控设备与新型Web应用的兼容性突围方案
  • Docker笔记(基本命令、挂载本地gpu、Dockerfile文件配置、数据挂载、docker换源)
  • 3Dmax模型位置归零
  • [机缘参悟-237]:AI人工神经网络与人类的神经网络工作原理的相似性
  • Java项目:基于SSM框架实现的进销存管理系统【ssm+B/S架构+源码+数据库+毕业论文+远程部署】
  • Java Collections工具类
  • Mac查看本机ip地址
  • 【密码学】3. 流密码
  • 互信息:理论框架、跨学科应用与前沿进展
  • 【实时Linux实战系列】实时运动分析系统的构建
  • 表征学习:机器认知世界的核心能力与前沿突破