当前位置: 首页 > news >正文

动手学深度学习(pytorch)学习记录12-激活函数[学习记录]

激活函数

激活函数(activation function)通过计算加权和并加上偏置来确定神经元是否应该被激活, 它们将输入信号转换为输出的可微运算。

import torch  
import matplotlib.pyplot as plt 

简单定义一个画图的函数

def graph_drawing(x_,y_,label_=None): plt.figure(figsize=(5, 2.5))  # 设置图形窗口的大小if label_ is None:plt.plot(x_, y_)else:plt.plot(x_, y_, label = label_)# plt.plot()里不要marker='x'更好看plt.legend()  # 显示图例      plt.show()

创建数据

# 创建 x 数据,并设置 requires_grad=True  
x = torch.arange(-8.0, 8.0, 0.1, requires_grad=True) 

ReLU函数

ReLU(Rectified Linear Unit)函数是一种在深度学习中广泛使用的激活函数,其表达式为f(x) = max(0, x)。它简单地将所有的负值置为0,保持正值不变,有助于解决梯度消失问题,并加速神经网络的训练过程。
在这里插入图片描述

# 应用 ReLU 函数  
y = torch.relu(x)  
graph_drawing(x_=x.detach(),y_=y.detach(),label_='relu(x)')

在这里插入图片描述
当输入为负时,ReLU函数的导数为0,而当输入为正时,ReLU函数的导数为1。 当输入值精确等于0时,ReLU函数不可导。可以忽略这种情况,因为输入可能永远都不会是0.

# 绘制ReLU函数的导函数图像
y.backward(torch.ones_like(x), retain_graph=True)
# retain_graph=True:这是一个可选参数,用于控制梯度图(即用于计算梯度的图结构)的保留。在默认情况下,.backward()会清除梯度图以节省内存。
graph_drawing(x_=x.detach(),y_=x.grad.numpy(),label_='grad of ReLU')
#转换为 numpy 数组,似乎转不转都行

在这里插入图片描述

sigmoid函数

对于一个定义域在R中的输入, sigmoid函数将输入变换为区间(0, 1)上的输出。 因此,sigmoid通常称为挤压函数(squashing function): 它将范围(-inf, inf)中的任意输入压缩到区间(0, 1)中的某个值。
当我们想要将输出视作二元分类问题的概率时, sigmoid仍然被广泛用作输出单元上的激活函数 (sigmoid可以视为softmax的特例)。
当输入接近0时,sigmoid函数接近线性变换。
在这里插入图片描述

# 绘制sigmoid函数图像
y = torch.sigmoid(x)
graph_drawing(x_=x.detach(),y_=y.detach(),label_='sigmoid(x)')

在这里插入图片描述
sigmoid函数的导数当输入为0时,sigmoid函数的导数达到最大值0.25; 而输入在任一方向上越远离0点时,导数越接近0。
在这里插入图片描述

# 清除以前的梯度
x.grad.data.zero_()
y.backward(torch.ones_like(x),retain_graph=True)
graph_drawing(x_=x.detach(),y_=x.grad.numpy(),label_='grad of sigmoid')

在这里插入图片描述

tanh函数

与sigmoid函数类似, tanh(双曲正切)函数也能将其输入压缩转换到区间(-1, 1)上。
在这里插入图片描述

y = torch.tanh(x)
graph_drawing(x_=x.detach().numpy(), y_=y.detach().numpy(), label_='tanh(x)')

在这里插入图片描述
tanh函数的导数图像: 当输入接近0时,tanh函数的导数接近最大值1。 与sigmoid函数图像类似, 输入在任一方向上越远离0点,导数越接近0
在这里插入图片描述

# 清除以前的梯度
x.grad.data.zero_()
y.backward(torch.ones_like(x),retain_graph=True)
graph_drawing(x_=x.detach(),y_=x.grad.numpy(),label_='grad of tanh')

在这里插入图片描述
封面图片来源

欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

http://www.lryc.cn/news/425942.html

相关文章:

  • 微服务实战系列之玩转Docker(十)
  • Mysql(四)---增删查改(进阶)
  • SOAP @WebService WSDL
  • 【Qt】QWidget的toolTip属性
  • 【操作系统】什么是进程?什么是线程?两者有什么区别(面试常考!!!)
  • AI -- Machine Learning
  • 了解交换机_1.交换机的技术发展
  • ubuntu 24.04 安装 Nvidia 显卡驱动 + CUDA + cuDNN,配置 AI 深度学习训练环境,简单易懂,一看就会!
  • 跟李沐学AI:目标检测的常用算法
  • 基于UE5和ROS2的激光雷达+深度RGBD相机小车的仿真指南(一)---UnrealCV获取深度+分割图像
  • Java算法解析一:二分算法及其衍生出来的问题
  • 数学建模预测类—【一元线性回归】
  • 配置更加美观的 Swagger UI
  • 软件测试 - 基础(软件测试的生命周期、测试报告、bug的级别、与开发人员产生争执的调解方式)
  • RTX 4070 GDDR6显存曝光:性能与成本的平衡之选
  • canvas的基础使用
  • Windows 常用网络命令之 telnet(测试端口是否连通)
  • x264 编码器像素运算系列:asd8函数
  • 什么是AR、VR、MR、XR?
  • Epic Games 商店面向欧盟 iPhone 用户上线
  • 【计算机毕设项目】2025级计算机专业小程序项目推荐 (小程序+后台管理)
  • Fast API + LangServe快速搭建 LLM 后台
  • CSS继承、盒子模型、float浮动、定位、diaplay
  • 使用百度文心智能体创建AI旅游助手
  • 斗破C++编程入门系列之四:运算符和表达式
  • CVPR2024 | PromptAD: 仅使用正常样本进行小样本异常检测的学习提示
  • 文件批量上传,oss使用时间戳解决同名问题 以及一些sql bug
  • 机器学习——线性回归(sklearn)
  • Go 语言切片(Slice) 15
  • 嵌入式开发--STM32G030C8T6,写片上FLASH死机CFGBSY和写入出错