当前位置: 首页 > news >正文

lightning的hook顺序

结果

setup: 训练循环开始前设置数据加载器和模型。

configure_optimizers: 设置优化器和学习率调度器。

on_fit_start: 训练过程开始。

on_train_start: 训练开始。

on_train_epoch_start: 每个训练周期开始。

on_train_batch_start: 每个训练批次开始。

on_before_backward: 反向传播之前。

on_after_backward: 反向传播之后。

on_before_zero_grad: 清空梯度之前。

on_after_zero_grad: 清空梯度之后。

on_before_optimizer_step: 优化器步骤之前。

on_train_batch_end: 每个训练批次结束。

on_train_epoch_end: 每个训练周期结束。

on_train_end: 训练结束。

on_fit_end: 训练过程结束。

测试代码

import torch
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.callbacks import Callback# 定义一个简单的线性回归模型
class LinearRegression(LightningModule):def __init__(self):super().__init__()self.linear = torch.nn.Linear(1, 1)def forward(self, x):return self.linear(x)def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = torch.nn.functional.mse_loss(y_hat, y)return lossdef on_after_backward(self, *args, **kwargs):print("After backward is called!", args, kwargs)return super().on_after_backward(*args, **kwargs)def on_before_zero_grad(self, *args, **kwargs):print("Before zero grad is called!", args, kwargs)return super().on_before_zero_grad(*args, **kwargs)def on_after_zero_grad(self, *args, **kwargs):print("After zero grad is called!", args, kwargs)return super().on_after_zero_grad(*args, **kwargs)def on_before_backward(self, *args, **kwargs):print("Before backward is called!", args, kwargs)return super().on_before_backward(*args, **kwargs)def on_before_optimizer_step(self, *args, **kwargs):print("Before optimizer step is called!", args, kwargs)return super().on_before_optimizer_step(*args, **kwargs)def on_after_optimizer_step(self, *args, **kwargs):print("After optimizer step is called!", args, kwargs)return super().on_after_optimizer_step(*args, **kwargs)def on_fit_start(self, *args, **kwargs):print("Fit is starting!", args, kwargs)return super().on_fit_start(*args, **kwargs)def on_fit_end(self, *args, **kwargs):print("Fit is ending!", args, kwargs)return super().on_fit_end(*args, **kwargs)def setup(self, *args, **kwargs):print("Setup is called!", args, kwargs)return super().setup(*args, **kwargs)def configure_optimizers(self, *args, **kwargs):print("Configure Optimizers is called!", args, kwargs)return super().configure_optimizers(*args, **kwargs)def on_train_start(self, *args, **kwargs):print("Training is starting!", args, kwargs)return super().on_train_start(*args, **kwargs)def on_train_end(self, *args, **kwargs):print("Training is ending!", args, kwargs)return super().on_train_end(*args, **kwargs)def on_train_batch_start(self, *args, **kwargs):print(f"Training batch is starting!", args, kwargs)return super().on_train_batch_start(*args, **kwargs)def on_train_batch_end(self, *args, **kwargs):print(f"Training batch is ending!", args, kwargs)return super().on_train_batch_end(*args, **kwargs)def on_train_epoch_start(self, *args, **kwargs):print(f"Training epoch is starting!", args, kwargs)return super().on_train_epoch_start(*args, **kwargs)def on_train_epoch_end(self, *args, **kwargs):print(f"Training epoch is ending!", args, kwargs)return super().on_train_epoch_end(*args, **kwargs)# 创建数据集
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float)
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float)
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=2)# 创建模型和训练器
model = LinearRegression()
trainer = Trainer(max_epochs=2)# 开始训练
trainer.fit(model, train_loader)
http://www.lryc.cn/news/356919.html

相关文章:

  • 【ARFoundation自学03】AR Point Cloud 点云(参考点标记)功能详解
  • x264 码率控制中实现 VBV 算法源码分析
  • 宝兰德入选“鑫智奖·2024金融数据智能运维创新优秀解决方案”榜单
  • Unity3D雨雪粒子特效(Particle System)
  • 记录使用自定义编辑器做试题识别功能
  • MySQL索引和视图
  • Java单元测试Mock的用法,关于接口测试的用例
  • 《心理学报》文本分析技术最新进展总结盘点
  • json格式文件备份redis数据库 工具
  • JAVA系列:NIO
  • 偏微分方程算法之抛物型方程差分格式编程示例二
  • linux 查看 线程名, 线程数
  • python class __getattr__ 与 __getattribute__ 的区别
  • [ C++ ] 类和对象( 下 )
  • 这么多不同接口的固态硬盘,你选对了嘛!
  • 使用IDEA远程debug调试
  • 开源自定义表单系统源码 一键生成表单工具 可自由DIY表单模型+二开
  • 【java10】集合中新增copyof创建只读集合
  • python小甲鱼作业001-3讲
  • 做电商,错过了2020年的抖音!那2024一定要选择视频号小店!
  • 赛氪网与武汉外语外事职业学院签署校企合作,共创职业教育新篇章
  • 如何在文档中有效添加网格:技巧与实例
  • 设计模式10——装饰模式
  • 如果返回的json 中有 ‘///’ 转换
  • JAVA学习-练习试用Java实现“多线程问题”
  • SQOOP详细讲解
  • 【Unity入门】认识Unity编辑器
  • Spring控制重复请求
  • AWS安全性身份和合规性之Key Management Service(KMS)
  • esp32 固件备份 固件恢复