Dynamic Sparse No Training: Training-Free Fine-tuning for Sparse LLMs
大语言模型(LLM)在设备上部署道路上落下了一个令人生畏的障碍。本文关注于大语言模型的剪枝算法。
动态稀疏训练(Dynamic Sparse Training,DST)是一种近期收到广泛关注的剪枝算法。与之前大部分剪枝方法需要训练整个网络不同,DST选择性更新一部分网络参数并允许稀疏网络拓扑动态进化。然而先前研究显示其在小规模BERT级别语言模型上微调的失败。
本文方法
本文算法将剪枝视作为设计一个二进制掩码指示权重是否移除。在给定剪枝率 p p p 条件下的LLM剪枝问题可以描述为:
min M , W ∣ ∣ W ∗ A − ( M ⋅ W ) ∗ A ∣ ∣ 2 s . t . 1 − ∣ ∣ M ∣ ∣ 0 C o u t ⋅ C i n = p \min_{M,W}||W*A-(M\cdot W)*A||_{2}\quad s.t. \quad 1-\frac{||M||_{0}}{C_{out}\cdot C_{in}}=p M,Wmin∣∣W∗A−(M⋅W)∗A∣∣2s.t.1−Cout⋅Cin∣∣M∣∣0=p
该问题可以从两个互补角度求解,1. 设计标准来剪枝对模型影响最小的权重,2. 对于获得的稀疏网络,剩下的权重自然地进行微调以进一步减少重建误差。这些常规的求解方式需要大量训练资源,对于大容量的LLM模型并不适用。
本文关注如何高效减少给定剪枝稀疏网络与对应密集网络间重建损失。本文不使用全微调或部分更新的方法恢复性能,而是根据对重建损失贡献在见之后细化稀疏掩码。本文方法源于Rigging the lottery:
Making all tickets winners中动态稀疏训练使用的剪枝-生长操作。DST在稀疏网络训练中包含权重剪枝和权重生长过程。基于此方法,本文DSþT,一种稀疏LLM无训练微调方法,该方法剥离权重更新并通过将优化目标转化为每个全中行的重建误差保持剪枝和增长。剪枝-生长过程与网络独立进行,并使用迭代方式逐渐优化稀疏掩码
DSþT从一个稀疏的LLM网络开始,可以使用任何已有的评估标准剪枝。然后通过查看重建损失执行迭代剪枝与生长。
生长标准
给定稀疏权重行 M r ⊙ W r M_{r}\odot W_{r} Mr⊙Wr,尝试恢复剪枝权重使其在不同输入激活上获得最多的重建误差 Δ r \Delta_{r} Δr。这里的重建标准同时考虑重建误差变化的期望和方差。索引为i的恢复权重表示为:
i = { arg max k ¬ M r , k ⋅ W r , k ⋅ E [ A r ] / Var ( A r ) , if E [ Δ r ] > 0 , arg min k ¬ M r , k ⋅ W r , k ⋅ E [ A r ] / Var ( A r ) , otherwise, i=\left\{\begin{array}{l} \underset{k}{\arg \max } \neg \mathbf{M}_{r, k} \cdot \mathbf{W}_{r, k} \cdot \mathbb{E}\left[\mathbf{A}_{r}\right] / \operatorname{Var}\left(\mathbf{A}_{r}\right), \text { if } \mathbb{E}\left[\Delta_{r}\right]>0, \\ \underset{k}{\arg \min } \neg \mathbf{M}_{r, k} \cdot \mathbf{W}_{r, k} \cdot \mathbb{E}\left[\mathbf{A}_{r}\right] / \operatorname{Var}\left(\mathbf{A}_{r}\right), \text { otherwise, } \end{array}\right. i=⎩ ⎨ ⎧kargmax¬Mr,k⋅Wr,k⋅E[Ar]/Var(Ar), if E[Δr]>0,kargmin¬Mr,k⋅Wr,k⋅E[Ar]/Var(Ar), otherwise,
这里考虑引入输入激活的方差主要因为如果权重对 Δ r \Delta_{r} Δr 的影响在不同输入之间表现出很高的方差,那么恢复它可能不会导致稳定的误差减小。
剪枝标准
在选择恢复权重后,需要选择其余的权重进行剪枝以维护固定的稀疏率。对于剪枝标准,本文使用Wanda标准的变化版本。除了剪枝权重的标准指标,本文的标准强制要求所选权重在剪枝时应该为减少重建损失做出积极贡献。这有助于在不影响无训练微调重建损失稳定减少情况下保留关键权重。
i = { arg max k , M r , k ] ⋅ W r , k ⋅ E [ A r ] < 0 M r , k ⋅ ∣ W r , k ⋅ ∣ ∣ A r ∣ ∣ 2 , if E [ Δ r ] > 0 , arg max k , M r , k ] ⋅ W r , k ⋅ E [ A r ] > 0 M r , k ⋅ ∣ W r , k ⋅ ∣ ∣ A r ∣ ∣ 2 , if o t h e r w i s e , i=\left\{\begin{array}{l} \underset{k,M_{r,k}]\cdot W_{r,k}\cdot E[A_{r}]<0}{\arg \max } M_{r,k}\cdot |W_{r,k}\cdot ||A_{r}||_{2}, \text { if } \mathbb{E}\left[\Delta_{r}\right]>0, \\ \underset{k,M_{r,k}]\cdot W_{r,k}\cdot E[A_{r}]>0}{\arg \max } M_{r,k}\cdot |W_{r,k}\cdot ||A_{r}||_{2}, \text { if } otherwise, \\ \end{array}\right. i=⎩ ⎨ ⎧k,Mr,k]⋅Wr,k⋅E[Ar]<0argmaxMr,k⋅∣Wr,k⋅∣∣Ar∣∣2, if E[Δr]>0,k,Mr,k]⋅Wr,k⋅E[Ar]>0argmaxMr,k⋅∣Wr,k⋅∣∣Ar∣∣2, if otherwise,