当前位置: 首页 > news >正文

PyTorch Lightning教程五:Debug调试

如果遇到了这样一个问题,当一次训练模型花了好几天,结果突然在验证或测试的时候崩掉了,这个时候其实是很奔溃的,主要还是由于没有提前知道哪些时候会出现什么问题,本节会引入Lightning的Debug方案

1.fast_dev_run参数

Trainer中的fast_dev_run参数通过你的训练器运行5批训练、验证、测试和预测数据,看看是否有任何错误,如下

Trainer(fast_dev_run=True)

如果fast_dev_run设置为7,则表示训练7个batch每次

⚠️注意:这个参数将禁用tuner、checkpoint callbacks, early stopping callbacks, loggers 和 logger callbacks(如 LearningRateMonitor和DeviceStatsMonitor)。

2.减少epoch长度

有时,我们只需要使用训练、val、测试或预测数据的一小部分(或一组批次),来看看是否有错误。例如,可以使用20%的训练集和1%的验证集。

在像Imagenet这样的大型数据集上,这可以帮助我们更快地调试或测试一些东西,而不是等待一个完整的epoch。

# 只使用10%的训练数据和1%的验证数据
trainer = Trainer(limit_train_batches=0.1, limit_val_batches=0.01)# 使用10批次训练和5批次验证
trainer = Trainer(limit_train_batches=10, limit_val_batches=5)

3.运行一次完整性验证

Lightning在训练开始时有2个验证的步骤。这避免了在验证循环中陷入冗长的训练循环。

trainer = Trainer(num_sanity_val_steps=2)

4.打印模型相关参数

每当调用.fit()函数时,训练器将打印LightningModule的权重摘要,例如

trainer.fit(...)

则出现

  | Name  | Type        | Params
----------------------------------
0 | net   | Sequential  | 132 K
1 | net.0 | Linear      | 131 K
2 | net.1 | BatchNorm1d | 1.0 K

需要将子模块添加到摘要中,添加一个ModelSummary,如下操作

# 方法1.引入回调函数
from lightning.pytorch.callbacks import ModelSummary
trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])  # 回调函数ModelSummary
trainer.fit()
# 注:如果不打印,则可以运行 Trainer(enable_model_summary=False)# 当然也可以下面这样子,直接打印
# 方法2.不调用fit
model = LitModel()
summary = ModelSummary(model, max_depth=-1)
print(summary)

4.所有中间层的输入输出

另一个调试工具是通过在LightningModule中设置example_input_array属性来显示所有层的中间输入和输出大小。

class LitModel(LightningModule):def __init__(self, *args, **kwargs):self.example_input_array = torch.Tensor(32, 1, 28, 28)

当执行.fit()时,会打印如下

  | Name  | Type        | Params | In sizes  | Out sizes
--------------------------------------------------------------
0 | net   | Sequential  | 132 K  | [10, 256] | [10, 512]
1 | net.0 | Linear      | 131 K  | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K  | [10, 512] | [10, 512]
http://www.lryc.cn/news/111891.html

相关文章:

  • 末流211无科研保研经验分享
  • 日期选择器多选换行
  • NodeJS原型链污染ctfshow_nodejs
  • 18. SpringBoot 如何在 POM 中引入本地 JAR 包
  • vue2-$nextTick有什么作用?
  • python自动收集粘贴板
  • Vue3_语法糖—— <script setup>以及unplugin-auto-import自动引入插件
  • 2023-08-06力扣做过了的题
  • 进程间通信之管道
  • f12 CSS网页调试_css样式被划了黑线怎么办
  • vue-制作自动滚动效果
  • [国产MCU]-BL602-开发实例-DMA数据传输
  • Redis压缩列表
  • 【SA8295P 源码分析】62 - Android GVM Kernel 内核 make bootimage 过程分析
  • 机器学习——SMO算法推导与实践
  • mac的终端通过code .指令快速启动vscode
  • 前端系统使用iframe下载文件
  • RabbitMQ - 简单案例
  • 《吐血整理》高级系列教程-吃透Fiddler抓包教程(30)-Fiddler如何抓Android7.0以上的Https包-番外篇
  • 服务器被攻击了怎么办?
  • P1156 垃圾陷阱(背包变形)
  • [Docker实现测试部署CI/CD----构建成功后钉钉告警(7)]
  • UE5 半透明覆层材质
  • 在Raspberry Pi 4上安装Ubuntu 20.04 + ROS noetic(不带显示器)
  • CommStudio for .NET Crack
  • 蓝桥杯上岸考点清单 (冲刺版)!!!
  • AI一键生成短视频
  • 基于MATLAB长时间序列遥感数据分析(以MODIS数据处理为例)
  • postgresql之内存池-AllocsetContext
  • QT 使用单例模式