当前位置: 首页 > news >正文

PyTorch自动求导

1. 计算图构建过程

x = torch.ones(5, requires_grad=True)  # 定义叶子节点,启用梯度跟踪
y = x + 2                             # 加法操作,生成中间节点 y
z = y * y * 3                         # 平方与乘法操作,生成中间节点 z
out = z.mean()                        # 标量输出(损失函数)
  • 动态计算图构建​:

    每行代码触发一个操作,PyTorch 动态记录操作依赖关系,生成有向无环图(DAG):

    x → (Add) → y → (Pow + Mul) → z → (Mean) → out

    节点类型:

    • 叶子节点​:用户直接创建的 xx.is_leaf = True)。
    • 非叶子节点​:y, z, out由运算生成(grad_fn属性记录操作类型)
  • 梯度跟踪机制​:

    设置 requires_grad=True后,所有依赖 x的中间节点自动继承此属性(如 y.requires_grad=True


2. 反向传播与梯度计算

out.backward()  # 触发反向传播
  • 反向传播流程​:
    1. 1.out开始反向遍历​:因 out是标量(shape=()),无需额外指定梯度权重
    2. 2.

      链式法则应用​:

      • out = z.mean()→ ∂zi​∂out​=51​(z有 5 个元素)。
      • z = 3y^2→ ∂yi​∂zi​​=6yi​。
      • y = x + 2→ ∂xi​∂yi​​=1
    3. 3.​梯度计算​:

      ∂xi​∂out​=∂zi​∂out​⋅∂yi​∂zi​​⋅∂xi​∂yi​​=51​⋅6yi​⋅1=56​(xi​+2)。

  • •​梯度存储​:

    结果存入叶子节点 x.grad,非叶子节点(如 y, z)的梯度默认不保留以节省内存


3. 梯度结果验证

print(f"x 的梯度: {x.grad}")  # 输出:tensor([3.6000, 3.6000, 3.6000, 3.6000, 3.6000])
  • •​数学推导​:

    代入 xi​=1:

    ∂xi​∂out​=56​(1+2)=518​=3.6。

    与代码输出一致,验证了链式法则的正确性


4. 梯度累积问题

  • •​默认行为​:

    backward()计算的梯度会累加x.grad。若多次执行 out.backward(),梯度将叠加(如运行两次后 x.grad变为 [7.2, 7.2, ...]

  • 解决方案​:

    训练循环中需在每次反向传播前调用 x.grad.zero_()optimizer.zero_grad()清零梯度


关键概念总结

概念

说明

代码示例

叶子节点

用户直接创建的张量,梯度计算终点

x = torch.ones(5, requires_grad=True)

动态计算图

运行时动态构建的操作依赖图,反向传播后自动释放

y = x + 2生成 AddBackward节点

非标量反向传播

out非标量(如向量),需传入 gradient参数作为权重矩阵

z.backward(torch.ones_like(z))

梯度保留

设置 retain_graph=True可保留计算图,支持多次反向传播

out.backward(retain_graph=True)


提示​:理解计算图结构是调试自动求导的关键。可通过 print(y.grad_fn)查看操作类型(如输出 <AddBackward0>),或使用 torchviz库可视化计算图

http://www.lryc.cn/news/626081.html

相关文章:

  • 开源 C++ QT Widget 开发(一)工程文件结构
  • vfs_statfs使用 查看当前文件系统一些信息情况
  • RocketMq消费者动态订阅topic
  • 聚合链路与软件网桥的原理及配置方法
  • 【LeetCode 热题 100】279. 完全平方数——(解法一)记忆化搜索
  • JVM原生的assert关键字
  • 手写C++ string类实现详解
  • 使用redis读写锁实现抢券功能
  • 怎样平衡NLP技术发展中数据质量和隐私保护的关系?
  • JVM 面试精选 20 题(续)
  • JVM对象创建和内存分配
  • SpringAI接入openAI配置出现的问题全解析
  • 今日行情明日机会——20250819
  • Java开发面试实战:Spring Boot微服务与数据库优化案例分析
  • 星图云开发者平台新功能速递 | 微服务管理器:无缝整合异构服务,释放云原生开发潜能
  • 微服务如何集成swagger3
  • 微服务-08.微服务拆分-拆分商品服务
  • UE5 使用RVT制作地形材质融合
  • idea如何设置tab为4个空格
  • CSS backdrop-filter:给元素背景添加模糊与色调的高级滤镜
  • Day08 Go语言学习
  • Ansible 中的文件包含与导入机制
  • 常见 GC 收集器与适用场景:从吞吐量到亚毫秒停顿的全景指南
  • NestJS 依赖注入方式全解
  • TDengine IDMP 运维指南(3. 使用 Ansible 部署)
  • 【上升跟庄买入】副图/选股指标,动态黄色线由下向上穿越绿色基准线时,发出买入信号
  • day32-进程与线程(5)
  • Ubuntu 下面安装搜狗输入法debug记录
  • Ubuntu一键安装harbor脚本
  • WSL虚拟机(我的是ubuntu20.04)将系统文件转移到E盘