当前位置: 首页 > news >正文

动手学深度学习—卷积神经网络LeNet(代码详解)

1. LeNet

LeNet由两个部分组成:

  • 卷积编码器:由两个卷积层组成;
  • 全连接层密集块:由三个全连接层组成。

在这里插入图片描述

  1. 每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层;
  2. 每个卷积层使用5×5卷积核和一个sigmoid激活函数;
  3. 这些层将输入映射到多个二维特征输出,通常同时增加通道的数量;
  4. 每个4×4池操作(步幅2)通过空间下采样将维数减少4倍。
import torch
from torch import nn
from d2l import torch as d2l# 定义模型net
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

该模型去掉了最后一层的高斯激活,下面将一个大小为28×28的单通道(黑白)图像通过LeNet,打印每一层输出的形状。

# 观察各层的输入输出通道数,宽度和高度
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)

在这里插入图片描述

  1. 第一个卷积层使用2个像素的填充,来补偿5×5卷积核导致的特征减少;
  2. 第二个卷积层没有填充,因此高度和宽度都减少了4个像素;
  3. 随着层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个;
  4. 每个汇聚层的高度和宽度都减半;
  5. 每个全连接层减少维数,最终输出一个维数与结果分类数相匹配的输出。

2. 模型训练

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
"""定义精度评估函数:1、将数据集复制到显存中2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save# 判断net是否属于torch.nn.Module类if isinstance(net, nn.Module):net.eval()# 如果不在参数选定的设备,将其传输到设备中if not device:device = next(iter(net.parameters())).device# Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 将X, y复制到设备中if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)# 计算正确预测的数量,总预测的数量,并存储到metric中metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]
"""定义GPU训练函数:1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;2、使用Xavier随机初始化模型参数;3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""# 定义初始化参数,对线性层和卷积层生效def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)# 在设备device上进行训练print('training on', device)net.to(device)# 优化器:随机梯度下降optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 损失函数:交叉熵损失函数loss = nn.CrossEntropyLoss()# Animator为绘图函数animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])# 调用Timer函数统计时间timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量metric = d2l.Accumulator(3)net.train()# enumerate() 函数用于将一个可遍历的数据对象for i, (X, y) in enumerate(train_iter):timer.start() # 进行计时optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将特征和标签转移到devicey_hat = net(X)l = loss(y_hat, y) # 交叉熵损失l.backward() # 进行梯度传递返回optimizer.step()with torch.no_grad():# 统计损失、预测正确数和样本数metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop() # 计时结束train_l = metric[0] / metric[2] # 计算损失train_acc = metric[1] / metric[2] # 计算精度# 进行绘图if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 测试精度test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc))# 输出损失值、训练精度、测试精度print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')# 设备的计算能力print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')

在这里插入图片描述

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

3. 小结

  1. 卷积神经网络(CNN)是一类使用卷积层的网络;
  2. 卷积神经网络中,可以组合使用卷积层、非线性激活函数和汇聚层;
  3. 为了构造高性能的卷积神经网络,通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数;
  4. 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。
http://www.lryc.cn/news/125746.html

相关文章:

  • 腾讯面经总结
  • matlab机器人工具箱基础使用
  • 利用WonderLeak进行内存泄露检测【一】
  • 二刷LeetCode--155. 最小栈(C++版本),思维题
  • 进程的状态与转换
  • 用MariaDB创建数据库,SQL练习,MarialDB安装和使用
  • 【Docker】 使用Docker-Compose 搭建基于 WordPress 的博客网站
  • Hlang社区-前端社区宣传首页实现
  • 【LeetCode-Medium】833. 字符串中的查找与替换
  • 数据结构中公式前中后缀表达式-二叉树应用
  • Visual Studio 2022连接远程系统进行C/C++开发
  • TiDB数据库从入门到精通系列之二:TiDB数据库的简介
  • opencv视频截取每一帧并保存为图片python代码CV2实现练习
  • 虹科方案 | 汽车总线协议转换解决方案(二)
  • [Android] 通过JNI 让 JAVA 调用 android native 接口
  • MySQL高可用MHA
  • DoIP学习笔记系列:(五)“安全认证”的.dll从何而来?
  • 205、仿真-51单片机直流数字电流表多档位切换Proteus仿真设计(程序+Proteus仿真+原理图+流程图+元器件清单+配套资料等)
  • 服务器如何防止cc攻击
  • 解读注解@Value占位符替换过程
  • 浅谈5G技术会给视频监控行业带来的一些变革情况
  • Java常用API---快速达到Java工作水准系列(1)
  • Python中使用隧道爬虫ip提升数据爬取效率
  • 深入源码分析kubernetes informer机制(四)DeltaFIFO
  • UI设计师个人工作总结范文
  • explicit关键字 和 static成员
  • 安装Linux操作系统CentOS 6详细图文步骤
  • 新增守护进程管理、支持添加MySQL远程数据库,支持PHP版本切换,1Panel开源面板v1.5.0发布
  • 十、接口(1)
  • percentile_approx 聚合函数