当前位置: 首页 > news >正文

一个基于 PyTorch 的完整模型训练流程

一个基于 PyTorch 的完整模型训练流程

flyfish

训练步骤具体操作目的
1. 训练前准备设置随机种子、配置超参数(batch size、学习率等)、选择计算设备(CPU/GPU)确保实验可复现;统一控制训练关键参数;利用硬件加速训练
2. 数据预处理与加载对数据进行标准化/归一化、转换为张量;用DataLoader按batch加载数据统一输入格式,适配模型要求;高效分批读取数据,减少内存占用
3. 初始化组件定义模型结构并加载到计算设备;选择损失函数(如交叉熵)和优化器(如Adam)搭建训练核心框架:模型负责预测,损失函数量化误差,优化器负责参数更新
4. 训练循环(每个epoch)逐轮迭代优化模型参数
4.1 模型切换为训练模式model.train()启用dropout、批量归一化的训练模式,确保梯度计算有效
4.2 遍历训练数据(每个batch)逐批更新参数
4.2.1 清零梯度optimizer.zero_grad()消除历史梯度累积,确保当前batch的梯度计算独立
4.2.2 前向传播output = model(data)用当前模型参数对输入数据做预测,得到输出结果
4.2.3 计算损失loss = criterion(output, target)量化预测结果与真实标签的差距,作为优化目标
4.2.4 反向传播loss.backward()从损失值反向推导,计算所有可训练参数的梯度(参数对损失的影响程度)
4.2.5 参数更新optimizer.step()根据梯度,按优化器规则调整模型参数,减小损失
4.3 记录训练指标保存每个epoch的训练损失、准确率跟踪模型在训练集上的学习效果
5. 验证(每个epoch后)评估模型泛化能力
5.1 模型切换为评估模式model.eval()关闭dropout、固定批量归一化参数,确保评估稳定
5.2 关闭梯度计算with torch.no_grad():减少内存占用,加速验证过程(无需计算梯度)
5.3 计算验证指标计算验证损失、准确率评估模型在未见过的数据上的表现,判断泛化能力
6. 模型保存保存表现最优的模型参数(如验证准确率最高时)留存最佳模型,便于后续部署或继续训练
7. 训练后分析绘制损失/准确率曲线,统计训练时间直观展示训练过程,分析模型收敛状态和效率

前向传播→计算损失→反向传播→参数优化

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm
import time# 设置随机种子,保证结果可复现
def set_seed(seed=42):torch.manual_seed(seed)torch.cuda.manual_seed(seed)np.random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = False# 定义超参数
class Config:def __init__(self):self.batch_size = 64self.learning_rate = 0.001self.epochs = 10self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.save_path = './models'self.log_interval = 100# 定义简单的卷积神经网络模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)self.relu = nn.ReLU()self.dropout = nn.Dropout(0.5)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 64 * 7 * 7)  # 展平x = self.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x# 准备数据
def prepare_data(config):# 定义数据变换transform = transforms.Compose([ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差])# 加载MNIST数据集train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)test_dataset = datasets.MNIST(root='./data',train=False,download=True,transform=transform)# 创建数据加载器train_loader = DataLoader(train_dataset,batch_size=config.batch_size,shuffle=True,num_workers=2)test_loader = DataLoader(test_dataset,batch_size=config.batch_size,shuffle=False,num_workers=2)return train_loader, test_loader# 训练函数
def train(model, train_loader, criterion, optimizer, config, epoch):model.train()  # 设置为训练模式train_loss = 0.0correct = 0total = 0# 使用tqdm显示进度条pbar = tqdm(train_loader, desc=f'Train Epoch {epoch}')for batch_idx, (data, target) in enumerate(pbar):data, target = data.to(config.device), target.to(config.device)# 清零梯度optimizer.zero_grad()# 前向传播output = model(data)loss = criterion(output, target)# 反向传播和优化loss.backward()optimizer.step()# 统计训练信息train_loss += loss.item()_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()# 打印日志if batch_idx % config.log_interval == 0:pbar.set_postfix({'loss': f'{train_loss/(batch_idx+1):.6f}','accuracy': f'{100.*correct/total:.2f}%'})# 计算平均损失和准确率avg_loss = train_loss / len(train_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy# 验证函数
def validate(model, test_loader, criterion, config):model.eval()  # 设置为评估模式test_loss = 0.0correct = 0total = 0# 不计算梯度with torch.no_grad():for data, target in test_loader:data, target = data.to(config.device), target.to(config.device)output = model(data)test_loss += criterion(output, target).item()# 统计准确率_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()# 计算平均损失和准确率avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalprint(f'\nTest set: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)\n')return avg_loss, accuracy# 保存模型
def save_model(model, optimizer, epoch, loss, config):# 创建保存目录if not os.path.exists(config.save_path):os.makedirs(config.save_path)# 保存模型状态torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}, f"{config.save_path}/model_epoch_{epoch}.pth")print(f"Model saved to {config.save_path}/model_epoch_{epoch}.pth")# 主函数
def main():# 初始化设置set_seed()config = Config()print(f"Using device: {config.device}")# 准备数据train_loader, test_loader = prepare_data(config)# 初始化模型、损失函数和优化器model = SimpleCNN().to(config.device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)# 记录训练过程中的指标history = {'train_loss': [],'train_acc': [],'val_loss': [],'val_acc': []}# 开始训练start_time = time.time()best_val_acc = 0.0for epoch in range(1, config.epochs + 1):print(f"\nEpoch {epoch}/{config.epochs}")print("-" * 50)# 训练train_loss, train_acc = train(model, train_loader, criterion, optimizer, config, epoch)history['train_loss'].append(train_loss)history['train_acc'].append(train_acc)# 验证val_loss, val_acc = validate(model, test_loader, criterion, config)history['val_loss'].append(val_loss)history['val_acc'].append(val_acc)# 保存最佳模型if val_acc > best_val_acc:best_val_acc = val_accsave_model(model, optimizer, epoch, val_loss, config)# 计算总训练时间end_time = time.time()total_time = end_time - start_timeprint(f"Training complete in {total_time:.0f}s ({total_time/config.epochs:.2f}s per epoch)")print(f"Best validation accuracy: {best_val_acc:.2f}%")# 绘制训练曲线plot_training_history(history)# 绘制训练历史
def plot_training_history(history):plt.figure(figsize=(12, 4))# 绘制损失曲线plt.subplot(1, 2, 1)plt.plot(history['train_loss'], label='Training Loss')plt.plot(history['val_loss'], label='Validation Loss')plt.title('Loss Curves')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()# 绘制准确率曲线plt.subplot(1, 2, 2)plt.plot(history['train_acc'], label='Training Accuracy')plt.plot(history['val_acc'], label='Validation Accuracy')plt.title('Accuracy Curves')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.tight_layout()plt.savefig('training_history.png')print("Training history plot saved as 'training_history.png'")plt.show()if __name__ == '__main__':main()
......
--------------------------------------------------
Train Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:07<00:00, 124.14it/s, loss=0.024222, accuracy=99.22%]Test set: Average loss: 0.0256, Accuracy: 9926/10000 (99.26%)Model saved to ./models/model_epoch_9.pthEpoch 10/10
--------------------------------------------------
Train Epoch 10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 938/938 [00:07<00:00, 127.89it/s, loss=0.021473, accuracy=99.31%]Test set: Average loss: 0.0266, Accuracy: 9927/10000 (99.27%)Model saved to ./models/model_epoch_10.pth
Training complete in 85s (8.52s per epoch)
Best validation accuracy: 99.27%
Training history plot saved as 'training_history.png'

在这里插入图片描述
一、左侧:Loss Curves(损失曲线)
蓝色:训练损失(Training Loss)
橙色:验证损失(Validation Loss)

二、右侧:Accuracy Curves(准确率曲线)
蓝色:训练准确率(Training Accuracy)
橙色:验证准确率(Validation Accuracy)

http://www.lryc.cn/news/618324.html

相关文章:

  • 【测试】Bug+设计测试用例
  • MR一体机(VST)预算思路
  • 如何实现PostgreSQL的高可用性,包括主流的复制方案、负载均衡方法以及故障转移流程?
  • 深入理解机器学习之TF-IDF:文本特征提取的核心技术
  • 防御保护11
  • windows版本:Prometheus+Grafana(普罗米修斯+格拉法纳)监控 JVM
  • 《Redis集群故障转移与自动恢复》
  • Myqsl建立库表练习
  • 零基础渗透测试全程记录(打靶)——Prime
  • linux远程部署dify和mac本地部署dify
  • java python
  • c#联合Halcon进行OCR字符识别(含halcon-25.05 百度网盘)
  • Docker 101:面向初学者的综合教程
  • Go 语言中的结构体、切片与映射:构建高效数据模型的基石
  • 五、Nginx、RabbitMQ和Redis在Linux中的安装和部署
  • Homebrew 入门教程(2025 年最新版)
  • docker-compose搭建 redis 集群
  • ETCD的简介和使用
  • 通用同步/异步收发器USART串口
  • Qwen-OCR:开源OCR技术的演进与全面分析
  • 嵌入式学习(day25)文件IO:open read/write close
  • Baumer高防护相机如何通过YoloV8深度学习模型实现木板表面缺陷的检测识别(C#代码UI界面版)
  • iOS混淆工具有哪些?团队协作视角下的分工与防护方案
  • Unity DOTS(一):ECS 初探:大规模实体管理与高性能
  • 鸿蒙下载图片保存到相册,截取某个组件保存到相册
  • 数据库常用操作
  • Linux 可执行程序核心知识笔记:ELF、加载、虚拟地址与动态库
  • 鸿蒙本地与云端数据双向同步实战:从原理到可运行 Demo 的全流程指南
  • Web学习笔记5
  • Linux环境gitlab多种部署方式及具体使用