当前位置: 首页 > news >正文

Pytorch笔记之分类

文章目录

  • 前言
  • 一、导入库
  • 二、数据处理
  • 三、构建模型
  • 四、迭代训练
  • 五、模型评估
  • 总结


前言

使用Pytorch进行MNIST分类,使用TensorDataset与DataLoader封装、加载本地数据集。


一、导入库

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader # 数据集工具
from load_mnist import load_mnist # 本地数据集

二、数据处理

1、导入本地数据集,将标签值设置为int类型,构建张量
2、使用TensorDataset与DataLoader封装训练集与测试集

# 构建数据
x_train, y_train, x_test, y_test = \load_mnist(normalize=True, flatten=False, one_hot_label=False)
# 数据处理
x_train = torch.from_numpy(x_train.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.int64))
x_test = torch.from_numpy(x_test.astype(np.float32))
y_test = torch.from_numpy(y_test.astype(np.int64))
# 数据集封装
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)
batch_size = 64
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

三、构建模型

输入到全连接层之前需要把(batch_size,28,28)展平为(batch_size,784)
交叉熵损失函数整合了Softmax,在模型中可以不添加Softmax

# 继承模型
class FC(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(784, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):y = self.fc1(x.view(x.shape[0],-1))y = self.softmax(y)return y
# 定义模型
model = FC()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

四、迭代训练

从DataLoader中取出x和y,进行前向和反向的计算

for epoch in range(10):print('Epoch:', epoch)for i,data in enumerate(train_loader):x, y = datay_pred = model.forward(x)loss = loss_function(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()

五、模型评估

在测试集中进行验证
使用.item()获得tensor的取值

	correct = 0for i,data in enumerate(test_loader):x, y = datay_pred = model.forward(x)_, y_pred = torch.max(y_pred, 1)correct += (y_pred == y).sum().item()acc = correct / len(test_dataset)print('Accuracy:{:.2%}'.format(acc))


总结

记录了TensorDataset与DataLoader的使用方法,模型的构建与训练和上一篇Pytorch笔记之回归相似。

http://www.lryc.cn/news/183761.html

相关文章:

  • 【目标检测】——PE-YOLO精读
  • Java 数组转集合
  • Elasticsearch:ES|QL 查询语言简介
  • qt qml中listview出现卡顿情况时的常用处理方法
  • Elasticsearch基础操作演示总结
  • Spring 作用域解析器AnnotationScopeMetadataResolver
  • 如何发布一个 NPM 包
  • Flask小项目教程(含MySQL与前端部分)
  • Eureka
  • STM32G070RBT6-MCU温度测量(ADC)
  • 数据结构之带头双向循环链表
  • adb详细教程(四)-使用adb启动应用、关闭应用、清空应用数据、获取设备已安装应用列表
  • 【Spring Boot】日志文件
  • 图像处理与计算机视觉--第五章-图像分割-Canny算子
  • LabVIEW开发教学实验室自动化INL和DNL测试系统
  • 数据结构: 数组与链表
  • unity 控制玩家物体
  • 指数分布优化器(EDO)(含MATLAB代码)
  • Java 时间的加减处理
  • 基于A4988/DRV8825的四路步进电机驱动器
  • 万字总结网络原理
  • 【AI视野·今日CV 计算机视觉论文速览 第262期】Fri, 6 Oct 2023
  • 一文搞懂Jenkins持续集成解决的是什么问题
  • 微信小程序去除默认滚动条展示
  • 3.02 创建订单操作详细-订单创建与回滚 (创建订单操作详细)
  • 需求放缓、价格战升级、利润率持续恶化对小鹏汽车造成了严重影响
  • 《算法通关之路》chapter19解题技巧和面试技巧
  • 什么是TF-A项目的长期支持?
  • 【LinuxC】时间、时区,相关命令、函数
  • mac清理垃圾的软件有哪些?这三款我最推荐