使用PyTorch实现Softmax回归(Mnist手写数字识别)
1. 引言
Softmax回归是机器学习中最基础但强大的分类算法之一,特别适用于多分类问题。本文将详细讲解如何使用PyTorch框架实现Softmax回归模型,对经典的MNIST手写数字数据集进行分类。通过本教程,您将掌握:
- Softmax回归的数学原理
- PyTorch实现Softmax回归的完整流程
- 使用TensorBoard进行训练可视化
- 模型保存与加载的最佳实践
- 预测结果的可视化展示
2. Softmax回归原理
2.1 数学基础
Softmax回归是逻辑回归的多分类扩展,其核心是将线性变换的输出转换为概率分布:
Softmax函数的数学表达式为:
σ ( z ) j = e z j ∑ k = 1 K e z k for j = 1 , … , K \sigma(\mathbf{z})_j = \frac{e^{z_j}}{\sum_{k=1}^K e^{z_k}} \quad \text{for} \quad j = 1, \dots, K σ(z)j=∑k=1Kezkezjforj=1,…,K
2.2 损失函数
我们使用负对数似然损失(NLL Loss),它与Softmax结合等价于交叉熵损失:
Loss = − ∑ i = 1 N ∑ k = 1 K y i k log ( p i k ) \text{Loss} = -\sum_{i=1}^{N} \sum_{k=1}^{K} y_{ik} \log(p_{ik}) Loss=−i=1∑Nk=1∑Kyiklog(pik)
其中:
- N N N是样本数量
- K K K是类别数量(MNIST中为10)
- y i k y_{ik} yik是样本 i i i属于类别 k k k的真实标签
- p i k p_{ik} pik是模型预测样本 i i i属于类别 k k k的概率
3.网络设计
确定网络结构及其形状
- 第一次参数:输入: x :[None,784] ;权重:[784,64] ;偏置:[64] ,输出:[None,64]
- 第二层参数:输入: x :[None,64] ;权重:[64,10] ;偏置:[10] ,输出:[None,10]
流程
- 获取数据
- 前向传播:网络结构定义
- 损失计算
- 反向传播:梯度下降优化
完善功能
- 准确率计算
- 添加 Tensorboard 观察变量
- 训练模型保存

4. 完整代码实现
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import os
import numpy as np
from datetime import datetime# 解决Matplotlib中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体显示中文
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号# 检查CUDA设备是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 确保日志目录存在
log_dir = "runs"
os.makedirs(log_dir, exist_ok=True)# 创建TensorBoard记录器
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_subdir = os.path.join(log_dir, f"softmax_mnist_{timestamp}")
writer = SummaryWriter(log_dir=log_subdir)class SoftmaxRegression(torch.nn.Module):"""简单的Softmax回归模型"""def __init__(self):super().__init__()# 单层网络结构:784输入 -> 10输出self.linear = torch.nn.Linear(28 * 28, 10)def forward(self, x):# 应用log_softmax到线性层输出return torch.nn.functional.log_softmax(self.linear(x), dim=1)def get_data_loader(is_train, batch_size=128):"""获取数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)) # MNIST的均值和标准差])dataset = MNIST(root='./data', train=is_train, download=True, transform=transform)return DataLoader(dataset, batch_size=batch_size, shuffle=is_train, pin_memory=True)def evaluate(test_data, net):"""评估模型准确率"""net.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_data:images, labels = images.to(device), labels.to(device)outputs = net(images.view(-1, 28 * 28))_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()return correct / totaldef save_model(net, filename='softmax_model.pth'):"""保存模型"""torch.save(net.state_dict(), filename)print(f"模型已保存至 {filename}")def load_model(net, filename='softmax_model.pth'):"""加载模型"""if os.path.exists(filename):net.load_state_dict(torch.load(filename, map_location=device))print(f"模型已从 {filename} 加载")else:print(f"警告: 未找到模型文件 {filename}")return netdef visualize_predictions(model, test_loader, num_images=12):"""可视化模型预测结果"""model.eval()images, labels = next(iter(test_loader))images, labels = images.to(device), labels.to(device)with torch.no_grad():outputs = model(images.view(-1, 28 * 28))_, predictions = torch.max(outputs, 1)plt.figure(figsize=(12, 8))for i in range(num_images):plt.subplot(3, 4, i + 1)img = images[i].cpu().numpy().squeeze()plt.imshow(img, cmap='gray')plt.title(f"预测: {predictions[i].item()} (真实: {labels[i].item()})")plt.axis('off')plt.tight_layout()plt.savefig('softmax_predictions.png', dpi=150)plt.show()# 将图像添加到TensorBoard(使用PIL图像替代)from PIL import Imagefrom torchvision.utils import save_image# 创建临时图像文件temp_img_path = "temp_grid.png"save_image(images[:num_images], temp_img_path, nrow=4)# 读取并添加到TensorBoardimg = Image.open(temp_img_path)img_array = np.array(img)writer.add_image('predictions', img_array, dataformats='HWC')def main():# 获取数据加载器train_loader = get_data_loader(is_train=True, batch_size=128)test_loader = get_data_loader(is_train=False, batch_size=512)# 创建模型model = SoftmaxRegression().to(device)print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")# 尝试加载现有模型model = load_model(model)# 评估初始准确率init_acc = evaluate(test_loader, model)print(f"初始准确率: {init_acc:.4f}")writer.add_scalar('Accuracy/test', init_acc, 0)# 使用Adam优化器optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 训练循环total_step = 0for epoch in range(10):model.train()total_loss = 0correct = 0total = 0for i, (images, labels) in enumerate(train_loader):images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images.view(-1, 28 * 28))loss = torch.nn.functional.nll_loss(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()# 统计信息total_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 每100个batch记录一次if (i + 1) % 100 == 0:train_acc = correct / totalavg_loss = total_loss / (i + 1)print(f"Epoch [{epoch + 1}/10], Step [{i + 1}/{len(train_loader)}], "f"Loss: {avg_loss:.4f}, Accuracy: {train_acc:.4f}")# 记录到TensorBoardwriter.add_scalar('Loss/train', avg_loss, total_step)writer.add_scalar('Accuracy/train', train_acc, total_step)total_step += 1# 每个epoch结束后评估测试集test_acc = evaluate(test_loader, model)print(f"Epoch [{epoch + 1}/10], 测试准确率: {test_acc:.4f}")writer.add_scalar('Accuracy/test', test_acc, epoch)# 训练后保存模型save_model(model)# 最终评估final_acc = evaluate(test_loader, model)print(f"最终测试准确率: {final_acc:.4f}")# 可视化预测结果visualize_predictions(model, test_loader)# 在TensorBoard中添加模型图dummy_input = torch.randn(1, 784).to(device)writer.add_graph(model, dummy_input)# 关闭TensorBoard写入器writer.close()print(f"TensorBoard日志保存在: {log_subdir}")print("使用命令查看TensorBoard: tensorboard --logdir=runs")if __name__ == '__main__':main()
5. 代码详解
5.1 环境设置与设备选择
# 解决Matplotlib中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 检查CUDA设备是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
这部分代码解决了Matplotlib显示中文乱码的问题,并自动选择最佳计算设备(优先使用GPU)。
5.2 数据加载与预处理
def get_data_loader(is_train, batch_size=128):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,)) # MNIST标准化])dataset = MNIST(root='./data', train=is_train, download=True, transform=transform)return DataLoader(dataset, batch_size=batch_size, shuffle=is_train, pin_memory=True)
数据预处理包括:
- 转换为张量
- 标准化处理(使用MNIST的全局均值0.1307和标准差0.3081)
- 批处理(训练集batch_size=128,测试集batch_size=512)
- 使用
pin_memory=True
加速GPU数据传输
5.3 Softmax回归模型
class SoftmaxRegression(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(28 * 28, 10)def forward(self, x):return torch.nn.functional.log_softmax(self.linear(x), dim=1)
模型结构非常简单:
- 输入层:28×28=784个特征(展平的图像像素)
- 输出层:10个神经元(对应0-9的数字类别)
- 激活函数:Log_Softmax(提高数值稳定性)
5.4 训练流程
训练过程的核心循环:
关键代码段:
# 前向传播
outputs = model(images.view(-1, 28 * 28))
loss = torch.nn.functional.nll_loss(outputs, labels)# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()# 记录到TensorBoard
writer.add_scalar('Loss/train', avg_loss, total_step)
writer.add_scalar('Accuracy/train', train_acc, total_step)
5.5 模型保存与加载
def save_model(net, filename='softmax_model.pth'):torch.save(net.state_dict(), filename)def load_model(net, filename='softmax_model.pth'):if os.path.exists(filename):net.load_state_dict(torch.load(filename, map_location=device))return net
最佳实践:
- 只保存模型参数(state_dict),而非整个模型
- 使用
.pth
扩展名 - 指定
map_location
确保跨设备加载 - 文件大小仅约40KB
5.6 可视化技术
5.6.1 TensorBoard集成
# 创建记录器
writer = SummaryWriter(log_dir=log_subdir)# 添加标量数据
writer.add_scalar('Accuracy/test', test_acc, epoch)# 添加模型图
writer.add_graph(model, dummy_input)# 添加图像预测
writer.add_image('predictions', img_array, dataformats='HWC')
TensorBoard提供以下可视化:
- 训练/测试损失曲线
- 准确率变化趋势
- 模型计算图
- 预测样本图像
5.6.2 Matplotlib可视化
def visualize_predictions(model, test_loader, num_images=12):# ...获取预测结果...plt.figure(figsize=(12, 8))for i in range(num_images):plt.subplot(3, 4, i + 1)plt.imshow(img, cmap='gray')plt.title(f"预测: {predictions[i].item()} (真实: {labels[i].item()})")plt.show()
6. 训练结果分析
6.1 性能指标(示例数据)
训练轮数 | 训练准确率 | 测试准确率 |
---|---|---|
初始 | ~10% | ~10% |
5轮 | ~85% | ~88% |
10轮 | ~91% | ~91% |
6.2 可视化结果
-
TensorBoard损失曲线:观察指数衰减趋势
TensorBoard损失曲线图
-
准确率曲线:训练/测试准确率对比
准确率曲线图
- 预测样本:

- 模型图:

7. 结论
本文详细介绍了使用PyTorch实现Softmax回归进行MNIST手写数字分类的完整流程。通过本教程,我们可以掌握:
- Softmax回归的数学原理和实现方式
- PyTorch数据加载、模型构建和训练的最佳实践
- TensorBoard可视化训练过程
- 模型保存与加载技术
- 预测结果的可视化展示
尽管Softmax回归模型简单,但它在MNIST数据集上能达到约92%的准确率,为理解更复杂的深度学习模型奠定了基础。