当前位置: 首页 > news >正文

pytorch多分类问题 CrossEntropyLoss()函数的输入size/shape不一致问题

在使用pytorch实现一个多分类任务的时候,许多多分类任务在训练过程中都会有如下的代码:

criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
# output.size : [batch_size, class_num]
# target.size : [batch_size]

许多的初学者会卡在这里,生出这样一个疑问:为什么输入的output.size和target.size是不一样的,这样如何计算损失值呢?因为根据损失函数的设计原理,损失值的计算应该如下图所示:
在这里插入图片描述
output应该和target的size是一一对应,才可以实现损失值的计算。
包括在官网上查看CrossEntropyLoss()函数的例子
在这里插入图片描述
也可以看到,依旧有size不匹配的例子。这是为什么呢?
这是因为在CrossEntropyLoss()函数的内部,会将传入的target转化为独热编码的格式,这样就会使target的size从[batch_size] =》 [batch_size, class_num]了。

故究其原因,就是CrossEntropyLoss()函数的内部会将target转化为独热编码,所以输入的时候直接将[batch_size] 的target(存放的是batch_size个对应类别标签) 输入进去即可。

http://www.lryc.cn/news/94561.html

相关文章:

  • 硬盘或者U盘提示需要格式化的解决办法
  • Clip-Path
  • Matlab绘图系列教程-Matlab 34 种绘图函数示例(下)
  • 【Vue+Django】Training Management Platform Axios并发请求 - 20230703
  • smart Spring:自定义注解、拦截器的使用(更新中...)
  • php导出pdf
  • 【ECMAScript6_2】字符串
  • 37.RocketMQ之Broker消息存储源码分析
  • RabbitMq应用延时消息
  • 【WEB自动化测试】- 浏览器操作方法
  • VSCode设置鼠标滚轮滑动设置字体大小
  • Spring MVC是什么?详解它的组件、请求流程及注解
  • 基于Spring Boot的广告公司业务管理平台设计与实现(Java+spring boot+MySQL)
  • docker 基本命令安装流程
  • 尚硅谷大数据Flink1.17实战教程-笔记02【Flink部署】
  • 【LeetCode每日一题合集】2023.7.3-2023.7.9
  • java企业工程项目管理系统平台源码
  • 软件设计模式与体系结构-设计模式-行为型软件设计模式-访问者模式
  • 【LeetCode】503. 下一个更大元素 II
  • 使用infura创建以太坊网络
  • TCP/IP协议是什么?
  • python图像处理实战(三)—图像几何变换
  • 学习vue2笔记
  • 【SQL】查找多个表中相同的字段
  • “未来之光:揭秘创新科技下的挂灯魅力“
  • Spring boot MongoDB实现自增序列
  • MyBatis查询数据库【秘籍宝典】
  • 目标检测舰船数据集整合
  • 第一章 Android 基础--开发环境搭建
  • 【LeetCode周赛】2022上半年题目精选集——二分