【深度学习】SOFT Top-k:用最优传输解锁可微的 Top-k 操作
【深度学习】SOFT Top-k:用最优传输解锁可微的 Top-k 操作
文章目录
- 【深度学习】SOFT Top-k:用最优传输解锁可微的 Top-k 操作
- 1 引言
- 2 Top-k的“微分困境”
- 3 核心思想:当Top-k遇上最优传输(OT)
- 3.1 将Top-k问题参数化为OT问题
- 3.2 通过熵正则化实现平滑
- 4 SOFT Top-k 算子详解
- 5 应用与效果
- 6 总结
- 参考文献
1 引言
在机器学习的世界里,top-k
操作无处不在。无论是在推荐系统中筛选出用户最可能喜欢的 k 个商品,在 k-NN 算法中寻找 k 个最近的邻居 ,还是在自然语言处理的束搜索(Beam Search)中保留 k 个最有可能的序列,top-k
都扮演着至关重要的角色。
然而,这个基础且强大的操作有一个“阿喀琉斯之踵”:它通常是不可微的。这意味着无法将它像普通层一样直接嵌入到深度神经网络中,然后使用梯度下降法进行端到端的训练。这极大地限制了许多新颖模型的设计。
为了绕开这个障碍,研究者们通常采用“两阶段训练”等妥协方案:先用一个代理损失(如交叉熵)训练特征提取网络,然后再将提取的特征用于 top-k
相关的任务。这种做法导致了训练目标和最终任务之间的不一致,往往会损害模型的性能表现。
那么,有没有办法让 top-k
变得可导呢?来自 Google 和佐治亚理工学院的研究者们在论文 Differentiable Top-k Operator with Optimal Transport 中提出了一种名为 SOFT (Scalable Optimal transport-based differenTiable) top-k 的算子,基于“最优传输”(Optimal Transport)的思想解决了这个问题。
2 Top-k的“微分困境”
为什么常规的 top-k
操作是不可导的?
-
算法实现的障碍:
top-k
的标准算法,如冒泡排序或快速选择(QuickSelect),都涉及到大量的索引交换和比较操作。这些基于逻辑和顺序的操作,其梯度要么无法定义,要么处处为零,无法为梯度下降提供有效信息。 -
数学本质的非连续性:从数学角度看,
top-k
可以被视为一个映射,它将一组输入分数x
映射到一个由 0 和 1 组成的指示向量A
(1代表该元素属于top-k,0则不属于)。这个映射是分段常数函数,因此是非连续的。
让我们以一个最简单的 top-1
例子来说明(即找出两个数 x1,x2x_1, x_2x1,x2 中较大的一个)。指示向量 A1A_1A1 的值(表示 x1x_1x1 是否为最大值)关于 x1x_1x1 的函数图像如下:
当 x1<x2x_1 < x_2x1<x2 时,A1=0A_1=0A1=0。当 x1x_1x1 刚刚超过 x2x_2x2 时,A1A_1A1 会从 0 瞬间跳变到 1 。在 x1=x2x_1 = x_2x1=x2 这个点,函数是不可导的;而在其他所有地方,它的导数都是 0 。这样的梯度对于模型训练是毫无用处的。
3 核心思想:当Top-k遇上最优传输(OT)
既然直接微分此路不通,换一个思路:将 top-k
问题重构为一个最优传输(Optimal Transport, OT)问题 。
3.1 将Top-k问题参数化为OT问题
最优传输旨在以最低的“运输成本”将一个概率分布的“质量”转移到另一个概率分布上。将 top-k
重新定义为这样一个运输问题:
- 源分布 μ\muμ:我们有 n 个输入分数 {xi}i=1n{\{x_i\}}_{i=1}^n{xi}i=1n,我们将它们看作 n 个质量均为 1/n1/n1/n 的源点。
- 目标分布 ν\nuν:我们设立两个目标点 0,1{0, 1}0,1。我们希望将 k 个单位的质量运到点 0(代表“top-k集合”),剩下的 n-k 个单位的质量运到点 1(代表“非top-k集合”)。
- 运输成本 CCC:从源点 xix_ixi 运输到目标点 yjy_jyj (其中yj∈0,1y_j \in {0, 1}yj∈0,1) 的成本定义为它们之间的欧氏距离的平方,即 Cij=(xi−yj)2C_{ij} = (x_i - y_j)^2Cij=(xi−yj)2 。
直观上,为了最小化总运输成本,模型会自发地将那些离 0 最近的 k 个 xix_ixi(也就是 k 个最小的数)运输到目标点 0,将其余离 1 最近的 n-k 个 xix_ixi(也就是 n-k 个最大的数)运输到目标点 1。
这个 OT 问题的解,即“最优运输方案” Γ∗\Gamma^*Γ∗, 是一个 n×2n \times 2n×2 的矩阵。Γi,1∗\Gamma^*_{i,1}Γi,1∗ 代表从 xix_ixi 运到 0 的质量。可以证明,这个运输方案恰好可以用来表示 top-k
的结果。
3.2 通过熵正则化实现平滑
虽然 OT 的框架很优雅,但标准 OT 问题的解对于输入的变化仍然不具备可微性。
关键的第二步来了:为 OT 问题引入熵正则化(Entropy Regularization)。具体来说,在最小化运输成本的同时,我们还希望最大化运输方案 Γ\GammaΓ 的熵 H(Γ)H(\Gamma)H(Γ)。目标函数变为:
Γ∗,ϵ=argminΓ⟨C,Γ⟩+ϵH(Γ)\Gamma^{*,\epsilon} = \arg\min_{\Gamma} \langle C, \Gamma \rangle + \epsilon H(\Gamma)Γ∗,ϵ=argΓmin⟨C,Γ⟩+ϵH(Γ)
其中 ϵ\epsilonϵ 是一个大于0的正则化系数。
熵正则化就像一个“平滑器”。它使得运输方案 Γ∗,ϵ\Gamma^{*,\epsilon}Γ∗,ϵ 不再是“非此即彼”的硬分配,而是变成了一个“软”的、模糊的分配。每个输入 xix_ixi 都会将一部分质量分配给 0,一部分分配给 1。最终,top-k
的指示向量 AϵA^\epsilonAϵ 中的值不再是刚性的 0 或 1,而是位于 (0, 1) 之间的平滑数值。这个熵正则化最优传输(EOT)问题的解 AϵA^\epsilonAϵ 对于输入分数 XXX 是可微的。
4 SOFT Top-k 算子详解
基于上述思想,SOFT top-k 算子诞生了。它的工作流程分为前向传播和反向传播。
注:关于前向传播(Forward Pass)与反向传播(Backward Pass)这两个概念的介绍,可以参见我的这一篇文章:【深度学习】一文彻底搞懂前向传播(Forward Pass)与反向传播(Backward Pass)。
-
前向传播:给定输入分数 XXX,通过高效的 Sinkhorn 算法 来求解熵正则化 OT 问题,从而计算出平滑后的
top-k
指示向量 AϵA^\epsilonAϵ。 -
反向传播:计算 AϵA^\epsilonAϵ 相对于输入 XXX 的雅可比矩阵(梯度)。如果直接对 Sinkhorn 算法的迭代过程应用自动微分,会占用巨大的内存。因此,作者们利用了 EOT 问题的 KKT (Karush-Kuhn-Tucker) 最优性条件,通过隐式微分技术,推导出了计算雅可比矩阵的解析表达式。这种方法在计算上十分高效,其时间和空间复杂度仅为 O(n)\mathcal{O}(n)O(n)。
-
超参数 ϵ\epsilonϵ 的权衡:ϵ\epsilonϵ 控制着近似的程度。
- 小 ϵ\epsilonϵ:近似偏差小,结果更接近真实的
top-k
,但函数更“陡峭”,平滑效果弱。 - 大 ϵ\epsilonϵ:平滑效果好,梯度计算更稳定,但近似偏差大,可能影响模型性能。
- 有趣的是,偏差也和第 k 个元素与第 k+1 个元素之间的差距(gap)有关。如果差距很大,即使 ϵ\epsilonϵ 较大,偏差也可以很小。
- 小 ϵ\epsilonϵ:近似偏差小,结果更接近真实的
此外,该框架还能被轻松扩展为 Sorted SOFT Top-k 算子,用于需要对 top-k 结果进行排序的场景,例如 Beam Search 。
5 应用与效果
为了验证 SOFT top-k 的威力,作者在三个典型的应用场景中进行了实验。
-
k-NN 图像分类
通过将 SOFT top-k 算子整合进 k-NN 分类器,作者们实现了一个可以端到端训练的神经网络 k-NN 模型。在 MNIST 和 CIFAR-10 数据集上的实验结果表明,该方法显著优于传统的两阶段训练方法和其他可微top-k
的基线模型。算法 MNIST CIFAR10 kNN+pretrained CNN 98.4% 91.1% CE+CNN 99.0% 91.3% kNN+Softmax k times 99.3% 92.2% kNN+SOFT Top-k 99.4% 92.6%
-
机器翻译中的束搜索 (Beam Search)
在序列生成任务中,训练和推理之间的不一致(也称作“暴露偏差”)是一个长期存在的问题。作者将 Sorted SOFT Top-k 算子应用于 Beam Search 过程,使得搜索过程本身可以被整合到训练循环中。在 WMT’14 英法翻译任务上,该方法取得了约 0.9 BLEU 值的提升。 -
机器翻译中的 Top-k Attention
传统的 Soft Attention 机制会考虑所有源端词语,可能导致注意力分散和冗余。作者使用 SOFT top-k 来选择性地关注最重要的 k 个源端词语,从而实现了一种稀疏的注意力机制。
在 WMT’16 英德翻译任务上,这种稀疏注意力机制也带来了约 0.8 BLEU 值的提升。
6 总结
top-k
操作的不可微性一直是深度学习领域的一个痛点。论文 Differentiable Top-k Operator with Optimal Transport 通过将 top-k
与熵正则化的最优传输理论相结合,提出了一种名为 SOFT top-k 的优雅解决方案:将离散的 top-k
选择问题转化为一个连续平滑的最优传输问题,实现了完全可微,并且具有高效、可扩展的前向和后向传播算法,并解锁了将 top-k
依赖的操作(如 k-NN、Beam Search)进行端到端训练的能力,在多个任务上取得了显著的效果提升。
参考文献
- Xie, Y., Dai, H., Chen, M., Dai, B., Zhao, T., Zha, H., Wei, W., & Pfister, T. (2020). Differentiable top-k operator with optimal transport. Advances in Neural Information Processing Systems, 33