当前位置: 首页 > news >正文

交叉熵Loss多分类问题实战(手写数字)

1、import所需要的torch库和包
在这里插入图片描述
2、加载mnist手写数字数据集,划分训练集和测试集,转化数据格式,batch_size设置为200在这里插入图片描述
3、定义三层线性网络参数w,b,设置求导信息
在这里插入图片描述
4、初始化参数,这一步比较关键,是否初始化影响到数据质量以及后续网络学习效果
在这里插入图片描述
5、自定义三层线性网络
在这里插入图片描述
6、选定优化器激活函数和loss函数
在这里插入图片描述
7、训练及测试,并记录每轮训练的loss变化和在测试集上的效果。第一轮就达到了98的准确度,判断是初始化效果较好,在前几次测试中根据初始化的情况不同,初始准确率为50%-85%不等
在这里插入图片描述
完整代码:

import torch
import torchvision
import torch.nn.functional as Ftrain_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),batch_size=200, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),batch_size=200, shuffle=True)w1 = torch.randn(200, 784, requires_grad=True)
b1 = torch.randn(200, requires_grad=True)
w2 = torch.randn(200, 200, requires_grad=True)
b2 = torch.randn(200, requires_grad=True)
w3 = torch.randn(10, 200, requires_grad=True)
b3 = torch.randn(10, requires_grad=True)torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)def forward(x):x = x@w1.t() +b1x = F.relu(x)x = x@w2.t() +b2x = F.relu(x)x = x@w3.t() +b3x = F.relu(x)return xoptimizer = torch.optim.Adam([w1, b1, w2, b2, w3, b3], lr=0.001)
criterion = torch.nn.CrossEntropyLoss()for epoch in range(10):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)logits = forward(data)loss = criterion(logits, target)optimizer.zero_grad()loss.backward()optimizer.step()if (batch_idx+1) % 150 == 0:print('Train Epoch:{} [{}/{}({:.0f}%)]\tLoss:{:.6f}'.format(epoch, (batch_idx+1) * len(data), len(train_loader.dataset),100. * (batch_idx+1) / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28*28)logits = forward(data)test_loss += criterion(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader)print('\nTest Set:Average Loss:{:.4f}, Accuracy:{}/{}({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
http://www.lryc.cn/news/189322.html

相关文章:

  • 如何看待Unity新的收费模式?(InsCode AI 创作助手)
  • Android Studio git 取消本地 commit(未Push)
  • ViewModifier/视图修饰符, ButtonStyle/按钮样式 的使用
  • 科技资讯|微软AR眼镜新专利曝光,可拆卸电池解决续航焦虑
  • idea系列---【上一次打开springboot项目还好好的,现在打开突然无法启动了】
  • 查询资源消耗
  • conda: error: argument COMMAND: invalid choice: ‘activate‘
  • 新鲜速递:Spring Cloud Alibaba环境在Spring Boot 3时代的快速搭建
  • 网络-网络状态网络速度
  • ACL访问控制列表的解析和配置
  • 记一次使用vue-markdown在vue中解析markdown格式文件,并自动生成目录大纲
  • 力扣每日一题35:搜索插入的位置
  • Iptabels的相关描述理解防火墙的必读文章
  • Maven 构建项目测试
  • 机器学习 - 似然函数:概念、应用与代码实例
  • LeetCode 热题 100-49. 字母异位词分组
  • TensorFlow入门(十九、softmax算法处理分类问题)
  • 刷题用到的非常有用的函数c++(持续更新)
  • 黑客技术(网络安全)——自学思路
  • lNmp安装:
  • Fisher辨别分析
  • 【Zookeeper专题】Zookeeper选举Leader源码解析
  • 机器学习之自训练协同训练
  • ubuntu 通过apt-get快速安装 docker
  • C++医院影像科PACS源码:三维重建、检查预约、胶片打印、图像处理、测量分析等
  • 企业聊天应用程序使用 Kubernetes
  • 记录用命令行将项目打包成war包
  • Linux基础知识笔记
  • Laya3.0 入门教程
  • 3D全景虚拟样板间展销系统扩展用户市场范围