当前位置: 首页 > news >正文

机器学习基础-手写数字识别

  1. 手写数字识别,计算机视觉领域的Hello World
  2. 利用MNIST数据集,55000训练集,5000验证集。
  3. Pytorch实现神经网络手写数字识别
  4. 感知机与神经元、权重和偏置、神经网络、输入层、隐藏层、输出层
  5. mac gpu的使用
  6. 本节就是对Pytorch可以做的事情有个直观的理解,先理解表面,把大概知识打通,然后再研究细节的东西
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
# Check that MPS is available
if not torch.backends.mps.is_available():if not torch.backends.mps.is_built():print("MPS not available because the current PyTorch install was not ""built with MPS enabled.")else:print("MPS not available because the current MacOS version is not 12.3+ ""and/or you do not have an MPS-enabled device on this machine.")
else:device = torch.device("mps")
class Net(nn.Module):def __init__(self):super().__init__()# 28*28 = 784为输入,100为输出self.fcl = nn.Linear(784,100)self.fc2 = nn.Linear(100,10)def forward(self,x):x = torch.flatten(x,start_dim = 1)x = torch.relu(self.fcl(x))x = self.fc2(x)return x
# 当前模型对数据集学几次
max_epochs = 5
# 每次训练模型对多少张图片进行训练
batch_size = 16# data
# ToTensor 把当前数据类型转换为 Tensor
# Compose是组合多个转换操作的类
transform = transforms.Compose([transforms.ToTensor()])# 55000
trainset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transform)
train_loader = torch.utils.data.DataLoader(trainset,batch_size=batch_size,shuffle=True)
testset = torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=transform)
test_loader = torch.utils.data.DataLoader(testset,batch_size=batch_size,shuffle=True)
# net init
net = Net()
net.to(device)# nn.MSE
loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.0001)def train():acc_num=0for epoch in range(max_epochs):for i,(data,label) in enumerate(train_loader):data = data.to(device)label = label.to(device)optimizer.zero_grad()output = net(data)Loss = loss(output,label)Loss.backward()optimizer.step()pred_class = torch.max(output,dim=1)[1]acc_num += torch.eq(pred_class,label.to(device)).sum().item()train_acc = acc_num / len(trainset)net.eval()acc_num = 0.0best_acc = 0with torch.no_grad():for val_data in test_loader:val_image,val_label = val_dataoutput = net(val_image.to(device))predict_y = torch.max(output , dim=1)[1]acc_num += torch.eq(predict_y,val_label.to(device)).sum().item()val_acc = acc_num/len(testset)print(train_acc,val_acc)if val_acc > best_acc:torch.save(net.state_dict(),'./minst.pth')best_acc = val_accacc_num = 0train_acc = 0test_acc = 0print('done')train()
0.1348 0.3007
done
0.4361 0.5548
done
0.5870666666666666 0.6335
done
0.6435333333333333 0.672
done
0.67915 0.7011
done
http://www.lryc.cn/news/188868.html

相关文章:

  • idea 插件推荐(持续更新)
  • 实现Promise所有核心功能和方法
  • 学习总结1
  • 使用 Apache Camel 和 Quarkus 的微服务(二)
  • pid-limit参数实验
  • jvm--执行引擎
  • day13|二叉树理论
  • php+html+js+ajax实现文件上传
  • 日期时间参数,格式配置(SpringBoot)
  • JAVA 泛型的定义以及使用
  • Day-08 基于 Docker安装 Nginx 镜像-负载均衡
  • 3、在 CentOS 8 系统上安装 PostgreSQL 15.4
  • sap 一次性供应商 供应商账户组 临时供应商 <转载>
  • 总结html5中常见的选择器
  • Java基础面试-JDK JRE JVM
  • OpenCV实现图像傅里叶变换
  • 快手新版本sig3参数算法还原
  • Linux 安全 - LSM机制
  • uni-app:实现简易自定义下拉列表
  • 排序算法——直接插入排序
  • 手动抄表和自动抄表优缺点对比
  • HiSilicon352 android9.0 emmc添加新分区
  • networkX-04-查找k短路
  • Linux虚拟机搭建RabbitMQ集群
  • C之fopen/fclose/fread/fwrite/flseek
  • 3D机器视觉:解锁未来的立体视野
  • 大端字节序存储 | 小端字节序存储介绍
  • ASP.Core3.1 WebAPI 发布到IIS
  • MyBatisPlus属性自动填充和乐观锁插件+查询删除操作+整合SpringBoot出现问题解决
  • 软件测试/测试开发丨App自动化—CSS 定位与原生定位