训练模型时梯度出现NAN或者INF(禁用amp的不同level)
判断参数梯度位nan或inf的代码:
for name, param in model.named_parameters():if param.grad is not None:if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():print(f"grad layer [{name}] is NaN or Inf")
首先来说可能得原因:
1. 模型中存在未初始化或未更新的参数(层)
2. 除以0或者log引起
3.输入数据存在你nan或者inf
4. 学习率过大造成梯度不稳定
5.数据类型问题
这里着重讲下第5点。
我的错误是前1-2个epoch的grad norm出现 nan, 后面又稳定了,偶尔又会出现inf,有点随机。
可参考类似的情况和回答:
After several iterations gradient norm and loss becomes nan · Issue #287 · microsoft/Swin-Transformer · GitHub
Got a nan loss and gradient norm when training swin-l on imagenet22k with O1 · Issue #82 · microsoft/Swin-Transformer · GitHub
因为debug了发现都不是1,2,3,4的问题所以最后调试问题出在数据类型上。
PyTorch 的 AMP(自动混合精度) 默认支持动态切换精度。它会在前向和后向传播中自动判断是否切换为 float16
精度,以节省显存并加速计算。在使用 AMP 时,通常采用以下几种机制来选择精度:
-
按操作动态调整精度:AMP 会根据具体操作的数值稳定性来选择
float32
或float16
,对于稳定性较好的操作(如矩阵乘法)使用float16
,对精度要求较高的操作(如归一化)则保留float32
。 -
GradScaler
动态调整梯度缩放:AMP 默认使用GradScaler
对梯度进行缩放,以避免因float16
造成的数值下溢(过小梯度被舍去)。
这种自动化过程旨在最大程度保持数值稳定性,并降低显存需求。只需使用 torch.cuda.amp.autocast
上下文管理器和 GradScaler
,AMP 就能完成动态精度切换
回到我的错误中来,若主函数里面有两个参数:
parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],help='mixed precision opt level, if O0, no amp is used (deprecated!)')
解决方案
1. 禁用自动混合精度(AMP): 如果你不依赖于 bfloat16
的性能优化,可以选择禁用 AMP。你可以在你的主函数中设置 --disable_amp
参数,或者在代码中直接注释掉与 GradScaler
和 autocast
相关的代码。这将避免因 bfloat16
引起的问题。
. --disable_amp
- 类型: 布尔型(
action='store_true'
) - 功能: 如果指定了这个参数,将会禁用 PyTorch 的自动混合精度功能。在训练过程中,这意味着模型将会使用全精度(通常是
float32
)进行计算,而不使用混合精度。 - 适用场景: 在调试或遇到精度问题时,可以选择禁用 AMP。
2. 使用 bfloat16
:
需要设置这3处
[1]
model = model.to(torch.bfloat16)[2]
samples = samples.to(torch.bfloat16)
targets = targets.to(torch.bfloat16)[3]
with torch.cuda.amp.autocast(dtype=torch.bfloat16): outputs = model(samples)
在 PyTorch 中,使用 torch.bfloat16
时,可能会遇到与 torch.cuda.amp
(自动混合精度)相关的问题,特别是关于梯度不稳定性和 unscale
操作的支持。
3. --amp-opt-level(推荐使用)
:参数用于指定自动混合精度(AMP)的优化级别。而不是
主要有以下几种可选的优化级别:
-
O0:
- 含义: 不使用混合精度,所有计算都在全精度(
float32
)下进行。 - 适用场景: 当模型需要最高的数值精度或在调试时。
- 含义: 不使用混合精度,所有计算都在全精度(
-
O1:
- 含义: 使用混合精度,但在关键操作中保持全精度(例如,反向传播)。大多数操作在
float16
下进行,只有在必要时(例如,loss 计算)切换回float32
。 - 适用场景: 在保证较高性能的同时,尽量减少数值不稳定性,适合大多数场景。
- 含义: 使用混合精度,但在关键操作中保持全精度(例如,反向传播)。大多数操作在
-
O2:
- 含义: 尽可能多地使用
float16
,仅在少数操作中使用float32
。几乎所有的计算都是在float16
下进行,可能会导致数值稳定性的问题。 - 适用场景: 性能优先的情况下使用,但需确保模型能够在
float16
下稳定运行。
- 含义: 尽可能多地使用
-
O3:
- 含义: 强制所有操作都使用
float16
,可能会导致更高的计算性能,但也可能引入更多的数值不稳定性。 - 适用场景: 仅适用于对性能要求极高且能够处理数值不稳定性的模型。
- 含义: 强制所有操作都使用
选择建议
- O1: 通常是最推荐的选项,适合大多数任务。
- O0: 如果遇到数值不稳定性或调试问题,可以选择。
- O2 和 O3: 适合在确保模型稳定性的前提下追求性能的高级用户。