当前位置: 首页 > news >正文

TensorRT量化工具pytorch_quantization代码解析(一)

量化工具箱pytorch_quantization 通过提供一个方便的 PyTorch 库来补充 TensorRT ,该库有助于生成可优化的 QAT 模型。该工具包提供了一个 API 来自动或手动为 QAT 或 PTQ 准备模型。

API 的核心是 TensorQuantizer 模块,它可以量化、伪量化或收集张量的统计信息。它与 QuantDescriptor 一起使用,后者描述了如何量化张量。在 TensorQuantizer 之上的是量化模块,这些模块被设计为 PyTorch 全精度模块的替代品。这些是使用 TensorQuantizer 对模块的权重和输入进行伪量化或收集统计信息的方便模块。

API 支持将 PyTorch 模块自动转换为其量化版本。转换也可以使用 API 手动完成,这允许在不想量化所有模块的情况下进行部分量化。例如,一些层可能对量化更敏感,并且使其未量化可提高任务精度。

量化第一步是将量化器模块添加到神经网络图中。该包提供了许多量化层模块,其中包含用于输入和权重的量化器。例如quant_nn.QuantLinear,它可以用来代替nn.Linear。这些量化层可以通过猴子修补或手动修改模型定义来自动替换。自动层替换是使用quant_module完成的。这应该在创建模型之前调用。

首先看以下代码:

from pytorch_quantization import quant_modules
quant_modules.initialize()

initialize()会动态地修改 PyTorch 代码,适用于每个模块的所有实例,将 torch.nn.module 的一些子类替换为对应的量化版本。如果不希望所有模块都量化,则应手动替换量化模块。独立量化器也可以添加到带有quant_nn.TensorQuantizer的模型中。

initialize()位于:tools\pytorch-quantization\pytorch_quantization\quant_modules.py,作用使用使用monkey patching进行动态模块更换为量化版本

什么是猴子补丁

  • Python是一种典型的动态脚本语言。它不仅具有 动态类型(dynamic type) ,而且它的 对象模型(object model) 也是动态的。Python的类是可变的(mutable),方法(methods)只是类的属性(attributes);这允许我们在 运行时(run time) 修改其行为。这被称为猴子补丁(Monkey Patching), 它指的是偷偷地更改代码。
  • Monkey Patching只是在 运行时(run time) 动态替换属性(attributes)。而在Python中,术语monkey patch指的是对函数(function)、类(class)或模块(module)的动态(或运行时)修改。
def initialize(float_module_list=None, custom_quant_modules=None):"""用量化版本动态地替换模块。在内部,状态由helper类对象维护,该对象有助于将原始模块替换回去。参数:float_module_list:列表,用户提供的列表,其中指明哪些模块不可执行替换custom_quant_modules:一个字典。用户提供的映射,用于指示除torch.nn及其相应量化版本之外的任何其他模块。Returns:空"""# 准备monkey patching中使用的内部变量quant_map和orginal_func_map_quant_module_helper_object.prepare_state(float_module_list, custom_quant_modules)#执行量化模块替换_quant_module_helper_object.apply_quant_modules()def deactivate():"""动态模块更换,可逆转monkey patching使用维护状态的helper类对象动态地替换回先前在initialize()函数调用中被monkey patching的原始模块。"""_quant_module_helper_object.restore_float_modules()# 维护被替换模块状态的全局对象。
_quant_module_helper_object = QuantModuleReplacementHelper()

自定义量化模块使用示例:

# torch.nn模块定义不可执行替换列表
float_module_list = ["Linear"]
# torch.nn以外的模块自定义映射
custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)]
# Monkey修补模块
pytorch_quantization.quant_modules.initialize(float_module_list, custom_modules)
# 使用量化模块
pytorch_quantization.quant_modules.deactivate()

继续看helperQuantModuleReplacementHelper

class QuantModuleReplacementHelper():"""帮助量化版本替换torch.nn模块术语monkey patch指的是对函数(function)、类(class)或模块(module)的动态(或运行时)修改该模块用工具内部实现或任何其他用户提供的自定义模块提供的量化版 替换(通过monkey patching)torch.nn模块属性:orginal_func_map:一个dict.维护原始torch.nn模块字典quant_support_list:列表,包含工具提供的量化版本的模块名称quant_map:一个字典,包含模块名称及其量化版本的字典quant_switch_opt:一个字典,用于指示哪些模块不能替换其量化版本。该dict由用户提供的列表更新,该列表指示在monkey patching中要忽略的模块"""def __init__(self):# 保留要更换的原始模块self.orginal_func_map = set()# 默认情况下,维护工具支持的量化模块列表self.default_quant_map = _DEFAULT_QUANT_MAP# 保存最终量化模块。self.quant_map = set()

_DEFAULT_QUANT_MAP是包含量化模块映射的文件的全局成员

_DEFAULT_QUANT_MAP = [_quant_entry(torch.nn, "Conv1d", quant_nn.QuantConv1d),_quant_entry(torch.nn, "Conv2d", quant_nn.QuantConv2d),_quant_entry(torch.nn, "Conv3d", quant_nn.QuantConv3d),_quant_entry(torch.nn, "ConvTranspose1d", quant_nn.QuantConvTranspose1d),_quant_entry(torch.nn, "ConvTranspose2d", quant_nn.QuantConvTranspose2d),_quant_entry(torch.nn, "ConvTranspose3d", quant_nn.QuantConvTranspose3d),_quant_entry(torch.nn, "Linear", quant_nn.QuantLinear),_quant_entry(torch.nn, "LSTM", quant_nn.QuantLSTM),_quant_entry(torch.nn, "LSTMCell", quant_nn.QuantLSTMCell),_quant_entry(torch.nn, "AvgPool1d", quant_nn.QuantAvgPool1d),_quant_entry(torch.nn, "AvgPool2d", quant_nn.QuantAvgPool2d),_quant_entry(torch.nn, "AvgPool3d", quant_nn.QuantAvgPool3d),_quant_entry(torch.nn, "AdaptiveAvgPool1d", quant_nn.QuantAdaptiveAvgPool1d),_quant_entry(torch.nn, "AdaptiveAvgPool2d", quant_nn.QuantAdaptiveAvgPool2d),_quant_entry(torch.nn, "AdaptiveAvgPool3d", quant_nn.QuantAdaptiveAvgPool3d),]

_quant_entry定义命名元组,用于存储量化模块映射,它拥有三个属性orig_mod mod_name replace_mod

_quant_entry = namedtuple('quant_entry', 'orig_mod mod_name replace_mod')

QuantModuleReplacementHelper类的属性方法:

  • prepare_state 准备稍后在monkey patching机制中使用的量化模块的命名字典quant_map和更换为原始模块orginal_func_map
    • 设置torch.nn工具支持的量化模块列表
    • 为torch.nn以外的模块设置自定义映射
    • 使用float_module_list关闭用户指示模块的monkey patching替换
    def prepare_state(self, float_module_list=None, custom_map=None):""""""# 对于支持的默认量化模块,生成quant_mapfor item in self.default_quant_map:if float_module_list is not None and item.mod_name in float_module_list:# 如果float_module_list中存在此模块,则跳过此模块continueelse:# 将模块追加到将在monkey patching中使用的变量中self.quant_map.add(item)# 存储要在反向monkey patching中使用的原始模块self.orginal_func_map.add(_quant_entry(item.orig_mod, item.mod_name,getattr(item.orig_mod, item.mod_name)))# 将自定义模块添加到quant_mapif custom_map is not None:for item in custom_map:# 将自定义模块附加到将在monkey补丁中使用的列表中# 将元组转换为命名元组self.quant_map.add(_quant_entry(item[0], item[1], item[2]))# 将原始模块存储在另一个列表中,该列表将用于反向monkey patchingself.orginal_func_map.add(_quant_entry(item[0], item[1], getattr(item[0], item[1])))
  • apply_quant_modules:根据quant_map,执行替换为量化模块
    def apply_quant_modules(self):for entry in self.quant_map:# 用于设置属性值,该属性不一定是存在的,对应函数 getattr()setattr(entry.orig_mod, entry.mod_name, entry.replace_mod)
  • restore_float_modules:通过使用orginal_func_map替换回原始模块,反转monkey patch的效果
    def restore_float_modules(self):for entry in self.orginal_func_map:setattr(entry.orig_mod, entry.mod_name, entry.replace_mod)
http://www.lryc.cn/news/40051.html

相关文章:

  • 【Kubernetes】第二十七篇 - 布署前端项(下)
  • 【MFC】两个ListBox控件数据交互
  • sklearn库学习--SelectKBest 、f_regression
  • 蓝桥杯刷题第十三天
  • CPU 和带宽之间的时空权衡
  • ES+Redis+MySQL,这个高可用架构设计太顶了!
  • 【Maven】Maven的常用命令
  • python的循环结构
  • 五种Python中字典的高级用法
  • [蓝桥杯单片机]——八到十一届初赛决赛客观题
  • 多线程(初阶)
  • 【Vue从入门到进阶】Node.js安装与配置
  • python 正则使用详解
  • 一个深度学习项目需要什么
  • 【Java进阶篇】—— 常用类和基础API
  • 手敲Mybatis(六)-反射工具天花板
  • 内含18禁~~关于自学\跳槽\转行做网络安全行业的一些建议
  • 春分策划×运维老王主讲:CMDB数据运营精准化公开课启动报名啦!
  • 制作INCA和CANape通用的A2L
  • Python人脸识别
  • 我用Python写了一个下载网站所有内容的软件,可见即可下,室友表示非常好用
  • 【M365运维】扩充OneDrive存储空间
  • hashcat(爆破工具,支持GPU,精)
  • 【机器学习】什么是监督学习、半监督学习、无监督学习、自监督学习以及弱监督学习
  • HashiCorp packer 制作AWS AMI镜像示例
  • 【java基础】根据泛型动态构造jackson的TypeReference(json反序列化为带泛型的类的对象)
  • 为什么VMware会给我多创建了两个网络呢?Windows和Linux为什么可以彼此ping的通呢
  • 服务器带宽承载多少人同时访问计算方法-浏览器中查看当前网页所有资源数据大小-客服系统高并发承载人数【唯一客服】...
  • 给新手----编译VSOMEIP保姆级别教程
  • MarkDown设置上下标