当前位置: 首页 > news >正文

Pytorch神经网络的模型架构(nn.Module和nn.Sequential的用法)

一、层和块

       在构造自定义块之前,我们先回顾一下多层感知机的代码。下面的代码生成一个网络,其中包含一个具有256个单元和ReLU激活函数的全连接隐藏层,然后是一个具有10个隐藏单元且不带激活函数的全连接输出层。

import torch
from torch import nn
from torch.nn import functional as Fnet = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))X = torch.rand(2, 20)
net(X)
tensor([[ 0.0748, -0.1284,  0.0661,  0.1824,  0.1819, -0.0896, -0.0444,  0.0611,-0.1083, -0.2545],[ 0.0015, -0.1136,  0.0300,  0.2422,  0.1924, -0.1676, -0.1643,  0.0208,-0.1123, -0.1084]], grad_fn=<AddmmBackward0>)

       `nn.Sequential`定义了一种特殊的`Module`,即在PyTorch中表示一个块的类,它维护了一个由`Module`组成的有序列表。注意,两个全连接层都是`Linear`类的实例,`Linear`类本身就是`Module`的子类。另外,到目前为止,我们一直在通过`net(X)`调用我们的模型来获得模型的输出。这实际上是`net.__call__(X)`的简写。这个前向传播函数非常简单:它将列表中的每个块连接在一起,将每个块的输出作为下一个块的输入。

二、自定义块

       Pytorch中任何一个层或者一个神经网络基本都是nn.Module的子类。下面是一个自定义的MLP类,功能和前面代码相同。

class MLP(nn.Module):# 用模型参数声明层。这里,我们声明两个全连接的层def __init__(self):# 调用MLP的父类Module的构造函数来执行必要的初始化。# 这样,在类实例化时也可以指定其他函数参数,例如模型参数paramssuper().__init__()self.hidden = nn.Linear(20, 256)  # 隐藏层self.out = nn.Linear(256, 10)  # 输出层# 定义模型的前向传播,即如何根据输入X返回所需的模型输出def forward(self, X):# 注意,这里我们使用ReLU的函数版本,其在nn.functional模块中定义。return self.out(F.relu(self.hidden(X)))

       所有的Module有两个重要的函数,一个是init()函数,在里面定义需要哪些类和参数,另外一个是forward()函数,定义了模型的前向传播。

       实例化多层感知机的层,然后在每次调用前向传播函数时调用这些层。

net = MLP()
net(X)
tensor([[ 0.0617, -0.0381,  0.0605, -0.2711, -0.0481, -0.1107,  0.2265, -0.0549,0.2573,  0.0887],[-0.0170, -0.0350,  0.1438, -0.2079, -0.0148, -0.0230,  0.0590,  0.0136,0.3161,  0.0014]], grad_fn=<AddmmBackward0>)

三、顺序块

       现在我们可以更仔细地看看`Sequential`类是如何工作的,回想一下`Sequential`的设计是为了把其他模块串起来。为了构建我们自己的简化的`MySequential`,我们只需要定义两个关键函数:

  1. 一种将块逐个追加到列表中的函数;
  2. 一种前向传播函数,用于将输入按追加块的顺序传递给块组成的“链条”。

       下面的`MySequential`类提供了与默认`Sequential`类相同的功能。

class MySequential(nn.Module):def __init__(self, *args):  # *args: list of input argumentssuper().__init__()for idx, module in enumerate(args):# 这里,module是Module子类的一个实例。我们把它保存在'Module'类的成员# 变量_modules中。_module的类型是OrderedDict(有序字典)self._modules[str(idx)] = moduledef forward(self, X):# OrderedDict保证了按照成员添加的顺序遍历它们for block in self._modules.values():X = block(X)return X

       当`MySequential`的前向传播函数被调用时,每个添加的块都按照它们被添加的顺序执行。现在可以使用我们的`MySequential`类重新实现多层感知机。

net = MySequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
net(X)
tensor([[ 0.0425,  0.2652, -0.1381,  0.0156, -0.1683,  0.0906, -0.2825,  0.0234,0.0289,  0.0594],[ 0.0372,  0.2065, -0.1196,  0.0681, -0.1791,  0.1555, -0.4214,  0.1164,-0.0223,  0.0265]], grad_fn=<AddmmBackward0>)

四、在前向传播函数中执行代码

       下面这段代码相比于nn.Sequential更加灵活,能够灵活定义前向计算:

class FixedHiddenMLP(nn.Module):def __init__(self):super().__init__()# 不计算梯度的随机权重参数。因此其在训练期间保持不变self.rand_weight = torch.rand((20, 20), requires_grad=False)self.linear = nn.Linear(20, 20)def forward(self, X):X = self.linear(X)# 使用创建的常量参数以及relu和mm函数X = F.relu(torch.mm(X, self.rand_weight) + 1)# 复用全连接层。这相当于两个全连接层共享参数X = self.linear(X)# 控制流while X.abs().sum() > 1:X /= 2return X.sum()net = FixedHiddenMLP()
net(X)
tensor(0.0402, grad_fn=<SumBackward0>)

五、嵌套使用

       我们可以混合搭配各种组合块的方法。在下面的例子中,我们以一些想到的方法嵌套块。

class NestMLP(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),nn.Linear(64, 32), nn.ReLU())self.linear = nn.Linear(32, 16)def forward(self, X):return self.linear(self.net(X))chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP())
chimera(X)
tensor(-0.0394, grad_fn=<SumBackward0>)

六、总结

  • 一个块可以由许多层组成;一个块可以由许多块组成。
  • 块可以包含代码。
  • 块负责大量的内部处理,包括参数初始化和反向传播。
  • 层和块的顺序连接由`Sequential`块处理。
http://www.lryc.cn/news/262332.html

相关文章:

  • JS数组之展开运算符
  • 读书笔记:《汽车构造与原理》
  • INS 量测更新
  • 【ssh基础知识】
  • 04 开发第一个组件
  • 【Unity】如何让Unity程序一打开就运行命令行命令
  • Web前端-HTML(表格与表单)
  • Android RecycleView实现平滑滚动置顶和调整滚动速度
  • 跳跃游戏 + 45. 跳跃游戏 II
  • 在Django中使用多语言(i18n)
  • 高性价比AWS Lambda无服务体验
  • 【物联网】EMQX(二)——docker快速搭建EMQX 和 MQTTX客户端使用
  • 2023 亚马逊云科技 re:lnvent 大会探秘: Amazon Connect 全渠道云联络中心
  • 鸿蒙开发之用户隐私权限申请
  • Docker笔记:简单部署 nodejs 项目和 golang 项目
  • java内置的数据结构
  • 轻松搭建FPGA开发环境:第三课——Vivado 库编译与设置说明
  • 【PostgreSQL】从零开始:(十一)PostgreSQL-Dropdb命令删除数据库
  • UDP网络编程其他相关事项
  • Redhat LINUX 9.3 + PG 16.1 搭建主备流复制
  • kafka设置消费者组
  • Worker-Thread设计模式
  • npm 安装包遇到问题的常用脚本(RequestError: socket hang up)
  • 活动 | Mint Blockchain 将于 2024 年 1 月 10 号启动 MintPass 限时铸造活动
  • Android动画(四)——属性动画ValueAnimator的妙用
  • C语言飞机大战
  • js 原型 和 原型链
  • 如何利用SD-WAN节省运维成本和简化运维工作?
  • 在工作中使用CHAT提高效率
  • Maven 项目的三种打包方式与 pom.xml 文件中项目描述