当前位置: 首页 > news >正文

nn.Module模块介绍

nn.Module是 PyTorch 中所有神经网络模块的基类,用于构建可训练的模型,即构建一个新结构的模型。它是 PyTorch 神经网络的核心抽象,一个抽象类,使用时必须实现必要的抽象函数。

1. 使用方法:

举一个例子:手写数字图像识别,建立一个深度学习的框架,输入是28\times28的图像,输出是一个1\times10的向量,表示0~9的各个类别的可能性。

网络的架构:

使用nn.Module构建这个网络

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SimpleCNN(nn.Module):def __init__(self):super().__init__()# 子模块定义self.conv1 = nn.Conv2d(1, 16, 3)  # 输入通道1,输出通道16,卷积核3x3self.pool = nn.MaxPool2d(2)       # 2x2最大池化self.fc = nn.Linear(16*13*13, 10) # 全连接层(假设输入图像为28x28)# 自定义参数(非子模块)self.scale = nn.Parameter(torch.tensor(1.0))  # 可训练标量参数,动态缩放输出结果def forward(self, x):x = self.pool(F.relu(self.conv1(x)))  # 卷积 -> ReLU -> 池化x = x.view(-1, 16*13*13)             # 展平x = self.fc(x) * self.scale           # 全连接层 + 自定义参数缩放return x

模型在创建时,必须包含__init__和__forward__两个方法。

2. 特点

(1)参数管理

当在 nn.Module 的子类中将 nn.Parameter 或子模块(如 nn.Conv2d)赋值给类属性时,PyTorch 会记录这些对象到内部的 _parameters 或 _modules 字典中,确保它们参与梯度计算、设备移动(CPU/GPU)、参数保存/加载等关键操作。例如,例子中成员变量 self.conv1,self.fc , self.scale。

可以实现参数的自动跟踪

model = SimpleCNN()
print(list(model.named_parameters()))

输出:

scale; conv1.weight、conv1.bias;  fc_weight、fc_bias如下:

[('scale', Parameter containing:
tensor(1., requires_grad=True)), ('conv1.weight', Parameter containing:
tensor([[[[ 0.3146, -0.2337,  0.2631],[ 0.1649,  0.2865,  0.2307],[-0.0522, -0.2642, -0.1696]]],[[[ 0.0158,  0.3199,  0.0063],[ 0.0858,  0.1410, -0.0497],[-0.1104,  0.2964,  0.2612]]],[[[-0.1222, -0.1469,  0.0314],[-0.2020, -0.3159, -0.0970],[ 0.2853,  0.1428,  0.0119]]],[[[ 0.1217, -0.0545, -0.1806],[-0.0048,  0.1158,  0.1185],[-0.0908,  0.0012, -0.0098]]],[[[ 0.1017, -0.0518,  0.1661],[-0.1580, -0.0326,  0.3247],[-0.3255, -0.2731, -0.2454]]],[[[ 0.2273, -0.1849, -0.1432],[-0.3186,  0.0621, -0.2068],[ 0.0756, -0.3076, -0.2667]]],[[[ 0.2341,  0.2008, -0.0361],[-0.3005, -0.1754, -0.3298],[-0.2160, -0.3142,  0.3064]]],[[[-0.2293, -0.1122, -0.1528],[ 0.2064,  0.0754, -0.2762],[ 0.2740, -0.0463, -0.1822]]],[[[ 0.2774,  0.0322, -0.1532],[-0.0482, -0.0678, -0.2401],[-0.0318,  0.2358, -0.2187]]],[[[ 0.1396,  0.1801,  0.1789],[-0.1797, -0.1715, -0.3309],[-0.1572,  0.0549,  0.0577]]],[[[-0.3022,  0.2383,  0.1073],[-0.0813,  0.2904, -0.2532],[-0.0321,  0.0273, -0.2783]]],[[[ 0.2397,  0.3167, -0.2939],[-0.2852, -0.2542,  0.1281],[ 0.0433,  0.2920,  0.2629]]],[[[ 0.0573, -0.0992, -0.2561],[ 0.1158,  0.2102, -0.1286],[-0.3075,  0.0806,  0.2279]]],[[[-0.2582,  0.2342, -0.2332],[-0.2627,  0.2822,  0.2278],[ 0.1213, -0.1526, -0.1611]]],[[[-0.0150,  0.3245, -0.1438],[ 0.0012,  0.1359,  0.2652],[ 0.1046,  0.1012, -0.2422]]],[[[-0.0178,  0.3177,  0.1215],[ 0.0338, -0.1513,  0.2207],[ 0.1846,  0.0616, -0.0704]]]], requires_grad=True)), ('conv1.bias', Parameter containing:
tensor([-0.3028,  0.2742,  0.0908,  0.0770,  0.0357,  0.1591,  0.1625, -0.0185,0.0871,  0.2598,  0.2732, -0.0111,  0.2493, -0.1319, -0.1072, -0.0537],requires_grad=True)), ('fc.weight', Parameter containing:
tensor([[ 0.0039, -0.0049, -0.0023,  ..., -0.0102, -0.0178, -0.0031],[-0.0039, -0.0058, -0.0025,  ..., -0.0030, -0.0131,  0.0092],[ 0.0077, -0.0068,  0.0059,  ...,  0.0078,  0.0055,  0.0096],...,[-0.0038,  0.0079, -0.0186,  ..., -0.0171, -0.0047,  0.0003],[-0.0056, -0.0179,  0.0017,  ..., -0.0092, -0.0189,  0.0128],[ 0.0144, -0.0057,  0.0038,  ...,  0.0152, -0.0043,  0.0025]],requires_grad=True)), ('fc.bias', Parameter containing:
tensor([-0.0178, -0.0058,  0.0016,  0.0112,  0.0151, -0.0164,  0.0127,  0.0060,-0.0175, -0.0156], requires_grad=True))]进程已结束,退出代码为 0

 设备移动统一管理

model.to('cuda')  # 所有参数和子模块自动移至GPU
print(model.weight.is_cuda)  # True

梯度计算自动启用

loss = model(x).sum()
loss.backward()  # 所有注册的参数自动计算梯度
print(model.weight.grad is not None)  # True

有些参数是需要手动注册的,才能实现自动的管理:参数/模块是通过列表或字典动态生成的,需要手动注册,在__init__部分进行注册。

class DynamicModel(nn.Module):def __init__(self):super().__init__()self.params_list = nn.ParameterList([nn.Parameter(torch.randn(10)) for _ in range(5)])  # 自动注册self.params_dict = nn.ModuleDict({'p1': nn.Linear(10, 5)})                              # 自动注册# 普通Python容器内的参数需手动注册self.custom_list = [nn.Parameter(torch.randn(10))]for i, param in enumerate(self.custom_list):self.register_parameter(f'custom_{i}', param)  # 手动注册

(2)模块嵌套

新构建的网络模型要在此处进行定义。

class ComplexModel(nn.Module):def __init__(self):super().__init__()self.conv_block = nn.Sequential(nn.Conv2d(3, 16, 3),nn.ReLU(),nn.MaxPool2d(2))self.classifier = nn.Linear(16*13*13, 10)  # 假设输入图像为28x28def forward(self, x):x = self.conv_block(x)x = x.view(x.size(0), -1)  # 展平return self.classifier(x)

(3)模型保存与加载

# 保存
torch.save(model.state_dict(), 'model.pth')# 加载
new_model = MyModel()
new_model.load_state_dict(torch.load('model.pth'))

(4)钩子(Hooks)

可进行调试或特征提取。

def forward_hook(module, input, output):print(f"Layer {module.__class__.__name__} output shape: {output.shape}")model.conv_block.register_forward_hook(forward_hook)  # 注册钩子

3. 注意事项

不要直接调用 forward()应该用 model(x)(PyTorch 会自动处理钩子和梯度)。模块命名唯一,子模块名称不能重复(如两个 self.fc 会覆盖)。

http://www.lryc.cn/news/623898.html

相关文章:

  • 计算机视觉(一):nvidia与cuda介绍
  • OpenMemory MCP发布!AI记忆本地共享,Claude、Cursor一键同步效率翻倍!
  • 【Linux】文件基础IO
  • Agent开发进阶路线:从基础响应到自主决策的架构演进
  • Python使用数据类dataclasses管理数据对象
  • 【C2000】C2000例程使用介绍
  • Python进行中文分词
  • MySQL定时任务详解 - Event Scheduler 事件调度器从基础到实战
  • Blender模拟结构光3D Scanner(二)投影仪内参数匹配
  • 火狐(Mozilla Firefox)浏览器离线安装包下载
  • 学习Python中Selenium模块的基本用法(5:程序基本步骤)
  • Python数据类型转换详解:从基础到实践
  • Python 基础语法(二)
  • 0️⃣基础 认识Python操作文件夹(初学者)
  • Linux:TCP协议
  • RK3568平台开发系列讲解:PCIE trainning失败怎么办
  • 深入解析函数指针及其数组、typedef关键字应用技巧
  • 0-12岁幼儿启蒙与教育
  • CF2121C Those Who Are With Us
  • 2001-2024年中国玉米种植分布数据集
  • 【牛客刷题】01字符串按递增长度截取并转换为十进制数值
  • Day07 缓存商品 购物车
  • 14.web api 5
  • LEA(Load Effective Address)指令
  • 19.5 「4步压缩大模型:GPTQ量化实战让OPT-1.3B显存直降75%」
  • 混沌工程(Chaos engineering):系统韧性保障之道
  • 图解希尔排序C语言实现
  • 【Java】多线程Thread类
  • 2025年- H97-Lc205--23.合并k个升序链表(链表、小根堆、优先队列)--Java版
  • 【撸靶笔记】第二关:GET -Error based -Intiger based