当前位置: 首页 > news >正文

PyTorch学习笔记:nn.MarginRankingLoss——排序损失

PyTorch学习笔记:nn.MarginRankingLoss——排序损失

torch.nn.MarginRankingLoss(margin=0.0, size_average=None, reduce=None, reduction='mean')

功能:创建一个排序损失函数,用于衡量输入x1x_1x1x2x_2x2之间的排序损失(Ranking Loss),输入的第三个参数yyy控制顺序还是逆序,因此yyy的取值范围为y∈{1,−1}y\in\{1,-1\}y{1,1}

损失函数
loss(x1,x2,y)=max⁡(0,−y∗(x1−x2)+margin)loss(x_1,x_2,y)=\max(0,-y*(x_1-x_2)+\text{margin}) loss(x1,x2,y)=max(0,y(x1x2)+margin)
当期望x1>x2x_1>x_2x1>x2,即排序为顺序时,应该传入y=1y=1y=1;当期望x1<x2x_1<x_2x1<x2,即排序为逆序时,应该传入y=−1y=-1y=1

输入:

  • margin:差额值,具体用法如公式所示,如果该值越大,则表示期望x1x_1x1x2x_2x2越远,即差额越大。输入数据类型为float,默认为0;
  • size_averagereduce已被弃用,具体功能由reduction替代
  • reduction:指定损失输出的形式,有三种选择:none|mean|sumnone:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出,会输出一个数;sum:将得到的损失求和再输出,会输出一个数

注意:

  • 输入的x1x_1x1x2x_2x2yyy必须是一维的数据,并且三个数据长度必须一致,数据长度表示batchbatchbatch大小

代码案例

一般用法

import torch
import torch.nn as nn# reduction设为none便于查看每个位置损失计算的结果
rankloss = nn.MarginRankingLoss(reduction='none')
x1 = torch.randn(10)
x2 = torch.randn(10)
# 随机生成10个0,1数据
y = torch.randint(0, 2, [10])
# 将y中数据为0的位置赋值为-1
y[y==0] = -1
loss = rankloss(x1, x2, y)
print(x1)
print(x2)
print(y)
print(loss)

输出

# x1
tensor([-1.2248, -1.4788,  0.1703,  0.1072, -0.2147,  0.7527, -1.2443,  0.8361, 0.3679,  0.5935])
# x2
tensor([-0.3616, -0.0333,  0.8483,  0.9880,  0.6980, -0.5157,  0.1767,  0.2060, -0.4908,  1.1774])
# y
tensor([ 1, -1,  1, -1,  1, -1,  1, -1,  1, -1])
# 对应位置的损失计算结果
tensor([0.8631, 0.0000, 0.6780, 0.0000, 0.9127, 1.2684, 1.4210, 0.6301, 0.0000, 0.0000])

官方文档

nn.MarginRankingLoss:https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html#torch.nn.MarginRankingLoss

初步完稿于:2022年2月6日

http://www.lryc.cn/news/2327.html

相关文章:

  • 【JavaScript】34_Date对象 ,日期的格式化
  • 计算机视觉 对比学习13篇经典论文、解读、代码
  • MySQL 选择数据库
  • 雅思经验(9)
  • java面试题(二十)中间件redis
  • JavaWEB必知必会-Servlet
  • oralce查找返回不同的值,寻找不同的表(原创)
  • Python-第四天 Python循环语句
  • spring中bean的生命周期(简单5步)
  • 10 个最难理解的 Python 概念
  • 【linux】线程概念
  • Leg转Goh引擎和架设单机+配置登陆器教程
  • idea整合svn
  • 字节青训前端笔记 | 数据可视化基础
  • ROS运行机C++程序,移动
  • C++中编译静态库与动态库
  • shell中sed命令用法
  • 【VictoriaMetrics】VictoriaMetrics启停脚本
  • 高性能网络SIG月度动态:SMC 与 IBM 就扩展协议达成一致,virtio 支持 XDP 新特性
  • 【正点原子FPGA连载】第七章程序固化实验摘自【正点原子】DFZU2EG_4EV MPSoC之嵌入式Vitis开发指南
  • LeetCode-2335. 装满杯子需要的最短总时长【贪心,数学】
  • 基于 oss 框架的音频驱动
  • 【golang】如何定制化zap日志库以及如何使用
  • 如何将 Ubuntu 升级到 22.04 LTS Jammy Jellyfish
  • ubuntu20.04安装docker与docker-compose
  • 笔试题-2023-加特兰-数字IC设计【纯净题目版】
  • 动态内存管理
  • Unsupervised Question Answering 简单综述
  • 智慧物流管理系统
  • 单表查询--实例