当前位置: 首页 > news >正文

2.11.ResNet

ResNet

动机:我们总是想加更多层,但加更多层并不总是能改进精度

在这里插入图片描述

可以看出F1到F6模型越来越大,但F6距离最优解却总变远了,反而效果不好,通俗的来说就是学偏了,实际上我们希望是这样的:
在这里插入图片描述

​ 更大的模型总是包含之前的小模型,则结果至少不会更差。

​ 这也是残差网络(ResNet)的核心思想:每个附加层都应该更容易地包含原始函数作为其元素之一。

1.残差块

​ 我们希望能串联一个层能改变函数类,让它变大:

在这里插入图片描述

​ 右侧是残差块,通过加入快速通道来得到 f ( x ) = x + g ( x ) f(x)=x+g(x) f(x)=x+g(x)的结构,如果 g ( x ) g(x) g(x)没有学到任何东西,就等价于恒等映射,可以直接跳过这个层,先去拟合小网络。显然如果使用了参拆快,那么 f ( x ) f(x) f(x)的范围肯定比 x x x大,且对输入的改变比较敏感。

ResNet块的具体细节

在这里插入图片描述

​ 可以使用1*1的卷积层来变换输出通道。箭头的位置可以随便选取,看具体效果吧

在这里插入图片描述

​ 效果都差不多

ResNet块可分为两类

  1. 高宽减半的ResNet块,即步幅为2,有1*1卷积层(步幅也设置为2)的,将输入高宽减半,输出通道自然增加,那么x需要通过1*1卷积层来变换输出通道
  2. 高宽不变的,即步幅为1,不需要使用1*1卷积层的ResNet块

2.ResNet架构

在这里插入图片描述

​ 如图所示为ResNet-18架构,类似VGG和GoogLeNet的总体架构,但替换成了ResNet块,基本架构也是这样的5阶段

  • 残差快使得很深的网络更加容易训练,甚至可以训练一千层的网络
  • 残差网络对随后的深度神经网络设计产生了深远影响

3.ResNet如何处理梯度消失


y = f ( x ) 梯度 ∂ y ∂ w w = w − D ∂ y ∂ w y= f(x)\\ 梯度\frac{\partial y}{\partial w}\\ w = w- D\frac{\partial y}{\partial w}\\ y=f(x)梯度wyw=wDwy
​ 不希望梯度变得很小,但如果又新嵌套很多层:
y ′ = g ( f ( x ) ) ∂ y ′ ∂ w = ∂ y ′ ∂ y ⋅ ∂ y ∂ w = ∂ g ( y ) ∂ y ⋅ ∂ y ∂ w y'=g(f(x))\\ \frac{\partial y'}{\partial w}=\frac{\partial y'}{\partial y}\cdot\frac{\partial y}{\partial w} =\frac{\partial g(y)}{\partial y}\cdot \frac{\partial y}{\partial w} y=g(f(x))wy=yywy=yg(y)wy
​ 如果新加的层拟合得很好,那么 ∂ g ( y ) ∂ y \frac{\partial g(y)}{\partial y} yg(y)就会很小,那么 ∂ y ′ ∂ w \frac{\partial y'}{\partial w} wy会很小,这时候我们只能增大学习率,但这样会导致顶部梯度爆炸,反之则底部梯度消失。

​ ResNet:
y ′ ′ = y + y ′ = f ( x ) + g ( f ( x ) ) ∂ y ′ ′ ∂ w = ∂ y ∂ w + ∂ y ′ ∂ w y'' = y+y' =f(x)+g(f(x))\\ \frac{\partial y''}{\partial w} =\frac{\partial y}{\partial w}+\frac{\partial y'}{\partial w} y′′=y+y=f(x)+g(f(x))wy′′=wy+wy
​ 将乘法变为了加法,这样大数加一个小数也是一个大数,这样在底部(靠近数据端的)在初始时也可以有较大的梯度(因为可以通过快速通道传递),会得到比较好的训练效果。

4.代码实现

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):  # @savedef __init__(self, input_channels, num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)self.relu = nn.ReLU(inplace=True)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += X  # 相加后再ReLUreturn F.relu(Y)blk = Residual(3, 3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
print('输入和输出形状一致:', Y.shape)blk = Residual(3, 6, use_1x1conv=True, strides=2)
print('使用步幅为2的1*1卷积层,输出通道翻倍,高宽减半:', blk(X).shape)'''ResNet块'''
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))def resnet_block(input_channels, num_channels, num_residuals,first_block=False):# num_residuals表示这里面有多少个resnet块# first_block用于特判第一个,之前在第一阶段b1块时就已经减少了很多,所以第一个残差块不减半blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blkb2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 10))X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
http://www.lryc.cn/news/409299.html

相关文章:

  • GitLab添加TortoiseGIT生成SSH Key
  • 20240729 大模型评测
  • 基于微信小程序的校园警务系统/校园安全管理系统/校园出入管理系统
  • 达梦数据库归档介绍
  • OpenAI推出AI搜索引擎SearchGPT
  • elementplus菜单组件的那些事
  • 【VSCode实战】Golang无法跳转问题竟是如此简单
  • three.js中加载ply格式的文件,并使用tween.js插件按照json姿态文件运动
  • 性能对比:Memcached 与 Redis 的关键差异
  • app-routing.module.ts 简单介绍
  • 基于JSP的水果销售管理网站
  • web3d值得学习并长期发展,性价比高吗?
  • 【大数据面试题】38 说说 Hive 怎么行转列
  • C语言中的二维数组
  • Android12 添加屏幕方向旋转方案
  • Harmony-(1)-TypeScript-ArkTs
  • TC8:SOMEIP_ETS_007-008
  • [网络编程】网络编程的基础使用
  • Postman中的Cookie和会话管理:掌握API测试的关键环节
  • python脚本,识别pdf数据,转换成表格形式
  • Linux环境安装KubeSphere容器云平台并实现远程访问Web UI 界面
  • jumpserver web资源--远程应用发布机
  • Linux环境docker部署Firefox结合内网穿透远程使用浏览器测试
  • 人工智能与机器学习原理精解【8】
  • 关于Protobuf 输入输出中文到文件中的一系列问题
  • 后端笔记(1)--javaweb简介
  • 便携式气象监测系统的优势:精准高效,随行监测
  • uniapp App判断是否安装某个app
  • C/C++大雪纷飞代码
  • 【linux】【设备树】具有 GPIO 控制器和连接器的硬件配置的备树(Device Tree)代码讲解