当前位置: 首页 > news >正文

关于Pytorch转换为MindSpore的一点建议

一、事先准备

必须要对Mindspore有一些了解,因为这个框架确实有些和其它流程不一样的地方,比如算子计算、训练过程中的自动微分,所以这两个课程要好好过一遍,官网介绍文档最好也要过一遍
1、零基础Mindspore:https://www.bilibili.com/video/BV1CS4y1z72r/?spm_id_from=333.337.search-card.all.click在这里插入图片描述 2、MindSpore进阶课程:https://www.bilibili.com/video/BV12W4y1t7yn/?spm_id_from=333.337.search-card.all.click
在这里插入图片描述

3、Mindspore教程:MindSpore教程 — MindSpore master documentation
在这里插入图片描述

对这些课程和文档过一遍后,可以去看几个数据加载和模型训练的案例
最好是自定义数据集加载,因为大多数据集都是表格或者其它,图像分类案例较少
跑一下几个案例,理解他们的这个过程

二、框架转换过程注意事项

框架转换主要有以下基本,拿转换医学影像分割的来讲述(pytorch-》Mindspore)
官网也是有给网络迁移部分的要点说明的,也可以好好看看
在这里插入图片描述

转换之前一定要理解自己原有网络当中的每一部分的处理、每一部分的数据形态和类型,这样转换起来比较容易

1、数据集导入

判断好数据集是什么类型,能否用快捷方式加载,如果不能就自定义数据集,然后用GeneratorDataset进行加载
数据加载类,注意最后返回的要是两部分值,前者为数据,后者为标签
在这里插入图片描述

一定要这样,因为GeneratorDataset需要这种形式,期间的计算,每一步可以看看有无问题,形态和原有网络保持一致

2、网络结构搭建

2.1 如果已经有算法,也有网络,那就一层的对比着看,保证每层输入输出一样

在这里插入图片描述

2.2 对应的网络中的API计算,大多都能对应上,主要有部分会有细节差异,需要去官网查询对应API,填写适应参数

如这里和pytorch的就不一样,mindspore中的scale_factor不能和bilinear一起,所有要替换为其它插值方式,另外插值法方式也会影响padding的值
在这里插入图片描述

就是要保证每层的输入输出都一致,计算要正常,如这里mindspore不写stride=2就会导致后面的计算出问题
在这里插入图片描述

2.3 一点一点的对比和尝试,必须要保证网络重每一步的计算前和计算后的数据形态一样

最终的输出也是要保持一致,数据经过网络得到预测值,预测值的shape注意保持一致

3、模型训练

一定要保证数据的准确,在pytorch内是什么形式在mindspore内也要是
对于梯度和loos的计算,多打印出来看一看,虽然pytorch和mindspore训练过程有所不同,但整体还是相似的
在这里插入图片描述

注意label的shape要和模型输出的logit一样,这样才能计算loss,这里可能会有维度不相同,那就去掉无关维度即可,mindspore里也有squeeze,多看看文档
流程就是,训练step内使用gard_fn,进行自动微分计算(这里mindspore用了这就不用梯度清零了),自动微分计算value_and_grad中又会调用前向传播函数,前向传播中涉及到loos的计算,一般只要loss输出没有问题,那么其它都是小事情
注意各项的形式,很容易理解的还是

4、训练和评估

这个过程就很简单了,只要前面定义好训练step和其他的什么优化器、损失函数还有前向传播网络什么的,那么这就很简单了,获取可迭代数据进行一个batch一个batch的训练就行了,loss可以计算可以输出,模型的评估上mindspore里面也有提供一些自定义的评估,看需要用到什么,先去搜搜看,看看如何使用的,直接套用即可
在这里插入图片描述

5、模型保存和调用推理

这部分就很简单了,按照格式定义即可

在这里插入图片描述

三、总结

整体来说,只要数据集构建没有问题,网络结构没有问题(需要计算测试)
那么框架转换就很简单了,因为训练的流程都大致相同,虽然mindspore里面没有梯度清零什么的
但是也有独特的自动微分梯度求导,这个多看几个案例,其实也是一套流程

http://www.lryc.cn/news/380514.html

相关文章:

  • JetBrains IDEA 新旧UI切换
  • iOS KeychainAccess的了解与使用
  • STM32 Customer BootLoader 刷新项目 (二) 方案介绍
  • 2-14 基于matlab的GA优化算法优化车间调度问题
  • Program-of-Thoughts(PoT):结合Python工具和CoT提升大语言模型数学推理能力
  • ansible setup模块
  • 【2024最新华为OD-C/D卷试题汇总】[支持在线评测] LYA的测试用例执行计划(100分) - 三语言AC题解(Python/Java/Cpp)
  • NSIS 入门教程 (一)
  • cve-2015-3306-proftpd-vulfocus
  • 超详细!想进华为od的请疯狂看我!
  • MQTT协议与TCP/IP协议在性能上的区别
  • LeetCode 每日一题 2024/6/17-2024/6/23
  • FlinkCDC pipeline模式 mysql-to-paimon.yaml
  • mysql数据库入门手册
  • 增强大型语言模型(LLM)可访问性:深入探究在单块AMD GPU上通过QLoRA微调Llama 2的过程
  • 空间复杂度 线性表,顺序表尾插。
  • linux创建用户、切换用户、删除用户
  • BC64 牛牛的快递(c++)
  • 离线linux通过USB连接并使用手机网络
  • I2C总线8位IO扩展器PCF8574
  • webClient + fastJSON2 获取json格式的数据,同时解析至java class 并 下划线转驼峰
  • 4、SpringMVC 实战小项目【加法计算器、用户登录、留言板、图书管理系统】
  • OpenCV--形态学
  • 【LinuxC语言】IP地址相关的函数
  • QT事件处理系统之五:自定义事件的发送案例 sendEvent和postEvent接口
  • 模版与策略模式
  • SQL-Python
  • mysql索引以及优化
  • 【pytorch06】 维度变换
  • 移动Web开发实战内容要点!!!