当前位置: 首页 > news >正文

基础神经网络模型搭建

nn 包提供通用深度学习网络的模块集合,接收输入张量,计算输出张量,并保存权重。通常使用两种途径搭建 PyTorch 中的模型:nn.Sequential和 nn.Module。

nn.Sequential通过线性层有序组合搭建模型;nn.Module通过__init__ 函数指定层,然后通过 forward 函数将层应用于输入,更灵活地构建自定义模型。

目录

搭建线性层

通过nn.Sequential搭建

通过nn.Module搭建

获取模型摘要


搭建线性层

使用 nn 包搭建线性层。线性层接收 64*1000 维的输入,保存 1000*100 维的权重,并计算 64*100 维的输出。

import torch
from torch import nn
input_tensor = torch.randn(64, 1000)
linear_layer = nn.Linear(1000, 100)
output = linear_layer(input_tensor)
print(input_tensor.size())
print(output.size())

通过nn.Sequential搭建

考虑一个两层的神经网络,四个节点作为输入,五个节点在隐藏层,一个节点作为输出

from torch import nn
model = nn.Sequential(nn.Linear(4, 5),nn.ReLU(),nn.Linear(5, 1),
)
print(model)

通过nn.Module搭建

在 PyTorch 中搭建模型的另一种方法是对 nn.Module 类进行子类化,通过__init__ 函数指定层,然后通过 forward 函数将层应用于输入,更灵活地构建自定义模型。

考虑两个卷积层和两个完全连接层搭建的模型:

import torch.nn.functional as F
class Net(nn.Module):def __init__(self):super(Net, self).__init__()def forward(self, x):pass

定义__init__ 函数和forward 函数

def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5, 1)self.conv2 = nn.Conv2d(20, 50, 5, 1)self.fc1 = nn.Linear(4*4*50, 500)self.fc2 = nn.Linear(500, 10)
def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4*4*50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

重写两个类函数并打印模型

重写:子类中实现一个与父类的成员函数原型完全相同的函数

Net.__init__ = __init__
Net.forward = forward
model = Net()
print(model)

 查看模型位置

print(next(model.parameters()).device)

 

将模型移动至CUDA设备 

device = torch.device("cuda:0")
model.to(device)
print(next(model.parameters()).device)

获取模型摘要

借助torchsummary包查获取模型摘要

pip install torchsummary
from torchsummary import summary
summary(model, input_size=(1, 28, 28))

http://www.lryc.cn/news/594559.html

相关文章:

  • 【Linux】3. Shell语言
  • 双8无碳小车“cad【17张】三维图+设计说名书
  • XTTS实现语音克隆:精确控制音频格式与生成流程【TTS的实战指南】
  • XSS GAME靶场
  • 仙盟数据库应用-外贸标签打印系统 前端数据库-V8--毕业论文-—-—仙盟创梦IDE
  • Apache基础配置
  • ESMFold 安装教程
  • 深度相机的工作模式(以奥比中光深度相机为例)
  • 近期工作感想:职业规划篇
  • 【RAG Agent】Deep Searcher实现逻辑解析
  • 尚庭公寓--------登陆流程介绍以及功能代码
  • Linux:线程控制
  • API获取及调用(以豆包为例实现图像分析)
  • 《计算机网络》实验报告三 UDP协议分析
  • 单线程 Reactor 模式
  • 【PyTorch】图像二分类项目
  • SSE和WebSocket区别到底是什么
  • 渗透笔记(XSS跨站脚本攻击)
  • `MYSQL`、`MYSQL_RES` 和 `MYSQL_FIELD`的含义与使用案例
  • [硬件电路-59]:电源:电子存储的仓库,电能的发生地,电场的动力场所
  • 2025最新 PostgreSQL17 安装及配置(Windows原生版)
  • BST(二叉搜索树)的笔试大题(C语言)
  • 【web安全】SQL注入与认证绕过
  • 【算法300题】:双指针
  • c#转python第四天:生态系统与常用库
  • XSS的介绍
  • Linux主机 ->多机器登录
  • 从零到精通:用DataBinding解锁MVVM的开发魔法
  • 【JS逆向基础】数据库之MongoDB
  • Django接口自动化平台实现(四)