当前位置: 首页 > news >正文

Pytorch(5)-----梯度计算

一、问题

    如何使用Pytorch计算样本张量的基本梯度呢?考虑一个样本数据集,且有两个展示变量,在给定初始权重的基础上,如何在每次迭代中计算梯度呢?

二、如何运行

    假设有x_data 和 y_data 列表,计算两个列表需要计算损失函数,一个forward通道以及一个循环中的训练。

    forward函数计算权重矩阵和输入张量的乘积。

from torch import FloatTensor
from torch.autograd import Variable  # 引入Variable方法是为了计算变量的梯度
a = Variable(FloatTensor([5]))
weights = [Variable(FloatTensor([i]), requires_grad=True) for i in (12, 53, 91, 73)]w1, w2, w3, w4 = weights  #权重赋值
b = w1 * a
c = w2 * a
d = w3 * b + w4 * c
Loss = (10 - d)
Loss.backward() #从loss 开始反向传播for index, weight in enumerate(weights, start=1):gradient, *_ = weight.grad.data  #取出梯度print(f"Gradient of w{index} w.r.t to Loss: {gradient}")Gradient of w1 w.r.t to Loss: -455.0
Gradient of w2 w.r.t to Loss: -365.0
Gradient of w3 w.r.t to Loss: -60.0
Gradient of w4 w.r.t to Loss: -265.0# 使用forward
def forward(x):return x * w  #forwar过程import torch
from torch.autograd import Variable
x_data = [11.0, 22.0, 33.0]
y_data = [21.0, 14.0, 64.0]w = Variable(torch.Tensor([1.0]), requires_grad=True) # 初始化为任意值;# 训练前打印
print("predict (before training)", 4, forward(4).data[0])
# 定义损失函数
def loss(x, y):y_pred = forward(x)return (y_pred - y) * (y_pred - y)
#运行训练循环
for epoch in range(10):for x_val, y_val in zip(x_data, y_data):l = loss(x_val, y_val)l.backward()print("\tgrad: ", x_val, y_val, w.grad.data[0])w.data = w.data - 0.01 * w.grad.data# 训练后,人工设置梯度为0,否则梯度会累加;w.grad.data.zero_()print("progress:", epoch, l.data[0])#结果
grad: 11.0 21.0 tensor(-220.)
grad: 22.0 14.0 tensor(2481.6001)
grad: 33.0 64.0 tensor(-51303.6484)progress: 0 tensor(604238.8125)
progress: 1 …………………………………………
………………………………………………………………………………
#训练后的预测 权重已更新
print("predict (after training)", 4, forward(4).data[0])#结果
predict (after training) 4 tensor(-9.2687e+24)

   下面的程序展示了如何用Variable 变量从损失函数计算梯度:

a = Variable(FloatTensor([5]))
weights = [Variable(FloatTensor([i]), requires_grad=True) for i in (12, 53, 91, 73)]
w1, w2, w3, w4 = weights
b = w1 * a
c = w2 * a
d = w3 * b + w4 * c
Loss = (10 - d)
Loss.backward()

http://www.lryc.cn/news/380245.html

相关文章:

  • C#的膨胀之路:创新还是灭亡
  • SpringBoot 过滤器和拦截器的区别
  • 协程执行顺序引发的问题
  • android webview调用js滚动到指定位置
  • WPF 深入理解一、基础知识介绍
  • 腾讯云点播ugc upload | lack signature 问题处理
  • 计算机视觉实验二:基于支持向量机和随机森林的分类(Part one: 编程实现基于支持向量机的人脸识别分类 )
  • 5.什么是C语言
  • DINO-DETR
  • Representation RL:HarmonyDream: Task Harmonization Inside World Models
  • Centos7系统下Docker的安装与配置
  • 无人机校企合作
  • 八爪鱼现金流-028,个人网站访问数据统计分析,解决方案
  • 大厂面试官问我:布隆过滤器有不能扩容和删除的缺陷,有没有可以替代的数据结构呢?【后端八股文二:布隆过滤器八股文合集】
  • PHP米表域名出售管理源码带后台
  • 【开发12年码农教你】Android端简单易用的SPI框架-——-SPA
  • 以太坊==MetaMask获取测试币最新网址
  • 军用FPGA软件 Verilog语言的编码准测之触发器、锁存器
  • 智能汽车 UI 风格独具魅力
  • javafx例子笔记
  • 【ajax基础】回调函数地狱
  • SparkSQL的分布式执行引擎-Thrift服务:学习总结(第七天)
  • 联华集团:IT团队如何实现从成本中心提升至价值中心|OceanBase 《DB大咖说》(十)
  • 计算机系统基础实训五—CacheLab实验
  • PHP框架之CodeIgniter框架
  • 714. 买卖股票的最佳时机含手续费
  • Linux系统查看程序内存及CPU占用
  • 数据结构7---图
  • Excel 如何复制单元格而不换行
  • 前端 CSS 经典:mix-blend-mode 属性