Day43 复习日
训练主模型
# 训练模型主函数(优化版)
def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, epochs: int
) -> tuple[list[float], list[float], list[float], list[float]]:# 初始化训练和验证过程中的监控指标train_losses: list[float] = [] # 存储每个epoch的训练损失val_losses: list[float] = [] # 存储每个epoch的验证损失train_accuracies: list[float] = [] # 存储每个epoch的训练准确率val_accuracies: list[float] = [] # 存储每个epoch的验证准确率# 新增:早停相关变量(可选)best_val_loss: float = float('inf')early_stop_counter: int = 0early_stop_patience: int = 5 # 连续5个epoch无提升则停止# 主训练循环 - 遍历指定轮数for epoch in range(epochs):# 设置模型为训练模式(启用Dropout和BatchNorm等训练特定层)model.train()train_loss: float = 0.0 # 累积训练损失correct: int = 0 # 正确预测的样本数total: int = 0 # 总样本数# 批次训练循环 - 遍历训练数据加载器中的所有批次for inputs, targets in train_loader:# 将数据移至计算设备(GPU或CPU)inputs, targets = inputs.to(device), targets.to(device)# 梯度清零 - 防止梯度累积(每个批次独立计算梯度)optimizer.zero_grad()# 前向传播 - 通过模型获取预测结果outputs = model(inputs)# 计算损失 - 使用预定义的损失函数(如交叉熵)loss = criterion(outputs, targets)# 反向传播 - 计算梯度loss.backward()# 参数更新 - 根据优化器(如Adam)更新模型权重optimizer.step()# 统计训练指标train_loss += loss.item() # 累积批次损失_, predicted = outputs.max(1) # 获取预测类别total += targets.size(0) # 累积总样本数correct += predicted.eq(targets).sum().item() # 累积正确预测数# 计算当前epoch的平均训练损失和准确率train_loss /= len(train_loader) # 平均批次损失train_accuracy = 100.0 * correct / total # 计算准确率百分比train_losses.append(train_loss) # 记录损失train_accuracies.append(train_accuracy) # 记录准确率# 模型验证部分model.eval() # 设置模型为评估模式(禁用Dropout等)val_loss: float = 0.0 # 累积验证损失correct = 0 # 正确预测的样本数total = 0 # 总样本数# 禁用梯度计算 - 验证过程不需要计算梯度,节省内存和计算资源with torch.no_grad():# 遍历验证数据加载器中的所有批次for inputs, targets in val_loader:# 将数据移至计算设备inputs, targets = inputs.to(device), targets.to(device)# 前向传播 - 获取验证预测结果outputs = model(inputs)# 计算验证损失loss = criterion(outputs, targets)# 统计验证指标val_loss += loss.item() # 累积验证损失_, predicted = outputs.max(1) # 获取预测类别total += targets.size(0) # 累积总样本数correct += predicted.eq(targets).sum().item() # 累积正确预测数# 计算当前epoch的平均验证损失和准确率val_loss /= len(val_loader) # 平均验证损失val_accuracy = 100.0 * correct / total # 计算验证准确率val_losses.append(val_loss) # 记录验证损失val_accuracies.append(val_accuracy) # 记录验证准确率# 打印当前epoch的训练和验证指标print(f'Epoch {epoch+1}/{epochs}')print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.2f}%')print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.2f}%')print('-' * 50)# 更新学习率调度器(修正mode为min,匹配验证损失)scheduler.step(val_loss) # 传入验证损失,mode='min'# 新增:早停逻辑(可选)if val_loss < best_val_loss:best_val_loss = val_lossearly_stop_counter = 0# 可选:保存最佳模型权重torch.save(model.state_dict(), 'best_model.pth')else:early_stop_counter += 1if early_stop_counter >= early_stop_patience:print(f"Early stopping at epoch {epoch+1}")break# 返回训练和验证过程中的所有指标,用于后续分析和可视化return train_losses, val_losses, train_accuracies, val_accuracies# 训练模型(保持调用方式不变)
epochs = 20
train_losses, val_losses, train_accuracies, val_accuracies = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs
)# 可视化训练过程(保持原函数不变)
def plot_training(train_losses, val_losses, train_accuracies, val_accuracies):plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_losses, label='Train Loss')plt.plot(val_losses, label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.title('Training and Validation Loss')plt.subplot(1, 2, 2)plt.plot(train_accuracies, label='Train Accuracy')plt.plot(val_accuracies, label='Validation Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.title('Training and Validation Accuracy')plt.tight_layout()plt.show()plot_training(train_losses, val_losses, train_accuracies, val_accuracies)
@浙大疏锦行