当前位置: 首页 > news >正文

pytorch的CrossEntropyLoss交叉熵损失函数默认是平均值

pytorch中使用nn.CrossEntropyLoss()创建出来的交叉熵损失函数计算损失默认是求平均值的,即多个样本输入后获取的是一个均值标量,而不是样本大小的向量。

net = nn.Linear(4, 2)
loss = nn.CrossEntropyLoss()
X = torch.rand(10, 4)
y = torch.ones(10, dtype=torch.long)
y_hat = net(X)
l = loss(y_hat, y)
print(l)

打印的结果:tensor(0.7075, grad_fn=<NllLossBackward0>)

以上是对10个样本做的均值的标量

net = nn.Linear(4, 2)
loss = nn.CrossEntropyLoss(reduction='none')
X = torch.rand(10, 4)
y = torch.ones(10, dtype=torch.long)
y_hat = net(X)
l = loss(y_hat, y)
print(l)

在构造CrossEntropyLoss时候加入 reduction='none',就把默认求平均取消掉了

打印结果:

tensor([0.6459, 0.7372, 0.6373, 0.6843, 0.6251, 0.6555, 0.5510, 0.7016, 0.6975,0.6849], grad_fn=<NllLossBackward0>)

以上是10个样本各自的loss值

上图是pytorch的CrossEntropyLoss的构造方法,默认是 reduction='mean'

http://www.lryc.cn/news/107985.html

相关文章:

  • 【力扣】206. 反转链表 <链表指针>
  • Java包装类(自动拆装箱)
  • 使用Golang反射技术实现一套有默认值的配置解析库
  • 数据安全能力框架模型-详细解读(二)
  • 【BASH】回顾与知识点梳理(八)
  • rust报错“Utf8Error { valid_up_to: 1, error_len: Some(1) } }”
  • 【Linux】节点之间配置免密登录
  • 【13】STM32·HAL库-正点原子SYSTEM文件夹 | SysTick工作原理、寄存器介绍 | printf函数使用、重定向
  • ansible配置文件案例
  • 【大数据】Flink 从入门到实践(一):初步介绍
  • 大数据课程F4——HIve的其他操作
  • React Native详解和代码实例
  • CAD随机球体颗粒过渡区3D插件
  • 【项目 进程12】2.25 sigprocmask函数使用 2.26sigaction信号捕捉函数 2.27SIGCHILD信号
  • 【无标题】面试题 02.07. 链表相交
  • Zotero ubuntu2023安装 关联 ubuntu文献翻译
  • Stable Diffusion教程(7) - PS安装AI绘画插件教程
  • 如何学技术
  • 【云存储】使用OSS快速搭建个人网盘教程(阿里云)
  • 微信小程序iconfont真机渲染失败
  • 万界星空/推出生产制造执行MES系统/开源MES/免费下载
  • 【VxWorks】Vxworks、QNX、Xenomai、Intime、Sylixos、Ucos等实时操作系统的性能特点
  • 17、YML配置文件及让springboot启动时加载我们自定义的yml配置文件的几种方式
  • 18、springboot默认的配置文件及导入额外配置文件
  • Conda换源(Linux)
  • 【C语言学习】数据类型转换
  • 深入了解PostgreSQL:高级查询和性能优化技巧
  • 【C#学习笔记】值类型(1)
  • 二十三种设计模式第二十二篇--中介者模式
  • 小研究 - 微服务系统服务依赖发现技术综述(二)