当前位置: 首页 > news >正文

【残差网络ResNet:残差块输入输出形状控制】

【残差网络ResNet:残差块输入输出形状控制】

  • 1 残差块输入输出形状控制程序
  • 2 查看经典的ResNet18模型

1 残差块输入输出形状控制程序

在这里插入图片描述
参考链接:https://arxiv.org/pdf/1512.03385.pdf
这是一个基本的残差块,由两层卷积组成前向传播 + 一层卷积和批归一化与组成,为了与两层卷积组成前向传播的形状一致,一层卷积和批归一化用来控制输出的形状,最终相加形成新的与前向传播一致的形状

class ResNetBasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels,kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.residual = nn.Conv2d(in_channels, out_channels,kernel_size=3, stride=stride, padding=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels)def forward(self, x):out = self.conv1(x)out = F.relu(self.bn1(out),inplace=True)out = self.conv2(out)out = self.bn2(out)res = self.residual(x)res = self.bn3(res)out += res                 # 直连return F.relu(out)

测试代码如下:

imgs_batch = torch.randn((8, 3, 224, 244))
resnet_block = ResNetBasicBlock(3, 16, 1)
pred_batch = resnet_block(imgs_batch)
print(pred_batch.shape)

输出如下:

torch.Size([8, 16, 224, 244])

使用tensorboard观察结构图代码:

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter('my_log/ResNetBasicBlock')
writer.add_graph(resnet_block, imgs_batch)
# 在promote中输入tensorboard --logdir path --host=127.0.0.1 ,path为绝对路径不加双引号,按照提示打开tensorboard

在这里插入图片描述

2 查看经典的ResNet18模型

resnet_model = torchvision.models.resnet18(pretrained=False)
print(resnet_model)

输出如下:

ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): BasicBlock((conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer3): Sequential((0): BasicBlock((conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer4): Sequential((0): BasicBlock((conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=512, out_features=1000, bias=True)
)
http://www.lryc.cn/news/176101.html

相关文章:

  • 【编译和链接——详解】
  • 【python爬虫】爬虫所需要的爬虫代理ip是什么?
  • 酒店预订小程序制作详细步骤解析
  • Intel汇编语言程序设计(第7版)第六章编程学习过程中写的小例子
  • ElementUI之动态树+数据表格+分页
  • ReferenceError: primordials is not defined错误解决
  • 【Element-UI】实现动态树、数据表格及分页效果
  • 解决仪器掉线备忘
  • Java面向对象高级
  • 渗透测试信息收集方法和工具分享
  • Unity打包出来的APK文件有问题总结
  • 记录:移动设备软件开发(Activity的显式启动和隐式启动)
  • 面试题库(十一):MQ和分布式事务
  • Linux日期和时间管理指南:日期、时间、时区、定时任务和时间同步
  • tsar-性能监控工具
  • 【Linux】系统编程简单线程池(C++)
  • 数据结构之道:如何选择适合你的数据存储
  • MySQL定时删除XX天数据
  • vue在js文件中调用$notify
  • C++从入门到精通
  • 2023网络安全面试题(附答案)+面经
  • 数据结构_红黑树
  • 一百八十八、Hive——HiveSQL查询表中的日期是星期几(亲测,附截图)
  • 基础题——数组
  • Qt地铁智慧换乘系统浅学( 一 )存储站点,线路信息
  • Python之xToolkit库
  • 2w+深度梳理!全网最全NLP面试题总结!
  • Spring 学习(五)JavaConfig 实现配置
  • 【Synapse数据集】Synapse数据集介绍和预处理,数据集下载网盘链接
  • 【运动规划算法项目实战】Dynamic Window Approach算法(附ROS C++代码)