当前位置: 首页 > news >正文

【pytorch】ModuleList 与 ModuleDict

ModuleList 与 ModuleDict

  • 1、ModuleList
  • 2、ModuleDict
  • 3、总结


1、ModuleList

1)ModuleList 接收一个子模块的列表作为输入,然后也可以类似 List 那样进行 append 和 extend 操作:

net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1])  # 可使用类似List的索引访问
print(net)
# net(torch.zeros(1, 784)) # 会报NotImplementedError# 输出:
# Linear(in_features=256, out_features=10, bias=True)
# ModuleList(
#   (0): Linear(in_features=784, out_features=256, bias=True)
#   (1): ReLU()
#   (2): Linear(in_features=256, out_features=10, bias=True)
# )

\quad
2)nn.Sequentialnn.ModuleList 二者的区别:

  • nn.ModuleList 仅仅是一个储存各种模块的列表,这些模块之间没有联系(所以不用保证相邻层的输入输出维度匹配), 而 nn.Sequential 内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配
  • nn.ModuleList 没有实现 forward 功能需要自己实现,所以上面执行 net(torch.zeros(1, 784)) 会报 NotImplementedError;而nn.Sequential 内部 forward 功能已经实现。

ModuleList 的出现只是让网络定义前向传播时更加灵活,见下面官网的例子:

class MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])def forward(self, x):# ModuleList can act as an iterable, or be indexed using intsfor i, l in enumerate(self.linears):x = self.linears[i // 2](x) + l(x)return x

\quad
3)另外,nn.ModuleList 不同于一般的 Python 的 list,加入到 nn.ModuleList 里面的所有模块的参数会被自动添加到整个网络中,下面看一个例子对比一下。

import torch
import torch.nn as nnclass Module_ModuleList(nn.Module):def __init__(self):super(Module_ModuleList, self).__init__()self.linears = nn.ModuleList([nn.Linear(10, 10)])class Module_List(nn.Module):def __init__(self):super(Module_List, self).__init__()self.linears = [nn.Linear(10, 10)]net1 = Module_ModuleList()
net2 = Module_List()print(net1)
for p in net1.parameters():print(p.size())print('*'*20)
print(net2)
for p in net2.parameters():print(p)

输出

Module_ModuleList((linears): ModuleList((0): Linear(in_features=10, out_features=10, bias=True))
)
torch.Size([10, 10])
torch.Size([10])
********************
Module_List()

2、ModuleDict

ModuleDict接收一个子模块的字典作为输入, 然后也可以类似字典那样进行添加访问操作:

net = nn.ModuleDict({'linear': nn.Linear(784, 256),'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
# net(torch.zeros(1, 784)) # 会报NotImplementedError# 输出:
# Linear(in_features=784, out_features=256, bias=True)
# Linear(in_features=256, out_features=10, bias=True)
# ModuleDict(
#   (act): ReLU()
#   (linear): Linear(in_features=784, out_features=256, bias=True)
#   (output): Linear(in_features=256, out_features=10, bias=True)
# )

(1)和 nn.ModuleList 一样,nn.ModuleDict 实例仅仅是存放了一些模块的字典,并没有定义 forward函数 需要自己定义。
(2)同样,nn.ModuleDict 也与 Python 的 Dict 有所不同,nn.ModuleDict 里的所有模块的参数会被自动添加到整个网络中。


3、总结

  1. SequentialModuleListModuleDict 类都继承自 Module 类。
  2. Sequential 不同,ModuleListModuleDict 并没有定义一个完整的网络,它们只是将不同的模块存放在一起,需要自己定义 forward 函数。
  3. 虽然 Sequential 等类可以使模型构造更加简单,但直接继承 Module 类可以极大地拓展模型构造的灵活性。
http://www.lryc.cn/news/14346.html

相关文章:

  • Hive窗口函数语法规则、窗口聚合函数、窗口表达式、窗口排序函数 - ROW NUMBER 、口排序函数 - NTILE、窗口分析函数
  • Go设计模式之函数选项模式
  • ClickHouse 数据类型、函数大小写敏感性
  • nodejs基于vue 网上商城购物系统
  • 掌握MySQL分库分表(一)数据库性能优化思路、分库分表优缺点
  • 何为小亚细亚?
  • 【mircopython】ESP32配置与烧录版本
  • Yaml:通过extrac进行传参,关联---接口关联封装(基于一个独立YAML的文件)
  • vue - vue中对Vant日历组件(calendar)的二次封装
  • 详解C++的类型转换
  • NLP文本自动生成介绍及Char-RNN中文文本自动生成训练demo
  • Teradata 离场,企业数据分析平台如何应对变革?
  • QWebEngineView-官翻
  • 网络安全高级攻击
  • 优思学院:六西格玛中的水平对比方法是什么?
  • UVa 690 Pipeline Scheduling 流水线调度 二进制表示状态 DFS 剪枝
  • 【ArcGIS Pro二次开发】(6):工程(Project)的基本操作
  • Qt OpenGL(四十)——Qt OpenGL 核心模式-雷达扫描效果
  • 群智能优化算法求解标准测试函数F1~F23之种群动态分布图(视频)
  • vue-axios封装与使用
  • 重要节点排序方法
  • 【2.20】动态规划 +项目 + 存储引擎
  • 触摸屏单个按键远程控制led
  • JVM12 class文件
  • 等保三级认证基本要求
  • Python 基本数据类型(一)
  • win10 环境变量及其作用大全
  • @Valid与@Validated的区别
  • 【LeetCode】剑指 Offer 09. 用两个栈实现队列 p68 -- Java Version
  • Java并发编程面试题——JUC专题