当前位置: 首页 > news >正文

【昇思初学入门】第八天打卡-模型保存与加载

模型保存与加载

学习心得

  • 保存 CheckPoint 格式文件,在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及再训练使用。如果想继续在不同硬件平台上做推理,可通过网络和CheckPoint格式文件生成对应的MINDIR、AIR和ONNX格式文件。
    model = network()
    mindspore.save_checkpoint(model, "model.ckpt")
    
    可以通过CheckpointConfig对象可以设置CheckPoint的保存策略。
    • save_checkpoint_steps表示每隔多少个step保存一次。
    • keep_checkpoint_max表示最多保留CheckPoint文件的数量。
    • prefix表示生成CheckPoint文件的前缀名。
    • directory表示存放文件的目录。
    from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
    config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10)
    ckpoint_cb = ModelCheckpoint(prefix='resnet50', directory=None, config=config_ck)
    model.train(epoch_num, dataset, callbacks=ckpoint_cb)
    
    要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。
    	model = network()param_dict = mindspore.load_checkpoint("model.ckpt")param_not_load, _ = mindspore.load_param_into_net(model, param_dict)print(param_not_load)
    
    param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。
    [] 
    
  1. 保存和加载MindIR,当有了CheckPoint文件后,如果想继续在MindSpore Lite端侧做推理,需要通过网络和CheckPoint生成对应的MINDIR格式模型文件。
    • 统一表示:MindIR作为MindSpore云侧(训练)和端侧(推理)的统一模型文件,同时存储了网络结构和权重参数值。这使得MindSpore能够在不同的硬件平台上实现一次训练多次部署的能力。
    • 导出MindIR:MindSpore提供了export接口,可以直接将模型保存为MindIR格式。
    • 保存模型
    model = network()
    inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
    mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
    
    • 加载模型
    mindspore.set_context(mode=mindspore.GRAPH_MODE)
    graph = mindspore.load("model.mindir")
    model = nn.GraphCell(graph)
    outputs = model(inputs)
    print(outputs.shape)
    
http://www.lryc.cn/news/384094.html

相关文章:

  • 喜报!极限科技新获得一项国家发明专利授权:“搜索数据库的正排索引处理方法、装置、介质和设备”
  • 深入探讨:UART与USART在单片机中串口的实际应用与实现技巧
  • Windows上PyTorch3D安装踩坑记录
  • 操作符详解(上) (C语言)
  • 使用 audit2allow 工具添加SELinux权限的方法
  • 一文弄懂FPGA
  • Rust 中使用 :: 这种语法的几种情况
  • Ruby langchainrb gem and custom configuration for the model setup
  • 高校新生如何选择最优手机流量卡?
  • QT QML 生成二维码
  • IDEA中Maven--下载安装自己适配的版本---理解
  • 【osgEarth】Ubuntu 22.04 源码编译osgEarth 3.5
  • ASP.NET Core 6.0 使用 资源过滤器和行为过滤器
  • 电脑屏幕花屏怎么办?5个方法解决问题!
  • git 初基本使用-----------笔记
  • Redis-数据类型-Bit的基本操作-getbit-setbit-Bitmap
  • 统信UOS上鼠标右键菜单中添加自定义内容
  • 学习入门 chatgpt原理 一
  • 生命在于学习——Python人工智能原理(4.7)
  • 经典游戏案例:仿植物大战僵尸
  • [Day 18] 區塊鏈與人工智能的聯動應用:理論、技術與實踐
  • 【Mac】DMG Canvas for mac(DMG镜像制作工具)软件介绍
  • RAG分块方法 从固定大小到自然语言处理分块——深入研究文本分块技术
  • FFmpeg 系列
  • 240626_昇思学习打卡-Day8-稀疏矩阵
  • Docker: 使用容器化数据库
  • Oracle对用户敏感数据进行编码处理
  • VXLAN详解:概念、架构、原理、搭建过程、常用命令与实战案例
  • Redis-数据类型-Hash
  • 基于redisson实现tomcat集群session共享