当前位置: 首页 > news >正文

Pytorch训练时报nan

0. 引言

Pytorch训练时在batch=N时loss为nan。经过断点检查发现在batch=N-1时,网络参数非nan,输出非nan,但梯度为nan,导致网络参数已经全部被更新为nan,遇到这种情况应该如何排查,如何避免?由于导致nan的情况较为繁多,本文给出的不是一个个例的解决方案,而是一种通用的抽象解决方案。

1. 排查

最简单的排查的方式就是检查parameter的参数值:

# model
for name, param in model.named_parameters(recurse=True):if not torch.isfinite(param.mean()):print(name)

通过该种方法可以打印出网络参数中数值非有限值的参数所在层。

第二种方法是检查parameter的梯度值,该方法需要retain_graph=True (Pytorch默认不保存图结构以节省GPU内存)

# compute loss
loss.backward(retain_graph=True)
# model
for name, param in model.named_parameters(recurse=True):if not torch.isfinite(param.grad.mean()):print(name)

检查梯度和参数值的方式都是从后往前查(和反向传播的顺序一致),子节点出现问题会导致其根节点必定出现问题,因此优先排查子节点是否是导致nan的原因。

最后提醒一下,如果nan排查成功,别忘了把retain_graph=True给删了,因为这条命令占用额外的GPU内存。

2. 规避

在这里介绍的方法是基于Pytorch 1.13的,Pytorch 2.x的用户也不想要担心,因为本教程中设置的参数在Pytorch 2.x里面已经设为默认参数,完全兼容。

# compute loss
# optimizer, model
clip_grad = 1.0 # maximum value to clip grad_norm
try:nn.utils.clip_grad_norm_(model.parameters(), clip_grad, norm_type=2, error_if_nonfinite=True) # 遇到nonfinite的梯度报错optimizer.step()
except:print("nan detected in grad, skip batch")optimizer.zero_grad()  # 所有梯度置0,保证下一个batch的正常训练continue  # 跳过这个batch的训练

这个代码的思想就是利用clip_grad_norm_自带的梯度检查功能在反向传播前对model的每个参数梯度进行检查,如若出现梯度异常值,则跳过batch(且不会对网络进行梯度更新)。需要的注意的是,optimizer.zero_grad()除了在本代码中出现,应该在主循环里面也另外有一个,但是此处省略了。

http://www.lryc.cn/news/479533.html

相关文章:

  • JavaScript定时器详解:setTimeout与setInterval的使用与注意事项
  • CSS——选择器、PxCook软件、盒子模型
  • Mysql 大表limit查询优化原理实战
  • 在vscode中开发运行uni-app项目
  • 【JavaEE初阶 — 多线程】Thread的常见构造方法&属性
  • ctfshow(316)--XSS漏洞--反射性XSS
  • ubuntu22.04安装conda
  • D58【python 接口自动化学习】- python基础之异常
  • Java项目实战II基于Spring Boot的便利店信息管理系统(开发文档+数据库+源码)
  • Java-日期计算工具类DateCalculator
  • 单片机串口接收状态机STM32
  • ipv6的 fc00(FC00::/7) 和 fec0(FEC0::/10)
  • Chat GPT英文学术写作指令
  • 超详细Pycharm安装汉化教程,Python环境配置和使用指南,Python零基础入门看这个就够了!
  • react-native:解决使用webView后部分场景在安卓10崩溃闪退问题
  • 大数据工具 flume 的安装配置与使用 (详细版)
  • 智慧城市智慧城市项目方案-大数据平台建设技术方案(原件Word)
  • C语言比较两个字符串是否相同
  • 丹摩征文活动|FLUX.1图像生成模型:AI工程师的创新实践
  • ZABBIX API获取监控服务器OS层信息
  • SpringBoot基础系列学习(五):JdbcTemplate 访问数据库
  • JavaEE-多线程初阶(3)
  • 从入门到精通:如何在Vue项目中有效运用el-image-viewer
  • uniapp组件实现省市区三级联动选择
  • 【C++】异常处理机制(对运行时错误的处理)
  • C++ boost steady_timer使用介绍
  • JVM 由多个模块组成,每个模块负责特定的功能
  • ORACLE批量插入更新如何拆分大事务?
  • kafka+zookeeper的搭建
  • Spark中的宽窄依赖