当前位置: 首页 > news >正文

【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别

【Pytorch】学习记录分享5——PyTorch经典网络 ResNet

      • 1. ResNet (残差网络)基础知识
      • 2. 感受野
      • 3. 手写体数字识别
        • 3. 0 数据集(训练与测试集)
        • 3. 1 数据加载
        • 3. 2 函数实现:
        • 3. 3 训练及其测试:

1. ResNet (残差网络)基础知识

图1 56层error比20层error高,提出ResNet (残差网络)的方案
在这里插入图片描述

网络效果:

在这里插入图片描述
网络结构:
在这里插入图片描述
在这里插入图片描述

2. 感受野

在这里插入图片描述
在这里插入图片描述

3. 手写体数字识别

3. 0 数据集(训练与测试集)

mnist 用于手写体训练与测试,这里包含完整的链接

3. 1 数据加载
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
### 首先读取数据
# - 分别构建训练集和测试集(验证集)
# - DataLoader来迭代取数据# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

在这里插入图片描述

3. 2 函数实现:
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)  output = self.out(x)return output# 准确率作为评估标准
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 
3. 3 训练及其测试:
# 训练网络模型
# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = []for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()  # 将模型设置为训练模式output = net(data)  # 使用模型进行前向传播loss = criterion(output, target)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新参数right = accuracy(output, target)  # 计算当前批次的准确率train_rights.append(right)  # 将准确率保存起来if batch_idx % 500 == 0:  # 每500个批次进行一次验证net.eval()  # 将模型设置为评估模式val_rights = []  # 存储验证集的准确率for (data, target) in test_loader:  # 在测试集上进行验证output = net(data)  # 使用模型进行前向传播right = accuracy(output, target)  # 计算验证集上的准确率val_rights.append(right)  # 将准确率保存起来#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))  # 计算训练集准确率的分子和分母val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))  # 计算验证集准确率的分子和分母print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1],100. * val_r[0].numpy() / val_r[1]))  # 打印当前进度和准确率信息

在这里插入图片描述

http://www.lryc.cn/news/268720.html

相关文章:

  • Flink1.17实战教程(第三篇:时间和窗口)
  • CSS 纵向扩展动画
  • Android 12 Token 机制
  • TCP与UDP是流式传输协议吗?
  • 61 贪心算法解救生艇问题
  • C#高级 01.Net多线程
  • Java---泛型讲解
  • 【论文阅读笔记】SegVol: Universal and Interactive Volumetric Medical Image Segmentation
  • Unix/Linux操作系统介绍
  • 什么是https证书?
  • C++ DAY2作业
  • RabbitMQ核心概念记录
  • 算法时间空间复杂度计算—空间复杂度
  • 计算机专业校招常见面试题目总结
  • 网络编程『简易TCP网络程序』
  • java itext5 生成PDF并填充数据导出
  • 如何配置TLSv1.2版本的ssl
  • 在CentOS 7上使用普通用户`minio`安装和配置MinIO
  • Vue3-27-路由-路径参数的简单使用
  • w7数据库基础之mysql函数
  • 智能优化算法应用:基于人工蜂鸟算法3D无线传感器网络(WSN)覆盖优化 - 附代码
  • Docker的基础使用
  • Sass(Scss)、Less的区别与选择 + 基本使用
  • GPT Zero 是什么?
  • c++学习笔记-提高篇-案例2-员工分组(vector/multimap)
  • TrustZone之问答
  • vue3中新增的组合式API:ref、reactive、toRefs、computed、watch、provide/inject、$ref
  • Flask 密码重设系统
  • HarmonyOS4.0开发应用(四)【ArkUI状态管理】
  • JS常见正则表达式写法(附案例)