当前位置: 首页 > news >正文

【PyTorch单点知识】深入了解 nn.ModuleList和 nn.ParameterList模块:灵活构建动态网络结构

文章目录

      • 0. 前言
      • 1. 为什么需要 `nn.ModuleList` 和 `nn.ParameterList`?
      • 2. `nn.ModuleList`:管理模块的列表
        • 2.1 什么是 `nn.ModuleList`?
        • 2.2 创建 `nn.ModuleList`
        • 2.3 动态添加或删除层
      • 3. `nn.ParameterList`:管理参数列表
        • 3.1 什么是 `nn.ParameterList`?
        • 3.2 创建 `nn.ParameterList`
        • 3.3 动态添加或删除参数
      • 4. 自适应模型
      • 5. 总结

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在 PyTorch 中,nn.ModuleListnn.ParameterList 是两种非常有用的工具,可以让你以更加灵活的方式构建和管理动态网络结构。这两种列表允许你在构建模型时轻松地添加或删除层,这对于构建自适应模型、循环网络或其他需要动态调整结构的场景非常有用。

本文将详细介绍这两个类的使用方法及其应用场景,帮助你更好地理解和运用它们来构建复杂和灵活的神经网络模型。

1. 为什么需要 nn.ModuleListnn.ParameterList

在构建深度学习模型时,我们经常需要创建包含多个层的网络。传统的做法是显式地定义每一层,但这在某些情况下可能不够灵活。例如,当你需要根据输入数据动态决定网络结构时,就需要一种更加灵活的方式来组织和管理这些层。

nn.ModuleListnn.ParameterList 提供了这样的灵活性。它们允许你将多个层或参数集合组织在一起,并且可以方便地在运行时增加、删除或修改这些层或参数。

2. nn.ModuleList:管理模块的列表

2.1 什么是 nn.ModuleList

nn.ModuleList 是一个包含 nn.Module 子类实例的有序列表。它可以用于管理一个模型中的多个层,而且这些层可以是任意类型的 nn.Module 对象。

2.2 创建 nn.ModuleList

要创建一个 nn.ModuleList,你可以简单地将 nn.Module 的实例作为一个列表传递给构造函数。例如:

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])def forward(self, x):for layer in self.layers:x = layer(x)return xM = MyModel()
print(M)

输出为:

MyModel((layers): ModuleList((0-4): 5 x Linear(in_features=10, out_features=10, bias=True))
)

在这个例子中,MyModel 包含了一个由五个 nn.Linear 层组成的 ModuleList。每个层都将输入的维度从 10 映射到 10。

2.3 动态添加或删除层

nn.ModuleList 支持像 Python 列表那样的索引操作,因此可以轻松地添加、删除或替换其中的层:

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)])def forward(self, x):for layer in self.layers:x = layer(x)return xM = MyModel()M.layers.append(nn.Conv2d(in_channels=1,out_channels=3,kernel_size=3))
print(M.layers[5])

输出为:

Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))

3. nn.ParameterList:管理参数列表

3.1 什么是 nn.ParameterList

nn.ParameterList 类似于 nn.ModuleList,但它用于管理一组 nn.Parameter 对象。这些参数可以是权重矩阵、偏置向量等。

3.2 创建 nn.ParameterList

要创建一个 nn.ParameterList,你可以将 nn.Parameter 对象作为一个列表传递给构造函数:

import torch.nn as nn
import torchtorch.manual_seed(666)class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.weights = nn.ParameterList([nn.Parameter(torch.randn(3, 3)) for _ in range(5)])def forward(self, x):for weight in self.weights:x = torch.mm(x, weight)return xM = MyModel()
print(M.weights)
print(M.weights[-1])

输出为:

ParameterList((0): Parameter containing: [torch.float32 of size 3x3](1): Parameter containing: [torch.float32 of size 3x3](2): Parameter containing: [torch.float32 of size 3x3](3): Parameter containing: [torch.float32 of size 3x3](4): Parameter containing: [torch.float32 of size 3x3]
)
Parameter containing:
tensor([[ 2.1743, -0.9672, -0.7672],[-0.5229, -2.2826,  0.1051],[-0.2497, -1.5241,  1.5813]], requires_grad=True)

在这个例子中,MyModel 包含了一个由五个随机权重矩阵组成的 ParameterList

3.3 动态添加或删除参数

nn.ParameterList 同样支持像 Python 列表那样的索引操作,因此你可以轻松地添加、删除或替换其中的参数:

import torch.nn as nn
import torchtorch.manual_seed(666)class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.weights = nn.ParameterList([nn.Parameter(torch.randn(3, 3)) for _ in range(5)])def forward(self, x):for weight in self.weights:x = torch.mm(x, weight)return xM = MyModel()M.weights.append(torch.zeros(3,3))
print(M.weights)
print(M.weights[-1])

输出为:

ParameterList((0): Parameter containing: [torch.float32 of size 3x3](1): Parameter containing: [torch.float32 of size 3x3](2): Parameter containing: [torch.float32 of size 3x3](3): Parameter containing: [torch.float32 of size 3x3](4): Parameter containing: [torch.float32 of size 3x3](5): Parameter containing: [torch.float32 of size 3x3]
)
Parameter containing:
tensor([[0., 0., 0.],[0., 0., 0.],[0., 0., 0.]], requires_grad=True)

4. 自适应模型

比如构建一个模型,其中的某些层可以根据输入数据动态决定是否使用。你可以使用 nn.ModuleList 来存储这些层,并在 .forward() 方法中根据条件决定是否使用它们:

import torch.nn as nn
class AdaptiveModel(nn.Module):def __init__(self, num_layers):super(AdaptiveModel, self).__init__()self.layers = nn.ModuleList([nn.Linear(5, 5) for _ in range(num_layers)])def forward(self, x):use_layers = [True, False, True, True, False]  # 示例:使用第0、2、3层for i, layer in enumerate(self.layers):if use_layers[i]:x = layer(x)return x

5. 总结

nn.ModuleListnn.ParameterList 提供了一种灵活的方式来构建和管理动态网络结构。通过这些工具,可以轻松地构建自适应模型、循环网络或其他需要动态调整结构的场景。

http://www.lryc.cn/news/436554.html

相关文章:

  • vscode创建Python虚拟环境无法激活问题处理
  • 【Go】Go语言中的基本数据类型与类型转换
  • 【Python中导入Tkinter模块创建计算器界面】
  • 中关村科金推出得助音视频鸿蒙SDK,助力金融业务系统鸿蒙化提速
  • 如何实现视频数据的PES打包和传输?
  • 【软考】程序设计语言基础
  • 野指针与空指针的异同
  • 虚拟存储器“大观”,讲解核心逻辑知识和408大题方法
  • 【AI赋能医学】基于深度学习和HRV特征的多类别心电图分类
  • 速盾:做外贸用高防cdn需要国外节点的吗?
  • 单片机中为什么要使用5v转3.3v,不直接使用3.3V电压
  • SpringBoot项目请求返回json空字段过滤
  • linux下进程详解
  • 春招审核流程优化:Spring Boot系统设计
  • QT:音视频播放器
  • 大模型入门 ch 03:注意力机制
  • STM32点亮第一个LED
  • [Linux]:动静态库
  • windows 显示进程地址空间
  • Android 12 SystemUI下拉状态栏禁止QuickQSPanel展开
  • 二分思想与相关问题(下)
  • 【算法专题】搜索算法
  • B2064 斐波那契数列
  • Spark的介绍
  • SpringBoot项目是如何启动
  • 科技之光,照亮未来之路“2024南京国际人工智能展会”
  • 在深度学习计算机视觉的语义分割中,Boundary和Edge的区别是?
  • 【JAVA入门】Day41 - 字节缓冲流和字符缓冲流
  • collocate join,bucket join,broadcast join,shuffle join对比分析
  • 微信自动通过好友和自动拉人进群,微加机器人这个功能太好用了