当前位置: 首页 > news >正文

常用torch.nn

目录

  • 一、torch.nn和torch.nn.functional
  • 二、nn.Linear
  • 三、nn.Embedding
  • 四、nn.Identity
  • 五、Pytorch非线性激活函数
  • 六、nn.Conv2d
  • 七、nn.Sequential
  • 八、nn.ModuleList
  • 九、torch.outer torch.cat

一、torch.nn和torch.nn.functional

Pytorch中torch.nn和torch.nn.functional的区别及实例详解
Pytorch中,nn与nn.functional有哪些区别?
相同之处:

两者都继承于nn.Module
nn.x与nn.functional.x的实际功能相同,比如nn.Conv3d和nn.functional.conv3d都是进行3d卷积
运行效率几乎相同
不同之处:
nn.x是nn.functional.x的类封装,nn.functional.x是具体的函数接口
nn.x除了具有nn.functional.x功能之外,还具有nn.Module相关的属性和方法,比如:train(),eval()等
nn.functional.x直接传入参数调用,nn.x需要先实例化再传参调用
nn.x能很好的与nn.Sequential结合使用,而nn.functional.x无法与nn.Sequential结合使用
nn.x不需要自定义和管理weight,而nn.functional.x需自定义weight,作为传入的参数

二、nn.Linear

torch.nn.functional 中的 Linear 函数
torch.nn.functional.linear 是 PyTorch 框架中的一个功能模块,主要用于实现线性变换。这个函数对于构建神经网络中的全连接层(或称为线性层)至关重要。它能够将输入数据通过一个线性公式(y = xA^T + b)转换为输出数据,其中 A 是权重,b 是偏置项。

用途
神经网络构建:在构建神经网络时,linear 函数用于添加线性层。
特征变换:在数据预处理和特征工程中,使用它进行线性特征变换。

基本用法如下:

output = torch.nn.functional.linear(input, weight, bias=None)
input:输入数据
weight:权重
bias:偏置(可选)
class torch.nn.Linear(in_features, out_features, bias=True)
参数:
in_features - 每个输入样本的大小
out_features - 每个输出样本的大小
bias - 若设置为False,这层不会学习偏置。默认值:True

对输入数据做线性变换:y=Ax+b
形状:
输入: (N,in_features)
输出: (N,out_features)
变量:
weight -形状为(out_features x in_features)的模块中可学习的权值
bias -形状为(out_features)的模块中可学习的偏置
例子:

>>> m = nn.Linear(20, 30)
>>> input = autograd.Variable(torch.randn(128, 20))
>>> output = m(input)
>>> print(output.size())

三、nn.Embedding

pytorch nn.Embedding详解
nn.Embedding作用
nn.Embedding是PyTorch中的一个常用模块,其主要作用是将输入的整数序列转换为密集向量表示。在自然语言处理(NLP)任务中,可以将每个单词表示成一个向量,从而方便进行下一步的计算和处理。
nn.Embedding词向量转化
nn.Embedding是将输入向量化,定义如下:

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)
参数说明:
num_embeddings :字典中词的个数
embedding_dim:embedding的维度
padding_idx(索引指定填充):如果给定,则遇到padding_idx中的索引,则将其位置填00是默认值,事实上随便填充什么值都可以)。

在PyTorch中,nn.Embedding用来实现词与词向量的映射。nn.Embedding具有一个权重(.weight),形状是(num_words, embedding_dim)。例如一共有100个词,每个词用16维向量表征,对应的权重就是一个100×16的矩阵。
Embedding的输入形状N×W,N是batch size,W是序列的长度,输出的形状是N×W×embedding_dim。
Embedding输入必须是LongTensor,FloatTensor需通过tensor.long()方法转成LongTensor。
Embedding的权重是可以训练的,既可以采用随机初始化,也可以采用预训练好的词向量初始化。

四、nn.Identity

nn.Identity() 是 PyTorch 中的一个层(layer)。它实际上是一个恒等映射,不对输入进行任何变换或操作,只是简单地将输入返回作为输出。

五、Pytorch非线性激活函数

torch.nn.functional非线性激活函数

六、nn.Conv2d

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

对几个输入特征图组成的输入信号应用2D卷积。
参数:
input – 输入张量 (minibatch x in_channels x iH x iW)
weight – 过滤器张量 (out_channels, in_channels/groups, kH, kW)
bias – 可选偏置张量 (out_channels) -
stride – 卷积核的步长,可以是单个数字或一个元组 (sh x sw)。默认为1
padding – 输入上隐含零填充。可以是单个数字或元组。 默认值:0
groups – 将输入分成组,in_channels应该被组数除尽

例子:

>>> # With square kernels and equal stride
>>> filters = autograd.Variable(torch.randn(8,4,3,3))
>>> inputs = autograd.Variable(torch.randn(1,4,5,5))
>>> F.conv2d(inputs, filters, padding=1)

七、nn.Sequential

nn.Sequential()介绍
一个序列容器,用于搭建神经网络的模块被按照被传入构造器的顺序添加到nn.Sequential()容器中。除此之外,一个包含神经网络模块的OrderedDict也可以被传入nn.Sequential()容器中。利用nn.Sequential()搭建好模型架构,模型前向传播时调用forward()方法,模型接收的输入首先被传入nn.Sequential()包含的第一个网络模块中。然后,第一个网络模块的输出传入第二个网络模块作为输入,按照顺序依次计算并传播,直到nn.Sequential()里的最后一个模块输出结果。
nn.Sequential()和torch.nn.ModuleList的区别在于:torch.nn.ModuleList只是一个储存网络模块的list,其中的网络模块之间没有连接关系和顺序关系。而nn.Sequential()内的网络模块之间是按照添加的顺序级联的。

八、nn.ModuleList

九、torch.outer torch.cat

**torch.outer:**实现张量的点积,张量都需要是一维向量

import torch
v1 = torch.arange(1., 5.)
print(v1)
v2 = torch.arange(1., 4.)
print(v2)
print(torch.outer(v1, v2))tensor([1., 2., 3., 4.])
tensor([1., 2., 3.])
tensor([[ 1.,  2.,  3.],[ 2.,  4.,  6.],[ 3.,  6.,  9.],[ 4.,  8., 12.]])

**torch.cat:**函数将两个张量(tensor)按指定维度拼接在一起,注意:除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。torch.cat()函数不会新增维度,而torch.stack()函数会新增一个维度,相同的是两个都是对张量进行拼接

http://www.lryc.cn/news/351130.html

相关文章:

  • 力扣226.翻转二叉树101.对称二叉树
  • word如何按照原本页面审阅文档
  • 前端基础入门三大核心之HTML篇:探索WebAssembly —— 开启网页高性能应用新时代
  • NDIS小端口驱动(四)
  • 用户态网络缓冲区设计
  • Linux运维工程师基础面试题整理(三)
  • 基于单片机与传感器技术的汽车起动线路设计
  • C#如何通过反射获取外部dll的函数
  • 从零开始傅里叶变换
  • 解决1万条数据前端渲染不卡的问题
  • 如何编写一个API——Python代码示例及拓展
  • UMPNet: Universal Manipulation Policy Network for Articulated Objects
  • 高通 Android 12/13冻结屏幕
  • C++实现图的存储和遍历
  • AI--构建检索增强生成 (RAG) 应用程序
  • QT7_视频知识点笔记_4_文件操作,Socket通信:TCP/UDP
  • 智慧社区管理系统:打造便捷、安全、和谐的新型社区生态
  • CustomTkinter:便捷美化Tkinter的UI界面(附模板)
  • 使用MicroPython和pyboard开发板(15):使用LCD和触摸传感器
  • c++20 std::jthread 源码简单赏析与应用
  • 自动化测试里的数据驱动和关键字驱动思路的理解
  • 【30天精通Prometheus:一站式监控实战指南】第6天:mysqld_exporter从入门到实战:安装、配置详解与生产环境搭建指南,超详细
  • 浅析智能体开发(第二部分):智能体设计模式和软件架构
  • Unity学习笔记---Transform组件
  • springboot+jsp校园理发店美容美发店信息管理系统0h29g
  • css - sass or scss ?
  • html5 笔记01
  • E5063A是德科技e5063a网络分析仪
  • 【星海随笔】微信小程序(二)
  • Python采集安居客租房信息