当前位置: 首页 > news >正文

huggingface 笔记: Trainer

  • Trainer 是一个为 Transformers 中 PyTorch 模型设计的完整训练与评估循环
  • 只需将模型、预处理器、数据集和训练参数传入 Trainer,其余交给它处理,即可快速开始训练
  • 自动处理以下训练流程:
    • 根据 batch 计算 loss

    • 使用 backward() 计算梯度

    • 根据梯度更新权重

    • 重复上述流程直到达到指定的 epoch 数

1 配置TrainingArguments

使用 TrainingArguments 定义训练超参数与功能选项

from transformers import TrainingArgumentstraining_args = TrainingArguments(。。。
)

1.1 主要参数

output_dir模型预测结果与检查点的保存目录
do_train是否执行训练(bool,可选,默认 False)
do_eval是否在验证集上执行评估
do_predict

(bool,可选,默认 False)

是否在测试集上执行预测

eval_strategy

训练期间的评估策略。可选值:

  • "no":不进行评估;

  • "steps":每隔 eval_steps 步评估一次;

  • "epoch":每个 epoch 结束时评估。

prediction_loss_only

(bool,可选,默认 False)

评估和预测阶段是否只返回损失

per_device_train_batch_size

(int,可选,默认 8)

训练时每个设备的 batch 大小

per_device_eval_batch_size

(int,可选,默认 8)

评估时每个设备的 batch 大小

gradient_accumulation_steps

(int,可选,默认 1)

累计多少步的梯度后执行一次反向传播和参数更新

eval_accumulation_steps

(int,可选)

在将预测结果移动到 CPU 前累积多少步输出。如果未设置,则所有预测保留在设备内存中,速度快但占用显存多。

eval_delay训练开始后延迟多少个 step 或 epoch 才进行第一次评估
torch_empty_cache_steps每隔多少步调用 torch.<device>.empty_cache() 清空缓存,以减少 CUDA OOM,代价是性能可能下降 10%。
优化器相关learning_rate
weight_decay
adam_beta1 / adam_beta2:AdamW 的 beta 参数,默认 0.9 / 0.999
adam_epsilon:AdamW 的 epsilon,默认 1e-8
max_grad_norm:梯度裁剪最大范数,默认 1.0
训练周期与步数控制num_train_epochs:训练 epoch 总数

max_steps:训练的最大 step 数

                        若设为正数,将覆盖 num_train_epochs

学习率调度器相关lr_scheduler_type:调度器类型(如 linear、cosine、polynomial)
lr_scheduler_kwargs:调度器额外参数

warmup_ratio:总步数中用于 warmup 的比例

或者

warmup_steps:warmup 阶段的步数(优先于 warmup_ratio)

日志与保存设置log_level / log_level_replica:主进程和副本日志等级(如 info、warning)
log_on_each_node:多节点训练中每个节点是否都记录日志
logging_dir:TensorBoard 日志目录,默认在 output_dir/runs
logging_strategy:日志记录策略("no"、"steps"、"epoch")
logging_first_step:是否记录第一个 global step 的日志
logging_steps:记录日志的步数间隔
logging_nan_inf_filter:是否过滤 NaN/inf 损失值
模型保存设置save_strategy:模型保存策略("no"、"epoch"、"steps"、"best")
save_steps:按步保存间隔
save_total_limit:保留的最大检查点数
save_safetensors:是否使用 safetensors 格式保存模型
save_on_each_node:是否在每个节点都保存模型
save_only_model:是否仅保存模型,不保存优化器、调度器等状态(省空间但不可恢复)
use_cpu是否强制使用 CPU。如果设为 False,将优先使用 CUDA 或 MPS(如可用)
seed(int,可选,默认 42)
设置训练开始时的随机种子。
data_seed

(int,可选)
数据采样使用的随机种子

如果未设置,将与 seed 相同

bf16(bool,可选,默认 False)
是否启用 bfloat16 精度训练
fp16(bool,可选,默认 False)
是否启用 float16 精度训练(混合精度训练)
DataLoader 设置dataloader_drop_last(bool,可选,默认 False)
数据长度不整除 batch size 时是否丢弃最后一个 batch
eval_steps(int 或 float,可选)
如果 eval_strategy="steps",两个评估之间的步数间隔。默认与 logging_steps 相同
dataloader_num_workers(int,可选,默认 0)
用于数据加载的子进程数量。0 表示使用主进程加载
metric_for_best_model(str,可选)
指定用于判断最佳模型的指标名称。应为 eval_ 开头的某个评估指标
greater_is_better(bool,可选)
指标值越大是否表示模型越好。若 metric_for_best_model"loss" 结尾,则默认为 False,否则为 True
optim优化器名称,如 "adamw_torch""adafactor""galore_adamw"
optim_args传给优化器的附加参数字符串
group_by_length是否将长度相近的样本分组(减少 padding)
dataloader_pin_memory是否将数据固定在内存中(可加快 CPU 到 GPU 的传输)
eval_on_start

是否在训练开始前先评估一次(sanity check)

在训练开始前进行一次验证,以确保评估流程无误

gradient_checkpointing(bool,可选,默认 False)
启用梯度检查点以节省内存,代价是反向传播会更慢

 

2 传入Trainer

然后将模型、数据集、预处理器和 TrainingArguments 传入 Trainer,调用 train() 启动训练:

from transformers import Trainertrainer = Trainer(model=model,args=training_args,train_dataset=dataset["train"],eval_dataset=dataset["test"],processing_class=tokenizer,data_collator=data_collator,compute_metrics=compute_metrics,
)trainer.train()

2.1 传入自定义模型时,需要的条件 

  • 模型始终返回元组(tuple)或 ModelOutput 的子类
  • 若传入了 labels 参数,你的模型能够计算损失,并将损失作为返回元组的第一个元素

2.2 参数说明

modelPreTrainedModel 或 torch.nn.Module
argsTrainingArguments类型,用于配置训练的参数
data_collator用于将 train_dataseteval_dataset 中的元素列表整理成 batch 的函数
train_dataset用于训练的数据集 Dataset 类型
eval_dataset

用于评估的数据集。

若是字典类型,将对每个键值对分别进行评估,并在指标名前添加键名作为前缀。

compute_loss_func用于计算损失的函数
compute_metrics评估阶段使用的评估指标计算函数
optimizers

(Optimizer, Scheduler) 元组

优化器和学习率调度器的元组

如果未提供,将使用模型上的 AdamW 优化器和由 get_linear_schedule_with_warmup() 创建的调度器

 2.3 compute_loss 方法

model

nn.Module

用于计算损失的模型

inputs输入模型的数据字典,键通常是 'input_ids''attention_mask''labels'
return_outputs(bool,可选,默认 False)
是否将模型的输出与损失一起返回。若为 True,则返回一个元组 (loss, outputs);否则仅返回 loss
num_items_in_batch

当前 batch 中的样本数量。若未传入此参数,将自动从 inputs 中推断。

用于在梯度累积时确保损失缩放正确

3 checkpoints 检查点

  • Trainer 会将检查点(默认为不保存优化器状态)保存到 TrainingArguments 中的 output_dir 路径下的子目录 checkpoint-000
    • 结尾数字表示保存该检查点的训练步数
  • 保存检查点有助于在中断后恢复训练。
    • 可通过在 train() 中设置 resume_from_checkpoint 参数来从最近或指定的检查点恢复:
    • trainer.train(resume_from_checkpoint=True)
      

 4 日志记录

  • 默认情况下,Trainer 使用 logging.INFO 级别输出错误、警告和基本信息
  • 可以通过 log_level() 修改日志级别
import logging
import syslogger = logging.getLogger(__name__)
#使用当前模块名 __name__ 创建一个 loggerlogging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",handlers=[logging.StreamHandler(sys.stdout)],
)
'''
format: 控制日志的格式,包括:%(asctime)s:时间戳%(levelname)s:日志级别(INFO、WARNING、ERROR 等)%(name)s:logger 的名字(这里是模块名)%(message)s:日志正文内容datefmt: 时间格式为“月/日/年 时:分:秒”handlers: 使用 sys.stdout,把日志输出到终端(而不是默认的 stderr)类似于这样的格式:07/09/2025 10:30:21 - INFO - __main__ - 正在加载模型
'''log_level = training_args.get_process_log_level()
#从 TrainingArguments 中获取当前训练进程应使用的日志级别logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
'''
这三行分别设置:当前模块的日志等级;datasets 库的日志等级;transformers 库的日志等级。
'''trainer = Trainer(...)

5 自定义功能

  • 可以通过继承 Trainer 或重写其方法,来添加所需功能,而无需从零开始编写训练循环
get_train_dataloader()创建训练数据加载器
get_eval_dataloader()创建评估数据加载器
get_test_dataloader()创建测试数据加载器
log()记录训练过程中的信息
create_optimizer_and_scheduler()创建优化器和学习率调度器
compute_loss()计算训练批次的损失
training_step()执行训练步骤
prediction_step()执行预测步骤
evaluate()评估模型并返回指标
predict()进行预测并返回指标(如果有标签)

5.1 举例:自定义compute_loss

from torch import nn
from transformers import Trainerclass CustomTrainer(Trainer):def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):labels = inputs.pop("labels")#从输入中取出 labels,其余留作模型输入outputs = model(**inputs)logits = outputs.get("logits")#通过模型获取输出。#模型返回一个字典,里面包含 logits(每个token的概率)。reduction = "mean" if num_items_in_batch is not None else "sum"'''reduction 控制 loss 的聚合方式:"sum":默认对所有样本 loss 求和;"mean":如果指定了 num_items_in_batch,最终会除以样本数,等价于平均 loss。'''loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device), reduction=reduction)#三类标签设置不同权重。第0类权重为1.0,第1类为2.0,第2类为3.0loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))'''将 logits 和 labels reshape 为二维,以满足 CrossEntropyLoss 的输入要求:logits: [batch_size, num_labels]labels: [batch_size]'''if num_items_in_batch is not None:loss = loss / num_items_in_batchreturn (loss, outputs) if return_outputs else loss'''若 return_outputs=True,则返回 (loss, outputs);否则只返回 loss'''

其他函数自动继承

http://www.lryc.cn/news/583063.html

相关文章:

  • 打造自己的组件库(二)CSS工程化方案
  • 跨服务sqlplus连接oracle数据库
  • 54页|PPT|新型数字政府综合解决方案:“一网 一云 一中台 N应用”平台体系 及“安全+运营”服务体系
  • 人工智能的基石:TensorFlow与PyTorch在图像识别和NLP中的应用
  • 影石(insta360)X4运动相机视频删除的恢复方法
  • 【视频观看系统】- 需求分析
  • 【DB2】load报错SQL3501W、SQL3109N、SQL2036N
  • Tensorflow的安装记录
  • django 一个表中包括id和parentid,如何通过parentid找到全部父爷id
  • react+ts 移动端页面分页,触底加载下一页
  • 板凳-------Mysql cookbook学习 (十一--------6)
  • 安卓设备信息查看器 - 源码编译
  • Android-重学kotlin(协程源码第二阶段)新学习总结
  • 中望CAD2026亮点速递(5):【相似查找】高效自动化识别定位
  • uniapp AndroidiOS 定位权限检查
  • Android ViewModel机制与底层原理详解
  • upload-labs靶场通关详解:第19关 条件竞争(二)
  • 池化思想-Mysql异步连接池
  • 5.注册中心横向对比:Nacos vs Eureka vs Consul —— 深度解析与科学选型指南
  • Web 前端框架选型:React、Vue 和 Angular 的对比与实践
  • 华为静态路由配置
  • 小米路由器3C刷OpenWrt,更换系统/变砖恢复 指南
  • 语音识别核心模型的数学原理和公式
  • 从互联网电脑迁移Dify到内网部署Dify方法记录
  • 【编程史】IDE 是谁发明的?从 punch cards 到 VS Code
  • 计算机网络实验——访问H3C网络设备
  • Java项目集成Log4j2全攻略
  • Using Spring for Apache Pulsar:Publishing and Consuming Partitioned Topics
  • 飞算 JavaAI 智能编程助手 - 重塑编程新模态
  • bash 判断 /opt/wslibs-cuda11.8 是否为软连接, 如果是,获取连接目的目录并自动创建