当前位置: 首页 > news >正文

五、CV_ResNet

五、ResNet

随着层数加深,特征图的尺寸是逐渐变小的,其通道是逐渐增多的

网络退化问题:理论上,网络越深,获取的信息就越多,特征也就越丰富。但在实践中,随着网络的加深,优化效果反而越差,测试数据和训练数据的准确率反而降低了

1.残差块

(1)作用

  • 缓解网络退化问题
  • 在模型中是用来降维的

(2)概念

F(x)F(x)F(x)代表某个只包含两层的映射函数,xxxF(x)F(x)F(x)具有相同维度。在训练过程中,我们的目标是修改F(x)F(x)F(x)中的wwwbbb逼近H(x)H(x)H(x),变换一下思路,用F(x)F(x)F(x)来逼近,则最终得到的输H(x)−xH(x)-xH(x)x出就变为F(x)+xF(x)+xF(x)+x,这里将直接从输入连接到输出的结构称为,shortcutshortcutshortcut整个结构就是残差块,ResNetResNetResNet的基础模块。

  • F(x)F(x)F(x)代表预测值
  • H(x)H(x)H(x)代表真实值
  • F(x)+xF(x)+xF(x)+x:这里的加指的是对应位置上的元素相加,也就是$element - wise $additionadditionaddition

ResNetResNetResNet沿用了VGG全3×33 \times 33×3卷积层的设计。残差块里首先有2个具有相同输出通道数的3×33 \times 33×3卷积层。每个卷积层后接BN层和ReLU激活函数,然后将输入直接加在最后的ReLU激活函数前,这种结构用于层数较少的神经网络中(resnet18, resnet34)

如果输入通道数(eg.eg.eg.图中有256通道数)比较多,就需要引入1×11 \times 11×1卷积层来调整输入的通道数,这种结构也叫做瓶颈模块,通常用关于网络层数较多的结构中(resnet50,resnet101,resnet152)

下面右图残差块的实现如下,可以设定输出通道数,是否使用1×11 \times 11×1的卷积及卷积层的步幅

import tensorflow as tf
from tensorflow.keras import layers, activations# 定义ResNet
class Residual(tf.keras.Model):# 指明残差块的通道数,是否使用1*1卷积,步长def __init__(self, num_channels,use_1x1convs = False, strides = 1):super(Residual, self).__init__()# 卷积层:指明卷积核个数,padding,卷积核大小,步长self.cov1 = layers.Conv2D(num_channels, padding = 'same', kernel_size = 3,strides = strides)# 卷积层:指明卷积核个数,padding,卷积核大小,步长self.conv2 = layers.Conv2D(num_channels,strides = 1,kernel_size = 3,padding = 'same')if use_1x1conv:self.conv3 = layers.Conv2D(num_channels,kernel_size = 1,strides = strides)else:self.conv3 = None# 指明BN层self.bn1 = layers.BatchNormalization()self.bn2 = layers.BatchNormalization()# 定义正向传播过程def call(self, X):# 卷积,BN, 激活Y = activations.relu(self.bn1(self.conv1(x)))# 卷积,BNY = self.bn2(self.conv2(Y))# 对输出数据进行1*1卷积保证通道数相同if self.conv3:X = self.conv3(X)# 返回与输入相加后激活的结果return activation.relu(Y + X)
  • 1×11\times 11×1卷积是用来调整通道数
  • 降维
    • pooling层
    • 设置卷积 strides = 2

2.ResNet模型

ResNet模型构成如下:

ResNet网络中按照残差块的通道数分为不同的模块。第一个模块使用了步幅为2的最大池化层。则无需减小宽和高(第一个模块需进行特殊处理。即需要进行下采样,降维)。之后每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半(每个模块间均需进行通道调整)

(1)定义残差模块

  • 第一个模块做了特别处理
class ResnetBlock(tf.keras.layers.Layer):def __init__(self, num_channels, num_res, first_block = False):super(ResnetBlock, self).__init__()# 存储残差块self.listLayers = []# 遍历残差数目生成模块for i in range(num_res):if i == 0 and not first_block:self.listLayers.append(Residual(num_channels, use_1x1conv = True, strides = 2))else:self.listLayers.append(Residual(num_channels))# 前向传播def call(self, x):for layer in self.listLayers:x = layer(x)return x                

(2)构建Resnet网络

  • ResNet的前两层跟之前介绍的GoogLeNet中一样:在输出通道数为64,步幅为2的7×77 \times 77×7卷积层后接步幅为2的3×33\times33×3的最大池化层。不同之处在于ResNet每个卷积层后增加了BN层,接着是所有残差模块,最后,与GoogLeNet一样,加入全局平均池化层(GPA)后接上全连接层输出
class Resnet(tf.keras.Model):# 定义网络的构成def __init__(self, num_blocks):super(Resnet, self).__init__()# 输入层self.conv = layers.Conv2D(filter = 64,kernel_size = 7,padding = 'same',strides = 2)# BN层self.bn = layers.BatchNormalization()# 激活层self.relu = layers.Activation('relu')# 池化self.mp = layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'same')# 残差模块self.res_block1 = ResnetBlock(64, num_blocks[0], first_block = True)self.res_block2 = ResnetBlock(128, num_blocks[1])self.res_block3 = ResnetBlock(256, num_blocks[2])self.res_block4 = ResnetBlock(512, num_blocks[3])# GAPself.gap = layers.GlobalAvgPool2D()# 全连接层self.fc = layers.Dense(units = 10, activation = tf.keras.activations.softmax)# 定义前向传播过程def call(self,x):# 输入部分传输过程x = self.conv(x)x = self.bn(x)x = self.relu(x)x = self.mp(x)# blockx = self.res_block1(x)x = self.res_block2(x)x = self.res_block3(x)x = self.res_block4(x)# 输出部分的传输x = self.gap(x)x = self.fc(x)return x     

这里每个模块里有4个卷积层(不计算11卷积层),加上最开始的卷积层和最后的全连接层,共计18层。这个模型被称为ResNet-18。通过配置不同的通道数给模块里的残差块数可以得到不同的ResNet模型。虽然ResNet的主体架构跟GoogLeNet的类似,但ResNet结构更简单,修改也更方便。

# 实例化
mynet = Resnet([2, 2, 2, 2])
x = tf.random.uniform((1, 224, 224, 1))
y = mynet(x)
mynet.summary()

最终可以得到ResNet的架构

3.手写数字识别

(1)数据读取

获取数据并进行维度调整

import numpy as np
from tensorflow.keras.datasets import mnist(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# N H W C
train_images = np.reshape(train_images, (train_images.shape[0], train_images.shape[1], train_images.shape[2], 1))test_images = np.reshape(test_images, (test_images.shape[0], test_images.shape[1], test_images.shape[2], 1))

定义两个方法获取部分数据

# 定义两个方法随机抽取部分样本演示def get_train(size):index = np.random.randint(0, np.shape(train_images)[0], size)resize_images = tf.image.resize_with_pad(train_images[index], 224, 224, )return resize_images.numpy(), train_labels[index]def get_test(size):index = np.random.randint(0, np.shape(test_images)[0], size)resize_images = tf.image.resize_with_pad(test_images[index], 224, 224, )return resize_images.numpy(), test_labels[index]
# 获取训练样本和测试样本
train_image, train_label = get_train(256)
test_image, test_label = get_test(128)

(2)模型编译

# 指定优化器,损失函数和评价指标
optimizer = tf.keras.optimizers.SGD(learning_rate = 0.01, momentum = 0.0)mynet.compile(optimizer = optimizer,loss = 'sparse_categorical_crossentropy',metrics = ['accuracy']
)

(3)模型训练

# 模型训练:指定训练数据集,batchsize, epoch, 验证集
mynet.fit(train_images, train_labels, batch_size = 128, epochs = 3, verbose = 1, # 显示整个训练的logvalidation_split = 0.2) # 验证集

(4)模型评估

mynet.evaluate(test_images, test_labels, verbose = 1)
http://www.lryc.cn/news/614090.html

相关文章:

  • 腾讯iOA:数据安全的港湾
  • wordpress的wp-config.php文件的详解
  • proteus实现简易DS18B20温度计(stm32)
  • Linux软硬链接与动静态库
  • SQL的多表连接查询(难点)
  • 冷冻食材,鲜美生活的新选择
  • trae开发c#
  • 面试题:bable,plugin,loader,还有在打包过程中.vue/.react文件是如何转化为.js文件的
  • 解决Ollama外部服务器无法访问:配置 `OLLAMA_HOST=0.0.0.0` 指南
  • 【世纪龙科技】数智重构车身实训-汽车车身测量虚拟实训软件
  • 网络基础——网络层级
  • 库函数NTC采样温度的方法(STC8)
  • 大模型——部署体验gpt-oss-20b
  • 项目一系列-第3章 若依框架入门
  • SEABORN库函数(第十八节课内容总结)
  • 睿抗开发者大赛国赛-24
  • Java基础之匿名内部类与lambda表达式
  • DAY 39 图像数据与显存
  • 缓存投毒进阶 -- justctf 2025 Busy Traffic
  • docker缓存目录转移设置和生效过程
  • 总结运行CRMEB标准版(uniapp)微信小程序的问题
  • 站在Vue的角度,对比鸿蒙开发中的数据渲染二
  • 【ESP32-menuconfig(1) -- Build Type及Bootloader config】
  • 跨平台音乐管理新方案:Melody如何实现一站式音源整合
  • 76 模块编程之高精度定时器
  • 数据仓库知识
  • PBootcms网站模板伪静态配置教程
  • C++信息学奥赛一本通-第一部分-基础一-第2章-第5节
  • linux信号量和日志
  • 户外广告牌识别准确率↑32%:陌讯多模态融合算法实战解析