当前位置: 首页 > news >正文

【深度学习】SOFT Top-k:用最优传输解锁可微的 Top-k 操作

【深度学习】SOFT Top-k:用最优传输解锁可微的 Top-k 操作

文章目录

  • 【深度学习】SOFT Top-k:用最优传输解锁可微的 Top-k 操作
    • 1 引言
    • 2 Top-k的“微分困境”
    • 3 核心思想:当Top-k遇上最优传输(OT)
      • 3.1 将Top-k问题参数化为OT问题
      • 3.2 通过熵正则化实现平滑
    • 4 SOFT Top-k 算子详解
    • 5 应用与效果
    • 6 总结
    • 参考文献

1 引言

在机器学习的世界里,top-k 操作无处不在。无论是在推荐系统中筛选出用户最可能喜欢的 k 个商品,在 k-NN 算法中寻找 k 个最近的邻居 ,还是在自然语言处理的束搜索(Beam Search)中保留 k 个最有可能的序列,top-k 都扮演着至关重要的角色。

然而,这个基础且强大的操作有一个“阿喀琉斯之踵”:它通常是不可微的。这意味着无法将它像普通层一样直接嵌入到深度神经网络中,然后使用梯度下降法进行端到端的训练。这极大地限制了许多新颖模型的设计。

为了绕开这个障碍,研究者们通常采用“两阶段训练”等妥协方案:先用一个代理损失(如交叉熵)训练特征提取网络,然后再将提取的特征用于 top-k 相关的任务。这种做法导致了训练目标和最终任务之间的不一致,往往会损害模型的性能表现。

那么,有没有办法让 top-k 变得可导呢?来自 Google 和佐治亚理工学院的研究者们在论文 Differentiable Top-k Operator with Optimal Transport 中提出了一种名为 SOFT (Scalable Optimal transport-based differenTiable) top-k 的算子,基于“最优传输”(Optimal Transport)的思想解决了这个问题。

2 Top-k的“微分困境”

为什么常规的 top-k 操作是不可导的?

  1. 算法实现的障碍top-k 的标准算法,如冒泡排序或快速选择(QuickSelect),都涉及到大量的索引交换和比较操作。这些基于逻辑和顺序的操作,其梯度要么无法定义,要么处处为零,无法为梯度下降提供有效信息。

  2. 数学本质的非连续性:从数学角度看,top-k 可以被视为一个映射,它将一组输入分数 x 映射到一个由 0 和 1 组成的指示向量 A(1代表该元素属于top-k,0则不属于)。这个映射是分段常数函数,因此是非连续的。

让我们以一个最简单的 top-1 例子来说明(即找出两个数 x1,x2x_1, x_2x1,x2 中较大的一个)。指示向量 A1A_1A1 的值(表示 x1x_1x1 是否为最大值)关于 x1x_1x1 的函数图像如下:

图1

图1 左侧为标准top-k算子,右侧为SOFT top-k算子。可以看出标准算子的输出是突变的,而SOFT算子是平滑的

x1<x2x_1 < x_2x1<x2 时,A1=0A_1=0A1=0。当 x1x_1x1 刚刚超过 x2x_2x2 时,A1A_1A1 会从 0 瞬间跳变到 1 。在 x1=x2x_1 = x_2x1=x2 这个点,函数是不可导的;而在其他所有地方,它的导数都是 0 。这样的梯度对于模型训练是毫无用处的。

3 核心思想:当Top-k遇上最优传输(OT)

既然直接微分此路不通,换一个思路:top-k 问题重构为一个最优传输(Optimal Transport, OT)问题

3.1 将Top-k问题参数化为OT问题

最优传输旨在以最低的“运输成本”将一个概率分布的“质量”转移到另一个概率分布上。将 top-k 重新定义为这样一个运输问题:

  • 源分布 μ\muμ:我们有 n 个输入分数 {xi}i=1n{\{x_i\}}_{i=1}^n{xi}i=1n,我们将它们看作 n 个质量均为 1/n1/n1/n 的源点。
  • 目标分布 ν\nuν:我们设立两个目标点 0,1{0, 1}0,1。我们希望将 k 个单位的质量运到点 0(代表“top-k集合”),剩下的 n-k 个单位的质量运到点 1(代表“非top-k集合”)。
  • 运输成本 CCC:从源点 xix_ixi 运输到目标点 yjy_jyj (其中yj∈0,1y_j \in {0, 1}yj0,1) 的成本定义为它们之间的欧氏距离的平方,即 Cij=(xi−yj)2C_{ij} = (x_i - y_j)^2Cij=(xiyj)2

直观上,为了最小化总运输成本,模型会自发地将那些离 0 最近的 k 个 xix_ixi(也就是 k 个最小的数)运输到目标点 0,将其余离 1 最近的 n-k 个 xix_ixi(也就是 n-k 个最大的数)运输到目标点 1。

这个 OT 问题的解,即“最优运输方案” Γ∗\Gamma^*Γ, 是一个 n×2n \times 2n×2 的矩阵。Γi,1∗\Gamma^*_{i,1}Γi,1 代表从 xix_ixi 运到 0 的质量。可以证明,这个运输方案恰好可以用来表示 top-k 的结果。

3.2 通过熵正则化实现平滑

虽然 OT 的框架很优雅,但标准 OT 问题的解对于输入的变化仍然不具备可微性。

关键的第二步来了:为 OT 问题引入熵正则化(Entropy Regularization)。具体来说,在最小化运输成本的同时,我们还希望最大化运输方案 Γ\GammaΓ 的熵 H(Γ)H(\Gamma)H(Γ)。目标函数变为:

Γ∗,ϵ=arg⁡min⁡Γ⟨C,Γ⟩+ϵH(Γ)\Gamma^{*,\epsilon} = \arg\min_{\Gamma} \langle C, \Gamma \rangle + \epsilon H(\Gamma)Γ,ϵ=argΓminC,Γ+ϵH(Γ)
其中 ϵ\epsilonϵ 是一个大于0的正则化系数。

图2

图2 不同平滑参数 epsilon 下的SOFT算子输出

熵正则化就像一个“平滑器”。它使得运输方案 Γ∗,ϵ\Gamma^{*,\epsilon}Γ,ϵ 不再是“非此即彼”的硬分配,而是变成了一个“软”的、模糊的分配。每个输入 xix_ixi 都会将一部分质量分配给 0,一部分分配给 1。最终,top-k 的指示向量 AϵA^\epsilonAϵ 中的值不再是刚性的 0 或 1,而是位于 (0, 1) 之间的平滑数值。这个熵正则化最优传输(EOT)问题的解 AϵA^\epsilonAϵ 对于输入分数 XXX 是可微的

4 SOFT Top-k 算子详解

基于上述思想,SOFT top-k 算子诞生了。它的工作流程分为前向传播和反向传播。

注:关于前向传播(Forward Pass)与反向传播(Backward Pass)这两个概念的介绍,可以参见我的这一篇文章:【深度学习】一文彻底搞懂前向传播(Forward Pass)与反向传播(Backward Pass)。

  • 前向传播:给定输入分数 XXX,通过高效的 Sinkhorn 算法 来求解熵正则化 OT 问题,从而计算出平滑后的 top-k 指示向量 AϵA^\epsilonAϵ

  • 反向传播:计算 AϵA^\epsilonAϵ 相对于输入 XXX 的雅可比矩阵(梯度)。如果直接对 Sinkhorn 算法的迭代过程应用自动微分,会占用巨大的内存。因此,作者们利用了 EOT 问题的 KKT (Karush-Kuhn-Tucker) 最优性条件,通过隐式微分技术,推导出了计算雅可比矩阵的解析表达式。这种方法在计算上十分高效,其时间和空间复杂度仅为 O(n)\mathcal{O}(n)O(n)

  • 超参数 ϵ\epsilonϵ 的权衡ϵ\epsilonϵ 控制着近似的程度。

    • ϵ\epsilonϵ:近似偏差小,结果更接近真实的 top-k,但函数更“陡峭”,平滑效果弱。
    • ϵ\epsilonϵ:平滑效果好,梯度计算更稳定,但近似偏差大,可能影响模型性能。
    • 有趣的是,偏差也和第 k 个元素与第 k+1 个元素之间的差距(gap)有关。如果差距很大,即使 ϵ\epsilonϵ 较大,偏差也可以很小。

此外,该框架还能被轻松扩展为 Sorted SOFT Top-k 算子,用于需要对 top-k 结果进行排序的场景,例如 Beam Search 。

5 应用与效果

为了验证 SOFT top-k 的威力,作者在三个典型的应用场景中进行了实验。

  1. k-NN 图像分类
    通过将 SOFT top-k 算子整合进 k-NN 分类器,作者们实现了一个可以端到端训练的神经网络 k-NN 模型。在 MNIST 和 CIFAR-10 数据集上的实验结果表明,该方法显著优于传统的两阶段训练方法和其他可微 top-k 的基线模型。

    算法MNISTCIFAR10
    kNN+pretrained CNN98.4%91.1%
    CE+CNN99.0%91.3%
    kNN+Softmax k times99.3%92.2%
    kNN+SOFT Top-k99.4%92.6%
表1:k-NN分类准确率对比
  1. 机器翻译中的束搜索 (Beam Search)
    在序列生成任务中,训练和推理之间的不一致(也称作“暴露偏差”)是一个长期存在的问题。作者将 Sorted SOFT Top-k 算子应用于 Beam Search 过程,使得搜索过程本身可以被整合到训练循环中。在 WMT’14 英法翻译任务上,该方法取得了约 0.9 BLEU 值的提升。

  2. 机器翻译中的 Top-k Attention
    传统的 Soft Attention 机制会考虑所有源端词语,可能导致注意力分散和冗余。作者使用 SOFT top-k 来选择性地关注最重要的 k 个源端词语,从而实现了一种稀疏的注意力机制。

图3

图3 Top-k Attention可视化(示意图),注意力变得稀疏且对齐清晰

在 WMT’16 英德翻译任务上,这种稀疏注意力机制也带来了约 0.8 BLEU 值的提升。

6 总结

top-k 操作的不可微性一直是深度学习领域的一个痛点。论文 Differentiable Top-k Operator with Optimal Transport 通过将 top-k 与熵正则化的最优传输理论相结合,提出了一种名为 SOFT top-k 的优雅解决方案:将离散的 top-k 选择问题转化为一个连续平滑的最优传输问题,实现了完全可微,并且具有高效、可扩展的前向和后向传播算法,并解锁了将 top-k 依赖的操作(如 k-NN、Beam Search)进行端到端训练的能力,在多个任务上取得了显著的效果提升。


参考文献

  1. Xie, Y., Dai, H., Chen, M., Dai, B., Zhao, T., Zha, H., Wei, W., & Pfister, T. (2020). Differentiable top-k operator with optimal transport. Advances in Neural Information Processing Systems, 33
http://www.lryc.cn/news/602347.html

相关文章:

  • 应急响应案例处置(下)
  • 应急响应处置案例(上)
  • 【LeetCode 热题 100】(一)哈希
  • 绿算技术携手昇腾发布高性能全闪硬盘缓存设备,推动AI大模型降本增效
  • 零基础部署网站?使用天翼云服务搭建语音听写应用系统
  • Angular 依赖注入
  • 谷歌浏览器深入用法全解析:解锁高效网络之旅
  • 图像处理第三篇:初级篇(续)—— 照明的理论知识
  • C++算法之单调栈
  • 达梦数据库获取每个数据库表的总条数及业务实战
  • 提取excel中的年月日
  • window显示驱动开发—Direct3D 11 视频播放改进
  • 你的连接不是专用连接
  • NI Ettus USRP X440 软件无线电
  • 28天0基础前端工程师完成Flask接口编写
  • Go 语言-->指针
  • Java-数构排序
  • WAIC看点:可交付AI登场,场景智能、专属知识将兑现下一代AI价值
  • vue怎么实现导入excel表功能
  • 基于开源AI智能名片链动2+1模式与S2B2C商城小程序的微商品牌规范化运营研究
  • IDEA 手动下载安装数据库驱动,IDEA无法下载数据库驱动问题解决方案,IDEA无法连接数据库解决方案(通用,Oracle为例)
  • idea启动java应用报错
  • 设计模式十二:门面模式 (FaçadePattern)
  • 结合项目阐述 设计模式:单例、工厂、观察者、代理
  • 记一次IDEA启动微服务卡住导致内存溢出问题
  • Java设计模式之<建造者模式>
  • idea编译报错 java: 非法字符: ‘\ufeff‘ 解决方案
  • 解决windows系统下 idea、CLion 控制台中文乱码问题
  • 机器学习sklearn:不纯度与决策树构建
  • Rust实战:AI与机器学习自动炒饭机器学习