当前位置: 首页 > news >正文

pytorch 自动微分

自动微分

  • 1. 基础概念
    • 1.1. **张量**
    • 1.2. **计算图**:
    • 1.3. **反向传播**
    • 1.4. **梯度**
  • 2. 计算梯度
      • 2.1 标量梯度计算
      • 2.2 向量梯度计算
      • 2.3 多标量梯度计算
      • 2.4 多向量梯度计算
  • 3. 梯度上下文控制
      • 3.1 控制梯度计算(with torch.no_grad())
      • 3.2 累计梯度
      • 3.3 梯度清零(torch.zero_())

自动微分模块torch.autograd负责 自动计算张量操作的梯度,具有自动求导功能。自动微分模块是构成神经网络训练的必要模块,可以实现网络权重参数的更新,使得反向传播算法的实现变得简单而高效。

1. 基础概念

1.1. 张量

Torch中一切皆为张量,属性requires_grad决定是否对其进行梯度计算。默认是 False,如需计算梯度则设置为True。

1.2. 计算图

torch.autograd通过创建一个动态计算图来跟踪张量的操作,每个张量是计算图中的一个节点,节点之间的操作构成图的边。

在 PyTorch 中,当张量的 requires_grad=True 时,PyTorch 会自动跟踪与该张量相关的所有操作,并构建计算图。**每个操作都会生成一个新的张量,并记录其依赖关系。**当设置为 True 时,表示该张量在计算图中需要参与梯度计算,即在反向传播(Backpropagation)过程中会自动计算其梯度;当设置为 False 时,不会计算梯度。

  1. 例如:
    z=x∗yloss=z.sum()z = x * y\\loss = z.sum() z=xyloss=z.sum()
    在上述代码中,x 和 y 是输入张量,即叶子节点,z 是中间结果,loss 是最终输出。每一步操作都会记录依赖关系:

z = x * y:z 依赖于 x 和 y。

loss = z.sum():loss 依赖于 z。

这些依赖关系形成了一个动态计算图,如下所示:

	  x       y\     /\   /\ /z||vloss

叶子节点

在 PyTorch 的自动微分机制中,叶子节点(leaf node) 是计算图中

  • 由用户直接创建的张量,并且它的 requires_grad=True。
  • 这些张量是计算图的起始点,通常作为模型参数或输入变量。

特征:

  • 1.没有由其他张量通过操作生成。
  • 2.如果参与了计算,其梯度会存储在 leaf_tensor.grad 中。
    1. 默认情况下,叶子节点的梯度不会自动清零,需要显式调用 optimizer.zero_grad() 或 x.grad.zero_() 清除。

如何判断一个张量是否是叶子节点?

通过 tensor.is_leaf 属性,可以判断一个张量是否是叶子节点。

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  # 叶子节点
y = x ** 2  # 非叶子节点(通过计算生成)
z = y.sum()print(x.is_leaf)  # True
print(y.is_leaf)  # False
print(z.is_leaf)  # False

叶子节点与非叶子节点的区别

特性叶子节点非叶子节点
创建方式用户直接创建的张量通过其他张量的运算生成
is_leaf 属性TrueFalse
梯度存储梯度存储在 .grad 属性中梯度不会存储在 .grad,只能通过反向传播传递
是否参与计算图是计算图的起点是计算图的中间或终点
删除条件默认不会被删除在反向传播后,默认被释放(除非 retain_graph=True)

detach():张量 x 从计算图中分离出来,返回一个新的张量,与 x 共享数据,但不包含计算图(即不会追踪梯度)。

特点

  • 返回的张量是一个新的张量,与原始张量共享数据。
  • 对 x.detach() 的操作不会影响原始张量的梯度计算。
  • 推荐使用 detach(),因为它更安全,且在未来版本的 PyTorch 中可能会取代 data。
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x.detach()  # y 是一个新张量,不追踪梯度y += 1  # 修改 y 不会影响 x 的梯度计算
print(x)  # tensor([1., 2., 3.], requires_grad=True)
print(y)  # tensor([2., 3., 4.])

1.3. 反向传播

使用tensor.backward()方法执行反向传播,从而计算张量的梯度。这个过程会自动计算每个张量对损失函数的梯度。例如:调用 loss.backward() 从输出节点 loss 开始,沿着计算图反向传播,计算每个节点的梯度。

1.4. 梯度

计算得到的梯度通过tensor.grad访问,这些梯度用于优化模型参数,以最小化损失函数。

2. 计算梯度

使用tensor.backward()方法执行反向传播,从而计算张量的梯度

2.1 标量梯度计算

参考代码如下:

  x=torch.tensor([1.0],requires_grad=True)y=x**2y.backward()print(x.grad)#tensor([2.])

2.2 向量梯度计算

案例:

 x=torch.tensor([1.0,2.0,3.0],requires_grad=True)y=x**2# y.backward(torch.tensor([1,1,1]))# print(x.grad)#tensor([2., 4., 6.])# 2.利用中间量将将y变为标量z=y.sum()z.backward()print(x.grad)#tensor([2., 4., 6.])

错误预警:RuntimeError: grad can be implicitly created only for scalar outputs
注意:

  • 由于 y 是一个向量,我们需要提供一个与 y 形状相同的向量作为 backward() 的参数,这个参数通常被称为 梯度张量(gradient tensor),它表示 y 中每个元素的梯度。上述代码(torch.tensor([1,1,1]))就是梯度张量

2.3 多标量梯度计算

参考代码如下

  '''多标量梯度计算'''x=torch.tensor([1.0],requires_grad=True)y=torch.tensor([2.0],requires_grad=True)z=x**2+y*4+7z.backward()print(x.grad,y.grad)'''tensor([2.]) tensor([4.])'''

2.4 多向量梯度计算

代码参考如下:

 '''多向量梯度计算'''x=torch.tensor([1.0,2.0,3.0],requires_grad=True)y=torch.tensor([2.0,3.0,4.0],requires_grad=True)z = x ** 2 + y * 4 + 7loss=z.sum()loss.backward()print(x.grad,y.grad,z.grad) #tensor([2., 4., 6.]) tensor([4., 4., 4.]) None

注意:

  • 不是叶子结点(即不是直接创建的张量)最后梯度会被清除

3. 梯度上下文控制

梯度计算的上下文控制和设置对于管理计算图、内存消耗、以及计算效率至关重要。下面我们学习下Torch中与梯度计算相关的一些主要设置方式。

3.1 控制梯度计算(with torch.no_grad())

梯度计算是有性能开销的,有些时候我们只是简单的运算,并不需要梯度

import torchdef test001():x = torch.tensor(10.5, requires_grad=True)print(x.requires_grad)  # True# 1. 默认y的requires_grad=Truey = x**2 + 2 * x + 3print(y.requires_grad)  # True# 2. 如果不需要y计算梯度-with进行上下文管理with torch.no_grad():y = x**2 + 2 * x + 3print(y.requires_grad)  # False# 3. 如果不需要y计算梯度-使用装饰器@torch.no_grad()def y_fn(x):return x**2 + 2 * x + 3y = y_fn(x)print(y.requires_grad)  # False# 4. 如果不需要y计算梯度-全局设置,需要谨慎torch.set_grad_enabled(False)y = x**2 + 2 * x + 3print(y.requires_grad)  # Falseif __name__ == "__main__":test001()

3.2 累计梯度

默认情况下,当我们重复对一个自变量进行梯度计算时,梯度是累加的

import torchdef test002():# 1. 创建张量:必须为浮点类型x = torch.tensor([1.0, 2.0, 5.3], requires_grad=True)# 2. 累计梯度:每次计算都会累计梯度for i in range(3):y = x**2 + 2 * x + 7z = y.mean()z.backward()print(x.grad)if __name__ == "__main__":test002()

输出结果:

tensor([1.3333, 2.0000, 4.2000])
tensor([2.6667, 4.0000, 8.4000])
tensor([ 4.0000,  6.0000, 12.6000])

思考:如果把 y = x**2 + 2 * x + 7放在循环外,会是什么结果?

会报错:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

PyTorch 的自动求导机制在调用 backward() 时,会计算梯度并将中间结果存储在计算图中。默认情况下,这些中间结果在第一次调用 backward() 后会被释放,以节省内存。如果再次调用 backward(),由于中间结果已经被释放,就会抛出这个错误。

3.3 梯度清零(torch.zero_())

大多数情况下是不需要梯度累加的,奇葩的事情还是需要解决的

    '''梯度清零'''x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)y = torch.tensor([3.0, 4.0, 5.0], requires_grad=True)# no_grad():使z不参与梯度计算for i in range(3):z = x ** 2 + y * 4 + 7loss = z.sum()if x.grad is not None and y.grad is not None:x.grad.zero_()y.grad.zero_()loss.backward()print(x.grad,y.grad)'''没清零前tensor([2., 4., 6.]) tensor([4., 4., 4.])tensor([ 4.,  8., 12.]) tensor([8., 8., 8.])tensor([ 6., 12., 18.]) tensor([12., 12., 12.])清零后tensor([2., 4., 6.]) tensor([4., 4., 4.])
tensor([2., 4., 6.]) tensor([4., 4., 4.])
tensor([2., 4., 6.]) tensor([4., 4., 4.])'''

注意:

  • 叶子结点的梯度默认是会进行累加的,如果迭代几次之后他们的每次梯度都会保留而后再与新梯度相加
  • 通过tensor.grad.zero_()方法使每次梯度清零,就不会影响下一次迭代后的梯度
http://www.lryc.cn/news/583328.html

相关文章:

  • Git 详解:从概念,常用命令,版本回退到工作流
  • sqlplus表结构查询
  • 3.常⽤控件
  • 跨平台ROS2视觉数据流:服务器运行IsaacSim+Foxglove本地可视化全攻略
  • 【动手学深度学习】4.9. 环境和分布偏移
  • MyBatis之数据操作增删改查基础全解
  • tinyxml2 开源库与 VS2010 结合使用
  • MySQL8.0基于GTID的组复制分布式集群的环境部署
  • 如何通过配置gitee实现Claude Code的版本管理
  • SpringBoot校园疫情防控系统源码
  • Flink1.20.1集成Paimon遇到的问题
  • stm32Cubmax的配置
  • 微信小程序91~100
  • Pycharm 报错 Environment location directory is not empty 如何解决
  • 基于Spring Boot+Vue的巴彦淖尔旅游网站(AI问答、腾讯地图API、WebSocket及时通讯、支付宝沙盒支付)
  • Ragas的Prompt Object
  • NHibernate案例
  • SAP ERP与Oracle EBS对比,两个ERP系统有什么区别?
  • aichat-core简化 LLM 与 MCP 集成的前端核心库(TypeScript)
  • C#项目 在Vue/React前端项目中 使用使用wkeWebBrowser引用并且内部使用iframe网页外链 页面部分白屏
  • Spring IoC 如何实现条件化装配 Bean?
  • HUAWEI HiCar6.0的新变化
  • 一条Redis命令是如何执行的?
  • C++随机打乱函数:简化源码与原理深度剖析
  • 从零开始学前端html篇2
  • 微信小程序控制空调之微信小程序篇
  • 双esp8266-01s间TCP通讯
  • 图像硬解码和软解码
  • RAM带宽计算及分析
  • 区块链系统开发技术应用构建可信数字生态链