当前位置: 首页 > news >正文

Day 40:训练和测试的规范写法

1. 单通道图片的规范准备工作

首先,回顾一下基础设置。用 MNIST 数据集(单通道灰度图像)作为例子。代码开头导入必要的库,并设置设备和随机种子,确保结果可复现。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

这里,transforms.Compose 用于数据预处理:转为张量、归一化(MNIST 的均值和标准差是固定的)。然后加载数据集,分成训练集和测试集,用 DataLoader 打包成批次(batch_size=64)。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

模型定义是个简单的 MLP:展平图像、两层全连接 + ReLU。损失函数用交叉熵,优化器选 Adam。

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()self.layer1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.layer2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = self.layer1(x)x = self.relu(x)x = self.layer2(x)return xmodel = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

这些都是老生常谈,但规范写法从这里开始体现:数据和模型分离,便于后续扩展。

2. 训练函数的封装:逻辑隔离,参数复用

以前写鸢尾花 MLP 时,训练代码直接扔在主程序里,看起来乱糟糟的。现在,我们用函数 train 封装整个过程。为什么这么做?一是参数(如 epochs)调整方便,不用翻代码;二是复用性强,以后多模型对比时,直接调用就好。
函数里记录每个 batch 的损失(iteration 级),每 100 batch 打印一次,epoch 结束时测试并绘图。注意,早停策略暂不加(需要验证集),但测试函数独立封装。

def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):model.train()all_iter_losses = []iter_indices = []for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totalepoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')plot_iter_losses(all_iter_losses, iter_indices)return epoch_test_acc

测试函数 test 也很干净:eval 模式、无梯度计算,计算平均损失和准确率。

def test(model, test_loader, criterion, device):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy

绘图函数绘制 iteration 损失曲线,更直观地看训练过程。

def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()

3. 运行与效果分析

直接调用训练:epochs=2(实际可调大),看看效果。

epochs = 2
print("开始训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

运行后,会看到每 100 batch 的损失打印,epoch 结束的准确率,以及损失曲线图。MNIST 上 MLP 能到 96%+,但如果换 CIFAR-10,准确率可能卡在 50% 左右。为什么?MLP 忽略图像空间结构,全连接层参数爆炸,容易过拟合。

@浙大疏锦行

http://www.lryc.cn/news/626306.html

相关文章:

  • 008.Redis Cluster集群架构实践
  • RabbitMQ:SpringAMQP Topic Exchange(主题交换机)
  • Linux中Cobbler服务部署与配置(快速部署和管理 Linux 系统)
  • mac电脑软件左上角的关闭/最小化/最大化按钮菜单的宽度和高度是多少像素
  • Mac 4步 安装 Jenv 管理多版本JDK
  • Mac 上安装并使用 frpc(FRP 内网穿透客户端)指南
  • 第四章:大模型(LLM)】07.Prompt工程-(4)思维链(CoT, Chain-of-Thought)Prompt
  • 第四章:大模型(LLM)】07.Prompt工程-(5)self-consistency prompt
  • 编译安装 Nginx
  • 从AI小智固件到人类智能:计算技术的层级跃迁
  • Linux-----《Linux系统管理速通:界面切换、远程连接、目录权限与用户管理一网打尽》
  • JavaScript 检查给定的四个点是否形成正方形(Check if given four points form a square)
  • [特殊字符] 小豆包 API 聚合平台:让 AI 接入更简单、更高效
  • PyTorch API 7
  • Linux 文件系统权限管理(补充)
  • pinctrl和gpio子系统实验
  • 前后端联合实现文件上传,实现 SQL Server image 类型文件上传
  • LeetCode热题100--101. 对称二叉树--简单
  • 【Kafka】常见简单八股总结
  • 力扣 30 天 JavaScript 挑战 第36天 第8题笔记 深入了解reduce,this
  • Linux Shell 常用操作与脚本示例详解
  • CNN 在故障诊断中的应用:原理、案例与优势
  • DAY 50 预训练模型+CBAM模块
  • 排查Redis数据倾斜引发的性能瓶颈
  • VScode ROS文件相关配置
  • 什么是大数据平台?大数据平台和数据中台有什么关系?
  • 网络间的通用语言TCP/IP-网络中的通用规则3
  • A股大盘数据-20250819 分析
  • 【PyTorch】单对象分割项目
  • Arthas 全面使用指南:离线安装 + Docker/K8s 集成 + 集中管理