重生学AI第十九集:VGG16的使用以及模型的保存与加载
1. VGG模型的认识与使用
1.1 认识VGG
1.1.1 VGG是什么?
VGG 是一种经典的卷积神经网络(CNN)结构,全名叫 Visual Geometry Group Network,是牛津大学视觉几何组(VGG)提出的。
最著名的是 VGG16 和 VGG19,数字代表网络的层数(只包含卷积层和全连接层)。
1.1.2 为什么要用它?
VGG 在 图像分类任务 中取得了很好的效果,是很多现代网络的灵感来源。我们可以用它作为网络模型的基础,在此基础上进行修改和添加来完成我们需要的模型。
1.1.3 他有什么特点?
VGG 结构非常简洁统一:
-
所有卷积层都用 3×3 小卷积核(方便堆叠、提取细节)
-
所有池化层都是 2×2 的最大池化
-
网络越深,特征表达能力越强
1.1.4 官方文档
VGG — Torchvision 0.22 documentation
1.2 VGG16的使用
1.2.1 网络结构
-
13 个卷积层
-
5 个最大池化层(不参与学习)
-
3 个全连接层(最后一个是分类输出)
1.2.2 参数
- weights :是否使用官方训练好的权重
- weights = VGG16_Weights.DEFAULT (使用官方训练过的权重)
- 不填写则无权重,原始模型参数
- progress :
- True : 显示进度条(默认)
- False : 不显示进度条
1.2.3 模型初始化
import torchvision.models
from torchvision.models import VGG16_Weights#初始化VGG16模型,eval表示切换到模型的评估模式,用于测试,模型更稳定
model_vgg16 = torchvision.models.vgg16(weights= VGG16_Weights.DEFAULT).eval()
print(model_vgg16)
控制台输出:
1.2.4 增加模型
从控制台的输出可以看到,最终输出1000个结果,现在让他输出10个,那么我们只需要再加一个全连接层即可
#增加一个全连接层
model_vgg16.add_module("add_linear",nn.Linear(1000, 10))
print(model_vgg16)
控制台输出:
1.2.5 增加模型到容器
# 添加模型到容器
model_vgg16.classifier.add_module("add_linear_seq",nn.Linear(1000, 10))
print(model_vgg16)
控制台输出:
1.2.6 修改原有的模型
# 修改原有模型
model_vgg16.classifier[6]=nn.Linear(4096, 10)
print(model_vgg16)
控制台输出
2. 网络模型的保存与加载
当我们训练好模型,就可以将模型以文件的方式保存下来,下一次使用时,直接加载模型即可恢复到上一次保存模型时的状态。
2.1 网络模型的保存与加载方式一
2.1.1 保存模型方式一
保存模型使用的函数是torch.save(),需要填写两个参数
- 参数一:训练好的模型名称
- 参数二:要保存的模型文件名,以 .pth 结尾
#保存模型 方式一
torch.save(model_vgg16,"vgg16_method1.pth")
运行之后,当前文件夹就会出现一个文件
2.1.2 加载模型方式一
加载模型的方式和保存的方式是一一对应的,用方式一保存的就要用方式一来加载
加载模型需要用到torch.load(),同样需要两个参数
- 参数一:需要加载的模型文件名
- 参数二:是否只加载模型的权重,因为默认是只加载权重的(True),所以需要显式声明为False,完整的加载整个模型
到这里我们就可以想到,这种方式其实是不太方便的,因为要保存整个模型文件,所以这个官方并不推荐
#加载模型 方式一
model1 = torch.load("vgg16_method1.pth",weights_only=False)
print(model1)
运行结果:
打印后就可以看到,跟我们之前训练的模型一模一样
2.2 网络模型的保存与加载方式二
2.2.1 保存方式二
方式二使用的函数和方式一一样,只不过参数略有调整,第一个参数不再直接传入模型,而是将模型转换为一个字典,这个字典保存的是模型的权重
#保存模型 方式二
torch.save(model_vgg16.state_dict(),"vgg16_method2.pth")
2.2.2 加载方式二
当然了,我们保存的是权重,加载的时候加载的也只是权重罢了
#加载模型 方式二
model2 = torch.load("vgg16_method2.pth")
print(model2)
所以打印后是这样的
这种方式我们就需要先初始化模型
#加载模型 方式二
model2 = torch.load("vgg16_method2.pth")
#初始化模型 因为我们要加载上次的权重,所以这里初始化的权重就可以不用了
vgg16 = torchvision.models.vgg16()
#加载我们保存的权重
vgg16.load_state_dict(model2)
print(vgg16)
这里运行后报错了,因为我们之前的模型,添加了三个层,而初始化的模型是没有这三个层的,所以这里也要加三个,跟之前的模型层保持一致
#加载模型 方式二
model2 = torch.load("vgg16_method2.pth")
#初始化模型 因为我们要加载上次的权重,所以这里初始化的权重就可以不用了
vgg16 = torchvision.models.vgg16()#添加网络层 与加载的模型层保持一致
vgg16.classifier[6] = nn.Linear(4096, 10)
vgg16.add_module("add_linear",nn.Linear(1000, 10))
vgg16.classifier.add_module("add_linear_seq",nn.Linear(1000, 10))#加载我们保存的权重
vgg16.load_state_dict(model2)
print(vgg16)
控制台输出结果:
成功,拜拜