模型学习系列之精度
背景
- 模型配置中,精度BF16,FP8
BF16
-
定义:16bit浮点,1符号+8指数+7位数,比FP32少16位尾数,但保留完整指数范围。
-
特点:显存BF16占2字节(Byte);吞吐量BF16是FP32的2倍;数值稳定性与FP32相当(动态范围 ≈ 10⁻³⁸~10³⁸)
-
场景:BF16训练任何规模的Transformer/LLM, 默认半精度
-
硬件要求:
- NVIDIA A100 或 H100
- AMD MI250 或 MI300
- Intel Gaudi 2/3
- Google TPU v4/v5
-
框架:PyTorch 2.1+、TensorFlow 2.15+均支持BF16
-
用法:
- 训练:开启开关
# PyTorch # 训练脚本里一行即可 model = model.to(dtype=torch.bfloat16) # 或者直接用 amp with torch.cuda.amp.autocast(dtype=torch.bfloat16):loss = model(input)# Transformers training_args = TrainingArguments(bf16=True, # 一行打开... )
-
推理:精度策略(几乎不管)
- 前向、反向、权重、梯度、优化器状态 全部可以 BF16。
- 只有极少数场景(如梯度累加求和、Adam 的 exp_avg/exp_avg_sq)框架会内部用 FP32 累加,用户无感。
-
验证
- 直接跑训练500-1000 step, 看loss曲线与FP32是否重合。
- 推理可直接用同一份BF16权重,无需额外量化脚本;
- 精度掉点<0.1%即视为无损。
-
总结:BF16数值稳定,半精度,训练首选。
FP8
-
定义:8 bit浮点,分两种子格式(E4M3、E5M2), 只有2-3位尾数,指数位缩到4~5位。
-
特点:显存FP8占1字节(byte), FP8峰值算力是BF16的2倍;数值容易不稳,FP8范围小,需要混合精度保护。动态范围 ≈ 10⁻⁶~10⁴(E4M3)或 10⁻¹⁴~10⁴(E5M2)
-
场景
- 训练超大模型(>70B)且GPU支持FP8 Tensor Core
- 推理高Batch、高并发场景
- 量化实验验证对精度影响可接受时。
-
硬件要求:
- NVIDIA H100
- Grace Hopper
- Intel Gaudi3
- Graphcore Bow-2000
-
框架:PyTorch 2.1+、TensorFlow 2.15+均支持FP8
-
使用
-
训练
- 开启开关
# PyTorch(H100) torch.backends.cuda.enable_flash_sdp(True) # 开启 Flash-Attention torch.backends.cudnn.allow_tf32 = True torch.set_float8_matmul_precision('high') # 全局启用 FP8 GEMM with torch.fp8_autocast(enabled=True): # 上下文loss = model(input)
- 精度策略
计算阶段 建议格式 备注 前向 / 反向 GEMM FP8 (E4M3) 权重、激活、梯度 梯度累加 FP32 累加器 防止误差累积 优化器状态 BF16 / FP32 Adam 一阶、二阶矩用 BF16,主权重 FP32 敏感算子 (LayerNorm, Softmax, Embedding, MoE-Gate) FP16 或 BF16 直接跳过 FP8,避免精度崩 - 动态缩放(必配)
# 延迟缩放(Delayed Scaling) from torch.cuda.amp import GradScaler scaler = torch.cuda.amp.GradScaler(init_scale=2.**14, growth_interval=1000)# 按 Tile/Block 缩放(NVIDIA MXFP8) 每 128×128 权重块、1×128 激活 tile 单独 amax 缩放
- 验证
# 训练 1000 step,对比 FP32/BF16 loss 曲线; # 每 100 step 做一次 amax 分布直方图,确保无饱和。
-
-
推理
- 离线量化(一次转权重)
# NVIDIA AMMO ammo quantize --model llama3_70b_fp16 \--output llama3_70b_fp8 \--dtype fp8_e4m3 \--calib-dataset wikitext-2-raw# Graphcore PopRTpython -m poprt.cli --input_model model.onnx \--output_model model_fp8.onnx \--precision fp8 --quantize
- 在线量化(无需更改模型)
# TensorRT-LLM from tensorrt_llm import LLM, QuantMode llm = LLM(model_dir, quant_mode=QuantMode.FP8)# StableDiffusion WebUI 在设置页勾选 “Enable FP8 precision for SDXL” → 保存 → 重启即可
- 缩放因子加载
# TensorRT-LLM 会自动寻找 model.fp8_scales.json;如缺失,需手动调用 config.set_flag(trt.BuilderFlag.FP8) config.add_optimization_profile(fp8_scale=scale_tensor)
-
验证
- 跑 1 k sample,比较 FP16 vs FP8 的 BLEU/ROUGE 或 CLIP-SIM 指标;
- 显存下降 ≥ 1.8 GB、吞吐提升 ≥ 1.3× 视为合格。
-
总结:把「FP8 开关 + 缩放策略 + 高精度累加」三板斧配好,就能在训练把显存和算力各砍半,在推理把显存再砍半。
总结
- BF16:稳半精,训练首选;FP8:激进省,超大模型和推理神器。