混合精度加快前向传播的速度
一方面,16 位计算比 32 位计算速度更快。但另一方面,精度的损失会随着时间的推移、一次又一次的运算而累积,从而导致数值问题。或许我们可以鱼(32 位)与熊掌(16 位)兼得?
混合精度(计算)登场!
“加载模型” 摘要
如果你的 GPU 支持,在所有 16 位运算相关场景中,使用 torch.bfloat16 而非 torch.float16。
supported = torch.cuda.is_bf16_supported(including_emulation=False)
dtype16 = (torch.bfloat16 if supported else