当前位置: 首页 > news >正文

pytorch网络的增删改

本文介绍对加载的网络的层进行增删改, 以alexnet网络为例进行介绍。

1. 加载网络

import torchvision.models as models  alexnet =models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
print(alexnet)

在这里插入图片描述

2. 删除网络

在做迁移学习的时候,我们通常是在分类网络的基础上进行修改的。一般会把网络最后的几层删除掉,主要是全局平均池化层、全连接层。只留前面的网络部分作为特征提取器,再次基础上进行其他的任务。

2.1 删除网络任意层

  • 将alexnet的classifier这一部分全删除掉

在这里插入图片描述

del  alexnet.classifer
print(alexnet)

删除classifer模块后,打印结果如下:
在这里插入图片描述
可以看到只剩下featuresavgpool这两个模块了。刚才的classifier就已经被我们删除掉了。

  • 删除classifier模块中的某一层

如果不想把classifier这一模块整体删除,只想删除classifier中比如第6个层

# del alexnet.classifier 
del alexnet.classifier[6]
print(alexnet)

在这里插入图片描述
可以看到classifier中第6层就已经被删除掉了。

2.2 删除网络的最后多层

如果想把网络的连续几层给删除掉,比如classifier中最后的几层删除掉

#------------------删除网络的最后多层--------------------------#
alexnet.classifier = alexnet.classifier[:-2]
print(alexnet)
#-------------------------------------------------------------#

打印信息如下:
在这里插入图片描述

  • 可以看出classifier看出最后2层(5,6)被删除掉了

可以使用切片的方式,保留不需要被删除的层重新赋给classifier模块,没有保留的就被删除了。

3. 修改网络的某一层

  • 没有修改之前alexnet.classifier的第6层是个全连接层,输入通道为4096, 输出通道为1000
    在这里插入图片描述
  • 假设此时,我们想最后一层全连接层的输出,改为1024。此时,你只需要重新定义这层全连接层。
#-----------------修改网络的某一层-----------------------------#
alexnet.classifier[6] = nn.Linear(in_features=4096,out_features=1024)
print(alexnet)
#-------------------------------------------------------------#

打印后,可以看到最后一层的输出由原来的4096改为了1024
在这里插入图片描述

4. 在网络中添加某一层

4.1 每次添加一层

假设我们想在网络最后输出中,再添加两层,分别为ReLUnn.Linear

#-----------------修改网络的某一层-----------------------------#
alexnet.classifier[6] = nn.Linear(in_features=4096,out_features=1024)
# print(alexnet)
#-------------------------------------------------------------##-------------网络添加层,每次添加一层--------------------------#
alexnet.classifier.add_module('7',nn.ReLU(inplace=True))
alexnet.classifier.add_module('8',nn.Linear(in_features=1024,out_features=20))
print(alexnet)
#-------------------------------------------------------------#
  • 利用add_module来添加层,第一个参数为层名称,第二个参数为定义layer的内容
  • 我们在alexnet.classifier这个block中进行添加的,添加后打印网络结构如下:
    在这里插入图片描述
  • 可以看到成功的添加了最后2层。

4.2 一次添加多层

如果觉得一层层的添加层比较麻烦,比如我们可以一次性添加一个大的模块new_block

block = nn.Sequential(nn.ReLU(inplace=True),nn.Linear(in_features=1024,out_features=20)
)alexnet.add_module('new_block',block)
print(alexnet)

在这里插入图片描述

  • 可以看到在alexnet网络中新增了new_block,该block中包括2层,分别是ReLU层以及Linear层。

以上就是对Pytorch网络增删改的方法,完整的代码如下:

import torchvision.models as models  
import torch.nn   as nn alexnet =models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
# print(alexnet)#1.-----------------删除网络的最后一层-------------------------#
# del alexnet.classifier 
# del alexnet.classifier[6]
# print(alexnet)
#-------------------------------------------------------------##------------------删除网络的最后多层--------------------------#
# alexnet.classifier = alexnet.classifier[:-2]
# print(alexnet)
#-------------------------------------------------------------##-----------------修改网络的某一层-----------------------------#
alexnet.classifier[6] = nn.Linear(in_features=4096,out_features=1024)
# print(alexnet)
#-------------------------------------------------------------##-------------网络添加层,每次添加一层--------------------------#
# alexnet.classifier.add_module('7',nn.ReLU(inplace=True))
# alexnet.classifier.add_module('8',nn.Linear(in_features=1024,out_features=20))
# print(alexnet)
#-------------------------------------------------------------##-----------------------网络添加层,一次性添加多层--------------#
block = nn.Sequential(nn.ReLU(inplace=True),nn.Linear(in_features=1024,out_features=20)
)alexnet.add_module('new_block',block)
print(alexnet)
#-------------------------------------------------------------#
http://www.lryc.cn/news/263618.html

相关文章:

  • Tomcat (Linux系统)详解全集
  • [德人合科技]——设计公司 \ 设计院图纸文件数据 | 资料透明加密防泄密软件
  • 数字化转型中的6S管理
  • Linux学习(1)——初识Linux
  • 2.5 - 网络协议 - HTTP协议工作原理,报文格式,抓包实战
  • 新增工具箱管理功能、重构网站证书管理功能,1Panel开源面板v1.9.0发布
  • 棋牌的电脑计时计费管理系统教程,棋牌灯控管理软件操作教程
  • 《Kotlin核心编程》笔记:设计模式
  • hive企业级调优策略之数据倾斜
  • MATLAB版本、labview版本、UHD版本 互相对应
  • 13 v-show指令
  • 23级新生C语言周赛(6)(郑州轻工业大学)
  • 关于“Python”的核心知识点整理大全24
  • Vue - 基于Element UI封装一个表格动态列组件
  • 计算机网络:DNS域名解析系统
  • java面试:==和equals有什么区别?
  • 数字人SaaS系统无限生成AI数字人!
  • 【MySQL】——数据类型及字符集
  • Redis cluster集群设置密码
  • Docker 核心技术
  • 15 使用v-model绑定单选框
  • go语言指针变量定义及说明
  • 基于“Galera+MariaDB”搭建多主数据库集群的实例
  • arcgis javascript api4.x加载天地图cgs2000坐标系
  • 算法学习——回溯算法
  • C语言—小小圣诞树
  • Android消息公告上下滚动切换轮播实现
  • tensorflow入门 自定义模型
  • 虚拟机启动 I/O error in “xfs_read_agi+0x95“
  • 【MYSQL】-库的操作