当前位置: 首页 > news >正文

【大模型学习 | 量化】pytorch量化基础知识(1)

pytorch量化

[!note]

  • 官方定义:performing computations and storing tensors at lower bitwidths than floating point precision.
  • 支持INT8量化,可以降低4倍的模型大小以及显存需求,加速2-4倍的推理速度
  • 通俗理解:降低权重和激活值的精度(FP32→INT8),从而提高模型大小以及显存需求。

一、前置知识

1.1 算子融合

​ 将多个连续层的计算操作合并为单个复合算子,减少对内存的访问次数

e.g. 例如将Conv → BN → ReLU, 融合为ConvBnReLU

操作流程内存访问次数计算强度
未融合(3个算子)6次
已融合(1个算子)2次

​ NVIDA GPU:

// 未融合:多次启动核函数
conv_kernel<<<...>>>(input, weight, temp1);
bias_kernel<<<...>>>(temp1, bias, temp2);
relu_kernel<<<...>>>(temp2, output);// 已融合:单核函数完成所有操作
fused_kernel<<<...>>>(input, weight, bias, output) {float val = conv2d(input, weight);val += bias;output = max(val, 0.0f);
}

二、量化知识

2.1 对称量化 & 非对称量化

⚙️ 区别

  • 对称量化(Symmetric Quantization)

X i n t = r o u n d ( X f l o a t s c a l e ) , s c a l e = m a x ( ∣ X ∣ ) 2 n − 1 − 1 X_{int}=round(\frac{X_{float}}{scale}), scale = \frac{max(|X|)}{2^{n-1}-1} Xint=round(scaleXfloat),scale=2n11max(X)

  • 非对称量化(Affine Quantization)

X i n t = r o u n d ( X f l o a t s c a l e ) + z e r o _ p o i n t , s c a l e = m a x x − m i n ) x 2 n − 1 X_{int}=round(\frac{X_{float}}{scale}) + zero\_point, scale = \frac{max_x-min_)x}{2^{n}-1} Xint=round(scaleXfloat)+zero_point,scale=2n1maxxmin)x

z e r o _ p o i n t = r o u n d ( − m i n ( x ) s c a l e ) zero\_point = round(\frac{-min(x)}{scale}) zero_point=round(scalemin(x))

特性对称量化(Symmetric Quantization)非对称量化(Affine Quantization)
零点位置固定为0动态计算(zero_point)
数值范围[-127, 127] (int8)[0, 255] (uint8)
计算开销更低(无需zero_point计算)更高
精度损失对偏斜分布敏感更鲁棒,能更好处理数据分布偏斜的情况
典型应用权重量化(正负均衡)激活值量化
硬件支持广泛支持(如GPU/TPU)需要额外处理zero_point

🤖 工程实现角度:为什么 PTQ 常用非对称,QAT 用对称

模式推荐默认背后原因
PTQ权重:对称 激活:非对称因为激活是不可训练的静态量化,非对称能更好地适应非负分布
QAT权重:对称 激活:对称(人为设定)因为激活是可训练的,你可以通过训练让它“对称”起来,精度损失更可控
2.2 PTQ & QAT

[!note]

PTQ 是直接对训练后的模型参数进行量化,因此适合于快速部署;QAT是通过插入伪量化节点,在训练过程中模拟量化误差以达到更高的精度,因此需要重新训练。

⚙️ 区别

特性PTQ(训练后量化)QAT(量化感知训练)
训练阶段仅FP32训练插入伪量化节点训练
反向传播❌ 不支持✅ 通过STE支持
精度损失较大(尤其小模型)通常更小
计算开销低(仅需校准)高(需完整训练)
典型用途快速部署高精度要求的场景

[!tip]

QAT伪量化节点

  • 作用:在训练时模拟量化的误差。在每一层训练时,权重、激活值依然是FT32,但在每一层的传播中,值被“量化再还原”,模拟了量化过程。
  • 由于量化过程有round函数,是不可微的,因此需要Straight-Through Estimator(STE)近似梯度的 FakeQuant 模块

三、Pytorch实现量化的三种方式

参考链接:Quantization — PyTorch 2.7 documentation

特性Eager Mode QATFX Graph QATExport QAT
实现方式动态图模式符号化重写编译器优化
控制流支持
算子融合❌(只能手动融合)✅🌟
典型APIprepare_qatprepare_fxexport
Type只支持module支持 module & function支持 module & function

[!note]

无论是PTQ 还是 QAT , 每一种实现方式都需要 prepare_fx 和 convert_fx

model_prepared = quantize_fx.prepare_fx(model, qconfig_mapping, example_inputs)
model_quantized = quantize_fx.convert_fx(model_prepared)

🎯 核心功能:在模型的每一个 qconfig_mapping 指定的量化位置(如 Conv2d、Linear)处,插入对应的 observerfake_quant 节点。

📦 插入两类模块:

类型对应 prepare 的用途说明
Observer用于 PTQ统计 min/max 用来 校准计算 scale 和 zero_point
FakeQuantize用于 QAT模拟量化误差,保留梯度流动,支持训练
3.1 Eager Mode Quantization
import torch# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):def __init__(self):super().__init__()# QuantStub converts tensors from floating point to quantizedself.quant = torch.ao.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 1, 1)self.bn = torch.nn.BatchNorm2d(1)self.relu = torch.nn.ReLU()# DeQuantStub converts tensors from quantized to floating pointself.dequant = torch.ao.quantization.DeQuantStub()def forward(self, x):x = self.quant(x)x = self.conv(x)x = self.bn(x)x = self.relu(x)x = self.dequant(x)return x# create a model instance
model_fp32 = M()# model must be set to eval for fusion to work
model_fp32.eval()# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,[['conv', 'bn', 'relu']])# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32_fused.train())# run the training loop (not shown)
training_loop(model_fp32_prepared)# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)
3.2 FX Graph Mode Quantization (maintenance)
import torch
from torch.ao.quantization import (get_default_qconfig_mapping,get_default_qat_qconfig_mapping,QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copymodel_fp = UserModel()#
# post training dynamic/weight_only quantization
## we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
# a tuple of one or more example inputs are needed to trace the model
example_inputs = (input_fp32)
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# no calibration needed when we only have dynamic/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)#
# post training static quantization
#model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)#
# quantization aware training for static quantization
#model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)
3.3 PyTorch 2 Export Quantization
import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.export import export_for_training
from torch.ao.quantization.quantizer import (XNNPACKQuantizer,get_symmetric_quantization_config,
)class M(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(5, 10)def forward(self, x):return self.linear(x)# initialize a floating point model
float_model = M().eval()# define calibration function
def calibrate(model, data_loader):model.eval()with torch.no_grad():for image, target in data_loader:model(image)# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result should mostly stay the same
m = export_for_training(m, *example_inputs).module()
# we get a model with aten ops# Step 2. quantization
# backend developer will write their own Quantizer and expose methods to allow
# users to express how they
# want the model to be quantized
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
# or prepare_qat_pt2e for Quantization Aware Training
m = prepare_pt2e(m, quantizer)# run calibration
# calibrate(m, sample_inference_data)
m = convert_pt2e(m)
http://www.lryc.cn/news/574963.html

相关文章:

  • webpack5 css-loader 配置项中的modules
  • 华为云Flexus+DeepSeek征文|基于Dify+ModelArts打造智能客服工单处理系统
  • 设计模式精讲 Day 13:责任链模式(Chain of Responsibility Pattern)
  • 告别Excel地狱!用 PostgreSQL + ServBay 搭建跨境电商WMS数据中枢
  • 华为运维工程师面试题(英语试题,内部资料)
  • 数据库系统总结
  • AI+智慧高校数字化校园解决方案PPT(34页)
  • 【开源解析】基于PyQt5的智能费用报销管理系统开发全解:附完整源码
  • 博图SCL语言中 RETURN 语句使用详解
  • Harmony中的HAP、HAR、HSP区别
  • 《推荐技术算法与实践》
  • Linux Kernel下exFat使用fallocate函数不生效问题
  • 微信小程序 / UNIAPP --- 阻止小程序返回(顶部导航栏返回、左 / 右滑手势、安卓物理返回键和调用 navigateBack 接口)
  • Feign源码解析:动态代理与HTTP请求全流程
  • 《汇编语言:基于X86处理器》第4章 复习题和练习,编程练习
  • 福彩双色球第2025072期篮球号码分析
  • (LeetCode 面试经典 150 题) 151. 反转字符串中的单词(栈+字符串)
  • UNIAPP入门基础
  • 网络安全是什么?
  • 暴雨信创电脑代理商成功中标长沙市中医康复医院
  • iClone 中创建的面部动画导入 Daz 3D
  • 【请关注】实操mongodb集群部署
  • VS2022的C#打包出错解决
  • Liunx操作系统笔记2
  • RS485 vs CAN总线:工业通信双雄的深度对决
  • syncthing忘记密码怎么办(Mac版)?
  • 【大模型实战】微调Qwen2.5 VL模型,增强目标检测任务。
  • 在IIS上运行PHP时显示PHP错误信息
  • web安全之h2注入系统学习
  • 14.Linux Docker