当前位置: 首页 > news >正文

bert-base-chinese模型的完整训练、推理和一些思考

前言

使用google-bert/bert-base-chinese模型进行中文文本分类任务,使用THUCNews中文数据集进行训练,训练完成后,可以导出模型,进行预测。

项目详细介绍和数据下载

数据集下载地址

Github完整代码

现记录训练过程中的一些感悟

1、训练时遇到的两个核心参数warmup_stepsweight_decay

代码片段如下
在这里插入图片描述

需要弄明白一些基础概念

epoch:指模型在训练过程中遍历完整个训练数据集一次。

step:指模型在训练过程中处理完一个batch的数据并完成一次梯度更新。

batch_size: 指在一次step中模型用于训练的数据量。

假设 训练数据集有 n 个样本,每个epoch的step计算方式
s t e p = n b a t c h _ s i z e step = \frac{n}{batch\_size} step=batch_sizen
训练过程的总步数为
s t e p s = s t e p × n u m _ t r a i n _ e p o c h s steps = step \times num\_train\_epochs steps=step×num_train_epochs

warmup_steps:主要目的是为了平稳地提升学习率,让模型在训练初期不会因为太高的学习率而跳过或远离全局最优解。

常见做法是将其设置为总训练步数的5%到10%的值。

此训练过程中warmup steps下限的计算方式如下,训练数据18w
w a r m u p _ s t e p s = 180000 32 × 5 × 5 % = 1406 warmup\_steps = \frac{180000}{32} \times 5 \times 5\% = 1406 warmup_steps=32180000×5×5%=1406

减少 warmup_steps 可能会导致模型更快地达到较高的学习率,从而错过或远离全局最优解。

weight_decay:是用于正则化模型权重的,实际上是 L2 正则化的一种形式

weight_decay的作用是在损失函数中添加一个惩罚项,该惩罚项与权重的平方成正比,这有助于抑制权重的大小,从而防止模型过拟合

weight_decay设置得过低,可能不足以防止过拟合;设置得过高,则可能导致模型欠拟合,即模型过于简单,无法很好地捕捉数据中的模式

2、通过tensorboard --logdir=./logs可视化训练过程

训练过程截图如下:

2.1、训练阶段

可以明显的看到训练时的学习率先逐渐上升之后在下降,这是我们想要的趋势。训练的损失值逐步下降,这也是我们希望的。但是当我们在分析评估数据数据集的损失时,我们会发现此时模型应该是过拟合了。

在这里插入图片描述

2.2、推理阶段

随着训练过程的增加,模型在评估数据集上的损失也是逐步减少,当在step=11250时,评估数据集上的损失开始逐渐增加,而训练数据的损失还在减少,那么可以肯定模型已经过拟合了。

模型已经充分的挖掘训练数据集中的语义特征,过分的学习到数据中的一些细枝末节。从而在新数据集上的表现越来越差。这种在训练数据集上表现优秀,在评估或测试数据集上表现较差现象,即模型出现了过拟合。

在这里插入图片描述

3、模型混淆矩阵的分析

混淆矩阵结果如下
在这里插入图片描述
指标如下

Accuracy0.9434
Precision0.9438
Recall0.9434

具体多分类任务指标和混淆矩阵分析参考这里非常详细。

4、如何解决模型过拟合的现象

【待更新】疯狂参数调节优化中…

http://www.lryc.cn/news/423419.html

相关文章:

  • JS基础5(JS的作用域和JS预解析)
  • Doris 夺命 30 连问!(中)
  • 书生.浦江大模型实战训练营——(四)书生·浦语大模型全链路开源开放体系
  • SpringBoot 整合 RabbitMQ 实现延迟消息
  • Cilium:基于开源 eBPF 的网络、安全性和可观察性
  • Axios 详解与使用指南
  • 深度学习 —— 个人学习笔记20(转置卷积、全卷积网络)
  • 解决Mac系统Python3.12版本pip安装报错error: externally-managed-environment的问题
  • lvm知识终结
  • ESP32S3 IDF 对 16路输入输出芯片MCP23017做了个简单的测试
  • 【技术前沿】Flux.1部署教程入门--Stable Diffusion团队最前沿、免费的开源AI图像生成器
  • Redis 的 STREAM 和 RocketMQ 是两种不同的消息队列和流处理解决方案,它们在设计理念、功能和用途上有显著区别。以下是它们的主要区别:
  • Visual Studio Code安装与C/C++语言运行(上)
  • 探索数据可视化,数据看板在各行业中的应用
  • haralyzer 半自动,一次性少量数据采集快捷方法
  • mall-admin-web-master前端项目下载依赖失败解决
  • 【07】JVM是怎么实现invokedynamic的
  • 使用API有效率地管理Dynadot域名,查看参与的拍卖列表
  • Linux 基本指令讲解
  • PRE_EMPHASIS
  • 【QT常用技术讲解】多线程处理+全局变量处理异步事件并获取多个线程返回的结果
  • 数组列表中的最大距离
  • C语言新手小白详细教程(7)指针和指针变量
  • Kafka保证消息不丢失
  • 数据结构+基数排序算法
  • C++ list【常用接口、模拟实现等】
  • 12.面试题——Spring Boot
  • 【前端VUE】npm i 出现版本错误等报错 简单直接解决命令
  • 精彩回顾 | 风丘科技亮相2024名古屋汽车工程博览会
  • 设计模式21-组合模式