当前位置: 首页 > news >正文

【实战1】手写字识别 Pytoch(更新中)

1 数据集

引用

import torch
from torch import nn ##nn创建神经网络
from torch.utils.data import DataLoader #DataLoader加载数据
from torchvision import datasets #datasets加载数据集
from torchvision.transforms import ToTensor #ToTensor将数据转换为张量

下载数据集 

# 下载数据
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor()
)testing_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor()
)

导入数据

# 加载数据
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(testing_data, batch_size=batch_size)# 打印第一个批次的数据形状 
for X, y in train_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

使用keras构建神经网络的方法

构建神经网络

定义神经网络类,继承自nn.Moudule

Pytorch使用nn.Sequentail对象创建神经网络

神经网络必须实现前向传播forward()方法

使用nn.Flatten对象将图片转成列列向量【对比Kares ,pytorch是面向对象的思路】

nn.to方法可以让模型运行在不同的硬件上

device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available() else "cpu")# 打印设备信息
print(f"Using {device} device")# 构建神经网络 2层 输入28*28,第一层128个神经元,激活函数relu,第二层10个神经元
class NeunalNetwork(nn.Module):def __init__(self):super().__init__()self.nmodel = nn.Sequential(nn.Linear(28*28, 128), # 输入层到隐藏层 nn.ReLU(), # 激活函数nn.Linear(128, 10) # 隐藏层到输出层)def forward(self, x):x = nn.Flatten(1)(x) # 将输入展平output = self.nmodel(x) #对模型训练return outputmodel = NeunalNetwork().to(device)
print(model)

编译神经网络

训练神经网络

评估神经网络

http://www.lryc.cn/news/594394.html

相关文章:

  • RTC外设详解
  • Vuex 核心知识详解:Vue2Vue3 状态管理指南
  • Qt--Widget类对象的构造函数分析
  • 【vue-7】Vue3 响应式数据声明:深入理解 reactive()
  • 2024年青少年信息素养大赛图形化编程小低组初赛真题(含答案)
  • ZooKeeper学习专栏(二):深入 Watch 机制与会话管理
  • C语言:深入理解指针(2)
  • 网络地址和主机地址之间进行转换的类
  • 剑指offer66_不用加减乘除做加法
  • Spring Boot 订单超时自动取消的 3 种主流实现方案
  • 腾讯二面手撕题:BatchNorm和LayerNorm
  • 08_Opencv_基本图形绘制
  • 学成在线项目
  • Eureka+LoadBalancer实现服务注册与发现
  • 限流算法与实现
  • Shell脚本-tee工具
  • Kafka 在分布式系统中的关键特性与机制深度解析
  • kotlin Flow快速学习2025
  • PostgreSQL实战:高效SQL技巧
  • 【LeetCode刷题指南】--反转链表,链表的中间结点,合并两个有序链表
  • 基于单片机无线防丢/儿童防丢报警器
  • 数据结构 | 栈:构建高效数据处理的基石
  • 【2025最新版】PDFelement全能PDF编辑器
  • [硬件电路-58]:根据电子元器件的控制信号的类型分为:电平控制型和脉冲控制型两大类。
  • LockFile简要分析
  • 《镜语者》
  • RocketMQ学习系列之——MQ入门概念
  • 【基础】——股票市场基础知识宏观
  • 无 sudo 权限的环境下将 nvcc (CUDA Toolkit) 安装到个人目录 linux
  • 【c++】200*200 01灰度矩阵求所有的连通区域坐标集合