当前位置: 首页 > news >正文

PyTorch学习笔记:nn.TripletMarginLoss——三元组损失

PyTorch学习笔记:nn.TripletMarginLoss——三元组损失

torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction='mean')

功能:创建一个三元组损失函数(triplet loss),用于衡量输入数据x1,x2,x3x_1,x_2,x_3x1,x2,x3之间的相对相似性,其中输入样本又分别称为中立样本、正样本以及负样本,具体介绍可见论文《Learning shallow convolutional feature descriptors with triplet losses》

损失函数
L(x1,x2,x3)=max⁡{d(x1,x2)−d(x1,x3)+margin,0}L(x_1,x_2,x_3)=\max\{d(x_1,x_2)-d(x_1,x_3)+\text{margin},0\} L(x1,x2,x3)=max{d(x1,x2)d(x1,x3)+margin,0}
其中:
d(xi,yi)=∣∣xi−yi∣∣pd(x_i,y_i)=||x_i-y_i||_p d(xi,yi)=∣∣xiyip
该函数的作用就是拉进x1x_1x1x2x_2x2的距离,使它们更加相似,同时推离x1x_1x1x3x_3x3的距离,即使它们更加不同。

输入:

  • margin:边界距离,具体含义如公式所示,如果该值越大,则表明x1x_1x1x2x_2x2期望距离越近,x1x_1x1x3x_3x3期望距离越远。输入数据类型为浮点数(float),默认1.0;
  • p:用于计算两个向量距离的范数,具体含义如公式所示。输入数据类型为整数(int),默认2,即欧氏距离;
  • swap:是否使用距离交换,具体功能可见论文《Learning shallow convolutional feature descriptors with triplet losses》;
  • size_averagereduce已被弃用,具体功能由reduction替代;
  • reduction:指定损失输出的形式,有三种选择:none|mean|sumnone:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出,会输出一个数;sum:将得到的损失求和再输出,会输出一个数。

注意:

  • 输入的三个样本数据维数必须为二维(N,D)(N,D)(N,D),其中第二个维度DDD表示向量长度;
  • 如果reduction设置为none,则输出的数组维数为1,尺寸为(N)(N)(N)

代码案例

一般用法

import torch
import torch.nn as nn# reduction设为none便于查看损失计算的结果
triplet_loss = nn.TripletMarginLoss(reduction='none')
x1 = torch.randn(20).reshape(2,10)
x2 = torch.randn(20).reshape(2,10)
x3 = torch.randn(20).reshape(2,10)
loss = triplet_loss(x1, x2, x3)
print(x1)
print(x2)
print(x3)
print(loss)

输出

tensor([[-0.1419,  0.0550, -0.2996, -1.7194,  0.5485, -0.9163, -0.6983,  0.0239,1.2940, -0.4858],[ 1.8544, -0.2349, -0.2523, -1.6167,  0.7861, -1.7627,  0.3139, -1.5112,-0.3378,  0.0059]])
tensor([[-1.5967,  0.4007,  0.1468, -1.0085, -1.4989,  1.7531,  0.0865, -0.9080,-0.4046,  0.5229],[-1.8673, -0.4958,  1.0122, -1.8696,  0.1974, -0.8017, -1.0562, -2.1461,1.7112, -0.6001]])
tensor([[-1.0008,  1.5316,  0.0078,  1.1405, -0.0629,  0.4934, -1.8050, -1.0302,0.8676, -0.1988],[ 1.3015, -0.2786,  0.4215, -0.6413, -0.0760, -0.8138,  0.2173,  1.5132,-0.6389,  0.7173]])
tensor([1.4133, 2.2473])

官方文档

nn.TripletMarginLoss:https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html?highlight=tripletmarginloss#torch.nn.TripletMarginLoss

初步完稿于:2022年3月30日

http://www.lryc.cn/news/2397.html

相关文章:

  • 冒泡排序详解
  • git极快上手指南超级精简版
  • 蓝桥杯-最长公共子序列(线性dp)
  • GO的并发模式Context
  • 《Redis实战篇》六、秒杀优化
  • 《C++ Primer Plus》第16章:string类和标准模板库(11)
  • 声明和定义
  • Python获取最小路径,查找元素在list中的坐标
  • 数据采集协同架构,集成马扎克、西门子、海德汉、广数、凯恩帝、三菱、海德汉、兄弟、哈斯、宝元、新代、发那科、华中各类数控以及各类PLC数据采集软件
  • Allegro172版本如何用自带的功能实现快速在1MMBGA下方等距放置电容
  • 一种简单的统计pytorch模型参数量的方法
  • 【PyTorch】教程:对抗学习实例生成
  • 中国区使用Open AI账号试用Chat GPT指南
  • STM32开发(9)----CubeMX配置外部中断
  • Nextjs了解内容
  • 从事功能测试1年,裸辞1个月,找不到工作的“我”怎么办?
  • 机器学习基本原理总结
  • JVET-AC0315:用于色度帧内预测的跨分量Merge模式
  • Session与Cookie的区别(二)
  • 疫情开发,软件测试行情趋势是怎么样的?
  • Java中间件描述与使用,面试可以用
  • [OpenMMLab]AI实战营第七节课
  • 面向对象的设计模式
  • 里氏替换原则|SOLID as a rock
  • 【C++】右左法则,指针、函数与数组
  • 打通数据价值链,百分点数据科学基础平台实现数据到决策的价值转换 | 爱分析调研
  • C++之多态【详细总结】
  • ThingsBoard-RPC
  • java分治算法
  • 【Flutter】【Unity】使用 Flutter + Unity 构建(AR 体验工具包)