Pytorch深度学习框架实战教程10:Pytorch模型保存详解和指南
相关文章 + 视频教程
《Pytorch深度学习框架实战教程01》《视频教程》
《Pytorch深度学习框架实战教程02:开发环境部署》《视频教程》
《Pytorch深度学习框架实战教程03:Tensor 的创建、属性、操作与转换详解》《视频教程》
《Pytorch深度学习框架实战教程04:Pytorch数据集和数据导入器》《视频教程》
《Pytorch深度学习框架实战教程05:Pytorch构建神经网络模型》《视频教程》
《Pytorch深度学习框架实战教程06:Pytorch模型训练和评估》《视频教程》
《Pytorch深度学习框架实战教程09:模型的保存和加载》《视频教程》
《Pytorch深度学习框架实战教程-番外篇01-卷积神经网络概念定义、工作原理和作用》
《Pytorch深度学习框架实战教程-番外篇02-Pytorch池化层概念定义、工作原理和作用》
《Pytorch深度学习框架实战教程-番外篇03-什么是激活函数,激活函数的作用和常用激活函数》
《PyTorch 深度学习框架实战教程-番外篇04:卷积层详解与实战指南》
《Pytorch深度学习框架实战教程-番外篇05-Pytorch全连接层概念定义、工作原理和作用》
《Pytorch深度学习框架实战教程-番外篇06:Pytorch损失函数原理、类型和案例》
《Pytorch深度学习框架实战教程-番外篇10-PyTorch中的nn.Linear详解》
引言:
pytorch模型保存的方式有哪些?为什么,我们推荐只保存模型权重的方式,进行模型的保存。
在 PyTorch 中,模型保存是将训练好的模型参数或整个模型状态持久化到磁盘的过程,便于后续加载使用(如推理、继续训练等)。下面详细介绍相关内容:
一、模型保存的方式及差异
PyTorch 主要有两种模型保存方式,核心区别在于保存的内容不同:
1. 保存模型参数(推荐)
仅保存模型的状态字典(state_dict
),这是一个包含模型所有可学习参数(权重和偏置)的 Python 字典。
保存代码:
import torch
import torch.nn as nn# 定义一个简单模型
class MyModel(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)model = MyModel()
# 保存状态字典
torch.save(model.state_dict(), "model_params.pth")
2. 保存整个模型(不推荐)
保存完整的模型对象(包括模型结构和参数),本质是使用 Python 的pickle
模块序列化整个对象。
保存代码:
# 保存整个模型
torch.save(model, "entire_model.pth")
两种方式的差异:
维度 | 保存参数(state_dict) | 保存整个模型 |
---|---|---|
保存内容 | 仅参数(权重、偏置) | 模型结构 + 参数 |
文件大小 | 较小(仅参数) | 较大(包含结构) |
灵活性 | 高(可加载到不同结构的模型) | 低(依赖原始模型定义) |
版本兼容性 | 好(参数格式稳定) | 差(pickle 序列化依赖环境) |
推荐场景 | 绝大多数情况(推理、续训) | 临时保存(不推荐长期使用) |
二、推荐的保存模式
优先推荐保存模型的state_dict
,原因如下:
- 灵活性高:加载时可灵活调整模型结构(如冻结部分层、修改输入输出维度等)。
- 版本兼容性好:
state_dict
是纯参数字典,不受 PyTorch 版本或模型定义代码变动的影响(只要参数名称匹配)。 - 节省空间:仅保存必要的参数,文件体积更小。
三、模型加载的具体要求
加载模型时需根据保存方式对应处理,核心要求是保证参数与模型结构匹配。
1. 加载state_dict
(对应 “保存参数” 方式)
步骤与要求:
- 必须先定义与原模型结构一致的模型类(至少需要匹配待加载参数的层名称和形状)。
- 通过
model.load_state_dict()
方法加载参数,需注意strict
参数(控制是否严格匹配所有参数)。
示例代码:
# 1. 重新定义模型结构(必须与保存参数时的结构兼容)
model = MyModel()# 2. 加载状态字典
state_dict = torch.load("model_params.pth")
# 严格匹配(默认,参数名称和数量必须完全一致)
model.load_state_dict(state_dict)
# 非严格匹配(允许部分参数不匹配,如加载预训练模型时冻结部分层)
# model.load_state_dict(state_dict, strict=False)
2. 加载整个模型(对应 “保存整个模型” 方式)
步骤与要求:
- 无需提前定义模型类(但依赖
pickle
反序列化,风险较高)。 - 必须保证加载环境中存在模型定义的依赖(如自定义层、导入路径等),否则会报错。
示例代码:
# 直接加载整个模型(不推荐)
model = torch.load("entire_model.pth")
四、其他注意事项
- 设备兼容性:加载时需注意模型参数所在设备(CPU/GPU),可通过
map_location
参数指定:# 加载到CPU(无论保存时在哪个设备) state_dict = torch.load("model_params.pth", map_location=torch.device('cpu'))
- 训练状态:加载后如需推理,需通过
model.eval()
切换到评估模式(关闭 dropout、BN 层等)。 - 文件格式:推荐使用
.pth
或.pt
作为后缀(PyTorch 默认格式)。
综上,实际应用中应优先选择保存state_dict
,加载时确保模型结构与参数匹配,以保证灵活性和兼容性。