当前位置: 首页 > news >正文

PyTorch学习笔记:nn.MSELoss——MSE损失

PyTorch学习笔记:nn.MSELoss——MSE损失

torch.nn.MSELoss(size_average = Nonereduce = None,reduction = 'mean')

功能:创建一个平方误差(MSE)损失函数,又称为L2损失:
l(x,y)=L={l1,…,lN}T,ln=(xn−yn)2l(x,y)=L=\{l_1,\dots,l_N\}^T,l_n=(x_n-y_n)^2 l(x,y)=L={l1,,lN}T,ln=(xnyn)2
其中,NNN表示batch size。

函数图像:

在这里插入图片描述

输入:

  • size_averagereduce已经被弃用,具体功能可由reduction替代
  • reduction:指定损失输出的形式,有三种选择:none|mean|sumnone:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出,会输出一个数;sum:将得到的损失求和再输出,会输出一个数

注意:

  • 输入的xxxyyy可以是任意维数的数组,但是二者形状必须一致

代码案例

对比reduction不同时,输出损失的差异

import torch.nn as nn
import torchx = torch.rand(10, dtype=torch.float)
y = torch.rand(10, dtype=torch.float)
mse_none = nn.MSELoss(reduction='none')
mse_mean = nn.MSELoss(reduction='mean')
mse_sum = nn.MSELoss(reduction='sum')
out_none = mse_none(x, y)
out_mean = mse_mean(x, y)
out_sum = mse_sum(x, y)
print(x)
print(y)
print(out_none)
print(out_mean)
print(out_sum)

输出

# 用于输入的x
tensor([0.4138, 0.1747, 0.9259, 0.2938, 0.5557, 0.9708, 0.0649, 0.6155, 0.3192, 0.1918])
# 用于输入的y
tensor([0.1024, 0.9160, 0.8386, 0.0783, 0.1479, 0.9933, 0.8791, 0.4219, 0.7586, 0.2212])
# 当reduction设置为none时,输出一个数组
# 该数组上的元素为x,y对应每个元素的平方误差损失,即对应元素做差求平方
tensor([9.6983e-02, 5.4955e-01, 7.6214e-03, 4.6433e-02, 1.6630e-01, 5.0293e-04, 6.6287e-01, 3.7512e-02, 1.9310e-01, 8.6344e-04])
# 当reduction设置为mean时,输出所有损失的平均值
tensor(0.1762)
# 当reduction设置为sum时,输出所有损失的和
tensor(1.7617)

注:绘图程序

import torch.nn as nn
import torch
import numpy as np
import matplotlib.pyplot as pltloss = nn.MSELoss(reduction='none')
x = torch.tensor([0]*100)
y = torch.from_numpy(np.linspace(-3,3,100))
loss_value = loss(x,y)
plt.plot(y, loss_value)
plt.savefig('MSELoss.jpg')

官方文档

nn.MSELoss:https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss

初步完稿于:2022年1月29日

http://www.lryc.cn/news/2483.html

相关文章:

  • apache和nginx的TLS1.0和TLS1.1禁用处理方案
  • K_A12_002 基于STM32等单片机采集光敏电阻传感器参数串口与OLED0.96双显示
  • 《机器学习》学习笔记
  • 前端卷算法系列(一)
  • 【机器学习】聚类算法(理论)
  • Docker-用Jenkins发版Java项目-(1)Docke安装Jenkins
  • java集合框架内容整理
  • win10系统安装Nginx
  • 数据库学习笔记(2)——workbench和SQL语言
  • 测量学期末考试之名词解释总结
  • TDengine时序数据库的简单使用
  • 记录每日LeetCode 2335.装满被子需要的最短总时长 Java实现
  • 了解线程池newFixedTheadPool
  • IP分片和TCP分段解析--之IP分片
  • 物联网方向常见通信方式有哪些?
  • windows wireshark抓到未加入组的组播消息
  • 【PTA Advanced】1156 Sexy Primes(C++)
  • 项目(今日指数)
  • 适配器模式(Adapter Pattern)
  • 网易一面:select分页要调优100倍,说说你的思路? (内含Mysql的36军规)
  • 二叉树的遍历 (2023-02-11)
  • string的深浅拷贝问题
  • C++中的万能头文件
  • Java 8 Lambda 表达式 Stream
  • 【VictoriaMetrics】VictoriaMetrics单机版部署(二进制版)
  • SCI论文阅读-使用基于图像的机器学习模型对FTIR光谱进行功能组识别
  • 双11大型互动游戏“喵果总动员” 质量保障方案总结
  • 剑指Offer专项突击版题解一
  • Django框架之模型
  • OSACN-Net:使用深度学习和Gabor心电图信号谱图进行睡眠呼吸暂停分类