当前位置: 首页 > news >正文

paddlepaddle模型转换onnx指导文档

一、检查本机cuda版本

1、右键找到invdia控制面板

在这里插入图片描述

2、找到系统信息

在这里插入图片描述

3、点开“组件”选项卡, 可以看到cuda版本,我们这里是cuda11.7

在这里插入图片描述

cuda驱动版本为516.94
在这里插入图片描述

二、安装paddlepaddle环境

1、获取pip安装命令 ,我们到paddlepaddle官网,找到cuda对应的安装命令

在这里插入图片描述

因为安装 完成paddlepaddle后还需要安装其他依赖,所以我们加上 -i 指定国内的pip源

python -m pip install -i   https://mirror.baidu.com/pypi/simple  paddlepaddle-gpu==2.5.1.post117 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html

2、在anaconda中新建一个python3.9的环境

conda create -n py39_paddle python=3.9

3、切换conda环境到我们新建的环境

conda activate py39_paddle

4、运行pip安装命令

python -m pip install -i   https://mirror.baidu.com/pypi/simple  paddlepaddle-gpu==2.5.1.post117 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.htmlInstalling collected packages: paddle-bfloat, sniffio, protobuf, Pillow, numpy, idna, h11, exceptiongroup, decorator, certifi, astor, opt-einsum, anyio, httpcore, httpx, paddlepaddle-gpu
Successfully installed Pillow-10.0.1 anyio-4.0.0 astor-0.8.1 certifi-2023.7.22 decorator-5.1.1 exceptiongroup-1.1.3 h11-0.14.0 httpcore-0.18.0 httpx-0.25.0 idna-3.4 numpy-1.26.0 opt-einsum-3.3.0 paddle-bfloat-0.1.7 paddlepaddle-gpu-2.5.1.post117 protobuf-3.20.2 sniffio-1.3.0

安装成功!!

三、模型转换

1、安装转换工具paddle2onnx

python -m pip install -i   https://mirror.baidu.com/pypi/simple  paddle2onnx

2.训练模型

import paddle
from paddle.vision.transforms import Normalizetransform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 下载数据集并初始化 DataSet
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)# 模型组网并初始化网络
lenet = paddle.vision.models.LeNet(num_classes=10)
model = paddle.Model(lenet)# 模型训练的配置准备,准备损失函数,优化器和评价指标
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),paddle.nn.CrossEntropyLoss(),paddle.metric.Accuracy())# 模型训练
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
# 模型评估
model.evaluate(test_dataset, batch_size=64, verbose=1)

3.环境报错

在这里插入图片描述
报错内容: cudnn没有装!

4、安装cudnn,cudatookit,参考:cudnn安装指导

https://www.notion.so/3a4f57edc6e54e4eaa63ed86234cf533?pvs=25

5、训练成功!

在这里插入图片描述

6、模型转换

# export to ONNX
save_path = 'onnx.save/lenet1' # 需要保存的路径
x_spec = paddle.static.InputSpec([None, 1, 28, 28], 'float32', 'x') # 为模型指定输入的形状和数据类型,支持持 Tensor 或 InputSpec ,InputSpec 支持动态的 shape。
paddle.onnx.export(lenet, save_path, input_spec=[x_spec], opset_version=14)

在这里插入图片描述
成功生成onnx文件

7、检查转换结果,没有问题

# 导入 ONNX 库
import onnx
# 载入 ONNX 模型
onnx_model = onnx.load("onnx.save/lenet1.onnx")
# 使用 ONNX 库检查 ONNX 模型是否合理
check = onnx.checker.check_model(onnx_model)
# 打印检查结果
print('check: ', check)
check:  None

四、模型精度测试

1、paddlepaddle模型推理

import onnxruntime
import numpy as np
img = np.random.randn(1, 1, 28, 28).astype(np.float32)
lenet.eval()
paddle_input = paddle.to_tensor(img) 
pad_output = lenet(paddle_input)

2、onnx模型推理

ort_session = onnxruntime.InferenceSession('onnx.save/lenet1.onnx',providers=['CPUExecutionProvider', 'CUDAExecutionProvider'])
model_inputs = ort_session.get_inputs()
ort_inputs = {model_inputs[0].name: img}
onnx_output = ort_session.run(['linear_11.tmp_1'], ort_inputs)[0]

### 3、检查推理 结果

paddle.max(pad_output-onnx_output)
Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=False,0.00000381)
http://www.lryc.cn/news/332824.html

相关文章:

  • 图像处理与视觉感知---期末复习重点(6)
  • git 如何删除本地和远程分支
  • Kong基于QPS、IP限流
  • 基于springboot实现甘肃非物质文化网站系统项目【项目源码+论文说明】
  • 【瑞萨RA6M3】1. 基于 vscode 搭建开发环境
  • 使用pip install替代conda install将packet下载到anaconda虚拟环境
  • 【HTML】常用CSS属性
  • python中的print(f‘‘)具体用法
  • 《青少年成长管理2024》022 “成长七要素之三:文化”4/5
  • Linux(05) Debian 系统修改主机名
  • 之前翻硬币问题胡思乱想的完善
  • 前端与后端协同:实现Excel导入导出功能
  • Docker:探索容器化技术,重塑云计算时代应用交付与管理
  • 畅捷通T+ KeyInfoList.aspx SQL漏洞复现
  • 【面经】interrupt()、interrupted()和isInterrupted()的区别与使用
  • 了解这些技术:Flutter应用顺利登陆iOS平台的步骤与方法
  • 经济学 劳动市场 医疗经济学
  • vue + koa + Sequelize + 阿里云部署 + 宝塔:宝塔数据库连接
  • 华为昇腾认证考试内容有哪些
  • Spring Boot接收从前端传过来的数据常用方式以及处理的技巧
  • EFCore通用数据操作类
  • java Web 辅助学习管理系统idea开发mysql数据库web结构java编程计算机网页源码maven项目
  • 使用Python实现K近邻算法
  • Celery的任务流
  • 使用Arcpy进行数据批处理-批量裁剪
  • 【攻防世界】ics-05
  • VTK的交互器
  • ChatGPT(3.5版本)开放无需注册:算力背后的数据之战悄然打响
  • python项目练习——14.学生管理系统
  • 基于SpringBoot的公益慈善平台