当前位置: 首页 > news >正文

基于深度可分离卷积的MNIST手势识别

基于深度可分离膨胀卷积的MNIST手写体识别

Github链接

项目背景
MNIST手写体数据集是深度学习领域中一个经典的入门数据集,包含了从0到9的手写数字图像,用于评估不同模型在图像分类任务上的性能。在本项目中,我们通过设计一种基于深度可分离膨胀卷积的神经网络模型,解决模型参数量大与特征提取能力不足之间的矛盾,同时实现对MNIST手写数字的高效识别。

核心技术

  1. 深度卷积
    深度卷积(Depthwise Convolution)将标准卷积分解为每个通道独立的卷积操作,从而显著减少计算量。相比传统卷积操作,深度卷积只需要处理每个通道的卷积,不会引入通道间的冗余计算。
  2. 点卷积
    点卷积(Pointwise Convolution)采用1×1的卷积核,用于整合深度卷积生成的特征,将通道间信息融合,增强表达能力。点卷积是深度可分离卷积不可或缺的部分,负责重建多通道的全局信息。
  3. 膨胀卷积
    膨胀卷积(Dilated Convolution)通过在卷积核间插入空洞扩展感受野,允许模型捕获更大范围的上下文信息,特别适合处理具有稀疏特征的任务,同时避免了增加参数量的开销。

在这里插入图片描述

实现流程

  1. 数据准备
    • 加载MNIST数据集,进行标准化预处理。
    • 将数据分为训练集和测试集,保证模型的评估结果具备可靠性。
  2. 模型设计
    • 构建基于深度卷积、点卷积及膨胀卷积的神经网络结构,重点在于设计轻量化且具有良好表达能力的卷积模块。
    • 使用ReLU激活函数和全连接层对提取的特征进行分类处理。
  3. 模型训练与测试
    • 使用交叉熵损失函数和Adam优化器训练模型,记录损失值变化以监控收敛情况。
    • 测试阶段,评估模型在MNIST测试集上的分类准确率,并验证其泛化能力。
  4. 参数量对比分析
    • 对比标准卷积和深度可分离卷积在参数量上的差异,直观展示优化效果。
    • 在同等条件下,深度可分离卷积显著减少参数量,同时保持了分类性能的稳定性。

项目成果
通过本项目的实验,模型在MNIST数据集上的分类准确率达到了接近 98.7% 的水平。结合深度可分离膨胀卷积的轻量化设计,我们在大幅减少参数量的同时,实现了与传统卷积模型媲美的性能。此方法为资源受限的场景(如移动设备和嵌入式系统)提供了一种有效的解决方案。

结论与展望
本项目展示了深度可分离膨胀卷积在图像分类任务中的强大能力,特别是在参数量和计算量优化方面。未来的工作可以尝试将该方法应用于更复杂的数据集和任务场景,例如自然图像分类、目标检测或语义分割,从而进一步验证其通用性与潜力。

# @Time    : 28/12/2024 上午 10:00
# @Author  : Xuan
# @File    : 基于深度可分离膨胀卷积的MNIST手写体识别.py
# @Software: PyCharmimport torch
import torch.nn as nn
import einops.layers.torch as elt
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 下载并加载训练集
train_dataset = datasets.MNIST(root='./dataset', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)# 下载并加载测试集
test_dataset = datasets.MNIST(root='./dataset', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)# 检查数据集大小
print(f'Train dataset size: {len(train_dataset)}') # Train dataset size: 60000
print(f'Test dataset size: {len(test_dataset)}') # Test dataset size: 10000class Model(nn.Module):def __init__(self):super(Model, self).__init__()# 深度、可分离、膨胀卷积self.conv = nn.Sequential(nn.Conv2d(1, 12, kernel_size=7),nn.ReLU(),nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, groups=6, dilation=2),nn.Conv2d(in_channels=12, out_channels=24, kernel_size=1),nn.ReLU(),nn.Conv2d(24, 6, kernel_size=7),)self.logits_layer = nn.Linear(in_features=6 * 12 * 12, out_features=10)def forward(self, x):x = self.conv(x)x = elt.Rearrange('b c h w -> b (c h w)')(x)logits = self.logits_layer(x)return logitsdevice = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)epochs = 10
for epoch in range(epochs):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')# Save the model
torch.save(model.state_dict(), './modelsave/mnist.pth')# Load the model
model = Model().to(device)
model.load_state_dict(torch.load('./modelsave/mnist.pth'))# Evaluation
model.eval()
correct = 0
first_image_displayed = False
with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()# Display the first image and its predictionif not first_image_displayed:plt.imshow(data[0].cpu().numpy().squeeze(), cmap='gray')plt.title(f'Predicted: {pred[0].item()}')plt.show()first_image_displayed = Trueprint(f'Test set: Accuracy: {correct / len(test_loader.dataset):.4f}') # Test set: Accuracy: 0.9874# 深度可分离卷积参数比较
# 普通卷积参数量
conv = nn.Conv2d(in_channels=3 ,out_channels=3, kernel_size=3) # in(3) * out(3) * k(3) * k(3) + out(3) = 84
conv_params = sum(p.numel() for p in conv.parameters())
print('conv_params:', conv_params) # conv_params: 84# 深度可分离卷积参数量
depthwise = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, groups=3) # in(3) * k(3) * k(3) + out(3) = 30
pointwise = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1) # in(3) * out(3) * k(1) * k(1) + out(3) = 12
depthwise_separable = nn.Sequential(depthwise, pointwise)
depthwise_separable_params = sum(p.numel() for p in depthwise_separable.parameters())
print('depthwise_separable_params:', depthwise_separable_params) # depthwise_separable_params: 42

在这里插入图片描述

http://www.lryc.cn/news/511188.html

相关文章:

  • Linux服务器pm2 运行chatgpt-on-wechat,搭建微信群ai机器人
  • Word批量更改题注
  • Springboot配置嵌入式服务器
  • 正交三角函数全面阐述
  • 《Vue3 四》Vue 的组件化
  • linux中,mysql数据库分片(分库分表)
  • springboot503基于Sringboot+Vue个人驾校预约管理系统(论文+源码)_kaic
  • Docker应用-项目部署及DockerCompose
  • 从0入门自主空中机器人-2-1【无人机硬件框架】
  • Kafka高性能设计
  • Redis字符串底层结构对数值型的支持常用数据结构和使用场景
  • uniapp 微信小程序 数据空白展示组件
  • 在vscode的ESP-IDF中使用自定义组件
  • 目标检测,语义分割标注工具--labelimg labelme
  • 发明专利与实用新型专利申请过程及自助与代办方式对比
  • Datawhale AI冬令营(第二期)动手学AI Agent task2--学Prompt工程,优化Agent效果
  • 基于python对网页进行爬虫简单教程
  • 【JavaEE进阶】@RequestMapping注解
  • 【WebAR-图像跟踪】在Unity中基于Imagine WebAR实现AR图像识别
  • 向bash shell脚本传参
  • Oracle中listagg与wm_concat函数的区别
  • 热更新与资源管理
  • Momentum Provably Improves Error Feedback!
  • Elasticsearch-脚本查询
  • 《Opencv》基础操作详解(3)
  • meshy的文本到3d的使用
  • C语言技巧之有条件的累加
  • 解释为什么fetch(JavaScript)无法将读取的数据存入外部变量
  • Windows Subsystem for Linux (WSL)
  • Go的Slice如何扩容