当前位置: 首页 > news >正文

b站小土堆pytorch学习记录—— P25-P26 网络模型的使用和修改、保存和读取

文章目录

  • 一、修改
    • 1.方法
    • 2.代码
  • 二、保存和读取
    • 1.方法
    • 2.代码
      • (1)保存
      • (2)加载
    • 3.陷阱

一、修改

1.方法

add_module(name: str, module: Module) -> None

name 是要添加的子模块的名称。
module 是要添加的子模块。
调用 add_module 方法会向当前模块中添加一个子模块,并使用指定的名称进行标识。

2.代码

import torchvision
from torch import nn# 实例化一个未经过预训练的 VGG16 模型
vgg16_false = torchvision.models.vgg16(pretrained=False)# 实例化一个经过预训练的 VGG16 模型
vgg16_true = torchvision.models.vgg16(pretrained=True)print("ok")# 输出经过预训练的 VGG16 模型及修改后的模型
print(vgg16_true)
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
print(vgg16_true)# 输出未经过预训练的 VGG16 模型及修改后的模型
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

修改前的vgg16_true:

在这里插入图片描述
修改后的vgg16_true:

在这里插入图片描述

修改前的vgg16_true:

在这里插入图片描述

修改后的vgg16_true:

在这里插入图片描述

二、保存和读取

1.方法

保存: torch.save(要保存的模型,“文件路径”)

加载: torch.load(“文件路径”)

2.代码

(1)保存

import torch
import torchvisionvgg16 = torchvision.models.vgg16(pretrained=False)# 保存方式1:模型结构+模型参数
torch.save(vgg16, "vgg16_module1.pth")# 保存方式2:模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_module2.pth")

(2)加载

import torch
import torchvision# 方式1 加载模型
module1 = torch.load("vgg16_module1.pth")
print(module1)#
module2 = torch.load("vgg16_module2.pth")
print(module2)# 方式2 加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_module2.pth"))
print(vgg16)

运行加载的代码后,打印结果如下

module1:

在这里插入图片描述
module2:

在这里插入图片描述

vgg16:

在这里插入图片描述

可以看到,第二种方式保存的数据,加载后是向量形式,需要通过别的方法加载为模型

3.陷阱

第一种方式加载,在某些条件下可能会报错

例如:

假设自定义一个神经网络,保存:

import torch
import torchvision
from torch import nn# 陷阱
class Guodong(nn.Module):def __init__(self):super(Guodong,self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self,x):x = self.conv1(x)return xguodong = Guodong()
torch.save(guodong,"guodong_method1.pth")

在另一个文件中加载:

import torch# 陷阱
module = torch.load("guodong_method1.pth")
print(module)

就会报错:

AttributeError: Can’t get attribute ‘Guodong’ on <module ‘main’ from ‘E:\deepLearning\Pycharm\pytroch_project\theFirstFile\module_load.py’>

解决办法:

(1)把Guodong类放在这个文件里

import torch
from torch import nn
import torchvisionclass Guodong(nn.Module):def __init__(self):super(Guodong,self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self,x):x = self.conv1(x)return x# 陷阱
module = torch.load("guodong_method1.pth")
print(module)

(2)from module_save import *

(module_save)是保存自定义模型的文件

from module_save import *# 陷阱
module = torch.load("guodong_method1.pth")
print(module)
http://www.lryc.cn/news/313172.html

相关文章:

  • [数据结构]OJ用队列实现栈
  • 「优选算法刷题」:最长回文子串
  • Java项目:41 springboot大学生入学审核系统的设计与实现010
  • 【数据结构与算法】常见排序算法(Sorting Algorithm)
  • Unity3D学习之XLua实践——背包系统
  • 前端技术研究越深入,越觉得技术不是决定录用唯一条件。
  • vue组件的重新渲染的问题
  • opengl 学习(二)-----你好,三角形
  • mongodb4.2升级到5.0版本,升级到6.0版本, 升级到7.0版本案例
  • CPU处理器模式与异常
  • Day 53 |● 1143.最长公共子序列 ● 1035.不相交的线 ● 53. 最大子序和
  • ant-desgin charts双轴图DualAxes,柱状图无法立即显示,并且只有在调整页面大小(放大或缩小)后才开始显示
  • 获取别人店铺的所有商品API接口
  • 成都正信:亲戚借了钱一直不还怎么委婉的说
  • Truenas入门级教程
  • 窗口函数dense() over(条件)
  • 蓝牙APP开发实现汽车遥控钥匙解锁汽车智能时代
  • 第三天 Kubernetes进阶实践
  • redis小结
  • PHP伪协议详解
  • 进程:守护进程
  • 千里马平台项目管理理念
  • GB 2312字符集:中文编码的基石
  • 我的创作周年纪念日
  • MySQL为什么要用B+树?
  • 今天分享一个好看的输入法皮肤相信每个人心里住着一个少女心我们美化一下她吧
  • 力扣刷题Days11第二题--141. 环形链表(js)
  • 微信自动回复的设置
  • SpringBoot源码解读与原理分析(一)SpringBoot整体概述
  • 如何选择VR全景设备,才能拍摄高质量的VR全景?