当前位置: 首页 > news >正文

深度学习损失计算

文章目录

  • 深度学习损失计算
    • 1.如何计算当前epoch的损失?
    • 2.为什么要计算样本平均损失,而不是计算批次平均损失?

深度学习损失计算

1.如何计算当前epoch的损失?

深度学习中的损失计算,通常为数据集的平均损失,即每个样本的平均损失值。计算步骤如下:

  • 计算单个批次的损失。每次迭代中,用当前模型预测值和真实值计算损失。假设 _loss 是这次迭代中计算得到的损失。
  • 转换为标量。利用item()方法将其转换为标量值。_loss.item()
  • 乘以批次大小。乘以批次大小的原因是,希望总损失是所有数据点的损失总和,而不是批次平均损失。
  • 累加损失loss += _loss.item() * batch_size 将当前批次的总损失累加到变量 loss 中。这样所有批次遍历结束后,就得到一个epoch的总损失。
  • 计算当前epoch的样本平均损失。通过总损失除以总的数据样本数,来得到平均损失。average_loss = loss/len(dataloader.dataset)【注意:除的是总的数据样本数(len(dataloader.dataset))!不是总的批次数(len(dataloader))!】

示例代码如下:

for epoch in total_epoch:  # epoch迭代total_loss = 0.0  # 初始化总损失for inputs, targets in dataloader:  # batch迭代outputs = model(inputs)  # 获取预测值_loss = criterion(outputs, targets)  # 计算当前批次损失,为批次平均损失batch_size = inputs.size(0)  # 获取批次大小total_loss += _loss.item() * batch_size  # 计算当前批次的总损失# 计算当前epoch的平均损失average_loss = total_loss / len(dataloader.dataset)  

2.为什么要计算样本平均损失,而不是计算批次平均损失?

由于每个批次的大小可能不一样,特别是在数据集的大小不是批次大小的整数倍时,所以使用 len(dataloader) 会导致错误的平均损失计算。

下面用一个简单的例子,解释这两种计算方式的不同:

假设数据集有 105 个样本,每个批次大小为 10,这样会有 11 个批次,其中最后一个批次只有 5 个样本。结合上面的伪代码,假设损失值 _loss.item() 是 1,对于 10 个批次的损失是 10,最后一个批次的损失是 5。那么:

  • t o t a l _ l o s s = ( 1 ∗ 10 ) ∗ 10 + ( 1 ∗ 5 ) ∗ 1 = 105 total\_loss = (1 * 10) * 10 + (1 * 5) * 1 = 105 total_loss=(110)10+(15)1=105
  • l e n ( d a t a l o a d e r . d a t a s e t ) = 105 len(dataloader.dataset) = 105 len(dataloader.dataset)=105
  • l e n ( d a t a l o a d e r ) = 11 len(dataloader) = 11 len(dataloader)=11

计算结果:

  • 样本平均损失计算:average_loss = total_loss / len(dataloader.dataset) 105 / 105 = 1 105/105 = 1 105/105=1
  • 批次平均损失计算:average_loss = total_loss / len(dataloader) 105 / 11 ≈ 9.545 105/11 \approx 9.545 105/119.545

显然,第一种方式是正确的,反映了每个样本的真实平均损失。

😃😃😃

http://www.lryc.cn/news/401284.html

相关文章:

  • 论文翻译:通过云计算对联网多智能体系统进行预测控制
  • Java核心(五)多线程
  • IDEA快速生成项目树形结构图
  • 【CPO-TCN-BiGRU-Attention回归预测】基于冠豪猪算法CPO优化时间卷积双向门控循环单元融合注意力机制
  • 面试高级 Java 工程师:2024 年的见闻与思考
  • 设计模式大白话之装饰者模式
  • 动手学深度学习6.3 填充和步幅-笔记练习(PyTorch)
  • 函数的形状怎么定义?
  • Windows 虚拟机服务器项目部署
  • JDBC(2)基础篇2——增删改查及常见问题
  • JVM知识点梳理
  • 产品经理-一份标准需求文档的8个模块(14)
  • 如何用一个例子向10岁小孩解释高并发实时服务的单线程事件循环架构
  • 如何为帕金森病患者选择合适的步行辅助设备?
  • 【排序算法】1.冒泡排序-C语言实现
  • Unity最新第三方开源插件《Stateful Component》管理中大型项目MonoBehaviour各种序列化字段 ,的高级解决方案
  • Spark SQL----INSERT TABLE
  • socket功能定义和一般模型
  • 如何在linux中给vim编辑器添加插件
  • Web 中POST为什么会发送两次请求
  • C语言经典程序100案例
  • 南京邮电大学统计学课程实验3 用EXCEL进行方差分析 指导
  • 2024-07-13 Unity AI状态机2 —— 项目介绍
  • shell脚本-linux如何在脚本中远程到一台linux机器并执行命令
  • Spring Data Redis + Redis数据缓存学习笔记
  • 在项目中,如何使用springboot+vue+springsecurity+redis缓存+Axios+MySQL数据库+mybatis
  • 微调 Florence-2 - 微软的尖端视觉语言模型
  • 【数据结构】二叉树全攻略,从实现到应用详解
  • 微信小程序加载动画文件
  • [计算机网络] VPN技术