当前位置: 首页 > news >正文

一起深度学习

CIFAR-10 卷积神经网络

  • 下载数据集
  • 构建网络
  • 运行测试

下载数据集

 batchsz = 32cifar_train= datasets.CIFAR10('data',train=True,transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()]),download=True)cifar_train = DataLoader(cifar_train,batch_size=batchsz,shuffle=True)cifar_test= datasets.CIFAR10('data',train=False,transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()]),download=True)cifar_test = DataLoader(cifar_test,batch_size=batchsz,shuffle=True)

构建网络

新建一个lenet5

import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):def __init__(self):super(Lenet5,self).__init__()self.conv_unit = nn.Sequential(# x :[b,3,32,32] => [b,6,]nn.Conv2d(3,6,5,1),  #卷积层#subsamping  池化层nn.AvgPool2d(kernel_size=2,stride=2,padding=0),#nn.Conv2d(6,16,5,1,0),nn.AvgPool2d(kernel_size=2,stride=2,padding=0))#flatten#fc_unitself.fc_unit = nn.Sequential(nn.Linear(16*5*5,120),nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))# self.criten = nn.CrossEntropyLoss()def forward(self,x):bachsz = x.size(0) #获取样本数量x = self.conv_unit(x)x = x.view(bachsz,16*5*5)logits = self.fc_unit(x)  #获取输出标签return logits

运行测试

   device  = torch.device('cuda') #使用gpu运行model = Lenet5().to(device)  #实例化网络criten = nn.CrossEntropyLoss().to(device)  #使用交叉熵optimizer = optim.Adam(model.parameters(),lr=1e-3)  #采用Adam及逆行优化参数for epoch in range(1000):for batchidx,(x,lable) in enumerate(cifar_train):x,lable = x.to(device),lable.to(device)logits = model(x)  #获得预测输出标签值loss = criten(logits,lable) #计算损失值optimizer.zero_grad() #将梯度归零loss.backward() #方向传播optimizer.step() #优化参数print(epoch,loss.item())total_correct = 0total_num = 0model.eval()  with torch.no_grad():  #表示不需要求梯度for x,label in cifar_test:x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)  获取预测值total_correct += torch.eq(pred,label).float().sum().item()total_num += x.size(0)acc = total_correct /total_numprint(epoch,acc)

网络图如下:
在这里插入图片描述

http://www.lryc.cn/news/343418.html

相关文章:

  • servlet-会话(cookie与session)
  • windows11忘记登录密码怎么办?
  • C#里如何设置输出路径,不要net7.0-windows
  • 知名员工上网行为管理系统推荐榜单
  • 第12章 软件测试基础(第三部分)测试类型、测试工具
  • open-vm-tools使用虚机的拷贝/粘切
  • CKEditor编辑器的简单使用方法,取值,赋值
  • 创建一个线程对象需要花费多少内存空间
  • Java -- (part23)
  • 1. C++入门:命名空间及输入输出
  • 【Kotlin】Java三目运算转成 kotlin 表达
  • 如何安全可控地进行内外网跨网络传输文件?
  • Python Json数据解析
  • pyinstaller打包pytorch和transformers程序
  • 西门子数控网络IP设定配置
  • [Unity]备份许可文件
  • 第十五届蓝桥杯省赛大学B组(c++)
  • Python Flask框架(一)初识Flask
  • VS2022 .Net6.0 无法打开窗体设计器
  • Linux学习之高级IO
  • 一分钟了解Polysciences PEI 40K转染试剂的原理
  • Clickhouse IP 函数
  • 【Python】numpy.ptp()
  • The provided password or token is incorrect or your account
  • 常见的shell命令
  • 堆栈打印跟踪Activity的启动过程(基于Android10.0.0-r41),framework修改,去除第三方app的倒计时页面
  • 只允许内网访问时,如何设置hosts
  • nature《自然》期刊文献怎么在家查看下载
  • python作业五
  • 经典的设计模式和Python示例(一)