当前位置: 首页 > news >正文

PyTorch入门之【CNN】

参考:https://www.bilibili.com/video/BV1114y1d79e/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
书接上回的MLP故本章就不详细解释了

目录

  • train
  • test

train

import torch
from torchvision.transforms import ToTensor
from torchvision import datasets
import torch.nn as nn# load MNIST dataset
training_data = datasets.MNIST(root='../02_dataset/data',train=True,download=True,transform=ToTensor()
)train_data_loader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)# define a CNN model
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1),nn.BatchNorm2d(32),nn.ReLU())self.conv_2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1),nn.BatchNorm2d(64),nn.ReLU(),)self.maxpool = nn.MaxPool2d(2)self.flatten = nn.Flatten()self.fc_1 = nn.Sequential(nn.Linear(9216, 128),nn.BatchNorm1d(128),nn.ReLU())self.fc_2 = nn.Linear(128, 10)def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = self.maxpool(x)x = self.flatten(x)x = self.fc_1(x)logits = self.fc_2(x)return logits# create a CNN model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN().to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()# train the model
num_epochs = 20for epoch in range(num_epochs):print(f'Epoch {epoch+1}\n-------------------------------')for idx, (img, label) in enumerate(train_data_loader):size = len(train_data_loader.dataset)img, label = img.to(device), label.to(device)# compute prediction errorpred = cnn(img)loss = loss_fn(pred, label)# backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if idx % 400 == 0:loss, current = loss.item(), idx*len(img)print(f'loss: {loss:>7f} [{current:>5d}/{size:>5d}]')# save the model
torch.save(cnn.state_dict(), 'cnn.pth')
print('Saved PyTorch Model State to cnn.pth')

test

import torch
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
import torch.nn as nn# load test data
test_data = datasets.MNIST(root='../02_dataset/data',train=False,download=True,transform=ToTensor()
)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)transform = transforms.Compose([transforms.Grayscale(),transforms.RandomRotation(10),transforms.ToTensor()
])
my_mnist = ImageFolder(root='../02_dataset/my-mnist', transform=transform)
my_mnist_loader = torch.utils.data.DataLoader(my_mnist, batch_size=64, shuffle=True)# define a CNN model
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1),nn.BatchNorm2d(32),nn.ReLU())self.conv_2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1),nn.BatchNorm2d(64),nn.ReLU(),)self.maxpool = nn.MaxPool2d(2)self.flatten = nn.Flatten()self.fc_1 = nn.Sequential(nn.Linear(9216, 128),nn.BatchNorm1d(128),nn.ReLU())self.fc_2 = nn.Linear(128, 10)def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = self.maxpool(x)x = self.flatten(x)x = self.fc_1(x)logits = self.fc_2(x)return logits# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN()
cnn.load_state_dict(torch.load('cnn.pth', map_location=device))
cnn.eval().to(device)# test the pretrained model on MNIST test data
size = len(test_data_loader.dataset)
correct = 0with torch.no_grad():for img, label in test_data_loader:img, label = img.to(device), label.to(device)pred = cnn(img)correct += (pred.argmax(1) == label).type(torch.float).sum().item()correct /= size
print(f'Accuracy on MNIST: {(100*correct):>0.1f}%')# test the pretrained model on my MNIST test data
size = len(my_mnist_loader.dataset)
correct = 0with torch.no_grad():for img, label in my_mnist_loader:img, label = img.to(device), label.to(device)pred = cnn(img)correct += (pred.argmax(1) == label).type(torch.float).sum().item()correct /= size
print(f'Accuracy on my MNIST: {(100*correct):>0.1f}%')
http://www.lryc.cn/news/183643.html

相关文章:

  • 马斯洛需求层次模型之安全需求之云安全浅谈
  • Pikachu靶场——远程命令执行漏洞(RCE)
  • 【WSN】无线传感器网络 X-Y 坐标到图形视图和位字符串前缀嵌入方法研究(Matlab代码实现)
  • Linux定时任务
  • 【Overload游戏引擎分析】画场景网格的Shader
  • 【JavaEE】多线程进阶(一)饿汉模式和懒汉模式
  • C++树详解
  • 支付环境安全漏洞介绍
  • 抄写Linux源码(Day16:内存管理)
  • Cookie和Session详解以及结合生成登录效果
  • Spring基础以及核心概念(IoC和DIQ)
  • 《C和指针》笔记32:多维数组初始化
  • 零食食品经营小程序商城的作用是什么
  • Java泛型--什么是泛型?
  • LabVIEW工业虚拟仪器的标准化实施
  • JavaScript系列从入门到精通系列第十七篇:JavaScript中的全局作用域
  • 汇编指令集合
  • TinyWebServer整体流程
  • 【Java项目推荐之黑马头条】自媒体文章实现异步上下架(使用Kafka中间件实现)
  • 自学(黑客)技术方法————网络安全
  • python+playwright 学习-84 Response 接口返回对象
  • GCN详解
  • 总结二:linux面经
  • 12、【Qlib】【主要组件】Qlib Recorder:实验管理
  • 三一充填泵:煤矿矸石无害化充填,煤炭绿色高效开采的破局利器
  • 医疗器械标准目录汇编2022版共178页(文中附下载链接!)
  • C#和Excel文件的读写交互
  • Pytorch目标分类深度学习自定义数据集训练
  • 2023 年 Web 安全最详细学习路线指南,从入门到入职(含书籍、工具包)【建议收藏】
  • qt常用控件1