当前位置: 首页 > news >正文

pytorch学习(十二):对现有的模型进行修改

以VGG16为例:

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

特征提取部分(features

  • 卷积层与ReLU激活:网络的前半部分主要由卷积层(Conv2d)和ReLU激活函数(ReLU)交替组成。每个卷积层后都紧跟一个ReLU层,用于引入非线性。这种结构有助于网络学习复杂的特征表示。

  • 卷积层配置

    • 初始阶段,使用64个3x3的卷积核,然后是ReLU激活,接着是另一个3x3卷积核和ReLU激活,之后是一个2x2的最大池化层(MaxPool2d),用于降低特征图的尺寸并增加感受野。
    • 类似地,这个过程在特征图的通道数增加到128、256和512时重复,每次增加通道数后都会跟随几个卷积层和ReLU激活,然后是一个最大池化层。
    • 值得注意的是,在512通道的部分,卷积层和ReLU激活的组合被重复了三次,而没有立即进行池化,这可能是为了进一步增强特征表示。
  • 最大池化层:用于在每个阶段的末尾减少特征图的尺寸,这有助于减少计算量和参数数量,同时保持重要的特征信息。

全连接层部分(classifier

  • 自适应平均池化:在特征提取部分之后,使用了一个自适应平均池化层(AdaptiveAvgPool2d),将特征图的尺寸调整为7x7。这是为了确保无论输入图像的大小如何,全连接层都能接收到固定大小的输入。

  • 全连接层

    • 第一个全连接层(Linear)将7x7x512的特征图展平为25088个特征,并映射到4096个输出特征上。
    • 接着是两个ReLU激活层、两个Dropout层(用于防止过拟合)和另外两个全连接层,最终输出1000个类别的得分(假设是用于ImageNet分类任务)。

可以看到,最后一层的全连接层的输出是1000,那么当我们有例如十分类的问题时候,就需要对网络进行修改。

 

vgg16_true.add_module('add_linear',nn.Linear(1000,10))

运行上述代码,在末尾加一层线性层,也就是全连接层。

还有一种方式是对原有的全连接层进行修改,将1000改为10。

vgg16_false.classifier[6]=nn.Linear(4096,10)

附上所有源代码;

# -*- coding: utf-8 -*-  
# File created on 2024/8/9 
# 作者:酷尔
# 公众号:酷尔计算机import torchvision
from torch import nn
# train_data=torchvision.datasets.ImageNet('./data_imagenet',split='train',download=True,transform=torchvision.transforms.ToTensor())vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=True)# print(vgg16_true)
# import os
#
# # 尝试从环境变量中获取TORCH_HOME
# torch_home = os.getenv('TORCH_HOME', os.path.expanduser('~/.torch'))
# model_cache_dir = os.path.join(torch_home, 'models')
#
# print(f"Model cache directory: {model_cache_dir}")
# 
# # 注意:这个目录可能不直接包含模型文件,因为 PyTorch 可能使用了内部的缓存机制
# # 来管理这些文件,并且它们可能以哈希名存储而不是直接以模型名存储。train_data=torchvision.datasets.CIFAR10('./dataset',train=True,download=True,transform=torchvision.transforms.ToTensor())
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
# print(vgg16_true)vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)

http://www.lryc.cn/news/420756.html

相关文章:

  • 服务器虚拟内存是什么?虚拟内存怎么设置?
  • 深度学习入门指南(1) - 从chatgpt入手
  • Python学习笔记(六)
  • 大数据安全规划总体方案(45页PPT)
  • 第20周:Pytorch文本分类入门
  • 记一次 SpringBoot2.x 配置 Fastjson请求报 internal server 500
  • OSPF笔记
  • IOC容器初始化流程
  • 第二季度云计算市场份额榜单:微软下滑,谷歌上升,AWS仍保持领先
  • 三点确定圆心算法推导
  • 神经网络 (NN) TensorFlow Playground在线应用程序
  • 腾讯课堂 离线m3u8.sqlite转成视频
  • Linux多路转接
  • IDEA导入Maven项目的流程配置以常见问题解决
  • 【数据分析---- Pandas进阶指南:核心计算方法、缺失值处理及数据类型管理】
  • 2024世界机器人大会将于8月21日至25日在京举行
  • 【Linux】lvm被删除或者lvm丢失了怎么办
  • 疫情防控管理系统
  • 永久删除的Android 文件去哪了?在Android上恢复误删除的消息和照片方法?
  • 宠物服务小程序多生态转化
  • 今天细说一下工业制造行业MES系统
  • C++ 知识点(长期更新)
  • Spring AI + 通义千问 入门学习
  • 38.【C语言】指针(重难点)(C)
  • Vue-05.指令-v-for
  • 自动驾驶的一些大白话讲解
  • Python学习笔记--参数
  • 刷题——大数加法
  • Pytorch人体姿态骨架生成图像
  • 前端面试常考的HTML标签知识!!!