当前位置: 首页 > news >正文

模型优化之剪枝

文章目录

  • 什么是神经网络剪枝
  • 剪枝的好处
  • 不同粒度的剪枝
  • 剪枝的分类
    • 非结构化剪枝
    • 结构化剪枝
  • 哪些层的参数更容易被剪掉
  • 剪枝效果

什么是神经网络剪枝

神经网络剪枝

  • 在训练期间删除连接
  • 密集张量将变得稀疏(用零填充)
  • 可以通过结构化块( n m nm nm)或( 11 11 11)删除连接

在这里插入图片描述

剪枝的好处

  • 减少过拟合
  • 稀疏性优势
  • 文件中有大量的0,如果有适当的稀疏张量表示方法,模型二进制文件尺寸减小。
  • 模型更小,可以减少内存带宽消耗量。
  • 对于特定模式的稀疏模型,可以开发优化算子,实现加速推理。

不同粒度的剪枝

在这里插入图片描述
什么时候做剪枝?

one-shot pruning : 一次性修剪,包括三个步骤训练模型、剪枝、再训练
剪枝:通常根据某种标准(如权重的大小、梯度的大小等)一次性去除大量权重。
再训练:剪枝后,模型通常需要进行一定数量的额外训练(称为fine-tuning或再训练)来恢复剪枝过程中可能损失的性能。

iterative pruning: 迭代式训练,特点如下:
初始训练:首先,对未剪枝的完整模型进行训练,直到达到满意的性能水平。
剪枝:然后,根据某种剪枝策略(例如基于权重的大小或敏感度)剪除模型的部分组件(如权重、神经元或通道)。
再训练:剪枝后,重新训练模型以恢复因剪枝而丢失的性能。
迭代:重复剪枝和再训练的过程,直到达到所需的剪枝率或性能标准。

automated gradual pruning: 自动化渐进剪枝,特点如下:
剪枝策略:采用一种预定义的剪枝策略,例如基于权重阈值、敏感度分析等,该策略在整个剪枝过程中保持一致。
渐进剪枝:在整个训练过程中逐渐增加剪枝率,通常从较低的剪枝率开始,逐步增加到目标剪枝率。
无需再训练:在整个剪枝过程中,模型持续被训练,而不是在剪枝后重新训练。
自动化:整个过程高度自动化,可以减少人为干预的需求

在这里插入图片描述

剪枝的分类

结构化剪枝(Structured Pruning)和非结构化剪枝(Unstructured Pruning)是两种常见的神经网络剪枝方法,它们的主要区别在于剪枝后网络结构的变化以及剪枝操作的粒度。

非结构化剪枝

不改变网络结构或者参数数量,把连接上的参数置0即为剪枝。
基于某种度量(如权重的绝对值大小)对所有权重进行排序,然后根据预先设定的剪枝比例(例如去除50%的最小权重)来决定哪些权重被设置为零。这种剪枝方法不会考虑权重在模型中的位置或结构,只关注权重本身的价值。示例代码:

# 导入剪枝函数
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude# 计算两轮之后完成剪枝时对应的迭代次数end_step
batch_size = 128
epochs = 2
validation_split = 0.1  # 10% of training set will be used for validation set.num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs# 定义剪枝模型参数,开始模型从50%稀疏度(权重为0的参数数量百分比),到80%稀疏度
pruning_params = {'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,final_sparsity=0.80,begin_step=0,end_step=end_step)
}model_for_pruning = prune_low_magnitude(model, **pruning_params)# 当使用函数`prune_low_magnitude`包装了一下模型后,需要重新编译一下
model_for_pruning.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model_for_pruning.summary()logdir = "./logs/mnist_pruning"callbacks = [tfmot.sparsity.keras.UpdatePruningStep(),tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]model_for_pruning.fit(train_images, train_labels,batch_size=batch_size, epochs=epochs, validation_split=validation_split,callbacks=callbacks)
# --------------------------------------------------
# 评估模型,对比剪枝前后模型的准确率变化
# 经过剪枝,这里有一个小的准确率下降,和没有进行剪枝相比的话
# --------------------------------------------------_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)

结构化剪枝

结构化剪枝改变了网络结构,即网络层输出元素个数,比如卷积核的减少会影响特征图数量。
在下面的例子中是基于选择的模型层做剪枝,所以需要指出哪些层去做结构化剪枝。比如剪枝第二个卷积层和第一个全连接层,剪枝策略为pruning_params_2_by_4,表示该层剪枝比例为2 / 4,即该层保留一半(2/4)的权重,而将另一半设为零。
注意:第一个卷积层不能被结构化剪枝。要是结构化剪枝的话,应该至少大于一个input channels(本例所用图片为单通道灰度图),所以我们对第一个卷积层使用随机剪枝。

model = keras.Sequential([prune_low_magnitude(keras.layers.Conv2D(32, 5, padding='same', activation='relu',input_shape=(28, 28, 1),name="pruning_sparsity_0_5"),**pruning_params_sparsity_0_5),keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),prune_low_magnitude(keras.layers.Conv2D(64, 5, padding='same',name="structural_pruning"),**pruning_params_2_by_4),keras.layers.BatchNormalization(),keras.layers.ReLU(),keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),keras.layers.Flatten(),prune_low_magnitude(keras.layers.Dense(1024, activation='relu',name="structural_pruning_dense"),**pruning_params_2_by_4),keras.layers.Dropout(0.4),keras.layers.Dense(10)
])

哪些层的参数更容易被剪掉

因为卷积层(conv)中的参数相比全连接层(fc)来说参数量少,所以卷积层参数的压缩比没有全连接层参数的压缩比大。换句话说,就是卷积层参数更加敏感,剪掉对准确率影响相对更大。越靠后的卷积层或卷积层之后的那些全连接层往往参数越容易被剪掉。

剪枝效果

  • 一般50%-70%左右的稀疏性,准确率降低幅度并不大
  • 剪枝是独立于量化技巧,通常与量化配合效果不错
  • 可以通过微调尝试不同的参数组合
http://www.lryc.cn/news/427649.html

相关文章:

  • JVM的组成
  • 快速上手 iOS Protocol Buffer
  • 每天一个数据分析题(四百八十)- 线性回归建模
  • 电动汽车和混动汽车DC-DC转换器的创新设计与测试方法
  • OriginPro快速上手指南:数据可视化与分析的利器
  • 缓存学习
  • 亚世光电:消费电子年度表演
  • AI 工程应用 建筑表面检测及修复
  • Qt-Qt中的小事项(7)
  • Android MediaRecorder 视频录制及报错解决
  • HarmonyOS应用程序访问控制探究
  • 董卫民赴考拉悠然等企业调研,强调加快发展人工智能产业
  • MFC将类A中的事件在类B中处理采用回调函数实现
  • 公众号 微信登录
  • sanic + webSocket:股票实时行情推送服务实现
  • Unity动态给按钮各个状态下的图片赋值
  • xiaomi pad 6PRO 小米平板6 pro hyperOS降级 澎湃os 降级MIUI 14 教程 免解锁BL 降级,168小时解锁绑定
  • MySQL 备份一个表
  • 鸿蒙开发入门day10-组件导航
  • 虚拟机Linux的坑 | VMware无法从主机向虚拟机 跨系统复制粘贴拖动 文件/文本
  • Chat App 项目之解析(二)
  • 数据结构与算法 - 双指针
  • Python3网络爬虫开发实战(10)模拟登录(需补充账号池的构建)
  • SQL 调优最佳实践笔记
  • Eclipse的使用配置教程:必要设置、创建工程及可能遇到的问题(很详细,很全面,能解决90%的问题)
  • 遗传算法与深度学习实战(4)——遗传算法详解与实现
  • Nginx+Tomcat实现负载均衡、动静分离集群部署
  • 英语学习8月19日
  • 关于windows环境使用nginx的一些性能问题
  • “解决Windows电脑无法投影到其他屏幕的问题:尝试更新驱动程序或更换视频卡“