当前位置: 首页 > news >正文

onnx转trt时,关于动态shape自动配置默认值的脚本

onnx转trt时,关于动态shape自动配置默认值,一般需要指定3个shape,分别是最小最优与最大。但是我们在测试时不想写那么多的代码,能否自动实现3个shape的配置,这里实现了一版。

import osimport tensorrt as trt
import pycuda.driver as cuda
import onnxdef build_engine(onnx_file_path, engine_dest_path, trt_engine_datatype=trt.DataType.HALF, batch_size=1, silent=False, dynamic_shapes={}, max_mem=(1 << 30)):"""Takes an ONNX file and creates a TensorRT engine to run inference with"""trt_logger = trt.Logger(trt.Logger.WARNING)EXPLICIT_BATCH = [] if trt.__version__[0] < '7' else [1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)]with trt.Builder(trt_logger) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, trt_logger) as parser:builder.max_batch_size = batch_size                                                                                         config = builder.create_builder_config()                                                                                        config.max_workspace_size = max_mem        # work spaceif trt_engine_datatype == trt.DataType.HALF:        # float 16config.set_flag(trt.BuilderFlag.FP16)#  Parse model fileif not os.path.exists(onnx_file_path):print('ONNX file {} not found, please run yolov3_to_onnx.py first to generate it.'.format(onnx_file_path))exit(0)print('Loading ONNX file from path {}...'.format(onnx_file_path))with open(onnx_file_path, 'rb') as model:print('Beginning ONNX file parsing')if not parser.parse(model.read()):print('ERROR: Failed to parse the ONNX file.')for error in range(parser.num_errors):print(parser.get_error(error))return Noneprint('Completed parsing of ONNX file')if not silent:print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))dynamic_shapes_fin = {}# 获取动态shapemod = onnx.load(onnx_file_path)for inp in mod.graph.input:shape = []dynam = Falsefor d in inp.type.tensor_type.shape.dim:shape.append(d.dim_value)if d.dim_param or d.dim_value <= 0:dynam = True# 动态纬度# 自动配置动态 shapeif dynam:shape_min = [(i if (i > 0) else 1) for i in shape]shape_mid = [(i if (i > 0) else 256) for i in shape]shape_max = [(i if (i > 0) else 512) for i in shape]dynamic_shapes_fin[inp.name] = [shape_min, shape_mid, shape_max]# 手动配置动态 batch_size for k, v in dynamic_shapes.items():dynamic_shapes_fin[k] = vif len(dynamic_shapes_fin) > 0:print("===> using dynamic shapes!")profile = builder.create_optimization_profile()for binding_name, dynamic_shape in dynamic_shapes_fin.items():min_shape, opt_shape, max_shape = dynamic_shapeprofile.set_shape(binding_name, min_shape, opt_shape, max_shape)config.add_optimization_profile(profile)trt_engine = builder.build_engine(network, config)buf = trt_engine.serialize()with open(engine_dest_path, 'wb') as f:f.write(buf)

用法,可手动指定,也能不指定,用默认的1、256、512作为测试值用于验证。

build_engine(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx", f"onnx/{project_name}/{project_name}_t2s_encoder.trt",# min_shape, opt_shape, max_shapedynamic_shapes={"ref_seq": [(1, 1), (1, 256), (1, 512)],"text_seq": [(1, 1), (1, 256), (1, 512)],"ref_bert": [(1024, 1), (1024, 256), (1024, 512)],"text_bert": [(1024, 1), (1024, 256), (1024, 512)],"ssl_content": [(1, 768, 1), (1, 768, 256), (1, 768, 512)],})
build_engine(f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", f"onnx/{project_name}/{project_name}_t2s_fsdec.trt")
http://www.lryc.cn/news/349813.html

相关文章:

  • 实验室无法培养的菌,原来可以这么研究!
  • Xed编辑器开发第一期:使用Rust从0到1写一个文本编辑器
  • 农业自动气象监测站:赋能智慧农业的新动力
  • 2-6 任务 猜数小游戏(单次版)
  • springboot 定时任务解决方案
  • 谷粒商城实战(024 业务-订单模块-分布式事务1)
  • .NET使用Microsoft.IdentityModel.Tokens对SAML2.0登录断言校验
  • 性能测试学习二
  • 小丑的身份证和复印件 (BFS + Floyd)
  • C++类与对象(上)
  • Exchanger的 常用场景及使用示例
  • Spring AI项目Open AI对话接口开发指导
  • 决策规划仿真平台的搭建
  • RustGUI学习(iced/iced_aw)之扩展小部件(十八):如何使用badge部件来凸显UI元素?
  • 触摸播放视频,并用iframe实现播放外站视频
  • 接口自动化-requests库
  • 队列的实现与OJ题目解析
  • 中北大学软件学院javaweb实验三JSP+JDBC综合实训(一)__数据库记录的增加、查询
  • 高通QCS6490开发(一): 广翼智联FV01 AI板卡简介
  • 【知识拓展】大白话说清楚:IP地址、子网掩码、网关、DNS等
  • Java 高级面试问题及答案2
  • 2024年网络安全威胁
  • 应用层之 HTTP 协议
  • 解决Word文档中页眉有部分有,有部分没有的问题
  • Python爬虫基础知识学习(以爬取某二手房数据、某博数据与某红薯(书)评论数据为例)
  • JavaScript-输入输出语句
  • peft+llama3训练自定义数据
  • vue+ts+vite+pinia+less+echarts 前端可视化 实战项目
  • 文心一言指令多样化,你知道的有哪些?
  • QT状态机8-使用恢复策略自动恢复属性