当前位置: 首页 > news >正文

PyTorch学习笔记:nn.SmoothL1Loss——平滑L1损失

PyTorch学习笔记:nn.SmoothL1Loss——平滑L1损失

torch.nn.SmoothL1Loss(size_average=None, reduce=None, reduction='mean', beta=1.0)

功能:创建一个平滑后的L1L_1L1损失函数,即Smooth L1:
l(x,y)=L={l1,…,lN}Tl(x,y)=L=\{l_1,\dots,l_N\}^T l(x,y)=L={l1,,lN}T
其中,
ln={12β(xn,yn)2,∣xn−yn∣<β∣xn−yn∣−12β,otherwise\begin{aligned} l_n=\left\{ \begin{matrix} & \frac{1}{2\beta}(x_n,y_n)^2, \quad |x_n-y_n|<\beta\\ &|x_n-y_n|-\frac12\beta,\quad \text{otherwise} \end{matrix} \right. \end{aligned} ln={2β1(xn,yn)2,xnyn<βxnyn21βotherwise

  如果绝对值误差低于β\betaβ,则创建一个平方项的损失(L2L_2L2),否则使用绝对值损失(L1L_1L1),此损失对异常值的敏感性低于L2L_2L2损失,即当xxxyyy相差过大时,该损失数值要小于L2L_2L2损失数值,在某些情况下该损失可以防止梯度爆炸,损失图如下所示:

在这里插入图片描述

输入:

  • size_averagereduce已经被弃用,具体功能可由reduction替代
  • reduction:指定损失输出的形式,有三种选择:none|mean|sumnone:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出,会输出一个数;sum:将得到的损失求和再输出,会输出一个数
  • beta:指定该损失在L1L_1L1L2L_2L2之间变化的阈值,默认1.01.01.0

注意:

  • Smooth L1损失与L1L_1L1损失类似,但是随着∣x−y∣<β|x-y|<\betaxy<β,即随着xxxyyy的靠近,损失形式逐渐向L2L_2L2损失的形式靠近

代码案例

一般用法

import torch.nn as nn
import torch# reduction设为none便于逐元素对比损失值
loss = nn.SmoothL1Loss(reduction='none')
x = torch.randn(10)
y = torch.randn(10)
loss_value = loss(x, y)
print(x)
print(y)
print(loss_value)

输出

# x
tensor([ 0.7584,  1.0724,  0.8966, -1.0947, -1.8141, -1.8305, -1.5329, -0.3077,0.6814, -0.2394])
# y
tensor([ 0.5081, -0.1718,  0.7817, -0.8019, -0.6405, -1.4802,  2.3039,  1.4522,1.1861, -0.2443])
# loss
tensor([3.1319e-02, 7.4427e-01, 6.6015e-03, 4.2872e-02, 6.7358e-01, 6.1354e-02,3.3368e+00, 1.2598e+00, 1.2736e-01, 1.1723e-05])

注:画图程序

import torch.nn as nn
import torch
import numpy as np
import matplotlib.pyplot as pltloss = nn.SmoothL1Loss(reduction='none')
x = torch.tensor([0]*100)
y = torch.from_numpy(np.linspace(-3,3,100))
loss_value = loss(x,y)
plt.plot(y, loss_value)
plt.savefig('SmoothL1Loss.jpg')

官方文档

nn.SmoothL1Loss:https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html#torch.nn.SmoothL1Loss

http://www.lryc.cn/news/1807.html

相关文章:

  • 2年时间,涨薪20k,想拿高薪还真不能老老实实的工作...
  • Spark - Spark SQL中RBO, CBO与AQE简单介绍
  • NeurIPS/ICLR/ICML AI三大会国内高校和企业近年中稿量完整统计
  • Android IO 框架 Okio 的实现原理,到底哪里 OK?
  • 一文讲解Linux 设备模型 kobject,kset
  • linux配置密码过期的安全策略(/etc/login.defs的解读)
  • c_character_string 字符串----我认真的弄明白了,也希望你们也是。
  • spring面试题 一
  • C++中char *,char a[ ]的特殊应用
  • 【Windows10】电脑副屏无法调节屏幕亮度?解决方法
  • Paper简读 - ProGen2: Exploring the Boundaries of Protein Language Models
  • leaflet 加载WKT数据(示例代码050)
  • 设计模式-组合模式和建筑者模式详解
  • Pcap文件的magic_number
  • MDS75-16-ASEMI三相整流模块MDS75-16
  • 基本TCP编程
  • 【沁恒WCH CH32V307V-R1开发板读取板载温度实验】
  • 学习SpringCloudAlibaba(二)微服务的拆分与编写
  • 通过对HashMap的源码分析解决部分关于HashMap的问题
  • 【无标题】
  • 渗透测试 -- 网站信息收集
  • Windows 搭建ARM虚拟机 UOS系统
  • day58每日温度_下一个更大元素1
  • 超清遥感影像语义分割处理
  • RabbitMQ安装及配置
  • 网络协议(四):网络互联模型、物理层、数据链路层
  • 请问有没有关于数据预测的方法?
  • [CVPR 2021] Your “Flamingo“ is My “Bird“: Fine-Grained, or Not
  • clickHouse笔记
  • 10.jQuery中请求预处理 $.ajaxPrefilter()