当前位置: 首页 > news >正文

Pytorch-MLP-CIFAR10

文章目录

  • model.py
  • main.py
  • 参数设置
  • 注意事项
  • 运行图

model.py

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as initclass MLP_cls(nn.Module):def __init__(self,in_dim=3*32*32):super(MLP_cls,self).__init__()self.lin1 = nn.Linear(in_dim,128)self.lin2 = nn.Linear(128,64)self.lin3 = nn.Linear(64,10)self.relu = nn.ReLU()init.xavier_uniform_(self.lin1.weight)init.xavier_uniform_(self.lin2.weight)init.xavier_uniform_(self.lin3.weight)def forward(self,x):x = x.view(-1,3*32*32)x = self.lin1(x)x = self.relu(x)x = self.lin2(x)x = self.relu(x)x = self.lin3(x)x = self.relu(x)return x

main.py

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import MLP_cls,CNN_clsseed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
net = MLP_cls()train_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('./data/', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])),batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('./data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])),batch_size=batch_size_test, shuffle=True)optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()print("****************Begin Training****************")
net.train()
for epoch in range(epochs):run_loss = 0correct_num = 0for batch_idx, (data, target) in enumerate(train_loader):out = net(data)_,pred = torch.max(out,dim=1)optimizer.zero_grad()loss = criterion(out,target)loss.backward()run_loss += lossoptimizer.step()correct_num  += torch.sum(pred==target)print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))print("****************Begin Testing****************")
net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):out = net(data)_,pred = torch.max(out,dim=1)test_loss += criterion(out,target)test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

参数设置

'./data/' #数据保存路径
seed = 42 #随机种子
batch_size_train = 64
batch_size_test  = 64
epochs = 10optim --> SGD
learning_rate = 0.01
momentum = 0.5

注意事项

CIFAR10是彩色图像,单个大小为3*32*32。所以view的时候后面展平。

运行图

在这里插入图片描述

http://www.lryc.cn/news/169560.html

相关文章:

  • SQL2 查询多列
  • 算法分享三个方面学习方法(做题经验,代码编写经验,比赛经验)
  • 爬虫 — 验证码反爬
  • 视频图像处理算法opencv模块硬件设计图像颜色识别模块
  • 目标检测网络之Fast-RCNN
  • Golang Gorm 创建HOOK
  • 计算机视觉的应用15-图片旋转验证码的角度计算模型的应用,解决旋转图片矫正问题
  • 【Seata】分布式事务问题和理论基础
  • 文件打包解包的方法
  • npm 清缓存(重新安装node-modules)
  • sqlserver查询表中所有字段信息
  • 二叉树的概念、存储及遍历
  • 【面试题】智力题
  • 【SpringBoot集成Redis + Session持久化存储到Redis】
  • day49:QT day2,信号与槽、对话框
  • Meta分析核心技术
  • Gof23设计模式之责任链模式
  • 数字孪生和元宇宙:打造未来的数字边界
  • 【新版】系统架构设计师 - 软件架构设计<新版>
  • Linux面试题
  • NODEJS版本管理工具
  • 【个人笔记本】本地化部署 类chatgpt模型 详细流程
  • RFID与人工智能怎么融合,RFID与人工智能融合的应用
  • 性能测试 —— Jmeter 常用三种定时器
  • 每个高级前端工程师都应该知道的前端布局
  • 100道基于Android毕业设计的选题题目,持续更新
  • idea显示git分支信息(GitToolBox插件)
  • Hadoop知识点之Hadoop发展历程
  • 阿里云无影电脑:免费体验无影云电脑3个月
  • 菜鸟教程《Python 3 教程》笔记(20):面向对象