用 PyTorch 实现全连接网络识别 MNIST 手写数字
目录
一、什么是全连接网络
二、代码实现步骤
1. 导入必要的库
2. 数据准备
3. 定义网络结构
4. 模型训练
5. 模型保存和加载
6. 预测单张图片
7. 主函数
三、运行结果说明
四、小结
一、什么是全连接网络
全连接神经网络(Fully Connected Neural Network)是一种最基础的神经网络结构,其特点是每一层的每个神经元都与上一层的所有神经元相连。
打个比方,就像公司里的部门架构:输入层是基层员工,隐藏层是中层管理,输出层是高层决策。基层的每个人都要向所有中层汇报,中层再向所有高层汇报,这样信息就能经过多层处理后得到最终结果。
但全连接网络处理图像时有个缺点:它会把图像的二维像素矩阵转换成一维向量,这就像把一张完整的图片撕成一条线,会丢失图像的空间特征。
二、代码实现步骤
1. 导入必要的库
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
这些库就像我们的工具包:
torch
是 PyTorch 的核心库nn
模块包含神经网络相关的工具optim
提供优化器torchvision
有现成的数据集和图像处理工具DataLoader
帮助我们批量加载数据PIL
用于处理图像
2. 数据准备
def build_data():transform = transforms.Compose([transforms.ToTensor(),])train_set = datasets.MNIST(root = '../dataset',train = True,download = True,transform = transform)test_set = datasets.MNIST(root = '../dataset',train = False,download = True,transform = transform)train_loader = DataLoader(dataset = train_set,batch_size = 128,shuffle = True)test_loader = DataLoader(dataset = test_set,batch_size = 64,shuffle = True)return train_loader, test_loader
这段代码做了三件事:
- 定义了数据转换方式,
ToTensor()
会把图像转换成张量并归一化 - 加载 MNIST 数据集(手写数字数据集,包含 0-9 共 10 类数字)
- 用
DataLoader
把数据分成批次,方便训练时批量处理
batch_size
表示每次处理多少张图片,shuffle=True
表示打乱数据顺序,让模型学习更全面。
3. 定义网络结构
class MNISTNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(28 * 28, 256)self.relu1 = nn.ReLU()self.fc2 = nn.Linear(256, 128)self.relu2 = nn.ReLU()self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28) # 把28x28的图像展平成784维向量x = self.relu1(self.fc1(x))x = self.relu2(self.fc2(x))x = self.fc3(x)return x
我们定义了一个 3 层的全连接网络:
- 输入层:MNIST 图像是 28x28 的,展平后是 784 个像素点
- 第一个隐藏层:256 个神经元,使用 ReLU 激活函数
- 第二个隐藏层:128 个神经元,同样使用 ReLU 激活函数
- 输出层:10 个神经元(对应 0-9 十个数字)
激活函数 ReLU 的作用是引入非线性,让网络能够学习复杂的模式,就像给计算器增加了更多运算功能。
4. 模型训练
def train(model, train_loader, epochs):criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,适合分类问题opt = optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降优化器for epoch in range(epochs):loss_sum = 0count = 0for x, y in train_loader:y_pred = model(x) # 前向传播,得到预测结果loss = criterion(y_pred, y) # 计算损失# 反向传播更新参数opt.zero_grad() # 清空梯度loss.backward() # 计算梯度opt.step() # 更新参数loss_sum += loss.item()_, pred = torch.max(y_pred, dim=1) # 找到概率最大的类别count += (pred == y).sum().item() # 统计正确的数量acc = count / len(train_loader.dataset) # 计算准确率print(f'epoch: {epoch+1}, Loss: {loss_sum:.4f}, Acc: {acc:.4f}')
训练过程就像学生做习题:
- 先用当前模型做预测(前向传播)
- 计算预测结果和正确答案的差距(损失函数)
- 分析哪里错了,怎么改进(反向传播计算梯度)
- 调整模型参数(优化器更新参数)
我们用交叉熵损失函数来衡量预测错误的程度,用随机梯度下降(SGD)来优化模型参数,学习率lr=0.01
控制每次调整的幅度。
5. 模型保存和加载
def save_model(model, model_path):torch.save(model.state_dict(), model_path) # 保存模型参数def load_model(model_path):model = MNISTNet()model.load_state_dict(torch.load(model_path)) # 加载模型参数return model
训练好的模型可以保存下来,下次用的时候直接加载,不用重新训练,就像保存游戏进度一样。
6. 预测单张图片
def predict(model, filePath):img = Image.open(filePath)# 图像预处理:调整大小、转成张量、归一化transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])t_img = transform(img)with torch.no_grad(): # 预测时不需要计算梯度y_pred = model(t_img)_, pred = torch.max(y_pred, dim=1)print(f'预测结果: {pred.item()}')
预测时需要对输入图片做和训练数据相同的预处理,with torch.no_grad()
可以加快计算速度,因为预测时不需要更新参数。
7. 主函数
if __name__ == '__main__':train_loader, test_loader = build_data()model = MNISTNet()# 训练模型train(model, train_loader, epochs=10)# 保存模型save_model(model, './mnist.pt')# 加载模型并预测model_pred = load_model('./mnist.pt')predict(model_pred, './img/3.png')
三、运行结果说明
训练过程中,我们会看到损失(Loss)逐渐减小,准确率(Acc)逐渐提高,这说明模型在不断进步。
对于 MNIST 这种简单数据集,用这个全连接网络通常能达到 97% 以上的准确率。如果想进一步提高性能,可以考虑使用卷积神经网络(CNN),它能更好地保留图像的空间特征。
四、小结
本文用 PyTorch 实现了一个全连接神经网络来识别 MNIST 手写数字,主要步骤包括:
- 准备数据:加载并预处理 MNIST 数据集
- 定义网络:设计 3 层全连接网络
- 训练模型:使用交叉熵损失和 SGD 优化器
- 保存和加载模型:方便复用
- 单张图片预测:实际应用模型
全连接网络虽然简单,但它是理解更复杂神经网络的基础。通过这个例子,我们可以了解神经网络的基本工作原理和 PyTorch 的使用方法。