当前位置: 首页 > news >正文

Pytorch模型转ONNX部署

开始以为会很困难,但是其实非常方便,下边分两步走:1. pytorch模型转onnx;2. 使用onnx进行inference

0. 准备工作

0.1 安装onnx

安装onnx和onnxruntime,onnx貌似是个环境。。倒是没有直接使用,onnxruntime是一个onnx的架构,方便部署使用的

CPU版本:

pip install onnx -i http://pypi.douban.com/simple/  --trusted-host pypi.douban.com
pip install onnxruntime -i http://pypi.douban.com/simple/  --trusted-host pypi.douban.com

GPU版本:

pip install onnx -i http://pypi.douban.com/simple/  --trusted-host pypi.douban.com
pip install onnxruntime-gpu  -i http://pypi.douban.com/simple/  --trusted-host pypi.douban.com

1. pytorch模型转ONNX

### 导出onnx模型
torch.onnx.export(self.network, {'input dict': input dict}, 'home3/medcog/pbliu/test_onnx.onnx')
print('output a onnx model!!!!!!')

坑1:dummy input那里的那个dict:{'input_dict': input_dict},'input_dict'是我network中forward中的参数名字,后边的input_dict是实际的数据,batch size=1。

坑2:只是为了用的话,export三个参数就够了:网络,虚拟输入(bs=1),保存路径。这时候输入的名字会按照顺序被替换掉"onnx::Cast_*",所以你把输入对回去就可以了,我的数据格式修改如下。(并且onnx只接受numpy格式)

onnx_dict = {}
key_prefix = 'onnx::Cast__{}'
onnx_idx = 1
for idx, (k,v) in enumerate(input_dict.items()):if k.startswith('input'):onnx_dict[key_prefix.format(onnx_idx)] = v.numpy()onnx_idx += 1

2. 如何用onnx进行inference

import onnxruntime as rt  
import numpy as np  # 加载 ONNX 模型  
sess = rt.InferenceSession('my_model.onnx', providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])  # 准备好数据onnx_dict# 调用模型进行推理  
result = sess.run(None, onnx_dict)

坑3:这里的sess.run中的None应该类似于tf中希望得到的结果,我这里没有命名,所以就写None了,会默认返回你之前pytorch输出的变量

坑4:sess.run使用的数据onnx_dict就是'onnx::Cast_*'和np array的键值对儿了,你之前在pytorch中定义的输入格式都不重要了,不管你是dict还是啥。

坑5. onnxruntime gpu的时候可能会报错,一个可能是cuda版本不适配的问题,直接在虚拟环境中安装对应版本的cuda就可以

conda install cudatoolkit=10.1
# 版本对照参考https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html

一些其他tips:

1. 实操时候遇到一个极蠢的问题,onnx比pytorch慢很多,后来发现是我把初始化写到运行代码中了,每次测试一个数据都会重新初始化一遍。

http://www.lryc.cn/news/167458.html

相关文章:

  • k8s优雅停服
  • 面试题五:computed的使用
  • 完美的分布式监控系统 Prometheus与优雅的开源可视化平台 Grafana
  • 黑马JVM总结(九)
  • 如何使用 RunwayML 进行创意 AI 创作
  • 【css】能被4整除 css :class,判断一个数能否被另外一个数整除,余数
  • ChatGPT与日本首相交流核废水事件-精准Prompt...
  • 关于 firefox 不能访问 http 的解决
  • 68、Spring Data JPA 的 方法名关键字查询
  • Brother CNC联网数采集和远程控制
  • Jenkins 编译 Maven 项目提示错误 version 17
  • 数据结构——排序算法——堆排序
  • 【Spring事务底层实现原理】
  • docker快速安装redis,mysql,minio,nacos等常用软件【持续更新】
  • SCRUM产品负责人(CSPO)认证培训课程
  • python连接mysql数据库的练习
  • 扩散模型在图像生成中的应用:从真实样例到逼真图像的奇妙转变
  • Windows 打包 Docker 提示环境错误: no DOCKER_HOST environment variable
  • 2023.9.8 基于传输层协议 UDP 和 TCP 编写网络通信程序
  • 单例模式,适用于对象唯一的情景(设计模式与开发实践 P4)
  • C语言实现三子棋游戏(详解)
  • javaee之黑马乐优商城3
  • Pytorch intermediate(二) ResNet
  • 【2023集创赛】加速科技杯作品:高光响应的二硫化铼光电探测器
  • 编写postcss插件,全局css文件px转vw
  • 精品SpringCloud的B2C模式在线学习网微服务分布式
  • 解决vue项目导出当前页Table为Excel
  • C++设计模式_04_Strategy 策略模式
  • 目标检测YOLO实战应用案例100讲-基于YOLOv3多模块融合的遥感目标检测(中)
  • element 表格fixed列高度无法100%