当前位置: 首页 > news >正文

cast提前!最简单有效的神经网络优化方法,没有之一!

做优化有时候真的很头疼,绞尽脑汁的想怎么做算法等价,怎么把神经网络各层指令流水起来,在确保整网精度的同时,又有高性能。

但有时做了半天,却发现流水根本就流不起来,总是莫名其妙地被卡住。

真的是一顿操作猛如虎,回头一看原地杵。

今天介绍一种神经网络的性能优化方法。它不需要懂特深奥的算法知识,就能做到整个优化系统,大到网络,小到算子的性能的成倍提升。

而且绝对是成倍的性能提升,并且显而易见的算法等价。

怎么做呢?很简单,只需要改一下算子的先后调用顺序就行。

先说下背景。

在做AI推理或者训练时,大部分情况下一个神经网络中的所有层(Layer)的计算数据类型是相同的。

比如为了网络有更好的识别精度,神经网络中的运算可以使用高精度的浮点数,如 float32,简称 FP32。

但有时为了性能,稍微损失一点识别精度也能接受,此时可能会使用 float16,简称FP16, 也就是半精度数据类型来做运算。

FP32 和 FP16 的区别在于,前者数据位宽是后者的两倍,因此表示相同的数据的时候,前者的精度更高,但内存占用也更大。

比如同时存储一张图片,如果使用 FP32的话,可能会占用1MB的内存,但如果使用FP16来存储,只占0.5MB的内存。

我们可能听说过混合精度推理、混合精度训练。这里说的混合,指的就是精度混合。比如一个神经网络中存在多种数据类型。

为什么可以做混合精度的推理或训练呢?

一个神经网络就像是一个大厦,由一层一层的算法搭建而成,每一层的算法可能不同。不同的算法对数据精度的敏感程度不同。

有很多算法对数据精度不敏感,比如 transpose, gather, scatter等,这类算法都是数据搬运操作,也就是纯IO操作。他们不需要进行数据计算,无需考虑数据在做加法时候的溢出处理等情况。

而有些算法对数据精度很敏感,典型的比如conv2d算法,它需要做大量的乘累加操作,数据的累加很容易出现溢出,此时需要用更高位宽的数据来接收累加结果。

如果把操作 FP32 比作需要搬运32块砖的话,那么 FP16 就是只需要搬运 16块砖。很明显,搬运16块砖比搬运32块砖,在其他条件不变的情况下,要省时省力。

因此,在神经网络尤其是混合训练或推理的网络中,如果遇到了一些数据搬运算法搬运的是 FP32,那么是很有机会只让他搬16块砖(FP16)的。

那么具体怎么做呢?

首先简化一个神经网络,假设一个神经网络有如下结构:

在这个假想的网络中,卷积层(conv2d)计算的输出是 FP32,然后送给transpose 层进行数据搬运,transpose由于是纯IO算法,因此它的输出也是FP32。

transpose的输出送给下一层cast,cast负责将FP32的数据转换为FP16, 因此cast 的输出是FP16。然后FP16的数据送给接下来的层进行运算。

不知有没有发现,在这个网络中,transpose 算法先搬运了FP32的数据,然后交给了 cast 进行数据类型转换,转换成了更低位宽的 FP16。

但是由于 transpose 是纯IO运算,对数据类型不敏感,因此,我们完全可以将cast算子提前到 transpose 之前,如此的话,transpose 只需要做 FP16 的数据搬运。

转换之后的网络如下:

这样做的结果就是:整个网络的计算是等价的,但是 transpose 算子却由原来进行 FP32 的数据搬运,变成了 FP16 的数据搬运。对 transpose而言,其IO性能表现是成倍的提升。

这只是举一个很简单的例子。

而实际上,在真实的网络中,使用此方法可以优化成功的算法有时不仅仅是一个简单的 transpose,而是一个很大的网络片段。

由此可见,仅仅将 cast 提前这一个简单的操作,就能使整网的性能提升一倍。

这个方法很简单,很有效,也很容易实施。但是在实际进行网络优化的时候,有时却会被忽略。

能够使用这一优化的网络必须满足以下两个条件:

  • 必须是混合精度的网络

  • 由高位宽转低位宽的cast 算子前存在 IO 型算子

在我们绞尽脑汁使用一些高级的技巧,如模型并行、层层流水来做网络优化的同时,不妨放大视角,着眼全图,看看整网是否满足上面的条件,没准只一眼,就能发现这一最简单有效的优化点,从此百分比的提升网络性能,不是梦!

http://www.lryc.cn/news/24144.html

相关文章:

  • LeetCode刷题——动态规划(C/C++)
  • 车载智能终端TBOX
  • 技术分担产品之忧(上):挑选有业务专家潜力的人
  • UVa 12569 Planning mobile robot on Tree (EASY Version) 树上机器人规划(简单版) BFS 二进制
  • intel的集成显卡(intel(r) uhd graphics) 配置stable diffusion
  • 【数据库的基础知识(2)】
  • Docker部署实战
  • RestTemplate 相关使用
  • 新手小白亚马逊注册最全教程在此
  • 二分查找重复情况 找最左边或最右边的位置下标
  • 智慧扫码点餐系统源码
  • 分布式环境并发场景下,如何操作抢红包(或者减少库存)
  • 明星的孩子也在做的感统训练,真的有用吗?
  • 守护进程与TCP通讯
  • 在线文本翻译能力新增14个直译模型,打造以中文为轴心语言的翻译系统
  • CVE-2022-42889 Apache Commons Text 漏洞
  • 20- widedeep及函数式构建模型 (TensorFlow系列) (深度学习)
  • 大家一起做测试的,凭什么你现在拿20k,我却还只有10k?...
  • >>数据管理:DAMA简介「考试和续期」
  • React的生命周期详细讲解
  • 蓝蓝算法二期工程day3,一万年太久,只争朝夕
  • 程序代码的自动化生成方案设计
  • Go 稀疏数组学习与实现
  • MySQL 学习笔记(借鉴黑马程序员MySQL)
  • 中级工程师职称申报到底需要参加答辩不?
  • MM32开发教程(LED灯)
  • win10安装docker
  • 设计模式系列 - 代理模式及动态代理详解
  • 【分享】订阅集简云畅捷通T+cloud连接器自动同步财务费用单至畅捷通
  • GPT的发展历程