当前位置: 首页 > news >正文

深度学习中的早停法

早停法(Early Stopping)是一种用于防止模型过拟合的技术,在训练过程中监视验证集(或者测试集)上的损失值。具体设立早停的限制包括两个主要参数:

  1. Patience(耐心):这是指验证集损失在连续多少个epoch没有显著改善时,才触发早停。当验证集损失连续几个epoch没有下降或者停止减少时,表示模型可能已经过拟合或者陷入局部最优点,这时候早停就会被触发。

  2. Best Loss(最佳损失):这是指在早停过程中保存的最低验证集损失值。当验证集损失值低于当前最佳损失时,更新最佳损失并重置耐心计数器。如果验证集损失连续不降,耐心计数器超过设定的耐心值时,早停就会被触发,训练过程停止。

    早停的具体设立是基于验证集上的损失值 val_loss。每次验证后,如果当前的 val_lossbest_loss 还要低,就更新 best_loss 并重置 patience_counter;否则,增加 patience_counter。当 patience_counter 达到设定的 patience 值时,早停被触发,即停止训练过程以防止模型过拟合。

    总结来说,早停的设立限制是基于耐心参数和最佳损失值,用来判断模型是否应该停止训练以避免过拟合。

# 训练模型
num_epochs = 200  # 总的训练轮数
best_loss = float('inf')  # 初始化最佳验证损失为正无穷大
patience = 10  # 早停的耐心值
patience_counter = 0  # 耐心计数器for epoch in range(num_epochs):model.train()for geno, pheno in train_loader:optimizer.zero_grad()  # 梯度清零outputs = model(geno)  # 前向传播loss = criterion(outputs.squeeze(), pheno)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 优化模型参数model.eval()val_loss = 0with torch.no_grad():  # 不计算梯度for geno, pheno in test_loader:outputs = model(geno)  # 前向传播val_loss += criterion(outputs.squeeze(), pheno).item()  # 计算验证损失val_loss /= len(test_loader)  # 计算平均验证损失print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}')scheduler.step(val_loss)  # 更新学习率# 早停法if val_loss < best_loss:best_loss = val_loss  # 更新最佳验证损失patience_counter = 0  # 重置耐心计数器else:patience_counter += 1  # 增加耐心计数器if patience_counter >= patience:  # 如果耐心计数器达到设定的耐心值print("Early stopping triggered")  # 触发早停break
  1. EarlyStopping
    • __init__ 方法初始化早停的参数,如 patience(耐心值)、verbose(是否打印消息)和 delta(损失改进的最小变化)。
    • __call__ 方法根据验证损失来决定是否更新 best_loss,以及是否增加计数器或者触发早停。
  2. 训练循环
    • 训练和验证过程与之前相同。
    • 每个epoch结束时,调用 early_stopping 对象,传入当前的验证损失。
    • 检查 early_stopping.early_stop 标志,如果为 True,则打印消息并停止训练。

通过使用 EarlyStopping 类,你可以更简洁和模块化地实现早停功能,使代码更易于维护和扩展。

import torch
import numpy as npclass EarlyStopping:def __init__(self, patience=10, verbose=False, delta=0):"""EarlyStopping 初始化.Args:patience (int): 当验证集损失在指定的epoch数内没有减少时触发早停.verbose (bool): 如果为True,则每次验证集损失改进时会打印一条消息.delta (float): 验证集损失改进的最小变化."""self.patience = patienceself.verbose = verboseself.delta = deltaself.best_loss = Noneself.counter = 0self.early_stop = Falsedef __call__(self, val_loss):if self.best_loss is None:self.best_loss = val_losselif val_loss > self.best_loss - self.delta:self.counter += 1if self.verbose:print(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_loss = val_lossself.counter = 0if self.verbose:print(f'Validation loss decreased to {self.best_loss:.6f}. Resetting counter.')# 初始化EarlyStopping对象
early_stopping = EarlyStopping(patience=10, verbose=True)# 训练模型
num_epochs = 200
for epoch in range(num_epochs):model.train()for geno, pheno in train_loader:optimizer.zero_grad()outputs = model(geno)loss = criterion(outputs.squeeze(), pheno)loss.backward()optimizer.step()model.eval()val_loss = 0with torch.no_grad():for geno, pheno in test_loader:outputs = model(geno)val_loss += criterion(outputs.squeeze(), pheno).item()val_loss /= len(test_loader)print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}')scheduler.step(val_loss)# 检查是否触发早停early_stopping(val_loss)if early_stopping.early_stop:print("Early stopping triggered")break

 

 

http://www.lryc.cn/news/414945.html

相关文章:

  • 科普文:JUC系列之多线程门闩同步器CountDownLatch的使用和源码
  • foreach循环和for循环在PHP中各有什么优势
  • 巧用casaos共享挂载自己的外接硬盘为局域网共享
  • 标题:解码“八股文”:助力、阻力,还是空谈?
  • 语言无界,沟通无限:2024年好用在线翻译工具推荐
  • 【Golang 面试 - 进阶题】每日 3 题(十八)
  • 二分+dp,CF 1993D - Med-imize
  • 三十种未授权访问漏洞复现 合集( 三)
  • 数据湖和数据仓库核心概念与对比
  • 探索WebKit的奥秘:打造高效、兼容的现代网页应用
  • 【leetcode】平衡二叉树、对称二叉树、二叉树的层序遍历(广度优先遍历)(详解)
  • 最短路径算法:Floyd-Warshall算法
  • 3DM游戏运行库合集离线安装包2024最新版
  • 【Bigdata】什么是混合型联机分析处理
  • Java 并发编程:volatile 关键字介绍与使用
  • 【Spark计算引擎----第三篇(RDD)---《深入理解 RDD:依赖、Spark 流程、Shuffle 与缓存》】
  • 四、日志收集loki+ promtail+grafana
  • xdma的linux驱动编译给arm使用(中断检测-测试程序)
  • 探索之路——初识 Vue Router:构建单页面应用的完整指南
  • 传输层_计算机网络
  • 自动驾驶的六个级别是什么?
  • 深度学习复盘与论文复现F
  • 如何学习自动化测试工具!
  • 短信接口被恶意盗刷
  • 实验4-2-1 求e的近似值
  • 内网穿透--LCX+portmap转发实验
  • 缓存一致性问题
  • 【MYSQL】MYSQL逻辑架构
  • 【Python】数据类型之字符串
  • c++编写java模式的线程类