当前位置: 首页 > news >正文

Lnton羚通关于PyTorch的保存和加载模型基础知识

SAVE AND LOAD THE MODEL (保存和加载模型)

PyTorch 模型存储学习到的参数在内部状态字典中,称为 state_dict, 他们的持久化通过 torch.save 方法。

model = models.shufflenet_v2_x0_5(pretrained=True)
torch.save(model, "../../data/ShuffleNetV2_X0.5.pth")

如果要加载模型的话,首先需要实例化一个同类型的模型对象,然后用 load_state_dict() 方法加载参数。

model = models.shufflenet_v2_x0_5()
model.load_state_dict(torch.load("../../data/ShuffleNetV2_X0.5.pth"))
model.eval()
Output exceeds the size limit. Open the full output data in a text editor
ShuffleNetV2((conv1): Sequential((0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(stage2): Sequential((0): InvertedResidual((branch1): Sequential((0): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)(3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(4): ReLU(inplace=True))(branch2): Sequential((0): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)(4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)(6): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): ReLU(inplace=True)
...(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(fc): Linear(in_features=1024, out_features=1000, bias=True)
)

Saving and Loading Models with Shapes
当加载模型权重时,我们需要首先实例化模型类,因为类定义了网络的结构。我们可能想要保存类的结构以及模型,在这种情况下,我们可以将 model (而不是 model.state_dict() ) 传递给保存函数:
 

torch.save(model, "../../data/ShuffleNetV2_X0.5_eval2.pth")

加载模型如这样:

model = torch.load("../../data/ShuffleNetV2_X0.5_eval2.pth")
print(model)

这种方法在序列化模型时使用 Python pickle 模块,因此它依赖于加载模型时可用的实际类定义。

Lnton羚通专注于音视频算法、算力、云平台的高科技人工智能企业。 公司基于视频分析技术、视频智能传输技术、远程监测技术以及智能语音融合技术等, 拥有多款可支持ONVIF、RTSP、GB/T28181等多协议、多路数的音视频智能分析服务器/云平台。

http://www.lryc.cn/news/129831.html

相关文章:

  • python+django+mysql项目实践四(信息修改+用户登陆)
  • sCrypt编程马拉松于8月13日在复旦大学成功举办
  • Selenium手动和自动两种方式启动Chrome驱动
  • 《PostgreSQL 开发指南》第32篇 物化视图
  • 【RocketMQ】快速入门
  • AB跳转轮询:让你的独立站收款智能化
  • 所有用户都能使用sudo吗
  • 【广州华锐视点】VR警务教育实训系统模拟真实场景进行实践训练
  • 【深入浅出C#】章节 7: 文件和输入输出操作:处理文本和二进制数据
  • Matlab中图例的位置(图例放在图的上方、下方、左方、右方、图外面)等
  • 【算法学习】两数之和II - 输入有序数组
  • 聚观早报|京东称在技术投入没有止境;木蚁机器人完成B2轮融资
  • C语言:选择+编程(每日一练)
  • 信道数据传输速率、码元传输速率、调制速度,信号传播速度之间的关系
  • docker的使用方法总结
  • 【C#】条码管理操作手册
  • RabbitMq-发布确认高级(避坑指南版)
  • Blender增强现实3D模型制作指南【AR】
  • Java查看https证书过期时间(JKS,CERT)
  • 关于vue,记录一次修饰符.stop和.once的使用,以及猜想。
  • 解决git reset --soft HEAD^撤销commit时报错
  • 【BASH】回顾与知识点梳理(三十四)
  • Python可视化在量化交易中的应用(11)_Seaborn折线图
  • 无涯教程-TensorFlow - TensorBoard可视化
  • [uni-app] uview封装Popup组件,处理props及v-model的传值问题
  • 【C++】int a;和int *p=new int;有什么区别?
  • redis事务管理
  • TPS_C++版本及功能支持备注
  • 同步jenkinsfile流水线(sync-job)
  • STM32单片机WIFI-APP智能温室大棚系统CO2土壤湿度空气温湿度补光