当前位置: 首页 > news >正文

【LLM学习笔记】第四篇:模型压缩方法——量化、剪枝、蒸馏、分解

文章目录

      • 1. 为什么要进行模型压缩
      • 2. 模型量化
        • 2.1 常见数据类型
        • 2.2 浮点数表示
        • 2.3 线性量化
        • 2.4 非线性量化
        • 2.5 挑战
        • 2.6 实际应用
      • 3. 模型剪枝
      • 4. 模型蒸馏
        • 4.1 模型蒸馏的基本流程
        • 4.2 模型蒸馏的优势
        • 4.3 实际应用
      • 5. 低秩分解(低秩近似)
        • 5.1 基本概念
        • 5.2 实际应用
      • 6. 总结

1. 为什么要进行模型压缩

模型压缩是通过压缩模型参数使用更高效的表示方式,减少模型所需的存储空间的方法。可以减少模型计算过程中的乘法和加法操作,降低计算开销。在模型压缩过程中,应尽量减小对模型性能的影响,保持模型在任务上的精度损失最小化

在这里插入图片描述

下面讲分别详细介绍模型压缩的4种常用方法。

2. 模型量化

模型量化是模型压缩技术中的一种重要方法,其主要目标是通过减少模型参数的表示精度来降低模型的存储空间和计算复杂度:

具体来说模型量化是指将模型中的浮点数参数转换为低精度的数据类型,如从32位浮点数(FP32)转换为16位浮点数(FP16)、8位整数(INT8)甚至是更低的精度(如4位整数、1位二值网络)。这种转换可以显著减少模型的存储空间和计算开销。
在这里插入图片描述

模型量化有下面几个目的:

  1. 减少模型显存占用:通过降低参数的表示精度,减少模型所需的存储空间。
  2. 加快推理速度:低精度数据类型的计算通常更快,可以提高模型的推理速度。
  3. 降低内存带宽需求:减少数据传输量,降低内存带宽需求。
  4. 降低功耗:减少数据传输和计算所需的能量,降低功耗。
2.1 常见数据类型
  1. FP32:32位浮点数,是最常用的高精度表示方式。
  2. FP16:16位浮点数,数值范围比FP32小,但占用内存较少。
  3. BF16:16位截断的FP32,增加指数位,数值范围更广,常用于深度学习。
  4. INT8:8位整数,位数仅为FP32的1/4,适用于模型参数的数据范围映射。
  5. INT4:4位整数,进一步减少位数,适用于极端资源受限的场景。
  6. 二值网络(Binary Network):1位二值网络,参数只能取0或1,计算效率极高但精度损失较大。
2.2 浮点数表示

浮点数通常遵循IEEE-754标准,由符号位、指数位和尾数位组成:

  • 符号位:表示数值的正负。
  • 指数位:表示数值的大小范围。
  • 尾数位:表示数值的精度。

这里为了方便理解,我将通过一个实例说明浮点数在计算机中的存储方式:如何将十进制数0.6转换成 IEEE 754 标准下的单精度浮点数(32位)。

第一步:十进制转换为二进制

首先,将 0.6 转换成二进制小数。

  1. 乘以 2

    • 0.6 * 2 = 1.2,取整数部分 1,留下小数部分 0.2
    • 0.2 * 2 = 0.4,取整数部分 0,留下小数部分 0.4
    • 0.4 * 2 = 0.8,取整数部分 0,留下小数部分 0.8
    • 0.8 * 2 = 1.6,取整数部分 1,留下小数部分 0.6
    • 0.6 * 2 = 1.2,取整数部分 1,留下小数部分 0.2 (这里开始循环)
  2. 结果

    • 经过上述步骤,我们发现 0.6 的二进制表示是一个无限循环小数 0.1001100110011...。为了简化,我们可以取前几位进行近似表示,例如 0.10011001100110011001100

第二步:二进制科学计数法

接下来,将二进制小数转换成二进制科学计数法的形式。

  1. 标准化

    • 0.10011001100110011001100 可以写成 1.0011001100110011001100 × 2^-1
  2. 指数

    • 指数部分为 -1

第三步:浮点数在内存中的表示方式

根据 IEEE 754 单精度格式,一个浮点数由三部分组成:

  • 符号位(S,1位):表示数的正负,0 表示正数,1 表示负数。
  • 指数位(E,8位):表示科学计数法中的指数,但需要加上一个偏移量(对于单精度,偏移量为 127)。
  • 尾数位(M,23位):表示小数点后面的数字,由于二进制科学计数法中总是有前导位 1,这个 1 在存储时不保存,只保存后续的位。

对于 0.6

  1. 符号位 S0(因为 0.6 是正数)。
  2. 指数位 E-1 + 127 = 126126 的二进制形式是 01111110
  3. 尾数位 M0011001100110011001100。(因为开头都是1所以只存储尾数位就可以了)

因此,0.6 的单精度浮点数表示为:

  • 符号位 S0
  • 指数位 E01111110
  • 尾数位 M0011001100110011001100

把这三部分组合起来,得到的二进制序列是:

0 01111110 0011001100110011001100

这就是 0.6 在内存中的二进制表示形式。

2.3 线性量化

线性量化是最简单的量化方法,通过线性映射将浮点数转换为低精度的整数。具体步骤如下:

  1. 确定量化范围:找到输入浮点数据的最大值和最小值。
  2. 计算缩放因子和零点
    • 缩放因子 S = max − min max_int − min_int S = \frac{\text{max} - \text{min}}{\text{max\_int} - \text{min\_int}} S=max_intmin_intmaxmin
    • 零点 Z = min − min_int × S S Z = \frac{\text{min} - \text{min\_int} \times S}{S} Z=Sminmin_int×S
  3. 量化公式
    • 定点数据 Q = round ( F − Z S ) Q = \text{round}\left(\frac{F - Z}{S}\right) Q=round(SFZ)
    • 浮点数据 F = S × Q + Z F = S \times Q + Z F=S×Q+Z
2.4 非线性量化

非线性量化方法包括但不限于:

  • 对数量化:使用对数函数进行映射。
  • KL散度量化:通过计算KL散度找到最佳的量化阈值。
2.5 挑战
  1. 精度损失

    • 低比特量化:比特数越低,精度损失越大。
    • 任务复杂度:任务越复杂,精度损失越大。
    • 模型大小:模型越小,精度损失越大。
  2. 硬件支持

    • 不同硬件支持的低比特指令不同。
    • 不同硬件提供的低比特指令计算方式不同。
    • 不同硬件体系结构的内核优化方式不同。
  3. 软件算法加速

    • 混合比特量化需要进行量化和反量化,插入Cast算子影响内核执行性能。
    • 降低运行时内存占用与降低模型参数量的差异。
    • 模型参数量小,压缩比高,不代表执行内存占用少。
2.6 实际应用
  1. 移动设备:在移动设备上部署模型时,存储空间和计算资源受限,模型量化可以显著提高模型的运行效率。
  2. 物联网设备:存储和计算资源极为有限,模型量化有助于将模型部署到这些设备上。
  3. 在线服务系统:实时处理大量用户数据,模型量化可以提高系统的响应速度和吞吐量。
  4. 大模型压缩:大语言模型参数量巨大,模型量化可以适应资源受限的部署环境。
  5. 自动驾驶:对实时性能和计算资源要求高,模型量化有助于优化模型以适应场景需求。

通过模型量化,可以在保持模型性能的同时,显著降低模型的存储和计算成本,适用于多种应用场景。

3. 模型剪枝

模型剪枝是一种用于减少神经网络模型大小和计算复杂度的技术,同时尽量保持模型的性能。通过去除网络中的一些冗余或不重要的连接或节点,可以显著减少模型的参数数量,从而提高推理速度和降低存储需求。

在这里插入图片描述

关于模型剪枝的具体分类以及基于PyTorch的实现方法,我在下面文章中已经详细说明过:

  • 【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(上,非结构化剪枝)
  • 【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(下,结构化剪枝)

4. 模型蒸馏

模型蒸馏(Model Distillation),也称为知识蒸馏(Knowledge Distillation),是一种将大型复杂模型的知识转移到小型简单模型的技术。这种方法的主要目的是在保持模型性能的同时,减少模型的大小和计算复杂度,从而使其更适合部署在资源受限的设备上。模型蒸馏的核心思想是让一个学生模型(小模型)从一个教师模型(大模型)中学习,而不是直接从原始数据中学习。

在这里插入图片描述

4.1 模型蒸馏的基本流程
  1. 教师模型训练

    • 训练一个高性能的大型模型(教师模型)。这个模型通常具有复杂的结构和大量的参数,能够在任务上取得很好的性能。
  2. 学生模型初始化

    • 初始化一个结构更简单、参数更少的小模型(学生模型)。学生模型的目标是从教师模型中学习知识,同时保持较高的性能。
  3. 蒸馏过程

    • 使用教师模型的输出作为软标签(soft labels)来训练学生模型。软标签是指教师模型对每个样本的预测概率分布,而不是原始数据的硬标签(hard labels)。
    • 学生模型不仅学习硬标签,还学习教师模型的软标签。这样可以让学生模型捕捉到教师模型的内部表示和泛化能力。
  4. 损失函数设计

    • 设计一个合适的损失函数,通常包括两部分:
      • 交叉熵损失:学生模型的预测与硬标签之间的交叉熵损失。
      • 蒸馏损失:学生模型的预测与教师模型的软标签之间的交叉熵损失。
    • 损失函数可以表示为:
      Loss = α ⋅ CrossEntropy ( y student , y true ) + ( 1 − α ) ⋅ CrossEntropy ( y student , y teacher ) \text{Loss} = \alpha \cdot \text{CrossEntropy}(y_{\text{student}}, y_{\text{true}}) + (1 - \alpha) \cdot \text{CrossEntropy}(y_{\text{student}}, y_{\text{teacher}}) Loss=αCrossEntropy(ystudent,ytrue)+(1α)CrossEntropy(ystudent,yteacher)
      其中, y student y_{\text{student}} ystudent 是学生模型的预测, y true y_{\text{true}} ytrue 是硬标签, y teacher y_{\text{teacher}} yteacher 是教师模型的软标签, α \alpha α 是一个平衡因子。
  5. 学生模型训练

    • 使用上述损失函数训练学生模型,直到收敛。
  6. 评估和微调

    • 评估学生模型的性能,如果有必要,可以进行微调以进一步提升性能。
4.2 模型蒸馏的优势
  1. 模型压缩:学生模型通常比教师模型小得多,占用的存储空间和计算资源更少。
  2. 推理加速:学生模型的推理速度更快,适合在资源受限的设备上部署。
  3. 性能保持:通过蒸馏过程,学生模型能够继承教师模型的一部分知识,保持较高的性能。
4.3 实际应用
  • 移动设备:在智能手机、IoT 设备等资源受限的环境中,部署轻量级的学生模型。
  • 边缘计算:在边缘服务器上运行高效的学生模型,减少数据传输延迟。
  • 实时应用:在需要快速响应的应用中,如自动驾驶、语音识别等,使用高效的学生模型。

5. 低秩分解(低秩近似)

5.1 基本概念

(一般是权重的)低秩分解(Low-Rank Decomposition)是一种用于降维和矩阵近似的技术,常用于减少矩阵的存储和计算成本,同时保留矩阵的主要信息。低秩分解在机器学习、数据压缩、推荐系统等领域有广泛应用。

低秩分解的基本概念是:给定一个 m × n m \times n m×n 的矩阵 A A A,低秩分解的目标是找到两个较小的矩阵 U U U V V V,使得 A ≈ U V T A \approx UV^T AUVT,其中 U U U m × k m \times k m×k的矩阵, V V V n × k n \times k n×k的矩阵, k k k 是一个小于 min ⁡ ( m , n ) \min(m, n) min(m,n)的正整数。这里的 k k k 称为秩(rank)。

在这里插入图片描述

关于秩以及低秩近似,我在此前也有介绍过:深度学习中的常用线性代数知识汇总——第一篇:基础概念、秩、奇异值

5.2 实际应用
  1. 推荐系统

    • 在用户-物品评分矩阵中,使用低秩分解可以找到用户的潜在兴趣和物品的潜在特征,从而进行个性化推荐。
  2. 图像压缩

    • 使用 SVD 对图像矩阵进行低秩近似,可以显著减少存储空间,同时保持图像的主要特征。
  3. 文本挖掘

    • 使用 NMF 对文档-词频矩阵进行分解,可以发现文档的主题和词的主题分布。

6. 总结

  1. 模型量化(Quantization):减少模型参数的表示精度,降低存储空间和计算复杂度。
  2. 参数剪枝(Pruning):删除模型中的不重要连接或参数,减少模型的大小和计算量。
  3. 知识蒸馏(Knowledge Distillation):通过构建一个轻量化的小模型,利用性能更好的教师模型的信息来监督训练学生模型。
  4. 低秩分解(Low-rank Factorization):将模型中执行计算的矩阵分解为低秩的子矩阵,减少模型参数的数量和计算复杂度。
http://www.lryc.cn/news/492268.html

相关文章:

  • python3 自动更新的缓存类
  • 英语知识网站开发:Spring Boot框架应用
  • 文件上传upload-labs-docker通关
  • git(Linux)
  • Doris实战—构建日志存储与分析平台
  • 【vue3+Typescript】unapp+stompsj模式下替代plus-websocket的封装模块
  • Tcon技术和Tconless技术介绍
  • C#-利用反射自动绑定请求标志类和具体执行命令类
  • 高中数学练习:初探均值换元法
  • 数据结构单链表,顺序表,广义表,多重链表,堆栈的学习
  • 【保姆级教程】使用lora微调LLM并在truthfulQA数据集评估(Part 2.在truthfulQA上评估LLM)
  • thinkphp中对请求封装
  • leetcode hot100【LeetCode 215.数组中的第K个最大元素】java实现
  • 簡單易懂:如何在Windows系統中修改IP地址?
  • Python中的23种设计模式:详细分类与总结
  • 日历使用及汉化——fullcalendar前端
  • 视频截断,使用 FFmpeg
  • 使用系统内NCCL环境重新编译Pytorch
  • 1. Klipper从安装到运行
  • docker 卸载与安装
  • 跨部门文件共享安全:平衡协作与风险的关键策略
  • 基于单片机的智慧小区人脸识别门禁系统
  • 【es6】原生js在页面上画矩形及删除的实现方法
  • 【git实践】分享一个适用于敏捷开发的分支管理策略
  • Redis与MySQL如何保证数据一致性
  • 基于微信小程序的教室预约系统+LW示例参考
  • Linux 安装 Git 服务器
  • 总结:Yarn资源管理
  • Python学习34天
  • 深入浅出 WebSocket:构建实时数据大屏的高级实践