当前位置: 首页 > news >正文

P16 激活函数与Loss 的梯度

参考:

https://www.ngui.cc/el/507608.html?action=onClick

这里面简单回顾一下PyTorch 里面的两个常用的梯度自动计算的API

autoGrad 和 Backward, 最后结合 softmax 简单介绍一下一下应用场景。

目录:

1 autoGrad

2 Backward

3 softmax


一 autoGrad

输入

x

输出

损失函数

参数更新

# -*- coding: utf-8 -*-
"""
Created on Mon Feb 13 21:28:26 2023@author: cxf
"""import torch
import torch.nn.functional as Fdef grad():x = torch.tensor([[1.0,2.0]]).view(2,1)w = torch.full([2,1], 1.0,requires_grad= True)target = torch.ones((1,1))out = torch.matmul(w.T, x)print(out)mse = F.mse_loss(out, target)print("\n mse",mse)grad_w = torch.autograd.grad(mse,[w])    print(grad_w)if __name__ == "__main__":grad()


二 Backward

求梯度另一种方法,可以通过backward

在创建动态图后,直接调用backward,更加方便

import torch
import torch.nn.functional as Fdef grad():x = torch.tensor([[1.0,2.0]]).view(2,1)w = torch.full([2,1], 1.0,requires_grad= True)target = torch.ones((1,1))out = torch.matmul(w.T, x)print(out)mse = F.mse_loss(out, target)print("\n mse",mse)mse.backward()   print(w.grad)if __name__ == "__main__":grad()

三 softmax

多分类模型常用的激活函数

这种模型通常用交叉熵做损失函数

因为标签中只有一个为1,其它都为0,假设为

则:

(j=i)

则写成向量形式为

import torch
import torch.nn.functional as F
from torch import nn#自己实现该梯度计算
def calcGrad(a,target):grad =a -targetprint("\n 直接计算",grad)# 直接计算 tensor([[ 0.0900, -0.7553,  0.6652]], grad_fn=<SubBackward0>)#调用API 方式实现
def grad():CEL =  nn.CrossEntropyLoss()z = torch.tensor([[1.0,2.0,3.0]],requires_grad=True)a = F.softmax(z,dim=1)print("\n 神经元输出",a)target = torch.tensor([[0.0,1.0,0.0]])loss =CEL(z,target)loss.backward()print("\n API 计算",z.grad)# API 计算 tensor([[ 0.0900, -0.7553,  0.6652]])calcGrad(a,target)if __name__ == "__main__":grad()

这里面要注意nn.CrossEntropyLoss

是相当于对z 先做softmax,得到a, 然后再做交叉熵

http://www.lryc.cn/news/8672.html

相关文章:

  • ThinkPHP5美食商城系统
  • Vue3 - $refs 使用教程,父组件调用获取子组件数据和方法(setup() / <script setup>)
  • 华为OD机试 - 众数和中位数(Python)| 真题+思路+考点+代码+岗位
  • 一眼万年的 Keychron 无线机械键盘
  • 自动化测试高频面试题(含答案)
  • 3、按键扫描检测处理
  • 集中式存储和分布式存储
  • 【机器学习数据集】如何获得机器学习的练习数据?
  • 【编程实践】使用 Kotlin HTTP 框架 Fuel 实现 GET,POST 接口 kittinunf.fuel【极简教程】
  • 大数据DataX(一):DataX的框架设计和插件体系
  • 软考高级信息系统项目管理师系列之十一:项目进度管理
  • vue2版本《后台管理模式》(下)
  • 软考中级-程序设计语言
  • Sphinx : 高性能SQL全文检索引擎
  • ansible实战应用系列教程6:管理ansible变量
  • java8新特性Stream流中anyMatch和allMatch和noneMatch的区别详解
  • 双网卡(有线和wifi)同时连接内网和外网
  • 如何赋能智能运维,迈出数字化黑匣子第一步?
  • 消息称索尼计划为PS5推出两款蓝牙耳机,Find My蓝牙耳机用途广
  • 状态管理VueX
  • i.MX8MP平台开发分享(clock篇)- PLL14xx驱动
  • 课程规范性要求
  • 华为OD机试 - 优秀学员统计(Python)| 真题+思路+考点+代码+岗位
  • 布林线(BOLL)计算公式详解,开口收口代表什么
  • 模糊的照片能修复吗?
  • 【Java|多线程与高并发】详解start()方法和run()方法的区别
  • mysql 一些有意思的sql语句,备忘
  • hive自定义函数
  • 数仓理论【范式】【维度建模】
  • 卷积神经网络