当前位置: 首页 > news >正文

Day43 复习日

训练主模型

# 训练模型主函数(优化版)
def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, epochs: int
) -> tuple[list[float], list[float], list[float], list[float]]:# 初始化训练和验证过程中的监控指标train_losses: list[float] = []  # 存储每个epoch的训练损失val_losses: list[float] = []    # 存储每个epoch的验证损失train_accuracies: list[float] = []  # 存储每个epoch的训练准确率val_accuracies: list[float] = []    # 存储每个epoch的验证准确率# 新增:早停相关变量(可选)best_val_loss: float = float('inf')early_stop_counter: int = 0early_stop_patience: int = 5  # 连续5个epoch无提升则停止# 主训练循环 - 遍历指定轮数for epoch in range(epochs):# 设置模型为训练模式(启用Dropout和BatchNorm等训练特定层)model.train()train_loss: float = 0.0  # 累积训练损失correct: int = 0         # 正确预测的样本数total: int = 0           # 总样本数# 批次训练循环 - 遍历训练数据加载器中的所有批次for inputs, targets in train_loader:# 将数据移至计算设备(GPU或CPU)inputs, targets = inputs.to(device), targets.to(device)# 梯度清零 - 防止梯度累积(每个批次独立计算梯度)optimizer.zero_grad()# 前向传播 - 通过模型获取预测结果outputs = model(inputs)# 计算损失 - 使用预定义的损失函数(如交叉熵)loss = criterion(outputs, targets)# 反向传播 - 计算梯度loss.backward()# 参数更新 - 根据优化器(如Adam)更新模型权重optimizer.step()# 统计训练指标train_loss += loss.item()  # 累积批次损失_, predicted = outputs.max(1)  # 获取预测类别total += targets.size(0)  # 累积总样本数correct += predicted.eq(targets).sum().item()  # 累积正确预测数# 计算当前epoch的平均训练损失和准确率train_loss /= len(train_loader)  # 平均批次损失train_accuracy = 100.0 * correct / total  # 计算准确率百分比train_losses.append(train_loss)  # 记录损失train_accuracies.append(train_accuracy)  # 记录准确率# 模型验证部分model.eval()  # 设置模型为评估模式(禁用Dropout等)val_loss: float = 0.0  # 累积验证损失correct = 0   # 正确预测的样本数total = 0     # 总样本数# 禁用梯度计算 - 验证过程不需要计算梯度,节省内存和计算资源with torch.no_grad():# 遍历验证数据加载器中的所有批次for inputs, targets in val_loader:# 将数据移至计算设备inputs, targets = inputs.to(device), targets.to(device)# 前向传播 - 获取验证预测结果outputs = model(inputs)# 计算验证损失loss = criterion(outputs, targets)# 统计验证指标val_loss += loss.item()  # 累积验证损失_, predicted = outputs.max(1)  # 获取预测类别total += targets.size(0)  # 累积总样本数correct += predicted.eq(targets).sum().item()  # 累积正确预测数# 计算当前epoch的平均验证损失和准确率val_loss /= len(val_loader)  # 平均验证损失val_accuracy = 100.0 * correct / total  # 计算验证准确率val_losses.append(val_loss)  # 记录验证损失val_accuracies.append(val_accuracy)  # 记录验证准确率# 打印当前epoch的训练和验证指标print(f'Epoch {epoch+1}/{epochs}')print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.2f}%')print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.2f}%')print('-' * 50)# 更新学习率调度器(修正mode为min,匹配验证损失)scheduler.step(val_loss)  # 传入验证损失,mode='min'# 新增:早停逻辑(可选)if val_loss < best_val_loss:best_val_loss = val_lossearly_stop_counter = 0# 可选:保存最佳模型权重torch.save(model.state_dict(), 'best_model.pth')else:early_stop_counter += 1if early_stop_counter >= early_stop_patience:print(f"Early stopping at epoch {epoch+1}")break# 返回训练和验证过程中的所有指标,用于后续分析和可视化return train_losses, val_losses, train_accuracies, val_accuracies# 训练模型(保持调用方式不变)
epochs = 20  
train_losses, val_losses, train_accuracies, val_accuracies = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs
)# 可视化训练过程(保持原函数不变)
def plot_training(train_losses, val_losses, train_accuracies, val_accuracies):plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_losses, label='Train Loss')plt.plot(val_losses, label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.title('Training and Validation Loss')plt.subplot(1, 2, 2)plt.plot(train_accuracies, label='Train Accuracy')plt.plot(val_accuracies, label='Validation Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.title('Training and Validation Accuracy')plt.tight_layout()plt.show()plot_training(train_losses, val_losses, train_accuracies, val_accuracies)

@浙大疏锦行

http://www.lryc.cn/news/621214.html

相关文章:

  • FPGA+护理:跨学科发展的探索(五)
  • Kotlin Data Classes 快速上手
  • 【深度学习】深度学习基础概念与初识PyTorch
  • 报数游戏(我将每文更新tips)
  • IPTV系统:开启视听与管理的全新篇章
  • 14 ABP Framework 文档管理
  • 【软考中级网络工程师】知识点之入侵防御系统:筑牢网络安全防线
  • SpringMVC(详细版从入门到精通)未完
  • P5967 [POI 2016] Korale 题解
  • 【数据分享】2014-2023年长江流域 (0.05度)5.5km分辨率的每小时日光诱导叶绿素荧光SIF数据
  • stm32项目(28)——基于stm32的环境监测并上传至onenet云平台
  • LT3045EDD#TRPBF ADI亚德诺 超低噪声LDO稳压器 电子元器件IC
  • web网站开发,在线%射击比赛成绩管理%系统开发demo,基于html,css,jquery,python,django,model,orm,mysql数据库
  • 模型选择与调优
  • 0814 TCP和DUP通信协议
  • 2021睿抗决赛 猛犸不上 Ban
  • 十分钟学会一个算法 —— 快速排序
  • ASCII与Unicode:编码世界的奥秘
  • 【前端工具】使用 Node.js 脚本实现项目打包后自动压缩
  • C#WPF实战出真汁02--登录界面设计
  • 微服务从0到1
  • 在Ubuntu上安装Google Chrome的详细教程
  • Ubuntu下载、安装、编译指定版本python
  • 大规模调用淘宝商品详情 API 的分布式请求调度实践
  • 大规模分布式光伏并网后对电力系统的影响
  • 自动驾驶与人形机器人的技术分水岭
  • dolphinscheduler中任务输出变量的问题出现ArrayIndexOutOfBoundsException
  • 【记录】Apache SeaTunnel 系统监控信息
  • 反射在Spring IOC容器中的应用——动态创建Bean (补充)
  • Linux 上手 UDP Socket 程序编写(含完整具体demo)