当前位置: 首页 > news >正文

libtorch的c++,加载*.pth

一、转换模型为TorchScript

前提:python只保存了参数,没存结构

要在C++中使用libtorch(PyTorch的C++接口),读取和加载通过torch.save保存的模型(    torch.save(pdn.state_dict()这种方式,只保存了参数,没存结构),需要转换模型为TorchScript。在python下实现。

def get_pdn_small(out_channels=384, padding=False):pad_mult = 1 if padding else 0return nn.Sequential(nn.Conv2d(in_channels=3, out_channels=128, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,padding=1 * pad_mult),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256, out_channels=out_channels, kernel_size=4))def get_pdn_medium(out_channels=384, padding=False):pad_mult = 1 if padding else 0return nn.Sequential(nn.Conv2d(in_channels=3, out_channels=256, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1 * pad_mult),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=4),nn.ReLU(inplace=True),nn.Conv2d(in_channels=out_channels, out_channels=out_channels,kernel_size=1))
import torch# 假设你有一个已训练的模型
model = get_pdn_small()# 加载模型的state_dict
model.load_state_dict(torch.load('teacher_small.pth'))
model.eval()  # 设置模型为评估模式# 将模型转化为TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save('teacher_small.pt')

二、在C++中加载TorchScript模型

在C++中,你可以使用torch::jit::load来加载.pt文件,如下所示:

#include <torch/script.h>  // One-stop header for loading TorchScript models
#include <iostream>
#include <memory>int main() {// 加载TorchScript模型try {// 加载模型std::shared_ptr<torch::jit::Module> model = std::make_shared<torch::jit::Module>(torch::jit::load("teacher_small.pt"));std::cout << "Model loaded successfully!" << std::endl;// 你可以在这里使用模型进行推理,比如输入一个张量// 例如,如果输入是一个3x224x224的图像,你需要创建一个相应的Tensortorch::Tensor input = torch::randn({1, 3, 224, 224});  // 示例输入std::vector<torch::jit::IValue> inputs;inputs.push_back(input);// 执行模型推理at::Tensor output = model->forward(inputs).toTensor();std::cout << "Output tensor: " << output << std::endl;}catch (const c10::Error& e) {std::cerr << "Error loading the model: " << e.what() << std::endl;return -1;}
}

http://www.lryc.cn/news/533313.html

相关文章:

  • 去除 RequestTemplate 对象中的指定请求头
  • b s架构 网络安全 网络安全架构分析
  • 【DeepSeek论文精读】2. DeepSeek LLM:以长期主义扩展开源语言模型
  • Spring Boot和SpringMVC的关系
  • java基础4(黑马)
  • nodejs - vue 视频切片上传,本地正常,线上环境导致磁盘爆满bug
  • 注意力机制(Attention Mechanism)和Transformer模型的区别与联系
  • C++,设计模式,【单例模式】
  • C++:类和对象初识
  • 官网下载Redis指南
  • 活动预告 |【Part1】 Azure 在线技术公开课:迁移和保护 Windows Server 和 SQL Server 工作负载
  • 【Linux系统编程】五、进程创建 -- fork()
  • 深入解析 STM32 GPIO:结构、配置与应用实践
  • 深入探究 C++17 std::is_invocable
  • Vmware网络模式
  • 神经辐射场(NeRF):从2D图像到3D场景的革命性重建
  • 深入解析AI技术原理
  • PDF 2.0 的新特性
  • Matlab机械手碰撞检测应用
  • (root) Additional property include:is not allowed
  • react 18父子组件通信
  • FastReport 加载Load(Stream) 模板内包含换行符不能展示
  • Maven 中常用的 scope 类型及其解析
  • vue3:点击子组件进行父子通信
  • Composo:企业级AI应用的质量守门员
  • Jackson扁平化处理对象
  • Java即时编译器(JIT)的原理及在美团的实践经验
  • 使用 Ollama 在 Windows 环境部署 DeepSeek 大模型实战指南
  • 算法基础之八大排序
  • 使用TensorFlow和Keras构建卷积神经网络:图像分类实战指南