当前位置: 首页 > news >正文

tensorrt

engine

/*tensorrt创建builder1. 创建builder2. 创建网络定义:builder-->network3. 配置参数:builder-->config4. 生成engine:builder-->engine()5. 序列化保存:engine-->serialize6. 释放资源:delete
*/
#include<iostream>
#include<NvInfer.h>
#include <fstream>
#include <assert.h>
class TRTLogger : public nvinfer1::ILogger {void log(Severity severity, const char *msg) noexcept override {if (severity != Severity::kINFO) {std::cout << msg << std::endl;}}
}gLogger;int main() {// 1. 创建builderTRTLogger logger;nvinfer1::IBuilder *builder = nvinfer1::createInferBuilder(logger);// 2. 创建网络定义auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);nvinfer1::INetworkDefinition *network = builder->createNetworkV2(explicitBatch);// 定义网络结构// 3. 配置参数// 添加配置参数,告诉tensorRT如何优化网络nvinfer1::IBuilderConfig *config = builder->createBuilderConfig();//设置最大工作空间,单位:字节config->setMaxWorkspaceSize(1 << 20);// 4. 生成enginenvinfer1::ICudaEngine *engine = builder->buildEngineWithConfig(*network, *config);if (!engine) {std::cout << "创建失败" << std::endl;return -1;}//5. 序列化nvinfer1::IHostMemory *serialized_engine = engine->serialize();// 存入文件std::ofstream outfile("model/mlp.engine", std::ios::binary);assert(outfile.is_open() && "打开失败");outfile.write((char *)serialized_engine->data(), serialized_engine->size());// 释放资源outfile.close();}

runtime推理

/*
使用cu文件时希望使用cuda的编译器,会自动链接cuda库
runtime推理过程
1. 创建一个runtime对象
2. 反序列化申城engine:runtime-->engine
3. 创建一个执行上下文ExecutionContext:engine-->context
4. 填充数据
5. 执行推理
6. 释放资源
*/#include<iostream>
#include<vector>
#include<fstream>
#include<cassert>#include"cuda_runtime.h"
#include"NvInfer.h"class TRTLogger : public nvinfer1::ILogger {void log(Severity severity, const char *msg) noexcept override {if (severity != Severity::kINFO) {std::cout << msg << std::endl;}}
}gLogger;// 加载模型
std::vector<unsigned char>loadEngineModel(const std::string &filename) {std::ifstream file(filename, std::ios::binary);   //二进制形式读取assert(file.is_open && "打开文件失败");// 定位到文件末尾file.seekg(0, std::ios::end);size_t size = file.tellg();		//获取文件大小std::vector<unsigned char> data(size);		//创建一个vector,大小为sizefile.seekg(0, std::ios::beg);				//定位到文件开头file.read((char *)data.data(), size);		// 读取文件内容到datefile.close();return data;}int main() {TRTLogger logger;nvinfer1::IRuntime *runtime = nvinfer1::createInferRuntime(logger);// 反序列化生成engineauto engineModel = loadEngineModel("/mlp.engine");/*调用runtime反序列化engineModel.data():模型数据地址engineModel.size():模型大小nullptr:pluginFactory*/nvinfer1::ICudaEngine *engine = runtime->deserializeCudaEngine(engineModel.data(),engineModel.size(),nullptr);if (!engine) {std::cout << "反序列化失败" << std::endl;return -1;}// 创建一个执行上下文nvinfer1::IExecutionContext *context = engine->createExecutionContext();//填充数据:host-->device-->inference-->host//输入数据float *host_input_data = new float[3]{ 2,4,8 };		//host输入数据int input_data_size = 3 * sizeof(float);			//输入数据大小float *device_input_data = nullptr;					//device输入数据float *host_output_data = new float[2];				//输出数据int output_data_size = 2 * sizeof(float);			//输出数据大小float *device_output_data = nullptr;					//device输出数据cudaMalloc((void **)&device_input_data, input_data_size);cudaMalloc((void **)&device_output_data, output_data_size);cudaStream_t stream = nullptr;cudaStreamCreate(&stream);/*host-->devicedevice_input_data目的地址host_input_data源地址input_data_size数据大小cudaMemcpyHostToDevice拷贝方式stream*/cudaMemcpyAsync(device_input_data, host_input_data, input_data_size, cudaMemcpyHostToDevice,stream);//bindings告诉context输入输出数据位置float * bindings[] = { device_input_data,device_output_data };// 进行推理bool success = context->enqueueV2((void **)bindings, stream, nullptr);// device-->hostcudaMemcpyAsync(host_output_data, device_output_data, output_data_size, cudaMemcpyDeviceToHost, stream);cudaStreamSynchronize(stream);std::cout << host_output_data << std::endl;//释放资源
}
http://www.lryc.cn/news/498299.html

相关文章:

  • 利用Grounding DINO进行自动标注——目标检测任务——YOLO格式
  • 网际协议(IP)与其三大配套协议(ARP、ICMP、IGMP)
  • uniapp 添加loading
  • cocotb pytest
  • docker run 设置启动命令
  • docker入门 自记录
  • css实现圆周运动效果
  • 【NoSQL数据库】MongoDB数据库——集合和文档的基本操作(创建、删除、更新、查询)
  • Dart 学习笔记(一)
  • 安防视频监控平台Liveweb视频汇聚管理系统管理方案
  • 十八(GIT)、GIT基本命令、axios别名方法、黑马就业数据平台(axios基地址、轻提示函数、注册及登录功能)
  • Linux查看系统基本信息
  • Word处理表格的一些宏
  • 将本地项目文件推送到Git仓库中
  • 2024-12-05OpenCV高级-滤波与增强
  • vue3中 axios 发送请求 刷新token 封装axios
  • aardio - 汉字笔顺处理 - json转sqlite转png
  • 数据结构学习笔记 双向链表
  • 深度学习作业十 BPTT
  • html+css+JavaScript实现轮播图
  • Python+onlyoffice 实现在线word编辑
  • PostgreSQLt二进制安装-contos7
  • Neo4j启动时指定JDK版本
  • kanzi3.6.10 窗口插件-美化绑定内容
  • 利用tablesaw库简化表格数据分析
  • 记录一下,解决js内存溢出npm ERR! code ELIFECYCLEnpm ERR! errno 134 以及 errno 9009
  • 【JavaWeb后端学习笔记】MySQL的数据查询语言(Data Query Language,DQL)
  • 360 最新Android面试题及参考答案
  • 《操作系统 - 清华大学》6 -3:局部页面置换算法:最近最久未使用算法 (LRU, Least Recently Used)
  • ES6新增了哪些特性(待更新)