当前位置: 首页 > news >正文

对比学习训练是如何进行的

对比学习(Contrastive Learning)是一种自监督学习的方法,旨在通过拉近相似样本的表示、拉远不相似样本的表示来学习特征表示。在训练过程中,模型并不依赖标签,而是通过样本之间的相似性进行学习。以下是对比学习的基本原理和具体的训练流程:

1. 基本原理

对比学习的核心目标是通过构造正样本对(相似样本)和负样本对(不相似样本),让模型学习到对相似样本的特征表示更接近,而对不相似样本的特征表示更远。常用的对比学习方法有 SimCLR、MoCo 等。

  • 正样本对:指的是经过不同增强方式得到的同一图像的不同视角,或在一些情况下是语义上相关的图像对。
  • 负样本对:指的是不同图像对,它们在语义上或像素空间上不相关。

2. 对比学习的训练流程

以MoCo为例:
在这里插入图片描述

其中momentum encoder是动量编码器,将encoder中的k的参数更新过程使用动量公式来约束,在MoCo中,作者将m设置为0.99(即momentum encoder中k的参数除了刚开始赋值给他,encoder不进行反向传播更新,往后全靠自己更新),这样就可以使得k的参数更新更依赖于之前k的参数了。

步骤1:样本增强

  • 对每个输入样本(例如图像),通过数据增强(如随机裁剪、旋转、颜色扰动等)生成多个视图。每个样本经过增强后形成一个正样本对,即该样本的两个不同增强版本。

步骤2:特征提取

  • 将增强后的样本输入到神经网络(如卷积神经网络或 Transformer)中,提取它们的特征表示。特征提取器通常不带标签地训练,模型在这个过程中学习到数据的潜在结构。

步骤3:相似性度量

  • 对每个样本对,计算它们的特征表示之间的相似度。通常使用余弦相似度(Cosine Similarity)来衡量特征向量之间的相似性。

    • 对于正样本对(相同样本的不同视图),希望它们的特征表示尽量接近,即相似度高。
    • 对于负样本对(不同样本),希望它们的特征表示尽量远,即相似度低。

步骤4:损失函数

  • 对比学习常用的损失函数是对比损失(Contrastive Loss)NCE(Noise Contrastive Estimation)损失,其中最常用的是 InfoNCE 损失。该损失函数通过极大化正样本对的相似性,极小化负样本对的相似性来优化模型。

    InfoNCE 损失函数的公式如下:
    在这里插入图片描述
    在反向传播过程中,L分别对q和k中的权重w微分来进行参数更新,使得权重作用于与q相似性高的k
    后的loss更低

步骤5:优化与更新

  • 利用梯度下降算法最小化对比损失,从而更新网络参数,使模型能够学到更好的特征表示。

3. 监督对比学习

在有标签的情况下,可以利用标签信息来构造更加有效的正负样本对。**监督对比学习(Supervised Contrastive Learning)**通过使用相同类别的样本作为正样本对,不同类别的样本作为负样本对,这种方式可以进一步提升模型的分类性能。具体步骤如下:

  1. 构建正样本对:对于每个样本,选择与其类别相同的其他样本作为正样本,而非只依赖数据增强生成正样本对。
  2. 构建负样本对:选择不同类别的样本作为负样本。

通过引入监督信息,监督对比学习可以更加有效地对齐同类别样本的特征表示,从而提升模型的泛化能力。

这种方式确保了模型能够更好地利用语言模式中的信息,增强对跨领域数据的泛化能力。相当于对正样本对最大化,负样本对最小化来使模型对于同类别的样本有着更好的辨识能力;在一些E2D的模型中,可以保留或冻结编码器部分当作预训练编码器,然后进行下游任务。

http://www.lryc.cn/news/451068.html

相关文章:

  • React 生命周期 - useEffect 介绍
  • OpenCV-指纹识别
  • IPD的核心思想
  • 如何在算家云搭建MVSEP-MDX23(音频分离)
  • 常用的Java安全框架
  • 使用 PHP 的 strip_tags函数保护您的应用安全
  • 您的计算机已被Lockbit3.0勒索病毒感染?恢复您的数据的方法在这里!
  • 经典sql题(十二)UDTF之Explode炸裂函数
  • 【AIGC】ChatGPT提示词解析:如何打造个人IP、CSDN爆款技术文案与高效教案设计
  • 【Ubuntu】Ubuntu常用命令
  • 架构设计笔记-5-软件工程基础知识-2
  • [网络]抓包工具介绍 tcpdump
  • 基于STM32和FPGA的射频数据采集系统设计流程
  • 自动变速箱系统(A/T)详细解析
  • 【Kubernetes】常见面试题汇总(四十三)
  • OpenCL 学习(1)---- OpenCL 基本概念
  • 自定义注解加 AOP 实现服务接口鉴权以及内部认证
  • 《软件工程概论》作业一:新冠疫情下软件产品设计(小区电梯实体按钮的软件替代方案)
  • 基于Ernie-Bot打造语音对话功能
  • 动手学深度学习(李沐)PyTorch 第 3 章 线性神经网络
  • ROS理论与实践学习笔记——2 ROS通信机制之服务通信
  • 技术成神之路:设计模式(十八)适配器模式
  • 图神经网络:处理复杂关系结构与图分类任务的强大工具
  • LeetCode: 1971. 寻找图中是否存在路径
  • mysql 查询表所有数据,分页的语句
  • TI DSP TMS320F280025 Note13:CPUtimer定时器原理分析与使用
  • Australis 相機率定軟體說明
  • C++入门(有C语言基础)
  • 第四届高性能计算与通信工程国际学术会议(HPCCE 2024)
  • 负载均衡架构解说