当前位置: 首页 > news >正文

对比学习损失—InfoNCE理论理解

InfoNoise的理解

InfoNCE loss

最近在看对比学习的东西,记录点基础的东西

「对比学习」 属于无监督学习的一种,给一堆数据,没有标签,自己学习出一种特征表示。

InfoNCE 这个损失是来自于论文:Momentum Contrast for Unsupervised Visual Representation Learning.

MoCo提出,我们可以把对比学习看成是一个字典查询的任务,即训练一个编码器从而去做字典查询的任务。假设已经有一个编码好的query q q q(一个特征),以及一系列编码好的样本 k 0 , k 1 , k 2 , . . . k_0, k_1, k_2,... k0,k1,k2,...,那么 k 0 , k 1 , k 2 , . . . k_0, k_1, k_2,... k0,k1,k2,...可以看作是字典里的key。假设字典里只有一个key k + k_+ k+(称为 positive)是跟 q q q 匹配的,它们就互为正样本对,其余的key为 q q q 的负样本。一旦定义好了正负样本对,就需要一个对比学习的损失函数来指导模型进行学习。

在这里插入图片描述
这个损失函数显然要满足要求:

  • q q q 和唯一的正样本 k + k_+ k+相似,并且和其他所有负样本key都不相似的时候,这个loss的值应该比较低。
  • q q q k + k_+ k+ 不相似,或者和其他负样本的key相似了,那么loss就应该大,从而惩罚模型。

(嗯,合情合理,符合逻辑)

InfoNCE loss公式如下: L q = − l o g e x p ( q ⋅ k + / τ ) ∑ i = 0 k e x p ( q ⋅ k i / τ ) L_q=-log\frac{exp(q\cdot k_+ / \tau)}{\sum_{i=0}^k exp(q\cdot k_i / \tau)} Lq=logi=0kexp(qki/τ)exp(qk+/τ)Info NCE loss其实是NCE的一个简单变体,它认为如果只把问题看作是一个二分类,只有数据样本和噪声样本的话,可能对模型学习不友好,因为很多噪声样本可能本就不是一个类,因此还是把它看成一个多分类问题比较合理(但这里的多分类 k k k 指代的是负采样之后负样本的数量)。于是就有了InfoNCE loss

先看一下softmax公式: y ^ = s o f t m a x ( z ) = e x p ( z ) ∑ i = 0 k e x p ( z i ) \hat y=softmax(z)=\frac{exp(z)}{\sum_{i=0}^k exp(z_i)} y^=softmax(z)=i=0kexp(zi)exp(z)而交叉熵损失函数为: L ( y ^ ) = − ∑ i = 0 k y i l o g ( y ^ i ) L(\hat y)=-\sum_{i=0}^ky_ilog(\hat y_i) L(y^)=i=0kyilog(y^i)仔细观察上面的交叉熵的计算公式可以知道,因为 y i y_i yi的元素不是0就是1,而且又是乘法,所以很自然地我们如果知道1所对应的index,那么就不用做其他无意义的运算了。

在监督学习下,ground truth是一个one-hot向量,softmax的 y ^ \hat y y^结果取 − l o g -log log,再与ground truth相乘,即得到如下交叉熵损失: − l o g e x p ( z ) ∑ i = 0 k e x p ( z i ) -log\frac{exp(z)}{\sum_{i=0}^k exp(z_i)} logi=0kexp(zi)exp(z)

上式中, q ⋅ k q\cdot k qk 是模型出来的logits,相当于softmax公式中的 z z z τ \tau τ是一个温度超参,是个标量,假设我们忽略,那么infoNCE loss其实就是cross entropy loss。唯一的区别是,在cross entropy loss里, k k k 指代的是数据集里类别的数量,而在对比学习InfoNCE loss里,这个 k k k 指的是负样本的数量。上式分母中的 ∑ \sum 是在1个正样本和 k k k个负样本上做的,从0到k,所以共 k + 1 k+1 k+1 个样本,也就是字典里所有的key。MoCo里提到,InfoNCE loss其实就是一个cross entropy loss,做的是一个k+1类的分类任务,目的就是想把这个 q q q 图片分到 k + k_+ k+这个类。

温度系数 τ \tau τ

再来说一下这个温度系数 τ \tau τ,虽然只是一个超参数,但它的设置是非常讲究的,直接影响了模型的效果。

上式Info NCE loss中的相当于是logits,温度系数可以用来控制logits的分布形状。对于既定的logits分布的形状,当 τ \tau τ值变大,则 1 / τ 1/\tau 1/τ就变小,则 q ⋅ k / τ q\cdot k/\tau qk/τ 会使得原来logits分布里的数值都变小,且经过指数运算之后,就变得更小了,导致原来的logits分布变得更平滑。相反,如果 τ \tau τ取得值小,就 1 / τ 1/\tau 1/τ 变大,原来的logits分布里的数值就相应的变大,经过指数运算之后,就变得更大,使得这个分布变得更集中,更加的peak。
如果温度系数设的越大,logits分布变得越平滑,那么对比损失会对所有的负样本一视同仁,导致模型学习没有轻重。如果温度系数设的过小,则模型会越关注特别困难的负样本,但其实那些负样本很可能是潜在的正样本,这样会导致模型很难收敛或者泛化能力差。

总之,温度系数的作用就是控制模型对负样本的区分度

http://www.lryc.cn/news/129554.html

相关文章:

  • 贝锐蒲公英助力电子公交站牌联网远程运维,打造智慧出行新趋势
  • SpringBoot + Vue 微人事(十)
  • 【Redis】Redis哨兵模式
  • 系统架构师---软件重用、基于架构的软件设计、软件模型
  • 【Web开发指南】MyEclipse XML编辑器的高级功能简介
  • 设计模式-观察者模式(观察者模式的需求衍变过程详解,关于监听的理解)
  • vue+electron中实现文件下载打开wps预览
  • 第4章 性能分析中的术语和指标
  • 数字化转型能带来哪些价值?_光点科技
  • 适用于Android™的Windows子系统Windows Subsystem fo r Android™Win11安装指南
  • hive高频使用的拼接函数及“避坑”
  • windows ipv4 多ip地址设置,默认网关跃点和自动跃点是什么意思?(跃点数)
  • java_免费文本翻译API_小牛翻译
  • flink消费kafka数据,按照指定时间开始消费
  • 【SpringCloud】Feign使用
  • WebApIs 第五天
  • 按斤称的C++散知识
  • C++策略模式
  • 如何在网页下载腾讯视频为本地MP4格式
  • opencv-yolov8-目标检测
  • CRYPTO 密码学-笔记
  • 基于YOLOv8模型的五类动物目标检测系统(PyTorch+Pyside6+YOLOv8模型)
  • Java课题笔记~ SpringBoot基础配置
  • vue实现文件上传,前后端
  • OJ练习第151题——克隆图
  • keepalived+lvs实现高可用
  • 【Let‘s make it big】英语合集61~70
  • python实现图像的二分类
  • 8.深浅拷贝和异常处理
  • Element Plus el-table 数据为空时自定义内容【默认为 No Data】