当前位置: 首页 > news >正文

PyTorch VGG16手写数字识别教程

手写数字识别教程:使用PyTorch和VGG16

1. 环境准备

确保你已安装以下库:

pip install torch torchvision
2. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
3. 数据预处理

我们需要对MNIST数据集进行转换,使其适合输入VGG16模型。由于VGG16的输入要求为224x224的图像,因此我们需要调整图像大小,并进行标准化处理。

transform = transforms.Compose([transforms.Resize((224, 224)),  # 将图像大小调整为224x224transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))  # 标准化处理,均值和标准差
])# 下载并加载训练和测试数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
4. 定义VGG16模型

VGG16由多个卷积层和全连接层组成。我们将调整输入通道以适应单通道的MNIST数据。

class VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()# 定义卷积层self.vgg = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, padding=1),  # 将输入通道设置为1(灰度图)nn.ReLU(),  # 激活函数nn.MaxPool2d(kernel_size=2, stride=2),  # 最大池化层,减小特征图尺寸nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(256, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)# 定义全连接层self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),  # 第一个全连接层nn.ReLU(),nn.Dropout(),  # 随机失活,防止过拟合nn.Linear(4096, 4096),  # 第二个全连接层nn.ReLU(),nn.Dropout(),nn.Linear(4096, 10)  # 输出层,10个类(数字0-9))def forward(self, x):x = self.vgg(x)  # 通过卷积层x = x.view(x.size(0), -1)  # 展平特征图x = self.classifier(x)  # 通过全连接层return x
5. 训练模型

我们将使用交叉熵损失函数和Adam优化器,并训练模型。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 检测可用的设备
model = VGG16().to(device)  # 实例化模型并移动到设备上
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器# 训练循环
for epoch in range(5):  # 训练5个epochmodel.train()  # 设置为训练模式for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 移动到设备optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')  # 输出当前epoch的损失
6. 测试模型

在测试阶段,我们将计算模型的准确率。

model.eval()  # 设置为评估模式
with torch.no_grad():  # 禁用梯度计算correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 移动到设备outputs = model(images)  # 前向传播_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)  # 统计总样本数correct += (predicted == labels).sum().item()  # 统计正确预测的数量print(f'Accuracy: {100 * correct / total:.2f}%')  # 输出准确率

总结

这个教程详细介绍了如何使用VGG16模型对MNIST数据集进行手写数字识别。通过调整网络参数和训练轮数,你可以进一步提高模型的性能。希望这个教程能帮助你更好地理解PyTorch及深度学习的应用!

http://www.lryc.cn/news/444985.html

相关文章:

  • 安卓13删除下拉栏中的设置按钮 android13删除设置按钮
  • FDA辅料数据库在线免费查询-药用辅料
  • git pull 报错 refusing to merge unrelated histories
  • STM32G431RBT6(蓝桥杯)串口(发送)
  • 使用 typed-rest-client 进行 REST API 调用
  • 在Ubuntu 14.04上安装Solr的方法
  • LabVIEW提高开发效率技巧----使用LabVIEW工具
  • Pyspark dataframe基本内置方法(4)
  • 配置win10开电脑时显示可登录账号策略
  • 01-Mac OS系统如何下载安装Python解释器
  • 24 C 语言常用的字符串处理函数详解:strlen、strcat、strcpy、strcmp、strchr、strrchr、strstr、strtok
  • 数据驱动农业——农业中的大数据
  • 学习《分布式》必须清楚的《CAP理论》
  • navicat无法连接远程mysql数据库1130报错的解决方法
  • JetPack01- LifeCycle 监听Activity或Fragment的生命周期
  • OpenCSG推出StarShip SecScan:AI驱动的软件安全革新
  • 占道经营检测-目标检测数据集(包括VOC格式、YOLO格式)
  • 828华为云征文 | 云服务器Flexus X实例:RAG 开源项目 FastGPT 部署,玩转大模型
  • MySQL之基本查询(一)(insert || select)
  • 基于深度学习的多智能体协作
  • Nmap网络扫描器基础功能介绍
  • idea 编辑器常用插件集合
  • 如何优化Java商城系统的代码结构
  • 两数之和、三数之和、四数之和
  • 这几个方法轻松压缩ppt文件大小,操作起来很简单的压缩PPT方法
  • 【nvm管理多版本node】下载安装以及常见问题和解决方案
  • C++(学习)2024.9.23
  • 大数据处理从零开始————3.Hadoop伪分布式和分布式搭建
  • 跟着问题学12——GRU详解
  • 内核是如何接收网络包的