当前位置: 首页 > news >正文

Pytorch图像分类模型模型实时在线验证代码

1.训练并保存自己的模型

保存的模型格式为:XXX.pth

torch.save(model, "./weight/last.pth")if best_acc <(validation_acc / len_val):torch.save(model, "./weight/best.pth")

2.转化为ONNX格式

2.1环境安装(window10)

pip install onnx
pip install onnxruntime#验证安装配置是否成功
import torch
print('PyTorch 版本', torch.__version__)import onnx
print('ONNX 版本', onnx.__version__)import onnxruntime as ort
print('ONNX Runtime 版本', ort.__version__)

2.2.pth格式转ONNX格式

import torch
from torchvision import models# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)model = torch.load('best.pth')
model = model.eval().to(device)
x = torch.randn(1, 3, 256, 256).to(device)  #这里要构造一个数据,保证和自己输入的图片大小一致3*256*256
output = model(x)  #output.shape = torch.Size([1, 10])  这是一个10分类问题#Pytorch模型转ONNX模型
x = torch.randn(1, 3, 256, 256).to(device)with torch.no_grad():torch.onnx.export(model,                   # 要转换的模型x,                       # 模型的任意一组输入'best.onnx', # 导出的 ONNX 文件名opset_version=11,        # ONNX 算子集版本input_names=['input'],   # 输入 Tensor 的名称(自己起名字)output_names=['output']  # 输出 Tensor 的名称(自己起名字)import onnx# 读取 ONNX 模型
onnx_model = onnx.load('resnet18_fruit30.onnx')# 检查模型格式是否正确
onnx.checker.check_model(onnx_model)
print('无报错,onnx模型载入成功')

这是project中就出现了“best.onnx”文件,表示转化ONNX格式成功!

3.可视化实时检测

3.1在PC电脑端查看

3.1.1环境安装(待补充)

pip install onnxruntime
需要提前保存一个类别ID和类别名称对应的文件

3.1.2 摄像头实时捕捉并分类

import onnxruntime
import torch
from torchvision import transforms
import torch.nn.functional as F
import pandas as pd
import numpy as np
from PIL import Image, ImageFont, ImageDraw
import matplotlib.pyplot as plt# 导入中文字体,指定字号
font = ImageFont.truetype('SimHei.ttf', 32)#载入ONNX模型,获取ONNX Runtime推力器
ort_session = onnxruntime.InferenceSession('best.onnx')#载入类别和ID对应字典
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(256),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
http://www.lryc.cn/news/347206.html

相关文章:

  • Java高并发场景(银行转账问题)
  • TypeScript 工具类型
  • [Kotlin]创建一个私有包并使用
  • 鸿蒙应用开发者高级认证指南及参考资料整理(含详细参考答案)
  • 数据匿名化技术
  • HTML学习笔记汇总
  • 初始JSVMP
  • 【机器学习数据可视化-04】Pyecharts数据可视化宝典
  • 通过 Java 操作 redis -- zset 有序集合基本命令
  • 力扣 516. 最长回文子序列 python AC
  • 数据库编程
  • (docker)进入容器后如何使用本机gpu
  • java基础知识点总结2024版(8万字超详细整理)
  • vue中使用element的i18n语言转换(保姆式教程-保证能用)
  • 01 设计模式--单例模式
  • css backdrop-filter 实现背景滤镜
  • AR人脸道具SDK解决方案,实现道具与人脸的自然融合
  • Windows安装RabbitMQ教程(附安装包)
  • 这个问题无人能解,菜鸟勿进
  • 揭秘高效引流获客的艺术:转化技巧大公开
  • 【Unity 鼠标输入检测】
  • LeetCode hot100-33-Y
  • C++和Python通信引文道路社评电商大规模行为图结构数据模型
  • 单片机-点亮第一盏灯
  • C++组合类
  • Linux学习笔记3
  • 免费证件照一键换底色
  • 使用 FFmpeg 从音视频中提取音频
  • GraphQL在现代Web应用中的应用与优势
  • socket编程 学习笔记 理解