TRL - Transformer Reinforcement Learning SFTTrainer 和 SFTConfig
TRL - Transformer Reinforcement Learning SFTTrainer 和 SFTConfig
flyfish
Name: trl
Version: 0.21.0
Summary: Train transformer language models with reinforcement learning.
Home-page: https://github.com/huggingface/trl
SFTTrainer 的作用就是把微调这个过程变简单:
不用自己写复杂的训练代码(比如怎么加载数据、怎么处理对话格式、怎么保存模型);
它能自动处理各种数据,直接喂给模型;
轻松搭配 PEFT 这类 “省资源” 的工具(不用训模型全部参数,少花钱少占内存);
训练中的日志、保存、断点续训这些琐事也搞定了。
|
简单示例
from trl import SFTConfig, SFTTrainer
from datasets import load_datasetdataset = load_dataset("trl-lib/Capybara", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(args=training_args,model="Qwen/Qwen2.5-0.5B",train_dataset=dataset,
)
trainer.train()
{'loss': 1.7356, 'grad_norm': 4.851604461669922, 'learning_rate': 1.9969635627530365e-05, 'num_tokens': 59528.0, 'mean_token_accuracy': 0.6122099697589874, 'epoch': 0.01}
{'loss': 1.624, 'grad_norm': 6.229388236999512, 'learning_rate': 1.9935897435897437e-05, 'num_tokens': 115219.0, 'mean_token_accuracy': 0.6266456484794617, 'epoch': 0.01}
{'loss': 1.4456, 'grad_norm': 5.725496292114258, 'learning_rate': 1.990215924426451e-05, 'num_tokens': 171787.0, 'mean_token_accuracy': 0.6584609568119049, 'epoch': 0.02}
{'loss': 1.6019, 'grad_norm': 5.118398189544678, 'learning_rate': 1.986842105263158e-05, 'num_tokens': 226067.0, 'mean_token_accuracy': 0.6274505913257599, 'epoch': 0.02}
{'loss': 1.5735, 'grad_norm': 4.509005546569824, 'learning_rate': 1.9834682860998653e-05, 'num_tokens': 284290.0, 'mean_token_accuracy': 0.6260855078697205, 'epoch': 0.03}
{'loss': 1.5226, 'grad_norm': 4.884885311126709, 'learning_rate': 1.9800944669365722e-05, 'num_tokens': 338761.0, 'mean_token_accuracy': 0.6402752816677093, 'epoch': 0.03}
{'loss': 1.5326, 'grad_norm': 5.511511325836182, 'learning_rate': 1.9767206477732795e-05, 'num_tokens': 397731.0, 'mean_token_accuracy': 0.6352024018764496, 'epoch': 0.04}
{'loss': 1.3588, 'grad_norm': 7.149945259094238, 'learning_rate': 1.9733468286099865e-05, 'num_tokens': 451526.0, 'mean_token_accuracy': 0.665032935142517, 'epoch': 0.04}
{'loss': 1.4091, 'grad_norm': 4.552429676055908, 'learning_rate': 1.9699730094466938e-05, 'num_tokens': 505287.0, 'mean_token_accuracy': 0.6482574462890625, 'epoch': 0.05}
{'loss': 1.4679, 'grad_norm': 4.194477081298828, 'learning_rate': 1.966599190283401e-05, 'num_tokens': 563276.0, 'mean_token_accuracy': 0.6456288278102875, 'epoch': 0.05}
{'loss': 1.3072, 'grad_norm': 4.873239994049072, 'learning_rate': 1.963225371120108e-05, 'num_tokens': 623927.0, 'mean_token_accuracy': 0.6719759583473206, 'epoch': 0.06}
{'loss': 1.4861, 'grad_norm': 5.325733661651611, 'learning_rate': 1.9598515519568153e-05, 'num_tokens': 678587.0, 'mean_token_accuracy': 0.6462028443813324, 'epoch': 0.06}
{'loss': 1.5824, 'grad_norm': 4.714112281799316, 'learning_rate': 1.9564777327935226e-05, 'num_tokens': 736510.0, 'mean_token_accuracy': 0.6345476865768432, 'epoch': 0.07}
{'loss': 1.4407, 'grad_norm': 4.857705116271973, 'learning_rate': 1.9531039136302295e-05, 'num_tokens': 794954.0, 'mean_token_accuracy': 0.6477592408657074, 'epoch': 0.07}
SFTTrainer 中加入 PEFT(参数高效微调)配置
使用 PEFT 库中的配置类(如 LoRA)并将其传递给peft_config参数
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
import torch# 加载数据集
dataset = load_dataset("trl-lib/Capybara", split="train")# 配置PEFT (使用LoRA作为示例)
peft_config = LoraConfig(r=8, # LoRA注意力维度lora_alpha=32, # LoRA缩放参数target_modules=[ # Qwen模型的目标模块,不同模型可能不同"q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj"],lora_dropout=0.05, # Dropout概率bias="none", # 不训练偏置参数task_type="CAUSAL_LM", # 任务类型:因果语言模型inference_mode=False # 训练模式
)# 配置训练参数
training_args = SFTConfig(output_dir="./Qwen2.5-0.5B-SFT-PEFT", # 输出目录num_train_epochs=3, # 训练轮数per_device_train_batch_size=4, # 每个设备的批次大小gradient_accumulation_steps=2, # 梯度累积步数learning_rate=2e-4, # 学习率logging_steps=10, # 日志记录步数save_steps=100, # 模型保存步数fp16=True, # 使用混合精度训练optim="paged_adamw_8bit", # 使用8位优化器节省显存report_to="wandb" if "wandb" in locals() else "none", # 日志报告方式max_length=1024,
)# 初始化SFTTrainer并传入PEFT配置
trainer = SFTTrainer(args=training_args,model="Qwen/Qwen2.5-0.5B",train_dataset=dataset,peft_config=peft_config, # 添加PEFT配置)# 开始训练
trainer.train()# 保存PEFT模型
trainer.save_model()# 如果需要,可以将PEFT模型与基础模型合并(推理时使用)
# from peft import AutoPeftModelForCausalLM
# model = AutoPeftModelForCausalLM.from_pretrained(
# "./Qwen2.5-0.5B-SFT-PEFT",
# device_map="auto",
# torch_dtype=torch.bfloat16
# )
# merged_model = model.merge_and_unload()
# merged_model.save_pretrained("./Qwen2.5-0.5B-SFT-merged")
[INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)
{'loss': 1.7566, 'grad_norm': 0.8341224193572998, 'learning_rate': 0.00019969635627530366, 'num_tokens': 59528.0, 'mean_token_accuracy': 0.6122347742319107, 'epoch': 0.01}
{'loss': 1.637, 'grad_norm': 0.9207030534744263, 'learning_rate': 0.00019935897435897437, 'num_tokens': 115219.0, 'mean_token_accuracy': 0.6304881483316421, 'epoch': 0.01}
{'loss': 1.4695, 'grad_norm': 0.7871270775794983, 'learning_rate': 0.0001990215924426451, 'num_tokens': 171787.0, 'mean_token_accuracy': 0.6504963368177414, 'epoch': 0.02}
{'loss': 1.6518, 'grad_norm': 0.7337052226066589, 'learning_rate': 0.0001986842105263158, 'num_tokens': 226067.0, 'mean_token_accuracy': 0.621559776365757, 'epoch': 0.02}
{'loss': 1.6317, 'grad_norm': 0.6506398916244507, 'learning_rate': 0.00019834682860998652, 'num_tokens': 284290.0, 'mean_token_accuracy': 0.6257665097713471, 'epoch': 0.03}
{'loss': 1.5889, 'grad_norm': 0.7434213161468506, 'learning_rate': 0.0001980094466936572, 'num_tokens': 338761.0, 'mean_token_accuracy': 0.6391594052314759, 'epoch': 0.03}
{'loss': 1.5915, 'grad_norm': 0.7964017987251282, 'learning_rate': 0.00019767206477732793, 'num_tokens': 397731.0, 'mean_token_accuracy': 0.6308055430650711, 'epoch': 0.04}
{'loss': 1.4446, 'grad_norm': 1.0523793697357178, 'learning_rate': 0.00019733468286099867, 'num_tokens': 451526.0, 'mean_token_accuracy': 0.6549594551324844, 'epoch': 0.04}
{'loss': 1.4834, 'grad_norm': 0.6530594229698181, 'learning_rate': 0.00019699730094466938, 'num_tokens': 505287.0, 'mean_token_accuracy': 0.6439991772174836, 'epoch': 0.05}
{'loss': 1.5385, 'grad_norm': 0.5796452164649963, 'learning_rate': 0.0001966599190283401, 'num_tokens': 563276.0, 'mean_token_accuracy': 0.6417025059461594, 'epoch': 0.05}
{'loss': 1.3649, 'grad_norm': 0.6477728486061096, 'learning_rate': 0.00019632253711201081, 'num_tokens': 623927.0, 'mean_token_accuracy': 0.6682371526956559, 'epoch': 0.06}
SFTTrainer参数
参数名称 | 类型 | 描述 |
---|---|---|
model | Union[str, PreTrainedModel] | 待训练的模型。可是huggingface模型ID、本地模型目录路径,或PreTrainedModel 对象(仅支持因果语言模型);通过AutoModelForCausalLM.from_pretrained 加载,支持args.model_init_kwargs 参数。 |
args | [SFTConfig ](可选,默认None ) | 训练器配置。若为None ,则使用默认配置。 |
data_collator | DataCollator (可选) | 从处理后的训练/评估数据集元素列表中构建批次的函数。默认使用自定义的DataCollatorForLanguageModeling 。 |
train_dataset | [~datasets.Dataset ] 或 [~datasets.IterableDataset ] | 训练数据集。支持语言建模型和提示-补全型,样本格式可为标准文本(纯文本)或对话格式(结构化消息,如角色+内容);也支持已分词数据集(需含input_ids 字段)。 |
eval_dataset | [~datasets.Dataset ]、[~datasets.IterableDataset ] 或 dict[...] | 评估数据集。需满足与train_dataset 相同的要求。 |
processing_class | PreTrainedTokenizerBase 等(可选,默认None ) | 数据处理类(如分词器)。若为None ,则通过AutoTokenizer.from_pretrained 从模型名称加载。 |
callbacks | 列表([~transformers.TrainerCallback] ,可选,默认None ) | 自定义训练循环的回调列表。会添加到默认回调列表中;可通过remove_callback 方法移除默认回调。 |
optimizers | tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] (可选,默认(None, None) ) | 包含优化器和调度器的元组。默认使用AdamW 优化器和get_linear_schedule_with_warmup 调度器(由args 控制)。 |
optimizer_cls_and_kwargs | Tuple[Type[torch.optim.Optimizer], Dict[str, Any]] (可选,默认None ) | 包含优化器类和关键字参数的元组。会覆盖args 中的optim 和optim_args ;与optimizers 参数不兼容,且无需提前将模型参数放到正确设备上。 |
preprocess_logits_for_metrics | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] (可选,默认None ) | 评估步骤中缓存logits前的预处理函数。需接收logits和标签(可为None ),返回处理后的logits;修改会反映在compute_metrics 接收的预测结果中。 |
peft_config | [~peft.PeftConfig ](可选,默认None ) | 用于封装模型的PEFT配置(如LoRA)。若为None ,则不封装模型。 |
formatting_func | Optional[Callable] | 分词前应用于数据集的格式化函数。会将数据集显式转换为语言建模类型。 |
SFTConfig
的参数
SFTConfig
继承自 transformers.TrainingArguments
作用
SFTConfig
是 SFTTrainer
的配置类,用于设置监督微调(SFT)的各项参数,包括模型加载、数据预处理、训练策略等细节,简化训练配置流程。通过细化模型加载、数据处理(如打包、填充)、训练策略(如损失计算范围)等参数,让监督微调更灵活适配不同场景(如对话训练、指令微调),同时继承了 TrainingArguments
的通用训练配置(如批次大小、训练轮数等),无需重复定义。
说明
参数名称 | 类型 | 默认值 | 描述 |
---|---|---|---|
覆盖父类默认值的参数 | |||
learning_rate | float | 2e-5 | AdamW 优化器的初始学习率(父类默认值不同)。 |
logging_steps | float | 10 | 每多少步记录一次日志(可设为整数或 [0,1) 之间的比例,代表总步数的占比)。 |
gradient_checkpointing | bool | True | 是否启用梯度检查点(以稍慢的反向传播为代价节省显存,父类默认值为 False)。 |
bf16 | Optional[bool] | None | 是否使用 bf16 混合精度训练。默认根据 fp16 自动设置(fp16 为 False 则默认启用)。 |
average_tokens_across_devices | bool | True | 是否跨设备平均 tokens 数量(用于精确计算损失,适配多设备训练)。 |
控制模型的参数 | |||
model_init_kwargs | Optional[dict] | None | 加载模型时的关键字参数(当 model 为字符串时,传给 AutoModelForCausalLM.from_pretrained )。 |
chat_template_path | Optional[str] | None | 模型聊天模板路径(可为分词器目录、Hugging Face 模型ID或 Jinja 模板文件),用于格式化对话。 |
控制数据预处理的参数 | |||
dataset_text_field | str | “text” | 数据集中存储文本的列名(如数据集里文本存在 text 列则无需修改)。 |
dataset_kwargs | Optional[dict] | None | 数据集准备的可选参数(仅支持 skip_prepare_dataset 键,用于跳过数据集预处理)。 |
dataset_num_proc | Optional[int] | None | 处理数据集的进程数(加速数据预处理)。 |
eos_token | Optional[str] | None | 序列结束符。默认使用处理类(如分词器)的 eos_token 。 |
pad_token | Optional[str] | None | 填充符。默认使用处理类的 pad_token ,若不存在则 fallback 到 eos_token 。 |
max_length | Optional[int] | 1024 | token 化后的最大序列长度,超过则从右侧截断;启用打包时,此值为固定块长度。 |
packing | bool | False | 是否将多个短序列打包成固定长度块(减少填充,提升效率),长度由 max_length 定义。 |
packing_strategy | str | “bfd” | 打包策略:"bfd" (最佳适配递减,默认)或 "wrapped" (包裹式)。 |
padding_free | bool | False | 是否无填充训练(将批次序列扁平化为单个连续序列,减少填充开销),需配合 FlashAttention 使用。 |
pad_to_multiple_of | Optional[int] | None | 序列填充到该值的倍数(如 8、16,提升硬件效率)。 |
eval_packing | Optional[bool] | None | 评估数据集是否启用打包,默认与 packing 保持一致。 |
控制训练的参数 | |||
completion_only_loss | Optional[bool] | None | 是否只计算“补全部分”的损失: - 对“提示-补全”型数据集,默认只算补全部分; - 对语言建模型数据集,默认算全序列。 |
assistant_only_loss | bool | False | 是否只计算“助手回复部分”的损失(仅支持对话型数据集)。 |
activation_offloading | bool | False | 是否将激活值卸载到 CPU(进一步节省 GPU 显存)。 |