当前位置: 首页 > news >正文

深度学习-读写模型网络文件

模型网络文件是深度学习模型的存储形式,保存了模型的架构、参数等信息。

读写模型网络文件是深度学习流程中的关键环节,方便模型的训练、测试、部署与共享。

1. 主流框架读写方法

(一)TensorFlow

  • 保存模型

    • 可以使用 tf.saved_model.save 方法保存整个模型,包括架构、参数、编译信息等。例如: model.save('model_dir', save_format='tf'),将模型保存在文件夹 'model_dir' 中。

  • 加载模型

    • 使用 tf.keras.models.load_model 加载保存的模型。如:loaded_model = tf.keras.models.load_model('model_dir'),即可加载之前保存的模型进行预测、继续训练等操作。

(二)PyTorch

使用 torch.save 和 torch.load 来保存和加载 张量

  • 保存模型

    • 通常有两种方式:一种是保存整个模型对象,使用 torch.save(model, 'model.pth'),将模型结构和参数都保存下来。另一种是仅保存模型的参数状态字典,即 torch.save(model.state_dict(), 'model_state_dict.pth'),这种方式更常见,因为当模型架构修改时,只要能正确加载参数,就无需重新训练整个模型。

  • 加载模型

    • 对于保存整个模型的情况,直接使用 model = torch.load('model.pth')。对于仅保存参数的情况,先定义好模型架构,再用 model.load_state_dict(torch.load('model_state_dict.pth')) 加载参数,使模型具备相应的能力。

对于深度学习模型而言,通常只需保存其权重参数即可满足需求。在 PyTorch 框架中,可以使用 torch.save() 函数来保存网络的 state_dict 参数,这是保存模型权重的一种高效方式。

而在加载模型权重时,可以借助网络的 load_state_dict() 方法,搭配 torch.load() 函数来实现对网络参数的读取,从而恢复模型的训练状态和性能表现。

2. 模型保存示例

torch.save(model.state_dict(), path)只保存“参数”(一个纯字典),文件小、加载灵活。

torch.save(model.state_dict(), "best_model.pt")

1. 加载时必须先重新建网络,再把参数填进去:

new_model = MyNet()                      # 重新建图
new_model.load_state_dict(torch.load("best_model.pt"))
new_model.eval()                         # 记得切到推理模式

2. 优点

  • 文件 ≈ 仅参数大小,磁盘占用小
  • 不关心原始类定义,跨代码版本更稳

3. 缺点

        需要手动重建网络结构才能用

torch.save(model, path):把整个模型(结构+参数)序列化为一个 Pickle 对象,一步到位。

torch.save(model, "full_model.pt")

1. 加载极其简单:

model = torch.load("full_model.pt")      # 结构+参数全回来
model.eval()

2. 优点

        一行代码即可复现模型,适合快速分享、断点继续训练

3. 缺点

  • Pickle 会硬编码类定义路径,代码位置/类名一变就加载失败

  • 文件更大(含结构+参数)

选用建议

  • 生产/长期维护 → 用 state_dict(稳妥、小、可迁移)。

  • 临时 checkpoint / 本地快速实验 → 用 完整模型(省事)。

http://www.lryc.cn/news/609212.html

相关文章:

  • 03.一键编译安装Redis脚本
  • 07.config 命令实现动态修改配置和慢查询
  • ThinkPHP8.x控制器和模型的使用方法
  • VUE-第二季-01
  • 【实习总结】Qt通过Qt Linguist(语言家)实现多语言支持
  • Python-初学openCV——图像预处理(六)
  • 机器学习之决策树(二)
  • solidworks打开step报【警告!可用的窗口资源极低】的解决方法
  • 《C 语言内存函数深度剖析:从原理到实战(memcpy/memmove/memset/memcmp 全解析)》
  • 使用ACK Serverless容器化部署大语言模型FastChat
  • 【十九、Javaweb-day19-Linux概述】
  • 我的世界模组进阶教程——伤害(1)
  • 每日面试题20:spring和spring boot的区别
  • Linux 文件与目录操作命令宝典
  • Unity_数据持久化_IXmlSerializable接口
  • 【视频内容创作】PR的关键帧动画
  • SQL157 更新记录(一)
  • linux下jvm之jstack的使用
  • 代码随想录day53图论4
  • Java 大视界 -- Java 大数据在智能教育学习资源个性化推荐与学习路径动态调整中的深度应用(378)
  • 【LLM】 BaseModel的作用
  • 【0基础PS】PS工具详解--文字工具
  • Shell脚本-变量是什么
  • 思途JSP学习 0802(项目完整流程)
  • Linux网络编程 --- 多路转接select
  • Unity JobSystem 与 BurstCompiler 资料
  • 2025.8.3
  • webrtv弱网-QualityScalerResource 源码分析及算法原理
  • 【大模型实战】向量数据库实战 - Chroma Milvus
  • Linux mount挂载选项详解(重点关注nosuid)