当前位置: 首页 > news >正文

神经网络 torch.nn---Non-Linear Activations (ReLU)

ReLU — PyTorch 2.3 documentation

torch.nn - PyTorch中文文档 (pytorch-cn.readthedocs.io)

非线性变换的目的

  • 非线性变换的目的是为神经网络引入一些非线性特征,使其训练出一些符合各种曲线或各种特征的模型。

  • 换句话来说,如果模型都是直线特征的话,它的泛化能力会不够好

torch.nn.ReLU

torch.nn.ReLU(inplace=False)torch.nn.modules.activation — PyTorch 2.3 documentation

inplace参数:

  • inplace=True,则会自动替换输入时的变量参数。如:input=-1,ReLU(input,implace=True),那么输出后,input=output=0

  • inplace=True,则不替换输入时的变量参数。如:input=-1,ReLU(input,implace=True),那么输出后,input=-1,output=0

作用:

  • input <= 0, output = 0
  • input  >  0,   output = input

计算公式:

程序代码:

示例1:

import torch
from torch import nn
from torch.nn import ReLUinput =torch.tensor([[1, -0.5],[-1, 3]
])
print(input.shape)input = torch.reshape(input,(-1,1,2,2))
print(input.shape)class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.relu1 = ReLU()  #inplace bool   原数据是否被替换def forward(self, input):output = self.relu1(input)return outputtudui = Tudui()
output = tudui(input)
print(output)

输出:

示例2:

import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=64)
# shuffle 是否打乱   False不打乱
# drop_last 最后一轮数据不够时,是否舍弃 true舍弃class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.sigmoid1 = Sigmoid()  #inplace bool   原数据是否被替换def forward(self, input):output = self.sigmoid1(input)return outputtudui = Tudui()
step = 1
writer = SummaryWriter('logs')
for data in dataloader:imgs, targets = datawriter.add_images('inputs',imgs,step)outputs = tudui(imgs)writer.add_images("outputs",outputs,step)step += 1writer.close()

在TensorBoard上看输出内容:

http://www.lryc.cn/news/368033.html

相关文章:

  • 【微服务】使用kubekey部署k8s多节点及kubesphere
  • 目标检测数据集 - 垃圾桶满溢检测数据集下载「包含VOC、COCO、YOLO三种格式」
  • 6.9总结(省赛排位赛1)
  • 58.CountdownLatch
  • Java数据结构准备工作---常用类
  • SD 使用教程
  • Sylar---协程调度模块
  • iOS Hook 崩溃
  • 区间预测 | Matlab实现LSTM-ABKDE长短期记忆神经网络自适应带宽核密度估计多变量回归区间预测
  • linux内核下rapidio(TSI721)相关笔记汇总
  • 从GPT-4到GPT-4o:人工智能的进化与革命
  • 【Java】/*抽象类和接口*/
  • TCP/IP协议介绍——三次握手四次挥手
  • [C++]基于C++opencv结合vibe和sort tracker实现高空抛物实时检测
  • Apache Doris 基础 -- 数据表设计(模式更改)
  • 【机器学习】【遗传算法】【项目实战】药品分拣的优化策略【附Python源码】
  • 电子电气架构 ---车载安全防火墙
  • 解决selenium加载网页过慢影响程序运行时间的问题
  • 何为云防护?有何作用
  • 2024050402-重学 Java 设计模式《实战责任链模式》
  • centos7安装字体
  • Llama模型家族之使用 ReFT技术对 Llama-3 进行微调(三)为 ReFT 微调准备模型及数据集
  • 学习Canvas过程中2D的方法、注释及感悟一(通俗易懂)
  • 《TCP/IP网络编程》(第十三章)多种I/O函数(2)
  • Java集合汇总
  • 度小满金融大模型的应用创新
  • Android WebView上传文件/自定义弹窗技术,附件的解决方案
  • selenium 输入框、按钮,输入点击,获取元素属性等简单例子
  • 结构体构造函数
  • 基于单片机的电子万年历设计