当前位置: 首页 > news >正文

实战:用 PyTorch 复现一个 3 层全连接网络,训练 MNIST,达到 95%+ 准确率

1. 使用 Anaconda 创建一个新环境,包括 python 和 与你显卡对应的 torch

2. PyCharm(2025.1.3.1)绑定 Conda 环境-CSDN博客

3. 

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm# 一次给模型看多少张图片
BATCH_SIZE = 64
# 把全部训练数据重复看多少遍
EPOCHS = 10
LR = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
train_set = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_set  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)# 原始数据集中,一张 MNIST 图片的形状是 (1, 28, 28) ← 1 个通道(灰度),高 28,宽 28。
# 当 DataLoader 按 batch_size=64 打包后,它把 64 张这样的图片堆在一起,形成一个新的 4 维张量,形状变成 (64, 1, 28, 28)
# shuffle = True 的作用:在每个 epoch 开始时,把训练集里的 60 000 张图片顺序彻底打乱一次。
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE)# 搭建神经网络:把图片拉成一条长条 → 过 128 个神经元 → 再过 64 个神经元 → 最后给出 10 个数字的得分
class Net(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Flatten(),nn.Linear(784, 128), nn.ReLU(),nn.Linear(128, 64),  nn.ReLU(),nn.Linear(64, 10))def forward(self, x):return self.net(x)model = Net().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)# 训练
for epoch in range(1, EPOCHS + 1):model.train()pbar = tqdm(train_loader, desc=f"Epoch {epoch}")for x, y in pbar:x, y = x.to(DEVICE), y.to(DEVICE)optimizer.zero_grad()loss = criterion(model(x), y)loss.backward()optimizer.step()pbar.set_postfix(loss=loss.item())model.eval()
correct = total = 0
with torch.no_grad():for x, y in test_loader:x, y = x.to(DEVICE), y.to(DEVICE)pred = model(x).argmax(1)correct += (pred == y).sum().item()total += y.size(0)
print(f"Test Accuracy: {100*correct/total:.2f}%")

4. 运行

http://www.lryc.cn/news/617313.html

相关文章:

  • 软件测试关于搜索方面的测试用例
  • DeepCompare文件深度对比软件:权限管理与安全功能全面解析
  • Android Audio实战——获取活跃音频类型(十五)
  • 安全合规4--下一代防火墙组网
  • 企业内外网物理隔离时文件怎么传输更安全
  • ChatML vs Harmony:深度解析OpenAI全新对话结构格式的变化
  • Linux 流编辑器 sed 详解
  • C#使用EPPlus读写Excel
  • Elasticsearch Node.js 客户端的安装
  • 【Node.js从 0 到 1:入门实战与项目驱动】1.3 Node.js 的应用场景(附案例与代码实现)
  • Flutter Dialog、BottomSheet
  • RabbitMQ 消息转换器详解
  • windows上RabbitMQ 启动时报错:发生系统错误 1067。 进程意外终止。
  • 内存问题排查工具ASan初探
  • 嵌入式Linnux学习 -- 软件编程2
  • uart通信中出现乱码,可能的原因是什么 ?
  • 借助 ChatGPT 快速实现 TinyMCE 段落间距与行间距调节
  • Nmap 渗透测试弹药库:精准扫描与隐蔽渗透技术手册
  • 什么是结构化思维?什么是结构化编程?
  • 计算机网络(一)——TCP
  • Vue脚手架模式与环境变量
  • 变频器实习DAY26 CDN 测试中心使用方法
  • Android16新特性速记
  • C语言如何安全的进行字符串拷贝
  • 从 GPT-2 到 gpt-oss:架构进步分析
  • 北京JAVA基础面试30天打卡07
  • Nacos-1--什么是Nacos?
  • 5G NR 非地面网络 (NTN)
  • JVM运维
  • C#(vs2015)利用unity实现弯管机仿真