当前位置: 首页 > news >正文

动手学深度学习(Pytorch版)代码实践 -卷积神经网络-29残差网络ResNet

29残差网络ResNet

在这里插入图片描述

import torch  
from torch import nn  
from torch.nn import functional as F 
import liliPytorch as lp  
import matplotlib.pyplot as plt# 定义一个继承自nn.Module的残差块类
class Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()# 第一个卷积层,使用3x3的卷积核,填充为1,步幅为指定值self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)# 第二个卷积层,使用3x3的卷积核,填充为1self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)# 可选的1x1卷积层,用于匹配输入输出通道数和步幅if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = None# 批量归一化层self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)# 为什么需要两个不同的批量归一化层?# 1.不同的位置,不同的输入特征# 2.独立的参数和统计数据def forward(self, X):# 先通过第一个卷积层、批量归一化层和ReLU激活函数Y = F.relu(self.bn1(self.conv1(X)))# 然后通过第二个卷积层和批量归一化层Y = self.bn2(self.conv2(Y))# 如果定义了conv3,则通过conv3调整Xif self.conv3:X = self.conv3(X)# 将输入X加到输出Y上实现残差连接Y += X# 通过ReLU激活函数return F.relu(Y)# 创建一个包含输入和输出形状一致的残差块实例,并测试其输出形状
# blk = Residual(3, 3)
# X = torch.rand(4, 3, 6, 6)
# Y = blk(X)
# print(Y.shape)  # 预期输出形状:torch.Size([4, 3, 6, 6])# 创建一个包含1x1卷积和步幅为2的残差块实例,并测试其输出形状
# blk = Residual(3, 6, use_1x1conv=True, strides=2)
# print(blk(X).shape)  # 预期输出形状:torch.Size([4, 6, 3, 3])# 定义一个包含初始卷积层、批量归一化层、ReLU激活函数和最大池化层的顺序容器
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)# 定义一个函数,用于创建由多个残差块组成的模块
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):# 如果是第一个残差块且不是第一个模块,则使用1x1卷积和步幅为2if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk# 创建由残差块组成的各个模块
# *符号有多种用途,但在函数调用时,*符号主要用于将列表或元组解包。
# *resnet_block()的作用是将列表中的元素逐个传递给nn.Sequential
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))# 创建整个ResNet模型
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),  # 自适应平均池化层nn.Flatten(),  # 展平层nn.Linear(512, 176)  # 全连接层,输出10类
)# 测试整个网络的输出形状
X = torch.rand(size=(1, 1, 96, 96))
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 128, 12, 12])
# Sequential output shape:         torch.Size([1, 256, 6, 6])
# Sequential output shape:         torch.Size([1, 512, 3, 3])
# AdaptiveAvgPool2d output shape:  torch.Size([1, 512, 1, 1])
# Flatten output shape:    torch.Size([1, 512])
# Linear output shape:     torch.Size([1, 10])# 设置训练参数
lr, num_epochs, batch_size = 0.05, 10, 256
# 加载训练和测试数据
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size, resize=96)
# 训练模型
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
# 显示训练结果
plt.show()# loss 0.009, train acc 0.998, test acc 0.920
# 2306.3 examples/sec on cuda:0

运行结果:
在这里插入图片描述

http://www.lryc.cn/news/386535.html

相关文章:

  • 解锁音乐潮流:使用TikTok API获取平台音乐信息
  • 基于yolo的物体识别坐标转换
  • STM32第七课:KQM6600空气质量传感器
  • 任务4.8.4 利用Spark SQL实现分组排行榜
  • 五线谱与简谱有什么区别 五线谱简谱混排怎么打 吉他谱软件哪个好
  • [C#][opencvsharp]C#使用opencvsharp进行年龄和性别预测支持视频图片检测
  • pdf拆分,pdf拆分在线使用,pdf拆分多个pdf
  • VScode Python debug:hydra.run.dir 写入launch.json
  • ExVideo: 提升5倍性能-用于视频合成模型的新型后调谐方法
  • laravel Dcat Admin 入门应用(三)Grid 之 Column
  • 掌握Llama 2分词器:填充、提示格式及更多
  • pdf合并,pdf合并成一个pdf,pdf合并在线网页版
  • 算法基础--------【图论】
  • x86和x64架构的区别及应用
  • 2024年度总结:不可错过的隧道IP网站评估推荐
  • Linux下VSCode的安装和基本使用
  • C# 实现websocket双向通信
  • Spring Boot结合FFmpeg实现视频会议系统视频流处理与优化
  • 扫扫地,搞搞卫生 ≠ 车间5S管理
  • ES(笔记)
  • 开箱即用的fastposter海报生成器
  • 力扣每日一题 6/28 动态规划/数组
  • [数据集][目标检测]游泳者溺水检测数据集VOC+YOLO格式8275张4类别
  • 若依 ruoyi 分离版 vue 简单的行内编辑实现
  • 【工具】API文档生成DocFX
  • 在 JavaScript 中处理异步操作和临时事件处理程序
  • [Cocos Creator] v3.8开发知识点记录(持续更新)
  • Excel_VBA编程
  • Java中的Path类使用详解及最佳实践
  • 生成和查看预定义宏