当前位置: 首页 > news >正文

重生学AI第十九集:VGG16的使用以及模型的保存与加载

1. VGG模型的认识与使用

1.1 认识VGG

1.1.1 VGG是什么?

VGG 是一种经典的卷积神经网络(CNN)结构,全名叫 Visual Geometry Group Network,是牛津大学视觉几何组(VGG)提出的。
最著名的是 VGG16VGG19,数字代表网络的层数(只包含卷积层和全连接层)。

1.1.2 为什么要用它?

VGG 在 图像分类任务 中取得了很好的效果,是很多现代网络的灵感来源。我们可以用它作为网络模型的基础,在此基础上进行修改添加来完成我们需要的模型。

1.1.3 他有什么特点?

VGG 结构非常简洁统一

  • 所有卷积层都用 3×3 小卷积核(方便堆叠、提取细节)

  • 所有池化层都是 2×2 的最大池化

  • 网络越深,特征表达能力越强

1.1.4 官方文档

VGG — Torchvision 0.22 documentation

1.2 VGG16的使用

1.2.1 网络结构

  • 13 个卷积层

  • 5 个最大池化层(不参与学习)

  • 3 个全连接层(最后一个是分类输出)

1.2.2 参数

  • weights :是否使用官方训练好的权重
    • weights = VGG16_Weights.DEFAULT (使用官方训练过的权重)
    • 不填写则无权重,原始模型参数
  • progress :
    • True : 显示进度条(默认)
    • False : 不显示进度条

1.2.3 模型初始化

import torchvision.models
from torchvision.models import VGG16_Weights#初始化VGG16模型,eval表示切换到模型的评估模式,用于测试,模型更稳定
model_vgg16 = torchvision.models.vgg16(weights= VGG16_Weights.DEFAULT).eval()
print(model_vgg16)

控制台输出:

1.2.4 增加模型

从控制台的输出可以看到,最终输出1000个结果,现在让他输出10个,那么我们只需要再加一个全连接层即可

#增加一个全连接层
model_vgg16.add_module("add_linear",nn.Linear(1000, 10))
print(model_vgg16)

控制台输出:

1.2.5 增加模型到容器

# 添加模型到容器
model_vgg16.classifier.add_module("add_linear_seq",nn.Linear(1000, 10))
print(model_vgg16)

控制台输出:

1.2.6 修改原有的模型

# 修改原有模型
model_vgg16.classifier[6]=nn.Linear(4096, 10)
print(model_vgg16)

控制台输出

2. 网络模型的保存与加载

当我们训练好模型,就可以将模型以文件的方式保存下来,下一次使用时,直接加载模型即可恢复到上一次保存模型时的状态。

2.1 网络模型的保存与加载方式一

2.1.1 保存模型方式一

保存模型使用的函数是torch.save(),需要填写两个参数

  • 参数一:训练好的模型名称
  • 参数二:要保存的模型文件名,以 .pth 结尾
#保存模型 方式一
torch.save(model_vgg16,"vgg16_method1.pth")

运行之后,当前文件夹就会出现一个文件

2.1.2 加载模型方式一

加载模型的方式和保存的方式是一一对应的,用方式一保存的就要用方式一来加载

加载模型需要用到torch.load(),同样需要两个参数

  • 参数一:需要加载的模型文件名
  • 参数二:是否只加载模型的权重,因为默认是只加载权重的(True),所以需要显式声明为False,完整的加载整个模型

到这里我们就可以想到,这种方式其实是不太方便的,因为要保存整个模型文件,所以这个官方并不推荐

#加载模型 方式一
model1 = torch.load("vgg16_method1.pth",weights_only=False)
print(model1)

运行结果:

打印后就可以看到,跟我们之前训练的模型一模一样

2.2 网络模型的保存与加载方式二

2.2.1 保存方式二

方式二使用的函数和方式一一样,只不过参数略有调整,第一个参数不再直接传入模型,而是将模型转换为一个字典,这个字典保存的是模型的权重

#保存模型 方式二
torch.save(model_vgg16.state_dict(),"vgg16_method2.pth")

2.2.2 加载方式二

当然了,我们保存的是权重,加载的时候加载的也只是权重罢了

#加载模型 方式二
model2 = torch.load("vgg16_method2.pth")
print(model2)

所以打印后是这样的

这种方式我们就需要先初始化模型

#加载模型 方式二
model2 = torch.load("vgg16_method2.pth")
#初始化模型 因为我们要加载上次的权重,所以这里初始化的权重就可以不用了
vgg16 = torchvision.models.vgg16()
#加载我们保存的权重
vgg16.load_state_dict(model2)
print(vgg16)

这里运行后报错了,因为我们之前的模型,添加了三个层,而初始化的模型是没有这三个层的,所以这里也要加三个,跟之前的模型层保持一致

#加载模型 方式二
model2 = torch.load("vgg16_method2.pth")
#初始化模型 因为我们要加载上次的权重,所以这里初始化的权重就可以不用了
vgg16 = torchvision.models.vgg16()#添加网络层 与加载的模型层保持一致
vgg16.classifier[6] = nn.Linear(4096, 10)
vgg16.add_module("add_linear",nn.Linear(1000, 10))
vgg16.classifier.add_module("add_linear_seq",nn.Linear(1000, 10))#加载我们保存的权重
vgg16.load_state_dict(model2)
print(vgg16)

控制台输出结果:

成功,拜拜

http://www.lryc.cn/news/596545.html

相关文章:

  • 【期末考试复习】计算机组成原理 - 直接补码阵列乘法器
  • 【接口自动化】pytest的基本使用
  • CSS+JavaScript 禁用浏览器复制功能的几种方法
  • web登录页面
  • 黑马点评练习题-给店铺类型查询业务添加缓存(String和List实现)
  • kafka4.0集群部署
  • 数据结构01:链表
  • docker compose 安装使用笔记
  • Docker实战:使用Docker部署TeamMapper思维导图工具
  • 【实时Linux实战系列】基于实时Linux的传感器网络设计
  • Spring Boot音乐服务器项目-登录模块
  • 【论文阅读】Fast-BEV: A Fast and Strong Bird’s-Eye View Perception Baseline
  • 基于VU13P的百G光纤FMC高性能处理板
  • Rust实战:决策树与随机森林实现
  • 板凳-------Mysql cookbook学习 (十二--------5)
  • 【RAG优化】PDF复杂表格解析问题分析
  • 阶段1--Linux中的文件服务器(FTP、NAS、SSH)
  • 从差异到协同:OKR 与 KPI 的管理逻辑,Moka 让适配更简单
  • 苹果app应用ipa文件程序开发后如何运行到苹果iOS真机上测试?
  • C# 析构函数
  • 【论文阅读 | TIV 2024 | CDC-YOLOFusion:利用跨尺度动态卷积融合实现可见光-红外目标检测】
  • 2025年07月22日Github流行趋势
  • 坑机介绍学习研究
  • 激活函数Focal Loss 详解​
  • 数组——初识数据结构
  • DMZ网络安全基础知识
  • [3-02-02].第04节:开发应用 - RequestMapping注解的属性2
  • Fluent许可与网络安全策略
  • 【kubernetes】-2 K8S的资源管理
  • Java数据结构——ArrayList