当前位置: 首页 > news >正文

optimizer.zero_grad(), loss.backward(), optimizer.step()的理解及使用

optimizer.zero_grad,loss.backward,optimizer.step

  • 用法介绍
  • optimizer.zero_grad():
  • loss.backward():
  • optimizer.step():

用法介绍

这三个函数的作用是将梯度归零(optimizer.zero_grad()),然后反向传播计算得到每个参数的梯度值(loss.backward()),最后通过梯度下降执行一步参数更新(optimizer.step())。

简单的说就是进来一个batch的数据,先将梯度归零,计算一次梯度,更新一次网络。

model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)for epoch in range(1, epochs):for i, (inputs, labels) in enumerate(train_loader):inputs= inputs.to(device=device)labels= labels.to(device=device)# forwardoutput= model(inputs)loss = criterion(output, labels)# backwardoptimizer.zero_grad()loss.backward()# gardient descent or adam stepoptimizer.step()

另外一种:将optimizer.zero_grad() 放在 optimizer.step() 后面,即梯度累加。

  1. 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  2. loss.backward() 反向传播,计算当前梯度;
  3. 多次循环步骤1-2,不清空梯度,使梯度累加在已有梯度上;
  4. 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;

总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。

一定条件下,batchsize越大训练效果越好,梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize ‘变相’ 扩大了8倍,是我们这种乞丐实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大。
参考链接:https://blog.csdn.net/weixin_36670529/article/details/108630740

optimizer.zero_grad():

param_groups:Optimizer类在实例化时会在构造函数中创建一个param_groups列表,列表中有num_groups个长度为6的param_group字典(num_groups取决于你定义optimizer时传入了几组参数),每个param_group包含了 [‘params’, ‘lr’, ‘momentum’, ‘dampening’, ‘weight_decay’, ‘nesterov’] 这6组键值对。

param_group[‘params’]:由传入的模型参数组成的列表,即实例化Optimizer类时传入该group的参数,如果参数没有分组,则为整个模型的参数model.parameters(),每个参数是一个torch.nn.parameter.Parameter对象。

def zero_grad(self):r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""for group in self.param_groups:for p in group['params']:if p.grad is not None:p.grad.detach_()p.grad.zero_()

optimizer.zero_grad()函数会遍历模型的所有参数,通过p.grad.detach_()方法截断反向传播的梯度流,再通过p.grad.zero_()函数将每个参数的梯度值设为0,即上一次的梯度记录被清空。

因为训练的过程通常使用mini-batch方法,所以如果不将梯度清零的话,梯度会与上一个batch的数据相关,因此该函数要写在反向传播和梯度下降之前。

loss.backward():

PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。

具体来说,torch.tensor是autograd包的基础类,如果你设置tensor的requires_grads为True,就会开始跟踪这个tensor上面的所有运算,如果你做完运算后使用tensor.backward(),所有的梯度就会自动运算,tensor的梯度将会累加到它的.grad属性里面去。

更具体地说,损失函数loss是由模型的所有权重w经过一系列运算得到的,若某个w的requires_grads为True,则w的所有上层参数(后面层的权重w)的.grad_fn属性中就保存了对应的运算,然后在使用loss.backward()后,会一层层的反向传播计算每个w的梯度值,并保存到该w的.grad属性中。

如果没有进行tensor.backward()的话,梯度值将会是None,因此loss.backward()要写在optimizer.step()之前。

optimizer.step():

以SGD为例,torch.optim.SGD().step()源码如下:

def step(self, closure=None):"""Performs a single optimization step.Arguments:closure (callable, optional): A closure that reevaluates the modeland returns the loss."""loss = Noneif closure is not None:loss = closure()for group in self.param_groups:weight_decay = group['weight_decay']momentum = group['momentum']dampening = group['dampening']nesterov = group['nesterov']for p in group['params']:if p.grad is None:continued_p = p.grad.dataif weight_decay != 0:d_p.add_(weight_decay, p.data)if momentum != 0:param_state = self.state[p]if 'momentum_buffer' not in param_state:buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()else:buf = param_state['momentum_buffer']buf.mul_(momentum).add_(1 - dampening, d_p)if nesterov:d_p = d_p.add(momentum, buf)else:d_p = bufp.data.add_(-group['lr'], d_p)return loss

step()函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,所以在执行optimizer.step()函数前应先执行loss.backward()函数来计算梯度。

注意:optimizer只负责通过梯度下降进行优化,而不负责产生梯度,梯度是tensor.backward()方法产生的。

http://www.lryc.cn/news/5876.html

相关文章:

  • 融资、量产和一栈式布局,这家Tier 1如此备战高阶智驾决赛圈
  • centos7.8安装oralce11g
  • 【蓝桥杯集训·每日一题】AcWing 3956. 截断数组
  • 万丈高楼平地起:Linux常用命令
  • Linux(Linux的连接使用)
  • Unity中画2D图表(2)——用XChart包绘制散点分布图 + 一条直线方程
  • Go 排序包 sort
  • Java Email 发HTML邮件工具 采用 freemarker模板引擎渲染
  • CNI 网络流量分析(六)Calico 介绍与原理(二)
  • 短视频标题的几种类型和闭坑注意事项
  • 操作系统——1.操作系统的概念、定义和目标
  • 【html弹框拖拽和div拖拽功能】原生html页面引入vue语法后通过自定义指令简单实现div和弹框拖拽功能
  • 2023新华为OD机试题 - 计算网络信号(JavaScript) | 刷完必过
  • 27.边缘系统的架构
  • 机器学习强基计划8-1:图解主成分分析PCA算法(附Python实现)
  • Hudi-集成Spark之spark-shell 方式
  • Python爬虫:从js逆向了解西瓜视频的下载链接的生成
  • Numpy-如何对数组进行切割
  • Python之字符串精讲(下)
  • Python图像卡通化animegan2-pytorch实例演示
  • 谢希仁版《计算机网络》期末总复习【完结】
  • 问:React的useState和setState到底是同步还是异步呢?
  • 深度理解机器学习16-门控循环单元
  • Python中Generators教程
  • 数据结构与算法基础-学习-10-线性表之栈的清理、销毁、压栈、弹栈
  • Leetcode 每日一题 1234. 替换子串得到平衡字符串
  • 【MYSQL中级篇】数据库数据查询学习
  • 华为OD机试真题JAVA实现【火星文计算】真题+解题思路+代码(20222023)
  • Linux基础知识
  • Linux 游戏性能谁的 更优秀X.Org还是Wayland!