当前位置: 首页 > news >正文

pytorch中常见的模型3种组织方式 nn.Sequential(OrderedDict)

在nn.Sequential中嵌套OrderedDict组织网络,以对层进行命名

import torch
import torch.nn as nn
from collections import OrderedDictclass OrderedDictCNN(nn.Module):def __init__(self):super(OrderedDictCNN, self).__init__()# 使用 OrderedDict 定义网络层self.model = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)),  # 初始卷积层('bn1', nn.BatchNorm2d(64)),('relu1', nn.ReLU(inplace=True)),('maxpool1', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),('conv2', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)),  # 特征提取层('bn2', nn.BatchNorm2d(128)),('relu2', nn.ReLU(inplace=True)),('maxpool2', nn.MaxPool2d(kernel_size=2, stride=2, padding=0)),('flatten', nn.Flatten()),  # 展平层('fc1', nn.Linear(128 * 112 * 112, 1000)),  # 全连接层('relu3', nn.ReLU(inplace=True)),('fc2', nn.Linear(1000, 10))  # 输出层]))def forward(self, x):return self.model(x)

使用多个nn.Sequential组织网络

import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 初始卷积层self.stem = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64),nn.ReLU(inplace=True))# 特征提取层self.feature_extraction = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2, padding=0))# 全连接层self.fc = nn.Sequential(nn.Flatten(),nn.Linear(128 * 112 * 112, 1000),nn.ReLU(inplace=True),nn.Linear(1000, 10))def forward(self, x):x = self.stem(x)x = self.feature_extraction(x)x = self.fc(x)return x

使用单个nn.Sequential组织网络

import torch
import torch.nn as nnclass SequentialCNN(nn.Module):def __init__(self):super(SequentialCNN, self).__init__()# 使用 nn.Sequential 定义网络层self.model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),  # 初始卷积层nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),  # 特征提取层nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2, padding=0),nn.Flatten(),  # 展平层nn.Linear(128 * 112 * 112, 1000),  # 全连接层nn.ReLU(inplace=True),nn.Linear(1000, 10)  # 输出层)def forward(self, x):return self.model(x)
http://www.lryc.cn/news/405049.html

相关文章:

  • 达梦数据库DM8-索引篇
  • 【中项】系统集成项目管理工程师-第4章 信息系统架构-4.5技术架构
  • 随机梯度下降 (Stochastic Gradient Descent, SGD)
  • TDengine 3.3.2.0 发布:新增 UDT 及 Oracle、SQL Server 数据接入
  • Ubuntu 24.04 LTS 无法打开Chrome浏览器
  • linux中RocketMQ安装(单机版)及springboot中的使用
  • 亚信安全终端一体化解决方案入选应用创新典型案例
  • Django视图与URLs路由详解
  • 怎么关闭 Windows 安全中心,手动关闭 Windows Defender 教程
  • 洛谷看不了别人主页怎么办
  • 邮件安全篇:企业电子邮件安全涉及哪些方面?
  • 软件测试09 自动化测试技术(Selenium)
  • 记录解决springboot项目上传图片到本地,在html里不能回显的问题
  • C++ 中 const 关键字
  • 客梯自动监测识别摄像机
  • 为什么那么多人学习AI绘画?工资香啊!
  • 国产JS库(js-tool-big-box)7月度总结
  • c++ 高精度加法(只支持正整数)
  • python键盘操作工具:ctypes、pyautogui
  • 计算机网络发展历史
  • 记录安装android studio踩的坑 win7系统
  • Python图形编程-PyGame快速入门
  • 邦芒宝典:8种方法调整职场心态
  • 华为OD2024D卷机试题汇总,含D量50%+,按算法分类刷题,事半功倍
  • Unity UGUI 之 Graphic Raycaster
  • 类和对象——相关的零碎知识
  • 【hadoop大数据集群 1】
  • TQSDRPI开发板教程:实现PL端的UDP回环与GPSDO
  • array.some() ==> 查找数组list中,是否有包含与当前currKey的值不一样的misId
  • 最简单的typora+gitee+picgo配置图床