当前位置: 首页 > news >正文

Pytorch深度学习框架实战教程10:Pytorch模型保存详解和指南

 相关文章 + 视频教程

《Pytorch深度学习框架实战教程01》《视频教程

Pytorch深度学习框架实战教程02:开发环境部署》《视频教程

Pytorch深度学习框架实战教程03:Tensor 的创建、属性、操作与转换详解》《视频教程

《Pytorch深度学习框架实战教程04:Pytorch数据集和数据导入器》《视频教程

《Pytorch深度学习框架实战教程05:Pytorch构建神经网络模型》《视频教程》

《Pytorch深度学习框架实战教程06:Pytorch模型训练和评估》《视频教程》

Pytorch深度学习框架实战教程09:模型的保存和加载》《视频教程》

Pytorch深度学习框架实战教程-番外篇01-卷积神经网络概念定义、工作原理和作用

Pytorch深度学习框架实战教程-番外篇02-Pytorch池化层概念定义、工作原理和作用

Pytorch深度学习框架实战教程-番外篇03-什么是激活函数,激活函数的作用和常用激活函数

《PyTorch 深度学习框架实战教程-番外篇04:卷积层详解与实战指南》

《Pytorch深度学习框架实战教程-番外篇05-Pytorch全连接层概念定义、工作原理和作用》

《Pytorch深度学习框架实战教程-番外篇06:Pytorch损失函数原理、类型和案例》

Pytorch深度学习框架实战教程-番外篇10-PyTorch中的nn.Linear详解

引言:

pytorch模型保存的方式有哪些?为什么,我们推荐只保存模型权重的方式,进行模型的保存。

在 PyTorch 中,模型保存是将训练好的模型参数或整个模型状态持久化到磁盘的过程,便于后续加载使用(如推理、继续训练等)。下面详细介绍相关内容:

一、模型保存的方式及差异

PyTorch 主要有两种模型保存方式,核心区别在于保存的内容不同:

1. 保存模型参数(推荐)

仅保存模型的状态字典(state_dict),这是一个包含模型所有可学习参数(权重和偏置)的 Python 字典。
保存代码

import torch
import torch.nn as nn# 定义一个简单模型
class MyModel(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)model = MyModel()
# 保存状态字典
torch.save(model.state_dict(), "model_params.pth")

2. 保存整个模型(不推荐)

保存完整的模型对象(包括模型结构和参数),本质是使用 Python 的pickle模块序列化整个对象。
保存代码

# 保存整个模型
torch.save(model, "entire_model.pth")
两种方式的差异:
维度保存参数(state_dict)保存整个模型
保存内容仅参数(权重、偏置)模型结构 + 参数
文件大小较小(仅参数)较大(包含结构)
灵活性高(可加载到不同结构的模型)低(依赖原始模型定义)
版本兼容性好(参数格式稳定)差(pickle 序列化依赖环境)
推荐场景绝大多数情况(推理、续训)临时保存(不推荐长期使用)

二、推荐的保存模式

优先推荐保存模型的state_dict,原因如下:

  1. 灵活性高:加载时可灵活调整模型结构(如冻结部分层、修改输入输出维度等)。
  2. 版本兼容性好state_dict是纯参数字典,不受 PyTorch 版本或模型定义代码变动的影响(只要参数名称匹配)。
  3. 节省空间:仅保存必要的参数,文件体积更小。

三、模型加载的具体要求

加载模型时需根据保存方式对应处理,核心要求是保证参数与模型结构匹配

1. 加载state_dict(对应 “保存参数” 方式)

步骤与要求

  • 必须先定义与原模型结构一致的模型类(至少需要匹配待加载参数的层名称和形状)。
  • 通过model.load_state_dict()方法加载参数,需注意strict参数(控制是否严格匹配所有参数)。

示例代码

# 1. 重新定义模型结构(必须与保存参数时的结构兼容)
model = MyModel()# 2. 加载状态字典
state_dict = torch.load("model_params.pth")
# 严格匹配(默认,参数名称和数量必须完全一致)
model.load_state_dict(state_dict)
# 非严格匹配(允许部分参数不匹配,如加载预训练模型时冻结部分层)
# model.load_state_dict(state_dict, strict=False)
2. 加载整个模型(对应 “保存整个模型” 方式)

步骤与要求

  • 无需提前定义模型类(但依赖pickle反序列化,风险较高)。
  • 必须保证加载环境中存在模型定义的依赖(如自定义层、导入路径等),否则会报错。

示例代码

# 直接加载整个模型(不推荐)
model = torch.load("entire_model.pth")

四、其他注意事项

  1. 设备兼容性:加载时需注意模型参数所在设备(CPU/GPU),可通过map_location参数指定:
    # 加载到CPU(无论保存时在哪个设备)
    state_dict = torch.load("model_params.pth", map_location=torch.device('cpu'))
    
  2. 训练状态:加载后如需推理,需通过model.eval()切换到评估模式(关闭 dropout、BN 层等)。
  3. 文件格式:推荐使用.pth.pt作为后缀(PyTorch 默认格式)。

综上,实际应用中应优先选择保存state_dict,加载时确保模型结构与参数匹配,以保证灵活性和兼容性。

http://www.lryc.cn/news/616068.html

相关文章:

  • Spring Boot集成WebSocket
  • Spring Boot与WebSocket构建物联网实时通信系统
  • Android Intent 解析
  • Leetcode 3644. Maximum K to Sort a Permutation
  • 数学建模——回归分析
  • 香橙派 RK3588 部署 DeepSeek
  • 【2025CVPR-图象分类方向】ProAPO:视觉分类的渐进式自动提示优化
  • 【Linux】通俗易懂讲解-正则表达式
  • WAIC2025逛展分享·AI鉴伪技术洞察“看不见”的伪造痕迹
  • Jetpack系列教程(二):Hilt——让依赖注入像吃蛋糕一样简单
  • JavaWeb(苍穹外卖)--学习笔记17(Apache Echarts)
  • 【鸿蒙/OpenHarmony/NDK】什么是NDK? 为啥要用NDK?
  • 【图像算法 - 11】基于深度学习 YOLO 与 ByteTrack 的目标检测与多目标跟踪系统(系统设计 + 算法实现 + 代码详解 + 扩展调优)
  • 机器学习——DBSCAN 聚类算法 + 标准化
  • Python 实例属性和类属性
  • 安卓录音方法
  • Java 后端性能优化实战:从 SQL 到 JVM 调优
  • 深入解析React Diff 算法
  • Word XML 批注范围克隆处理器
  • React:useEffect 与副作用
  • MyBatis的xml中字符串类型判空与非字符串类型判空处理方式
  • 秋招春招实习百度笔试百度管培生笔试题库百度非技术岗笔试|笔试解析和攻略|题库分享
  • wordpress语言包制作工具
  • python正则表达式里面有特殊符号如何处理
  • 亚麻云之静态资源管家——S3存储服务实战
  • Day41--动态规划--121. 买卖股票的最佳时机,122. 买卖股票的最佳时机 II,123. 买卖股票的最佳时机 III
  • LeetCode 组合总数
  • AI质检数据准备利器:基于Qt/QML 5.14的图像批量裁剪工具开发实战
  • Python 2025:最新技术趋势与展望
  • Text2SQL 自助式数据报表开发(Chat BI)