当前位置: 首页 > news >正文

PyTorch - 模型训练损失 (Loss) NaN 问题的解决方案

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/133378367

在模型训练中,如果出现 NaN 的问题,严重影响 Loss 的反传过程,因此,需要加入一些微小值进行处理,避免影响模型的训练结果。

例如,交叉熵损失 sigmoid_cross_entropy,包括对数函数(log) ,当计算 log 值时,当输入为0时,则会导致溢出,因此,需要加入极小值 (例如 1e-8) 约束,避免溢出。

交叉熵公式:

L ( y , y ^ ) = − 1 N ∑ i = 1 N [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] L(y, \hat{y}) = -\frac{1}{N} \sum_{i=1}^N [y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] L(y,y^)=N1i=1N[yilog(y^i)+(1yi)log(1y^i)]

Log 曲线:

log

即:

# 额外增加 eps,可以避免数值溢出
def sigmoid_cross_entropy(logits, labels, eps=1e-8):logits = logits.float()log_p = torch.log(torch.sigmoid(logits)+eps)log_not_p = torch.log(torch.sigmoid(-logits)+eps)loss = -labels * log_p - (1 - labels) * log_not_preturn loss

Sigmoid Cross Entropy 是一种常用的损失函数,用于衡量二分类问题中模型的预测结果和真实标签之间的差异,作用是优化模型的参数,使得模型能够更好地拟合数据,提高分类的准确性。

参考:How to solve the loss become nan because of using torch.log()

http://www.lryc.cn/news/180675.html

相关文章:

  • 8、Nacos服务注册服务端源码分析(七)
  • MySQL使用Xtrabackup在线做主从
  • scala基础入门
  • 【Java-LangChain:面向开发者的提示工程-5】推断
  • 【C++】手撕vector(vector的模拟实现)
  • 智能指针那些事
  • Fiddler抓取手机https包的步骤
  • idea没有maven工具栏解决方法
  • levelDB引擎
  • IM同步服务
  • MySQL 运维常用脚本
  • ABC322刷题记
  • visual studio的安装及scanf报错的解决
  • React生命周期
  • SpringBoot整合RocketMQ笔记
  • 【【萌新的RiscV学习之在写代码之前对于关键路径的分析-11】】
  • A. Sequence with Digits
  • gitlab配置webhook限制提交注释
  • 蓝桥杯Python scratch C++选拔赛stema个人如何报名?
  • Cesium实现动态旋转四棱锥(2023.9.11)
  • 2023最新PS(photoshop)Win+Mac免费下载安装包及教程内置AI绘画-网盘下载
  • 【JAVA】为什么要使用封装以及如何封装
  • 18.示例程序(编码器接口测速)
  • 【超详细】Fastjson 1.2.24 命令执行漏洞复现-JNDI简单实现反弹shell(CVE-2017-18349)
  • 【牛客网】JZ39 数组中出现次数超过一半的数字
  • 【Mysql】Lock wait timeout exceeded; try restarting transaction
  • python生成中金所期权行权价
  • CentOS7.9 安装postgresql
  • qt线程介绍
  • 记一次用dataframe进行数据清理