当前位置: 首页 > news >正文

torch.nn中Sequential的使用

1、torch.nn中的Sequential介绍

结构:
torch.nn–>Containers–>Sequential

class torch.nn.Sequential(*args: Module)
class torch.nn.Sequential(arg: OrderedDict[str, Module])

  一种顺序容器。模块将按照它们在构造函数中传递的顺序添加到其中。或者,可以传入模块的 OrderedDict。Sequential 的 forward() 方法接受任何输入并将其转发到它包含的第一个模块。然后,它将输出按顺序“链接”到每个后续模块的输入,最后返回最后一个模块的输出。
  Sequential 提供的值相对于手动调用序列 的模块是它允许将整个容器视为 单个模块,以便在 Sequential 适用于它存储的每个模块(每个模块都是 Sequential 的注册子模块)。
  Sequential 和 torch.nn.ModuleList的区别?ModuleList 顾名思义——一个用于存储模块的列表。另一方面, Sequential 中的层以级联方式连接。

# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()
)# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([("conv1", nn.Conv2d(1, 20, 5)),("relu1", nn.ReLU()),("conv2", nn.Conv2d(20, 64, 5)),("relu2", nn.ReLU()),])
)

1.1 Sequential的方法

  1. append(module)——将给定模块附加到末尾。
    参数:module (nn.Module) – 要附加的模块
    返回值:Self
import torch.nn as nn
n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
n.append(nn.Linear(3, 4))
Sequential((0): Linear(in_features=1, out_features=2, bias=True)(1): Linear(in_features=2, out_features=3, bias=True)(2): Linear(in_features=3, out_features=4, bias=True)
)
  1. extend(sequential)——使用另一个顺序容器中的层扩展当前顺序容器。
    参数:sequential (Sequential) – 一个顺序容器,其层将添加到当前容器中。
    返回值:Self
import torch.nn as nn
n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5))
n.extend(other) # or `n + other`
Sequential((0): Linear(in_features=1, out_features=2, bias=True)(1): Linear(in_features=2, out_features=3, bias=True)(2): Linear(in_features=3, out_features=4, bias=True)(3): Linear(in_features=4, out_features=5, bias=True)
)
  1. insert(index, module)——将模块插入指定索引处的顺序容器中。
    参数: - index (int) – 要插入模块的索引。 - module (Module) – 要插入的模块。
    返回值:Self
import torch.nn as nn
n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
n.insert(0, nn.Linear(3, 4))
Sequential((0): Linear(in_features=3, out_features=4, bias=True)(1): Linear(in_features=1, out_features=2, bias=True)(2): Linear(in_features=2, out_features=3, bias=True)
)

2、Pytorch实战

2.1 参数设置

  这里以CIFAR10数据集为例,使用如下网络模型:
在这里插入图片描述
卷积层的参数可以由torch.nn.Conv2d的介绍计算得到:
在这里插入图片描述

最大池化层的参数可以由torch.nn.MaxPool2d的介绍计算得到:
在这里插入图片描述

2.2 建立网络模型并验证结构

# pytorch实战——sequential practicefrom torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linearclass Seq(nn.Module):def __init__(self):super(Seq,self).__init__()self.conv1=Conv2d(3,32,5,1,2)self.maxpool1=MaxPool2d(2)self.conv2=Conv2d(32,32,5,1,2)self.maxpool2=MaxPool2d(2)self.conv3=Conv2d(32,64,5,1,2)self.maxpool3=MaxPool2d(2)self.flatten=Flattenself.linear1=Linear(1024,64)self.linear2=Linear(64,10)def forward(self,x):x=self.conv1(x)x=self.maxpool1(x)x=self.conv2(x)x=self.maxpool2(x)x=self.conv3(x)x=self.maxpool3(x)x=self.flatten(x)x=self.linear1(x)x=self.linear2(x)return xseq=Seq()
print(seq)
input=torch.ones((64,3,32,32))
output=seq(input)
print(output.shape)

结果:

Seq((conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(linear1): Linear(in_features=1024, out_features=64, bias=True)(linear2): Linear(in_features=64, out_features=10, bias=True)
)
torch.Size([64, 10])

2.3 使用Sequential重新建立网络模型

# pytorch实战——sequential practice
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linearclass Seq(nn.Module):def __init__(self):super(Seq,self).__init__()self.module=Sequential(Conv2d(3, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 64, 5, 1, 2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x=self.module(x)return xseq=Seq()

2.4 使用Tensorboard可视化网络结构

Tensorboard学习笔记:Pytorch中Tensorboard的学习

# pytorch实战——sequential practice
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.tensorboard import SummaryWriterclass Seq(nn.Module):def __init__(self):super(Seq,self).__init__()self.module=Sequential(Conv2d(3, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 32, 5, 1, 2),MaxPool2d(2),Conv2d(32, 64, 5, 1, 2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self,x):x=self.module(x)return xseq=Seq()
print(seq)
input=torch.ones((64,3,32,32))
output=seq(input)
print(output.shape)writer=SummaryWriter(".\logs_seq")
writer.add_graph(seq,input)
writer.close()

在终端进入当前环境,输入命令:

(mypytorch) PS E:\my_pycharm_projects\project1> tensorboard --logdir=logs_seq
#结果:
TensorFlow installation not found - running with reduced feature set.
W0811 23:16:22.249774 31868 plugin_event_accumulator.py:369] Found more than one graph event p
er run, or there was a metagraph containing a graph_def, as well as one or more graph events.  Overwriting the graph with the newest event.
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

打开网址:
在这里插入图片描述
双击放大查看模型结构细节:
在这里插入图片描述
在这里插入图片描述

http://www.lryc.cn/news/621232.html

相关文章:

  • 【代码随想录day 20】 力扣 538.把二叉搜索树转换为累加树
  • CMake语法与Bash语法的区别
  • 扩展用例-失败的嵌套
  • 流式数据服务端怎么传给前端,前端怎么接收?
  • jenkins在windows配置sshpass
  • 设计模式笔记_行为型_状态模式
  • 【JavaEE】多线程 -- 线程状态
  • 纸箱拆垛:物流自动化中的“开箱密码”与3D视觉的智能革命
  • 面试题之项目中灰度发布是怎么做的
  • Flink on YARN启动全流程深度解析
  • 会议通信系统核心流程详解(底稿1)
  • Linux软件编程:进程和线程
  • C#面试题及详细答案120道(01-10)-- 基础语法与数据类型
  • Flink Stream API 源码走读 - socketTextStream
  • 2025H1手游市场:SLG领涨、休闲爆发,何为出海新航道?
  • 广告灯的左移右移
  • Day43 复习日
  • FPGA+护理:跨学科发展的探索(五)
  • Kotlin Data Classes 快速上手
  • 【深度学习】深度学习基础概念与初识PyTorch
  • 报数游戏(我将每文更新tips)
  • IPTV系统:开启视听与管理的全新篇章
  • 14 ABP Framework 文档管理
  • 【软考中级网络工程师】知识点之入侵防御系统:筑牢网络安全防线
  • SpringMVC(详细版从入门到精通)未完
  • P5967 [POI 2016] Korale 题解
  • 【数据分享】2014-2023年长江流域 (0.05度)5.5km分辨率的每小时日光诱导叶绿素荧光SIF数据
  • stm32项目(28)——基于stm32的环境监测并上传至onenet云平台
  • LT3045EDD#TRPBF ADI亚德诺 超低噪声LDO稳压器 电子元器件IC
  • web网站开发,在线%射击比赛成绩管理%系统开发demo,基于html,css,jquery,python,django,model,orm,mysql数据库