当前位置: 首页 > news >正文

PyTorch基本使用-自动微分模块

学习目的:掌握自动微分模块的使用

训练神经网络时,最常用的算法就是反向传播。在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整。为了计算这些梯度,PyTorch 内置了名为 torch.autograd的微分引擎。它支持任意计算图的自动梯度计算:
在这里插入图片描述

接下来我们使用这个结构进行自动微分模块的介绍。我们使用 backward 方法、grad 属性来实现梯度的计算和访问。

  • 当X为标量时梯度的计算

    import torch
    # 1. 当X为标量时梯度的计算
    def test01():x = torch.tensor(5)# 目标值y = torch.tensor(0.)# 设置要更新的权重和偏置的初始值w = torch.tensor(1.0,requires_grad=True,dtype=torch.float32)b = torch.tensor(3.0,requires_grad=True,dtype=torch.float32)#设置网络的输出值z = x*w + b #矩阵乘法# 设置损失函数,并进行损失计算loss = torch.nn.MSELoss()loss = loss(z,y)# 自动微分loss.backward()# 打印w,b变量的梯度# backward 函数计算的梯度值会存储在张量的grad 变量中print('W的梯度:',w.grad)print('B的梯度:',b.grad)test01()
    

    输出结果:

    W的梯度: tensor(80.)
    B的梯度: tensor(16.)
    
  • 当X为多维张量时梯度计算

    import torch
    def test02():# 输入张量 2*5x = torch.ones(2,5)# 目标张量 2*3y = torch.zeros(2,3)# 设置要更新的权重和偏置的初始值w = torch.randn(5,3,requires_grad=True)b = torch.randn(3,requires_grad=True)#设置网络的输出值z = torch.matmul(x,w)+ b #矩阵乘法# 设置损失函数,并进行损失计算loss = torch.nn.MSELoss()loss = loss(z,y)# 自动微分loss.backward()# 打印w,b变量的梯度# backward 函数计算的梯度值会存储在张量的grad 变量中print('W的梯度:',w.grad)print('B的梯度:',b.grad)test02()
    

    输出结果:

    W的梯度: tensor([[-1.7502,  0.8537,  0.6175],[-1.7502,  0.8537,  0.6175],[-1.7502,  0.8537,  0.6175],[-1.7502,  0.8537,  0.6175],[-1.7502,  0.8537,  0.6175]])
    B的梯度: tensor([-1.7502,  0.8537,  0.6175])
    
http://www.lryc.cn/news/503362.html

相关文章:

  • libevent-Reactor设计模式【1】
  • 奇奇怪怪的错误-Tag和space不兼容
  • 29.攻防世界ics-06
  • 强化学习路径规划:基于SARSA算法的移动机器人路径规划,可以更改地图大小及起始点,可以自定义障碍物,MATLAB代码
  • 【MFC】如何读取rtf文件并进行展示
  • Vulhub:Log4j[漏洞复现]
  • 面向预测性维护的TinyML技术栈全面综述
  • 沈阳理工大学《2024年811自动控制原理真题》 (完整版)
  • 用前端html如何实现2024烟花效果
  • Redis应用-在用户数据里的应用
  • C++ 中面向对象编程如实现数据隐藏
  • JavaEE 【知识改变命运】04 多线程(3)
  • gz中生成模型
  • 前端(Axios和Promis)
  • AI Agent:重塑业务流程自动化的未来力量(2/30)
  • 前端页面导出word
  • 【考前预习】1.计算机网络概述
  • ubuntu20.04复现 Leg-KILO
  • Ensembl数据库下载参考基因组(常见模式植物)bioinfomatics 工具37
  • 简单介绍web开发和HTML CSS_web网站开发流程
  • Docker 中使用 PHP 通过 Canal 同步 Mysql 数据到 ElasticSearch
  • 数据结构之五:排序
  • 科研绘图系列:R语言绘制热图和散点图以及箱线图(pheatmap, scatterplot boxplot)
  • 基于 webRTC Vue 的局域网 文件传输工具
  • LeetCode 718. 最长重复子数组 java题解
  • 算法知识-15-深搜
  • 区块链dapp 开发详解(VUE3.0)
  • Plugin [id: ‘flutter‘] was not found in any of the following sources解决方法
  • 专升本-高数 1
  • 【考前预习】3.计算机网络—数据链路层