当前位置: 首页 > news >正文

实战:基于卷积的MNIST手写体分类

前面实现了基于多层感知机的MNIST手写体识别,本章将实现以卷积神经网络完成的MNIST手写体识别。

1.  数据的准备

在本例中,依旧使用MNIST数据集,对这个数据集的数据和标签介绍,前面的章节已详细说明过了,相对于前面章节直接对数据进行“折叠”处理,这里需要显式地标注出数据的通道,代码如下:

import numpy as npimport einops.layers.torch as elt#载入数据x_train = np.load("../dataset/mnist/x_train.npy")y_train_label = np.load("../dataset/mnist/y_train_label.npy")x_train = np.expand_dims(x_train,axis=1)   #在指定维度上进行扩充print(x_train.shape)

这里是对数据的修正,np.expand_dims的作用是在指定维度上进行扩充,这里在第二维(也就是PyTorch的通道维度)进行扩充,结果如下:

(60000, 1, 28, 28)

2.  模型的设计

下面使用PyTorch 2.0框架对模型进行设计,在本例中将使用卷积层对数据进行处理,完整的模型如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
class MnistNetword(nn.Module):def __init__(self):super(MnistNetword, self).__init__()#前置的特征提取模块self.convs_stack = nn.Sequential(nn.Conv2d(1,12,kernel_size=7),  	#第一个卷积层nn.ReLU(),nn.Conv2d(12,24,kernel_size=5), 	#第二个卷积层nn.ReLU(),nn.Conv2d(24,6,kernel_size=3)  	#第三个卷积层)#最终分类器层self.logits_layer = nn.Linear(in_features=1536,out_features=10)def forward(self,inputs):image = inputsx = self.convs_stack(image)        #elt.Rearrange的作用是对输入数据的维度进行调整,读者可以使用torch.nn.Flatten函数完成此工作x = elt.Rearrange("b c h w -> b (c h w)")(x)logits = self.logits_layer(x)return logits
model = MnistNetword()
torch.save(model,"model.pth")

这里首先设定了3个卷积层作为前置的特征提取层,最后一个全连接层作为分类器层,需要注意的是,对于分类器的全连接层,输入维度需要手动计算,当然读者可以一步一步尝试打印特征提取层的结果,依次将结果作为下一层的输入维度。最后对模型进行保存。

3.  基于卷积的MNIST分类模型

下面进入本章的最后示例部分,也就是MNIST手写体的分类。完整的训练代码如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
#载入数据
x_train = np.load("../dataset/mnist/x_train.npy")
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x_train = np.expand_dims(x_train,axis=1)
print(x_train.shape)
class MnistNetword(nn.Module):def __init__(self):super(MnistNetword, self).__init__()self.convs_stack = nn.Sequential(nn.Conv2d(1,12,kernel_size=7),nn.ReLU(),nn.Conv2d(12,24,kernel_size=5),nn.ReLU(),nn.Conv2d(24,6,kernel_size=3))self.logits_layer = nn.Linear(in_features=1536,out_features=10)def forward(self,inputs):image = inputsx = self.convs_stack(image)x = elt.Rearrange("b c h w -> b (c h w)")(x)logits = self.logits_layer(x)return logits
device = "cuda" if torch.cuda.is_available() else "cpu"
#注意记得将model发送到GPU计算
model = MnistNetword().to(device)
model = torch.compile(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
batch_size = 128
for epoch in range(42):train_num = len(x_train)//128train_loss = 0.for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizex_batch = torch.tensor(x_train[start:end]).to(device)y_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(x_batch)loss = loss_fn(pred, y_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == y_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

在这里,我们使用了本章新定义的卷积神经网络模块作为局部特征抽取,而对于其他的损失函数以及优化函数,只使用了与前期一样的模式进行模型训练。最终结果如下所示,请读者自行验证。

(60000, 1, 28, 28)
epoch: 0 train_loss: 2.3 accuracy: 0.11
epoch: 1 train_loss: 2.3 accuracy: 0.13
epoch: 2 train_loss: 2.3 accuracy: 0.2
epoch: 3 train_loss: 2.3 accuracy: 0.18
…
epoch: 58 train_loss: 0.5 accuracy: 0.98
epoch: 59 train_loss: 0.49 accuracy: 0.98
epoch: 60 train_loss: 0.49 accuracy: 0.98
epoch: 61 train_loss: 0.48 accuracy: 0.98
epoch: 62 train_loss: 0.48 accuracy: 0.98Process finished with exit code 0

本文节选自《PyTorch 2.0深度学习从零开始学》,本书实战案例丰富,可带领读者快速掌握深度学习算法及其常见案例。

   

http://www.lryc.cn/news/151757.html

相关文章:

  • Ubuntu开启生成Core Dump的方法
  • git视频教程Jenkins持续集成视频教程Git Gitlab Sonar教程
  • 机器学习:Xgboost
  • 《Kubernetes部署篇:Ubuntu20.04基于二进制安装安装cri-containerd-cni》
  • [CISCN 2019初赛]Love Math
  • 运行命令出现错误 /bin/bash^M: bad interpreter: No such file or directory
  • 码农重装系统后需要安装的软件
  • Kotlin return 和 loop jump
  • 计算一组数据中的低中位数即如果一组数据中有两个中位数则较小的那个为低中位数statistics.median_low()
  • ChatGPT是否能够协助人们提高公共服务和社区建设能力?
  • 机器人中的数值优化(七)——修正阻尼牛顿法
  • 程序员自由创业周记#3:No1.作品
  • 固定资产制度怎么完善管理?
  • 神经网络--感知机
  • Java“牵手”1688图片识别商品接口数据,图片地址识别商品接口,图片识别相似商品接口,1688API申请指南
  • 科技资讯|微软获得AI双肩包专利,Find My防丢背包大火
  • 数学建模:多目标优化算法
  • arcmap 在oracle删除表重新创建提示表名存在解决放啊
  • 新版HBuilderX在uni_modules创建搜索search组件
  • Ubutnu允许ssh连接使用root与密码登录
  • MySQL中表的设计
  • UE4/5在蓝图细节面板中添加函数按钮(蓝图与c++的方法)
  • Python爬虫乱码问题之encoding和apparent_encoding的区别
  • Docker技术--Docker简介和架构
  • 废品回收功能文档
  • 【ARMv8 SIMD和浮点指令编程】NEON 乘法指令——asimdrdm
  • [SWPUCTF 2022]——Web方向 详细Writeup
  • Shell编程:流程控制与高级应用的深入解析
  • 一文讲通嵌入式现状
  • 设计模式-代理模式Proxy